Skip to content

Commit 0ff0d7b

Browse files
committed
jnp.take: fix annotation for fill_value
1 parent 5d5ce1c commit 0ff0d7b

File tree

5 files changed

+22
-10
lines changed

5 files changed

+22
-10
lines changed

jax/_src/basearray.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,14 @@ def sharding(self) -> Sharding:
115115

116116
Array.__module__ = "jax"
117117

118+
# StaticScalar is the Union of all scalar types that can be converted to
119+
# JAX arrays, and are possible to mark as static arguments.
120+
StaticScalar = Union[
121+
np.bool_, np.number, # NumPy scalar types
122+
bool, int, float, complex, # Python scalar types
123+
]
124+
StaticScalar.__doc__ = "Type annotation for JAX-compatible static scalars."
125+
118126

119127
# ArrayLike is a Union of all objects that can be implicitly converted to a
120128
# standard JAX array (i.e. not including future non-standard array types like
@@ -123,7 +131,6 @@ def sharding(self) -> Sharding:
123131
ArrayLike = Union[
124132
Array, # JAX array type
125133
np.ndarray, # NumPy array type
126-
np.bool_, np.number, # NumPy scalar types
127-
bool, int, float, complex, # Python scalar types
134+
StaticScalar, # valid scalars
128135
]
129136
ArrayLike.__doc__ = "Type annotation for JAX array-like objects."

jax/_src/basearray.pyi

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -217,11 +217,15 @@ class Array(abc.ABC):
217217
def unsafe_buffer_pointer(self) -> int: ...
218218

219219

220+
StaticScalar = Union[
221+
np.bool_, np.number, # NumPy scalar types
222+
bool, int, float, complex, # Python scalar types
223+
]
224+
220225
ArrayLike = Union[
221226
Array, # JAX array type
222227
np.ndarray, # NumPy array type
223-
np.bool_, np.number, # NumPy scalar types
224-
bool, int, float, complex, # Python scalar types
228+
StaticScalar, # valid scalars
225229
]
226230

227231

jax/_src/numpy/lax_numpy.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@
6868
from jax._src.numpy import util
6969
from jax._src.numpy.vectorize import vectorize
7070
from jax._src.typing import (
71-
Array, ArrayLike, DimSize, DuckTypedArray,
72-
DType, DTypeLike, Shape, DeprecatedArg
71+
Array, ArrayLike, DeprecatedArg, DimSize, DuckTypedArray,
72+
DType, DTypeLike, Shape, StaticScalar,
7373
)
7474
from jax._src.util import (unzip2, subvals, safe_zip,
7575
ceil_of_ratio, partition_list,
@@ -5637,7 +5637,7 @@ def take(
56375637
mode: str | None = None,
56385638
unique_indices: bool = False,
56395639
indices_are_sorted: bool = False,
5640-
fill_value: ArrayLike | None = None,
5640+
fill_value: StaticScalar | None = None,
56415641
) -> Array:
56425642
return _take(a, indices, None if axis is None else operator.index(axis), out,
56435643
mode, unique_indices=unique_indices, indices_are_sorted=indices_are_sorted,

jax/_src/typing.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from jax._src.basearray import (
3535
Array as Array,
3636
ArrayLike as ArrayLike,
37+
StaticScalar as StaticScalar,
3738
)
3839

3940
DType = np.dtype

jax/numpy/__init__.pyi

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ from jax._src.lax.slicing import GatherScatterMode
1111
from jax._src.lib import Device
1212
from jax._src.numpy.index_tricks import _Mgrid, _Ogrid, CClass as _CClass, RClass as _RClass
1313
from jax._src.typing import (
14-
Array, ArrayLike, DType, DTypeLike,
15-
DimSize, DuckTypedArray, Shape, DeprecatedArg
14+
Array, ArrayLike, DType, DTypeLike, DeprecatedArg,
15+
DimSize, DuckTypedArray, Shape, StaticScalar,
1616
)
1717
from jax.numpy import fft as fft, linalg as linalg
1818
from jax.sharding import Sharding as _Sharding
@@ -804,7 +804,7 @@ def take(
804804
mode: Optional[str] = ...,
805805
unique_indices: builtins.bool = ...,
806806
indices_are_sorted: builtins.bool = ...,
807-
fill_value: Optional[ArrayLike] = ...,
807+
fill_value: Optional[StaticScalar] = ...,
808808
) -> Array: ...
809809
def take_along_axis(
810810
arr: ArrayLike,

0 commit comments

Comments
 (0)