From 2307940eaff922d4c5b751717937770eec1bbcda Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 5 Jun 2025 11:29:01 +0200 Subject: [PATCH 1/5] Fix typo in test variable --- tests/tensor/test_math_scipy.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/tensor/test_math_scipy.py b/tests/tensor/test_math_scipy.py index fbfa5fb77e..37c2b9ea13 100644 --- a/tests/tensor/test_math_scipy.py +++ b/tests/tensor/test_math_scipy.py @@ -376,8 +376,8 @@ def test_gammainc_ddk_tabulated_values(): # https://github.com/stan-dev/math/blob/21333bb70b669a1bd54d444ecbe1258078d33153/test/unit/math/prim/scal/fun/grad_reg_lower_inc_gamma_test.cpp k, x = pt.scalars("k", "x") gammainc_out = pt.gammainc(k, x) - gammaincc_ddk = pt.grad(gammainc_out, k) - f_grad = function([k, x], gammaincc_ddk) + gammainc_ddk = pt.grad(gammainc_out, k) + f_grad = function([k, x], gammainc_ddk) rtol = 1e-5 if config.floatX == "float64" else 1e-2 atol = 1e-10 if config.floatX == "float64" else 1e-6 From 39fa1fce7cf99417a6e7a93028eca522ee8007f6 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 18 Nov 2025 10:40:54 +0100 Subject: [PATCH 2/5] Method call not property --- pytensor/link/numba/dispatch/scalar.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytensor/link/numba/dispatch/scalar.py b/pytensor/link/numba/dispatch/scalar.py index 4a4d9b319d..b8f14f02a4 100644 --- a/pytensor/link/numba/dispatch/scalar.py +++ b/pytensor/link/numba/dispatch/scalar.py @@ -73,7 +73,7 @@ def numba_funcify_ScalarOp(op, node, **kwargs): scalar_func_numba = wrap_cython_function( cython_func, output_dtype, input_dtypes ) - has_pyx_skip_dispatch = scalar_func_numba.has_pyx_skip_dispatch + has_pyx_skip_dispatch = scalar_func_numba.has_pyx_skip_dispatch() input_inner_dtypes = scalar_func_numba.numpy_arg_dtypes() output_inner_dtype = scalar_func_numba.numpy_output_dtype() From fa177fbbeb699713ed4989e03f0b5e4596617feb Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 18 Nov 2025 10:39:30 +0100 Subject: [PATCH 3/5] Workaround https://github.com/numba/numba/issues/9554 --- pytensor/link/numba/dispatch/scalar.py | 18 ++++++++++++++++++ tests/link/numba/test_scalar.py | 10 ++++++++++ 2 files changed, 28 insertions(+) diff --git a/pytensor/link/numba/dispatch/scalar.py b/pytensor/link/numba/dispatch/scalar.py index b8f14f02a4..aa6537bbcd 100644 --- a/pytensor/link/numba/dispatch/scalar.py +++ b/pytensor/link/numba/dispatch/scalar.py @@ -23,6 +23,7 @@ Composite, Identity, Mul, + Pow, Reciprocal, ScalarOp, Second, @@ -171,6 +172,23 @@ def {binary_op_name}({input_signature}): return nary_fn +@register_funcify_and_cache_key(Pow) +def numba_funcify_Pow(op, node, **kwargs): + pow_dtype = node.inputs[1].type.dtype + if pow_dtype.startswith("int"): + # Numba power fails when exponents are non 64-bit discrete integers and fasthmath=True + # https://github.com/numba/numba/issues/9554 + + def pow(x, y): + return x ** np.asarray(y, dtype=np.int64).item() + else: + + def pow(x, y): + return x**y + + return numba_basic.numba_njit(pow), scalar_op_cache_key(op) + + @register_funcify_and_cache_key(Add) def numba_funcify_Add(op, node, **kwargs): nary_add_fn = binary_to_nary_func(node.inputs, "add", "+") diff --git a/tests/link/numba/test_scalar.py b/tests/link/numba/test_scalar.py index 2125d7cc0e..b431b81379 100644 --- a/tests/link/numba/test_scalar.py +++ b/tests/link/numba/test_scalar.py @@ -184,3 +184,13 @@ def test_Softplus(dtype): strict=True, err_msg=f"Failed for value {value}", ) + + +def test_discrete_power(): + # Test we don't fail to compile power with discrete exponents due to https://github.com/numba/numba/issues/9554 + x = pt.scalar("x", dtype="float64") + exponent = pt.scalar("exponent", dtype="int8") + out = pt.power(x, exponent) + compare_numba_and_py( + [x, exponent], [out], [np.array(0.5), np.array(2, dtype="int8")] + ) From 8dd6f7e6d63de3ac988542a86a77e149925ea74e Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 5 Jun 2025 11:29:22 +0200 Subject: [PATCH 4/5] Scalar while loop defaults to False if n_steps == 0 --- pytensor/scalar/loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytensor/scalar/loop.py b/pytensor/scalar/loop.py index f23c4e1c42..d588f59b1b 100644 --- a/pytensor/scalar/loop.py +++ b/pytensor/scalar/loop.py @@ -183,7 +183,7 @@ def perform(self, node, inputs, output_storage): inner_fn = self.py_perform_fn if self.is_while: - until = True + until = False for i in range(n_steps): *carry, until = inner_fn(*carry, *constant) if until: From 89d567af19da8d8fa2c5d3a58157c57faa7ce90a Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 5 Jun 2025 11:31:14 +0200 Subject: [PATCH 5/5] Numba dispatch of ScalarLoop --- pytensor/link/numba/dispatch/scalar.py | 50 ++++++++++++++++++ tests/link/numba/test_elemwise.py | 44 ++++++++++++---- tests/link/numba/test_scalar.py | 72 ++++++++++++++++++++++++++ 3 files changed, 156 insertions(+), 10 deletions(-) diff --git a/pytensor/link/numba/dispatch/scalar.py b/pytensor/link/numba/dispatch/scalar.py index aa6537bbcd..c6b070469e 100644 --- a/pytensor/link/numba/dispatch/scalar.py +++ b/pytensor/link/numba/dispatch/scalar.py @@ -16,6 +16,7 @@ get_name_for_object, unique_name_generator, ) +from pytensor.scalar import ScalarLoop from pytensor.scalar.basic import ( Add, Cast, @@ -336,3 +337,52 @@ def softplus(x): return numba_basic.direct_cast(value, out_dtype) return softplus, scalar_op_cache_key(op) + + +@register_funcify_and_cache_key(ScalarLoop) +def numba_funcify_ScalarLoop(op, node, **kwargs): + inner_fn, inner_fn_cache_key = numba_funcify_and_cache_key(op.fgraph) + if inner_fn_cache_key is None: + loop_cache_key = None + else: + loop_cache_key = sha256( + str((type(op), op.is_while, inner_fn_cache_key)).encode() + ).hexdigest() + + if op.is_while: + n_update = len(op.outputs) - 1 + + @numba_basic.numba_njit + def while_loop(n_steps, *inputs): + carry, constant = inputs[:n_update], inputs[n_update:] + + until = False + for i in range(n_steps): + outputs = inner_fn(*carry, *constant) + carry, until = outputs[:-1], outputs[-1] + if until: + break + + return *carry, until + + return while_loop, loop_cache_key + + else: + n_update = len(op.outputs) + + @numba_basic.numba_njit + def for_loop(n_steps, *inputs): + carry, constant = inputs[:n_update], inputs[n_update:] + + if n_steps < 0: + raise ValueError("ScalarLoop does not have a termination condition.") + + for i in range(n_steps): + carry = inner_fn(*carry, *constant) + + if n_update == 1: + return carry[0] + else: + return carry + + return for_loop, loop_cache_key diff --git a/tests/link/numba/test_elemwise.py b/tests/link/numba/test_elemwise.py index 954656cebe..de05bc6831 100644 --- a/tests/link/numba/test_elemwise.py +++ b/tests/link/numba/test_elemwise.py @@ -587,18 +587,42 @@ def test_elemwise_multiple_inplace_outs(): def test_scalar_loop(): - a = float64("a") - scalar_loop = pytensor.scalar.ScalarLoop([a], [a + a]) + a_scalar = float64("a") + const_scalar = float64("const") + scalar_loop = pytensor.scalar.ScalarLoop( + init=[a_scalar], + update=[a_scalar + a_scalar + const_scalar], + constant=[const_scalar], + ) - x = pt.tensor("x", shape=(3,)) - elemwise_loop = Elemwise(scalar_loop)(3, x) + a = pt.tensor("a", shape=(3,)) + const = pt.tensor("const", shape=(3,)) + n_steps = 3 + elemwise_loop = Elemwise(scalar_loop)(n_steps, a, const) - with pytest.warns(UserWarning, match="object mode"): - compare_numba_and_py( - [x], - [elemwise_loop], - (np.array([1, 2, 3], dtype="float64"),), - ) + compare_numba_and_py( + [a, const], + [elemwise_loop], + [np.array([1, 2, 3], dtype="float64"), np.array([1, 1, 1], dtype="float64")], + ) + + +def test_gammainc_wrt_k_grad(): + x = pt.vector("x", dtype="float64") + k = pt.vector("k", dtype="float64") + + out = pt.gammainc(k, x) + grad_out = grad(out.sum(), k) + + compare_numba_and_py( + [x, k], + [grad_out], + # These values of x and k trigger all the branches in the gradient of gammainc + [ + np.array([0.0, 29.0, 31.0], dtype="float64"), + np.array([1.0, 13.0, 11.0], dtype="float64"), + ], + ) class TestsBenchmark: diff --git a/tests/link/numba/test_scalar.py b/tests/link/numba/test_scalar.py index b431b81379..3c83168dcd 100644 --- a/tests/link/numba/test_scalar.py +++ b/tests/link/numba/test_scalar.py @@ -6,6 +6,7 @@ import pytensor.scalar.math as psm import pytensor.tensor as pt from pytensor import config, function +from pytensor.scalar import ScalarLoop from pytensor.scalar.basic import Composite from pytensor.tensor import tensor from pytensor.tensor.elemwise import Elemwise @@ -194,3 +195,74 @@ def test_discrete_power(): compare_numba_and_py( [x, exponent], [out], [np.array(0.5), np.array(2, dtype="int8")] ) + + +class TestScalarLoop: + def test_scalar_for_loop_single_out(self): + n_steps = ps.int64("n_steps") + x0 = ps.float64("x0") + const = ps.float64("const") + x = x0 + const + + op = ScalarLoop(init=[x0], constant=[const], update=[x]) + x = op(n_steps, x0, const) + + fn = function([n_steps, x0, const], [x], mode=numba_mode) + + res_x = fn(n_steps=5, x0=0, const=1) + np.testing.assert_allclose(res_x, 5) + + res_x = fn(n_steps=5, x0=0, const=2) + np.testing.assert_allclose(res_x, 10) + + res_x = fn(n_steps=4, x0=3, const=-1) + np.testing.assert_allclose(res_x, -1) + + def test_scalar_for_loop_multiple_outs(self): + n_steps = ps.int64("n_steps") + x0 = ps.float64("x0") + y0 = ps.int64("y0") + const = ps.float64("const") + x = x0 + const + y = y0 + 1 + + op = ScalarLoop(init=[x0, y0], constant=[const], update=[x, y]) + x, y = op(n_steps, x0, y0, const) + + fn = function([n_steps, x0, y0, const], [x, y], mode=numba_mode) + + res_x, res_y = fn(n_steps=5, x0=0, y0=0, const=1) + np.testing.assert_allclose(res_x, 5) + np.testing.assert_allclose(res_y, 5) + + res_x, res_y = fn(n_steps=5, x0=0, y0=0, const=2) + np.testing.assert_allclose(res_x, 10) + np.testing.assert_allclose(res_y, 5) + + res_x, res_y = fn(n_steps=4, x0=3, y0=2, const=-1) + np.testing.assert_allclose(res_x, -1) + np.testing.assert_allclose(res_y, 6) + + def test_scalar_while_loop(self): + n_steps = ps.int64("n_steps") + x0 = ps.float64("x0") + x = x0 + 1 + until = x >= 10 + + op = ScalarLoop(init=[x0], update=[x], until=until) + fn = function([n_steps, x0], op(n_steps, x0), mode=numba_mode) + np.testing.assert_allclose(fn(n_steps=20, x0=0), [10, True]) + np.testing.assert_allclose(fn(n_steps=20, x0=1), [10, True]) + np.testing.assert_allclose(fn(n_steps=5, x0=1), [6, False]) + np.testing.assert_allclose(fn(n_steps=0, x0=1), [1, False]) + + def test_loop_with_cython_wrapped_op(self): + x = ps.float64("x") + op = ScalarLoop(init=[x], update=[ps.psi(x)]) + out = op(1, x) + + fn = function([x], out, mode=numba_mode) + x_test = np.float64(0.5) + res = fn(x_test) + expected_res = ps.psi(x).eval({x: x_test}) + np.testing.assert_allclose(res, expected_res)