diff --git a/include/triton/Analysis/Utility.h b/include/triton/Analysis/Utility.h index 4f6aff739cdd..37d24ac929a9 100644 --- a/include/triton/Analysis/Utility.h +++ b/include/triton/Analysis/Utility.h @@ -216,8 +216,6 @@ bool isBlockedToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy); bool isMfmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy); -bool isMmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy); - // Return true if the src and dst layout match. bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy, RankedTensorType dstTy); diff --git a/include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h b/include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h index 22c8f9c8a330..8c7ab9831667 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h +++ b/include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h @@ -18,14 +18,6 @@ namespace gpu { SmallVector reorderValues(const SmallVector &values, Type inType, Type ouType); -SmallVector unpackI32(const SmallVector &inValues, Type srcTy, - ConversionPatternRewriter &rewriter, Location loc, - const LLVMTypeConverter *typeConverter); - -SmallVector packI32(const SmallVector &inValues, Type srcTy, - ConversionPatternRewriter &rewriter, Location loc, - const LLVMTypeConverter *typeConverter); - Type getElementType(Value value); class MultipleOperandsRange @@ -187,8 +179,8 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern { for (auto operand : adaptor.getOperands()) { auto argTy = op->getOperand(0).getType(); auto subOperands = unpackLLElements(loc, operand, rewriter); - subOperands = unpackI32(subOperands, argTy, rewriter, loc, - this->getTypeConverter()); + subOperands = unpackI32s(subOperands, argTy, rewriter, loc, + this->getTypeConverter()); allOperands.resize(subOperands.size()); for (auto v : llvm::enumerate(subOperands)) allOperands[v.index()].push_back(v.value()); @@ -215,7 +207,7 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern { } resultVals = maybeDeduplicate(op, resultVals); resultVals = - packI32(resultVals, resultTy, rewriter, loc, this->getTypeConverter()); + packI32s(resultVals, resultTy, rewriter, loc, this->getTypeConverter()); Value view = packLLElements(loc, this->getTypeConverter(), resultVals, rewriter, resultTy); rewriter.replaceOp(op, view); diff --git a/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/include/triton/Conversion/TritonGPUToLLVM/Utility.h index 29b8865c03ae..56a82d7cc0fb 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Utility.h +++ b/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -1388,6 +1388,67 @@ inline Value getStructFromSharedMemoryObject(Location loc, return llvmStruct; } +// For some reasons, LLVM's NVPTX backend inserts unnecessary (?) integer +// instructions to pack & unpack sub-word integers. A workaround is to +// store the results of tensors with dot operand encodings in i32 to +// facilitate instructions such as `ldmatrix`. +// +// TODO: Confirm if the problem is still there. +inline bool requiresI32Conversion(Type type) { + auto tensorTy = dyn_cast(type); + if (!tensorTy) + return false; + auto dotOpEnc = dyn_cast(tensorTy.getEncoding()); + if (!dotOpEnc) + return false; + auto parent = dyn_cast(dotOpEnc.getParent()); + if (!(parent && parent.getVersionMajor() < 3)) + return false; + return true; +} + +inline SmallVector packI32s(const SmallVector &inValues, + Type type, RewriterBase &rewriter, + Location loc, + const LLVMTypeConverter *typeConverter) { + if (!requiresI32Conversion(type)) + return inValues; + Type eltTy = + typeConverter->convertType(cast(type).getElementType()); + + SmallVector outValues; + int vecWidth = 32 / eltTy.getIntOrFloatBitWidth(); + auto vecTy = vec_ty(eltTy, vecWidth); + for (int i = 0; i < inValues.size(); i += vecWidth) { + Value vec = undef(vecTy); + for (int j = 0; j < vecWidth; j++) { + vec = insert_element(vec, inValues[i + j], i32_val(j)); + } + outValues.push_back(bitcast(vec, i32_ty)); + } + return outValues; +} + +inline SmallVector unpackI32s(const SmallVector &inValues, + Type type, RewriterBase &rewriter, + Location loc, + const LLVMTypeConverter *typeConverter) { + if (!requiresI32Conversion(type)) + return inValues; + Type eltTy = + typeConverter->convertType(cast(type).getElementType()); + + SmallVector outValues; + for (auto v : inValues) { + auto vecTy = vec_ty(eltTy, 32 / eltTy.getIntOrFloatBitWidth()); + auto vec = bitcast(v, vecTy); + for (int i = 0; i < 32 / eltTy.getIntOrFloatBitWidth(); i++) { + outValues.push_back(extract_element(vec, i32_val(i))); + } + } + return outValues; +} + inline SmallVector unpackLLElements(Location loc, Value llvmStruct, RewriterBase &rewriter) { assert(bool(llvmStruct) && "can not unpack null values"); diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 9782be48d7d8..30ba11c31782 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -731,14 +731,14 @@ bool cvtNeedsWarpShuffle(RankedTensorType srcTy, RankedTensorType dstTy) { } bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy) { - // TODO(jlebar): Remove these special cases (`isMmaToDotShortcut`, - // `isBlockedToDotShortcut` and `isMfmaToDotShortcut`) once they're fully - // subsumed by the linear-layout checks. + // TODO(jlebar): Remove these special cases (`isBlockedToDotShortcut` and + // `isMfmaToDotShortcut`) once they're fully subsumed by the linear-layout + // checks. // TODO(Keren): We didn't check `cvtNeedsWarpShuffle` here because it's not // supported yet in Triton's backend. return !cvtReordersRegisters(srcTy, dstTy) && !isBlockedToDotShortcut(srcTy, dstTy) && - !isMmaToDotShortcut(srcTy, dstTy) && + !matchMmaV3AndDotOperandLayout(srcTy, dstTy) && !isMfmaToDotShortcut(srcTy, dstTy); } @@ -749,20 +749,6 @@ bool atomicNeedsSharedMemory(Value value) { return true; } -bool isMmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) { - if (matchMmaV3AndDotOperandLayout(srcTy, dstTy)) - return true; - // dot_op = #mma - // when #mma = MmaEncoding - auto mmaLayout = dyn_cast(srcTy.getEncoding()); - auto dotOperandLayout = dyn_cast(dstTy.getEncoding()); - return mmaLayout && dotOperandLayout && mmaLayout.getVersionMajor() == 2 && - mmaLayout.getWarpsPerCTA()[1] == 1 && - dotOperandLayout.getOpIdx() == 0 && - dotOperandLayout.getParent() == mmaLayout && - !srcTy.getElementType().isF32(); -} - namespace { /// A data structure similar to SetVector but maintains diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index ea9091f4e19b..65ee8cc0023e 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -328,7 +328,20 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion } else { // Cast 5. The two layouts are equivalent. We should probably remove // these in RemoveLayoutConversion. - rewriter.replaceOp(op, adaptor.getSrc()); + auto dstCvt = requiresI32Conversion(dstTy); + auto srcCvt = requiresI32Conversion(srcTy); + if (dstCvt || srcCvt) { + auto inVals = unpackLLElements(op.getLoc(), adaptor.getSrc(), rewriter); + inVals = unpackI32s(inVals, srcTy, rewriter, op.getLoc(), + getTypeConverter()); + inVals = + packI32s(inVals, dstTy, rewriter, op.getLoc(), getTypeConverter()); + auto res = packLLElements(op.getLoc(), getTypeConverter(), inVals, + rewriter, op.getType()); + rewriter.replaceOp(op, res); + } else { + rewriter.replaceOp(op, adaptor.getSrc()); + } return success(); } } @@ -342,9 +355,12 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion StringAttr kRegister = str_attr("register"); assert(!cvtNeedsSharedMemory(op.getSrc().getType(), op.getType())); + auto srcTy = op.getSrc().getType(); + auto dstTy = op.getType(); auto inVals = unpackLLElements(loc, adaptor.getSrc(), rewriter); + inVals = unpackI32s(inVals, srcTy, rewriter, loc, getTypeConverter()); SmallVector outVals(numRegs); - for (int i = 0; i < outVals.size(); i++) { + for (int i = 0; i < numRegs; i++) { // Remove free masks from the register index // For example, if idx = 0b00111, and masks = 0b00100, then we get // 0b00011. It means that register 7 (0b111) has the same value as @@ -355,6 +371,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion : idx; outVals[i] = inVals[srcIdx]; } + outVals = packI32s(outVals, dstTy, rewriter, loc, getTypeConverter()); Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter, op.getType()); rewriter.replaceOp(op, result); @@ -386,9 +403,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion if (auto dotOperand = dyn_cast(layout)) { if (auto nvidiaMma = dyn_cast(dotOperand.getParent())) { - if (product(getCTAsPerCGA(nvidiaMma)) > 1) { - return false; - } if (useLegacyMMAConversion) { return false; } @@ -398,6 +412,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion dotOperand.getKWidth() * dstTy.getElementTypeBitWidth() > 64; return largeKWidth && nvidiaMma.isAmpere(); } + return false; } if (isa(layout)) { return true; @@ -439,6 +454,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion inVals[it.index()] = ptrtoint(llvmElemTy, it.value()); } } + inVals = unpackI32s(inVals, srcTy, rewriter, loc, getTypeConverter()); // Pretty sure this is the identity function ATM // It'd be better to simply call `quotient({kBlock})` and @@ -458,22 +474,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion } } - // FIXME [Dot LL] - // We know it's just for largeKWidth case in Ampere - // In this case, we need to pack the outputs into i32 - if (isa(dstTy.getEncoding())) { - auto concat = [&](Value a, Value b) { - return or_(zext(i32_ty, bitcast(a, i16_ty)), - shl(zext(i32_ty, bitcast(b, i16_ty)), i32_val(16))); - }; - - SmallVector outVals32(outVals.size() / 2); - for (int i = 0; i < outVals32.size(); ++i) { - outVals32[i] = concat(outVals[2 * i], outVals[2 * i + 1]); - } - outVals = outVals32; - } - + outVals = packI32s(outVals, dstTy, rewriter, loc, getTypeConverter()); Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter, op.getType()); rewriter.replaceOp(op, result); diff --git a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp index 8ee166866974..470e8b32b540 100644 --- a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -103,51 +103,6 @@ SmallVector reorderValues(const SmallVector &values, Type inType, llvm_unreachable("unimplemented code path"); } -SmallVector unpackI32(const SmallVector &inValues, Type srcTy, - ConversionPatternRewriter &rewriter, Location loc, - const LLVMTypeConverter *typeConverter) { - auto tensorTy = dyn_cast(srcTy); - if (!tensorTy) - return inValues; - auto encoding = dyn_cast(tensorTy.getEncoding()); - if (!(encoding && isa(encoding.getParent()))) - return inValues; - SmallVector outValues; - for (auto v : inValues) { - // cast i32 to appropriate eltType vector and extract elements - auto eltType = typeConverter->convertType(tensorTy.getElementType()); - auto vecType = vec_ty(eltType, 32 / eltType.getIntOrFloatBitWidth()); - auto vec = bitcast(v, vecType); - for (int i = 0; i < 32 / eltType.getIntOrFloatBitWidth(); i++) { - outValues.push_back(extract_element(vec, i32_val(i))); - } - } - return outValues; -} - -SmallVector packI32(const SmallVector &inValues, Type srcTy, - ConversionPatternRewriter &rewriter, Location loc, - const LLVMTypeConverter *typeConverter) { - auto tensorTy = dyn_cast(srcTy); - if (!tensorTy) - return inValues; - auto encoding = dyn_cast(tensorTy.getEncoding()); - if (!(encoding && isa(encoding.getParent()))) - return inValues; - SmallVector outValues; - auto eltType = typeConverter->convertType(tensorTy.getElementType()); - int vecWidth = 32 / eltType.getIntOrFloatBitWidth(); - auto vecType = vec_ty(eltType, vecWidth); - for (int i = 0; i < inValues.size(); i += vecWidth) { - Value vec = undef(vecType); - for (int j = 0; j < vecWidth; j++) { - vec = insert_element(vec, inValues[i + j], i32_val(j)); - } - outValues.push_back(bitcast(vec, i32_ty)); - } - return outValues; -} - int getNumElementsPerThreads(Type type, const LLVMTypeConverter *typeConverter) { int numElemsPerThread = 1; @@ -500,7 +455,7 @@ struct ElementwiseInlineAsmOpConversion auto argTy = op->getOperand(0).getType(); auto subOperands = unpackLLElements(loc, operand, rewriter); unpackedOperands.push_back( - unpackI32(subOperands, argTy, rewriter, loc, getTypeConverter())); + unpackI32s(subOperands, argTy, rewriter, loc, getTypeConverter())); } int numElemsPerThread = getNumElementsPerThreads(op->getResult(0).getType(), @@ -560,10 +515,11 @@ struct ElementwiseInlineAsmOpConversion unpackedResults[i], /*inType=*/op->getOperand(0).getType(), /*ouType=*/op->getResult(i).getType()); } - auto packed = packI32(unpackedResults[i], op->getResult(i).getType(), - rewriter, loc, getTypeConverter()); - outs.push_back(packLLElements(loc, getTypeConverter(), packed, rewriter, - op->getResult(i).getType())); + auto dstTy = op->getResult(i).getType(); + unpackedResults[i] = packI32s(unpackedResults[i], dstTy, rewriter, loc, + getTypeConverter()); + outs.push_back(packLLElements(loc, getTypeConverter(), unpackedResults[i], + rewriter, op->getResult(i).getType())); } rewriter.replaceOp(op, outs); diff --git a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp index 1a0c115a9ecf..e2ed0228de8d 100644 --- a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp @@ -184,42 +184,7 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern { SmallVector outVals = loadSharedToDistributed( dstTy, srcTy, elemLlvmTy, smemObj, loc, rewriter, targetInfo); - // FIXME [Dot LL] - // Ampere case - // In this case, we need to pack the outputs into i32 - if (auto dotOp = dyn_cast(dstTy.getEncoding())) { - if (auto parent = dyn_cast(dotOp.getParent())) { - if (parent.isAmpere()) { - if (elemLlvmTy.isInteger(8)) { - auto concat = [&](Value a1, Value a2, Value a3, Value a4) { - return or_( - or_(zext(i32_ty, a1), shl(zext(i32_ty, a2), i32_val(8))), - or_(shl(zext(i32_ty, a3), i32_val(16)), - shl(zext(i32_ty, a4), i32_val(24)))); - }; - SmallVector outVals32(outVals.size() / 4); - for (int i = 0; i < outVals32.size(); ++i) { - outVals32[i] = concat(outVals[4 * i], outVals[4 * i + 1], - outVals[4 * i + 2], outVals[4 * i + 3]); - } - outVals = outVals32; - } else { - assert(elemLlvmTy.isBF16() && "Unexpected element type"); - auto concat = [&](Value a, Value b) { - return or_(zext(i32_ty, bitcast(a, i16_ty)), - shl(zext(i32_ty, bitcast(b, i16_ty)), i32_val(16))); - }; - - SmallVector outVals32(outVals.size() / 2); - for (int i = 0; i < outVals32.size(); ++i) { - outVals32[i] = concat(outVals[2 * i], outVals[2 * i + 1]); - } - outVals = outVals32; - } - } - } - } - + outVals = packI32s(outVals, dstTy, rewriter, loc, typeConverter); Value result = packLLElements(loc, typeConverter, outVals, rewriter, dstTy); rewriter.replaceOp(op, result); diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index 56af4eaef8b9..6978ccfb2553 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -551,8 +551,8 @@ AMDMfmaEncodingAttr::toLinearLayout(ArrayRef shape) const { } std::optional -dotOperandMfmaToLinearLayout(DotOperandEncodingAttr dotMfmaLayout, - ArrayRef shape) { +mfmaDotToLinearLayout(DotOperandEncodingAttr dotMfmaLayout, + ArrayRef shape) { // Current linear layout conversion for dot operand is only necessary to // enable LDS bypass for operand B in the MFMA dot path. To achieve @@ -895,7 +895,7 @@ LinearLayout ampereDotToLinearLayout(ArrayRef shape, llvm::to_vector(llvm::reverse(ArrayRef(dimNames).take_back(2)))); auto order = dot.getCTAOrder(); - assert(order[0] == 1 && order[1] == 0); + assert(order[0] == rank - 1 && order[1] == rank - 2); ctaLayout *= identityND(S("warp"), dot.getWarpsPerCTA(), order, dimNames); return combineCtaCgaWithShape(ctaLayout, mma.getCTALayout(), shape); @@ -903,13 +903,11 @@ LinearLayout ampereDotToLinearLayout(ArrayRef shape, std::optional DotOperandEncodingAttr::toLinearLayout(ArrayRef shape) const { - if (auto mfmaLayout = llvm::dyn_cast(getParent())) { - return dotOperandMfmaToLinearLayout(*this, shape); - } else if (auto mma = mlir::dyn_cast(getParent())) { - // FIXME [Dot LL] - // Do this unconditionally - auto largeKWidth = getKWidth() == 8; - if (mma.isAmpere() && largeKWidth) { + auto parent = getParent(); + if (auto mfmaLayout = llvm::dyn_cast(parent)) { + return mfmaDotToLinearLayout(*this, shape); + } else if (auto mma = mlir::dyn_cast(parent)) { + if (mma.getVersionMajor() == 2 && mma.getVersionMinor() == 0) { return ampereDotToLinearLayout(shape, *this); } } diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index 6d8279795209..9f3d8fff491b 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -290,7 +290,7 @@ struct MMAV3UseRegOperand dotOp.getContext(), /*opIdx=*/0, srcEnc, /*kWidth=*/0); auto newTy = RankedTensorType::get(srcTy.getShape(), srcTy.getElementType(), dotOperandEnc); - if (!isMmaToDotShortcut(srcTy, newTy)) + if (!matchMmaV3AndDotOperandLayout(srcTy, newTy)) return failure(); Value newOperand = diff --git a/test/Analysis/test-allocation.mlir b/test/Analysis/test-allocation.mlir index a0719c974f9c..208f6b80bfe5 100644 --- a/test/Analysis/test-allocation.mlir +++ b/test/Analysis/test-allocation.mlir @@ -6,7 +6,7 @@ #A_SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> #A_SHARED_T = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [0, 1]}> #B_SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> -#C = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> +#C = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1], instrShape = [16, 8]}> #A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 2}> #B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 2}> diff --git a/test/Analysis/test-membar.mlir b/test/Analysis/test-membar.mlir index 2054853b30c1..df4f5ab01feb 100644 --- a/test/Analysis/test-membar.mlir +++ b/test/Analysis/test-membar.mlir @@ -5,7 +5,7 @@ #BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> #A_SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> #A_SHARED_T = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [0, 1]}> -#C = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> +#C = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1], instrShape = [16, 8]}> #A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 2}> #B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 2}> diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index 4a61ee4bc1b0..61dc9c7ef9a6 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -823,6 +823,110 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- +#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> +#dot1 = #triton_gpu.dot_op<{opIdx=0, parent=#mma, kWidth=2}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { + // CHECK-LABEL: convert_layout_mmav2_dot_reg + tt.func @convert_layout_mmav2_dot_reg(%arg0: tensor<16x16xf16, #mma>) { + // CHECK-NOT: st.shared + // CHECK-NOT: llvm.load + %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf16, #mma> -> tensor<16x16xf16, #dot1> + tt.return + } +} + +// ----- + +#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> +#dot1 = #triton_gpu.dot_op<{opIdx=0, parent=#mma, kWidth=2}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { + // CHECK-LABEL: convert_layout_mmav2_dot_reg + tt.func @convert_layout_mmav2_dot_reg(%arg0: tensor<1x16xf16, #mma>) { + // CHECK-NOT: st.shared + // CHECK-NOT: llvm.load + %0 = triton_gpu.convert_layout %arg0 : tensor<1x16xf16, #mma> -> tensor<1x16xf16, #dot1> + tt.return + } +} + +// ----- + +#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}> +#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#slice = #triton_gpu.slice<{dim = 0, parent = #mma}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: convert_layout_slice_mmav2_blocked_reg + tt.func @convert_layout_slice_mmav2_blocked_reg(%arg0: tensor<1xf16, #slice>) { + // CHECK-NOT: st.shared + // CHECK-NOT: llvm.load + %0 = triton_gpu.convert_layout %arg0 : tensor<1xf16, #slice> -> tensor<1xf16, #blocked> + tt.return + } +} + +// ----- + +#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: convert_layout_mmav3_mmav3_0 + tt.func @convert_layout_mmav3_mmav3_0(%arg0: tensor<64x64xf16, #mma0>) { + // CHECK-NOT: st.shared + // CHECK-NOT: llvm.load + %0 = triton_gpu.convert_layout %arg0 : tensor<64x64xf16, #mma0> -> tensor<64x64xf16, #mma1> + tt.return + } +} + +// ----- + +#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: convert_layout_mmav3_mmav3_1 + tt.func @convert_layout_mmav3_mmav3_1(%arg0: tensor<64x64xf16, #mma1>) { + // CHECK-NOT: st.shared + // CHECK-NOT: llvm.load + %0 = triton_gpu.convert_layout %arg0 : tensor<64x64xf16, #mma1> -> tensor<64x64xf16, #mma0> + tt.return + } +} + +// ----- + +#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: convert_layout_mmav3_mmav3_2 + tt.func @convert_layout_mmav3_mmav3_2(%arg0: tensor<16x16xf16, #mma1>) { + // CHECK-NOT: st.shared + // CHECK-NOT: llvm.load + %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf16, #mma1> -> tensor<16x16xf16, #mma0> + tt.return + } +} + +// ----- + +#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: convert_layout_mmav3_mmav3_3 + tt.func @convert_layout_mmav3_mmav3_3(%arg0: tensor<1x64xf16, #mma1>) { + // CHECK-NOT: st.shared + // CHECK-NOT: llvm.load + %0 = triton_gpu.convert_layout %arg0 : tensor<1x64xf16, #mma1> -> tensor<1x64xf16, #mma0> + tt.return + } +} + +// ----- + #blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> #mma = #triton_gpu.nvidia_mma<{versionMajor = 1, versionMinor = 3, warpsPerCTA = [2, 2], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 16]}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 71fd3c0cd4e7..96289bbb2e47 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -631,46 +631,6 @@ struct ConvertLayoutOpConversion convertMMAV3To8BitsDotOperand(op, adaptor, rewriter); return success(); } - - if (isMmaToDotShortcut(srcTy, dstTy)) { - // get source values - auto vals = unpackLLElements(loc, adaptor.getSrc(), rewriter); - unsigned elems = getTotalElemsPerThread(srcTy); - Type elemTy = - this->getTypeConverter()->convertType(srcTy.getElementType()); - // for the destination type, we need to pack values together - // so they can be consumed by tensor core operations - SmallVector vecVals; - // For some reasons, LLVM's NVPTX backend inserts unnecessary (?) integer - // instructions to pack & unpack sub-word integers. A workaround is to - // store the results of ldmatrix in i32 - auto elemSize = elemTy.getIntOrFloatBitWidth(); - if (auto intTy = dyn_cast(elemTy) && elemSize <= 16) { - auto fold = 32 / elemSize; - for (unsigned i = 0; i < elems; i += fold) { - Value val = i32_val(0); - for (unsigned j = 0; j < fold; j++) { - auto ext = - shl(i32_ty, zext(i32_ty, vals[i + j]), i32_val(elemSize * j)); - val = or_(i32_ty, val, ext); - } - vecVals.push_back(bitcast(val, i32_ty)); - } - } else { - unsigned vecSize = std::max(32 / elemSize, 1); - Type vecTy = vec_ty(elemTy, vecSize); - for (unsigned i = 0; i < elems; i += vecSize) { - Value packed = rewriter.create(loc, vecTy); - for (unsigned j = 0; j < vecSize; j++) - packed = insert_element(vecTy, packed, vals[i + j], i32_val(j)); - vecVals.push_back(bitcast(packed, i32_ty)); - } - } - Value view = - packLLElements(loc, getTypeConverter(), vecVals, rewriter, dstTy); - rewriter.replaceOp(op, view); - return success(); - } return failure(); } diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp index cf0ddc248dd1..36b14e270b27 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp @@ -70,10 +70,18 @@ struct DecomposeUnsupportedConversions : public mlir::triton::impl::DecomposeUnsupportedNVIDIAConversionsBase< DecomposeUnsupportedConversions> { void runOnOperation() override { + // FIXME [Dot LL] + // Remove the decomposeTensorCoreToDotLayoutConversion class entirely after + // we have enabled the new layout conversion for all the cases. + auto nvidiaShortCutFn = [&](RankedTensorType srcTy, + RankedTensorType dstTy) { + return matchMmaV3AndDotOperandLayout(srcTy, dstTy) || + cvtReordersRegisters(srcTy, dstTy); + }; ModuleOp mod = getOperation(); triton::gpu::decomposeSplatOpToSharedLayoutConversion(mod); triton::gpu::decomposeTensorCoreToDotLayoutConversion(mod, - isMmaToDotShortcut); + nvidiaShortCutFn); triton::gpu::decomposeBlockedToDotLayoutConversion(mod); mlir::RewritePatternSet patterns(&getContext()); diff --git a/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp b/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp index fd65233e5c6b..d4c15bbad03f 100644 --- a/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp +++ b/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp @@ -41,9 +41,9 @@ class LinearLayoutConversionsTest : public ::testing::Test { CTALayoutAttr::get(&ctx, cpg, cSplit, cOrd), instrShape); } - DotOperandEncodingAttr dotMMAv2(int idx, int kWidth, ArrayRef warps, - ArrayRef order) { - auto mmaLayout = mma(2, 0, {16, 8}, warps, {1, 1}, {1, 1}, order); + DotOperandEncodingAttr dotMMAv2(int idx, int kWidth, + ArrayRef warps) { + auto mmaLayout = mma(2, 0, {16, 8}, warps, {1, 1}, {1, 1}, {1, 0}); return DotOperandEncodingAttr::get(&ctx, idx, mmaLayout, /*kWidth=*/kWidth); } @@ -301,6 +301,19 @@ TEST_F(LinearLayoutConversionsTest, Blocked4D) { {S("dim0"), S("dim1"), S("dim2"), S("dim3")})); } +TEST_F(LinearLayoutConversionsTest, MMAv2_16x16) { + EXPECT_EQ(toLinearLayout({16, 16}, + mma(2, 0, {16, 8}, {1, 1}, {1, 1}, {1, 1}, {0, 1})), + LinearLayout( + { + {S("register"), {{0, 1}, {8, 0}, {0, 8}}}, + {S("lane"), {{0, 2}, {0, 4}, {1, 0}, {2, 0}, {4, 0}}}, + {S("warp"), {}}, + {S("block"), {}}, + }, + {S("dim0"), S("dim1")})); +} + TEST_F(LinearLayoutConversionsTest, MMAv2_32x32) { EXPECT_EQ(toLinearLayout({32, 32}, mma(2, 0, {16, 8}, {1, 1}, {1, 1}, {1, 1}, {0, 1})), @@ -502,7 +515,7 @@ TEST_F(LinearLayoutConversionsTest, MMAv3_4x4Warps) { } TEST_F(LinearLayoutConversionsTest, DotMMAv2_tile_kwidth8) { - EXPECT_EQ(toLinearLayout({16, 64}, dotMMAv2(0, 8, {1, 1}, {1, 0})), + EXPECT_EQ(toLinearLayout({16, 64}, dotMMAv2(0, 8, {1, 1})), LinearLayout( { {S("register"), {{0, 1}, {0, 2}, {0, 4}, {8, 0}, {0, 32}}}, @@ -511,7 +524,7 @@ TEST_F(LinearLayoutConversionsTest, DotMMAv2_tile_kwidth8) { {S("block"), {}}, }, {S("dim0"), S("dim1")})); - EXPECT_EQ(toLinearLayout({64, 8}, dotMMAv2(1, 8, {1, 1}, {1, 0})), + EXPECT_EQ(toLinearLayout({64, 8}, dotMMAv2(1, 8, {1, 1})), LinearLayout( { {S("register"), {{1, 0}, {2, 0}, {4, 0}, {32, 0}}}, @@ -524,7 +537,7 @@ TEST_F(LinearLayoutConversionsTest, DotMMAv2_tile_kwidth8) { TEST_F(LinearLayoutConversionsTest, DotMMAv2_large_warp4_kwidth8) { EXPECT_EQ( - toLinearLayout({128, 128}, dotMMAv2(0, 8, {4, 1}, {1, 0})), + toLinearLayout({128, 128}, dotMMAv2(0, 8, {4, 1})), LinearLayout( { {S("register"), @@ -534,7 +547,7 @@ TEST_F(LinearLayoutConversionsTest, DotMMAv2_large_warp4_kwidth8) { {S("block"), {}}, }, {S("dim0"), S("dim1")})); - EXPECT_EQ(toLinearLayout({128, 64}, dotMMAv2(1, 8, {4, 1}, {1, 0})), + EXPECT_EQ(toLinearLayout({128, 64}, dotMMAv2(1, 8, {4, 1})), LinearLayout( { {S("register"), @@ -554,7 +567,7 @@ TEST_F(LinearLayoutConversionsTest, DotMMAv2_large_warp4_kwidth8) { {S("block"), {}}, }, {S("dim0"), S("dim1")})); - EXPECT_EQ(toLinearLayout({64, 128}, dotMMAv2(1, 8, {4, 1}, {1, 0})), + EXPECT_EQ(toLinearLayout({64, 128}, dotMMAv2(1, 8, {4, 1})), LinearLayout( { {S("register"),