diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 1e63c4b390d4..4c68ac9858ff 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -1017,7 +1017,8 @@ SmallVector DotOperandEncodingAttr::getCTASplitNum() const { assert(rank == 2 || rank == 3 && "Invalid dotLayout"); // Do not split CTA in K dimension - getOpIdx() == 0 ? res[rank - 1] = 1 : res[rank - 2] = 1; + auto kDim = getOpIdx() == 0 ? rank - 1 : rank - 2; + res[kDim] = 1; return res; } SmallVector DotOperandEncodingAttr::getWarpsPerCTA() const { diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index 43c87af487a1..b707d8f7d328 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -280,78 +280,6 @@ LinearLayout combineCtaCgaWithShape(LinearLayout ctaLayout, return ret; } -LinearLayout ampereMmaToLinearLayout(ArrayRef shape, - NvidiaMmaEncodingAttr mma) { - int rank = shape.size(); - - assert(mma.isAmpere()); - assert(rank == 2 || rank == 3); - assert(mma.getInstrShape().size() == rank); - assert((rank == 2 && mma.getInstrShape() == ArrayRef({16, 8})) || - (rank == 3 && mma.getInstrShape() == ArrayRef({1, 16, 8}))); - - MLIRContext *ctx = mma.getContext(); - SmallVector dimNames = standardOutDimNames(ctx, rank); - - auto orderedDimNames = permuteDimNames(dimNames, mma.getRepOrder()); - assert(mma.getRepOrder() == getMatrixOrder(rank, /*rowMajor=*/true)); - - LinearLayout ctaLayout( - {{S("register"), {{1, 0}, {0, 8}}}, - {S("lane"), {{2, 0}, {4, 0}, {0, 1}, {0, 2}, {0, 4}}}}, - ArrayRef(orderedDimNames).take_front(2)); - assert(getWarpOrder(mma) == getMatrixOrder(rank, /*rowMajor=*/true)); - // FIXME(Lezcano). identityND should not have an `order` param as it's - // redundant with the order of the out dims. - ctaLayout *= - identityND(S("warp"), mma.getWarpsPerCTA(), mma.getWarpOrder(), dimNames); - - return combineCtaCgaWithShape(ctaLayout, mma.getCTALayout(), shape); -} - -LinearLayout hopperMmaToLinearLayout(ArrayRef shape, - NvidiaMmaEncodingAttr mma) { - int rank = shape.size(); - assert(mma.isHopper()); - assert(rank == 2); - - // wgmma operates on groups of 4 warps. - assert(product(mma.getWarpsPerCTA()) % 4 == 0); - - // Check that it's a known MMA layout. - assert(mma.getInstrShape().size() == 3); - int m = mma.getInstrShape()[0]; - int n = mma.getInstrShape()[1]; - int k = mma.getInstrShape()[2]; - assert(m == 16); - assert(n == 8 || n == 16 || n == 32 || n == 64 || n == 128 || n == 256); - assert(k == 8 || k == 16 || k == 32); - - MLIRContext *ctx = mma.getContext(); - LinearLayout ctaLayout( - {{S("register"), {{1, 0}, {0, 8}}}, - {S("lane"), {{2, 0}, {4, 0}, {0, 1}, {0, 2}, {0, 4}}}}, - {S("dim1"), S("dim0")}); - - // Expand the `register` dimension so the size of dim1 matches `n`. - ctaLayout *= LinearLayout::identity1D(n / ctaLayout.getOutDimSize(S("dim1")), - S("register"), S("dim1")); - - // The order given by choosing (`dim1`, `dim0`) is [1, 0], that is, N-major. - // Since the warpOrder needs to be M-major, we need to transpose the out - // dimensions AND transpose the order - // FIXME(Lezcano). identityND should not have an `order` param as it's - // redundant. The order is already given by the order of the - // out dims, and if it has an order, it shouldn't change the - // order of the out dims. - assert(getWarpOrder(mma) == SmallVector({0, 1})); - ctaLayout *= identityND(S("warp"), mma.getWarpsPerCTA(), /*order=*/{0, 1}, - {S("dim0"), S("dim1")}) - .transposeOuts(llvm::to_vector(ctaLayout.getOutDimNames())); - - return combineCtaCgaWithShape(ctaLayout, mma.getCTALayout(), shape); -} - LinearLayout sharedToLinearLayoutNoLeadingOffset(ArrayRef shape, SharedEncodingAttr shared) { assert(!shared.getHasLeadingOffset()); @@ -779,13 +707,153 @@ BlockedEncodingAttr::toLinearLayout(ArrayRef shape) const { return combineCtaCgaWithShape(ctaLayout, getCTALayout(), shape); } +LinearLayout nvidiaMmaTile(MLIRContext *ctx, ArrayRef tileShape, + unsigned kWidth, ArrayRef order, + ArrayRef repOrder) { + // Trivial layout mapping 0 -> (0, 0), but we set the order to repOrder + int rank = repOrder.size(); + auto dimNames = standardOutDimNames(ctx, rank); + auto trivialShape = SmallVector(rank, 1); + LinearLayout ctaLayout = + identityND(S("register"), trivialShape, repOrder, dimNames); + + assert(rank >= 2); + auto inner = order[0]; + auto outer = order[1]; + + assert(tileShape.size() == rank); + int m = tileShape[outer]; + int n = tileShape[inner]; + + // The relative order of registers and lanes is given by: + // - Inner dim: kWidth registers + // - Inner dim: 4 lanes + // - Outer dim: 8 lanes + // - Outer dim: repeat m / 8 times + // - Inner dim: repeat n / (kWidth * 4) times + assert(m % 8 == 0); + assert(n % (kWidth * 4) == 0); + // There is at least one subtile on the inner-most dimension + // FIXME. We should implement operator* in terms of operator*= + // and chain *= instead of using * + auto outDimNames = llvm::to_vector(ctaLayout.getOutDimNames()); + ctaLayout = ctaLayout * + LinearLayout::identity1D(kWidth, S("register"), dimNames[inner]) * + LinearLayout::identity1D(4, S("lane"), dimNames[inner]) * + LinearLayout::identity1D(8, S("lane"), dimNames[outer]) * + LinearLayout::identity1D(m / 8, S("register"), dimNames[outer]) * + LinearLayout::identity1D(n / (kWidth * 4), S("register"), + dimNames[inner]); + return ctaLayout; +} + std::optional NvidiaMmaEncodingAttr::toLinearLayout(ArrayRef shape) const { + auto ctx = getContext(); + int rank = shape.size(); + + SmallVector tileShape; if (isAmpere()) { - return ampereMmaToLinearLayout(shape, *this); + // Ampere.getInstrShape() returns the tile shape + tileShape = SmallVector(getInstrShape()); + } else { + assert(isHopper()); + auto instrShapeMNK = getInstrShape(); + tileShape = SmallVector({instrShapeMNK[0], instrShapeMNK[1]}); } - if (isHopper()) { - return hopperMmaToLinearLayout(shape, *this); + // nvidiamma layout always assumes kWidth = 2 + constexpr auto kWidth = 2; + auto ctaLayout = + nvidiaMmaTile(ctx, tileShape, kWidth, getOrder(*this), getRepOrder()); + + // The triton orders are defined on [dim0, dim1, ...], so we need to pass + // those dims Then, for some reason, operator* requires the orders to match + // so we need to reorder the outs to match + // FIXME(Lezcano). identityND should not take a dim name, as it's redundant. + // The order in triton assumes the standardDims, so it should + // use those. + ctaLayout *= identityND(S("warp"), getWarpsPerCTA(), getWarpOrder(), + standardOutDimNames(ctx, rank)) + .transposeOuts(llvm::to_vector(ctaLayout.getOutDimNames())); + + return combineCtaCgaWithShape(ctaLayout, getCTALayout(), shape); +} + +LinearLayout warpsNvidiaDot(MLIRContext *ctx, ArrayRef mmaWarpShape, + ArrayRef mmaWarpOrder, bool isA) { + // Let warpsPerCTAMma = {2, 2}, then + // warpsPerCTA = {2, 1} for opA and warpsPerCTA = {1, 2} for opB + // assume warpOrder = {1, 0} + // Assume that C is tiled by 2x2 tiles. Since warpOrder={1, 0}, we have that + // the C is owned as per the following layout: + // C: 0 | 1 + // - | - + // 2 | 3 + // In order to be able to compute C, we need the following warp tiling of + // A and B: + // A: 0 1 | 0 1 B: 0 2 | 1 3 + // - - | - - - - | - - + // 2 3 | 2 3 0 2 | 1 3 + // In other words, we need to broadcast along K + auto rank = mmaWarpOrder.size(); + auto inner = isA ? rank - 1 : rank - 2; + auto outer = isA ? rank - 2 : rank - 1; + auto dimNames = standardOutDimNames(ctx, rank); + auto trivialShape = SmallVector(rank, 1); + LinearLayout warpLayout = + identityND(S("warp"), trivialShape, mmaWarpOrder, dimNames); + + // We have to broadcast along the inner dimension + // For A, when moving along M we go from 0 to 2. + // For B, when moving along N we go from 0 to 1. + // As such, choosing the order of A {1, 0}, gives us the correct broadcasting + // Same happens if the mmaWarpOrder is {0, 1}, like in Hopper + for (auto d : mmaWarpOrder) { + if (d == inner) { + warpLayout *= + LinearLayout::zeros1D(mmaWarpShape[d], S("warp"), dimNames[d]); + } else { + warpLayout *= + LinearLayout::identity1D(mmaWarpShape[d], S("warp"), dimNames[d]); + } + } + return warpLayout; +} + +LinearLayout nvidiaDotToLinearLayout(ArrayRef shape, + DotOperandEncodingAttr dot) { + int rank = shape.size(); + auto mma = cast(dot.getParent()); + int kWidth = dot.getKWidth(); + bool isA = dot.getOpIdx() == 0; + MLIRContext *ctx = mma.getContext(); + + SmallVector tileShape(rank, 1); + if (isA) { + tileShape[rank - 2] = 16; + tileShape[rank - 1] = kWidth * 8; + } else { + // Hopper takes the rhs via shared memory + assert(mma.isAmpere()); + tileShape[rank - 2] = kWidth * 8; + tileShape[rank - 1] = 8; + } + auto ctaLayout = + nvidiaMmaTile(ctx, tileShape, kWidth, getOrder(dot), dot.getRepOrder()); + ctaLayout *= + warpsNvidiaDot(ctx, mma.getWarpsPerCTA(), mma.getWarpOrder(), isA) + .transposeOuts(llvm::to_vector(ctaLayout.getOutDimNames())); + + return combineCtaCgaWithShape(ctaLayout, getCTALayout(dot), shape); +} + +std::optional +DotOperandEncodingAttr::toLinearLayout(ArrayRef shape) const { + auto parent = getParent(); + if (auto mfmaLayout = llvm::dyn_cast(parent)) { + return mfmaDotToLinearLayout(*this, shape); + } else if (auto mma = mlir::dyn_cast(parent)) { + return nvidiaDotToLinearLayout(shape, *this); } return std::nullopt; } @@ -860,116 +928,6 @@ SliceEncodingAttr::toLinearLayout(ArrayRef shape) const { return ret; } -LinearLayout ampereDotToLinearLayout(ArrayRef shape, - DotOperandEncodingAttr dot) { - // Note that, even though MMAv2 looks similar to this layout, they are just - // the same at a register and lane level. The warps treatment is different! - int rank = shape.size(); - auto mma = cast(dot.getParent()); - int kWidth = dot.getKWidth(); - bool isA = dot.getOpIdx() == 0; - - assert((rank == 2 && mma.getInstrShape() == ArrayRef({16, 8})) || - (rank == 3 && mma.getInstrShape() == ArrayRef({1, 16, 8}))); - assert(mma.isAmpere()); - - MLIRContext *ctx = mma.getContext(); - - // The A and B operands are tiled in a kMajor fashion - auto kMajorOrder = dot.getRepOrder(); - assert(kMajorOrder == - getOrderForDotOperand(dot.getOpIdx(), rank, /*kMajor=*/true)); - - auto kMajorDims = - permuteDimNames(standardOutDimNames(ctx, rank), kMajorOrder); - // This agrees with the order of the elements, which means that we can share - // the code below for both A and B without having to perform any swaps - assert(getOrder(dot) == kMajorOrder); - - std::vector> registers; - std::vector> lanes; - int32_t i = 1; - // kWidth contiguous elements - while (i < kWidth) { - registers.push_back({i, 0}); - i *= 2; - } - // 4 threads per chunk - for (int j = 0; j < 2; j++) { - lanes.push_back({i, 0}); - i *= 2; - } - // 8 threads going down - lanes.push_back({0, 1}); - lanes.push_back({0, 2}); - lanes.push_back({0, 4}); - // 2 tiles in column-major order - // Just one if it's the B operand - if (isA) { - registers.push_back({0, 8}); - } - registers.push_back({i, 0}); - - LinearLayout ctaLayout({{S("register"), registers}, {S("lane"), lanes}}, - ArrayRef(kMajorDims).take_front(2)); - - // Let warpsPerCTAMma = {2, 2}, then - // warpsPerCTA = {2, 1} for opA and warpsPerCTA = {1, 2} for opB - // assume warpOrder = {0, 1} - // Assume that C is tiled by 2x2 tiles. Since warpOrder={1, 0}, we have that - // the C is owned as per the following layout: - // C: 0 | 1 - // - | - - // 2 | 3 - // In order to be able to compute C, we need the following warp tiling of - // A and B: - // A: 0 1 | 0 1 B: 0 2 | 1 3 - // - - | - - - - | - - - // 2 3 | 2 3 0 2 | 1 3 - // In particular, for A and B we need to broadcast along K - - assert(mma.getWarpOrder() == getMatrixOrder(rank, /*rowMajor=*/true)); - auto warpsPerCTAMma = mma.getWarpsPerCTA(); - std::vector> warps; - if (isA) { - for (int i = 1; i < warpsPerCTAMma[1]; i *= 2) { - warps.push_back({0, 0}); - } - for (int i = 1; i < warpsPerCTAMma[0]; i *= 2) { - warps.push_back({0, i}); - } - } else { - for (int i = 1; i < warpsPerCTAMma[1]; i *= 2) { - warps.push_back({0, i}); - } - for (int i = 1; i < warpsPerCTAMma[0]; i *= 2) { - warps.push_back({0, 0}); - } - } - if (rank == 3) { - for (auto &w : warps) { - w.push_back(0); - } - } - - ctaLayout *= LinearLayout({{S("warp"), warps}}, kMajorDims); - - return combineCtaCgaWithShape(ctaLayout, getCTALayout(dot), shape); -} - -std::optional -DotOperandEncodingAttr::toLinearLayout(ArrayRef shape) const { - auto parent = getParent(); - if (auto mfmaLayout = llvm::dyn_cast(parent)) { - return mfmaDotToLinearLayout(*this, shape); - } else if (auto mma = mlir::dyn_cast(parent)) { - if (mma.isAmpere()) { - return ampereDotToLinearLayout(shape, *this); - } - } - return std::nullopt; -} - std::optional toLinearLayout(ArrayRef shape, Attribute layout, std::optional elemBitWidth /*= std::nullopt*/) { diff --git a/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp b/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp index d662537ed72d..af6242b59662 100644 --- a/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp +++ b/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp @@ -41,10 +41,16 @@ class LinearLayoutConversionsTest : public ::testing::Test { CTALayoutAttr::get(&ctx, cpg, cSplit, cOrd), instrShape); } - DotOperandEncodingAttr dotMMAv2(int idx, int kWidth, - ArrayRef warps) { - auto mmaLayout = mma(2, 0, {16, 8}, warps, {1, 1}, {1, 1}, {1, 0}); - return DotOperandEncodingAttr::get(&ctx, idx, mmaLayout, /*kWidth=*/kWidth); + NvidiaMmaEncodingAttr mma(unsigned versionMaj, unsigned versionMin, + ArrayRef instrShape, + ArrayRef numWarps) { + auto ctaLayout = CTALayoutAttr::getDefault(&ctx, numWarps.size()); + return NvidiaMmaEncodingAttr::get(&ctx, versionMaj, versionMin, numWarps, + std::move(ctaLayout), instrShape); + } + + DotOperandEncodingAttr dot(Attribute parent, int idx, int kWidth) { + return DotOperandEncodingAttr::get(&ctx, idx, parent, /*kWidth=*/kWidth); } AMDMfmaEncodingAttr mfma(ArrayRef warps, unsigned mDim, @@ -391,8 +397,7 @@ TEST_F(LinearLayoutConversionsTest, MMAv2_Small3D) { } TEST_F(LinearLayoutConversionsTest, MMAv3_64x16) { - SmallVector, 4> instrShapes = { - {16, 16, 8}, {16, 16, 8}, {16, 8, 8}}; + SmallVector, 2> instrShapes = {{16, 16, 8}, {16, 8, 8}}; for (auto instrShape : instrShapes) { SCOPED_TRACE(triton::join(instrShape, ",")); EXPECT_EQ(toLinearLayout({64, 16}, mma(3, 0, instrShape, {4, 1}, {1, 1}, @@ -515,7 +520,8 @@ TEST_F(LinearLayoutConversionsTest, MMAv3_4x4Warps) { } TEST_F(LinearLayoutConversionsTest, DotMMAv2_tile_kwidth8) { - EXPECT_EQ(toLinearLayout({16, 64}, dotMMAv2(0, 8, {1, 1})), + auto parent = mma(2, 0, {16, 8}, {1, 1}); + EXPECT_EQ(toLinearLayout({16, 64}, dot(parent, 0, 8)), LinearLayout( { {S("register"), {{0, 1}, {0, 2}, {0, 4}, {8, 0}, {0, 32}}}, @@ -524,7 +530,7 @@ TEST_F(LinearLayoutConversionsTest, DotMMAv2_tile_kwidth8) { {S("block"), {}}, }, {S("dim0"), S("dim1")})); - EXPECT_EQ(toLinearLayout({64, 8}, dotMMAv2(1, 8, {1, 1})), + EXPECT_EQ(toLinearLayout({64, 8}, dot(parent, 1, 8)), LinearLayout( { {S("register"), {{1, 0}, {2, 0}, {4, 0}, {32, 0}}}, @@ -536,8 +542,9 @@ TEST_F(LinearLayoutConversionsTest, DotMMAv2_tile_kwidth8) { } TEST_F(LinearLayoutConversionsTest, DotMMAv2_large_warp4_kwidth8) { + auto parent = mma(2, 0, {16, 8}, {4, 1}); EXPECT_EQ( - toLinearLayout({128, 128}, dotMMAv2(0, 8, {4, 1})), + toLinearLayout({128, 128}, dot(parent, 0, 8)), LinearLayout( { {S("register"), @@ -547,7 +554,7 @@ TEST_F(LinearLayoutConversionsTest, DotMMAv2_large_warp4_kwidth8) { {S("block"), {}}, }, {S("dim0"), S("dim1")})); - EXPECT_EQ(toLinearLayout({128, 64}, dotMMAv2(1, 8, {4, 1})), + EXPECT_EQ(toLinearLayout({128, 64}, dot(parent, 1, 8)), LinearLayout( { {S("register"), @@ -567,7 +574,7 @@ TEST_F(LinearLayoutConversionsTest, DotMMAv2_large_warp4_kwidth8) { {S("block"), {}}, }, {S("dim0"), S("dim1")})); - EXPECT_EQ(toLinearLayout({64, 128}, dotMMAv2(1, 8, {4, 1})), + EXPECT_EQ(toLinearLayout({64, 128}, dot(parent, 1, 8)), LinearLayout( { {S("register"), @@ -589,22 +596,122 @@ TEST_F(LinearLayoutConversionsTest, DotMMAv2_large_warp4_kwidth8) { {S("dim0"), S("dim1")})); } +TEST_F(LinearLayoutConversionsTest, DotMMAv2_3D) { + // We implement one that exercises all the paths + auto parent = mma(2, 0, {1, 16, 8}, {2, 4, 2}); + EXPECT_EQ(toLinearLayout({16, 128, 128}, dot(parent, 0, 8)), + LinearLayout( + { + {S("register"), + {{0, 0, 1}, + {0, 0, 2}, + {0, 0, 4}, + {0, 8, 0}, + {0, 0, 32}, + {0, 0, 64}, + {0, 64, 0}, + {2, 0, 0}, + {4, 0, 0}, + {8, 0, 0}}}, + {S("lane"), + {{0, 0, 8}, {0, 0, 16}, {0, 1, 0}, {0, 2, 0}, {0, 4, 0}}}, + {S("warp"), {{0, 0, 0}, {0, 16, 0}, {0, 32, 0}, {1, 0, 0}}}, + {S("block"), {}}, + }, + {S("dim0"), S("dim1"), S("dim2")})); + EXPECT_EQ(toLinearLayout({8, 128, 64}, dot(parent, 1, 8)), + LinearLayout( + { + {S("register"), + {{0, 1, 0}, + {0, 2, 0}, + {0, 4, 0}, + {0, 32, 0}, + {0, 64, 0}, + {0, 0, 16}, + {0, 0, 32}, + {2, 0, 0}, + {4, 0, 0}}}, + {S("lane"), + {{0, 8, 0}, {0, 16, 0}, {0, 0, 1}, {0, 0, 2}, {0, 0, 4}}}, + { + S("warp"), + {{0, 0, 8}, {0, 0, 0}, {0, 0, 0}, {1, 0, 0}}, + }, + {S("block"), {}}, + }, + {S("dim0"), S("dim1"), S("dim2")})); +} + +TEST_F(LinearLayoutConversionsTest, DotMMAv3_warp4_kwidth2) { + auto parent = mma(3, 0, {16, 16, 8}, {4, 1}); + auto dotOp = dot(parent, 0, 2); + + EXPECT_EQ(toLinearLayout({64, 16}, dotOp), + LinearLayout( + { + {S("register"), {{0, 1}, {8, 0}, {0, 8}}}, + {S("lane"), {{0, 2}, {0, 4}, {1, 0}, {2, 0}, {4, 0}}}, + {S("warp"), {{16, 0}, {32, 0}}}, + {S("block"), {}}, + }, + {S("dim0"), S("dim1")})); + + EXPECT_EQ(toLinearLayout({128, 16}, dotOp), + LinearLayout( + { + {S("register"), {{0, 1}, {8, 0}, {0, 8}, {64, 0}}}, + {S("lane"), {{0, 2}, {0, 4}, {1, 0}, {2, 0}, {4, 0}}}, + {S("warp"), {{16, 0}, {32, 0}}}, + {S("block"), {}}, + }, + {S("dim0"), S("dim1")})); + + EXPECT_EQ(toLinearLayout({128, 32}, dotOp), + LinearLayout( + { + {S("register"), {{0, 1}, {8, 0}, {0, 8}, {0, 16}, {64, 0}}}, + {S("lane"), {{0, 2}, {0, 4}, {1, 0}, {2, 0}, {4, 0}}}, + {S("warp"), {{16, 0}, {32, 0}}}, + {S("block"), {}}, + }, + {S("dim0"), S("dim1")})); +} + +TEST_F(LinearLayoutConversionsTest, DotMMAv3_mixed_warp_kwidth4) { + // Testing dot with MMAv3 encoding for opIdx = 0 and kWidth = 4 + auto parent = mma(3, 0, {16, 16, 8}, {4, 2}); + auto dotOp = dot(parent, 0, 4); + + EXPECT_EQ(toLinearLayout({128, 64}, dotOp), + LinearLayout( + { + {S("register"), + {{0, 1}, {0, 2}, {8, 0}, {0, 16}, {0, 32}, {64, 0}}}, + {S("lane"), {{0, 4}, {0, 8}, {1, 0}, {2, 0}, {4, 0}}}, + {S("warp"), {{16, 0}, {32, 0}, {0, 0}}}, + {S("block"), {}}, + }, + {S("dim0"), S("dim1")})); +} + TEST_F(LinearLayoutConversionsTest, DotMMAv2_split_warp_kwidth8) { + auto parent = mma(2, 0, {16, 8}, {2, 2}); EXPECT_EQ( - toLinearLayout({32, 64}, dotMMAv2(0, 8, {2, 2})), + toLinearLayout({32, 64}, dot(parent, 0, 8)), LinearLayout({{S("register"), {{0, 1}, {0, 2}, {0, 4}, {8, 0}, {0, 32}}}, {S("lane"), {{0, 8}, {0, 16}, {1, 0}, {2, 0}, {4, 0}}}, {S("warp"), {{0, 0}, {16, 0}}}, {S("block"), {}}}, {S("dim0"), S("dim1")})); EXPECT_EQ( - toLinearLayout({64, 16}, dotMMAv2(1, 8, {2, 2})), + toLinearLayout({64, 16}, dot(parent, 1, 8)), LinearLayout({{S("register"), {{1, 0}, {2, 0}, {4, 0}, {32, 0}}}, {S("lane"), {{8, 0}, {16, 0}, {0, 1}, {0, 2}, {0, 4}}}, {S("warp"), {{0, 8}, {0, 0}}}, {S("block"), {}}}, {S("dim0"), S("dim1")})); - EXPECT_EQ(toLinearLayout({64, 128}, dotMMAv2(0, 8, {2, 2})), + EXPECT_EQ(toLinearLayout({64, 128}, dot(parent, 0, 8)), LinearLayout( {{S("register"), {{0, 1}, {0, 2}, {0, 4}, {8, 0}, {0, 32}, {0, 64}, {32, 0}}}, @@ -613,7 +720,7 @@ TEST_F(LinearLayoutConversionsTest, DotMMAv2_split_warp_kwidth8) { {S("block"), {}}}, {S("dim0"), S("dim1")})); EXPECT_EQ( - toLinearLayout({128, 32}, dotMMAv2(1, 8, {2, 2})), + toLinearLayout({128, 32}, dot(parent, 1, 8)), LinearLayout( {{S("register"), {{1, 0}, {2, 0}, {4, 0}, {32, 0}, {64, 0}, {0, 16}}}, {S("lane"), {{8, 0}, {16, 0}, {0, 1}, {0, 2}, {0, 4}}},