Skip to content

Commit 1a6bb95

Browse files
Merge pull request #33924 from mattjj:while-loop-internal-ref-effect
PiperOrigin-RevId: 843789461
2 parents 95048f2 + 3ef63ce commit 1a6bb95

File tree

3 files changed

+32
-2
lines changed

3 files changed

+32
-2
lines changed

jax/_src/ad_checkpoint.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -845,7 +845,8 @@ def remat_dce(used_outputs: list[bool], eqn: core.JaxprEqn
845845
pe.dce_rules[remat_p] = remat_dce
846846

847847
def _has_effects(effects) -> bool:
848-
return bool({e for e in effects if not isinstance(e, core.NamedAxisEffect)})
848+
not_really_effects = (core.NamedAxisEffect, core.InternalMutableArrayEffect)
849+
return any(not isinstance(e, not_really_effects) for e in effects)
849850

850851

851852
def remat_expansion(

jax/_src/interpreters/partial_eval.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1050,7 +1050,8 @@ def ensure_instantiated(inst: bool, x: Atom) -> Atom:
10501050
return x
10511051

10521052
def has_effects(effects) -> bool:
1053-
return bool({e for e in effects if not isinstance(e, core.NamedAxisEffect)})
1053+
not_really_effects = (core.NamedAxisEffect, core.InternalMutableArrayEffect)
1054+
return any(not isinstance(e, not_really_effects) for e in effects)
10541055

10551056
known_eqns, staged_eqns = [], []
10561057
foreach(write, in_unknowns, in_inst, jaxpr.invars)

tests/mutable_array_test.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1053,6 +1053,34 @@ def f(x):
10531053

10541054
f(3.) # don't crash
10551055

1056+
def test_remat_while_loop_residuals(self):
1057+
@jax.custom_vjp
1058+
def ra2a(x):
1059+
return jax.freeze(jax.new_ref(x))
1060+
1061+
def ra2a_fwd(x):
1062+
o = ra2a(x)
1063+
return o, ()
1064+
1065+
def ra2a_bwd(res, g):
1066+
return (ra2a(g),)
1067+
1068+
ra2a.defvjp(ra2a_fwd, ra2a_bwd)
1069+
1070+
@jax.jit
1071+
@jax.remat
1072+
def f(x):
1073+
1074+
def g(x):
1075+
def body(carry):
1076+
i, x = carry
1077+
x = ra2a(x)
1078+
return i + 1, x
1079+
return jax.lax.while_loop(lambda x: x[0] < 5, body, (0, x))[1]
1080+
return g(x)
1081+
1082+
jax.linearize(f, 5.) # don't crash
1083+
10561084

10571085
@jtu.with_config(jax_mutable_array_checks=True)
10581086
class MutableArrayErrorsTest(jtu.JaxTestCase):

0 commit comments

Comments
 (0)