Pipeline: Predict¶
Purpose¶
The predict stage scores a single (season_year, init_date) pair against a fitted model and writes county-level predictions to disk. The module is deliberately split into two functions (predict and write_walk_forward_outputs) so that walk-forward CV can accumulate K init-dates' predictions in memory and write once per fold, while single-init callers (CLI predict, forecast pipeline) compose both functions via the thin run_predict wrapper. The separation prevents the blind-overwrite semantics of write_walk_forward_outputs from destroying earlier init-date rows during walk-forward iteration (run_predict.py:13–25).
Inputs¶
| Input | Path | Format | Producer |
|---|---|---|---|
run_dir |
Filesystem path or s3:// URI |
Directory | run_hindcast._create_run_root |
detrender.pkl |
{run_dir}/models/{experiment_key}/{fold_label}/detrender.pkl |
Pickle (joblib) | run_fit.train |
| Regressor file(s) | {run_dir}/models/{experiment_key}/{fold_label}/ |
joblib / XGBoost JSON | run_fit.train |
feature_fill_values.parquet |
{run_dir}/models/{experiment_key}/{fold_label}/feature_fill_values.parquet |
Parquet | run_fit.train |
pred.parquet (forecast path) |
{run_dir}/forecast/{season_year}/{init_date}/features/pred.parquet |
Parquet | run_forecast.run_features |
pred.parquet (canonical path) |
{features_dir}/{experiment_key}/pred.parquet |
Parquet | build_features / assemble |
included_geo_identifiers |
{run_dir}/included_geo_identifiers.json |
JSON | run_hindcast._persist_included |
Outputs¶
| Output | Path | Format | Consumer |
|---|---|---|---|
walk_forward_preds.parquet |
{preds_dir}/{experiment_key}/{fold_label}/walk_forward_preds.parquet |
Parquet | run_meta_models.postprocess_experiment |
year_data.parquet |
{preds_dir}/{experiment_key}/{fold_label}/year_data.parquet |
Parquet | Diagnostics, dashboard |
Step-by-step flow¶
predict(run_dir, *, season_year, init_date) — pure compute (run_predict.py:282)¶
-
Resolve fold label —
_resolve_fold_label(run_dir, experiment_key, season_year)(run_predict.py:86): checks{run_dir}/models/{experiment_key}/production/first; if present, returns"production". Otherwise tries{run_dir}/models/{experiment_key}/{season_year}/. RaisesFileNotFoundErrorif neither directory exists. -
Load model artefacts —
_load_prediction_inputs(run_dir, season_year=..., init_date=...)(run_predict.py:172): a.ExperimentResult.from_run_dir(run_dir)— load experiment handle. b.fold.load_detrender(config)— rehydrate the fittedAbstractDetrendfromdetrender.pkl. c.fold.load_model()— load the concreteAbstractRegressionImpl. d.fold.load_feature_fill_values()+MedianImputer.from_fill_values(fill_values)— reconstruct imputer from training medians. e.result.load_included_geo_identifiers()— load county frozenset for filtering. -
Forecast vs canonical parquet routing — at
run_predict.py:204–233: - Prefer
{run_dir}/forecast/{season_year}/{init_date}/features/pred.parquetwhen it exists (forecast mode; exact init_date, no grid snap). - Fall back to
{features_dir}/{experiment_key}/pred.parquetwith grid-snap: ifinit_dateis not on the hindcast init grid,config.commodity.nearest_init_date(season_year, init_date)selects the closest configured init (run_predict.py:221–229). -
Raises
FileNotFoundErrorif neither path exists. -
Load prediction slice —
_load_prediction_slice(pred_parquet, season_year=..., init_date=..., included_geo_identifiers=...)(run_predict.py:120): uses Polars predicate pushdown (pl.scan_parquet(str(pred_parquet))) to filter toyear == season_year,init_date == init_date, counties inincluded_geo_identifiers. RaisesValueErrorwith available init-dates if the result is empty. -
Four-step inverse pipeline —
_predict(inputs)(run_predict.py:255): - Step 1 — Detrend:
inputs.detrender.transform(inputs.pred_slice)removes the trend from the prediction slice, addingtarget_detrended_col(run_predict.py:263). - Step 2 — Score:
predict_kernel(model, pred_slice_detrended, detrender, config, ...)frommodels/regression/runtime.py:140applies the regressor to the detrended feature matrix (run_predict.py:264). - Step 3 — Weather-correct:
apply_weather_correction_postprocessinsidepredict_kernelscales byweather_correction_weight, applies per-DOY weights fromseason_doy_weather_weight, and clips bymax_abs_weather_correction_bu_ac(runtime.py:40). -
Step 4 — Retrend:
predict_kernelcallsdetrender.inverse_transform(data_regression, sim_detrended), yieldingsim_yield_kg_hain absolute kg/ha units (runtime.py:176). -
Assemble canonical frame —
build_wide_prediction_frame(county_forecast, target_col, area_col, commodity)assembles the 9-column canonical frame (run_predict.py:313–318). -
Return —
(wide_prediction_frame, inputs)whereinputs.pred_sliceis theyear_datafor this init.
write_walk_forward_outputs(wide_prediction_frame, *, year_data, fold_preds_dir) — pure persistence (run_predict.py:322)¶
- Create directory —
fold_preds_dir.mkdir(parents=True, exist_ok=True)(run_predict.py:335). - Write
walk_forward_preds.parquet— blind overwrite atfold_preds_dir / "walk_forward_preds.parquet"(run_predict.py:336–338). - Write
year_data.parquet— blind overwrite atfold_preds_dir / "year_data.parquet"(run_predict.py:337, 339).
run_predict(run_dir, *, season_year, init_date) — single-init wrapper (run_predict.py:346)¶
Composes steps 1–10 for a single init-date. Writes under {preds_dir}/{experiment_key}/{fold_label}/. Used by CLI predict and stages/run_forecast. NOT used by walk-forward CV.
Mermaid flow diagram¶
flowchart LR
RD["run_dir + season_year + init_date"]
RL["_resolve_fold_label\nrun_predict.py:86\n(production → season_year → error)"]
LP["_load_prediction_inputs\nrun_predict.py:172\n(detrender + model + imputer + slice)"]
PR["pred.parquet routing\nrun_predict.py:204\n(forecast path → canonical + snap)"]
PS["_load_prediction_slice\npolars pushdown\nrun_predict.py:120"]
subgraph KERNEL["_predict() — four-step inverse pipeline\nrun_predict.py:255"]
DT["Step 1: detrender.transform\n(remove trend)"]
SC["Step 2: predict_kernel\n(regressor score)"]
WC["Step 3: weather correction\napply_weather_correction_postprocess"]
RT["Step 4: detrender.inverse_transform\n(retrend → sim_yield_kg_ha)"]
DT --> SC --> WC --> RT
end
WF["build_wide_prediction_frame\nrun_predict.py:313"]
WR["write_walk_forward_outputs\nrun_predict.py:322\nwalk_forward_preds.parquet\nyear_data.parquet"]
RD --> RL --> LP --> PR --> PS --> KERNEL --> WF --> WR
Invariants and contracts¶
DESIGN.md predict contract (verbatim):
"WHEN running the forecast stage (
cli run forecast) for a production prediction, THE SYSTEM SHALL invoke the atomic predict operation exactly once for the CLI-provided(--season-year, --init-date)pair and produce exactly one row per geography per aggregation level."
Blind-overwrite contract (from module docstring, run_predict.py:13–14):
"Each call [to
write_walk_forward_outputs] OVERWRITES the destination; there is no read-merge-write accumulator."
Walk-forward bypass contract (from module docstring, run_predict.py:18–25):
"Walk-forward CV … deliberately does NOT route through
run_predict: it needs to accumulate K init_dates' worth of predictions per fold and write once."
Fold resolution priority (run_predict.py:86–98): production fold takes precedence over any season-year fold. This means after run fit-production, the production model is used for all subsequent cli predict and forecast calls regardless of season_year.
DESIGN.md Clause 34 (artefact contract, verbatim):
"PREDICT →
preds/{experiment_key}/{label}/walk_forward_preds.parquet."
DESIGN.md Clause 29 (Polars path, verbatim):
"WHEN passing a path to
polars.scan_parquet… the system SHALL convert viastr(path)."
Applied at run_predict.py:136: pl.scan_parquet(str(pred_parquet)).
Failure modes and recovery¶
| Symptom | Likely cause | Recovery |
|---|---|---|
FileNotFoundError: No model found under … tried 'production' and '{year}' |
FIT stage not run | Run cli run hindcast or cli run fit-production |
ValueError: No rows in pred.parquet for season_year=N, init_date=… |
Year not covered in feature parquet or init-date off-grid in forecast mode | Re-run features with feature_end_year raised; or check init_date format |
FileNotFoundError: No pred.parquet found at either … |
Neither forecast nor canonical parquet present | Run cli run features or cli run forecast-features |
walk_forward_preds.parquet has only one init-date after CV |
Walk-forward caller used run_predict (overwrite) instead of accumulator |
Walk-forward must use run.runner._predict_fold_rolling, not run_predict |
Negative or extreme sim_yield_kg_ha values |
inverse_transform retrend applied to uncorrected detrended values |
Check detrender and weather correction config |
Cross-references¶
- ExperimentConfig —
commodity.hindcast_init_dates,commodity.nearest_init_date,preds_dir - HindcastSlice —
load_detrender,load_model,load_feature_fill_values - ExperimentResult —
from_run_dir,load_included_geo_identifiers - Concept: walk-forward CV — accumulation semantics that bypass
run_predict - Pipeline: fit — upstream producer of all three artefacts consumed here
- Pipeline: postprocess — downstream consumer of
walk_forward_preds.parquet - Source: regression —
runtime.predict,apply_weather_correction_postprocess,inverse_transform - Source: detrend —
AbstractDetrend.transform,inverse_transform - Source: stages — full
run_predict.pysummary - Source: DESIGN.md — Clauses 29, 34; forecast isolation; walk-forward predict contract
PRs that materially changed this stage¶
- PR #369 (
f5399b96) — restructured forecast pred.parquet path fromforecast/{init_date}/toforecast/{season_year}/{init_date}/, fixing a skip-if-exists collision that caused_load_prediction_sliceto raiseValueErrorforseason_year=N+1. Changed routing check atrun_predict.py:204–213. - PR #345 (
tl/fix-path-issues) — switchedpl.scan_parquetto acceptstr(pred_parquet)instead ofCloudPathdirectly (DESIGN.md Clause 29); fixedAnyPath-based path resolution for S3 targets.