diff --git a/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/include/triton/Conversion/TritonGPUToLLVM/Utility.h index 6527fe314737..50be42bbc822 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Utility.h +++ b/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -537,12 +537,6 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target, const TargetInfoBase &target, std::function perVectorCallback); -[[nodiscard]] bool emitTransferBetweenRegistersAndShared( - LinearLayout ®Layout, triton::gpu::MemDescType sharedTy, Type elemLlvmTy, - std::optional maxVecElems, const SharedMemoryObject &smemObj, - Location loc, RewriterBase &rewriter, const TargetInfoBase &target, - std::function perVectorCallback); - [[nodiscard]] bool emitTransferBetweenRegistersAndShared( LinearLayout ®Layout, triton::gpu::MemDescType sharedTy, Type elemLlvmTy, std::optional maxVecElems, const SharedMemoryObject &smemObj, diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index d97f18785543..a8bdf8a54f6d 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -167,13 +167,14 @@ def SharedEncodingTrait : AttrInterface<"SharedEncodingTrait"> { ]; } -def SwizzledSharedEncodingAttr : - TritonGPU_Attr<"SwizzledSharedEncoding", "swizzled_shared_encoding", [SharedEncodingTrait, LayoutEncodingTrait]> { +def SwizzledSharedEncodingAttr + : TritonGPU_Attr<"SwizzledSharedEncoding", "swizzled_shared_encoding", + [SharedEncodingTrait, LayoutEncodingTrait]> { let mnemonic = "swizzled_shared"; let description = [{ An encoding for tensors whose elements may be simultaneously accessed by -different cuda threads in the programs, via shared memory. In other words, +different GPU threads in the programs, via shared memory. In other words, for all indices i \in Z^d, \mathcal{L}(i) = {0, 1, ..., 32*num_warps - 1}. In order to avoid shared memory bank conflicts, elements may be swizzled. @@ -181,7 +182,7 @@ Here are some examples. In all cases, the input tensor is [0, 1, ..., n-1]. 1. Basic swizzling - #shared<{vec=1, perPhase=1, maxPhase=4, order=[1,0]}> + #ttg.swizzled_shared<{vec=1, perPhase=1, maxPhase=4, order=[1,0]}> [ 0, 1, 2, 3], // xor with 0 [ 5, 4, 7, 6], // xor with 1 [10, 11, 8, 9], // xor with 2 @@ -192,7 +193,7 @@ out[r][c^r]). 2. Multiple rows per phase - #shared<{vec=1, perPhase=2, maxPhase=4, order=[1,0]}> + #ttg.swizzled_shared<{vec=1, perPhase=2, maxPhase=4, order=[1,0]}> [ 0, 1, 2, 3], // phase 0 (xor with 0) [ 4, 5, 6, 7], [ 9, 8, 11, 10], // phase 1 (xor with 1) @@ -203,7 +204,7 @@ means that pairs of 2 rows get the same swizzling. 3. Max-phase applied - $shared<{vec=1, perPhase=1, maxPhase=2, order=[1,0]}> + #ttg.swizzled_shared<{vec=1, perPhase=1, maxPhase=2, order=[1,0]}> [ 0, 1, 2, 3], // phase 0 (xor with 0) [ 5, 4, 7, 6], // phase 1 (xor with 1) [ 8, 9, 10, 11], // phase 0 @@ -218,7 +219,7 @@ effect of limiting the maximum value of the xor to m-1. 4. Max-phase and per-phase - #shared<{vec=1, perPhase=2, maxPhase=2, order=[1,0]}> + #ttg.swizzled_shared<{vec=1, perPhase=2, maxPhase=2, order=[1,0]}> [ 0, 1, 2, 3], // phase 0 (xor with 0) [ 4, 5, 6, 7], // phase 0 [ 9, 8, 11, 10], // phase 1 (xor with 1) @@ -234,7 +235,7 @@ maximum value of maxPhase-1. In other words, elements of row r are xor'ed with 5. Adding vec - #shared<{vec=2, perPhase=1, maxPhase=4, order=[1,0]}> + #ttg.swizzled_shared<{vec=2, perPhase=1, maxPhase=4, order=[1,0]}> [ 0, 1, 2, 3, 4, 5, 6, 7], [10, 11, 8, 9, 14, 15, 12, 13], [20, 21, 22, 23, 16, 17, 18, 19], @@ -383,6 +384,88 @@ When vec=2, elements are swizzled in pairs of 2. In other words, the element at let genVerifyDecl = 1; } +def PaddeddSharedEncodingAttr + : TritonGPU_Attr<"PaddedSharedEncoding", "padded_shared_encoding", + [SharedEncodingTrait, LayoutEncodingTrait]> { + let mnemonic = "padded_shared"; + + let description = [{ +An encoding for tensors whose elements may be simultaneously accessed by +different GPU threads in the programs, via shared memory. In other words, +for all indices i \in Z^d, \mathcal{L}(i) = {0, 1, ..., 32*num_warps - 1}. +Compared to SwizzledSharedEncodingAttr, this encoding uses padding to avoid +shared memory bank conflicts. + +Formally, given a layout: + padded_shared<[:+, :+, ...]> +We insert a padding of `` elements after every `` elements. +Multi interval-padding pairs are supported for flexibility of multi tiered +padding schemes; they compose in an additive manner. So for a 1-D tensor element +at index i, the corresponding shared memory location index is + i + \sum_{k} (i / interval_k) * pad_k = 1 +`` and `` all need to be power of two. + +Some concrete examples, using `eM` to mean tensor elements and `pN` to mean +padding: + +1. Single interval-padding pair: + + #ttg.padded_shared<[2:+2]> + [e0, e1, p0, p1, + e2, e3, p2, p3, + ...] + +2. Double interval-padding pairs: + + #ttg.padded_shared<[2:+1, 4:+2]> + [e0, e1, p0, + e2, e3, p1, p2, p3, + e4, e5, p4, + e6, e7, p5, p6, p7, + ...] + +In addition to interval-padding pairs, this encoding requires an `order` to +specify the logical tensor dimenions from the fastest-to slowest-varying. +It may optionally support CGA level organization like other encoding +attributes too, for example, + #ttg.padded_shared<[2:+1, 4:+2] { + order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], + CTAOrder = [0, 1]}> + }]; + + let parameters = (ins + ArrayRefParameter<"unsigned">:$intervals, + ArrayRefParameter<"unsigned">:$paddings, + // Order of logical tensor dimensions; fastest-varying first. + ArrayRefParameter<"unsigned">:$order, + "CTALayoutAttr":$CTALayout + ); + + let builders = [ + AttrBuilder<(ins "ArrayRef>":$intervalPads, + "ArrayRef":$order, "CTALayoutAttr":$ctaLayout)>, + ]; + + let extraClassDeclaration = extraBaseClassDeclaration # [{ + unsigned getRank() const { return getOrder().size(); } + int32_t getAlignment() const { return 16; } + + unsigned getMinInterval() const { + return *llvm::min_element(getIntervals()); + } + + // Returns the total number of elements including padding given the input + // tensor shape. + int64_t getPaddedSize(ArrayRef shape) const; + + SmallVector getCTAsPerCGA() const; + SmallVector getCTAOrder() const; + SmallVector getCTASplitNum() const; + }]; + let hasCustomAssemblyFormat = 1; + let genVerifyDecl = 1; +} + def NVMMASharedEncodingAttr : TritonGPU_Attr<"NVMMASharedEncoding", "nvmma_shared_encoding", [SharedEncodingTrait, LayoutEncodingTrait]> { let mnemonic = "nvmma_shared"; diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index 7b897aaacc11..9325e2309713 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -260,12 +260,17 @@ class AllocationAnalysis { auto alloc = dyn_cast(op); if (!alloc || !alloc.isSharedMemoryAlloc()) return; - // Bytes could be a different value once we support padding or other - // allocation policies. auto allocType = alloc.getType(); - auto shapePerCTA = gpu::getAllocationShapePerCTA(allocType); - auto bytes = - product(shapePerCTA) * allocType.getElementTypeBitWidth() / 8; + int64_t numElems = 0; + if (auto paddedLayout = + dyn_cast(allocType.getEncoding())) { + SmallVector unpaddedShape = gpu::getShapePerCTA(allocType); + numElems = paddedLayout.getPaddedSize(unpaddedShape); + } else { + auto shapePerCTA = gpu::getAllocationShapePerCTA(allocType); + numElems = product(shapePerCTA); + } + int64_t bytes = numElems * allocType.getElementTypeBitWidth() / 8; auto alignment = alloc.getAlignmentOrDefault(); allocation->addBuffer(alloc, bytes, diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index fd5fde431a81..681d78e20f88 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -8,7 +8,9 @@ #include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" #include "triton/Tools/LayoutUtils.h" +#include "triton/Tools/LinearLayout.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/Support/MathExtras.h" #if defined(_MSC_VER) && !defined(__clang__) // from https://gist.github.com/pps83/3210a2f980fd02bb2ba2e5a1fc4a2ef0 @@ -408,6 +410,10 @@ Value getSmemVecAddr(const LinearLayout ®Layout, // We propose case 2 (see comments below), which provides a more general // solution for all swizzled shared memory scenarios, including the edge case // mentioned above. + // + // Padded shared layout falls into case 1--we can rely on the logic for case 1 + // to get the 1-D offset into shared memory. Then we just need to add the + // padding offset. if (isSimpleSharedMemoryAccess(shape, allocShape, sharedEnc)) { // Case 1 smemOffset = applyLinearLayout(loc, rewriter, regToSharedLayout, {{kRegister, regId}, @@ -436,6 +442,18 @@ Value getSmemVecAddr(const LinearLayout ®Layout, smemOffset = dot(rewriter, loc, smemOffsets, applyPermutation(smemStrides, smemOrder)); } + if (auto paddedLayout = + dyn_cast(sharedEnc)) { + // Apply the offset needed for padding. + Value padOffset = b.i32_val(0); + for (auto [interval, padding] : llvm::zip_equal( + paddedLayout.getIntervals(), paddedLayout.getPaddings())) { + Value iVal = b.i32_val(llvm::Log2_32(interval)); + Value pVal = b.i32_val(llvm::Log2_32(padding)); + padOffset = b.add(padOffset, b.shl(b.ashr(smemOffset, iVal), pVal)); + } + smemOffset = b.add(smemOffset, padOffset); + } } else { // Case 2 -> rank-reduced swizzling assert(rank >= 2 && "Swizzling only applies to tensors with rank >= 2"); assert((isa lowerLocalLdSt(Location loc, MLIRContext *ctx, rewriter, targetInfo); } -bool emitTransferBetweenRegistersAndShared( - LinearLayout ®Layout, triton::gpu::MemDescType sharedTy, Type elemLlvmTy, - std::optional maxVecElems, const SharedMemoryObject &smemObj, - Location loc, RewriterBase &rewriter, const TargetInfoBase &target, - std::function perVectorCallback) { - auto [laneId, warpId] = getLaneAndWarpId(rewriter, loc); - return emitTransferBetweenRegistersAndShared( - regLayout, sharedTy, elemLlvmTy, maxVecElems, smemObj, loc, rewriter, - target, laneId, warpId, perVectorCallback); -} - bool emitTransferBetweenRegistersAndShared( LinearLayout ®Layout, triton::gpu::MemDescType sharedTy, Type elemLlvmTy, std::optional maxVecElems, const SharedMemoryObject &smemObj, @@ -652,11 +659,19 @@ bool emitTransferBetweenRegistersAndShared( StringAttr kRegister = str_attr("register"); StringAttr kLane = str_attr("lane"); StringAttr kWarp = str_attr("warp"); + StringAttr kOffset = str_attr("offset"); auto shape = sharedTy.getShape(); - LinearLayout sharedLayout = - triton::gpu::toLinearLayout(shape, sharedTy.getEncoding()); - LinearLayout regToSharedLayout = regLayout.invertAndCompose(sharedLayout); + auto paddedLayout = + dyn_cast(sharedTy.getEncoding()); + LinearLayout regToSharedLayout = LinearLayout::empty(); + if (paddedLayout) { + regToSharedLayout = + regLayout.reshapeOuts({{kOffset, regLayout.getTotalOutDimSize()}}); + } else { + auto sharedLL = triton::gpu::toLinearLayout(shape, sharedTy.getEncoding()); + regToSharedLayout = regLayout.invertAndCompose(sharedLL); + } // TODO(jlebar): We don't currently support loading from shared memory in a // different CTA. We'd need to emit `mapa.shared::cluster` instructions. @@ -681,9 +696,12 @@ bool emitTransferBetweenRegistersAndShared( // // It's OK if the vector width we choose here is wider than the hardware // supports; LLVM will legalize it. - const int vecElems = - std::min(regToSharedLayout.getNumConsecutiveInOut(), - maxVecElems.value_or(std::numeric_limits::max())); + int vecElems = + std::min({regToSharedLayout.getNumConsecutiveInOut(), + maxVecElems.value_or(std::numeric_limits::max())}); + if (paddedLayout) { + vecElems = std::min(vecElems, int(paddedLayout.getMinInterval())); + } auto withCTAOffset = triton::gpu::getNumCTAs(sharedTy.getEncoding()) > 1; Value blockId = @@ -697,10 +715,14 @@ bool emitTransferBetweenRegistersAndShared( // take out the "block" dimension. // Thus we use `pseudoinvert` instead of `invert` here for simplicity. auto allocShape = sharedTy.getAllocShape(); - LinearLayout invertAllocSharedLayout = - triton::gpu::toLinearLayout(allocShape.take_back(sharedTy.getRank()), - sharedTy.getEncoding()) - .pseudoinvert(); + auto invertAllocSharedLayout = LinearLayout::empty(); + if (!paddedLayout) { + // For now this is only needed for the cases where we have swizzling. + invertAllocSharedLayout = + triton::gpu::toLinearLayout(allocShape.take_back(sharedTy.getRank()), + sharedTy.getEncoding()) + .pseudoinvert(); + } int numElems = regToSharedLayout.getInDimSize(kRegister); auto vecTy = vec_ty(elemLlvmTy, vecElems); @@ -723,9 +745,10 @@ bool emitTransferBetweenRegistersAndShared( std::function perVectorCallback) { auto regLayout = triton::gpu::toLinearLayout(registerTy.getShape(), registerTy.getEncoding()); + auto [laneId, warpId] = getLaneAndWarpId(rewriter, loc); return emitTransferBetweenRegistersAndShared( regLayout, sharedTy, elemLlvmTy, maxVecElems, smemObj, loc, rewriter, - target, perVectorCallback); + target, laneId, warpId, perVectorCallback); } SmallVector loadSharedToDistributed(triton::gpu::LocalLoadOp localLoadOp, @@ -913,10 +936,13 @@ bool isSimpleSharedMemoryAccess(ArrayRef shape, ArrayRef allocShape, triton::gpu::SharedEncodingTrait sharedEnc) { auto rank = shape.size(); + auto paddedLayout = + dyn_cast(sharedEnc); auto swizzledLayout = dyn_cast(sharedEnc); auto nvmmaLayout = dyn_cast(sharedEnc); - bool noSwizzling = (swizzledLayout && swizzledLayout.getMaxPhase() == 1) || + bool noSwizzling = paddedLayout || + (swizzledLayout && swizzledLayout.getMaxPhase() == 1) || (nvmmaLayout && nvmmaLayout.getSwizzlingByteWidth() == 0); return /*no swizzling*/ noSwizzling || /*swizzling but same shape*/ shape == allocShape || diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index c4c8a19081ec..6719783b6581 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -5,6 +5,7 @@ #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/IR/OperationSupport.h" #include "mlir/Support/LLVM.h" #include "triton/Analysis/Utility.h" #include "triton/Dialect/Triton/IR/Interfaces.h" @@ -20,7 +21,9 @@ #include "triton/Tools/LinearLayout.h" #include "triton/Tools/StrUtil.h" #include "triton/Tools/Sys/GetEnv.hpp" +#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/MathExtras.h" // Include TableGen'erated code #include "triton/Dialect/TritonGPU/IR/Dialect.cpp.inc" @@ -175,18 +178,19 @@ SmallVector getRepOrder(RankedTensorType type) { // This one's not terribly bad as we don't broadcast ShareEncodings SmallVector getOrder(SharedEncodingTrait layout, ArrayRef shape) { - if (auto swizzledLayout = - mlir::dyn_cast(layout)) { + if (auto swizzledLayout = dyn_cast(layout)) { return llvm::to_vector(swizzledLayout.getOrder()); } - if (auto sharedLayout = mlir::dyn_cast(layout)) { + if (auto paddedLayout = dyn_cast(layout)) { + return llvm::to_vector(paddedLayout.getOrder()); + } + if (auto sharedLayout = dyn_cast(layout)) { if (shape.size() == 1) { return {0}; } return getMatrixOrder(shape.size(), !sharedLayout.getTransposed()); } - if (auto sharedLayout = - mlir::dyn_cast(layout)) { + if (auto sharedLayout = dyn_cast(layout)) { return llvm::to_vector(sharedLayout.getOrder()); } llvm::report_fatal_error("Unimplemented usage of getOrder for MemDescType"); @@ -314,7 +318,7 @@ SmallVector getShapePerCTA(Attribute layout, ArrayRef shape) { SmallVector getAllocationShapePerCTA(Attribute layout, ArrayRef shapeLogical) { SmallVector shape(shapeLogical); - if (auto sharedMMALayout = mlir::dyn_cast(layout)) { + if (auto sharedMMALayout = dyn_cast(layout)) { if (sharedMMALayout.getFp4Padded()) { auto packedAxis = getOrder(sharedMMALayout, shapeLogical)[0]; shape[packedAxis] *= 2; @@ -658,6 +662,16 @@ SmallVector SwizzledSharedEncodingAttr::getCTASplitNum() const { return SmallVector(getCTALayout().getCTASplitNum()); } +SmallVector PaddedSharedEncodingAttr::getCTAsPerCGA() const { + return llvm::to_vector(getCTALayout().getCTAsPerCGA()); +} +SmallVector PaddedSharedEncodingAttr::getCTAOrder() const { + return llvm::to_vector(getCTALayout().getCTAOrder()); +} +SmallVector PaddedSharedEncodingAttr::getCTASplitNum() const { + return llvm::to_vector(getCTALayout().getCTASplitNum()); +} + int32_t AMDRotatingSharedEncodingAttr::getAlignment() const { return 16; } SmallVector AMDRotatingSharedEncodingAttr::getCTAsPerCGA() const { @@ -1509,6 +1523,35 @@ void SliceEncodingAttr::print(mlir::AsmPrinter &printer) const { // Helper shared encoding functions //===----------------------------------------------------------------------===// +std::optional +parseCTAAttrs(AsmParser &parser, NamedAttrList attrList, unsigned rank) { + std::optional> CTAsPerCGA; + std::optional> CTASplitNum; + std::optional> CTAOrder; + + for (const NamedAttribute &attr : attrList) { + if (attr.getName() == "CTAsPerCGA") { + if (parseIntArrayAttr(parser, attr, CTAsPerCGA.emplace(), "CTAsPerCGA") + .failed()) + return {}; + } else if (attr.getName() == "CTASplitNum") { + if (parseIntArrayAttr(parser, attr, CTASplitNum.emplace(), "CTASplitNum") + .failed()) + return {}; + } else if (attr.getName() == "CTAOrder") { + if (parseIntArrayAttr(parser, attr, CTAOrder.emplace(), "CTAOrder") + .failed()) + return {}; + } else { + parser.emitError(parser.getNameLoc(), "unexpected key: ") + << attr.getName().strref(); + return {}; + } + } + + return getCTALayoutOrError(parser, CTAsPerCGA, CTASplitNum, CTAOrder, rank); +} + template Attribute parseSwizzledEncoding(AsmParser &parser, Type type) { if (parser.parseLess().failed()) @@ -1524,9 +1567,7 @@ Attribute parseSwizzledEncoding(AsmParser &parser, Type type) { unsigned perPhase = 0; unsigned maxPhase = 0; SmallVector order; - std::optional> CTAsPerCGA; - std::optional> CTASplitNum; - std::optional> CTAOrder; + NamedAttrList remainingAttrs; for (const NamedAttribute &attr : dict) { if (attr.getName() == "vec") { if (parseUInt(parser, attr, vec, "vec").failed()) @@ -1540,32 +1581,15 @@ Attribute parseSwizzledEncoding(AsmParser &parser, Type type) { } else if (attr.getName() == "order") { if (parseIntArrayAttr(parser, attr, order, "order").failed()) return {}; - } else if (attr.getName() == "CTAsPerCGA") { - if (parseIntArrayAttr(parser, attr, CTAsPerCGA.emplace(), "CTAsPerCGA") - .failed()) - return {}; - } else if (attr.getName() == "CTASplitNum") { - if (parseIntArrayAttr(parser, attr, CTASplitNum.emplace(), "CTASplitNum") - .failed()) - return {}; - } else if (attr.getName() == "CTAOrder") { - if (parseIntArrayAttr(parser, attr, CTAOrder.emplace(), "CTAOrder") - .failed()) - return {}; } else { - parser.emitError(parser.getNameLoc(), "unexpected key: ") - << attr.getName().strref(); - return {}; + remainingAttrs.push_back(attr); } } - std::optional CTALayout = getCTALayoutOrError( - parser, CTAsPerCGA, CTASplitNum, CTAOrder, /*rank=*/order.size()); - if (!CTALayout.has_value()) - return {}; - - return parser.getChecked(parser.getContext(), vec, perPhase, - maxPhase, order, *CTALayout); + if (auto CTALayout = parseCTAAttrs(parser, remainingAttrs, order.size())) + return parser.getChecked( + parser.getContext(), vec, perPhase, maxPhase, order, *CTALayout); + return {}; } //===----------------------------------------------------------------------===// @@ -1600,6 +1624,123 @@ void SwizzledSharedEncodingAttr::print(AsmPrinter &printer) const { printer << "}>"; } +//===----------------------------------------------------------------------===// +// PaddedShared encoding +//===----------------------------------------------------------------------===// + +Attribute PaddedSharedEncodingAttr::parse(AsmParser &parser, Type type) { + // <[ + if (failed(parser.parseLess()) || failed(parser.parseLSquare())) + return {}; + + // :+ + SmallVector intervals, paddings; + auto parseIntervalPaddingPair = [&]() { + unsigned interval = 0, padding = 0; + if (failed(parser.parseInteger(interval)) || failed(parser.parseColon()) || + failed(parser.parsePlus()) || failed(parser.parseInteger(padding))) + return failure(); + intervals.push_back(interval); + paddings.push_back(padding); + return success(); + }; + // ] + if (failed(parser.parseCommaSeparatedList(parseIntervalPaddingPair)) || + failed(parser.parseRSquare())) + return {}; + + // {}> + NamedAttrList attrList; + if (failed(parser.parseOptionalAttrDict(attrList)) || + failed(parser.parseGreater())) + return {}; + + // Decode order and CTA attributes + SmallVector order; + NamedAttrList remainingAttrs; + for (const NamedAttribute &attr : attrList) { + if (attr.getName() == "order") { + if (parseIntArrayAttr(parser, attr, order, "order").failed()) + return {}; + } else { + remainingAttrs.push_back(attr); + } + } + if (auto ctaLayout = parseCTAAttrs(parser, remainingAttrs, order.size())) + return parser.getChecked( + parser.getContext(), intervals, paddings, order, *ctaLayout); + return {}; +} + +void PaddedSharedEncodingAttr::print(AsmPrinter &printer) const { + printer << "<["; + llvm::interleaveComma(llvm::zip(getIntervals(), getPaddings()), printer, + [&](std::tuple intervalPad) { + printer << std::get<0>(intervalPad) << ":+" + << std::get<1>(intervalPad); + }); + printer << "] {order = [" << getOrder() << "]"; + maybePrintCTALayout(getContext(), printer, getCTALayout(), + /*rank=*/getOrder().size()); + printer << "}>"; +} + +LogicalResult PaddedSharedEncodingAttr::verify( + function_ref emitError, ArrayRef intervals, + ArrayRef paddings, ArrayRef order, + CTALayoutAttr ctaLayout) { + if (intervals.size() != paddings.size()) + return emitError() << "intervals size (" << intervals.size() + << ") must match paddings size (" << paddings.size() + << ")"; + + if (intervals.empty()) + return emitError() << "must have at least one interval-padding pair"; + + if (!llvm::all_of(intervals, llvm::isPowerOf2_32)) + return emitError() << "interval values must all be power of two"; + if (!llvm::all_of(paddings, llvm::isPowerOf2_32)) + return emitError() << "padding values must all be power of two"; + + llvm::SmallSet intervalValues(intervals.begin(), + intervals.end()); + if (intervalValues.size() != intervals.size()) + return emitError() << "interval values cannot have duplicates"; + + if (order.empty()) + return emitError() << "order cannot be empty"; + + if (order.size() != ctaLayout.getRank()) + return emitError() << "order size (" << order.size() + << ") must match CTALayout rank (" << ctaLayout.getRank() + << ")"; + return verifyLayoutOrder(emitError, order); +} + +PaddedSharedEncodingAttr PaddedSharedEncodingAttr::get( + MLIRContext *context, ArrayRef> intervalPads, + ArrayRef order, CTALayoutAttr ctaLayout) { + SmallVector intervals, paddings; + intervals.reserve(intervalPads.size()); + paddings.reserve(intervalPads.size()); + for (auto [interval, padding] : intervalPads) { + intervals.push_back(interval); + paddings.push_back(padding); + } + return get(context, intervals, paddings, order, ctaLayout); +} + +int64_t PaddedSharedEncodingAttr::getPaddedSize(ArrayRef shape) const { + int64_t unpaddedSize = product(shape); + int64_t paddingSize = 0; + for (auto [interval, padding] : + llvm::zip_equal(getIntervals(), getPaddings())) { + paddingSize += (unpaddedSize >> llvm::Log2_32(interval)) + << llvm::Log2_32(padding); + } + return unpaddedSize + paddingSize; +} + //===----------------------------------------------------------------------===// // NVMMAShared encoding //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index 0ac56a8a78ef..f06526a1a63b 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -1,7 +1,5 @@ #include -#include "triton/Dialect/Triton/IR/Dialect.h" -#include "triton/Dialect/Triton/IR/Utility.h" #include "triton/Dialect/TritonGPU/IR/Attributes.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" @@ -11,13 +9,10 @@ #include "triton/Tools/LinearLayout.h" #include "triton/Tools/StrUtil.h" #include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/DenseSet.h" #include "llvm/ADT/Twine.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/MathExtras.h" -using mlir::triton::ScaleDotElemType; - namespace mlir::triton::gpu { namespace { diff --git a/lib/Tools/LinearLayout.cpp b/lib/Tools/LinearLayout.cpp index 55f8039558e2..5a6c25136def 100644 --- a/lib/Tools/LinearLayout.cpp +++ b/lib/Tools/LinearLayout.cpp @@ -8,7 +8,6 @@ #include "third_party/f2reduce/f2reduce.h" #include "triton/Tools/LayoutUtils.h" #include "triton/Tools/StrUtil.h" -#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetOperations.h" #include "llvm/ADT/StringRef.h" diff --git a/test/Analysis/test-allocation.mlir b/test/Analysis/test-allocation.mlir index 3400039ed352..27739e0e561d 100644 --- a/test/Analysis/test-allocation.mlir +++ b/test/Analysis/test-allocation.mlir @@ -21,6 +21,9 @@ #NVMMA_SHARED_64 = #ttg.nvmma_shared<{swizzlingByteWidth = 64, transposed = false, elementBitWidth = 16}> #NVMMA_SHARED_128 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> #NVMMA_SHARED_FP4PADDED = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 8, fp4Padded = true}> +#PADDED_SHARED_0 = #ttg.padded_shared<[256:+8] {order = [1, 0]}> +#PADDED_SHARED_1 = #ttg.padded_shared<[128:+4, 256:+8] {order = [1, 0]}> +#PADDED_SHARED_2 = #ttg.padded_shared<[64:+2, 128:+4, 256:+8] {order = [1, 0]}> #smem = #ttg.shared_memory @@ -937,4 +940,64 @@ tt.func @nvmma_alignment(%lb : index, %ub : index, %step : index, %A : !tt.ptr !ttg.memdesc<1x255xf16, #PADDED_SHARED_0, #ttg.shared_memory, mutable> + // expected-remark @+2 {{offset = 0, size = 528}} + // (256 + 8) * 2B = 528B + %alloc1 = ttg.local_alloc : () -> !ttg.memdesc<1x256xf16, #PADDED_SHARED_0, #ttg.shared_memory, mutable> + // expected-remark @+2 {{offset = 0, size = 530}} + // (257 + 8) * 2B = 530B + %alloc2 = ttg.local_alloc : () -> !ttg.memdesc<1x257xf16, #PADDED_SHARED_0, #ttg.shared_memory, mutable> + // expected-remark @+2 {{offset = 0, size = 1038}} + // (511 + 8) * 2B = 1038B + %alloc3 = ttg.local_alloc : () -> !ttg.memdesc<1x511xf16, #PADDED_SHARED_0, #ttg.shared_memory, mutable> + // expected-remark @+2 {{offset = 0, size = 1056}} + // (512 + 8 * 2) * 2B = 1056B + %alloc4 = ttg.local_alloc : () -> !ttg.memdesc<1x512xf16, #PADDED_SHARED_0, #ttg.shared_memory, mutable> + // expected-remark @+2 {{offset = 0, size = 1058}} + // (513 + 8 * 2) * 2B = 1058B + %alloc5 = ttg.local_alloc : () -> !ttg.memdesc<1x513xf16, #PADDED_SHARED_0, #ttg.shared_memory, mutable> + // expected-remark @+2 {{offset = 0, size = 528}} + // (16 * 16 + 8) * 2B = 528B + %alloc6 = ttg.local_alloc : () -> !ttg.memdesc<16x16xf16, #PADDED_SHARED_0, #ttg.shared_memory, mutable> + // expected-remark @+2 {{offset = 0, size = 1056}} + // (16 * 32 + 8 * 2) * 2B = 1056B + %alloc7 = ttg.local_alloc : () -> !ttg.memdesc<16x32xf16, #PADDED_SHARED_0, #ttg.shared_memory, mutable> + // expected-remark @+2 {{offset = 0, size = 1008}} + // (31 * 16 + 8) * 2B = 1008B + %alloc8 = ttg.local_alloc : () -> !ttg.memdesc<31x16xf16, #PADDED_SHARED_0, #ttg.shared_memory, mutable> + tt.return +} + +// expected-remark @below {{padded_shared_layout_element_type}} +// expected-remark @below {{size = 16896}} +tt.func @padded_shared_layout_element_type() { + // expected-remark @+2 {{offset = 0, size = 4224}} + // (16 * 256 + 8 * 16) * 1B = 4224B + %alloc0 = ttg.local_alloc : () -> !ttg.memdesc<16x256xi8, #PADDED_SHARED_0, #ttg.shared_memory, mutable> + // expected-remark @+2 {{offset = 0, size = 8448}} + // (16 * 256 + 8 * 16) * 2B = 8448B + %alloc1 = ttg.local_alloc : () -> !ttg.memdesc<16x256xf16, #PADDED_SHARED_0, #ttg.shared_memory, mutable> + // expected-remark @+2 {{offset = 0, size = 16896}} + // (16 * 256 + 8 * 16) * 4B = 16896B + %alloc2 = ttg.local_alloc : () -> !ttg.memdesc<16x256xf32, #PADDED_SHARED_0, #ttg.shared_memory, mutable> + tt.return +} + +// expected-remark @below {{padded_shared_layout_multi_tier}} +// expected-remark @below {{size = 4480}} +tt.func @padded_shared_layout_multi_tier() { + // expected-remark @+2 {{offset = 0, size = 4352}} + // (16 * 256 + 4 * 32 + 8 * 16) * 1B = 4352B + %alloc0 = ttg.local_alloc : () -> !ttg.memdesc<16x256xi8, #PADDED_SHARED_1, #ttg.shared_memory, mutable> + // expected-remark @+2 {{offset = 0, size = 4480}} + // (16 * 256 + 2 * 64 + 4 * 32 + 8 * 16) * 1B = 4480B + %alloc1 = ttg.local_alloc : () -> !ttg.memdesc<16x256xi8, #PADDED_SHARED_2, #ttg.shared_memory, mutable> + tt.return +} } diff --git a/test/Conversion/amd/tritongpu_to_llvm.mlir b/test/Conversion/amd/tritongpu_to_llvm.mlir index aea48d2a4d05..16784f2150b7 100644 --- a/test/Conversion/amd/tritongpu_to_llvm.mlir +++ b/test/Conversion/amd/tritongpu_to_llvm.mlir @@ -380,3 +380,32 @@ module attributes {"ttg.target" = "hip:gfx942", "ttg.num-ctas" = 1 : i32, "ttg.n tt.return } } + +// ----- + +// CHECK-LABEL: padded_shared_layout +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 8], warpsPerCTA = [2, 2], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#shared = #ttg.padded_shared<[128:+4, 256:+8] {order = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.target" = "hip:gfx942", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func @padded_shared_layout(%arg0: tensor<64x64xf16, #blocked>) { + // CHECK-DAG: %[[CST0:.+]] = llvm.mlir.constant(0 : i32) + // CHECK-DAG: %[[CST2:.+]] = llvm.mlir.constant(2 : i32) + // CHECK-DAG: %[[CST3:.+]] = llvm.mlir.constant(3 : i32) + // CHECK-DAG: %[[CST7:.+]] = llvm.mlir.constant(7 : i32) + // CHECK-DAG: %[[CST8:.+]] = llvm.mlir.constant(8 : i32) + + // CHECK: %[[SHR0:.+]] = llvm.ashr %[[XOR:.+]], %[[CST7]] : i32 + // CHECK-NEXT: %[[SHL0:.+]] = llvm.shl %[[SHR0]], %[[CST2]] : i32 + // CHECK-NEXT: %[[ADD0:.+]] = llvm.add %[[SHL0]], %[[CST0]] : i32 + // CHECK-NEXT: %[[SHR1:.+]] = llvm.ashr %[[XOR]], %[[CST8]] : i32 + // CHECK-NEXT: %[[SHL1:.+]] = llvm.shl %[[SHR1]], %[[CST3]] : i32 + // CHECK-NEXT: %[[ADD1:.+]] = llvm.add %[[ADD0]], %[[SHL1]] : i32 + // CHECK-NEXT: %[[ADD2:.+]] = llvm.add %[[XOR]], %[[ADD1]] : i32 + // CHECK-NEXT: llvm.getelementptr inbounds %{{.+}}[%[[ADD2]]] + + // CHECK-COUNT-16: llvm.store {{.*}} : vector<1xf16>, !llvm.ptr<3> + %0 = ttg.local_alloc %arg0 : (tensor<64x64xf16, #blocked>) -> !ttg.memdesc<64x64xf16, #shared, #smem, mutable> + tt.return + } +} diff --git a/test/TritonGPU/invalid-attributes.mlir b/test/TritonGPU/invalid-attributes.mlir index df693a6ea81c..22938b6055b7 100644 --- a/test/TritonGPU/invalid-attributes.mlir +++ b/test/TritonGPU/invalid-attributes.mlir @@ -76,3 +76,43 @@ // expected-error@+1 {{(M, N) cases other than (32, 32) or (16, 16) unimplemented}} #mfma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1, 1], instrShape = [16, 8], isTransposed = false}> + +// ----- + +// expected-error@+1 {{interval values must all be power of two}} +#shared = #ttg.padded_shared<[3:+2]> + +// ----- + +// expected-error@+1 {{interval values must all be power of two}} +#shared = #ttg.padded_shared<[0:+2]> + +// ----- + +// expected-error@+1 {{padding values must all be power of two}} +#shared = #ttg.padded_shared<[2:+3]> + +// ----- + +// expected-error@+1 {{padding values must all be power of two}} +#shared = #ttg.padded_shared<[2:+0]> + +// ----- + +// expected-error@+1 {{interval values cannot have duplicates}} +#shared = #ttg.padded_shared<[2:+1, 2:+4]> + +// ----- + +// expected-error@+1 {{order cannot be empty}} +#shared = #ttg.padded_shared<[2:+1, 4:+2]> + +// ----- + +// expected-error@+1 {{unexpected key: unknown}} +#shared = #ttg.padded_shared<[2:+1, 4:+2] {order = [1, 0], unknown = 5}> + +// ----- + +// expected-error@+1 {{order size (3) must match CTALayout rank (2)}} +#shared = #ttg.padded_shared<[2:+1, 4:+2] {order = [2, 1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandWMMA.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandWMMA.cpp index 221d3b849d1b..77fc628b446e 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandWMMA.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandWMMA.cpp @@ -154,7 +154,10 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, auto aTensorTy = cast(tensor.getType()); ArrayRef shape = aTensorTy.getShape(); - auto sharedLayout = cast(aTensorTy.getEncoding()); + auto sharedLayout = + dyn_cast(aTensorTy.getEncoding()); + if (!sharedLayout) + return Value(); auto order = sharedLayout.getOrder(); // Rely on the linear layout conversion logic in this case, since only slowest diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/MemoryOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/MemoryOpToLLVM.cpp index dcc70aa64198..a0c4f2083b19 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/MemoryOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/MemoryOpToLLVM.cpp @@ -271,10 +271,11 @@ struct TransLocalLoadOpConversion SmallVector outVals; SmallVector elemsI32; mlir::Type retTy = dstTy; + auto [laneId, warpId] = getLaneAndWarpId(rewriter, loc); bool valid = emitTransferBetweenRegistersAndShared( ldsTransLayout, srcTy, llvmElemTy, /*maxVecElems=*/std::nullopt, smemObj, loc, rewriter, targetInfo, - [&](VectorType vecTy, Value vecAddr) { + laneId, warpId, [&](VectorType vecTy, Value vecAddr) { if (bitwidth == 16) { auto dsReadOp = rewriter.create(loc, vecTy, vecAddr);