@@ -89,7 +89,7 @@ def build_model(self, X, y, coords) -> None:
8989 """Build the model, must be implemented by subclass."""
9090 raise NotImplementedError ("This method must be implemented by a subclass" )
9191
92- def _data_setter (self , X ) -> None :
92+ def _data_setter (self , X : xr . DataArray ) -> None :
9393 """
9494 Set data for the model.
9595
@@ -105,6 +105,9 @@ def _data_setter(self, X) -> None:
105105 """
106106 new_no_of_observations = X .shape [0 ]
107107
108+ # Use integer indices for obs_ind to avoid datetime compatibility issues with PyMC
109+ obs_coords = np .arange (new_no_of_observations )
110+
108111 # Check if this model has multiple treated units
109112 if hasattr (self , "idata" ) and self .idata is not None :
110113 posterior = self .idata .posterior
@@ -125,13 +128,13 @@ def _data_setter(self, X) -> None:
125128 # Multi-unit case or single unit with treated_units dimension
126129 pm .set_data (
127130 {"X" : X , "y" : np .zeros ((new_no_of_observations , n_treated_units ))},
128- coords = {"obs_ind" : np . arange ( new_no_of_observations ) },
131+ coords = {"obs_ind" : obs_coords },
129132 )
130133 else :
131134 # Other model types (e.g., LinearRegression) without treated_units dimension
132135 pm .set_data (
133136 {"X" : X , "y" : np .zeros (new_no_of_observations )},
134- coords = {"obs_ind" : np . arange ( new_no_of_observations ) },
137+ coords = {"obs_ind" : obs_coords },
135138 )
136139
137140 def fit (self , X , y , coords : Optional [Dict [str , Any ]] = None ) -> None :
@@ -154,7 +157,7 @@ def fit(self, X, y, coords: Optional[Dict[str, Any]] = None) -> None:
154157 )
155158 return self .idata
156159
157- def predict (self , X ):
160+ def predict (self , X : xr . DataArray ):
158161 """
159162 Predict data given input data `X`
160163
@@ -166,16 +169,19 @@ def predict(self, X):
166169 # sample_posterior_predictive() if provided in sample_kwargs.
167170 random_seed = self .sample_kwargs .get ("random_seed" , None )
168171 self ._data_setter (X )
169- with self : # sample with new input data
172+ with self :
170173 pp = pm .sample_posterior_predictive (
171174 self .idata ,
172175 var_names = ["y_hat" , "mu" ],
173176 progressbar = False ,
174177 random_seed = random_seed ,
175178 )
176179
177- # TODO: This is a bit of a hack. Maybe it could be done properly in _data_setter?
178- if isinstance (X , xr .DataArray ):
180+ # Assign coordinates from input X to ensure xarray operations work correctly
181+ # This is necessary because PyMC uses integer indices internally, but we need
182+ # to preserve the original coordinates (e.g., datetime indices) for proper
183+ # alignment with other xarray operations like calculate_impact()
184+ if isinstance (X , xr .DataArray ) and "obs_ind" in X .coords :
179185 pp ["posterior_predictive" ] = pp ["posterior_predictive" ].assign_coords (
180186 obs_ind = X .obs_ind
181187 )
0 commit comments