Skip to content

Commit b974342

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Add support for unreduced + scan over layers.
This requires preserving unreduced through broadcast_in_dim and dynamic_update_slice which works because both the operations are linear operations. PiperOrigin-RevId: 814943483
1 parent a7afcf7 commit b974342

File tree

5 files changed

+83
-6
lines changed

5 files changed

+83
-6
lines changed

jax/_src/core.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3090,7 +3090,9 @@ def _map_shaped_array(
30903090
assert axis is None or aval.shape[axis] == size
30913091
if axis is None:
30923092
return aval
3093-
sharding = aval.sharding.update(spec=tuple_delete(aval.sharding.spec, axis))
3093+
aval_s = aval.sharding
3094+
sharding = aval_s.update(
3095+
spec=aval_s.spec.update(partitions=tuple_delete(aval_s.spec, axis)))
30943096
return ShapedArray(tuple_delete(aval.shape, axis), aval.dtype,
30953097
weak_type=aval.weak_type, sharding=sharding, vma=aval.vma,
30963098
memory_space=aval.memory_space)
@@ -3101,8 +3103,9 @@ def _unmap_shaped_array(
31013103
if axis is None:
31023104
return aval
31033105
elif type(axis) is int:
3104-
sharding = aval.sharding.update(spec=tuple_insert(
3105-
aval.sharding.spec, axis, explicit_mesh_axis))
3106+
aval_s = aval.sharding
3107+
sharding = aval_s.update(spec=aval_s.spec.update(partitions=tuple_insert(
3108+
aval_s.spec, axis, explicit_mesh_axis)))
31063109
return ShapedArray(tuple_insert(aval.shape, axis, size), aval.dtype,
31073110
weak_type=aval.weak_type, sharding=sharding,
31083111
vma=aval.vma, memory_space=aval.memory_space)

jax/_src/lax/control_flow/loops.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -531,6 +531,7 @@ def cond_fun(while_carry):
531531
# knows not to AR at the boundary of while. This is a no-op at the trace level
532532
# but during lowering time, it inserts an extra sharding constraint.
533533
carry = tree_map(_constrain_unreduced, carry)
534+
ys = tree_map(_constrain_unreduced, ys)
534535
return [*carry, *ys]
535536

536537
def _constrain_unreduced(val):
@@ -544,7 +545,8 @@ def _split_leading(sz, x):
544545
def _concat(a, b): return lax.concatenate([a, b], 0)
545546

546547
def _empty_array(prefix, length_spec, aval):
547-
sharding = aval.sharding.update(spec=(*length_spec, *aval.sharding.spec))
548+
sharding = aval.sharding.update(spec=aval.sharding.spec.update(
549+
partitions=(*length_spec, *aval.sharding.spec)))
548550
# TODO(yashkatariya): Replace `lax.empty2` with `lax.empty` once
549551
# AllocateBuffer issues are fixed. Also delete `empty2` after this usage is
550552
# removed. Basically uncomment the following 2 lines.

jax/_src/lax/lax.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6528,7 +6528,8 @@ def _broadcast_in_dim_sharding_rule(operand, *, shape, broadcast_dimensions,
65286528
orig_spec = iter(operand.sharding.spec)
65296529
new_spec = [next(orig_spec) if i in bds else None for i in range(len(shape))]
65306530
assert next(orig_spec, None) is None
6531-
return operand.sharding.update(spec=new_spec)
6531+
return operand.sharding.update(
6532+
spec=operand.sharding.spec.update(partitions=new_spec))
65326533

65336534
def _broadcast_in_dim_typecheck_rule(
65346535
_, operand, *dyn_shape, shape, broadcast_dimensions, sharding):

jax/_src/lax/slicing.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1670,6 +1670,15 @@ def _dynamic_update_slice_sharding_rule(operand, update, *start_indices):
16701670
f" {update.str_short(mesh_axis_types=True)}.")
16711671
return operand.sharding
16721672

1673+
def _dynamic_update_slice_unreduced_rule(out_s, operand, update, *start_indices):
1674+
if operand.sharding.spec.unreduced != update.sharding.spec.unreduced:
1675+
raise core.ShardingTypeError(
1676+
"dynamic_update_slice operand and update must be unreduced along the"
1677+
" same axes. Got operand sharding"
1678+
f" {operand.str_short(mesh_axis_types=True)} and update sharding"
1679+
f" {update.str_short(mesh_axis_types=True)}.")
1680+
return out_s
1681+
16731682
def _dynamic_update_slice_dtype_rule(operand, update, *start_indices):
16741683
lax.check_same_dtypes("dynamic_update_slice", operand, update)
16751684
if any(i.dtype != start_indices[0].dtype or
@@ -1735,7 +1744,8 @@ def _dynamic_update_slice_batching_rule(batched_args, batch_dims):
17351744
dynamic_update_slice_p = standard_primitive(
17361745
_dynamic_update_slice_shape_rule, _dynamic_update_slice_dtype_rule,
17371746
'dynamic_update_slice', sharding_rule=_dynamic_update_slice_sharding_rule,
1738-
vma_rule=partial(core.standard_vma_rule, 'dynamic_update_slice'))
1747+
vma_rule=partial(core.standard_vma_rule, 'dynamic_update_slice'),
1748+
unreduced_rule=_dynamic_update_slice_unreduced_rule)
17391749
ad.primitive_jvps[dynamic_update_slice_p] = _dynamic_update_slice_jvp
17401750
ad.primitive_transposes[dynamic_update_slice_p] = \
17411751
_dynamic_update_slice_transpose_rule

tests/pjit_test.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9375,6 +9375,67 @@ def f(x, y):
93759375
' same'):
93769376
f(arr1, arr2)
93779377

9378+
@jtu.with_explicit_mesh((2,), 'x')
9379+
def test_scan_over_layers_minibatch_unreduced(self, mesh):
9380+
if ifrt_version < 30:
9381+
self.skipTest('Requires ifrt_version >= 30')
9382+
if not jtu.if_cloud_tpu_at_least(2025, 9, 21):
9383+
self.skipTest("Requires libtpu built after 2025-09-21")
9384+
9385+
def assert_unreduced(val):
9386+
self.assertEqual(val.aval.sharding.spec.unreduced, {'x'})
9387+
9388+
@jax.custom_vjp
9389+
def f(xs, w):
9390+
return jnp.dot(xs, w)
9391+
9392+
def f_fwd(xs, w):
9393+
return f(xs, w), (xs, w)
9394+
9395+
def f_bwd(res, g):
9396+
xs, w = res
9397+
return jnp.dot(g, w), jnp.dot(xs.T, g, out_sharding=P(unreduced={'x'}))
9398+
f.defvjp(f_fwd, f_bwd)
9399+
9400+
def model(stacked_ws, xs_mubatch):
9401+
def scan_over_layers(carry_xs, w):
9402+
return f(carry_xs, w), None
9403+
final_xs, _ = jax.lax.scan(scan_over_layers, xs_mubatch, stacked_ws)
9404+
return jnp.sum(final_xs)
9405+
9406+
@partial(jax.jit, donate_argnums=(0,))
9407+
def step(stacked_ws, xs):
9408+
def mubatch_loop_body(stacked_grad_acc, xs_mubatch):
9409+
grad = jax.grad(model)(stacked_ws, xs_mubatch)
9410+
assert_unreduced(grad)
9411+
assert_unreduced(stacked_grad_acc)
9412+
stacked_grad_acc = jax.tree.map(jnp.add, stacked_grad_acc, grad)
9413+
assert_unreduced(stacked_grad_acc)
9414+
return stacked_grad_acc, None
9415+
9416+
stacked_grad_acc = jax.tree.map(jnp.zeros_like, stacked_ws)
9417+
stacked_grad_acc = reshard(stacked_grad_acc, P(unreduced={'x'}))
9418+
stacked_grad_acc, _ = jax.lax.scan(
9419+
mubatch_loop_body, stacked_grad_acc, xs)
9420+
assert_unreduced(stacked_grad_acc)
9421+
# AR once for a batch
9422+
stacked_grad_acc = reshard(stacked_grad_acc, P())
9423+
return jax.tree.map(
9424+
lambda W, g: W - g * 0.01, stacked_ws, stacked_grad_acc)
9425+
9426+
ws = tuple(jax.device_put(jnp.ones((4, 4)), P()) for _ in range(4))
9427+
xs = jax.device_put(jnp.ones((2, 2, 4)), P(None, 'x', None))
9428+
stacked_ws = jnp.stack(ws, axis=0)
9429+
step(stacked_ws, xs) # doesn't crash
9430+
9431+
compiled_text = step.lower(stacked_ws, xs).compile().as_text()
9432+
if compiled_text is not None:
9433+
if jtu.test_device_matches(['gpu']):
9434+
self.assertEqual(compiled_text.count('all-reduce-start('), 1)
9435+
self.assertEqual(compiled_text.count('all-reduce-done('), 1)
9436+
else:
9437+
self.assertEqual(compiled_text.count('all-reduce('), 1)
9438+
93789439

93799440
@jtu.pytest_mark_if_available('multiaccelerator')
93809441
class PJitErrorTest(jtu.JaxTestCase):

0 commit comments

Comments
 (0)