Skip to content

Commit 2fbe731

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
[Mosaic TPU] Only canonicalize the dtype of transposes if they use the XLU
We should never change the dtypes of transposes that only deal with untiled dimensions, for example. PiperOrigin-RevId: 815013892
1 parent 3e5be14 commit 2fbe731

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1799,9 +1799,11 @@ FailureOr<Value> canonicalize_transpose(const CanonicalizeContext &ctx,
17991799
return res;
18001800
};
18011801

1802+
bool uses_xlu = !op.getPermutation().empty() &&
1803+
op.getPermutation().back() != op.getPermutation().size() - 1;
18021804
// TODO(b/448848595): Enable 8-bit transposes on generation 7.
18031805
if (element_type.getIntOrFloatBitWidth() == 8 && ctx.compatibility_mode &&
1804-
ctx.hardware_generation > 3) {
1806+
ctx.hardware_generation > 3 && uses_xlu) {
18051807
VectorType input_vty_int = VectorType::get(
18061808
input_vty.getShape(),
18071809
builder.getIntegerType(input_vty.getElementTypeBitWidth()));

0 commit comments

Comments
 (0)