From 4a5cebd93b294e344df55635b62d42f8635e90ca Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Mon, 11 Nov 2024 14:11:41 +0100 Subject: [PATCH 1/2] Don't raise raw Exception in eval --- pytensor/graph/basic.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytensor/graph/basic.py b/pytensor/graph/basic.py index 6b4ca7570d..29b2043f6d 100644 --- a/pytensor/graph/basic.py +++ b/pytensor/graph/basic.py @@ -622,9 +622,9 @@ def convert_string_keys_to_variables(inputs_to_values) -> dict["Variable", Any]: if isinstance(key, str): matching_vars = get_var_by_name([self], key) if not matching_vars: - raise Exception(f"{key} not found in graph") + raise ValueError(f"{key} not found in graph") elif len(matching_vars) > 1: - raise Exception(f"Found multiple variables with name {key}") + raise ValueError(f"Found multiple variables with name {key}") new_input_to_values[matching_vars[0]] = value else: new_input_to_values[key] = value From 11b8fd652f959259f6f63c39cadb4d64231a3f75 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Mon, 11 Nov 2024 14:12:29 +0100 Subject: [PATCH 2/2] Fix bug in local_div_switch_sink rewrite Introduced in 4f7d7096ea98fa1285b50a9d583373b5963d425d --- pytensor/tensor/rewriting/math.py | 5 ++++- tests/tensor/rewriting/test_math.py | 23 ++++++++++++++++++++++- 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index b230f035cc..16df2d1b08 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -699,7 +699,10 @@ def local_div_switch_sink(fgraph, node): # will point to the new division op. copy_stack_trace(node.outputs, fdiv) - fct = switch(switch_cond, zero_switch_input, fdiv) + if branch == 0: + fct = switch(switch_cond, zero_switch_input, fdiv) + else: + fct = switch(switch_cond, fdiv, zero_switch_input) # Tell debug_mode than the output is correct, even if nan disappear fct.tag.values_eq_approx = values_eq_approx_remove_nan diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index e4a08cdf81..1160562e62 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -2163,7 +2163,7 @@ def test_local_mul_div_switch_sink_cast(self, op, rewrite): # The zero branch upcasts the output, so we can't ignore its dtype zero_branch = constant(np.array(0, dtype="float64"), name="zero_branch") other_branch = scalar("other_branch", dtype="float32") - outer_var = scalar("mul_var", dtype="bool") + outer_var = scalar("outer_var", dtype="bool") out = op(switch(cond, zero_branch, other_branch), outer_var) fgraph = FunctionGraph(outputs=[out], clone=False) @@ -2173,6 +2173,27 @@ def test_local_mul_div_switch_sink_cast(self, op, rewrite): expected_out = switch(cond, zero_branch, op(other_branch, outer_var)) assert equal_computations([new_out], [expected_out]) + @pytest.mark.parametrize( + "op, rewrite", [(mul, local_mul_switch_sink), (true_div, local_div_switch_sink)] + ) + def test_local_mul_div_switch_sink_branch_order(self, op, rewrite): + cond = scalar("cond", dtype="bool") + zero_branch = constant(np.array(0.0, dtype="float64"), "zero_branch") + other_branch = scalar("other_branch", dtype="float64") + outer_var = scalar("outer_var", dtype="float64") + + left = op(switch(cond, zero_branch, other_branch), outer_var) + right = op(switch(cond, other_branch, zero_branch), outer_var) + fgraph = FunctionGraph(outputs=[left, right], clone=False) + [new_left] = rewrite.transform(fgraph, left.owner) + [new_right] = rewrite.transform(fgraph, right.owner) + + expected_left = switch(cond, zero_branch, op(other_branch, outer_var)) + expected_right = switch(cond, op(other_branch, outer_var), zero_branch) + assert equal_computations( + [new_left, new_right], [expected_left, expected_right] + ) + @pytest.mark.skipif( config.cxx == "",