diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 000000000..ebab1a835 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,274 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Project Overview + +PyMC-Marketing is a Bayesian marketing analytics library built on PyMC, providing three main modeling capabilities: + +- **Marketing Mix Modeling (MMM)**: Measure marketing channel effectiveness with adstock, saturation, and budget optimization +- **Customer Lifetime Value (CLV)**: Predict customer value using probabilistic models (BG/NBD, Pareto/NBD, Gamma-Gamma, etc.) +- **Customer Choice Analysis**: Understand product selection with Multivariate Interrupted Time Series (MVITS) and discrete choice models + +## Development Commands + +### Environment Setup +```bash +# Create and activate conda environment (recommended) +conda env create -f environment.yml +conda activate pymc-marketing-dev + +# Install package in editable mode +make init +``` + +### Testing and Quality +To use pytest you first need to activate the enviroment: +```bash +# Try to initialize conda (works if conda is in PATH or common locations) +eval "$(conda shell.bash hook 2>/dev/null)" && conda activate pymc-marketing-dev || \ +source "$(conda info --base 2>/dev/null)/etc/profile.d/conda.sh" && conda activate pymc-marketing-dev +``` + +Running tests: +```bash +# first need to activate the enviorment: + +# Run all tests with coverage +make test + +# Run specific test file (you first need to activate the conda env: conda activate pymc-marketing-dev) +pytest tests/path/to/test_file.py + +# Run specific test function (you first need to activate the conda env: conda activate pymc-marketing-dev) +pytest tests/path/to/test_file.py::test_function_name + +# Check linting (ruff + mypy) +make check_lint + +# Auto-fix linting issues +make lint + +# Check code formatting +make check_format + +# Auto-format code +make format +``` + +### Documentation +```bash +# Build HTML documentation +make html + +# Clean docs and rebuild from scratch +make cleandocs && make html + +# Run notebooks to verify examples +make run_notebooks # All notebooks +make run_notebooks_mmm # MMM notebooks only +make run_notebooks_other # Non-MMM notebooks +``` + +### Other Utilities +```bash +# Generate UML diagrams for architecture +make uml + +# Start MLflow tracking server +make mlflow_server +``` + +## High-Level Architecture + +### Core Base Classes + +**ModelBuilder** ([pymc_marketing/model_builder.py](pymc_marketing/model_builder.py)) +- Abstract base class for all PyMC-Marketing models +- Defines the model lifecycle: `build_model()` → `fit()` → `predict()` +- Provides save/load functionality via NetCDF and InferenceData +- Manages `model_config` (priors) and `sampler_config` (MCMC settings) + +**RegressionModelBuilder** (extends ModelBuilder) +- Adds scikit-learn-like API: `fit(X, y)`, `predict(X)` +- Base class for MMM and some customer choice models +- Handles prior/posterior predictive sampling + +**CLVModel** ([pymc_marketing/clv/models/basic.py](pymc_marketing/clv/models/basic.py)) +- Base class for CLV models (BetaGeo, ParetoNBD, GammaGamma, etc.) +- Takes data in constructor, not fit method: `model = BetaGeoModel(data=df)` +- Supports multiple inference methods: `method="mcmc"` (default), `"map"`, `"advi"`, etc. + +### Module 1: MMM Architecture + +**Class Hierarchy:** +``` +RegressionModelBuilder + └── MMMModelBuilder (mmm/base.py) + ├── BaseMMM/MMM (mmm/mmm.py) - Single market + └── MMM (mmm/multidimensional.py) - Panel/hierarchical data +``` + +**Component-Based Design:** + +MMM uses composable transformation components: + +1. **Adstock Transformations** ([pymc_marketing/mmm/components/adstock.py](pymc_marketing/mmm/components/adstock.py)) + - Model carryover effects of advertising + - Built-in: GeometricAdstock, DelayedAdstock, WeibullCDFAdstock, WeibullPDFAdstock + - All extend `AdstockTransformation` base class + +2. **Saturation Transformations** ([pymc_marketing/mmm/components/saturation.py](pymc_marketing/mmm/components/saturation.py)) + - Model diminishing returns + - Built-in: LogisticSaturation, HillSaturation, MichaelisMentenSaturation, TanhSaturation + - All extend `SaturationTransformation` base class + +3. **Transformation Protocol** ([pymc_marketing/mmm/components/base.py](pymc_marketing/mmm/components/base.py)) + - Base class defining transformation interface + - Requires: `function()`, `prefix`, `default_priors` + - Custom transformations should extend this + +**Validation and Preprocessing System:** + +MMM models use a decorator-based system: +- Methods tagged with `_tags = {"validation_X": True}` run during `fit(X, y)` +- Methods tagged with `_tags = {"preprocessing_y": True}` transform data before modeling +- Built-in validators in [pymc_marketing/mmm/validating.py](pymc_marketing/mmm/validating.py) +- Built-in preprocessors in [pymc_marketing/mmm/preprocessing.py](pymc_marketing/mmm/preprocessing.py) + +**Key MMM Features:** +- Time-varying parameters via HSGP (Hilbert Space Gaussian Process) +- Lift test calibration for experiments +- Budget optimization ([pymc_marketing/mmm/budget_optimizer.py](pymc_marketing/mmm/budget_optimizer.py)) +- Causal DAG support ([pymc_marketing/mmm/causal.py](pymc_marketing/mmm/causal.py)) +- Additive effects system ([pymc_marketing/mmm/additive_effect.py](pymc_marketing/mmm/additive_effect.py)) for custom components + +**Multidimensional MMM vs Base MMM:** +- Base MMM ([pymc_marketing/mmm/mmm.py](pymc_marketing/mmm/mmm.py)): Single market, simpler API +- Multidimensional MMM ([pymc_marketing/mmm/multidimensional.py](pymc_marketing/mmm/multidimensional.py)): Panel data, per-channel transformations via `MediaConfigList`, more flexible + +### Module 2: CLV Architecture + +**Available Models:** +- BetaGeoModel: Beta-Geometric/NBD for continuous non-contractual settings +- ParetoNBDModel: Pareto/NBD alternative formulation +- GammaGammaModel: Monetary value prediction +- ShiftedBetaGeoModel, ModifiedBetaGeoModel: Variants +- BetaGeoBetaBinomModel: Discrete time variant + +**CLV Pattern:** +```python +# Data passed to constructor, not fit() +model = clv.BetaGeoModel(data=df) + +# Fit with various inference methods +model.fit(method="mcmc") # or "map", "advi", "fullrank_advi" + +# Predict for known customers +model.expected_purchases(customer_id, t) +model.probability_alive(customer_id) +``` + +**Custom Distributions:** +CLV models use custom distributions in [pymc_marketing/clv/distributions.py](pymc_marketing/clv/distributions.py) + +### Module 3: Customer Choice + +- **MVITS** ([pymc_marketing/customer_choice/mv_its.py](pymc_marketing/customer_choice/mv_its.py)): Multivariate Interrupted Time Series for product launch incrementality +- **Discrete Choice Models**: Logit models in [pymc_marketing/customer_choice/](pymc_marketing/customer_choice/) + +### Cross-Cutting Systems + +**Prior Configuration System** ([pymc_marketing/prior.py](pymc_marketing/prior.py), now in pymc_extras) +- Declarative prior specification outside PyMC context +- Example: `Prior("Normal", mu=0, sigma=1)` +- Supports hierarchical priors, non-centered parameterization, transformations +- Used in all `model_config` dictionaries + +**Model Configuration** ([pymc_marketing/model_config.py](pymc_marketing/model_config.py)) +- `parse_model_config()` converts dicts to Prior objects +- Handles nested priors for hierarchical models +- Supports HSGP kwargs for Gaussian processes + +**Save/Load Infrastructure** +- Models save to NetCDF via ArviZ InferenceData +- `model.save("filename.nc")` serializes model + data + config +- `Model.load("filename.nc")` reconstructs from file +- Training data stored in `idata.fit_data` group + +**MLflow Integration** ([pymc_marketing/mlflow.py](pymc_marketing/mlflow.py)) +- `autolog()` patches PyMC and PyMC-Marketing functions +- Automatically logs: model structure, diagnostics (r_hat, ESS, divergences), MMM/CLV configs +- Start server with: `make mlflow_server` + +## Code Style and Testing + +**Linting:** +- Uses Ruff for linting and formatting +- Uses mypy for type checking +- Config in [pyproject.toml](pyproject.toml) under `[tool.ruff]` and `[tool.mypy]` +- Docstrings follow NumPy style guide + +**Testing:** +- pytest with coverage reporting +- Config in [pyproject.toml](pyproject.toml) under `[tool.pytest.ini_options]` +- Test files mirror package structure in [tests/](tests/) + +**Pre-commit Hooks:** +```bash +pre-commit install # Set up hooks +pre-commit run --all-files # Run manually +``` + +## Important Patterns and Conventions + +### Adding a New MMM Transformation + +1. Extend `AdstockTransformation` or `SaturationTransformation` from [pymc_marketing/mmm/components/base.py](pymc_marketing/mmm/components/base.py) +2. Implement: `function()`, `prefix` property, `default_priors` property +3. Add to [pymc_marketing/mmm/components/adstock.py](pymc_marketing/mmm/components/adstock.py) or [saturation.py](pymc_marketing/mmm/components/saturation.py) +4. Export in [pymc_marketing/mmm/__init__.py](pymc_marketing/mmm/__init__.py) + +### Adding a New CLV Model + +1. Extend `CLVModel` from [pymc_marketing/clv/models/basic.py](pymc_marketing/clv/models/basic.py) +2. Implement: `build_model()`, prediction methods (e.g., `expected_purchases()`) +3. Define required data columns in `__init__` +4. Add tests in [tests/clv/models/](tests/clv/models/) + +### Adding a New Additive Effect (MMM) + +1. Implement `MuEffect` protocol from [pymc_marketing/mmm/additive_effect.py](pymc_marketing/mmm/additive_effect.py) +2. Required methods: `create_data()`, `create_effect()`, `set_data()` +3. See FourierEffect, LinearTrendEffect as examples + +### Model Lifecycle + +All models follow this pattern: +1. **Configuration**: Store data and config in `__init__` +2. **Build**: `build_model()` creates PyMC model, attaches to `self.model` +3. **Fit**: `fit()` calls `pm.sample()` or alternative inference +4. **Store**: Results stored in `self.idata` (ArviZ InferenceData) +5. **Predict**: `sample_posterior_predictive()` with new data + +## Documentation and Examples + +**Notebooks:** +- MMM examples: [docs/source/notebooks/mmm/](docs/source/notebooks/mmm/) +- CLV examples: [docs/source/notebooks/clv/](docs/source/notebooks/clv/) +- Customer choice: [docs/source/notebooks/customer_choice/](docs/source/notebooks/customer_choice/) + +**Gallery Generation:** +- [scripts/generate_gallery.py](scripts/generate_gallery.py) creates notebook gallery for docs +- Run with `make html` + +**UML Diagrams:** +- Architecture diagrams in [docs/source/uml/](docs/source/uml/) +- Generate with `make uml` +- See [CONTRIBUTING.md](CONTRIBUTING.md) for package/class diagrams + +## Community and Support + +- [GitHub Issues](https://github.com/pymc-labs/pymc-marketing/issues) for bugs/features +- [PyMC Discourse](https://discourse.pymc.io/) for general discussion +- [PyMC-Marketing Discussions](https://github.com/pymc-labs/pymc-marketing/discussions) for Q&A diff --git a/environment.yml b/environment.yml index 829f07d6e..b2e69b551 100644 --- a/environment.yml +++ b/environment.yml @@ -62,3 +62,4 @@ dependencies: - pip: - roadmapper - labs-sphinx-theme + - arviz-plots[matplotlib] diff --git a/pymc_marketing/mmm/__init__.py b/pymc_marketing/mmm/__init__.py index 503c128f1..b43abd980 100644 --- a/pymc_marketing/mmm/__init__.py +++ b/pymc_marketing/mmm/__init__.py @@ -38,6 +38,7 @@ TanhSaturationBaselined, saturation_from_dict, ) +from pymc_marketing.mmm.config import mmm_plot_config from pymc_marketing.mmm.fourier import MonthlyFourier, WeeklyFourier, YearlyFourier from pymc_marketing.mmm.hsgp import ( HSGP, @@ -109,6 +110,7 @@ "create_eta_prior", "create_m_and_L_recommendations", "mmm", + "mmm_plot_config", "preprocessing", "preprocessing_method_X", "preprocessing_method_y", diff --git a/pymc_marketing/mmm/config.py b/pymc_marketing/mmm/config.py new file mode 100644 index 000000000..b67a16dcc --- /dev/null +++ b/pymc_marketing/mmm/config.py @@ -0,0 +1,167 @@ +# Copyright 2022 - 2025 The PyMC Labs Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Configuration management for MMM plotting.""" + +import warnings + +VALID_BACKENDS = {"matplotlib", "plotly", "bokeh"} + + +class MMMPlotConfig(dict): + """Configuration dictionary for MMM plotting settings. + + Global configuration object that controls MMM plotting behavior including + backend selection and version control. Modeled after ArviZ's rcParams pattern. + + Available Configuration Keys + ---------------------------- + + **plot.backend** : str, default="matplotlib" + Plotting backend to use for all plots in MMMPlotSuite. Options: + + * ``"matplotlib"`` - Static plots, publication-quality, widest compatibility + * ``"plotly"`` - Interactive plots with hover tooltips and zoom + * ``"bokeh"`` - Interactive plots with rich interactions + + Can be overridden per-method using the ``backend`` parameter. + + .. versionadded:: 0.18.0 + + **plot.show_warnings** : bool, default=True + Whether to show deprecation and other warnings from the plotting suite. + + .. versionadded:: 0.18.0 + + **plot.use_v2** : bool, default=False + Whether to use new arviz_plots-based plotting suite vs legacy suite. + + * ``False`` (default in v0.18.0): Use legacy matplotlib-only suite + * ``True``: Use new multi-backend arviz_plots-based suite + + This flag controls which suite is returned by ``MMM.plot`` property. + + .. versionadded:: 0.18.0 + + .. versionchanged:: 0.19.0 + Default will change to True (new suite becomes default). + + .. deprecated:: 0.20.0 + This flag will be removed as legacy suite is removed. + + Examples + -------- + Set plotting backend globally: + + .. code-block:: python + + from pymc_marketing.mmm import mmm_plot_config + + mmm_plot_config["plot.backend"] = "plotly" + # All plots now use plotly by default + mmm = MMM(...) + mmm.fit(X, y) + pc = mmm.plot.posterior_predictive() # Uses plotly + pc.show() + + Enable new plotting suite (v2): + + .. code-block:: python + + mmm_plot_config["plot.use_v2"] = True + # Now using arviz_plots-based multi-backend suite + mmm = MMM(...) + mmm.fit(X, y) + pc = mmm.plot.contributions_over_time(var=["intercept"]) + pc.show() + + Suppress warnings: + + .. code-block:: python + + mmm_plot_config["plot.show_warnings"] = False + + Reset to defaults: + + .. code-block:: python + + mmm_plot_config.reset() + mmm_plot_config["plot.backend"] + # 'matplotlib' + + Context manager pattern for temporary config changes: + + .. code-block:: python + + original = mmm_plot_config["plot.backend"] + try: + mmm_plot_config["plot.backend"] = "plotly" + # Use plotly for this section + pc = mmm.plot.posterior_predictive() + pc.show() + finally: + mmm_plot_config["plot.backend"] = original + + See Also + -------- + MMM.plot : Property that returns appropriate plot suite based on config + MMMPlotSuite : New multi-backend plotting suite + LegacyMMMPlotSuite : Legacy matplotlib-only suite + + Notes + ----- + Configuration changes affect all subsequent plot calls globally unless + overridden at the method level using the ``backend`` parameter. + + The configuration is a singleton - changes affect all MMM instances in + the current Python session. + """ + + _defaults = { + "plot.backend": "matplotlib", + "plot.show_warnings": True, + "plot.use_v2": False, # Use new arviz_plots-based suite (False = legacy suite for backward compatibility) + } + + VALID_KEYS = set(_defaults.keys()) + + def __init__(self): + super().__init__(self._defaults) + + def __setitem__(self, key, value): + """Set config value with validation for key and backend.""" + if key not in self.VALID_KEYS: + warnings.warn( + f"Invalid config key '{key}'. Valid keys are: {sorted(self.VALID_KEYS)}. " + f"Setting anyway, but this key may not be recognized.", + UserWarning, + stacklevel=2, + ) + if key == "plot.backend": + if value not in VALID_BACKENDS: + warnings.warn( + f"Invalid backend '{value}'. Valid backends are: {VALID_BACKENDS}. " + f"Setting anyway, but plotting may fail.", + UserWarning, + stacklevel=2, + ) + super().__setitem__(key, value) + + def reset(self): + """Reset all configuration to default values.""" + self.clear() + self.update(self._defaults) + + +# Global config instance +mmm_plot_config = MMMPlotConfig() diff --git a/pymc_marketing/mmm/legacy_plot.py b/pymc_marketing/mmm/legacy_plot.py new file mode 100644 index 000000000..bcf6689d8 --- /dev/null +++ b/pymc_marketing/mmm/legacy_plot.py @@ -0,0 +1,1937 @@ +# Copyright 2022 - 2025 The PyMC Labs Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""MMM related plotting class. + +Examples +-------- +Quickstart with MMM: + +.. code-block:: python + + from pymc_marketing.mmm import GeometricAdstock, LogisticSaturation + from pymc_marketing.mmm.multidimensional import MMM + import pandas as pd + + # Minimal dataset + X = pd.DataFrame( + { + "date": pd.date_range("2025-01-01", periods=12, freq="W-MON"), + "C1": [100, 120, 90, 110, 105, 115, 98, 102, 108, 111, 97, 109], + "C2": [80, 70, 95, 85, 90, 88, 92, 94, 91, 89, 93, 87], + } + ) + y = pd.Series( + [230, 260, 220, 240, 245, 255, 235, 238, 242, 246, 233, 249], name="y" + ) + + mmm = MMM( + date_column="date", + channel_columns=["C1", "C2"], + target_column="y", + adstock=GeometricAdstock(l_max=10), + saturation=LogisticSaturation(), + ) + mmm.fit(X, y) + mmm.sample_posterior_predictive(X) + + # Posterior predictive time series + _ = mmm.plot.posterior_predictive(var=["y"], hdi_prob=0.9) + + # Posterior contributions over time (e.g., channel_contribution) + _ = mmm.plot.contributions_over_time(var=["channel_contribution"], hdi_prob=0.9) + + # Channel saturation scatter plot (scaled space by default) + _ = mmm.plot.saturation_scatterplot(original_scale=False) + +Wrap a custom PyMC model +-------- + +Requirements + +- posterior_predictive plots: an `az.InferenceData` with a `posterior_predictive` group + containing the variable(s) you want to plot with a `date` coordinate. +- contributions_over_time plots: a `posterior` group with time‑series variables (with `date`). +- saturation plots: a `constant_data` dataset with variables: + - `channel_data`: dims include `("date", "channel", ...)` + - `channel_scale`: dims include `("channel", ...)` + - `target_scale`: scalar or broadcastable to the curve dims + and a `posterior` variable named `channel_contribution` (or + `channel_contribution_original_scale` if plotting `original_scale=True`). + +.. code-block:: python + + import numpy as np + import pandas as pd + import pymc as pm + from pymc_marketing.mmm.plot import MMMPlotSuite + + dates = pd.date_range("2025-01-01", periods=30, freq="D") + y_obs = np.random.normal(size=len(dates)) + + with pm.Model(coords={"date": dates}): + sigma = pm.HalfNormal("sigma", 1.0) + pm.Normal("y", 0.0, sigma, observed=y_obs, dims="date") + + idata = pm.sample_prior_predictive(random_seed=1) + idata.extend(pm.sample(draws=200, chains=2, tune=200, random_seed=1)) + idata.extend(pm.sample_posterior_predictive(idata, random_seed=1)) + + plot = MMMPlotSuite(idata) + _ = plot.posterior_predictive(var=["y"], hdi_prob=0.9) + +Custom contributions_over_time +-------- + +.. code-block:: python + + import numpy as np + import pandas as pd + import pymc as pm + from pymc_marketing.mmm.plot import MMMPlotSuite + + dates = pd.date_range("2025-01-01", periods=30, freq="D") + x = np.linspace(0, 2 * np.pi, len(dates)) + series = np.sin(x) + + with pm.Model(coords={"date": dates}): + pm.Deterministic("component", series, dims="date") + idata = pm.sample_prior_predictive(random_seed=2) + idata.extend(pm.sample(draws=50, chains=1, tune=0, random_seed=2)) + + plot = MMMPlotSuite(idata) + _ = plot.contributions_over_time(var=["component"], hdi_prob=0.9) + +Saturation plots with a custom model +-------- + +.. code-block:: python + + import numpy as np + import pandas as pd + import xarray as xr + import pymc as pm + from pymc_marketing.mmm.plot import MMMPlotSuite + + dates = pd.date_range("2025-01-01", periods=20, freq="W-MON") + channels = ["C1", "C2"] + + # Create constant_data required for saturation plots + channel_data = xr.DataArray( + np.random.rand(len(dates), len(channels)), + dims=("date", "channel"), + coords={"date": dates, "channel": channels}, + name="channel_data", + ) + channel_scale = xr.DataArray( + np.ones(len(channels)), + dims=("channel",), + coords={"channel": channels}, + name="channel_scale", + ) + target_scale = xr.DataArray(1.0, name="target_scale") + + # Build a toy model that yields a matching posterior var + with pm.Model(coords={"date": dates, "channel": channels}): + # A fake contribution over time per channel (dims must include date & channel) + contrib = pm.Normal("channel_contribution", 0.0, 1.0, dims=("date", "channel")) + + idata = pm.sample_prior_predictive(random_seed=3) + idata.extend(pm.sample(draws=50, chains=1, tune=0, random_seed=3)) + + # Attach constant_data to idata + idata.constant_data = xr.Dataset( + { + "channel_data": channel_data, + "channel_scale": channel_scale, + "target_scale": target_scale, + } + ) + + plot = MMMPlotSuite(idata) + _ = plot.saturation_scatterplot(original_scale=False) + +Notes +----- +- `MMM` exposes this suite via the `mmm.plot` property, which internally passes the model's + `idata` into `MMMPlotSuite`. +- Any PyMC model can use `MMMPlotSuite` directly if its `InferenceData` contains the needed + groups/variables described above. +""" + +import itertools +from collections.abc import Iterable +from typing import Any + +import arviz as az +import matplotlib.pyplot as plt +import numpy as np +import xarray as xr +from matplotlib.axes import Axes +from matplotlib.figure import Figure +from numpy.typing import NDArray + +WIDTH_PER_COL: float = 10.0 +HEIGHT_PER_ROW: float = 4.0 + + +class LegacyMMMPlotSuite: + """Legacy matplotlib-based MMM plotting suite. + + .. deprecated:: 0.18.0 + This class will be removed in v0.20.0. Use MMMPlotSuite with + mmm_plot_config["plot.use_v2"] = True for the new arviz_plots-based suite. + + This class is maintained for backward compatibility but will be removed + in a future release. The new MMMPlotSuite supports multiple backends + (matplotlib, plotly, bokeh) and returns PlotCollection objects. + + Provides methods for visualizing the posterior predictive distribution, + contributions over time, and saturation curves for a Media Mix Model. + """ + + def __init__( + self, + idata: xr.Dataset | az.InferenceData, + ): + self.idata = idata + + def _init_subplots( + self, + n_subplots: int, + ncols: int = 1, + width_per_col: float = 10.0, + height_per_row: float = 4.0, + ) -> tuple[Figure, NDArray[Axes]]: + """Initialize a grid of subplots. + + Parameters + ---------- + n_subplots : int + Number of rows (if ncols=1) or total subplots. + ncols : int + Number of columns in the subplot grid. + width_per_col : float + Width (in inches) for each column of subplots. + height_per_row : float + Height (in inches) for each row of subplots. + + Returns + ------- + fig : matplotlib.figure.Figure + The created Figure object. + axes : np.ndarray of matplotlib.axes.Axes + 2D array of axes of shape (n_subplots, ncols). + """ + fig, axes = plt.subplots( + nrows=n_subplots, + ncols=ncols, + figsize=(width_per_col * ncols, height_per_row * n_subplots), + squeeze=False, + ) + return fig, axes + + def _build_subplot_title( + self, + dims: list[str], + combo: tuple, + fallback_title: str = "Time Series", + ) -> str: + """Build a subplot title string from dimension names and their values.""" + if dims: + title_parts = [f"{d}={v}" for d, v in zip(dims, combo, strict=False)] + return ", ".join(title_parts) + return fallback_title + + def _get_additional_dim_combinations( + self, + data: xr.Dataset, + variable: str, + ignored_dims: set[str], + ) -> tuple[list[str], list[tuple]]: + """Identify dimensions to plot over and get their coordinate combinations.""" + if variable not in data: + raise ValueError(f"Variable '{variable}' not found in the dataset.") + + all_dims = list(data[variable].dims) + additional_dims = [d for d in all_dims if d not in ignored_dims] + + if additional_dims: + additional_coords = [data.coords[d].values for d in additional_dims] + dim_combinations = list(itertools.product(*additional_coords)) + else: + # If no extra dims, just treat as a single combination + dim_combinations = [()] + + return additional_dims, dim_combinations + + def _reduce_and_stack( + self, data: xr.DataArray, dims_to_ignore: set[str] | None = None + ) -> xr.DataArray: + """Sum over leftover dims and stack chain+draw into sample if present.""" + if dims_to_ignore is None: + dims_to_ignore = {"date", "chain", "draw", "sample"} + + leftover_dims = [d for d in data.dims if d not in dims_to_ignore] + if leftover_dims: + data = data.sum(dim=leftover_dims) + + # Combine chain+draw into 'sample' if both exist + if "chain" in data.dims and "draw" in data.dims: + data = data.stack(sample=("chain", "draw")) + + return data + + def _get_posterior_predictive_data( + self, + idata: xr.Dataset | None, + ) -> xr.Dataset: + """Retrieve the posterior_predictive group from either provided or self.idata.""" + if idata is not None: + return idata + + # Otherwise, check if self.idata has posterior_predictive + if ( + not hasattr(self.idata, "posterior_predictive") # type: ignore + or self.idata.posterior_predictive is None # type: ignore + ): + raise ValueError( + "No posterior_predictive data found in 'self.idata'. " + "Please run 'MMM.sample_posterior_predictive()' or provide " + "an external 'idata' argument." + ) + return self.idata.posterior_predictive # type: ignore + + def _add_median_and_hdi( + self, ax: Axes, data: xr.DataArray, var: str, hdi_prob: float = 0.85 + ) -> Axes: + """Add median and HDI to the given axis.""" + median = data.median(dim="sample") if "sample" in data.dims else data.median() + hdi = az.hdi( + data, + hdi_prob=hdi_prob, + input_core_dims=[["sample"]] if "sample" in data.dims else None, + ) + + if "date" not in data.dims: + raise ValueError(f"Expected 'date' dimension in {var}, but none found.") + dates = data.coords["date"].values + # Add median and HDI to the plot + ax.plot(dates, median, label=var, alpha=0.9) + ax.fill_between(dates, hdi[var][..., 0], hdi[var][..., 1], alpha=0.2) + return ax + + def _validate_dims( + self, + dims: dict[str, str | int | list], + all_dims: list[str], + ) -> None: + """Validate that provided dims exist in the model's dimensions and values.""" + if dims: + for key, val in dims.items(): + if key not in all_dims: + raise ValueError( + f"Dimension '{key}' not found in idata dimensions." + ) + valid_values = self.idata.posterior.coords[key].values + if isinstance(val, (list, tuple, np.ndarray)): + for v in val: + if v not in valid_values: + raise ValueError( + f"Value '{v}' not found in dimension '{key}'." + ) + else: + if val not in valid_values: + raise ValueError( + f"Value '{val}' not found in dimension '{key}'." + ) + + def _dim_list_handler( + self, dims: dict[str, str | int | list] | None + ) -> tuple[list[str], list[tuple]]: + """Extract keys, values, and all combinations for list-valued dims.""" + dims_lists = { + k: v + for k, v in (dims or {}).items() + if isinstance(v, (list, tuple, np.ndarray)) + } + if dims_lists: + dims_keys = list(dims_lists.keys()) + dims_values = [ + v if isinstance(v, (list, tuple, np.ndarray)) else [v] + for v in dims_lists.values() + ] + dims_combos = list(itertools.product(*dims_values)) + else: + dims_keys = [] + dims_combos = [()] + return dims_keys, dims_combos + + def _resolve_backend(self, backend: str | None) -> str: + """Resolve backend parameter to actual backend string.""" + from pymc_marketing.mmm.config import mmm_plot_config + + return backend or mmm_plot_config["plot.backend"] + + # ------------------------------------------------------------------------ + # Main Plotting Methods + # ------------------------------------------------------------------------ + + def posterior_predictive( + self, + var: list[str] | None = None, + idata: xr.Dataset | None = None, + hdi_prob: float = 0.85, + ) -> tuple[Figure, NDArray[Axes]]: + """Plot time series from the posterior predictive distribution. + + By default, if both `var` and `idata` are not provided, uses + `self.idata.posterior_predictive` and defaults the variable to `["y"]`. + + Parameters + ---------- + var : list of str, optional + A list of variable names to plot. Default is ["y"] if not provided. + idata : xarray.Dataset, optional + The posterior predictive dataset to plot. If not provided, tries to + use `self.idata.posterior_predictive`. + hdi_prob: float, optional + The probability mass of the highest density interval to be displayed. Default is 0.85. + + Returns + ------- + fig : matplotlib.figure.Figure + The Figure object containing the subplots. + axes : np.ndarray of matplotlib.axes.Axes + Array of Axes objects corresponding to each subplot row. + + Raises + ------ + ValueError + If no `idata` is provided and `self.idata.posterior_predictive` does + not exist, instructing the user to run `MMM.sample_posterior_predictive()`. + If `hdi_prob` is not between 0 and 1, instructing the user to provide a valid value. + """ + if not 0 < hdi_prob < 1: + raise ValueError("HDI probability must be between 0 and 1.") + # 1. Retrieve or validate posterior_predictive data + pp_data = self._get_posterior_predictive_data(idata) + + # 2. Determine variables to plot + if var is None: + var = ["y"] + main_var = var[0] + + # 3. Identify additional dims & get all combos + ignored_dims = {"chain", "draw", "date", "sample"} + additional_dims, dim_combinations = self._get_additional_dim_combinations( + data=pp_data, variable=main_var, ignored_dims=ignored_dims + ) + + # 4. Prepare subplots + fig, axes = self._init_subplots(n_subplots=len(dim_combinations), ncols=1) + + # 5. Loop over dimension combinations + for row_idx, combo in enumerate(dim_combinations): + ax = axes[row_idx][0] + + # Build indexers + indexers = ( + dict(zip(additional_dims, combo, strict=False)) + if additional_dims + else {} + ) + + # 6. Plot each requested variable + for v in var: + if v not in pp_data: + raise ValueError( + f"Variable '{v}' not in the posterior_predictive dataset." + ) + + data = pp_data[v].sel(**indexers) + # Sum leftover dims, stack chain+draw if needed + data = self._reduce_and_stack(data, ignored_dims) + ax = self._add_median_and_hdi(ax, data, v, hdi_prob=hdi_prob) + + # 7. Subplot title & labels + title = self._build_subplot_title( + dims=additional_dims, + combo=combo, + fallback_title="Posterior Predictive Time Series", + ) + ax.set_title(title) + ax.set_xlabel("Date") + ax.set_ylabel("Posterior Predictive") + ax.legend(loc="best") + + return fig, axes + + def contributions_over_time( + self, + var: list[str], + hdi_prob: float = 0.85, + dims: dict[str, str | int | list] | None = None, + ) -> tuple[Figure, NDArray[Axes]]: + """Plot the time-series contributions for each variable in `var`. + + showing the median and the credible interval (default 85%). + Creates one subplot per combination of non-(chain/draw/date) dimensions + and places all variables on the same subplot. + + Parameters + ---------- + var : list of str + A list of variable names to plot from the posterior. + hdi_prob: float, optional + The probability mass of the highest density interval to be displayed. Default is 0.85. + dims : dict[str, str | int | list], optional + Dimension filters to apply. Example: {"country": ["US", "UK"], "user_type": "new"}. + If provided, only the selected slice(s) will be plotted. + + Returns + ------- + fig : matplotlib.figure.Figure + The Figure object containing the subplots. + axes : np.ndarray of matplotlib.axes.Axes + Array of Axes objects corresponding to each subplot row. + + Raises + ------ + ValueError + If `hdi_prob` is not between 0 and 1, instructing the user to provide a valid value. + """ + if not 0 < hdi_prob < 1: + raise ValueError("HDI probability must be between 0 and 1.") + + if not hasattr(self.idata, "posterior"): + raise ValueError( + "No posterior data found in 'self.idata'. " + "Please ensure 'self.idata' contains a 'posterior' group." + ) + + main_var = var[0] + all_dims = list(self.idata.posterior[main_var].dims) # type: ignore + ignored_dims = {"chain", "draw", "date"} + additional_dims = [d for d in all_dims if d not in ignored_dims] + + coords = { + key: value.to_numpy() + for key, value in self.idata.posterior[var].coords.items() + } + + # Apply user-specified filters (`dims`) + if dims: + self._validate_dims(dims=dims, all_dims=all_dims) + # Remove filtered dims from the combinations + additional_dims = [d for d in additional_dims if d not in dims] + else: + self._validate_dims({}, all_dims) + # additional_dims = [d for d in additional_dims if d not in dims] + + # Identify combos for remaining dims + if additional_dims: + additional_coords = [ + self.idata.posterior.coords[dim].values # type: ignore + for dim in additional_dims + ] + dim_combinations = list(itertools.product(*additional_coords)) + else: + dim_combinations = [()] + + # If dims contains lists, build all combinations for those as well + dims_keys, dims_combos = self._dim_list_handler(dims) + + # Prepare subplots: one for each combo of dims_lists and additional_dims + total_combos = list(itertools.product(dims_combos, dim_combinations)) + fig, axes = self._init_subplots(len(total_combos), ncols=1) + + for row_idx, (dims_combo, addl_combo) in enumerate(total_combos): + ax = axes[row_idx][0] + # Build indexers for dims and additional_dims + indexers = ( + dict(zip(additional_dims, addl_combo, strict=False)) + if additional_dims + else {} + ) + if dims: + # For dims with lists, use the current value from dims_combo + for i, k in enumerate(dims_keys): + indexers[k] = dims_combo[i] + # For dims with single values, use as is + for k, v in (dims or {}).items(): + if k not in dims_keys: + indexers[k] = v + + # Plot posterior median and HDI for each var + for v in var: + data = self.idata.posterior[v] + missing_coords = { + key: value for key, value in coords.items() if key not in data.dims + } + data = data.expand_dims(**missing_coords) + data = data.sel(**indexers) # apply slice + data = self._reduce_and_stack( + data, dims_to_ignore={"date", "chain", "draw", "sample"} + ) + ax = self._add_median_and_hdi(ax, data, v, hdi_prob=hdi_prob) + + # Title includes both fixed and combo dims + title_dims = ( + list(dims.keys()) + additional_dims if dims else additional_dims + ) + title_combo = tuple(indexers[k] for k in title_dims) + + title = self._build_subplot_title( + dims=title_dims, combo=title_combo, fallback_title="Time Series" + ) + ax.set_title(title) + ax.set_xlabel("Date") + ax.set_ylabel("Posterior Value") + ax.legend(loc="best") + + return fig, axes + + def saturation_scatterplot( + self, + original_scale: bool = False, + dims: dict[str, str | int | list] | None = None, + **kwargs, + ) -> tuple[Figure, NDArray[Axes]]: + """Plot the saturation curves for each channel. + + Creates a grid of subplots for each combination of channel and non-(date/channel) dimensions. + Optionally, subset by dims (single values or lists). + Each channel will have a consistent color across all subplots. + """ + if not hasattr(self.idata, "constant_data"): + raise ValueError( + "No 'constant_data' found in 'self.idata'. " + "Please ensure 'self.idata' contains the constant_data group." + ) + + # Identify additional dimensions beyond 'date' and 'channel' + cdims = self.idata.constant_data.channel_data.dims + additional_dims = [dim for dim in cdims if dim not in ("date", "channel")] + + # Validate dims and remove filtered dims from additional_dims + if dims: + self._validate_dims(dims, list(self.idata.constant_data.channel_data.dims)) + additional_dims = [d for d in additional_dims if d not in dims] + else: + self._validate_dims({}, list(self.idata.constant_data.channel_data.dims)) + + # Build all combinations for dims with lists + dims_keys, dims_combos = self._dim_list_handler(dims) + + # Build all combinations for remaining dims + if additional_dims: + additional_coords = [ + self.idata.constant_data.coords[d].values for d in additional_dims + ] + additional_combinations = list(itertools.product(*additional_coords)) + else: + additional_combinations = [()] + + channels = self.idata.constant_data.coords["channel"].values + n_channels = len(channels) + n_addl = len(additional_combinations) + n_dims = len(dims_combos) + + # For most use cases, n_dims will be 1, so grid is channels x additional_combinations + # If dims_combos > 1, treat as extra axis (rare, but possible) + nrows = n_channels + ncols = n_addl * n_dims + total_combos = list( + itertools.product(channels, dims_combos, additional_combinations) + ) + n_subplots = len(total_combos) + + # Assign a color to each channel + channel_colors = {ch: f"C{i}" for i, ch in enumerate(channels)} + + # Prepare subplots as a grid + fig, axes = plt.subplots( + nrows=nrows, + ncols=ncols, + figsize=( + kwargs.get("width_per_col", 8) * ncols, + kwargs.get("height_per_row", 4) * nrows, + ), + squeeze=False, + ) + + channel_contribution = ( + "channel_contribution_original_scale" + if original_scale + else "channel_contribution" + ) + + if original_scale and not hasattr(self.idata.posterior, channel_contribution): + raise ValueError( + f"""No posterior.{channel_contribution} data found in 'self.idata'. \n + Add a original scale deterministic:\n + mmm.add_original_scale_contribution_variable(\n + var=[\n + \"channel_contribution\",\n + ...\n + ]\n + )\n + """ + ) + + for _idx, (channel, dims_combo, addl_combo) in enumerate(total_combos): + # Compute subplot position + row = list(channels).index(channel) + # If dims_combos > 1, treat as extra axis (columns: addl * dims) + if n_dims > 1: + col = list(additional_combinations).index(addl_combo) * n_dims + list( + dims_combos + ).index(dims_combo) + else: + col = list(additional_combinations).index(addl_combo) + ax = axes[row][col] + + # Build indexers for dims and additional_dims + indexers = ( + dict(zip(additional_dims, addl_combo, strict=False)) + if additional_dims + else {} + ) + if dims: + for i, k in enumerate(dims_keys): + indexers[k] = dims_combo[i] + for k, v in (dims or {}).items(): + if k not in dims_keys: + indexers[k] = v + indexers["channel"] = channel + + # Select X data (constant_data) + x_data = self.idata.constant_data.channel_data.sel(**indexers) + # Select Y data (posterior contributions) and scale if needed + y_data = self.idata.posterior[channel_contribution].sel(**indexers) + y_data = y_data.mean(dim=[d for d in y_data.dims if d in ("chain", "draw")]) + x_data = x_data.broadcast_like(y_data) + y_data = y_data.broadcast_like(x_data) + ax.scatter( + x_data.values.flatten(), + y_data.values.flatten(), + alpha=0.8, + color=channel_colors[channel], + label=str(channel), + ) + # Build subplot title + title_dims = ( + ["channel"] + (list(dims.keys()) if dims else []) + additional_dims + ) + title_combo = ( + channel, + *[indexers[k] for k in title_dims if k != "channel"], + ) + title = self._build_subplot_title( + dims=title_dims, + combo=title_combo, + fallback_title="Channel Saturation Curve", + ) + ax.set_title(title) + ax.set_xlabel("Channel Data (X)") + ax.set_ylabel("Channel Contributions (Y)") + ax.legend(loc="best") + + # Hide any unused axes (if grid is larger than needed) + for i in range(nrows): + for j in range(ncols): + if i * ncols + j >= n_subplots: + axes[i][j].set_visible(False) + + return fig, axes + + def saturation_curves( + self, + curve: xr.DataArray, + original_scale: bool = False, + n_samples: int = 10, + hdi_probs: float | list[float] | None = None, + random_seed: np.random.Generator | None = None, + colors: Iterable[str] | None = None, + subplot_kwargs: dict | None = None, + rc_params: dict | None = None, + dims: dict[str, str | int | list] | None = None, + **plot_kwargs, + ) -> tuple[plt.Figure, np.ndarray]: + """ + Overlay saturation‑curve scatter‑plots with posterior‑predictive sample curves and HDI bands. + + **allowing** you to customize figsize and font sizes. + + Parameters + ---------- + curve : xr.DataArray + Posterior‑predictive curves (e.g. dims `("chain","draw","x","channel","geo")`). + original_scale : bool, default=False + Plot `channel_contribution_original_scale` if True, else `channel_contribution`. + n_samples : int, default=10 + Number of sample‑curves per subplot. + hdi_probs : float or list of float, optional + Credible interval probabilities (e.g. 0.94 or [0.5, 0.94]). + If None, uses ArviZ's default (0.94). + random_seed : np.random.Generator, optional + RNG for reproducible sampling. If None, uses `np.random.default_rng()`. + colors : iterable of str, optional + Colors for the sample & HDI plots. + subplot_kwargs : dict, optional + Passed to `plt.subplots` (e.g. `{"figsize": (10,8)}`). + Merged with the function's own default sizing. + rc_params : dict, optional + Temporary `matplotlib.rcParams` for this plot. + Example keys: `"xtick.labelsize"`, `"ytick.labelsize"`, + `"axes.labelsize"`, `"axes.titlesize"`. + dims : dict[str, str | int | list], optional + Dimension filters to apply. Example: {"country": ["US", "UK"], "region": "X"}. + If provided, only the selected slice(s) will be plotted. + **plot_kwargs + Any other kwargs forwarded to `plot_curve` + (for instance `same_axes=True`, `legend=True`, etc.). + + Returns + ------- + fig : plt.Figure + Matplotlib figure with your grid. + axes : np.ndarray of plt.Axes + Array of shape `(n_channels, n_geo)`. + """ + from pymc_marketing.plot import plot_hdi, plot_samples + + if not hasattr(self.idata, "constant_data"): + raise ValueError( + "No 'constant_data' found in 'self.idata'. " + "Please ensure 'self.idata' contains the constant_data group." + ) + + contrib_var = ( + "channel_contribution_original_scale" + if original_scale + else "channel_contribution" + ) + + if original_scale and not hasattr(self.idata.posterior, contrib_var): + raise ValueError( + f"""No posterior.{contrib_var} data found in 'self.idata'.\n" + "Add a original scale deterministic:\n" + " mmm.add_original_scale_contribution_variable(\n" + " var=[\n" + " 'channel_contribution',\n" + " ...\n" + " ]\n" + " )\n" + """ + ) + curve_data = ( + curve * self.idata.constant_data.target_scale if original_scale else curve + ) + curve_data = curve_data.rename("saturation_curve") + + # — 1. figure out grid shape based on scatter data dimensions / identify dims and combos + cdims = self.idata.constant_data.channel_data.dims + all_dims = list(cdims) + additional_dims = [d for d in cdims if d not in ("date", "channel")] + # Validate dims and remove filtered dims from additional_dims + if dims: + self._validate_dims(dims, all_dims) + additional_dims = [d for d in additional_dims if d not in dims] + else: + self._validate_dims({}, all_dims) + # Build all combinations for dims with lists + dims_keys, dims_combos = self._dim_list_handler(dims) + # Build all combinations for remaining dims + if additional_dims: + additional_coords = [ + self.idata.constant_data.coords[d].values for d in additional_dims + ] + additional_combinations = list(itertools.product(*additional_coords)) + else: + additional_combinations = [()] + channels = self.idata.constant_data.coords["channel"].values + n_channels = len(channels) + n_addl = len(additional_combinations) + n_dims = len(dims_combos) + nrows = n_channels + ncols = n_addl * n_dims + total_combos = list( + itertools.product(channels, dims_combos, additional_combinations) + ) + n_subplots = len(total_combos) + + # — 2. merge subplot_kwargs — + user_subplot = subplot_kwargs or {} + + # Handle user-specified ncols/nrows + if "ncols" in user_subplot: + # User specified ncols, calculate nrows + ncols = user_subplot["ncols"] + nrows = int(np.ceil(n_subplots / ncols)) + user_subplot.pop("ncols") # Remove to avoid conflict + elif "nrows" in user_subplot: + # User specified nrows, calculate ncols + nrows = user_subplot["nrows"] + ncols = int(np.ceil(n_subplots / nrows)) + user_subplot.pop("nrows") # Remove to avoid conflict + default_subplot = {"figsize": (ncols * 4, nrows * 3)} + subkw = {**default_subplot, **user_subplot} + # — 3. create subplots ourselves — + rc_params = rc_params or {} + with plt.rc_context(rc_params): + fig, axes = plt.subplots(nrows=nrows, ncols=ncols, **subkw) + # ensure a 2D array + if nrows == 1 and ncols == 1: + axes = np.array([[axes]]) + elif nrows == 1: + axes = axes.reshape(1, -1) + elif ncols == 1: + axes = axes.reshape(-1, 1) + # Flatten axes for easier iteration + axes_flat = axes.flatten() + if colors is None: + colors = [f"C{i}" for i in range(n_channels)] + elif not isinstance(colors, list): + colors = list(colors) + subplot_idx = 0 + for _idx, (ch, dims_combo, addl_combo) in enumerate(total_combos): + if subplot_idx >= len(axes_flat): + break + ax = axes_flat[subplot_idx] + subplot_idx += 1 + # Build indexers for dims and additional_dims + indexers = ( + dict(zip(additional_dims, addl_combo, strict=False)) + if additional_dims + else {} + ) + if dims: + for i, k in enumerate(dims_keys): + indexers[k] = dims_combo[i] + for k, v in (dims or {}).items(): + if k not in dims_keys: + indexers[k] = v + indexers["channel"] = ch + # Select and broadcast curve data for this channel + curve_idx = { + dim: val for dim, val in indexers.items() if dim in curve_data.dims + } + subplot_curve = curve_data.sel(**curve_idx) + if original_scale: + valid_idx = { + k: v + for k, v in indexers.items() + if k in self.idata.constant_data.channel_scale.dims + } + channel_scale = self.idata.constant_data.channel_scale.sel(**valid_idx) + x_original = subplot_curve.coords["x"] * channel_scale + subplot_curve = subplot_curve.assign_coords(x=x_original) + if n_samples > 0: + plot_samples( + subplot_curve, + non_grid_names="x", + n=n_samples, + rng=random_seed, + axes=np.array([[ax]]), + colors=[colors[list(channels).index(ch)]], + same_axes=False, + legend=False, + **plot_kwargs, + ) + if hdi_probs is not None: + # Robustly handle hdi_probs as float, list, tuple, or np.ndarray + if isinstance(hdi_probs, (float, int)): + hdi_probs_iter = [hdi_probs] + elif isinstance(hdi_probs, (list, tuple, np.ndarray)): + hdi_probs_iter = hdi_probs + else: + raise TypeError( + "hdi_probs must be a float, list, tuple, or np.ndarray" + ) + for hdi_prob in hdi_probs_iter: + plot_hdi( + subplot_curve, + non_grid_names="x", + hdi_prob=hdi_prob, + axes=np.array([[ax]]), + colors=[colors[list(channels).index(ch)]], + same_axes=False, + legend=False, + **plot_kwargs, + ) + x_data = self.idata.constant_data.channel_data.sel(**indexers) + y = ( + self.idata.posterior[contrib_var] + .sel(**indexers) + .mean( + dim=[ + d + for d in self.idata.posterior[contrib_var].dims + if d in ("chain", "draw") + ] + ) + ) + x_data, y = x_data.broadcast_like(y), y.broadcast_like(x_data) + ax.scatter( + x_data.values.flatten(), + y.values.flatten(), + alpha=0.8, + color=colors[list(channels).index(ch)], + ) + title_dims = ( + ["channel"] + (list(dims.keys()) if dims else []) + additional_dims + ) + title_combo = ( + ch, + *[indexers[k] for k in title_dims if k != "channel"], + ) + title = self._build_subplot_title( + dims=title_dims, + combo=title_combo, + fallback_title="Channel Saturation Curves", + ) + ax.set_title(title) + ax.set_xlabel("Channel Data (X)") + ax.set_ylabel("Channel Contribution (Y)") + for ax_idx in range(subplot_idx, len(axes_flat)): + axes_flat[ax_idx].set_visible(False) + return fig, axes + + def saturation_curves_scatter( + self, original_scale: bool = False, **kwargs + ) -> tuple[Figure, NDArray[Axes]]: + """ + Plot scatter plots of channel contributions vs. channel data. + + .. deprecated:: 0.1.0 + Will be removed in version 0.2.0. Use :meth:`saturation_scatterplot` instead. + + Parameters + ---------- + channel_contribution : str, optional + Name of the channel contribution variable in the InferenceData. + additional_dims : list[str], optional + Additional dimensions to consider beyond 'channel'. + additional_combinations : list[tuple], optional + Specific combinations of additional dimensions to plot. + **kwargs + Additional keyword arguments passed to _init_subplots. + + Returns + ------- + fig : plt.Figure + The matplotlib figure. + axes : np.ndarray + Array of matplotlib axes. + """ + import warnings + + warnings.warn( + "saturation_curves_scatter is deprecated and will be removed in version 0.2.0. " + "Use saturation_scatterplot instead.", + DeprecationWarning, + stacklevel=2, + ) + # Note: channel_contribution, additional_dims, and additional_combinations + # are not used by saturation_scatterplot, so we don't pass them + return self.saturation_scatterplot(original_scale=original_scale, **kwargs) + + def budget_allocation( + self, + samples: xr.Dataset, + scale_factor: float | None = None, + figsize: tuple[float, float] = (12, 6), + ax: plt.Axes | None = None, + original_scale: bool = True, + dims: dict[str, str | int | list] | None = None, + ) -> tuple[Figure, plt.Axes] | tuple[Figure, np.ndarray]: + """Plot the budget allocation and channel contributions. + + Creates a bar chart comparing allocated spend and channel contributions + for each channel. If additional dimensions besides 'channel' are present, + creates a subplot for each combination of these dimensions. + + Parameters + ---------- + samples : xr.Dataset + The dataset containing the channel contributions and allocation values. + Expected to have 'channel_contribution' and 'allocation' variables. + scale_factor : float, optional + Scale factor to convert to original scale, if original_scale=True. + If None and original_scale=True, assumes scale_factor=1. + figsize : tuple[float, float], optional + The size of the figure to be created. Default is (12, 6). + ax : plt.Axes, optional + The axis to plot on. If None, a new figure and axis will be created. + Only used when no extra dimensions are present. + original_scale : bool, optional + A boolean flag to determine if the values should be plotted in their + original scale. Default is True. + dims : dict[str, str | int | list], optional + Dimension filters to apply. Example: {"country": ["US", "UK"], "user_type": "new"}. + If provided, only the selected slice(s) will be plotted. + + Returns + ------- + fig : matplotlib.figure.Figure + The Figure object containing the plot. + axes : matplotlib.axes.Axes or numpy.ndarray of matplotlib.axes.Axes + The Axes object with the plot, or array of Axes for multiple subplots. + """ + # Get the channels from samples + if "channel" not in samples.dims: + raise ValueError( + "Expected 'channel' dimension in samples dataset, but none found." + ) + + # Check for required variables in samples + if not any( + "channel_contribution" in var_name for var_name in samples.data_vars + ): + raise ValueError( + "Expected a variable containing 'channel_contribution' in samples, but none found." + ) + if "allocation" not in samples: + raise ValueError( + "Expected 'allocation' variable in samples, but none found." + ) + + # Find the variable containing 'channel_contribution' in its name + channel_contrib_var = next( + var_name + for var_name in samples.data_vars + if "channel_contribution" in var_name + ) + + all_dims = list(samples.dims) + # Validate dims + if dims: + self._validate_dims(dims=dims, all_dims=all_dims) + else: + self._validate_dims({}, all_dims) + + # Handle list-valued dims: build all combinations + dims_keys, dims_combos = self._dim_list_handler(dims) + + # After filtering with dims, only use extra dims not in dims and not ignored for subplotting + ignored_dims = {"channel", "date", "sample", "chain", "draw"} + channel_contribution_dims = list(samples[channel_contrib_var].dims) + extra_dims = [ + d + for d in channel_contribution_dims + if d not in ignored_dims and d not in (dims or {}) + ] + + # Identify combos for remaining dims + if extra_dims: + extra_coords = [samples.coords[dim].values for dim in extra_dims] + extra_combos = list(itertools.product(*extra_coords)) + else: + extra_combos = [()] + + # Prepare subplots: one for each combo of dims_lists and extra_dims + total_combos = list(itertools.product(dims_combos, extra_combos)) + n_subplots = len(total_combos) + if n_subplots == 1 and ax is not None: + axes = np.array([[ax]]) + fig = ax.get_figure() + else: + fig, axes = self._init_subplots( + n_subplots=n_subplots, + ncols=1, + width_per_col=figsize[0], + height_per_row=figsize[1], + ) + + for row_idx, (dims_combo, extra_combo) in enumerate(total_combos): + ax_ = axes[row_idx][0] + # Build indexers for dims and extra_dims + indexers = ( + dict(zip(extra_dims, extra_combo, strict=False)) if extra_dims else {} + ) + if dims: + # For dims with lists, use the current value from dims_combo + for i, k in enumerate(dims_keys): + indexers[k] = dims_combo[i] + # For dims with single values, use as is + for k, v in (dims or {}).items(): + if k not in dims_keys: + indexers[k] = v + + # Select channel contributions for this subplot + channel_contrib_data = samples[channel_contrib_var].sel(**indexers) + allocation_data = samples.allocation + # Only select dims that exist in allocation + allocation_indexers = { + k: v for k, v in indexers.items() if k in allocation_data.dims + } + allocation_data = allocation_data.sel(**allocation_indexers) + + # Average over all dims except channel (and those used for this subplot) + used_dims = set(indexers.keys()) | {"channel"} + reduction_dims = [ + dim for dim in channel_contrib_data.dims if dim not in used_dims + ] + channel_contribution = channel_contrib_data.mean( + dim=reduction_dims + ).to_numpy() + if channel_contribution.ndim > 1: + channel_contribution = channel_contribution.flatten() + if original_scale and scale_factor is not None: + channel_contribution *= scale_factor + + allocation_used_dims = set(allocation_indexers.keys()) | {"channel"} + allocation_reduction_dims = [ + dim for dim in allocation_data.dims if dim not in allocation_used_dims + ] + if allocation_reduction_dims: + allocated_spend = allocation_data.mean( + dim=allocation_reduction_dims + ).to_numpy() + else: + allocated_spend = allocation_data.to_numpy() + if allocated_spend.ndim > 1: + allocated_spend = allocated_spend.flatten() + + self._plot_budget_allocation_bars( + ax_, + samples.coords["channel"].values, + allocated_spend, + channel_contribution, + ) + + # Build subplot title + title_dims = (list(dims.keys()) if dims else []) + extra_dims + title_combo = tuple(indexers[k] for k in title_dims) + title = self._build_subplot_title( + dims=title_dims, + combo=title_combo, + fallback_title="Budget Allocation", + ) + ax_.set_title(title) + + fig.tight_layout() + return fig, axes if n_subplots > 1 else (fig, axes[0][0]) + + def _plot_budget_allocation_bars( + self, + ax: plt.Axes, + channels: NDArray, + allocated_spend: NDArray, + channel_contribution: NDArray, + ) -> None: + """Plot budget allocation bars on a given axis. + + Parameters + ---------- + ax : plt.Axes + The axis to plot on. + channels : NDArray + Array of channel names. + allocated_spend : NDArray + Array of allocated spend values. + channel_contribution : NDArray + Array of channel contribution values. + """ + bar_width = 0.35 + opacity = 0.7 + index = range(len(channels)) + + # Plot allocated spend + bars1 = ax.bar( + index, + allocated_spend, + bar_width, + color="C0", + alpha=opacity, + label="Allocated Spend", + ) + + # Create twin axis for contributions + ax2 = ax.twinx() + + # Plot contributions + bars2 = ax2.bar( + [i + bar_width for i in index], + channel_contribution, + bar_width, + color="C1", + alpha=opacity, + label="Channel Contribution", + ) + + # Labels and formatting + ax.set_xlabel("Channels") + ax.set_ylabel("Allocated Spend", color="C0", labelpad=10) + ax2.set_ylabel("Channel Contributions", color="C1", labelpad=10) + + # Set x-ticks in the middle of the bars + ax.set_xticks([i + bar_width / 2 for i in index]) + ax.set_xticklabels(channels) + ax.tick_params(axis="x", rotation=90) + + # Turn off grid and add legend + ax.grid(False) + ax2.grid(False) + + bars = [bars1, bars2] + labels = ["Allocated Spend", "Channel Contributions"] + ax.legend(bars, labels, loc="best") + + def allocated_contribution_by_channel_over_time( + self, + samples: xr.Dataset, + scale_factor: float | None = None, + lower_quantile: float = 0.025, + upper_quantile: float = 0.975, + original_scale: bool = True, + figsize: tuple[float, float] = (10, 6), + ax: plt.Axes | None = None, + ) -> tuple[Figure, plt.Axes | NDArray[Axes]]: + """Plot the allocated contribution by channel with uncertainty intervals. + + This function visualizes the mean allocated contributions by channel along with + the uncertainty intervals defined by the lower and upper quantiles. + If additional dimensions besides 'channel', 'date', and 'sample' are present, + creates a subplot for each combination of these dimensions. + + Parameters + ---------- + samples : xr.Dataset + The dataset containing the samples of channel contributions. + Expected to have 'channel_contribution' variable with dimensions + 'channel', 'date', and 'sample'. + scale_factor : float, optional + Scale factor to convert to original scale, if original_scale=True. + If None and original_scale=True, assumes scale_factor=1. + lower_quantile : float, optional + The lower quantile for the uncertainty interval. Default is 0.025. + upper_quantile : float, optional + The upper quantile for the uncertainty interval. Default is 0.975. + original_scale : bool, optional + If True, the contributions are plotted on the original scale. Default is True. + figsize : tuple[float, float], optional + The size of the figure to be created. Default is (10, 6). + ax : plt.Axes, optional + The axis to plot on. If None, a new figure and axis will be created. + Only used when no extra dimensions are present. + + Returns + ------- + fig : matplotlib.figure.Figure + The Figure object containing the plot. + axes : matplotlib.axes.Axes or numpy.ndarray of matplotlib.axes.Axes + The Axes object with the plot, or array of Axes for multiple subplots. + """ + # Check for expected dimensions and variables + if "channel" not in samples.dims: + raise ValueError( + "Expected 'channel' dimension in samples dataset, but none found." + ) + if "date" not in samples.dims: + raise ValueError( + "Expected 'date' dimension in samples dataset, but none found." + ) + if "sample" not in samples.dims: + raise ValueError( + "Expected 'sample' dimension in samples dataset, but none found." + ) + # Check if any variable contains channel contributions + if not any( + "channel_contribution" in var_name for var_name in samples.data_vars + ): + raise ValueError( + "Expected a variable containing 'channel_contribution' in samples, but none found." + ) + + # Get channel contributions data + channel_contrib_var = next( + var_name + for var_name in samples.data_vars + if "channel_contribution" in var_name + ) + + # Identify extra dimensions beyond 'channel', 'date', and 'sample' + all_dims = list(samples[channel_contrib_var].dims) + ignored_dims = {"channel", "date", "sample"} + extra_dims = [dim for dim in all_dims if dim not in ignored_dims] + + # If no extra dimensions or using provided axis, create a single plot + if not extra_dims or ax is not None: + if ax is None: + fig, ax = plt.subplots(figsize=figsize) + else: + fig = ax.get_figure() + + channel_contribution = samples[channel_contrib_var] + + # Apply scale factor if in original scale + if original_scale and scale_factor is not None: + channel_contribution = channel_contribution * scale_factor + + # Plot mean values by channel + channel_contribution.mean(dim="sample").plot(hue="channel", ax=ax) + + # Add uncertainty intervals for each channel + for channel in samples.coords["channel"].values: + ax.fill_between( + x=channel_contribution.date.values, + y1=channel_contribution.sel(channel=channel).quantile( + lower_quantile, dim="sample" + ), + y2=channel_contribution.sel(channel=channel).quantile( + upper_quantile, dim="sample" + ), + alpha=0.1, + ) + + ax.set_xlabel("Date") + ax.set_ylabel("Channel Contribution") + ax.set_title("Allocated Contribution by Channel Over Time") + + fig.tight_layout() + return fig, ax + + # For multiple dimensions, create a grid of subplots + # Determine layout based on number of extra dimensions + if len(extra_dims) == 1: + # One extra dimension: use for rows + dim_values = [samples.coords[extra_dims[0]].values] + nrows = len(dim_values[0]) + ncols = 1 + subplot_dims = [extra_dims[0], None] + elif len(extra_dims) == 2: + # Two extra dimensions: one for rows, one for columns + dim_values = [ + samples.coords[extra_dims[0]].values, + samples.coords[extra_dims[1]].values, + ] + nrows = len(dim_values[0]) + ncols = len(dim_values[1]) + subplot_dims = extra_dims + else: + # Three or more: use first two for rows/columns, average over the rest + dim_values = [ + samples.coords[extra_dims[0]].values, + samples.coords[extra_dims[1]].values, + ] + nrows = len(dim_values[0]) + ncols = len(dim_values[1]) + subplot_dims = [extra_dims[0], extra_dims[1]] + + # Calculate figure size based on number of subplots + subplot_figsize = (figsize[0] * max(1, ncols), figsize[1] * max(1, nrows)) + fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=subplot_figsize) + + # Make axes indexable even for 1x1 grid + if nrows == 1 and ncols == 1: + axes = np.array([[axes]]) + elif nrows == 1: + axes = axes.reshape(1, -1) + elif ncols == 1: + axes = axes.reshape(-1, 1) + + # Create a subplot for each combination of dimension values + for i, row_val in enumerate(dim_values[0]): + for j, col_val in enumerate( + dim_values[1] if len(dim_values) > 1 else [None] + ): + ax = axes[i, j] + + # Select data for this subplot + selection = {subplot_dims[0]: row_val} + if col_val is not None: + selection[subplot_dims[1]] = col_val + + # Select channel contributions for this subplot + subset = samples[channel_contrib_var].sel(**selection) + + # Apply scale factor if needed + if original_scale and scale_factor is not None: + subset = subset * scale_factor + + # Plot mean values by channel for this subset + subset.mean(dim="sample").plot(hue="channel", ax=ax) + + # Add uncertainty intervals for each channel + for channel in samples.coords["channel"].values: + channel_data = subset.sel(channel=channel) + ax.fill_between( + x=channel_data.date.values, + y1=channel_data.quantile(lower_quantile, dim="sample"), + y2=channel_data.quantile(upper_quantile, dim="sample"), + alpha=0.1, + ) + + # Add subplot title based on dimension values + title_parts = [] + if subplot_dims[0] is not None: + title_parts.append(f"{subplot_dims[0]}={row_val}") + if subplot_dims[1] is not None: + title_parts.append(f"{subplot_dims[1]}={col_val}") + + base_title = "Allocated Contribution by Channel Over Time" + if title_parts: + ax.set_title(f"{base_title} - {', '.join(title_parts)}") + else: + ax.set_title(base_title) + + ax.set_xlabel("Date") + ax.set_ylabel("Channel Contribution") + + fig.tight_layout() + return fig, axes + + def sensitivity_analysis( + self, + hdi_prob: float = 0.94, + ax: plt.Axes | None = None, + aggregation: dict[str, tuple[str, ...] | list[str]] | None = None, + subplot_kwargs: dict[str, Any] | None = None, + *, + plot_kwargs: dict[str, Any] | None = None, + ylabel: str = "Effect", + xlabel: str = "Sweep", + title: str | None = None, + add_figure_title: bool = False, + subplot_title_fallback: str = "Sensitivity Analysis", + ) -> tuple[Figure, NDArray[Axes]] | plt.Axes: + """Plot sensitivity analysis results. + + Parameters + ---------- + hdi_prob : float, default 0.94 + HDI probability mass. + ax : plt.Axes, optional + The axis to plot on. + aggregation : dict, optional + Aggregation to apply to the data. + E.g., {"sum": ("channel",)} to sum over the channel dimension. + + Other Parameters + ---------------- + plot_kwargs : dict, optional + Keyword arguments forwarded to the underlying line plot. Defaults include + ``{"color": "C0"}``. + ylabel : str, optional + Y-axis label. Defaults to "Effect". + xlabel : str, optional + X-axis label. Defaults to "Sweep". + title : str, optional + Figure-level title to add when ``add_figure_title=True``. + add_figure_title : bool, optional + Whether to add a figure-level title. Defaults to ``False``. + subplot_title_fallback : str, optional + Fallback title used for subplot titles when no plotting dims exist. Defaults + to "Sensitivity Analysis". + + Examples + -------- + Basic run using stored results in `idata`: + + .. code-block:: python + + # Assuming you already ran a sweep and stored results + # under idata.sensitivity_analysis via SensitivityAnalysis.run_sweep(..., extend_idata=True) + ax = mmm.plot.sensitivity_analysis(hdi_prob=0.9) + + With aggregation over dimensions (e.g., sum over channels): + + .. code-block:: python + + ax = mmm.plot.sensitivity_analysis( + hdi_prob=0.9, + aggregation={"sum": ("channel",)}, + ) + """ + if not hasattr(self.idata, "sensitivity_analysis"): + raise ValueError( + "No sensitivity analysis results found. Run run_sweep() first." + ) + sa = self.idata.sensitivity_analysis # type: ignore + x = sa["x"] if isinstance(sa, xr.Dataset) else sa + # Coerce numeric dtype + try: + x = x.astype(float) + except Exception as err: + import warnings + + warnings.warn( + f"Failed to cast sensitivity analysis data to float: {err}", + RuntimeWarning, + stacklevel=2, + ) + # Apply aggregations + if aggregation: + for op, dims in aggregation.items(): + dims_list = [d for d in dims if d in x.dims] + if not dims_list: + continue + if op == "sum": + x = x.sum(dim=dims_list) + elif op == "mean": + x = x.mean(dim=dims_list) + else: + x = x.median(dim=dims_list) + # Determine plotting dimensions (excluding sample & sweep) + plot_dims = [d for d in x.dims if d not in {"sample", "sweep"}] + if plot_dims: + dim_combinations = list( + itertools.product(*[x.coords[d].values for d in plot_dims]) + ) + else: + dim_combinations = [()] + + n_panels = len(dim_combinations) + + # Handle axis/grid creation + subplot_kwargs = {**(subplot_kwargs or {})} + nrows_user = subplot_kwargs.pop("nrows", None) + ncols_user = subplot_kwargs.pop("ncols", None) + if nrows_user is not None and ncols_user is not None: + raise ValueError( + "Specify only one of 'nrows' or 'ncols' in subplot_kwargs." + ) + + if n_panels > 1: + if ax is not None: + raise ValueError( + "Multiple sensitivity panels detected; please omit 'ax' and use 'subplot_kwargs' instead." + ) + if ncols_user is not None: + ncols = ncols_user + nrows = int(np.ceil(n_panels / ncols)) + elif nrows_user is not None: + nrows = nrows_user + ncols = int(np.ceil(n_panels / nrows)) + else: + ncols = max(1, int(np.ceil(np.sqrt(n_panels)))) + nrows = int(np.ceil(n_panels / ncols)) + subplot_kwargs.setdefault("figsize", (ncols * 4.0, nrows * 3.0)) + fig, axes_grid = plt.subplots( + nrows=nrows, + ncols=ncols, + **subplot_kwargs, + ) + if isinstance(axes_grid, plt.Axes): + axes_grid = np.array([[axes_grid]]) + elif axes_grid.ndim == 1: + axes_grid = axes_grid.reshape(1, -1) + axes_array = axes_grid + else: + if ax is not None: + axes_array = np.array([[ax]]) + fig = ax.figure + else: + if ncols_user is not None or nrows_user is not None: + subplot_kwargs.setdefault("figsize", (4.0, 3.0)) + fig, single_ax = plt.subplots( + nrows=1, + ncols=1, + **subplot_kwargs, + ) + else: + fig, single_ax = plt.subplots() + axes_array = np.array([[single_ax]]) + + # Merge plotting kwargs with defaults + _plot_kwargs = {"color": "C0"} + if plot_kwargs: + _plot_kwargs.update(plot_kwargs) + _line_color = _plot_kwargs.get("color", "C0") + + axes_flat = axes_array.flatten() + for idx, combo in enumerate(dim_combinations): + current_ax = axes_flat[idx] + indexers = dict(zip(plot_dims, combo, strict=False)) if plot_dims else {} + subset = x.sel(**indexers) if indexers else x + subset = subset.squeeze(drop=True) + subset = subset.astype(float) + + if "sweep" in subset.dims: + sweep_dim = "sweep" + else: + cand = [d for d in subset.dims if d != "sample"] + if not cand: + raise ValueError( + "Expected 'sweep' (or a non-sample) dimension in sensitivity results." + ) + sweep_dim = cand[0] + + sweep = ( + np.asarray(subset.coords[sweep_dim].values) + if sweep_dim in subset.coords + else np.arange(subset.sizes[sweep_dim]) + ) + + mean = subset.mean("sample") if "sample" in subset.dims else subset + reduce_dims = [d for d in mean.dims if d != sweep_dim] + if reduce_dims: + mean = mean.sum(dim=reduce_dims) + + if "sample" in subset.dims: + hdi = az.hdi(subset, hdi_prob=hdi_prob, input_core_dims=[["sample"]]) + if isinstance(hdi, xr.Dataset): + hdi = hdi[next(iter(hdi.data_vars))] + else: + hdi = xr.concat([mean, mean], dim="hdi").assign_coords( + hdi=np.array([0, 1]) + ) + + reduce_hdi = [d for d in hdi.dims if d not in (sweep_dim, "hdi")] + if reduce_hdi: + hdi = hdi.sum(dim=reduce_hdi) + if set(hdi.dims) == {sweep_dim, "hdi"} and list(hdi.dims) != [ + sweep_dim, + "hdi", + ]: + hdi = hdi.transpose(sweep_dim, "hdi") # type: ignore + + current_ax.plot(sweep, np.asarray(mean.values, dtype=float), **_plot_kwargs) + az.plot_hdi( + x=sweep, + hdi_data=np.asarray(hdi.values, dtype=float), + hdi_prob=hdi_prob, + color=_line_color, + ax=current_ax, + ) + + title = self._build_subplot_title( + dims=plot_dims, + combo=combo, + fallback_title=subplot_title_fallback, + ) + current_ax.set_title(title) + current_ax.set_xlabel(xlabel) + current_ax.set_ylabel(ylabel) + + # Hide any unused axes (happens if grid > panels) + for ax_extra in axes_flat[n_panels:]: + ax_extra.set_visible(False) + + # Optional figure-level title: only for multi-panel layouts, default color (black) + if add_figure_title and title is not None and n_panels > 1: + fig.suptitle(title) + + if n_panels == 1: + return axes_array[0, 0] + + fig.tight_layout() + return fig, axes_array + + def uplift_curve( + self, + hdi_prob: float = 0.94, + ax: plt.Axes | None = None, + aggregation: dict[str, tuple[str, ...] | list[str]] | None = None, + subplot_kwargs: dict[str, Any] | None = None, + *, + plot_kwargs: dict[str, Any] | None = None, + ylabel: str = "Uplift", + xlabel: str = "Sweep", + title: str | None = "Uplift curve", + add_figure_title: bool = True, + ) -> tuple[Figure, NDArray[Axes]] | plt.Axes: + """ + Plot precomputed uplift curves stored under `idata.sensitivity_analysis['uplift_curve']`. + + Parameters + ---------- + hdi_prob : float, default 0.94 + HDI probability mass. + ax : plt.Axes, optional + The axis to plot on. + aggregation : dict, optional + Aggregation to apply to the data. + E.g., {"sum": ("channel",)} to sum over the channel dimension. + subplot_kwargs : dict, optional + Additional subplot configuration forwarded to :meth:`sensitivity_analysis`. + plot_kwargs : dict, optional + Keyword arguments forwarded to the underlying line plot. If not provided, defaults + are used by :meth:`sensitivity_analysis` (e.g., color "C0"). + ylabel : str, optional + Y-axis label. Defaults to "Uplift". + xlabel : str, optional + X-axis label. Defaults to "Sweep". + title : str, optional + Figure-level title to add when ``add_figure_title=True``. Defaults to "Uplift curve". + add_figure_title : bool, optional + Whether to add a figure-level title. Defaults to ``True``. + + Examples + -------- + Persist uplift curve and plot: + + .. code-block:: python + + from pymc_marketing.mmm.sensitivity_analysis import SensitivityAnalysis + + sweeps = np.linspace(0.5, 1.5, 11) + sa = SensitivityAnalysis(mmm.model, mmm.idata) + results = sa.run_sweep( + var_input="channel_data", + sweep_values=sweeps, + var_names="channel_contribution", + sweep_type="multiplicative", + ) + uplift = sa.compute_uplift_curve_respect_to_base( + results, ref=1.0, extend_idata=True + ) + _ = mmm.plot.uplift_curve(hdi_prob=0.9) + """ + if not hasattr(self.idata, "sensitivity_analysis"): + raise ValueError( + "No sensitivity analysis results found in 'self.idata'. " + "Run 'mmm.sensitivity.run_sweep()' first." + ) + + sa_group = self.idata.sensitivity_analysis # type: ignore + if isinstance(sa_group, xr.Dataset): + if "uplift_curve" not in sa_group: + raise ValueError( + "Expected 'uplift_curve' in idata.sensitivity_analysis. " + "Use SensitivityAnalysis.compute_uplift_curve_respect_to_base(..., extend_idata=True)." + ) + data_var = sa_group["uplift_curve"] + else: + raise ValueError( + "sensitivity_analysis does not contain 'uplift_curve'. Did you persist it to idata?" + ) + + # Delegate to a thin wrapper by temporarily constructing a Dataset + tmp_idata = xr.Dataset({"x": data_var}) + # Monkey-patch minimal attributes needed + tmp_idata["x"].attrs.update(getattr(sa_group, "attrs", {})) # type: ignore + # Temporarily swap + original_group = self.idata.sensitivity_analysis # type: ignore + try: + self.idata.sensitivity_analysis = tmp_idata # type: ignore + return self.sensitivity_analysis( + hdi_prob=hdi_prob, + ax=ax, + aggregation=aggregation, + subplot_kwargs=subplot_kwargs, + subplot_title_fallback="Uplift curve", + plot_kwargs=plot_kwargs, + ylabel=ylabel, + xlabel=xlabel, + title=title, + add_figure_title=add_figure_title, + ) + finally: + self.idata.sensitivity_analysis = original_group # type: ignore + + def marginal_curve( + self, + hdi_prob: float = 0.94, + ax: plt.Axes | None = None, + aggregation: dict[str, tuple[str, ...] | list[str]] | None = None, + subplot_kwargs: dict[str, Any] | None = None, + *, + plot_kwargs: dict[str, Any] | None = None, + ylabel: str = "Marginal effect", + xlabel: str = "Sweep", + title: str | None = "Marginal effects", + add_figure_title: bool = True, + ) -> tuple[Figure, NDArray[Axes]] | plt.Axes: + """ + Plot precomputed marginal effects stored under `idata.sensitivity_analysis['marginal_effects']`. + + Parameters + ---------- + hdi_prob : float, default 0.94 + HDI probability mass. + ax : plt.Axes, optional + The axis to plot on. + aggregation : dict, optional + Aggregation to apply to the data. + E.g., {"sum": ("channel",)} to sum over the channel dimension. + subplot_kwargs : dict, optional + Additional subplot configuration forwarded to :meth:`sensitivity_analysis`. + plot_kwargs : dict, optional + Keyword arguments forwarded to the underlying line plot. Defaults to ``{"color": "C1"}``. + ylabel : str, optional + Y-axis label. Defaults to "Marginal effect". + xlabel : str, optional + X-axis label. Defaults to "Sweep". + title : str, optional + Figure-level title to add when ``add_figure_title=True``. Defaults to "Marginal effects". + add_figure_title : bool, optional + Whether to add a figure-level title. Defaults to ``True``. + + Examples + -------- + Persist marginal effects and plot: + + .. code-block:: python + + from pymc_marketing.mmm.sensitivity_analysis import SensitivityAnalysis + + sweeps = np.linspace(0.5, 1.5, 11) + sa = SensitivityAnalysis(mmm.model, mmm.idata) + results = sa.run_sweep( + var_input="channel_data", + sweep_values=sweeps, + var_names="channel_contribution", + sweep_type="multiplicative", + ) + me = sa.compute_marginal_effects(results, extend_idata=True) + _ = mmm.plot.marginal_curve(hdi_prob=0.9) + """ + if not hasattr(self.idata, "sensitivity_analysis"): + raise ValueError( + "No sensitivity analysis results found in 'self.idata'. " + "Run 'mmm.sensitivity.run_sweep()' first." + ) + + sa_group = self.idata.sensitivity_analysis # type: ignore + if isinstance(sa_group, xr.Dataset): + if "marginal_effects" not in sa_group: + raise ValueError( + "Expected 'marginal_effects' in idata.sensitivity_analysis. " + "Use SensitivityAnalysis.compute_marginal_effects(..., extend_idata=True)." + ) + data_var = sa_group["marginal_effects"] + else: + raise ValueError( + "sensitivity_analysis does not contain 'marginal_effects'. Did you persist it to idata?" + ) + + # We want a different y-label and color + # Temporarily swap group to reuse plotting logic + tmp = xr.Dataset({"x": data_var}) + tmp["x"].attrs.update(getattr(sa_group, "attrs", {})) # type: ignore + original = self.idata.sensitivity_analysis # type: ignore + try: + self.idata.sensitivity_analysis = tmp # type: ignore + # Reuse core plotting; percentage=False by definition + # Merge defaults for plot_kwargs if not provided + _plot_kwargs = {"color": "C1"} + if plot_kwargs: + _plot_kwargs.update(plot_kwargs) + return self.sensitivity_analysis( + hdi_prob=hdi_prob, + ax=ax, + aggregation=aggregation, + subplot_kwargs=subplot_kwargs, + subplot_title_fallback="Marginal effects", + plot_kwargs=_plot_kwargs, + ylabel=ylabel, + xlabel=xlabel, + title=title, + add_figure_title=add_figure_title, + ) + finally: + self.idata.sensitivity_analysis = original # type: ignore diff --git a/pymc_marketing/mmm/multidimensional.py b/pymc_marketing/mmm/multidimensional.py index 5b2614d89..faf06d239 100644 --- a/pymc_marketing/mmm/multidimensional.py +++ b/pymc_marketing/mmm/multidimensional.py @@ -183,9 +183,11 @@ SaturationTransformation, saturation_from_dict, ) +from pymc_marketing.mmm.config import mmm_plot_config from pymc_marketing.mmm.events import EventEffect from pymc_marketing.mmm.fourier import YearlyFourier from pymc_marketing.mmm.hsgp import HSGPBase, hsgp_from_dict +from pymc_marketing.mmm.legacy_plot import LegacyMMMPlotSuite from pymc_marketing.mmm.lift_test import ( add_cost_per_target_potentials, add_lift_measurements_to_likelihood_from_saturation, @@ -616,11 +618,59 @@ def attrs_to_init_kwargs(cls, attrs: dict[str, str]) -> dict[str, Any]: } @property - def plot(self) -> MMMPlotSuite: - """Use the MMMPlotSuite to plot the results.""" + def plot(self): + """Use the MMMPlotSuite to plot the results. + + The plot suite version is controlled by mmm_plot_config["plot.use_v2"]: + - False (default): Uses legacy matplotlib-based suite (will be deprecated) + - True: Uses new arviz_plots-based suite with multi-backend support + + .. versionchanged:: 0.18.0 + Added version control via mmm_plot_config["plot.use_v2"]. + The legacy suite will be removed in v0.20.0. + + Examples + -------- + Use new plot suite: + + >>> from pymc_marketing.mmm import mmm_plot_config + >>> mmm_plot_config["plot.use_v2"] = True + >>> pc = mmm.plot.posterior_predictive() + >>> pc.show() + + Use legacy plot suite: + + >>> mmm_plot_config["plot.use_v2"] = False + >>> fig, ax = mmm.plot.posterior_predictive() + >>> fig.savefig("plot.png") + + Returns + ------- + MMMPlotSuite or LegacyMMMPlotSuite + Plot suite instance for creating MMM visualizations. + """ self._validate_model_was_built() self._validate_idata_exists() - return MMMPlotSuite(idata=self.idata) + + # Check version flag + if mmm_plot_config.get("plot.use_v2", False): + return MMMPlotSuite(idata=self.idata) + else: + # Show deprecation warning for legacy suite + if mmm_plot_config.get("plot.show_warnings", True): + warnings.warn( + "The current MMMPlotSuite will be deprecated in v0.20.0. " + "The new version uses arviz_plots and supports multiple backends " + "(matplotlib, plotly, bokeh). " + "To use the new version:\n" + " from pymc_marketing.mmm import mmm_plot_config\n" + " mmm_plot_config['plot.use_v2'] = True\n" + "To suppress this warning: mmm_plot_config['plot.show_warnings'] = False\n" + "See migration guide: https://docs.pymc-marketing.io/en/latest/mmm/plotting_migration.html", + FutureWarning, + stacklevel=2, + ) + return LegacyMMMPlotSuite(idata=self.idata) @property def default_model_config(self) -> dict: diff --git a/pymc_marketing/mmm/plot.py b/pymc_marketing/mmm/plot.py index 18537a175..c8fbf8098 100644 --- a/pymc_marketing/mmm/plot.py +++ b/pymc_marketing/mmm/plot.py @@ -46,7 +46,7 @@ mmm.sample_posterior_predictive(X) # Posterior predictive time series - _ = mmm.plot.posterior_predictive(var=["y"], hdi_prob=0.9) + _ = mmm.plot.posterior_predictive(var="y", hdi_prob=0.9) # Posterior contributions over time (e.g., channel_contribution) _ = mmm.plot.contributions_over_time(var=["channel_contribution"], hdi_prob=0.9) @@ -88,7 +88,7 @@ idata.extend(pm.sample_posterior_predictive(idata, random_seed=1)) plot = MMMPlotSuite(idata) - _ = plot.posterior_predictive(var=["y"], hdi_prob=0.9) + _ = plot.posterior_predictive(var="y", hdi_prob=0.9) Custom contributions_over_time -------- @@ -170,19 +170,21 @@ """ import itertools -from collections.abc import Iterable -from typing import Any import arviz as az -import matplotlib.pyplot as plt +import arviz_plots as azp import numpy as np import xarray as xr -from matplotlib.axes import Axes -from matplotlib.figure import Figure -from numpy.typing import NDArray +from arviz_base.labels import DimCoordLabeller, NoVarLabeller, mix_labellers +from arviz_plots import PlotCollection + +from pymc_marketing.mmm.config import mmm_plot_config __all__ = ["MMMPlotSuite"] +WIDTH_PER_COL: float = 10.0 +HEIGHT_PER_ROW: float = 4.0 + class MMMPlotSuite: """Media Mix Model Plot Suite. @@ -197,53 +199,6 @@ def __init__( ): self.idata = idata - def _init_subplots( - self, - n_subplots: int, - ncols: int = 1, - width_per_col: float = 10.0, - height_per_row: float = 4.0, - ) -> tuple[Figure, NDArray[Axes]]: - """Initialize a grid of subplots. - - Parameters - ---------- - n_subplots : int - Number of rows (if ncols=1) or total subplots. - ncols : int - Number of columns in the subplot grid. - width_per_col : float - Width (in inches) for each column of subplots. - height_per_row : float - Height (in inches) for each row of subplots. - - Returns - ------- - fig : matplotlib.figure.Figure - The created Figure object. - axes : np.ndarray of matplotlib.axes.Axes - 2D array of axes of shape (n_subplots, ncols). - """ - fig, axes = plt.subplots( - nrows=n_subplots, - ncols=ncols, - figsize=(width_per_col * ncols, height_per_row * n_subplots), - squeeze=False, - ) - return fig, axes - - def _build_subplot_title( - self, - dims: list[str], - combo: tuple, - fallback_title: str = "Time Series", - ) -> str: - """Build a subplot title string from dimension names and their values.""" - if dims: - title_parts = [f"{d}={v}" for d, v in zip(dims, combo, strict=False)] - return ", ".join(title_parts) - return fallback_title - def _get_additional_dim_combinations( self, data: xr.Dataset, @@ -266,23 +221,6 @@ def _get_additional_dim_combinations( return additional_dims, dim_combinations - def _reduce_and_stack( - self, data: xr.DataArray, dims_to_ignore: set[str] | None = None - ) -> xr.DataArray: - """Sum over leftover dims and stack chain+draw into sample if present.""" - if dims_to_ignore is None: - dims_to_ignore = {"date", "chain", "draw", "sample"} - - leftover_dims = [d for d in data.dims if d not in dims_to_ignore] - if leftover_dims: - data = data.sum(dim=leftover_dims) - - # Combine chain+draw into 'sample' if both exist - if "chain" in data.dims and "draw" in data.dims: - data = data.stack(sample=("chain", "draw")) - - return data - def _get_posterior_predictive_data( self, idata: xr.Dataset | None, @@ -303,25 +241,6 @@ def _get_posterior_predictive_data( ) return self.idata.posterior_predictive # type: ignore - def _add_median_and_hdi( - self, ax: Axes, data: xr.DataArray, var: str, hdi_prob: float = 0.85 - ) -> Axes: - """Add median and HDI to the given axis.""" - median = data.median(dim="sample") if "sample" in data.dims else data.median() - hdi = az.hdi( - data, - hdi_prob=hdi_prob, - input_core_dims=[["sample"]] if "sample" in data.dims else None, - ) - - if "date" not in data.dims: - raise ValueError(f"Expected 'date' dimension in {var}, but none found.") - dates = data.coords["date"].values - # Add median and HDI to the plot - ax.plot(dates, median, label=var, alpha=0.9) - ax.fill_between(dates, hdi[var][..., 0], hdi[var][..., 1], alpha=0.2) - return ax - def _validate_dims( self, dims: dict[str, str | int | list], @@ -368,293 +287,579 @@ def _dim_list_handler( dims_combos = [()] return dims_keys, dims_combos + def _resolve_backend(self, backend: str | None) -> str: + """Resolve backend parameter to actual backend string.""" + return backend or mmm_plot_config["plot.backend"] + + def _get_data_or_fallback( + self, + data: xr.Dataset | None, + idata_attr: str, + data_name: str, + ) -> xr.Dataset: + """Get data from parameter or fall back to self.idata attribute. + + Parameters + ---------- + data : xr.Dataset or None + Data provided by user. + idata_attr : str + Attribute name on self.idata to use as fallback (e.g., "posterior"). + data_name : str + Human-readable name for error messages (e.g., "posterior data"). + + Returns + ------- + xr.Dataset + The data to use. + + Raises + ------ + ValueError + If data is None and self.idata doesn't have the required attribute. + """ + if data is None: + if not hasattr(self.idata, idata_attr): + raise ValueError( + f"No {data_name} found in 'self.idata' and no 'data' argument provided. " + f"Please ensure 'self.idata' contains a '{idata_attr}' group or provide 'data' explicitly." + ) + data = getattr(self.idata, idata_attr) + return data + # ------------------------------------------------------------------------ # Main Plotting Methods # ------------------------------------------------------------------------ def posterior_predictive( self, - var: list[str] | None = None, + var: str | None = None, idata: xr.Dataset | None = None, hdi_prob: float = 0.85, - ) -> tuple[Figure, NDArray[Axes]]: - """Plot time series from the posterior predictive distribution. + backend: str | None = None, + ) -> PlotCollection: + """Plot posterior predictive distributions over time. - By default, if both `var` and `idata` are not provided, uses - `self.idata.posterior_predictive` and defaults the variable to `["y"]`. + Visualizes posterior predictive samples as time series, showing the median + line and highest density interval (HDI) bands. Useful for checking model fit + and understanding prediction uncertainty. Parameters ---------- - var : list of str, optional - A list of variable names to plot. Default is ["y"] if not provided. - idata : xarray.Dataset, optional - The posterior predictive dataset to plot. If not provided, tries to - use `self.idata.posterior_predictive`. - hdi_prob: float, optional - The probability mass of the highest density interval to be displayed. Default is 0.85. + var : str, optional + Variable name to plot from posterior_predictive group. If None, uses "y". + idata : xr.Dataset, optional + Dataset containing posterior predictive samples with a "date" coordinate. + If None, uses self.idata.posterior_predictive. + + This parameter allows: + - Testing with mock data without modifying self.idata + - Plotting external posterior predictive samples + - Comparing different model fits side-by-side + hdi_prob : float, default 0.85 + Probability mass for HDI interval (between 0 and 1). + backend : str, optional + Plotting backend to use. Options: "matplotlib", "plotly", "bokeh". + If None, uses global config via mmm_plot_config["plot.backend"]. + Default is "matplotlib". Returns ------- - fig : matplotlib.figure.Figure - The Figure object containing the subplots. - axes : np.ndarray of matplotlib.axes.Axes - Array of Axes objects corresponding to each subplot row. + PlotCollection + arviz_plots PlotCollection object containing the plot. + + Use ``.show()`` to display or ``.save("filename")`` to save. + Unlike the legacy suite which returned ``(Figure, Axes)``, + this provides a unified interface across all backends. Raises ------ ValueError - If no `idata` is provided and `self.idata.posterior_predictive` does - not exist, instructing the user to run `MMM.sample_posterior_predictive()`. - If `hdi_prob` is not between 0 and 1, instructing the user to provide a valid value. + If no posterior_predictive data found in self.idata and no idata provided. + ValueError + If hdi_prob is not between 0 and 1. + + See Also + -------- + LegacyMMMPlotSuite.posterior_predictive : Legacy matplotlib-only implementation + + Notes + ----- + Breaking changes from legacy implementation: + + - Returns PlotCollection instead of (Figure, Axes) + - Different interface for saving and displaying plots + + Examples + -------- + Basic usage: + + .. code-block:: python + + mmm.sample_posterior_predictive(X) + pc = mmm.plot.posterior_predictive() + pc.show() + + Plot with different HDI probability: + + .. code-block:: python + + pc = mmm.plot.posterior_predictive(hdi_prob=0.94) + pc.show() + + Save to file: + + .. code-block:: python + + pc = mmm.plot.posterior_predictive() + pc.save("posterior_predictive.png") + + Use different backend: + + .. code-block:: python + + pc = mmm.plot.posterior_predictive(backend="plotly") + pc.show() + + Provide explicit data: + + .. code-block:: python + + external_pp = xr.Dataset(...) # Custom posterior predictive + pc = mmm.plot.posterior_predictive(idata=external_pp) + pc.show() + + Direct instantiation pattern: + + .. code-block:: python + + from pymc_marketing.mmm.plot import MMMPlotSuite + + mps = MMMPlotSuite(custom_idata) + pc = mps.posterior_predictive() + pc.show() """ if not 0 < hdi_prob < 1: raise ValueError("HDI probability must be between 0 and 1.") + + # Resolve backend + backend = self._resolve_backend(backend) + # 1. Retrieve or validate posterior_predictive data pp_data = self._get_posterior_predictive_data(idata) - # 2. Determine variables to plot + # 2. Determine variable to plot if var is None: - var = ["y"] - main_var = var[0] + var = "y" + main_var = var # 3. Identify additional dims & get all combos ignored_dims = {"chain", "draw", "date", "sample"} - additional_dims, dim_combinations = self._get_additional_dim_combinations( + additional_dims, _ = self._get_additional_dim_combinations( data=pp_data, variable=main_var, ignored_dims=ignored_dims ) # 4. Prepare subplots - fig, axes = self._init_subplots(n_subplots=len(dim_combinations), ncols=1) - - # 5. Loop over dimension combinations - for row_idx, combo in enumerate(dim_combinations): - ax = axes[row_idx][0] - - # Build indexers - indexers = ( - dict(zip(additional_dims, combo, strict=False)) - if additional_dims - else {} - ) - - # 6. Plot each requested variable - for v in var: - if v not in pp_data: - raise ValueError( - f"Variable '{v}' not in the posterior_predictive dataset." - ) + pc = azp.PlotCollection.wrap( + pp_data[main_var].to_dataset(), + cols=additional_dims, + col_wrap=1, + figure_kwargs={ + "sharex": True, + }, + backend=backend, + ) - data = pp_data[v].sel(**indexers) - # Sum leftover dims, stack chain+draw if needed - data = self._reduce_and_stack(data, ignored_dims) - ax = self._add_median_and_hdi(ax, data, v, hdi_prob=hdi_prob) + # plot hdi + hdi = pp_data.azstats.hdi(hdi_prob) + pc.map( + azp.visuals.fill_between_y, + x=pp_data["date"], + y_bottom=hdi.sel(ci_bound="lower"), + y_top=hdi.sel(ci_bound="upper"), + alpha=0.2, + color="C0", + ) - # 7. Subplot title & labels - title = self._build_subplot_title( - dims=additional_dims, - combo=combo, - fallback_title="Posterior Predictive Time Series", - ) - ax.set_title(title) - ax.set_xlabel("Date") - ax.set_ylabel("Posterior Predictive") - ax.legend(loc="best") + # plot median line + pc.map( + azp.visuals.line_xy, + x=pp_data["date"], + y=pp_data.median(dim=["chain", "draw"]), + color="C0", + ) - return fig, axes + # add labels + pc.map(azp.visuals.labelled_x, text="Date") + pc.map(azp.visuals.labelled_y, text="Posterior Predictive") + pc.map( + azp.visuals.labelled_title, + subset_info=True, + labeller=mix_labellers((NoVarLabeller, DimCoordLabeller))(), + ) + return pc def contributions_over_time( self, var: list[str], + data: xr.Dataset | None = None, hdi_prob: float = 0.85, dims: dict[str, str | int | list] | None = None, - ) -> tuple[Figure, NDArray[Axes]]: - """Plot the time-series contributions for each variable in `var`. + backend: str | None = None, + ) -> PlotCollection: + """Plot time-series contributions for specified variables. - showing the median and the credible interval (default 85%). - Creates one subplot per combination of non-(chain/draw/date) dimensions - and places all variables on the same subplot. + Visualizes how variables contribute over time, showing the median line and + HDI bands. Useful for understanding channel contributions, intercepts, or + other time-varying effects in your model. Parameters ---------- var : list of str - A list of variable names to plot from the posterior. - hdi_prob: float, optional - The probability mass of the highest density interval to be displayed. Default is 0.85. + Variable names to plot from the posterior group. Must have a "date" dimension. + Examples: ["channel_contribution"], ["intercept"], ["channel_contribution", "intercept"]. + data : xr.Dataset, optional + Dataset containing posterior data with variables in `var`. + If None, uses self.idata.posterior. + This parameter allows: + - Testing with mock data without modifying self.idata + - Plotting external results not stored in self.idata + - Comparing different posterior samples side-by-side + - Avoiding unintended side effects on self.idata + hdi_prob : float, default 0.85 + Probability mass for HDI interval (between 0 and 1). dims : dict[str, str | int | list], optional - Dimension filters to apply. Example: {"country": ["US", "UK"], "user_type": "new"}. + Dimension filters to apply. Keys are dimension names, values are either: + - Single value: {"country": "US", "user_type": "new"} + - List of values: {"country": ["US", "UK"]} + If provided, only the selected slice(s) will be plotted. + backend : str, optional + Plotting backend to use. Options: "matplotlib", "plotly", "bokeh". + If None, uses global config via mmm_plot_config["plot.backend"]. + Default is "matplotlib". Returns ------- - fig : matplotlib.figure.Figure - The Figure object containing the subplots. - axes : np.ndarray of matplotlib.axes.Axes - Array of Axes objects corresponding to each subplot row. + PlotCollection + arviz_plots PlotCollection object containing the plot. + + Use ``.show()`` to display or ``.save("filename")`` to save. + Unlike the legacy suite which returned ``(Figure, Axes)``, + this provides a unified interface across all backends. Raises ------ ValueError - If `hdi_prob` is not between 0 and 1, instructing the user to provide a valid value. + If hdi_prob is not between 0 and 1. + ValueError + If no posterior data found in self.idata and no data argument provided. + ValueError + If any variable in `var` not found in data. + + See Also + -------- + LegacyMMMPlotSuite.contributions_over_time : Legacy matplotlib-only implementation + + Notes + ----- + Breaking changes from legacy implementation: + + - Returns PlotCollection instead of (Figure, Axes) + - Variable names must be passed in a list (was already list in legacy) + + Examples + -------- + Basic usage - plot channel contributions: + + .. code-block:: python + + mmm.fit(X, y) + pc = mmm.plot.contributions_over_time(var=["channel_contribution"]) + pc.show() + + Plot multiple variables together: + + .. code-block:: python + + pc = mmm.plot.contributions_over_time( + var=["channel_contribution", "intercept"] + ) + pc.show() + + Filter by dimension: + + .. code-block:: python + + pc = mmm.plot.contributions_over_time( + var=["channel_contribution"], dims={"geo": "US"} + ) + pc.show() + + Filter with multiple dimension values: + + .. code-block:: python + + pc = mmm.plot.contributions_over_time( + var=["channel_contribution"], dims={"geo": ["US", "UK"]} + ) + pc.show() + + Use different backend: + + .. code-block:: python + + pc = mmm.plot.contributions_over_time( + var=["channel_contribution"], backend="plotly" + ) + pc.show() + + Provide explicit data (option 1 - via data parameter): + + .. code-block:: python + + custom_posterior = xr.Dataset(...) + pc = mmm.plot.contributions_over_time( + var=["my_contribution"], data=custom_posterior + ) + pc.show() + + Provide explicit data (option 2 - direct instantiation): + + .. code-block:: python + + from pymc_marketing.mmm.plot import MMMPlotSuite + + mps = MMMPlotSuite(custom_idata) + pc = mps.contributions_over_time(var=["my_contribution"]) + pc.show() """ if not 0 < hdi_prob < 1: raise ValueError("HDI probability must be between 0 and 1.") - if not hasattr(self.idata, "posterior"): + # Get data with fallback to self.idata.posterior + data = self._get_data_or_fallback(data, "posterior", "posterior data") + + # Validate data has the required variables + missing_vars = [v for v in var if v not in data] + if missing_vars: raise ValueError( - "No posterior data found in 'self.idata'. " - "Please ensure 'self.idata' contains a 'posterior' group." + f"Variables {missing_vars} not found in data. " + f"Available variables: {list(data.data_vars)}" ) + # Resolve backend + backend = self._resolve_backend(backend) + main_var = var[0] - all_dims = list(self.idata.posterior[main_var].dims) # type: ignore ignored_dims = {"chain", "draw", "date"} - additional_dims = [d for d in all_dims if d not in ignored_dims] - - coords = { - key: value.to_numpy() - for key, value in self.idata.posterior[var].coords.items() - } + da = data[var] - # Apply user-specified filters (`dims`) + # Apply dims filtering if provided if dims: - self._validate_dims(dims=dims, all_dims=all_dims) - # Remove filtered dims from the combinations - additional_dims = [d for d in additional_dims if d not in dims] - else: - self._validate_dims({}, all_dims) - # additional_dims = [d for d in additional_dims if d not in dims] - - # Identify combos for remaining dims - if additional_dims: - additional_coords = [ - self.idata.posterior.coords[dim].values # type: ignore - for dim in additional_dims - ] - dim_combinations = list(itertools.product(*additional_coords)) - else: - dim_combinations = [()] + self._validate_dims(dims, list(da[main_var].dims)) + for dim_name, dim_value in dims.items(): + if isinstance(dim_value, (list, tuple, np.ndarray)): + da = da.sel({dim_name: dim_value}) + else: + da = da.sel({dim_name: dim_value}) - # If dims contains lists, build all combinations for those as well - dims_keys, dims_combos = self._dim_list_handler(dims) + additional_dims, _ = self._get_additional_dim_combinations( + data=da, variable=main_var, ignored_dims=ignored_dims + ) - # Prepare subplots: one for each combo of dims_lists and additional_dims - total_combos = list(itertools.product(dims_combos, dim_combinations)) - fig, axes = self._init_subplots(len(total_combos), ncols=1) + # 4. Prepare subplots + pc = azp.PlotCollection.wrap( + da, + cols=additional_dims, + col_wrap=1, + figure_kwargs={ + "sharex": True, + }, + backend=backend, + ) - for row_idx, (dims_combo, addl_combo) in enumerate(total_combos): - ax = axes[row_idx][0] - # Build indexers for dims and additional_dims - indexers = ( - dict(zip(additional_dims, addl_combo, strict=False)) - if additional_dims - else {} - ) - if dims: - # For dims with lists, use the current value from dims_combo - for i, k in enumerate(dims_keys): - indexers[k] = dims_combo[i] - # For dims with single values, use as is - for k, v in (dims or {}).items(): - if k not in dims_keys: - indexers[k] = v - - # Plot posterior median and HDI for each var - for v in var: - data = self.idata.posterior[v] - missing_coords = { - key: value for key, value in coords.items() if key not in data.dims - } - data = data.expand_dims(**missing_coords) - data = data.sel(**indexers) # apply slice - data = self._reduce_and_stack( - data, dims_to_ignore={"date", "chain", "draw", "sample"} - ) - ax = self._add_median_and_hdi(ax, data, v, hdi_prob=hdi_prob) + # plot hdi + hdi = da.azstats.hdi(hdi_prob) + pc.map( + azp.visuals.fill_between_y, + x=da["date"], + y_bottom=hdi.sel(ci_bound="lower"), + y_top=hdi.sel(ci_bound="upper"), + alpha=0.2, + color="C0", + ) - # Title includes both fixed and combo dims - title_dims = ( - list(dims.keys()) + additional_dims if dims else additional_dims - ) - title_combo = tuple(indexers[k] for k in title_dims) + # plot median line + pc.map( + azp.visuals.line_xy, + x=da["date"], + y=da.median(dim=["chain", "draw"]), + color="C0", + ) - title = self._build_subplot_title( - dims=title_dims, combo=title_combo, fallback_title="Time Series" - ) - ax.set_title(title) - ax.set_xlabel("Date") - ax.set_ylabel("Posterior Value") - ax.legend(loc="best") + # add labels + pc.map(azp.visuals.labelled_x, text="Date") + pc.map(azp.visuals.labelled_y, text="Posterior Value") + pc.map( + azp.visuals.labelled_title, + subset_info=True, + labeller=mix_labellers((NoVarLabeller, DimCoordLabeller))(), + ) - return fig, axes + return pc def saturation_scatterplot( self, original_scale: bool = False, + constant_data: xr.Dataset | None = None, + posterior_data: xr.Dataset | None = None, dims: dict[str, str | int | list] | None = None, - **kwargs, - ) -> tuple[Figure, NDArray[Axes]]: - """Plot the saturation curves for each channel. + backend: str | None = None, + ) -> PlotCollection: + """Plot saturation scatter plot showing channel spend vs contributions. + + Creates scatter plots of actual channel spend (X-axis) against channel + contributions (Y-axis), one subplot per channel. Useful for understanding + the saturation behavior and diminishing returns of each marketing channel. + + Parameters + ---------- + original_scale : bool, default False + Whether to plot in original scale (True) or scaled space (False). + If True, requires channel_contribution_original_scale in posterior. + constant_data : xr.Dataset, optional + Dataset containing constant_data group with required variables: + - 'channel_data': Channel spend data (dims include "date", "channel") + - 'channel_scale': Scaling factor per channel (if original_scale=True) + - 'target_scale': Target scaling factor (if original_scale=True) + + If None, uses self.idata.constant_data. + This parameter allows: + - Testing with mock constant data + - Plotting with alternative scaling factors + - Comparing different data scenarios + posterior_data : xr.Dataset, optional + Dataset containing posterior group with channel contribution variables. + Must contain 'channel_contribution' or 'channel_contribution_original_scale'. + If None, uses self.idata.posterior. + This parameter allows: + - Testing with mock posterior samples + - Plotting external inference results + - Comparing different model fits + dims : dict[str, str | int | list], optional + Dimension filters to apply. Examples: + - {"geo": "US"} - Single value + - {"geo": ["US", "UK"]} - Multiple values + + If provided, only the selected slice(s) will be plotted. + backend : str, optional + Plotting backend to use. Options: "matplotlib", "plotly", "bokeh". + If None, uses global config via mmm_plot_config["plot.backend"]. + Default is "matplotlib". + + Returns + ------- + PlotCollection + arviz_plots PlotCollection object containing the plot. + + Use ``.show()`` to display or ``.save("filename")`` to save. + Unlike the legacy suite which returned ``(Figure, Axes)``, + this provides a unified interface across all backends. + + Raises + ------ + ValueError + If required data not found in self.idata and not provided explicitly. + ValueError + If 'channel_data' not found in constant_data. + ValueError + If original_scale=True but channel_contribution_original_scale not in posterior. + + See Also + -------- + saturation_curves : Add posterior predictive curves to this scatter plot + LegacyMMMPlotSuite.saturation_scatterplot : Legacy matplotlib-only implementation + + Notes + ----- + Breaking changes from legacy implementation: + + - Returns PlotCollection instead of (Figure, Axes) + - Lost **kwargs for matplotlib customization (use backend-specific methods) + - Different grid layout algorithm + + Examples + -------- + Basic usage (scaled space): + + .. code-block:: python + + mmm.fit(X, y) + pc = mmm.plot.saturation_scatterplot() + pc.show() + + Plot in original scale: + + .. code-block:: python + + mmm.add_original_scale_contribution_variable(var=["channel_contribution"]) + pc = mmm.plot.saturation_scatterplot(original_scale=True) + pc.show() + + Filter by dimension: + + .. code-block:: python + + pc = mmm.plot.saturation_scatterplot(dims={"geo": "US"}) + pc.show() + + Use different backend: + + .. code-block:: python + + pc = mmm.plot.saturation_scatterplot(backend="plotly") + pc.show() + + Provide explicit data: + + .. code-block:: python - Creates a grid of subplots for each combination of channel and non-(date/channel) dimensions. - Optionally, subset by dims (single values or lists). - Each channel will have a consistent color across all subplots. + custom_constant = xr.Dataset(...) + custom_posterior = xr.Dataset(...) + pc = mmm.plot.saturation_scatterplot( + constant_data=custom_constant, posterior_data=custom_posterior + ) + pc.show() """ - if not hasattr(self.idata, "constant_data"): + # Resolve backend + backend = self._resolve_backend(backend) + + # Get constant_data and posterior_data with fallback + constant_data = self._get_data_or_fallback( + constant_data, "constant_data", "constant data" + ) + posterior_data = self._get_data_or_fallback( + posterior_data, "posterior", "posterior data" + ) + + # Validate required variables exist + if "channel_data" not in constant_data: raise ValueError( - "No 'constant_data' found in 'self.idata'. " - "Please ensure 'self.idata' contains the constant_data group." + "'channel_data' variable not found in constant_data. " + f"Available variables: {list(constant_data.data_vars)}" ) # Identify additional dimensions beyond 'date' and 'channel' - cdims = self.idata.constant_data.channel_data.dims + cdims = constant_data.channel_data.dims additional_dims = [dim for dim in cdims if dim not in ("date", "channel")] # Validate dims and remove filtered dims from additional_dims if dims: - self._validate_dims(dims, list(self.idata.constant_data.channel_data.dims)) + self._validate_dims(dims, list(constant_data.channel_data.dims)) additional_dims = [d for d in additional_dims if d not in dims] else: - self._validate_dims({}, list(self.idata.constant_data.channel_data.dims)) - - # Build all combinations for dims with lists - dims_keys, dims_combos = self._dim_list_handler(dims) - - # Build all combinations for remaining dims - if additional_dims: - additional_coords = [ - self.idata.constant_data.coords[d].values for d in additional_dims - ] - additional_combinations = list(itertools.product(*additional_coords)) - else: - additional_combinations = [()] - - channels = self.idata.constant_data.coords["channel"].values - n_channels = len(channels) - n_addl = len(additional_combinations) - n_dims = len(dims_combos) - - # For most use cases, n_dims will be 1, so grid is channels x additional_combinations - # If dims_combos > 1, treat as extra axis (rare, but possible) - nrows = n_channels - ncols = n_addl * n_dims - total_combos = list( - itertools.product(channels, dims_combos, additional_combinations) - ) - n_subplots = len(total_combos) - - # Assign a color to each channel - channel_colors = {ch: f"C{i}" for i, ch in enumerate(channels)} - - # Prepare subplots as a grid - fig, axes = plt.subplots( - nrows=nrows, - ncols=ncols, - figsize=( - kwargs.get("width_per_col", 8) * ncols, - kwargs.get("height_per_row", 4) * nrows, - ), - squeeze=False, - ) + self._validate_dims({}, list(constant_data.channel_data.dims)) channel_contribution = ( "channel_contribution_original_scale" @@ -662,9 +867,9 @@ def saturation_scatterplot( else "channel_contribution" ) - if original_scale and not hasattr(self.idata.posterior, channel_contribution): + if channel_contribution not in posterior_data: raise ValueError( - f"""No posterior.{channel_contribution} data found in 'self.idata'. \n + f"""No posterior.{channel_contribution} data found in posterior_data. \n Add a original scale deterministic:\n mmm.add_original_scale_contribution_variable(\n var=[\n @@ -675,160 +880,210 @@ def saturation_scatterplot( """ ) - for _idx, (channel, dims_combo, addl_combo) in enumerate(total_combos): - # Compute subplot position - row = list(channels).index(channel) - # If dims_combos > 1, treat as extra axis (columns: addl * dims) - if n_dims > 1: - col = list(additional_combinations).index(addl_combo) * n_dims + list( - dims_combos - ).index(dims_combo) - else: - col = list(additional_combinations).index(addl_combo) - ax = axes[row][col] - - # Build indexers for dims and additional_dims - indexers = ( - dict(zip(additional_dims, addl_combo, strict=False)) - if additional_dims - else {} - ) - if dims: - for i, k in enumerate(dims_keys): - indexers[k] = dims_combo[i] - for k, v in (dims or {}).items(): - if k not in dims_keys: - indexers[k] = v - indexers["channel"] = channel - - # Select X data (constant_data) - x_data = self.idata.constant_data.channel_data.sel(**indexers) - # Select Y data (posterior contributions) and scale if needed - y_data = self.idata.posterior[channel_contribution].sel(**indexers) - y_data = y_data.mean(dim=[d for d in y_data.dims if d in ("chain", "draw")]) - x_data = x_data.broadcast_like(y_data) - y_data = y_data.broadcast_like(x_data) - ax.scatter( - x_data.values.flatten(), - y_data.values.flatten(), - alpha=0.8, - color=channel_colors[channel], - label=str(channel), - ) - # Build subplot title - title_dims = ( - ["channel"] + (list(dims.keys()) if dims else []) + additional_dims - ) - title_combo = ( - channel, - *[indexers[k] for k in title_dims if k != "channel"], - ) - title = self._build_subplot_title( - dims=title_dims, - combo=title_combo, - fallback_title="Channel Saturation Curve", - ) - ax.set_title(title) - ax.set_xlabel("Channel Data (X)") - ax.set_ylabel("Channel Contributions (Y)") - ax.legend(loc="best") + # Apply dims filtering to channel_data and channel_contribution + channel_data = constant_data.channel_data + channel_contrib = posterior_data[channel_contribution] - # Hide any unused axes (if grid is larger than needed) - for i in range(nrows): - for j in range(ncols): - if i * ncols + j >= n_subplots: - axes[i][j].set_visible(False) + if dims: + for dim_name, dim_value in dims.items(): + if isinstance(dim_value, (list, tuple, np.ndarray)): + channel_data = channel_data.sel({dim_name: dim_value}) + channel_contrib = channel_contrib.sel({dim_name: dim_value}) + else: + channel_data = channel_data.sel({dim_name: dim_value}) + channel_contrib = channel_contrib.sel({dim_name: dim_value}) + + pc = azp.PlotCollection.grid( + channel_contrib.mean(dim=["chain", "draw"]).to_dataset(), + cols=additional_dims, + rows=["channel"], + aes={"color": ["channel"]}, + backend=backend, + ) + pc.map( + azp.visuals.scatter_xy, + x=channel_data, + ) + pc.map(azp.visuals.labelled_x, text="Channel Data", ignore_aes={"color"}) + pc.map( + azp.visuals.labelled_y, text="Channel Contributions", ignore_aes={"color"} + ) + pc.map( + azp.visuals.labelled_title, + subset_info=True, + labeller=mix_labellers((NoVarLabeller, DimCoordLabeller))(), + ignore_aes={"color"}, + ) - return fig, axes + return pc def saturation_curves( self, curve: xr.DataArray, original_scale: bool = False, + constant_data: xr.Dataset | None = None, + posterior_data: xr.Dataset | None = None, n_samples: int = 10, hdi_probs: float | list[float] | None = None, random_seed: np.random.Generator | None = None, - colors: Iterable[str] | None = None, - subplot_kwargs: dict | None = None, - rc_params: dict | None = None, dims: dict[str, str | int | list] | None = None, - **plot_kwargs, - ) -> tuple[plt.Figure, np.ndarray]: - """ - Overlay saturation‑curve scatter‑plots with posterior‑predictive sample curves and HDI bands. + backend: str | None = None, + ) -> PlotCollection: + """Overlay saturation scatter plots with posterior predictive curves and HDI bands. - **allowing** you to customize figsize and font sizes. + Builds on saturation_scatterplot() by adding: + - Sample curves from the posterior distribution + - HDI bands showing uncertainty + - Smooth saturation curves over the scatter plot Parameters ---------- curve : xr.DataArray - Posterior‑predictive curves (e.g. dims `("chain","draw","x","channel","geo")`). - original_scale : bool, default=False - Plot `channel_contribution_original_scale` if True, else `channel_contribution`. - n_samples : int, default=10 - Number of sample‑curves per subplot. + Posterior predictive saturation curves with required dimensions: + - "chain", "draw": MCMC samples + - "x": Input values for curve evaluation + - "channel": Channel names + + Generate using: ``mmm.saturation.sample_curve(...)`` + original_scale : bool, default False + Plot in original scale (True) or scaled space (False). + If True, requires channel_contribution_original_scale in posterior. + constant_data : xr.Dataset, optional + Dataset containing constant_data group. If None, uses self.idata.constant_data. + This parameter allows testing with mock data and plotting alternative scenarios. + posterior_data : xr.Dataset, optional + Dataset containing posterior group. If None, uses self.idata.posterior. + This parameter allows testing with mock posterior samples and comparing model fits. + n_samples : int, default 10 + Number of sample curves to draw per subplot. + Set to 0 to show only HDI bands without individual samples. hdi_probs : float or list of float, optional - Credible interval probabilities (e.g. 0.94 or [0.5, 0.94]). - If None, uses ArviZ's default (0.94). + HDI probability levels for credible intervals. + Examples: 0.94 (single band), [0.5, 0.94] (multiple bands). + If None, no HDI bands are drawn. random_seed : np.random.Generator, optional - RNG for reproducible sampling. If None, uses `np.random.default_rng()`. - colors : iterable of str, optional - Colors for the sample & HDI plots. - subplot_kwargs : dict, optional - Passed to `plt.subplots` (e.g. `{"figsize": (10,8)}`). - Merged with the function's own default sizing. - rc_params : dict, optional - Temporary `matplotlib.rcParams` for this plot. - Example keys: `"xtick.labelsize"`, `"ytick.labelsize"`, - `"axes.labelsize"`, `"axes.titlesize"`. + Random number generator for reproducible curve sampling. + If None, uses ``np.random.default_rng()``. dims : dict[str, str | int | list], optional - Dimension filters to apply. Example: {"country": ["US", "UK"], "region": "X"}. + Dimension filters to apply. Examples: + - {"geo": "US"} + - {"geo": ["US", "UK"]} + If provided, only the selected slice(s) will be plotted. - **plot_kwargs - Any other kwargs forwarded to `plot_curve` - (for instance `same_axes=True`, `legend=True`, etc.). + backend : str, optional + Plotting backend to use. Options: "matplotlib", "plotly", "bokeh". + If None, uses global config via mmm_plot_config["plot.backend"]. + Default is "matplotlib". Returns ------- - fig : plt.Figure - Matplotlib figure with your grid. - axes : np.ndarray of plt.Axes - Array of shape `(n_channels, n_geo)`. + PlotCollection + arviz_plots PlotCollection object containing the plot. - """ - from pymc_marketing.plot import plot_hdi, plot_samples + Use ``.show()`` to display or ``.save("filename")`` to save. - if not hasattr(self.idata, "constant_data"): - raise ValueError( - "No 'constant_data' found in 'self.idata'. " - "Please ensure 'self.idata' contains the constant_data group." - ) + Raises + ------ + ValueError + If curve is missing required dimensions ("x" or "channel"). + ValueError + If original_scale=True but channel_contribution_original_scale not in posterior. - contrib_var = ( - "channel_contribution_original_scale" - if original_scale - else "channel_contribution" - ) + See Also + -------- + saturation_scatterplot : Base scatter plot without curves + LegacyMMMPlotSuite.saturation_curves : Legacy matplotlib-only implementation - if original_scale and not hasattr(self.idata.posterior, contrib_var): - raise ValueError( - f"""No posterior.{contrib_var} data found in 'self.idata'.\n" - "Add a original scale deterministic:\n" - " mmm.add_original_scale_contribution_variable(\n" - " var=[\n" - " 'channel_contribution',\n" - " ...\n" - " ]\n" - " )\n" - """ - ) - curve_data = ( - curve * self.idata.constant_data.target_scale if original_scale else curve - ) - curve_data = curve_data.rename("saturation_curve") + Notes + ----- + Breaking changes from legacy implementation: + + - Returns PlotCollection instead of (Figure, Axes) + - Lost colors, subplot_kwargs, rc_params parameters + - Different HDI calculation (uses arviz_plots instead of custom) + + Examples + -------- + Generate and plot saturation curves: + + .. code-block:: python + + # Generate curves using saturation transformation + curve = mmm.saturation.sample_curve( + idata=mmm.idata.posterior[["saturation_beta", "saturation_lam"]], + max_value=2.0, + ) + pc = mmm.plot.saturation_curves(curve) + pc.show() + + Add HDI bands: + + .. code-block:: python + + pc = mmm.plot.saturation_curves(curve, hdi_probs=[0.5, 0.94]) + pc.show() + + Original scale with custom seed: + + .. code-block:: python + + import numpy as np + + rng = np.random.default_rng(42) + mmm.add_original_scale_contribution_variable(var=["channel_contribution"]) + pc = mmm.plot.saturation_curves( + curve, original_scale=True, n_samples=15, random_seed=rng + ) + pc.show() + + Filter by dimension: + + .. code-block:: python + + pc = mmm.plot.saturation_curves(curve, dims={"geo": "US"}) + pc.show() + """ + # Get constant_data and posterior_data with fallback + constant_data = self._get_data_or_fallback( + constant_data, "constant_data", "constant data" + ) + posterior_data = self._get_data_or_fallback( + posterior_data, "posterior", "posterior data" + ) + + contrib_var = ( + "channel_contribution_original_scale" + if original_scale + else "channel_contribution" + ) + + if original_scale and contrib_var not in posterior_data: + raise ValueError( + f"""No posterior.{contrib_var} data found in posterior_data.\n" + "Add a original scale deterministic:\n" + " mmm.add_original_scale_contribution_variable(\n" + " var=[\n" + " 'channel_contribution',\n" + " ...\n" + " ]\n" + " )\n" + """ + ) + # Validate curve dimensions + if "x" not in curve.dims: + raise ValueError("curve must have an 'x' dimension") + if "channel" not in curve.dims: + raise ValueError("curve must have a 'channel' dimension") + + if original_scale: + curve_data = curve * constant_data.target_scale + curve_data["x"] = curve_data["x"] * constant_data.channel_scale + else: + curve_data = curve + curve_data = curve_data.rename("saturation_curve") # — 1. figure out grid shape based on scatter data dimensions / identify dims and combos - cdims = self.idata.constant_data.channel_data.dims + cdims = constant_data.channel_data.dims all_dims = list(cdims) additional_dims = [d for d in cdims if d not in ("date", "channel")] # Validate dims and remove filtered dims from additional_dims @@ -837,244 +1092,151 @@ def saturation_curves( additional_dims = [d for d in additional_dims if d not in dims] else: self._validate_dims({}, all_dims) - # Build all combinations for dims with lists - dims_keys, dims_combos = self._dim_list_handler(dims) - # Build all combinations for remaining dims - if additional_dims: - additional_coords = [ - self.idata.constant_data.coords[d].values for d in additional_dims - ] - additional_combinations = list(itertools.product(*additional_coords)) - else: - additional_combinations = [()] - channels = self.idata.constant_data.coords["channel"].values - n_channels = len(channels) - n_addl = len(additional_combinations) - n_dims = len(dims_combos) - nrows = n_channels - ncols = n_addl * n_dims - total_combos = list( - itertools.product(channels, dims_combos, additional_combinations) + + # create the saturation scatterplot + pc = self.saturation_scatterplot( + original_scale=original_scale, + constant_data=constant_data, + posterior_data=posterior_data, + dims=dims, + backend=backend, ) - n_subplots = len(total_combos) - - # — 2. merge subplot_kwargs — - user_subplot = subplot_kwargs or {} - - # Handle user-specified ncols/nrows - if "ncols" in user_subplot: - # User specified ncols, calculate nrows - ncols = user_subplot["ncols"] - nrows = int(np.ceil(n_subplots / ncols)) - user_subplot.pop("ncols") # Remove to avoid conflict - elif "nrows" in user_subplot: - # User specified nrows, calculate ncols - nrows = user_subplot["nrows"] - ncols = int(np.ceil(n_subplots / nrows)) - user_subplot.pop("nrows") # Remove to avoid conflict - default_subplot = {"figsize": (ncols * 4, nrows * 3)} - subkw = {**default_subplot, **user_subplot} - # — 3. create subplots ourselves — - rc_params = rc_params or {} - with plt.rc_context(rc_params): - fig, axes = plt.subplots(nrows=nrows, ncols=ncols, **subkw) - # ensure a 2D array - if nrows == 1 and ncols == 1: - axes = np.array([[axes]]) - elif nrows == 1: - axes = axes.reshape(1, -1) - elif ncols == 1: - axes = axes.reshape(-1, 1) - # Flatten axes for easier iteration - axes_flat = axes.flatten() - if colors is None: - colors = [f"C{i}" for i in range(n_channels)] - elif not isinstance(colors, list): - colors = list(colors) - subplot_idx = 0 - for _idx, (ch, dims_combo, addl_combo) in enumerate(total_combos): - if subplot_idx >= len(axes_flat): - break - ax = axes_flat[subplot_idx] - subplot_idx += 1 - # Build indexers for dims and additional_dims - indexers = ( - dict(zip(additional_dims, addl_combo, strict=False)) - if additional_dims - else {} - ) - if dims: - for i, k in enumerate(dims_keys): - indexers[k] = dims_combo[i] - for k, v in (dims or {}).items(): - if k not in dims_keys: - indexers[k] = v - indexers["channel"] = ch - # Select and broadcast curve data for this channel - curve_idx = { - dim: val for dim, val in indexers.items() if dim in curve_data.dims - } - subplot_curve = curve_data.sel(**curve_idx) - if original_scale: - valid_idx = { - k: v - for k, v in indexers.items() - if k in self.idata.constant_data.channel_scale.dims - } - channel_scale = self.idata.constant_data.channel_scale.sel(**valid_idx) - x_original = subplot_curve.coords["x"] * channel_scale - subplot_curve = subplot_curve.assign_coords(x=x_original) - if n_samples > 0: - plot_samples( - subplot_curve, - non_grid_names="x", - n=n_samples, - rng=random_seed, - axes=np.array([[ax]]), - colors=[colors[list(channels).index(ch)]], - same_axes=False, - legend=False, - **plot_kwargs, - ) - if hdi_probs is not None: - # Robustly handle hdi_probs as float, list, tuple, or np.ndarray - if isinstance(hdi_probs, (float, int)): - hdi_probs_iter = [hdi_probs] - elif isinstance(hdi_probs, (list, tuple, np.ndarray)): - hdi_probs_iter = hdi_probs - else: - raise TypeError( - "hdi_probs must be a float, list, tuple, or np.ndarray" - ) - for hdi_prob in hdi_probs_iter: - plot_hdi( - subplot_curve, - non_grid_names="x", - hdi_prob=hdi_prob, - axes=np.array([[ax]]), - colors=[colors[list(channels).index(ch)]], - same_axes=False, - legend=False, - **plot_kwargs, - ) - x_data = self.idata.constant_data.channel_data.sel(**indexers) - y = ( - self.idata.posterior[contrib_var] - .sel(**indexers) - .mean( - dim=[ - d - for d in self.idata.posterior[contrib_var].dims - if d in ("chain", "draw") - ] + + # add the hdi bands + if hdi_probs is not None: + # Robustly handle hdi_probs as float, list, tuple, or np.ndarray + if isinstance(hdi_probs, (float, int)): + hdi_probs_iter = [hdi_probs] + elif isinstance(hdi_probs, (list, tuple, np.ndarray)): + hdi_probs_iter = hdi_probs + else: + raise TypeError("hdi_probs must be a float, list, tuple, or np.ndarray") + for hdi_prob in hdi_probs_iter: + hdi = curve_data.azstats.hdi(hdi_prob) + pc.map( + azp.visuals.fill_between_y, + x=curve_data["x"], + y_bottom=hdi.sel(ci_bound="lower"), + y_top=hdi.sel(ci_bound="upper"), + alpha=0.2, ) - ) - x_data, y = x_data.broadcast_like(y), y.broadcast_like(x_data) - ax.scatter( - x_data.values.flatten(), - y.values.flatten(), - alpha=0.8, - color=colors[list(channels).index(ch)], - ) - title_dims = ( - ["channel"] + (list(dims.keys()) if dims else []) + additional_dims - ) - title_combo = ( - ch, - *[indexers[k] for k in title_dims if k != "channel"], - ) - title = self._build_subplot_title( - dims=title_dims, - combo=title_combo, - fallback_title="Channel Saturation Curves", - ) - ax.set_title(title) - ax.set_xlabel("Channel Data (X)") - ax.set_ylabel("Channel Contribution (Y)") - for ax_idx in range(subplot_idx, len(axes_flat)): - axes_flat[ax_idx].set_visible(False) - return fig, axes - - def saturation_curves_scatter( - self, original_scale: bool = False, **kwargs - ) -> tuple[Figure, NDArray[Axes]]: - """ - Plot scatter plots of channel contributions vs. channel data. - .. deprecated:: 0.1.0 - Will be removed in version 0.2.0. Use :meth:`saturation_scatterplot` instead. + if n_samples > 0: + ## sample the curves + rng = np.random.default_rng(random_seed) - Parameters - ---------- - channel_contribution : str, optional - Name of the channel contribution variable in the InferenceData. - additional_dims : list[str], optional - Additional dimensions to consider beyond 'channel'. - additional_combinations : list[tuple], optional - Specific combinations of additional dimensions to plot. - **kwargs - Additional keyword arguments passed to _init_subplots. + # Stack the two dimensions + stacked = curve_data.stack(sample=("chain", "draw")) - Returns - ------- - fig : plt.Figure - The matplotlib figure. - axes : np.ndarray - Array of matplotlib axes. - """ - import warnings + # Sample from the stacked dimension + idx = rng.choice(stacked.sizes["sample"], size=n_samples, replace=False) - warnings.warn( - "saturation_curves_scatter is deprecated and will be removed in version 0.2.0. " - "Use saturation_scatterplot instead.", - DeprecationWarning, - stacklevel=2, - ) - # Note: channel_contribution, additional_dims, and additional_combinations - # are not used by saturation_scatterplot, so we don't pass them - return self.saturation_scatterplot(original_scale=original_scale, **kwargs) + # Select and unstack + sampled_curves = stacked.isel(sample=idx) + + # plot the sampled curves + pc.map( + azp.visuals.multiple_lines, x_dim="x", data=sampled_curves, alpha=0.2 + ) - def budget_allocation( + return pc + + def budget_allocation_roas( self, samples: xr.Dataset, - scale_factor: float | None = None, - figsize: tuple[float, float] = (12, 6), - ax: plt.Axes | None = None, - original_scale: bool = True, dims: dict[str, str | int | list] | None = None, - ) -> tuple[Figure, plt.Axes] | tuple[Figure, np.ndarray]: - """Plot the budget allocation and channel contributions. + dims_to_group_by: list[str] | str | None = None, + backend: str | None = None, + ) -> PlotCollection: + """Plot ROI (Return on Ad Spend) distributions for budget allocation scenarios. - Creates a bar chart comparing allocated spend and channel contributions - for each channel. If additional dimensions besides 'channel' are present, - creates a subplot for each combination of these dimensions. + Visualizes the posterior distribution of ROI for each channel given a budget + allocation. Useful for comparing ROI across channels and understanding + optimization trade-offs. Parameters ---------- samples : xr.Dataset - The dataset containing the channel contributions and allocation values. - Expected to have 'channel_contribution' and 'allocation' variables. - scale_factor : float, optional - Scale factor to convert to original scale, if original_scale=True. - If None and original_scale=True, assumes scale_factor=1. - figsize : tuple[float, float], optional - The size of the figure to be created. Default is (12, 6). - ax : plt.Axes, optional - The axis to plot on. If None, a new figure and axis will be created. - Only used when no extra dimensions are present. - original_scale : bool, optional - A boolean flag to determine if the values should be plotted in their - original scale. Default is True. + Dataset from budget allocation optimization containing: + - 'channel_contribution_original_scale': Channel contributions + - 'allocation': Allocated budget per channel + - 'channel' dimension + + Typically obtained from: ``mmm.allocate_budget_to_maximize_response(...)`` dims : dict[str, str | int | list], optional - Dimension filters to apply. Example: {"country": ["US", "UK"], "user_type": "new"}. + Dimension filters to apply. Examples: + - {"geo": "US"} + - {"geo": ["US", "UK"]} + If provided, only the selected slice(s) will be plotted. + dims_to_group_by : list[str] | str | None, optional + Dimension(s) to group by for overlaying distributions. + When specified, all ROI distributions for each coordinate of that + dimension will be plotted together for comparison. + + - None (default): Each distribution plotted separately + - Single string: Group by that dimension (e.g., "geo") + - List of strings: Group by multiple dimensions (e.g., ["geo", "segment"]) + backend : str | None, optional + Backend to use for plotting. If None, uses global backend configuration. Returns ------- - fig : matplotlib.figure.Figure - The Figure object containing the plot. - axes : matplotlib.axes.Axes or numpy.ndarray of matplotlib.axes.Axes - The Axes object with the plot, or array of Axes for multiple subplots. + PlotCollection + arviz_plots PlotCollection object containing the plot. + + Use ``.show()`` to display or ``.save("filename")`` to save. + + Raises + ------ + ValueError + If 'channel' dimension not found in samples. + ValueError + If required variables not found in samples. + + See Also + -------- + LegacyMMMPlotSuite.budget_allocation : Legacy bar chart method (different purpose) + + Notes + ----- + This method is NEW in MMMPlotSuite v2 and serves a different purpose + than the legacy ``budget_allocation()`` method: + + - **New method** (this): Shows ROI distributions (KDE plots) + - **Legacy method**: Shows bar charts comparing spend vs contributions + + To use the legacy method, set: ``mmm_plot_config["plot.use_v2"] = False`` + + Examples + -------- + Basic usage with budget optimization results: + + .. code-block:: python + + allocation_results = mmm.allocate_budget_to_maximize_response( + total_budget=100_000, budget_bounds={"lower": 0.5, "upper": 2.0} + ) + pc = mmm.plot.budget_allocation_roas(allocation_results) + pc.show() + + Group by geography to compare ROI across regions: + + .. code-block:: python + + pc = mmm.plot.budget_allocation_roas( + allocation_results, dims_to_group_by="geo" + ) + pc.show() + + Filter and group: + + .. code-block:: python + + pc = mmm.plot.budget_allocation_roas( + allocation_results, dims={"segment": "premium"}, dims_to_group_by="geo" + ) + pc.show() """ # Get the channels from samples if "channel" not in samples.dims: @@ -1083,11 +1245,9 @@ def budget_allocation( ) # Check for required variables in samples - if not any( - "channel_contribution" in var_name for var_name in samples.data_vars - ): + if "channel_contribution_original_scale" not in samples.data_vars: raise ValueError( - "Expected a variable containing 'channel_contribution' in samples, but none found." + "Expected a variable containing 'channel_contribution_original_scale' in samples, but none found." ) if "allocation" not in samples: raise ValueError( @@ -1095,11 +1255,7 @@ def budget_allocation( ) # Find the variable containing 'channel_contribution' in its name - channel_contrib_var = next( - var_name - for var_name in samples.data_vars - if "channel_contribution" in var_name - ) + channel_contrib_var = "channel_contribution_original_scale" all_dims = list(samples.dims) # Validate dims @@ -1108,218 +1264,144 @@ def budget_allocation( else: self._validate_dims({}, all_dims) - # Handle list-valued dims: build all combinations - dims_keys, dims_combos = self._dim_list_handler(dims) - - # After filtering with dims, only use extra dims not in dims and not ignored for subplotting - ignored_dims = {"channel", "date", "sample", "chain", "draw"} - channel_contribution_dims = list(samples[channel_contrib_var].dims) - extra_dims = [ - d - for d in channel_contribution_dims - if d not in ignored_dims and d not in (dims or {}) - ] - - # Identify combos for remaining dims - if extra_dims: - extra_coords = [samples.coords[dim].values for dim in extra_dims] - extra_combos = list(itertools.product(*extra_coords)) - else: - extra_combos = [()] - - # Prepare subplots: one for each combo of dims_lists and extra_dims - total_combos = list(itertools.product(dims_combos, extra_combos)) - n_subplots = len(total_combos) - if n_subplots == 1 and ax is not None: - axes = np.array([[ax]]) - fig = ax.get_figure() - else: - fig, axes = self._init_subplots( - n_subplots=n_subplots, - ncols=1, - width_per_col=figsize[0], - height_per_row=figsize[1], - ) + channel_contribution = samples[channel_contrib_var].sum(dim="date") + channel_contribution.name = "channel_contribution" + + from arviz_base import convert_to_datatree + + roa_da = channel_contribution / samples.allocation + roa_dt = convert_to_datatree(roa_da) + if isinstance(dims_to_group_by, str): + dims_to_group_by = [dims_to_group_by] + if dims_to_group_by: + grouped = {"all": roa_dt.copy()} + for dim in dims_to_group_by: + new_grouped = {} + for curr_k, curr_group in grouped.items(): + curr_coords = curr_group.posterior.coords[dim].values + new_grouped.update( + { + f"{curr_k}, {dim}: {key}": curr_group.sel({dim: key}) + for key in curr_coords + } + ) + grouped = new_grouped - for row_idx, (dims_combo, extra_combo) in enumerate(total_combos): - ax_ = axes[row_idx][0] - # Build indexers for dims and extra_dims - indexers = ( - dict(zip(extra_dims, extra_combo, strict=False)) if extra_dims else {} - ) - if dims: - # For dims with lists, use the current value from dims_combo - for i, k in enumerate(dims_keys): - indexers[k] = dims_combo[i] - # For dims with single values, use as is - for k, v in (dims or {}).items(): - if k not in dims_keys: - indexers[k] = v - - # Select channel contributions for this subplot - channel_contrib_data = samples[channel_contrib_var].sel(**indexers) - allocation_data = samples.allocation - # Only select dims that exist in allocation - allocation_indexers = { - k: v for k, v in indexers.items() if k in allocation_data.dims - } - allocation_data = allocation_data.sel(**allocation_indexers) - - # Average over all dims except channel (and those used for this subplot) - used_dims = set(indexers.keys()) | {"channel"} - reduction_dims = [ - dim for dim in channel_contrib_data.dims if dim not in used_dims - ] - channel_contribution = channel_contrib_data.mean( - dim=reduction_dims - ).to_numpy() - if channel_contribution.ndim > 1: - channel_contribution = channel_contribution.flatten() - if original_scale and scale_factor is not None: - channel_contribution *= scale_factor - - allocation_used_dims = set(allocation_indexers.keys()) | {"channel"} - allocation_reduction_dims = [ - dim for dim in allocation_data.dims if dim not in allocation_used_dims - ] - if allocation_reduction_dims: - allocated_spend = allocation_data.mean( - dim=allocation_reduction_dims - ).to_numpy() - else: - allocated_spend = allocation_data.to_numpy() - if allocated_spend.ndim > 1: - allocated_spend = allocated_spend.flatten() - - self._plot_budget_allocation_bars( - ax_, - samples.coords["channel"].values, - allocated_spend, - channel_contribution, - ) + grouped_roa_dt = {} + prefix = "all, " + for k, v in grouped.items(): + if k.startswith(prefix): + grouped_roa_dt[k[len(prefix) :]] = v + else: + grouped_roa_dt[k] = v + else: + grouped_roa_dt = roa_dt + + pc = azp.plot_dist( + grouped_roa_dt, + kind="kde", + sample_dims=["sample"], + backend=backend, + labeller=mix_labellers((NoVarLabeller, DimCoordLabeller))(), + ) - # Build subplot title - title_dims = (list(dims.keys()) if dims else []) + extra_dims - title_combo = tuple(indexers[k] for k in title_dims) - title = self._build_subplot_title( - dims=title_dims, - combo=title_combo, - fallback_title="Budget Allocation", - ) - ax_.set_title(title) + if dims_to_group_by: + pc.add_legend(dim="model", title="") - fig.tight_layout() - return fig, axes if n_subplots > 1 else (fig, axes[0][0]) + return pc - def _plot_budget_allocation_bars( + def allocated_contribution_by_channel_over_time( self, - ax: plt.Axes, - channels: NDArray, - allocated_spend: NDArray, - channel_contribution: NDArray, - ) -> None: - """Plot budget allocation bars on a given axis. + samples: xr.Dataset, + hdi_prob: float = 0.85, + backend: str | None = None, + ) -> PlotCollection: + """Plot channel contributions over time from budget allocation optimization. + + Visualizes how contributions from each channel evolve over time given an + optimized budget allocation. Shows mean contribution lines per channel with + HDI uncertainty bands. Parameters ---------- - ax : plt.Axes - The axis to plot on. - channels : NDArray - Array of channel names. - allocated_spend : NDArray - Array of allocated spend values. - channel_contribution : NDArray - Array of channel contribution values. - """ - bar_width = 0.35 - opacity = 0.7 - index = range(len(channels)) - - # Plot allocated spend - bars1 = ax.bar( - index, - allocated_spend, - bar_width, - color="C0", - alpha=opacity, - label="Allocated Spend", - ) + samples : xr.Dataset + Dataset from budget allocation optimization containing channel + contributions over time. Required dimensions: + - 'channel': Channel names + - 'date': Time dimension + - 'sample': MCMC samples + + Required variables: + - Variable containing 'channel_contribution' (e.g., 'channel_contribution' + or 'channel_contribution_original_scale') + + Typically obtained from: ``mmm.allocate_budget_to_maximize_response(...)`` + hdi_prob : float, default 0.85 + Probability mass for HDI interval (between 0 and 1). + backend : str | None, optional + Backend to use for plotting. If None, uses global backend configuration. - # Create twin axis for contributions - ax2 = ax.twinx() - - # Plot contributions - bars2 = ax2.bar( - [i + bar_width for i in index], - channel_contribution, - bar_width, - color="C1", - alpha=opacity, - label="Channel Contribution", - ) + Returns + ------- + PlotCollection + arviz_plots PlotCollection object containing the plot. - # Labels and formatting - ax.set_xlabel("Channels") - ax.set_ylabel("Allocated Spend", color="C0", labelpad=10) - ax2.set_ylabel("Channel Contributions", color="C1", labelpad=10) + Use ``.show()`` to display or ``.save("filename")`` to save. + Unlike the legacy suite which returned ``(Figure, Axes)``, + this provides a unified interface across all backends. - # Set x-ticks in the middle of the bars - ax.set_xticks([i + bar_width / 2 for i in index]) - ax.set_xticklabels(channels) - ax.tick_params(axis="x", rotation=90) + Raises + ------ + ValueError + If required dimensions ('channel', 'date', 'sample') not found in samples. + ValueError + If no variable containing 'channel_contribution' found in samples. - # Turn off grid and add legend - ax.grid(False) - ax2.grid(False) + See Also + -------- + budget_allocation_roas : Plot ROI distributions from same allocation results + LegacyMMMPlotSuite.allocated_contribution_by_channel_over_time : Legacy implementation - bars = [bars1, bars2] - labels = ["Allocated Spend", "Channel Contributions"] - ax.legend(bars, labels, loc="best") + Notes + ----- + Breaking changes from legacy implementation: - def allocated_contribution_by_channel_over_time( - self, - samples: xr.Dataset, - scale_factor: float | None = None, - lower_quantile: float = 0.025, - upper_quantile: float = 0.975, - original_scale: bool = True, - figsize: tuple[float, float] = (10, 6), - ax: plt.Axes | None = None, - ) -> tuple[Figure, plt.Axes | NDArray[Axes]]: - """Plot the allocated contribution by channel with uncertainty intervals. - - This function visualizes the mean allocated contributions by channel along with - the uncertainty intervals defined by the lower and upper quantiles. - If additional dimensions besides 'channel', 'date', and 'sample' are present, - creates a subplot for each combination of these dimensions. + - Returns PlotCollection instead of (Figure, Axes) + - Lost scale_factor, lower_quantile, upper_quantile, figsize, ax parameters + - Now uses HDI instead of quantiles for uncertainty + - Automatic handling of extra dimensions (creates subplots) - Parameters - ---------- - samples : xr.Dataset - The dataset containing the samples of channel contributions. - Expected to have 'channel_contribution' variable with dimensions - 'channel', 'date', and 'sample'. - scale_factor : float, optional - Scale factor to convert to original scale, if original_scale=True. - If None and original_scale=True, assumes scale_factor=1. - lower_quantile : float, optional - The lower quantile for the uncertainty interval. Default is 0.025. - upper_quantile : float, optional - The upper quantile for the uncertainty interval. Default is 0.975. - original_scale : bool, optional - If True, the contributions are plotted on the original scale. Default is True. - figsize : tuple[float, float], optional - The size of the figure to be created. Default is (10, 6). - ax : plt.Axes, optional - The axis to plot on. If None, a new figure and axis will be created. - Only used when no extra dimensions are present. + Examples + -------- + Basic usage with budget optimization results: - Returns - ------- - fig : matplotlib.figure.Figure - The Figure object containing the plot. - axes : matplotlib.axes.Axes or numpy.ndarray of matplotlib.axes.Axes - The Axes object with the plot, or array of Axes for multiple subplots. + .. code-block:: python + + allocation_results = mmm.allocate_budget_to_maximize_response( + total_budget=100_000, budget_bounds={"lower": 0.5, "upper": 2.0} + ) + pc = mmm.plot.allocated_contribution_by_channel_over_time( + allocation_results + ) + pc.show() + + Custom HDI probability: + + .. code-block:: python + + pc = mmm.plot.allocated_contribution_by_channel_over_time( + allocation_results, hdi_prob=0.94 + ) + pc.show() + + Use different backend: + + .. code-block:: python + + pc = mmm.plot.allocated_contribution_by_channel_over_time( + allocation_results, backend="plotly" + ) + pc.show() """ # Check for expected dimensions and variables if "channel" not in samples.dims: @@ -1354,200 +1436,123 @@ def allocated_contribution_by_channel_over_time( ignored_dims = {"channel", "date", "sample"} extra_dims = [dim for dim in all_dims if dim not in ignored_dims] - # If no extra dimensions or using provided axis, create a single plot - if not extra_dims or ax is not None: - if ax is None: - fig, ax = plt.subplots(figsize=figsize) - else: - fig = ax.get_figure() - - channel_contribution = samples[channel_contrib_var] - - # Apply scale factor if in original scale - if original_scale and scale_factor is not None: - channel_contribution = channel_contribution * scale_factor - - # Plot mean values by channel - channel_contribution.mean(dim="sample").plot(hue="channel", ax=ax) - - # Add uncertainty intervals for each channel - for channel in samples.coords["channel"].values: - ax.fill_between( - x=channel_contribution.date.values, - y1=channel_contribution.sel(channel=channel).quantile( - lower_quantile, dim="sample" - ), - y2=channel_contribution.sel(channel=channel).quantile( - upper_quantile, dim="sample" - ), - alpha=0.1, - ) - - ax.set_xlabel("Date") - ax.set_ylabel("Channel Contribution") - ax.set_title("Allocated Contribution by Channel Over Time") - - fig.tight_layout() - return fig, ax - - # For multiple dimensions, create a grid of subplots - # Determine layout based on number of extra dimensions - if len(extra_dims) == 1: - # One extra dimension: use for rows - dim_values = [samples.coords[extra_dims[0]].values] - nrows = len(dim_values[0]) - ncols = 1 - subplot_dims = [extra_dims[0], None] - elif len(extra_dims) == 2: - # Two extra dimensions: one for rows, one for columns - dim_values = [ - samples.coords[extra_dims[0]].values, - samples.coords[extra_dims[1]].values, - ] - nrows = len(dim_values[0]) - ncols = len(dim_values[1]) - subplot_dims = extra_dims - else: - # Three or more: use first two for rows/columns, average over the rest - dim_values = [ - samples.coords[extra_dims[0]].values, - samples.coords[extra_dims[1]].values, - ] - nrows = len(dim_values[0]) - ncols = len(dim_values[1]) - subplot_dims = [extra_dims[0], extra_dims[1]] - - # Calculate figure size based on number of subplots - subplot_figsize = (figsize[0] * max(1, ncols), figsize[1] * max(1, nrows)) - fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=subplot_figsize) - - # Make axes indexable even for 1x1 grid - if nrows == 1 and ncols == 1: - axes = np.array([[axes]]) - elif nrows == 1: - axes = axes.reshape(1, -1) - elif ncols == 1: - axes = axes.reshape(-1, 1) - - # Create a subplot for each combination of dimension values - for i, row_val in enumerate(dim_values[0]): - for j, col_val in enumerate( - dim_values[1] if len(dim_values) > 1 else [None] - ): - ax = axes[i, j] - - # Select data for this subplot - selection = {subplot_dims[0]: row_val} - if col_val is not None: - selection[subplot_dims[1]] = col_val - - # Select channel contributions for this subplot - subset = samples[channel_contrib_var].sel(**selection) - - # Apply scale factor if needed - if original_scale and scale_factor is not None: - subset = subset * scale_factor - - # Plot mean values by channel for this subset - subset.mean(dim="sample").plot(hue="channel", ax=ax) - - # Add uncertainty intervals for each channel - for channel in samples.coords["channel"].values: - channel_data = subset.sel(channel=channel) - ax.fill_between( - x=channel_data.date.values, - y1=channel_data.quantile(lower_quantile, dim="sample"), - y2=channel_data.quantile(upper_quantile, dim="sample"), - alpha=0.1, - ) + pc = azp.PlotCollection.wrap( + samples[channel_contrib_var].to_dataset(), + cols=extra_dims, + aes={"color": ["channel"]}, + col_wrap=1, + figure_kwargs={ + "sharex": True, + }, + backend=backend, + ) - # Add subplot title based on dimension values - title_parts = [] - if subplot_dims[0] is not None: - title_parts.append(f"{subplot_dims[0]}={row_val}") - if subplot_dims[1] is not None: - title_parts.append(f"{subplot_dims[1]}={col_val}") + # plot hdi + hdi = samples[channel_contrib_var].azstats.hdi(hdi_prob, dim="sample") + pc.map( + azp.visuals.fill_between_y, + x=samples[channel_contrib_var]["date"], + y_bottom=hdi.sel(ci_bound="lower"), + y_top=hdi.sel(ci_bound="upper"), + alpha=0.2, + ) - base_title = "Allocated Contribution by Channel Over Time" - if title_parts: - ax.set_title(f"{base_title} - {', '.join(title_parts)}") - else: - ax.set_title(base_title) + # plot mean contribution line + pc.map( + azp.visuals.line_xy, + x=samples[channel_contrib_var]["date"], + y=samples[channel_contrib_var].mean(dim="sample"), + ) - ax.set_xlabel("Date") - ax.set_ylabel("Channel Contribution") + pc.map(azp.visuals.labelled_x, text="Date", ignore_aes={"color"}) + pc.map( + azp.visuals.labelled_y, text="Channel Contribution", ignore_aes={"color"} + ) + pc.map( + azp.visuals.labelled_title, + subset_info=True, + labeller=mix_labellers((NoVarLabeller, DimCoordLabeller))(), + ignore_aes={"color"}, + ) - fig.tight_layout() - return fig, axes + pc.add_legend(dim="channel") + return pc - def sensitivity_analysis( + def _sensitivity_analysis_plot( self, + data: xr.DataArray | xr.Dataset, hdi_prob: float = 0.94, - ax: plt.Axes | None = None, aggregation: dict[str, tuple[str, ...] | list[str]] | None = None, - subplot_kwargs: dict[str, Any] | None = None, - *, - plot_kwargs: dict[str, Any] | None = None, - ylabel: str = "Effect", - xlabel: str = "Sweep", - title: str | None = None, - add_figure_title: bool = False, - subplot_title_fallback: str = "Sensitivity Analysis", - ) -> tuple[Figure, NDArray[Axes]] | plt.Axes: - """Plot sensitivity analysis results. + backend: str | None = None, + ) -> PlotCollection: + """Private helper for plotting sensitivity analysis results. + + This is an internal method that performs the core plotting logic for + sensitivity analysis visualizations. Public methods (sensitivity_analysis, + uplift_curve, marginal_curve) handle data retrieval and call this helper. Parameters ---------- - hdi_prob : float, default 0.94 - HDI probability mass. - ax : plt.Axes, optional - The axis to plot on. - aggregation : dict, optional - Aggregation to apply to the data. - E.g., {"sum": ("channel",)} to sum over the channel dimension. - - Other Parameters - ---------------- - plot_kwargs : dict, optional - Keyword arguments forwarded to the underlying line plot. Defaults include - ``{"color": "C0"}``. - ylabel : str, optional - Y-axis label. Defaults to "Effect". - xlabel : str, optional - X-axis label. Defaults to "Sweep". - title : str, optional - Figure-level title to add when ``add_figure_title=True``. - add_figure_title : bool, optional - Whether to add a figure-level title. Defaults to ``False``. - subplot_title_fallback : str, optional - Fallback title used for subplot titles when no plotting dims exist. Defaults - to "Sensitivity Analysis". - - Examples - -------- - Basic run using stored results in `idata`: + data : xr.DataArray or xr.Dataset + Sensitivity analysis data to plot. Must have required dimensions: + - 'sample': MCMC samples + - 'sweep': Sweep values (e.g., multipliers or input values) - .. code-block:: python + If Dataset, should contain 'x' variable. - # Assuming you already ran a sweep and stored results - # under idata.sensitivity_analysis via SensitivityAnalysis.run_sweep(..., extend_idata=True) - ax = mmm.plot.sensitivity_analysis(hdi_prob=0.9) + IMPORTANT: This parameter is REQUIRED with no fallback to self.idata. + This design maintains separation of concerns - public methods handle + data retrieval, this helper handles pure plotting. + hdi_prob : float, default 0.94 + HDI probability mass (between 0 and 1). + aggregation : dict, optional + Aggregations to apply before plotting. + Keys are operations ("sum", "mean", "median"), values are dimension tuples. + Example: {"sum": ("channel",)} sums over the channel dimension. + backend : str | None, optional + Backend to use for plotting. If None, uses global backend configuration. - With aggregation over dimensions (e.g., sum over channels): + Returns + ------- + PlotCollection + arviz_plots PlotCollection object containing the plot. - .. code-block:: python + Note: Y-axis label is NOT set by this helper. Public methods calling + this helper should set appropriate labels (e.g., "Contribution", + "Uplift (%)", "Marginal Effect"). - ax = mmm.plot.sensitivity_analysis( - hdi_prob=0.9, - aggregation={"sum": ("channel",)}, - ) + Raises + ------ + ValueError + If data is missing required dimensions ('sample', 'sweep'). + + Notes + ----- + Design rationale for REQUIRED data parameter: + + - **Separation of concerns**: Public methods handle data location/retrieval + (from self.idata.sensitivity_analysis, self.idata.posterior, etc.), + this helper handles pure visualization logic. + - **Testability**: Easy to test plotting logic with mock data. + - **Cleaner implementation**: No monkey-patching or state manipulation. + - **Flexibility**: Can be reused for different data sources without + coupling to self.idata structure. + + This is a PRIVATE method (starts with _) and should not be called directly + by users. Use public methods instead: + - sensitivity_analysis(): General sensitivity analysis plots + - uplift_curve(): Uplift percentage plots + - marginal_curve(): Marginal effects plots """ - if not hasattr(self.idata, "sensitivity_analysis"): + # Handle Dataset or DataArray + x = data["x"] if isinstance(data, xr.Dataset) else data + + # Validate dimensions + required_dims = {"sample", "sweep"} + if not required_dims.issubset(set(x.dims)): raise ValueError( - "No sensitivity analysis results found. Run run_sweep() first." + f"Data must have dimensions {required_dims}, got {set(x.dims)}" ) - sa = self.idata.sensitivity_analysis # type: ignore - x = sa["x"] if isinstance(sa, xr.Dataset) else sa # Coerce numeric dtype try: x = x.astype(float) @@ -1571,199 +1576,247 @@ def sensitivity_analysis( x = x.mean(dim=dims_list) else: x = x.median(dim=dims_list) + # Determine plotting dimensions (excluding sample & sweep) - plot_dims = [d for d in x.dims if d not in {"sample", "sweep"}] - if plot_dims: - dim_combinations = list( - itertools.product(*[x.coords[d].values for d in plot_dims]) - ) - else: - dim_combinations = [()] + plot_dims = set(x.dims) - {"sample", "sweep"} + + pc = azp.PlotCollection.wrap( + x.to_dataset(), + cols=plot_dims, + col_wrap=2, + figure_kwargs={ + "sharex": True, + }, + backend=backend, + ) - n_panels = len(dim_combinations) + # plot hdi + hdi = x.azstats.hdi(hdi_prob, dim="sample") + pc.map( + azp.visuals.fill_between_y, + x=x["sweep"], + y_bottom=hdi.sel(ci_bound="lower"), + y_top=hdi.sel(ci_bound="upper"), + alpha=0.4, + color="C0", + ) + # plot aggregated line + pc.map( + azp.visuals.line_xy, + x=x["sweep"], + y=x.mean(dim="sample"), + color="C0", + ) + # add labels + pc.map(azp.visuals.labelled_x, text="Sweep") + pc.map( + azp.visuals.labelled_title, + subset_info=True, + labeller=mix_labellers((NoVarLabeller, DimCoordLabeller))(), + ) + return pc - # Handle axis/grid creation - subplot_kwargs = {**(subplot_kwargs or {})} - nrows_user = subplot_kwargs.pop("nrows", None) - ncols_user = subplot_kwargs.pop("ncols", None) - if nrows_user is not None and ncols_user is not None: - raise ValueError( - "Specify only one of 'nrows' or 'ncols' in subplot_kwargs." - ) + def sensitivity_analysis( + self, + data: xr.DataArray | xr.Dataset | None = None, + hdi_prob: float = 0.94, + aggregation: dict[str, tuple[str, ...] | list[str]] | None = None, + backend: str | None = None, + ) -> PlotCollection: + """Plot sensitivity analysis results showing response to input changes. - if n_panels > 1: - if ax is not None: - raise ValueError( - "Multiple sensitivity panels detected; please omit 'ax' and use 'subplot_kwargs' instead." - ) - if ncols_user is not None: - ncols = ncols_user - nrows = int(np.ceil(n_panels / ncols)) - elif nrows_user is not None: - nrows = nrows_user - ncols = int(np.ceil(n_panels / nrows)) - else: - ncols = max(1, int(np.ceil(np.sqrt(n_panels)))) - nrows = int(np.ceil(n_panels / ncols)) - subplot_kwargs.setdefault("figsize", (ncols * 4.0, nrows * 3.0)) - fig, axes_grid = plt.subplots( - nrows=nrows, - ncols=ncols, - **subplot_kwargs, - ) - if isinstance(axes_grid, plt.Axes): - axes_grid = np.array([[axes_grid]]) - elif axes_grid.ndim == 1: - axes_grid = axes_grid.reshape(1, -1) - axes_array = axes_grid - else: - if ax is not None: - axes_array = np.array([[ax]]) - fig = ax.figure - else: - if ncols_user is not None or nrows_user is not None: - subplot_kwargs.setdefault("figsize", (4.0, 3.0)) - fig, single_ax = plt.subplots( - nrows=1, - ncols=1, - **subplot_kwargs, - ) - else: - fig, single_ax = plt.subplots() - axes_array = np.array([[single_ax]]) - - # Merge plotting kwargs with defaults - _plot_kwargs = {"color": "C0"} - if plot_kwargs: - _plot_kwargs.update(plot_kwargs) - _line_color = _plot_kwargs.get("color", "C0") - - axes_flat = axes_array.flatten() - for idx, combo in enumerate(dim_combinations): - current_ax = axes_flat[idx] - indexers = dict(zip(plot_dims, combo, strict=False)) if plot_dims else {} - subset = x.sel(**indexers) if indexers else x - subset = subset.squeeze(drop=True) - subset = subset.astype(float) - - if "sweep" in subset.dims: - sweep_dim = "sweep" - else: - cand = [d for d in subset.dims if d != "sample"] - if not cand: - raise ValueError( - "Expected 'sweep' (or a non-sample) dimension in sensitivity results." - ) - sweep_dim = cand[0] + Visualizes how model outputs (e.g., channel contributions) change as inputs + (e.g., channel spend) are varied. Shows mean response line and HDI bands + across sweep values. - sweep = ( - np.asarray(subset.coords[sweep_dim].values) - if sweep_dim in subset.coords - else np.arange(subset.sizes[sweep_dim]) - ) + Parameters + ---------- + data : xr.DataArray or xr.Dataset, optional + Sensitivity analysis data with required dimensions: + - 'sample': MCMC samples + - 'sweep': Sweep values (e.g., multipliers) + + If Dataset, should contain 'x' variable. + If None, uses self.idata.sensitivity_analysis. + This parameter allows: + - Testing with mock sensitivity analysis results + - Plotting external sweep results + - Comparing different sensitivity analyses + hdi_prob : float, default 0.94 + HDI probability mass (between 0 and 1). + aggregation : dict, optional + Aggregations to apply before plotting. + Keys: "sum", "mean", or "median" + Values: tuple of dimension names - mean = subset.mean("sample") if "sample" in subset.dims else subset - reduce_dims = [d for d in mean.dims if d != sweep_dim] - if reduce_dims: - mean = mean.sum(dim=reduce_dims) + Example: ``{"sum": ("channel",)}`` sums over channels before plotting. + backend : str | None, optional + Backend to use for plotting. If None, uses global backend configuration. - if "sample" in subset.dims: - hdi = az.hdi(subset, hdi_prob=hdi_prob, input_core_dims=[["sample"]]) - if isinstance(hdi, xr.Dataset): - hdi = hdi[next(iter(hdi.data_vars))] - else: - hdi = xr.concat([mean, mean], dim="hdi").assign_coords( - hdi=np.array([0, 1]) - ) + Returns + ------- + PlotCollection + arviz_plots PlotCollection object containing the plot. + + Use ``.show()`` to display or ``.save("filename")`` to save. + Unlike the legacy suite which returned ``(Figure, Axes)`` or ``Axes``, + this provides a unified interface across all backends. + + Raises + ------ + ValueError + If no sensitivity analysis data found in self.idata and no data provided. + + See Also + -------- + uplift_curve : Plot uplift percentages (derived from sensitivity analysis) + marginal_curve : Plot marginal effects (derived from sensitivity analysis) + LegacyMMMPlotSuite.sensitivity_analysis : Legacy matplotlib-only implementation + + Notes + ----- + Breaking changes from legacy implementation: + + - Returns PlotCollection instead of (Figure, Axes) or Axes + - Lost ax, subplot_kwargs, plot_kwargs parameters (use backend methods) + - Cleaner implementation without monkey-patching + - Data parameter for explicit data passing (no side effects on self.idata) - reduce_hdi = [d for d in hdi.dims if d not in (sweep_dim, "hdi")] - if reduce_hdi: - hdi = hdi.sum(dim=reduce_hdi) - if set(hdi.dims) == {sweep_dim, "hdi"} and list(hdi.dims) != [ - sweep_dim, - "hdi", - ]: - hdi = hdi.transpose(sweep_dim, "hdi") # type: ignore - - current_ax.plot(sweep, np.asarray(mean.values, dtype=float), **_plot_kwargs) - az.plot_hdi( - x=sweep, - hdi_data=np.asarray(hdi.values, dtype=float), - hdi_prob=hdi_prob, - color=_line_color, - ax=current_ax, + Examples + -------- + Run sweep and plot results: + + .. code-block:: python + + from pymc_marketing.mmm.sensitivity_analysis import SensitivityAnalysis + + # Run sensitivity sweep + sweeps = np.linspace(0.5, 1.5, 11) + sa = SensitivityAnalysis(mmm.model, mmm.idata) + results = sa.run_sweep( + var_input="channel_data", + sweep_values=sweeps, + var_names="channel_contribution", + sweep_type="multiplicative", + extend_idata=True, # Store in idata ) - title = self._build_subplot_title( - dims=plot_dims, - combo=combo, - fallback_title=subplot_title_fallback, + # Plot stored results + pc = mmm.plot.sensitivity_analysis(hdi_prob=0.9) + pc.show() + + Aggregate over channels: + + .. code-block:: python + + pc = mmm.plot.sensitivity_analysis( + hdi_prob=0.9, aggregation={"sum": ("channel",)} ) - current_ax.set_title(title) - current_ax.set_xlabel(xlabel) - current_ax.set_ylabel(ylabel) + pc.show() - # Hide any unused axes (happens if grid > panels) - for ax_extra in axes_flat[n_panels:]: - ax_extra.set_visible(False) + Use different backend: - # Optional figure-level title: only for multi-panel layouts, default color (black) - if add_figure_title and title is not None and n_panels > 1: - fig.suptitle(title) + .. code-block:: python + + pc = mmm.plot.sensitivity_analysis(backend="plotly") + pc.show() + + Provide explicit data: - if n_panels == 1: - return axes_array[0, 0] + .. code-block:: python + + external_results = sa.run_sweep(...) # Not stored in idata + pc = mmm.plot.sensitivity_analysis(data=external_results) + pc.show() + """ + # Retrieve data if not provided + data = self._get_data_or_fallback( + data, "sensitivity_analysis", "sensitivity analysis results" + ) - fig.tight_layout() - return fig, axes_array + pc = self._sensitivity_analysis_plot( + data=data, hdi_prob=hdi_prob, aggregation=aggregation, backend=backend + ) + pc.map(azp.visuals.labelled_y, text="Contribution") + return pc def uplift_curve( self, + data: xr.DataArray | xr.Dataset | None = None, hdi_prob: float = 0.94, - ax: plt.Axes | None = None, aggregation: dict[str, tuple[str, ...] | list[str]] | None = None, - subplot_kwargs: dict[str, Any] | None = None, - *, - plot_kwargs: dict[str, Any] | None = None, - ylabel: str = "Uplift", - xlabel: str = "Sweep", - title: str | None = "Uplift curve", - add_figure_title: bool = True, - ) -> tuple[Figure, NDArray[Axes]] | plt.Axes: - """ - Plot precomputed uplift curves stored under `idata.sensitivity_analysis['uplift_curve']`. + backend: str | None = None, + ) -> PlotCollection: + """Plot uplift curves showing percentage change relative to baseline. + + Visualizes relative percentage changes in model outputs (e.g., channel + contributions) as inputs are varied, compared to a reference point. + Shows mean uplift line and HDI bands. Parameters ---------- + data : xr.DataArray or xr.Dataset, optional + Uplift curve data computed from sensitivity analysis. + If Dataset, should contain 'uplift_curve' variable. + If None, uses self.idata.sensitivity_analysis['uplift_curve']. + + Must be precomputed using: + ``SensitivityAnalysis.compute_uplift_curve_respect_to_base(...)`` + This parameter allows: + - Testing with mock uplift curve data + - Plotting externally computed uplift curves + - Comparing uplift curves from different models hdi_prob : float, default 0.94 - HDI probability mass. - ax : plt.Axes, optional - The axis to plot on. + HDI probability mass (between 0 and 1). aggregation : dict, optional - Aggregation to apply to the data. - E.g., {"sum": ("channel",)} to sum over the channel dimension. - subplot_kwargs : dict, optional - Additional subplot configuration forwarded to :meth:`sensitivity_analysis`. - plot_kwargs : dict, optional - Keyword arguments forwarded to the underlying line plot. If not provided, defaults - are used by :meth:`sensitivity_analysis` (e.g., color "C0"). - ylabel : str, optional - Y-axis label. Defaults to "Uplift". - xlabel : str, optional - X-axis label. Defaults to "Sweep". - title : str, optional - Figure-level title to add when ``add_figure_title=True``. Defaults to "Uplift curve". - add_figure_title : bool, optional - Whether to add a figure-level title. Defaults to ``True``. + Aggregations to apply before plotting. + Keys: "sum", "mean", or "median" + Values: tuple of dimension names + + Example: ``{"sum": ("channel",)}`` sums over channels before plotting. + backend : str | None, optional + Backend to use for plotting. If None, uses global backend configuration. + + Returns + ------- + PlotCollection + arviz_plots PlotCollection object containing the plot. + + Use ``.show()`` to display or ``.save("filename")`` to save. + Unlike the legacy suite which returned ``(Figure, Axes)`` or ``Axes``, + this provides a unified interface across all backends. + + Raises + ------ + ValueError + If no uplift curve data found in self.idata and no data provided. + ValueError + If 'uplift_curve' variable not found in sensitivity_analysis group. + + See Also + -------- + sensitivity_analysis : Plot raw sensitivity analysis results + marginal_curve : Plot marginal effects (absolute changes) + LegacyMMMPlotSuite.uplift_curve : Legacy matplotlib-only implementation + + Notes + ----- + Breaking changes from legacy implementation: + + - Returns PlotCollection instead of (Figure, Axes) or Axes + - Cleaner implementation without monkey-patching + - No longer modifies self.idata.sensitivity_analysis temporarily + - Data parameter for explicit data passing Examples -------- - Persist uplift curve and plot: + Compute and plot uplift curve: .. code-block:: python from pymc_marketing.mmm.sensitivity_analysis import SensitivityAnalysis + # Run sensitivity sweep sweeps = np.linspace(0.5, 1.5, 11) sa = SensitivityAnalysis(mmm.model, mmm.idata) results = sa.run_sweep( @@ -1772,99 +1825,156 @@ def uplift_curve( var_names="channel_contribution", sweep_type="multiplicative", ) + + # Compute uplift relative to baseline (ref=1.0) uplift = sa.compute_uplift_curve_respect_to_base( - results, ref=1.0, extend_idata=True + results, + ref=1.0, + extend_idata=True, # Store in idata ) - _ = mmm.plot.uplift_curve(hdi_prob=0.9) + + # Plot stored uplift curve + pc = mmm.plot.uplift_curve(hdi_prob=0.9) + pc.show() + + Aggregate over channels: + + .. code-block:: python + + pc = mmm.plot.uplift_curve(aggregation={"sum": ("channel",)}) + pc.show() + + Use different backend: + + .. code-block:: python + + pc = mmm.plot.uplift_curve(backend="plotly") + pc.show() + + Provide explicit data: + + .. code-block:: python + + uplift_data = sa.compute_uplift_curve_respect_to_base(results, ref=1.0) + pc = mmm.plot.uplift_curve(data=uplift_data) + pc.show() """ - if not hasattr(self.idata, "sensitivity_analysis"): - raise ValueError( - "No sensitivity analysis results found in 'self.idata'. " - "Run 'mmm.sensitivity.run_sweep()' first." + # Retrieve data if not provided + if data is None: + sa_group = self._get_data_or_fallback( + None, "sensitivity_analysis", "sensitivity analysis results" ) - - sa_group = self.idata.sensitivity_analysis # type: ignore - if isinstance(sa_group, xr.Dataset): - if "uplift_curve" not in sa_group: + if isinstance(sa_group, xr.Dataset): + if "uplift_curve" not in sa_group: + raise ValueError( + "Expected 'uplift_curve' in idata.sensitivity_analysis. " + "Use SensitivityAnalysis.compute_uplift_curve_respect_to_base(..., extend_idata=True)." + ) + data = sa_group["uplift_curve"] + else: raise ValueError( - "Expected 'uplift_curve' in idata.sensitivity_analysis. " - "Use SensitivityAnalysis.compute_uplift_curve_respect_to_base(..., extend_idata=True)." + "sensitivity_analysis does not contain 'uplift_curve'. Did you persist it to idata?" ) - data_var = sa_group["uplift_curve"] - else: - raise ValueError( - "sensitivity_analysis does not contain 'uplift_curve'. Did you persist it to idata?" - ) - # Delegate to a thin wrapper by temporarily constructing a Dataset - tmp_idata = xr.Dataset({"x": data_var}) - # Monkey-patch minimal attributes needed - tmp_idata["x"].attrs.update(getattr(sa_group, "attrs", {})) # type: ignore - # Temporarily swap - original_group = self.idata.sensitivity_analysis # type: ignore - try: - self.idata.sensitivity_analysis = tmp_idata # type: ignore - return self.sensitivity_analysis( - hdi_prob=hdi_prob, - ax=ax, - aggregation=aggregation, - subplot_kwargs=subplot_kwargs, - subplot_title_fallback="Uplift curve", - plot_kwargs=plot_kwargs, - ylabel=ylabel, - xlabel=xlabel, - title=title, - add_figure_title=add_figure_title, - ) - finally: - self.idata.sensitivity_analysis = original_group # type: ignore + # Handle Dataset input + if isinstance(data, xr.Dataset): + if "uplift_curve" in data: + data = data["uplift_curve"] + elif "x" in data: + data = data["x"] + else: + raise ValueError("Dataset must contain 'uplift_curve' or 'x' variable.") + + # Call helper with data (no more monkey-patching!) + pc = self._sensitivity_analysis_plot( + data=data, + hdi_prob=hdi_prob, + aggregation=aggregation, + backend=backend, + ) + pc.map(azp.visuals.labelled_y, text="Uplift (%)") + return pc def marginal_curve( self, + data: xr.DataArray | xr.Dataset | None = None, hdi_prob: float = 0.94, - ax: plt.Axes | None = None, aggregation: dict[str, tuple[str, ...] | list[str]] | None = None, - subplot_kwargs: dict[str, Any] | None = None, - *, - plot_kwargs: dict[str, Any] | None = None, - ylabel: str = "Marginal effect", - xlabel: str = "Sweep", - title: str | None = "Marginal effects", - add_figure_title: bool = True, - ) -> tuple[Figure, NDArray[Axes]] | plt.Axes: - """ - Plot precomputed marginal effects stored under `idata.sensitivity_analysis['marginal_effects']`. + backend: str | None = None, + ) -> PlotCollection: + """Plot marginal effects showing absolute rate of change. + + Visualizes the instantaneous rate of change (derivative) of model outputs + with respect to inputs. Shows how much output changes per unit change in + input at each sweep value. Parameters ---------- + data : xr.DataArray or xr.Dataset, optional + Marginal effects data computed from sensitivity analysis. + If Dataset, should contain 'marginal_effects' variable. + If None, uses self.idata.sensitivity_analysis['marginal_effects']. + + Must be precomputed using: + ``SensitivityAnalysis.compute_marginal_effects(...)`` + This parameter allows: + - Testing with mock marginal effects data + - Plotting externally computed marginal effects + - Comparing marginal effects from different models hdi_prob : float, default 0.94 - HDI probability mass. - ax : plt.Axes, optional - The axis to plot on. + HDI probability mass (between 0 and 1). aggregation : dict, optional - Aggregation to apply to the data. - E.g., {"sum": ("channel",)} to sum over the channel dimension. - subplot_kwargs : dict, optional - Additional subplot configuration forwarded to :meth:`sensitivity_analysis`. - plot_kwargs : dict, optional - Keyword arguments forwarded to the underlying line plot. Defaults to ``{"color": "C1"}``. - ylabel : str, optional - Y-axis label. Defaults to "Marginal effect". - xlabel : str, optional - X-axis label. Defaults to "Sweep". - title : str, optional - Figure-level title to add when ``add_figure_title=True``. Defaults to "Marginal effects". - add_figure_title : bool, optional - Whether to add a figure-level title. Defaults to ``True``. + Aggregations to apply before plotting. + Keys: "sum", "mean", or "median" + Values: tuple of dimension names + + Example: ``{"sum": ("channel",)}`` sums over channels before plotting. + backend : str | None, optional + Backend to use for plotting. If None, uses global backend configuration. + + Returns + ------- + PlotCollection + arviz_plots PlotCollection object containing the plot. + + Use ``.show()`` to display or ``.save("filename")`` to save. + Unlike the legacy suite which returned ``(Figure, Axes)`` or ``Axes``, + this provides a unified interface across all backends. + + Raises + ------ + ValueError + If no marginal effects data found in self.idata and no data provided. + ValueError + If 'marginal_effects' variable not found in sensitivity_analysis group. + + See Also + -------- + sensitivity_analysis : Plot raw sensitivity analysis results + uplift_curve : Plot uplift percentages (relative changes) + LegacyMMMPlotSuite.marginal_curve : Legacy matplotlib-only implementation + + Notes + ----- + Breaking changes from legacy implementation: + + - Returns PlotCollection instead of (Figure, Axes) or Axes + - Cleaner implementation without monkey-patching + - No longer modifies self.idata.sensitivity_analysis temporarily + - Data parameter for explicit data passing + + Marginal effects show the **slope** of the sensitivity curve, helping + identify where returns are diminishing most rapidly. Examples -------- - Persist marginal effects and plot: + Compute and plot marginal effects: .. code-block:: python from pymc_marketing.mmm.sensitivity_analysis import SensitivityAnalysis + # Run sensitivity sweep sweeps = np.linspace(0.5, 1.5, 11) sa = SensitivityAnalysis(mmm.model, mmm.idata) results = sa.run_sweep( @@ -1873,51 +1983,142 @@ def marginal_curve( var_names="channel_contribution", sweep_type="multiplicative", ) - me = sa.compute_marginal_effects(results, extend_idata=True) - _ = mmm.plot.marginal_curve(hdi_prob=0.9) + + # Compute marginal effects (derivatives) + me = sa.compute_marginal_effects( + results, + extend_idata=True, # Store in idata + ) + + # Plot stored marginal effects + pc = mmm.plot.marginal_curve(hdi_prob=0.9) + pc.show() + + Aggregate over channels: + + .. code-block:: python + + pc = mmm.plot.marginal_curve(aggregation={"sum": ("channel",)}) + pc.show() + + Use different backend: + + .. code-block:: python + + pc = mmm.plot.marginal_curve(backend="plotly") + pc.show() + + Provide explicit data: + + .. code-block:: python + + marginal_data = sa.compute_marginal_effects(results) + pc = mmm.plot.marginal_curve(data=marginal_data) + pc.show() """ - if not hasattr(self.idata, "sensitivity_analysis"): - raise ValueError( - "No sensitivity analysis results found in 'self.idata'. " - "Run 'mmm.sensitivity.run_sweep()' first." + # Retrieve data if not provided + if data is None: + sa_group = self._get_data_or_fallback( + None, "sensitivity_analysis", "sensitivity analysis results" ) + if isinstance(sa_group, xr.Dataset): + if "marginal_effects" not in sa_group: + raise ValueError( + "Expected 'marginal_effects' in idata.sensitivity_analysis. " + "Use SensitivityAnalysis.compute_marginal_effects(..., extend_idata=True)." + ) + data = sa_group["marginal_effects"] + else: + raise ValueError( + "sensitivity_analysis does not contain 'marginal_effects'. Did you persist it to idata?" + ) - sa_group = self.idata.sensitivity_analysis # type: ignore - if isinstance(sa_group, xr.Dataset): - if "marginal_effects" not in sa_group: + # Handle Dataset input + if isinstance(data, xr.Dataset): + if "marginal_effects" in data: + data = data["marginal_effects"] + elif "x" in data: + data = data["x"] + else: raise ValueError( - "Expected 'marginal_effects' in idata.sensitivity_analysis. " - "Use SensitivityAnalysis.compute_marginal_effects(..., extend_idata=True)." + "Dataset must contain 'marginal_effects' or 'x' variable." ) - data_var = sa_group["marginal_effects"] - else: - raise ValueError( - "sensitivity_analysis does not contain 'marginal_effects'. Did you persist it to idata?" - ) - # We want a different y-label and color - # Temporarily swap group to reuse plotting logic - tmp = xr.Dataset({"x": data_var}) - tmp["x"].attrs.update(getattr(sa_group, "attrs", {})) # type: ignore - original = self.idata.sensitivity_analysis # type: ignore - try: - self.idata.sensitivity_analysis = tmp # type: ignore - # Reuse core plotting; percentage=False by definition - # Merge defaults for plot_kwargs if not provided - _plot_kwargs = {"color": "C1"} - if plot_kwargs: - _plot_kwargs.update(plot_kwargs) - return self.sensitivity_analysis( - hdi_prob=hdi_prob, - ax=ax, - aggregation=aggregation, - subplot_kwargs=subplot_kwargs, - subplot_title_fallback="Marginal effects", - plot_kwargs=_plot_kwargs, - ylabel=ylabel, - xlabel=xlabel, - title=title, - add_figure_title=add_figure_title, - ) - finally: - self.idata.sensitivity_analysis = original # type: ignore + # Call helper with data (no more monkey-patching!) + pc = self._sensitivity_analysis_plot( + data=data, + hdi_prob=hdi_prob, + aggregation=aggregation, + backend=backend, + ) + pc.map(azp.visuals.labelled_y, text="Marginal Effect") + return pc + + def budget_allocation(self, *args, **kwargs): + """ + Create bar chart comparing allocated spend and channel contributions. + + .. deprecated:: 0.18.0 + This method was removed in MMMPlotSuite v2. The arviz_plots library + used in v2 doesn't support this specific chart type. See alternatives below. + + Raises + ------ + NotImplementedError + This method is not available in MMMPlotSuite v2. + + Notes + ----- + Alternatives: + + 1. **For ROI distributions**: Use :meth:`budget_allocation_roas` + (different purpose but related to budget allocation) + + 2. **To use the old method**: Switch to legacy suite: + + .. code-block:: python + + from pymc_marketing.mmm import mmm_plot_config + + mmm_plot_config["plot.use_v2"] = False + mmm.plot.budget_allocation(samples) + + 3. **Custom implementation**: Create bar chart using samples data: + + .. code-block:: python + + import matplotlib.pyplot as plt + + channel_contrib = samples["channel_contribution"].mean(...) + allocated_spend = samples["allocation"] + # Create custom bar chart with matplotlib + + See Also + -------- + budget_allocation_roas : Plot ROI distributions by channel + + Examples + -------- + Use legacy suite temporarily: + + .. code-block:: python + + from pymc_marketing.mmm import mmm_plot_config + + original = mmm_plot_config.get("plot.use_v2") + try: + mmm_plot_config["plot.use_v2"] = False + fig, ax = mmm.plot.budget_allocation(samples) + fig.savefig("budget.png") + finally: + mmm_plot_config["plot.use_v2"] = original + """ + raise NotImplementedError( + "budget_allocation() was removed in MMMPlotSuite v2.\n\n" + "The new arviz_plots-based implementation doesn't support this chart type.\n\n" + "Alternatives:\n" + " 1. For ROI distributions: use budget_allocation_roas()\n" + " 2. To use old method: set mmm_plot_config['plot.use_v2'] = False\n" + " 3. Implement custom bar chart using the samples data\n\n" + "See documentation: https://docs.pymc-marketing.io/en/latest/mmm/plotting_migration.html#budget-allocation" + ) diff --git a/pyproject.toml b/pyproject.toml index 92e3aadb4..379bb90e7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,7 @@ dependencies = [ "pyprojroot", "pymc-extras>=0.4.0", "preliz>=0.20.0", + "arviz_plots[matplotlib]>=0.7.0" ] [project.optional-dependencies] @@ -92,6 +93,7 @@ test = [ "osqp<1.0.0,>=0.6.2", "pygraphviz", "preliz>=0.20.0", + "arviz_plots[plotly,bokeh]>=0.7.0" ] [tool.hatch.build.targets.sdist] diff --git a/tests/mmm/conftest.py b/tests/mmm/conftest.py new file mode 100644 index 000000000..1c36e609d --- /dev/null +++ b/tests/mmm/conftest.py @@ -0,0 +1,307 @@ +# Copyright 2022 - 2025 The PyMC Labs Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Shared fixtures for MMM plotting tests.""" + +import arviz as az +import numpy as np +import pandas as pd +import pytest +import xarray as xr + + +@pytest.fixture +def mock_posterior_data(): + """Mock posterior Dataset for testing data parameters.""" + rng = np.random.default_rng(42) + return xr.Dataset( + { + "intercept": xr.DataArray( + rng.normal(size=(4, 100, 52)), + dims=("chain", "draw", "date"), + coords={ + "chain": np.arange(4), + "draw": np.arange(100), + "date": pd.date_range("2025-01-01", periods=52, freq="W"), + }, + ) + } + ) + + +@pytest.fixture +def mock_constant_data(): + """Mock constant_data Dataset for saturation plots.""" + rng = np.random.default_rng(42) + n_dates = 52 + n_channels = 3 + + return xr.Dataset( + { + "channel_data": xr.DataArray( + rng.uniform(0, 100, size=(n_dates, n_channels)), + dims=("date", "channel"), + coords={ + "date": pd.date_range("2025-01-01", periods=n_dates, freq="W"), + "channel": ["TV", "Radio", "Digital"], + }, + ), + "channel_scale": xr.DataArray( + rng.uniform(0.5, 2.0, size=(n_channels,)), + dims=("channel",), + coords={"channel": ["TV", "Radio", "Digital"]}, + ), + "target_scale": xr.DataArray(1.0), + } + ) + + +@pytest.fixture +def mock_sensitivity_data(): + """Mock sensitivity analysis data.""" + rng = np.random.default_rng(42) + return xr.Dataset( + { + "x": xr.DataArray( + rng.normal(size=(100, 20)), + dims=("sample", "sweep"), + coords={ + "sample": np.arange(100), + "sweep": np.linspace(0, 1, 20), + }, + ) + } + ) + + +@pytest.fixture +def mock_idata_with_posterior(): + """Mock InferenceData with posterior data.""" + rng = np.random.default_rng(42) + posterior = xr.Dataset( + { + "intercept": xr.DataArray( + rng.normal(size=(4, 100, 52)), + dims=("chain", "draw", "date"), + coords={ + "chain": np.arange(4), + "draw": np.arange(100), + "date": pd.date_range("2025-01-01", periods=52, freq="W"), + }, + ) + } + ) + return az.InferenceData(posterior=posterior) + + +@pytest.fixture +def mock_idata_with_uplift_curve(): + """Mock InferenceData with uplift_curve in sensitivity_analysis.""" + rng = np.random.default_rng(42) + + posterior = xr.Dataset( + { + "intercept": xr.DataArray( + rng.normal(size=(4, 100)), + dims=("chain", "draw"), + ) + } + ) + + sensitivity_analysis = xr.Dataset( + { + "uplift_curve": xr.DataArray( + rng.normal(size=(100, 20)), + dims=("sample", "sweep"), + ) + } + ) + + return az.InferenceData( + posterior=posterior, sensitivity_analysis=sensitivity_analysis + ) + + +@pytest.fixture +def mock_idata_with_sensitivity(): + """Mock InferenceData with sensitivity_analysis group.""" + rng = np.random.default_rng(42) + + posterior = xr.Dataset( + { + "intercept": xr.DataArray( + rng.normal(size=(4, 100)), + dims=("chain", "draw"), + ) + } + ) + + sensitivity_analysis = xr.Dataset( + { + "x": xr.DataArray( + rng.normal(size=(100, 20)), + dims=("sample", "sweep"), + coords={ + "sample": np.arange(100), + "sweep": np.linspace(0, 1, 20), + }, + ) + } + ) + + return az.InferenceData( + posterior=posterior, sensitivity_analysis=sensitivity_analysis + ) + + +@pytest.fixture +def mock_idata_for_legacy(): + """Mock InferenceData for legacy suite tests.""" + rng = np.random.default_rng(42) + dates = pd.date_range("2025-01-01", periods=52, freq="W") + + posterior_predictive = xr.Dataset( + { + "y": xr.DataArray( + rng.normal(size=(4, 100, 52)), + dims=("chain", "draw", "date"), + coords={ + "chain": np.arange(4), + "draw": np.arange(100), + "date": dates, + }, + ) + } + ) + + return az.InferenceData(posterior_predictive=posterior_predictive) + + +@pytest.fixture +def mock_idata(): + """Mock InferenceData for compatibility testing.""" + rng = np.random.default_rng(42) + + posterior = xr.Dataset( + { + "intercept": xr.DataArray( + rng.normal(size=(4, 100, 52)), + dims=("chain", "draw", "date"), + coords={ + "chain": np.arange(4), + "draw": np.arange(100), + "date": pd.date_range("2025-01-01", periods=52, freq="W"), + }, + ), + "channel_contribution": xr.DataArray( + rng.normal(size=(4, 100, 52, 3)), + dims=("chain", "draw", "date", "channel"), + coords={ + "chain": np.arange(4), + "draw": np.arange(100), + "date": pd.date_range("2025-01-01", periods=52, freq="W"), + "channel": ["TV", "Radio", "Digital"], + }, + ), + } + ) + + posterior_predictive = xr.Dataset( + { + "y": xr.DataArray( + rng.normal(size=(4, 100, 52)), + dims=("chain", "draw", "date"), + coords={ + "chain": np.arange(4), + "draw": np.arange(100), + "date": pd.date_range("2025-01-01", periods=52, freq="W"), + }, + ), + } + ) + + constant_data = xr.Dataset( + { + "channel_data": xr.DataArray( + rng.uniform(0, 100, size=(52, 3)), + dims=("date", "channel"), + coords={ + "date": pd.date_range("2025-01-01", periods=52, freq="W"), + "channel": ["TV", "Radio", "Digital"], + }, + ), + "channel_scale": xr.DataArray( + rng.uniform(0.5, 2.0, size=(3,)), + dims=("channel",), + coords={"channel": ["TV", "Radio", "Digital"]}, + ), + "target_scale": xr.DataArray(1.0), + } + ) + + return az.InferenceData( + posterior=posterior, + posterior_predictive=posterior_predictive, + constant_data=constant_data, + ) + + +@pytest.fixture +def mock_mmm(mock_idata): + """Mock MMM instance with idata for compatibility testing.""" + from unittest.mock import Mock + + from pymc_marketing.mmm.multidimensional import MMM + + mmm = Mock(spec=MMM) + mmm.idata = mock_idata + mmm._validate_model_was_built = Mock() + mmm._validate_idata_exists = Mock() + + # Make .plot property work with actual implementation + type(mmm).plot = MMM.plot + + return mmm + + +@pytest.fixture +def mock_mmm_fitted(mock_mmm): + """Mock fitted MMM instance for compatibility testing.""" + # Same as mock_mmm, just clearer name for tests that need fitted model + return mock_mmm + + +@pytest.fixture +def mock_allocation_samples(): + """Mock samples dataset for budget allocation tests.""" + rng = np.random.default_rng(42) + + return xr.Dataset( + { + "channel_contribution_original_scale": xr.DataArray( + rng.normal(size=(4, 100, 52, 3)), + dims=("chain", "draw", "date", "channel"), + coords={ + "chain": np.arange(4), + "draw": np.arange(100), + "date": pd.date_range("2025-01-01", periods=52, freq="W"), + "channel": ["TV", "Radio", "Digital"], + }, + ), + "allocation": xr.DataArray( + rng.uniform(100, 1000, size=(3,)), + dims=("channel",), + coords={"channel": ["TV", "Radio", "Digital"]}, + ), + } + ) diff --git a/tests/mmm/test_legacy_plot.py b/tests/mmm/test_legacy_plot.py new file mode 100644 index 000000000..6a879a9c5 --- /dev/null +++ b/tests/mmm/test_legacy_plot.py @@ -0,0 +1,1054 @@ +# Copyright 2022 - 2025 The PyMC Labs Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import warnings + +import arviz as az +import numpy as np +import pandas as pd +import pytest +import xarray as xr +from matplotlib.axes import Axes +from matplotlib.figure import Figure + +from pymc_marketing.mmm import GeometricAdstock, LogisticSaturation + +with warnings.catch_warnings(): + warnings.simplefilter("ignore", FutureWarning) + from pymc_marketing.mmm.multidimensional import MMM + +from pymc_marketing.mmm.legacy_plot import LegacyMMMPlotSuite + + +@pytest.fixture +def mmm(): + return MMM( + date_column="date", + channel_columns=["C1", "C2"], + dims=("country",), + target_column="y", + adstock=GeometricAdstock(l_max=10), + saturation=LogisticSaturation(), + ) + + +@pytest.fixture +def df() -> pd.DataFrame: + dates = pd.date_range("2025-01-01", periods=3, freq="W-MON").rename("date") + df = pd.DataFrame( + { + ("A", "C1"): [1, 2, 3], + ("B", "C1"): [4, 5, 6], + ("A", "C2"): [7, 8, 9], + ("B", "C2"): [10, 11, 12], + }, + index=dates, + ) + df.columns.names = ["country", "channel"] + + y = pd.DataFrame( + { + ("A", "y"): [1, 2, 3], + ("B", "y"): [4, 5, 6], + }, + index=dates, + ) + y.columns.names = ["country", "channel"] + + return pd.concat( + [ + df.stack("country", future_stack=True), + y.stack("country", future_stack=True), + ], + axis=1, + ).reset_index() + + +@pytest.fixture +def fit_mmm_with_channel_original_scale(df, mmm, mock_pymc_sample): + X = df.drop(columns=["y"]) + y = df["y"] + + mmm.build_model(X, y) + mmm.add_original_scale_contribution_variable( + var=[ + "channel_contribution", + ] + ) + + mmm.fit(X, y) + + return mmm + + +@pytest.fixture +def fit_mmm_without_channel_original_scale(df, mmm, mock_pymc_sample): + X = df.drop(columns=["y"]) + y = df["y"] + + mmm.fit(X, y) + + return mmm + + +def test_saturation_curves_scatter_original_scale(fit_mmm_with_channel_original_scale): + fig, ax = fit_mmm_with_channel_original_scale.plot.saturation_curves_scatter( + original_scale=True + ) + assert isinstance(fig, Figure) + assert isinstance(ax, np.ndarray) + assert all(isinstance(a, Axes) for a in ax.flat) + + +def test_saturation_curves_scatter_original_scale_fails_if_no_deterministic( + fit_mmm_without_channel_original_scale, +): + with pytest.raises(ValueError): + fit_mmm_without_channel_original_scale.plot.saturation_curves_scatter( + original_scale=True + ) + + +def test_contributions_over_time(fit_mmm_with_channel_original_scale): + fig, ax = fit_mmm_with_channel_original_scale.plot.contributions_over_time( + var=["channel_contribution"], + hdi_prob=0.95, + ) + assert isinstance(fig, Figure) + assert isinstance(ax, np.ndarray) + assert all(isinstance(a, Axes) for a in ax.flat) + + +def test_contributions_over_time_with_dim(mock_suite: LegacyMMMPlotSuite): + # Test with explicit dim argument + fig, ax = mock_suite.contributions_over_time( + var=["intercept", "linear_trend"], + dims={"country": "A"}, + ) + assert isinstance(fig, Figure) + assert isinstance(ax, np.ndarray) + assert all(isinstance(a, Axes) for a in ax.flat) + # Optionally, check axes shape if known + if hasattr(ax, "shape"): + # When filtering to a single country, shape[-1] should be 1 + assert ax.shape[-1] == 1 + + +def test_contributions_over_time_with_dims_list(mock_suite: LegacyMMMPlotSuite): + """Test that passing a list to dims creates a subplot for each value.""" + fig, ax = mock_suite.contributions_over_time( + var=["intercept"], + dims={"country": ["A", "B"]}, + ) + assert isinstance(fig, Figure) + assert isinstance(ax, np.ndarray) + # Should create one subplot per value in the list (here: 2 countries) + assert ax.shape[0] == 2 + # Optionally, check subplot titles contain the correct country + for i, country in enumerate(["A", "B"]): + assert country in ax[i, 0].get_title() + + +def test_contributions_over_time_with_multiple_dims_lists( + mock_suite: LegacyMMMPlotSuite, +): + """Test that passing multiple lists to dims creates a subplot for each combination.""" + # Add a fake 'region' dim to the mock posterior for this test if not present + idata = mock_suite.idata + if "region" not in idata.posterior["intercept"].dims: + idata.posterior["intercept"] = idata.posterior["intercept"].expand_dims( + region=["X", "Y"] + ) + fig, ax = mock_suite.contributions_over_time( + var=["intercept"], + dims={"country": ["A", "B"], "region": ["X", "Y"]}, + ) + assert isinstance(fig, Figure) + assert isinstance(ax, np.ndarray) + # Should create one subplot per combination (2 countries x 2 regions = 4) + assert ax.shape[0] == 4 + combos = [("A", "X"), ("A", "Y"), ("B", "X"), ("B", "Y")] + for i, (country, region) in enumerate(combos): + title = ax[i, 0].get_title() + assert country in title + assert region in title + + +def test_posterior_predictive(fit_mmm_with_channel_original_scale, df): + fit_mmm_with_channel_original_scale.sample_posterior_predictive( + df.drop(columns=["y"]) + ) + fig, ax = fit_mmm_with_channel_original_scale.plot.posterior_predictive( + hdi_prob=0.95, + ) + assert isinstance(fig, Figure) + assert isinstance(ax, np.ndarray) + assert all(isinstance(a, Axes) for a in ax.flat) + + +@pytest.fixture(scope="module") +def mock_idata() -> az.InferenceData: + seed = sum(map(ord, "Fake posterior")) + rng = np.random.default_rng(seed) + normal = rng.normal + + dates = pd.date_range("2025-01-01", periods=52, freq="W-MON") + return az.InferenceData( + posterior=xr.Dataset( + { + "intercept": xr.DataArray( + normal(size=(4, 100, 52, 3)), + dims=("chain", "draw", "date", "country"), + coords={ + "chain": np.arange(4), + "draw": np.arange(100), + "date": dates, + "country": ["A", "B", "C"], + }, + ), + "linear_trend": xr.DataArray( + normal(size=(4, 100, 52)), + dims=("chain", "draw", "date"), + coords={ + "chain": np.arange(4), + "draw": np.arange(100), + "date": dates, + }, + ), + } + ) + ) + + +@pytest.fixture(scope="module") +def mock_idata_with_sensitivity(mock_idata): + # Copy the mock_idata so we don't mutate the shared fixture + idata = mock_idata.copy() + n_sample, n_sweep = 40, 5 + sweep = np.linspace(0.5, 1.5, n_sweep) + regions = ["A", "B"] + + samples = xr.DataArray( + np.random.normal(0, 1, size=(n_sample, n_sweep, len(regions))), + dims=("sample", "sweep", "region"), + coords={ + "sample": np.arange(n_sample), + "sweep": sweep, + "region": regions, + }, + name="x", + ) + + marginal_effects = xr.DataArray( + np.random.normal(0, 1, size=(n_sample, n_sweep, len(regions))), + dims=("sample", "sweep", "region"), + coords={ + "sample": np.arange(n_sample), + "sweep": sweep, + "region": regions, + }, + name="marginal_effects", + ) + + uplift_curve = xr.DataArray( + np.random.normal(0, 1, size=(n_sample, n_sweep, len(regions))), + dims=("sample", "sweep", "region"), + coords={ + "sample": np.arange(n_sample), + "sweep": sweep, + "region": regions, + }, + name="uplift_curve", + ) + + sensitivity_analysis = xr.Dataset( + { + "x": samples, + "marginal_effects": marginal_effects, + "uplift_curve": uplift_curve, + }, + coords={"sweep": sweep, "region": regions}, + attrs={"sweep_type": "multiplicative", "var_names": "test_var"}, + ) + + idata.sensitivity_analysis = sensitivity_analysis + return idata + + +@pytest.fixture(scope="module") +def mock_suite(mock_idata): + """Fixture to create a mock LegacyMMMPlotSuite with a mocked posterior.""" + return LegacyMMMPlotSuite(idata=mock_idata) + + +@pytest.fixture(scope="module") +def mock_suite_with_sensitivity(mock_idata_with_sensitivity): + """Fixture to create a mock LegacyMMMPlotSuite with sensitivity analysis.""" + return LegacyMMMPlotSuite(idata=mock_idata_with_sensitivity) + + +def test_contributions_over_time_expand_dims(mock_suite: LegacyMMMPlotSuite): + fig, ax = mock_suite.contributions_over_time( + var=[ + "intercept", + "linear_trend", + ] + ) + + assert isinstance(fig, Figure) + assert isinstance(ax, np.ndarray) + assert all(isinstance(a, Axes) for a in ax.flat) + + +@pytest.fixture(scope="module") +def mock_idata_with_constant_data() -> az.InferenceData: + """Create mock InferenceData with constant_data and posterior for saturation tests.""" + seed = sum(map(ord, "Saturation tests")) + rng = np.random.default_rng(seed) + normal = rng.normal + + dates = pd.date_range("2025-01-01", periods=52, freq="W-MON") + channels = ["channel_1", "channel_2"] + countries = ["A", "B"] + + # Create posterior data + posterior = xr.Dataset( + { + "channel_contribution": xr.DataArray( + normal(size=(4, 100, 52, 2, 2)), + dims=("chain", "draw", "date", "channel", "country"), + coords={ + "chain": np.arange(4), + "draw": np.arange(100), + "date": dates, + "channel": channels, + "country": countries, + }, + ), + "channel_contribution_original_scale": xr.DataArray( + normal(size=(4, 100, 52, 2, 2)) * 100, # scaled up for original scale + dims=("chain", "draw", "date", "channel", "country"), + coords={ + "chain": np.arange(4), + "draw": np.arange(100), + "date": dates, + "channel": channels, + "country": countries, + }, + ), + } + ) + + # Create constant_data + constant_data = xr.Dataset( + { + "channel_data": xr.DataArray( + rng.uniform(0, 10, size=(52, 2, 2)), + dims=("date", "channel", "country"), + coords={ + "date": dates, + "channel": channels, + "country": countries, + }, + ), + "channel_scale": xr.DataArray( + [[100.0, 200.0], [150.0, 250.0]], + dims=("country", "channel"), + coords={"country": countries, "channel": channels}, + ), + "target_scale": xr.DataArray( + [1000.0], + dims="target", + coords={"target": ["y"]}, + ), + } + ) + + return az.InferenceData(posterior=posterior, constant_data=constant_data) + + +@pytest.fixture(scope="module") +def mock_suite_with_constant_data(mock_idata_with_constant_data): + """Fixture to create a LegacyMMMPlotSuite with constant_data for saturation tests.""" + return LegacyMMMPlotSuite(idata=mock_idata_with_constant_data) + + +@pytest.fixture(scope="module") +def mock_saturation_curve() -> xr.DataArray: + """Create mock saturation curve data for testing saturation_curves method.""" + seed = sum(map(ord, "Saturation curve")) + rng = np.random.default_rng(seed) + + # Create curve data with typical saturation curve shape + x_values = np.linspace(0, 1, 100) + channels = ["channel_1", "channel_2"] + countries = ["A", "B"] + + curve_data = [] + for _ in range(4): # chains + for _ in range(100): # draws + for _ in channels: + for _ in countries: + # Simple saturation curve: y = x / (1 + x) + y_values = x_values / (1 + x_values) + rng.normal( + 0, 0.01, size=x_values.shape + ) + curve_data.append(y_values) + + curve_array = np.array(curve_data).reshape( + 4, 100, len(channels), len(countries), len(x_values) + ) + + return xr.DataArray( + curve_array, + dims=("chain", "draw", "channel", "country", "x"), + coords={ + "chain": np.arange(4), + "draw": np.arange(100), + "channel": channels, + "country": countries, + "x": x_values, + }, + ) + + +class TestSaturationScatterplot: + def test_saturation_scatterplot_basic(self, mock_suite_with_constant_data): + """Test basic functionality of saturation_scatterplot.""" + fig, axes = mock_suite_with_constant_data.saturation_scatterplot() + + assert isinstance(fig, Figure) + assert isinstance(axes, np.ndarray) + assert all(isinstance(ax, Axes) for ax in axes.flat) + + def test_saturation_scatterplot_original_scale(self, mock_suite_with_constant_data): + """Test saturation_scatterplot with original_scale=True.""" + fig, axes = mock_suite_with_constant_data.saturation_scatterplot( + original_scale=True + ) + + assert isinstance(fig, Figure) + assert isinstance(axes, np.ndarray) + assert all(isinstance(ax, Axes) for ax in axes.flat) + + def test_saturation_scatterplot_custom_kwargs(self, mock_suite_with_constant_data): + """Test saturation_scatterplot with custom kwargs.""" + fig, axes = mock_suite_with_constant_data.saturation_scatterplot( + width_per_col=8.0, height_per_row=5.0 + ) + + assert isinstance(fig, Figure) + assert isinstance(axes, np.ndarray) + assert all(isinstance(ax, Axes) for ax in axes.flat) + + def test_saturation_scatterplot_no_constant_data(self, mock_suite): + """Test that saturation_scatterplot raises error without constant_data.""" + with pytest.raises(ValueError, match=r"No 'constant_data' found"): + mock_suite.saturation_scatterplot() + + def test_saturation_scatterplot_no_original_scale_contribution( + self, mock_suite_with_constant_data + ): + """Test that saturation_scatterplot raises error when original_scale=True but no original scale data.""" + # Remove the original scale contribution from the mock data + idata_copy = mock_suite_with_constant_data.idata.copy() + idata_copy.posterior = idata_copy.posterior.drop_vars( + "channel_contribution_original_scale" + ) + suite_without_original_scale = LegacyMMMPlotSuite(idata=idata_copy) + + with pytest.raises( + ValueError, match=r"No posterior.channel_contribution_original_scale" + ): + suite_without_original_scale.saturation_scatterplot(original_scale=True) + + +class TestSaturationScatterplotDims: + def test_saturation_scatterplot_with_dim(self, mock_suite_with_constant_data): + """Test saturation_scatterplot with a single value in dims.""" + fig, axes = mock_suite_with_constant_data.saturation_scatterplot( + dims={"country": "A"} + ) + assert isinstance(fig, Figure) + assert isinstance(axes, np.ndarray) + # Should create one column (n_channels, 1) + assert axes.shape[1] == 1 + for row in range(axes.shape[0]): + assert "country=A" in axes[row, 0].get_title() + + def test_saturation_scatterplot_with_dims_list(self, mock_suite_with_constant_data): + """Test saturation_scatterplot with a list in dims (should create subplots for each value).""" + fig, axes = mock_suite_with_constant_data.saturation_scatterplot( + dims={"country": ["A", "B"]} + ) + assert isinstance(fig, Figure) + assert isinstance(axes, np.ndarray) + # Should create two columns (n_channels, 2) + assert axes.shape[1] == 2 + for col, country in enumerate(["A", "B"]): + for row in range(axes.shape[0]): + assert f"country={country}" in axes[row, col].get_title() + + def test_saturation_scatterplot_with_multiple_dims_lists( + self, mock_suite_with_constant_data + ): + """Test saturation_scatterplot with multiple lists in dims (should create subplots for each combination).""" + # Add a fake 'region' dim to the mock constant_data for this test if not present + idata = mock_suite_with_constant_data.idata + if "region" not in idata.constant_data.channel_data.dims: + # Expand channel_data and posterior to add region + new_regions = ["X", "Y"] + channel_data = idata.constant_data.channel_data.expand_dims( + region=new_regions + ) + idata.constant_data["channel_data"] = channel_data + for var in ["channel_contribution", "channel_contribution_original_scale"]: + if var in idata.posterior: + idata.posterior[var] = idata.posterior[var].expand_dims( + region=new_regions + ) + fig, axes = mock_suite_with_constant_data.saturation_scatterplot( + dims={"country": ["A", "B"], "region": ["X", "Y"]} + ) + assert isinstance(fig, Figure) + assert isinstance(axes, np.ndarray) + # Should create 4 columns (n_channels, 4) + assert axes.shape[1] == 4 + combos = [("A", "X"), ("A", "Y"), ("B", "X"), ("B", "Y")] + for col, (country, region) in enumerate(combos): + for row in range(axes.shape[0]): + title = axes[row, col].get_title() + assert f"country={country}" in title + assert f"region={region}" in title + + +class TestSaturationCurves: + def test_saturation_curves_basic( + self, mock_suite_with_constant_data, mock_saturation_curve + ): + """Test basic functionality of saturation_curves.""" + fig, axes = mock_suite_with_constant_data.saturation_curves( + curve=mock_saturation_curve, n_samples=5 + ) + + assert isinstance(fig, Figure) + assert isinstance(axes, np.ndarray) + assert all(isinstance(ax, Axes) for ax in axes.flat) + + def test_saturation_curves_original_scale( + self, mock_suite_with_constant_data, mock_saturation_curve + ): + """Test saturation_curves with original_scale=True.""" + fig, axes = mock_suite_with_constant_data.saturation_curves( + curve=mock_saturation_curve, original_scale=True, n_samples=3 + ) + + assert isinstance(fig, Figure) + assert isinstance(axes, np.ndarray) + assert all(isinstance(ax, Axes) for ax in axes.flat) + + def test_saturation_curves_with_hdi( + self, mock_suite_with_constant_data, mock_saturation_curve + ): + """Test saturation_curves with HDI intervals.""" + fig, axes = mock_suite_with_constant_data.saturation_curves( + curve=mock_saturation_curve, n_samples=3, hdi_probs=[0.5, 0.94] + ) + + assert isinstance(fig, Figure) + assert isinstance(axes, np.ndarray) + assert all(isinstance(ax, Axes) for ax in axes.flat) + + def test_saturation_curves_single_hdi( + self, mock_suite_with_constant_data, mock_saturation_curve + ): + """Test saturation_curves with single HDI probability.""" + fig, axes = mock_suite_with_constant_data.saturation_curves( + curve=mock_saturation_curve, n_samples=3, hdi_probs=0.85 + ) + + assert isinstance(fig, Figure) + assert isinstance(axes, np.ndarray) + assert all(isinstance(ax, Axes) for ax in axes.flat) + + def test_saturation_curves_custom_colors( + self, mock_suite_with_constant_data, mock_saturation_curve + ): + """Test saturation_curves with custom colors.""" + fig, axes = mock_suite_with_constant_data.saturation_curves( + curve=mock_saturation_curve, n_samples=3, colors=["red", "blue"] + ) + + assert isinstance(fig, Figure) + assert isinstance(axes, np.ndarray) + assert all(isinstance(ax, Axes) for ax in axes.flat) + + def test_saturation_curves_subplot_kwargs( + self, mock_suite_with_constant_data, mock_saturation_curve + ): + """Test saturation_curves with custom subplot_kwargs.""" + fig, axes = mock_suite_with_constant_data.saturation_curves( + curve=mock_saturation_curve, + n_samples=3, + subplot_kwargs={"figsize": (12, 8)}, + ) + + assert isinstance(fig, Figure) + assert isinstance(axes, np.ndarray) + assert all(isinstance(ax, Axes) for ax in axes.flat) + # Check that figsize was applied + assert fig.get_size_inches()[0] == 12 + assert fig.get_size_inches()[1] == 8 + + def test_saturation_curves_rc_params( + self, mock_suite_with_constant_data, mock_saturation_curve + ): + """Test saturation_curves with rc_params.""" + fig, axes = mock_suite_with_constant_data.saturation_curves( + curve=mock_saturation_curve, n_samples=3, rc_params={"font.size": 14} + ) + + assert isinstance(fig, Figure) + assert isinstance(axes, np.ndarray) + assert all(isinstance(ax, Axes) for ax in axes.flat) + + def test_saturation_curves_no_samples( + self, mock_suite_with_constant_data, mock_saturation_curve + ): + """Test saturation_curves with n_samples=0.""" + fig, axes = mock_suite_with_constant_data.saturation_curves( + curve=mock_saturation_curve, n_samples=0, hdi_probs=0.85 + ) + + assert isinstance(fig, Figure) + assert isinstance(axes, np.ndarray) + assert all(isinstance(ax, Axes) for ax in axes.flat) + + def test_saturation_curves_no_constant_data( + self, mock_suite, mock_saturation_curve + ): + """Test that saturation_curves raises error without constant_data.""" + with pytest.raises(ValueError, match=r"No 'constant_data' found"): + mock_suite.saturation_curves(curve=mock_saturation_curve) + + def test_saturation_curves_no_original_scale_contribution( + self, mock_suite_with_constant_data, mock_saturation_curve + ): + """Test that saturation_curves raises error when original_scale=True but no original scale data.""" + # Remove the original scale contribution from the mock data + idata_copy = mock_suite_with_constant_data.idata.copy() + idata_copy.posterior = idata_copy.posterior.drop_vars( + "channel_contribution_original_scale" + ) + suite_without_original_scale = LegacyMMMPlotSuite(idata=idata_copy) + + with pytest.raises( + ValueError, match=r"No posterior.channel_contribution_original_scale" + ): + suite_without_original_scale.saturation_curves( + curve=mock_saturation_curve, original_scale=True + ) + + +class TestSaturationCurvesDims: + def test_saturation_curves_with_dim( + self, mock_suite_with_constant_data, mock_saturation_curve + ): + """Test saturation_curves with a single value in dims.""" + fig, axes = mock_suite_with_constant_data.saturation_curves( + curve=mock_saturation_curve, n_samples=3, dims={"country": "A"} + ) + assert isinstance(fig, Figure) + assert isinstance(axes, np.ndarray) + + for row in range(axes.shape[0]): + assert "country=A" in axes[row, 0].get_title() + + def test_saturation_curves_with_dims_list( + self, mock_suite_with_constant_data, mock_saturation_curve + ): + """Test saturation_curves with a list in dims (should create subplots for each value).""" + fig, axes = mock_suite_with_constant_data.saturation_curves( + curve=mock_saturation_curve, n_samples=3, dims={"country": ["A", "B"]} + ) + assert isinstance(fig, Figure) + assert isinstance(axes, np.ndarray) + + def test_saturation_curves_with_multiple_dims_lists( + self, mock_suite_with_constant_data, mock_saturation_curve + ): + """Test saturation_curves with multiple lists in dims (should create subplots for each combination).""" + # Add a fake 'region' dim to the mock constant_data for this test if not present + idata = mock_suite_with_constant_data.idata + if "region" not in idata.constant_data.channel_data.dims: + # Expand channel_data and posterior to add region + new_regions = ["X", "Y"] + channel_data = idata.constant_data.channel_data.expand_dims( + region=new_regions + ) + idata.constant_data["channel_data"] = channel_data + for var in ["channel_contribution", "channel_contribution_original_scale"]: + if var in idata.posterior: + idata.posterior[var] = idata.posterior[var].expand_dims( + region=new_regions + ) + fig, axes = mock_suite_with_constant_data.saturation_curves( + curve=mock_saturation_curve, + n_samples=3, + dims={"country": ["A", "B"], "region": ["X", "Y"]}, + ) + assert isinstance(fig, Figure) + assert isinstance(axes, np.ndarray) + combos = [("A", "X"), ("A", "Y"), ("B", "X"), ("B", "Y")] + + for col, (country, region) in enumerate(combos): + for row in range(axes.shape[0]): + title = axes[row, col].get_title() + assert f"country={country}" in title + assert f"region={region}" in title + + +def test_saturation_curves_scatter_deprecation_warning(mock_suite_with_constant_data): + """Test that saturation_curves_scatter shows deprecation warning.""" + with pytest.warns( + DeprecationWarning, match=r"saturation_curves_scatter is deprecated" + ): + fig, axes = mock_suite_with_constant_data.saturation_curves_scatter() + + assert isinstance(fig, Figure) + assert isinstance(axes, np.ndarray) + assert all(isinstance(ax, Axes) for ax in axes.flat) + + +@pytest.fixture(scope="module") +def mock_idata_with_constant_data_single_dim() -> az.InferenceData: + """Mock InferenceData where channel_data has only ('date','channel') dims.""" + seed = sum(map(ord, "Saturation single-dim tests")) + rng = np.random.default_rng(seed) + normal = rng.normal + + dates = pd.date_range("2025-01-01", periods=12, freq="W-MON") + channels = ["channel_1", "channel_2", "channel_3"] + + posterior = xr.Dataset( + { + "channel_contribution": xr.DataArray( + normal(size=(2, 10, 12, 3)), + dims=("chain", "draw", "date", "channel"), + coords={ + "chain": np.arange(2), + "draw": np.arange(10), + "date": dates, + "channel": channels, + }, + ), + "channel_contribution_original_scale": xr.DataArray( + normal(size=(2, 10, 12, 3)) * 100.0, + dims=("chain", "draw", "date", "channel"), + coords={ + "chain": np.arange(2), + "draw": np.arange(10), + "date": dates, + "channel": channels, + }, + ), + } + ) + + constant_data = xr.Dataset( + { + "channel_data": xr.DataArray( + rng.uniform(0, 10, size=(12, 3)), + dims=("date", "channel"), + coords={"date": dates, "channel": channels}, + ), + "channel_scale": xr.DataArray( + [100.0, 150.0, 200.0], dims=("channel",), coords={"channel": channels} + ), + "target_scale": xr.DataArray( + [1000.0], dims="target", coords={"target": ["y"]} + ), + } + ) + + return az.InferenceData(posterior=posterior, constant_data=constant_data) + + +@pytest.fixture(scope="module") +def mock_suite_with_constant_data_single_dim(mock_idata_with_constant_data_single_dim): + return LegacyMMMPlotSuite(idata=mock_idata_with_constant_data_single_dim) + + +@pytest.fixture(scope="module") +def mock_saturation_curve_single_dim() -> xr.DataArray: + """Saturation curve with dims ('chain','draw','channel','x').""" + seed = sum(map(ord, "Saturation curve single-dim")) + rng = np.random.default_rng(seed) + x_values = np.linspace(0, 1, 50) + channels = ["channel_1", "channel_2", "channel_3"] + + # shape: (chains=2, draws=10, channel=3, x=50) + curve_array = np.empty((2, 10, len(channels), len(x_values))) + for ci in range(2): + for di in range(10): + for c in range(len(channels)): + curve_array[ci, di, c, :] = x_values / (1 + x_values) + rng.normal( + 0, 0.02, size=x_values.shape + ) + + return xr.DataArray( + curve_array, + dims=("chain", "draw", "channel", "x"), + coords={ + "chain": np.arange(2), + "draw": np.arange(10), + "channel": channels, + "x": x_values, + }, + name="saturation_curve", + ) + + +def test_saturation_curves_single_dim_axes_shape( + mock_suite_with_constant_data_single_dim, mock_saturation_curve_single_dim +): + """When there are no extra dims, columns should default to 1 (no ncols=0).""" + fig, axes = mock_suite_with_constant_data_single_dim.saturation_curves( + curve=mock_saturation_curve_single_dim, n_samples=3 + ) + + assert isinstance(fig, Figure) + assert isinstance(axes, np.ndarray) + # Expect (n_channels, 1) + assert axes.shape[1] == 1 + assert axes.shape[0] == mock_saturation_curve_single_dim.sizes["channel"] + + +def test_saturation_curves_multi_dim_axes_shape( + mock_suite_with_constant_data, mock_saturation_curve +): + """With an extra dim (e.g., 'country'), expect (n_channels, n_countries).""" + fig, axes = mock_suite_with_constant_data.saturation_curves( + curve=mock_saturation_curve, n_samples=2 + ) + + assert isinstance(fig, Figure) + assert isinstance(axes, np.ndarray) + + +def test_sensitivity_analysis_basic(mock_suite_with_sensitivity): + fig, axes = mock_suite_with_sensitivity.sensitivity_analysis() + + assert isinstance(fig, Figure) + assert isinstance(axes, np.ndarray) + assert axes.ndim == 2 + expected_panels = len( + mock_suite_with_sensitivity.idata.sensitivity_analysis.coords["region"] + ) # type: ignore + assert axes.size >= expected_panels + assert all(isinstance(ax, Axes) for ax in axes.flat[:expected_panels]) + + +def test_sensitivity_analysis_with_aggregation(mock_suite_with_sensitivity): + ax = mock_suite_with_sensitivity.sensitivity_analysis( + aggregation={"sum": ("region",)} + ) + assert isinstance(ax, Axes) + + +def test_marginal_curve(mock_suite_with_sensitivity): + fig, axes = mock_suite_with_sensitivity.marginal_curve() + + assert isinstance(fig, Figure) + assert isinstance(axes, np.ndarray) + assert axes.ndim == 2 + regions = mock_suite_with_sensitivity.idata.sensitivity_analysis.coords["region"] # type: ignore + assert axes.size >= len(regions) + assert all(isinstance(ax, Axes) for ax in axes.flat[: len(regions)]) + + +def test_uplift_curve(mock_suite_with_sensitivity): + fig, axes = mock_suite_with_sensitivity.uplift_curve() + + assert isinstance(fig, Figure) + assert isinstance(axes, np.ndarray) + assert axes.ndim == 2 + regions = mock_suite_with_sensitivity.idata.sensitivity_analysis.coords["region"] # type: ignore + assert axes.size >= len(regions) + assert all(isinstance(ax, Axes) for ax in axes.flat[: len(regions)]) + + +def test_sensitivity_analysis_multi_panel(mock_suite_with_sensitivity): + # The fixture provides an extra 'region' dimension, so multiple panels should be produced + fig, axes = mock_suite_with_sensitivity.sensitivity_analysis( + subplot_kwargs={"ncols": 2} + ) + + assert isinstance(fig, Figure) + assert isinstance(axes, np.ndarray) + assert axes.ndim == 2 + # There should be two regions, therefore exactly two panels + expected_panels = len( + mock_suite_with_sensitivity.idata.sensitivity_analysis.coords["region"] + ) # type: ignore + assert axes.size >= expected_panels + assert all(isinstance(ax, Axes) for ax in axes.flat[:expected_panels]) + + +def test_sensitivity_analysis_error_on_missing_results(mock_idata): + suite = LegacyMMMPlotSuite(idata=mock_idata) + with pytest.raises(ValueError, match=r"No sensitivity analysis results found"): + suite.sensitivity_analysis() + suite.plot_sensitivity_analysis() + + +def test_budget_allocation_with_dims(mock_suite_with_constant_data): + # Use dims to filter to a single country + samples = mock_suite_with_constant_data.idata.posterior + # Add a fake 'allocation' variable for testing + samples = samples.copy() + samples["allocation"] = ( + samples["channel_contribution"].dims, + np.abs(samples["channel_contribution"].values), + ) + plot_suite = mock_suite_with_constant_data + fig, _ax = plot_suite.budget_allocation( + samples=samples, + dims={"country": "A"}, + ) + assert isinstance(fig, Figure) + + +def test_budget_allocation_with_dims_list(mock_suite_with_constant_data): + """Test that passing a list to dims creates a subplot for each value.""" + samples = mock_suite_with_constant_data.idata.posterior.copy() + # Add a fake 'allocation' variable for testing + samples["allocation"] = ( + samples["channel_contribution"].dims, + np.abs(samples["channel_contribution"].values), + ) + plot_suite = mock_suite_with_constant_data + fig, ax = plot_suite.budget_allocation( + samples=samples, + dims={"country": ["A", "B"]}, + ) + assert isinstance(fig, Figure) + assert isinstance(ax, np.ndarray) + + +def test__validate_dims_valid(): + """Test _validate_dims with valid dims and values.""" + suite = LegacyMMMPlotSuite(idata=None) + + # Patch suite.idata.posterior.coords to simulate valid dims + class DummyCoord: + def __init__(self, values): + self.values = values + + class DummyCoords: + def __init__(self): + self._coords = { + "country": DummyCoord(["A", "B"]), + "region": DummyCoord(["X", "Y"]), + } + + def __getitem__(self, key): + return self._coords[key] + + class DummyPosterior: + coords = DummyCoords() + + suite.idata = type("idata", (), {"posterior": DummyPosterior()})() + # Should not raise + suite._validate_dims({"country": "A", "region": "X"}, ["country", "region"]) + suite._validate_dims({"country": ["A", "B"]}, ["country", "region"]) + + +def test__validate_dims_invalid_dim(): + """Test _validate_dims raises for invalid dim name.""" + suite = LegacyMMMPlotSuite(idata=None) + + class DummyCoord: + def __init__(self, values): + self.values = values + + class DummyCoords: + def __init__(self): + self.country = DummyCoord(["A", "B"]) + + def __getitem__(self, key): + return getattr(self, key) + + class DummyPosterior: + coords = DummyCoords() + + suite.idata = type("idata", (), {"posterior": DummyPosterior()})() + with pytest.raises(ValueError, match=r"Dimension 'region' not found"): + suite._validate_dims({"region": "X"}, ["country"]) + + +def test__validate_dims_invalid_value(): + """Test _validate_dims raises for invalid value.""" + suite = LegacyMMMPlotSuite(idata=None) + + class DummyCoord: + def __init__(self, values): + self.values = values + + class DummyCoords: + def __init__(self): + self.country = DummyCoord(["A", "B"]) + + def __getitem__(self, key): + return getattr(self, key) + + class DummyPosterior: + coords = DummyCoords() + + suite.idata = type("idata", (), {"posterior": DummyPosterior()})() + with pytest.raises(ValueError, match=r"Value 'C' not found in dimension 'country'"): + suite._validate_dims({"country": "C"}, ["country"]) + + +def test__dim_list_handler_none(): + """Test _dim_list_handler with None input.""" + suite = LegacyMMMPlotSuite(idata=None) + keys, combos = suite._dim_list_handler(None) + assert keys == [] + assert combos == [()] + + +def test__dim_list_handler_single(): + """Test _dim_list_handler with a single list-valued dim.""" + suite = LegacyMMMPlotSuite(idata=None) + keys, combos = suite._dim_list_handler({"country": ["A", "B"]}) + assert keys == ["country"] + assert set(combos) == {("A",), ("B",)} + + +def test__dim_list_handler_multiple(): + """Test _dim_list_handler with multiple list-valued dims.""" + suite = LegacyMMMPlotSuite(idata=None) + keys, combos = suite._dim_list_handler( + {"country": ["A", "B"], "region": ["X", "Y"]} + ) + assert set(keys) == {"country", "region"} + assert set(combos) == {("A", "X"), ("A", "Y"), ("B", "X"), ("B", "Y")} + + +def test__dim_list_handler_mixed(): + """Test _dim_list_handler with mixed single and list values.""" + suite = LegacyMMMPlotSuite(idata=None) + keys, combos = suite._dim_list_handler({"country": ["A", "B"], "region": "X"}) + assert keys == ["country"] + assert set(combos) == {("A",), ("B",)} diff --git a/tests/mmm/test_legacy_plot_regression.py b/tests/mmm/test_legacy_plot_regression.py new file mode 100644 index 000000000..7c42baace --- /dev/null +++ b/tests/mmm/test_legacy_plot_regression.py @@ -0,0 +1,56 @@ +# Copyright 2022 - 2025 The PyMC Labs Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Regression tests for legacy plot suite.""" + +import numpy as np +from matplotlib.axes import Axes +from matplotlib.figure import Figure + +from pymc_marketing.mmm.legacy_plot import LegacyMMMPlotSuite + + +def test_legacy_suite_all_methods_exist(): + """Test all legacy suite methods still exist after rename.""" + expected_methods = [ + "posterior_predictive", + "contributions_over_time", + "saturation_scatterplot", + "saturation_curves", + "saturation_curves_scatter", # Deprecated but still in legacy + "budget_allocation", + "allocated_contribution_by_channel_over_time", + "sensitivity_analysis", + "uplift_curve", + "marginal_curve", + ] + + for method_name in expected_methods: + assert hasattr(LegacyMMMPlotSuite, method_name), ( + f"LegacyMMMPlotSuite missing method: {method_name}" + ) + + +def test_legacy_suite_returns_tuple(mock_idata_for_legacy): + """Test legacy suite returns tuple, not PlotCollection.""" + suite = LegacyMMMPlotSuite(idata=mock_idata_for_legacy) + result = suite.posterior_predictive() + + assert isinstance(result, tuple) + assert len(result) == 2 + assert isinstance(result[0], Figure) + # result[1] can be Axes or ndarray of Axes + if isinstance(result[1], np.ndarray): + assert all(isinstance(ax, Axes) for ax in result[1].flat) + else: + assert isinstance(result[1], Axes) diff --git a/tests/mmm/test_plot.py b/tests/mmm/test_plot.py index ea41b44ce..b0c8c885b 100644 --- a/tests/mmm/test_plot.py +++ b/tests/mmm/test_plot.py @@ -11,187 +11,35 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import warnings +# +"""Tests for new MMMPlotSuite with multi-backend support (arviz_plots-based). + +This file tests the new arviz_plots-based MMMPlotSuite that supports +matplotlib, plotly, and bokeh backends. + +For tests of the legacy matplotlib-only suite, see test_legacy_plot.py. + +Test Organization: +- Parametrized backend tests: Each plotting method tested with all backends +- Backend behavior tests: Config override, invalid backends +- Data parameter tests: Explicit data parameter functionality +- Integration tests: Multiple plots, backend switching + +.. versionadded:: 0.18.0 + New test suite for arviz_plots-based MMMPlotSuite. +""" import arviz as az import numpy as np import pandas as pd import pytest import xarray as xr -from matplotlib.axes import Axes -from matplotlib.figure import Figure - -from pymc_marketing.mmm import GeometricAdstock, LogisticSaturation - -with warnings.catch_warnings(): - warnings.simplefilter("ignore", FutureWarning) - from pymc_marketing.mmm.multidimensional import MMM from pymc_marketing.mmm.plot import MMMPlotSuite - -@pytest.fixture -def mmm(): - return MMM( - date_column="date", - channel_columns=["C1", "C2"], - dims=("country",), - target_column="y", - adstock=GeometricAdstock(l_max=10), - saturation=LogisticSaturation(), - ) - - -@pytest.fixture -def df() -> pd.DataFrame: - dates = pd.date_range("2025-01-01", periods=3, freq="W-MON").rename("date") - df = pd.DataFrame( - { - ("A", "C1"): [1, 2, 3], - ("B", "C1"): [4, 5, 6], - ("A", "C2"): [7, 8, 9], - ("B", "C2"): [10, 11, 12], - }, - index=dates, - ) - df.columns.names = ["country", "channel"] - - y = pd.DataFrame( - { - ("A", "y"): [1, 2, 3], - ("B", "y"): [4, 5, 6], - }, - index=dates, - ) - y.columns.names = ["country", "channel"] - - return pd.concat( - [ - df.stack("country", future_stack=True), - y.stack("country", future_stack=True), - ], - axis=1, - ).reset_index() - - -@pytest.fixture -def fit_mmm_with_channel_original_scale(df, mmm, mock_pymc_sample): - X = df.drop(columns=["y"]) - y = df["y"] - - mmm.build_model(X, y) - mmm.add_original_scale_contribution_variable( - var=[ - "channel_contribution", - ] - ) - - mmm.fit(X, y) - - return mmm - - -@pytest.fixture -def fit_mmm_without_channel_original_scale(df, mmm, mock_pymc_sample): - X = df.drop(columns=["y"]) - y = df["y"] - - mmm.fit(X, y) - - return mmm - - -def test_saturation_curves_scatter_original_scale(fit_mmm_with_channel_original_scale): - fig, ax = fit_mmm_with_channel_original_scale.plot.saturation_curves_scatter( - original_scale=True - ) - assert isinstance(fig, Figure) - assert isinstance(ax, np.ndarray) - assert all(isinstance(a, Axes) for a in ax.flat) - - -def test_saturation_curves_scatter_original_scale_fails_if_no_deterministic( - fit_mmm_without_channel_original_scale, -): - with pytest.raises(ValueError): - fit_mmm_without_channel_original_scale.plot.saturation_curves_scatter( - original_scale=True - ) - - -def test_contributions_over_time(fit_mmm_with_channel_original_scale): - fig, ax = fit_mmm_with_channel_original_scale.plot.contributions_over_time( - var=["channel_contribution"], - hdi_prob=0.95, - ) - assert isinstance(fig, Figure) - assert isinstance(ax, np.ndarray) - assert all(isinstance(a, Axes) for a in ax.flat) - - -def test_contributions_over_time_with_dim(mock_suite: MMMPlotSuite): - # Test with explicit dim argument - fig, ax = mock_suite.contributions_over_time( - var=["intercept", "linear_trend"], - dims={"country": "A"}, - ) - assert isinstance(fig, Figure) - assert isinstance(ax, np.ndarray) - assert all(isinstance(a, Axes) for a in ax.flat) - # Optionally, check axes shape if known - if hasattr(ax, "shape"): - # When filtering to a single country, shape[-1] should be 1 - assert ax.shape[-1] == 1 - - -def test_contributions_over_time_with_dims_list(mock_suite: MMMPlotSuite): - """Test that passing a list to dims creates a subplot for each value.""" - fig, ax = mock_suite.contributions_over_time( - var=["intercept"], - dims={"country": ["A", "B"]}, - ) - assert isinstance(fig, Figure) - assert isinstance(ax, np.ndarray) - # Should create one subplot per value in the list (here: 2 countries) - assert ax.shape[0] == 2 - # Optionally, check subplot titles contain the correct country - for i, country in enumerate(["A", "B"]): - assert country in ax[i, 0].get_title() - - -def test_contributions_over_time_with_multiple_dims_lists(mock_suite: MMMPlotSuite): - """Test that passing multiple lists to dims creates a subplot for each combination.""" - # Add a fake 'region' dim to the mock posterior for this test if not present - idata = mock_suite.idata - if "region" not in idata.posterior["intercept"].dims: - idata.posterior["intercept"] = idata.posterior["intercept"].expand_dims( - region=["X", "Y"] - ) - fig, ax = mock_suite.contributions_over_time( - var=["intercept"], - dims={"country": ["A", "B"], "region": ["X", "Y"]}, - ) - assert isinstance(fig, Figure) - assert isinstance(ax, np.ndarray) - # Should create one subplot per combination (2 countries x 2 regions = 4) - assert ax.shape[0] == 4 - combos = [("A", "X"), ("A", "Y"), ("B", "X"), ("B", "Y")] - for i, (country, region) in enumerate(combos): - title = ax[i, 0].get_title() - assert country in title - assert region in title - - -def test_posterior_predictive(fit_mmm_with_channel_original_scale, df): - fit_mmm_with_channel_original_scale.sample_posterior_predictive( - df.drop(columns=["y"]) - ) - fig, ax = fit_mmm_with_channel_original_scale.plot.posterior_predictive( - hdi_prob=0.95, - ) - assert isinstance(fig, Figure) - assert isinstance(ax, np.ndarray) - assert all(isinstance(a, Axes) for a in ax.flat) +# ============================================================================= +# Fixtures +# ============================================================================= @pytest.fixture(scope="module") @@ -215,12 +63,13 @@ def mock_idata() -> az.InferenceData: }, ), "linear_trend": xr.DataArray( - normal(size=(4, 100, 52)), - dims=("chain", "draw", "date"), + normal(size=(4, 100, 52, 3)), + dims=("chain", "draw", "date", "country"), coords={ "chain": np.arange(4), "draw": np.arange(100), "date": dates, + "country": ["A", "B", "C"], }, ), } @@ -295,19 +144,6 @@ def mock_suite_with_sensitivity(mock_idata_with_sensitivity): return MMMPlotSuite(idata=mock_idata_with_sensitivity) -def test_contributions_over_time_expand_dims(mock_suite: MMMPlotSuite): - fig, ax = mock_suite.contributions_over_time( - var=[ - "intercept", - "linear_trend", - ] - ) - - assert isinstance(fig, Figure) - assert isinstance(ax, np.ndarray) - assert all(isinstance(a, Axes) for a in ax.flat) - - @pytest.fixture(scope="module") def mock_idata_with_constant_data() -> az.InferenceData: """Create mock InferenceData with constant_data and posterior for saturation tests.""" @@ -420,633 +256,1212 @@ def mock_saturation_curve() -> xr.DataArray: ) -class TestSaturationScatterplot: - def test_saturation_scatterplot_basic(self, mock_suite_with_constant_data): - """Test basic functionality of saturation_scatterplot.""" - fig, axes = mock_suite_with_constant_data.saturation_scatterplot() +# ============================================================================= +# Basic Functionality Tests +# ============================================================================= - assert isinstance(fig, Figure) - assert isinstance(axes, np.ndarray) - assert all(isinstance(ax, Axes) for ax in axes.flat) - def test_saturation_scatterplot_original_scale(self, mock_suite_with_constant_data): - """Test saturation_scatterplot with original_scale=True.""" - fig, axes = mock_suite_with_constant_data.saturation_scatterplot( - original_scale=True +def test_contributions_over_time_expand_dims(mock_suite: MMMPlotSuite): + from arviz_plots import PlotCollection + + pc = mock_suite.contributions_over_time( + var=[ + "intercept", + "linear_trend", + ] + ) + + assert isinstance(pc, PlotCollection) + assert hasattr(pc, "backend") + assert hasattr(pc, "show") + + +# ============================================================================= +# Comprehensive Backend Tests (Milestone 3) +# ============================================================================= +# These tests verify that all plotting methods work correctly across all +# supported backends (matplotlib, plotly, bokeh). +# ============================================================================= + + +class TestPosteriorPredictiveBackends: + """Test posterior_predictive method across all backends.""" + + @pytest.mark.parametrize("backend", ["matplotlib", "plotly", "bokeh"]) + def test_posterior_predictive_all_backends(self, mock_suite, backend): + """Test posterior_predictive works with all backends.""" + from arviz_plots import PlotCollection + + # Create idata with posterior_predictive + idata = mock_suite.idata.copy() + rng = np.random.default_rng(42) + dates = pd.date_range("2025-01-01", periods=52, freq="W") + idata.posterior_predictive = xr.Dataset( + { + "y": xr.DataArray( + rng.normal(size=(4, 100, 52)), + dims=("chain", "draw", "date"), + coords={ + "chain": np.arange(4), + "draw": np.arange(100), + "date": dates, + }, + ) + } ) + suite = MMMPlotSuite(idata=idata) - assert isinstance(fig, Figure) - assert isinstance(axes, np.ndarray) - assert all(isinstance(ax, Axes) for ax in axes.flat) + pc = suite.posterior_predictive(backend=backend) - def test_saturation_scatterplot_custom_kwargs(self, mock_suite_with_constant_data): - """Test saturation_scatterplot with custom kwargs.""" - fig, axes = mock_suite_with_constant_data.saturation_scatterplot( - width_per_col=8.0, height_per_row=5.0 + assert isinstance(pc, PlotCollection), ( + f"Expected PlotCollection for backend {backend}, got {type(pc)}" ) - assert isinstance(fig, Figure) - assert isinstance(axes, np.ndarray) - assert all(isinstance(ax, Axes) for ax in axes.flat) - def test_saturation_scatterplot_no_constant_data(self, mock_suite): - """Test that saturation_scatterplot raises error without constant_data.""" - with pytest.raises(ValueError, match=r"No 'constant_data' found"): - mock_suite.saturation_scatterplot() +class TestContributionsOverTimeBackends: + """Test contributions_over_time method across all backends.""" - def test_saturation_scatterplot_no_original_scale_contribution( - self, mock_suite_with_constant_data + @pytest.mark.parametrize("backend", ["matplotlib", "plotly", "bokeh"]) + def test_contributions_over_time_all_backends(self, mock_suite, backend): + """Test contributions_over_time works with all backends.""" + from arviz_plots import PlotCollection + + pc = mock_suite.contributions_over_time(var=["intercept"], backend=backend) + + assert isinstance(pc, PlotCollection), ( + f"Expected PlotCollection for backend {backend}, got {type(pc)}" + ) + + +class TestSaturationPlotBackends: + """Test saturation plot methods across all backends.""" + + @pytest.mark.parametrize("backend", ["matplotlib", "plotly", "bokeh"]) + def test_saturation_scatterplot_all_backends( + self, mock_suite_with_constant_data, backend ): - """Test that saturation_scatterplot raises error when original_scale=True but no original scale data.""" - # Remove the original scale contribution from the mock data - idata_copy = mock_suite_with_constant_data.idata.copy() - idata_copy.posterior = idata_copy.posterior.drop_vars( - "channel_contribution_original_scale" + """Test saturation_scatterplot works with all backends.""" + from arviz_plots import PlotCollection + + pc = mock_suite_with_constant_data.saturation_scatterplot(backend=backend) + + assert isinstance(pc, PlotCollection), ( + f"Expected PlotCollection for backend {backend}, got {type(pc)}" ) - suite_without_original_scale = MMMPlotSuite(idata=idata_copy) - with pytest.raises( - ValueError, match=r"No posterior.channel_contribution_original_scale" - ): - suite_without_original_scale.saturation_scatterplot(original_scale=True) + @pytest.mark.parametrize("backend", ["matplotlib", "plotly", "bokeh"]) + def test_saturation_curves_all_backends( + self, mock_suite_with_constant_data, mock_saturation_curve, backend + ): + """Test saturation_curves works with all backends.""" + from arviz_plots import PlotCollection + pc = mock_suite_with_constant_data.saturation_curves( + curve=mock_saturation_curve, backend=backend, n_samples=3 + ) -class TestSaturationScatterplotDims: - def test_saturation_scatterplot_with_dim(self, mock_suite_with_constant_data): - """Test saturation_scatterplot with a single value in dims.""" - fig, axes = mock_suite_with_constant_data.saturation_scatterplot( - dims={"country": "A"} + assert isinstance(pc, PlotCollection), ( + f"Expected PlotCollection for backend {backend}, got {type(pc)}" ) - assert isinstance(fig, Figure) - assert isinstance(axes, np.ndarray) - # Should create one column (n_channels, 1) - assert axes.shape[1] == 1 - for row in range(axes.shape[0]): - assert "country=A" in axes[row, 0].get_title() - def test_saturation_scatterplot_with_dims_list(self, mock_suite_with_constant_data): - """Test saturation_scatterplot with a list in dims (should create subplots for each value).""" - fig, axes = mock_suite_with_constant_data.saturation_scatterplot( - dims={"country": ["A", "B"]} + +class TestBudgetAllocationBackends: + """Test budget allocation methods across all backends.""" + + @pytest.mark.parametrize("backend", ["matplotlib", "plotly", "bokeh"]) + def test_budget_allocation_roas_all_backends(self, mock_suite, backend): + """Test budget_allocation_roas works with all backends.""" + from arviz_plots import PlotCollection + + # Create proper allocation samples with required variables and dimensions + rng = np.random.default_rng(42) + channels = ["TV", "Radio", "Digital"] + dates = pd.date_range("2025-01-01", periods=52, freq="W") + samples = xr.Dataset( + { + "channel_contribution_original_scale": xr.DataArray( + rng.normal(loc=1000, scale=100, size=(100, 52, 3)), + dims=("sample", "date", "channel"), + coords={ + "sample": np.arange(100), + "date": dates, + "channel": channels, + }, + ), + "allocation": xr.DataArray( + rng.uniform(100, 1000, size=(3,)), + dims=("channel",), + coords={"channel": channels}, + ), + } ) - assert isinstance(fig, Figure) - assert isinstance(axes, np.ndarray) - # Should create two columns (n_channels, 2) - assert axes.shape[1] == 2 - for col, country in enumerate(["A", "B"]): - for row in range(axes.shape[0]): - assert f"country={country}" in axes[row, col].get_title() - - def test_saturation_scatterplot_with_multiple_dims_lists( - self, mock_suite_with_constant_data - ): - """Test saturation_scatterplot with multiple lists in dims (should create subplots for each combination).""" - # Add a fake 'region' dim to the mock constant_data for this test if not present - idata = mock_suite_with_constant_data.idata - if "region" not in idata.constant_data.channel_data.dims: - # Expand channel_data and posterior to add region - new_regions = ["X", "Y"] - channel_data = idata.constant_data.channel_data.expand_dims( - region=new_regions - ) - idata.constant_data["channel_data"] = channel_data - for var in ["channel_contribution", "channel_contribution_original_scale"]: - if var in idata.posterior: - idata.posterior[var] = idata.posterior[var].expand_dims( - region=new_regions - ) - fig, axes = mock_suite_with_constant_data.saturation_scatterplot( - dims={"country": ["A", "B"], "region": ["X", "Y"]} + + pc = mock_suite.budget_allocation_roas(samples=samples, backend=backend) + + assert isinstance(pc, PlotCollection), ( + f"Expected PlotCollection for backend {backend}, got {type(pc)}" ) - assert isinstance(fig, Figure) - assert isinstance(axes, np.ndarray) - # Should create 4 columns (n_channels, 4) - assert axes.shape[1] == 4 - combos = [("A", "X"), ("A", "Y"), ("B", "X"), ("B", "Y")] - for col, (country, region) in enumerate(combos): - for row in range(axes.shape[0]): - title = axes[row, col].get_title() - assert f"country={country}" in title - assert f"region={region}" in title - - -class TestSaturationCurves: - def test_saturation_curves_basic( - self, mock_suite_with_constant_data, mock_saturation_curve + + @pytest.mark.parametrize("backend", ["matplotlib", "plotly", "bokeh"]) + def test_allocated_contribution_by_channel_over_time_all_backends( + self, mock_suite, backend ): - """Test basic functionality of saturation_curves.""" - fig, axes = mock_suite_with_constant_data.saturation_curves( - curve=mock_saturation_curve, n_samples=5 + """Test allocated_contribution_by_channel_over_time works with all backends.""" + from arviz_plots import PlotCollection + + # Create proper samples with 'sample', 'date', and 'channel' dimensions + rng = np.random.default_rng(42) + dates = pd.date_range("2025-01-01", periods=52, freq="W") + channels = ["TV", "Radio", "Digital"] + samples = xr.Dataset( + { + "channel_contribution": xr.DataArray( + rng.normal(size=(100, 52, 3)), + dims=("sample", "date", "channel"), + coords={ + "sample": np.arange(100), + "date": dates, + "channel": channels, + }, + ) + } ) - assert isinstance(fig, Figure) - assert isinstance(axes, np.ndarray) - assert all(isinstance(ax, Axes) for ax in axes.flat) + pc = mock_suite.allocated_contribution_by_channel_over_time( + samples=samples, backend=backend + ) - def test_saturation_curves_original_scale( - self, mock_suite_with_constant_data, mock_saturation_curve - ): - """Test saturation_curves with original_scale=True.""" - fig, axes = mock_suite_with_constant_data.saturation_curves( - curve=mock_saturation_curve, original_scale=True, n_samples=3 + assert isinstance(pc, PlotCollection), ( + f"Expected PlotCollection for backend {backend}, got {type(pc)}" ) - assert isinstance(fig, Figure) - assert isinstance(axes, np.ndarray) - assert all(isinstance(ax, Axes) for ax in axes.flat) - def test_saturation_curves_with_hdi( - self, mock_suite_with_constant_data, mock_saturation_curve +class TestSensitivityAnalysisBackends: + """Test sensitivity analysis methods across all backends.""" + + @pytest.mark.parametrize("backend", ["matplotlib", "plotly", "bokeh"]) + def test_sensitivity_analysis_all_backends( + self, mock_suite_with_sensitivity, backend ): - """Test saturation_curves with HDI intervals.""" - fig, axes = mock_suite_with_constant_data.saturation_curves( - curve=mock_saturation_curve, n_samples=3, hdi_probs=[0.5, 0.94] + """Test sensitivity_analysis works with all backends.""" + from arviz_plots import PlotCollection + + pc = mock_suite_with_sensitivity.sensitivity_analysis(backend=backend) + + assert isinstance(pc, PlotCollection), ( + f"Expected PlotCollection for backend {backend}, got {type(pc)}" ) - assert isinstance(fig, Figure) - assert isinstance(axes, np.ndarray) - assert all(isinstance(ax, Axes) for ax in axes.flat) + @pytest.mark.parametrize("backend", ["matplotlib", "plotly", "bokeh"]) + def test_uplift_curve_all_backends(self, mock_suite_with_sensitivity, backend): + """Test uplift_curve works with all backends.""" + from arviz_plots import PlotCollection - def test_saturation_curves_single_hdi( - self, mock_suite_with_constant_data, mock_saturation_curve - ): - """Test saturation_curves with single HDI probability.""" - fig, axes = mock_suite_with_constant_data.saturation_curves( - curve=mock_saturation_curve, n_samples=3, hdi_probs=0.85 + pc = mock_suite_with_sensitivity.uplift_curve(backend=backend) + + assert isinstance(pc, PlotCollection), ( + f"Expected PlotCollection for backend {backend}, got {type(pc)}" ) - assert isinstance(fig, Figure) - assert isinstance(axes, np.ndarray) - assert all(isinstance(ax, Axes) for ax in axes.flat) + @pytest.mark.parametrize("backend", ["matplotlib", "plotly", "bokeh"]) + def test_marginal_curve_all_backends(self, mock_suite_with_sensitivity, backend): + """Test marginal_curve works with all backends.""" + from arviz_plots import PlotCollection - def test_saturation_curves_custom_colors( - self, mock_suite_with_constant_data, mock_saturation_curve - ): - """Test saturation_curves with custom colors.""" - fig, axes = mock_suite_with_constant_data.saturation_curves( - curve=mock_saturation_curve, n_samples=3, colors=["red", "blue"] + pc = mock_suite_with_sensitivity.marginal_curve(backend=backend) + + assert isinstance(pc, PlotCollection), ( + f"Expected PlotCollection for backend {backend}, got {type(pc)}" ) - assert isinstance(fig, Figure) - assert isinstance(axes, np.ndarray) - assert all(isinstance(ax, Axes) for ax in axes.flat) - def test_saturation_curves_subplot_kwargs( - self, mock_suite_with_constant_data, mock_saturation_curve +class TestBackendBehavior: + """Test backend configuration and override behavior.""" + + def test_backend_overrides_global_config(self, mock_suite): + """Test that method backend parameter overrides global config.""" + from arviz_plots import PlotCollection + + from pymc_marketing.mmm import mmm_plot_config + + original = mmm_plot_config.get("plot.backend", "matplotlib") + + try: + # Set global to matplotlib + mmm_plot_config["plot.backend"] = "matplotlib" + + # Override with plotly + pc_plotly = mock_suite.contributions_over_time( + var=["intercept"], backend="plotly" + ) + assert isinstance(pc_plotly, PlotCollection) + + # Default should still be matplotlib + pc_default = mock_suite.contributions_over_time(var=["intercept"]) + assert isinstance(pc_default, PlotCollection) + + finally: + mmm_plot_config["plot.backend"] = original + + @pytest.mark.parametrize("config_backend", ["matplotlib", "plotly", "bokeh"]) + def test_backend_parameter_none_uses_config(self, mock_suite, config_backend): + """Test that backend=None uses global config.""" + from arviz_plots import PlotCollection + + from pymc_marketing.mmm import mmm_plot_config + + original = mmm_plot_config.get("plot.backend", "matplotlib") + + try: + mmm_plot_config["plot.backend"] = config_backend + + pc = mock_suite.contributions_over_time( + var=["intercept"], + backend=None, # Explicitly None + ) + + assert isinstance(pc, PlotCollection) + + finally: + mmm_plot_config["plot.backend"] = original + + def test_invalid_backend_raises_error(self, mock_suite): + """Test that invalid backend raises an appropriate error.""" + # Invalid backend should raise an error (arviz_plots behavior) + with pytest.raises((ModuleNotFoundError, ImportError, ValueError)): + _pc = mock_suite.contributions_over_time( + var=["intercept"], backend="invalid_backend" + ) + + +class TestDataParameters: + """Test explicit data parameter functionality.""" + + def test_contributions_over_time_with_explicit_data(self, mock_posterior_data): + """Test contributions_over_time accepts explicit data parameter.""" + from arviz_plots import PlotCollection + + # Create suite without idata + suite = MMMPlotSuite(idata=None) + + # Should work with explicit data parameter + pc = suite.contributions_over_time(var=["intercept"], data=mock_posterior_data) + + assert isinstance(pc, PlotCollection) + + def test_saturation_scatterplot_with_explicit_data( + self, mock_constant_data, mock_posterior_data ): - """Test saturation_curves with custom subplot_kwargs.""" - fig, axes = mock_suite_with_constant_data.saturation_curves( - curve=mock_saturation_curve, - n_samples=3, - subplot_kwargs={"figsize": (12, 8)}, + """Test saturation_scatterplot accepts explicit data parameters.""" + from arviz_plots import PlotCollection + + suite = MMMPlotSuite(idata=None) + + # Create posterior data with channel_contribution matching constant_data channels + rng = np.random.default_rng(42) + n_channels = len(mock_constant_data.coords["channel"]) + posterior_data = xr.Dataset( + { + "channel_contribution": xr.DataArray( + rng.normal(size=(4, 100, 52, n_channels)), + dims=("chain", "draw", "date", "channel"), + coords={ + "chain": np.arange(4), + "draw": np.arange(100), + "date": mock_constant_data.coords["date"], + "channel": mock_constant_data.coords["channel"], + }, + ) + } ) - assert isinstance(fig, Figure) - assert isinstance(axes, np.ndarray) - assert all(isinstance(ax, Axes) for ax in axes.flat) - # Check that figsize was applied - assert fig.get_size_inches()[0] == 12 - assert fig.get_size_inches()[1] == 8 + pc = suite.saturation_scatterplot( + constant_data=mock_constant_data, posterior_data=posterior_data + ) - def test_saturation_curves_rc_params( - self, mock_suite_with_constant_data, mock_saturation_curve - ): - """Test saturation_curves with rc_params.""" - fig, axes = mock_suite_with_constant_data.saturation_curves( - curve=mock_saturation_curve, n_samples=3, rc_params={"font.size": 14} + assert isinstance(pc, PlotCollection) + + +class TestIntegration: + """Integration tests for multiple plots and backend switching.""" + + def test_multiple_plots_same_suite_instance(self, mock_suite_with_constant_data): + """Test that same suite instance can create multiple plots.""" + from arviz_plots import PlotCollection + + suite = mock_suite_with_constant_data + + # Create multiple different plots + pc1 = suite.contributions_over_time(var=["channel_contribution"]) + pc2 = suite.saturation_scatterplot() + + assert isinstance(pc1, PlotCollection) + assert isinstance(pc2, PlotCollection) + + # All should be independent PlotCollection objects + assert pc1 is not pc2 + + def test_backend_switching_same_method(self, mock_suite): + """Test that backends can be switched for same method.""" + from arviz_plots import PlotCollection + + suite = mock_suite + + # Create same plot with different backends + pc_mpl = suite.contributions_over_time(var=["intercept"], backend="matplotlib") + pc_plotly = suite.contributions_over_time(var=["intercept"], backend="plotly") + pc_bokeh = suite.contributions_over_time(var=["intercept"], backend="bokeh") + + assert isinstance(pc_mpl, PlotCollection) + assert isinstance(pc_plotly, PlotCollection) + assert isinstance(pc_bokeh, PlotCollection) + + +# ============================================================================= +# Validation Error Tests +# ============================================================================= + + +class TestValidationErrors: + """Test validation and error handling.""" + + def test_posterior_predictive_invalid_hdi_prob(self, mock_suite): + """Test that invalid hdi_prob raises ValueError.""" + # Create idata with posterior_predictive + idata = mock_suite.idata.copy() + rng = np.random.default_rng(42) + dates = pd.date_range("2025-01-01", periods=52, freq="W") + idata.posterior_predictive = xr.Dataset( + { + "y": xr.DataArray( + rng.normal(size=(4, 100, 52)), + dims=("chain", "draw", "date"), + coords={ + "chain": np.arange(4), + "draw": np.arange(100), + "date": dates, + }, + ) + } ) + suite = MMMPlotSuite(idata=idata) - assert isinstance(fig, Figure) - assert isinstance(axes, np.ndarray) - assert all(isinstance(ax, Axes) for ax in axes.flat) + with pytest.raises(ValueError, match="HDI probability must be between 0 and 1"): + suite.posterior_predictive(hdi_prob=1.5) - def test_saturation_curves_no_samples( - self, mock_suite_with_constant_data, mock_saturation_curve + with pytest.raises(ValueError, match="HDI probability must be between 0 and 1"): + suite.posterior_predictive(hdi_prob=0.0) + + def test_contributions_over_time_invalid_hdi_prob(self, mock_suite): + """Test that invalid hdi_prob raises ValueError.""" + with pytest.raises(ValueError, match="HDI probability must be between 0 and 1"): + mock_suite.contributions_over_time(var=["intercept"], hdi_prob=2.0) + + def test_contributions_over_time_missing_variable(self, mock_suite): + """Test that missing variable raises ValueError.""" + with pytest.raises(ValueError, match="not found in data"): + mock_suite.contributions_over_time(var=["nonexistent_var"]) + + def test_posterior_predictive_no_data(self): + """Test that missing posterior_predictive data raises ValueError.""" + suite = MMMPlotSuite(idata=None) + + with pytest.raises(ValueError, match="No posterior_predictive data found"): + suite.posterior_predictive() + + def test_contributions_over_time_no_posterior(self): + """Test that missing posterior data raises ValueError.""" + suite = MMMPlotSuite(idata=None) + + with pytest.raises(ValueError, match="No posterior data found"): + suite.contributions_over_time(var=["intercept"]) + + def test_saturation_scatterplot_no_constant_data(self): + """Test that missing constant_data raises ValueError.""" + suite = MMMPlotSuite(idata=None) + + with pytest.raises(ValueError, match="No constant data found"): + suite.saturation_scatterplot() + + def test_saturation_scatterplot_missing_channel_data(self, mock_posterior_data): + """Test that missing channel_data variable raises ValueError.""" + suite = MMMPlotSuite(idata=None) + + # Create constant_data without channel_data + constant_data = xr.Dataset({"other_var": xr.DataArray([1, 2, 3])}) + + with pytest.raises(ValueError, match="'channel_data' variable not found"): + suite.saturation_scatterplot( + constant_data=constant_data, posterior_data=mock_posterior_data + ) + + def test_saturation_scatterplot_missing_channel_contribution( + self, mock_constant_data ): - """Test saturation_curves with n_samples=0.""" - fig, axes = mock_suite_with_constant_data.saturation_curves( - curve=mock_saturation_curve, n_samples=0, hdi_probs=0.85 + """Test that missing channel_contribution raises ValueError.""" + suite = MMMPlotSuite(idata=None) + + # Create posterior without channel_contribution + posterior = xr.Dataset({"other_var": xr.DataArray([1, 2, 3])}) + + with pytest.raises(ValueError, match=r"No posterior\.channel_contribution"): + suite.saturation_scatterplot( + constant_data=mock_constant_data, posterior_data=posterior + ) + + def test_saturation_curves_missing_x_dimension(self, mock_suite_with_constant_data): + """Test that curve without 'x' dimension raises ValueError.""" + # Create curve without 'x' dimension + bad_curve = xr.DataArray( + np.random.rand(10, 2), + dims=("time", "channel"), + coords={"time": np.arange(10), "channel": ["A", "B"]}, ) - assert isinstance(fig, Figure) - assert isinstance(axes, np.ndarray) - assert all(isinstance(ax, Axes) for ax in axes.flat) + with pytest.raises(ValueError, match="curve must have an 'x' dimension"): + mock_suite_with_constant_data.saturation_curves(curve=bad_curve) - def test_saturation_curves_no_constant_data( - self, mock_suite, mock_saturation_curve + def test_saturation_curves_missing_channel_dimension( + self, mock_suite_with_constant_data ): - """Test that saturation_curves raises error without constant_data.""" - with pytest.raises(ValueError, match=r"No 'constant_data' found"): - mock_suite.saturation_curves(curve=mock_saturation_curve) + """Test that curve without 'channel' dimension raises ValueError.""" + # Create curve without 'channel' dimension + bad_curve = xr.DataArray( + np.random.rand(10, 20), + dims=("time", "x"), + coords={"time": np.arange(10), "x": np.linspace(0, 1, 20)}, + ) - def test_saturation_curves_no_original_scale_contribution( - self, mock_suite_with_constant_data, mock_saturation_curve - ): - """Test that saturation_curves raises error when original_scale=True but no original scale data.""" - # Remove the original scale contribution from the mock data - idata_copy = mock_suite_with_constant_data.idata.copy() - idata_copy.posterior = idata_copy.posterior.drop_vars( - "channel_contribution_original_scale" + with pytest.raises(ValueError, match="curve must have a 'channel' dimension"): + mock_suite_with_constant_data.saturation_curves(curve=bad_curve) + + def test_budget_allocation_roas_missing_channel_dim(self, mock_suite): + """Test that samples without channel dimension raises ValueError.""" + # Create samples without channel dimension + samples = xr.Dataset({"some_var": xr.DataArray([1, 2, 3])}) + + with pytest.raises(ValueError, match="Expected 'channel' dimension"): + mock_suite.budget_allocation_roas(samples=samples) + + def test_budget_allocation_roas_missing_contribution(self, mock_suite): + """Test that samples without contribution variable raises ValueError.""" + # Create samples with channel but missing contribution + samples = xr.Dataset( + { + "other_var": xr.DataArray( + [1, 2, 3], dims=("channel",), coords={"channel": ["A", "B", "C"]} + ) + } + ) + + with pytest.raises( + ValueError, + match="Expected a variable containing 'channel_contribution_original_scale'", + ): + mock_suite.budget_allocation_roas(samples=samples) + + def test_budget_allocation_roas_missing_allocation(self, mock_suite): + """Test that samples without allocation raises ValueError.""" + rng = np.random.default_rng(42) + dates = pd.date_range("2025-01-01", periods=52, freq="W") + channels = ["A", "B", "C"] + + # Create samples with contribution but missing allocation + samples = xr.Dataset( + { + "channel_contribution_original_scale": xr.DataArray( + rng.normal(size=(100, 52, 3)), + dims=("sample", "date", "channel"), + coords={ + "sample": np.arange(100), + "date": dates, + "channel": channels, + }, + ) + } + ) + + with pytest.raises(ValueError, match="Expected 'allocation' variable"): + mock_suite.budget_allocation_roas(samples=samples) + + def test_allocated_contribution_missing_channel(self, mock_suite): + """Test that samples without channel dimension raises ValueError.""" + samples = xr.Dataset({"some_var": xr.DataArray([1, 2, 3])}) + + with pytest.raises(ValueError, match="Expected 'channel' dimension"): + mock_suite.allocated_contribution_by_channel_over_time(samples=samples) + + def test_allocated_contribution_missing_date(self, mock_suite): + """Test that samples without date dimension raises ValueError.""" + samples = xr.Dataset( + { + "channel_contribution": xr.DataArray( + [[1, 2], [3, 4]], + dims=("sample", "channel"), + coords={"sample": [0, 1], "channel": ["A", "B"]}, + ) + } + ) + + with pytest.raises(ValueError, match="Expected 'date' dimension"): + mock_suite.allocated_contribution_by_channel_over_time(samples=samples) + + def test_allocated_contribution_missing_sample(self, mock_suite): + """Test that samples without sample dimension raises ValueError.""" + dates = pd.date_range("2025-01-01", periods=10, freq="W") + samples = xr.Dataset( + { + "channel_contribution": xr.DataArray( + [[1, 2], [3, 4]], + dims=("date", "channel"), + coords={"date": dates[:2], "channel": ["A", "B"]}, + ) + } + ) + + with pytest.raises(ValueError, match="Expected 'sample' dimension"): + mock_suite.allocated_contribution_by_channel_over_time(samples=samples) + + def test_allocated_contribution_missing_contribution_var(self, mock_suite): + """Test that samples without channel_contribution variable raises ValueError.""" + dates = pd.date_range("2025-01-01", periods=10, freq="W") + samples = xr.Dataset( + { + "other_var": xr.DataArray( + [[[1, 2]]], + dims=("sample", "date", "channel"), + coords={"sample": [0], "date": dates[:1], "channel": ["A", "B"]}, + ) + } ) - suite_without_original_scale = MMMPlotSuite(idata=idata_copy) with pytest.raises( - ValueError, match=r"No posterior.channel_contribution_original_scale" + ValueError, match="Expected a variable containing 'channel_contribution'" ): - suite_without_original_scale.saturation_curves( - curve=mock_saturation_curve, original_scale=True + mock_suite.allocated_contribution_by_channel_over_time(samples=samples) + + def test_sensitivity_analysis_invalid_dimensions(self, mock_suite): + """Test that data without required dimensions raises ValueError.""" + # Create data without required dimensions + bad_data = xr.DataArray( + np.random.rand(10, 20), dims=("time", "space"), name="x" + ) + + with pytest.raises(ValueError, match="Data must have dimensions"): + mock_suite._sensitivity_analysis_plot(data=bad_data) + + def test_sensitivity_analysis_no_data(self): + """Test that missing sensitivity_analysis group raises ValueError.""" + suite = MMMPlotSuite(idata=None) + + with pytest.raises(ValueError, match="No sensitivity analysis results found"): + suite.sensitivity_analysis() + + def test_uplift_curve_missing_data(self): + """Test that missing uplift_curve raises ValueError.""" + # Create idata with sensitivity_analysis but without uplift_curve + rng = np.random.default_rng(42) + idata = az.InferenceData( + posterior=xr.Dataset( + {"intercept": xr.DataArray(rng.normal(size=(4, 100)))} + ), + sensitivity_analysis=xr.Dataset( + { + "x": xr.DataArray( + rng.normal(size=(100, 20)), + dims=("sample", "sweep"), + coords={ + "sample": np.arange(100), + "sweep": np.linspace(0, 1, 20), + }, + ) + } + ), + ) + suite = MMMPlotSuite(idata=idata) + + with pytest.raises(ValueError, match="Expected 'uplift_curve'"): + suite.uplift_curve() + + def test_marginal_curve_missing_data(self): + """Test that missing marginal_effects raises ValueError.""" + # Create idata with sensitivity_analysis but without marginal_effects + rng = np.random.default_rng(42) + idata = az.InferenceData( + posterior=xr.Dataset( + {"intercept": xr.DataArray(rng.normal(size=(4, 100)))} + ), + sensitivity_analysis=xr.Dataset( + { + "x": xr.DataArray( + rng.normal(size=(100, 20)), + dims=("sample", "sweep"), + coords={ + "sample": np.arange(100), + "sweep": np.linspace(0, 1, 20), + }, + ) + } + ), + ) + suite = MMMPlotSuite(idata=idata) + + with pytest.raises(ValueError, match="Expected 'marginal_effects'"): + suite.marginal_curve() + + def test_get_additional_dim_combinations_missing_variable(self, mock_suite): + """Test that missing variable in dataset raises ValueError.""" + with pytest.raises(ValueError, match="Variable 'nonexistent' not found"): + mock_suite._get_additional_dim_combinations( + data=mock_suite.idata.posterior, + variable="nonexistent", + ignored_dims={"chain", "draw"}, ) + def test_validate_dims_invalid_dimension(self, mock_suite): + """Test that invalid dimension raises ValueError.""" + with pytest.raises(ValueError, match="Dimension 'invalid_dim' not found"): + mock_suite._validate_dims( + dims={"invalid_dim": "A"}, all_dims=["chain", "draw", "country"] + ) -class TestSaturationCurvesDims: - def test_saturation_curves_with_dim( - self, mock_suite_with_constant_data, mock_saturation_curve - ): - """Test saturation_curves with a single value in dims.""" - fig, axes = mock_suite_with_constant_data.saturation_curves( - curve=mock_saturation_curve, n_samples=3, dims={"country": "A"} + def test_validate_dims_invalid_value(self, mock_suite): + """Test that invalid dimension value raises ValueError.""" + with pytest.raises(ValueError, match="Value 'Z' not found in dimension"): + mock_suite._validate_dims( + dims={"country": "Z"}, all_dims=["chain", "draw", "country"] + ) + + def test_validate_dims_invalid_list_value(self, mock_suite): + """Test that invalid value in list raises ValueError.""" + with pytest.raises(ValueError, match="Value 'Z' not found in dimension"): + mock_suite._validate_dims( + dims={"country": ["A", "Z"]}, all_dims=["chain", "draw", "country"] + ) + + +# ============================================================================= +# Edge Case Tests +# ============================================================================= + + +class TestEdgeCases: + """Test edge cases and special scenarios.""" + + def test_contributions_over_time_with_dims_filtering(self, mock_suite): + """Test contributions_over_time with dims parameter.""" + from arviz_plots import PlotCollection + + # Filter to specific country + pc = mock_suite.contributions_over_time( + var=["intercept"], dims={"country": "A"} + ) + assert isinstance(pc, PlotCollection) + + def test_contributions_over_time_with_list_dims(self, mock_suite): + """Test contributions_over_time with list-valued dims.""" + from arviz_plots import PlotCollection + + # Filter to multiple countries + pc = mock_suite.contributions_over_time( + var=["intercept"], dims={"country": ["A", "B"]} ) - assert isinstance(fig, Figure) - assert isinstance(axes, np.ndarray) + assert isinstance(pc, PlotCollection) - for row in range(axes.shape[0]): - assert "country=A" in axes[row, 0].get_title() + def test_saturation_scatterplot_with_dims_single_value( + self, mock_suite_with_constant_data + ): + """Test saturation_scatterplot with single-value dims.""" + from arviz_plots import PlotCollection - def test_saturation_curves_with_dims_list( + pc = mock_suite_with_constant_data.saturation_scatterplot(dims={"country": "A"}) + assert isinstance(pc, PlotCollection) + + def test_saturation_scatterplot_with_dims_list(self, mock_suite_with_constant_data): + """Test saturation_scatterplot with list-valued dims.""" + from arviz_plots import PlotCollection + + pc = mock_suite_with_constant_data.saturation_scatterplot( + dims={"country": ["A", "B"]} + ) + assert isinstance(pc, PlotCollection) + + def test_saturation_curves_with_hdi_probs_float( self, mock_suite_with_constant_data, mock_saturation_curve ): - """Test saturation_curves with a list in dims (should create subplots for each value).""" - fig, axes = mock_suite_with_constant_data.saturation_curves( - curve=mock_saturation_curve, n_samples=3, dims={"country": ["A", "B"]} + """Test saturation_curves with float hdi_probs.""" + from arviz_plots import PlotCollection + + pc = mock_suite_with_constant_data.saturation_curves( + curve=mock_saturation_curve, hdi_probs=0.9, n_samples=3 ) - assert isinstance(fig, Figure) - assert isinstance(axes, np.ndarray) + assert isinstance(pc, PlotCollection) - def test_saturation_curves_with_multiple_dims_lists( + def test_saturation_curves_with_hdi_probs_list( self, mock_suite_with_constant_data, mock_saturation_curve ): - """Test saturation_curves with multiple lists in dims (should create subplots for each combination).""" - # Add a fake 'region' dim to the mock constant_data for this test if not present - idata = mock_suite_with_constant_data.idata - if "region" not in idata.constant_data.channel_data.dims: - # Expand channel_data and posterior to add region - new_regions = ["X", "Y"] - channel_data = idata.constant_data.channel_data.expand_dims( - region=new_regions - ) - idata.constant_data["channel_data"] = channel_data - for var in ["channel_contribution", "channel_contribution_original_scale"]: - if var in idata.posterior: - idata.posterior[var] = idata.posterior[var].expand_dims( - region=new_regions - ) - fig, axes = mock_suite_with_constant_data.saturation_curves( - curve=mock_saturation_curve, - n_samples=3, - dims={"country": ["A", "B"], "region": ["X", "Y"]}, + """Test saturation_curves with list of hdi_probs.""" + from arviz_plots import PlotCollection + + pc = mock_suite_with_constant_data.saturation_curves( + curve=mock_saturation_curve, hdi_probs=[0.5, 0.9], n_samples=3 ) - assert isinstance(fig, Figure) - assert isinstance(axes, np.ndarray) - combos = [("A", "X"), ("A", "Y"), ("B", "X"), ("B", "Y")] + assert isinstance(pc, PlotCollection) - for col, (country, region) in enumerate(combos): - for row in range(axes.shape[0]): - title = axes[row, col].get_title() - assert f"country={country}" in title - assert f"region={region}" in title + def test_saturation_curves_with_hdi_probs_tuple( + self, mock_suite_with_constant_data, mock_saturation_curve + ): + """Test saturation_curves with tuple of hdi_probs.""" + from arviz_plots import PlotCollection + pc = mock_suite_with_constant_data.saturation_curves( + curve=mock_saturation_curve, hdi_probs=(0.5, 0.9), n_samples=3 + ) + assert isinstance(pc, PlotCollection) -def test_saturation_curves_scatter_deprecation_warning(mock_suite_with_constant_data): - """Test that saturation_curves_scatter shows deprecation warning.""" - with pytest.warns( - DeprecationWarning, match=r"saturation_curves_scatter is deprecated" + def test_saturation_curves_with_hdi_probs_array( + self, mock_suite_with_constant_data, mock_saturation_curve ): - fig, axes = mock_suite_with_constant_data.saturation_curves_scatter() + """Test saturation_curves with numpy array of hdi_probs.""" + from arviz_plots import PlotCollection - assert isinstance(fig, Figure) - assert isinstance(axes, np.ndarray) - assert all(isinstance(ax, Axes) for ax in axes.flat) + pc = mock_suite_with_constant_data.saturation_curves( + curve=mock_saturation_curve, + hdi_probs=np.array([0.5, 0.9]), + n_samples=3, + ) + assert isinstance(pc, PlotCollection) + def test_budget_allocation_roas_with_dims_to_group_by_string(self, mock_suite): + """Test budget_allocation_roas with dims_to_group_by as string.""" + from arviz_plots import PlotCollection -@pytest.fixture(scope="module") -def mock_idata_with_constant_data_single_dim() -> az.InferenceData: - """Mock InferenceData where channel_data has only ('date','channel') dims.""" - seed = sum(map(ord, "Saturation single-dim tests")) - rng = np.random.default_rng(seed) - normal = rng.normal + rng = np.random.default_rng(42) + dates = pd.date_range("2025-01-01", periods=52, freq="W") + channels = ["TV", "Radio", "Digital"] + regions = ["East", "West"] - dates = pd.date_range("2025-01-01", periods=12, freq="W-MON") - channels = ["channel_1", "channel_2", "channel_3"] + samples = xr.Dataset( + { + "channel_contribution_original_scale": xr.DataArray( + rng.normal(loc=1000, scale=100, size=(100, 52, 3, 2)), + dims=("sample", "date", "channel", "region"), + coords={ + "sample": np.arange(100), + "date": dates, + "channel": channels, + "region": regions, + }, + ), + "allocation": xr.DataArray( + rng.uniform(100, 1000, size=(3, 2)), + dims=("channel", "region"), + coords={"channel": channels, "region": regions}, + ), + } + ) - posterior = xr.Dataset( - { - "channel_contribution": xr.DataArray( - normal(size=(2, 10, 12, 3)), - dims=("chain", "draw", "date", "channel"), - coords={ - "chain": np.arange(2), - "draw": np.arange(10), - "date": dates, - "channel": channels, - }, - ), - "channel_contribution_original_scale": xr.DataArray( - normal(size=(2, 10, 12, 3)) * 100.0, - dims=("chain", "draw", "date", "channel"), - coords={ - "chain": np.arange(2), - "draw": np.arange(10), - "date": dates, - "channel": channels, - }, - ), - } - ) + pc = mock_suite.budget_allocation_roas( + samples=samples, dims_to_group_by="region" + ) + assert isinstance(pc, PlotCollection) - constant_data = xr.Dataset( - { - "channel_data": xr.DataArray( - rng.uniform(0, 10, size=(12, 3)), - dims=("date", "channel"), - coords={"date": dates, "channel": channels}, - ), - "channel_scale": xr.DataArray( - [100.0, 150.0, 200.0], dims=("channel",), coords={"channel": channels} - ), - "target_scale": xr.DataArray( - [1000.0], dims="target", coords={"target": ["y"]} - ), - } - ) + def test_budget_allocation_roas_with_dims_to_group_by_list(self, mock_suite): + """Test budget_allocation_roas with dims_to_group_by as list.""" + from arviz_plots import PlotCollection - return az.InferenceData(posterior=posterior, constant_data=constant_data) + rng = np.random.default_rng(42) + dates = pd.date_range("2025-01-01", periods=52, freq="W") + channels = ["TV", "Radio"] + regions = ["East", "West"] + samples = xr.Dataset( + { + "channel_contribution_original_scale": xr.DataArray( + rng.normal(loc=1000, scale=100, size=(100, 52, 2, 2)), + dims=("sample", "date", "channel", "region"), + coords={ + "sample": np.arange(100), + "date": dates, + "channel": channels, + "region": regions, + }, + ), + "allocation": xr.DataArray( + rng.uniform(100, 1000, size=(2, 2)), + dims=("channel", "region"), + coords={"channel": channels, "region": regions}, + ), + } + ) -@pytest.fixture(scope="module") -def mock_suite_with_constant_data_single_dim(mock_idata_with_constant_data_single_dim): - return MMMPlotSuite(idata=mock_idata_with_constant_data_single_dim) + pc = mock_suite.budget_allocation_roas( + samples=samples, dims_to_group_by=["channel", "region"] + ) + assert isinstance(pc, PlotCollection) + def test_sensitivity_analysis_with_aggregation_sum(self, mock_sensitivity_data): + """Test sensitivity_analysis_plot with sum aggregation.""" + from arviz_plots import PlotCollection -@pytest.fixture(scope="module") -def mock_saturation_curve_single_dim() -> xr.DataArray: - """Saturation curve with dims ('chain','draw','channel','x').""" - seed = sum(map(ord, "Saturation curve single-dim")) - rng = np.random.default_rng(seed) - x_values = np.linspace(0, 1, 50) - channels = ["channel_1", "channel_2", "channel_3"] - - # shape: (chains=2, draws=10, channel=3, x=50) - curve_array = np.empty((2, 10, len(channels), len(x_values))) - for ci in range(2): - for di in range(10): - for c in range(len(channels)): - curve_array[ci, di, c, :] = x_values / (1 + x_values) + rng.normal( - 0, 0.02, size=x_values.shape + # Add a dimension to aggregate over + data_with_dim = xr.Dataset( + { + "x": xr.DataArray( + np.random.rand(100, 20, 3), + dims=("sample", "sweep", "channel"), + coords={ + "sample": np.arange(100), + "sweep": np.linspace(0, 1, 20), + "channel": ["A", "B", "C"], + }, ) + } + ) - return xr.DataArray( - curve_array, - dims=("chain", "draw", "channel", "x"), - coords={ - "chain": np.arange(2), - "draw": np.arange(10), - "channel": channels, - "x": x_values, - }, - name="saturation_curve", - ) + suite = MMMPlotSuite(idata=None) + pc = suite._sensitivity_analysis_plot( + data=data_with_dim, aggregation={"sum": ("channel",)} + ) + assert isinstance(pc, PlotCollection) + def test_sensitivity_analysis_with_aggregation_mean(self, mock_sensitivity_data): + """Test sensitivity_analysis_plot with mean aggregation.""" + from arviz_plots import PlotCollection -def test_saturation_curves_single_dim_axes_shape( - mock_suite_with_constant_data_single_dim, mock_saturation_curve_single_dim -): - """When there are no extra dims, columns should default to 1 (no ncols=0).""" - fig, axes = mock_suite_with_constant_data_single_dim.saturation_curves( - curve=mock_saturation_curve_single_dim, n_samples=3 - ) + data_with_dim = xr.Dataset( + { + "x": xr.DataArray( + np.random.rand(100, 20, 3), + dims=("sample", "sweep", "channel"), + coords={ + "sample": np.arange(100), + "sweep": np.linspace(0, 1, 20), + "channel": ["A", "B", "C"], + }, + ) + } + ) - assert isinstance(fig, Figure) - assert isinstance(axes, np.ndarray) - # Expect (n_channels, 1) - assert axes.shape[1] == 1 - assert axes.shape[0] == mock_saturation_curve_single_dim.sizes["channel"] + suite = MMMPlotSuite(idata=None) + pc = suite._sensitivity_analysis_plot( + data=data_with_dim, aggregation={"mean": ("channel",)} + ) + assert isinstance(pc, PlotCollection) + def test_sensitivity_analysis_with_aggregation_median(self, mock_sensitivity_data): + """Test sensitivity_analysis_plot with median aggregation.""" + from arviz_plots import PlotCollection -def test_saturation_curves_multi_dim_axes_shape( - mock_suite_with_constant_data, mock_saturation_curve -): - """With an extra dim (e.g., 'country'), expect (n_channels, n_countries).""" - fig, axes = mock_suite_with_constant_data.saturation_curves( - curve=mock_saturation_curve, n_samples=2 - ) + data_with_dim = xr.Dataset( + { + "x": xr.DataArray( + np.random.rand(100, 20, 3), + dims=("sample", "sweep", "channel"), + coords={ + "sample": np.arange(100), + "sweep": np.linspace(0, 1, 20), + "channel": ["A", "B", "C"], + }, + ) + } + ) - assert isinstance(fig, Figure) - assert isinstance(axes, np.ndarray) + suite = MMMPlotSuite(idata=None) + pc = suite._sensitivity_analysis_plot( + data=data_with_dim, aggregation={"median": ("channel",)} + ) + assert isinstance(pc, PlotCollection) + def test_uplift_curve_with_dataset_containing_uplift_curve(self): + """Test uplift_curve when data is Dataset with uplift_curve variable.""" + from arviz_plots import PlotCollection -def test_sensitivity_analysis_basic(mock_suite_with_sensitivity): - fig, axes = mock_suite_with_sensitivity.sensitivity_analysis() + rng = np.random.default_rng(42) + data = xr.Dataset( + { + "uplift_curve": xr.DataArray( + rng.normal(size=(100, 20)), + dims=("sample", "sweep"), + coords={ + "sample": np.arange(100), + "sweep": np.linspace(0, 1, 20), + }, + ) + } + ) - assert isinstance(fig, Figure) - assert isinstance(axes, np.ndarray) - assert axes.ndim == 2 - expected_panels = len( - mock_suite_with_sensitivity.idata.sensitivity_analysis.coords["region"] - ) # type: ignore - assert axes.size >= expected_panels - assert all(isinstance(ax, Axes) for ax in axes.flat[:expected_panels]) + suite = MMMPlotSuite(idata=None) + pc = suite.uplift_curve(data=data) + assert isinstance(pc, PlotCollection) + def test_uplift_curve_with_dataset_containing_x(self): + """Test uplift_curve when data is Dataset with x variable.""" + from arviz_plots import PlotCollection -def test_sensitivity_analysis_with_aggregation(mock_suite_with_sensitivity): - ax = mock_suite_with_sensitivity.sensitivity_analysis( - aggregation={"sum": ("region",)} - ) - assert isinstance(ax, Axes) + rng = np.random.default_rng(42) + data = xr.Dataset( + { + "x": xr.DataArray( + rng.normal(size=(100, 20)), + dims=("sample", "sweep"), + coords={ + "sample": np.arange(100), + "sweep": np.linspace(0, 1, 20), + }, + ) + } + ) + suite = MMMPlotSuite(idata=None) + pc = suite.uplift_curve(data=data) + assert isinstance(pc, PlotCollection) -def test_marginal_curve(mock_suite_with_sensitivity): - fig, axes = mock_suite_with_sensitivity.marginal_curve() + def test_marginal_curve_with_dataset_containing_marginal_effects(self): + """Test marginal_curve when data is Dataset with marginal_effects variable.""" + from arviz_plots import PlotCollection - assert isinstance(fig, Figure) - assert isinstance(axes, np.ndarray) - assert axes.ndim == 2 - regions = mock_suite_with_sensitivity.idata.sensitivity_analysis.coords["region"] # type: ignore - assert axes.size >= len(regions) - assert all(isinstance(ax, Axes) for ax in axes.flat[: len(regions)]) + rng = np.random.default_rng(42) + data = xr.Dataset( + { + "marginal_effects": xr.DataArray( + rng.normal(size=(100, 20)), + dims=("sample", "sweep"), + coords={ + "sample": np.arange(100), + "sweep": np.linspace(0, 1, 20), + }, + ) + } + ) + suite = MMMPlotSuite(idata=None) + pc = suite.marginal_curve(data=data) + assert isinstance(pc, PlotCollection) -def test_uplift_curve(mock_suite_with_sensitivity): - fig, axes = mock_suite_with_sensitivity.uplift_curve() + def test_marginal_curve_with_dataset_containing_x(self): + """Test marginal_curve when data is Dataset with x variable.""" + from arviz_plots import PlotCollection - assert isinstance(fig, Figure) - assert isinstance(axes, np.ndarray) - assert axes.ndim == 2 - regions = mock_suite_with_sensitivity.idata.sensitivity_analysis.coords["region"] # type: ignore - assert axes.size >= len(regions) - assert all(isinstance(ax, Axes) for ax in axes.flat[: len(regions)]) + rng = np.random.default_rng(42) + data = xr.Dataset( + { + "x": xr.DataArray( + rng.normal(size=(100, 20)), + dims=("sample", "sweep"), + coords={ + "sample": np.arange(100), + "sweep": np.linspace(0, 1, 20), + }, + ) + } + ) + suite = MMMPlotSuite(idata=None) + pc = suite.marginal_curve(data=data) + assert isinstance(pc, PlotCollection) -def test_sensitivity_analysis_multi_panel(mock_suite_with_sensitivity): - # The fixture provides an extra 'region' dimension, so multiple panels should be produced - fig, axes = mock_suite_with_sensitivity.sensitivity_analysis( - subplot_kwargs={"ncols": 2} - ) - assert isinstance(fig, Figure) - assert isinstance(axes, np.ndarray) - assert axes.ndim == 2 - # There should be two regions, therefore exactly two panels - expected_panels = len( - mock_suite_with_sensitivity.idata.sensitivity_analysis.coords["region"] - ) # type: ignore - assert axes.size >= expected_panels - assert all(isinstance(ax, Axes) for ax in axes.flat[:expected_panels]) - - -def test_sensitivity_analysis_error_on_missing_results(mock_idata): - suite = MMMPlotSuite(idata=mock_idata) - with pytest.raises(ValueError, match=r"No sensitivity analysis results found"): - suite.sensitivity_analysis() - suite.plot_sensitivity_analysis() - - -def test_budget_allocation_with_dims(mock_suite_with_constant_data): - # Use dims to filter to a single country - samples = mock_suite_with_constant_data.idata.posterior - # Add a fake 'allocation' variable for testing - samples = samples.copy() - samples["allocation"] = ( - samples["channel_contribution"].dims, - np.abs(samples["channel_contribution"].values), - ) - plot_suite = mock_suite_with_constant_data - fig, _ax = plot_suite.budget_allocation( - samples=samples, - dims={"country": "A"}, - ) - assert isinstance(fig, Figure) +# ============================================================================= +# Original Scale Tests +# ============================================================================= -def test_budget_allocation_with_dims_list(mock_suite_with_constant_data): - """Test that passing a list to dims creates a subplot for each value.""" - samples = mock_suite_with_constant_data.idata.posterior.copy() - # Add a fake 'allocation' variable for testing - samples["allocation"] = ( - samples["channel_contribution"].dims, - np.abs(samples["channel_contribution"].values), - ) - plot_suite = mock_suite_with_constant_data - fig, ax = plot_suite.budget_allocation( - samples=samples, - dims={"country": ["A", "B"]}, - ) - assert isinstance(fig, Figure) - assert isinstance(ax, np.ndarray) +class TestOriginalScale: + """Test original_scale parameter functionality.""" + def test_saturation_scatterplot_original_scale_true( + self, mock_suite_with_constant_data + ): + """Test saturation_scatterplot with original_scale=True.""" + from arviz_plots import PlotCollection -def test__validate_dims_valid(): - """Test _validate_dims with valid dims and values.""" - suite = MMMPlotSuite(idata=None) + pc = mock_suite_with_constant_data.saturation_scatterplot(original_scale=True) + assert isinstance(pc, PlotCollection) - # Patch suite.idata.posterior.coords to simulate valid dims - class DummyCoord: - def __init__(self, values): - self.values = values + def test_saturation_scatterplot_original_scale_missing_variable( + self, mock_constant_data + ): + """Test that original_scale=True without variable raises ValueError.""" + suite = MMMPlotSuite(idata=None) - class DummyCoords: - def __init__(self): - self._coords = { - "country": DummyCoord(["A", "B"]), - "region": DummyCoord(["X", "Y"]), + # Create posterior without channel_contribution_original_scale + posterior = xr.Dataset( + { + "channel_contribution": xr.DataArray( + np.random.rand(4, 100, 52, 3), + dims=("chain", "draw", "date", "channel"), + coords={ + "chain": np.arange(4), + "draw": np.arange(100), + "date": mock_constant_data.coords["date"], + "channel": mock_constant_data.coords["channel"], + }, + ) } + ) + + with pytest.raises( + ValueError, match=r"No posterior\.channel_contribution_original_scale" + ): + suite.saturation_scatterplot( + original_scale=True, + constant_data=mock_constant_data, + posterior_data=posterior, + ) - def __getitem__(self, key): - return self._coords[key] + def test_saturation_curves_original_scale_missing_variable( + self, mock_constant_data, mock_saturation_curve + ): + """Test that original_scale=True without variable raises ValueError.""" + suite = MMMPlotSuite(idata=None) - class DummyPosterior: - coords = DummyCoords() + # Create posterior without channel_contribution_original_scale + posterior = xr.Dataset( + { + "channel_contribution": xr.DataArray( + np.random.rand(4, 100, 52, 3), + dims=("chain", "draw", "date", "channel"), + coords={ + "chain": np.arange(4), + "draw": np.arange(100), + "date": mock_constant_data.coords["date"], + "channel": mock_constant_data.coords["channel"], + }, + ) + } + ) - suite.idata = type("idata", (), {"posterior": DummyPosterior()})() - # Should not raise - suite._validate_dims({"country": "A", "region": "X"}, ["country", "region"]) - suite._validate_dims({"country": ["A", "B"]}, ["country", "region"]) + with pytest.raises(ValueError, match=r"No posterior\.channel_contribution"): + suite.saturation_curves( + curve=mock_saturation_curve, + original_scale=True, + constant_data=mock_constant_data, + posterior_data=posterior, + ) -def test__validate_dims_invalid_dim(): - """Test _validate_dims raises for invalid dim name.""" - suite = MMMPlotSuite(idata=None) +# ============================================================================= +# Deprecated Method Tests +# ============================================================================= - class DummyCoord: - def __init__(self, values): - self.values = values - class DummyCoords: - def __init__(self): - self.country = DummyCoord(["A", "B"]) +class TestDeprecatedMethods: + """Test deprecated methods raise appropriate errors.""" - def __getitem__(self, key): - return getattr(self, key) + def test_budget_allocation_raises_not_implemented(self, mock_suite): + """Test that budget_allocation() raises NotImplementedError.""" + with pytest.raises(NotImplementedError, match=r"budget_allocation.*removed"): + mock_suite.budget_allocation() - class DummyPosterior: - coords = DummyCoords() - suite.idata = type("idata", (), {"posterior": DummyPosterior()})() - with pytest.raises(ValueError, match=r"Dimension 'region' not found"): - suite._validate_dims({"region": "X"}, ["country"]) +# ============================================================================= +# Additional Coverage Tests +# ============================================================================= -def test__validate_dims_invalid_value(): - """Test _validate_dims raises for invalid value.""" - suite = MMMPlotSuite(idata=None) +class TestAdditionalCoverage: + """Additional tests to reach >95% coverage.""" - class DummyCoord: - def __init__(self, values): - self.values = values + def test_posterior_predictive_with_explicit_idata(self): + """Test posterior_predictive with explicit idata parameter.""" + from arviz_plots import PlotCollection - class DummyCoords: - def __init__(self): - self.country = DummyCoord(["A", "B"]) + rng = np.random.default_rng(42) + dates = pd.date_range("2025-01-01", periods=52, freq="W") - def __getitem__(self, key): - return getattr(self, key) + # Create posterior_predictive dataset + pp_data = xr.Dataset( + { + "y": xr.DataArray( + rng.normal(size=(4, 100, 52)), + dims=("chain", "draw", "date"), + coords={ + "chain": np.arange(4), + "draw": np.arange(100), + "date": dates, + }, + ) + } + ) - class DummyPosterior: - coords = DummyCoords() + # Create suite without idata + suite = MMMPlotSuite(idata=None) - suite.idata = type("idata", (), {"posterior": DummyPosterior()})() - with pytest.raises(ValueError, match=r"Value 'C' not found in dimension 'country'"): - suite._validate_dims({"country": "C"}, ["country"]) + # Should work with explicit idata parameter + pc = suite.posterior_predictive(var="y", idata=pp_data) + assert isinstance(pc, PlotCollection) + def test_saturation_curves_with_invalid_hdi_probs_type( + self, mock_suite_with_constant_data, mock_saturation_curve + ): + """Test that invalid hdi_probs type raises TypeError.""" + with pytest.raises(TypeError, match="hdi_probs must be a float"): + mock_suite_with_constant_data.saturation_curves( + curve=mock_saturation_curve, hdi_probs={"invalid": "type"} + ) -def test__dim_list_handler_none(): - """Test _dim_list_handler with None input.""" - suite = MMMPlotSuite(idata=None) - keys, combos = suite._dim_list_handler(None) - assert keys == [] - assert combos == [()] + def test_uplift_curve_with_dataset_missing_both_variables(self): + """Test uplift_curve when Dataset has neither uplift_curve nor x.""" + rng = np.random.default_rng(42) + data = xr.Dataset( + { + "other_var": xr.DataArray( + rng.normal(size=(100, 20)), + dims=("sample", "sweep"), + coords={ + "sample": np.arange(100), + "sweep": np.linspace(0, 1, 20), + }, + ) + } + ) + suite = MMMPlotSuite(idata=None) + with pytest.raises(ValueError, match="must contain 'uplift_curve' or 'x'"): + suite.uplift_curve(data=data) -def test__dim_list_handler_single(): - """Test _dim_list_handler with a single list-valued dim.""" - suite = MMMPlotSuite(idata=None) - keys, combos = suite._dim_list_handler({"country": ["A", "B"]}) - assert keys == ["country"] - assert set(combos) == {("A",), ("B",)} + def test_marginal_curve_with_dataset_missing_both_variables(self): + """Test marginal_curve when Dataset has neither marginal_effects nor x.""" + rng = np.random.default_rng(42) + data = xr.Dataset( + { + "other_var": xr.DataArray( + rng.normal(size=(100, 20)), + dims=("sample", "sweep"), + coords={ + "sample": np.arange(100), + "sweep": np.linspace(0, 1, 20), + }, + ) + } + ) + suite = MMMPlotSuite(idata=None) + with pytest.raises(ValueError, match="must contain 'marginal_effects' or 'x'"): + suite.marginal_curve(data=data) -def test__dim_list_handler_multiple(): - """Test _dim_list_handler with multiple list-valued dims.""" - suite = MMMPlotSuite(idata=None) - keys, combos = suite._dim_list_handler( - {"country": ["A", "B"], "region": ["X", "Y"]} - ) - assert set(keys) == {"country", "region"} - assert set(combos) == {("A", "X"), ("A", "Y"), ("B", "X"), ("B", "Y")} + def test_sensitivity_analysis_with_aggregation_no_matching_dims(self): + """Test sensitivity_analysis_plot with aggregation but no matching dims.""" + from arviz_plots import PlotCollection + # Create data without the dimension to aggregate + data = xr.Dataset( + { + "x": xr.DataArray( + np.random.rand(100, 20), + dims=("sample", "sweep"), + coords={ + "sample": np.arange(100), + "sweep": np.linspace(0, 1, 20), + }, + ) + } + ) -def test__dim_list_handler_mixed(): - """Test _dim_list_handler with mixed single and list values.""" - suite = MMMPlotSuite(idata=None) - keys, combos = suite._dim_list_handler({"country": ["A", "B"], "region": "X"}) - assert keys == ["country"] - assert set(combos) == {("A",), ("B",)} + suite = MMMPlotSuite(idata=None) + # Should work even though "channel" doesn't exist in data + pc = suite._sensitivity_analysis_plot( + data=data, aggregation={"sum": ("channel",)} + ) + assert isinstance(pc, PlotCollection) diff --git a/tests/mmm/test_plot_compatibility.py b/tests/mmm/test_plot_compatibility.py new file mode 100644 index 000000000..1dbdaceeb --- /dev/null +++ b/tests/mmm/test_plot_compatibility.py @@ -0,0 +1,562 @@ +# Copyright 2022 - 2025 The PyMC Labs Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Compatibility tests for plot suite version switching.""" + +import warnings + +import numpy as np +import pytest +from arviz_plots import PlotCollection +from matplotlib.axes import Axes +from matplotlib.figure import Figure + + +class TestVersionSwitching: + """Test that mmm_plot_config['plot.use_v2'] controls which suite is returned.""" + + def test_use_v2_false_returns_legacy_suite(self, mock_mmm): + """Test that use_v2=False returns LegacyMMMPlotSuite.""" + from pymc_marketing.mmm import mmm_plot_config + from pymc_marketing.mmm.legacy_plot import LegacyMMMPlotSuite + + original = mmm_plot_config.get("plot.use_v2", False) + try: + mmm_plot_config["plot.use_v2"] = False + + with pytest.warns(FutureWarning, match="deprecated in v0.20.0"): + plot_suite = mock_mmm.plot + + assert isinstance(plot_suite, LegacyMMMPlotSuite) + assert plot_suite.__class__.__name__ == "LegacyMMMPlotSuite" + finally: + mmm_plot_config["plot.use_v2"] = original + + def test_use_v2_true_returns_new_suite(self, mock_mmm): + """Test that use_v2=True returns MMMPlotSuite.""" + from pymc_marketing.mmm import mmm_plot_config + from pymc_marketing.mmm.legacy_plot import LegacyMMMPlotSuite + from pymc_marketing.mmm.plot import MMMPlotSuite + + original = mmm_plot_config.get("plot.use_v2", False) + try: + mmm_plot_config["plot.use_v2"] = True + + # Should not warn + with warnings.catch_warnings(): + warnings.simplefilter("error") + plot_suite = mock_mmm.plot + + assert isinstance(plot_suite, MMMPlotSuite) + assert not isinstance(plot_suite, LegacyMMMPlotSuite) + assert plot_suite.__class__.__name__ == "MMMPlotSuite" + finally: + mmm_plot_config["plot.use_v2"] = original + + def test_default_is_legacy_suite(self, mock_mmm): + """Test that default behavior uses legacy suite (backward compatible).""" + from pymc_marketing.mmm import mmm_plot_config + from pymc_marketing.mmm.legacy_plot import LegacyMMMPlotSuite + + # Ensure default state + if "plot.use_v2" in mmm_plot_config: + del mmm_plot_config["plot.use_v2"] + + with pytest.warns(FutureWarning): + plot_suite = mock_mmm.plot + + assert isinstance(plot_suite, LegacyMMMPlotSuite) + + def test_config_flag_persists_across_calls(self, mock_mmm): + """Test that setting config flag affects all subsequent calls.""" + from pymc_marketing.mmm import mmm_plot_config + from pymc_marketing.mmm.plot import MMMPlotSuite + + original = mmm_plot_config.get("plot.use_v2", False) + try: + # Set once + mmm_plot_config["plot.use_v2"] = True + + # Multiple calls should all use new suite + plot_suite1 = mock_mmm.plot + plot_suite2 = mock_mmm.plot + plot_suite3 = mock_mmm.plot + + assert isinstance(plot_suite1, MMMPlotSuite) + assert isinstance(plot_suite2, MMMPlotSuite) + assert isinstance(plot_suite3, MMMPlotSuite) + finally: + mmm_plot_config["plot.use_v2"] = original + + def test_switching_between_v2_true_and_false(self, mock_mmm): + """Test that switching from use_v2=True to False and back works correctly.""" + from pymc_marketing.mmm import mmm_plot_config + from pymc_marketing.mmm.legacy_plot import LegacyMMMPlotSuite + from pymc_marketing.mmm.plot import MMMPlotSuite + + original = mmm_plot_config.get("plot.use_v2", False) + try: + # Start with use_v2 = True + mmm_plot_config["plot.use_v2"] = True + + # Should return new suite without warnings + with warnings.catch_warnings(): + warnings.simplefilter("error") + plot_suite_v2 = mock_mmm.plot + + assert isinstance(plot_suite_v2, MMMPlotSuite) + + # Switch to use_v2 = False + mmm_plot_config["plot.use_v2"] = False + + # Should return legacy suite with deprecation warning + with pytest.warns(FutureWarning, match="deprecated in v0.20.0"): + plot_suite_legacy = mock_mmm.plot + + assert isinstance(plot_suite_legacy, LegacyMMMPlotSuite) + + # Switch back to use_v2 = True + mmm_plot_config["plot.use_v2"] = True + + # Should return new suite again without warnings + with warnings.catch_warnings(): + warnings.simplefilter("error") + plot_suite_v2_again = mock_mmm.plot + + assert isinstance(plot_suite_v2_again, MMMPlotSuite) + finally: + mmm_plot_config["plot.use_v2"] = original + + +class TestDeprecationWarnings: + """Test deprecation warnings shown correctly with helpful information.""" + + def test_deprecation_warning_shown_by_default(self, mock_mmm): + """Test that deprecation warning is shown when using legacy suite.""" + from pymc_marketing.mmm import mmm_plot_config + + original_use_v2 = mmm_plot_config.get("plot.use_v2", False) + original_warnings = mmm_plot_config.get("plot.show_warnings", True) + + try: + mmm_plot_config["plot.use_v2"] = False + mmm_plot_config["plot.show_warnings"] = True + + with pytest.warns(FutureWarning, match=r"deprecated in v0\.20\.0"): + plot_suite = mock_mmm.plot + + assert plot_suite is not None + finally: + mmm_plot_config["plot.use_v2"] = original_use_v2 + mmm_plot_config["plot.show_warnings"] = original_warnings + + def test_deprecation_warning_suppressible(self, mock_mmm): + """Test that deprecation warning can be suppressed.""" + from pymc_marketing.mmm import mmm_plot_config + + original_use_v2 = mmm_plot_config.get("plot.use_v2", False) + original_warnings = mmm_plot_config.get("plot.show_warnings", True) + + try: + mmm_plot_config["plot.use_v2"] = False + mmm_plot_config["plot.show_warnings"] = False + + # Should not warn + with warnings.catch_warnings(): + warnings.simplefilter("error") # Turn warnings into errors + plot_suite = mock_mmm.plot + + assert plot_suite is not None + finally: + mmm_plot_config["plot.use_v2"] = original_use_v2 + mmm_plot_config["plot.show_warnings"] = original_warnings + + def test_warning_message_includes_migration_info(self, mock_mmm): + """Test that warning provides clear migration instructions.""" + from pymc_marketing.mmm import mmm_plot_config + + original_use_v2 = mmm_plot_config.get("plot.use_v2", False) + + try: + mmm_plot_config["plot.use_v2"] = False + + with pytest.warns(FutureWarning) as warning_list: + _ = mock_mmm.plot + + warning_msg = str(warning_list[0].message) + + # Check for key information + assert "v0.20.0" in warning_msg, "Should mention removal version" + assert "plot.use_v2" in warning_msg, "Should show how to enable v2" + assert "True" in warning_msg, "Should show value to set" + assert any( + word in warning_msg.lower() + for word in ["migration", "guide", "documentation", "docs"] + ), "Should reference migration guide" + finally: + mmm_plot_config["plot.use_v2"] = original_use_v2 + + def test_no_warning_when_using_new_suite(self, mock_mmm): + """Test that no warning shown when using new suite.""" + from pymc_marketing.mmm import mmm_plot_config + + original = mmm_plot_config.get("plot.use_v2", False) + + try: + mmm_plot_config["plot.use_v2"] = True + + with warnings.catch_warnings(): + warnings.simplefilter("error") + plot_suite = mock_mmm.plot + + assert plot_suite is not None + finally: + mmm_plot_config["plot.use_v2"] = original + + +class TestReturnTypeCompatibility: + """Test both suites return correct, expected types.""" + + def test_legacy_suite_returns_tuple(self, mock_mmm_fitted): + """Test legacy suite returns (Figure, Axes) tuple.""" + from pymc_marketing.mmm import mmm_plot_config + + original = mmm_plot_config.get("plot.use_v2", False) + + try: + mmm_plot_config["plot.use_v2"] = False + + with pytest.warns(FutureWarning): + plot_suite = mock_mmm_fitted.plot + result = plot_suite.posterior_predictive() + + assert isinstance(result, tuple), f"Expected tuple, got {type(result)}" + assert len(result) == 2, f"Expected 2-tuple, got length {len(result)}" + assert isinstance(result[0], Figure), ( + f"Expected Figure, got {type(result[0])}" + ) + + # result[1] can be Axes or ndarray of Axes + if isinstance(result[1], np.ndarray): + assert all(isinstance(ax, Axes) for ax in result[1].flat) + else: + assert isinstance(result[1], Axes) + finally: + mmm_plot_config["plot.use_v2"] = original + + def test_new_suite_returns_plot_collection(self, mock_mmm_fitted): + """Test new suite returns PlotCollection.""" + from pymc_marketing.mmm import mmm_plot_config + + original = mmm_plot_config.get("plot.use_v2", False) + + try: + mmm_plot_config["plot.use_v2"] = True + + plot_suite = mock_mmm_fitted.plot + result = plot_suite.posterior_predictive() + + assert isinstance(result, PlotCollection), ( + f"Expected PlotCollection, got {type(result)}" + ) + assert hasattr(result, "backend"), ( + "PlotCollection should have backend attribute" + ) + assert hasattr(result, "show"), "PlotCollection should have show method" + finally: + mmm_plot_config["plot.use_v2"] = original + + def test_both_suites_produce_valid_plots(self, mock_mmm_fitted): + """Test that both suites can successfully create plots.""" + from pymc_marketing.mmm import mmm_plot_config + + original = mmm_plot_config.get("plot.use_v2", False) + + try: + # Legacy suite + mmm_plot_config["plot.use_v2"] = False + with pytest.warns(FutureWarning): + legacy_result = mock_mmm_fitted.plot.contributions_over_time( + var=["intercept"] + ) + assert legacy_result is not None + + # New suite + mmm_plot_config["plot.use_v2"] = True + new_result = mock_mmm_fitted.plot.contributions_over_time(var=["intercept"]) + assert new_result is not None + finally: + mmm_plot_config["plot.use_v2"] = original + + +class TestDeprecatedMethodRemoval: + """Test saturation_curves_scatter() removed from new suite but kept in legacy.""" + + def test_saturation_curves_scatter_removed_from_new_suite(self, mock_mmm_fitted): + """Test saturation_curves_scatter removed from new MMMPlotSuite.""" + from pymc_marketing.mmm import mmm_plot_config + + original = mmm_plot_config.get("plot.use_v2", False) + + try: + mmm_plot_config["plot.use_v2"] = True + plot_suite = mock_mmm_fitted.plot + + assert not hasattr(plot_suite, "saturation_curves_scatter"), ( + "saturation_curves_scatter should not exist in new MMMPlotSuite" + ) + finally: + mmm_plot_config["plot.use_v2"] = original + + def test_saturation_curves_scatter_exists_in_legacy_suite(self, mock_mmm_fitted): + """Test saturation_curves_scatter still exists in LegacyMMMPlotSuite.""" + from pymc_marketing.mmm import mmm_plot_config + + original = mmm_plot_config.get("plot.use_v2", False) + + try: + mmm_plot_config["plot.use_v2"] = False + + with pytest.warns(FutureWarning): + plot_suite = mock_mmm_fitted.plot + + assert hasattr(plot_suite, "saturation_curves_scatter"), ( + "saturation_curves_scatter should exist in LegacyMMMPlotSuite" + ) + finally: + mmm_plot_config["plot.use_v2"] = original + + +class TestMissingMethods: + """Test methods that exist in one suite but not the other handle gracefully.""" + + def test_budget_allocation_exists_in_legacy_suite( + self, mock_mmm_fitted, mock_allocation_samples + ): + """Test that budget_allocation() works in legacy suite.""" + from pymc_marketing.mmm import mmm_plot_config + + original = mmm_plot_config.get("plot.use_v2", False) + + try: + mmm_plot_config["plot.use_v2"] = False + + with pytest.warns(FutureWarning): + plot_suite = mock_mmm_fitted.plot + + # Should work (not raise AttributeError) + result = plot_suite.budget_allocation(samples=mock_allocation_samples) + assert isinstance(result, tuple) + assert len(result) == 2 + finally: + mmm_plot_config["plot.use_v2"] = original + + def test_budget_allocation_raises_in_new_suite(self, mock_mmm_fitted): + """Test that budget_allocation() raises helpful error in new suite.""" + from pymc_marketing.mmm import mmm_plot_config + + original = mmm_plot_config.get("plot.use_v2", False) + + try: + mmm_plot_config["plot.use_v2"] = True + plot_suite = mock_mmm_fitted.plot + + with pytest.raises(NotImplementedError, match="removed in MMMPlotSuite v2"): + plot_suite.budget_allocation(samples=None) + finally: + mmm_plot_config["plot.use_v2"] = original + + def test_budget_allocation_roas_exists_in_new_suite(self, mock_mmm_fitted): + """Test that budget_allocation_roas() exists in new suite.""" + from pymc_marketing.mmm import mmm_plot_config + + original = mmm_plot_config.get("plot.use_v2", False) + + try: + mmm_plot_config["plot.use_v2"] = True + plot_suite = mock_mmm_fitted.plot + + # Just check that the method exists (not AttributeError) + assert hasattr(plot_suite, "budget_allocation_roas"), ( + "budget_allocation_roas should exist in new MMMPlotSuite" + ) + assert callable(plot_suite.budget_allocation_roas), ( + "budget_allocation_roas should be callable" + ) + finally: + mmm_plot_config["plot.use_v2"] = original + + def test_budget_allocation_roas_missing_in_legacy_suite(self, mock_mmm_fitted): + """Test that budget_allocation_roas() doesn't exist in legacy suite.""" + from pymc_marketing.mmm import mmm_plot_config + + original = mmm_plot_config.get("plot.use_v2", False) + + try: + mmm_plot_config["plot.use_v2"] = False + + with pytest.warns(FutureWarning): + plot_suite = mock_mmm_fitted.plot + + with pytest.raises(AttributeError): + plot_suite.budget_allocation_roas(samples=None) + finally: + mmm_plot_config["plot.use_v2"] = original + + +class TestConfigValidation: + """Test MMMPlotConfig key validation.""" + + def test_invalid_key_warns_but_allows_setting(self): + """Test that setting an invalid config key warns but still sets the value.""" + from pymc_marketing.mmm import mmm_plot_config + + # Store original state + original_invalid = mmm_plot_config.get("invalid.key", None) + try: + # Try to set an invalid key + with pytest.warns(UserWarning, match="Invalid config key"): + mmm_plot_config["invalid.key"] = "some_value" + + # Verify the warning message contains valid keys + with pytest.warns(UserWarning) as warning_list: + mmm_plot_config["another.invalid.key"] = "another_value" + + warning_msg = str(warning_list[0].message) + assert "Invalid config key" in warning_msg + assert "another.invalid.key" in warning_msg + assert "plot.backend" in warning_msg or "plot.show_warnings" in warning_msg + + # Verify the invalid key was still set (allows setting but warns) + assert mmm_plot_config["invalid.key"] == "some_value" + assert mmm_plot_config["another.invalid.key"] == "another_value" + finally: + # Clean up invalid keys + if "invalid.key" in mmm_plot_config: + del mmm_plot_config["invalid.key"] + if "another.invalid.key" in mmm_plot_config: + del mmm_plot_config["another.invalid.key"] + if original_invalid is not None: + mmm_plot_config["invalid.key"] = original_invalid + + def test_valid_keys_do_not_warn(self): + """Test that setting valid config keys does not warn.""" + from pymc_marketing.mmm import mmm_plot_config + + original_backend = mmm_plot_config.get("plot.backend", "matplotlib") + original_use_v2 = mmm_plot_config.get("plot.use_v2", False) + original_warnings = mmm_plot_config.get("plot.show_warnings", True) + + try: + # Setting valid keys should not warn + with warnings.catch_warnings(): + warnings.simplefilter("error", UserWarning) + mmm_plot_config["plot.backend"] = "plotly" + mmm_plot_config["plot.use_v2"] = True + mmm_plot_config["plot.show_warnings"] = False + + # Verify values were set + assert mmm_plot_config["plot.backend"] == "plotly" + assert mmm_plot_config["plot.use_v2"] is True + assert mmm_plot_config["plot.show_warnings"] is False + finally: + mmm_plot_config["plot.backend"] = original_backend + mmm_plot_config["plot.use_v2"] = original_use_v2 + mmm_plot_config["plot.show_warnings"] = original_warnings + + def test_reset_restores_defaults(self): + """Test that reset() restores all configuration to default values.""" + from pymc_marketing.mmm import mmm_plot_config + + # Store original state + original_backend = mmm_plot_config.get("plot.backend", "matplotlib") + original_use_v2 = mmm_plot_config.get("plot.use_v2", False) + original_warnings = mmm_plot_config.get("plot.show_warnings", True) + + try: + # Change all config values + mmm_plot_config["plot.backend"] = "plotly" + mmm_plot_config["plot.use_v2"] = True + mmm_plot_config["plot.show_warnings"] = False + + # Verify they were changed + assert mmm_plot_config["plot.backend"] == "plotly" + assert mmm_plot_config["plot.use_v2"] is True + assert mmm_plot_config["plot.show_warnings"] is False + + # Reset to defaults + mmm_plot_config.reset() + + # Verify all values are back to defaults + assert mmm_plot_config["plot.backend"] == "matplotlib" + assert mmm_plot_config["plot.use_v2"] is False + assert mmm_plot_config["plot.show_warnings"] is True + + # Verify reset clears any invalid keys that were set + mmm_plot_config["invalid.key"] = "test" + assert "invalid.key" in mmm_plot_config + mmm_plot_config.reset() + assert "invalid.key" not in mmm_plot_config + finally: + # Restore original state + mmm_plot_config["plot.backend"] = original_backend + mmm_plot_config["plot.use_v2"] = original_use_v2 + mmm_plot_config["plot.show_warnings"] = original_warnings + + def test_invalid_backend_warns_but_allows_setting(self): + """Test that setting an invalid backend warns but still sets the value.""" + from pymc_marketing.mmm import mmm_plot_config + + original_backend = mmm_plot_config.get("plot.backend", "matplotlib") + + try: + # Try to set an invalid backend + with pytest.warns(UserWarning, match="Invalid backend"): + mmm_plot_config["plot.backend"] = "invalid_backend" + + # Verify the warning message contains valid backends + with pytest.warns(UserWarning) as warning_list: + mmm_plot_config["plot.backend"] = "another_invalid" + + warning_msg = str(warning_list[0].message) + assert "Invalid backend" in warning_msg + assert "another_invalid" in warning_msg + assert ( + "matplotlib" in warning_msg + or "plotly" in warning_msg + or "bokeh" in warning_msg + ) + + # Verify the invalid backend was still set (allows setting but warns) + assert mmm_plot_config["plot.backend"] == "another_invalid" + finally: + mmm_plot_config["plot.backend"] = original_backend + + def test_valid_backends_do_not_warn(self): + """Test that setting valid backend values does not warn.""" + from pymc_marketing.mmm import mmm_plot_config + + original_backend = mmm_plot_config.get("plot.backend", "matplotlib") + + try: + # Setting valid backends should not warn + with warnings.catch_warnings(): + warnings.simplefilter("error", UserWarning) + mmm_plot_config["plot.backend"] = "matplotlib" + mmm_plot_config["plot.backend"] = "plotly" + mmm_plot_config["plot.backend"] = "bokeh" + + # Verify values were set + assert mmm_plot_config["plot.backend"] == "bokeh" + finally: + mmm_plot_config["plot.backend"] = original_backend diff --git a/tests/mmm/test_plot_data_parameters.py b/tests/mmm/test_plot_data_parameters.py new file mode 100644 index 000000000..033feec0a --- /dev/null +++ b/tests/mmm/test_plot_data_parameters.py @@ -0,0 +1,120 @@ +# Copyright 2022 - 2025 The PyMC Labs Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for data parameter standardization across plotting methods.""" + +import arviz_plots +import pytest +import xarray as xr + +from pymc_marketing.mmm.plot import MMMPlotSuite + + +@pytest.mark.parametrize( + "use_explicit_data", + [ + pytest.param(True, id="explicit_data_parameter"), + pytest.param(False, id="fallback_to_idata"), + ], +) +def test_contributions_over_time_data_parameter( + use_explicit_data, mock_posterior_data, mock_idata_with_posterior +): + """Test contributions_over_time with explicit data or fallback to idata.posterior.""" + if use_explicit_data: + suite = MMMPlotSuite(idata=None) + pc = suite.contributions_over_time(var=["intercept"], data=mock_posterior_data) + else: + suite = MMMPlotSuite(idata=mock_idata_with_posterior) + pc = suite.contributions_over_time(var=["intercept"]) + + assert isinstance(pc, arviz_plots.PlotCollection) + + +def test_contributions_over_time_no_data_raises_clear_error(): + """Test clear error when no data available.""" + suite = MMMPlotSuite(idata=None) + + with pytest.raises( + ValueError, match=r"No posterior data found.*and no 'data' argument provided" + ): + suite.contributions_over_time(var=["intercept"]) + + +def test_saturation_scatterplot_accepts_data_parameters( + mock_constant_data, mock_posterior_data +): + """Test saturation_scatterplot accepts data parameters.""" + import numpy as np + + # Need to add channel_contribution to mock_posterior_data + # Replicate the data across the channel dimension (3 channels) + intercept_values = mock_posterior_data["intercept"].values + channel_contrib_values = np.repeat(intercept_values[:, :, :, np.newaxis], 3, axis=3) + + mock_posterior_data["channel_contribution"] = xr.DataArray( + channel_contrib_values, + dims=("chain", "draw", "date", "channel"), + coords={ + **{k: v for k, v in mock_posterior_data.coords.items()}, + "channel": ["TV", "Radio", "Digital"], + }, + ) + + suite = MMMPlotSuite(idata=None) + + pc = suite.saturation_scatterplot( + constant_data=mock_constant_data, posterior_data=mock_posterior_data + ) + + assert isinstance(pc, arviz_plots.PlotCollection) + + +def test_sensitivity_analysis_plot_requires_data_parameter(mock_sensitivity_data): + """Test _sensitivity_analysis_plot requires data parameter (no fallback).""" + suite = MMMPlotSuite(idata=None) + + # Should work with data parameter + pc = suite._sensitivity_analysis_plot(data=mock_sensitivity_data) + + assert isinstance(pc, arviz_plots.PlotCollection) + + +def test_sensitivity_analysis_plot_no_fallback_to_self_idata( + mock_idata_with_sensitivity, +): + """Test _sensitivity_analysis_plot doesn't use self.idata even if available.""" + suite = MMMPlotSuite(idata=mock_idata_with_sensitivity) + + # Should raise error even though self.idata has sensitivity_analysis + with pytest.raises(TypeError, match=r"missing.*required.*argument.*data"): + suite._sensitivity_analysis_plot() + + +def test_uplift_curve_passes_data_to_helper_no_monkey_patch( + mock_idata_with_uplift_curve, +): + """Test uplift_curve passes data directly, no monkey-patching.""" + suite = MMMPlotSuite(idata=mock_idata_with_uplift_curve) + + # Store original idata reference + original_idata = suite.idata + original_sa_group = original_idata.sensitivity_analysis + + # Call uplift_curve + pc = suite.uplift_curve() + + # Verify no monkey-patching occurred + assert suite.idata is original_idata + assert suite.idata.sensitivity_analysis is original_sa_group + assert isinstance(pc, arviz_plots.PlotCollection)