Skip to content

Commit 599eb7b

Browse files
Merge pull request #33602 from yashwantbezawada:fix-arange-constant-folding
PiperOrigin-RevId: 841846305
2 parents fac6550 + c71327c commit 599eb7b

File tree

2 files changed

+53
-3
lines changed

2 files changed

+53
-3
lines changed

jax/_src/numpy/lax_numpy.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5989,9 +5989,23 @@ def _arange(start: ArrayLike | DimSize, stop: ArrayLike | DimSize | None = None,
59895989
start = ceil_(start).astype(int)
59905990
return lax.broadcasted_iota(dtype, (start,), 0, out_sharding=out_sharding) # type: ignore[arg-type]
59915991
else:
5992-
if step is None and start == 0 and stop is not None:
5993-
return lax.broadcasted_iota(dtype, (np.ceil(stop).astype(int),), 0,
5994-
out_sharding=out_sharding)
5992+
if step is None and stop is not None:
5993+
# Skip optimization if start or stop is complex (ceil doesn't support complex)
5994+
start_dtype = _dtype(start)
5995+
stop_dtype = _dtype(stop)
5996+
if (dtypes.issubdtype(start_dtype, np.complexfloating) or
5997+
dtypes.issubdtype(stop_dtype, np.complexfloating)):
5998+
return array(np.arange(start, stop=stop, step=step, dtype=dtype),
5999+
device=out_sharding)
6000+
# Use iota + offset instead of creating a constant array
6001+
size = int(np.ceil(stop - start))
6002+
if size <= 0:
6003+
return array([], dtype=dtype, device=out_sharding)
6004+
result = lax.broadcasted_iota(dtype, (size,), 0, out_sharding=out_sharding)
6005+
if start != 0:
6006+
# Add offset if start is non-zero
6007+
result = lax.add(result, lax.convert_element_type(start, dtype))
6008+
return result
59956009
return array(np.arange(start, stop=stop, step=step, dtype=dtype),
59966010
device=out_sharding)
59976011

tests/lax_numpy_test.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4869,6 +4869,42 @@ def testArangeJaxpr(self, args, specify_device):
48694869
self.assertEqual(len(jaxpr.jaxpr.eqns), num_eqs)
48704870
self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, lax.iota_p)
48714871

4872+
@jtu.sample_product(specify_device=[True, False])
4873+
def testArangeJaxprNonZeroStart(self, specify_device):
4874+
device = jax.devices()[-1] if specify_device else None
4875+
jaxpr = jax.make_jaxpr(lambda: jnp.arange(1, 5, device=device))()
4876+
# Non-zero start should produce iota + add (+ device_put if device specified)
4877+
num_eqs = 3 if device is not None else 2
4878+
self.assertEqual(len(jaxpr.jaxpr.eqns), num_eqs)
4879+
self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, lax.iota_p)
4880+
self.assertEqual(jaxpr.jaxpr.eqns[1].primitive, lax.add_p)
4881+
4882+
@jtu.sample_product(
4883+
dtype=[np.int32, np.float32],
4884+
iteration=range(10)
4885+
)
4886+
def testArangeRandomValues(self, dtype, iteration):
4887+
del iteration # not needed: each test case gets its own random seed.
4888+
rng = jtu.rand_default(self.rng())
4889+
start = rng((), dtype)
4890+
stop = rng((), dtype)
4891+
jax_result = jnp.arange(start, stop, dtype=dtype)
4892+
np_result = np.arange(start, stop, dtype=dtype)
4893+
self.assertAllClose(jax_result, np_result)
4894+
4895+
def testArangeComplex(self):
4896+
test_cases = [
4897+
(1+2j, 5+3j),
4898+
(0+0j, 5+0j),
4899+
(1.0+0j, 5.0+0j),
4900+
(0, 5, 1+1j),
4901+
]
4902+
for args in test_cases:
4903+
with self.subTest(args=args):
4904+
jax_result = jnp.arange(*args)
4905+
np_result = np.arange(*args)
4906+
self.assertArraysEqual(jax_result, np_result)
4907+
48724908
def testIssue830(self):
48734909
a = jnp.arange(4, dtype=jnp.complex64)
48744910
self.assertEqual(a.dtype, jnp.complex64)

0 commit comments

Comments
 (0)