From ff68d22b53c6711189af686513301ca36740442f Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 23 Jan 2025 14:34:49 +0100 Subject: [PATCH] Respect predefined modes in `get_default_mode` Also allow arbitrary capitalization of the modes. Also make linker and optimizer non-mutable config as the mode is cached after using them for the first time. --- pytensor/compile/__init__.py | 1 - pytensor/compile/mode.py | 78 +++++++++++++--------------- pytensor/configdefaults.py | 4 +- tests/compile/function/test_types.py | 8 +-- tests/compile/test_mode.py | 13 +++++ 5 files changed, 54 insertions(+), 50 deletions(-) diff --git a/pytensor/compile/__init__.py b/pytensor/compile/__init__.py index 9bd140d746..f6a95fe163 100644 --- a/pytensor/compile/__init__.py +++ b/pytensor/compile/__init__.py @@ -37,7 +37,6 @@ PrintCurrentFunctionGraph, get_default_mode, get_mode, - instantiated_default_mode, local_useless, optdb, predefined_linkers, diff --git a/pytensor/compile/mode.py b/pytensor/compile/mode.py index 152ad3554d..ae905089b5 100644 --- a/pytensor/compile/mode.py +++ b/pytensor/compile/mode.py @@ -492,7 +492,7 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs): "PYTORCH": PYTORCH, } -instantiated_default_mode = None +_CACHED_RUNTIME_MODES: dict[str, Mode] = {} def get_mode(orig_string): @@ -500,50 +500,46 @@ def get_mode(orig_string): string = config.mode else: string = orig_string + if not isinstance(string, str): return string # it is hopefully already a mode... - global instantiated_default_mode - # The default mode is cached. However, config.mode can change - # If instantiated_default_mode has the right class, use it. - if orig_string is None and instantiated_default_mode: - if string in predefined_modes: - default_mode_class = predefined_modes[string].__class__.__name__ - else: - default_mode_class = string - if instantiated_default_mode.__class__.__name__ == default_mode_class: - return instantiated_default_mode - - if string in ("Mode", "DebugMode", "NanGuardMode"): - if string == "DebugMode": - # need to import later to break circular dependency. - from .debugmode import DebugMode - - # DebugMode use its own linker. - ret = DebugMode(optimizer=config.optimizer) - elif string == "NanGuardMode": - # need to import later to break circular dependency. - from .nanguardmode import NanGuardMode - - # NanGuardMode use its own linker. - ret = NanGuardMode(True, True, True, optimizer=config.optimizer) - else: - # TODO: Can't we look up the name and invoke it rather than using eval here? - ret = eval(string + "(linker=config.linker, optimizer=config.optimizer)") - elif string in predefined_modes: - ret = predefined_modes[string] - else: - raise Exception(f"No predefined mode exist for string: {string}") + # Keep the original string for error messages + upper_string = string.upper() - if orig_string is None: - # Build and cache the default mode - if config.optimizer_excluding: - ret = ret.excluding(*config.optimizer_excluding.split(":")) - if config.optimizer_including: - ret = ret.including(*config.optimizer_including.split(":")) - if config.optimizer_requiring: - ret = ret.requiring(*config.optimizer_requiring.split(":")) - instantiated_default_mode = ret + if upper_string in predefined_modes: + return predefined_modes[upper_string] + + global _CACHED_RUNTIME_MODES + + if upper_string in _CACHED_RUNTIME_MODES: + return _CACHED_RUNTIME_MODES[upper_string] + + # Need to define the mode for the first time + if upper_string == "MODE": + ret = Mode(linker=config.linker, optimizer=config.optimizer) + elif upper_string in ("DEBUGMODE", "DEBUG_MODE"): + from pytensor.compile.debugmode import DebugMode + + # DebugMode use its own linker. + ret = DebugMode(optimizer=config.optimizer) + elif upper_string == "NANGUARDMODE": + from pytensor.compile.nanguardmode import NanGuardMode + + # NanGuardMode use its own linker. + ret = NanGuardMode(True, True, True, optimizer=config.optimizer) + + else: + raise ValueError(f"No predefined mode exist for string: {string}") + + if config.optimizer_excluding: + ret = ret.excluding(*config.optimizer_excluding.split(":")) + if config.optimizer_including: + ret = ret.including(*config.optimizer_including.split(":")) + if config.optimizer_requiring: + ret = ret.requiring(*config.optimizer_requiring.split(":")) + # Cache the mode for next time + _CACHED_RUNTIME_MODES[upper_string] = ret return ret diff --git a/pytensor/configdefaults.py b/pytensor/configdefaults.py index fcc36f0c6f..6000311df7 100644 --- a/pytensor/configdefaults.py +++ b/pytensor/configdefaults.py @@ -387,7 +387,8 @@ def add_compile_configvars(): config.add( "linker", "Default linker used if the pytensor flags mode is Mode", - EnumStr("cvm", linker_options), + # Not mutable because the default mode is cached after the first use. + EnumStr("cvm", linker_options, mutable=False), in_c_key=False, ) @@ -410,6 +411,7 @@ def add_compile_configvars(): EnumStr( "o4", ["o3", "o2", "o1", "unsafe", "fast_run", "fast_compile", "merge", "None"], + mutable=False, # Not mutable because the default mode is cached after the first use. ), in_c_key=False, ) diff --git a/tests/compile/function/test_types.py b/tests/compile/function/test_types.py index bef3ae25bf..0990dbeca0 100644 --- a/tests/compile/function/test_types.py +++ b/tests/compile/function/test_types.py @@ -1105,14 +1105,10 @@ def test_optimizations_preserved(self): ((a.T.T) * (dot(xm, (sm.T.T.T)) + x).T * (x / x) + s), ) old_default_mode = config.mode - old_default_opt = config.optimizer - old_default_link = config.linker try: try: str_f = pickle.dumps(f, protocol=-1) - config.mode = "Mode" - config.linker = "py" - config.optimizer = "None" + config.mode = "NUMBA" g = pickle.loads(str_f) # print g.maker.mode # print compile.mode.default_mode @@ -1121,8 +1117,6 @@ def test_optimizations_preserved(self): g = "ok" finally: config.mode = old_default_mode - config.optimizer = old_default_opt - config.linker = old_default_link if g == "ok": return diff --git a/tests/compile/test_mode.py b/tests/compile/test_mode.py index c965087ea2..291eac0782 100644 --- a/tests/compile/test_mode.py +++ b/tests/compile/test_mode.py @@ -13,6 +13,7 @@ from pytensor.graph.features import NoOutputFromInplace from pytensor.graph.rewriting.db import RewriteDatabaseQuery, SequenceDB from pytensor.link.basic import LocalLinker +from pytensor.link.jax import JAXLinker from pytensor.tensor.math import dot, tanh from pytensor.tensor.type import matrix, vector @@ -142,3 +143,15 @@ class MyLinker(LocalLinker): test_mode = Mode(linker=MyLinker()) with pytest.raises(Exception): get_target_language(test_mode) + + +def test_predefined_modes_respected(): + default_mode = get_default_mode() + assert not isinstance(default_mode.linker, JAXLinker) + + with config.change_flags(mode="JAX"): + jax_mode = get_default_mode() + assert isinstance(jax_mode.linker, JAXLinker) + + default_mode_again = get_default_mode() + assert not isinstance(default_mode_again.linker, JAXLinker)