Skip to content

Commit 8b95853

Browse files
tlongerijax authors
authored andcommitted
[Mosaic] Add relayout for (1, 128 * packing) -> (packing, 128).
PiperOrigin-RevId: 637951690
1 parent 8fec5d6 commit 8b95853

File tree

2 files changed

+109
-8
lines changed

2 files changed

+109
-8
lines changed

jaxlib/mosaic/dialect/tpu/tpu.td

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,18 @@ def TPU_ContractPrecisionEnum
9797
let assemblyFormat = "`<` $value `>`";
9898
}
9999

100+
def TPU_PackFormat : I32EnumAttr<"PackFormat", "Pack format", [
101+
I32EnumAttrCase<"kCompressed", 0, "compressed">,
102+
I32EnumAttrCase<"kInterleaved", 1, "interleaved">
103+
]> {
104+
let genSpecializedAttr = 0;
105+
let cppNamespace = "::mlir::tpu";
106+
}
107+
108+
def TPU_PackFormatEnum : EnumAttr<TPU_Dialect, TPU_PackFormat, "pack_format"> {
109+
let assemblyFormat = "`<` $value `>`";
110+
}
111+
100112
def TPU_TiledCase : I32EnumAttrCase<"tiled", 0>;
101113
def TPU_LaneCase : I32EnumAttrCase<"lanes", 1>;
102114
def TPU_SublaneCase : I32EnumAttrCase<"sublanes", 2>;
@@ -278,7 +290,10 @@ def TPU_UnpackSubelementsOp : TPU_Op<"unpack_subelements", [Pure]> {
278290

279291
// Integer packs are always signed at the moment.
280292
def TPU_PackSubelementsOp : TPU_Op<"pack_subelements", [Pure]> {
281-
let arguments = (ins Variadic<AnyVector>:$sources);
293+
let arguments = (ins
294+
Variadic<AnyVector>:$sources,
295+
TPU_PackFormatEnum:$pack_format
296+
);
282297
let results = (outs AnyVector:$output);
283298
let assemblyFormat = [{ $sources attr-dict `:` type($sources) `->` type($output) }];
284299
}

jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc

Lines changed: 93 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -827,7 +827,8 @@ LogicalResult trunc_op_rule_impl(RewriteContext &ctx, OpTy op,
827827
++idxs_local.back();
828828
}
829829
}
830-
*v = builder.create<PackSubelementsOp>(res_vreg_ty, parts);
830+
*v = builder.create<PackSubelementsOp>(res_vreg_ty, parts,
831+
tpu::PackFormat::kCompressed);
831832
});
832833
} else if (layout_out.hasNativeTiling(ctx.target_shape)) {
833834
int packing = layout_out.packing();
@@ -848,7 +849,8 @@ LogicalResult trunc_op_rule_impl(RewriteContext &ctx, OpTy op,
848849
parts.push_back(parts.back());
849850
}
850851
}
851-
*v = builder.create<PackSubelementsOp>(res_vreg_ty, parts);
852+
*v = builder.create<PackSubelementsOp>(res_vreg_ty, parts,
853+
tpu::PackFormat::kCompressed);
852854
parts.clear();
853855
});
854856
} else {
@@ -4576,10 +4578,8 @@ FailureOr<TypedValue<VectorType>> relayout(
45764578
} else if ( // TODO(b/265133506): Generalize retiling.
45774579
// (8,128) -> (8 * packing,128) tiling change for packed type.
45784580
src.implicit_dim() == VectorLayout::ImplicitDim::kNone &&
4579-
dst.implicit_dim() == VectorLayout::ImplicitDim::kNone &&
4580-
vty.getElementTypeBitWidth() < 32 &&
4581-
32 % vty.getElementTypeBitWidth() == 0 &&
4582-
src.offsets() == dst.offsets() &&
4581+
dst.implicit_dim() == VectorLayout::ImplicitDim::kNone && bitwidth < 32 &&
4582+
32 % bitwidth == 0 && src.offsets() == dst.offsets() &&
45834583
src.tiling() == std::array<int64_t, 2>{8, 128} &&
45844584
dst.tiling() == std::array<int64_t, 2>{8 * dst.packing(), 128}) {
45854585
const VectorLayout new_src(src.bitwidth(), src.offsets(), dst.tiling());
@@ -4606,7 +4606,93 @@ FailureOr<TypedValue<VectorType>> relayout(
46064606
}
46074607
}
46084608
*tile = builder.create<tpu::PackSubelementsOp>(
4609-
v.getLoc(), src_tiles.begin()->getType(), parts);
4609+
v.getLoc(), src_tiles.begin()->getType(), parts,
4610+
tpu::PackFormat::kCompressed);
4611+
});
4612+
src = new_src;
4613+
src_tiles = std::move(src_tiles_retiled);
4614+
} else if ( // Handle retiling from (1, 128 * packing) to (packing, 128) for
4615+
// packed data.
4616+
// We do compressed unpacking followed by interleaved packing.
4617+
// TODO(tlongeri): This can be used as a first step before using
4618+
// a generalized retiling where we only move sublanes around
4619+
// (without packing/unpacking).
4620+
// TODO(tlongeri): Interleaved unpacking followed by interleaved
4621+
// packing (but with different pairings) might also be
4622+
// interesting if the next step is a retile, since we can also
4623+
// match corresponding elements without shifting. It's just that
4624+
// the tiles are not adjacent (no contiguous vreg slice).
4625+
src.implicit_dim() == VectorLayout::ImplicitDim::kNone &&
4626+
dst.implicit_dim() == VectorLayout::ImplicitDim::kNone && bitwidth < 32 &&
4627+
32 % bitwidth == 0 && src.offsets() == dst.offsets() &&
4628+
src.tiling() == std::array<int64_t, 2>{1, 128 * packing} &&
4629+
dst.tiling() == std::array<int64_t, 2>{packing, 128}) {
4630+
// To illustrate, consider a 2 x 16 16-bit shape laid out in vregs of
4631+
// 4 sublanes and 2 lanes (this is convenient for to keep the example small
4632+
// yet non-trivial) with (1, 4) tiling. We will relayout to (2, 2) tiling.
4633+
//
4634+
// The vreg slice is 1 x 16, that is, the vreg contains the data for a
4635+
// 1 x 16 window of the logical shape.
4636+
//
4637+
// [a b c d e f g h i j k l m n o p] -> vreg 1
4638+
// [A B C D E F G H I J K L M N O P] -> vreg 2
4639+
//
4640+
// Note: we support multiple vregs per row of the logical shape, but we use
4641+
// one here just to keep the example small.
4642+
//
4643+
// When we do a compressed unpack, the resulting vregs effectively have a
4644+
// tiling of (1, 2) and cover a vreg slice of 1 x 8 logical elements.
4645+
//
4646+
// [a b c d e f g h] -> vreg 1, part 1 [i j k l m n o p] -> vreg 1, part 2
4647+
// [A B C D E F G H] -> vreg 2, part 1 [I J K L M N O P] -> vreg 2, part 2
4648+
//
4649+
// It is clear that if combine vreg 1, part 1 and vreg 2, part 1 we get data
4650+
// that covers a 2 x 8 vreg slice. Note, however, that we will have to mind
4651+
// the internal ordering of the vreg.
4652+
//
4653+
// [a b c d e f g h [i j k l m n o p
4654+
// A B C D E F G H] -> new vreg 1 I J K L M N O P] -> new vreg 2
4655+
//
4656+
// To see if we can get the right internal ordering that we need for (2, 2)
4657+
// tiling, let's break new vreg 1 into (1, 2) rows, which correspond to
4658+
// sublanes when unpacked and half-sublanes when packed.
4659+
//
4660+
// [(a b) (c d) (e f) (g h)
4661+
// (A B) (C D) (E F) (G H)]
4662+
//
4663+
// The sublane order for the vreg parts is [(a b) (c d) ...] for vreg 1,
4664+
// part 1 and [(A B) (C D) ...] for vreg 2, part 1.
4665+
//
4666+
// The desired half-sublane order, for packed (2, 2) tiling, is
4667+
// [(a b) (A B) (c d) (C D) ...]. That is, traverse down each column before
4668+
// moving to the next one. This is exactly an interleaving of the sublanes
4669+
// of the vreg parts.
4670+
const VectorLayout new_src(src.bitwidth(), src.offsets(),
4671+
std::array<int64_t, 2>{packing, 128});
4672+
xla::Array<Value> src_tiles_retiled(
4673+
new_src.tileArrayShape(vty.getShape(), target_shape));
4674+
const VectorType vreg_x32 =
4675+
vty.getElementType().isSignlessInteger()
4676+
? VectorType::get(target_shape, builder.getI32Type())
4677+
: VectorType::get(target_shape, builder.getF32Type());
4678+
src_tiles_retiled.Each([&](absl::Span<const int64_t> idx, Value *tile) {
4679+
SmallVector<Value> parts;
4680+
parts.reserve(packing);
4681+
SmallVector<int64_t> src_idx(toArrayRef(idx));
4682+
*(src_idx.end() - 2) *= packing;
4683+
const int64_t vreg_part = *(src_idx.end() - 1) % packing;
4684+
*(src_idx.end() - 1) /= packing;
4685+
for (int i = 0; i < packing; ++i) {
4686+
parts.push_back(builder.create<tpu::UnpackSubelementsOp>(
4687+
v.getLoc(), vreg_x32, src_tiles(src_idx), vreg_part));
4688+
if (*(src_idx.end() - 2) < *(src_tiles.dimensions().end() - 2)) {
4689+
++*(src_idx.end() - 2);
4690+
} // The rest is padding, so just pick any of the input parts (but not
4691+
// an arbitrary vreg so we don't add an extra dependency).
4692+
}
4693+
*tile = builder.create<tpu::PackSubelementsOp>(
4694+
v.getLoc(), src_tiles.begin()->getType(), parts,
4695+
tpu::PackFormat::kInterleaved);
46104696
});
46114697
src = new_src;
46124698
src_tiles = std::move(src_tiles_retiled);

0 commit comments

Comments
 (0)