@@ -3768,6 +3768,133 @@ def format_scales(scales):
37683768 )
37693769 np .testing .assert_allclose (result , expected , rtol = 1e-3 )
37703770
3771+ @parameterized .product (
3772+ m = [256 ],
3773+ n = [256 ],
3774+ scale_jax_dtype = [jnp .float8_e8m0fnu , jnp .float8_e4m3fn ],
3775+ )
3776+ def test_collective_scaled_matmul (self , m , n , scale_jax_dtype ):
3777+ self .skip_if_wg_semantics ()
3778+
3779+ in_jax_dtype = jnp .float4_e2m1fn
3780+ out_jax_dtype = jnp .float32
3781+ scale_block = 32 if scale_jax_dtype == jnp .float8_e8m0fnu else 16
3782+ swizzle = 128
3783+ k_steps = 2
3784+ swizzle_elems = 8 * swizzle // dtypes .itemsize_bits (in_jax_dtype )
3785+ k = swizzle_elems * k_steps
3786+ tiling = (8 , swizzle_elems )
3787+ transforms = (
3788+ plgpu .TilingTransform (tiling ), plgpu .SwizzleTransform (swizzle )
3789+ )
3790+ out_transforms = self .default_transforms (dtype = out_jax_dtype )
3791+
3792+ m_block = m // 2
3793+ n_block = n // 2
3794+
3795+ def kernel (lhs_gmem , rhs_gmem , lhs_scales_gmem , rhs_scales_gmem , out_gmem ,
3796+ lhs_smem , rhs_smem , lhs_scales_smem , rhs_scales_smem , out_smem ,
3797+ tma_barrier , mma_barrier ,
3798+ acc_tmem , lhs_scales_tmem , rhs_scales_tmem ):
3799+ plgpu .copy_gmem_to_smem (lhs_gmem , lhs_smem , tma_barrier ,
3800+ collective_axes = "x" , partitioned_axis = 0 )
3801+ plgpu .copy_gmem_to_smem (rhs_gmem , rhs_smem , tma_barrier ,
3802+ collective_axes = "x" , partitioned_axis = 0 )
3803+ plgpu .copy_gmem_to_smem (lhs_scales_gmem , lhs_scales_smem , tma_barrier ,
3804+ collective_axes = "x" , partitioned_axis = 0 )
3805+ # RHS scales are replicated (multicast)
3806+ plgpu .copy_gmem_to_smem (rhs_scales_gmem , rhs_scales_smem , tma_barrier ,
3807+ collective_axes = "x" , partitioned_axis = None )
3808+ cluster_idx = lax .axis_index ("x" )
3809+
3810+ @pl .when (cluster_idx == 0 )
3811+ def _leader_block ():
3812+ plgpu .barrier_wait (tma_barrier )
3813+ plgpu .async_copy_scales_to_tmem (lhs_scales_smem , lhs_scales_tmem , collective_axis = "x" )
3814+ plgpu .async_copy_scales_to_tmem (rhs_scales_smem , rhs_scales_tmem , collective_axis = "x" )
3815+ plgpu .tcgen05_mma (
3816+ acc_tmem ,
3817+ lhs_smem ,
3818+ plgpu .transpose_ref (rhs_smem , (1 , 0 )),
3819+ mma_barrier ,
3820+ a_scale = lhs_scales_tmem ,
3821+ b_scale = rhs_scales_tmem ,
3822+ accumulate = False ,
3823+ collective_axis = "x"
3824+ )
3825+ plgpu .barrier_wait (mma_barrier )
3826+
3827+ out_smem [...] = plgpu .async_load_tmem (acc_tmem )
3828+ plgpu .commit_smem ()
3829+ slice_out = pl .ds (cluster_idx * m_block , m_block )
3830+ plgpu .copy_smem_to_gmem (out_smem , out_gmem .at [slice_out , :])
3831+ plgpu .wait_smem_to_gmem (0 )
3832+
3833+ scratch_shapes = [
3834+ plgpu .SMEM ((m_block , k ), in_jax_dtype , transforms = transforms ),
3835+ plgpu .SMEM ((n_block , k ), in_jax_dtype , transforms = transforms ),
3836+ plgpu .SMEM ((m_block // 128 , k // (scale_block * 4 ), 32 , 16 ), scale_jax_dtype ),
3837+ plgpu .SMEM ((n // 128 , k // (scale_block * 4 ), 32 , 16 ), scale_jax_dtype ),
3838+ plgpu .SMEM ((m_block , n ), out_jax_dtype , transforms = out_transforms ),
3839+ plgpu .Barrier (num_arrivals = 4 ),
3840+ plgpu .Barrier (orders_tensor_core = True ),
3841+ plgpu .TMEM ((m_block , n ), out_jax_dtype , collective = True ),
3842+ plgpu .TMEM ((m_block , k // scale_block ), scale_jax_dtype ,
3843+ layout = plgpu .TMEMLayout .SCALES_LAYOUT , collective = True ),
3844+ plgpu .TMEM ((n , k // scale_block ), scale_jax_dtype ,
3845+ layout = plgpu .TMEMLayout .SCALES_LAYOUT , collective = True ),
3846+ ]
3847+
3848+ f = self .kernel (
3849+ kernel ,
3850+ out_shape = jax .ShapeDtypeStruct ((m , n ), out_jax_dtype ),
3851+ grid = (1 ,),
3852+ grid_names = ("_" ,),
3853+ cluster = (2 ,),
3854+ cluster_names = ("x" ,),
3855+ scratch_shapes = scratch_shapes ,
3856+ )
3857+
3858+ x = jax .random .uniform (jax .random .key (1 ), shape = (m , k ), dtype = jnp .float32 ).astype (in_jax_dtype )
3859+ y = jax .random .uniform (jax .random .key (2 ), shape = (n , k ), dtype = jnp .float32 ).astype (in_jax_dtype )
3860+
3861+ ka , kb = jax .random .split (jax .random .key (1234 ), 2 )
3862+ if scale_jax_dtype == jnp .float8_e8m0fnu :
3863+ x_scale = jax .lax .bitcast_convert_type (
3864+ jax .random .randint (ka , (m , k // scale_block ), 122 , 132 , dtype = jnp .uint8 ),
3865+ scale_jax_dtype
3866+ )
3867+ y_scale = jax .lax .bitcast_convert_type (
3868+ jax .random .randint (kb , (n , k // scale_block ), 122 , 132 , dtype = jnp .uint8 ),
3869+ scale_jax_dtype
3870+ )
3871+ else :
3872+ x_scale = jnp .abs (
3873+ jax .random .normal (ka , (m , k // scale_block ), dtype = jnp .float32 ).astype (scale_jax_dtype )
3874+ )
3875+ y_scale = jnp .abs (
3876+ jax .random .normal (kb , (n , k // scale_block ), dtype = jnp .float32 ).astype (scale_jax_dtype )
3877+ )
3878+
3879+ def format_scales (scales ):
3880+ mn , k = scales .shape
3881+ assert mn % 128 == 0 and k % 4 == 0
3882+ return (
3883+ scales .reshape (mn // 128 , 4 , 32 , k // 4 , 4 )
3884+ .transpose (0 , 3 , 2 , 1 , 4 )
3885+ .reshape (mn // 128 , k // 4 , 32 , 16 )
3886+ )
3887+
3888+ result = f (x , y , format_scales (x_scale ), format_scales (y_scale ))
3889+
3890+ x_logical_scale = jnp .repeat (x_scale , scale_block , axis = 1 ).astype (jnp .float32 )
3891+ y_logical_scale = jnp .repeat (y_scale , scale_block , axis = 1 ).astype (jnp .float32 )
3892+ expected = jnp .dot (
3893+ x .astype (jnp .float32 ) * x_logical_scale ,
3894+ (y .astype (jnp .float32 ) * y_logical_scale ).T ,
3895+ )
3896+ np .testing .assert_allclose (result , expected , rtol = 1e-3 )
3897+
37713898 @parameterized .product (
37723899 m = [128 ],
37733900 n = [128 , 256 ],
0 commit comments