Skip to content

Commit 3bec82d

Browse files
bchetiouiGoogle-ML-Automation
authored andcommitted
[Mosaic GPU] Remove {Least,Most}ReplicatedExpression constructors.
After getting rid of hints, these constructors are no longer necessary. PiperOrigin-RevId: 842112040
1 parent 9be13f9 commit 3bec82d

File tree

4 files changed

+3
-314
lines changed

4 files changed

+3
-314
lines changed

jax/experimental/mosaic/gpu/constraints.py

Lines changed: 1 addition & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from collections.abc import Sequence
2525
import dataclasses
2626
import math
27-
from typing import Any, Callable, assert_never, final
27+
from typing import Any, assert_never, final
2828

2929
from . import fragmented_array as fa
3030
from . import launch_context as lc
@@ -86,22 +86,6 @@ def __str__(self):
8686
return f"C({self.value})"
8787

8888

89-
@dataclasses.dataclass(frozen=True)
90-
class LeastReplicated:
91-
expressions: tuple[Expression, ...]
92-
93-
def __post_init__(self):
94-
assert len(self.expressions) >= 1
95-
96-
97-
@dataclasses.dataclass(frozen=True)
98-
class MostReplicated:
99-
expressions: tuple[Expression, ...]
100-
101-
def __post_init__(self):
102-
assert len(self.expressions) >= 1
103-
104-
10589
@dataclasses.dataclass(frozen=True)
10690
class Reduce:
10791
expression: Expression
@@ -136,71 +120,13 @@ def __str__(self):
136120
Expression = (
137121
Variable
138122
| Constant
139-
| LeastReplicated
140-
| MostReplicated
141123
| Reduce
142124
| BroadcastInDim
143125
| Reshape
144126
| Transpose
145127
)
146128

147129

