Skip to content

Commit ab3b9b2

Browse files
allanrenucciGoogle-ML-Automation
authored andcommitted
[Mosaic GPU] Exclude strided layouts in reduction rules.
We currently don't support reducing strided layouts. PiperOrigin-RevId: 841886121
1 parent 6cb78e5 commit ab3b9b2

File tree

2 files changed

+32
-7
lines changed

2 files changed

+32
-7
lines changed

jax/experimental/mosaic/gpu/layout_inference.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1042,12 +1042,12 @@ def _vector_reduction_constraint_system(
10421042
return cs.ConstraintSystem(), {in_variable: [in_variable.key]}, []
10431043

10441044

1045-
def _reduction_constraint_and_hint(
1045+
def _reduction_constraints_and_hint(
10461046
larger: cs.Variable,
10471047
smaller: cs.Variable,
10481048
larger_shape: tuple[int, ...],
10491049
reduction_dims: tuple[int, ...],
1050-
) -> tuple[cs.Constraint, Hint]:
1050+
) -> tuple[list[cs.Constraint], Hint]:
10511051
reduce_expr = cs.Reduce(larger, reduction_dims)
10521052
# There are always many options for broadcasting a layout, so we can only
10531053
# derive a broadcast hint in the out_variable -> source_variable direction.
@@ -1056,7 +1056,12 @@ def _reduction_constraint_and_hint(
10561056
)
10571057
broadcast_expr = cs.BroadcastInDim(smaller, broadcast_dims, larger_shape)
10581058
broadcast_hint = Hint(variable=larger, expression=broadcast_expr)
1059-
return cs.Equals(lhs=smaller, rhs=reduce_expr), broadcast_hint
1059+
constraints = [
1060+
cs.Equals(lhs=smaller, rhs=reduce_expr),
1061+
# TODO(allanrenucci): Remove once we support reduction of strided layouts.
1062+
cs.NotOfType(larger, fa.WGStridedFragLayout),
1063+
]
1064+
return constraints, broadcast_hint
10601065

10611066

10621067
@_add_constraint_system_derivation_rule(vector.MultiDimReductionOp)
@@ -1071,7 +1076,7 @@ def _multi_dim_reduction_constraint_system(
10711076
source_variable = cs.Variable(source)
10721077
out_variable = cs.Variable(out)
10731078

1074-
reduction_constraint, broadcast_hint = _reduction_constraint_and_hint(
1079+
reduction_constraints, broadcast_hint = _reduction_constraints_and_hint(
10751080
source_variable,
10761081
out_variable,
10771082
tuple(ir.ShapedType(op.source.type).shape),
@@ -1081,7 +1086,7 @@ def _multi_dim_reduction_constraint_system(
10811086
# strided layouts from being chosen---since trying to reduce a strided layout
10821087
# may cause us to raise an Exception at the moment.
10831088
return (
1084-
cs.ConstraintSystem(constraints=[reduction_constraint]),
1089+
cs.ConstraintSystem(constraints=reduction_constraints),
10851090
{source_variable: [source], out_variable: [acc, out]},
10861091
[broadcast_hint],
10871092
)
@@ -1100,12 +1105,12 @@ def _broadcast_in_dim_constraint_system(
11001105
i for i in range(len(out_shape)) if i not in op.broadcast_dimensions
11011106
)
11021107

1103-
reduction_constraint, broadcast_hint = _reduction_constraint_and_hint(
1108+
reduction_constraints, broadcast_hint = _reduction_constraints_and_hint(
11041109
out_variable, source_variable, out_shape, reduction_dims
11051110
)
11061111

11071112
return (
1108-
cs.ConstraintSystem(constraints=[reduction_constraint]),
1113+
cs.ConstraintSystem(constraints=reduction_constraints),
11091114
{
11101115
source_variable: [source_variable.key],
11111116
out_variable: [out_variable.key],

tests/mosaic/gpu_layout_inference_test.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,26 @@ def test_infer_broadcast_in_dim_layout(self, layout, axis, hint_on_input):
327327
self.checkInLayouts(bcast, [in_layout])
328328
self.checkOutLayouts(bcast, [out_layout])
329329

330+
# TODO(allanrenucci): Turn into a positive test. This is currently not
331+
# implemented. The test checks we fail gracefully.
332+
@parameterized.parameters(True, False)
333+
def test_cant_infer_reduced_strided_layout(self, hint_on_input):
334+
with ir.InsertionPoint(self.module.body):
335+
[x] = undefs(ir.VectorType.get((128,), ir.F32Type.get()))
336+
if hint_on_input:
337+
layout = mgpu.WGStridedFragLayout.from_shaped_type(x.type)
338+
x = layout_cast(x, layout)
339+
out_type = ir.VectorType.get((128, 128), ir.F32Type.get())
340+
out = mgpu.dialect.broadcast_in_dim(out_type, x, [0])
341+
if not hint_on_input:
342+
layout = mgpu.WGStridedFragLayout.from_shaped_type(out.type)
343+
layout_cast(out, layout)
344+
345+
with self.assertRaisesRegex(
346+
ValueError, "Failed to infer a possible set of layouts"
347+
):
348+
mgpu.infer_layout(self.module)
349+
330350
@parameterized.parameters(
331351
(1, mgpu.WGMMA_LAYOUT, None, None),
332352
(0, mgpu.WGMMA_LAYOUT, None, None),

0 commit comments

Comments
 (0)