Skip to content

Plot Input (`sketch.plot_input`)

This module provides functions for plotting input data characteristics, such as metrics over time and correlation matrices.

Key functions

  • plot_all_metrics(input_data, output_dir, suffix)

    • Plots, in a single vertical figure, the time series for all media volumes, media costs, extra features, and the target.
    • Inputs: InputData object, output directory, filename suffix.
    • Saves: metrics_{suffix}.png.
  • plot_correlation_matrix(input_data, per_observation_df)

    • Returns a Plotly heatmap Figure and a DataFrame for the correlation matrix of columns that include “volume” or the target column.
    • Inputs: InputData and a per-observation DataFrame (typically from preprocessing).
    • Returns: (fig: plotly.graph_objects.Figure, corr_df: pd.DataFrame).
  • plot_all_media_spend(input_data, per_observation_df)

    • Returns a Plotly line chart of the target series over time (useful for quick trend inspection).
    • Inputs: InputData and per-observation DataFrame.
    • Returns: fig: plotly.graph_objects.Figure.

Usage example

from src.sketch.plot_input import plot_all_metrics, plot_correlation_matrix, plot_all_media_spend
# Plot all metrics to a single PNG
plot_all_metrics(input_data, output_dir=results_dir, suffix="train")
# Correlation heatmap (volumes + target)
corr_fig, corr_df = plot_correlation_matrix(input_data, per_observation_df)
corr_fig.write_image(f"{results_dir}/correlation.png") # optional
# Target over time (Plotly figure)
spend_fig = plot_all_media_spend(input_data, per_observation_df)
spend_fig.write_html(f"{results_dir}/target_over_time.html") # optional