Skip to content

Commit be1f505

Browse files
Merge pull request #33841 from mattjj:bjp
PiperOrigin-RevId: 842541857
2 parents 4257c62 + 8186c19 commit be1f505

File tree

2 files changed

+32
-2
lines changed

2 files changed

+32
-2
lines changed

jax/_src/interpreters/mlir.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2451,8 +2451,15 @@ def f_lowered(ctx: LoweringRuleContext, *args, **params):
24512451
wrapped_fun = lu.wrap_init(f, params,
24522452
debug_info=api_util.debug_info("lower_fun", fun, args, {}))
24532453

2454-
jaxpr, _, consts_for_constvars = pe.trace_to_jaxpr_dynamic(wrapped_fun,
2455-
ctx.avals_in)
2454+
jaxpr, _, consts_for_constvars = pe.trace_to_jaxpr_dynamic(
2455+
wrapped_fun, ctx.avals_in)
2456+
2457+
if any(isinstance(e, core.InternalMutableArrayEffect) for e in jaxpr.effects):
2458+
from jax._src.interpreters import pxla # type: ignore
2459+
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts_for_constvars)
2460+
closed_jaxpr = pxla._discharge_internal_refs(closed_jaxpr)
2461+
jaxpr, consts_for_constvars = closed_jaxpr.jaxpr, closed_jaxpr.consts
2462+
24562463
# TODO(frostig,mattjj): check ctx.avals_out against jaxpr avals out?
24572464

24582465
if ctx.platforms is not None:

tests/mutable_array_test.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from jax._src import test_util as jtu
2828
from jax._src.api import vjp3
2929
from jax._src.util import safe_map, safe_zip
30+
from jax._src.interpreters import mlir
3031
from jax.sharding import NamedSharding, PartitionSpec as P, AxisType
3132
import jax.numpy as jnp
3233

@@ -1031,6 +1032,28 @@ def test_none_index(self):
10311032
y = ref[None]
10321033
self.assertEqual(y.shape, (1, 3))
10331034

1035+
def test_what_if_you_lower_fun_something_with_internal_effects(self):
1036+
bjp_p = core.Primitive('bjp')
1037+
1038+
@bjp_p.def_abstract_eval
1039+
def _(aval):
1040+
return aval
1041+
1042+
def lowering(x):
1043+
x_ref = jax.new_ref(x)
1044+
x_ref[...] += 1
1045+
x_ref[...] += -1
1046+
return jax.freeze(x_ref)
1047+
1048+
mlir.register_lowering(bjp_p, mlir.lower_fun(lowering, multiple_results=False))
1049+
1050+
@jax.jit
1051+
def f(x):
1052+
return bjp_p.bind(x)
1053+
1054+
f(3.) # don't crash
1055+
1056+
10341057
@jtu.with_config(jax_mutable_array_checks=True)
10351058
class MutableArrayErrorsTest(jtu.JaxTestCase):
10361059
def test_return_from_jit(self):

0 commit comments

Comments
 (0)