Skip to content

Commit eeab9e4

Browse files
Merge pull request #33819 from jakevdp:arange-followup
PiperOrigin-RevId: 842326801
2 parents 863e4e7 + 1bfd8dc commit eeab9e4

File tree

1 file changed

+39
-43
lines changed

1 file changed

+39
-43
lines changed

jax/_src/numpy/lax_numpy.py

Lines changed: 39 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
import numpy as np
3737

3838
from jax._src import api
39-
from jax._src import config
4039
from jax._src import core
4140
from jax._src import deprecations
4241
from jax._src import dtypes
@@ -143,7 +142,8 @@ def iscomplexobj(x: Any) -> bool:
143142
>>> jnp.iscomplexobj(jnp.array([0, 1+2j]))
144143
True
145144
"""
146-
if x is None:
145+
# Check for int here to avoid potential overflow in jnp.array below.
146+
if x is None or isinstance(x, int):
147147
return False
148148
try:
149149
typ = x.dtype.type
@@ -5954,60 +5954,56 @@ def arange(start: ArrayLike | DimSize, stop: ArrayLike | DimSize | None = None,
59545954
def _arange(start: ArrayLike | DimSize, stop: ArrayLike | DimSize | None = None,
59555955
step: ArrayLike | None = None, dtype: DTypeLike | None = None,
59565956
out_sharding: NamedSharding | None = None) -> Array:
5957+
# Validate inputs
59575958
if dtype is not None:
59585959
dtype = dtypes.check_and_canonicalize_user_dtype(dtype, "arange")
5959-
if not config.dynamic_shapes.value:
5960-
util.check_arraylike("arange", start)
5961-
if stop is None and step is None:
5962-
start = core.concrete_or_error(None, start, "It arose in the jnp.arange argument 'stop'")
5963-
else:
5964-
start = core.concrete_or_error(None, start, "It arose in the jnp.arange argument 'start'")
5965-
util.check_arraylike_or_none("arange", None, stop, step)
5960+
util.check_arraylike_or_none("arange", start, stop, step)
5961+
5962+
# Ensure start/stop/step are concrete
5963+
start_name = "stop" if stop is None and step is None else "start"
5964+
start = core.concrete_or_error(None, start, f"It arose in the jnp.arange argument '{start_name}'")
59665965
stop = core.concrete_or_error(None, stop, "It arose in the jnp.arange argument 'stop'")
59675966
step = core.concrete_or_error(None, step, "It arose in the jnp.arange argument 'step'")
5968-
start_name = "stop" if stop is None and step is None else "start"
5967+
5968+
# Ensure start/stop/step are scalars
59695969
for name, val in [(start_name, start), ("stop", stop), ("step", step)]:
59705970
if val is not None and np.ndim(val) != 0:
59715971
raise ValueError(f"jax.numpy.arange: arguments must be scalars; got {name}={val}")
5972+
5973+
# Handle symbolic dimensions
59725974
if any(core.is_symbolic_dim(v) for v in (start, stop, step)):
5973-
# Some dynamic shapes
5974-
if stop is None and step is None:
5975-
stop = start
5976-
start = 0
5977-
step = 1
5978-
elif stop is not None and step is None:
5975+
if stop is None:
5976+
start, stop = 0, start
5977+
if step is None:
59795978
step = 1
59805979
return _arange_dynamic(start, stop, step, dtype or dtypes.default_int_dtype())
5980+
59815981
if dtype is None:
5982-
dtype = result_type(start, *(x for x in [stop, step] if x is not None))
5982+
dtype = dtypes.result_type(start, *(x for x in [stop, step] if x is not None))
59835983
dtype = dtypes.jax_dtype(dtype)
5984-
if stop is None and step is None:
5985-
start_dtype = _dtype(start)
5986-
if (not dtypes.issubdtype(start_dtype, np.integer) and
5987-
not dtypes.issubdtype(start_dtype, dtypes.extended)):
5988-
ceil_ = ufuncs.ceil if isinstance(start, core.Tracer) else np.ceil
5989-
start = ceil_(start).astype(int)
5990-
return lax.broadcasted_iota(dtype, (start,), 0, out_sharding=out_sharding) # type: ignore[arg-type]
5984+
5985+
if iscomplexobj(start) or iscomplexobj(stop) or iscomplexobj(step):
5986+
# Complex arange is poorly defined; fall back to NumPy here.
5987+
# TODO(jakevdp): deprecate the complex case.
5988+
return array(np.arange(start, stop, step, dtype=dtype), device=out_sharding)
5989+
5990+
if step is not None:
5991+
# arange(N, M, K): when step is specified, fall back to NumPy.
5992+
return array(np.arange(start, stop, step, dtype=dtype), device=out_sharding)
5993+
5994+
if stop is None:
5995+
start, stop = 0, start
5996+
5997+
if start == 0:
5998+
# arange(M) or arange(0, M)
5999+
size = max(0, int(np.ceil(stop)))
6000+
return lax.broadcasted_iota(dtype, (size,), 0, out_sharding=out_sharding)
6001+
59916002
else:
5992-
if step is None and stop is not None:
5993-
# Skip optimization if start or stop is complex (ceil doesn't support complex)
5994-
start_dtype = _dtype(start)
5995-
stop_dtype = _dtype(stop)
5996-
if (dtypes.issubdtype(start_dtype, np.complexfloating) or
5997-
dtypes.issubdtype(stop_dtype, np.complexfloating)):
5998-
return array(np.arange(start, stop=stop, step=step, dtype=dtype),
5999-
device=out_sharding)
6000-
# Use iota + offset instead of creating a constant array
6001-
size = int(np.ceil(stop - start))
6002-
if size <= 0:
6003-
return array([], dtype=dtype, device=out_sharding)
6004-
result = lax.broadcasted_iota(dtype, (size,), 0, out_sharding=out_sharding)
6005-
if start != 0:
6006-
# Add offset if start is non-zero
6007-
result = lax.add(result, lax.convert_element_type(start, dtype))
6008-
return result
6009-
return array(np.arange(start, stop=stop, step=step, dtype=dtype),
6010-
device=out_sharding)
6003+
# arange(N, M)
6004+
size = max(0, int(np.ceil(stop - start)))
6005+
return lax.add(lax.convert_element_type(start, dtype),
6006+
lax.broadcasted_iota(dtype, (size,), 0, out_sharding=out_sharding))
60116007

60126008

60136009
def _arange_dynamic(

0 commit comments

Comments
 (0)