Skip to content

Commit bb4c073

Browse files
committed
Add the function name, the Jaxpr, and lowering platforms to Lowered.
These changes are necessary to ensure that `Lowered` carries all the information that is needed for export and serialization. These are in preparation of a cleanup of the exporting and serialization APIs to integrate them with the AOT APIs. In particular, exporting will start with a `Lowered` object and will not include anymore its own lowering code. We add the lowered function name and the Jaxpr (as the attributes `_fun_name` and `_jaxpr`) to `Lowered`, and we add the tuple of lowering platforms (as `Lowered._lowering._platforms`). The function name is useful for better error messages when exporting and serializating. The Jaxpr is useful for exporting also the VJP of the function and obtaining an `Exported` that can be differentiated.
1 parent 4fae9aa commit bb4c073

File tree

6 files changed

+175
-138
lines changed

6 files changed

+175
-138
lines changed

jax/_src/api.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1846,7 +1846,7 @@ def lower(*args, **kwargs) -> stages.Lowered:
18461846
fun, in_axes, out_axes, static_broadcasted_tuple, donate_tuple,
18471847
devices, backend, axis_size, args, kwargs)
18481848
abstract_args = list(map(shaped_abstractify, p.flat_args))
1849-
computation = pxla.lower_parallel_callable(
1849+
computation, closed_jaxpr = pxla.lower_parallel_callable(
18501850
p.flat_fun, backend, axis_name,
18511851
axis_size=p.local_axis_size, global_axis_size=p.global_axis_size,
18521852
devices=p.devices,
@@ -1858,7 +1858,8 @@ def lower(*args, **kwargs) -> stages.Lowered:
18581858
avals=abstract_args,
18591859
lowering_parameters=lowering_parameters)
18601860
return stages.Lowered.from_flat_info(
1861-
computation, p.in_tree, abstract_args, donate_tuple, p.out_tree())
1861+
computation, p.in_tree, abstract_args, donate_tuple, p.out_tree(),
1862+
fun_name=p.flat_fun.__name__, jaxpr=closed_jaxpr)
18621863

18631864
return lower
18641865

