Skip to content

Commit 5d5ce1c

Browse files
author
jax authors
committed
Merge pull request #21433 from mattjj:scan-avoid-singleton-dim-in-lowering
PiperOrigin-RevId: 637255864
2 parents 787b7c2 + a24b738 commit 5d5ce1c

File tree

2 files changed

+33
-8
lines changed

2 files changed

+33
-8
lines changed

jax/_src/lax/control_flow/loops.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -392,13 +392,17 @@ def _scan_impl(*args, reverse, length, num_consts, num_carry, jaxpr, linear,
392392
consts, carry, xs_ = split_list(args, [num_consts, num_carry])
393393
_, y_avals = split_list(jaxpr.out_avals, [num_carry])
394394
num_trips, remainder = divmod(length, unroll)
395-
if remainder:
396-
if not reverse:
397-
xs_, xs_rem = unzip2(_map(partial(_split_leading, num_trips*unroll), xs_))
398-
else:
399-
xs_rem, xs_ = unzip2(_map(partial(_split_leading, remainder), xs_))
400-
xss = [lax.reshape(x, (num_trips, unroll, *x.shape[1:])) for x in xs_]
401-
yss = _map(partial(_empty_array, (num_trips, unroll)), y_avals)
395+
if unroll == 1:
396+
xss = xs_
397+
yss = _map(partial(_empty_array, (length,)), y_avals)
398+
else:
399+
if remainder:
400+
if not reverse:
401+
xs_, xs_rem = unzip2(_map(partial(_split_leading, num_trips*unroll), xs_))
402+
else:
403+
xs_rem, xs_ = unzip2(_map(partial(_split_leading, remainder), xs_))
404+
xss = [lax.reshape(x, (num_trips, unroll, *x.shape[1:])) for x in xs_]
405+
yss = _map(partial(_empty_array, (num_trips, unroll)), y_avals)
402406

403407
def cond_fun(while_carry):
404408
i, _, _ = while_carry
@@ -413,6 +417,9 @@ def body_fun(while_carry):
413417
return i_ + 1, carry, yss
414418
def inner(n, carry, xs):
415419
ys = []
420+
if unroll == 1:
421+
carry_y = eval_jaxpr_p.bind(*consts, *carry, *xs, jaxpr=jaxpr)
422+
return split_list(carry_y, [num_carry])
416423
for i_ in range(n):
417424
i = n - i_ - 1 if reverse else i_
418425
x = [slicing.index_in_dim(x, i, keepdims=False) for x in xs]
@@ -425,7 +432,10 @@ def inner(n, carry, xs):
425432
if num_trips:
426433
i = lax._const(num_trips, 0)
427434
_, carry, yss = jax.lax.while_loop(cond_fun, body_fun, (i, carry, yss))
428-
ys = [lax.reshape(ys, (num_trips * unroll, *ys.shape[2:])) for ys in yss]
435+
if unroll != 1:
436+
ys = [lax.reshape(ys, (num_trips * unroll, *ys.shape[2:])) for ys in yss]
437+
else:
438+
ys = yss
429439
if remainder:
430440
carry, ys_rem = inner(remainder, carry, xs_rem)
431441
ys = _map(_concat, ys, ys_rem) if not reverse else _map(_concat, ys_rem, ys)

tests/lax_control_flow_test.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2913,6 +2913,21 @@ def f(x):
29132913
self.assertEqual(expect_a_dot, " dot(" in hlo)
29142914
self.assertEqual(not expect_a_dot, " while(" in hlo)
29152915

2916+
def test_scan_lowering_doesnt_introduce_singleton(self):
2917+
b = 4
2918+
i = 2
2919+
2920+
def scan(y):
2921+
def body(carry, x):
2922+
return carry, jnp.dot(x, x)
2923+
return jax.lax.scan(body, 1.0, y, unroll=False)
2924+
2925+
fn = jax.jit(scan)
2926+
2927+
init = np.array(np.arange(b * i * i), dtype=np.float32).reshape((b, i, i))
2928+
hlo_text = fn.lower(init).as_text('hlo')
2929+
self.assertNotIn('4,1,2,2', hlo_text)
2930+
29162931

29172932
if __name__ == '__main__':
29182933
absltest.main(testLoader=jtu.JaxTestLoader())

0 commit comments

Comments
 (0)