From 6742df0fd6c649952323e3fb136ef42e138cbd18 Mon Sep 17 00:00:00 2001 From: lezcano Date: Fri, 11 Oct 2024 15:52:56 +0100 Subject: [PATCH 1/4] Implement LL --- .../TritonGPU/IR/LinearLayoutConversions.cpp | 69 +++++++++++++++- .../TritonGPU/LinearLayoutConversionsTest.cpp | 82 +++++++++++++++++++ 2 files changed, 149 insertions(+), 2 deletions(-) diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index c35b186fbf97..316dfa3468a0 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -4,6 +4,7 @@ #include "triton/Dialect/TritonGPU/IR/Attributes.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" +#include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h" #include "triton/Tools/LinearLayout.h" #include "triton/Tools/StrUtil.h" #include "llvm/ADT/DenseMap.h" @@ -821,13 +822,77 @@ SliceEncodingAttr::toLinearLayout(ArrayRef shape) const { return ret; } +LinearLayout ampereDotToLinearLayout(ArrayRef shape, + DotOperandEncodingAttr dot) { + // TODO,BE. Implement ampereMMA in terms of this one + int rank = shape.size(); + auto mma = cast(dot.getParent()); + int kWidth = dot.getKWidth(); + bool isA = dot.getOpIdx() == 0; + + assert(mma.isAmpere()); + assert(rank == 2 || rank == 3); + assert(mma.getInstrShape().size() == rank); + assert((rank == 2 && mma.getInstrShape() == ArrayRef({16, 8})) || + (rank == 3 && mma.getInstrShape() == ArrayRef({1, 16, 8}))); + + MLIRContext *ctx = mma.getContext(); + SmallVector dimNames = standardOutDimNames(ctx, rank); + + // Implement A. For B transpose in the end + std::vector> registers; + std::vector> lanes; + int32_t i = 1; + // kWidth contiguous elements + while (i < kWidth) { + registers.push_back({i, 0}); + i *= 2; + } + // 4 threads per chunk + for (int j = 0; j < 2; j++) { + lanes.push_back({i, 0}); + i *= 2; + } + // 8 threads going down + lanes.push_back({0, 1}); + lanes.push_back({0, 2}); + lanes.push_back({0, 4}); + // 2 tiles in column-major order + // Just one if it's the B operand + if (isA) { + registers.push_back({0, 8}); + } + registers.push_back({i, 0}); + + if (!isA) { + for (auto &r : registers) { + std::swap(r[0], r[1]); + } + for (auto &l : lanes) { + std::swap(l[0], l[1]); + } + } + + LinearLayout ctaLayout( + {{S("register"), registers}, {S("lane"), lanes}}, + llvm::to_vector(llvm::reverse(ArrayRef(dimNames).take_back(2)))); + + auto order = dot.getCTAOrder(); + assert(order[0] == 1 && order[1] == 0); + ctaLayout *= identityND(S("warp"), dot.getWarpsPerCTA(), order, dimNames); + + return combineCtaCgaWithShape(ctaLayout, mma.getCTALayout(), shape); +} + std::optional DotOperandEncodingAttr::toLinearLayout(ArrayRef shape) const { - if (auto mfmaLayout = llvm::dyn_cast(getParent())) { return dotOperandMfmaToLinearLayout(*this, shape); + } else if (auto mma = mlir::dyn_cast(getParent())) { + if (mma.isAmpere()) { + return ampereDotToLinearLayout(shape, *this); + } } - return std::nullopt; } diff --git a/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp b/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp index c65428d03975..76c9c442257d 100644 --- a/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp +++ b/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp @@ -4,6 +4,7 @@ #include "triton/Dialect/TritonGPU/IR/Attributes.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Tools/StrUtil.h" +#include "llvm/ADT/ArrayRef.h" #include "llvm/Support/Signals.h" #include #include @@ -40,6 +41,12 @@ class LinearLayoutConversionsTest : public ::testing::Test { CTALayoutAttr::get(&ctx, cpg, cSplit, cOrd), instrShape); } + DotOperandEncodingAttr dotMMAv2(int idx, int kWidth, ArrayRef warps, + ArrayRef order) { + auto mmaLayout = mma(2, 0, {16, 8}, warps, {1, 1}, {1, 1}, order); + return DotOperandEncodingAttr::get(&ctx, idx, mmaLayout, /*kWidth=*/kWidth); + } + AMDMfmaEncodingAttr mfma(ArrayRef warps, unsigned mDim, unsigned nDim, bool isTransposed) { SmallVector cpg(warps.size(), 1u); @@ -494,6 +501,81 @@ TEST_F(LinearLayoutConversionsTest, MMAv3_4x4Warps) { {S("dim0"), S("dim1")})); } +TEST_F(LinearLayoutConversionsTest, DotMMAv2_tile_kwidth8) { + EXPECT_EQ(toLinearLayout({16, 64}, dotMMAv2(0, 8, {1, 1}, {1, 0})), + LinearLayout( + { + {S("register"), {{0, 1}, {0, 2}, {0, 4}, {8, 0}, {0, 32}}}, + {S("lane"), {{0, 8}, {0, 16}, {1, 0}, {2, 0}, {4, 0}}}, + {S("warp"), {}}, + {S("block"), {}}, + }, + {S("dim0"), S("dim1")})); + EXPECT_EQ(toLinearLayout({64, 8}, dotMMAv2(1, 8, {1, 1}, {1, 0})), + LinearLayout( + { + {S("register"), {{1, 0}, {2, 0}, {4, 0}, {32, 0}}}, + {S("lane"), {{8, 0}, {16, 0}, {0, 1}, {0, 2}, {0, 4}}}, + {S("warp"), {}}, + {S("block"), {}}, + }, + {S("dim0"), S("dim1")})); +} + +TEST_F(LinearLayoutConversionsTest, DotMMAv2_large_warp4_kwidth8) { + EXPECT_EQ( + toLinearLayout({128, 128}, dotMMAv2(0, 8, {4, 1}, {1, 0})), + LinearLayout( + { + {S("register"), + {{0, 1}, {0, 2}, {0, 4}, {8, 0}, {0, 32}, {0, 64}, {64, 0}}}, + {S("lane"), {{0, 8}, {0, 16}, {1, 0}, {2, 0}, {4, 0}}}, + {S("warp"), {{16, 0}, {32, 0}}}, + {S("block"), {}}, + }, + {S("dim0"), S("dim1")})); + EXPECT_EQ(toLinearLayout({128, 64}, dotMMAv2(1, 8, {4, 1}, {1, 0})), + LinearLayout( + { + {S("register"), + {{1, 0}, + {2, 0}, + {4, 0}, + {32, 0}, + {0, 8}, + {0, 16}, + {0, 32}, + {64, 0}}}, + {S("lane"), {{8, 0}, {16, 0}, {0, 1}, {0, 2}, {0, 4}}}, + { + S("warp"), + {}, + }, + {S("block"), {}}, + }, + {S("dim0"), S("dim1")})); + EXPECT_EQ(toLinearLayout({64, 128}, dotMMAv2(1, 8, {4, 1}, {1, 0})), + LinearLayout( + { + {S("register"), + {{1, 0}, + {2, 0}, + {4, 0}, + {32, 0}, + {0, 8}, + {0, 16}, + {0, 32}, + {0, 64}}}, + {S("lane"), {{8, 0}, {16, 0}, {0, 1}, {0, 2}, {0, 4}}}, + { + S("warp"), + {}, + }, + {S("block"), {}}, + }, + {S("dim0"), S("dim1")})); +} + TEST_F(LinearLayoutConversionsTest, MFMA32_2x4Warps) { auto mfmaNT = mfma(/*warps=*/{2, 4}, /*mDim=*/32, /*nDim=*/32, /*isTransposed=*/false); From cf47c27df70dc45225d6d51e493c2a4822c69380 Mon Sep 17 00:00:00 2001 From: lezcano Date: Mon, 14 Oct 2024 14:19:33 +0100 Subject: [PATCH 2/4] Deactivate the LL cast for DotOperand for now --- lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp | 10 ++++++---- .../Dialect/TritonGPU/LinearLayoutConversionsTest.cpp | 10 +++++----- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index 316dfa3468a0..0131462439fa 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -888,11 +888,13 @@ std::optional DotOperandEncodingAttr::toLinearLayout(ArrayRef shape) const { if (auto mfmaLayout = llvm::dyn_cast(getParent())) { return dotOperandMfmaToLinearLayout(*this, shape); - } else if (auto mma = mlir::dyn_cast(getParent())) { - if (mma.isAmpere()) { - return ampereDotToLinearLayout(shape, *this); - } } + // TODO Activate in a follow-up PR + // else if (auto mma = mlir::dyn_cast(getParent())) { + // if (mma.isAmpere()) { + // return ampereDotToLinearLayout(shape, *this); + // } + //} return std::nullopt; } diff --git a/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp b/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp index 76c9c442257d..e91a0ecf26ab 100644 --- a/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp +++ b/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp @@ -502,7 +502,7 @@ TEST_F(LinearLayoutConversionsTest, MMAv3_4x4Warps) { } TEST_F(LinearLayoutConversionsTest, DotMMAv2_tile_kwidth8) { - EXPECT_EQ(toLinearLayout({16, 64}, dotMMAv2(0, 8, {1, 1}, {1, 0})), + EXPECT_EQ(ampereDotToLinearLayout({16, 64}, dotMMAv2(0, 8, {1, 1}, {1, 0})), LinearLayout( { {S("register"), {{0, 1}, {0, 2}, {0, 4}, {8, 0}, {0, 32}}}, @@ -511,7 +511,7 @@ TEST_F(LinearLayoutConversionsTest, DotMMAv2_tile_kwidth8) { {S("block"), {}}, }, {S("dim0"), S("dim1")})); - EXPECT_EQ(toLinearLayout({64, 8}, dotMMAv2(1, 8, {1, 1}, {1, 0})), + EXPECT_EQ(ampereDotToLinearLayout({64, 8}, dotMMAv2(1, 8, {1, 1}, {1, 0})), LinearLayout( { {S("register"), {{1, 0}, {2, 0}, {4, 0}, {32, 0}}}, @@ -524,7 +524,7 @@ TEST_F(LinearLayoutConversionsTest, DotMMAv2_tile_kwidth8) { TEST_F(LinearLayoutConversionsTest, DotMMAv2_large_warp4_kwidth8) { EXPECT_EQ( - toLinearLayout({128, 128}, dotMMAv2(0, 8, {4, 1}, {1, 0})), + ampereDotToLinearLayout({128, 128}, dotMMAv2(0, 8, {4, 1}, {1, 0})), LinearLayout( { {S("register"), @@ -534,7 +534,7 @@ TEST_F(LinearLayoutConversionsTest, DotMMAv2_large_warp4_kwidth8) { {S("block"), {}}, }, {S("dim0"), S("dim1")})); - EXPECT_EQ(toLinearLayout({128, 64}, dotMMAv2(1, 8, {4, 1}, {1, 0})), + EXPECT_EQ(ampereDotToLinearLayout({128, 64}, dotMMAv2(1, 8, {4, 1}, {1, 0})), LinearLayout( { {S("register"), @@ -554,7 +554,7 @@ TEST_F(LinearLayoutConversionsTest, DotMMAv2_large_warp4_kwidth8) { {S("block"), {}}, }, {S("dim0"), S("dim1")})); - EXPECT_EQ(toLinearLayout({64, 128}, dotMMAv2(1, 8, {4, 1}, {1, 0})), + EXPECT_EQ(ampereDotToLinearLayot({64, 128}, dotMMAv2(1, 8, {4, 1}, {1, 0})), LinearLayout( { {S("register"), From 1bec1b6fd2ec83ac7fcf50630c51c51b994ac3fc Mon Sep 17 00:00:00 2001 From: lezcano Date: Mon, 14 Oct 2024 14:54:51 +0100 Subject: [PATCH 3/4] Fix getWarpsPerCTA --- .../TritonGPU/IR/LinearLayoutConversions.h | 7 +++++++ lib/Dialect/TritonGPU/IR/Dialect.cpp | 16 ++++++---------- .../TritonGPU/LinearLayoutConversionsTest.cpp | 2 +- 3 files changed, 14 insertions(+), 11 deletions(-) diff --git a/include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h b/include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h index 1367f65a031f..1124daec6dfc 100644 --- a/include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h +++ b/include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h @@ -250,6 +250,13 @@ chooseStMatrixLayout(MLIRContext *ctx, RankedTensorType tensorTy, ArrayRef repShape, ArrayRef paddedRepShape, ArrayRef order, int swizzleByteSize); + +// FIXME +// Exposing to use it in LinearLayoutConversionsTest.cpp +// Remove it once we fully activate the DotOperand conversion via LLs +class DotOperandEncodingAttr; +LinearLayout ampereDotToLinearLayout(ArrayRef shape, + DotOperandEncodingAttr dot); } // namespace mlir::triton::gpu #endif // TRITON_DIALECT_TRITONGPU_IR_LINEARLAYOUTCONVERSIONS_H diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 80fe1aed29f4..5d1d3617b08a 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -1037,16 +1037,12 @@ SmallVector DotOperandEncodingAttr::getCTASplitNum() const { return res; } SmallVector DotOperandEncodingAttr::getWarpsPerCTA() const { - auto parentLayout = getParent(); - assert(parentLayout && "DotOperandEncodingAttr must have a parent"); - if (auto distributedLayout = - mlir::dyn_cast(parentLayout)) { - return distributedLayout.getWarpsPerCTA(); - } else { - llvm::report_fatal_error( - "DotOperandEncodingAttr non-DistributedEncodingAttr parent not " - "supported yet"); - } + auto distributedLayout = mlir::cast(getParent()); + auto warps = distributedLayout.getWarpsPerCTA(); + auto rank = warps.size(); + auto kDim = getOpIdx() == 0 ? rank - 1 : rank - 2; + warps[kDim] = 1; + return warps; } SmallVector DotOperandEncodingAttr::getWarpOrder() const { return ::getWarpOrder(*this); diff --git a/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp b/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp index e91a0ecf26ab..015a450dfff0 100644 --- a/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp +++ b/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp @@ -554,7 +554,7 @@ TEST_F(LinearLayoutConversionsTest, DotMMAv2_large_warp4_kwidth8) { {S("block"), {}}, }, {S("dim0"), S("dim1")})); - EXPECT_EQ(ampereDotToLinearLayot({64, 128}, dotMMAv2(1, 8, {4, 1}, {1, 0})), + EXPECT_EQ(ampereDotToLinearLayout({64, 128}, dotMMAv2(1, 8, {4, 1}, {1, 0})), LinearLayout( { {S("register"), From b0484fc8d4f518fefb0c34c217fc2222c8015343 Mon Sep 17 00:00:00 2001 From: lezcano Date: Tue, 15 Oct 2024 17:45:59 +0100 Subject: [PATCH 4/4] address review --- lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index 0131462439fa..bc365057f811 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -831,8 +831,6 @@ LinearLayout ampereDotToLinearLayout(ArrayRef shape, bool isA = dot.getOpIdx() == 0; assert(mma.isAmpere()); - assert(rank == 2 || rank == 3); - assert(mma.getInstrShape().size() == rank); assert((rank == 2 && mma.getInstrShape() == ArrayRef({16, 8})) || (rank == 3 && mma.getInstrShape() == ArrayRef({1, 16, 8})));