From e4d762384d398307b7cea63ee1a4d5ae3071c8a8 Mon Sep 17 00:00:00 2001 From: Jokeren Date: Thu, 24 Oct 2024 11:30:20 -0400 Subject: [PATCH 01/12] Update --- include/triton/Analysis/Utility.h | 2 +- include/triton/Tools/LinearLayout.h | 7 +++++++ lib/Analysis/Utility.cpp | 2 +- 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/include/triton/Analysis/Utility.h b/include/triton/Analysis/Utility.h index cb3e3d292efa..4f6aff739cdd 100644 --- a/include/triton/Analysis/Utility.h +++ b/include/triton/Analysis/Utility.h @@ -212,7 +212,7 @@ bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy); bool atomicNeedsSharedMemory(Value result); -bool isBlockedToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstT); +bool isBlockedToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy); bool isMfmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy); diff --git a/include/triton/Tools/LinearLayout.h b/include/triton/Tools/LinearLayout.h index c728cfbb32cf..47e3fca79bb1 100644 --- a/include/triton/Tools/LinearLayout.h +++ b/include/triton/Tools/LinearLayout.h @@ -679,6 +679,13 @@ class LinearLayout { // (i.e. every input bit affects the output). llvm::MapVector getFreeVariableMasks() const; + // Increase an input dimension without affecting the output dimension. The + // added free variables are mapped to 0, ensuring that the new input + // dimensions correspond directly to the existing output space. The function + // errors out if `newInDimSize` is less than the current size or the new size + // is not a power of 2. + LinearLayout resize(StringAttr inDim, int32_t newInDimSize) const; + std::string toString() const; friend bool operator==(LinearLayout lhs, LinearLayout rhs); diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 4915d7b1acda..c637258b4560 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -536,7 +536,7 @@ bool supportMMA(Value value, int version) { (elemTy.isInteger(8) && version >= 2); } -bool isBlockedToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) { +bool isBlockedToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) { auto blockedLayout = dyn_cast(srcTy.getEncoding()); auto dotOperandLayout = dyn_cast(dstTy.getEncoding()); if (blockedLayout == nullptr || dotOperandLayout == nullptr) From ce70be65d6213af944b9a7d415a9364853537f6f Mon Sep 17 00:00:00 2001 From: Jokeren Date: Thu, 24 Oct 2024 15:03:05 -0400 Subject: [PATCH 02/12] Update --- include/triton/Tools/LinearLayout.h | 2 + lib/Analysis/Utility.cpp | 46 +++++++++++- .../TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp | 21 ++++-- lib/Tools/LinearLayout.cpp | 15 ++++ test/Conversion/tritongpu_to_llvm.mlir | 74 +++++++++++++++++++ unittest/Tools/LinearLayoutTest.cpp | 33 +++++++++ 6 files changed, 183 insertions(+), 8 deletions(-) diff --git a/include/triton/Tools/LinearLayout.h b/include/triton/Tools/LinearLayout.h index 47e3fca79bb1..649f2766e793 100644 --- a/include/triton/Tools/LinearLayout.h +++ b/include/triton/Tools/LinearLayout.h @@ -413,6 +413,8 @@ class LinearLayout { bool isSurjective() const { return surjective; } + bool isEmpty() const { return bases.empty(); } + const BasesT &getBases() const { return bases; } // Get the pos'th basis vector for the inDim -> outDim mapping. diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index c637258b4560..1b597f655a2f 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -640,6 +640,31 @@ bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy, return ans; } +namespace { + +// Count the number of not free registers in a layout. +// A register is free if it maps to duplicate values in the layout. +// For example, in the following layout, the number of not free registers is 8 +// +// register=1 -> 0 +// register=2 -> 1 +// register=4 -> 0 +// register=8 -> 2 +// register=16 -> 4 +// +int32_t countNotFreeRegs(int32_t numRegs, int32_t freeRegMasks) { + auto numRegsLog2 = llvm::Log2_32(numRegs); + auto numNotFreeRegsLog2 = 0; + for (auto i = 0; i < numRegsLog2; i++) { + if ((freeRegMasks & (1 << i)) == 0) { + numNotFreeRegsLog2++; + } + } + return 1 << numNotFreeRegsLog2; +} + +} // namespace + // We get the smallest submap of srcTy^{-1} * dstTy that is not the identity // under kBlock, kWarp or kLane (in that order). The idea here is that if we // have a transformation that's the identity on kBlock, we don't need to use @@ -655,13 +680,30 @@ std::optional minimalCvtLayout(RankedTensorType srcTy, toLinearLayout(dstTy.getShape(), dstTy.getEncoding()); if (!(srcLayout.has_value() && dstLayout.has_value())) return std::nullopt; + StringAttr kRegister = StringAttr::get(ctx, "register"); + StringAttr kLane = StringAttr::get(ctx, "lane"); + StringAttr kWarp = StringAttr::get(ctx, "warp"); + StringAttr kBlock = StringAttr::get(ctx, "block"); + auto numSrcRegs = srcLayout->getInDimSize(kRegister); + auto numDstRegs = dstLayout->getInDimSize(kRegister); + auto srcFreeVarMasks = srcLayout->getFreeVariableMasks()[kRegister]; + auto dstFreeVarMasks = dstLayout->getFreeVariableMasks()[kRegister]; + auto numNotFreeSrcRegs = countNotFreeRegs(numSrcRegs, srcFreeVarMasks); + auto numNotFreeDstRegs = countNotFreeRegs(numDstRegs, dstFreeVarMasks); + // We need to ensure that the number of registers is the same in the source + // and destination layouts. We do this by padding the smaller layout with + // extra registers. + auto numRegs = std::max(numSrcRegs, numDstRegs); + auto srcLayoutWithFreeReg = srcLayout->resize(kRegister, numRegs); + auto dstLayoutWithFreeReg = dstLayout->resize(kRegister, numRegs); // comp describes the layout function to create dst from src. - LinearLayout comp = dstLayout->invertAndCompose(*srcLayout); + LinearLayout comp = + dstLayoutWithFreeReg.invertAndCompose(srcLayoutWithFreeReg); // We try to quotient by the largest subspace first auto dims = SmallVector{"block", "warp", "lane", "register"}; for (auto dim : dims) { auto quotient = comp.quotient(StringAttr::get(ctx, dim)); - if (!quotient.has_value()) { + if (!quotient.has_value() || quotient->isEmpty()) { break; } comp = *quotient; diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index a18b2cbc308c..d00aca24fab2 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -317,9 +317,12 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion *toLinearLayout(dstTy.getShape(), dstTy.getEncoding()); return transferWithinBlock(op, srcLayout, dstLayout, adaptor, rewriter); } else if (llvm::is_contained(dims, str_attr("register"))) { + LinearLayout dstLayout = + *toLinearLayout(dstTy.getShape(), dstTy.getEncoding()); // Case 4. Transfer between values in the same thread, in which case we // simply reorder the elements of adaptor.getSrc(). - return transferWithinThread(op, *conversion, adaptor, rewriter); + return transferWithinThread(op, dstLayout, *conversion, adaptor, + rewriter); } else { // The two layouts are equivalent. We should probably remove these in // RemoveLayoutConversion. @@ -329,8 +332,8 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion } LogicalResult - transferWithinThread(ConvertLayoutOp op, const LinearLayout &conversion, - OpAdaptor adaptor, + transferWithinThread(ConvertLayoutOp op, const LinearLayout &dstLayout, + const LinearLayout &conversion, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { MLIRContext *ctx = op.getContext(); auto loc = op.getLoc(); @@ -339,9 +342,15 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion auto inVals = unpackLLElements(loc, adaptor.getSrc(), rewriter); SmallVector outVals; - outVals.resize(conversion.getInDimSize(kRegister)); - for (int i = 0; i < conversion.getInDimSize(kRegister); i++) { - auto srcIdx = conversion.apply({{kRegister, i}}).begin()->second; + outVals.resize(dstLayout.getInDimSize(kRegister)); + auto masks = dstLayout.getFreeVariableMasks()[kRegister]; + for (int i = 0; i < outVals.size(); i++) { + // Remove free masks from the register index + // For example, if idx = 0b00111, and masks = 0b00100, then we get + // 0b00011. It means that register 7 (0b111) has the same value as + // register 3 (0b011). + auto idx = i & (~masks); + auto srcIdx = conversion.apply({{kRegister, idx}}).begin()->second; outVals[i] = inVals[srcIdx]; } Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter, diff --git a/lib/Tools/LinearLayout.cpp b/lib/Tools/LinearLayout.cpp index bf017f8c6463..4319d1f086dd 100644 --- a/lib/Tools/LinearLayout.cpp +++ b/lib/Tools/LinearLayout.cpp @@ -1016,6 +1016,21 @@ bool LinearLayout::equalIgnoringOutDimSizes(const LinearLayout &other) const { return true; } +LinearLayout LinearLayout::resize(StringAttr inDim, + int32_t newInDimSize) const { + BasesT bases = getBases(); + assert(bases.contains(inDim) && "inDim not in layout"); + assert(llvm::isPowerOf2_32(newInDimSize) && + "newInDimSize must be a power of 2"); + assert(newInDimSize >= getInDimSize(inDim) && + "newInDimSize must be >= old size"); + auto numFreeVariables = llvm::Log2_32(newInDimSize) - getInDimSizeLog2(inDim); + for (int i = 0; i < numFreeVariables; i++) { + bases[inDim].push_back(std::vector(getNumOutDims(), 0)); + } + return LinearLayout(std::move(bases), llvm::to_vector(getOutDimNames())); +} + std::string LinearLayout::toString() const { // Start with a newline because we print out a bulleted list; it doesn't // make sense for the first line of this list to be on the same line as diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index e1a2ec68bd5a..4a61ee4bc1b0 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -847,6 +847,80 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- +#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> +#dot1 = #triton_gpu.dot_op<{opIdx=0, parent=#mma, kWidth=2}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { + // CHECK-LABEL: convert_layout_mmav2_dot_reg + tt.func @convert_layout_mmav2_dot_reg(%arg0: tensor<16x16xf16, #mma>) { + // CHECK-NOT: st.shared + // CHECK-NOT: llvm.load + %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf16, #mma> -> tensor<16x16xf16, #dot1> + tt.return + } +} + +// ----- + +#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: convert_layout_mmav3_mmav3_0 + tt.func @convert_layout_mmav3_mmav3_0(%arg0: tensor<64x64xf16, #mma0>) { + // CHECK-NOT: st.shared + // CHECK-NOT: llvm.load + %0 = triton_gpu.convert_layout %arg0 : tensor<64x64xf16, #mma0> -> tensor<64x64xf16, #mma1> + tt.return + } +} + +// ----- + +#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: convert_layout_mmav3_mmav3_1 + tt.func @convert_layout_mmav3_mmav3_1(%arg0: tensor<64x64xf16, #mma1>) { + // CHECK-NOT: st.shared + // CHECK-NOT: llvm.load + %0 = triton_gpu.convert_layout %arg0 : tensor<64x64xf16, #mma1> -> tensor<64x64xf16, #mma0> + tt.return + } +} + +// ----- + +#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: convert_layout_mmav3_mmav3_2 + tt.func @convert_layout_mmav3_mmav3_2(%arg0: tensor<16x16xf16, #mma1>) { + // CHECK-NOT: st.shared + // CHECK-NOT: llvm.load + %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf16, #mma1> -> tensor<16x16xf16, #mma0> + tt.return + } +} + +// ----- + +#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: convert_layout_mmav3_mmav3_3 + tt.func @convert_layout_mmav3_mmav3_3(%arg0: tensor<1x64xf16, #mma1>) { + // CHECK-NOT: st.shared + // CHECK-NOT: llvm.load + %0 = triton_gpu.convert_layout %arg0 : tensor<1x64xf16, #mma1> -> tensor<1x64xf16, #mma0> + tt.return + } +} + +// ----- + #blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 8], order = [0, 1]}> #mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 32]}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32} { diff --git a/unittest/Tools/LinearLayoutTest.cpp b/unittest/Tools/LinearLayoutTest.cpp index f006447002ef..897172fd6d34 100644 --- a/unittest/Tools/LinearLayoutTest.cpp +++ b/unittest/Tools/LinearLayoutTest.cpp @@ -747,6 +747,39 @@ TEST_F(LinearLayoutTest, QuotientIdentityMultipleDimensions) { ASSERT_TRUE(quotientLayout->quotient({S("dim2")}).has_value()); } +TEST_F(LinearLayoutTest, Resize) { + auto init = LinearLayout( + { + {S("in0"), {{0, 1}, {0, 2}}}, + {S("in1"), {{1, 0}, {2, 0}}}, + {S("in2"), {}}, + }, + {S("dim0"), S("dim1")}); + EXPECT_EQ(init.resize(S("in0"), 8), + LinearLayout( + { + {S("in0"), {{0, 1}, {0, 2}, {0, 0}}}, + {S("in1"), {{1, 0}, {2, 0}}}, + {S("in2"), {}}, + }, + {S("dim0"), S("dim1")})); + EXPECT_EQ(init.resize(S("in0"), 4), LinearLayout( + { + {S("in0"), {{0, 1}, {0, 2}}}, + {S("in1"), {{1, 0}, {2, 0}}}, + {S("in2"), {}}, + }, + {S("dim0"), S("dim1")})); + EXPECT_EQ(init.resize(S("in1"), 8), + LinearLayout( + { + {S("in0"), {{0, 1}, {0, 2}}}, + {S("in1"), {{1, 0}, {2, 0}, {0, 0}}}, + {S("in2"), {}}, + }, + {S("dim0"), S("dim1")})); +} + } // anonymous namespace } // namespace mlir::triton From f6f0f3a4a2efae3a4c849d38f724ad5bbd00b934 Mon Sep 17 00:00:00 2001 From: Jokeren Date: Thu, 24 Oct 2024 20:56:39 -0400 Subject: [PATCH 03/12] Update --- include/triton/Tools/LinearLayout.h | 2 - lib/Analysis/Utility.cpp | 37 +-------- .../TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp | 75 ++++++++++++------- 3 files changed, 51 insertions(+), 63 deletions(-) diff --git a/include/triton/Tools/LinearLayout.h b/include/triton/Tools/LinearLayout.h index 649f2766e793..47e3fca79bb1 100644 --- a/include/triton/Tools/LinearLayout.h +++ b/include/triton/Tools/LinearLayout.h @@ -413,8 +413,6 @@ class LinearLayout { bool isSurjective() const { return surjective; } - bool isEmpty() const { return bases.empty(); } - const BasesT &getBases() const { return bases; } // Get the pos'th basis vector for the inDim -> outDim mapping. diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 1b597f655a2f..e6a73b996cd0 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -640,31 +640,6 @@ bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy, return ans; } -namespace { - -// Count the number of not free registers in a layout. -// A register is free if it maps to duplicate values in the layout. -// For example, in the following layout, the number of not free registers is 8 -// -// register=1 -> 0 -// register=2 -> 1 -// register=4 -> 0 -// register=8 -> 2 -// register=16 -> 4 -// -int32_t countNotFreeRegs(int32_t numRegs, int32_t freeRegMasks) { - auto numRegsLog2 = llvm::Log2_32(numRegs); - auto numNotFreeRegsLog2 = 0; - for (auto i = 0; i < numRegsLog2; i++) { - if ((freeRegMasks & (1 << i)) == 0) { - numNotFreeRegsLog2++; - } - } - return 1 << numNotFreeRegsLog2; -} - -} // namespace - // We get the smallest submap of srcTy^{-1} * dstTy that is not the identity // under kBlock, kWarp or kLane (in that order). The idea here is that if we // have a transformation that's the identity on kBlock, we don't need to use @@ -686,24 +661,20 @@ std::optional minimalCvtLayout(RankedTensorType srcTy, StringAttr kBlock = StringAttr::get(ctx, "block"); auto numSrcRegs = srcLayout->getInDimSize(kRegister); auto numDstRegs = dstLayout->getInDimSize(kRegister); - auto srcFreeVarMasks = srcLayout->getFreeVariableMasks()[kRegister]; - auto dstFreeVarMasks = dstLayout->getFreeVariableMasks()[kRegister]; - auto numNotFreeSrcRegs = countNotFreeRegs(numSrcRegs, srcFreeVarMasks); - auto numNotFreeDstRegs = countNotFreeRegs(numDstRegs, dstFreeVarMasks); // We need to ensure that the number of registers is the same in the source // and destination layouts. We do this by padding the smaller layout with // extra registers. auto numRegs = std::max(numSrcRegs, numDstRegs); - auto srcLayoutWithFreeReg = srcLayout->resize(kRegister, numRegs); - auto dstLayoutWithFreeReg = dstLayout->resize(kRegister, numRegs); + auto srcLayoutWithFreeRegs = srcLayout->resize(kRegister, numRegs); + auto dstLayoutWithFreeRegs = dstLayout->resize(kRegister, numRegs); // comp describes the layout function to create dst from src. LinearLayout comp = - dstLayoutWithFreeReg.invertAndCompose(srcLayoutWithFreeReg); + dstLayoutWithFreeRegs.invertAndCompose(srcLayoutWithFreeRegs); // We try to quotient by the largest subspace first auto dims = SmallVector{"block", "warp", "lane", "register"}; for (auto dim : dims) { auto quotient = comp.quotient(StringAttr::get(ctx, dim)); - if (!quotient.has_value() || quotient->isEmpty()) { + if (!quotient.has_value()) { break; } comp = *quotient; diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index d00aca24fab2..ce3acabe59b5 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -288,69 +288,88 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion return rewriter.notifyMatchFailure( op, "NYI. srcTy and/or dstTy don't implement LLs yet"); } + LinearLayout srcLayout = + *toLinearLayout(srcTy.getShape(), srcTy.getEncoding()); + LinearLayout dstLayout = + *toLinearLayout(dstTy.getShape(), dstTy.getEncoding()); + + StringAttr kBlock = str_attr("block"); + StringAttr kWarp = str_attr("warp"); + StringAttr kLane = str_attr("lane"); + StringAttr kRegister = str_attr("register"); assert(to_vector(conversion->getInDimNames()) == to_vector(conversion->getOutDimNames())); auto dims = conversion->getInDimNames(); - if (llvm::is_contained(dims, str_attr("block"))) { + if (llvm::is_contained(dims, kBlock)) { // Case 1: Transfer between values in different CTAs. // This requires moving values through distributed shared memory. return rewriter.notifyMatchFailure( op, "NYI: Transfer between different CTAs"); - } else if (llvm::is_contained(dims, str_attr("warp"))) { + } else if (llvm::is_contained(dims, kWarp)) { // Case 2: Transfer between values in the same CTA, in which case we move // values through shared memory. - LinearLayout srcLayout = - *toLinearLayout(srcTy.getShape(), srcTy.getEncoding()); - LinearLayout dstLayout = - *toLinearLayout(dstTy.getShape(), dstTy.getEncoding()); return transferWithinBlock(op, srcLayout, dstLayout, adaptor, rewriter); - } else if (llvm::is_contained(dims, str_attr("lane"))) { + } else if (llvm::is_contained(dims, kLane)) { // Case 3. Transfer between values in the same warp, in which case we try // to move values using warp shuffles, though if the pattern is // complicated enough we may fall back to using shared memory // TODO(Keren): implement warp shuffle instead of using the general // approach that uses shared memory - LinearLayout srcLayout = - *toLinearLayout(srcTy.getShape(), srcTy.getEncoding()); - LinearLayout dstLayout = - *toLinearLayout(dstTy.getShape(), dstTy.getEncoding()); return transferWithinBlock(op, srcLayout, dstLayout, adaptor, rewriter); - } else if (llvm::is_contained(dims, str_attr("register"))) { - LinearLayout dstLayout = - *toLinearLayout(dstTy.getShape(), dstTy.getEncoding()); + } else if (llvm::is_contained(dims, kRegister)) { // Case 4. Transfer between values in the same thread, in which case we // simply reorder the elements of adaptor.getSrc(). - return transferWithinThread(op, dstLayout, *conversion, adaptor, - rewriter); + return transferWithinThread(op, *conversion, adaptor, rewriter); + } else if (dstLayout.getInDimSize(kRegister) > + srcLayout.getInDimSize(kRegister)) { + // Case 5. We don't need to transfer from one layout to another, but + // need to replicate the values. + return replicateWithinThread(op, dstLayout, adaptor, rewriter); + } else { - // The two layouts are equivalent. We should probably remove these in - // RemoveLayoutConversion. + // Cast 6. The two layouts are equivalent. We should probably remove + // these in RemoveLayoutConversion. rewriter.replaceOp(op, adaptor.getSrc()); return success(); } } LogicalResult - transferWithinThread(ConvertLayoutOp op, const LinearLayout &dstLayout, - const LinearLayout &conversion, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { + replicateWithinThread(ConvertLayoutOp op, const LinearLayout &dstLayout, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { MLIRContext *ctx = op.getContext(); auto loc = op.getLoc(); StringAttr kRegister = str_attr("register"); assert(!cvtNeedsSharedMemory(op.getSrc().getType(), op.getType())); auto inVals = unpackLLElements(loc, adaptor.getSrc(), rewriter); - SmallVector outVals; - outVals.resize(dstLayout.getInDimSize(kRegister)); + SmallVector outVals(dstLayout.getInDimSize(kRegister)); auto masks = dstLayout.getFreeVariableMasks()[kRegister]; for (int i = 0; i < outVals.size(); i++) { - // Remove free masks from the register index - // For example, if idx = 0b00111, and masks = 0b00100, then we get - // 0b00011. It means that register 7 (0b111) has the same value as - // register 3 (0b011). auto idx = i & (~masks); - auto srcIdx = conversion.apply({{kRegister, idx}}).begin()->second; + outVals[i] = inVals[idx]; + } + Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter, + op.getType()); + rewriter.replaceOp(op, result); + return success(); + } + + LogicalResult + transferWithinThread(ConvertLayoutOp op, const LinearLayout &conversion, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + MLIRContext *ctx = op.getContext(); + auto loc = op.getLoc(); + StringAttr kRegister = str_attr("register"); + assert(!cvtNeedsSharedMemory(op.getSrc().getType(), op.getType())); + + auto inVals = unpackLLElements(loc, adaptor.getSrc(), rewriter); + SmallVector outVals(conversion.getInDimSize(kRegister)); + for (int i = 0; i < outVals.size(); i++) { + auto srcIdx = conversion.apply({{kRegister, i}}).begin()->second; outVals[i] = inVals[srcIdx]; } Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter, From f836abedb62c1f2d945cda25fc2d24b083c29c0a Mon Sep 17 00:00:00 2001 From: Jokeren Date: Thu, 24 Oct 2024 20:58:20 -0400 Subject: [PATCH 04/12] Update --- lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index ce3acabe59b5..95102b55ae3e 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -323,10 +323,9 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion return transferWithinThread(op, *conversion, adaptor, rewriter); } else if (dstLayout.getInDimSize(kRegister) > srcLayout.getInDimSize(kRegister)) { - // Case 5. We don't need to transfer from one layout to another, but - // need to replicate the values. + // Case 5: `dims` is empty, so no layout conversion is required, but the + // values need to be replicated. return replicateWithinThread(op, dstLayout, adaptor, rewriter); - } else { // Cast 6. The two layouts are equivalent. We should probably remove // these in RemoveLayoutConversion. From e813456795630688f257f1985997b05c9d8f67a1 Mon Sep 17 00:00:00 2001 From: Jokeren Date: Thu, 24 Oct 2024 21:00:31 -0400 Subject: [PATCH 05/12] Update --- lib/Analysis/Utility.cpp | 31 ++++++++++++++++++++++++++++--- 1 file changed, 28 insertions(+), 3 deletions(-) diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index e6a73b996cd0..9782be48d7d8 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -661,9 +661,34 @@ std::optional minimalCvtLayout(RankedTensorType srcTy, StringAttr kBlock = StringAttr::get(ctx, "block"); auto numSrcRegs = srcLayout->getInDimSize(kRegister); auto numDstRegs = dstLayout->getInDimSize(kRegister); - // We need to ensure that the number of registers is the same in the source - // and destination layouts. We do this by padding the smaller layout with - // extra registers. + // The `invertAndCompose` function will generate a layout that is injective + // by assigning new output dimensions to free variables. For instance, + // consider a scenario where `srcLayout` has a free variable in the lane + // dimension, while `dstLayout` has two free variables in the lane + // dimension and also a larger number of registers. + // The injective form of `srcLayout` will add only a single additional row + // to the transformation matrix, whereas the injective form of `dstLayout` + // will add two additional rows. This discrepancy causes misleading results + // because the matrices end up with a different number of rows. + // + // Take `dstLayout ⋅ srcLayout^-1` as an example: + // + // - `injective(dstLayout)`: [n, m] → [n + 2, m] + // - `injective(srcLayout)`: [n, m] → [n + 1, m] + // - `injective(srcLayout)^-1`: [n + 1, m] → [m, n + 1] + // - `injective(dstLayout) ⋅ injective(srcLayout)^-1`: [n + 2, m] ⋅ [m, n + + // 1] → [n + 2, n + 1] + // + // Here, the `(n + 1)`-th row added by `dstLayout` represents the free + // variable in registers, and the `(n + 2)`-th row represents the free + // variable in lanes. However, the `(n + 1)`-th row added by `srcLayout` + // represents the free variable in lanes. As a result, the `(n + 1)`-th row + // in two layouts do not correspond to the same free variable. + // + // To address this issue, we pad the free variables in `srcLayout` and + // `dstLayout` to ensure they have the same number of registers. This + // guarantees that the resulting matrices have the same number of rows, + // ensuring consistency in the composition process. auto numRegs = std::max(numSrcRegs, numDstRegs); auto srcLayoutWithFreeRegs = srcLayout->resize(kRegister, numRegs); auto dstLayoutWithFreeRegs = dstLayout->resize(kRegister, numRegs); From 3c072a6a4acf675e7240a7a8aa60950e5b528018 Mon Sep 17 00:00:00 2001 From: Jokeren Date: Thu, 24 Oct 2024 21:05:30 -0400 Subject: [PATCH 06/12] Update --- lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 95102b55ae3e..2824cc8db255 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -347,6 +347,10 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion SmallVector outVals(dstLayout.getInDimSize(kRegister)); auto masks = dstLayout.getFreeVariableMasks()[kRegister]; for (int i = 0; i < outVals.size(); i++) { + // Remove free masks from the register index + // For example, if idx = 0b00111, and masks = 0b00100, then we get + // 0b00011. It means that register 7 (0b111) has the same value as + // register 3 (0b011). auto idx = i & (~masks); outVals[i] = inVals[idx]; } From c675401ec0579fbb86f48c150548360a4df7df92 Mon Sep 17 00:00:00 2001 From: Jokeren Date: Thu, 24 Oct 2024 22:29:47 -0400 Subject: [PATCH 07/12] Update --- lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 2824cc8db255..d43bd3a2c09a 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -321,7 +321,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion // Case 4. Transfer between values in the same thread, in which case we // simply reorder the elements of adaptor.getSrc(). return transferWithinThread(op, *conversion, adaptor, rewriter); - } else if (dstLayout.getInDimSize(kRegister) > + } else if (dstLayout.getInDimSize(kRegister) != srcLayout.getInDimSize(kRegister)) { // Case 5: `dims` is empty, so no layout conversion is required, but the // values need to be replicated. @@ -371,6 +371,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion auto inVals = unpackLLElements(loc, adaptor.getSrc(), rewriter); SmallVector outVals(conversion.getInDimSize(kRegister)); + auto masks = dstLayout.getFreeVariableMasks()[kRegister]; for (int i = 0; i < outVals.size(); i++) { auto srcIdx = conversion.apply({{kRegister, i}}).begin()->second; outVals[i] = inVals[srcIdx]; From 6f79395c0a8d6dc8bc1693abb872731ddc474b48 Mon Sep 17 00:00:00 2001 From: Jokeren Date: Thu, 24 Oct 2024 22:36:54 -0400 Subject: [PATCH 08/12] Update --- lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index d43bd3a2c09a..7dbbd2dd6c80 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -371,7 +371,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion auto inVals = unpackLLElements(loc, adaptor.getSrc(), rewriter); SmallVector outVals(conversion.getInDimSize(kRegister)); - auto masks = dstLayout.getFreeVariableMasks()[kRegister]; for (int i = 0; i < outVals.size(); i++) { auto srcIdx = conversion.apply({{kRegister, i}}).begin()->second; outVals[i] = inVals[srcIdx]; From 1d9243c06e41c02233b7aef15491a932cf6671c5 Mon Sep 17 00:00:00 2001 From: Jokeren Date: Thu, 24 Oct 2024 22:55:43 -0400 Subject: [PATCH 09/12] Update --- .../TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp | 36 ++++++------------- 1 file changed, 10 insertions(+), 26 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 7dbbd2dd6c80..fd33da7e5b06 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -320,12 +320,13 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion } else if (llvm::is_contained(dims, kRegister)) { // Case 4. Transfer between values in the same thread, in which case we // simply reorder the elements of adaptor.getSrc(). - return transferWithinThread(op, *conversion, adaptor, rewriter); + return transferWithinThread(op, dstLayout, conversion, adaptor, rewriter); } else if (dstLayout.getInDimSize(kRegister) != srcLayout.getInDimSize(kRegister)) { // Case 5: `dims` is empty, so no layout conversion is required, but the // values need to be replicated. - return replicateWithinThread(op, dstLayout, adaptor, rewriter); + return transferWithinThread(op, dstLayout, std::nullopt, adaptor, + rewriter); } else { // Cast 6. The two layouts are equivalent. We should probably remove // these in RemoveLayoutConversion. @@ -335,9 +336,10 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion } LogicalResult - replicateWithinThread(ConvertLayoutOp op, const LinearLayout &dstLayout, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { + transferWithinThread(ConvertLayoutOp op, const LinearLayout &dstLayout, + std::optional conversion, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { MLIRContext *ctx = op.getContext(); auto loc = op.getLoc(); StringAttr kRegister = str_attr("register"); @@ -352,27 +354,9 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion // 0b00011. It means that register 7 (0b111) has the same value as // register 3 (0b011). auto idx = i & (~masks); - outVals[i] = inVals[idx]; - } - Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter, - op.getType()); - rewriter.replaceOp(op, result); - return success(); - } - - LogicalResult - transferWithinThread(ConvertLayoutOp op, const LinearLayout &conversion, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - MLIRContext *ctx = op.getContext(); - auto loc = op.getLoc(); - StringAttr kRegister = str_attr("register"); - assert(!cvtNeedsSharedMemory(op.getSrc().getType(), op.getType())); - - auto inVals = unpackLLElements(loc, adaptor.getSrc(), rewriter); - SmallVector outVals(conversion.getInDimSize(kRegister)); - for (int i = 0; i < outVals.size(); i++) { - auto srcIdx = conversion.apply({{kRegister, i}}).begin()->second; + auto srcIdx = conversion + ? conversion->apply({{kRegister, idx}}).begin()->second + : idx; outVals[i] = inVals[srcIdx]; } Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter, From 527d14707f3e9e7b0bfd735ad495a8a8ffc195dc Mon Sep 17 00:00:00 2001 From: Jokeren Date: Fri, 25 Oct 2024 09:58:39 -0400 Subject: [PATCH 10/12] Update --- .../TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp | 23 ++++++++----------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index fd33da7e5b06..73f1f036d788 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -317,18 +317,16 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion // TODO(Keren): implement warp shuffle instead of using the general // approach that uses shared memory return transferWithinBlock(op, srcLayout, dstLayout, adaptor, rewriter); - } else if (llvm::is_contained(dims, kRegister)) { + } else if (llvm::is_contained(dims, kRegister) || + dstLayout.getInDimSize(kRegister) != + srcLayout.getInDimSize(kRegister)) { // Case 4. Transfer between values in the same thread, in which case we // simply reorder the elements of adaptor.getSrc(). - return transferWithinThread(op, dstLayout, conversion, adaptor, rewriter); - } else if (dstLayout.getInDimSize(kRegister) != - srcLayout.getInDimSize(kRegister)) { - // Case 5: `dims` is empty, so no layout conversion is required, but the - // values need to be replicated. - return transferWithinThread(op, dstLayout, std::nullopt, adaptor, - rewriter); + return transferWithinThread( + op, dstLayout.getFreeVariableMasks()[kRegister], + dstLayout.getInDimSize(kRegister), conversion, adaptor, rewriter); } else { - // Cast 6. The two layouts are equivalent. We should probably remove + // Cast 5. The two layouts are equivalent. We should probably remove // these in RemoveLayoutConversion. rewriter.replaceOp(op, adaptor.getSrc()); return success(); @@ -336,7 +334,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion } LogicalResult - transferWithinThread(ConvertLayoutOp op, const LinearLayout &dstLayout, + transferWithinThread(ConvertLayoutOp op, int32_t regMasks, int32_t numRegs, std::optional conversion, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { @@ -346,14 +344,13 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion assert(!cvtNeedsSharedMemory(op.getSrc().getType(), op.getType())); auto inVals = unpackLLElements(loc, adaptor.getSrc(), rewriter); - SmallVector outVals(dstLayout.getInDimSize(kRegister)); - auto masks = dstLayout.getFreeVariableMasks()[kRegister]; + SmallVector outVals(numRegs); for (int i = 0; i < outVals.size(); i++) { // Remove free masks from the register index // For example, if idx = 0b00111, and masks = 0b00100, then we get // 0b00011. It means that register 7 (0b111) has the same value as // register 3 (0b011). - auto idx = i & (~masks); + auto idx = i & (~regMasks); auto srcIdx = conversion ? conversion->apply({{kRegister, idx}}).begin()->second : idx; From 05ab549f9dc436a4c5cf92fcd9a884a3205bc30e Mon Sep 17 00:00:00 2001 From: Jokeren Date: Fri, 25 Oct 2024 10:01:09 -0400 Subject: [PATCH 11/12] Update --- lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 73f1f036d788..7966cd695c0d 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -335,8 +335,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion LogicalResult transferWithinThread(ConvertLayoutOp op, int32_t regMasks, int32_t numRegs, - std::optional conversion, - OpAdaptor adaptor, + const LinearLayout &conversion, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { MLIRContext *ctx = op.getContext(); auto loc = op.getLoc(); @@ -351,8 +350,8 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion // 0b00011. It means that register 7 (0b111) has the same value as // register 3 (0b011). auto idx = i & (~regMasks); - auto srcIdx = conversion - ? conversion->apply({{kRegister, idx}}).begin()->second + auto srcIdx = conversion.hasInDim(kRegister) + ? conversion.apply({{kRegister, idx}}).begin()->second : idx; outVals[i] = inVals[srcIdx]; } From eda45bafc2dc962df99f9fd88b660ac2c7116911 Mon Sep 17 00:00:00 2001 From: Jokeren Date: Fri, 25 Oct 2024 10:08:47 -0400 Subject: [PATCH 12/12] Update --- lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 7966cd695c0d..ea9091f4e19b 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -324,7 +324,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion // simply reorder the elements of adaptor.getSrc(). return transferWithinThread( op, dstLayout.getFreeVariableMasks()[kRegister], - dstLayout.getInDimSize(kRegister), conversion, adaptor, rewriter); + dstLayout.getInDimSize(kRegister), *conversion, adaptor, rewriter); } else { // Cast 5. The two layouts are equivalent. We should probably remove // these in RemoveLayoutConversion.