From 560f41e0a3adb583ccb09277f3fd42a414663bd6 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Mon, 4 Nov 2024 12:57:12 +0100 Subject: [PATCH] Implement unconditional constant_folding rewrite --- pytensor/tensor/rewriting/basic.py | 23 ++++- tests/tensor/rewriting/test_basic.py | 132 +++++++++++++++++---------- 2 files changed, 105 insertions(+), 50 deletions(-) diff --git a/pytensor/tensor/rewriting/basic.py b/pytensor/tensor/rewriting/basic.py index 78d00790ac..c239b4bec4 100644 --- a/pytensor/tensor/rewriting/basic.py +++ b/pytensor/tensor/rewriting/basic.py @@ -32,6 +32,7 @@ from pytensor.graph import FunctionGraph from pytensor.graph.basic import Constant, Variable from pytensor.graph.rewriting.basic import ( + NodeProcessingGraphRewriter, NodeRewriter, RemovalNodeRewriter, Rewriter, @@ -1101,10 +1102,7 @@ def local_useless_split(fgraph, node): @node_rewriter(None) -def constant_folding(fgraph, node): - if not node.op.do_constant_folding(fgraph, node): - return False - +def unconditional_constant_folding(fgraph, node): if not all(isinstance(inp, Constant) for inp in node.inputs): return False @@ -1151,6 +1149,23 @@ def constant_folding(fgraph, node): return rval +topo_unconditional_constant_folding = in2out( + unconditional_constant_folding, + ignore_newtrees=True, + name="topo_unconditional_constant_folding", + # Not all Ops have a perform method, so we ignore failures to constant_fold + failure_callback=NodeProcessingGraphRewriter.warn_ignore, +) + + +@node_rewriter(None) +def constant_folding(fgraph, node): + if not node.op.do_constant_folding(fgraph, node): + return False + + return unconditional_constant_folding.transform(fgraph, node) + + topo_constant_folding = in2out( constant_folding, ignore_newtrees=True, name="topo_constant_folding" ) diff --git a/tests/tensor/rewriting/test_basic.py b/tests/tensor/rewriting/test_basic.py index 4ff773dbb8..8911f56630 100644 --- a/tests/tensor/rewriting/test_basic.py +++ b/tests/tensor/rewriting/test_basic.py @@ -12,7 +12,8 @@ from pytensor.compile.mode import get_default_mode, get_mode from pytensor.compile.ops import DeepCopyOp, deep_copy_op from pytensor.configdefaults import config -from pytensor.graph.basic import equal_computations +from pytensor.graph import Op +from pytensor.graph.basic import Constant, equal_computations from pytensor.graph.fg import FunctionGraph from pytensor.graph.rewriting.basic import check_stack_trace, out2in from pytensor.graph.rewriting.db import RewriteDatabaseQuery @@ -29,6 +30,7 @@ TensorFromScalar, as_tensor, cast, + constant, join, tile, ) @@ -65,6 +67,8 @@ local_merge_alloc, local_useless_alloc, local_useless_elemwise, + topo_constant_folding, + topo_unconditional_constant_folding, topological_fill_sink, ) from pytensor.tensor.rewriting.math import local_lift_transpose_through_dot @@ -742,56 +746,92 @@ def test_upcast(self): ) or (len(topo) > 1) -def test_constant_folding(): - # Test that constant folding get registered at fast_compile - # An error removed that registration during the registration. - x = dvector() - mode = get_mode("FAST_COMPILE").excluding("fusion") - f = function([x], [x * 2, x + x], mode=mode) - topo = f.maker.fgraph.toposort() - assert len(topo) == 2 - - # Test that we do not crash when constant folding elemwise scalar - # as they should not generate c code. +class TestConstantFolding: + def test_constant_folding(self): + # Test that constant folding get registered at fast_compile + # An error removed that registration during the registration. + x = dvector() + mode = get_mode("FAST_COMPILE").excluding("fusion") + f = function([x], [x * 2, x + x], mode=mode) + topo = f.maker.fgraph.toposort() + assert len(topo) == 2 - x = pt.constant(3) - assert x.ndim == 0 - mode = get_mode("FAST_COMPILE").excluding("fusion") - f = function([], [x * 2, x + x], mode=mode) - topo = f.maker.fgraph.toposort() - assert len(topo) == 2 - assert all(isinstance(n.op, DeepCopyOp) for n in topo) + # Test that we do not crash when constant folding elemwise scalar + # as they should not generate c code. + x = pt.constant(3) + assert x.ndim == 0 + mode = get_mode("FAST_COMPILE").excluding("fusion") + f = function([], [x * 2, x + x], mode=mode) + topo = f.maker.fgraph.toposort() + assert len(topo) == 2 + assert all(isinstance(n.op, DeepCopyOp) for n in topo) -@pytest.mark.xfail( - reason="PyTensor rewrites constants before stabilization. " - "This breaks stabilization rewrites in some cases. See #504.", - raises=AssertionError, -) -def test_constant_get_stabilized(): - # Currently PyTensor enables the `constant_folding` rewrite before stabilization rewrites. - # This caused some stabilization rewrites to not be activated and that - # caused inf values to appear when they should not. + @pytest.mark.xfail( + reason="PyTensor rewrites constants before stabilization. " + "This breaks stabilization rewrites in some cases. See #504.", + raises=AssertionError, + ) + def test_constant_get_stabilized(self): + # Currently PyTensor enables the `constant_folding` rewrite before stabilization rewrites. + # This caused some stabilization rewrites to not be activated and that + # caused inf values to appear when they should not. - # We can't simply move the `constant_folding` rewrite to - # specialize since this will break other rewrites. We will need to - # partially duplicate some canonicalize rewrites to fix this issue. + # We can't simply move the `constant_folding` rewrite to + # specialize since this will break other rewrites. We will need to + # partially duplicate some canonicalize rewrites to fix this issue. - x2 = scalar() - y2 = log(1 + exp(x2)) - mode = get_default_mode() - mode.check_isfinite = False - f2 = function([x2], y2, mode=mode) - - assert len(f2.maker.fgraph.toposort()) == 1 - assert f2.maker.fgraph.toposort()[0].op == softplus - assert f2(800) == 800 - - x = pt.as_tensor_variable(800) - y = log(1 + exp(x)) - f = function([], y, mode=mode) - # When this error is fixed, the following line should be ok. - assert f() == 800, f() + x2 = scalar() + y2 = log(1 + exp(x2)) + mode = get_default_mode() + mode.check_isfinite = False + f2 = function([x2], y2, mode=mode) + + assert len(f2.maker.fgraph.toposort()) == 1 + assert f2.maker.fgraph.toposort()[0].op == softplus + assert f2(800) == 800 + + x = pt.as_tensor_variable(800) + y = log(1 + exp(x)) + f = function([], y, mode=mode) + # When this error is fixed, the following line should be ok. + assert f() == 800, f() + + def test_unconditional(self): + x = pt.alloc(np.e, *(3, 5)) + fg = FunctionGraph(outputs=[x], clone=False) + + # Default constant folding doesn't apply to Alloc used as outputs + topo_constant_folding.apply(fg) + assert not isinstance(fg.outputs[0], Constant) + + # Unconditional constant folding does apply + topo_unconditional_constant_folding.apply(fg) + assert isinstance(fg.outputs[0], Constant) + np.testing.assert_allclose(fg.outputs[0].data, np.full((3, 5), np.e)) + + def test_unconditional_no_perform_method(self): + """Test that errors are caught when the Op does not have a perform method.""" + + class OpNoPerform(Op): + itypes = [scalar(dtype="float64").type] + otypes = [scalar(dtype="float64").type] + + def perform(self, *args, **kwargs): + raise NotImplementedError("This Op cannot be evaluated") + + x = constant(np.array(5.0)) + out = OpNoPerform()(x) + + fg = FunctionGraph(outputs=[out], clone=False) + # Default constant_folding will raise + with pytest.raises(NotImplementedError): + topo_constant_folding.apply(fg) + + # Unconditional constant folding will be silent + topo_unconditional_constant_folding.apply(fg) + assert not isinstance(fg.outputs[0], Constant) + assert isinstance(fg.outputs[0].owner.op, OpNoPerform) class TestLocalSwitchSink: