diff --git a/include/triton/Dialect/TritonGPU/IR/Dialect.h b/include/triton/Dialect/TritonGPU/IR/Dialect.h index a9b49448c1d0..f2715043d799 100644 --- a/include/triton/Dialect/TritonGPU/IR/Dialect.h +++ b/include/triton/Dialect/TritonGPU/IR/Dialect.h @@ -116,9 +116,7 @@ SmallVector getCTAOrder(Attribute layout); * (3) In the implementation of emitIndices, ShapePerCTATile will * be replicated or wrapped to fit ShapePerCTA. */ -SmallVector -getShapePerCTATile(Attribute layout, - ArrayRef tensorShape = ArrayRef()); +SmallVector getShapePerCTATile(Attribute layout); SmallVector getShapePerCTA(ArrayRef CTASplitNum, ArrayRef shape); diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index 07514126d2a4..e6be2f83323f 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -502,11 +502,6 @@ We call each individual tile "rep". "SmallVector", "getCTASplitNum">, - InterfaceMethod<"Gets the shape of the encoding's tile, e.g. sizePerThread * threadsPerWarp * warpsPerCTA", - "SmallVector", - "getShapePerCTATile", - (ins "ArrayRef":$tensorShape)>, - InterfaceMethod<"Gets the number of contiguous elements per thread.", "SmallVector", "getContigPerThread">, @@ -565,7 +560,6 @@ L(T) = [ {0,8} , {1,9} , {2,10}, {3,11}, {0,8} , {1, 9} , {2, 10}, {3, 11}, SmallVector getThreadOrder() const; SmallVector getSizePerThread() const; - SmallVector getShapePerCTATile(ArrayRef tensorShape = ArrayRef()) const; std::optional toLinearLayout(ArrayRef shape) const; }]; @@ -765,13 +759,6 @@ def MmaEncodingTrait : AttrInterface<"MmaEncodingTrait"> { "bool", "supportReduction">, - InterfaceMethod<"Return shape per CTA.", - "SmallVector", - "getShapePerCTATileForOperand", - (ins "ArrayRef":$tensorShape, - "int":$kWidth, - "int":$opIdx)>, - InterfaceMethod<"Return size per thread for dot operands.", "SmallVector", "getSizePerThreadForOperand", @@ -900,7 +887,6 @@ V [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129, return true; } SmallVector getSizePerThreadForOperand(int kWidth, int opIdx) const; - SmallVector getShapePerCTATileForOperand(ArrayRef shape, int kWidth, int opIdx) const; unsigned getTotalElemsPerThreadForOperand(ArrayRef shape, Type eltTy, int kWidth, int opIdx) const; SmallVector getInstrShapeForOperand(int kWidth, int opIdx) const; SmallVector getRepForOperand(ArrayRef operandShape, int kWidth, int opIdx) const; @@ -1008,7 +994,6 @@ Row | warp 0 warp 2 return true; } SmallVector getSizePerThreadForOperand(int kWidth, int opIdx) const; - SmallVector getShapePerCTATileForOperand(ArrayRef shape, int kWidth, int opIdx) const; unsigned getTotalElemsPerThreadForOperand(ArrayRef shape, Type eltTy, int kWidth, int opIdx) const; SmallVector getElemsPerInstrForOperands() const; SmallVector getRepForOperand(ArrayRef operandShape, @@ -1140,7 +1125,6 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is: return false; }; SmallVector getSizePerThreadForOperand(int kWidth, int opIdx) const; - SmallVector getShapePerCTATileForOperand(ArrayRef shape, int kWidth, int opIdx) const; SmallVector getContigPerThread() { assert(isAmpere() || isHopper()); diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index 02269c9aacf5..53897578aa4a 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -41,10 +41,8 @@ static SmallVector getRepShapeForCvt(RankedTensorType srcTy, auto srcShapePerCTA = gpu::getShapePerCTA(srcTy); auto dstShapePerCTA = gpu::getShapePerCTA(dstTy); - auto srcShapePerCTATile = - gpu::getShapePerCTATile(srcLayout, srcTy.getShape()); - auto dstShapePerCTATile = - gpu::getShapePerCTATile(dstLayout, dstTy.getShape()); + auto srcShapePerCTATile = gpu::getShapePerCTATile(srcLayout); + auto dstShapePerCTATile = gpu::getShapePerCTATile(dstLayout); assert(srcTy.getRank() == dstTy.getRank() && "src and dst must have the same rank"); diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index a802d62ace1f..e48cfca441d3 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -174,8 +174,8 @@ struct ConvertLayoutOpConversion SmallVector outNumCTAsEachRep(rank); SmallVector inNumCTAs(rank); SmallVector outNumCTAs(rank); - auto srcShapePerCTATile = getShapePerCTATile(srcLayout, srcTy.getShape()); - auto dstShapePerCTATile = getShapePerCTATile(dstLayout, shape); + auto srcShapePerCTATile = getShapePerCTATile(srcLayout); + auto dstShapePerCTATile = getShapePerCTATile(dstLayout); auto shapePerCTA = getShapePerCTA(srcLayout, shape); for (unsigned d = 0; d < rank; ++d) { diff --git a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp index 26dc8a537973..088dbd997602 100644 --- a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp @@ -421,7 +421,7 @@ struct ReduceOpConversion auto resultIndices = emitIndices(loc, rewriter, targetInfo, resultLayout, resultTy, true); auto resultShape = resultTy.getShape(); - auto resultCTATile = getShapePerCTATile(resultLayout, resultShape); + auto resultCTATile = getShapePerCTATile(resultLayout); assert(resultIndices.size() == resultElems); SmallVector resultVals(resultElems); diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index dce2a6034f6f..721c8dd10c09 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -201,12 +201,25 @@ SmallVector getUniqueContigPerThread(Attribute layout, } return ret; } - -SmallVector getShapePerCTATile(Attribute layout, - ArrayRef tensorShape) { +SmallVector getShapePerCTATile(Attribute layout) { if (auto distributedLayout = mlir::dyn_cast(layout)) { - return distributedLayout.getShapePerCTATile(tensorShape); + auto sizePerThread = distributedLayout.getSizePerThread(); + auto threadsPerWarp = distributedLayout.getThreadsPerWarp(); + // ThreadsPerWarp does not align with this function for slice layout + if (auto sliceLayout = mlir::dyn_cast(layout)) { + threadsPerWarp = getThreadsPerWarp(sliceLayout.getParent()); + threadsPerWarp.erase(threadsPerWarp.begin() + sliceLayout.getDim()); + } + auto warpsPerCTA = distributedLayout.getWarpsPerCTA(); + assert(sizePerThread.size() == threadsPerWarp.size() && + sizePerThread.size() == warpsPerCTA.size()); + SmallVector shape; + for (auto [size, thread, warp] : + llvm::zip(sizePerThread, threadsPerWarp, warpsPerCTA)) { + shape.push_back(size * thread * warp); + } + return shape; } else { llvm::report_fatal_error("getShapePerCTATile not implemented"); return SmallVector(); @@ -678,14 +691,6 @@ SmallVector BlockedEncodingAttr::getThreadOrder() const { SmallVector BlockedEncodingAttr::getSizePerThread() const { return SmallVector(getSizePerThread__()); } -SmallVector -BlockedEncodingAttr::getShapePerCTATile(ArrayRef tensorShape) const { - SmallVector shape; - for (unsigned d = 0, n = getOrder().size(); d < n; ++d) - shape.push_back(getSizePerThread()[d] * getThreadsPerWarp()[d] * - getWarpsPerCTA()[d]); - return shape; -} template SmallVector SliceEncodingAttr::paddedShape(ArrayRef shape) const { @@ -787,12 +792,6 @@ SmallVector SliceEncodingAttr::getSizePerThread() const { sizePerThread.erase(sizePerThread.begin() + getDim()); return sizePerThread; } -SmallVector -SliceEncodingAttr::getShapePerCTATile(ArrayRef tensorShape) const { - SmallVector shape = ::getShapePerCTATile(getParent(), tensorShape); - shape.erase(shape.begin() + getDim()); - return shape; -} // @@ -979,9 +978,9 @@ unsigned DotOperandEncodingAttr::getTotalElemsPerThread(ArrayRef shape, } if (auto blockedLayout = mlir::dyn_cast(getParent())) { auto shapePerCTA = getShapePerCTA(*this, shape); - auto shapePerCTATile = ::getShapePerCTATile(blockedLayout); + auto shapePerCTATile = getShapePerCTATile(blockedLayout); auto order = blockedLayout.getOrder(); - auto sizePerThread = ::getSizePerThread(blockedLayout); + auto sizePerThread = blockedLayout.getSizePerThread(); int K = getOpIdx() == 0 ? shapePerCTA[1] : shapePerCTA[0]; int otherDim = getOpIdx() == 1 ? shapePerCTA[1] : shapePerCTA[0]; @@ -1043,19 +1042,6 @@ SmallVector DotOperandEncodingAttr::getThreadOrder() const { return getOrderForDotOperand(getOpIdx(), getWarpsPerCTA().size(), /*kMajor*/ true); } -SmallVector DotOperandEncodingAttr::getShapePerCTATile( - ArrayRef tensorShape) const { - auto parentLayout = getParent(); - assert(parentLayout && "DotOperandEncodingAttr must have a parent"); - if (auto parentMmaLayout = mlir::dyn_cast(parentLayout)) { - return parentMmaLayout.getShapePerCTATileForOperand( - tensorShape, getKWidth(), getOpIdx()); - } else { - llvm::report_fatal_error( - "DotOperandEncodingAttr non-NvidiaMmaEncodingAttr parent not " - "supported yet"); - } -} LogicalResult DotOperandEncodingAttr::verify( ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, @@ -1562,16 +1548,6 @@ void SharedEncodingAttr::print(AsmPrinter &printer) const { //===----------------------------------------------------------------------===// // TODO: there is a lot of common code with MmaEncoding here -SmallVector -AMDMfmaEncodingAttr::getShapePerCTATile(ArrayRef tensorShape) const { - auto warpsPerCTA = getWarpsPerCTA(); - auto rank = warpsPerCTA.size(); - SmallVector shapePerCTATile(warpsPerCTA.begin(), warpsPerCTA.end()); - shapePerCTATile[rank - 1] *= getMDim(); - shapePerCTATile[rank - 2] *= getNDim(); - return shapePerCTATile; -} - SmallVector AMDMfmaEncodingAttr::getCTAsPerCGA() const { return SmallVector(getCTALayout().getCTAsPerCGA()); } @@ -1715,43 +1691,10 @@ AMDMfmaEncodingAttr::getSizePerThreadForOperand(int kWidth, int opIdx) const { return sizePerThread; } -SmallVector -AMDMfmaEncodingAttr::getShapePerCTATileForOperand(ArrayRef shape, - int kWidth, int opIdx) const { - assert(getMDim() == 32 || getMDim() == 16); - auto parentShapePerCTATile = getShapePerCTATile(shape); - auto rank = parentShapePerCTATile.size(); - if (opIdx == 0) { - if (rank == 2) - return {parentShapePerCTATile[rank - 2], 32}; - else - return {parentShapePerCTATile[0], parentShapePerCTATile[rank - 2], 32}; - } else if (opIdx == 1) { - if (rank == 2) - return {32, parentShapePerCTATile[rank - 1]}; - else - return {parentShapePerCTATile[0], 32, parentShapePerCTATile[rank - 1]}; - } else { - llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1"); - } - llvm_unreachable("DotOperandEncodingAttr opIdx must be 0 or 1"); -} - //===----------------------------------------------------------------------===// // Wmma encoding //===----------------------------------------------------------------------===// -SmallVector -AMDWmmaEncodingAttr::getShapePerCTATile(ArrayRef tensorShape) const { - auto warpsPerCTA = getWarpsPerCTA(); - auto rank = warpsPerCTA.size(); - SmallVector shapePerCTATile(warpsPerCTA.begin(), warpsPerCTA.end()); - - auto mnkDim = getMNKDimPerInstr(); - shapePerCTATile[rank - 2] *= mnkDim[0]; - shapePerCTATile[rank - 1] *= mnkDim[1]; - return shapePerCTATile; -} SmallVector AMDWmmaEncodingAttr::getRepOrder() const { auto rank = getWarpsPerCTA().size(); return getMatrixOrder(rank, /*rowMajor*/ true); @@ -1816,21 +1759,6 @@ AMDWmmaEncodingAttr::getSizePerThreadForOperand(int kWidth, int opIdx) const { return sizePerThread; } -SmallVector -AMDWmmaEncodingAttr::getShapePerCTATileForOperand(ArrayRef shape, - int kWidth, int opIdx) const { - auto parentShapePerCTA = getShapePerCTATile(shape); - auto rank = shape.size(); - assert(rank == 2); - if (opIdx == 0) { - return {parentShapePerCTA[0], static_cast(shape[1])}; - } else if (opIdx == 1) { - return {static_cast(shape[0]), parentShapePerCTA[1]}; - } else { - llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1"); - } -} - unsigned AMDWmmaEncodingAttr::getTotalElemsPerThreadForOperand( ArrayRef shape, Type eltTy, int kWidth, int opIdx) const { auto rep = getRepForOperand(shape, eltTy, kWidth, opIdx); @@ -1949,24 +1877,6 @@ SmallVector NvidiaMmaEncodingAttr::getSizePerThread() const { llvm_unreachable("Unexpected mma version"); } -SmallVector -NvidiaMmaEncodingAttr::getShapePerCTATile(ArrayRef tensorShape) const { - if (isAmpere()) { - auto warpsPerCTA = getWarpsPerCTA(); - auto rank = warpsPerCTA.size(); - SmallVector shapePerCTATile(warpsPerCTA.begin(), - warpsPerCTA.end()); - shapePerCTATile[rank - 1] *= 8; - shapePerCTATile[rank - 2] *= 16; - return shapePerCTATile; - } - if (isHopper()) { - auto instrShape = getInstrShape(); - return {16 * getWarpsPerCTA()[0], instrShape[1] * getWarpsPerCTA()[1]}; - } - llvm::report_fatal_error("Unexpected MMA layout version found"); -} - SmallVector NvidiaMmaEncodingAttr::getRepOrderForOperand(int opIdx) const { auto rank = getWarpsPerCTA().size(); @@ -2007,16 +1917,6 @@ NvidiaMmaEncodingAttr::getRepForOperand(ArrayRef shape, int bitwidth, } } -SmallVector NvidiaMmaEncodingAttr::getShapePerCTATileForOperand( - ArrayRef shape, int kWidth, int opIdx) const { - assert(isAmpere() && "mmaLayout Hopper is not implemented yet"); - auto shapePerCTATile = getShapePerCTATile(shape); - auto rank = shapePerCTATile.size(); - auto kDim = opIdx == 0 ? rank - 1 : rank - 2; - // 4 threads * 2 subtiles - shapePerCTATile[kDim] = kWidth * 2 * 4; - return shapePerCTATile; -} SmallVector NvidiaMmaEncodingAttr::getSizePerThreadForOperand(int kWidth, int opIdx) const { auto rank = getWarpsPerCTA().size(); diff --git a/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp b/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp index 7c2473dbe56f..0e2a9304ebfe 100644 --- a/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp +++ b/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp @@ -78,8 +78,7 @@ LogicalResult ExtractSliceOp::verify() { } auto srcShape = srcTy.getShape(); - auto shapePerCTATile = - mlir::triton::gpu::getShapePerCTATile(srcLayout, srcShape); + auto shapePerCTATile = mlir::triton::gpu::getShapePerCTATile(srcLayout); shapePerCTATile[0] = std::min(static_cast(srcShape[0]), shapePerCTATile[0]); shapePerCTATile[1] = diff --git a/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ExtractSliceOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ExtractSliceOpToLLVM.cpp index c0100812f299..ad56bd2d414e 100644 --- a/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ExtractSliceOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ExtractSliceOpToLLVM.cpp @@ -70,7 +70,7 @@ struct ExtractSliceOpConversion auto order = triton::gpu::getOrder(srcLayout); // Calculate valid total number of workers in each dimension - auto shapePerCTATile = triton::gpu::getShapePerCTATile(srcLayout, srcShape); + auto shapePerCTATile = triton::gpu::getShapePerCTATile(srcLayout); shapePerCTATile[0] = std::min(static_cast(srcShape[0]), shapePerCTATile[0]); shapePerCTATile[1] = diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp index 2e28dec802b5..825697e0e911 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -48,7 +48,7 @@ Value redundantDataMask(Type valueTy, ConversionPatternRewriter &rewriter, } else { warpOrder = triton::gpu::getWarpOrder(layout); } - auto shapePerCTATile = triton::gpu::getShapePerCTATile(layout, shape); + auto shapePerCTATile = triton::gpu::getShapePerCTATile(layout); Value warpSize = i32_val(triton::gpu::getWarpSize(layout)); Value laneId = urem(tid, warpSize); Value warpId = udiv(tid, warpSize); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 185b653f72b0..1324511aeb89 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -17,7 +17,6 @@ using ::mlir::LLVM::linearize; using ::mlir::triton::gpu::DotOperandEncodingAttr; using ::mlir::triton::gpu::getOrder; using ::mlir::triton::gpu::getShapePerCTA; -using ::mlir::triton::gpu::getShapePerCTATile; using ::mlir::triton::gpu::getSizePerThread; using ::mlir::triton::gpu::getTotalElemsPerThread; using ::mlir::triton::gpu::SharedEncodingAttr; diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp index a439b89270a9..cc52507121b5 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -46,7 +46,7 @@ Value redundantDataMask(Type valueTy, ConversionPatternRewriter &rewriter, } else { warpOrder = triton::gpu::getWarpOrder(layout); } - auto shapePerCTATile = triton::gpu::getShapePerCTATile(layout, shape); + auto shapePerCTATile = triton::gpu::getShapePerCTATile(layout); Value warpSize = i32_val(32); Value laneId = urem(tid, warpSize); Value warpId = udiv(tid, warpSize); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp index d1cef15a354e..76b565365406 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp @@ -12,8 +12,6 @@ using namespace mlir; using mlir::LLVM::getWrappedMultiDimOffset; using ::mlir::LLVM::linearize; -using ::mlir::triton::gpu::getShapePerCTA; -using ::mlir::triton::gpu::getShapePerCTATile; namespace { // declare vprintf(i8*, i8*) as external function LLVM::LLVMFuncOp getVprintfDeclaration(RewriterBase &rewriter) {