Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,13 @@ chooseStMatrixLayout(MLIRContext *ctx, RankedTensorType tensorTy,
ArrayRef<unsigned> repShape,
ArrayRef<unsigned> paddedRepShape,
ArrayRef<unsigned> order, int swizzleByteSize);

// FIXME
// Exposing to use it in LinearLayoutConversionsTest.cpp
// Remove it once we fully activate the DotOperand conversion via LLs
class DotOperandEncodingAttr;
LinearLayout ampereDotToLinearLayout(ArrayRef<int64_t> shape,
DotOperandEncodingAttr dot);
} // namespace mlir::triton::gpu

#endif // TRITON_DIALECT_TRITONGPU_IR_LINEARLAYOUTCONVERSIONS_H
16 changes: 6 additions & 10 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1037,16 +1037,12 @@ SmallVector<unsigned> DotOperandEncodingAttr::getCTASplitNum() const {
return res;
}
SmallVector<unsigned> DotOperandEncodingAttr::getWarpsPerCTA() const {
auto parentLayout = getParent();
assert(parentLayout && "DotOperandEncodingAttr must have a parent");
if (auto distributedLayout =
mlir::dyn_cast<DistributedEncodingTrait>(parentLayout)) {
return distributedLayout.getWarpsPerCTA();
} else {
llvm::report_fatal_error(
"DotOperandEncodingAttr non-DistributedEncodingAttr parent not "
"supported yet");
}
auto distributedLayout = mlir::cast<DistributedEncodingTrait>(getParent());
auto warps = distributedLayout.getWarpsPerCTA();
auto rank = warps.size();
auto kDim = getOpIdx() == 0 ? rank - 1 : rank - 2;
warps[kDim] = 1;
return warps;
Comment on lines 1039 to +1045
Copy link
Copy Markdown
Contributor Author

@lezcano lezcano Oct 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Jokeren added this fix as I needed it for the layout conversion to be correct and pass the newly added tests.

}
SmallVector<unsigned> DotOperandEncodingAttr::getWarpOrder() const {
return ::getWarpOrder(*this);
Expand Down
69 changes: 67 additions & 2 deletions lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
#include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h"
#include "triton/Tools/LinearLayout.h"
#include "triton/Tools/StrUtil.h"
#include "llvm/ADT/DenseMap.h"
Expand Down Expand Up @@ -821,13 +822,77 @@ SliceEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
return ret;
}

LinearLayout ampereDotToLinearLayout(ArrayRef<int64_t> shape,
DotOperandEncodingAttr dot) {
// TODO,BE. Implement ampereMMA in terms of this one
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does "BE" mean? Backend?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better Engineering

int rank = shape.size();
auto mma = cast<NvidiaMmaEncodingAttr>(dot.getParent());
int kWidth = dot.getKWidth();
bool isA = dot.getOpIdx() == 0;

assert(mma.isAmpere());
assert((rank == 2 && mma.getInstrShape() == ArrayRef<unsigned>({16, 8})) ||
(rank == 3 && mma.getInstrShape() == ArrayRef<unsigned>({1, 16, 8})));

MLIRContext *ctx = mma.getContext();
SmallVector<StringAttr> dimNames = standardOutDimNames(ctx, rank);

// Implement A. For B transpose in the end
std::vector<std::vector<int32_t>> registers;
std::vector<std::vector<int32_t>> lanes;
int32_t i = 1;
// kWidth contiguous elements
while (i < kWidth) {
registers.push_back({i, 0});
i *= 2;
}
// 4 threads per chunk
for (int j = 0; j < 2; j++) {
lanes.push_back({i, 0});
i *= 2;
}
// 8 threads going down
lanes.push_back({0, 1});
lanes.push_back({0, 2});
lanes.push_back({0, 4});
// 2 tiles in column-major order
// Just one if it's the B operand
if (isA) {
registers.push_back({0, 8});
}
registers.push_back({i, 0});

if (!isA) {
for (auto &r : registers) {
std::swap(r[0], r[1]);
}
for (auto &l : lanes) {
std::swap(l[0], l[1]);
}
}

LinearLayout ctaLayout(
{{S("register"), registers}, {S("lane"), lanes}},
llvm::to_vector(llvm::reverse(ArrayRef(dimNames).take_back(2))));

auto order = dot.getCTAOrder();
assert(order[0] == 1 && order[1] == 0);
ctaLayout *= identityND(S("warp"), dot.getWarpsPerCTA(), order, dimNames);

return combineCtaCgaWithShape(ctaLayout, mma.getCTALayout(), shape);
}

