Skip to content

Commit ce70501

Browse files
committed
Add numba overload for Nonzero
1 parent 2a7f3e1 commit ce70501

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

pytensor/tensor/basic.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from typing import TYPE_CHECKING, Union
1414
from typing import cast as type_cast
1515

16+
import numba as nb
1617
import numpy as np
1718
from numpy.exceptions import AxisError
1819

@@ -972,6 +973,7 @@ def make_node(self, a):
972973
output = [TensorType(dtype="int64", shape=(None,))() for i in range(a.ndim)]
973974
return Apply(self, [a], output)
974975

976+
@nb.njit
975977
def perform(self, node, inp, out_):
976978
a = inp[0]
977979

0 commit comments

Comments
 (0)