|
36 | 36 | import numpy as np |
37 | 37 |
|
38 | 38 | from jax._src import api |
39 | | -from jax._src import config |
40 | 39 | from jax._src import core |
41 | 40 | from jax._src import deprecations |
42 | 41 | from jax._src import dtypes |
@@ -143,7 +142,8 @@ def iscomplexobj(x: Any) -> bool: |
143 | 142 | >>> jnp.iscomplexobj(jnp.array([0, 1+2j])) |
144 | 143 | True |
145 | 144 | """ |
146 | | - if x is None: |
| 145 | + # Check for int here to avoid potential overflow in jnp.array below. |
| 146 | + if x is None or isinstance(x, int): |
147 | 147 | return False |
148 | 148 | try: |
149 | 149 | typ = x.dtype.type |
@@ -5954,60 +5954,56 @@ def arange(start: ArrayLike | DimSize, stop: ArrayLike | DimSize | None = None, |
5954 | 5954 | def _arange(start: ArrayLike | DimSize, stop: ArrayLike | DimSize | None = None, |
5955 | 5955 | step: ArrayLike | None = None, dtype: DTypeLike | None = None, |
5956 | 5956 | out_sharding: NamedSharding | None = None) -> Array: |
| 5957 | + # Validate inputs |
5957 | 5958 | if dtype is not None: |
5958 | 5959 | dtype = dtypes.check_and_canonicalize_user_dtype(dtype, "arange") |
5959 | | - if not config.dynamic_shapes.value: |
5960 | | - util.check_arraylike("arange", start) |
5961 | | - if stop is None and step is None: |
5962 | | - start = core.concrete_or_error(None, start, "It arose in the jnp.arange argument 'stop'") |
5963 | | - else: |
5964 | | - start = core.concrete_or_error(None, start, "It arose in the jnp.arange argument 'start'") |
5965 | | - util.check_arraylike_or_none("arange", None, stop, step) |
| 5960 | + util.check_arraylike_or_none("arange", start, stop, step) |
| 5961 | + |
| 5962 | + # Ensure start/stop/step are concrete |
| 5963 | + start_name = "stop" if stop is None and step is None else "start" |
| 5964 | + start = core.concrete_or_error(None, start, f"It arose in the jnp.arange argument '{start_name}'") |
5966 | 5965 | stop = core.concrete_or_error(None, stop, "It arose in the jnp.arange argument 'stop'") |
5967 | 5966 | step = core.concrete_or_error(None, step, "It arose in the jnp.arange argument 'step'") |
5968 | | - start_name = "stop" if stop is None and step is None else "start" |
| 5967 | + |
| 5968 | + # Ensure start/stop/step are scalars |
5969 | 5969 | for name, val in [(start_name, start), ("stop", stop), ("step", step)]: |
5970 | 5970 | if val is not None and np.ndim(val) != 0: |
5971 | 5971 | raise ValueError(f"jax.numpy.arange: arguments must be scalars; got {name}={val}") |
| 5972 | + |
| 5973 | + # Handle symbolic dimensions |
5972 | 5974 | if any(core.is_symbolic_dim(v) for v in (start, stop, step)): |
5973 | | - # Some dynamic shapes |
5974 | | - if stop is None and step is None: |
5975 | | - stop = start |
5976 | | - start = 0 |
5977 | | - step = 1 |
5978 | | - elif stop is not None and step is None: |
| 5975 | + if stop is None: |
| 5976 | + start, stop = 0, start |
| 5977 | + if step is None: |
5979 | 5978 | step = 1 |
5980 | 5979 | return _arange_dynamic(start, stop, step, dtype or dtypes.default_int_dtype()) |
| 5980 | + |
5981 | 5981 | if dtype is None: |
5982 | | - dtype = result_type(start, *(x for x in [stop, step] if x is not None)) |
| 5982 | + dtype = dtypes.result_type(start, *(x for x in [stop, step] if x is not None)) |
5983 | 5983 | dtype = dtypes.jax_dtype(dtype) |
5984 | | - if stop is None and step is None: |
5985 | | - start_dtype = _dtype(start) |
5986 | | - if (not dtypes.issubdtype(start_dtype, np.integer) and |
5987 | | - not dtypes.issubdtype(start_dtype, dtypes.extended)): |
5988 | | - ceil_ = ufuncs.ceil if isinstance(start, core.Tracer) else np.ceil |
5989 | | - start = ceil_(start).astype(int) |
5990 | | - return lax.broadcasted_iota(dtype, (start,), 0, out_sharding=out_sharding) # type: ignore[arg-type] |
| 5984 | + |
| 5985 | + if iscomplexobj(start) or iscomplexobj(stop) or iscomplexobj(step): |
| 5986 | + # Complex arange is poorly defined; fall back to NumPy here. |
| 5987 | + # TODO(jakevdp): deprecate the complex case. |
| 5988 | + return array(np.arange(start, stop, step, dtype=dtype), device=out_sharding) |
| 5989 | + |
| 5990 | + if step is not None: |
| 5991 | + # arange(N, M, K): when step is specified, fall back to NumPy. |
| 5992 | + return array(np.arange(start, stop, step, dtype=dtype), device=out_sharding) |
| 5993 | + |
| 5994 | + if stop is None: |
| 5995 | + start, stop = 0, start |
| 5996 | + |
| 5997 | + if start == 0: |
| 5998 | + # arange(M) or arange(0, M) |
| 5999 | + size = max(0, int(np.ceil(stop))) |
| 6000 | + return lax.broadcasted_iota(dtype, (size,), 0, out_sharding=out_sharding) |
| 6001 | + |
5991 | 6002 | else: |
5992 | | - if step is None and stop is not None: |
5993 | | - # Skip optimization if start or stop is complex (ceil doesn't support complex) |
5994 | | - start_dtype = _dtype(start) |
5995 | | - stop_dtype = _dtype(stop) |
5996 | | - if (dtypes.issubdtype(start_dtype, np.complexfloating) or |
5997 | | - dtypes.issubdtype(stop_dtype, np.complexfloating)): |
5998 | | - return array(np.arange(start, stop=stop, step=step, dtype=dtype), |
5999 | | - device=out_sharding) |
6000 | | - # Use iota + offset instead of creating a constant array |
6001 | | - size = int(np.ceil(stop - start)) |
6002 | | - if size <= 0: |
6003 | | - return array([], dtype=dtype, device=out_sharding) |
6004 | | - result = lax.broadcasted_iota(dtype, (size,), 0, out_sharding=out_sharding) |
6005 | | - if start != 0: |
6006 | | - # Add offset if start is non-zero |
6007 | | - result = lax.add(result, lax.convert_element_type(start, dtype)) |
6008 | | - return result |
6009 | | - return array(np.arange(start, stop=stop, step=step, dtype=dtype), |
6010 | | - device=out_sharding) |
| 6003 | + # arange(N, M) |
| 6004 | + size = max(0, int(np.ceil(stop - start))) |
| 6005 | + return lax.add(lax.convert_element_type(start, dtype), |
| 6006 | + lax.broadcasted_iota(dtype, (size,), 0, out_sharding=out_sharding)) |
6011 | 6007 |
|
6012 | 6008 |
|
6013 | 6009 | def _arange_dynamic( |
|
0 commit comments