Skip to content

Commit 5e6ffe1

Browse files
authored
Allow for tvp priors in multidimensional.MMM (#1785)
* parse the tvp keys as is * test that the tvp config is used
1 parent 900926b commit 5e6ffe1

File tree

2 files changed

+52
-1
lines changed

2 files changed

+52
-1
lines changed

pymc_marketing/mmm/multidimensional.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def __init__(
206206
sampler_config = sampler_config
207207
model_config = parse_model_config(
208208
model_config, # type: ignore
209-
hsgp_kwargs_fields=["intercept_tvp_config", "media_tvp_config"],
209+
non_distributions=["intercept_tvp_config", "media_tvp_config"],
210210
)
211211

212212
if model_config is not None:

tests/mmm/test_multidimensional.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1412,3 +1412,54 @@ def test_arbitrary_date_column_with_control_variables(
14121412

14131413
idata = mmm_controls.fit(X_with_controls, y, draws=50, tune=25, chains=1)
14141414
assert isinstance(idata, az.InferenceData)
1415+
1416+
1417+
@pytest.mark.parametrize(
1418+
"model_config, expected_config, expected_rv",
1419+
[
1420+
pytest.param(
1421+
{"intercept_tvp_config": {"ls_lower": 0.1, "ls_upper": None}},
1422+
None,
1423+
dict(name="intercept_latent_process_raw_ls_raw", kind="WeibullBetaRV"),
1424+
id="weibull",
1425+
),
1426+
pytest.param(
1427+
{"intercept_tvp_config": {"ls_lower": 1, "ls_upper": 10}},
1428+
None,
1429+
dict(name="intercept_latent_process_raw_ls", kind="InvGammaRV"),
1430+
id="inversegamma",
1431+
),
1432+
],
1433+
)
1434+
def test_specify_time_varying_configuration(
1435+
single_dim_data,
1436+
model_config,
1437+
expected_config,
1438+
expected_rv,
1439+
) -> None:
1440+
X, y = single_dim_data
1441+
expected_config = expected_config or model_config
1442+
1443+
mmm = MMM(
1444+
date_column="date",
1445+
target_column="target",
1446+
channel_columns=["channel_1", "channel_2"],
1447+
control_columns=["control_1", "control_2"],
1448+
adstock=GeometricAdstock(l_max=2),
1449+
saturation=LogisticSaturation(),
1450+
model_config=model_config,
1451+
time_varying_intercept=True,
1452+
)
1453+
1454+
assert isinstance(mmm.model_config["intercept_tvp_config"], dict)
1455+
assert (
1456+
mmm.model_config["intercept_tvp_config"]
1457+
== expected_config["intercept_tvp_config"]
1458+
)
1459+
1460+
mmm.build_model(X, y)
1461+
1462+
assert (
1463+
mmm.model[expected_rv["name"]].owner.op.__class__.__name__
1464+
== expected_rv["kind"]
1465+
)

0 commit comments

Comments
 (0)