Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 27 additions & 33 deletions lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ namespace {

#define S(v) StringAttr::get(ctx, (v))

// Returns ["out0", "out1", ..., "out<rank-1>"].
// Returns ["dim0", "dim1", ..., "dim<rank-1>"].
SmallVector<StringAttr> standardOutDimNames(MLIRContext *ctx, int rank) {
SmallVector<StringAttr> ret;
for (int i = 0; i < rank; i++) {
Expand Down Expand Up @@ -71,14 +71,18 @@ void assertIsRegisterLayout(const LinearLayout &layout) {
expectedOuts.end()));
}

// Returns a 1D -> ND layout that's equivalent to creating a 1D -> 1D mapping of
// size product(shape) and then reshaping to permute(shape, order).
LinearLayout identityND(StringAttr inDimName, ArrayRef<unsigned> shape,
ArrayRef<unsigned> order,
ArrayRef<StringAttr> outDimNames) {
// Returns a 1D -> ND layout into [dim0, dim1, ...] that's equivalent to
// creating a 1D -> 1D mapping of size product(shape) and then reshaping to
// permute(shape, order).
LinearLayout identityStandardND(StringAttr inDimName, ArrayRef<unsigned> shape,
ArrayRef<unsigned> order) {
assert(shape.size() == order.size());

MLIRContext *ctx = inDimName.getContext();
auto rank = shape.size();

// The order in triton is written wrt. [dim0, dim1, ...].
SmallVector<StringAttr> outDimNames = standardOutDimNames(ctx, rank);

LinearLayout ret = LinearLayout::empty();
for (int i = 0; i < shape.size(); i++) {
// Start with the most-minor dimension, which is order[0].
Expand Down Expand Up @@ -491,7 +495,7 @@ AMDMfmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
// And each warp takes the same register and lane sub-layout. So mulitply with
// an identity layout for the warp.
LinearLayout warpLayout =
identityND(S("warp"), getWarpsPerCTA(), order, outDimNames);
identityStandardND(S("warp"), getWarpsPerCTA(), order);
LinearLayout ctaLayout = tileLayout * warpLayout;

return combineCtaCgaWithShape(ctaLayout, getCTALayout(), shape);
Expand Down Expand Up @@ -601,8 +605,7 @@ mfmaDotToLinearLayout(DotOperandEncodingAttr dotMfmaLayout,
tileLayout *= LinearLayout::identity1D(1, kLane, outDimNames[order[2]]);
}

LinearLayout warpLayout =
identityND(kWarp, warpsPerCTA, warpOrder, outDimNames);
LinearLayout warpLayout = identityStandardND(kWarp, warpsPerCTA, warpOrder);

LinearLayout ctaLayout = tileLayout.transposeOuts(outDimNames) *
warpLayout.transposeOuts(outDimNames);
Expand Down Expand Up @@ -684,7 +687,7 @@ AMDWmmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
// And each warp takes the same register and lane sub-layout. So mulitply with
// an identity layout for the warp.
LinearLayout warpLayout =
identityND(S("warp"), getWarpsPerCTA(), order, outDimNames);
identityStandardND(S("warp"), getWarpsPerCTA(), order);
LinearLayout ctaLayout = tileLayout * warpLayout;

return combineCtaCgaWithShape(ctaLayout, getCTALayout(), shape);
Expand All @@ -700,9 +703,9 @@ BlockedEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {

const auto &order = getOrder();
LinearLayout ctaLayout =
identityND(S("register"), getSizePerThread(), order, outDimNames) *
identityND(S("lane"), getThreadsPerWarp(), order, outDimNames) *
identityND(S("warp"), getWarpsPerCTA(), order, outDimNames);
identityStandardND(S("register"), getSizePerThread(), order) *
identityStandardND(S("lane"), getThreadsPerWarp(), order) *
identityStandardND(S("warp"), getWarpsPerCTA(), order);

return combineCtaCgaWithShape(ctaLayout, getCTALayout(), shape);
}
Expand All @@ -711,11 +714,12 @@ LinearLayout nvidiaMmaTile(MLIRContext *ctx, ArrayRef<unsigned> tileShape,
unsigned kWidth, ArrayRef<unsigned> order,
ArrayRef<unsigned> repOrder) {
// Trivial layout mapping 0 -> (0, 0), but we set the order to repOrder
// Like LinearLayout::empty() but with a rank and an order
int rank = repOrder.size();
auto dimNames = standardOutDimNames(ctx, rank);
auto trivialShape = SmallVector<unsigned>(rank, 1);
LinearLayout ctaLayout =
identityND(S("register"), trivialShape, repOrder, dimNames);
identityStandardND(S("register"), trivialShape, repOrder);

assert(rank >= 2);
auto inner = order[0];
Expand Down Expand Up @@ -769,11 +773,7 @@ NvidiaMmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
// The triton orders are defined on [dim0, dim1, ...], so we need to pass
// those dims Then, for some reason, operator* requires the orders to match
// so we need to reorder the outs to match
// FIXME(Lezcano). identityND should not take a dim name, as it's redundant.
// The order in triton assumes the standardDims, so it should
// use those.
ctaLayout *= identityND(S("warp"), getWarpsPerCTA(), getWarpOrder(),
standardOutDimNames(ctx, rank))
ctaLayout *= identityStandardND(S("warp"), getWarpsPerCTA(), getWarpOrder())
.transposeOuts(llvm::to_vector(ctaLayout.getOutDimNames()));

return combineCtaCgaWithShape(ctaLayout, getCTALayout(), shape);
Expand All @@ -797,11 +797,8 @@ LinearLayout warpsNvidiaDot(MLIRContext *ctx, ArrayRef<unsigned> mmaWarpShape,
// In other words, we need to broadcast along K
auto rank = mmaWarpOrder.size();
auto inner = isA ? rank - 1 : rank - 2;
auto outer = isA ? rank - 2 : rank - 1;
auto dimNames = standardOutDimNames(ctx, rank);
auto trivialShape = SmallVector<unsigned>(rank, 1);
LinearLayout warpLayout =
identityND(S("warp"), trivialShape, mmaWarpOrder, dimNames);
LinearLayout warpLayout = LinearLayout::empty();

// We have to broadcast along the inner dimension
// For A, when moving along M we go from 0 to 2.
Expand Down Expand Up @@ -1086,9 +1083,8 @@ std::optional<LinearLayout> chooseStMatrixLayoutLeadingOffset(

// Expand the `warp` dimension according to warpsPerCTA.
auto mma = cast<NvidiaMmaEncodingAttr>(tensorTy.getEncoding());
layout *=
identityND(kWarp, mma.getWarpsPerCTA(), /*order=*/{0, 1}, {kRow, kCol})
.transposeOuts(llvm::to_vector(layout.getOutDimNames()));
layout *= identityStandardND(kWarp, mma.getWarpsPerCTA(), /*order=*/{0, 1})
.transposeOuts(llvm::to_vector(layout.getOutDimNames()));

// Expand the `register` dimension so the size of columns matches `n`.
int n = mma.getInstrShape()[1];
Expand Down Expand Up @@ -1126,9 +1122,8 @@ std::optional<LinearLayout> chooseStMatrixLayoutNoLeadingOffset(
LinearLayout::identity1D(n / layout.getOutDimSize(kCol), kReg, kCol);

// Expand the `warp` dimension according to warpsPerCTA.
layout *=
identityND(kWarp, mma.getWarpsPerCTA(), /*order=*/{0, 1}, {kRow, kCol})
.transposeOuts(llvm::to_vector(layout.getOutDimNames()));
layout *= identityStandardND(kWarp, mma.getWarpsPerCTA(), /*order=*/{0, 1})
.transposeOuts(llvm::to_vector(layout.getOutDimNames()));
auto ret =
combineCtaCgaWithShape(layout, mma.getCTALayout(), tensorTy.getShape());
auto tensorShapePerCTA = getShapePerCTA(mma, tensorTy.getShape());
Expand All @@ -1138,9 +1133,8 @@ std::optional<LinearLayout> chooseStMatrixLayoutNoLeadingOffset(
ret = ensureLayoutNotSmallerThan(ret, namedTensorShape);
ret = ensureLayoutNotLargerThan(ret, namedTensorShape);
return ret.transposeOuts(llvm::to_vector(layout.getOutDimNames()))
.reshapeOuts({{S("offset"), ret.getTotalOutDimSize()},
{S("iteration"), 1}}) *
identityND(kBlock, {1, 1}, {0, 1}, {S("offset"), S("iteration")});
.reshapeOuts(
{{S("offset"), ret.getTotalOutDimSize()}, {S("iteration"), 1}});
}

} // anonymous namespace
Expand Down