Skip to content

Commit 11bb981

Browse files
allanrenucciGoogle-ML-Automation
authored andcommitted
[Pallas:MGPU] Support not tiled transposed loads in swap LANE lowering rule.
PiperOrigin-RevId: 842740700
1 parent c7bc9b2 commit 11bb981

File tree

2 files changed

+6
-7
lines changed

2 files changed

+6
-7
lines changed

jax/_src/pallas/mosaic_gpu/lowering.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1717,9 +1717,12 @@ def _swap_lowering_rule(
17171717
layout=value.layout,
17181718
)
17191719
value.store_tiled(x_smem, swizzle=swizzle)
1720-
case ():
1720+
case () | (gpu_core.TransposeRef((1, 0)),):
1721+
transposed = bool(transforms)
17211722
match value.layout:
17221723
case mgpu.TiledLayout():
1724+
if transposed:
1725+
x_smem = mgpu.memref_transpose(x_smem, (1, 0))
17231726
old_value = mgpu.FragmentedArray.load_untiled(
17241727
x_smem,
17251728
layout=value.layout,
@@ -1728,6 +1731,8 @@ def _swap_lowering_rule(
17281731
)
17291732
value.store_untiled(x_smem, optimized=False)
17301733
case _:
1734+
if transposed:
1735+
raise NotImplementedError(f"Unsupported transforms: {transforms}")
17311736
old_value = mgpu.FragmentedArray.load_strided(
17321737
x_smem, is_signed=mgpu_utils.is_signed(v_aval.dtype)
17331738
)

tests/pallas/mosaic_gpu_test.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1056,12 +1056,6 @@ def test_transposed_load_store(self, src_layout, dst_layout):
10561056
def is_transposed(layout):
10571057
return layout == plgpu.Layout.WGMMA_TRANSPOSED
10581058

1059-
if (
1060-
self.LOWERING_SEMANTICS == mgpu.LoweringSemantics.Lane
1061-
and is_transposed(dst_layout)
1062-
):
1063-
self.skipTest("Not implemented: transposed, not tiled")
1064-
10651059
shape, dtype = (128, 128), jnp.float32
10661060

10671061
@functools.partial(

0 commit comments

Comments
 (0)