@@ -1222,6 +1222,215 @@ def test_mmm_linear_trend_different_dimensions_original_scale(
12221222 }
12231223
12241224
1225+ def test_set_xarray_data_preserves_dtypes (multi_dim_data , mock_pymc_sample ):
1226+ """Test that _set_xarray_data preserves the original data types from the model."""
1227+ X , y = multi_dim_data
1228+
1229+ # Build and fit the model
1230+ mmm = MMM (
1231+ adstock = GeometricAdstock (l_max = 2 ),
1232+ saturation = LogisticSaturation (),
1233+ date_column = "date" ,
1234+ target_column = "target" ,
1235+ channel_columns = ["channel_1" , "channel_2" , "channel_3" ],
1236+ dims = ("country" ,),
1237+ control_columns = None , # Testing without control columns first
1238+ )
1239+
1240+ mmm .build_model (X , y )
1241+
1242+ # Store original dtypes from the model
1243+ original_channel_dtype = mmm .model .named_vars ["channel_data" ].type .dtype
1244+ original_target_dtype = mmm .model .named_vars ["target_data" ].type .dtype
1245+
1246+ # Create new data with different dtypes
1247+ X_new = X .copy ()
1248+ # Convert channel columns to float32 (different from typical float64)
1249+ for col in ["channel_1" , "channel_2" , "channel_3" ]:
1250+ X_new [col ] = X_new [col ].astype (np .float32 )
1251+
1252+ # Transform to xarray dataset without target (prediction scenario)
1253+ dataset_xarray = mmm ._posterior_predictive_data_transformation (
1254+ X = X_new ,
1255+ y = None , # Don't pass y for prediction
1256+ include_last_observations = False ,
1257+ )
1258+
1259+ # Verify that the input data has different dtypes
1260+ assert dataset_xarray ._channel .dtype == np .float32
1261+
1262+ # Apply _set_xarray_data
1263+ model = mmm ._set_xarray_data (dataset_xarray , clone_model = True )
1264+
1265+ # Check that the data in the model has been converted to the original dtypes
1266+ assert model .named_vars ["channel_data" ].get_value ().dtype == original_channel_dtype
1267+
1268+ # Also verify the data shapes are preserved
1269+ assert model .named_vars ["channel_data" ].get_value ().shape == (
1270+ len (X_new [mmm .date_column ].unique ()),
1271+ len (mmm .xarray_dataset .coords ["country" ]),
1272+ len (mmm .channel_columns ),
1273+ )
1274+
1275+ # Now test with target data - create properly structured y data
1276+ # Combine X and y to create a proper DataFrame structure
1277+ df_with_target = X_new .copy ()
1278+ df_with_target ["target" ] = y .values # Add target column
1279+
1280+ # Convert target to float32 to test dtype conversion
1281+ df_with_target ["target" ] = df_with_target ["target" ].astype (np .float32 )
1282+
1283+ # Extract y as a properly indexed Series
1284+ y_new = df_with_target .set_index (["date" , "country" ])["target" ]
1285+
1286+ # Transform to xarray dataset with target
1287+ dataset_xarray_with_target = mmm ._posterior_predictive_data_transformation (
1288+ X = X_new ,
1289+ y = y_new ,
1290+ include_last_observations = False ,
1291+ )
1292+
1293+ # Verify that the target has different dtype
1294+ assert dataset_xarray_with_target ._target .dtype == np .float32
1295+
1296+ # Apply _set_xarray_data with target
1297+ model_with_target = mmm ._set_xarray_data (
1298+ dataset_xarray_with_target , clone_model = True
1299+ )
1300+
1301+ # Check that target dtype is preserved
1302+ assert (
1303+ model_with_target .named_vars ["target_data" ].get_value ().dtype
1304+ == original_target_dtype
1305+ )
1306+ assert model_with_target .named_vars ["target_data" ].get_value ().shape == (
1307+ len (X_new [mmm .date_column ].unique ()),
1308+ len (mmm .xarray_dataset .coords ["country" ]),
1309+ )
1310+
1311+
1312+ def test_set_xarray_data_with_control_columns_preserves_dtypes (multi_dim_data ):
1313+ """Test that _set_xarray_data preserves dtypes when control columns are present."""
1314+ X , y = multi_dim_data
1315+
1316+ # Add control columns with specific dtypes
1317+ X ["control_1" ] = np .random .randn (len (X )).astype (np .float64 )
1318+ X ["control_2" ] = np .random .randn (len (X )).astype (np .float64 )
1319+
1320+ # Build model with control columns
1321+ mmm = MMM (
1322+ adstock = GeometricAdstock (l_max = 2 ),
1323+ saturation = LogisticSaturation (),
1324+ date_column = "date" ,
1325+ target_column = "target" ,
1326+ channel_columns = ["channel_1" , "channel_2" , "channel_3" ],
1327+ dims = ("country" ,),
1328+ control_columns = ["control_1" , "control_2" ],
1329+ )
1330+
1331+ mmm .build_model (X , y )
1332+
1333+ # Store original dtypes
1334+ original_channel_dtype = mmm .model .named_vars ["channel_data" ].type .dtype
1335+ original_control_dtype = mmm .model .named_vars ["control_data" ].type .dtype
1336+ original_target_dtype = mmm .model .named_vars ["target_data" ].type .dtype
1337+
1338+ # Create new data with different dtypes
1339+ X_new = X .copy ()
1340+ # Convert all numeric columns to float32
1341+ for col in X_new .select_dtypes (include = [np .number ]).columns :
1342+ X_new [col ] = X_new [col ].astype (np .float32 )
1343+
1344+ # First test without target (prediction scenario)
1345+ dataset_xarray = mmm ._posterior_predictive_data_transformation (
1346+ X = X_new ,
1347+ y = None ,
1348+ include_last_observations = False ,
1349+ )
1350+
1351+ # Apply _set_xarray_data
1352+ model = mmm ._set_xarray_data (dataset_xarray , clone_model = True )
1353+
1354+ # Check that data types are preserved
1355+ assert model .named_vars ["channel_data" ].get_value ().dtype == original_channel_dtype
1356+ assert model .named_vars ["control_data" ].get_value ().dtype == original_control_dtype
1357+
1358+ # Now test with target data - create properly structured y data
1359+ df_with_target = X_new .copy ()
1360+ df_with_target ["target" ] = y .values
1361+ df_with_target ["target" ] = df_with_target ["target" ].astype (np .float32 )
1362+
1363+ # Extract y as a properly indexed Series
1364+ y_new = df_with_target .set_index (["date" , "country" ])["target" ]
1365+
1366+ # Transform to xarray dataset with target
1367+ dataset_xarray_with_target = mmm ._posterior_predictive_data_transformation (
1368+ X = X_new ,
1369+ y = y_new ,
1370+ include_last_observations = False ,
1371+ )
1372+
1373+ # Apply _set_xarray_data with target
1374+ model_with_target = mmm ._set_xarray_data (
1375+ dataset_xarray_with_target , clone_model = True
1376+ )
1377+
1378+ # Check that all data types are preserved
1379+ assert (
1380+ model_with_target .named_vars ["channel_data" ].get_value ().dtype
1381+ == original_channel_dtype
1382+ )
1383+ assert (
1384+ model_with_target .named_vars ["control_data" ].get_value ().dtype
1385+ == original_control_dtype
1386+ )
1387+ assert (
1388+ model_with_target .named_vars ["target_data" ].get_value ().dtype
1389+ == original_target_dtype
1390+ )
1391+
1392+
1393+ def test_set_xarray_data_without_target_preserves_dtypes (multi_dim_data ):
1394+ """Test that _set_xarray_data preserves dtypes when target is not provided."""
1395+ X , y = multi_dim_data
1396+
1397+ # Build the model
1398+ mmm = MMM (
1399+ adstock = GeometricAdstock (l_max = 2 ),
1400+ saturation = LogisticSaturation (),
1401+ date_column = "date" ,
1402+ target_column = "target" ,
1403+ channel_columns = ["channel_1" , "channel_2" , "channel_3" ],
1404+ dims = ("country" ,),
1405+ )
1406+
1407+ mmm .build_model (X , y )
1408+
1409+ # Store original dtype
1410+ original_channel_dtype = mmm .model .named_vars ["channel_data" ].type .dtype
1411+
1412+ # Create new data without target
1413+ X_new = X .copy ()
1414+ for col in ["channel_1" , "channel_2" , "channel_3" ]:
1415+ X_new [col ] = X_new [col ].astype (np .float32 )
1416+
1417+ # Transform to xarray dataset without y
1418+ dataset_xarray = mmm ._posterior_predictive_data_transformation (
1419+ X = X_new ,
1420+ y = None , # No target provided
1421+ include_last_observations = False ,
1422+ )
1423+
1424+ # Apply _set_xarray_data
1425+ model = mmm ._set_xarray_data (dataset_xarray , clone_model = True )
1426+
1427+ # Check that channel data type is preserved
1428+ assert model .named_vars ["channel_data" ].get_value ().dtype == original_channel_dtype
1429+
1430+ # Target data should remain unchanged from the original model
1431+ # (no new target data was provided)
1432+
1433+
12251434@pytest .mark .parametrize (
12261435 "date_col_name" ,
12271436 ["date_week" , "week" , "period" , "timestamp" , "time_period" ],
0 commit comments