Skip to content

Commit 97b7f0f

Browse files
committed
Speedup FunctionGraph methods
1 parent 9be6796 commit 97b7f0f

File tree

6 files changed

+140
-154
lines changed

6 files changed

+140
-154
lines changed

pytensor/graph/fg.py

Lines changed: 46 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,10 @@
2020
applys_between,
2121
graph_inputs,
2222
io_toposort,
23+
toposort,
2324
vars_between,
2425
)
2526
from pytensor.graph.utils import MetaObject, MissingInputError, TestValueError
26-
from pytensor.misc.ordered_set import OrderedSet
2727

2828

2929
ClientType = tuple[Apply, int]
@@ -132,7 +132,6 @@ def __init__(
132132
features = []
133133

134134
self._features: list[Feature] = []
135-
136135
# All apply nodes in the subgraph defined by inputs and
137136
# outputs are cached in this field
138137
self.apply_nodes: set[Apply] = set()
@@ -160,7 +159,8 @@ def __init__(
160159
"input's owner or use graph.clone."
161160
)
162161

163-
self.add_input(in_var, check=False)
162+
self.inputs.append(in_var)
163+
self.clients.setdefault(in_var, [])
164164

165165
for output in outputs:
166166
self.add_output(output, reason="init")
@@ -188,16 +188,6 @@ def add_input(self, var: Variable, check: bool = True) -> None:
188188
return
189189

190190
self.inputs.append(var)
191-
self.setup_var(var)
192-
193-
def setup_var(self, var: Variable) -> None:
194-
"""Set up a variable so it belongs to this `FunctionGraph`.
195-
196-
Parameters
197-
----------
198-
var : pytensor.graph.basic.Variable
199-
200-
"""
201191
self.clients.setdefault(var, [])
202192

203193
def get_clients(self, var: Variable) -> list[ClientType]:
@@ -321,10 +311,11 @@ def import_var(
321311
322312
"""
323313
# Imports the owners of the variables
324-
if var.owner and var.owner not in self.apply_nodes:
325-
self.import_node(var.owner, reason=reason, import_missing=import_missing)
314+
apply = var.owner
315+
if apply is not None and apply not in self.apply_nodes:
316+
self.import_node(apply, reason=reason, import_missing=import_missing)
326317
elif (
327-
var.owner is None
318+
apply is None
328319
and not isinstance(var, AtomicVariable)
329320
and var not in self.inputs
330321
):
@@ -335,10 +326,11 @@ def import_var(
335326
f"Computation graph contains a NaN. {var.type.why_null}"
336327
)
337328
if import_missing:
338-
self.add_input(var)
329+
self.inputs.append(var)
330+
self.clients.setdefault(var, [])
339331
else:
340332
raise MissingInputError(f"Undeclared input: {var}", variable=var)
341-
self.setup_var(var)
333+
self.clients.setdefault(var, [])
342334
self.variables.add(var)
343335

344336
def import_node(
@@ -355,29 +347,29 @@ def import_node(
355347
apply_node : Apply
356348
The node to be imported.
357349
check : bool
358-
Check that the inputs for the imported nodes are also present in
359-
the `FunctionGraph`.
350+
Check that the inputs for the imported nodes are also present in the `FunctionGraph`.
360351
reason : str
361352
The name of the optimization or operation in progress.
362353
import_missing : bool
363354
Add missing inputs instead of raising an exception.
364355
"""
365356
# We import the nodes in topological order. We only are interested in
366-
# new nodes, so we use all variables we know of as if they were the
367-
# input set. (The functions in the graph module only use the input set
368-
# to know where to stop going down.)
369-
new_nodes = io_toposort(self.variables, apply_node.outputs)
370-
371-
if check:
372-
for node in new_nodes:
357+
# new nodes, so we use all nodes we know of as inputs to interrupt the toposort
358+
self_variables = self.variables
359+
self_clients = self.clients
360+
self_apply_nodes = self.apply_nodes
361+
self_inputs = self.inputs
362+
for node in toposort(apply_node.outputs, blockers=self_variables):
363+
if check:
373364
for var in node.inputs:
374365
if (
375366
var.owner is None
376367
and not isinstance(var, AtomicVariable)
377-
and var not in self.inputs
368+
and var not in self_inputs
378369
):
379370
if import_missing:
380-
self.add_input(var)
371+
self_inputs.append(var)
372+
self_clients.setdefault(var, [])
381373
else:
382374
error_msg = (
383375
f"Input {node.inputs.index(var)} ({var})"
@@ -389,20 +381,20 @@ def import_node(
389381
)
390382
raise MissingInputError(error_msg, variable=var)
391383

392-
for node in new_nodes:
393-
assert node not in self.apply_nodes
394-
self.apply_nodes.add(node)
395-
if not hasattr(node.tag, "imported_by"):
396-
node.tag.imported_by = []
397-
node.tag.imported_by.append(str(reason))
384+
self_apply_nodes.add(node)
385+
tag = node.tag
386+
if not hasattr(tag, "imported_by"):
387+
tag.imported_by = [str(reason)]
388+
else:
389+
tag.imported_by.append(str(reason))
398390
for output in node.outputs:
399-
self.setup_var(output)
400-
self.variables.add(output)
401-
for i, input in enumerate(node.inputs):
402-
if input not in self.variables:
403-
self.setup_var(input)
404-
self.variables.add(input)
405-
self.add_client(input, (node, i))
391+
self_clients.setdefault(output, [])
392+
self_variables.add(output)
393+
for i, inp in enumerate(node.inputs):
394+
if inp not in self_variables:
395+
self_clients.setdefault(inp, [])
396+
self_variables.add(inp)
397+
self_clients[inp].append((node, i))
406398
self.execute_callbacks("on_import", node, reason)
407399

408400
def change_node_input(
@@ -456,7 +448,7 @@ def change_node_input(
456448
self.outputs[node.op.idx] = new_var
457449

458450
self.import_var(new_var, reason=reason, import_missing=import_missing)
459-
self.add_client(new_var, (node, i))
451+
self.clients[new_var].append((node, i))
460452
self.remove_client(r, (node, i), reason=reason)
461453
# Precondition: the substitution is semantically valid However it may
462454
# introduce cycles to the graph, in which case the transaction will be
@@ -755,11 +747,7 @@ def toposort(self) -> list[Apply]:
755747
:meth:`FunctionGraph.orderings`.
756748
757749
"""
758-
if len(self.apply_nodes) < 2:
759-
# No sorting is necessary
760-
return list(self.apply_nodes)
761-
762-
return io_toposort(self.inputs, self.outputs, self.orderings())
750+
return io_toposort(self.inputs, self.outputs, orderings=self.orderings())
763751

764752
def orderings(self) -> dict[Apply, list[Apply]]:
765753
"""Return a map of node to node evaluation dependencies.
@@ -778,29 +766,17 @@ def orderings(self) -> dict[Apply, list[Apply]]:
778766
take care of computing the dependencies by itself.
779767
780768
"""
781-
assert isinstance(self._features, list)
782-
all_orderings: list[dict] = []
769+
all_orderings: list[dict] = [
770+
orderings
771+
for feature in self._features
772+
if (
773+
hasattr(feature, "orderings") and (orderings := feature.orderings(self))
774+
)
775+
]
783776

784-
for feature in self._features:
785-
if hasattr(feature, "orderings"):
786-
orderings = feature.orderings(self)
787-
if not isinstance(orderings, dict):
788-
raise TypeError(
789-
"Non-deterministic return value from "
790-
+ str(feature.orderings)
791-
+ ". Nondeterministic object is "
792-
+ str(orderings)
793-
)
794-
if len(orderings) > 0:
795-
all_orderings.append(orderings)
796-
for node, prereqs in orderings.items():
797-
if not isinstance(prereqs, list | OrderedSet):
798-
raise TypeError(
799-
"prereqs must be a type with a "
800-
"deterministic iteration order, or toposort "
801-
" will be non-deterministic."
802-
)
803-
if len(all_orderings) == 1:
777+
if not all_orderings:
778+
return {}
779+
elif len(all_orderings) == 1:
804780
# If there is only 1 ordering, we reuse it directly.
805781
return all_orderings[0].copy()
806782
else:

0 commit comments

Comments
 (0)