148-
def reduce_replicated_expression(
149-
input_expr: LeastReplicated | MostReplicated,
150-
assignments: dict[Variable, Constant],
151-
reducer: Callable[[fa.FragmentedLayout, fa.FragmentedLayout], fa.FragmentedLayout | None]
152-
) -> Expression | Unsatisfiable:
153-
assert input_expr.expressions
154-
155-
new_expressions: list[Expression] = []
156-
# Use a set to eliminate duplicates, but preserve the order.
157-
seen: set[Expression] = set()
158-
for expr in input_expr.expressions:
159-
reduced_expr = reduce_expression(expr, assignments)
160-
if isinstance(reduced_expr, Unsatisfiable):
161-
return Unsatisfiable()
162-
if reduced_expr in seen:
163-
continue
164-
new_expressions.append(reduced_expr)
165-
seen.add(reduced_expr)
166-
167-
if len(new_expressions) == 1:
168-
return new_expressions[0]
169-
170-
consts = []
171-
unknowns = []
172-
for e in new_expressions:
173-
if not isinstance(e, Constant):
174-
unknowns.append(e)
175-
continue
176-
if not isinstance(e, RegisterLayout):
177-
raise ValueError(
178-
f"Reduction of non-register layout constant is not supported: {e}"
179-
)
180-
consts.append(e)
181-
182-
if consts:
183-
const_red, *consts = consts
184-
red = const_red
185-
for cst in consts:
186-
red_value = reducer(red.value, cst.value)
187-
if red_value is None:
188-
# The layouts are not compatible up to replication, this expression
189-
# cannot be simplified.
190-
return Unsatisfiable()
191-
red = RegisterLayout(red_value)
192-
else:
193-
red = None
194-
195-
constructor = type(input_expr)
196-
if red is not None:
197-
if unknowns:
198-
return constructor((red, *unknowns))
199-
return red
200-
201-
return constructor(tuple(unknowns))
202-
203-
204130
def reduce_broadcast_expression(
205131
broadcast: BroadcastInDim, assignments: dict[Variable, Constant]
206132
) -> Expression | Unsatisfiable:
@@ -314,14 +240,6 @@ def reduce_expression(
314240
return expr
315241
case Variable():
316242
return assignments.get(expr, expr)
317-
case MostReplicated():
318-
return reduce_replicated_expression(
319-
expr, assignments, layouts_lib.join_layouts
320-
)
321-
case LeastReplicated():
322-
return reduce_replicated_expression(
323-
expr, assignments, layouts_lib.meet_layouts
324-
)
325243
case Reduce(expression=expr, axes=axes):
326244
reduced_expr = reduce_expression(expr, assignments)
327245
match reduced_expr:
@@ -640,12 +558,6 @@ def extract_variables(expr: Expression) -> None:
640558
free_variables.append(expr)
641559
case Constant():
642560
...
643-
case MostReplicated(expressions=expressions):
644-
for e in expressions:
645-
extract_variables(e)
646-
case LeastReplicated(expressions=expressions):
647-
for e in expressions:
648-
extract_variables(e)
649561
case Reduce(expression=e):
650562
extract_variables(e)
651563
case BroadcastInDim(expression=e):

jax/experimental/mosaic/gpu/layouts.py

Lines changed: 0 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
"""Layout utilities."""
1616

1717
import re
18-
from typing import assert_never
1918

2019
from jax._src.lib import mosaic_gpu_dialect as mgpu
2120
from jax._src.lib.mlir import ir
@@ -224,118 +223,6 @@ def splat_is_compatible_with_tiled(
224223
return all(d1 % d2 == 0 for d1, d2 in zip(s1, s2))
225224

226225

227-
def meet_layouts(
228-
layout1: fa.FragmentedLayout, layout2: fa.FragmentedLayout
229-
) -> fa.FragmentedLayout | None:
230-
"""Returns the "meet" of two layouts that are compatible up to replication.
231-
232-
The "meet" of the two layouts is the most replicated layout that is still
233-
less replicated than the arguments.
234-
235-
This is the dual of `join_layouts`.
236-
237-
Returns:
238-
The "meet" of the two layouts if both layouts are compatible up to
239-
replication.
240-
241-
Raises:
242-
ValueError: if the two layouts are not compatible up to replication.
243-
"""
244-
if layout1 == layout2:
245-
return layout1
246-
247-
match (layout1, layout2):
248-
case (fa.WGSplatFragLayout(), _):
249-
if isinstance(layout2, fa.TiledLayout):
250-
if splat_is_compatible_with_tiled(layout1, layout2):
251-
return layout2
252-
elif layout1.shape == layout2.shape:
253-
return layout2
254-
case (_, fa.WGSplatFragLayout()):
255-
if isinstance(layout1, fa.TiledLayout):
256-
if splat_is_compatible_with_tiled(layout2, layout1):
257-
return layout1
258-
elif layout1.shape == layout2.shape:
259-
return layout1
260-
case (fa.TiledLayout(), fa.TiledLayout()):
261-
# TODO(bchetioui): handle `TiledLayout` replication.
262-
raise NotImplementedError("TiledLayout replication not supported yet")
263-
264-
# Layouts are not compatible up to replication.
265-
return None
266-
267-
# NOTE: We say that two layouts are compatible up to replication if the two
268-
# layouts satisfy at least one of the following conditions together:
269-
#
270-
# - The two layouts are equal;
271-
# - One of the layouts is a `WGSplatFragLayout`, and
272-
# * The other layout is a `WGStridedFragLayout` with the same shape;
273-
# * The other layout is a `TiledLayout` that can be used to tile the shape
274-
# embedded in the `WGSplatFragLayout`.
275-
#
276-
# If any of these conditions hold, then we are always able to substitute one
277-
# layout with the other without having to reorder any data in the underlying
278-
# array---i.e. a relayout is free.
279-
#
280-
# Note that there are other combinations of layouts for which relayout is free,
281-
# but we voluntarily narrowed down our definition to span a small, useful
282-
# subset.
283-
284-
def join_layouts(
285-
layout1: fa.FragmentedLayout, layout2: fa.FragmentedLayout
286-
) -> fa.FragmentedLayout | None:
287-
"""Returns the "join" of two layouts that are compatible up to replication.
288-
289-
The "join" of the two layouts is the least replicated layout that is still
290-
more replicated than the arguments.
291-
292-
This is the dual of `meet_layouts`.
293-
294-
Returns:
295-
The "join" of the two layouts if both layouts are compatible up to
296-
replication.
297-
298-
Raises:
299-
ValueError: if the two layouts are not compatible up to replication.
300-
"""
301-
if layout1 == layout2:
302-
return layout1
303-
304-
match (layout1, layout2):
305-
case (fa.WGSplatFragLayout(), _):
306-
if isinstance(layout2, fa.TiledLayout):
307-
if splat_is_compatible_with_tiled(layout1, layout2):
308-
return layout1
309-
elif layout1.shape == layout2.shape:
310-
return layout1
311-
case (_, fa.WGSplatFragLayout()):
312-
if isinstance(layout1, fa.TiledLayout):
313-
if splat_is_compatible_with_tiled(layout2, layout1):
314-
return layout2
315-
elif layout1.shape == layout2.shape:
316-
return layout2
317-
case (fa.TiledLayout(), fa.TiledLayout()):
318-
# TODO(bchetioui): handle `TiledLayout` replication.
319-
raise NotImplementedError("TiledLayout replication not supported yet")
320-
321-
# Layouts are not compatible up to replication.
322-
return None
323-
324-
325-
def has_any_replication(layout: fa.FragmentedLayout) -> bool:
326-
match layout:
327-
case fa.WGSplatFragLayout():
328-
return True
329-
case fa.WGStridedFragLayout():
330-
return False
331-
case fa.TiledLayout():
332-
is_warp_replicated = any(isinstance(d, fa.Replicated) for d in layout.warp_dims)
333-
is_lane_replicated = any(isinstance(d, fa.Replicated) for d in layout.lane_dims)
334-
return is_warp_replicated or is_lane_replicated
335-
case _ as unreachable:
336-
return assert_never(unreachable) # pytype: disable=wrong-arg-types
337-
338-
339226
_tile_transform_attr_pattern = re.compile(
340227
r"^#mosaic_gpu.tile<[^>]+>$"
341228
)

tests/mosaic/gpu_constraints_test.py

Lines changed: 1 addition & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -127,11 +127,9 @@ def test_constraint_system_unknowns_are_all_the_variables_without_assignment(
127127
):
128128
v0, v1, v2, v3 = V(0), V(1), V(2), V(3)
129129
layout = RL(mgpu.WGSplatFragLayout((1, 1)))
130-
least_replicated = cs.LeastReplicated((v2, v3))
131-
most_replicated = cs.MostReplicated((least_replicated,))
132130
system = cs.ConstraintSystem(
133131
assignments={v0: layout},
134-
constraints=[Eq(v1, most_replicated)],
132+
constraints=[Eq(v1, v2), cs.Relayout(v2, v3)],
135133
)
136134
self.assertSequenceEqual(system.unknowns(), [v1, v2, v3])
137135

@@ -164,112 +162,6 @@ def test_intersection_of_compatible_systems_is_union_of_fields(self):
164162
self.assertSequenceEqual(system1.unknowns(), [v1])
165163
self.assertSequenceEqual(system_intersection.unknowns(), [v0, v1])
166164

167-
def test_reduce_extracts_most_replicated_expression_correctly(self):
168-
v0 = V(0)
169-
shape = (1, 128)
170-
layout0 = RL(mgpu.WGSplatFragLayout(shape))
171-
layout1 = RL(mgpu.WGStridedFragLayout(shape, vec_size=1))
172-
with self.subTest("most-replicated-expression-exists"):
173-
system = cs.ConstraintSystem(
174-
constraints=[Eq(v0, cs.MostReplicated((layout0, layout1)))],
175-
)
176-
self.assertEqual(
177-
cs.reduce(system),
178-
cs.ConstraintSystem(assignments={v0: layout0}),
179-
)
180-
181-
with self.subTest("most-replicated-expression-is-unique-expression"):
182-
system = cs.ConstraintSystem(
183-
constraints=[Eq(v0, cs.MostReplicated((layout0,)))],
184-
)
185-
self.assertEqual(
186-
cs.reduce(system),
187-
cs.ConstraintSystem(assignments={v0: layout0}),
188-
)
189-
190-
with self.subTest("most-replicated-expression-does-not-exist"):
191-
system = cs.ConstraintSystem(
192-
constraints=[Eq(v0, cs.MostReplicated((layout1, v0)))],
193-
)
194-
self.assertEqual(cs.reduce(system), system)
195-
196-
def test_reduce_extracts_least_replicated_expression_correctly(self):
197-
v0 = V(0)
198-
shape = (1, 128)
199-
layout0 = RL(mgpu.WGSplatFragLayout(shape))
200-
layout1 = RL(mgpu.WGStridedFragLayout(shape, vec_size=1))
201-
with self.subTest("least-replicated-expression-exists"):
202-
system = cs.ConstraintSystem(
203-
constraints=[Eq(v0, cs.LeastReplicated([layout0, layout1]))],
204-
)
205-
self.assertEqual(
206-
cs.reduce(system),
207-
cs.ConstraintSystem(assignments={v0: layout1}),
208-
)
209-
210-
with self.subTest("least-replicated-expression-is-unique-expression"):
211-
system = cs.ConstraintSystem(
212-
constraints=[Eq(v0, cs.LeastReplicated((layout0,)))],
213-
)
214-
self.assertEqual(
215-
cs.reduce(system),
216-
cs.ConstraintSystem(assignments={v0: layout0}),
217-
)
218-
219-
with self.subTest("least-replicated-expression-does-not-exist"):
220-
system = cs.ConstraintSystem(
221-
constraints=[Eq(v0, cs.LeastReplicated((layout0, v0)))],
222-
)
223-
self.assertEqual(cs.reduce(system), system)
224-
225-
def test_reduce_most_replicated_expression_reduces_compatible_layouts(self):
226-
splat_layout = RL(mgpu.WGSplatFragLayout((128, 64)))
227-
tiled_layout = RL(mgpu.WGMMA_LAYOUT)
228-
self.assertEqual(
229-
cs.reduce_expression(
230-
cs.MostReplicated((splat_layout, tiled_layout)),
231-
{},
232-
),
233-
splat_layout,
234-
)
235-
236-
def test_reduce_most_replicated_expression_is_unsatisfiable_for_incompatible_layouts(
237-
self,
238-
):
239-
splat_layout = RL(mgpu.WGSplatFragLayout((1, 2)))
240-
tiled_layout = RL(mgpu.WGMMA_LAYOUT)
241-
self.assertIsInstance(
242-
cs.reduce_expression(
243-
cs.MostReplicated((splat_layout, tiled_layout)),
244-
{},
245-
),
246-
cs.Unsatisfiable,
247-
)
248-
249-
def test_reduce_least_replicated_expression_reduces_compatible_layouts(self):
250-
splat_layout = RL(mgpu.WGSplatFragLayout((128, 64)))
251-
tiled_layout = RL(mgpu.WGMMA_LAYOUT)
252-
self.assertEqual(
253-
cs.reduce_expression(
254-
cs.LeastReplicated((splat_layout, tiled_layout)),
255-
{},
256-
),
257-
tiled_layout,
258-
)
259-
260-
def test_reduce_least_replicated_expression_is_unsatisfiable_for_incompatible_layouts(
261-
self,
262-
):
263-
splat_layout = RL(mgpu.WGSplatFragLayout((1, 2)))
264-
tiled_layout = RL(mgpu.WGMMA_LAYOUT)
265-
self.assertIsInstance(
266-
cs.reduce_expression(
267-
cs.LeastReplicated((splat_layout, tiled_layout)),
268-
{},
269-
),
270-
cs.Unsatisfiable,
271-
)
272-
273165
@parameterized.named_parameters(
274166
("reduce_to_row_layout", (1,), mgpu.WGMMA_ROW_LAYOUT),
275167
("reduce_to_col_layout", (0,), mgpu.WGMMA_COL_LAYOUT),

0 commit comments

Comments
 (0)