Skip to content

Commit f374387

Browse files
Merge pull request #33837 from jakevdp:arange-complex-dep
PiperOrigin-RevId: 842354532
2 parents 53ca7bb + c006c30 commit f374387

File tree

4 files changed

+25
-8
lines changed

4 files changed

+25
-8
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
2121
Please use `jax.lax.pcast(..., to='varying')` as the replacement.
2222
* `with mesh:` context manager has been deprecated.
2323
Please use `with jax.set_mesh(mesh):` instead.
24+
* Complex arguments passed to {func}`jax.numpy.arange` now result in a
25+
deprecation warning, because the output is poorly-defined.
2426

2527
* Changes:
2628
* jax's `Tracer` no longer inherits from `jax.Array` at runtime. However,

jax/_src/deprecations.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,8 @@ def warn(deprecation_id: str, message: str, stacklevel: int) -> None:
128128
register('jax-lax-dot-positional-args')
129129
register('jax-lib-module')
130130
register('jax-nn-one-hot-float-input')
131-
register("jax-numpy-astype-complex-to-real")
131+
register('jax-numpy-arange-complex')
132+
register('jax-numpy-astype-complex-to-real')
132133
register('jax-numpy-clip-args')
133134
register('jax-scipy-special-sph-harm')
134135
register('safer-randint-config')

jax/_src/numpy/lax_numpy.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5983,6 +5983,14 @@ def _arange(start: ArrayLike | DimSize, stop: ArrayLike | DimSize | None = None,
59835983
dtype = dtypes.jax_dtype(dtype)
59845984

59855985
if iscomplexobj(start) or iscomplexobj(stop) or iscomplexobj(step):
5986+
deprecations.warn(
5987+
"jax-numpy-arange-complex",
5988+
(
5989+
"Passing complex start/stop/step to jnp.arange is deprecated;"
5990+
" in the future this will result in a ValueError."
5991+
),
5992+
stacklevel=3
5993+
)
59865994
# Complex arange is poorly defined; fall back to NumPy here.
59875995
# TODO(jakevdp): deprecate the complex case.
59885996
return array(np.arange(start, stop, step, dtype=dtype), device=out_sharding)

tests/lax_numpy_test.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from jax._src import array
4747
from jax._src import config
4848
from jax._src import core
49+
from jax._src import deprecations
4950
from jax._src import dtypes
5051
from jax._src import test_util as jtu
5152
from jax._src.lax import lax as lax_internal
@@ -4892,18 +4893,23 @@ def testArangeRandomValues(self, dtype, iteration):
48924893
np_result = np.arange(start, stop, dtype=dtype)
48934894
self.assertAllClose(jax_result, np_result)
48944895

4895-
def testArangeComplex(self):
4896-
test_cases = [
4896+
@parameterized.parameters(
48974897
(1+2j, 5+3j),
48984898
(0+0j, 5+0j),
48994899
(1.0+0j, 5.0+0j),
49004900
(0, 5, 1+1j),
4901-
]
4902-
for args in test_cases:
4903-
with self.subTest(args=args):
4901+
)
4902+
def testArangeComplex(self, *args):
4903+
dep_id = "jax-numpy-arange-complex"
4904+
msg = "Passing complex start/stop/step to jnp.arange is deprecated"
4905+
if deprecations.is_accelerated(dep_id):
4906+
with self.assertRaisesRegex(ValueError, msg):
4907+
jax_result = jnp.arange(*args)
4908+
else:
4909+
with self.assertWarnsRegex(DeprecationWarning, msg):
49044910
jax_result = jnp.arange(*args)
4905-
np_result = np.arange(*args)
4906-
self.assertArraysEqual(jax_result, np_result)
4911+
np_result = np.arange(*args)
4912+
self.assertArraysEqual(jax_result, np_result)
49074913

49084914
def testIssue830(self):
49094915
a = jnp.arange(4, dtype=jnp.complex64)

0 commit comments

Comments
 (0)