2828 MMM ,
2929 create_event_mu_effect ,
3030)
31- from pymc_marketing .mmm .scaling import VariableScaling
31+ from pymc_marketing .mmm .scaling import Scaling , VariableScaling
3232from pymc_marketing .prior import Prior
3333
3434
@@ -123,12 +123,14 @@ def single_dim_data():
123123 # Generate random channel data
124124 channel_1 = np .random .randint (100 , 500 , size = len (date_range ))
125125 channel_2 = np .random .randint (100 , 500 , size = len (date_range ))
126+ channel_3 = np .nan
126127
127128 df = pd .DataFrame (
128129 {
129130 "date" : date_range ,
130131 "channel_1" : channel_1 ,
131132 "channel_2" : channel_2 ,
133+ "channel_3" : channel_3 ,
132134 }
133135 )
134136 # Target is sum of channels with noise
@@ -137,7 +139,7 @@ def single_dim_data():
137139 + df ["channel_2" ]
138140 + np .random .randint (100 , 300 , size = len (date_range ))
139141 )
140- X = df [["date" , "channel_1" , "channel_2" ]].copy ()
142+ X = df [["date" , "channel_1" , "channel_2" , "channel_3" ]].copy ()
141143
142144 return X , df .set_index (["date" ])["target" ].copy ()
143145
@@ -161,14 +163,16 @@ def multi_dim_data():
161163 for date in date_range :
162164 channel_1 = np .random .randint (100 , 500 )
163165 channel_2 = np .random .randint (100 , 500 )
166+ channel_3 = np .nan
164167 target = channel_1 + channel_2 + np .random .randint (50 , 150 )
165- records .append ((date , country , channel_1 , channel_2 , target ))
168+ records .append ((date , country , channel_1 , channel_2 , channel_3 , target ))
166169
167170 df = pd .DataFrame (
168- records , columns = ["date" , "country" , "channel_1" , "channel_2" , "target" ]
171+ records ,
172+ columns = ["date" , "country" , "channel_1" , "channel_2" , "channel_3" , "target" ],
169173 )
170174
171- X = df [["date" , "country" , "channel_1" , "channel_2" ]].copy ()
175+ X = df [["date" , "country" , "channel_1" , "channel_2" , "channel_3" ]].copy ()
172176
173177 return X , df ["target" ].copy ()
174178
@@ -208,7 +212,7 @@ def test_fit(
208212 mmm = MMM (
209213 date_column = "date" ,
210214 target_column = "target" ,
211- channel_columns = ["channel_1" , "channel_2" ],
215+ channel_columns = ["channel_1" , "channel_2" , "channel_3" ],
212216 dims = dims ,
213217 adstock = adstock ,
214218 saturation = saturation ,
@@ -296,7 +300,7 @@ def test_sample_posterior_predictive_new_data(single_dim_data, mock_pymc_sample)
296300 mmm = MMM (
297301 date_column = "date" ,
298302 target_column = "target" ,
299- channel_columns = ["channel_1" , "channel_2" ],
303+ channel_columns = ["channel_1" , "channel_2" , "channel_3" ],
300304 adstock = adstock ,
301305 saturation = saturation ,
302306 )
@@ -307,6 +311,11 @@ def test_sample_posterior_predictive_new_data(single_dim_data, mock_pymc_sample)
307311
308312 mmm .sample_posterior_predictive (X_train , extend_idata = True , random_seed = 42 )
309313
314+ def no_null_values (ds ):
315+ return ds .y .isnull ().mean ()
316+
317+ np .testing .assert_allclose (no_null_values (mmm .idata .posterior_predictive ), 0 )
318+
310319 # Sample posterior predictive on new data
311320 out_of_sample_idata = mmm .sample_posterior_predictive (
312321 X_new , extend_idata = False , random_seed = 42
@@ -318,6 +327,8 @@ def test_sample_posterior_predictive_new_data(single_dim_data, mock_pymc_sample)
318327 "there should be a 'posterior_predictive' group in the inference data."
319328 )
320329
330+ np .testing .assert_allclose (no_null_values (out_of_sample_idata ), 0 )
331+
321332 # Check the shape of that group. We expect the new date dimension to match X_new length
322333 # plus no addition if we didn't set include_last_observations (which is False by default).
323334 assert "date" in out_of_sample_idata .dims , (
@@ -349,7 +360,7 @@ def test_sample_posterior_predictive_same_data(single_dim_data, mock_pymc_sample
349360 mmm = MMM (
350361 date_column = "date" ,
351362 target_column = "target" ,
352- channel_columns = ["channel_1" , "channel_2" ],
363+ channel_columns = ["channel_1" , "channel_2" , "channel_3" ],
353364 adstock = adstock ,
354365 saturation = saturation ,
355366 )
@@ -587,7 +598,7 @@ def test_check_for_incompatible_dims(adstock, saturation, dims) -> None:
587598 kwargs = dict (
588599 date_column = "date" ,
589600 target_column = "target" ,
590- channel_columns = ["channel_1" , "channel_2" ],
601+ channel_columns = ["channel_1" , "channel_2" , "channel_3" ],
591602 )
592603 with pytest .raises (ValueError ):
593604 MMM (
@@ -608,7 +619,7 @@ def test_different_target_scaling(method, multi_dim_data, mock_pymc_sample) -> N
608619 scaling = scaling ,
609620 date_column = "date" ,
610621 target_column = "target" ,
611- channel_columns = ["channel_1" , "channel_2" ],
622+ channel_columns = ["channel_1" , "channel_2" , "channel_3" ],
612623 dims = ("country" ,),
613624 )
614625 assert mmm .scaling .target == VariableScaling (method = method , dims = ())
@@ -645,7 +656,7 @@ def test_target_scaling_raises() -> None:
645656 scaling = scaling ,
646657 date_column = "date" ,
647658 target_column = "target" ,
648- channel_columns = ["channel_1" , "channel_2" ],
659+ channel_columns = ["channel_1" , "channel_2" , "channel_3" ],
649660 )
650661
651662
@@ -664,7 +675,7 @@ def test_target_scaling_and_contributions(
664675 scaling = scaling ,
665676 date_column = "date" ,
666677 target_column = "target" ,
667- channel_columns = ["channel_1" , "channel_2" ],
678+ channel_columns = ["channel_1" , "channel_2" , "channel_3" ],
668679 dims = ("country" ,),
669680 )
670681
@@ -680,3 +691,51 @@ def test_target_scaling_and_contributions(
680691 mmm .fit (X , y )
681692 except Exception as e :
682693 pytest .fail (f"Unexpected error: { e } " )
694+
695+
696+ @pytest .mark .parametrize (
697+ "dims, expected_dims" ,
698+ [
699+ ((), ("country" , "channel" )),
700+ (("country" ,), ("channel" ,)),
701+ (("channel" ,), ("country" ,)),
702+ ],
703+ ids = ["country-channel" , "country" , "channel" ],
704+ )
705+ def test_channel_scaling (multi_dim_data , dims , expected_dims , mock_pymc_sample ) -> None :
706+ X , y = multi_dim_data
707+
708+ scaling = {"channel" : {"method" : "mean" , "dims" : dims }}
709+ mmm = MMM (
710+ adstock = GeometricAdstock (l_max = 2 ),
711+ saturation = LogisticSaturation (),
712+ scaling = scaling ,
713+ date_column = "date" ,
714+ target_column = "target" ,
715+ channel_columns = ["channel_1" , "channel_2" , "channel_3" ],
716+ dims = ("country" ,),
717+ )
718+
719+ mmm .fit (X , y )
720+
721+ assert mmm .scalers ._channel .dims == expected_dims
722+
723+
724+ def test_scaling_dict_doesnt_mutate () -> None :
725+ scaling = {}
726+ dims = ("country" ,)
727+ mmm = MMM (
728+ adstock = GeometricAdstock (l_max = 2 ),
729+ saturation = LogisticSaturation (),
730+ scaling = scaling ,
731+ date_column = "date" ,
732+ target_column = "target" ,
733+ channel_columns = ["channel_1" , "channel_2" , "channel_3" ],
734+ dims = dims ,
735+ )
736+
737+ assert scaling == {}
738+ assert mmm .scaling == Scaling (
739+ target = VariableScaling (method = "max" , dims = dims ),
740+ channel = VariableScaling (method = "max" , dims = dims ),
741+ )
0 commit comments