Skip to content

Commit de28ee6

Browse files
author
jax authors
committed
Merge pull request #21451 from gnecula:poly_pmap
PiperOrigin-RevId: 637890579
2 parents 720d2b8 + acb56a2 commit de28ee6

File tree

2 files changed

+22
-2
lines changed

2 files changed

+22
-2
lines changed

jax/_src/interpreters/pxla.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -793,7 +793,8 @@ def lower_parallel_callable(
793793
ordered_effects=ordered_effects,
794794
keepalive=lowering_result.keepalive,
795795
host_callbacks=lowering_result.host_callbacks,
796-
jaxpr_debug_info=closed_jaxpr.jaxpr.debug_info)
796+
jaxpr_debug_info=closed_jaxpr.jaxpr.debug_info,
797+
shape_poly_state=lowering_result.shape_poly_state)
797798

798799

799800
def _pmap_unmap_shaped_array(
@@ -906,7 +907,10 @@ def from_hlo(hlo: ir.Module,
906907
host_callbacks: list[Any],
907908
keepalive: Any,
908909
jaxpr_debug_info: core.JaxprDebugInfo,
910+
shape_poly_state: mlir.ShapePolyLoweringState | None = None,
909911
compiler_options=None):
912+
if shape_poly_state is not None and shape_poly_state.uses_dim_vars:
913+
hlo = mlir.refine_polymorphic_shapes(hlo)
910914
devices = pci.devices
911915
if devices is None:
912916
if shards.num_global_shards > xb.device_count(pci.backend):

tests/export_test.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -848,6 +848,23 @@ def f(x): # x: f32[b]
848848
a = exp2.in_avals[0].shape[0]
849849
self.assertEqual(exp2.out_avals[0].shape, output_shape(a))
850850

851+
def test_poly_call_pmap(self):
852+
if len(jax.devices()) < 2:
853+
self.skipTest("Need at least 2 devices")
854+
def f(x): # x: f32[a, 4]
855+
return x + jnp.arange(x.shape[0], dtype=x.dtype).reshape((x.shape[0], 1))
856+
857+
a, = export.symbolic_shape("a")
858+
exp = export.export(f)(
859+
jax.ShapeDtypeStruct((a, 4), np.float32))
860+
f_exp = export.call_exported(exp)
861+
x_jit = np.arange(12, dtype=np.float32).reshape((3, 4))
862+
res_jit = jax.jit(f_exp)(x_jit)
863+
self.assertAllClose(res_jit, f(x_jit))
864+
x_pmap = np.arange(24, dtype=np.float32).reshape((2, 3, 4))
865+
res_pmap = jax.pmap(f_exp)(x_pmap)
866+
self.assertAllClose(res_pmap, jnp.stack([f(x) for x in x_pmap]))
867+
851868
def test_with_sharding(self):
852869
nr_devices = 2
853870
if len(jax.devices()) < nr_devices:
@@ -1204,7 +1221,6 @@ def f(x):
12041221
g_rev = jax.grad(export.call(exp))(input)
12051222
self.assertAllClose(g, g_rev)
12061223

1207-
12081224
def test_multi_platform(self):
12091225
x = np.arange(8, dtype=np.float32)
12101226
exp = get_exported(_testing_multi_platform_func,

0 commit comments

Comments
 (0)