diff --git a/include/triton/Dialect/TritonGPU/IR/Dialect.h b/include/triton/Dialect/TritonGPU/IR/Dialect.h index bb323f592cff..c312fb2a8c84 100644 --- a/include/triton/Dialect/TritonGPU/IR/Dialect.h +++ b/include/triton/Dialect/TritonGPU/IR/Dialect.h @@ -178,11 +178,11 @@ unsigned getNumCTAs(Attribute layout); // len(shape) == rank. SmallVector getMatrixOrder(unsigned rank, bool rowMajor); -// Return the order that represents that the dot operand is in kMajor +// Return the order that represents that the dot operand is in kContig // (contiguous in the inner dimension) or it's contiguous on the outer // dimension. SmallVector getOrderForDotOperand(unsigned opIdx, unsigned rank, - bool kMajor); + bool kContig); bool isExpensiveCat(CatOp cat, Attribute targetEncoding); diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 9153db3309fa..b793c78ce163 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -227,15 +227,15 @@ SmallVector getMatrixOrder(unsigned rank, bool rowMajor) { } SmallVector getOrderForDotOperand(unsigned opIdx, unsigned rank, - bool kMajor) { - // kMajor: if true, the matrix is fastest-running on k, + bool kContig) { + // kContig: if true, the matrix is fastest-running on k, // otherwise it is on m (resp. n) // opIdx=0: [batch, m, k] if rank == 3 else [m, k] // opIdx=1: [batch, k, n] if rank == 3 else [k, n] // batch (if rank == 3) is always the slowest running dimension assert(rank == 2 || rank == 3); assert(opIdx == 0 || opIdx == 1); - auto rowMajor = bool(opIdx) != kMajor; + auto rowMajor = bool(opIdx) != kContig; return getMatrixOrder(rank, rowMajor); } @@ -268,7 +268,7 @@ SmallVector getOrder(Attribute layout) { } if (auto dotLayout = dyn_cast(layout)) { auto rank = dotLayout.getWarpsPerCTA().size(); - return getOrderForDotOperand(dotLayout.getOpIdx(), rank, /*kMajor*/ true); + return getOrderForDotOperand(dotLayout.getOpIdx(), rank, /*kContig*/ true); } if (auto sliceLayout = dyn_cast(layout)) { SmallVector parentOrder = getOrder(sliceLayout.getParent()); @@ -987,7 +987,7 @@ SmallVector DotOperandEncodingAttr::getWarpOrder() const { } SmallVector DotOperandEncodingAttr::getThreadOrder() const { return getOrderForDotOperand(getOpIdx(), getWarpsPerCTA().size(), - /*kMajor*/ true); + /*kContig*/ true); } LogicalResult DotOperandEncodingAttr::verify( @@ -1959,7 +1959,7 @@ SmallVector AMDMfmaEncodingAttr::getRepOrder() const { SmallVector AMDMfmaEncodingAttr::getRepOrderForOperand(int opIdx) const { auto rank = getWarpsPerCTA().size(); - return getOrderForDotOperand(opIdx, rank, /*kMajor*/ true); + return getOrderForDotOperand(opIdx, rank, /*kContig*/ true); } SmallVector @@ -2027,7 +2027,7 @@ SmallVector AMDWmmaEncodingAttr::getRepOrder() const { SmallVector AMDWmmaEncodingAttr::getRepOrderForOperand(int opIdx) const { auto rank = getWarpsPerCTA().size(); - return getOrderForDotOperand(opIdx, rank, /*kMajor*/ true); + return getOrderForDotOperand(opIdx, rank, /*kContig*/ true); } SmallVector @@ -2219,7 +2219,7 @@ SmallVector NvidiaMmaEncodingAttr::getSizePerThread() const { SmallVector NvidiaMmaEncodingAttr::getRepOrderForOperand(int opIdx) const { auto rank = getWarpsPerCTA().size(); - return getOrderForDotOperand(opIdx, rank, /*kMajor*/ true); + return getOrderForDotOperand(opIdx, rank, /*kContig*/ true); } SmallVector diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandHelper.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandHelper.cpp index cd8f0db2ca3c..4206705c5175 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandHelper.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandHelper.cpp @@ -74,7 +74,7 @@ Value computeBasePtr(ConversionPatternRewriter &rewriter, Location loc, return base; } -bool isKMajor(llvm::ArrayRef order, int opIdx) { +bool isKContig(llvm::ArrayRef order, int opIdx) { auto rank = order.size(); int kdim = opIdx == 0 ? rank - 1 : rank - 2; return order[0] == kdim; @@ -102,9 +102,9 @@ bool isSwizzlePatternFitsIntoBlock(const SharedEncodingAttr sharedLayout, const auto swizzleSlowDimSize = sharedLayout.getMaxPhase() * sharedLayout.getPerPhase(); const auto swizzlePatternSizeK = - isKMajor(order, opIdx) ? swizzleFastDimSize : swizzleSlowDimSize; + isKContig(order, opIdx) ? swizzleFastDimSize : swizzleSlowDimSize; const auto swizzlePatternSizeNonK = - !isKMajor(order, opIdx) ? swizzleFastDimSize : swizzleSlowDimSize; + !isKContig(order, opIdx) ? swizzleFastDimSize : swizzleSlowDimSize; const auto blockSizeK = mfmaInstrK * reps[reps.size() - 1]; const auto blockSizeNonK = mfmaInstrNonK * warpsPerBlockNonK; diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandHelper.h b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandHelper.h index e691dbbc437f..fe4613606dfc 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandHelper.h +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandHelper.h @@ -36,7 +36,7 @@ Value computeBasePtr(ConversionPatternRewriter &rewriter, Location loc, const SharedMemoryObject &smemObj, ArrayRef strides); -bool isKMajor(llvm::ArrayRef order, int opIdx); +bool isKContig(llvm::ArrayRef order, int opIdx); using computeTensorElemMappingInBlockT = std::function>( diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp index e9fcfffbb86c..e4a683d870f8 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp @@ -279,7 +279,7 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, Value smemBase; auto smemStrides = smemObj.getStrides(aTensorTy, loc, rewriter); bool isFastPath = - !AMD::isKMajor(order, opIdx) && !hasSwizzleEnabled(sharedLayout); + !AMD::isKContig(order, opIdx) && !hasSwizzleEnabled(sharedLayout); if (isFastPath) { // fast path handles tensors that are not k-major and have swizzling // disabled, in which case offsets computation can be simplified diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp index 8af2bb926648..c5ec00097d93 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp @@ -498,13 +498,13 @@ LogicalResult convertDot(const LLVMTypeConverter *typeConverter, // getValuesFromDotOperandLayoutStruct as both a and b are K-major assert(dotOpA.getRepOrder() == getOrderForDotOperand(dotOpA.getOpIdx(), aShapePerCTA.size(), - /*kMajor=*/true)); + /*kContig=*/true)); auto ha = getValuesFromDotOperandLayoutStruct( typeConverter, loc, rewriter, loadedA, repBatch, repM, repK, aTensorTy); assert(dotOpB.getRepOrder() == getOrderForDotOperand(dotOpB.getOpIdx(), bShapePerCTA.size(), - /*kMajor=*/true)); + /*kContig=*/true)); auto hb = getValuesFromDotOperandLayoutStruct( typeConverter, loc, rewriter, loadedB, repBatch, repN, repK, bTensorTy);