Skip to content

Commit d5bf94e

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
[Pallas:MGPU] Add support for collective scale/sparse metadata copies to TMEM
PiperOrigin-RevId: 841879848
1 parent 31e61ae commit d5bf94e

File tree

3 files changed

+154
-8
lines changed

3 files changed

+154
-8
lines changed

jax/_src/pallas/mosaic_gpu/primitives.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
import numpy as np
5454

5555

56+
AxisName = jax_core.AxisName
5657
WARP_SIZE = 32
5758
WARPGROUP_SIZE = 128
5859

@@ -3202,7 +3203,10 @@ def _async_store_tmem_lowering_rule_wg(
32023203
async_copy_scales_to_tmem_p = jax_core.Primitive("async_copy_scales_to_tmem")
32033204
async_copy_scales_to_tmem_p.multiple_results = True
32043205

3205-
def async_copy_scales_to_tmem(smem_ref: _Ref, tmem_ref: _Ref):
3206+
3207+
def async_copy_scales_to_tmem(
3208+
smem_ref: _Ref, tmem_ref: _Ref, collective_axis: AxisName | None = None,
3209+
):
32063210
"""Copies the MMA scales from SMEM to TMEM.
32073211
32083212
The copy is performed asynchronously and can be awaited by calling
@@ -3226,12 +3230,17 @@ def async_copy_scales_to_tmem(smem_ref: _Ref, tmem_ref: _Ref):
32263230
async_copy_scales_to_tmem_p.bind(
32273231
smem_ref, tmem_ref, *flat_smem_transforms, *flat_tmem_transforms,
32283232
smem_tree=smem_transforms_treedef, tmem_tree=tmem_transforms_treedef,
3233+
collective_axis=collective_axis,
32293234
)
32303235

3236+
32313237
async_copy_sparse_metadata_to_tmem_p = jax_core.Primitive("async_copy_sparse_metadata_to_tmem")
32323238
async_copy_sparse_metadata_to_tmem_p.multiple_results = True
32333239

3234-
def async_copy_sparse_metadata_to_tmem(smem_ref: _Ref, tmem_ref: _Ref):
3240+
3241+
def async_copy_sparse_metadata_to_tmem(
3242+
smem_ref: _Ref, tmem_ref: _Ref, collective_axis: AxisName | None = None
3243+
):
32353244
"""Copies the MMA sparse metadata from SMEM to TMEM.
32363245
32373246
The copy is performed asynchronously and can be awaited by calling
@@ -3255,19 +3264,21 @@ def async_copy_sparse_metadata_to_tmem(smem_ref: _Ref, tmem_ref: _Ref):
32553264
async_copy_sparse_metadata_to_tmem_p.bind(
32563265
smem_ref, tmem_ref, *flat_smem_transforms, *flat_tmem_transforms,
32573266
smem_tree=smem_transforms_treedef, tmem_tree=tmem_transforms_treedef,
3267+
collective_axis=collective_axis,
32583268
)
32593269

3270+
32603271
@async_copy_scales_to_tmem_p.def_effectful_abstract_eval
32613272
@async_copy_sparse_metadata_to_tmem_p.def_effectful_abstract_eval
3262-
def _async_copy_to_tmem_abstract_eval(smem_ref, tmem_ref, *avals_flat, smem_tree, tmem_tree):
3273+
def _async_copy_to_tmem_abstract_eval(smem_ref, tmem_ref, *_args, **_kwargs):
32633274
if smem_ref.memory_space != gpu_core.MemorySpace.SMEM:
32643275
raise ValueError("async_copy_scales_to_tmem source must be an SMEM ref")
32653276
if tmem_ref.memory_space != gpu_core.MemorySpace.TMEM:
32663277
raise ValueError("async_copy_scales_to_tmem target must be a TMEM ref")
32673278
return (), {gpu_core._memory_effect}
32683279

32693280
def _async_copy_to_tmem_lowering_rule(
3270-
impl, ctx: lowering.LoweringRuleContext, smem_ref, tmem_ref, *leaves, smem_tree, tmem_tree
3281+
impl, ctx: lowering.LoweringRuleContext, smem_ref, tmem_ref, *leaves, smem_tree, tmem_tree, collective_axis
32713282
):
32723283
assert isinstance(tmem_ref, tcgen05.TMEMRef)
32733284
smem_leaves, tmem_leaves = util.split_list(leaves, [smem_tree.num_leaves])
@@ -3279,8 +3290,17 @@ def _async_copy_to_tmem_lowering_rule(
32793290
raise NotImplementedError(f"Unimplemented transforms for SMEM refs: {smem_transforms}")
32803291
if tmem_transforms:
32813292
raise NotImplementedError(f"Unimplemented transforms for TMEM refs: {tmem_transforms}")
3282-
with mgpu.when(ctx.module_ctx.single_lane_predicate):
3283-
impl(smem_ref, tmem_ref)
3293+
3294+
predicate = ctx.module_ctx.single_lane_predicate
3295+
if collective_axis is not None:
3296+
is_leader_block = _collective_mma_predicate(ctx, collective_axis)
3297+
predicate = arith_dialect.andi(predicate, is_leader_block)
3298+
collective = True
3299+
else:
3300+
collective = False
3301+
3302+
with mgpu.when(predicate):
3303+
impl(smem_ref, tmem_ref, collective=collective)
32843304
return ()
32853305

32863306
@lowering.register_lowering_rule(

tests/mosaic/gpu_test.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1760,11 +1760,10 @@ def format_scales(scales):
17601760

17611761
@parameterized.product(
17621762
m=(256,),
1763-
n=(64, 128, 256),
1763+
n=(256,),
17641764
scale_jax_dtype=(jnp.float8_e8m0fnu, jnp.float8_e4m3fn),
17651765
)
17661766
def test_mma_block_scaled_collective(self, m, n, scale_jax_dtype):
1767-
m, n = 256, 256
17681767
in_jax_dtype = jnp.float4_e2m1fn
17691768
out_jax_dtype = jnp.float32
17701769
scale_block = 32 if scale_jax_dtype == jnp.float8_e8m0fnu else 16

tests/pallas/mosaic_gpu_test.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3768,6 +3768,133 @@ def format_scales(scales):
37683768
)
37693769
np.testing.assert_allclose(result, expected, rtol=1e-3)
37703770

3771+
@parameterized.product(
3772+
m=[256],
3773+
n=[256],
3774+
scale_jax_dtype=[jnp.float8_e8m0fnu, jnp.float8_e4m3fn],
3775+
)
3776+
def test_collective_scaled_matmul(self, m, n, scale_jax_dtype):
3777+
self.skip_if_wg_semantics()
3778+
3779+
in_jax_dtype = jnp.float4_e2m1fn
3780+
out_jax_dtype = jnp.float32
3781+
scale_block = 32 if scale_jax_dtype == jnp.float8_e8m0fnu else 16
3782+
swizzle = 128
3783+
k_steps = 2
3784+
swizzle_elems = 8 * swizzle // dtypes.itemsize_bits(in_jax_dtype)
3785+
k = swizzle_elems * k_steps
3786+
tiling = (8, swizzle_elems)
3787+
transforms = (
3788+
plgpu.TilingTransform(tiling), plgpu.SwizzleTransform(swizzle)
3789+
)
3790+
out_transforms = self.default_transforms(dtype=out_jax_dtype)
3791+
3792+
m_block = m // 2
3793+
n_block = n // 2
3794+
3795+
def kernel(lhs_gmem, rhs_gmem, lhs_scales_gmem, rhs_scales_gmem, out_gmem,
3796+
lhs_smem, rhs_smem, lhs_scales_smem, rhs_scales_smem, out_smem,
3797+
tma_barrier, mma_barrier,
3798+
acc_tmem, lhs_scales_tmem, rhs_scales_tmem):
3799+
plgpu.copy_gmem_to_smem(lhs_gmem, lhs_smem, tma_barrier,
3800+
collective_axes="x", partitioned_axis=0)
3801+
plgpu.copy_gmem_to_smem(rhs_gmem, rhs_smem, tma_barrier,
3802+
collective_axes="x", partitioned_axis=0)
3803+
plgpu.copy_gmem_to_smem(lhs_scales_gmem, lhs_scales_smem, tma_barrier,
3804+
collective_axes="x", partitioned_axis=0)
3805+
# RHS scales are replicated (multicast)
3806+
plgpu.copy_gmem_to_smem(rhs_scales_gmem, rhs_scales_smem, tma_barrier,
3807+
collective_axes="x", partitioned_axis=None)
3808+
cluster_idx = lax.axis_index("x")
3809+
3810+
@pl.when(cluster_idx == 0)
3811+
def _leader_block():
3812+
plgpu.barrier_wait(tma_barrier)
3813+
plgpu.async_copy_scales_to_tmem(lhs_scales_smem, lhs_scales_tmem, collective_axis="x")
3814+
plgpu.async_copy_scales_to_tmem(rhs_scales_smem, rhs_scales_tmem, collective_axis="x")
3815+
plgpu.tcgen05_mma(
3816+
acc_tmem,
3817+
lhs_smem,
3818+
plgpu.transpose_ref(rhs_smem, (1, 0)),
3819+
mma_barrier,
3820+
a_scale=lhs_scales_tmem,
3821+
b_scale=rhs_scales_tmem,
3822+
accumulate=False,
3823+
collective_axis="x"
3824+
)
3825+
plgpu.barrier_wait(mma_barrier)
3826+
3827+
out_smem[...] = plgpu.async_load_tmem(acc_tmem)
3828+
plgpu.commit_smem()
3829+
slice_out = pl.ds(cluster_idx * m_block, m_block)
3830+
plgpu.copy_smem_to_gmem(out_smem, out_gmem.at[slice_out, :])
3831+
plgpu.wait_smem_to_gmem(0)
3832+
3833+
scratch_shapes = [
3834+
plgpu.SMEM((m_block, k), in_jax_dtype, transforms=transforms),
3835+
plgpu.SMEM((n_block, k), in_jax_dtype, transforms=transforms),
3836+
plgpu.SMEM((m_block // 128, k // (scale_block * 4), 32, 16), scale_jax_dtype),
3837+
plgpu.SMEM((n // 128, k // (scale_block * 4), 32, 16), scale_jax_dtype),
3838+
plgpu.SMEM((m_block, n), out_jax_dtype, transforms=out_transforms),
3839+
plgpu.Barrier(num_arrivals=4),
3840+
plgpu.Barrier(orders_tensor_core=True),
3841+
plgpu.TMEM((m_block, n), out_jax_dtype, collective=True),
3842+
plgpu.TMEM((m_block, k // scale_block), scale_jax_dtype,
3843+
layout=plgpu.TMEMLayout.SCALES_LAYOUT, collective=True),
3844+
plgpu.TMEM((n, k // scale_block), scale_jax_dtype,
3845+
layout=plgpu.TMEMLayout.SCALES_LAYOUT, collective=True),
3846+
]
3847+
3848+
f = self.kernel(
3849+
kernel,
3850+
out_shape=jax.ShapeDtypeStruct((m, n), out_jax_dtype),
3851+
grid=(1,),
3852+
grid_names=("_",),
3853+
cluster=(2,),
3854+
cluster_names=("x",),
3855+
scratch_shapes=scratch_shapes,
3856+
)
3857+
3858+
x = jax.random.uniform(jax.random.key(1), shape=(m, k), dtype=jnp.float32).astype(in_jax_dtype)
3859+
y = jax.random.uniform(jax.random.key(2), shape=(n, k), dtype=jnp.float32).astype(in_jax_dtype)
3860+
3861+
ka, kb = jax.random.split(jax.random.key(1234), 2)
3862+
if scale_jax_dtype == jnp.float8_e8m0fnu:
3863+
x_scale = jax.lax.bitcast_convert_type(
3864+
jax.random.randint(ka, (m, k // scale_block), 122, 132, dtype=jnp.uint8),
3865+
scale_jax_dtype
3866+
)
3867+
y_scale = jax.lax.bitcast_convert_type(
3868+
jax.random.randint(kb, (n, k // scale_block), 122, 132, dtype=jnp.uint8),
3869+
scale_jax_dtype
3870+
)
3871+
else:
3872+
x_scale = jnp.abs(
3873+
jax.random.normal(ka, (m, k // scale_block), dtype=jnp.float32).astype(scale_jax_dtype)
3874+
)
3875+
y_scale = jnp.abs(
3876+
jax.random.normal(kb, (n, k // scale_block), dtype=jnp.float32).astype(scale_jax_dtype)
3877+
)
3878+
3879+
def format_scales(scales):
3880+
mn, k = scales.shape
3881+
assert mn % 128 == 0 and k % 4 == 0
3882+
return (
3883+
scales.reshape(mn // 128, 4, 32, k // 4, 4)
3884+
.transpose(0, 3, 2, 1, 4)
3885+
.reshape(mn // 128, k // 4, 32, 16)
3886+
)
3887+
3888+
result = f(x, y, format_scales(x_scale), format_scales(y_scale))
3889+
3890+
x_logical_scale = jnp.repeat(x_scale, scale_block, axis=1).astype(jnp.float32)
3891+
y_logical_scale = jnp.repeat(y_scale, scale_block, axis=1).astype(jnp.float32)
3892+
expected = jnp.dot(
3893+
x.astype(jnp.float32) * x_logical_scale,
3894+
(y.astype(jnp.float32) * y_logical_scale).T,
3895+
)
3896+
np.testing.assert_allclose(result, expected, rtol=1e-3)
3897+
37713898
@parameterized.product(
37723899
m=[128],
37733900
n=[128, 256],

0 commit comments

Comments
 (0)