From a6444c70a8f2c5795f70a313f329043c9c5dac27 Mon Sep 17 00:00:00 2001 From: nimish Date: Sat, 8 Mar 2025 19:05:52 +0530 Subject: [PATCH 1/8] Modify np.tri Op, wrap around OpFromGraph --- pytensor/tensor/basic.py | 40 ++++++++++++---------------------------- 1 file changed, 12 insertions(+), 28 deletions(-) diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index e30887cfe3..34ba6e80d6 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -43,6 +43,7 @@ get_vector_length, ) from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback +from pytensor.tensor.einsum import _iota from pytensor.tensor.elemwise import ( DimShuffle, Elemwise, @@ -1084,35 +1085,18 @@ def nonzero_values(a): return _a.flatten()[flatnonzero(_a)] -class Tri(Op): +class Tri(OpFromGraph): + """ + Wrapper Op for np.tri graphs + """ + __props__ = ("dtype",) - def __init__(self, dtype=None): - if dtype is None: - dtype = config.floatX + def __init__(self, *args, M, k, dtype, **kwargs): + self.M = M + self.k = k self.dtype = dtype - - def make_node(self, N, M, k): - N = as_tensor_variable(N) - M = as_tensor_variable(M) - k = as_tensor_variable(k) - return Apply( - self, - [N, M, k], - [TensorType(dtype=self.dtype, shape=(None, None))()], - ) - - def perform(self, node, inp, out_): - N, M, k = inp - (out,) = out_ - out[0] = np.tri(N, M, k, dtype=self.dtype) - - def infer_shape(self, fgraph, node, in_shapes): - out_shape = [node.inputs[0], node.inputs[1]] - return [out_shape] - - def grad(self, inp, grads): - return [grad_undefined(self, i, inp[i]) for i in range(3)] + super().__init__(*args, **kwargs, strict=True) def tri(N, M=None, k=0, dtype=None): @@ -1144,8 +1128,8 @@ def tri(N, M=None, k=0, dtype=None): dtype = config.floatX if M is None: M = N - op = Tri(dtype) - return op(N, M, k) + output = ((_iota(M) + k) > _iota(N)).astype(int) + return Tri(inputs=[N], outputs=[output], M=M, k=k, dtype=dtype)(N) def tril(m, k=0): From 21a158deaba34cd1e3b8242c1747c4cf4f5f32ac Mon Sep 17 00:00:00 2001 From: nimish Date: Tue, 11 Mar 2025 00:06:10 +0530 Subject: [PATCH 2/8] Move iota from einsum.py to basic.py --- pytensor/tensor/basic.py | 65 +++++++++++++++++++++++++++++++++++++-- pytensor/tensor/einsum.py | 65 ++------------------------------------- 2 files changed, 65 insertions(+), 65 deletions(-) diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 34ba6e80d6..0ef405eb1a 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -43,7 +43,6 @@ get_vector_length, ) from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback -from pytensor.tensor.einsum import _iota from pytensor.tensor.elemwise import ( DimShuffle, Elemwise, @@ -1061,6 +1060,65 @@ def flatnonzero(a): return nonzero(_a.flatten(), return_matrix=False)[0] +def iota(shape: TensorVariable, axis: int) -> TensorVariable: + """ + Create an array with values increasing along the specified axis. + + Iota is a multidimensional generalization of the `arange` function. The returned array is filled with whole numbers + increasing along the specified axis. + + Parameters + ---------- + shape: TensorVariable + The shape of the array to be created. + axis: int + The axis along which to fill the array with increasing values. + + Returns + ------- + TensorVariable + An array with values increasing along the specified axis. + + Examples + -------- + In the simplest case where ``shape`` is 1d, the output will be equivalent to ``pt.arange``: + + .. testcode:: + + import pytensor.tensor as pt + + shape = pt.as_tensor((5,)) + print(pt.basic.iota(shape, 0).eval()) + + .. testoutput:: + + [0 1 2 3 4] + + In higher dimensions, it will look like many concatenated `arange`: + + .. testcode:: + + shape = pt.as_tensor((5, 5)) + print(pt.basic.iota(shape, 1).eval()) + + .. testoutput:: + + [[0 1 2 3 4] + [0 1 2 3 4] + [0 1 2 3 4] + [0 1 2 3 4] + [0 1 2 3 4]] + + Setting ``axis=0`` above would result in the transpose of the output. + """ + len_shape = get_vector_length(shape) + axis = normalize_axis_index(axis, len_shape) + values = arange(shape[axis]) + return pytensor.tensor.extra_ops.broadcast_to( + shape_padright(values, len_shape - axis - 1), shape + ) + + def nonzero_values(a): """Return a vector of non-zero elements contained in the input array. @@ -1128,7 +1186,10 @@ def tri(N, M=None, k=0, dtype=None): dtype = config.floatX if M is None: M = N - output = ((_iota(M) + k) > _iota(N)).astype(int) + output = ((iota(as_tensor((N, 1)), 0) + k + 1) > iota(as_tensor((1, M)), 1)).astype( + int + ) + N = as_tensor_variable(N) return Tri(inputs=[N], outputs=[output], M=M, k=k, dtype=dtype)(N) diff --git a/pytensor/tensor/einsum.py b/pytensor/tensor/einsum.py index 660c16d387..145a3c1bfd 100644 --- a/pytensor/tensor/einsum.py +++ b/pytensor/tensor/einsum.py @@ -11,15 +11,13 @@ from pytensor.npy_2_compat import ( _find_contraction, _parse_einsum_input, - normalize_axis_index, normalize_axis_tuple, ) from pytensor.tensor import TensorLike from pytensor.tensor.basic import ( - arange, as_tensor, expand_dims, - get_vector_length, + iota, moveaxis, stack, transpose, @@ -28,7 +26,6 @@ from pytensor.tensor.extra_ops import broadcast_to from pytensor.tensor.functional import vectorize from pytensor.tensor.math import and_, eq, tensordot -from pytensor.tensor.shape import shape_padright from pytensor.tensor.variable import TensorVariable @@ -63,64 +60,6 @@ def __str__(self): return f"Einsum{{{self.subscripts=}, {self.path=}, {self.optimized=}}}" -def _iota(shape: TensorVariable, axis: int) -> TensorVariable: - """ - Create an array with values increasing along the specified axis. - - Iota is a multidimensional generalization of the `arange` function. The returned array is filled with whole numbers - increasing along the specified axis. - - Parameters - ---------- - shape: TensorVariable - The shape of the array to be created. - axis: int - The axis along which to fill the array with increasing values. - - Returns - ------- - TensorVariable - An array with values increasing along the specified axis. - - Examples - -------- - In the simplest case where ``shape`` is 1d, the output will be equivalent to ``pt.arange``: - - .. testcode:: - - import pytensor.tensor as pt - from pytensor.tensor.einsum import _iota - - shape = pt.as_tensor((5,)) - print(_iota(shape, 0).eval()) - - .. testoutput:: - - [0 1 2 3 4] - - In higher dimensions, it will look like many concatenated `arange`: - - .. testcode:: - - shape = pt.as_tensor((5, 5)) - print(_iota(shape, 1).eval()) - - .. testoutput:: - - [[0 1 2 3 4] - [0 1 2 3 4] - [0 1 2 3 4] - [0 1 2 3 4] - [0 1 2 3 4]] - - Setting ``axis=0`` above would result in the transpose of the output. - """ - len_shape = get_vector_length(shape) - axis = normalize_axis_index(axis, len_shape) - values = arange(shape[axis]) - return broadcast_to(shape_padright(values, len_shape - axis - 1), shape) - - def _delta(shape: TensorVariable, axes: Sequence[int]) -> TensorVariable: """ Create a Kroncker delta tensor. @@ -201,7 +140,7 @@ def _delta(shape: TensorVariable, axes: Sequence[int]) -> TensorVariable: if len(axes) == 1: raise ValueError("Need at least two axes to create a delta tensor") base_shape = stack([shape[axis] for axis in axes]) - iotas = [_iota(base_shape, i) for i in range(len(axes))] + iotas = [iota(base_shape, i) for i in range(len(axes))] eyes = [eq(i1, i2) for i1, i2 in pairwise(iotas)] result = reduce(and_, eyes) non_axes = [i for i in range(len(tuple(shape))) if i not in axes] From 5172a520347200b6d452e82507e3ffc29e80eb90 Mon Sep 17 00:00:00 2001 From: nimish Date: Wed, 12 Mar 2025 08:56:21 +0530 Subject: [PATCH 3/8] Modify tests for tri Op --- pytensor/tensor/basic.py | 9 ++++++--- tests/tensor/test_basic.py | 33 ++++++++++++++++++--------------- 2 files changed, 24 insertions(+), 18 deletions(-) diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 0ef405eb1a..07506c1baf 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -34,7 +34,7 @@ from pytensor.npy_2_compat import normalize_axis_index, normalize_axis_tuple from pytensor.printing import Printer, min_informative_str, pprint, set_precedence from pytensor.raise_op import CheckAndRaise -from pytensor.scalar import int32 +from pytensor.scalar import int32, upcast from pytensor.scalar.basic import ScalarConstant, ScalarType, ScalarVariable from pytensor.tensor import ( _as_tensor_variable, @@ -1186,8 +1186,9 @@ def tri(N, M=None, k=0, dtype=None): dtype = config.floatX if M is None: M = N + output = ((iota(as_tensor((N, 1)), 0) + k + 1) > iota(as_tensor((1, M)), 1)).astype( - int + dtype ) N = as_tensor_variable(N) return Tri(inputs=[N], outputs=[output], M=M, k=k, dtype=dtype)(N) @@ -1244,7 +1245,9 @@ def tril(m, k=0): [55, 56, 57, 58, 0]]]) """ - return m * tri(*m.shape[-2:], k=k, dtype=m.dtype) + N, M = m.shape[-2:] + dtype = upcast(m.dtype) + return m * tri(N, M=M, k=k, dtype=dtype) # M is symbolic, while it shouldnt be def triu(m, k=0): diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index 60643e2984..107e69f322 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -991,16 +991,19 @@ def check(dtype, N, M_=None, k=0): if M is None and config.mode in ["DebugMode", "DEBUG_MODE"]: M = N N_symb = iscalar() - M_symb = iscalar() - k_symb = iscalar() - f = function( - [N_symb, M_symb, k_symb], tri(N_symb, M_symb, k_symb, dtype=dtype) - ) - result = f(N, M, k) + f = function([N_symb], tri(N_symb, M=M, k=k, dtype=dtype)) + # kwargs = {} + result = f(N) assert np.allclose(result, np.tri(N, M_, k, dtype=dtype)) assert result.dtype == np.dtype(dtype) - for dtype in ["int32", "int64", "float32", "float64", "uint16", "complex64"]: + for dtype in [ + "int32", + "int64", + "float32", + "float64", + "uint16", + ]: # Handle "complex64" ? check(dtype, 3) # M != N, k = 0 check(dtype, 3, 5) @@ -1022,15 +1025,15 @@ def test_tril_triu(self): def check_l(m, k=0): m_symb = matrix(dtype=m.dtype) - k_symb = iscalar() - f = function([m_symb, k_symb], tril(m_symb, k_symb)) + # k_symb = iscalar() + f = function([m_symb], tril(m_symb, k=k)) f_indx = function( - [m_symb, k_symb], tril_indices(m_symb.shape[0], k_symb, m_symb.shape[1]) + [m_symb], tril_indices(m_symb.shape[0], k=k, m=m_symb.shape[1]) ) - f_indx_from = function([m_symb, k_symb], tril_indices_from(m_symb, k_symb)) - result = f(m, k) - result_indx = f_indx(m, k) - result_from = f_indx_from(m, k) + f_indx_from = function([m_symb], tril_indices_from(m_symb)) + result = f(m) + result_indx = f_indx(m, k=k) + result_from = f_indx_from(m, k=k) assert np.allclose(result, np.tril(m, k)) assert np.allclose(result_indx, np.tril_indices(m.shape[0], k, m.shape[1])) assert np.allclose(result_from, np.tril_indices_from(m, k)) @@ -1040,7 +1043,7 @@ def check_l(m, k=0): def check_u(m, k=0): m_symb = matrix(dtype=m.dtype) k_symb = iscalar() - f = function([m_symb, k_symb], triu(m_symb, k_symb)) + f = function([m_symb, k_symb], triu(m_symb, k=k)) f_indx = function( [m_symb, k_symb], triu_indices(m_symb.shape[0], k_symb, m_symb.shape[1]) ) From d28d7741aa2e9b1e9a665615be36c8549dd017f5 Mon Sep 17 00:00:00 2001 From: nimish Date: Thu, 13 Mar 2025 19:23:22 +0530 Subject: [PATCH 4/8] Modify Tri class, revert tests --- pytensor/tensor/basic.py | 23 +++++++++-------------- tests/tensor/test_basic.py | 29 ++++++++++++++++------------- 2 files changed, 25 insertions(+), 27 deletions(-) diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 07506c1baf..c0e6f51d2a 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -34,7 +34,7 @@ from pytensor.npy_2_compat import normalize_axis_index, normalize_axis_tuple from pytensor.printing import Printer, min_informative_str, pprint, set_precedence from pytensor.raise_op import CheckAndRaise -from pytensor.scalar import int32, upcast +from pytensor.scalar import int32 from pytensor.scalar.basic import ScalarConstant, ScalarType, ScalarVariable from pytensor.tensor import ( _as_tensor_variable, @@ -1148,14 +1148,6 @@ class Tri(OpFromGraph): Wrapper Op for np.tri graphs """ - __props__ = ("dtype",) - - def __init__(self, *args, M, k, dtype, **kwargs): - self.M = M - self.k = k - self.dtype = dtype - super().__init__(*args, **kwargs, strict=True) - def tri(N, M=None, k=0, dtype=None): """ @@ -1184,14 +1176,19 @@ def tri(N, M=None, k=0, dtype=None): """ if dtype is None: dtype = config.floatX + dtype = np.dtype(dtype) + if M is None: M = N + N = as_tensor_variable(N) + M = as_tensor_variable(M) + k = as_tensor_variable(k) + output = ((iota(as_tensor((N, 1)), 0) + k + 1) > iota(as_tensor((1, M)), 1)).astype( dtype ) - N = as_tensor_variable(N) - return Tri(inputs=[N], outputs=[output], M=M, k=k, dtype=dtype)(N) + return Tri(inputs=[N, M, k], outputs=[output])(N, M, k) def tril(m, k=0): @@ -1245,9 +1242,7 @@ def tril(m, k=0): [55, 56, 57, 58, 0]]]) """ - N, M = m.shape[-2:] - dtype = upcast(m.dtype) - return m * tri(N, M=M, k=k, dtype=dtype) # M is symbolic, while it shouldnt be + return m * tri(*m.shape[-2:], k=k, dtype=m.dtype) def triu(m, k=0): diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index 107e69f322..21b10de8e7 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -988,12 +988,15 @@ def check(dtype, N, M_=None, k=0): M = M_ # Currently DebugMode does not support None as inputs even if this is # allowed. - if M is None and config.mode in ["DebugMode", "DEBUG_MODE"]: + if M is None: # and config.mode in ["DebugMode", "DEBUG_MODE"]: M = N N_symb = iscalar() - f = function([N_symb], tri(N_symb, M=M, k=k, dtype=dtype)) - # kwargs = {} - result = f(N) + M_symb = iscalar() + k_symb = iscalar() + f = function( + [N_symb, M_symb, k_symb], tri(N_symb, M_symb, k_symb, dtype=dtype) + ) + result = f(N, M, k) assert np.allclose(result, np.tri(N, M_, k, dtype=dtype)) assert result.dtype == np.dtype(dtype) @@ -1025,15 +1028,15 @@ def test_tril_triu(self): def check_l(m, k=0): m_symb = matrix(dtype=m.dtype) - # k_symb = iscalar() - f = function([m_symb], tril(m_symb, k=k)) + k_symb = iscalar() + f = function([m_symb, k_symb], tril(m_symb, k_symb)) f_indx = function( - [m_symb], tril_indices(m_symb.shape[0], k=k, m=m_symb.shape[1]) + [m_symb, k_symb], tril_indices(m_symb.shape[0], k_symb, m_symb.shape[1]) ) - f_indx_from = function([m_symb], tril_indices_from(m_symb)) - result = f(m) - result_indx = f_indx(m, k=k) - result_from = f_indx_from(m, k=k) + f_indx_from = function([m_symb, k_symb], tril_indices_from(m_symb, k_symb)) + result = f(m, k) + result_indx = f_indx(m, k) + result_from = f_indx_from(m, k) assert np.allclose(result, np.tril(m, k)) assert np.allclose(result_indx, np.tril_indices(m.shape[0], k, m.shape[1])) assert np.allclose(result_from, np.tril_indices_from(m, k)) @@ -1043,7 +1046,7 @@ def check_l(m, k=0): def check_u(m, k=0): m_symb = matrix(dtype=m.dtype) k_symb = iscalar() - f = function([m_symb, k_symb], triu(m_symb, k=k)) + f = function([m_symb, k_symb], triu(m_symb, k_symb)) f_indx = function( [m_symb, k_symb], triu_indices(m_symb.shape[0], k_symb, m_symb.shape[1]) ) @@ -1075,7 +1078,7 @@ def check_u_batch(m): assert np.allclose(result, np.triu(m, k)) assert result.dtype == np.dtype(dtype) - for dtype in ["int32", "int64", "float32", "float64", "uint16", "complex64"]: + for dtype in ["int32", "int64", "float32", "float64", "uint16"]: m = random_of_dtype((10, 10), dtype) check_l(m, 0) check_l(m, 1) From 3d930676f29cc764a8bc57e305cd024277e7d8bc Mon Sep 17 00:00:00 2001 From: nimish Date: Thu, 13 Mar 2025 21:24:41 +0530 Subject: [PATCH 5/8] Add back 'complex64' type to tests --- tests/tensor/test_basic.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index 21b10de8e7..baa2378f79 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -1006,7 +1006,8 @@ def check(dtype, N, M_=None, k=0): "float32", "float64", "uint16", - ]: # Handle "complex64" ? + "complex64", + ]: check(dtype, 3) # M != N, k = 0 check(dtype, 3, 5) @@ -1078,7 +1079,7 @@ def check_u_batch(m): assert np.allclose(result, np.triu(m, k)) assert result.dtype == np.dtype(dtype) - for dtype in ["int32", "int64", "float32", "float64", "uint16"]: + for dtype in ["int32", "int64", "float32", "float64", "uint16", "complex64"]: m = random_of_dtype((10, 10), dtype) check_l(m, 0) check_l(m, 1) From 3cc8e68a97b7f7bcd72f17d1100228b6213b5d01 Mon Sep 17 00:00:00 2001 From: nimish Date: Thu, 13 Mar 2025 21:34:33 +0530 Subject: [PATCH 6/8] Move iota tests to test_basic.py --- tests/tensor/test_basic.py | 24 ++++++++++++++++++++++++ tests/tensor/test_einsum.py | 25 +------------------------ 2 files changed, 25 insertions(+), 24 deletions(-) diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index baa2378f79..bda86f2c6e 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -58,6 +58,7 @@ identity_like, infer_static_shape, inverse_permutation, + iota, join, make_vector, mgrid, @@ -980,6 +981,29 @@ def test_static_output_type(self): assert eye(1, l, 3).type.shape == (1, None) +def test_iota(): + mode = Mode(linker="py", optimizer=None) + np.testing.assert_allclose( + iota((4, 8), 0).eval(mode=mode), + [ + [0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 1, 1], + [2, 2, 2, 2, 2, 2, 2, 2], + [3, 3, 3, 3, 3, 3, 3, 3], + ], + ) + + np.testing.assert_allclose( + iota((4, 8), 1).eval(mode=mode), + [ + [0, 1, 2, 3, 4, 5, 6, 7], + [0, 1, 2, 3, 4, 5, 6, 7], + [0, 1, 2, 3, 4, 5, 6, 7], + [0, 1, 2, 3, 4, 5, 6, 7], + ], + ) + + class TestTriangle: def test_tri(self): def check(dtype, N, M_=None, k=0): diff --git a/tests/tensor/test_einsum.py b/tests/tensor/test_einsum.py index ba8e354518..8e4e14855c 100644 --- a/tests/tensor/test_einsum.py +++ b/tests/tensor/test_einsum.py @@ -10,7 +10,7 @@ from pytensor.graph.op import HasInnerGraph from pytensor.tensor.basic import moveaxis from pytensor.tensor.blockwise import Blockwise -from pytensor.tensor.einsum import _delta, _general_dot, _iota, einsum +from pytensor.tensor.einsum import _delta, _general_dot, einsum from pytensor.tensor.shape import Reshape from pytensor.tensor.type import tensor @@ -38,29 +38,6 @@ def assert_no_blockwise_in_graph(fgraph: FunctionGraph, core_op=None) -> None: assert_no_blockwise_in_graph(inner_fgraph, core_op=core_op) -def test_iota(): - mode = Mode(linker="py", optimizer=None) - np.testing.assert_allclose( - _iota((4, 8), 0).eval(mode=mode), - [ - [0, 0, 0, 0, 0, 0, 0, 0], - [1, 1, 1, 1, 1, 1, 1, 1], - [2, 2, 2, 2, 2, 2, 2, 2], - [3, 3, 3, 3, 3, 3, 3, 3], - ], - ) - - np.testing.assert_allclose( - _iota((4, 8), 1).eval(mode=mode), - [ - [0, 1, 2, 3, 4, 5, 6, 7], - [0, 1, 2, 3, 4, 5, 6, 7], - [0, 1, 2, 3, 4, 5, 6, 7], - [0, 1, 2, 3, 4, 5, 6, 7], - ], - ) - - def test_delta(): mode = Mode(linker="py", optimizer=None) np.testing.assert_allclose( From f5fc9df731f56ff67eb6e9b5c25c81dc32e42ad5 Mon Sep 17 00:00:00 2001 From: nimish Date: Fri, 14 Mar 2025 17:13:17 +0530 Subject: [PATCH 7/8] Modify InferShape test for Tri class --- tests/tensor/test_basic.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index bda86f2c6e..ba9adb966e 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -3930,15 +3930,15 @@ def test_Tri(self): biscal = iscalar() ciscal = iscalar() self._compile_and_check( - [aiscal, biscal, ciscal], [Tri()(aiscal, biscal, ciscal)], [4, 4, 0], Tri + [aiscal, biscal, ciscal], [tri(aiscal, biscal, ciscal)], [4, 4, 0], Tri ) self._compile_and_check( - [aiscal, biscal, ciscal], [Tri()(aiscal, biscal, ciscal)], [4, 5, 0], Tri + [aiscal, biscal, ciscal], [tri(aiscal, biscal, ciscal)], [4, 5, 0], Tri ) self._compile_and_check( - [aiscal, biscal, ciscal], [Tri()(aiscal, biscal, ciscal)], [3, 5, 0], Tri + [aiscal, biscal, ciscal], [tri(aiscal, biscal, ciscal)], [3, 5, 0], Tri ) def test_ExtractDiag(self): From 00ab6b246632f2ce458a5ac50d31c4bd5fac39f6 Mon Sep 17 00:00:00 2001 From: nimish Date: Fri, 14 Mar 2025 20:23:50 +0530 Subject: [PATCH 8/8] Fix triu docstring; remove dtype prop in JAX version --- pytensor/link/jax/dispatch/tensor_basic.py | 2 +- pytensor/tensor/basic.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pytensor/link/jax/dispatch/tensor_basic.py b/pytensor/link/jax/dispatch/tensor_basic.py index 2956afad02..b0546b42a4 100644 --- a/pytensor/link/jax/dispatch/tensor_basic.py +++ b/pytensor/link/jax/dispatch/tensor_basic.py @@ -203,6 +203,6 @@ def tri(*args): x if const_x is None else const_x for x, const_x in zip(args, const_args, strict=True) ] - return jnp.tri(*args, dtype=op.dtype) + return jnp.tri(*args) return tri diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index c0e6f51d2a..dc85fbd3b6 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -1268,7 +1268,7 @@ def triu(m, k=0): [ 0, 8, 9], [ 0, 0, 12]]) - >>> pt.triu(np.arange(3 * 4 * 5).reshape((3, 4, 5))).eval() + >>> pt.triu(pt.arange(3 * 4 * 5).reshape((3, 4, 5))).eval() array([[[ 0, 1, 2, 3, 4], [ 0, 6, 7, 8, 9], [ 0, 0, 12, 13, 14],