@@ -827,7 +827,8 @@ LogicalResult trunc_op_rule_impl(RewriteContext &ctx, OpTy op,
827827 ++idxs_local.back ();
828828 }
829829 }
830- *v = builder.create <PackSubelementsOp>(res_vreg_ty, parts);
830+ *v = builder.create <PackSubelementsOp>(res_vreg_ty, parts,
831+ tpu::PackFormat::kCompressed );
831832 });
832833 } else if (layout_out.hasNativeTiling (ctx.target_shape )) {
833834 int packing = layout_out.packing ();
@@ -848,7 +849,8 @@ LogicalResult trunc_op_rule_impl(RewriteContext &ctx, OpTy op,
848849 parts.push_back (parts.back ());
849850 }
850851 }
851- *v = builder.create <PackSubelementsOp>(res_vreg_ty, parts);
852+ *v = builder.create <PackSubelementsOp>(res_vreg_ty, parts,
853+ tpu::PackFormat::kCompressed );
852854 parts.clear ();
853855 });
854856 } else {
@@ -4576,10 +4578,8 @@ FailureOr<TypedValue<VectorType>> relayout(
45764578 } else if ( // TODO(b/265133506): Generalize retiling.
45774579 // (8,128) -> (8 * packing,128) tiling change for packed type.
45784580 src.implicit_dim () == VectorLayout::ImplicitDim::kNone &&
4579- dst.implicit_dim () == VectorLayout::ImplicitDim::kNone &&
4580- vty.getElementTypeBitWidth () < 32 &&
4581- 32 % vty.getElementTypeBitWidth () == 0 &&
4582- src.offsets () == dst.offsets () &&
4581+ dst.implicit_dim () == VectorLayout::ImplicitDim::kNone && bitwidth < 32 &&
4582+ 32 % bitwidth == 0 && src.offsets () == dst.offsets () &&
45834583 src.tiling () == std::array<int64_t , 2 >{8 , 128 } &&
45844584 dst.tiling () == std::array<int64_t , 2 >{8 * dst.packing (), 128 }) {
45854585 const VectorLayout new_src (src.bitwidth (), src.offsets (), dst.tiling ());
@@ -4606,7 +4606,93 @@ FailureOr<TypedValue<VectorType>> relayout(
46064606 }
46074607 }
46084608 *tile = builder.create <tpu::PackSubelementsOp>(
4609- v.getLoc (), src_tiles.begin ()->getType (), parts);
4609+ v.getLoc (), src_tiles.begin ()->getType (), parts,
4610+ tpu::PackFormat::kCompressed );
4611+ });
4612+ src = new_src;
4613+ src_tiles = std::move (src_tiles_retiled);
4614+ } else if ( // Handle retiling from (1, 128 * packing) to (packing, 128) for
4615+ // packed data.
4616+ // We do compressed unpacking followed by interleaved packing.
4617+ // TODO(tlongeri): This can be used as a first step before using
4618+ // a generalized retiling where we only move sublanes around
4619+ // (without packing/unpacking).
4620+ // TODO(tlongeri): Interleaved unpacking followed by interleaved
4621+ // packing (but with different pairings) might also be
4622+ // interesting if the next step is a retile, since we can also
4623+ // match corresponding elements without shifting. It's just that
4624+ // the tiles are not adjacent (no contiguous vreg slice).
4625+ src.implicit_dim () == VectorLayout::ImplicitDim::kNone &&
4626+ dst.implicit_dim () == VectorLayout::ImplicitDim::kNone && bitwidth < 32 &&
4627+ 32 % bitwidth == 0 && src.offsets () == dst.offsets () &&
4628+ src.tiling () == std::array<int64_t , 2 >{1 , 128 * packing} &&
4629+ dst.tiling () == std::array<int64_t , 2 >{packing, 128 }) {
4630+ // To illustrate, consider a 2 x 16 16-bit shape laid out in vregs of
4631+ // 4 sublanes and 2 lanes (this is convenient for to keep the example small
4632+ // yet non-trivial) with (1, 4) tiling. We will relayout to (2, 2) tiling.
4633+ //
4634+ // The vreg slice is 1 x 16, that is, the vreg contains the data for a
4635+ // 1 x 16 window of the logical shape.
4636+ //
4637+ // [a b c d e f g h i j k l m n o p] -> vreg 1
4638+ // [A B C D E F G H I J K L M N O P] -> vreg 2
4639+ //
4640+ // Note: we support multiple vregs per row of the logical shape, but we use
4641+ // one here just to keep the example small.
4642+ //
4643+ // When we do a compressed unpack, the resulting vregs effectively have a
4644+ // tiling of (1, 2) and cover a vreg slice of 1 x 8 logical elements.
4645+ //
4646+ // [a b c d e f g h] -> vreg 1, part 1 [i j k l m n o p] -> vreg 1, part 2
4647+ // [A B C D E F G H] -> vreg 2, part 1 [I J K L M N O P] -> vreg 2, part 2
4648+ //
4649+ // It is clear that if combine vreg 1, part 1 and vreg 2, part 1 we get data
4650+ // that covers a 2 x 8 vreg slice. Note, however, that we will have to mind
4651+ // the internal ordering of the vreg.
4652+ //
4653+ // [a b c d e f g h [i j k l m n o p
4654+ // A B C D E F G H] -> new vreg 1 I J K L M N O P] -> new vreg 2
4655+ //
4656+ // To see if we can get the right internal ordering that we need for (2, 2)
4657+ // tiling, let's break new vreg 1 into (1, 2) rows, which correspond to
4658+ // sublanes when unpacked and half-sublanes when packed.
4659+ //
4660+ // [(a b) (c d) (e f) (g h)
4661+ // (A B) (C D) (E F) (G H)]
4662+ //
4663+ // The sublane order for the vreg parts is [(a b) (c d) ...] for vreg 1,
4664+ // part 1 and [(A B) (C D) ...] for vreg 2, part 1.
4665+ //
4666+ // The desired half-sublane order, for packed (2, 2) tiling, is
4667+ // [(a b) (A B) (c d) (C D) ...]. That is, traverse down each column before
4668+ // moving to the next one. This is exactly an interleaving of the sublanes
4669+ // of the vreg parts.
4670+ const VectorLayout new_src (src.bitwidth (), src.offsets (),
4671+ std::array<int64_t , 2 >{packing, 128 });
4672+ xla::Array<Value> src_tiles_retiled (
4673+ new_src.tileArrayShape (vty.getShape (), target_shape));
4674+ const VectorType vreg_x32 =
4675+ vty.getElementType ().isSignlessInteger ()
4676+ ? VectorType::get (target_shape, builder.getI32Type ())
4677+ : VectorType::get (target_shape, builder.getF32Type ());
4678+ src_tiles_retiled.Each ([&](absl::Span<const int64_t > idx, Value *tile) {
4679+ SmallVector<Value> parts;
4680+ parts.reserve (packing);
4681+ SmallVector<int64_t > src_idx (toArrayRef (idx));
4682+ *(src_idx.end () - 2 ) *= packing;
4683+ const int64_t vreg_part = *(src_idx.end () - 1 ) % packing;
4684+ *(src_idx.end () - 1 ) /= packing;
4685+ for (int i = 0 ; i < packing; ++i) {
4686+ parts.push_back (builder.create <tpu::UnpackSubelementsOp>(
4687+ v.getLoc (), vreg_x32, src_tiles (src_idx), vreg_part));
4688+ if (*(src_idx.end () - 2 ) < *(src_tiles.dimensions ().end () - 2 )) {
4689+ ++*(src_idx.end () - 2 );
4690+ } // The rest is padding, so just pick any of the input parts (but not
4691+ // an arbitrary vreg so we don't add an extra dependency).
4692+ }
4693+ *tile = builder.create <tpu::PackSubelementsOp>(
4694+ v.getLoc (), src_tiles.begin ()->getType (), parts,
4695+ tpu::PackFormat::kInterleaved );
46104696 });
46114697 src = new_src;
46124698 src_tiles = std::move (src_tiles_retiled);
0 commit comments