Skip to content

Commit 2d23a66

Browse files
committed
jnp.take_along_axis: support fill_value
1 parent 720d2b8 commit 2d23a66

File tree

3 files changed

+15
-5
lines changed

3 files changed

+15
-5
lines changed

jax/_src/numpy/lax_numpy.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5725,12 +5725,13 @@ def _normalize_index(index, axis_size):
57255725

57265726
@util.implements(np.take_along_axis, update_doc=False,
57275727
lax_description=TAKE_ALONG_AXIS_DOC)
5728-
@partial(jit, static_argnames=('axis', 'mode'))
5728+
@partial(jit, static_argnames=('axis', 'mode', 'fill_value'))
57295729
def take_along_axis(
57305730
arr: ArrayLike,
57315731
indices: ArrayLike,
57325732
axis: int | None,
57335733
mode: str | lax.GatherScatterMode | None = None,
5734+
fill_value: StaticScalar | None = None,
57345735
) -> Array:
57355736
util.check_arraylike("take_along_axis", arr, indices)
57365737
a = asarray(arr)
@@ -5743,8 +5744,9 @@ def take_along_axis(
57435744
if ndim(indices) != 1:
57445745
msg = "take_along_axis indices must be 1D if axis=None, got shape {}"
57455746
raise ValueError(msg.format(idx_shape))
5746-
return take_along_axis(a.ravel(), indices, 0)
5747-
rank = ndim(arr)
5747+
a = a.ravel()
5748+
axis = 0
5749+
rank = a.ndim
57485750
if rank != ndim(indices):
57495751
msg = "indices and arr must have the same number of dimensions; {} vs. {}"
57505752
raise ValueError(msg.format(ndim(indices), a.ndim))
@@ -5812,7 +5814,7 @@ def replace(tup, val):
58125814
collapsed_slice_dims=tuple(collapsed_slice_dims),
58135815
start_index_map=tuple(start_index_map))
58145816
return lax.gather(a, gather_indices_arr, dnums, tuple(slice_sizes),
5815-
mode="fill" if mode is None else mode)
5817+
mode="fill" if mode is None else mode, fill_value=fill_value)
58165818

58175819

58185820
### Indexing

jax/numpy/__init__.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -811,6 +811,7 @@ def take_along_axis(
811811
indices: ArrayLike,
812812
axis: Optional[int],
813813
mode: Optional[Union[str, GatherScatterMode]] = ...,
814+
fill_value: Optional[StaticScalar] = None,
814815
) -> Array: ...
815816
def tan(x: ArrayLike, /) -> Array: ...
816817
def tanh(x: ArrayLike, /) -> Array: ...

tests/lax_numpy_test.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4483,6 +4483,13 @@ def testTakeAlongAxisWithEmptyArgs(self):
44834483
x = jnp.ones((4, 0, 3), dtype=jnp.int32)
44844484
np.testing.assert_array_equal(x, jnp.take_along_axis(x, x, axis=1))
44854485

4486+
def testTakeAlongAxisOptionalArgs(self):
4487+
x = jnp.arange(5.0)
4488+
ind = jnp.array([0, 2, 4, 6])
4489+
expected = jnp.array([0.0, 2.0, 4.0, 10.0], dtype=x.dtype)
4490+
actual = jnp.take_along_axis(x, ind, axis=None, mode='fill', fill_value=10.0)
4491+
self.assertArraysEqual(expected, actual)
4492+
44864493
@jtu.sample_product(
44874494
dtype=inexact_dtypes,
44884495
shape=[0, 5],
@@ -5973,7 +5980,7 @@ def testWrappedSignaturesMatch(self):
59735980
'clip': ['x', 'max', 'min'],
59745981
'einsum': ['subscripts', 'precision'],
59755982
'einsum_path': ['subscripts'],
5976-
'take_along_axis': ['mode'],
5983+
'take_along_axis': ['mode', 'fill_value'],
59775984
'fill_diagonal': ['inplace'],
59785985
}
59795986

0 commit comments

Comments
 (0)