Skip to content

Plot Diagnostics (`sketch.plot_diagnostics`)

This module provides functions for plotting model diagnostics, such as posterior predictions, parameter traces, posterior distributions, and model structure graphs.

Key functions

  • plot_posterior_predictions(X, y, X_test, y_test, mmm, config, output_dir, n_points=52, show_oos_r2=False)

    • Two-panel figure:
      • Top: Actuals vs posterior predictive (median) with 50%/90% HDI fans; optional OOS median + HDI and Holdout shading.
      • Bottom: Residuals over time (±2σ band).
    • Inputs: in-sample (X, y), optional OOS (X_test, y_test), fitted model mmm, config (uses date_col, target_col), output directory.
    • Saves: model_fit_predictions.png.
  • plot_model_trace(model, results_dir)

    • Parameter trace plots via ArviZ for intercept, beta_channel, alpha, lam, and present conditionally: likelihood_sigma, gamma_control, gamma_fourier.
    • Input: fitted model with fit_result (ArviZ InferenceData), output directory.
    • Saves: model_trace.png.
  • plot_posterior_distributions(idata, results_dir, filename='posterior_distributions.png')

    • Small-multiples grid of posterior distributions for all parameters in idata.posterior.
    • Inputs: ArviZ InferenceData, output directory, optional filename.
    • Saves: posterior_distributions.png (default).
  • plot_model_structure(model)

    • Returns a Graphviz graph of the PyMC model; requires Graphviz installed.
    • Input: object with a .model attribute of type pymc.Model.
    • Returns: graphviz.Digraph or None if Graphviz is unavailable.

Usage example

from src.sketch.plot_diagnostics import (
plot_posterior_predictions,
plot_model_trace,
plot_posterior_distributions,
plot_model_structure,
)
# Posterior predictive diagnostics
plot_posterior_predictions(X, y, X_test, y_test, mmm, config, output_dir=results_dir, show_oos_r2=True)
# Trace and posterior summaries
plot_model_trace(mmm, results_dir)
plot_posterior_distributions(mmm.fit_result, results_dir)
# Model structure (requires Graphviz)
graph = plot_model_structure(mmm)
if graph is not None:
graph.render("model_structure", directory=results_dir, format="png", cleanup=True)

Notes

  • Date handling: dates are read from the columns specified in config['date_col'] for both in-sample and OOS frames.
  • Scaling: posterior predictive outputs are transformed back to original scale using in-graph scaling parameters or the target transformer as applicable.