jax/_src/interpreters/pxla.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -556,7 +556,7 @@ def parallel_callable(fun: lu.WrappedFun,
556556
donated_invars: Sequence[bool],
557557
is_explicit_global_axis_size: bool,
558558
*avals):
559-
pmap_computation = lower_parallel_callable(
559+
pmap_computation, _ = lower_parallel_callable(
560560
fun, backend_name, axis_name, axis_size, global_axis_size, devices, name,
561561
in_axes, out_axes_thunk, donated_invars,
562562
is_explicit_global_axis_size, avals,
@@ -679,7 +679,7 @@ def lower_parallel_callable(
679679
is_explicit_global_axis_size: bool,
680680
avals: Sequence[core.AbstractValue],
681681
*,
682-
lowering_parameters: mlir.LoweringParameters) -> PmapComputation:
682+
lowering_parameters: mlir.LoweringParameters) -> tuple[PmapComputation, core.ClosedJaxpr]:
683683
# Determine global_axis_size for use in AxisEnv.
684684
# TODO(mattjj,skyewm): revive this check (inner_pmap always False now)
685685
# if xb.process_count() > 1 and global_axis_size is None and inner_pmap:
@@ -761,6 +761,7 @@ def lower_parallel_callable(
761761
tuple_args = dispatch.should_tuple_args(len(shards.global_sharded_avals),
762762
backend.platform)
763763
module_name = f"pmap_{fun.__name__}"
764+
platforms = lowering_parameters.platforms or (backend.platform,)
764765
with maybe_extend_axis_env(axis_name, global_axis_size, None):
765766
ordered_effects = list(
766767
effects.ordered_effects.filter_in(closed_jaxpr.effects))
@@ -776,7 +777,7 @@ def lower_parallel_callable(
776777
closed_jaxpr,
777778
ordered_effects=ordered_effects,
778779
backend_or_name=backend,
779-
platforms=lowering_parameters.platforms or (backend.platform,),
780+
platforms=platforms,
780781
axis_context=sharding_impls.ReplicaAxisContext(axis_env),
781782
name_stack=name_stack,
782783
donated_args=donated_invars,
@@ -787,14 +788,16 @@ def lower_parallel_callable(
787788
result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths,
788789
num_replicas=replicas.num_global_replicas,
789790
lowering_parameters=lowering_parameters)
790-
return PmapComputation(lowering_result.module, pci=pci, replicas=replicas,
791+
return PmapComputation(lowering_result.module,
792+
platforms=platforms,
793+
pci=pci, replicas=replicas,
791794
shards=shards, tuple_args=tuple_args,
792795
unordered_effects=unordered_effects,
793796
ordered_effects=ordered_effects,
794797
keepalive=lowering_result.keepalive,
795798
host_callbacks=lowering_result.host_callbacks,
796799
jaxpr_debug_info=closed_jaxpr.jaxpr.debug_info,
797-
shape_poly_state=lowering_result.shape_poly_state)
800+
shape_poly_state=lowering_result.shape_poly_state), closed_jaxpr
798801

799802

800803
def _pmap_unmap_shaped_array(
@@ -907,10 +910,13 @@ def from_hlo(hlo: ir.Module,
907910
host_callbacks: list[Any],
908911
keepalive: Any,
909912
jaxpr_debug_info: core.JaxprDebugInfo,
913+
platforms: Sequence[str],
910914
shape_poly_state: mlir.ShapePolyLoweringState | None = None,
911915
compiler_options=None):
916+
del platforms
912917
if shape_poly_state is not None and shape_poly_state.uses_dim_vars:
913918
hlo = mlir.refine_polymorphic_shapes(hlo)
919+
914920
devices = pci.devices
915921
if devices is None:
916922
if shards.num_global_shards > xb.device_count(pci.backend):
@@ -1941,7 +1947,6 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
19411947
"The following ordered effects are not supported for "
19421948
f"more than 1 device: {unsupported_effects}")
19431949
ordered_effects = list(effects.ordered_effects.filter_in(closed_jaxpr.effects))
1944-
19451950
with dispatch.log_elapsed_time(
19461951
"Finished jaxpr to MLIR module conversion {fun_name} in {elapsed_time} sec",
19471952
fun_name=str(name_stack), event=dispatch.JAXPR_TO_MLIR_MODULE_EVENT):
@@ -2141,6 +2146,7 @@ def lower_sharding_computation(
21412146
for js, source_info in util.stable_unique(jaxpr_sharding))),
21422147
devices_from_context)
21432148

2149+
platforms = lowering_parameters.platforms or (backend.platform,)
21442150
# TODO(yashkatariya): Enable this when offload APIs are stable.
21452151
# transfer_mem_kind_in_jaxpr = list(jaxpr_transfer_mem_kinds(jaxpr))
21462152

@@ -2204,6 +2210,7 @@ def lower_sharding_computation(
22042210
kept_var_idx=kept_var_idx,
22052211
mut=mut,
22062212
backend=backend,
2213+
platforms=platforms,
22072214
device_assignment=da_object,
22082215
committed=committed,
22092216
in_layouts=in_layouts,
@@ -2244,6 +2251,7 @@ def lower_mesh_computation(
22442251
lowering_parameters: mlir.LoweringParameters) -> MeshComputation:
22452252
assert not mesh.empty
22462253
backend = xb.get_device_backend(mesh.devices.flat[0])
2254+
platforms = lowering_parameters.platforms or (backend.platform,)
22472255
name_stack = source_info_util.new_name_stack(wrap_name(fun_name, api_name))
22482256

22492257
global_axis_sizes = mesh.shape
@@ -2352,7 +2360,7 @@ def lower_mesh_computation(
23522360
closed_jaxpr,
23532361
ordered_effects=ordered_effects,
23542362
backend_or_name=backend,
2355-
platforms=lowering_parameters.platforms or (backend.platform,),
2363+
platforms=platforms,
23562364
axis_context=axis_ctx,
23572365
name_stack=name_stack,
23582366
donated_args=donated_invars,
@@ -2382,6 +2390,7 @@ def lower_mesh_computation(
23822390
keepalive=lowering_result.keepalive,
23832391
kept_var_idx=set(range(len(global_in_avals))),
23842392
backend=backend,
2393+
platforms=platforms,
23852394
device_assignment=_create_da_object(tuple(mesh.devices.flat)),
23862395
committed=True,
23872396
in_layouts=(None,) * len(global_in_avals),
@@ -2394,10 +2403,14 @@ class MeshComputation(stages.XlaLowering):
23942403
_executable: MeshExecutable | None
23952404

23962405
def __init__(self, name: str, hlo: ir.Module,
2397-
donated_invars: Sequence[bool], **compile_args):
2406+
donated_invars: Sequence[bool],
2407+
platforms: Sequence[str] | None = None, # None only for backwards
2408+
# compatibility with PartIR
2409+
**compile_args):
23982410
self._name = name
23992411
self._hlo = hlo
24002412
self._donated_invars = donated_invars
2413+
self._platforms = platforms
24012414
self.compile_args = compile_args
24022415
self._executable = None
24032416

jax/_src/maps.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -617,7 +617,7 @@ def lower(*args, **kwargs):
617617
'_experimental_lowering_platform', mlir.LoweringParameters())
618618
fun_flat, args_flat, params, in_tree, out_tree = infer_params(*args)
619619
avals_flat = [shaped_abstractify(arg) for arg in args_flat]
620-
computation = make_xmap_callable(
620+
computation, jaxpr = make_xmap_callable(
621621
fun_flat, params['name'], params['in_axes'], params['out_axes_thunk'],
622622
params['donated_invars'], params['global_axis_sizes'], params['axis_resources'],
623623
params['resource_env'], params['backend'], params['spmd_in_axes'],
@@ -628,7 +628,7 @@ def lower(*args, **kwargs):
628628
in_avals = in_tree.unflatten(avals_flat)
629629
return stages.Lowered.from_flat_info(
630630
computation, in_tree, in_avals, donate_argnums, out_tree(),
631-
no_kwargs=True)
631+
no_kwargs=True, fun_name=params['name'], jaxpr=jaxpr)
632632

633633
fun_mapped.lower = lower
634634
return type_cast(stages.Wrapped, fun_mapped)
@@ -637,11 +637,12 @@ def xmap_impl(fun: lu.WrappedFun, *args, name, in_axes, out_axes_thunk, donated_
637637
global_axis_sizes, axis_resources, resource_env, backend,
638638
spmd_in_axes, spmd_out_axes_thunk):
639639
in_avals = [core.raise_to_shaped(core.get_aval(arg)) for arg in args]
640-
xmap_callable = make_xmap_callable(
640+
computation, _ = make_xmap_callable(
641641
fun, name, in_axes, out_axes_thunk, donated_invars, global_axis_sizes,
642642
axis_resources, resource_env, backend,
643643
spmd_in_axes, spmd_out_axes_thunk,
644-
mlir.LoweringParameters(), *in_avals).compile().unsafe_call
644+
mlir.LoweringParameters(), *in_avals)
645+
xmap_callable = computation.compile().unsafe_call
645646
distributed_debug_log(("Running xmapped function", name),
646647
("python function", fun.f),
647648
("mesh", resource_env.physical_mesh),
@@ -708,15 +709,15 @@ def make_xmap_callable(fun: lu.WrappedFun,
708709
in_shardings, out_shardings, donated_invars,
709710
use_spmd_lowering, in_avals,
710711
tiling_method=tiling_method,
711-
lowering_parameters=lowering_parameters)
712+
lowering_parameters=lowering_parameters), jaxpr
712713
else:
713714
jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(f, in_avals)
714715
return pxla.lower_sharding_computation(
715716
core.ClosedJaxpr(jaxpr, consts), 'jit', name,
716717
(UNSPECIFIED,) * len(in_avals), (UNSPECIFIED,) * len(out_avals),
717718
(None,) * len(in_avals), (None,) * len(out_avals),
718719
donated_invars, keep_unused=True, inline=False,
719-
devices_from_context=None, lowering_parameters=lowering_parameters)
720+
devices_from_context=None, lowering_parameters=lowering_parameters), jaxpr
720721

721722

722723
class EvaluationPlan(NamedTuple):

jax/_src/pjit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,7 @@ def lower(*args, **kwargs):
469469
donate_argnums = tuple(i for i, d in enumerate(donated_invars) if d)
470470
return stages.Lowered.from_flat_info(
471471
lowering, in_tree, flat_global_in_avals, donate_argnums,
472-
out_tree)
472+
out_tree, fun_name=params["name"], jaxpr=params["jaxpr"])
473473

474474
@api_boundary
475475
def eval_shape(*args, **kwargs):

jax/_src/stages.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -601,23 +601,29 @@ class Lowered(Stage):
601601
querying properties of lowered computations across JAX's various
602602
lowering paths (:func:`~jax.jit`, :func:`~jax.pmap`, etc.).
603603
"""
604-
__slots__ = ["args_info", "out_tree", "_lowering", "_no_kwargs"]
605-
604+
__slots__ = ["_lowering", "args_info", "out_tree", "_no_kwargs", "_fun_name", "_jaxpr"]
605+
_lowering: XlaLowering
606606
args_info: Any # PyTree of ArgInfo
607607
out_tree: tree_util.PyTreeDef
608-
_lowering: XlaLowering
609608
_no_kwargs: bool
609+
_fun_name: str
610+
_jaxpr: core.ClosedJaxpr | None # Can be None when this class is constructed
611+
# outside of JAX core.
610612

611613
def __init__(
612614
self,
613615
lowering: XlaLowering,
614616
args_info, # PyTree of ArgInfo
615617
out_tree: tree_util.PyTreeDef,
616-
no_kwargs: bool = False):
618+
no_kwargs: bool = False,
619+
fun_name: str = "unknown",
620+
jaxpr: core.ClosedJaxpr | None = None):
617621
self._lowering = lowering
618622
self._no_kwargs = no_kwargs
619623
self.args_info = args_info
620624
self.out_tree = out_tree
625+
self._fun_name = fun_name
626+
self._jaxpr = jaxpr
621627

622628
@classmethod
623629
def from_flat_info(cls,
@@ -626,7 +632,9 @@ def from_flat_info(cls,
626632
in_avals,
627633
donate_argnums: tuple[int, ...],
628634
out_tree: tree_util.PyTreeDef,
629-
no_kwargs: bool = False):
635+
no_kwargs: bool = False,
636+
fun_name: str = "unknown",
637+
jaxpr: core.ClosedJaxpr | None = None):
630638
"""Initialize from flat info (``in_avals`` etc.) and an input PyTreeDef.
631639
632640
Args:
@@ -635,12 +643,14 @@ def from_flat_info(cls,
635643
no_kwargs: If ``True`` the transformation, and the
636644
``Compiled`` returned from this object will not support keyword
637645
arguments (an error will be raised if some are provided).
646+
fun_name: the name of the lowered function, if available.
647+
jaxpr: the Jaxpr of the lowered function, if available.
638648
"""
639649
return cls(
640650
lowering,
641651
make_args_info(in_tree, in_avals, donate_argnums),
642652
out_tree,
643-
no_kwargs=no_kwargs)
653+
no_kwargs=no_kwargs, fun_name=fun_name, jaxpr=jaxpr)
644654

645655
def compile(
646656
self, compiler_options: CompilerOptions | None = None) -> Compiled:

0 commit comments

Comments
 (0)