@@ -556,7 +556,7 @@ def parallel_callable(fun: lu.WrappedFun,
556556 donated_invars : Sequence [bool ],
557557 is_explicit_global_axis_size : bool ,
558558 * avals ):
559- pmap_computation = lower_parallel_callable (
559+ pmap_computation , _ = lower_parallel_callable (
560560 fun , backend_name , axis_name , axis_size , global_axis_size , devices , name ,
561561 in_axes , out_axes_thunk , donated_invars ,
562562 is_explicit_global_axis_size , avals ,
@@ -679,7 +679,7 @@ def lower_parallel_callable(
679679 is_explicit_global_axis_size : bool ,
680680 avals : Sequence [core .AbstractValue ],
681681 * ,
682- lowering_parameters : mlir .LoweringParameters ) -> PmapComputation :
682+ lowering_parameters : mlir .LoweringParameters ) -> tuple [ PmapComputation , core . ClosedJaxpr ] :
683683 # Determine global_axis_size for use in AxisEnv.
684684 # TODO(mattjj,skyewm): revive this check (inner_pmap always False now)
685685 # if xb.process_count() > 1 and global_axis_size is None and inner_pmap:
@@ -761,6 +761,7 @@ def lower_parallel_callable(
761761 tuple_args = dispatch .should_tuple_args (len (shards .global_sharded_avals ),
762762 backend .platform )
763763 module_name = f"pmap_{ fun .__name__ } "
764+ platforms = lowering_parameters .platforms or (backend .platform ,)
764765 with maybe_extend_axis_env (axis_name , global_axis_size , None ):
765766 ordered_effects = list (
766767 effects .ordered_effects .filter_in (closed_jaxpr .effects ))
@@ -776,7 +777,7 @@ def lower_parallel_callable(
776777 closed_jaxpr ,
777778 ordered_effects = ordered_effects ,
778779 backend_or_name = backend ,
779- platforms = lowering_parameters . platforms or ( backend . platform ,) ,
780+ platforms = platforms ,
780781 axis_context = sharding_impls .ReplicaAxisContext (axis_env ),
781782 name_stack = name_stack ,
782783 donated_args = donated_invars ,
@@ -787,14 +788,16 @@ def lower_parallel_callable(
787788 result_names = jaxpr .debug_info and jaxpr .debug_info .result_paths ,
788789 num_replicas = replicas .num_global_replicas ,
789790 lowering_parameters = lowering_parameters )
790- return PmapComputation (lowering_result .module , pci = pci , replicas = replicas ,
791+ return PmapComputation (lowering_result .module ,
792+ platforms = platforms ,
793+ pci = pci , replicas = replicas ,
791794 shards = shards , tuple_args = tuple_args ,
792795 unordered_effects = unordered_effects ,
793796 ordered_effects = ordered_effects ,
794797 keepalive = lowering_result .keepalive ,
795798 host_callbacks = lowering_result .host_callbacks ,
796799 jaxpr_debug_info = closed_jaxpr .jaxpr .debug_info ,
797- shape_poly_state = lowering_result .shape_poly_state )
800+ shape_poly_state = lowering_result .shape_poly_state ), closed_jaxpr
798801
799802
800803def _pmap_unmap_shaped_array (
@@ -907,10 +910,13 @@ def from_hlo(hlo: ir.Module,
907910 host_callbacks : list [Any ],
908911 keepalive : Any ,
909912 jaxpr_debug_info : core .JaxprDebugInfo ,
913+ platforms : Sequence [str ],
910914 shape_poly_state : mlir .ShapePolyLoweringState | None = None ,
911915 compiler_options = None ):
916+ del platforms
912917 if shape_poly_state is not None and shape_poly_state .uses_dim_vars :
913918 hlo = mlir .refine_polymorphic_shapes (hlo )
919+
914920 devices = pci .devices
915921 if devices is None :
916922 if shards .num_global_shards > xb .device_count (pci .backend ):
@@ -1941,7 +1947,6 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
19411947 "The following ordered effects are not supported for "
19421948 f"more than 1 device: { unsupported_effects } " )
19431949 ordered_effects = list (effects .ordered_effects .filter_in (closed_jaxpr .effects ))
1944-
19451950 with dispatch .log_elapsed_time (
19461951 "Finished jaxpr to MLIR module conversion {fun_name} in {elapsed_time} sec" ,
19471952 fun_name = str (name_stack ), event = dispatch .JAXPR_TO_MLIR_MODULE_EVENT ):
@@ -2141,6 +2146,7 @@ def lower_sharding_computation(
21412146 for js , source_info in util .stable_unique (jaxpr_sharding ))),
21422147 devices_from_context )
21432148
2149+ platforms = lowering_parameters .platforms or (backend .platform ,)
21442150 # TODO(yashkatariya): Enable this when offload APIs are stable.
21452151 # transfer_mem_kind_in_jaxpr = list(jaxpr_transfer_mem_kinds(jaxpr))
21462152
@@ -2204,6 +2210,7 @@ def lower_sharding_computation(
22042210 kept_var_idx = kept_var_idx ,
22052211 mut = mut ,
22062212 backend = backend ,
2213+ platforms = platforms ,
22072214 device_assignment = da_object ,
22082215 committed = committed ,
22092216 in_layouts = in_layouts ,
@@ -2244,6 +2251,7 @@ def lower_mesh_computation(
22442251 lowering_parameters : mlir .LoweringParameters ) -> MeshComputation :
22452252 assert not mesh .empty
22462253 backend = xb .get_device_backend (mesh .devices .flat [0 ])
2254+ platforms = lowering_parameters .platforms or (backend .platform ,)
22472255 name_stack = source_info_util .new_name_stack (wrap_name (fun_name , api_name ))
22482256
22492257 global_axis_sizes = mesh .shape
@@ -2352,7 +2360,7 @@ def lower_mesh_computation(
23522360 closed_jaxpr ,
23532361 ordered_effects = ordered_effects ,
23542362 backend_or_name = backend ,
2355- platforms = lowering_parameters . platforms or ( backend . platform ,) ,
2363+ platforms = platforms ,
23562364 axis_context = axis_ctx ,
23572365 name_stack = name_stack ,
23582366 donated_args = donated_invars ,
@@ -2382,6 +2390,7 @@ def lower_mesh_computation(
23822390 keepalive = lowering_result .keepalive ,
23832391 kept_var_idx = set (range (len (global_in_avals ))),
23842392 backend = backend ,
2393+ platforms = platforms ,
23852394 device_assignment = _create_da_object (tuple (mesh .devices .flat )),
23862395 committed = True ,
23872396 in_layouts = (None ,) * len (global_in_avals ),
@@ -2394,10 +2403,15 @@ class MeshComputation(stages.XlaLowering):
23942403 _executable : MeshExecutable | None
23952404
23962405 def __init__ (self , name : str , hlo : ir .Module ,
2397- donated_invars : Sequence [bool ], ** compile_args ):
2406+ donated_invars : Sequence [bool ],
2407+ # TODO(necula): fix this when internal clients stop using this
2408+ # constructor directly.
2409+ platforms : Sequence [str ] | None = None ,
2410+ ** compile_args ):
23982411 self ._name = name
23992412 self ._hlo = hlo
24002413 self ._donated_invars = donated_invars
2414+ self ._platforms = platforms
24012415 self .compile_args = compile_args
24022416 self ._executable = None
24032417
0 commit comments