Skip to content

Commit 0a165fd

Browse files
authored
Allow for media scaling settings (#1573)
* specify the channel scaling * implement the channel scaling specification * add test for the scaling * add additional checks * set nans to zero * add nans to the test
1 parent 19f0e42 commit 0a165fd

File tree

3 files changed

+112
-20
lines changed

3 files changed

+112
-20
lines changed

pymc_marketing/mmm/multidimensional.py

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import json
1919
import warnings
20+
from copy import deepcopy
2021
from typing import Any, Literal, Protocol
2122

2223
import arviz as az
@@ -267,17 +268,30 @@ def __init__(
267268
self.dims = dims
268269

269270
if isinstance(scaling, dict):
271+
scaling = deepcopy(scaling)
272+
273+
if "channel" not in scaling:
274+
scaling["channel"] = VariableScaling(method="max", dims=self.dims)
275+
if "target" not in scaling:
276+
scaling["target"] = VariableScaling(method="max", dims=self.dims)
277+
270278
scaling = Scaling(**scaling)
271279

272280
self.scaling: Scaling = scaling or Scaling(
273-
target=VariableScaling(method="max", dims=self.dims)
281+
target=VariableScaling(method="max", dims=self.dims),
282+
channel=VariableScaling(method="max", dims=self.dims),
274283
)
275284

276285
if set(self.scaling.target.dims).difference([*self.dims, "date"]):
277286
raise ValueError(
278287
f"Target scaling dims {self.scaling.target.dims} must contain {self.dims} and 'date'"
279288
)
280289

290+
if set(self.scaling.channel.dims).difference([*self.dims, "channel", "date"]):
291+
raise ValueError(
292+
f"Channel scaling dims {self.scaling.channel.dims} must contain {self.dims}, 'channel', and 'date'"
293+
)
294+
281295
model_config = model_config if model_config is not None else {}
282296
sampler_config = sampler_config
283297
model_config = parse_model_config(
@@ -864,11 +878,21 @@ def forward_pass(
864878

865879
def _compute_scales(self) -> None:
866880
"""Compute and save scaling factors for channels and target."""
867-
method = getattr(self.xarray_dataset, self.scaling.target.method)
868-
self.scalers = method(dim=("date", *self.dims))
869-
self.scalers["_target"] = method(dim=("date", *self.scaling.target.dims))[
870-
"_target"
871-
]
881+
self.scalers = xr.Dataset()
882+
883+
channel_method = getattr(
884+
self.xarray_dataset["_channel"],
885+
self.scaling.channel.method,
886+
)
887+
self.scalers["_channel"] = channel_method(
888+
dim=("date", *self.scaling.channel.dims)
889+
)
890+
891+
target_method = getattr(
892+
self.xarray_dataset["_target"],
893+
self.scaling.target.method,
894+
)
895+
self.scalers["_target"] = target_method(dim=("date", *self.scaling.target.dims))
872896

873897
def get_scales_as_xarray(self) -> dict[str, xr.DataArray]:
874898
"""Return the saved scaling factors as xarray DataArrays.
@@ -1023,7 +1047,7 @@ def build_model(
10231047
_channel_scale = pm.Data(
10241048
"channel_scale",
10251049
self.scalers._channel.values,
1026-
dims="channel",
1050+
dims=self.scalers._channel.dims,
10271051
)
10281052
_target_scale = pm.Data(
10291053
"target_scale",
@@ -1048,7 +1072,12 @@ def build_model(
10481072
)
10491073

10501074
# Scale `channel_data` and `target`
1051-
channel_data_ = _channel_data / _channel_scale
1075+
channel_dim_handler = create_dim_handler(("date", *self.dims, "channel"))
1076+
channel_data_ = _channel_data / channel_dim_handler(
1077+
_channel_scale,
1078+
self.scalers._channel.dims,
1079+
)
1080+
channel_data_ = pt.switch(pt.isnan(channel_data_), 0.0, channel_data_)
10521081
channel_data_.name = "channel_data_scaled"
10531082
channel_data_.dims = ("date", *self.dims, "channel")
10541083

pymc_marketing/mmm/scaling.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,3 +71,7 @@ class Scaling(BaseModel):
7171
...,
7272
description="The scaling for the target variable.",
7373
)
74+
channel: VariableScaling = Field(
75+
...,
76+
description="The scaling for the channel variable.",
77+
)

tests/mmm/test_multidimensional.py

Lines changed: 71 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
MMM,
2929
create_event_mu_effect,
3030
)
31-
from pymc_marketing.mmm.scaling import VariableScaling
31+
from pymc_marketing.mmm.scaling import Scaling, VariableScaling
3232
from pymc_marketing.prior import Prior
3333

3434

@@ -123,12 +123,14 @@ def single_dim_data():
123123
# Generate random channel data
124124
channel_1 = np.random.randint(100, 500, size=len(date_range))
125125
channel_2 = np.random.randint(100, 500, size=len(date_range))
126+
channel_3 = np.nan
126127

127128
df = pd.DataFrame(
128129
{
129130
"date": date_range,
130131
"channel_1": channel_1,
131132
"channel_2": channel_2,
133+
"channel_3": channel_3,
132134
}
133135
)
134136
# Target is sum of channels with noise
@@ -137,7 +139,7 @@ def single_dim_data():
137139
+ df["channel_2"]
138140
+ np.random.randint(100, 300, size=len(date_range))
139141
)
140-
X = df[["date", "channel_1", "channel_2"]].copy()
142+
X = df[["date", "channel_1", "channel_2", "channel_3"]].copy()
141143

142144
return X, df.set_index(["date"])["target"].copy()
143145

@@ -161,14 +163,16 @@ def multi_dim_data():
161163
for date in date_range:
162164
channel_1 = np.random.randint(100, 500)
163165
channel_2 = np.random.randint(100, 500)
166+
channel_3 = np.nan
164167
target = channel_1 + channel_2 + np.random.randint(50, 150)
165-
records.append((date, country, channel_1, channel_2, target))
168+
records.append((date, country, channel_1, channel_2, channel_3, target))
166169

167170
df = pd.DataFrame(
168-
records, columns=["date", "country", "channel_1", "channel_2", "target"]
171+
records,
172+
columns=["date", "country", "channel_1", "channel_2", "channel_3", "target"],
169173
)
170174

171-
X = df[["date", "country", "channel_1", "channel_2"]].copy()
175+
X = df[["date", "country", "channel_1", "channel_2", "channel_3"]].copy()
172176

173177
return X, df["target"].copy()
174178

@@ -208,7 +212,7 @@ def test_fit(
208212
mmm = MMM(
209213
date_column="date",
210214
target_column="target",
211-
channel_columns=["channel_1", "channel_2"],
215+
channel_columns=["channel_1", "channel_2", "channel_3"],
212216
dims=dims,
213217
adstock=adstock,
214218
saturation=saturation,
@@ -296,7 +300,7 @@ def test_sample_posterior_predictive_new_data(single_dim_data, mock_pymc_sample)
296300
mmm = MMM(
297301
date_column="date",
298302
target_column="target",
299-
channel_columns=["channel_1", "channel_2"],
303+
channel_columns=["channel_1", "channel_2", "channel_3"],
300304
adstock=adstock,
301305
saturation=saturation,
302306
)
@@ -307,6 +311,11 @@ def test_sample_posterior_predictive_new_data(single_dim_data, mock_pymc_sample)
307311

308312
mmm.sample_posterior_predictive(X_train, extend_idata=True, random_seed=42)
309313

314+
def no_null_values(ds):
315+
return ds.y.isnull().mean()
316+
317+
np.testing.assert_allclose(no_null_values(mmm.idata.posterior_predictive), 0)
318+
310319
# Sample posterior predictive on new data
311320
out_of_sample_idata = mmm.sample_posterior_predictive(
312321
X_new, extend_idata=False, random_seed=42
@@ -318,6 +327,8 @@ def test_sample_posterior_predictive_new_data(single_dim_data, mock_pymc_sample)
318327
"there should be a 'posterior_predictive' group in the inference data."
319328
)
320329

330+
np.testing.assert_allclose(no_null_values(out_of_sample_idata), 0)
331+
321332
# Check the shape of that group. We expect the new date dimension to match X_new length
322333
# plus no addition if we didn't set include_last_observations (which is False by default).
323334
assert "date" in out_of_sample_idata.dims, (
@@ -349,7 +360,7 @@ def test_sample_posterior_predictive_same_data(single_dim_data, mock_pymc_sample
349360
mmm = MMM(
350361
date_column="date",
351362
target_column="target",
352-
channel_columns=["channel_1", "channel_2"],
363+
channel_columns=["channel_1", "channel_2", "channel_3"],
353364
adstock=adstock,
354365
saturation=saturation,
355366
)
@@ -587,7 +598,7 @@ def test_check_for_incompatible_dims(adstock, saturation, dims) -> None:
587598
kwargs = dict(
588599
date_column="date",
589600
target_column="target",
590-
channel_columns=["channel_1", "channel_2"],
601+
channel_columns=["channel_1", "channel_2", "channel_3"],
591602
)
592603
with pytest.raises(ValueError):
593604
MMM(
@@ -608,7 +619,7 @@ def test_different_target_scaling(method, multi_dim_data, mock_pymc_sample) -> N
608619
scaling=scaling,
609620
date_column="date",
610621
target_column="target",
611-
channel_columns=["channel_1", "channel_2"],
622+
channel_columns=["channel_1", "channel_2", "channel_3"],
612623
dims=("country",),
613624
)
614625
assert mmm.scaling.target == VariableScaling(method=method, dims=())
@@ -645,7 +656,7 @@ def test_target_scaling_raises() -> None:
645656
scaling=scaling,
646657
date_column="date",
647658
target_column="target",
648-
channel_columns=["channel_1", "channel_2"],
659+
channel_columns=["channel_1", "channel_2", "channel_3"],
649660
)
650661

651662

@@ -664,7 +675,7 @@ def test_target_scaling_and_contributions(
664675
scaling=scaling,
665676
date_column="date",
666677
target_column="target",
667-
channel_columns=["channel_1", "channel_2"],
678+
channel_columns=["channel_1", "channel_2", "channel_3"],
668679
dims=("country",),
669680
)
670681

@@ -680,3 +691,51 @@ def test_target_scaling_and_contributions(
680691
mmm.fit(X, y)
681692
except Exception as e:
682693
pytest.fail(f"Unexpected error: {e}")
694+
695+
696+
@pytest.mark.parametrize(
697+
"dims, expected_dims",
698+
[
699+
((), ("country", "channel")),
700+
(("country",), ("channel",)),
701+
(("channel",), ("country",)),
702+
],
703+
ids=["country-channel", "country", "channel"],
704+
)
705+
def test_channel_scaling(multi_dim_data, dims, expected_dims, mock_pymc_sample) -> None:
706+
X, y = multi_dim_data
707+
708+
scaling = {"channel": {"method": "mean", "dims": dims}}
709+
mmm = MMM(
710+
adstock=GeometricAdstock(l_max=2),
711+
saturation=LogisticSaturation(),
712+
scaling=scaling,
713+
date_column="date",
714+
target_column="target",
715+
channel_columns=["channel_1", "channel_2", "channel_3"],
716+
dims=("country",),
717+
)
718+
719+
mmm.fit(X, y)
720+
721+
assert mmm.scalers._channel.dims == expected_dims
722+
723+
724+
def test_scaling_dict_doesnt_mutate() -> None:
725+
scaling = {}
726+
dims = ("country",)
727+
mmm = MMM(
728+
adstock=GeometricAdstock(l_max=2),
729+
saturation=LogisticSaturation(),
730+
scaling=scaling,
731+
date_column="date",
732+
target_column="target",
733+
channel_columns=["channel_1", "channel_2", "channel_3"],
734+
dims=dims,
735+
)
736+
737+
assert scaling == {}
738+
assert mmm.scaling == Scaling(
739+
target=VariableScaling(method="max", dims=dims),
740+
channel=VariableScaling(method="max", dims=dims),
741+
)

0 commit comments

Comments
 (0)