Skip to content

Commit 3ef63ce

Browse files
mattjjsharadmv
andcommitted
[mutable-arrays] ignroe InternalMutableArrayEffect in partial_eval has_effects
It's not really an effect from the point of view of partial eval; like NamedAxisEffect, it's just a bit of info we're sneaking into the effect system as a handy way of tracking it. But it should be ignored for the purpose of deciding "is this code effectful". (We're hoping that InternalMutableArrayEffect, and the effect system more generally, and partial eval, are not long for this world...) This is a follow-up on #33906 Co-authored-by: Sharad Vikram <sharadmv@google.com>
1 parent 1c0bb94 commit 3ef63ce

File tree

3 files changed

+32
-2
lines changed

3 files changed

+32
-2
lines changed

jax/_src/ad_checkpoint.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -845,7 +845,8 @@ def remat_dce(used_outputs: list[bool], eqn: core.JaxprEqn
845845
pe.dce_rules[remat_p] = remat_dce
846846

847847
def _has_effects(effects) -> bool:
848-
return bool({e for e in effects if not isinstance(e, core.NamedAxisEffect)})
848+
not_really_effects = (core.NamedAxisEffect, core.InternalMutableArrayEffect)
849+
return any(not isinstance(e, not_really_effects) for e in effects)
849850

850851

851852
def remat_expansion(

jax/_src/interpreters/partial_eval.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1051,7 +1051,8 @@ def ensure_instantiated(inst: bool, x: Atom) -> Atom:
10511051
return x
10521052

10531053
def has_effects(effects) -> bool:
1054-
return bool({e for e in effects if not isinstance(e, core.NamedAxisEffect)})
1054+
not_really_effects = (core.NamedAxisEffect, core.InternalMutableArrayEffect)
1055+
return any(not isinstance(e, not_really_effects) for e in effects)
10551056

10561057
known_eqns, staged_eqns = [], []
10571058
foreach(write, in_unknowns, in_inst, jaxpr.invars)

tests/mutable_array_test.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1053,6 +1053,34 @@ def f(x):
10531053

10541054
f(3.) # don't crash
10551055

1056+
def test_remat_while_loop_residuals(self):
1057+
@jax.custom_vjp
1058+
def ra2a(x):
1059+
return jax.freeze(jax.new_ref(x))
1060+
1061+
def ra2a_fwd(x):
1062+
o = ra2a(x)
1063+
return o, ()
1064+
1065+
def ra2a_bwd(res, g):
1066+
return (ra2a(g),)
1067+
1068+
ra2a.defvjp(ra2a_fwd, ra2a_bwd)
1069+
1070+
@jax.jit
1071+
@jax.remat
1072+
def f(x):
1073+
1074+
def g(x):
1075+
def body(carry):
1076+
i, x = carry
1077+
x = ra2a(x)
1078+
return i + 1, x
1079+
return jax.lax.while_loop(lambda x: x[0] < 5, body, (0, x))[1]
1080+
return g(x)
1081+
1082+
jax.linearize(f, 5.) # don't crash
1083+
10561084

10571085
@jtu.with_config(jax_mutable_array_checks=True)
10581086
class MutableArrayErrorsTest(jtu.JaxTestCase):

0 commit comments

Comments
 (0)