66from collections import defaultdict , deque
77from collections .abc import Generator , Sequence
88from functools import cache , reduce
9+ from operator import or_
910from typing import Literal
1011from warnings import warn
1112
2930)
3031from pytensor .graph .rewriting .db import SequenceDB
3132from pytensor .graph .rewriting .unify import OpPattern
32- from pytensor .graph .traversal import ancestors , toposort
33+ from pytensor .graph .traversal import toposort
3334from pytensor .graph .utils import InconsistencyError , MethodNotDefined
3435from pytensor .scalar .math import Grad2F1Loop , _grad_2f1_loop
3536from pytensor .tensor .basic import (
@@ -663,16 +664,9 @@ def find_fuseable_subgraph(
663664 visited_nodes : set [Apply ],
664665 fuseable_clients : FUSEABLE_MAPPING ,
665666 unfuseable_clients : UNFUSEABLE_MAPPING ,
667+ ancestors_bitset : dict [Apply , int ],
666668 toposort_index : dict [Apply , int ],
667669 ) -> tuple [list [Variable ], list [Variable ]]:
668- def variables_depend_on (
669- variables , depend_on , stop_search_at = None
670- ) -> bool :
671- return any (
672- a in depend_on
673- for a in ancestors (variables , blockers = stop_search_at )
674- )
675-
676670 for starting_node in toposort_index :
677671 if starting_node in visited_nodes :
678672 continue
@@ -684,7 +678,8 @@ def variables_depend_on(
684678
685679 subgraph_inputs : dict [Variable , Literal [None ]] = {} # ordered set
686680 subgraph_outputs : dict [Variable , Literal [None ]] = {} # ordered set
687- unfuseable_clients_subgraph : set [Variable ] = set ()
681+ subgraph_inputs_ancestors_bitset = 0
682+ unfuseable_clients_subgraph_bitset = 0
688683
689684 # If we need to manipulate the maps in place, we'll do a shallow copy later
690685 # For now we query on the original ones
@@ -716,50 +711,32 @@ def variables_depend_on(
716711 if must_become_output :
717712 subgraph_outputs .pop (next_out , None )
718713
719- required_unfuseable_inputs = [
720- inp
721- for inp in next_node .inputs
722- if next_node in unfuseable_clients_clone .get (inp )
723- ]
724- new_required_unfuseable_inputs = [
725- inp
726- for inp in required_unfuseable_inputs
727- if inp not in subgraph_inputs
728- ]
729-
730- must_backtrack = False
731- if new_required_unfuseable_inputs and subgraph_outputs :
732- # We need to check that any new inputs required by this node
733- # do not depend on other outputs of the current subgraph,
734- # via an unfuseable path.
735- if variables_depend_on (
736- [next_out ],
737- depend_on = unfuseable_clients_subgraph ,
738- stop_search_at = subgraph_outputs ,
739- ):
740- must_backtrack = True
714+ # We need to check that any inputs required by this node
715+ # do not depend on other outputs of the current subgraph,
716+ # via an unfuseable path.
717+ must_backtrack = (
718+ ancestors_bitset [next_node ]
719+ & unfuseable_clients_subgraph_bitset
720+ )
741721
742722 if not must_backtrack :
743- implied_unfuseable_clients = {
744- c
745- for client in unfuseable_clients_clone .get (next_out )
746- if not isinstance (client .op , Output )
747- for c in client .outputs
748- }
749-
750- new_implied_unfuseable_clients = (
751- implied_unfuseable_clients - unfuseable_clients_subgraph
723+ implied_unfuseable_clients_bitset = reduce (
724+ or_ ,
725+ (
726+ 1 << toposort_index [client ]
727+ for client in unfuseable_clients_clone .get (next_out )
728+ if not isinstance (client .op , Output )
729+ ),
730+ 0 ,
752731 )
753732
754- if new_implied_unfuseable_clients and subgraph_inputs :
755- # We need to check that any inputs of the current subgraph
756- # do not depend on other clients of this node,
757- # via an unfuseable path.
758- if variables_depend_on (
759- subgraph_inputs ,
760- depend_on = new_implied_unfuseable_clients ,
761- ):
762- must_backtrack = True
733+ # We need to check that any inputs of the current subgraph
734+ # do not depend on other clients of this node,
735+ # via an unfuseable path.
736+ must_backtrack = (
737+ subgraph_inputs_ancestors_bitset
738+ & implied_unfuseable_clients_bitset
739+ )
763740
764741 if must_backtrack :
765742 for inp in next_node .inputs :
@@ -800,29 +777,24 @@ def variables_depend_on(
800777 # immediate dependency problems. Update subgraph
801778 # mappings as if it next_node was part of it.
802779 # Useless inputs will be removed by the useless Composite rewrite
803- for inp in new_required_unfuseable_inputs :
804- subgraph_inputs [inp ] = None
805-
806780 if must_become_output :
807781 subgraph_outputs [next_out ] = None
808- unfuseable_clients_subgraph . update (
809- new_implied_unfuseable_clients
782+ unfuseable_clients_subgraph_bitset |= (
783+ implied_unfuseable_clients_bitset
810784 )
811785
812- # Expand through unvisited fuseable ancestors
813- fuseable_nodes_to_visit .extendleft (
814- sorted (
815- (
816- inp .owner
817- for inp in next_node .inputs
818- if (
819- inp not in required_unfuseable_inputs
820- and inp .owner not in visited_nodes
821- )
822- ),
823- key = toposort_index .get , # type: ignore[arg-type]
824- )
825- )
786+ for inp in sorted (
787+ next_node .inputs ,
788+ key = lambda x : toposort_index .get (x .owner , - 1 ),
789+ ):
790+ if next_node in unfuseable_clients_clone .get (inp , ()):
791+ # input must become an input of the subgraph since it's unfuseable with new node
792+ subgraph_inputs_ancestors_bitset |= (
793+ ancestors_bitset .get (inp .owner , 0 )
794+ )
795+ subgraph_inputs [inp ] = None
796+ elif inp .owner not in visited_nodes :
797+ fuseable_nodes_to_visit .appendleft (inp .owner )
826798
827799 # Expand through unvisited fuseable clients
828800 fuseable_nodes_to_visit .extend (
@@ -859,6 +831,8 @@ def update_fuseable_mappings_after_fg_replace(
859831 visited_nodes : set [Apply ],
860832 fuseable_clients : FUSEABLE_MAPPING ,
861833 unfuseable_clients : UNFUSEABLE_MAPPING ,
834+ toposort_index : dict [Apply , int ],
835+ ancestors_bitset : dict [Apply , int ],
862836 starting_nodes : set [Apply ],
863837 updated_nodes : set [Apply ],
864838 ) -> None :
@@ -869,11 +843,25 @@ def update_fuseable_mappings_after_fg_replace(
869843 dropped_nodes = starting_nodes - updated_nodes
870844
871845 # Remove intermediate Composite nodes from mappings
846+ # And compute the ancestors bitset of the new composite node
847+ # As well as the new toposort index for the new node
848+ new_node_ancestor_bitset = 0
849+ new_node_toposort_index = len (toposort_index )
872850 for dropped_node in dropped_nodes :
873851 (dropped_out ,) = dropped_node .outputs
874852 fuseable_clients .pop (dropped_out , None )
875853 unfuseable_clients .pop (dropped_out , None )
876854 visited_nodes .remove (dropped_node )
855+ # The new composite ancestor bitset is the union
856+ # of the ancestors of all the dropped nodes
857+ new_node_ancestor_bitset |= ancestors_bitset [dropped_node ]
858+ # The new composite node can have the same order as the latest node that was absorbed into it
859+ new_node_toposort_index = max (
860+ new_node_toposort_index , toposort_index [dropped_node ]
861+ )
862+
863+ ancestors_bitset [new_composite_node ] = new_node_ancestor_bitset
864+ toposort_index [new_composite_node ] = new_node_toposort_index
877865
878866 # Update fuseable information for subgraph inputs
879867 for inp in subgraph_inputs :
@@ -905,12 +893,23 @@ def update_fuseable_mappings_after_fg_replace(
905893 fuseable_clients , unfuseable_clients = initialize_fuseable_mappings (fg = fg )
906894 visited_nodes : set [Apply ] = set ()
907895 toposort_index = {node : i for i , node in enumerate (fgraph .toposort ())}
896+ # Create a bitset for each node of all its ancestors
897+ # This allows to quickly check if a variable depends on a set
898+ ancestors_bitset : dict [Apply , int ] = {}
899+ for node , index in toposort_index .items ():
900+ node_ancestor_bitset = 1 << index
901+ for inp in node .inputs :
902+ if (inp_node := inp .owner ) is not None :
903+ node_ancestor_bitset |= ancestors_bitset [inp_node ]
904+ ancestors_bitset [node ] = node_ancestor_bitset
905+
908906 while True :
909907 try :
910908 subgraph_inputs , subgraph_outputs = find_fuseable_subgraph (
911909 visited_nodes = visited_nodes ,
912910 fuseable_clients = fuseable_clients ,
913911 unfuseable_clients = unfuseable_clients ,
912+ ancestors_bitset = ancestors_bitset ,
914913 toposort_index = toposort_index ,
915914 )
916915 except ValueError :
@@ -929,6 +928,8 @@ def update_fuseable_mappings_after_fg_replace(
929928 visited_nodes = visited_nodes ,
930929 fuseable_clients = fuseable_clients ,
931930 unfuseable_clients = unfuseable_clients ,
931+ toposort_index = toposort_index ,
932+ ancestors_bitset = ancestors_bitset ,
932933 starting_nodes = starting_nodes ,
933934 updated_nodes = fg .apply_nodes ,
934935 )
0 commit comments