Skip to content

Commit e43f4cb

Browse files
allanrenucciGoogle-ML-Automation
authored andcommitted
[Pallas:MGPU] Expose fragmented_array.Replicated as part of the public API.
PiperOrigin-RevId: 842675493
1 parent 9ec840f commit e43f4cb

File tree

2 files changed

+18
-0
lines changed

2 files changed

+18
-0
lines changed

jax/experimental/pallas/mosaic_gpu.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@
8787
from jax._src.pallas.mosaic_gpu.primitives import wgmma_wait as wgmma_wait
8888
from jax._src.pallas.mosaic_gpu.torch import as_torch_kernel as as_torch_kernel
8989
from jax.experimental.mosaic.gpu.core import LoweringSemantics as LoweringSemantics
90+
from jax.experimental.mosaic.gpu.fragmented_array import Replicated as Replicated
9091
from jax.experimental.mosaic.gpu.fragmented_array import Tiling as Tiling
9192

9293

tests/pallas/mosaic_gpu_test.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2563,6 +2563,23 @@ def kernel(dst, collective_barrier):
25632563
)()
25642564
np.testing.assert_array_equal(y, np.ones((), dtype=np.int32))
25652565

2566+
def test_replicated_layout(self):
2567+
shape = (32,)
2568+
@functools.partial(
2569+
self.pallas_call,
2570+
out_shape=jax.ShapeDtypeStruct(shape, jnp.float32),
2571+
)
2572+
def kernel(src_ref, dst_ref):
2573+
layout = plgpu.Layout.TILED(
2574+
plgpu.Tiling(((32,), (1,))),
2575+
warp_dims=(plgpu.Replicated(4),),
2576+
lane_dims=(-2,),
2577+
vector_dim=-1,
2578+
)
2579+
dst_ref[...] = plgpu.load(src_ref, (), layout=layout)
2580+
src = jnp.arange(shape[0], dtype=jnp.float32)
2581+
np.testing.assert_array_equal(kernel(src), src)
2582+
25662583

25672584
class PallasCallWarpPrimitiveSemanticsTest(PallasTest):
25682585
def setUp(self):

0 commit comments

Comments
 (0)