Skip to content

Commit 6cb91ed

Browse files
superbobryGoogle-ML-Automation
authored andcommitted
[mosaic] Added documentation and a few useful methods to tpu.tiled attribute
PiperOrigin-RevId: 842307512
1 parent f7e3bdb commit 6cb91ed

File tree

2 files changed

+229
-21
lines changed

2 files changed

+229
-21
lines changed

jaxlib/mosaic/dialect/tpu/tpu.td

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,13 +168,39 @@ def TPU_VectorLayoutAttr : TPU_Attr<"VectorLayout", "vpad"> {
168168
def TPU_TiledLayoutAttr
169169
: TPU_Attr<"TiledLayout", "tiled",
170170
[DeclareAttrInterfaceMethods<MemRefLayoutAttrInterface>]> {
171-
let description = [{TODO}];
171+
let description = [{
172+
This attribute represents tiled layouts in memrefs.
173+
174+
Multiple levels of tiling are supported with the following restriction:
175+
- Additional levels of tiling may not add any padding.
176+
- Additional levels of tiling may not tile previously untiled dimensions,
177+
that is, they cannot tile across first-level tiles.
178+
179+
Tile strides encode the stride when moving along a given dimension. They
180+
must have the same rank as the shape and must be decreasing with increasing
181+
dimension number. For tiled dimensions, the stride applies only when moving
182+
across first-level tiles. The strides are in units of the size of the first
183+
tile, or 1 if there are no tiles.
184+
}];
172185
let parameters = (ins
173186
ArrayRefParameter<"::xla::Tile", "">:$tiles,
174187
ArrayRefParameter<"int64_t", "">:$tile_strides
175188
);
189+
let extraClassDeclaration = [{
190+
static ::llvm::SmallVector<int64_t> getDefaultTileStrides(::llvm::ArrayRef<::xla::Tile> tiles, ::llvm::ArrayRef<int64_t> shape);
191+
bool tilesAreKnownContiguous(::llvm::ArrayRef<int64_t> shape) const;
192+
193+
int64_t getRank() const {
194+
return getTileStrides().size();
195+
}
196+
int64_t getUntiledRank() const;
197+
198+
::llvm::SmallVector<int64_t> getExpandedShape(::llvm::ArrayRef<int64_t> shape) const;
199+
::llvm::SmallVector<int64_t> getExpandedStrides() const;
200+
}];
176201

177202
let hasCustomAssemblyFormat = 1;
203+
let genVerifyDecl = 1;
178204
}
179205

180206
def TPU_MemorySpace : I32EnumAttr<"MemorySpace", "Memory space", [

jaxlib/mosaic/dialect/tpu/tpu_dialect.cc

Lines changed: 202 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ limitations under the License.
1515

1616
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
1717

18+
#include <algorithm>
19+
#include <cassert>
1820
#include <cstdint>
1921
#include <optional>
2022
#include <utility>
@@ -23,14 +25,16 @@ limitations under the License.
2325
#include "absl/log/log.h"
2426
#include "llvm/ADT/APFloat.h"
2527
#include "llvm/ADT/Hashing.h"
28+
#include "llvm/ADT/STLExtras.h"
2629
#include "llvm/ADT/TypeSwitch.h" // IWYU pragma: keep.
30+
#include "llvm/Support/MathExtras.h"
2731
#include "mlir/Dialect/Arith/IR/Arith.h"
2832
#include "mlir/Dialect/Func/IR/FuncOps.h"
2933
#include "mlir/Dialect/MemRef/IR/MemRef.h"
3034
#include "mlir/IR/AffineExpr.h"
3135
#include "mlir/IR/AffineMap.h"
3236
#include "mlir/IR/Builders.h"
33-
#include "mlir/IR/BuiltinAttributeInterfaces.h"
37+
#include "mlir/IR/BuiltinTypeInterfaces.h"
3438
#include "mlir/IR/BuiltinTypes.h"
3539
#include "mlir/IR/Diagnostics.h"
3640
#include "mlir/IR/DialectImplementation.h" // IWYU pragma: keep.
@@ -215,32 +219,210 @@ Attribute TiledLayoutAttr::parse(AsmParser &parser, Type type) {
215219
}
216220

217221
AffineMap TiledLayoutAttr::getAffineMap() const {
218-
AffineMap map =
219-
AffineMap::getMultiDimIdentityMap(getTileStrides().size(), getContext());
220222
SmallVector<AffineExpr, 8> exprs;
221-
for (const xla::Tile &tile : getTiles()) {
222-
exprs.clear();
223+
for (int64_t i = 0; i < getRank(); ++i) {
224+
exprs.push_back(getAffineDimExpr(i, getContext()));
225+
}
226+
for (const xla::Tile& tile : getTiles()) {
227+
SmallVector<AffineExpr, 8> new_exprs;
223228
auto dimensions = tile.dimensions();
224-
int64_t untiled_dims = map.getNumResults() - dimensions.size();
225-
if (untiled_dims < 0) {
226-
LOG(FATAL) << "Invalid TiledLayoutAttr: Number of dims must be larger "
227-
"or equal to the rank of the tile";
229+
int64_t untiled_rank = exprs.size() - dimensions.size();
230+
assert(untiled_rank >= 0);
231+
for (int64_t i = 0; i < untiled_rank; ++i) {
232+
new_exprs.push_back(exprs[i]);
233+
}
234+
for (int64_t i = 0; i < dimensions.size(); ++i) {
235+
new_exprs.push_back(exprs[untiled_rank + i].floorDiv(dimensions[i]));
236+
}
237+
for (int64_t i = 0; i < dimensions.size(); ++i) {
238+
new_exprs.push_back(exprs[untiled_rank + i] % dimensions[i]);
239+
}
240+
exprs = std::move(new_exprs);
241+
}
242+
int64_t num_symbols = 0;
243+
AffineExpr result = getAffineConstantExpr(0, getContext());
244+
SmallVector<int64_t> strides = getExpandedStrides();
245+
assert(strides.size() == exprs.size());
246+
for (int64_t i = 0; i < exprs.size(); ++i) {
247+
AffineExpr stride_expr =
248+
ShapedType::isDynamic(strides[i])
249+
? getAffineSymbolExpr(num_symbols++, getContext())
250+
: getAffineConstantExpr(strides[i], getContext());
251+
result = result + exprs[i] * stride_expr;
252+
}
253+
return AffineMap::get(getRank(), num_symbols, result);
254+
}
255+
256+
namespace {
257+
int64_t getUntiledRank(ArrayRef<xla::Tile> tiles, const int64_t rank) {
258+
// Note: This implementation does not assume there is no nested tiling across
259+
// the first level of tiling, though this is enforced by the verifier.
260+
int64_t untiled_rank = rank;
261+
int64_t tiled_rank = rank;
262+
for (const xla::Tile& tile : tiles) {
263+
const int64_t tile_ndims = tile.dimensions().size();
264+
untiled_rank = std::min(untiled_rank, tiled_rank - tile_ndims);
265+
tiled_rank += tile_ndims;
266+
}
267+
return untiled_rank;
268+
}
269+
} // namespace
270+
271+
int64_t TiledLayoutAttr::getUntiledRank() const {
272+
return mlir::tpu::getUntiledRank(getTiles(), getRank());
273+
}
274+
275+
namespace {
276+
FailureOr<SmallVector<int64_t>> getExpandedShape(
277+
const ArrayRef<int64_t> untiled_shape, const ArrayRef<xla::Tile> tiles,
278+
const bool require_alignment) {
279+
SmallVector<int64_t> shape(untiled_shape);
280+
for (const xla::Tile& tile : tiles) {
281+
const int64_t tile_ndims = tile.dimensions().size();
282+
const llvm::ArrayRef<int64_t> tiled_shape =
283+
llvm::ArrayRef(shape).take_back(tile_ndims);
284+
llvm::SmallVector<int64_t> new_tiled_shape(2 * tile_ndims);
285+
for (int64_t i = 0; i < tile_ndims; ++i) {
286+
if (require_alignment && (ShapedType::isDynamic(tiled_shape[i]) ||
287+
tiled_shape[i] % tile.dimension(i) != 0)) {
288+
return failure();
289+
}
290+
if (ShapedType::isDynamic(tiled_shape[i])) {
291+
new_tiled_shape[i] = ShapedType::kDynamic;
292+
} else {
293+
new_tiled_shape[i] =
294+
llvm::divideCeil(tiled_shape[i], tile.dimension(i));
295+
}
296+
new_tiled_shape[tile_ndims + i] = tile.dimension(i);
297+
}
298+
shape.pop_back_n(tile_ndims);
299+
shape.append(new_tiled_shape);
300+
}
301+
return shape;
302+
}
303+
} // namespace
304+
305+
SmallVector<int64_t> TiledLayoutAttr::getDefaultTileStrides(
306+
const ArrayRef<xla::Tile> tiles, const ArrayRef<int64_t> shape) {
307+
SmallVector<int64_t> strides(shape.size());
308+
int64_t stride = 1;
309+
const xla::Tile* const first_tile = tiles.empty() ? nullptr : &tiles.front();
310+
const int64_t first_tile_rank =
311+
first_tile == nullptr ? 0 : first_tile->dimensions().size();
312+
for (int64_t d = shape.size() - 1; d >= 0; --d) {
313+
assert(!ShapedType::isDynamic(shape[d]));
314+
strides[d] = stride;
315+
if (d >= shape.size() - first_tile_rank) {
316+
assert(first_tile != nullptr);
317+
const int64_t tile_d = d - (shape.size() - first_tile_rank);
318+
stride *= llvm::divideCeil(shape[d], first_tile->dimension(tile_d));
319+
} else {
320+
stride *= shape[d];
228321
}
229-
for (int64_t i = 0; i < untiled_dims; ++i) {
230-
exprs.push_back(getAffineDimExpr(i, getContext()));
322+
}
323+
return strides;
324+
}
325+
326+
bool TiledLayoutAttr::tilesAreKnownContiguous(
327+
const ArrayRef<int64_t> shape) const {
328+
const ArrayRef<xla::Tile> tiles = getTiles();
329+
const ArrayRef<int64_t> tile_strides = getTileStrides();
330+
int64_t stride = 1;
331+
const xla::Tile* const first_tile = tiles.empty() ? nullptr : &tiles.front();
332+
const int64_t first_tile_rank =
333+
first_tile == nullptr ? 0 : first_tile->dimensions().size();
334+
for (int64_t d = shape.size() - 1; d >= 0; --d) {
335+
int64_t size_tiles;
336+
if (d >= shape.size() - first_tile_rank &&
337+
shape[d] != ShapedType::kDynamic) {
338+
assert(first_tile != nullptr);
339+
const int64_t tile_d = d - (shape.size() - first_tile_rank);
340+
size_tiles = llvm::divideCeil(shape[d], first_tile->dimension(tile_d));
341+
} else {
342+
size_tiles = shape[d];
231343
}
232-
for (int i = 0; i < dimensions.size(); ++i) {
233-
exprs.push_back(getAffineDimExpr(untiled_dims + i, getContext())
234-
.floorDiv(dimensions[i]));
344+
// Dimensions with only one element/tile can have any stride.
345+
if (stride != tile_strides[d] && size_tiles != 1) {
346+
return false;
235347
}
236-
for (int i = 0; i < dimensions.size(); ++i) {
237-
exprs.push_back(getAffineDimExpr(untiled_dims + i, getContext()) %
238-
dimensions[i]);
348+
if (d == 0) {
349+
break;
239350
}
240-
auto tile_map = AffineMap::get(map.getNumResults(), 0, exprs, getContext());
241-
map = tile_map.compose(map);
351+
// When any dimension other than the leading one has a dynamic size, we
352+
// cannot guarantee that there are no gaps.
353+
if (size_tiles == ShapedType::kDynamic) {
354+
return false;
355+
}
356+
stride *= size_tiles;
357+
}
358+
return true;
359+
}
360+
361+
SmallVector<int64_t> TiledLayoutAttr::getExpandedShape(
362+
ArrayRef<int64_t> untiled_shape) const {
363+
// getExpandedShape should never fail without require_alignment
364+
return *mlir::tpu::getExpandedShape(untiled_shape, getTiles(),
365+
/*require_alignment=*/false);
366+
}
367+
368+
SmallVector<int64_t> TiledLayoutAttr::getExpandedStrides() const {
369+
if (getTiles().empty()) {
370+
return SmallVector<int64_t>(getTileStrides());
371+
}
372+
SmallVector<int64_t> strides(getTileStrides());
373+
// Expand front tile
374+
const xla::Tile& first_tile = getTiles().front();
375+
const FailureOr<SmallVector<int64_t>> failure_or_expanded_tile =
376+
mlir::tpu::getExpandedShape(first_tile.dimensions(),
377+
getTiles().drop_front(),
378+
/*require_alignment=*/true);
379+
// Verification should ensure this:
380+
assert(succeeded(failure_or_expanded_tile));
381+
const SmallVector<int64_t>& expanded_tile = *failure_or_expanded_tile;
382+
strides.resize_for_overwrite(getRank() + expanded_tile.size());
383+
int64_t first_tile_size = llvm::product_of(first_tile.dimensions());
384+
int64_t tile_size = 1;
385+
for (int64_t d = strides.size() - 1; d >= 0; --d) {
386+
if (d >= getRank()) {
387+
const int64_t new_stride = tile_size;
388+
tile_size *= expanded_tile[d - getRank()];
389+
strides[d] = new_stride;
390+
} else {
391+
strides[d] *= first_tile_size;
392+
}
393+
}
394+
return strides;
395+
}
396+
397+
LogicalResult TiledLayoutAttr::verify(
398+
function_ref<InFlightDiagnostic()> emitError,
399+
const llvm::ArrayRef<xla::Tile> tiles,
400+
const llvm::ArrayRef<int64_t> tile_strides) {
401+
if (llvm::any_of(tile_strides, ShapedType::isDynamic)) {
402+
return emitError() << "Not implemented: Dynamic tile strides";
403+
}
404+
if (tiles.empty()) {
405+
return success();
406+
}
407+
const int64_t rank = tile_strides.size();
408+
const xla::Tile& first_tile = tiles.front();
409+
const int64_t first_tile_rank = first_tile.dimensions().size();
410+
// The interpretation of tile strides is unclear if there is nested tiling
411+
// across first tiles (e.g. T(8, 128)(2, 4, 64)), and this has no applications
412+
// anyway.
413+
if (mlir::tpu::getUntiledRank(tiles, rank) != rank - first_tile_rank) {
414+
return emitError() << "Not implemented: Nested tiling across first tiles";
415+
}
416+
// Check that nested tiles evenly divide previous tiles (so they don't add any
417+
// padding or change the tile size)
418+
if (failed(mlir::tpu::getExpandedShape(first_tile.dimensions(),
419+
tiles.drop_front(),
420+
/*require_alignment=*/true))) {
421+
return emitError() << "Not implemented: Nested tiles must evenly divide "
422+
<< "the first tile " << first_tile.ToString()
423+
<< " but they do not (would add padding)";
242424
}
243-
return map;
425+
return success();
244426
}
245427

246428
MemRefType getMemRefType(Value value) {

0 commit comments

Comments
 (0)