Skip to content

Commit 1a39424

Browse files
cetagostiniCopilot
andauthored
Preparing plots and notebook for old API deprecation (#2055)
* Preparing for old API deprecation * Changes in test. * Update pymc_marketing/model_builder.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update test_plot.py * Refactor prior_predictive to accept single variable name Changed the prior_predictive method in MMMPlotSuite to accept a single variable name (str) instead of a list of variable names. Updated related logic and documentation to reflect this change for consistency and simplicity. Also added posterior_predictive_constant_data to idata groups in multidimensional.py. --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 4983c1e commit 1a39424

File tree

11 files changed

+14791
-116
lines changed

11 files changed

+14791
-116
lines changed

docs/source/notebooks/mmm/dev/mmm_example_new.ipynb

Lines changed: 12613 additions & 0 deletions
Large diffs are not rendered by default.

pymc_marketing/customer_choice/mnl_logit.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,13 @@ class MNLogit(RegressionModelBuilder):
8282
8383
Example `utility_equations` list:
8484
85-
>>> utility_equations = [
86-
... "alt_1 ~ X1_alt1 + X2_alt1 | income",
87-
... "alt_2 ~ X1_alt2 + X2_alt2 | income",
88-
... "alt_3 ~ X1_alt3 + X2_alt3 | income",
89-
... ]
85+
.. code-block:: python
86+
87+
utility_equations = [
88+
"alt_1 ~ X1_alt1 + X2_alt1 | income",
89+
"alt_2 ~ X1_alt2 + X2_alt2 | income",
90+
"alt_3 ~ X1_alt3 + X2_alt3 | income",
91+
]
9092
9193
"""
9294

pymc_marketing/customer_choice/nested_logit.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -91,18 +91,22 @@ class NestedLogit(RegressionModelBuilder):
9191
9292
Example `utility_equations` list:
9393
94-
>>> utility_equations = [
95-
... "alt_1 ~ X1_alt1 + X2_alt1 | income",
96-
... "alt_2 ~ X1_alt2 + X2_alt2 | income",
97-
... "alt_3 ~ X1_alt3 + X2_alt3 | income",
98-
... ]
94+
.. code-block:: python
95+
96+
utility_equations = [
97+
"alt_1 ~ X1_alt1 + X2_alt1 | income",
98+
"alt_2 ~ X1_alt2 + X2_alt2 | income",
99+
"alt_3 ~ X1_alt3 + X2_alt3 | income",
100+
]
99101
100102
Example nesting structure:
101103
102-
>>> nesting_structure = {
103-
... "Nest1": ["alt1"],
104-
... "Nest2": {"Nest2_1": ["alt_2", "alt_3"], "Nest_2_2": ["alt_4", "alt_5"]},
105-
... }
104+
.. code-block:: python
105+
106+
nesting_structure = {
107+
"Nest1": ["alt1"],
108+
"Nest2": {"Nest2_1": ["alt_2", "alt_3"], "Nest_2_2": ["alt_4", "alt_5"]},
109+
}
106110
107111
"""
108112

pymc_marketing/mmm/base.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -243,8 +243,10 @@ def preprocess(
243243
244244
Example
245245
-------
246-
>>> data = pd.DataFrame({"x1": [1, 2, 3], "y": [4, 5, 6]})
247-
>>> self.preprocess("X", data)
246+
.. code-block:: python
247+
248+
data = pd.DataFrame({"x1": [1, 2, 3], "y": [4, 5, 6]})
249+
self.preprocess("X", data)
248250
249251
"""
250252
data_cp = data.copy()

pymc_marketing/mmm/mmm.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2121,28 +2121,26 @@ def format_recovered_transformation_parameters(
21212121
21222122
Example
21232123
-------
2124-
>>> self.format_recovered_transformation_parameters(quantile=0.5)
2125-
>>> Output:
2126-
{
2127-
'x1': {
2128-
'saturation_params': {
2129-
'lam': 2.4761893929757077,
2130-
'beta': 0.360226791880304
2124+
.. code-block:: python
2125+
2126+
self.format_recovered_transformation_parameters(quantile=0.5)
2127+
# Output:
2128+
{
2129+
"x1": {
2130+
"saturation_params": {
2131+
"lam": 2.4761893929757077,
2132+
"beta": 0.360226791880304,
2133+
},
2134+
"adstock_params": {"alpha": 0.39910387900504796},
21312135
},
2132-
'adstock_params': {
2133-
'alpha': 0.39910387900504796
2134-
}
2135-
},
2136-
'x2': {
2137-
'saturation_params': {
2138-
'lam': 2.6485978655163436,
2139-
'beta': 0.2399381337197204
2136+
"x2": {
2137+
"saturation_params": {
2138+
"lam": 2.6485978655163436,
2139+
"beta": 0.2399381337197204,
2140+
},
2141+
"adstock_params": {"alpha": 0.18859423763437405},
21402142
},
2141-
'adstock_params': {
2142-
'alpha': 0.18859423763437405
2143-
}
21442143
}
2145-
}
21462144
21472145
"""
21482146
# Retrieve channel names

pymc_marketing/mmm/multidimensional.py

Lines changed: 45 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1013,11 +1013,14 @@ def forward_pass(
10131013
10141014
Examples
10151015
--------
1016-
>>> mmm = MMM(
1017-
date_column="date_week",
1018-
channel_columns=["channel_1", "channel_2"],
1019-
target_column="target",
1020-
)
1016+
.. code-block:: python
1017+
1018+
mmm = MMM(
1019+
date_column="date_week",
1020+
channel_columns=["channel_1", "channel_2"],
1021+
target_column="target",
1022+
)
1023+
10211024
"""
10221025
first, second = (
10231026
(self.adstock, self.saturation)
@@ -1055,13 +1058,16 @@ def get_scales_as_xarray(self) -> dict[str, xr.DataArray]:
10551058
10561059
Examples
10571060
--------
1058-
>>> mmm = MMM(
1059-
date_column="date_week",
1060-
channel_columns=["channel_1", "channel_2"],
1061-
target_column="target",
1062-
)
1063-
>>> mmm.build_model(X, y)
1064-
>>> mmm.get_scales_as_xarray()
1061+
.. code-block:: python
1062+
1063+
mmm = MMM(
1064+
date_column="date_week",
1065+
channel_columns=["channel_1", "channel_2"],
1066+
target_column="target",
1067+
)
1068+
mmm.build_model(X, y)
1069+
mmm.get_scales_as_xarray()
1070+
10651071
"""
10661072
if not hasattr(self, "scalers"):
10671073
raise ValueError(
@@ -1100,9 +1106,12 @@ def add_original_scale_contribution_variable(self, var: list[str]) -> None:
11001106
11011107
Examples
11021108
--------
1103-
>>> model.add_original_scale_contribution_variable(
1104-
>>> var=["channel_contribution", "total_media_contribution", "y"]
1105-
>>> )
1109+
.. code-block:: python
1110+
1111+
model.add_original_scale_contribution_variable(
1112+
var=["channel_contribution", "total_media_contribution", "y"]
1113+
)
1114+
11061115
"""
11071116
self._validate_model_was_built()
11081117
target_dims = self.scalers._target.dims
@@ -1695,8 +1704,11 @@ def sample_posterior_predictive(
16951704
self.idata, **sample_posterior_predictive_kwargs
16961705
)
16971706

1698-
if extend_idata:
1699-
self.idata.extend(post_pred, join="right") # type: ignore
1707+
if extend_idata and self.idata is not None:
1708+
self.idata.add_groups(
1709+
posterior_predictive=post_pred.posterior_predictive,
1710+
posterior_predictive_constant_data=post_pred.constant_data,
1711+
) # type: ignore
17001712

17011713
group = "posterior_predictive"
17021714
posterior_predictive_samples = az.extract(post_pred, group, combined=combined)
@@ -1723,11 +1735,14 @@ def sensitivity(self) -> SensitivityAnalysis:
17231735
17241736
Examples
17251737
--------
1726-
>>> mmm.sensitivity.run_sweep(
1727-
... var_names=["channel_1", "channel_2"],
1728-
... sweep_values=np.linspace(0.5, 2.0, 10),
1729-
... sweep_type="multiplicative",
1730-
... )
1738+
.. code-block:: python
1739+
1740+
mmm.sensitivity.run_sweep(
1741+
var_names=["channel_1", "channel_2"],
1742+
sweep_values=np.linspace(0.5, 2.0, 10),
1743+
sweep_type="multiplicative",
1744+
)
1745+
17311746
"""
17321747
# Provide the underlying PyMC model, the model's inference data, and dims
17331748
return SensitivityAnalysis(
@@ -2145,7 +2160,10 @@ def create_fit_data(
21452160
21462161
Examples
21472162
--------
2148-
>>> ds = mmm.create_fit_data(X, y)
2163+
.. code-block:: python
2164+
2165+
ds = mmm.create_fit_data(X, y)
2166+
21492167
"""
21502168
# --- Coerce X to DataFrame ---
21512169
if isinstance(X, xr.Dataset):
@@ -2234,7 +2252,10 @@ def build_from_idata(self, idata: az.InferenceData) -> None:
22342252
22352253
Examples
22362254
--------
2237-
>>> mmm.build_from_idata(idata)
2255+
.. code-block:: python
2256+
2257+
mmm.build_from_idata(idata)
2258+
22382259
"""
22392260
dataset = idata.fit_data.to_dataframe()
22402261

0 commit comments

Comments
 (0)