Skip to content

Commit 420f61a

Browse files
Finalize inital point function inside _make_funcs
1 parent 762401c commit 420f61a

File tree

1 file changed

+40
-18
lines changed

1 file changed

+40
-18
lines changed

python/nutpie/compile_pymc.py

Lines changed: 40 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def seeded_array_fn(seed: SeedType = None):
6666
flat_array = np.empty(total_size, dtype="float64", order="C")
6767
cursor = 0
6868

69-
for name, shape in zip(names, shapes):
69+
for name, shape in zip(names, shapes, strict=True):
7070
initial_value = initial_value_dict[name]
7171
n = int(np.prod(initial_value.shape))
7272
if initial_value.shape != shape:
@@ -217,7 +217,7 @@ def make_user_data(shared_vars, shared_data):
217217

218218
def _compile_pymc_model_numba(
219219
model: "pm.Model",
220-
initial_point_fn: Callable[[SeedType], dict[str, np.ndarray]],
220+
pymc_initial_point_fn: Callable[[SeedType], dict[str, np.ndarray]],
221221
**kwargs,
222222
) -> CompiledPyMCModel:
223223
if find_spec("numba") is None:
@@ -234,8 +234,15 @@ def _compile_pymc_model_numba(
234234
n_expanded,
235235
logp_fn_pt,
236236
expand_fn_pt,
237+
initial_point_fn,
237238
shape_info,
238-
) = _make_functions(model, mode="NUMBA", compute_grad=True, join_expanded=True)
239+
) = _make_functions(
240+
model,
241+
mode="NUMBA",
242+
compute_grad=True,
243+
join_expanded=True,
244+
pymc_initial_point_fn=pymc_initial_point_fn,
245+
)
239246

240247
expand_fn = expand_fn_pt.vm.jit_fn
241248
logp_fn = logp_fn_pt.vm.jit_fn
@@ -282,17 +289,15 @@ def _compile_pymc_model_numba(
282289
expand_numba = numba.cfunc(c_sig_expand, **kwargs)(expand_numba_raw)
283290

284291
dims, coords = _prepare_dims_and_coords(model, shape_info)
285-
initial_point_fn_array = rv_dict_to_flat_array_wrapper(
286-
initial_point_fn, names=shape_info[0], shapes=shape_info[-1]
287-
)
292+
288293
return CompiledPyMCModel(
289294
_n_dim=n_dim,
290295
dims=dims,
291296
_coords=coords,
292297
_shapes={name: tuple(shape) for name, _, shape in zip(*shape_info)},
293298
compiled_logp_func=logp_numba,
294299
compiled_expand_func=expand_numba,
295-
initial_point_func=initial_point_fn_array,
300+
initial_point_func=initial_point_fn,
296301
shared_data=shared_data,
297302
user_data=user_data,
298303
n_expanded=n_expanded,
@@ -331,7 +336,7 @@ def _compile_pymc_model_jax(
331336
model,
332337
*,
333338
gradient_backend=None,
334-
initial_point_fn: Callable[[SeedType], dict[str, np.ndarray]],
339+
pymc_initial_point_fn: Callable[[SeedType], dict[str, np.ndarray]],
335340
**kwargs,
336341
):
337342
if find_spec("jax") is None:
@@ -353,12 +358,14 @@ def _compile_pymc_model_jax(
353358
_,
354359
logp_fn_pt,
355360
expand_fn_pt,
361+
initial_point_fn,
356362
shape_info,
357363
) = _make_functions(
358364
model,
359365
mode="JAX",
360366
compute_grad=gradient_backend == "pytensor",
361367
join_expanded=False,
368+
pymc_initial_point_fn=pymc_initial_point_fn,
362369
)
363370

364371
logp_fn = logp_fn_pt.vm.jit_fn
@@ -409,15 +416,11 @@ def expand(x, **shared):
409416

410417
dims, coords = _prepare_dims_and_coords(model, shape_info)
411418

412-
initial_point_fn_array = rv_dict_to_flat_array_wrapper(
413-
initial_point_fn, names=shape_info[0], shapes=shape_info[-1]
414-
)
415-
416419
return from_pyfunc(
417420
ndim=n_dim,
418421
make_logp_fn=make_logp_func,
419422
make_expand_fn=make_expand_func,
420-
make_initial_point_fn=initial_point_fn_array,
423+
make_initial_point_fn=initial_point_fn,
421424
expanded_dtypes=dtypes,
422425
expanded_shapes=shapes,
423426
expanded_names=names,
@@ -484,13 +487,13 @@ def compile_pymc_model(
484487
if gradient_backend == "jax":
485488
raise ValueError("Gradient backend cannot be jax when using numba backend")
486489
return _compile_pymc_model_numba(
487-
model, initial_point_fn=initial_point_fn, **kwargs
490+
model=model, pymc_initial_point_fn=initial_point_fn, **kwargs
488491
)
489492
elif backend.lower() == "jax":
490493
return _compile_pymc_model_jax(
491-
model,
494+
model=model,
492495
gradient_backend=gradient_backend,
493-
initial_point_fn=initial_point_fn,
496+
pymc_initial_point_fn=initial_point_fn,
494497
**kwargs,
495498
)
496499
else:
@@ -527,9 +530,19 @@ def _compute_shapes(model):
527530

528531

529532
def _make_functions(
530-
model, *, mode, compute_grad, join_expanded
533+
model: "pm.Model",
534+
*,
535+
mode: Literal["JAX", "NUMBA"],
536+
compute_grad: bool,
537+
join_expanded: bool,
538+
pymc_initial_point_fn: Callable[[SeedType], dict[str, np.ndarray]],
531539
) -> tuple[
532-
int, int, Callable, Callable, tuple[list[str], list[slice], list[tuple[int, ...]]]
540+
int,
541+
int,
542+
Callable,
543+
Callable,
544+
Callable,
545+
tuple[list[str], list[slice], list[tuple[int, ...]]],
533546
]:
534547
"""
535548
Compile functions required by nuts-rs from a given PyMC model.
@@ -546,6 +559,8 @@ def _make_functions(
546559
join_expanded: bool
547560
Whether to join the expanded variables into a single array. If False, the expanded variables will be returned
548561
as a list of arrays.
562+
pymc_initial_point_fn: Callable
563+
Initial point function created by pymc.initial_point.make_initial_point_fn
549564
550565
Returns
551566
-------
@@ -558,6 +573,8 @@ def _make_functions(
558573
and the gradient, otherwise only the logp is returned.
559574
expand_fn_pt: Callable
560575
Compiled pytensor function that computes the remaining variables for the trace
576+
initial_point_fn: Callable
577+
Python function that takes a random seed and returns a flat array of initial values
561578
param_data: tuple of lists
562579
Tuple containing data necessary to unravel a flat array of model variables back into a ragged list of arrays.
563580
The first list contains the names of the variables, the second list contains the slices that correspond to the
@@ -607,6 +624,10 @@ def _make_functions(
607624

608625
num_free_vars = count
609626

627+
initial_point_fn = rv_dict_to_flat_array_wrapper(
628+
pymc_initial_point_fn, names=joined_names, shapes=joined_shapes
629+
)
630+
610631
joined = pt.TensorType("float64", shape=(num_free_vars,))(
611632
name="_unconstrained_point"
612633
)
@@ -673,6 +694,7 @@ def _make_functions(
673694
num_expanded,
674695
logp_fn_pt,
675696
expand_fn_pt,
697+
initial_point_fn,
676698
(all_names, all_slices, all_shapes),
677699
)
678700

0 commit comments

Comments
 (0)