Plot Diagnostics (`sketch.plot_diagnostics`)
This module provides functions for plotting model diagnostics, such as posterior predictions, parameter traces, posterior distributions, and model structure graphs.
Key functions
-
plot_posterior_predictions(X, y, X_test, y_test, mmm, config, output_dir, n_points=52, show_oos_r2=False)- Two-panel figure:
- Top: Actuals vs posterior predictive (median) with 50%/90% HDI fans; optional OOS median + HDI and Holdout shading.
- Bottom: Residuals over time (±2σ band).
- Inputs: in-sample (X, y), optional OOS (X_test, y_test), fitted model
mmm,config(usesdate_col,target_col), output directory. - Saves:
model_fit_predictions.png.
- Two-panel figure:
-
plot_model_trace(model, results_dir)- Parameter trace plots via ArviZ for
intercept,beta_channel,alpha,lam, and present conditionally:likelihood_sigma,gamma_control,gamma_fourier. - Input: fitted model with
fit_result(ArviZ InferenceData), output directory. - Saves:
model_trace.png.
- Parameter trace plots via ArviZ for
-
plot_posterior_distributions(idata, results_dir, filename='posterior_distributions.png')- Small-multiples grid of posterior distributions for all parameters in
idata.posterior. - Inputs: ArviZ InferenceData, output directory, optional filename.
- Saves:
posterior_distributions.png(default).
- Small-multiples grid of posterior distributions for all parameters in
-
plot_model_structure(model)- Returns a Graphviz graph of the PyMC model; requires Graphviz installed.
- Input: object with a
.modelattribute of typepymc.Model. - Returns:
graphviz.DigraphorNoneif Graphviz is unavailable.
Usage example
from src.sketch.plot_diagnostics import ( plot_posterior_predictions, plot_model_trace, plot_posterior_distributions, plot_model_structure,)
# Posterior predictive diagnosticsplot_posterior_predictions(X, y, X_test, y_test, mmm, config, output_dir=results_dir, show_oos_r2=True)
# Trace and posterior summariesplot_model_trace(mmm, results_dir)plot_posterior_distributions(mmm.fit_result, results_dir)
# Model structure (requires Graphviz)graph = plot_model_structure(mmm)if graph is not None: graph.render("model_structure", directory=results_dir, format="png", cleanup=True)Notes
- Date handling: dates are read from the columns specified in
config['date_col']for both in-sample and OOS frames. - Scaling: posterior predictive outputs are transformed back to original scale using in-graph scaling parameters or the target transformer as applicable.