Skip to content

Commit 387f748

Browse files
committed
get tests passing
1 parent d06b0c6 commit 387f748

File tree

2 files changed

+67
-41
lines changed

2 files changed

+67
-41
lines changed

causalpy/pymc_models.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1718,9 +1718,12 @@ def build_model(
17181718
# Build coordinates for the model
17191719
coordinates = self.ss_mod.coords.copy()
17201720
if coords:
1721-
# Merge with user-provided coords (excluding datetime_index which was extracted)
1721+
# Merge with user-provided coords (excluding datetime_index and obs_ind which are handled separately)
17221722
coords_copy = coords.copy()
17231723
coords_copy.pop("datetime_index", None)
1724+
coords_copy.pop(
1725+
"obs_ind", None
1726+
) # obs_ind handled by state-space model's time dimension
17241727
coordinates.update(coords_copy)
17251728

17261729
# Build model
@@ -1920,13 +1923,12 @@ def score(
19201923
**kwargs: Any,
19211924
) -> pd.Series:
19221925
"""
1923-
Compute R^2 between observed and mean forecast.
1926+
Score the Bayesian R^2 given inputs X and outputs y.
19241927
19251928
Parameters
19261929
----------
19271930
X : xr.DataArray, optional
1928-
Input features with dims ["obs_ind", "coeffs"]. Not used by state-space
1929-
models, kept for API compatibility.
1931+
Input features. Not used by state-space models, but kept for API compatibility.
19301932
y : xr.DataArray
19311933
Target variable with dims ["obs_ind", "treated_units"].
19321934
coords : dict, optional
@@ -1937,5 +1939,5 @@ def score(
19371939
pd.Series
19381940
R² score and standard deviation for each treated unit.
19391941
"""
1940-
# Use base class score method now that we have treated_units dimension
1941-
return super().score(X, y, coords=coords, **kwargs)
1942+
# Use base class implementation - X is accepted but not used by predict()
1943+
return super().score(X, y, coords, **kwargs)

causalpy/tests/test_integration_pymc_examples.py

Lines changed: 59 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -945,21 +945,32 @@ def test_bayesian_structural_time_series():
945945
assert isinstance(score_empty_x, pd.Series)
946946

947947
# --- Test Case 4: Model with incorrect coord/data setup (ValueErrors) --- #
948+
# Test that X must have datetime coordinates
948949
with pytest.raises(
949950
ValueError,
950-
match=r"coords must contain 'datetime_index' of type pd\.DatetimeIndex",
951+
match=r"X\.coords\['obs_ind'\] must contain datetime values",
951952
):
952953
model_error_idx = cp.pymc_models.BayesianBasisExpansionTimeSeries(
953954
sample_kwargs=bsts_sample_kwargs
954955
)
955-
bad_dt_idx_coords = coords_with_x.copy()
956-
bad_dt_idx_coords["datetime_index"] = np.arange(n_obs) # Not a DatetimeIndex
957-
958-
# Using DataArrays here too for consistency, though check happens on coords dict
956+
# Create X with non-datetime obs_ind coordinates
957+
bad_X = xr.DataArray(
958+
data_with_x[["x1"]].values,
959+
dims=["obs_ind", "coeffs"],
960+
coords={
961+
"obs_ind": np.arange(n_obs),
962+
"coeffs": ["x1"],
963+
}, # integers not datetime
964+
)
965+
bad_y = xr.DataArray(
966+
data_with_x["y"].values[:, None],
967+
dims=["obs_ind", "treated_units"],
968+
coords={"obs_ind": np.arange(n_obs), "treated_units": ["unit_0"]},
969+
)
959970
model_error_idx.fit(
960-
X=X_da,
961-
y=y_da,
962-
coords=bad_dt_idx_coords.copy(), # Pass a copy
971+
X=bad_X,
972+
y=bad_y,
973+
coords=coords_with_x.copy(),
963974
)
964975

965976
with pytest.raises(ValueError, match="Model was built with exogenous variables"):
@@ -969,7 +980,7 @@ def test_bayesian_structural_time_series():
969980

970981
with pytest.raises(
971982
ValueError,
972-
match=r"Mismatch: X_exog_array has 2 columns, but 1 names provided",
983+
match=r"Exogenous variable names mismatch",
973984
):
974985
wrong_shape_x_pred_vals = np.hstack(
975986
[data_with_x[["x1"]].values, data_with_x[["x1"]].values]
@@ -1038,12 +1049,6 @@ def test_state_space_time_series():
10381049
"random_seed": 42,
10391050
}
10401051

1041-
# Coordinates for the model
1042-
coords = {
1043-
"obs_ind": np.arange(n_obs),
1044-
"datetime_index": dates,
1045-
}
1046-
10471052
# Create DataArray for y to support score() which requires xarray
10481053
# Use dates as obs_ind coordinate (datetime values required by new API)
10491054
y_da = xr.DataArray(
@@ -1062,10 +1067,16 @@ def test_state_space_time_series():
10621067

10631068
# Test the complete workflow
10641069
# --- Test Case 1: Model fitting --- #
1070+
# Create dummy X (state-space doesn't use exogenous vars but we pass empty array for API consistency)
1071+
dummy_X = xr.DataArray(
1072+
np.zeros((len(dates), 0)),
1073+
dims=["obs_ind", "coeffs"],
1074+
coords={"obs_ind": dates, "coeffs": []},
1075+
)
1076+
# StateSpaceTimeSeries extracts datetime from xarray coords, no separate coords dict needed
10651077
idata = model.fit(
1066-
X=None, # No exogenous variables for state-space model
1078+
X=dummy_X,
10671079
y=y_da,
1068-
coords=coords.copy(),
10691080
)
10701081

10711082
# Verify inference data structure
@@ -1089,9 +1100,14 @@ def test_state_space_time_series():
10891100
assert "mu" in idata.posterior_predictive
10901101

10911102
# --- Test Case 2: In-sample prediction --- #
1103+
# Create dummy X for in-sample prediction (state-space doesn't use it but API requires it for consistency)
1104+
dummy_X_insample = xr.DataArray(
1105+
np.zeros((len(dates), 0)),
1106+
dims=["obs_ind", "coeffs"],
1107+
coords={"obs_ind": dates, "coeffs": []},
1108+
)
10921109
predictions_in_sample = model.predict(
1093-
X=None,
1094-
coords=coords,
1110+
X=dummy_X_insample,
10951111
out_of_sample=False,
10961112
)
10971113
assert isinstance(predictions_in_sample, az.InferenceData)
@@ -1101,9 +1117,6 @@ def test_state_space_time_series():
11011117

11021118
# --- Test Case 3: Out-of-sample prediction (forecasting) --- #
11031119
future_dates = pd.date_range(start="2020-04-01", end="2020-04-07", freq="D")
1104-
future_coords = {
1105-
"datetime_index": future_dates,
1106-
}
11071120
# Create dummy X for forecasting (needs time index)
11081121
future_X = xr.DataArray(
11091122
np.zeros((len(future_dates), 0)),
@@ -1113,7 +1126,6 @@ def test_state_space_time_series():
11131126

11141127
predictions_out_sample = model.predict(
11151128
X=future_X,
1116-
coords=future_coords,
11171129
out_of_sample=True,
11181130
)
11191131
# Note: predict now returns InferenceData, not Dataset!
@@ -1134,10 +1146,15 @@ def test_state_space_time_series():
11341146
)
11351147

11361148
# --- Test Case 4: Model scoring --- #
1149+
# Create dummy X for score (state-space doesn't use it but API requires it)
1150+
dummy_X_for_score = xr.DataArray(
1151+
np.zeros((len(dates), 0)),
1152+
dims=["obs_ind", "coeffs"],
1153+
coords={"obs_ind": dates, "coeffs": []},
1154+
)
11371155
score = model.score(
1138-
X=None,
1156+
X=dummy_X_for_score,
11391157
y=y_da,
1140-
coords=coords,
11411158
)
11421159
assert isinstance(score, pd.Series)
11431160
assert "unit_0_r2" in score.index
@@ -1164,30 +1181,37 @@ def test_state_space_time_series():
11641181
assert model.mode == "FAST_COMPILE"
11651182

11661183
# --- Test Case 6: Error handling --- #
1167-
# Test with invalid datetime_index
1184+
# Test that y must have datetime coordinates
11681185
with pytest.raises(
11691186
ValueError,
1170-
match=r"coords must contain 'datetime_index' of type pd\.DatetimeIndex",
1187+
match=r"y\.coords\['obs_ind'\] must contain datetime values",
11711188
):
11721189
model_error = cp.pymc_models.StateSpaceTimeSeries(
11731190
sample_kwargs=ss_sample_kwargs
11741191
)
1175-
bad_coords = coords.copy()
1176-
bad_coords["datetime_index"] = np.arange(n_obs) # Not a DatetimeIndex
1192+
# Create y with non-datetime coords (integers instead)
1193+
bad_y = xr.DataArray(
1194+
data["y"].values.reshape(-1, 1),
1195+
dims=["obs_ind", "treated_units"],
1196+
coords={"obs_ind": np.arange(n_obs), "treated_units": ["unit_0"]},
1197+
)
1198+
bad_X = xr.DataArray(
1199+
np.zeros((n_obs, 0)),
1200+
dims=["obs_ind", "coeffs"],
1201+
coords={"obs_ind": np.arange(n_obs), "coeffs": []},
1202+
)
11771203
model_error.fit(
1178-
X=None,
1179-
y=data["y"].values.reshape(-1, 1),
1180-
coords=bad_coords,
1204+
X=bad_X,
1205+
y=bad_y,
11811206
)
11821207

1183-
# Test prediction with invalid coords (missing X)
1208+
# Test prediction with missing X for out-of-sample
11841209
with pytest.raises(
11851210
ValueError,
1186-
match="X must have 'obs_ind' coordinate with datetime values",
1211+
match="X must be provided for out-of-sample predictions",
11871212
):
11881213
model.predict(
11891214
X=None,
1190-
coords={"invalid": "coords"},
11911215
out_of_sample=True,
11921216
)
11931217

0 commit comments

Comments
 (0)