@@ -1490,26 +1490,11 @@ def zerosumnormal(name, *, sigma=1.0, size, model_logp):
14901490 return joined_inputs , [model_logp , model_dlogp ]
14911491
14921492
1493- @pytest .fixture (scope = "session " )
1493+ @pytest .fixture (scope = "module " )
14941494def radon_model ():
14951495 return create_radon_model ()
14961496
14971497
1498- @pytest .fixture (scope = "session" )
1499- def radon_model_variants ():
1500- # Convert to list comp
1501- return [
1502- create_radon_model (
1503- intercept_dist = intercept_dist ,
1504- sigma_dist = sigma_dist ,
1505- centered = centered ,
1506- )
1507- for centered in (True , False )
1508- for intercept_dist in ("normal" , "lognormal" )
1509- for sigma_dist in ("halfnormal" , "lognormal" )
1510- ]
1511-
1512-
15131498@pytest .mark .parametrize ("mode" , ["C" , "C_VM" , "NUMBA" ])
15141499def test_radon_model_repeated_compile_benchmark (mode , radon_model , benchmark ):
15151500 joined_inputs , [model_logp , model_dlogp ] = radon_model
@@ -1526,15 +1511,13 @@ def compile_and_call_once():
15261511
15271512
15281513@pytest .mark .parametrize ("mode" , ["C" , "C_VM" , "NUMBA" ])
1529- def test_radon_model_variants_compile_benchmark (
1530- mode , radon_model , radon_model_variants , benchmark
1531- ):
1514+ def test_radon_model_small_variants_compile_benchmark (mode , radon_model , benchmark ):
15321515 """Test compilation speed when a slightly variant of a function is compiled each time.
15331516
15341517 This test more realistically simulates a use case where a model is recompiled
15351518 multiple times with small changes, such as in an interactive environment.
15361519
1537- NOTE: For this test to be meaningful on subsequent runs, the cache must be cleared
1520+ TODO: Change cache directory to force recompile. This must be done before pytensor is imported
15381521 """
15391522 joined_inputs , [model_logp , model_dlogp ] = radon_model
15401523 rng = np .random .default_rng (1 )
@@ -1547,11 +1530,21 @@ def test_radon_model_variants_compile_benchmark(
15471530 fn (x )
15481531
15491532 def compile_and_call_once ():
1550- for joined_inputs , [model_logp , model_dlogp ] in radon_model_variants :
1551- fn = function (
1552- [joined_inputs ], [model_logp , model_dlogp ], mode = mode , trust_input = True
1553- )
1554- fn (x )
1533+ for centered in (True , False ):
1534+ for intercept_dist in ("normal" , "lognormal" ):
1535+ for sigma_dist in ("halfnormal" , "lognormal" ):
1536+ joined_inputs , [model_logp , model_dlogp ] = create_radon_model (
1537+ intercept_dist = intercept_dist ,
1538+ sigma_dist = sigma_dist ,
1539+ centered = centered ,
1540+ )
1541+ fn = function (
1542+ [joined_inputs ],
1543+ [model_logp , model_dlogp ],
1544+ mode = mode ,
1545+ trust_input = True ,
1546+ )
1547+ fn (x )
15551548
15561549 benchmark .pedantic (compile_and_call_once , rounds = 1 , iterations = 1 )
15571550
0 commit comments