Skip to content

Pipeline: Predict

Purpose

The predict stage scores a single (season_year, init_date) pair against a fitted model and writes county-level predictions to disk. The module is deliberately split into two functions (predict and write_walk_forward_outputs) so that walk-forward CV can accumulate K init-dates' predictions in memory and write once per fold, while single-init callers (CLI predict, forecast pipeline) compose both functions via the thin run_predict wrapper. The separation prevents the blind-overwrite semantics of write_walk_forward_outputs from destroying earlier init-date rows during walk-forward iteration (run_predict.py:13–25).

Inputs

Input Path Format Producer
run_dir Filesystem path or s3:// URI Directory run_hindcast._create_run_root
detrender.pkl {run_dir}/models/{experiment_key}/{fold_label}/detrender.pkl Pickle (joblib) run_fit.train
Regressor file(s) {run_dir}/models/{experiment_key}/{fold_label}/ joblib / XGBoost JSON run_fit.train
feature_fill_values.parquet {run_dir}/models/{experiment_key}/{fold_label}/feature_fill_values.parquet Parquet run_fit.train
pred.parquet (forecast path) {run_dir}/forecast/{season_year}/{init_date}/features/pred.parquet Parquet run_forecast.run_features
pred.parquet (canonical path) {features_dir}/{experiment_key}/pred.parquet Parquet build_features / assemble
included_geo_identifiers {run_dir}/included_geo_identifiers.json JSON run_hindcast._persist_included

Outputs

Output Path Format Consumer
walk_forward_preds.parquet {preds_dir}/{experiment_key}/{fold_label}/walk_forward_preds.parquet Parquet run_meta_models.postprocess_experiment
year_data.parquet {preds_dir}/{experiment_key}/{fold_label}/year_data.parquet Parquet Diagnostics, dashboard

Step-by-step flow

predict(run_dir, *, season_year, init_date) — pure compute (run_predict.py:282)

  1. Resolve fold label_resolve_fold_label(run_dir, experiment_key, season_year) (run_predict.py:86): checks {run_dir}/models/{experiment_key}/production/ first; if present, returns "production". Otherwise tries {run_dir}/models/{experiment_key}/{season_year}/. Raises FileNotFoundError if neither directory exists.

  2. Load model artefacts_load_prediction_inputs(run_dir, season_year=..., init_date=...) (run_predict.py:172): a. ExperimentResult.from_run_dir(run_dir) — load experiment handle. b. fold.load_detrender(config) — rehydrate the fitted AbstractDetrend from detrender.pkl. c. fold.load_model() — load the concrete AbstractRegressionImpl. d. fold.load_feature_fill_values() + MedianImputer.from_fill_values(fill_values) — reconstruct imputer from training medians. e. result.load_included_geo_identifiers() — load county frozenset for filtering.

  3. Forecast vs canonical parquet routing — at run_predict.py:204–233:

  4. Prefer {run_dir}/forecast/{season_year}/{init_date}/features/pred.parquet when it exists (forecast mode; exact init_date, no grid snap).
  5. Fall back to {features_dir}/{experiment_key}/pred.parquet with grid-snap: if init_date is not on the hindcast init grid, config.commodity.nearest_init_date(season_year, init_date) selects the closest configured init (run_predict.py:221–229).
  6. Raises FileNotFoundError if neither path exists.

  7. Load prediction slice_load_prediction_slice(pred_parquet, season_year=..., init_date=..., included_geo_identifiers=...) (run_predict.py:120): uses Polars predicate pushdown (pl.scan_parquet(str(pred_parquet))) to filter to year == season_year, init_date == init_date, counties in included_geo_identifiers. Raises ValueError with available init-dates if the result is empty.

  8. Four-step inverse pipeline_predict(inputs) (run_predict.py:255):

  9. Step 1 — Detrend: inputs.detrender.transform(inputs.pred_slice) removes the trend from the prediction slice, adding target_detrended_col (run_predict.py:263).
  10. Step 2 — Score: predict_kernel(model, pred_slice_detrended, detrender, config, ...) from models/regression/runtime.py:140 applies the regressor to the detrended feature matrix (run_predict.py:264).
  11. Step 3 — Weather-correct: apply_weather_correction_postprocess inside predict_kernel scales by weather_correction_weight, applies per-DOY weights from season_doy_weather_weight, and clips by max_abs_weather_correction_bu_ac (runtime.py:40).
  12. Step 4 — Retrend: predict_kernel calls detrender.inverse_transform(data_regression, sim_detrended), yielding sim_yield_kg_ha in absolute kg/ha units (runtime.py:176).

  13. Assemble canonical framebuild_wide_prediction_frame(county_forecast, target_col, area_col, commodity) assembles the 9-column canonical frame (run_predict.py:313–318).

  14. Return(wide_prediction_frame, inputs) where inputs.pred_slice is the year_data for this init.

