Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pytensor/compile/debugmode.py
Original file line number Diff line number Diff line change
Expand Up @@ -1331,7 +1331,11 @@ def printstuff(self):
# the external requirements of the .linker attribute of a mode
# 1) it's a class instance
# 2) it a has a .clone() method
# 3) it has required_rewrites and incompatible_rewrites class attributes
class _DummyLinker:
required_rewrites = ()
incompatible_rewrites = ()

# This is not a real linker anyway
def clone(self, allow_gc=None):
return self
Expand Down
60 changes: 13 additions & 47 deletions pytensor/compile/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,14 @@ def __setstate__(self, state):
if isinstance(optimizer, str) or optimizer is None:
optimizer = predefined_optimizers[optimizer]
if isinstance(optimizer, RewriteDatabaseQuery):
# TODO: From the __init__ signature this should always be the case
# But some tests and internal logic allow passing a GraphRewriter directly as optimizer
# Cleanup!
self.provided_optimizer = optimizer
if r := linker.required_rewrites:
optimizer = optimizer.including(*r)
if r := linker.incompatible_rewrites:
optimizer = optimizer.excluding(*r)
self._optimizer = optimizer
self.call_time = 0
self.fn_time = 0
Expand All @@ -365,14 +372,13 @@ def __str__(self):
f"optdb={self.optdb})"
)

def __get_optimizer(self):
@property
def optimizer(self):
if isinstance(self._optimizer, RewriteDatabaseQuery):
return self.optdb.query(self._optimizer)
else:
return self._optimizer

optimizer = property(__get_optimizer)

def get_linker_optimizer(self, linker, optimizer):
if isinstance(linker, str) or linker is None:
linker = predefined_linkers[linker]
Expand Down Expand Up @@ -466,61 +472,21 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):

NUMBA = Mode(
NumbaLinker(),
RewriteDatabaseQuery(
include=["fast_run", "numba"],
exclude=[
"cxx_only",
"BlasOpt",
"local_careduce_fusion",
"scan_save_mem_prealloc",
],
),
RewriteDatabaseQuery(include=["fast_run", "numba"]),
)

JAX = Mode(
JAXLinker(),
RewriteDatabaseQuery(
include=["fast_run", "jax"],
exclude=[
"cxx_only",
"BlasOpt",
"fusion",
"inplace",
"scan_save_mem_prealloc",
# There are specific variants for the LU decompositions supported by JAX
"reuse_lu_decomposition_multiple_solves",
"scan_split_non_sequence_lu_decomposition_solve",
],
),
RewriteDatabaseQuery(include=["fast_run", "jax"]),
)
PYTORCH = Mode(
PytorchLinker(),
RewriteDatabaseQuery(
include=["fast_run"],
exclude=[
"cxx_only",
"BlasOpt",
"fusion",
"inplace",
"scan_save_mem_prealloc",
"reuse_lu_decomposition_multiple_solves",
"scan_split_non_sequence_lu_decomposition_solve",
],
),
RewriteDatabaseQuery(include=["fast_run"]),
)

MLX = Mode(
MLXLinker(),
RewriteDatabaseQuery(
include=["fast_run"],
exclude=[
"cxx_only",
"BlasOpt",
"fusion",
"inplace",
"scan_save_mem_prealloc",
],
),
RewriteDatabaseQuery(include=["fast_run"]),
)


