@@ -1907,7 +1907,7 @@ def kernel(x_ref, y_ref, o_ref):
19071907 y = jax .lax .iota (jnp .float32 , 128 ) * 3
19081908 np .testing .assert_array_equal (kernel (x , y ), x + y )
19091909
1910- def test_smem_aliasing_works (self ):
1910+ def test_smem_aliasing_works_basic (self ):
19111911 self .skip_if_wg_semantics ()
19121912
19131913 in_shape = (2 , 256 )
@@ -1938,17 +1938,16 @@ def test_smem_aliasing_works(self):
19381938 plgpu .SMEM (
19391939 (128 ,),
19401940 jnp .float32 ,
1941- transforms = (plgpu .TilingTransform ((64 ,)),),
1942- ),
1941+ transforms = (plgpu .TilingTransform ((64 ,)),)),
19431942 ]
19441943 ],
19451944 )
19461945 ],
19471946 )
19481947 def kernel (x_ref , o_ref128 , aliased_ref ):
1949- smem_ref256 , _ , smem_ref128 = aliased_ref
1948+ smem_ref256 , [ _ , [ smem_ref128 ]] = aliased_ref
19501949 # Ensure that extraction via index works the same as unfolding.
1951- smem_ref128_2 = aliased_ref [2 ]
1950+ smem_ref128_2 = aliased_ref [1 ][ 1 ][ 0 ]
19521951 self .assertIsInstance (smem_ref128 , state_types .TransformedRef )
19531952 self .assertIsInstance (smem_ref128_2 , state_types .TransformedRef )
19541953 self .assertIs (smem_ref128 .ref , smem_ref128_2 .ref )
@@ -2005,7 +2004,7 @@ def test_smem_aliasing_works_with_subbyte_dtypes(self):
20052004 ],
20062005 )
20072006 def kernel (x_ref , o_refi4 , aliased_ref ):
2008- _ , smem_refi8 , _ , smem_refi4 = aliased_ref
2007+ [ _ , smem_refi8 ], [ _ , smem_refi4 ] = aliased_ref
20092008 smem_refi8 [...] = x_ref [...]
20102009 plgpu .commit_smem ()
20112010 plgpu .copy_smem_to_gmem (smem_refi4 , o_refi4 )
@@ -3415,7 +3414,7 @@ def test_tmem_ref_aliasing(self):
34153414 thread_name = "x" ,
34163415 )
34173416 def kernel (x_ref , y_ref , aliased_ref , smem_ref , barrier_ref ):
3418- tmem_128x32a , tmem_128x32b , tmem_128x64 = aliased_ref
3417+ [ tmem_128x32a , tmem_128x32b ] , tmem_128x64 = aliased_ref
34193418 plgpu .copy_gmem_to_smem (x_ref , smem_ref , barrier_ref )
34203419 plgpu .barrier_wait (barrier_ref )
34213420 # Test tmem_128x32 a and b
@@ -4268,7 +4267,7 @@ def kernel(a_gmem, b_gmem, out_gmem128, out_gmem64,
42684267 plgpu .barrier_wait (tma_barrier )
42694268 plgpu .copy_gmem_to_smem (b_gmem , b_smem , tma_barrier )
42704269 plgpu .barrier_wait (tma_barrier )
4271- acc_128 , lhs_128 , lhs_64 , acc_64 , _ = aliased_refs
4270+ [ acc_128 , lhs_128 ], [ lhs_64 , acc_64 ] , _ = aliased_refs
42724271
42734272 # Do 128x128 @ 128x128 matmul
42744273 plgpu .async_store_tmem (lhs_128 , plgpu .load (a_smem , (), layout = plgpu .Layout .TCGEN05 ))
@@ -4305,21 +4304,27 @@ def kernel(a_gmem, b_gmem, out_gmem128, out_gmem64,
43054304
43064305 f = self .kernel (
43074306 kernel ,
4308- out_shape = [jax .ShapeDtypeStruct (shape , dtype ),
4309- jax .ShapeDtypeStruct (shape , dtype )],
4307+ out_shape = [
4308+ jax .ShapeDtypeStruct (shape , dtype ),
4309+ jax .ShapeDtypeStruct (shape , dtype ),
4310+ ],
43104311 scratch_shapes = [
4311- plgpu .SMEM (shape , dtype , transforms = transforms ), # a_smem
4312- plgpu .SMEM (shape , dtype , transforms = transforms ), # b_smem
4313- plgpu .SMEM (shape , dtype , transforms = transforms ), # out_smem
4314- plgpu .Barrier (), # tma_barrier
4315- plgpu .Barrier (orders_tensor_core = True ), # mma_barrier
4316- plgpu .RefUnion ( # aliased_refs
4317- [plgpu .TMEM ((128 , 128 ), jnp .float32 ), # acc
4318- plgpu .TMEM ((128 , 128 ), dtype , packed = True )], # lhs
4319- [plgpu .TMEM ((128 , 64 ), dtype , packed = True ), # lhs
4320- plgpu .TMEM ((128 , 128 ), jnp .float32 )], # acc
4321- plgpu .TMEM ((128 , 128 ), jnp .float32 ) # unused
4322- ),
4312+ plgpu .SMEM (shape , dtype , transforms = transforms ), # a_smem
4313+ plgpu .SMEM (shape , dtype , transforms = transforms ), # b_smem
4314+ plgpu .SMEM (shape , dtype , transforms = transforms ), # out_smem
4315+ plgpu .Barrier (), # tma_barrier
4316+ plgpu .Barrier (orders_tensor_core = True ), # mma_barrier
4317+ plgpu .RefUnion ( # aliased_refs
4318+ [
4319+ plgpu .TMEM ((128 , 128 ), jnp .float32 ), # acc
4320+ plgpu .TMEM ((128 , 128 ), dtype , packed = True ), # lhs
4321+ ],
4322+ [
4323+ plgpu .TMEM ((128 , 64 ), dtype , packed = True ), # lhs
4324+ plgpu .TMEM ((128 , 128 ), jnp .float32 ), # acc
4325+ ],
4326+ plgpu .TMEM ((128 , 128 ), jnp .float32 ), # unused
4327+ ),
43234328 ],
43244329 )
43254330 x = jax .random .uniform (jax .random .key (0 ), shape = shape , dtype = dtype )
0 commit comments