Skip to content

Commit 637982e

Browse files
[MGPU] Add support for broadcast on major dim in WGStridedFragLayout.
PiperOrigin-RevId: 842680640
1 parent e43f4cb commit 637982e

File tree

4 files changed

+78
-7
lines changed

4 files changed

+78
-7
lines changed

jax/_src/pallas/mosaic_gpu/lowering.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1860,6 +1860,14 @@ def _broadcast_in_dim_lowering_rule(
18601860
if (isinstance(x.layout, mgpu.WGSplatFragLayout) and
18611861
broadcast_dimensions == tuple(range(rank_diff, rank_diff + x_aval.ndim))):
18621862
return x.broadcast(shape)
1863+
if (
1864+
isinstance(x.layout, mgpu.WGStridedFragLayout)
1865+
and broadcast_dimensions == tuple(range(rank_diff, y_aval.ndim))
1866+
):
1867+
new_layout = mgpu.WGStridedFragLayout(
1868+
shape=y_aval.shape, vec_size=x.layout.vec_size
1869+
)
1870+
return x.broadcast_in_dim(y_aval.shape, broadcast_dimensions, new_layout)
18631871
if not isinstance(layout := x.layout, mgpu.TiledLayout):
18641872
raise NotImplementedError(f"Unsupported layout: {x.layout}")
18651873
if any(d1 >= d2 for d1, d2 in zip(broadcast_dimensions[:-1], broadcast_dimensions[1:])):

jax/experimental/mosaic/gpu/fragmented_array.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2504,6 +2504,20 @@ def broadcast_in_dim(
25042504
return type(self).splat(
25052505
self.registers.item(), shape, layout, is_signed=self.is_signed
25062506
)
2507+
if isinstance(self.layout, WGStridedFragLayout) and isinstance(layout, WGStridedFragLayout):
2508+
new_dims = set(range(len(shape))) - set(source_dimensions)
2509+
vec_match = self.layout.vec_size == layout.vec_size
2510+
broadcast_dim_match = new_dims == set(range(len(new_dims)))
2511+
assert layout.shape == shape, (layout.shape, shape)
2512+
if vec_match and broadcast_dim_match:
2513+
return FragmentedArray(
2514+
_registers=np.tile(
2515+
self.registers,
2516+
np.prod(shape[:len(new_dims)]),
2517+
),
2518+
_layout=layout,
2519+
_is_signed=self.is_signed,
2520+
)
25072521
if not isinstance(self.layout, TiledLayout) or not isinstance(layout, TiledLayout):
25082522
raise NotImplementedError(self.layout, layout)
25092523
if any(d1 >= d2 for d1, d2 in zip(source_dimensions, source_dimensions[1:])):

tests/mosaic/gpu_test.py

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3796,16 +3796,49 @@ def kernel(ctx, *args):
37963796
)(inp)
37973797
np.testing.assert_array_equal(result, inp)
37983798

3799-
@parameterized.parameters((128, 128), (128, 64), (64, 128))
3800-
def test_broadcast_major(self, m, n):
3799+
@parameterized.product(
3800+
mns=((128, 128), (128, 64), (64, 128)),
3801+
layout=(mtu.RegisterLayout.WG_STRIDED, mtu.RegisterLayout.WGMMA),
3802+
)
3803+
def test_broadcast_major(self, mns, layout):
3804+
m, n = mns
3805+
3806+
if n < 128 and layout == mtu.RegisterLayout.WG_STRIDED:
3807+
self.skipTest(f"{n=} < 128 not supported for {layout=}")
3808+
3809+
dtype = jnp.float16
3810+
load_layout = (
3811+
layout.to_mgpu((n,), dtype)
3812+
if layout == mtu.RegisterLayout.WG_STRIDED
3813+
else mgpu.WGMMA_COL_LAYOUT
3814+
)
3815+
broadcast_layout = (
3816+
mgpu.WGStridedFragLayout((m, n), load_layout.vec_size)
3817+
if layout == mtu.RegisterLayout.WG_STRIDED
3818+
else layout.to_mgpu((m, n), dtype)
3819+
)
3820+
3821+
def load(gmem_input):
3822+
match layout:
3823+
case mtu.RegisterLayout.WG_STRIDED:
3824+
return mgpu.FragmentedArray.load_strided(
3825+
gmem_input, vec_size=load_layout.vec_size
3826+
)
3827+
case mtu.RegisterLayout.WGMMA:
3828+
return mgpu.FragmentedArray.load_untiled(
3829+
gmem_input, layout=mgpu.WGMMA_COL_LAYOUT, optimized=False
3830+
)
3831+
case _:
3832+
raise NotImplementedError(f"Unsupported layout: {layout}")
3833+
38013834
def kernel(ctx, gmem_input, gmem_output, _):
3802-
t = mgpu.FragmentedArray.load_untiled(
3803-
gmem_input, layout=mgpu.WGMMA_COL_LAYOUT, optimized=False
3835+
t = load(gmem_input)
3836+
t.broadcast_in_dim((m, n), (1,), broadcast_layout).store_untiled(
3837+
gmem_output, optimized=False
38043838
)
3805-
t.broadcast_in_dim((m, n), (1,), mgpu.WGMMA_LAYOUT).store_untiled(gmem_output, optimized=False)
38063839

3807-
inp = self.prng.uniform(-1, 1, (n,)).astype(jnp.float16)
3808-
out_shape = jax.ShapeDtypeStruct((m, n), jnp.float16)
3840+
inp = self.prng.uniform(-1, 1, (n,)).astype(dtype)
3841+
out_shape = jax.ShapeDtypeStruct((m, n), dtype)
38093842
result = mgpu.as_gpu_kernel(
38103843
kernel, (1, 1, 1), (128, 1, 1), (inp,), out_shape, inp
38113844
)(inp)

tests/pallas/mosaic_gpu_test.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2407,6 +2407,22 @@ def test_broadcast_in_dim_does_not_crash_on_small_shape(self):
24072407
shape, plgpu.Layout.TCGEN05_TMEM_NATIVE, axis=1, hint=False
24082408
)
24092409

2410+
def test_broadcast_in_dim_wg_strided_majormost_dim(self):
2411+
self.skip_if_wg_semantics()
2412+
@functools.partial(
2413+
self.pallas_call,
2414+
out_shape=jax.ShapeDtypeStruct((256, 128), jnp.float32),
2415+
)
2416+
def kernel(x_ref, y_ref):
2417+
to_be_broadcasted = plgpu.load(
2418+
x_ref, (), layout=plgpu.Layout.WG_STRIDED((128,), 1)
2419+
)
2420+
broadcasted = lax.broadcast_in_dim(to_be_broadcasted, (256, 128), (1,))
2421+
y_ref[...] = broadcasted
2422+
2423+
result = jax.random.uniform(jax.random.key(0), shape=(128,), dtype=jnp.float32)
2424+
np.testing.assert_array_equal(kernel(result), jnp.broadcast_to(result[None,:], (256, 128)))
2425+
24102426
def test_broadcast_in_dim_tcgen05_native_layout(self):
24112427
@functools.partial(
24122428
self.kernel,

0 commit comments

Comments
 (0)