write_walk_forward_outputs(wide_prediction_frame, *, year_data, fold_preds_dir) — pure persistence (run_predict.py:322)

  1. Create directoryfold_preds_dir.mkdir(parents=True, exist_ok=True) (run_predict.py:335).
  2. Write walk_forward_preds.parquet — blind overwrite at fold_preds_dir / "walk_forward_preds.parquet" (run_predict.py:336–338).
  3. Write year_data.parquet — blind overwrite at fold_preds_dir / "year_data.parquet" (run_predict.py:337, 339).

run_predict(run_dir, *, season_year, init_date) — single-init wrapper (run_predict.py:346)

Composes steps 1–10 for a single init-date. Writes under {preds_dir}/{experiment_key}/{fold_label}/. Used by CLI predict and stages/run_forecast. NOT used by walk-forward CV.

Mermaid flow diagram

flowchart LR
    RD["run_dir + season_year + init_date"]
    RL["_resolve_fold_label\nrun_predict.py:86\n(production → season_year → error)"]
    LP["_load_prediction_inputs\nrun_predict.py:172\n(detrender + model + imputer + slice)"]
    PR["pred.parquet routing\nrun_predict.py:204\n(forecast path → canonical + snap)"]
    PS["_load_prediction_slice\npolars pushdown\nrun_predict.py:120"]

    subgraph KERNEL["_predict() — four-step inverse pipeline\nrun_predict.py:255"]
        DT["Step 1: detrender.transform\n(remove trend)"]
        SC["Step 2: predict_kernel\n(regressor score)"]
        WC["Step 3: weather correction\napply_weather_correction_postprocess"]
        RT["Step 4: detrender.inverse_transform\n(retrend → sim_yield_kg_ha)"]
        DT --> SC --> WC --> RT
    end

    WF["build_wide_prediction_frame\nrun_predict.py:313"]
    WR["write_walk_forward_outputs\nrun_predict.py:322\nwalk_forward_preds.parquet\nyear_data.parquet"]

    RD --> RL --> LP --> PR --> PS --> KERNEL --> WF --> WR

Invariants and contracts

DESIGN.md predict contract (verbatim):

"WHEN running the forecast stage (cli run forecast) for a production prediction, THE SYSTEM SHALL invoke the atomic predict operation exactly once for the CLI-provided (--season-year, --init-date) pair and produce exactly one row per geography per aggregation level."

Blind-overwrite contract (from module docstring, run_predict.py:13–14):

"Each call [to write_walk_forward_outputs] OVERWRITES the destination; there is no read-merge-write accumulator."

Walk-forward bypass contract (from module docstring, run_predict.py:18–25):

"Walk-forward CV … deliberately does NOT route through run_predict: it needs to accumulate K init_dates' worth of predictions per fold and write once."

Fold resolution priority (run_predict.py:86–98): production fold takes precedence over any season-year fold. This means after run fit-production, the production model is used for all subsequent cli predict and forecast calls regardless of season_year.

DESIGN.md Clause 34 (artefact contract, verbatim):

"PREDICT → preds/{experiment_key}/{label}/walk_forward_preds.parquet."

DESIGN.md Clause 29 (Polars path, verbatim):

"WHEN passing a path to polars.scan_parquet … the system SHALL convert via str(path)."

Applied at run_predict.py:136: pl.scan_parquet(str(pred_parquet)).

Failure modes and recovery

Symptom Likely cause Recovery
FileNotFoundError: No model found under … tried 'production' and '{year}' FIT stage not run Run cli run hindcast or cli run fit-production
ValueError: No rows in pred.parquet for season_year=N, init_date=… Year not covered in feature parquet or init-date off-grid in forecast mode Re-run features with feature_end_year raised; or check init_date format
FileNotFoundError: No pred.parquet found at either … Neither forecast nor canonical parquet present Run cli run features or cli run forecast-features
walk_forward_preds.parquet has only one init-date after CV Walk-forward caller used run_predict (overwrite) instead of accumulator Walk-forward must use run.runner._predict_fold_rolling, not run_predict
Negative or extreme sim_yield_kg_ha values inverse_transform retrend applied to uncorrected detrended values Check detrender and weather correction config

Cross-references

PRs that materially changed this stage

  • PR #369 (f5399b96) — restructured forecast pred.parquet path from forecast/{init_date}/ to forecast/{season_year}/{init_date}/, fixing a skip-if-exists collision that caused _load_prediction_slice to raise ValueError for season_year=N+1. Changed routing check at run_predict.py:204–213.
  • PR #345 (tl/fix-path-issues) — switched pl.scan_parquet to accept str(pred_parquet) instead of CloudPath directly (DESIGN.md Clause 29); fixed AnyPath-based path resolution for S3 targets.