diff --git a/bin/RegisterTritonDialects.h b/bin/RegisterTritonDialects.h index d5eb81eb9f4a..f094ce963a5a 100644 --- a/bin/RegisterTritonDialects.h +++ b/bin/RegisterTritonDialects.h @@ -63,6 +63,8 @@ inline void registerTritonDialects(mlir::DialectRegistry ®istry) { mlir::registerTritonAMDGPUStreamPipelineV2(); mlir::registerTritonAMDGPUCanonicalizePointers(); mlir::registerTritonAMDGPUConvertToBufferOps(); + mlir::triton::registerTritonAMDGPUInsertInstructionSchedHints(); + mlir::triton::registerTritonAMDGPULowerInstructionSchedHints(); // TODO: register Triton & TritonGPU passes registry.insert reorderValues(const SmallVector &values, Type inType, - Type ouType); - -SmallVector unpackI32(const SmallVector &inValues, Type srcTy, - ConversionPatternRewriter &rewriter, Location loc, - const LLVMTypeConverter *typeConverter); - -SmallVector packI32(const SmallVector &inValues, Type srcTy, - ConversionPatternRewriter &rewriter, Location loc, - const LLVMTypeConverter *typeConverter); - Type getElementType(Value value); class MultipleOperandsRange @@ -187,8 +176,6 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern { for (auto operand : adaptor.getOperands()) { auto argTy = op->getOperand(0).getType(); auto subOperands = unpackLLElements(loc, operand, rewriter); - subOperands = unpackI32(subOperands, argTy, rewriter, loc, - this->getTypeConverter()); allOperands.resize(subOperands.size()); for (auto v : llvm::enumerate(subOperands)) allOperands[v.index()].push_back(v.value()); @@ -209,13 +196,7 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern { } it += curr.size(); } - if (op->getNumOperands() > 0) { - auto argTy = op->getOperand(0).getType(); - resultVals = reorderValues(resultVals, argTy, resultTy); - } resultVals = maybeDeduplicate(op, resultVals); - resultVals = - packI32(resultVals, resultTy, rewriter, loc, this->getTypeConverter()); Value view = packLLElements(loc, this->getTypeConverter(), resultVals, rewriter, resultTy); rewriter.replaceOp(op, view); diff --git a/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h b/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h index 29aec5904e8e..b6d2fbeff94f 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h +++ b/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h @@ -27,15 +27,33 @@ constexpr int patternBenefitPrioritizeOverLLVMConversions = 10; constexpr int patternBenefitClampOptimizedPattern = 20; constexpr int patternBenefitConvertLayoutOptimizedPattern = 20; +struct BackendCallbacks { + /** + * A backend-specific callback for appending auxiliary data during + * `LocalStoreOp` conversion. + * + * @param[in] op The reference to the re-written `LocalStoreOp`. + * @param[in] count The number of issued LLVM instructions. + * @param[in] type The input type of issued LLVM instructions. + */ + std::function + localStoreOpConversion = nullptr; +}; + void populateElementwiseOpToLLVMPatterns( LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, ModuleAxisInfoAnalysis &axisInfoAnalysis, const TargetInfoBase &targetInfo, PatternBenefit benefit); -void populateMemoryOpToLLVMPattern(LLVMTypeConverter &typeConverter, - const TargetInfoBase &targetInfo, - RewritePatternSet &patterns, - PatternBenefit benefit); +// The given callback is invoked at the end of a successful rewrite. The +// callback receives 1) the current source op, 2) the number of issued LLVM +// instructions and 3) their input types. Each MLIR backend can provide a +// callback and, thus, handle backend-specific behaviors. +void populateMemoryOpToLLVMPattern( + LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, + RewritePatternSet &patterns, PatternBenefit benefit, + std::optional backendCallbacks = std::nullopt); void populateAssertOpToLLVMPattern(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, diff --git a/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/include/triton/Conversion/TritonGPUToLLVM/Utility.h index 29b8865c03ae..253033e98e8c 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Utility.h +++ b/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -391,6 +391,19 @@ inline Value getSharedMemoryBase(Location loc, RewriterBase &rewriter, Value base = gep(ptrTy, i8_ty, LLVM::getStackPointer(rewriter, func), offVal); return base; } + +// ----------------------------------------------------------------------- +// MXFP utilities +// ----------------------------------------------------------------------- + +// Convert each value, which is an int8 containing 2 packed mxfp4 values, +// into 2 standalone bf16 values +SmallVector convertMxfp4x2ToBf16x2(RewriterBase &rewriter, Location loc, + ArrayRef values); + +// Scale a mxfp4 value by a given scale. +Value mxfpScaleBf16(RewriterBase &rewriter, Location loc, Value v, Value scale); + } // namespace LLVM /* ------------------------------------ */ @@ -453,15 +466,16 @@ emitBaseIndexWithinCTAForBlockedLayout(Location loc, RewriterBase &rewriter, auto sizePerThread = blockedLayout.getSizePerThread(); auto threadsPerWarp = blockedLayout.getThreadsPerWarp(); auto warpsPerCTA = blockedLayout.getWarpsPerCTA(); - auto order = blockedLayout.getOrder(); + auto threadOrder = blockedLayout.getThreadOrder(); + auto warpOrder = blockedLayout.getWarpOrder(); auto shapePerCTA = triton::gpu::getShapePerCTA(blockedLayout, shape); unsigned rank = shape.size(); // delinearize threadId to get the base index SmallVector multiDimWarpId = - delinearize(rewriter, loc, warpId, warpsPerCTA, order); + delinearize(rewriter, loc, warpId, warpsPerCTA, warpOrder); SmallVector multiDimThreadId = - delinearize(rewriter, loc, laneId, threadsPerWarp, order); + delinearize(rewriter, loc, laneId, threadsPerWarp, threadOrder); SmallVector multiDimBase(rank); for (unsigned k = 0; k < rank; ++k) { @@ -1366,11 +1380,11 @@ SmallVector loadSharedToDistributed(RankedTensorType dstTy, Location loc, RewriterBase &rewriter, const TargetInfoBase &target); -void storeDistributedToShared(MemDescType dstTy, RankedTensorType srcTy, - Type elemLlvmTy, ArrayRef srcVals, - Value smemBase, ArrayRef dstStrides, - Location loc, RewriterBase &rewriter, - const TargetInfoBase &target); +void storeDistributedToShared( + MemDescType dstTy, RankedTensorType srcTy, Type elemLlvmTy, + ArrayRef srcVals, Value smemBase, ArrayRef dstStrides, + Location loc, RewriterBase &rewriter, const TargetInfoBase &target, + std::pair *const llvmOpCount = nullptr); inline Value getStructFromSharedMemoryObject(Location loc, const SharedMemoryObject &smemObj, diff --git a/include/triton/Dialect/Triton/IR/TritonAttrDefs.td b/include/triton/Dialect/Triton/IR/TritonAttrDefs.td index f3159338bd0a..04e4c25fd6d8 100644 --- a/include/triton/Dialect/Triton/IR/TritonAttrDefs.td +++ b/include/triton/Dialect/Triton/IR/TritonAttrDefs.td @@ -119,15 +119,16 @@ def TT_InputPrecisionAttr : I32EnumAttr< let cppNamespace = "::mlir::triton"; } -// Type for F8F6F4 kind of floats. -def TT_F8F6F4TypeAttr : I32EnumAttr< - "F8F6F4Type", "", +// Type for ScaleDotElemType kind of floats. +def TT_ScaleDotElemTypeAttr : I32EnumAttr< + "ScaleDotElemType", "", [ I32EnumAttrCase<"E4M3", 0, "e4m3">, I32EnumAttrCase<"E5M2", 1, "e5m2">, I32EnumAttrCase<"E2M3", 2, "e2m3">, I32EnumAttrCase<"E3M2", 3, "e3m2">, - I32EnumAttrCase<"E2M1", 4, "e2m1"> + I32EnumAttrCase<"E2M1", 4, "e2m1">, + I32EnumAttrCase<"BF16", 5, "bf16"> ]>{ let cppNamespace = "::mlir::triton"; diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index 283dd9165918..010901f28735 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -685,15 +685,15 @@ def TT_DotScaledOp : TT_Op<"dot_scaled", [Pure, let arguments = ( ins - // inputs are integer types as they are packed types and we currently - // don't have a representation for those. - TT_IntTensor:$lhs, - TT_IntTensor:$rhs, + // inputs are floats if we have a type for them, otherwise (fp4), + // they are packed in pairs in an I8Tensor + RankedTensorOf<[TT_Float,I8]>:$lhs, + RankedTensorOf<[TT_Float,I8]>:$rhs, TT_FloatTensor:$c, - TT_IntTensor:$lhs_scale, - Optional:$rhs_scale, - TT_F8F6F4TypeAttr:$lhs_type, - TT_F8F6F4TypeAttr:$rhs_type + RankedTensorOf<[I8]>:$lhs_scale, + Optional>:$rhs_scale, + TT_ScaleDotElemTypeAttr:$lhs_type, + TT_ScaleDotElemTypeAttr:$rhs_type ); let results = (outs TT_FloatTensor:$d); diff --git a/include/triton/Dialect/TritonGPU/IR/Dialect.h b/include/triton/Dialect/TritonGPU/IR/Dialect.h index 74ea99b58891..a9b49448c1d0 100644 --- a/include/triton/Dialect/TritonGPU/IR/Dialect.h +++ b/include/triton/Dialect/TritonGPU/IR/Dialect.h @@ -76,9 +76,8 @@ SmallVector getWarpsPerCTAWithUniqueData(Attribute layout, ArrayRef tensorShape); // Returns the dimensions of the tensor from minor (fast-varying) to -// major (slow-varying). For blocked, mma, and dotOperand layouts, -// though the elements are in registers, the order refers to memory -// layout of the original tensor in global memory. +// major (slow-varying). For distributed layouts, this represents +// the order of the elements within a thread. // For shared Layout, the order refers to which dimension of the original tensor // is contiguous in shared memory. SmallVector getOrder(Attribute layout); @@ -130,6 +129,17 @@ unsigned getNumWarpsPerCTA(Attribute layout); unsigned getNumCTAs(Attribute layout); +// Return the order that represents that the batch is in row-major or +// column-major order for a batch of matrices of shape [*, m, n] with +// len(shape) == rank. +SmallVector getMatrixOrder(unsigned rank, bool rowMajor); + +// Return the order that represents that the dot operand is in kMajor +// (contiguous in the inner dimension) or it's contiguous on the outer +// dimension. +SmallVector getOrderForDotOperand(unsigned opIdx, unsigned rank, + bool kMajor); + bool isExpensiveCat(CatOp cat, Attribute targetEncoding); // Return true if a view between the two types cannot be implemented as a no-op. diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index c8512fce57fa..33308fb24569 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -361,8 +361,8 @@ compared to 1*64 when the hasLeadingOffset is false. return get(context, vec, perPhase, maxPhase, order, CTALayout); } - // ---- begin Ampere ---- - if (mmaEnc.isAmpere()) { + // ---- begin Ampere & Hopper ---- + if (mmaEnc.isAmpere() || mmaEnc.isHopper()) { int perPhase = 128 / (shapePerCTA[order[0]] * 4 / dotOpEnc.getKWidth()); perPhase = std::max(perPhase, 1); std::vector matShape = {8, 8, 4 * dotOpEnc.getKWidth()}; @@ -397,13 +397,6 @@ compared to 1*64 when the hasLeadingOffset is false. llvm_unreachable("invalid operand index"); } - // ---- begin version 3 ---- - if (mmaEnc.isHopper()) { - llvm_unreachable("SharedEncodingAttr builder when the MMAEncodingAttr" - " is Hopper has not been implemented yet"); - return $_get(context, 1, 1, 1, order, CTALayout, true); - } - // ---- not implemented ---- llvm_unreachable("unsupported swizzling for provided MMA version"); }]>, @@ -481,9 +474,16 @@ layout = [0 4 8 12] [3 7 11 15] For the Threads Per Warp and Values Per Thread level, the linear id distribution is variant for each sub-class encoding. + +If the layout does not completely cover the tensor, we tile it until we cover the entire tensor. +We call each individual tile "rep". }]; let methods = [ + InterfaceMethod<"Get the order of reps (tiles of this layout that tile the whole tensor). The fastest-changing axis first", + "SmallVector", + "getRepOrder">, + // Interface for the meta information about the multiple thread hierarchy. InterfaceMethod<"Get the shape of the CTAs per CGA.", "SmallVector", @@ -570,6 +570,7 @@ L(T) = [ {0,8} , {1,9} , {2,10}, {3,11}, {0,8} , {1, 9} , {2, 10}, {3, 11}, }]; code extraDistributedDeclaration = extraBaseClassDeclaration # [{ + SmallVector getRepOrder() const; SmallVector getCTAsPerCGA() const; SmallVector getCTAOrder() const; SmallVector getCTASplitNum() const; @@ -921,6 +922,7 @@ V [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129, unsigned getTotalElemsPerThreadForOperand(ArrayRef shape, Type eltTy, int kWidth, int opIdx) const; SmallVector getInstrShapeForOperand(int kWidth, int opIdx) const; SmallVector getRepForOperand(ArrayRef operandShape, int kWidth, int opIdx) const; + SmallVector getRepOrderForOperand(int opIdx) const; SmallVector getContigPerThread() { auto rank = getWarpsPerCTA().size(); @@ -1029,6 +1031,7 @@ Row | warp 0 warp 2 SmallVector getElemsPerInstrForOperands() const; SmallVector getRepForOperand(ArrayRef operandShape, Type elemType, int kWidth, int opIdx) const; + SmallVector getRepOrderForOperand(int opIdx) const; static SmallVector getMNKDimPerInstr(); SmallVector getContigPerThread() { @@ -1206,8 +1209,6 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is: bool isAmpere() const; bool isHopper() const; - unsigned getElemsPerThreadOfOperand(int opIdx, ArrayRef shape) const; - // Get [isARow, isBRow, isAVec4, isBVec4, id] from versionMinor std::tuple decodeVoltaLayoutStates() const; @@ -1224,8 +1225,9 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is: SmallVector getMMAv1Rep(int opIdx) const; SmallVector getMMAv1ShapePerWarp(int opIdx) const; int getMMAv1Vec(int opIdx) const; - SmallVector getMMAv2RepForOperand(ArrayRef shape, - int bitwidth, int kWidth, int opIdx) const; + SmallVector getRepForOperand(ArrayRef shape, + int bitwidth, int opIdx) const; + SmallVector getRepOrderForOperand(int opIdx) const; bool supportReduction() const { if (isAmpere() || isHopper()) { @@ -1319,6 +1321,27 @@ The parent field is the layout of d. kWidth defines number of consecutive elements stored by one thread along k dimension. Some layouts do not use this parameter, either because they have a fixed number of elements along the K dim, or they use all elements of the tensor along the K dim. + +# WGMMA Notes +We require kWidth to be provided for Hopper because the dtype at loading might be +different from the dtype at WGMMA, due to casting. The kWidth is determined by the +dtype at WGMMA. + +The encoded tensor consists of operand A for possibly multiple wgmma instructions. +For each wgmma, each warp in a warp group feeds a single "warp matrix" +Each warp matrix consists of 2x2 "quads". +Each thread holds several elements in each quad. Right before a wgmma, +the sum of bitwidth of +the elements in each quad should add up to 32. + +These values are stored unrolled in `elements`. +The ordering of dimensions is as follows by convention: +batch (only 1 batch for Hopper currently) +matM (m-index of the "warp matrix") +matK (k-index of the "warp matrix") +quadK (k-index of the "quad" in the core matrix) +quadM (m-index of the "quad" in the core matrix) +vecIdx (index of the element in the quad; this is always along the k-dim) }]; let parameters = ( @@ -1329,16 +1352,16 @@ elements along the K dim, or they use all elements of the tensor along the K dim ); let builders = [ - // Specially for MMAV1(Volta) AttrBuilder<(ins "unsigned":$opIdx, "Attribute":$parent, "Type":$eltTy), [{ NvidiaMmaEncodingAttr parentAttr = mlir::dyn_cast(parent); - if (!parentAttr || !parentAttr.isAmpere()) - return $_get(context, opIdx, parent, 0); + if (!parentAttr || (!parentAttr.isAmpere() && !parentAttr.isHopper())) + return $_get(context, opIdx, parent, 0); // For MMAV1 + // For MMAV2 and V3 unsigned bitwidth = eltTy.getIntOrFloatBitWidth(); - unsigned MMAv2kWidth = 32 / bitwidth; - return $_get(context, opIdx, parent, MMAv2kWidth); + unsigned kWidth = 32 / bitwidth; + return $_get(context, opIdx, parent, kWidth); }]> ]; diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td index a290cb20310a..6299ee6ed43d 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td @@ -268,7 +268,7 @@ def TTG_UpcastMXFPOp : TTG_Op<"upcast_mxfp", [Pure, DeclareOpInterfaceMethods getFreeVariableMasks() const; + // Increase an input dimension without affecting the output dimension. The + // added free variables are mapped to 0, ensuring that the new input + // dimensions correspond directly to the existing output space. The function + // errors out if `newInDimSize` is less than the current size or the new size + // is not a power of 2. + LinearLayout resize(StringAttr inDim, int32_t newInDimSize) const; + std::string toString() const; friend bool operator==(LinearLayout lhs, LinearLayout rhs); diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index 276a6e7004df..f8ea6e42cd70 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -80,6 +80,9 @@ static SmallVector getRepShapeForCvt(RankedTensorType srcTy, auto srcShapePerCTATile = getShapePerCTATile(srcLayout, srcTy.getShape()); auto dstShapePerCTATile = getShapePerCTATile(dstLayout, dstTy.getShape()); + assert(srcTy.getRank() == dstTy.getRank() && + "src and dst must have the same rank"); + unsigned rank = dstTy.getRank(); SmallVector repShape(rank); for (unsigned d = 0; d < rank; ++d) { @@ -115,12 +118,8 @@ ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy, assert(!isMfmaToDotShortcut(srcTy, dstTy)); - // FIXME This is NOT entirely correct - // This should be getElemOrder, but we don't have such a method - // TODO Implement getElemOrder and make sure it's consistent with - // getContigPerThread - auto inOrd = gpu::getThreadOrder(srcLayout); - auto outOrd = gpu::getThreadOrder(dstLayout); + auto inOrd = gpu::getOrder(srcLayout); + auto outOrd = gpu::getOrder(dstLayout); scratchConfig.order = outOrd; unsigned srcContigPerThread = diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 4915d7b1acda..30ba11c31782 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -536,7 +536,7 @@ bool supportMMA(Value value, int version) { (elemTy.isInteger(8) && version >= 2); } -bool isBlockedToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) { +bool isBlockedToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) { auto blockedLayout = dyn_cast(srcTy.getEncoding()); auto dotOperandLayout = dyn_cast(dstTy.getEncoding()); if (blockedLayout == nullptr || dotOperandLayout == nullptr) @@ -655,8 +655,46 @@ std::optional minimalCvtLayout(RankedTensorType srcTy, toLinearLayout(dstTy.getShape(), dstTy.getEncoding()); if (!(srcLayout.has_value() && dstLayout.has_value())) return std::nullopt; + StringAttr kRegister = StringAttr::get(ctx, "register"); + StringAttr kLane = StringAttr::get(ctx, "lane"); + StringAttr kWarp = StringAttr::get(ctx, "warp"); + StringAttr kBlock = StringAttr::get(ctx, "block"); + auto numSrcRegs = srcLayout->getInDimSize(kRegister); + auto numDstRegs = dstLayout->getInDimSize(kRegister); + // The `invertAndCompose` function will generate a layout that is injective + // by assigning new output dimensions to free variables. For instance, + // consider a scenario where `srcLayout` has a free variable in the lane + // dimension, while `dstLayout` has two free variables in the lane + // dimension and also a larger number of registers. + // The injective form of `srcLayout` will add only a single additional row + // to the transformation matrix, whereas the injective form of `dstLayout` + // will add two additional rows. This discrepancy causes misleading results + // because the matrices end up with a different number of rows. + // + // Take `dstLayout ⋅ srcLayout^-1` as an example: + // + // - `injective(dstLayout)`: [n, m] → [n + 2, m] + // - `injective(srcLayout)`: [n, m] → [n + 1, m] + // - `injective(srcLayout)^-1`: [n + 1, m] → [m, n + 1] + // - `injective(dstLayout) ⋅ injective(srcLayout)^-1`: [n + 2, m] ⋅ [m, n + + // 1] → [n + 2, n + 1] + // + // Here, the `(n + 1)`-th row added by `dstLayout` represents the free + // variable in registers, and the `(n + 2)`-th row represents the free + // variable in lanes. However, the `(n + 1)`-th row added by `srcLayout` + // represents the free variable in lanes. As a result, the `(n + 1)`-th row + // in two layouts do not correspond to the same free variable. + // + // To address this issue, we pad the free variables in `srcLayout` and + // `dstLayout` to ensure they have the same number of registers. This + // guarantees that the resulting matrices have the same number of rows, + // ensuring consistency in the composition process. + auto numRegs = std::max(numSrcRegs, numDstRegs); + auto srcLayoutWithFreeRegs = srcLayout->resize(kRegister, numRegs); + auto dstLayoutWithFreeRegs = dstLayout->resize(kRegister, numRegs); // comp describes the layout function to create dst from src. - LinearLayout comp = dstLayout->invertAndCompose(*srcLayout); + LinearLayout comp = + dstLayoutWithFreeRegs.invertAndCompose(srcLayoutWithFreeRegs); // We try to quotient by the largest subspace first auto dims = SmallVector{"block", "warp", "lane", "register"}; for (auto dim : dims) { @@ -693,14 +731,14 @@ bool cvtNeedsWarpShuffle(RankedTensorType srcTy, RankedTensorType dstTy) { } bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy) { - // TODO(jlebar): Remove these special cases (`isMmaToDotShortcut`, - // `isBlockedToDotShortcut` and `isMfmaToDotShortcut`) once they're fully - // subsumed by the linear-layout checks. + // TODO(jlebar): Remove these special cases (`isBlockedToDotShortcut` and + // `isMfmaToDotShortcut`) once they're fully subsumed by the linear-layout + // checks. // TODO(Keren): We didn't check `cvtNeedsWarpShuffle` here because it's not // supported yet in Triton's backend. return !cvtReordersRegisters(srcTy, dstTy) && !isBlockedToDotShortcut(srcTy, dstTy) && - !isMmaToDotShortcut(srcTy, dstTy) && + !matchMmaV3AndDotOperandLayout(srcTy, dstTy) && !isMfmaToDotShortcut(srcTy, dstTy); } @@ -711,20 +749,6 @@ bool atomicNeedsSharedMemory(Value value) { return true; } -bool isMmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) { - if (matchMmaV3AndDotOperandLayout(srcTy, dstTy)) - return true; - // dot_op = #mma - // when #mma = MmaEncoding - auto mmaLayout = dyn_cast(srcTy.getEncoding()); - auto dotOperandLayout = dyn_cast(dstTy.getEncoding()); - return mmaLayout && dotOperandLayout && mmaLayout.getVersionMajor() == 2 && - mmaLayout.getWarpsPerCTA()[1] == 1 && - dotOperandLayout.getOpIdx() == 0 && - dotOperandLayout.getParent() == mmaLayout && - !srcTy.getElementType().isF32(); -} - namespace { /// A data structure similar to SetVector but maintains diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index a18b2cbc308c..3fcae4897829 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -288,60 +288,73 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion return rewriter.notifyMatchFailure( op, "NYI. srcTy and/or dstTy don't implement LLs yet"); } + LinearLayout srcLayout = + *toLinearLayout(srcTy.getShape(), srcTy.getEncoding()); + LinearLayout dstLayout = + *toLinearLayout(dstTy.getShape(), dstTy.getEncoding()); + + StringAttr kBlock = str_attr("block"); + StringAttr kWarp = str_attr("warp"); + StringAttr kLane = str_attr("lane"); + StringAttr kRegister = str_attr("register"); assert(to_vector(conversion->getInDimNames()) == to_vector(conversion->getOutDimNames())); auto dims = conversion->getInDimNames(); - if (llvm::is_contained(dims, str_attr("block"))) { + if (llvm::is_contained(dims, kBlock)) { // Case 1: Transfer between values in different CTAs. // This requires moving values through distributed shared memory. return rewriter.notifyMatchFailure( op, "NYI: Transfer between different CTAs"); - } else if (llvm::is_contained(dims, str_attr("warp"))) { + } else if (llvm::is_contained(dims, kWarp)) { // Case 2: Transfer between values in the same CTA, in which case we move // values through shared memory. - LinearLayout srcLayout = - *toLinearLayout(srcTy.getShape(), srcTy.getEncoding()); - LinearLayout dstLayout = - *toLinearLayout(dstTy.getShape(), dstTy.getEncoding()); return transferWithinBlock(op, srcLayout, dstLayout, adaptor, rewriter); - } else if (llvm::is_contained(dims, str_attr("lane"))) { + } else if (llvm::is_contained(dims, kLane)) { // Case 3. Transfer between values in the same warp, in which case we try // to move values using warp shuffles, though if the pattern is // complicated enough we may fall back to using shared memory // TODO(Keren): implement warp shuffle instead of using the general // approach that uses shared memory - LinearLayout srcLayout = - *toLinearLayout(srcTy.getShape(), srcTy.getEncoding()); - LinearLayout dstLayout = - *toLinearLayout(dstTy.getShape(), dstTy.getEncoding()); return transferWithinBlock(op, srcLayout, dstLayout, adaptor, rewriter); - } else if (llvm::is_contained(dims, str_attr("register"))) { + } else if (llvm::is_contained(dims, kRegister) || + dstLayout.getInDimSize(kRegister) != + srcLayout.getInDimSize(kRegister)) { // Case 4. Transfer between values in the same thread, in which case we // simply reorder the elements of adaptor.getSrc(). - return transferWithinThread(op, *conversion, adaptor, rewriter); + return transferWithinThread( + op, dstLayout.getFreeVariableMasks()[kRegister], + dstLayout.getInDimSize(kRegister), *conversion, adaptor, rewriter); } else { - // The two layouts are equivalent. We should probably remove these in - // RemoveLayoutConversion. + // Cast 5. The two layouts are equivalent. We should probably remove + // these in RemoveLayoutConversion. rewriter.replaceOp(op, adaptor.getSrc()); return success(); } } LogicalResult - transferWithinThread(ConvertLayoutOp op, const LinearLayout &conversion, - OpAdaptor adaptor, + transferWithinThread(ConvertLayoutOp op, int32_t regMasks, int32_t numRegs, + const LinearLayout &conversion, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { MLIRContext *ctx = op.getContext(); auto loc = op.getLoc(); StringAttr kRegister = str_attr("register"); assert(!cvtNeedsSharedMemory(op.getSrc().getType(), op.getType())); + auto srcTy = op.getSrc().getType(); + auto dstTy = op.getType(); auto inVals = unpackLLElements(loc, adaptor.getSrc(), rewriter); - SmallVector outVals; - outVals.resize(conversion.getInDimSize(kRegister)); - for (int i = 0; i < conversion.getInDimSize(kRegister); i++) { - auto srcIdx = conversion.apply({{kRegister, i}}).begin()->second; + SmallVector outVals(numRegs); + for (int i = 0; i < numRegs; i++) { + // Remove free masks from the register index + // For example, if idx = 0b00111, and masks = 0b00100, then we get + // 0b00011. It means that register 7 (0b111) has the same value as + // register 3 (0b011). + auto idx = i & (~regMasks); + auto srcIdx = conversion.hasInDim(kRegister) + ? conversion.apply({{kRegister, idx}}).begin()->second + : idx; outVals[i] = inVals[srcIdx]; } Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter, @@ -375,18 +388,14 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion if (auto dotOperand = dyn_cast(layout)) { if (auto nvidiaMma = dyn_cast(dotOperand.getParent())) { - if (product(getCTAsPerCGA(nvidiaMma)) > 1) { - return false; - } if (useLegacyMMAConversion) { return false; } - // FIXME [Dot LL] - // Enabling LL path for buggy kWidth path - bool largeKWidth = - dotOperand.getKWidth() * dstTy.getElementTypeBitWidth() > 64; - return largeKWidth && nvidiaMma.isAmpere(); + if (nvidiaMma.isAmpere()) { + return true; + } } + return false; } if (isa(layout)) { return true; @@ -447,22 +456,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion } } - // FIXME [Dot LL] - // We know it's just for largeKWidth case in Ampere - // In this case, we need to pack the outputs into i32 - if (isa(dstTy.getEncoding())) { - auto concat = [&](Value a, Value b) { - return or_(zext(i32_ty, bitcast(a, i16_ty)), - shl(zext(i32_ty, bitcast(b, i16_ty)), i32_val(16))); - }; - - SmallVector outVals32(outVals.size() / 2); - for (int i = 0; i < outVals32.size(); ++i) { - outVals32[i] = concat(outVals[2 * i], outVals[2 * i + 1]); - } - outVals = outVals32; - } - Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter, op.getType()); rewriter.replaceOp(op, result); diff --git a/lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp b/lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp index 1346cc143ed2..74b2767f0de1 100644 --- a/lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp +++ b/lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp @@ -90,6 +90,10 @@ void decomposeBlockedToDotLayoutConversion(ModuleOp module) { auto dstDotOp = dyn_cast(dstType.getEncoding()); if (srcBlocked && dstDotOp) { + auto dotParent = dyn_cast(dstDotOp.getParent()); + if (dotParent && dotParent.isAmpere()) { + return; + } Attribute sharedMemorySpace = triton::gpu::SharedMemorySpaceAttr::get(srcType.getContext()); auto tmpType = MemDescType::get( diff --git a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp index 8ee166866974..632ccf10848c 100644 --- a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -17,158 +17,17 @@ Type getElementType(Value value) { return tensorType.getElementType(); return type; } -// MMA encoding has a different order depending on the element's bit width; -// reorder if we're in this case. -SmallVector reorderValues(const SmallVector &values, Type inType, - Type ouType) { - auto inTensorTy = dyn_cast(inType); - auto ouTensorTy = dyn_cast(ouType); - if (!inTensorTy || !ouTensorTy) - return values; - auto inEncoding = dyn_cast(inTensorTy.getEncoding()); - auto ouEncoding = dyn_cast(ouTensorTy.getEncoding()); - assert(inEncoding == ouEncoding); - if (!inEncoding) - return values; - // If the parent of the dot operand is in block encoding, we don't need to - // reorder elements - auto parentEncoding = dyn_cast(ouEncoding.getParent()); - if (!parentEncoding) - return values; - size_t inBitWidth = inTensorTy.getElementType().getIntOrFloatBitWidth(); - size_t ouBitWidth = ouTensorTy.getElementType().getIntOrFloatBitWidth(); - auto ouEltTy = ouTensorTy.getElementType(); - if (inBitWidth == ouBitWidth) - return values; - if (inBitWidth == 16 && ouBitWidth == 32) { - // Register layout conversion: - // - // [0, 1], [4, 5] ⟶ [0], [1], [4], [5] - // [2, 3], [6, 7] [2], [3], [6], [7] - // - // Original access order: - // - // [0, 1], [2, 3], [4, 5], [6, 7] - // - // Transformed access order: - // - // [0], [2], [1], [3], [4], [6], [5], [7] - SmallVector ret; - for (unsigned i = 0; i < values.size(); i += 8) { - ret.push_back(values[i]); - ret.push_back(values[i + 2]); - ret.push_back(values[i + 1]); - ret.push_back(values[i + 3]); - ret.push_back(values[i + 4]); - ret.push_back(values[i + 6]); - ret.push_back(values[i + 5]); - ret.push_back(values[i + 7]); - } - return ret; - } - if (inBitWidth == 8 && ouBitWidth == 16) { - // Register layout conversion: - // - // [0, 1, 2, 3], [8, 9, 10, 11] ⟶ [0, 1], [2, 3], [8, 9], [10, 11] - // [4, 5, 6, 7], [12, 13, 14, 15] [4, 5], [6, 7], [12, 13], [14, 15] - // - // Original access order: - // - // [0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15] - // - // Transformed access order: - // - // [0, 1], [4, 5], [2, 3], [6, 7], [8, 9], [12, 13], [10, 11], [14, 15] - SmallVector ret; - for (unsigned i = 0; i < values.size(); i += 16) { - ret.push_back(values[i]); - ret.push_back(values[i + 1]); - ret.push_back(values[i + 4]); - ret.push_back(values[i + 5]); - ret.push_back(values[i + 2]); - ret.push_back(values[i + 3]); - ret.push_back(values[i + 6]); - ret.push_back(values[i + 7]); - ret.push_back(values[i + 8]); - ret.push_back(values[i + 9]); - ret.push_back(values[i + 12]); - ret.push_back(values[i + 13]); - ret.push_back(values[i + 10]); - ret.push_back(values[i + 11]); - ret.push_back(values[i + 14]); - ret.push_back(values[i + 15]); - } - return ret; - } - llvm_unreachable("unimplemented code path"); -} - -SmallVector unpackI32(const SmallVector &inValues, Type srcTy, - ConversionPatternRewriter &rewriter, Location loc, - const LLVMTypeConverter *typeConverter) { - auto tensorTy = dyn_cast(srcTy); - if (!tensorTy) - return inValues; - auto encoding = dyn_cast(tensorTy.getEncoding()); - if (!(encoding && isa(encoding.getParent()))) - return inValues; - SmallVector outValues; - for (auto v : inValues) { - // cast i32 to appropriate eltType vector and extract elements - auto eltType = typeConverter->convertType(tensorTy.getElementType()); - auto vecType = vec_ty(eltType, 32 / eltType.getIntOrFloatBitWidth()); - auto vec = bitcast(v, vecType); - for (int i = 0; i < 32 / eltType.getIntOrFloatBitWidth(); i++) { - outValues.push_back(extract_element(vec, i32_val(i))); - } - } - return outValues; -} - -SmallVector packI32(const SmallVector &inValues, Type srcTy, - ConversionPatternRewriter &rewriter, Location loc, - const LLVMTypeConverter *typeConverter) { - auto tensorTy = dyn_cast(srcTy); - if (!tensorTy) - return inValues; - auto encoding = dyn_cast(tensorTy.getEncoding()); - if (!(encoding && isa(encoding.getParent()))) - return inValues; - SmallVector outValues; - auto eltType = typeConverter->convertType(tensorTy.getElementType()); - int vecWidth = 32 / eltType.getIntOrFloatBitWidth(); - auto vecType = vec_ty(eltType, vecWidth); - for (int i = 0; i < inValues.size(); i += vecWidth) { - Value vec = undef(vecType); - for (int j = 0; j < vecWidth; j++) { - vec = insert_element(vec, inValues[i + j], i32_val(j)); - } - outValues.push_back(bitcast(vec, i32_ty)); - } - return outValues; -} int getNumElementsPerThreads(Type type, const LLVMTypeConverter *typeConverter) { int numElemsPerThread = 1; - auto tensorTy = dyn_cast(type); - if (!tensorTy) - return numElemsPerThread; - auto structType = - dyn_cast(typeConverter->convertType(type)); - if (structType) { - numElemsPerThread = structType.getBody().size(); + if (auto tensorTy = dyn_cast(type)) { + auto structType = + dyn_cast(typeConverter->convertType(type)); + if (structType) + numElemsPerThread = structType.getBody().size(); } - auto encoding = dyn_cast(tensorTy.getEncoding()); - if (!(encoding && isa(encoding.getParent()))) - return numElemsPerThread; - auto eltType = tensorTy.getElementType(); - assert(eltType.getIntOrFloatBitWidth() <= 32 && - "Only support element type with bit width <= 32 in dot operand mma " - "layout"); - // dot operand data are packed into i32 elements so use the following formula - // to get the number of elements per thread. - return (32 / eltType.getIntOrFloatBitWidth()) * numElemsPerThread; + return numElemsPerThread; } } // namespace mlir::triton::gpu @@ -499,8 +358,7 @@ struct ElementwiseInlineAsmOpConversion for (auto operand : adaptor.getOperands()) { auto argTy = op->getOperand(0).getType(); auto subOperands = unpackLLElements(loc, operand, rewriter); - unpackedOperands.push_back( - unpackI32(subOperands, argTy, rewriter, loc, getTypeConverter())); + unpackedOperands.push_back(subOperands); } int numElemsPerThread = getNumElementsPerThreads(op->getResult(0).getType(), @@ -553,17 +411,8 @@ struct ElementwiseInlineAsmOpConversion // Reorder and pack the results. SmallVector outs; for (int i = 0; i < unpackedResults.size(); i++) { - // We reordered all the inputs so they match operand 0. Reorder the - // outputs accordingly. - if (op->getNumOperands() > 0) { - unpackedResults[i] = reorderValues( - unpackedResults[i], /*inType=*/op->getOperand(0).getType(), - /*ouType=*/op->getResult(i).getType()); - } - auto packed = packI32(unpackedResults[i], op->getResult(i).getType(), - rewriter, loc, getTypeConverter()); - outs.push_back(packLLElements(loc, getTypeConverter(), packed, rewriter, - op->getResult(i).getType())); + outs.push_back(packLLElements(loc, getTypeConverter(), unpackedResults[i], + rewriter, op->getResult(i).getType())); } rewriter.replaceOp(op, outs); diff --git a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp index 1a0c115a9ecf..fbd6248fe710 100644 --- a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp @@ -15,12 +15,11 @@ using namespace mlir::triton::gpu; // blocked -> shared. // Swizzling in shared memory to avoid bank conflict. Normally used for // A/B operands of dots. -void lowerDistributedToShared(Location loc, Value src, Value dst, - Value adaptorSrc, - const SharedMemoryObject &smemObj, - const LLVMTypeConverter *typeConverter, - ConversionPatternRewriter &rewriter, - const TargetInfoBase &targetInfo) { +void lowerDistributedToShared( + Location loc, Value src, Value dst, Value adaptorSrc, + const SharedMemoryObject &smemObj, const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter, const TargetInfoBase &targetInfo, + std::pair *const llvmOpCount = nullptr) { auto srcTy = cast(src.getType()); auto dstTy = cast(dst.getType()); auto outOrd = mlir::cast(dstTy.getEncoding()).getOrder(); @@ -33,7 +32,7 @@ void lowerDistributedToShared(Location loc, Value src, Value dst, auto dstStrides = smemObj.getStrides(); auto inVals = unpackLLElements(loc, adaptorSrc, rewriter); storeDistributedToShared(dstTy, srcTy, elemTy, inVals, smemBase, dstStrides, - loc, rewriter, targetInfo); + loc, rewriter, targetInfo, llvmOpCount); } struct LocalAllocOpConversion @@ -174,7 +173,6 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern { auto dstLayout = dstTy.getEncoding(); assert((dstShape.size() <= 2 || isSupportedDotOpLayout(dstLayout)) && "Unexpected rank of ConvertLayout(shared->distributed)"); - auto inOrd = getOrder(srcSharedLayout); auto smemObj = LLVM::getSharedMemoryObjectFromStruct( loc, adaptor.getSrc(), @@ -184,42 +182,6 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern { SmallVector outVals = loadSharedToDistributed( dstTy, srcTy, elemLlvmTy, smemObj, loc, rewriter, targetInfo); - // FIXME [Dot LL] - // Ampere case - // In this case, we need to pack the outputs into i32 - if (auto dotOp = dyn_cast(dstTy.getEncoding())) { - if (auto parent = dyn_cast(dotOp.getParent())) { - if (parent.isAmpere()) { - if (elemLlvmTy.isInteger(8)) { - auto concat = [&](Value a1, Value a2, Value a3, Value a4) { - return or_( - or_(zext(i32_ty, a1), shl(zext(i32_ty, a2), i32_val(8))), - or_(shl(zext(i32_ty, a3), i32_val(16)), - shl(zext(i32_ty, a4), i32_val(24)))); - }; - SmallVector outVals32(outVals.size() / 4); - for (int i = 0; i < outVals32.size(); ++i) { - outVals32[i] = concat(outVals[4 * i], outVals[4 * i + 1], - outVals[4 * i + 2], outVals[4 * i + 3]); - } - outVals = outVals32; - } else { - assert(elemLlvmTy.isBF16() && "Unexpected element type"); - auto concat = [&](Value a, Value b) { - return or_(zext(i32_ty, bitcast(a, i16_ty)), - shl(zext(i32_ty, bitcast(b, i16_ty)), i32_val(16))); - }; - - SmallVector outVals32(outVals.size() / 2); - for (int i = 0; i < outVals32.size(); ++i) { - outVals32[i] = concat(outVals[2 * i], outVals[2 * i + 1]); - } - outVals = outVals32; - } - } - } - } - Value result = packLLElements(loc, typeConverter, outVals, rewriter, dstTy); rewriter.replaceOp(op, result); @@ -235,12 +197,15 @@ struct LocalStoreOpConversion public: using ConvertOpToLLVMPattern< triton::gpu::LocalStoreOp>::ConvertOpToLLVMPattern; + using BackendCallbackType = + decltype(BackendCallbacks::localStoreOpConversion); LocalStoreOpConversion(const LLVMTypeConverter &converter, const TargetInfoBase &targetInfo, + BackendCallbackType backendCallback, PatternBenefit benefit = 1) : ConvertOpToLLVMPattern(converter, benefit), - targetInfo(targetInfo) {} + targetInfo(targetInfo), backendCallback(backendCallback) {} LogicalResult matchAndRewrite(triton::gpu::LocalStoreOp op, OpAdaptor adaptor, @@ -250,24 +215,36 @@ struct LocalStoreOpConversion getTypeConverter()->convertType(op.getDst().getType().getElementType()); auto smemObj = LLVM::getSharedMemoryObjectFromStruct( op.getLoc(), adaptor.getDst(), llvmElemTy, rewriter); + + std::pair llvmOpCount; lowerDistributedToShared(op.getLoc(), op.getSrc(), op.getDst(), adaptor.getSrc(), smemObj, getTypeConverter(), - rewriter, targetInfo); + rewriter, targetInfo, &llvmOpCount); + + if (backendCallback) + (backendCallback)(op, llvmOpCount.first, llvmOpCount.second); + rewriter.eraseOp(op); return success(); } private: const TargetInfoBase &targetInfo; + BackendCallbackType backendCallback; }; } // namespace void mlir::triton::populateMemoryOpToLLVMPattern( LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, - RewritePatternSet &patterns, PatternBenefit benefit) { + RewritePatternSet &patterns, PatternBenefit benefit, + std::optional backendCallbacks) { patterns.add(typeConverter, targetInfo, benefit); patterns.add(typeConverter, benefit); patterns.add(typeConverter, targetInfo, benefit); - patterns.add(typeConverter, targetInfo, benefit); + + auto backendCall = + backendCallbacks ? backendCallbacks->localStoreOpConversion : nullptr; + patterns.add(typeConverter, targetInfo, backendCall, + benefit); } diff --git a/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp index 969b227c8dda..64e6ca787780 100644 --- a/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp @@ -389,10 +389,10 @@ ScanOpConversion::getDelinearizedIds(ConversionPatternRewriter &rewriter, auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcEncoding); auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcEncoding); - auto order = triton::gpu::getOrder(srcEncoding); + auto threadOrder = triton::gpu::getThreadOrder(srcEncoding); auto warpOrder = triton::gpu::getWarpOrder(srcEncoding); SmallVector multiDimLaneId = - delinearize(rewriter, loc, laneId, threadsPerWarp, order); + delinearize(rewriter, loc, laneId, threadsPerWarp, threadOrder); SmallVector multiDimWarpId = delinearize(rewriter, loc, warpId, warpsPerCTA, warpOrder); @@ -402,7 +402,7 @@ ScanOpConversion::getDelinearizedIds(ConversionPatternRewriter &rewriter, multiDimLaneId[axis] = i32_val(0); threadsPerWarp[axis] = 1; Value laneIdParallel = - linearize(rewriter, loc, multiDimLaneId, threadsPerWarp, order); + linearize(rewriter, loc, multiDimLaneId, threadsPerWarp, threadOrder); multiDimWarpId[axis] = i32_val(0); warpsPerCTA[axis] = 1; Value warpIdParallel = diff --git a/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp b/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp index cc6d8875b5c7..8cac1efbff8b 100644 --- a/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp @@ -70,29 +70,12 @@ Type TritonGPUToLLVMTypeConverter::convertTritonPointerType( return LLVM::LLVMPointerType::get(ctx, type.getAddressSpace()); } -Type TritonGPUToLLVMTypeConverter::getElementTypeForStruct( - TensorOrMemDesc type) { - auto ctx = type.getContext(); - Attribute layout = type.getEncoding(); - Type elemTy = convertType(type.getElementType()); - auto dotOpLayout = mlir::dyn_cast(layout); - if (!dotOpLayout) - return elemTy; - auto mmaParent = - mlir::dyn_cast(dotOpLayout.getParent()); - if (!mmaParent || mmaParent.isHopper()) - return elemTy; - int bitwidth = elemTy.getIntOrFloatBitWidth(); - assert(bitwidth <= 32); - return IntegerType::get(ctx, 32); -} - Type TritonGPUToLLVMTypeConverter::convertTritonTensorType( RankedTensorType type, const TargetInfoBase &targetInfo) { auto ctx = type.getContext(); Attribute layout = type.getEncoding(); SmallVector shape(type.getShape().begin(), type.getShape().end()); - Type eltType = getElementTypeForStruct(cast(type)); + Type eltType = convertType(type.getElementType()); if (auto shared_layout = mlir::dyn_cast(layout)) { SmallVector types; diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index e857dd36f6cb..b5ab3601ea17 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -404,7 +404,8 @@ void storeDistributedToShared(MemDescType dstTy, RankedTensorType srcTy, Type elemLlvmTy, ArrayRef srcVals, Value smemBase, ArrayRef dstStrides, Location loc, RewriterBase &rewriter, - const TargetInfoBase &target) { + const TargetInfoBase &target, + std::pair *const llvmOpCount) { bool success = emitTransferBetweenRegistersAndShared( srcTy, dstTy, elemLlvmTy, /*maxVecElems=*/std::nullopt, smemBase, dstStrides, loc, rewriter, target, [&](VectorType vecTy, Value vecAddr) { @@ -418,7 +419,12 @@ void storeDistributedToShared(MemDescType dstTy, RankedTensorType srcTy, store(vec, vecAddr) .setAlignment(vecTy.getNumElements() * elemLlvmTy.getIntOrFloatBitWidth() / 8); + if (llvmOpCount) { + ++(llvmOpCount->first); + llvmOpCount->second = vecTy; + } }); + if (!success) llvm::report_fatal_error("Failed to emit transfer from register to shared"); } @@ -856,5 +862,49 @@ SmallVector getWrappedMultiDimOffset( return multiDimOffsetWrapped; } +SmallVector convertMxfp4x2ToBf16x2(RewriterBase &rewriter, Location loc, + ArrayRef values) { + SmallVector results; + for (auto v : values) { + auto em0 = and_(v, i8_val(0x70)); + auto em1 = and_(v, i8_val(0x7)); + Value v0 = or_(shl(zext(i16_ty, em0), i16_val(2)), + shl(zext(i16_ty, and_(v, i8_val(0x80))), i16_val(8))); + Value v1 = or_(shl(zext(i16_ty, em1), i16_val(6)), + shl(zext(i16_ty, and_(v, i8_val(0x8))), i16_val(12))); + + // Three cases: + // 1) x is normal and non-zero: Correct bias + v0 = select(icmp_ne(and_(em0, i8_val(0x60)), i8_val(0)), + add(v0, i16_val((127 - 1) << 7)), v0); + v1 = select(icmp_ne(and_(em1, i8_val(0x6)), i8_val(0)), + add(v1, i16_val((127 - 1) << 7)), v1); + + // 2) x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in + // bf16 + v0 = bitcast(select(icmp_eq(em0, i8_val(0x10)), + or_(i16_val(16128), and_(v0, i16_val(0x8000))), v0), + bf16_ty); + v1 = bitcast(select(icmp_eq(em1, i8_val(0x1)), + or_(i16_val(16128), and_(v1, i16_val(0x8000))), v1), + bf16_ty); + // 3) x is zero, nothing to do + results.push_back(v0); + results.push_back(v1); + } + return results; +} + +Value mxfpScaleBf16(RewriterBase &rewriter, Location loc, Value v, + Value scale) { + Value vBf16 = bitcast(v, bf16_ty); + Value nanBf16 = bitcast(i16_val(0x7fff), bf16_ty); + Value scaleIsNan = icmp_eq(scale, i8_val(0xff)); + Value scaleBf16 = bitcast(shl(zext(i16_ty, scale), i16_val(7)), bf16_ty); + Value scaledBf16 = fmul(vBf16, scaleBf16); + // Account for NaN in the scale as per the mxfp specification. + return select(scaleIsNan, nanBf16, scaledBf16); +}; + } // namespace LLVM } // namespace mlir diff --git a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp index 297a94e851f6..8ba0fd3356f6 100644 --- a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp @@ -336,7 +336,6 @@ struct BroadcastOpConversion unsigned rank = srcTy.getRank(); auto typeConverter = getTypeConverter(); assert(rank == resultTy.getRank()); - auto order = triton::gpu::getOrder(srcLayout); auto srcOffsets = emitOffsetForLayout(srcLayout, srcTy); auto resultOffsets = emitOffsetForLayout(resultLayout, resultTy); SmallVector srcVals = unpackLLElements(loc, src, rewriter); diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 71506ecbb9f0..3338638d48b5 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -217,7 +217,7 @@ bool isExpensiveView(Type srcType, Type dstType) { return getTotalElemsPerThread(srcType) != getTotalElemsPerThread(dstType); } -/* Utility function used by getOrder and getCTAOrder of SliceEncodingAttr. +/* Utility function used by get.*Order methods of SliceEncodingAttr. * Erase dim and decrease all values larger than dim by 1. * Example: order = [0, 2, 4, 3, 1], dim = 2 * resOrder = [0, 3, 2, 1] @@ -235,6 +235,19 @@ static SmallVector eraseOrder(ArrayRef order, return resOrder; } +SmallVector getMatrixOrder(unsigned rank, bool rowMajor) { + // Return the order that represents that the batch is in row-major or + // column-major order for a batch of matrices of shape [*, m, n] with + // len(shape) == rank. + assert(rank >= 2); + SmallVector order(rank); + std::iota(order.rbegin(), order.rend(), 0); + if (!rowMajor) { + std::swap(order[0], order[1]); + } + return order; +} + SmallVector getOrderForDotOperand(unsigned opIdx, unsigned rank, bool kMajor) { // kMajor: if true, the matrix is fastest-running on k, @@ -244,42 +257,28 @@ SmallVector getOrderForDotOperand(unsigned opIdx, unsigned rank, // batch (if rank == 3) is always the slowest running dimension assert(rank == 2 || rank == 3); assert(opIdx == 0 || opIdx == 1); - SmallVector order(rank); - std::iota(order.rbegin(), order.rend(), 0); - // If opIdx is 1 and kMajor is true, the order is [0, 1] - // (resp. [1, 2, 0] if rank == 3) - // Same if opIdx is 0 and kMajor is false - if (bool(opIdx) == kMajor) { - std::swap(order[0], order[1]); - } - return order; + auto rowMajor = bool(opIdx) != kMajor; + return getMatrixOrder(rank, rowMajor); +} + +SmallVector getRepOrder(Attribute layout) { + if (auto distributedLayout = mlir::dyn_cast(layout)) + return distributedLayout.getRepOrder(); + else + llvm::report_fatal_error("Unimplemented usage of getRepOrder"); + return {}; } SmallVector getWarpOrder(Attribute layout) { - if (auto dotLayout = dyn_cast(layout)) { - if (isa(dotLayout.getParent())) { - return getWarpOrder(dotLayout.getParent()); - } - } - auto order = getOrder(layout); - // FIXME: This mmaLayout if should just return - // getOrderForDotOperand(0, order.size(), kMajor=false) - // as mma has the same order as DotOperand(opIdx=0) - if (auto mmaLayout = dyn_cast(layout)) { - if (mmaLayout.isHopper()) { - // Hopper MMA instructions force a warp order of [0, 1]. See docs: - // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-wgmma-mma-async-m64nnk8 - auto it = std::find(order.begin(), order.end(), 0); - order.erase(it); - order.insert(order.begin(), 0); - } - } else if (auto dotOpLayout = dyn_cast(layout)) { - order = getOrderForDotOperand(dotOpLayout.getOpIdx(), order.size(), - /*kMajor*/ false); - } - return order; + if (auto distributedLayout = mlir::dyn_cast(layout)) + return distributedLayout.getWarpOrder(); + else + llvm::report_fatal_error("Unimplemented usage of getThreadOrder"); + return {}; } +// Returns the order of the elements in a layout from the fastest running +// dimension to the slowest SmallVector getOrder(Attribute layout) { if (auto blockedLayout = dyn_cast(layout)) { return llvm::to_vector(blockedLayout.getOrder()); @@ -287,9 +286,7 @@ SmallVector getOrder(Attribute layout) { if (auto mmaLayout = dyn_cast(layout)) { auto distributedLayout = cast(layout); auto rank = distributedLayout.getWarpsPerCTA().size(); - SmallVector order(rank); - std::iota(order.rbegin(), order.rend(), 0); - return order; + return getMatrixOrder(rank, /*rowMajor*/ true); } if (auto dotLayout = dyn_cast(layout)) { auto rank = dotLayout.getWarpsPerCTA().size(); @@ -311,7 +308,7 @@ SmallVector getOrder(Attribute layout) { llvm::report_fatal_error("Unimplemented usage of getOrder"); return {}; -}; +} SmallVector getThreadOrder(Attribute layout) { if (auto distributedLayout = mlir::dyn_cast(layout)) @@ -319,7 +316,7 @@ SmallVector getThreadOrder(Attribute layout) { else llvm::report_fatal_error("Unimplemented usage of getThreadOrder"); return {}; -}; +} CTALayoutAttr getCTALayout(Attribute layout) { if (auto distributedLayout = @@ -421,7 +418,7 @@ unsigned getNumWarpsPerCTA(Attribute layout) { else if (auto wmmaLayout = dyn_cast(layout)) warpsPerCTA = wmmaLayout.getWarpsPerCTA(); else if (auto dotLayout = dyn_cast(layout)) - return getNumWarpsPerCTA(dotLayout.getParent()); + warpsPerCTA = dotLayout.getWarpsPerCTA(); else if (auto sharedLayout = dyn_cast(layout)) llvm::report_fatal_error("Cannot get numWarps from SharedEncodingAttr"); else @@ -654,6 +651,9 @@ unsigned BlockedEncodingAttr::getTotalElemsPerThread(ArrayRef shape, // If we only had BlockedEncodingAttr, we could simply return ArrayRefs here. // But we need to have a consistent interface with e.g. SliceEncodingAttr, which // computes some of these fields. +SmallVector BlockedEncodingAttr::getRepOrder() const { + return SmallVector(getOrder()); +} SmallVector BlockedEncodingAttr::getCTAsPerCGA() const { return SmallVector(getCTALayout().getCTAsPerCGA()); } @@ -720,6 +720,10 @@ unsigned SliceEncodingAttr::getTotalElemsPerThread(ArrayRef shape, Type eltTy) const { return product(getElemsPerThread(shape, eltTy)); } +SmallVector SliceEncodingAttr::getRepOrder() const { + auto parentRepOrder = ::getRepOrder(getParent()); + return eraseOrder(parentRepOrder, getDim()); +} SmallVector SliceEncodingAttr::getCTASplitNum() const { SmallVector res = ::getCTASplitNum(getParent()); res.erase(res.begin() + getDim()); @@ -762,7 +766,8 @@ SmallVector SliceEncodingAttr::getWarpsPerCTA() const { return warpsPerCTA; } SmallVector SliceEncodingAttr::getWarpOrder() const { - return ::getWarpOrder(*this); + auto parentWarpOrder = ::getWarpOrder(getParent()); + return eraseOrder(parentWarpOrder, getDim()); } SmallVector SliceEncodingAttr::getThreadsPerWarp() const { auto parent = getParent(); @@ -774,7 +779,8 @@ SmallVector SliceEncodingAttr::getThreadsPerWarp() const { return threadsPerWarp; } SmallVector SliceEncodingAttr::getThreadOrder() const { - return ::getOrder(*this); + auto parentThreadOrder = ::getThreadOrder(getParent()); + return eraseOrder(parentThreadOrder, getDim()); } SmallVector SliceEncodingAttr::getSizePerThread() const { auto sizePerThread = ::getSizePerThread(getParent()); @@ -907,36 +913,6 @@ NvidiaMmaEncodingAttr::getElemsPerThread(ArrayRef shape, return elemsPerThread; } -unsigned NvidiaMmaEncodingAttr::getElemsPerThreadOfOperand( - int opIdx, ArrayRef shape) const { - size_t rank = shape.size(); - assert(rank == 2 && "Unexpected rank of mma layout"); - auto shapePerCTA = getShapePerCTA(*this, shape); - int res = 0; - if (isVolta()) { - llvm_unreachable( - "getElemsPerThreadOfOperand() not supported for version 1"); - } else if (isAmpere()) { - llvm_unreachable( - "getElemsPerThreadOfOperand() not supported for version 2"); - } else if (isHopper()) { - auto wpt = getWarpsPerCTA(); - auto instrMNK = getInstrShape(); - if (opIdx == 0) { - int repM = ceil(shapePerCTA[0], instrMNK[0] * wpt[0]); - int repK = ceil(shapePerCTA[1], instrMNK[2]); - return 8 * repM * repK; - - } else if (opIdx == 1) { - int repK = ceil(shapePerCTA[0], instrMNK[2]); - int repN = ceil(shapePerCTA[1], instrMNK[1] * wpt[1]); - // benzh@ here need more check - return 4 * std::max(instrMNK[1] / 32, 1) * repK * repN; - } - } - return res; -} - unsigned NvidiaMmaEncodingAttr::getTotalElemsPerThread(ArrayRef shape, Type eltTy) const { return product(getElemsPerThread(shape, eltTy)); @@ -959,25 +935,41 @@ unsigned SharedEncodingAttr::getTotalElemsPerThread(ArrayRef shape, SmallVector DotOperandEncodingAttr::getElemsPerThread(ArrayRef shape, Type eltTy) const { + auto rank = shape.size(); + assert(rank == 2 || rank == 3); - if (auto parent = mlir::dyn_cast(getParent())) { - auto rank = shape.size(); - assert(rank == 2 || rank == 3); - - auto idx = getOpIdx(); - assert(idx == 0 || idx == 1); - - SmallVector elemsPerThread(rank); + auto idx = getOpIdx(); + assert(idx == 0 || idx == 1); - auto kWidth = getKWidth(); - auto rep = parent.getRepForOperand(shape, kWidth, idx); + SmallVector elemsPerThread(rank); + auto parent = getParent(); + auto kWidth = getKWidth(); + if (auto mfma = mlir::dyn_cast(parent)) { + auto rep = mfma.getRepForOperand(shape, kWidth, idx); if (rank == 3) elemsPerThread[0] = rep[0]; elemsPerThread[rank - 2] = (idx == 0) ? rep[1] : rep[1] * kWidth; elemsPerThread[rank - 1] = (idx == 0) ? rep[2] * kWidth : rep[2]; - return elemsPerThread; + } else if (auto mma = mlir::dyn_cast(parent)) { + if (mma.isAmpere()) { + auto bitwidth = getPointeeType(eltTy).getIntOrFloatBitWidth(); + auto rep = mma.getRepForOperand(shape, bitwidth, idx); + auto sizePerThread = getSizePerThread(); + auto elemsPerKRep = 32 / bitwidth * 2; + if (rank == 3) + elemsPerThread[0] = rep[0]; + elemsPerThread[rank - 2] = + (idx == 0) + ? rep[1] * sizePerThread[rank - 2] + : std::max(rep[1] * elemsPerKRep, sizePerThread[rank - 2]); + elemsPerThread[rank - 1] = + (idx == 0) + ? std::max(rep[2] * elemsPerKRep, sizePerThread[rank - 1]) + : rep[2] * sizePerThread[rank - 1]; + return elemsPerThread; + } } llvm_unreachable("getElemsPerThread is not supported for dot operand"); @@ -987,6 +979,10 @@ DotOperandEncodingAttr::getElemsPerThread(ArrayRef shape, unsigned DotOperandEncodingAttr::getTotalElemsPerThread(ArrayRef shape, Type eltTy) const { if (auto mmaParent = mlir::dyn_cast(getParent())) { + if (auto nvidiaMmaParent = mlir::dyn_cast(mmaParent); + nvidiaMmaParent && nvidiaMmaParent.isAmpere()) { + return product(getElemsPerThread(shape, eltTy)); + } return mmaParent.getTotalElemsPerThreadForOperand(shape, eltTy, getKWidth(), getOpIdx()); } @@ -1042,7 +1038,14 @@ SmallVector DotOperandEncodingAttr::getWarpsPerCTA() const { return warps; } SmallVector DotOperandEncodingAttr::getWarpOrder() const { - return ::getWarpOrder(*this); + // FIXME(Lezcano): Preexisting. Do we want to have this path at all? + if (mlir::isa(getParent())) { + return ::getWarpOrder(getParent()); + } + // It's quite weird to talk about warp order when that the warps + // are broadcasted along the K dimension + llvm::report_fatal_error("DotOperandEncoding::getWarpOrder not implemented"); + return {}; } SmallVector DotOperandEncodingAttr::getThreadOrder() const { return getOrderForDotOperand(getOpIdx(), getWarpsPerCTA().size(), @@ -1074,13 +1077,18 @@ LogicalResult DotOperandEncodingAttr::verify( return emitError() << "triton_gpu.dot_op parent paramenter cannot be null"; } if (auto parentAttr = mlir::dyn_cast(parent)) { - if (kWidth != 0 && !parentAttr.isAmpere()) + if (kWidth != 0 && !(parentAttr.isAmpere() || parentAttr.isHopper())) return emitError() << "triton_gpu.dot_op kWidth parameter can only be " - "non-zero for Ampere MMA parent"; - if (kWidth == 0 && parentAttr.isAmpere()) + "non-zero for Ampere or Hopper MMA parent"; + if (kWidth == 0 && (parentAttr.isAmpere() || parentAttr.isHopper())) return emitError() << "triton_gpu.dot_op kWidth parameter is mandatory for " - "Ampere MMA parent"; + "Ampere or Hopper MMA parent"; + if (opIdx != 0 && parentAttr.isHopper()) + return emitError() + << "triton_gpu.dot_op opIdx parameter must be 0 for " + "Hopper MMA parent, since Hopper WGMMA only allows first " + "operand to be in registers"; return success(); } @@ -1585,7 +1593,7 @@ SmallVector AMDMfmaEncodingAttr::getWarpsPerCTA() const { return SmallVector(getWarpsPerCTA__()); } SmallVector AMDMfmaEncodingAttr::getWarpOrder() const { - return ::getWarpOrder(*this); + return ::getOrder(*this); } SmallVector AMDMfmaEncodingAttr::getThreadOrder() const { auto order = ::getOrder(*this); @@ -1658,6 +1666,10 @@ AMDMfmaEncodingAttr::getInstrShapeForOperand(int kWidth, int opIdx) const { return {kDim, nDim}; } +SmallVector AMDMfmaEncodingAttr::getRepOrder() const { + llvm::report_fatal_error("NYI. AMDMfmaEncodingAttr::getRepOrder"); +} + SmallVector AMDMfmaEncodingAttr::getRepForOperand(ArrayRef operandShape, int kWidth, int opIdx) const { @@ -1741,6 +1753,9 @@ AMDWmmaEncodingAttr::getShapePerCTATile(ArrayRef tensorShape) const { shapePerCTATile[rank - 1] *= mnkDim[1]; return shapePerCTATile; } +SmallVector AMDWmmaEncodingAttr::getRepOrder() const { + llvm::report_fatal_error("NYI. AMDWmmaEncodingAttr::getRepOrder"); +} SmallVector AMDWmmaEncodingAttr::getCTAsPerCGA() const { return SmallVector(getCTALayout().getCTAsPerCGA()); } @@ -1754,7 +1769,7 @@ SmallVector AMDWmmaEncodingAttr::getWarpsPerCTA() const { return SmallVector(getWarpsPerCTA__()); } SmallVector AMDWmmaEncodingAttr::getWarpOrder() const { - return ::getWarpOrder(*this); + return ::getOrder(*this); } SmallVector AMDWmmaEncodingAttr::getThreadOrder() const { return ::getOrder(*this); @@ -1865,6 +1880,10 @@ bool NvidiaMmaEncodingAttr::isAmpere() const { return getVersionMajor() == 2; } bool NvidiaMmaEncodingAttr::isHopper() const { return getVersionMajor() == 3; } +SmallVector NvidiaMmaEncodingAttr::getRepOrder() const { + auto rank = getWarpsPerCTA().size(); + return getMatrixOrder(rank, /*rowMajor*/ true); +} SmallVector NvidiaMmaEncodingAttr::getCTAsPerCGA() const { return SmallVector(getCTALayout().getCTAsPerCGA()); } @@ -1878,7 +1897,11 @@ SmallVector NvidiaMmaEncodingAttr::getWarpsPerCTA() const { return SmallVector(getWarpsPerCTA__()); } SmallVector NvidiaMmaEncodingAttr::getWarpOrder() const { - return ::getWarpOrder(*this); + auto rank = getWarpsPerCTA().size(); + // Hopper (wgmma) uses column-major as this is embeded in the instruction + // For Ampere we can choose either row-major or column-major. + // We choose row-major as the legacy path did so + return getMatrixOrder(rank, /*rowMajor*/ !isHopper()); } SmallVector NvidiaMmaEncodingAttr::getThreadsPerWarp() const { auto rank = getWarpsPerCTA().size(); @@ -1902,10 +1925,11 @@ SmallVector NvidiaMmaEncodingAttr::getThreadsPerWarp() const { "getThreadsPerWarp not implemented for unknown Mma version "); } SmallVector NvidiaMmaEncodingAttr::getThreadOrder() const { - return ::getOrder(*this); + auto rank = getWarpsPerCTA().size(); + return getMatrixOrder(rank, /*rowMajor*/ true); } SmallVector NvidiaMmaEncodingAttr::getSizePerThread() const { - auto rank = ::getOrder(*this).size(); + auto rank = getWarpsPerCTA().size(); SmallVector res(rank, 1); if (isAmpere()) { res[rank - 2] = 2; @@ -2013,31 +2037,43 @@ SmallVector NvidiaMmaEncodingAttr::getMMAv1ShapePerWarp(int opIdx) const { int NvidiaMmaEncodingAttr::getMMAv1Vec(int opIdx) const { return 2 * getMMAv1Rep(opIdx)[opIdx]; } -SmallVector NvidiaMmaEncodingAttr::getMMAv2RepForOperand( - ArrayRef shape, int bitwidth, int kWidth, int opIdx) const { + +SmallVector +NvidiaMmaEncodingAttr::getRepOrderForOperand(int opIdx) const { + auto rank = getWarpsPerCTA().size(); + return getOrderForDotOperand(opIdx, rank, /*kMajor*/ true); +} + +SmallVector +NvidiaMmaEncodingAttr::getRepForOperand(ArrayRef shape, int bitwidth, + int opIdx) const { auto rank = shape.size(); auto warpsPerCTA = getWarpsPerCTA(); + // {batch, m, n, k} + // Hopper path never uses the n value, since this method is only invoked + // for in-RF (dotOpEnc) operands, but WGMMA only supports in A to be in RF SmallVector shapePerWarp = {1, 16, 8, 4 * 64 / bitwidth}; int numRepBatch = rank == 3 ? std::max(1, shape[0] / (shapePerWarp[0] * warpsPerCTA[0])) : 1; - assert(isAmpere()); - if (opIdx == 0) + if (opIdx == 0) { return {numRepBatch, - std::max(1, shape[rank - 2] / + std::max(1, /*repM=*/shape[rank - 2] / (shapePerWarp[1] * warpsPerCTA[rank - 2])), - std::max(1, shape[rank - 1] / shapePerWarp[3])}; - else { + std::max(1, /*repK=*/shape[rank - 1] / shapePerWarp[3])}; + } else { assert(opIdx == 1); - return {numRepBatch, - std::max(1, shape[rank - 2] / shapePerWarp[3]), - std::max(1, shape[rank - 1] / (shapePerWarp[2] * - warpsPerCTA[rank - 1]))}; + return { + numRepBatch, + std::max(1, /*repK=*/shape[rank - 2] / shapePerWarp[3]), + std::max(1, /*repN=*/shape[rank - 1] / + (shapePerWarp[2] * warpsPerCTA[rank - 1]))}; } } + unsigned NvidiaMmaEncodingAttr::getTotalElemsPerThreadForOperand( ArrayRef shape, Type eltTy, int kWidth, int opIdx) const { auto shapePerCTA = getShapePerCTA(*this, shape); @@ -2045,16 +2081,13 @@ unsigned NvidiaMmaEncodingAttr::getTotalElemsPerThreadForOperand( int warpsPerCTAN = getWarpsPerCTA()[1]; // H100 if (isHopper()) { - return getTotalElemsPerThread(shape, eltTy); - } - // A100 - if (isAmpere()) { - auto rep = getMMAv2RepForOperand(shapePerCTA, eltTy.getIntOrFloatBitWidth(), - kWidth, opIdx); - if (opIdx == 0) - return 4 * rep[0] * rep[1] * rep[2]; - if (opIdx == 1) - return 4 * rep[0] * rep[1] * std::max(rep[2] / 2, 1); + assert(opIdx == 0); + auto instrMNK = getInstrShape(); + int repM = ceil(shapePerCTA[0], instrMNK[0] * warpsPerCTAM); + int repK = ceil(shapePerCTA[1], instrMNK[2]); + // For each WGMMA instr, a 2x2 matrix fragment is loaded. Each thread holds + // kWidth elements for each quadrant. WGMMA is repeated repM * repK times. + return 4 * kWidth * repM * repK; } // V100 if (isVolta()) { @@ -2121,25 +2154,12 @@ unsigned NvidiaMmaEncodingAttr::getTotalElemsPerThreadForOperand( SmallVector NvidiaMmaEncodingAttr::getShapePerCTATileForOperand( ArrayRef shape, int kWidth, int opIdx) const { assert(isAmpere() && "mmaLayout version = 1 is not implemented yet"); - auto parentShapePerCTATile = getShapePerCTATile(shape); - auto rank = parentShapePerCTATile.size(); + auto shapePerCTATile = getShapePerCTATile(shape); + auto rank = shapePerCTATile.size(); + auto kDim = opIdx == 0 ? rank - 1 : rank - 2; // 4 threads * 2 subtiles - unsigned kWidthTile = kWidth * 2 * 4; - if (opIdx == 0) { - if (rank == 2) - return {parentShapePerCTATile[rank - 2], kWidthTile}; - else - return {parentShapePerCTATile[0], parentShapePerCTATile[rank - 2], - kWidthTile}; - } else if (opIdx == 1) { - if (rank == 2) - return {kWidthTile, parentShapePerCTATile[rank - 1]}; - else - return {parentShapePerCTATile[0], kWidthTile, - parentShapePerCTATile[rank - 1]}; - } else { - llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1"); - } + shapePerCTATile[kDim] = kWidth * 2 * 4; + return shapePerCTATile; } SmallVector NvidiaMmaEncodingAttr::getSizePerThreadForOperand(int kWidth, int opIdx) const { @@ -2149,11 +2169,10 @@ NvidiaMmaEncodingAttr::getSizePerThreadForOperand(int kWidth, int opIdx) const { if (opIdx == 0) { sizePerThread[rank - 2] = 2; sizePerThread[rank - 1] = 2 * kWidth; - } else if (opIdx == 1) { + } else { + assert(opIdx == 1); sizePerThread[rank - 2] = 2 * kWidth; sizePerThread[rank - 1] = 1; - } else { - llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1"); } return sizePerThread; } @@ -2161,6 +2180,15 @@ NvidiaMmaEncodingAttr::getSizePerThreadForOperand(int kWidth, int opIdx) const { //===----------------------------------------------------------------------===// // DotOperand Encoding //===----------------------------------------------------------------------===// +SmallVector DotOperandEncodingAttr::getRepOrder() const { + if (auto mma = mlir::dyn_cast(getParent())) { + return mma.getRepOrderForOperand(getOpIdx()); + } + llvm::report_fatal_error( + "getRepOrder not implemented for DotOperandEncodingAttr"); + return {}; +} + SmallVector DotOperandEncodingAttr::getThreadsPerWarp() const { auto parent = getParent(); if (auto mma = mlir::dyn_cast(parent)) { diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index 56af4eaef8b9..43c87af487a1 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -41,6 +41,17 @@ SmallVector standardOutDimNames(MLIRContext *ctx, int rank) { return ret; } +// TODO Have order be a mandatory argument of standardOutDimNames. +SmallVector permuteDimNames(const SmallVector &names, + const SmallVector &order) { + assert(names.size() == order.size()); + SmallVector ret; + for (unsigned i : order) { + ret.push_back(names[i]); + } + return ret; +} + void assertIsRegisterLayout(const LinearLayout &layout) { assert(layout.getNumInDims() > 0); MLIRContext *ctx = layout.getInDimNames().begin()->getContext(); @@ -282,14 +293,18 @@ LinearLayout ampereMmaToLinearLayout(ArrayRef shape, MLIRContext *ctx = mma.getContext(); SmallVector dimNames = standardOutDimNames(ctx, rank); + auto orderedDimNames = permuteDimNames(dimNames, mma.getRepOrder()); + assert(mma.getRepOrder() == getMatrixOrder(rank, /*rowMajor=*/true)); + LinearLayout ctaLayout( {{S("register"), {{1, 0}, {0, 8}}}, {S("lane"), {{2, 0}, {4, 0}, {0, 1}, {0, 2}, {0, 4}}}}, - llvm::to_vector(llvm::reverse(ArrayRef(dimNames).take_back(2)))); - - ctaLayout *= identityND( - S("warp"), mma.getWarpsPerCTA(), - llvm::to_vector(llvm::reverse(llvm::seq(rank))), dimNames); + ArrayRef(orderedDimNames).take_front(2)); + assert(getWarpOrder(mma) == getMatrixOrder(rank, /*rowMajor=*/true)); + // FIXME(Lezcano). identityND should not have an `order` param as it's + // redundant with the order of the out dims. + ctaLayout *= + identityND(S("warp"), mma.getWarpsPerCTA(), mma.getWarpOrder(), dimNames); return combineCtaCgaWithShape(ctaLayout, mma.getCTALayout(), shape); } @@ -322,10 +337,14 @@ LinearLayout hopperMmaToLinearLayout(ArrayRef shape, ctaLayout *= LinearLayout::identity1D(n / ctaLayout.getOutDimSize(S("dim1")), S("register"), S("dim1")); - // Expand the `warp` dimension according to warpsPerCTA. - // - // It's weird that this is order [0,1] when MMAv2's warpsPerCTA is [1,0], but - // this really does seem to be correct. + // The order given by choosing (`dim1`, `dim0`) is [1, 0], that is, N-major. + // Since the warpOrder needs to be M-major, we need to transpose the out + // dimensions AND transpose the order + // FIXME(Lezcano). identityND should not have an `order` param as it's + // redundant. The order is already given by the order of the + // out dims, and if it has an order, it shouldn't change the + // order of the out dims. + assert(getWarpOrder(mma) == SmallVector({0, 1})); ctaLayout *= identityND(S("warp"), mma.getWarpsPerCTA(), /*order=*/{0, 1}, {S("dim0"), S("dim1")}) .transposeOuts(llvm::to_vector(ctaLayout.getOutDimNames())); @@ -551,8 +570,8 @@ AMDMfmaEncodingAttr::toLinearLayout(ArrayRef shape) const { } std::optional -dotOperandMfmaToLinearLayout(DotOperandEncodingAttr dotMfmaLayout, - ArrayRef shape) { +mfmaDotToLinearLayout(DotOperandEncodingAttr dotMfmaLayout, + ArrayRef shape) { // Current linear layout conversion for dot operand is only necessary to // enable LDS bypass for operand B in the MFMA dot path. To achieve @@ -843,20 +862,30 @@ SliceEncodingAttr::toLinearLayout(ArrayRef shape) const { LinearLayout ampereDotToLinearLayout(ArrayRef shape, DotOperandEncodingAttr dot) { - // TODO,BE. Implement ampereMMA in terms of this one + // Note that, even though MMAv2 looks similar to this layout, they are just + // the same at a register and lane level. The warps treatment is different! int rank = shape.size(); auto mma = cast(dot.getParent()); int kWidth = dot.getKWidth(); bool isA = dot.getOpIdx() == 0; - assert(mma.isAmpere()); assert((rank == 2 && mma.getInstrShape() == ArrayRef({16, 8})) || (rank == 3 && mma.getInstrShape() == ArrayRef({1, 16, 8}))); + assert(mma.isAmpere()); MLIRContext *ctx = mma.getContext(); - SmallVector dimNames = standardOutDimNames(ctx, rank); - // Implement A. For B transpose in the end + // The A and B operands are tiled in a kMajor fashion + auto kMajorOrder = dot.getRepOrder(); + assert(kMajorOrder == + getOrderForDotOperand(dot.getOpIdx(), rank, /*kMajor=*/true)); + + auto kMajorDims = + permuteDimNames(standardOutDimNames(ctx, rank), kMajorOrder); + // This agrees with the order of the elements, which means that we can share + // the code below for both A and B without having to perform any swaps + assert(getOrder(dot) == kMajorOrder); + std::vector> registers; std::vector> lanes; int32_t i = 1; @@ -881,35 +910,60 @@ LinearLayout ampereDotToLinearLayout(ArrayRef shape, } registers.push_back({i, 0}); - if (!isA) { - for (auto &r : registers) { - std::swap(r[0], r[1]); + LinearLayout ctaLayout({{S("register"), registers}, {S("lane"), lanes}}, + ArrayRef(kMajorDims).take_front(2)); + + // Let warpsPerCTAMma = {2, 2}, then + // warpsPerCTA = {2, 1} for opA and warpsPerCTA = {1, 2} for opB + // assume warpOrder = {0, 1} + // Assume that C is tiled by 2x2 tiles. Since warpOrder={1, 0}, we have that + // the C is owned as per the following layout: + // C: 0 | 1 + // - | - + // 2 | 3 + // In order to be able to compute C, we need the following warp tiling of + // A and B: + // A: 0 1 | 0 1 B: 0 2 | 1 3 + // - - | - - - - | - - + // 2 3 | 2 3 0 2 | 1 3 + // In particular, for A and B we need to broadcast along K + + assert(mma.getWarpOrder() == getMatrixOrder(rank, /*rowMajor=*/true)); + auto warpsPerCTAMma = mma.getWarpsPerCTA(); + std::vector> warps; + if (isA) { + for (int i = 1; i < warpsPerCTAMma[1]; i *= 2) { + warps.push_back({0, 0}); + } + for (int i = 1; i < warpsPerCTAMma[0]; i *= 2) { + warps.push_back({0, i}); + } + } else { + for (int i = 1; i < warpsPerCTAMma[1]; i *= 2) { + warps.push_back({0, i}); } - for (auto &l : lanes) { - std::swap(l[0], l[1]); + for (int i = 1; i < warpsPerCTAMma[0]; i *= 2) { + warps.push_back({0, 0}); + } + } + if (rank == 3) { + for (auto &w : warps) { + w.push_back(0); } } - LinearLayout ctaLayout( - {{S("register"), registers}, {S("lane"), lanes}}, - llvm::to_vector(llvm::reverse(ArrayRef(dimNames).take_back(2)))); - - auto order = dot.getCTAOrder(); - assert(order[0] == 1 && order[1] == 0); - ctaLayout *= identityND(S("warp"), dot.getWarpsPerCTA(), order, dimNames); + ctaLayout *= LinearLayout({{S("warp"), warps}}, kMajorDims); - return combineCtaCgaWithShape(ctaLayout, mma.getCTALayout(), shape); + return combineCtaCgaWithShape(ctaLayout, getCTALayout(dot), shape); } std::optional DotOperandEncodingAttr::toLinearLayout(ArrayRef shape) const { - if (auto mfmaLayout = llvm::dyn_cast(getParent())) { - return dotOperandMfmaToLinearLayout(*this, shape); - } else if (auto mma = mlir::dyn_cast(getParent())) { - // FIXME [Dot LL] - // Do this unconditionally - auto largeKWidth = getKWidth() == 8; - if (mma.isAmpere() && largeKWidth) { + auto parent = getParent(); + if (auto mfmaLayout = llvm::dyn_cast(parent)) { + return mfmaDotToLinearLayout(*this, shape); + } else if (auto mma = mlir::dyn_cast(parent)) { + if (mma.isAmpere()) { return ampereDotToLinearLayout(shape, *this); } } diff --git a/lib/Dialect/TritonGPU/IR/Ops.cpp b/lib/Dialect/TritonGPU/IR/Ops.cpp index e61fe096e10b..65c647e16b1e 100644 --- a/lib/Dialect/TritonGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonGPU/IR/Ops.cpp @@ -34,13 +34,13 @@ LogicalResult UpcastMXFPOp::verify() { "operands must have the same number of dimensions, at least 2"); } - if (!(fpType == F8F6F4Type::E2M1 || fpType == F8F6F4Type::E4M3 || - fpType == F8F6F4Type::E5M2)) { + if (!(fpType == ScaleDotElemType::E2M1 || fpType == ScaleDotElemType::E4M3 || + fpType == ScaleDotElemType::E5M2)) { return emitOpError("NYI: fpType must be E2M1, E4M3, or E5M2"); } // Change to support fp8 types - const auto elems_packed = fpType == F8F6F4Type::E2M1 ? 2 : 1; + const auto elems_packed = fpType == ScaleDotElemType::E2M1 ? 2 : 1; if (xShape.back() != (32 / elems_packed) * scaleShape.back()) { return emitOpError("last dimension of first operand must be 16 times " @@ -52,21 +52,25 @@ LogicalResult UpcastMXFPOp::verify() { "all dimensions except the last must match between operands"); } - auto layoutX = xTy.getEncoding(); - if (!layoutX || !isa(layoutX)) { + auto dotEncoding = + dyn_cast_or_null(xTy.getEncoding()); + if (!dotEncoding) { return emitOpError("Expected a DotOperandEncodingAttr for values"); } - auto layoutScale = scaleTy.getEncoding(); - if (!layoutScale || !isa(layoutScale)) { + + auto blockedScale = + dyn_cast_or_null(scaleTy.getEncoding()); + if (!blockedScale) { return emitOpError("Expected a BlockOperandEncoding for scales"); } - auto blockedScale = cast(layoutScale); - // Necessary to keep all of the scales of a given block of values in the same - // warp - auto threadsPerWarp = blockedScale.getThreadsPerWarp(); - if (threadsPerWarp != ArrayRef({16, 2})) { - return emitOpError("Expected threads per warp to be {16, 2}"); + if (isa(dotEncoding.getParent())) { + // Necessary to keep all of the scales of a given block of values in the + // same warp + auto threadsPerWarp = blockedScale.getThreadsPerWarp(); + if (threadsPerWarp != ArrayRef({16, 2})) { + return emitOpError("Expected threads per warp to be {16, 2}"); + } } return success(); @@ -89,7 +93,7 @@ LogicalResult UpcastMXFPOp::inferReturnTypes( return emitOptionalError(loc, "expected a dotOperand encoding"); } - if (typeEncoded == F8F6F4Type::E2M1) { + if (typeEncoded == ScaleDotElemType::E2M1) { auto oldEncoding = cast(encoding); auto newVEncoding = DotOperandEncodingAttr::get( ctx, oldEncoding.getOpIdx(), oldEncoding.getParent(), diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp index a2d4012bf23e..ef9ed531cec5 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp @@ -183,7 +183,7 @@ class BlockedToMMA : public mlir::OpRewritePattern { // elements distribution to the order of higher precision primitives. As a // result, kwidth can be the bitwidth of the lower precision primitive. // Conversely, in the downcasting scenario, no reordering is performed, - // making it directory use the lower precision primitive. + // making it directly use the lower precision primitive. static int computeOrigBitWidth(Value x) { int finalBitWidth = getElementTypeOrSelf(x).getIntOrFloatBitWidth(); int origBitWidth = finalBitWidth; @@ -415,22 +415,12 @@ class ScaledBlockedToMMAv2 auto aType = dotOp.getLhsType(); auto bType = dotOp.getRhsType(); - auto enumToType = [&rewriter](F8F6F4Type type) { - switch (type) { - case F8F6F4Type::E4M3: - return rewriter.getFloat8E4M3FNType(); - case F8F6F4Type::E5M2: - return rewriter.getFloat8E5M2Type(); - default: - llvm_unreachable("unexpected type"); - } - }; - - assert((aType == F8F6F4Type::E4M3 || aType == F8F6F4Type::E5M2 || - aType == F8F6F4Type::E2M1) && + assert((aType == ScaleDotElemType::E4M3 || + aType == ScaleDotElemType::E5M2 || + aType == ScaleDotElemType::E2M1) && "NYI: lhs supports fp4 or fp8"); - assert(bType == F8F6F4Type::E4M3 || - bType == F8F6F4Type::E5M2 && "NYI: rhs supports fp8"); + assert(bType == ScaleDotElemType::E4M3 || bType == ScaleDotElemType::E5M2 || + bType == ScaleDotElemType::BF16 && "NYI: rhs supports fp8 and bf16"); // TODO run accelerate matmul on A and B first to choose their layouts // Set return type @@ -454,11 +444,12 @@ class ScaledBlockedToMMAv2 auto newAcc = rewriter.create(oldAcc.getLoc(), newRetType, oldAcc); - auto toMMABf16 = [&newRetType, &rewriter, &ctx, &enumToType]( - TypedValue v, int idx, - F8F6F4Type type) -> TypedValue { + auto toMMABf16 = + [&newRetType, &rewriter, + &ctx](TypedValue v, int idx, + ScaleDotElemType type) -> TypedValue { auto vType = v.getType(); - if (type == F8F6F4Type::E2M1) { + if (type == ScaleDotElemType::E2M1) { // A bit too dynamically typed... // perhaps return ints in both cases? @@ -469,23 +460,23 @@ class ScaledBlockedToMMAv2 vType.getShape(), vType.getElementType(), newVEncoding); return rewriter.create(v.getLoc(), newVType, v); } else { - assert(type == F8F6F4Type::E5M2 || type == F8F6F4Type::E4M3); + assert(type == ScaleDotElemType::E5M2 || + type == ScaleDotElemType::E4M3 || + type == ScaleDotElemType::BF16); auto newVEncoding = DotOperandEncodingAttr::get( ctx, idx, newRetType.getEncoding(), /*kWidth=*/8); auto newVType = RankedTensorType::get( vType.getShape(), vType.getElementType(), newVEncoding); v = rewriter.create(v.getLoc(), newVType, v); - // Bitcast - auto vTypeFp8 = RankedTensorType::get(vType.getShape(), - enumToType(type), newVEncoding); - v = cast>( - rewriter.create(v.getLoc(), vTypeFp8, v).getResult()); - - // Convert to bf16 - auto vTypeBf16 = RankedTensorType::get( - vType.getShape(), rewriter.getBF16Type(), newVEncoding); - return rewriter.create(v.getLoc(), vTypeBf16, v); + if (type == ScaleDotElemType::BF16) { + return v; + } else { + // Convert to bf16 + auto vTypeBf16 = RankedTensorType::get( + vType.getShape(), rewriter.getBF16Type(), newVEncoding); + return rewriter.create(v.getLoc(), vTypeBf16, v); + } } }; a = toMMABf16(a, 0, aType); @@ -515,11 +506,11 @@ class ScaledBlockedToMMAv2 auto newScaleEncoding = triton::gpu::BlockedEncodingAttr::get( ctx, {1, 1}, threadsPerWarp, warpsPerCTA, {1, 0}, CTALayout); - auto newScaleType = RankedTensorType::get(scale.getType().getShape(), - scale.getType().getElementType(), - newScaleEncoding); - scale = - rewriter.create(scale.getLoc(), newScaleType, scale); + auto newScaleDotElemType = RankedTensorType::get( + scale.getType().getShape(), scale.getType().getElementType(), + newScaleEncoding); + scale = rewriter.create(scale.getLoc(), + newScaleDotElemType, scale); auto scaledA = rewriter.create( dotOp.getLoc(), a, scale, dotOp.getLhsType()); diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index 6d8279795209..4695984acfd3 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -286,11 +286,12 @@ struct MMAV3UseRegOperand dstEnc.getVersionMajor() != 3) return failure(); auto srcTy = cast(alloc.getSrc().getType()); + auto kWidth = 32 / srcTy.getElementTypeBitWidth(); auto dotOperandEnc = DotOperandEncodingAttr::get( - dotOp.getContext(), /*opIdx=*/0, srcEnc, /*kWidth=*/0); + dotOp.getContext(), /*opIdx=*/0, srcEnc, /*kWidth=*/kWidth); auto newTy = RankedTensorType::get(srcTy.getShape(), srcTy.getElementType(), dotOperandEnc); - if (!isMmaToDotShortcut(srcTy, newTy)) + if (!matchMmaV3AndDotOperandLayout(srcTy, newTy)) return failure(); Value newOperand = diff --git a/lib/Tools/LinearLayout.cpp b/lib/Tools/LinearLayout.cpp index bf017f8c6463..4319d1f086dd 100644 --- a/lib/Tools/LinearLayout.cpp +++ b/lib/Tools/LinearLayout.cpp @@ -1016,6 +1016,21 @@ bool LinearLayout::equalIgnoringOutDimSizes(const LinearLayout &other) const { return true; } +LinearLayout LinearLayout::resize(StringAttr inDim, + int32_t newInDimSize) const { + BasesT bases = getBases(); + assert(bases.contains(inDim) && "inDim not in layout"); + assert(llvm::isPowerOf2_32(newInDimSize) && + "newInDimSize must be a power of 2"); + assert(newInDimSize >= getInDimSize(inDim) && + "newInDimSize must be >= old size"); + auto numFreeVariables = llvm::Log2_32(newInDimSize) - getInDimSizeLog2(inDim); + for (int i = 0; i < numFreeVariables; i++) { + bases[inDim].push_back(std::vector(getNumOutDims(), 0)); + } + return LinearLayout(std::move(bases), llvm::to_vector(getOutDimNames())); +} + std::string LinearLayout::toString() const { // Start with a newline because we print out a bulleted list; it doesn't // make sense for the first line of this list to be on the same line as diff --git a/python/src/ir.cc b/python/src/ir.cc index 9945c6188294..cce7c87e8d87 100644 --- a/python/src/ir.cc +++ b/python/src/ir.cc @@ -205,12 +205,13 @@ void init_triton_ir(py::module &&m) { .value("IEEE", InputPrecision::IEEE) .export_values(); - py::enum_(m, "F8F6F4TY", py::module_local()) - .value("E4M3", F8F6F4Type::E4M3) - .value("E5M2", F8F6F4Type::E5M2) - .value("E2M3", F8F6F4Type::E2M3) - .value("E3M2", F8F6F4Type::E3M2) - .value("E2M1", F8F6F4Type::E2M1) + py::enum_(m, "ScaleDotElemTypeTY", py::module_local()) + .value("E4M3", ScaleDotElemType::E4M3) + .value("E5M2", ScaleDotElemType::E5M2) + .value("E2M3", ScaleDotElemType::E2M3) + .value("E3M2", ScaleDotElemType::E3M2) + .value("E2M1", ScaleDotElemType::E2M1) + .value("BF16", ScaleDotElemType::BF16) .export_values(); py::class_(m, "context", py::module_local()) @@ -1423,9 +1424,9 @@ void init_triton_ir(py::module &&m) { }) .def("create_dot_scaled", [](TritonOpBuilder &self, mlir::Value &lhs, mlir::Value &lhs_scale, - F8F6F4Type lhs_format, mlir::Value &rhs, - std::optional &rhs_scale, F8F6F4Type rhs_format, - mlir::Value &c) -> mlir::Value { + ScaleDotElemType lhs_format, mlir::Value &rhs, + std::optional &rhs_scale, + ScaleDotElemType rhs_format, mlir::Value &c) -> mlir::Value { return self.create( c.getType(), lhs, rhs, c, lhs_scale, rhs_scale.value_or(Value()), lhs_format, rhs_format); diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 569ed16bbd8b..7499391c9f2c 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -28,6 +28,7 @@ is_cuda, is_interpreter, is_hip, + is_hip_mi200, get_arch, torch_float8_dtypes, torch_dtypes, @@ -139,6 +140,17 @@ def __str__(self): return f"#{GPU_DIALECT}.nvidia_mma<{{versionMajor={self.version[0]}, versionMinor={self.version[1]}, warpsPerCTA={self.warps_per_cta}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}, instrShape={self.instr_shape}}}>" +class DotOperandLayout: + + def __init__(self, parent, op_idx, k_width): + self.parent = parent + self.op_idx = op_idx + self.k_width = k_width + + def __str__(self): + return f"#{GPU_DIALECT}.dot_op<{{parent={self.parent}, opIdx={self.op_idx}, kWidth={self.k_width}}}>" + + class BlockedLayout: def __init__(self, size_per_thread, threads_per_warp, warps_per_cta, order, ctas_per_cga, cta_split_num, cta_order): @@ -1720,6 +1732,7 @@ def kernel(output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): assert torch.all(output == ref) +<<<<<<< HEAD @pytest.mark.interpreter @pytest.mark.parametrize("num_ctas", num_ctas_list) @@ -3326,25 +3339,30 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid assert 'wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3' in ptx -@pytest.mark.parametrize("M, N, K, col_a, col_b, type_a, type_b, num_warps", - [(M, N, K, col_a, col_b, type_a, type_b, 4) +@pytest.mark.parametrize("M, N, K, col_a, col_b, type_a, type_b, num_warps, mma, kpack", + [(M, N, K, col_a, col_b, type_a, type_b, 4, mma, kpack) for M, N, K in itertools.product([32, 64, 128], [32, 64, 128], [64, 128]) for col_a, col_b in itertools.product([True, False], repeat=2) for type_a in ["e2m1", "e4m3", "e5m2"] - for type_b in ["e4m3", "e5m2"]]) -def test_scaled_dot(M, N, K, col_a, col_b, type_a, type_b, num_warps, device): - if not is_cuda(): - pytest.skip("scaled_dot only supported on CUDA") - else: + for type_b in ["e4m3", "e5m2", "bf16"] + for mma in ([32, 16] if is_hip() else [16]) + for kpack in ([1, 2] if is_hip() else [1])]) +def test_scaled_dot(M, N, K, col_a, col_b, type_a, type_b, num_warps, mma, kpack, device): + if is_cuda(): cc = torch.cuda.get_device_capability() if cc < (8, 9): pytest.skip("float8e4nv not supported on CUDA < 8.9") + if is_hip(): + if (type_a not in ["e2m1", "e5m2"]) or (type_b not in ["e2m1", "e5m2", "bf16"]): + pytest.skip(f"scaled_dot({type_a}, {type_b}) not yet implemented for HIP") + if mma == 16 and K == 64: + pytest.skip(f"K == {K} too small for mfma {mma} in scaled_dot") @triton.jit def dot_scale_kernel(a_base, stride_a0, stride_a1, a_scale, b_base, stride_b0, stride_b1, out, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, type_a: tl.constexpr, type_b: tl.constexpr): - tl.static_assert(type_b == "e4m3" or type_b == "e5m2", "type_b must be fp8") + tl.static_assert((type_b == "e4m3" or type_b == "e5m2") or type_b == "bf16", "type_b must be fp8 or bf16") IS_FP8: tl.constexpr = type_a == "e4m3" or type_a == "e5m2" DIV_FACTOR: tl.constexpr = 1 if IS_FP8 else 2 PACKED_BLOCK_K_A: tl.constexpr = BLOCK_K // DIV_FACTOR @@ -3435,7 +3453,7 @@ def mxfp_to_bf16_kernel( def dot_scale_ref(x, scale, y, type_x, type_y): e_bits, m_bits = {"e2m1": (2, 1), "e4m3": (4, 3), "e5m2": (5, 2)}[type_x] - type_fp8_y = {"e4m3": torch.float8_e4m3fn, "e5m2": torch.float8_e5m2}[type_y] + type_y = {"e4m3": torch.float8_e4m3fn, "e5m2": torch.float8_e5m2, "bf16": torch.bfloat16}[type_y] comp_dtype = torch.bfloat16 @@ -3448,7 +3466,7 @@ def dot_scale_ref(x, scale, y, type_x, type_y): mxfp_to_bf16_kernel[grid](x, scale, x_upcast, scale.numel(), e_bits, m_bits, BLOCK_SIZE, num_warps=num_warps) assert x_upcast.isfinite().all() - y_upcast = y.view(type_fp8_y).to(comp_dtype) + y_upcast = y.view(type_y).to(comp_dtype) class AccumulateInFp32: @@ -3460,28 +3478,30 @@ def __exit__(self, exc_type, exc_val, exc_tb): torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = self.prev_value with AccumulateInFp32(): - return torch.matmul(x_upcast.to(comp_dtype), y_upcast.to(comp_dtype)) + return torch.matmul(x_upcast, y_upcast) torch.manual_seed(0) - def create_uint8(shape, col_major=False, max_val=255): + def make_arg(shape, ty, col_major=False, max_val=255): if col_major: shape = shape[:-2] + (shape[-1], shape[-2]) - ret = torch.randint(max_val + 1, shape, dtype=torch.uint8, device=device) + if ty == "bf16": + ret = torch.randn(shape, dtype=torch.bfloat16, device=device) + # Clamp to avoid relative error issues + ret.clamp_(-2**15, 2**15 - 1) + else: + ret = torch.randint(max_val + 1, shape, dtype=torch.uint8, device=device) if col_major: ret = ret.mT return ret DIV_FACTOR = 2 if type_a == "e2m1" else 1 - x = create_uint8((M, K // DIV_FACTOR), col_major=col_a) - y = create_uint8((K, N), col_major=col_b) + x = make_arg((M, K // DIV_FACTOR), type_a, col_major=col_a) + y = make_arg((K, N), type_b, col_major=col_b) # sample scales that don't overflow as otherwise it's implementation defined (underflowing is alright) - # We substract a reasonably high number (64) so that the sum of all the mxfp elements does not overflow - m_bytes = int(type_a[1]) - bias_type_a = 1 << (m_bytes - 1) - 1 - max_exponent_type_a = (1 << m_bytes) - 1 - bias_type_a - scale_x = create_uint8((M, K // 32), max_val=255 - max_exponent_type_a - 64) + # Max scale= 2**15 + scale_x = make_arg((M, K // 32), "e8m0", max_val=127 + 15) def make_finite(x, dtype): # e5m2 has too many non-finite values when sampled uniformly (1 / 32) and @@ -3497,22 +3517,31 @@ def make_finite(x, dtype): x = make_finite(x, type_a) y = make_finite(y, type_b) + kernel_kwargs = {"num_warps": num_warps} + if is_hip(): + kernel_kwargs["kpack"] = kpack + kernel_kwargs["matrix_instr_nonkdim"] = mma z = x.new_empty((M, N), dtype=torch.bfloat16) - pgm = dot_scale_kernel[(1, )](x, *x.stride(), scale_x, y, *y.stride(), z, M, N, K, type_a, type_b, - num_warps=num_warps) + pgm = dot_scale_kernel[(1, )](x, *x.stride(), scale_x, y, *y.stride(), z, M, N, K, type_a, type_b, **kernel_kwargs) z_ref = dot_scale_ref(x, scale_x, y, type_a, type_b) - # generous rtol as we are sampling the whole range of floats - torch.testing.assert_close(z, z_ref, atol=1e-5, rtol=1e-2) + # Bigger tolerance for AMD MI200 devices. + # MI200 devices use reduced precision fp16 and bf16 and flush input and output denormal values + # to zero. Detailed info is at: + # https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices + atol = 2e-4 if is_hip_mi200() else 1e-5 + rtol = 2e-2 if is_hip_mi200() else 1e-2 + torch.testing.assert_close(z, z_ref, atol=atol, rtol=rtol) # make sure ld/st are vectorized - ptx = pgm.asm['ptx'] - if (max(M, N) * K) // (num_warps * 32) >= 4: - assert 'ld.global.v4' in ptx - if M * N // (num_warps * 32) >= 4: - assert 'st.global.v4' in ptx - assert re.search(r'mma.sync.aligned.m\d+n\d+k16(?:.row.col)?.f32.bf16.bf16', ptx) + if is_cuda(): + ptx = pgm.asm['ptx'] + if (max(M, N) * K) // (num_warps * 32) >= 4: + assert 'ld.global.v4' in ptx + if M * N // (num_warps * 32) >= 4: + assert 'st.global.v4' in ptx + assert re.search(r'mma.sync.aligned.m\d+n\d+k16(?:.row.col)?.f32.bf16.bf16', ptx) @pytest.mark.interpreter @@ -5173,6 +5202,14 @@ def kernel(Out): BlockedLayout([4, 1], [8, THREADS_PER_WARP // 8], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]), BlockedLayout([1, 1], [THREADS_PER_WARP, 1], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]), BlockedLayout([4, 4], [1, THREADS_PER_WARP], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + DotOperandLayout(parent=MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=0, k_width=2), + DotOperandLayout(parent=MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=1, k_width=2), + DotOperandLayout(parent=MmaLayout([2, 0], [2, 2], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=0, k_width=2), + DotOperandLayout(parent=MmaLayout([2, 0], [2, 2], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=1, k_width=2), + DotOperandLayout(parent=MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=0, k_width=8), + DotOperandLayout(parent=MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=1, k_width=8), + DotOperandLayout(parent=MmaLayout([2, 0], [2, 2], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=0, k_width=8), + DotOperandLayout(parent=MmaLayout([2, 0], [2, 2], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=1, k_width=8), MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 8]), ] @@ -5210,6 +5247,10 @@ def compute_scratch_buffer_shape(src_layout, dst_layout, shape): def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device): if str(src_layout) == str(dst_layout): pytest.skip() + if (isinstance(src_layout, DotOperandLayout) + and isinstance(interm_layout, SharedLayout)) or (isinstance(dst_layout, DotOperandLayout) + and isinstance(interm_layout, SharedLayout)): + pytest.skip("DotOperandLayout <-> SharedLayout conversion is not completely supported") if is_hip(): try: scratch_shape = compute_scratch_buffer_shape(src_layout, dst_layout, (M, N)) diff --git a/python/triton/_internal_testing.py b/python/triton/_internal_testing.py index f8909f7c0587..d8b76b6ea79a 100644 --- a/python/triton/_internal_testing.py +++ b/python/triton/_internal_testing.py @@ -40,6 +40,11 @@ def is_hip(): return False if target is None else target.backend == "hip" +def is_hip_mi200(): + target = get_current_target() + return target.backend == 'hip' and target.arch == 'gfx90a' + + def get_arch(): target = get_current_target() return "" if target is None else str(target.arch) diff --git a/python/triton/language/core.py b/python/triton/language/core.py index e2c57b388bb0..856b537c5103 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -1555,15 +1555,17 @@ def dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc=None, lhs and rhs use microscaling formats described here: https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf :param lhs: The first tensor to be multiplied. - :type lhs: 2D tensor of f8, f6 or f4 format packed in int32 format. + :type lhs: 2D tensor representing fp4 or fp8 elements packed into uint8 for fp4 inputs, or in uint8 or the corresponding fp8 type for fp8 inputs. :param lhs_scale: Scale factor for lhs tensor. - :type lhs_scale: ue8m0 float8 type (currently represented as an int8 tensor). - :param lhs_format: format of the lhs tensor, available formats: {:code:`e4m3`, :code: `e5m2`, :code:`e2m3`, :code:`e3m2`, :code:`e2m1`}. + :type lhs_scale: e8m0 type represented as an uint8 tensor. + :param lhs_format: format of the lhs tensor. Available formats: {:code:`e2m1`, :code:`e4m3`, :code: `e5m2`}. + :type lhs_format: str :param rhs: The second tensor to be multiplied. - :type rhs: 2D tensor of f8, f6 or f4 format packed in int32 format. + :type rhs: 2D tensor representing fp8 or bf16 elements in uint8 or the corresponding fp8 type for fp8 inputs or bf16 for bf16 inputs. :param rhs_scale: Scale factor for rhs tensor. - :type rhs_scale: ue8m0 float8 type (currently represented as an int8 tensor). - :param rhs_format: format of the rhs tensor, available formats: {:code:`e4m3`, :code: `e5m2`, :code:`e2m3`, :code:`e3m2`, :code:`e2m1`}. + :type rhs_scale: e8m0 type represented as an uint8 tensor. + :param rhs_format: format of the rhs tensor. Available formats: {:code:`e4m3`, :code: `e5m2`, :code:`bf16`}. + :type rhs_format: str :param acc: The accumulator tensor. If not None, the result is added to this tensor. """ out_dtype = _constexpr_to_value(out_dtype) diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 8e9f87b5ed12..fe20aa436ad8 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -1527,33 +1527,48 @@ def dot(lhs: tl.tensor, rhs: tl.tensor, acc: tl.tensor, input_precision: Optiona ret_ty) -def _str_to_fp_type(float_format: Optional[str]): - if float_format == 'e4m3': - return ir.F8F6F4TY.E4M3 - if float_format == 'e5m2': - return ir.F8F6F4TY.E5M2 - if float_format == 'e2m3': - return ir.F8F6F4TY.E2M3 - if float_format == 'e3m2': - return ir.F8F6F4TY.E3M2 - if float_format == 'e2m1': - return ir.F8F6F4TY.E2M1 - raise ValueError(f"Invalid float format: {float_format}.") - - -def dot_scaled(lhs: tl.tensor, lhs_scale: tl.tensor, lhs_format, rhs: tl.tensor, rhs_scale: Optional[tl.tensor], - rhs_format, acc: tl.tensor | None, out_dtype: tl.dtype, builder: ir.builder) -> tl.tensor: +def _str_to_fp_type(float_format: str): + ty_enum = getattr(ir.ScaleDotElemTypeTY, float_format.upper(), None) + if ty_enum is None: + raise ValueError(f"Invalid float format: {float_format}.") + return ty_enum + + +def _bitcast_to_fp_type(val: tl.tensor, float_format: str, builder: ir.builder): + """ + If float_format is subbyte, make sure it's packed as uint8 and return it. + Otherwise, return a tensor (perhaps bitcasting) of the specified float format. + """ + triton_ty = {"e5m2": tl.float8e5, "e4m3": tl.float8e4nv, "bf16": tl.bfloat16}.get(float_format) + if triton_ty is None: + assert float_format == "e2m1", f"Internal Error: Unexpected float format: {float_format}" + assert val.dtype == tl.uint8, f"e2m1 format must be packed as uint8. Got {val.dtype}" + return val + if val.dtype == triton_ty: + return val + else: + unsigned_ty = {"e5m2": tl.uint8, "e4m3": tl.uint8, "bf16": tl.uint16}[float_format] + assert val.dtype == unsigned_ty, f"Unexpected dtype for {float_format}. Got {val.dtype}" + return bitcast(val, triton_ty, builder) + + +def dot_scaled(lhs: tl.tensor, lhs_scale: tl.tensor, lhs_format: str, rhs: tl.tensor, rhs_scale: Optional[tl.tensor], + rhs_format: str, acc: tl.tensor | None, out_dtype: tl.dtype, builder: ir.builder) -> tl.tensor: assert lhs.type.is_block() and rhs.type.is_block() #TODO: validate types. lhs_rank = len(lhs.shape) rhs_rank = len(rhs.shape) assert lhs_rank == rhs_rank == 2 or lhs_rank == rhs_rank == 3, f"Both inputs must be either 2D or 3D; (lhs: {lhs.shape} vs rhs: {rhs.shape})" + lhs_format: str = lhs_format.value + rhs_format: str = rhs_format.value lhs_format_enum = _str_to_fp_type(lhs_format) rhs_format_enum = _str_to_fp_type(rhs_format) assert lhs_format in ("e2m1", "e4m3", "e5m2"), f"NYI: lhs_format {lhs_format}" - assert rhs_format in ("e4m3", "e5m2"), f"NYI: rhs_format {rhs_format}" + assert rhs_format in ("e4m3", "e5m2", "bf16"), f"NYI: rhs_format {rhs_format}" rhs_scale_is_none = isinstance(rhs_scale, tl.constexpr) and rhs_scale.value is None assert rhs_scale_is_none, "NYI: rhs_scale not supported" + lhs = _bitcast_to_fp_type(lhs, lhs_format, builder) + rhs = _bitcast_to_fp_type(rhs, rhs_format, builder) M = lhs.type.shape[-2] K, N = rhs.type.shape[-2:] diff --git a/test/Analysis/test-allocation.mlir b/test/Analysis/test-allocation.mlir index a0719c974f9c..208f6b80bfe5 100644 --- a/test/Analysis/test-allocation.mlir +++ b/test/Analysis/test-allocation.mlir @@ -6,7 +6,7 @@ #A_SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> #A_SHARED_T = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [0, 1]}> #B_SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> -#C = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> +#C = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1], instrShape = [16, 8]}> #A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 2}> #B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 2}> diff --git a/test/Analysis/test-membar.mlir b/test/Analysis/test-membar.mlir index 2054853b30c1..df4f5ab01feb 100644 --- a/test/Analysis/test-membar.mlir +++ b/test/Analysis/test-membar.mlir @@ -5,7 +5,7 @@ #BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> #A_SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> #A_SHARED_T = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [0, 1]}> -#C = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> +#C = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1], instrShape = [16, 8]}> #A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 2}> #B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 2}> diff --git a/test/Conversion/amd/load_store.mlir b/test/Conversion/amd/load_store.mlir index 93796439b012..543ed4f2df12 100644 --- a/test/Conversion/amd/load_store.mlir +++ b/test/Conversion/amd/load_store.mlir @@ -27,3 +27,32 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : tt.return } } + +// ----- + +#mma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 16], isTransposed = true}> +module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: global_store_mfma_vec16 + tt.func public @global_store_mfma_vec16(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> + %cst_0 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + %cst_1 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + %0 = tt.dot %cst_0, %cst_1, %cst : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<32x32xf32, #mma> + %1 = math.exp2 %0 : tensor<32x32xf32, #mma> + %2 = arith.truncf %1 : tensor<32x32xf32, #mma> to tensor<32x32xf16, #mma> + %c32_i32 = arith.constant 32 : i32 + %100 = tt.get_program_id x : i32 + %101 = arith.muli %100, %c32_i32 : i32 + %102 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>> + %300 = tt.expand_dims %102 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>> -> tensor<1x32xi32, #mma> + %200 = tt.broadcast %300 : tensor<1x32xi32, #mma> -> tensor<32x32xi32, #mma> + %103 = tt.splat %101 : i32 -> tensor<32x32xi32, #mma> + %104 = arith.addi %103, %200 : tensor<32x32xi32, #mma> + %105 = tt.splat %arg0 : !tt.ptr -> tensor<32x32x!tt.ptr, #mma> + %106 = tt.addptr %105, %104 : tensor<32x32x!tt.ptr, #mma>, tensor<32x32xi32, #mma> + // Store 16 elements with four vectorized store instruction + // CHECK-COUNT-4: llvm.intr.masked.store {{.*}}, {{.*}}, {{.*}} {alignment = 16 : i32} : vector<4xf16>, vector<4xi1> into !llvm.ptr + tt.store %106, %2 : tensor<32x32x!tt.ptr, #mma> + tt.return + } +} diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index 34573f7739b8..cd5e5bc6363f 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -821,6 +821,110 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- +#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> +#dot1 = #triton_gpu.dot_op<{opIdx=0, parent=#mma, kWidth=2}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { + // CHECK-LABEL: convert_layout_mmav2_dot_reg + tt.func @convert_layout_mmav2_dot_reg(%arg0: tensor<16x16xf16, #mma>) { + // CHECK-NOT: st.shared + // CHECK-NOT: llvm.load + %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf16, #mma> -> tensor<16x16xf16, #dot1> + tt.return + } +} + +// ----- + +#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> +#dot1 = #triton_gpu.dot_op<{opIdx=0, parent=#mma, kWidth=2}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { + // CHECK-LABEL: convert_layout_mmav2_dot_reg + tt.func @convert_layout_mmav2_dot_reg(%arg0: tensor<1x16xf16, #mma>) { + // CHECK-NOT: st.shared + // CHECK-NOT: llvm.load + %0 = triton_gpu.convert_layout %arg0 : tensor<1x16xf16, #mma> -> tensor<1x16xf16, #dot1> + tt.return + } +} + +// ----- + +#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}> +#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#slice = #triton_gpu.slice<{dim = 0, parent = #mma}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: convert_layout_slice_mmav2_blocked_reg + tt.func @convert_layout_slice_mmav2_blocked_reg(%arg0: tensor<1xf16, #slice>) { + // CHECK-NOT: st.shared + // CHECK-NOT: llvm.load + %0 = triton_gpu.convert_layout %arg0 : tensor<1xf16, #slice> -> tensor<1xf16, #blocked> + tt.return + } +} + +// ----- + +#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: convert_layout_mmav3_mmav3_0 + tt.func @convert_layout_mmav3_mmav3_0(%arg0: tensor<64x64xf16, #mma0>) { + // CHECK-NOT: st.shared + // CHECK-NOT: llvm.load + %0 = triton_gpu.convert_layout %arg0 : tensor<64x64xf16, #mma0> -> tensor<64x64xf16, #mma1> + tt.return + } +} + +// ----- + +#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: convert_layout_mmav3_mmav3_1 + tt.func @convert_layout_mmav3_mmav3_1(%arg0: tensor<64x64xf16, #mma1>) { + // CHECK-NOT: st.shared + // CHECK-NOT: llvm.load + %0 = triton_gpu.convert_layout %arg0 : tensor<64x64xf16, #mma1> -> tensor<64x64xf16, #mma0> + tt.return + } +} + +// ----- + +#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: convert_layout_mmav3_mmav3_2 + tt.func @convert_layout_mmav3_mmav3_2(%arg0: tensor<16x16xf16, #mma1>) { + // CHECK-NOT: st.shared + // CHECK-NOT: llvm.load + %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf16, #mma1> -> tensor<16x16xf16, #mma0> + tt.return + } +} + +// ----- + +#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: convert_layout_mmav3_mmav3_3 + tt.func @convert_layout_mmav3_mmav3_3(%arg0: tensor<1x64xf16, #mma1>) { + // CHECK-NOT: st.shared + // CHECK-NOT: llvm.load + %0 = triton_gpu.convert_layout %arg0 : tensor<1x64xf16, #mma1> -> tensor<1x64xf16, #mma0> + tt.return + } +} + +// ----- + #blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> #mma = #triton_gpu.nvidia_mma<{versionMajor = 1, versionMinor = 3, warpsPerCTA = [2, 2], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 16]}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { @@ -845,6 +949,80 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- +#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> +#dot1 = #triton_gpu.dot_op<{opIdx=0, parent=#mma, kWidth=2}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { + // CHECK-LABEL: convert_layout_mmav2_dot_reg + tt.func @convert_layout_mmav2_dot_reg(%arg0: tensor<16x16xf16, #mma>) { + // CHECK-NOT: st.shared + // CHECK-NOT: llvm.load + %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf16, #mma> -> tensor<16x16xf16, #dot1> + tt.return + } +} + +// ----- + +#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: convert_layout_mmav3_mmav3_0 + tt.func @convert_layout_mmav3_mmav3_0(%arg0: tensor<64x64xf16, #mma0>) { + // CHECK-NOT: st.shared + // CHECK-NOT: llvm.load + %0 = triton_gpu.convert_layout %arg0 : tensor<64x64xf16, #mma0> -> tensor<64x64xf16, #mma1> + tt.return + } +} + +// ----- + +#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: convert_layout_mmav3_mmav3_1 + tt.func @convert_layout_mmav3_mmav3_1(%arg0: tensor<64x64xf16, #mma1>) { + // CHECK-NOT: st.shared + // CHECK-NOT: llvm.load + %0 = triton_gpu.convert_layout %arg0 : tensor<64x64xf16, #mma1> -> tensor<64x64xf16, #mma0> + tt.return + } +} + +// ----- + +#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: convert_layout_mmav3_mmav3_2 + tt.func @convert_layout_mmav3_mmav3_2(%arg0: tensor<16x16xf16, #mma1>) { + // CHECK-NOT: st.shared + // CHECK-NOT: llvm.load + %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf16, #mma1> -> tensor<16x16xf16, #mma0> + tt.return + } +} + +// ----- + +#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: convert_layout_mmav3_mmav3_3 + tt.func @convert_layout_mmav3_mmav3_3(%arg0: tensor<1x64xf16, #mma1>) { + // CHECK-NOT: st.shared + // CHECK-NOT: llvm.load + %0 = triton_gpu.convert_layout %arg0 : tensor<1x64xf16, #mma1> -> tensor<1x64xf16, #mma0> + tt.return + } +} + +// ----- + #blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 8], order = [0, 1]}> #mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 32]}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32} { @@ -1267,9 +1445,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : tt.func @matmul_tf32_cst_b(%ptr:!tt.ptr {tt.divisibility = 16 : i32}, %a: tensor<32x16xf32, #dot_operand_a>, %c: tensor<32x32xf32, #mma>) { // CHECK: %[[CST:.+]] = llvm.mlir.constant(1.000000e+00 : f32) : f32 - // CHECK: %[[BC:.+]] = llvm.bitcast %[[CST]] : f32 to i32 - // CHECK: %[[SI:.+]] = llvm.mlir.undef : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> - // CHECK: llvm.insertvalue %[[BC]], %[[SI]][0] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> + // CHECK: %[[BC:.+]] = llvm.bitcast %[[CST]] : f32 to f32 + // CHECK: %[[SI:.+]] = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> + // CHECK: llvm.insertvalue %[[BC]], %[[SI]][0] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> %b_mat = arith.constant dense<1.000000e+00> : tensor<16x32xf32, #dot_operand_b> %28 = tt.dot %a, %b_mat, %c, inputPrecision = tf32 : tensor<32x16xf32, #dot_operand_a> * tensor<16x32xf32, #dot_operand_b> -> tensor<32x32xf32, #mma> %38 = triton_gpu.convert_layout %28 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> @@ -1288,16 +1466,12 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-LABEL: matmul_f16_cst_operands tt.func public @matmul_f16_cst_operands(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> - // CHECK: %[[C1f:.+]] = llvm.mlir.constant(1.000000e+00 : f16) : f16 - // CHECK: %[[Ci16:.+]] = llvm.bitcast %[[C1f]] : f16 to i16 - // CHECK: %[[U:.+]] = llvm.mlir.undef : vector<2xi16> + // CHECK: %[[U:.+]] = llvm.mlir.undef : vector<2xf16> // CHECK: %[[C0:.+]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[V0:.+]] = llvm.insertelement %[[Ci16]], %[[U]][%[[C0]] : i32] : vector<2xi16> + // CHECK: %[[V0:.+]] = llvm.insertelement %{{.*}}, %[[U]][%[[C0]] : i32] : vector<2xf16> // CHECK: %[[C1:.+]] = llvm.mlir.constant(1 : i32) : i32 - // CHECK: %[[V1:.+]] = llvm.insertelement %[[Ci16]], %[[V0]][%[[C1]] : i32] : vector<2xi16> - // CHECK: %[[BC:.+]] = llvm.bitcast %[[V1]] : vector<2xi16> to i32 - // CHECK: %[[SU:.+]] = llvm.mlir.undef : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> - // CHECK: llvm.insertvalue %[[BC]], %[[SU]][0] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> + // CHECK: %[[V1:.+]] = llvm.insertelement %{{.*}}, %[[V0]][%[[C1]] : i32] : vector<2xf16> + // CHECK: %[[BC:.+]] = llvm.bitcast %[[V1]] : vector<2xf16> to i32 %cst_0 = arith.constant dense<1.000000e+00> : tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> %cst_1 = arith.constant dense<1.000000e+00> : tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> %cst_2 = arith.constant dense<32> : tensor<32x1xi32, #blocked> diff --git a/test/Conversion/tritongpu_to_llvm_hopper.mlir b/test/Conversion/tritongpu_to_llvm_hopper.mlir index 113ec3cf6651..65ab0194a9af 100644 --- a/test/Conversion/tritongpu_to_llvm_hopper.mlir +++ b/test/Conversion/tritongpu_to_llvm_hopper.mlir @@ -97,9 +97,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: nvgpu.wgmma_wait_group %{{.*}} {pendings = 0 : i32} : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> tt.func @dot_reg_operand_A(%a: tensor<128x64xf16, #mma>, %b: !tt.memdesc<64x64xf16, #shared>) { %cst = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> - %opA = triton_gpu.convert_layout %a : tensor<128x64xf16, #mma> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> + %opA = triton_gpu.convert_layout %a : tensor<128x64xf16, #mma> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> %m = triton_nvidia_gpu.warp_group_dot %opA, %b, %cst { inputPrecision = 0 : i32 }: - tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * !tt.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> + tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !tt.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> tt.return } } @@ -114,10 +114,24 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // Generate a wgmma where the first operand is a struct. // CHECK: nvgpu.wgmma {{.*}} : (!llvm.struct<(i32, i32, i32, i32)>, i64, i1) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> // CHECK: nvgpu.wgmma_wait_group %{{.*}} {pendings = 0 : i32} - tt.func @dot_reg_operand_A_fp8(%a: tensor<128x128xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>>, %b: !tt.memdesc<128x256xf8E5M2, #shared>) { + tt.func @dot_reg_operand_A_fp8(%a: tensor<128x128xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>, %b: !tt.memdesc<128x256xf8E5M2, #shared>) { %cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma1> %m = triton_nvidia_gpu.warp_group_dot %a, %b, %cst { maxNumImpreciseAcc = 1073741824 : i32, inputPrecision = 0 : i32 } : - tensor<128x128xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * !tt.memdesc<128x256xf8E5M2, #shared> -> tensor<128x256xf32, #mma1> + tensor<128x128xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * !tt.memdesc<128x256xf8E5M2, #shared> -> tensor<128x256xf32, #mma1> + tt.return + } +} +// +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + tt.func @dot_reg_operand_upcast(%a_desc: !tt.memdesc<128x64xi8, #shared>, %b: !tt.memdesc<64x64xf16, #shared>, %acc: tensor<128x64xf32, #mma>) { + %a_dotop = triton_gpu.local_load %a_desc : !tt.memdesc<128x64xi8, #shared> -> tensor<128x64xi8, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %a_casted = arith.sitofp %a_dotop : tensor<128x64xi8, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> to tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %res = triton_nvidia_gpu.warp_group_dot %a_casted, %b, %acc : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !tt.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> tt.return } } @@ -193,7 +207,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: prmt.b32 // CHECK: prmt.b32 tt.func @cvt_mma_to_dot_fp8(%a: tensor<128x64xf8E5M2, #mma>) { - %opA = triton_gpu.convert_layout %a : tensor<128x64xf8E5M2, #mma> -> tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> + %opA = triton_gpu.convert_layout %a : tensor<128x64xf8E5M2, #mma> -> tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> tt.return } } diff --git a/test/TritonGPU/accelerate-matmul.mlir b/test/TritonGPU/accelerate-matmul.mlir index 85b37f3ed3a9..420a9d5c2cbf 100644 --- a/test/TritonGPU/accelerate-matmul.mlir +++ b/test/TritonGPU/accelerate-matmul.mlir @@ -164,21 +164,21 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : // ----- -// Verify that dot_scaled (mxfp8 x fp8) decomposes as expected +// Verify that dot_scaled (mxfp4 x bf16) decomposes as expected #blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}> #blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}> module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { // CHECK-LABEL: dot_scaled tt.func @dot_scaled( - %a: tensor<128x64xi8, #blocked2>, + %a: tensor<128x32xi8, #blocked2>, %scale: tensor<128x2xi8, #blocked1>, - %b: tensor<64x128xi8, #blocked>) + %b: tensor<64x128xbf16, #blocked>) -> tensor<128x128xf32, #blocked> { // CHECK: triton_gpu.upcast_mxfp // CHECK: tt.dot %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> - %result = tt.dot_scaled %a, %scale, %b, %cst lhs = e4m3 rhs = e4m3 : tensor<128x64xi8, #blocked2>, tensor<128x2xi8, #blocked1> * tensor<64x128xi8, #blocked> -> tensor<128x128xf32, #blocked> + %result = tt.dot_scaled %a, %scale, %b, %cst lhs = e2m1 rhs = bf16 : tensor<128x32xi8, #blocked2>, tensor<128x2xi8, #blocked1> * tensor<64x128xbf16, #blocked> -> tensor<128x128xf32, #blocked> tt.return %result : tensor<128x128xf32, #blocked> } } diff --git a/test/TritonGPU/amd/amd-instruction-sched.mlir b/test/TritonGPU/amd/amd-instruction-sched.mlir new file mode 100644 index 000000000000..400c219b6790 --- /dev/null +++ b/test/TritonGPU/amd/amd-instruction-sched.mlir @@ -0,0 +1,103 @@ +// RUN: triton-opt %s -split-input-file -triton-amdgpu-insert-instruction-sched-hints -triton-amdgpu-lower-insert-instruction-sched-hints='variant=iglp0' -verify-diagnostics | FileCheck %s -check-prefix=INSERT_IGLP0 +// RUN: triton-opt %s -split-input-file -triton-amdgpu-insert-instruction-sched-hints -triton-amdgpu-lower-insert-instruction-sched-hints='variant=iglp1' -verify-diagnostics | FileCheck %s -check-prefix=INSERT_IGLP1 +// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline-v2='num_stages=1' -triton-amdgpu-insert-instruction-sched-hints -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -verify-diagnostics | FileCheck %s -check-prefix=INSTR_COUNT_NS1 +// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline-v2='num_stages=2' -triton-amdgpu-insert-instruction-sched-hints -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -verify-diagnostics | FileCheck %s -check-prefix=INSTR_COUNT_NS2 +// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline-v2='num_stages=2' -triton-amdgpu-insert-instruction-sched-hints -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -triton-amdgpu-lower-insert-instruction-sched-hints='variant=ck_v3' -debug-only='lower-insert-instruction-sched-hints' -verify-diagnostics 2>&1 | FileCheck %s -check-prefix=USE_CKV3_GLOBAL_LOAD +// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline-v2='num_stages=1' | FileCheck %s -check-prefix=LABELING_PS_1 +// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline-v2='num_stages=2' | FileCheck %s -check-prefix=LABELING_PS_2 + +module { + // INSERT_IGLP0-LABEL: @test_dot_op + // INSERT_IGLP1-LABEL: @test_dot_op + // INSTR_COUNT_NS1-LABEL: @test_dot_op + // INSTR_COUNT_NS2-LABEL: @test_dot_op + // LABELING_PS_1-LABEL: @test_dot_op + // LABELING_PS_2-LABEL: @test_dot_op + tt.func @test_dot_op(%lb : index, %ub : index, %step : index, + %A : !tt.ptr {tt.divisibility = 16 : i32}, + %B : !tt.ptr {tt.divisibility = 16 : i32}, + %C : !tt.ptr {tt.divisibility = 16 : i32}) { + // A ptrs + %a_ptr_splat = tt.splat %A : !tt.ptr -> tensor<128x32x!tt.ptr> + %a_tmp0 = tt.make_range {end = 32: i32, start = 0: i32} : tensor<32xi32> + %a_tmp1 = tt.expand_dims %a_tmp0 {axis = 0 : i32} : tensor<32xi32> -> tensor<1x32xi32> + %a_offs = tt.broadcast %a_tmp1 : tensor<1x32xi32> -> tensor<128x32xi32> + %a_ptr_init = tt.addptr %a_ptr_splat, %a_offs : tensor<128x32x!tt.ptr>, tensor<128x32xi32> + // B ptrs + %b_ptr_splat = tt.splat %B : !tt.ptr -> tensor<32x128x!tt.ptr> + %b_tmp0 = tt.make_range {end = 128: i32, start = 0: i32} : tensor<128xi32> + %b_tmp1 = tt.expand_dims %b_tmp0 {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> + %b_offs = tt.broadcast %b_tmp1 : tensor<1x128xi32> -> tensor<32x128xi32> + %b_ptr_init = tt.addptr %b_ptr_splat, %b_offs : tensor<32x128x!tt.ptr>, tensor<32x128xi32> + + %a_mask = arith.constant dense : tensor<128x32xi1> + %a_other = arith.constant dense<0.00e+00> : tensor<128x32xf16> + %b_mask = arith.constant dense : tensor<32x128xi1> + %b_other = arith.constant dense<0.00e+00> : tensor<32x128xf16> + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32> + + %a_off = arith.constant dense<4> : tensor<128x32xi32> + %b_off = arith.constant dense<4> : tensor<32x128xi32> + + %loop:3 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr>, tensor<32x128x!tt.ptr>, tensor<128x128xf32>) { + %a = tt.load %a_ptr : tensor<128x32x!tt.ptr> + %b = tt.load %b_ptr, %b_mask, %b_other : tensor<32x128x!tt.ptr> + + // INSERT_IGLP0: rocdl.iglp.opt 0 + // INSERT_IGLP1: rocdl.iglp.opt 1 + + // INSTR_COUNT_NS1: amdgpu.instruction_sched_hint + // INSTR_COUNT_NS1-SAME: isBufferLoadsAEnabled = false + // INSTR_COUNT_NS1-SAME: isBufferLoadsBEnabled = false + // INSTR_COUNT_NS1-SAME: numDsReadsA = #amdgpu.InstCounter<8, vector<4xf16>> + // INSTR_COUNT_NS1-SAME: numDsReadsB = #amdgpu.InstCounter<32, vector<1xf16>> + // INSTR_COUNT_NS1-SAME: numDsWritesA = #amdgpu.InstCounter<0, none> + // INSTR_COUNT_NS1-SAME: numDsWritesB = #amdgpu.InstCounter<0, none> + // INSTR_COUNT_NS1-SAME: numGlobalLoadsA = #amdgpu.InstCounter<4, vector<4xf16>> + // INSTR_COUNT_NS1-SAME: numGlobalLoadsB = #amdgpu.InstCounter<4, vector<4xf16>> + // INSTR_COUNT_NS1-SAME: numMMAs = #amdgpu.InstCounter<16, tensor<32x32x8xf16>> + + // INSTR_COUNT_NS2: amdgpu.instruction_sched_hint + // INSTR_COUNT_NS2-SAME: isBufferLoadsAEnabled = false + // INSTR_COUNT_NS2-SAME: isBufferLoadsBEnabled = false + // INSTR_COUNT_NS2-SAME: numDsReadsA = #amdgpu.InstCounter<8, vector<4xf16>> + // INSTR_COUNT_NS2-SAME: numDsReadsB = #amdgpu.InstCounter<32, vector<1xf16>> + // INSTR_COUNT_NS2-SAME: numDsWritesA = #amdgpu.InstCounter<4, vector<4xf16>> + // INSTR_COUNT_NS2-SAME: numDsWritesB = #amdgpu.InstCounter<4, vector<4xf16>> + // INSTR_COUNT_NS2-SAME: numGlobalLoadsA = #amdgpu.InstCounter<4, vector<4xf16>> + // INSTR_COUNT_NS2-SAME: numGlobalLoadsB = #amdgpu.InstCounter<4, vector<4xf16>> + // INSTR_COUNT_NS2-SAME: numMMAs = #amdgpu.InstCounter<16, tensor<32x32x8xf16>> + + // USE_CKV3_GLOBAL_LOAD: [lower-insert-instruction-sched-hints] + // USE_CKV3_GLOBAL_LOAD-SAME: Skipping instruction scheduling because `ck_v3` scheduling can be used only with `buffer_load` instructions. + + // LABELING_PS_1: scf.for + // LABELING_PS_1: %[[REG0_OP0:.+]] = tt.load {{.*}} {OpIdx = #amdgpu.OpIdx<0>} + // LABELING_PS_1: %[[REG0_OP1:.+]] = tt.load {{.*}} {OpIdx = #amdgpu.OpIdx<1>} + // LABELING_PS_1: %[[REG1_OP0:.+]] = triton_gpu.convert_layout %[[REG0_OP0]] + // LABELING_PS_1: %[[REG1_OP1:.+]] = triton_gpu.convert_layout %[[REG0_OP1]] + // LABELING_PS_1: tt.dot %[[REG1_OP0]], %[[REG1_OP1]], {{.*}} + + // LABELING_PS_2: scf.for + // LABELING_PS_2: %[[REG0_OP0:.+]] = tt.load {{.*}} {OpIdx = #amdgpu.OpIdx<0>} + // LABELING_PS_2: %[[REG0_OP1:.+]] = tt.load {{.*}} {OpIdx = #amdgpu.OpIdx<1>} + // LABELING_PS_2: triton_gpu.local_store %[[REG0_OP0]], %{{.*}} {OpIdx = #amdgpu.OpIdx<0>} + // LABELING_PS_2: triton_gpu.local_store %[[REG0_OP1]], %{{.*}} {OpIdx = #amdgpu.OpIdx<1>} + + %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16> * tensor<32x128xf16> -> tensor<128x128xf32> + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr>, tensor<128x32xi32> + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr>, tensor<32x128xi32> + scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr>, tensor<32x128x!tt.ptr>, tensor<128x128xf32> + } + + // C ptrs + %c_ptr_splat = tt.splat %C : !tt.ptr -> tensor<128x128x!tt.ptr> + %c_tmp0 = tt.make_range {end = 128: i32, start = 0: i32} : tensor<128xi32> + %c_tmp1 = tt.expand_dims %c_tmp0 {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> + %c_offs = tt.broadcast %c_tmp1 : tensor<1x128xi32> -> tensor<128x128xi32> + %c_ptr = tt.addptr %c_ptr_splat, %c_offs : tensor<128x128x!tt.ptr>, tensor<128x128xi32> + + tt.store %c_ptr, %loop#2 : tensor<128x128x!tt.ptr> + tt.return +} +} diff --git a/test/TritonGPU/dot-operands.mlir b/test/TritonGPU/dot-operands.mlir index 82fc1ddf7b65..2bdc4436713e 100644 --- a/test/TritonGPU/dot-operands.mlir +++ b/test/TritonGPU/dot-operands.mlir @@ -164,8 +164,8 @@ tt.func @update_kwidth_slice( #shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { // CHECK: tt.func @mma_v3_reg_operand_A -// CHECK: %[[A:.+]] = triton_gpu.convert_layout %{{.*}} : tensor<128x64xf16, #mma> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> -// CHECK: triton_nvidia_gpu.warp_group_dot %[[A]], {{.*}} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * !tt.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> +// CHECK: %[[A:.+]] = triton_gpu.convert_layout %{{.*}} : tensor<128x64xf16, #mma> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> +// CHECK: triton_nvidia_gpu.warp_group_dot %[[A]], {{.*}} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !tt.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> tt.func @mma_v3_reg_operand_A(%arg0: tensor<128x64xf16, #mma>, %arg1: !tt.memdesc<64x64xf16, #shared>, %arg2: tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>{ %A = triton_gpu.local_alloc %arg0 : (tensor<128x64xf16, #mma>) -> !tt.memdesc<128x64xf16, #shared1> %r = triton_nvidia_gpu.warp_group_dot %A, %arg1, %arg2 : !tt.memdesc<128x64xf16, #shared1> * !tt.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> @@ -180,8 +180,8 @@ tt.func @mma_v3_reg_operand_A(%arg0: tensor<128x64xf16, #mma>, %arg1: !tt.memdes #shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { // CHECK: tt.func @mma_v3_reg_operand_A_fp8 -// CHECK: %[[A:.+]] = triton_gpu.convert_layout %{{.*}} : tensor<128x64xf8E5M2, #mma> -> tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> -// CHECK: triton_nvidia_gpu.warp_group_dot %[[A]], {{.*}} : tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * !tt.memdesc<64x64xf8E5M2, #shared> -> tensor<128x64xf32, #mma> +// CHECK: %[[A:.+]] = triton_gpu.convert_layout %{{.*}} : tensor<128x64xf8E5M2, #mma> -> tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> +// CHECK: triton_nvidia_gpu.warp_group_dot %[[A]], {{.*}} : tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * !tt.memdesc<64x64xf8E5M2, #shared> -> tensor<128x64xf32, #mma> tt.func @mma_v3_reg_operand_A_fp8(%arg0: tensor<128x64xf8E5M2, #mma>, %arg1: !tt.memdesc<64x64xf8E5M2, #shared>, %arg2: tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>{ %A = triton_gpu.local_alloc %arg0 : (tensor<128x64xf8E5M2, #mma>) -> !tt.memdesc<128x64xf8E5M2, #shared1> %r = triton_nvidia_gpu.warp_group_dot %A, %arg1, %arg2 : !tt.memdesc<128x64xf8E5M2, #shared1> * !tt.memdesc<64x64xf8E5M2, #shared> -> tensor<128x64xf32, #mma> diff --git a/test/TritonGPU/invalid-attributes.mlir b/test/TritonGPU/invalid-attributes.mlir index c8b3c2ef6b0b..26a7c0773b9f 100644 --- a/test/TritonGPU/invalid-attributes.mlir +++ b/test/TritonGPU/invalid-attributes.mlir @@ -2,7 +2,7 @@ // expected-error@+2 {{triton_gpu.dot_op opIdx paramenter can be 0 or 1, got: 2}} #blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}> -#dot_op = #triton_gpu.dot_op<{opIdx = 2, parent = #blocked}> +#dot_op = #triton_gpu.dot_op<{opIdx = 2, parent = #blocked, kWidth = 2}> // ----- @@ -12,19 +12,25 @@ // ----- -// expected-error@+2 {{triton_gpu.dot_op kWidth parameter can only be non-zero for Ampere MMA parent}} +// expected-error@+2 {{triton_gpu.dot_op kWidth parameter can only be non-zero for Ampere or Hopper MMA parent}} #mma = #triton_gpu.nvidia_mma<{versionMajor = 1, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> #dot_op = #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}> // ----- -// expected-error@+2 {{triton_gpu.dot_op kWidth parameter is mandatory for Ampere MMA parent}} +// expected-error@+2 {{triton_gpu.dot_op kWidth parameter is mandatory for Ampere or Hopper MMA parent}} #mma = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> #dot_op = #triton_gpu.dot_op<{opIdx = 0, parent = #mma}> // ----- -// expected-error@+2 {{triton_gpu.dot_op kWidth parameter can only be non-zero for Ampere MMA parent}} +// expected-error@+2 {{triton_gpu.dot_op kWidth parameter is mandatory for Ampere or Hopper MMA parent}} +#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> +#dot_op = #triton_gpu.dot_op<{opIdx = 0, parent = #mma}> + +// ----- + +// expected-error@+2 {{triton_gpu.dot_op opIdx parameter must be 0 for Hopper MMA parent, since Hopper WGMMA only allows first operand to be in registers}} #mma = #triton_gpu.nvidia_mma<{versionMajor = 3, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> #dot_op = #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}> diff --git a/test/TritonGPU/loop-pipeline-hopper.mlir b/test/TritonGPU/loop-pipeline-hopper.mlir index d391be688c23..2c2182154d6a 100644 --- a/test/TritonGPU/loop-pipeline-hopper.mlir +++ b/test/TritonGPU/loop-pipeline-hopper.mlir @@ -398,8 +398,8 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : %21 = triton_nvidia_gpu.warp_group_dot %19, %20, %cst_2 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> %22 = arith.truncf %21 : tensor<128x16xf32, #mma1> to tensor<128x16xf16, #mma1> %23 = tt.trans %20 {order=array} : !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> !tt.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> - %24 = triton_gpu.convert_layout %22 : tensor<128x16xf16, #mma1> -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> - %25 = triton_nvidia_gpu.warp_group_dot %24, %23, %arg4 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * !tt.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf32, #mma> + %24 = triton_gpu.convert_layout %22 : tensor<128x16xf16, #mma1> -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 2}>> + %25 = triton_nvidia_gpu.warp_group_dot %24, %23, %arg4 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 2}>> * !tt.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf32, #mma> %26 = tt.addptr %arg5, %cst : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> scf.yield %25, %26 : tensor<128x64xf32, #mma>, tensor<64x16x!tt.ptr, #blocked> } @@ -481,7 +481,7 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : %c0_i64 = arith.constant 0 : i64 %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> %cst_3 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> - %cst_4 = arith.constant dense<1.000000e+00> : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> + %cst_4 = arith.constant dense<1.000000e+00> : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 2}>> %c1_i32 = arith.constant 1 : i32 %c8_i32 = arith.constant 8 : i32 %0 = tt.addptr %arg0, %c0_i64 : !tt.ptr, i64 @@ -519,7 +519,7 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : %l = tt.load %arg5 : tensor<64x16x!tt.ptr, #blocked> %c = triton_gpu.local_alloc %l : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> %23 = tt.trans %c {order=array} : !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> !tt.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> - %25 = triton_nvidia_gpu.warp_group_dot %cst_4, %23, %arg4 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * !tt.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf32, #mma> + %25 = triton_nvidia_gpu.warp_group_dot %cst_4, %23, %arg4 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 2}>> * !tt.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf32, #mma> %26 = tt.addptr %arg5, %cst : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> scf.yield %25, %26, %21 : tensor<128x64xf32, #mma>, tensor<64x16x!tt.ptr, #blocked>, tensor<128x16xf32, #mma1> } @@ -624,7 +624,7 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : %c0_i64 = arith.constant 0 : i64 %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> %cst_3 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> - %cst_4 = arith.constant dense<1.000000e+00> : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> + %cst_4 = arith.constant dense<1.000000e+00> : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 2}>> %c1_i32 = arith.constant 1 : i32 %c8_i32 = arith.constant 8 : i32 @@ -685,7 +685,7 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : // This dot can be async even though %prev_dot2 is not used directly by an // async dot, because that use follows the synchronous dot above. %prev_dot2.1 = arith.addf %prev_dot2, %prev_dot2 : tensor<128x64xf32, #mma> - %dot2 = triton_nvidia_gpu.warp_group_dot %cst_4, %23, %prev_dot2.1 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * !tt.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf32, #mma> + %dot2 = triton_nvidia_gpu.warp_group_dot %cst_4, %23, %prev_dot2.1 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 2}>> * !tt.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf32, #mma> %26 = tt.addptr %arg5, %cst : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> scf.yield %dot2, %26, %dot1.1, %dot0 : tensor<128x64xf32, #mma>, tensor<64x16x!tt.ptr, #blocked>, tensor<128x16xf32, #mma1>, tensor<128x16xf32, #mma1> } diff --git a/test/TritonGPU/pipeline-hopper-remove-wait.mlir b/test/TritonGPU/pipeline-hopper-remove-wait.mlir index 74fd2e05551b..a7064ea82204 100644 --- a/test/TritonGPU/pipeline-hopper-remove-wait.mlir +++ b/test/TritonGPU/pipeline-hopper-remove-wait.mlir @@ -113,7 +113,7 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : %115 = triton_nvidia_gpu.warp_group_dot %113, %114, %cst :!tt.memdesc<128x128xf16, #shared> * !tt.memdesc<128x64xf16, #shared1> -> tensor<128x64xf32, #mma> %116 = arith.truncf %115 : tensor<128x64xf32, #mma> to tensor<128x64xf16, #mma> %117 = triton_gpu.local_alloc %112 : (tensor<64x128xf16, #blocked>) -> !tt.memdesc<64x128xf16, #shared> - %118 = triton_gpu.convert_layout %116 : tensor<128x64xf16, #mma> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> + %118 = triton_gpu.convert_layout %116 : tensor<128x64xf16, #mma> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> // The first dot gets converted to dot-async + wait. The second one // doesn't have a wait because the first wait is sufficient. // CHECK: triton_nvidia_gpu.warp_group_dot @@ -121,7 +121,7 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : // CHECK: triton_nvidia_gpu.warp_group_dot // CHECK-NOT: triton_nvidia_gpu.warp_group_dot_wait // CHECK: scf.yield - %119 = triton_nvidia_gpu.warp_group_dot %118, %117, %arg23 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * !tt.memdesc<64x128xf16, #shared> -> tensor<128x128xf32, #mma1> + %119 = triton_nvidia_gpu.warp_group_dot %118, %117, %arg23 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !tt.memdesc<64x128xf16, #shared> -> tensor<128x128xf32, #mma1> %120 = arith.mulf %arg24, %arg25 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> %121 = arith.addf %120, %arg25 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> %122 = arith.extsi %c0_i32 : i32 to i64 diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index 390d1c83e61d..8669f5e04707 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -274,7 +274,7 @@ def make_llir(src, metadata, options): passes.common.add_canonicalizer(pm) passes.common.add_cse(pm) passes.common.add_symbol_dce(pm) - amd.passes.ttgpuir.lower_instruction_sched_hints(pm, options.instruction_sched_variant) + amd.passes.ttgpuir.lower_instruction_sched_hints(pm, options.num_stages, options.instruction_sched_variant) if os.environ.get("TRITON_DISABLE_LINE_INFO", "0") == "0": passes.llvmir.add_di_scope(pm) amd.passes.ttgpuir.add_builtin_func_to_llvmir(pm, __HIP_FTZ) diff --git a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.td b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.td index 31a43acd2f89..c0aa08421bdd 100644 --- a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.td +++ b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.td @@ -32,4 +32,31 @@ class TritonAMDGPU_Attr traits = [], : AttrDef { } +def TritonAMDGPU_OpIdxAttr : TritonAMDGPU_Attr<"OpIdx"> { + let cppNamespace = "::mlir::triton::amdgpu"; + let mnemonic = "OpIdx"; + let summary = "An operand index attribute."; + let description = [{ + The attribute is a way to describe which input argument of the target + operation (e.g., `tt.dot`) the result of a given operation belongs to. + }]; + + let parameters = (ins "uint32_t":$value); + let assemblyFormat = "`<` $value `>`"; +} + +def TritonAMDGPU_InstCounter : TritonAMDGPU_Attr<"InstCounter"> { + let cppNamespace = "::mlir::triton::amdgpu"; + let mnemonic = "InstCounter"; + let summary = "An instruction counter attribute."; + let description = [{ + The attribute holds the number of issued LLVM instructions of a specific kind as well as + the data type. + }]; + + let parameters = (ins "uint32_t":$value, "Type":$type); + let assemblyFormat = "`<` params `>`"; +} + + #endif diff --git a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUDialect.td b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUDialect.td index d5956cf7a33c..c0c18b07e907 100644 --- a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUDialect.td +++ b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUDialect.td @@ -35,6 +35,9 @@ def TritonAMDGPU_Dialect : Dialect { }]; let dependentDialects = []; + + let useDefaultAttributePrinterParser = 1; + let usePropertiesForAttributes = 1; } #endif diff --git a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td index 538e31378fe8..68c50d48635b 100644 --- a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td +++ b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td @@ -57,7 +57,29 @@ def InstructionSchedHint : TT_AMDGPU_Op<"instruction_sched_hint", []> { interleave for better instruction level parallelism. }]; - let assemblyFormat = [{attr-dict}]; + let arguments = (ins + TritonAMDGPU_InstCounter:$numDsReadsA, + TritonAMDGPU_InstCounter:$numDsReadsB, + TritonAMDGPU_InstCounter:$numDsWritesA, + TritonAMDGPU_InstCounter:$numDsWritesB, + TritonAMDGPU_InstCounter:$numGlobalLoadsA, + TritonAMDGPU_InstCounter:$numGlobalLoadsB, + BoolAttr:$isBufferLoadsAEnabled, + BoolAttr:$isBufferLoadsBEnabled, + TritonAMDGPU_InstCounter:$numMMAs + ); + + let builders = [ + OpBuilder<(ins), [{ + auto ctx = $_state.getContext(); + auto noneType = NoneType::get(ctx); + auto emptyAttr = amdgpu::InstCounterAttr::get(ctx, 0, noneType); + build($_builder, $_state, emptyAttr, emptyAttr, emptyAttr, emptyAttr, + emptyAttr, emptyAttr, false, false, emptyAttr); + }]> + ]; + + let assemblyFormat = [{ attr-dict }]; } // diff --git a/third_party/amd/include/TritonAMDGPUToLLVM/Passes.h b/third_party/amd/include/TritonAMDGPUToLLVM/Passes.h index bd726bd845d2..4036cdecd1bd 100644 --- a/third_party/amd/include/TritonAMDGPUToLLVM/Passes.h +++ b/third_party/amd/include/TritonAMDGPUToLLVM/Passes.h @@ -36,9 +36,10 @@ createConvertTritonAMDGPUToLLVMPass(StringRef targetArch, bool ftz); std::unique_ptr> createConvertBuiltinFuncToLLVMPass(bool ftz); std::unique_ptr> -createInsertInstructionSchedHintsPass(); +createTritonAMDGPUInsertInstructionSchedHintsPass(); std::unique_ptr> -createLowerInstructionSchedHintsPass(std::string variant); +createTritonAMDGPULowerInstructionSchedHintsPass(int32_t numStages, + std::string variant); #define GEN_PASS_REGISTRATION #include "TritonAMDGPUToLLVM/Passes.h.inc" diff --git a/third_party/amd/include/TritonAMDGPUToLLVM/Passes.td b/third_party/amd/include/TritonAMDGPUToLLVM/Passes.td index 9f4665aef217..0c1ccee76d77 100644 --- a/third_party/amd/include/TritonAMDGPUToLLVM/Passes.td +++ b/third_party/amd/include/TritonAMDGPUToLLVM/Passes.td @@ -59,20 +59,25 @@ def ConvertBuiltinFuncToLLVM : Pass<"convert-builtin-func-to-llvm", "mlir::Modul ]; } -def InsertInstructionSchedHints : Pass<"insert-instruction-sched-hints", "mlir::ModuleOp"> { +def TritonAMDGPUInsertInstructionSchedHints : Pass<"triton-amdgpu-insert-instruction-sched-hints", "mlir::ModuleOp"> { let summary = "Insert instruction scheduling hints after the dot ops in the main loop"; - let constructor = "mlir::triton::createInsertInstructionSchedHintsPass()"; + let constructor = "mlir::triton::createTritonAMDGPUInsertInstructionSchedHintsPass()"; - let dependentDialects = ["mlir::LLVM::LLVMDialect"]; + let dependentDialects = ["mlir::LLVM::LLVMDialect", + "mlir::triton::amdgpu::TritonAMDGPUDialect"]; } -def LowerInstructionSchedHints : Pass<"lower-insert-instruction-sched-hints", "mlir::ModuleOp"> { +def TritonAMDGPULowerInstructionSchedHints : Pass<"triton-amdgpu-lower-insert-instruction-sched-hints", "mlir::ModuleOp"> { let summary = "Lower instruction scheduling hints to LLVM intrinsics"; - let constructor = "mlir::triton::createLowerInstructionSchedHintsPass(\"\")"; + let constructor = "mlir::triton::createTritonAMDGPULowerInstructionSchedHintsPass(/*numStages=*/2, /*variant=*/\"\")"; - let dependentDialects = ["mlir::LLVM::LLVMDialect"]; + let dependentDialects = ["mlir::LLVM::LLVMDialect", + "mlir::ROCDL::ROCDLDialect", + "mlir::triton::amdgpu::TritonAMDGPUDialect"]; let options = [ + Option<"numStages", "num_stages", "int32_t", /*default*/"2", + "number of pipeline stages">, Option<"variant", "variant", "std::string", /*default*/"\"default\"", "instruction scheduling variant">, ]; diff --git a/third_party/amd/include/TritonAMDGPUTransforms/Passes.td b/third_party/amd/include/TritonAMDGPUTransforms/Passes.td index 433e60be67f6..93345b0d6de4 100644 --- a/third_party/amd/include/TritonAMDGPUTransforms/Passes.td +++ b/third_party/amd/include/TritonAMDGPUTransforms/Passes.td @@ -13,7 +13,7 @@ def TritonAMDGPUStreamPipelineV2 : Pass<"tritonamdgpu-stream-pipeline-v2", "mlir let constructor = "mlir::createTritonAMDGPUStreamPipelineV2Pass()"; - let dependentDialects = []; + let dependentDialects = ["mlir::triton::amdgpu::TritonAMDGPUDialect"]; let options = [ Option<"numStages", "num_stages", diff --git a/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp b/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp index a82a77e9f57e..1e429fdc39a9 100644 --- a/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp +++ b/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp @@ -24,6 +24,9 @@ #include "triton/Dialect/Triton/IR/Dialect.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/IR/OperationSupport.h" + +#include "llvm/ADT/TypeSwitch.h" // clang-format off #include "Dialect/TritonAMDGPU/IR/Dialect.h" @@ -45,5 +48,8 @@ void mlir::triton::amdgpu::TritonAMDGPUDialect::initialize() { >(); } +#define GET_ATTRDEF_CLASSES +#include "Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.cpp.inc" + #define GET_OP_CLASSES #include "Dialect/TritonAMDGPU/IR/Ops.cpp.inc" diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt b/third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt index b6a514f450cc..abd86dc03301 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt @@ -20,6 +20,7 @@ add_triton_library(TritonAMDGPUToLLVM OptimizeLDSUtility.cpp SPMDOpToLLVM.cpp SchedInstructions.cpp + UpcastMXFPToLLVM.cpp DEPENDS TritonAMDGPUConversionPassIncGen diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp index b832d985bbe7..9043090802bf 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp @@ -21,6 +21,7 @@ * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ #include "../PatternTritonGPUOpToLLVM.h" +#include "../TritonAMDGPUToLLVM/SchedInstructions.h" #include "SharedToDotOperandHelper.h" #include "Utility.h" @@ -336,6 +337,7 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, int elemsPerLoad = numOfElems / loadsPerThread; assert(numOfElems % loadsPerThread == 0); + VectorType loadVecTy = vec_ty(elemTy, elemsPerLoad); for (int b = 0; b < repB; ++b) { int operandSize = shape[rank - 1] * shape[rank - 2]; Value batchOffset = mul(i32_val(operandSize), @@ -346,7 +348,6 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, for (int k = 0; k < numRepK; ++k) { auto vecTy = vec_ty(resElemTy, numOfElems); for (unsigned loadId = 0; loadId < loadsPerThread; ++loadId) { - auto loadVecTy = vec_ty(elemTy, elemsPerLoad); Value loadOffset; loadOffset = offsets[nonK * loadsPerThread * numRepK + k * loadsPerThread + loadId]; @@ -363,6 +364,14 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, } } + for (auto op : tensor.getUsers()) { + if (auto localLoadOp = llvm::dyn_cast(op)) { + const size_t numDsReadsCount = + repB * numRepNonK * numRepK * loadsPerThread; + setNumGeneratedDsReads(localLoadOp, numDsReadsCount, loadVecTy); + } + } + MLIRContext *ctx = mfmaLayout.getContext(); Type structTy = LLVM::LLVMStructType::getLiteral( ctx, SmallVector(loadedValues.size(), loadedValues[0].getType())); diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandWMMA.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandWMMA.cpp index b60c86e1a3a5..1ca9e49745d6 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandWMMA.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandWMMA.cpp @@ -21,6 +21,7 @@ * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ #include "../PatternTritonGPUOpToLLVM.h" +#include "../TritonAMDGPUToLLVM/SchedInstructions.h" #include "SharedToDotOperandHelper.h" #include "Utility.h" @@ -212,6 +213,7 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, int loadsPerThread = offsets.size() / (numRepNonK * numRepK); int elemsPerLoad = numElemsPerThreadPerRep / loadsPerThread; assert(numElemsPerThreadPerRep % loadsPerThread == 0); + auto loadVecTy = vec_ty(elemTy, elemsPerLoad); for (int b = 0; b < repB; ++b) { int operandSize = shape[rank - 1] * shape[rank - 2]; Value batchOffset = mul(i32_val(operandSize), @@ -221,7 +223,6 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, auto vecTy = vec_ty(resElemTy, numElemsPerThreadPerRep); Value valVec = undef(vecTy); for (unsigned loadId = 0; loadId < loadsPerThread; ++loadId) { - auto loadVecTy = vec_ty(elemTy, elemsPerLoad); Value loadOffset = offsets[nonK * loadsPerThread * numRepK + k * loadsPerThread + loadId]; loadOffset = add(loadOffset, batchOffset); @@ -237,6 +238,14 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, } } + for (auto op : tensor.getUsers()) { + if (auto localLoadOp = llvm::dyn_cast(op)) { + const size_t numDsReadsCount = + repB * numRepNonK * numRepK * loadsPerThread; + setNumGeneratedDsReads(localLoadOp, numDsReadsCount, loadVecTy); + } + } + MLIRContext *ctx = wmmaLayout.getContext(); Type structTy = LLVM::LLVMStructType::getLiteral( ctx, SmallVector(loadedValues.size(), loadedValues[0].getType())); diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp index 204d54894d3b..1eed112c30c0 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp @@ -21,9 +21,9 @@ * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ #include "../PatternTritonGPUOpToLLVM.h" +#include "../TritonAMDGPUToLLVM/SchedInstructions.h" #include "TritonAMDGPUTransforms/MfmaGroup.h" #include "Utility.h" - #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" using namespace mlir; @@ -261,6 +261,14 @@ struct DotOpMFMAConversionHelper { Type structTy = LLVM::LLVMStructType::getLiteral( ctx, SmallVector(fc.size(), dstElemTy)); Value res = packLLElements(loc, typeConverter, fc, rewriter, structTy); + + Type elemtTy = elemTyA; + const size_t mmaCount = + numRepB * numRepM * numRepN * numRepK * kWidth / kBase; + setNumGeneratedMMAs(op, mmaCount, maybeMfmaInsn->getMDim(), + maybeMfmaInsn->getNDim(), maybeMfmaInsn->getKDim(), + elemtTy); + rewriter.replaceOp(op, res); return success(); diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp index 5a003f768833..0042cf89e93b 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp @@ -22,6 +22,7 @@ */ #include "../PatternTritonGPUOpToLLVM.h" +#include "../TritonAMDGPUToLLVM/SchedInstructions.h" #include "Utility.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" @@ -325,6 +326,10 @@ LogicalResult convertDot(DotOp op, DotOpAdaptor adaptor, Type structTy = LLVM::LLVMStructType::getLiteral( wmmaLayout.getContext(), SmallVector(fc.size(), dstElemTy)); Value res = packLLElements(loc, typeConverter, fc, rewriter, structTy); + + const size_t mmaCount = numRepB * numRepM * numRepN * numRepK; + setNumGeneratedMMAs(op, mmaCount, mnkDim[0], mnkDim[1], mnkDim[2], elemTy); + rewriter.replaceOp(op, res); return success(); } diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp index a45efd4a7971..d8bc29d30dd1 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -1,6 +1,7 @@ #include "BufferOpsEmitter.h" #include "Dialect/TritonAMDGPU/IR/Dialect.h" #include "PatternTritonGPUOpToLLVM.h" +#include "SchedInstructions.h" #include "TargetInfo.h" #include "Utility.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" @@ -39,15 +40,27 @@ Value redundantDataMask(Type valueTy, ConversionPatternRewriter &rewriter, auto sizePerThread = triton::gpu::getSizePerThread(layout); auto threadsPerWarp = triton::gpu::getThreadsPerWarp(layout); auto warpsPerCTA = triton::gpu::getWarpsPerCTA(layout); - auto order = triton::gpu::getOrder(layout); + auto threadOrder = triton::gpu::getThreadOrder(layout); + SmallVector warpOrder(rank); + if (auto enc = dyn_cast(layout)) { + warpOrder = + triton::gpu::getMatrixOrder(rank, /*rowMajor=*/enc.getOpIdx() == 1); + } else { + warpOrder = triton::gpu::getWarpOrder(layout); + } auto shapePerCTATile = triton::gpu::getShapePerCTATile(layout, shape); Value warpSize = i32_val(triton::gpu::getWarpSize(layout)); Value laneId = urem(tid, warpSize); Value warpId = udiv(tid, warpSize); + // TODO: [DOT LL] + // The delinearize function is not entirely correct for certain layouts, + // such as wgmma. The correct approach is to convert a legacy layout to its + // corresponding linear layout and use the linear layout's + // getFreeVariableMasks to identify redundant elements. SmallVector multiDimWarpId = - delinearize(rewriter, loc, warpId, warpsPerCTA, order); + delinearize(rewriter, loc, warpId, warpsPerCTA, warpOrder); SmallVector multiDimThreadId = - delinearize(rewriter, loc, laneId, threadsPerWarp, order); + delinearize(rewriter, loc, laneId, threadsPerWarp, threadOrder); for (unsigned dim = 0; dim < rank; ++dim) { // if there is no data replication across threads on this dimension if (shape[dim] >= shapePerCTATile[dim]) @@ -276,6 +289,7 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern, auto cacheMod = op.getCache(); SmallVector loadedVals; + Type vecTy = LLVM::getFixedVectorType(valueElemTy, vec); for (size_t vecStart = 0; vecStart < numElems; vecStart += vec) { const size_t maxWordWidth = std::max(32, valueElemNBits); const size_t totalWidth = valueElemNBits * vec; @@ -286,7 +300,6 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern, assert(wordNElems * nWords * numVecs == numElems); Value pred = mask ? maskElems[vecStart] : int_val(1, 1); - auto vecTy = LLVM::getFixedVectorType(valueElemTy, vec); Value ptr = addrspacecast(ptr_ty(getContext()), ptrElems[vecStart]); Value falseVal = createZeroVector(rewriter, loc, cast(vecTy)); @@ -309,6 +322,9 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern, Type llvmResultStructTy = getTypeConverter()->convertType(valueTy); Value resultStruct = packLLElements(loc, getTypeConverter(), loadedVals, rewriter, llvmResultStructTy); + + setNumGeneratedGlobalLoads(op, numVecs, vecTy); + rewriter.replaceOp(op, {resultStruct}); return success(); } @@ -391,6 +407,10 @@ struct BufferLoadOpConversion Type llvmResultStructTy = getTypeConverter()->convertType(valueTy); Value resultStruct = packLLElements(loc, getTypeConverter(), loadedVals, rewriter, llvmResultStructTy); + + const int numVecs = numElems / vec; + setNumGeneratedGlobalLoads(op, numVecs, vecTy); + rewriter.replaceOp(op, {resultStruct}); return success(); } diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/PatternTritonGPUOpToLLVM.h b/third_party/amd/lib/TritonAMDGPUToLLVM/PatternTritonGPUOpToLLVM.h index 764f31a610e1..1fdf3bdaa1cd 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/PatternTritonGPUOpToLLVM.h +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/PatternTritonGPUOpToLLVM.h @@ -34,6 +34,11 @@ void populateTritonAMDGPUToLLVMPatterns(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, PatternBenefit benefit); +void populateUpcastMXFPToLLVMPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + const TargetInfo &targetInfo, + PatternBenefit benefit); + } // namespace mlir::triton::AMD #endif diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp index 9bed87961966..62ef7a164337 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp @@ -1,87 +1,157 @@ +#include "SchedInstructions.h" #include "TritonAMDGPUToLLVM/Passes.h" - +#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "mlir/Pass/Pass.h" -#include "third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" -#include "triton/Dialect/Triton/IR/Dialect.h" namespace mlir::triton { -#define GEN_PASS_DEF_INSERTINSTRUCTIONSCHEDHINTS -#define GEN_PASS_DEF_LOWERINSTRUCTIONSCHEDHINTS +#define GEN_PASS_DEF_TRITONAMDGPUINSERTINSTRUCTIONSCHEDHINTS +#define GEN_PASS_DEF_TRITONAMDGPULOWERINSTRUCTIONSCHEDHINTS #include "TritonAMDGPUToLLVM/Passes.h.inc" } // namespace mlir::triton +#undef DEBUG_TYPE +#define DEBUG_TYPE "lower-insert-instruction-sched-hints" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + using namespace mlir; -namespace { +// TODO: The following passes/algorithms are applicable only for a single +// `tt.dot` op in a `scf.for` block -i.e., a single schedule hint op per block. +// Note, we need to relax this assumption in the future and extend the current +// implementation. -// The bitmask that encodes kinds of the instructions from AMD ISA. -// The bitmask is used for providing instruction scheduling hints. -enum InstructionKindMask { - NONE = 0x0000000, - ALL_ALU = 0x00000001, - VALU = 0x00000002, - SALU = 0x00000004, - MFMA = 0x00000008, - ALL_VMEM = 0x00000010, - VMEM_READ = 0x00000020, - VMEM_WRITE = 0x00000040, - ALL_DS = 0x00000080, - DS_READ = 0x00000100, - DS_WRITE = 0x00000200 -}; +namespace mlir::triton { +void setNumGeneratedMMAs(DotOp op, size_t mmaCount, unsigned m, unsigned n, + unsigned k, Type elementType) { + auto *ctx = op->getContext(); + auto mmaType = RankedTensorType::get({m, n, k}, elementType); + auto counterAttr = + triton::amdgpu::InstCounterAttr::get(ctx, mmaCount, mmaType); + + op->getBlock()->walk([&](triton::amdgpu::InstructionSchedHint schedHint) { + schedHint.setNumMMAsAttr(counterAttr); + }); +} + +template +void setNumGeneratedGlobalLoads(LoadOpType op, size_t globalLoadsCount, + Type type) { + MLIRContext *ctx = op->getContext(); + auto counterAttr = + triton::amdgpu::InstCounterAttr::get(ctx, globalLoadsCount, type); + + op->getBlock()->walk([&](triton::amdgpu::InstructionSchedHint schedHint) { + if (auto opIdxAttr = op->template getAttrOfType( + triton::amdgpu::OpIdxAttr::getMnemonic())) { + assert(opIdxAttr.getValue() < 2); + const bool isBufferLoadOp = + std::is_same_v; + if (opIdxAttr.getValue() == 0) { + schedHint.setNumGlobalLoadsAAttr(counterAttr); + schedHint.setIsBufferLoadsAEnabled(isBufferLoadOp); + } else { + schedHint.setNumGlobalLoadsBAttr(counterAttr); + schedHint.setIsBufferLoadsBEnabled(isBufferLoadOp); + } + } + }); +} +template void setNumGeneratedGlobalLoads(triton::amdgpu::BufferLoadOp op, + size_t globalLoadsCount, Type type); +template void setNumGeneratedGlobalLoads(triton::LoadOp op, + size_t globalLoadsCount, Type type); + +void setNumGeneratedDsReads(gpu::LocalLoadOp op, size_t dsReadsCount, + Type type) { + auto *ctx = op->getContext(); + auto counterAttr = + triton::amdgpu::InstCounterAttr::get(ctx, dsReadsCount, type); + + op->getBlock()->walk([&](triton::amdgpu::InstructionSchedHint schedHint) { + Value dst = op.getResult(); + auto dstTensorTy = cast(dst.getType()); + auto dotOperandLayout = + cast(dstTensorTy.getEncoding()); + const size_t opIdx = dotOperandLayout.getOpIdx(); + assert(opIdx < 2); + if (opIdx == 0) + schedHint.setNumDsReadsAAttr(counterAttr); + else + schedHint.setNumDsReadsBAttr(counterAttr); + }); +} + +void storeOpConversionCallback(triton::gpu::LocalStoreOp op, + size_t localStoreOpCount, Type type) { + MLIRContext *ctx = op->getContext(); + auto counterAttr = + triton::amdgpu::InstCounterAttr::get(ctx, localStoreOpCount, type); + + op->getBlock()->walk([&](triton::amdgpu::InstructionSchedHint schedHint) { + if (auto opIdxAttr = op->getAttrOfType( + triton::amdgpu::OpIdxAttr::getMnemonic())) { + assert(opIdxAttr.getValue() < 2); + if (opIdxAttr.getValue() == 0) + schedHint.setNumDsWritesAAttr(counterAttr); + else + schedHint.setNumDsWritesBAttr(counterAttr); + } + }); +} + +triton::DotOp getSingleDotOpIfExists(scf::ForOp forOp) { + triton::DotOp dotOp = nullptr; + size_t dotCounter = 0; + forOp->walk( + [&dotOp, &dotCounter](triton::DotOp op) { dotOp = op, ++dotCounter; }); + + return (dotCounter == 1) ? dotOp : nullptr; +} +} // namespace mlir::triton + +namespace { // Create an intrinsic to control how different instruction kinds should // interleave for better ILP. void createSchedGroupBarrier(PatternRewriter &rewriter, Location loc, - InstructionKindMask maskValue, int sizeValue, - int groupIdValue) { - MLIRContext *ctx = rewriter.getContext(); - const char *intrinsicName = "llvm.amdgcn.sched.group.barrier"; - - Value mask = - LLVM::createConstantI32(loc, rewriter, static_cast(maskValue)); - Value size = - LLVM::createConstantI32(loc, rewriter, static_cast(sizeValue)); - Value groupId = LLVM::createConstantI32(loc, rewriter, - static_cast(groupIdValue)); - - LLVM::createLLVMIntrinsicCallOp(rewriter, loc, intrinsicName, TypeRange{}, - ValueRange{mask, size, groupId}); + mlir::amdgpu::sched_barrier_opt_enum maskValue, + int sizeValue, int groupIdValue) { + IntegerAttr mask = + rewriter.getI32IntegerAttr(static_cast(maskValue)); + IntegerAttr size = + rewriter.getI32IntegerAttr(static_cast(sizeValue)); + IntegerAttr groupId = + rewriter.getI32IntegerAttr(static_cast(groupIdValue)); + rewriter.create(loc, mask, size, groupId); } // Insert intrinsic that controls the types of instructions that may be -// allowed to cross the intrinsic during instruction scheduling +// allowed to cross the intrinsic during instruction scheduling. Operation *createSchedBarrier(PatternRewriter &rewriter, Location loc, - int64_t maskValue) { - MLIRContext *ctx = rewriter.getContext(); - const char *intrinsicName = "llvm.amdgcn.sched.barrier"; - LLVM::FastmathFlagsAttr defaultFlags{}; - - Value mask = - LLVM::createConstantI32(loc, rewriter, static_cast(maskValue)); - return LLVM::createLLVMIntrinsicCallOp(rewriter, loc, intrinsicName, - TypeRange{}, ValueRange{mask}); + mlir::amdgpu::sched_barrier_opt_enum maskValue) { + IntegerAttr mask = + rewriter.getI32IntegerAttr(static_cast(maskValue)); + return rewriter.create(loc, mask); } // Insert an experimental intrinsic for instruction group level parallelism. // The intrinsic takes a value that specifies the strategy. Operation *createIglpOpt(PatternRewriter &rewriter, Location loc, int value) { - MLIRContext *ctx = rewriter.getContext(); - const char *intrinsicName = "llvm.amdgcn.iglp.opt"; - LLVM::FastmathFlagsAttr defaultFlags{}; - Value iglpValue = - LLVM::createConstantI32(loc, rewriter, static_cast(value)); - return LLVM::createLLVMIntrinsicCallOp(rewriter, loc, intrinsicName, - TypeRange{}, ValueRange{iglpValue}); + IntegerAttr iglpValue = + rewriter.getI32IntegerAttr(static_cast(value)); + return rewriter.create(loc, iglpValue); } struct InstructionSchedHintsRewriter : public OpRewritePattern { - InstructionSchedHintsRewriter(mlir::MLIRContext *ctx, std::string variant) - : OpRewritePattern(ctx) { + InstructionSchedHintsRewriter(MLIRContext *ctx, int32_t numStages, + std::string variant) + : OpRewritePattern(ctx), numStages(numStages) { std::transform(variant.begin(), variant.end(), variant.begin(), [](unsigned char c) { return std::tolower(c); }); @@ -89,20 +159,162 @@ struct InstructionSchedHintsRewriter .Case("default", SchedulingType::NONE) .Case("iglp0", SchedulingType::IGLP0) .Case("iglp1", SchedulingType::IGLP1) + .Case("ck_v3", SchedulingType::CK_V3) .Default(SchedulingType::UNKNOWN); + + if (this->numStages < 2) { + this->schedulingType = SchedulingType::NONE; + LDBG("ignoring instruction scheduling due to a very low num. " + "stages value. Must be >= 2"); + } } - enum class SchedulingType : uint32_t { NONE = 0, IGLP0, IGLP1, UNKNOWN }; + enum class SchedulingType : uint32_t { + NONE = 0, + IGLP0, + IGLP1, + CK_V3, + UNKNOWN + }; + + // This is the implementation of the CK's V3 pipelining (see + // see ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp). + // This scheduling requires 1x register and 1x LDS buffers combined with the + // local (LDS to registers) and global (HBM to registers) data prefetching. + // see: + // include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.h + void + createCKV3Schedule(PatternRewriter &rewriter, Location loc, + triton::amdgpu::InstructionSchedHint schedHint) const { + + if (!(schedHint.getIsBufferLoadsAEnabled() && + schedHint.getIsBufferLoadsBEnabled())) { + LDBG("Skipping instruction scheduling because `ck_v3` " + "scheduling can be used only with `buffer_load` instructions."); + return; + } + + const uint32_t numDsReadInstA = schedHint.getNumDsReadsA().getValue(); + const uint32_t numDsReadInstB = schedHint.getNumDsReadsB().getValue(); + + const uint32_t numDsWriteInstA = schedHint.getNumDsWritesA().getValue(); + const uint32_t numDsWriteInstB = schedHint.getNumDsWritesB().getValue(); + + const uint32_t numBufferLoadInstA = + schedHint.getNumGlobalLoadsA().getValue(); + const uint32_t numBufferLoadInstB = + schedHint.getNumGlobalLoadsB().getValue(); + + if (numBufferLoadInstA == 0) + schedHint.emitError("buffer load count for tile A must be initialized"); + + if (numBufferLoadInstB == 0) + schedHint.emitError("buffer load count for tile B must be initialized"); + + const uint32_t numMfmaInst = schedHint.getNumMMAs().getValue(); + + auto mfmaType = cast(schedHint.getNumMMAs().getType()); + const uint32_t nPerXDL = mfmaType.getShape()[1]; + const uint32_t mfmaCycle = nPerXDL == 16 ? 16 : 32; + + auto dsReadsAType = cast(schedHint.getNumDsReadsA().getType()); + auto dsReadsBType = cast(schedHint.getNumDsReadsB().getType()); + + const uint32_t dsReadAIssueCycle = dsReadsAType.getShape()[0] == 16 ? 8 : 4; + const uint32_t dsReadBIssueCycle = dsReadsBType.getShape()[0] == 16 ? 8 : 4; + + const auto dsReadAMfmaRate = + (mfmaCycle - 4 + 2 * dsReadAIssueCycle - 1) / (2 * dsReadAIssueCycle); + const auto dsReadBMfmaRate = + (mfmaCycle - 4 + 2 * dsReadBIssueCycle - 1) / (2 * dsReadBIssueCycle); + + const auto numDsreadAMfma = + (numDsReadInstA + dsReadAMfmaRate - 1) / dsReadAMfmaRate; + const auto numDsreadBMfma = + (numDsReadInstB + dsReadBMfmaRate - 1) / dsReadBMfmaRate; + + // stage 1 + const auto numMfmaStage1 = numMfmaInst - (numDsreadAMfma + numDsreadBMfma); + const auto numMfmaPerIssue = + numMfmaStage1 / (numBufferLoadInstA + numBufferLoadInstB); + + const auto numDswritePerIssueA = numDsWriteInstA / numBufferLoadInstA; + const auto numDswritePerIssueB = numDsWriteInstB / numBufferLoadInstB; + + for (size_t i = 0; i < numBufferLoadInstA; ++i) { + for (size_t idswrite = 0; idswrite < numDswritePerIssueA; ++idswrite) { + createSchedGroupBarrier(rewriter, loc, + mlir::amdgpu::sched_barrier_opt_enum::ds_write, + 1, 0); + createSchedGroupBarrier(rewriter, loc, + mlir::amdgpu::sched_barrier_opt_enum::mfma_wmma, + 1, 0); + } + createSchedGroupBarrier( + rewriter, loc, mlir::amdgpu::sched_barrier_opt_enum::vmem_read, 1, 0); + createSchedGroupBarrier(rewriter, loc, + mlir::amdgpu::sched_barrier_opt_enum::mfma_wmma, + numMfmaPerIssue - numDswritePerIssueA, 0); + } + + for (size_t i = 0; i < numBufferLoadInstB; ++i) { + for (size_t idswrite = 0; idswrite < numDswritePerIssueB; ++idswrite) { + createSchedGroupBarrier(rewriter, loc, + mlir::amdgpu::sched_barrier_opt_enum::ds_write, + 1, 0); + createSchedGroupBarrier(rewriter, loc, + mlir::amdgpu::sched_barrier_opt_enum::mfma_wmma, + 1, 0); + } + createSchedGroupBarrier( + rewriter, loc, mlir::amdgpu::sched_barrier_opt_enum::vmem_read, 1, 0); + createSchedGroupBarrier(rewriter, loc, + mlir::amdgpu::sched_barrier_opt_enum::mfma_wmma, + numMfmaPerIssue - numDswritePerIssueB, 0); + } + + // stage 2 + for (size_t i = 0; i < numDsreadAMfma; ++i) { + if ((numDsReadInstA - (i + 1) * dsReadAMfmaRate) >= dsReadAMfmaRate) { + createSchedGroupBarrier(rewriter, loc, + mlir::amdgpu::sched_barrier_opt_enum::ds_read, + dsReadAMfmaRate, 0); + } else { + createSchedGroupBarrier( + rewriter, loc, mlir::amdgpu::sched_barrier_opt_enum::ds_read, + numDsReadInstA - (numDsreadAMfma - 1) * dsReadAMfmaRate, 0); + } + createSchedGroupBarrier( + rewriter, loc, mlir::amdgpu::sched_barrier_opt_enum::mfma_wmma, 1, 0); + } + + for (size_t i = 0; i < numDsreadBMfma; ++i) { + if ((numDsReadInstB - (i + 1) * dsReadBMfmaRate) >= dsReadBMfmaRate) { + createSchedGroupBarrier(rewriter, loc, + mlir::amdgpu::sched_barrier_opt_enum::ds_read, + dsReadBMfmaRate, 0); + } else { + createSchedGroupBarrier( + rewriter, loc, mlir::amdgpu::sched_barrier_opt_enum::ds_read, + numDsReadInstB - (numDsreadBMfma - 1) * dsReadBMfmaRate, 0); + } + createSchedGroupBarrier( + rewriter, loc, mlir::amdgpu::sched_barrier_opt_enum::mfma_wmma, 1, 0); + } + } LogicalResult matchAndRewrite(triton::amdgpu::InstructionSchedHint instructionSchedHint, PatternRewriter &rewriter) const override { + if (this->schedulingType == SchedulingType::NONE) { + rewriter.eraseOp(instructionSchedHint); + return success(); + } if (this->schedulingType == SchedulingType::UNKNOWN) { - llvm::dbgs() - << "[" << getDebugName() << "]: " - << "unknown instruction scheduling variant has been provided\n"; - return mlir::failure(); + instructionSchedHint.emitError( + "unknown instruction scheduling variant has been provided"); + return failure(); } // The switch controls whether instructions are allowed to cross the basic @@ -110,13 +322,15 @@ struct InstructionSchedHintsRewriter // not supposed to be used together with IGLP OPT according to the AMDGPU // backend documentation. const bool limitSchedulingRange = - !(schedulingType == SchedulingType::IGLP0 || + !(schedulingType == SchedulingType::NONE || + schedulingType == SchedulingType::IGLP0 || schedulingType == SchedulingType::IGLP1); Location loc = instructionSchedHint->getLoc(); Block *block = instructionSchedHint->getBlock(); if (limitSchedulingRange) { rewriter.setInsertionPointToStart(block); - createSchedBarrier(rewriter, loc, InstructionKindMask::NONE); + createSchedBarrier(rewriter, loc, + mlir::amdgpu::sched_barrier_opt_enum::none); } rewriter.setInsertionPoint(block, std::prev(block->end())); @@ -128,6 +342,10 @@ struct InstructionSchedHintsRewriter createIglpOpt(rewriter, loc, static_cast(schedulingType) - 1); break; } + case SchedulingType::CK_V3: { + createCKV3Schedule(rewriter, loc, instructionSchedHint); + break; + } case SchedulingType::NONE: [[fallthrough]]; default: { @@ -136,21 +354,25 @@ struct InstructionSchedHintsRewriter } if (limitSchedulingRange) - createSchedBarrier(rewriter, loc, InstructionKindMask::NONE); + createSchedBarrier(rewriter, loc, + mlir::amdgpu::sched_barrier_opt_enum::none); rewriter.eraseOp(instructionSchedHint); - return mlir::success(); + return success(); } private: + int32_t numStages; SchedulingType schedulingType; }; -struct LowerInstructionSchedHints - : public triton::impl::LowerInstructionSchedHintsBase< - LowerInstructionSchedHints> { +struct TritonAMDGPULowerInstructionSchedHints + : public triton::impl::TritonAMDGPULowerInstructionSchedHintsBase< + TritonAMDGPULowerInstructionSchedHints> { - explicit LowerInstructionSchedHints(std::string variant) { + explicit TritonAMDGPULowerInstructionSchedHints(int32_t numStages, + std::string variant) { + this->numStages = numStages; this->variant = variant; } @@ -161,29 +383,40 @@ struct LowerInstructionSchedHints ConversionTarget target(*ctx); target.addLegalDialect(); target.addIllegalOp(); + target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); RewritePatternSet patterns(ctx); - patterns.add(ctx, this->variant); + + patterns.add(ctx, this->numStages, + + this->variant); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) { + signalPassFailure(); } } }; -struct InsertInstructionSchedHints - : public triton::impl::InsertInstructionSchedHintsBase< - InsertInstructionSchedHints> { +struct TritonAMDGPUInsertInstructionSchedHints + : public triton::impl::TritonAMDGPUInsertInstructionSchedHintsBase< + TritonAMDGPUInsertInstructionSchedHints> { + void runOnOperation() override { MLIRContext *ctx = &getContext(); ModuleOp mod = getOperation(); - mod->walk([ctx](triton::DotOp dot) { - if (dyn_cast(dot->getParentOp())) { - mlir::OpBuilder rewriter(ctx); - rewriter.setInsertionPointAfter(dot); - rewriter.create(dot->getLoc()); + mod.walk([this, ctx](scf::ForOp forOp) { + // Note, instruction schedule barriers are inserted only in the case of + // a single `tt.dot` op in a `scf::ForOp` scope in the current + // implementation. + if (auto dotOp = getSingleDotOpIfExists(forOp)) { + OpBuilder rewriter(ctx); + rewriter.setInsertionPointAfter(dotOp); + rewriter.create(dotOp->getLoc()); } }); } @@ -192,12 +425,14 @@ struct InsertInstructionSchedHints namespace mlir::triton { std::unique_ptr> -createLowerInstructionSchedHintsPass(std::string variant) { - return std::make_unique(variant); +createTritonAMDGPULowerInstructionSchedHintsPass(int32_t numStages, + std::string variant) { + return std::make_unique(numStages, + variant); } std::unique_ptr> -createInsertInstructionSchedHintsPass() { - return std::make_unique(); +createTritonAMDGPUInsertInstructionSchedHintsPass() { + return std::make_unique(); } } // namespace mlir::triton diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.h b/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.h new file mode 100644 index 000000000000..45985fe808f2 --- /dev/null +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.h @@ -0,0 +1,26 @@ +#ifndef TRITON_CONVERSION_TRITONAMDGPU_TO_LLVM_SCHED_INSTRUCTIONS_H +#define TRITON_CONVERSION_TRITONAMDGPU_TO_LLVM_SCHED_INSTRUCTIONS_H + +#include "mlir/IR/Types.h" +#include "third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +// The following functions are used to collect and set side-channel information +// during to LLVM conversion/lowering to facilitate instruction scheduling +// controls. +namespace mlir::triton { +void setNumGeneratedMMAs(DotOp op, size_t mmaCount, unsigned m, unsigned n, + unsigned k, Type elementType); + +template +void setNumGeneratedGlobalLoads(LoadOpType op, size_t globalLoadsCount, + Type type); +void setNumGeneratedDsReads(gpu::LocalLoadOp op, size_t numDsReadsCount, + Type type); +void storeOpConversionCallback(triton::gpu::LocalStoreOp op, size_t llvmOpCount, + Type type); +triton::DotOp getSingleDotOpIfExists(scf::ForOp forOp); +} // namespace mlir::triton + +#endif diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp index aa71c92666f7..f99cd50b0d27 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp @@ -1,6 +1,7 @@ #include "TritonAMDGPUToLLVM/Passes.h" #include "PatternTritonGPUOpToLLVM.h" +#include "SchedInstructions.h" #include "TargetInfo.h" #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" @@ -20,6 +21,7 @@ #include "triton/Analysis/Membar.h" #include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" #include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" @@ -72,8 +74,9 @@ struct ConvertTritonAMDGPUToLLVM } void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); + registry + .insert(); } void runOnOperation() override { @@ -193,8 +196,12 @@ struct ConvertTritonAMDGPUToLLVM commonBenefit); populatePatterns7(mlir::triton::populateHistogramOpToLLVMPatterns, commonBenefit); - mlir::triton::populateMemoryOpToLLVMPattern(typeConverter, targetInfo, - patterns, commonBenefit); + + mlir::triton::BackendCallbacks callbacks; + callbacks.localStoreOpConversion = storeOpConversionCallback; + + mlir::triton::populateMemoryOpToLLVMPattern( + typeConverter, targetInfo, patterns, commonBenefit, callbacks); mlir::triton::populateMakeRangeOpToLLVMPattern(typeConverter, targetInfo, patterns, commonBenefit); mlir::triton::populateAssertOpToLLVMPattern(typeConverter, patterns, @@ -207,6 +214,8 @@ struct ConvertTritonAMDGPUToLLVM mlir::triton::AMD::populateTritonAMDGPUToLLVMPatterns(typeConverter, patterns, AMDBenefit); + mlir::triton::AMD::populateUpcastMXFPToLLVMPatterns(typeConverter, patterns, + targetInfo, AMDBenefit); // TODO(thomas): this should probably be done in a separate step to not // interfere with our own lowering of arith ops. Add arith/math's patterns @@ -223,6 +232,7 @@ struct ConvertTritonAMDGPUToLLVM mlir::triton::populatePrintOpToLLVMPattern(typeConverter, patterns, targetInfo, commonBenefit); mlir::ub::populateUBToLLVMConversionPatterns(typeConverter, patterns); + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) { return signalPassFailure(); } diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp new file mode 100644 index 000000000000..f8165a769335 --- /dev/null +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp @@ -0,0 +1,143 @@ +#include "PatternTritonGPUOpToLLVM.h" + +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Transforms/DialectConversion.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Debug.h" +#include + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::gpu; + +namespace { + +class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { +private: + const TargetInfoBase &targetInfo; + +public: + UpcastMXFPOpPattern(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, PatternBenefit benefit) + : ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) { + } + + LogicalResult + matchAndRewrite(UpcastMXFPOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto fpType = op.getFpType(); + bool isPacked = fpType == ScaleDotElemType::E2M1; + if (!(isPacked || fpType == ScaleDotElemType::E4M3 || + fpType == ScaleDotElemType::E5M2)) + return rewriter.notifyMatchFailure(op, "NYI: non-mxfp8 cases"); + + Location loc = op.getLoc(); + auto xVals = unpackLLElements(loc, adaptor.getSrc(), rewriter); + auto scaleVals = unpackLLElements(loc, adaptor.getScale(), rewriter); + LDBG("x: " << xVals.size() << " x " << xVals.front().getType()); + LDBG("scale: " << scaleVals.size() << " x " << scaleVals.front().getType()); + + // When we lower scaled dot op, we made sure to distribute K only on one + // warp. MXFP spec mandates 1 scale value for every 32 onsecutive values + // along the K dimension. So in total each thread should read 32x main + // element values. + if (xVals.size() != scaleVals.size() * (isPacked ? 16 : 32)) + return rewriter.notifyMatchFailure(op, "unsupported problem size"); + + auto dotEncoding = + cast(op.getSrc().getType().getEncoding()); + if (dotEncoding.getOpIdx() == 1) + return rewriter.notifyMatchFailure(op, "NYI: dot RHS"); + auto mfmaEncoding = dyn_cast(dotEncoding.getParent()); + if (!mfmaEncoding) + return rewriter.notifyMatchFailure(op, "NYI: non-mfma dot operand"); + LDBG("mfma: " << mfmaEncoding); + + int mDim = mfmaEncoding.getMDim(); + if (mDim != 32 && mDim != 16) + return rewriter.notifyMatchFailure(op, "NYI: non-mfma32/16 intrinsics"); + + int numThreads = triton::gpu::TritonGPUDialect::getThreadsPerWarp( + op->getParentOfType()); + Value warpSize = i32_val(numThreads); + Value tid = tid_val(); + Value warpId = udiv(tid, warpSize); + Value laneId = urem(tid, warpSize); + + if (isPacked) + xVals = LLVM::convertMxfp4x2ToBf16x2(rewriter, loc, xVals); + + // Given that MFMA layout for the A tensor arranges thread in a column-major + // manner, for the current tid, it's at row (tid % mDim). When we set up + // blocked layout for the A scale tensor, we made sure that it has a + // threadsPerWarp = [M=mDim, K=64/mDim]. So the threads holding scale values + // for the current thread starts at ((tid % mDim) * (64 / mDim)). + Value offset = mul(urem(laneId, i32_val(mDim)), i32_val(numThreads / mDim)); + + if (mDim == 32) { + // One mfma32 intrinsic processes a 32x8 A tensor slice. Due to how we + // tile, the same warp owns the whole K dim. Inside a warp, each thread + // only holds 4 consecutive elements along K--a 1x4 vector. We need to + // tile the warp 4 times to cover 32 values along K. So for a thread, the + // first 4 1x4 vectors it holds shares the first scale value at row (tid % + // mDim). the second 4 1x4 vectors shares the second scale value at row + // (tid % mDim); and so forth. + std::array scaleThreads = {offset, add(offset, i32_val(1))}; + + for (auto [i, scaleVal] : llvm::enumerate(scaleVals)) { + std::array si = { + targetInfo.shuffleIdx(rewriter, loc, scaleVal, scaleThreads[0]), + targetInfo.shuffleIdx(rewriter, loc, scaleVal, scaleThreads[1]), + }; + + for (int j = 0; j < 32; ++j) { + int index = 32 * i + j; + xVals[index] = + LLVM::mxfpScaleBf16(rewriter, loc, xVals[index], si[j / 16]); + } + } + } else { + assert(mDim == 16); + // One mfma16 intrinsic processes a 16x16 A tensor slice. Similarly, we + // need to tile the warp 2 times to cover 32 valeus. So for a thread, the + // first 2 1x4 vectors shares the first scale value at row (tid % mDim). + std::array scaleThreads = {offset, add(offset, i32_val(1)), + add(offset, i32_val(2)), + add(offset, i32_val(3))}; + + for (auto [i, scaleVal] : llvm::enumerate(scaleVals)) { + auto si = std::array{ + targetInfo.shuffleIdx(rewriter, loc, scaleVal, scaleThreads[0]), + targetInfo.shuffleIdx(rewriter, loc, scaleVal, scaleThreads[1]), + targetInfo.shuffleIdx(rewriter, loc, scaleVal, scaleThreads[2]), + targetInfo.shuffleIdx(rewriter, loc, scaleVal, scaleThreads[3]), + }; + + for (int j = 0; j < 32; ++j) { + int index = 32 * i + j; + xVals[index] = + LLVM::mxfpScaleBf16(rewriter, loc, xVals[index], si[j / 8]); + } + } + } + + Value result = + packLLElements(loc, getTypeConverter(), xVals, rewriter, op.getType()); + rewriter.replaceOp(op, result); + return success(); + } +}; +} // anonymous namespace + +void mlir::triton::AMD::populateUpcastMXFPToLLVMPatterns( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + const TargetInfo &targetInfo, PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, benefit); +} diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp index 6f93bfee99c7..ba60a96f47b6 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp @@ -160,6 +160,15 @@ FailureOr chooseMfmaInstruction(tt::DotOp dot, int mfmaVersion, aType.getShape().back(), mfmaVersion, nonKDim); } +FailureOr chooseMfmaInstruction(tt::DotScaledOp dot, int mfmaVersion, + int nonKDim) { + // For scaled dot, we handle it with bf16 emulation for now. + Type bf16Type = Builder(dot.getContext()).getBF16Type(); + return chooseMfmaInstruction( + dot.getC().getType(), /*aElemType=*/bf16Type, /*bElemType=*/bf16Type, + dot.getLhs().getType().getShape().back(), mfmaVersion, nonKDim); +} + using OperandTypesVector = SmallVector; OperandTypesVector selectMatrixCoreOperandTypes(tt::DotOp dot, @@ -469,6 +478,140 @@ class BlockedToMFMA : public OpRewritePattern { } }; +class ScaledBlockedToMFMA final : public OpRewritePattern { + int mfmaVersion; + int nonKDim; + int kPack; + +public: + ScaledBlockedToMFMA(MLIRContext *context, int mfmaVersion, int nonKDim, + int kPack, PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), mfmaVersion(mfmaVersion), + nonKDim(nonKDim), kPack(kPack) {} + + LogicalResult matchAndRewrite(triton::DotScaledOp dotOp, + PatternRewriter &rewriter) const override { + using TensorValue = TypedValue; + + RankedTensorType oldRetType = dotOp.getType(); + if (!isa_and_nonnull(oldRetType.getEncoding())) + return rewriter.notifyMatchFailure( + dotOp, "expected blocked encoding result tensor"); + + if (dotOp.getRhsScale()) + return rewriter.notifyMatchFailure(dotOp, "NYI: RHS scale"); + + TensorValue a = dotOp.getLhs(); + TensorValue b = dotOp.getRhs(); + TensorValue aScale = dotOp.getLhsScale(); + ScaleDotElemType aElemType = dotOp.getLhsType(); + ScaleDotElemType bElemType = dotOp.getRhsType(); + + if (!(aElemType == ScaleDotElemType::E2M1 || + aElemType == ScaleDotElemType::E4M3 || + aElemType == ScaleDotElemType::E5M2)) + return rewriter.notifyMatchFailure(dotOp, "NYI: non-mxfp8/mxfp4 LHS"); + if (!(bElemType == ScaleDotElemType::E4M3 || + bElemType == ScaleDotElemType::E5M2 || + bElemType == ScaleDotElemType::BF16)) + return rewriter.notifyMatchFailure(dotOp, "NYI: non-fp8/bf16 RHS"); + + MLIRContext *ctx = dotOp.getContext(); + auto moduleOp = dotOp->getParentOfType(); + + ttg::CTALayoutAttr ctaLayout = ttg::getCTALayout(oldRetType.getEncoding()); + int numWarps = ttg::TritonGPUDialect::getNumWarps(moduleOp); + int numThreads = ttg::TritonGPUDialect::getThreadsPerWarp(moduleOp); + + // Choose a suitable MFMA instruction for this scaled dot op. + FailureOr mfmaInstr = + chooseMfmaInstruction(dotOp, mfmaVersion, nonKDim); + if (failed(mfmaInstr)) + return rewriter.notifyMatchFailure(dotOp, "cannot choose mfma intrinsic"); + + unsigned mDim = mfmaInstr.value().getMDim(); + unsigned nDim = mfmaInstr.value().getNDim(); + unsigned kDim = mfmaInstr.value().getKDim(); + unsigned kBase = mfmaInstr.value().getKBase(); + + // If A tensor contains mxfp4, we pack every two values into one int8 value + // there. For such cases, we have different initial kWidth for LHS and RHS, + // which will be "fixed" later by using upcast_mxfp to convert LHS to + // unpacked values. For such packed cases, we cannot support flexible kPack + // choices from the developer--it just does not apply here. So mandate the + // choice here. + bool isPacked = aElemType == ScaleDotElemType::E2M1; + unsigned kWdiths[] = {isPacked ? 4 : kBase * kPack, + isPacked ? 8 : kBase * kPack}; + + // For A tensor, 32 consecutive elements along K dim share the same scale. + // We'd like to keep the scale values together with the base values in the + // same warp to avoid cross-warp data exchange. It means we want warpsPerCTA + // = 1 along the N dimension. + SmallVector warpsPerCTA(oldRetType.getRank(), 1); + warpsPerCTA.front() = numWarps; + + // Always use transposed mfma layout. This enables larger vectorization + // for global store instructions. + auto mfmaEnc = ttg::AMDMfmaEncodingAttr::get( + ctx, /*versionMajor=*/mfmaVersion, /*versionMinor=*/0, warpsPerCTA, + /*instrShape=*/mDim, nDim, /*isTransposed=*/true, ctaLayout); + + auto newRetType = RankedTensorType::get( + oldRetType.getShape(), oldRetType.getElementType(), mfmaEnc); + + auto newAcc = rewriter.create( + dotOp.getC().getLoc(), newRetType, dotOp.getC()); + + auto toMMABf16 = [&](TensorValue v, int idx, + ScaleDotElemType type) -> TensorValue { + auto vType = v.getType(); + auto newVEncoding = DotOperandEncodingAttr::get( + ctx, idx, newRetType.getEncoding(), kWdiths[idx]); + auto newVType = RankedTensorType::get( + vType.getShape(), vType.getElementType(), newVEncoding); + v = rewriter.create(v.getLoc(), newVType, v); + if (type == ScaleDotElemType::BF16) + return v; + // Don't need to covert int8 holding mxfp4 for A--the upcast_mxfp op can + // take int8 tensor as input. + if (idx == 0 && type == ScaleDotElemType::E2M1) + return v; + + auto vTypeBf16 = RankedTensorType::get( + vType.getShape(), rewriter.getBF16Type(), newVEncoding); + return rewriter.create(v.getLoc(), vTypeBf16, v); + }; + a = toMMABf16(a, 0, aElemType); + b = toMMABf16(b, 1, bElemType); + + // We need to have "matching" encoding between the A tensor and A scale + // tensor to make sure the scale values needed is in the same warp. So we + // adopt the same CTA layout and warps per CTA. The warp dimensions needs to + // match along M dimension too. With in a warp, we have 64 threads. We let + // each thread read in one scale value. So we need a threadsPerWarp = mDim + // along M dimension. + SmallVector threadsPerWarp = {mDim, numThreads / mDim}; + auto newScaleEncoding = triton::gpu::BlockedEncodingAttr::get( + ctx, {1, 1}, threadsPerWarp, warpsPerCTA, {1, 0}, ctaLayout); + + auto newScaleType = RankedTensorType::get(aScale.getType().getShape(), + aScale.getType().getElementType(), + newScaleEncoding); + aScale = rewriter.create(aScale.getLoc(), + newScaleType, aScale); + + auto scaledA = rewriter.create( + dotOp.getLoc(), a, aScale, dotOp.getLhsType()); + + auto newDot = + rewriter.create(dotOp.getLoc(), newRetType, scaledA, b, newAcc); + rewriter.replaceOpWithNewOp(dotOp, oldRetType, + newDot); + return success(); + } +}; + static Value promoteOperand(OpBuilder &builder, Location loc, Value operand, Type promotedType) { Type tensorPromotedType = cast(operand.getType()) @@ -690,8 +833,8 @@ class TritonAMDGPUAccelerateMatmulPass case ISAFamily::CDNA1: case ISAFamily::CDNA2: case ISAFamily::CDNA3: - patterns.add<::BlockedToMFMA>(context, getMfmaVersion(isaFamily), - matrixInstructionSize, kPack); + patterns.add<::BlockedToMFMA, ::ScaledBlockedToMFMA>( + context, getMfmaVersion(isaFamily), matrixInstructionSize, kPack); break; case ISAFamily::RDNA3: patterns.add<::BlockedToWMMA>(context, diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp index f1d922041fcf..e66a2feb57fe 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp @@ -177,8 +177,21 @@ struct ConvertTritonLoadToBufferLoad Value maybeMask{}; if (op.getMask() && !isZeroConst(op.getMask())) maybeMask = op.getMask(); - rewriter.replaceOpWithNewOp( - op, op.getType(), basePtr, tensorOffset, maybeMask, maybeOther); + + auto bufferLoadOp = rewriter.create( + op->getLoc(), op.getType(), basePtr, tensorOffset, maybeMask, + maybeOther); + + // Propagate `OpIdxAttr` if the currently processed `tt.LoadOp` was + // labeled it. The attribute needs to be preserved for custom instruction + // scheduling. + if (auto opIdxAttr = op->getAttrOfType( + triton::amdgpu::OpIdxAttr::getMnemonic())) { + bufferLoadOp->setAttr(triton::amdgpu::OpIdxAttr::getMnemonic(), + opIdxAttr); + } + rewriter.replaceOp(op, bufferLoadOp); + return success(); } LDBG("Failed to convert: " << op); diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipelineV2.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipelineV2.cpp index deb566a8b1b5..3b4935026c3f 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipelineV2.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipelineV2.cpp @@ -1,6 +1,8 @@ #include "TritonAMDGPUTransforms/Passes.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Support/LLVM.h" +#include "third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h" +#include "third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.h" #include "triton/Analysis/AxisInfo.h" #include "triton/Analysis/Utility.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" @@ -168,6 +170,15 @@ void StreamPipeliner::createStreamCopy( result = select->getResults(); } + // If the currently processed `LoadOp` is labeled with an index regarding + // to which `DotOp` operand the corresponding data belongs to, then label the + // expanded `LocalStoreOp` with the same index. This is required for + // instruction scheduling hints to correctly count the emitted `ds_write` + // instructions for each GEMM tile. + if (auto attr = loadOp->getAttr(triton::amdgpu::OpIdxAttr::getMnemonic())) { + storeOp->setAttr(triton::amdgpu::OpIdxAttr::getMnemonic(), attr); + } + loadOp->replaceAllUsesWith(result); // Prefetch load ahead of the dot stage if is used by the dot. @@ -685,6 +696,41 @@ bool StreamPipeliner::pipelineLoop() { } namespace { +// Go through a single use chain to get the result of the target op after all +// unary ops - e.g., `convert_layout`, `fp_to_fp`, etc. +template Operation *passPrevUnaryOps(Value value) { + auto getNextUnaryOps = [](Value value) -> Operation * { + if (auto defOp = value.getDefiningOp()) { + if ((defOp->getNumOperands() == 1) || llvm::dyn_cast(defOp)) + return defOp; + } + return nullptr; + }; + + auto unaryOp = getNextUnaryOps(value); + while (unaryOp) { + if (llvm::dyn_cast(unaryOp)) + return unaryOp; + unaryOp = getNextUnaryOps(unaryOp->getOperand(0)); + } + return nullptr; +} + +// Annotate each `tt.LoadOp` instruction with its corresponding gemm operand +// index. Note, this is a part of the instruction scheduling routine. Currently, +// we support `forOp`s which contain only a single `tt.DotOp` in the bodies. +void labelLoadOpsForTritonDot(scf::ForOp forOp) { + mlir::MLIRContext *ctx = forOp->getContext(); + if (auto dotOp = triton::getSingleDotOpIfExists(forOp)) { + for (auto [opIdx, dotOperand] : llvm::enumerate(dotOp->getOperands())) { + if (auto loadOp = passPrevUnaryOps(dotOperand)) { + auto opIdxAttr = triton::amdgpu::OpIdxAttr::get(ctx, opIdx); + loadOp->setAttr(triton::amdgpu::OpIdxAttr::getMnemonic(), opIdxAttr); + } + } + } +} + struct PipelinePass : public TritonAMDGPUStreamPipelineV2Base { PipelinePass() = default; PipelinePass(int32_t numStages) { this->numStages = numStages; } @@ -692,6 +738,7 @@ struct PipelinePass : public TritonAMDGPUStreamPipelineV2Base { void runOnOperation() override { SmallVector loops; getOperation()->walk([&](scf::ForOp forOp) { + labelLoadOpsForTritonDot(forOp); // Bail out for loops with num_stage <= 1. if (getNumStagesOrDefault(forOp) > 1) loops.push_back(forOp); diff --git a/third_party/amd/python/triton_amd.cc b/third_party/amd/python/triton_amd.cc index f97676aafe36..d30be6959839 100644 --- a/third_party/amd/python/triton_amd.cc +++ b/third_party/amd/python/triton_amd.cc @@ -45,11 +45,12 @@ void init_triton_amd_passes_ttgpuir(py::module &&m) { pm.addPass(createConvertBuiltinFuncToLLVMPass(ftz)); }); m.def("insert_instruction_sched_hints", [](mlir::PassManager &pm) { - pm.addPass(createInsertInstructionSchedHintsPass()); + pm.addPass(createTritonAMDGPUInsertInstructionSchedHintsPass()); }); m.def("lower_instruction_sched_hints", - [](mlir::PassManager &pm, std::string variant) { - pm.addPass(createLowerInstructionSchedHintsPass(variant)); + [](mlir::PassManager &pm, int32_t numStages, std::string variant) { + pm.addPass(createTritonAMDGPULowerInstructionSchedHintsPass(numStages, + variant)); }); m.def("add_decompose_unsupported_conversions", [](mlir::PassManager &pm, const std::string &arch) { diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/CMakeLists.txt b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/CMakeLists.txt index a944da1c83f1..b26c73b88d76 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/CMakeLists.txt +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/CMakeLists.txt @@ -1,6 +1,6 @@ add_triton_library(TritonNVIDIAGPUToLLVM ConvertLayoutOpToLLVM/SharedToDotOperandMMAv1.cpp - ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp + ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp ConvertLayoutOpToLLVM.cpp DotOpToLLVM/MMAv1.cpp DotOpToLLVM/MMAv2.cpp diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 71fd3c0cd4e7..3f3a2817de43 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -34,13 +34,13 @@ Value convertLayout(int opIdx, Value tensor, const SharedMemoryObject &smemObj, } // namespace SharedToDotOperandMMAv1 -namespace SharedToDotOperandMMAv2 { +namespace SharedToDotOperandMMAv2OrV3 { Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, Location loc, Value tensor, DotOperandEncodingAttr bEncoding, const SharedMemoryObject &smemObj, const LLVMTypeConverter *typeConverter, Value thread); -} +} // namespace SharedToDotOperandMMAv2OrV3 namespace { @@ -88,11 +88,20 @@ struct LocalLoadOpConversion auto smemObj = getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), llvmElemTy, rewriter); Value res; - if (!isOuter && mmaLayout.isAmpere()) { // tensor core v2 - res = SharedToDotOperandMMAv2::convertLayout( + + if (isOuter) { + assert(false && "MMA Layout does not support outer product"); + return res; + } + + if (mmaLayout.isHopper() || mmaLayout.isAmpere()) { // tensor core v2 or v3 + if (mmaLayout.isHopper()) + assert(dotOperandLayout.getOpIdx() == 0); + + res = SharedToDotOperandMMAv2OrV3::convertLayout( dotOperandLayout.getOpIdx(), rewriter, loc, src, dotOperandLayout, smemObj, typeConverter, getThreadId(rewriter, loc)); - } else if (!isOuter && mmaLayout.isVolta() && isMMA) { // tensor core v1 + } else if (mmaLayout.isVolta() && isMMA) { // tensor core v1 bool isMMAv1Row = mmaLayout.getMMAv1IsRow(dotOperandLayout.getOpIdx()); auto srcSharedLayout = cast(src.getType().getEncoding()); @@ -631,46 +640,6 @@ struct ConvertLayoutOpConversion convertMMAV3To8BitsDotOperand(op, adaptor, rewriter); return success(); } - - if (isMmaToDotShortcut(srcTy, dstTy)) { - // get source values - auto vals = unpackLLElements(loc, adaptor.getSrc(), rewriter); - unsigned elems = getTotalElemsPerThread(srcTy); - Type elemTy = - this->getTypeConverter()->convertType(srcTy.getElementType()); - // for the destination type, we need to pack values together - // so they can be consumed by tensor core operations - SmallVector vecVals; - // For some reasons, LLVM's NVPTX backend inserts unnecessary (?) integer - // instructions to pack & unpack sub-word integers. A workaround is to - // store the results of ldmatrix in i32 - auto elemSize = elemTy.getIntOrFloatBitWidth(); - if (auto intTy = dyn_cast(elemTy) && elemSize <= 16) { - auto fold = 32 / elemSize; - for (unsigned i = 0; i < elems; i += fold) { - Value val = i32_val(0); - for (unsigned j = 0; j < fold; j++) { - auto ext = - shl(i32_ty, zext(i32_ty, vals[i + j]), i32_val(elemSize * j)); - val = or_(i32_ty, val, ext); - } - vecVals.push_back(bitcast(val, i32_ty)); - } - } else { - unsigned vecSize = std::max(32 / elemSize, 1); - Type vecTy = vec_ty(elemTy, vecSize); - for (unsigned i = 0; i < elems; i += vecSize) { - Value packed = rewriter.create(loc, vecTy); - for (unsigned j = 0; j < vecSize; j++) - packed = insert_element(vecTy, packed, vals[i + j], i32_val(j)); - vecVals.push_back(bitcast(packed, i32_ty)); - } - } - Value view = - packLLElements(loc, getTypeConverter(), vecVals, rewriter, dstTy); - rewriter.replaceOp(op, view); - return success(); - } return failure(); } diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp similarity index 84% rename from third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp rename to third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp index 21c2bee584a6..4c99a44dff15 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp @@ -25,6 +25,7 @@ class MMA16816SmemLoader { ArrayRef tileShape, ArrayRef instrShape, ArrayRef matShape, SmallVector multiDimWarpId, int perPhase, int maxPhase, int elemBytes, + int mmaElemBytes, bool isHopper, ConversionPatternRewriter &rewriter, const LLVMTypeConverter *typeConverter, const Location &loc); @@ -67,6 +68,8 @@ class MMA16816SmemLoader { int perPhase; int maxPhase; int elemBytes; + int mmaElemBytes; + bool isHopper; ConversionPatternRewriter &rewriter; const Location &loc; MLIRContext *ctx{}; @@ -203,10 +206,10 @@ MMA16816SmemLoader::computeLdmatrixMatOffs(Value lane, Value cSwizzleOffset) { // vecWidth // <-------> // *#t0 ... *#t0 t1 ... t1 t2 ... t2 t3 ... t3 || *t0 ... *t0 t1 ... t1 t2 ... t2 t3 ... t3 /|\ -// t4 ... t4 t5 ... t5 t6 ... t6 t7 ... t7 || t4 ... t4 t5 ... t5 t6 ... t6 t7 ... t7 | -// t8 ... t8 t9 ... t9 t10 .. t10 t11 .. t11 || t8 ... t8 t9 ... t9 t10 .. t10 t11 .. t11 | quad height -// ... | -// t28 .. t28 t29 .. t29 t30 .. t30 t31 .. t31 || t28 .. t28 t29 .. t29 t30 .. t30 t31 .. t31 \|/ +// t4 ... t4 t5 ... t5 t6 ... t6 t7 ... t7 || t4 ... t4 t5 ... t5 t6 ... t6 t7 ... t7 | +// t8 ... t8 t9 ... t9 t10 .. t10 t11 .. t11 || t8 ... t8 t9 ... t9 t10 .. t10 t11 .. t11 | quad height +// ... | +// t28 ... t28 t29 .. t29 t30 .. t30 t31 .. t31 || t28 .. t28 t29 .. t29 t30 .. t30 t31 .. t31 \|/ // --------------------------------------------- || -------------------------------------------- // *#t0 ... *#t0 t1 ... t1 t2 ... t2 t3 ... t3 || t0 ... t0 t1 ... t1 t2 ... t2 t3 ... t3 // t4 ... t4 t5 ... t5 t6 ... t6 t7 ... t7 || t4 ... t4 t5 ... t5 t6 ... t6 t7 ... t7 @@ -364,6 +367,7 @@ MMA16816SmemLoader::loadX4(int batch, int mat0, int mat1, ArrayRef ptrs, extract_val(elemTy, resV4, 2), extract_val(elemTy, resV4, 3)}; } else { // base pointers + // ptrs[k][...] holds `vec` pointers each for (quadK == k) std::array, 2> ptrs; for (int i = 0; i < vecWidth; i++) ptrs[0][i] = getPtr(ptrIdx + i); @@ -383,11 +387,13 @@ MMA16816SmemLoader::loadX4(int batch, int mat0, int mat1, ArrayRef ptrs, i0 = add(i0, mul(i32_val(batch * warpsPerCTA[0]), smemBatchOffset)); i1 = add(i1, mul(i32_val(batch * warpsPerCTA[0]), smemBatchOffset)); } + // ii[m] holds the offset for (quadM == m) std::array ii = {i0, i1}; // load 4 32-bit values from shared memory // (equivalent to ldmatrix.x4) SmallVector> vptrs(4, SmallVector(vecWidth)); + // i iterates the 2x2 quads, m-first for (int i = 0; i < 4; ++i) for (int j = 0; j < vecWidth; ++j) { vptrs[i][j] = gep(ptr_ty(ctx, 3), shemTy, ptrs[i / 2][j], ii[i % 2]); @@ -402,7 +408,9 @@ MMA16816SmemLoader::loadX4(int batch, int mat0, int mat1, ArrayRef ptrs, int canonWidth = (8 * elemBytes * inc) / canonBits; Type canonInt = int_ty(canonBits); std::array retElems; - retElems.fill(undef(vec_ty(canonInt, 32 / canonBits))); + // don't pack to 32b for Hopper + int vecSize = isHopper ? 1 : 32 / canonBits; + retElems.fill(undef(vec_ty(canonInt, vecSize))); for (int r = 0; r < 2; ++r) { for (int em = 0; em < 2 * vecWidth; em += inc) { int e = em % vecWidth; @@ -421,8 +429,11 @@ MMA16816SmemLoader::loadX4(int batch, int mat0, int mat1, ArrayRef ptrs, } if (isActualTrans) std::swap(retElems[1], retElems[2]); - return {bitcast(retElems[0], i32_ty), bitcast(retElems[1], i32_ty), - bitcast(retElems[2], i32_ty), bitcast(retElems[3], i32_ty)}; + + auto iTy = isHopper ? int_ty(8 * elemBytes * inc) : i32_ty; + + return {bitcast(retElems[0], iTy), bitcast(retElems[1], iTy), + bitcast(retElems[2], iTy), bitcast(retElems[3], iTy)}; } } @@ -432,8 +443,9 @@ MMA16816SmemLoader::MMA16816SmemLoader( ArrayRef smemStrides, ArrayRef tileShape, ArrayRef instrShape, ArrayRef matShape, SmallVector multiDimWarpId, int perPhase, int maxPhase, - int elemBytes, ConversionPatternRewriter &rewriter, - const LLVMTypeConverter *typeConverter, const Location &loc) + int elemBytes, int mmaElemBytes, bool isHopper, + ConversionPatternRewriter &rewriter, const LLVMTypeConverter *typeConverter, + const Location &loc) : nPerWarp(nPerWarp), order(order.begin(), order.end()), warpsPerCTA(warpsPerCTA.begin(), warpsPerCTA.end()), kOrder(kOrder), kWidth(kWidth), tileShape(tileShape.begin(), tileShape.end()), @@ -441,17 +453,29 @@ MMA16816SmemLoader::MMA16816SmemLoader( matShape(matShape.begin(), matShape.end()), multiDimWarpId(multiDimWarpId.begin(), multiDimWarpId.end()), perPhase(perPhase), maxPhase(maxPhase), elemBytes(elemBytes), - rewriter(rewriter), loc(loc), ctx(rewriter.getContext()) { + mmaElemBytes(mmaElemBytes), isHopper(isHopper), rewriter(rewriter), + loc(loc), ctx(rewriter.getContext()) { + // If the current elemType width is different from the MMA elemType width, + // i.e. width-changing casting is done later in DotOp Layout... then, in the + // case of Hopper, the number of bytes held by each thread after loading will + // no longer be 32B. Hence this flag is required to stipulate different logic. + bool isHopperWidthChange = isHopper && (mmaElemBytes != elemBytes); + contiguousMatShape = matShape[order[0]]; stridedMatShape = matShape[order[1]]; stridedSmemOffset = smemStrides[order[1]]; smemBatchOffset = smemStrides[order[2]]; - vecWidth = 4 / elemBytes; + if (isHopperWidthChange) { + vecWidth = 4 / mmaElemBytes; + } else { + vecWidth = 4 / elemBytes; + } // rule: k must be the fast-changing axis. needTrans = kOrder != order[0]; nonKOrder = (kOrder == 2) ? 1 : 2; canUseLdmatrix = elemBytes == 2 || (!needTrans); canUseLdmatrix = canUseLdmatrix && (kWidth == vecWidth); + canUseLdmatrix = canUseLdmatrix && !isHopperWidthChange; if (canUseLdmatrix) { // Each CTA, the warps is arranged as [1xwarpsPerTile] if not transposed, @@ -505,24 +529,58 @@ Type getSharedMemTy(Type argType) { } Value composeValuesToDotOperandLayoutStruct( - const ValueTable &vals, int batch, int n0, int n1, + const ValueTable &vals, int batch, int repOuter, int repK, const LLVMTypeConverter *typeConverter, Location loc, - ConversionPatternRewriter &rewriter) { + ConversionPatternRewriter &rewriter, Type eltTy, int kWidth, bool isHopper, + bool isA) { + auto bitwidth = eltTy.getIntOrFloatBitWidth(); + assert(32 >= bitwidth && "only support 32-bit or less"); + auto numElemsPerVec = 32 / bitwidth; + auto vecTy = vec_ty(eltTy, numElemsPerVec); + // FIXME: Fix the hopper path + // FIXME: [DOT LL] + // `kWidth` specifies the number of contiguous elements each thread will load. + // Loaded elements are packed into a vector of int32, which will then be + // unpacked into individual elements. + // `kIters` specifies the number of contiguous int32 elements each thread + // should load. + auto kIters = isHopper ? 1 : kWidth / (32 / bitwidth); + std::vector elems; - for (int b = 0; b < batch; ++b) - for (int m = 0; m < n0; ++m) - for (int k = 0; k < n1; ++k) { - elems.push_back(vals.at({b, 2 * m, 2 * k})); - elems.push_back(vals.at({b, 2 * m + 1, 2 * k})); - elems.push_back(vals.at({b, 2 * m, 2 * k + 1})); - elems.push_back(vals.at({b, 2 * m + 1, 2 * k + 1})); + auto unpackVec = [&](int b, int m, int k) { + for (auto kIter = 0; kIter < kIters; ++kIter) { + auto val = vals.at({b, m, k + kIter}); + auto vec = bitcast(val, vecTy); + for (auto i = 0; i < numElemsPerVec; ++i) { + elems.push_back(extract_element(eltTy, vec, i32_val(i))); } + } + }; + + // Loading A tile is different from loading B tile since each tile of A is + // 16x16 while B is 16x8. + if (isA) { + for (int b = 0; b < batch; ++b) + for (int m = 0; m < repOuter; ++m) + for (int k = 0; k < std::max(repK / kIters, 1); ++k) { + unpackVec(b, 2 * m, kIters * 2 * k); + unpackVec(b, 2 * m + 1, kIters * 2 * k); + unpackVec(b, 2 * m, kIters * (2 * k + 1)); + unpackVec(b, 2 * m + 1, kIters * (2 * k + 1)); + } + } else { + for (int b = 0; b < batch; ++b) + for (int n = 0; n < repOuter; ++n) + for (int k = 0; k < std::max(repK / kIters, 1); ++k) { + unpackVec(b, n, kIters * 2 * k); + unpackVec(b, n, kIters * (2 * k + 1)); + } + } assert(!elems.empty()); - Type elemTy = elems[0].getType(); - MLIRContext *ctx = elemTy.getContext(); + MLIRContext *ctx = eltTy.getContext(); Type structTy = LLVM::LLVMStructType::getLiteral( - ctx, SmallVector(elems.size(), elemTy)); + ctx, SmallVector(elems.size(), eltTy)); auto result = packLLElements(loc, typeConverter, elems, rewriter, structTy); return result; } @@ -544,18 +602,20 @@ getLoadMatrixFn(MemDescType descTy, const SharedMemoryObject &smemObj, const int maxPhase = sharedLayout.getMaxPhase(); const int vecPhase = sharedLayout.getVec(); const int elemBytes = descTy.getElementTypeBitWidth() / 8; + const int mmaElemBytes = 4 / kWidth; + const bool isHopper = mmaLayout.getVersionMajor() == 3; auto order = sharedLayout.getOrder(); int nPerWarp = std::max(shapePerCTA[2] / mmaLayout.getWarpsPerCTA()[2], 8); - // (a, b) is the coordinate. auto load = [=, &rewriter, &vals](int batch, int a, int b) { - MMA16816SmemLoader loader( - nPerWarp, warpsPerTile, sharedLayout.getOrder(), - mmaLayout.getWarpsPerCTA(), kOrder, kWidth, smemObj.strides, - shapePerCTA /*tileShape*/, instrShape, matShape, multiDimWarpId, - perPhase, maxPhase, elemBytes, rewriter, typeConverter, loc); + MMA16816SmemLoader loader(nPerWarp, warpsPerTile, sharedLayout.getOrder(), + mmaLayout.getWarpsPerCTA(), kOrder, kWidth, + smemObj.strides, shapePerCTA /*tileShape*/, + instrShape, matShape, multiDimWarpId, perPhase, + maxPhase, elemBytes, mmaElemBytes, isHopper, + rewriter, typeConverter, loc); // Offset of a slice within the original tensor in shared memory Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]); SmallVector offs = loader.computeOffsets(lane, cSwizzleOffset); @@ -573,6 +633,7 @@ getLoadMatrixFn(MemDescType descTy, const SharedMemoryObject &smemObj, auto [ha0, ha1, ha2, ha3] = loader.loadX4( batch, (kOrder == 2) ? a : b /*mat0*/, (kOrder == 2) ? b : a /*mat1*/, ptrs, matTy, getSharedMemTy(eltTy)); + if (!isA) std::swap(ha1, ha2); // the following is incorrect @@ -595,28 +656,32 @@ Value loadArg(ConversionPatternRewriter &rewriter, Location loc, MemDescType descTy, DotOperandEncodingAttr encoding, const SharedMemoryObject &smemObj, const LLVMTypeConverter *typeConverter, Value thread, bool isA) { + auto mmaLayout = mlir::cast(encoding.getParent()); + bool isHopper = mmaLayout.getVersionMajor() == 3; auto shapePerCTA = getShapePerCTA(descTy); int bitwidth = descTy.getElementTypeBitWidth(); - auto mmaLayout = mlir::cast(encoding.getParent()); + // For Hopper WGMMA, the sum of bitwidth of the elements in each quad should + // add up to 32. We use kWidth to compute the element bitwidth of the input to + // WGMMA, which could be different from `bitwidth` due to later casting. + int mmaBitwidth = isHopper ? (32 / encoding.getKWidth()) : bitwidth; ValueTable vals; - int mmaInstrM = 16, mmaInstrN = 8, mmaInstrK = 4 * 64 / bitwidth; - int matShapeM = 8, matShapeN = 8, matShapeK = 2 * 64 / bitwidth; + int mmaInstrM = 16, mmaInstrN = 8, mmaInstrK = 4 * 64 / mmaBitwidth; + int matShapeM = 8, matShapeN = 8, matShapeK = 2 * 64 / mmaBitwidth; int kWidth = encoding.getKWidth(); - auto numRep = mmaLayout.getMMAv2RepForOperand(shapePerCTA, bitwidth, kWidth, - encoding.getOpIdx()); + auto numRep = + mmaLayout.getRepForOperand(shapePerCTA, mmaBitwidth, encoding.getOpIdx()); auto warpsPerCTA = mmaLayout.getWarpsPerCTA(); - auto order = triton::gpu::getOrder(mmaLayout); + auto warpOrder = mmaLayout.getWarpOrder(); Value warp = udiv(thread, i32_val(32)); Value lane = urem(thread, i32_val(32)); SmallVector multiDimWarpId = - delinearize(rewriter, loc, warp, warpsPerCTA, order); + delinearize(rewriter, loc, warp, warpsPerCTA, warpOrder); Value warpB = urem(multiDimWarpId[0], i32_val(shapePerCTA[0])); int warpsPerTile; - auto rank = shapePerCTA.size(); Value warpM = urem(multiDimWarpId[1], i32_val(shapePerCTA[1] / 16)); Value warpN = urem(multiDimWarpId[2], i32_val(shapePerCTA[2] / 8)); if (isA) @@ -651,8 +716,10 @@ Value loadArg(ConversionPatternRewriter &rewriter, Location loc, loadFn(b, 2 * m, 2 * k); // Format the values to LLVM::Struct to passing to mma codegen. + Type eltTy = typeConverter->convertType(descTy.getElementType()); return composeValuesToDotOperandLayoutStruct( - vals, numRepBatch, numRepOuter, numRepK, typeConverter, loc, rewriter); + vals, numRepBatch, isA ? numRep[1] : numRep[2], numRepK, typeConverter, + loc, rewriter, eltTy, kWidth, isHopper, isA); } template @@ -764,7 +831,7 @@ getExpandedSharedMemoryObject(ConversionPatternRewriter &rewriter, Location loc, return expandedSmemObj; } -namespace SharedToDotOperandMMAv2 { +namespace SharedToDotOperandMMAv2OrV3 { Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, Location loc, Value tensor, DotOperandEncodingAttr encoding, const SharedMemoryObject &smemObj, @@ -785,4 +852,4 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, expandedSmemObj, typeConverter, thread, false); } } -} // namespace SharedToDotOperandMMAv2 +} // namespace SharedToDotOperandMMAv2OrV3 diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp index cf0ddc248dd1..40cb55bbc00d 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp @@ -70,10 +70,23 @@ struct DecomposeUnsupportedConversions : public mlir::triton::impl::DecomposeUnsupportedNVIDIAConversionsBase< DecomposeUnsupportedConversions> { void runOnOperation() override { + // FIXME [Dot LL] + // Remove the decomposeTensorCoreToDotLayoutConversion class entirely after + // we have enabled the new layout conversion for all the cases. + auto nvidiaShortCutFn = [&](RankedTensorType srcTy, + RankedTensorType dstTy) { + auto nvidiaMma = dyn_cast(srcTy.getEncoding()); + // Supported mma to dot conversion + if (nvidiaMma && nvidiaMma.isAmpere()) + return true; + // No need to decompose if shared memory is not needed + return matchMmaV3AndDotOperandLayout(srcTy, dstTy) || + cvtReordersRegisters(srcTy, dstTy); + }; ModuleOp mod = getOperation(); triton::gpu::decomposeSplatOpToSharedLayoutConversion(mod); triton::gpu::decomposeTensorCoreToDotLayoutConversion(mod, - isMmaToDotShortcut); + nvidiaShortCutFn); triton::gpu::decomposeBlockedToDotLayoutConversion(mod); mlir::RewritePatternSet patterns(&getContext()); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp index c2940a04386f..69c5c9e6df8d 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp @@ -9,6 +9,7 @@ using namespace mlir; using namespace mlir::triton; using ::mlir::triton::gpu::DotOperandEncodingAttr; +using ::mlir::triton::gpu::getOrderForDotOperand; using ::mlir::triton::gpu::NvidiaMmaEncodingAttr; using ValueTableV2 = std::map, Value>; @@ -59,56 +60,73 @@ Value loadC(Value tensor, Value llTensor, ValueTableV2 getValuesFromDotOperandLayoutStruct( const LLVMTypeConverter *typeConverter, Location loc, - ConversionPatternRewriter &rewriter, Value value, int batch, int n0, int n1, - RankedTensorType type) { + ConversionPatternRewriter &rewriter, Value value, int batch, int repOuter, + int repK, RankedTensorType type) { auto elems = unpackLLElements(loc, value, rewriter); + auto eltTy = type.getElementType(); int offset{}; ValueTableV2 vals; + auto bitwidth = eltTy.getIntOrFloatBitWidth(); + auto numElemsPerVec = 32 / bitwidth; + auto vecTy = vec_ty(eltTy, numElemsPerVec); + + auto packVec = [&](std::array dstIdx) { + Value vec = undef(vecTy); + for (auto i = 0; i < numElemsPerVec; ++i) { + vec = insert_element(vec, bitcast(elems[offset + i], eltTy), i32_val(i)); + } + vals[dstIdx] = bitcast(vec, i32_ty); + offset += numElemsPerVec; + }; - // FIXME [Dot LL] - // [ez] Generalize the logic below for kWidth * elemBitWidth > 32 auto dot = cast(type.getEncoding()); - auto largeK = dot.getKWidth() == 8 && - cast(dot.getParent()).isAmpere(); + auto kWidth = dot.getKWidth(); + auto largeK = bitwidth * kWidth > 32; if (largeK) { + // For layouts with a large K dimension, the original register layout needs + // to be divided into multiple MMAs, where each MMA has contiguous 32 bits + // along the K dimension per thread. + // Using kWidth = 8 and bitwidth = 2 as an example, + // we split the MMA into 4 sub-MMAs, each with a stride 4 x 32-bit along the + // K dimension. llvm::SmallVector si; - // For kWidth = 8, split the mma into 4 mmas with "stride 4" along K if (dot.getOpIdx() == 0) { // Original register layout: // - // [0, 1, 2, 3], [8, 9, 10, 11] - // [4, 5, 6, 7], [12, 13, 14, 15] - // - // Each element in the layout consists of two bf16 values. - // For example, the row [0, 1, 2, 3] expands to: + // [0, 1, 2, 3, 4, 5, 6, 7], [16, 17, 18, 19, 20, 21, 22, 23, 23] + // [8, 9, 10, 11, 12, 13, 14, 15], [24, 25, 26, 27, 28, 29, 30, 31] // - // [[0/0, 0/1], [1/0, 1/1], [2/0, 2/1], [3/0, 3/1]] - // - // Here, 0/0 refers to the first half of element 0, and 0/1 refers to the - // second half, matching kWidth = 8. + // Each element in the layout is a single bf16. // // To derive four independent MMA operations, a stride of 4 is applied to // the original register layout: // - // 1st MMA: [0, 4, 8, 12] - // 2nd MMA: [1, 5, 9, 13] - // 3rd MMA: [2, 6, 10, 14] - // 4th MMA: [3, 7, 11, 15] - si = llvm::SmallVector{0, 4, 8, 12, 1, 5, 9, 13, - 2, 6, 10, 14, 3, 7, 11, 15}; + // 1st MMA: [[0, 1], [8, 9], [16, 17], [24, 25]] + // 2nd MMA: [[2, 3], [10, 11], [18, 19], [26, 27]] + // 3rd MMA: [[4, 5], [12, 13], [20, 21], [28, 29]] + // 4th MMA: [[6, 7], [14, 15], [22, 23], [30, 31]] + for (size_t kRep = 0; kRep < kWidth / numElemsPerVec; ++kRep) + for (size_t tile = 0; tile < 4; ++tile) + for (size_t e = 0; e < numElemsPerVec; ++e) { + si.push_back(kRep * numElemsPerVec + tile * kWidth + e); + } } else { // Original register layout: // - // [0, 1, 2, 3]^T, [4, 5, 6, 7]^T + // [0, 1, 2, 3, 4, 5, 6, 7]^T, [8, 9, 10, 11, 12, 13, 14, 15]^T // // A stride of 4 is applied to derive four independent MMA operations: // - // 1st MMA: [0, 4] - // 2nd MMA: [1, 5] - // 3rd MMA: [2, 6] - // 4th MMA: [3, 7] - si = llvm::SmallVector{0, 4, 1, 5, 2, 6, 3, 7}; + // 1st MMA: [[0, 1], [8, 9]] + // 2nd MMA: [[2, 3], [10, 11]] + // 3rd MMA: [[4, 5], [12, 13]] + // 4th MMA: [[6, 7], [14, 15]] + for (size_t kRep = 0; kRep < kWidth / numElemsPerVec; ++kRep) + for (size_t tile = 0; tile < 2; ++tile) + for (size_t e = 0; e < numElemsPerVec; ++e) { + si.push_back(kRep * numElemsPerVec + tile * kWidth + e); + } } auto step = si.size(); @@ -119,34 +137,25 @@ ValueTableV2 getValuesFromDotOperandLayoutStruct( } std::copy(perm.begin(), perm.end(), elems.begin() + i * step); } - - if (dot.getOpIdx() == 1) { - // there are kWidth * 2 elems packed as bf16x2 - int elemsInTile = dot.getKWidth(); - // n0 and n1 are unrolled in the legacy path - // Unrolling n1 makes some sense, but unrolling n0 makes absolutely no - // sense IMO - n0 *= 2; - n1 *= 2; - for (auto b = 0; b < batch; ++b) - for (auto j = 0; j < n1 / elemsInTile; ++j) - for (auto i = 0; i < n0; ++i) - for (auto k = 0; k < elemsInTile; ++k) { - vals[{b, i, elemsInTile * j + k}] = elems[offset++]; - } - return vals; - } } - for (auto b = 0; b < batch; ++b) - for (auto i = 0; i < n0; ++i) { - for (auto j = 0; j < n1; j++) { - vals[{b, 2 * i, 2 * j}] = elems[offset++]; - vals[{b, 2 * i + 1, 2 * j}] = elems[offset++]; - vals[{b, 2 * i, 2 * j + 1}] = elems[offset++]; - vals[{b, 2 * i + 1, 2 * j + 1}] = elems[offset++]; - } - } + if (dot.getOpIdx() == 0) { + for (auto b = 0; b < batch; ++b) + for (auto m = 0; m < repOuter; ++m) + for (auto k = 0; k < repK; ++k) { + packVec({b, 2 * m, 2 * k}); + packVec({b, 2 * m + 1, 2 * k}); + packVec({b, 2 * m, 2 * k + 1}); + packVec({b, 2 * m + 1, 2 * k + 1}); + } + } else { + for (auto b = 0; b < batch; ++b) + for (auto n = 0; n < repOuter; ++n) + for (auto k = 0; k < repK; ++k) { + packVec({b, n, 2 * k}); + packVec({b, n, 2 * k + 1}); + } + } return vals; } @@ -394,28 +403,30 @@ LogicalResult convertDot(const LLVMTypeConverter *typeConverter, int bitwidth = aTensorTy.getElementType().getIntOrFloatBitWidth(); auto dotOpA = cast(aTensorTy.getEncoding()); auto repA = cast(dotOpA.getParent()) - .getMMAv2RepForOperand(aShapePerCTA, bitwidth, - dotOpA.getKWidth(), dotOpA.getOpIdx()); + .getRepForOperand(aShapePerCTA, bitwidth, dotOpA.getOpIdx()); auto dotOpB = cast(bTensorTy.getEncoding()); auto repB = cast(dotOpB.getParent()) - .getMMAv2RepForOperand(bShapePerCTA, bitwidth, - dotOpB.getKWidth(), dotOpB.getOpIdx()); + .getRepForOperand(bShapePerCTA, bitwidth, dotOpB.getOpIdx()); assert(repA[2] == repB[1]); assert(repA[0] == repB[0]); int repM = repA[1], repN = repB[2], repK = repA[2]; int repBatch = repA[0]; + // We can reuse the same iteration order in + // getValuesFromDotOperandLayoutStruct as both a and b are K-major + assert(dotOpA.getRepOrder() == getOrderForDotOperand(dotOpA.getOpIdx(), + aShapePerCTA.size(), + /*kMajor=*/true)); auto ha = getValuesFromDotOperandLayoutStruct( typeConverter, loc, rewriter, loadedA, repBatch, repM, repK, aTensorTy); - // FIXME [Dot LL] - // max(repN / 2, 1) is wrong for repN = 1! - // This is also wrong in - // NvidiaMmaEncodingAttr::getTotalElemsPerThreadForOperand + assert(dotOpB.getRepOrder() == getOrderForDotOperand(dotOpB.getOpIdx(), + bShapePerCTA.size(), + /*kMajor=*/true)); auto hb = getValuesFromDotOperandLayoutStruct( - typeConverter, loc, rewriter, loadedB, repBatch, std::max(repN / 2, 1), - repK, bTensorTy); + typeConverter, loc, rewriter, loadedB, repBatch, repN, repK, bTensorTy); + auto fc = unpackLLElements(loc, loadedC, rewriter); auto numMmaRets = dTensorTy.getElementType().getIntOrFloatBitWidth() / 8; int numCPackedElem = 4 / numMmaRets; diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp index 1bb55373e046..2b9b4f159bf4 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp @@ -442,6 +442,11 @@ LogicalResult convertDot(const LLVMTypeConverter *typeConverter, if (aSharedLayout) { a = aLoader.smemLoad(m, k, rewriter, loc); } else { + auto aDotOpEnc = + cast(aTensorTy.getEncoding()); + assert(aDotOpEnc.getKWidth() == + 32 / aTensorTy.getElementTypeBitWidth()); + unsigned regASize = (instrShape[0] * instrShape[2]) / 32; llvm::SmallVector regA = loadReg(rewriter, loc, structA, (m * numRepK + k) * regASize, diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp index cb430d8fadef..115a3316dc5c 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -38,16 +38,27 @@ Value redundantDataMask(Type valueTy, ConversionPatternRewriter &rewriter, auto sizePerThread = triton::gpu::getSizePerThread(layout); auto threadsPerWarp = triton::gpu::getThreadsPerWarp(layout); auto warpsPerCTA = triton::gpu::getWarpsPerCTA(layout); - auto order = triton::gpu::getOrder(layout); - auto warpOrder = triton::gpu::getWarpOrder(layout); + auto threadOrder = triton::gpu::getThreadOrder(layout); + SmallVector warpOrder(rank); + if (auto enc = dyn_cast(layout)) { + warpOrder = + triton::gpu::getMatrixOrder(rank, /*rowMajor=*/enc.getOpIdx() == 1); + } else { + warpOrder = triton::gpu::getWarpOrder(layout); + } auto shapePerCTATile = triton::gpu::getShapePerCTATile(layout, shape); Value warpSize = i32_val(32); Value laneId = urem(tid, warpSize); Value warpId = udiv(tid, warpSize); + // TODO: [DOT LL] + // The delinearize function is not entirely correct for certain layouts, + // such as wgmma. The correct approach is to convert a legacy layout to its + // corresponding linear layout and use the linear layout's + // getFreeVariableMasks to identify redundant elements. SmallVector multiDimWarpId = delinearize(rewriter, loc, warpId, warpsPerCTA, warpOrder); SmallVector multiDimThreadId = - delinearize(rewriter, loc, laneId, threadsPerWarp, order); + delinearize(rewriter, loc, laneId, threadsPerWarp, threadOrder); for (unsigned dim = 0; dim < rank; ++dim) { // if there is no data replication across threads on this dimension if (shape[dim] >= shapePerCTATile[dim]) diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/UpcastMXFPToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/UpcastMXFPToLLVM.cpp index 722bf56cd015..6cba3f45da4d 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/UpcastMXFPToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/UpcastMXFPToLLVM.cpp @@ -1,6 +1,7 @@ #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "PatternTritonGPUOpToLLVM.h" @@ -12,7 +13,6 @@ #include "triton/Dialect/Triton/IR/Dialect.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/Support/raw_ostream.h" #include using namespace mlir; @@ -30,60 +30,6 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { : ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) {} - llvm::SmallVector - unpackFP4Elements(Location loc, ConversionPatternRewriter &rewriter, - const llvm::SmallVector &vals, Value laneId) const { - auto fp4x2ToBf16x2 = [&loc, &rewriter](Value v) -> Value { - auto em0 = and_(v, i8_val(0x70)); - auto em1 = and_(v, i8_val(0x7)); - Value v0 = or_(shl(zext(i16_ty, em0), i16_val(2)), - shl(zext(i16_ty, and_(v, i8_val(0x80))), i16_val(8))); - Value v1 = or_(shl(zext(i16_ty, em1), i16_val(6)), - shl(zext(i16_ty, and_(v, i8_val(0x8))), i16_val(12))); - - // Three cases: - // 1) x is normal and non-zero: Correct bias - v0 = select(icmp_ne(and_(em0, i8_val(0x60)), i8_val(0)), - add(v0, i16_val((127 - 1) << 7)), v0); - v1 = select(icmp_ne(and_(em1, i8_val(0x6)), i8_val(0)), - add(v1, i16_val((127 - 1) << 7)), v1); - - // 2) x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in - // bf16 - v0 = select(icmp_eq(em0, i8_val(0x10)), - or_(i16_val(16128), and_(v0, i16_val(0x8000))), v0); - v1 = select(icmp_eq(em1, i8_val(0x1)), - or_(i16_val(16128), and_(v1, i16_val(0x8000))), v1); - // 3) x is zero, nothing to do - - // Swap as they come packed in big endian - return or_(zext(i32_ty, v0), shl(zext(i32_ty, v1), i32_val(16))); - }; - - auto fp4x8ToBf16x2 = [&loc, &rewriter, &fp4x2ToBf16x2]( - Value v) -> llvm::SmallVector { - llvm::SmallVector results(4); - for (int i = 0; i < 4; ++i) { - auto v_i = trunc(i8_ty, lshr(v, i32_val(8 * i))); - results[i] = fp4x2ToBf16x2(v_i); - } - return results; - }; - - // Split fp4x8 into 4 bf16x2 - llvm::SmallVector ret; - ret.reserve(vals.size() * 4); - for (int i = 0; i < vals.size(); ++i) { - auto vs = fp4x8ToBf16x2(vals[i]); - assert(vs.size() == 4); - for (auto v : vs) { - ret.push_back(v); - } - } - - return ret; - } - LogicalResult matchAndRewrite(UpcastMXFPOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { @@ -103,27 +49,8 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { Value warpId = udiv(tid, warpSize); Value laneId = urem(tid, warpSize); - if (fpType == F8F6F4Type::E2M1) { - xVals = unpackFP4Elements(loc, rewriter, xVals, laneId); - } - - auto scaleBf16x2 = [&loc, &rewriter](Value v, Value s) -> Value { - // Split bf16x2 into 2 bf16, scale each of them, and pack them back - // TODO Is it true that the bfloats are always packed as bf16x2? - auto bf16_0 = bitcast(trunc(i16_ty, v), bf16_ty); - auto bf16_1 = bitcast(trunc(i16_ty, lshr(v, i32_val(16))), bf16_ty); - auto scaleIsNan = icmp_eq(s, i8_val(0xff)); - auto scaleBf16 = bitcast(shl(zext(i16_ty, s), i16_val(7)), bf16_ty); - auto scaledBf16_0 = fmul(bf16_0, scaleBf16); - auto scaledBf16_1 = fmul(bf16_1, scaleBf16); - auto i16_0 = bitcast(scaledBf16_0, i16_ty); - auto i16_1 = bitcast(scaledBf16_1, i16_ty); - auto packed = - or_(zext(i32_ty, i16_0), shl(zext(i32_ty, i16_1), i32_val(16))); - // Account for NaN in the scale as per the mxfp specification - auto packed_nan = select(scaleIsNan, i32_val(0x7fff7fff), packed); - return packed_nan; - }; + if (fpType == ScaleDotElemType::E2M1) + xVals = LLVM::convertMxfp4x2ToBf16x2(rewriter, loc, xVals); // Each thread owns elements of 4 mxfp vectors so we need 4 scales // Letting c = tid / 4 * 2, we need the elements from threads c, c + 1, c + @@ -141,8 +68,9 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { targetInfo.shuffleIdx(rewriter, loc, scaleVal, ci[3]), }; - for (int j = 0; j < 16; ++j) { - xVals[16 * i + j] = scaleBf16x2(xVals[16 * i + j], si[j / 4]); + for (int j = 0; j < 32; ++j) { + xVals[32 * i + j] = + LLVM::mxfpScaleBf16(rewriter, loc, xVals[32 * i + j], si[j / 8]); } } diff --git a/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp b/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp index fd65233e5c6b..d662537ed72d 100644 --- a/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp +++ b/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp @@ -41,9 +41,9 @@ class LinearLayoutConversionsTest : public ::testing::Test { CTALayoutAttr::get(&ctx, cpg, cSplit, cOrd), instrShape); } - DotOperandEncodingAttr dotMMAv2(int idx, int kWidth, ArrayRef warps, - ArrayRef order) { - auto mmaLayout = mma(2, 0, {16, 8}, warps, {1, 1}, {1, 1}, order); + DotOperandEncodingAttr dotMMAv2(int idx, int kWidth, + ArrayRef warps) { + auto mmaLayout = mma(2, 0, {16, 8}, warps, {1, 1}, {1, 1}, {1, 0}); return DotOperandEncodingAttr::get(&ctx, idx, mmaLayout, /*kWidth=*/kWidth); } @@ -301,6 +301,19 @@ TEST_F(LinearLayoutConversionsTest, Blocked4D) { {S("dim0"), S("dim1"), S("dim2"), S("dim3")})); } +TEST_F(LinearLayoutConversionsTest, MMAv2_16x16) { + EXPECT_EQ(toLinearLayout({16, 16}, + mma(2, 0, {16, 8}, {1, 1}, {1, 1}, {1, 1}, {0, 1})), + LinearLayout( + { + {S("register"), {{0, 1}, {8, 0}, {0, 8}}}, + {S("lane"), {{0, 2}, {0, 4}, {1, 0}, {2, 0}, {4, 0}}}, + {S("warp"), {}}, + {S("block"), {}}, + }, + {S("dim0"), S("dim1")})); +} + TEST_F(LinearLayoutConversionsTest, MMAv2_32x32) { EXPECT_EQ(toLinearLayout({32, 32}, mma(2, 0, {16, 8}, {1, 1}, {1, 1}, {1, 1}, {0, 1})), @@ -502,7 +515,7 @@ TEST_F(LinearLayoutConversionsTest, MMAv3_4x4Warps) { } TEST_F(LinearLayoutConversionsTest, DotMMAv2_tile_kwidth8) { - EXPECT_EQ(toLinearLayout({16, 64}, dotMMAv2(0, 8, {1, 1}, {1, 0})), + EXPECT_EQ(toLinearLayout({16, 64}, dotMMAv2(0, 8, {1, 1})), LinearLayout( { {S("register"), {{0, 1}, {0, 2}, {0, 4}, {8, 0}, {0, 32}}}, @@ -511,7 +524,7 @@ TEST_F(LinearLayoutConversionsTest, DotMMAv2_tile_kwidth8) { {S("block"), {}}, }, {S("dim0"), S("dim1")})); - EXPECT_EQ(toLinearLayout({64, 8}, dotMMAv2(1, 8, {1, 1}, {1, 0})), + EXPECT_EQ(toLinearLayout({64, 8}, dotMMAv2(1, 8, {1, 1})), LinearLayout( { {S("register"), {{1, 0}, {2, 0}, {4, 0}, {32, 0}}}, @@ -524,7 +537,7 @@ TEST_F(LinearLayoutConversionsTest, DotMMAv2_tile_kwidth8) { TEST_F(LinearLayoutConversionsTest, DotMMAv2_large_warp4_kwidth8) { EXPECT_EQ( - toLinearLayout({128, 128}, dotMMAv2(0, 8, {4, 1}, {1, 0})), + toLinearLayout({128, 128}, dotMMAv2(0, 8, {4, 1})), LinearLayout( { {S("register"), @@ -534,7 +547,7 @@ TEST_F(LinearLayoutConversionsTest, DotMMAv2_large_warp4_kwidth8) { {S("block"), {}}, }, {S("dim0"), S("dim1")})); - EXPECT_EQ(toLinearLayout({128, 64}, dotMMAv2(1, 8, {4, 1}, {1, 0})), + EXPECT_EQ(toLinearLayout({128, 64}, dotMMAv2(1, 8, {4, 1})), LinearLayout( { {S("register"), @@ -542,19 +555,19 @@ TEST_F(LinearLayoutConversionsTest, DotMMAv2_large_warp4_kwidth8) { {2, 0}, {4, 0}, {32, 0}, + {64, 0}, {0, 8}, {0, 16}, - {0, 32}, - {64, 0}}}, + {0, 32}}}, {S("lane"), {{8, 0}, {16, 0}, {0, 1}, {0, 2}, {0, 4}}}, { S("warp"), - {}, + {{0, 0}, {0, 0}}, }, {S("block"), {}}, }, {S("dim0"), S("dim1")})); - EXPECT_EQ(toLinearLayout({64, 128}, dotMMAv2(1, 8, {4, 1}, {1, 0})), + EXPECT_EQ(toLinearLayout({64, 128}, dotMMAv2(1, 8, {4, 1})), LinearLayout( { {S("register"), @@ -569,13 +582,46 @@ TEST_F(LinearLayoutConversionsTest, DotMMAv2_large_warp4_kwidth8) { {S("lane"), {{8, 0}, {16, 0}, {0, 1}, {0, 2}, {0, 4}}}, { S("warp"), - {}, + {{0, 0}, {0, 0}}, }, {S("block"), {}}, }, {S("dim0"), S("dim1")})); } +TEST_F(LinearLayoutConversionsTest, DotMMAv2_split_warp_kwidth8) { + EXPECT_EQ( + toLinearLayout({32, 64}, dotMMAv2(0, 8, {2, 2})), + LinearLayout({{S("register"), {{0, 1}, {0, 2}, {0, 4}, {8, 0}, {0, 32}}}, + {S("lane"), {{0, 8}, {0, 16}, {1, 0}, {2, 0}, {4, 0}}}, + {S("warp"), {{0, 0}, {16, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ( + toLinearLayout({64, 16}, dotMMAv2(1, 8, {2, 2})), + LinearLayout({{S("register"), {{1, 0}, {2, 0}, {4, 0}, {32, 0}}}, + {S("lane"), {{8, 0}, {16, 0}, {0, 1}, {0, 2}, {0, 4}}}, + {S("warp"), {{0, 8}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ(toLinearLayout({64, 128}, dotMMAv2(0, 8, {2, 2})), + LinearLayout( + {{S("register"), + {{0, 1}, {0, 2}, {0, 4}, {8, 0}, {0, 32}, {0, 64}, {32, 0}}}, + {S("lane"), {{0, 8}, {0, 16}, {1, 0}, {2, 0}, {4, 0}}}, + {S("warp"), {{0, 0}, {16, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ( + toLinearLayout({128, 32}, dotMMAv2(1, 8, {2, 2})), + LinearLayout( + {{S("register"), {{1, 0}, {2, 0}, {4, 0}, {32, 0}, {64, 0}, {0, 16}}}, + {S("lane"), {{8, 0}, {16, 0}, {0, 1}, {0, 2}, {0, 4}}}, + {S("warp"), {{0, 8}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); +} + TEST_F(LinearLayoutConversionsTest, MFMA32_2x4Warps) { auto mfmaNT = mfma(/*warps=*/{2, 4}, /*mDim=*/32, /*nDim=*/32, /*isTransposed=*/false); diff --git a/unittest/Tools/LinearLayoutTest.cpp b/unittest/Tools/LinearLayoutTest.cpp index f006447002ef..897172fd6d34 100644 --- a/unittest/Tools/LinearLayoutTest.cpp +++ b/unittest/Tools/LinearLayoutTest.cpp @@ -747,6 +747,39 @@ TEST_F(LinearLayoutTest, QuotientIdentityMultipleDimensions) { ASSERT_TRUE(quotientLayout->quotient({S("dim2")}).has_value()); } +TEST_F(LinearLayoutTest, Resize) { + auto init = LinearLayout( + { + {S("in0"), {{0, 1}, {0, 2}}}, + {S("in1"), {{1, 0}, {2, 0}}}, + {S("in2"), {}}, + }, + {S("dim0"), S("dim1")}); + EXPECT_EQ(init.resize(S("in0"), 8), + LinearLayout( + { + {S("in0"), {{0, 1}, {0, 2}, {0, 0}}}, + {S("in1"), {{1, 0}, {2, 0}}}, + {S("in2"), {}}, + }, + {S("dim0"), S("dim1")})); + EXPECT_EQ(init.resize(S("in0"), 4), LinearLayout( + { + {S("in0"), {{0, 1}, {0, 2}}}, + {S("in1"), {{1, 0}, {2, 0}}}, + {S("in2"), {}}, + }, + {S("dim0"), S("dim1")})); + EXPECT_EQ(init.resize(S("in1"), 8), + LinearLayout( + { + {S("in0"), {{0, 1}, {0, 2}}}, + {S("in1"), {{1, 0}, {2, 0}, {0, 0}}}, + {S("in2"), {}}, + }, + {S("dim0"), S("dim1")})); +} + } // anonymous namespace } // namespace mlir::triton