Skip to content

Commit 4257c62

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Remove one sized mesh axis from spmd_axis_name during comparison with explicit axes if remove_size_one_mesh_axis_from_type is turned on.
PiperOrigin-RevId: 842467979
1 parent 548eaa5 commit 4257c62

File tree

2 files changed

+23
-1
lines changed

2 files changed

+23
-1
lines changed

jax/_src/api.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@
6868
from jax._src.lib import xla_client as xc
6969
from jax._src.lib import pmap_lib
7070
from jax._src.sharding import Sharding
71-
from jax._src.mesh import get_concrete_mesh
71+
from jax._src.mesh import get_concrete_mesh, get_abstract_mesh
7272
from jax._src.sharding_impls import (PmapSharding, PartitionSpec as P,
7373
NamedSharding)
7474
from jax._src.layout import Format
@@ -1191,6 +1191,9 @@ def vmap_f(*args, **kwargs):
11911191
_mapped_axis_size(fun, in_tree, args_flat, in_axes_flat, "vmap"))
11921192
explicit_mesh_axis = _mapped_axis_spec(args_flat, in_axes_flat)
11931193
if spmd_axis_name is not None and explicit_mesh_axis is not None:
1194+
spmd_axis_name = (
1195+
tuple(core.remove_size_one_mesh_axis(P(spmd_axis_name), get_abstract_mesh()))
1196+
if config.remove_size_one_mesh_axis_from_type.value else spmd_axis_name)
11941197
if spmd_axis_name == explicit_mesh_axis:
11951198
spmd_axis_name = None
11961199
else:

tests/pjit_test.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7268,6 +7268,25 @@ def f(x):
72687268
"Only one of spmd_axis_name or arrays sharded on.*spmd_axis_name"):
72697269
f(arr)
72707270

7271+
@config.remove_size_one_mesh_axis_from_type(True)
7272+
@jtu.with_explicit_mesh((2, 1), ('x', 'y'))
7273+
def test_spmd_axis_name_explicit_mode_assert_remove_one_size(self, mesh):
7274+
np_inp = np.arange(16).reshape(8, 2)
7275+
arr = jax.device_put(np_inp, NamedSharding(mesh, P(('x', 'y'), None)))
7276+
7277+
@jax.jit
7278+
@partial(jax.vmap, spmd_axis_name=('x', 'y'))
7279+
def f(x):
7280+
# breakpoint()
7281+
self.assertEqual(x.aval.sharding.spec, P(None))
7282+
out = x * 2
7283+
self.assertEqual(out.aval.sharding.spec, P(None))
7284+
return out
7285+
7286+
out = f(arr)
7287+
self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None)))
7288+
self.assertArraysEqual(out, np_inp * 2)
7289+
72717290
@jtu.with_explicit_mesh((2,), ('x',))
72727291
def test_unmapped_last_vmap(self, mesh):
72737292
np_inp = np.arange(8)

0 commit comments

Comments
 (0)