Expand Down
51 changes: 37 additions & 14 deletions pytensor/link/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,9 @@ class Linker(ABC):
the FunctionGraph.
"""

required_rewrites: tuple[str, ...] = ("minimum_compile",)
incompatible_rewrites: tuple[str, ...] = ()

def __init__(
self,
*,
Expand Down Expand Up @@ -656,21 +659,37 @@ def create_jitable_thunk(
thunk_outputs = [storage_map[n] for n in self.fgraph.outputs]
fgraph_jit = self.jit_compile(converted_fgraph)

def thunk(
fgraph_jit=fgraph_jit,
thunk_inputs=thunk_inputs,
thunk_outputs=thunk_outputs,
):
try:
outputs = fgraph_jit(*(x[0] for x in thunk_inputs))
except Exception:
# TODO: Should we add a fake node that combines all outputs,
# since the error may come from any of them?
raise_with_op(self.fgraph, output_nodes[0], thunk)
if thunk_outputs:

# zip strict not specified because we are in a hot loop
for o_storage, o_val in zip(thunk_outputs, outputs):
o_storage[0] = o_val
def thunk(
fgraph_jit=fgraph_jit,
thunk_inputs=thunk_inputs,
thunk_outputs=thunk_outputs,
):
try:
outputs = fgraph_jit(*(x[0] for x in thunk_inputs))
except Exception:
# TODO: Should we add a fake node that combines all outputs,
# since the error may come from any of them?
raise_with_op(self.fgraph, output_nodes[0], thunk)

# zip strict not specified because we are in a hot loop
for o_storage, o_val in zip(thunk_outputs, outputs):
o_storage[0] = o_val

else:
# Edge case - functions without outputs
def thunk(
fgraph_jit=fgraph_jit,
thunk_inputs=thunk_inputs,
thunk_outputs=thunk_outputs,
):
try:
res = fgraph_jit(*(x[0] for x in thunk_inputs))
except Exception:
raise_with_op(self.fgraph, output_nodes[0], thunk)
assert res is None
return thunk_outputs

thunk.inputs = thunk_inputs
thunk.outputs = thunk_outputs
Expand Down Expand Up @@ -714,3 +733,7 @@ def make_all(self, input_storage=None, output_storage=None, storage_map=None):
thunks,
nodes,
)

def __repr__(self):
# Assumes no subclass needs init arguments
return f"{self.__class__.__name__}()"
16 changes: 16 additions & 0 deletions pytensor/link/jax/linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,22 @@
class JAXLinker(JITLinker):
"""A `Linker` that JIT-compiles NumPy-based operations using JAX."""

required_rewrites = (
"minimum_compile",
"jax",
) # TODO: Distinguish between optional "jax" and "minimum_compile_jax"
incompatible_rewrites = (
"cxx",
"BlasOpt",
"local_careduce_fusion",
"scan_save_mem_prealloc",
# JAX does it his own inplace optimization
"inplace",
# There are specific variants for the LU decompositions supported by JAX
"reuse_lu_decomposition_multiple_solves",
"scan_split_non_sequence_lu_decomposition_solve",
)

scalar_shape_inputs: tuple[int, ...]

def __init__(self, *args, **kwargs):
Expand Down
8 changes: 8 additions & 0 deletions pytensor/link/mlx/linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,14 @@
class MLXLinker(JITLinker):
"""A `Linker` that JIT-compiles NumPy-based operations using Apple's MLX."""

incompatible_rewrites = (
"cxx_only",
"BlasOpt",
"fusion",
"inplace",
"scan_save_mem_prealloc",
)

def __init__(self, use_compile=True, *args, **kwargs):
super().__init__(*args, **kwargs)
self.gen_functors = []
Expand Down
11 changes: 11 additions & 0 deletions pytensor/link/numba/linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,17 @@


class NumbaLinker(JITLinker):
required_rewrites = (
"minimum_compile",
"numba",
) # TODO: Distinguish between optional "numba" and "minimum_compile_numba"
incompatible_rewrites = (
"cxx",
"BlasOpt",
"local_careduce_fusion",
"scan_save_mem_prealloc",
)

"""A `Linker` that JIT-compiles NumPy-based operations using Numba."""

def fgraph_convert(self, fgraph, **kwargs):
Expand Down
10 changes: 10 additions & 0 deletions pytensor/link/pytorch/linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,16 @@
class PytorchLinker(JITLinker):
"""A `Linker` that compiles NumPy-based operations using torch.compile."""

incompatible_rewrites = (
"cxx_only",
"BlasOpt",
"fusion",
"inplace",
"scan_save_mem_prealloc",
"reuse_lu_decomposition_multiple_solves",
"scan_split_non_sequence_lu_decomposition_solve",
)

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.gen_functors = []
Expand Down
10 changes: 7 additions & 3 deletions tests/compile/test_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,15 @@ def test_NoOutputFromInplace():


def test_including():
mode = Mode(optimizer="merge")
assert set(mode._optimizer.include) == {"merge"}
mode = Mode(linker="py", optimizer="merge")
assert set(mode._optimizer.include) == {"minimum_compile", "merge"}

new_mode = mode.including("fast_compile")
assert set(new_mode._optimizer.include) == {"merge", "fast_compile"}
assert set(new_mode._optimizer.include) == {
"minimum_compile",
"merge",
"fast_compile",
}


class TestBunchOfModes:
Expand Down
Loading