@@ -261,21 +261,26 @@ def batched_sample_posterior_predictive(
261261 compile_kwargs .setdefault ("accept_inplace" , True )
262262
263263 constant_data : Dict [str , np .ndarray ] = {}
264- trace_coords : Dict [str , np .ndarray ] = {}
265264 _constant_data = getattr (idata , "constant_data" , None )
266265 if _constant_data is not None :
267- trace_coords .update (
268- {str (k ): v .data for k , v in _constant_data .coords .items ()}
269- )
270266 constant_data .update ({str (k ): v .data for k , v in _constant_data .items ()})
267+ idata_coords = {key : v .data for key , v in idata .posterior .coords .items ()}
268+ if "observed_data" in idata .groups ():
269+ idata_coords .update (
270+ {key : v .data for key , v in idata .observed_data .coords .items ()}
271+ )
272+ if "constant_data" in idata .groups ():
273+ idata_coords .update (
274+ {key : v .data for key , v in idata .constant_data .coords .items ()}
275+ )
271276
272277 constant_coords = set ()
273- for dim , coord in trace_coords .items ():
278+ for dim , coord in idata_coords .items ():
274279 current_coord = self .model .coords .get (dim , None )
275280 if (
276281 current_coord is not None
277282 and len (coord ) == len (current_coord )
278- and np .all (coord == current_coord )
283+ and np .all (coord == np . asarray ( current_coord ) )
279284 ):
280285 constant_coords .add (dim )
281286
@@ -1094,10 +1099,28 @@ def compile_between_and_within_subjects(
10941099 # Compile the forward sampling function that computes the expected values
10951100 # The second argument that is returned by compile_forward_sampling_function
10961101 # is the resampled variables, so we ignore it
1102+ idata_coords = {key : v .data for key , v in idata .posterior .coords .items ()}
1103+ if "observed_data" in idata .groups ():
1104+ idata_coords .update (
1105+ {key : v .data for key , v in idata .observed_data .coords .items ()}
1106+ )
1107+ if "constant_data" in idata .groups ():
1108+ idata_coords .update (
1109+ {key : v .data for key , v in idata .constant_data .coords .items ()}
1110+ )
1111+ constant_coords = {
1112+ dim
1113+ for dim , coord in self .model .coords .items ()
1114+ if dim not in idata_coords
1115+ or np .array_equal (np .asarray (coord ), idata_coords [dim ])
1116+ }
1117+
10971118 return compile_forward_sampling_function (
10981119 expected_values ,
10991120 vars_in_trace = vars_in_trace ,
11001121 basic_rvs = basic_RVs ,
1122+ constant_data = getattr (idata , "constant_data" , None ),
1123+ constant_coords = constant_coords ,
11011124 ** kwargs ,
11021125 )[0 ]
11031126
0 commit comments