@@ -490,10 +490,9 @@ def __str__(self):
490490
491491def reduce_constraint (
492492 constraint : Constraint , assignments : dict [Variable , Constant ]
493- ) -> Constraint | Tautological | Unsatisfiable :
493+ ) -> Constraint | Unsatisfiable :
494494 """Reduces a constraint."""
495495
496- new_constraint : Constraint
497496 match constraint :
498497 case Equals (lhs = lhs , rhs = rhs ):
499498 lhs_red = reduce_expression (lhs , assignments )
@@ -502,39 +501,34 @@ def reduce_constraint(
502501 rhs_red = reduce_expression (rhs , assignments )
503502 if isinstance (rhs_red , Unsatisfiable ):
504503 return Unsatisfiable ()
505- new_constraint = Equals (lhs_red , rhs_red )
504+ return Equals (lhs_red , rhs_red )
506505 case Relayout (source = source , target = target ):
507506 source_red = reduce_expression (source , assignments )
508507 target_red = reduce_expression (target , assignments )
509508 if isinstance (source_red , Unsatisfiable ) or isinstance (
510509 target_red , Unsatisfiable
511510 ):
512511 return Unsatisfiable ()
513- new_constraint = Relayout (source_red , target_red )
512+ return Relayout (source_red , target_red )
514513 case NotOfType (expr = expr , type = type ):
515514 expr_red = reduce_expression (expr , assignments )
516515 if isinstance (expr_red , Unsatisfiable ):
517516 return Unsatisfiable ()
518- new_constraint = NotOfType (expr_red , type )
517+ return NotOfType (expr_red , type )
519518 case IsTransferable (source = source , target = target , shape = shape ):
520519 source_red = reduce_expression (source , assignments )
521520 target_red = reduce_expression (target , assignments )
522521 if isinstance (source_red , Unsatisfiable ) or isinstance (target_red , Unsatisfiable ):
523522 return Unsatisfiable ()
524- new_constraint = IsTransferable (source_red , target_red , shape )
523+ return IsTransferable (source_red , target_red , shape )
525524 case Divides (expr = expr , tiling_multiple = tiling_multiple ):
526525 expr_red = reduce_expression (expr , assignments )
527526 if isinstance (expr_red , Unsatisfiable ):
528527 return Unsatisfiable ()
529- new_constraint = Divides (expr_red , tiling_multiple )
528+ return Divides (expr_red , tiling_multiple )
530529 case _ as never :
531530 assert_never (never )
532531
533- constraint_holds = new_constraint .holds ()
534- if constraint_holds is None :
535- return new_constraint
536- return Tautological () if constraint_holds else Unsatisfiable ()
537-
538532
539533@dataclasses .dataclass
540534class ConstraintSystem :
@@ -620,10 +614,6 @@ def __and__(self, other: ConstraintSystem | Unsatisfiable) -> Unsatisfiable:
620614 return self
621615
622616
623- class Tautological :
624- ...
625-
626-
627617def non_splat_variables (
628618 constraints : Sequence [Constraint ],
629619) -> set [Variable ]:
@@ -832,11 +822,16 @@ def try_assign(var: Variable, cst: Constant) -> bool:
832822 if not try_assign (var , cst ):
833823 return Unsatisfiable ()
834824 changed = True
835- case Tautological ():
836- changed = True
837825 case _ as new_constraint :
838- changed |= new_constraint != constraint
839- constraints .append (new_constraint )
826+ assert isinstance (new_constraint , Constraint ) # make pytype happy
827+ match new_constraint .holds ():
828+ case None :
829+ constraints .append (new_constraint )
830+ changed |= new_constraint != constraint
831+ case False :
832+ return Unsatisfiable ()
833+ case True :
834+ changed = True
840835
841836 new_constraints = merge_divides_constraints (constraints )
842837 changed |= len (new_constraints ) != len (constraints )
0 commit comments