@@ -1042,12 +1042,12 @@ def _vector_reduction_constraint_system(
10421042 return cs .ConstraintSystem (), {in_variable : [in_variable .key ]}, []
10431043
10441044
1045- def _reduction_constraint_and_hint (
1045+ def _reduction_constraints_and_hint (
10461046 larger : cs .Variable ,
10471047 smaller : cs .Variable ,
10481048 larger_shape : tuple [int , ...],
10491049 reduction_dims : tuple [int , ...],
1050- ) -> tuple [cs .Constraint , Hint ]:
1050+ ) -> tuple [list [ cs .Constraint ] , Hint ]:
10511051 reduce_expr = cs .Reduce (larger , reduction_dims )
10521052 # There are always many options for broadcasting a layout, so we can only
10531053 # derive a broadcast hint in the out_variable -> source_variable direction.
@@ -1056,7 +1056,12 @@ def _reduction_constraint_and_hint(
10561056 )
10571057 broadcast_expr = cs .BroadcastInDim (smaller , broadcast_dims , larger_shape )
10581058 broadcast_hint = Hint (variable = larger , expression = broadcast_expr )
1059- return cs .Equals (lhs = smaller , rhs = reduce_expr ), broadcast_hint
1059+ constraints = [
1060+ cs .Equals (lhs = smaller , rhs = reduce_expr ),
1061+ # TODO(allanrenucci): Remove once we support reduction of strided layouts.
1062+ cs .NotOfType (larger , fa .WGStridedFragLayout ),
1063+ ]
1064+ return constraints , broadcast_hint
10601065
10611066
10621067@_add_constraint_system_derivation_rule (vector .MultiDimReductionOp )
@@ -1071,7 +1076,7 @@ def _multi_dim_reduction_constraint_system(
10711076 source_variable = cs .Variable (source )
10721077 out_variable = cs .Variable (out )
10731078
1074- reduction_constraint , broadcast_hint = _reduction_constraint_and_hint (
1079+ reduction_constraints , broadcast_hint = _reduction_constraints_and_hint (
10751080 source_variable ,
10761081 out_variable ,
10771082 tuple (ir .ShapedType (op .source .type ).shape ),
@@ -1081,7 +1086,7 @@ def _multi_dim_reduction_constraint_system(
10811086 # strided layouts from being chosen---since trying to reduce a strided layout
10821087 # may cause us to raise an Exception at the moment.
10831088 return (
1084- cs .ConstraintSystem (constraints = [ reduction_constraint ] ),
1089+ cs .ConstraintSystem (constraints = reduction_constraints ),
10851090 {source_variable : [source ], out_variable : [acc , out ]},
10861091 [broadcast_hint ],
10871092 )
@@ -1100,12 +1105,12 @@ def _broadcast_in_dim_constraint_system(
11001105 i for i in range (len (out_shape )) if i not in op .broadcast_dimensions
11011106 )
11021107
1103- reduction_constraint , broadcast_hint = _reduction_constraint_and_hint (
1108+ reduction_constraints , broadcast_hint = _reduction_constraints_and_hint (
11041109 out_variable , source_variable , out_shape , reduction_dims
11051110 )
11061111
11071112 return (
1108- cs .ConstraintSystem (constraints = [ reduction_constraint ] ),
1113+ cs .ConstraintSystem (constraints = reduction_constraints ),
11091114 {
11101115 source_variable : [source_variable .key ],
11111116 out_variable : [out_variable .key ],
0 commit comments