@@ -1412,3 +1412,54 @@ def test_arbitrary_date_column_with_control_variables(
14121412
14131413 idata = mmm_controls .fit (X_with_controls , y , draws = 50 , tune = 25 , chains = 1 )
14141414 assert isinstance (idata , az .InferenceData )
1415+
1416+
1417+ @pytest .mark .parametrize (
1418+ "model_config, expected_config, expected_rv" ,
1419+ [
1420+ pytest .param (
1421+ {"intercept_tvp_config" : {"ls_lower" : 0.1 , "ls_upper" : None }},
1422+ None ,
1423+ dict (name = "intercept_latent_process_raw_ls_raw" , kind = "WeibullBetaRV" ),
1424+ id = "weibull" ,
1425+ ),
1426+ pytest .param (
1427+ {"intercept_tvp_config" : {"ls_lower" : 1 , "ls_upper" : 10 }},
1428+ None ,
1429+ dict (name = "intercept_latent_process_raw_ls" , kind = "InvGammaRV" ),
1430+ id = "inversegamma" ,
1431+ ),
1432+ ],
1433+ )
1434+ def test_specify_time_varying_configuration (
1435+ single_dim_data ,
1436+ model_config ,
1437+ expected_config ,
1438+ expected_rv ,
1439+ ) -> None :
1440+ X , y = single_dim_data
1441+ expected_config = expected_config or model_config
1442+
1443+ mmm = MMM (
1444+ date_column = "date" ,
1445+ target_column = "target" ,
1446+ channel_columns = ["channel_1" , "channel_2" ],
1447+ control_columns = ["control_1" , "control_2" ],
1448+ adstock = GeometricAdstock (l_max = 2 ),
1449+ saturation = LogisticSaturation (),
1450+ model_config = model_config ,
1451+ time_varying_intercept = True ,
1452+ )
1453+
1454+ assert isinstance (mmm .model_config ["intercept_tvp_config" ], dict )
1455+ assert (
1456+ mmm .model_config ["intercept_tvp_config" ]
1457+ == expected_config ["intercept_tvp_config" ]
1458+ )
1459+
1460+ mmm .build_model (X , y )
1461+
1462+ assert (
1463+ mmm .model [expected_rv ["name" ]].owner .op .__class__ .__name__
1464+ == expected_rv ["kind" ]
1465+ )
0 commit comments