Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 77 additions & 75 deletions pytensor/compile/function/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,8 @@ def __init__(
assert len(self.input_storage) == len(self.maker.fgraph.inputs)
assert len(self.output_storage) == len(self.maker.fgraph.outputs)

self.has_defaults = any(refeed for _, refeed, _ in self.defaults)

# Group indexes of inputs that are potentially aliased to each other
# Note: Historically, we only worried about aliasing inputs if they belonged to the same type,
# even though there could be two distinct types that use the same kinds of underlying objects.
Expand Down Expand Up @@ -540,14 +542,40 @@ def __contains__(self, item):
self._value = ValueAttribute()
self._container = ContainerAttribute()

# TODO: Get rid of all this `expanded_inputs` nonsense
assert len(self.maker.expanded_inputs) == len(self.input_storage)
update_storage = [
container
for inp, container in zip(
self.maker.expanded_inputs, input_storage, strict=True
)
if inp.update is not None
]
# Updates are the last inner outputs that are not returned by Function.__call__
self.n_returned_outputs = len(self.output_storage) - len(update_storage)

# Function.__call__ is responsible for updating the inputs, unless the vm promises to do it itself
self.update_input_storage: tuple[int, Container] = ()
if getattr(vm, "need_update_inputs", True):
self.update_input_storage = tuple(
zip(
range(self.n_returned_outputs, len(output_storage)),
update_storage,
strict=True,
)
)

# This is used only when `vm.need_update_inputs` is `False`, because
# we're using one of the VM objects and it is putting updates back into
# the input containers all by itself.
self.n_returned_outputs = len(self.output_storage) - sum(
inp.update is not None for inp in self.maker.expanded_inputs
# In every function call we place inputs in the input_storage, and the vm places outputs in the output_storage
# After the call, we want to erase (some of) these references, to allow Python to GC them if unused
# Required input containers are the non-default inputs, must always be provided again, so we GC them
self.clear_input_storage_data = tuple(
container.storage for container in input_storage if container.required
)
# This is only done when `vm.allow_gc` is True, which can change at runtime.
self.clear_output_storage_data = tuple(
container.storage
for container, variable in zip(
self.output_storage, self.maker.fgraph.outputs, strict=True
)
if variable.owner is not None # Not a constant output
)

for node in self.maker.fgraph.apply_nodes:
Expand Down Expand Up @@ -747,7 +775,7 @@ def checkSV(sv_ori, sv_rpl):
elif isinstance(profile, str):
profile = pytensor.compile.profiling.ProfileStats(message=profile)

f_cpy = maker.__class__(
f_cpy = type(maker)(
inputs=ins,
outputs=outs,
fgraph=fg_cpy,
Expand All @@ -765,6 +793,8 @@ def checkSV(sv_ori, sv_rpl):
# check that.
accept_inplace=True,
no_fgraph_prep=True,
output_keys=maker.output_keys,
name=name,
).create(input_storage, storage_map=new_storage_map)

for in_ori, in_cpy, ori, cpy in zip(
Expand Down Expand Up @@ -797,8 +827,6 @@ def checkSV(sv_ori, sv_rpl):

f_cpy.trust_input = self.trust_input
f_cpy.unpack_single = self.unpack_single
f_cpy.name = name
f_cpy.maker.fgraph.name = name
return f_cpy

def _restore_defaults(self):
Expand All @@ -808,7 +836,7 @@ def _restore_defaults(self):
value = value.storage[0]
self[i] = value

def __call__(self, *args, **kwargs):
def __call__(self, *args, output_subset=None, **kwargs):
"""
Evaluates value of a function on given arguments.

Expand Down Expand Up @@ -836,20 +864,21 @@ def __call__(self, *args, **kwargs):
List of outputs on indices/keys from ``output_subset`` or all of them,
if ``output_subset`` is not passed.
"""
trust_input = self.trust_input
input_storage = self.input_storage
vm = self.vm
profile = self.profile

if profile:
t0 = time.perf_counter()

output_subset = kwargs.pop("output_subset", None)
if output_subset is not None:
warnings.warn("output_subset is deprecated.", FutureWarning)
if self.output_keys is not None:
output_subset = [self.output_keys.index(key) for key in output_subset]

# Reinitialize each container's 'provided' counter
if self.trust_input:
if trust_input:
for arg_container, arg in zip(input_storage, args, strict=False):
arg_container.storage[0] = arg
else:
Expand Down Expand Up @@ -908,7 +937,7 @@ def __call__(self, *args, **kwargs):
for k, arg in kwargs.items():
self[k] = arg

if not self.trust_input:
if not trust_input:
# Collect aliased inputs among the storage space
for potential_group in self._potential_aliased_input_groups:
args_share_memory: list[list[int]] = []
Expand Down Expand Up @@ -960,11 +989,7 @@ def __call__(self, *args, **kwargs):
if profile:
t0_fn = time.perf_counter()
try:
outputs = (
self.vm()
if output_subset is None
else self.vm(output_subset=output_subset)
)
outputs = vm() if output_subset is None else vm(output_subset=output_subset)
except Exception:
self._restore_defaults()
if hasattr(self.vm, "position_of_error"):
Expand All @@ -991,73 +1016,53 @@ def __call__(self, *args, **kwargs):

# Retrieve the values that were computed
if outputs is None:
outputs = [x.data for x in self.output_storage]

# Remove internal references to required inputs.
# These cannot be re-used anyway.
for arg_container in input_storage:
if arg_container.required:
arg_container.storage[0] = None

# if we are allowing garbage collection, remove the
# output reference from the internal storage cells
if getattr(self.vm, "allow_gc", False):
# strict=False because we are in a hot loop
for o_container, o_variable in zip(
self.output_storage, self.maker.fgraph.outputs, strict=False
):
if o_variable.owner is not None:
# this node is the variable of computation
# WARNING: This circumvents the 'readonly' attribute in x
o_container.storage[0] = None

if getattr(self.vm, "need_update_inputs", True):
# Update the inputs that have an update function
# strict=False because we are in a hot loop
for input, storage in reversed(
list(zip(self.maker.expanded_inputs, input_storage, strict=False))
):
if input.update is not None:
storage.data = outputs.pop()
else:
outputs = outputs[: self.n_returned_outputs]
outputs = [x.storage[0] for x in self.output_storage]

# Set updates and filter them out from the returned outputs
for i, input_storage in self.update_input_storage:
input_storage.storage[0] = outputs[i]
outputs = outputs[: self.n_returned_outputs]

# Remove input and output values from storage data
for storage_data in self.clear_input_storage_data:
storage_data[0] = None
if getattr(vm, "allow_gc", False):
for storage_data in self.clear_output_storage_data:
storage_data[0] = None

# Put default values back in the storage
self._restore_defaults()
if self.has_defaults:
self._restore_defaults()

if profile:
dt_call = time.perf_counter() - t0
pytensor.compile.profiling.total_fct_exec_time += dt_call
self.maker.mode.call_time += dt_call
profile.fct_callcount += 1
profile.fct_call_time += dt_call
if hasattr(self.vm, "update_profile"):
self.vm.update_profile(profile)
if hasattr(vm, "update_profile"):
vm.update_profile(profile)
if profile.ignore_first_call:
profile.reset()
profile.ignore_first_call = False

if self.return_none:
return None
elif self.unpack_single and len(outputs) == 1 and output_subset is None:
return outputs[0]
else:
if self.output_keys is not None:
assert len(self.output_keys) == len(outputs)

if output_subset is None:
# strict=False because we are in a hot loop
return dict(zip(self.output_keys, outputs, strict=False))
else:
return {
self.output_keys[index]: outputs[index]
for index in output_subset
}
if output_subset is not None:
outputs = [outputs[i] for i in output_subset]

if output_subset is None:
return outputs
if self.output_keys is None:
if self.unpack_single:
[out] = outputs
return out
else:
return [outputs[i] for i in output_subset]
return outputs
else:
output_keys = self.output_keys
if output_subset is not None:
output_keys = [output_keys[i] for i in output_subset]
return dict(zip(output_keys, outputs, strict=True))

value = property(
lambda self: self._value,
Expand All @@ -1077,9 +1082,10 @@ def free(self):
# 1.no allow_gc return False
# 2.has allow_gc, if allow_gc is False, return True
if not getattr(self.vm, "allow_gc", True):
for key in self.vm.storage_map:
if not isinstance(key, Constant):
self.vm.storage_map[key][0] = None
storage_map = self.vm.storage_map
for key, value in storage_map.items():
if key.owner is not None: # Not a constant
value[0] = None

for node in self.nodes_with_inner_function:
if hasattr(node.fn, "free"):
Expand All @@ -1091,10 +1097,6 @@ def get_shared(self):
"""
return [i.variable for i in self.maker.inputs if i.implicit]

def sync_shared(self):
# NOTE: sync was needed on old gpu backend
pass

def dprint(self, **kwargs):
"""Debug print itself

Expand Down
52 changes: 10 additions & 42 deletions pytensor/link/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,41 +653,36 @@ def create_jitable_thunk(
)

thunk_inputs = self.create_thunk_inputs(storage_map)

thunks = []

thunk_outputs = [storage_map[n] for n in self.fgraph.outputs]

fgraph_jit = self.jit_compile(converted_fgraph)

def thunk(
fgraph=self.fgraph,
fgraph_jit=fgraph_jit,
thunk_inputs=thunk_inputs,
thunk_outputs=thunk_outputs,
):
outputs = fgraph_jit(*[self.input_filter(x[0]) for x in thunk_inputs])
try:
outputs = fgraph_jit(*(x[0] for x in thunk_inputs))
except Exception:
# TODO: Should we add a fake node that combines all outputs,
# since the error may come from any of them?
raise_with_op(self.fgraph, output_nodes[0], thunk)

# strict=False because we are in a hot loop
for o_var, o_storage, o_val in zip(
fgraph.outputs, thunk_outputs, outputs, strict=False
):
compute_map[o_var][0] = True
o_storage[0] = self.output_filter(o_var, o_val)
return outputs
for o_storage, o_val in zip(thunk_outputs, outputs, strict=False):
o_storage[0] = o_val

thunk.inputs = thunk_inputs
thunk.outputs = thunk_outputs
thunk.lazy = False

thunks.append(thunk)
thunks = [thunk]

return thunks, output_nodes, fgraph_jit

def make_all(self, input_storage=None, output_storage=None, storage_map=None):
fgraph = self.fgraph
nodes = self.schedule(fgraph)
no_recycling = self.no_recycling

input_storage, output_storage, storage_map = map_storage(
fgraph, nodes, input_storage, output_storage, storage_map
Expand All @@ -701,34 +696,7 @@ def make_all(self, input_storage=None, output_storage=None, storage_map=None):
compute_map, nodes, input_storage, output_storage, storage_map
)

computed, last_user = gc_helper(nodes)

if self.allow_gc:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this gc happening somewhere else now? Why was it here if it could just be removed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't do anything for jitted functions where you don't control intermediate allocations

post_thunk_old_storage = [
[
storage_map[input]
for input in node.inputs
if (input in computed)
and (input not in fgraph.outputs)
and (node == last_user[input])
]
for node in nodes
]
else:
post_thunk_old_storage = None

if no_recycling is True:
no_recycling = list(storage_map.values())
no_recycling = difference(no_recycling, input_storage)
else:
no_recycling = [
storage_map[r] for r in no_recycling if r not in fgraph.inputs
]

fn = streamline(
fgraph, thunks, nodes, post_thunk_old_storage, no_recycling=no_recycling
)

[fn] = thunks
fn.jit_fn = jit_fn
fn.allow_gc = self.allow_gc
fn.storage_map = storage_map
Expand Down
8 changes: 1 addition & 7 deletions pytensor/link/jax/linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from numpy.random import Generator, RandomState

from pytensor.compile.sharedvalue import SharedVariable, shared
from pytensor.graph.basic import Constant
from pytensor.link.basic import JITLinker


Expand Down Expand Up @@ -72,12 +71,7 @@ def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs):
def jit_compile(self, fn):
import jax

# I suppose we can consider `Constant`s to be "static" according to
# JAX.
static_argnums = [
n for n, i in enumerate(self.fgraph.inputs) if isinstance(i, Constant)
]
return jax.jit(fn, static_argnums=static_argnums)
return jax.jit(fn)

def create_thunk_inputs(self, storage_map):
from pytensor.link.jax.dispatch import jax_typify
Expand Down
Loading
Loading