Skip to content

Commit 8e8ea0c

Browse files
Ignore axis argument in numba CumOp when input is 1d
1 parent 17c675a commit 8e8ea0c

File tree

2 files changed

+18
-2
lines changed

2 files changed

+18
-2
lines changed

pytensor/link/numba/dispatch/extra_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def numba_funcify_CumOp(op: CumOp, node: Apply, **kwargs):
4949

5050
@numba_basic.numba_njit
5151
def cumop(x):
52-
return np.cumsum(x, axis=axis)
52+
return np.cumsum(x)
5353

5454
else:
5555

@@ -73,7 +73,7 @@ def cumop(x):
7373

7474
@numba_basic.numba_njit
7575
def cumop(x):
76-
return np.cumprod(x, axis=axis)
76+
return np.cumprod(x)
7777

7878
else:
7979

tests/link/numba/test_extra_ops.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,3 +372,19 @@ def test_Searchsorted(a, v, side, sorter, exc):
372372
g,
373373
[test_a, test_v] if sorter is None else [test_a, test_v, test_sorter],
374374
)
375+
376+
377+
@pytest.mark.parametrize(
378+
"op", [extra_ops.cumsum, extra_ops.cumprod], ids=["sum", "prod"]
379+
)
380+
def test_1d_cumsum_with_axis_arg(op):
381+
x = pt.vector()
382+
test_x = np.array([1.0, 2.0, 3.0], dtype=config.floatX)
383+
384+
g = op(x, axis=0)
385+
386+
compare_numba_and_py(
387+
[x],
388+
g,
389+
[test_x],
390+
)

0 commit comments

Comments
 (0)