|
68 | 68 | from jax._src.numpy import util |
69 | 69 | from jax._src.numpy.vectorize import vectorize |
70 | 70 | from jax._src.typing import ( |
71 | | - Array, ArrayLike, DimSize, DuckTypedArray, |
72 | | - DType, DTypeLike, Shape, DeprecatedArg |
| 71 | + Array, ArrayLike, DeprecatedArg, DimSize, DuckTypedArray, |
| 72 | + DType, DTypeLike, Shape, StaticScalar, |
73 | 73 | ) |
74 | 74 | from jax._src.util import (unzip2, subvals, safe_zip, |
75 | 75 | ceil_of_ratio, partition_list, |
@@ -5637,7 +5637,7 @@ def take( |
5637 | 5637 | mode: str | None = None, |
5638 | 5638 | unique_indices: bool = False, |
5639 | 5639 | indices_are_sorted: bool = False, |
5640 | | - fill_value: ArrayLike | None = None, |
| 5640 | + fill_value: StaticScalar | None = None, |
5641 | 5641 | ) -> Array: |
5642 | 5642 | return _take(a, indices, None if axis is None else operator.index(axis), out, |
5643 | 5643 | mode, unique_indices=unique_indices, indices_are_sorted=indices_are_sorted, |
|
0 commit comments