Skip to content

Commit 15ba1b7

Browse files
allanrenucciGoogle-ML-Automation
authored andcommitted
[Mosaic GPU][NFC] Remove Tautological from constraint reduction.
PiperOrigin-RevId: 842652535
1 parent 3a28e93 commit 15ba1b7

File tree

1 file changed

+15
-20
lines changed

1 file changed

+15
-20
lines changed

jax/experimental/mosaic/gpu/constraints.py

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -490,10 +490,9 @@ def __str__(self):
490490

491491
def reduce_constraint(
492492
constraint: Constraint, assignments: dict[Variable, Constant]
493-
) -> Constraint | Tautological | Unsatisfiable:
493+
) -> Constraint | Unsatisfiable:
494494
"""Reduces a constraint."""
495495

496-
new_constraint: Constraint
497496
match constraint:
498497
case Equals(lhs=lhs, rhs=rhs):
499498
lhs_red = reduce_expression(lhs, assignments)
@@ -502,39 +501,34 @@ def reduce_constraint(
502501
rhs_red = reduce_expression(rhs, assignments)
503502
if isinstance(rhs_red, Unsatisfiable):
504503
return Unsatisfiable()
505-
new_constraint = Equals(lhs_red, rhs_red)
504+
return Equals(lhs_red, rhs_red)
506505
case Relayout(source=source, target=target):
507506
source_red = reduce_expression(source, assignments)
508507
target_red = reduce_expression(target, assignments)
509508
if isinstance(source_red, Unsatisfiable) or isinstance(
510509
target_red, Unsatisfiable
511510
):
512511
return Unsatisfiable()
513-
new_constraint = Relayout(source_red, target_red)
512+
return Relayout(source_red, target_red)
514513
case NotOfType(expr=expr, type=type):
515514
expr_red = reduce_expression(expr, assignments)
516515
if isinstance(expr_red, Unsatisfiable):
517516
return Unsatisfiable()
518-
new_constraint = NotOfType(expr_red, type)
517+
return NotOfType(expr_red, type)
519518
case IsTransferable(source=source, target=target, shape=shape):
520519
source_red = reduce_expression(source, assignments)
521520
target_red = reduce_expression(target, assignments)
522521
if isinstance(source_red, Unsatisfiable) or isinstance(target_red, Unsatisfiable):
523522
return Unsatisfiable()
524-
new_constraint = IsTransferable(source_red, target_red, shape)
523+
return IsTransferable(source_red, target_red, shape)
525524
case Divides(expr=expr, tiling_multiple=tiling_multiple):
526525
expr_red = reduce_expression(expr, assignments)
527526
if isinstance(expr_red, Unsatisfiable):
528527
return Unsatisfiable()
529-
new_constraint = Divides(expr_red, tiling_multiple)
528+
return Divides(expr_red, tiling_multiple)
530529
case _ as never:
531530
assert_never(never)
532531

533-
constraint_holds = new_constraint.holds()
534-
if constraint_holds is None:
535-
return new_constraint
536-
return Tautological() if constraint_holds else Unsatisfiable()
537-
538532

539533
@dataclasses.dataclass
540534
class ConstraintSystem:
@@ -620,10 +614,6 @@ def __and__(self, other: ConstraintSystem | Unsatisfiable) -> Unsatisfiable:
620614
return self
621615

622616

623-
class Tautological:
624-
...
625-
626-
627617
def non_splat_variables(
628618
constraints: Sequence[Constraint],
629619
) -> set[Variable]:
@@ -832,11 +822,16 @@ def try_assign(var: Variable, cst: Constant) -> bool:
832822
if not try_assign(var, cst):
833823
return Unsatisfiable()
834824
changed = True
835-
case Tautological():
836-
changed = True
837825
case _ as new_constraint:
838-
changed |= new_constraint != constraint
839-
constraints.append(new_constraint)
826+
assert isinstance(new_constraint, Constraint) # make pytype happy
827+
match new_constraint.holds():
828+
case None:
829+
constraints.append(new_constraint)
830+
changed |= new_constraint != constraint
831+
case False:
832+
return Unsatisfiable()
833+
case True:
834+
changed = True
840835

841836
new_constraints = merge_divides_constraints(constraints)
842837
changed |= len(new_constraints) != len(constraints)

0 commit comments

Comments
 (0)