Skip to content

Commit 893a7f2

Browse files
authored
Automatic casting based on model types for posterior predictive (#1781)
* Automatic casting based on model types for posterior predictive * Adding tests on the autocast. * pre commit
1 parent 578ffb1 commit 893a7f2

File tree

2 files changed

+231
-7
lines changed

2 files changed

+231
-7
lines changed

pymc_marketing/mmm/multidimensional.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1306,18 +1306,26 @@ def _set_xarray_data(
13061306
"""
13071307
model = cm(self.model) if clone_model else self.model
13081308

1309-
data = {
1310-
"channel_data": dataset_xarray._channel.transpose(
1311-
"date", *self.dims, "channel"
1312-
)
1313-
}
1309+
# Get channel data and handle dtype conversion
1310+
channel_values = dataset_xarray._channel.transpose(
1311+
"date", *self.dims, "channel"
1312+
)
1313+
if "channel_data" in model.named_vars:
1314+
original_dtype = model.named_vars["channel_data"].type.dtype
1315+
channel_values = channel_values.astype(original_dtype)
1316+
1317+
data = {"channel_data": channel_values}
13141318
coords = self.model.coords.copy()
13151319
coords["date"] = dataset_xarray["date"].to_numpy()
13161320

13171321
if "_control" in dataset_xarray:
1318-
data["control_data"] = dataset_xarray["_control"].transpose(
1322+
control_values = dataset_xarray["_control"].transpose(
13191323
"date", *self.dims, "control"
13201324
)
1325+
if "control_data" in model.named_vars:
1326+
original_dtype = model.named_vars["control_data"].type.dtype
1327+
control_values = control_values.astype(original_dtype)
1328+
data["control_data"] = control_values
13211329
coords["control"] = dataset_xarray["control"].to_numpy()
13221330
if self.yearly_seasonality is not None:
13231331
data["dayofyear"] = dataset_xarray["date"].dt.dayofyear.to_numpy()
@@ -1330,7 +1338,14 @@ def _set_xarray_data(
13301338
)
13311339

13321340
if "_target" in dataset_xarray:
1333-
data["target_data"] = dataset_xarray._target.transpose("date", *self.dims)
1341+
target_values = dataset_xarray._target.transpose("date", *self.dims)
1342+
# Get the original dtype from the model's shared variable
1343+
if "target_data" in model.named_vars:
1344+
original_dtype = model.named_vars["target_data"].type.dtype
1345+
# Convert to the original dtype to avoid precision loss errors
1346+
data["target_data"] = target_values.astype(original_dtype)
1347+
else:
1348+
data["target_data"] = target_values
13341349

13351350
self.new_updated_data = data
13361351
self.new_updated_coords = coords

tests/mmm/test_multidimensional.py

Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1222,6 +1222,215 @@ def test_mmm_linear_trend_different_dimensions_original_scale(
12221222
}
12231223

12241224

1225+
def test_set_xarray_data_preserves_dtypes(multi_dim_data, mock_pymc_sample):
1226+
"""Test that _set_xarray_data preserves the original data types from the model."""
1227+
X, y = multi_dim_data
1228+
1229+
# Build and fit the model
1230+
mmm = MMM(
1231+
adstock=GeometricAdstock(l_max=2),
1232+
saturation=LogisticSaturation(),
1233+
date_column="date",
1234+
target_column="target",
1235+
channel_columns=["channel_1", "channel_2", "channel_3"],
1236+
dims=("country",),
1237+
control_columns=None, # Testing without control columns first
1238+
)
1239+
1240+
mmm.build_model(X, y)
1241+
1242+
# Store original dtypes from the model
1243+
original_channel_dtype = mmm.model.named_vars["channel_data"].type.dtype
1244+
original_target_dtype = mmm.model.named_vars["target_data"].type.dtype
1245+
1246+
# Create new data with different dtypes
1247+
X_new = X.copy()
1248+
# Convert channel columns to float32 (different from typical float64)
1249+
for col in ["channel_1", "channel_2", "channel_3"]:
1250+
X_new[col] = X_new[col].astype(np.float32)
1251+
1252+
# Transform to xarray dataset without target (prediction scenario)
1253+
dataset_xarray = mmm._posterior_predictive_data_transformation(
1254+
X=X_new,
1255+
y=None, # Don't pass y for prediction
1256+
include_last_observations=False,
1257+
)
1258+
1259+
# Verify that the input data has different dtypes
1260+
assert dataset_xarray._channel.dtype == np.float32
1261+
1262+
# Apply _set_xarray_data
1263+
model = mmm._set_xarray_data(dataset_xarray, clone_model=True)
1264+
1265+
# Check that the data in the model has been converted to the original dtypes
1266+
assert model.named_vars["channel_data"].get_value().dtype == original_channel_dtype
1267+
1268+
# Also verify the data shapes are preserved
1269+
assert model.named_vars["channel_data"].get_value().shape == (
1270+
len(X_new[mmm.date_column].unique()),
1271+
len(mmm.xarray_dataset.coords["country"]),
1272+
len(mmm.channel_columns),
1273+
)
1274+
1275+
# Now test with target data - create properly structured y data
1276+
# Combine X and y to create a proper DataFrame structure
1277+
df_with_target = X_new.copy()
1278+
df_with_target["target"] = y.values # Add target column
1279+
1280+
# Convert target to float32 to test dtype conversion
1281+
df_with_target["target"] = df_with_target["target"].astype(np.float32)
1282+
1283+
# Extract y as a properly indexed Series
1284+
y_new = df_with_target.set_index(["date", "country"])["target"]
1285+
1286+
# Transform to xarray dataset with target
1287+
dataset_xarray_with_target = mmm._posterior_predictive_data_transformation(
1288+
X=X_new,
1289+
y=y_new,
1290+
include_last_observations=False,
1291+
)
1292+
1293+
# Verify that the target has different dtype
1294+
assert dataset_xarray_with_target._target.dtype == np.float32
1295+
1296+
# Apply _set_xarray_data with target
1297+
model_with_target = mmm._set_xarray_data(
1298+
dataset_xarray_with_target, clone_model=True
1299+
)
1300+
1301+
# Check that target dtype is preserved
1302+
assert (
1303+
model_with_target.named_vars["target_data"].get_value().dtype
1304+
== original_target_dtype
1305+
)
1306+
assert model_with_target.named_vars["target_data"].get_value().shape == (
1307+
len(X_new[mmm.date_column].unique()),
1308+
len(mmm.xarray_dataset.coords["country"]),
1309+
)
1310+
1311+
1312+
def test_set_xarray_data_with_control_columns_preserves_dtypes(multi_dim_data):
1313+
"""Test that _set_xarray_data preserves dtypes when control columns are present."""
1314+
X, y = multi_dim_data
1315+
1316+
# Add control columns with specific dtypes
1317+
X["control_1"] = np.random.randn(len(X)).astype(np.float64)
1318+
X["control_2"] = np.random.randn(len(X)).astype(np.float64)
1319+
1320+
# Build model with control columns
1321+
mmm = MMM(
1322+
adstock=GeometricAdstock(l_max=2),
1323+
saturation=LogisticSaturation(),
1324+
date_column="date",
1325+
target_column="target",
1326+
channel_columns=["channel_1", "channel_2", "channel_3"],
1327+
dims=("country",),
1328+
control_columns=["control_1", "control_2"],
1329+
)
1330+
1331+
mmm.build_model(X, y)
1332+
1333+
# Store original dtypes
1334+
original_channel_dtype = mmm.model.named_vars["channel_data"].type.dtype
1335+
original_control_dtype = mmm.model.named_vars["control_data"].type.dtype
1336+
original_target_dtype = mmm.model.named_vars["target_data"].type.dtype
1337+
1338+
# Create new data with different dtypes
1339+
X_new = X.copy()
1340+
# Convert all numeric columns to float32
1341+
for col in X_new.select_dtypes(include=[np.number]).columns:
1342+
X_new[col] = X_new[col].astype(np.float32)
1343+
1344+
# First test without target (prediction scenario)
1345+
dataset_xarray = mmm._posterior_predictive_data_transformation(
1346+
X=X_new,
1347+
y=None,
1348+
include_last_observations=False,
1349+
)
1350+
1351+
# Apply _set_xarray_data
1352+
model = mmm._set_xarray_data(dataset_xarray, clone_model=True)
1353+
1354+
# Check that data types are preserved
1355+
assert model.named_vars["channel_data"].get_value().dtype == original_channel_dtype
1356+
assert model.named_vars["control_data"].get_value().dtype == original_control_dtype
1357+
1358+
# Now test with target data - create properly structured y data
1359+
df_with_target = X_new.copy()
1360+
df_with_target["target"] = y.values
1361+
df_with_target["target"] = df_with_target["target"].astype(np.float32)
1362+
1363+
# Extract y as a properly indexed Series
1364+
y_new = df_with_target.set_index(["date", "country"])["target"]
1365+
1366+
# Transform to xarray dataset with target
1367+
dataset_xarray_with_target = mmm._posterior_predictive_data_transformation(
1368+
X=X_new,
1369+
y=y_new,
1370+
include_last_observations=False,
1371+
)
1372+
1373+
# Apply _set_xarray_data with target
1374+
model_with_target = mmm._set_xarray_data(
1375+
dataset_xarray_with_target, clone_model=True
1376+
)
1377+
1378+
# Check that all data types are preserved
1379+
assert (
1380+
model_with_target.named_vars["channel_data"].get_value().dtype
1381+
== original_channel_dtype
1382+
)
1383+
assert (
1384+
model_with_target.named_vars["control_data"].get_value().dtype
1385+
== original_control_dtype
1386+
)
1387+
assert (
1388+
model_with_target.named_vars["target_data"].get_value().dtype
1389+
== original_target_dtype
1390+
)
1391+
1392+
1393+
def test_set_xarray_data_without_target_preserves_dtypes(multi_dim_data):
1394+
"""Test that _set_xarray_data preserves dtypes when target is not provided."""
1395+
X, y = multi_dim_data
1396+
1397+
# Build the model
1398+
mmm = MMM(
1399+
adstock=GeometricAdstock(l_max=2),
1400+
saturation=LogisticSaturation(),
1401+
date_column="date",
1402+
target_column="target",
1403+
channel_columns=["channel_1", "channel_2", "channel_3"],
1404+
dims=("country",),
1405+
)
1406+
1407+
mmm.build_model(X, y)
1408+
1409+
# Store original dtype
1410+
original_channel_dtype = mmm.model.named_vars["channel_data"].type.dtype
1411+
1412+
# Create new data without target
1413+
X_new = X.copy()
1414+
for col in ["channel_1", "channel_2", "channel_3"]:
1415+
X_new[col] = X_new[col].astype(np.float32)
1416+
1417+
# Transform to xarray dataset without y
1418+
dataset_xarray = mmm._posterior_predictive_data_transformation(
1419+
X=X_new,
1420+
y=None, # No target provided
1421+
include_last_observations=False,
1422+
)
1423+
1424+
# Apply _set_xarray_data
1425+
model = mmm._set_xarray_data(dataset_xarray, clone_model=True)
1426+
1427+
# Check that channel data type is preserved
1428+
assert model.named_vars["channel_data"].get_value().dtype == original_channel_dtype
1429+
1430+
# Target data should remain unchanged from the original model
1431+
# (no new target data was provided)
1432+
1433+
12251434
@pytest.mark.parametrize(
12261435
"date_col_name",
12271436
["date_week", "week", "period", "timestamp", "time_period"],

0 commit comments

Comments
 (0)