Skip to content

Commit 900926b

Browse files
Bug: Sample posterior predictive raise error if dates overlap (#1778)
* Bug: Sample posterior predictive raise error if dates overlap * Automatic casting based on target type * William feedback --------- Co-authored-by: Will Dean <57733339+williambdean@users.noreply.github.com>
1 parent 2e963b1 commit 900926b

File tree

2 files changed

+171
-1
lines changed

2 files changed

+171
-1
lines changed

pymc_marketing/mmm/multidimensional.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1159,6 +1159,43 @@ def create_deterministic(x: pt.TensorVariable) -> None:
11591159
observed=target_data_scaled,
11601160
)
11611161

1162+
def _validate_date_overlap_with_include_last_observations(
1163+
self, X: pd.DataFrame, include_last_observations: bool
1164+
) -> None:
1165+
"""Validate that include_last_observations is not used with overlapping dates.
1166+
1167+
Parameters
1168+
----------
1169+
X : pd.DataFrame
1170+
The input data for prediction.
1171+
include_last_observations : bool
1172+
Whether to include the last observations of the training data.
1173+
1174+
Raises
1175+
------
1176+
ValueError
1177+
If include_last_observations=True and input dates overlap with training dates.
1178+
"""
1179+
if not include_last_observations:
1180+
return
1181+
1182+
# Get training dates and input dates
1183+
training_dates = pd.to_datetime(self.model_coords["date"])
1184+
input_dates = pd.to_datetime(X[self.date_column].unique())
1185+
1186+
# Check for overlap
1187+
overlapping_dates = set(training_dates).intersection(set(input_dates))
1188+
1189+
if overlapping_dates:
1190+
overlapping_dates_str = ", ".join(
1191+
sorted([str(d.date()) for d in overlapping_dates])
1192+
)
1193+
raise ValueError(
1194+
f"Cannot use include_last_observations=True when input dates overlap with training dates. "
1195+
f"Overlapping dates found: {overlapping_dates_str}. "
1196+
f"Either set include_last_observations=False or use input dates that don't overlap with training data."
1197+
)
1198+
11621199
def _posterior_predictive_data_transformation(
11631200
self,
11641201
X: pd.DataFrame,
@@ -1181,6 +1218,11 @@ def _posterior_predictive_data_transformation(
11811218
xr.Dataset
11821219
The transformed data in xarray format.
11831220
"""
1221+
# Validate that include_last_observations is not used with overlapping dates
1222+
self._validate_date_overlap_with_include_last_observations(
1223+
X, include_last_observations
1224+
)
1225+
11841226
dataarrays = []
11851227
if include_last_observations:
11861228
last_obs = self.xarray_dataset.isel(date=slice(-self.adstock.l_max, None))
@@ -1220,13 +1262,15 @@ def _posterior_predictive_data_transformation(
12201262
)
12211263
else:
12221264
# Return empty xarray with same dimensions as the target but full of zeros
1265+
# Use the same dtype as the existing target data to avoid dtype mismatches
1266+
target_dtype = self.xarray_dataset._target.dtype
12231267
y_xarray = xr.DataArray(
12241268
np.zeros(
12251269
(
12261270
X[self.date_column].nunique(),
12271271
*[len(self.xarray_dataset.coords[dim]) for dim in self.dims],
12281272
),
1229-
dtype=np.int32,
1273+
dtype=target_dtype,
12301274
),
12311275
dims=("date", *self.dims),
12321276
coords={

tests/mmm/test_multidimensional.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,132 @@ def test_sample_posterior_predictive_same_data(single_dim_data, mock_pymc_sample
418418
)
419419

420420

421+
def test_sample_posterior_predictive_same_data_with_include_last_observations(
422+
single_dim_data, mock_pymc_sample
423+
):
424+
"""
425+
Test that using include_last_observations=True with training data (overlapping dates)
426+
raises a ValueError with a clear error message.
427+
"""
428+
X, y = single_dim_data
429+
X_train = X.iloc[:-5]
430+
y_train = y.iloc[:-5]
431+
432+
# Build and fit the model
433+
adstock = GeometricAdstock(l_max=2)
434+
saturation = LogisticSaturation()
435+
436+
mmm = MMM(
437+
date_column="date",
438+
target_column="target",
439+
channel_columns=["channel_1", "channel_2", "channel_3"],
440+
adstock=adstock,
441+
saturation=saturation,
442+
)
443+
444+
mmm.build_model(X_train, y_train)
445+
mmm.fit(X_train, y_train, draws=200, tune=100, chains=1, random_seed=123)
446+
447+
# Try to use include_last_observations=True with the same training data
448+
# This should raise a ValueError
449+
with pytest.raises(
450+
ValueError,
451+
match="Cannot use include_last_observations=True when input dates overlap with training dates",
452+
):
453+
mmm.sample_posterior_predictive(
454+
X_train, # Same training data
455+
include_last_observations=True, # This should trigger the error
456+
extend_idata=False,
457+
random_seed=123,
458+
)
459+
460+
461+
def test_sample_posterior_predictive_partial_overlap_with_include_last_observations(
462+
single_dim_data, mock_pymc_sample
463+
):
464+
"""
465+
Test that even partial date overlap with include_last_observations=True raises ValueError.
466+
"""
467+
X, y = single_dim_data
468+
X_train = X.iloc[:-5]
469+
y_train = y.iloc[:-5]
470+
471+
# Build and fit the model
472+
adstock = GeometricAdstock(l_max=2)
473+
saturation = LogisticSaturation()
474+
475+
mmm = MMM(
476+
date_column="date",
477+
target_column="target",
478+
channel_columns=["channel_1", "channel_2", "channel_3"],
479+
adstock=adstock,
480+
saturation=saturation,
481+
)
482+
483+
mmm.build_model(X_train, y_train)
484+
mmm.fit(X_train, y_train, draws=200, tune=100, chains=1, random_seed=123)
485+
486+
# Create data that partially overlaps with training data
487+
# Take the last 3 training dates + 3 new future dates
488+
overlap_data = X.iloc[-8:-2] # This will include some training dates
489+
490+
# This should raise a ValueError due to partial overlap
491+
with pytest.raises(
492+
ValueError,
493+
match="Cannot use include_last_observations=True when input dates overlap with training dates",
494+
):
495+
mmm.sample_posterior_predictive(
496+
overlap_data,
497+
include_last_observations=True,
498+
extend_idata=False,
499+
random_seed=123,
500+
)
501+
502+
503+
def test_sample_posterior_predictive_no_overlap_with_include_last_observations(
504+
single_dim_data, mock_pymc_sample
505+
):
506+
"""
507+
Test that include_last_observations=True works correctly when there's no date overlap.
508+
"""
509+
X, y = single_dim_data
510+
X_train = X.iloc[:-5]
511+
X_new = X.iloc[-5:] # Non-overlapping future dates
512+
y_train = y.iloc[:-5]
513+
514+
# Build and fit the model
515+
adstock = GeometricAdstock(l_max=2)
516+
saturation = LogisticSaturation()
517+
518+
mmm = MMM(
519+
date_column="date",
520+
target_column="target",
521+
channel_columns=["channel_1", "channel_2", "channel_3"],
522+
adstock=adstock,
523+
saturation=saturation,
524+
)
525+
526+
mmm.build_model(X_train, y_train)
527+
mmm.fit(X_train, y_train, draws=200, tune=100, chains=1, random_seed=123)
528+
529+
# This should work fine since dates don't overlap
530+
try:
531+
result = mmm.sample_posterior_predictive(
532+
X_new, # Non-overlapping dates
533+
include_last_observations=True, # Should work fine
534+
extend_idata=False,
535+
random_seed=123,
536+
)
537+
538+
# Verify that the result includes the expected dates
539+
# (should be l_max training dates + new prediction dates, then sliced to remove l_max)
540+
expected_dates = X_new["date"].values
541+
np.testing.assert_array_equal(result.coords["date"].values, expected_dates)
542+
543+
except ValueError as e:
544+
pytest.fail(f"Unexpected error when using non-overlapping dates: {e}")
545+
546+
421547
@pytest.fixture
422548
def df_events() -> pd.DataFrame:
423549
return pd.DataFrame(

0 commit comments

Comments
 (0)