@@ -945,21 +945,32 @@ def test_bayesian_structural_time_series():
945945 assert isinstance (score_empty_x , pd .Series )
946946
947947 # --- Test Case 4: Model with incorrect coord/data setup (ValueErrors) --- #
948+ # Test that X must have datetime coordinates
948949 with pytest .raises (
949950 ValueError ,
950- match = r"coords must contain 'datetime_index' of type pd\.DatetimeIndex " ,
951+ match = r"X\. coords\['obs_ind'\] must contain datetime values " ,
951952 ):
952953 model_error_idx = cp .pymc_models .BayesianBasisExpansionTimeSeries (
953954 sample_kwargs = bsts_sample_kwargs
954955 )
955- bad_dt_idx_coords = coords_with_x .copy ()
956- bad_dt_idx_coords ["datetime_index" ] = np .arange (n_obs ) # Not a DatetimeIndex
957-
958- # Using DataArrays here too for consistency, though check happens on coords dict
956+ # Create X with non-datetime obs_ind coordinates
957+ bad_X = xr .DataArray (
958+ data_with_x [["x1" ]].values ,
959+ dims = ["obs_ind" , "coeffs" ],
960+ coords = {
961+ "obs_ind" : np .arange (n_obs ),
962+ "coeffs" : ["x1" ],
963+ }, # integers not datetime
964+ )
965+ bad_y = xr .DataArray (
966+ data_with_x ["y" ].values [:, None ],
967+ dims = ["obs_ind" , "treated_units" ],
968+ coords = {"obs_ind" : np .arange (n_obs ), "treated_units" : ["unit_0" ]},
969+ )
959970 model_error_idx .fit (
960- X = X_da ,
961- y = y_da ,
962- coords = bad_dt_idx_coords .copy (), # Pass a copy
971+ X = bad_X ,
972+ y = bad_y ,
973+ coords = coords_with_x .copy (),
963974 )
964975
965976 with pytest .raises (ValueError , match = "Model was built with exogenous variables" ):
@@ -969,7 +980,7 @@ def test_bayesian_structural_time_series():
969980
970981 with pytest .raises (
971982 ValueError ,
972- match = r"Mismatch: X_exog_array has 2 columns, but 1 names provided " ,
983+ match = r"Exogenous variable names mismatch " ,
973984 ):
974985 wrong_shape_x_pred_vals = np .hstack (
975986 [data_with_x [["x1" ]].values , data_with_x [["x1" ]].values ]
@@ -1038,12 +1049,6 @@ def test_state_space_time_series():
10381049 "random_seed" : 42 ,
10391050 }
10401051
1041- # Coordinates for the model
1042- coords = {
1043- "obs_ind" : np .arange (n_obs ),
1044- "datetime_index" : dates ,
1045- }
1046-
10471052 # Create DataArray for y to support score() which requires xarray
10481053 # Use dates as obs_ind coordinate (datetime values required by new API)
10491054 y_da = xr .DataArray (
@@ -1062,10 +1067,16 @@ def test_state_space_time_series():
10621067
10631068 # Test the complete workflow
10641069 # --- Test Case 1: Model fitting --- #
1070+ # Create dummy X (state-space doesn't use exogenous vars but we pass empty array for API consistency)
1071+ dummy_X = xr .DataArray (
1072+ np .zeros ((len (dates ), 0 )),
1073+ dims = ["obs_ind" , "coeffs" ],
1074+ coords = {"obs_ind" : dates , "coeffs" : []},
1075+ )
1076+ # StateSpaceTimeSeries extracts datetime from xarray coords, no separate coords dict needed
10651077 idata = model .fit (
1066- X = None , # No exogenous variables for state-space model
1078+ X = dummy_X ,
10671079 y = y_da ,
1068- coords = coords .copy (),
10691080 )
10701081
10711082 # Verify inference data structure
@@ -1089,9 +1100,14 @@ def test_state_space_time_series():
10891100 assert "mu" in idata .posterior_predictive
10901101
10911102 # --- Test Case 2: In-sample prediction --- #
1103+ # Create dummy X for in-sample prediction (state-space doesn't use it but API requires it for consistency)
1104+ dummy_X_insample = xr .DataArray (
1105+ np .zeros ((len (dates ), 0 )),
1106+ dims = ["obs_ind" , "coeffs" ],
1107+ coords = {"obs_ind" : dates , "coeffs" : []},
1108+ )
10921109 predictions_in_sample = model .predict (
1093- X = None ,
1094- coords = coords ,
1110+ X = dummy_X_insample ,
10951111 out_of_sample = False ,
10961112 )
10971113 assert isinstance (predictions_in_sample , az .InferenceData )
@@ -1101,9 +1117,6 @@ def test_state_space_time_series():
11011117
11021118 # --- Test Case 3: Out-of-sample prediction (forecasting) --- #
11031119 future_dates = pd .date_range (start = "2020-04-01" , end = "2020-04-07" , freq = "D" )
1104- future_coords = {
1105- "datetime_index" : future_dates ,
1106- }
11071120 # Create dummy X for forecasting (needs time index)
11081121 future_X = xr .DataArray (
11091122 np .zeros ((len (future_dates ), 0 )),
@@ -1113,7 +1126,6 @@ def test_state_space_time_series():
11131126
11141127 predictions_out_sample = model .predict (
11151128 X = future_X ,
1116- coords = future_coords ,
11171129 out_of_sample = True ,
11181130 )
11191131 # Note: predict now returns InferenceData, not Dataset!
@@ -1134,10 +1146,15 @@ def test_state_space_time_series():
11341146 )
11351147
11361148 # --- Test Case 4: Model scoring --- #
1149+ # Create dummy X for score (state-space doesn't use it but API requires it)
1150+ dummy_X_for_score = xr .DataArray (
1151+ np .zeros ((len (dates ), 0 )),
1152+ dims = ["obs_ind" , "coeffs" ],
1153+ coords = {"obs_ind" : dates , "coeffs" : []},
1154+ )
11371155 score = model .score (
1138- X = None ,
1156+ X = dummy_X_for_score ,
11391157 y = y_da ,
1140- coords = coords ,
11411158 )
11421159 assert isinstance (score , pd .Series )
11431160 assert "unit_0_r2" in score .index
@@ -1164,30 +1181,37 @@ def test_state_space_time_series():
11641181 assert model .mode == "FAST_COMPILE"
11651182
11661183 # --- Test Case 6: Error handling --- #
1167- # Test with invalid datetime_index
1184+ # Test that y must have datetime coordinates
11681185 with pytest .raises (
11691186 ValueError ,
1170- match = r"coords must contain 'datetime_index' of type pd\.DatetimeIndex " ,
1187+ match = r"y\. coords\['obs_ind'\] must contain datetime values " ,
11711188 ):
11721189 model_error = cp .pymc_models .StateSpaceTimeSeries (
11731190 sample_kwargs = ss_sample_kwargs
11741191 )
1175- bad_coords = coords .copy ()
1176- bad_coords ["datetime_index" ] = np .arange (n_obs ) # Not a DatetimeIndex
1192+ # Create y with non-datetime coords (integers instead)
1193+ bad_y = xr .DataArray (
1194+ data ["y" ].values .reshape (- 1 , 1 ),
1195+ dims = ["obs_ind" , "treated_units" ],
1196+ coords = {"obs_ind" : np .arange (n_obs ), "treated_units" : ["unit_0" ]},
1197+ )
1198+ bad_X = xr .DataArray (
1199+ np .zeros ((n_obs , 0 )),
1200+ dims = ["obs_ind" , "coeffs" ],
1201+ coords = {"obs_ind" : np .arange (n_obs ), "coeffs" : []},
1202+ )
11771203 model_error .fit (
1178- X = None ,
1179- y = data ["y" ].values .reshape (- 1 , 1 ),
1180- coords = bad_coords ,
1204+ X = bad_X ,
1205+ y = bad_y ,
11811206 )
11821207
1183- # Test prediction with invalid coords ( missing X)
1208+ # Test prediction with missing X for out-of-sample
11841209 with pytest .raises (
11851210 ValueError ,
1186- match = "X must have 'obs_ind' coordinate with datetime values " ,
1211+ match = "X must be provided for out-of-sample predictions " ,
11871212 ):
11881213 model .predict (
11891214 X = None ,
1190- coords = {"invalid" : "coords" },
11911215 out_of_sample = True ,
11921216 )
11931217
0 commit comments