2020 applys_between ,
2121 graph_inputs ,
2222 io_toposort ,
23+ toposort ,
2324 vars_between ,
2425)
2526from pytensor .graph .utils import MetaObject , MissingInputError , TestValueError
26- from pytensor .misc .ordered_set import OrderedSet
2727
2828
2929ClientType = tuple [Apply , int ]
@@ -132,7 +132,6 @@ def __init__(
132132 features = []
133133
134134 self ._features : list [Feature ] = []
135-
136135 # All apply nodes in the subgraph defined by inputs and
137136 # outputs are cached in this field
138137 self .apply_nodes : set [Apply ] = set ()
@@ -160,7 +159,8 @@ def __init__(
160159 "input's owner or use graph.clone."
161160 )
162161
163- self .add_input (in_var , check = False )
162+ self .inputs .append (in_var )
163+ self .clients .setdefault (in_var , [])
164164
165165 for output in outputs :
166166 self .add_output (output , reason = "init" )
@@ -188,16 +188,6 @@ def add_input(self, var: Variable, check: bool = True) -> None:
188188 return
189189
190190 self .inputs .append (var )
191- self .setup_var (var )
192-
193- def setup_var (self , var : Variable ) -> None :
194- """Set up a variable so it belongs to this `FunctionGraph`.
195-
196- Parameters
197- ----------
198- var : pytensor.graph.basic.Variable
199-
200- """
201191 self .clients .setdefault (var , [])
202192
203193 def get_clients (self , var : Variable ) -> list [ClientType ]:
@@ -321,10 +311,11 @@ def import_var(
321311
322312 """
323313 # Imports the owners of the variables
324- if var .owner and var .owner not in self .apply_nodes :
325- self .import_node (var .owner , reason = reason , import_missing = import_missing )
314+ apply = var .owner
315+ if apply is not None and apply not in self .apply_nodes :
316+ self .import_node (apply , reason = reason , import_missing = import_missing )
326317 elif (
327- var . owner is None
318+ apply is None
328319 and not isinstance (var , AtomicVariable )
329320 and var not in self .inputs
330321 ):
@@ -335,10 +326,11 @@ def import_var(
335326 f"Computation graph contains a NaN. { var .type .why_null } "
336327 )
337328 if import_missing :
338- self .add_input (var )
329+ self .inputs .append (var )
330+ self .clients .setdefault (var , [])
339331 else :
340332 raise MissingInputError (f"Undeclared input: { var } " , variable = var )
341- self .setup_var (var )
333+ self .clients . setdefault (var , [] )
342334 self .variables .add (var )
343335
344336 def import_node (
@@ -355,29 +347,29 @@ def import_node(
355347 apply_node : Apply
356348 The node to be imported.
357349 check : bool
358- Check that the inputs for the imported nodes are also present in
359- the `FunctionGraph`.
350+ Check that the inputs for the imported nodes are also present in the `FunctionGraph`.
360351 reason : str
361352 The name of the optimization or operation in progress.
362353 import_missing : bool
363354 Add missing inputs instead of raising an exception.
364355 """
365356 # We import the nodes in topological order. We only are interested in
366- # new nodes, so we use all variables we know of as if they were the
367- # input set. (The functions in the graph module only use the input set
368- # to know where to stop going down.)
369- new_nodes = io_toposort ( self .variables , apply_node . outputs )
370-
371- if check :
372- for node in new_nodes :
357+ # new nodes, so we use all nodes we know of as inputs to interrupt the toposort
358+ self_variables = self . variables
359+ self_clients = self . clients
360+ self_apply_nodes = self .apply_nodes
361+ self_inputs = self . inputs
362+ for node in toposort ( apply_node . outputs , blockers = self_variables ) :
363+ if check :
373364 for var in node .inputs :
374365 if (
375366 var .owner is None
376367 and not isinstance (var , AtomicVariable )
377- and var not in self . inputs
368+ and var not in self_inputs
378369 ):
379370 if import_missing :
380- self .add_input (var )
371+ self_inputs .append (var )
372+ self_clients .setdefault (var , [])
381373 else :
382374 error_msg = (
383375 f"Input { node .inputs .index (var )} ({ var } )"
@@ -389,20 +381,20 @@ def import_node(
389381 )
390382 raise MissingInputError (error_msg , variable = var )
391383
392- for node in new_nodes :
393- assert node not in self . apply_nodes
394- self . apply_nodes . add ( node )
395- if not hasattr ( node . tag , " imported_by" ):
396- node . tag . imported_by = []
397- node . tag .imported_by .append (str (reason ))
384+ self_apply_nodes . add ( node )
385+ tag = node . tag
386+ if not hasattr ( tag , "imported_by" ):
387+ tag . imported_by = [ str ( reason )]
388+ else :
389+ tag .imported_by .append (str (reason ))
398390 for output in node .outputs :
399- self . setup_var (output )
400- self . variables .add (output )
401- for i , input in enumerate (node .inputs ):
402- if input not in self . variables :
403- self . setup_var ( input )
404- self . variables . add (input )
405- self . add_client ( input , (node , i ))
391+ self_clients . setdefault (output , [] )
392+ self_variables .add (output )
393+ for i , inp in enumerate (node .inputs ):
394+ if inp not in self_variables :
395+ self_clients . setdefault ( inp , [] )
396+ self_variables . add (inp )
397+ self_clients [ inp ]. append ( (node , i ))
406398 self .execute_callbacks ("on_import" , node , reason )
407399
408400 def change_node_input (
@@ -456,7 +448,7 @@ def change_node_input(
456448 self .outputs [node .op .idx ] = new_var
457449
458450 self .import_var (new_var , reason = reason , import_missing = import_missing )
459- self .add_client ( new_var , (node , i ))
451+ self .clients [ new_var ]. append ( (node , i ))
460452 self .remove_client (r , (node , i ), reason = reason )
461453 # Precondition: the substitution is semantically valid However it may
462454 # introduce cycles to the graph, in which case the transaction will be
@@ -755,11 +747,7 @@ def toposort(self) -> list[Apply]:
755747 :meth:`FunctionGraph.orderings`.
756748
757749 """
758- if len (self .apply_nodes ) < 2 :
759- # No sorting is necessary
760- return list (self .apply_nodes )
761-
762- return io_toposort (self .inputs , self .outputs , self .orderings ())
750+ return io_toposort (self .inputs , self .outputs , orderings = self .orderings ())
763751
764752 def orderings (self ) -> dict [Apply , list [Apply ]]:
765753 """Return a map of node to node evaluation dependencies.
@@ -778,29 +766,17 @@ def orderings(self) -> dict[Apply, list[Apply]]:
778766 take care of computing the dependencies by itself.
779767
780768 """
781- assert isinstance (self ._features , list )
782- all_orderings : list [dict ] = []
769+ all_orderings : list [dict ] = [
770+ orderings
771+ for feature in self ._features
772+ if (
773+ hasattr (feature , "orderings" ) and (orderings := feature .orderings (self ))
774+ )
775+ ]
783776
784- for feature in self ._features :
785- if hasattr (feature , "orderings" ):
786- orderings = feature .orderings (self )
787- if not isinstance (orderings , dict ):
788- raise TypeError (
789- "Non-deterministic return value from "
790- + str (feature .orderings )
791- + ". Nondeterministic object is "
792- + str (orderings )
793- )
794- if len (orderings ) > 0 :
795- all_orderings .append (orderings )
796- for node , prereqs in orderings .items ():
797- if not isinstance (prereqs , list | OrderedSet ):
798- raise TypeError (
799- "prereqs must be a type with a "
800- "deterministic iteration order, or toposort "
801- " will be non-deterministic."
802- )
803- if len (all_orderings ) == 1 :
777+ if not all_orderings :
778+ return {}
779+ elif len (all_orderings ) == 1 :
804780 # If there is only 1 ordering, we reuse it directly.
805781 return all_orderings [0 ].copy ()
806782 else :
0 commit comments