std::optional<LinearLayout>
DotOperandEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {

if (auto mfmaLayout = llvm::dyn_cast<AMDMfmaEncodingAttr>(getParent())) {
return dotOperandMfmaToLinearLayout(*this, shape);
}

// TODO Activate in a follow-up PR
// else if (auto mma = mlir::dyn_cast<NvidiaMmaEncodingAttr>(getParent())) {
// if (mma.isAmpere()) {
// return ampereDotToLinearLayout(shape, *this);
// }
//}
return std::nullopt;
}

Expand Down
82 changes: 82 additions & 0 deletions unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Tools/StrUtil.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/Support/Signals.h"
#include <gmock/gmock.h>
#include <gtest/gtest.h>
Expand Down Expand Up @@ -40,6 +41,12 @@ class LinearLayoutConversionsTest : public ::testing::Test {
CTALayoutAttr::get(&ctx, cpg, cSplit, cOrd), instrShape);
}

DotOperandEncodingAttr dotMMAv2(int idx, int kWidth, ArrayRef<unsigned> warps,
ArrayRef<unsigned> order) {
auto mmaLayout = mma(2, 0, {16, 8}, warps, {1, 1}, {1, 1}, order);
return DotOperandEncodingAttr::get(&ctx, idx, mmaLayout, /*kWidth=*/kWidth);
}

AMDMfmaEncodingAttr mfma(ArrayRef<unsigned> warps, unsigned mDim,
unsigned nDim, bool isTransposed) {
SmallVector<unsigned> cpg(warps.size(), 1u);
Expand Down Expand Up @@ -494,6 +501,81 @@ TEST_F(LinearLayoutConversionsTest, MMAv3_4x4Warps) {
{S("dim0"), S("dim1")}));
}

TEST_F(LinearLayoutConversionsTest, DotMMAv2_tile_kwidth8) {
EXPECT_EQ(ampereDotToLinearLayout({16, 64}, dotMMAv2(0, 8, {1, 1}, {1, 0})),
LinearLayout(
{
{S("register"), {{0, 1}, {0, 2}, {0, 4}, {8, 0}, {0, 32}}},
{S("lane"), {{0, 8}, {0, 16}, {1, 0}, {2, 0}, {4, 0}}},
{S("warp"), {}},
{S("block"), {}},
},
{S("dim0"), S("dim1")}));
EXPECT_EQ(ampereDotToLinearLayout({64, 8}, dotMMAv2(1, 8, {1, 1}, {1, 0})),
LinearLayout(
{
{S("register"), {{1, 0}, {2, 0}, {4, 0}, {32, 0}}},
{S("lane"), {{8, 0}, {16, 0}, {0, 1}, {0, 2}, {0, 4}}},
{S("warp"), {}},
{S("block"), {}},
},
{S("dim0"), S("dim1")}));
}

TEST_F(LinearLayoutConversionsTest, DotMMAv2_large_warp4_kwidth8) {
EXPECT_EQ(
ampereDotToLinearLayout({128, 128}, dotMMAv2(0, 8, {4, 1}, {1, 0})),
LinearLayout(
{
{S("register"),
{{0, 1}, {0, 2}, {0, 4}, {8, 0}, {0, 32}, {0, 64}, {64, 0}}},
{S("lane"), {{0, 8}, {0, 16}, {1, 0}, {2, 0}, {4, 0}}},
{S("warp"), {{16, 0}, {32, 0}}},
{S("block"), {}},
},
{S("dim0"), S("dim1")}));
EXPECT_EQ(ampereDotToLinearLayout({128, 64}, dotMMAv2(1, 8, {4, 1}, {1, 0})),
LinearLayout(
{
{S("register"),
{{1, 0},
{2, 0},
{4, 0},
{32, 0},
{0, 8},
{0, 16},
{0, 32},
{64, 0}}},
{S("lane"), {{8, 0}, {16, 0}, {0, 1}, {0, 2}, {0, 4}}},
{
S("warp"),
{},
},
{S("block"), {}},
},
{S("dim0"), S("dim1")}));
EXPECT_EQ(ampereDotToLinearLayout({64, 128}, dotMMAv2(1, 8, {4, 1}, {1, 0})),
LinearLayout(
{
{S("register"),
{{1, 0},
{2, 0},
{4, 0},
{32, 0},
{0, 8},
{0, 16},
{0, 32},
{0, 64}}},
{S("lane"), {{8, 0}, {16, 0}, {0, 1}, {0, 2}, {0, 4}}},
{
S("warp"),
{},
},
{S("block"), {}},
},
{S("dim0"), S("dim1")}));
}

TEST_F(LinearLayoutConversionsTest, MFMA32_2x4Warps) {
auto mfmaNT = mfma(/*warps=*/{2, 4}, /*mDim=*/32, /*nDim=*/32,
/*isTransposed=*/false);
Expand Down