Plot Input (`sketch.plot_input`)
This module provides functions for plotting input data characteristics, such as metrics over time and correlation matrices.
Key functions
-
plot_all_metrics(input_data, output_dir, suffix)- Plots, in a single vertical figure, the time series for all media volumes, media costs, extra features, and the target.
- Inputs:
InputDataobject, output directory, filename suffix. - Saves:
metrics_{suffix}.png.
-
plot_correlation_matrix(input_data, per_observation_df)- Returns a Plotly heatmap Figure and a DataFrame for the correlation matrix of columns that include “volume” or the target column.
- Inputs:
InputDataand a per-observation DataFrame (typically from preprocessing). - Returns:
(fig: plotly.graph_objects.Figure, corr_df: pd.DataFrame).
-
plot_all_media_spend(input_data, per_observation_df)- Returns a Plotly line chart of the target series over time (useful for quick trend inspection).
- Inputs:
InputDataand per-observation DataFrame. - Returns:
fig: plotly.graph_objects.Figure.
Usage example
from src.sketch.plot_input import plot_all_metrics, plot_correlation_matrix, plot_all_media_spend
# Plot all metrics to a single PNGplot_all_metrics(input_data, output_dir=results_dir, suffix="train")
# Correlation heatmap (volumes + target)corr_fig, corr_df = plot_correlation_matrix(input_data, per_observation_df)corr_fig.write_image(f"{results_dir}/correlation.png") # optional
# Target over time (Plotly figure)spend_fig = plot_all_media_spend(input_data, per_observation_df)spend_fig.write_html(f"{results_dir}/target_over_time.html") # optional