Skip to content

Commit 795c11c

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Make unreduced + scan_over_layers + microbatching loop work with jax.grad + reduced annotations too in addition to custom_vjp
PiperOrigin-RevId: 814975937
1 parent b974342 commit 795c11c

File tree

4 files changed

+41
-19
lines changed

4 files changed

+41
-19
lines changed

jax/_src/lax/lax.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7083,7 +7083,8 @@ def _squeeze_sharding_rule(operand, *, dimensions):
70837083
dims_set = set(dimensions)
70847084
new_spec = tuple(s for i, s in enumerate(operand.sharding.spec)
70857085
if i not in dims_set)
7086-
return operand.sharding.update(spec=new_spec)
7086+
return operand.sharding.update(
7087+
spec=operand.sharding.spec.update(partitions=new_spec))
70877088

70887089
def _compute_squeeze_shape(shape, dimensions):
70897090
dims_set = set(dimensions)

jax/_src/lax/utils.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,14 +69,18 @@ def _get_abstract_mesh_from_avals(in_avals) -> mesh_lib.AbstractMesh:
6969
m = a.sharding.mesh
7070
return mesh_lib.empty_abstract_mesh if m is None else m
7171

72-
def call_unreduced_rule(prim, unreduced_rule, out_s, *avals, **kwargs):
72+
def call_unreduced_rule(prim, unreduced_rule, out_s, num_out, *avals, **kwargs):
7373
if unreduced_rule is not None:
7474
return unreduced_rule(out_s, *avals, **kwargs)
7575

7676
if any(a.sharding.spec.unreduced for a in avals):
7777
raise NotImplementedError(
7878
f'unreduced rule for {prim.name} is not implemented. Please file an'
7979
' issue at https://github.com/jax-ml/jax/issues')
80+
if any(s.spec.unreduced for s in ([out_s] if num_out is None else out_s)):
81+
raise NotImplementedError(
82+
f'unreduced rule for {prim.name} is not implemented. Please file an'
83+
' issue at https://github.com/jax-ml/jax/issues')
8084
return out_s
8185

8286
def call_sharding_rule(prim, sh_rule, unreduced_rule, num_out, *avals, **kwargs):
@@ -85,9 +89,11 @@ def call_sharding_rule(prim, sh_rule, unreduced_rule, num_out, *avals, **kwargs)
8589
if ((cur_mesh.empty or cur_mesh._are_all_axes_auto_or_manual) and
8690
(aval_mesh.empty or aval_mesh._are_all_axes_auto_or_manual)):
8791
aval_mesh = cur_mesh if aval_mesh.empty else aval_mesh
88-
s = NamedSharding(aval_mesh, P())
89-
s = call_unreduced_rule(prim, unreduced_rule, s, *avals, **kwargs)
90-
return s if num_out is None else [s] * num_out
92+
out_s = NamedSharding(aval_mesh, P())
93+
out_s = out_s if num_out is None else [out_s] * num_out
94+
out_s = call_unreduced_rule(prim, unreduced_rule, out_s, num_out,
95+
*avals, **kwargs)
96+
return out_s
9197
if sh_rule is None:
9298
raise core.ShardingTypeError(
9399
f'sharding rule for {prim.name} is not implemented. Please file an'
@@ -96,7 +102,7 @@ def call_sharding_rule(prim, sh_rule, unreduced_rule, num_out, *avals, **kwargs)
96102
' mode via: `jax.sharding.auto_axes(fun, out_shardings=...)`')
97103
out_sharding = sh_rule(*avals, **kwargs)
98104
out_sharding = call_unreduced_rule(prim, unreduced_rule, out_sharding,
99-
*avals, **kwargs)
105+
num_out, *avals, **kwargs)
100106
return out_sharding
101107

102108
def call_shape_dtype_sharding_rule(prim, shape_rule, dtype_rule, sharding_rule,

tests/BUILD

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -425,9 +425,9 @@ jax_multiplatform_test(
425425
"gpu_h100x2",
426426
],
427427
shard_count = {
428-
"cpu": 3,
428+
"cpu": 5,
429429
"gpu": 2,
430-
"tpu": 2,
430+
"tpu": 5,
431431
},
432432
tags = ["multiaccelerator"],
433433
deps = [

tests/pjit_test.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9375,8 +9375,12 @@ def f(x, y):
93759375
' same'):
93769376
f(arr1, arr2)
93779377

9378+
@parameterized.named_parameters(
9379+
('custom_vjp', True),
9380+
('grad', False),
9381+
)
93789382
@jtu.with_explicit_mesh((2,), 'x')
9379-
def test_scan_over_layers_minibatch_unreduced(self, mesh):
9383+
def test_scan_over_layers_minibatch_unreduced(self, use_custom_vjp, mesh):
93809384
if ifrt_version < 30:
93819385
self.skipTest('Requires ifrt_version >= 30')
93829386
if not jtu.if_cloud_tpu_at_least(2025, 9, 21):
@@ -9385,17 +9389,21 @@ def test_scan_over_layers_minibatch_unreduced(self, mesh):
93859389
def assert_unreduced(val):
93869390
self.assertEqual(val.aval.sharding.spec.unreduced, {'x'})
93879391

9388-
@jax.custom_vjp
9389-
def f(xs, w):
9390-
return jnp.dot(xs, w)
9392+
if use_custom_vjp:
9393+
@jax.custom_vjp
9394+
def f(xs, w):
9395+
return jnp.dot(xs, w)
93919396

9392-
def f_fwd(xs, w):
9393-
return f(xs, w), (xs, w)
9397+
def f_fwd(xs, w):
9398+
return f(xs, w), (xs, w)
93949399

9395-
def f_bwd(res, g):
9396-
xs, w = res
9397-
return jnp.dot(g, w), jnp.dot(xs.T, g, out_sharding=P(unreduced={'x'}))
9398-
f.defvjp(f_fwd, f_bwd)
9400+
def f_bwd(res, g):
9401+
xs, w = res
9402+
return jnp.dot(g, w), jnp.dot(xs.T, g, out_sharding=P(unreduced={'x'}))
9403+
f.defvjp(f_fwd, f_bwd)
9404+
else:
9405+
def f(xs, w):
9406+
return jnp.dot(xs, w)
93999407

94009408
def model(stacked_ws, xs_mubatch):
94019409
def scan_over_layers(carry_xs, w):
@@ -9423,7 +9431,14 @@ def mubatch_loop_body(stacked_grad_acc, xs_mubatch):
94239431
return jax.tree.map(
94249432
lambda W, g: W - g * 0.01, stacked_ws, stacked_grad_acc)
94259433

9426-
ws = tuple(jax.device_put(jnp.ones((4, 4)), P()) for _ in range(4))
9434+
if use_custom_vjp:
9435+
ws = tuple(jax.device_put(jnp.ones((4, 4)), P()) for _ in range(4))
9436+
else:
9437+
# Mark `w` with `reduced={'x'}` so that on the bwd pass we will induce
9438+
# an `unreduced={'x'}`.
9439+
ws = tuple(jax.device_put(jnp.ones((4, 4)), P(reduced={'x'}))
9440+
for _ in range(4))
9441+
94279442
xs = jax.device_put(jnp.ones((2, 2, 4)), P(None, 'x', None))
94289443
stacked_ws = jnp.stack(ws, axis=0)
94299444
step(stacked_ws, xs) # doesn't crash

0 commit comments

Comments
 (0)