Skip to content

Commit e288c27

Browse files
allanrenucciGoogle-ML-Automation
authored andcommitted
[Mosaic GPU] Support FragmentedArray.broadcast_in_dim for splat to other layouts.
PiperOrigin-RevId: 842145211
1 parent 3bec82d commit e288c27

File tree

2 files changed

+28
-11
lines changed

2 files changed

+28
-11
lines changed

jax/experimental/mosaic/gpu/fragmented_array.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1301,9 +1301,8 @@ def to_layout(self, new_layout: FragmentedLayout) -> FragmentedArray:
13011301
raise NotImplementedError(
13021302
f"Cannot convert from {self.layout} to {new_layout}"
13031303
)
1304-
[reg] = self.registers.flat
13051304
return type(self).splat(
1306-
reg, self.shape, new_layout, is_signed=self.is_signed
1305+
self.registers.item(), self.shape, new_layout, is_signed=self.is_signed
13071306
)
13081307

13091308
def _pointwise(
@@ -2502,15 +2501,9 @@ def broadcast_in_dim(
25022501
f" {shape[target_dim]} in shape after broadcast"
25032502
)
25042503
if isinstance(self.layout, WGSplatFragLayout):
2505-
if isinstance(layout, WGSplatFragLayout):
2506-
if layout.shape != shape:
2507-
raise ValueError(
2508-
f"Layout shape {layout.shape} does not match broadcast shape {shape}"
2509-
)
2510-
return FragmentedArray(
2511-
_registers=self.registers, _layout=layout, _is_signed=self.is_signed,
2512-
)
2513-
# TODO: Support splat to other layouts
2504+
return type(self).splat(
2505+
self.registers.item(), shape, layout, is_signed=self.is_signed
2506+
)
25142507
if not isinstance(self.layout, TiledLayout) or not isinstance(layout, TiledLayout):
25152508
raise NotImplementedError(self.layout, layout)
25162509
if any(d1 >= d2 for d1, d2 in zip(source_dimensions, source_dimensions[1:])):

tests/mosaic/gpu_test.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3812,6 +3812,30 @@ def kernel(ctx, gmem_input, gmem_output, _):
38123812
out_ref = jax.lax.broadcast_in_dim(inp, (m, n), (1,))
38133813
np.testing.assert_array_equal(result, out_ref)
38143814

3815+
@parameterized.parameters(*mtu.RegisterLayout)
3816+
def test_broadcast_splat(self, layout):
3817+
out_shape = (128, 128)
3818+
3819+
def body(ctx, out_ref, scratch):
3820+
del ctx, scratch
3821+
c42 = arith.constant(ir.IntegerType.get_signless(32), 42)
3822+
arr = mgpu.FragmentedArray.splat(c42, (128,), is_signed=True)
3823+
out_layout = layout.to_mgpu(out_shape, jnp.int32)
3824+
result = arr.broadcast_in_dim(out_shape, (0,), out_layout)
3825+
result.store_untiled(out_ref, optimized=False)
3826+
3827+
kernel = mgpu.as_gpu_kernel(
3828+
body,
3829+
grid=(1, 1, 1),
3830+
block=(128, 1, 1),
3831+
in_shape=(),
3832+
out_shape=jax.ShapeDtypeStruct(out_shape, jnp.int32),
3833+
smem_scratch_shape=[],
3834+
)
3835+
np.testing.assert_array_equal(
3836+
kernel(), np.full(out_shape, 42, dtype=np.int32)
3837+
)
3838+
38153839
def test_warp_tree_reduce(self):
38163840
def kernel(ctx, out, *_):
38173841
del ctx

0 commit comments

Comments
 (0)