diff --git a/bin/RegisterTritonDialects.h b/bin/RegisterTritonDialects.h index 1bd1db9496ea..17c1073e8b53 100644 --- a/bin/RegisterTritonDialects.h +++ b/bin/RegisterTritonDialects.h @@ -57,7 +57,9 @@ inline void registerTritonDialects(mlir::DialectRegistry ®istry) { mlir::registerTritonAMDGPUAccelerateMatmul(); mlir::registerTritonAMDGPUOptimizeEpilogue(); mlir::registerTritonAMDGPUReorderInstructions(); + mlir::registerTritonAMDGPUBypassLDSForDotLayout(); mlir::registerTritonAMDGPUStreamPipeline(); + mlir::registerTritonAMDGPUStreamPipelineV2(); // TODO: register Triton & TritonGPU passes registry.insert unpackLLElements(Location loc, Value llvmStruct, ConversionPatternRewriter &rewriter) { assert(bool(llvmStruct) && "can not unpack null values"); + + if (isMoeLDSBypass()) { + auto llvmVec = llvmStruct; + auto vecTy = dyn_cast<::mlir::VectorType>(llvmVec.getType()); + if (vecTy) { + auto elemTy = vecTy.getElementType(); + auto elemNo = vecTy.getDimSize(0); + SmallVector results; + for (int elem = 0; elem < elemNo; ++elem) + results.push_back(extract_element(llvmVec, i32_val(elem))); + return results; + } + } + if (llvmStruct.getType().isIntOrIndexOrFloat() || isa(llvmStruct.getType()) || isa(llvmStruct.getType())) @@ -1550,6 +1564,26 @@ inline Value packLLElements(Location loc, const LLVMTypeConverter *typeConverter, ValueRange resultVals, ConversionPatternRewriter &rewriter, Type type) { + if (isMoeLDSBypass()) { + auto firstType = resultVals[0].getType(); + bool sameTypes = true; + for (int i = 1; i < resultVals.size(); ++i) + if (resultVals[i].getType() != firstType) { + sameTypes = false; + break; + } + if (sameTypes && firstType.isIntOrFloat() && resultVals.size() > 1) { + // Packing into vector instead of structure, to prevent LLVM splitting + // structure into separate values + auto vecTy = vec_ty(firstType, resultVals.size()); + Value llvmVector = rewriter.create(loc, vecTy); + for (const auto &v : llvm::enumerate(resultVals)) + llvmVector = + insert_element(vecTy, llvmVector, v.value(), i32_val(v.index())); + return llvmVector; + } + } + auto structType = dyn_cast(typeConverter->convertType(type)); if (!structType) { @@ -1562,6 +1596,7 @@ inline Value packLLElements(Location loc, emitError(loc) << " size mismatch when packing elements for LLVM struct" << " expected " << elementTypes.size() << " but got " << resultVals.size(); + assert(false); } Value llvmStruct = rewriter.create(loc, structType); for (const auto &v : llvm::enumerate(resultVals)) { diff --git a/include/triton/Dialect/Triton/IR/Dialect.h b/include/triton/Dialect/Triton/IR/Dialect.h index b1f1597c5aa7..104f854a4726 100644 --- a/include/triton/Dialect/Triton/IR/Dialect.h +++ b/include/triton/Dialect/Triton/IR/Dialect.h @@ -23,6 +23,10 @@ namespace mlir { namespace triton { +void enableMoeLDSBypass(bool value); + +bool isMoeLDSBypass(); + struct GlobalMemory : public SideEffects::Resource::Base { StringRef getName() final { return ""; } }; diff --git a/include/triton/Dialect/Triton/IR/Traits.h b/include/triton/Dialect/Triton/IR/Traits.h index f34a0fd5955f..a768f279446c 100644 --- a/include/triton/Dialect/Triton/IR/Traits.h +++ b/include/triton/Dialect/Triton/IR/Traits.h @@ -4,8 +4,7 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OpDefinition.h" #include "mlir/Support/LogicalResult.h" - -#include +#include "triton/Dialect/Triton/IR/Types.h" namespace mlir { namespace OpTrait { @@ -58,6 +57,51 @@ class VerifyTensorLayoutsTrait } }; +// Verify if the op is a dot-like operation. +// A dot-like operation should have three operands. +// The first two operands should share a common dimension, and the result +// should have the dimensions of the two operands that are not shared. +// A dot-like operation can be either 2d or 3d. +// In the 3d case, the first dimension of operands is the batch dimension. +template +class DotLike : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + if (op->getNumOperands() != 3) + return op->emitOpError("expected 3 operands"); + auto aTy = cast(op->getOperand(0).getType()); + auto bTy = cast(op->getOperand(1).getType()); + auto cTy = cast(op->getOperand(2).getType()); + auto aShape = aTy.getShape(); + auto bShape = bTy.getShape(); + auto cShape = cTy.getShape(); + // Check if all 3d or all 2d + if (aShape.size() != 2 && aShape.size() != 3) + return op->emitOpError("expected operands to be 2d or 3d"); + if (aShape.size() != bShape.size() || aShape.size() != cShape.size()) + return op->emitOpError("expected all operands to have the same rank"); + // Check if the first two operands share a common dimension + if (aShape[aShape.size() - 1] != bShape[aShape.size() - 2]) + return op->emitOpError("expected the last dimension of the first operand " + "to be equal to the second-to-last dimension of " + "the second operand"); + // Check the batch dimension + if (aShape.size() == 3 && + (aShape[0] != cShape[0] || bShape[0] != cShape[0])) + return op->emitOpError("expected the first dimension of the first " + "operand to be equal to the first dimension of " + "the result"); + // Check the output shape + if (cShape[cShape.size() - 2] != aShape[aShape.size() - 2] || + cShape[cShape.size() - 1] != bShape[aShape.size() - 1]) + return op->emitOpError( + "expected the output shape to be the concatenation of the last " + "dimension of the first operand and the last dimension of the " + "second "); + return success(); + } +}; + template class SameOperandsAndResultEncoding : public TraitBase { diff --git a/include/triton/Dialect/Triton/IR/TritonInterfaces.td b/include/triton/Dialect/Triton/IR/TritonInterfaces.td index cfc7d0032cce..f51cca0bc254 100644 --- a/include/triton/Dialect/Triton/IR/TritonInterfaces.td +++ b/include/triton/Dialect/Triton/IR/TritonInterfaces.td @@ -5,6 +5,7 @@ include "mlir/IR/OpBase.td" def TensorSizeTrait : NativeOpTrait<"TensorSizeTrait">; def VerifyTensorLayoutsTrait : NativeOpTrait<"VerifyTensorLayoutsTrait">; +def DotLike : NativeOpTrait<"DotLike">; def SameOperandsEncoding : NativeOpTrait<"SameOperandsEncoding">; def SameOperandsAndResultEncoding : NativeOpTrait<"SameOperandsAndResultEncoding">; def SameLoadStoreOperandsShape : NativeOpTrait<"SameLoadStoreOperandsShape">; diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index a8ab6caa253d..92d66e5b0344 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -625,6 +625,7 @@ def TT_GetNumProgramsOp : TT_Op<"get_num_programs", [Pure]> { // def TT_DotOp : TT_Op<"dot", [Pure, DeclareOpInterfaceMethods, + DotLike, TypesMatchWith<"result's type matches accumulator's type", "d", "c", "$_self">]> { let summary = "dot"; @@ -640,8 +641,8 @@ def TT_DotOp : TT_Op<"dot", [Pure, let arguments = ( ins - TT_TensorOrMemDesc:$a, - TT_TensorOrMemDesc:$b, + TT_FpIntTensor:$a, + TT_FpIntTensor:$b, TT_FpIntTensor:$c, DefaultValuedAttr:$inputPrecision, DefaultValuedAttr:$maxNumImpreciseAcc diff --git a/include/triton/Dialect/Triton/IR/TritonTypes.td b/include/triton/Dialect/Triton/IR/TritonTypes.td index fd5af9cc8ea3..6ceb4bc47665 100644 --- a/include/triton/Dialect/Triton/IR/TritonTypes.td +++ b/include/triton/Dialect/Triton/IR/TritonTypes.td @@ -25,7 +25,8 @@ def TT_BoolTensor : RankedTensorOf<[I1]>; def TT_BoolLike : AnyTypeOf<[I1, TT_BoolTensor]>; // Integer Type -def TT_Int : AnyTypeOf<[I1, I8, I16, I32, I64], "integer">; +def I4 : I<4>; +def TT_Int : AnyTypeOf<[I1, I4, I8, I16, I32, I64], "integer">; def TT_IntTensor : RankedTensorOf<[TT_Int]>; def TT_IntLike : AnyTypeOf<[TT_Int, TT_IntTensor]>; @@ -106,12 +107,13 @@ def TT_MemDescType : TritonTypeDef<"MemDesc", "memdesc", [ShapedTypeInterface]> ArrayRefParameter<"int64_t">:$shape, "Type":$elementType, "Attribute":$encoding, + "Attribute":$memorySpace, "bool":$mutable_memory ); let extraClassDeclaration = [{ MemDescType cloneWith(std::optional> shape, Type elementType) const { - return MemDescType::get(shape.value_or(getShape()), elementType, getEncoding()); + return MemDescType::get(shape.value_or(getShape()), elementType, getEncoding(), getMemorySpace(), getMutableMemory()); } bool hasRank() const { return true; } @@ -120,17 +122,19 @@ def TT_MemDescType : TritonTypeDef<"MemDesc", "memdesc", [ShapedTypeInterface]> TypeBuilderWithInferredContext<(ins "llvm::ArrayRef":$shape, "Type":$elementType, - "Attribute":$encoding + "Attribute":$encoding, + "Attribute":$memorySpace ), [{ - return $_get(elementType.getContext(), shape, elementType, encoding, /*mutableMemory=*/false); + return $_get(elementType.getContext(), shape, elementType, encoding, memorySpace, /*mutableMemory=*/false); }]>, TypeBuilderWithInferredContext<(ins "llvm::ArrayRef":$shape, "Type":$elementType, "Attribute":$encoding, + "Attribute":$memorySpace, "bool":$mutableMemory ), [{ - return $_get(elementType.getContext(), shape, elementType, encoding, mutableMemory); + return $_get(elementType.getContext(), shape, elementType, encoding, memorySpace, mutableMemory); }]> ]; let hasCustomAssemblyFormat = 1; diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index ae23f9d13cea..19d2b9a80782 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -1298,4 +1298,10 @@ elements along the K dim, or they use all elements of the tensor along the K dim }]; } +def TTG_SharedMemorySpace : AttrDef { + let mnemonic = "shared_memory"; + let description = [{ + Attribute to indicate that the memory descriptor points to shared memory. + }]; +} #endif diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td index 2530009cbd87..a87e1c44ac61 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td @@ -144,6 +144,12 @@ def TTG_LocalAllocOp : TTG_Op<"local_alloc", [DeclareOpInterfaceMethods:$src); + let extraClassDeclaration = [{ + bool isSharedMemoryAlloc() { + return getType().getMemorySpace() && + isa(getType().getMemorySpace()); + } + }]; let assemblyFormat = [{$src attr-dict `:` functional-type(operands, results)}]; let results = (outs TT_MemDescType:$result); @@ -163,7 +169,7 @@ def TTG_LocalDeallocOp : TTG_Op<"local_dealloc", [MemoryEffects<[MemFree { "mlir::arith::ArithDialect"]; } + def TritonGPUAccelerateMatmul : Pass<"tritongpu-accelerate-matmul", "mlir::ModuleOp"> { let summary = "accelerate matmul"; diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.h b/include/triton/Dialect/TritonGPU/Transforms/PipelineExpander.h similarity index 100% rename from lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.h rename to include/triton/Dialect/TritonGPU/Transforms/PipelineExpander.h diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.h b/include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h similarity index 100% rename from lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.h rename to include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h diff --git a/include/triton/Dialect/TritonGPU/Transforms/Schedule.h b/include/triton/Dialect/TritonGPU/Transforms/Schedule.h new file mode 100644 index 000000000000..1dd1fc686034 --- /dev/null +++ b/include/triton/Dialect/TritonGPU/Transforms/Schedule.h @@ -0,0 +1,107 @@ +#ifndef TRITON_TRITONGPU_TRANSFORM_PIPELINE_SCHEDULE_H_ +#define TRITON_TRITONGPU_TRANSFORM_PIPELINE_SCHEDULE_H_ + +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Support/LLVM.h" +#include "triton/Dialect/TritonGPU/Transforms/PipelineExpander.h" +#include "llvm/ADT/ArrayRef.h" +#include +#include + +namespace mlir { +namespace triton { + +/// This fill out the pipelining options including schedule and annotations +/// for wait ops. This also does pre-processing by converting some of the +/// loads into async loads so that the IR is ready to be pipelined. +bool preProcessLoopAndGetSchedule(scf::ForOp &forOp, int numStages, + mlir::triton::PipeliningOption &options); + +/// Fills out pipelining options for an outer loop pipelining case. This +/// schedules async copies to overlap with the epilogue of a loop. +bool getOuterLoopSchedule(scf::ForOp &forOp, int numStages, + mlir::triton::PipeliningOption &options); + +/// Pipeline the TMA stores in the loop. +bool pipelineTMAStores(scf::ForOp forOp); + +/// This does post-processing on the pipelined loop to try to pipeline wgmma +/// ops. +// TODO: this should be included as part of the pipeline but currently the wgmma +// wait modeling is problematic. +void asyncLaunchDots(scf::ForOp forOp); + +/// Post process the pipelined loop by updating the wait ops with the right +/// number of groups in flight. +void updateWaits(ModuleOp module); + +class CoarseSchedule { +public: + class ClusterList { + std::list orderClusters; + + public: + using iterator = decltype(orderClusters)::iterator; + ClusterList() = default; + iterator begin() { return orderClusters.begin(); } + iterator end() { return orderClusters.end(); } + size_t size() { return orderClusters.size(); } + iterator newAtBack() { + orderClusters.push_back(orderClusters.size()); + return std::prev(orderClusters.end()); + } + iterator newAtFront() { + orderClusters.push_front(-1); + for (auto &clusterId : orderClusters) { + clusterId++; + } + return orderClusters.begin(); + } + iterator newBefore(iterator cluster) { + auto ret = orderClusters.insert(cluster, *cluster); + for (auto &clusterId : llvm::make_range(cluster, orderClusters.end())) { + clusterId++; + } + return ret; + } + }; + + CoarseSchedule(int numStages) : numStages(numStages) {} + int numStages; + ClusterList clusters; + using Cluster = decltype(clusters)::iterator; + + DenseMap> opToStageAndCluster; + + void insert(Operation *op, int stage, Cluster cluster) { + opToStageAndCluster[op] = {stage, cluster}; + } + + bool insertIfAbsent(Operation *op, int stage, Cluster cluster) { + if (opToStageAndCluster.count(op)) + return false; + insert(op, stage, cluster); + return true; + } + + void insertDepsOfOp(Operation *op, int stage, CoarseSchedule::Cluster cluster, + bool includeArg); + + void erase(Operation *op) { opToStageAndCluster.erase(op); } + + int count(Operation *op) { return opToStageAndCluster.count(op); } + + std::pair operator[](Operation *op) { + return opToStageAndCluster[op]; + } + + SmallVector> + getOpsInOrder(scf::ForOp forOp); + std::vector> + createFinalSchedule(scf::ForOp forOp); + void dump(); +}; + +} // namespace triton +} // namespace mlir +#endif // TRITON_TRITONGPU_TRANSFORM_PIPELINE_SCHEDULE_H_ diff --git a/include/triton/Dialect/TritonGPU/Transforms/Utility.h b/include/triton/Dialect/TritonGPU/Transforms/Utility.h index 114c1814254c..12f44cc378d1 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Utility.h +++ b/include/triton/Dialect/TritonGPU/Transforms/Utility.h @@ -23,7 +23,7 @@ class SharedEncodingAttr; SmallVector mmaVersionToInstrShape(int version, const ArrayRef &shape, - TensorOrMemDesc type, + RankedTensorType type, int numWarps); /// Returns true if the Load uses block pointer. @@ -135,6 +135,9 @@ scf::IfOp replaceIfOpWithNewSignature( RewriterBase &rewriter, scf::IfOp loop, TypeRange newResultTypes, SmallVectorImpl> &replacements); +// Append the given |newOperands| to the |forOp|'s yield op. +void appendToForOpYield(scf::ForOp forOp, ArrayRef newOperands); + Operation *cloneWithInferType(mlir::OpBuilder &rewriter, Operation *op, IRMapping &mapping); diff --git a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td index 486bbf5535c1..1a23c6747e9d 100644 --- a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td +++ b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td @@ -67,13 +67,14 @@ def TTNG_ClusterWaitOp : TTNG_Op<"cluster_wait", []> { } // -// DotAsync Op +// WarpGroupDot Op // -def TTNG_DotAsyncOp : TTNG_Op<"dot_async", [DeclareOpInterfaceMethods, +def TTNG_WarpGroupDotOp : TTNG_Op<"warp_group_dot", [DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, + DotLike, TypesMatchWith<"result's type matches accumulator's type", "d", "c", "$_self">]> { - let summary = "dot async"; + let summary = "warp group dot"; let description = [{ $d = matrix_multiply($a, $b) + $c. For docs on InputPrecisionAttr, see TT_DotOp @@ -82,17 +83,18 @@ def TTNG_DotAsyncOp : TTNG_Op<"dot_async", [DeclareOpInterfaceMethods:$inputPrecision, + DefaultValuedAttr:$maxNumImpreciseAcc, + DefaultValuedAttr:$isAsync); let results = (outs TT_FpIntTensor:$d); let assemblyFormat = "$a`,` $b`,` $c attr-dict `:` type($a) `*` type($b) `->` type($d)"; } -def TTNG_DotWaitOp : TTNG_Op<"dot_wait", [DeclareOpInterfaceMethods, - AllTypesMatch<["inputs", "outputs"]>]> { - let summary = "dot wait"; +def TTNG_WarpGroupDotWaitOp : TTNG_Op<"warp_group_dot_wait", [DeclareOpInterfaceMethods, + AllTypesMatch<["inputs", "outputs"]>]> { + let summary = "warp group dot wait"; let arguments = (ins Variadic:$inputs, I32Attr:$pendings); let results = (outs Variadic:$outputs); let description = [{ @@ -100,7 +102,7 @@ def TTNG_DotWaitOp : TTNG_Op<"dot_wait", [DeclareOpInterfaceMethods *> results) { AliasInfo aliasInfo; bool pessimistic = true; - // These ops may allocate a new shared memory buffer. auto result = op->getResult(0); + // skip ops that return memdesc in a different memory space. + if (auto memdescTy = dyn_cast(result.getType())) { + if (!isa_and_nonnull( + memdescTy.getMemorySpace())) + return; + } // Only LocalAllocOp creates a new buffer. if (isa(op)) { diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index a129cb1947c6..cae55efac1ed 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -210,7 +210,8 @@ class AllocationAnalysis { // XXX(Keren): Why this hard-coded alignment? size_t kAlignment = 8; for (Value result : op->getResults()) { - if (auto alloc = result.getDefiningOp()) { + auto alloc = result.getDefiningOp(); + if (alloc && alloc.isSharedMemoryAlloc()) { // Bytes could be a different value once we support padding or other // allocation policies. auto allocType = alloc.getType(); diff --git a/lib/Analysis/Membar.cpp b/lib/Analysis/Membar.cpp index 407a5ae1508a..256fcc8e88dd 100644 --- a/lib/Analysis/Membar.cpp +++ b/lib/Analysis/Membar.cpp @@ -145,17 +145,6 @@ void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo, } } } - // XXX(Keren): This is a hack as we cannot set side effects for dot ops, but - // on hopper they do have side effects. Need to clean it up - if (auto dotOp = dyn_cast(op)) { - for (auto value : dotOp.getOperands()) { - for (auto bufferId : allocation->getBufferIds(value)) { - if (bufferId != Allocation::InvalidBufferId) - curBlockInfo.syncReadIntervals.insert( - allocation->getAllocatedInterval(bufferId)); - } - } - } // Scratch buffer is considered as both shared memory write & read auto bufferId = allocation->getBufferId(op); if (bufferId != Allocation::InvalidBufferId) { diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 32cc43c9d5d2..1ffc6a05f3c8 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -582,6 +582,20 @@ bool supportMMA(Value value, int version) { (elemTy.isInteger(8) && version >= 2); } +bool isBlockedToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) { + if (isMoeLDSBypass()) { + if (auto blockSrc = llvm::dyn_cast( + srcTy.getEncoding())) { + auto dotOp = llvm::dyn_cast( + dstTy.getEncoding()); + if (dotOp && dotOp.getOpIdx() == 1) { + return true; + } + } + } + return false; +} + bool isMfmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) { auto srcLayout = srcTy.getEncoding(); auto dstLayout = dstTy.getEncoding(); diff --git a/lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp b/lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp index 690155ee54b5..6a5eb2528e72 100644 --- a/lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp +++ b/lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp @@ -90,17 +90,23 @@ void decomposeBlockedToDotLayoutConversion(ModuleOp module) { OpBuilder builder(cvtOp); auto srcType = cast(cvtOp.getSrc().getType()); auto dstType = cast(cvtOp.getType()); + if (isMoeLDSBypass() && isBlockedToDotShortcut(srcType, dstType)) { + return; + } auto srcBlocked = dyn_cast(srcType.getEncoding()); auto dstDotOp = dyn_cast(dstType.getEncoding()); if (srcBlocked && dstDotOp) { + Attribute sharedMemorySpace = + triton::gpu::SharedMemorySpaceAttr::get(srcType.getContext()); auto tmpType = MemDescType::get( dstType.getShape(), dstType.getElementType(), triton::gpu::SharedEncodingAttr::get( module.getContext(), dstDotOp, srcType.getShape(), srcBlocked.getOrder(), srcBlocked.getCTALayout(), - srcType.getElementType())); + srcType.getElementType()), + sharedMemorySpace); auto tmp = builder.create( cvtOp.getLoc(), tmpType, cvtOp.getSrc()); addAttrs(tmp, cvtOp->getAttrs()); diff --git a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp index 0287207be51a..05e7936a6648 100644 --- a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -135,6 +135,15 @@ int getNumElementsPerThreads(Type type, if (structType) { numElemsPerThread = structType.getBody().size(); } + + if (isMoeLDSBypass()) { + auto vectorType = + dyn_cast(typeConverter->convertType(type)); + if (vectorType) { + numElemsPerThread = vectorType.getDimSize(0); + } + } + auto encoding = dyn_cast(tensorTy.getEncoding()); if (!(encoding && isa(encoding.getParent()))) return numElemsPerThread; diff --git a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp index 12ab6684c3b3..38ecb288af1a 100644 --- a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp @@ -51,6 +51,8 @@ struct LocalAllocOpConversion LogicalResult matchAndRewrite(triton::gpu::LocalAllocOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + if (!op.isSharedMemoryAlloc()) + return failure(); Location loc = op->getLoc(); Value smemBase = LLVM::getSharedMemoryBase(loc, rewriter, op.getOperation()); diff --git a/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp b/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp index 908aa1e2bd2f..ce55bf75d871 100644 --- a/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp @@ -110,6 +110,10 @@ Type TritonGPUToLLVMTypeConverter::convertTritonTensorType( } unsigned numElementsPerThread = getTotalElemsPerThread(type); + + if (isMoeLDSBypass() && eltType.isIntOrFloat() && numElementsPerThread > 1) + return vec_ty(eltType, numElementsPerThread); + SmallVector types(numElementsPerThread, eltType); return LLVM::LLVMStructType::getLiteral(ctx, types); } diff --git a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp index e0f6e937723a..e1198fe6c330 100644 --- a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp @@ -26,6 +26,12 @@ struct SplatOpConversion : public ConvertOpToLLVMPattern { auto srcType = typeConverter->convertType(tensorTy); if (auto structTy = dyn_cast(srcType)) srcType = structTy.getBody()[0]; + + if (isMoeLDSBypass()) { + if (auto vectorTy = dyn_cast(srcType)) + srcType = vectorTy.getElementType(); + } + // If the type sizes don't match we need to pack constants. if (srcType.isIntOrFloat() && constVal.getType().getIntOrFloatBitWidth() != srcType.getIntOrFloatBitWidth()) { diff --git a/lib/Dialect/Triton/IR/Dialect.cpp b/lib/Dialect/Triton/IR/Dialect.cpp index 8f46e8ca8bc0..2a3c62cad679 100644 --- a/lib/Dialect/Triton/IR/Dialect.cpp +++ b/lib/Dialect/Triton/IR/Dialect.cpp @@ -21,6 +21,15 @@ using namespace mlir::triton; // TritonDialect Dialect Interfaces //===----------------------------------------------------------------------===// +namespace mlir::triton { + +static bool moeLDSBypass = false; + +void enableMoeLDSBypass(bool value) { moeLDSBypass = value; } + +bool isMoeLDSBypass() { return moeLDSBypass; } +} // namespace mlir::triton + namespace { struct TritonInlinerInterface : public DialectInlinerInterface { using DialectInlinerInterface::DialectInlinerInterface; diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index ce4f97336919..4b26babef93e 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -224,9 +224,9 @@ LogicalResult TransOp::inferReturnTypes( return failure(); } } - if (isa(argTy)) { - inferredReturnTypes.push_back( - MemDescType::get(retShape, retEltTy, retEncoding)); + if (auto memDescTy = dyn_cast(argTy)) { + inferredReturnTypes.push_back(MemDescType::get( + retShape, retEltTy, retEncoding, memDescTy.getMemorySpace())); } else { inferredReturnTypes.push_back( RankedTensorType::get(retShape, retEltTy, retEncoding)); diff --git a/lib/Dialect/Triton/IR/Types.cpp b/lib/Dialect/Triton/IR/Types.cpp index 0e1df5b744cd..f9356cb73791 100644 --- a/lib/Dialect/Triton/IR/Types.cpp +++ b/lib/Dialect/Triton/IR/Types.cpp @@ -70,16 +70,24 @@ Type MemDescType::parse(AsmParser &parser) { return Type(); } bool mutableMemory = false; + Attribute memorySpace; if (succeeded(parser.parseOptionalComma())) { + if (failed(parser.parseOptionalKeyword(kMutableMemory))) { + if (parser.parseAttribute(memorySpace)) + return Type(); + } else { + mutableMemory = true; + } + } + if (mutableMemory == false && succeeded(parser.parseOptionalComma())) { if (parser.parseOptionalKeyword(kMutableMemory)) return Type(); mutableMemory = true; } if (parser.parseGreater()) return Type(); - return MemDescType::get(parser.getContext(), dimensions, elementType, - encoding, mutableMemory); + encoding, memorySpace, mutableMemory); } void MemDescType::print(AsmPrinter &printer) const { @@ -89,6 +97,8 @@ void MemDescType::print(AsmPrinter &printer) const { printer << getElementType(); if (getEncoding()) printer << ", " << getEncoding(); + if (getMemorySpace()) + printer << ", " << getMemorySpace(); if (getMutableMemory()) printer << ", " << kMutableMemory; printer << ">"; diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 69067b7068a4..eacd481553d6 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -2757,8 +2757,9 @@ struct CanonicalizeConvertFromConvert // for hopper MMAv3 if (mlir::isa(dstType.getEncoding()) && mlir::isa(srcType.getEncoding()) && - llvm::any_of(op.getResult().getUsers(), - [](Operation *dot) { return isa(dot); })) { + llvm::any_of(op.getResult().getUsers(), [](Operation *dot) { + return dot->hasTrait(); + })) { return failure(); } diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp index df84c4e628ea..2d048fd95c68 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp @@ -208,12 +208,15 @@ class BlockedToMMA : public mlir::RewritePattern { } } + Attribute SharedMemorySpace = + SharedMemorySpaceAttr::get(argType.getContext()); auto CTALayout = getCTALayout(argType.getEncoding()); auto newLayout = SharedEncodingAttr::get(argType.getContext(), argType.getShape(), newOrder, CTALayout, argType.getElementType()); - auto newType = MemDescType::get(argType.getShape(), - argType.getElementType(), newLayout); + auto newType = + MemDescType::get(argType.getShape(), argType.getElementType(), + newLayout, SharedMemorySpace); rewriter.setInsertionPointAfterValue(arg); return rewriter.create(arg.getLoc(), newType, arg); } @@ -296,11 +299,14 @@ class BlockedToMMA : public mlir::RewritePattern { auto newAcc = rewriter.create(oldAcc.getLoc(), newRetType, oldAcc); + Operation *newDot = nullptr; if (versionMajor == 3) { a = getMMAv3Operand(a, rewriter, 0); b = getMMAv3Operand(b, rewriter, 1); + newDot = rewriter.create( + dotOp.getLoc(), newRetType, a, b, newAcc, dotOp.getInputPrecision(), + dotOp.getMaxNumImpreciseAcc(), false); } else { - // convert operands int minBitwidth = std::min(computeOrigBitWidth(a), computeOrigBitWidth(b)); @@ -319,14 +325,13 @@ class BlockedToMMA : public mlir::RewritePattern { auto newBType = RankedTensorType::get( oldBType.getShape(), oldBType.getElementType(), newBEncoding); b = rewriter.create(b.getLoc(), newBType, b); + newDot = rewriter.create(dotOp.getLoc(), newRetType, a, b, newAcc, + dotOp.getInputPrecision(), + dotOp.getMaxNumImpreciseAcc()); } // convert dot instruction - auto newDot = rewriter.create(dotOp.getLoc(), newRetType, a, b, - newAcc, dotOp.getInputPrecision(), - dotOp.getMaxNumImpreciseAcc()); - rewriter.replaceOpWithNewOp(op, oldRetType, - newDot.getResult()); + newDot->getResult(0)); return success(); } }; diff --git a/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt b/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt index 7b2ab63e8748..9767effa5a74 100644 --- a/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt +++ b/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt @@ -12,6 +12,7 @@ add_triton_library(TritonGPUTransforms Pipeliner/SoftwarePipeliner.cpp Pipeliner/TMAStoresPipeline.cpp Pipeliner/PipeliningUtility.cpp + Pipeliner/Schedule.cpp Prefetch.cpp RemoveLayoutConversions.cpp ReorderInstructions.cpp diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index 4a30bf9f37c8..f41b8ef8687f 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -59,12 +59,12 @@ class SwizzleShmemConvert : public OpRewritePattern { srcTy.getElementType(), /*needTrans=*/true); if (newInnerCvtEnc == cvtEncoding) return failure(); - rewriter.setInsertionPoint(trans); + auto sharedMemorySpace = SharedMemorySpaceAttr::get(getContext()); auto alloc = rewriter.create( trans.getLoc(), MemDescType::get(srcTy.getShape(), srcTy.getElementType(), - newInnerCvtEnc), + newInnerCvtEnc, sharedMemorySpace), trans.getSrc()); auto newTrans = rewriter.create(trans.getLoc(), alloc, ArrayRef({1, 0})); @@ -210,7 +210,7 @@ class FuseTransHopper : public OpRewritePattern { LogicalResult matchAndRewrite(LocalAllocOp allocOp, PatternRewriter &rewriter) const override { if (!allocOp->hasOneUse() || - !isa(*allocOp->getUsers().begin())) + !allocOp->getUsers().begin()->hasTrait()) return failure(); auto dot = *allocOp->getUsers().begin(); @@ -254,7 +254,8 @@ class FuseTransHopper : public OpRewritePattern { allocEncoding.getCTALayout(), srcTy.getElementType()); MemDescType innerTy = - MemDescType::get(srcTy.getShape(), srcTy.getElementType(), newInnerEnc); + MemDescType::get(srcTy.getShape(), srcTy.getElementType(), newInnerEnc, + allocType.getMemorySpace()); auto newAlloc = rewriter.create(allocOp.getLoc(), innerTy, trans.getSrc()); rewriter.replaceOpWithNewOp(allocOp, newAlloc, @@ -267,10 +268,11 @@ class FuseTransHopper : public OpRewritePattern { // dot(convert(lhs #mma) #shared, rhs) #mma -> // dot(convert(lhs #mma) #dot_operand, rhs) #mma, // for fp16 or bf16 MMAv3 dots. -struct MMAV3UseRegOperand : public OpRewritePattern { +struct MMAV3UseRegOperand + : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(DotOp dotOp, + LogicalResult matchAndRewrite(triton::nvidia_gpu::WarpGroupDotOp dotOp, PatternRewriter &rewriter) const override { auto alloc = dotOp.getOperand(0).getDefiningOp(); if (!alloc || !alloc.getSrc()) diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp index f3d5aa00ec27..044793dcfcea 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp @@ -1,6 +1,3 @@ -#include "PipelineExpander.h" -#include "PipeliningUtility.h" -#include "Schedule.h" #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/IRMapping.h" @@ -14,6 +11,9 @@ #include "triton/Dialect/TritonGPU/IR/Attributes.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/PipelineExpander.h" +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" +#include "triton/Dialect/TritonGPU/Transforms/Schedule.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" #include "llvm/ADT/MapVector.h" @@ -51,168 +51,10 @@ struct LoadInfo { } // namespace -class CoarseSchedule { -public: - class ClusterList { - std::list orderClusters; - - public: - using iterator = decltype(orderClusters)::iterator; - ClusterList() = default; - iterator begin() { return orderClusters.begin(); } - iterator end() { return orderClusters.end(); } - size_t size() { return orderClusters.size(); } - iterator newAtBack() { - orderClusters.push_back(orderClusters.size()); - return std::prev(orderClusters.end()); - } - iterator newAtFront() { - orderClusters.push_front(-1); - for (auto &clusterId : orderClusters) { - clusterId++; - } - return orderClusters.begin(); - } - iterator newBefore(iterator cluster) { - auto ret = orderClusters.insert(cluster, *cluster); - for (auto &clusterId : llvm::make_range(cluster, orderClusters.end())) { - clusterId++; - } - return ret; - } - }; - - CoarseSchedule(int numStages) : numStages(numStages) {} - int numStages; - ClusterList clusters; - using Cluster = decltype(clusters)::iterator; - - DenseMap> opToStageAndCluster; - - void insert(Operation *op, int stage, Cluster cluster) { - opToStageAndCluster[op] = {stage, cluster}; - } - - bool insertIfAbsent(Operation *op, int stage, Cluster cluster) { - if (opToStageAndCluster.count(op)) - return false; - insert(op, stage, cluster); - return true; - } - - void insertDepsOfOp(Operation *op, int stage, CoarseSchedule::Cluster cluster, - bool includeArg) { - for (Value operand : op->getOperands()) { - Value v = operand; - llvm::SmallDenseSet seen; - while (auto arg = dyn_cast(v)) { - if (!includeArg) - break; - if (!seen.insert(v).second) - break; - if (arg.getArgNumber() > 0 && arg.getOwner() == op->getBlock()) { - auto yieldOp = op->getBlock()->getTerminator(); - v = yieldOp->getOperand(arg.getArgNumber() - 1); - continue; - } - break; - } - Operation *defOp = v.getDefiningOp(); - if (defOp && defOp->getBlock() == op->getBlock()) { - if (insertIfAbsent(defOp, stage, cluster)) { - insertDepsOfOp(defOp, stage, cluster, includeArg); - } - } - } - } - - void erase(Operation *op) { opToStageAndCluster.erase(op); } - - int count(Operation *op) { return opToStageAndCluster.count(op); } - - std::pair operator[](Operation *op) { - return opToStageAndCluster[op]; - } - - SmallVector> - getOpsInOrder(scf::ForOp forOp) { - SmallVector>, 8> - orderClusters(clusters.size()); - for (auto &op : forOp.getBody()->without_terminator()) { - if (opToStageAndCluster.count(&op) == 0) { - continue; - } - assert(opToStageAndCluster[&op].first < numStages && - "Op with invalid stage!"); - int clusterId = *opToStageAndCluster[&op].second; - assert(clusterId == std::distance(clusters.begin(), - opToStageAndCluster[&op].second) && - "Cluster ID mismatch!"); - orderClusters[clusterId].push_back( - make_tuple(&op, opToStageAndCluster[&op].first, - opToStageAndCluster[&op].second)); - } - SmallVector> opsInOrder; - for (int i = 0; i < orderClusters.size(); i++) { - for (auto [op, stage, cluster] : orderClusters[i]) { - opsInOrder.push_back({op, stage, cluster}); - } - } - - return opsInOrder; - } - - std::vector> - createFinalSchedule(scf::ForOp forOp) { - SmallVector> opsInOrder = - getOpsInOrder(forOp); - std::vector> schedule; - for (auto [op, stage, cluster] : opsInOrder) { - LDBG("Adding op to schedule at stage " << stage << " cluster " << *cluster - << ":" << *op); - schedule.push_back({op, stage}); - } - return schedule; - } - - void dump() { - for (int i = 0; i < numStages; i++) { - LDBG("- Ops in stage " << i); - for (auto &[op, stageAndCluster] : opToStageAndCluster) { - if (i == stageAndCluster.first) { - llvm::outs() << " cluster: " << *stageAndCluster.second << " "; - op->dump(); - } - } - } - } -}; - -static bool isMMAv3Dot(Operation *op) { - auto dot = dyn_cast(op); - if (!dot) - return false; - auto enc = - mlir::dyn_cast(dot.getType().getEncoding()); - return enc && enc.isHopper(); -} - -// Replace the ForOp's yield with a new one with the given operands appended. -static void appendToYield(scf::ForOp forOp, ArrayRef newOperands) { - // Fix up the yield op. - Operation *yieldOp = forOp.getBody()->getTerminator(); - SmallVector operands(yieldOp->getOperands()); - operands.append(newOperands.begin(), newOperands.end()); - - OpBuilder builder(yieldOp); - builder.create(yieldOp->getLoc(), operands); - yieldOp->erase(); -} - static void createAsyncCopy(scf::ForOp &forOp, tt::LoadOp loadOp, Value alloc, Value insertIdx, Value extractIdx, - CoarseSchedule &schedule, - CoarseSchedule::Cluster prefetchCluster, + tt::CoarseSchedule &schedule, + tt::CoarseSchedule::Cluster prefetchCluster, llvm::MapVector &loadToInfo, int numStages) { OpBuilder builder(forOp); @@ -245,9 +87,11 @@ static void createAsyncCopy(scf::ForOp &forOp, tt::LoadOp loadOp, Value alloc, tt::MemDescType allocTy = cast(alloc.getType()); SmallVector copyOffsets(allocTy.getRank(), zero); copyOffsets[0] = insertIdx; + Attribute sharedMemorySpace = + triton::gpu::SharedMemorySpaceAttr::get(forOp.getContext()); tt::MemDescType subviewTy = tt::MemDescType::get( allocTy.getShape().drop_front(), allocTy.getElementType(), - allocTy.getEncoding(), /*mutableMemory=*/true); + allocTy.getEncoding(), sharedMemorySpace, /*mutableMemory=*/true); auto view = builder.create(loc, subviewTy, alloc, copyOffsets); Operation *copy = builder.create( @@ -312,10 +156,12 @@ static void createAsyncCopy(scf::ForOp &forOp, tt::LoadOp loadOp, Value alloc, static void createTMAAsyncCopy( scf::ForOp &forOp, tt::ExperimentalDescriptorLoadOp loadOp, Value alloc, Value insertIdx, Value extractIdx, Value barrier, Operation *waitOp, - Value phase, CoarseSchedule &schedule, + Value phase, tt::CoarseSchedule &schedule, llvm::MapVector &loadToInfo, int numStages) { assert(phase && "Phase value is required for TMA async copy."); OpBuilder builder(forOp); + Attribute sharedMemorySpace = + triton::gpu::SharedMemorySpaceAttr::get(forOp.getContext()); Value zero = builder.create(forOp.getLoc(), 0, 32); builder.setInsertionPoint(loadOp); Location loc = loadOp.getLoc(); @@ -324,7 +170,7 @@ static void createTMAAsyncCopy( copyOffsets[0] = insertIdx; tt::MemDescType subviewTy = tt::MemDescType::get( allocTy.getShape().drop_front(), allocTy.getElementType(), - allocTy.getEncoding(), /*mutableMemory=*/true); + allocTy.getEncoding(), sharedMemorySpace, /*mutableMemory=*/true); auto view = builder.create(loc, subviewTy, alloc, copyOffsets); @@ -368,11 +214,13 @@ static void createTMAAsyncCopy( } // If all the transitive uses of the given value have are used by a convert to -// the same dot operand encoding, return true and get the shared encoding that -// needs to be used to be compatible with users' layouts. +// the same dot operand encoding, return the shared encoding that needs to be +// used to be compatible with users' layouts. If there are imcompatible shared +// encodings set `incompatible` to true. static std::optional -getSharedEncIfAllUsersAreDotEnc(Value val) { +getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible) { ttg::SharedEncodingAttr attr; + incompatible = false; for (Operation *user : val.getUsers()) { ttg::SharedEncodingAttr tempAttr; if (user->getNumResults() != 1) @@ -382,7 +230,8 @@ getSharedEncIfAllUsersAreDotEnc(Value val) { // First time we find a shared encoding in the chain, save it and try to // use it if it is compatible with the other users. tempAttr = cast(memDesc.getEncoding()); - if (!getSharedEncIfAllUsersAreDotEnc(user->getResult(0)).has_value()) + if (!getSharedEncIfAllUsersAreDotEnc(user->getResult(0), incompatible) + .has_value()) return std::nullopt; } else { if (!isa(user)) @@ -402,8 +251,10 @@ getSharedEncIfAllUsersAreDotEnc(Value val) { srcTy.getElementType().getIntOrFloatBitWidth(), /*needTrans=*/false); } // Check that the shared encodings needed by the users are compatible. - if (!tempAttr || (attr != nullptr && attr != tempAttr)) + if (attr != nullptr && attr != tempAttr) { + incompatible = true; return std::nullopt; + } attr = tempAttr; } return attr; @@ -506,7 +357,7 @@ loadOpsToIndirectionLevelAndUse(scf::ForOp forOp) { }; for (Operation &op : forOp.getBody()->without_terminator()) { - if (!isa(op)) + if (!op.hasTrait()) continue; seen.clear(); dfs(&op, 0, &op); @@ -583,7 +434,7 @@ assignMemoryLayouts(llvm::SmallVector> continue; } - if (auto dot = dyn_cast(use)) { + if (use->hasTrait()) { loadInfo.usedByDot = true; if (loadIsMMAv3(op)) { loadInfo.loadIsMMAV3 = true; @@ -592,10 +443,14 @@ assignMemoryLayouts(llvm::SmallVector> } else if (isa(op)) { loadInfo.sharedEncoding = getSharedEncoding(op, /*loadIsMMAv3=*/true).value_or(nullptr); - } else { + } else if (auto dot = dyn_cast(use)) { + bool incompatible = false; loadInfo.sharedEncoding = - getSharedEncIfAllUsersAreDotEnc(op->getResult(0)).value_or(nullptr); - + getSharedEncIfAllUsersAreDotEnc(op->getResult(0), incompatible) + .value_or(nullptr); + // If we can't agree on a shared encoding skip pipelinig the load. + if (incompatible) + continue; // HACK: Triton LLVM codegen has a bug where local_loads from #shared to // #mma layout can lead to invalid code if the loaded shape is smaller // than the mma tile (e.g. loading a 128x1 tensor for an MMAv2 dot with @@ -642,7 +497,7 @@ assignMemoryLayouts(llvm::SmallVector> // If we still don't have a shared encoding, try a "generic" shared // encoding. - if (!loadInfo.sharedEncoding && !isMMAv3Dot(use)) { + if (!loadInfo.sharedEncoding && !isa(use)) { loadInfo.sharedEncoding = getSharedEncoding(op, /*isMMAV3=*/loadInfo.loadIsMMAV3) .value_or(nullptr); @@ -662,7 +517,7 @@ assignMemoryLayouts(llvm::SmallVector> } static llvm::MapVector -scheduleLoads(scf::ForOp forOp, CoarseSchedule &schedule, +scheduleLoads(scf::ForOp forOp, tt::CoarseSchedule &schedule, DenseSet &rootUsers, int numStages) { ModuleOp moduleOp = forOp->getParentOfType(); tt::ModuleAxisInfoAnalysis axisInfoAnalysis(moduleOp); @@ -700,7 +555,7 @@ scheduleLoads(scf::ForOp forOp, CoarseSchedule &schedule, unsigned stagesBetweenLoads = ceil(numStages - 2, maxIndirectionLevel + 1); - CoarseSchedule::Cluster rootUsersCluster = schedule.clusters.newAtFront(); + tt::CoarseSchedule::Cluster rootUsersCluster = schedule.clusters.newAtFront(); // Put the root uses of the loads in the last stage. for (auto &[loadOp, dist, use] : loadOpToIndLevelAndUse) { if (loadToInfo.count(loadOp) == 0) @@ -713,7 +568,7 @@ scheduleLoads(scf::ForOp forOp, CoarseSchedule &schedule, } } - SmallVector loadsClusters; + SmallVector loadsClusters; for (int i = 0; i < maxIndirectionLevel + 1; i++) { loadsClusters.push_back(schedule.clusters.newAtBack()); } @@ -738,10 +593,10 @@ scheduleLoads(scf::ForOp forOp, CoarseSchedule &schedule, // Schedule the prologue and epilogue `if` ops in the loop, pushing them as // close to the loop boundaries as possible. Return the cluster after the // prologue (or the beginning of the loop if there is no prologue). -static CoarseSchedule::Cluster -schedulePrologueAndEpilogue(scf::ForOp forOp, CoarseSchedule &schedule, +static tt::CoarseSchedule::Cluster +schedulePrologueAndEpilogue(scf::ForOp forOp, tt::CoarseSchedule &schedule, DenseSet &rootUsers, int numStages) { - CoarseSchedule::Cluster afterPrologue = schedule.clusters.begin(); + tt::CoarseSchedule::Cluster afterPrologue = schedule.clusters.begin(); // Look for the IfOp that is in the backward slice any of the currently // scheduled ops and put it at the beginning of the loop. @@ -763,14 +618,14 @@ schedulePrologueAndEpilogue(scf::ForOp forOp, CoarseSchedule &schedule, } } } - CoarseSchedule::Cluster prologueCluster = schedule.clusters.newAtFront(); + tt::CoarseSchedule::Cluster prologueCluster = schedule.clusters.newAtFront(); for (auto [ifOp, stage] : ifsToStage) { schedule.insert(ifOp, stage, prologueCluster); } // Look for the IfOp that is in the forward slice of the root users and put it // at the end of the loop. - CoarseSchedule::Cluster epilogueCluster = schedule.clusters.newAtBack(); + tt::CoarseSchedule::Cluster epilogueCluster = schedule.clusters.newAtBack(); for (auto rootUser : rootUsers) { SetVector forwardSlice; getForwardSlice(rootUser, &forwardSlice); @@ -797,9 +652,9 @@ schedulePrologueAndEpilogue(scf::ForOp forOp, CoarseSchedule &schedule, // Add dependencies of anchor ops to the coarse schedule. Schedule them to // the same stage and ordering cluster as the anchor op. -static void scheduleDependencies(scf::ForOp forOp, CoarseSchedule &schedule, +static void scheduleDependencies(scf::ForOp forOp, tt::CoarseSchedule &schedule, int numStages) { - SmallVector> + SmallVector> opsInOrder = schedule.getOpsInOrder(forOp); // Schedule dependencies stage by stage. for (int stage = 0; stage < numStages; stage++) { @@ -814,7 +669,7 @@ static void scheduleDependencies(scf::ForOp forOp, CoarseSchedule &schedule, // Find dependencies with distance of 1. They will go to the next stage, // but in the cluster before the current op. static void scheduleDistanceOneDependencies(scf::ForOp forOp, - CoarseSchedule &schedule, + tt::CoarseSchedule &schedule, int numStages) { auto getNestedOperands = [](Operation *op) -> SmallVector { SmallVector operands; @@ -828,7 +683,8 @@ static void scheduleDistanceOneDependencies(scf::ForOp forOp, }; // Mapping from the cluster to the cluster before it. - DenseMap dist1Cluster; + DenseMap + dist1Cluster; for (auto &op : forOp.getBody()->without_terminator()) { if (schedule.count(&op) == 0) continue; @@ -863,14 +719,14 @@ static void scheduleDistanceOneDependencies(scf::ForOp forOp, } } -static void scheduleRemainingToLastStage(scf::ForOp forOp, - CoarseSchedule &schedule, - CoarseSchedule::Cluster afterPrologue, - int numStages) { +static void +scheduleRemainingToLastStage(scf::ForOp forOp, tt::CoarseSchedule &schedule, + tt::CoarseSchedule::Cluster afterPrologue, + int numStages) { // Assign the rest of the ops to the last stage. // Take care of the ordering of the ops - uses cannot be scheduled to the // cluster before the definition. - DenseMap opToCluster; + DenseMap opToCluster; for (auto &op : forOp.getBody()->without_terminator()) { if (schedule.count(&op) == 0) { opToCluster[&op] = afterPrologue; @@ -888,8 +744,8 @@ static void scheduleRemainingToLastStage(scf::ForOp forOp, Operation *op = queue.pop_back_val(); for (auto user : op->getUsers()) { if (opToCluster.count(user)) { - CoarseSchedule::Cluster userCluster = opToCluster[user]; - CoarseSchedule::Cluster opCluster; + tt::CoarseSchedule::Cluster userCluster = opToCluster[user]; + tt::CoarseSchedule::Cluster opCluster; if (schedule.count(op)) opCluster = schedule[op].second; else @@ -910,11 +766,14 @@ static void scheduleRemainingToLastStage(scf::ForOp forOp, static Value createAlloc(scf::ForOp &forOp, Operation *loadOp, ttg::SharedEncodingAttr sharedEnc, unsigned distance) { OpBuilder builder(forOp); + Attribute sharedMemorySpace = + triton::gpu::SharedMemorySpaceAttr::get(forOp.getContext()); auto ty = cast(loadOp->getResultTypes()[0]); SmallVector bufferShape(ty.getShape().begin(), ty.getShape().end()); bufferShape.insert(bufferShape.begin(), distance); Type memdescType = mlir::triton::MemDescType::get( - bufferShape, ty.getElementType(), sharedEnc, /*mutableMemory*/ true); + bufferShape, ty.getElementType(), sharedEnc, sharedMemorySpace, + /*mutableMemory*/ true); Value alloc = builder.create( loadOp->getLoc(), memdescType, Value()); return alloc; @@ -923,6 +782,8 @@ static Value createAlloc(scf::ForOp &forOp, Operation *loadOp, // Create an allocation to hold the mbarriers. static Value createBarrierAlloc(scf::ForOp &forOp, unsigned distance) { OpBuilder builder(forOp); + Attribute sharedMemorySpace = + triton::gpu::SharedMemorySpaceAttr::get(forOp.getContext()); Location loc = forOp.getLoc(); auto context = forOp.getContext(); auto barrierCTALayout = @@ -930,11 +791,12 @@ static Value createBarrierAlloc(scf::ForOp &forOp, unsigned distance) { /*CTASplitNum=*/{1}, /*CTAOrder=*/{0}); auto barrierEncoding = ttg::SharedEncodingAttr::get(context, 1, 1, 1, {0}, barrierCTALayout); - Type barrierMemDescType = - tt::MemDescType::get({distance}, builder.getI64Type(), barrierEncoding, - /*mutableMemory=*/true); - Type singleBarrierMemDescType = tt::MemDescType::get( - {1}, builder.getI64Type(), barrierEncoding, /*mutableMemory=*/true); + Type barrierMemDescType = tt::MemDescType::get( + {distance}, builder.getI64Type(), barrierEncoding, sharedMemorySpace, + /*mutableMemory=*/true); + Type singleBarrierMemDescType = + tt::MemDescType::get({1}, builder.getI64Type(), barrierEncoding, + sharedMemorySpace, /*mutableMemory=*/true); Value barrierAlloc = builder.create( loc, barrierMemDescType, Value()); for (unsigned i = 0; i < distance; i++) { @@ -959,7 +821,7 @@ struct AsyncLoad { // multiple loads is the schedule allows it. static void createTMABarrierAndWait( scf::ForOp &forOp, SmallVector &asyncLoads, Value insertIdx, - Value extractIdx, Value phase, int numBuffers, CoarseSchedule &schedule, + Value extractIdx, Value phase, int numBuffers, tt::CoarseSchedule &schedule, SmallVector &barriers, const llvm::MapVector &loadToInfo) { llvm::SmallDenseMap loadToAsyncLoad; @@ -1030,9 +892,12 @@ static void createTMABarrierAndWait( barriers.push_back(barrierAlloc); Location loc = forOp.getLoc(); OpBuilder builder(forOp); + Attribute sharedMemorySpace = + triton::gpu::SharedMemorySpaceAttr::get(builder.getContext()); tt::MemDescType barrierTy = tt::MemDescType::get( {1}, builder.getI64Type(), cast(barrierAlloc.getType()).getEncoding(), + sharedMemorySpace, /*mutableMemory=*/true); builder.setInsertionPoint(group[0]->loadOp); Value barrier = builder.create( @@ -1059,7 +924,7 @@ static void createTMABarrierAndWait( // Convert load ops into their asyn version and apply multi-buffering based on // the required number of buffers. static SmallVector -createAsyncOps(scf::ForOp &forOp, CoarseSchedule &schedule, +createAsyncOps(scf::ForOp &forOp, tt::CoarseSchedule &schedule, llvm::MapVector &loadToInfo, SmallVector &barriers, int numStages) { // Calculate the number of buffers needed for each load. @@ -1147,7 +1012,7 @@ createAsyncOps(scf::ForOp &forOp, CoarseSchedule &schedule, // Create a cluster for the prefetches. It may end up being empty, but this // is OK. - CoarseSchedule::Cluster prefetchCluster = schedule.clusters.newAtBack(); + tt::CoarseSchedule::Cluster prefetchCluster = schedule.clusters.newAtBack(); for (AsyncLoad &asyncLoad : asyncLoads) { if (auto loadOp = dyn_cast(asyncLoad.loadOp)) { @@ -1164,13 +1029,15 @@ createAsyncOps(scf::ForOp &forOp, CoarseSchedule &schedule, if (phase) newYieldOperands.push_back(phase); // Patch the yield with the updated counters. - appendToYield(forOp, newYieldOperands); + appendToForOpYield(forOp, newYieldOperands); return allocs; } static void invalidateBarriers(OpBuilder &builder, SmallVector &barriers) { + Attribute sharedMemorySpace = + triton::gpu::SharedMemorySpaceAttr::get(builder.getContext()); for (Value barrier : barriers) { int numBarriers = cast(barrier.getType()).getShape()[0]; for (int i = 0; i < numBarriers; i++) { @@ -1178,6 +1045,7 @@ static void invalidateBarriers(OpBuilder &builder, tt::MemDescType barrierTy = tt::MemDescType::get( {1}, builder.getI64Type(), cast(barrier.getType()).getEncoding(), + sharedMemorySpace, /*mutableMemory=*/true); Value barrierView = builder.create( barrier.getLoc(), barrierTy, barrier, idx); @@ -1191,7 +1059,7 @@ bool mlir::triton::preProcessLoopAndGetSchedule( // Schedule the loads and root ops (dot ops) in the loop. This will give us // a scaffold for the final schedule. DenseSet rootUsers; - CoarseSchedule coarseSchedule(numStages); + tt::CoarseSchedule coarseSchedule(numStages); llvm::MapVector loadToInfo = scheduleLoads(forOp, coarseSchedule, rootUsers, numStages); if (loadToInfo.empty()) @@ -1212,7 +1080,7 @@ bool mlir::triton::preProcessLoopAndGetSchedule( coarseSchedule.dump(); }); - CoarseSchedule::Cluster afterPrologue = + tt::CoarseSchedule::Cluster afterPrologue = schedulePrologueAndEpilogue(forOp, coarseSchedule, rootUsers, numStages); LLVM_DEBUG({ LDBG("Coarse schedule with prologue and epilogue:"); @@ -1397,16 +1265,17 @@ void mlir::triton::updateWaits(ModuleOp module) { // also adds some MemDesc's to the wait. The idea is that if you have // // %alloc = ttg.local_alloc ... -// %a = ttng.dot_async %alloc -// %a1 = ttng.dot_wait %a +// %a = ttng.warp_group_dot %alloc +// %a1 = ttng.warp_group_dot_wait %a // // then we want the wait to depend on %alloc as well as %a. This extends the // live range of %alloc, so that it won't be destroyed until after the dot is // waited on. // -// Specifically, this function finds all dot_async ops that elements of `values` -// depend on. Then it adds the MemDesc operands of those dots to the wait. -static void threadValuesThroughWait(ttng::DotWaitOp wait, +// Specifically, this function finds all warp_group_dot ops that elements of +// `values` depend on. Then it adds the MemDesc operands of those dots to the +// wait. +static void threadValuesThroughWait(ttng::WarpGroupDotWaitOp wait, MutableArrayRef values) { IRRewriter builder(wait.getContext()); builder.setInsertionPoint(wait); @@ -1423,12 +1292,12 @@ static void threadValuesThroughWait(ttng::DotWaitOp wait, newOperands.insert(values.begin(), values.end()); // Find memdefs depended on by `values` through async dot ops. - SmallVector asyncDots; + SmallVector asyncDots; for (Value v : values) { BackwardSliceOptions options; options.omitBlockArguments = true; options.filter = [&](Operation *op) { - if (auto dot = dyn_cast(op)) { + if (auto dot = dyn_cast(op)) { asyncDots.push_back(dot); return false; } @@ -1438,7 +1307,7 @@ static void threadValuesThroughWait(ttng::DotWaitOp wait, getBackwardSlice(v, &slice, options); } - for (ttng::DotAsyncOp dot : asyncDots) { + for (ttng::WarpGroupDotOp dot : asyncDots) { for (Value operand : dot.getOperands()) { if (isa(operand.getType())) { newOperands.insert(operand); @@ -1448,7 +1317,7 @@ static void threadValuesThroughWait(ttng::DotWaitOp wait, // We can't use replaceWithNewOp because we're changing the number of return // values in the operation. - auto newWait = builder.create( + auto newWait = builder.create( wait.getLoc(), llvm::to_vector(newOperands), wait.getPendings()); auto dominatedByNewWait = [&](OpOperand &operand) { @@ -1469,13 +1338,14 @@ static void threadValuesThroughWait(ttng::DotWaitOp wait, wait->erase(); } -// Determines whether a given MMAv3 dot op, represented as ttng.dot_async, needs -// a wait immediately after it. +// Determines whether a given MMAv3 dot op, represented as ttng.warp_group_dot, +// needs a wait immediately after it. // // In PTX, MMAv3 exists only as an asynchronous op. In Triton, we can represent -// MMAv3 ops as either tt.dot (synchronous) or ttng.dot_async. But even if we -// use ttng.dot_async, the conservative thing is to make a dot "effectively -// synchronous" by inserting a `ttng.dot_wait {pendings=0}` right after it. +// MMAv3 ops as either ttng.warp_group_dot {isAsync=True} or ttng.warp_group_dot +// {isAsync=False}. But even if we use ttng.warp_group_dot {isAsync=True}, the +// conservative thing is to make a dot "effectively synchronous" by inserting a +// `ttng.warp_group_dot_wait {pendings=0}` right after it. // // We can omit the wait and create a "properly async" dot if all of the // following are true. @@ -1487,28 +1357,29 @@ static void threadValuesThroughWait(ttng::DotWaitOp wait, // and will be synced with a `wait 0` at the beginning of the `if` block. // // 3. During iteration i, between the start of the loop up until the first -// `ttng.dot_wait {pendings=0}` op, the result of the dot from iteration i-1 -// is consumed only by other MMAv3 dots as the `c` operand. +// `ttng.warp_group_dot_wait {pendings=0}` op, the result of the dot from +// iteration i-1 is consumed only by other MMAv3 dots as the `c` operand. // // This is safe because the following pseudo-PTX is valid: // -// %accum = dot_async %a1, %b1, %c1 -// %accum = dot_async %a2, %b2, %accum +// %accum = warp_group_dot %a1, %b1, %c1 +// %accum = warp_group_dot %a2, %b2, %accum // // That is, the second async dot can use the result of the first one without // an intervening wait. However, the only operation that can legally read -// %accum before the wait is another dot_async, and this only works for the -// `c` operand, not `a` or `b`. See +// %accum before the wait is another warp_group_dot, and this only works for +// the `c` operand, not `a` or `b`. See // https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-matrix-instructions-wgmma-fence -// (ttng::DotAsyncOp corresponds to wgmma.fence followed by one or more -// wgmma.async ops, so our understanding is that the two ttng::DotAsyncOps -// don't have to correspond to wgmma.async ops with the same shapes as -// specified in the docs, because there's an intervening fence.) +// (ttng::WarpGroupDotOp corresponds to wgmma.fence followed by one or more +// wgmma.async ops, so our understanding is that the two +// ttng::WarpGroupDotOps don't have to correspond to wgmma.async ops with +// the same shapes as specified in the docs, because there's an intervening +// fence.) // // If the op can be properly async, this function returns the index of the dot // in the loop's iter_args. (Rule (2) above ensures this is well-defined.) // -static std::optional dotCanBeProperlyAsync(ttng::DotAsyncOp dotOp, +static std::optional dotCanBeProperlyAsync(ttng::WarpGroupDotOp dotOp, scf::ForOp forOp) { LDBG("Considering whether to make MMAv3 dot properly async: " << dotOp); @@ -1582,16 +1453,17 @@ static std::optional dotCanBeProperlyAsync(ttng::DotAsyncOp dotOp, // Rule 3a: Are the only users of the dot's result from iteration i-1 other // MMAv3 dots? If so, we're done, this dot can be properly async. if (llvm::all_of(iterArg.getUses(), [&](OpOperand &use) { - return isa(use.getOwner()) && + return isa(use.getOwner()) && use.getOperandNumber() == 2; })) { return iterArgIdx; } // Rule 3b: Are all users of the dot's result from iteration i-1 after the - // first `dot_wait {pendings=0}` op? If so, the dot can be properly async, - // but we have to thread its result from iteration i-1 through the wait. - auto waitOps = forOp.getBody()->getOps(); + // first `warp_group_dot_wait {pendings=0}` op? If so, the dot can be + // properly async, but we have to thread its result from iteration i-1 through + // the wait. + auto waitOps = forOp.getBody()->getOps(); auto firstWaitOpIter = llvm::find_if( waitOps, [&](auto waitOp) { return waitOp.getPendings() == 0; }); if (firstWaitOpIter != waitOps.end() && @@ -1602,7 +1474,8 @@ static std::optional dotCanBeProperlyAsync(ttng::DotAsyncOp dotOp, } return (*firstWaitOpIter)->isBeforeInBlock(user); })) { - LDBG("MMAv3 dot can be properly async because it follows a dot_wait " + LDBG("MMAv3 dot can be properly async because it follows a " + "warp_group_dot_wait " "{pendings=0}.\n" << " wait: " << *firstWaitOpIter << "\n" << " dot: " << dotOp); @@ -1617,16 +1490,16 @@ static std::optional dotCanBeProperlyAsync(ttng::DotAsyncOp dotOp, // If necessary, insert a dot-wait inside the loop, waiting for the results of // the properly-async dots from iteration i-1 to complete. (We pipeline to -// depth 2, so there are at most 2 copies of each dot_async in flight at a +// depth 2, so there are at most 2 copies of each warp_group_dot in flight at a // time.) // -// We can skip inserting the wait if we have a `dot_wait {pendings=0}` somewhere -// in the loop. To see why, consider: +// We can skip inserting the wait if we have a `warp_group_dot_wait +// {pendings=0}` somewhere in the loop. To see why, consider: // -// dot_async -// dot_async; wait 0 // synchronous dot -// dot_async -// dot_async +// warp_group_dot +// warp_group_dot; wait 0 // synchronous dot +// warp_group_dot +// warp_group_dot // // In this example, there are three properly-async dots, so we'd normally put // `wait 3` at the end of the loop, meaning "wait until there are 3 or fewer @@ -1634,13 +1507,13 @@ static std::optional dotCanBeProperlyAsync(ttng::DotAsyncOp dotOp, // completes, there are only *two* pending async dots from this iteration, so // this wait would do nothing. This is true in general, no matter where the // `wait 0` appears. -static void insertAsyncDotWaitInLoop( +static void insertAsyncWarpGroupDotWaitInLoop( scf::ForOp forOp, const llvm::MapVector &properlyAsyncDots) { if (properlyAsyncDots.empty()) return; - if (llvm::any_of(forOp.getBody()->getOps(), + if (llvm::any_of(forOp.getBody()->getOps(), [](auto wait) { return wait.getPendings() == 0; })) { return; } @@ -1664,8 +1537,8 @@ static void insertAsyncDotWaitInLoop( for (auto [block, users] : blockToUsers) { OpBuilder builder(block, block->begin()); - auto newWait = builder.create(asyncDot->getLoc(), - ArrayRef{}, 0); + auto newWait = builder.create( + asyncDot->getLoc(), ArrayRef{}, 0); threadValuesThroughWait(newWait, users); } @@ -1682,9 +1555,9 @@ static void insertAsyncDotWaitInLoop( IRRewriter builder(forOp.getContext()); auto lastAsyncDot = properlyAsyncDots.back().first; builder.setInsertionPointAfter(lastAsyncDot); - auto wait = builder.create(lastAsyncDot->getLoc(), - /*inputs=*/ArrayRef{}, - properlyAsyncDots.size()); + auto wait = builder.create( + lastAsyncDot->getLoc(), + /*inputs=*/ArrayRef{}, properlyAsyncDots.size()); // Thread the results of the async dots through the wait. SmallVector addlWaitOperands; @@ -1694,49 +1567,40 @@ static void insertAsyncDotWaitInLoop( threadValuesThroughWait(wait, addlWaitOperands); } -// Convert MMAv3 tt::DotOps (i.e. Hopper wgmma) into ttng::DotAsyncOps and -// insert ttng::DotWaitOps as necessary. +// Convert MMAv3 ttng::WarpGroupDotOps {isAsync = False} (i.e. Hopper wgmma) +// into ttng::WarpGroupDotOps {isAsync = True} and insert +// ttng::WarpGroupDotWaitOps as necessary. // // We assume we have space for each dot to be pipelined to depth 2, i.e. each -// dot op in the loop can have at most 2 dot_async ops in flight at once. (Each -// dot_async op usually corresponds to a series of wgmma.async ops.) +// dot op in the loop can have at most 2 warp_group_dot ops in flight at once. +// (Each warp_group_dot op usually corresponds to a series of wgmma.async ops.) void triton::asyncLaunchDots(scf::ForOp forOp) { LDBG("Original loop:\n" << *forOp); - // First, change every MMAv3 tt.dot into ttng.dot_async. The rest of this - // function is concerned with inserting ttng.dot_wait ops in the appropriate - // places. - // - // It's not strictly necessary to convert every dot into dot_async: - // Synchronous MMAv3 dots can be represented equally well as `tt.dot` or - // `ttng.dot_async; wait 0`. But this makes things easier elsewhere. + // First, change every MMAv3 ttng.warp_group_dot {isAsync=false} + // into ttng.warp_group_dot {isAsync=true}. + // The rest of this function is concerned with inserting + // ttng.warp_group_dot_wait ops in the appropriate places. // // We call those dots that don't need to be followed immediately by a `wait 0` // "properly async", or sometimes just "async". - IRRewriter builder(forOp.getContext()); - for (auto dotOp : llvm::to_vector(forOp.getBody()->getOps())) { - if (isMMAv3Dot(dotOp)) { - builder.setInsertionPoint(dotOp); - builder.replaceOpWithNewOp( - dotOp, dotOp.getA(), dotOp.getB(), dotOp.getC(), - dotOp.getInputPrecision(), dotOp.getMaxNumImpreciseAcc()); - } - } - + // // For each dot, determine whether it can be properly async, or if it needs a // sync immediately after. If it can be properly async, we know its only use // is in the loop's `yield` statement; asyncDots maps the op to its index in // the yield op. + IRRewriter builder(forOp.getContext()); llvm::MapVector properlyAsyncDots; - for (auto dotOp : forOp.getBody()->getOps()) { - if (auto iterArgIdx = dotCanBeProperlyAsync(dotOp, forOp)) { - properlyAsyncDots[dotOp] = *iterArgIdx; + for (auto WarpGroupDotOp : forOp.getBody()->getOps()) { + WarpGroupDotOp.setIsAsync(true); + if (auto iterArgIdx = dotCanBeProperlyAsync(WarpGroupDotOp, forOp)) { + properlyAsyncDots[WarpGroupDotOp] = *iterArgIdx; } else { - builder.setInsertionPointAfter(dotOp); - auto wait = - builder.create(dotOp.getLoc(), ArrayRef{}, - /*pendings=*/0); - SmallVector waitOperands = {dotOp.getResult()}; + builder.setInsertionPointAfter(WarpGroupDotOp); + auto wait = builder.create( + WarpGroupDotOp.getLoc(), ArrayRef{}, + /*pendings=*/0); + SmallVector waitOperands = {WarpGroupDotOp.getResult()}; threadValuesThroughWait(wait, waitOperands); } } @@ -1750,7 +1614,7 @@ void triton::asyncLaunchDots(scf::ForOp forOp) { // iteration's set of asynchronous dots (and their corresponding async copies // from global to shmem) can't start until the first iteration's set has // completed. - insertAsyncDotWaitInLoop(forOp, properlyAsyncDots); + insertAsyncWarpGroupDotWaitInLoop(forOp, properlyAsyncDots); // Finally, insert a wait after the loop, waiting for dots from the final // iteration of the loop. @@ -1760,7 +1624,7 @@ void triton::asyncLaunchDots(scf::ForOp forOp) { } // Wait until there are 0 outstanding async dot ops. builder.setInsertionPointAfter(forOp); - auto dotWaitAfterLoop = - builder.create(forOp.getLoc(), ArrayRef{}, 0); - threadValuesThroughWait(dotWaitAfterLoop, waitOperands); + auto WarpGroupDotWaitAfterLoop = builder.create( + forOp.getLoc(), ArrayRef{}, 0); + threadValuesThroughWait(WarpGroupDotWaitAfterLoop, waitOperands); } diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/OuterLoopPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/OuterLoopPipeline.cpp index 8b3f55bb8b78..d8a34f6946ba 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/OuterLoopPipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/OuterLoopPipeline.cpp @@ -1,7 +1,7 @@ -#include "PipelineExpander.h" -#include "PipeliningUtility.h" -#include "Schedule.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/PipelineExpander.h" +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" +#include "triton/Dialect/TritonGPU/Transforms/Schedule.h" using namespace mlir; namespace tt = mlir::triton; diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp index 6dfd0e344a5f..95c6adc21952 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp @@ -31,7 +31,7 @@ #include "llvm/ADT/MapVector.h" #include "llvm/Support/Debug.h" -#include "PipelineExpander.h" +#include "triton/Dialect/TritonGPU/Transforms/PipelineExpander.h" #define DEBUG_TYPE "triton-loop-pipelining" #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") @@ -106,8 +106,8 @@ struct LoopPipelinerInternal { RewriterBase &rewriter); /// Emits the epilogue, this creates `maxStage - 1` part which will contain /// operations from stages [i; maxStage], where i is the part index. - void emitEpilogue(RewriterBase &rewriter, - llvm::SmallVector &returnValues); + LogicalResult emitEpilogue(RewriterBase &rewriter, ForOp newForOp, + llvm::SmallVector &returnValues); }; bool LoopPipelinerInternal::initializeLoopInfo( @@ -145,10 +145,6 @@ bool LoopPipelinerInternal::initializeLoopInfo( LDBG("--no epilogue or predicate set -> BAIL"); return false; } - if (dynamicLoop && peelEpilogue) { - LDBG("--dynamic loop doesn't support epilogue yet -> BAIL"); - return false; - } std::vector> schedule; options.getScheduleFn(forOp, schedule); if (schedule.empty()) { @@ -611,15 +607,21 @@ LogicalResult LoopPipelinerInternal::createKernel( // If there is a live range spanning across more than 2 stages we need to // add extra arg. for (unsigned i = 1; i < numVersionReturned; i++) { + // @@@ RegionIterArgs? + auto yieldOpr = newForOp.getBody()->getArguments()[yieldOperands.size() + 1 + + newForOp.getNumInductionVars()]; setValueMapping(it.first, newForOp->getResult(yieldOperands.size()), - version++); - yieldOperands.push_back( - newForOp.getBody()->getArguments()[yieldOperands.size() + 1 + - newForOp.getNumInductionVars()]); + version); + setValueMapping(yieldOpr, it.first, version); + ++version; + yieldOperands.push_back(yieldOpr); } setValueMapping(it.first, newForOp->getResult(yieldOperands.size()), - version++); - yieldOperands.push_back(mapping.lookupOrDefault(it.first)); + version); + // Map new yield param to old for reverse lookup + auto yieldOpr = mapping.lookupOrDefault(it.first); + setValueMapping(yieldOpr, it.first, version + 1); + yieldOperands.push_back(yieldOpr); } // Map the yield operand to the forOp returned value. for (const auto &retVal : @@ -642,12 +644,19 @@ LogicalResult LoopPipelinerInternal::createKernel( return success(); } -void LoopPipelinerInternal::emitEpilogue( - RewriterBase &rewriter, llvm::SmallVector &returnValues) { +LogicalResult LoopPipelinerInternal::emitEpilogue( + RewriterBase &rewriter, ForOp newForOp, + llvm::SmallVector &returnValues) { + Location loc = forOp.getLoc(); + auto forArgs = forOp.getRegionIterArgs(); + auto newForArgs = newForOp.getRegionIterArgs(); + llvm::SmallVector returnTypes; + llvm::for_each(returnValues, [&](Value v) { returnTypes.push_back(v.getType()); }); + llvm::SmallVector oldReturnValues(returnValues.size(), Value()); + // Emit different versions of the induction variable. They will be // removed by dead code if not used. for (int64_t i = 0; i < maxStage; i++) { - Location loc = forOp.getLoc(); Type t = lb.getType(); Value minusOne = rewriter.create(loc, rewriter.getIntegerAttr(t, -1)); @@ -669,7 +678,23 @@ void LoopPipelinerInternal::emitEpilogue( } // Emit `maxStage - 1` epilogue part that includes operations from stages // [i; maxStage]. + llvm::SmallVector yieldVersions(returnValues.size(), 0); for (int64_t i = 1; i <= maxStage; i++) { + OpBuilder::InsertionGuard g(rewriter); + scf::IfOp guardIfOp; + if (dynamicLoop) { + // if (ub > lb(maxStage - i)) { + Value lbStage = valueMapping[forOp.getInductionVar()][maxStage - i + 1]; + Value pred = rewriter.create(loc, arith::CmpIPredicate::sgt, ub, lbStage); + guardIfOp = rewriter.create(loc, returnTypes, pred, /*elseBranch*/true); + // else return inputs + rewriter.setInsertionPointToStart(guardIfOp.elseBlock()); + rewriter.create(loc, returnValues); + // then body + rewriter.setInsertionPointToStart(guardIfOp.thenBlock()); + // increment yieldOperands + llvm::for_each(yieldVersions, [&](int64_t &ver) { ver++; }); + } for (Operation *op : opOrder) { if (stages[op] < i) continue; @@ -684,29 +709,59 @@ void LoopPipelinerInternal::emitEpilogue( if (annotateFn) annotateFn(newOp, triton::PipeliningOption::PipelinerPart::Epilogue, i - 1); - for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) { - setValueMapping(op->getResult(destId), newOp->getResult(destId), - maxStage - stages[op] + i); + unsigned currentVersion = maxStage - stages[op] + i; + unsigned nextVersion = currentVersion + 1; + for (auto pair : llvm::enumerate(op->getResults())) { + Value oldRes = pair.value(); + Value newRes = newOp->getResult(pair.index()); + setValueMapping(oldRes, newRes, currentVersion); // If the value is a loop carried dependency update the loop argument // mapping and keep track of the last version to replace the original // forOp uses. for (OpOperand &operand : forOp.getBody()->getTerminator()->getOpOperands()) { - if (operand.get() != op->getResult(destId)) + if (operand.get() != oldRes) continue; - unsigned version = maxStage - stages[op] + i + 1; // If the version is greater than maxStage it means it maps to the // original forOp returned value. - if (version > maxStage) { - returnValues[operand.getOperandNumber()] = newOp->getResult(destId); + int64_t ri = operand.getOperandNumber(); + oldReturnValues[ri] = oldRes; + returnValues[ri] = newRes; + setValueMapping(forArgs[ri], newRes, nextVersion); + } + for (OpOperand &operand : + newForOp.getBody()->getTerminator()->getOpOperands()) { + auto it = valueMapping.find(operand.get()); + if (it == valueMapping.end()) continue; - } - setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()], - newOp->getResult(destId), version); + if (it->second[currentVersion] != oldRes) + continue; + // + int64_t ri = operand.getOperandNumber(); + oldReturnValues[ri] = oldRes; + returnValues[ri] = newRes; + setValueMapping(newForArgs[ri], newRes, nextVersion); } } } + if (dynamicLoop) { + // scf.yield rvals; + rewriter.create(loc, returnValues); + for (int ri = 0; ri < returnValues.size(); ++ri) { + auto ifVal = guardIfOp.getResult(ri); + returnValues[ri] = ifVal; + // + if (oldReturnValues[ri]) + setValueMapping(oldReturnValues[ri], ifVal, i+1); + // Reset for args + if (ri < forArgs.size()) + setValueMapping(forArgs[ri], ifVal, i+1); + else + setValueMapping(newForArgs[ri], ifVal, i+1); + } + } } + return success(); } void LoopPipelinerInternal::setValueMapping(Value key, Value el, int64_t idx) { @@ -718,7 +773,11 @@ void LoopPipelinerInternal::setValueMapping(Value key, Value el, int64_t idx) { valueMapping .insert(std::make_pair(key, llvm::SmallVector(maxStage + 1))) .first; - it->second[idx] = el; + auto &arr = it->second; + if (arr.size() == idx) + arr.push_back(el); + else + arr[idx] = el; } } // namespace @@ -759,17 +818,19 @@ mlir::triton::pipelineForLoop(RewriterBase &rewriter, ForOp forOp, rewriter))) return failure(); - llvm::SmallVector returnValues = - newForOp.getResults().take_front(forOp->getNumResults()); + llvm::SmallVector returnValues = newForOp.getResults(); if (options.peelEpilogue) { // 4. Emit the epilogue after the new forOp. rewriter.setInsertionPointAfter(newForOp); - pipeliner.emitEpilogue(rewriter, returnValues); + + if (failed(pipeliner.emitEpilogue(rewriter, newForOp, returnValues))) + return failure(); } // 5. Erase the original loop and replace the uses with the epilogue output. - if (forOp->getNumResults() > 0) + if (forOp->getNumResults() > 0) { + returnValues.pop_back_n(newForOp.getResults().size() - forOp.getResults().size()); rewriter.replaceOp(forOp, returnValues); - else + } else rewriter.eraseOp(forOp); return newForOp; diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp index c773d808c8b1..6a31aa8abddf 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp @@ -1,4 +1,4 @@ -#include "PipeliningUtility.h" +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/TypeUtilities.h" @@ -34,11 +34,9 @@ Operation *mlir::triton::predicateOp(RewriterBase &rewriter, Operation *op, OpBuilder::InsertionGuard guard(rewriter); if (mlir::isMemoryEffectFree(op)) return op; - if (isa(op)) + if (isa(op)) return op; - if (isa(op)) - return op; - if (isa(op)) + if (isa(op)) return op; if (auto ifOp = dyn_cast(op)) { rewriter.setInsertionPoint(op); diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.cpp new file mode 100644 index 000000000000..1116b70a0262 --- /dev/null +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.cpp @@ -0,0 +1,92 @@ +#include "triton/Dialect/TritonGPU/Transforms/Schedule.h" +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" + +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Support/LLVM.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "llvm/Support/Debug.h" + +using namespace mlir; +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; +namespace ttng = mlir::triton::nvidia_gpu; + +void tt::CoarseSchedule::insertDepsOfOp(Operation *op, int stage, + tt::CoarseSchedule::Cluster cluster, + bool includeArg) { + for (Value operand : op->getOperands()) { + Value v = operand; + llvm::SmallDenseSet seen; + while (auto arg = dyn_cast(v)) { + if (!includeArg) + break; + if (!seen.insert(v).second) + break; + if (arg.getArgNumber() > 0 && arg.getOwner() == op->getBlock()) { + auto yieldOp = op->getBlock()->getTerminator(); + v = yieldOp->getOperand(arg.getArgNumber() - 1); + continue; + } + break; + } + Operation *defOp = v.getDefiningOp(); + if (defOp && defOp->getBlock() == op->getBlock()) { + if (insertIfAbsent(defOp, stage, cluster)) { + insertDepsOfOp(defOp, stage, cluster, includeArg); + } + } + } +} + +SmallVector> +tt::CoarseSchedule::getOpsInOrder(scf::ForOp forOp) { + SmallVector>, 8> + orderClusters(clusters.size()); + for (auto &op : forOp.getBody()->without_terminator()) { + if (opToStageAndCluster.count(&op) == 0) { + continue; + } + assert(opToStageAndCluster[&op].first < numStages && + "Op with invalid stage!"); + int clusterId = *opToStageAndCluster[&op].second; + assert(clusterId == std::distance(clusters.begin(), + opToStageAndCluster[&op].second) && + "Cluster ID mismatch!"); + orderClusters[clusterId].push_back(make_tuple( + &op, opToStageAndCluster[&op].first, opToStageAndCluster[&op].second)); + } + SmallVector> opsInOrder; + for (int i = 0; i < orderClusters.size(); i++) { + for (auto [op, stage, cluster] : orderClusters[i]) { + opsInOrder.push_back({op, stage, cluster}); + } + } + + return opsInOrder; +} + +std::vector> +tt::CoarseSchedule::createFinalSchedule(scf::ForOp forOp) { + SmallVector> + opsInOrder = getOpsInOrder(forOp); + std::vector> schedule; + for (auto [op, stage, cluster] : opsInOrder) + schedule.push_back({op, stage}); + return schedule; +} + +void tt::CoarseSchedule::dump() { + for (int i = 0; i < numStages; i++) { + llvm::dbgs() << "\n---- Ops in stage " << i << "\n"; + for (auto &[op, stageAndCluster] : opToStageAndCluster) { + if (i == stageAndCluster.first) { + llvm::dbgs() << " cluster: " << *stageAndCluster.second + << ":\n\t" << *op << "\n"; + } + } + } +} diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.h b/lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.h deleted file mode 100644 index c61e81818201..000000000000 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.h +++ /dev/null @@ -1,39 +0,0 @@ -#ifndef TRITON_TRITONGPU_TRANSFORM_PIPELINE_SCHEDULE_H_ -#define TRITON_TRITONGPU_TRANSFORM_PIPELINE_SCHEDULE_H_ - -#include "PipelineExpander.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Support/LLVM.h" -#include "llvm/ADT/ArrayRef.h" -#include - -namespace mlir { -namespace triton { - -/// This fill out the pipelining options including schedule and annotations -/// for wait ops. This also does pre-processing by converting some of the -/// loads into async loads so that the IR is ready to be pipelined. -bool preProcessLoopAndGetSchedule(scf::ForOp &forOp, int numStages, - mlir::triton::PipeliningOption &options); - -/// Fills out pipelining options for an outer loop pipelining case. This -/// schedules async copies to overlap with the epilogue of a loop. -bool getOuterLoopSchedule(scf::ForOp &forOp, int numStages, - mlir::triton::PipeliningOption &options); - -/// Pipeline the TMA stores in the loop. -bool pipelineTMAStores(scf::ForOp forOp); - -/// This does post-processing on the pipelined loop to try to pipeline wgmma -/// ops. -// TODO: this should be included as part of the pipeline but currently the wgmma -// wait modeling is problematic. -void asyncLaunchDots(scf::ForOp forOp); - -/// Post process the pipelined loop by updating the wait ops with the right -/// number of groups in flight. -void updateWaits(ModuleOp module); - -} // namespace triton -} // namespace mlir -#endif // TRITON_TRITONGPU_TRANSFORM_PIPELINE_SCHEDULE_H_ diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/SoftwarePipeliner.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/SoftwarePipeliner.cpp index e5ed6ed370d0..8766e82b9f15 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/SoftwarePipeliner.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/SoftwarePipeliner.cpp @@ -1,6 +1,3 @@ -#include "PipelineExpander.h" -#include "PipeliningUtility.h" -#include "Schedule.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/TypeUtilities.h" @@ -11,6 +8,9 @@ #include "triton/Analysis/Utility.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/PipelineExpander.h" +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" +#include "triton/Dialect/TritonGPU/Transforms/Schedule.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" #include "triton/Tools/Sys/GetEnv.hpp" diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/TMAStoresPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/TMAStoresPipeline.cpp index 6318b178d39f..cf010992bb07 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/TMAStoresPipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/TMAStoresPipeline.cpp @@ -1,5 +1,5 @@ -#include "Schedule.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Schedule.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" using namespace mlir; @@ -39,9 +39,11 @@ static Value createAlloc(scf::ForOp &forOp, encoding = ttg::SharedEncodingAttr::get( ty.getContext(), ty.getShape(), order, ctaLayout, ty.getElementType()); } - - Type memdescType = tt::MemDescType::get(ty.getShape(), ty.getElementType(), - encoding, /*mutableMemory*/ true); + Attribute sharedMemorySpace = + triton::gpu::SharedMemorySpaceAttr::get(ty.getContext()); + Type memdescType = + tt::MemDescType::get(ty.getShape(), ty.getElementType(), encoding, + sharedMemorySpace, /*mutableMemory*/ true); Value alloc = builder.create(storeOp->getLoc(), memdescType, Value()); return alloc; diff --git a/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp b/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp index 85a95aaa7d5e..23140835f645 100644 --- a/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp @@ -136,8 +136,9 @@ Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue, builder.create(v.getLoc(), off, 32)); Value newSmem = builder.create( v.getLoc(), - triton::MemDescType::get(shape, elementType, type.getEncoding()), v, - offsetsVal); + triton::MemDescType::get(shape, elementType, type.getEncoding(), + type.getMemorySpace()), + v, offsetsVal); auto dotOperandEnc = triton::gpu::DotOperandEncodingAttr::get( builder.getContext(), opIdx, dotEncoding, prefetchWidth / 8); diff --git a/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp b/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp index c0b586d6016a..415f8dca77ed 100644 --- a/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp +++ b/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp @@ -58,6 +58,10 @@ class TritonGPUReduceDataDuplicationPass dstDotOp.getParent() == srcMfmaEncoding) return; } + if (isMoeLDSBypass() && isBlockedToDotShortcut(srcType, dstType)) { + return; + } + auto srcOrder = triton::gpu::getOrder(srcEncoding); auto rank = srcOrder.size(); SmallVector sharedOrder; @@ -70,12 +74,14 @@ class TritonGPUReduceDataDuplicationPass } else { sharedOrder = srcOrder; } + auto sharedMemorySpace = + triton::gpu::SharedMemorySpaceAttr::get(srcType.getContext()); auto tmpType = triton::MemDescType::get( dstType.getShape(), dstType.getElementType(), triton::gpu::SharedEncodingAttr::get( mod.getContext(), dstDotOp, srcType.getShape(), sharedOrder, - triton::gpu::getCTALayout(srcEncoding), - srcType.getElementType())); + triton::gpu::getCTALayout(srcEncoding), srcType.getElementType()), + sharedMemorySpace); auto tmp = builder.create( cvtOp.getLoc(), tmpType, cvtOp.getSrc()); auto newConvert = builder.create(cvtOp.getLoc(), diff --git a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp index 967d34c8f11f..585d6670f162 100644 --- a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp @@ -39,45 +39,6 @@ namespace { // // ----------------------------------------------------------------------------- -// dot(a, b, load(ptr)) -> add(load(ptr), dot(a, b, 0)) -class ConvertDotConvert : public RewritePattern { -public: - ConvertDotConvert(MLIRContext *context) - : RewritePattern(ConvertLayoutOp::getOperationName(), 1, context) {} - - LogicalResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const override { - auto dstOp = cast(op); - auto dotOp = dstOp.getSrc().getDefiningOp(); - if (!dotOp) - return failure(); - if (std::distance(dstOp->user_begin(), dstOp->user_end()) != 1 || - std::distance(dotOp->user_begin(), dotOp->user_end()) != 1) - return failure(); - auto cvtOp = dotOp.getOperand(2).getDefiningOp(); - if (!cvtOp) - return failure(); - if (!cvtOp.getSrc().getDefiningOp()) - return failure(); - RankedTensorType dstTy = dstOp.getType(); - RankedTensorType srcTy = cvtOp.getSrc().getType(); - if (dstTy != srcTy) - return failure(); - - auto _0f = rewriter.create( - op->getLoc(), dstTy.getElementType(), - rewriter.getZeroAttr(dstTy.getElementType())); - auto _0 = rewriter.create(op->getLoc(), dotOp.getType(), _0f); - auto newDot = rewriter.create( - op->getLoc(), dotOp.getType(), dotOp.getOperand(0), dotOp.getOperand(1), - _0, dotOp.getInputPrecision(), dotOp.getMaxNumImpreciseAcc()); - auto newCvt = rewriter.create(op->getLoc(), dstTy, - newDot.getResult()); - rewriter.replaceOpWithNewOp(op, newCvt, cvtOp.getSrc()); - return success(); - } -}; - // The current algorithm works by analyzing the IR and doing a one-shot rewrite // based on the analysis. The algorithm is as follows. // @@ -285,7 +246,7 @@ bool hasConvertToMMATransisitiveUse(Operation *op, Attribute encoding) { bool isLayoutAnchor(Operation *op) { if (isa(op)) return isExpensiveLoadOrStore(op); - if (isa(op)) + if (isa(op)) return true; // Heuristic: Mark permuting reshape as a layout anchor. Its dst can be @@ -402,7 +363,7 @@ SmallVector LayoutPropagation::propagateToUsers(Value value, if (user->hasTrait() || user->hasTrait() || isa(user)) { + ConvertLayoutOp, nvidia_gpu::WarpGroupDotWaitOp>(user)) { setEncoding(user->getResults(), info, changed, user); continue; } @@ -821,7 +782,7 @@ Operation *LayoutPropagation::rewriteOp(Operation *op) { if (op->hasTrait() || op->hasTrait() || isa(op)) { + ConvertLayoutOp, nvidia_gpu::WarpGroupDotWaitOp>(op)) { Operation *newOp = cloneElementwise(rewriter, op, encoding); for (auto [oldResult, newResult] : llvm::zip(op->getResults(), newOp->getResults())) { @@ -1288,17 +1249,6 @@ class TritonGPURemoveLayoutConversionsPass m.dump(); }); - RewritePatternSet decomposePatterns(context); - decomposePatterns.add(context); - if (applyPatternsAndFoldGreedily(m, std::move(decomposePatterns)) - .failed()) { - signalPassFailure(); - } - LLVM_DEBUG({ - DBGS() << "Module after decomposing dot-converts:\n"; - m.dump(); - }); - // 4. Apply clean up patterns to remove remove dead convert and dead code // generated by the previous transformations. RewritePatternSet cleanUpPatterns2(context); diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index 9bf61f01eaef..eaf0a7e2a148 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -25,7 +25,7 @@ using namespace triton; SmallVector mmaVersionToInstrShape(int version, const ArrayRef &shape, - TensorOrMemDesc type, + RankedTensorType type, int numWarps) { if (version == 1) return {16, 16}; @@ -442,8 +442,8 @@ std::optional inferSrcEncoding(Operation *op, Attribute encoding) { if (op->hasTrait() || op->hasTrait() || op->hasTrait() || - isa( - op)) { + isa(op)) { return encoding; } @@ -472,7 +472,7 @@ std::optional inferDstEncoding(Operation *op, Attribute encoding) { op->hasTrait() || op->hasTrait() || isa(op)) + nvidia_gpu::WarpGroupDotWaitOp>(op)) return encoding; if (auto reduceOp = dyn_cast(op)) return inferDstEncoding(reduceOp, encoding); @@ -627,6 +627,16 @@ scf::IfOp replaceIfOpWithNewSignature( return newIf; } +void appendToForOpYield(scf::ForOp forOp, ArrayRef newOperands) { + Operation *yieldOp = forOp.getBody()->getTerminator(); + SmallVector operands(yieldOp->getOperands()); + operands.append(newOperands.begin(), newOperands.end()); + + OpBuilder builder(yieldOp); + builder.create(yieldOp->getLoc(), operands); + yieldOp->erase(); +} + Operation *cloneWithInferType(mlir::OpBuilder &rewriter, Operation *op, IRMapping &mapping) { Operation *newOp = rewriter.clone(*op, mapping); diff --git a/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp b/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp index 0b06ee643bb8..36489b2d2bf3 100644 --- a/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp @@ -32,8 +32,8 @@ namespace mlir { namespace triton { namespace nvidia_gpu { -// -- DotAsyncOp -- -mlir::LogicalResult DotAsyncOp::inferReturnTypes( +// -- WarpGroupDotOp -- +mlir::LogicalResult WarpGroupDotOp::inferReturnTypes( MLIRContext *context, std::optional location, ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { @@ -57,7 +57,7 @@ mlir::LogicalResult DotAsyncOp::inferReturnTypes( return mlir::success(); } -void DotAsyncOp::getEffects( +void WarpGroupDotOp::getEffects( SmallVectorImpl> &effects) { auto a = getA(); @@ -70,8 +70,8 @@ void DotAsyncOp::getEffects( mlir::triton::gpu::SharedMemory::get()); } -// -- DotWaitOp -- -LogicalResult DotWaitOp::inferReturnTypes( +// -- WarpGroupDotWaitOp -- +LogicalResult WarpGroupDotWaitOp::inferReturnTypes( ::mlir::MLIRContext *context, ::std::optional<::mlir::Location> location, ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions, diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp index c7dd8d595f0d..fb0e7f6fdb18 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp @@ -44,7 +44,7 @@ struct FenceInsertionPass return; ModuleOp mod = getOperation(); mod.walk([&](Operation *op) { - if (!isa(op)) + if (!isa(op)) return WalkResult::advance(); OpBuilder builder(op); auto a = op->getOperand(0); @@ -79,7 +79,7 @@ struct FenceInsertionPass static DenseSet> trace; auto op = operand.getDefiningOp(); // avoid redundant insertion - if (op && isa(op)) + if (op && op->hasTrait()) return false; // reach convertlayout if (op && isa(op) && diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp index 58e2888b717d..6f1bd728db03 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp @@ -23,6 +23,8 @@ class TMALoadLowering : public OpRewritePattern { LogicalResult matchAndRewrite(ExperimentalDescriptorLoadOp op, PatternRewriter &rewriter) const override { + MLIRContext *ctx = op.getContext(); + Attribute sharedMemorySpace = triton::gpu::SharedMemorySpaceAttr::get(ctx); auto loc = op.getLoc(); auto tensorType = op.getResult().getType(); auto order = getOrder(tensorType.getEncoding()); @@ -36,15 +38,16 @@ class TMALoadLowering : public OpRewritePattern { } MemDescType memDescType = MemDescType::get(tensorType.getShape(), tensorType.getElementType(), - encoding, /*mutableMemory=*/true); + encoding, sharedMemorySpace, /*mutableMemory=*/true); Value alloc = rewriter.create(loc, memDescType, Value()); auto barrierCTALayout = CTALayoutAttr::get( /*context=*/tensorType.getContext(), /*CTAsPerCGA=*/{1}, /*CTASplitNum=*/{1}, /*CTAOrder=*/{0}); auto barrierEncoding = SharedEncodingAttr::get(tensorType.getContext(), 1, 1, 1, {0}, barrierCTALayout); - MemDescType barrierMemDescType = MemDescType::get( - {1}, rewriter.getI64Type(), barrierEncoding, /*mutableMemory=*/true); + MemDescType barrierMemDescType = + MemDescType::get({1}, rewriter.getI64Type(), barrierEncoding, + sharedMemorySpace, /*mutableMemory=*/true); Value barrierAlloc = rewriter.create(loc, barrierMemDescType, Value()); rewriter.create(loc, barrierAlloc, 1); @@ -70,6 +73,8 @@ class TMAStoreLowering LogicalResult matchAndRewrite(ExperimentalDescriptorStoreOp op, PatternRewriter &rewriter) const override { + MLIRContext *ctx = op.getContext(); + Attribute sharedMemorySpace = triton::gpu::SharedMemorySpaceAttr::get(ctx); auto loc = op.getLoc(); auto tensorType = op.getSrc().getType(); auto order = getOrder(tensorType.getEncoding()); @@ -83,7 +88,7 @@ class TMAStoreLowering } MemDescType memDescType = MemDescType::get(tensorType.getShape(), tensorType.getElementType(), - encoding, /*mutableMemory=*/true); + encoding, sharedMemorySpace, /*mutableMemory=*/true); Value alloc = rewriter.create(loc, memDescType, op.getSrc()); rewriter.create(loc, false); rewriter.create( diff --git a/python/src/ir.cc b/python/src/ir.cc index 0befdc491ba8..81b70cf5dd03 100644 --- a/python/src/ir.cc +++ b/python/src/ir.cc @@ -201,7 +201,11 @@ void init_triton_ir(py::module &&m) { .value("IEEE", InputPrecision::IEEE) .export_values(); - py::class_(m, "context", py::module_local()).def(py::init<>()); + py::class_(m, "context", py::module_local()) + .def(py::init<>()) + .def("enable_moe_lds_bypass", [](MLIRContext &self, bool value) -> void { + enableMoeLDSBypass(value); + }); m.def("load_dialects", [](MLIRContext &context) { DialectRegistry registry; diff --git a/python/test/unit/hopper/test_experimental_tma.py b/python/test/unit/hopper/test_experimental_tma.py index b20f75bc53a6..c3c8cb7f035e 100644 --- a/python/test/unit/hopper/test_experimental_tma.py +++ b/python/test/unit/hopper/test_experimental_tma.py @@ -26,14 +26,14 @@ def test_descriptor_load_ttgir(): tt.func public @kernel(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) attributes {{noinline = false}} {{ %c0_i32 = arith.constant 0 : i32 %0 = tt.make_range {{end = {SIZE} : i32, start = 0 : i32}} : tensor<{SIZE}xi32, #blocked> - %1 = triton_gpu.local_alloc : () -> !tt.memdesc<{SIZE}xf32, #shared, mutable> - %2 = triton_gpu.local_alloc : () -> !tt.memdesc<1xi64, #shared, mutable> - triton_nvidia_gpu.init_barrier %2, 1 : <1xi64, #shared, mutable> + %1 = triton_gpu.local_alloc : () -> !tt.memdesc<{SIZE}xf32, #shared, #triton_gpu.shared_memory, mutable> + %2 = triton_gpu.local_alloc : () -> !tt.memdesc<1xi64, #shared, #triton_gpu.shared_memory, mutable> + triton_nvidia_gpu.init_barrier %2, 1 : <1xi64, #shared, #triton_gpu.shared_memory, mutable> %true = arith.constant 1 : i1 - triton_nvidia_gpu.barrier_expect %2, {size_in_bytes}, %true : <1xi64, #shared, mutable> - triton_nvidia_gpu.async_tma_copy_global_to_local %arg1[%c0_i32] %1, %2, %true : , <1xi64, #shared, mutable> -> <{SIZE}xf32, #shared, mutable> - triton_nvidia_gpu.wait_barrier %2, %c0_i32 : <1xi64, #shared, mutable> - %3 = triton_gpu.local_load %1 : !tt.memdesc<{SIZE}xf32, #shared, mutable> -> tensor<{SIZE}xf32, #blocked> + triton_nvidia_gpu.barrier_expect %2, {size_in_bytes}, %true : <1xi64, #shared, #triton_gpu.shared_memory, mutable> + triton_nvidia_gpu.async_tma_copy_global_to_local %arg1[%c0_i32] %1, %2, %true : , <1xi64, #shared, #triton_gpu.shared_memory, mutable> -> <{SIZE}xf32, #shared, #triton_gpu.shared_memory, mutable> + triton_nvidia_gpu.wait_barrier %2, %c0_i32 : <1xi64, #shared, #triton_gpu.shared_memory, mutable> + %3 = triton_gpu.local_load %1 : !tt.memdesc<{SIZE}xf32, #shared, #triton_gpu.shared_memory, mutable> -> tensor<{SIZE}xf32, #blocked> %4 = tt.splat %arg0 : !tt.ptr -> tensor<{SIZE}x!tt.ptr, #blocked> %5 = tt.addptr %4, %0 : tensor<{SIZE}x!tt.ptr, #blocked>, tensor<{SIZE}xi32, #blocked> tt.store %5, %3 : tensor<{SIZE}x!tt.ptr, #blocked> diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 5972c93d7f98..d6103abfd207 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -3018,9 +3018,9 @@ def convert_fp8_to_fp32(x, device, dtype_str): @pytest.mark.interpreter @pytest.mark.parametrize( "M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dtype, out_dtype, kpack", - [(*shape, 4, False, False, epilogue, input_precision, in_dtype, out_dtype, 1) - for shape in [(64, 64, 64), (32, 32, 32), (16, 16, 16)] - for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols', 'softmax', 'chain-dot'] + [(*shape, 8, False, True, epilogue, input_precision, in_dtype, out_dtype, 2) + for shape in [(32, 128, 128)] + for epilogue in ['none'] for input_precision in ['tf32', 'tf32x3', 'ieee'] for in_dtype, out_dtype in [('float16', 'float16'), ('float16', 'float32'), ('float32', 'float32')] if not (input_precision != 'ieee' and (in_dtype in ['float16']))] + @@ -3158,7 +3158,7 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid 'COL_A': col_a, 'COL_B': col_b, 'BLOCK_M': M, 'BLOCK_K': K, 'BLOCK_N': N, 'ADD_MATRIX': epilogue == 'add-matrix', 'ADD_ROWS': epilogue == 'add-rows', 'ADD_COLS': epilogue == 'add-cols', 'DO_SOFTMAX': epilogue == 'softmax', 'CHAIN_DOT': epilogue == 'chain-dot', 'INPUT_PRECISION': input_precision, 'num_warps': - num_warps, 'num_ctas': num_ctas, 'out_dtype': out_dtype + num_warps, 'num_ctas': num_ctas, 'out_dtype': out_dtype, 'matrix_instr_nonkdim': 16 } if is_hip(): @@ -4852,10 +4852,10 @@ def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device): %12 = triton_gpu.convert_layout %9 : tensor<{M}x{N}xi32, #src> -> tensor<{M}x{N}xi32, #dst> %13 = triton_gpu.convert_layout %11 : tensor<{M}x{N}xf16, #src> -> tensor<{M}x{N}xf16, #dst> """ if interm_layout is None else f""" - %15 = triton_gpu.local_alloc %9 : (tensor<{M}x{N}xi32, #src>) -> !tt.memdesc<{M}x{N}xi32, #interm> - %16 = triton_gpu.local_load %15 : !tt.memdesc<{M}x{N}xi32, #interm> -> tensor<{M}x{N}xi32, #src> - %17 = triton_gpu.local_alloc %11 : (tensor<{M}x{N}xf16, #src>) -> !tt.memdesc<{M}x{N}xf16, #interm> - %18 = triton_gpu.local_load %17 : !tt.memdesc<{M}x{N}xf16, #interm> -> tensor<{M}x{N}xf16, #src> + %15 = triton_gpu.local_alloc %9 : (tensor<{M}x{N}xi32, #src>) -> !tt.memdesc<{M}x{N}xi32, #interm, #triton_gpu.shared_memory> + %16 = triton_gpu.local_load %15 : !tt.memdesc<{M}x{N}xi32, #interm, #triton_gpu.shared_memory> -> tensor<{M}x{N}xi32, #src> + %17 = triton_gpu.local_alloc %11 : (tensor<{M}x{N}xf16, #src>) -> !tt.memdesc<{M}x{N}xf16, #interm, #triton_gpu.shared_memory> + %18 = triton_gpu.local_load %17 : !tt.memdesc<{M}x{N}xf16, #interm, #triton_gpu.shared_memory> -> tensor<{M}x{N}xf16, #src> %12 = triton_gpu.convert_layout %16 : tensor<{M}x{N}xi32, #src> -> tensor<{M}x{N}xi32, #dst> %13 = triton_gpu.convert_layout %18 : tensor<{M}x{N}xf16, #src> -> tensor<{M}x{N}xf16, #dst> diff --git a/shuffle_test.py b/shuffle_test.py new file mode 100644 index 000000000000..d044e22bb224 --- /dev/null +++ b/shuffle_test.py @@ -0,0 +1,78 @@ +import triton +import triton.language as tl +import torch + +import numpy as np + + +def to_numpy(x): + return x.cpu().numpy() + + +@triton.jit +def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, Z, stride_zm, stride_zn, BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr): + off_m = tl.arange(0, BLOCK_M) + off_n = tl.arange(0, BLOCK_N) + off_k = tl.arange(0, BLOCK_K) + Xs = X + off_m[:, None] * stride_xm + off_k[None, :] * stride_xk + Ys = Y + off_k[:, None] * stride_yk + off_n[None, :] * stride_yn + Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn + x = tl.load(Xs) + y = tl.load(Ys) + z = tl.dot(x, y) + tl.store(Zs, z) + + +def permute_weight(x: torch.Tensor, numWarps) -> torch.Tensor: + B = x.shape[0] + N = x.shape[1] + K = x.shape[2] + x_ = x.clone() + + mfmaNSize = 16 + kWidth = 8 + numKGroups = 4 + + kRepeats = K // numKGroups // kWidth + nRepeats = N // numWarps // mfmaNSize + # 0 1 2 3 4 5 6 + # B NNNNNNNN NNNNNNNN NNNNNNNNN KKKKKKKK KKKKKKKKKK KKKKKK + x_ = x_.view(B, nRepeats, numWarps, mfmaNSize, kRepeats, numKGroups, kWidth) + + x_ = x_.permute(0, 1, 4, 2, 5, 3, 6) + + x_ = x_.contiguous() + x_ = x_.view(x.shape[0], x.shape[1], x.shape[2]) + return x_ + + +M = 128 +N = 128 +K = 128 + +x = torch.zeros((M, K), dtype=torch.float16, device="cuda") +for i in range(M): + x[i, i] = 1 +y = torch.zeros((1, N, K), dtype=torch.float16, device="cuda") + +for i in range(K): + for j in range(N): + y[0, j, i] = i + j * K + +ref = torch.matmul(x, y.permute([0, 2, 1]).reshape(K, N)) + +numWarps = 4 +y = permute_weight(y, numWarps) + +np.set_printoptions(threshold=100_000) +# print(to_numpy(y.reshape(N, K).to(torch.int32))) + +z = torch.zeros((M, N), dtype=torch.float32, device="cuda") + +kernel[(1, 1, 1)](x, x.stride(0), x.stride(1), y, y.stride(2), y.stride(1), z, z.stride(0), z.stride(1), M, N, K, + enable_moe_lds_bypass=True, num_warps=numWarps, matrix_instr_nonkdim=16, kpack=2) + +print(to_numpy(z.reshape(N, M).to(torch.int32))) + +np.testing.assert_allclose(to_numpy(ref), to_numpy(z)) diff --git a/test/Analysis/test-alias.mlir b/test/Analysis/test-alias.mlir index 2f73e0880f05..8c4e65daabfa 100644 --- a/test/Analysis/test-alias.mlir +++ b/test/Analysis/test-alias.mlir @@ -41,7 +41,7 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, // CHECK-LABEL: alloc tt.func @alloc(%A : !tt.ptr) { // CHECK: %0 -> %0 - %cst2 = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED, mutable> + %cst2 = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> tt.return } @@ -49,40 +49,40 @@ tt.func @alloc(%A : !tt.ptr) { tt.func @alloc_init(%A : !tt.ptr) { %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> // CHECK: %0 -> %0 - %cst1 = triton_gpu.local_alloc %cst0 : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> + %cst1 = triton_gpu.local_alloc %cst0 : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> tt.return } // CHECK-LABEL: trans tt.func @trans(%A : !tt.ptr) { // CHECK: %0 -> %0 - %tensor = triton_gpu.local_alloc : () -> !tt.memdesc<16x32xf16, #A_SHARED> + %tensor = triton_gpu.local_alloc : () -> !tt.memdesc<16x32xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK: %1 -> %0 - %b = tt.trans %tensor {order=array} : !tt.memdesc<16x32xf16, #A_SHARED> -> !tt.memdesc<32x16xf16, #A_SHARED_T> + %b = tt.trans %tensor {order=array} : !tt.memdesc<16x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> !tt.memdesc<32x16xf16, #A_SHARED_T, #triton_gpu.shared_memory> tt.return } // CHECK-LABEL: subview -tt.func @subview(%A : !tt.memdesc<1x16x16xf16, #A_SHARED>) { +tt.func @subview(%A : !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory>) { %index = arith.constant 0 : i32 // CHECK: %0 -> %0 - %a = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %a = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: %1 -> %0 - %cst1 = triton_gpu.memdesc_subview %a[%index, %index, %index] : !tt.memdesc<1x16x16xf16, #A_SHARED> -> !tt.memdesc<16x16xf16, #A_SHARED> + %cst1 = triton_gpu.memdesc_subview %a[%index, %index, %index] : !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> tt.return } // CHECK-LABEL: if_alias tt.func @if_alias(%i1 : i1) { // CHECK: %0 -> %0 - %a = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED> + %a = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK: %1 -> %1 - %b = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED> + %b = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: %2 -> %0,%1 - %cst2 = scf.if %i1 -> !tt.memdesc<16x16xf16, #A_SHARED> { - scf.yield %a : !tt.memdesc<16x16xf16, #A_SHARED> + %cst2 = scf.if %i1 -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> { + scf.yield %a : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> } else { - scf.yield %b : !tt.memdesc<16x16xf16, #A_SHARED> + scf.yield %b : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> } tt.return } @@ -90,11 +90,11 @@ tt.func @if_alias(%i1 : i1) { // CHECK-LABEL: for tt.func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { // CHECK: %0 -> %0 - %a = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED> + %a = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK: %1 -> %1 - %b = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED> + %b = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK: %2 -> %2 - %c = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED> + %c = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: %arg6 -> %0 // CHECK-NEXT: %arg7 -> %1 // CHECK-NEXT: %arg8 -> %2 @@ -102,8 +102,8 @@ tt.func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !t // CHECK-NEXT: %3#1 -> %0,%1 // CHECK-NEXT: %3#2 -> %0,%1,%2 %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a, %b_shared = %b, %c_shared = %c) -> - (!tt.memdesc<16x16xf16, #A_SHARED>, !tt.memdesc<16x16xf16, #A_SHARED>, !tt.memdesc<16x16xf16, #A_SHARED>) { - scf.yield %b_shared, %a_shared, %a_shared : !tt.memdesc<16x16xf16, #A_SHARED>, !tt.memdesc<16x16xf16, #A_SHARED>, !tt.memdesc<16x16xf16, #A_SHARED> + (!tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory>) { + scf.yield %b_shared, %a_shared, %a_shared : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> } tt.return } @@ -111,11 +111,11 @@ tt.func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !t // CHECK-LABEL: for_if tt.func @for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { // CHECK: %0 -> %0 - %a_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> + %a_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: %1 -> %1 - %b_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> + %b_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: %2 -> %2 - %c_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> + %c_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: %arg7 -> %0 // CHECK-NEXT: %arg8 -> %1 // CHECK-NEXT: %arg9 -> %2 @@ -123,14 +123,14 @@ tt.func @for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : // CHECK-NEXT: %3#1 -> %0,%1 // CHECK-NEXT: %3#2 -> %0,%1,%2 %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> - (!tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>) { + (!tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>) { scf.if %i1 { %index = arith.constant 8 : i32 // CHECK-NEXT: %4 -> %0,%1 - %cst0 = triton_gpu.memdesc_subview %a_shared[%index, %index] : !tt.memdesc<128x32xf16, #A_SHARED> -> !tt.memdesc<32xf16, #A_SHARED> + %cst0 = triton_gpu.memdesc_subview %a_shared[%index, %index] : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> !tt.memdesc<32xf16, #A_SHARED, #triton_gpu.shared_memory> scf.yield } - scf.yield %b_shared, %a_shared, %a_shared : !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED> + scf.yield %b_shared, %a_shared, %a_shared : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> } tt.return } @@ -138,11 +138,11 @@ tt.func @for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : // CHECK-LABEL: for_for_if tt.func @for_for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { // CHECK: %0 -> %0 - %a_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> + %a_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: %1 -> %1 - %b_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> + %b_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: %2 -> %2 - %c_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> + %c_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: %arg7 -> %0 // CHECK-NEXT: %arg8 -> %1 // CHECK-NEXT: %arg9 -> %2 @@ -150,23 +150,23 @@ tt.func @for_for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr, // CHECK-NEXT: %3#1 -> %1 // CHECK-NEXT: %3#2 -> %2,%6,%6 %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> - (!tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>) { + (!tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>) { // CHECK-NEXT: %arg11 -> %2,%6,%6 // CHECK-NEXT: %4 -> %2,%6,%6 - %c_shared_next = scf.for %jv = %lb to %ub step %step iter_args(%c_shared_next = %c_shared) -> (!tt.memdesc<128x32xf16, #A_SHARED>) { + %c_shared_next = scf.for %jv = %lb to %ub step %step iter_args(%c_shared_next = %c_shared) -> (!tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>) { // CHECK-NEXT: %5 -> %6,%6 - %c_shared_next_next = scf.if %i1 -> !tt.memdesc<128x32xf16, #A_SHARED> { + %c_shared_next_next = scf.if %i1 -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> { // CHECK-NEXT: %6 -> %6 - %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> - scf.yield %cst0 : !tt.memdesc<128x32xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + scf.yield %cst0 : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> } else { // CHECK-NEXT: %6 -> %6 - %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> - scf.yield %cst0 : !tt.memdesc<128x32xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + scf.yield %cst0 : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> } - scf.yield %c_shared_next_next : !tt.memdesc<128x32xf16, #A_SHARED> + scf.yield %c_shared_next_next : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> } - scf.yield %a_shared, %b_shared, %c_shared_next : !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED> + scf.yield %a_shared, %b_shared, %c_shared_next : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> } tt.return } @@ -175,29 +175,29 @@ tt.func @for_for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr, tt.func @cf_for(%arg0: index, %arg1: index, %arg2: index, %arg3: !tt.ptr, %arg4: !tt.ptr) { %idx = arith.constant 0 : i32 // CHECK: %0 -> %0 - %cst = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> + %cst = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: %1 -> %1 - %cst_0 = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> + %cst_0 = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: %2 -> %0 - %0 = triton_gpu.memdesc_subview %cst[%idx, %idx] : !tt.memdesc<128x32xf16, #A_SHARED> -> !tt.memdesc<128x32xf16, #A_SHARED> + %0 = triton_gpu.memdesc_subview %cst[%idx, %idx] : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> gpu.barrier // CHECK-NEXT: %3 -> %3 - %cst_1 = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> + %cst_1 = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: %5 -> %0,%1,%3 // CHECK-NEXT: %6 -> %0,%1,%3 // CHECK-NEXT: %7 -> %0,%1,%3 - cf.br ^bb1(%arg0, %cst, %cst_0, %cst_1 : index, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>) -^bb1(%1: index, %2: !tt.memdesc<128x32xf16, #A_SHARED>, %3: !tt.memdesc<128x32xf16, #A_SHARED>, %4: !tt.memdesc<128x32xf16, #A_SHARED>): // 2 preds: ^bb0, ^bb2 + cf.br ^bb1(%arg0, %cst, %cst_0, %cst_1 : index, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>) +^bb1(%1: index, %2: !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, %3: !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, %4: !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>): // 2 preds: ^bb0, ^bb2 %5 = arith.cmpi slt, %1, %arg1 : index cf.cond_br %5, ^bb2, ^bb3 ^bb2: // pred: ^bb1 gpu.barrier %8 = arith.addi %1, %arg2 : index - cf.br ^bb1(%8, %4, %2, %3 : index, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>) + cf.br ^bb1(%8, %4, %2, %3 : index, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>) ^bb3: // pred: ^bb1 gpu.barrier // CHECK-NEXT: %10 -> %0 - %9 = triton_gpu.memdesc_subview %0[%idx, %idx] : !tt.memdesc<128x32xf16, #A_SHARED> -> !tt.memdesc<128x32xf16, #A_SHARED> + %9 = triton_gpu.memdesc_subview %0[%idx, %idx] : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> tt.return } diff --git a/test/Analysis/test-allocation.mlir b/test/Analysis/test-allocation.mlir index f7926de3a24c..76a6340d7aef 100644 --- a/test/Analysis/test-allocation.mlir +++ b/test/Analysis/test-allocation.mlir @@ -80,47 +80,47 @@ tt.func @reusable(%A : !tt.ptr) { // CHECK-LABEL: preallocate tt.func @preallocate(%A : !tt.ptr) { // CHECK: offset = 0, size = 512 - %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 1024, size = 512 - %cst1 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst1 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 2048, size = 512 - %cst2 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst2 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 3072, size = 1024 - %a = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED> + %a = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 4096, size = 1024 - %b = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED> + %b = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory> - triton_gpu.local_dealloc %cst0 : !tt.memdesc<1x16x16xf16, #A_SHARED> + triton_gpu.local_dealloc %cst0 : !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 0, size = 1024 - %c = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED> + %c = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory> - triton_gpu.local_dealloc %cst1 : !tt.memdesc<1x16x16xf16, #A_SHARED> - triton_gpu.local_dealloc %cst2 : !tt.memdesc<1x16x16xf16, #A_SHARED> + triton_gpu.local_dealloc %cst1 : !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + triton_gpu.local_dealloc %cst2 : !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 1024, size = 1024 - %cst4 = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED> + %cst4 = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 6144, size = 2048 - %e = triton_gpu.local_alloc : () -> !tt.memdesc<64x16xf16, #A_SHARED> - triton_gpu.local_dealloc %a : !tt.memdesc<32x16xf16, #A_SHARED> + %e = triton_gpu.local_alloc : () -> !tt.memdesc<64x16xf16, #A_SHARED, #triton_gpu.shared_memory> + triton_gpu.local_dealloc %a : !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 8192, size = 2048 - %d = triton_gpu.local_alloc : () -> !tt.memdesc<64x16xf16, #A_SHARED> - triton_gpu.local_dealloc %b : !tt.memdesc<32x16xf16, #A_SHARED> + %d = triton_gpu.local_alloc : () -> !tt.memdesc<64x16xf16, #A_SHARED, #triton_gpu.shared_memory> + triton_gpu.local_dealloc %b : !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 10240, size = 2048 - %f = triton_gpu.local_alloc : () -> !tt.memdesc<64x16xf16, #A_SHARED> - triton_gpu.local_dealloc %cst4 : !tt.memdesc<32x16xf16, #A_SHARED> - triton_gpu.local_dealloc %c : !tt.memdesc<32x16xf16, #A_SHARED> + %f = triton_gpu.local_alloc : () -> !tt.memdesc<64x16xf16, #A_SHARED, #triton_gpu.shared_memory> + triton_gpu.local_dealloc %cst4 : !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory> + triton_gpu.local_dealloc %c : !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 0, size = 2048 - %cst5 = triton_gpu.local_alloc : () -> !tt.memdesc<64x16xf16, #A_SHARED> + %cst5 = triton_gpu.local_alloc : () -> !tt.memdesc<64x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 2048, size = 4096 - %g = triton_gpu.local_alloc : () -> !tt.memdesc<128x16xf16, #A_SHARED> - triton_gpu.local_dealloc %e : !tt.memdesc<64x16xf16, #A_SHARED> + %g = triton_gpu.local_alloc : () -> !tt.memdesc<128x16xf16, #A_SHARED, #triton_gpu.shared_memory> + triton_gpu.local_dealloc %e : !tt.memdesc<64x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 2048, size = 4096 - %h = triton_gpu.local_alloc : () -> !tt.memdesc<128x16xf16, #A_SHARED> - triton_gpu.local_dealloc %d : !tt.memdesc<64x16xf16, #A_SHARED> + %h = triton_gpu.local_alloc : () -> !tt.memdesc<128x16xf16, #A_SHARED, #triton_gpu.shared_memory> + triton_gpu.local_dealloc %d : !tt.memdesc<64x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 2048, size = 4096 - %i = triton_gpu.local_alloc : () -> !tt.memdesc<128x16xf16, #A_SHARED> - triton_gpu.local_dealloc %f : !tt.memdesc<64x16xf16, #A_SHARED> - triton_gpu.local_dealloc %cst5 : !tt.memdesc<64x16xf16, #A_SHARED> + %i = triton_gpu.local_alloc : () -> !tt.memdesc<128x16xf16, #A_SHARED, #triton_gpu.shared_memory> + triton_gpu.local_dealloc %f : !tt.memdesc<64x16xf16, #A_SHARED, #triton_gpu.shared_memory> + triton_gpu.local_dealloc %cst5 : !tt.memdesc<64x16xf16, #A_SHARED, #triton_gpu.shared_memory> tt.return // CHECK-NEXT: size = 12288 } @@ -130,11 +130,11 @@ tt.func @preallocate(%A : !tt.ptr) { tt.func @unused(%A : !tt.ptr) { %cst = arith.constant dense<0.000000e+00> : tensor<32x16xf16, #AL> // CHECK: offset = 0, size = 1024 - %cst0 = triton_gpu.local_alloc %cst : (tensor<32x16xf16, #AL>) -> !tt.memdesc<32x16xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc %cst : (tensor<32x16xf16, #AL>) -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 0, size = 512 - %cst1 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst1 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 0, size = 512 - %cst2 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst2 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> tt.return // CHECK: size = 1024 } @@ -143,33 +143,33 @@ tt.func @unused(%A : !tt.ptr) { // CHECK-LABEL: longlive tt.func @longlive(%A : !tt.ptr) { // CHECK: offset = 0, size = 512 - %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 1024, size = 512 - %cst1 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst1 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 2048, size = 512 - %cst2 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst2 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 3072, size = 1024 - %a = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED> - triton_gpu.local_dealloc %cst1 : !tt.memdesc<1x16x16xf16, #A_SHARED> - triton_gpu.local_dealloc %cst2 : !tt.memdesc<1x16x16xf16, #A_SHARED> + %a = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory> + triton_gpu.local_dealloc %cst1 : !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + triton_gpu.local_dealloc %cst2 : !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 1024, size = 512 - %cst3 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst3 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 2048, size = 512 - %cst4 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst4 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 3072, size = 1024 - %b = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED> + %b = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 3072, size = 512 - %cst5 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst5 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 3072, size = 512 - %cst6 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst6 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 3072, size = 1024 - %c = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED> - triton_gpu.local_dealloc %cst3 : !tt.memdesc<1x16x16xf16, #A_SHARED> - triton_gpu.local_dealloc %cst4 : !tt.memdesc<1x16x16xf16, #A_SHARED> + %c = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory> + triton_gpu.local_dealloc %cst3 : !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + triton_gpu.local_dealloc %cst4 : !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 1024, size = 1024 - %d = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED> - triton_gpu.local_dealloc %cst0 : !tt.memdesc<1x16x16xf16, #A_SHARED> + %d = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory> + triton_gpu.local_dealloc %cst0 : !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> tt.return // CHECK-NEXT: size = 4096 } @@ -178,43 +178,43 @@ tt.func @longlive(%A : !tt.ptr) { // CHECK-LABEL: multi_color tt.func @multi_color(%A : !tt.ptr) { // CHECK: offset = 0, size = 64 - %cst = triton_gpu.local_alloc : () -> !tt.memdesc<4x8xf16, #A_SHARED> + %cst = triton_gpu.local_alloc : () -> !tt.memdesc<4x8xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 1536, size = 32 - %cst_0 = triton_gpu.local_alloc : () -> !tt.memdesc<4x4xf16, #A_SHARED> + %cst_0 = triton_gpu.local_alloc : () -> !tt.memdesc<4x4xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 1664, size = 128 - %cst_1 = triton_gpu.local_alloc : () -> !tt.memdesc<16x4xf16, #A_SHARED> + %cst_1 = triton_gpu.local_alloc : () -> !tt.memdesc<16x4xf16, #A_SHARED, #triton_gpu.shared_memory> %cst_2 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> // CHECK-NEXT: scratch offset = 128, size = 1152 %0 = triton_gpu.convert_layout %cst_2 : tensor<16x32xf16, #AL> -> tensor<16x32xf16, #AL> - %1 = triton_gpu.local_load %cst : !tt.memdesc<4x8xf16, #A_SHARED> -> tensor<4x8xf16, #AL> + %1 = triton_gpu.local_load %cst : !tt.memdesc<4x8xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<4x8xf16, #AL> // CHECK-NEXT: offset = 0, size = 128 - %cst_3 = triton_gpu.local_alloc : () -> !tt.memdesc<4x16xf16, #A_SHARED> - %2 = triton_gpu.local_load %cst_0 : !tt.memdesc<4x4xf16, #A_SHARED> -> tensor<4x4xf16, #AL> + %cst_3 = triton_gpu.local_alloc : () -> !tt.memdesc<4x16xf16, #A_SHARED, #triton_gpu.shared_memory> + %2 = triton_gpu.local_load %cst_0 : !tt.memdesc<4x4xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<4x4xf16, #AL> // CHECK-NEXT: scratch offset = 0, size = 1152 %3 = triton_gpu.convert_layout %cst_2 : tensor<16x32xf16, #AL> -> tensor<16x32xf16, #AL> // CHECK-NEXT: offset = 0, size = 256 - %cst_4 = triton_gpu.local_alloc : () -> !tt.memdesc<4x32xf16, #A_SHARED> + %cst_4 = triton_gpu.local_alloc : () -> !tt.memdesc<4x32xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 256, size = 64 - %cst_5 = triton_gpu.local_alloc : () -> !tt.memdesc<4x8xf16, #A_SHARED> - %4 = triton_gpu.local_load %cst_5 : !tt.memdesc<4x8xf16, #A_SHARED> -> tensor<4x8xf16, #AL> - %5 = triton_gpu.local_load %cst_5 : !tt.memdesc<4x8xf16, #A_SHARED> -> tensor<4x8xf16, #AL> + %cst_5 = triton_gpu.local_alloc : () -> !tt.memdesc<4x8xf16, #A_SHARED, #triton_gpu.shared_memory> + %4 = triton_gpu.local_load %cst_5 : !tt.memdesc<4x8xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<4x8xf16, #AL> + %5 = triton_gpu.local_load %cst_5 : !tt.memdesc<4x8xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<4x8xf16, #AL> // CHECK-NEXT: offset = 1024, size = 512 - %cst_6 = triton_gpu.local_alloc : () -> !tt.memdesc<8x32xf16, #A_SHARED> + %cst_6 = triton_gpu.local_alloc : () -> !tt.memdesc<8x32xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 1792, size = 128 - %cst_7 = triton_gpu.local_alloc : () -> !tt.memdesc<2x32xf16, #A_SHARED> - %6 = triton_gpu.local_load %cst_0 : !tt.memdesc<4x4xf16, #A_SHARED> -> tensor<4x4xf16, #AL> + %cst_7 = triton_gpu.local_alloc : () -> !tt.memdesc<2x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %6 = triton_gpu.local_load %cst_0 : !tt.memdesc<4x4xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<4x4xf16, #AL> // CHECK-NEXT: offset = 1024, size = 512 - %cst_8 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst_8 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 256, size = 32 - %cst_9 = triton_gpu.local_alloc : () -> !tt.memdesc<4x4xf16, #A_SHARED> + %cst_9 = triton_gpu.local_alloc : () -> !tt.memdesc<4x4xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 1024, size = 512 - %cst_10 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> - %7 = triton_gpu.local_load %cst_1 : !tt.memdesc<16x4xf16, #A_SHARED> -> tensor<16x4xf16, #AL> - %8 = triton_gpu.local_load %cst_4 : !tt.memdesc<4x32xf16, #A_SHARED> -> tensor<4x32xf16, #AL> + %cst_10 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + %7 = triton_gpu.local_load %cst_1 : !tt.memdesc<16x4xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x4xf16, #AL> + %8 = triton_gpu.local_load %cst_4 : !tt.memdesc<4x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<4x32xf16, #AL> // CHECK-NEXT: scratch offset = 0, size = 1152 %9 = triton_gpu.convert_layout %cst_2 : tensor<16x32xf16, #AL> -> tensor<16x32xf16, #AL> %cst_11 = arith.constant dense<0.000000e+00> : tensor<4x4xf16, #AL> - %10 = triton_gpu.local_load %cst_7 : !tt.memdesc<2x32xf16, #A_SHARED> -> tensor<2x32xf16, #AL> + %10 = triton_gpu.local_load %cst_7 : !tt.memdesc<2x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<2x32xf16, #AL> %cst_12 = arith.constant dense<0.000000e+00> : tensor<4x16xf16, #AL> %cst_13 = arith.constant dense<0.000000e+00> : tensor<8x32xf16, #AL> // CHECK-NEXT: size = 1920 @@ -225,25 +225,25 @@ tt.func @multi_color(%A : !tt.ptr) { // CHECK-LABEL: multi_color_multi_rounds tt.func @multi_color_multi_rounds(%arg0: !tt.ptr) { // CHECK: offset = 0, size = 32 - %cst = triton_gpu.local_alloc : () -> !tt.memdesc<4x4xf16, #A_SHARED> + %cst = triton_gpu.local_alloc : () -> !tt.memdesc<4x4xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 1280, size = 128 - %cst_0 = triton_gpu.local_alloc : () -> !tt.memdesc<16x4xf16, #A_SHARED> + %cst_0 = triton_gpu.local_alloc : () -> !tt.memdesc<16x4xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 2048, size = 8192 - %cst_1 = triton_gpu.local_alloc : () -> !tt.memdesc<1024x4xf16, #A_SHARED> + %cst_1 = triton_gpu.local_alloc : () -> !tt.memdesc<1024x4xf16, #A_SHARED, #triton_gpu.shared_memory> %cst_2 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> // CHECK-NEXT: scratch offset = 128, size = 1152 %0 = triton_gpu.convert_layout %cst_2 : tensor<16x32xf16, #AL> -> tensor<16x32xf16, #AL> - %1 = triton_gpu.local_load %cst : !tt.memdesc<4x4xf16, #A_SHARED> -> tensor<4x4xf16, #AL> + %1 = triton_gpu.local_load %cst : !tt.memdesc<4x4xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<4x4xf16, #AL> // CHECK-NEXT: offset = 1152, size = 128 - %cst_3 = triton_gpu.local_alloc : () -> !tt.memdesc<2x32xf16, #A_SHARED> - %2 = triton_gpu.local_load %cst : !tt.memdesc<4x4xf16, #A_SHARED> -> tensor<4x4xf16, #AL> + %cst_3 = triton_gpu.local_alloc : () -> !tt.memdesc<2x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %2 = triton_gpu.local_load %cst : !tt.memdesc<4x4xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<4x4xf16, #AL> // CHECK-NEXT: offset = 0, size = 512 - %cst_4 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> - %3 = triton_gpu.local_load %cst_0 : !tt.memdesc<16x4xf16, #A_SHARED> -> tensor<16x4xf16, #AL> - %4 = triton_gpu.local_load %cst_1 : !tt.memdesc<1024x4xf16, #A_SHARED> -> tensor<1024x4xf16, #AL> + %cst_4 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + %3 = triton_gpu.local_load %cst_0 : !tt.memdesc<16x4xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x4xf16, #AL> + %4 = triton_gpu.local_load %cst_1 : !tt.memdesc<1024x4xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<1024x4xf16, #AL> // CHECK-NEXT: scratch offset = 0, size = 1152 %5 = triton_gpu.convert_layout %cst_2 : tensor<16x32xf16, #AL> -> tensor<16x32xf16, #AL> - %6 = triton_gpu.local_load %cst_3 : !tt.memdesc<2x32xf16, #A_SHARED> -> tensor<2x32xf16, #AL> + %6 = triton_gpu.local_load %cst_3 : !tt.memdesc<2x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<2x32xf16, #AL> // CHECK-NEXT: size = 10240 tt.return } @@ -252,10 +252,10 @@ tt.func @multi_color_multi_rounds(%arg0: !tt.ptr) { // CHECK-LABEL: alloc tt.func @alloc(%A : !tt.ptr) { // CHECK: offset = 0, size = 512 - %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> %cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> // CHECK-NEXT: offset = 0, size = 512 - %cst2 = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED> + %cst2 = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> tt.return // CHECK-NEXT: size = 512 } @@ -264,10 +264,10 @@ tt.func @alloc(%A : !tt.ptr) { // CHECK-LABEL: dealloc tt.func @dealloc(%A : !tt.ptr) { // CHECK: offset = 0, size = 1024 - %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK: offset = 1024, size = 1024 - %cst1 = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED> - triton_gpu.local_dealloc %cst0 : !tt.memdesc<32x16xf16, #A_SHARED> + %cst1 = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory> + triton_gpu.local_dealloc %cst0 : !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory> tt.return // CHECK-NEXT: size = 2048 } @@ -288,8 +288,8 @@ tt.func @scratch() { // CHECK-LABEL: trans tt.func @trans(%A : !tt.ptr) { // CHECK: offset = 0, size = 1024 - %tensor = triton_gpu.local_alloc : () -> !tt.memdesc<16x32xf16, #A_SHARED> - %b = tt.trans %tensor {order=array} : !tt.memdesc<16x32xf16, #A_SHARED> -> !tt.memdesc<32x16xf16, #A_SHARED_T> + %tensor = triton_gpu.local_alloc : () -> !tt.memdesc<16x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %b = tt.trans %tensor {order=array} : !tt.memdesc<16x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> !tt.memdesc<32x16xf16, #A_SHARED_T, #triton_gpu.shared_memory> tt.return } @@ -297,9 +297,9 @@ tt.func @trans(%A : !tt.ptr) { // CHECK-LABEL: extract_slice tt.func @extract_slice(%A : !tt.ptr) { // CHECK: offset = 0, size = 512 - %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> %index = arith.constant 0 : i32 - %cst1 = triton_gpu.memdesc_subview %cst0[%index, %index, %index] : !tt.memdesc<1x16x16xf16, #A_SHARED> -> !tt.memdesc<16x16xf16, #A_SHARED> + %cst1 = triton_gpu.memdesc_subview %cst0[%index, %index, %index] : !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> tt.return // CHECK-NEXT: size = 512 } @@ -309,25 +309,25 @@ tt.func @extract_slice(%A : !tt.ptr) { // CHECK-LABEL: if tt.func @if(%i1 : i1) { // CHECK: offset = 0, size = 512 - %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 1024, size = 512 - %cst1 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst1 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> scf.if %i1 { // CHECK-NEXT: offset = 2048, size = 1024 - %a = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED> + %a = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 2048, size = 1024 - %b = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED> - triton_gpu.local_dealloc %cst0 : !tt.memdesc<1x16x16xf16, #A_SHARED> - triton_gpu.local_dealloc %cst1 : !tt.memdesc<1x16x16xf16, #A_SHARED> + %b = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory> + triton_gpu.local_dealloc %cst0 : !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + triton_gpu.local_dealloc %cst1 : !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> } // CHECK-NEXT: offset = 0, size = 512 - %cst2 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst2 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 1024, size = 512 - %cst3 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst3 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 2048, size = 1024 - %a = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED> - triton_gpu.local_dealloc %cst2 : !tt.memdesc<1x16x16xf16, #A_SHARED> - triton_gpu.local_dealloc %cst3 : !tt.memdesc<1x16x16xf16, #A_SHARED> + %a = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory> + triton_gpu.local_dealloc %cst2 : !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + triton_gpu.local_dealloc %cst3 : !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> tt.return // CHECK-NEXT: size = 3072 } @@ -337,28 +337,28 @@ tt.func @if(%i1 : i1) { // CHECK-LABEL: if_else tt.func @if_else(%i1 : i1) { // CHECK: offset = 0, size = 512 - %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 1024, size = 512 - %cst1 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst1 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> scf.if %i1 { // CHECK-NEXT: offset = 2048, size = 1024 - %a = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED> + %a = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 2048, size = 1024 - %b = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED> + %b = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory> } else { // CHECK-NEXT: offset = 2048, size = 512 - %cst2 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst2 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 3072, size = 512 - %cst3 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst3 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 4096, size = 1024 - %a = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED> - triton_gpu.local_dealloc %cst2 : !tt.memdesc<1x16x16xf16, #A_SHARED> - triton_gpu.local_dealloc %cst3 : !tt.memdesc<1x16x16xf16, #A_SHARED> + %a = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory> + triton_gpu.local_dealloc %cst2 : !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + triton_gpu.local_dealloc %cst3 : !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> } // CHECK-NEXT: offset = 2048, size = 1024 - %a = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED> - triton_gpu.local_dealloc %cst0 : !tt.memdesc<1x16x16xf16, #A_SHARED> - triton_gpu.local_dealloc %cst1 : !tt.memdesc<1x16x16xf16, #A_SHARED> + %a = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory> + triton_gpu.local_dealloc %cst0 : !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + triton_gpu.local_dealloc %cst1 : !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> tt.return // CHECK-NEXT: size = 5120 } @@ -368,13 +368,13 @@ tt.func @if_else(%i1 : i1) { // CHECK-LABEL: for tt.func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { // CHECK: offset = 0, size = 8192 - %a_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> + %a_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 8192, size = 8192 - %b_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> + %b_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 16384, size = 8192 - %c_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> - %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>) { - scf.yield %b_shared, %a_shared, %a_shared : !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED> + %c_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>) { + scf.yield %b_shared, %a_shared, %a_shared : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> } tt.return // CHECK-NEXT: size = 24576 @@ -383,18 +383,18 @@ tt.func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !t // CHECK-LABEL: for_if_slice tt.func @for_if_slice(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { // CHECK: offset = 0, size = 8192 - %a_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> + %a_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 8192, size = 8192 - %b_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> + %b_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 16384, size = 8192 - %c_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> - %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>) { + %c_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>) { scf.if %i1 { %index = arith.constant 8 : i32 - %cst0 = triton_gpu.memdesc_subview %a_shared[%index, %index] : !tt.memdesc<128x32xf16, #A_SHARED> -> !tt.memdesc<32xf16, #A_SHARED> + %cst0 = triton_gpu.memdesc_subview %a_shared[%index, %index] : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> !tt.memdesc<32xf16, #A_SHARED, #triton_gpu.shared_memory> scf.yield } - scf.yield %b_shared, %a_shared, %a_shared : !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED> + scf.yield %b_shared, %a_shared, %a_shared : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> } tt.return // CHECK-NEXT: size = 24576 @@ -404,16 +404,16 @@ tt.func @for_if_slice(%lb : index, %ub : index, %step : index, %A : !tt.ptr // CHECK-LABEL: for_use_ancestor tt.func @for_use_ancestor(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { // CHECK: offset = 0, size = 8192 - %a_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> + %a_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 8192, size = 8192 - %b_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> + %b_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 16384, size = 8192 - %c_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> - %a_shared, %b_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init) -> (!tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>) { - %c0 = tt.trans %c_shared_init {order=array} : !tt.memdesc<128x32xf16, #A_SHARED> -> !tt.memdesc<32x128xf16, #A_SHARED_T> + %c_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %a_shared, %b_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init) -> (!tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>) { + %c0 = tt.trans %c_shared_init {order=array} : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> !tt.memdesc<32x128xf16, #A_SHARED_T, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 24576, size = 8192 - %c1 = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> - scf.yield %b_shared, %a_shared: !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED> + %c1 = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + scf.yield %b_shared, %a_shared: !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> } tt.return // CHECK-NEXT: size = 32768 @@ -424,28 +424,28 @@ tt.func @for_use_ancestor(%lb : index, %ub : index, %step : index, %A : !tt.ptr< // CHECK-LABEL: for_for_if tt.func @for_for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { // CHECK: offset = 0, size = 8192 - %a_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> + %a_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 8192, size = 8192 - %b_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> + %b_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 16384, size = 8192 - %c_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> - %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>) { - %c_shared_next = scf.for %jv = %lb to %ub step %step iter_args(%c_shared_next = %c_shared) -> (!tt.memdesc<128x32xf16, #A_SHARED>) { - %c_shared_next_next = scf.if %i1 -> !tt.memdesc<128x32xf16, #A_SHARED> { + %c_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>) { + %c_shared_next = scf.for %jv = %lb to %ub step %step iter_args(%c_shared_next = %c_shared) -> (!tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>) { + %c_shared_next_next = scf.if %i1 -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> { // CHECK-NEXT: offset = 24576, size = 8192 - %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> - scf.yield %cst0 : !tt.memdesc<128x32xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + scf.yield %cst0 : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> } else { // CHECK-NEXT: offset = 32768, size = 8192 - %cst1 = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> - scf.yield %cst1 : !tt.memdesc<128x32xf16, #A_SHARED> + %cst1 = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + scf.yield %cst1 : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> } - scf.yield %c_shared_next_next : !tt.memdesc<128x32xf16, #A_SHARED> + scf.yield %c_shared_next_next : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> } - scf.yield %a_shared, %b_shared, %c_shared_next : !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED> + scf.yield %a_shared, %b_shared, %c_shared_next : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> } // CHECK-NEXT: offset = 0, size = 8192 - %cst2 = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> + %cst2 = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> tt.return // CHECK-NEXT: size = 40960 } @@ -457,7 +457,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: alloc1 tt.func @alloc1(%A : !tt.ptr) { // CHECK: offset = 0, size = 512 - %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> tt.return // CHECK-NEXT: size = 512 } @@ -465,7 +465,7 @@ tt.func @alloc1(%A : !tt.ptr) { // CHECK-LABEL: alloc2 tt.func @alloc2(%A : !tt.ptr) { // CHECK: offset = 0, size = 1024 - %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory> tt.return // CHECK-NEXT: size = 1024 } @@ -474,10 +474,10 @@ tt.func @alloc2(%A : !tt.ptr) { tt.func @alloc3(%cond : i1) { scf.if %cond { // CHECK: offset = 0, size = 512 - %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> } else { // CHECK-NEXT: offset = 0, size = 1024 - %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<16x32xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<16x32xf16, #A_SHARED, #triton_gpu.shared_memory> } tt.return // CHECK-NEXT: size = 1024 @@ -499,7 +499,7 @@ tt.func @alloc4(%A : !tt.ptr, %cond : i1) { // CHECK-LABEL: single_call tt.func @single_call(%A : !tt.ptr) { // CHECK: offset = 0, size = 512 - %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> %cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> // CHECK-NEXT: virtual offset = 0, size = 512 tt.call @alloc1(%A) : (!tt.ptr) -> () @@ -510,7 +510,7 @@ tt.func @single_call(%A : !tt.ptr) { // CHECK-LABEL: multiple_calls tt.func @multiple_calls(%A : !tt.ptr) { // CHECK: offset = 0, size = 512 - %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: virtual offset = 0, size = 512 tt.call @alloc1(%A) : (!tt.ptr) -> () %cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> @@ -525,9 +525,9 @@ tt.func @if_else_calls(%A : !tt.ptr, %cond : i1) { %cst = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> scf.if %cond { // CHECK: offset = 0, size = 512 - %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 0, size = 1024 - %cst1 = triton_gpu.local_alloc %cst : (tensor<16x32xf16, #AL>) -> !tt.memdesc<16x32xf16, #A_SHARED> + %cst1 = triton_gpu.local_alloc %cst : (tensor<16x32xf16, #AL>) -> !tt.memdesc<16x32xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: virtual offset = 0, size = 512 tt.call @alloc1(%A) : (!tt.ptr) -> () } else { @@ -542,7 +542,7 @@ tt.func @if_else_calls(%A : !tt.ptr, %cond : i1) { // CHECK-LABEL: for_calls tt.func @for_calls(%A : !tt.ptr, %cond : i1) { // CHECK: offset = 0, size = 512 - %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> %cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> %lb = arith.constant 0 : index %ub = arith.constant 10 : index @@ -558,7 +558,7 @@ tt.func @for_calls(%A : !tt.ptr, %cond : i1) { // CHECK-LABEL: call_graph_1 tt.func @call_graph_1(%A : !tt.ptr, %cond : i1) { // CHECK: offset = 0, size = 512 - %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: virtual offset = 0, size = 1024 tt.call @alloc3(%cond) : (i1) -> () tt.return @@ -568,7 +568,7 @@ tt.func @call_graph_1(%A : !tt.ptr, %cond : i1) { // CHECK-LABEL: call_graph_2 tt.func @call_graph_2(%A : !tt.ptr, %cond : i1) { // CHECK: offset = 0, size = 512 - %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: virtual offset = 0, size = 1024 tt.call @alloc4(%A, %cond) : (!tt.ptr, i1) -> () tt.return diff --git a/test/Analysis/test-membar.mlir b/test/Analysis/test-membar.mlir index 747a63959942..6f657cda5555 100644 --- a/test/Analysis/test-membar.mlir +++ b/test/Analysis/test-membar.mlir @@ -5,7 +5,6 @@ #BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> #A_SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> #A_SHARED_T = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [0, 1]}> -#B_SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> #C = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> #A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 2}> #B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 2}> @@ -47,10 +46,10 @@ tt.func @raw_single_block(%A : !tt.ptr) { %cst2 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> %0 = tt.splat %A : !tt.ptr -> tensor<128x32x!tt.ptr, #AL> %1 = tt.load %0, %cst1, %cst2 : tensor<128x32x!tt.ptr, #AL> - %2 = triton_gpu.local_alloc %1 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> + %2 = triton_gpu.local_alloc %1 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %3 = triton_gpu.local_load %2 : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> + %3 = triton_gpu.local_load %2 : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> tt.return } @@ -60,14 +59,14 @@ tt.func @war_single_block(%A : !tt.ptr) { %cst2 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> %0 = tt.splat %A : !tt.ptr -> tensor<128x32x!tt.ptr, #AL> %1 = tt.load %0, %cst1, %cst2 : tensor<128x32x!tt.ptr, #AL> - %2 = triton_gpu.local_alloc %1 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> + %2 = triton_gpu.local_alloc %1 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK: triton_gpu.local_alloc // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %3 = triton_gpu.local_load %2 : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> + %3 = triton_gpu.local_load %2 : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> // CHECK: gpu.barrier // CHECK-NEXT: %4 = triton_gpu.local_alloc - %4 = triton_gpu.local_alloc %1 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> + %4 = triton_gpu.local_alloc %1 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> tt.return } @@ -77,25 +76,25 @@ tt.func @war_single_block_local_store(%A : !tt.ptr) { %cst2 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> %0 = tt.splat %A : !tt.ptr -> tensor<128x32x!tt.ptr, #AL> %1 = tt.load %0, %cst1, %cst2 : tensor<128x32x!tt.ptr, #AL> - %2 = triton_gpu.local_alloc %1 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> + %2 = triton_gpu.local_alloc %1 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK: triton_gpu.local_alloc // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %3 = triton_gpu.local_load %2 : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> + %3 = triton_gpu.local_load %2 : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_store - triton_gpu.local_store %1, %2 : tensor<128x32xf16, #AL> -> !tt.memdesc<128x32xf16, #A_SHARED> + triton_gpu.local_store %1, %2 : tensor<128x32xf16, #AL> -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> tt.return } // CHECK-LABEL: scratch tt.func @scratch(%arg: tensor<16x16xf16, #AL>) { - %cst0 = triton_gpu.local_alloc %arg : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc %arg : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load // CHECK: gpu.barrier // CHECK: tt.reduce - %1 = triton_gpu.local_load %cst0 : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> + %1 = triton_gpu.local_load %cst0 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> %2 = "tt.reduce" (%1) ({ ^bb0(%arg1: f16, %arg2: f16): %add = arith.addf %arg1, %arg2 : f16 @@ -106,34 +105,34 @@ tt.func @scratch(%arg: tensor<16x16xf16, #AL>) { // CHECK-LABEL: async_wait tt.func @async_wait(%arg: tensor<32x16xf16, #AL>) { - %cst0 = triton_gpu.local_alloc %arg : (tensor<32x16xf16, #AL>) -> !tt.memdesc<32x16xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc %arg : (tensor<32x16xf16, #AL>) -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK: triton_gpu.async_wait triton_gpu.async_wait {num = 4 : i32} // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %1 = triton_gpu.local_load %cst0 : !tt.memdesc<32x16xf16, #A_SHARED> -> tensor<32x16xf16, #AL> + %1 = triton_gpu.local_load %cst0 : !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<32x16xf16, #AL> tt.return } // CHECK-LABEL: subview tt.func @subview() { %cst0 = arith.constant dense<0.000000e+00> : tensor<32x16xf16, #AL> - %a = triton_gpu.local_alloc %cst0 : (tensor<32x16xf16, #AL>) -> !tt.memdesc<32x16xf16, #A_SHARED> + %a = triton_gpu.local_alloc %cst0 : (tensor<32x16xf16, #AL>) -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory> %index = arith.constant 0 : i32 - %0 = triton_gpu.memdesc_subview %a[%index, %index] : !tt.memdesc<32x16xf16, #A_SHARED> -> !tt.memdesc<16x16xf16, #A_SHARED> + %0 = triton_gpu.memdesc_subview %a[%index, %index] : !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %1 = triton_gpu.local_load %0 : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> + %1 = triton_gpu.local_load %0 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_alloc - %2 = triton_gpu.local_alloc %1 : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> + %2 = triton_gpu.local_alloc %1 : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> tt.return } // CHECK-LABEL: trans -tt.func @trans(%a: !tt.memdesc<16x32xf16, #A_SHARED>) { +tt.func @trans(%a: !tt.memdesc<16x32xf16, #A_SHARED, #triton_gpu.shared_memory>) { // CHECK-NOT: gpu.barrier - %b = tt.trans %a {order=array} : !tt.memdesc<16x32xf16, #A_SHARED> -> !tt.memdesc<32x16xf16, #A_SHARED_T> + %b = tt.trans %a {order=array} : !tt.memdesc<16x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> !tt.memdesc<32x16xf16, #A_SHARED_T, #triton_gpu.shared_memory> tt.return } @@ -143,31 +142,31 @@ tt.func @async_copy_global_to_local(%A : !tt.ptr, %i1 : i1) { %a_ptr = tt.splat %A : !tt.ptr -> tensor<16x16x!tt.ptr, #AL> %mask = tt.splat %i1 : i1 -> tensor<16x16xi1, #AL> %other = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> - %alloc = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, mutable> - %subview = triton_gpu.memdesc_subview %alloc[%index, %index, %index] : !tt.memdesc<1x16x16xf16, #A_SHARED, mutable> -> !tt.memdesc<16x16xf16, #A_SHARED, mutable> - %1 = triton_gpu.async_copy_global_to_local %a_ptr, %subview : tensor<16x16x!tt.ptr, #AL> -> !tt.memdesc<16x16xf16, #A_SHARED, mutable> + %alloc = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %subview = triton_gpu.memdesc_subview %alloc[%index, %index, %index] : !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %1 = triton_gpu.async_copy_global_to_local %a_ptr, %subview : tensor<16x16x!tt.ptr, #AL> -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %4 = triton_gpu.local_load %subview : !tt.memdesc<16x16xf16, #A_SHARED, mutable> -> tensor<16x16xf16, #AL> + %4 = triton_gpu.local_load %subview : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> tensor<16x16xf16, #AL> tt.return } // If branch inserted a barrier for %cst0, but else didn't, then the barrier should be inserted in the parent region // CHECK-LABEL: multi_blocks tt.func @multi_blocks(%i1 : i1) { %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> - %cst0 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> scf.if %i1 { // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %0 = triton_gpu.local_load %cst0 : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> + %0 = triton_gpu.local_load %cst0 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> scf.yield } else { - %cst1 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> + %cst1 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> scf.yield } // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %2 = triton_gpu.local_load %cst0 : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> + %2 = triton_gpu.local_load %cst0 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> tt.return } @@ -175,21 +174,21 @@ tt.func @multi_blocks(%i1 : i1) { // CHECK-LABEL: multi_blocks_join_barrier tt.func @multi_blocks_join_barrier(%i1 : i1) { %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> - %cst0 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> scf.if %i1 { // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %0 = triton_gpu.local_load %cst0 : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> + %0 = triton_gpu.local_load %cst0 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> scf.yield } else { // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %1 = triton_gpu.local_load %cst0 : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> + %1 = triton_gpu.local_load %cst0 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> scf.yield } // CHECK-NOT: gpu.barrier // CHECK: tt.return - %a_ = triton_gpu.local_load %cst0 : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> + %a_ = triton_gpu.local_load %cst0 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> tt.return } @@ -197,25 +196,25 @@ tt.func @multi_blocks_join_barrier(%i1 : i1) { // CHECK-LABEL: multi_blocks_yield tt.func @multi_blocks_yield(%i1 : i1) { %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> - %cst0 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> - %a = scf.if %i1 -> (!tt.memdesc<16x16xf16, #A_SHARED>) { + %cst0 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + %a = scf.if %i1 -> (!tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory>) { // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %0 = triton_gpu.local_load %cst0 : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> - %1 = triton_gpu.local_alloc %0 : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> - scf.yield %1 : !tt.memdesc<16x16xf16, #A_SHARED> + %0 = triton_gpu.local_load %cst0 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> + %1 = triton_gpu.local_alloc %0 : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + scf.yield %1 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> } else { // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %2 = triton_gpu.local_load %cst0 : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> - %3 = triton_gpu.local_alloc %2 : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> - scf.yield %3 : !tt.memdesc<16x16xf16, #A_SHARED> + %2 = triton_gpu.local_load %cst0 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> + %3 = triton_gpu.local_alloc %2 : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + scf.yield %3 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> } - %a_ = triton_gpu.local_load %cst0 : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> + %a_ = triton_gpu.local_load %cst0 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> // CHECK: triton_gpu.local_load // CHECK-NEXT: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %4 = triton_gpu.local_load %a : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> + %4 = triton_gpu.local_load %a : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> tt.return } @@ -223,27 +222,27 @@ tt.func @multi_blocks_yield(%i1 : i1) { // CHECK-LABEL: multi_blocks_entry_no_shared tt.func @multi_blocks_entry_no_shared(%i1 : i1) { %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> - %cst0 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> - %a = scf.if %i1 -> (!tt.memdesc<16x16xf16, #A_SHARED>) { + %cst0 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + %a = scf.if %i1 -> (!tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory>) { // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_alloc // CHECK-NEXT: gpu.barrier // CHECK-NEXT: triton_gpu.local_load // CHECK-NEXT: gpu.barrier // CHECK-NEXT: triton_gpu.local_alloc - %cst1 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> - %0 = triton_gpu.local_load %cst1 : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> - %1 = triton_gpu.local_alloc %0 : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> - scf.yield %1 : !tt.memdesc<16x16xf16, #A_SHARED> + %cst1 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + %0 = triton_gpu.local_load %cst1 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> + %1 = triton_gpu.local_alloc %0 : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + scf.yield %1 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> } else { // CHECK-NOT: gpu.barrier // CHECK: triton_gpu.local_alloc - %cst1 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> - scf.yield %cst1 : !tt.memdesc<16x16xf16, #A_SHARED> + %cst1 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + scf.yield %cst1 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> } // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %2 = triton_gpu.local_load %a : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> + %2 = triton_gpu.local_load %a : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> tt.return } @@ -251,16 +250,16 @@ tt.func @multi_blocks_entry_no_shared(%i1 : i1) { // CHECK-LABEL: multi_blocks_noelse tt.func @multi_blocks_noelse(%i1 : i1) { %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> - %cst0 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> scf.if %i1 { // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %0 = triton_gpu.local_load %cst0 : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> + %0 = triton_gpu.local_load %cst0 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> scf.yield } // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %1 = triton_gpu.local_load %cst0 : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> + %1 = triton_gpu.local_load %cst0 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> tt.return } @@ -268,39 +267,39 @@ tt.func @multi_blocks_noelse(%i1 : i1) { // CHECK-LABEL: multi_blocks_nested_scf tt.func @multi_blocks_nested_scf(%i1 : i1, %i2 : i1) { %cst = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> - %cst0 = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> scf.if %i1 { scf.if %i2 { // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %0 = triton_gpu.local_load %cst0 : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> + %0 = triton_gpu.local_load %cst0 : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> scf.yield } scf.yield } else { // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %1 = triton_gpu.local_load %cst0 : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> + %1 = triton_gpu.local_load %cst0 : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> scf.yield } // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %2 = triton_gpu.local_load %cst0 : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> + %2 = triton_gpu.local_load %cst0 : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> tt.return } // CHECK-LABEL: for tt.func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { %cst = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> - %a_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> - %b_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> - %c_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> - %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>) { + %a_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %b_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %c_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>) { // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %a0 = triton_gpu.local_load %a_shared : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> - %b0 = triton_gpu.local_load %b_shared : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> - scf.yield %b_shared, %a_shared, %a_shared : !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED> + %a0 = triton_gpu.local_load %a_shared : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + %b0 = triton_gpu.local_load %b_shared : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + scf.yield %b_shared, %a_shared, %a_shared : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> } tt.return } @@ -310,24 +309,24 @@ tt.func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !t // CHECK-LABEL: for_alias tt.func @for_alias(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { %cst = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> - %a_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> - %b_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> + %a_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %b_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %a0 = triton_gpu.local_load %a_shared_init : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> - %b0 = triton_gpu.local_load %b_shared_init : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> - %0 = triton_gpu.local_alloc %a0 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> - %c_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> - %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>) { + %a0 = triton_gpu.local_load %a_shared_init : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + %b0 = triton_gpu.local_load %b_shared_init : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + %0 = triton_gpu.local_alloc %a0 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %c_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>) { // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %a1 = triton_gpu.local_load %a_shared : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> - %b1 = triton_gpu.local_load %b_shared : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> - scf.yield %c_shared, %a_shared, %b_shared : !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED> + %a1 = triton_gpu.local_load %a_shared : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + %b1 = triton_gpu.local_load %b_shared : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + scf.yield %c_shared, %a_shared, %b_shared : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> } // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %r = triton_gpu.local_load %0 : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> + %r = triton_gpu.local_load %0 : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> tt.return } @@ -336,63 +335,63 @@ tt.func @for_alias(%lb : index, %ub : index, %step : index, %A : !tt.ptr, % // CHECK-LABEL: for_reuse tt.func @for_reuse(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { %cst = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> - %a_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> - %b_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> + %a_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %b_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %a0 = triton_gpu.local_load %a_shared_init : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> - %b0 = triton_gpu.local_load %b_shared_init : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> - %0 = triton_gpu.local_alloc %a0 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> - %c_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> - %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>) { + %a0 = triton_gpu.local_load %a_shared_init : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + %b0 = triton_gpu.local_load %b_shared_init : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + %0 = triton_gpu.local_alloc %a0 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %c_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>) { // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_alloc - %a1 = triton_gpu.local_load %a_shared_init : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> - %b1 = triton_gpu.local_load %b_shared_init : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> - %1 = triton_gpu.local_alloc %a1 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> + %a1 = triton_gpu.local_load %a_shared_init : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + %b1 = triton_gpu.local_load %b_shared_init : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + %1 = triton_gpu.local_alloc %a1 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_alloc - %a2 = triton_gpu.local_load %a_shared_init : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> - %b2 = triton_gpu.local_load %b_shared_init : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> - %2 = triton_gpu.local_alloc %a1 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> - scf.yield %c_shared, %a_shared, %b_shared : !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED> + %a2 = triton_gpu.local_load %a_shared_init : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + %b2 = triton_gpu.local_load %b_shared_init : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + %2 = triton_gpu.local_alloc %a1 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + scf.yield %c_shared, %a_shared, %b_shared : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> } // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %r = triton_gpu.local_load %0 : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> + %r = triton_gpu.local_load %0 : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> tt.return } // CHECK-LABEL: for_reuse_nested tt.func @for_reuse_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { %cst = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> - %a_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> - %b_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> + %a_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %b_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %a0 = triton_gpu.local_load %a_shared_init : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> - %b0 = triton_gpu.local_load %b_shared_init : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> - %0 = triton_gpu.local_alloc %a0 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> - %c_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> - %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>) { + %a0 = triton_gpu.local_load %a_shared_init : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + %b0 = triton_gpu.local_load %b_shared_init : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + %0 = triton_gpu.local_alloc %a0 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %c_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>) { // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_alloc - %a1 = triton_gpu.local_load %a_shared_init : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> - %b1 = triton_gpu.local_load %b_shared_init : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> - %1 = triton_gpu.local_alloc %a1 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> - %a_shared_next, %b_shared_next, %c_shared_next = scf.for %ivv = %lb to %ub step %step iter_args(%a_shared_nested = %a_shared_init, %b_shared_nested = %b_shared_init, %c_shared_nested = %c_shared_init) -> (!tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>) { + %a1 = triton_gpu.local_load %a_shared_init : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + %b1 = triton_gpu.local_load %b_shared_init : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + %1 = triton_gpu.local_alloc %a1 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %a_shared_next, %b_shared_next, %c_shared_next = scf.for %ivv = %lb to %ub step %step iter_args(%a_shared_nested = %a_shared_init, %b_shared_nested = %b_shared_init, %c_shared_nested = %c_shared_init) -> (!tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>) { // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_alloc - %a2 = triton_gpu.local_load %a_shared_init : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> - %b2 = triton_gpu.local_load %b_shared_init : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> - %2 = triton_gpu.local_alloc %a2 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> - scf.yield %c_shared_nested, %a_shared_nested, %b_shared_nested : !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED> + %a2 = triton_gpu.local_load %a_shared_init : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + %b2 = triton_gpu.local_load %b_shared_init : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + %2 = triton_gpu.local_alloc %a2 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + scf.yield %c_shared_nested, %a_shared_nested, %b_shared_nested : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> } - scf.yield %c_shared, %a_shared, %b_shared : !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED> + scf.yield %c_shared, %a_shared, %b_shared : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> } // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %r = triton_gpu.local_load %0 : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> + %r = triton_gpu.local_load %0 : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> tt.return } @@ -400,25 +399,25 @@ tt.func @for_reuse_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr< // CHECK-LABEL: for_for_if tt.func @for_for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { %cst = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> - %a_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> - %b_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> - %c_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> - %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>) { - %c_shared_next = scf.for %jv = %lb to %ub step %step iter_args(%c_shared_next = %c_shared) -> (!tt.memdesc<128x32xf16, #A_SHARED>) { - %c_shared_next_next = scf.if %i1 -> !tt.memdesc<128x32xf16, #A_SHARED> { + %a_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %b_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %c_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>) { + %c_shared_next = scf.for %jv = %lb to %ub step %step iter_args(%c_shared_next = %c_shared) -> (!tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>) { + %c_shared_next_next = scf.if %i1 -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> { // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_alloc - %cst0 = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> - scf.yield %cst0 : !tt.memdesc<128x32xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + scf.yield %cst0 : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> } else { // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_alloc - %cst0 = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> - scf.yield %cst0 : !tt.memdesc<128x32xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + scf.yield %cst0 : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> } - scf.yield %c_shared_next_next : !tt.memdesc<128x32xf16, #A_SHARED> + scf.yield %c_shared_next_next : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> } - scf.yield %a_shared, %b_shared, %c_shared_next : !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED> + scf.yield %a_shared, %b_shared, %c_shared_next : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> } tt.return } @@ -427,30 +426,30 @@ tt.func @for_for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr, // CHECK-LABEL: for_if_for tt.func @for_if_for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { %cst = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> - %a_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> - %b_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> - %c_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> + %a_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %b_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %c_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK: gpu.barrier - %c_blocked = triton_gpu.local_load %c_shared_init : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> + %c_blocked = triton_gpu.local_load %c_shared_init : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> - %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>) { - %c_shared_next_next = scf.if %i1 -> !tt.memdesc<128x32xf16, #A_SHARED> { + %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>) { + %c_shared_next_next = scf.if %i1 -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> { // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_alloc - %cst0 = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> - scf.yield %cst0 : !tt.memdesc<128x32xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + scf.yield %cst0 : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> } else { - %c_shared_ = scf.for %jv = %lb to %ub step %step iter_args(%c_shared_next = %c_shared) -> (!tt.memdesc<128x32xf16, #A_SHARED>) { + %c_shared_ = scf.for %jv = %lb to %ub step %step iter_args(%c_shared_next = %c_shared) -> (!tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>) { // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %c_blocked_next = triton_gpu.local_load %c_shared_next : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> - scf.yield %c_shared : !tt.memdesc<128x32xf16, #A_SHARED> + %c_blocked_next = triton_gpu.local_load %c_shared_next : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + scf.yield %c_shared : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> } - scf.yield %c_shared_ : !tt.memdesc<128x32xf16, #A_SHARED> + scf.yield %c_shared_ : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> } // CHECK-NOT: gpu.barrier - %b_blocked_next = triton_gpu.local_load %b_shared: !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> - scf.yield %a_shared, %b_shared, %c_shared_next_next : !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED> + %b_blocked_next = triton_gpu.local_load %b_shared: !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + scf.yield %a_shared, %b_shared, %c_shared_next_next : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> } tt.return } @@ -458,63 +457,63 @@ tt.func @for_if_for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, // CHECK-LABEL: cf_if tt.func @cf_if(%i1 : i1) { %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> - %a = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> + %a = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> cf.cond_br %i1, ^bb1, ^bb2 ^bb1: // pred: ^bb0 // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %0 = triton_gpu.local_load %a : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> + %0 = triton_gpu.local_load %a : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> cf.br ^bb2 ^bb2: // 2 preds: ^bb0, ^bb1 // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %1 = triton_gpu.local_load %a : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> + %1 = triton_gpu.local_load %a : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> tt.return } tt.func @cf_if_else(%i1 : i1) { %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> - %a = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> + %a = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> cf.cond_br %i1, ^bb1, ^bb2 ^bb1: // pred: ^bb0 // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %0 = triton_gpu.local_load %a : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> - %1 = triton_gpu.local_alloc %0 : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> - cf.br ^bb3(%1 : !tt.memdesc<16x16xf16, #A_SHARED>) + %0 = triton_gpu.local_load %a : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> + %1 = triton_gpu.local_alloc %0 : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + cf.br ^bb3(%1 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory>) ^bb2: // pred: ^bb0 // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %2 = triton_gpu.local_load %a : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> - %3 = triton_gpu.local_alloc %2 : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> - cf.br ^bb3(%3 : !tt.memdesc<16x16xf16, #A_SHARED>) -^bb3(%arg: !tt.memdesc<16x16xf16, #A_SHARED>): // 2 preds: ^bb1, ^bb2 + %2 = triton_gpu.local_load %a : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> + %3 = triton_gpu.local_alloc %2 : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + cf.br ^bb3(%3 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory>) +^bb3(%arg: !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory>): // 2 preds: ^bb1, ^bb2 cf.br ^bb4 ^bb4: // pred: ^bb3 // CHECK: triton_gpu.local_load - %4 = triton_gpu.local_load %a : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> + %4 = triton_gpu.local_load %a : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %5 = triton_gpu.local_load %arg : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> + %5 = triton_gpu.local_load %arg : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> tt.return } tt.func @cf_if_else_return(%i1 : i1) { %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> - %a = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> - %b = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> + %a = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + %b = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> cf.cond_br %i1, ^bb1, ^bb2 ^bb1: // pred: ^bb0 // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %0 = triton_gpu.local_load %a : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> - %1 = triton_gpu.local_load %b : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> + %0 = triton_gpu.local_load %a : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> + %1 = triton_gpu.local_load %b : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> tt.return ^bb2: // pred: ^bb0 // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %2 = triton_gpu.local_load %a : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> - %3 = triton_gpu.local_load %b : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> + %2 = triton_gpu.local_load %a : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> + %3 = triton_gpu.local_load %b : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> tt.return } @@ -525,38 +524,38 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : // CHECK-LABEL: convert_layout1 tt.func @convert_layout1(%A : !tt.ptr) { // CHECK-NOT: gpu.barrier - %0 = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED> - %1 = triton_gpu.local_load %0 : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> + %0 = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + %1 = triton_gpu.local_load %0 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> tt.return } // CHECK-LABEL: convert_layout2 tt.func @convert_layout2(%A : !tt.ptr) { %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> - %0 = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED> - %1 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> + %0 = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + %1 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK: triton_gpu.local_load // CHECK-NEXT: gpu.barrier // CHECK: triton_gpu.local_load - %3 = triton_gpu.local_load %0 : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> - %4 = triton_gpu.local_load %1 : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> + %3 = triton_gpu.local_load %0 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> + %4 = triton_gpu.local_load %1 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> tt.return } // CHECK-LABEL: convert_layout3 tt.func @convert_layout3(%cond : i1) { scf.if %cond { - %0 = triton_gpu.local_alloc : () -> !tt.memdesc<16x64xf16, #A_SHARED> + %0 = triton_gpu.local_alloc : () -> !tt.memdesc<16x64xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK: triton_gpu.local_load // CHECK-NOT: gpu.barrier - %1 = triton_gpu.local_load %0 : !tt.memdesc<16x64xf16, #A_SHARED> -> tensor<16x64xf16, #AL> + %1 = triton_gpu.local_load %0 : !tt.memdesc<16x64xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x64xf16, #AL> } else { - %0 = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED> + %0 = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK: triton_gpu.local_load // CHECK-NEXT: gpu.barrier // CHECK-NEXT: triton_gpu.local_alloc - %1 = triton_gpu.local_load %0 : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> - %2 = triton_gpu.local_alloc %1 : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> + %1 = triton_gpu.local_load %0 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> + %2 = triton_gpu.local_alloc %1 : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> } tt.return } @@ -595,7 +594,7 @@ tt.func @single_call_no_sync(%A : !tt.ptr) { // CHECK-LABEL: multiple_calls tt.func @multiple_calls(%A : !tt.ptr) { %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> - %cst0 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> tt.call @convert_layout1(%A) : (!tt.ptr) -> () %cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> tt.call @convert_layout2(%A) : (!tt.ptr) -> () @@ -607,12 +606,12 @@ tt.func @if_else_calls(%A : !tt.ptr, %cond : i1) { scf.if %cond { %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> %cst_ = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> - %cst0 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK: gpu.barrier // CHECK-NEXT: tt.call // CHECK-NEXT: gpu.barrier tt.call @convert_layout1(%A) : (!tt.ptr) -> () - %cst1 = triton_gpu.local_alloc %cst_ : (tensor<16x32xf16, #AL>) -> !tt.memdesc<16x32xf16, #A_SHARED> + %cst1 = triton_gpu.local_alloc %cst_ : (tensor<16x32xf16, #AL>) -> !tt.memdesc<16x32xf16, #A_SHARED, #triton_gpu.shared_memory> } else { %cst0 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> // CHECK: tt.call @@ -625,7 +624,7 @@ tt.func @if_else_calls(%A : !tt.ptr, %cond : i1) { // CHECK-LABEL: for_calls tt.func @for_calls(%A : !tt.ptr, %cond : i1) { %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> - %cst0 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> %cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> %lb = arith.constant 0 : index %ub = arith.constant 10 : index @@ -641,7 +640,7 @@ tt.func @for_calls(%A : !tt.ptr, %cond : i1) { // CHECK-LABEL: call_graph_1 tt.func @call_graph_1(%A : !tt.ptr, %cond : i1) { %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> - %cst0 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> // CHECK: gpu.barrier + %cst0 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK: gpu.barrier // CHECK-NEXT: tt.call tt.call @convert_layout3(%cond) : (i1) -> () tt.return @@ -653,7 +652,7 @@ tt.func @call_graph_2(%A : !tt.ptr, %cond : i1) { tt.call @convert_layout4(%A, %cond) : (!tt.ptr, i1) -> () // CHECK: tt.call // CHECK-NEXT: gpu.barrier - %cst0 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> tt.return } diff --git a/test/Conversion/amd/decompose-unsupported-conversions.mlir b/test/Conversion/amd/decompose-unsupported-conversions.mlir index b5fc5b72ce47..a6d39bdfb4b1 100644 --- a/test/Conversion/amd/decompose-unsupported-conversions.mlir +++ b/test/Conversion/amd/decompose-unsupported-conversions.mlir @@ -7,7 +7,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { tt.func @wmma_to_wmma_dot_op(%arg0: tensor<16x16xf16, #mma>) { // CHECK: %[[SRC_BLOCKED:.+]] = triton_gpu.convert_layout %{{.*}} : tensor<16x16xf16, #[[WMMA]]> -> tensor<16x16xf16, #[[BLOCKED]]> - // CHECK-NEXT: %[[INT_SHARED:.+]] = triton_gpu.local_alloc %[[SRC_BLOCKED]] : {{.*}} -> !tt.memdesc<16x16xf16, #[[SHARED]]> + // CHECK-NEXT: %[[INT_SHARED:.+]] = triton_gpu.local_alloc %[[SRC_BLOCKED]] : {{.*}} -> !tt.memdesc<16x16xf16, #[[SHARED]], #triton_gpu.shared_memory> // CHECK-NEXT: %[[DST_DOT_OP:.+]] = triton_gpu.local_load %[[INT_SHARED]] : {{.*}} -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[WMMA]], kWidth = 16}>> %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf16, #mma> -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> tt.return diff --git a/test/Conversion/amd/tritongpu_wmma_dot_to_llvm.mlir b/test/Conversion/amd/tritongpu_wmma_dot_to_llvm.mlir index 5a4ada339200..ce59fa94a408 100644 --- a/test/Conversion/amd/tritongpu_wmma_dot_to_llvm.mlir +++ b/test/Conversion/amd/tritongpu_wmma_dot_to_llvm.mlir @@ -42,7 +42,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : } // CHECK-LABEL: wmma_dot_int8_32 - tt.func @wmma_dot_int8_32(%arg0: tensor<16x16xui8, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>, %arg1: tensor<16x16xui8, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>, %arg2: tensor<16x16xi32, #mma>) { + tt.func @wmma_dot_int8_32(%arg0: tensor<16x16xi8, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>, %arg1: tensor<16x16xi8, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>, %arg2: tensor<16x16xi32, #mma>) { // CHECK-COUNT-16: llvm.extractvalue %{{.*}} : !llvm.struct<(i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8)> // CHECK-COUNT-16: llvm.insertelement {{.*}} : vector<16xi8> // CHECK: llvm.bitcast %{{.*}} : vector<16xi8> to vector<4xi32> @@ -51,13 +51,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: llvm.bitcast %{{.*}} : vector<16xi8> to vector<4xi32> // CHECK-COUNT-8: llvm.extractvalue %{{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> // CHECK: rocdl.wmma.i32.16x16x16.iu8 {{.*}} : (i1, vector<4xi32>, i1, vector<4xi32>, vector<8xi32>, i1) -> vector<8xi32> - %0 = tt.dot %arg0, %arg1, %arg2 {inputPrecision = 2 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<16x16xui8, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> * tensor<16x16xui8, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>> -> tensor<16x16xi32, #mma> + %0 = tt.dot %arg0, %arg1, %arg2 {inputPrecision = 2 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<16x16xi8, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> * tensor<16x16xi8, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>> -> tensor<16x16xi32, #mma> // CHECK-COUNT-8: llvm.insertvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> tt.return } // CHECK-LABEL: wmma_dot_int4_32 - tt.func @wmma_dot_int4_32(%arg0: tensor<16x16xui4, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>, %arg1: tensor<16x16xui4, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>, %arg2: tensor<16x16xi32, #mma>) { + tt.func @wmma_dot_int4_32(%arg0: tensor<16x16xi4, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>, %arg1: tensor<16x16xi4, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>, %arg2: tensor<16x16xi32, #mma>) { // CHECK-COUNT-16: llvm.extractvalue %{{.*}} : !llvm.struct<(i4, i4, i4, i4, i4, i4, i4, i4, i4, i4, i4, i4, i4, i4, i4, i4)> // CHECK-COUNT-16: llvm.insertelement {{.*}} : vector<16xi4> // CHECK: llvm.bitcast %{{.*}} : vector<16xi4> to vector<2xi32> @@ -66,7 +66,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: llvm.bitcast %{{.*}} : vector<16xi4> to vector<2xi32> // CHECK-COUNT-8: llvm.extractvalue %{{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> // CHECK: rocdl.wmma.i32.16x16x16.iu4 {{.*}} : (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32> - %0 = tt.dot %arg0, %arg1, %arg2 {inputPrecision = 2 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<16x16xui4, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> * tensor<16x16xui4, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>> -> tensor<16x16xi32, #mma> + %0 = tt.dot %arg0, %arg1, %arg2 {inputPrecision = 2 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<16x16xi4, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> * tensor<16x16xi4, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>> -> tensor<16x16xi32, #mma> // CHECK-COUNT-8: llvm.insertvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> tt.return } diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index 3bf06a836d98..0b7030311d46 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -445,7 +445,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: llvm.mlir.addressof @global_smem // CHECK-NEXT: llvm.getelementptr // CHECK-NEXT: llvm.mlir.constant - %0 = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #shared0> + %0 = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #shared0, #triton_gpu.shared_memory> tt.return } } @@ -475,8 +475,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-NEXT: llvm.getelementptr %index = arith.constant 1 : i32 %zero = arith.constant 0 : i32 - %0 = triton_gpu.local_alloc : () -> !tt.memdesc<128x16x32xf32, #shared0> - %1 = triton_gpu.memdesc_subview %0[%index, %zero, %zero] : !tt.memdesc<128x16x32xf32, #shared0> -> !tt.memdesc<16x32xf32, #shared0> + %0 = triton_gpu.local_alloc : () -> !tt.memdesc<128x16x32xf32, #shared0, #triton_gpu.shared_memory> + %1 = triton_gpu.memdesc_subview %0[%index, %zero, %zero] : !tt.memdesc<128x16x32xf32, #shared0, #triton_gpu.shared_memory> -> !tt.memdesc<16x32xf32, #shared0, #triton_gpu.shared_memory> tt.return } } @@ -506,7 +506,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : %24 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #slice1d0> %59 = tt.addptr %58, %24 : tensor<64x!tt.ptr, #slice1d0>, tensor<64xi32, #slice1d0> %66 = tt.addptr %59, %cst_2 : tensor<64x!tt.ptr, #slice1d0>, tensor<64xi32, #slice1d0> - %71 = triton_gpu.local_alloc : () -> !tt.memdesc<2x64xi64, #shared> + %71 = triton_gpu.local_alloc : () -> !tt.memdesc<2x64xi64, #shared, #triton_gpu.shared_memory> // CHECK: llvm.inline_asm has_side_effects asm_dialect = att // CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x8, 0x8 // CHECK: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x8, 0x8 @@ -517,7 +517,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : // CHECK: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x8, 0x8 // CHECK: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x8, 0x8 // CHECK: cp.async.commit_group - %73 = triton_gpu.async_copy_global_to_local %66, %71 : tensor<64x!tt.ptr, #slice1d0> -> !tt.memdesc<2x64xi64, #shared> + %73 = triton_gpu.async_copy_global_to_local %66, %71 : tensor<64x!tt.ptr, #slice1d0> -> !tt.memdesc<2x64xi64, #shared, #triton_gpu.shared_memory> triton_gpu.async_commit_group %73 tt.return } @@ -550,7 +550,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %off = arith.addi %broadcast_off0, %broadcast_off1 : tensor<16x64xi32, #AL> %a_init = tt.splat %arg0 : !tt.ptr -> tensor<16x64x!tt.ptr, #AL> %a_ptr = tt.addptr %a_init, %off : tensor<16x64x!tt.ptr, #AL>, tensor<16x64xi32, #AL> - %tensor = triton_gpu.local_alloc : () -> !tt.memdesc<16x64xf32, #A> + %tensor = triton_gpu.local_alloc : () -> !tt.memdesc<16x64xf32, #A, #triton_gpu.shared_memory> %index = arith.constant 1 : i32 // CHECK: llvm.inline_asm has_side_effects asm_dialect = att @@ -559,7 +559,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-SAME: cp.async.cg.shared.global [ ${{.*}} + 16 ], [ ${{.*}} + 0 ], 0x10, 0x10 // CHECK: llvm.inline_asm has_side_effects asm_dialect = att // CHECK-SAME: cp.async.commit_group - %a = triton_gpu.async_copy_global_to_local %a_ptr, %tensor : tensor<16x64x!tt.ptr, #AL> -> !tt.memdesc<16x64xf32, #A> + %a = triton_gpu.async_copy_global_to_local %a_ptr, %tensor : tensor<16x64x!tt.ptr, #AL> -> !tt.memdesc<16x64xf32, #A, #triton_gpu.shared_memory> triton_gpu.async_commit_group tt.return } @@ -592,7 +592,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %off = arith.addi %broadcast_off0, %broadcast_off1 : tensor<16x32xi32, #AL> %a_init = tt.splat %arg0 : !tt.ptr -> tensor<16x32x!tt.ptr, #AL> %a_ptr = tt.addptr %a_init, %off : tensor<16x32x!tt.ptr, #AL>, tensor<16x32xi32, #AL> - %tensor = triton_gpu.local_alloc : () -> !tt.memdesc<16x32xf32, #A> + %tensor = triton_gpu.local_alloc : () -> !tt.memdesc<16x32xf32, #A, #triton_gpu.shared_memory> %index = arith.constant 1 : i32 // CHECK: llvm.inline_asm @@ -605,7 +605,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4 // CHECK: llvm.inline_asm // CHECK-SAME: cp.async.commit_group - %a = triton_gpu.async_copy_global_to_local %a_ptr, %tensor : tensor<16x32x!tt.ptr, #AL> -> !tt.memdesc<16x32xf32, #A> + %a = triton_gpu.async_copy_global_to_local %a_ptr, %tensor : tensor<16x32x!tt.ptr, #AL> -> !tt.memdesc<16x32xf32, #A, #triton_gpu.shared_memory> triton_gpu.async_commit_group tt.return } @@ -637,7 +637,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %off = arith.addi %broadcast_off0, %broadcast_off1 : tensor<32x32xi32, #AL> %a_init = tt.splat %arg0 : !tt.ptr -> tensor<32x32x!tt.ptr, #AL> %a_ptr = tt.addptr %a_init, %off : tensor<32x32x!tt.ptr, #AL>, tensor<32x32xi32, #AL> - %tensor = triton_gpu.local_alloc : () -> !tt.memdesc<32x32xf32, #A> + %tensor = triton_gpu.local_alloc : () -> !tt.memdesc<32x32xf32, #A, #triton_gpu.shared_memory> %index = arith.constant 1 : i32 // CHECK: llvm.mlir.constant(0 : i32) : i32 @@ -662,7 +662,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4 // CHECK: llvm.inline_asm // CHECK-SAME: cp.async.commit_group - %a = triton_gpu.async_copy_global_to_local %a_ptr, %tensor : tensor<32x32x!tt.ptr, #AL> -> !tt.memdesc<32x32xf32, #A> + %a = triton_gpu.async_copy_global_to_local %a_ptr, %tensor : tensor<32x32x!tt.ptr, #AL> -> !tt.memdesc<32x32xf32, #A, #triton_gpu.shared_memory> triton_gpu.async_commit_group tt.return } @@ -806,14 +806,14 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { // CHECK-LABEL: convert_dot tt.func @convert_dot(%A: tensor<16x16xf16, #blocked0>, %B: tensor<16x16xf16, #blocked0>) { - %AA = triton_gpu.local_alloc %A : (tensor<16x16xf16, #blocked0>) -> !tt.memdesc<16x16xf16, #shared0> - %BB = triton_gpu.local_alloc %B : (tensor<16x16xf16, #blocked0>) -> !tt.memdesc<16x16xf16, #shared0> + %AA = triton_gpu.local_alloc %A : (tensor<16x16xf16, #blocked0>) -> !tt.memdesc<16x16xf16, #shared0, #triton_gpu.shared_memory> + %BB = triton_gpu.local_alloc %B : (tensor<16x16xf16, #blocked0>) -> !tt.memdesc<16x16xf16, #shared0, #triton_gpu.shared_memory> // CHECK: llvm.inline_asm // CHECK: ldmatrix.sync.aligned.m8n8.x4 // CHECK: llvm.inline_asm // CHECK-SAME: ldmatrix.sync.aligned.m8n8.x4 - %AA_DOT = triton_gpu.local_load %AA : !tt.memdesc<16x16xf16, #shared0> -> tensor<16x16xf16, #dot_operand_a> - %BB_DOT = triton_gpu.local_load %BB : !tt.memdesc<16x16xf16, #shared0> -> tensor<16x16xf16, #dot_operand_b> + %AA_DOT = triton_gpu.local_load %AA : !tt.memdesc<16x16xf16, #shared0, #triton_gpu.shared_memory> -> tensor<16x16xf16, #dot_operand_a> + %BB_DOT = triton_gpu.local_load %BB : !tt.memdesc<16x16xf16, #shared0, #triton_gpu.shared_memory> -> tensor<16x16xf16, #dot_operand_b> %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma0> // CHECK: llvm.inline_asm @@ -906,7 +906,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : // CHECK-SAME: !llvm.ptr<3> // CHECK: llvm.store // CHECK-SAME: !llvm.ptr<3> - %0 = triton_gpu.local_alloc %arg0 : (tensor<128x32xf32, #blocked0>) -> !tt.memdesc<128x32xf32, #shared0> + %0 = triton_gpu.local_alloc %arg0 : (tensor<128x32xf32, #blocked0>) -> !tt.memdesc<128x32xf32, #shared0, #triton_gpu.shared_memory> tt.return } } @@ -963,11 +963,11 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma, kWidth=2}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { tt.func @matmul_kernel_dot_operand_layout(%ptr:!tt.ptr {tt.divisibility = 16 : i32}, - %a:!tt.memdesc<128x32xf16, #shared>, %b:!tt.memdesc<32x256xf16, #shared>) { + %a:!tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory>, %b:!tt.memdesc<32x256xf16, #shared, #triton_gpu.shared_memory>) { %cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma> // CHECK: ldmatrix.sync.aligned.m8n8.x4.shared.b16 - %a_mat = triton_gpu.local_load %a : !tt.memdesc<128x32xf16, #shared> -> tensor<128x32xf16, #dot_operand_a> - %b_mat = triton_gpu.local_load %b : !tt.memdesc<32x256xf16, #shared> -> tensor<32x256xf16, #dot_operand_b> + %a_mat = triton_gpu.local_load %a : !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x32xf16, #dot_operand_a> + %b_mat = triton_gpu.local_load %b : !tt.memdesc<32x256xf16, #shared, #triton_gpu.shared_memory> -> tensor<32x256xf16, #dot_operand_b> %28 = tt.dot %a_mat, %b_mat, %cst : tensor<128x32xf16, #dot_operand_a> * tensor<32x256xf16, #dot_operand_b> -> tensor<128x256xf32, #mma> %38 = triton_gpu.convert_layout %28 : tensor<128x256xf32, #mma> -> tensor<128x256xf32, #blocked> @@ -989,11 +989,11 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { tt.func @matmul884_kernel_dot_operand_layout(%ptr:!tt.ptr {tt.divisibility = 16 : i32}, - %a:!tt.memdesc<32x64xf16, #shared0>, %b:!tt.memdesc<64x64xf16, #shared1>) { + %a:!tt.memdesc<32x64xf16, #shared0, #triton_gpu.shared_memory>, %b:!tt.memdesc<64x64xf16, #shared1, #triton_gpu.shared_memory>) { %cst = arith.constant dense<0.000000e+00> : tensor<32x64xf32, #mma> // CHECK: ldmatrix.sync.aligned.m8n8.x4.shared.b16 - %a_mat = triton_gpu.local_load %a : !tt.memdesc<32x64xf16, #shared0> -> tensor<32x64xf16, #dot_operand_a> - %b_mat = triton_gpu.local_load %b : !tt.memdesc<64x64xf16, #shared1> -> tensor<64x64xf16, #dot_operand_b> + %a_mat = triton_gpu.local_load %a : !tt.memdesc<32x64xf16, #shared0, #triton_gpu.shared_memory> -> tensor<32x64xf16, #dot_operand_a> + %b_mat = triton_gpu.local_load %b : !tt.memdesc<64x64xf16, #shared1, #triton_gpu.shared_memory> -> tensor<64x64xf16, #dot_operand_b> %28 = tt.dot %a_mat, %b_mat, %cst : tensor<32x64xf16, #dot_operand_a> * tensor<64x64xf16, #dot_operand_b> -> tensor<32x64xf32, #mma> %38 = triton_gpu.convert_layout %28 : tensor<32x64xf32, #mma> -> tensor<32x64xf32, #blocked> @@ -1012,11 +1012,11 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { tt.func @matmul_fmadot(%ptr:!tt.ptr {tt.divisibility = 16 : i32}, - %a:!tt.memdesc<32x16xf32, #shared>, %b:!tt.memdesc<16x32xf32, #shared>) { + %a:!tt.memdesc<32x16xf32, #shared, #triton_gpu.shared_memory>, %b:!tt.memdesc<16x32xf32, #shared, #triton_gpu.shared_memory>) { %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #blocked> // CHECK: llvm.intr.fmuladd - %a_mat = triton_gpu.local_load %a : !tt.memdesc<32x16xf32, #shared> -> tensor<32x16xf32, #dot_operand_a> - %b_mat = triton_gpu.local_load %b : !tt.memdesc<16x32xf32, #shared> -> tensor<16x32xf32, #dot_operand_b> + %a_mat = triton_gpu.local_load %a : !tt.memdesc<32x16xf32, #shared, #triton_gpu.shared_memory> -> tensor<32x16xf32, #dot_operand_a> + %b_mat = triton_gpu.local_load %b : !tt.memdesc<16x32xf32, #shared, #triton_gpu.shared_memory> -> tensor<16x32xf32, #dot_operand_b> %28 = tt.dot %a_mat, %b_mat, %cst, inputPrecision = ieee : tensor<32x16xf32, #dot_operand_a> * tensor<16x32xf32, #dot_operand_b> -> tensor<32x32xf32, #blocked> %30 = tt.splat %ptr : !tt.ptr -> tensor<32x1x!tt.ptr, #blocked> @@ -1036,7 +1036,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: matmul_tf32dot tt.func @matmul_tf32dot(%ptr:!tt.ptr {tt.divisibility = 16 : i32}, - %a:!tt.memdesc<32x16xf32, #shared>, %b:!tt.memdesc<16x32xf32, #shared>) { + %a:!tt.memdesc<32x16xf32, #shared, #triton_gpu.shared_memory>, %b:!tt.memdesc<16x32xf32, #shared, #triton_gpu.shared_memory>) { %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> // CHECK: llvm.inline_asm // CHECK-SAME: ldmatrix.sync.aligned.m8n8.x4.shared.b16 @@ -1044,8 +1044,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: llvm.inline_asm // CHECK-SAME: ldmatrix.sync.aligned.m8n8.x4.shared.b16 // CHECK-SAME: (i32, i32, i32, i32) - %a_mat = triton_gpu.local_load %a : !tt.memdesc<32x16xf32, #shared> -> tensor<32x16xf32, #dot_operand_a> - %b_mat = triton_gpu.local_load %b : !tt.memdesc<16x32xf32, #shared> -> tensor<16x32xf32, #dot_operand_b> + %a_mat = triton_gpu.local_load %a : !tt.memdesc<32x16xf32, #shared, #triton_gpu.shared_memory> -> tensor<32x16xf32, #dot_operand_a> + %b_mat = triton_gpu.local_load %b : !tt.memdesc<16x32xf32, #shared, #triton_gpu.shared_memory> -> tensor<16x32xf32, #dot_operand_b> // CHECK: llvm.inline_asm // CHECK-SAME: mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 @@ -1240,8 +1240,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : // CHECK-LABEL: test_base_index_cache tt.func @test_base_index_cache(%arg0: tensor<128x32xf32, #blocked0>) { // CHECK: nvvm.read.ptx.sreg.tid.x - %0 = triton_gpu.local_alloc %arg0 : (tensor<128x32xf32, #blocked0>) -> !tt.memdesc<128x32xf32, #shared0> - %1 = triton_gpu.local_alloc %arg0 : (tensor<128x32xf32, #blocked0>) -> !tt.memdesc<128x32xf32, #shared0> + %0 = triton_gpu.local_alloc %arg0 : (tensor<128x32xf32, #blocked0>) -> !tt.memdesc<128x32xf32, #shared0, #triton_gpu.shared_memory> + %1 = triton_gpu.local_alloc %arg0 : (tensor<128x32xf32, #blocked0>) -> !tt.memdesc<128x32xf32, #shared0, #triton_gpu.shared_memory> tt.return } } @@ -1253,10 +1253,10 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : // CHECK-LABEL: test_index_cache_different_block tt.func @test_index_cache_different_block(%arg0: tensor<128x32xf32, #blocked0>, %arg1: i1) { // CHECK: nvvm.read.ptx.sreg.tid.x - %0 = triton_gpu.local_alloc %arg0 : (tensor<128x32xf32, #blocked0>) -> !tt.memdesc<128x32xf32, #shared0> + %0 = triton_gpu.local_alloc %arg0 : (tensor<128x32xf32, #blocked0>) -> !tt.memdesc<128x32xf32, #shared0, #triton_gpu.shared_memory> cf.cond_br %arg1, ^bb1, ^bb2 ^bb1: // pred: ^bb0 - %1 = triton_gpu.local_alloc %arg0 : (tensor<128x32xf32, #blocked0>) -> !tt.memdesc<128x32xf32, #shared0> + %1 = triton_gpu.local_alloc %arg0 : (tensor<128x32xf32, #blocked0>) -> !tt.memdesc<128x32xf32, #shared0, #triton_gpu.shared_memory> cf.br ^bb2 ^bb2: // 2 preds: ^bb0, ^bb1 tt.return @@ -1550,8 +1550,8 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : // CHECK: llvm.load // CHECK-SAME: {alignment = 8 : i64} : !llvm.ptr<3> -> vector<8xi8> // CHECK-NOT: llvm.load - tt.func public @vectorize_shmem_load(%shmem : !tt.memdesc<16x16xi8, #shared>) { - %0 = triton_gpu.local_load %shmem : !tt.memdesc<16x16xi8, #shared> -> tensor<16x16xi8, #blocked> + tt.func public @vectorize_shmem_load(%shmem : !tt.memdesc<16x16xi8, #shared, #triton_gpu.shared_memory>) { + %0 = triton_gpu.local_load %shmem : !tt.memdesc<16x16xi8, #shared, #triton_gpu.shared_memory> -> tensor<16x16xi8, #blocked> tt.return } } @@ -1566,7 +1566,7 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : // CHECK-SAME: {alignment = 64 : i64} : vector<16xi32>, !llvm.ptr<3> // CHECK-NOT: llvm.store tt.func public @vectorize_shmem_store(%block : tensor<64x64xi32, #blocked>) { - %0 = triton_gpu.local_alloc %block : (tensor<64x64xi32, #blocked>) -> !tt.memdesc<64x64xi32, #shared> + %0 = triton_gpu.local_alloc %block : (tensor<64x64xi32, #blocked>) -> !tt.memdesc<64x64xi32, #shared, #triton_gpu.shared_memory> tt.return } } @@ -1591,9 +1591,9 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : // CHECK: llvm.extractelement {{.*}} : vector<8xi16> tt.func public @test_local_load_bf16() { %c0_i32 = arith.constant 0 : i32 - %19 = triton_gpu.local_alloc : () -> !tt.memdesc<1x1x2048xbf16, #shared, mutable> - %22 = triton_gpu.memdesc_subview %19[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<1x1x2048xbf16, #shared, mutable> -> !tt.memdesc<1x2048xbf16, #shared, mutable> - %39 = triton_gpu.local_load %22 : !tt.memdesc<1x2048xbf16, #shared, mutable> -> tensor<1x2048xbf16, #blocked> + %19 = triton_gpu.local_alloc : () -> !tt.memdesc<1x1x2048xbf16, #shared, #triton_gpu.shared_memory, mutable> + %22 = triton_gpu.memdesc_subview %19[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<1x1x2048xbf16, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<1x2048xbf16, #shared, #triton_gpu.shared_memory, mutable> + %39 = triton_gpu.local_load %22 : !tt.memdesc<1x2048xbf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<1x2048xbf16, #blocked> %40 = arith.extf %39 : tensor<1x2048xbf16, #blocked> to tensor<1x2048xf32, #blocked> tt.return } @@ -1607,8 +1607,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: llvm.store tt.func public @test_local_store(%arg0: tensor<1xf32, #blocked>) { %c0_i32 = arith.constant 0 : i32 - %0 = triton_gpu.local_alloc {allocation.offset = 0 : i32} : () -> !tt.memdesc<1xf32, #shared, mutable> - triton_gpu.local_store %arg0, %0 : tensor<1xf32, #blocked> -> !tt.memdesc<1xf32, #shared, mutable> + %0 = triton_gpu.local_alloc {allocation.offset = 0 : i32} : () -> !tt.memdesc<1xf32, #shared, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %arg0, %0 : tensor<1xf32, #blocked> -> !tt.memdesc<1xf32, #shared, #triton_gpu.shared_memory, mutable> tt.return } } @@ -1621,9 +1621,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: llvm.store tt.func public @test_local_store_subview(%arg0: tensor<1xf32, #blocked>) { %c0_i32 = arith.constant 0 : i32 - %0 = triton_gpu.local_alloc {allocation.offset = 0 : i32} : () -> !tt.memdesc<1xf32, #shared, mutable> - %sv = triton_gpu.memdesc_subview %0[%c0_i32] : !tt.memdesc<1xf32, #shared, mutable> -> !tt.memdesc<1xf32, #shared, mutable> - triton_gpu.local_store %arg0, %sv : tensor<1xf32, #blocked> -> !tt.memdesc<1xf32, #shared, mutable> + %0 = triton_gpu.local_alloc {allocation.offset = 0 : i32} : () -> !tt.memdesc<1xf32, #shared, #triton_gpu.shared_memory, mutable> + %sv = triton_gpu.memdesc_subview %0[%c0_i32] : !tt.memdesc<1xf32, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<1xf32, #shared, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %arg0, %sv : tensor<1xf32, #blocked> -> !tt.memdesc<1xf32, #shared, #triton_gpu.shared_memory, mutable> tt.return } } diff --git a/test/Conversion/tritongpu_to_llvm_hopper.mlir b/test/Conversion/tritongpu_to_llvm_hopper.mlir index 743d554a31b6..52b7f04354fa 100644 --- a/test/Conversion/tritongpu_to_llvm_hopper.mlir +++ b/test/Conversion/tritongpu_to_llvm_hopper.mlir @@ -14,7 +14,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : // CHECK-COUNT-128: llvm.fadd // CHECK: nvgpu.wgmma // CHECK-COUNT-128: llvm.fadd - %m = triton_nvidia_gpu.dot_async %a, %b, %c + %m = triton_nvidia_gpu.warp_group_dot %a, %b, %c {maxNumImpreciseAcc = 32 : i32, inputPrecision = 0 : i32} : !tt.memdesc<128x128xf8E5M2, #shared> * !tt.memdesc<128x256xf8E5M2, #shared1> -> tensor<128x256xf32, #mma> tt.return @@ -38,7 +38,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : // CHECK: nvgpu.wgmma // CHECK-NOT: llvm.fadd // CHECK: llvm.return - %m = triton_nvidia_gpu.dot_async %a, %b, %c + %m = triton_nvidia_gpu.warp_group_dot %a, %b, %c {maxNumImpreciseAcc = 129 : i32, inputPrecision = 0 : i32} : !tt.memdesc<128x128xf8E5M2, #shared> * !tt.memdesc<128x256xf8E5M2, #shared1> -> tensor<128x256xf32, #mma> tt.return @@ -62,7 +62,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : // CHECK: nvgpu.wgmma // CHECK-COUNT-128: llvm.fadd // CHECK: llvm.return - %m = triton_nvidia_gpu.dot_async %a, %b, %c + %m = triton_nvidia_gpu.warp_group_dot %a, %b, %c {maxNumImpreciseAcc = 64 : i32, inputPrecision = 0 : i32} : !tt.memdesc<128x128xf8E5M2, #shared> * !tt.memdesc<128x256xf8E5M2, #shared1> -> tensor<128x256xf32, #mma> tt.return @@ -80,7 +80,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: nvgpu.wgmma %{{.*}}, %{{.*}} { tt.func @dot_zero_acc(%a: !tt.memdesc<128x64xf16, #shared>, %b: !tt.memdesc<64x64xf16, #shared1>) { %cst = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> - %m = triton_nvidia_gpu.dot_async %a, %b, %cst {inputPrecision = 0 : i32, maxNumImpreciseAcc = 0 : i32} : + %m = triton_nvidia_gpu.warp_group_dot %a, %b, %cst {inputPrecision = 0 : i32, maxNumImpreciseAcc = 0 : i32} : !tt.memdesc<128x64xf16, #shared> * !tt.memdesc<64x64xf16, #shared1> -> tensor<128x64xf32, #mma> tt.return } @@ -98,7 +98,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : tt.func @dot_reg_operand_A(%a: tensor<128x64xf16, #mma>, %b: !tt.memdesc<64x64xf16, #shared>) { %cst = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> %opA = triton_gpu.convert_layout %a : tensor<128x64xf16, #mma> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> - %m = tt.dot %opA, %b, %cst, inputPrecision = tf32 : + %m = triton_nvidia_gpu.warp_group_dot %opA, %b, %cst { inputPrecision = 0 : i32 }: tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * !tt.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> tt.return } @@ -116,7 +116,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: nvgpu.wgmma_wait_group %{{.*}} {pendings = 0 : i32} tt.func @dot_reg_operand_A_fp8(%a: tensor<128x128xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>>, %b: !tt.memdesc<128x256xf8E5M2, #shared>) { %cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma1> - %m = tt.dot %a, %b, %cst, inputPrecision = tf32 {maxNumImpreciseAcc = 1073741824 : i32} : + %m = triton_nvidia_gpu.warp_group_dot %a, %b, %cst { maxNumImpreciseAcc = 1073741824 : i32, inputPrecision = 0 : i32 } : tensor<128x128xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * !tt.memdesc<128x256xf8E5M2, #shared> -> tensor<128x256xf32, #mma1> tt.return } @@ -208,7 +208,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-COUNT-128: llvm.fadd tt.func @dot_zero_acc_operand(%a: !tt.memdesc<128x128xf8E5M2, #shared>, %b: !tt.memdesc<128x128xf8E5M2, #shared1>) { %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma> - %m = tt.dot %a, %b, %cst, inputPrecision = tf32 {maxNumImpreciseAcc = 64 : i32} : + %m = triton_nvidia_gpu.warp_group_dot %a, %b, %cst {maxNumImpreciseAcc = 64 : i32, inputPrecision = 0 : i32} : !tt.memdesc<128x128xf8E5M2, #shared> * !tt.memdesc<128x128xf8E5M2, #shared1> -> tensor<128x128xf32, #mma> tt.return } diff --git a/test/TritonGPU/accelerate-matmul.mlir b/test/TritonGPU/accelerate-matmul.mlir index 8c4e85aa0859..d37c9ebd901a 100644 --- a/test/TritonGPU/accelerate-matmul.mlir +++ b/test/TritonGPU/accelerate-matmul.mlir @@ -23,8 +23,8 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x64xf16, #blocked1> %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #blocked2> // CHECK: scf.for - // CHECK: tt.dot {{.*}} -> tensor<128x16xf16, #[[MMA]]> - // CHECK: tt.dot {{.*}} -> tensor<128x64xf16, #[[MMA1]]> + // CHECK: triton_nvidia_gpu.warp_group_dot {{.*}} -> tensor<128x16xf16, #[[MMA]]> + // CHECK: triton_nvidia_gpu.warp_group_dot {{.*}} -> tensor<128x64xf16, #[[MMA1]]> %115 = scf.for %arg15 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg16 = %cst_0) -> (tensor<128x64xf16, #blocked1>) : i32 { %172 = tt.dot %170, %171, %cst : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x16xf16, #blocked> %178 = triton_gpu.convert_layout %172 : tensor<128x16xf16, #blocked> -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked1}>> @@ -32,8 +32,8 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : scf.yield %180 : tensor<128x64xf16, #blocked1> } // CHECK: scf.for - // CHECK: tt.dot {{.*}} -> tensor<128x32xf16, #[[MMA2]]> - // CHECK: tt.dot {{.*}} -> tensor<128x64xf16, #[[MMA1]]> + // CHECK: triton_nvidia_gpu.warp_group_dot {{.*}} -> tensor<128x32xf16, #[[MMA2]]> + // CHECK: triton_nvidia_gpu.warp_group_dot {{.*}} -> tensor<128x64xf16, #[[MMA1]]> %149 = scf.for %arg15 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg16 = %115) -> (tensor<128x64xf16, #blocked1>) : i32 { %166 = tt.dot %164, %165, %cst_2 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked2}>> * tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked2}>> -> tensor<128x32xf16, #blocked2> %172 = triton_gpu.convert_layout %166 : tensor<128x32xf16, #blocked2> -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked1}>> diff --git a/test/TritonGPU/amd/amd-reorder-instructions.mlir b/test/TritonGPU/amd/amd-reorder-instructions.mlir index cb565d1f054d..5dfd0f2a5f4c 100644 --- a/test/TritonGPU/amd/amd-reorder-instructions.mlir +++ b/test/TritonGPU/amd/amd-reorder-instructions.mlir @@ -1,25 +1,501 @@ // RUN: triton-opt %s -split-input-file -tritonamdgpu-reorder-instructions | FileCheck %s -// Check that we order load, local_alloc and local_load one after another. This is useful -// for making sure that Q tensor in FA is hoisted out of the main loop and kept in registers +// Check that we place local_alloc, local_store (optional) and local_load right after definition of their operands +// in cases where local_alloc is in the loop but it's operand is not. +// This is useful for making sure that Q tensor in FA is hoisted out of the main loop and kept in registers // throughout the computation. -// CHECK-LABEL: order_load_alloc_local_load -// CHECK: %[[LOAD:.+]] = tt.load -// CHECK-NEXT: %[[ALLOC:.+]] = triton_gpu.local_alloc %[[LOAD]] -// CHECK-NEXT: triton_gpu.local_load %[[ALLOC]] + +// CHECK-LABEL: hoist_q_out_of_the_loop +// CHECK: %[[TRUNCF:.+]] = arith.truncf +// CHECK-NEXT: %[[ALLOC:.+]] = triton_gpu.local_alloc %[[TRUNCF]] +// CHECK-NEXT: triton_gpu.local_load %[[ALLOC]] +// CHECK: scf.for +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}> +#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0], hasLeadingOffset = false}> +#mfma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "hip:gfx90a", "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func public @hoist_q_out_of_the_loop(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg20: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c0_i32 = arith.constant 0 : i32 + %cst = arith.constant 1.44269502 : f32 + %c128_i32 = arith.constant 128 : i32 + %c128_i64 = arith.constant 128 : i64 + %c0_i64 = arith.constant 0 : i64 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<256x128xf32, #mfma> + %1 = tt.get_program_id y : i32 + %2 = arith.muli %1, %arg7 : i32 + %3 = tt.addptr %arg0, %2 : !tt.ptr, i32 + %12 = tt.splat %3 : !tt.ptr -> tensor<256x128x!tt.ptr, #blocked1> + %41 = tt.load %12 : tensor<256x128x!tt.ptr, #blocked1> + %42 = arith.extf %41 : tensor<256x128xf16, #blocked1> to tensor<256x128xf32, #blocked1> + %43 = tt.splat %cst : f32 -> tensor<256x128xf32, #blocked1> + %44 = arith.mulf %42, %43 : tensor<256x128xf32, #blocked1> + %45 = arith.truncf %44 : tensor<256x128xf32, #blocked1> to tensor<256x128xf16, #blocked1> + %54:1 = scf.for %arg21 = %c0_i32 to %arg20 step %c128_i32 iter_args(%arg26 = %c0_i64) -> (i64) : i32 { + %73 = tt.splat %3 : !tt.ptr -> tensor<128x128x!tt.ptr, #blocked2> + %74 = tt.load %73 : tensor<128x128x!tt.ptr, #blocked2> + %75 = triton_gpu.local_alloc %45 : (tensor<256x128xf16, #blocked1>) -> !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory> + %76 = triton_gpu.local_load %75 : !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory> -> tensor<256x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> + %77 = triton_gpu.local_alloc %74 : (tensor<128x128xf16, #blocked2>) -> !tt.memdesc<128x128xf16, #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>, #triton_gpu.shared_memory> + %78 = triton_gpu.local_load %77 : !tt.memdesc<128x128xf16, #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>, #triton_gpu.shared_memory> -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> + %79 = tt.dot %76, %78, %cst_2 : tensor<256x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> * tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> -> tensor<256x128xf32, #mfma> + %107 = arith.addi %arg26, %c128_i64 : i64 + scf.yield %107 : i64 + } {tt.divisibility_arg1 = dense<128> : tensor<1xi32>} + tt.return + } +} + + +// ----- +// Check that reordering described in hoist_q_out_of_the_loop is not done in the case where both +// local_alloc and it's src tensor defining op are in the loop. +// CHECK-LABEL: no_hoist_q_type_reordering +// CHECK: scf.for +// CHECK: %[[TRUNCF:.+]] = arith.truncf +// CHECK-NEXT: arith.constant +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}> +#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0], hasLeadingOffset = false}> +#mfma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "hip:gfx90a", "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func public @no_hoist_q_type_reordering(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg20: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c0_i32 = arith.constant 0 : i32 + %cst = arith.constant 1.44269502 : f32 + %c128_i32 = arith.constant 128 : i32 + %c128_i64 = arith.constant 128 : i64 + %c0_i64 = arith.constant 0 : i64 + %1 = tt.get_program_id y : i32 + %2 = arith.muli %1, %arg7 : i32 + %3 = tt.addptr %arg0, %2 : !tt.ptr, i32 + %12 = tt.splat %3 : !tt.ptr -> tensor<256x128x!tt.ptr, #blocked1> + %41 = tt.load %12 : tensor<256x128x!tt.ptr, #blocked1> + %42 = arith.extf %41 : tensor<256x128xf16, #blocked1> to tensor<256x128xf32, #blocked1> + %43 = tt.splat %cst : f32 -> tensor<256x128xf32, #blocked1> + %44 = arith.mulf %42, %43 : tensor<256x128xf32, #blocked1> + %54:1 = scf.for %arg21 = %c0_i32 to %arg20 step %c128_i32 iter_args(%arg26 = %c0_i64) -> (i64) : i32 { + %45 = arith.truncf %44 : tensor<256x128xf32, #blocked1> to tensor<256x128xf16, #blocked1> + %cst_2 = arith.constant dense<0.000000e+00> : tensor<256x128xf32, #mfma> + %73 = tt.splat %3 : !tt.ptr -> tensor<128x128x!tt.ptr, #blocked2> + %74 = tt.load %73 : tensor<128x128x!tt.ptr, #blocked2> + %75 = triton_gpu.local_alloc %45 : (tensor<256x128xf16, #blocked1>) -> !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory> + %76 = triton_gpu.local_load %75 : !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory> -> tensor<256x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> + %77 = triton_gpu.local_alloc %74 : (tensor<128x128xf16, #blocked2>) -> !tt.memdesc<128x128xf16, #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>, #triton_gpu.shared_memory> + %78 = triton_gpu.local_load %77 : !tt.memdesc<128x128xf16, #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>, #triton_gpu.shared_memory> -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> + %79 = tt.dot %76, %78, %cst_2 : tensor<256x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> * tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> -> tensor<256x128xf32, #mfma> + %107 = arith.addi %arg26, %c128_i64 : i64 + scf.yield %107 : i64 + } {tt.divisibility_arg1 = dense<128> : tensor<1xi32>} + tt.return + } +} + +// ----- #blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}> #mma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}> #shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}> + +// CHECK-LABEL: order_load_alloc_local_load_local_store +// CHECK: %[[LOAD:.+]] = tt.load +// CHECK: %[[ALLOC:.+]] = triton_gpu.local_alloc +// CHECK: triton_gpu.local_store %[[LOAD]], %[[ALLOC]] +// CHECK: triton_gpu.local_load %[[ALLOC]] module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { - tt.func public @order_load_alloc_local_load(%arg0: tensor<32x32x!tt.ptr, #blocked>) attributes {noinline = false} { + tt.func public @order_load_alloc_local_load_local_store(%arg0: tensor<32x32x!tt.ptr, #blocked>) attributes {noinline = false} { %9 = tt.load %arg0 : tensor<32x32x!tt.ptr, #blocked> %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> - %10 = triton_gpu.local_alloc %9 : (tensor<32x32xf32, #blocked>) -> !tt.memdesc<32x32xf32, #shared> + %10 = triton_gpu.local_alloc : () -> !tt.memdesc<32x32xf32, #shared, mutable> + triton_gpu.local_store %9, %10 : tensor<32x32xf32, #blocked> -> !tt.memdesc<32x32xf32, #shared, mutable> %cst_0 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %11 = triton_gpu.local_load %10 : !tt.memdesc<32x32xf32, #shared> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %11 = triton_gpu.local_load %10 : !tt.memdesc<32x32xf32, #shared, mutable> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> %12 = tt.dot %11, %cst_0, %cst : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x32xf32, #mma> %13 = triton_gpu.convert_layout %12 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> tt.store %arg0, %13 : tensor<32x32x!tt.ptr, #blocked> tt.return } } + +// ----- +// Move loads (and independent local_stores) as early as possible. +// For example in the matmul_loop below, the scf.for loop looks like this after pipeliner: +// scf.for ... { +// // stage 1 +// %a = tt.local_load %a_tile +// %b = tt.local_load %b_tile +// tt.dot %c, %a, %b +// // stage 0 +// %aptr = tt.addptr %aptr, %k +// %a_next = tt.load %aptr +// %bptr = tt.addptr %bptr, %k +// %b_next = tt.load %bptr +// tt.local_store %a_next +// tt.local_store %b_next +// yield +// } +// +// Solution for num_stages=2 : +// scf.for ... { +// // stage 0.a +// %aptr = tt.addptr %aptr, %k +// %a_next = tt.load %aptr +// %bptr = tt.addptr %bptr, %k +// %b_next = tt.load %bptr +// // stage 1 +// %a = tt.local_load %a_tile +// %b = tt.local_load %b_tile +// tt.dot %c, %a, %b +// // stage 0.b +// tt.local_store %a_next +// tt.local_store %b_next +// yield +// } +// +// Solution for num_stages=3 (double-buffered) : +// scf.for ... { +// // stage 1 +// tt.local_store %a_next_1 +// tt.local_store %b_next_1 +// // stage 0 +// %aptr = tt.addptr %aptr, %k +// %a_next_2 = tt.load %aptr +// %bptr = tt.addptr %bptr, %k +// %b_next_2 = tt.load %bptr +// // stage 2 +// %a = tt.local_load %a_tile +// %b = tt.local_load %b_tile +// tt.dot %c, %a, %b +// yield +// } + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 64], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = []}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> +#shared2 = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0], hasLeadingOffset = false}> +#shared3 = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> +#shared4 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32, triton_gpu.target = "hip:gfx942"} { + +// CHECK-LABEL: tt.func @matmul_loop +// CHECK: %{{.*}}:6 = scf.for %[[ARG5:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %{{.*}}, %[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %{{.*}}) +// Stage 0.a +// CHECK: %[[ADDPTR_20:.*]] = tt.addptr %[[ARG6]], %{{.*}} +// CHECK: %[[SUBI_21:.*]] = arith.subi %{{.*}}, %{{.*}} +// CHECK: %[[CMPI_22:.*]] = arith.cmpi slt, %[[ARG5]], %[[SUBI_21]] +// CHECK: %[[SPLAT_23:.*]] = tt.splat %[[CMPI_22]] +// CHECK: %[[LOAD_24:.*]] = tt.load %[[ADDPTR_20]], %[[SPLAT_23]] +// CHECK: %[[ADDPTR_25:.*]] = tt.addptr %[[ARG7]], %{{.*}} +// CHECK: %[[SPLAT_26:.*]] = tt.splat %[[CMPI_22]] +// CHECK: %[[LOAD_27:.*]] = tt.load %[[ADDPTR_25]], %[[SPLAT_26]] +// Stage 1 +// CHECK: %[[LOCAL_LOAD_28:.*]] = triton_gpu.local_load %[[ARG10]] +// CHECK: %[[LOCAL_LOAD_29:.*]] = triton_gpu.local_load %[[ARG11]] +// CHECK: %[[MULF_30:.*]] = arith.mulf %[[LOCAL_LOAD_29]], %{{.*}} +// CHECK: %[[DOT_31:.*]] = tt.dot %[[LOCAL_LOAD_28]], %[[MULF_30]], %[[ARG8]] +// Stage 0.b +// CHECK: %[[ADDI_32:.*]] = arith.addi %[[ARG9]], %{{.*}} +// CHECK: %[[CMPI_33:.*]] = arith.cmpi slt, %[[ADDI_32]], %{{.*}} +// CHECK: %[[SELECT_34:.*]] = arith.select %[[CMPI_33]], %[[ADDI_32]], %{{.*}} +// CHECK: %[[MEMDESC_SUBVIEW_35:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_34]], %{{.*}}, %{{.*}}] +// CHECK: triton_gpu.local_store %[[LOAD_24]], %[[MEMDESC_SUBVIEW_35]] +// CHECK: %[[MEMDESC_SUBVIEW_36:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_34]], %{{.*}}, %{{.*}}] +// CHECK: triton_gpu.local_store %[[LOAD_27]], %[[MEMDESC_SUBVIEW_36]] +// CHECK: scf.yield %[[ADDPTR_20]], %[[ADDPTR_25]], %[[DOT_31]], %[[SELECT_34]], %[[MEMDESC_SUBVIEW_35]], %[[MEMDESC_SUBVIEW_36]] +// CHECK: } + + tt.func @matmul_loop(%arg0: index, %arg1: index, %arg2: index, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #mma> { + %c1_i32 = arith.constant 1 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst = arith.constant dense<4.000000e+00> : tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %cst_0 = arith.constant dense<4> : tensor<32x128xi32, #blocked> + %cst_1 = arith.constant dense<4> : tensor<128x32xi32, #blocked1> + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma> + %cst_3 = arith.constant dense<0.000000e+00> : tensor<32x128xf16, #blocked> + %0 = tt.splat %arg3 : !tt.ptr -> tensor<128x32x!tt.ptr, #blocked1> + %1 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x32xi32, #blocked1> + %3 = tt.broadcast %2 : tensor<1x32xi32, #blocked1> -> tensor<128x32xi32, #blocked1> + %4 = tt.addptr %0, %3 : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> + %5 = tt.splat %arg4 : !tt.ptr -> tensor<32x128x!tt.ptr, #blocked> + %6 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %7 = tt.expand_dims %6 {axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> + %8 = tt.broadcast %7 : tensor<1x128xi32, #blocked> -> tensor<32x128xi32, #blocked> + %9 = tt.addptr %5, %8 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> + %10 = triton_gpu.local_alloc : () -> !tt.memdesc<1x128x32xf16, #shared, #triton_gpu.shared_memory, mutable> + %11 = triton_gpu.local_alloc : () -> !tt.memdesc<1x32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> + %12 = arith.cmpi slt, %arg0, %arg1 : index + %13 = tt.splat %12 : i1 -> tensor<128x32xi1, #blocked1> + %14 = tt.load %4, %13 : tensor<128x32x!tt.ptr, #blocked1> + %15 = tt.splat %12 : i1 -> tensor<32x128xi1, #blocked> + %16 = tt.load %9, %15, %cst_3 : tensor<32x128x!tt.ptr, #blocked> + %17 = triton_gpu.memdesc_subview %10[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<1x128x32xf16, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %14, %17 : tensor<128x32xf16, #blocked1> -> !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> + %18 = triton_gpu.memdesc_subview %11[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<1x32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %16, %18 : tensor<32x128xf16, #blocked> -> !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> + %19:6 = scf.for %arg5 = %arg0 to %arg1 step %arg2 iter_args(%arg6 = %4, %arg7 = %9, %arg8 = %cst_2, %arg9 = %c0_i32, %arg10 = %17, %arg11 = %18) -> (tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked>, tensor<128x128xf32, #mma>, i32, !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable>, !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable>) { + %20 = arith.subi %arg1, %arg2 : index + %21 = arith.cmpi slt, %arg5, %20 : index + %22 = triton_gpu.local_load %arg10 : !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %23 = triton_gpu.local_load %arg11 : !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %24 = arith.mulf %23, %cst : tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %25 = tt.dot %22, %24, %arg8 : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x128xf32, #mma> + %26 = tt.addptr %arg6, %cst_1 : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> + %27 = tt.addptr %arg7, %cst_0 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> + %28 = tt.splat %21 : i1 -> tensor<128x32xi1, #blocked1> + %29 = tt.load %26, %28 : tensor<128x32x!tt.ptr, #blocked1> + %30 = tt.splat %21 : i1 -> tensor<32x128xi1, #blocked> + %31 = tt.load %27, %30, %cst_3 : tensor<32x128x!tt.ptr, #blocked> + %32 = arith.addi %arg9, %c1_i32 : i32 + %33 = arith.cmpi slt, %32, %c1_i32 : i32 + %34 = arith.select %33, %32, %c0_i32 : i32 + %35 = triton_gpu.memdesc_subview %10[%34, %c0_i32, %c0_i32] : !tt.memdesc<1x128x32xf16, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %29, %35 : tensor<128x32xf16, #blocked1> -> !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> + %36 = triton_gpu.memdesc_subview %11[%34, %c0_i32, %c0_i32] : !tt.memdesc<1x32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %31, %36 : tensor<32x128xf16, #blocked> -> !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> + scf.yield %26, %27, %25, %34, %35, %36 : tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked>, tensor<128x128xf32, #mma>, i32, !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable>, !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> + } + triton_gpu.local_dealloc %10 : !tt.memdesc<1x128x32xf16, #shared, #triton_gpu.shared_memory, mutable> + triton_gpu.local_dealloc %11 : !tt.memdesc<1x32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> + tt.return %19#2 : tensor<128x128xf32, #mma> + } + + +// This example tests that tt.load overlaps with independent ttg.local_store which +// overlaps with independent tt.dot. +// num_stages == 3, double buffered + +// CHECK-LABEL: tt.func @matmul_loop_mb +// CHECK: %{{.*}}:8 = scf.for %[[ARG5:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %{{.*}}, %[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %{{.*}}, %[[ARG12:.*]] = %{{.*}}, %[[ARG13:.*]] = %{{.*}}) +// Stage 0 +// CHECK: %[[ADDPTR_28:.*]] = tt.addptr %[[ARG6]], %{{.*}} +// CHECK: %[[MULI_29:.*]] = arith.muli %{{.*}}, %{{.*}} +// CHECK: %[[SUBI_30:.*]] = arith.subi %{{.*}}, %[[MULI_29]] +// CHECK: %[[CMPI_31:.*]] = arith.cmpi slt, %[[ARG5]], %[[SUBI_30]] +// CHECK: %[[SPLAT_32:.*]] = tt.splat %[[CMPI_31]] +// CHECK: %[[LOAD_33:.*]] = tt.load %[[ADDPTR_28]], %[[SPLAT_32]] +// CHECK: %[[ADDPTR_34:.*]] = tt.addptr %[[ARG7]], %{{.*}} +// CHECK: %[[SPLAT_35:.*]] = tt.splat %[[CMPI_31]] +// CHECK: %[[LOAD_36:.*]] = tt.load %[[ADDPTR_34]], %[[SPLAT_35]] +// Stage 1 +// CHECK: %[[ADDI_37:.*]] = arith.addi %[[ARG9]], %{{.*}} +// CHECK: %[[CMPI_38:.*]] = arith.cmpi slt, %[[ADDI_37]], %{{.*}} +// CHECK: %[[SELECT_39:.*]] = arith.select %[[CMPI_38]], %[[ADDI_37]], %{{.*}} +// CHECK: %[[MEMDESC_SUBVIEW_40:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_39]], %{{.*}}, %{{.*}}] +// CHECK: triton_gpu.local_store %[[ARG12]], %[[MEMDESC_SUBVIEW_40]] +// CHECK: %[[MEMDESC_SUBVIEW_41:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_39]], %{{.*}}, %{{.*}}] +// CHECK: triton_gpu.local_store %[[ARG13]], %[[MEMDESC_SUBVIEW_41]] +// Stage 2 +// CHECK: %[[LOCAL_LOAD_42:.*]] = triton_gpu.local_load %[[ARG10]] +// CHECK: %[[LOCAL_LOAD_43:.*]] = triton_gpu.local_load %[[ARG11]] +// CHECK: %[[MULF_44:.*]] = arith.mulf %[[LOCAL_LOAD_43]], %{{.*}} +// CHECK: %[[DOT_45:.*]] = tt.dot %[[LOCAL_LOAD_42]], %[[MULF_44]], %[[ARG8]] +// CHECK: scf.yield %[[ADDPTR_28]], %[[ADDPTR_34]], %[[DOT_45]], %[[SELECT_39]], %[[MEMDESC_SUBVIEW_40]], %[[MEMDESC_SUBVIEW_41]], %[[LOAD_33]], %[[LOAD_36]] +// CHECK: } + + tt.func @matmul_loop_mb(%arg0: index, %arg1: index, %arg2: index, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #mma> { + %c2 = arith.constant 2 : index + %c2_i32 = arith.constant 2 : i32 + %c1_i32 = arith.constant 1 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst = arith.constant dense<4.000000e+00> : tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %cst_0 = arith.constant dense<4> : tensor<32x128xi32, #blocked> + %cst_1 = arith.constant dense<4> : tensor<128x32xi32, #blocked1> + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma> + %cst_3 = arith.constant dense<0.000000e+00> : tensor<32x128xf16, #blocked> + %0 = tt.splat %arg3 : !tt.ptr -> tensor<128x32x!tt.ptr, #blocked1> + %1 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x32xi32, #blocked1> + %3 = tt.broadcast %2 : tensor<1x32xi32, #blocked1> -> tensor<128x32xi32, #blocked1> + %4 = tt.addptr %0, %3 : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> + %5 = tt.splat %arg4 : !tt.ptr -> tensor<32x128x!tt.ptr, #blocked> + %6 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %7 = tt.expand_dims %6 {axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> + %8 = tt.broadcast %7 : tensor<1x128xi32, #blocked> -> tensor<32x128xi32, #blocked> + %9 = tt.addptr %5, %8 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> + %10 = triton_gpu.local_alloc : () -> !tt.memdesc<2x128x32xf16, #shared, #triton_gpu.shared_memory, mutable> + %11 = triton_gpu.local_alloc : () -> !tt.memdesc<2x32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> + %12 = arith.cmpi slt, %arg0, %arg1 : index + %13 = tt.splat %12 : i1 -> tensor<128x32xi1, #blocked1> + %14 = tt.load %4, %13 : tensor<128x32x!tt.ptr, #blocked1> + %15 = tt.splat %12 : i1 -> tensor<32x128xi1, #blocked> + %16 = tt.load %9, %15, %cst_3 : tensor<32x128x!tt.ptr, #blocked> + %17 = arith.addi %arg0, %arg2 : index + %18 = arith.cmpi slt, %17, %arg1 : index + %19 = tt.addptr %4, %cst_1 : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> + %20 = tt.addptr %9, %cst_0 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> + %21 = tt.splat %18 : i1 -> tensor<128x32xi1, #blocked1> + %22 = tt.load %19, %21 : tensor<128x32x!tt.ptr, #blocked1> + %23 = tt.splat %18 : i1 -> tensor<32x128xi1, #blocked> + %24 = tt.load %20, %23, %cst_3 : tensor<32x128x!tt.ptr, #blocked> + %25 = triton_gpu.memdesc_subview %10[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<2x128x32xf16, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %14, %25 : tensor<128x32xf16, #blocked1> -> !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> + %26 = triton_gpu.memdesc_subview %11[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<2x32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %16, %26 : tensor<32x128xf16, #blocked> -> !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> + %27:8 = scf.for %arg5 = %arg0 to %arg1 step %arg2 iter_args(%arg6 = %19, %arg7 = %20, %arg8 = %cst_2, %arg9 = %c0_i32, %arg10 = %25, %arg11 = %26, %arg12 = %22, %arg13 = %24) -> (tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked>, tensor<128x128xf32, #mma>, i32, !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable>, !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable>, tensor<128x32xf16, #blocked1>, tensor<32x128xf16, #blocked>) { + %28 = arith.muli %arg2, %c2 : index + %29 = arith.subi %arg1, %28 : index + %30 = arith.cmpi slt, %arg5, %29 : index + %31 = triton_gpu.local_load %arg10 : !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %32 = triton_gpu.local_load %arg11 : !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %33 = arith.mulf %32, %cst : tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %34 = tt.dot %31, %33, %arg8 : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x128xf32, #mma> + %35 = tt.addptr %arg6, %cst_1 : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> + %36 = tt.addptr %arg7, %cst_0 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> + %37 = tt.splat %30 : i1 -> tensor<128x32xi1, #blocked1> + %38 = tt.load %35, %37 : tensor<128x32x!tt.ptr, #blocked1> + %39 = tt.splat %30 : i1 -> tensor<32x128xi1, #blocked> + %40 = tt.load %36, %39, %cst_3 : tensor<32x128x!tt.ptr, #blocked> + %41 = arith.addi %arg9, %c1_i32 : i32 + %42 = arith.cmpi slt, %41, %c2_i32 : i32 + %43 = arith.select %42, %41, %c0_i32 : i32 + %44 = triton_gpu.memdesc_subview %10[%43, %c0_i32, %c0_i32] : !tt.memdesc<2x128x32xf16, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %arg12, %44 : tensor<128x32xf16, #blocked1> -> !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> + %45 = triton_gpu.memdesc_subview %11[%43, %c0_i32, %c0_i32] : !tt.memdesc<2x32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %arg13, %45 : tensor<32x128xf16, #blocked> -> !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> + scf.yield %35, %36, %34, %43, %44, %45, %38, %40 : tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked>, tensor<128x128xf32, #mma>, i32, !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable>, !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable>, tensor<128x32xf16, #blocked1>, tensor<32x128xf16, #blocked> + } + triton_gpu.local_dealloc %10 : !tt.memdesc<2x128x32xf16, #shared, #triton_gpu.shared_memory, mutable> + triton_gpu.local_dealloc %11 : !tt.memdesc<2x32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> + tt.return %27#2 : tensor<128x128xf32, #mma> + } + +// This example shows dependent loads and verifies all are moved early. +// CHECK-LABEL: tt.func @indirect_bmm_vector +// CHECK: %{{.*}}:7 = scf.for %[[ARG6:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %{{.*}}, %[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %{{.*}}, %[[ARG12:.*]] = %{{.*}}, %[[ARG13:.*]] = %{{.*}}) +// Stage 0 +// CHECK: %[[ADDPTR_20:.*]] = tt.addptr %[[ARG8]], %{{.*}} +// CHECK: %[[SUBI_21:.*]] = arith.subi %{{.*}}, %{{.*}} +// CHECK: %[[CMPI_22:.*]] = arith.cmpi slt, %[[ARG6]], %[[SUBI_21]] +// CHECK: %[[SPLAT_23:.*]] = tt.splat %[[CMPI_22]] +// CHECK: %[[LOAD_24:.*]] = tt.load %[[ADDPTR_20]], %[[SPLAT_23]] +// Stage 1.a +// CHECK: %[[EXPAND_DIMS_25:.*]] = tt.expand_dims %[[ARG13]] {axis = 1 : i32} +// CHECK: %[[BROADCAST_26:.*]] = tt.broadcast %[[EXPAND_DIMS_25]] +// CHECK: %[[MULI_27:.*]] = arith.muli %{{.*}}, %[[BROADCAST_26]] +// CHECK: %[[ADDPTR_28:.*]] = tt.addptr %{{.*}}, %[[MULI_27]] +// CHECK: %[[SPLAT_29:.*]] = tt.splat %[[CMPI_22]] +// CHECK: %[[LOAD_30:.*]] = tt.load %[[ADDPTR_28]], %[[SPLAT_29]] +// CHECK: %[[ADDPTR_31:.*]] = tt.addptr %[[ARG9]], %{{.*}} +// CHECK: %[[SUBI_32:.*]] = arith.subi %{{.*}}, %{{.*}} +// CHECK: %[[CMPI_33:.*]] = arith.cmpi slt, %[[ARG6]], %[[SUBI_32]] +// CHECK: %[[SPLAT_34:.*]] = tt.splat %[[CMPI_33]] +// CHECK: %[[LOAD_35:.*]] = tt.load %[[ADDPTR_31]], %[[SPLAT_34]] +// Stage 2 +// CHECK: %[[LOCAL_LOAD_36:.*]] = triton_gpu.local_load %[[ARG11]] +// CHECK: %[[LOCAL_LOAD_37:.*]] = triton_gpu.local_load %[[ARG12]] +// CHECK: %[[DOT_38:.*]] = tt.dot %[[LOCAL_LOAD_36]], %[[LOCAL_LOAD_37]], %[[ARG7]] +// Stage 1.b +// CHECK: %[[ADDI_39:.*]] = arith.addi %[[ARG10]], %{{.*}} +// CHECK: %[[CMPI_40:.*]] = arith.cmpi slt, %[[ADDI_39]], %{{.*}} +// CHECK: %[[SELECT_41:.*]] = arith.select %[[CMPI_40]], %[[ADDI_39]], %{{.*}} +// CHECK: %[[MEMDESC_SUBVIEW_42:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_41]], %{{.*}}, %{{.*}}] +// CHECK: triton_gpu.local_store %[[LOAD_24]], %[[MEMDESC_SUBVIEW_42]] +// CHECK: %[[MEMDESC_SUBVIEW_43:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_41]], %{{.*}}, %{{.*}}] +// CHECK: triton_gpu.local_store %[[LOAD_30]], %[[MEMDESC_SUBVIEW_43]] +// CHECK: scf.yield %[[DOT_38]], %[[ADDPTR_20]], %[[ADDPTR_31]], %[[SELECT_41]], %[[MEMDESC_SUBVIEW_42]], %[[MEMDESC_SUBVIEW_43]], %[[LOAD_35]] +// CHECK: } + + tt.func @indirect_bmm_vector(%arg0: tensor<16x16xi64, #blocked> {tt.constancy = 16 : i32, tt.divisibility = 16 : i32}, %arg1: index, %arg2: tensor<16x16x!tt.ptr, #blocked1> {tt.contiguity = 2 : i32, tt.divisibility = 16 : i32}, %arg3: tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, %arg4: tensor<16x16xi32, #blocked1> {tt.constancy = 16 : i32, tt.divisibility = 16 : i32}, %arg5: tensor<16x16x!tt.ptr, #blocked> {tt.contiguity = 16 : i32, tt.divisibility = 16 : i32}) -> tensor<16x16xf32, #mma> { + %c2 = arith.constant 2 : index + %c0_i32 = arith.constant 0 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma> + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %c1_i32 = arith.constant 1 : i32 + %cst_0 = arith.constant dense<1> : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> + %1 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> + %2 = arith.cmpi sgt, %arg1, %c0 : index + %3 = tt.splat %2 : i1 -> tensor<16xi1, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %4 = tt.load %arg3, %3 : tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %5 = arith.cmpi sgt, %arg1, %c1 : index + %6 = tt.addptr %arg3, %cst_0 : tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %7 = tt.splat %2 : i1 -> tensor<16x16xi1, #blocked1> + %8 = tt.load %arg2, %7 : tensor<16x16x!tt.ptr, #blocked1> + %9 = tt.expand_dims %4 {axis = 1 : i32} : tensor<16xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<16x1xi64, #blocked> + %10 = tt.broadcast %9 : tensor<16x1xi64, #blocked> -> tensor<16x16xi64, #blocked> + %11 = arith.muli %arg0, %10 : tensor<16x16xi64, #blocked> + %12 = tt.addptr %arg5, %11 : tensor<16x16x!tt.ptr, #blocked>, tensor<16x16xi64, #blocked> + %13 = tt.splat %2 : i1 -> tensor<16x16xi1, #blocked> + %14 = tt.load %12, %13 : tensor<16x16x!tt.ptr, #blocked> + %15 = tt.splat %5 : i1 -> tensor<16xi1, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %16 = tt.load %6, %15 : tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %17 = triton_gpu.memdesc_subview %0[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<1x16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %8, %17 : tensor<16x16xf16, #blocked1> -> !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> + %18 = triton_gpu.memdesc_subview %1[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<1x16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %14, %18 : tensor<16x16xf16, #blocked> -> !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> + %19:7 = scf.for %arg6 = %c0 to %arg1 step %c1 iter_args(%arg7 = %cst, %arg8 = %arg2, %arg9 = %6, %arg10 = %c0_i32, %arg11 = %17, %arg12 = %18, %arg13 = %16) -> (tensor<16x16xf32, #mma>, tensor<16x16x!tt.ptr, #blocked1>, tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, i32, !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable>, !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable>, tensor<16xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) { + %20 = arith.subi %arg1, %c2 : index + %21 = arith.cmpi slt, %arg6, %20 : index + %22 = arith.subi %arg1, %c1 : index + %23 = arith.cmpi slt, %arg6, %22 : index + %24 = triton_gpu.local_load %arg11 : !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %25 = triton_gpu.local_load %arg12 : !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %26 = tt.dot %24, %25, %arg7 : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x16xf32, #mma> + %27 = tt.addptr %arg8, %arg4 : tensor<16x16x!tt.ptr, #blocked1>, tensor<16x16xi32, #blocked1> + %28 = tt.addptr %arg9, %cst_0 : tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %29 = tt.splat %23 : i1 -> tensor<16x16xi1, #blocked1> + %30 = tt.load %27, %29 : tensor<16x16x!tt.ptr, #blocked1> + %31 = tt.expand_dims %arg13 {axis = 1 : i32} : tensor<16xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<16x1xi64, #blocked> + %32 = tt.broadcast %31 : tensor<16x1xi64, #blocked> -> tensor<16x16xi64, #blocked> + %33 = arith.muli %arg0, %32 : tensor<16x16xi64, #blocked> + %34 = tt.addptr %arg5, %33 : tensor<16x16x!tt.ptr, #blocked>, tensor<16x16xi64, #blocked> + %35 = tt.splat %23 : i1 -> tensor<16x16xi1, #blocked> + %36 = tt.load %34, %35 : tensor<16x16x!tt.ptr, #blocked> + %37 = tt.splat %21 : i1 -> tensor<16xi1, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %38 = tt.load %28, %37 : tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %39 = arith.addi %arg10, %c1_i32 : i32 + %40 = arith.cmpi slt, %39, %c1_i32 : i32 + %41 = arith.select %40, %39, %c0_i32 : i32 + %42 = triton_gpu.memdesc_subview %0[%41, %c0_i32, %c0_i32] : !tt.memdesc<1x16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %30, %42 : tensor<16x16xf16, #blocked1> -> !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> + %43 = triton_gpu.memdesc_subview %1[%41, %c0_i32, %c0_i32] : !tt.memdesc<1x16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %36, %43 : tensor<16x16xf16, #blocked> -> !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> + scf.yield %26, %27, %28, %41, %42, %43, %38 : tensor<16x16xf32, #mma>, tensor<16x16x!tt.ptr, #blocked1>, tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, i32, !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable>, !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable>, tensor<16xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + } + triton_gpu.local_dealloc %0 : !tt.memdesc<1x16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> + triton_gpu.local_dealloc %1 : !tt.memdesc<1x16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> + tt.return %19#0 : tensor<16x16xf32, #mma> + } +} + +// ----- + +// CHECK-LABEL: sink_convert_dealloc +// CHECK-COUNT-2: triton_gpu.local_dealloc %{{.+}} : !tt.memdesc<4x128x64xf16, #shared, mutable> +// CHECK: triton_gpu.convert_layout %arg0 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #blocked1> +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1]}> +module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @sink_convert_dealloc(%arg0: tensor<32x32xf32, #blocked>) attributes {noinline = false} { + %0 = triton_gpu.local_alloc : () -> !tt.memdesc<4x128x64xf16, #shared, mutable> + %1 = triton_gpu.local_alloc : () -> !tt.memdesc<4x128x64xf16, #shared, mutable> + %2 = triton_gpu.convert_layout %arg0 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #blocked1> + triton_gpu.local_dealloc %0 : !tt.memdesc<4x128x64xf16, #shared, mutable> + triton_gpu.local_dealloc %1 : !tt.memdesc<4x128x64xf16, #shared, mutable> + %3 = arith.addf %2, %2 : tensor<32x32xf32, #blocked1> + tt.return + } +} + +// ----- + +// CHECK-LABEL: anchor_barrier +// CHECK: gpu.barrier +// CHECK: tt.load %arg0 : tensor<32x32x!tt.ptr, #blocked> +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1]}> +module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @anchor_barrier(%arg0: tensor<32x32x!tt.ptr, #blocked>) attributes {noinline = false} { + %0 = triton_gpu.local_alloc : () -> !tt.memdesc<4x128x64xf16, #shared, mutable> + gpu.barrier + %2 = tt.load %arg0 : tensor<32x32x!tt.ptr, #blocked> + %1 = triton_gpu.local_alloc %2 : (tensor<32x32xf16, #blocked>) -> !tt.memdesc<4x128x64xf16, #shared, mutable> + triton_gpu.local_dealloc %0 : !tt.memdesc<4x128x64xf16, #shared, mutable> + triton_gpu.local_dealloc %1 : !tt.memdesc<4x128x64xf16, #shared, mutable> + tt.return + } +} diff --git a/test/TritonGPU/amd/amd-stream-pipeline.mlir b/test/TritonGPU/amd/amd-stream-pipeline.mlir deleted file mode 100644 index 4b2de3336413..000000000000 --- a/test/TritonGPU/amd/amd-stream-pipeline.mlir +++ /dev/null @@ -1,44 +0,0 @@ -// RUN: triton-opt %s -split-input-file --tritonamdgpu-stream-pipeline | FileCheck %s - -// CHECK-LABEL: @check_stream_pipeline_epilogue -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [16, 4], warpsPerCTA = [1, 1], order = [1, 0]}> -#mma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [32, 32], isTransposed = false}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, triton_gpu.target = "hip:gfx90a", "triton_gpu.threads-per-warp" = 64 : i32} { - tt.func public @check_stream_pipeline_epilogue(%Aptr: tensor<32x32x!tt.ptr, #blocked>, %Bptr : tensor<32x32x!tt.ptr, #blocked>, %arg4 : i32, %arg5 : i1) { - %cst_0 = arith.constant dense<16> : tensor<32x32xi32, #blocked> - %cst_2 = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #blocked> - %cst_5 = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> - %c0_i32 = arith.constant 0 : i32 - %c1_i32 = arith.constant 1 : i32 - // CHECK: scf.for {{.*}} = %[[LB:.*]] to %[[UB:.*]] step %[[STEP:.*]] iter_args({{.*}}) - %36:3 = scf.for %arg9 = %c0_i32 to %arg4 step %c1_i32 iter_args(%arg10 = %cst_5, %arg12 = %Aptr, %arg13 = %Bptr) -> (tensor<32x32xf32, #mma>, tensor<32x32x!tt.ptr, #blocked>, tensor<32x32x!tt.ptr, #blocked>) : i32 { - %61 = arith.muli %arg9, %arg4 : i32 - %62 = arith.cmpi slt, %arg4, %61 : i32 - %63 = tt.splat %62 : i1 -> tensor<32x32xi1, #blocked> - // This load will not be pipelined - %66 = tt.load %arg12, %63 : tensor<32x32x!tt.ptr, #blocked> - // This load will be pipelined - %70 = tt.load %arg13 : tensor<32x32x!tt.ptr, #blocked> - %71 = triton_gpu.convert_layout %66 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> - %72 = triton_gpu.convert_layout %70 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> - %73 = tt.dot %71, %72, %arg10 : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma> - // This scf.if will make load at %66 non-pipelineable - %74 = scf.if %arg5 -> (tensor<32x32xf32, #blocked>){ - scf.yield %66 : tensor<32x32xf32, #blocked> - } else { - scf.yield %cst_2: tensor<32x32xf32, #blocked> - } - %75 = tt.addptr %arg12, %cst_0 : tensor<32x32x!tt.ptr, #blocked>, tensor<32x32xi32, #blocked> - %76 = tt.addptr %arg13, %cst_0 : tensor<32x32x!tt.ptr, #blocked>, tensor<32x32xi32, #blocked> - scf.yield %73, %75, %76 : tensor<32x32xf32, #mma>, tensor<32x32x!tt.ptr, #blocked>, tensor<32x32x!tt.ptr, #blocked> - } - // CHECK: %[[C1:.*]] = arith.constant 1 : i32 - // CHECK: %[[t0:.*]] = arith.subi %[[UB:.*]], %[[C1]] - // CHECK: %[[t1:.*]] = arith.subi %[[t0]], %[[LB]] - // CHECK: %[[t2:.*]] = arith.divui %[[t1]], %[[STEP]] - // CHECK: %[[t3:.*]] = arith.muli %[[t2]], %[[STEP]] - // CHECK: %[[PPLUB:.*]] = arith.addi %[[LB]], %[[t3]] - // CHECK: arith.muli %[[PPLUB]], {{.*}} - tt.return - } -} diff --git a/test/TritonGPU/combine.mlir b/test/TritonGPU/combine.mlir index b951f5041cf9..9118bc4f2fc6 100644 --- a/test/TritonGPU/combine.mlir +++ b/test/TritonGPU/combine.mlir @@ -2316,11 +2316,11 @@ tt.func @assertop(%ptr: tensor<1024x!tt.ptr, #blocked>) { #blocked3 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { - // CHECK-LABEL: @dot_wait_propagate - tt.func public @dot_wait_propagate(%arg0: tensor<16x2xf32, #blocked>) -> tensor<16x2xf32, #blocked> { + // CHECK-LABEL: @warp_group_dot_wait_propagate + tt.func public @warp_group_dot_wait_propagate(%arg0: tensor<16x2xf32, #blocked>) -> tensor<16x2xf32, #blocked> { // CHECK-NOT: triton_gpu.convert_layout %a = triton_gpu.convert_layout %arg0 : tensor<16x2xf32, #blocked> -> tensor<16x2xf32, #blocked1> - %b = triton_nvidia_gpu.dot_wait %a {pendings = 0 : i32} : tensor<16x2xf32, #blocked1> + %b = triton_nvidia_gpu.warp_group_dot_wait %a {pendings = 0 : i32} : tensor<16x2xf32, #blocked1> %c = triton_gpu.convert_layout %b : tensor<16x2xf32, #blocked1> -> tensor<16x2xf32, #blocked> tt.return %c : tensor<16x2xf32, #blocked> } diff --git a/test/TritonGPU/dot-operands.mlir b/test/TritonGPU/dot-operands.mlir index ed24e5f58f45..82fc1ddf7b65 100644 --- a/test/TritonGPU/dot-operands.mlir +++ b/test/TritonGPU/dot-operands.mlir @@ -165,10 +165,10 @@ tt.func @update_kwidth_slice( module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { // CHECK: tt.func @mma_v3_reg_operand_A // CHECK: %[[A:.+]] = triton_gpu.convert_layout %{{.*}} : tensor<128x64xf16, #mma> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> -// CHECK: tt.dot %[[A]], {{.*}} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * !tt.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> +// CHECK: triton_nvidia_gpu.warp_group_dot %[[A]], {{.*}} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * !tt.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> tt.func @mma_v3_reg_operand_A(%arg0: tensor<128x64xf16, #mma>, %arg1: !tt.memdesc<64x64xf16, #shared>, %arg2: tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>{ %A = triton_gpu.local_alloc %arg0 : (tensor<128x64xf16, #mma>) -> !tt.memdesc<128x64xf16, #shared1> - %r = tt.dot %A, %arg1, %arg2 : !tt.memdesc<128x64xf16, #shared1> * !tt.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> + %r = triton_nvidia_gpu.warp_group_dot %A, %arg1, %arg2 : !tt.memdesc<128x64xf16, #shared1> * !tt.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> tt.return %r : tensor<128x64xf32, #mma> } } @@ -181,10 +181,10 @@ tt.func @mma_v3_reg_operand_A(%arg0: tensor<128x64xf16, #mma>, %arg1: !tt.memdes module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { // CHECK: tt.func @mma_v3_reg_operand_A_fp8 // CHECK: %[[A:.+]] = triton_gpu.convert_layout %{{.*}} : tensor<128x64xf8E5M2, #mma> -> tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> -// CHECK: tt.dot %[[A]], {{.*}} : tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * !tt.memdesc<64x64xf8E5M2, #shared> -> tensor<128x64xf32, #mma> +// CHECK: triton_nvidia_gpu.warp_group_dot %[[A]], {{.*}} : tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * !tt.memdesc<64x64xf8E5M2, #shared> -> tensor<128x64xf32, #mma> tt.func @mma_v3_reg_operand_A_fp8(%arg0: tensor<128x64xf8E5M2, #mma>, %arg1: !tt.memdesc<64x64xf8E5M2, #shared>, %arg2: tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>{ %A = triton_gpu.local_alloc %arg0 : (tensor<128x64xf8E5M2, #mma>) -> !tt.memdesc<128x64xf8E5M2, #shared1> - %r = tt.dot %A, %arg1, %arg2 : !tt.memdesc<128x64xf8E5M2, #shared1> * !tt.memdesc<64x64xf8E5M2, #shared> -> tensor<128x64xf32, #mma> + %r = triton_nvidia_gpu.warp_group_dot %A, %arg1, %arg2 : !tt.memdesc<128x64xf8E5M2, #shared1> * !tt.memdesc<64x64xf8E5M2, #shared> -> tensor<128x64xf32, #mma> tt.return %r : tensor<128x64xf32, #mma> } } diff --git a/test/TritonGPU/fence-inserstion.mlir b/test/TritonGPU/fence-inserstion.mlir index bf15adbdb66c..9ed3646d92b2 100644 --- a/test/TritonGPU/fence-inserstion.mlir +++ b/test/TritonGPU/fence-inserstion.mlir @@ -12,7 +12,7 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : %0 = triton_gpu.local_alloc %arg0 : (tensor<128x128xf16, #blocked>) -> !tt.memdesc<128x128xf16, #shared> %1 = triton_gpu.local_alloc %arg1 : (tensor<128x64xf16, #blocked2>) -> !tt.memdesc<128x64xf16, #shared1> // CHECK: triton_nvidia_gpu.fence_async_shared - %2 = tt.dot %0, %1, %cst : !tt.memdesc<128x128xf16, #shared> * !tt.memdesc<128x64xf16, #shared1> -> tensor<128x64xf32, #mma> + %2 = triton_nvidia_gpu.warp_group_dot %0, %1, %cst : !tt.memdesc<128x128xf16, #shared> * !tt.memdesc<128x64xf16, #shared1> -> tensor<128x64xf32, #mma> tt.return } } @@ -36,10 +36,10 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : // CHECK: triton_nvidia_gpu.fence_async_shared // CHECK: scf.for // CHECK-NOT: triton_nvidia_gpu.fence_async_shared - // CHECK: tt.dot + // CHECK: triton_nvidia_gpu.warp_group_dot scf.for %iv0 = %c0_i32 to %c64_i32 step %c32_i32 : i32 { scf.for %iv1 = %c0_i32 to %c64_i32 step %c32_i32 : i32 { - %2 = tt.dot %0, %1, %cst : !tt.memdesc<128x128xf16, #shared> * !tt.memdesc<128x64xf16, #shared1> -> tensor<128x64xf32, #mma> + %2 = triton_nvidia_gpu.warp_group_dot %0, %1, %cst : !tt.memdesc<128x128xf16, #shared> * !tt.memdesc<128x64xf16, #shared1> -> tensor<128x64xf32, #mma> } } tt.return diff --git a/test/TritonGPU/loop-pipeline-hip.mlir b/test/TritonGPU/loop-pipeline-hip.mlir new file mode 100644 index 000000000000..ffb90026da26 --- /dev/null +++ b/test/TritonGPU/loop-pipeline-hip.mlir @@ -0,0 +1,161 @@ +// RUN: triton-opt %s -split-input-file -tritonamdgpu-stream-pipeline-v2=num_stages=2 -canonicalize | FileCheck %s + +#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = false}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> +module attributes {"triton_gpu.target" = "hip:gfx942", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: tt.func @load_two_users + tt.func @load_two_users(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) -> (tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma>) { + %cst = arith.constant dense<0> : tensor<1x16xi32, #blocked> + %cst_0 = arith.constant dense<0> : tensor<128x1xi32, #blocked1> + %c0_i64 = arith.constant 0 : i64 + %c0_i32 = arith.constant 0 : i32 + %cst_1 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma> + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %0 = tt.addptr %arg0, %c0_i64 : !tt.ptr, i64 + %1 = tt.addptr %arg1, %c0_i64 : !tt.ptr, i64 + %2 = tt.splat %1 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked1> + %3 = tt.addptr %2, %cst_0 : tensor<128x1x!tt.ptr, #blocked1>, tensor<128x1xi32, #blocked1> + %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %6 = tt.broadcast %3 : tensor<128x1x!tt.ptr, #blocked1> -> tensor<128x64x!tt.ptr, #blocked1> + %7 = tt.broadcast %5 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> + %8 = tt.addptr %6, %7 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %9 = tt.load %8 : tensor<128x64x!tt.ptr, #blocked1> + %10 = tt.splat %0 : !tt.ptr -> tensor<1x16x!tt.ptr, #blocked> + %11 = tt.addptr %10, %cst : tensor<1x16x!tt.ptr, #blocked>, tensor<1x16xi32, #blocked> + %12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %14 = tt.broadcast %11 : tensor<1x16x!tt.ptr, #blocked> -> tensor<64x16x!tt.ptr, #blocked> + %15 = tt.broadcast %13 : tensor<64x1xi32, #blocked> -> tensor<64x16xi32, #blocked> + %16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> + // CHECK: triton_gpu.local_store + // CHECK: scf.for + // CHECK: tt.dot + // CHECK: tt.dot + // CHECK: tt.load + // CHECK: triton_gpu.local_store + // CHECK: scf.yield + %17:2 = scf.for %arg2 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg3 = %cst_1, %arg4 = %cst_2) -> (tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma>) : i32 { + %18 = tt.load %16 : tensor<64x16x!tt.ptr, #blocked> + %19 = triton_gpu.convert_layout %9 : tensor<128x64xf16, #blocked1> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %20 = triton_gpu.convert_layout %18 : tensor<64x16xf16, #blocked> -> tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %21 = tt.dot %19, %20, %cst_1 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x16xf32, #mma> + %22 = arith.truncf %21 : tensor<128x16xf32, #mma> to tensor<128x16xf16, #mma> + %23 = triton_gpu.convert_layout %22 : tensor<128x16xf16, #mma> -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %24 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared, #triton_gpu.shared_memory> + %25 = tt.trans %24 {order=array} : !tt.memdesc<64x16xf16, #shared, #triton_gpu.shared_memory> -> !tt.memdesc<16x64xf16, #shared1, #triton_gpu.shared_memory> + %26 = triton_gpu.local_load %25 : !tt.memdesc<16x64xf16, #shared1, #triton_gpu.shared_memory> -> tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %27 = tt.dot %23, %26, %arg4 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x64xf32, #mma> + scf.yield %21, %27 : tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma> + } + tt.return %17#0, %17#1 : tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma> + } +} + +// ----- + +// CHECK-LABEL: tt.func public @_jagged_hstu_attn_fwd_0d1d2d3d4d5de +// CHECK-NOT: triton_gpu.convert_layout {{.*}} : tensor<32x64xf32, #shared> -> tensor<32x64xf32, #shared1> + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [64, 1], warpsPerCTA = [2, 2], order = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [64, 1], warpsPerCTA = [1, 4], order = [0, 1]}> +#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1], hasLeadingOffset = false}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}> +module attributes {"triton_gpu.target" = "hip:gfx942", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func public @_jagged_hstu_attn_fwd_0d1d2d3d4d5de(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<64x32xf32, #mma> + %c64_i32 = arith.constant 64 : i32 + %c0_i32 = arith.constant 0 : i32 + %c32_i32 = arith.constant 32 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c64_i32 : i32 + %2 = tt.get_program_id y : i32 + %3 = tt.load %arg3 : !tt.ptr + %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %5 = tt.splat %1 : i32 -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %6 = arith.addi %5, %4 : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %7 = tt.expand_dims %6 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %8 = tt.splat %3 : i64 -> tensor<64x1xi64, #blocked> + %9 = arith.extsi %7 : tensor<64x1xi32, #blocked> to tensor<64x1xi64, #blocked> + %10 = arith.addi %8, %9 : tensor<64x1xi64, #blocked> + %11 = arith.extsi %arg5 : i32 to i64 + %12 = tt.splat %11 : i64 -> tensor<64x1xi64, #blocked> + %13 = arith.muli %10, %12 : tensor<64x1xi64, #blocked> + %14 = arith.muli %2, %arg5 : i32 + %15 = arith.extsi %14 : i32 to i64 + %16 = tt.splat %15 : i64 -> tensor<64x1xi64, #blocked> + %17 = arith.addi %13, %16 : tensor<64x1xi64, #blocked> + %18 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %19 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %20 = tt.expand_dims %18 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> + %21 = tt.expand_dims %19 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %22 = tt.splat %arg5 : i32 -> tensor<1x64xi32, #blocked> + %23 = tt.splat %arg5 : i32 -> tensor<1x64xi32, #blocked1> + %24 = arith.muli %20, %22 : tensor<1x64xi32, #blocked> + %25 = arith.muli %21, %23 : tensor<1x64xi32, #blocked1> + %26 = tt.broadcast %17 : tensor<64x1xi64, #blocked> -> tensor<64x64xi64, #blocked> + %27 = arith.extsi %24 : tensor<1x64xi32, #blocked> to tensor<1x64xi64, #blocked> + %28 = arith.extsi %25 : tensor<1x64xi32, #blocked1> to tensor<1x64xi64, #blocked1> + %29 = tt.broadcast %27 : tensor<1x64xi64, #blocked> -> tensor<64x64xi64, #blocked> + %30 = arith.addi %26, %29 : tensor<64x64xi64, #blocked> + %31 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %32 = tt.expand_dims %31 {axis = 1 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<32x1xi32, #blocked1> + %33 = tt.splat %3 : i64 -> tensor<32x1xi64, #blocked1> + %34 = arith.extsi %32 : tensor<32x1xi32, #blocked1> to tensor<32x1xi64, #blocked1> + %35 = arith.addi %33, %34 : tensor<32x1xi64, #blocked1> + %36 = tt.splat %11 : i64 -> tensor<32x1xi64, #blocked1> + %37 = arith.muli %35, %36 : tensor<32x1xi64, #blocked1> + %38 = tt.splat %15 : i64 -> tensor<32x1xi64, #blocked1> + %39 = arith.addi %37, %38 : tensor<32x1xi64, #blocked1> + %40 = tt.broadcast %39 : tensor<32x1xi64, #blocked1> -> tensor<32x64xi64, #blocked1> + %41 = tt.broadcast %28 : tensor<1x64xi64, #blocked1> -> tensor<32x64xi64, #blocked1> + %42 = arith.addi %40, %41 : tensor<32x64xi64, #blocked1> + %43 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %44 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %45 = tt.expand_dims %43 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x32xi32, #blocked1> + %46 = tt.expand_dims %44 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x32xi32, #blocked> + %47 = tt.splat %arg5 : i32 -> tensor<1x32xi32, #blocked1> + %48 = tt.splat %arg5 : i32 -> tensor<1x32xi32, #blocked> + %49 = arith.muli %45, %47 : tensor<1x32xi32, #blocked1> + %50 = arith.muli %46, %48 : tensor<1x32xi32, #blocked> + %51 = tt.broadcast %39 : tensor<32x1xi64, #blocked1> -> tensor<32x32xi64, #blocked1> + %52 = arith.extsi %49 : tensor<1x32xi32, #blocked1> to tensor<1x32xi64, #blocked1> + %53 = arith.extsi %50 : tensor<1x32xi32, #blocked> to tensor<1x32xi64, #blocked> + %54 = tt.broadcast %52 : tensor<1x32xi64, #blocked1> -> tensor<32x32xi64, #blocked1> + %55 = arith.addi %51, %54 : tensor<32x32xi64, #blocked1> + %56 = tt.splat %arg0 : !tt.ptr -> tensor<64x64x!tt.ptr, #blocked> + %57 = tt.addptr %56, %30 : tensor<64x64x!tt.ptr, #blocked>, tensor<64x64xi64, #blocked> + %58 = tt.splat %arg1 : !tt.ptr -> tensor<32x64x!tt.ptr, #blocked1> + %59 = tt.addptr %58, %42 : tensor<32x64x!tt.ptr, #blocked1>, tensor<32x64xi64, #blocked1> + %60 = tt.splat %arg2 : !tt.ptr -> tensor<32x32x!tt.ptr, #blocked1> + %61 = tt.addptr %60, %55 : tensor<32x32x!tt.ptr, #blocked1>, tensor<32x32xi64, #blocked1> + %62 = tt.load %57 : tensor<64x64x!tt.ptr, #blocked> + %63 = scf.for %arg6 = %c0_i32 to %c64_i32 step %c32_i32 iter_args(%arg7 = %cst) -> (tensor<64x32xf32, #mma>) : i32 { + %70 = tt.load %59 : tensor<32x64x!tt.ptr, #blocked1> + %71 = triton_gpu.convert_layout %62 : tensor<64x64xf32, #blocked> -> tensor<64x64xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %72 = triton_gpu.local_alloc %70 : (tensor<32x64xf32, #blocked1>) -> !tt.memdesc<32x64xf32, #shared, #triton_gpu.shared_memory> + %73 = tt.trans %72 {order=array} : !tt.memdesc<32x64xf32, #shared, #triton_gpu.shared_memory> -> !tt.memdesc<64x32xf32, #shared1, #triton_gpu.shared_memory> + %74 = triton_gpu.local_load %73 : !tt.memdesc<64x32xf32, #shared1, #triton_gpu.shared_memory> -> tensor<64x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + %75 = tt.dot %71, %74, %cst : tensor<64x64xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<64x32xf32, #mma> + %76 = tt.load %61 : tensor<32x32x!tt.ptr, #blocked1> + %77 = triton_gpu.convert_layout %75 : tensor<64x32xf32, #mma> -> tensor<64x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %78 = triton_gpu.convert_layout %76 : tensor<32x32xf32, #blocked1> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + %79 = tt.dot %77, %78, %arg7 : tensor<64x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<64x32xf32, #mma> + scf.yield %79 : tensor<64x32xf32, #mma> + } + %64 = tt.broadcast %17 : tensor<64x1xi64, #blocked> -> tensor<64x32xi64, #blocked> + %65 = tt.broadcast %53 : tensor<1x32xi64, #blocked> -> tensor<64x32xi64, #blocked> + %66 = arith.addi %64, %65 : tensor<64x32xi64, #blocked> + %67 = tt.splat %arg4 : !tt.ptr -> tensor<64x32x!tt.ptr, #blocked> + %68 = tt.addptr %67, %66 : tensor<64x32x!tt.ptr, #blocked>, tensor<64x32xi64, #blocked> + %69 = triton_gpu.convert_layout %63 : tensor<64x32xf32, #mma> -> tensor<64x32xf32, #blocked> + tt.store %68, %69 : tensor<64x32x!tt.ptr, #blocked> + tt.return + } +} // end module diff --git a/test/TritonGPU/loop-pipeline-hopper.mlir b/test/TritonGPU/loop-pipeline-hopper.mlir index 48fd5f22e870..c62c91d939a4 100644 --- a/test/TritonGPU/loop-pipeline-hopper.mlir +++ b/test/TritonGPU/loop-pipeline-hopper.mlir @@ -18,11 +18,11 @@ // CHECK: %[[BBUFFER:.*]] = triton_gpu.local_alloc // CHECK-DAG: %[[LOOP_COND_0:.*]] = arith.cmpi slt, %[[LB:.*]], %[[UB:.*]] // CHECK-DAG: %[[LOOP_COND_0_SPLAT_A:.*]] = tt.splat %[[LOOP_COND_0]] -// CHECK-DAG: %[[ASUB:.*]] = triton_gpu.memdesc_subview %[[ABUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] : !tt.memdesc<2x128x32xf16, #shared, mutable> -> !tt.memdesc<128x32xf16, #shared, mutable> -// CHECK: %[[T_A0:.*]] = triton_gpu.async_copy_global_to_local %{{.*}}, %[[ASUB]] mask %[[LOOP_COND_0_SPLAT_A]] : tensor<128x32x!tt.ptr, #blocked1> -> <128x32xf16, #shared, mutable> +// CHECK-DAG: %[[ASUB:.*]] = triton_gpu.memdesc_subview %[[ABUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] : !tt.memdesc<2x128x32xf16, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> +// CHECK: %[[T_A0:.*]] = triton_gpu.async_copy_global_to_local %{{.*}}, %[[ASUB]] mask %[[LOOP_COND_0_SPLAT_A]] : tensor<128x32x!tt.ptr, #blocked1> -> <128x32xf16, #shared, #triton_gpu.shared_memory, mutable> // CHECK-DAG: %[[LOOP_COND_0_SPLAT_B:.*]] = tt.splat %[[LOOP_COND_0]] // CHECK-DAG: %[[BSUB:.*]] = triton_gpu.memdesc_subview %[[BBUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: %[[T_B0:.*]] = triton_gpu.async_copy_global_to_local %{{.*}}, %[[BSUB]] mask %[[LOOP_COND_0_SPLAT_B]] other %{{.*}} : tensor<32x128x!tt.ptr, #blocked> -> <32x128xf16, #shared1, mutable> +// CHECK: %[[T_B0:.*]] = triton_gpu.async_copy_global_to_local %{{.*}}, %[[BSUB]] mask %[[LOOP_COND_0_SPLAT_B]] other %{{.*}} : tensor<32x128x!tt.ptr, #blocked> -> <32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> // CHECK-DAG: %[[IV_1:.*]] = arith.addi %[[LB]], %[[STEP:.*]] // CHECK-DAG: %[[LOOP_COND_1:.*]] = arith.cmpi slt, %[[IV_1]], %[[UB]] // CHECK-DAG: %[[LOOP_COND_1_SPLAT_A:.*]] = tt.splat %[[LOOP_COND_1]] @@ -303,8 +303,8 @@ tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, //// C-HECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_2]], %[[LOOP_IDX:.*]] = %[[CONSTANT_0]] // // C-HECK: %[[MBARRIER_AB_ITER:.*]] = triton_nvidia_gpu.extract_mbarrier %[[MBARRIER_AB]][{{.*}}] // // C-HECK: triton_nvidia_gpu.mbarrier_wait %[[MBARRIER_AB_ITER]], {{.*}} -// // C-HECK: triton_nvidia_gpu.dot_async %[[arg_a0]], %[[arg_b0]], {{.*}} -// // C-HECK: triton_nvidia_gpu.dot_wait {{.*}} +// // C-HECK: triton_nvidia_gpu.warp_group_dot %[[arg_a0]], %[[arg_b0]], {{.*}} +// // C-HECK: triton_nvidia_gpu.warp_group_dot_wait {{.*}} // // C-HECK: %[[EMPTY_BARRIER_B_ITER_ARRIVE:.*]] = triton_nvidia_gpu.extract_mbarrier %[[EMPTY_BARRIER_B]][{{.*}}] // // C-HECK: triton_nvidia_gpu.mbarrier_arrive %[[EMPTY_BARRIER_B_ITER_ARRIVE]] // // C-HECK: %[[MBARRIER_AB_NEXT_ITER:.*]] = triton_nvidia_gpu.extract_mbarrier %[[MBARRIER_AB]][{{.*}}] @@ -332,9 +332,9 @@ tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, // %a = tt.load %a_tileptr : !tt.ptr, 1> // %b = tt.load %b_tileptr : !tt.ptr, 1> // -// %sa = triton_gpu.local_alloc %a : (tensor<128x32xf16, #BA>) -> !tt.memdesc<128x32xf16, #SA> -// %sb = triton_gpu.local_alloc %b : (tensor<32x128xf16, #BB>) -> !tt.memdesc<32x128xf16, #SB> -// %c = triton_gpu_nvidia.dot_async %sa, %sb, %prev_c : tensor<128x32xf16, #SA> * tensor<32x128xf16, #SB> -> tensor<128x128xf32, #C> +// %sa = triton_gpu.local_alloc %a : (tensor<128x32xf16, #BA>) -> !tt.memdesc<128x32xf16, #SA, #triton_gpu.shared_memory> +// %sb = triton_gpu.local_alloc %b : (tensor<32x128xf16, #BB>) -> !tt.memdesc<32x128xf16, #SB, #triton_gpu.shared_memory> +// %c = triton_nvidia_gpu.warp_group_dot %sa, %sb, %prev_c : tensor<128x32xf16, #SA> * tensor<32x128xf16, #SB> -> tensor<128x128xf32, #C> // // %a_tileptr_next = tt.advance %a_tileptr, [%c0, %c32_i32] : !tt.ptr, 1> // %b_tileptr_next = tt.advance %b_tileptr, [%c32_i32, %c0] : !tt.ptr, 1> @@ -384,21 +384,21 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : %16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> // CHECK: scf.for // CHECK: triton_gpu.async_wait {{.*}} {num = 1 : i32} - // CHECK: triton_nvidia_gpu.dot_async - // CHECK-NEXT: triton_nvidia_gpu.dot_wait {{.*}} {pendings = 0 : i32} - // CHECK: triton_nvidia_gpu.dot_async + // CHECK: triton_nvidia_gpu.warp_group_dot + // CHECK-NEXT: triton_nvidia_gpu.warp_group_dot_wait {{.*}} {pendings = 0 : i32} + // CHECK: triton_nvidia_gpu.warp_group_dot // CHECK: triton_gpu.async_copy_global_to_local // CHECK: triton_gpu.async_commit_group // CHECK: scf.yield %17:2 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_3, %arg5 = %16) -> (tensor<128x64xf32, #mma>, tensor<64x16x!tt.ptr, #blocked>) : i32 { %18 = tt.load %arg5 : tensor<64x16x!tt.ptr, #blocked> - %19 = triton_gpu.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !tt.memdesc<128x64xf16, #shared> - %20 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared1> - %21 = tt.dot %19, %20, %cst_2 : !tt.memdesc<128x64xf16, #shared> * !tt.memdesc<64x16xf16, #shared1> -> tensor<128x16xf32, #mma1> + %19 = triton_gpu.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> + %20 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> + %21 = triton_nvidia_gpu.warp_group_dot %19, %20, %cst_2 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> %22 = arith.truncf %21 : tensor<128x16xf32, #mma1> to tensor<128x16xf16, #mma1> - %23 = tt.trans %20 {order=array} : !tt.memdesc<64x16xf16, #shared1> -> !tt.memdesc<16x64xf16, #shared> + %23 = tt.trans %20 {order=array} : !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> !tt.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> %24 = triton_gpu.convert_layout %22 : tensor<128x16xf16, #mma1> -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> - %25 = tt.dot %24, %23, %arg4 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * !tt.memdesc<16x64xf16, #shared> -> tensor<128x64xf32, #mma> + %25 = triton_nvidia_gpu.warp_group_dot %24, %23, %arg4 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * !tt.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf32, #mma> %26 = tt.addptr %arg5, %cst : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> scf.yield %25, %26 : tensor<128x64xf32, #mma>, tensor<64x16x!tt.ptr, #blocked> } @@ -431,22 +431,22 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : %16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> // CHECK: scf.for // CHECK: triton_gpu.async_wait {{.*}} {num = 2 : i32} - // CHECK: triton_nvidia_gpu.dot_async - // CHECK-NEXT: triton_nvidia_gpu.dot_wait {{.*}} {pendings = 1 : i32} + // CHECK: triton_nvidia_gpu.warp_group_dot + // CHECK-NEXT: triton_nvidia_gpu.warp_group_dot_wait {{.*}} {pendings = 1 : i32} // CHECK: triton_gpu.async_copy_global_to_local // CHECK: triton_gpu.async_commit_group // CHECK: scf.if - // CHECK: triton_nvidia_gpu.dot_wait {{.*}} {pendings = 0 : i32} + // CHECK: triton_nvidia_gpu.warp_group_dot_wait {{.*}} {pendings = 0 : i32} // CHECK: arith.mulf // CHECK: scf.yield // CHECK: scf.yield - // CHECK: triton_nvidia_gpu.dot_wait {{.*}} {pendings = 0 : i32} + // CHECK: triton_nvidia_gpu.warp_group_dot_wait {{.*}} {pendings = 0 : i32} %17:3 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2, %arg5 = %16, %arg6 = %8) -> (tensor<128x16xf32, #mma1>, tensor<64x16x!tt.ptr, #blocked>, tensor<128x64x!tt.ptr, #blocked1>) : i32 { %9 = tt.load %arg6 : tensor<128x64x!tt.ptr, #blocked1> %18 = tt.load %arg5 : tensor<64x16x!tt.ptr, #blocked> - %19 = triton_gpu.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !tt.memdesc<128x64xf16, #shared> - %20 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared1> - %acc = tt.dot %19, %20, %arg4 : !tt.memdesc<128x64xf16, #shared> * !tt.memdesc<64x16xf16, #shared1> -> tensor<128x16xf32, #mma1> + %19 = triton_gpu.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> + %20 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> + %acc = triton_nvidia_gpu.warp_group_dot %19, %20, %arg4 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> %cnd = arith.cmpi slt, %arg3, %ext : i32 %acc_ = scf.if %cnd -> (tensor<128x16xf32, #mma1>) { %acc_zero = arith.mulf %acc, %cst_2 : tensor<128x16xf32, #mma1> @@ -501,24 +501,24 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : %15 = tt.broadcast %13 : tensor<64x1xi32, #blocked> -> tensor<64x16xi32, #blocked> %16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> %18 = tt.load %16 : tensor<64x16x!tt.ptr, #blocked> - %19 = triton_gpu.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !tt.memdesc<128x64xf16, #shared> - %20 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared1> + %19 = triton_gpu.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> + %20 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> // CHECK: %[[ALLOC1:.+]] = triton_gpu.local_alloc // CHECK: %[[ALLOC2:.+]] = triton_gpu.local_alloc // CHECK: %[[R:.+]]:{{.+}} = scf.for - // CHECK: %[[DOT1:.+]] = triton_nvidia_gpu.dot_async{{.*}} + // CHECK: %[[DOT1:.+]] = triton_nvidia_gpu.warp_group_dot{{.*}} // CHECK: triton_gpu.async_wait {{.*}} {num = 1 : i32} // CHECK: %[[TRANS:.+]] = tt.trans{{.*}} : !tt.memdesc - // CHECK: %[[DOT2:.+]] = triton_nvidia_gpu.dot_async{{.*}} %[[TRANS]] - // CHECK: triton_nvidia_gpu.dot_wait %[[DOT1]], %[[DOT2]], %[[ALLOC1]], %[[ALLOC2]], %[[TRANS]] {pendings = 2 : i32} + // CHECK: %[[DOT2:.+]] = triton_nvidia_gpu.warp_group_dot{{.*}} %[[TRANS]] + // CHECK: triton_nvidia_gpu.warp_group_dot_wait %[[DOT1]], %[[DOT2]], %[[ALLOC1]], %[[ALLOC2]], %[[TRANS]] {pendings = 2 : i32} // CHECK: scf.yield - // CHECK: %{{.*}}:2 = triton_nvidia_gpu.dot_wait %[[R]]#{{.+}}, %[[R]]#{{.+}} {pendings = 0 : i32} : tensor<128x16xf32, #{{.*}}>, tensor<128x64xf32, #{{.*}}> + // CHECK: %{{.*}}:2 = triton_nvidia_gpu.warp_group_dot_wait %[[R]]#{{.+}}, %[[R]]#{{.+}} {pendings = 0 : i32} : tensor<128x16xf32, #{{.*}}>, tensor<128x64xf32, #{{.*}}> %17:3 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_3, %arg5 = %16, %arg6 = %cst_2) -> (tensor<128x64xf32, #mma>, tensor<64x16x!tt.ptr, #blocked>, tensor<128x16xf32, #mma1>) : i32 { - %21 = tt.dot %19, %20, %arg6 : !tt.memdesc<128x64xf16, #shared> * !tt.memdesc<64x16xf16, #shared1> -> tensor<128x16xf32, #mma1> + %21 = triton_nvidia_gpu.warp_group_dot %19, %20, %arg6 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> %l = tt.load %arg5 : tensor<64x16x!tt.ptr, #blocked> - %c = triton_gpu.local_alloc %l : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared1> - %23 = tt.trans %c {order=array} : !tt.memdesc<64x16xf16, #shared1> -> !tt.memdesc<16x64xf16, #shared> - %25 = tt.dot %cst_4, %23, %arg4 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * !tt.memdesc<16x64xf16, #shared> -> tensor<128x64xf32, #mma> + %c = triton_gpu.local_alloc %l : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> + %23 = tt.trans %c {order=array} : !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> !tt.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> + %25 = triton_nvidia_gpu.warp_group_dot %cst_4, %23, %arg4 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * !tt.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf32, #mma> %26 = tt.addptr %arg5, %cst : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> scf.yield %25, %26, %21 : tensor<128x64xf32, #mma>, tensor<64x16x!tt.ptr, #blocked>, tensor<128x16xf32, #mma1> } @@ -576,13 +576,13 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : %22:3 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst, %arg5 = %12, %arg6 = %21) -> (tensor<128x256xf32, #mma>, tensor<128x64x!tt.ptr, #blocked>, tensor<64x256x!tt.ptr, #blocked1>) : i32 { %35 = tt.load %arg5 : tensor<128x64x!tt.ptr, #blocked> %36 = tt.load %arg6 : tensor<64x256x!tt.ptr, #blocked1> - %37 = triton_gpu.local_alloc %35 : (tensor<128x64xf8E5M2, #blocked>) -> !tt.memdesc<128x64xf8E5M2, #shared> - %38 = triton_gpu.local_alloc %36 : (tensor<64x256xf8E5M2, #blocked1>) -> !tt.memdesc<64x256xf8E5M2, #shared1> + %37 = triton_gpu.local_alloc %35 : (tensor<128x64xf8E5M2, #blocked>) -> !tt.memdesc<128x64xf8E5M2, #shared, #triton_gpu.shared_memory> + %38 = triton_gpu.local_alloc %36 : (tensor<64x256xf8E5M2, #blocked1>) -> !tt.memdesc<64x256xf8E5M2, #shared1, #triton_gpu.shared_memory> // CHECK: triton_gpu.local_alloc // CHECK: scf.for - // CHECK: triton_nvidia_gpu.dot_async - // CHECK-NEXT: triton_nvidia_gpu.dot_wait - %39 = tt.dot %37, %38, %arg4 {maxNumImpreciseAcc = 1073741824 : i32} : !tt.memdesc<128x64xf8E5M2, #shared> * !tt.memdesc<64x256xf8E5M2, #shared1> -> tensor<128x256xf32, #mma> + // CHECK: triton_nvidia_gpu.warp_group_dot + // CHECK-NEXT: triton_nvidia_gpu.warp_group_dot_wait + %39 = triton_nvidia_gpu.warp_group_dot %37, %38, %arg4 {maxNumImpreciseAcc = 1073741824 : i32} : !tt.memdesc<128x64xf8E5M2, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x256xf8E5M2, #shared1, #triton_gpu.shared_memory> -> tensor<128x256xf32, #mma> %40 = tt.addptr %arg5, %cst_6 : tensor<128x64x!tt.ptr, #blocked>, tensor<128x64xi32, #blocked> %41 = tt.addptr %arg6, %cst_5 : tensor<64x256x!tt.ptr, #blocked1>, tensor<64x256xi32, #blocked1> scf.yield %39, %40, %41 : tensor<128x256xf32, #mma>, tensor<128x64x!tt.ptr, #blocked>, tensor<64x256x!tt.ptr, #blocked1> @@ -656,35 +656,35 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : %15 = tt.broadcast %13 : tensor<64x1xi32, #blocked> -> tensor<64x16xi32, #blocked> %16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> %18 = tt.load %16 : tensor<64x16x!tt.ptr, #blocked> - %19 = triton_gpu.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !tt.memdesc<128x64xf16, #shared> - %20 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared1> + %19 = triton_gpu.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> + %20 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> // CHECK: %[[LOOP:[^ :]+]]{{.*}} scf.for {{.*}} iter_args(%[[PREV_DOT2:[^ ]+]] - // CHECK-NOT: triton_nvidia_gpu.dot_wait - // CHECK: %[[DOT0:.+]] = triton_nvidia_gpu.dot_async - // CHECK-NOT: triton_nvidia_gpu.dot_wait - // CHECK: %[[DOT1:.+]] = triton_nvidia_gpu.dot_async - // CHECK-NEXT: triton_nvidia_gpu.dot_wait + // CHECK-NOT: triton_nvidia_gpu.warp_group_dot_wait + // CHECK: %[[DOT0:.+]] = triton_nvidia_gpu.warp_group_dot + // CHECK-NOT: triton_nvidia_gpu.warp_group_dot_wait + // CHECK: %[[DOT1:.+]] = triton_nvidia_gpu.warp_group_dot + // CHECK-NEXT: triton_nvidia_gpu.warp_group_dot_wait // CHECK-DAG-SAME: %[[DOT0]] // CHECK-DAG-SAME: %[[DOT1]] // CHECK-DAG-SAME: %[[PREV_DOT2]] // CHECK-SAME: {pendings = 0 : i32} - // CHECK: %[[DOT2:.+]] = triton_nvidia_gpu.dot_async - // CHECK-NOT: triton_nvidia_gpu.dot_wait + // CHECK: %[[DOT2:.+]] = triton_nvidia_gpu.warp_group_dot + // CHECK-NOT: triton_nvidia_gpu.warp_group_dot_wait // CHECK: scf.yield %[[DOT2]] - // CHECK: triton_nvidia_gpu.dot_wait %[[LOOP]]#3, %[[LOOP]]#0 {pendings = 0 : i32} + // CHECK: triton_nvidia_gpu.warp_group_dot_wait %[[LOOP]]#3, %[[LOOP]]#0 {pendings = 0 : i32} %17:4 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%prev_dot2 = %cst_3, %arg5 = %16, %prev_dot1 = %cst_2, %prev_dot0 = %cst_2) -> (tensor<128x64xf32, #mma>, tensor<64x16x!tt.ptr, #blocked>, tensor<128x16xf32, #mma1>, tensor<128x16xf32, #mma1>) : i32 { // This one can be async. - %dot0 = tt.dot %19, %20, %prev_dot1 : !tt.memdesc<128x64xf16, #shared> * !tt.memdesc<64x16xf16, #shared1> -> tensor<128x16xf32, #mma1> + %dot0 = triton_nvidia_gpu.warp_group_dot %19, %20, %prev_dot1 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> // This can't be async because its result is modified before it's yielded. - %dot1 = tt.dot %19, %20, %prev_dot1 : !tt.memdesc<128x64xf16, #shared> * !tt.memdesc<64x16xf16, #shared1> -> tensor<128x16xf32, #mma1> + %dot1 = triton_nvidia_gpu.warp_group_dot %19, %20, %prev_dot1 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> %dot1.1 = arith.addf %dot1, %dot1 : tensor<128x16xf32, #mma1> %l = tt.load %arg5 : tensor<64x16x!tt.ptr, #blocked> - %c = triton_gpu.local_alloc %l : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared1> - %23 = tt.trans %c {order=array} : !tt.memdesc<64x16xf16, #shared1> -> !tt.memdesc<16x64xf16, #shared> + %c = triton_gpu.local_alloc %l : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> + %23 = tt.trans %c {order=array} : !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> !tt.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> // This dot can be async even though %prev_dot2 is not used directly by an // async dot, because that use follows the synchronous dot above. %prev_dot2.1 = arith.addf %prev_dot2, %prev_dot2 : tensor<128x64xf32, #mma> - %dot2 = tt.dot %cst_4, %23, %prev_dot2.1 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * !tt.memdesc<16x64xf16, #shared> -> tensor<128x64xf32, #mma> + %dot2 = triton_nvidia_gpu.warp_group_dot %cst_4, %23, %prev_dot2.1 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * !tt.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf32, #mma> %26 = tt.addptr %arg5, %cst : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> scf.yield %dot2, %26, %dot1.1, %dot0 : tensor<128x64xf32, #mma>, tensor<64x16x!tt.ptr, #blocked>, tensor<128x16xf32, #mma1>, tensor<128x16xf32, #mma1> } diff --git a/test/TritonGPU/loop-pipeline.mlir b/test/TritonGPU/loop-pipeline.mlir index 20d8093b03ae..6ed34a96082b 100644 --- a/test/TritonGPU/loop-pipeline.mlir +++ b/test/TritonGPU/loop-pipeline.mlir @@ -1,5 +1,6 @@ -// RUN: triton-opt %s -split-input-file -tritongpu-pipeline=num-stages=3 -canonicalize | FileCheck %s +// RUN: triton-opt %s -split-input-file -tritongpu-pipeline=num-stages=3 -canonicalize | FileCheck %s --check-prefixes=COMMON,CHECK // RUN: triton-opt %s -split-input-file -tritongpu-pipeline=num-stages=3 | FileCheck %s --check-prefix=CHECK-NOCANON +// RUN: triton-opt %s -split-input-file -tritonamdgpu-stream-pipeline-v2=num_stages=2 -canonicalize | FileCheck %s --check-prefixes=COMMON,AMD // 4 warps // matmul: 128x32 @ 32x128 -> 128x128 @@ -55,7 +56,52 @@ // CHECK-DAG: %[[NEXT_B:.*]] = triton_gpu.memdesc_subview %{{.+}}[%[[EXT_IDX_3]], // CHECK-DAG: triton_gpu.async_wait {{.*}} {num = 2 : i32} // CHECK: scf.yield {{.*}}, %[[INS_IDX_3]], %[[EXT_IDX_3]], %[[NEXT_A]], %[[NEXT_B]] -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.target" = "cuda:80"} { + +// AMD-LABEL: tt.func @matmul_loop +// AMD: %[[LOCAL_ALLOC_10:.*]] = triton_gpu.local_alloc +// AMD: %[[LOCAL_ALLOC_11:.*]] = triton_gpu.local_alloc +// AMD: %[[CMPI_12:.*]] = arith.cmpi slt, %{{.*}}, %{{.*}} +// AMD: %[[SPLAT_13:.*]] = tt.splat %[[CMPI_12]] +// AMD: %[[LOAD_14:.*]] = tt.load %{{.*}}, %[[SPLAT_13]] +// AMD: %[[SPLAT_15:.*]] = tt.splat %[[CMPI_12]] +// AMD: %[[LOAD_16:.*]] = tt.load %{{.*}}, %[[SPLAT_15]], %{{.*}} +// AMD: %[[MEMDESC_SUBVIEW_17:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_10]][%{{.*}}, %{{.*}}, %{{.*}}] +// AMD: triton_gpu.local_store %[[LOAD_14]], %[[MEMDESC_SUBVIEW_17]] +// AMD: %[[MEMDESC_SUBVIEW_18:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_11]][%{{.*}}, %{{.*}}, %{{.*}}] +// AMD: triton_gpu.local_store %[[LOAD_16]], %[[MEMDESC_SUBVIEW_18]] +// AMD: %[[SUBI_19:.*]] = arith.subi %{{.*}}, %{{.*}} +// AMD: %{{.*}}:6 = scf.for %[[ARG5:.*]] = %{{.*}} to %[[SUBI_19]] step %{{.*}} iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %{{.*}}, %[[ARG10:.*]] = %[[MEMDESC_SUBVIEW_17]], %[[ARG11:.*]] = %[[MEMDESC_SUBVIEW_18]]) +// AMD: %[[LOCAL_LOAD_23:.*]] = triton_gpu.local_load %[[ARG10]] +// AMD: %[[LOCAL_LOAD_24:.*]] = triton_gpu.local_load %[[ARG11]] +// AMD: %[[MULF_25:.*]] = arith.mulf %[[LOCAL_LOAD_24]], %{{.*}} +// AMD: %[[DOT_26:.*]] = tt.dot %[[LOCAL_LOAD_23]], %[[MULF_25]], %[[ARG8]] +// AMD: %[[ADDPTR_27:.*]] = tt.addptr %[[ARG6]], %{{.*}} +// AMD: %[[ADDPTR_28:.*]] = tt.addptr %[[ARG7]], %{{.*}} +// AMD: %[[LOAD_29:.*]] = tt.load %[[ADDPTR_27]] +// AMD: %[[LOAD_30:.*]] = tt.load %[[ADDPTR_28]] +// AMD: %[[ADDI_31:.*]] = arith.addi %[[ARG9]], %{{.*}} +// AMD: %[[CMPI_32:.*]] = arith.cmpi slt, %[[ADDI_31]], %{{.*}} +// AMD: %[[SELECT_33:.*]] = arith.select %[[CMPI_32]], %[[ADDI_31]], %{{.*}} +// AMD: %[[MEMDESC_SUBVIEW_34:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_10]][%[[SELECT_33]], %{{.*}}, %{{.*}}] +// AMD: triton_gpu.local_store %[[LOAD_29]], %[[MEMDESC_SUBVIEW_34]] +// AMD: %[[MEMDESC_SUBVIEW_35:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_11]][%[[SELECT_33]], %{{.*}}, %{{.*}}] +// AMD: triton_gpu.local_store %[[LOAD_30]], %[[MEMDESC_SUBVIEW_35]] +// AMD: scf.yield %[[ADDPTR_27]], %[[ADDPTR_28]], %[[DOT_26]], %[[SELECT_33]], %[[MEMDESC_SUBVIEW_34]], %[[MEMDESC_SUBVIEW_35]] +// AMD: } +// AMD: %[[CMPI_21:.*]] = arith.cmpi sgt, %{{.*}}, %{{.*}} +// AMD: %[[IF_22:.*]] = scf.if %[[CMPI_21]] +// AMD: %[[LOCAL_LOAD_23:.*]] = triton_gpu.local_load %{{.*}}#4 +// AMD: %[[LOCAL_LOAD_24:.*]] = triton_gpu.local_load %{{.*}}#5 +// AMD: %[[MULF_25:.*]] = arith.mulf %[[LOCAL_LOAD_24]], %{{.*}} +// AMD: %[[DOT_26:.*]] = tt.dot %[[LOCAL_LOAD_23]], %[[MULF_25]], %{{.*}}#2 +// AMD: scf.yield %[[DOT_26]] +// AMD: } else { +// AMD: scf.yield %{{.*}}#2 +// AMD: } +// AMD: triton_gpu.local_dealloc %[[LOCAL_ALLOC_10]] +// AMD: triton_gpu.local_dealloc %[[LOCAL_ALLOC_11]] + +module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32} { tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr {tt.divisibility = 16 : i32}, %B : !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #C> { @@ -146,6 +192,29 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, // CHECK: triton_gpu.memdesc_subview %[[BBUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]] // CHECK: triton_gpu.async_copy_global_to_local // CHECK scf.yield + +// AMD-LABEL: tt.func @matmul_loop_nested +// AMD: scf.for +// AMD-COUNT-2: triton_gpu.local_alloc +// AMD-COUNT-2: tt.load +// AMD: %[[SUBVIEW0:.*]] = triton_gpu.memdesc_subview +// AMD: triton_gpu.local_store %{{.+}}, %[[SUBVIEW0]] +// AMD: %[[SUBVIEW1:.*]] = triton_gpu.memdesc_subview +// AMD: triton_gpu.local_store %{{.+}}, %[[SUBVIEW1]] +// AMD: %[[FOR:.*]]:6 = scf.for +// AMD-COUNT-2: triton_gpu.local_load +// AMD: tt.dot +// AMD-COUNT-2: tt.addptr +// AMD-COUNT-2: tt.load +// AMD: %[[SUBVIEW0:.*]] = triton_gpu.memdesc_subview +// AMD: triton_gpu.local_store %{{.+}}, %[[SUBVIEW0]] +// AMD: %[[SUBVIEW1:.*]] = triton_gpu.memdesc_subview +// AMD: triton_gpu.local_store %{{.+}}, %[[SUBVIEW1]] +// AMD: scf.yield +// AMD: %[[IF_23:.*]] = scf.if +// AMD-COUNT-2: triton_gpu.local_dealloc +// AMD: scf.yield %[[IF_23]] + tt.func @matmul_loop_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr {tt.divisibility = 16 : i32}, %B : !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #C>{ @@ -216,6 +285,39 @@ tt.func @matmul_loop_nested(%lb : index, %ub : index, %step : index, // CHECK-DAG: %[[NEXT_B:.*]] = triton_gpu.memdesc_subview %{{.+}}[%[[EXT_IDX_3]] // CHECK-DAG: triton_gpu.async_wait {{.*}} {num = 1 : i32} // CHECK: scf.yield {{.*}}, %[[INS_IDX_3]], %[[EXT_IDX_3]], %[[NEXT_B]] + +// AMD-LABEL: tt.func @matmul_loop_single_pipeline +// AMD: %[[LOAD_10:.*]] = tt.load %{{.*}} +// AMD: %[[CONVERT_LAYOUT_11:.*]] = triton_gpu.convert_layout %[[LOAD_10]] +// AMD: %[[LOCAL_ALLOC_12:.*]] = triton_gpu.local_alloc +// AMD: %[[CMPI_13:.*]] = arith.cmpi slt, %{{.*}}, %{{.*}} +// AMD: %[[SPLAT_14:.*]] = tt.splat %[[CMPI_13]] +// AMD: %[[LOAD_15:.*]] = tt.load %{{.*}}, %[[SPLAT_14]], %{{.*}} +// AMD: %[[MEMDESC_SUBVIEW_16:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_12]][%{{.*}}, %{{.*}}, %{{.*}}] +// AMD: triton_gpu.local_store %[[LOAD_15]], %[[MEMDESC_SUBVIEW_16]] +// AMD: %[[SUBI_17:.*]] = arith.subi %{{.*}}, %{{.*}} +// AMD: %{{.*}}:4 = scf.for %[[ARG5:.*]] = %{{.*}} to %[[SUBI_17]] step %{{.*}} iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %[[MEMDESC_SUBVIEW_16]]) +// AMD: %[[LOCAL_LOAD_21:.*]] = triton_gpu.local_load %[[ARG9]] +// AMD: %[[DOT_22:.*]] = tt.dot %[[CONVERT_LAYOUT_11]], %[[LOCAL_LOAD_21]], %[[ARG7]] +// AMD: %[[ADDPTR_23:.*]] = tt.addptr %[[ARG6]], %{{.*}} +// AMD: %[[LOAD_24:.*]] = tt.load %[[ADDPTR_23]] +// AMD: %[[ADDI_25:.*]] = arith.addi %[[ARG8]], %{{.*}} +// AMD: %[[CMPI_26:.*]] = arith.cmpi slt, %[[ADDI_25]], %{{.*}} +// AMD: %[[SELECT_27:.*]] = arith.select %[[CMPI_26]], %[[ADDI_25]], %{{.*}} +// AMD: %[[MEMDESC_SUBVIEW_28:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_12]][%[[SELECT_27]], %{{.*}}, %{{.*}}] +// AMD: triton_gpu.local_store %[[LOAD_24]], %[[MEMDESC_SUBVIEW_28]] +// AMD: scf.yield %[[ADDPTR_23]], %[[DOT_22]], %[[SELECT_27]], %[[MEMDESC_SUBVIEW_28]] +// AMD: } +// AMD: %[[CMPI_19:.*]] = arith.cmpi sgt, %{{.*}}, %{{.*}} +// AMD: %[[IF_20:.*]] = scf.if %[[CMPI_19]] +// AMD: %[[LOCAL_LOAD_21:.*]] = triton_gpu.local_load %{{.*}}#3 +// AMD: %[[DOT_22:.*]] = tt.dot %[[CONVERT_LAYOUT_11]], %[[LOCAL_LOAD_21]], %{{.*}}#1 +// AMD: scf.yield %[[DOT_22]] +// AMD: } else { +// AMD: scf.yield %{{.*}}#1 +// AMD: } +// AMD: triton_gpu.local_dealloc %[[LOCAL_ALLOC_12]] + tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, %A : !tt.ptr {tt.divisibility = 16 : i32}, %B : !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #C> { @@ -268,6 +370,36 @@ tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, // CHECK: %[[NEXT_BUFFER_0:.*]] = tt.addptr {{.*}}, %[[IND_BUFFER_2]] // CHECK: triton_gpu.async_copy_global_to_local %[[NEXT_BUFFER_0]] // CHECK: triton_gpu.async_wait {{.*}} {num = 2 : i32} + +// AMD-LABEL: tt.func @indirect_bmm_scalar +// AMD: %[[LOCAL_ALLOC_0:.*]] = triton_gpu.local_alloc +// AMD: %[[LOCAL_ALLOC_1:.*]] = triton_gpu.local_alloc +// AMD: %[[CMPI_2:.*]] = arith.cmpi sgt, %{{.*}}, %{{.*}} +// AMD: %[[SPLAT_3:.*]] = tt.splat %[[CMPI_2]] +// AMD: %[[LOAD_4:.*]] = tt.load %{{.*}}, %[[SPLAT_3]] +// AMD: %[[LOAD_5:.*]] = tt.load %{{.*}}, %[[CMPI_2]] +// AMD: %[[MULI_6:.*]] = arith.muli %{{.*}}, %[[LOAD_5]] +// AMD: %[[SPLAT_7:.*]] = tt.splat %[[MULI_6]] +// AMD: %[[ADDPTR_8:.*]] = tt.addptr %{{.*}}, %[[SPLAT_7]] +// AMD: %[[SPLAT_9:.*]] = tt.splat %[[CMPI_2]] +// AMD: %[[LOAD_10:.*]] = tt.load %[[ADDPTR_8]], %[[SPLAT_9]] +// AMD: %[[CMPI_11:.*]] = arith.cmpi sgt, %{{.*}}, %{{.*}} +// AMD: %[[ADDPTR_12:.*]] = tt.addptr %{{.*}}, %{{.*}} +// AMD: %[[ADDPTR_13:.*]] = tt.addptr %{{.*}}, %{{.*}} +// AMD: %[[SPLAT_14:.*]] = tt.splat %[[CMPI_11]] +// AMD: %[[LOAD_15:.*]] = tt.load %[[ADDPTR_12]], %[[SPLAT_14]] +// AMD: %[[LOAD_16:.*]] = tt.load %[[ADDPTR_13]], %[[CMPI_11]] +// AMD: %[[MULI_17:.*]] = arith.muli %{{.*}}, %[[LOAD_16]] +// AMD: %[[SPLAT_18:.*]] = tt.splat %[[MULI_17]] +// AMD: %[[ADDPTR_19:.*]] = tt.addptr %{{.*}}, %[[SPLAT_18]] +// AMD: %[[SPLAT_20:.*]] = tt.splat %[[CMPI_11]] +// AMD: %[[LOAD_21:.*]] = tt.load %[[ADDPTR_19]], %[[SPLAT_20]] +// AMD: %[[MEMDESC_SUBVIEW_22:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_0]][%{{.*}}, %{{.*}}, %{{.*}}] +// AMD: triton_gpu.local_store %[[LOAD_4]], %[[MEMDESC_SUBVIEW_22]] +// AMD: %[[MEMDESC_SUBVIEW_23:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_1]][%{{.*}}, %{{.*}}, %{{.*}}] +// AMD: triton_gpu.local_store %[[LOAD_10]], %[[MEMDESC_SUBVIEW_23]] +// AMD: %{{.*}}:8 = scf.for %[[ARG6:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %[[ADDPTR_12]], %[[ARG9:.*]] = %[[ADDPTR_13]], %[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %[[MEMDESC_SUBVIEW_22]], %[[ARG12:.*]] = %[[MEMDESC_SUBVIEW_23]], %[[ARG13:.*]] = %[[LOAD_15]], %[[ARG14:.*]] = %[[LOAD_21]]) + tt.func @indirect_bmm_scalar(%77: i64 {tt.divisibility=16: i32}, %76: index, %49: tensor<16x16x!tt.ptr, #AL> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, @@ -293,7 +425,7 @@ tt.func @indirect_bmm_scalar(%77: i64 {tt.divisibility=16: i32}, %91 = tt.addptr %arg20, %78 : tensor<16x16x!tt.ptr, #AL>, tensor<16x16xi32, #AL> %92 = tt.addptr %arg21, %c1_i32 : !tt.ptr, i32 scf.yield %90, %91, %92 : tensor<16x16xf32, #C>, tensor<16x16x!tt.ptr, #AL>, !tt.ptr - } + } {tt.num_stages = 3 : i32} tt.return %79#0 : tensor<16x16xf32, #C> } @@ -313,6 +445,60 @@ tt.func @indirect_bmm_scalar(%77: i64 {tt.divisibility=16: i32}, // CHECK: triton_gpu.async_copy_global_to_local %[[NEXT_BUFFER_0]] // CHECK: triton_gpu.async_wait {{.*}} {num = 2 : i32} // CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[IND_BUFFER_0]] + +// AMD-LABEL: tt.func @indirect_bmm_scalar_dist_one +// AMD: %[[LOAD_0:.*]] = tt.load %{{.*}} +// AMD: %[[ADDPTR_1:.*]] = tt.addptr %{{.*}}, %{{.*}} +// AMD: %[[LOCAL_ALLOC_2:.*]] = triton_gpu.local_alloc +// AMD: %[[LOCAL_ALLOC_3:.*]] = triton_gpu.local_alloc +// AMD: %[[CMPI_4:.*]] = arith.cmpi sgt, %{{.*}}, %{{.*}} +// AMD: %[[SPLAT_5:.*]] = tt.splat %[[CMPI_4]] +// AMD: %[[LOAD_6:.*]] = tt.load %{{.*}}, %[[SPLAT_5]] +// AMD: %[[LOAD_7:.*]] = tt.load %[[ADDPTR_1]], %[[CMPI_4]] +// AMD: %[[MULI_8:.*]] = arith.muli %{{.*}}, %[[LOAD_0]] +// AMD: %[[SPLAT_9:.*]] = tt.splat %[[MULI_8]] +// AMD: %[[ADDPTR_10:.*]] = tt.addptr %{{.*}}, %[[SPLAT_9]] +// AMD: %[[SPLAT_11:.*]] = tt.splat %[[CMPI_4]] +// AMD: %[[LOAD_12:.*]] = tt.load %[[ADDPTR_10]], %[[SPLAT_11]] +// AMD: %[[ADDPTR_13:.*]] = tt.addptr %[[ADDPTR_1]], %{{.*}} +// AMD: %[[MEMDESC_SUBVIEW_14:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_2]][%{{.*}}, %{{.*}}, %{{.*}}] +// AMD: triton_gpu.local_store %[[LOAD_6]], %[[MEMDESC_SUBVIEW_14]] +// AMD: %[[MEMDESC_SUBVIEW_15:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_3]][%{{.*}}, %{{.*}}, %{{.*}}] +// AMD: triton_gpu.local_store %[[LOAD_12]], %[[MEMDESC_SUBVIEW_15]] +// AMD: %[[SUBI_16:.*]] = arith.subi %{{.*}}, %{{.*}} +// AMD: %{{.*}}:7 = scf.for %[[ARG6:.*]] = %{{.*}} to %[[SUBI_16]] step %{{.*}} iter_args(%[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %[[ADDPTR_13]], %[[ARG10:.*]] = %[[LOAD_7]], %[[ARG11:.*]] = %{{.*}}, %[[ARG12:.*]] = %[[MEMDESC_SUBVIEW_14]], %[[ARG13:.*]] = %[[MEMDESC_SUBVIEW_15]]) +// AMD: %[[LOCAL_LOAD_20:.*]] = triton_gpu.local_load %[[ARG12]] +// AMD: %[[LOCAL_LOAD_21:.*]] = triton_gpu.local_load %[[ARG13]] +// AMD: %[[DOT_22:.*]] = tt.dot %[[LOCAL_LOAD_20]], %[[LOCAL_LOAD_21]], %[[ARG7]] +// AMD: %[[ADDPTR_23:.*]] = tt.addptr %[[ARG8]], %{{.*}} +// AMD: %[[LOAD_24:.*]] = tt.load %[[ADDPTR_23]] +// AMD: %[[LOAD_25:.*]] = tt.load %[[ARG9]] +// AMD: %[[MULI_26:.*]] = arith.muli %{{.*}}, %[[ARG10]] +// AMD: %[[SPLAT_27:.*]] = tt.splat %[[MULI_26]] +// AMD: %[[ADDPTR_28:.*]] = tt.addptr %{{.*}}, %[[SPLAT_27]] +// AMD: %[[LOAD_29:.*]] = tt.load %[[ADDPTR_28]] +// AMD: %[[ADDPTR_30:.*]] = tt.addptr %[[ARG9]], %{{.*}} +// AMD: %[[ADDI_31:.*]] = arith.addi %[[ARG11]], %{{.*}} +// AMD: %[[CMPI_32:.*]] = arith.cmpi slt, %[[ADDI_31]], %{{.*}} +// AMD: %[[SELECT_33:.*]] = arith.select %[[CMPI_32]], %[[ADDI_31]], %{{.*}} +// AMD: %[[MEMDESC_SUBVIEW_34:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_2]][%[[SELECT_33]], %{{.*}}, %{{.*}}] +// AMD: triton_gpu.local_store %[[LOAD_24]], %[[MEMDESC_SUBVIEW_34]] +// AMD: %[[MEMDESC_SUBVIEW_35:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_3]][%[[SELECT_33]], %{{.*}}, %{{.*}}] +// AMD: triton_gpu.local_store %[[LOAD_29]], %[[MEMDESC_SUBVIEW_35]] +// AMD: scf.yield %[[DOT_22]], %[[ADDPTR_23]], %[[ADDPTR_30]], %[[LOAD_25]], %[[SELECT_33]], %[[MEMDESC_SUBVIEW_34]], %[[MEMDESC_SUBVIEW_35]] +// AMD: } +// AMD: %[[CMPI_18:.*]] = arith.cmpi sgt, %{{.*}}, %{{.*}} +// AMD: %[[IF_19:.*]] = scf.if %[[CMPI_18]] +// AMD: %[[LOCAL_LOAD_20:.*]] = triton_gpu.local_load %{{.*}}#5 +// AMD: %[[LOCAL_LOAD_21:.*]] = triton_gpu.local_load %{{.*}}#6 +// AMD: %[[DOT_22:.*]] = tt.dot %[[LOCAL_LOAD_20]], %[[LOCAL_LOAD_21]], %{{.*}}#0 +// AMD: scf.yield %[[DOT_22]] +// AMD: } else { +// AMD: scf.yield %{{.*}}#0 +// AMD: } +// AMD: triton_gpu.local_dealloc %[[LOCAL_ALLOC_2]] +// AMD: triton_gpu.local_dealloc %[[LOCAL_ALLOC_3]] + tt.func @indirect_bmm_scalar_dist_one(%77: i64 {tt.divisibility=16: i32}, %76: index, %49: tensor<16x16x!tt.ptr, #AL> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, @@ -365,6 +551,52 @@ tt.func @indirect_bmm_scalar_dist_one(%77: i64 {tt.divisibility=16: i32}, // CHECK: triton_gpu.async_copy_global_to_local %[[NEXT_BUFFER_0]] // CHECK: triton_gpu.async_wait {{.*}} {num = 1 : i32} // CHECK: scf.yield + +// AMD-LABEL: tt.func @indirect_bmm_vector +// AMD: %[[LOCAL_ALLOC_0:.*]] = triton_gpu.local_alloc +// AMD: %[[LOCAL_ALLOC_1:.*]] = triton_gpu.local_alloc +// AMD: %[[CMPI_2:.*]] = arith.cmpi sgt, %{{.*}}, %{{.*}} +// AMD: %[[SPLAT_3:.*]] = tt.splat %[[CMPI_2]] +// AMD: %[[LOAD_4:.*]] = tt.load %{{.*}}, %[[SPLAT_3]] +// AMD: %[[CMPI_5:.*]] = arith.cmpi sgt, %{{.*}}, %{{.*}} +// AMD: %[[ADDPTR_6:.*]] = tt.addptr %{{.*}}, %{{.*}} +// AMD: %[[SPLAT_7:.*]] = tt.splat %[[CMPI_2]] +// AMD: %[[LOAD_8:.*]] = tt.load %{{.*}}, %[[SPLAT_7]] +// AMD: %[[EXPAND_DIMS_9:.*]] = tt.expand_dims %[[LOAD_4]] {axis = 1 : i32} +// AMD: %[[BROADCAST_10:.*]] = tt.broadcast %[[EXPAND_DIMS_9]] +// AMD: %[[MULI_11:.*]] = arith.muli %{{.*}}, %[[BROADCAST_10]] +// AMD: %[[ADDPTR_12:.*]] = tt.addptr %{{.*}}, %[[MULI_11]] +// AMD: %[[SPLAT_13:.*]] = tt.splat %[[CMPI_2]] +// AMD: %[[LOAD_14:.*]] = tt.load %[[ADDPTR_12]], %[[SPLAT_13]] +// AMD: %[[SPLAT_15:.*]] = tt.splat %[[CMPI_5]] +// AMD: %[[LOAD_16:.*]] = tt.load %[[ADDPTR_6]], %[[SPLAT_15]] +// AMD: %[[MEMDESC_SUBVIEW_17:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_0]][%{{.*}}, %{{.*}}, %{{.*}}] +// AMD: triton_gpu.local_store %[[LOAD_8]], %[[MEMDESC_SUBVIEW_17]] +// AMD: %[[MEMDESC_SUBVIEW_18:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_1]][%{{.*}}, %{{.*}}, %{{.*}}] +// AMD: triton_gpu.local_store %[[LOAD_14]], %[[MEMDESC_SUBVIEW_18]] +// AMD: %[[SUBI_19:.*]] = arith.subi %{{.*}}, %{{.*}} +// AMD: %{{.*}}:7 = scf.for %[[ARG6:.*]] = %{{.*}} to %[[SUBI_19]] step %{{.*}} iter_args(%[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %[[ADDPTR_6]], %[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %[[MEMDESC_SUBVIEW_17]], %[[ARG12:.*]] = %[[MEMDESC_SUBVIEW_18]], %[[ARG13:.*]] = %[[LOAD_16]]) +// AMD: %[[LOCAL_LOAD_47:.*]] = triton_gpu.local_load %[[ARG11]] +// AMD: %[[LOCAL_LOAD_48:.*]] = triton_gpu.local_load %[[ARG12]] +// AMD: %[[DOT_49:.*]] = tt.dot %[[LOCAL_LOAD_47]], %[[LOCAL_LOAD_48]], %[[ARG7]] +// AMD: %[[ADDPTR_50:.*]] = tt.addptr %[[ARG8]], %{{.*}} +// AMD: %[[ADDPTR_51:.*]] = tt.addptr %[[ARG9]], %{{.*}} +// AMD: %[[LOAD_52:.*]] = tt.load %[[ADDPTR_50]] +// AMD: %[[EXPAND_DIMS_53:.*]] = tt.expand_dims %[[ARG13]] {axis = 1 : i32} +// AMD: %[[BROADCAST_54:.*]] = tt.broadcast %[[EXPAND_DIMS_53]] +// AMD: %[[MULI_55:.*]] = arith.muli %{{.*}}, %[[BROADCAST_54]] +// AMD: %[[ADDPTR_56:.*]] = tt.addptr %{{.*}}, %[[MULI_55]] +// AMD: %[[LOAD_57:.*]] = tt.load %[[ADDPTR_56]] +// AMD: %[[LOAD_58:.*]] = tt.load %[[ADDPTR_51]] +// AMD: %[[ADDI_59:.*]] = arith.addi %[[ARG10]], %{{.*}} +// AMD: %[[CMPI_60:.*]] = arith.cmpi slt, %[[ADDI_59]], %{{.*}} +// AMD: %[[SELECT_61:.*]] = arith.select %[[CMPI_60]], %[[ADDI_59]], %{{.*}} +// AMD: %[[MEMDESC_SUBVIEW_62:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_0]][%[[SELECT_61]], %{{.*}}, %{{.*}}] +// AMD: triton_gpu.local_store %[[LOAD_52]], %[[MEMDESC_SUBVIEW_62]] +// AMD: %[[MEMDESC_SUBVIEW_63:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_1]][%[[SELECT_61]], %{{.*}}, %{{.*}}] +// AMD: triton_gpu.local_store %[[LOAD_57]], %[[MEMDESC_SUBVIEW_63]] +// AMD: scf.yield %[[DOT_49]], %[[ADDPTR_50]], %[[ADDPTR_51]], %[[SELECT_61]], %[[MEMDESC_SUBVIEW_62]], %[[MEMDESC_SUBVIEW_63]], %[[LOAD_58]] + tt.func @indirect_bmm_vector(%77: tensor<16x16xi64, #BL> {tt.divisibility=16: i32, tt.constancy=16: i32}, %76: index, %49: tensor<16x16x!tt.ptr, #AL> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, @@ -392,16 +624,16 @@ tt.func @indirect_bmm_vector(%77: tensor<16x16xi64, #BL> {tt.divisibility=16: i3 %91 = tt.addptr %arg20, %78 : tensor<16x16x!tt.ptr, #AL>, tensor<16x16xi32, #AL> %92 = tt.addptr %arg21, %c1_i32_splat : tensor<16x!tt.ptr, #BLs1>, tensor<16xi32, #BLs1> scf.yield %90, %91, %92 : tensor<16x16xf32, #C>, tensor<16x16x!tt.ptr, #AL>, tensor<16x!tt.ptr, #BLs1> - } + } {tt.num_stages = 3 : i32} tt.return %79#0 : tensor<16x16xf32, #C> } -// CHECK-LABEL: tt.func @post_load_inv -// CHECK: scf.for -// CHECK-DAG: %[[IV:.*]] = arith.index_cast -// CHECK: %[[NEXT_IV:.*]] = arith.addi %[[IV]], %c1_i32 : i32 -// CHECK: arith.index_cast -// CHECK-NOT: arith.addi %[[NEXT_IV]] +// COMMON-LABEL: tt.func @post_load_inv +// COMMON: scf.for +// COMMON-DAG: %[[IV:.*]] = arith.index_cast +// COMMON: %[[NEXT_IV:.*]] = arith.addi %[[IV]], %c1_i32 : i32 +// COMMON: arith.index_cast +// COMMON-NOT: arith.addi %[[NEXT_IV]] tt.func @post_load_inv(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, @@ -452,11 +684,12 @@ tt.func @post_load_inv(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, tt.return %85#0 : tensor<32x32xf32, #C> } -// CHECK-LABEL: tt.func @cross_iter_dep +// COMMON-LABEL: tt.func @cross_iter_dep // TODO: enable pipelining with distance of 2 -// CHECK-NOT: triton_gpu.async_commit_group -// CHECK: scf.for -// CHECK: scf.yield +// COMMON-NOT: triton_gpu.async_commit_group +// COMMON: scf.for +// COMMON: scf.yield + tt.func @cross_iter_dep(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, @@ -509,14 +742,14 @@ tt.func @cross_iter_dep(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, tt.return %119#0 : tensor<32x32xf32, #C> } -// CHECK-LABEL: tt.func @dep_arg_two_uses -// CHECK: tt.expand_dims -// CHECK: tt.expand_dims -// CHECK: tt.expand_dims %arg5 -// CHECK-NEXT: tt.expand_dims %arg5 -// CHECK: %[[PTR0:.*]] = tt.splat %arg6 -// CHECK: %[[PTR1:.*]] = tt.addptr %[[PTR0]] -// CHECK-NEXT: tt.load %[[PTR1]] +// COMMON-LABEL: tt.func @dep_arg_two_uses +// COMMON: tt.expand_dims +// COMMON: tt.expand_dims +// COMMON: tt.expand_dims %arg5 +// COMMON-NEXT: tt.expand_dims %arg5 +// COMMON: %[[PTR0:.*]] = tt.splat %arg6 +// COMMON: %[[PTR1:.*]] = tt.addptr %[[PTR0]] +// COMMON-NEXT: tt.load %[[PTR1]] tt.func @dep_arg_two_uses(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #C> { @@ -582,7 +815,7 @@ tt.func @dep_arg_two_uses(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, #mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> #shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = false}> #shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: tt.func @load_two_users tt.func @load_two_users(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) -> (tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma>) { %cst = arith.constant dense<0> : tensor<1x16xi32, #blocked> @@ -626,9 +859,9 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : %21 = tt.dot %19, %20, %cst_1 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x16xf32, #mma> %22 = arith.truncf %21 : tensor<128x16xf32, #mma> to tensor<128x16xf16, #mma> %23 = triton_gpu.convert_layout %22 : tensor<128x16xf16, #mma> -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - %24 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared> - %25 = tt.trans %24 {order=array} : !tt.memdesc<64x16xf16, #shared> -> !tt.memdesc<16x64xf16, #shared1> - %26 = triton_gpu.local_load %25 : !tt.memdesc<16x64xf16, #shared1> -> tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %24 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared, #triton_gpu.shared_memory> + %25 = tt.trans %24 {order=array} : !tt.memdesc<64x16xf16, #shared, #triton_gpu.shared_memory> -> !tt.memdesc<16x64xf16, #shared1, #triton_gpu.shared_memory> + %26 = triton_gpu.local_load %25 : !tt.memdesc<16x64xf16, #shared1, #triton_gpu.shared_memory> -> tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> %27 = tt.dot %23, %26, %arg4 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x64xf32, #mma> scf.yield %21, %27 : tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma> } @@ -643,8 +876,8 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : #mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> #shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 2, order = [0, 1], hasLeadingOffset = false}> #shared1 = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 2, order = [1, 0], hasLeadingOffset = false}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { -// CHECK-LABEL: tt.func @load_two_users_incompatible_layouts +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +// COMMON-LABEL: tt.func @load_two_users_incompatible_layouts tt.func @load_two_users_incompatible_layouts(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) -> (tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma>) { %cst = arith.constant dense<0> : tensor<1x16xi32, #blocked> %cst_0 = arith.constant dense<0> : tensor<128x1xi32, #blocked1> @@ -671,8 +904,9 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : %14 = tt.broadcast %11 : tensor<1x16x!tt.ptr, #blocked> -> tensor<64x16x!tt.ptr, #blocked> %15 = tt.broadcast %13 : tensor<64x1xi32, #blocked> -> tensor<64x16xi32, #blocked> %16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> - // CHECK-NOT: triton_gpu.insert_slice_async - // CHECK: scf.for + // check that the load didn't get pipelined. + // COMMON-NOT: alloc + // COMMON: scf.for %17:2 = scf.for %arg2 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg3 = %cst_1, %arg4 = %cst_2) -> (tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma>) : i32 { %18 = tt.load %16 : tensor<64x16x!tt.ptr, #blocked> %19 = triton_gpu.convert_layout %9 : tensor<128x64xf16, #blocked1> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> @@ -680,9 +914,9 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : %21 = tt.dot %19, %20, %cst_1 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x16xf32, #mma> %22 = arith.truncf %21 : tensor<128x16xf32, #mma> to tensor<128x16xf16, #mma> %23 = triton_gpu.convert_layout %22 : tensor<128x16xf16, #mma> -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - %24 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared> - %25 = tt.trans %24 {order=array} : !tt.memdesc<64x16xf16, #shared> -> !tt.memdesc<16x64xf16, #shared1> - %26 = triton_gpu.local_load %25 : !tt.memdesc<16x64xf16, #shared1> -> tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %24 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared, #triton_gpu.shared_memory> + %25 = tt.trans %24 {order=array} : !tt.memdesc<64x16xf16, #shared, #triton_gpu.shared_memory> -> !tt.memdesc<16x64xf16, #shared1, #triton_gpu.shared_memory> + %26 = triton_gpu.local_load %25 : !tt.memdesc<16x64xf16, #shared1, #triton_gpu.shared_memory> -> tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> %27 = tt.dot %23, %26, %arg4 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x64xf32, #mma> scf.yield %21, %27 : tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma> } @@ -704,6 +938,15 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : // CHECK: triton_gpu.async_copy_global_to_local // CHECK: triton_gpu.async_commit_group // CHECK: scf.yield + +// AMD-LABEL: tt.func public @nested_loops +// AMD: scf.for +// AMD: triton_gpu.local_alloc +// AMD-NOT: triton_gpu.local_alloc +// AMD: scf.for +// AMD: scf.yield +// AMD-DIS: scf.yield + // // The following code has the structure: // @@ -717,17 +960,12 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : // } // ``` // -// Only the outer for should be pipelined. The regression this tests -// causes an assertion to fail while pipelining the outer `for`, in -// particular while predicating the operations scheduled to be emitted -// in the prologue. -// -// We check that there is no allocation before the first occurrence of -// scf.for because that would mean that the first load `%a = load()` -// would be pipelined. +// For CUDA, we pipeline the inner loop first then pipeline the outer +// loop to prefetch the async copy after the inner loop. +// For HIP, we only pipeline the inner loop for now. #blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> #mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { tt.func public @nested_loops(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> %cst_0 = arith.constant dense<320> : tensor<32x1xi32, #blocked> @@ -789,7 +1027,7 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : #mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> #shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1], hasLeadingOffset = false}> #shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { tt.func public @_jagged_hstu_attn_fwd_0d1d2d3d4d5de(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}) attributes {noinline = false} { %cst = arith.constant dense<0.000000e+00> : tensor<64x32xf32, #mma> %c64_i32 = arith.constant 64 : i32 @@ -861,9 +1099,9 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : %63 = scf.for %arg6 = %c0_i32 to %c64_i32 step %c32_i32 iter_args(%arg7 = %cst) -> (tensor<64x32xf32, #mma>) : i32 { %70 = tt.load %59 : tensor<32x64x!tt.ptr, #blocked1> %71 = triton_gpu.convert_layout %62 : tensor<64x64xf32, #blocked> -> tensor<64x64xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> - %72 = triton_gpu.local_alloc %70 : (tensor<32x64xf32, #blocked1>) -> !tt.memdesc<32x64xf32, #shared> - %73 = tt.trans %72 {order=array} : !tt.memdesc<32x64xf32, #shared> -> !tt.memdesc<64x32xf32, #shared1> - %74 = triton_gpu.local_load %73 : !tt.memdesc<64x32xf32, #shared1> -> tensor<64x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + %72 = triton_gpu.local_alloc %70 : (tensor<32x64xf32, #blocked1>) -> !tt.memdesc<32x64xf32, #shared, #triton_gpu.shared_memory> + %73 = tt.trans %72 {order=array} : !tt.memdesc<32x64xf32, #shared, #triton_gpu.shared_memory> -> !tt.memdesc<64x32xf32, #shared1, #triton_gpu.shared_memory> + %74 = triton_gpu.local_load %73 : !tt.memdesc<64x32xf32, #shared1, #triton_gpu.shared_memory> -> tensor<64x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> %75 = tt.dot %71, %74, %cst : tensor<64x64xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<64x32xf32, #mma> %76 = tt.load %61 : tensor<32x32x!tt.ptr, #blocked1> %77 = triton_gpu.convert_layout %75 : tensor<64x32xf32, #mma> -> tensor<64x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> @@ -888,7 +1126,7 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : // CHECK: scf.for // CHECK: %[[NEXT_BUFFER_1:.*]] = tt.addptr %{{.*}}, {{.*}} // CHECK: triton_gpu.async_copy_global_to_local %[[NEXT_BUFFER_1]] -// CHECK: %[[IND_BUFFER_0:.*]] = triton_gpu.memdesc_subview {{.*}} : !tt.memdesc<1x16xi64, #[[$SHARED_LAYOUT]], mutable> -> !tt.memdesc<16xi64, #[[$SHARED_LAYOUT]], mutable> +// CHECK: %[[IND_BUFFER_0:.*]] = triton_gpu.memdesc_subview {{.*}} : !tt.memdesc<1x16xi64, #[[$SHARED_LAYOUT]], #triton_gpu.shared_memory, mutable> -> !tt.memdesc<16xi64, #[[$SHARED_LAYOUT]], #triton_gpu.shared_memory, mutable> // CHECK: %[[IND_BUFFER_1:.*]] = triton_gpu.local_load %[[IND_BUFFER_0]] // CHECK: %[[IND_BUFFER_2:.*]] = tt.expand_dims %[[IND_BUFFER_1]] {axis = 1 : i32} // CHECK: %[[IND_BUFFER_3:.*]] = tt.broadcast %[[IND_BUFFER_2]] @@ -896,13 +1134,39 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : // CHECK: %[[NEXT_BUFFER_0:.*]] = tt.addptr {{.*}}, %[[IND_BUFFER_4]] // CHECK: triton_gpu.async_copy_global_to_local %[[NEXT_BUFFER_0]] // CHECK: triton_gpu.async_wait {{.*}} {num = 1 : i32} + +// AMD-DIS: #[[$SHARED_LAYOUT:shared.*]] = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], hasLeadingOffset = false}> +// AMD-LABEL: tt.func @indirect_load_shared_layout +// AMD: %{{.*}}:7 = scf.for %[[ARG6:[a-z0-9]*]] = +// AMD-SAME: iter_args(%[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %{{.*}}, %[[ARG11:.*]] = %{{.*}}, %[[ARG12:.*]] = %{{.*}}, %[[ARG13:.*]] = %{{.*}}, %[[ARG14:.*]] = %{{.*}}) +// AMD: %[[LOCAL_LOAD_27:.*]] = triton_gpu.local_load %[[ARG12]] +// AMD: %[[LOCAL_LOAD_28:.*]] = triton_gpu.local_load %[[ARG13]] +// AMD: %[[DOT_31:.*]] = tt.dot %[[LOCAL_LOAD_27]], %[[LOCAL_LOAD_28]], %[[ARG7]] +// AMD: %[[ADDPTR_32:.*]] = tt.addptr %[[ARG8]], %{{.*}} +// AMD: %[[ADDPTR_33:.*]] = tt.addptr %[[ARG9]], %{{.*}} +// AMD: %[[LOAD_35:.*]] = tt.load %[[ADDPTR_32]] +// AMD: %[[EXPAND_DIMS_36:.*]] = tt.expand_dims %[[ARG14]] {axis = 1 : i32} +// AMD: %[[BROADCAST_37:.*]] = tt.broadcast %[[EXPAND_DIMS_36]] +// AMD: %[[MULI_38:.*]] = arith.muli %{{.*}}, %[[BROADCAST_37]] +// AMD: %[[ADDPTR_39:.*]] = tt.addptr %{{.*}}, %[[MULI_38]] +// AMD: %[[LOAD_41:.*]] = tt.load %[[ADDPTR_39]] +// AMD: %[[LOAD_43:.*]] = tt.load %[[ADDPTR_33]] +// AMD: %[[ADDI_44:.*]] = arith.addi %[[ARG11]], %{{.*}} +// AMD: %[[CMPI_45:.*]] = arith.cmpi slt, %[[ADDI_44]], %{{.*}} +// AMD: %[[SELECT_46:.*]] = arith.select %[[CMPI_45]], %[[ADDI_44]], %{{.*}} +// AMD: %[[MEMDESC_SUBVIEW_47:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_46]], %{{.*}}, %{{.*}}] +// AMD: triton_gpu.local_store %[[LOAD_35]], %[[MEMDESC_SUBVIEW_47]] +// AMD: %[[MEMDESC_SUBVIEW_48:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_46]], %{{.*}}, %{{.*}}] +// AMD: triton_gpu.local_store %[[LOAD_41]], %[[MEMDESC_SUBVIEW_48]] +// AMD: scf.yield %[[DOT_31]], %[[ADDPTR_32]], %[[ADDPTR_33]], %[[SELECT_46]], %[[MEMDESC_SUBVIEW_47]], %[[MEMDESC_SUBVIEW_48]], %[[LOAD_43]] + #AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> #BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> #BLs1 = #triton_gpu.slice<{parent=#BL, dim=1}> #C = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> #A = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth=2}> #B = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth=2}> -module attributes {"triton_gpu.target" = "cuda:86", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { tt.func @indirect_load_shared_layout(%77: tensor<16x16xi64, #BL> {tt.divisibility=16: i32, tt.constancy=16: i32}, %76: index, %49: tensor<16x16x!tt.ptr, #AL> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, @@ -930,7 +1194,7 @@ tt.func @indirect_load_shared_layout(%77: tensor<16x16xi64, #BL> {tt.divisibilit %91 = tt.addptr %arg20, %78 : tensor<16x16x!tt.ptr, #AL>, tensor<16x16xi32, #AL> %92 = tt.addptr %arg21, %c1_i32_splat : tensor<16x!tt.ptr, #BLs1>, tensor<16xi32, #BLs1> scf.yield %90, %91, %92 : tensor<16x16xf32, #C>, tensor<16x16x!tt.ptr, #AL>, tensor<16x!tt.ptr, #BLs1> - } + } {tt.num_stages = 3 : i32} tt.return %79#0 : tensor<16x16xf32, #C> } } @@ -945,9 +1209,19 @@ tt.func @indirect_load_shared_layout(%77: tensor<16x16xi64, #BL> {tt.divisibilit // CHECK: triton_gpu.async_copy_global_to_local // CHECK: triton_gpu.memdesc_subview // CHECK: tt.return + +// AMD-LABEL: @kernel_yield_constant +// AMD: tt.load +// AMD: triton_gpu.memdesc_subview +// AMD: triton_gpu.local_store +// AMD: scf.for +// AMD: tt.load +// AMD: triton_gpu.memdesc_subview +// AMD: triton_gpu.local_store +// AMD: tt.return #blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> #mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}> -module attributes {"triton_gpu.target" = "cuda:86", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { tt.func public @kernel_yield_constant(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 16 : i32}) attributes {noinline = false} { %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> %cst1 = arith.constant dense<1.000000e+00> : tensor<32x32xf32, #mma> @@ -1001,8 +1275,22 @@ module attributes {"triton_gpu.target" = "cuda:86", "triton_gpu.num-ctas" = 1 : // CHECK: %[[B1BUFFER:.*]] = triton_gpu.memdesc_subview %[[BBUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]]] // CHECK: triton_gpu.async_copy_global_to_local {{.*}}, %[[B1BUFFER]] // CHECK: scf.for + +// AMD-LABEL: tt.func public @add_kernel +// AMD: %[[LOAD_11:.*]] = tt.load %{{.*}}, %{{.*}} +// AMD: %[[ADDPTR_12:.*]] = tt.addptr %{{.*}}, %{{.*}} +// AMD: %[[LOAD_13:.*]] = tt.load %[[ADDPTR_12]], %{{.*}} +// AMD: %[[ADDI_14:.*]] = arith.addi %{{.*}}, %{{.*}} +// AMD: %[[SPLAT_15:.*]] = tt.splat %[[ADDI_14]] +// AMD: %[[ADDI_16:.*]] = arith.addi %[[SPLAT_15]], %{{.*}} +// AMD: %[[CMPI_17:.*]] = arith.cmpi slt, %[[ADDI_16]], %{{.*}} +// AMD: %[[ADDPTR_18:.*]] = tt.addptr %{{.*}}, %[[ADDI_16]] +// AMD: %[[LOAD_19:.*]] = tt.load %[[ADDPTR_18]], %[[CMPI_17]] +// AMD: %[[ADDPTR_20:.*]] = tt.addptr %{{.*}}, %[[ADDI_16]] +// AMD: %[[LOAD_21:.*]] = tt.load %[[ADDPTR_20]], %[[CMPI_17]] +// AMD: scf.for #blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { tt.func public @add_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 16 : i32}) attributes {noinline = false} { %c1024_i32 = arith.constant 1024 : i32 %c0_i32 = arith.constant 0 : i32 @@ -1026,7 +1314,7 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : %15 = arith.addf %12, %14 : tensor<1024xf32, #blocked> %16 = tt.addptr %6, %9 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> tt.store %16, %15, %10 : tensor<1024x!tt.ptr, #blocked> - }{tt.num_stages = 3 : i32} + } {tt.num_stages = 3 : i32} tt.return } } @@ -1067,11 +1355,22 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : // CHECK: %[[COMMIT_2:.*]] = triton_gpu.async_commit_group %[[ASYNC_COPY_5]] // CHECK: scf.yield %[[COMMIT_1]], %[[COMMIT_2]] // CHECK: triton_gpu.local_dealloc %[[BUFFER_1]] + +// AMD-LABEL: tt.func public @nested_loops +// AMD-NOT: triton_gpu.local_alloc +// AMD: scf.for +// AMD: triton_gpu.local_alloc +// AMD: scf.for +// AMD: triton_gpu.local_load +// AMD: tt.dot +// AMD: triton_gpu.local_store +// AMD: scf.yield +// AMD: triton_gpu.local_dealloc #blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [2, 1], order = [1, 0]}> #mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 2], instrShape = [16, 8]}> #shared = #triton_gpu.shared<{vec = 4, perPhase = 2, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}> #shared1 = #triton_gpu.shared<{vec = 4, perPhase = 2, maxPhase = 4, order = [0, 1], hasLeadingOffset = false}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32} { tt.func public @nested_loops(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma> %c1_i32 = arith.constant 1 : i32 @@ -1090,9 +1389,9 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : %9 = tt.addptr %7, %8 : tensor<16x16x!tt.ptr, #blocked>, tensor<16x16xi32, #blocked> scf.for %arg1 = %c0_i32 to %c2_i32 step %c1_i32 : i32 { %10 = tt.load %9 : tensor<16x16x!tt.ptr, #blocked> - %11 = triton_gpu.local_alloc %10 : (tensor<16x16xf32, #blocked>) -> !tt.memdesc<16x16xf32, #shared> - %12 = tt.trans %11 {order = array} : !tt.memdesc<16x16xf32, #shared> -> !tt.memdesc<16x16xf32, #shared1> - %13 = triton_gpu.local_load %12 : !tt.memdesc<16x16xf32, #shared1> -> tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + %11 = triton_gpu.local_alloc %10 : (tensor<16x16xf32, #blocked>) -> !tt.memdesc<16x16xf32, #shared, #triton_gpu.shared_memory> + %12 = tt.trans %11 {order = array} : !tt.memdesc<16x16xf32, #shared, #triton_gpu.shared_memory> -> !tt.memdesc<16x16xf32, #shared1, #triton_gpu.shared_memory> + %13 = triton_gpu.local_load %12 : !tt.memdesc<16x16xf32, #shared1, #triton_gpu.shared_memory> -> tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> scf.for %arg2 = %c0_i32 to %c2_i32 step %c1_i32 : i32 { %14 = tt.load %9 : tensor<16x16x!tt.ptr, #blocked> %15 = triton_gpu.convert_layout %14 : tensor<16x16xf32, #blocked> -> tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> @@ -1115,7 +1414,7 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : #blocked4 = #triton_gpu.blocked<{sizePerThread = [16, 2, 1], threadsPerWarp = [4, 1, 8], warpsPerCTA = [1, 1, 8], order = [1, 0, 2]}> #blocked5 = #triton_gpu.blocked<{sizePerThread = [32, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 8], order = [0, 1]}> #mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 8], instrShape = [16, 8]}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32} { tt.func public @int4_matmul_ampere( %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32} @@ -1146,14 +1445,12 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : %51 = tt.addptr %50, %47 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> // Check that both loads in the loop are pipelined. - // TODO(jlebar): https://github.com/triton-lang/triton/pull/3472 disables the - // relevant optimization. Once we've reenabled it, we can uncomment this test. // CHECK: scf.for - // COM: CHECK-NOT: tt.load + // CHECK-NOT: tt.load // CHECK: triton_gpu.async_copy_global_to_local - // COM: CHECK-NOT: tt.load - // COM: CHECK: triton_gpu.async_copy_global_to_local - // COM: CHECK-NOT: tt.load + // CHECK-NOT: tt.load + // CHECK: triton_gpu.async_copy_global_to_local + // CHECK-NOT: tt.load // CHECK: scf.yield %54:3 = scf.for %arg9 = %c0_i32 to %c16_i32 step %c1_i32 iter_args(%arg10 = %cst_3, %arg11 = %41, %arg12 = %51) -> (tensor<16x256xf32, #mma>, tensor<16x128x!tt.ptr, #blocked1>, tensor<64x256x!tt.ptr, #blocked>) : i32 { %78 = tt.load %arg11 : tensor<16x128x!tt.ptr, #blocked1> @@ -1182,7 +1479,7 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : // This test triggered some failure in the verifier, so we only // included a simple check for the kernel name. -// CHECK-LABEL: @load_convert_layout +// COMMON-LABEL: @load_convert_layout #AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> #BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> #ALs0 = #triton_gpu.slice<{parent=#AL, dim=0}> @@ -1192,7 +1489,7 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : #A = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth=2}> #B = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth=2}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { tt.func @load_convert_layout(%77: tensor<16x16xi64, #BL> {tt.divisibility=16: i32, tt.constancy=16: i32}, %76: index, %49: tensor<16x16x!tt.ptr, #AL> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, @@ -1223,7 +1520,7 @@ tt.func @load_convert_layout(%77: tensor<16x16xi64, #BL> {tt.divisibility=16: i3 %91 = tt.addptr %arg20, %78 : tensor<16x16x!tt.ptr, #AL>, tensor<16x16xi32, #AL> %92 = tt.addptr %arg21, %c1_i32_splat : tensor<16x!tt.ptr, #BLs1>, tensor<16xi32, #BLs1> scf.yield %90, %91, %92 : tensor<16x16xf32, #C>, tensor<16x16x!tt.ptr, #AL>, tensor<16x!tt.ptr, #BLs1> - } + } {tt.num_stages = 3 : i32} tt.return %79#0 : tensor<16x16xf32, #C> } } @@ -1233,10 +1530,10 @@ tt.func @load_convert_layout(%77: tensor<16x16xi64, #BL> {tt.divisibility=16: i3 // This test captured some ICE in MatmulLoopPipeline pass, so we only // included a simple check for the kernel name. -// CHECK-LABEL: @matmul_indirect_pipeline +// COMMON-LABEL: @matmul_indirect_pipeline #blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 2], order = [0, 1]}> #mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 1], instrShape = [16, 8]}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32} { tt.func public @matmul_indirect_pipeline(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> %c1_i32 = arith.constant 1 : i32 @@ -1269,18 +1566,18 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : %23 = tt.dot %21, %22, %cst : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma> %24 = triton_gpu.convert_layout %23 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> tt.store %11, %24 : tensor<32x32x!tt.ptr, #blocked> - } + } {tt.num_stages = 3 : i32} tt.return } } // ----- -// CHECK-LABEL: @dont_pipeline_128x1 -// CHECK-NOT: local_load{{.*}}128x1 +// COMMON-LABEL: @dont_pipeline_128x1 +// COMMON-NOT: local_load{{.*}}128x1 #blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> #mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { tt.func public @dont_pipeline_128x1(%arg6: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { %cst = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> %c128_i32 = arith.constant 128 : i32 @@ -1319,8 +1616,8 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : // Check that the dependencies across ops of different nesting does not cause crash or // incorrect schedule that fails to pipeline. -// CHECK-LABEL: @matmul_nested_ops -// CHECK: triton_gpu.local_load +// COMMON-LABEL: @matmul_nested_ops +// COMMON: triton_gpu.local_load #AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> #BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> @@ -1331,7 +1628,7 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : #A = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth=2}> #B = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth=2}> -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.target" = "cuda:80"} { +module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32} { tt.func @matmul_nested_ops(%lb : index, %ub : index, %step : index, %A : !tt.ptr {tt.divisibility = 16 : i32}, %B : !tt.ptr {tt.divisibility = 16 : i32}, @@ -1389,9 +1686,9 @@ tt.func @matmul_nested_ops(%lb : index, %ub : index, %step : index, #mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16, 16]}> #shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> #shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { - // CHECK-LABEL: dot_prologue_epilogue - // CHECK: {{.*}}, {{.*}}, %[[EXT:.*]]: i32, {{.*}} +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + // COMMON-LABEL: dot_prologue_epilogue + // COMMON: {{.*}}, {{.*}}, %[[EXT:.*]]: i32, {{.*}} tt.func @dot_prologue_epilogue(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { %cst = arith.constant dense<0> : tensor<64x16xi32, #blocked> %cst2 = arith.constant dense<0> : tensor<128x64xi32, #blocked1> @@ -1414,17 +1711,17 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : %14 = tt.broadcast %10 : tensor<1x16x!tt.ptr, #blocked> -> tensor<64x16x!tt.ptr, #blocked> %15 = tt.broadcast %13 : tensor<64x1xi32, #blocked> -> tensor<64x16xi32, #blocked> %16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> - // CHECK: %[[C0:.*]] = arith.constant 0 : i32 - // CHECK: scf.for %[[IND_VAR:.*]] = %[[C0]] - // CHECK-NOT load - // CHECK: %[[CND:.*]] = arith.cmpi slt, %[[IND_VAR]], %[[EXT]] - // CHECK: scf.if %[[CND]] - // CHECK: dot - // CHECK: scf.if %[[CND]] - // CHECK: arith.mulf - // CHECK: scf.yield - // CHECK-NOT: tt.addptr - // CHECK: scf.yield + // COMMON: %[[C0:.*]] = arith.constant 0 : i32 + // COMMON: scf.for %[[IND_VAR:.*]] = %[[C0]] + // COMMON-NOT: load + // COMMON: %[[CND:.*]] = arith.cmpi slt, %[[IND_VAR]], %[[EXT]] + // COMMON: scf.if %[[CND]] + // COMMON: dot + // COMMON: scf.if %[[CND]] + // COMMON: arith.mulf + // COMMON: scf.yield + // COMMON-NOT: tt.addptr + // COMMON: scf.yield %17:3 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2, %arg5 = %16, %arg6 = %8) -> (tensor<128x16xf32, #mma1>, tensor<64x16x!tt.ptr, #blocked>, tensor<128x64x!tt.ptr, #blocked1>) : i32 { %9 = tt.load %arg6 : tensor<128x64x!tt.ptr, #blocked1> %cnd = arith.cmpi slt, %arg3, %ext : i32 @@ -1435,9 +1732,9 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : scf.yield %arg5 : tensor<64x16x!tt.ptr, #blocked> } %18 = tt.load %inc_ptr : tensor<64x16x!tt.ptr, #blocked> - %19 = triton_gpu.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !tt.memdesc<128x64xf16, #shared> - %20 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared1> - %acc = tt.dot %19, %20, %arg4 : !tt.memdesc<128x64xf16, #shared> * !tt.memdesc<64x16xf16, #shared1> -> tensor<128x16xf32, #mma1> + %19 = triton_gpu.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> + %20 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> + %acc = triton_nvidia_gpu.warp_group_dot %19, %20, %arg4 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> %acc_ = scf.if %cnd -> (tensor<128x16xf32, #mma1>) { %acc_zero = arith.mulf %acc, %cst_2 : tensor<128x16xf32, #mma1> scf.yield %acc_zero : tensor<128x16xf32, #mma1> @@ -1461,7 +1758,7 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : #mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16, 16]}> #shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> #shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { // CHECK-NOCANON-LABEL: pipeline_downstream_dependencies // CHECK-NOCANON: {{.*}}, {{.*}}, %[[EXT:.*]]: i32, {{.*}} tt.func @pipeline_downstream_dependencies(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { @@ -1500,9 +1797,9 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : %17:3 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2, %arg5 = %16, %arg6 = %8) -> (tensor<128x16xf32, #mma1>, tensor<64x16x!tt.ptr, #blocked>, tensor<128x64x!tt.ptr, #blocked1>) : i32 { %9 = tt.load %arg6 : tensor<128x64x!tt.ptr, #blocked1> %18 = tt.load %arg5 : tensor<64x16x!tt.ptr, #blocked> - %19 = triton_gpu.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !tt.memdesc<128x64xf16, #shared> - %20 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared1> - %acc = tt.dot %19, %20, %arg4 : !tt.memdesc<128x64xf16, #shared> * !tt.memdesc<64x16xf16, #shared1> -> tensor<128x16xf32, #mma1> + %19 = triton_gpu.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> + %20 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> + %acc = triton_nvidia_gpu.warp_group_dot %19, %20, %arg4 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> %cnd = arith.cmpi slt, %arg3, %ext : i32 %if_ret:2 = scf.if %cnd -> (tensor<128x16xf32, #mma1>, tensor<64x16xi32, #blocked>) { %acc_zero = arith.mulf %acc, %cst_2 : tensor<128x16xf32, #mma1> @@ -1528,8 +1825,20 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : // CHECK: %[[B:.*]] = triton_gpu.local_load // CHECK: arith.select {{.*}}, %[[B]], %[[CONSTANT]] +// AMD-LABEL: @masked_add_kernel +// AMD: %[[CONSTANT:.*]] = arith.constant dense<0xFF800000> +// AMD: tt.load {{.*}}, %{{.*}}, %[[CONSTANT]] +// AMD: tt.load {{.*}}, %{{.*}}, %[[CONSTANT]] +// AMD: tt.load {{.*}}, %{{.*}}, %[[CONSTANT]] +// AMD: tt.load {{.*}}, %{{.*}}, %[[CONSTANT]] +// AMD: scf.for +// AMD: arith.select +// AMD: arith.addf +// AMD: %[[A:.*]] = tt.load {{.*}}, %{{.*}}, %[[CONSTANT]] +// AMD: %[[B:.*]] = tt.load {{.*}}, %{{.*}}, %[[CONSTANT]] + #blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { tt.func public @masked_add_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 16 : i32}) attributes {noinline = false} { %c1024_i32 = arith.constant 1024 : i32 %c0_i32 = arith.constant 0 : i32 @@ -1566,11 +1875,11 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> #mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 256, 16]}> #shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: @matmul_tma -// CHECK-DAG: triton_gpu.local_alloc : () -> !tt.memdesc<3x128x64xf16, #{{.+}}, mutable> -// CHECK-DAG: triton_gpu.local_alloc : () -> !tt.memdesc<3x64x256xf16, #{{.+}}, mutable> -// CHECK-DAG: triton_gpu.local_alloc : () -> !tt.memdesc<3xi64, #{{.+}}, mutable> +// CHECK-DAG: triton_gpu.local_alloc : () -> !tt.memdesc<3x128x64xf16, #{{.+}}, #triton_gpu.shared_memory, mutable> +// CHECK-DAG: triton_gpu.local_alloc : () -> !tt.memdesc<3x64x256xf16, #{{.+}}, #triton_gpu.shared_memory, mutable> +// CHECK-DAG: triton_gpu.local_alloc : () -> !tt.memdesc<3xi64, #{{.+}}, #triton_gpu.shared_memory, mutable> // CHECK-COUNT-3: triton_nvidia_gpu.init_barrier // CHECK-COUNT-4: triton_nvidia_gpu.async_tma_copy_global_to_local // CHECK: scf.for @@ -1586,10 +1895,10 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma> %0:2 = scf.for %arg3 = %c0_i32 to %c256_i32 step %c1_i32 iter_args(%arg4 = %cst, %arg5 = %c0_i32) -> (tensor<128x256xf32, #mma>, i32) : i32 { %1 = tt.experimental_descriptor_load %arg0[%c0_i32, %arg5] : !tt.ptr -> tensor<128x64xf16, #blocked> - %2 = triton_gpu.local_alloc %1 : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared> + %2 = triton_gpu.local_alloc %1 : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> %3 = tt.experimental_descriptor_load %arg1[%arg5, %c0_i32] : !tt.ptr -> tensor<64x256xf16, #blocked1> - %4 = triton_gpu.local_alloc %3 : (tensor<64x256xf16, #blocked1>) -> !tt.memdesc<64x256xf16, #shared> - %5 = tt.dot %2, %4, %arg4, inputPrecision = tf32 : !tt.memdesc<128x64xf16, #shared> * !tt.memdesc<64x256xf16, #shared> -> tensor<128x256xf32, #mma> + %4 = triton_gpu.local_alloc %3 : (tensor<64x256xf16, #blocked1>) -> !tt.memdesc<64x256xf16, #shared, #triton_gpu.shared_memory> + %5 = triton_nvidia_gpu.warp_group_dot %2, %4, %arg4 { inputPrecision = 0 : i32 } : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x256xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x256xf32, #mma> %6 = arith.addi %arg5, %c64_i32 : i32 scf.yield %5, %6 : tensor<128x256xf32, #mma>, i32 } diff --git a/test/TritonGPU/pipeline-hopper-remove-wait.mlir b/test/TritonGPU/pipeline-hopper-remove-wait.mlir index 1e3d4d96708b..74fd2e05551b 100644 --- a/test/TritonGPU/pipeline-hopper-remove-wait.mlir +++ b/test/TritonGPU/pipeline-hopper-remove-wait.mlir @@ -110,18 +110,18 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : %112 = tt.load %111 : tensor<64x128x!tt.ptr, #blocked> %113 = triton_gpu.local_alloc %38 : (tensor<128x128xf16, #blocked>) -> !tt.memdesc<128x128xf16, #shared> %114 = triton_gpu.local_alloc %90 : (tensor<128x64xf16, #blocked2>) -> !tt.memdesc<128x64xf16, #shared1> - %115 = tt.dot %113, %114, %cst :!tt.memdesc<128x128xf16, #shared> * !tt.memdesc<128x64xf16, #shared1> -> tensor<128x64xf32, #mma> + %115 = triton_nvidia_gpu.warp_group_dot %113, %114, %cst :!tt.memdesc<128x128xf16, #shared> * !tt.memdesc<128x64xf16, #shared1> -> tensor<128x64xf32, #mma> %116 = arith.truncf %115 : tensor<128x64xf32, #mma> to tensor<128x64xf16, #mma> %117 = triton_gpu.local_alloc %112 : (tensor<64x128xf16, #blocked>) -> !tt.memdesc<64x128xf16, #shared> %118 = triton_gpu.convert_layout %116 : tensor<128x64xf16, #mma> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> // The first dot gets converted to dot-async + wait. The second one // doesn't have a wait because the first wait is sufficient. - // CHECK: triton_nvidia_gpu.dot_async - // CHECK: triton_nvidia_gpu.dot_wait {{.*}} {pendings = 0 : i32} - // CHECK: triton_nvidia_gpu.dot_async - // CHECK-NOT: triton_nvidia_gpu.dot_wait + // CHECK: triton_nvidia_gpu.warp_group_dot + // CHECK: triton_nvidia_gpu.warp_group_dot_wait {{.*}} {pendings = 0 : i32} + // CHECK: triton_nvidia_gpu.warp_group_dot + // CHECK-NOT: triton_nvidia_gpu.warp_group_dot_wait // CHECK: scf.yield - %119 = tt.dot %118, %117, %arg23 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * !tt.memdesc<64x128xf16, #shared> -> tensor<128x128xf32, #mma1> + %119 = triton_nvidia_gpu.warp_group_dot %118, %117, %arg23 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * !tt.memdesc<64x128xf16, #shared> -> tensor<128x128xf32, #mma1> %120 = arith.mulf %arg24, %arg25 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> %121 = arith.addf %120, %arg25 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> %122 = arith.extsi %c0_i32 : i32 to i64 diff --git a/test/TritonGPU/reduce-data-duplication.mlir b/test/TritonGPU/reduce-data-duplication.mlir index 7dd91df04187..e98a6108d758 100644 --- a/test/TritonGPU/reduce-data-duplication.mlir +++ b/test/TritonGPU/reduce-data-duplication.mlir @@ -2,7 +2,7 @@ // CHECK: #[[SHARED:.*]] = #triton_gpu.shared<{vec = 8, perPhase = 8, maxPhase = 2, order = [0, 1], hasLeadingOffset = false} // CHECK: apply_swizzle -// CHECK: %{{.*}} = triton_gpu.local_alloc %{{.*}} : (tensor<16x256xf16, #{{.*}}>) -> !tt.memdesc<16x256xf16, #[[SHARED]]> +// CHECK: %{{.*}} = triton_gpu.local_alloc %{{.*}} : (tensor<16x256xf16, #{{.*}}>) -> !tt.memdesc<16x256xf16, #[[SHARED]], #triton_gpu.shared_memory> #blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 4], order = [0, 1]}> #mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [16, 8]}> diff --git a/test/TritonNvidiaGPU/membar.mlir b/test/TritonNvidiaGPU/membar.mlir index 95202f12ee7e..358f53fd7cd6 100644 --- a/test/TritonNvidiaGPU/membar.mlir +++ b/test/TritonNvidiaGPU/membar.mlir @@ -9,8 +9,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-NEXT: init_barrier tt.func @init_barrier() { %cst = arith.constant dense<0> : tensor<1xi64, #blocked0> - %alloc = triton_gpu.local_alloc %cst : (tensor<1xi64, #blocked0>) -> !tt.memdesc<1xi64, #shared0, mutable> - triton_nvidia_gpu.init_barrier %alloc, 1 : !tt.memdesc<1xi64, #shared0, mutable> + %alloc = triton_gpu.local_alloc %cst : (tensor<1xi64, #blocked0>) -> !tt.memdesc<1xi64, #shared0, #triton_gpu.shared_memory, mutable> + triton_nvidia_gpu.init_barrier %alloc, 1 : !tt.memdesc<1xi64, #shared0, #triton_gpu.shared_memory, mutable> tt.return } } @@ -28,9 +28,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-NEXT: inval_barrier tt.func @inval_barrier() { %cst = arith.constant dense<0> : tensor<1xi64, #blocked0> - %alloc = triton_gpu.local_alloc %cst : (tensor<1xi64, #blocked0>) -> !tt.memdesc<1xi64, #shared0, mutable> - triton_nvidia_gpu.init_barrier %alloc, 1 : !tt.memdesc<1xi64, #shared0, mutable> - triton_nvidia_gpu.inval_barrier %alloc : !tt.memdesc<1xi64, #shared0, mutable> + %alloc = triton_gpu.local_alloc %cst : (tensor<1xi64, #blocked0>) -> !tt.memdesc<1xi64, #shared0, #triton_gpu.shared_memory, mutable> + triton_nvidia_gpu.init_barrier %alloc, 1 : !tt.memdesc<1xi64, #shared0, #triton_gpu.shared_memory, mutable> + triton_nvidia_gpu.inval_barrier %alloc : !tt.memdesc<1xi64, #shared0, #triton_gpu.shared_memory, mutable> tt.return } } @@ -48,9 +48,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-NEXT: barrier_expect tt.func @barrier_expect(%pred : i1) { %cst = arith.constant dense<0> : tensor<1xi64, #blocked0> - %alloc = triton_gpu.local_alloc %cst : (tensor<1xi64, #blocked0>) -> !tt.memdesc<1xi64, #shared0, mutable> - triton_nvidia_gpu.init_barrier %alloc, 1 : !tt.memdesc<1xi64, #shared0, mutable> - triton_nvidia_gpu.barrier_expect %alloc, 16384, %pred : <1xi64, #shared0, mutable> + %alloc = triton_gpu.local_alloc %cst : (tensor<1xi64, #blocked0>) -> !tt.memdesc<1xi64, #shared0, #triton_gpu.shared_memory, mutable> + triton_nvidia_gpu.init_barrier %alloc, 1 : !tt.memdesc<1xi64, #shared0, #triton_gpu.shared_memory, mutable> + triton_nvidia_gpu.barrier_expect %alloc, 16384, %pred : <1xi64, #shared0, #triton_gpu.shared_memory, mutable> tt.return } } @@ -68,9 +68,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-NEXT: wait_barrier tt.func @wait_barrier(%phase : i32) { %cst = arith.constant dense<0> : tensor<1xi64, #blocked0> - %alloc = triton_gpu.local_alloc %cst : (tensor<1xi64, #blocked0>) -> !tt.memdesc<1xi64, #shared0, mutable> - triton_nvidia_gpu.init_barrier %alloc, 1 : !tt.memdesc<1xi64, #shared0, mutable> - triton_nvidia_gpu.wait_barrier %alloc, %phase : <1xi64, #shared0, mutable> + %alloc = triton_gpu.local_alloc %cst : (tensor<1xi64, #blocked0>) -> !tt.memdesc<1xi64, #shared0, #triton_gpu.shared_memory, mutable> + triton_nvidia_gpu.init_barrier %alloc, 1 : !tt.memdesc<1xi64, #shared0, #triton_gpu.shared_memory, mutable> + triton_nvidia_gpu.wait_barrier %alloc, %phase : <1xi64, #shared0, #triton_gpu.shared_memory, mutable> tt.return } } @@ -89,8 +89,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-NEXT: gpu.barrier // CHECK-NEXT: init_barrier %cst = arith.constant dense<0> : tensor<128x64xi64, #blocked0> - %alloc = triton_gpu.local_alloc %cst : (tensor<128x64xi64, #blocked0>) -> !tt.memdesc<128x64xi64, #shared0, mutable> - triton_gpu.local_dealloc %alloc : !tt.memdesc<128x64xi64, #shared0, mutable> + %alloc = triton_gpu.local_alloc %cst : (tensor<128x64xi64, #blocked0>) -> !tt.memdesc<128x64xi64, #shared0, #triton_gpu.shared_memory, mutable> + triton_gpu.local_dealloc %alloc : !tt.memdesc<128x64xi64, #shared0, #triton_gpu.shared_memory, mutable> %l = tt.experimental_descriptor_load %arg0[%arg1, %arg1] : !tt.ptr -> tensor<128x64xf16, #blocked0> tt.return %l : tensor<128x64xf16, #blocked0> } @@ -108,8 +108,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-NEXT: triton_gpu.local_alloc tt.func public @tma_store(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: tensor<128x256xf32, #blocked0>) { %cst = arith.constant dense<0> : tensor<128x64xi64, #blocked0> - %alloc = triton_gpu.local_alloc %cst : (tensor<128x64xi64, #blocked0>) -> !tt.memdesc<128x64xi64, #shared0, mutable> - triton_gpu.local_dealloc %alloc : !tt.memdesc<128x64xi64, #shared0, mutable> + %alloc = triton_gpu.local_alloc %cst : (tensor<128x64xi64, #blocked0>) -> !tt.memdesc<128x64xi64, #shared0, #triton_gpu.shared_memory, mutable> + triton_gpu.local_dealloc %alloc : !tt.memdesc<128x64xi64, #shared0, #triton_gpu.shared_memory, mutable> tt.experimental_descriptor_store %arg0[%arg1, %arg1], %arg2 : !tt.ptr, tensor<128x256xf32, #blocked0> tt.return } diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index bdf2f863b3d6..9b0aa68f1d66 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -27,6 +27,7 @@ class HIPOptions: allowed_dot_input_precisions: Tuple[str] = ("ieee", ) enable_fp_fusion: bool = True matrix_instr_nonkdim: int = 0 + enable_moe_lds_bypass: bool = False kpack: int = 1 allow_flush_denorm: bool = False max_num_imprecise_acc_default: int = 0 @@ -105,6 +106,7 @@ def path_to_rocm_lld(): @staticmethod def make_ttir(mod, metadata, options): + mod.context.enable_moe_lds_bypass(options.enable_moe_lds_bypass) pm = ir.pass_manager(mod.context) pm.enable_debug() passes.common.add_inliner(pm) @@ -133,14 +135,28 @@ def make_ttgir(mod, metadata, options): amd.passes.ttgpuir.add_accelerate_matmul(pm, options.arch, options.matrix_instr_nonkdim, options.kpack) passes.ttgpuir.add_remove_layout_conversions(pm) amd.passes.ttgpuir.add_optimize_epilogue(pm) + + if options.enable_moe_lds_bypass: + amd.passes.ttgpuir.add_tritongpu_bypass_lds_for_dot_layout_pass(pm) + passes.ttgpuir.add_optimize_dot_operands(pm, True) - if options.num_stages == 0 and amd.has_matrix_core_feature(options.arch): - amd.passes.ttgpuir.add_stream_pipeline(pm) + use_new_pipeliner = os.getenv("TRITON_HIP_USE_NEW_STREAM_PIPELINE", "0") == "1" + if amd.has_matrix_core_feature(options.arch): + if use_new_pipeliner: + # In the old pipeliner we only support num_stages = 0/1, which means something + # different than the NVIDIA side. In the new pipeliner we unify the num_stages + # interpretation. Default to use 2 stages if not explicitly set. + num_stages = options.num_stages if options.num_stages != 0 else 2 + amd.passes.ttgpuir.add_stream_pipelinev2(pm, num_stages) + else: + if options.num_stages == 0: + amd.passes.ttgpuir.add_stream_pipeline(pm) passes.common.add_canonicalizer(pm) + passes.ttgpuir.add_optimize_dot_operands(pm, True) passes.ttgpuir.add_remove_layout_conversions(pm) passes.ttgpuir.add_reduce_data_duplication(pm) - if options.num_stages != 0: + if use_new_pipeliner or options.num_stages != 0: amd.passes.ttgpuir.add_reorder_instructions(pm) passes.common.add_cse(pm) passes.common.add_symbol_dce(pm) diff --git a/third_party/amd/include/TritonAMDGPUTransforms/Passes.h b/third_party/amd/include/TritonAMDGPUTransforms/Passes.h index e7a9753b2145..47bebfa61a2d 100644 --- a/third_party/amd/include/TritonAMDGPUTransforms/Passes.h +++ b/third_party/amd/include/TritonAMDGPUTransforms/Passes.h @@ -8,6 +8,8 @@ namespace mlir { std::unique_ptr createTritonAMDGPUStreamPipelinePass(); +std::unique_ptr createTritonAMDGPUStreamPipelineV2Pass(int numStages = 2); + std::unique_ptr createTritonAMDGPUAccelerateMatmulPass(std::string archGenName = std::string(), int matrixInstructionSize = 0, @@ -20,6 +22,7 @@ std::unique_ptr createTritonAMDGPUReorderInstructionsPass(); std::unique_ptr createTritonAMDGPUVerifier(); std::unique_ptr createTritonAMDGPUOptimizeEpiloguePass(); +std::unique_ptr createTritonAMDGPUBypassLDSForDotLayout(); /// Generate the code for registering passes. #define GEN_PASS_REGISTRATION diff --git a/third_party/amd/include/TritonAMDGPUTransforms/Passes.td b/third_party/amd/include/TritonAMDGPUTransforms/Passes.td index a818b1ac9da5..8a1ae6f7b76d 100644 --- a/third_party/amd/include/TritonAMDGPUTransforms/Passes.td +++ b/third_party/amd/include/TritonAMDGPUTransforms/Passes.td @@ -16,6 +16,25 @@ def TritonAMDGPUStreamPipeline : Pass<"tritonamdgpu-stream-pipeline", "mlir::Mod let dependentDialects = []; } +def TritonAMDGPUStreamPipelineV2 : Pass<"tritonamdgpu-stream-pipeline-v2", "mlir::ModuleOp"> { + let summary = "pipeline"; + + let description = [{ + Pipeline global loads through registers to shared memory while computing on previous + tile + }]; + + let constructor = "mlir::createTritonAMDGPUStreamPipelineV2Pass()"; + + let dependentDialects = []; + + let options = [ + Option<"numStages", "num_stages", + "int32_t", /*default*/"2", + "Number of Pipeline stages"> + ]; +} + def TritonAMDGPUAccelerateMatmul : Pass<"tritonamdgpu-accelerate-matmul", "mlir::ModuleOp"> { let summary = "accelerate matmul"; @@ -53,6 +72,18 @@ def TritonAMDGPUOptimizeEpilogue : Pass<"tritonamdgpu-optimize-epilogue", "mlir: } +def TritonAMDGPUBypassLDSForDotLayout: Pass<"tritonamdgpu-bypass-lds-for-dot-layout", "mlir::ModuleOp"> { + let summary = "Bypass moving data trough LDS for dot layout when possible"; + + let description = [{ + Bypass moving data trough LDS for dot layout when possible. + }]; + + let constructor = "mlir::createTritonAMDGPUBypassLDSForDotLayout()"; + + let dependentDialects = []; +} + def TritonAMDGPUReorderInstructions: Pass<"tritonamdgpu-reorder-instructions", "mlir::ModuleOp"> { let summary = "Reorder instructions"; diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 953b01dab08a..9978f94695ae 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -125,6 +125,21 @@ struct ConvertLayoutOpConversion Attribute srcLayout = srcTy.getEncoding(); Attribute dstLayout = dstTy.getEncoding(); + if (isMoeLDSBypass() && isBlockedToDotShortcut(srcTy, dstTy)) { + auto srcElems = triton::gpu::getTotalElemsPerThread(srcTy); + auto dstElems = triton::gpu::getTotalElemsPerThread(dstTy); + if (srcElems != dstElems) { + llvm::errs() << "incompatible layout conversion: " << op << "\n"; + } + assert(srcElems == dstElems); + auto loc = op.getLoc(); + auto vals = unpackLLElements(loc, adaptor.getSrc(), rewriter); + Value view = + packLLElements(loc, getTypeConverter(), vals, rewriter, dstTy); + rewriter.replaceOp(op, view); + return success(); + } + if (isa(srcLayout) && isa(dstLayout)) { return lowerMfmaToDotOperand(op, adaptor, rewriter); diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/AMDBypassLDSForDotLayout.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/AMDBypassLDSForDotLayout.cpp new file mode 100644 index 000000000000..ebd69da5a232 --- /dev/null +++ b/third_party/amd/lib/TritonAMDGPUTransforms/AMDBypassLDSForDotLayout.cpp @@ -0,0 +1,322 @@ +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" + +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" +#include "mlir/Transforms/RegionUtils.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#define GEN_PASS_CLASSES +#include "TritonAMDGPUTransforms/Passes.h" + +using namespace mlir; +namespace ttg = triton::gpu; + +// // convert(val) : mma -> blocked +// // tt.store(ptr, val, mask, ...) : blocked +// // ==> +// // convert(ptr) : blocked -> mma +// // convert(mask) : blocked -> mma +// // tt.store(ptr, val, mask, ...) : mma +// // +// // Store with mma layout directly + +// static Type getNewType(Type type, Attribute encoding) { +// RankedTensorType tensorType = typecast(); +// return RankedTensorType::get(tensorType.getShape(), +// tensorType.getElementType(), encoding); +// } + +// void convertLayout(Attribute encoding, Operation *op) { +// OpBuilder builder(op); +// // Convert operands +// // For load/store with tensor pointers, we don't have to change the +// // operands' type, we do this by changing the outputs' type of +// // `make_tensor_ptr` +// SmallVector newArgs; +// for (auto operand : op->getOperands()) { +// auto tensorType = operand.getType().dyn_cast(); +// if (tensorType && +// !tensorType.getEncoding().isa()) { +// Type newType = getNewType(tensorType, encoding); +// newArgs.push_back( +// builder.create(op->getLoc(), newType, +// operand)); +// } else { +// newArgs.push_back(operand); +// } +// } + +// // Convert output types +// SmallVector newTypes; +// for (auto t : op->getResultTypes()) { +// bool isAsync = isa(op); +// newTypes.push_back(isAsync ? t : getNewType(t, encoding)); +// } + +// // Construct new op with the new encoding +// Operation *newOp = builder.create(op->getLoc(), +// op->getName().getIdentifier(), +// newArgs, newTypes, op->getAttrs()); + +// // Cast the results back to the original layout +// for (size_t i = 0; i < op->getNumResults(); i++) { +// Value newResult = newOp->getResult(i); +// if (newTypes[i] != op->getResultTypes()[i]) { +// newResult = builder.create( +// op->getLoc(), op->getResult(i).getType(), newResult); +// } +// op->getResult(i).replaceAllUsesWith(newResult); +// } +// op->erase(); +// } + +static Type getNewType(Type type, Attribute encoding) { + RankedTensorType tensorType = cast(type); + return RankedTensorType::get(tensorType.getShape(), + tensorType.getElementType(), encoding); +} + +void convertLayout(Attribute encoding, Operation *op) { + OpBuilder builder(op); + // Convert operands + // For load/store with tensor pointers, we don't have to change the + // operands' type, we do this by changing the outputs' type of + // `make_tensor_ptr` + SmallVector newArgs; + for (auto operand : op->getOperands()) { + auto tensorType = dyn_cast(operand.getType()); + if (tensorType && + !isa(tensorType.getEncoding())) { + Type newType = getNewType(tensorType, encoding); + newArgs.push_back(builder.create( + op->getLoc(), newType, operand)); + } else { + newArgs.push_back(operand); + } + } + + // Convert output types + SmallVector newTypes; + for (auto t : op->getResultTypes()) { + bool isAsync = isa(op); + newTypes.push_back(isAsync ? t : getNewType(t, encoding)); + } + + // Construct new op with the new encoding + Operation *newOp = builder.create(op->getLoc(), op->getName().getIdentifier(), + newArgs, newTypes, op->getAttrs()); + + // Cast the results back to the original layout + for (size_t i = 0; i < op->getNumResults(); i++) { + Value newResult = newOp->getResult(i); + if (newTypes[i] != op->getResultTypes()[i]) { + newResult = builder.create( + op->getLoc(), op->getResult(i).getType(), newResult); + } + op->getResult(i).replaceAllUsesWith(newResult); + } + op->erase(); +} + +triton::LoadOp getLoadInst(Operation *op, ModuleOp &mod) { + SmallVector loadOpsVec; + + mod.walk([&](triton::LoadOp loadOp) { + SetVector forwardSlices; + getForwardSlice((Operation *)loadOp, &forwardSlices); + if (std::find(forwardSlices.begin(), forwardSlices.end(), op) != + forwardSlices.end()) { + loadOpsVec.push_back(loadOp); + } + }); + + // Currently, we expect the dot operand to depend only on one tensor + // from global memory (applicable for dot ops that don't depend on other dot + // ops). This condition can be lifted if necessary. + // assert(loadOpsVec.size() == 1); + // llvm::outs() << "number of loads in DF chain: " << loadOpsVec.size() << + // "\n"; + return loadOpsVec.back(); +} + +class BypassLDSForDotLayout : public mlir::RewritePattern { + +public: + explicit BypassLDSForDotLayout(mlir::MLIRContext *context) + : mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), + 1, context) {} + mlir::LogicalResult + matchAndRewrite(mlir::Operation *op, + mlir::PatternRewriter &rewriter) const override { + + auto cvtOp = dyn_cast(op); + auto mod = op->getParentOfType(); + + if (!cvtOp) + return mlir::failure(); + + auto srcType = cast(cvtOp.getOperand().getType()); + auto dstType = cast(cvtOp.getType()); + + if (srcType.getShape().size() != 2) { + return mlir::failure(); + } + + auto srcBlocked = + dyn_cast(srcType.getEncoding()); + auto dstDotOp = + dyn_cast(dstType.getEncoding()); + + if (!(srcBlocked && dstDotOp)) { + return mlir::failure(); + } + + int kDim = dstDotOp.getOpIdx() == 0 ? 1 : 0; + int nonKDim = dstDotOp.getOpIdx() == 0 ? 0 : 1; + + if (dstDotOp.getOpIdx() != 1) { + return mlir::failure(); + } + auto numWarps = triton::gpu::getNumWarpsPerCTA(srcBlocked); + auto numThreads = triton::gpu::getWarpSize(srcBlocked); + if (numThreads != 64) { + return mlir::failure(); + } + + SmallVector newWarpsPerCTA(2); + SmallVector newSizePerThread(2); + SmallVector newThreadsPerWarp(2); + SmallVector newOrder(2); + + // Should we use only one configuration? + auto shape = dstType.getShape(); + newOrder[0] = 1; + newOrder[1] = 0; + + const unsigned targetLoadBitWidth = 128; + const unsigned elemBitWidth = + srcType.getElementType().getIntOrFloatBitWidth(); + + switch (shape[kDim]) { + case 1024: + newSizePerThread[0] = targetLoadBitWidth / elemBitWidth; + newSizePerThread[1] = 1; + newThreadsPerWarp[0] = 64; + newThreadsPerWarp[1] = 1; + newWarpsPerCTA[0] = 1; + newWarpsPerCTA[1] = numWarps; + newOrder[0] = 0; + newOrder[1] = 1; + break; + case 512: + newSizePerThread[0] = targetLoadBitWidth / elemBitWidth; + newSizePerThread[1] = 1; + newThreadsPerWarp[0] = shape[kDim] / newSizePerThread[0]; + newThreadsPerWarp[1] = 64 / newThreadsPerWarp[0]; + newWarpsPerCTA[0] = 1; + newWarpsPerCTA[1] = numWarps; + newOrder[0] = 0; + newOrder[1] = 1; + break; + case 256: + assert(elemBitWidth == 8 && "8 bit dtype should use BLOCK_K=256"); + newSizePerThread[0] = targetLoadBitWidth / elemBitWidth; + newSizePerThread[1] = 1; + newThreadsPerWarp[0] = 16; + newThreadsPerWarp[1] = 4; + newWarpsPerCTA[0] = 1; + newWarpsPerCTA[1] = numWarps; + newOrder[0] = 0; + newOrder[1] = 1; + break; + case 128: + assert(elemBitWidth == 16 && "16 bit dtype should use BLOCK_K=128"); + newSizePerThread[0] = targetLoadBitWidth / elemBitWidth; + newSizePerThread[1] = 1; + newThreadsPerWarp[0] = 16; + newThreadsPerWarp[1] = 4; + newWarpsPerCTA[0] = 1; + newWarpsPerCTA[1] = numWarps; + newOrder[0] = 0; + newOrder[1] = 1; + break; + case 64: + case 32: + case 16: + assert(false && "BLOCK_K must be 128 for fp16 and 256 for int8/fp8"); + default: + return failure(); + } + + auto newBlockedEncoding = triton::gpu::BlockedEncodingAttr::get( + mod.getContext(), newSizePerThread, newThreadsPerWarp, newWarpsPerCTA, + newOrder, srcBlocked.getCTALayout()); + + auto loadInst = getLoadInst(cvtOp, mod); + + auto loadType = dyn_cast(loadInst.getResult().getType()); + if (!loadType || loadType.getEncoding() == newBlockedEncoding) + return failure(); + + assert(loadType.getElementType().getIntOrFloatBitWidth() == elemBitWidth && + "data type unexpectedly changing bitwidth between load and dot"); + + convertLayout(newBlockedEncoding, (Operation *)loadInst); + if (failed(mlir::verify(mod))) { + assert(false); + } + return mlir::success(); + } +}; + +class TritonAMDGPUBypassLDSForDotLayoutPass + : public TritonAMDGPUBypassLDSForDotLayoutBase< + TritonAMDGPUBypassLDSForDotLayoutPass> { + +public: + TritonAMDGPUBypassLDSForDotLayoutPass() = default; + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp m = getOperation(); + // int *a = nullptr; + // *a = 4; + mlir::RewritePatternSet patterns(context); + + patterns.add(context); + + if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) { + signalPassFailure(); + } + } +}; + +std::unique_ptr mlir::createTritonAMDGPUBypassLDSForDotLayout() { + return std::make_unique(); +} diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp index 86505386cd0c..de132f04e831 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp @@ -60,6 +60,9 @@ SmallVector warpsPerTile(tt::DotOp dotOp, SmallVector shapePerWarp) { auto rank = shape.size(); // Early exit for batched matmul + if (triton::isMoeLDSBypass()) + return {1, static_cast(numWarps)}; + if (rank == 3) return {(unsigned)numWarps, 1, 1}; diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt b/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt index d96860c3ef90..81939eac5960 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt +++ b/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt @@ -3,8 +3,9 @@ add_triton_library(TritonAMDGPUTransforms OptimizeEpilogue.cpp ReorderInstructions.cpp StreamPipeline.cpp + StreamPipelineV2.cpp MfmaGroup.cpp - + AMDBypassLDSForDotLayout.cpp DEPENDS TritonAMDGPUTransformsIncGen TritonGPUIR diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp index f9fac1bf5b0d..9371c8b5f897 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp @@ -2,88 +2,377 @@ #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Dominance.h" -#include "mlir/IR/IRMapping.h" -#include "mlir/IR/Matchers.h" -#include "mlir/IR/PatternMatch.h" #include "mlir/IR/Verifier.h" -#include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "mlir/Transforms/Passes.h" -#include "mlir/Transforms/RegionUtils.h" -#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" -#include "triton/Dialect/TritonGPU/Transforms/Passes.h" -#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" -#define GEN_PASS_CLASSES -#include "TritonAMDGPUTransforms/Passes.h" +#include "llvm/ADT/STLExtras.h" using namespace mlir; +namespace ttg = mlir::triton::gpu; + +//===----------------------------------------------------------------------===// +// Utility functions +//===----------------------------------------------------------------------===// -static bool willIncreaseRegisterPressure(Operation *op) { - if (isa(op)) - return true; - auto cvt = dyn_cast(op); - if (!cvt) - return false; - if (isa(cvt.getType().getEncoding())) - return true; - return false; +// Return true if the given moduleOp contains a pure matmul problem; i.e., +// single dot in the main loop. +static bool isPureMatmulProblem(ModuleOp moduleOp) { + bool isMatmul = true; + bool foundLoop = false; + moduleOp.walk([&](scf::ForOp forOp) -> void { + int counter = 0; + forOp.walk([&counter](triton::DotOp dotOp) { ++counter; }); + isMatmul = (isMatmul && (counter == 1)); + foundLoop = true; + }); + return foundLoop && isMatmul; } -class TritonAMDGPUReorderInstructionsPass - : public TritonAMDGPUReorderInstructionsBase< - TritonAMDGPUReorderInstructionsPass> { -public: - TritonAMDGPUReorderInstructionsPass() = default; +// Search through block to find earliest insertion point for move op. This can +// be either an atomic op or last usage of source pointer. Search ends when move +// op is encountered. +static llvm::ilist::iterator +findEarlyInsertionPoint(Block *block, Operation *move) { + Value src; + if (auto ld = dyn_cast(move)) + src = ld.getPtr(); - void runOnOperation() override { - ModuleOp m = getOperation(); - mlir::DominanceInfo dom(m); - // Sink conversions into loops when they will increase - // register pressure - DenseMap opToMove; - auto moveAfter = [](Operation *lhs, Operation *rhs) { - lhs->moveAfter(rhs); - }; - m.walk([&](Operation *op) { - if (!willIncreaseRegisterPressure(op)) - return; - auto user_begin = op->user_begin(); - auto user_end = op->user_end(); - if (std::distance(user_begin, user_end) != 1) - return; - if (user_begin->getParentOfType() == - op->getParentOfType()) - return; - opToMove.insert({op, *user_begin}); - }); - for (auto &kv : opToMove) - kv.first->moveBefore(kv.second); - // Move LocalLoadOp and LocalAllocOp immediately after their operands. - m.walk([&](Operation *op) { - if (!isa(op)) { - return; + auto ipnt = block->end(); + for (auto bi = block->begin(); bi != block->end(); ++bi) { + auto *op = &*bi; + if (op == move) // Don't move later than current location + break; + + op->walk([&](Operation *wop) { + if (src) { + // Check for ops accessing src value. + for (auto opr : wop->getOperands()) { + if (opr == src) + ipnt = bi; + } } - Operation *argOp = op->getOperand(0).getDefiningOp(); - if (!argOp) - return; - moveAfter(op, argOp); + // Atomics used for global synchronization. + if (isa(wop)) + ipnt = bi; + // Break at barrier + if (isa(wop)) + ipnt = bi; + // Break at loops. + if (isa(wop)) + ipnt = bi; }); - // Move transpositions just after their definition - opToMove.clear(); - m.walk([&](triton::TransOp op) { - Operation *argOp = op.getSrc().getDefiningOp(); - if (!argOp) + } + return ipnt; +} + +// Return the first user in the same block of the given op. If the user is in a +// nested block then return the op owning the block. Return nullptr if not +// existing. +static Operation *getFirstUseInSameBlock(Operation *op) { + SmallVector usersInSameBlock; + for (auto user : op->getUsers()) { + if (Operation *ancestor = op->getBlock()->findAncestorOpInBlock(*user)) + usersInSameBlock.push_back(ancestor); + } + auto minOpIt = + llvm::min_element(usersInSameBlock, [](Operation *a, Operation *b) { + return a->isBeforeInBlock(b); + }); + return minOpIt != usersInSameBlock.end() ? *minOpIt : nullptr; +} + +// Check if the operation opInsideLoop is inside any scf::ForOp and +// opOutsideLoop is not inside the same loop. +static bool isCrossLoopBoundary(mlir::Operation *opInsideLoop, + mlir::Operation *opOutsideLoop) { + scf::ForOp parentForOp = opInsideLoop->getParentOfType(); + return parentForOp && !parentForOp->isAncestor(opOutsideLoop); +} + +//===----------------------------------------------------------------------===// +// Reorder mechanisms +//===----------------------------------------------------------------------===// + +// Sink dot layout conversions into loops to decrease register pressure when +// possible. +static void sinkDotConversion(ModuleOp moduleOp) { + DenseMap opToMove; + moduleOp.walk([&](ttg::ConvertLayoutOp op) { + Attribute encoding = op.getType().getEncoding(); + if (!isa_and_nonnull(encoding)) + return; + if (!op->hasOneUse()) + return; + Operation *user = *op->getUsers().begin(); + if (user->getParentOfType() == + op->getParentOfType()) + return; + opToMove[op] = user; + }); + + for (auto &kv : opToMove) + kv.first->moveBefore(kv.second); +} + +// Adjust the placement of shared memory writes and reads to immediately follow +// the definition of their operands in case where shared memory write is in the +// loop but its operand is not. +// +// This is a heuristic driven by optimizing fused attention by hoisting Q tensor +// shared memory read/write operations outside of the loop, as Q is a loop +// invariant and can be loaded once before entering the loop. But it should be +// generally applicable. +// +// There are two possible patterns for this adjustment depending on whether the +// write to shared memory is performed using an optional `local_alloc` argument +// or a `local_store` instruction. +// +// 1) %1 = some_op ... (typically a load or an operation that scales the tensor +// after loading) +// %2 = local_alloc %1 +// %3 = local_load %2 +// +// 2) %1 = some_op ... +// %2 = local_alloc +// %3 = local_store %1, %2 +// %4 = local_load %2 +static void hoistLocalLoad(ModuleOp moduleOp) { + moduleOp.walk([&](ttg::LocalLoadOp localLoad) { + auto localAlloc = localLoad.getSrc().getDefiningOp(); + if (!localAlloc) + return; + + // Case when localAlloc has operands + if (localAlloc->getNumOperands() == 1) { + if (!localAlloc->hasOneUse()) return; - moveAfter(op, argOp); - }); - return; + + auto srcTensorOp = localAlloc.getSrc().getDefiningOp(); + // Check if localAlloc is in the loop but it's src tensor defining op is + // outside of it. + if (!srcTensorOp || !isCrossLoopBoundary(localAlloc, srcTensorOp)) + return; + + localAlloc->moveAfter(srcTensorOp); + localLoad->moveAfter(localAlloc); + return; + } + + // Case when localAlloc has no operands + assert(localAlloc->getNumOperands() < 1); + auto allocVal = localAlloc->getResult(0); + + // Check if the localAlloc has exactly two uses (localStore and localLoad) + int numUses = std::distance(allocVal.use_begin(), allocVal.use_end()); + if (numUses != 2) + return; + + // localStore comes before localLoad in block. + Operation *localStore = getFirstUseInSameBlock(localAlloc); + if (!isa(localStore)) + return; + + auto srcTensorOp = localStore->getOperand(0).getDefiningOp(); + // Check if localStore is in the loop but it's src tensor defining op is + // outside of it. + if (!srcTensorOp || !isCrossLoopBoundary(localStore, srcTensorOp)) { + return; + } + + localAlloc->moveAfter(srcTensorOp); + localStore->moveAfter(localAlloc); + localLoad->moveAfter(localStore); + }); +} + +// Sink conversion after the last dealloc but before the first use in its block. +// This helps to avoid unnecessary shared memory allocation. +static void moveDownCoversion(ModuleOp moduleOp) { + SmallVector convertOps; + moduleOp.walk([&](ttg::ConvertLayoutOp op) { convertOps.push_back(op); }); + + for (auto op : convertOps) { + Operation *user = getFirstUseInSameBlock(op); + for (auto it = Block::iterator(op), ie = op->getBlock()->end(); + it != ie && &*it != user; ++it) + if (isa(&*it)) + op->moveAfter(&*it); + } +} + +// Move transpositions just after their definition. +static void moveUpTranspose(ModuleOp moduleOp) { + SmallVector transOps; + moduleOp.walk([&](triton::TransOp op) { transOps.push_back(op); }); + + for (auto op : transOps) + if (Operation *argOp = op.getSrc().getDefiningOp()) + op->moveAfter(argOp); +} + +// Schedule global load and local store ops for better GEMM performance. +static void scheduleGlobalLoadLocalStore(ModuleOp m) { + SmallVector moveOps; + // Move global loads early to prefetch. This may increase register pressure + // but it enables issuing global loads early. + m.walk([&](triton::LoadOp op) { moveOps.push_back(op); }); + // Move local_stores early if dependence distance greater than one iteration. + // Best perf on GEMM when these precede global loads. + m.walk([&](ttg::LocalStoreOp op) { moveOps.push_back(op); }); + + for (auto op : llvm::reverse(moveOps)) { + // Gather use-def chain in block. + Block *block = op->getBlock(); + bool leadsToLoad = false; + SetVector backwardSet; + + BackwardSliceOptions options; + options.omitBlockArguments = true; + options.inclusive = false; + options.filter = [&](Operation *defOp) -> bool { + Block *defBlock = defOp->getBlock(); + if (!block->findAncestorOpInBlock(*defOp)) + return false; + // Check for a `load` dependent path. + leadsToLoad |= isa(defOp); + // Only move ops residing in the same block. + return defBlock == block; + }; + mlir::getBackwardSlice(op, &backwardSet, options); + backwardSet.insert(op); + + // Don't move a local_store if its source is a load from + // the same iteration. + if (isa(op) && leadsToLoad) + continue; + + auto ipoint = findEarlyInsertionPoint(block, op); + // Remove ops that already precede the insertion point. This is done + // before moves happen to avoid `Operation::isBeforeInBlock` N^2 + // complexity. + + SmallVector dfg = backwardSet.takeVector(); + if (ipoint != block->end()) { + // Move ops to insertion point. + llvm::erase_if( + dfg, [&](Operation *op) { return !ipoint->isBeforeInBlock(op); }); + for (auto *dfgop : llvm::reverse(dfg)) + dfgop->moveAfter(block, ipoint); + } else { + // Move ops to block begin. + for (auto *dfgop : llvm::reverse(dfg)) + dfgop->moveBefore(block, block->begin()); + } + } +} + +/** + * Sched-load optimization for matmul kernels with large tile sizes + * The basic idea of sched-load optimization is to sink the 2nd tt.load + * after local_load so that global_load instructions can be interleaved with + * mfma's. This can help hide the issue latency of global_load instructions + * and improve performance on MI300X. + * + * It's assumed that the IR before this optimization has the following + * structure: + * ```mlir + * scf.for .. + * { + * tileA = tt.load a_ptr + * tileB = tt.load b_ptr + * opA = local_load bufferA + * opB = local_load bufferB + * res = tt.dot opA, opB + * local_store tileA, bufferA + * local_store tileB, bufferB + * } + * ``` + * After this optimization, the IR is transformed to + * ```mlir + * scf.for .. + * { + * tileA = tt.load a_ptr + * opA = local_load bufferA + * opB = local_load bufferB + * tileB = tt.load b_ptr <-- 2nd tt.load is sinked here + * res = tt.dot opA, opB + * local_store tileA, bufferA + * local_store tileB, bufferB + * } + * ``` + * For now, we don't have a perfect hueristic about when should this + * optimization be applied. Therefore, we implement a simple hueristic that + * this is applied when the tile size of A and B are large enough, i.e. + * nonKDim >= 128 and kDim >= 64. And also this is only applied for typical + * matmul kernels, i.e. only two tt.load's and one dotOp inside the loop. We + * are experimenting how to better control instruction scheduling and enable + * such optimizations. + */ +static void sinkSecondLoad(ModuleOp m) { + m.walk([&](scf::ForOp forOp) -> void { + SetVector loadOps; + triton::DotOp dotOp; + for (Operation &op : forOp) { + if (auto loadOp = dyn_cast(&op)) + loadOps.insert(loadOp); + if (auto curOp = dyn_cast(&op)) + dotOp = curOp; + } + // Only apply the optimization when there are 2 load's in the loop + if (loadOps.size() != 2) + return; + // Only apply the optimization when tile size is large enough + // 1. nonKDim >= 128 + // 2. kDim >= 64 + auto ldAOp = loadOps[0]; + auto tileAShape = cast(ldAOp.getType()).getShape(); + auto ldBOp = loadOps[1]; + auto tileBShape = cast(ldBOp.getType()).getShape(); + if (!(tileAShape[0] >= 128 && tileAShape[1] >= 64 && tileBShape[1] >= 128)) + return; + // Only apply the optimization when the moving is legal + // 1. Make sure the 2nd loadOp is before the dot + // 2. Make sure the first user of the 2nd loadOp is after the dot. + bool isBeforeDotOp = ldBOp->isBeforeInBlock(dotOp); + auto firstUser = *ldBOp.getResult().getUsers().begin(); + bool firstUserAfterDotOp = dotOp->isBeforeInBlock(firstUser); + if (isBeforeDotOp && firstUserAfterDotOp) + // move ldBOp right before tt.dot + ldBOp->moveBefore(dotOp); + }); +} + +//===----------------------------------------------------------------------===// +// Pass definition +//===----------------------------------------------------------------------===// + +#define GEN_PASS_CLASSES +#include "TritonAMDGPUTransforms/Passes.h" + +namespace { +struct TritonAMDGPUReorderInstructionsPass + : public TritonAMDGPUReorderInstructionsBase< + TritonAMDGPUReorderInstructionsPass> { + void runOnOperation() override { + ModuleOp m = getOperation(); + + hoistLocalLoad(m); + + sinkDotConversion(m); + moveDownCoversion(m); + + moveUpTranspose(m); + + if (isPureMatmulProblem(m)) { + scheduleGlobalLoadLocalStore(m); + sinkSecondLoad(m); + } } }; +} // namespace std::unique_ptr mlir::createTritonAMDGPUReorderInstructionsPass() { return std::make_unique(); diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp index 6f9ed6a23b4e..621960753a96 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp @@ -6,6 +6,7 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "triton/Analysis/AxisInfo.h" #include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" #include "llvm/ADT/MapVector.h" @@ -403,8 +404,9 @@ void LoopPipeliner::createBufferTypes() { auto sharedEnc = ttg::SharedEncodingAttr::get( ty.getContext(), dotOpEnc, ty.getShape(), ttg::getOrder(ty.getEncoding()), CTALayout, eType); - loadsBufferType[loadOp] = - triton::MemDescType::get(bufferShape, eType, sharedEnc); + loadsBufferType[loadOp] = triton::MemDescType::get( + bufferShape, eType, sharedEnc, + triton::gpu::SharedMemorySpaceAttr::get(ty.getContext())); } } diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipelineV2.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipelineV2.cpp new file mode 100644 index 000000000000..ae3403a15a9a --- /dev/null +++ b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipelineV2.cpp @@ -0,0 +1,760 @@ +#include "TritonAMDGPUTransforms/Passes.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Support/LLVM.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/PipelineExpander.h" +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" +#include "triton/Dialect/TritonGPU/Transforms/Schedule.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "llvm/Support/Debug.h" + +//===----------------------------------------------------------------------===// +// This file will create a schedule that will be handed over to the pipeline +// expander. +// Software pipeliners are usually separated into two pieces, one that create a +// modulo schedule and an expander that rewrites the loop and emits a prologue +// and epilogue. This pass first calls a helper that will pre-process the IR +// to create stream operations and create a modulo schedule. Then we call the +// expander to generate the prologue and new loop. +//===----------------------------------------------------------------------===// + +#define GEN_PASS_CLASSES +#include "TritonAMDGPUTransforms/Passes.h.inc" + +#define DEBUG_TYPE "tritonamdgpu-stream-pipeline-v2" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +using namespace mlir; +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; + +namespace { + +struct LoadInfo { + // Shared layout is used for loads feeding into dot ops. + ttg::SharedEncodingAttr sharedEncoding = nullptr; + // Blocked encoding is used for loads not used by the dot. + ttg::BlockedEncodingAttr blockedEncoding = nullptr; + // The distance of this load's stage to its use' stage. + int distToUse = 0; + bool usedByDot = false; +}; + +} // namespace + +static void createStreamCopy(scf::ForOp &forOp, tt::LoadOp loadOp, Value alloc, + Value extractIdx, tt::CoarseSchedule &schedule, + tt::CoarseSchedule::Cluster prefetchCluster, + llvm::MapVector &loadToInfo, + int numStages) { + OpBuilder builder(forOp); + Value zero = builder.create(forOp.getLoc(), 0, 32); + // Replace the load with insert/extract slice. + builder.setInsertionPoint(loadOp); + Location loc = loadOp.getLoc(); + Value src = loadOp.getPtr(); + Value mask = loadOp.getMask(); + Value other = loadOp.getOther(); + if (!isExpensiveLoadOrStore(loadOp) && loadToInfo[loadOp].blockedEncoding) { + // For inexpensive loads that do not directly feed into dot ops + // we want to use optimal layout for the data. + ttg::BlockedEncodingAttr encoding = loadToInfo[loadOp].blockedEncoding; + auto convertBlockLayout = [&](Value src) { + auto ty = cast(src.getType()); + auto newTy = + RankedTensorType::get(ty.getShape(), ty.getElementType(), encoding); + auto cvt = + builder.create(loadOp->getLoc(), newTy, src); + return cvt.getResult(); + }; + src = convertBlockLayout(src); + if (mask) + mask = convertBlockLayout(mask); + if (other) + other = convertBlockLayout(other); + } + + tt::MemDescType allocTy = cast(alloc.getType()); + SmallVector copyOffsets(allocTy.getRank(), zero); + Operation *copy = builder.clone(*loadOp); + + auto [stage, cluster] = schedule[loadOp]; + schedule.erase(loadOp); + schedule.insert(copy, stage, cluster); + + // Extract part. + SmallVector loadOffsets(allocTy.getRank(), zero); + loadOffsets[0] = extractIdx; + auto sharedMemorySpace = + triton::gpu::SharedMemorySpaceAttr::get(forOp.getContext()); + auto subviewTy = tt::MemDescType::get( + allocTy.getShape().drop_front(), allocTy.getElementType(), + allocTy.getEncoding(), sharedMemorySpace, /*mutableMemory=*/true); + auto viewLoad = + builder.create(loc, subviewTy, alloc, loadOffsets); + auto storeOp = + builder.create(loc, copy->getResult(0), viewLoad); + // Clean up old local caches. + SmallVector allocsToErase; + for (Operation *user : loadOp->getUsers()) { + if (auto alloc = dyn_cast(user)) { + alloc.replaceAllUsesWith(viewLoad.getResult()); + allocsToErase.push_back(alloc); + } + } + for (auto alloc : allocsToErase) + alloc.erase(); + + auto sharedLoad = + builder.create(loc, loadOp.getType(), viewLoad); + auto result = sharedLoad->getResults(); + + // Create a select for non-zero other values. + if (other && !isZeroConst(other)) { + auto select = builder.create( + loc, loadOp.getType(), mask, sharedLoad.getResult(), other); + result = select->getResults(); + } + + loadOp->replaceAllUsesWith(result); + + // Prefetch load ahead of the dot stage if is used by the dot. + if (loadToInfo[loadOp].usedByDot) { + assert(numStages >= 2 && "requires num_stages=2 at least"); + schedule.insert(storeOp, numStages - 2, prefetchCluster); + schedule.insert(viewLoad, numStages - 2, prefetchCluster); + } + loadOp.erase(); +} + +// If all the transitive uses of the given value have are used by a convert to +// the same dot operand encoding, return true and get the shared encoding that +// needs to be used to be compatible with users' layouts. +static std::optional +getSharedEncIfAllUsersAreDotEnc(Value val) { + ttg::SharedEncodingAttr attr; + for (Operation *user : val.getUsers()) { + ttg::SharedEncodingAttr tempAttr; + if (user->getNumResults() != 1) + return std::nullopt; + if (auto memDesc = + dyn_cast(user->getResult(0).getType())) { + // First time we find a shared encoding in the chain, save it and try to + // use it if it is compatible with the other users. + tempAttr = cast(memDesc.getEncoding()); + if (!getSharedEncIfAllUsersAreDotEnc(user->getResult(0)).has_value()) + return std::nullopt; + } else { + if (!isa(user)) + return std::nullopt; + auto dotOpEnc = dyn_cast( + cast(user->getResult(0).getType()).getEncoding()); + if (!dotOpEnc) + return std::nullopt; + auto srcTy = cast(val.getType()); + auto CTALayout = ttg::getCTALayout(srcTy.getEncoding()); + auto order = ttg::getOrder(srcTy.getEncoding()); + unsigned bitWidth = srcTy.getElementType().getIntOrFloatBitWidth(); + tempAttr = ttg::SharedEncodingAttr::get( + val.getContext(), dotOpEnc, srcTy.getShape(), + ttg::getOrder(srcTy.getEncoding()), + ttg::getCTALayout(srcTy.getEncoding()), + srcTy.getElementType().getIntOrFloatBitWidth(), /*needTrans=*/false); + } + // Check that the shared encodings needed by the users are compatible. + if (!tempAttr || (attr != nullptr && attr != tempAttr)) + return std::nullopt; + attr = tempAttr; + } + return attr; +} + +static ttg::BlockedEncodingAttr +getBlockedEncoding(tt::LoadOp loadOp, tt::ModuleAxisInfoAnalysis &axisInfo) { + Value src = loadOp.getPtr(); + auto ty = cast(src.getType()); + auto mod = loadOp->getParentOfType(); + int numWarps = ttg::TritonGPUDialect::getNumWarps(mod); + int threadsPerWarp = ttg::TritonGPUDialect::getThreadsPerWarp(mod); + tt::AxisInfo::DimVectorT contiguity = + axisInfo.getAxisInfo(src)->getContiguity(); + SmallVector order = argSort(contiguity); + unsigned currPerThread = getNumElementsPerThread(loadOp, order, axisInfo); + SmallVector sizePerThread(order.size(), 1); + sizePerThread[order[0]] = currPerThread; + ttg::CTALayoutAttr ctaLayout = ttg::getCTALayout(ty.getEncoding()); + return ttg::BlockedEncodingAttr::get(loadOp->getContext(), ty.getShape(), + sizePerThread, order, numWarps, + threadsPerWarp, ctaLayout); +} + +// Create a map from load ops to their indirection levels and the final uses +// of the load op (another load op, or a dot op). +// +// Indirection level is "0" for the load op directly used by the dot op, +// "1" for the load op used by the load op used by the dot op, and so on. +static llvm::SmallVector> +loadOpsToIndirectionLevelAndUse(scf::ForOp forOp) { + llvm::SmallVector> + loadOpToIndLevelAndUse; + DenseSet seen; + + // Recursively visit the given op and its operands to discover all load ops + // and collect their indirection levels and uses. + std::function dfs = + [&](Operation *op, int distance, Operation *use) { + // Skip previously visited load ops. + if (!seen.insert(op).second) + return; + + if (isa(op)) { + // TODO: What if there are multiple uses at different distances? + loadOpToIndLevelAndUse.emplace_back(op, distance, use); + use = op; + ++distance; + } + for (Value operand : op->getOperands()) { + Operation *defOp = operand.getDefiningOp(); + if (defOp && defOp->getBlock() == op->getBlock()) { + dfs(defOp, distance, use); + } + } + }; + + for (Operation &op : forOp.getBody()->without_terminator()) { + if (!op.hasTrait()) + continue; + seen.clear(); + dfs(&op, 0, &op); + } + + // If the loop has numStages attribute, also consider pipelining other loads + // that are not directly used by dot ops. + if (forOp->hasAttr(tt::kNumStagesAttrName)) { + for (Operation &op : forOp.getBody()->without_terminator()) { + if (!isa(op)) + dfs(&op, 0, &op); + } + } + + return loadOpToIndLevelAndUse; +} + +// Goes through all load ops to identify those that can be pipelined and assign +// layout to them. +static llvm::MapVector +assignMemoryLayouts(llvm::SmallVector> + &loadOpToIndLevelAndUse, + tt::ModuleAxisInfoAnalysis &axisInfoAnalysis) { + llvm::MapVector loadToInfo; + + for (auto &[op, dist, use] : loadOpToIndLevelAndUse) { + if (loadToInfo.count(op)) + // TODO: We'd need to verify that the distance is the same. + continue; + + LoadInfo loadInfo; + auto loadOp = cast(op); + assert(!isLoadFromTensorPtr(loadOp) && + "Block ptr should have been lowered before this pass."); + auto ptr = loadOp.getPtr(); + unsigned vec = axisInfoAnalysis.getPtrContiguity(ptr); + if (auto mask = loadOp.getMask()) + vec = std::min(vec, axisInfoAnalysis.getMaskAlignment(mask)); + + auto tensorTy = dyn_cast(ptr.getType()); + if (!tensorTy) { + LDBG("Skip non-tensor load " << *loadOp); + continue; + } + + auto pointeeTy = + cast(tensorTy.getElementType()).getPointeeType(); + unsigned width = vec * pointeeTy.getIntOrFloatBitWidth(); + + // Limit shared memory sharing to width >= 32 elements. + LDBG("Load " << *loadOp << " has width " << width); + if (width < 32) { + LDBG("Skip width<32 load " << *loadOp); + continue; + } + + if (use->hasTrait()) { + // Heuristic to avoid pipelining loads which result could be used directly + // as dot operand + RankedTensorType oprTy; + DenseSet seen; + std::function findLoad = [&](Value loadOp, + Value opr) -> bool { + if (loadOp == opr) + return true; + // Skip previously visited load ops. + if (seen.contains(opr)) + return false; + seen.insert(opr); + + if (Operation *op = opr.getDefiningOp()) { + // skip processing if operand already goes through shared memory + // or layout conversion through shared memory inevitable + if (!isa(op) && + !op->hasTrait() && + !op->hasTrait()) + return false; + for (Value operand : op->getOperands()) { + if (findLoad(loadOp, operand)) + return true; + } + } + return false; + }; + for (Value opr : use->getOperands()) { + if (findLoad(loadOp.getResult(), opr)) { + oprTy = dyn_cast(opr.getType()); + break; + } + } + bool foundLDSBypass = oprTy && tensorTy.getShape() == oprTy.getShape() && + isBlockedToDotShortcut(tensorTy, oprTy); + if (!foundLDSBypass) { + // Only use shared memory when feeding into a dot op. + loadInfo.usedByDot = true; + loadInfo.sharedEncoding = + getSharedEncIfAllUsersAreDotEnc(op->getResult(0)).value_or(nullptr); + } + } else if (auto useOp = dyn_cast(use)) { + // The use of this loadOp is another loadOp. If the use is not in the + // loadToInfo already, it means that the use is not valid for pipelining + // for some reason. We should skip this loadOp, too. + // + // Note that we have an assumption that the use of this loadOp has already + // be processed in a previous loop iteration. This assumption is held by + // how loadOpsToIndirectionLevelAndUse recursively collects + // loadOpToIndLevelAndUse using DFS. + if (loadToInfo.count(useOp) == 0) { + continue; + } + } + + // If we still don't have a shared encoding, try a "generic" shared + // encoding. + if (!loadInfo.sharedEncoding) { + // Also pipeline in-register buffers. + if (auto loadOp = dyn_cast(op)) { + loadInfo.blockedEncoding = getBlockedEncoding(loadOp, axisInfoAnalysis); + } + } + + loadToInfo[op] = loadInfo; + } + + return loadToInfo; +} + +static llvm::MapVector +scheduleLoads(scf::ForOp forOp, tt::CoarseSchedule &schedule, + DenseSet &rootUsers, int numStages) { + ModuleOp moduleOp = forOp->getParentOfType(); + tt::ModuleAxisInfoAnalysis axisInfoAnalysis(moduleOp); + + // Get all loads that are (transitively) used by dot ops and their distance + // to the dot op. + llvm::SmallVector> + loadOpToIndLevelAndUse = loadOpsToIndirectionLevelAndUse(forOp); + LLVM_DEBUG({ + LDBG("Found " << loadOpToIndLevelAndUse.size() << " loads to pipeline:"); + for (const auto &[l, i, u] : loadOpToIndLevelAndUse) { + LDBG(" - load: " << *l); + LDBG(" at indirection level: " << i); + LDBG(" used by op: " << *u); + } + }); + if (loadOpToIndLevelAndUse.empty()) + return {}; + + // Check which loads are good for pipelining, and assign them memory layouts. + llvm::MapVector loadToInfo = + assignMemoryLayouts(loadOpToIndLevelAndUse, axisInfoAnalysis); + if (loadToInfo.empty()) + return {}; + + // Filter out load ops that cannot be pipelined. + int resize = 0; + for (int i = 0, e = loadOpToIndLevelAndUse.size(); i < e; ++i) { + auto [loadOp, distance, use] = loadOpToIndLevelAndUse[i]; + if (loadToInfo.count(loadOp) != 0) + loadOpToIndLevelAndUse[resize++] = loadOpToIndLevelAndUse[i]; + } + loadOpToIndLevelAndUse.resize(resize); + + // Calculate the stage distance between applicable loads. + int maxIndirectionLevel = -1; + for (auto [loadOp, dist, use] : loadOpToIndLevelAndUse) + maxIndirectionLevel = std::max(maxIndirectionLevel, dist); + + // The stage gap between chained loads--this allows us to "spread" loads + // with a non-one step in case the number of stages given by the user is + // large. + assert(numStages >= 2 && "requires num_stages=2 at least"); + unsigned stagesBetweenLoads = + llvm::divideCeil(numStages - 2, maxIndirectionLevel + 1); + LDBG("stagesBetweenLoads = " << stagesBetweenLoads); + + // Put the root uses of the loads in the last stage. + tt::CoarseSchedule::Cluster rootUsersCluster = schedule.clusters.newAtFront(); + for (auto &[loadOp, dist, use] : loadOpToIndLevelAndUse) { + // Non-LoadOp(s) are the (final) root uses of all LoadOp(s). + if (!isa(use)) { + schedule.insert(use, numStages - 1, rootUsersCluster); + rootUsers.insert(use); + } + } + + // Create a cluster for load ops at each indirection level. + SmallVector loadsClusters; + for (int i = 0; i <= maxIndirectionLevel; i++) { + loadsClusters.push_back(schedule.clusters.newAtBack()); + } + // Assign stages to the loads. + for (auto [loadOp, indLevel, _] : loadOpToIndLevelAndUse) { + int stage = (maxIndirectionLevel - indLevel) * stagesBetweenLoads; + schedule.insert(loadOp, stage, loadsClusters[indLevel]); + } + + // Calculate distance from the load to the use. + for (auto [loadOp, _, use] : loadOpToIndLevelAndUse) { + loadToInfo[loadOp].distToUse = schedule[use].first - schedule[loadOp].first; + } + + LLVM_DEBUG({ + LDBG("Chosen loads to pipeline:"); + for (const auto &[load, info] : loadToInfo) { + LDBG(" - load: " << *load); + LDBG(" distToUse: " << info.distToUse); + LDBG(" usedByDot: " << info.usedByDot); + } + }); + return loadToInfo; +} + +// Add dependencies of anchor ops to the coarse schedule. Schedule them to +// the same stage and ordering cluster as the anchor op. +static void scheduleDependencies(scf::ForOp forOp, tt::CoarseSchedule &schedule, + int numStages) { + SmallVector> + opsInOrder = schedule.getOpsInOrder(forOp); + // Schedule dependencies stage by stage. + for (int stage = 0; stage < numStages; ++stage) { + for (auto [op, stage_, cluster] : opsInOrder) { + if (stage_ != stage) + continue; + schedule.insertDepsOfOp(op, stage, cluster, false); + } + } +} + +// Find dependencies with distance of 1. They will go to the next stage, +// but in the cluster before the current op. +static void scheduleDistanceOneDependencies(scf::ForOp forOp, + tt::CoarseSchedule &schedule, + int numStages) { + auto getNestedOperands = [](Operation *op) { + SmallVector operands; + op->walk([&](Operation *nestedOp) { + for (Value operand : nestedOp->getOperands()) { + if (operand.getParentBlock()->getParentOp()->isAncestor(nestedOp)) + operands.push_back(operand); + } + }); + return operands; + }; + + // Mapping from the cluster to the cluster before it. + DenseMap + dist1Cluster; + for (auto &op : forOp.getBody()->without_terminator()) { + if (schedule.count(&op) == 0) + continue; + auto [stage, cluster] = schedule[&op]; + // Can't schedule past the last stage. + if (stage == numStages - 1) + continue; + for (Value operand : getNestedOperands(&op)) { + auto arg = dyn_cast(operand); + if (!arg || arg.getArgNumber() == 0 || arg.getOwner() != op.getBlock()) + continue; + auto yieldOp = op.getBlock()->getTerminator(); + Value v = yieldOp->getOperand(arg.getArgNumber() - 1); + Operation *defOp = v.getDefiningOp(); + if (!defOp || schedule.count(defOp) != 0) + continue; + if (isa(defOp)) { + // Exception: schedule loads with a distance of 1 together with the + // current op. + schedule.insertIfAbsent(defOp, stage, cluster); + schedule.insertDepsOfOp(defOp, stage, cluster, true); + } else { + if (dist1Cluster.count(&cluster) == 0) { + dist1Cluster[&cluster] = schedule.clusters.newBefore(cluster); + } + schedule.insertIfAbsent(defOp, stage + 1, dist1Cluster[&cluster]); + schedule.insertDepsOfOp(defOp, stage + 1, dist1Cluster[&cluster], true); + } + } + } +} + +static void +scheduleRemainingToLastStage(scf::ForOp forOp, tt::CoarseSchedule &schedule, + tt::CoarseSchedule::Cluster afterPrologue, + int numStages) { + // Assign the rest of the ops to the last stage. + // Take care of the ordering of the ops - uses cannot be scheduled to the + // cluster before the definition. + DenseMap opToCluster; + for (auto &op : forOp.getBody()->without_terminator()) { + if (schedule.count(&op) == 0) { + opToCluster[&op] = afterPrologue; + } + } + SmallVector queue; + for (auto [op, stage, cluster] : schedule.getOpsInOrder(forOp)) { + // We really only care about the producers from the last stage. + // Others will be scheduled before these ops anyway. + if (stage == numStages - 1) { + queue.push_back(op); + } + } + while (!queue.empty()) { + Operation *op = queue.pop_back_val(); + for (auto user : op->getUsers()) { + if (opToCluster.count(user)) { + tt::CoarseSchedule::Cluster userCluster = opToCluster[user]; + tt::CoarseSchedule::Cluster opCluster = schedule[op].second; + if (*userCluster < *opCluster) { + opToCluster[user] = opCluster; + queue.push_back(user); + } + } + } + } + for (auto [op, cluster] : opToCluster) { + schedule.insert(op, numStages - 1, cluster); + } +} + +// Create an allocation that can hold distance number of loadOp shapes. +static Value createAlloc(scf::ForOp &forOp, Operation *loadOp, + ttg::SharedEncodingAttr sharedEnc, unsigned distance) { + OpBuilder builder(forOp); + Attribute sharedMemorySpace = + triton::gpu::SharedMemorySpaceAttr::get(forOp.getContext()); + auto ty = cast(loadOp->getResultTypes()[0]); + SmallVector bufferShape(ty.getShape().begin(), ty.getShape().end()); + bufferShape.insert(bufferShape.begin(), distance); + Type memdescType = tt::MemDescType::get(bufferShape, ty.getElementType(), + sharedEnc, sharedMemorySpace, + /*mutableMemory=*/true); + return builder.create(loadOp->getLoc(), memdescType, + Value()); +} + +// Convert load ops into shared memory allocation loads and apply +// multi-buffering based on the required number of buffers. +static SmallVector +createStreamOps(scf::ForOp &forOp, tt::CoarseSchedule &schedule, + llvm::MapVector &loadToInfo, + int numStages) { + // Calculate the number of buffers needed for each load. + // TODO: Use the precise number of buffers needed by the particular load. + int numBuffers = -1; + for (auto &[_, info] : loadToInfo) + numBuffers = std::max(numBuffers, info.distToUse); + LDBG("deduced shared memory buffer number = " << numBuffers); + + SmallVector allocs; + SmallVector> loadToAllocs; + for (auto &[loadOp, info] : loadToInfo) { + if (!info.sharedEncoding) + continue; + + Value alloc = createAlloc(forOp, loadOp, info.sharedEncoding, numBuffers); + assert(alloc && "Failed to create alloc for the async load."); + allocs.push_back(alloc); + loadToAllocs.emplace_back(loadOp, alloc); + } + + IRRewriter builder(forOp.getContext()); + builder.setInsertionPoint(forOp); + + Location loc = forOp.getLoc(); + Value minusOne = builder.create(loc, -1, 32); + Value zero = builder.create(loc, 0, 32); + Value one = builder.create(loc, 1, 32); + Value extractIdx = minusOne; + Value numBuffersVal = + builder.create(loc, numBuffers, 32); + + unsigned newOperandIndex = forOp.getBody()->getNumArguments(); + // Patch the loop to add the new loop carried dependencies. + scf::ForOp newForOp = + replaceForOpWithNewSignature(builder, forOp, {extractIdx}); + forOp.erase(); + forOp = newForOp; + + // Create one counter for the extract indices to avoid creating long + // live range. + extractIdx = newForOp.getBody()->getArgument(newOperandIndex); + + builder.setInsertionPoint(newForOp.getBody(), newForOp.getBody()->begin()); + extractIdx = builder.create(loc, extractIdx, one); + Value cndExt = builder.create(loc, arith::CmpIPredicate::slt, + extractIdx, numBuffersVal); + extractIdx = builder.create(loc, cndExt, extractIdx, zero); + + // Create a cluster for prefetching global reads for the dot. + tt::CoarseSchedule::Cluster prefetchCluster = schedule.clusters.newAtBack(); + + for (auto &[op, alloc] : loadToAllocs) { + if (auto loadOp = dyn_cast(op)) + createStreamCopy(forOp, loadOp, alloc, extractIdx, schedule, + prefetchCluster, loadToInfo, numStages); + } + // Patch the yield with the updated counters. + appendToForOpYield(forOp, {extractIdx}); + + return allocs; +} + +static bool preprocessLoopAndBuildSchedule(scf::ForOp &forOp, int numStages, + tt::PipeliningOption &options) { + // Schedule the loads and root ops (dot ops) in the loop. This will give us + // a scaffold for the final schedule. + DenseSet rootUsers; + tt::CoarseSchedule coarseSchedule(numStages); + llvm::MapVector loadToInfo = + scheduleLoads(forOp, coarseSchedule, rootUsers, numStages); + if (loadToInfo.empty()) + return false; + + LLVM_DEBUG({ + LDBG("Coarse schedule loads only:"); + coarseSchedule.dump(); + }); + + // Convert the loads into shared memory allocations and loads from them. + SmallVector allocs = + createStreamOps(forOp, coarseSchedule, loadToInfo, numStages); + + LLVM_DEBUG({ + LDBG("Coarse schedule with stream loads:"); + coarseSchedule.dump(); + }); + + tt::CoarseSchedule::Cluster afterPrologue = coarseSchedule.clusters.begin(); + + scheduleDependencies(forOp, coarseSchedule, numStages); + LLVM_DEBUG({ + LDBG("Coarse schedule with dependencies:"); + coarseSchedule.dump(); + }); + + scheduleDistanceOneDependencies(forOp, coarseSchedule, numStages); + LLVM_DEBUG({ + LDBG("Coarse schedule with dist 1:"); + coarseSchedule.dump(); + }); + + scheduleRemainingToLastStage(forOp, coarseSchedule, afterPrologue, numStages); + LLVM_DEBUG({ + LDBG("Final coarse schedule:"); + coarseSchedule.dump(); + }); + + // Create the final schedule for the kernel loop. This will dictate the + // stages and order of operations to the pipeline expander. + std::vector> schedule = + coarseSchedule.createFinalSchedule(forOp); + + // Fill out the pipeline options. + options.getScheduleFn = + [schedule](scf::ForOp, std::vector> &s) { + s = std::move(schedule); + }; + options.peelEpilogue = true; + options.predicateFn = tt::predicateOp; + options.supportDynamicLoops = true; + + OpBuilder builder(forOp); + builder.setInsertionPointAfter(forOp); + // Explicitly deallocate created allocations. + for (auto alloc : allocs) + builder.create(forOp.getLoc(), alloc); + return true; +} + +// Return true if the preconditions for pipelining the loop are met. +static bool checkPrecondition(scf::ForOp forOp) { + // Skip loop with distance > 1 for now. + // TODO: relax the constraint in the expander. + if (llvm::any_of(forOp.getBody()->getTerminator()->getOperands(), + [](Value operand) { return !operand.getDefiningOp(); })) + return false; + + // Don't pipeline outer loops. + auto hasNestedLoopInside = [forOp](Operation *op) { + if (op != forOp && isa(op)) + return WalkResult::interrupt(); + return WalkResult::advance(); + }; + return !forOp->walk(hasNestedLoopInside).wasInterrupted(); +} + +static bool pipelineLoop(scf::ForOp forOp, int numStages) { + if (!checkPrecondition(forOp)) + return false; + + tt::PipeliningOption options; + if (!preprocessLoopAndBuildSchedule(forOp, numStages, options)) + return false; + LDBG("Loop before sending to expander:\n" << *forOp); + + IRRewriter rewriter(forOp->getContext()); + rewriter.setInsertionPoint(forOp); + return succeeded(tt::pipelineForLoop(rewriter, forOp, options)); +} + +namespace { +struct PipelinePass : public TritonAMDGPUStreamPipelineV2Base { + PipelinePass() = default; + PipelinePass(int32_t numStages) { this->numStages = numStages; } + + void runOnOperation() override { + SmallVector loops; + getOperation()->walk([&](scf::ForOp forOp) { + // Bail out for loops with num_stage <= 1. + if (getNumStagesOrDefault(forOp) > 1) + loops.push_back(forOp); + }); + + for (scf::ForOp forOp : loops) + pipelineLoop(forOp, getNumStagesOrDefault(forOp)); + } + +private: + int getNumStagesOrDefault(scf::ForOp forOp) { + // Use the attribute attached to the loop if it exists, otherwise use the + // global control. + if (auto attr = forOp->getAttrOfType(tt::kNumStagesAttrName)) + return attr.getInt(); + return numStages; + } +}; +} // anonymous namespace + +std::unique_ptr +mlir::createTritonAMDGPUStreamPipelineV2Pass(int numStages) { + return std::make_unique(numStages); +} diff --git a/third_party/amd/python/triton_amd.cc b/third_party/amd/python/triton_amd.cc index ddc1feb2aa94..2d83cf51bbe6 100644 --- a/third_party/amd/python/triton_amd.cc +++ b/third_party/amd/python/triton_amd.cc @@ -51,10 +51,14 @@ void init_triton_amd_passes_ttgpuir(py::module &&m) { const std::string, int, int); ADD_PASS_WRAPPER_0("add_optimize_epilogue", mlir::createTritonAMDGPUOptimizeEpiloguePass); + ADD_PASS_WRAPPER_0("add_tritongpu_bypass_lds_for_dot_layout_pass", + mlir::createTritonAMDGPUBypassLDSForDotLayout); ADD_PASS_WRAPPER_0("add_reorder_instructions", mlir::createTritonAMDGPUReorderInstructionsPass); ADD_PASS_WRAPPER_0("add_stream_pipeline", mlir::createTritonAMDGPUStreamPipelinePass); + ADD_PASS_WRAPPER_1("add_stream_pipelinev2", + mlir::createTritonAMDGPUStreamPipelineV2Pass, int); } void addControlConstant(llvm::Module *module, const char *name, diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp index 6977e0597db2..4c95e530b173 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp @@ -738,7 +738,8 @@ MemDescType getExpandedDesc(MemDescType descTy) { expandedShape[2] = shape[1]; auto encoding = descTy.getEncoding(); auto expandedEncoding = getExpandedEncoding(encoding); - auto expandedDesc = MemDescType::get(expandedShape, elTy, expandedEncoding); + auto expandedDesc = MemDescType::get(expandedShape, elTy, expandedEncoding, + descTy.getMemorySpace()); return expandedDesc; } diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp index 4407a50bd0d9..4b8c7c1b37d1 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp @@ -34,8 +34,8 @@ class DecomposeLocalLoadToDotOperand auto dstDotOp = dyn_cast( op.getType().getEncoding()); - auto sharedEncoding = - cast(op.getSrc().getType().getEncoding()); + MemDescType srcType = op.getSrc().getType(); + auto sharedEncoding = cast(srcType.getEncoding()); if (!dstDotOp || !sharedEncoding.getHasLeadingOffset()) return failure(); RankedTensorType type = op.getType(); @@ -55,7 +55,8 @@ class DecomposeLocalLoadToDotOperand triton::gpu::SharedEncodingAttr::get( op.getContext(), dstDotOp, type.getShape(), triton::gpu::getOrder(parentEnc), - triton::gpu::getCTALayout(parentEnc), type.getElementType())); + triton::gpu::getCTALayout(parentEnc), type.getElementType()), + srcType.getMemorySpace()); auto tmp = rewriter.create( op.getLoc(), newSharedDescTy, load); auto newConvert = diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM.cpp index 374b9ec9e49b..3e915a577c54 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM.cpp @@ -23,15 +23,10 @@ LogicalResult convertMMA16816(triton::DotOp op, triton::DotOp::Adaptor adaptor, const LLVMTypeConverter *typeConverter, ConversionPatternRewriter &rewriter); -LogicalResult convertWGMMA(triton::DotOp op, triton::DotOp::Adaptor adaptor, +LogicalResult convertWGMMA(triton::nvidia_gpu::WarpGroupDotOp op, + triton::nvidia_gpu::WarpGroupDotOp::Adaptor adaptor, const LLVMTypeConverter *typeConverter, ConversionPatternRewriter &rewriter, Value thread); - -LogicalResult convertAsyncWGMMA(triton::nvidia_gpu::DotAsyncOp op, - triton::nvidia_gpu::DotAsyncOp::Adaptor adaptor, - const LLVMTypeConverter *typeConverter, - ConversionPatternRewriter &rewriter, - Value thread); namespace { struct DotOpConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -59,9 +54,6 @@ struct DotOpConversion : public ConvertOpToLLVMPattern { return convertMMA1688(op, adaptor, getTypeConverter(), rewriter); if (mmaLayout.isAmpere()) return convertMMA16816(op, adaptor, getTypeConverter(), rewriter); - if (mmaLayout.isHopper()) - return convertWGMMA(op, adaptor, getTypeConverter(), rewriter, - getThreadId(rewriter, loc)); llvm::report_fatal_error( "Unsupported MMA kind found when converting DotOp to LLVM."); @@ -76,13 +68,13 @@ struct DotOpConversion : public ConvertOpToLLVMPattern { } }; -struct DotAsyncOpConversion - : public ConvertOpToLLVMPattern { +struct WarpGroupDotOpConversion + : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern< - triton::nvidia_gpu::DotAsyncOp>::ConvertOpToLLVMPattern; + triton::nvidia_gpu::WarpGroupDotOp>::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(triton::nvidia_gpu::DotAsyncOp op, OpAdaptor adaptor, + matchAndRewrite(triton::nvidia_gpu::WarpGroupDotOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); // D = A * B + C @@ -100,26 +92,26 @@ struct DotAsyncOpConversion if (!isOuter && mmaLayout && supportMMA(op.getOperand(0), mmaLayout.getVersionMajor())) { if (mmaLayout.isHopper()) { - return convertAsyncWGMMA(op, adaptor, getTypeConverter(), rewriter, - getThreadId(rewriter, loc)); + return convertWGMMA(op, adaptor, getTypeConverter(), rewriter, + getThreadId(rewriter, loc)); } llvm::report_fatal_error( - "Unsupported MMA kind found when converting DotAsyncOp to LLVM."); + "Unsupported MMA kind found when converting WarpGroupDotOp to LLVM."); } llvm::report_fatal_error( - "Unsupported DotAsyncOp found when converting TritonGPU to LLVM."); + "Unsupported WarpGroupDotOp found when converting TritonGPU to LLVM."); } }; -struct DotWaitOpConversion - : public ConvertOpToLLVMPattern { +struct WarpGroupDotWaitOpConversion + : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern< - triton::nvidia_gpu::DotWaitOp>::ConvertOpToLLVMPattern; + triton::nvidia_gpu::WarpGroupDotWaitOp>::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(triton::nvidia_gpu::DotWaitOp op, OpAdaptor adaptor, + matchAndRewrite(triton::nvidia_gpu::WarpGroupDotWaitOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto pendings = op.getPendings(); Location loc = op.getLoc(); @@ -180,6 +172,6 @@ void mlir::triton::NVIDIA::populateDotOpToLLVMPatterns( LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, PatternBenefit benefit) { patterns.add(typeConverter, benefit); - patterns.add(typeConverter, benefit); - patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); } diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp index 738f0fe040f6..baed96a29704 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp @@ -500,28 +500,12 @@ LogicalResult convertDot(const LLVMTypeConverter *typeConverter, return success(); } -LogicalResult convertWGMMA(triton::DotOp op, triton::DotOp::Adaptor adaptor, +LogicalResult convertWGMMA(triton::nvidia_gpu::WarpGroupDotOp op, + triton::nvidia_gpu::WarpGroupDotOp::Adaptor adaptor, const LLVMTypeConverter *typeConverter, ConversionPatternRewriter &rewriter, Value thread) { auto AEnc = op.getA().getType().getEncoding(); auto BEnc = op.getB().getType().getEncoding(); - assert((mlir::isa(AEnc))); - assert(mlir::isa(BEnc) && - "Operand B should use Shared layout."); - return convertDot(typeConverter, rewriter, op.getLoc(), op.getOperation(), // - op.getA(), op.getB(), op.getC(), op.getD(), // - adaptor.getA(), adaptor.getB(), adaptor.getC(), // - op.getInputPrecision() == InputPrecision::TF32, - op.getMaxNumImpreciseAcc(), true, thread); -} - -LogicalResult convertAsyncWGMMA(triton::nvidia_gpu::DotAsyncOp op, - triton::nvidia_gpu::DotAsyncOp::Adaptor adaptor, - const LLVMTypeConverter *typeConverter, - ConversionPatternRewriter &rewriter, - Value thread) { - auto AEnc = op.getA().getType().getEncoding(); - auto BEnc = op.getB().getType().getEncoding(); assert(mlir::isa(AEnc) || mlir::isa(AEnc)); assert(mlir::isa(BEnc) && @@ -530,5 +514,5 @@ LogicalResult convertAsyncWGMMA(triton::nvidia_gpu::DotAsyncOp op, op.getA(), op.getB(), op.getC(), op.getD(), // adaptor.getA(), adaptor.getB(), adaptor.getC(), op.getInputPrecision() == InputPrecision::TF32, - op.getMaxNumImpreciseAcc(), false, thread); + op.getMaxNumImpreciseAcc(), !op.getIsAsync(), thread); }