Skip to content

Commit a807820

Browse files
committed
Final fixes to failing tests
1 parent 3ec6683 commit a807820

File tree

4 files changed

+35
-11
lines changed

4 files changed

+35
-11
lines changed

homepy/blocks/gp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def make_covariance_kernel(kernel_spec: str) -> pm.gp.cov.Covariance:
140140

141141
def eval_(node):
142142
if isinstance(node, ast.Constant):
143-
return node.n
143+
return node.value
144144
elif isinstance(node, ast.BinOp):
145145
return operators[type(node.op)](eval_(node.left), eval_(node.right))
146146
elif isinstance(node, ast.UnaryOp):

homepy/models/base.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -261,21 +261,26 @@ def batched_sample_posterior_predictive(
261261
compile_kwargs.setdefault("accept_inplace", True)
262262

263263
constant_data: Dict[str, np.ndarray] = {}
264-
trace_coords: Dict[str, np.ndarray] = {}
265264
_constant_data = getattr(idata, "constant_data", None)
266265
if _constant_data is not None:
267-
trace_coords.update(
268-
{str(k): v.data for k, v in _constant_data.coords.items()}
269-
)
270266
constant_data.update({str(k): v.data for k, v in _constant_data.items()})
267+
idata_coords = {key: v.data for key, v in idata.posterior.coords.items()}
268+
if "observed_data" in idata.groups():
269+
idata_coords.update(
270+
{key: v.data for key, v in idata.observed_data.coords.items()}
271+
)
272+
if "constant_data" in idata.groups():
273+
idata_coords.update(
274+
{key: v.data for key, v in idata.constant_data.coords.items()}
275+
)
271276

272277
constant_coords = set()
273-
for dim, coord in trace_coords.items():
278+
for dim, coord in idata_coords.items():
274279
current_coord = self.model.coords.get(dim, None)
275280
if (
276281
current_coord is not None
277282
and len(coord) == len(current_coord)
278-
and np.all(coord == current_coord)
283+
and np.all(coord == np.asarray(current_coord))
279284
):
280285
constant_coords.add(dim)
281286

@@ -1094,10 +1099,28 @@ def compile_between_and_within_subjects(
10941099
# Compile the forward sampling function that computes the expected values
10951100
# The second argument that is returned by compile_forward_sampling_function
10961101
# is the resampled variables, so we ignore it
1102+
idata_coords = {key: v.data for key, v in idata.posterior.coords.items()}
1103+
if "observed_data" in idata.groups():
1104+
idata_coords.update(
1105+
{key: v.data for key, v in idata.observed_data.coords.items()}
1106+
)
1107+
if "constant_data" in idata.groups():
1108+
idata_coords.update(
1109+
{key: v.data for key, v in idata.constant_data.coords.items()}
1110+
)
1111+
constant_coords = {
1112+
dim
1113+
for dim, coord in self.model.coords.items()
1114+
if dim not in idata_coords
1115+
or np.array_equal(np.asarray(coord), idata_coords[dim])
1116+
}
1117+
10971118
return compile_forward_sampling_function(
10981119
expected_values,
10991120
vars_in_trace=vars_in_trace,
11001121
basic_rvs=basic_RVs,
1122+
constant_data=getattr(idata, "constant_data", None),
1123+
constant_coords=constant_coords,
11011124
**kwargs,
11021125
)[0]
11031126

homepy/pytensorf.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from pymc.pytensorf import find_rng_nodes
2222
from pytensor.graph import FunctionGraph
2323
from pytensor.graph.basic import clone_get_equiv
24+
from pytensor.tensor.exceptions import NotScalarConstantError
2425
from pytensor.tensor.random.basic import NormalRV
2526

2627
try:
@@ -46,16 +47,16 @@ def make_normal_not_centered(fgraph, node):
4647
nonlocal resampled_vars
4748
nonlocal free_RVs
4849
nonlocal free_RV_names
49-
rng, size, dtype, loc, scale = node.inputs
50+
rng, size, loc, scale = node.inputs
5051

5152
try:
5253
loc_is_zero = get_underlying_scalar_constant_value(loc) == 0
53-
except ValueError:
54+
except NotScalarConstantError:
5455
loc_is_zero = False
5556

5657
name = getattr(node.outputs[1], "name", None)
5758
if not loc_is_zero and name in resampled_vars_mapping:
58-
raw = pt.random.normal(0, 1, rng=rng, size=size, dtype=dtype)
59+
raw = pt.random.normal(0, 1, rng=rng, size=size)
5960
raw.name = name + "_raw"
6061
og_index = resampled_vars_mapping[name]
6162
resampled_vars[og_index[0]][og_index[1]] = raw

homepy/tests/blocks/test_likelihoods.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ def test_parse_observed_features_and_dims_error():
205205
zero_rate_name="zero_rate",
206206
zero_rate_prior=None,
207207
),
208-
("add", "sigmoid", "halfnormal_rv", "uniform_rv", "MarginalMixtureRV"),
208+
("add", "sigmoid", "halfnormal_rv", "uniform_rv", "MixtureRV"),
209209
("exp", "sqr", "exponential_rv", "log", "beta_rv"),
210210
),
211211
(

0 commit comments

Comments
 (0)