2828from pytensor .graph .rewriting .basic import (
2929 GraphRewriter ,
3030 copy_stack_trace ,
31- in2out ,
31+ dfs_rewriter ,
3232 node_rewriter ,
3333)
3434from pytensor .graph .rewriting .db import EquilibriumDB , SequenceDB
@@ -2548,15 +2548,15 @@ def scan_push_out_dot1(fgraph, node):
25482548# ScanSaveMem should execute only once per node.
25492549optdb .register (
25502550 "scan_save_mem_prealloc" ,
2551- in2out (scan_save_mem_prealloc , ignore_newtrees = True ),
2551+ dfs_rewriter (scan_save_mem_prealloc , ignore_newtrees = True ),
25522552 "fast_run" ,
25532553 "scan" ,
25542554 "scan_save_mem" ,
25552555 position = 1.61 ,
25562556)
25572557optdb .register (
25582558 "scan_save_mem_no_prealloc" ,
2559- in2out (scan_save_mem_no_prealloc , ignore_newtrees = True ),
2559+ dfs_rewriter (scan_save_mem_no_prealloc , ignore_newtrees = True ),
25602560 "numba" ,
25612561 "jax" ,
25622562 "pytorch" ,
@@ -2577,7 +2577,7 @@ def scan_push_out_dot1(fgraph, node):
25772577
25782578scan_seqopt1 .register (
25792579 "scan_remove_constants_and_unused_inputs0" ,
2580- in2out (remove_constants_and_unused_inputs_scan , ignore_newtrees = True ),
2580+ dfs_rewriter (remove_constants_and_unused_inputs_scan , ignore_newtrees = True ),
25812581 "remove_constants_and_unused_inputs_scan" ,
25822582 "fast_run" ,
25832583 "scan" ,
@@ -2586,7 +2586,7 @@ def scan_push_out_dot1(fgraph, node):
25862586
25872587scan_seqopt1 .register (
25882588 "scan_push_out_non_seq" ,
2589- in2out (scan_push_out_non_seq , ignore_newtrees = True ),
2589+ dfs_rewriter (scan_push_out_non_seq , ignore_newtrees = True ),
25902590 "scan_pushout_nonseqs_ops" , # For backcompat: so it can be tagged with old name
25912591 "fast_run" ,
25922592 "scan" ,
@@ -2596,7 +2596,7 @@ def scan_push_out_dot1(fgraph, node):
25962596
25972597scan_seqopt1 .register (
25982598 "scan_push_out_seq" ,
2599- in2out (scan_push_out_seq , ignore_newtrees = True ),
2599+ dfs_rewriter (scan_push_out_seq , ignore_newtrees = True ),
26002600 "scan_pushout_seqs_ops" , # For backcompat: so it can be tagged with old name
26012601 "fast_run" ,
26022602 "scan" ,
@@ -2607,7 +2607,7 @@ def scan_push_out_dot1(fgraph, node):
26072607
26082608scan_seqopt1 .register (
26092609 "scan_push_out_dot1" ,
2610- in2out (scan_push_out_dot1 , ignore_newtrees = True ),
2610+ dfs_rewriter (scan_push_out_dot1 , ignore_newtrees = True ),
26112611 "scan_pushout_dot1" , # For backcompat: so it can be tagged with old name
26122612 "fast_run" ,
26132613 "more_mem" ,
@@ -2620,7 +2620,7 @@ def scan_push_out_dot1(fgraph, node):
26202620scan_seqopt1 .register (
26212621 "scan_push_out_add" ,
26222622 # TODO: Perhaps this should be an `EquilibriumGraphRewriter`?
2623- in2out (scan_push_out_add , ignore_newtrees = False ),
2623+ dfs_rewriter (scan_push_out_add , ignore_newtrees = False ),
26242624 "scan_pushout_add" , # For backcompat: so it can be tagged with old name
26252625 "fast_run" ,
26262626 "more_mem" ,
@@ -2631,22 +2631,22 @@ def scan_push_out_dot1(fgraph, node):
26312631
26322632scan_eqopt2 .register (
26332633 "while_scan_merge_subtensor_last_element" ,
2634- in2out (while_scan_merge_subtensor_last_element , ignore_newtrees = True ),
2634+ dfs_rewriter (while_scan_merge_subtensor_last_element , ignore_newtrees = True ),
26352635 "fast_run" ,
26362636 "scan" ,
26372637)
26382638
26392639scan_eqopt2 .register (
26402640 "constant_folding_for_scan2" ,
2641- in2out (constant_folding , ignore_newtrees = True ),
2641+ dfs_rewriter (constant_folding , ignore_newtrees = True ),
26422642 "fast_run" ,
26432643 "scan" ,
26442644)
26452645
26462646
26472647scan_eqopt2 .register (
26482648 "scan_remove_constants_and_unused_inputs1" ,
2649- in2out (remove_constants_and_unused_inputs_scan , ignore_newtrees = True ),
2649+ dfs_rewriter (remove_constants_and_unused_inputs_scan , ignore_newtrees = True ),
26502650 "remove_constants_and_unused_inputs_scan" ,
26512651 "fast_run" ,
26522652 "scan" ,
@@ -2661,23 +2661,23 @@ def scan_push_out_dot1(fgraph, node):
26612661# After Merge optimization
26622662scan_eqopt2 .register (
26632663 "scan_remove_constants_and_unused_inputs2" ,
2664- in2out (remove_constants_and_unused_inputs_scan , ignore_newtrees = True ),
2664+ dfs_rewriter (remove_constants_and_unused_inputs_scan , ignore_newtrees = True ),
26652665 "remove_constants_and_unused_inputs_scan" ,
26662666 "fast_run" ,
26672667 "scan" ,
26682668)
26692669
26702670scan_eqopt2 .register (
26712671 "scan_merge_inouts" ,
2672- in2out (scan_merge_inouts , ignore_newtrees = True ),
2672+ dfs_rewriter (scan_merge_inouts , ignore_newtrees = True ),
26732673 "fast_run" ,
26742674 "scan" ,
26752675)
26762676
26772677# After everything else
26782678scan_eqopt2 .register (
26792679 "scan_remove_constants_and_unused_inputs3" ,
2680- in2out (remove_constants_and_unused_inputs_scan , ignore_newtrees = True ),
2680+ dfs_rewriter (remove_constants_and_unused_inputs_scan , ignore_newtrees = True ),
26812681 "remove_constants_and_unused_inputs_scan" ,
26822682 "fast_run" ,
26832683 "scan" ,
0 commit comments