Guide: Performing Cross-Validation and Model Evaluation
This guide explains common methods for evaluating how well your ammm model generalises to unseen data, focusing on Leave-One-Out Cross-Validation (LOO-CV) and hold-out set validation.
Overview
Section titled “Overview”Evaluating model performance on data it wasn’t trained on is crucial for understanding its predictive capabilities and avoiding overfitting. Two primary methods are discussed:
- Leave-One-Out Cross-Validation (LOO-CV): A robust method for estimating out-of-sample predictive performance, approximated efficiently using Pareto Smoothed Importance Sampling (PSIS-LOO) via the ArviZ library.
- Hold-out Set Validation: A simpler approach where a portion of the data is withheld during training and used for evaluation afterwards.
Leave-One-Out Cross-Validation (LOO-CV) with ArviZ
Section titled “Leave-One-Out Cross-Validation (LOO-CV) with ArviZ”LOO-CV approximates the process of fitting the model N times, leaving out each data point once, and evaluating the prediction for the left-out point. This provides a good estimate of generalisation performance but is computationally expensive to perform exactly. ArviZ provides the az.loo() function, which uses PSIS to approximate this process efficiently using the fitted model’s InferenceData.
Steps:
- Fit your model: Run the ammm model driver (e.g.,
runme.py) to fit your model and generate themodel.ncfile containing the ArviZInferenceData(NetCDF). - Calculate LOO-CV: In a separate Python script or notebook, load the fitted model and calculate LOO-CV using ArviZ:
import arviz as azfrom pathlib import Pathfrom src.core.mmm_model_v2 import DelayedSaturatedMMMv2model_path = Path("results/model.nc")if not model_path.exists():print("No saved model found at results/model.nc. Run the pipeline first.")else:model = DelayedSaturatedMMMv2.load(str(model_path))idata = getattr(model, "idata", None)if idata is None:raise RuntimeError("Loaded model has no 'idata'. Re-fit the model with sampling enabled.")# Calculate LOO-CVloo_results = az.loo(idata)print(loo_results)
- Interpret Results: The
loo_resultsobject contains key metrics:elpd_loo: Expected Log Pointwise Predictive Density. Higher values are better, indicating better out-of-sample predictive accuracy. This is the primary metric for comparison.p_loo: Estimated effective number of parameters. Acts as a penalty for model complexity.loo_se: Standard error of theelpd_looestimate.pareto_k: Diagnostic values. High values (> 0.7) indicate unreliable LOO-CV estimates for specific data points, suggesting the approximation might be poor.
Refer to the LOO-CV Explanation for more details on interpretation.
Using LOO-CV for Model Comparison
Section titled “Using LOO-CV for Model Comparison”LOO-CV is particularly useful for comparing different model specifications (e.g., different priors, inclusion/exclusion of control variables, different seasonality settings).
- Fit multiple models using the exact same training dataset for each variation. Keep each run’s outputs in a separate results directory (e.g.,
results_model_a/,results_model_b/). - Load the fitted models and extract their
InferenceData. - Use ArviZ’s
az.compare()function:import arviz as azfrom pathlib import Pathfrom src.core.mmm_model_v2 import DelayedSaturatedMMMv2model_a_path = Path("results_model_a/model.nc")model_b_path = Path("results_model_b/model.nc")if not (model_a_path.exists() and model_b_path.exists()):print("Missing one or both model files. Ensure each run saved model.nc into its results directory.")else:model_a = DelayedSaturatedMMMv2.load(str(model_a_path))model_b = DelayedSaturatedMMMv2.load(str(model_b_path))idata_a = model_a.idataidata_b = model_b.idatamodel_comparison = az.compare({"model_a": idata_a, "model_b": idata_b})print(model_comparison)az.plot_compare(model_comparison) # Visualise comparison - Interpret Comparison: The
az.compare()output ranks models based onelpd_loo, providing differences and standard errors of the differences, helping you determine if one model is significantly better than another in terms of predictive accuracy.
Hold-out Set Validation
Section titled “Hold-out Set Validation”Setting train_test_ratio in the configuration file to a value less than 1.0 (e.g., 0.8) reserves the later portion of the data as a hold-out test set. This provides a simpler, intuitive way to assess generalisation.
Steps:
- Configure: Set
train_test_ratioin your YAML config (e.g.,0.8). - Fit Model: Run the ammm driver. The model will be trained only on the initial portion (e.g., 80%) of the data.
- Evaluate:
- Examine the
model_fit_predictions.pngplot. If generated with test data, it will show predictions against actuals for the hold-out period. Visually assess the fit. - (Recommended) In a separate script/notebook, load the full dataset and the fitted model (
model.nc). Use the model’s.predict()method (if available, or manually calculate predictions using posterior means) on the hold-out portion of the features (X_test). Compare these predictions against the actual hold-out target values (y_test) using standard regression metrics like Mean Absolute Percentage Error (MAPE), Root Mean Squared Error (RMSE), or Mean Absolute Error (MAE).
- Examine the