diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index caf68475a9ca..7d8b5588dd7b 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -1004,11 +1004,6 @@ def TT_MakeTensorDescOp : TT_Op<"make_tensor_descriptor", [ "triton::PaddingOption":$padding)> ]; - let extraClassDeclaration = [{ - ArrayRef getTensorShape() { - return getType().getBlockType().getShape(); - } - }]; } // The following ops, including `call`, `func`, and `return` are copied and modified from diff --git a/include/triton/Dialect/Triton/IR/TritonTypeInterfaces.td b/include/triton/Dialect/Triton/IR/TritonTypeInterfaces.td index a12c6b6fe975..ad24ca7fa7f3 100644 --- a/include/triton/Dialect/Triton/IR/TritonTypeInterfaces.td +++ b/include/triton/Dialect/Triton/IR/TritonTypeInterfaces.td @@ -26,25 +26,48 @@ def TT_TensorDescInterface : TypeInterface<"TensorDescInterface"> { let methods = [ InterfaceMethod< - /*desc=*/"Returns the block type of the tensor descriptor", + /*desc=*/"Returns the shape of the descriptor block", + /*retType=*/"llvm::ArrayRef", + /*methodName=*/"getShape", + /*args=*/(ins) + >, + InterfaceMethod< + /*desc=*/"Returns the element type of the descriptor block", + /*retType=*/"mlir::Type", + /*methodName=*/"getElementType", + /*args=*/(ins) + >, + InterfaceMethod< + /*desc=*/"Returns the optional shared memory layout encoding", + /*retType=*/"mlir::Attribute", + /*methodName=*/"getSharedLayout", + /*args=*/(ins) + >, + InterfaceMethod< + /*desc=*/"Returns a block tensor type constructed from shape and element type", /*retType=*/"mlir::RankedTensorType", /*methodName=*/"getBlockType", - /*args=*/(ins) + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImpl=*/[{ + return mlir::RankedTensorType::get($_type.getShape(), + $_type.getElementType()); + }] >, InterfaceMethod< - /*desc=*/"Returns the block type with signless integer element type", + /*desc=*/"Returns a block tensor type constructed with signless integer element type", /*retType=*/"mlir::RankedTensorType", /*methodName=*/"getSignlessBlockType", /*args=*/(ins), /*methodBody=*/"", /*defaultImpl=*/[{ - auto resTy = $_type.getBlockType(); - if (auto intTy = llvm::dyn_cast(resTy.getElementType())) { - auto width = resTy.getElementTypeBitWidth(); - auto signlessTy = mlir::IntegerType::get($_type.getContext(), width); - resTy = resTy.clone(signlessTy); + auto shape = $_type.getShape(); + auto elemTy = $_type.getElementType(); + if (auto intTy = llvm::dyn_cast(elemTy)) { + auto width = intTy.getWidth(); + elemTy = mlir::IntegerType::get($_type.getContext(), width); } - return resTy; + return mlir::RankedTensorType::get(shape, elemTy); }] >, ]; diff --git a/include/triton/Dialect/Triton/IR/TritonTypes.td b/include/triton/Dialect/Triton/IR/TritonTypes.td index 07d187aa5b04..dc27c705d4a4 100644 --- a/include/triton/Dialect/Triton/IR/TritonTypes.td +++ b/include/triton/Dialect/Triton/IR/TritonTypes.td @@ -106,27 +106,58 @@ def TT_TensorDescType : TritonTypeDef<"TensorDesc", "tensordesc", [TT_TensorDesc A portable abstraction for TMA descriptors. This is the base tensor descriptor type for tiled tensor memory access. + Shape and elementType describe the block dimensions and data type. + The optional sharedLayout attribute carries the shared memory encoding + (e.g. swizzle pattern) that is assigned during lowering. + For specialized access patterns like im2col, see TensorDescIm2ColType in the TritonNvidiaGPU dialect. }]; let parameters = (ins - "RankedTensorType":$blockType + ArrayRefParameter<"int64_t">:$shape, + "Type":$elementType, + OptionalParameter<"Attribute">:$sharedLayout ); - let assemblyFormat = "`<` $blockType `>`"; - let builders = [ + // Builder from shape + elementType + sharedLayout + TypeBuilderWithInferredContext<(ins + "llvm::ArrayRef":$shape, + "Type":$elementType, + "Attribute":$sharedLayout + ), [{ + return $_get(elementType.getContext(), shape, elementType, sharedLayout); + }]>, // Builder with signedness - TypeBuilder<(ins "RankedTensorType":$blockType, "bool":$isSigned), [{ - if (auto intTy = llvm::dyn_cast(blockType.getElementType())) { + TypeBuilderWithInferredContext<(ins + "llvm::ArrayRef":$shape, + "Type":$elementType, + "bool":$isSigned + ), [{ + if (auto intTy = llvm::dyn_cast(elementType)) { auto sem = isSigned ? IntegerType::Signed : IntegerType::Unsigned; - auto elemTy = IntegerType::get($_ctxt, intTy.getWidth(), sem); - blockType = blockType.clone(elemTy); + elementType = IntegerType::get(elementType.getContext(), intTy.getWidth(), sem); } - return Base::get($_ctxt, blockType); + return $_get(elementType.getContext(), shape, elementType, Attribute{}); + }]>, + // Builder with signedness and shared layout + TypeBuilderWithInferredContext<(ins + "llvm::ArrayRef":$shape, + "Type":$elementType, + "Attribute":$sharedLayout, + "bool":$isSigned + ), [{ + if (auto intTy = llvm::dyn_cast(elementType)) { + auto sem = isSigned ? IntegerType::Signed : IntegerType::Unsigned; + elementType = IntegerType::get(elementType.getContext(), intTy.getWidth(), sem); + } + return $_get(elementType.getContext(), shape, elementType, sharedLayout); }]>, ]; + + let hasCustomAssemblyFormat = 1; + let skipDefaultBuilders = 1; } #endif diff --git a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUTypes.td b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUTypes.td index edf0f27908db..6a7f256f8cec 100644 --- a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUTypes.td +++ b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUTypes.td @@ -52,7 +52,9 @@ def TTNG_TensorDescIm2ColType : TTNG_TypeDef<"TensorDescIm2Col", "tensordesc_im2 operations. Parameters: - - blockType: The shape and element type of the data block being accessed + - shape: The block dimensions + - elementType: The element type of the data block + - sharedLayout: Optional shared memory encoding (swizzle pattern, etc.) This type implements TensorDescInterface, sharing common operations with the tiled TensorDescType in the base Triton dialect. @@ -62,28 +64,28 @@ def TTNG_TensorDescIm2ColType : TTNG_TypeDef<"TensorDescIm2Col", "tensordesc_im2 }]; let parameters = (ins - "RankedTensorType":$blockType + ArrayRefParameter<"int64_t">:$shape, + "Type":$elementType, + OptionalParameter<"Attribute">:$sharedLayout ); - let assemblyFormat = [{ - `<` $blockType `>` - }]; - let builders = [ - // Builder with signedness for integer types - TypeBuilder<(ins - "RankedTensorType":$blockType, + TypeBuilderWithInferredContext<(ins + "llvm::ArrayRef":$shape, + "Type":$elementType, + "Attribute":$sharedLayout, "bool":$isSigned ), [{ - if (auto intTy = llvm::dyn_cast(blockType.getElementType())) { + if (auto intTy = llvm::dyn_cast(elementType)) { auto sem = isSigned ? IntegerType::Signed : IntegerType::Unsigned; - auto elemTy = IntegerType::get($_ctxt, intTy.getWidth(), sem); - blockType = blockType.clone(elemTy); + elementType = IntegerType::get(elementType.getContext(), intTy.getWidth(), sem); } - return Base::get($_ctxt, blockType); + return $_get(elementType.getContext(), shape, elementType, sharedLayout); }]> ]; + let hasCustomAssemblyFormat = 1; + let skipDefaultBuilders = 1; let genVerifyDecl = 1; } diff --git a/include/triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h b/include/triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h index 2cbc80588c18..58d6b7dc26be 100644 --- a/include/triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h +++ b/include/triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h @@ -32,17 +32,18 @@ inline SmallVector getTMABlockShape(Attribute encoding, mmaEnc.getFp4Padded(), mmaEnc.getTransposed(), packedSize, mode); } -inline SmallVector -getTMABlockShape(RankedTensorType ty, bool packedSize, gpu::TMAMode mode) { +inline SmallVector getTMABlockShape(triton::gpu::MemDescType ty, + bool packedSize, + gpu::TMAMode mode) { auto shapePerCTA = gpu::getShapePerCTA(ty); return getTMABlockShape(ty.getEncoding(), shapePerCTA, packedSize, mode); } -inline SmallVector getTMABlockShape(triton::gpu::MemDescType ty, +inline SmallVector getTMABlockShape(triton::TensorDescInterface ty, bool packedSize, gpu::TMAMode mode) { - auto shapePerCTA = gpu::getShapePerCTA(ty); - return getTMABlockShape(ty.getEncoding(), shapePerCTA, packedSize, mode); + auto shapePerCTA = gpu::getShapePerCTA(ty.getSharedLayout(), ty.getShape()); + return getTMABlockShape(ty.getSharedLayout(), shapePerCTA, packedSize, mode); } FailureOr getTMASwizzleMode(Location loc, triton::TensorDescInterface ty); diff --git a/lib/Conversion/TritonInstrumentToLLVM/GSanToLLVM.cpp b/lib/Conversion/TritonInstrumentToLLVM/GSanToLLVM.cpp index 21d9c77f1980..de8834b11527 100644 --- a/lib/Conversion/TritonInstrumentToLLVM/GSanToLLVM.cpp +++ b/lib/Conversion/TritonInstrumentToLLVM/GSanToLLVM.cpp @@ -637,7 +637,7 @@ struct GSanTensorDescInfoOpConversion "expected byte-addressable element"); } - unsigned rank = descTy.getBlockType().getRank(); + unsigned rank = descTy.getShape().size(); unsigned elemBytes = elemTy.getIntOrFloatBitWidth() / 8; if (op->getNumResults() != 1 + 2 * rank) { return rewriter.notifyMatchFailure( diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index a6112f854e63..0be189c451f6 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -1094,9 +1094,7 @@ void MakeTensorDescOp::build(OpBuilder &builder, OperationState &state, } auto elemTy = ptrTy.getPointeeType(); SmallVector blockShape64(blockShape); - auto blockTy = RankedTensorType::get(blockShape64, elemTy); - auto descTy = - TensorDescType::get(builder.getContext(), blockTy, isSignedInteger); + auto descTy = TensorDescType::get(blockShape64, elemTy, isSignedInteger); auto paddingAttr = PaddingOptionAttr::get(builder.getContext(), padding); return build(builder, state, descTy, base, shape, strides, paddingAttr); } diff --git a/lib/Dialect/Triton/IR/Types.cpp b/lib/Dialect/Triton/IR/Types.cpp index 979ee56b2130..27aca2135802 100644 --- a/lib/Dialect/Triton/IR/Types.cpp +++ b/lib/Dialect/Triton/IR/Types.cpp @@ -24,6 +24,42 @@ void TritonDialect::registerTypes() { >(); } +// Format: !tt.tensordesc<128x64xf16> +// !tt.tensordesc<128x64xf16, #shared> +Type TensorDescType::parse(AsmParser &parser) { + if (failed(parser.parseLess())) + return Type(); + + SmallVector shape; + if (failed(parser.parseDimensionList(shape, /*allowDynamic=*/false))) + return Type(); + + Type elementType; + if (failed(parser.parseType(elementType))) + return Type(); + + Attribute sharedLayout; + if (succeeded(parser.parseOptionalComma())) { + if (failed(parser.parseAttribute(sharedLayout))) + return Type(); + } + + if (failed(parser.parseGreater())) + return Type(); + + return TensorDescType::get(shape, elementType, sharedLayout); +} + +void TensorDescType::print(AsmPrinter &printer) const { + printer << "<"; + for (auto dim : getShape()) + printer << dim << "x"; + printer << getElementType(); + if (getSharedLayout()) + printer << ", " << getSharedLayout(); + printer << ">"; +} + Type PointerType::parse(AsmParser &parser) { if (parser.parseLess()) return Type(); diff --git a/lib/Dialect/Triton/Transforms/RewriteTensorDescriptorToPointer.cpp b/lib/Dialect/Triton/Transforms/RewriteTensorDescriptorToPointer.cpp index c7c25d2bde07..30e6dac6f81c 100644 --- a/lib/Dialect/Triton/Transforms/RewriteTensorDescriptorToPointer.cpp +++ b/lib/Dialect/Triton/Transforms/RewriteTensorDescriptorToPointer.cpp @@ -65,7 +65,7 @@ struct Descriptor { }; Descriptor unpackDescriptor(TensorDescType type, ValueRange pack) { - int rank = type.getBlockType().getRank(); + int rank = type.getShape().size(); assert(pack.size() == 1 + 2 * static_cast(rank) + 2 && "Expected tensor descriptors to consist of a pointer, " "followed by 'rank' shape values and 'rank' stride values, " @@ -328,7 +328,7 @@ struct RewriteLoadPattern : OpConversionPattern { matchAndRewrite(triton::DescriptorLoadOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); - const auto blockShape = op.getDesc().getType().getBlockType().getShape(); + const auto blockShape = op.getDesc().getType().getShape(); auto descTy = op.getDesc().getType(); auto desc = unpackDescriptor(descTy, adaptor.getDesc()); auto offsets = castToI64(rewriter, op.getIndices()); @@ -340,7 +340,7 @@ struct RewriteLoadPattern : OpConversionPattern { newLoad->setAttrs(filterSegmentSizes(op->getAttrs())); Value result = newLoad.getResult(); - if (descTy.getBlockType().getElementType().isF32()) { + if (descTy.getElementType().isF32()) { auto ifOp = scf::IfOp::create(rewriter, loc, result.getType(), desc.roundF32ToTF32, /*withElse=*/true); @@ -367,7 +367,7 @@ struct RewriteStorePattern : OpConversionPattern { ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); auto descTy = op.getDesc().getType(); - const auto blockShape = descTy.getBlockType().getShape(); + const auto blockShape = descTy.getShape(); auto desc = unpackDescriptor(descTy, adaptor.getDesc()); auto offsets = castToI64(rewriter, op.getIndices()); @@ -458,7 +458,7 @@ struct RewriteScatterPattern std::optional translateReduceKind(DescriptorReduceKind kind, TensorDescType ty) { - auto scalarTy = ty.getBlockType().getElementType(); + auto scalarTy = ty.getElementType(); switch (kind) { case DescriptorReduceKind::ADD: return scalarTy.isInteger() ? RMWOp::ADD : RMWOp::FADD; @@ -496,7 +496,7 @@ struct RewriteReducePattern : OpConversionPattern { ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); auto descTy = op.getDesc().getType(); - const auto blockShape = descTy.getBlockType().getShape(); + const auto blockShape = descTy.getShape(); auto desc = unpackDescriptor(descTy, adaptor.getDesc()); auto offsets = castToI64(rewriter, op.getIndices()); auto rmwOp = translateReduceKind(op.getKind(), descTy); @@ -504,7 +504,7 @@ struct RewriteReducePattern : OpConversionPattern { std::string msgstring; llvm::raw_string_ostream msg(msgstring); msg << "Cannot fallback on descriptor atomic op, unsupported for type " - << descTy.getBlockType().getElementType(); + << descTy.getElementType(); return op->emitError(msgstring); } diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 6115ac5a03a9..443a493c19de 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -4444,6 +4444,8 @@ std::optional triton::gpu::getWarpSpecializeTag(Operation *op) { } PaddedSharedEncodingAttr triton::gpu::getPaddedEncoding(Attribute encoding) { + if (!encoding) + return nullptr; if (auto padded = dyn_cast(encoding)) return padded; if (auto partitioned = dyn_cast(encoding)) diff --git a/lib/Dialect/TritonGPU/Transforms/DescriptorMemoryLayouts.cpp b/lib/Dialect/TritonGPU/Transforms/DescriptorMemoryLayouts.cpp index 52d9a30c3c27..3818477aa611 100644 --- a/lib/Dialect/TritonGPU/Transforms/DescriptorMemoryLayouts.cpp +++ b/lib/Dialect/TritonGPU/Transforms/DescriptorMemoryLayouts.cpp @@ -126,8 +126,8 @@ TensorDescType getTensorDescTypeWithEncoding(Operation *op, Attribute encoding) { auto sharedEnc = cast(encoding); encoding = updateEncodingForShape(op, sharedEnc, existingTy); - auto blockTy = existingTy.cloneWithEncoding(encoding); - return TensorDescType::get(existingTy.getContext(), blockTy); + return TensorDescType::get(existingTy.getShape(), existingTy.getElementType(), + encoding); } struct UseInfo { @@ -283,7 +283,7 @@ AssignDescriptorMemoryLayouts::getUseInfo(Operation *op) { : load.getType().getEncoding(); info.cgaLayout = getCGALayout(encoding); auto shape = load.getResult().getType().getShape(); - auto rank = load.getDesc().getType().getBlockType().getRank(); + auto rank = load.getDesc().getType().getShape().size(); info.shape = expandToRank(shape, rank); return info; } @@ -294,7 +294,7 @@ AssignDescriptorMemoryLayouts::getUseInfo(Operation *op) { : gather.getType().getEncoding(); info.cgaLayout = getCGALayout(encoding); auto shape = gather.getResult().getType().getShape(); - auto rank = gather.getDesc().getType().getBlockType().getRank(); + auto rank = gather.getDesc().getType().getShape().size(); info.shape = expandToRank(shape, rank); return info; } @@ -303,7 +303,7 @@ AssignDescriptorMemoryLayouts::getUseInfo(Operation *op) { auto encoding = store.getSrc().getType().getEncoding(); info.cgaLayout = getCGALayout(encoding); auto shape = store.getSrc().getType().getShape(); - auto rank = store.getDesc().getType().getBlockType().getRank(); + auto rank = store.getDesc().getType().getShape().size(); info.shape = expandToRank(shape, rank); return info; } @@ -353,7 +353,7 @@ void AssignDescriptorMemoryLayouts::runOnFunction(FuncOp &func) { auto itr = valueToEncodingInfo.find(typedVal); if (itr != valueToEncodingInfo.end()) info = combineEncodings(*itr->second, info, - typedVal.getType().getBlockType().getRank()); + typedVal.getType().getShape().size()); } auto einfo = internEncoding(encodings, info); diff --git a/lib/Dialect/TritonInstrument/Transforms/GlobalSanitizer.cpp b/lib/Dialect/TritonInstrument/Transforms/GlobalSanitizer.cpp index 0d12450a7302..841c83d73257 100644 --- a/lib/Dialect/TritonInstrument/Transforms/GlobalSanitizer.cpp +++ b/lib/Dialect/TritonInstrument/Transforms/GlobalSanitizer.cpp @@ -127,7 +127,7 @@ static DescriptorInfo getDescriptorInfo(Value desc, OpBuilder &builder) { auto elemTy = descTy.getSignlessBlockType().getElementType(); auto basePtrTy = tt::getPointerType(elemTy); - unsigned rank = descTy.getBlockType().getRank(); + unsigned rank = descTy.getShape().size(); SmallVector resultTypes; resultTypes.reserve(1 + 2 * rank); resultTypes.push_back(basePtrTy); diff --git a/lib/Dialect/TritonNvidiaGPU/IR/Dialect.cpp b/lib/Dialect/TritonNvidiaGPU/IR/Dialect.cpp index f4611e60184a..d501e8a7ef3d 100644 --- a/lib/Dialect/TritonNvidiaGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonNvidiaGPU/IR/Dialect.cpp @@ -489,17 +489,58 @@ LogicalResult impl::verifyMMAv5Op(Operation *op) { #define GET_TYPEDEF_CLASSES #include "triton/Dialect/TritonNvidiaGPU/IR/Types.cpp.inc" +//===----------------------------------------------------------------------===// +// TensorDescIm2ColType Printer/Parser +//===----------------------------------------------------------------------===// +// Format: !ttng.tensordesc_im2col<64x128xf16> +// !ttng.tensordesc_im2col<64x128xf16, #shared> +Type TensorDescIm2ColType::parse(AsmParser &parser) { + if (failed(parser.parseLess())) + return Type(); + + SmallVector shape; + if (failed(parser.parseDimensionList(shape, /*allowDynamic=*/false))) + return Type(); + + Type elementType; + if (failed(parser.parseType(elementType))) + return Type(); + + Attribute sharedLayout; + if (succeeded(parser.parseOptionalComma())) { + if (failed(parser.parseAttribute(sharedLayout))) + return Type(); + } + + if (failed(parser.parseGreater())) + return Type(); + + Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation()); + return TensorDescIm2ColType::getChecked(loc, parser.getContext(), shape, + elementType, sharedLayout); +} + +void TensorDescIm2ColType::print(AsmPrinter &printer) const { + printer << "<"; + for (auto dim : getShape()) + printer << dim << "x"; + printer << getElementType(); + if (getSharedLayout()) + printer << ", " << getSharedLayout(); + printer << ">"; +} + //===----------------------------------------------------------------------===// // TensorDescIm2ColType Verifier //===----------------------------------------------------------------------===// LogicalResult TensorDescIm2ColType::verify(function_ref emitError, - RankedTensorType blockType) { - // blockType must be rank 2 for im2col mode - if (blockType.getRank() != 2) { + ArrayRef shape, Type elementType, + Attribute sharedLayout) { + if (shape.size() != 2) { return emitError() - << "TensorDescIm2ColType requires rank-2 blockType, got rank " - << blockType.getRank(); + << "TensorDescIm2ColType requires rank-2 shape, got rank " + << shape.size(); } return success(); } diff --git a/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp b/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp index dcacce131f23..5bd29e234fa8 100644 --- a/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp @@ -295,8 +295,8 @@ static LogicalResult verifyTMAEncoding(Operation *op, TensorDescInterface desc, auto nvmma = dyn_cast(enc); if (!nvmma) return op->emitOpError("TMA descriptor must have NVMMA shared layout"); - auto descEnc = dyn_cast_if_present( - desc.getBlockType().getEncoding()); + auto descEnc = + dyn_cast_if_present(desc.getSharedLayout()); // NOTE: Cannot do descEnc != enc as the encodings may differ in rank for // rank-reducing loads if (!descEnc || descEnc.getTransposed() != nvmma.getTransposed() || @@ -344,7 +344,7 @@ static bool isIm2ColDescriptor(Type descType) { static LogicalResult verifyAsyncTMACoords(Operation *op, ValueRange coords, TensorDescInterface desc, bool isIm2Col) { - unsigned blockRank = desc.getBlockType().getRank(); + unsigned blockRank = desc.getShape().size(); if (isIm2Col) { // For IM2COL mode, coordinates are for the full tensor (3D-5D) diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.cpp index c1f8c06b43f6..9c5e66f468da 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.cpp @@ -11,8 +11,8 @@ namespace mlir::triton::nvidia_gpu { ttg::SharedEncodingTrait getEncodingFromDescriptor(Operation *op, RankedTensorType tensorType, Value desc) { - auto descBlockType = cast(desc.getType()).getBlockType(); - Attribute encoding = descBlockType.getEncoding(); + auto descType = cast(desc.getType()); + Attribute encoding = descType.getSharedLayout(); if (!encoding) { constexpr auto msg = "Internal Error: Tensor descriptor should have encoding set"; @@ -21,7 +21,7 @@ ttg::SharedEncodingTrait getEncodingFromDescriptor(Operation *op, llvm::report_fatal_error(msg); } auto sharedEnc = cast(encoding); - if (descBlockType.getShape() == tensorType.getShape()) + if (descType.getShape() == tensorType.getShape()) return sharedEnc; return ttg::updateEncodingForShape(op, sharedEnc, tensorType); @@ -35,9 +35,9 @@ bool hasCGABroadcast(ttg::MemDescType memDescType) { } FailureOr getTMASwizzleMode(Location loc, tt::TensorDescInterface ty) { - auto blockType = ty.getBlockType(); - auto encoding = blockType.getEncoding(); - auto mmaEncoding = dyn_cast(encoding); + auto encoding = ty.getSharedLayout(); + auto mmaEncoding = + dyn_cast_if_present(encoding); unsigned swizzleBytes = mmaEncoding ? mmaEncoding.getSwizzlingByteWidth() : 0; if (!mmaEncoding) { auto swizzledEnc = dyn_cast(encoding); @@ -90,14 +90,13 @@ enum TMA_ELEMENT_TYPES { }; FailureOr getTMAElementType(Location loc, tt::TensorDescInterface ty) { - auto blockType = ty.getBlockType(); - auto encoding = blockType.getEncoding(); + auto encoding = ty.getSharedLayout(); bool fp4Padded = isFp4Padded(encoding); if (fp4Padded) return TMA_B4X16_P64; - auto elemTy = blockType.getElementType(); + auto elemTy = ty.getElementType(); if (elemTy.isBF16()) { return TMA_BF16; } else if (elemTy.isF16()) { @@ -138,13 +137,13 @@ LogicalResult createTMADesc(Value tmaPtr, MakeTensorDescOp op, auto elemType = op.getBase().getType().getPointeeType(); auto elemSize = elemType.getIntOrFloatBitWidth() / 8; - auto encoding = op.getType().getBlockType().getEncoding(); + auto encoding = op.getType().getSharedLayout(); auto mmaEncoding = llvm::dyn_cast_or_null(encoding); bool fp4Padded = mmaEncoding && mmaEncoding.getFp4Padded(); int paddingScale = fp4Padded ? 2 : 1; - auto shapePerCTA = gpu::getShapePerCTA(encoding, op.getTensorShape()); + auto shapePerCTA = gpu::getShapePerCTA(encoding, op.getType().getShape()); // MakeTensorDescOp creates tiled descriptors (not im2col) auto blockShape = getTMABlockShape(encoding, shapePerCTA, /*packedSize=*/false, gpu::TMAMode::Tiled); @@ -161,8 +160,8 @@ LogicalResult createTMADesc(Value tmaPtr, MakeTensorDescOp op, unsigned swizzleBytes = mmaEncoding ? mmaEncoding.getSwizzlingByteWidth() : 0; if (!mmaEncoding) { - auto swizzledEnc = dyn_cast( - op.getType().getBlockType().getEncoding()); + auto swizzledEnc = + dyn_cast_if_present(encoding); if (!swizzledEnc || swizzledEnc.getVec() != 1 || swizzledEnc.getPerPhase() != 1 || swizzledEnc.getMaxPhase() != 1) { op->emitError() << "Unhandled encoding type"; diff --git a/python/src/gluon_ir.cc b/python/src/gluon_ir.cc index 35f358b165c9..da66a9277f24 100644 --- a/python/src/gluon_ir.cc +++ b/python/src/gluon_ir.cc @@ -595,19 +595,18 @@ void init_gluon_ir(py::module &&m) { .def("get_tensor_descriptor_layout_type", [](GluonOpBuilder &self, Type blockType, bool isSigned, Attribute layout) -> Type { - auto ctx = self.getContext(); auto blockTy = cast(blockType); - auto blockTyLayout = blockTy.cloneWithEncoding(layout); - return triton::TensorDescType::get(ctx, blockTyLayout, isSigned); + return triton::TensorDescType::get(blockTy.getShape(), + blockTy.getElementType(), + layout, isSigned); }) .def("get_tensor_descriptor_im2col_layout_type", [](GluonOpBuilder &self, Type blockType, bool isSigned, Attribute layout) -> Type { - auto ctx = self.getContext(); auto blockTy = cast(blockType); - auto blockTyLayout = blockTy.cloneWithEncoding(layout); return triton::nvidia_gpu::TensorDescIm2ColType::get( - ctx, blockTyLayout); + blockTy.getShape(), blockTy.getElementType(), layout, + isSigned); }) .def("is_convert_layout_trivial", [](GluonOpBuilder &self, Type resultTy, Value value) -> bool { diff --git a/python/src/ir.cc b/python/src/ir.cc index db769960eb87..58ae278c5a6e 100644 --- a/python/src/ir.cc +++ b/python/src/ir.cc @@ -183,8 +183,7 @@ py::list getTensorDescMetadata(ModuleOp &mod) { continue; bool isIm2Col = isa(arg.getType()); - auto blockType = descTy.getBlockType(); - auto encoding = blockType.getEncoding(); + auto encoding = descTy.getSharedLayout(); py::dict metadata; if (isa(encoding)) { @@ -195,21 +194,23 @@ py::list getTensorDescMetadata(ModuleOp &mod) { throw py::type_error("invalid TMA descriptor type"); auto tmaMode = isIm2Col ? ttg::TMAMode::Im2Col : ttg::TMAMode::Tiled; auto blockSize = - ttng::getTMABlockShape(blockType, /*packedSize=*/false, tmaMode); + ttng::getTMABlockShape(descTy, /*packedSize=*/false, tmaMode); metadata["swizzle"] = *swizzle; - metadata["elem_size"] = blockType.getElementTypeBitWidth() / 8; + metadata["elem_size"] = + descTy.getElementType().getIntOrFloatBitWidth() / 8; metadata["elem_type"] = *elemType; metadata["block_size"] = std::vector(blockSize.begin(), blockSize.end()); metadata["fp4_padded"] = mmaEncoding && mmaEncoding.getFp4Padded(); metadata["is_im2col"] = isIm2Col; } else { - auto blockShape = blockType.getShape(); + auto blockShape = descTy.getShape(); metadata["block_size"] = std::vector(blockShape.begin(), blockShape.end()); - metadata["elem_bits"] = blockType.getElementTypeBitWidth(); + metadata["elem_bits"] = descTy.getElementType().getIntOrFloatBitWidth(); - if (auto paddedEnc = dyn_cast(encoding)) { + if (auto paddedEnc = + dyn_cast_if_present(encoding)) { py::list intervalPaddingPairs; for (auto [interval, padding] : llvm::zip_equal( paddedEnc.getIntervals(), paddedEnc.getPaddings())) { @@ -220,7 +221,7 @@ py::list getTensorDescMetadata(ModuleOp &mod) { } metadata["interval_padding_pairs"] = intervalPaddingPairs; - auto blockShape = blockType.getShape(); + auto blockShape = descTy.getShape(); } } result.append(std::move(metadata)); @@ -1536,9 +1537,9 @@ void init_triton_ir(py::module &&m) { }) .def("create_tensor_descriptor_type", [](TritonOpBuilder &self, Type blockTy, bool isSigned) -> Type { - auto ctx = self.getContext(); - return triton::TensorDescType::get( - ctx, cast(blockTy), isSigned); + auto rtt = cast(blockTy); + return triton::TensorDescType::get(rtt.getShape(), + rtt.getElementType(), isSigned); }) .def("create_descriptor_load", [](TritonOpBuilder &self, Value desc, std::vector &indices, diff --git a/python/test/gluon/test_frontend.py b/python/test/gluon/test_frontend.py index 5ae68243a4b9..0262c1cebcec 100644 --- a/python/test/gluon/test_frontend.py +++ b/python/test/gluon/test_frontend.py @@ -831,14 +831,14 @@ def test_async_tma(target): #shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> #smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} { - tt.func public @async_tma_kernel(%arg0: !tt.tensordesc>, %arg1: i32, %arg2: i32, %arg3: i64, %arg4: i64) attributes {noinline = false} { + tt.func public @async_tma_kernel(%arg0: !tt.tensordesc<128x128xf16, #shared>, %arg1: i32, %arg2: i32, %arg3: i64, %arg4: i64) attributes {noinline = false} { %0 = ttg.local_alloc : () -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> %1 = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #shared1, #smem, mutable> ttng.init_barrier %1, 1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> %c0_i32 = arith.constant 0 : i32 %c0_i32_0 = arith.constant 0 : i32 %true = arith.constant true - ttng.async_tma_copy_global_to_local %arg0[%c0_i32, %c0_i32_0] %0, %1, %true : !tt.tensordesc>, !ttg.memdesc<1xi64, #shared1, #smem, mutable> -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> + ttng.async_tma_copy_global_to_local %arg0[%c0_i32, %c0_i32_0] %0, %1, %true : !tt.tensordesc<128x128xf16, #shared>, !ttg.memdesc<1xi64, #shared1, #smem, mutable> -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> %true_1 = arith.constant true ttng.barrier_expect %1, 32768, %true_1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> %c0_i32_2 = arith.constant 0 : i32 @@ -847,7 +847,7 @@ def test_async_tma(target): ttng.inval_barrier %1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> %c0_i32_4 = arith.constant 0 : i32 %c0_i32_5 = arith.constant 0 : i32 - ttng.async_tma_copy_local_to_global %arg0[%c0_i32_4, %c0_i32_5] %0 : !tt.tensordesc>, !ttg.memdesc<128x128xf16, #shared, #smem, mutable> + ttng.async_tma_copy_local_to_global %arg0[%c0_i32_4, %c0_i32_5] %0 : !tt.tensordesc<128x128xf16, #shared>, !ttg.memdesc<128x128xf16, #shared, #smem, mutable> ttng.async_tma_store_wait {pendings = 0 : i32} tt.return } @@ -891,14 +891,14 @@ def test_async_tma_blackwell(): #shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> #smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} { - tt.func public @async_tma_blackwell_kernel(%arg0: !tt.tensordesc>, %arg1: i32, %arg2: i32, %arg3: i64, %arg4: i64) attributes {noinline = false} { + tt.func public @async_tma_blackwell_kernel(%arg0: !tt.tensordesc<1x128xf16, #shared>, %arg1: i32, %arg2: i32, %arg3: i64, %arg4: i64) attributes {noinline = false} { %0 = ttg.local_alloc : () -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> %1 = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #shared1, #smem, mutable> ttng.init_barrier %1, 1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> %2 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> %true = arith.constant true %c0_i32 = arith.constant 0 : i32 - ttng.async_tma_gather %arg0[%2, %c0_i32] %0, %1, %true : !tt.tensordesc>, tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32, !ttg.memdesc<1xi64, #shared1, #smem, mutable>, !ttg.memdesc<128x128xf16, #shared, #smem, mutable>, i1 + ttng.async_tma_gather %arg0[%2, %c0_i32] %0, %1, %true : !tt.tensordesc<1x128xf16, #shared>, tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32, !ttg.memdesc<1xi64, #shared1, #smem, mutable>, !ttg.memdesc<128x128xf16, #shared, #smem, mutable>, i1 %true_0 = arith.constant true ttng.barrier_expect %1, 32768, %true_0 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> %c0_i32_1 = arith.constant 0 : i32 @@ -906,7 +906,7 @@ def test_async_tma_blackwell(): ttng.wait_barrier %1, %c0_i32_1, %true_2 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> ttng.inval_barrier %1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> %c0_i32_3 = arith.constant 0 : i32 - ttng.async_tma_scatter %arg0[%2, %c0_i32_3] %0 : !tt.tensordesc>, tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32, !ttg.memdesc<128x128xf16, #shared, #smem, mutable> + ttng.async_tma_scatter %arg0[%2, %c0_i32_3] %0 : !tt.tensordesc<1x128xf16, #shared>, tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32, !ttg.memdesc<128x128xf16, #shared, #smem, mutable> ttng.async_tma_store_wait {pendings = 0 : i32} tt.return } @@ -3423,12 +3423,12 @@ def test_amd_tdm_load(target): %c128_i32 = arith.constant 128 : i32 %c128_i64 = arith.constant 128 : i64 %c1_i64 = arith.constant 1 : i64 - %0 = tt.make_tensor_descriptor %arg0, [%c32_i32, %c128_i32], [%c128_i64, %c1_i64] : , > + %0 = tt.make_tensor_descriptor %arg0, [%c32_i32, %c128_i32], [%c128_i64, %c1_i64] : , <16x64xf16, #shared> %1 = ttg.local_alloc : () -> !ttg.memdesc<16x64xf16, #shared, #smem, mutable> %c0_i32 = arith.constant 0 : i32 %c2_i32 = arith.constant 2 : i32 %c1_i32 = arith.constant 1 : i32 - %2 = amdg.async_tdm_copy_global_to_local %0[%c0_i32, %c2_i32] into %1, pred = %c1_i32 : !tt.tensordesc> -> !ttg.memdesc<16x64xf16, #shared, #smem, mutable> + %2 = amdg.async_tdm_copy_global_to_local %0[%c0_i32, %c2_i32] into %1, pred = %c1_i32 : !tt.tensordesc<16x64xf16, #shared> -> !ttg.memdesc<16x64xf16, #shared, #smem, mutable> %3 = amdg.async_tdm_wait {num = 0 : i32} %4 = ttg.local_load %1 : !ttg.memdesc<16x64xf16, #shared, #smem, mutable> -> tensor<16x64xf16, #blocked> tt.return @@ -3459,12 +3459,12 @@ def test_amd_host_tdm_load(target): #shared = #ttg.padded_shared<[32:+4] {order = [1, 0], shape = [16, 64]}> #smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} { - tt.func public @amd_host_tdm_load_kernel(%arg0: !tt.tensordesc>, %arg1: i32, %arg2: i32, %arg3: i64, %arg4: i64) attributes {noinline = false} { + tt.func public @amd_host_tdm_load_kernel(%arg0: !tt.tensordesc<16x64xf16, #shared>, %arg1: i32, %arg2: i32, %arg3: i64, %arg4: i64) attributes {noinline = false} { %0 = ttg.local_alloc : () -> !ttg.memdesc<16x64xf16, #shared, #smem, mutable> %c0_i32 = arith.constant 0 : i32 %c2_i32 = arith.constant 2 : i32 %c1_i32 = arith.constant 1 : i32 - %1 = amdg.async_tdm_copy_global_to_local %arg0[%c0_i32, %c2_i32] into %0, pred = %c1_i32 : !tt.tensordesc> -> !ttg.memdesc<16x64xf16, #shared, #smem, mutable> + %1 = amdg.async_tdm_copy_global_to_local %arg0[%c0_i32, %c2_i32] into %0, pred = %c1_i32 : !tt.tensordesc<16x64xf16, #shared> -> !ttg.memdesc<16x64xf16, #shared, #smem, mutable> %2 = amdg.async_tdm_wait {num = 0 : i32} %3 = ttg.local_load %0 : !ttg.memdesc<16x64xf16, #shared, #smem, mutable> -> tensor<16x64xf16, #blocked> tt.return @@ -3504,13 +3504,13 @@ def test_amd_tdm_store(target): %c128_i32 = arith.constant 128 : i32 %c128_i64 = arith.constant 128 : i64 %c1_i64 = arith.constant 1 : i64 - %0 = tt.make_tensor_descriptor %arg0, [%c32_i32, %c128_i32], [%c128_i64, %c1_i64] : , > + %0 = tt.make_tensor_descriptor %arg0, [%c32_i32, %c128_i32], [%c128_i64, %c1_i64] : , <16x64xf16, #shared> %cst = arith.constant 1.000000e+00 : f16 %cst_0 = arith.constant dense<1.000000e+00> : tensor<16x64xf16, #blocked> %1 = ttg.local_alloc %cst_0 : (tensor<16x64xf16, #blocked>) -> !ttg.memdesc<16x64xf16, #shared, #smem, mutable> %c0_i32 = arith.constant 0 : i32 %c2_i32 = arith.constant 2 : i32 - amdg.async_tdm_copy_local_to_global %0[%c0_i32, %c2_i32] from %1 : !ttg.memdesc<16x64xf16, #shared, #smem, mutable> -> !tt.tensordesc> + amdg.async_tdm_copy_local_to_global %0[%c0_i32, %c2_i32] from %1 : !ttg.memdesc<16x64xf16, #shared, #smem, mutable> -> !tt.tensordesc<16x64xf16, #shared> %2 = amdg.async_tdm_wait {num = 0 : i32} tt.return } @@ -3553,12 +3553,12 @@ def test_amd_tdm_gather(target): %c128_i32 = arith.constant 128 : i32 %c128_i64 = arith.constant 128 : i64 %c1_i64 = arith.constant 1 : i64 - %0 = tt.make_tensor_descriptor %arg0, [%c32_i32, %c128_i32], [%c128_i64, %c1_i64] : , > + %0 = tt.make_tensor_descriptor %arg0, [%c32_i32, %c128_i32], [%c128_i64, %c1_i64] : , <16x64xf16, #shared> %1 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked}>> %2 = ttg.local_alloc : () -> !ttg.memdesc<16x64xf16, #shared, #smem, mutable> %c0_i32 = arith.constant 0 : i32 %c1_i32 = arith.constant 1 : i32 - %3 = amdg.async_tdm_gather %0[%1, %c0_i32] to %2, pred = %c1_i32 : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked}>>, !ttg.memdesc<16x64xf16, #shared, #smem, mutable> -> !tt.tensordesc> + %3 = amdg.async_tdm_gather %0[%1, %c0_i32] to %2, pred = %c1_i32 : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked}>>, !ttg.memdesc<16x64xf16, #shared, #smem, mutable> -> !tt.tensordesc<16x64xf16, #shared> %4 = amdg.async_tdm_wait {num = 0 : i32} %5 = ttg.local_load %2 : !ttg.memdesc<16x64xf16, #shared, #smem, mutable> -> tensor<16x64xf16, #blocked> tt.return @@ -3602,13 +3602,13 @@ def test_amd_tdm_scatter(target): %c128_i32 = arith.constant 128 : i32 %c128_i64 = arith.constant 128 : i64 %c1_i64 = arith.constant 1 : i64 - %0 = tt.make_tensor_descriptor %arg0, [%c32_i32, %c128_i32], [%c128_i64, %c1_i64] : , > + %0 = tt.make_tensor_descriptor %arg0, [%c32_i32, %c128_i32], [%c128_i64, %c1_i64] : , <16x64xf16, #shared> %cst = arith.constant 1.000000e+00 : f16 %cst_0 = arith.constant dense<1.000000e+00> : tensor<16x64xf16, #blocked> %1 = ttg.local_alloc %cst_0 : (tensor<16x64xf16, #blocked>) -> !ttg.memdesc<16x64xf16, #shared, #smem, mutable> %2 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked}>> %c0_i32 = arith.constant 0 : i32 - %3 = amdg.async_tdm_scatter %0[%2, %c0_i32] from %1 : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked}>>, !ttg.memdesc<16x64xf16, #shared, #smem, mutable> -> !tt.tensordesc> + %3 = amdg.async_tdm_scatter %0[%2, %c0_i32] from %1 : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked}>>, !ttg.memdesc<16x64xf16, #shared, #smem, mutable> -> !tt.tensordesc<16x64xf16, #shared> %4 = amdg.async_tdm_wait {num = 0 : i32} tt.return } @@ -3641,16 +3641,16 @@ def test_amd_tdm_load_pred(target): %c64_i32_0 = arith.constant 64 : i32 %c64_i64 = arith.constant 64 : i64 %c1_i64 = arith.constant 1 : i64 - %0 = tt.make_tensor_descriptor %arg0, [%c64_i32, %c64_i32_0], [%c64_i64, %c1_i64] : , > + %0 = tt.make_tensor_descriptor %arg0, [%c64_i32, %c64_i32_0], [%c64_i64, %c1_i64] : , <64x64xf16, #shared> %1 = ttg.local_alloc : () -> !ttg.memdesc<64x64xf16, #shared, #smem, mutable> %c0_i32 = arith.constant 0 : i32 %c2_i32 = arith.constant 2 : i32 %c0_i32_1 = arith.constant 0 : i32 - %2 = amdg.async_tdm_copy_global_to_local %0[%c0_i32, %c2_i32] into %1, pred = %c0_i32_1 : !tt.tensordesc> -> !ttg.memdesc<64x64xf16, #shared, #smem, mutable> + %2 = amdg.async_tdm_copy_global_to_local %0[%c0_i32, %c2_i32] into %1, pred = %c0_i32_1 : !tt.tensordesc<64x64xf16, #shared> -> !ttg.memdesc<64x64xf16, #shared, #smem, mutable> %c0_i32_2 = arith.constant 0 : i32 %c2_i32_3 = arith.constant 2 : i32 %c1_i32 = arith.constant 1 : i32 - %3 = amdg.async_tdm_copy_global_to_local %0[%c0_i32_2, %c2_i32_3] into %1, pred = %c1_i32 : !tt.tensordesc> -> !ttg.memdesc<64x64xf16, #shared, #smem, mutable> + %3 = amdg.async_tdm_copy_global_to_local %0[%c0_i32_2, %c2_i32_3] into %1, pred = %c1_i32 : !tt.tensordesc<64x64xf16, #shared> -> !ttg.memdesc<64x64xf16, #shared, #smem, mutable> tt.return } } @@ -3778,14 +3778,14 @@ def test_amd_tdm_load_mbarrier(target): %c128_i32 = arith.constant 128 : i32 %c128_i64 = arith.constant 128 : i64 %c1_i64 = arith.constant 1 : i64 - %0 = tt.make_tensor_descriptor %arg0, [%c32_i32, %c128_i32], [%c128_i64, %c1_i64] : , > + %0 = tt.make_tensor_descriptor %arg0, [%c32_i32, %c128_i32], [%c128_i64, %c1_i64] : , <16x64xf16, #shared> %1 = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #shared1, #smem, mutable> %2 = ttg.local_alloc : () -> !ttg.memdesc<16x64xf16, #shared, #smem, mutable> amdg.init_barrier %1, 1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> %c0_i32 = arith.constant 0 : i32 %c2_i32 = arith.constant 2 : i32 %c1_i32 = arith.constant 1 : i32 - %3 = amdg.async_tdm_copy_global_to_local %0[%c0_i32, %c2_i32] into %2, pred = %c1_i32, barrier = %1 : !tt.tensordesc>, !ttg.memdesc<1xi64, #shared1, #smem, mutable> -> !ttg.memdesc<16x64xf16, #shared, #smem, mutable> + %3 = amdg.async_tdm_copy_global_to_local %0[%c0_i32, %c2_i32] into %2, pred = %c1_i32, barrier = %1 : !tt.tensordesc<16x64xf16, #shared>, !ttg.memdesc<1xi64, #shared1, #smem, mutable> -> !ttg.memdesc<16x64xf16, #shared, #smem, mutable> %4 = ttg.local_load %2 : !ttg.memdesc<16x64xf16, #shared, #smem, mutable> -> tensor<16x64xf16, #blocked> tt.return } @@ -3826,7 +3826,7 @@ def nv_tma_descriptor_load_kernel(input_ptr): %c128_i32_0 = arith.constant 128 : i32 %c128_i64 = arith.constant 128 : i64 %c1_i64 = arith.constant 1 : i64 - %0 = tt.make_tensor_descriptor %arg0, [%c128_i32, %c128_i32_0], [%c128_i64, %c1_i64] : , > + %0 = tt.make_tensor_descriptor %arg0, [%c128_i32, %c128_i32_0], [%c128_i64, %c1_i64] : , <128x128xf32, #shared> %1 = ttg.local_alloc : () -> !ttg.memdesc<128x128xf32, #shared, #smem, mutable> %2 = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #shared1, #smem, mutable> ttng.init_barrier %2, 1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> @@ -3835,7 +3835,7 @@ def nv_tma_descriptor_load_kernel(input_ptr): %c0_i32 = arith.constant 0 : i32 %c0_i32_1 = arith.constant 0 : i32 %true_2 = arith.constant true - ttng.async_tma_copy_global_to_local %0[%c0_i32, %c0_i32_1] %1, %2, %true_2 : !tt.tensordesc>, !ttg.memdesc<1xi64, #shared1, #smem, mutable> -> !ttg.memdesc<128x128xf32, #shared, #smem, mutable> + ttng.async_tma_copy_global_to_local %0[%c0_i32, %c0_i32_1] %1, %2, %true_2 : !tt.tensordesc<128x128xf32, #shared>, !ttg.memdesc<1xi64, #shared1, #smem, mutable> -> !ttg.memdesc<128x128xf32, #shared, #smem, mutable> tt.return } } @@ -3872,11 +3872,11 @@ def nv_tma_descriptor_store_kernel(input_ptr): %c128_i32_0 = arith.constant 128 : i32 %c128_i64 = arith.constant 128 : i64 %c1_i64 = arith.constant 1 : i64 - %0 = tt.make_tensor_descriptor %arg0, [%c128_i32, %c128_i32_0], [%c128_i64, %c1_i64] : , > + %0 = tt.make_tensor_descriptor %arg0, [%c128_i32, %c128_i32_0], [%c128_i64, %c1_i64] : , <128x128xf32, #shared> %1 = ttg.local_alloc : () -> !ttg.memdesc<128x128xf32, #shared, #smem, mutable> %c0_i32 = arith.constant 0 : i32 %c0_i32_1 = arith.constant 0 : i32 - ttng.async_tma_copy_local_to_global %0[%c0_i32, %c0_i32_1] %1 : !tt.tensordesc>, !ttg.memdesc<128x128xf32, #shared, #smem, mutable> + ttng.async_tma_copy_local_to_global %0[%c0_i32, %c0_i32_1] %1 : !tt.tensordesc<128x128xf32, #shared>, !ttg.memdesc<128x128xf32, #shared, #smem, mutable> ttng.async_tma_store_wait {pendings = 0 : i32} tt.return } diff --git a/python/test/unit/cuda/test_tensor_descriptor_cuda.py b/python/test/unit/cuda/test_tensor_descriptor_cuda.py index 0b7c9a9fc7c0..c41d5fef439e 100644 --- a/python/test/unit/cuda/test_tensor_descriptor_cuda.py +++ b/python/test/unit/cuda/test_tensor_descriptor_cuda.py @@ -16,5 +16,5 @@ def kernel(a, b): A = torch.randn(1024, device=device) desc = TensorDescriptor.from_tensor(A, [128]) h = kernel.warmup(desc, 16, grid=(1, )) - assert "%a: !tt.tensordesc>" in h.asm["ttir"] + assert "%a: !tt.tensordesc<128xf32>" in h.asm["ttir"] assert "%b: i32 {tt.divisibility = 16 : i32}" in h.asm["ttir"] diff --git a/python/test/unit/language/test_line_info.py b/python/test/unit/language/test_line_info.py index 1899c2bd5128..6ce3ae82b5a3 100644 --- a/python/test/unit/language/test_line_info.py +++ b/python/test/unit/language/test_line_info.py @@ -318,7 +318,7 @@ def kernel_basic(src, N, BLOCK_SIZE: tl.constexpr): @triton.jit def kernel_tensordesc_param(foo): # CHECK-LABEL: tt.func public @kernel_tensordesc_param - # CHECK-SAME: %foo: !tt.tensordesc> + # CHECK-SAME: %foo: !tt.tensordesc<32x64xf16> # CHECK-SAME: %foo.shape.0: i32 # CHECK-SAME: %foo.shape.1: i32 # CHECK-SAME: %foo.stride.0: i64 diff --git a/test/Analysis/test-membar-ttng.mlir b/test/Analysis/test-membar-ttng.mlir index f3d5fbaf2044..17ab5b0c16f7 100644 --- a/test/Analysis/test-membar-ttng.mlir +++ b/test/Analysis/test-membar-ttng.mlir @@ -25,7 +25,7 @@ tt.func @async_store_wait(%arg: tensor<32x16xf16, #AL>) { module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 18944 : i32} { // CHECK-LABEL: tma_special_cases -tt.func @tma_special_cases(%arg1: !tt.tensordesc>, %arg2: !tt.tensordesc>) -> (tensor<256x64xf16, #blocked>){ +tt.func @tma_special_cases(%arg1: !tt.tensordesc<256x64xf16, #shared>, %arg2: !tt.tensordesc<1x64xf16, #shared>) -> (tensor<256x64xf16, #blocked>){ %true = arith.constant 1 : i1 %cx = arith.constant dense<1> : tensor<32xi32> %c0 = arith.constant 0 : i32 @@ -41,7 +41,7 @@ tt.func @tma_special_cases(%arg1: !tt.tensordesc>, % // CHECK-NEXT: ttng.async_tma_copy_global_to_local // CHECK-NEXT: ttng.wait_barrier ttng.barrier_expect %barrier, 49152, %true : !ttg.memdesc<1xi64, #shared1, #ttg.shared_memory, mutable> - ttng.async_tma_copy_global_to_local %arg1[%c0, %c0] %alloc, %barrier, %true : !tt.tensordesc>, !ttg.memdesc<1xi64, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> + ttng.async_tma_copy_global_to_local %arg1[%c0, %c0] %alloc, %barrier, %true : !tt.tensordesc<256x64xf16, #shared>, !ttg.memdesc<1xi64, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> ttng.wait_barrier %barrier, %c0 : !ttg.memdesc<1xi64, #shared1, #ttg.shared_memory, mutable> // CHECK-NEXT: ttg.barrier local @@ -50,7 +50,7 @@ tt.func @tma_special_cases(%arg1: !tt.tensordesc>, % // CHECK-NEXT: ttng.barrier_expect // CHECK-NEXT: ttg.barrier local // CHECK-NEXT: ttng.wait_barrier - ttng.async_tma_copy_global_to_local %arg1[%c0, %c0] %alloc, %barrier, %true : !tt.tensordesc>, !ttg.memdesc<1xi64, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> + ttng.async_tma_copy_global_to_local %arg1[%c0, %c0] %alloc, %barrier, %true : !tt.tensordesc<256x64xf16, #shared>, !ttg.memdesc<1xi64, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> ttng.barrier_expect %barrier, 49152, %true : !ttg.memdesc<1xi64, #shared1, #ttg.shared_memory, mutable> ttng.wait_barrier %barrier, %c0 : !ttg.memdesc<1xi64, #shared1, #ttg.shared_memory, mutable> @@ -63,7 +63,7 @@ tt.func @tma_special_cases(%arg1: !tt.tensordesc>, % // CHECK-NEXT: ttng.async_tma_copy_global_to_local // CHECK-NEXT: ttng.wait_barrier ttng.barrier_expect %barrier, 49152, %true : !ttg.memdesc<1xi64, #shared1, #ttg.shared_memory, mutable> - ttng.async_tma_copy_global_to_local %arg1[%c0, %c0] %alloc, %barrier, %true : !tt.tensordesc>, !ttg.memdesc<1xi64, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> + ttng.async_tma_copy_global_to_local %arg1[%c0, %c0] %alloc, %barrier, %true : !tt.tensordesc<256x64xf16, #shared>, !ttg.memdesc<1xi64, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> ttng.wait_barrier %barrier, %c0 : !ttg.memdesc<1xi64, #shared1, #ttg.shared_memory, mutable> // CHECK-NEXT: memdesc_subslice @@ -74,7 +74,7 @@ tt.func @tma_special_cases(%arg1: !tt.tensordesc>, % // CHECK-NEXT: ttng.wait_barrier %view = ttg.memdesc_subslice %alloc [0, 0] : !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<32x64xf16, #shared, #ttg.shared_memory, mutable> ttng.barrier_expect %barrier, 49152, %true : !ttg.memdesc<1xi64, #shared1, #ttg.shared_memory, mutable> - ttng.async_tma_gather %arg2[%cx, %c0] %view, %barrier, %true : !tt.tensordesc>, tensor<32xi32>, i32, !ttg.memdesc<1xi64, #shared1, #ttg.shared_memory, mutable>, !ttg.memdesc<32x64xf16, #shared, #ttg.shared_memory, mutable>, i1 + ttng.async_tma_gather %arg2[%cx, %c0] %view, %barrier, %true : !tt.tensordesc<1x64xf16, #shared>, tensor<32xi32>, i32, !ttg.memdesc<1xi64, #shared1, #ttg.shared_memory, mutable>, !ttg.memdesc<32x64xf16, #shared, #ttg.shared_memory, mutable>, i1 ttng.wait_barrier %barrier, %c0 : !ttg.memdesc<1xi64, #shared1, #ttg.shared_memory, mutable> // CHECK-NEXT: ttg.barrier local @@ -95,7 +95,7 @@ tt.func @tma_special_cases(%arg1: !tt.tensordesc>, % module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 18944 : i32} { // CHECK-LABEL: tma_special_cases_cf -tt.func @tma_special_cases_cf(%arg1: !tt.tensordesc>, %i1 : i1, %arg2: tensor<256x64xf16, #blocked>) -> (tensor<256x64xf16, #blocked>){ +tt.func @tma_special_cases_cf(%arg1: !tt.tensordesc<256x64xf16, #shared>, %i1 : i1, %arg2: tensor<256x64xf16, #blocked>) -> (tensor<256x64xf16, #blocked>){ %true = arith.constant 1 : i1 %c0 = arith.constant 0 : i32 %barrier = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #shared1, #ttg.shared_memory, mutable> @@ -111,7 +111,7 @@ tt.func @tma_special_cases_cf(%arg1: !tt.tensordesc> // CHECK-NEXT: ttng.wait_barrier // CF-NEXT: cf.br // SCF-NEXT: } else { - ttng.async_tma_copy_global_to_local %arg1[%c0, %c0] %alloc, %barrier, %true : !tt.tensordesc>, !ttg.memdesc<1xi64, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> + ttng.async_tma_copy_global_to_local %arg1[%c0, %c0] %alloc, %barrier, %true : !tt.tensordesc<256x64xf16, #shared>, !ttg.memdesc<1xi64, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> ttng.barrier_expect %barrier, 49152, %true : !ttg.memdesc<1xi64, #shared1, #ttg.shared_memory, mutable> ttng.wait_barrier %barrier, %c0 : !ttg.memdesc<1xi64, #shared1, #ttg.shared_memory, mutable> } else { diff --git a/test/Analysis/test-membar.mlir b/test/Analysis/test-membar.mlir index beb6e832302d..ebd26997d4f8 100644 --- a/test/Analysis/test-membar.mlir +++ b/test/Analysis/test-membar.mlir @@ -1214,7 +1214,7 @@ tt.func @loop_memindex_subslice(%arg0: tensor<2x128x128xf16>) { module attributes {ttg.target = "cuda:90", "ttg.num-warps" = 8 : i32} { // CHECK-LABEL: warp_dot_multi_read - tt.func @warp_dot_multi_read(%arg0: !tt.tensordesc>, %arg1: tensor<128x128x!tt.ptr>, %arg2: i32, %arg3: i1, %arg4: tensor<128x256xf32, #mma>, %arg5: tensor<128x128xi1>) { + tt.func @warp_dot_multi_read(%arg0: !tt.tensordesc<1x256x128xf8E5M2, #shared1>, %arg1: tensor<128x128x!tt.ptr>, %arg2: i32, %arg3: i1, %arg4: tensor<128x256xf32, #mma>, %arg5: tensor<128x128xi1>) { %a_tile = ttg.local_alloc : () -> !ttg.memdesc<128x128xf8E5M2, #shared1, #smem, mutable> %b_tile = ttg.local_alloc : () -> !ttg.memdesc<256x128xf8E5M2, #shared1, #smem, mutable> @@ -1228,7 +1228,7 @@ module attributes {ttg.target = "cuda:90", "ttg.num-warps" = 8 : i32} { // CHECK: ttg.barrier local // CHECK-NEXT: ttg.async_copy_global_to_local ttg.async_copy_global_to_local %arg1, %a_tile mask %arg5 {contiguity = 16 : i32} : tensor<128x128x!tt.ptr> -> <128x128xf8E5M2, #shared1, #smem, mutable> - ttng.async_tma_copy_global_to_local %arg0[%arg2, %arg2, %arg2] %b_tile, %barrier, %arg3 : !tt.tensordesc>, !ttg.memdesc<1xi64, #shared2, #smem, mutable> -> !ttg.memdesc<256x128xf8E5M2, #shared1, #smem, mutable> + ttng.async_tma_copy_global_to_local %arg0[%arg2, %arg2, %arg2] %b_tile, %barrier, %arg3 : !tt.tensordesc<1x256x128xf8E5M2, #shared1>, !ttg.memdesc<1xi64, #shared2, #smem, mutable> -> !ttg.memdesc<256x128xf8E5M2, #shared1, #smem, mutable> tt.return } } diff --git a/test/Conversion/amd/allocate_shared_memory.mlir b/test/Conversion/amd/allocate_shared_memory.mlir index a02f2beb8919..992a1f07850f 100644 --- a/test/Conversion/amd/allocate_shared_memory.mlir +++ b/test/Conversion/amd/allocate_shared_memory.mlir @@ -34,14 +34,14 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // CHECK-LABEL: @ws_tensordesc_2d_capture // CHECK: allocation.offset = 48 : i32 -tt.func @ws_tensordesc_2d_capture(%desc: !tt.tensordesc>) { +tt.func @ws_tensordesc_2d_capture(%desc: !tt.tensordesc<64x64xf16>) { ttg.warp_specialize(%desc) attributes {warpGroupStartIds = array} default { ttg.warp_yield } - partition0(%arg0: !tt.tensordesc>) num_warps(4) { + partition0(%arg0: !tt.tensordesc<64x64xf16>) num_warps(4) { ttg.warp_return - } : (!tt.tensordesc>) -> () + } : (!tt.tensordesc<64x64xf16>) -> () tt.return } @@ -57,14 +57,14 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // CHECK-LABEL: @ws_tensordesc_5d_capture // CHECK: allocation.offset = 80 : i32 -tt.func @ws_tensordesc_5d_capture(%desc: !tt.tensordesc>) { +tt.func @ws_tensordesc_5d_capture(%desc: !tt.tensordesc<8x8x8x16x16xf16>) { ttg.warp_specialize(%desc) attributes {warpGroupStartIds = array} default { ttg.warp_yield } - partition0(%arg0: !tt.tensordesc>) num_warps(4) { + partition0(%arg0: !tt.tensordesc<8x8x8x16x16xf16>) num_warps(4) { ttg.warp_return - } : (!tt.tensordesc>) -> () + } : (!tt.tensordesc<8x8x8x16x16xf16>) -> () tt.return } diff --git a/test/Conversion/amd/tritongpu_tdm_stride_order.mlir b/test/Conversion/amd/tritongpu_tdm_stride_order.mlir index b77aa125acba..7e352951bd15 100644 --- a/test/Conversion/amd/tritongpu_tdm_stride_order.mlir +++ b/test/Conversion/amd/tritongpu_tdm_stride_order.mlir @@ -8,7 +8,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr %c_stride1 = arith.constant 1 : i64 // expected-error @+2 {{requires at least one dimension to have stride 1}} // expected-error @+1 {{failed to legalize operation}} - %0 = tt.make_tensor_descriptor %arg0, [%c_shape, %c_shape], [%stride0, %stride1] : , > + %0 = tt.make_tensor_descriptor %arg0, [%c_shape, %c_shape], [%stride0, %stride1] : , <64x64xf16, #shared> tt.return } } @@ -22,7 +22,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr %c_stride1 = arith.constant 1 : i64 // expected-error @+2 {{requires at least one dimension to have stride 1}} // expected-error @+1 {{failed to legalize operation}} - %0 = tt.make_tensor_descriptor %arg0, [%c_shape, %c_shape], [%stride1, %stride1] : , > + %0 = tt.make_tensor_descriptor %arg0, [%c_shape, %c_shape], [%stride1, %stride1] : , <64x64xf16, #shared> tt.return } } @@ -36,7 +36,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr %c_stride1 = arith.constant 1 : i64 // expected-error @+2 {{requires shared order [rank-2, rank-1, rank-3, rank-4, ..., 0] because dim[rank-2] has stride 1}} // expected-error @+1 {{failed to legalize operation}} - %0 = tt.make_tensor_descriptor %arg0, [%c_shape, %c_shape], [%c_stride1, %runtime_stride] : , > + %0 = tt.make_tensor_descriptor %arg0, [%c_shape, %c_shape], [%c_stride1, %runtime_stride] : , <64x64xf16, #shared> tt.return } } @@ -50,7 +50,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr %c_stride1 = arith.constant 1 : i64 // expected-error @+2 {{requires shared order [rank-1, rank-2, ..., 0]}} // expected-error @+1 {{failed to legalize operation}} - %0 = tt.make_tensor_descriptor %arg0, [%c_shape, %c_shape], [%runtime_stride, %c_stride1] : , > + %0 = tt.make_tensor_descriptor %arg0, [%c_shape, %c_shape], [%runtime_stride, %c_stride1] : , <64x64xf16, #shared> tt.return } } @@ -64,7 +64,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr tt.func public @tdm_1x1x1(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) { %c_stride1 = arith.constant 1 : i64 %c_shape = arith.constant 1 : i32 - %0 = tt.make_tensor_descriptor %arg0, [%c_shape, %c_shape, %c_shape], [%c_stride1, %c_stride1, %c_stride1] : , > + %0 = tt.make_tensor_descriptor %arg0, [%c_shape, %c_shape, %c_shape], [%c_stride1, %c_stride1, %c_stride1] : , <1x1x1xf16, #shared> tt.return } } @@ -79,7 +79,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr %c_stride1 = arith.constant 1 : i64 %c_shape = arith.constant 1 : i32 %c_shape2 = arith.constant 128 : i32 - %0 = tt.make_tensor_descriptor %arg0, [%c_shape2, %c_shape, %c_shape], [%runtime_stride, %c_stride1, %c_stride1] : , > + %0 = tt.make_tensor_descriptor %arg0, [%c_shape2, %c_shape, %c_shape], [%runtime_stride, %c_stride1, %c_stride1] : , <1x1x1xf16, #shared> tt.return } } @@ -94,7 +94,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr %c_shape2 = arith.constant 128 : i32 // expected-error @+2 {{requires all stride 1 dimensions to be consecutive starting from the last dimension}} // expected-error @+1 {{failed to legalize operation}} - %0 = tt.make_tensor_descriptor %arg0, [%c_shape, %c_shape2, %c_shape], [%c_stride1, %runtime_stride, %c_stride1] : , > + %0 = tt.make_tensor_descriptor %arg0, [%c_shape, %c_shape2, %c_shape], [%c_stride1, %runtime_stride, %c_stride1] : , <1x1x1xf16, #shared> tt.return } } @@ -109,7 +109,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr %c_shape2 = arith.constant 128 : i32 // expected-error @+2 {{requires all stride 1 dimensions to be consecutive starting from the last dimension}} // expected-error @+1 {{failed to legalize operation}} - %0 = tt.make_tensor_descriptor %arg0, [%c_shape, %c_shape, %c_shape2], [%c_stride1, %c_stride1, %runtime_stride] : , > + %0 = tt.make_tensor_descriptor %arg0, [%c_shape, %c_shape, %c_shape2], [%c_stride1, %c_stride1, %runtime_stride] : , <1x1x1xf16, #shared> tt.return } } diff --git a/test/Conversion/amd/tritongpu_tdm_to_llvm.mlir b/test/Conversion/amd/tritongpu_tdm_to_llvm.mlir index 923aa63bc0e8..90fb578ec179 100644 --- a/test/Conversion/amd/tritongpu_tdm_to_llvm.mlir +++ b/test/Conversion/amd/tritongpu_tdm_to_llvm.mlir @@ -11,12 +11,12 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr %c_stride1 = arith.constant 1 : i64 %c_offset = arith.constant 0 : i32 %c_pred = arith.constant 1 : i32 - %0 = tt.make_tensor_descriptor %arg0, [%c_shape, %c_shape], [%c_stride0, %c_stride1] : , > + %0 = tt.make_tensor_descriptor %arg0, [%c_shape, %c_shape], [%c_stride0, %c_stride1] : , <64x64xf16, #shared> %1 = ttg.local_alloc : () -> !ttg.memdesc<64x64xf16, #shared, #smem, mutable> // CHECK-COUNT-4: llvm.insertelement{{.*}} : vector<4xi32> // CHECK-COUNT-8: llvm.insertelement{{.*}} : vector<8xi32> // CHECK: "llvm.amdgcn.tensor.load.to.lds"({{.+}}) : (vector<4xi32>, vector<8xi32>, vector<4xi32>, vector<4xi32>, vector<8xi32>, i32) -> () - %2 = amdg.async_tdm_copy_global_to_local %0[%c_offset, %c_offset] into %1, pred = %c_pred : !tt.tensordesc> -> !ttg.memdesc<64x64xf16, #shared, #smem, mutable> + %2 = amdg.async_tdm_copy_global_to_local %0[%c_offset, %c_offset] into %1, pred = %c_pred : !tt.tensordesc<64x64xf16, #shared> -> !ttg.memdesc<64x64xf16, #shared, #smem, mutable> // CHECK: rocdl.s.wait.tensorcnt 0 %3 = amdg.async_tdm_intrinsic_wait {count = 0 : i32} %4 = ttg.local_load %1 : !ttg.memdesc<64x64xf16, #shared, #smem, mutable> -> tensor<64x64xf16, #blocked> @@ -36,14 +36,14 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr %c_stride0 = arith.constant 128 : i64 %c_stride1 = arith.constant 1 : i64 %c_offset = arith.constant 0 : i32 - %0 = tt.make_tensor_descriptor %arg0, [%c_shape, %c_shape], [%c_stride0, %c_stride1] : , > + %0 = tt.make_tensor_descriptor %arg0, [%c_shape, %c_shape], [%c_stride0, %c_stride1] : , <64x64xf16, #shared> %1 = ttg.local_alloc : () -> !ttg.memdesc<64x64xf16, #shared, #smem, mutable> %2 = arith.constant dense<0.000000e+00> : tensor<64x64xf16, #blocked> ttg.local_store %2, %1 : tensor<64x64xf16, #blocked> -> !ttg.memdesc<64x64xf16, #shared, #smem, mutable> // CHECK-COUNT-4: llvm.insertelement{{.*}} : vector<4xi32> // CHECK-COUNT-8: llvm.insertelement{{.*}} : vector<8xi32> // CHECK: "llvm.amdgcn.tensor.store.from.lds"({{.+}}) : (vector<4xi32>, vector<8xi32>, vector<4xi32>, vector<4xi32>, vector<8xi32>, i32) -> () - amdg.async_tdm_copy_local_to_global %0[%c_offset, %c_offset] from %1: !ttg.memdesc<64x64xf16, #shared, #smem, mutable> -> !tt.tensordesc> + amdg.async_tdm_copy_local_to_global %0[%c_offset, %c_offset] from %1: !ttg.memdesc<64x64xf16, #shared, #smem, mutable> -> !tt.tensordesc<64x64xf16, #shared> // CHECK: rocdl.s.wait.tensorcnt 0 %3 = amdg.async_tdm_intrinsic_wait {count = 0 : i32} tt.return @@ -72,11 +72,11 @@ module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr // CHECK: %[[OFFSET_DIM1:.*]] = llvm.mul{{.*}}%[[STRIDE1]] // CHECK: %[[TOTAL_OFFSET:.*]] = llvm.add %[[OFFSET_TMP1]], %[[OFFSET_DIM1]] // CHECK: %[[ADJUSTED_PTR:.*]] = llvm.getelementptr %{{.*}}[%[[TOTAL_OFFSET]]] - %0 = tt.make_tensor_descriptor %arg0, [%c_shape, %c_shape], [%c_stride0, %c_stride1] : , > + %0 = tt.make_tensor_descriptor %arg0, [%c_shape, %c_shape], [%c_stride0, %c_stride1] : , <64x64xf16, #shared> %1 = ttg.local_alloc : () -> !ttg.memdesc<64x64xf16, #shared, #smem, mutable> // CHECK: "llvm.amdgcn.tensor.load.to.lds"({{.+}}) : (vector<4xi32>, vector<8xi32>, vector<4xi32>, vector<4xi32>, vector<8xi32>, i32) -> () - %2 = amdg.async_tdm_copy_global_to_local %0[%c_offset, %c_offset] into %1, pred = %c_pred : !tt.tensordesc> -> !ttg.memdesc<64x64xf16, #shared, #smem, mutable> + %2 = amdg.async_tdm_copy_global_to_local %0[%c_offset, %c_offset] into %1, pred = %c_pred : !tt.tensordesc<64x64xf16, #shared> -> !ttg.memdesc<64x64xf16, #shared, #smem, mutable> tt.return } } @@ -104,10 +104,10 @@ module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr // CHECK: %[[OFFSET_DIM1:.*]] = llvm.mul{{.*}}%[[STRIDE1]] // CHECK: %[[TOTAL_OFFSET:.*]] = llvm.add %[[OFFSET_TMP1]], %[[OFFSET_DIM1]] // CHECK: %[[ADJUSTED_PTR:.*]] = llvm.getelementptr %{{.*}}[%[[TOTAL_OFFSET]]] - %0 = tt.make_tensor_descriptor %arg0, [%c_shape, %c_shape], [%c_stride0, %c_stride1] : , > + %0 = tt.make_tensor_descriptor %arg0, [%c_shape, %c_shape], [%c_stride0, %c_stride1] : , <64x64xf16, #shared> %1 = ttg.local_alloc : () -> !ttg.memdesc<64x64xf16, #shared, #smem, mutable> // CHECK: "llvm.amdgcn.tensor.store.from.lds"({{.+}}) : (vector<4xi32>, vector<8xi32>, vector<4xi32>, vector<4xi32>, vector<8xi32>, i32) -> () - amdg.async_tdm_copy_local_to_global %0[%c_offset, %c_offset] from %1: !ttg.memdesc<64x64xf16, #shared, #smem, mutable> -> !tt.tensordesc> + amdg.async_tdm_copy_local_to_global %0[%c_offset, %c_offset] from %1: !ttg.memdesc<64x64xf16, #shared, #smem, mutable> -> !tt.tensordesc<64x64xf16, #shared> tt.return } } @@ -136,12 +136,12 @@ module attributes {"ttg.num-ctas" = 16 : i32, "ttg.num-warps" = 4 : i32, "ttg.th // CHECK: %[[TMP2:.*]] = llvm.and %[[TMP]] // CHECK-NOT: llvm.insertelement{{.*}} : vector<8xi32> // CHECK: llvm.insertelement %[[TMP2]] - %0 = tt.make_tensor_descriptor %arg0, [%c_shape, %c_shape], [%c_stride0, %c_stride1] : , > + %0 = tt.make_tensor_descriptor %arg0, [%c_shape, %c_shape], [%c_stride0, %c_stride1] : , <64x64xf16, #shared> %1 = ttg.local_alloc : () -> !ttg.memdesc<64x64xf16, #shared, #smem, mutable> // CHECK: "llvm.amdgcn.tensor.load.to.lds"({{.+}}) : (vector<4xi32>, vector<8xi32>, vector<4xi32>, vector<4xi32>, vector<8xi32>, i32) -> () - %2 = amdg.async_tdm_copy_global_to_local %0[%c_offset, %c_offset] into %1, pred = %c_pred : !tt.tensordesc> -> !ttg.memdesc<64x64xf16, #shared, #smem, mutable> + %2 = amdg.async_tdm_copy_global_to_local %0[%c_offset, %c_offset] into %1, pred = %c_pred : !tt.tensordesc<64x64xf16, #shared> -> !ttg.memdesc<64x64xf16, #shared, #smem, mutable> tt.return } } @@ -158,13 +158,13 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr %c_stride1 = arith.constant 1 : i64 %c_offset = arith.constant 0 : i32 %c_pred = arith.constant true - %0 = tt.make_tensor_descriptor %arg0, [%c_shape, %c_shape], [%c_stride0, %c_stride1] : , > + %0 = tt.make_tensor_descriptor %arg0, [%c_shape, %c_shape], [%c_stride0, %c_stride1] : , <64x64xf16, #shared> // CHECK: rocdl.global.prefetch %{{.*}} scope 8 - amdg.tdm_prefetch %0[%c_offset, %c_offset], %c_pred, speculative = false : !tt.tensordesc> + amdg.tdm_prefetch %0[%c_offset, %c_offset], %c_pred, speculative = false : !tt.tensordesc<64x64xf16, #shared> // CHECK: rocdl.global.prefetch %{{.*}} scope 9 - amdg.tdm_prefetch %0[%c_offset, %c_offset], %c_pred, speculative = true : !tt.tensordesc> + amdg.tdm_prefetch %0[%c_offset, %c_offset], %c_pred, speculative = true : !tt.tensordesc<64x64xf16, #shared> tt.return } } @@ -187,11 +187,11 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: tdm_2d_with_padding tt.func public @tdm_2d_with_padding( - %tensorDesc: !tt.tensordesc>, + %tensorDesc: !tt.tensordesc<128x64xf16>, %memDesc: !ttg.memdesc<128x64xf16, #shared, #smem, mutable> ) { %c0_i32 = arith.constant 0 : i32 - amdg.async_tdm_copy_local_to_global %tensorDesc[%c0_i32, %c0_i32] from %memDesc: !ttg.memdesc<128x64xf16, #shared, #smem, mutable> -> !tt.tensordesc> + amdg.async_tdm_copy_local_to_global %tensorDesc[%c0_i32, %c0_i32] from %memDesc: !ttg.memdesc<128x64xf16, #shared, #smem, mutable> -> !tt.tensordesc<128x64xf16> // CHECK: "llvm.amdgcn.tensor.store.from.lds"({{.+}}) : (vector<4xi32>, vector<8xi32>, vector<4xi32>, vector<4xi32>, vector<8xi32>, i32) -> () tt.return } @@ -204,11 +204,11 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: tdm_5d_with_padding tt.func public @tdm_5d_with_padding( - %tensorDesc: !tt.tensordesc>, + %tensorDesc: !tt.tensordesc<8x8x8x16x16xf16>, %memDesc: !ttg.memdesc<8x8x8x16x16xf16, #shared_5d, #smem_5d, mutable> ) { %c0_i32 = arith.constant 0 : i32 - amdg.async_tdm_copy_local_to_global %tensorDesc[%c0_i32, %c0_i32, %c0_i32, %c0_i32, %c0_i32] from %memDesc: !ttg.memdesc<8x8x8x16x16xf16, #shared_5d, #smem_5d, mutable> -> !tt.tensordesc> + amdg.async_tdm_copy_local_to_global %tensorDesc[%c0_i32, %c0_i32, %c0_i32, %c0_i32, %c0_i32] from %memDesc: !ttg.memdesc<8x8x8x16x16xf16, #shared_5d, #smem_5d, mutable> -> !tt.tensordesc<8x8x8x16x16xf16> // CHECK: "llvm.amdgcn.tensor.store.from.lds"({{.+}}) : (vector<4xi32>, vector<8xi32>, vector<4xi32>, vector<4xi32>, vector<8xi32>, i32) -> () tt.return } diff --git a/test/Conversion/relayout_tritongpu.mlir b/test/Conversion/relayout_tritongpu.mlir index 0dda73163076..08f0771643ab 100644 --- a/test/Conversion/relayout_tritongpu.mlir +++ b/test/Conversion/relayout_tritongpu.mlir @@ -50,21 +50,21 @@ tt.func @tmem_scales_layout() { // CHECK: [[SLICE_PARENT:#.*]] = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}> // CHECK: @async_tma_gather -tt.func @async_tma_gather(%desc: !tt.tensordesc>, %y_offset: i32, +tt.func @async_tma_gather(%desc: !tt.tensordesc<1x128xbf16, #shared>, %y_offset: i32, %bar: !ttg.memdesc<1xi64, #bar_layout, #ttg.shared_memory, mutable>, %result: !ttg.memdesc<32x128xbf16, #shared, #ttg.shared_memory, mutable>, %pred: i1) { %x_offsets = arith.constant dense<1> : tensor<32xi32> // CHECK: [[IDX:%.*]] = ttg.convert_layout %cst : tensor<32xi32, #{{.*}}> -> tensor<32xi32, #ttg.slice<{dim = 0, parent = [[SLICE_PARENT]]}>> - ttng.async_tma_gather %desc[%x_offsets, %y_offset] %result, %bar, %pred : !tt.tensordesc>, tensor<32xi32>, i32, !ttg.memdesc<1xi64, #bar_layout, #ttg.shared_memory, mutable>, !ttg.memdesc<32x128xbf16, #shared, #ttg.shared_memory, mutable>, i1 + ttng.async_tma_gather %desc[%x_offsets, %y_offset] %result, %bar, %pred : !tt.tensordesc<1x128xbf16, #shared>, tensor<32xi32>, i32, !ttg.memdesc<1xi64, #bar_layout, #ttg.shared_memory, mutable>, !ttg.memdesc<32x128xbf16, #shared, #ttg.shared_memory, mutable>, i1 tt.return } // CHECK: @async_tma_scatter -tt.func @async_tma_scatter(%desc: !tt.tensordesc>, %y_offset: i32, +tt.func @async_tma_scatter(%desc: !tt.tensordesc<1x128xbf16, #shared>, %y_offset: i32, %src: !ttg.memdesc<32x128xbf16, #shared, #ttg.shared_memory, mutable>) { %x_offsets = arith.constant dense<1> : tensor<32xi32> // CHECK: [[IDX:%.*]] = ttg.convert_layout %cst : tensor<32xi32, #{{.*}}> -> tensor<32xi32, #ttg.slice<{dim = 0, parent = [[SLICE_PARENT]]}>> - ttng.async_tma_scatter %desc[%x_offsets, %y_offset] %src : !tt.tensordesc>, tensor<32xi32>, i32, !ttg.memdesc<32x128xbf16, #shared, #ttg.shared_memory, mutable> + ttng.async_tma_scatter %desc[%x_offsets, %y_offset] %src : !tt.tensordesc<1x128xbf16, #shared>, tensor<32xi32>, i32, !ttg.memdesc<32x128xbf16, #shared, #ttg.shared_memory, mutable> tt.return } diff --git a/test/Conversion/tma_to_llvm.mlir b/test/Conversion/tma_to_llvm.mlir index 34fd0b7ed1af..b522397f784e 100644 --- a/test/Conversion/tma_to_llvm.mlir +++ b/test/Conversion/tma_to_llvm.mlir @@ -12,7 +12,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // CHECK-LABEL: @tma_gather_simple // CHECK-SAME: i32 [[Y0:%3]] -tt.func @tma_gather_simple(%arg0: !tt.tensordesc>, %arg1: !ttg.memdesc<1xi64, #shared, #smem, mutable>, %arg2: tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, %arg3: i32, %arg4: !ttg.memdesc<32x128xbf16, #shared1, #smem, mutable>, %arg5: i1) { +tt.func @tma_gather_simple(%arg0: !tt.tensordesc<1x128xbf16, #shared1>, %arg1: !ttg.memdesc<1xi64, #shared, #smem, mutable>, %arg2: tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, %arg3: i32, %arg4: !ttg.memdesc<32x128xbf16, #shared1, #smem, mutable>, %arg5: i1) { // There are 32 indices distributed to 4 warps, so each warp as 8 indices. // CHECK: [[BAR:%.*]] = extractvalue {{.*}} %1, 0 @@ -80,14 +80,14 @@ tt.func @tma_gather_simple(%arg0: !tt.tensordesc>, // CHECK: [[BASEPTR3:%.*]] = getelementptr i8, ptr addrspace(3) [[BASEPTR0]], i64 6144 // CHECK: cp.async.bulk.tensor.2d.tile::gather4 // CHECK-SAME: (i1 [[PRED]], ptr addrspace(3) [[BASEPTR3]], ptr nonnull %0, i32 [[Y1]], i32 [[IDX4]], i32 [[IDX5]], i32 [[IDX6]], i32 [[IDX7]], ptr addrspace(3) [[BAR]]) - ttng.async_tma_gather %arg0[%arg2, %arg3] %arg4, %arg1, %arg5 : !tt.tensordesc>, tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32, !ttg.memdesc<1xi64, #shared, #smem, mutable>, !ttg.memdesc<32x128xbf16, #shared1, #smem, mutable>, i1 + ttng.async_tma_gather %arg0[%arg2, %arg3] %arg4, %arg1, %arg5 : !tt.tensordesc<1x128xbf16, #shared1>, tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32, !ttg.memdesc<1xi64, #shared, #smem, mutable>, !ttg.memdesc<32x128xbf16, #shared1, #smem, mutable>, i1 // CHECK-NEXT: ret void tt.return } // CHECK-LABEL: @tma_gather_8_consecutive_indices -tt.func @tma_gather_8_consecutive_indices(%arg0: !tt.tensordesc>, %arg1: !ttg.memdesc<1xi64, #shared, #smem, mutable>, %arg2: tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>, %arg3: i32, %arg4: !ttg.memdesc<32x128xbf16, #shared1, #smem, mutable>, %arg5: i1) { +tt.func @tma_gather_8_consecutive_indices(%arg0: !tt.tensordesc<1x128xbf16, #shared1>, %arg1: !ttg.memdesc<1xi64, #shared, #smem, mutable>, %arg2: tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>, %arg3: i32, %arg4: !ttg.memdesc<32x128xbf16, #shared1, #smem, mutable>, %arg5: i1) { // Due to the `sizePerThread = [1, 8]`, each warp now handles 8 consecutive // rows, where each row is divided into 2 segments for a total of 4 gather4s. // @@ -111,26 +111,26 @@ tt.func @tma_gather_8_consecutive_indices(%arg0: !tt.tensordesc>, tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>, i32, !ttg.memdesc<1xi64, #shared, #smem, mutable>, !ttg.memdesc<32x128xbf16, #shared1, #smem, mutable>, i1 + ttng.async_tma_gather %arg0[%arg2, %arg3] %arg4, %arg1, %arg5 : !tt.tensordesc<1x128xbf16, #shared1>, tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>, i32, !ttg.memdesc<1xi64, #shared, #smem, mutable>, !ttg.memdesc<32x128xbf16, #shared1, #smem, mutable>, i1 // CHECK-NEXT: ret void tt.return } // CHECK-LABEL: @tma_gather_redundant_indices -tt.func @tma_gather_redundant_indices(%arg0: !tt.tensordesc>, %arg1: !ttg.memdesc<1xi64, #shared, #smem, mutable>, %arg2: tensor<32xi32, #linear>, %arg3: i32, %arg4: !ttg.memdesc<32x128xbf16, #shared1, #smem, mutable>, %arg5: i1) { +tt.func @tma_gather_redundant_indices(%arg0: !tt.tensordesc<1x128xbf16, #shared1>, %arg1: !ttg.memdesc<1xi64, #shared, #smem, mutable>, %arg2: tensor<32xi32, #linear>, %arg3: i32, %arg4: !ttg.memdesc<32x128xbf16, #shared1, #smem, mutable>, %arg5: i1) { // Codegen for this case is actually incorrect due to linear layouts // incorrectly handling register broadcasting, but the test outcome is nonetheless // the same. // CHECK-COUNT-4: cp.async.bulk.tensor - ttng.async_tma_gather %arg0[%arg2, %arg3] %arg4, %arg1, %arg5 : !tt.tensordesc>, tensor<32xi32, #linear>, i32, !ttg.memdesc<1xi64, #shared, #smem, mutable>, !ttg.memdesc<32x128xbf16, #shared1, #smem, mutable>, i1 + ttng.async_tma_gather %arg0[%arg2, %arg3] %arg4, %arg1, %arg5 : !tt.tensordesc<1x128xbf16, #shared1>, tensor<32xi32, #linear>, i32, !ttg.memdesc<1xi64, #shared, #smem, mutable>, !ttg.memdesc<32x128xbf16, #shared1, #smem, mutable>, i1 // CHECK-NEXT: ret void tt.return } // CHECK-LABEL: @tma_gather_redundant_warps -tt.func @tma_gather_redundant_warps(%arg0: !tt.tensordesc>, %arg1: !ttg.memdesc<1xi64, #shared, #smem, mutable>, %arg2: tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked2}>>, %arg3: i32, %arg4: !ttg.memdesc<32x128xbf16, #shared1, #smem, mutable>, %arg5: i1) { +tt.func @tma_gather_redundant_warps(%arg0: !tt.tensordesc<1x128xbf16, #shared1>, %arg1: !ttg.memdesc<1xi64, #shared, #smem, mutable>, %arg2: tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked2}>>, %arg3: i32, %arg4: !ttg.memdesc<32x128xbf16, #shared1, #smem, mutable>, %arg5: i1) { // CHECK: [[WARP_ID:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32 // CHECK: [[WARP_SELECT:%.*]] = and i32 [[WARP_ID]], 2 // CHECK: [[WARP_PRED:%.*]] = icmp eq i32 [[WARP_SELECT]], 0 @@ -140,14 +140,14 @@ tt.func @tma_gather_redundant_warps(%arg0: !tt.tensordesc>, tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked2}>>, i32, !ttg.memdesc<1xi64, #shared, #smem, mutable>, !ttg.memdesc<32x128xbf16, #shared1, #smem, mutable>, i1 + ttng.async_tma_gather %arg0[%arg2, %arg3] %arg4, %arg1, %arg5 : !tt.tensordesc<1x128xbf16, #shared1>, tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked2}>>, i32, !ttg.memdesc<1xi64, #shared, #smem, mutable>, !ttg.memdesc<32x128xbf16, #shared1, #smem, mutable>, i1 // CHECK-NEXT: ret void tt.return } // CHECK-LABEL: @tma_scatter -tt.func @tma_scatter(%arg0: !tt.tensordesc>, %arg1: tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, %arg2: i32, %arg3: !ttg.memdesc<32x128xbf16, #shared1, #smem, mutable>) { +tt.func @tma_scatter(%arg0: !tt.tensordesc<1x128xbf16, #shared1>, %arg1: tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, %arg2: i32, %arg3: !ttg.memdesc<32x128xbf16, #shared1, #smem, mutable>) { // The lowering for `async_tma_scatter` shares practically all of its logic // with `async_tma_gather`, so we don't need to re-test the indexing logic. @@ -158,7 +158,7 @@ tt.func @tma_scatter(%arg0: !tt.tensordesc>, %arg1: // CHECK: [[PTR:%.*]] = getelementptr {{.*}} [[BASE_PTR]] // CHECK-NEXT: "@$0 cp.async.bulk.tensor.2d.tile::scatter4.global.shared::cta.bulk_group [$1, {$2, $3, $4, $5, $6}], [$7];" // CHECK-SAME: (i1 [[PRED]], ptr nonnull %0, i32 %2, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, ptr addrspace(3) [[PTR]]) - ttng.async_tma_scatter %arg0[%arg1, %arg2] %arg3 : !tt.tensordesc>, tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32, !ttg.memdesc<32x128xbf16, #shared1, #smem, mutable> + ttng.async_tma_scatter %arg0[%arg1, %arg2] %arg3 : !tt.tensordesc<1x128xbf16, #shared1>, tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32, !ttg.memdesc<32x128xbf16, #shared1, #smem, mutable> // CHECK: nvvm.cp.async.bulk.commit.group() diff --git a/test/Conversion/triton_to_tritongpu.mlir b/test/Conversion/triton_to_tritongpu.mlir index bfe47653f506..7f47b0c53786 100644 --- a/test/Conversion/triton_to_tritongpu.mlir +++ b/test/Conversion/triton_to_tritongpu.mlir @@ -129,22 +129,22 @@ tt.func @gather_op() { // CHECK: [[SLICE_PARENT:#.*]] = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [1, 2], order = [1, 0]}> // CHECK: @gather4_layout -tt.func @gather4_layout(%arg0: !tt.tensordesc>, %arg1: i32, %arg2: !tt.ptr) { +tt.func @gather4_layout(%arg0: !tt.tensordesc<1x128xf32>, %arg1: i32, %arg2: !tt.ptr) { %cst = arith.constant dense<1> : tensor<32xi32> // CHECK: [[IDX:%.*]] = ttg.convert_layout %cst : tensor<32xi32, #{{.*}}> -> tensor<32xi32, #ttg.slice<{dim = 0, parent = [[SLICE_PARENT]]}>> - %0 = tt.descriptor_gather %arg0[%cst, %arg1] : (!tt.tensordesc>, tensor<32xi32>, i32) -> tensor<32x128xf32> + %0 = tt.descriptor_gather %arg0[%cst, %arg1] : (!tt.tensordesc<1x128xf32>, tensor<32xi32>, i32) -> tensor<32x128xf32> %1 = tt.splat %arg2 : !tt.ptr -> tensor<32x128x!tt.ptr> tt.store %1, %0 : tensor<32x128x!tt.ptr> tt.return } // CHECK: @scatter4_layout -tt.func @scatter4_layout(%arg0: !tt.tensordesc>, %arg1: i32, %arg2: !tt.ptr) { +tt.func @scatter4_layout(%arg0: !tt.tensordesc<1x128xf32>, %arg1: i32, %arg2: !tt.ptr) { %cst = arith.constant dense<1> : tensor<32xi32> %0 = tt.splat %arg2 : !tt.ptr -> tensor<32x128x!tt.ptr> %1 = tt.load %0 : tensor<32x128x!tt.ptr> // CHECK: [[IDX:%.*]] = ttg.convert_layout %cst : tensor<32xi32, #{{.*}}> -> tensor<32xi32, #ttg.slice<{dim = 0, parent = [[SLICE_PARENT]]}>> - tt.descriptor_scatter %arg0[%cst, %arg1], %1 : !tt.tensordesc>, tensor<32xi32>, i32, tensor<32x128xf32> + tt.descriptor_scatter %arg0[%cst, %arg1], %1 : !tt.tensordesc<1x128xf32>, tensor<32xi32>, i32, tensor<32x128xf32> tt.return } diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index 354ec1c6ef99..764be9c30c36 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -2512,10 +2512,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { // CHECK-LABEL: @reinterpret_tensor_descriptor -tt.func private @reinterpret_tensor_descriptor(%arg0: !tt.ptr) -> !tt.tensordesc> { +tt.func private @reinterpret_tensor_descriptor(%arg0: !tt.ptr) -> !tt.tensordesc<128x64xf16, #shared> { // CHECK-NEXT: llvm.addrspacecast %arg0 : !llvm.ptr to !llvm.ptr - %0 = ttng.reinterpret_tensor_descriptor %arg0 : !tt.ptr to !tt.tensordesc> - tt.return %0 : !tt.tensordesc> + %0 = ttng.reinterpret_tensor_descriptor %arg0 : !tt.ptr to !tt.tensordesc<128x64xf16, #shared> + tt.return %0 : !tt.tensordesc<128x64xf16, #shared> } } diff --git a/test/Conversion/tritongpu_to_llvm_gsan.mlir b/test/Conversion/tritongpu_to_llvm_gsan.mlir index 2022780ce86d..2fc2179c1d2c 100644 --- a/test/Conversion/tritongpu_to_llvm_gsan.mlir +++ b/test/Conversion/tritongpu_to_llvm_gsan.mlir @@ -42,7 +42,7 @@ module attributes {"ttg.instrumentation_mode" = "gsan", "ttg.num-ctas" = 1 : i32 module attributes {"ttg.instrumentation_mode" = "gsan", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: llvm.func @tma_f16_gsan_merge - tt.func @tma_f16_gsan_merge(%desc: !tt.tensordesc>) { + tt.func @tma_f16_gsan_merge(%desc: !tt.tensordesc<32x64xf16, #shared_f16>) { %true = arith.constant true %c0_i32 = arith.constant 0 : i32 %buf = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<32x64xf16, #shared_f16, #smem, mutable> @@ -51,7 +51,7 @@ module attributes {"ttg.instrumentation_mode" = "gsan", "ttg.num-ctas" = 1 : i32 // CHECK: %[[COUNT:.*]] = llvm.mlir.constant(32 : i32) : i32 // CHECK: %[[BYTES:.*]] = llvm.mlir.constant(4 : i32) : i32 // CHECK: llvm.call @__triton_gsan_load_tensor(%{{.*}}, %{{.*}}, %[[COUNT]], %[[BYTES]], %{{.*}}, %{{.*}}) - ttng.async_tma_copy_global_to_local %desc[%c0_i32, %c0_i32] %buf, %barrier, %true : !tt.tensordesc>, !ttg.memdesc<1xi64, #bar, #smem, mutable> -> !ttg.memdesc<32x64xf16, #shared_f16, #smem, mutable> + ttng.async_tma_copy_global_to_local %desc[%c0_i32, %c0_i32] %buf, %barrier, %true : !tt.tensordesc<32x64xf16, #shared_f16>, !ttg.memdesc<1xi64, #bar, #smem, mutable> -> !ttg.memdesc<32x64xf16, #shared_f16, #smem, mutable> tt.return } } @@ -64,7 +64,7 @@ module attributes {"ttg.instrumentation_mode" = "gsan", "ttg.num-ctas" = 1 : i32 module attributes {"ttg.instrumentation_mode" = "gsan", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: llvm.func @tma_f16_gsan_merge_4warps - tt.func @tma_f16_gsan_merge_4warps(%desc: !tt.tensordesc>) { + tt.func @tma_f16_gsan_merge_4warps(%desc: !tt.tensordesc<128x64xf16, #shared_f16>) { %true = arith.constant true %c0_i32 = arith.constant 0 : i32 %buf = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<128x64xf16, #shared_f16, #smem, mutable> @@ -73,7 +73,7 @@ module attributes {"ttg.instrumentation_mode" = "gsan", "ttg.num-ctas" = 1 : i32 // CHECK: %[[COUNT_4W:.*]] = llvm.mlir.constant(32 : i32) : i32 // CHECK: %[[BYTES_4W:.*]] = llvm.mlir.constant(4 : i32) : i32 // CHECK: llvm.call @__triton_gsan_load_tensor(%{{.*}}, %{{.*}}, %[[COUNT_4W]], %[[BYTES_4W]], %{{.*}}, %{{.*}}) - ttng.async_tma_copy_global_to_local %desc[%c0_i32, %c0_i32] %buf, %barrier, %true : !tt.tensordesc>, !ttg.memdesc<1xi64, #bar, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared_f16, #smem, mutable> + ttng.async_tma_copy_global_to_local %desc[%c0_i32, %c0_i32] %buf, %barrier, %true : !tt.tensordesc<128x64xf16, #shared_f16>, !ttg.memdesc<1xi64, #bar, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared_f16, #smem, mutable> tt.return } } diff --git a/test/Conversion/tritoninstrument_to_llvm.mlir b/test/Conversion/tritoninstrument_to_llvm.mlir index a2156c533fc2..d729958b49eb 100644 --- a/test/Conversion/tritoninstrument_to_llvm.mlir +++ b/test/Conversion/tritoninstrument_to_llvm.mlir @@ -103,9 +103,9 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:90"} { // CHECK: llvm.mul %{{.*}}, %{{.*}} : i64 // CHECK: llvm.udiv %{{.*}}, %{{.*}} : i64 tt.func private @experimental_gsan_tensordesc_info( - %desc: !tt.tensordesc> + %desc: !tt.tensordesc<32x32xf32, #shared> ) { - %0:5 = "tti.experimental_gsan_tensordesc_info"(%desc) : (!tt.tensordesc>) -> (!tt.ptr, i64, i64, i64, i64) + %0:5 = "tti.experimental_gsan_tensordesc_info"(%desc) : (!tt.tensordesc<32x32xf32, #shared>) -> (!tt.ptr, i64, i64, i64, i64) tt.return } } diff --git a/test/Conversion/tritonnvidiagpu_to_llvm.mlir b/test/Conversion/tritonnvidiagpu_to_llvm.mlir index 299ff223cd6a..b745b8df87ed 100644 --- a/test/Conversion/tritonnvidiagpu_to_llvm.mlir +++ b/test/Conversion/tritonnvidiagpu_to_llvm.mlir @@ -139,8 +139,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK: "@$0 cp.async.bulk.tensor.2d.shared::cta.global.mbarrier::complete_tx::bytes [$1], [$2, {$3, $4}], [$5];", "b,r,l,r,r,r" {{.*}} : (i1, !llvm.ptr<3>, !llvm.ptr, i32, i32, !llvm.ptr<3>) -> !llvm.void // CHECK-NOT: cp.async.bulk.tensor.2d.shared // CHECK: return - tt.func @tma_copy_global_to_local(%tma: !tt.tensordesc>, %alloc: !ttg.memdesc<128x128xf32, #shared1, #smem, mutable>, %x: i32, %barrier: !ttg.memdesc<1xi64, #shared0, #smem>, %pred: i1) { - ttng.async_tma_copy_global_to_local %tma[%x, %x] %alloc, %barrier, %pred : !tt.tensordesc>, !ttg.memdesc<1xi64, #shared0, #smem> -> !ttg.memdesc<128x128xf32, #shared1, #smem, mutable> + tt.func @tma_copy_global_to_local(%tma: !tt.tensordesc<128x128xf32, #shared1>, %alloc: !ttg.memdesc<128x128xf32, #shared1, #smem, mutable>, %x: i32, %barrier: !ttg.memdesc<1xi64, #shared0, #smem>, %pred: i1) { + ttng.async_tma_copy_global_to_local %tma[%x, %x] %alloc, %barrier, %pred : !tt.tensordesc<128x128xf32, #shared1>, !ttg.memdesc<1xi64, #shared0, #smem> -> !ttg.memdesc<128x128xf32, #shared1, #smem, mutable> tt.return } } @@ -155,8 +155,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: tma_copy_barrier_mask_zero // CHECK: cp.async.bulk.tensor.2d.shared::cta.global.mbarrier::complete_tx::bytes // CHECK-NOT: cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier - tt.func @tma_copy_barrier_mask_zero(%tma: !tt.tensordesc>, %alloc: !ttg.memdesc<128x128xf32, #shared1, #smem, mutable>, %x: i32, %barrier: !ttg.memdesc<1xi64, #shared0_cta, #smem>, %pred: i1) { - ttng.async_tma_copy_global_to_local %tma[%x, %x] %alloc, %barrier, %pred : !tt.tensordesc>, !ttg.memdesc<1xi64, #shared0_cta, #smem> -> !ttg.memdesc<128x128xf32, #shared1, #smem, mutable> + tt.func @tma_copy_barrier_mask_zero(%tma: !tt.tensordesc<128x128xf32, #shared1>, %alloc: !ttg.memdesc<128x128xf32, #shared1, #smem, mutable>, %x: i32, %barrier: !ttg.memdesc<1xi64, #shared0_cta, #smem>, %pred: i1) { + ttng.async_tma_copy_global_to_local %tma[%x, %x] %alloc, %barrier, %pred : !tt.tensordesc<128x128xf32, #shared1>, !ttg.memdesc<1xi64, #shared0_cta, #smem> -> !ttg.memdesc<128x128xf32, #shared1, #smem, mutable> tt.return } } @@ -176,8 +176,8 @@ module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32} { // TMA uses shared::cluster when barrier mask is non-zero // CHECK: cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes // CHECK-NOT: cp.async.bulk.tensor.2d.shared::cta.global.mbarrier - tt.func @tma_copy_barrier_mask_nonzero(%tma: !tt.tensordesc>, %alloc: !ttg.memdesc<128x128xf32, #shared1_cga, #smem, mutable>, %x: i32, %barrier: !ttg.memdesc<1xi64, #shared0_cluster, #smem>, %pred: i1) { - ttng.async_tma_copy_global_to_local %tma[%x, %x] %alloc, %barrier, %pred : !tt.tensordesc>, !ttg.memdesc<1xi64, #shared0_cluster, #smem> -> !ttg.memdesc<128x128xf32, #shared1_cga, #smem, mutable> + tt.func @tma_copy_barrier_mask_nonzero(%tma: !tt.tensordesc<128x128xf32, #shared1_cga>, %alloc: !ttg.memdesc<128x128xf32, #shared1_cga, #smem, mutable>, %x: i32, %barrier: !ttg.memdesc<1xi64, #shared0_cluster, #smem>, %pred: i1) { + ttng.async_tma_copy_global_to_local %tma[%x, %x] %alloc, %barrier, %pred : !tt.tensordesc<128x128xf32, #shared1_cga>, !ttg.memdesc<1xi64, #shared0_cluster, #smem> -> !ttg.memdesc<128x128xf32, #shared1_cga, #smem, mutable> tt.return } } @@ -193,10 +193,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK: cp.async.bulk.tensor.4d.shared::cta.global.im2col.mbarrier::complete_tx::bytes // CHECK-NOT: cp.async.bulk.tensor.4d.shared::cta.global.mbarrier // CHECK: return - tt.func @tma_copy_global_to_local_im2col(%tma: !ttng.tensordesc_im2col>, %alloc: !ttg.memdesc<16x64xf32, #shared1, #smem, mutable>, %x: i32, %barrier: !ttg.memdesc<1xi64, #shared0, #smem>, %pred: i1) { + tt.func @tma_copy_global_to_local_im2col(%tma: !ttng.tensordesc_im2col<16x64xf32, #shared1>, %alloc: !ttg.memdesc<16x64xf32, #shared1, #smem, mutable>, %x: i32, %barrier: !ttg.memdesc<1xi64, #shared0, #smem>, %pred: i1) { %off_w = arith.constant 1 : i16 %off_h = arith.constant 2 : i16 - ttng.async_tma_copy_global_to_local %tma[%x, %x, %x, %x] offsets = [%off_w, %off_h] %alloc, %barrier, %pred : !ttng.tensordesc_im2col>, !ttg.memdesc<1xi64, #shared0, #smem> -> !ttg.memdesc<16x64xf32, #shared1, #smem, mutable> + ttng.async_tma_copy_global_to_local %tma[%x, %x, %x, %x] offsets = [%off_w, %off_h] %alloc, %barrier, %pred : !ttng.tensordesc_im2col<16x64xf32, #shared1>, !ttg.memdesc<1xi64, #shared0, #smem> -> !ttg.memdesc<16x64xf32, #shared1, #smem, mutable> tt.return } } @@ -228,10 +228,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { // CHECK: llvm.mlir.constant(3 : i32) // CHECK: cp.async.bulk.tensor.4d.shared::cta.global.im2col.mbarrier::complete_tx::bytes // CHECK: return - tt.func @tma_copy_global_to_local_im2col_multi_msg(%tma: !ttng.tensordesc_im2col>, %alloc: !ttg.memdesc<64x1024xf32, #shared2, #smem, mutable>, %x: i32, %barrier: !ttg.memdesc<1xi64, #shared0, #smem>, %pred: i1) { + tt.func @tma_copy_global_to_local_im2col_multi_msg(%tma: !ttng.tensordesc_im2col<64x1024xf32, #shared2>, %alloc: !ttg.memdesc<64x1024xf32, #shared2, #smem, mutable>, %x: i32, %barrier: !ttg.memdesc<1xi64, #shared0, #smem>, %pred: i1) { %off_w = arith.constant 1 : i16 %off_h = arith.constant 2 : i16 - ttng.async_tma_copy_global_to_local %tma[%x, %x, %x, %x] offsets = [%off_w, %off_h] %alloc, %barrier, %pred : !ttng.tensordesc_im2col>, !ttg.memdesc<1xi64, #shared0, #smem> -> !ttg.memdesc<64x1024xf32, #shared2, #smem, mutable> + ttng.async_tma_copy_global_to_local %tma[%x, %x, %x, %x] offsets = [%off_w, %off_h] %alloc, %barrier, %pred : !ttng.tensordesc_im2col<64x1024xf32, #shared2>, !ttg.memdesc<1xi64, #shared0, #smem> -> !ttg.memdesc<64x1024xf32, #shared2, #smem, mutable> tt.return } } @@ -263,10 +263,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { // CHECK: llvm.mlir.constant(3 : i32) // CHECK: cp.async.bulk.tensor.4d.shared::cta.global.im2col.mbarrier::complete_tx::bytes // CHECK: return - tt.func @tma_copy_global_to_local_im2col_multi_msg_swizzle(%tma: !ttng.tensordesc_im2col>, %alloc: !ttg.memdesc<64x256xf16, #shared_swz, #smem_swz, mutable>, %x: i32, %barrier: !ttg.memdesc<1xi64, #shared0_swz, #smem_swz>, %pred: i1) { + tt.func @tma_copy_global_to_local_im2col_multi_msg_swizzle(%tma: !ttng.tensordesc_im2col<64x256xf16, #shared_swz>, %alloc: !ttg.memdesc<64x256xf16, #shared_swz, #smem_swz, mutable>, %x: i32, %barrier: !ttg.memdesc<1xi64, #shared0_swz, #smem_swz>, %pred: i1) { %off_w = arith.constant 1 : i16 %off_h = arith.constant 2 : i16 - ttng.async_tma_copy_global_to_local %tma[%x, %x, %x, %x] offsets = [%off_w, %off_h] %alloc, %barrier, %pred : !ttng.tensordesc_im2col>, !ttg.memdesc<1xi64, #shared0_swz, #smem_swz> -> !ttg.memdesc<64x256xf16, #shared_swz, #smem_swz, mutable> + ttng.async_tma_copy_global_to_local %tma[%x, %x, %x, %x] offsets = [%off_w, %off_h] %alloc, %barrier, %pred : !ttng.tensordesc_im2col<64x256xf16, #shared_swz>, !ttg.memdesc<1xi64, #shared0_swz, #smem_swz> -> !ttg.memdesc<64x256xf16, #shared_swz, #smem_swz, mutable> tt.return } } @@ -281,8 +281,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK: "@$0 cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [$1, {$2, $3}], [$4];", "b,l,r,r,r" {{.*}} : (i1, !llvm.ptr, i32, i32, !llvm.ptr<3>) -> !llvm.void // CHECK-NOT: cp.async.bulk.tensor.2d.global.shared::cta.bulk_group // CHECK: nvvm.cp.async.bulk.commit.group - tt.func @tma_copy_local_to_global(%tma: !tt.tensordesc>, %alloc: !ttg.memdesc<128x128xf32, #shared1, #smem>, %x: i32) { - ttng.async_tma_copy_local_to_global %tma[%x, %x] %alloc : !tt.tensordesc>, !ttg.memdesc<128x128xf32, #shared1, #smem> + tt.func @tma_copy_local_to_global(%tma: !tt.tensordesc<128x128xf32, #shared1>, %alloc: !ttg.memdesc<128x128xf32, #shared1, #smem>, %x: i32) { + ttng.async_tma_copy_local_to_global %tma[%x, %x] %alloc : !tt.tensordesc<128x128xf32, #shared1>, !ttg.memdesc<128x128xf32, #shared1, #smem> tt.return } } diff --git a/test/Hopper/WarpSpecialization/ws_code_partition.mlir b/test/Hopper/WarpSpecialization/ws_code_partition.mlir index 912150a5e678..e3379e2dcf71 100644 --- a/test/Hopper/WarpSpecialization/ws_code_partition.mlir +++ b/test/Hopper/WarpSpecialization/ws_code_partition.mlir @@ -179,7 +179,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ #mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 256, 16]}> #shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { - tt.func public @_matmul_layernorm_persistent_one_producer_one_consumer_one_epilog(%arg0: !tt.tensordesc>, %arg1: !tt.tensordesc>, %arg2: !tt.tensordesc>, %arg3: !tt.tensordesc>, %arg4: !tt.tensordesc>, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32 {tt.divisibility = 16 : i32}, %arg10: i32 {tt.divisibility = 16 : i32}, %arg11: f32) { + tt.func public @_matmul_layernorm_persistent_one_producer_one_consumer_one_epilog(%arg0: !tt.tensordesc<128x64xf16, #shared>, %arg1: !tt.tensordesc<64x256xf16, #shared>, %arg2: !tt.tensordesc<128x256xf16, #shared>, %arg3: !tt.tensordesc<256xf16, #shared>, %arg4: !tt.tensordesc<256xf16, #shared>, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32 {tt.divisibility = 16 : i32}, %arg10: i32 {tt.divisibility = 16 : i32}, %arg11: f32) { %c63_i32 = arith.constant {async_task_id = array} 63 : i32 %c128_i32 = arith.constant {async_task_id = array} 128 : i32 %c0_i32 = arith.constant {async_task_id = array} 0 : i32 @@ -208,9 +208,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ %false = arith.constant {async_task_id = array} false %12 = scf.for %arg13 = %c0_i32 to %1 step %c1_i32 iter_args(%arg14 = %cst) -> (tensor<128x256xf32, #mma>) : i32 { %45 = arith.muli %arg13, %c64_i32 {async_task_id = array} : i32 - %46 = tt.descriptor_load %arg0[%11, %45] {async_task_id = array} : !tt.tensordesc> -> tensor<128x64xf16, #blocked> + %46 = tt.descriptor_load %arg0[%11, %45] {async_task_id = array} : !tt.tensordesc<128x64xf16, #shared> -> tensor<128x64xf16, #blocked> %47 = ttg.local_alloc %46 {async_task_id = array} : (tensor<128x64xf16, #blocked>) -> !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> - %48 = tt.descriptor_load %arg1[%45, %c0_i32] {async_task_id = array} : !tt.tensordesc> -> tensor<64x256xf16, #blocked1> + %48 = tt.descriptor_load %arg1[%45, %c0_i32] {async_task_id = array} : !tt.tensordesc<64x256xf16, #shared> -> tensor<64x256xf16, #blocked1> %49 = ttg.local_alloc %48 {async_task_id = array} : (tensor<64x256xf16, #blocked1>) -> !ttg.memdesc<64x256xf16, #shared, #ttg.shared_memory> %50 = ttng.warp_group_dot %47, %49, %arg14 {async_task_id = array, inputPrecision = 0 : i32} : !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> * !ttg.memdesc<64x256xf16, #shared, #ttg.shared_memory> -> tensor<128x256xf32, #mma> scf.yield {async_task_id = array} %50 : tensor<128x256xf32, #mma> @@ -234,8 +234,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ %21 = arith.addf %20, %10 {async_task_id = array} : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> %22 = math.sqrt %21 {async_task_id = array} : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> %23 = arith.divf %cst_0, %22 {async_task_id = array} : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> - %24 = tt.descriptor_load %arg3[%c0_i32] {async_task_id = array} : !tt.tensordesc> -> tensor<256xf16, #blocked2> - %25 = tt.descriptor_load %arg4[%c0_i32] {async_task_id = array} : !tt.tensordesc> -> tensor<256xf16, #blocked2> + %24 = tt.descriptor_load %arg3[%c0_i32] {async_task_id = array} : !tt.tensordesc<256xf16, #shared> -> tensor<256xf16, #blocked2> + %25 = tt.descriptor_load %arg4[%c0_i32] {async_task_id = array} : !tt.tensordesc<256xf16, #shared> -> tensor<256xf16, #blocked2> %26 = tt.expand_dims %23 {async_task_id = array, axis = 1 : i32} : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> -> tensor<128x1xf32, #mma> %27 = tt.broadcast %26 {async_task_id = array} : tensor<128x1xf32, #mma> -> tensor<128x256xf32, #mma> %28 = arith.mulf %17, %27 {async_task_id = array} : tensor<128x256xf32, #mma> @@ -255,7 +255,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ %42 = arith.addf %35, %41 {async_task_id = array} : tensor<128x256xf32, #mma> %43 = arith.truncf %42 {async_task_id = array} : tensor<128x256xf32, #mma> to tensor<128x256xf16, #mma> %44 = ttg.convert_layout %43 {async_task_id = array} : tensor<128x256xf16, #mma> -> tensor<128x256xf16, #blocked1> - tt.descriptor_store %arg2[%11, %c0_i32], %44 {async_task_id = array} : !tt.tensordesc>, tensor<128x256xf16, #blocked1> + tt.descriptor_store %arg2[%11, %c0_i32], %44 {async_task_id = array} : !tt.tensordesc<128x256xf16, #shared>, tensor<128x256xf16, #blocked1> } {async_task_id = array} tt.return } @@ -285,21 +285,21 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ %c64_i32 = arith.constant {async_task_id = array} 64 : i32 %cst = arith.constant {async_task_id = array} dense<0.000000e+00> : tensor<64x128xf32, #mma> %0 = tt.get_program_id x {async_task_id = array} : i32 - %1 = ttng.reinterpret_tensor_descriptor %arg0 {async_task_id = array} : !tt.ptr to !tt.tensordesc> - %2 = ttng.reinterpret_tensor_descriptor %arg2 {async_task_id = array} : !tt.ptr to !tt.tensordesc> - %3 = ttng.reinterpret_tensor_descriptor %arg3 {async_task_id = array} : !tt.ptr to !tt.tensordesc> + %1 = ttng.reinterpret_tensor_descriptor %arg0 {async_task_id = array} : !tt.ptr to !tt.tensordesc<64x64xf8E4M3FN, #shared> + %2 = ttng.reinterpret_tensor_descriptor %arg2 {async_task_id = array} : !tt.ptr to !tt.tensordesc<128x64xf8E4M3FN, #shared> + %3 = ttng.reinterpret_tensor_descriptor %arg3 {async_task_id = array} : !tt.ptr to !tt.tensordesc<128xf32, #shared1> scf.for %arg4 = %0 to %arg1 step %c64_i32 : i32 { %4 = arith.muli %arg4, %c2048_i32 {async_task_id = array} : i32 %5 = scf.for %arg5 = %c0_i32 to %c2048_i32 step %c64_i32 iter_args(%arg6 = %cst) -> (tensor<64x128xf32, #mma>) : i32 { - %8 = tt.descriptor_load %1[%4, %arg5] {async_task_id = array} : !tt.tensordesc> -> tensor<64x64xf8E4M3FN, #blocked> + %8 = tt.descriptor_load %1[%4, %arg5] {async_task_id = array} : !tt.tensordesc<64x64xf8E4M3FN, #shared> -> tensor<64x64xf8E4M3FN, #blocked> %9 = ttg.local_alloc %8 {async_task_id = array} : (tensor<64x64xf8E4M3FN, #blocked>) -> !ttg.memdesc<64x64xf8E4M3FN, #shared, #smem> - %10 = tt.descriptor_load %2[%4, %arg5] {async_task_id = array} : !tt.tensordesc> -> tensor<128x64xf8E4M3FN, #blocked> + %10 = tt.descriptor_load %2[%4, %arg5] {async_task_id = array} : !tt.tensordesc<128x64xf8E4M3FN, #shared> -> tensor<128x64xf8E4M3FN, #blocked> %11 = ttg.local_alloc %10 {async_task_id = array} : (tensor<128x64xf8E4M3FN, #blocked>) -> !ttg.memdesc<128x64xf8E4M3FN, #shared, #smem> %12 = ttg.memdesc_trans %11 {async_task_id = array, order = array} : !ttg.memdesc<128x64xf8E4M3FN, #shared, #smem> -> !ttg.memdesc<64x128xf8E4M3FN, #shared2, #smem> %13 = ttng.warp_group_dot %9, %12, %arg6 {async_task_id = array, inputPrecision = 0 : i32, maxNumImpreciseAcc = 1073741824 : i32} : !ttg.memdesc<64x64xf8E4M3FN, #shared, #smem> * !ttg.memdesc<64x128xf8E4M3FN, #shared2, #smem> -> tensor<64x128xf32, #mma> scf.yield {async_task_id = array} %13 : tensor<64x128xf32, #mma> } {async_task_id = array} - %6 = tt.descriptor_load %3[%4] {async_task_id = array} : !tt.tensordesc> -> tensor<128xf32, #blocked1> + %6 = tt.descriptor_load %3[%4] {async_task_id = array} : !tt.tensordesc<128xf32, #shared1> -> tensor<128xf32, #blocked1> %7 = ttg.convert_layout %6 {async_task_id = array} : tensor<128xf32, #blocked1> -> tensor<128xf32, #ttg.slice<{dim = 0, parent = #blocked}>> } {async_task_id = array} tt.return diff --git a/test/Hopper/WarpSpecialization/ws_data_partition.mlir b/test/Hopper/WarpSpecialization/ws_data_partition.mlir index 79fb91ba5e42..575a17b437d1 100644 --- a/test/Hopper/WarpSpecialization/ws_data_partition.mlir +++ b/test/Hopper/WarpSpecialization/ws_data_partition.mlir @@ -69,16 +69,16 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ ttng.tensormap_create %arg6, %arg2, [%c64_i32, %c128_i32], [%arg8, %arg9], [%3], [%c1_i32, %c1_i32] {async_task_id = array, elem_type = 1 : i32, fill_mode = 0 : i32, interleave_layout = 0 : i32, swizzle_mode = 3 : i32} : (!tt.ptr, !tt.ptr, i32, i32, i32, i32, i64, i32, i32) -> () ttng.tensormap_create %arg6, %arg3, [%c64_i32, %c64_i32], [%arg8, %2], [%3], [%c1_i32, %c1_i32] {async_task_id = array, elem_type = 1 : i32, fill_mode = 0 : i32, interleave_layout = 0 : i32, swizzle_mode = 3 : i32} : (!tt.ptr, !tt.ptr, i32, i32, i32, i32, i64, i32, i32) -> () ttng.tensormap_create %arg6, %arg5, [%c64_i32, %c64_i32], [%arg8, %2], [%3], [%c1_i32, %c1_i32] {async_task_id = array, elem_type = 1 : i32, fill_mode = 0 : i32, interleave_layout = 0 : i32, swizzle_mode = 3 : i32} : (!tt.ptr, !tt.ptr, i32, i32, i32, i32, i64, i32, i32) -> () - %4 = ttng.reinterpret_tensor_descriptor %arg6 {async_task_id = array} : !tt.ptr to !tt.tensordesc> - %5 = ttng.reinterpret_tensor_descriptor %arg6 {async_task_id = array} : !tt.ptr to !tt.tensordesc> - %6 = ttng.reinterpret_tensor_descriptor %arg6 {async_task_id = array} : !tt.ptr to !tt.tensordesc> - %7 = ttng.reinterpret_tensor_descriptor %arg6 {async_task_id = array} : !tt.ptr to !tt.tensordesc> + %4 = ttng.reinterpret_tensor_descriptor %arg6 {async_task_id = array} : !tt.ptr to !tt.tensordesc<128x128xbf16> + %5 = ttng.reinterpret_tensor_descriptor %arg6 {async_task_id = array} : !tt.ptr to !tt.tensordesc<128x128xbf16> + %6 = ttng.reinterpret_tensor_descriptor %arg6 {async_task_id = array} : !tt.ptr to !tt.tensordesc<128x128xbf16> + %7 = ttng.reinterpret_tensor_descriptor %arg6 {async_task_id = array} : !tt.ptr to !tt.tensordesc<128x128xbf16> // CHECK: tt.descriptor_load {{.*}} -> tensor<64x128xbf16 // CHECK: tt.descriptor_load {{.*}} -> tensor<64x128xbf16 - %8 = tt.descriptor_load %4[%0, %1] {async_task_id = array} : !tt.tensordesc> -> tensor<128x128xbf16, #blocked1> + %8 = tt.descriptor_load %4[%0, %1] {async_task_id = array} : !tt.tensordesc<128x128xbf16> -> tensor<128x128xbf16, #blocked1> %9 = ttg.local_alloc %8 {async_task_id = array} : (tensor<128x128xbf16, #blocked1>) -> !ttg.memdesc<128x128xbf16, #shared, #smem> // CHECK: tt.descriptor_load {{.*}} -> tensor<128x128xbf16 - %10 = tt.descriptor_load %5[%1, %1] {async_task_id = array} : !tt.tensordesc> -> tensor<128x128xbf16, #blocked1> + %10 = tt.descriptor_load %5[%1, %1] {async_task_id = array} : !tt.tensordesc<128x128xbf16> -> tensor<128x128xbf16, #blocked1> %11 = ttg.local_alloc %10 {async_task_id = array} : (tensor<128x128xbf16, #blocked1>) -> !ttg.memdesc<128x128xbf16, #shared, #smem> // CHECK: ttng.warp_group_dot {{.*}} : !ttg.memdesc<64x128xbf16, {{.*}} * !ttg.memdesc<128x128xbf16, {{.*}} -> tensor<64x128xf32, {{.*}} // CHECK: ttng.warp_group_dot {{.*}} : !ttg.memdesc<64x128xbf16, {{.*}} * !ttg.memdesc<128x128xbf16, {{.*}} -> tensor<64x128xf32, {{.*}} @@ -87,7 +87,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ %14 = ttg.local_alloc %13 {async_task_id = array} : (tensor<128x128xbf16, #mma>) -> !ttg.memdesc<128x128xbf16, #shared, #smem> // CHECK: tt.descriptor_load {{.*}} -> tensor<64x128xbf16 // CHECK: tt.descriptor_load {{.*}} -> tensor<64x128xbf16 - %15 = tt.descriptor_load %6[%0, %1] {async_task_id = array} : !tt.tensordesc> -> tensor<128x128xbf16, #blocked1> + %15 = tt.descriptor_load %6[%0, %1] {async_task_id = array} : !tt.tensordesc<128x128xbf16> -> tensor<128x128xbf16, #blocked1> %16 = ttg.local_alloc %15 {async_task_id = array} : (tensor<128x128xbf16, #blocked1>) -> !ttg.memdesc<128x128xbf16, #shared, #smem> %17 = ttg.memdesc_trans %16 {async_task_id = array, order = array} : !ttg.memdesc<128x128xbf16, #shared, #smem> -> !ttg.memdesc<128x128xbf16, #shared1, #smem> // CHECK: ttng.warp_group_dot {{.*}} : !ttg.memdesc<128x64xbf16, {{.*}} * !ttg.memdesc<64x128xbf16, {{.*}} -> tensor<128x128xf32, {{.*}} diff --git a/test/Hopper/WarpSpecialization/ws_task_id_propagation.mlir b/test/Hopper/WarpSpecialization/ws_task_id_propagation.mlir index 582815c0b0f7..52b9134771f1 100644 --- a/test/Hopper/WarpSpecialization/ws_task_id_propagation.mlir +++ b/test/Hopper/WarpSpecialization/ws_task_id_propagation.mlir @@ -29,7 +29,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // CHECK-NEXT: tt.descriptor_store %[[OUTPUT:.*]][%[[IV]], %[[IV]]], %{{.*}} {async_task_id = array} // CHECK-NEXT: } {async_task_id = array} - tt.func public @matmul_persistent_tma_ws_cooperative_kernel(%arg0: !tt.tensordesc>, %arg1: !tt.tensordesc>, %arg2: !tt.tensordesc>, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) { + tt.func public @matmul_persistent_tma_ws_cooperative_kernel(%arg0: !tt.tensordesc<128x64xf16>, %arg1: !tt.tensordesc<64x256xf16>, %arg2: !tt.tensordesc<128x256xf16>, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) { %c0_i32 = arith.constant 0 : i32 %c1_i32 = arith.constant 1 : i32 %c64_i32 = arith.constant 64 : i32 @@ -38,9 +38,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ %1 = tt.get_num_programs x : i32 scf.for %arg6 = %0 to %arg3 step %1 : i32 { %2:2 = scf.for %arg7 = %c0_i32 to %arg5 step %c1_i32 iter_args(%arg8 = %cst, %arg9 = %c0_i32) -> (tensor<128x256xf32, #mma>, i32) : i32 { - %5 = tt.descriptor_load %arg0[%arg6, %arg9] {async_task_id = array} : !tt.tensordesc> -> tensor<128x64xf16, #blocked> + %5 = tt.descriptor_load %arg0[%arg6, %arg9] {async_task_id = array} : !tt.tensordesc<128x64xf16> -> tensor<128x64xf16, #blocked> %6 = ttg.local_alloc %5 : (tensor<128x64xf16, #blocked>) -> !ttg.memdesc<128x64xf16, #shared, #smem> - %7 = tt.descriptor_load %arg1[%arg9, %arg6] {async_task_id = array} : !tt.tensordesc> -> tensor<64x256xf16, #blocked1> + %7 = tt.descriptor_load %arg1[%arg9, %arg6] {async_task_id = array} : !tt.tensordesc<64x256xf16> -> tensor<64x256xf16, #blocked1> %8 = ttg.local_alloc %7 : (tensor<64x256xf16, #blocked1>) -> !ttg.memdesc<64x256xf16, #shared, #smem> %9 = ttng.warp_group_dot %6, %8, %arg8 {async_task_id = array, inputPrecision = 0 : i32} : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x256xf16, #shared, #smem> -> tensor<128x256xf32, #mma> %10 = arith.addi %arg9, %c64_i32 : i32 @@ -48,7 +48,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ } %3 = arith.truncf %2#0 : tensor<128x256xf32, #mma> to tensor<128x256xf16, #mma> %4 = ttg.convert_layout %3 : tensor<128x256xf16, #mma> -> tensor<128x256xf16, #blocked1> - tt.descriptor_store %arg2[%arg6, %arg6], %4 {async_task_id = array} : !tt.tensordesc>, tensor<128x256xf16, #blocked1> + tt.descriptor_store %arg2[%arg6, %arg6], %4 {async_task_id = array} : !tt.tensordesc<128x256xf16>, tensor<128x256xf16, #blocked1> } tt.return } diff --git a/test/Hopper/WarpSpecialization/ws_task_partition.mlir b/test/Hopper/WarpSpecialization/ws_task_partition.mlir index 17d5aa918c28..b0101fae2d75 100644 --- a/test/Hopper/WarpSpecialization/ws_task_partition.mlir +++ b/test/Hopper/WarpSpecialization/ws_task_partition.mlir @@ -14,7 +14,7 @@ #shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> #smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { - tt.func public @matmul_persistent_tma_ws_cooperative_kernel(%arg0: !tt.tensordesc>, %arg1: !tt.tensordesc>, %arg2: !tt.tensordesc>, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) { + tt.func public @matmul_persistent_tma_ws_cooperative_kernel(%arg0: !tt.tensordesc<128x64xf16>, %arg1: !tt.tensordesc<64x256xf16>, %arg2: !tt.tensordesc<128x256xf16>, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) { %c0_i32 = arith.constant 0 : i32 %c1_i32 = arith.constant 1 : i32 %c64_i32 = arith.constant 64 : i32 @@ -23,9 +23,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ %1 = tt.get_num_programs x : i32 scf.for %arg6 = %0 to %arg3 step %1 : i32 { %2:2 = scf.for %arg7 = %c0_i32 to %arg5 step %c1_i32 iter_args(%arg8 = %cst, %arg9 = %c0_i32) -> (tensor<128x256xf32, #mma>, i32) : i32 { - %5 = tt.descriptor_load %arg0[%arg6, %arg9] : !tt.tensordesc> -> tensor<128x64xf16, #blocked> + %5 = tt.descriptor_load %arg0[%arg6, %arg9] : !tt.tensordesc<128x64xf16> -> tensor<128x64xf16, #blocked> %6 = ttg.local_alloc %5 : (tensor<128x64xf16, #blocked>) -> !ttg.memdesc<128x64xf16, #shared, #smem> - %7 = tt.descriptor_load %arg1[%arg9, %arg6] : !tt.tensordesc> -> tensor<64x256xf16, #blocked1> + %7 = tt.descriptor_load %arg1[%arg9, %arg6] : !tt.tensordesc<64x256xf16> -> tensor<64x256xf16, #blocked1> %8 = ttg.local_alloc %7 : (tensor<64x256xf16, #blocked1>) -> !ttg.memdesc<64x256xf16, #shared, #smem> %9 = ttng.warp_group_dot %6, %8, %arg8 {inputPrecision = 0 : i32} : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x256xf16, #shared, #smem> -> tensor<128x256xf32, #mma> %10 = arith.addi %arg9, %c64_i32 : i32 @@ -33,7 +33,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ } %3 = arith.truncf %2#0 : tensor<128x256xf32, #mma> to tensor<128x256xf16, #mma> %4 = ttg.convert_layout %3 : tensor<128x256xf16, #mma> -> tensor<128x256xf16, #blocked1> - tt.descriptor_store %arg2[%arg6, %arg6], %4 : !tt.tensordesc>, tensor<128x256xf16, #blocked1> + tt.descriptor_store %arg2[%arg6, %arg6], %4 : !tt.tensordesc<128x256xf16>, tensor<128x256xf16, #blocked1> } tt.return } diff --git a/test/NVWS/aref-tmem-insertion.mlir b/test/NVWS/aref-tmem-insertion.mlir index d37890dd8370..22efc8d362e2 100644 --- a/test/NVWS/aref-tmem-insertion.mlir +++ b/test/NVWS/aref-tmem-insertion.mlir @@ -18,7 +18,7 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { // CHECK-LABEL: @warp_specialize_tma_matmul - tt.func @warp_specialize_tma_matmul(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: !tt.tensordesc>, %arg4: !tt.tensordesc>) { + tt.func @warp_specialize_tma_matmul(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: !tt.tensordesc<128x64xf16, #shared>, %arg4: !tt.tensordesc<128x64xf16, #shared>) { %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> %c64_i32 = arith.constant 64 : i32 @@ -35,8 +35,8 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { // CHECK-NEXT: [[TOK2:%.*]] = scf.for {{.*}} iter_args([[TOK:%.*]] = [[ATOK]]) %1 = scf.for %arg5 = %c0_i32 to %arg0 step %c1_i32 iter_args(%arg6 = %0) -> (!ttg.async.token) : i32 { %2 = arith.muli %arg5, %c64_i32 {ttg.partition = array} : i32 - %3 = tt.descriptor_load %arg3[%arg1, %2] {ttg.partition = array} : !tt.tensordesc> -> tensor<128x64xf16, #blocked1> - %4 = tt.descriptor_load %arg4[%arg2, %2] {ttg.partition = array} : !tt.tensordesc> -> tensor<128x64xf16, #blocked1> + %3 = tt.descriptor_load %arg3[%arg1, %2] {ttg.partition = array} : !tt.tensordesc<128x64xf16, #shared> -> tensor<128x64xf16, #blocked1> + %4 = tt.descriptor_load %arg4[%arg2, %2] {ttg.partition = array} : !tt.tensordesc<128x64xf16, #shared> -> tensor<128x64xf16, #blocked1> %5 = ttg.local_alloc %3 {ttg.partition = array} : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #smem> %6 = ttg.local_alloc %4 {ttg.partition = array} : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #smem> %7 = ttg.memdesc_trans %6 {order = array, ttg.partition = array} : !ttg.memdesc<128x64xf16, #shared, #smem> -> !ttg.memdesc<64x128xf16, #shared1, #smem> @@ -56,7 +56,7 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { } // CHECK-LABEL: @matmul_tma_acc_with_unconditional_user - tt.func @matmul_tma_acc_with_unconditional_user(%arg0: !tt.tensordesc>, %arg1: !tt.tensordesc>) { + tt.func @matmul_tma_acc_with_unconditional_user(%arg0: !tt.tensordesc<128x64xf16, #shared>, %arg1: !tt.tensordesc<64x128xf16, #shared>) { %c32_i32 = arith.constant 32 : i32 %cst = arith.constant dense<1.000000e+00> : tensor<128x128xf32, #blocked> %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> @@ -73,8 +73,8 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { // CHECK: [[TOK1:%.]] = scf.for [[I:%.*]] = [[UB:%.*]] to [[LB:%.*]] step [[STEP:%.*]] iter_args([[TOK:%.*]] = [[ATOK]]) %1 = scf.for %arg2 = %c0_i32 to %c32_i32 step %c1_i32 iter_args(%arg3 = %0) -> (!ttg.async.token) : i32 { %2:3 = "get_offsets"(%arg2) {ttg.partition = array} : (i32) -> (i32, i32, i32) - %3 = tt.descriptor_load %arg0[%2#0, %2#2] {ttg.partition = array} : !tt.tensordesc> -> tensor<128x64xf16, #blocked1> - %4 = tt.descriptor_load %arg1[%2#1, %2#2] {ttg.partition = array} : !tt.tensordesc> -> tensor<64x128xf16, #blocked1> + %3 = tt.descriptor_load %arg0[%2#0, %2#2] {ttg.partition = array} : !tt.tensordesc<128x64xf16, #shared> -> tensor<128x64xf16, #blocked1> + %4 = tt.descriptor_load %arg1[%2#1, %2#2] {ttg.partition = array} : !tt.tensordesc<64x128xf16, #shared> -> tensor<64x128xf16, #blocked1> %5 = ttg.local_alloc %3 {ttg.partition = array} : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #smem> %6 = ttg.local_alloc %4 {ttg.partition = array} : (tensor<64x128xf16, #blocked1>) -> !ttg.memdesc<64x128xf16, #shared, #smem> // CHECK: [[BUF:%.*]] = nvws.aref.buffer [[AREF]], [[TOK]] {ttg.partition = array} @@ -102,7 +102,7 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { } // CHECK-LABEL: @matmul_tma_acc_with_conditional_user - tt.func @matmul_tma_acc_with_conditional_user(%arg0: !tt.tensordesc>, %arg1: !tt.tensordesc>) { + tt.func @matmul_tma_acc_with_conditional_user(%arg0: !tt.tensordesc<128x64xf16, #shared>, %arg1: !tt.tensordesc<64x128xf16, #shared>) { %c32_i32 = arith.constant 32 : i32 %cst = arith.constant dense<1.000000e+00> : tensor<128x128xf32, #blocked> %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> @@ -120,8 +120,8 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { // CHECK: [[TOK2:%.*]] = scf.for [[I:%.*]] = [[UB:%.*]] to [[LB:%.*]] step [[STEP:%.*]] iter_args([[TOK:%.*]] = [[ATOK]]) %1 = scf.for %arg2 = %c0_i32 to %c32_i32 step %c1_i32 iter_args(%arg3 = %0) -> (!ttg.async.token) : i32 { %2:3 = "get_offsets"(%arg2) {ttg.partition = array} : (i32) -> (i32, i32, i32) - %3 = tt.descriptor_load %arg0[%2#0, %2#2] {ttg.partition = array} : !tt.tensordesc> -> tensor<128x64xf16, #blocked1> - %4 = tt.descriptor_load %arg1[%2#1, %2#2] {ttg.partition = array} : !tt.tensordesc> -> tensor<64x128xf16, #blocked1> + %3 = tt.descriptor_load %arg0[%2#0, %2#2] {ttg.partition = array} : !tt.tensordesc<128x64xf16, #shared> -> tensor<128x64xf16, #blocked1> + %4 = tt.descriptor_load %arg1[%2#1, %2#2] {ttg.partition = array} : !tt.tensordesc<64x128xf16, #shared> -> tensor<64x128xf16, #blocked1> %5 = ttg.local_alloc %3 {ttg.partition = array} : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #smem> %6 = ttg.local_alloc %4 {ttg.partition = array} : (tensor<64x128xf16, #blocked1>) -> !ttg.memdesc<64x128xf16, #shared, #smem> %7 = ttng.tc_gen5_mma %5, %6, %result[%arg3], %true, %true {ttg.partition = array} : !ttg.memdesc<128x64xf16, #shared, #smem>, !ttg.memdesc<64x128xf16, #shared, #smem>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> @@ -157,7 +157,7 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { } // CHECK-LABEL: @matmul_tma_acc_with_conditional_def - tt.func @matmul_tma_acc_with_conditional_def(%arg0: !tt.tensordesc>, %arg1: !tt.tensordesc>) { + tt.func @matmul_tma_acc_with_conditional_def(%arg0: !tt.tensordesc<128x64xf16, #shared>, %arg1: !tt.tensordesc<64x128xf16, #shared>) { %c32_i32 = arith.constant 32 : i32 %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> %true = arith.constant true @@ -169,8 +169,8 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { %0 = ttng.tmem_store %cst, %result[%token], %true : tensor<128x128xf32, #blocked> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> %1 = scf.for %arg2 = %c0_i32 to %c32_i32 step %c1_i32 iter_args(%arg3 = %0) -> (!ttg.async.token) : i32 { %2:3 = "get_offsets"(%arg2) : (i32) -> (i32, i32, i32) - %3 = tt.descriptor_load %arg0[%2#0, %2#2] {ttg.partition = array} : !tt.tensordesc> -> tensor<128x64xf16, #blocked1> - %4 = tt.descriptor_load %arg1[%2#1, %2#2] {ttg.partition = array} : !tt.tensordesc> -> tensor<64x128xf16, #blocked1> + %3 = tt.descriptor_load %arg0[%2#0, %2#2] {ttg.partition = array} : !tt.tensordesc<128x64xf16, #shared> -> tensor<128x64xf16, #blocked1> + %4 = tt.descriptor_load %arg1[%2#1, %2#2] {ttg.partition = array} : !tt.tensordesc<64x128xf16, #shared> -> tensor<64x128xf16, #blocked1> %5 = ttg.local_alloc %3 {ttg.partition = array} : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #smem> %6 = ttg.local_alloc %4 {ttg.partition = array} : (tensor<64x128xf16, #blocked1>) -> !ttg.memdesc<64x128xf16, #shared, #smem> %7 = ttng.tc_gen5_mma %5, %6, %result[%arg3], %true, %true {ttg.partition = array} : !ttg.memdesc<128x64xf16, #shared, #smem>, !ttg.memdesc<64x128xf16, #shared, #smem>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> @@ -192,7 +192,7 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { } // CHECK-LABEL: @matmul_tma_acc_with_conditional_def_and_use - tt.func @matmul_tma_acc_with_conditional_def_and_use(%arg0: !tt.tensordesc>, %arg1: !tt.tensordesc>) { + tt.func @matmul_tma_acc_with_conditional_def_and_use(%arg0: !tt.tensordesc<128x64xf16, #shared>, %arg1: !tt.tensordesc<64x128xf16, #shared>) { %c32_i32 = arith.constant 32 : i32 %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> %true = arith.constant true @@ -205,8 +205,8 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { // CHECK: [[TOK2:%.*]] = scf.for [[I:%.*]] = [[UB:%.*]] to [[LB:%.*]] step [[STEP:%.*]] iter_args([[TOK:%.*]] = [[ATOK]]) %1 = scf.for %arg2 = %c0_i32 to %c32_i32 step %c1_i32 iter_args(%arg3 = %0) -> (!ttg.async.token) : i32 { %2:3 = "get_offsets"(%arg2) {ttg.partition = array} : (i32) -> (i32, i32, i32) - %3 = tt.descriptor_load %arg0[%2#0, %2#2] {ttg.partition = array} : !tt.tensordesc> -> tensor<128x64xf16, #blocked1> - %4 = tt.descriptor_load %arg1[%2#1, %2#2] {ttg.partition = array} : !tt.tensordesc> -> tensor<64x128xf16, #blocked1> + %3 = tt.descriptor_load %arg0[%2#0, %2#2] {ttg.partition = array} : !tt.tensordesc<128x64xf16, #shared> -> tensor<128x64xf16, #blocked1> + %4 = tt.descriptor_load %arg1[%2#1, %2#2] {ttg.partition = array} : !tt.tensordesc<64x128xf16, #shared> -> tensor<64x128xf16, #blocked1> %5 = ttg.local_alloc %3 {ttg.partition = array} : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #smem> %6 = ttg.local_alloc %4 {ttg.partition = array} : (tensor<64x128xf16, #blocked1>) -> !ttg.memdesc<64x128xf16, #shared, #smem> %7 = ttng.tc_gen5_mma %5, %6, %result[%arg3], %true, %true {ttg.partition = array} : !ttg.memdesc<128x64xf16, #shared, #smem>, !ttg.memdesc<64x128xf16, #shared, #smem>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> @@ -239,7 +239,7 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { } // CHECK-LABEL: @matmul_tma_acc_with_conditional_def_and_use_no_multibuf_flag - tt.func @matmul_tma_acc_with_conditional_def_and_use_no_multibuf_flag(%arg0: !tt.tensordesc>, %arg1: !tt.tensordesc>) { + tt.func @matmul_tma_acc_with_conditional_def_and_use_no_multibuf_flag(%arg0: !tt.tensordesc<128x64xf16, #shared>, %arg1: !tt.tensordesc<64x128xf16, #shared>) { %c32_i32 = arith.constant 32 : i32 %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> %true = arith.constant true @@ -250,8 +250,8 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { %0 = ttng.tmem_store %cst, %result[%token], %true : tensor<128x128xf32, #blocked> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> %1:2 = scf.for %arg2 = %c0_i32 to %c32_i32 step %c1_i32 iter_args(%arg3 = %true, %arg4 = %0) -> (i1, !ttg.async.token) : i32 { %2:3 = "get_offsets"(%arg2) {ttg.partition = array} : (i32) -> (i32, i32, i32) - %3 = tt.descriptor_load %arg0[%2#0, %2#2] {ttg.partition = array} : !tt.tensordesc> -> tensor<128x64xf16, #blocked1> - %4 = tt.descriptor_load %arg1[%2#1, %2#2] {ttg.partition = array} : !tt.tensordesc> -> tensor<64x128xf16, #blocked1> + %3 = tt.descriptor_load %arg0[%2#0, %2#2] {ttg.partition = array} : !tt.tensordesc<128x64xf16, #shared> -> tensor<128x64xf16, #blocked1> + %4 = tt.descriptor_load %arg1[%2#1, %2#2] {ttg.partition = array} : !tt.tensordesc<64x128xf16, #shared> -> tensor<64x128xf16, #blocked1> %5 = ttg.local_alloc %3 {ttg.partition = array} : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #smem> %6 = ttg.local_alloc %4 {ttg.partition = array} : (tensor<64x128xf16, #blocked1>) -> !ttg.memdesc<64x128xf16, #shared, #smem> %7 = ttng.tc_gen5_mma %5, %6, %result[%arg4], %arg3, %true {ttg.partition = array} : !ttg.memdesc<128x64xf16, #shared, #smem>, !ttg.memdesc<64x128xf16, #shared, #smem>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> @@ -271,7 +271,7 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { } // CHECK-LABEL: @matmul_scaled_rhs_scales_tma - tt.func @matmul_scaled_rhs_scales_tma(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: !tt.tensordesc>, %arg4: !tt.tensordesc>, %arg5: !tt.tensordesc>) { + tt.func @matmul_scaled_rhs_scales_tma(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: !tt.tensordesc<128x64xf8E4M3FN, #shared2>, %arg4: !tt.tensordesc<128x64xf8E4M3FN, #shared2>, %arg5: !tt.tensordesc<128x8xi8, #shared3>) { // CHECK: [[CST:%.*]] = arith.constant dense<127> : tensor<128x8xi8 // CHECK: [[CST_0:%.*]] = arith.constant dense<{{.*}}> : tensor<128x128xf32 %cst = arith.constant dense<127> : tensor<128x8xi8, #linear> @@ -301,9 +301,9 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { // CHECK-NEXT: local_alloc [[LHS]] // CHECK-NEXT: local_alloc [[RHS]] %2 = arith.muli %arg6, %c64_i32 {ttg.partition = array} : i32 - %3 = tt.descriptor_load %arg3[%arg1, %2] {ttg.partition = array} : !tt.tensordesc> -> tensor<128x64xf8E4M3FN, #blocked1> - %4 = tt.descriptor_load %arg4[%arg2, %2] {ttg.partition = array} : !tt.tensordesc> -> tensor<128x64xf8E4M3FN, #blocked1> - %5 = tt.descriptor_load %arg5[%arg1, %c0_i32] {ttg.partition = array} : !tt.tensordesc> -> tensor<128x8xi8, #linear> + %3 = tt.descriptor_load %arg3[%arg1, %2] {ttg.partition = array} : !tt.tensordesc<128x64xf8E4M3FN, #shared2> -> tensor<128x64xf8E4M3FN, #blocked1> + %4 = tt.descriptor_load %arg4[%arg2, %2] {ttg.partition = array} : !tt.tensordesc<128x64xf8E4M3FN, #shared2> -> tensor<128x64xf8E4M3FN, #blocked1> + %5 = tt.descriptor_load %arg5[%arg1, %c0_i32] {ttg.partition = array} : !tt.tensordesc<128x8xi8, #shared3> -> tensor<128x8xi8, #linear> %6 = ttg.local_alloc %3 {ttg.partition = array} : (tensor<128x64xf8E4M3FN, #blocked1>) -> !ttg.memdesc<128x64xf8E4M3FN, #shared2, #smem> %7 = ttg.local_alloc %4 {ttg.partition = array} : (tensor<128x64xf8E4M3FN, #blocked1>) -> !ttg.memdesc<128x64xf8E4M3FN, #shared2, #smem> %8 = ttg.memdesc_trans %7 {order = array, ttg.partition = array} : !ttg.memdesc<128x64xf8E4M3FN, #shared2, #smem> -> !ttg.memdesc<64x128xf8E4M3FN, #shared4, #smem> @@ -334,14 +334,14 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { } // CHECK-LABEL: @user_partition_has_cycle - tt.func @user_partition_has_cycle(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: !tt.tensordesc>, %arg4: !tt.tensordesc>) { + tt.func @user_partition_has_cycle(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: !tt.tensordesc<128x64xf16, #shared>, %arg4: !tt.tensordesc<128x64xf16, #shared>) { %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> %c64_i32 = arith.constant 64 : i32 %c1_i32 = arith.constant 1 : i32 %c0_i32 = arith.constant 0 : i32 %false = arith.constant false %true = arith.constant true - %0 = tt.descriptor_load %arg3[%c0_i32, %c0_i32] : !tt.tensordesc> -> tensor<128x64xf16, #blocked1> + %0 = tt.descriptor_load %arg3[%c0_i32, %c0_i32] : !tt.tensordesc<128x64xf16, #shared> -> tensor<128x64xf16, #blocked1> // CHECK: [[BUF:%.*]] = ttng.tmem_alloc // CHECK-NEXT: [[AREF:%.*]] = nvws.aref.create [[BUF]] // CHECK-NEXT: {{.*}}, [[ATOK:%.*]] = nvws.aref.put.enter [[AREF]] : @@ -350,7 +350,7 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { %result, %token = ttng.tmem_alloc : () -> (!ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.async.token) %2:2 = scf.for %arg5 = %c0_i32 to %arg0 step %c1_i32 iter_args(%arg6 = %cst, %arg7 = %token) -> (tensor<128x128xf32, #blocked>, !ttg.async.token) : i32 { %3 = arith.muli %arg5, %c64_i32 {ttg.partition = array} : i32 - %4 = tt.descriptor_load %arg4[%arg2, %3] {ttg.partition = array} : !tt.tensordesc> -> tensor<128x64xf16, #blocked1> + %4 = tt.descriptor_load %arg4[%arg2, %3] {ttg.partition = array} : !tt.tensordesc<128x64xf16, #shared> -> tensor<128x64xf16, #blocked1> %5 = ttg.local_alloc %4 {ttg.partition = array} : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #smem> %6 = ttg.memdesc_trans %5 {order = array, ttg.partition = array} : !ttg.memdesc<128x64xf16, #shared, #smem> -> !ttg.memdesc<64x128xf16, #shared1, #smem> // CHECK: [[BUF:%.*]] = nvws.aref.buffer [[AREF]], [[TOK]] {ttg.partition = array} @@ -375,7 +375,7 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { } // CHECK-LABEL: @matmul_tma_acc_with_conditional_def_and_use_flag - tt.func @matmul_tma_acc_with_conditional_def_and_use_flag(%arg0: !tt.tensordesc>, %arg1: !tt.tensordesc>) { + tt.func @matmul_tma_acc_with_conditional_def_and_use_flag(%arg0: !tt.tensordesc<128x64xf16, #shared>, %arg1: !tt.tensordesc<64x128xf16, #shared>) { %c32_i32 = arith.constant 32 : i32 %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> %true = arith.constant true @@ -390,8 +390,8 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { // CHECK-NEXT: scf.for {{.*}} iter_args({{.*}}, [[TOK:%.*]] = [[ATOK]]) %1:2 = scf.for %arg2 = %c0_i32 to %c32_i32 step %c1_i32 iter_args(%arg3 = %true, %arg4 = %0) -> (i1, !ttg.async.token) : i32 { %2:3 = "get_offsets"(%arg2) {ttg.partition = array} : (i32) -> (i32, i32, i32) - %3 = tt.descriptor_load %arg0[%2#0, %2#2] {ttg.partition = array} : !tt.tensordesc> -> tensor<128x64xf16, #blocked1> - %4 = tt.descriptor_load %arg1[%2#1, %2#2] {ttg.partition = array} : !tt.tensordesc> -> tensor<64x128xf16, #blocked1> + %3 = tt.descriptor_load %arg0[%2#0, %2#2] {ttg.partition = array} : !tt.tensordesc<128x64xf16, #shared> -> tensor<128x64xf16, #blocked1> + %4 = tt.descriptor_load %arg1[%2#1, %2#2] {ttg.partition = array} : !tt.tensordesc<64x128xf16, #shared> -> tensor<64x128xf16, #blocked1> %5 = ttg.local_alloc %3 {ttg.partition = array} : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #smem> %6 = ttg.local_alloc %4 {ttg.partition = array} : (tensor<64x128xf16, #blocked1>) -> !ttg.memdesc<64x128xf16, #shared, #smem> // CHECK: aref.buffer [[AREF]], [[TOK]] {ttg.partition = array} @@ -426,7 +426,7 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { } // CHECK-LABEL: @specialize_mma_only - tt.func @specialize_mma_only(%arg0: !tt.tensordesc>, %arg1: !ttg.memdesc<128x64xf16, #shared, #smem>, %arg2: i32) { + tt.func @specialize_mma_only(%arg0: !tt.tensordesc<64x128xf16, #shared>, %arg1: !ttg.memdesc<128x64xf16, #shared, #smem>, %arg2: i32) { %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> %true = arith.constant true %c1_i32 = arith.constant 1 : i32 @@ -440,7 +440,7 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { %0 = ttng.tmem_store %cst, %result[%token], %true : tensor<128x128xf32, #blocked> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> // CHECK-NEXT: [[TOK:%.*]] = scf.for {{.*}} iter_args([[TOK:%.*]] = {{.*}}) %1 = scf.for %arg3 = %c0_i32 to %arg2 step %c1_i32 iter_args(%arg4 = %0) -> (!ttg.async.token) : i32 { - %2 = tt.descriptor_load %arg0[%arg3, %arg3] {ttg.partition = array} : !tt.tensordesc> -> tensor<64x128xf16, #blocked1> + %2 = tt.descriptor_load %arg0[%arg3, %arg3] {ttg.partition = array} : !tt.tensordesc<64x128xf16, #shared> -> tensor<64x128xf16, #blocked1> // CHECK: [[BUF:%.*]] = nvws.aref.buffer [[AREF]], [[TOK]] // CHECK-NEXT: tmem_load [[BUF]] %result_2, %token_3 = ttng.tmem_load %result[%arg4] {ttg.partition = array} : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked> @@ -465,7 +465,7 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { } // CHECK-LABEL: @load_scale_mma_user - tt.func @load_scale_mma_user(%arg0: !ttg.memdesc<128x64xf16, #shared, #smem>, %arg1: !ttg.memdesc<64x128xf16, #shared, #smem>, %arg2: !tt.tensordesc>, %arg3: !ttg.memdesc<128x8xi8, #tmem_scales, #ttng.tensor_memory>, %arg4: i32) { + tt.func @load_scale_mma_user(%arg0: !ttg.memdesc<128x64xf16, #shared, #smem>, %arg1: !ttg.memdesc<64x128xf16, #shared, #smem>, %arg2: !tt.tensordesc<8x128xi8, #shared>, %arg3: !ttg.memdesc<128x8xi8, #tmem_scales, #ttng.tensor_memory>, %arg4: i32) { %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> %true = arith.constant true %c1_i32 = arith.constant 1 : i32 @@ -479,7 +479,7 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { // CHECK-NEXT: [[SCALE_AREF:%.*]] = nvws.aref.create // CHECK-NEXT: [[TOK1:%.*]] = scf.for %1 = scf.for %arg5 = %c0_i32 to %arg4 step %c1_i32 iter_args(%arg6 = %0) -> (!ttg.async.token) : i32 { - %2 = tt.descriptor_load %arg2[%arg5, %arg5] {ttg.partition = array} : !tt.tensordesc> -> tensor<8x128xi8, #blocked1> + %2 = tt.descriptor_load %arg2[%arg5, %arg5] {ttg.partition = array} : !tt.tensordesc<8x128xi8, #shared> -> tensor<8x128xi8, #blocked1> %3 = ttg.local_alloc %2 {ttg.partition = array} : (tensor<8x128xi8, #blocked1>) -> !ttg.memdesc<8x128xi8, #shared, #smem> %4 = ttg.local_load %3 {ttg.partition = array} : !ttg.memdesc<8x128xi8, #shared, #smem> -> tensor<8x128xi8, #linear1> %5 = tt.trans %4 {order = array, ttg.partition = array} : tensor<8x128xi8, #linear1> -> tensor<128x8xi8, #linear> @@ -520,7 +520,7 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { } // CHECK-LABEL: @store_mma_load - tt.func @store_mma_load(%arg0: i32, %arg1: !tt.tensordesc>, %arg2: !ttg.memdesc<64x128xf16, #shared, #smem>) { + tt.func @store_mma_load(%arg0: i32, %arg1: !tt.tensordesc<128x64xf16, #shared>, %arg2: !ttg.memdesc<64x128xf16, #shared, #smem>) { %true = arith.constant true %c1_i32 = arith.constant 1 : i32 %c0_i32 = arith.constant 0 : i32 @@ -529,7 +529,7 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { // CHECK-NEXT: aref.put.enter %result, %token = ttng.tmem_alloc : () -> (!ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.async.token) %0 = scf.for %arg3 = %c0_i32 to %arg0 step %c1_i32 iter_args(%arg4 = %token) -> (!ttg.async.token) : i32 { - %1 = tt.descriptor_load %arg1[%arg3, %arg3] {ttg.partition = array} : !tt.tensordesc> -> tensor<128x64xf16, #blocked1> + %1 = tt.descriptor_load %arg1[%arg3, %arg3] {ttg.partition = array} : !tt.tensordesc<128x64xf16, #shared> -> tensor<128x64xf16, #blocked1> %2 = arith.addf %1, %1 {ttg.partition = array} : tensor<128x64xf16, #blocked1> %3 = ttg.local_alloc %2 {ttg.partition = array} : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #smem> // CHECK: make_acc @@ -557,7 +557,7 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { } // CHECK-LABEL: @local_alloc_into_mma - tt.func @local_alloc_into_mma(%arg0: i32, %arg1: tensor<128x64xf16, #blocked1>, %arg2: !tt.tensordesc>) { + tt.func @local_alloc_into_mma(%arg0: i32, %arg1: tensor<128x64xf16, #blocked1>, %arg2: !tt.tensordesc<64x128xf16, #shared>) { %c1_i32 = arith.constant 1 : i32 %c0_i32 = arith.constant 0 : i32 %true = arith.constant true @@ -567,7 +567,7 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { %result, %token = ttng.tmem_alloc : () -> (!ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.async.token) %5 = scf.for %arg3 = %c0_i32 to %arg0 step %c1_i32 iter_args(%arg4 = %token) -> (!ttg.async.token) : i32 { %0 = ttg.local_alloc %arg1 {ttg.partition = array} : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #smem> - %1 = tt.descriptor_load %arg2[%arg3, %arg3] {ttg.partition = array} : !tt.tensordesc> -> tensor<64x128xf16, #blocked1> + %1 = tt.descriptor_load %arg2[%arg3, %arg3] {ttg.partition = array} : !tt.tensordesc<64x128xf16, #shared> -> tensor<64x128xf16, #blocked1> %2 = arith.addf %1, %1 {ttg.partition = array} : tensor<64x128xf16, #blocked1> %3 = ttg.local_alloc %2 {ttg.partition = array} : (tensor<64x128xf16, #blocked1>) -> !ttg.memdesc<64x128xf16, #shared, #smem> // CHECK: aref.buffer @@ -579,7 +579,7 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { tt.return } - tt.func @shmem_sink_iterator_invalidation(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: !tt.tensordesc>, %arg4: !tt.tensordesc>) { + tt.func @shmem_sink_iterator_invalidation(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: !tt.tensordesc<128x64xf16, #shared>, %arg4: !tt.tensordesc<128x64xf16, #shared>) { %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> %c64_i32 = arith.constant 64 : i32 %c1_i32 = arith.constant 1 : i32 @@ -595,8 +595,8 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { // CHECK-NEXT: [[TOK1:%.*]] = scf.for {{.*}} iter_args([[TOK2:%.*]] = [[ATOK]]) %1 = scf.for %arg5 = %c0_i32 to %arg0 step %c1_i32 iter_args(%arg6 = %0) -> (!ttg.async.token) : i32 { %2 = arith.muli %arg5, %c64_i32 {ttg.partition = array} : i32 - %3 = tt.descriptor_load %arg4[%arg2, %2] {ttg.partition = array} : !tt.tensordesc> -> tensor<128x64xf16, #blocked1> - %4 = tt.descriptor_load %arg3[%arg1, %2] {ttg.partition = array} : !tt.tensordesc> -> tensor<128x64xf16, #blocked1> + %3 = tt.descriptor_load %arg4[%arg2, %2] {ttg.partition = array} : !tt.tensordesc<128x64xf16, #shared> -> tensor<128x64xf16, #blocked1> + %4 = tt.descriptor_load %arg3[%arg1, %2] {ttg.partition = array} : !tt.tensordesc<128x64xf16, #shared> -> tensor<128x64xf16, #blocked1> %5 = ttg.local_alloc %4 {ttg.partition = array} : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #smem> %6 = ttg.local_load %5 {ttg.partition = array} : !ttg.memdesc<128x64xf16, #shared, #smem> -> tensor<128x64xf16, #blocked2> %7 = ttg.local_alloc %3 {ttg.partition = array} : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #smem> @@ -636,7 +636,7 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { #tmem = #ttng.tensor_memory_encoding #tmem1 = #ttng.tensor_memory_encoding module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { - tt.func public @attention_forward(%arg0: !ttg.memdesc<256x64xf16, #shared, #smem>, %arg1: !tt.tensordesc>, %arg2: !tt.tensordesc>, %arg3: f32, %arg4: i32) { + tt.func public @attention_forward(%arg0: !ttg.memdesc<256x64xf16, #shared, #smem>, %arg1: !tt.tensordesc<64x64xf16, #shared>, %arg2: !tt.tensordesc<64x64xf16, #shared>, %arg3: f32, %arg4: i32) { %cst = arith.constant dense<1.000000e+00> : tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>> %cst_0 = arith.constant dense<0.000000e+00> : tensor<256x64xf32, #blocked> %cst_1 = arith.constant dense<0xFF800000> : tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>> @@ -659,7 +659,7 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { // CHECK-NEXT: [[AREF_P:%.*]] = nvws.aref.create // CHECK-NEXT: [[RET:%.*]]:4 = scf.for {{.*}} iter_args([[A1:%.*]] = {{.*}}, [[A2:%.*]] = {{.*}}, [[TOKS:%.*]] = [[TOK_S]], [[TOKO:%.*]] = [[TOK_O]]) %1:4 = scf.for %arg5 = %c0_i32 to %arg4 step %c64_i32 iter_args(%arg6 = %cst, %arg7 = %cst_1, %arg8 = %token, %arg9 = %0) -> (tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, !ttg.async.token, !ttg.async.token) : i32 { - %2 = tt.descriptor_load %arg1[%arg5, %c0_i32] {ttg.partition = array} : !tt.tensordesc> -> tensor<64x64xf16, #blocked1> + %2 = tt.descriptor_load %arg1[%arg5, %c0_i32] {ttg.partition = array} : !tt.tensordesc<64x64xf16, #shared> -> tensor<64x64xf16, #blocked1> %3 = ttg.local_alloc %2 {ttg.partition = array} : (tensor<64x64xf16, #blocked1>) -> !ttg.memdesc<64x64xf16, #shared, #smem> %4 = ttg.memdesc_trans %3 {order = array, ttg.partition = array} : !ttg.memdesc<64x64xf16, #shared, #smem> -> !ttg.memdesc<64x64xf16, #shared1, #smem> // CHECK: [[BUF:%.*]] = nvws.aref.buffer [[AREF_S]], [[TOKS]] {ttg.partition = array} @@ -695,7 +695,7 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { %result_8, %token_9 = ttng.tmem_load %result_2[%arg9] {ttg.partition = array} : !ttg.memdesc<256x64xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<256x64xf32, #blocked> %18 = arith.mulf %result_8, %17 {ttg.partition = array} : tensor<256x64xf32, #blocked> - %19 = tt.descriptor_load %arg2[%arg5, %c0_i32] {ttg.partition = array} : !tt.tensordesc> -> tensor<64x64xf16, #blocked1> + %19 = tt.descriptor_load %arg2[%arg5, %c0_i32] {ttg.partition = array} : !tt.tensordesc<64x64xf16, #shared> -> tensor<64x64xf16, #blocked1> %20 = ttg.local_alloc %19 {ttg.partition = array} : (tensor<64x64xf16, #blocked1>) -> !ttg.memdesc<64x64xf16, #shared, #smem> %21 = arith.truncf %8 {ttg.partition = array} : tensor<256x64xf32, #blocked> to tensor<256x64xf16, #blocked> // CHECK: {{.*}}, [[TOKP:%.*]] = nvws.aref.put.enter [[AREF_P]] {ttg.partition = array} @@ -800,7 +800,7 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { #tmem = #ttng.tensor_memory_encoding module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { // CHECK-LABEL: @if_split_workaround - tt.func @if_split_workaround(%arg0: !tt.tensordesc>, %arg1: tensor<64x128x!tt.ptr, #blocked3> {tt.contiguity = dense<[1, 64]> : tensor<2xi32>, tt.divisibility = dense<16> : tensor<2xi32>}) { + tt.func @if_split_workaround(%arg0: !tt.tensordesc<1x64xf16, #shared>, %arg1: tensor<64x128x!tt.ptr, #blocked3> {tt.contiguity = dense<[1, 64]> : tensor<2xi32>, tt.divisibility = dense<16> : tensor<2xi32>}) { %c0_i32 = arith.constant 0 : i32 %c1_i32 = arith.constant 1 : i32 %true = arith.constant true @@ -813,7 +813,7 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { %1:3 = scf.for %arg2 = %c0_i32 to %c32_i32 step %c1_i32 iter_args(%arg3 = %true, %arg4 = %arg1, %arg5 = %0) -> (i1, tensor<64x128x!tt.ptr, #blocked3>, !ttg.async.token) : i32 { %2:3 = "get_offsets"(%arg2) {loop.cluster = 3 : i32, loop.stage = 0 : i32, ttg.partition = array} : (i32) -> (i32, tensor<64x128xi32, #blocked3>, i32) %3 = tt.splat %2#0 {loop.cluster = 3 : i32, loop.stage = 0 : i32, ttg.partition = array} : i32 -> tensor<128xi32, #blocked2> - %4 = tt.descriptor_gather %arg0[%3, %2#2] {loop.cluster = 3 : i32, loop.stage = 0 : i32, ttg.partition = array} : (!tt.tensordesc>, tensor<128xi32, #blocked2>, i32) -> tensor<128x64xf16, #blocked1> + %4 = tt.descriptor_gather %arg0[%3, %2#2] {loop.cluster = 3 : i32, loop.stage = 0 : i32, ttg.partition = array} : (!tt.tensordesc<1x64xf16, #shared>, tensor<128xi32, #blocked2>, i32) -> tensor<128x64xf16, #blocked1> %5 = tt.addptr %arg4, %2#1 {loop.cluster = 3 : i32, loop.stage = 1 : i32, tt.constancy = dense<1> : tensor<2xi32>, tt.contiguity = dense<[1, 64]> : tensor<2xi32>, tt.divisibility = dense<16> : tensor<2xi32>, ttg.partition = array} : tensor<64x128x!tt.ptr, #blocked3>, tensor<64x128xi32, #blocked3> %6 = tt.load %5 {loop.cluster = 3 : i32, loop.stage = 1 : i32, ttg.partition = array} : tensor<64x128x!tt.ptr, #blocked3> %7 = ttg.local_alloc %4 {loop.cluster = 2 : i32, loop.stage = 2 : i32, ttg.partition = array} : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #smem> diff --git a/test/NVWS/assign_stage_phase.mlir b/test/NVWS/assign_stage_phase.mlir index a5043856530a..63acbe56c9a0 100644 --- a/test/NVWS/assign_stage_phase.mlir +++ b/test/NVWS/assign_stage_phase.mlir @@ -192,7 +192,7 @@ module attributes {"ttg.target" = "cuda:0", "ttg.num-ctas" = 1 : i32, "ttg.num-w module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { // CHECK-LABEL: @warp_specialize_tma_matmul - tt.func @warp_specialize_tma_matmul(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: !tt.tensordesc>, %arg4: !tt.tensordesc>) { + tt.func @warp_specialize_tma_matmul(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: !tt.tensordesc<128x64xf16, #shared>, %arg4: !tt.tensordesc<128x64xf16, #shared>) { %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> %c64_i32 = arith.constant 64 : i32 %c1_i32 = arith.constant 1 : i32 @@ -212,8 +212,8 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { %2 = ttng.tmem_store %cst, %1[], %true : tensor<128x128xf32, #blocked> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable, 1x128x128> scf.for %arg5 = %c0_i32 to %arg0 step %c1_i32 : i32 { %4 = arith.muli %arg5, %c64_i32 {ttg.partition = array} : i32 - %5 = tt.descriptor_load %arg3[%arg1, %4] {ttg.partition = array} : !tt.tensordesc> -> tensor<128x64xf16, #blocked1> - %6 = tt.descriptor_load %arg4[%arg2, %4] {ttg.partition = array} : !tt.tensordesc> -> tensor<128x64xf16, #blocked1> + %5 = tt.descriptor_load %arg3[%arg1, %4] {ttg.partition = array} : !tt.tensordesc<128x64xf16, #shared> -> tensor<128x64xf16, #blocked1> + %6 = tt.descriptor_load %arg4[%arg2, %4] {ttg.partition = array} : !tt.tensordesc<128x64xf16, #shared> -> tensor<128x64xf16, #blocked1> %7 = ttg.local_alloc %5 {ttg.partition = array} : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #smem> %8 = ttg.local_alloc %6 {ttg.partition = array} : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #smem> %9 = ttg.memdesc_trans %8 {order = array, ttg.partition = array} : !ttg.memdesc<128x64xf16, #shared, #smem> -> !ttg.memdesc<64x128xf16, #shared1, #smem> @@ -237,7 +237,7 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { } // CHECK-LABEL: @matmul_tma_acc_with_unconditional_user - tt.func @matmul_tma_acc_with_unconditional_user(%arg0: !tt.tensordesc>, %arg1: !tt.tensordesc>) { + tt.func @matmul_tma_acc_with_unconditional_user(%arg0: !tt.tensordesc<128x64xf16, #shared>, %arg1: !tt.tensordesc<64x128xf16, #shared>) { // CHECK: [[C1:%.*]] = arith.constant 1 // CHECK: [[C0:%.*]] = arith.constant 0 %c32_i32 = arith.constant 32 : i32 @@ -262,8 +262,8 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { // CHECK: [[RET:%.*]]:5 = scf.for {{.*}} iter_args([[TOK:%.*]] = [[ATOK:%.*]], [[S0:%.*]] = [[S]], [[P0:%.*]] = [[P]], [[S1:%.*]] = [[C1]], [[P1:%.*]] = [[C1]]) %3 = scf.for %arg2 = %c0_i32 to %c32_i32 step %c1_i32 iter_args(%arg3 = %token) -> (!ttg.async.token) : i32 { %4:3 = "get_offsets"(%arg2) {ttg.partition = array} : (i32) -> (i32, i32, i32) - %5 = tt.descriptor_load %arg0[%4#0, %4#2] {ttg.partition = array} : !tt.tensordesc> -> tensor<128x64xf16, #blocked1> - %6 = tt.descriptor_load %arg1[%4#1, %4#2] {ttg.partition = array} : !tt.tensordesc> -> tensor<64x128xf16, #blocked1> + %5 = tt.descriptor_load %arg0[%4#0, %4#2] {ttg.partition = array} : !tt.tensordesc<128x64xf16, #shared> -> tensor<128x64xf16, #blocked1> + %6 = tt.descriptor_load %arg1[%4#1, %4#2] {ttg.partition = array} : !tt.tensordesc<64x128xf16, #shared> -> tensor<64x128xf16, #blocked1> %7 = ttg.local_alloc %5 {ttg.partition = array} : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #smem> %8 = ttg.local_alloc %6 {ttg.partition = array} : (tensor<64x128xf16, #blocked1>) -> !ttg.memdesc<64x128xf16, #shared, #smem> // CHECK: nvws.aref.buffer [[AREF]][[[S0]] @@ -319,7 +319,7 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { #tmem = #ttng.tensor_memory_encoding module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { // CHECK-LABEL: @assign_stage_buffer - tt.func @assign_stage_buffer(%arg0: !tt.tensordesc>, %arg1: !tt.tensordesc>) { + tt.func @assign_stage_buffer(%arg0: !tt.tensordesc<128x64xf16, #shared>, %arg1: !tt.tensordesc<64x128xf16, #shared>) { %c32_i32 = arith.constant 32 : i32 %cst = arith.constant dense<1.000000e+00> : tensor<128x128xf32, #blocked> %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> @@ -337,8 +337,8 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { // CHECK: scf.for {{.*}} iter_args([[TOK1:%.*]] = [[TOK]], [[SPUT:%.*]] = {{.*}}, {{.*}} = {{.*}}, {{.*}} = {{.*}}, {{.*}} = {{.*}}) %3 = scf.for %arg2 = %c0_i32 to %c32_i32 step %c1_i32 iter_args(%arg3 = %token) -> (!ttg.async.token) : i32 { %4:3 = "get_offsets"(%arg2) {ttg.partition = array} : (i32) -> (i32, i32, i32) - %5 = tt.descriptor_load %arg0[%4#0, %4#2] {ttg.partition = array} : !tt.tensordesc> -> tensor<128x64xf16, #blocked1> - %6 = tt.descriptor_load %arg1[%4#1, %4#2] {ttg.partition = array} : !tt.tensordesc> -> tensor<64x128xf16, #blocked1> + %5 = tt.descriptor_load %arg0[%4#0, %4#2] {ttg.partition = array} : !tt.tensordesc<128x64xf16, #shared> -> tensor<128x64xf16, #blocked1> + %6 = tt.descriptor_load %arg1[%4#1, %4#2] {ttg.partition = array} : !tt.tensordesc<64x128xf16, #shared> -> tensor<64x128xf16, #blocked1> %7 = ttg.local_alloc %5 {ttg.partition = array} : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #smem> %8 = ttg.local_alloc %6 {ttg.partition = array} : (tensor<64x128xf16, #blocked1>) -> !ttg.memdesc<64x128xf16, #shared, #smem> // CHECK: nvws.aref.buffer [[AREF]][[[SPUT]]], [[TOK1]] @@ -382,7 +382,7 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { #tmem = #ttng.tensor_memory_encoding module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { // CHECK-LABEL: @attention_forward - tt.func public @attention_forward(%arg0: !ttg.memdesc<256x64xf16, #shared, #smem>, %arg1: !tt.tensordesc>, %arg2: !tt.tensordesc>, %arg3: f32, %arg4: i32) { + tt.func public @attention_forward(%arg0: !ttg.memdesc<256x64xf16, #shared, #smem>, %arg1: !tt.tensordesc<64x64xf16, #shared>, %arg2: !tt.tensordesc<64x64xf16, #shared>, %arg3: f32, %arg4: i32) { %cst = arith.constant dense<1.000000e+00> : tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>> %cst_0 = arith.constant dense<0.000000e+00> : tensor<256x64xf32, #blocked> %cst_1 = arith.constant dense<0xFF800000> : tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>> @@ -402,7 +402,7 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { %4 = nvws.aref.create %result_5 : <[!ttg.memdesc<1x256x64xf16, #tmem, #ttng.tensor_memory, mutable>]> // CHECK: [[RET:%.*]]:16 = scf.for %5:4 = scf.for %arg5 = %c0_i32 to %arg4 step %c64_i32 iter_args(%arg6 = %cst, %arg7 = %cst_1, %arg8 = %token, %arg9 = %token_4) -> (tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, !ttg.async.token, !ttg.async.token) : i32 { - %7 = tt.descriptor_load %arg1[%arg5, %c0_i32] {ttg.partition = array} : !tt.tensordesc> -> tensor<64x64xf16, #blocked1> + %7 = tt.descriptor_load %arg1[%arg5, %c0_i32] {ttg.partition = array} : !tt.tensordesc<64x64xf16, #shared> -> tensor<64x64xf16, #blocked1> %8 = ttg.local_alloc %7 {ttg.partition = array} : (tensor<64x64xf16, #blocked1>) -> !ttg.memdesc<64x64xf16, #shared, #smem> %9 = ttg.memdesc_trans %8 {order = array, ttg.partition = array} : !ttg.memdesc<64x64xf16, #shared, #smem> -> !ttg.memdesc<64x64xf16, #shared1, #smem> %10 = nvws.aref.buffer %0, %arg8 {ttg.partition = array} : <[!ttg.memdesc<2x256x64xf32, #tmem, #ttng.tensor_memory, mutable>]>, !ttg.async.token -> !ttg.memdesc<256x64xf32, #tmem, #ttng.tensor_memory, mutable, 2x256x64> @@ -431,7 +431,7 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { %25 = nvws.aref.buffer %1, %arg9 {ttg.partition = array} : <[!ttg.memdesc<1x256x64xf32, #tmem, #ttng.tensor_memory, mutable>]>, !ttg.async.token -> !ttg.memdesc<256x64xf32, #tmem, #ttng.tensor_memory, mutable, 1x256x64> %result_14, %token_15 = ttng.tmem_load %25[] {ttg.partition = array} : !ttg.memdesc<256x64xf32, #tmem, #ttng.tensor_memory, mutable, 1x256x64> -> tensor<256x64xf32, #blocked> %26 = arith.mulf %result_14, %24 {ttg.partition = array} : tensor<256x64xf32, #blocked> - %27 = tt.descriptor_load %arg2[%arg5, %c0_i32] {ttg.partition = array} : !tt.tensordesc> -> tensor<64x64xf16, #blocked1> + %27 = tt.descriptor_load %arg2[%arg5, %c0_i32] {ttg.partition = array} : !tt.tensordesc<64x64xf16, #shared> -> tensor<64x64xf16, #blocked1> %28 = ttg.local_alloc %27 {ttg.partition = array} : (tensor<64x64xf16, #blocked1>) -> !ttg.memdesc<64x64xf16, #shared, #smem> %29 = arith.truncf %15 {ttg.partition = array} : tensor<256x64xf32, #blocked> to tensor<256x64xf16, #blocked> %buffers_16, %token_17 = nvws.aref.put.enter %4 {ttg.partition = array} : <[!ttg.memdesc<1x256x64xf16, #tmem, #ttng.tensor_memory, mutable>]> -> !ttg.memdesc<256x64xf16, #tmem, #ttng.tensor_memory, mutable, 1x256x64>, !ttg.async.token @@ -510,7 +510,7 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { #tmem_scales = #ttng.tensor_memory_scales_encoding<> module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { // CHECK-LABEL: @matmul_tma_acc_with_conditional_user - tt.func @matmul_tma_acc_with_conditional_user(%arg0: !tt.tensordesc>, %arg1: !tt.tensordesc>) { + tt.func @matmul_tma_acc_with_conditional_user(%arg0: !tt.tensordesc<128x64xf16, #shared>, %arg1: !tt.tensordesc<64x128xf16, #shared>) { %c32_i32 = arith.constant 32 : i32 %cst = arith.constant dense<1.000000e+00> : tensor<128x128xf32, #blocked> %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> @@ -524,8 +524,8 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { %2 = ttng.tmem_store %cst_0, %1[], %true : tensor<128x128xf32, #blocked> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable, 2x128x128> %3 = scf.for %arg2 = %c0_i32 to %c32_i32 step %c1_i32 iter_args(%arg3 = %token) -> (!ttg.async.token) : i32 { %4:3 = "get_offsets"(%arg2) {ttg.partition = array} : (i32) -> (i32, i32, i32) - %5 = tt.descriptor_load %arg0[%4#0, %4#2] {ttg.partition = array} : !tt.tensordesc> -> tensor<128x64xf16, #blocked1> - %6 = tt.descriptor_load %arg1[%4#1, %4#2] {ttg.partition = array} : !tt.tensordesc> -> tensor<64x128xf16, #blocked1> + %5 = tt.descriptor_load %arg0[%4#0, %4#2] {ttg.partition = array} : !tt.tensordesc<128x64xf16, #shared> -> tensor<128x64xf16, #blocked1> + %6 = tt.descriptor_load %arg1[%4#1, %4#2] {ttg.partition = array} : !tt.tensordesc<64x128xf16, #shared> -> tensor<64x128xf16, #blocked1> %7 = ttg.local_alloc %5 {ttg.partition = array} : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #smem> %8 = ttg.local_alloc %6 {ttg.partition = array} : (tensor<64x128xf16, #blocked1>) -> !ttg.memdesc<64x128xf16, #shared, #smem> %9 = nvws.aref.buffer %0, %arg3 {ttg.partition = array} : <[!ttg.memdesc<2x128x128xf32, #tmem, #ttng.tensor_memory, mutable>]>, !ttg.async.token -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable, 2x128x128> @@ -576,11 +576,11 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ %c8_i32 = arith.constant 8 : i32 %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> %0 = arith.extsi %arg3 : i32 to i64 - %1 = tt.make_tensor_descriptor %arg0, [%arg6, %arg8], [%0, %c1_i64] : , > + %1 = tt.make_tensor_descriptor %arg0, [%arg6, %arg8], [%0, %c1_i64] : , <128x128xf8E4M3FN, #shared> %2 = arith.extsi %arg4 : i32 to i64 - %3 = tt.make_tensor_descriptor %arg1, [%arg7, %arg8], [%2, %c1_i64] : , > + %3 = tt.make_tensor_descriptor %arg1, [%arg7, %arg8], [%2, %c1_i64] : , <128x128xf8E4M3FN, #shared> %4 = arith.extsi %arg5 : i32 to i64 - %5 = tt.make_tensor_descriptor %arg2, [%arg6, %arg7], [%4, %c1_i64] : , > + %5 = tt.make_tensor_descriptor %arg2, [%arg6, %arg7], [%4, %c1_i64] : , <128x128xf8E4M3FN, #shared> %6 = tt.get_program_id x : i32 %7 = arith.addi %arg6, %c127_i32 : i32 %8 = arith.divsi %7, %c128_i32 : i32 @@ -637,7 +637,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // CHECK: arith.select {{.*}} ttg.partition = array // CHECK-NEXT: aref.put.enter {{.*}} ttg.partition = array %buffers_8, %token_9 = nvws.aref.put.enter %16 {loop.cluster = 2 : i32, loop.stage = 0 : i32, ttg.partition = array} : <[!ttg.memdesc<1x128x128xf8E4M3FN, #shared, #smem, mutable>]> -> !ttg.memdesc<128x128xf8E4M3FN, #shared, #smem, mutable, 1x128x128>, !ttg.async.token - nvws.descriptor_load %1[%28, %36] 16384 %buffers_8 {loop.cluster = 2 : i32, loop.stage = 0 : i32, ttg.partition = array} : !tt.tensordesc>, i32, i32, !ttg.memdesc<128x128xf8E4M3FN, #shared, #smem, mutable, 1x128x128> + nvws.descriptor_load %1[%28, %36] 16384 %buffers_8 {loop.cluster = 2 : i32, loop.stage = 0 : i32, ttg.partition = array} : !tt.tensordesc<128x128xf8E4M3FN, #shared>, i32, i32, !ttg.memdesc<128x128xf8E4M3FN, #shared, #smem, mutable, 1x128x128> nvws.aref.put.exit %16, %token_9 [#nvws.async_op] {loop.cluster = 2 : i32, loop.stage = 0 : i32, ttg.partition = array} : <[!ttg.memdesc<1x128x128xf8E4M3FN, #shared, #smem, mutable>]>, !ttg.async.token // CHECK: aref.put.exit {{.*}} ttg.partition = array // CHECK: arith.addi {{.*}} {ttg.partition = array} @@ -650,7 +650,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // CHECK-NOT: partition = array} : <[!ttg.memdesc<1x128x128xf8E4M3FN, #shared, #smem, mutable>]> -> !ttg.memdesc<128x128xf8E4M3FN, #shared, #smem, 1x128x128>, !ttg.async.token %buffers_12, %token_13 = nvws.aref.put.enter %18 {loop.cluster = 2 : i32, loop.stage = 0 : i32, ttg.partition = array} : <[!ttg.memdesc<1x128x128xf8E4M3FN, #shared, #smem, mutable>]> -> !ttg.memdesc<128x128xf8E4M3FN, #shared, #smem, mutable, 1x128x128>, !ttg.async.token - nvws.descriptor_load %3[%29, %36] 16384 %buffers_12 {loop.cluster = 2 : i32, loop.stage = 0 : i32, ttg.partition = array} : !tt.tensordesc>, i32, i32, !ttg.memdesc<128x128xf8E4M3FN, #shared, #smem, mutable, 1x128x128> + nvws.descriptor_load %3[%29, %36] 16384 %buffers_12 {loop.cluster = 2 : i32, loop.stage = 0 : i32, ttg.partition = array} : !tt.tensordesc<128x128xf8E4M3FN, #shared>, i32, i32, !ttg.memdesc<128x128xf8E4M3FN, #shared, #smem, mutable, 1x128x128> nvws.aref.put.exit %18, %token_13 [#nvws.async_op] {loop.cluster = 2 : i32, loop.stage = 0 : i32, ttg.partition = array} : <[!ttg.memdesc<1x128x128xf8E4M3FN, #shared, #smem, mutable>]>, !ttg.async.token %buffers_14, %token_15 = nvws.aref.get.enter %18 {ttg.partition = array} : <[!ttg.memdesc<1x128x128xf8E4M3FN, #shared, #smem, mutable>]> -> !ttg.memdesc<128x128xf8E4M3FN, #shared, #smem, 1x128x128>, !ttg.async.token %37 = ttg.memdesc_trans %buffers_14 {loop.cluster = 0 : i32, loop.stage = 2 : i32, order = array, ttg.partition = array} : !ttg.memdesc<128x128xf8E4M3FN, #shared, #smem, 1x128x128> -> !ttg.memdesc<128x128xf8E4M3FN, #shared1, #smem, 1x128x128> @@ -670,7 +670,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ nvws.aref.get.exit %19, %token_7 [#nvws.async_op] {ttg.partition = array} : <[!ttg.memdesc<1x128x128xf32, #tmem, #ttng.tensor_memory, mutable>]>, !ttg.async.token %34 = tt.fp_to_fp %result_4 {ttg.partition = array}, rounding = rtne : tensor<128x128xf32, #blocked> -> tensor<128x128xf8E4M3FN, #blocked> %35 = ttg.convert_layout %34 {ttg.partition = array} : tensor<128x128xf8E4M3FN, #blocked> -> tensor<128x128xf8E4M3FN, #blocked1> - tt.descriptor_store %5[%28, %29], %35 {ttg.partition = array} : !tt.tensordesc>, tensor<128x128xf8E4M3FN, #blocked1> + tt.descriptor_store %5[%28, %29], %35 {ttg.partition = array} : !tt.tensordesc<128x128xf8E4M3FN, #shared>, tensor<128x128xf8E4M3FN, #blocked1> } {tt.num_stages = 3 : i32, tt.warp_specialize, ttg.partition.stages = [0 : i32, 1 : i32, 0 : i32], ttg.warp_specialize.tag = 0 : i32, ttg.partition = array} tt.return } diff --git a/test/NVWS/hoist_tmem_store.mlir b/test/NVWS/hoist_tmem_store.mlir index e8c2eaa085e5..34932b094840 100644 --- a/test/NVWS/hoist_tmem_store.mlir +++ b/test/NVWS/hoist_tmem_store.mlir @@ -7,7 +7,7 @@ #smem = #ttg.shared_memory #tmem = #ttng.tensor_memory_encoding module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { - tt.func public @matmul_nested_persistent_ws_kernel(%arg0: !tt.tensordesc>, %arg1: !tt.tensordesc>, %arg2: !tt.tensordesc>, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + tt.func public @matmul_nested_persistent_ws_kernel(%arg0: !tt.tensordesc<128x128xf8E4M3FN, #shared>, %arg1: !tt.tensordesc<128x128xf8E4M3FN, #shared>, %arg2: !tt.tensordesc<128x128xf8E4M3FN, #shared>, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { %false = arith.constant false %true = arith.constant true %c128_i32 = arith.constant 128 : i32 @@ -49,9 +49,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ llvm.intr.assume %18 : i1 {ttg.partition = array} %19:2 = scf.for %arg7 = %c0_i32 to %17 step %c1_i32 iter_args(%arg8 = %false, %arg9 = %16) -> (i1, !ttg.async.token) : i32 { %22 = arith.muli %arg7, %c128_i32 {loop.cluster = 2 : i32, loop.stage = 0 : i32, ttg.partition = array} : i32 - %23 = tt.descriptor_load %arg0[%14, %22] {loop.cluster = 2 : i32, loop.stage = 0 : i32, ttg.partition = array} : !tt.tensordesc> -> tensor<128x128xf8E4M3FN, #blocked1> + %23 = tt.descriptor_load %arg0[%14, %22] {loop.cluster = 2 : i32, loop.stage = 0 : i32, ttg.partition = array} : !tt.tensordesc<128x128xf8E4M3FN, #shared> -> tensor<128x128xf8E4M3FN, #blocked1> %24 = ttg.local_alloc %23 {loop.cluster = 0 : i32, loop.stage = 2 : i32, ttg.partition = array} : (tensor<128x128xf8E4M3FN, #blocked1>) -> !ttg.memdesc<128x128xf8E4M3FN, #shared, #smem> - %25 = tt.descriptor_load %arg1[%15, %22] {loop.cluster = 2 : i32, loop.stage = 0 : i32, ttg.partition = array} : !tt.tensordesc> -> tensor<128x128xf8E4M3FN, #blocked1> + %25 = tt.descriptor_load %arg1[%15, %22] {loop.cluster = 2 : i32, loop.stage = 0 : i32, ttg.partition = array} : !tt.tensordesc<128x128xf8E4M3FN, #shared> -> tensor<128x128xf8E4M3FN, #blocked1> %26 = ttg.local_alloc %25 {loop.cluster = 0 : i32, loop.stage = 2 : i32, ttg.partition = array} : (tensor<128x128xf8E4M3FN, #blocked1>) -> !ttg.memdesc<128x128xf8E4M3FN, #shared, #smem> %27 = ttg.memdesc_trans %26 {loop.cluster = 0 : i32, loop.stage = 2 : i32, order = array, ttg.partition = array} : !ttg.memdesc<128x128xf8E4M3FN, #shared, #smem> -> !ttg.memdesc<128x128xf8E4M3FN, #shared1, #smem> %28 = ttng.tc_gen5_mma %24, %27, %result[%arg9], %arg8, %true {loop.cluster = 0 : i32, loop.stage = 2 : i32, tt.self_latency = 1 : i32, ttg.partition = array} : !ttg.memdesc<128x128xf8E4M3FN, #shared, #smem>, !ttg.memdesc<128x128xf8E4M3FN, #shared1, #smem>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> @@ -80,9 +80,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ %17 = arith.addi %3, %arg6 {ttg.partition = array} : i32 %19:2 = scf.for %arg7 = %c0_i32 to %17 step %c1_i32 iter_args(%arg8 = %false, %arg9 = %16) -> (i1, !ttg.async.token) : i32 { %22 = arith.muli %arg7, %c128_i32 {loop.cluster = 2 : i32, loop.stage = 0 : i32, ttg.partition = array} : i32 - %23 = tt.descriptor_load %arg0[%14, %22] {loop.cluster = 2 : i32, loop.stage = 0 : i32, ttg.partition = array} : !tt.tensordesc> -> tensor<128x128xf8E4M3FN, #blocked1> + %23 = tt.descriptor_load %arg0[%14, %22] {loop.cluster = 2 : i32, loop.stage = 0 : i32, ttg.partition = array} : !tt.tensordesc<128x128xf8E4M3FN, #shared> -> tensor<128x128xf8E4M3FN, #blocked1> %24 = ttg.local_alloc %23 {loop.cluster = 0 : i32, loop.stage = 2 : i32, ttg.partition = array} : (tensor<128x128xf8E4M3FN, #blocked1>) -> !ttg.memdesc<128x128xf8E4M3FN, #shared, #smem> - %25 = tt.descriptor_load %arg1[%15, %22] {loop.cluster = 2 : i32, loop.stage = 0 : i32, ttg.partition = array} : !tt.tensordesc> -> tensor<128x128xf8E4M3FN, #blocked1> + %25 = tt.descriptor_load %arg1[%15, %22] {loop.cluster = 2 : i32, loop.stage = 0 : i32, ttg.partition = array} : !tt.tensordesc<128x128xf8E4M3FN, #shared> -> tensor<128x128xf8E4M3FN, #blocked1> %26 = ttg.local_alloc %25 {loop.cluster = 0 : i32, loop.stage = 2 : i32, ttg.partition = array} : (tensor<128x128xf8E4M3FN, #blocked1>) -> !ttg.memdesc<128x128xf8E4M3FN, #shared, #smem> %27 = ttg.memdesc_trans %26 {loop.cluster = 0 : i32, loop.stage = 2 : i32, order = array, ttg.partition = array} : !ttg.memdesc<128x128xf8E4M3FN, #shared, #smem> -> !ttg.memdesc<128x128xf8E4M3FN, #shared1, #smem> %28 = ttng.tc_gen5_mma %24, %27, %result[%arg9], %arg8, %true {loop.cluster = 0 : i32, loop.stage = 2 : i32, tt.self_latency = 1 : i32, ttg.partition = array} : !ttg.memdesc<128x128xf8E4M3FN, #shared, #smem>, !ttg.memdesc<128x128xf8E4M3FN, #shared1, #smem>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> diff --git a/test/NVWS/insert_aref.mlir b/test/NVWS/insert_aref.mlir index 363e7fbad48f..90cf2b19d5be 100644 --- a/test/NVWS/insert_aref.mlir +++ b/test/NVWS/insert_aref.mlir @@ -16,7 +16,7 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { // FUNC-LABEL: @warp_specialize_tma_matmul // CHECK: @warp_specialize_tma_matmul - tt.func @warp_specialize_tma_matmul(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: !tt.tensordesc>, %arg4: !tt.tensordesc>) { + tt.func @warp_specialize_tma_matmul(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: !tt.tensordesc<128x64xf16, #shared>, %arg4: !tt.tensordesc<128x64xf16, #shared>) { %true = arith.constant true %c0_i32 = arith.constant 0 : i32 %c1_i32 = arith.constant 1 : i32 @@ -34,11 +34,11 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { // CHECK: [[PUT_BUF1:%.*]], [[TOKEN1:%.*]] = nvws.aref.put.enter [[AREF1]] {loop.cluster = 1 : i32, loop.stage = 0 : i32, ttg.partition = array} // CHECK-NEXT: nvws.descriptor_load {{.*}} 16384 [[PUT_BUF1]] // CHECK: nvws.aref.put.exit [[AREF1]], [[TOKEN1]] [#nvws.async_op] {loop.cluster = 1 : i32, loop.stage = 0 : i32, ttg.partition = array} - %3 = tt.descriptor_load %arg3[%arg1, %2] {loop.cluster = 1 : i32, loop.stage = 0 : i32, ttg.partition = array} : !tt.tensordesc> -> tensor<128x64xf16, #blocked1> + %3 = tt.descriptor_load %arg3[%arg1, %2] {loop.cluster = 1 : i32, loop.stage = 0 : i32, ttg.partition = array} : !tt.tensordesc<128x64xf16, #shared> -> tensor<128x64xf16, #blocked1> // CHECK: [[PUT_BUF2:%.*]], [[TOKEN2:%.*]] = nvws.aref.put.enter [[AREF2]] // CHECK-NEXT: nvws.descriptor_load {{.*}} 16384 [[PUT_BUF2]] // CHECK: nvws.aref.put.exit [[AREF2]] - %4 = tt.descriptor_load %arg4[%arg2, %2] {loop.cluster = 1 : i32, loop.stage = 0 : i32, ttg.partition = array} : !tt.tensordesc> -> tensor<128x64xf16, #blocked1> + %4 = tt.descriptor_load %arg4[%arg2, %2] {loop.cluster = 1 : i32, loop.stage = 0 : i32, ttg.partition = array} : !tt.tensordesc<128x64xf16, #shared> -> tensor<128x64xf16, #blocked1> %5 = ttg.local_alloc %3 {loop.cluster = 0 : i32, loop.stage = 1 : i32, ttg.partition = array} : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #smem> %6 = ttg.local_alloc %4 {loop.cluster = 0 : i32, loop.stage = 1 : i32, ttg.partition = array} : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #smem> @@ -59,14 +59,14 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { } // CHECK-LABEL: @specialize_load_only - tt.func @specialize_load_only(%arg0: !tt.tensordesc>, %arg1: i32) { + tt.func @specialize_load_only(%arg0: !tt.tensordesc<128x64xf16, #shared>, %arg1: i32) { %c0_i32 = arith.constant 0 : i32 %c1_i32 = arith.constant 1 : i32 scf.for %arg2 = %c0_i32 to %arg1 step %c1_i32 : i32 { // CHECK: nvws.aref.put.enter // CHECK: nvws.descriptor_load // CHECK: nvws.aref.put.exit - %0 = tt.descriptor_load %arg0[%arg2, %arg2] {loop.cluster = 1 : i32, loop.stage = 0, ttg.partition = array} : !tt.tensordesc> -> tensor<128x64xf16, #blocked1> + %0 = tt.descriptor_load %arg0[%arg2, %arg2] {loop.cluster = 1 : i32, loop.stage = 0, ttg.partition = array} : !tt.tensordesc<128x64xf16, #shared> -> tensor<128x64xf16, #blocked1> // CHECK: {{.*}}, [[GET_TOKEN:%.*]] = nvws.aref.get.enter // CHECK: [[REG:%.*]] = ttg.local_load // CHECK: nvws.aref.get.exit {{.*}}, [[GET_TOKEN]] [#nvws.async_op] @@ -113,14 +113,14 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { } // CHECK-LABEL: @load_used_as_reg_and_smem - tt.func @load_used_as_reg_and_smem(%arg0: !tt.tensordesc>, %arg1: i32) { + tt.func @load_used_as_reg_and_smem(%arg0: !tt.tensordesc<128x64xf16, #shared>, %arg1: i32) { %c0_i32 = arith.constant 0 : i32 %c1_i32 = arith.constant 1 : i32 scf.for %arg2 = %c0_i32 to %arg1 step %c1_i32 : i32 { // CHECK: nvws.aref.put.enter // CHECK: nvws.descriptor_load // CHECK: nvws.aref.put.exit - %0 = tt.descriptor_load %arg0[%arg2, %arg2] {loop.cluster = 2 : i32, loop.stage = 0 : i32, ttg.partition = array} : !tt.tensordesc> -> tensor<128x64xf16, #blocked1> + %0 = tt.descriptor_load %arg0[%arg2, %arg2] {loop.cluster = 2 : i32, loop.stage = 0 : i32, ttg.partition = array} : !tt.tensordesc<128x64xf16, #shared> -> tensor<128x64xf16, #blocked1> %alloc = ttg.local_alloc %0 {loop.cluster = 0 : i32, loop.stage = 1 : i32, ttg.partition = array} : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #smem> // CHECK-DAG: [[GET_BUF1:%.*]], [[GET_TOKEN1:%.*]] = nvws.aref.get.enter {{.*}} {loop.cluster = 1 : i32, loop.stage = 1 : i32, ttg.partition = array} // CHECK-DAG: [[REG:%.*]] = ttg.local_load [[GET_BUF1]] {loop.cluster = 1 : i32, loop.stage = 1 : i32, ttg.partition = array} @@ -136,14 +136,14 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { } // CHECK-LABEL: @load_used_as_reg_and_smem_same_partition - tt.func @load_used_as_reg_and_smem_same_partition(%arg0: !tt.tensordesc>, %arg1: i32) { + tt.func @load_used_as_reg_and_smem_same_partition(%arg0: !tt.tensordesc<128x64xf16, #shared>, %arg1: i32) { %c0_i32 = arith.constant 0 : i32 %c1_i32 = arith.constant 1 : i32 scf.for %arg2 = %c0_i32 to %arg1 step %c1_i32 : i32 { // CHECK: nvws.aref.put.enter // CHECK: nvws.descriptor_load // CHECK: nvws.aref.put.exit - %0 = tt.descriptor_load %arg0[%arg2, %arg2] {loop.cluster = 2 : i32, loop.stage = 0 : i32, ttg.partition = array} : !tt.tensordesc> -> tensor<128x64xf16, #blocked1> + %0 = tt.descriptor_load %arg0[%arg2, %arg2] {loop.cluster = 2 : i32, loop.stage = 0 : i32, ttg.partition = array} : !tt.tensordesc<128x64xf16, #shared> -> tensor<128x64xf16, #blocked1> %alloc = ttg.local_alloc %0 {loop.cluster = 0 : i32, loop.stage = 1 : i32, ttg.partition = array} : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #smem> // CHECK: [[GET_BUF:%.*]], [[GET_TOKEN:%.*]] = nvws.aref.get.enter {{.*}} {loop.cluster = 0 : i32, loop.stage = 1 // CHECK: [[REG:%.*]] = ttg.local_load [[GET_BUF]] {loop.cluster = 0 : i32, loop.stage = 1 : i32, ttg.partition = array} @@ -157,7 +157,7 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { } // CHECK-LABEL: @matmul_scaled_rhs_scales_tma - tt.func @matmul_scaled_rhs_scales_tma(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: !tt.tensordesc>, %arg4: !tt.tensordesc>, %arg5: !tt.tensordesc>) { + tt.func @matmul_scaled_rhs_scales_tma(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: !tt.tensordesc<128x64xf8E4M3FN, #shared3>, %arg4: !tt.tensordesc<128x64xf8E4M3FN, #shared3>, %arg5: !tt.tensordesc<128x8xi8, #shared2>) { %true = arith.constant true %c0_i32 = arith.constant 0 : i32 %c1_i32 = arith.constant 1 : i32 @@ -167,12 +167,12 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { %result = ttng.tmem_alloc %cst_0 : (tensor<128x8xi8, #linear>) -> !ttg.memdesc<128x8xi8, #tmem_scales, #ttng.tensor_memory> %0 = scf.for %arg6 = %c0_i32 to %arg0 step %c1_i32 iter_args(%arg7 = %cst) -> (tensor<128x128xf32, #blocked>) : i32 { %1 = arith.muli %arg6, %c64_i32 {loop.cluster = 1 : i32, loop.stage = 0 : i32, ttg.partition = array} : i32 - %2 = tt.descriptor_load %arg3[%arg1, %1] {loop.cluster = 1 : i32, loop.stage = 0 : i32, ttg.partition = array} : !tt.tensordesc> -> tensor<128x64xf8E4M3FN, #blocked1> - %3 = tt.descriptor_load %arg4[%arg2, %1] {loop.cluster = 1 : i32, loop.stage = 0 : i32, ttg.partition = array} : !tt.tensordesc> -> tensor<128x64xf8E4M3FN, #blocked1> + %2 = tt.descriptor_load %arg3[%arg1, %1] {loop.cluster = 1 : i32, loop.stage = 0 : i32, ttg.partition = array} : !tt.tensordesc<128x64xf8E4M3FN, #shared3> -> tensor<128x64xf8E4M3FN, #blocked1> + %3 = tt.descriptor_load %arg4[%arg2, %1] {loop.cluster = 1 : i32, loop.stage = 0 : i32, ttg.partition = array} : !tt.tensordesc<128x64xf8E4M3FN, #shared3> -> tensor<128x64xf8E4M3FN, #blocked1> %5 = ttg.local_alloc %2 {loop.cluster = 0 : i32, loop.stage = 1 : i32, ttg.partition = array} : (tensor<128x64xf8E4M3FN, #blocked1>) -> !ttg.memdesc<128x64xf8E4M3FN, #shared3, #smem> %6 = ttg.local_alloc %3 {loop.cluster = 0 : i32, loop.stage = 1 : i32, ttg.partition = array} : (tensor<128x64xf8E4M3FN, #blocked1>) -> !ttg.memdesc<128x64xf8E4M3FN, #shared3, #smem> // CHECK: [[REG:%.*]] = tt.descriptor_load - %4 = tt.descriptor_load %arg5[%arg1, %c0_i32] {loop.cluster = 1 : i32, loop.stage = 0 : i32, ttg.partition = array} : !tt.tensordesc> -> tensor<128x8xi8, #linear> + %4 = tt.descriptor_load %arg5[%arg1, %c0_i32] {loop.cluster = 1 : i32, loop.stage = 0 : i32, ttg.partition = array} : !tt.tensordesc<128x8xi8, #shared2> -> tensor<128x8xi8, #linear> // CHECK: tmem_alloc [[REG]] %result_1 = ttng.tmem_alloc %4 {loop.cluster = 0 : i32, loop.stage = 1 : i32, ttg.partition = array} : (tensor<128x8xi8, #linear>) -> !ttg.memdesc<128x8xi8, #tmem_scales, #ttng.tensor_memory> %7 = ttg.memdesc_trans %6 {loop.cluster = 0 : i32, loop.stage = 1 : i32, order = array, ttg.partition = array} : !ttg.memdesc<128x64xf8E4M3FN, #shared3, #smem> -> !ttg.memdesc<64x128xf8E4M3FN, #shared4, #smem> @@ -186,7 +186,7 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { // FUNC-LABEL: @local_alloc_default_partition // CHECK: @local_alloc_default_partition - tt.func @local_alloc_default_partition(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: !tt.tensordesc>, %arg4: !tt.tensordesc>) { + tt.func @local_alloc_default_partition(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: !tt.tensordesc<128x128xf16, #shared>, %arg4: !tt.tensordesc<128x128xf16, #shared>) { %true = arith.constant true %c0_i32 = arith.constant 0 : i32 %c1_i32 = arith.constant 1 : i32 @@ -211,11 +211,11 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { // CHECK: [[AREF_LHS_TRANS_GET_BUF:%.*]], {{.*}} = nvws.aref.get.enter [[AREF_LHS_TRANS]] {{.*}}ttg.partition = array} // CHECK: [[LHS:%.*]] = ttg.memdesc_trans [[AREF_LHS_TRANS_GET_BUF]] {{.*}}ttg.partition = array} - %3 = tt.descriptor_load %arg3[%arg1, %2] {loop.cluster = 1 : i32, loop.stage = 0 : i32, ttg.partition = array} : !tt.tensordesc> -> tensor<128x128xf16, #blocked2> + %3 = tt.descriptor_load %arg3[%arg1, %2] {loop.cluster = 1 : i32, loop.stage = 0 : i32, ttg.partition = array} : !tt.tensordesc<128x128xf16, #shared> -> tensor<128x128xf16, #blocked2> %5 = ttg.local_alloc %3 {loop.cluster = 0 : i32, loop.stage = 1 : i32, ttg.partition = array} : (tensor<128x128xf16, #blocked2>) -> !ttg.memdesc<128x128xf16, #shared1, #smem> %lhs_trans = ttg.memdesc_trans %5 {loop.cluster = 0 : i32, loop.stage = 1 : i32, order = array, ttg.partition = array} : !ttg.memdesc<128x128xf16, #shared1, #smem> -> !ttg.memdesc<128x128xf16, #shared, #smem> - %4 = tt.descriptor_load %arg4[%arg2, %2] {loop.cluster = 1 : i32, loop.stage = 0 : i32, ttg.partition = array} : !tt.tensordesc> -> tensor<128x128xf16, #blocked1> + %4 = tt.descriptor_load %arg4[%arg2, %2] {loop.cluster = 1 : i32, loop.stage = 0 : i32, ttg.partition = array} : !tt.tensordesc<128x128xf16, #shared> -> tensor<128x128xf16, #blocked1> %6 = ttg.local_alloc %4 {loop.cluster = 0 : i32, loop.stage = 1 : i32, ttg.partition = array} : (tensor<128x128xf16, #blocked1>) -> !ttg.memdesc<128x128xf16, #shared, #smem> %7 = ttg.memdesc_trans %6 {loop.cluster = 0 : i32, loop.stage = 1 : i32, order = array, ttg.partition = array} : !ttg.memdesc<128x128xf16, #shared, #smem> -> !ttg.memdesc<128x128xf16, #shared1, #smem> @@ -692,7 +692,7 @@ tt.func @cycle_in_partition(%lb: i32, %ub: i32, %step: i32) { #smem = #ttg.shared_memory #tmem = #ttng.tensor_memory_encoding module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { - tt.func public @inner_loop_fixed_operand(%arg0: !tt.tensordesc>, %arg1: !tt.tensordesc>, %arg2: !tt.tensordesc>, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + tt.func public @inner_loop_fixed_operand(%arg0: !tt.tensordesc<128x128xf8E4M3FN, #shared>, %arg1: !tt.tensordesc<128x128xf8E4M3FN, #shared>, %arg2: !tt.tensordesc<128x128xf8E4M3FN, #shared>, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { %false = arith.constant false %true = arith.constant true %c128_i32 = arith.constant 128 : i32 @@ -734,11 +734,11 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ %14 = arith.divsi %13, %10 {ttg.partition = array} : i32 %15 = arith.muli %12, %c128_i32 {ttg.partition = array} : i32 %16 = arith.muli %14, %c128_i32 {ttg.partition = array} : i32 - %17 = tt.descriptor_load %arg0[%15, %c0_i32] {loop.cluster = 2 : i32, loop.stage = 0 : i32, ttg.partition = array} : !tt.tensordesc> -> tensor<128x128xf8E4M3FN, #blocked1> + %17 = tt.descriptor_load %arg0[%15, %c0_i32] {loop.cluster = 2 : i32, loop.stage = 0 : i32, ttg.partition = array} : !tt.tensordesc<128x128xf8E4M3FN, #shared> -> tensor<128x128xf8E4M3FN, #blocked1> %18 = ttg.local_alloc %17 {loop.cluster = 0 : i32, loop.stage = 2 : i32, ttg.partition = array} : (tensor<128x128xf8E4M3FN, #blocked1>) -> !ttg.memdesc<128x128xf8E4M3FN, #shared, #smem> %19:2 = scf.for %arg8 = %c0_i32 to %3 step %c1_i32 iter_args(%arg9 = %false, %arg10 = %arg7) -> (i1, !ttg.async.token) : i32 { %22 = arith.muli %arg8, %c128_i32 {loop.cluster = 2 : i32, loop.stage = 0 : i32, ttg.partition = array} : i32 - %23 = tt.descriptor_load %arg1[%16, %22] {loop.cluster = 2 : i32, loop.stage = 0 : i32, ttg.partition = array} : !tt.tensordesc> -> tensor<128x128xf8E4M3FN, #blocked1> + %23 = tt.descriptor_load %arg1[%16, %22] {loop.cluster = 2 : i32, loop.stage = 0 : i32, ttg.partition = array} : !tt.tensordesc<128x128xf8E4M3FN, #shared> -> tensor<128x128xf8E4M3FN, #blocked1> %24 = ttg.local_alloc %23 {loop.cluster = 0 : i32, loop.stage = 2 : i32, ttg.partition = array} : (tensor<128x128xf8E4M3FN, #blocked1>) -> !ttg.memdesc<128x128xf8E4M3FN, #shared, #smem> %25 = ttg.memdesc_trans %24 {loop.cluster = 0 : i32, loop.stage = 2 : i32, order = array, ttg.partition = array} : !ttg.memdesc<128x128xf8E4M3FN, #shared, #smem> -> !ttg.memdesc<128x128xf8E4M3FN, #shared1, #smem> %26 = ttng.tc_gen5_mma %18, %25, %result[%arg10], %arg9, %true {loop.cluster = 0 : i32, loop.stage = 2 : i32, tt.self_latency = 1 : i32, ttg.partition = array} : !ttg.memdesc<128x128xf8E4M3FN, #shared, #smem>, !ttg.memdesc<128x128xf8E4M3FN, #shared1, #smem>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> @@ -747,7 +747,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ %result_0, %token_1 = ttng.tmem_load %result[%19#1] {ttg.partition = array} : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked> %20 = tt.fp_to_fp %result_0 {ttg.partition = array}, rounding = rtne : tensor<128x128xf32, #blocked> -> tensor<128x128xf8E4M3FN, #blocked> %21 = ttg.convert_layout %20 {ttg.partition = array} : tensor<128x128xf8E4M3FN, #blocked> -> tensor<128x128xf8E4M3FN, #blocked1> - tt.descriptor_store %arg2[%15, %16], %21 {ttg.partition = array} : !tt.tensordesc>, tensor<128x128xf8E4M3FN, #blocked1> + tt.descriptor_store %arg2[%15, %16], %21 {ttg.partition = array} : !tt.tensordesc<128x128xf8E4M3FN, #shared>, tensor<128x128xf8E4M3FN, #blocked1> scf.yield {ttg.partition = array} %token_1 : !ttg.async.token } {tt.num_stages = 3 : i32, tt.warp_specialize, ttg.partition = array, ttg.partition.outputs = [array], ttg.partition.stages = [0 : i32, 1 : i32, 0 : i32], ttg.warp_specialize.tag = 0 : i32} tt.return diff --git a/test/NVWS/lower_aref.mlir b/test/NVWS/lower_aref.mlir index eb27650f815c..90b99bfd0571 100644 --- a/test/NVWS/lower_aref.mlir +++ b/test/NVWS/lower_aref.mlir @@ -212,7 +212,7 @@ module attributes {"ttg.num-warps" = 4 : i32} { #smem = #ttg.shared_memory #tmem = #ttng.tensor_memory_encoding module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { - tt.func @warp_specialize_tma_matmul(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: !tt.tensordesc>, %arg4: !tt.tensordesc>) { + tt.func @warp_specialize_tma_matmul(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: !tt.tensordesc<128x64xf16, #shared>, %arg4: !tt.tensordesc<128x64xf16, #shared>) { %0 = ub.poison : !ttg.async.token %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> %c64_i32 = arith.constant 64 : i32 @@ -243,11 +243,11 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { // CHECK: ttng.async_tma_copy_global_to_local {{.*}} [[BUF_A_SLICE]], [[TMA_FULL_SLICE]], {{.*}} {loop.cluster = 1 : i32, loop.stage = 0 : i32, ttg.partition = array} // CHECK: ttng.async_tma_copy_global_to_local {{.*}} [[BUF_B_SLICE]], [[TMA_FULL_SLICE]], {{.*}} {loop.cluster = 1 : i32, loop.stage = 0 : i32, ttg.partition = array} %buffers, %token_2 = nvws.aref.put.enter %4[%c0_i32, %c0_i32] {loop.cluster = 1 : i32, loop.stage = 0 : i32, ttg.partition = array} : <[!ttg.memdesc<1x128x64xf16, #shared, #smem, mutable>]> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable>, !ttg.async.token - nvws.descriptor_load %arg3[%arg1, %11] 16384 %buffers {loop.cluster = 1 : i32, loop.stage = 0 : i32, ttg.partition = array} : !tt.tensordesc>, i32, i32, !ttg.memdesc<128x64xf16, #shared, #smem, mutable> + nvws.descriptor_load %arg3[%arg1, %11] 16384 %buffers {loop.cluster = 1 : i32, loop.stage = 0 : i32, ttg.partition = array} : !tt.tensordesc<128x64xf16, #shared>, i32, i32, !ttg.memdesc<128x64xf16, #shared, #smem, mutable> nvws.aref.put.exit %4[%c0_i32], %token_2 [#nvws.async_op] {loop.cluster = 1 : i32, loop.stage = 0 : i32, ttg.partition = array} : <[!ttg.memdesc<1x128x64xf16, #shared, #smem, mutable>]>, !ttg.async.token %buffers_3, %token_4 = nvws.aref.get.enter %4[%c0_i32, %c0_i32] {loop.cluster = 0 : i32, loop.stage = 1 : i32, ttg.partition = array} : <[!ttg.memdesc<1x128x64xf16, #shared, #smem, mutable>]> -> !ttg.memdesc<128x64xf16, #shared, #smem>, !ttg.async.token %buffers_5, %token_6 = nvws.aref.put.enter %6[%c0_i32, %c0_i32] {loop.cluster = 1 : i32, loop.stage = 0 : i32, ttg.partition = array} : <[!ttg.memdesc<1x128x64xf16, #shared, #smem, mutable>]> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable>, !ttg.async.token - nvws.descriptor_load %arg4[%arg2, %11] 16384 %buffers_5 {loop.cluster = 1 : i32, loop.stage = 0 : i32, ttg.partition = array} : !tt.tensordesc>, i32, i32, !ttg.memdesc<128x64xf16, #shared, #smem, mutable> + nvws.descriptor_load %arg4[%arg2, %11] 16384 %buffers_5 {loop.cluster = 1 : i32, loop.stage = 0 : i32, ttg.partition = array} : !tt.tensordesc<128x64xf16, #shared>, i32, i32, !ttg.memdesc<128x64xf16, #shared, #smem, mutable> nvws.aref.put.exit %6[%c0_i32], %token_6 [#nvws.async_op] {loop.cluster = 1 : i32, loop.stage = 0 : i32, ttg.partition = array} : <[!ttg.memdesc<1x128x64xf16, #shared, #smem, mutable>]>, !ttg.async.token %buffers_7, %token_8 = nvws.aref.get.enter %6[%c0_i32, %c0_i32] {loop.cluster = 0 : i32, loop.stage = 1 : i32, ttg.partition = array} : <[!ttg.memdesc<1x128x64xf16, #shared, #smem, mutable>]> -> !ttg.memdesc<128x64xf16, #shared, #smem>, !ttg.async.token @@ -275,7 +275,7 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { #shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> #smem = #ttg.shared_memory module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { - tt.func @load_used_as_reg_and_smem(%arg0: !tt.tensordesc>, %arg1: i32) { + tt.func @load_used_as_reg_and_smem(%arg0: !tt.tensordesc<128x64xf16, #shared>, %arg1: i32) { %c1_i32 = arith.constant 1 : i32 %c0_i32 = arith.constant 0 : i32 // CHECK: [[EMPTY:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<3x1xi64 @@ -288,7 +288,7 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { %1 = nvws.aref.create %0 : <[!ttg.memdesc<1x128x64xf16, #shared, #smem, mutable>]> scf.for %arg2 = %c0_i32 to %arg1 step %c1_i32 : i32 { %buffers, %token = nvws.aref.put.enter %1[%c0_i32, %c0_i32] {loop.cluster = 1 : i32, loop.stage = 0 : i32, ttg.partition = array} : <[!ttg.memdesc<1x128x64xf16, #shared, #smem, mutable>]> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable>, !ttg.async.token - nvws.descriptor_load %arg0[%arg2, %arg2] 16384 %buffers {loop.cluster = 1 : i32, loop.stage = 0 : i32, ttg.partition = array} : !tt.tensordesc>, i32, i32, !ttg.memdesc<128x64xf16, #shared, #smem, mutable> + nvws.descriptor_load %arg0[%arg2, %arg2] 16384 %buffers {loop.cluster = 1 : i32, loop.stage = 0 : i32, ttg.partition = array} : !tt.tensordesc<128x64xf16, #shared>, i32, i32, !ttg.memdesc<128x64xf16, #shared, #smem, mutable> nvws.aref.put.exit %1[%c0_i32], %token [#nvws.async_op] {loop.cluster = 1 : i32, loop.stage = 0 : i32, ttg.partition = array} : <[!ttg.memdesc<1x128x64xf16, #shared, #smem, mutable>]>, !ttg.async.token %buffers_0, %token_1 = nvws.aref.get.enter %1[%c0_i32, %c0_i32] {loop.cluster = 0 : i32, loop.stage = 1 : i32, ttg.partition = array} : <[!ttg.memdesc<1x128x64xf16, #shared, #smem, mutable>]> -> !ttg.memdesc<128x64xf16, #shared, #smem>, !ttg.async.token %2 = ttg.local_load %buffers_0 {loop.cluster = 0 : i32, loop.stage = 1 : i32, ttg.partition = array} : !ttg.memdesc<128x64xf16, #shared, #smem> -> tensor<128x64xf16, #blocked> @@ -315,7 +315,7 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { #shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> #smem = #ttg.shared_memory module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { - tt.func @load_used_as_reg_and_smem_same_partition(%arg0: !tt.tensordesc>, %arg1: i32) { + tt.func @load_used_as_reg_and_smem_same_partition(%arg0: !tt.tensordesc<128x64xf16, #shared>, %arg1: i32) { %c1_i32 = arith.constant 1 : i32 %c0_i32 = arith.constant 0 : i32 // CHECK: [[EMPTY:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<3x1xi64 @@ -328,7 +328,7 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { %1 = nvws.aref.create %0 : <[!ttg.memdesc<1x128x64xf16, #shared, #smem, mutable>]> scf.for %arg2 = %c0_i32 to %arg1 step %c1_i32 : i32 { %buffers, %token = nvws.aref.put.enter %1[%c0_i32, %c0_i32] {loop.cluster = 1 : i32, loop.stage = 0 : i32, ttg.partition = array} : <[!ttg.memdesc<1x128x64xf16, #shared, #smem, mutable>]> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 1x128x64>, !ttg.async.token - nvws.descriptor_load %arg0[%arg2, %arg2] 16384 %buffers {loop.cluster = 1 : i32, loop.stage = 0 : i32, ttg.partition = array} : !tt.tensordesc>, i32, i32, !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 1x128x64> + nvws.descriptor_load %arg0[%arg2, %arg2] 16384 %buffers {loop.cluster = 1 : i32, loop.stage = 0 : i32, ttg.partition = array} : !tt.tensordesc<128x64xf16, #shared>, i32, i32, !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 1x128x64> nvws.aref.put.exit %1[%c0_i32], %token [#nvws.async_op] {loop.cluster = 1 : i32, loop.stage = 0 : i32, ttg.partition = array} : <[!ttg.memdesc<1x128x64xf16, #shared, #smem, mutable>]>, !ttg.async.token %buffers_0, %token_1 = nvws.aref.get.enter %1[%c0_i32, %c0_i32] {loop.cluster = 0 : i32, loop.stage = 1 : i32, ttg.partition = array} : <[!ttg.memdesc<1x128x64xf16, #shared, #smem, mutable>]> -> !ttg.memdesc<128x64xf16, #shared, #smem, 1x128x64>, !ttg.async.token %2 = ttg.local_load %buffers_0 {loop.cluster = 0 : i32, loop.stage = 1 : i32, ttg.partition = array} : !ttg.memdesc<128x64xf16, #shared, #smem, 1x128x64> -> tensor<128x64xf16, #blocked> @@ -356,7 +356,7 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { #tmem = #ttng.tensor_memory_encoding module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { // CHECK-LABEL: @lower_aref_buffer - tt.func @lower_aref_buffer(%arg0: !tt.tensordesc>, %arg1: !tt.tensordesc>) { + tt.func @lower_aref_buffer(%arg0: !tt.tensordesc<128x64xf16, #shared>, %arg1: !tt.tensordesc<64x128xf16, #shared>) { %c32_i32 = arith.constant 32 : i32 %cst = arith.constant dense<1.000000e+00> : tensor<128x128xf32, #blocked> %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> @@ -372,8 +372,8 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { // CHECK: scf.for {{.*}} iter_args({{.*}} = {{.*}}, [[SPUT:%.*]] = {{.*}}, {{.*}} = {{.*}}, {{.*}} = {{.*}}, {{.*}} = {{.*}}) %3 = scf.for %arg2 = %c0_i32 to %c32_i32 step %c1_i32 iter_args(%arg3 = %token) -> (!ttg.async.token) : i32 { %4:3 = "get_offsets"(%arg2) {ttg.partition = array} : (i32) -> (i32, i32, i32) - %5 = tt.descriptor_load %arg0[%4#0, %4#2] {ttg.partition = array} : !tt.tensordesc> -> tensor<128x64xf16, #blocked1> - %6 = tt.descriptor_load %arg1[%4#1, %4#2] {ttg.partition = array} : !tt.tensordesc> -> tensor<64x128xf16, #blocked1> + %5 = tt.descriptor_load %arg0[%4#0, %4#2] {ttg.partition = array} : !tt.tensordesc<128x64xf16, #shared> -> tensor<128x64xf16, #blocked1> + %6 = tt.descriptor_load %arg1[%4#1, %4#2] {ttg.partition = array} : !tt.tensordesc<64x128xf16, #shared> -> tensor<64x128xf16, #blocked1> %7 = ttg.local_alloc %5 {ttg.partition = array} : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #smem> %8 = ttg.local_alloc %6 {ttg.partition = array} : (tensor<64x128xf16, #blocked1>) -> !ttg.memdesc<64x128xf16, #shared, #smem> // CHECK: local_alloc @@ -419,7 +419,7 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { // CHECK-LABEL: @aref_not_in_loop - tt.func @aref_not_in_loop(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: !tt.tensordesc>, %arg4: !tt.tensordesc>) { + tt.func @aref_not_in_loop(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: !tt.tensordesc<128x64xf16, #shared>, %arg4: !tt.tensordesc<128x64xf16, #shared>) { %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> %c64_i32 = arith.constant 64 : i32 %c1_i32 = arith.constant 1 : i32 @@ -439,8 +439,8 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { %2 = ttng.tmem_store %cst, %1[], %true : tensor<128x128xf32, #blocked> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable, 1x128x128> scf.for %arg5 = %c0_i32 to %arg0 step %c1_i32 : i32 { %4 = arith.muli %arg5, %c64_i32 {ttg.partition = array} : i32 - %5 = tt.descriptor_load %arg3[%arg1, %4] {ttg.partition = array} : !tt.tensordesc> -> tensor<128x64xf16, #blocked1> - %6 = tt.descriptor_load %arg4[%arg2, %4] {ttg.partition = array} : !tt.tensordesc> -> tensor<128x64xf16, #blocked1> + %5 = tt.descriptor_load %arg3[%arg1, %4] {ttg.partition = array} : !tt.tensordesc<128x64xf16, #shared> -> tensor<128x64xf16, #blocked1> + %6 = tt.descriptor_load %arg4[%arg2, %4] {ttg.partition = array} : !tt.tensordesc<128x64xf16, #shared> -> tensor<128x64xf16, #blocked1> %7 = ttg.local_alloc %5 {ttg.partition = array} : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #smem> %8 = ttg.local_alloc %6 {ttg.partition = array} : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #smem> %9 = ttg.memdesc_trans %8 {order = array, ttg.partition = array} : !ttg.memdesc<128x64xf16, #shared, #smem> -> !ttg.memdesc<64x128xf16, #shared1, #smem> @@ -471,7 +471,7 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { // CHECK-LABEL: @load_scale_mma_user - tt.func @load_scale_mma_user(%arg0: !ttg.memdesc<128x64xf16, #shared, #smem>, %arg1: !ttg.memdesc<64x128xf16, #shared, #smem>, %arg2: !tt.tensordesc>, %arg3: !ttg.memdesc<128x8xi8, #tmem_scales, #ttng.tensor_memory>, %arg4: i32) { + tt.func @load_scale_mma_user(%arg0: !ttg.memdesc<128x64xf16, #shared, #smem>, %arg1: !ttg.memdesc<64x128xf16, #shared, #smem>, %arg2: !tt.tensordesc<8x128xi8, #shared>, %arg3: !ttg.memdesc<128x8xi8, #tmem_scales, #ttng.tensor_memory>, %arg4: i32) { %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> %true = arith.constant true %c1_i32 = arith.constant 1 : i32 @@ -483,7 +483,7 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { %2 = ttng.tmem_store %cst, %1[], %true : tensor<128x128xf32, #blocked> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable, 1x128x128> // CHECK: scf.for %3 = scf.for %arg5 = %c0_i32 to %arg4 step %c1_i32 iter_args(%arg6 = %token) -> (!ttg.async.token) : i32 { - %5 = tt.descriptor_load %arg2[%arg5, %arg5] {ttg.partition = array} : !tt.tensordesc> -> tensor<8x128xi8, #blocked1> + %5 = tt.descriptor_load %arg2[%arg5, %arg5] {ttg.partition = array} : !tt.tensordesc<8x128xi8, #shared> -> tensor<8x128xi8, #blocked1> %6 = ttg.local_alloc %5 {ttg.partition = array} : (tensor<8x128xi8, #blocked1>) -> !ttg.memdesc<8x128xi8, #shared, #smem> %7 = ttg.local_load %6 {ttg.partition = array} : !ttg.memdesc<8x128xi8, #shared, #smem> -> tensor<8x128xi8, #linear1> %8 = tt.trans %7 {order = array, ttg.partition = array} : tensor<8x128xi8, #linear1> -> tensor<128x8xi8, #linear> @@ -521,7 +521,7 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { #smem = #ttg.shared_memory #tmem = #ttng.tensor_memory_encoding module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { - tt.func public @attention_forward(%arg0: !ttg.memdesc<256x64xf16, #shared, #smem>, %arg1: !tt.tensordesc>, %arg2: !tt.tensordesc>, %arg3: f32, %arg4: i32, %arg5: !tt.ptr) { + tt.func public @attention_forward(%arg0: !ttg.memdesc<256x64xf16, #shared, #smem>, %arg1: !tt.tensordesc<64x64xf16, #shared>, %arg2: !tt.tensordesc<64x64xf16, #shared>, %arg3: f32, %arg4: i32, %arg5: !tt.ptr) { %cst = arith.constant dense<1.000000e+00> : tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>> %cst_0 = arith.constant dense<0.000000e+00> : tensor<256x64xf32, #blocked> %cst_1 = arith.constant dense<0xFF800000> : tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>> @@ -552,7 +552,7 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { ttg.local_store %arg8, %buffers_9 {loop.cluster = 0 : i32, loop.stage = 4 : i32, ttg.partition = array} : tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>> -> !ttg.memdesc<256xf32, #shared1, #smem, mutable, 1x256> nvws.aref.put.exit %11, %token_10 [#nvws.async_op] {loop.cluster = 0 : i32, loop.stage = 4 : i32, ttg.partition = array} : <[!ttg.memdesc<1x256xf32, #shared1, #smem, mutable>]>, !ttg.async.token %buffers_11, %token_12 = nvws.aref.put.enter %5 {loop.cluster = 4 : i32, loop.stage = 0 : i32, ttg.partition = array} : <[!ttg.memdesc<1x64x64xf16, #shared, #smem, mutable>]> -> !ttg.memdesc<64x64xf16, #shared, #smem, mutable, 1x64x64>, !ttg.async.token - nvws.descriptor_load %arg1[%arg6, %c0_i32] 8192 %buffers_11 {loop.cluster = 4 : i32, loop.stage = 0 : i32, ttg.partition = array} : !tt.tensordesc>, i32, i32, !ttg.memdesc<64x64xf16, #shared, #smem, mutable, 1x64x64> + nvws.descriptor_load %arg1[%arg6, %c0_i32] 8192 %buffers_11 {loop.cluster = 4 : i32, loop.stage = 0 : i32, ttg.partition = array} : !tt.tensordesc<64x64xf16, #shared>, i32, i32, !ttg.memdesc<64x64xf16, #shared, #smem, mutable, 1x64x64> nvws.aref.put.exit %5, %token_12 [#nvws.async_op] {loop.cluster = 4 : i32, loop.stage = 0 : i32, ttg.partition = array} : <[!ttg.memdesc<1x64x64xf16, #shared, #smem, mutable>]>, !ttg.async.token %buffers_13, %token_14 = nvws.aref.get.enter %5 {loop.cluster = 2 : i32, loop.stage = 2 : i32, ttg.partition = array} : <[!ttg.memdesc<1x64x64xf16, #shared, #smem, mutable>]> -> !ttg.memdesc<64x64xf16, #shared, #smem, 1x64x64>, !ttg.async.token %16 = ttg.memdesc_trans %buffers_13 {loop.cluster = 2 : i32, loop.stage = 2 : i32, order = array, ttg.partition = array} : !ttg.memdesc<64x64xf16, #shared, #smem, 1x64x64> -> !ttg.memdesc<64x64xf16, #shared2, #smem, 1x64x64> @@ -598,7 +598,7 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { %39 = arith.mulf %result_25, %34 {loop.cluster = 0 : i32, loop.stage = 4 : i32, ttg.partition = array} : tensor<256x64xf32, #blocked> %40 = arith.addf %39, %37 {loop.cluster = 0 : i32, loop.stage = 4 : i32, ttg.partition = array} : tensor<256x64xf32, #blocked> %buffers_27, %token_28 = nvws.aref.put.enter %7 {loop.cluster = 2 : i32, loop.stage = 2 : i32, ttg.partition = array} : <[!ttg.memdesc<1x64x64xf16, #shared, #smem, mutable>]> -> !ttg.memdesc<64x64xf16, #shared, #smem, mutable, 1x64x64>, !ttg.async.token - nvws.descriptor_load %arg2[%arg6, %c0_i32] 8192 %buffers_27 {loop.cluster = 2 : i32, loop.stage = 2 : i32, ttg.partition = array} : !tt.tensordesc>, i32, i32, !ttg.memdesc<64x64xf16, #shared, #smem, mutable, 1x64x64> + nvws.descriptor_load %arg2[%arg6, %c0_i32] 8192 %buffers_27 {loop.cluster = 2 : i32, loop.stage = 2 : i32, ttg.partition = array} : !tt.tensordesc<64x64xf16, #shared>, i32, i32, !ttg.memdesc<64x64xf16, #shared, #smem, mutable, 1x64x64> nvws.aref.put.exit %7, %token_28 [#nvws.async_op] {loop.cluster = 2 : i32, loop.stage = 2 : i32, ttg.partition = array} : <[!ttg.memdesc<1x64x64xf16, #shared, #smem, mutable>]>, !ttg.async.token %buffers_29, %token_30 = nvws.aref.get.enter %7 {loop.cluster = 0 : i32, loop.stage = 4 : i32, ttg.partition = array} : <[!ttg.memdesc<1x64x64xf16, #shared, #smem, mutable>]> -> !ttg.memdesc<64x64xf16, #shared, #smem, 1x64x64>, !ttg.async.token %41 = arith.truncf %22 {loop.cluster = 0 : i32, loop.stage = 4 : i32, ttg.partition = array} : tensor<256x64xf32, #blocked> to tensor<256x64xf16, #blocked> diff --git a/test/Triton/combine.mlir b/test/Triton/combine.mlir index 9e90a3f13a87..597285242cca 100644 --- a/test/Triton/combine.mlir +++ b/test/Triton/combine.mlir @@ -429,11 +429,11 @@ tt.func @test_reshape_reduce(%0: tensor<32x4x2xi32>) -> (i32, tensor<16xi32>) { } // CHECK-LABEL: test_rank_reduce_desc_load -tt.func @test_rank_reduce_desc_load(%0: !tt.tensordesc>) -> (tensor<128x64xf16>) { +tt.func @test_rank_reduce_desc_load(%0: !tt.tensordesc<1x128x64xf16>) -> (tensor<128x64xf16>) { %c0 = arith.constant 0 : i32 - // CHECK: %[[R:.+]] = tt.descriptor_load {{.*}} : !tt.tensordesc> -> tensor<128x64xf16> + // CHECK: %[[R:.+]] = tt.descriptor_load {{.*}} : !tt.tensordesc<1x128x64xf16> -> tensor<128x64xf16> // CHECK: tt.return %[[R]] - %l = tt.descriptor_load %0[%c0, %c0, %c0] : !tt.tensordesc> -> tensor<1x128x64xf16> + %l = tt.descriptor_load %0[%c0, %c0, %c0] : !tt.tensordesc<1x128x64xf16> -> tensor<1x128x64xf16> %r = tt.reshape %l : tensor<1x128x64xf16> -> tensor<128x64xf16> tt.return %r : tensor<128x64xf16> } diff --git a/test/Triton/invalid.mlir b/test/Triton/invalid.mlir index 25432c0fbf4b..43ea857e1e84 100644 --- a/test/Triton/invalid.mlir +++ b/test/Triton/invalid.mlir @@ -428,84 +428,84 @@ tt.func @gather_op(%arg0: tensor<128x16xf32>, %arg1: tensor<512x4xi32>) { // ----- -tt.func @invalid_desc_load(%arg0: !tt.tensordesc>) { +tt.func @invalid_desc_load(%arg0: !tt.tensordesc<16x16xf32>) { %c = arith.constant 0 : i32 // expected-error @below {{descriptor block and tensor must have the same number of elements}} - tt.descriptor_load %arg0[%c, %c] : !tt.tensordesc> -> tensor<16xf32> + tt.descriptor_load %arg0[%c, %c] : !tt.tensordesc<16x16xf32> -> tensor<16xf32> tt.return } // ----- -tt.func @invalid_desc_load(%arg0: !tt.tensordesc>) { +tt.func @invalid_desc_load(%arg0: !tt.tensordesc<16x16xf32>) { %c = arith.constant 0 : i32 // expected-error @below {{descriptor block and tensor element types must match}} - tt.descriptor_load %arg0[%c, %c] : !tt.tensordesc> -> tensor<16x16xf16> + tt.descriptor_load %arg0[%c, %c] : !tt.tensordesc<16x16xf32> -> tensor<16x16xf16> tt.return } // ----- -tt.func @invalid_desc_store(%arg0: !tt.tensordesc>, %arg1: tensor<32x16xf32>) { +tt.func @invalid_desc_store(%arg0: !tt.tensordesc<16x16xf32>, %arg1: tensor<32x16xf32>) { %c = arith.constant 0 : i32 // expected-error @below {{descriptor block and tensor must have the same number of elements}} - tt.descriptor_store %arg0[%c, %c], %arg1 : !tt.tensordesc>, tensor<32x16xf32> + tt.descriptor_store %arg0[%c, %c], %arg1 : !tt.tensordesc<16x16xf32>, tensor<32x16xf32> tt.return } // ----- -tt.func @invalid_tma_gather(%arg0: !tt.tensordesc>, %arg1: tensor<32xi32>, %arg2: i32) { +tt.func @invalid_tma_gather(%arg0: !tt.tensordesc<128xbf16>, %arg1: tensor<32xi32>, %arg2: i32) { // expected-error @below {{block must be a 2D tensor}} - %0 = tt.descriptor_gather %arg0[%arg1, %arg2] : (!tt.tensordesc>, tensor<32xi32>, i32) -> tensor<32xbf16> + %0 = tt.descriptor_gather %arg0[%arg1, %arg2] : (!tt.tensordesc<128xbf16>, tensor<32xi32>, i32) -> tensor<32xbf16> tt.return } // ----- -tt.func @invalid_tma_gather(%arg0: !tt.tensordesc>, %arg1: tensor<32xi32>, %arg2: i32) { +tt.func @invalid_tma_gather(%arg0: !tt.tensordesc<2x128xbf16>, %arg1: tensor<32xi32>, %arg2: i32) { // expected-error @below {{block must have exactly 1 row}} - %0 = tt.descriptor_gather %arg0[%arg1, %arg2] : (!tt.tensordesc>, tensor<32xi32>, i32) -> tensor<32x128xbf16> + %0 = tt.descriptor_gather %arg0[%arg1, %arg2] : (!tt.tensordesc<2x128xbf16>, tensor<32xi32>, i32) -> tensor<32x128xbf16> tt.return } // ----- -tt.func @invalid_tma_gather(%arg0: !tt.tensordesc>, %arg1: tensor<1x32xi32>, %arg2: i32) { +tt.func @invalid_tma_gather(%arg0: !tt.tensordesc<1x128xbf16>, %arg1: tensor<1x32xi32>, %arg2: i32) { // expected-error @below {{x offsets must be a 1D tensor}} - %0 = tt.descriptor_gather %arg0[%arg1, %arg2] : (!tt.tensordesc>, tensor<1x32xi32>, i32) -> tensor<32x128xbf16> + %0 = tt.descriptor_gather %arg0[%arg1, %arg2] : (!tt.tensordesc<1x128xbf16>, tensor<1x32xi32>, i32) -> tensor<32x128xbf16> tt.return } // ----- -tt.func @invalid_tma_gather(%arg0: !tt.tensordesc>, %arg1: tensor<32xi32>, %arg2: i32) { +tt.func @invalid_tma_gather(%arg0: !tt.tensordesc<1x128xbf16>, %arg1: tensor<32xi32>, %arg2: i32) { // expected-error @below {{result must be a 2D tensor}} - %0 = tt.descriptor_gather %arg0[%arg1, %arg2] : (!tt.tensordesc>, tensor<32xi32>, i32) -> tensor<128xbf16> + %0 = tt.descriptor_gather %arg0[%arg1, %arg2] : (!tt.tensordesc<1x128xbf16>, tensor<32xi32>, i32) -> tensor<128xbf16> tt.return } // ----- -tt.func @invalid_tma_gather(%arg0: !tt.tensordesc>, %arg1: tensor<32xi32>, %arg2: i32) { +tt.func @invalid_tma_gather(%arg0: !tt.tensordesc<1x128xbf16>, %arg1: tensor<32xi32>, %arg2: i32) { // expected-error @below {{result tensor number of columns must match block (128)}} - %0 = tt.descriptor_gather %arg0[%arg1, %arg2] : (!tt.tensordesc>, tensor<32xi32>, i32) -> tensor<32x64xbf16> + %0 = tt.descriptor_gather %arg0[%arg1, %arg2] : (!tt.tensordesc<1x128xbf16>, tensor<32xi32>, i32) -> tensor<32x64xbf16> tt.return } // ----- -tt.func @invalid_tma_gather(%arg0: !tt.tensordesc>, %arg1: tensor<32xi32>, %arg2: i32) { +tt.func @invalid_tma_gather(%arg0: !tt.tensordesc<1x128xbf16>, %arg1: tensor<32xi32>, %arg2: i32) { // expected-error @below {{result tensor must have as many rows as indices (32)}} - %0 = tt.descriptor_gather %arg0[%arg1, %arg2] : (!tt.tensordesc>, tensor<32xi32>, i32) -> tensor<64x128xbf16> + %0 = tt.descriptor_gather %arg0[%arg1, %arg2] : (!tt.tensordesc<1x128xbf16>, tensor<32xi32>, i32) -> tensor<64x128xbf16> tt.return } // ----- -tt.func @invalid_tma_gather(%arg0: !tt.tensordesc>, %arg1: tensor<32xi32>, %arg2: i32) { +tt.func @invalid_tma_gather(%arg0: !tt.tensordesc<1x128xbf16>, %arg1: tensor<32xi32>, %arg2: i32) { // expected-error @below {{result tensor element type must match block ('bf16')}} - %0 = tt.descriptor_gather %arg0[%arg1, %arg2] : (!tt.tensordesc>, tensor<32xi32>, i32) -> tensor<32x128xf32> + %0 = tt.descriptor_gather %arg0[%arg1, %arg2] : (!tt.tensordesc<1x128xbf16>, tensor<32xi32>, i32) -> tensor<32x128xf32> tt.return } diff --git a/test/Triton/ops.mlir b/test/Triton/ops.mlir index a41a3b85060b..e026c6b04071 100644 --- a/test/Triton/ops.mlir +++ b/test/Triton/ops.mlir @@ -251,10 +251,10 @@ tt.func @masked_histogram(%0: tensor<512xi32>, %1: tensor<512xi1>) { } // CHECK-LABEL: descriptor_load -tt.func @descriptor_load(%0: !tt.tensordesc>) { - // CHECK: tt.descriptor_load %{{.+}}[%{{.+}}] : !tt.tensordesc> -> tensor<128xf32> +tt.func @descriptor_load(%0: !tt.tensordesc<128xf32>) { + // CHECK: tt.descriptor_load %{{.+}}[%{{.+}}] : !tt.tensordesc<128xf32> -> tensor<128xf32> %c0_i32 = arith.constant 0 : i32 - %1 = tt.descriptor_load %0[%c0_i32] : !tt.tensordesc> -> tensor<128xf32> + %1 = tt.descriptor_load %0[%c0_i32] : !tt.tensordesc<128xf32> -> tensor<128xf32> tt.return } @@ -266,16 +266,16 @@ tt.func @gather_op(%arg0: tensor<128x16xf32>, %arg1: tensor<512x16xi32>) -> tens } // CHECK-LABEL: @tma_gather -tt.func @tma_gather(%arg0: !tt.tensordesc>, %arg1: tensor<32xi32>, %arg2: i32) { - // CHECK-NEXT: %0 = tt.descriptor_gather %arg0[%arg1, %arg2] : (!tt.tensordesc>, tensor<32xi32>, i32) -> tensor<32x128xbf16> - %0 = tt.descriptor_gather %arg0[%arg1, %arg2] : (!tt.tensordesc>, tensor<32xi32>, i32) -> tensor<32x128xbf16> +tt.func @tma_gather(%arg0: !tt.tensordesc<1x128xbf16>, %arg1: tensor<32xi32>, %arg2: i32) { + // CHECK-NEXT: %0 = tt.descriptor_gather %arg0[%arg1, %arg2] : (!tt.tensordesc<1x128xbf16>, tensor<32xi32>, i32) -> tensor<32x128xbf16> + %0 = tt.descriptor_gather %arg0[%arg1, %arg2] : (!tt.tensordesc<1x128xbf16>, tensor<32xi32>, i32) -> tensor<32x128xbf16> tt.return } // CHECK-LABEL: @tma_scatter -tt.func @tma_scatter(%arg0: !tt.tensordesc>, %arg1: tensor<32xi32>, %arg2: i32, %arg3: tensor<32x128xbf16>) { - // CHECK-NEXT: tt.descriptor_scatter %arg0[%arg1, %arg2], %arg3 : !tt.tensordesc>, tensor<32xi32>, i32, tensor<32x128xbf16> - tt.descriptor_scatter %arg0[%arg1, %arg2], %arg3 : !tt.tensordesc>, tensor<32xi32>, i32, tensor<32x128xbf16> +tt.func @tma_scatter(%arg0: !tt.tensordesc<1x128xbf16>, %arg1: tensor<32xi32>, %arg2: i32, %arg3: tensor<32x128xbf16>) { + // CHECK-NEXT: tt.descriptor_scatter %arg0[%arg1, %arg2], %arg3 : !tt.tensordesc<1x128xbf16>, tensor<32xi32>, i32, tensor<32x128xbf16> + tt.descriptor_scatter %arg0[%arg1, %arg2], %arg3 : !tt.tensordesc<1x128xbf16>, tensor<32xi32>, i32, tensor<32x128xbf16> tt.return } diff --git a/test/Triton/rewrite-tensor-descriptor-to-pointer.mlir b/test/Triton/rewrite-tensor-descriptor-to-pointer.mlir index e0c6fea47dd9..2580eaaeef14 100644 --- a/test/Triton/rewrite-tensor-descriptor-to-pointer.mlir +++ b/test/Triton/rewrite-tensor-descriptor-to-pointer.mlir @@ -7,8 +7,8 @@ module { %c0_i32 = arith.constant 0 : i32 %c128_i32 = arith.constant 128 : i32 %c256_i32 = arith.constant 256 : i32 - %0 = tt.make_tensor_descriptor %arg0, [%c256_i32, %c256_i32], [%c1_i64, %c256_i64] {order = array} : , > - %3 = tt.descriptor_load %0[%arg1, %arg2] : !tt.tensordesc> -> tensor<128x128xf32> + %0 = tt.make_tensor_descriptor %arg0, [%c256_i32, %c256_i32], [%c1_i64, %c256_i64] {order = array} : , <128x128xf32> + %3 = tt.descriptor_load %0[%arg1, %arg2] : !tt.tensordesc<128x128xf32> -> tensor<128x128xf32> tt.return %3 : tensor<128x128xf32> } } @@ -62,8 +62,8 @@ module { %c0_i32 = arith.constant 0 : i32 %c128_i32 = arith.constant 128 : i32 %c256_i32 = arith.constant 256 : i32 - %0 = tt.make_tensor_descriptor %arg0, [%c256_i32, %c256_i32], [%c1_i64, %c256_i64] {order = array} : , > - tt.descriptor_store %0[%arg1, %arg2], %arg3 : !tt.tensordesc>, tensor<128x128xf32> + %0 = tt.make_tensor_descriptor %arg0, [%c256_i32, %c256_i32], [%c1_i64, %c256_i64] {order = array} : , <128x128xf32> + tt.descriptor_store %0[%arg1, %arg2], %arg3 : !tt.tensordesc<128x128xf32>, tensor<128x128xf32> tt.return } } @@ -111,16 +111,16 @@ module { #loc2 = loc("rewrite-tensor-descriptor-to-pointer.mlir":147:28) module { - tt.func public @callee(%tensordesc: !tt.tensordesc> loc("tensordesc"(#loc2))) -> !tt.tensordesc> { - tt.return %tensordesc : !tt.tensordesc> + tt.func public @callee(%tensordesc: !tt.tensordesc<128x128xf32> loc("tensordesc"(#loc2))) -> !tt.tensordesc<128x128xf32> { + tt.return %tensordesc : !tt.tensordesc<128x128xf32> } tt.func public @caller(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) { %c1_i64 = arith.constant 1 : i64 %c256_i32 = arith.constant 256 : i32 %c256_i64 = arith.constant 256 : i64 - %0 = tt.make_tensor_descriptor %arg0, [%c256_i32, %c256_i32], [%c256_i64, %c1_i64] {order = array} : , > - %1 = tt.call @callee(%0) : (!tt.tensordesc>) -> !tt.tensordesc> + %0 = tt.make_tensor_descriptor %arg0, [%c256_i32, %c256_i32], [%c256_i64, %c1_i64] {order = array} : , <128x128xf32> + %1 = tt.call @callee(%0) : (!tt.tensordesc<128x128xf32>) -> !tt.tensordesc<128x128xf32> tt.return } } @@ -152,7 +152,7 @@ module { // ----- module { - tt.func public @arg_attr(%arg0: !tt.tensordesc>, %arg1: i32 {tt.divisibility = 16 : i32}) { + tt.func public @arg_attr(%arg0: !tt.tensordesc<128x128xf32>, %arg1: i32 {tt.divisibility = 16 : i32}) { tt.return } } diff --git a/test/TritonGPU/accelerate-matmul.mlir b/test/TritonGPU/accelerate-matmul.mlir index e473c95f492f..d3b03f7650ba 100644 --- a/test/TritonGPU/accelerate-matmul.mlir +++ b/test/TritonGPU/accelerate-matmul.mlir @@ -462,9 +462,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { -tt.func @scalar_load_in_bwd_slice(%arg0: tensor<128x128xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>, %arg1: !tt.tensordesc>, %arg2: !tt.ptr) -> tensor<128x128xf32, #blocked> { +tt.func @scalar_load_in_bwd_slice(%arg0: tensor<128x128xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>, %arg1: !tt.tensordesc<128x128xf8E5M2>, %arg2: !tt.ptr) -> tensor<128x128xf32, #blocked> { %0 = tt.load %arg2 : !tt.ptr - %1 = tt.descriptor_load %arg1[%0, %0] : !tt.tensordesc> -> tensor<128x128xf8E5M2, #blocked1> + %1 = tt.descriptor_load %arg1[%0, %0] : !tt.tensordesc<128x128xf8E5M2> -> tensor<128x128xf8E5M2, #blocked1> %2 = ttg.convert_layout %1 : tensor<128x128xf8E5M2, #blocked1> -> tensor<128x128xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> %3 = tt.dot %2, %arg0, %cst, inputPrecision = tf32 : tensor<128x128xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<128x128xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x128xf32, #blocked> @@ -601,8 +601,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: identify_load_then_trans tt.func public @identify_load_then_trans( - %arg0: !tt.tensordesc>, - %arg1: !tt.tensordesc>, + %arg0: !tt.tensordesc<128x128xf16>, + %arg1: !tt.tensordesc<128x128xf16>, %arg2: i32, %arg3: i32, %arg4: i32, @@ -610,8 +610,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ ) -> tensor<128x128xf32, #blocked> { // CHECK: %[[DESC0:.*]] = tt.descriptor_load %arg0 // CHECK: %[[DESC1:.*]] = tt.descriptor_load %arg1 - %13 = tt.descriptor_load %arg0[%arg4, %arg2] : !tt.tensordesc> -> tensor<128x128xf16, #blocked2> - %14 = tt.descriptor_load %arg1[%arg3, %arg4] : !tt.tensordesc> -> tensor<128x128xf16, #blocked2> + %13 = tt.descriptor_load %arg0[%arg4, %arg2] : !tt.tensordesc<128x128xf16> -> tensor<128x128xf16, #blocked2> + %14 = tt.descriptor_load %arg1[%arg3, %arg4] : !tt.tensordesc<128x128xf16> -> tensor<128x128xf16, #blocked2> // CHECK: %[[TRANS0:.*]] = tt.trans %[[DESC0]] // CHECK: %[[ALLOC0:.*]] = ttg.local_alloc %[[TRANS0]] %15 = tt.trans %13 {order = array} : tensor<128x128xf16, #blocked2> -> tensor<128x128xf16, #blocked3> diff --git a/test/TritonGPU/amd/amd-canonicalize-pointers-dont-run-mlir-canonicalizer.mlir b/test/TritonGPU/amd/amd-canonicalize-pointers-dont-run-mlir-canonicalizer.mlir index 944e9d388c3e..0637e0306e85 100644 --- a/test/TritonGPU/amd/amd-canonicalize-pointers-dont-run-mlir-canonicalizer.mlir +++ b/test/TritonGPU/amd/amd-canonicalize-pointers-dont-run-mlir-canonicalizer.mlir @@ -163,20 +163,20 @@ module attributes {"ttg.num-warps" = 4 : i32} { // ----- module attributes {"ttg.num-warps" = 4 : i32} { - tt.func @make_tensor_descriptor(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %n: i32 {tt.divisibility = 16 : i32}) -> !tt.tensordesc> { + tt.func @make_tensor_descriptor(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %n: i32 {tt.divisibility = 16 : i32}) -> !tt.tensordesc<16xf32> { %c1_i64 = arith.constant 1 : i64 %c1_i32 = arith.constant 1 : i32 %ptr = tt.addptr %arg0, %c1_i32 : !tt.ptr, i32 - %desc = tt.make_tensor_descriptor %ptr, [%n], [%c1_i64] : !tt.ptr, !tt.tensordesc> - tt.return %desc : !tt.tensordesc> + %desc = tt.make_tensor_descriptor %ptr, [%n], [%c1_i64] : !tt.ptr, !tt.tensordesc<16xf32> + tt.return %desc : !tt.tensordesc<16xf32> } } // CHECK-LABEL: tt.func @make_tensor_descriptor( -// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %[[VAL_1:.*]]: i32 {tt.divisibility = 16 : i32}) -> !tt.tensordesc> { +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %[[VAL_1:.*]]: i32 {tt.divisibility = 16 : i32}) -> !tt.tensordesc<16xf32> { // CHECK: %[[VAL_2:.*]] = arith.constant 1 : i64 // CHECK: %[[VAL_3:.*]] = arith.constant 1 : i32 // CHECK: %[[VAL_4:.*]] = tt.addptr %[[VAL_0]], %[[VAL_3]] : !tt.ptr, i32 -// CHECK: %[[VAL_5:.*]] = tt.make_tensor_descriptor %[[VAL_4]], {{\[}}%[[VAL_1]]], {{\[}}%[[VAL_2]]] : , > -// CHECK: tt.return %[[VAL_5]] : !tt.tensordesc> +// CHECK: %[[VAL_5:.*]] = tt.make_tensor_descriptor %[[VAL_4]], {{\[}}%[[VAL_1]]], {{\[}}%[[VAL_2]]] : , <16xf32> +// CHECK: tt.return %[[VAL_5]] : !tt.tensordesc<16xf32> // CHECK: } diff --git a/test/TritonGPU/amd/amd-consan.mlir b/test/TritonGPU/amd/amd-consan.mlir index e4d8131a8f26..d475a14b6ea7 100644 --- a/test/TritonGPU/amd/amd-consan.mlir +++ b/test/TritonGPU/amd/amd-consan.mlir @@ -464,7 +464,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar #smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.shared = 65544 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 8 : i32} { // CHECK-LABEL: @async_tdm_copy_global_to_local - tt.func public @async_tdm_copy_global_to_local(%desc: !tt.tensordesc>) { + tt.func public @async_tdm_copy_global_to_local(%desc: !tt.tensordesc<32x32xf32>) { // CHECK-DAG: %[[BUFFERS:.*]] = tti.experimental_buffer_descriptors [0], [{{.*}}], shared_mem : tensor<1x1xi64 // CHECK-DAG: %[[WRITE_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 8 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr @@ -496,7 +496,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.shar // CHECK: tt.call @__triton_consan_track_visible_writes // CHECK: tt.call @__triton_consan_verify_barrier_arrive // CHECK: tt.call @__triton_consan_update_barrier_state - %1 = amdg.async_tdm_copy_global_to_local %desc[%c0_i32, %c0_i32] into %0, pred = %pred, barrier = %bar : !tt.tensordesc>, !ttg.memdesc<1xi64, #shared1, #smem, mutable> -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> + %1 = amdg.async_tdm_copy_global_to_local %desc[%c0_i32, %c0_i32] into %0, pred = %pred, barrier = %bar : !tt.tensordesc<32x32xf32>, !ttg.memdesc<1xi64, #shared1, #smem, mutable> -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> tt.return } } @@ -510,8 +510,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.shar module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.shared = 65544 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 8 : i32} { // CHECK-LABEL: @async_tdm_copy_global_to_local_two_bufs_one_barrier tt.func public @async_tdm_copy_global_to_local_two_bufs_one_barrier( - %a: !tt.tensordesc>, - %b: !tt.tensordesc>) { + %a: !tt.tensordesc<32x32xf32>, + %b: !tt.tensordesc<32x32xf32>) { %c0_i32 = arith.constant 0 : i32 %pred = arith.constant 1 : i32 @@ -533,7 +533,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.shar // CHECK: tt.call @__triton_consan_track_visible_writes // CHECK: tt.call @__triton_consan_verify_barrier_arrive // CHECK: tt.call @__triton_consan_update_barrier_state - %0 = amdg.async_tdm_copy_global_to_local %a[%c0_i32, %c0_i32] into %a_smem, pred = %pred, barrier = %bar : !tt.tensordesc>, !ttg.memdesc<1xi64, #shared1, #smem, mutable> -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> + %0 = amdg.async_tdm_copy_global_to_local %a[%c0_i32, %c0_i32] into %a_smem, pred = %pred, barrier = %bar : !tt.tensordesc<32x32xf32>, !ttg.memdesc<1xi64, #shared1, #smem, mutable> -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> // Second TDM copy: same full instrumentation // CHECK: tt.call @__triton_consan_verify_write_visibility @@ -546,7 +546,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.shar // CHECK: tt.call @__triton_consan_track_visible_writes // CHECK: tt.call @__triton_consan_verify_barrier_arrive // CHECK: tt.call @__triton_consan_update_barrier_state - %1 = amdg.async_tdm_copy_global_to_local %b[%c0_i32, %c0_i32] into %b_smem, pred = %pred, barrier = %bar : !tt.tensordesc>, !ttg.memdesc<1xi64, #shared1, #smem, mutable> -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> + %1 = amdg.async_tdm_copy_global_to_local %b[%c0_i32, %c0_i32] into %b_smem, pred = %pred, barrier = %bar : !tt.tensordesc<32x32xf32>, !ttg.memdesc<1xi64, #shared1, #smem, mutable> -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> %c0_phase = arith.constant 0 : i32 amdg.wait_barrier %bar, %c0_phase : !ttg.memdesc<1xi64, #shared1, #smem, mutable> @@ -564,7 +564,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.shar #smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.shared = 65544 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 8 : i32} { // CHECK-LABEL: @async_tdm_copy_global_to_local_no_barrier - tt.func public @async_tdm_copy_global_to_local_no_barrier(%desc: !tt.tensordesc>) { + tt.func public @async_tdm_copy_global_to_local_no_barrier(%desc: !tt.tensordesc<32x32xf32>) { %c0_i32 = arith.constant 0 : i32 %pred = arith.constant 1 : i32 %0 = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> @@ -575,7 +575,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.shar // CHECK: tt.call @__triton_consan_stage_access_for_commit // CHECK: tt.call @__triton_consan_commit_accesses // CHECK-NOT: tt.call @__triton_consan_verify_barrier_arrive - %1 = amdg.async_tdm_copy_global_to_local %desc[%c0_i32, %c0_i32] into %0, pred = %pred : !tt.tensordesc> -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> + %1 = amdg.async_tdm_copy_global_to_local %desc[%c0_i32, %c0_i32] into %0, pred = %pred : !tt.tensordesc<32x32xf32> -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> tt.return } } @@ -587,7 +587,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.shar #blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0]}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.shared = 65544 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 8 : i32} { // CHECK-LABEL: @async_tdm_copy_local_to_global - tt.func public @async_tdm_copy_local_to_global(%desc: !tt.tensordesc>, %ptr: tensor<128x128x!tt.ptr, #blocked>) { + tt.func public @async_tdm_copy_local_to_global(%desc: !tt.tensordesc<32x32xf32>, %ptr: tensor<128x128x!tt.ptr, #blocked>) { %c0_i32 = arith.constant 0 : i32 %0 = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> %shmem = ttg.local_alloc {allocation.offset = 4096 : i32} : () -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> @@ -598,7 +598,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.shar // CHECK: tt.call @__triton_consan_check_outstanding_commits_excl_self_noalias // CHECK: tt.call @__triton_consan_stage_access_for_commit // CHECK: tt.call @__triton_consan_commit_accesses - amdg.async_tdm_copy_local_to_global %desc[%c0_i32, %c0_i32] from %0 : !ttg.memdesc<32x32xf32, #shared, #smem, mutable> -> !tt.tensordesc> + amdg.async_tdm_copy_local_to_global %desc[%c0_i32, %c0_i32] from %0 : !ttg.memdesc<32x32xf32, #shared, #smem, mutable> -> !tt.tensordesc<32x32xf32> tt.return } } @@ -609,7 +609,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.shar #smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.shared = 65544 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 8 : i32} { // CHECK-LABEL: @async_tdm_load_store_no_barrier - tt.func public @async_tdm_load_store_no_barrier(%in_desc: !tt.tensordesc>, %out_desc: !tt.tensordesc>) { + tt.func public @async_tdm_load_store_no_barrier(%in_desc: !tt.tensordesc<32x32xf32>, %out_desc: !tt.tensordesc<32x32xf32>) { %c0_i32 = arith.constant 0 : i32 %pred = arith.constant 1 : i32 %0 = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> @@ -617,11 +617,11 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.shar // CHECK: tt.call @__triton_consan_check_outstanding_commits_excl_self_noalias // CHECK: tt.call @__triton_consan_stage_access_for_commit // CHECK: tt.call @__triton_consan_commit_accesses - %1 = amdg.async_tdm_copy_global_to_local %in_desc[%c0_i32, %c0_i32] into %0, pred = %pred : !tt.tensordesc> -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> + %1 = amdg.async_tdm_copy_global_to_local %in_desc[%c0_i32, %c0_i32] into %0, pred = %pred : !tt.tensordesc<32x32xf32> -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> // CHECK: tt.call @__triton_consan_check_outstanding_commits_excl_self_noalias // CHECK: tt.call @__triton_consan_stage_access_for_commit // CHECK: tt.call @__triton_consan_commit_accesses - amdg.async_tdm_copy_local_to_global %out_desc[%c0_i32, %c0_i32] from %0 : !ttg.memdesc<32x32xf32, #shared, #smem, mutable> -> !tt.tensordesc> + amdg.async_tdm_copy_local_to_global %out_desc[%c0_i32, %c0_i32] from %0 : !ttg.memdesc<32x32xf32, #shared, #smem, mutable> -> !tt.tensordesc<32x32xf32> tt.return } } @@ -633,7 +633,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.shar #smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.shared = 65544 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 8 : i32} { // CHECK-LABEL: @async_tdm_copy_local_to_global_with_barrier - tt.func public @async_tdm_copy_local_to_global_with_barrier(%desc: !tt.tensordesc>) { + tt.func public @async_tdm_copy_local_to_global_with_barrier(%desc: !tt.tensordesc<32x32xf32>) { %c0_i32 = arith.constant 0 : i32 %0 = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> %bar = ttg.local_alloc {allocation.offset = 65536 : i32} : () -> !ttg.memdesc<1xi64, #shared1, #smem, mutable> @@ -648,7 +648,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.shar // CHECK: tt.call @__triton_consan_verify_barrier_arrive // CHECK: tt.call @__triton_consan_update_barrier_state // CHECK-NOT: tt.call @__triton_consan_stage_access_for_commit - amdg.async_tdm_copy_local_to_global %desc[%c0_i32, %c0_i32] from %0, barrier = %bar : !ttg.memdesc<32x32xf32, #shared, #smem, mutable>, !ttg.memdesc<1xi64, #shared1, #smem, mutable> -> !tt.tensordesc> + amdg.async_tdm_copy_local_to_global %desc[%c0_i32, %c0_i32] from %0, barrier = %bar : !ttg.memdesc<32x32xf32, #shared, #smem, mutable>, !ttg.memdesc<1xi64, #shared1, #smem, mutable> -> !tt.tensordesc<32x32xf32> tt.return } } @@ -716,11 +716,11 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar #blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0]}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.shared = 65544 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 8 : i32} { // CHECK-LABEL: @tdm_load_no_barrier_wait - tt.func public @tdm_load_no_barrier_wait(%desc: !tt.tensordesc>) { + tt.func public @tdm_load_no_barrier_wait(%desc: !tt.tensordesc<32x32xf32>) { %c0_i32 = arith.constant 0 : i32 %pred = arith.constant 1 : i32 %0 = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> - %1 = amdg.async_tdm_copy_global_to_local %desc[%c0_i32, %c0_i32] into %0, pred = %pred : !tt.tensordesc> -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> + %1 = amdg.async_tdm_copy_global_to_local %desc[%c0_i32, %c0_i32] into %0, pred = %pred : !tt.tensordesc<32x32xf32> -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> // CHECK: tt.call @__triton_consan_clear_outstanding_commits_transfer_both amdg.async_tdm_wait {num = 0 : i32} ttg.local_load %0 : !ttg.memdesc<32x32xf32, #shared, #smem, mutable> -> tensor<32x32xf32, #blocked> @@ -735,10 +735,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.shar #blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0]}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.shared = 65544 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 8 : i32} { // CHECK-LABEL: @tdm_store_no_barrier_wait - tt.func public @tdm_store_no_barrier_wait(%desc: !tt.tensordesc>) { + tt.func public @tdm_store_no_barrier_wait(%desc: !tt.tensordesc<32x32xf32>) { %c0_i32 = arith.constant 0 : i32 %0 = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> - amdg.async_tdm_copy_local_to_global %desc[%c0_i32, %c0_i32] from %0 : !ttg.memdesc<32x32xf32, #shared, #smem, mutable> -> !tt.tensordesc> + amdg.async_tdm_copy_local_to_global %desc[%c0_i32, %c0_i32] from %0 : !ttg.memdesc<32x32xf32, #shared, #smem, mutable> -> !tt.tensordesc<32x32xf32> // CHECK: tt.call @__triton_consan_clear_outstanding_commits_transfer_both amdg.async_tdm_wait {num = 0 : i32} ttg.local_load %0 : !ttg.memdesc<32x32xf32, #shared, #smem, mutable> -> tensor<32x32xf32, #blocked> @@ -753,12 +753,12 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.shar #blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0]}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.shared = 65544 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 8 : i32} { // CHECK-LABEL: @tdm_load_store_no_barrier_wait - tt.func public @tdm_load_store_no_barrier_wait(%desc: !tt.tensordesc>) { + tt.func public @tdm_load_store_no_barrier_wait(%desc: !tt.tensordesc<32x32xf32>) { %c0_i32 = arith.constant 0 : i32 %pred = arith.constant 1 : i32 %0 = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> - %1 = amdg.async_tdm_copy_global_to_local %desc[%c0_i32, %c0_i32] into %0, pred = %pred : !tt.tensordesc> -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> - amdg.async_tdm_copy_local_to_global %desc[%c0_i32, %c0_i32] from %0 : !ttg.memdesc<32x32xf32, #shared, #smem, mutable> -> !tt.tensordesc> + %1 = amdg.async_tdm_copy_global_to_local %desc[%c0_i32, %c0_i32] into %0, pred = %pred : !tt.tensordesc<32x32xf32> -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> + amdg.async_tdm_copy_local_to_global %desc[%c0_i32, %c0_i32] from %0 : !ttg.memdesc<32x32xf32, #shared, #smem, mutable> -> !tt.tensordesc<32x32xf32> // CHECK: tt.call @__triton_consan_clear_outstanding_commits_transfer_both amdg.async_tdm_wait {num = 0 : i32} ttg.local_load %0 : !ttg.memdesc<32x32xf32, #shared, #smem, mutable> -> tensor<32x32xf32, #blocked> diff --git a/test/TritonGPU/amd/amd-convert-tensor-ops.mlir b/test/TritonGPU/amd/amd-convert-tensor-ops.mlir index ceca7b44b9fc..edf9372f2eca 100644 --- a/test/TritonGPU/amd/amd-convert-tensor-ops.mlir +++ b/test/TritonGPU/amd/amd-convert-tensor-ops.mlir @@ -1,9 +1,9 @@ // RUN: triton-opt %s -split-input-file --tritonamdgpu-convert-tensor-ops | FileCheck %s // CHECK-LABEL: test_cvt1 -// CHECK: amdg.async_tdm_copy_global_to_local {{.*}}: !tt.tensordesc> -> !ttg.memdesc<128x16xf16, #shared, #smem, mutable> +// CHECK: amdg.async_tdm_copy_global_to_local {{.*}}: !tt.tensordesc<128x16xf16, #shared> -> !ttg.memdesc<128x16xf16, #shared, #smem, mutable> // CHECK: amdg.async_tdm_wait {num = 0 : i32} -// CHECK: amdg.async_tdm_copy_local_to_global {{.*}} : !ttg.memdesc<128x128xf16, #shared2, #smem, mutable> -> !tt.tensordesc> +// CHECK: amdg.async_tdm_copy_local_to_global {{.*}} : !ttg.memdesc<128x128xf16, #shared2, #smem, mutable> -> !tt.tensordesc<128x128xf16, #shared2> // CHECK: amdg.async_tdm_wait {num = 0 : i32} #blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0], CGALayout = [[0, 0], [1, 0]]}> @@ -30,17 +30,17 @@ module attributes {"ttg.num-ctas" = 4 : i32, "ttg.num-warps" = 8 : i32, ttg.targ %1 = tt.get_program_id y : i32 %2 = arith.muli %0, %c128_i32 : i32 %3 = arith.muli %1, %c128_i32 : i32 - %4 = tt.make_tensor_descriptor %a_ptr, [%c1024_i32, %c256_i32], [%c256_i64, %c1_i64] : , > - %5 = tt.make_tensor_descriptor %b_ptr, [%c256_i32, %c512_i32], [%c512_i64, %c1_i64] : , > - %6 = tt.make_tensor_descriptor %c_ptr, [%c1024_i32, %c512_i32], [%c512_i64, %c1_i64] : , > - %7 = tt.descriptor_load %4[%2, %c0_i32] : !tt.tensordesc> -> tensor<128x16xf16, #blocked> - %8 = tt.descriptor_load %5[%c0_i32, %3] : !tt.tensordesc> -> tensor<16x128xf16, #blocked1> + %4 = tt.make_tensor_descriptor %a_ptr, [%c1024_i32, %c256_i32], [%c256_i64, %c1_i64] : , <128x16xf16, #shared> + %5 = tt.make_tensor_descriptor %b_ptr, [%c256_i32, %c512_i32], [%c512_i64, %c1_i64] : , <16x128xf16, #shared1> + %6 = tt.make_tensor_descriptor %c_ptr, [%c1024_i32, %c512_i32], [%c512_i64, %c1_i64] : , <128x128xf16, #shared2> + %7 = tt.descriptor_load %4[%2, %c0_i32] : !tt.tensordesc<128x16xf16, #shared> -> tensor<128x16xf16, #blocked> + %8 = tt.descriptor_load %5[%c0_i32, %3] : !tt.tensordesc<16x128xf16, #shared1> -> tensor<16x128xf16, #blocked1> %9 = ttg.convert_layout %7 : tensor<128x16xf16, #blocked> -> tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> %10 = ttg.convert_layout %8 : tensor<16x128xf16, #blocked1> -> tensor<16x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>> %11 = tt.dot %9, %10, %cst : tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> * tensor<16x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>> -> tensor<128x128xf32, #mma> %12 = arith.truncf %11 : tensor<128x128xf32, #mma> to tensor<128x128xf16, #mma> %13 = ttg.convert_layout %12 : tensor<128x128xf16, #mma> -> tensor<128x128xf16, #blocked2> - tt.descriptor_store %6[%2, %3], %13 : !tt.tensordesc>, tensor<128x128xf16, #blocked2> + tt.descriptor_store %6[%2, %3], %13 : !tt.tensordesc<128x128xf16, #shared2>, tensor<128x128xf16, #blocked2> tt.return } } diff --git a/test/TritonGPU/amd/amd-optimize-descriptor-encoding.mlir b/test/TritonGPU/amd/amd-optimize-descriptor-encoding.mlir index e5afc40e3635..816df5d8bae9 100644 --- a/test/TritonGPU/amd/amd-optimize-descriptor-encoding.mlir +++ b/test/TritonGPU/amd/amd-optimize-descriptor-encoding.mlir @@ -8,15 +8,15 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // CHECK-DAG: #[[$PADDED:.*]] = #ttg.padded_shared<[32:+16] {order = [1, 0], shape = [1, 32]}> // CHECK-LABEL: @descriptor_gather tt.func public @descriptor_gather(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}, %arg3: tensor<32xi32, #blocked> ) -> tensor<32x32xi8, #blocked1> { - // CHECK: tt.make_tensor_descriptor {{.*}} : , > - // CHECK: tt.descriptor_gather {{.*}} : (!tt.tensordesc> + // CHECK: tt.make_tensor_descriptor {{.*}} : , <1x32xi8, #[[$PADDED]]> + // CHECK: tt.descriptor_gather {{.*}} : (!tt.tensordesc<1x32xi8, #[[$PADDED]]> %c1_i64 = arith.constant 1 : i64 %cst = arith.constant dense<32> : tensor<8x1xi32> %c64_i32 = arith.constant 64 : i32 %c8_i32 = arith.constant 8 : i32 %0 = arith.extsi %arg2 : i32 to i64 - %1 = tt.make_tensor_descriptor %arg0, [%arg1, %arg2], [%0, %c1_i64] : , > - %2 = tt.descriptor_gather %1[%arg3, %c8_i32] : (!tt.tensordesc>, tensor<32xi32, #blocked>, i32) -> tensor<32x32xi8, #blocked1> + %1 = tt.make_tensor_descriptor %arg0, [%arg1, %arg2], [%0, %c1_i64] : , <1x32xi8> + %2 = tt.descriptor_gather %1[%arg3, %c8_i32] : (!tt.tensordesc<1x32xi8>, tensor<32xi32, #blocked>, i32) -> tensor<32x32xi8, #blocked1> tt.return %2 : tensor<32x32xi8, #blocked1> } } @@ -29,15 +29,15 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // CHECK-DAG: #[[$PADDED:.*]] = #ttg.padded_shared<[32:+16] {order = [1, 0], shape = [1, 32]}> // CHECK-LABEL: @descriptor_scatter tt.func public @descriptor_scatter(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}, %arg3: tensor<32xi32, #blocked>, %arg4: tensor<32x32xi8, #blocked1>) { - // CHECK: tt.make_tensor_descriptor {{.*}} : , > - // CHECK: tt.descriptor_scatter {{.*}} : !tt.tensordesc>, {{.*}} + // CHECK: tt.make_tensor_descriptor {{.*}} : , <1x32xi8, #[[$PADDED]]> + // CHECK: tt.descriptor_scatter {{.*}} : !tt.tensordesc<1x32xi8, #[[$PADDED]]>, {{.*}} %c1_i64 = arith.constant 1 : i64 %cst = arith.constant dense<32> : tensor<8x1xi32> %c64_i32 = arith.constant 64 : i32 %c8_i32 = arith.constant 8 : i32 %0 = arith.extsi %arg2 : i32 to i64 - %1 = tt.make_tensor_descriptor %arg0, [%arg1, %arg2], [%0, %c1_i64] : , > - tt.descriptor_scatter %1[%arg3, %c8_i32], %arg4 : !tt.tensordesc>, tensor<32xi32, #blocked>, i32, tensor<32x32xi8, #blocked1> + %1 = tt.make_tensor_descriptor %arg0, [%arg1, %arg2], [%0, %c1_i64] : , <1x32xi8> + tt.descriptor_scatter %1[%arg3, %c8_i32], %arg4 : !tt.tensordesc<1x32xi8>, tensor<32xi32, #blocked>, i32, tensor<32x32xi8, #blocked1> tt.return } } @@ -53,13 +53,13 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // CHECK-DAG: #[[$PADDED_ALLOC:.*]] = #ttg.padded_shared<[32:+2] {order = [1, 0], shape = [256, 32]}> // CHECK-LABEL: @descriptor_load tt.func public @descriptor_load(%arg0: !tt.ptr, %arg1: i32, %arg2: i32, %arg3: i64, %arg4: i64) { - // CHECK: tt.make_tensor_descriptor {{.*}} : , > - // CHECK: %[[LOAD:.*]] = tt.descriptor_load {{.*}} : !tt.tensordesc> -> tensor<256x32xf32, #[[$BLOCKED]]> + // CHECK: tt.make_tensor_descriptor {{.*}} : , <1x256x32xf32, #[[$PADDED]]> + // CHECK: %[[LOAD:.*]] = tt.descriptor_load {{.*}} : !tt.tensordesc<1x256x32xf32, #[[$PADDED]]> -> tensor<256x32xf32, #[[$BLOCKED]]> // CHECK: ttg.local_alloc %[[LOAD]] : (tensor<256x32xf32, #[[$BLOCKED]]>) -> !ttg.memdesc<256x32xf32, #[[$PADDED_ALLOC]], #smem> %c1_i32 = arith.constant 1 : i32 %c1_i64 = arith.constant 1 : i64 - %0 = tt.make_tensor_descriptor %arg0, [%c1_i32, %arg1, %arg2], [%arg3, %arg4, %c1_i64] : , > - %1 = tt.descriptor_load %0[%c1_i32, %c1_i32, %c1_i32] : !tt.tensordesc> -> tensor<256x32xf32, #blocked> + %0 = tt.make_tensor_descriptor %arg0, [%c1_i32, %arg1, %arg2], [%arg3, %arg4, %c1_i64] : , <1x256x32xf32> + %1 = tt.descriptor_load %0[%c1_i32, %c1_i32, %c1_i32] : !tt.tensordesc<1x256x32xf32> -> tensor<256x32xf32, #blocked> %2 = ttg.local_alloc %1 : (tensor<256x32xf32, #blocked>) -> !ttg.memdesc<256x32xf32, #shared, #smem> tt.return } @@ -74,12 +74,12 @@ tt.func public @descriptor_load(%arg0: !tt.ptr, %arg1: i32, %arg2: i32, %ar // CHECK-DAG: #[[$PADDED:.*]] = #ttg.padded_shared<[64:+8] {order = [1, 0], shape = [64, 64]}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1250"} { // CHECK-LABEL: @descriptor_kernel_arg -tt.func public @descriptor_kernel_arg(%arg0: !tt.tensordesc>, %arg1: i32, %arg2: i32, %arg3: i64, %arg4: i64) { - // CHECK: %arg0: !tt.tensordesc> - // CHECK: %[[LOAD:.*]] = tt.descriptor_load %arg0[{{.*}}] : !tt.tensordesc> -> tensor<64x64xf16, #[[$BLOCKED]]> +tt.func public @descriptor_kernel_arg(%arg0: !tt.tensordesc<64x64xf16>, %arg1: i32, %arg2: i32, %arg3: i64, %arg4: i64) { + // CHECK: %arg0: !tt.tensordesc<64x64xf16, #[[$PADDED]]> + // CHECK: %[[LOAD:.*]] = tt.descriptor_load %arg0[{{.*}}] : !tt.tensordesc<64x64xf16, #[[$PADDED]]> -> tensor<64x64xf16, #[[$BLOCKED]]> // CHECK: ttg.local_alloc %[[LOAD]] : (tensor<64x64xf16, #[[$BLOCKED]]>) -> !ttg.memdesc<64x64xf16, #[[$PADDED]], #smem> %c1_i32 = arith.constant 1 : i32 - %1 = tt.descriptor_load %arg0[%c1_i32, %c1_i32] : !tt.tensordesc> -> tensor<64x64xf16, #blocked> + %1 = tt.descriptor_load %arg0[%c1_i32, %c1_i32] : !tt.tensordesc<64x64xf16> -> tensor<64x64xf16, #blocked> %2 = ttg.local_alloc %1 : (tensor<64x64xf16, #blocked>) -> !ttg.memdesc<64x64xf16, #shared, #smem> tt.return } @@ -103,22 +103,22 @@ tt.func public @descriptor_load_while(%arg0: !tt.ptr {tt.divisibility = 16 : %c1_i64 = arith.constant 1 : i64 %0 = arith.extsi %arg2 : i32 to i64 - // CHECK: tt.make_tensor_descriptor {{.*}} : , > - %1 = tt.make_tensor_descriptor %arg0, [%arg1, %arg2], [%0, %c1_i64] : , > + // CHECK: tt.make_tensor_descriptor {{.*}} : , <1x32xi8, #[[$PADDED_DESC]]> + %1 = tt.make_tensor_descriptor %arg0, [%arg1, %arg2], [%0, %c1_i64] : , <1x32xi8> - %2 = scf.while (%arg4 = %1) : (!tt.tensordesc>) -> (!tt.tensordesc>) { - scf.condition(%cond) %arg4 : !tt.tensordesc> + %2 = scf.while (%arg4 = %1) : (!tt.tensordesc<1x32xi8>) -> (!tt.tensordesc<1x32xi8>) { + scf.condition(%cond) %arg4 : !tt.tensordesc<1x32xi8> } do { - ^bb0(%arg4: !tt.tensordesc>): - // CHECK: ^bb0(%[[ARG4:.*]]: !tt.tensordesc>): - // CHECK: tt.descriptor_gather %[[ARG4]][{{.*}}] : (!tt.tensordesc> - %3 = tt.descriptor_gather %arg4[%arg3, %c8_i32] : (!tt.tensordesc>, tensor<32xi32, #blocked>, i32) -> tensor<32x32xi8, #blocked1> + ^bb0(%arg4: !tt.tensordesc<1x32xi8>): + // CHECK: ^bb0(%[[ARG4:.*]]: !tt.tensordesc<1x32xi8, #[[$PADDED_DESC]]>): + // CHECK: tt.descriptor_gather %[[ARG4]][{{.*}}] : (!tt.tensordesc<1x32xi8, #[[$PADDED_DESC]]> + %3 = tt.descriptor_gather %arg4[%arg3, %c8_i32] : (!tt.tensordesc<1x32xi8>, tensor<32xi32, #blocked>, i32) -> tensor<32x32xi8, #blocked1> - scf.yield %arg4 : !tt.tensordesc> + scf.yield %arg4 : !tt.tensordesc<1x32xi8> } - // CHECK: %[[GATHER:.*]] = tt.descriptor_gather {{.*}} : (!tt.tensordesc> - %4 = tt.descriptor_gather %1[%arg3, %c8_i32] : (!tt.tensordesc>, tensor<32xi32, #blocked>, i32) -> tensor<32x32xi8, #blocked1> + // CHECK: %[[GATHER:.*]] = tt.descriptor_gather {{.*}} : (!tt.tensordesc<1x32xi8, #[[$PADDED_DESC]]> + %4 = tt.descriptor_gather %1[%arg3, %c8_i32] : (!tt.tensordesc<1x32xi8>, tensor<32xi32, #blocked>, i32) -> tensor<32x32xi8, #blocked1> // CHECK: ttg.local_alloc %[[GATHER]] {{.*}} : (tensor<32x32xi8, #blocked1>) -> !ttg.memdesc<32x32xi8, #[[$PADDED_ALLOC]], #smem> %8 = ttg.local_alloc %4 {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<32x32xi8, #blocked1>) -> !ttg.memdesc<32x32xi8, #shared, #smem> @@ -137,15 +137,15 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ // CHECK-DAG: #[[$PADDED_B:.*]] = #ttg.padded_shared<[128:+16] { // CHECK-LABEL: @descriptor_load_dot_operand tt.func public @descriptor_load_dot_operand(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32, %arg3: i32, %arg4: i64, %arg5: i64) { - // CHECK: tt.make_tensor_descriptor {{.*}} : , > - // CHECK: tt.make_tensor_descriptor {{.*}} : , , <512x32xf16, #[[$PADDED_A]]> + // CHECK: tt.make_tensor_descriptor {{.*}} : , <32x64xf16, #[[$PADDED_B]] %c0_i32 = arith.constant 0 : i32 %c1_i64 = arith.constant 1 : i64 %cst = arith.constant dense<0.000000e+00> : tensor<512x64xf32, #mma> - %0 = tt.make_tensor_descriptor %arg0, [%arg2, %arg3], [%arg4, %c1_i64] : , > - %1 = tt.make_tensor_descriptor %arg1, [%arg3, %arg2], [%arg5, %c1_i64] : , > - %2 = tt.descriptor_load %0[%c0_i32, %c0_i32] : !tt.tensordesc> -> tensor<512x32xf16, #blocked> - %3 = tt.descriptor_load %1[%c0_i32, %c0_i32] : !tt.tensordesc> -> tensor<32x64xf16, #blocked1> + %0 = tt.make_tensor_descriptor %arg0, [%arg2, %arg3], [%arg4, %c1_i64] : , <512x32xf16> + %1 = tt.make_tensor_descriptor %arg1, [%arg3, %arg2], [%arg5, %c1_i64] : , <32x64xf16> + %2 = tt.descriptor_load %0[%c0_i32, %c0_i32] : !tt.tensordesc<512x32xf16> -> tensor<512x32xf16, #blocked> + %3 = tt.descriptor_load %1[%c0_i32, %c0_i32] : !tt.tensordesc<32x64xf16> -> tensor<32x64xf16, #blocked1> %4 = ttg.convert_layout %2 : tensor<512x32xf16, #blocked> -> tensor<512x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> %5 = ttg.convert_layout %3 : tensor<32x64xf16, #blocked1> -> tensor<32x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>> %6 = tt.dot %4, %5, %cst : tensor<512x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> * tensor<32x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>> -> tensor<512x64xf32, #mma> @@ -172,23 +172,23 @@ tt.func public @descriptor_fallback(%arg0: !tt.ptr, %arg1: i32, %arg2: i32, %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %rng = arith.constant 5 : index - // CHECK: tt.make_tensor_descriptor {{.*}} : , > - %0 = tt.make_tensor_descriptor %arg0, [%c1_i32, %arg1, %arg2], [%c1_i64, %arg3, %arg4] : , > - // CHECK: scf.for {{.*}} -> (!tt.tensordesc>) - %1 = scf.for %iv = %c0 to %rng step %c1 iter_args(%iter_desc = %0) -> (!tt.tensordesc>) { - // CHECK: scf.if {{.*}} -> (!tt.tensordesc>) - %2 = scf.if %cond -> (!tt.tensordesc>) { - // CHECK: tt.descriptor_load {{.*}} : !tt.tensordesc> -> tensor<64x32xf32, #[[$BLOCKED]]> - %3 = tt.descriptor_load %iter_desc[%c1_i32, %c1_i32, %c1_i32] : !tt.tensordesc> -> tensor<64x32xf32, #blocked> + // CHECK: tt.make_tensor_descriptor {{.*}} : , <1x64x32xf32, #[[$PADDED_FALLBACK]]> + %0 = tt.make_tensor_descriptor %arg0, [%c1_i32, %arg1, %arg2], [%c1_i64, %arg3, %arg4] : , <1x64x32xf32> + // CHECK: scf.for {{.*}} -> (!tt.tensordesc<1x64x32xf32, #[[$PADDED_FALLBACK]]>) + %1 = scf.for %iv = %c0 to %rng step %c1 iter_args(%iter_desc = %0) -> (!tt.tensordesc<1x64x32xf32>) { + // CHECK: scf.if {{.*}} -> (!tt.tensordesc<1x64x32xf32, #[[$PADDED_FALLBACK]]>) + %2 = scf.if %cond -> (!tt.tensordesc<1x64x32xf32>) { + // CHECK: tt.descriptor_load {{.*}} : !tt.tensordesc<1x64x32xf32, #[[$PADDED_FALLBACK]]> -> tensor<64x32xf32, #[[$BLOCKED]]> + %3 = tt.descriptor_load %iter_desc[%c1_i32, %c1_i32, %c1_i32] : !tt.tensordesc<1x64x32xf32> -> tensor<64x32xf32, #blocked> %4 = ttg.local_alloc %3 : (tensor<64x32xf32, #blocked>) -> !ttg.memdesc<64x32xf32, #shared, #smem, mutable> - scf.yield %iter_desc : !tt.tensordesc> + scf.yield %iter_desc : !tt.tensordesc<1x64x32xf32> } else { - // CHECK: tt.descriptor_load {{.*}} : !tt.tensordesc> -> tensor<64x32xf32, #[[$BLOCKED]]> - %5 = tt.descriptor_load %iter_desc[%c0_i32, %c0_i32, %c1_i32] : !tt.tensordesc> -> tensor<64x32xf32, #blocked> + // CHECK: tt.descriptor_load {{.*}} : !tt.tensordesc<1x64x32xf32, #[[$PADDED_FALLBACK]]> -> tensor<64x32xf32, #[[$BLOCKED]]> + %5 = tt.descriptor_load %iter_desc[%c0_i32, %c0_i32, %c1_i32] : !tt.tensordesc<1x64x32xf32> -> tensor<64x32xf32, #blocked> %6 = ttg.local_alloc %5 : (tensor<64x32xf32, #blocked>) -> !ttg.memdesc<64x32xf32, #shared1, #smem, mutable> - scf.yield %iter_desc : !tt.tensordesc> + scf.yield %iter_desc : !tt.tensordesc<1x64x32xf32> } - scf.yield %2 : !tt.tensordesc> + scf.yield %2 : !tt.tensordesc<1x64x32xf32> } tt.return } diff --git a/test/TritonGPU/amd/amd-pipeline-tdm.mlir b/test/TritonGPU/amd/amd-pipeline-tdm.mlir index f474b804bcbe..1053759a2254 100644 --- a/test/TritonGPU/amd/amd-pipeline-tdm.mlir +++ b/test/TritonGPU/amd/amd-pipeline-tdm.mlir @@ -20,15 +20,15 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ %2 = arith.muli %0, %c512_i32 : i32 %3 = arith.muli %1, %c64_i32 : i32 %4 = arith.extsi %K : i32 to i64 - %5 = tt.make_tensor_descriptor %a_ptr, [%M, %K], [%4, %c1_i64] : , > + %5 = tt.make_tensor_descriptor %a_ptr, [%M, %K], [%4, %c1_i64] : , <512x32xf16> %6 = arith.extsi %N : i32 to i64 - %7 = tt.make_tensor_descriptor %b_ptr, [%K, %N], [%6, %c1_i64] : , > - %8 = tt.make_tensor_descriptor %c_ptr, [%M, %N], [%6, %c1_i64] : , > + %7 = tt.make_tensor_descriptor %b_ptr, [%K, %N], [%6, %c1_i64] : , <32x64xf16> + %8 = tt.make_tensor_descriptor %c_ptr, [%M, %N], [%6, %c1_i64] : , <512x64xf16> %9 = arith.addi %K, %c31_i32 : i32 %10 = arith.divsi %9, %c32_i32 : i32 %accumulator:2 = scf.for %accumulator_0 = %c0_i32 to %10 step %c1_i32 iter_args(%arg7 = %c0_i32, %arg8 = %cst) -> (i32, tensor<512x64xf32, #mma>) : i32 { - %13 = tt.descriptor_load %5[%2, %arg7] : !tt.tensordesc> -> tensor<512x32xf16, #blocked> - %14 = tt.descriptor_load %7[%arg7, %3] : !tt.tensordesc> -> tensor<32x64xf16, #blocked1> + %13 = tt.descriptor_load %5[%2, %arg7] : !tt.tensordesc<512x32xf16> -> tensor<512x32xf16, #blocked> + %14 = tt.descriptor_load %7[%arg7, %3] : !tt.tensordesc<32x64xf16> -> tensor<32x64xf16, #blocked1> %15 = ttg.convert_layout %13 : tensor<512x32xf16, #blocked> -> tensor<512x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> %16 = ttg.convert_layout %14 : tensor<32x64xf16, #blocked1> -> tensor<32x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>> %17 = tt.dot %15, %16, %arg8 : tensor<512x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> * tensor<32x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>> -> tensor<512x64xf32, #mma> @@ -37,7 +37,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ } %11 = arith.truncf %accumulator#1 : tensor<512x64xf32, #mma> to tensor<512x64xf16, #mma> %12 = ttg.convert_layout %11 : tensor<512x64xf16, #mma> -> tensor<512x64xf16, #blocked1> - tt.descriptor_store %8[%2, %3], %12 : !tt.tensordesc>, tensor<512x64xf16, #blocked1> + tt.descriptor_store %8[%2, %3], %12 : !tt.tensordesc<512x64xf16>, tensor<512x64xf16, #blocked1> tt.return } } @@ -59,17 +59,17 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ // CHECK-NOT: #ttg.padded_shared // CHECK-LABEL: tt.func @matmul_kernel_make_tensor_descriptor -// CHECK: async_tdm_copy_global_to_local {{.*}} : !tt.tensordesc> -> !ttg.memdesc<512x32xf16, #[[$PADDED_A]], #smem, mutable> +// CHECK: async_tdm_copy_global_to_local {{.*}} : !tt.tensordesc<512x32xf16, #[[$PADDED_A]]> -> !ttg.memdesc<512x32xf16, #[[$PADDED_A]], #smem, mutable> // CHECK: ttg.async_commit_group tokens -// CHECK: async_tdm_copy_global_to_local {{.*}} : !tt.tensordesc> -> !ttg.memdesc<32x64xf16, #[[$PADDED_B]], #smem, mutable> +// CHECK: async_tdm_copy_global_to_local {{.*}} : !tt.tensordesc<32x64xf16, #[[$PADDED_B]]> -> !ttg.memdesc<32x64xf16, #[[$PADDED_B]], #smem, mutable> // CHECK: ttg.async_commit_group tokens // CHECK: scf.for -// CHECK: async_tdm_copy_global_to_local {{.*}} : !tt.tensordesc> -> !ttg.memdesc<512x32xf16, #[[$PADDED_A]], #smem, mutable> +// CHECK: async_tdm_copy_global_to_local {{.*}} : !tt.tensordesc<512x32xf16, #[[$PADDED_A]]> -> !ttg.memdesc<512x32xf16, #[[$PADDED_A]], #smem, mutable> // CHECK: ttg.async_commit_group tokens -// CHECK: async_tdm_copy_global_to_local {{.*}} : !tt.tensordesc> -> !ttg.memdesc<32x64xf16, #[[$PADDED_B]], #smem, mutable> +// CHECK: async_tdm_copy_global_to_local {{.*}} : !tt.tensordesc<32x64xf16, #[[$PADDED_B]]> -> !ttg.memdesc<32x64xf16, #[[$PADDED_B]], #smem, mutable> // CHECK: ttg.async_commit_group tokens // CHECK: } -// CHECK: tt.descriptor_store {{.*}} : !tt.tensordesc> +// CHECK: tt.descriptor_store {{.*}} : !tt.tensordesc<512x64xf16, #[[$PADDED_C]]> // ----- @@ -111,15 +111,15 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ %2 = arith.muli %0, %c256_i32 : i32 %3 = arith.muli %1, %c64_i32 : i32 %4 = arith.extsi %K : i32 to i64 - %5 = tt.make_tensor_descriptor %a_ptr, [%M, %K], [%4, %c1_i64] : , > + %5 = tt.make_tensor_descriptor %a_ptr, [%M, %K], [%4, %c1_i64] : , <256x64xf8E5M2> %6 = arith.extsi %N : i32 to i64 - %7 = tt.make_tensor_descriptor %b_ptr, [%K, %N], [%6, %c1_i64] : , > - %8 = tt.make_tensor_descriptor %c_ptr, [%M, %N], [%6, %c1_i64] : , > + %7 = tt.make_tensor_descriptor %b_ptr, [%K, %N], [%6, %c1_i64] : , <64x64xf8E5M2> + %8 = tt.make_tensor_descriptor %c_ptr, [%M, %N], [%6, %c1_i64] : , <256x64xf16> %9 = arith.addi %K, %c63_i32 : i32 %10 = arith.divsi %9, %c64_i32 : i32 %accumulator:2 = scf.for %iv = %c0_i32 to %10 step %c1_i32 iter_args(%k_off = %c0_i32, %acc = %cst) -> (i32, tensor<256x64xf32, #mma>) : i32 { - %a = tt.descriptor_load %5[%2, %k_off] : !tt.tensordesc> -> tensor<256x64xf8E5M2, #blocked> - %b = tt.descriptor_load %7[%k_off, %3] : !tt.tensordesc> -> tensor<64x64xf8E5M2, #blocked1> + %a = tt.descriptor_load %5[%2, %k_off] : !tt.tensordesc<256x64xf8E5M2> -> tensor<256x64xf8E5M2, #blocked> + %b = tt.descriptor_load %7[%k_off, %3] : !tt.tensordesc<64x64xf8E5M2> -> tensor<64x64xf8E5M2, #blocked1> %a_dot = ttg.convert_layout %a : tensor<256x64xf8E5M2, #blocked> -> tensor<256x64xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> %b_dot = ttg.convert_layout %b : tensor<64x64xf8E5M2, #blocked1> -> tensor<64x64xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>> %d = tt.dot %a_dot, %b_dot, %acc : tensor<256x64xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> * tensor<64x64xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>> -> tensor<256x64xf32, #mma> @@ -128,21 +128,21 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ } %out = arith.truncf %accumulator#1 : tensor<256x64xf32, #mma> to tensor<256x64xf16, #mma> %out_blocked = ttg.convert_layout %out : tensor<256x64xf16, #mma> -> tensor<256x64xf16, #blocked1> - tt.descriptor_store %8[%2, %3], %out_blocked : !tt.tensordesc>, tensor<256x64xf16, #blocked1> + tt.descriptor_store %8[%2, %3], %out_blocked : !tt.tensordesc<256x64xf16>, tensor<256x64xf16, #blocked1> tt.return } } // CHECK-LABEL: tt.func @tdm_padding_fp8 -// CHECK: async_tdm_copy_global_to_local {{.*}} : !tt.tensordesc> -> !ttg.memdesc<256x64xf8E5M2, #[[$PADDED_A]], #smem, mutable> -// CHECK: async_tdm_copy_global_to_local {{.*}} : !tt.tensordesc> -> !ttg.memdesc<64x64xf8E5M2, #[[$PADDED_B]], #smem, mutable> +// CHECK: async_tdm_copy_global_to_local {{.*}} : !tt.tensordesc<256x64xf8E5M2, #[[$PADDED_A]]> -> !ttg.memdesc<256x64xf8E5M2, #[[$PADDED_A]], #smem, mutable> +// CHECK: async_tdm_copy_global_to_local {{.*}} : !tt.tensordesc<64x64xf8E5M2, #[[$PADDED_B]]> -> !ttg.memdesc<64x64xf8E5M2, #[[$PADDED_B]], #smem, mutable> // CHECK: scf.for -// CHECK: async_tdm_copy_global_to_local {{.*}} : !tt.tensordesc> -> !ttg.memdesc<256x64xf8E5M2, #[[$PADDED_A]], #smem, mutable> +// CHECK: async_tdm_copy_global_to_local {{.*}} : !tt.tensordesc<256x64xf8E5M2, #[[$PADDED_A]]> -> !ttg.memdesc<256x64xf8E5M2, #[[$PADDED_A]], #smem, mutable> // CHECK: ttg.async_commit_group tokens -// CHECK: async_tdm_copy_global_to_local {{.*}} : !tt.tensordesc> -> !ttg.memdesc<64x64xf8E5M2, #[[$PADDED_B]], #smem, mutable> +// CHECK: async_tdm_copy_global_to_local {{.*}} : !tt.tensordesc<64x64xf8E5M2, #[[$PADDED_B]]> -> !ttg.memdesc<64x64xf8E5M2, #[[$PADDED_B]], #smem, mutable> // CHECK: ttg.async_commit_group tokens // CHECK: } -// CHECK: tt.descriptor_store {{.*}} : !tt.tensordesc> +// CHECK: tt.descriptor_store {{.*}} : !tt.tensordesc<256x64xf16, #[[$PADDED_C]]> // ----- @@ -184,15 +184,15 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ %2 = arith.muli %0, %c256_i32 : i32 %3 = arith.muli %1, %c64_i32 : i32 %4 = arith.extsi %K : i32 to i64 - %5 = tt.make_tensor_descriptor %a_ptr, [%M, %K], [%4, %c1_i64] : , > + %5 = tt.make_tensor_descriptor %a_ptr, [%M, %K], [%4, %c1_i64] : , <256x16xf32> %6 = arith.extsi %N : i32 to i64 - %7 = tt.make_tensor_descriptor %b_ptr, [%K, %N], [%6, %c1_i64] : , > - %8 = tt.make_tensor_descriptor %c_ptr, [%M, %N], [%6, %c1_i64] : , > + %7 = tt.make_tensor_descriptor %b_ptr, [%K, %N], [%6, %c1_i64] : , <16x64xf32> + %8 = tt.make_tensor_descriptor %c_ptr, [%M, %N], [%6, %c1_i64] : , <256x64xf32> %9 = arith.addi %K, %c15_i32 : i32 %10 = arith.divsi %9, %c16_i32 : i32 %accumulator:2 = scf.for %iv = %c0_i32 to %10 step %c1_i32 iter_args(%k_off = %c0_i32, %acc = %cst) -> (i32, tensor<256x64xf32, #mma>) : i32 { - %a = tt.descriptor_load %5[%2, %k_off] : !tt.tensordesc> -> tensor<256x16xf32, #blocked> - %b = tt.descriptor_load %7[%k_off, %3] : !tt.tensordesc> -> tensor<16x64xf32, #blocked1> + %a = tt.descriptor_load %5[%2, %k_off] : !tt.tensordesc<256x16xf32> -> tensor<256x16xf32, #blocked> + %b = tt.descriptor_load %7[%k_off, %3] : !tt.tensordesc<16x64xf32> -> tensor<16x64xf32, #blocked1> %a_dot = ttg.convert_layout %a : tensor<256x16xf32, #blocked> -> tensor<256x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> %b_dot = ttg.convert_layout %b : tensor<16x64xf32, #blocked1> -> tensor<16x64xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>> %d = tt.dot %a_dot, %b_dot, %acc, inputPrecision = tf32 : tensor<256x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> * tensor<16x64xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>> -> tensor<256x64xf32, #mma> @@ -200,18 +200,18 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ scf.yield %next_k, %d : i32, tensor<256x64xf32, #mma> } %out_blocked = ttg.convert_layout %accumulator#1 : tensor<256x64xf32, #mma> -> tensor<256x64xf32, #blocked1> - tt.descriptor_store %8[%2, %3], %out_blocked : !tt.tensordesc>, tensor<256x64xf32, #blocked1> + tt.descriptor_store %8[%2, %3], %out_blocked : !tt.tensordesc<256x64xf32>, tensor<256x64xf32, #blocked1> tt.return } } // CHECK-LABEL: tt.func @tdm_padding_f32 -// CHECK: async_tdm_copy_global_to_local {{.*}} : !tt.tensordesc> -> !ttg.memdesc<256x16xf32, #[[$PADDED_A]], #smem, mutable> -// CHECK: async_tdm_copy_global_to_local {{.*}} : !tt.tensordesc> -> !ttg.memdesc<16x64xf32, #[[$PADDED_B]], #smem, mutable> +// CHECK: async_tdm_copy_global_to_local {{.*}} : !tt.tensordesc<256x16xf32, #[[$PADDED_A]]> -> !ttg.memdesc<256x16xf32, #[[$PADDED_A]], #smem, mutable> +// CHECK: async_tdm_copy_global_to_local {{.*}} : !tt.tensordesc<16x64xf32, #[[$PADDED_B]]> -> !ttg.memdesc<16x64xf32, #[[$PADDED_B]], #smem, mutable> // CHECK: scf.for -// CHECK: async_tdm_copy_global_to_local {{.*}} : !tt.tensordesc> -> !ttg.memdesc<256x16xf32, #[[$PADDED_A]], #smem, mutable> +// CHECK: async_tdm_copy_global_to_local {{.*}} : !tt.tensordesc<256x16xf32, #[[$PADDED_A]]> -> !ttg.memdesc<256x16xf32, #[[$PADDED_A]], #smem, mutable> // CHECK: ttg.async_commit_group tokens -// CHECK: async_tdm_copy_global_to_local {{.*}} : !tt.tensordesc> -> !ttg.memdesc<16x64xf32, #[[$PADDED_B]], #smem, mutable> +// CHECK: async_tdm_copy_global_to_local {{.*}} : !tt.tensordesc<16x64xf32, #[[$PADDED_B]]> -> !ttg.memdesc<16x64xf32, #[[$PADDED_B]], #smem, mutable> // CHECK: ttg.async_commit_group tokens // CHECK: } -// CHECK: tt.descriptor_store {{.*}} : !tt.tensordesc> +// CHECK: tt.descriptor_store {{.*}} : !tt.tensordesc<256x64xf32, #[[$PADDED_C]]> diff --git a/test/TritonGPU/amd/amd-update-async-wait-count-without-token.mlir b/test/TritonGPU/amd/amd-update-async-wait-count-without-token.mlir index 86753cdda4a1..7f9d367254d0 100644 --- a/test/TritonGPU/amd/amd-update-async-wait-count-without-token.mlir +++ b/test/TritonGPU/amd/amd-update-async-wait-count-without-token.mlir @@ -570,7 +570,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // CHECK-LABEL: tdm_gather_scatter_multiple_instructions tt.func public @tdm_gather_scatter_multiple_instructions( %memDesc: !ttg.memdesc<256x128xf16, #shared, #smem, mutable>, - %tensorDesc: !tt.tensordesc>, + %tensorDesc: !tt.tensordesc<64x128xf16>, %row_indices_i32: tensor<64xi32>, %row_indices_i16: tensor<256xi16>, %pred: i32 @@ -578,13 +578,13 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ %c0_i32 = arith.constant 0 : i32 // Gather with 64xi32 indices: 64/8 = 8 instructions - amdg.async_tdm_gather %tensorDesc[%row_indices_i32, %c0_i32] to %memDesc, pred = %pred : tensor<64xi32>, !ttg.memdesc<256x128xf16, #shared, #smem, mutable> -> !tt.tensordesc> + amdg.async_tdm_gather %tensorDesc[%row_indices_i32, %c0_i32] to %memDesc, pred = %pred : tensor<64xi32>, !ttg.memdesc<256x128xf16, #shared, #smem, mutable> -> !tt.tensordesc<64x128xf16> // Scatter with 64xi32 indices: 64/8 = 8 instructions - amdg.async_tdm_scatter %tensorDesc[%row_indices_i32, %c0_i32] from %memDesc : tensor<64xi32>, !ttg.memdesc<256x128xf16, #shared, #smem, mutable> -> !tt.tensordesc> + amdg.async_tdm_scatter %tensorDesc[%row_indices_i32, %c0_i32] from %memDesc : tensor<64xi32>, !ttg.memdesc<256x128xf16, #shared, #smem, mutable> -> !tt.tensordesc<64x128xf16> // Gather with 128xi16 indices: 256/16 = 16 instructions - amdg.async_tdm_gather %tensorDesc[%row_indices_i16, %c0_i32] to %memDesc, pred = %pred : tensor<256xi16>, !ttg.memdesc<256x128xf16, #shared, #smem, mutable> -> !tt.tensordesc> + amdg.async_tdm_gather %tensorDesc[%row_indices_i16, %c0_i32] to %memDesc, pred = %pred : tensor<256xi16>, !ttg.memdesc<256x128xf16, #shared, #smem, mutable> -> !tt.tensordesc<64x128xf16> // Scatter with 128xi16 indices: 256/16 = 16 instructions - amdg.async_tdm_scatter %tensorDesc[%row_indices_i16, %c0_i32] from %memDesc : tensor<256xi16>, !ttg.memdesc<256x128xf16, #shared, #smem, mutable> -> !tt.tensordesc> + amdg.async_tdm_scatter %tensorDesc[%row_indices_i16, %c0_i32] from %memDesc : tensor<256xi16>, !ttg.memdesc<256x128xf16, #shared, #smem, mutable> -> !tt.tensordesc<64x128xf16> // CHECK: amdg.async_tdm_intrinsic_wait {count = 0 amdg.async_tdm_wait {num = 0 : i32} @@ -610,15 +610,15 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // CHECK-LABEL: tdm_load_store_single_instruction tt.func public @tdm_load_store_single_instruction( %memDesc: !ttg.memdesc<64x128xf16, #shared, #smem, mutable>, - %tensorDesc: !tt.tensordesc>, + %tensorDesc: !tt.tensordesc<64x128xf16>, %pred: i32 ) { %c0_i32 = arith.constant 0 : i32 - %0 = amdg.async_tdm_copy_global_to_local %tensorDesc[%c0_i32, %c0_i32] into %memDesc, pred = %pred : !tt.tensordesc> -> !ttg.memdesc<64x128xf16, #shared, #smem, mutable> - amdg.async_tdm_copy_local_to_global %tensorDesc[%c0_i32, %c0_i32] from %memDesc : !ttg.memdesc<64x128xf16, #shared, #smem, mutable> -> !tt.tensordesc> - %1 = amdg.async_tdm_copy_global_to_local %tensorDesc[%c0_i32, %c0_i32] into %memDesc, pred = %pred : !tt.tensordesc> -> !ttg.memdesc<64x128xf16, #shared, #smem, mutable> - amdg.async_tdm_copy_local_to_global %tensorDesc[%c0_i32, %c0_i32] from %memDesc : !ttg.memdesc<64x128xf16, #shared, #smem, mutable> -> !tt.tensordesc> + %0 = amdg.async_tdm_copy_global_to_local %tensorDesc[%c0_i32, %c0_i32] into %memDesc, pred = %pred : !tt.tensordesc<64x128xf16> -> !ttg.memdesc<64x128xf16, #shared, #smem, mutable> + amdg.async_tdm_copy_local_to_global %tensorDesc[%c0_i32, %c0_i32] from %memDesc : !ttg.memdesc<64x128xf16, #shared, #smem, mutable> -> !tt.tensordesc<64x128xf16> + %1 = amdg.async_tdm_copy_global_to_local %tensorDesc[%c0_i32, %c0_i32] into %memDesc, pred = %pred : !tt.tensordesc<64x128xf16> -> !ttg.memdesc<64x128xf16, #shared, #smem, mutable> + amdg.async_tdm_copy_local_to_global %tensorDesc[%c0_i32, %c0_i32] from %memDesc : !ttg.memdesc<64x128xf16, #shared, #smem, mutable> -> !tt.tensordesc<64x128xf16> // CHECK: amdg.async_tdm_intrinsic_wait {count = 0 amdg.async_tdm_wait {num = 0 : i32} diff --git a/test/TritonGPU/amd/amd-update-async-wait-count.mlir b/test/TritonGPU/amd/amd-update-async-wait-count.mlir index 29e214b40dec..e51ca357d8a9 100644 --- a/test/TritonGPU/amd/amd-update-async-wait-count.mlir +++ b/test/TritonGPU/amd/amd-update-async-wait-count.mlir @@ -407,13 +407,13 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ #smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: simple_tdm_waitcnt - tt.func public @simple_tdm_waitcnt(%memDesc: !ttg.memdesc<128x16xf16, #shared, #smem, mutable>, %tensorDesc: !tt.tensordesc>, %mask: i32 + tt.func public @simple_tdm_waitcnt(%memDesc: !ttg.memdesc<128x16xf16, #shared, #smem, mutable>, %tensorDesc: !tt.tensordesc<128x16xf16>, %mask: i32 ) { %c0_i32 = arith.constant 0 : i32 // Each async_tdm_copy only emits a single instruction (-> counts 1) - %1 = amdg.async_tdm_copy_global_to_local %tensorDesc[%c0_i32, %c0_i32] into %memDesc, pred = %mask : !tt.tensordesc> -> !ttg.memdesc<128x16xf16, #shared, #smem, mutable> - %2 = amdg.async_tdm_copy_global_to_local %tensorDesc[%c0_i32, %c0_i32] into %memDesc, pred = %mask : !tt.tensordesc> -> !ttg.memdesc<128x16xf16, #shared, #smem, mutable> + %1 = amdg.async_tdm_copy_global_to_local %tensorDesc[%c0_i32, %c0_i32] into %memDesc, pred = %mask : !tt.tensordesc<128x16xf16> -> !ttg.memdesc<128x16xf16, #shared, #smem, mutable> + %2 = amdg.async_tdm_copy_global_to_local %tensorDesc[%c0_i32, %c0_i32] into %memDesc, pred = %mask : !tt.tensordesc<128x16xf16> -> !ttg.memdesc<128x16xf16, #shared, #smem, mutable> // Do not wait on the second tdm => waitcnt 1 // CHECK: amdg.async_tdm_intrinsic_wait {{.*}} {count = 1 @@ -488,17 +488,17 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ #smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: mix_async_copy_and_async_tdm_copy - tt.func public @mix_async_copy_and_async_tdm_copy(%memDesc: !ttg.memdesc<128x8xf16, #shared, #smem, mutable>, %tensorDesc: !tt.tensordesc>, %mask: i32, %ptr: tensor<128x8x!tt.ptr, #blocked> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[16, 16]> : tensor<2xi32>} + tt.func public @mix_async_copy_and_async_tdm_copy(%memDesc: !ttg.memdesc<128x8xf16, #shared, #smem, mutable>, %tensorDesc: !tt.tensordesc<128x8xf16>, %mask: i32, %ptr: tensor<128x8x!tt.ptr, #blocked> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[16, 16]> : tensor<2xi32>} ) { %c0_i32 = arith.constant 0 : i32 // Each async_tdm_copy only emits a single instruction (-> counts 1) - %1 = amdg.async_tdm_copy_global_to_local %tensorDesc[%c0_i32, %c0_i32] into %memDesc, pred = %mask : !tt.tensordesc> -> !ttg.memdesc<128x8xf16, #shared, #smem, mutable> + %1 = amdg.async_tdm_copy_global_to_local %tensorDesc[%c0_i32, %c0_i32] into %memDesc, pred = %mask : !tt.tensordesc<128x8xf16> -> !ttg.memdesc<128x8xf16, #shared, #smem, mutable> %2 = ttg.async_copy_global_to_local %ptr, %memDesc : tensor<128x8x!tt.ptr, #blocked> -> <128x8xf16, #shared, #smem, mutable> %21 = ttg.async_commit_group tokens %2 - %3 = amdg.async_tdm_copy_global_to_local %tensorDesc[%c0_i32, %c0_i32] into %memDesc, pred = %mask : !tt.tensordesc> -> !ttg.memdesc<128x8xf16, #shared, #smem, mutable> + %3 = amdg.async_tdm_copy_global_to_local %tensorDesc[%c0_i32, %c0_i32] into %memDesc, pred = %mask : !tt.tensordesc<128x8xf16> -> !ttg.memdesc<128x8xf16, #shared, #smem, mutable> %4 = ttg.async_copy_global_to_local %ptr, %memDesc : tensor<128x8x!tt.ptr, #blocked> -> <128x8xf16, #shared, #smem, mutable> %5 = ttg.async_copy_global_to_local %ptr, %memDesc : tensor<128x8x!tt.ptr, #blocked> -> <128x8xf16, #shared, #smem, mutable> @@ -558,7 +558,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // CHECK-LABEL: tdm_gather_scatter_multiple_instructions tt.func public @tdm_gather_scatter_multiple_instructions( %memDesc: !ttg.memdesc<256x128xf16, #shared, #smem, mutable>, - %tensorDesc: !tt.tensordesc>, + %tensorDesc: !tt.tensordesc<64x128xf16>, %row_indices_i32: tensor<64xi32>, %row_indices_i16: tensor<256xi16>, %pred: i32 @@ -566,13 +566,13 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ %c0_i32 = arith.constant 0 : i32 // Gather with 64xi32 indices: 64/8 = 8 instructions - %token1 = amdg.async_tdm_gather %tensorDesc[%row_indices_i32, %c0_i32] to %memDesc, pred = %pred : tensor<64xi32>, !ttg.memdesc<256x128xf16, #shared, #smem, mutable> -> !tt.tensordesc> + %token1 = amdg.async_tdm_gather %tensorDesc[%row_indices_i32, %c0_i32] to %memDesc, pred = %pred : tensor<64xi32>, !ttg.memdesc<256x128xf16, #shared, #smem, mutable> -> !tt.tensordesc<64x128xf16> // Scatter with 64xi32 indices: 64/8 = 8 instructions - %token2 = amdg.async_tdm_scatter %tensorDesc[%row_indices_i32, %c0_i32] from %memDesc : tensor<64xi32>, !ttg.memdesc<256x128xf16, #shared, #smem, mutable> -> !tt.tensordesc> + %token2 = amdg.async_tdm_scatter %tensorDesc[%row_indices_i32, %c0_i32] from %memDesc : tensor<64xi32>, !ttg.memdesc<256x128xf16, #shared, #smem, mutable> -> !tt.tensordesc<64x128xf16> // Gather with 128xi16 indices: 256/16 = 16 instructions - %token3 = amdg.async_tdm_gather %tensorDesc[%row_indices_i16, %c0_i32] to %memDesc, pred = %pred : tensor<256xi16>, !ttg.memdesc<256x128xf16, #shared, #smem, mutable> -> !tt.tensordesc> + %token3 = amdg.async_tdm_gather %tensorDesc[%row_indices_i16, %c0_i32] to %memDesc, pred = %pred : tensor<256xi16>, !ttg.memdesc<256x128xf16, #shared, #smem, mutable> -> !tt.tensordesc<64x128xf16> // Scatter with 128xi16 indices: 256/16 = 16 instructions - %token4 = amdg.async_tdm_scatter %tensorDesc[%row_indices_i16, %c0_i32] from %memDesc : tensor<256xi16>, !ttg.memdesc<256x128xf16, #shared, #smem, mutable> -> !tt.tensordesc> + %token4 = amdg.async_tdm_scatter %tensorDesc[%row_indices_i16, %c0_i32] from %memDesc : tensor<256xi16>, !ttg.memdesc<256x128xf16, #shared, #smem, mutable> -> !tt.tensordesc<64x128xf16> // CHECK: amdg.async_tdm_intrinsic_wait {{.*}} {count = 0 %w1 = amdg.async_tdm_wait %token4 {num = 0 : i32} diff --git a/test/TritonGPU/amd/invalid.mlir b/test/TritonGPU/amd/invalid.mlir index e389a8816eae..1d48578abd8e 100644 --- a/test/TritonGPU/amd/invalid.mlir +++ b/test/TritonGPU/amd/invalid.mlir @@ -209,22 +209,22 @@ module attributes {"ttg.num-ctas" = 4 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr #smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32} { tt.func public @interval_not_matching_innermost_block_dimension( - %tensorDesc: !tt.tensordesc>, + %tensorDesc: !tt.tensordesc<128x64xf16>, %memDesc: !ttg.memdesc<128x64xf16, #shared_32, #smem, mutable> ) { %c0_i32 = arith.constant 0 : i32 // expected-error @+1 {{TDM store padding is only supported when padding interval equals the innermost block dimension}} - amdg.async_tdm_copy_local_to_global %tensorDesc[%c0_i32, %c0_i32] from %memDesc: !ttg.memdesc<128x64xf16, #shared_32, #smem, mutable> -> !tt.tensordesc> + amdg.async_tdm_copy_local_to_global %tensorDesc[%c0_i32, %c0_i32] from %memDesc: !ttg.memdesc<128x64xf16, #shared_32, #smem, mutable> -> !tt.tensordesc<128x64xf16> tt.return } tt.func public @tdm_store_two_padding_intervals( - %tensorDesc: !tt.tensordesc>, + %tensorDesc: !tt.tensordesc<128x64xf16>, %memDesc: !ttg.memdesc<128x64xf16, #shared_2_intervals, #smem, mutable> ) { %c0_i32 = arith.constant 0 : i32 // expected-error @+1 {{TDM store only supports single interval paddings}} - amdg.async_tdm_copy_local_to_global %tensorDesc[%c0_i32, %c0_i32] from %memDesc: !ttg.memdesc<128x64xf16, #shared_2_intervals, #smem, mutable> -> !tt.tensordesc> + amdg.async_tdm_copy_local_to_global %tensorDesc[%c0_i32, %c0_i32] from %memDesc: !ttg.memdesc<128x64xf16, #shared_2_intervals, #smem, mutable> -> !tt.tensordesc<128x64xf16> tt.return } } diff --git a/test/TritonGPU/automatic-warp-specialization.mlir b/test/TritonGPU/automatic-warp-specialization.mlir index 037ef5547e30..41e94686a542 100644 --- a/test/TritonGPU/automatic-warp-specialization.mlir +++ b/test/TritonGPU/automatic-warp-specialization.mlir @@ -23,8 +23,8 @@ tt.func @matmul_change_desc_in_prologue( %false = arith.constant false %zero = arith.constant dense<0.0> : tensor<128x128xf32, #acc_layout> %k_tiles = arith.constant 32 : i32 - %a_desc_undef = ub.poison : !tt.tensordesc> - %b_desc_undef = ub.poison : !tt.tensordesc> + %a_desc_undef = ub.poison : !tt.tensordesc<128x64xf16, #shared> + %b_desc_undef = ub.poison : !tt.tensordesc<64x128xf16, #shared> // CHECK-LABEL: ttg.warp_specialize // CHECK-LABEL: default // BASE-NOT: tt.make_tensor_descriptor @@ -45,20 +45,20 @@ tt.func @matmul_change_desc_in_prologue( // PIPELINE-COUNT-2: async_tma_copy_global_to_local // PIPELINE-NOT: async_tma_copy_global_to_local // CHECK-NOT: partition2 - scf.for %k = %c0_i32 to %k_tiles step %c1_i32 iter_args(%acc = %zero, %flag = %true, %a_desc = %a_desc_undef, %b_desc = %b_desc_undef) -> (tensor<128x128xf32, #acc_layout>, i1, !tt.tensordesc>, !tt.tensordesc>) : i32 { + scf.for %k = %c0_i32 to %k_tiles step %c1_i32 iter_args(%acc = %zero, %flag = %true, %a_desc = %a_desc_undef, %b_desc = %b_desc_undef) -> (tensor<128x128xf32, #acc_layout>, i1, !tt.tensordesc<128x64xf16, #shared>, !tt.tensordesc<64x128xf16, #shared>) : i32 { %do_prologue = "prologue_cond"(%k) : (i32) -> i1 - %cur_a_desc, %cur_b_desc = scf.if %do_prologue -> (!tt.tensordesc>, !tt.tensordesc>) { + %cur_a_desc, %cur_b_desc = scf.if %do_prologue -> (!tt.tensordesc<128x64xf16, #shared>, !tt.tensordesc<64x128xf16, #shared>) { %c1_i64 = arith.constant 1 : i64 - %next_a_desc = tt.make_tensor_descriptor %a_base, [%k, %k], [%c1_i64, %c1_i64] : !tt.ptr, !tt.tensordesc> - %next_b_desc = tt.make_tensor_descriptor %b_base, [%k, %k], [%c1_i64, %c1_i64] : !tt.ptr, !tt.tensordesc> - scf.yield %next_a_desc, %next_b_desc : !tt.tensordesc>, !tt.tensordesc> + %next_a_desc = tt.make_tensor_descriptor %a_base, [%k, %k], [%c1_i64, %c1_i64] : !tt.ptr, !tt.tensordesc<128x64xf16, #shared> + %next_b_desc = tt.make_tensor_descriptor %b_base, [%k, %k], [%c1_i64, %c1_i64] : !tt.ptr, !tt.tensordesc<64x128xf16, #shared> + scf.yield %next_a_desc, %next_b_desc : !tt.tensordesc<128x64xf16, #shared>, !tt.tensordesc<64x128xf16, #shared> } else { - scf.yield %a_desc, %b_desc : !tt.tensordesc>, !tt.tensordesc> + scf.yield %a_desc, %b_desc : !tt.tensordesc<128x64xf16, #shared>, !tt.tensordesc<64x128xf16, #shared> } %off_m, %off_n, %off_k = "get_offsets"(%k) : (i32) -> (i32, i32, i32) - %a = tt.descriptor_load %a_desc[%off_m, %off_k] : !tt.tensordesc> -> tensor<128x64xf16, #oper_layout> - %b = tt.descriptor_load %b_desc[%off_n, %off_k] : !tt.tensordesc> -> tensor<64x128xf16, #oper_layout> + %a = tt.descriptor_load %a_desc[%off_m, %off_k] : !tt.tensordesc<128x64xf16, #shared> -> tensor<128x64xf16, #oper_layout> + %b = tt.descriptor_load %b_desc[%off_n, %off_k] : !tt.tensordesc<64x128xf16, #shared> -> tensor<64x128xf16, #oper_layout> %a_shared = ttg.local_alloc %a : (tensor<128x64xf16, #oper_layout>) -> !ttg.memdesc<128x64xf16, #shared, #smem> %b_shared = ttg.local_alloc %b : (tensor<64x128xf16, #oper_layout>) -> !ttg.memdesc<64x128xf16, #shared, #smem> %c_tmem, %c_tok = ttng.tmem_alloc %acc : (tensor<128x128xf32, #acc_layout>) -> (!ttg.memdesc<128x128xf32, #acc_tmem, #ttng.tensor_memory, mutable>, !ttg.async.token) @@ -70,7 +70,7 @@ tt.func @matmul_change_desc_in_prologue( scf.if %do_epilogue { "acc_user"(%c) : (tensor<128x128xf32, #acc_layout>) -> () } - scf.yield %c, %use_acc, %cur_a_desc, %cur_b_desc : tensor<128x128xf32, #acc_layout>, i1, !tt.tensordesc>, !tt.tensordesc> + scf.yield %c, %use_acc, %cur_a_desc, %cur_b_desc : tensor<128x128xf32, #acc_layout>, i1, !tt.tensordesc<128x64xf16, #shared>, !tt.tensordesc<64x128xf16, #shared> } {tt.warp_specialize, tt.disallow_acc_multi_buffer, tt.num_stages = 2 : i32} tt.return @@ -78,8 +78,8 @@ tt.func @matmul_change_desc_in_prologue( // CHECK-LABEL: @matmul_tma_acc_with_conditional_def_and_use tt.func @matmul_tma_acc_with_conditional_def_and_use( - %a_desc: !tt.tensordesc>, - %b_desc: !tt.tensordesc> + %a_desc: !tt.tensordesc<1x64xf16, #shared>, + %b_desc: !tt.tensordesc<64x128xf16, #shared> ) { %c0_i32 = arith.constant 0 : i32 %c1_i32 = arith.constant 1 : i32 @@ -101,8 +101,8 @@ tt.func @matmul_tma_acc_with_conditional_def_and_use( scf.for %k = %c0_i32 to %k_tiles step %c1_i32 iter_args(%acc = %zero, %flag = %true) -> (tensor<128x128xf32, #acc_layout>, i1) : i32 { %off_m, %off_n, %off_k = "get_offsets"(%k) : (i32) -> (i32, i32, i32) %indices = tt.splat %off_m : i32 -> tensor<128xi32, #indices_layout> - %a = tt.descriptor_gather %a_desc[%indices, %off_k] : (!tt.tensordesc>, tensor<128xi32, #indices_layout>, i32) -> tensor<128x64xf16, #oper_layout> - %b = tt.descriptor_load %b_desc[%off_n, %off_k] : !tt.tensordesc> -> tensor<64x128xf16, #oper_layout> + %a = tt.descriptor_gather %a_desc[%indices, %off_k] : (!tt.tensordesc<1x64xf16, #shared>, tensor<128xi32, #indices_layout>, i32) -> tensor<128x64xf16, #oper_layout> + %b = tt.descriptor_load %b_desc[%off_n, %off_k] : !tt.tensordesc<64x128xf16, #shared> -> tensor<64x128xf16, #oper_layout> %a_shared = ttg.local_alloc %a : (tensor<128x64xf16, #oper_layout>) -> !ttg.memdesc<128x64xf16, #shared, #smem> %b_shared = ttg.local_alloc %b : (tensor<64x128xf16, #oper_layout>) -> !ttg.memdesc<64x128xf16, #shared, #smem> %c_tmem, %c_tok = ttng.tmem_alloc %acc : (tensor<128x128xf32, #acc_layout>) -> (!ttg.memdesc<128x128xf32, #acc_tmem, #ttng.tensor_memory, mutable>, !ttg.async.token) @@ -120,7 +120,7 @@ tt.func @matmul_tma_acc_with_conditional_def_and_use( // CHECK-LABEL: @matmul_tma_and_regular_load tt.func @matmul_tma_and_regular_load( - %a_desc: !tt.tensordesc>, + %a_desc: !tt.tensordesc<1x64xf16, #shared>, %b_ptr_init: tensor<64x128x!tt.ptr, #b_layout> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[1, 64]> : tensor<2xi32>} ) { %c0_i32 = arith.constant 0 : i32 @@ -159,7 +159,7 @@ tt.func @matmul_tma_and_regular_load( %off_m, %offs_n, %off_k = "get_offsets"(%k) : (i32) -> (i32, tensor<64x128xi32, #b_layout>, i32) %indices = tt.splat %off_m : i32 -> tensor<128xi32, #indices_layout> - %a = tt.descriptor_gather %a_desc[%indices, %off_k] : (!tt.tensordesc>, tensor<128xi32, #indices_layout>, i32) -> tensor<128x64xf16, #oper_layout> + %a = tt.descriptor_gather %a_desc[%indices, %off_k] : (!tt.tensordesc<1x64xf16, #shared>, tensor<128xi32, #indices_layout>, i32) -> tensor<128x64xf16, #oper_layout> %b_ptrs = tt.addptr %b_ptr, %offs_n {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[1, 64]> : tensor<2xi32>, tt.constancy = dense<[1, 1]> : tensor<2xi32>} : tensor<64x128x!tt.ptr, #b_layout>, tensor<64x128xi32, #b_layout> %b = tt.load %b_ptrs : tensor<64x128x!tt.ptr, #b_layout> @@ -197,8 +197,8 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { // CHECK-LABEL: @attention_forward tt.func public @attention_forward( %Q_shared: !ttg.memdesc<256x64xf16, #shared, #smem>, - %K_desc: !tt.tensordesc>, - %V_desc: !tt.tensordesc>, + %K_desc: !tt.tensordesc<64x64xf16, #shared>, + %V_desc: !tt.tensordesc<64x64xf16, #shared>, %qk_scale: f32, %n_tiles: i32, %idx_ptr: !tt.ptr @@ -231,7 +231,7 @@ tt.func public @attention_forward( tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>> ) : i32 { - %K = tt.descriptor_load %K_desc[%i, %c0_i32] : !tt.tensordesc> -> tensor<64x64xf16, #load_blocked> + %K = tt.descriptor_load %K_desc[%i, %c0_i32] : !tt.tensordesc<64x64xf16, #shared> -> tensor<64x64xf16, #load_blocked> %K_shared = ttg.local_alloc %K : (tensor<64x64xf16, #load_blocked>) -> !ttg.memdesc<64x64xf16, #shared, #smem> %K_trans = ttg.memdesc_trans %K_shared {order = array} : !ttg.memdesc<64x64xf16, #shared, #smem> -> !ttg.memdesc<64x64xf16, #shared_T, #smem> @@ -264,7 +264,7 @@ tt.func public @attention_forward( %acc_step = arith.mulf %acc, %alpha_1 : tensor<256x64xf32, #blocked> %acc_corrected = arith.addf %acc_step, %bias : tensor<256x64xf32, #blocked> - %62 = tt.descriptor_load %V_desc[%i, %c0_i32] : !tt.tensordesc> -> tensor<64x64xf16, #load_blocked> + %62 = tt.descriptor_load %V_desc[%i, %c0_i32] : !tt.tensordesc<64x64xf16, #shared> -> tensor<64x64xf16, #load_blocked> %63 = ttg.local_alloc %62 : (tensor<64x64xf16, #load_blocked>) -> !ttg.memdesc<64x64xf16, #shared, #smem> %P = arith.truncf %softmax : tensor<256x64xf32, #blocked> to tensor<256x64xf16, #blocked> @@ -358,9 +358,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ %c_ptr = tt.addptr %group_c_ptrs, %g : !tt.ptr, i32 %c_ptr_10 = tt.load %c_ptr : !tt.ptr %c_ptr_11 = tt.int_to_ptr %c_ptr_10 : i64 -> !tt.ptr - %a_desc_12 = tt.make_tensor_descriptor %a_ptr_7, [%gm, %gk], [%stride, %c1_i64] : , > - %b_desc_13 = tt.make_tensor_descriptor %b_ptr_9, [%gn, %gk], [%stride, %c1_i64] : , > - %c_desc_14 = tt.make_tensor_descriptor %c_ptr_11, [%gm, %gn], [%stride, %c1_i64] : , > + %a_desc_12 = tt.make_tensor_descriptor %a_ptr_7, [%gm, %gk], [%stride, %c1_i64] : , <128x64xf16, #shared> + %b_desc_13 = tt.make_tensor_descriptor %b_ptr_9, [%gn, %gk], [%stride, %c1_i64] : , <128x64xf16, #shared> + %c_desc_14 = tt.make_tensor_descriptor %c_ptr_11, [%gm, %gn], [%stride, %c1_i64] : , <128x128xf16, #shared> scf.for %tile_idx = %start_pid to %num_tiles step %c4_i32 : i32 { %tile_m_idx = arith.divsi %tile_idx, %num_n_tiles_1 : i32 %tile_n_idx = arith.remsi %tile_idx, %num_n_tiles_1 : i32 @@ -370,9 +370,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ %accumulator_16 = ttng.tmem_store %cst, %accumulator[%accumulator_15], %true : tensor<128x128xf32, #blocked> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> %accumulator_17:2 = scf.for %accumulator_20 = %c0_i32 to %1 step %c1_i32 iter_args(%arg11 = %false, %accumulator_21 = %accumulator_16) -> (i1, !ttg.async.token) : i32 { %a = arith.muli %accumulator_20, %c64_i32 : i32 - %a_22 = tt.descriptor_load %a_desc_12[%offs_am, %a] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : !tt.tensordesc> -> tensor<128x64xf16, #blocked1> + %a_22 = tt.descriptor_load %a_desc_12[%offs_am, %a] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : !tt.tensordesc<128x64xf16, #shared> -> tensor<128x64xf16, #blocked1> %a_23 = ttg.local_alloc %a_22 : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #smem> - %b = tt.descriptor_load %b_desc_13[%offs_bn, %a] : !tt.tensordesc> -> tensor<128x64xf16, #blocked1> + %b = tt.descriptor_load %b_desc_13[%offs_bn, %a] : !tt.tensordesc<128x64xf16, #shared> -> tensor<128x64xf16, #blocked1> %accumulator_24 = ttg.local_alloc %b : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #smem> %accumulator_25 = ttg.memdesc_trans %accumulator_24 {order = array} : !ttg.memdesc<128x64xf16, #shared, #smem> -> !ttg.memdesc<64x128xf16, #shared1, #smem> %accumulator_26 = ttng.tc_gen5_mma %a_23, %accumulator_25, %accumulator[%accumulator_21], %arg11, %true : !ttg.memdesc<128x64xf16, #shared, #smem>, !ttg.memdesc<64x128xf16, #shared1, #smem>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> @@ -381,7 +381,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ %accumulator_18, %accumulator_19 = ttng.tmem_load %accumulator[%accumulator_17#1] : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked> %c = arith.truncf %accumulator_18 : tensor<128x128xf32, #blocked> to tensor<128x128xf16, #blocked> %2 = ttg.convert_layout %c : tensor<128x128xf16, #blocked> -> tensor<128x128xf16, #blocked2> - tt.descriptor_store %c_desc_14[%offs_am, %offs_bn], %2 : !tt.tensordesc>, tensor<128x128xf16, #blocked2> + tt.descriptor_store %c_desc_14[%offs_am, %offs_bn], %2 : !tt.tensordesc<128x128xf16, #shared>, tensor<128x128xf16, #blocked2> } } {tt.warp_specialize} tt.return diff --git a/test/TritonGPU/coalesce.mlir b/test/TritonGPU/coalesce.mlir index 9bd824360e37..ec2b7d0ae442 100644 --- a/test/TritonGPU/coalesce.mlir +++ b/test/TritonGPU/coalesce.mlir @@ -207,12 +207,12 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ #blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: @descriptor_store - tt.func public @descriptor_store(%arg0: !tt.tensordesc>) { + tt.func public @descriptor_store(%arg0: !tt.tensordesc<2x64xf16>) { %c0_i32 = arith.constant 0 : i32 %cst = arith.constant dense<0.000000e+00> : tensor<2x64xf16, #blocked> // CHECK: %[[C:.+]] = ttg.convert_layout %{{.+}} : tensor<2x64xf16, #{{.+}}> -> tensor<2x64xf16, #[[$LAYOUT]]> - // CHECK: tt.descriptor_store {{.*}}, %[[C]] : !tt.tensordesc>, tensor<2x64xf16, #[[$LAYOUT]]> - tt.descriptor_store %arg0[%c0_i32, %c0_i32], %cst : !tt.tensordesc>, tensor<2x64xf16, #blocked> + // CHECK: tt.descriptor_store {{.*}}, %[[C]] : !tt.tensordesc<2x64xf16>, tensor<2x64xf16, #[[$LAYOUT]]> + tt.descriptor_store %arg0[%c0_i32, %c0_i32], %cst : !tt.tensordesc<2x64xf16>, tensor<2x64xf16, #blocked> tt.return } } diff --git a/test/TritonGPU/combine.mlir b/test/TritonGPU/combine.mlir index b89d933221d3..f4d573f69926 100644 --- a/test/TritonGPU/combine.mlir +++ b/test/TritonGPU/combine.mlir @@ -3599,15 +3599,15 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num- // CHECK: tt.func @mma_v3_reg_push_elementwise_chained_descritor_load // CHECK: %[[CST_DOTOP:.*]] = arith.constant dense<0.000000e+00> : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - // CHECK: %[[A_BLOCK:.*]] = tt.descriptor_load %{{.*}} : !tt.tensordesc> -> tensor<128x64xi8, #blocked> + // CHECK: %[[A_BLOCK:.*]] = tt.descriptor_load %{{.*}} : !tt.tensordesc<128x64xsi8> -> tensor<128x64xi8, #blocked> // CHECK: %[[A_DOTOP:.*]] = ttg.convert_layout %[[A_BLOCK]] : tensor<128x64xi8, #blocked> -> tensor<128x64xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> // CHECK: %[[A_CASTED:.*]] = arith.sitofp %[[A_DOTOP]] : tensor<128x64xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> to tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> // CHECK: %[[A_SCALED:.*]] = arith.mulf %[[A_CASTED]], %[[CST_DOTOP]] : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> // CHECK: %[[A_NEGATED:.*]] = arith.negf %[[A_SCALED]] : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> // CHECK: %[[R:.*]] = ttng.warp_group_dot %[[A_NEGATED]], %{{.*}}, %{{.*}} : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !ttg.memdesc<64x64xf16, #shared, #smem> -> tensor<128x64xf32, #mma> - tt.func @mma_v3_reg_push_elementwise_chained_descritor_load(%pa: !tt.tensordesc>, %dotb: !ttg.memdesc<64x64xf16, #shared, #smem>, %dotc: tensor<128x64xf32, #mma>, %A_dim1: i32, %A_dim2: i32) -> tensor<128x64xf32, #mma>{ + tt.func @mma_v3_reg_push_elementwise_chained_descritor_load(%pa: !tt.tensordesc<128x64xsi8>, %dotb: !ttg.memdesc<64x64xf16, #shared, #smem>, %dotc: tensor<128x64xf32, #mma>, %A_dim1: i32, %A_dim2: i32) -> tensor<128x64xf32, #mma>{ %cst = arith.constant dense<0.000000e+00> : tensor<128x64xf16, #blocked> - %a_i8 = tt.descriptor_load %pa[%A_dim1, %A_dim2]: !tt.tensordesc> -> tensor<128x64xi8, #blocked> + %a_i8 = tt.descriptor_load %pa[%A_dim1, %A_dim2]: !tt.tensordesc<128x64xsi8> -> tensor<128x64xi8, #blocked> %a_f16 = arith.sitofp %a_i8 : tensor<128x64xi8, #blocked> to tensor<128x64xf16, #blocked> %a_scaled = arith.mulf %a_f16, %cst : tensor<128x64xf16, #blocked> %a_negated = arith.negf %a_scaled : tensor<128x64xf16, #blocked> diff --git a/test/TritonGPU/consan.mlir b/test/TritonGPU/consan.mlir index 9c75abf104bb..b65bdb333ecb 100644 --- a/test/TritonGPU/consan.mlir +++ b/test/TritonGPU/consan.mlir @@ -206,7 +206,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar #smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:90", ttg.tensor_memory_size = 0 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 1 : i32} { // CHECK-LABEL: @async_tma_copy_global_to_local - tt.func public @async_tma_copy_global_to_local(%arg0: !tt.tensordesc>) { + tt.func public @async_tma_copy_global_to_local(%arg0: !tt.tensordesc<32x32xf32, #shared>) { // CHECK-DAG: %[[BUFFERS:.*]] = tti.experimental_buffer_descriptors [0], [{{.*}}], shared_mem : tensor<1x1xi64 // CHECK-DAG: %[[WRITE_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 8 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr @@ -245,7 +245,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar // CHECK: tt.call @__triton_consan_track_barrier_write_for_buffer // CHECK: tt.call @__triton_consan_verify_barrier_arrive // CHECK: tt.call @__triton_consan_update_barrier_state - ttng.async_tma_copy_global_to_local %arg0[%c0_i32, %c0_i32] %0, %bar, %true : !tt.tensordesc>, !ttg.memdesc<1xi64, #shared1, #smem, mutable> -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> + ttng.async_tma_copy_global_to_local %arg0[%c0_i32, %c0_i32] %0, %bar, %true : !tt.tensordesc<32x32xf32, #shared>, !ttg.memdesc<1xi64, #shared1, #smem, mutable> -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> tt.return } } @@ -259,8 +259,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:90", ttg.tensor_memory_size = 0 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 1 : i32} { // CHECK-LABEL: @async_tma_copy_global_to_local_two_bufs_one_barrier tt.func public @async_tma_copy_global_to_local_two_bufs_one_barrier( - %a: !tt.tensordesc>, - %b: !tt.tensordesc>) { + %a: !tt.tensordesc<32x32xf32, #shared>, + %b: !tt.tensordesc<32x32xf32, #shared>) { %true = arith.constant true %c0_i32 = arith.constant 0 : i32 @@ -278,8 +278,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar // CHECK: tt.call @__triton_consan_update_barrier_state // CHECK: ttng.barrier_expect // CHECK-COUNT-2: tt.call @__triton_consan_track_barrier_write_for_buffer - ttng.async_tma_copy_global_to_local %a[%c0_i32, %c0_i32] %a_smem, %bar, %true : !tt.tensordesc>, !ttg.memdesc<1xi64, #shared1, #smem, mutable> -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> - ttng.async_tma_copy_global_to_local %b[%c0_i32, %c0_i32] %b_smem, %bar, %true : !tt.tensordesc>, !ttg.memdesc<1xi64, #shared1, #smem, mutable> -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> + ttng.async_tma_copy_global_to_local %a[%c0_i32, %c0_i32] %a_smem, %bar, %true : !tt.tensordesc<32x32xf32, #shared>, !ttg.memdesc<1xi64, #shared1, #smem, mutable> -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> + ttng.async_tma_copy_global_to_local %b[%c0_i32, %c0_i32] %b_smem, %bar, %true : !tt.tensordesc<32x32xf32, #shared>, !ttg.memdesc<1xi64, #shared1, #smem, mutable> -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> ttng.wait_barrier %bar, %c0_i32, %true : !ttg.memdesc<1xi64, #shared1, #smem, mutable> @@ -311,8 +311,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar // CHECK: tt.call @__triton_consan_track_barrier_write_for_buffer{{.*}}({{[^,]+}}, {{[^,]+}}, %true, %[[B_TRACK:.*]], {{[^,]+}}, // CHECK: ttng.async_tma_copy_global_to_local %arg1 tt.func public @async_tma_copy_global_to_local_two_bufs_two_barriers( - %a: !tt.tensordesc>, - %b: !tt.tensordesc>) { + %a: !tt.tensordesc<32x32xf32, #shared>, + %b: !tt.tensordesc<32x32xf32, #shared>) { %true = arith.constant true %c0_i32 = arith.constant 0 : i32 %a_smem = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> @@ -322,9 +322,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar ttng.init_barrier %bar0, 1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> ttng.init_barrier %bar1, 1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> ttng.barrier_expect %bar0, 4096, %true : !ttg.memdesc<1xi64, #shared1, #smem, mutable> - ttng.async_tma_copy_global_to_local %a[%c0_i32, %c0_i32] %a_smem, %bar0, %true : !tt.tensordesc>, !ttg.memdesc<1xi64, #shared1, #smem, mutable> -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> + ttng.async_tma_copy_global_to_local %a[%c0_i32, %c0_i32] %a_smem, %bar0, %true : !tt.tensordesc<32x32xf32, #shared>, !ttg.memdesc<1xi64, #shared1, #smem, mutable> -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> ttng.barrier_expect %bar1, 4096, %true : !ttg.memdesc<1xi64, #shared1, #smem, mutable> - ttng.async_tma_copy_global_to_local %b[%c0_i32, %c0_i32] %b_smem, %bar1, %true : !tt.tensordesc>, !ttg.memdesc<1xi64, #shared1, #smem, mutable> -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> + ttng.async_tma_copy_global_to_local %b[%c0_i32, %c0_i32] %b_smem, %bar1, %true : !tt.tensordesc<32x32xf32, #shared>, !ttg.memdesc<1xi64, #shared1, #smem, mutable> -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> ttng.wait_barrier %bar1, %c0_i32, %true : !ttg.memdesc<1xi64, #shared1, #smem, mutable> %va = ttg.local_load %a_smem : !ttg.memdesc<32x32xf32, #shared, #smem, mutable> -> tensor<32x32xf32, #blocked> %vb = ttg.local_load %b_smem : !ttg.memdesc<32x32xf32, #shared, #smem, mutable> -> tensor<32x32xf32, #blocked> @@ -343,7 +343,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:90", ttg.tensor_memory_size = 0 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 1 : i32} { // CHECK-LABEL: @async_tma_copy_local_to_global - tt.func public @async_tma_copy_local_to_global(%arg0: !tt.tensordesc>, %ptr: tensor<128x128x!tt.ptr, #blocked>, %acc: tensor<128x128xf16, #mma>) { + tt.func public @async_tma_copy_local_to_global(%arg0: !tt.tensordesc<32x32xf32, #shared>, %ptr: tensor<128x128x!tt.ptr, #blocked>, %acc: tensor<128x128xf16, #mma>) { %true = arith.constant true %c0_i32 = arith.constant 0 : i32 %0 = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> @@ -354,7 +354,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar // CHECK: tt.call @__triton_consan_check_outstanding_commits // CHECK: tt.call @__triton_consan_stage_access_for_commit // CHECK: tt.call @__triton_consan_commit_accesses - ttng.async_tma_copy_local_to_global %arg0[%c0_i32, %c0_i32] %0 : !tt.tensordesc>, !ttg.memdesc<32x32xf32, #shared, #smem, mutable> + ttng.async_tma_copy_local_to_global %arg0[%c0_i32, %c0_i32] %0 : !tt.tensordesc<32x32xf32, #shared>, !ttg.memdesc<32x32xf32, #shared, #smem, mutable> tt.return } } @@ -369,7 +369,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:90", ttg.tensor_memory_size = 0 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 1 : i32} { // CHECK-LABEL: @async_tma_store_wait - tt.func public @async_tma_store_wait(%arg0: !tt.tensordesc>, %ptr: tensor<128x128x!tt.ptr, #blocked>, %acc: tensor<128x128xf16, #mma>) { + tt.func public @async_tma_store_wait(%arg0: !tt.tensordesc<32x32xf32, #shared>, %ptr: tensor<128x128x!tt.ptr, #blocked>, %acc: tensor<128x128xf16, #mma>) { %true = arith.constant true %c0_i32 = arith.constant 0 : i32 %0 = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> @@ -392,7 +392,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:90", ttg.tensor_memory_size = 0 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 1 : i32} { // CHECK-LABEL: @async_tma_gather - tt.func public @async_tma_gather(%arg0: !tt.tensordesc>, %ptr: tensor<128x128x!tt.ptr, #blocked>, %acc: tensor<128x128xf16, #mma>) { + tt.func public @async_tma_gather(%arg0: !tt.tensordesc<1x32xf32, #shared>, %ptr: tensor<128x128x!tt.ptr, #blocked>, %acc: tensor<128x128xf16, #mma>) { %true = arith.constant true %c0_i32 = arith.constant 0 : i32 %x_offsets = arith.constant dense<1> : tensor<32xi32> @@ -411,7 +411,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar // CHECK: tt.call @__triton_consan_clear_read_visibility // CHECK: tt.call @__triton_consan_clear_read_tracking // CHECK: tt.call @__triton_consan_track_barrier_write_for_buffer - ttng.async_tma_gather %arg0[%x_offsets, %c0_i32] %0, %bar, %true : !tt.tensordesc>, tensor<32xi32>, i32, !ttg.memdesc<1xi64, #shared1, #smem, mutable>, !ttg.memdesc<32x32xf32, #shared, #smem, mutable>, i1 + ttng.async_tma_gather %arg0[%x_offsets, %c0_i32] %0, %bar, %true : !tt.tensordesc<1x32xf32, #shared>, tensor<32xi32>, i32, !ttg.memdesc<1xi64, #shared1, #smem, mutable>, !ttg.memdesc<32x32xf32, #shared, #smem, mutable>, i1 tt.return } } @@ -426,7 +426,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:90", ttg.tensor_memory_size = 0 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 1 : i32} { // CHECK-LABEL: @async_tma_scatter - tt.func public @async_tma_scatter(%arg0: !tt.tensordesc>, %ptr: tensor<128x128x!tt.ptr, #blocked>, %acc: tensor<128x128xf16, #mma>) { + tt.func public @async_tma_scatter(%arg0: !tt.tensordesc<1x32xf32, #shared>, %ptr: tensor<128x128x!tt.ptr, #blocked>, %acc: tensor<128x128xf16, #mma>) { %true = arith.constant true %c0_i32 = arith.constant 0 : i32 %x_offsets = arith.constant dense<1> : tensor<32xi32> @@ -440,7 +440,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar // CHECK: tt.call @__triton_consan_verify_write_visibility // CHECK: tt.call @__triton_consan_check_outstanding_commits - ttng.async_tma_scatter %arg0[%x_offsets, %c0_i32] %0 : !tt.tensordesc>, tensor<32xi32>, i32, !ttg.memdesc<32x32xf32, #shared, #smem, mutable> + ttng.async_tma_scatter %arg0[%x_offsets, %c0_i32] %0 : !tt.tensordesc<1x32xf32, #shared>, tensor<32xi32>, i32, !ttg.memdesc<32x32xf32, #shared, #smem, mutable> tt.return } } @@ -478,7 +478,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar #blocked = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [32, 1], warpsPerCTA = [1, 1], order = [0, 1]}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:90", ttg.tensor_memory_size = 0 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 1 : i32} { // CHECK-LABEL: @wait_barrier - tt.func public @wait_barrier(%arg0: !tt.tensordesc>) { + tt.func public @wait_barrier(%arg0: !tt.tensordesc<32x32xf32, #shared>) { // CHECK-DAG: %[[BUFFERS:.*]] = tti.experimental_buffer_descriptors [0], [{{.*}}], shared_mem : tensor<1x1xi64, #linear> // CHECK-DAG: %[[WRITE_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 8 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr @@ -523,7 +523,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar #blocked = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [32, 1], warpsPerCTA = [1, 1], order = [0, 1]}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:90", ttg.tensor_memory_size = 0 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 1 : i32} { // CHECK-LABEL: @arrive_barrier - tt.func public @arrive_barrier(%arg0: !tt.tensordesc>) { + tt.func public @arrive_barrier(%arg0: !tt.tensordesc<32x32xf32, #shared>) { // CHECK-DAG: %[[BSTATE_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 8 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr // CHECK: call {{.*}}fill_global_tensor{{.*}}T1x1xI64(%[[BSTATE_GLOB]], %c0_i64 %true = arith.constant true @@ -588,7 +588,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar #tmem = #ttng.tensor_memory_encoding module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:90", ttg.tensor_memory_size = 0 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 1 : i32} { // CHECK-LABEL: @tcgen5_mma - tt.func public @tcgen5_mma(%arg0: !tt.tensordesc>) { + tt.func public @tcgen5_mma(%arg0: !tt.tensordesc<32x32xf32, #shared>) { // CHECK-DAG: %[[SM_BUFS:.*]] = tti.experimental_buffer_descriptors [0, 32768], [{{.*}}], shared_mem : tensor<1x2xi64 // CHECK-DAG: %[[SM_WRITE_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 16 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr // CHECK-DAG: %[[SM_READ_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 1024 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr @@ -663,7 +663,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar #tmem1 = #ttng.tensor_memory_encoding module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:90", ttg.tensor_memory_size = 0 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 1 : i32} { // CHECK-LABEL: @tcgen5_mma_lhs_in_tmem - tt.func public @tcgen5_mma_lhs_in_tmem(%arg0: !tt.tensordesc>) { + tt.func public @tcgen5_mma_lhs_in_tmem(%arg0: !tt.tensordesc<32x32xf32, #shared>) { // CHECK-DAG: %[[SM_BUFS:.*]] = tti.experimental_buffer_descriptors [32768], [{{.*}}], shared_mem : tensor<1x1xi64 // CHECK-DAG: %[[SM_WRITE_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 8 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr // CHECK-DAG: %[[SM_READ_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 512 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr @@ -741,7 +741,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar #blocked = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [32, 1], warpsPerCTA = [1, 1], order = [0, 1]}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:90", ttg.tensor_memory_size = 0 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 1 : i32} { // CHECK-LABEL: @tcgen5_commit - tt.func public @tcgen5_commit(%arg0: !tt.tensordesc>) { + tt.func public @tcgen5_commit(%arg0: !tt.tensordesc<32x32xf32, #shared>) { %c0_i32 = arith.constant 0 : i32 %0 = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> @@ -1146,7 +1146,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar #tmem = #ttng.tensor_memory_encoding module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:90", ttg.tensor_memory_size = 0 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 8 : i32} { // CHECK-LABEL: @ws_allocation - tt.func public @ws_allocation(%arg0: !tt.tensordesc>) { + tt.func public @ws_allocation(%arg0: !tt.tensordesc<32x32xf32, #shared>) { // CHECK-DAG: tti.experimental_buffer_descriptors [65536], [{{.*}}], shared_mem : tensor<1x1xi64, // CHECK-DAG: tti.experimental_buffer_descriptors [0], [{{.*}}], shared_mem : tensor<1x1xi64 %smem = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> @@ -1191,7 +1191,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar #tmem = #ttng.tensor_memory_encoding module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:90", ttg.tensor_memory_size = 0 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 8 : i32} { // CHECK-LABEL: @ws_buf_ptrs_default - tt.func public @ws_buf_ptrs_default(%arg0: !tt.tensordesc>) { + tt.func public @ws_buf_ptrs_default(%arg0: !tt.tensordesc<32x32xf32, #shared>) { // CHECK-DAG: tti.experimental_buffer_descriptors [0, 32768, 65536, 0], [{{.*}}], shared_mem // CHECK-DAG: tti.experimental_buffer_descriptors [65536], [{{.*}}], shared_mem %smem = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<3x128x128xf16, #shared, #smem, mutable> @@ -1227,7 +1227,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar #tmem = #ttng.tensor_memory_encoding module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:90", ttg.tensor_memory_size = 0 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 8 : i32} { // CHECK-LABEL: @ws_buf_ptrs_partition0 - tt.func public @ws_buf_ptrs_partition0(%arg0: !tt.tensordesc>) { + tt.func public @ws_buf_ptrs_partition0(%arg0: !tt.tensordesc<32x32xf32, #shared>) { // CHECK-DAG: tti.experimental_buffer_descriptors [0, 32768, 65536, 0], [{{.*}}], shared_mem // CHECK-DAG: tti.experimental_buffer_descriptors [65536], [{{.*}}], shared_mem %smem = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<3x128x128xf16, #shared, #smem, mutable> diff --git a/test/TritonGPU/dot-operands.mlir b/test/TritonGPU/dot-operands.mlir index cfdd2b4a84b4..f292a2fcb211 100644 --- a/test/TritonGPU/dot-operands.mlir +++ b/test/TritonGPU/dot-operands.mlir @@ -232,9 +232,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ %cst_scales = arith.constant dense<127> : tensor<128x4xi8, #linear> %true = arith.constant true - %desc = tt.make_tensor_descriptor %scale_desc_ptr, [%c1_i32, %c2_i32, %c1_i32, %c32_i32, %c16_i32], [%c1024_i64, %c512_i64, %c512_i64, %c16_i64, %c1_i64] : !tt.ptr, > - // CHECK: %[[DESC_LOAD:.*]] = tt.descriptor_load {{.*}} !tt.tensordesc> -> tensor<1x2x1x32x16xi8, #[[BLOCKED5]]> - %83 = tt.descriptor_load %desc[%c0_i32, %c0_i32, %c0_i32, %c0_i32, %c0_i32] : !tt.tensordesc> -> tensor<1x2x1x32x16xi8, #blocked5> + %desc = tt.make_tensor_descriptor %scale_desc_ptr, [%c1_i32, %c2_i32, %c1_i32, %c32_i32, %c16_i32], [%c1024_i64, %c512_i64, %c512_i64, %c16_i64, %c1_i64] : !tt.ptr, <1x2x1x32x16xi8> + // CHECK: %[[DESC_LOAD:.*]] = tt.descriptor_load {{.*}} !tt.tensordesc<1x2x1x32x16xi8> -> tensor<1x2x1x32x16xi8, #[[BLOCKED5]]> + %83 = tt.descriptor_load %desc[%c0_i32, %c0_i32, %c0_i32, %c0_i32, %c0_i32] : !tt.tensordesc<1x2x1x32x16xi8> -> tensor<1x2x1x32x16xi8, #blocked5> // CHECK: %[[DESC_LA:.*]] = ttg.local_alloc %[[DESC_LOAD]] : (tensor<1x2x1x32x16xi8, #[[BLOCKED5]]>) -> !ttg.memdesc<1x2x1x32x16xi8, #[[SHARED2]], #[[SMEM]]> %84 = ttg.local_alloc %83 : (tensor<1x2x1x32x16xi8, #blocked5>) -> !ttg.memdesc<1x2x1x32x16xi8, #shared2, #smem> // CHECK-NOT: ttg.local_load diff --git a/test/TritonGPU/gsan.mlir b/test/TritonGPU/gsan.mlir index d7408f7d591d..4dc84dff6f16 100644 --- a/test/TritonGPU/gsan.mlir +++ b/test/TritonGPU/gsan.mlir @@ -44,7 +44,7 @@ module attributes {"ttg.num-warps" = 2 : i32, "ttg.threads-per-warp" = 32 : i32} module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: tt.func @instrumented_async_tma_copy - tt.func @instrumented_async_tma_copy(%desc: !tt.tensordesc>) { + tt.func @instrumented_async_tma_copy(%desc: !tt.tensordesc<32x32xf32, #shared>) { %true = arith.constant true %c0_i32 = arith.constant 0 : i32 %buf = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> @@ -52,11 +52,11 @@ module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} // CHECK: tti.experimental_gsan_tensordesc_info %arg0 // CHECK: tti.experimental_gsan_tensor_access %{{.*}}, false, %{{.*}} // CHECK-NEXT: ttng.async_tma_copy_global_to_local - ttng.async_tma_copy_global_to_local %desc[%c0_i32, %c0_i32] %buf, %barrier, %true : !tt.tensordesc>, !ttg.memdesc<1xi64, #bar, #smem, mutable> -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> + ttng.async_tma_copy_global_to_local %desc[%c0_i32, %c0_i32] %buf, %barrier, %true : !tt.tensordesc<32x32xf32, #shared>, !ttg.memdesc<1xi64, #bar, #smem, mutable> -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> // CHECK: tti.experimental_gsan_tensordesc_info %arg0 // CHECK: tti.experimental_gsan_tensor_access %{{.*}}, true, %{{.*}} // CHECK-NEXT: ttng.async_tma_copy_local_to_global - ttng.async_tma_copy_local_to_global %desc[%c0_i32, %c0_i32] %buf : !tt.tensordesc>, !ttg.memdesc<32x32xf32, #shared, #smem, mutable> + ttng.async_tma_copy_local_to_global %desc[%c0_i32, %c0_i32] %buf : !tt.tensordesc<32x32xf32, #shared>, !ttg.memdesc<32x32xf32, #shared, #smem, mutable> tt.return } } @@ -70,7 +70,7 @@ module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: tt.func @instrumented_async_tma_gather_scatter - tt.func @instrumented_async_tma_gather_scatter(%desc: !tt.tensordesc>) { + tt.func @instrumented_async_tma_gather_scatter(%desc: !tt.tensordesc<1x32xf32, #shared>) { %true = arith.constant true %c0_i32 = arith.constant 0 : i32 %x_offsets = arith.constant dense<1> : tensor<32xi32, #blocked_rows> @@ -79,11 +79,11 @@ module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} // CHECK: tti.experimental_gsan_tensordesc_info %arg0 // CHECK: tti.experimental_gsan_tensor_access %{{.*}}, false, %{{.*}} // CHECK-NEXT: ttng.async_tma_gather - ttng.async_tma_gather %desc[%x_offsets, %c0_i32] %buf, %barrier, %true : !tt.tensordesc>, tensor<32xi32, #blocked_rows>, i32, !ttg.memdesc<1xi64, #bar, #smem, mutable>, !ttg.memdesc<32x32xf32, #shared, #smem, mutable>, i1 + ttng.async_tma_gather %desc[%x_offsets, %c0_i32] %buf, %barrier, %true : !tt.tensordesc<1x32xf32, #shared>, tensor<32xi32, #blocked_rows>, i32, !ttg.memdesc<1xi64, #bar, #smem, mutable>, !ttg.memdesc<32x32xf32, #shared, #smem, mutable>, i1 // CHECK: tti.experimental_gsan_tensordesc_info %arg0 // CHECK: tti.experimental_gsan_tensor_access %{{.*}}, true, %{{.*}} // CHECK-NEXT: ttng.async_tma_scatter - ttng.async_tma_scatter %desc[%x_offsets, %c0_i32] %buf : !tt.tensordesc>, tensor<32xi32, #blocked_rows>, i32, !ttg.memdesc<32x32xf32, #shared, #smem, mutable> + ttng.async_tma_scatter %desc[%x_offsets, %c0_i32] %buf : !tt.tensordesc<1x32xf32, #shared>, tensor<32xi32, #blocked_rows>, i32, !ttg.memdesc<32x32xf32, #shared, #smem, mutable> tt.return } } @@ -106,13 +106,13 @@ module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} %c32_i32 = arith.constant 32 : i32 ttng.tensormap_create %raw_desc, %base, [%c32_i32, %c32_i32], [%shape1, %shape0], [%stride0], [%c1_i32, %c1_i32] {elem_type = 0 : i32, fill_mode = 0 : i32, interleave_layout = 0 : i32, swizzle_mode = 0 : i32} : (!tt.ptr, !tt.ptr, i32, i32, i32, i32, i64, i32, i32) -> () // CHECK: %[[DESC:.*]] = ttng.reinterpret_tensor_descriptor %arg0 - %desc = ttng.reinterpret_tensor_descriptor %raw_desc : !tt.ptr to !tt.tensordesc> + %desc = ttng.reinterpret_tensor_descriptor %raw_desc : !tt.ptr to !tt.tensordesc<32x32xf32, #shared> %buf = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> %barrier = ttg.local_alloc {allocation.offset = 4096 : i32} : () -> !ttg.memdesc<1xi64, #bar, #smem, mutable> // CHECK: tti.experimental_gsan_tensordesc_info %[[DESC]] // CHECK: tti.experimental_gsan_tensor_access %{{.*}}, false, %{{.*}} // CHECK-NEXT: ttng.async_tma_copy_global_to_local - ttng.async_tma_copy_global_to_local %desc[%c0_i32, %c0_i32] %buf, %barrier, %true : !tt.tensordesc>, !ttg.memdesc<1xi64, #bar, #smem, mutable> -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> + ttng.async_tma_copy_global_to_local %desc[%c0_i32, %c0_i32] %buf, %barrier, %true : !tt.tensordesc<32x32xf32, #shared>, !ttg.memdesc<1xi64, #bar, #smem, mutable> -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> tt.return } } diff --git a/test/TritonGPU/loop-pipeline-async-latencies.mlir b/test/TritonGPU/loop-pipeline-async-latencies.mlir index 58127d8dd0c9..d8567c24b3e6 100644 --- a/test/TritonGPU/loop-pipeline-async-latencies.mlir +++ b/test/TritonGPU/loop-pipeline-async-latencies.mlir @@ -9,7 +9,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: matmul_kernel_tma_persistent -tt.func public @matmul_kernel_tma_persistent(%arg0: !tt.tensordesc>, %arg1: !tt.tensordesc>, %arg2: !tt.tensordesc>, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) { +tt.func public @matmul_kernel_tma_persistent(%arg0: !tt.tensordesc<128x64xf16, #shared>, %arg1: !tt.tensordesc<256x64xf16, #shared>, %arg2: !tt.tensordesc<128x256xf16, #shared>, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) { %c2_i32 = arith.constant 2 : i32 %c1_i32 = arith.constant 1 : i32 %c0_i32 = arith.constant 0 : i32 @@ -98,9 +98,9 @@ tt.func public @matmul_kernel_tma_persistent(%arg0: !tt.tensordesc> -> tensor<128x64xf16, #blocked> + %4 = tt.descriptor_load %arg0[%c0_i32, %arg6] {tt.latency = 1 : i32} : !tt.tensordesc<128x64xf16, #shared> -> tensor<128x64xf16, #blocked> %5 = ttg.local_alloc %4 : (tensor<128x64xf16, #blocked>) -> !ttg.memdesc<128x64xf16, #shared, #smem> - %6 = tt.descriptor_load %arg1[%c0_i32, %arg6] {tt.latency = 3 : i32} : !tt.tensordesc> -> tensor<256x64xf16, #blocked> + %6 = tt.descriptor_load %arg1[%c0_i32, %arg6] {tt.latency = 3 : i32} : !tt.tensordesc<256x64xf16, #shared> -> tensor<256x64xf16, #blocked> %7 = ttg.local_alloc %6 : (tensor<256x64xf16, #blocked>) -> !ttg.memdesc<256x64xf16, #shared, #smem> %8 = ttg.memdesc_trans %7 {order = array} : !ttg.memdesc<256x64xf16, #shared, #smem> -> !ttg.memdesc<64x256xf16, #shared1, #smem> %9 = ttng.warp_group_dot %5, %8, %arg7 {inputPrecision = 0 : i32} : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x256xf16, #shared1, #smem> -> tensor<128x256xf32, #mma> @@ -129,7 +129,7 @@ tt.func public @matmul_kernel_tma_persistent(%arg0: !tt.tensordesc to tensor<128x256xf16, #mma> %12 = ttg.convert_layout %11 : tensor<128x256xf16, #mma> -> tensor<128x256xf16, #blocked1> - tt.descriptor_store %arg2[%c0_i32, %c0_i32], %12 : !tt.tensordesc>, tensor<128x256xf16, #blocked1> + tt.descriptor_store %arg2[%c0_i32, %c0_i32], %12 : !tt.tensordesc<128x256xf16, #shared>, tensor<128x256xf16, #blocked1> } // CHECK: yield %{{.*}}, [[NEXT_LHS_BUF_IDX]], [[LHS_BUF_IDX]], [[LHS_PHASE]], [[NEXT_RHS_BUF_IDX]], [[RHS_BUF_IDX]], [[RHS_PHASE]] scf.yield %9 : tensor<128x256xf32, #mma> diff --git a/test/TritonGPU/loop-pipeline-blackwell.mlir b/test/TritonGPU/loop-pipeline-blackwell.mlir index d817e7e3c5cc..f52158eed27f 100644 --- a/test/TritonGPU/loop-pipeline-blackwell.mlir +++ b/test/TritonGPU/loop-pipeline-blackwell.mlir @@ -183,8 +183,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // CHECK-SAME: [[LHS_X:%arg[0-9]+]]: // CHECK-SAME: [[RHS_X:%arg[0-9]+]]: tt.func private @pipelined_gather( - %lhs_desc: !tt.tensordesc>, - %rhs_desc: !tt.tensordesc>, + %lhs_desc: !tt.tensordesc<1x128xbf16, #nvmma_128>, + %rhs_desc: !tt.tensordesc<1x32xbf16, #nvmma_64>, %lhs_x_offsets: tensor<32xi32, #blocked1>, %rhs_x_offsets: tensor<128xi32, #blocked1>) -> tensor<32x32xf32, #blocked> { %c0_i32 = arith.constant 0 : i32 @@ -221,8 +221,8 @@ tt.func private @pipelined_gather( // CHECK: [[LHS_VIEW:%.*]] = ttg.memdesc_index [[LHS_BUFS]] // CHECK: [[LHS:%.*]] = ttg.local_load [[LHS_VIEW]] // CHECK: tt.dot [[LHS]], [[RHS]] - %lhs = tt.descriptor_gather %lhs_desc[%lhs_x_offsets, %y] : (!tt.tensordesc>, tensor<32xi32, #blocked1>, i32) -> tensor<32x128xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - %rhs = tt.descriptor_gather %rhs_desc[%rhs_x_offsets, %y] : (!tt.tensordesc>, tensor<128xi32, #blocked1>, i32) -> tensor<128x32xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %lhs = tt.descriptor_gather %lhs_desc[%lhs_x_offsets, %y] : (!tt.tensordesc<1x128xbf16, #nvmma_128>, tensor<32xi32, #blocked1>, i32) -> tensor<32x128xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %rhs = tt.descriptor_gather %rhs_desc[%rhs_x_offsets, %y] : (!tt.tensordesc<1x32xbf16, #nvmma_64>, tensor<128xi32, #blocked1>, i32) -> tensor<128x32xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> %next = tt.dot %lhs, %rhs, %acc : tensor<32x128xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<128x32xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x32xf32, #mma> diff --git a/test/TritonGPU/loop-pipeline-cuda.mlir b/test/TritonGPU/loop-pipeline-cuda.mlir index b9b0d46d5e17..8a363a28ca0a 100644 --- a/test/TritonGPU/loop-pipeline-cuda.mlir +++ b/test/TritonGPU/loop-pipeline-cuda.mlir @@ -181,16 +181,16 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-NOT: ttng.wait_barrier // CHECK-COUNT-2: ttng.async_tma_copy_global_to_local // CHECK: scf.yield - tt.func public @matmul_tma(%arg0: !tt.tensordesc>, %arg1: !tt.tensordesc>) -> tensor<128x256xf32, #mma> { + tt.func public @matmul_tma(%arg0: !tt.tensordesc<128x64xf16, #shared>, %arg1: !tt.tensordesc<64x256xf16, #shared>) -> tensor<128x256xf32, #mma> { %c256_i32 = arith.constant 256 : i32 %c0_i32 = arith.constant 0 : i32 %c64_i32 = arith.constant 64 : i32 %c1_i32 = arith.constant 1 : i32 %cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma> %0:2 = scf.for %arg3 = %c0_i32 to %c256_i32 step %c1_i32 iter_args(%arg4 = %cst, %arg5 = %c0_i32) -> (tensor<128x256xf32, #mma>, i32) : i32 { - %1 = tt.descriptor_load %arg0[%c0_i32, %arg5] : !tt.tensordesc> -> tensor<128x64xf16, #blocked> + %1 = tt.descriptor_load %arg0[%c0_i32, %arg5] : !tt.tensordesc<128x64xf16, #shared> -> tensor<128x64xf16, #blocked> %2 = ttg.local_alloc %1 : (tensor<128x64xf16, #blocked>) -> !ttg.memdesc<128x64xf16, #shared, #smem> - %3 = tt.descriptor_load %arg1[%arg5, %c0_i32] : !tt.tensordesc> -> tensor<64x256xf16, #blocked1> + %3 = tt.descriptor_load %arg1[%arg5, %c0_i32] : !tt.tensordesc<64x256xf16, #shared> -> tensor<64x256xf16, #blocked1> %4 = ttg.local_alloc %3 : (tensor<64x256xf16, #blocked1>) -> !ttg.memdesc<64x256xf16, #shared, #smem> %5 = ttng.warp_group_dot %2, %4, %arg4 { inputPrecision = 0 : i32 } : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x256xf16, #shared, #smem> -> tensor<128x256xf32, #mma> %6 = arith.addi %arg5, %c64_i32 : i32 diff --git a/test/TritonGPU/loop-pipeline-hopper.mlir b/test/TritonGPU/loop-pipeline-hopper.mlir index dc44796ffe98..f556c82e6452 100644 --- a/test/TritonGPU/loop-pipeline-hopper.mlir +++ b/test/TritonGPU/loop-pipeline-hopper.mlir @@ -453,7 +453,7 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num- module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { // CHECK: #[[$SHARED:.+]] = #ttg.nvmma_shared<{swizzlingByteWidth = 64, transposed = false, elementBitWidth = 8}> // CHECK-LABEL: tma_store_pipeline - tt.func public @tma_store_pipeline(%arg0: tensor<128x128xf32, #blocked>, %arg1: !tt.tensordesc>, %arg2: i32, %arg3: i32) { + tt.func public @tma_store_pipeline(%arg0: tensor<128x128xf32, #blocked>, %arg1: !tt.tensordesc<128x128xf32, #shared>, %arg2: i32, %arg3: i32) { %c0_i32 = arith.constant 0 : i32 // CHECK: ttg.local_alloc : () -> !ttg.memdesc<128x128xf32, #[[$SHARED]], #smem, mutable> // CHECK: scf.for @@ -463,7 +463,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // CHECK-NEXT: ttg.local_store // CHECK-NEXT: ttng.fence_async_shared // CHECK-NEXT: ttng.async_tma_copy_local_to_global - tt.descriptor_store %arg1[%1, %1], %arg0 : !tt.tensordesc>, tensor<128x128xf32, #blocked> + tt.descriptor_store %arg1[%1, %1], %arg0 : !tt.tensordesc<128x128xf32, #shared>, tensor<128x128xf32, #blocked> } tt.return } @@ -476,7 +476,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ #shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 32}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: tma_scatter_pipeline - tt.func public @tma_scatter_pipeline(%arg0: tensor<8x128xf32, #blocked>, %arg1: !tt.tensordesc>, %arg2: i32, %arg3: i32) { + tt.func public @tma_scatter_pipeline(%arg0: tensor<8x128xf32, #blocked>, %arg1: !tt.tensordesc<1x128xf32, #shared>, %arg2: i32, %arg3: i32) { %c0_i32 = arith.constant 0 : i32 scf.for %arg4 = %c0_i32 to %arg3 step %arg2 : i32 { %1 = arith.divsi %arg4, %arg2 : i32 @@ -485,7 +485,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // CHECK-NEXT: ttg.local_store // CHECK-NEXT: ttng.fence_async_shared // CHECK-NEXT: ttng.async_tma_scatter - tt.descriptor_scatter %arg1[%2, %1], %arg0 : !tt.tensordesc>, tensor<8xi32, #blocked1>, i32, tensor<8x128xf32, #blocked> + tt.descriptor_scatter %arg1[%2, %1], %arg0 : !tt.tensordesc<1x128xf32, #shared>, tensor<8xi32, #blocked1>, i32, tensor<8x128xf32, #blocked> } tt.return } @@ -506,7 +506,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // CHECK: scf.for scf.for %arg4 = %c0_i32 to %arg3 step %arg2 : i32 { %1 = arith.divsi %arg4, %arg2 : i32 - %desc = tt.make_tensor_descriptor %arg1, [%c128_i32, %c128_i32], [%c128_i64, %c1_i64] : , > + %desc = tt.make_tensor_descriptor %arg1, [%c128_i32, %c128_i32], [%c128_i64, %c1_i64] : , <128x128xf32, #shared> // CHECK: ttng.tensormap_create // CHECK: ttng.tensormap_fenceproxy_acquire // CHECK: ttng.async_tma_store_wait {pendings = 0 : i32} @@ -514,7 +514,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // CHECK-NEXT: ttng.fence_async_shared // CHECK-NEXT: ttng.async_tma_copy_local_to_global // CHECK: scf.yield - tt.descriptor_store %desc[%c0_i32, %1], %arg0 : !tt.tensordesc>, tensor<128x128xf32, #blocked> + tt.descriptor_store %desc[%c0_i32, %1], %arg0 : !tt.tensordesc<128x128xf32, #shared>, tensor<128x128xf32, #blocked> } // CHECK: ttng.async_tma_store_wait {pendings = 0 : i32} tt.return @@ -526,7 +526,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ #smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: tma_multiple_store_pipeline - tt.func public @tma_multiple_store_pipeline(%arg0: tensor<1xf32, #blocked>, %arg1: !tt.tensordesc>, %arg2: i32, %arg3: i32) { + tt.func public @tma_multiple_store_pipeline(%arg0: tensor<1xf32, #blocked>, %arg1: !tt.tensordesc<1xf32, #shared>, %arg2: i32, %arg3: i32) { %c0_i32 = arith.constant 0 : i32 // CHECK: %[[ALLOC:.+]] = ttg.local_alloc : () -> !ttg.memdesc<1xf32, #shared, #smem, mutable> // CHECK: scf.for @@ -541,8 +541,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // CHECK-NEXT: ttg.local_store %{{.+}}, %[[ALLOC]] // CHECK-NEXT: ttng.fence_async_shared // CHECK-NEXT: ttng.async_tma_copy_local_to_global %{{.*}} %[[ALLOC]] - tt.descriptor_store %arg1[%1], %arg0 : !tt.tensordesc>, tensor<1xf32, #blocked> - tt.descriptor_store %arg1[%2], %arg0 : !tt.tensordesc>, tensor<1xf32, #blocked> + tt.descriptor_store %arg1[%1], %arg0 : !tt.tensordesc<1xf32, #shared>, tensor<1xf32, #blocked> + tt.descriptor_store %arg1[%2], %arg0 : !tt.tensordesc<1xf32, #shared>, tensor<1xf32, #blocked> } tt.return } @@ -1022,22 +1022,22 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ %6 = arith.muli %4, %c64_i32 : i32 %7 = arith.addi %arg5, %c63_i32 : i32 %8 = arith.divsi %7, %c64_i32 : i32 - %9 = ttng.reinterpret_tensor_descriptor %arg0 : !tt.ptr to !tt.tensordesc> - %10 = ttng.reinterpret_tensor_descriptor %arg1 : !tt.ptr to !tt.tensordesc> + %9 = ttng.reinterpret_tensor_descriptor %arg0 : !tt.ptr to !tt.tensordesc<128x64xf8E4M3FN, #shared> + %10 = ttng.reinterpret_tensor_descriptor %arg1 : !tt.ptr to !tt.tensordesc<64x64xf8E4M3FN, #shared> %true = arith.constant true %false = arith.constant false %11:2 = scf.for %arg6 = %c0_i32 to %8 step %c1_i32 iter_args(%arg7 = %cst, %arg8 = %c0_i32) -> (tensor<128x64xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 64, 32]}>>, i32) : i32 { - %14 = tt.descriptor_load %9[%5, %arg8] : !tt.tensordesc> -> tensor<128x64xf8E4M3FN, #blocked> + %14 = tt.descriptor_load %9[%5, %arg8] : !tt.tensordesc<128x64xf8E4M3FN, #shared> -> tensor<128x64xf8E4M3FN, #blocked> %15 = ttg.local_alloc %14 : (tensor<128x64xf8E4M3FN, #blocked>) -> !ttg.memdesc<128x64xf8E4M3FN, #shared, #ttg.shared_memory> - %16 = tt.descriptor_load %10[%arg8, %6] : !tt.tensordesc> -> tensor<64x64xf8E4M3FN, #blocked> + %16 = tt.descriptor_load %10[%arg8, %6] : !tt.tensordesc<64x64xf8E4M3FN, #shared> -> tensor<64x64xf8E4M3FN, #blocked> %17 = ttg.local_alloc %16 : (tensor<64x64xf8E4M3FN, #blocked>) -> !ttg.memdesc<64x64xf8E4M3FN, #shared1, #ttg.shared_memory> %18 = ttng.warp_group_dot %15, %17, %arg7 {inputPrecision = 0 : i32, maxNumImpreciseAcc = 1073741824 : i32} : !ttg.memdesc<128x64xf8E4M3FN, #shared, #ttg.shared_memory> * !ttg.memdesc<64x64xf8E4M3FN, #shared1, #ttg.shared_memory> -> tensor<128x64xf32, #mma> %19 = arith.addi %arg8, %c64_i32 : i32 scf.yield %18, %19 : tensor<128x64xf32, #mma>, i32 } %12 = ttg.convert_layout %11#0 : tensor<128x64xf32, #mma> -> tensor<128x64xf32, #blocked> - %13 = ttng.reinterpret_tensor_descriptor %arg2 : !tt.ptr to !tt.tensordesc> - tt.descriptor_store %13[%5, %6], %12 : !tt.tensordesc>, tensor<128x64xf32, #blocked> + %13 = ttng.reinterpret_tensor_descriptor %arg2 : !tt.ptr to !tt.tensordesc<128x64xf32, #nvmma_128> + tt.descriptor_store %13[%5, %6], %12 : !tt.tensordesc<128x64xf32, #nvmma_128>, tensor<128x64xf32, #blocked> tt.return } } diff --git a/test/TritonGPU/loop-schedule.mlir b/test/TritonGPU/loop-schedule.mlir index 0e1fdd6997f8..6f44b84bf8e4 100644 --- a/test/TritonGPU/loop-schedule.mlir +++ b/test/TritonGPU/loop-schedule.mlir @@ -74,7 +74,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ tt.func public @fused_loop(%arg5: !tt.ptr {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}) { %c10_i32 = arith.constant 10 : i32 %false = arith.constant false - %0 = ub.poison : !tt.tensordesc> + %0 = ub.poison : !tt.tensordesc<64x256xf16> %cst = arith.constant dense<0> : tensor<128x1xi64, #blocked> %c-1_i32 = arith.constant -1 : i32 %c1_i32 = arith.constant 1 : i32 @@ -86,12 +86,12 @@ tt.func public @fused_loop(%arg5: !tt.ptr {tt.divisibility = 16 : i32}, %ar %1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> %2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> %3 = arith.extsi %arg7 : i32 to i64 - %4 = tt.make_tensor_descriptor %arg5, [%arg7, %arg7], [%3, %c1_i64] : , > + %4 = tt.make_tensor_descriptor %arg5, [%arg7, %arg7], [%3, %c1_i64] : , <64x256xf16> %5 = tt.broadcast %2 : tensor<1x64xi32, #blocked> -> tensor<128x64xi32, #blocked> %7 = tt.splat %3 : i64 -> tensor<128x1xi64, #blocked> // CHECK: scf.for - %8:9 = scf.for %arg29 = %c0_i32 to %arg7 step %c1_i32 iter_args(%arg30 = %c-1_i32, %arg31 = %4, %arg32 = %c0_i32, %arg33 = %arg5, %arg34 = %cst_0, %arg35 = %c0_i32, %arg36 = %cst, %arg37 = %0, %arg38 = %false) -> (i32, !tt.tensordesc>, i32, !tt.ptr, tensor<128x256xf32, #mma>, i32, tensor<128x1xi64, #blocked>, !tt.tensordesc>, i1) : i32 { + %8:9 = scf.for %arg29 = %c0_i32 to %arg7 step %c1_i32 iter_args(%arg30 = %c-1_i32, %arg31 = %4, %arg32 = %c0_i32, %arg33 = %arg5, %arg34 = %cst_0, %arg35 = %c0_i32, %arg36 = %cst, %arg37 = %0, %arg38 = %false) -> (i32, !tt.tensordesc<64x256xf16>, i32, !tt.ptr, tensor<128x256xf32, #mma>, i32, tensor<128x1xi64, #blocked>, !tt.tensordesc<64x256xf16>, i1) : i32 { %9 = arith.addi %arg30, %c1_i32 : i32 %10 = arith.cmpi eq, %arg30, %c10_i32 : i32 %11 = arith.select %10, %c0_i32, %9 : i32 @@ -101,7 +101,7 @@ tt.func public @fused_loop(%arg5: !tt.ptr {tt.divisibility = 16 : i32}, %ar // CHECK: {_test_marker_0, loop.cluster = 4 : i32, loop.stage = 0 : i32} %13 = arith.select %12, %c0_i32, %arg32 {_test_marker_0} : i32 - %14 = arith.select %12, %arg31, %arg37 : !tt.tensordesc> + %14 = arith.select %12, %arg31, %arg37 : !tt.tensordesc<64x256xf16> %15 = arith.select %12, %c10_i32, %arg35 : i32 %16 = scf.if %12 -> (tensor<128x1xi64, #blocked>) { %32 = arith.muli %cst, %7 : tensor<128x1xi64, #blocked> @@ -117,7 +117,7 @@ tt.func public @fused_loop(%arg5: !tt.ptr {tt.divisibility = 16 : i32}, %ar %22 = tt.load %20 : tensor<128x64x!tt.ptr, #blocked> %23 = ttg.local_alloc %22 : (tensor<128x64xf16, #blocked>) -> !ttg.memdesc<128x64xf16, #shared, #smem> %24 = arith.muli %13, %c64_i32 : i32 - %25 = tt.descriptor_load %14[%24, %15] : !tt.tensordesc> -> tensor<64x256xf16, #blocked1> + %25 = tt.descriptor_load %14[%24, %15] : !tt.tensordesc<64x256xf16> -> tensor<64x256xf16, #blocked1> %26 = ttg.local_alloc %25 : (tensor<64x256xf16, #blocked1>) -> !ttg.memdesc<64x256xf16, #shared, #smem> %27 = ttng.warp_group_dot %23, %26, %arg34, %arg38 {inputPrecision = 0 : i32} : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x256xf16, #shared, #smem> -> tensor<128x256xf32, #mma> %28 = arith.addi %13, %c1_i32 : i32 @@ -135,7 +135,7 @@ tt.func public @fused_loop(%arg5: !tt.ptr {tt.divisibility = 16 : i32}, %ar "use"(%27) : (tensor<128x256xf32, #mma>) -> () // CHECK: {_test_marker_3, loop.cluster = 5 : i32, loop.stage = 2 : i32} } {_test_marker_3} - scf.yield %11, %14, %28, %30, %27, %15, %16, %14, %31 : i32, !tt.tensordesc>, i32, !tt.ptr, tensor<128x256xf32, #mma>, i32, tensor<128x1xi64, #blocked>, !tt.tensordesc>, i1 + scf.yield %11, %14, %28, %30, %27, %15, %16, %14, %31 : i32, !tt.tensordesc<64x256xf16>, i32, !tt.ptr, tensor<128x256xf32, #mma>, i32, tensor<128x1xi64, #blocked>, !tt.tensordesc<64x256xf16>, i1 } tt.return } diff --git a/test/TritonGPU/matmul-loop-pipeline.mlir b/test/TritonGPU/matmul-loop-pipeline.mlir index f52dd178ad43..39e8012e0bfb 100644 --- a/test/TritonGPU/matmul-loop-pipeline.mlir +++ b/test/TritonGPU/matmul-loop-pipeline.mlir @@ -71,7 +71,7 @@ tt.func public @make_tensor_desc_epilogue(%arg0: i32, %arg1: !tt.ptr, %arg2 // CHECK-NOT: tt.make_tensor_descriptor // CHECK: ttng.tensormap_create // CHECK-NEXT: ttng.tensormap_fenceproxy_acquire - %5 = tt.make_tensor_descriptor %arg1, [%arg2, %arg2], [%c1_i64, %c1_i64] : , > + %5 = tt.make_tensor_descriptor %arg1, [%arg2, %arg2], [%c1_i64, %c1_i64] : , <128x256xf32, #nvmma_128> } {loop.cluster = 5 : i32, loop.stage = 2 : i32} } {tt.num_stages = 3 : i32, tt.scheduled_max_stage = 2 : i32} tt.return diff --git a/test/TritonGPU/partition-scheduling.mlir b/test/TritonGPU/partition-scheduling.mlir index 8340d0dc4ebb..ac0681eda4b8 100644 --- a/test/TritonGPU/partition-scheduling.mlir +++ b/test/TritonGPU/partition-scheduling.mlir @@ -14,8 +14,8 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { // CHECK-LABEL: @attention_forward tt.func public @attention_forward( %Q_shared: !ttg.memdesc<256x64xf16, #shared, #smem>, - %K_desc: !tt.tensordesc>, - %V_desc: !tt.tensordesc>, + %K_desc: !tt.tensordesc<64x64xf16, #shared>, + %V_desc: !tt.tensordesc<64x64xf16, #shared>, %qk_scale: f32, %n_tiles: i32 ) { @@ -42,7 +42,7 @@ tt.func public @attention_forward( ) : i32 { // CHECK-COUNT-2: ttg.partition = array - %K = tt.descriptor_load %K_desc[%i, %c0_i32] : !tt.tensordesc> -> tensor<64x64xf16, #load_blocked> + %K = tt.descriptor_load %K_desc[%i, %c0_i32] : !tt.tensordesc<64x64xf16, #shared> -> tensor<64x64xf16, #load_blocked> %K_shared = ttg.local_alloc %K : (tensor<64x64xf16, #load_blocked>) -> !ttg.memdesc<64x64xf16, #shared, #smem> %QK_tmem, %QK_tok = ttng.tmem_alloc : () -> (!ttg.memdesc<256x64xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.async.token) @@ -85,7 +85,7 @@ tt.func public @attention_forward( %e = "sum"(%acc_x) : (tensor<256x64xf32, #blocked>) -> tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>> %next_e_i = arith.addf %e_i, %e : tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>> - %V = tt.descriptor_load %V_desc[%i, %c0_i32] : !tt.tensordesc> -> tensor<64x64xf16, #load_blocked> + %V = tt.descriptor_load %V_desc[%i, %c0_i32] : !tt.tensordesc<64x64xf16, #shared> -> tensor<64x64xf16, #load_blocked> %V_shared = ttg.local_alloc %V : (tensor<64x64xf16, #load_blocked>) -> !ttg.memdesc<64x64xf16, #shared, #smem> %P = arith.truncf %softmax : tensor<256x64xf32, #blocked> to tensor<256x64xf16, #blocked> @@ -107,8 +107,8 @@ tt.func public @attention_forward( // CHECK-LABEL: @mma_operand_view tt.func public @mma_operand_view( %Q_shared: !ttg.memdesc<256x64xf16, #shared, #smem>, - %K_desc: !tt.tensordesc>, - %V_desc: !tt.tensordesc>, + %K_desc: !tt.tensordesc<64x64xf16, #shared>, + %V_desc: !tt.tensordesc<64x64xf16, #shared>, %qk_scale: f32, %n_tiles: i32 ) { @@ -124,7 +124,7 @@ tt.func public @mma_operand_view( %QK_tmem, %QK_tok = ttng.tmem_alloc : () -> (!ttg.memdesc<256x64xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.async.token) scf.for %i = %c0_i32 to %n_tiles step %c64_i32 : i32 { - %K = tt.descriptor_load %K_desc[%i, %c0_i32] : !tt.tensordesc> -> tensor<64x64xf16, #load_blocked> + %K = tt.descriptor_load %K_desc[%i, %c0_i32] : !tt.tensordesc<64x64xf16, #shared> -> tensor<64x64xf16, #load_blocked> // CHECK: [[K_SHARED:%.*]] = ttg.local_alloc {{.*}}partition = array %K_shared = ttg.local_alloc %K : (tensor<64x64xf16, #load_blocked>) -> !ttg.memdesc<64x64xf16, #shared, #smem> @@ -152,12 +152,12 @@ tt.func public @mma_operand_view( } // CHECK-LABEL: @optimize_broadcast -tt.func @optimize_broadcast(%arg0: i32, %arg1: !tt.tensordesc>) { +tt.func @optimize_broadcast(%arg0: i32, %arg1: !tt.tensordesc<128x128xf32, #shared_f32>) { %c0_i32 = arith.constant 0 : i32 %c1_i32 = arith.constant 1 : i32 // CHECK: scf.for scf.for %i = %c0_i32 to %arg0 step %c1_i32 : i32 { - %md = tt.descriptor_load %arg1[%c0_i32, %c0_i32] {ttg.partition = array} : !tt.tensordesc> -> tensor<128x128xf32, #load_blocked> + %md = tt.descriptor_load %arg1[%c0_i32, %c0_i32] {ttg.partition = array} : !tt.tensordesc<128x128xf32, #shared_f32> -> tensor<128x128xf32, #load_blocked> %smem = ttg.local_alloc %md {ttg.partition = array} : (tensor<128x128xf32, #load_blocked>) -> !ttg.memdesc<128x128xf32, #shared_f32, #smem> %tmp = ttg.local_load %smem {ttg.partition = array} : !ttg.memdesc<128x128xf32, #shared_f32, #smem> -> tensor<128x128xf32, #load_blocked> "use_memdesc"(%tmp) {ttg.partition = array} : (tensor<128x128xf32, #load_blocked>) -> () @@ -231,31 +231,31 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { %false = arith.constant false %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> %c32_i32 = arith.constant 32 : i32 - %0 = ub.poison : !tt.tensordesc> - %1 = ub.poison : !tt.tensordesc> + %0 = ub.poison : !tt.tensordesc<128x64xf16, #shared> + %1 = ub.poison : !tt.tensordesc<64x128xf16, #shared> %result, %token = ttng.tmem_alloc : () -> (!ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.async.token) %2 = ttng.tmem_store %cst, %result[%token], %true : tensor<128x128xf32, #blocked> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> // CHECK: scf.for - %3:4 = scf.for %arg2 = %c0_i32 to %c32_i32 step %c1_i32 iter_args(%arg3 = %true, %arg4 = %0, %arg5 = %1, %arg6 = %2) -> (i1, !tt.tensordesc>, !tt.tensordesc>, !ttg.async.token) : i32 { + %3:4 = scf.for %arg2 = %c0_i32 to %c32_i32 step %c1_i32 iter_args(%arg3 = %true, %arg4 = %0, %arg5 = %1, %arg6 = %2) -> (i1, !tt.tensordesc<128x64xf16, #shared>, !tt.tensordesc<64x128xf16, #shared>, !ttg.async.token) : i32 { // CHECK-NEXT: "prologue_cond"({{.*}}) {ttg.partition = array} %4 = "prologue_cond"(%arg2) : (i32) -> i1 // CHECK-NEXT: scf.if - %5:2 = scf.if %4 -> (!tt.tensordesc>, !tt.tensordesc>) { + %5:2 = scf.if %4 -> (!tt.tensordesc<128x64xf16, #shared>, !tt.tensordesc<64x128xf16, #shared>) { // CHECK-COUNT-2: ttg.partition = array - %15 = tt.make_tensor_descriptor %arg0, [%arg2, %arg2], [%c1_i64, %c1_i64] : , > - %16 = tt.make_tensor_descriptor %arg1, [%arg2, %arg2], [%c1_i64, %c1_i64] : , > + %15 = tt.make_tensor_descriptor %arg0, [%arg2, %arg2], [%c1_i64, %c1_i64] : , <128x64xf16, #shared> + %16 = tt.make_tensor_descriptor %arg1, [%arg2, %arg2], [%c1_i64, %c1_i64] : , <64x128xf16, #shared> // CHECK-NEXT: scf.yield {ttg.partition = array} - scf.yield %15, %16 : !tt.tensordesc>, !tt.tensordesc> + scf.yield %15, %16 : !tt.tensordesc<128x64xf16, #shared>, !tt.tensordesc<64x128xf16, #shared> } else { // CHECK-NEXT: } else { // CHECK-NEXT: scf.yield {ttg.partition = array} - scf.yield %arg4, %arg5 : !tt.tensordesc>, !tt.tensordesc> + scf.yield %arg4, %arg5 : !tt.tensordesc<128x64xf16, #shared>, !tt.tensordesc<64x128xf16, #shared> // CHECK-NEXT: ttg.partition = array, ttg.partition.outputs = [array, array] } // CHECK-COUNT-5: ttg.partition = array %6:3 = "get_offsets"(%arg2) : (i32) -> (i32, i32, i32) - %7 = tt.descriptor_load %arg4[%6#0, %6#2] : !tt.tensordesc> -> tensor<128x64xf16, #blocked1> - %8 = tt.descriptor_load %arg5[%6#1, %6#2] : !tt.tensordesc> -> tensor<64x128xf16, #blocked1> + %7 = tt.descriptor_load %arg4[%6#0, %6#2] : !tt.tensordesc<128x64xf16, #shared> -> tensor<128x64xf16, #blocked1> + %8 = tt.descriptor_load %arg5[%6#1, %6#2] : !tt.tensordesc<64x128xf16, #shared> -> tensor<64x128xf16, #blocked1> %9 = ttg.local_alloc %7 : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #smem> %10 = ttg.local_alloc %8 : (tensor<64x128xf16, #blocked1>) -> !ttg.memdesc<64x128xf16, #shared, #smem> // CHECK-NEXT: tc_gen5_mma {{.*}} {ttg.partition = array} {{.*}} @@ -278,14 +278,14 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { scf.yield %11 : !ttg.async.token } // CHECK-NEXT: scf.yield {ttg.partition = array} - scf.yield %13, %5#0, %5#1, %14 : i1, !tt.tensordesc>, !tt.tensordesc>, !ttg.async.token + scf.yield %13, %5#0, %5#1, %14 : i1, !tt.tensordesc<128x64xf16, #shared>, !tt.tensordesc<64x128xf16, #shared>, !ttg.async.token // CHECK-NEXT: ttg.partition = array, ttg.partition.outputs = [array, array, array, array] } {tt.disallow_acc_multi_buffer, tt.num_stages = 4 : i32, tt.warp_specialize} tt.return } // CHECK-LABEL: @matmul_tma_acc_with_conditional_def_and_use - tt.func @matmul_tma_acc_with_conditional_def_and_use(%arg0: !tt.tensordesc>, %arg1: !tt.tensordesc>) { + tt.func @matmul_tma_acc_with_conditional_def_and_use(%arg0: !tt.tensordesc<1x64xf16, #shared>, %arg1: !tt.tensordesc<64x128xf16, #shared>) { %c0_i32 = arith.constant 0 : i32 %c1_i32 = arith.constant 1 : i32 %true = arith.constant true @@ -299,8 +299,8 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { // CHECK-COUNT-6: ttg.partition = array %2:3 = "get_offsets"(%arg2) : (i32) -> (i32, i32, i32) %3 = tt.splat %2#0 : i32 -> tensor<128xi32, #blocked2> - %4 = tt.descriptor_gather %arg0[%3, %2#2] : (!tt.tensordesc>, tensor<128xi32, #blocked2>, i32) -> tensor<128x64xf16, #blocked1> - %5 = tt.descriptor_load %arg1[%2#1, %2#2] : !tt.tensordesc> -> tensor<64x128xf16, #blocked1> + %4 = tt.descriptor_gather %arg0[%3, %2#2] : (!tt.tensordesc<1x64xf16, #shared>, tensor<128xi32, #blocked2>, i32) -> tensor<128x64xf16, #blocked1> + %5 = tt.descriptor_load %arg1[%2#1, %2#2] : !tt.tensordesc<64x128xf16, #shared> -> tensor<64x128xf16, #blocked1> %6 = ttg.local_alloc %4 : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #smem> %7 = ttg.local_alloc %5 : (tensor<64x128xf16, #blocked1>) -> !ttg.memdesc<64x128xf16, #shared, #smem> // CHECK-NEXT: ttg.partition = array @@ -342,8 +342,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ // CHECK-LABEL: @if_stmt_yield_outputs tt.func @if_stmt_yield_outputs(%lb: i32, %ub: i32, %step: i32, %a0: i32, %b0: i32, - %arg1: !tt.tensordesc> {tt.nv_tma_desc = 1 : i32}, - %arg2: !tt.tensordesc> {tt.nv_tma_desc = 1 : i32}) { + %arg1: !tt.tensordesc<1x128x64xbf16, #shared> {tt.nv_tma_desc = 1 : i32}, + %arg2: !tt.tensordesc<1x64x64xf32, #shared1> {tt.nv_tma_desc = 1 : i32}) { %false = arith.constant false %true = arith.constant true %c0_i32 = arith.constant 0 : i32 @@ -355,7 +355,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ // CHECK: scf.for scf.for %arg3 = %lb to %ub step %step : i32 { // CHECK-NEXT: tt.descriptor_load {{.*}} {ttg.partition = array} {{.*}} - %20 = tt.descriptor_load %arg1[%a0, %b0, %c0_i32] : !tt.tensordesc> -> tensor<128x64xbf16, #blocked> + %20 = tt.descriptor_load %arg1[%a0, %b0, %c0_i32] : !tt.tensordesc<1x128x64xbf16, #shared> -> tensor<128x64xbf16, #blocked> %22 = arith.cmpi sge, %arg3, %c3_i32 : i32 // CHECK: scf.if %23 = scf.if %22 -> (tensor<128x64xbf16, #blocked>) { @@ -381,7 +381,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ // CHECK: scf.for scf.for %arg3 = %lb to %ub step %step : i32 { - %20 = tt.descriptor_load %arg1[%a0, %b0, %c0_i32] : !tt.tensordesc> -> tensor<128x64xbf16, #blocked> + %20 = tt.descriptor_load %arg1[%a0, %b0, %c0_i32] : !tt.tensordesc<1x128x64xbf16, #shared> -> tensor<128x64xbf16, #blocked> %22 = arith.cmpi sge, %arg3, %c3_i32 : i32 %23 = scf.if %22 -> (tensor<128x64xbf16, #blocked>) { %32 = arith.muli %arg3, %c128_i32 {ttg.partition = array} : i32 @@ -402,7 +402,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ // CHECK: scf.for scf.for %arg4 = %lb to %ub step %step : i32 { - %20 = tt.descriptor_load %arg1[%a0, %b0, %c0_i32] : !tt.tensordesc> -> tensor<128x64xbf16, #blocked> + %20 = tt.descriptor_load %arg1[%a0, %b0, %c0_i32] : !tt.tensordesc<1x128x64xbf16, #shared> -> tensor<128x64xbf16, #blocked> %22 = arith.cmpi sge, %arg4, %c3_i32 : i32 // CHECK: scf.if %23 = scf.if %22 -> (tensor<128x64xbf16, #blocked>) { @@ -437,7 +437,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ #tmem = #ttng.tensor_memory_encoding module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: matmul_nested_persistent_ws_kernel - tt.func public @matmul_nested_persistent_ws_kernel(%a_desc_0: !tt.tensordesc>, %b_desc_1: !tt.tensordesc>, %c_desc_2: !tt.tensordesc>, %M: i32 {tt.divisibility = 16 : i32}, %N: i32 {tt.divisibility = 16 : i32}, %K: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + tt.func public @matmul_nested_persistent_ws_kernel(%a_desc_0: !tt.tensordesc<128x128xf8E4M3FN, #shared>, %b_desc_1: !tt.tensordesc<128x128xf8E4M3FN, #shared>, %c_desc_2: !tt.tensordesc<128x128xf8E4M3FN, #shared>, %M: i32 {tt.divisibility = 16 : i32}, %N: i32 {tt.divisibility = 16 : i32}, %K: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { %false = arith.constant false %true = arith.constant true %c1_i64 = arith.constant 1 : i64 @@ -475,9 +475,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // CHECK: arith.muli {{.*}}ttg.partition = array} %off_k = arith.muli %accumulator_15, %c128_i32 {loop.cluster = 2 : i32, loop.stage = 0 : i32} : i32 // CHECK: tt.descriptor_load {{.*}}ttg.partition = array} - %a = tt.descriptor_load %a_desc_0[%off_am, %off_k] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : !tt.tensordesc> -> tensor<128x128xf8E4M3FN, #blocked1> + %a = tt.descriptor_load %a_desc_0[%off_am, %off_k] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : !tt.tensordesc<128x128xf8E4M3FN, #shared> -> tensor<128x128xf8E4M3FN, #blocked1> %a_17 = ttg.local_alloc %a {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<128x128xf8E4M3FN, #blocked1>) -> !ttg.memdesc<128x128xf8E4M3FN, #shared, #smem> - %b = tt.descriptor_load %b_desc_1[%off_bn, %off_k] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : !tt.tensordesc> -> tensor<128x128xf8E4M3FN, #blocked1> + %b = tt.descriptor_load %b_desc_1[%off_bn, %off_k] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : !tt.tensordesc<128x128xf8E4M3FN, #shared> -> tensor<128x128xf8E4M3FN, #blocked1> %accumulator_18 = ttg.local_alloc %b {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<128x128xf8E4M3FN, #blocked1>) -> !ttg.memdesc<128x128xf8E4M3FN, #shared, #smem> %accumulator_19 = ttg.memdesc_trans %accumulator_18 {loop.cluster = 0 : i32, loop.stage = 2 : i32, order = array} : !ttg.memdesc<128x128xf8E4M3FN, #shared, #smem> -> !ttg.memdesc<128x128xf8E4M3FN, #shared1, #smem> // CHECK: ttng.tc_gen5_mma {{.*}}ttg.partition = array} @@ -489,7 +489,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ %accumulator_12, %accumulator_13 = ttng.tmem_load %accumulator[%accumulator_11#1] : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked> %c = tt.fp_to_fp %accumulator_12, rounding = rtne : tensor<128x128xf32, #blocked> -> tensor<128x128xf8E4M3FN, #blocked> %c_14 = ttg.convert_layout %c : tensor<128x128xf8E4M3FN, #blocked> -> tensor<128x128xf8E4M3FN, #blocked1> - tt.descriptor_store %c_desc_2[%off_am, %off_bn], %c_14 : !tt.tensordesc>, tensor<128x128xf8E4M3FN, #blocked1> + tt.descriptor_store %c_desc_2[%off_am, %off_bn], %c_14 : !tt.tensordesc<128x128xf8E4M3FN, #shared>, tensor<128x128xf8E4M3FN, #blocked1> } {tt.num_stages = 3 : i32, tt.warp_specialize} tt.return } @@ -506,7 +506,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ #smem = #ttg.shared_memory #tmem = #ttng.tensor_memory_encoding module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { - tt.func public @attention_persistent_inner_loop_kernel(%desc_q: !tt.tensordesc>, %desc_q_0: i32, %desc_q_1: i32, %desc_q_2: i64, %desc_q_3: i64, %desc_k: !tt.tensordesc>, %desc_k_4: i32, %desc_k_5: i32, %desc_k_6: i64, %desc_k_7: i64, %desc_v: !tt.tensordesc>, %desc_v_8: i32, %desc_v_9: i32, %desc_v_10: i64, %desc_v_11: i64, %desc_acc: !tt.tensordesc>, %desc_acc_12: i32, %desc_acc_13: i32, %desc_acc_14: i64, %desc_acc_15: i64, %l_i_ptr: !tt.ptr {tt.divisibility = 16 : i32}, %m_i_ptr: !tt.ptr {tt.divisibility = 16 : i32}, %M: i32 {tt.divisibility = 16 : i32}, %N: i32 {tt.divisibility = 16 : i32}, %qk_scale: f32) attributes {noinline = false} { + tt.func public @attention_persistent_inner_loop_kernel(%desc_q: !tt.tensordesc<128x128xf16, #shared>, %desc_q_0: i32, %desc_q_1: i32, %desc_q_2: i64, %desc_q_3: i64, %desc_k: !tt.tensordesc<128x128xf16, #shared>, %desc_k_4: i32, %desc_k_5: i32, %desc_k_6: i64, %desc_k_7: i64, %desc_v: !tt.tensordesc<128x128xf16, #shared>, %desc_v_8: i32, %desc_v_9: i32, %desc_v_10: i64, %desc_v_11: i64, %desc_acc: !tt.tensordesc<128x128xf16, #shared>, %desc_acc_12: i32, %desc_acc_13: i32, %desc_acc_14: i64, %desc_acc_15: i64, %l_i_ptr: !tt.ptr {tt.divisibility = 16 : i32}, %m_i_ptr: !tt.ptr {tt.divisibility = 16 : i32}, %M: i32 {tt.divisibility = 16 : i32}, %N: i32 {tt.divisibility = 16 : i32}, %qk_scale: f32) attributes {noinline = false} { %false = arith.constant false %true = arith.constant true %c1_i32 = arith.constant 1 : i32 @@ -522,14 +522,14 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // CHECK: scf.for %tile_idx = scf.for %_ = %c0_i32 to %tiles_per_sm step %c1_i32 iter_args(%tile_idx_20 = %prog_id) -> (i32) : i32 { %off_m = arith.muli %tile_idx_20, %c128_i32 : i32 - %q = tt.descriptor_load %desc_q[%off_m, %c0_i32] : !tt.tensordesc> -> tensor<128x128xf16, #blocked2> + %q = tt.descriptor_load %desc_q[%off_m, %c0_i32] : !tt.tensordesc<128x128xf16, #shared> -> tensor<128x128xf16, #blocked2> %q_21 = ttg.local_alloc %q : (tensor<128x128xf16, #blocked2>) -> !ttg.memdesc<128x128xf16, #shared, #smem> %qk_22, %qk_23 = ttng.tmem_alloc : () -> (!ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.async.token) %acc, %acc_24 = ttng.tmem_alloc : () -> (!ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.async.token) %acc_25 = ttng.tmem_store %cst_17, %acc[%acc_24], %true : tensor<128x128xf32, #blocked> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> // CHECK: scf.for %acc_26:4 = scf.for %acc_30 = %c0_i32 to %N step %c128_i32 iter_args(%arg28 = %cst_16, %arg29 = %cst, %qk_31 = %qk_23, %acc_32 = %acc_25) -> (tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, !ttg.async.token, !ttg.async.token) : i32 { - %k = tt.descriptor_load %desc_k[%acc_30, %c0_i32] : !tt.tensordesc> -> tensor<128x128xf16, #blocked2> + %k = tt.descriptor_load %desc_k[%acc_30, %c0_i32] : !tt.tensordesc<128x128xf16, #shared> -> tensor<128x128xf16, #blocked2> %k_33 = ttg.local_alloc %k : (tensor<128x128xf16, #blocked2>) -> !ttg.memdesc<128x128xf16, #shared, #smem> %k_34 = ttg.memdesc_trans %k_33 {order = array} : !ttg.memdesc<128x128xf16, #shared, #smem> -> !ttg.memdesc<128x128xf16, #shared1, #smem> %qk_35 = ttng.tc_gen5_mma %q_21, %k_34, %qk_22[%qk_31], %false, %true : !ttg.memdesc<128x128xf16, #shared, #smem>, !ttg.memdesc<128x128xf16, #shared1, #smem>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> @@ -543,7 +543,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ %acc_48, %acc_49 = ttng.tmem_load %acc[%acc_32] : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked> %acc_50 = arith.mulf %acc_48, %acc_47 : tensor<128x128xf32, #blocked> %acc_54 = ttng.tmem_store %acc_50, %acc[%acc_49], %true : tensor<128x128xf32, #blocked> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> - %v = tt.descriptor_load %desc_v[%acc_30, %c0_i32] : !tt.tensordesc> -> tensor<128x128xf16, #blocked2> + %v = tt.descriptor_load %desc_v[%acc_30, %c0_i32] : !tt.tensordesc<128x128xf16, #shared> -> tensor<128x128xf16, #blocked2> %v_51 = ttg.local_alloc %v : (tensor<128x128xf16, #blocked2>) -> !ttg.memdesc<128x128xf16, #shared, #smem> %acc_55 = ttng.tc_gen5_mma %p_53, %v_51, %acc[%acc_54], %true, %true : !ttg.memdesc<128x128xf16, #shared, #smem>, !ttg.memdesc<128x128xf16, #shared, #smem>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> diff --git a/test/TritonGPU/pipeline-assign-latencies.mlir b/test/TritonGPU/pipeline-assign-latencies.mlir index 9983a854b49e..8639ed316862 100644 --- a/test/TritonGPU/pipeline-assign-latencies.mlir +++ b/test/TritonGPU/pipeline-assign-latencies.mlir @@ -1033,8 +1033,8 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { // CHECK-LABEL: @attention_forward tt.func public @attention_forward( %Q_shared: !ttg.memdesc<256x64xf16, #shared, #smem>, - %K_desc: !tt.tensordesc>, - %V_desc: !tt.tensordesc>, + %K_desc: !tt.tensordesc<64x64xf16, #shared>, + %V_desc: !tt.tensordesc<64x64xf16, #shared>, %qk_scale: f32, %n_tiles: i32 ) { @@ -1059,7 +1059,7 @@ tt.func public @attention_forward( tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>> ) : i32 { // CHECK: descriptor_load {{.*}} {tt.latency = 2 : i32} - %K = tt.descriptor_load %K_desc[%i, %c0_i32] : !tt.tensordesc> -> tensor<64x64xf16, #load_blocked> + %K = tt.descriptor_load %K_desc[%i, %c0_i32] : !tt.tensordesc<64x64xf16, #shared> -> tensor<64x64xf16, #load_blocked> %K_shared = ttg.local_alloc %K : (tensor<64x64xf16, #load_blocked>) -> !ttg.memdesc<64x64xf16, #shared, #smem> %K_trans = ttg.memdesc_trans %K_shared {order = array} : !ttg.memdesc<64x64xf16, #shared, #smem> -> !ttg.memdesc<64x64xf16, #shared_T, #smem> // CHECK: tc_gen5_mma {{.*}} {tt.latency = 2 : i32, tt.self_latency = 0 : i32} @@ -1071,7 +1071,7 @@ tt.func public @attention_forward( %acc_corrected = arith.mulf %acc, %alpha_1 : tensor<256x64xf32, #blocked> // CHECK: descriptor_load {{.*}} {tt.latency = 2 : i32} - %V = tt.descriptor_load %V_desc[%i, %c0_i32] : !tt.tensordesc> -> tensor<64x64xf16, #load_blocked> + %V = tt.descriptor_load %V_desc[%i, %c0_i32] : !tt.tensordesc<64x64xf16, #shared> -> tensor<64x64xf16, #load_blocked> %V_shared = ttg.local_alloc %V : (tensor<64x64xf16, #load_blocked>) -> !ttg.memdesc<64x64xf16, #shared, #smem> %P_tmem = ttng.tmem_alloc %P : (tensor<256x64xf16, #blocked>) -> !ttg.memdesc<256x64xf16, #tmem, #ttng.tensor_memory> %acc_tmem, %acc_tok = ttng.tmem_alloc %acc_corrected : (tensor<256x64xf32, #blocked>) -> (!ttg.memdesc<256x64xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.async.token) @@ -1099,7 +1099,7 @@ tt.func public @attention_forward( #smem = #ttg.shared_memory #tmem = #ttng.tensor_memory_encoding module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { - tt.func public @attention_persistent_inner_loop_kernel(%desc_q: !tt.tensordesc>, %desc_q_0: i32, %desc_q_1: i32, %desc_q_2: i64, %desc_q_3: i64, %desc_k: !tt.tensordesc>, %desc_k_4: i32, %desc_k_5: i32, %desc_k_6: i64, %desc_k_7: i64, %desc_v: !tt.tensordesc>, %desc_v_8: i32, %desc_v_9: i32, %desc_v_10: i64, %desc_v_11: i64, %desc_acc: !tt.tensordesc>, %desc_acc_12: i32, %desc_acc_13: i32, %desc_acc_14: i64, %desc_acc_15: i64, %l_i_ptr: !tt.ptr {tt.divisibility = 16 : i32}, %m_i_ptr: !tt.ptr {tt.divisibility = 16 : i32}, %M: i32 {tt.divisibility = 16 : i32}, %N: i32 {tt.divisibility = 16 : i32}, %qk_scale: f32) attributes {noinline = false} { + tt.func public @attention_persistent_inner_loop_kernel(%desc_q: !tt.tensordesc<128x128xf16, #shared>, %desc_q_0: i32, %desc_q_1: i32, %desc_q_2: i64, %desc_q_3: i64, %desc_k: !tt.tensordesc<128x128xf16, #shared>, %desc_k_4: i32, %desc_k_5: i32, %desc_k_6: i64, %desc_k_7: i64, %desc_v: !tt.tensordesc<128x128xf16, #shared>, %desc_v_8: i32, %desc_v_9: i32, %desc_v_10: i64, %desc_v_11: i64, %desc_acc: !tt.tensordesc<128x128xf16, #shared>, %desc_acc_12: i32, %desc_acc_13: i32, %desc_acc_14: i64, %desc_acc_15: i64, %l_i_ptr: !tt.ptr {tt.divisibility = 16 : i32}, %m_i_ptr: !tt.ptr {tt.divisibility = 16 : i32}, %M: i32 {tt.divisibility = 16 : i32}, %N: i32 {tt.divisibility = 16 : i32}, %qk_scale: f32) attributes {noinline = false} { %false = arith.constant false %true = arith.constant true %c1_i32 = arith.constant 1 : i32 @@ -1113,13 +1113,13 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ %tiles_per_sm = arith.divsi %num_tiles, %num_sm : i32 %tile_idx = scf.for %_ = %c0_i32 to %tiles_per_sm step %c1_i32 iter_args(%tile_idx_20 = %prog_id) -> (i32) : i32 { %off_m = arith.muli %tile_idx_20, %c128_i32 : i32 - %q = tt.descriptor_load %desc_q[%off_m, %c0_i32] : !tt.tensordesc> -> tensor<128x128xf16, #blocked2> + %q = tt.descriptor_load %desc_q[%off_m, %c0_i32] : !tt.tensordesc<128x128xf16, #shared> -> tensor<128x128xf16, #blocked2> %q_21 = ttg.local_alloc %q : (tensor<128x128xf16, #blocked2>) -> !ttg.memdesc<128x128xf16, #shared, #smem> %qk_22, %qk_23 = ttng.tmem_alloc : () -> (!ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.async.token) %acc, %acc_24 = ttng.tmem_alloc : () -> (!ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.async.token) %acc_26:4 = scf.for %acc_30 = %c0_i32 to %N step %c128_i32 iter_args(%arg28 = %cst_16, %arg29 = %cst, %qk_31 = %qk_23, %acc_32 = %acc_24) -> (tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, !ttg.async.token, !ttg.async.token) : i32 { // CHECK: tt.descriptor_load {{.*}} {tt.latency = 2 : i32} - %k = tt.descriptor_load %desc_k[%acc_30, %c0_i32] : !tt.tensordesc> -> tensor<128x128xf16, #blocked2> + %k = tt.descriptor_load %desc_k[%acc_30, %c0_i32] : !tt.tensordesc<128x128xf16, #shared> -> tensor<128x128xf16, #blocked2> %k_33 = ttg.local_alloc %k : (tensor<128x128xf16, #blocked2>) -> !ttg.memdesc<128x128xf16, #shared, #smem> %k_34 = ttg.memdesc_trans %k_33 {order = array} : !ttg.memdesc<128x128xf16, #shared, #smem> -> !ttg.memdesc<128x128xf16, #shared1, #smem> // CHECK: ttng.tc_gen5_mma {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}} {tt.latency = 2 : i32, tt.self_latency = 0 : i32} @@ -1133,7 +1133,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ %p_53 = ttg.local_alloc %p : (tensor<128x128xf16, #blocked>) -> !ttg.memdesc<128x128xf16, #shared, #smem> %acc_54 = ttng.tmem_store %acc_50, %acc[%acc_49], %true : tensor<128x128xf32, #blocked> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> // CHECK: tt.descriptor_load {{.*}} {tt.latency = 2 : i32} - %v = tt.descriptor_load %desc_v[%acc_30, %c0_i32] : !tt.tensordesc> -> tensor<128x128xf16, #blocked2> + %v = tt.descriptor_load %desc_v[%acc_30, %c0_i32] : !tt.tensordesc<128x128xf16, #shared> -> tensor<128x128xf16, #blocked2> %v_51 = ttg.local_alloc %v : (tensor<128x128xf16, #blocked2>) -> !ttg.memdesc<128x128xf16, #shared, #smem> // CHECK: ttng.tc_gen5_mma {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}} {tt.self_latency = 0 : i32} diff --git a/test/TritonGPU/pipeline-loop-nest.mlir b/test/TritonGPU/pipeline-loop-nest.mlir index c59ceb38650a..2adb877fe45c 100644 --- a/test/TritonGPU/pipeline-loop-nest.mlir +++ b/test/TritonGPU/pipeline-loop-nest.mlir @@ -45,10 +45,10 @@ tt.func public @matmul_kernel_tma_persistent(%arg0: !tt.ptr, %arg1: !tt.p %20 = arith.muli %18, %c128_i32 : i32 %21 = scf.for %arg8 = %c0_i32 to %6 step %c1_i32 iter_args(%arg9 = %cst) -> (tensor<128x128xf32>) : i32 { %35 = arith.muli %arg8, %c64_i32 : i32 - %36 = ttng.reinterpret_tensor_descriptor %arg0 : !tt.ptr to !tt.tensordesc> - %37 = tt.descriptor_load %36[%19, %35] : !tt.tensordesc> -> tensor<128x64xf16> - %38 = ttng.reinterpret_tensor_descriptor %arg1 : !tt.ptr to !tt.tensordesc> - %39 = tt.descriptor_load %38[%20, %35] : !tt.tensordesc> -> tensor<128x64xf16> + %36 = ttng.reinterpret_tensor_descriptor %arg0 : !tt.ptr to !tt.tensordesc<128x64xf16, #shared> + %37 = tt.descriptor_load %36[%19, %35] : !tt.tensordesc<128x64xf16, #shared> -> tensor<128x64xf16> + %38 = ttng.reinterpret_tensor_descriptor %arg1 : !tt.ptr to !tt.tensordesc<128x64xf16, #shared> + %39 = tt.descriptor_load %38[%20, %35] : !tt.tensordesc<128x64xf16, #shared> -> tensor<128x64xf16> // BLACKWELL: ttg.memdesc_trans // BLACKWELL: [[ACC_BUF:%.*]] = ttg.memdesc_index [[ACC_BUFS]] // BLACKWELL: ttng.tc_gen5_mma {{%[0-9]+}}, {{%[0-9]+}}, [[ACC_BUF]] @@ -76,8 +76,8 @@ tt.func public @matmul_kernel_tma_persistent(%arg0: !tt.ptr, %arg1: !tt.p %31 = arith.muli %28, %c128_i32 : i32 %32 = arith.muli %30, %c128_i32 : i32 %33 = arith.truncf %21 : tensor<128x128xf32> to tensor<128x128xf16> - %34 = ttng.reinterpret_tensor_descriptor %arg2 : !tt.ptr to !tt.tensordesc> - tt.descriptor_store %34[%31, %32], %33 : !tt.tensordesc>, tensor<128x128xf16> + %34 = ttng.reinterpret_tensor_descriptor %arg2 : !tt.ptr to !tt.tensordesc<128x128xf16, #shared> + tt.descriptor_store %34[%31, %32], %33 : !tt.tensordesc<128x128xf16, #shared>, tensor<128x128xf16> scf.yield %22 : i32 } {tt.flatten} tt.return diff --git a/test/TritonGPU/pipeline-lower-loop.mlir b/test/TritonGPU/pipeline-lower-loop.mlir index 474c50344c23..a0837216c1fe 100644 --- a/test/TritonGPU/pipeline-lower-loop.mlir +++ b/test/TritonGPU/pipeline-lower-loop.mlir @@ -766,10 +766,10 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { // CHECK: ttg.local_dealloc %[[BARRIER]] // CHECK: ttg.local_dealloc %[[A]] tt.func @tma_load_lowering(%lb : index, %ub : index, %step : index, - %desc : !tt.tensordesc>, + %desc : !tt.tensordesc<128x32xf16, #nvmma_64>, %offs : i32) -> () { scf.for %iv = %lb to %ub step %step : index { - %a = tt.descriptor_load %desc[%offs, %offs] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : !tt.tensordesc> -> tensor<128x32xf16, #A> + %a = tt.descriptor_load %desc[%offs, %offs] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : !tt.tensordesc<128x32xf16, #nvmma_64> -> tensor<128x32xf16, #A> "use"(%a) {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<128x32xf16, #A>) -> () } {tt.scheduled_max_stage = 2 : i32} tt.return @@ -821,11 +821,11 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { // CHECK-DAG: ttg.local_dealloc %[[BARRIER]] // CHECK-DAG: ttg.local_dealloc %[[A]] tt.func @tma_gather_lowering(%lb : index, %ub : index, %step : index, - %desc : !tt.tensordesc>, + %desc : !tt.tensordesc<1x128xf32, #nvmma_128>, %x : tensor<32xi32, #offsets>, %y : i32) -> () { scf.for %iv = %lb to %ub step %step : index { - %a = tt.descriptor_gather %desc[%x, %y] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : (!tt.tensordesc>, tensor<32xi32, #offsets>, i32) -> tensor<32x128xf32, #A> + %a = tt.descriptor_gather %desc[%x, %y] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : (!tt.tensordesc<1x128xf32, #nvmma_128>, tensor<32xi32, #offsets>, i32) -> tensor<32x128xf32, #A> "use"(%a) {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<32x128xf32, #A>) -> () } {tt.scheduled_max_stage = 2 : i32} tt.return @@ -852,16 +852,16 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { // CHECK: ttng.wait_barrier // CHECK: "use3" tt.func @tma_reuse_barrier(%lb : index, %ub : index, %step : index, - %descA : !tt.tensordesc>, - %descB : !tt.tensordesc>, - %descC : !tt.tensordesc>, + %descA : !tt.tensordesc<128x32xf16, #nvmma_64>, + %descB : !tt.tensordesc<128x32xf16, #nvmma_64>, + %descC : !tt.tensordesc<128x32xf16, #nvmma_64>, %offs : i32) -> () { scf.for %iv = %lb to %ub step %step : index { - %a = tt.descriptor_load %descA[%offs, %offs] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : !tt.tensordesc> -> tensor<128x32xf16, #A> - %b = tt.descriptor_load %descB[%offs, %offs] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : !tt.tensordesc> -> tensor<128x32xf16, #A> + %a = tt.descriptor_load %descA[%offs, %offs] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : !tt.tensordesc<128x32xf16, #nvmma_64> -> tensor<128x32xf16, #A> + %b = tt.descriptor_load %descB[%offs, %offs] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : !tt.tensordesc<128x32xf16, #nvmma_64> -> tensor<128x32xf16, #A> "use1"(%a) {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<128x32xf16, #A>) -> () "use2"(%b) {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<128x32xf16, #A>) -> () - %c = tt.descriptor_load %descC[%offs, %offs] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : !tt.tensordesc> -> tensor<128x32xf16, #A> + %c = tt.descriptor_load %descC[%offs, %offs] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : !tt.tensordesc<128x32xf16, #nvmma_64> -> tensor<128x32xf16, #A> "use3"(%c) {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<128x32xf16, #A>) -> () } {tt.scheduled_max_stage = 2 : i32} tt.return @@ -888,16 +888,16 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // CHECK-NOT: ttg.local_alloc // CHECK: ttng.warp_group_dot tt.func public @tma_pipelining_mmav3(%lb : index, %ub : index, %step : index, - %descA : !tt.tensordesc>, - %descB : !tt.tensordesc>, + %descA : !tt.tensordesc<128x128xf16, #shared>, + %descB : !tt.tensordesc<128x128xf16, #shared>, %offs : i32) -> tensor<128x128xf16, #mma> { %true = arith.constant true %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma> %c0_i32 = arith.constant 0 : i32 %res = scf.for %i = %lb to %ub step %step iter_args(%acc = %cst) -> (tensor<128x128xf32, #mma>) : index { - %A = tt.descriptor_load %descA[%offs, %offs] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : !tt.tensordesc> -> tensor<128x128xf16, #blocked1> + %A = tt.descriptor_load %descA[%offs, %offs] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : !tt.tensordesc<128x128xf16, #shared> -> tensor<128x128xf16, #blocked1> %A_sh = ttg.local_alloc %A {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<128x128xf16, #blocked1>) -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory> - %B = tt.descriptor_load %descB[%offs, %offs] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : !tt.tensordesc> -> tensor<128x128xf16, #blocked1> + %B = tt.descriptor_load %descB[%offs, %offs] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : !tt.tensordesc<128x128xf16, #shared> -> tensor<128x128xf16, #blocked1> %B_sh = ttg.local_alloc %B {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<128x128xf16, #blocked1>) -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory> %acc_res = ttng.warp_group_dot %A_sh, %B_sh, %acc {loop.cluster = 0 : i32, loop.stage = 2 : i32} : !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory> * !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory> -> tensor<128x128xf32, #mma> scf.yield %acc_res : tensor<128x128xf32, #mma> @@ -934,8 +934,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ %strides_x: i64, %strides_y: i64) -> (){ scf.for %iv = %lb to %ub step %step : index { - %desc = tt.make_tensor_descriptor %A, [%shape_x, %shape_y], [%strides_x, %strides_y] {loop.cluster = 0 : i32, loop.stage = 1 : i32} : , > - "use"(%desc) {loop.cluster = 0 : i32, loop.stage = 1 : i32} : (!tt.tensordesc>) -> () + %desc = tt.make_tensor_descriptor %A, [%shape_x, %shape_y], [%strides_x, %strides_y] {loop.cluster = 0 : i32, loop.stage = 1 : i32} : , <128x128xf16, #nvmma_128> + "use"(%desc) {loop.cluster = 0 : i32, loop.stage = 1 : i32} : (!tt.tensordesc<128x128xf16, #nvmma_128>) -> () } {tt.scheduled_max_stage = 1 : i32} tt.return } @@ -1432,9 +1432,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ %c32_i32 = arith.constant 32 : i32 %c32_i64 = arith.constant 32 : i64 %cst_0 = arith.constant dense<127> : tensor<128x4xi8, #linear> - %0 = tt.make_tensor_descriptor %arg6, [%c32_i32, %c32_i32], [%c32_i64, %c1_i64] : , > - %1 = tt.make_tensor_descriptor %arg9, [%c32_i32, %c32_i32, %c32_i32], [%c32_i64, %c32_i64, %c1_i64] : , > - %2 = tt.make_tensor_descriptor %arg12, [%c32_i32, %c32_i32, %c32_i32, %c32_i32, %c16_i32], [%c32_i64, %c32_i64, %c32_i64, %c16_i64, %c1_i64] : , > + %0 = tt.make_tensor_descriptor %arg6, [%c32_i32, %c32_i32], [%c32_i64, %c1_i64] : , <1x128xf8E4M3FN, #shared> + %1 = tt.make_tensor_descriptor %arg9, [%c32_i32, %c32_i32, %c32_i32], [%c32_i64, %c32_i64, %c1_i64] : , <1x64x256xi8, #shared1> + %2 = tt.make_tensor_descriptor %arg12, [%c32_i32, %c32_i32, %c32_i32, %c32_i32, %c16_i32], [%c32_i64, %c32_i64, %c32_i64, %c16_i64, %c1_i64] : , <1x2x1x32x16xi8, #shared2> %3 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> %4 = ttng.tmem_alloc %cst_0 : (tensor<128x4xi8, #linear>) -> !ttg.memdesc<128x4xi8, #tmem_scales, #ttng.tensor_memory> %5, %acc_tok = ttng.tmem_alloc : () -> (!ttg.memdesc<128x256xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.async.token) @@ -1448,11 +1448,11 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // CHECK: ttng.tmem_alloc // CHECK: ttng.tc_gen5_mma_scaled - %7 = tt.descriptor_gather %0[%3, %c0_i32] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : (!tt.tensordesc>, tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>, i32) -> tensor<128x128xf8E4M3FN, #blocked2> + %7 = tt.descriptor_gather %0[%3, %c0_i32] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : (!tt.tensordesc<1x128xf8E4M3FN, #shared>, tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>, i32) -> tensor<128x128xf8E4M3FN, #blocked2> %8 = ttg.local_alloc %7 {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<128x128xf8E4M3FN, #blocked2>) -> !ttg.memdesc<128x128xf8E4M3FN, #shared, #smem> - %9 = tt.descriptor_load %1[%arg30, %c0_i32, %c0_i32] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : !tt.tensordesc> -> tensor<64x256xi8, #blocked2> + %9 = tt.descriptor_load %1[%arg30, %c0_i32, %c0_i32] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : !tt.tensordesc<1x64x256xi8, #shared1> -> tensor<64x256xi8, #blocked2> %10 = ttg.local_alloc %9 {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<64x256xi8, #blocked2>) -> !ttg.memdesc<64x256xi8, #shared3, #smem> - %11 = tt.descriptor_load %2[%arg30, %c0_i32, %c0_i32, %c0_i32, %c0_i32] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : !tt.tensordesc> -> tensor<1x2x1x32x16xi8, #blocked3> + %11 = tt.descriptor_load %2[%arg30, %c0_i32, %c0_i32, %c0_i32, %c0_i32] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : !tt.tensordesc<1x2x1x32x16xi8, #shared2> -> tensor<1x2x1x32x16xi8, #blocked3> %12 = tt.reshape %11 {loop.cluster = 0 : i32, loop.stage = 2 : i32} : tensor<1x2x1x32x16xi8, #blocked3> -> tensor<2x1x32x4x4xi8, #blocked4> %13 = tt.trans %12 {loop.cluster = 0 : i32, loop.stage = 2 : i32, order = array} : tensor<2x1x32x4x4xi8, #blocked4> -> tensor<2x4x32x1x4xi8, #blocked5> %14 = ttg.convert_layout %13 {loop.cluster = 0 : i32, loop.stage = 2 : i32} : tensor<2x4x32x1x4xi8, #blocked5> -> tensor<2x4x32x1x4xi8, #linear1> @@ -1721,13 +1721,13 @@ tt.func @conditional_store_race_fix(%lb : index, %ub : index, %step : index, #tmem = #ttng.tensor_memory_encoding module attributes {ttg.max_reg_auto_ws = 152 : i32, ttg.min_reg_auto_ws = 24 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: @non_pipelined_op - tt.func public @non_pipelined_op(%x_desc: !tt.tensordesc>, %x_desc_0: i32, %x_desc_1: i32, %x_desc_2: i64, %x_desc_3: i64, %y_desc: !tt.tensordesc>, %y_desc_4: i32, %y_desc_5: i32, %y_desc_6: i64, %y_desc_7: i64, %out_desc: !tt.tensordesc>, %out_desc_8: i32, %out_desc_9: i32, %out_desc_10: i64, %out_desc_11: i64, %N: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + tt.func public @non_pipelined_op(%x_desc: !tt.tensordesc<64x64xbf16, #shared>, %x_desc_0: i32, %x_desc_1: i32, %x_desc_2: i64, %x_desc_3: i64, %y_desc: !tt.tensordesc<64x64xbf16, #shared>, %y_desc_4: i32, %y_desc_5: i32, %y_desc_6: i64, %y_desc_7: i64, %out_desc: !tt.tensordesc<64x64xf32, #shared1>, %out_desc_8: i32, %out_desc_9: i32, %out_desc_10: i64, %out_desc_11: i64, %N: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { %acc = arith.constant false %true = arith.constant true %c1_i32 = arith.constant 1 : i32 %BLOCK_N = arith.constant 64 : i32 %c0_i32 = arith.constant 0 : i32 - %x = tt.descriptor_load %x_desc[%c0_i32, %c0_i32] : !tt.tensordesc> -> tensor<64x64xbf16, #blocked> + %x = tt.descriptor_load %x_desc[%c0_i32, %c0_i32] : !tt.tensordesc<64x64xbf16, #shared> -> tensor<64x64xbf16, #blocked> %x_12 = ttg.local_alloc %x : (tensor<64x64xbf16, #blocked>) -> !ttg.memdesc<64x64xbf16, #shared, #smem> %num_slices = arith.divsi %N, %BLOCK_N : i32 %acc_13, %acc_14 = ttng.tmem_alloc : () -> (!ttg.memdesc<64x64xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.async.token) @@ -1737,7 +1737,7 @@ module attributes {ttg.max_reg_auto_ws = 152 : i32, ttg.min_reg_auto_ws = 24 : i %y = arith.muli %i, %BLOCK_N {loop.cluster = 1 : i32, loop.stage = 0 : i32} : i32 // CHECK: ttng.barrier_expect {{.*}} {loop.cluster = 2 : i32, loop.stage = 0 : i32}, {{.*}} // CHECK: ttng.async_tma_copy_global_to_local {{.*}} {loop.cluster = 2 : i32, loop.stage = 0 : i32} {{.*}} - %y_16 = tt.descriptor_load %y_desc[%c0_i32, %y] {loop.cluster = 1 : i32, loop.stage = 0 : i32} : !tt.tensordesc> -> tensor<64x64xbf16, #blocked> + %y_16 = tt.descriptor_load %y_desc[%c0_i32, %y] {loop.cluster = 1 : i32, loop.stage = 0 : i32} : !tt.tensordesc<64x64xbf16, #shared> -> tensor<64x64xbf16, #blocked> // CHECK: ttng.wait_barrier {{.*}} {loop.cluster = 2 : i32, loop.stage = 0 : i32} {{.*}} %y_17 = ttg.local_alloc %y_16 {loop.cluster = 1 : i32, loop.stage = 0 : i32} : (tensor<64x64xbf16, #blocked>) -> !ttg.memdesc<64x64xbf16, #shared, #smem> // CHECK:{{.*}} = ttng.tc_gen5_mma {{.*}} {is_async, loop.cluster = 2 : i32, loop.stage = 0 : i32} {{.*}} @@ -1747,7 +1747,7 @@ module attributes {ttg.max_reg_auto_ws = 152 : i32, ttg.min_reg_auto_ws = 24 : i %acc_19, %acc_20 = ttng.tmem_load %acc_13[%acc_18] {loop.cluster = 0 : i32, loop.stage = 1 : i32} : !ttg.memdesc<64x64xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<64x64xf32, #blocked1> %1 = ttg.convert_layout %acc_19 {loop.cluster = 0 : i32, loop.stage = 1 : i32} : tensor<64x64xf32, #blocked1> -> tensor<64x64xf32, #blocked> // CHECK: tt.descriptor_store {{.*}} {loop.cluster = 1 : i32, loop.stage = 1 : i32} {{.*}} - tt.descriptor_store %out_desc[%c0_i32, %y], %1 {loop.cluster = 0 : i32, loop.stage = 1 : i32} : !tt.tensordesc>, tensor<64x64xf32, #blocked> + tt.descriptor_store %out_desc[%c0_i32, %y], %1 {loop.cluster = 0 : i32, loop.stage = 1 : i32} : !tt.tensordesc<64x64xf32, #shared1>, tensor<64x64xf32, #blocked> scf.yield %acc_20 : !ttg.async.token } {tt.scheduled_max_stage = 1 : i32} tt.return @@ -1764,13 +1764,13 @@ module attributes {ttg.max_reg_auto_ws = 152 : i32, ttg.min_reg_auto_ws = 24 : i #tmem = #ttng.tensor_memory_encoding module attributes {ttg.max_reg_auto_ws = 152 : i32, ttg.min_reg_auto_ws = 24 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: @non_pipelined_op_two_stage - tt.func public @non_pipelined_op_two_stage(%x_desc: !tt.tensordesc>, %x_desc_0: i32, %x_desc_1: i32, %x_desc_2: i64, %x_desc_3: i64, %y_desc: !tt.tensordesc>, %y_desc_4: i32, %y_desc_5: i32, %y_desc_6: i64, %y_desc_7: i64, %out_desc: !tt.tensordesc>, %out_desc_8: i32, %out_desc_9: i32, %out_desc_10: i64, %out_desc_11: i64, %N: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + tt.func public @non_pipelined_op_two_stage(%x_desc: !tt.tensordesc<64x64xbf16, #shared>, %x_desc_0: i32, %x_desc_1: i32, %x_desc_2: i64, %x_desc_3: i64, %y_desc: !tt.tensordesc<64x64xbf16, #shared>, %y_desc_4: i32, %y_desc_5: i32, %y_desc_6: i64, %y_desc_7: i64, %out_desc: !tt.tensordesc<64x64xf32, #shared1>, %out_desc_8: i32, %out_desc_9: i32, %out_desc_10: i64, %out_desc_11: i64, %N: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { %acc = arith.constant false %true = arith.constant true %c1_i32 = arith.constant 1 : i32 %BLOCK_N = arith.constant 64 : i32 %c0_i32 = arith.constant 0 : i32 - %x = tt.descriptor_load %x_desc[%c0_i32, %c0_i32] : !tt.tensordesc> -> tensor<64x64xbf16, #blocked> + %x = tt.descriptor_load %x_desc[%c0_i32, %c0_i32] : !tt.tensordesc<64x64xbf16, #shared> -> tensor<64x64xbf16, #blocked> %x_12 = ttg.local_alloc %x : (tensor<64x64xbf16, #blocked>) -> !ttg.memdesc<64x64xbf16, #shared, #smem> %num_slices = arith.divsi %N, %BLOCK_N : i32 %acc_13, %acc_14 = ttng.tmem_alloc : () -> (!ttg.memdesc<64x64xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.async.token) @@ -1780,7 +1780,7 @@ module attributes {ttg.max_reg_auto_ws = 152 : i32, ttg.min_reg_auto_ws = 24 : i %y = arith.muli %i, %BLOCK_N {loop.cluster = 1 : i32, loop.stage = 0 : i32} : i32 // CHECK: ttng.barrier_expect {{.*}} {loop.cluster = 3 : i32, loop.stage = 0 : i32}, {{.*}} // CHECK: ttng.async_tma_copy_global_to_local {{.*}} {loop.cluster = 3 : i32, loop.stage = 0 : i32} {{.*}} - %y_16 = tt.descriptor_load %y_desc[%c0_i32, %y] {loop.cluster = 1 : i32, loop.stage = 0 : i32} : !tt.tensordesc> -> tensor<64x64xbf16, #blocked> + %y_16 = tt.descriptor_load %y_desc[%c0_i32, %y] {loop.cluster = 1 : i32, loop.stage = 0 : i32} : !tt.tensordesc<64x64xbf16, #shared> -> tensor<64x64xbf16, #blocked> // CHECK: ttng.wait_barrier {{.*}} {loop.cluster = 3 : i32, loop.stage = 0 : i32} {{.*}} %y_17 = ttg.local_alloc %y_16 {loop.cluster = 1 : i32, loop.stage = 0 : i32} : (tensor<64x64xbf16, #blocked>) -> !ttg.memdesc<64x64xbf16, #shared, #smem> // CHECK:{{.*}} = ttng.tc_gen5_mma {{.*}} {is_async, loop.cluster = 3 : i32, loop.stage = 0 : i32} {{.*}} @@ -1790,7 +1790,7 @@ module attributes {ttg.max_reg_auto_ws = 152 : i32, ttg.min_reg_auto_ws = 24 : i %acc_19, %acc_20 = ttng.tmem_load %acc_13[%acc_18] {loop.cluster = 0 : i32, loop.stage = 2 : i32} : !ttg.memdesc<64x64xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<64x64xf32, #blocked1> %1 = ttg.convert_layout %acc_19 {loop.cluster = 0 : i32, loop.stage = 2 : i32} : tensor<64x64xf32, #blocked1> -> tensor<64x64xf32, #blocked> // CHECK: tt.descriptor_store {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} {{.*}} - tt.descriptor_store %out_desc[%c0_i32, %y], %1 {loop.cluster = 0 : i32, loop.stage = 2 : i32} : !tt.tensordesc>, tensor<64x64xf32, #blocked> + tt.descriptor_store %out_desc[%c0_i32, %y], %1 {loop.cluster = 0 : i32, loop.stage = 2 : i32} : !tt.tensordesc<64x64xf32, #shared1>, tensor<64x64xf32, #blocked> scf.yield %acc_20 : !ttg.async.token } {tt.scheduled_max_stage = 2 : i32} tt.return diff --git a/test/TritonGPU/proxy_fence_insertion.mlir b/test/TritonGPU/proxy_fence_insertion.mlir index ef46224d1daf..2630a0d1f185 100644 --- a/test/TritonGPU/proxy_fence_insertion.mlir +++ b/test/TritonGPU/proxy_fence_insertion.mlir @@ -6,7 +6,7 @@ #smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: fence_write_after_read - tt.func @fence_write_after_read(%arg0: !tt.tensordesc>, %arg1: !ttg.memdesc<1xi64, #shared1, #smem, mutable>) { + tt.func @fence_write_after_read(%arg0: !tt.tensordesc<64x64xf32, #shared>, %arg1: !ttg.memdesc<1xi64, #shared1, #smem, mutable>) { // CHECK: ttg.local_load // CHECK: ttng.fence_async_shared // CHECK: ttng.async_tma_copy_global_to_local @@ -16,7 +16,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { %1 = ttg.local_load %0 : !ttg.memdesc<32x64xf32, #shared, #smem, mutable> -> tensor<32x64xf32, #blocked> "test.keep"(%1) : (tensor<32x64xf32, #blocked>) -> () %2 = ttg.local_alloc {allocation.offset = 32 : i32} : () -> !ttg.memdesc<64x64xf32, #shared, #smem, mutable> - ttng.async_tma_copy_global_to_local %arg0[%c0_i32, %c0_i32] %2, %arg1, %true : !tt.tensordesc>, !ttg.memdesc<1xi64, #shared1, #smem, mutable> -> !ttg.memdesc<64x64xf32, #shared, #smem, mutable> + ttng.async_tma_copy_global_to_local %arg0[%c0_i32, %c0_i32] %2, %arg1, %true : !tt.tensordesc<64x64xf32, #shared>, !ttg.memdesc<1xi64, #shared1, #smem, mutable> -> !ttg.memdesc<64x64xf32, #shared, #smem, mutable> tt.return } } @@ -29,7 +29,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { #smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: missing_proxy_fence_memdesc_index_alias_single - tt.func @missing_proxy_fence_memdesc_index_alias_single(%arg0: !tt.tensordesc>, %arg1: !ttg.memdesc<1xi64, #shared1, #smem, mutable>) { + tt.func @missing_proxy_fence_memdesc_index_alias_single(%arg0: !tt.tensordesc<64x64xf32, #shared>, %arg1: !ttg.memdesc<1xi64, #shared1, #smem, mutable>) { // Keep the first fence to clear dependencies from local_alloc. // CHECK: ttng.fence_async_shared // CHECK: ttg.local_load @@ -43,7 +43,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { ttng.fence_async_shared {bCluster = false} %2 = ttg.local_load %1 : !ttg.memdesc<64x64xf32, #shared, #smem, mutable> -> tensor<64x64xf32, #blocked> "test.keep"(%2) : (tensor<64x64xf32, #blocked>) -> () - ttng.async_tma_copy_global_to_local %arg0[%c0_i32, %c0_i32] %1, %arg1, %true : !tt.tensordesc>, !ttg.memdesc<1xi64, #shared1, #smem, mutable> -> !ttg.memdesc<64x64xf32, #shared, #smem, mutable> + ttng.async_tma_copy_global_to_local %arg0[%c0_i32, %c0_i32] %1, %arg1, %true : !tt.tensordesc<64x64xf32, #shared>, !ttg.memdesc<1xi64, #shared1, #smem, mutable> -> !ttg.memdesc<64x64xf32, #shared, #smem, mutable> tt.return } } @@ -56,17 +56,17 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { #smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: async_proxy_after_async_proxy - tt.func @async_proxy_after_async_proxy(%arg0: !tt.tensordesc>, %arg1: !ttg.memdesc<1xi64, #shared1, #smem, mutable>) { + tt.func @async_proxy_after_async_proxy(%arg0: !tt.tensordesc<64x64xf32, #shared>, %arg1: !ttg.memdesc<1xi64, #shared1, #smem, mutable>) { // CHECK: ttng.async_tma_copy_global_to_local // CHECK-NOT: ttng.fence_async_shared // CHECK: ttng.async_tma_copy_global_to_local %c0_i32 = arith.constant 0 : i32 %true = arith.constant true %0 = ttg.local_alloc {allocation.offset = 16 : i32} : () -> !ttg.memdesc<64x64xf32, #shared, #smem, mutable> - ttng.async_tma_copy_global_to_local %arg0[%c0_i32, %c0_i32] %0, %arg1, %true : !tt.tensordesc>, !ttg.memdesc<1xi64, #shared1, #smem, mutable> -> !ttg.memdesc<64x64xf32, #shared, #smem, mutable> + ttng.async_tma_copy_global_to_local %arg0[%c0_i32, %c0_i32] %0, %arg1, %true : !tt.tensordesc<64x64xf32, #shared>, !ttg.memdesc<1xi64, #shared1, #smem, mutable> -> !ttg.memdesc<64x64xf32, #shared, #smem, mutable> ttng.async_tma_store_wait {pendings = 0 : i32} %2 = ttg.local_alloc {allocation.offset = 32 : i32} : () -> !ttg.memdesc<64x64xf32, #shared, #smem, mutable> - ttng.async_tma_copy_global_to_local %arg0[%c0_i32, %c0_i32] %2, %arg1, %true : !tt.tensordesc>, !ttg.memdesc<1xi64, #shared1, #smem, mutable> -> !ttg.memdesc<64x64xf32, #shared, #smem, mutable> + ttng.async_tma_copy_global_to_local %arg0[%c0_i32, %c0_i32] %2, %arg1, %true : !tt.tensordesc<64x64xf32, #shared>, !ttg.memdesc<1xi64, #shared1, #smem, mutable> -> !ttg.memdesc<64x64xf32, #shared, #smem, mutable> tt.return } } @@ -78,7 +78,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { #smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: missing_proxy_fence_local_store_before_async_tma_copy_local_to_global - tt.func @missing_proxy_fence_local_store_before_async_tma_copy_local_to_global(%arg0: !tt.tensordesc>, %arg1: tensor<128x256xf32, #blocked>) { + tt.func @missing_proxy_fence_local_store_before_async_tma_copy_local_to_global(%arg0: !tt.tensordesc<128x256xf32, #shared>, %arg1: tensor<128x256xf32, #blocked>) { // CHECK: ttng.async_tma_store_wait {pendings = 1 : i32} // CHECK-NEXT: ttg.local_store // CHECK-NEXT: ttng.fence_async_shared @@ -87,7 +87,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { %0 = ttg.local_alloc {allocation.offset = 16 : i32} : () -> !ttg.memdesc<128x256xf32, #shared, #smem, mutable> ttng.async_tma_store_wait {pendings = 1 : i32} ttg.local_store %arg1, %0 : tensor<128x256xf32, #blocked> -> !ttg.memdesc<128x256xf32, #shared, #smem, mutable> - ttng.async_tma_copy_local_to_global %arg0[%c0_i32, %c0_i32] %0 : !tt.tensordesc>, !ttg.memdesc<128x256xf32, #shared, #smem, mutable> + ttng.async_tma_copy_local_to_global %arg0[%c0_i32, %c0_i32] %0 : !tt.tensordesc<128x256xf32, #shared>, !ttg.memdesc<128x256xf32, #shared, #smem, mutable> tt.return } } diff --git a/test/TritonGPU/samples/descriptor-matmul-pipeline.mlir b/test/TritonGPU/samples/descriptor-matmul-pipeline.mlir index a01e35fa8836..f40f5ac4d45f 100644 --- a/test/TritonGPU/samples/descriptor-matmul-pipeline.mlir +++ b/test/TritonGPU/samples/descriptor-matmul-pipeline.mlir @@ -47,10 +47,10 @@ // CHECK: %[[VAL_32:.*]] = arith.remsi %[[VAL_20]], %[[VAL_25]] : i32 // CHECK: %[[VAL_33:.*]] = arith.divsi %[[VAL_32]], %[[VAL_29]] : i32 // CHECK: %[[VAL_34:.*]] = arith.extsi %[[VAL_5]] : i32 to i64 -// CHECK: %[[VAL_35:.*]] = tt.make_tensor_descriptor %[[VAL_0]], {{\[}}%[[VAL_3]], %[[VAL_5]]], {{\[}}%[[VAL_34]], %[[VAL_14]]] : , > -// CHECK: %[[VAL_36:.*]] = tt.make_tensor_descriptor %[[VAL_1]], {{\[}}%[[VAL_4]], %[[VAL_5]]], {{\[}}%[[VAL_34]], %[[VAL_14]]] : , > +// CHECK: %[[VAL_35:.*]] = tt.make_tensor_descriptor %[[VAL_0]], {{\[}}%[[VAL_3]], %[[VAL_5]]], {{\[}}%[[VAL_34]], %[[VAL_14]]] : , <128x64xf16, #[[$ATTR_2]]> +// CHECK: %[[VAL_36:.*]] = tt.make_tensor_descriptor %[[VAL_1]], {{\[}}%[[VAL_4]], %[[VAL_5]]], {{\[}}%[[VAL_34]], %[[VAL_14]]] : , <256x64xf16, #[[$ATTR_2]]> // CHECK: %[[VAL_37:.*]] = arith.extsi %[[VAL_4]] : i32 to i64 -// CHECK: %[[VAL_38:.*]] = tt.make_tensor_descriptor %[[VAL_2]], {{\[}}%[[VAL_3]], %[[VAL_4]]], {{\[}}%[[VAL_37]], %[[VAL_14]]] : , > +// CHECK: %[[VAL_38:.*]] = tt.make_tensor_descriptor %[[VAL_2]], {{\[}}%[[VAL_3]], %[[VAL_4]]], {{\[}}%[[VAL_37]], %[[VAL_14]]] : , <128x256xf16, #[[$ATTR_2]]> // CHECK: %[[VAL_39:.*]] = arith.muli %[[VAL_31]], %[[VAL_10]] : i32 // CHECK: %[[VAL_40:.*]] = arith.muli %[[VAL_33]], %[[VAL_11]] : i32 // CHECK: %[[VAL_41:.*]] = arith.addi %[[VAL_5]], %[[VAL_18]] : i32 @@ -68,16 +68,16 @@ // CHECK: %[[VAL_50:.*]] = ttg.memdesc_index %[[VAL_45]]{{\[}}%[[VAL_12]]{{\]}} : !ttg.memdesc<3x1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> // CHECK: ttng.barrier_expect %[[VAL_50]], 49152, %[[VAL_49]] : !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> // CHECK: %[[VAL_51:.*]] = ttg.memdesc_index %[[VAL_43]]{{\[}}%[[VAL_12]]{{\]}} : !ttg.memdesc<3x128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> -// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_35]]{{\[}}%[[VAL_39]], %[[VAL_12]]] %[[VAL_51]], %[[VAL_50]], %[[VAL_49]] : !tt.tensordesc>, !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> +// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_35]]{{\[}}%[[VAL_39]], %[[VAL_12]]] %[[VAL_51]], %[[VAL_50]], %[[VAL_49]] : !tt.tensordesc<128x64xf16, #[[$ATTR_2]]>, !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> // CHECK: %[[VAL_52:.*]] = ttg.memdesc_index %[[VAL_44]]{{\[}}%[[VAL_12]]{{\]}} : !ttg.memdesc<3x256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> -// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_36]]{{\[}}%[[VAL_40]], %[[VAL_12]]] %[[VAL_52]], %[[VAL_50]], %[[VAL_49]] : !tt.tensordesc>, !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> +// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_36]]{{\[}}%[[VAL_40]], %[[VAL_12]]] %[[VAL_52]], %[[VAL_50]], %[[VAL_49]] : !tt.tensordesc<256x64xf16, #[[$ATTR_2]]>, !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> // CHECK: %[[VAL_53:.*]] = arith.cmpi sgt, %[[VAL_42]], %[[VAL_15]] : i32 // CHECK: %[[VAL_54:.*]] = ttg.memdesc_index %[[VAL_45]]{{\[}}%[[VAL_15]]{{\]}} : !ttg.memdesc<3x1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> // CHECK: ttng.barrier_expect %[[VAL_54]], 49152, %[[VAL_53]] : !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> // CHECK: %[[VAL_55:.*]] = ttg.memdesc_index %[[VAL_43]]{{\[}}%[[VAL_15]]{{\]}} : !ttg.memdesc<3x128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> -// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_35]]{{\[}}%[[VAL_39]], %[[VAL_13]]] %[[VAL_55]], %[[VAL_54]], %[[VAL_53]] : !tt.tensordesc>, !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> +// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_35]]{{\[}}%[[VAL_39]], %[[VAL_13]]] %[[VAL_55]], %[[VAL_54]], %[[VAL_53]] : !tt.tensordesc<128x64xf16, #[[$ATTR_2]]>, !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> // CHECK: %[[VAL_56:.*]] = ttg.memdesc_index %[[VAL_44]]{{\[}}%[[VAL_15]]{{\]}} : !ttg.memdesc<3x256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> -// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_36]]{{\[}}%[[VAL_40]], %[[VAL_13]]] %[[VAL_56]], %[[VAL_54]], %[[VAL_53]] : !tt.tensordesc>, !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> +// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_36]]{{\[}}%[[VAL_40]], %[[VAL_13]]] %[[VAL_56]], %[[VAL_54]], %[[VAL_53]] : !tt.tensordesc<256x64xf16, #[[$ATTR_2]]>, !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> // CHECK: %[[VAL_57:.*]]:5 = scf.for %[[VAL_58:.*]] = %[[VAL_12]] to %[[VAL_42]] step %[[VAL_15]] iter_args(%[[VAL_59:.*]] = %[[VAL_19]], %[[VAL_60:.*]] = %[[VAL_13]], %[[VAL_61:.*]] = %[[VAL_15]], %[[VAL_62:.*]] = %[[VAL_8]], %[[VAL_63:.*]] = %[[VAL_12]]) -> (tensor<128x256xf32, #[[$ATTR_1]]>, i32, i32, i32, i32) : i32 { // CHECK: %[[VAL_64:.*]] = arith.subi %[[VAL_42]], %[[VAL_7]] : i32 // CHECK: %[[VAL_65:.*]] = arith.cmpi slt, %[[VAL_58]], %[[VAL_64]] : i32 @@ -100,9 +100,9 @@ // CHECK: %[[VAL_81:.*]] = ttg.memdesc_index %[[VAL_45]]{{\[}}%[[VAL_80]]{{\]}} : !ttg.memdesc<3x1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> // CHECK: ttng.barrier_expect %[[VAL_81]], 49152, %[[VAL_65]] : !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> // CHECK: %[[VAL_82:.*]] = ttg.memdesc_index %[[VAL_43]]{{\[}}%[[VAL_80]]{{\]}} : !ttg.memdesc<3x128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> -// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_35]]{{\[}}%[[VAL_39]], %[[VAL_77]]] %[[VAL_82]], %[[VAL_81]], %[[VAL_65]] : !tt.tensordesc>, !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> +// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_35]]{{\[}}%[[VAL_39]], %[[VAL_77]]] %[[VAL_82]], %[[VAL_81]], %[[VAL_65]] : !tt.tensordesc<128x64xf16, #[[$ATTR_2]]>, !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> // CHECK: %[[VAL_83:.*]] = ttg.memdesc_index %[[VAL_44]]{{\[}}%[[VAL_80]]{{\]}} : !ttg.memdesc<3x256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> -// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_36]]{{\[}}%[[VAL_40]], %[[VAL_77]]] %[[VAL_83]], %[[VAL_81]], %[[VAL_65]] : !tt.tensordesc>, !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> +// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_36]]{{\[}}%[[VAL_40]], %[[VAL_77]]] %[[VAL_83]], %[[VAL_81]], %[[VAL_65]] : !tt.tensordesc<256x64xf16, #[[$ATTR_2]]>, !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> // CHECK: scf.yield %[[VAL_76]]#0, %[[VAL_77]], %[[VAL_80]], %[[VAL_68]], %[[VAL_70]] : tensor<128x256xf32, #[[$ATTR_1]]>, i32, i32, i32, i32 // CHECK: } // CHECK: %[[VAL_84:.*]] = ttng.warp_group_dot_wait %[[VAL_85:.*]]#0 {pendings = 0 : i32} : tensor<128x256xf32, #[[$ATTR_1]]> @@ -117,7 +117,7 @@ // CHECK: ttg.local_dealloc %[[VAL_43]] : !ttg.memdesc<3x128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> // CHECK: %[[VAL_89:.*]] = arith.truncf %[[VAL_84]] : tensor<128x256xf32, #[[$ATTR_1]]> to tensor<128x256xf16, #[[$ATTR_1]]> // CHECK: %[[VAL_90:.*]] = ttg.convert_layout %[[VAL_89]] : tensor<128x256xf16, #[[$ATTR_1]]> -> tensor<128x256xf16, #[[$ATTR_0]]> -// CHECK: tt.descriptor_store %[[VAL_38]]{{\[}}%[[VAL_39]], %[[VAL_40]]], %[[VAL_90]] : !tt.tensordesc>, tensor<128x256xf16, #[[$ATTR_0]]> +// CHECK: tt.descriptor_store %[[VAL_38]]{{\[}}%[[VAL_39]], %[[VAL_40]]], %[[VAL_90]] : !tt.tensordesc<128x256xf16, #[[$ATTR_2]]>, tensor<128x256xf16, #[[$ATTR_0]]> // CHECK: tt.return // CHECK: } module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { @@ -148,18 +148,18 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ %12 = arith.remsi %0, %5 : i32 %13 = arith.divsi %12, %9 : i32 %14 = arith.extsi %arg5 : i32 to i64 - %15 = tt.make_tensor_descriptor %arg0, [%arg3, %arg5], [%14, %c1_i64] : , > - %16 = tt.make_tensor_descriptor %arg1, [%arg4, %arg5], [%14, %c1_i64] : , > + %15 = tt.make_tensor_descriptor %arg0, [%arg3, %arg5], [%14, %c1_i64] : , <128x64xf16, #shared> + %16 = tt.make_tensor_descriptor %arg1, [%arg4, %arg5], [%14, %c1_i64] : , <256x64xf16, #shared> %17 = arith.extsi %arg4 : i32 to i64 - %18 = tt.make_tensor_descriptor %arg2, [%arg3, %arg4], [%17, %c1_i64] : , > + %18 = tt.make_tensor_descriptor %arg2, [%arg3, %arg4], [%17, %c1_i64] : , <128x256xf16, #shared> %19 = arith.muli %11, %c128_i32 : i32 %20 = arith.muli %13, %c256_i32 : i32 %21 = arith.addi %arg5, %c63_i32 : i32 %22 = arith.divsi %21, %c64_i32 : i32 %23:2 = scf.for %arg6 = %c0_i32 to %22 step %c1_i32 iter_args(%arg7 = %cst, %arg8 = %c0_i32) -> (tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>>, i32) : i32 { - %26 = tt.descriptor_load %15[%19, %arg8] : !tt.tensordesc> -> tensor<128x64xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>> + %26 = tt.descriptor_load %15[%19, %arg8] : !tt.tensordesc<128x64xf16, #shared> -> tensor<128x64xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>> %27 = ttg.local_alloc %26 : (tensor<128x64xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>>) -> !ttg.memdesc<128x64xf16, #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>, #ttg.shared_memory> - %28 = tt.descriptor_load %16[%20, %arg8] : !tt.tensordesc> -> tensor<256x64xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>> + %28 = tt.descriptor_load %16[%20, %arg8] : !tt.tensordesc<256x64xf16, #shared> -> tensor<256x64xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>> %29 = ttg.local_alloc %28 : (tensor<256x64xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>>) -> !ttg.memdesc<256x64xf16, #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>, #ttg.shared_memory> %30 = ttg.memdesc_trans %29 {order = array} : !ttg.memdesc<256x64xf16, #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>, #ttg.shared_memory> -> !ttg.memdesc<64x256xf16, #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 16}>, #ttg.shared_memory> %31 = ttng.warp_group_dot %27, %30, %arg7 {inputPrecision = 0 : i32} : !ttg.memdesc<128x64xf16, #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>, #ttg.shared_memory> * !ttg.memdesc<64x256xf16, #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 16}>, #ttg.shared_memory> -> tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> @@ -168,7 +168,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ } %24 = arith.truncf %23#0 : tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> to tensor<128x256xf16, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> %25 = ttg.convert_layout %24 : tensor<128x256xf16, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> -> tensor<128x256xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 8], order = [1, 0]}>> - tt.descriptor_store %18[%19, %20], %25 : !tt.tensordesc>, tensor<128x256xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 8], order = [1, 0]}>> + tt.descriptor_store %18[%19, %20], %25 : !tt.tensordesc<128x256xf16, #shared>, tensor<128x256xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 8], order = [1, 0]}>> tt.return } } diff --git a/test/TritonGPU/samples/descriptor-matmul-pipeline.mlir.in b/test/TritonGPU/samples/descriptor-matmul-pipeline.mlir.in index bec32eadcd76..b9ec39f1517a 100644 --- a/test/TritonGPU/samples/descriptor-matmul-pipeline.mlir.in +++ b/test/TritonGPU/samples/descriptor-matmul-pipeline.mlir.in @@ -31,18 +31,18 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ %12 = arith.remsi %0, %5 : i32 %13 = arith.divsi %12, %9 : i32 %14 = arith.extsi %arg5 : i32 to i64 - %15 = tt.make_tensor_descriptor %arg0, [%arg3, %arg5], [%14, %c1_i64] : , > - %16 = tt.make_tensor_descriptor %arg1, [%arg4, %arg5], [%14, %c1_i64] : , > + %15 = tt.make_tensor_descriptor %arg0, [%arg3, %arg5], [%14, %c1_i64] : , <128x64xf16, #shared> + %16 = tt.make_tensor_descriptor %arg1, [%arg4, %arg5], [%14, %c1_i64] : , <256x64xf16, #shared> %17 = arith.extsi %arg4 : i32 to i64 - %18 = tt.make_tensor_descriptor %arg2, [%arg3, %arg4], [%17, %c1_i64] : , > + %18 = tt.make_tensor_descriptor %arg2, [%arg3, %arg4], [%17, %c1_i64] : , <128x256xf16, #shared> %19 = arith.muli %11, %c128_i32 : i32 %20 = arith.muli %13, %c256_i32 : i32 %21 = arith.addi %arg5, %c63_i32 : i32 %22 = arith.divsi %21, %c64_i32 : i32 %23:2 = scf.for %arg6 = %c0_i32 to %22 step %c1_i32 iter_args(%arg7 = %cst, %arg8 = %c0_i32) -> (tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>>, i32) : i32 { - %26 = tt.descriptor_load %15[%19, %arg8] : !tt.tensordesc> -> tensor<128x64xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>> + %26 = tt.descriptor_load %15[%19, %arg8] : !tt.tensordesc<128x64xf16, #shared> -> tensor<128x64xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>> %27 = ttg.local_alloc %26 : (tensor<128x64xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>>) -> !ttg.memdesc<128x64xf16, #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>, #ttg.shared_memory> - %28 = tt.descriptor_load %16[%20, %arg8] : !tt.tensordesc> -> tensor<256x64xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>> + %28 = tt.descriptor_load %16[%20, %arg8] : !tt.tensordesc<256x64xf16, #shared> -> tensor<256x64xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>> %29 = ttg.local_alloc %28 : (tensor<256x64xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>>) -> !ttg.memdesc<256x64xf16, #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>, #ttg.shared_memory> %30 = ttg.memdesc_trans %29 {order = array} : !ttg.memdesc<256x64xf16, #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>, #ttg.shared_memory> -> !ttg.memdesc<64x256xf16, #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 16}>, #ttg.shared_memory> %31 = ttng.warp_group_dot %27, %30, %arg7 {inputPrecision = 0 : i32} : !ttg.memdesc<128x64xf16, #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>, #ttg.shared_memory> * !ttg.memdesc<64x256xf16, #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 16}>, #ttg.shared_memory> -> tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> @@ -51,7 +51,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ } %24 = arith.truncf %23#0 : tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> to tensor<128x256xf16, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> %25 = ttg.convert_layout %24 : tensor<128x256xf16, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> -> tensor<128x256xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 8], order = [1, 0]}>> - tt.descriptor_store %18[%19, %20], %25 : !tt.tensordesc>, tensor<128x256xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 8], order = [1, 0]}>> + tt.descriptor_store %18[%19, %20], %25 : !tt.tensordesc<128x256xf16, #shared>, tensor<128x256xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 8], order = [1, 0]}>> tt.return } } diff --git a/test/TritonGPU/samples/simulated-grouped-gemm.mlir b/test/TritonGPU/samples/simulated-grouped-gemm.mlir index e7481bbc1add..4d1502b10464 100644 --- a/test/TritonGPU/samples/simulated-grouped-gemm.mlir +++ b/test/TritonGPU/samples/simulated-grouped-gemm.mlir @@ -41,10 +41,10 @@ // CHECK: %[[VAL_28:.*]] = arith.divsi %[[VAL_27]], %[[VAL_16]] : i32 // CHECK: %[[VAL_29:.*]] = arith.muli %[[VAL_24]], %[[VAL_26]] : i32 // CHECK: %[[VAL_30:.*]] = arith.extsi %[[VAL_5]] : i32 to i64 -// CHECK: %[[VAL_31:.*]] = tt.make_tensor_descriptor %[[VAL_0]], {{\[}}%[[VAL_3]], %[[VAL_5]]], {{\[}}%[[VAL_30]], %[[VAL_17]]] : , > -// CHECK: %[[VAL_32:.*]] = tt.make_tensor_descriptor %[[VAL_1]], {{\[}}%[[VAL_4]], %[[VAL_5]]], {{\[}}%[[VAL_30]], %[[VAL_17]]] : , > +// CHECK: %[[VAL_31:.*]] = tt.make_tensor_descriptor %[[VAL_0]], {{\[}}%[[VAL_3]], %[[VAL_5]]], {{\[}}%[[VAL_30]], %[[VAL_17]]] : , <128x64xf16, #[[$ATTR_2]]> +// CHECK: %[[VAL_32:.*]] = tt.make_tensor_descriptor %[[VAL_1]], {{\[}}%[[VAL_4]], %[[VAL_5]]], {{\[}}%[[VAL_30]], %[[VAL_17]]] : , <256x64xf16, #[[$ATTR_2]]> // CHECK: %[[VAL_33:.*]] = arith.extsi %[[VAL_4]] : i32 to i64 -// CHECK: %[[VAL_34:.*]] = tt.make_tensor_descriptor %[[VAL_2]], {{\[}}%[[VAL_3]], %[[VAL_4]]], {{\[}}%[[VAL_33]], %[[VAL_17]]] : , > +// CHECK: %[[VAL_34:.*]] = tt.make_tensor_descriptor %[[VAL_2]], {{\[}}%[[VAL_3]], %[[VAL_4]]], {{\[}}%[[VAL_33]], %[[VAL_17]]] : , <128x256xf16, #[[$ATTR_2]]> // CHECK: %[[VAL_35:.*]] = arith.divsi %[[VAL_29]], %[[VAL_10]] : i32 // CHECK: %[[VAL_36:.*]] = arith.remsi %[[VAL_29]], %[[VAL_10]] : i32 // CHECK: %[[VAL_37:.*]] = arith.cmpi slt, %[[VAL_22]], %[[VAL_36]] : i32 @@ -63,23 +63,23 @@ // CHECK: %[[VAL_46:.*]] = ttg.global_scratch_alloc {alignment = 128 : i32, nbytes = 384 : i32} : !tt.ptr // CHECK: %[[VAL_47:.*]] = ttg.global_scratch_alloc {alignment = 128 : i32, nbytes = 384 : i32} : !tt.ptr // CHECK: %[[VAL_48:.*]] = ttg.global_scratch_alloc {alignment = 128 : i32, nbytes = 384 : i32} : !tt.ptr -// CHECK: %[[VAL_49:.*]]:13 = scf.for %[[VAL_50:.*]] = %[[VAL_12]] to %[[VAL_43]] step %[[VAL_9]] iter_args(%[[VAL_51:.*]] = %[[VAL_11]], %[[VAL_52:.*]] = %[[VAL_31]], %[[VAL_53:.*]] = %[[VAL_32]], %[[VAL_54:.*]] = %[[VAL_34]], %[[VAL_55:.*]] = %[[VAL_40]], %[[VAL_56:.*]] = %[[VAL_11]], %[[VAL_57:.*]] = %[[VAL_12]], %[[VAL_58:.*]] = %[[VAL_12]], %[[VAL_59:.*]] = %[[VAL_21]], %[[VAL_60:.*]] = %[[VAL_8]], %[[VAL_61:.*]] = %[[VAL_12]], %[[VAL_62:.*]] = %[[VAL_12]], %[[VAL_63:.*]] = %[[VAL_12]]) -> (i32, !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32, tensor<128x256xf32, #[[$ATTR_1]]>, i1, i32, i32, i32) : i32 { +// CHECK: %[[VAL_49:.*]]:13 = scf.for %[[VAL_50:.*]] = %[[VAL_12]] to %[[VAL_43]] step %[[VAL_9]] iter_args(%[[VAL_51:.*]] = %[[VAL_11]], %[[VAL_52:.*]] = %[[VAL_31]], %[[VAL_53:.*]] = %[[VAL_32]], %[[VAL_54:.*]] = %[[VAL_34]], %[[VAL_55:.*]] = %[[VAL_40]], %[[VAL_56:.*]] = %[[VAL_11]], %[[VAL_57:.*]] = %[[VAL_12]], %[[VAL_58:.*]] = %[[VAL_12]], %[[VAL_59:.*]] = %[[VAL_21]], %[[VAL_60:.*]] = %[[VAL_8]], %[[VAL_61:.*]] = %[[VAL_12]], %[[VAL_62:.*]] = %[[VAL_12]], %[[VAL_63:.*]] = %[[VAL_12]]) -> (i32, !tt.tensordesc<128x64xf16, #[[$ATTR_2]]>, !tt.tensordesc<256x64xf16, #[[$ATTR_2]]>, !tt.tensordesc<128x256xf16, #[[$ATTR_2]]>, i32, i32, i32, i32, tensor<128x256xf32, #[[$ATTR_1]]>, i1, i32, i32, i32) : i32 { // CHECK: %[[VAL_64:.*]] = arith.cmpi eq, %[[VAL_51]], %[[VAL_44]] : i32 // CHECK: %[[VAL_65:.*]] = arith.addi %[[VAL_51]], %[[VAL_9]] : i32 // CHECK: %[[VAL_66:.*]] = arith.select %[[VAL_64]], %[[VAL_12]], %[[VAL_65]] : i32 // CHECK: %[[VAL_67:.*]] = arith.cmpi eq, %[[VAL_66]], %[[VAL_12]] : i32 -// CHECK: %[[VAL_68:.*]]:10 = scf.if %[[VAL_67]] -> (!tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32, i32, i32, i32) { +// CHECK: %[[VAL_68:.*]]:10 = scf.if %[[VAL_67]] -> (!tt.tensordesc<128x64xf16, #[[$ATTR_2]]>, !tt.tensordesc<256x64xf16, #[[$ATTR_2]]>, !tt.tensordesc<128x256xf16, #[[$ATTR_2]]>, i32, i32, i32, i32, i32, i32, i32) { // CHECK: %[[VAL_69:.*]] = arith.addi %[[VAL_56]], %[[VAL_9]] : i32 // CHECK: %[[VAL_70:.*]] = arith.cmpi eq, %[[VAL_69]], %[[VAL_9]] : i32 // CHECK: %[[VAL_71:.*]] = arith.select %[[VAL_70]], %[[VAL_12]], %[[VAL_69]] : i32 -// CHECK: %[[VAL_72:.*]]:6 = scf.if %[[VAL_70]] -> (!tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32) { +// CHECK: %[[VAL_72:.*]]:6 = scf.if %[[VAL_70]] -> (!tt.tensordesc<128x64xf16, #[[$ATTR_2]]>, !tt.tensordesc<256x64xf16, #[[$ATTR_2]]>, !tt.tensordesc<128x256xf16, #[[$ATTR_2]]>, i32, i32, i32) { // CHECK: %[[VAL_73:.*]] = tt.addptr %[[VAL_0]], %[[VAL_42]] : !tt.ptr, i32 // CHECK: %[[VAL_74:.*]] = arith.muli %[[VAL_61]], %[[VAL_14]] : i32 // CHECK: %[[VAL_75:.*]] = tt.addptr %[[VAL_46]], %[[VAL_74]] : !tt.ptr, i32 // CHECK: %[[VAL_76:.*]] = arith.muli %[[VAL_30]], %[[VAL_6]] : i64 // CHECK: ttng.tensormap_create %[[VAL_75]], %[[VAL_73]], {{\[}}%[[VAL_16]], %[[VAL_14]]], {{\[}}%[[VAL_5]], %[[VAL_3]]], {{\[}}%[[VAL_76]]], {{\[}}%[[VAL_9]], %[[VAL_9]]] {elem_type = 6 : i32, fill_mode = 0 : i32, interleave_layout = 0 : i32, swizzle_mode = 3 : i32} : (!tt.ptr, !tt.ptr, i32, i32, i32, i32, i64, i32, i32) -> () // CHECK: ttng.tensormap_fenceproxy_acquire %[[VAL_75]] : !tt.ptr -// CHECK: %[[VAL_77:.*]] = ttng.reinterpret_tensor_descriptor %[[VAL_75]] : !tt.ptr to !tt.tensordesc> +// CHECK: %[[VAL_77:.*]] = ttng.reinterpret_tensor_descriptor %[[VAL_75]] : !tt.ptr to !tt.tensordesc<128x64xf16, #[[$ATTR_2]]> // CHECK: %[[VAL_78:.*]] = arith.addi %[[VAL_61]], %[[VAL_9]] : i32 // CHECK: %[[VAL_79:.*]] = arith.cmpi sge, %[[VAL_78]], %[[VAL_7]] : i32 // CHECK: %[[VAL_80:.*]] = arith.select %[[VAL_79]], %[[VAL_12]], %[[VAL_78]] : i32 @@ -89,7 +89,7 @@ // CHECK: %[[VAL_84:.*]] = arith.muli %[[VAL_30]], %[[VAL_6]] : i64 // CHECK: ttng.tensormap_create %[[VAL_83]], %[[VAL_81]], {{\[}}%[[VAL_16]], %[[VAL_15]]], {{\[}}%[[VAL_5]], %[[VAL_4]]], {{\[}}%[[VAL_84]]], {{\[}}%[[VAL_9]], %[[VAL_9]]] {elem_type = 6 : i32, fill_mode = 0 : i32, interleave_layout = 0 : i32, swizzle_mode = 3 : i32} : (!tt.ptr, !tt.ptr, i32, i32, i32, i32, i64, i32, i32) -> () // CHECK: ttng.tensormap_fenceproxy_acquire %[[VAL_83]] : !tt.ptr -// CHECK: %[[VAL_85:.*]] = ttng.reinterpret_tensor_descriptor %[[VAL_83]] : !tt.ptr to !tt.tensordesc> +// CHECK: %[[VAL_85:.*]] = ttng.reinterpret_tensor_descriptor %[[VAL_83]] : !tt.ptr to !tt.tensordesc<256x64xf16, #[[$ATTR_2]]> // CHECK: %[[VAL_86:.*]] = arith.addi %[[VAL_62]], %[[VAL_9]] : i32 // CHECK: %[[VAL_87:.*]] = arith.cmpi sge, %[[VAL_86]], %[[VAL_7]] : i32 // CHECK: %[[VAL_88:.*]] = arith.select %[[VAL_87]], %[[VAL_12]], %[[VAL_86]] : i32 @@ -99,13 +99,13 @@ // CHECK: %[[VAL_92:.*]] = arith.muli %[[VAL_33]], %[[VAL_6]] : i64 // CHECK: ttng.tensormap_create %[[VAL_91]], %[[VAL_89]], {{\[}}%[[VAL_16]], %[[VAL_14]]], {{\[}}%[[VAL_4]], %[[VAL_3]]], {{\[}}%[[VAL_92]]], {{\[}}%[[VAL_9]], %[[VAL_9]]] {elem_type = 6 : i32, fill_mode = 0 : i32, interleave_layout = 0 : i32, swizzle_mode = 3 : i32} : (!tt.ptr, !tt.ptr, i32, i32, i32, i32, i64, i32, i32) -> () // CHECK: ttng.tensormap_fenceproxy_acquire %[[VAL_91]] : !tt.ptr -// CHECK: %[[VAL_93:.*]] = ttng.reinterpret_tensor_descriptor %[[VAL_91]] : !tt.ptr to !tt.tensordesc> +// CHECK: %[[VAL_93:.*]] = ttng.reinterpret_tensor_descriptor %[[VAL_91]] : !tt.ptr to !tt.tensordesc<128x256xf16, #[[$ATTR_2]]> // CHECK: %[[VAL_94:.*]] = arith.addi %[[VAL_63]], %[[VAL_9]] : i32 // CHECK: %[[VAL_95:.*]] = arith.cmpi sge, %[[VAL_94]], %[[VAL_7]] : i32 // CHECK: %[[VAL_96:.*]] = arith.select %[[VAL_95]], %[[VAL_12]], %[[VAL_94]] : i32 -// CHECK: scf.yield %[[VAL_77]], %[[VAL_85]], %[[VAL_93]], %[[VAL_80]], %[[VAL_88]], %[[VAL_96]] : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32 +// CHECK: scf.yield %[[VAL_77]], %[[VAL_85]], %[[VAL_93]], %[[VAL_80]], %[[VAL_88]], %[[VAL_96]] : !tt.tensordesc<128x64xf16, #[[$ATTR_2]]>, !tt.tensordesc<256x64xf16, #[[$ATTR_2]]>, !tt.tensordesc<128x256xf16, #[[$ATTR_2]]>, i32, i32, i32 // CHECK: } else { -// CHECK: scf.yield %[[VAL_52]], %[[VAL_53]], %[[VAL_54]], %[[VAL_61]], %[[VAL_62]], %[[VAL_63]] : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32 +// CHECK: scf.yield %[[VAL_52]], %[[VAL_53]], %[[VAL_54]], %[[VAL_61]], %[[VAL_62]], %[[VAL_63]] : !tt.tensordesc<128x64xf16, #[[$ATTR_2]]>, !tt.tensordesc<256x64xf16, #[[$ATTR_2]]>, !tt.tensordesc<128x256xf16, #[[$ATTR_2]]>, i32, i32, i32 // CHECK: } // CHECK: %[[VAL_97:.*]] = arith.addi %[[VAL_55]], %[[VAL_10]] : i32 // CHECK: %[[VAL_98:.*]] = arith.divsi %[[VAL_97]], %[[VAL_41]] : i32 @@ -118,14 +118,14 @@ // CHECK: %[[VAL_105:.*]] = arith.divsi %[[VAL_104]], %[[VAL_101]] : i32 // CHECK: %[[VAL_106:.*]] = arith.muli %[[VAL_103]], %[[VAL_14]] : i32 // CHECK: %[[VAL_107:.*]] = arith.muli %[[VAL_105]], %[[VAL_15]] : i32 -// CHECK: scf.yield %[[VAL_108:.*]]#0, %[[VAL_108]]#1, %[[VAL_108]]#2, %[[VAL_97]], %[[VAL_71]], %[[VAL_106]], %[[VAL_107]], %[[VAL_108]]#3, %[[VAL_108]]#4, %[[VAL_108]]#5 : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32, i32, i32, i32 +// CHECK: scf.yield %[[VAL_108:.*]]#0, %[[VAL_108]]#1, %[[VAL_108]]#2, %[[VAL_97]], %[[VAL_71]], %[[VAL_106]], %[[VAL_107]], %[[VAL_108]]#3, %[[VAL_108]]#4, %[[VAL_108]]#5 : !tt.tensordesc<128x64xf16, #[[$ATTR_2]]>, !tt.tensordesc<256x64xf16, #[[$ATTR_2]]>, !tt.tensordesc<128x256xf16, #[[$ATTR_2]]>, i32, i32, i32, i32, i32, i32, i32 // CHECK: } else { -// CHECK: scf.yield %[[VAL_52]], %[[VAL_53]], %[[VAL_54]], %[[VAL_55]], %[[VAL_56]], %[[VAL_57]], %[[VAL_58]], %[[VAL_61]], %[[VAL_62]], %[[VAL_63]] : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32, i32, i32, i32 +// CHECK: scf.yield %[[VAL_52]], %[[VAL_53]], %[[VAL_54]], %[[VAL_55]], %[[VAL_56]], %[[VAL_57]], %[[VAL_58]], %[[VAL_61]], %[[VAL_62]], %[[VAL_63]] : !tt.tensordesc<128x64xf16, #[[$ATTR_2]]>, !tt.tensordesc<256x64xf16, #[[$ATTR_2]]>, !tt.tensordesc<128x256xf16, #[[$ATTR_2]]>, i32, i32, i32, i32, i32, i32, i32 // CHECK: } // CHECK: %[[VAL_109:.*]] = arith.muli %[[VAL_66]], %[[VAL_16]] : i32 -// CHECK: %[[VAL_110:.*]] = tt.descriptor_load %[[VAL_111:.*]]#0{{\[}}%[[VAL_111]]#5, %[[VAL_109]]] : !tt.tensordesc> -> tensor<128x64xf16, #[[$ATTR_0]]> +// CHECK: %[[VAL_110:.*]] = tt.descriptor_load %[[VAL_111:.*]]#0{{\[}}%[[VAL_111]]#5, %[[VAL_109]]] : !tt.tensordesc<128x64xf16, #[[$ATTR_2]]> -> tensor<128x64xf16, #[[$ATTR_0]]> // CHECK: %[[VAL_112:.*]] = ttg.local_alloc %[[VAL_110]] : (tensor<128x64xf16, #[[$ATTR_0]]>) -> !ttg.memdesc<128x64xf16, #[[$ATTR_2]], #[[$ATTR_4]]> -// CHECK: %[[VAL_113:.*]] = tt.descriptor_load %[[VAL_111]]#1{{\[}}%[[VAL_111]]#6, %[[VAL_109]]] : !tt.tensordesc> -> tensor<256x64xf16, #[[$ATTR_0]]> +// CHECK: %[[VAL_113:.*]] = tt.descriptor_load %[[VAL_111]]#1{{\[}}%[[VAL_111]]#6, %[[VAL_109]]] : !tt.tensordesc<256x64xf16, #[[$ATTR_2]]> -> tensor<256x64xf16, #[[$ATTR_0]]> // CHECK: %[[VAL_114:.*]] = ttg.local_alloc %[[VAL_113]] : (tensor<256x64xf16, #[[$ATTR_0]]>) -> !ttg.memdesc<256x64xf16, #[[$ATTR_2]], #[[$ATTR_4]]> // CHECK: %[[VAL_115:.*]] = ttg.memdesc_trans %[[VAL_114]] {order = array} : !ttg.memdesc<256x64xf16, #[[$ATTR_2]], #[[$ATTR_4]]> -> !ttg.memdesc<64x256xf16, #[[$ATTR_3]], #[[$ATTR_4]]> // CHECK: %[[VAL_116:.*]] = ttng.warp_group_dot %[[VAL_112]], %[[VAL_115]], %[[VAL_59]], %[[VAL_60]] {inputPrecision = 0 : i32, isAsync = true} : !ttg.memdesc<128x64xf16, #[[$ATTR_2]], #[[$ATTR_4]]> * !ttg.memdesc<64x256xf16, #[[$ATTR_3]], #[[$ATTR_4]]> -> tensor<128x256xf32, #[[$ATTR_1]]> @@ -137,9 +137,9 @@ // CHECK: ttng.async_tma_store_wait {pendings = 0 : i32} // CHECK: ttg.local_store %[[VAL_120]], %[[VAL_45]] : tensor<128x256xf16, #[[$ATTR_1]]> -> !ttg.memdesc<128x256xf16, #[[$ATTR_2]], #[[$ATTR_4]], mutable> // CHECK: ttng.fence_async_shared {bCluster = false} -// CHECK: ttng.async_tma_copy_local_to_global %[[VAL_111]]#2{{\[}}%[[VAL_111]]#5, %[[VAL_111]]#6] %[[VAL_45]] : !tt.tensordesc>, !ttg.memdesc<128x256xf16, #[[$ATTR_2]], #[[$ATTR_4]], mutable> +// CHECK: ttng.async_tma_copy_local_to_global %[[VAL_111]]#2{{\[}}%[[VAL_111]]#5, %[[VAL_111]]#6] %[[VAL_45]] : !tt.tensordesc<128x256xf16, #[[$ATTR_2]]>, !ttg.memdesc<128x256xf16, #[[$ATTR_2]], #[[$ATTR_4]], mutable> // CHECK: } -// CHECK: scf.yield %[[VAL_66]], %[[VAL_111]]#0, %[[VAL_111]]#1, %[[VAL_111]]#2, %[[VAL_111]]#3, %[[VAL_111]]#4, %[[VAL_111]]#5, %[[VAL_111]]#6, %[[VAL_117]]#0, %[[VAL_119]], %[[VAL_111]]#7, %[[VAL_111]]#8, %[[VAL_111]]#9 : i32, !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32, tensor<128x256xf32, #[[$ATTR_1]]>, i1, i32, i32, i32 +// CHECK: scf.yield %[[VAL_66]], %[[VAL_111]]#0, %[[VAL_111]]#1, %[[VAL_111]]#2, %[[VAL_111]]#3, %[[VAL_111]]#4, %[[VAL_111]]#5, %[[VAL_111]]#6, %[[VAL_117]]#0, %[[VAL_119]], %[[VAL_111]]#7, %[[VAL_111]]#8, %[[VAL_111]]#9 : i32, !tt.tensordesc<128x64xf16, #[[$ATTR_2]]>, !tt.tensordesc<256x64xf16, #[[$ATTR_2]]>, !tt.tensordesc<128x256xf16, #[[$ATTR_2]]>, i32, i32, i32, i32, tensor<128x256xf32, #[[$ATTR_1]]>, i1, i32, i32, i32 // CHECK: } // CHECK: ttng.async_tma_store_wait {pendings = 0 : i32} // CHECK: ttg.local_dealloc %[[VAL_45]] : !ttg.memdesc<128x256xf16, #[[$ATTR_2]], #[[$ATTR_4]], mutable> @@ -169,10 +169,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ %6 = arith.divsi %5, %c64_i32 : i32 %7 = arith.muli %2, %4 : i32 %8 = arith.extsi %arg5 : i32 to i64 - %9 = tt.make_tensor_descriptor %arg0, [%arg3, %arg5], [%8, %c1_i64] : , > - %10 = tt.make_tensor_descriptor %arg1, [%arg4, %arg5], [%8, %c1_i64] : , > + %9 = tt.make_tensor_descriptor %arg0, [%arg3, %arg5], [%8, %c1_i64] : , <128x64xf16, #nvmma_128> + %10 = tt.make_tensor_descriptor %arg1, [%arg4, %arg5], [%8, %c1_i64] : , <256x64xf16, #nvmma_128> %11 = arith.extsi %arg4 : i32 to i64 - %12 = tt.make_tensor_descriptor %arg2, [%arg3, %arg4], [%11, %c1_i64] : , > + %12 = tt.make_tensor_descriptor %arg2, [%arg3, %arg4], [%11, %c1_i64] : , <128x256xf16, #nvmma_128> %13 = arith.divsi %7, %c132_i32 : i32 %14 = arith.remsi %7, %c132_i32 : i32 %15 = arith.cmpi slt, %0, %14 : i32 @@ -189,24 +189,24 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ %21 = arith.subi %6, %c1_i32 : i32 %true = arith.constant true %false = arith.constant false - %22:10 = scf.for %arg6 = %c0_i32 to %20 step %c1_i32 iter_args(%arg7 = %c-1_i32, %arg8 = %9, %arg9 = %10, %arg10 = %12, %arg11 = %17, %arg12 = %c-1_i32, %arg13 = %c0_i32, %arg14 = %c0_i32, %arg15 = %cst, %arg16 = %false) -> (i32, !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32, tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>>, i1) : i32 { + %22:10 = scf.for %arg6 = %c0_i32 to %20 step %c1_i32 iter_args(%arg7 = %c-1_i32, %arg8 = %9, %arg9 = %10, %arg10 = %12, %arg11 = %17, %arg12 = %c-1_i32, %arg13 = %c0_i32, %arg14 = %c0_i32, %arg15 = %cst, %arg16 = %false) -> (i32, !tt.tensordesc<128x64xf16, #nvmma_128>, !tt.tensordesc<256x64xf16, #nvmma_128>, !tt.tensordesc<128x256xf16, #nvmma_128>, i32, i32, i32, i32, tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>>, i1) : i32 { %23 = arith.cmpi eq, %arg7, %21 {loop.cluster = 0 : i32, loop.stage = 0 : i32} : i32 %24 = arith.addi %arg7, %c1_i32 {loop.cluster = 0 : i32, loop.stage = 0 : i32} : i32 %25 = arith.select %23, %c0_i32, %24 {loop.cluster = 0 : i32, loop.stage = 0 : i32} : i32 %26 = arith.cmpi eq, %25, %c0_i32 {loop.cluster = 0 : i32, loop.stage = 0 : i32} : i32 - %27:7 = scf.if %26 -> (!tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32) { + %27:7 = scf.if %26 -> (!tt.tensordesc<128x64xf16, #nvmma_128>, !tt.tensordesc<256x64xf16, #nvmma_128>, !tt.tensordesc<128x256xf16, #nvmma_128>, i32, i32, i32, i32) { %37 = arith.addi %arg12, %c1_i32 : i32 %38 = arith.cmpi eq, %37, %c1_i32 : i32 - %39:4 = scf.if %38 -> (!tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32) { + %39:4 = scf.if %38 -> (!tt.tensordesc<128x64xf16, #nvmma_128>, !tt.tensordesc<256x64xf16, #nvmma_128>, !tt.tensordesc<128x256xf16, #nvmma_128>, i32) { %51 = tt.addptr %arg0, %19 : !tt.ptr, i32 - %52 = tt.make_tensor_descriptor %51, [%arg3, %arg5], [%8, %c1_i64] : , > + %52 = tt.make_tensor_descriptor %51, [%arg3, %arg5], [%8, %c1_i64] : , <128x64xf16, #nvmma_128> %53 = tt.addptr %arg1, %19 : !tt.ptr, i32 - %54 = tt.make_tensor_descriptor %53, [%arg4, %arg5], [%8, %c1_i64] : , > + %54 = tt.make_tensor_descriptor %53, [%arg4, %arg5], [%8, %c1_i64] : , <256x64xf16, #nvmma_128> %55 = tt.addptr %arg2, %19 : !tt.ptr, i32 - %56 = tt.make_tensor_descriptor %55, [%arg3, %arg4], [%11, %c1_i64] : , > - scf.yield %52, %54, %56, %c0_i32 : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32 + %56 = tt.make_tensor_descriptor %55, [%arg3, %arg4], [%11, %c1_i64] : , <128x256xf16, #nvmma_128> + scf.yield %52, %54, %56, %c0_i32 : !tt.tensordesc<128x64xf16, #nvmma_128>, !tt.tensordesc<256x64xf16, #nvmma_128>, !tt.tensordesc<128x256xf16, #nvmma_128>, i32 } else { - scf.yield %arg8, %arg9, %arg10, %37 : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32 + scf.yield %arg8, %arg9, %arg10, %37 : !tt.tensordesc<128x64xf16, #nvmma_128>, !tt.tensordesc<256x64xf16, #nvmma_128>, !tt.tensordesc<128x256xf16, #nvmma_128>, i32 } %40 = arith.addi %arg11, %c132_i32 : i32 %41 = arith.divsi %40, %18 : i32 @@ -219,14 +219,14 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ %48 = arith.divsi %47, %44 : i32 %49 = arith.muli %46, %c128_i32 : i32 %50 = arith.muli %48, %c256_i32 : i32 - scf.yield %39#0, %39#1, %39#2, %40, %39#3, %49, %50 : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32 + scf.yield %39#0, %39#1, %39#2, %40, %39#3, %49, %50 : !tt.tensordesc<128x64xf16, #nvmma_128>, !tt.tensordesc<256x64xf16, #nvmma_128>, !tt.tensordesc<128x256xf16, #nvmma_128>, i32, i32, i32, i32 } else { - scf.yield %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14 : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32 + scf.yield %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14 : !tt.tensordesc<128x64xf16, #nvmma_128>, !tt.tensordesc<256x64xf16, #nvmma_128>, !tt.tensordesc<128x256xf16, #nvmma_128>, i32, i32, i32, i32 } {loop.cluster = 0 : i32, loop.stage = 0 : i32} %28 = arith.muli %25, %c64_i32 {loop.cluster = 2 : i32, loop.stage = 0 : i32} : i32 - %29 = tt.descriptor_load %27#0[%27#5, %28] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : !tt.tensordesc> -> tensor<128x64xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>> + %29 = tt.descriptor_load %27#0[%27#5, %28] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : !tt.tensordesc<128x64xf16, #nvmma_128> -> tensor<128x64xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>> %30 = ttg.local_alloc %29 {loop.cluster = 1 : i32, loop.stage = 2 : i32} : (tensor<128x64xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>>) -> !ttg.memdesc<128x64xf16, #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>, #ttg.shared_memory> - %31 = tt.descriptor_load %27#1[%27#6, %28] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : !tt.tensordesc> -> tensor<256x64xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>> + %31 = tt.descriptor_load %27#1[%27#6, %28] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : !tt.tensordesc<256x64xf16, #nvmma_128> -> tensor<256x64xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>> %32 = ttg.local_alloc %31 {loop.cluster = 1 : i32, loop.stage = 2 : i32} : (tensor<256x64xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>>) -> !ttg.memdesc<256x64xf16, #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>, #ttg.shared_memory> %33 = ttg.memdesc_trans %32 {loop.cluster = 1 : i32, loop.stage = 2 : i32, order = array} : !ttg.memdesc<256x64xf16, #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>, #ttg.shared_memory> -> !ttg.memdesc<64x256xf16, #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 16}>, #ttg.shared_memory> %34 = ttng.warp_group_dot %30, %33, %arg15, %arg16 {inputPrecision = 0 : i32, loop.cluster = 1 : i32, loop.stage = 2 : i32} : !ttg.memdesc<128x64xf16, #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>, #ttg.shared_memory> * !ttg.memdesc<64x256xf16, #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 16}>, #ttg.shared_memory> -> tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> @@ -234,12 +234,12 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ %36 = scf.if %35 -> (i1) { %37 = arith.truncf %34 : tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> to tensor<128x256xf16, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> %38 = ttg.convert_layout %37 : tensor<128x256xf16, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> -> tensor<128x256xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 8], order = [1, 0]}>> - tt.descriptor_store %27#2[%27#5, %27#6], %38 : !tt.tensordesc>, tensor<128x256xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 8], order = [1, 0]}>> + tt.descriptor_store %27#2[%27#5, %27#6], %38 : !tt.tensordesc<128x256xf16, #nvmma_128>, tensor<128x256xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 8], order = [1, 0]}>> scf.yield %false : i1 } else { scf.yield %true : i1 } {loop.cluster = 3 : i32, loop.stage = 2 : i32} - scf.yield %25, %27#0, %27#1, %27#2, %27#3, %27#4, %27#5, %27#6, %34, %36 : i32, !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32, tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>>, i1 + scf.yield %25, %27#0, %27#1, %27#2, %27#3, %27#4, %27#5, %27#6, %34, %36 : i32, !tt.tensordesc<128x64xf16, #nvmma_128>, !tt.tensordesc<256x64xf16, #nvmma_128>, !tt.tensordesc<128x256xf16, #nvmma_128>, i32, i32, i32, i32, tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>>, i1 } tt.return } diff --git a/test/TritonGPU/samples/simulated-grouped-gemm.mlir.in b/test/TritonGPU/samples/simulated-grouped-gemm.mlir.in index 9392c5782ed8..b27399c6f2c9 100644 --- a/test/TritonGPU/samples/simulated-grouped-gemm.mlir.in +++ b/test/TritonGPU/samples/simulated-grouped-gemm.mlir.in @@ -26,10 +26,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ %6 = arith.divsi %5, %c64_i32 : i32 %7 = arith.muli %2, %4 : i32 %8 = arith.extsi %arg5 : i32 to i64 - %9 = tt.make_tensor_descriptor %arg0, [%arg3, %arg5], [%8, %c1_i64] : , > - %10 = tt.make_tensor_descriptor %arg1, [%arg4, %arg5], [%8, %c1_i64] : , > + %9 = tt.make_tensor_descriptor %arg0, [%arg3, %arg5], [%8, %c1_i64] : , <128x64xf16, #nvmma_128> + %10 = tt.make_tensor_descriptor %arg1, [%arg4, %arg5], [%8, %c1_i64] : , <256x64xf16, #nvmma_128> %11 = arith.extsi %arg4 : i32 to i64 - %12 = tt.make_tensor_descriptor %arg2, [%arg3, %arg4], [%11, %c1_i64] : , > + %12 = tt.make_tensor_descriptor %arg2, [%arg3, %arg4], [%11, %c1_i64] : , <128x256xf16, #nvmma_128> %13 = arith.divsi %7, %c132_i32 : i32 %14 = arith.remsi %7, %c132_i32 : i32 %15 = arith.cmpi slt, %0, %14 : i32 @@ -46,24 +46,24 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ %21 = arith.subi %6, %c1_i32 : i32 %true = arith.constant true %false = arith.constant false - %22:10 = scf.for %arg6 = %c0_i32 to %20 step %c1_i32 iter_args(%arg7 = %c-1_i32, %arg8 = %9, %arg9 = %10, %arg10 = %12, %arg11 = %17, %arg12 = %c-1_i32, %arg13 = %c0_i32, %arg14 = %c0_i32, %arg15 = %cst, %arg16 = %false) -> (i32, !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32, tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>>, i1) : i32 { + %22:10 = scf.for %arg6 = %c0_i32 to %20 step %c1_i32 iter_args(%arg7 = %c-1_i32, %arg8 = %9, %arg9 = %10, %arg10 = %12, %arg11 = %17, %arg12 = %c-1_i32, %arg13 = %c0_i32, %arg14 = %c0_i32, %arg15 = %cst, %arg16 = %false) -> (i32, !tt.tensordesc<128x64xf16, #nvmma_128>, !tt.tensordesc<256x64xf16, #nvmma_128>, !tt.tensordesc<128x256xf16, #nvmma_128>, i32, i32, i32, i32, tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>>, i1) : i32 { %23 = arith.cmpi eq, %arg7, %21 {loop.cluster = 0 : i32, loop.stage = 0 : i32} : i32 %24 = arith.addi %arg7, %c1_i32 {loop.cluster = 0 : i32, loop.stage = 0 : i32} : i32 %25 = arith.select %23, %c0_i32, %24 {loop.cluster = 0 : i32, loop.stage = 0 : i32} : i32 %26 = arith.cmpi eq, %25, %c0_i32 {loop.cluster = 0 : i32, loop.stage = 0 : i32} : i32 - %27:7 = scf.if %26 -> (!tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32) { + %27:7 = scf.if %26 -> (!tt.tensordesc<128x64xf16, #nvmma_128>, !tt.tensordesc<256x64xf16, #nvmma_128>, !tt.tensordesc<128x256xf16, #nvmma_128>, i32, i32, i32, i32) { %37 = arith.addi %arg12, %c1_i32 : i32 %38 = arith.cmpi eq, %37, %c1_i32 : i32 - %39:4 = scf.if %38 -> (!tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32) { + %39:4 = scf.if %38 -> (!tt.tensordesc<128x64xf16, #nvmma_128>, !tt.tensordesc<256x64xf16, #nvmma_128>, !tt.tensordesc<128x256xf16, #nvmma_128>, i32) { %51 = tt.addptr %arg0, %19 : !tt.ptr, i32 - %52 = tt.make_tensor_descriptor %51, [%arg3, %arg5], [%8, %c1_i64] : , > + %52 = tt.make_tensor_descriptor %51, [%arg3, %arg5], [%8, %c1_i64] : , <128x64xf16, #nvmma_128> %53 = tt.addptr %arg1, %19 : !tt.ptr, i32 - %54 = tt.make_tensor_descriptor %53, [%arg4, %arg5], [%8, %c1_i64] : , > + %54 = tt.make_tensor_descriptor %53, [%arg4, %arg5], [%8, %c1_i64] : , <256x64xf16, #nvmma_128> %55 = tt.addptr %arg2, %19 : !tt.ptr, i32 - %56 = tt.make_tensor_descriptor %55, [%arg3, %arg4], [%11, %c1_i64] : , > - scf.yield %52, %54, %56, %c0_i32 : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32 + %56 = tt.make_tensor_descriptor %55, [%arg3, %arg4], [%11, %c1_i64] : , <128x256xf16, #nvmma_128> + scf.yield %52, %54, %56, %c0_i32 : !tt.tensordesc<128x64xf16, #nvmma_128>, !tt.tensordesc<256x64xf16, #nvmma_128>, !tt.tensordesc<128x256xf16, #nvmma_128>, i32 } else { - scf.yield %arg8, %arg9, %arg10, %37 : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32 + scf.yield %arg8, %arg9, %arg10, %37 : !tt.tensordesc<128x64xf16, #nvmma_128>, !tt.tensordesc<256x64xf16, #nvmma_128>, !tt.tensordesc<128x256xf16, #nvmma_128>, i32 } %40 = arith.addi %arg11, %c132_i32 : i32 %41 = arith.divsi %40, %18 : i32 @@ -76,14 +76,14 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ %48 = arith.divsi %47, %44 : i32 %49 = arith.muli %46, %c128_i32 : i32 %50 = arith.muli %48, %c256_i32 : i32 - scf.yield %39#0, %39#1, %39#2, %40, %39#3, %49, %50 : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32 + scf.yield %39#0, %39#1, %39#2, %40, %39#3, %49, %50 : !tt.tensordesc<128x64xf16, #nvmma_128>, !tt.tensordesc<256x64xf16, #nvmma_128>, !tt.tensordesc<128x256xf16, #nvmma_128>, i32, i32, i32, i32 } else { - scf.yield %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14 : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32 + scf.yield %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14 : !tt.tensordesc<128x64xf16, #nvmma_128>, !tt.tensordesc<256x64xf16, #nvmma_128>, !tt.tensordesc<128x256xf16, #nvmma_128>, i32, i32, i32, i32 } {loop.cluster = 0 : i32, loop.stage = 0 : i32} %28 = arith.muli %25, %c64_i32 {loop.cluster = 2 : i32, loop.stage = 0 : i32} : i32 - %29 = tt.descriptor_load %27#0[%27#5, %28] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : !tt.tensordesc> -> tensor<128x64xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>> + %29 = tt.descriptor_load %27#0[%27#5, %28] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : !tt.tensordesc<128x64xf16, #nvmma_128> -> tensor<128x64xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>> %30 = ttg.local_alloc %29 {loop.cluster = 1 : i32, loop.stage = 2 : i32} : (tensor<128x64xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>>) -> !ttg.memdesc<128x64xf16, #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>, #ttg.shared_memory> - %31 = tt.descriptor_load %27#1[%27#6, %28] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : !tt.tensordesc> -> tensor<256x64xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>> + %31 = tt.descriptor_load %27#1[%27#6, %28] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : !tt.tensordesc<256x64xf16, #nvmma_128> -> tensor<256x64xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>> %32 = ttg.local_alloc %31 {loop.cluster = 1 : i32, loop.stage = 2 : i32} : (tensor<256x64xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>>) -> !ttg.memdesc<256x64xf16, #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>, #ttg.shared_memory> %33 = ttg.memdesc_trans %32 {loop.cluster = 1 : i32, loop.stage = 2 : i32, order = array} : !ttg.memdesc<256x64xf16, #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>, #ttg.shared_memory> -> !ttg.memdesc<64x256xf16, #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 16}>, #ttg.shared_memory> %34 = ttng.warp_group_dot %30, %33, %arg15, %arg16 {inputPrecision = 0 : i32, loop.cluster = 1 : i32, loop.stage = 2 : i32} : !ttg.memdesc<128x64xf16, #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>, #ttg.shared_memory> * !ttg.memdesc<64x256xf16, #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 16}>, #ttg.shared_memory> -> tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> @@ -91,12 +91,12 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ %36 = scf.if %35 -> (i1) { %37 = arith.truncf %34 : tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> to tensor<128x256xf16, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> %38 = ttg.convert_layout %37 : tensor<128x256xf16, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> -> tensor<128x256xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 8], order = [1, 0]}>> - tt.descriptor_store %27#2[%27#5, %27#6], %38 : !tt.tensordesc>, tensor<128x256xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 8], order = [1, 0]}>> + tt.descriptor_store %27#2[%27#5, %27#6], %38 : !tt.tensordesc<128x256xf16, #nvmma_128>, tensor<128x256xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 8], order = [1, 0]}>> scf.yield %false : i1 } else { scf.yield %true : i1 } {loop.cluster = 3 : i32, loop.stage = 2 : i32} - scf.yield %25, %27#0, %27#1, %27#2, %27#3, %27#4, %27#5, %27#6, %34, %36 : i32, !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32, tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>>, i1 + scf.yield %25, %27#0, %27#1, %27#2, %27#3, %27#4, %27#5, %27#6, %34, %36 : i32, !tt.tensordesc<128x64xf16, #nvmma_128>, !tt.tensordesc<256x64xf16, #nvmma_128>, !tt.tensordesc<128x256xf16, #nvmma_128>, i32, i32, i32, i32, tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>>, i1 } tt.return } diff --git a/test/TritonNvidiaGPU/invalid.mlir b/test/TritonNvidiaGPU/invalid.mlir index 71ee17b5c6ed..94d96fb9b22b 100644 --- a/test/TritonNvidiaGPU/invalid.mlir +++ b/test/TritonNvidiaGPU/invalid.mlir @@ -72,12 +72,12 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar #blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { -tt.func @async_tma_gather(%desc: !tt.tensordesc>, %x_offsets: tensor<32xi32, #blocked>, %y_offset: i32, +tt.func @async_tma_gather(%desc: !tt.tensordesc<1x128xbf16, #shared>, %x_offsets: tensor<32xi32, #blocked>, %y_offset: i32, %bar: !ttg.memdesc<2xi32, #shared1, #ttg.shared_memory, mutable>, %result: !ttg.memdesc<32x128xbf16, #shared, #ttg.shared_memory, mutable>, %pred: i1) { // expected-error @below {{barrier allocation must be a descriptor of Nxi64 type with N <= number of CTAs}} - ttng.async_tma_gather %desc[%x_offsets, %y_offset] %result, %bar, %pred : !tt.tensordesc>, tensor<32xi32, #blocked>, i32, !ttg.memdesc<2xi32, #shared1, #ttg.shared_memory, mutable>, !ttg.memdesc<32x128xbf16, #shared, #ttg.shared_memory, mutable>, i1 + ttng.async_tma_gather %desc[%x_offsets, %y_offset] %result, %bar, %pred : !tt.tensordesc<1x128xbf16, #shared>, tensor<32xi32, #blocked>, i32, !ttg.memdesc<2xi32, #shared1, #ttg.shared_memory, mutable>, !ttg.memdesc<32x128xbf16, #shared, #ttg.shared_memory, mutable>, i1 tt.return } } @@ -90,12 +90,12 @@ tt.func @async_tma_gather(%desc: !tt.tensordesc>, %x #blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> module attributes {"ttg.num-warps" = 4 : i32} { -tt.func @async_tma_gather(%desc: !tt.tensordesc>, %x_offsets: tensor<32xi32, #blocked>, %y_offset: i32, +tt.func @async_tma_gather(%desc: !tt.tensordesc<1x128xbf16, #shared>, %x_offsets: tensor<32xi32, #blocked>, %y_offset: i32, %bar: !ttg.memdesc<1xi64, #shared1, #ttg.shared_memory, mutable>, %result: !ttg.memdesc<32x128xbf16, #shared, #ttg.shared_memory>, %pred: i1) { // expected-error @below {{cannot store into immutable memory}} - ttng.async_tma_gather %desc[%x_offsets, %y_offset] %result, %bar, %pred : !tt.tensordesc>, tensor<32xi32, #blocked>, i32, !ttg.memdesc<1xi64, #shared1, #ttg.shared_memory, mutable>, !ttg.memdesc<32x128xbf16, #shared, #ttg.shared_memory>, i1 + ttng.async_tma_gather %desc[%x_offsets, %y_offset] %result, %bar, %pred : !tt.tensordesc<1x128xbf16, #shared>, tensor<32xi32, #blocked>, i32, !ttg.memdesc<1xi64, #shared1, #ttg.shared_memory, mutable>, !ttg.memdesc<32x128xbf16, #shared, #ttg.shared_memory>, i1 tt.return } } @@ -121,13 +121,13 @@ tt.func @wgmma(%a: tensor<128x128xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kW #shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> #smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { - tt.func public @async_tma_copy_global_to_local(%arg0: !tt.tensordesc>) -> tensor<256x32xf32, #blocked> { + tt.func public @async_tma_copy_global_to_local(%arg0: !tt.tensordesc<1x256x32xf32, #shared>) -> tensor<256x32xf32, #blocked> { %true = arith.constant true %c32_i32 = arith.constant 32 : i32 %0 = ttg.local_alloc : () -> !ttg.memdesc<256x32xf32, #shared, #smem, mutable> %1 = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #shared1, #smem, mutable> // expected-error @below {{TMA descriptor must have NVMMA shared layout}} - ttng.async_tma_copy_global_to_local %arg0[%c32_i32, %c32_i32, %c32_i32] %0, %1, %true : !tt.tensordesc>, !ttg.memdesc<1xi64, #shared1, #smem, mutable> -> !ttg.memdesc<256x32xf32, #shared, #smem, mutable> + ttng.async_tma_copy_global_to_local %arg0[%c32_i32, %c32_i32, %c32_i32] %0, %1, %true : !tt.tensordesc<1x256x32xf32, #shared>, !ttg.memdesc<1xi64, #shared1, #smem, mutable> -> !ttg.memdesc<256x32xf32, #shared, #smem, mutable> } } @@ -138,13 +138,13 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ #shared2 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> #smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { - tt.func public @async_tma_copy_global_to_local(%arg0: !tt.tensordesc>) -> tensor<256x32xf32, #blocked> { + tt.func public @async_tma_copy_global_to_local(%arg0: !tt.tensordesc<1x256x32xf32, #shared>) -> tensor<256x32xf32, #blocked> { %true = arith.constant true %c32_i32 = arith.constant 32 : i32 %0 = ttg.local_alloc : () -> !ttg.memdesc<256x32xf32, #shared, #smem, mutable> %1 = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #shared2, #smem, mutable> // expected-error @below {{TMA descriptor layout must not be transposed}} - ttng.async_tma_copy_global_to_local %arg0[%c32_i32, %c32_i32, %c32_i32] %0, %1, %true : !tt.tensordesc>, !ttg.memdesc<1xi64, #shared2, #smem, mutable> -> !ttg.memdesc<256x32xf32, #shared, #smem, mutable> + ttng.async_tma_copy_global_to_local %arg0[%c32_i32, %c32_i32, %c32_i32] %0, %1, %true : !tt.tensordesc<1x256x32xf32, #shared>, !ttg.memdesc<1xi64, #shared2, #smem, mutable> -> !ttg.memdesc<256x32xf32, #shared, #smem, mutable> } } @@ -156,13 +156,13 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ #shared_mbar = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> #smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { - tt.func public @async_tma_copy_global_to_local(%arg0: !tt.tensordesc>) { + tt.func public @async_tma_copy_global_to_local(%arg0: !tt.tensordesc<1x256x64xf32, #nvmma32>) { %true = arith.constant true %c32_i32 = arith.constant 32 : i32 %0 = ttg.local_alloc : () -> !ttg.memdesc<256x64xf32, #nvmma64, #smem, mutable> %1 = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #shared_mbar, #smem, mutable> // expected-error @below {{TMA descriptor layout must match shared layout}} - ttng.async_tma_copy_global_to_local %arg0[%c32_i32, %c32_i32, %c32_i32] %0, %1, %true : !tt.tensordesc>, !ttg.memdesc<1xi64, #shared_mbar, #smem, mutable> -> !ttg.memdesc<256x64xf32, #nvmma64, #smem, mutable> + ttng.async_tma_copy_global_to_local %arg0[%c32_i32, %c32_i32, %c32_i32] %0, %1, %true : !tt.tensordesc<1x256x64xf32, #nvmma32>, !ttg.memdesc<1xi64, #shared_mbar, #smem, mutable> -> !ttg.memdesc<256x64xf32, #nvmma64, #smem, mutable> tt.return } } @@ -173,13 +173,13 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ #shared2 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> #smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { - tt.func public @tma_im2col_missing_offsets(%arg0: !ttng.tensordesc_im2col>) { + tt.func public @tma_im2col_missing_offsets(%arg0: !ttng.tensordesc_im2col<64x128xf16, #nvmma_128>) { %true = arith.constant true %c0_i32 = arith.constant 0 : i32 %0 = ttg.local_alloc : () -> !ttg.memdesc<64x128xf16, #nvmma_128, #smem, mutable> %1 = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #shared2, #smem, mutable> // expected-error @below {{IM2COL mode requires offsets to be provided}} - ttng.async_tma_copy_global_to_local %arg0[%c0_i32, %c0_i32, %c0_i32, %c0_i32] %0, %1, %true : !ttng.tensordesc_im2col>, !ttg.memdesc<1xi64, #shared2, #smem, mutable> -> !ttg.memdesc<64x128xf16, #nvmma_128, #smem, mutable> + ttng.async_tma_copy_global_to_local %arg0[%c0_i32, %c0_i32, %c0_i32, %c0_i32] %0, %1, %true : !ttng.tensordesc_im2col<64x128xf16, #nvmma_128>, !ttg.memdesc<1xi64, #shared2, #smem, mutable> -> !ttg.memdesc<64x128xf16, #nvmma_128, #smem, mutable> tt.return } } @@ -191,14 +191,14 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ #shared2 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> #smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { - tt.func public @tma_im2col_wrong_offset_count(%arg0: !ttng.tensordesc_im2col>) { + tt.func public @tma_im2col_wrong_offset_count(%arg0: !ttng.tensordesc_im2col<64x128xf16, #nvmma_128>) { %true = arith.constant true %c0_i32 = arith.constant 0 : i32 %c1_i16 = arith.constant 1 : i16 %0 = ttg.local_alloc : () -> !ttg.memdesc<64x128xf16, #nvmma_128, #smem, mutable> %1 = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #shared2, #smem, mutable> // expected-error @below {{IM2COL mode with 4D coordinates requires 2 offsets, but got 1}} - ttng.async_tma_copy_global_to_local %arg0[%c0_i32, %c0_i32, %c0_i32, %c0_i32] offsets = [%c1_i16] %0, %1, %true : !ttng.tensordesc_im2col>, !ttg.memdesc<1xi64, #shared2, #smem, mutable> -> !ttg.memdesc<64x128xf16, #nvmma_128, #smem, mutable> + ttng.async_tma_copy_global_to_local %arg0[%c0_i32, %c0_i32, %c0_i32, %c0_i32] offsets = [%c1_i16] %0, %1, %true : !ttng.tensordesc_im2col<64x128xf16, #nvmma_128>, !ttg.memdesc<1xi64, #shared2, #smem, mutable> -> !ttg.memdesc<64x128xf16, #nvmma_128, #smem, mutable> tt.return } } @@ -210,14 +210,14 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ #shared2 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> #smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { - tt.func public @tma_tiled_with_offsets(%arg0: !tt.tensordesc>) { + tt.func public @tma_tiled_with_offsets(%arg0: !tt.tensordesc<64x128xf16, #nvmma_128>) { %true = arith.constant true %c0_i32 = arith.constant 0 : i32 %c1_i16 = arith.constant 1 : i16 %0 = ttg.local_alloc : () -> !ttg.memdesc<64x128xf16, #nvmma_128, #smem, mutable> %1 = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #shared2, #smem, mutable> // expected-error @below {{TILED mode does not support offsets}} - ttng.async_tma_copy_global_to_local %arg0[%c0_i32, %c0_i32] offsets = [%c1_i16] %0, %1, %true : !tt.tensordesc>, !ttg.memdesc<1xi64, #shared2, #smem, mutable> -> !ttg.memdesc<64x128xf16, #nvmma_128, #smem, mutable> + ttng.async_tma_copy_global_to_local %arg0[%c0_i32, %c0_i32] offsets = [%c1_i16] %0, %1, %true : !tt.tensordesc<64x128xf16, #nvmma_128>, !ttg.memdesc<1xi64, #shared2, #smem, mutable> -> !ttg.memdesc<64x128xf16, #nvmma_128, #smem, mutable> tt.return } } @@ -229,13 +229,13 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ #shared2 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> #smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { - tt.func public @tma_im2col_2d_invalid(%arg0: !ttng.tensordesc_im2col>) { + tt.func public @tma_im2col_2d_invalid(%arg0: !ttng.tensordesc_im2col<64x128xf16, #nvmma_128>) { %true = arith.constant true %c0_i32 = arith.constant 0 : i32 %0 = ttg.local_alloc : () -> !ttg.memdesc<64x128xf16, #nvmma_128, #smem, mutable> %1 = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #shared2, #smem, mutable> // expected-error @below {{IM2COL mode requires at least 3D coordinates, but got 2D}} - ttng.async_tma_copy_global_to_local %arg0[%c0_i32, %c0_i32] %0, %1, %true : !ttng.tensordesc_im2col>, !ttg.memdesc<1xi64, #shared2, #smem, mutable> -> !ttg.memdesc<64x128xf16, #nvmma_128, #smem, mutable> + ttng.async_tma_copy_global_to_local %arg0[%c0_i32, %c0_i32] %0, %1, %true : !ttng.tensordesc_im2col<64x128xf16, #nvmma_128>, !ttg.memdesc<1xi64, #shared2, #smem, mutable> -> !ttg.memdesc<64x128xf16, #nvmma_128, #smem, mutable> tt.return } } @@ -605,13 +605,13 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar #shared_bar = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CGALayout = [[0]]}> #smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { - tt.func public @async_tma_copy_multicast_requires_broadcast(%arg0: !tt.tensordesc>) { + tt.func public @async_tma_copy_multicast_requires_broadcast(%arg0: !tt.tensordesc<64x128xf16, #nvmma_no_broadcast>) { %true = arith.constant true %c0_i32 = arith.constant 0 : i32 %0 = ttg.local_alloc : () -> !ttg.memdesc<64x128xf16, #nvmma_no_broadcast, #smem, mutable> %1 = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #shared_bar, #smem, mutable> // expected-error @below {{multicast requires the shared layout to broadcast across CTAs}} - ttng.async_tma_copy_global_to_local %arg0[%c0_i32, %c0_i32] %0, %1, %true {multicast} : !tt.tensordesc>, !ttg.memdesc<1xi64, #shared_bar, #smem, mutable> -> !ttg.memdesc<64x128xf16, #nvmma_no_broadcast, #smem, mutable> + ttng.async_tma_copy_global_to_local %arg0[%c0_i32, %c0_i32] %0, %1, %true {multicast} : !tt.tensordesc<64x128xf16, #nvmma_no_broadcast>, !ttg.memdesc<1xi64, #shared_bar, #smem, mutable> -> !ttg.memdesc<64x128xf16, #nvmma_no_broadcast, #smem, mutable> tt.return } } @@ -620,8 +620,8 @@ module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 8 : i32, ttg.targ // Test invalid TensorDescIm2ColType: rank-3 blockType (must be rank-2) module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { - // expected-error @below {{TensorDescIm2ColType requires rank-2 blockType, got rank 3}} - tt.func @tensordesc_im2col_wrong_rank(%desc: !ttng.tensordesc_im2col>) { + // expected-error @below {{TensorDescIm2ColType requires rank-2 shape, got rank 3}} + tt.func @tensordesc_im2col_wrong_rank(%desc: !ttng.tensordesc_im2col<32x64x128xf16>) { tt.return } } diff --git a/test/TritonNvidiaGPU/membar-cluster.mlir b/test/TritonNvidiaGPU/membar-cluster.mlir index 0a210c0a0d15..c49082d9481f 100644 --- a/test/TritonNvidiaGPU/membar-cluster.mlir +++ b/test/TritonNvidiaGPU/membar-cluster.mlir @@ -383,7 +383,7 @@ module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // CHECK: ttng.cluster_barrier // CHECK-NEXT: ttng.fence_mbarrier_init_release_cluster // CHECK-NEXT: ttng.async_tma_copy_global_to_local - tt.func @convert_layout_trivial_then_tma_multicast_cluster_barrier(%input: tensor<64x128xf16, #blockedTmaSrc>, %desc: !tt.tensordesc>) -> tensor<64x128xf16, #blockedTmaDst> { + tt.func @convert_layout_trivial_then_tma_multicast_cluster_barrier(%input: tensor<64x128xf16, #blockedTmaSrc>, %desc: !tt.tensordesc<64x128xf16, #nvmmaTma>) -> tensor<64x128xf16, #blockedTmaDst> { %c0 = arith.constant 0 : i32 %true = arith.constant true %barrier = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #barrierEncTma, #smem, mutable> @@ -391,7 +391,7 @@ module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, ttg.targ %cvt = ttg.convert_layout %input : tensor<64x128xf16, #blockedTmaSrc> -> tensor<64x128xf16, #blockedTmaDst> %dst = ttg.local_alloc : () -> !ttg.memdesc<64x128xf16, #nvmmaTma, #smem, mutable> ttng.async_tma_copy_global_to_local %desc[%c0, %c0] %dst, %barrier, %true {multicast} : - !tt.tensordesc>, !ttg.memdesc<1xi64, #barrierEncTma, #smem, mutable> -> !ttg.memdesc<64x128xf16, #nvmmaTma, #smem, mutable> + !tt.tensordesc<64x128xf16, #nvmmaTma>, !ttg.memdesc<1xi64, #barrierEncTma, #smem, mutable> -> !ttg.memdesc<64x128xf16, #nvmmaTma, #smem, mutable> ttng.wait_barrier %barrier, %c0 deps %dst : !ttg.memdesc<1xi64, #barrierEncTma, #smem, mutable>, !ttg.memdesc<64x128xf16, #nvmmaTma, #smem, mutable> @@ -487,7 +487,7 @@ module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // CHECK-NEXT: ttng.cluster_barrier {relaxed = true} // CHECK-NEXT: ttng.async_tma_copy_global_to_local // CHECK: tt.return - tt.func @cluster_tma_multicast_with_per_cta_barrier(%desc: !tt.tensordesc>) -> tensor<64x128xf16, #blocked> { + tt.func @cluster_tma_multicast_with_per_cta_barrier(%desc: !tt.tensordesc<64x128xf16, #nvmma>) -> tensor<64x128xf16, #blocked> { %c0 = arith.constant 0 : i32 %true = arith.constant true %cst = arith.constant dense<0.000000e+00> : tensor<64x128xf16, #blocked> @@ -495,7 +495,7 @@ module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, ttg.targ %barrier = ttg.local_alloc : () -> !ttg.memdesc<2xi64, #barrierEnc, #smem, mutable> ttng.init_barrier %barrier, 1 : !ttg.memdesc<2xi64, #barrierEnc, #smem, mutable> ttng.async_tma_copy_global_to_local %desc[%c0, %c0] %buf, %barrier, %true {multicast} : - !tt.tensordesc>, !ttg.memdesc<2xi64, #barrierEnc, #smem, mutable> -> !ttg.memdesc<64x128xf16, #nvmma, #smem, mutable> + !tt.tensordesc<64x128xf16, #nvmma>, !ttg.memdesc<2xi64, #barrierEnc, #smem, mutable> -> !ttg.memdesc<64x128xf16, #nvmma, #smem, mutable> ttng.wait_barrier %barrier, %c0 deps %buf : !ttg.memdesc<2xi64, #barrierEnc, #smem, mutable>, !ttg.memdesc<64x128xf16, #nvmma, #smem, mutable> @@ -513,7 +513,7 @@ module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // CHECK-NEXT: ttng.fence_mbarrier_init_release_cluster // CHECK-NEXT: ttng.cluster_barrier {relaxed = true} // CHECK: tt.return - tt.func @no_cluster_tma_without_multicast(%desc: !tt.tensordesc>) -> tensor<64x128xf16, #blocked> { + tt.func @no_cluster_tma_without_multicast(%desc: !tt.tensordesc<64x128xf16, #nvmma>) -> tensor<64x128xf16, #blocked> { %c0 = arith.constant 0 : i32 %true = arith.constant true %cst = arith.constant dense<0.000000e+00> : tensor<64x128xf16, #blocked> @@ -521,7 +521,7 @@ module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, ttg.targ %barrier = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #barrierEnc, #smem, mutable> ttng.init_barrier %barrier, 1 : !ttg.memdesc<1xi64, #barrierEnc, #smem, mutable> ttng.async_tma_copy_global_to_local %desc[%c0, %c0] %buf, %barrier, %true : - !tt.tensordesc>, !ttg.memdesc<1xi64, #barrierEnc, #smem, mutable> -> !ttg.memdesc<64x128xf16, #nvmma, #smem, mutable> + !tt.tensordesc<64x128xf16, #nvmma>, !ttg.memdesc<1xi64, #barrierEnc, #smem, mutable> -> !ttg.memdesc<64x128xf16, #nvmma, #smem, mutable> ttng.wait_barrier %barrier, %c0 deps %buf : !ttg.memdesc<1xi64, #barrierEnc, #smem, mutable>, !ttg.memdesc<64x128xf16, #nvmma, #smem, mutable> @@ -555,7 +555,7 @@ module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // CHECK: ttng.wait_barrier // CHECK: ttng.cluster_barrier // CHECK: tt.return - tt.func @no_cluster_when_same_allocation(%desc: !tt.tensordesc>) -> tensor<64x128xf16, #blocked> { + tt.func @no_cluster_when_same_allocation(%desc: !tt.tensordesc<64x128xf16, #nvmma>) -> tensor<64x128xf16, #blocked> { %c0 = arith.constant 0 : i32 %true = arith.constant true %cst = arith.constant dense<0.000000e+00> : tensor<64x128xf16, #blocked> @@ -563,7 +563,7 @@ module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, ttg.targ %barrier = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #barrierEnc, #smem, mutable> ttng.init_barrier %barrier, 1 : !ttg.memdesc<1xi64, #barrierEnc, #smem, mutable> ttng.async_tma_copy_global_to_local %desc[%c0, %c0] %buf, %barrier, %true {multicast} : - !tt.tensordesc>, !ttg.memdesc<1xi64, #barrierEnc, #smem, mutable> -> !ttg.memdesc<64x128xf16, #nvmma, #smem, mutable> + !tt.tensordesc<64x128xf16, #nvmma>, !ttg.memdesc<1xi64, #barrierEnc, #smem, mutable> -> !ttg.memdesc<64x128xf16, #nvmma, #smem, mutable> ttng.wait_barrier %barrier, %c0 deps %buf : !ttg.memdesc<1xi64, #barrierEnc, #smem, mutable>, !ttg.memdesc<64x128xf16, #nvmma, #smem, mutable> @@ -609,7 +609,7 @@ module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, "ttng.tw // CHECK-NOT: ttng.cluster_barrier // CHECK: ttng.tc_gen5_mma // CHECK: ttng.wait_barrier - tt.func @example_matmul(%a_desc: !tt.tensordesc>, %b_desc: !tt.tensordesc>) { + tt.func @example_matmul(%a_desc: !tt.tensordesc<256x16xf16, #sharedA>, %b_desc: !tt.tensordesc<16x64xf16, #sharedB>) { %c0 = arith.constant 0 : i32 %c1 = arith.constant 1 : i32 %true = arith.constant true @@ -630,9 +630,9 @@ module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, "ttng.tw ttng.barrier_expect %bTMA, 5120, %true : !ttg.memdesc<1xi64, #barrierTMA, #smem, mutable> %offs = arith.muli %k_i32, %c16 : i32 ttng.async_tma_copy_global_to_local %a_desc[%c0, %offs] %smem_a, %bTMA, %true : - !tt.tensordesc>, !ttg.memdesc<1xi64, #barrierTMA, #smem, mutable> -> !ttg.memdesc<256x16xf16, #sharedA, #smem, mutable> + !tt.tensordesc<256x16xf16, #sharedA>, !ttg.memdesc<1xi64, #barrierTMA, #smem, mutable> -> !ttg.memdesc<256x16xf16, #sharedA, #smem, mutable> ttng.async_tma_copy_global_to_local %b_desc[%offs, %c0] %smem_b, %bTMA, %true : - !tt.tensordesc>, !ttg.memdesc<1xi64, #barrierTMA, #smem, mutable> -> !ttg.memdesc<16x64xf16, #sharedB, #smem, mutable> + !tt.tensordesc<16x64xf16, #sharedB>, !ttg.memdesc<1xi64, #barrierTMA, #smem, mutable> -> !ttg.memdesc<16x64xf16, #sharedB, #smem, mutable> ttng.wait_barrier %bTMA, %phase, %true deps %smem_a, %smem_b : !ttg.memdesc<1xi64, #barrierTMA, #smem, mutable>, !ttg.memdesc<256x16xf16, #sharedA, #smem, mutable>, @@ -703,7 +703,7 @@ module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // CHECK: ttg.local_alloc // CHECK-NEXT: ttng.cluster_barrier // CHECK-NEXT: ttng.async_tma_copy_global_to_local - tt.func @cluster_barrier_between_lifetimes_same_offset(%desc: !tt.tensordesc>) -> tensor<64x128xf16, #blocked> { + tt.func @cluster_barrier_between_lifetimes_same_offset(%desc: !tt.tensordesc<64x128xf16, #nvmma>) -> tensor<64x128xf16, #blocked> { %c0 = arith.constant 0 : i32 %true = arith.constant true @@ -712,7 +712,7 @@ module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // a lifetime start %a = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<64x128xf16, #nvmma, #smem, mutable> ttng.async_tma_copy_global_to_local %desc[%c0, %c0] %a, %barrier, %true {multicast} : - !tt.tensordesc>, !ttg.memdesc<1xi64, #barrierEnc, #smem, mutable> -> !ttg.memdesc<64x128xf16, #nvmma, #smem, mutable> + !tt.tensordesc<64x128xf16, #nvmma>, !ttg.memdesc<1xi64, #barrierEnc, #smem, mutable> -> !ttg.memdesc<64x128xf16, #nvmma, #smem, mutable> ttng.wait_barrier %barrier, %c0 deps %a : !ttg.memdesc<1xi64, #barrierEnc, #smem, mutable>, !ttg.memdesc<64x128xf16, #nvmma, #smem, mutable> @@ -723,7 +723,7 @@ module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // b lifetime start %b = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<64x128xf16, #nvmma, #smem, mutable> ttng.async_tma_copy_global_to_local %desc[%c0, %c0] %b, %barrier, %true {multicast} : - !tt.tensordesc>, !ttg.memdesc<1xi64, #barrierEnc, #smem, mutable> -> !ttg.memdesc<64x128xf16, #nvmma, #smem, mutable> + !tt.tensordesc<64x128xf16, #nvmma>, !ttg.memdesc<1xi64, #barrierEnc, #smem, mutable> -> !ttg.memdesc<64x128xf16, #nvmma, #smem, mutable> ttng.wait_barrier %barrier, %c0 deps %b : !ttg.memdesc<1xi64, #barrierEnc, #smem, mutable>, !ttg.memdesc<64x128xf16, #nvmma, #smem, mutable> diff --git a/test/TritonNvidiaGPU/membar.mlir b/test/TritonNvidiaGPU/membar.mlir index 96a07161e756..571cb93d4839 100644 --- a/test/TritonNvidiaGPU/membar.mlir +++ b/test/TritonNvidiaGPU/membar.mlir @@ -86,7 +86,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { #shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> #smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { - tt.func public @tma_load(%arg0: !tt.tensordesc>, %arg1: i32) -> tensor<128x64xf16, #blocked0> { + tt.func public @tma_load(%arg0: !tt.tensordesc<128x64xf16, #shared>, %arg1: i32) -> tensor<128x64xf16, #blocked0> { // CHECK-LABEL: tma_load // CHECK: local_dealloc // CHECK-NEXT: local_alloc @@ -96,7 +96,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ %cst = arith.constant dense<0> : tensor<128x64xi64, #blocked0> %alloc = ttg.local_alloc %cst : (tensor<128x64xi64, #blocked0>) -> !ttg.memdesc<128x64xi64, #shared1, #smem, mutable> ttg.local_dealloc %alloc : !ttg.memdesc<128x64xi64, #shared1, #smem, mutable> - %l = tt.descriptor_load %arg0[%arg1, %arg1] : !tt.tensordesc> -> tensor<128x64xf16, #blocked0> + %l = tt.descriptor_load %arg0[%arg1, %arg1] : !tt.tensordesc<128x64xf16, #shared> -> tensor<128x64xf16, #blocked0> tt.return %l : tensor<128x64xf16, #blocked0> } } @@ -114,11 +114,11 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // CHECK-NEXT: ttg.local_dealloc // CHECK-NEXT: ttg.barrier local // CHECK-NEXT: ttg.local_alloc - tt.func public @tma_store(%arg0: !tt.tensordesc>, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: tensor<128x256xf32, #blocked0>) { + tt.func public @tma_store(%arg0: !tt.tensordesc<128x256xf32, #nvmma32>, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: tensor<128x256xf32, #blocked0>) { %cst = arith.constant dense<0> : tensor<128x64xi64, #blocked0> %alloc = ttg.local_alloc %cst : (tensor<128x64xi64, #blocked0>) -> !ttg.memdesc<128x64xi64, #shared0, #smem, mutable> ttg.local_dealloc %alloc : !ttg.memdesc<128x64xi64, #shared0, #smem, mutable> - tt.descriptor_store %arg0[%arg1, %arg1], %arg2 : !tt.tensordesc>, tensor<128x256xf32, #blocked0> + tt.descriptor_store %arg0[%arg1, %arg1], %arg2 : !tt.tensordesc<128x256xf32, #nvmma32>, tensor<128x256xf32, #blocked0> tt.return } } diff --git a/test/TritonNvidiaGPU/ops.mlir b/test/TritonNvidiaGPU/ops.mlir index 16156f5746c4..84c8e31af7bd 100644 --- a/test/TritonNvidiaGPU/ops.mlir +++ b/test/TritonNvidiaGPU/ops.mlir @@ -67,12 +67,12 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-SAME: [[BAR:%arg[0-9]+]]: // CHECK-SAME: [[RESULT:%arg[0-9]+]]: // CHECK-SAME: [[PRED:%arg[0-9]+]]: - tt.func @async_tma_gather(%desc: !tt.tensordesc>, %x_offsets: tensor<32xi32, #blocked>, %y_offset: i32, + tt.func @async_tma_gather(%desc: !tt.tensordesc<1x128xbf16, #shared>, %x_offsets: tensor<32xi32, #blocked>, %y_offset: i32, %bar: !ttg.memdesc<1xi64, #shared2, #ttg.shared_memory, mutable>, %result: !ttg.memdesc<32x128xbf16, #shared, #ttg.shared_memory, mutable>, %pred: i1) { - // CHECK-NEXT: ttng.async_tma_gather [[DESC]][[[X_OFFSETS]], [[Y_OFFSET]]] [[RESULT]], [[BAR]], [[PRED]] : !tt.tensordesc>, tensor<32xi32, #blocked>, i32, !ttg.memdesc<1xi64, #shared2, #smem, mutable>, !ttg.memdesc<32x128xbf16, #shared, #smem, mutable>, i1 - ttng.async_tma_gather %desc[%x_offsets, %y_offset] %result, %bar, %pred : !tt.tensordesc>, tensor<32xi32, #blocked>, i32, !ttg.memdesc<1xi64, #shared2, #ttg.shared_memory, mutable>, !ttg.memdesc<32x128xbf16, #shared, #ttg.shared_memory, mutable>, i1 + // CHECK-NEXT: ttng.async_tma_gather [[DESC]][[[X_OFFSETS]], [[Y_OFFSET]]] [[RESULT]], [[BAR]], [[PRED]] : !tt.tensordesc<1x128xbf16, #shared>, tensor<32xi32, #blocked>, i32, !ttg.memdesc<1xi64, #shared2, #smem, mutable>, !ttg.memdesc<32x128xbf16, #shared, #smem, mutable>, i1 + ttng.async_tma_gather %desc[%x_offsets, %y_offset] %result, %bar, %pred : !tt.tensordesc<1x128xbf16, #shared>, tensor<32xi32, #blocked>, i32, !ttg.memdesc<1xi64, #shared2, #ttg.shared_memory, mutable>, !ttg.memdesc<32x128xbf16, #shared, #ttg.shared_memory, mutable>, i1 tt.return } @@ -81,10 +81,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-SAME: [[X_OFFSETS:%arg[0-9]+]]: // CHECK-SAME: [[Y_OFFSET:%arg[0-9]+]]: // CHECK-SAME: [[SRC:%arg[0-9]+]]: - tt.func @async_tma_scatter(%desc: !tt.tensordesc>, %x_offsets: tensor<32xi32, #blocked>, %y_offset: i32, + tt.func @async_tma_scatter(%desc: !tt.tensordesc<1x128xbf16, #shared>, %x_offsets: tensor<32xi32, #blocked>, %y_offset: i32, %src: !ttg.memdesc<32x128xbf16, #shared, #ttg.shared_memory, mutable>) { - // CHECK-NEXT: ttng.async_tma_scatter [[DESC]][[[X_OFFSETS]], [[Y_OFFSET]]] [[SRC]] : !tt.tensordesc>, tensor<32xi32, #blocked>, i32, !ttg.memdesc<32x128xbf16, #shared, #smem, mutable> - ttng.async_tma_scatter %desc[%x_offsets, %y_offset] %src : !tt.tensordesc>, tensor<32xi32, #blocked>, i32, !ttg.memdesc<32x128xbf16, #shared, #ttg.shared_memory, mutable> + // CHECK-NEXT: ttng.async_tma_scatter [[DESC]][[[X_OFFSETS]], [[Y_OFFSET]]] [[SRC]] : !tt.tensordesc<1x128xbf16, #shared>, tensor<32xi32, #blocked>, i32, !ttg.memdesc<32x128xbf16, #shared, #smem, mutable> + ttng.async_tma_scatter %desc[%x_offsets, %y_offset] %src : !tt.tensordesc<1x128xbf16, #shared>, tensor<32xi32, #blocked>, i32, !ttg.memdesc<32x128xbf16, #shared, #ttg.shared_memory, mutable> tt.return } @@ -131,20 +131,20 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: @tma_load_im2col_3d // CHECK: ttng.async_tma_copy_global_to_local {{.*}} offsets = [{{.*}}] {{.*}} : !ttng.tensordesc_im2col - tt.func public @tma_load_im2col_3d(%desc: !ttng.tensordesc_im2col>) { + tt.func public @tma_load_im2col_3d(%desc: !ttng.tensordesc_im2col<64x128xf16, #nvmma_128>) { %true = arith.constant true %c0 = arith.constant 0 : i32 %off = arith.constant 1 : i16 %buf = ttg.local_alloc : () -> !ttg.memdesc<64x128xf16, #nvmma_128, #smem, mutable> %bar = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #shared3, #smem, mutable> ttng.init_barrier %bar, 1 : !ttg.memdesc<1xi64, #shared3, #smem, mutable> - ttng.async_tma_copy_global_to_local %desc[%c0, %c0, %c0] offsets = [%off] %buf, %bar, %true : !ttng.tensordesc_im2col>, !ttg.memdesc<1xi64, #shared3, #smem, mutable> -> !ttg.memdesc<64x128xf16, #nvmma_128, #smem, mutable> + ttng.async_tma_copy_global_to_local %desc[%c0, %c0, %c0] offsets = [%off] %buf, %bar, %true : !ttng.tensordesc_im2col<64x128xf16, #nvmma_128>, !ttg.memdesc<1xi64, #shared3, #smem, mutable> -> !ttg.memdesc<64x128xf16, #nvmma_128, #smem, mutable> tt.return } // CHECK-LABEL: @tma_load_im2col_4d // CHECK: ttng.async_tma_copy_global_to_local {{.*}} offsets = [{{.*}}, {{.*}}] {{.*}} : !ttng.tensordesc_im2col - tt.func public @tma_load_im2col_4d(%desc: !ttng.tensordesc_im2col>) { + tt.func public @tma_load_im2col_4d(%desc: !ttng.tensordesc_im2col<64x128xf16, #nvmma_128>) { %true = arith.constant true %c0 = arith.constant 0 : i32 %off1 = arith.constant 1 : i16 @@ -152,13 +152,13 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ %buf = ttg.local_alloc : () -> !ttg.memdesc<64x128xf16, #nvmma_128, #smem, mutable> %bar = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #shared3, #smem, mutable> ttng.init_barrier %bar, 1 : !ttg.memdesc<1xi64, #shared3, #smem, mutable> - ttng.async_tma_copy_global_to_local %desc[%c0, %c0, %c0, %c0] offsets = [%off1, %off2] %buf, %bar, %true : !ttng.tensordesc_im2col>, !ttg.memdesc<1xi64, #shared3, #smem, mutable> -> !ttg.memdesc<64x128xf16, #nvmma_128, #smem, mutable> + ttng.async_tma_copy_global_to_local %desc[%c0, %c0, %c0, %c0] offsets = [%off1, %off2] %buf, %bar, %true : !ttng.tensordesc_im2col<64x128xf16, #nvmma_128>, !ttg.memdesc<1xi64, #shared3, #smem, mutable> -> !ttg.memdesc<64x128xf16, #nvmma_128, #smem, mutable> tt.return } // CHECK-LABEL: @tma_load_im2col_5d // CHECK: ttng.async_tma_copy_global_to_local {{.*}} offsets = [{{.*}}, {{.*}}, {{.*}}] {{.*}} : !ttng.tensordesc_im2col - tt.func public @tma_load_im2col_5d(%desc: !ttng.tensordesc_im2col>) { + tt.func public @tma_load_im2col_5d(%desc: !ttng.tensordesc_im2col<64x128xf16, #nvmma_128>) { %true = arith.constant true %c0 = arith.constant 0 : i32 %off1 = arith.constant 1 : i16 @@ -167,26 +167,26 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ %buf = ttg.local_alloc : () -> !ttg.memdesc<64x128xf16, #nvmma_128, #smem, mutable> %bar = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #shared3, #smem, mutable> ttng.init_barrier %bar, 1 : !ttg.memdesc<1xi64, #shared3, #smem, mutable> - ttng.async_tma_copy_global_to_local %desc[%c0, %c0, %c0, %c0, %c0] offsets = [%off1, %off2, %off3] %buf, %bar, %true : !ttng.tensordesc_im2col>, !ttg.memdesc<1xi64, #shared3, #smem, mutable> -> !ttg.memdesc<64x128xf16, #nvmma_128, #smem, mutable> + ttng.async_tma_copy_global_to_local %desc[%c0, %c0, %c0, %c0, %c0] offsets = [%off1, %off2, %off3] %buf, %bar, %true : !ttng.tensordesc_im2col<64x128xf16, #nvmma_128>, !ttg.memdesc<1xi64, #shared3, #smem, mutable> -> !ttg.memdesc<64x128xf16, #nvmma_128, #smem, mutable> tt.return } // CHECK-LABEL: @tma_load_tiled_mode // CHECK: ttng.async_tma_copy_global_to_local {{.*}}[{{.*}}, {{.*}}] %{{.*}}, %{{.*}}, {{.*}} : !tt.tensordesc // CHECK-NOT: offsets - tt.func public @tma_load_tiled_mode(%desc: !tt.tensordesc>) { + tt.func public @tma_load_tiled_mode(%desc: !tt.tensordesc<64x128xf16, #nvmma_128>) { %true = arith.constant true %c0 = arith.constant 0 : i32 %buf = ttg.local_alloc : () -> !ttg.memdesc<64x128xf16, #nvmma_128, #smem, mutable> %bar = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #shared3, #smem, mutable> ttng.init_barrier %bar, 1 : !ttg.memdesc<1xi64, #shared3, #smem, mutable> - ttng.async_tma_copy_global_to_local %desc[%c0, %c0] %buf, %bar, %true : !tt.tensordesc>, !ttg.memdesc<1xi64, #shared3, #smem, mutable> -> !ttg.memdesc<64x128xf16, #nvmma_128, #smem, mutable> + ttng.async_tma_copy_global_to_local %desc[%c0, %c0] %buf, %bar, %true : !tt.tensordesc<64x128xf16, #nvmma_128>, !ttg.memdesc<1xi64, #shared3, #smem, mutable> -> !ttg.memdesc<64x128xf16, #nvmma_128, #smem, mutable> tt.return } // CHECK-LABEL: @tensordesc_im2col - // CHECK-SAME: !ttng.tensordesc_im2col> - tt.func public @tensordesc_im2col(%desc: !ttng.tensordesc_im2col>) { + // CHECK-SAME: !ttng.tensordesc_im2col<64x128xf16, {{.*}}> + tt.func public @tensordesc_im2col(%desc: !ttng.tensordesc_im2col<64x128xf16, #nvmma_128>) { // CHECK: tt.return tt.return } diff --git a/test/TritonNvidiaGPU/optimize_descriptor_encoding.mlir b/test/TritonNvidiaGPU/optimize_descriptor_encoding.mlir index db185179ae88..aff7ca9c66af 100644 --- a/test/TritonNvidiaGPU/optimize_descriptor_encoding.mlir +++ b/test/TritonNvidiaGPU/optimize_descriptor_encoding.mlir @@ -7,15 +7,15 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { // CHECK-DAG: #[[NVMMA_32:.*]] = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = false, elementBitWidth = 8}> tt.func public @tma_gather(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}, %arg3: tensor<32xi32, #blocked> ) -> tensor<32x32xi8, #blocked1> { - // CHECK: tt.make_tensor_descriptor {{.*}} : , > - // CHECK: tt.descriptor_gather {{.*}} : (!tt.tensordesc> + // CHECK: tt.make_tensor_descriptor {{.*}} : , <1x32xi8, #[[NVMMA_32]]> + // CHECK: tt.descriptor_gather {{.*}} : (!tt.tensordesc<1x32xi8, #[[NVMMA_32]]> %c1_i64 = arith.constant 1 : i64 %cst = arith.constant dense<32> : tensor<8x1xi32> %c64_i32 = arith.constant 64 : i32 %c8_i32 = arith.constant 8 : i32 %0 = arith.extsi %arg2 : i32 to i64 - %1 = tt.make_tensor_descriptor %arg0, [%arg1, %arg2], [%0, %c1_i64] : , > - %2 = tt.descriptor_gather %1[%arg3, %c8_i32] : (!tt.tensordesc>, tensor<32xi32, #blocked>, i32) -> tensor<32x32xi8, #blocked1> + %1 = tt.make_tensor_descriptor %arg0, [%arg1, %arg2], [%0, %c1_i64] : , <1x32xi8> + %2 = tt.descriptor_gather %1[%arg3, %c8_i32] : (!tt.tensordesc<1x32xi8>, tensor<32xi32, #blocked>, i32) -> tensor<32x32xi8, #blocked1> tt.return %2 : tensor<32x32xi8, #blocked1> } } @@ -28,15 +28,15 @@ tt.func public @tma_gather(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { // CHECK-DAG: #[[NVMMA_32:.*]] = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = false, elementBitWidth = 8}> tt.func public @tma_scatter(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}, %arg3: tensor<32xi32, #blocked>, %arg4: tensor<32x32xi8, #blocked1>) { - // CHECK: tt.make_tensor_descriptor {{.*}} : , > - // CHECK: tt.descriptor_scatter {{.*}} : !tt.tensordesc>, {{.*}} + // CHECK: tt.make_tensor_descriptor {{.*}} : , <1x32xi8, #[[NVMMA_32]]> + // CHECK: tt.descriptor_scatter {{.*}} : !tt.tensordesc<1x32xi8, #[[NVMMA_32]]>, {{.*}} %c1_i64 = arith.constant 1 : i64 %cst = arith.constant dense<32> : tensor<8x1xi32> %c64_i32 = arith.constant 64 : i32 %c8_i32 = arith.constant 8 : i32 %0 = arith.extsi %arg2 : i32 to i64 - %1 = tt.make_tensor_descriptor %arg0, [%arg1, %arg2], [%0, %c1_i64] : , > - tt.descriptor_scatter %1[%arg3, %c8_i32], %arg4 : !tt.tensordesc>, tensor<32xi32, #blocked>, i32, tensor<32x32xi8, #blocked1> + %1 = tt.make_tensor_descriptor %arg0, [%arg1, %arg2], [%0, %c1_i64] : , <1x32xi8> + tt.descriptor_scatter %1[%arg3, %c8_i32], %arg4 : !tt.tensordesc<1x32xi8>, tensor<32xi32, #blocked>, i32, tensor<32x32xi8, #blocked1> tt.return } } @@ -52,13 +52,13 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // CHECK-DAG: #[[SWIZZLE_MMA:.*]] = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 32, rank = 3}> // CHECK-DAG: #[[SWIZZLE_2D:.*]] = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> tt.func public @tma_scatter(%arg0: !tt.ptr, %arg1: i32, %arg2: i32, %arg3: i64, %arg4: i64) { - // CHECK: tt.make_tensor_descriptor {{.*}} : , > - // CHECK: %[[LOAD:.*]] = tt.descriptor_load {{.*}} : !tt.tensordesc> -> tensor<256x32xf32, #[[BLOCKED]]> + // CHECK: tt.make_tensor_descriptor {{.*}} : , <1x256x32xf32, #[[SWIZZLE_MMA]]> + // CHECK: %[[LOAD:.*]] = tt.descriptor_load {{.*}} : !tt.tensordesc<1x256x32xf32, #[[SWIZZLE_MMA]]> -> tensor<256x32xf32, #[[BLOCKED]]> // CHECK: ttg.local_alloc %[[LOAD]] : (tensor<256x32xf32, #[[BLOCKED]]>) -> !ttg.memdesc<256x32xf32, #[[SWIZZLE_2D]], #smem> %c1_i32 = arith.constant 1 : i32 %c1_i64 = arith.constant 1 : i64 - %0 = tt.make_tensor_descriptor %arg0, [%c1_i32, %arg1, %arg2], [%arg3, %arg4, %c1_i64] : , > - %1 = tt.descriptor_load %0[%c1_i32, %c1_i32, %c1_i32] : !tt.tensordesc> -> tensor<256x32xf32, #blocked> + %0 = tt.make_tensor_descriptor %arg0, [%c1_i32, %arg1, %arg2], [%arg3, %arg4, %c1_i64] : , <1x256x32xf32> + %1 = tt.descriptor_load %0[%c1_i32, %c1_i32, %c1_i32] : !tt.tensordesc<1x256x32xf32> -> tensor<256x32xf32, #blocked> %2 = ttg.local_alloc %1 : (tensor<256x32xf32, #blocked>) -> !ttg.memdesc<256x32xf32, #shared, #smem> tt.return } @@ -73,12 +73,12 @@ tt.func public @tma_scatter(%arg0: !tt.ptr, %arg1: i32, %arg2: i32, %arg3: module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { // CHECK-DAG: #[[BLOCKED:.*]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> // CHECK-DAG: #[[NVMMA_64:.*]] = #ttg.nvmma_shared<{swizzlingByteWidth = 64, transposed = false, elementBitWidth = 16}> -tt.func public @descriptor_kernel_arg(%arg0: !tt.tensordesc>, %arg1: i32, %arg2: i32, %arg3: i64, %arg4: i64) { - // CHECK: %arg0: !tt.tensordesc> - // CHECK: %[[LOAD:.*]] = tt.descriptor_load %arg0[{{.*}}] : !tt.tensordesc> -> tensor<64x64xf16, #[[BLOCKED]]> +tt.func public @descriptor_kernel_arg(%arg0: !tt.tensordesc<64x64xf16>, %arg1: i32, %arg2: i32, %arg3: i64, %arg4: i64) { + // CHECK: %arg0: !tt.tensordesc<64x64xf16, #[[NVMMA_64]]> + // CHECK: %[[LOAD:.*]] = tt.descriptor_load %arg0[{{.*}}] : !tt.tensordesc<64x64xf16, #[[NVMMA_64]]> -> tensor<64x64xf16, #[[BLOCKED]]> // CHECK: ttg.local_alloc %[[LOAD]] : (tensor<64x64xf16, #[[BLOCKED]]>) -> !ttg.memdesc<64x64xf16, #[[NVMMA_64]], #smem> %c1_i32 = arith.constant 1 : i32 - %1 = tt.descriptor_load %arg0[%c1_i32, %c1_i32] : !tt.tensordesc> -> tensor<64x64xf16, #blocked> + %1 = tt.descriptor_load %arg0[%c1_i32, %c1_i32] : !tt.tensordesc<64x64xf16> -> tensor<64x64xf16, #blocked> %2 = ttg.local_alloc %1 : (tensor<64x64xf16, #blocked>) -> !ttg.memdesc<64x64xf16, #shared, #smem> tt.return } @@ -101,22 +101,22 @@ tt.func public @tma_load_while(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %c1_i64 = arith.constant 1 : i64 %0 = arith.extsi %arg2 : i32 to i64 - // CHECK: tt.make_tensor_descriptor {{.*}} : , > - %1 = tt.make_tensor_descriptor %arg0, [%arg1, %arg2], [%0, %c1_i64] : , > + // CHECK: tt.make_tensor_descriptor {{.*}} : , <1x32xi8, #[[NVMMA_32]]> + %1 = tt.make_tensor_descriptor %arg0, [%arg1, %arg2], [%0, %c1_i64] : , <1x32xi8> - %2 = scf.while (%arg4 = %1) : (!tt.tensordesc>) -> (!tt.tensordesc>) { - scf.condition(%cond) %arg4 : !tt.tensordesc> + %2 = scf.while (%arg4 = %1) : (!tt.tensordesc<1x32xi8>) -> (!tt.tensordesc<1x32xi8>) { + scf.condition(%cond) %arg4 : !tt.tensordesc<1x32xi8> } do { - ^bb0(%arg4: !tt.tensordesc>): - // CHECK: ^bb0(%[[ARG4:.*]]: !tt.tensordesc>): - // CHECK: tt.descriptor_gather %[[ARG4]][{{.*}}] : (!tt.tensordesc> - %3 = tt.descriptor_gather %arg4[%arg3, %c8_i32] : (!tt.tensordesc>, tensor<32xi32, #blocked>, i32) -> tensor<32x32xi8, #blocked1> + ^bb0(%arg4: !tt.tensordesc<1x32xi8>): + // CHECK: ^bb0(%[[ARG4:.*]]: !tt.tensordesc<1x32xi8, #[[NVMMA_32]]>): + // CHECK: tt.descriptor_gather %[[ARG4]][{{.*}}] : (!tt.tensordesc<1x32xi8, #[[NVMMA_32]]> + %3 = tt.descriptor_gather %arg4[%arg3, %c8_i32] : (!tt.tensordesc<1x32xi8>, tensor<32xi32, #blocked>, i32) -> tensor<32x32xi8, #blocked1> - scf.yield %arg4 : !tt.tensordesc> + scf.yield %arg4 : !tt.tensordesc<1x32xi8> } - // CHECK: %[[GATHER:.*]] = tt.descriptor_gather {{.*}} : (!tt.tensordesc> - %4 = tt.descriptor_gather %1[%arg3, %c8_i32] : (!tt.tensordesc>, tensor<32xi32, #blocked>, i32) -> tensor<32x32xi8, #blocked1> + // CHECK: %[[GATHER:.*]] = tt.descriptor_gather {{.*}} : (!tt.tensordesc<1x32xi8, #[[NVMMA_32]]> + %4 = tt.descriptor_gather %1[%arg3, %c8_i32] : (!tt.tensordesc<1x32xi8>, tensor<32xi32, #blocked>, i32) -> tensor<32x32xi8, #blocked1> // CHECK: ttg.local_alloc %[[GATHER]] {{.*}} : (tensor<32x32xi8, #blocked1>) -> !ttg.memdesc<32x32xi8, #[[NVMMA_32]], #smem> %8 = ttg.local_alloc %4 {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<32x32xi8, #blocked1>) -> !ttg.memdesc<32x32xi8, #shared, #smem> diff --git a/test/TritonNvidiaGPU/tma_lowering.mlir b/test/TritonNvidiaGPU/tma_lowering.mlir index d62195c03e27..cecba988f44a 100644 --- a/test/TritonNvidiaGPU/tma_lowering.mlir +++ b/test/TritonNvidiaGPU/tma_lowering.mlir @@ -11,8 +11,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // CHECK: ttng.wait_barrier // CHECK: ttng.inval_barrier // CHECK: ttg.local_load - tt.func public @tma_load(%arg0: !tt.tensordesc>, %arg1: i32) -> tensor<128x64xf16, #blocked> { - %l = tt.descriptor_load %arg0[%arg1, %arg1] : !tt.tensordesc> -> tensor<128x64xf16, #blocked> + tt.func public @tma_load(%arg0: !tt.tensordesc<128x64xf16, #nvmma_128>, %arg1: i32) -> tensor<128x64xf16, #blocked> { + %l = tt.descriptor_load %arg0[%arg1, %arg1] : !tt.tensordesc<128x64xf16, #nvmma_128> -> tensor<128x64xf16, #blocked> tt.return %l : tensor<128x64xf16, #blocked> } } @@ -26,8 +26,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // CHECK: ttg.local_alloc {{.*}} -> !ttg.memdesc<128x256xf32, #shared, #smem> // CHECK: ttng.fence_async_shared {bCluster = false} // CHECK: ttng.async_tma_copy_local_to_global - tt.func public @tma_store(%arg0: !tt.tensordesc>, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: tensor<128x256xf32, #blocked>) { - tt.descriptor_store %arg0[%arg1, %arg1], %arg2 : !tt.tensordesc>, tensor<128x256xf32, #blocked> + tt.func public @tma_store(%arg0: !tt.tensordesc<128x256xf32, #nvmma_128>, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: tensor<128x256xf32, #blocked>) { + tt.descriptor_store %arg0[%arg1, %arg1], %arg2 : !tt.tensordesc<128x256xf32, #nvmma_128>, tensor<128x256xf32, #blocked> tt.return } } @@ -41,15 +41,15 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // CHECK: %1 = ttg.global_scratch_alloc {alignment = 128 : i32, nbytes = 128 : i32} : !tt.ptr // CHECK: ttng.tensormap_create %1, %arg0, [%c32_i32, %c8_i32], [%arg2, %arg1], [%0], [%c1_i32, %c1_i32] {elem_type = 0 : i32, fill_mode = 0 : i32, interleave_layout = 0 : i32, swizzle_mode = 1 : i32} : (!tt.ptr, !tt.ptr, i32, i32, i32, i32, i64, i32, i32) -> () // CHECK: ttng.tensormap_fenceproxy_acquire %1 : !tt.ptr - // CHECK: ttng.reinterpret_tensor_descriptor %1 : !tt.ptr to !tt.tensordesc> - tt.func public @make_tensor_descriptor(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32} ) -> !tt.tensordesc> { + // CHECK: ttng.reinterpret_tensor_descriptor %1 : !tt.ptr to !tt.tensordesc<8x32xi8, #shared> + tt.func public @make_tensor_descriptor(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32} ) -> !tt.tensordesc<8x32xi8, #nvmma_32> { %c1_i64 = arith.constant 1 : i64 %cst = arith.constant dense<32> : tensor<8x1xi32> %c64_i32 = arith.constant 64 : i32 %c8_i32 = arith.constant 8 : i32 %0 = arith.extsi %arg2 : i32 to i64 - %1 = tt.make_tensor_descriptor %arg0, [%arg1, %arg2], [%0, %c1_i64] : !tt.ptr, !tt.tensordesc> - tt.return %1 : !tt.tensordesc> + %1 = tt.make_tensor_descriptor %arg0, [%arg1, %arg2], [%0, %c1_i64] : !tt.ptr, !tt.tensordesc<8x32xi8, #nvmma_32> + tt.return %1 : !tt.tensordesc<8x32xi8, #nvmma_32> } } @@ -62,7 +62,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { // CHECK-LABEL: @tma_gather -tt.func @tma_gather(%arg0: !tt.tensordesc>, %arg1: tensor<32xi32, #blocked>, %arg2: i32) -> tensor<32x128xbf16, #blocked1> { +tt.func @tma_gather(%arg0: !tt.tensordesc<1x128xbf16, #nvmma_128>, %arg1: tensor<32xi32, #blocked>, %arg2: i32) -> tensor<32x128xbf16, #blocked1> { // CHECK: [[RESULT:%.*]] = ttg.local_alloc // CHECK: [[BARRIER:%.*]] = ttg.local_alloc // CHECK: ttng.init_barrier [[BARRIER]] @@ -70,18 +70,18 @@ tt.func @tma_gather(%arg0: !tt.tensordesc>, %arg1 // CHECK: ttng.wait_barrier [[BARRIER]] // CHECK: ttng.inval_barrier [[BARRIER]] // CHECK: [[OUT:%.*]] = ttg.local_load [[RESULT]] - %0 = tt.descriptor_gather %arg0[%arg1, %arg2] : (!tt.tensordesc>, tensor<32xi32, #blocked>, i32) -> tensor<32x128xbf16, #blocked1> + %0 = tt.descriptor_gather %arg0[%arg1, %arg2] : (!tt.tensordesc<1x128xbf16, #nvmma_128>, tensor<32xi32, #blocked>, i32) -> tensor<32x128xbf16, #blocked1> // CHECK: return [[OUT]] tt.return %0 : tensor<32x128xbf16, #blocked1> } // CHECK-LABEL: @tma_scatter -tt.func @tma_scatter(%arg0: !tt.tensordesc>, %arg1: tensor<32xi32, #blocked>, %arg2: i32, %arg3: tensor<32x128xbf16, #blocked1>) { +tt.func @tma_scatter(%arg0: !tt.tensordesc<1x128xbf16, #nvmma_128>, %arg1: tensor<32xi32, #blocked>, %arg2: i32, %arg3: tensor<32x128xbf16, #blocked1>) { // CHECK-NEXT: [[SRC:%.*]] = ttg.local_alloc %arg3 // CHECK-NEXT: ttng.fence_async_shared {bCluster = false} // CHECK-NEXT: ttng.async_tma_scatter %arg0[%arg1, %arg2] [[SRC]] // CHECK-NEXT: ttng.async_tma_store_wait - tt.descriptor_scatter %arg0[%arg1, %arg2], %arg3 : !tt.tensordesc>, tensor<32xi32, #blocked>, i32, tensor<32x128xbf16, #blocked1> + tt.descriptor_scatter %arg0[%arg1, %arg2], %arg3 : !tt.tensordesc<1x128xbf16, #nvmma_128>, tensor<32xi32, #blocked>, i32, tensor<32x128xbf16, #blocked1> tt.return } @@ -94,11 +94,11 @@ tt.func @tma_scatter(%arg0: !tt.tensordesc>, %arg // CHECK: #[[$NVMMA:.+]] = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 32}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABLE: @rank_reducing_load - tt.func public @rank_reducing_load(%arg0: !tt.tensordesc>) -> tensor<256x32xf32, #blocked> { + tt.func public @rank_reducing_load(%arg0: !tt.tensordesc<1x256x32xf32, #nvmma_128>) -> tensor<256x32xf32, #blocked> { %c32_i32 = arith.constant 32 : i32 // CHECK: %[[A:.+]] = ttg.local_alloc : () -> !ttg.memdesc<256x32xf32, #[[$NVMMA]], #smem, mutable> // CHECK: tng.async_tma_copy_global_to_local %{{.+}}[%{{.+}}, %{{.+}}, %{{.+}}] %[[A]], - %l = tt.descriptor_load %arg0[%c32_i32, %c32_i32, %c32_i32] : !tt.tensordesc> -> tensor<256x32xf32, #blocked> + %l = tt.descriptor_load %arg0[%c32_i32, %c32_i32, %c32_i32] : !tt.tensordesc<1x256x32xf32, #nvmma_128> -> tensor<256x32xf32, #blocked> tt.return %l : tensor<256x32xf32, #blocked> } } @@ -112,8 +112,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ // CHECK: #[[$NVMMA:.+]] = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 32}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: @tma_load_alloc_user - tt.func public @tma_load_alloc_user(%arg0: !tt.tensordesc>, %arg1: i32) -> (tensor<64x64xf32, #blocked>, !ttg.memdesc<64x64xf32, #shared, #smem, mutable>) { - %0 = tt.descriptor_load %arg0[%arg1, %arg1] : !tt.tensordesc> -> tensor<64x64xf32, #blocked> + tt.func public @tma_load_alloc_user(%arg0: !tt.tensordesc<64x64xf32, #nvmma_128>, %arg1: i32) -> (tensor<64x64xf32, #blocked>, !ttg.memdesc<64x64xf32, #shared, #smem, mutable>) { + %0 = tt.descriptor_load %arg0[%arg1, %arg1] : !tt.tensordesc<64x64xf32, #nvmma_128> -> tensor<64x64xf32, #blocked> // CHECK: %[[A:.+]] = ttg.local_alloc : () -> !ttg.memdesc<64x64xf32, #[[$NVMMA]], #smem, mutable> // CHECK: tng.async_tma_copy_global_to_local %{{.+}}[%{{.+}}, %{{.+}}] %[[A]], %1 = ttg.local_alloc %0 : (tensor<64x64xf32, #blocked>) -> !ttg.memdesc<64x64xf32, #shared, #smem, mutable> @@ -135,13 +135,13 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ #smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: @tma_load_double_use - tt.func public @tma_load_double_use(%arg0: !tt.tensordesc>, %arg1: !tt.tensordesc>) -> tensor<64x32xf32, #mma1> { + tt.func public @tma_load_double_use(%arg0: !tt.tensordesc<64x32xf32, #shared>, %arg1: !tt.tensordesc<64x64xf32, #shared1>) -> tensor<64x32xf32, #mma1> { %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> %cst_0 = arith.constant dense<0.000000e+00> : tensor<64x32xf32, #mma1> %c32_i32 = arith.constant 32 : i32 %c64_i32 = arith.constant 64 : i32 // CHECK: %[[A:.+]] = ttg.local_alloc : () -> !ttg.memdesc<64x32xf32 - %0 = tt.descriptor_load %arg0[%c64_i32, %c32_i32] : !tt.tensordesc> -> tensor<64x32xf32, #blocked> + %0 = tt.descriptor_load %arg0[%c64_i32, %c32_i32] : !tt.tensordesc<64x32xf32, #shared> -> tensor<64x32xf32, #blocked> // CHECK: %[[B:.+]] = ttg.local_load %[[A]] // CHECK: %[[C:.+]] = ttg.local_alloc %[[B]] %1 = ttg.local_alloc %0 : (tensor<64x32xf32, #blocked>) -> !ttg.memdesc<64x32xf32, #shared1, #smem> diff --git a/third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h b/third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h index 3d7696862d76..a9412a0ec596 100644 --- a/third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h +++ b/third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h @@ -44,7 +44,7 @@ namespace mlir::triton::amdgpu { /// 2D tensors: group0 (4) + group1 (8) = 12 dwords /// 3D-5D tensors: group0 (4) + group1 (8) + group2 (4) + group3 (4) = 20 dwords inline int getTensorDescNumDwords(triton::TensorDescType type) { - auto shape = type.getBlockType().getShape(); + auto shape = type.getShape(); return (shape.size() > 2) ? (4 + 8 + 4 + 4) : (4 + 8); } } // namespace mlir::triton::amdgpu diff --git a/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp b/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp index f65d6a127656..f73187e80cc7 100644 --- a/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp +++ b/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp @@ -598,7 +598,7 @@ LogicalResult AsyncTDMCopyGlobalToLocalOp::verify() { auto smemTy = getResult().getType(); // Check that every dimension of the block shape is <= 2^16 - auto blockShape = tensorDescTy.getBlockType().getShape(); + auto blockShape = tensorDescTy.getShape(); auto verifyResult = verifyTDMBlockSize(getOperation(), blockShape); if (failed(verifyResult)) return verifyResult; @@ -674,7 +674,7 @@ LogicalResult AsyncTDMCopyLocalToGlobalOp::verify() { auto smemTy = getSrc().getType(); // Check that every dimension of the block shape is <= 2^16 - auto blockShape = tensorDescTy.getBlockType().getShape(); + auto blockShape = tensorDescTy.getShape(); auto verifyResult = verifyTDMBlockSize(getOperation(), blockShape); if (failed(verifyResult)) return verifyResult; @@ -722,7 +722,7 @@ LogicalResult AsyncTDMScatterOp::verify() { auto smemTy = getSrc().getType(); // TDM scatter mode only supports 2D tensors - auto blockShape = tensorDescTy.getBlockType().getShape(); + auto blockShape = tensorDescTy.getShape(); if (blockShape.size() != 2) return emitOpError("TDM scatter only supports 2D tensors, got ") << blockShape.size() << "D"; @@ -772,7 +772,7 @@ LogicalResult AsyncTDMGatherOp::verify() { auto smemTy = getDst().getType(); // TDM gather mode only supports 2D tensors - auto blockShape = tensorDescTy.getBlockType().getShape(); + auto blockShape = tensorDescTy.getShape(); if (blockShape.size() != 2) return emitOpError("TDM gather only supports 2D tensors, got ") << blockShape.size() << "D"; @@ -880,9 +880,8 @@ LogicalResult TDMPrefetchOp::inferReturnTypes( } auto descType = cast(ad.getDesc().getType()); - auto blockType = descType.getBlockType(); - auto blockShape = blockType.getShape(); - auto elementType = blockType.getElementType(); + auto blockShape = descType.getShape(); + auto elementType = descType.getElementType(); // Lookup the module to get the number of threads per warp, number of warps // and number of CTAs diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp index 09b21d13f7fc..e12203013a7d 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -1193,17 +1193,16 @@ struct AsyncTDMCopyGlobalToLocalOpConversion auto b = TritonLLVMOpBuilder(loc, rewriter); auto tensorDescTy = op.getDesc().getType(); - auto descBlockTy = tensorDescTy.getBlockType(); - auto encoding = descBlockTy.getEncoding(); + auto encoding = tensorDescTy.getSharedLayout(); Type elementType = - getTypeConverter()->convertType(descBlockTy.getElementType()); + getTypeConverter()->convertType(tensorDescTy.getElementType()); // Use descBlockTy to query shared layout because TDM lowering logic expects // the descriptor's dimensionality. For rank-reducing loads, destination // shared memory may have fewer dimensions than the descriptor block type. triton::LinearLayout sharedLayout = isPaddedEncoding(encoding) - ? paddedLinearLayout(descBlockTy.getShape(), encoding) - : toLinearLayout(descBlockTy); + ? paddedLinearLayout(tensorDescTy.getShape(), encoding) + : toLinearLayout(tensorDescTy.getShape(), encoding); // Extract padding information if present unsigned padInterval = 0; unsigned padAmount = 0; @@ -1224,8 +1223,7 @@ struct AsyncTDMCopyGlobalToLocalOpConversion SmallVector desc = unpackLLElements(loc, adaptor.getDesc(), rewriter); - SmallVector blockShape = - llvm::to_vector(tensorDescTy.getBlockType().getShape()); + SmallVector blockShape = llvm::to_vector(tensorDescTy.getShape()); // 2D tensors: 12 dwords (group0: 4, group1: 8) // 3D-5D tensors: 20 dwords (group0: 4, group1: 8, group2: 4, group3: 4) @@ -1251,10 +1249,10 @@ struct AsyncTDMCopyGlobalToLocalOpConversion auto ctaId = targetInfo.getClusterCTAId(rewriter, loc); - auto shapePerCTA = triton::gpu::getShapePerCTA(descBlockTy); + auto shapePerCTA = + triton::gpu::getShapePerCTA(encoding, tensorDescTy.getShape()); auto sharedOrder = triton::gpu::getOrder( - cast(descBlockTy.getEncoding()), - shapePerCTA); + cast(encoding), shapePerCTA); bool isRowMajor = sharedOrder[0] == (sharedOrder.size() - 1); mlir::LLVM::AMD::emitTDMLoadStore( @@ -1293,8 +1291,7 @@ struct AsyncTDMCopyLocalToGlobalOpConversion SmallVector desc = unpackLLElements(loc, adaptor.getDesc(), rewriter); - SmallVector blockShape = - llvm::to_vector(tensorDescTy.getBlockType().getShape()); + SmallVector blockShape = llvm::to_vector(tensorDescTy.getShape()); // 2D tensors: 12 dwords (group0: 4, group1: 8) // 3D-5D tensors: 20 dwords (group0: 4, group1: 8, group2: 4, group3: 4) @@ -1380,8 +1377,7 @@ struct AsyncTDMScatterOpConversion SmallVector desc = unpackLLElements(loc, adaptor.getDesc(), rewriter); - SmallVector blockShape = - llvm::to_vector(tensorDescTy.getBlockType().getShape()); + SmallVector blockShape = llvm::to_vector(tensorDescTy.getShape()); // Scatter only supports 2D tensors assert(blockShape.size() == 2 && @@ -1467,8 +1463,7 @@ struct AsyncTDMGatherOpConversion SmallVector desc = unpackLLElements(loc, adaptor.getDesc(), rewriter); - SmallVector blockShape = - llvm::to_vector(tensorDescTy.getBlockType().getShape()); + SmallVector blockShape = llvm::to_vector(tensorDescTy.getShape()); // Gather only supports 2D tensors assert(blockShape.size() == 2 && @@ -2479,10 +2474,9 @@ struct TDMPrefetchConversion auto b = TritonLLVMOpBuilder(loc, rewriter); auto tdescType = op.getDesc().getType(); - auto tensorType = tdescType.getBlockType(); - SmallVector blockShape = llvm::to_vector(tensorType.getShape()); + SmallVector blockShape = llvm::to_vector(tdescType.getShape()); Type elementType = - getTypeConverter()->convertType(tensorType.getElementType()); + getTypeConverter()->convertType(tdescType.getElementType()); SmallVector desc = unpackLLElements(loc, adaptor.getDesc(), rewriter); SmallVector offset = adaptor.getIndices(); diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TensorPtrOpsToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TensorPtrOpsToLLVM.cpp index 8e74cd004e83..724469c6fcdb 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TensorPtrOpsToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TensorPtrOpsToLLVM.cpp @@ -108,8 +108,7 @@ struct MakeTensorDescOpConversion auto result = op.getResult(); auto tensorDescTy = result.getType(); - auto blockTy = tensorDescTy.getBlockType(); - auto sharedEnc = blockTy.getEncoding(); + auto sharedEnc = tensorDescTy.getSharedLayout(); if (!sharedEnc) { return rewriter.notifyMatchFailure( op, "Descriptor has no shared memory layout assigned."); @@ -125,8 +124,8 @@ struct MakeTensorDescOpConversion } Type elementType = - getTypeConverter()->convertType(blockTy.getElementType()); - SmallVector blockShape = to_vector(blockTy.getShape()); + getTypeConverter()->convertType(tensorDescTy.getElementType()); + SmallVector blockShape = to_vector(tensorDescTy.getShape()); int numWarps = lookupNumWarps(op); auto shapePerCTA = triton::gpu::getShapePerCTA(sharedEnc, blockShape); diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/Utility.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/Utility.cpp index a3e565bd4a2b..f275f433a10f 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/Utility.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/Utility.cpp @@ -465,12 +465,13 @@ composePaddedLayoutWMMA(int opIdx, unsigned vecWidth, ttg::SharedEncodingTrait getEncodingFromDescriptor(Operation *op, RankedTensorType tensorType, Value desc) { - auto descBlockType = cast(desc.getType()).getBlockType(); - auto encoding = cast(descBlockType.getEncoding()); - if (!encoding) { + auto descTy = cast(desc.getType()); + auto sharedLayout = descTy.getSharedLayout(); + if (!sharedLayout) { emitError(op->getLoc()) << "Missing encoding on the tensor descriptor"; return {}; } + auto encoding = cast(sharedLayout); return ttg::updateEncodingForShape(op, encoding, tensorType); } diff --git a/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/WSDataPartition.cpp b/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/WSDataPartition.cpp index 29820cf3ad84..6e19a7f2653d 100644 --- a/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/WSDataPartition.cpp +++ b/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/WSDataPartition.cpp @@ -188,8 +188,7 @@ static SmallVector getShape(Type type) { else if (auto tensorType = dyn_cast(type)) return {tensorType.getShape().begin(), tensorType.getShape().end()}; else if (auto tensorDescType = dyn_cast(type)) - return {tensorDescType.getBlockType().getShape().begin(), - tensorDescType.getBlockType().getShape().end()}; + return {tensorDescType.getShape().begin(), tensorDescType.getShape().end()}; else if (auto ptrType = dyn_cast(type)) return getShape(ptrType.getPointeeType()); return {}; @@ -837,15 +836,11 @@ static Operation *sliceOp(Operation *op, int offset, IRMapping &mappings, type.getEncoding()); newV.setType(newType); } else if (auto type = dyn_cast(v.getType())) { - auto blockType = type.getBlockType(); - SmallVector shape{blockType.getShape().begin(), - blockType.getShape().end()}; + SmallVector shape(type.getShape()); int sliceSize = shape[dim] / numOfPartitions; shape[dim] = sliceSize; - auto newBlockType = RankedTensorType::get( - shape, blockType.getElementType(), blockType.getEncoding()); - auto newType = - TensorDescType::get(builder.getContext(), newBlockType); + auto newType = TensorDescType::get(shape, type.getElementType(), + type.getSharedLayout()); newV.setType(newType); } } diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp index e6ed46e3542a..bcc32d7dc29f 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -1264,7 +1264,7 @@ struct AsyncTMACopyGlobalToLocalOpConversion auto offsets = applyLinearLayout(loc, rewriter, msgToOffset, {{kMsg, copyIdxVal}, {kBlock, ctaId}}); int operandIdx = 3; - auto encoding = op.getDesc().getType().getBlockType().getEncoding(); + auto encoding = op.getDesc().getType().getSharedLayout(); bool fp4Padded = nvidia_gpu::isFp4Padded(encoding); for (int i = 0; i < rank; i++) { Value coord = adaptor.getCoord()[rank - i - 1];