Skip to content

Commit 19f0e42

Browse files
authored
1570 target scaled and variable need to be broadcasted with non global scaling (#1571)
1 parent 20113f3 commit 19f0e42

File tree

2 files changed

+42
-2
lines changed

2 files changed

+42
-2
lines changed

pymc_marketing/mmm/multidimensional.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -933,10 +933,17 @@ def add_original_scale_contribution_variable(self, var: list[str]) -> None:
933933
with self.model:
934934
for v in var:
935935
self._validate_contribution_variable(v)
936+
dims = self.model.named_vars_to_dims[v]
937+
dim_handler = create_dim_handler(dims)
938+
936939
pm.Deterministic(
937940
name=v + "_original_scale",
938-
var=self.model[v] * self.model["target_scale"],
939-
dims=self.model.named_vars_to_dims[v],
941+
var=self.model[v]
942+
* dim_handler(
943+
self.model["target_scale"],
944+
self.scalers._target.dims,
945+
),
946+
dims=dims,
940947
)
941948

942949
def build_model(

tests/mmm/test_multidimensional.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -647,3 +647,36 @@ def test_target_scaling_raises() -> None:
647647
target_column="target",
648648
channel_columns=["channel_1", "channel_2"],
649649
)
650+
651+
652+
@pytest.mark.parametrize("dims", [(), ("country",)], ids=["country-level", "global"])
653+
def test_target_scaling_and_contributions(
654+
multi_dim_data,
655+
dims,
656+
mock_pymc_sample,
657+
) -> None:
658+
X, y = multi_dim_data
659+
660+
scaling = {"target": {"method": "mean", "dims": dims}}
661+
mmm = MMM(
662+
adstock=GeometricAdstock(l_max=2),
663+
saturation=LogisticSaturation(),
664+
scaling=scaling,
665+
date_column="date",
666+
target_column="target",
667+
channel_columns=["channel_1", "channel_2"],
668+
dims=("country",),
669+
)
670+
671+
var_names = ["channel_contribution", "intercept_contribution", "y"]
672+
mmm.build_model(X, y)
673+
mmm.add_original_scale_contribution_variable(var=var_names)
674+
675+
for var in var_names:
676+
new_var_name = f"{var}_original_scale"
677+
assert new_var_name in mmm.model.named_vars
678+
679+
try:
680+
mmm.fit(X, y)
681+
except Exception as e:
682+
pytest.fail(f"Unexpected error: {e}")

0 commit comments

Comments
 (0)