Skip to content

Commit 55de89b

Browse files
author
jax authors
committed
Merge pull request #21479 from gnecula:jax2tf_fix_grad_mesh
PiperOrigin-RevId: 638328377
2 parents fa4980c + 7d92328 commit 55de89b

File tree

3 files changed

+23
-11
lines changed

3 files changed

+23
-11
lines changed

jax/experimental/export/_export.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,7 @@ def export(fun_jax: Callable,
351351
*,
352352
lowering_platforms: Sequence[str] | None = None,
353353
disabled_checks: Sequence[DisabledSafetyCheck] = (),
354+
_device_assignment_for_internal_jax2tf_use_only = None,
354355
) -> Callable[..., Exported]:
355356
"""Exports native serialization for a JAX function.
356357
@@ -413,12 +414,15 @@ def do_export(*args_specs, **kwargs_specs) -> Exported:
413414
_experimental_lowering_parameters=mlir.LoweringParameters(
414415
platforms=actual_lowering_platforms,
415416
))
416-
return _export_lowered(lowered, disabled_checks=disabled_checks)
417+
return _export_lowered(
418+
lowered, disabled_checks=disabled_checks,
419+
_device_assignment_for_internal_jax2tf_use_only=_device_assignment_for_internal_jax2tf_use_only)
417420
return do_export
418421

419422
def _export_lowered(
420423
lowered: stages.Lowered,
421424
disabled_checks: Sequence[DisabledSafetyCheck] = (),
425+
_device_assignment_for_internal_jax2tf_use_only = None,
422426
) -> Exported:
423427
version = config.jax_serialization_version.value
424428
if (version < minimum_supported_serialization_version or
@@ -498,6 +502,8 @@ def export_sharding(s: LoweringSharding,
498502
for s, aval in zip(lowering.compile_args["out_shardings"], out_avals_flat))
499503

500504
device_assignment = lowering.compile_args["device_assignment"]
505+
if _device_assignment_for_internal_jax2tf_use_only is not None:
506+
_device_assignment_for_internal_jax2tf_use_only[0] = device_assignment
501507
def _get_exported_vjp(exp_primal: Exported) -> Exported:
502508
# Turn the primal jaxpr into a function, in preparation for exporting
503509
# the VJP. Note that jaxpr_as_fun produces a function with flat arguments

jax/experimental/jax2tf/jax2tf.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -514,11 +514,15 @@ def _restore_context():
514514
_thread_local_state.call_tf_concrete_function_list = _prev_func_list
515515

516516
self._restore_context = _restore_context
517+
_exported_device_assignment = [None]
517518
self.exported = export.export(
518519
self.fun_jax,
519520
lowering_platforms=self.native_serialization_platforms,
520-
disabled_checks=self.native_serialization_disabled_checks
521+
disabled_checks=self.native_serialization_disabled_checks,
522+
_device_assignment_for_internal_jax2tf_use_only=_exported_device_assignment,
521523
)(*self.args_specs, **self.kwargs_specs)
524+
assert(_exported_device_assignment[0] is not None)
525+
self.device_assignment = _exported_device_assignment[0]
522526

523527
def after_conversion(self):
524528
self._restore_context()
@@ -531,15 +535,13 @@ def run_fun_tf(self,
531535

532536
def get_vjp_fun(self) -> tuple[Callable,
533537
Sequence[core.AbstractValue]]:
534-
# TODO(necula): use the actual device assignment from the primal function
535-
device_assignment = jax.devices(jax.default_backend())[:self.exported.nr_devices]
536538
return _export._get_vjp_fun(self.fun_jax,
537539
in_tree=self.exported.in_tree,
538540
in_avals=self.exported.in_avals,
539541
in_shardings=self.exported.in_shardings,
540542
out_avals=self.exported.out_avals,
541543
out_shardings=self.exported.out_shardings,
542-
device_assignment=device_assignment,
544+
device_assignment=self.device_assignment,
543545
apply_jit=True)
544546

545547
class GraphSerializationImpl(SerializationImpl):

jax/experimental/jax2tf/tests/sharding_test.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -455,22 +455,26 @@ def f_grad_tf(x_v, res_ct):
455455
])
456456

457457
def test_grad_sharding_different_mesh(self):
458-
self.skipTest("TODO: fix the plumbing of device_assignment for jax2tf: https://github.com/google/jax/pull/21319")
459458
# Convert with two similar meshes, the only difference being
460459
# the order of the devices. grad should not fail.
461460
# https://github.com/google/jax/issues/21314
461+
devices = jax.local_devices()[:2]
462+
if len(devices) < 2:
463+
raise unittest.SkipTest("Test requires 2 local devices")
462464
def f_jax(x):
463465
return jnp.sum(x * 2.)
464466

465-
mesh = Mesh(jax.local_devices(), "i")
467+
mesh = Mesh(devices, "i")
466468
# The same mesh with reversed order of devices
467-
mesh_rev = Mesh(list(reversed(jax.local_devices())), "i")
469+
mesh_rev = Mesh(list(reversed(devices)), "i")
468470
shardings = NamedSharding(mesh, jax.sharding.PartitionSpec(("i",)))
469471
shardings_rev = NamedSharding(mesh_rev, jax.sharding.PartitionSpec(("i",)))
470472

471-
f_tf = tf.function(jax2tf.convert(pjit.pjit(f_jax, in_shardings=shardings)))
472-
f_tf_rev = tf.function(jax2tf.convert(pjit.pjit(f_jax, in_shardings=shardings_rev)))
473-
inp = np.ones((jax.local_device_count(), 4), dtype=np.float32)
473+
f_tf = tf.function(jax2tf.convert(pjit.pjit(f_jax, in_shardings=shardings)),
474+
autograph=False)
475+
f_tf_rev = tf.function(jax2tf.convert(pjit.pjit(f_jax, in_shardings=shardings_rev)),
476+
autograph=False)
477+
inp = np.ones((2, 4), dtype=np.float32)
474478

475479
input_v = tf.Variable(inp)
476480
with tf.GradientTape(persistent=True) as tape:

0 commit comments

Comments
 (0)