@@ -4869,6 +4869,42 @@ def testArangeJaxpr(self, args, specify_device):
48694869 self .assertEqual (len (jaxpr .jaxpr .eqns ), num_eqs )
48704870 self .assertEqual (jaxpr .jaxpr .eqns [0 ].primitive , lax .iota_p )
48714871
4872+ @jtu .sample_product (specify_device = [True , False ])
4873+ def testArangeJaxprNonZeroStart (self , specify_device ):
4874+ device = jax .devices ()[- 1 ] if specify_device else None
4875+ jaxpr = jax .make_jaxpr (lambda : jnp .arange (1 , 5 , device = device ))()
4876+ # Non-zero start should produce iota + add (+ device_put if device specified)
4877+ num_eqs = 3 if device is not None else 2
4878+ self .assertEqual (len (jaxpr .jaxpr .eqns ), num_eqs )
4879+ self .assertEqual (jaxpr .jaxpr .eqns [0 ].primitive , lax .iota_p )
4880+ self .assertEqual (jaxpr .jaxpr .eqns [1 ].primitive , lax .add_p )
4881+
4882+ @jtu .sample_product (
4883+ dtype = [np .int32 , np .float32 ],
4884+ iteration = range (10 )
4885+ )
4886+ def testArangeRandomValues (self , dtype , iteration ):
4887+ del iteration # not needed: each test case gets its own random seed.
4888+ rng = jtu .rand_default (self .rng ())
4889+ start = rng ((), dtype )
4890+ stop = rng ((), dtype )
4891+ jax_result = jnp .arange (start , stop , dtype = dtype )
4892+ np_result = np .arange (start , stop , dtype = dtype )
4893+ self .assertAllClose (jax_result , np_result )
4894+
4895+ def testArangeComplex (self ):
4896+ test_cases = [
4897+ (1 + 2j , 5 + 3j ),
4898+ (0 + 0j , 5 + 0j ),
4899+ (1.0 + 0j , 5.0 + 0j ),
4900+ (0 , 5 , 1 + 1j ),
4901+ ]
4902+ for args in test_cases :
4903+ with self .subTest (args = args ):
4904+ jax_result = jnp .arange (* args )
4905+ np_result = np .arange (* args )
4906+ self .assertArraysEqual (jax_result , np_result )
4907+
48724908 def testIssue830 (self ):
48734909 a = jnp .arange (4 , dtype = jnp .complex64 )
48744910 self .assertEqual (a .dtype , jnp .complex64 )
0 commit comments