@@ -66,7 +66,7 @@ def seeded_array_fn(seed: SeedType = None):
6666 flat_array = np .empty (total_size , dtype = "float64" , order = "C" )
6767 cursor = 0
6868
69- for name , shape in zip (names , shapes ):
69+ for name , shape in zip (names , shapes , strict = True ):
7070 initial_value = initial_value_dict [name ]
7171 n = int (np .prod (initial_value .shape ))
7272 if initial_value .shape != shape :
@@ -217,7 +217,7 @@ def make_user_data(shared_vars, shared_data):
217217
218218def _compile_pymc_model_numba (
219219 model : "pm.Model" ,
220- initial_point_fn : Callable [[SeedType ], dict [str , np .ndarray ]],
220+ pymc_initial_point_fn : Callable [[SeedType ], dict [str , np .ndarray ]],
221221 ** kwargs ,
222222) -> CompiledPyMCModel :
223223 if find_spec ("numba" ) is None :
@@ -234,8 +234,15 @@ def _compile_pymc_model_numba(
234234 n_expanded ,
235235 logp_fn_pt ,
236236 expand_fn_pt ,
237+ initial_point_fn ,
237238 shape_info ,
238- ) = _make_functions (model , mode = "NUMBA" , compute_grad = True , join_expanded = True )
239+ ) = _make_functions (
240+ model ,
241+ mode = "NUMBA" ,
242+ compute_grad = True ,
243+ join_expanded = True ,
244+ pymc_initial_point_fn = pymc_initial_point_fn ,
245+ )
239246
240247 expand_fn = expand_fn_pt .vm .jit_fn
241248 logp_fn = logp_fn_pt .vm .jit_fn
@@ -282,17 +289,15 @@ def _compile_pymc_model_numba(
282289 expand_numba = numba .cfunc (c_sig_expand , ** kwargs )(expand_numba_raw )
283290
284291 dims , coords = _prepare_dims_and_coords (model , shape_info )
285- initial_point_fn_array = rv_dict_to_flat_array_wrapper (
286- initial_point_fn , names = shape_info [0 ], shapes = shape_info [- 1 ]
287- )
292+
288293 return CompiledPyMCModel (
289294 _n_dim = n_dim ,
290295 dims = dims ,
291296 _coords = coords ,
292297 _shapes = {name : tuple (shape ) for name , _ , shape in zip (* shape_info )},
293298 compiled_logp_func = logp_numba ,
294299 compiled_expand_func = expand_numba ,
295- initial_point_func = initial_point_fn_array ,
300+ initial_point_func = initial_point_fn ,
296301 shared_data = shared_data ,
297302 user_data = user_data ,
298303 n_expanded = n_expanded ,
@@ -331,7 +336,7 @@ def _compile_pymc_model_jax(
331336 model ,
332337 * ,
333338 gradient_backend = None ,
334- initial_point_fn : Callable [[SeedType ], dict [str , np .ndarray ]],
339+ pymc_initial_point_fn : Callable [[SeedType ], dict [str , np .ndarray ]],
335340 ** kwargs ,
336341):
337342 if find_spec ("jax" ) is None :
@@ -353,12 +358,14 @@ def _compile_pymc_model_jax(
353358 _ ,
354359 logp_fn_pt ,
355360 expand_fn_pt ,
361+ initial_point_fn ,
356362 shape_info ,
357363 ) = _make_functions (
358364 model ,
359365 mode = "JAX" ,
360366 compute_grad = gradient_backend == "pytensor" ,
361367 join_expanded = False ,
368+ pymc_initial_point_fn = pymc_initial_point_fn ,
362369 )
363370
364371 logp_fn = logp_fn_pt .vm .jit_fn
@@ -409,15 +416,11 @@ def expand(x, **shared):
409416
410417 dims , coords = _prepare_dims_and_coords (model , shape_info )
411418
412- initial_point_fn_array = rv_dict_to_flat_array_wrapper (
413- initial_point_fn , names = shape_info [0 ], shapes = shape_info [- 1 ]
414- )
415-
416419 return from_pyfunc (
417420 ndim = n_dim ,
418421 make_logp_fn = make_logp_func ,
419422 make_expand_fn = make_expand_func ,
420- make_initial_point_fn = initial_point_fn_array ,
423+ make_initial_point_fn = initial_point_fn ,
421424 expanded_dtypes = dtypes ,
422425 expanded_shapes = shapes ,
423426 expanded_names = names ,
@@ -484,13 +487,13 @@ def compile_pymc_model(
484487 if gradient_backend == "jax" :
485488 raise ValueError ("Gradient backend cannot be jax when using numba backend" )
486489 return _compile_pymc_model_numba (
487- model , initial_point_fn = initial_point_fn , ** kwargs
490+ model = model , pymc_initial_point_fn = initial_point_fn , ** kwargs
488491 )
489492 elif backend .lower () == "jax" :
490493 return _compile_pymc_model_jax (
491- model ,
494+ model = model ,
492495 gradient_backend = gradient_backend ,
493- initial_point_fn = initial_point_fn ,
496+ pymc_initial_point_fn = initial_point_fn ,
494497 ** kwargs ,
495498 )
496499 else :
@@ -527,9 +530,19 @@ def _compute_shapes(model):
527530
528531
529532def _make_functions (
530- model , * , mode , compute_grad , join_expanded
533+ model : "pm.Model" ,
534+ * ,
535+ mode : Literal ["JAX" , "NUMBA" ],
536+ compute_grad : bool ,
537+ join_expanded : bool ,
538+ pymc_initial_point_fn : Callable [[SeedType ], dict [str , np .ndarray ]],
531539) -> tuple [
532- int , int , Callable , Callable , tuple [list [str ], list [slice ], list [tuple [int , ...]]]
540+ int ,
541+ int ,
542+ Callable ,
543+ Callable ,
544+ Callable ,
545+ tuple [list [str ], list [slice ], list [tuple [int , ...]]],
533546]:
534547 """
535548 Compile functions required by nuts-rs from a given PyMC model.
@@ -546,6 +559,8 @@ def _make_functions(
546559 join_expanded: bool
547560 Whether to join the expanded variables into a single array. If False, the expanded variables will be returned
548561 as a list of arrays.
562+ pymc_initial_point_fn: Callable
563+ Initial point function created by pymc.initial_point.make_initial_point_fn
549564
550565 Returns
551566 -------
@@ -558,6 +573,8 @@ def _make_functions(
558573 and the gradient, otherwise only the logp is returned.
559574 expand_fn_pt: Callable
560575 Compiled pytensor function that computes the remaining variables for the trace
576+ initial_point_fn: Callable
577+ Python function that takes a random seed and returns a flat array of initial values
561578 param_data: tuple of lists
562579 Tuple containing data necessary to unravel a flat array of model variables back into a ragged list of arrays.
563580 The first list contains the names of the variables, the second list contains the slices that correspond to the
@@ -607,6 +624,10 @@ def _make_functions(
607624
608625 num_free_vars = count
609626
627+ initial_point_fn = rv_dict_to_flat_array_wrapper (
628+ pymc_initial_point_fn , names = joined_names , shapes = joined_shapes
629+ )
630+
610631 joined = pt .TensorType ("float64" , shape = (num_free_vars ,))(
611632 name = "_unconstrained_point"
612633 )
@@ -673,6 +694,7 @@ def _make_functions(
673694 num_expanded ,
674695 logp_fn_pt ,
675696 expand_fn_pt ,
697+ initial_point_fn ,
676698 (all_names , all_slices , all_shapes ),
677699 )
678700
0 commit comments