Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ Remember to align the itemized text with the first line of an item within a list
now use {class}`jax.Array` instead of {class}`np.ndarray`. You can recover
the old behavior by transforming the arguments via
`jax.tree.map(np.asarray, args)` before passing them to the callback.
* `complex_arr.astype(bool)` now follows the same semantics as NumPy, returning
False where `complex_arr` is equal to `0 + 0j`, and True otherwise.

* Deprecations & Removals
* Pallas now exclusively uses XLA for compiling kernels on GPU. The old
Expand Down
7 changes: 6 additions & 1 deletion jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2263,11 +2263,16 @@ def _convert_to_array_if_dtype_fails(x: ArrayLike) -> ArrayLike:
implementation dependent.
""")
def astype(x: ArrayLike, dtype: DTypeLike | None, /, *, copy: bool = True) -> Array:
util.check_arraylike("astype", x)
x_arr = asarray(x)
del copy # unused in JAX
if dtype is None:
dtype = dtypes.canonicalize_dtype(float_)
dtypes.check_user_dtype_supported(dtype, "astype")
return lax.convert_element_type(x, dtype)
# convert_element_type(complex, bool) has the wrong semantics.
if np.dtype(dtype) == bool and issubdtype(x_arr.dtype, complexfloating):
return (x_arr != _lax_const(x_arr, 0))
return lax.convert_element_type(x_arr, dtype)


@util.implements(np.asarray, lax_description=_ARRAY_DOC)
Expand Down
18 changes: 18 additions & 0 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3822,6 +3822,24 @@ def testAstype(self, from_dtype, to_dtype, use_method):
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)

@jtu.sample_product(
from_dtype=['int32', 'float32', 'complex64'],
use_method=[True, False],
)
def testAstypeBool(self, from_dtype, use_method, to_dtype='bool'):
rng = jtu.rand_some_zero(self.rng())
args_maker = lambda: [rng((3, 4), from_dtype)]
if (not use_method) and hasattr(np, "astype"): # Added in numpy 2.0
np_op = lambda x: np.astype(x, to_dtype)
else:
np_op = lambda x: np.asarray(x).astype(to_dtype)
if use_method:
jnp_op = lambda x: jnp.asarray(x).astype(to_dtype)
else:
jnp_op = lambda x: jnp.astype(x, to_dtype)
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)

def testAstypeInt4(self):
# Test converting from int4 to int8
x = np.array([1, -2, -3, 4, -8, 7], dtype=jnp.int4)
Expand Down