Skip to content

Commit 33b9087

Browse files
authored
[DOCS] Improve MMM Case Study Notebook (#2116)
* custom saturation plots * fix scaler * rm wrong cells * training optimization plots * reorder * bug * rerun
1 parent 0e861cb commit 33b9087

File tree

4 files changed

+5232
-5821
lines changed

4 files changed

+5232
-5821
lines changed

docs/source/notebooks/mmm/mmm_case_study.ipynb

Lines changed: 2821 additions & 3665 deletions
Large diffs are not rendered by default.

docs/source/notebooks/mmm/mmm_example.ipynb

Lines changed: 2361 additions & 2131 deletions
Large diffs are not rendered by default.

pymc_marketing/mmm/base.py

Lines changed: 39 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -288,12 +288,25 @@ def _get_group_predictive_data(
288288
f"Make sure the model has been fitted and the {group} has been sampled!"
289289
) from e
290290

291-
if original_scale:
291+
has_been_scaled = (
292+
hasattr(self, "_posterior_predictive_samples_original_scale")
293+
and self._posterior_predictive_samples_original_scale
294+
)
295+
296+
if original_scale and not has_been_scaled:
292297
group_data = apply_sklearn_transformer_across_dim(
293298
data=group_data,
294299
func=self.get_target_transformer().inverse_transform,
295300
dim_name="date",
296301
)
302+
303+
if not original_scale and has_been_scaled:
304+
group_data = apply_sklearn_transformer_across_dim(
305+
data=group_data,
306+
func=self.get_target_transformer().transform,
307+
dim_name="date",
308+
)
309+
297310
return group_data
298311

299312
def _get_prior_predictive_data(self, original_scale: bool = False) -> Dataset:
@@ -707,38 +720,44 @@ def get_errors(self, original_scale: bool = False) -> DataArray:
707720
"Make sure the model has been fitted and the posterior_predictive has been sampled!"
708721
) from e
709722

710-
target_array = np.asarray(
711-
transform_1d_array(self.get_target_transformer().transform, self.y)
712-
)
713-
714-
if len(target_array) != len(posterior_predictive_data.date):
715-
raise ValueError(
716-
"The length of the target variable doesn't match the length of the date column. "
717-
"If you are computing out-of-sample errors, please overwrite `self.y` with the "
718-
"corresponding (non-transformed) target variable."
723+
target_array = (
724+
np.asarray(self.y)
725+
if original_scale
726+
else np.asarray(
727+
transform_1d_array(self.get_target_transformer().transform, self.y)
719728
)
729+
)
720730

721731
target = (
722732
pd.Series(target_array, index=self.posterior_predictive.date)
723733
.rename_axis("date")
724734
.to_xarray()
725735
)
726736

727-
errors = (
737+
if original_scale:
738+
# If posterior predictive data is not in original scale, transform it:
739+
if not hasattr(self, "_posterior_predictive_samples_original_scale"):
740+
posterior_predictive_data = apply_sklearn_transformer_across_dim(
741+
data=posterior_predictive_data,
742+
func=self.get_target_transformer().inverse_transform,
743+
dim_name="date",
744+
)
745+
else:
746+
# If posterior predictive data is in original scale, transform it back
747+
# to the scaled space:
748+
if hasattr(self, "_posterior_predictive_samples_original_scale"):
749+
posterior_predictive_data = apply_sklearn_transformer_across_dim(
750+
data=posterior_predictive_data,
751+
func=self.get_target_transformer().transform,
752+
dim_name="date",
753+
)
754+
755+
return (
728756
(target - posterior_predictive_data)[self.output_var]
729757
.rename("errors")
730758
.transpose(..., "date")
731759
)
732760

733-
if original_scale:
734-
return apply_sklearn_transformer_across_dim(
735-
data=errors,
736-
func=self.get_target_transformer().inverse_transform,
737-
dim_name="date",
738-
)
739-
740-
return errors
741-
742761
def plot_errors(
743762
self, original_scale: bool = False, ax: plt.Axes = None, **plt_kwargs: Any
744763
) -> plt.Figure:

pymc_marketing/mmm/mmm.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2389,21 +2389,26 @@ def legend_title_func(channel):
23892389
fig.suptitle("Direct response curves", fontsize=16)
23902390
return fig
23912391

2392-
def _transform_to_original_scale_new(self, samples: DataArray) -> DataArray:
2392+
def _transform_to_original_scale_new(
2393+
self, samples: DataArray, var_names: list[str] | None = None
2394+
) -> DataArray:
23932395
"""Transform samples to original scale using new scaling approach.
23942396
23952397
Parameters
23962398
----------
23972399
samples : DataArray
23982400
Samples in scaled space
2401+
var_names : list[str] | None
2402+
Variable names requested in sampling.
23992403
24002404
Returns
24012405
-------
24022406
DataArray
24032407
Samples in original scale
24042408
"""
2405-
if self.output_var in samples:
2406-
samples[self.output_var] = samples[self.output_var] * self.target_scale
2409+
vars_to_transform = var_names if var_names is not None else [self.output_var]
2410+
for var_name in (v for v in vars_to_transform if v in samples):
2411+
samples[var_name] *= self.target_scale
24072412
return samples
24082413

24092414
def _transform_to_original_scale_legacy(
@@ -2527,12 +2532,13 @@ def sample_posterior_predictive(
25272532

25282533
# Transform to original scale if requested
25292534
if original_scale:
2535+
self._posterior_predictive_samples_original_scale = True
2536+
var_names = sample_posterior_predictive_kwargs.get("var_names")
25302537
if self._has_new_scaling():
25312538
posterior_predictive_samples = self._transform_to_original_scale_new(
2532-
posterior_predictive_samples
2539+
posterior_predictive_samples, var_names
25332540
)
25342541
else:
2535-
var_names = sample_posterior_predictive_kwargs.get("var_names")
25362542
posterior_predictive_samples = self._transform_to_original_scale_legacy(
25372543
posterior_predictive_samples, var_names
25382544
)

0 commit comments

Comments
 (0)