Skip to content

Commit 7bd84d8

Browse files
committed
Fixed expected call tests
1 parent f389283 commit 7bd84d8

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

homepy/tests/models/test_base.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -334,12 +334,13 @@ def test_model_batched_sample_posterior_predictive(
334334
for var_name in model.get_var_names_to_resample(idata, var_names=var_names)
335335
]
336336
vars_in_trace_ = [model.model[var_name] for var_name in idata.posterior.data_vars]
337+
constant_coords = set(model.model.coords)
337338
mock_compile_forward_sampling_function.assert_called_once_with(
338339
outputs=outputs_,
339340
vars_in_trace=vars_in_trace_,
340341
basic_rvs=model.model.basic_RVs,
341342
constant_data={},
342-
constant_coords=set(),
343+
constant_coords=constant_coords,
343344
allow_input_downcast=True,
344345
accept_inplace=True,
345346
)
@@ -392,7 +393,10 @@ def test_new_model_batched_sample_posterior_predictive(
392393
vars_in_trace=vars_in_trace_,
393394
basic_rvs=new_model.model.basic_RVs,
394395
constant_data={},
395-
constant_coords=set(),
396+
constant_coords={
397+
dim for dim in new_model.model.coords
398+
if dim not in ['C_99_ids', 'observation', 'C_cage']
399+
},
396400
allow_input_downcast=True,
397401
accept_inplace=True,
398402
)

0 commit comments

Comments
 (0)