File tree Expand file tree Collapse file tree 1 file changed +4
-5
lines changed
Expand file tree Collapse file tree 1 file changed +4
-5
lines changed Original file line number Diff line number Diff line change @@ -354,8 +354,8 @@ def expand(x, **shared):
354354def compile_pymc_model (
355355 model : "pm.Model" ,
356356 * ,
357- backend : Literal ["numba" , "jax" ] | None = None ,
358- gradient_backend : Literal ["pytensor" , "jax" ] | None = None ,
357+ backend : Literal ["numba" , "jax" ] = "numba" ,
358+ gradient_backend : Literal ["pytensor" , "jax" ] = "pytensor" ,
359359 ** kwargs ,
360360) -> CompiledModel :
361361 """Compile necessary functions for sampling a pymc model.
@@ -384,10 +384,9 @@ def compile_pymc_model(
384384 "and restart your kernel in case you are in an interactive session."
385385 )
386386
387- if backend is None :
388- backend = "numba"
389-
390387 if backend .lower () == "numba" :
388+ if gradient_backend == "jax" :
389+ raise ValueError ("Gradient backend cannot be jax when using numba backend" )
391390 return _compile_pymc_model_numba (model , ** kwargs )
392391 elif backend .lower () == "jax" :
393392 return _compile_pymc_model_jax (
You can’t perform that action at this time.
0 commit comments