@@ -15,6 +15,8 @@ limitations under the License.
1515
1616#include " jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
1717
18+ #include < algorithm>
19+ #include < cassert>
1820#include < cstdint>
1921#include < optional>
2022#include < utility>
@@ -23,14 +25,16 @@ limitations under the License.
2325#include " absl/log/log.h"
2426#include " llvm/ADT/APFloat.h"
2527#include " llvm/ADT/Hashing.h"
28+ #include " llvm/ADT/STLExtras.h"
2629#include " llvm/ADT/TypeSwitch.h" // IWYU pragma: keep.
30+ #include " llvm/Support/MathExtras.h"
2731#include " mlir/Dialect/Arith/IR/Arith.h"
2832#include " mlir/Dialect/Func/IR/FuncOps.h"
2933#include " mlir/Dialect/MemRef/IR/MemRef.h"
3034#include " mlir/IR/AffineExpr.h"
3135#include " mlir/IR/AffineMap.h"
3236#include " mlir/IR/Builders.h"
33- #include " mlir/IR/BuiltinAttributeInterfaces .h"
37+ #include " mlir/IR/BuiltinTypeInterfaces .h"
3438#include " mlir/IR/BuiltinTypes.h"
3539#include " mlir/IR/Diagnostics.h"
3640#include " mlir/IR/DialectImplementation.h" // IWYU pragma: keep.
@@ -215,32 +219,210 @@ Attribute TiledLayoutAttr::parse(AsmParser &parser, Type type) {
215219}
216220
217221AffineMap TiledLayoutAttr::getAffineMap () const {
218- AffineMap map =
219- AffineMap::getMultiDimIdentityMap (getTileStrides ().size (), getContext ());
220222 SmallVector<AffineExpr, 8 > exprs;
221- for (const xla::Tile &tile : getTiles ()) {
222- exprs.clear ();
223+ for (int64_t i = 0 ; i < getRank (); ++i) {
224+ exprs.push_back (getAffineDimExpr (i, getContext ()));
225+ }
226+ for (const xla::Tile& tile : getTiles ()) {
227+ SmallVector<AffineExpr, 8 > new_exprs;
223228 auto dimensions = tile.dimensions ();
224- int64_t untiled_dims = map.getNumResults () - dimensions.size ();
225- if (untiled_dims < 0 ) {
226- LOG (FATAL) << " Invalid TiledLayoutAttr: Number of dims must be larger "
227- " or equal to the rank of the tile" ;
229+ int64_t untiled_rank = exprs.size () - dimensions.size ();
230+ assert (untiled_rank >= 0 );
231+ for (int64_t i = 0 ; i < untiled_rank; ++i) {
232+ new_exprs.push_back (exprs[i]);
233+ }
234+ for (int64_t i = 0 ; i < dimensions.size (); ++i) {
235+ new_exprs.push_back (exprs[untiled_rank + i].floorDiv (dimensions[i]));
236+ }
237+ for (int64_t i = 0 ; i < dimensions.size (); ++i) {
238+ new_exprs.push_back (exprs[untiled_rank + i] % dimensions[i]);
239+ }
240+ exprs = std::move (new_exprs);
241+ }
242+ int64_t num_symbols = 0 ;
243+ AffineExpr result = getAffineConstantExpr (0 , getContext ());
244+ SmallVector<int64_t > strides = getExpandedStrides ();
245+ assert (strides.size () == exprs.size ());
246+ for (int64_t i = 0 ; i < exprs.size (); ++i) {
247+ AffineExpr stride_expr =
248+ ShapedType::isDynamic (strides[i])
249+ ? getAffineSymbolExpr (num_symbols++, getContext ())
250+ : getAffineConstantExpr (strides[i], getContext ());
251+ result = result + exprs[i] * stride_expr;
252+ }
253+ return AffineMap::get (getRank (), num_symbols, result);
254+ }
255+
256+ namespace {
257+ int64_t getUntiledRank (ArrayRef<xla::Tile> tiles, const int64_t rank) {
258+ // Note: This implementation does not assume there is no nested tiling across
259+ // the first level of tiling, though this is enforced by the verifier.
260+ int64_t untiled_rank = rank;
261+ int64_t tiled_rank = rank;
262+ for (const xla::Tile& tile : tiles) {
263+ const int64_t tile_ndims = tile.dimensions ().size ();
264+ untiled_rank = std::min (untiled_rank, tiled_rank - tile_ndims);
265+ tiled_rank += tile_ndims;
266+ }
267+ return untiled_rank;
268+ }
269+ } // namespace
270+
271+ int64_t TiledLayoutAttr::getUntiledRank () const {
272+ return mlir::tpu::getUntiledRank (getTiles (), getRank ());
273+ }
274+
275+ namespace {
276+ FailureOr<SmallVector<int64_t >> getExpandedShape (
277+ const ArrayRef<int64_t > untiled_shape, const ArrayRef<xla::Tile> tiles,
278+ const bool require_alignment) {
279+ SmallVector<int64_t > shape (untiled_shape);
280+ for (const xla::Tile& tile : tiles) {
281+ const int64_t tile_ndims = tile.dimensions ().size ();
282+ const llvm::ArrayRef<int64_t > tiled_shape =
283+ llvm::ArrayRef (shape).take_back (tile_ndims);
284+ llvm::SmallVector<int64_t > new_tiled_shape (2 * tile_ndims);
285+ for (int64_t i = 0 ; i < tile_ndims; ++i) {
286+ if (require_alignment && (ShapedType::isDynamic (tiled_shape[i]) ||
287+ tiled_shape[i] % tile.dimension (i) != 0 )) {
288+ return failure ();
289+ }
290+ if (ShapedType::isDynamic (tiled_shape[i])) {
291+ new_tiled_shape[i] = ShapedType::kDynamic ;
292+ } else {
293+ new_tiled_shape[i] =
294+ llvm::divideCeil (tiled_shape[i], tile.dimension (i));
295+ }
296+ new_tiled_shape[tile_ndims + i] = tile.dimension (i);
297+ }
298+ shape.pop_back_n (tile_ndims);
299+ shape.append (new_tiled_shape);
300+ }
301+ return shape;
302+ }
303+ } // namespace
304+
305+ SmallVector<int64_t > TiledLayoutAttr::getDefaultTileStrides (
306+ const ArrayRef<xla::Tile> tiles, const ArrayRef<int64_t > shape) {
307+ SmallVector<int64_t > strides (shape.size ());
308+ int64_t stride = 1 ;
309+ const xla::Tile* const first_tile = tiles.empty () ? nullptr : &tiles.front ();
310+ const int64_t first_tile_rank =
311+ first_tile == nullptr ? 0 : first_tile->dimensions ().size ();
312+ for (int64_t d = shape.size () - 1 ; d >= 0 ; --d) {
313+ assert (!ShapedType::isDynamic (shape[d]));
314+ strides[d] = stride;
315+ if (d >= shape.size () - first_tile_rank) {
316+ assert (first_tile != nullptr );
317+ const int64_t tile_d = d - (shape.size () - first_tile_rank);
318+ stride *= llvm::divideCeil (shape[d], first_tile->dimension (tile_d));
319+ } else {
320+ stride *= shape[d];
228321 }
229- for (int64_t i = 0 ; i < untiled_dims; ++i) {
230- exprs.push_back (getAffineDimExpr (i, getContext ()));
322+ }
323+ return strides;
324+ }
325+
326+ bool TiledLayoutAttr::tilesAreKnownContiguous (
327+ const ArrayRef<int64_t > shape) const {
328+ const ArrayRef<xla::Tile> tiles = getTiles ();
329+ const ArrayRef<int64_t > tile_strides = getTileStrides ();
330+ int64_t stride = 1 ;
331+ const xla::Tile* const first_tile = tiles.empty () ? nullptr : &tiles.front ();
332+ const int64_t first_tile_rank =
333+ first_tile == nullptr ? 0 : first_tile->dimensions ().size ();
334+ for (int64_t d = shape.size () - 1 ; d >= 0 ; --d) {
335+ int64_t size_tiles;
336+ if (d >= shape.size () - first_tile_rank &&
337+ shape[d] != ShapedType::kDynamic ) {
338+ assert (first_tile != nullptr );
339+ const int64_t tile_d = d - (shape.size () - first_tile_rank);
340+ size_tiles = llvm::divideCeil (shape[d], first_tile->dimension (tile_d));
341+ } else {
342+ size_tiles = shape[d];
231343 }
232- for ( int i = 0 ; i < dimensions. size (); ++i) {
233- exprs. push_back ( getAffineDimExpr (untiled_dims + i, getContext ())
234- . floorDiv (dimensions[i])) ;
344+ // Dimensions with only one element/tile can have any stride.
345+ if (stride != tile_strides[d] && size_tiles != 1 ) {
346+ return false ;
235347 }
236- for (int i = 0 ; i < dimensions.size (); ++i) {
237- exprs.push_back (getAffineDimExpr (untiled_dims + i, getContext ()) %
238- dimensions[i]);
348+ if (d == 0 ) {
349+ break ;
239350 }
240- auto tile_map = AffineMap::get (map.getNumResults (), 0 , exprs, getContext ());
241- map = tile_map.compose (map);
351+ // When any dimension other than the leading one has a dynamic size, we
352+ // cannot guarantee that there are no gaps.
353+ if (size_tiles == ShapedType::kDynamic ) {
354+ return false ;
355+ }
356+ stride *= size_tiles;
357+ }
358+ return true ;
359+ }
360+
361+ SmallVector<int64_t > TiledLayoutAttr::getExpandedShape (
362+ ArrayRef<int64_t > untiled_shape) const {
363+ // getExpandedShape should never fail without require_alignment
364+ return *mlir::tpu::getExpandedShape (untiled_shape, getTiles (),
365+ /* require_alignment=*/ false );
366+ }
367+
368+ SmallVector<int64_t > TiledLayoutAttr::getExpandedStrides () const {
369+ if (getTiles ().empty ()) {
370+ return SmallVector<int64_t >(getTileStrides ());
371+ }
372+ SmallVector<int64_t > strides (getTileStrides ());
373+ // Expand front tile
374+ const xla::Tile& first_tile = getTiles ().front ();
375+ const FailureOr<SmallVector<int64_t >> failure_or_expanded_tile =
376+ mlir::tpu::getExpandedShape (first_tile.dimensions (),
377+ getTiles ().drop_front (),
378+ /* require_alignment=*/ true );
379+ // Verification should ensure this:
380+ assert (succeeded (failure_or_expanded_tile));
381+ const SmallVector<int64_t >& expanded_tile = *failure_or_expanded_tile;
382+ strides.resize_for_overwrite (getRank () + expanded_tile.size ());
383+ int64_t first_tile_size = llvm::product_of (first_tile.dimensions ());
384+ int64_t tile_size = 1 ;
385+ for (int64_t d = strides.size () - 1 ; d >= 0 ; --d) {
386+ if (d >= getRank ()) {
387+ const int64_t new_stride = tile_size;
388+ tile_size *= expanded_tile[d - getRank ()];
389+ strides[d] = new_stride;
390+ } else {
391+ strides[d] *= first_tile_size;
392+ }
393+ }
394+ return strides;
395+ }
396+
397+ LogicalResult TiledLayoutAttr::verify (
398+ function_ref<InFlightDiagnostic()> emitError,
399+ const llvm::ArrayRef<xla::Tile> tiles,
400+ const llvm::ArrayRef<int64_t> tile_strides) {
401+ if (llvm::any_of (tile_strides, ShapedType::isDynamic)) {
402+ return emitError () << " Not implemented: Dynamic tile strides" ;
403+ }
404+ if (tiles.empty ()) {
405+ return success ();
406+ }
407+ const int64_t rank = tile_strides.size ();
408+ const xla::Tile& first_tile = tiles.front ();
409+ const int64_t first_tile_rank = first_tile.dimensions ().size ();
410+ // The interpretation of tile strides is unclear if there is nested tiling
411+ // across first tiles (e.g. T(8, 128)(2, 4, 64)), and this has no applications
412+ // anyway.
413+ if (mlir::tpu::getUntiledRank (tiles, rank) != rank - first_tile_rank) {
414+ return emitError () << " Not implemented: Nested tiling across first tiles" ;
415+ }
416+ // Check that nested tiles evenly divide previous tiles (so they don't add any
417+ // padding or change the tile size)
418+ if (failed (mlir::tpu::getExpandedShape (first_tile.dimensions (),
419+ tiles.drop_front (),
420+ /* require_alignment=*/ true ))) {
421+ return emitError () << " Not implemented: Nested tiles must evenly divide "
422+ << " the first tile " << first_tile.ToString ()
423+ << " but they do not (would add padding)" ;
242424 }
243- return map ;
425+ return success () ;
244426}
245427
246428MemRefType getMemRefType (Value value) {
0 commit comments