diff --git a/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/include/triton/Conversion/TritonGPUToLLVM/Utility.h index fa1d5fcdb02f..aadaf34c841d 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Utility.h +++ b/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -356,12 +356,12 @@ class SharedMemoryObject { RewriterBase &rewriter) const; // Returns a mask representing all the bits of the memdesc offsets that - // may be modified by an affine offset coming from a memdesc_subview. + // may be modified by an affine offset coming from a memdesc_subslice. // The offsets are considered to be in the type of the memdesc. // For padded layouts, we return the offsets without padding. static uint64_t getMaskSpanOffsets(triton::gpu::MemDescType srcTy); - // Returns whether the shared memory access had a memdesc_subview + // Returns whether the shared memory access had a memdesc_subslice // that is rank-preserving (soon to be called memdesc_slice) static bool isAffineSharedMemoryAccess(triton::gpu::MemDescType srcTy) { return getMaskSpanOffsets(srcTy) != 0; diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td index 5bbf8d5c2913..e8064f9069aa 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td @@ -200,38 +200,57 @@ def TTG_LocalDeallocOp : TTG_Op<"local_dealloc"> { // Use qualified() otherwise "!ttg.memdesc" is printed as "". let assemblyFormat = [{$src attr-dict `:` qualified(type($src))}]; } - -def TTG_MemDescSubviewOp : TTG_Op<"memdesc_subview", [Pure, MemDescViewTrait]> { +def TTG_MemDescIndexOp : TTG_Op<"memdesc_index", [Pure, MemDescViewTrait]> { let summary = "take a subview of the descriptor."; let description = [{ - This operation returns a new descriptor representing a subview of the buffer. + This operation returns a new descriptor pointing to the `i`-th element of the + input descriptor along the 0-th dimension. + It doesn't affect the underlying memory. For example, suppose that - the input shape is 2x4x16xf16, - the output shape is 4x16xf16, and - - offsets = [1, 0, 0]. + - index = 1. + Then the output descriptor is equivalent to input[1], where input is the logical tensor. - Then in Python syntax, the subview covers input[1]. + When the input is of rank 1 (i.e, shape=[k]), the output will have shape=[1]. + }]; - Just one dimension may be split (at most one non-zero offset). + let arguments = (ins TTG_MemDescType:$src, I32:$index); - When the input shape and the output shape have different rank: - Or the output shape is a tensor of 1D tensor of 1 element: - - The rank of the output must be 1D smaller than the input. - - We assume the input is split along the 0th dimension. - - The offset along the 0th dimension may be a runtime value. - When the input and the output have the same rank: - - The offset must be a compile-time constant - - Larger or equal to the tile of the tensor (or zero) - - That does not split the input along the swizzling pattern (if any) - }]; - let arguments = ( - ins TTG_MemDescType:$src, Variadic:$offsets); + let results = (outs TTG_MemDescType:$result); + + let assemblyFormat = [{$src `,` $index attr-dict `:` qualified(type($src)) `->` qualified(type($result))}]; + + let hasVerifier = 1; +} +def TTG_MemDescSubsliceOp : TTG_Op<"memdesc_subslice", [Pure, MemDescViewTrait]> { + let summary = "take a subview of the descriptor."; + + let description = [{ + This operation returns a new descriptor representing a subview of the logical tensor. + It doesn't affect the underlying memory. + + For example, suppose that + - the input shape is 32x16xf16, + - the output shape is 8x16xf16, and + - offsets = [2, 1]. + Then in Python syntax, the subview covers input[2:8+2, 1:16+1] where input is + the logical tensor. + + The offsets must be larger or equal to the tile of the tensor (or zero). + }]; + let arguments = (ins TTG_MemDescType:$src, DenseI32ArrayAttr:$offsets); // Use qualified() otherwise "!ttg.memdesc" is printed as "". - let assemblyFormat = [{$src `[` $offsets `]` attr-dict `:` qualified(type($src)) `->` qualified(type($result))}]; + // Render offsets inline as %src[0, 0] via a custom directive, but keep + // the overall parse/print generated from this assemblyFormat. + let assemblyFormat = [{ + $src `[` custom($offsets) `]` attr-dict `:` qualified(type($src)) + `->` qualified(type($result)) + }]; let results = (outs TTG_MemDescType:$result); diff --git a/include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h b/include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h index 14afd7da9e0e..ced09049db73 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h +++ b/include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h @@ -133,11 +133,11 @@ gpu::SharedEncodingTrait getSharedEncoding(Operation *loadOp); // specified. int getNumStagesOrDefault(scf::ForOp forOp, int defaultNumStages); -// Given a result of MemDescSubview, or Alloca, create a MemDescSubview with a +// Given a result of MemDescIndex, or Alloca, create a MemDescIndex with a // single buffer slice (leading dimension equal to 1), at the given index. TypedValue createSingleBufferView(OpBuilder &builder, Value alloc, Value idx); -// Given a result of MemDescSubview, or Alloca, create a MemDescSubview with a +// Given a result of MemDescIndex, or Alloca, create a MemDescIndex with a // single buffer slice (leading dimension equal to 1), at the given index. TypedValue createSingleBufferView(OpBuilder &builder, Value alloc, int idx); diff --git a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td index 5d9410524ded..18133512e29d 100644 --- a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td +++ b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td @@ -674,8 +674,8 @@ def TTNG_TMEMSubSliceOp : TTNG_Op<"tmem_subslice", [Pure]> { let description = [{ This operation takes a subslice of a tensor memory allocation and returns a new descriptor containing the address and a view of the subslice. - This is similar to ttg.memdesc_subview except the offset needs to be static and we can only - slice alog the inner dimension of a 2D memdesc as this is the only one we can do for TMem. + This is similar to ttg.memdesc_subslice except we can only slice along the inner dimension + of a 2D memdesc as this is the only one we can do for TMem. }]; let arguments = (ins TTG_MemDescType:$src, I32Attr:$N); diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index 41e5b6e4f6d6..c9ac807dab02 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -705,7 +705,7 @@ bool emitTransferBetweenRegistersAndShared( {kBlock, blockId}}, regIds); - // Compute affine offset given by memdesc_subview + // Compute affine offset given by memdesc_subslice auto offset = smemObj.getShmemOffset(loc, rewriter, sharedTy); SmallVector vecAddrVec; for (auto &indices : indicesVec) { @@ -1153,7 +1153,7 @@ Value SharedMemoryObject::getShmemOffset(Location loc, RewriterBase &rewriter, auto ctx = srcTy.getContext(); auto b = TritonLLVMOpBuilder(loc, rewriter); - // If it did not have a memdesc_subview, we don't need to compute the offset + // If it did not have a memdesc_subslice we don't need to compute the offset // as it is zero if (!isAffineSharedMemoryAccess(srcTy)) { return b.i32_val(0); diff --git a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp index 0316fb266d4a..966a495efade 100644 --- a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp @@ -465,13 +465,46 @@ struct BroadcastOpConversion } }; -struct MemDescSubviewOpConversion - : public ConvertOpToLLVMPattern { +struct MemDescIndexOpConversion + : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern< - triton::gpu::MemDescSubviewOp>::ConvertOpToLLVMPattern; + triton::gpu::MemDescIndexOp>::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(triton::gpu::MemDescSubviewOp op, OpAdaptor adaptor, + matchAndRewrite(triton::gpu::MemDescIndexOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto *ctx = op->getContext(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto srcTy = op.getSrc().getType(); + auto destTy = op.getResult().getType(); + auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType()); + + auto smemObj = getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), + llvmElemTy, rewriter); + auto base = smemObj.getBase(); + auto elemPtrTy = base.getType(); + Value stride = smemObj.getStrides(srcTy, loc, rewriter).front(); + Value offset = b.mul(op.getIndex(), stride); + auto prevOffsets = smemObj.getOffsets(); + SmallVector offsetVals(prevOffsets.end() - destTy.getRank(), + prevOffsets.end()); + // Advance the pointer and keep the opOffsets as the new shape + smemObj = SharedMemoryObject(b.gep(elemPtrTy, llvmElemTy, base, offset), + llvmElemTy, offsetVals); + auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter); + rewriter.replaceOp(op, retVal); + return success(); + } +}; + +struct MemDescSubsliceOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + triton::gpu::MemDescSubsliceOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::gpu::MemDescSubsliceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); auto *ctx = op->getContext(); @@ -484,40 +517,17 @@ struct MemDescSubviewOpConversion auto smemObj = getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), llvmElemTy, rewriter); - SmallVector opOffsetVals = op.getOffsets(); - // We assume we always create a subview of the last dimensions - // Compute total offset - auto rankReduced = srcTy.getRank() - destTy.getRank(); + auto opOffsetVals = op.getOffsets(); auto base = smemObj.getBase(); auto elemPtrTy = base.getType(); - auto is1d = srcTy.getRank() == 1 && destTy.getRank() == 1 && - destTy.getDimSize(0) == 1; - if (rankReduced || is1d) { - auto smemStrides = smemObj.getStrides(srcTy, loc, rewriter); - SmallVector opSmemStrides(smemStrides.end() - opOffsetVals.size(), - smemStrides.end()); - // We are splitting the pipelining dimension which may not be a power of 2 - // so we can't use LinearLayouts - auto offset = dot(rewriter, loc, opOffsetVals, opSmemStrides); - // Remove the first offsets - SmallVector offsetVals; - for (int i = rankReduced; i < opOffsetVals.size(); i++) { - offsetVals.push_back(b.add(opOffsetVals[i], smemObj.getOffsets()[i])); - } - // Advance the pointer and keep the opOffsets as the new shape - smemObj = SharedMemoryObject(b.gep(elemPtrTy, llvmElemTy, base, offset), - llvmElemTy, offsetVals); - } else { - // Accumulate the logical offsets - SmallVector offsetVals; - for (auto [oldOff, newOff] : - llvm::zip(smemObj.getOffsets(), opOffsetVals)) { - offsetVals.push_back(b.add(oldOff, newOff)); - } - smemObj = SharedMemoryObject(base, llvmElemTy, offsetVals); + // Accumulate the logical offsets + SmallVector offsetVals; + for (auto [oldOffVal, opOff] : + llvm::zip(smemObj.getOffsets(), opOffsetVals)) { + offsetVals.push_back(b.add(oldOffVal, b.i32_val(opOff))); } - + smemObj = SharedMemoryObject(base, llvmElemTy, offsetVals); auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter); rewriter.replaceOp(op, retVal); return success(); @@ -563,6 +573,7 @@ void mlir::triton::populateViewOpToLLVMPatterns( typeConverter, benefit); patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); - patterns.add(typeConverter, benefit); + patterns.add( + typeConverter, benefit); patterns.add(typeConverter, benefit); } diff --git a/lib/Dialect/TritonGPU/IR/Ops.cpp b/lib/Dialect/TritonGPU/IR/Ops.cpp index 2b73905ac03a..7639f7be9aba 100644 --- a/lib/Dialect/TritonGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonGPU/IR/Ops.cpp @@ -13,6 +13,29 @@ #include "llvm/Support/Casting.h" #include "llvm/Support/LogicalResult.h" +// Provide custom directive handlers for declarative assemblyFormat. +// They must be visible before including the generated op classes. +static mlir::ParseResult parseOffsets(mlir::OpAsmParser &p, + mlir::DenseI32ArrayAttr &attr) { + llvm::SmallVector values; + if (p.parseCommaSeparatedList([&]() { + int32_t v; + if (p.parseInteger(v)) + return mlir::failure(); + values.push_back(v); + return mlir::success(); + })) + return mlir::failure(); + attr = p.getBuilder().getDenseI32ArrayAttr(values); + return mlir::success(); +} + +static void printOffsets(mlir::OpAsmPrinter &p, mlir::Operation *op, + mlir::DenseI32ArrayAttr attr) { + auto vals = attr.asArrayRef(); + llvm::interleaveComma(vals, p, [&](int32_t v) { p << v; }); +} + #define GET_OP_CLASSES #include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc" @@ -475,7 +498,7 @@ LogicalResult MemDescReshapeOp::verify() { } auto srcShape = srcType.getShape(); if (srcType.getAllocShape().take_back(srcShape.size()) != srcShape) { - return emitError("NYI: memdesc_reshape of memdesc_subviews"); + return emitError("NYI: memdesc_reshape of memdesc_subslice"); } MemDescType expectedTy; @@ -522,8 +545,7 @@ static LogicalResult inferMemDescReshapeOpEncoding(ArrayRef srcShape, mmaEncoding.getElementBitWidth(), mmaEncoding.getFp4Padded(), CTALayout); // Big guns, check linear layouts are equivalent - // We disallow reshaping memdesc_subviews in the verifier - // We disallow reshaping memdesc_subviews in the verifier + // We disallow reshaping memdesc_subslice in the verifier auto srcLL = toLinearLayout(srcShape, srcEnc, srcShape); auto dstLL = toLinearLayout(dstShape, dstEnc, dstShape); if (reshapeLayout(ctx, srcLL, dstShape) != dstLL) { @@ -662,41 +684,41 @@ LogicalResult AsyncCopyGlobalToLocalOp::verify() { return success(); } -LogicalResult MemDescSubviewOp::verify() { +LogicalResult MemDescIndexOp::verify() { auto srcTy = getSrc().getType(); auto dstTy = getType(); - if (srcTy.getElementType() != dstTy.getElementType()) { return emitError("result element type must match desc element type"); } - if (getOffsets().size() != srcTy.getRank()) { - return emitError("offsets must have the same rank as input"); + bool is1D = + srcTy.getRank() == 1 && dstTy.getRank() == 1 && dstTy.getDimSize(0) == 1; + bool correctRank = srcTy.getRank() == dstTy.getRank() + 1 || is1D; + if (!correctRank) { + return emitError( + "result rank must be less than or equal to input rank or 1D -> 1D"); } - if (srcTy.getRank() < dstTy.getRank()) { - return emitError("result rank must be less than or equal to input rank"); + if (srcTy.getAllocShape().size() != srcTy.getRank()) { + return emitError("We don't allow taking memdesc_index of a memdesc_index"); } - auto rankDiff = srcTy.getRank() - dstTy.getRank(); - for (int i = 0; i < dstTy.getRank(); i++) { - if (dstTy.getDimSize(i) > srcTy.getDimSize(i + rankDiff)) { - return emitError( - "result shape cannot be larger than input shape at dimension ") - << i; - } + + if (!is1D && ArrayRef(srcTy.getShape()).take_back(dstTy.getRank()) != + dstTy.getShape()) { + return emitError("result shape must equal to srcShape[1:]"); + } + + bool isSubview = srcTy.getAllocShape() != srcTy.getShape(); + if (isSubview) { + return emitError("We don't support memdesc_index of a subview"); } auto srcEnc = srcTy.getEncoding(); auto dstEnc = dstTy.getEncoding(); - if (!!srcEnc != !!dstEnc) { + if (bool(srcEnc) != bool(dstEnc)) { return emitError("src and result must both have or not have an encoding"); } - if (!isa(srcEnc) && - !isa(srcEnc)) { - return emitError("src encoding must be SharedEncodingTrait"); - } - if (!isa(dstEnc) && - !isa(srcEnc)) { - return emitError("result encoding must be SharedEncodingTrait"); + if (isa(srcEnc) != isa(dstEnc)) { + return emitError("src and dst must have the same type of encoding"); } if (isa(srcEnc)) { @@ -705,69 +727,41 @@ LogicalResult MemDescSubviewOp::verify() { return emitError("only 3D -> 2D subviews are supported for " "TensorMemoryEncodingAttr"); } - for (int i = 1; i < srcTy.getRank(); i++) { - if (auto constOp = getOffsets()[i].getDefiningOp()) { - if (!isa(constOp.getValue()) || - cast(constOp.getValue()).getInt() != 0) { - return emitError("only first offset can be non-zero for the subview" - "of TensorMemoryEncodingAttr"); - } - } else { - return emitError( - "offsets other than the first one must be constant zeros"); - } - } return success(); } + return success(); +} - assert(isa(srcEnc)); +LogicalResult MemDescSubsliceOp::verify() { + auto srcTy = getSrc().getType(); + auto dstTy = getType(); - // corner case: 1D -> 1D into a 1 element tensor (we don't have 0D tensors) - if (srcTy.getRank() == 1 && dstTy.getRank() == 1 && - dstTy.getDimSize(0) == 1) { - return success(); + if (srcTy.getElementType() != dstTy.getElementType()) { + return emitError("result element type must match desc element type"); + } + if (getOffsets().size() != srcTy.getRank()) { + return emitError("offsets must have the same rank as input"); } - - // There are two cases: - // 1. The subview is rank-reducing - // - We split along the first dimension. It can be with non-constant offsets if (srcTy.getRank() != dstTy.getRank()) { - if (srcTy.getRank() - dstTy.getRank() != 1) { - return emitError( - "only nD -> (n-1)D rank-reducing subviews are supported"); - } - for (auto offset : getOffsets().take_back(dstTy.getRank())) { - APInt value; - if (!matchPattern(offset, m_ConstantInt(&value))) { - return emitError("only constant values are allowed outside the front " - "dimension in a rank-reducing subview"); - } - if (!value.isZero()) { - return emitError( - "only first offset can be non-zero for a rank-reducing subview"); - } - } - return success(); + return emitError("result rank must equal to input rank"); + } + + auto srcEnc = srcTy.getEncoding(); + auto dstEnc = dstTy.getEncoding(); + if (bool(srcEnc) != bool(dstEnc)) { + return emitError("src and result must both have or not have an encoding"); + } + if (!isa(srcEnc) || !isa(dstEnc)) { + return emitError("src and dst must both be of shared memory encoding"); } - assert(srcTy.getRank() == dstTy.getRank()); - // 2. The src is non-rank-reducing - // - We split along at most one dim, but just with constant values - // - The values where the split happens must not be within the swizzling - // pattern - // Check which dimensions we are splitting along + SetVector splitDims{}; for (int i = 0; i < srcTy.getRank(); i++) { if (srcTy.getDimSize(i) != dstTy.getDimSize(i)) { splitDims.insert(i); } } - SmallVector offsets; - for (auto offset : getOffsets()) { - APInt value; - if (!matchPattern(offset, m_ConstantInt(&value))) - return emitError("only constant values are allowed for the split"); - offsets.push_back(value.getSExtValue()); - } + SmallVector offsets(getOffsets().begin(), getOffsets().end()); // Identity subview if (splitDims.empty()) { return success(); diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp index b9b55e70b445..00a9e810156e 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp @@ -598,15 +598,8 @@ triton::createSingleBufferView(OpBuilder &builder, Value alloc, Value idx) { shape, allocDescType.getElementType(), allocDescType.getEncoding(), allocDescType.getMemorySpace(), allocDescType.getMutableMemory(), /*allocShape=*/allocDescType.getAllocShape()); - SmallVector idxs = {idx}; - if (allocDescType.getShape().size() > 1) { - Value zero = builder.create(alloc.getLoc(), 0, 32); - for (unsigned i = 1; i < allocDescType.getShape().size(); i++) { - idxs.push_back(zero); - } - } - return builder.create(alloc.getLoc(), viewDescType, - alloc, idxs); + return builder.create(alloc.getLoc(), viewDescType, + alloc, idx); } TypedValue diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/WGMMAPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/WGMMAPipeline.cpp index 7cdecf15071d..7430b285c7b4 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/WGMMAPipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/WGMMAPipeline.cpp @@ -281,7 +281,7 @@ SmallVector splitLhs(OpBuilder &builder, } // Split the RHS of a RSWGMMADot operation into multiple multiple -// tensors of size newKxN via MemDescSubview +// tensors of size newKxN via MemDescSubslice SmallVector splitRhs(OpBuilder &builder, TypedValue rhs, int64_t newK) { auto loc = rhs.getLoc(); @@ -291,18 +291,15 @@ SmallVector splitRhs(OpBuilder &builder, auto nSplits = type.getShape()[kDim] / newK; auto shape = llvm::to_vector(type.getShape()); shape[kDim] = newK; - SmallVector offsetsVal; - for (int i = 0; i < rank; i++) { - offsetsVal.push_back(builder.create(loc, 0, 32)); - } + SmallVector offsets(rank, 0); auto newType = ttg::MemDescType::get( shape, type.getElementType(), type.getEncoding(), type.getMemorySpace(), /*isMutable=*/false, type.getAllocShape()); SmallVector ret; for (int i = 0; i < nSplits; i++) { - offsetsVal[kDim] = builder.create(loc, i * newK, 32); - Value newSmem = builder.create( - loc, newType, rhs, offsetsVal); + offsets[kDim] = i * newK; + Value newSmem = + builder.create(loc, newType, rhs, offsets); ret.push_back(newSmem); } return ret; @@ -431,11 +428,11 @@ static std::optional dotCanBeProperlyAsync(ttng::WarpGroupDotOp dotOp, return true; } // If it's a shmem operand, it must either be defined outside the loop, or - // come from an MemDescSubview op. Only ConvertLayout and view ops are + // come from an MemDescIndex op. Only ConvertLayout and view ops are // allowed in between. Value transitiveOperand = operand; while (isa_and_nonnull( + ttg::MemDescReshapeOp, ttg::MemDescSubsliceOp>( transitiveOperand.getDefiningOp()) || isa(transitiveOperand)) { auto blockArg = dyn_cast(transitiveOperand); @@ -448,7 +445,7 @@ static std::optional dotCanBeProperlyAsync(ttng::WarpGroupDotOp dotOp, } } return forOp.isDefinedOutsideOfLoop(transitiveOperand) || - transitiveOperand.getDefiningOp(); + transitiveOperand.getDefiningOp(); }; // Rule 1: All shmem operands are multi-buffered. diff --git a/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp b/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp index 248871ebdbcb..e15f53e4bb14 100644 --- a/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp @@ -122,7 +122,7 @@ Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue, auto type = cast(v.getType()); SmallVector shape{type.getShape().begin(), type.getShape().end()}; auto rank = shape.size(); - SmallVector offset(rank, 0); + SmallVector offset(rank, 0); Type elementType = type.getElementType(); // k => (prefetchWidth, k - prefetchWidth) @@ -136,16 +136,12 @@ Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue, if (offsetK) offset[kIdx] = *offsetK; - SmallVector offsetsVal; - for (int64_t off : offset) - offsetsVal.push_back( - builder.create(v.getLoc(), off, 32)); - Value newSmem = builder.create( + Value newSmem = builder.create( v.getLoc(), triton::gpu::MemDescType::get( shape, elementType, type.getEncoding(), type.getMemorySpace(), type.getMutableMemory(), type.getAllocShape()), - v, offsetsVal); + v, offset); auto dotOperandEnc = triton::gpu::DotOperandEncodingAttr::get( builder.getContext(), opIdx, dotEncoding, prefetchWidth / 8); diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index 82ebe2354933..fc6334da9db4 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -1307,11 +1307,11 @@ void populateForOpDeadArgumentElimination(RewritePatternSet &patterns) { ttg::LocalAllocOp findShmemAlloc(Value operand) { // If it's a shmem operand, it must either be defined outside the loop, or - // come from an MemDescSubview op. Only ConvertLayout and Trans ops are + // come from an MemDescIndex op. Only ConvertLayout and MemdescView ops are // allowed in between. Value transitiveOperand = operand; while (isa_and_nonnull( + ttg::MemDescReshapeOp, ttg::MemDescSubsliceOp>( transitiveOperand.getDefiningOp()) || isa(transitiveOperand)) { if (auto blockArg = dyn_cast(transitiveOperand)) { @@ -1324,7 +1324,7 @@ ttg::LocalAllocOp findShmemAlloc(Value operand) { transitiveOperand = transitiveOperand.getDefiningOp()->getOperand(0); } } - if (auto subView = dyn_cast_or_null( + if (auto subView = dyn_cast_or_null( transitiveOperand.getDefiningOp())) { // Multi-buffered operand return dyn_cast_or_null( @@ -1488,14 +1488,22 @@ void replaceUsesAndPropagateType(OpBuilder &builder, Operation *oldUse, // `subview(old_op)` is replaced by a new `subview(val)`. builder.setInsertionPoint(user); Value newVal; - if (auto subview = dyn_cast(user)) { + if (auto subview = dyn_cast(user)) { ttg::MemDescType oldType = subview.getType(); bool isMutable = cast(val.getType()).getMutableMemory(); Type newDstType = ttg::MemDescType::get( oldType.getShape(), oldType.getElementType(), oldType.getEncoding(), - oldType.getMemorySpace(), isMutable); - newVal = builder.create( - subview.getLoc(), newDstType, val, subview.getOffsets()); + oldType.getMemorySpace(), isMutable, oldType.getAllocShape()); + newVal = builder.create(subview.getLoc(), newDstType, + val, subview.getIndex()); + } else if (auto subslice = dyn_cast(user)) { + ttg::MemDescType oldType = subslice.getType(); + bool isMutable = cast(val.getType()).getMutableMemory(); + Type newDstType = ttg::MemDescType::get( + oldType.getShape(), oldType.getElementType(), oldType.getEncoding(), + oldType.getMemorySpace(), isMutable, oldType.getAllocShape()); + newVal = builder.create( + subslice.getLoc(), newDstType, val, subslice.getOffsets()); } else if (auto trans = dyn_cast(user)) { newVal = builder.create(trans.getLoc(), val, trans.getOrder()); diff --git a/lib/Dialect/TritonGPU/Transforms/WarpSpecialization/RewritePartitionDependencies.cpp b/lib/Dialect/TritonGPU/Transforms/WarpSpecialization/RewritePartitionDependencies.cpp index de39f22f21d2..4f98894371a7 100644 --- a/lib/Dialect/TritonGPU/Transforms/WarpSpecialization/RewritePartitionDependencies.cpp +++ b/lib/Dialect/TritonGPU/Transforms/WarpSpecialization/RewritePartitionDependencies.cpp @@ -226,10 +226,7 @@ int UseInfo::getMaxUseDistance(const Partition &partition) { namespace { struct AsyncRef { Value getValueView(ImplicitLocOpBuilder &b, Value idx) const { - Value zero = b.create(b.getI32IntegerAttr(0)); - SmallVector offsets(allocType.getRank(), zero); - offsets.front() = idx; - return b.create(viewType, alloc, offsets); + return b.create(viewType, alloc, idx); } Value getReadyView(ImplicitLocOpBuilder &b, Value idx) const { return createSingleBufferView(b, readyBars, idx); diff --git a/lib/Dialect/TritonInstrument/Transforms/ConcurrencySanitizer.cpp b/lib/Dialect/TritonInstrument/Transforms/ConcurrencySanitizer.cpp index d2382133c6bc..5b0a825e1a44 100644 --- a/lib/Dialect/TritonInstrument/Transforms/ConcurrencySanitizer.cpp +++ b/lib/Dialect/TritonInstrument/Transforms/ConcurrencySanitizer.cpp @@ -28,16 +28,12 @@ bool canAllocBeInstrumented(triton::gpu::LocalAllocOp op) { return false; } if (llvm::all_of(op->getUsers(), [](Operation *user) { - return !isa(user); + return !isa(user); })) { return true; } if (llvm::all_of(op->getUsers(), [](Operation *user) { - auto subview = dyn_cast(user); - return subview && llvm::all_of(subview.getOffsets().drop_front(), - [](Value offset) { - return isConstantIntValue(offset, 0); - }); + return isa(user); })) { return true; } @@ -46,10 +42,10 @@ bool canAllocBeInstrumented(triton::gpu::LocalAllocOp op) { return false; } -// Interpret local_allocs that are used in ttg.memdesc_subview as multibuffered +// Interpret local_allocs that are used in ttg.memdesc_index as multibuffered bool isMultiBuffered(triton::gpu::LocalAllocOp op) { return llvm::any_of(op->getUsers(), [](Operation *user) { - return isa(user); + return isa(user); }); } diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/InterleaveTMem.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/InterleaveTMem.cpp index 367945ea09d0..052df2a34b33 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/InterleaveTMem.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/InterleaveTMem.cpp @@ -58,6 +58,63 @@ struct AccessRange { unsigned rankOffset = 0; }; +std::pair findBufferAccess(Value a); + +std::pair +findBufferAccessMemdescSubview(Operation *subview) { + OpBuilder builder(subview->getContext()); + Location loc = subview->getLoc(); + TypedValue src; + SmallVector shape; + SmallVector offsets; + if (auto indexOp = dyn_cast(subview)) { + src = indexOp.getSrc(); + shape = to_vector(indexOp.getType().getShape()); + offsets = {indexOp.getIndex()}; + for (auto i : llvm::seq(std::max(0, shape.size() - 1))) + offsets.push_back(builder.create(loc, 0, 32)); + } else { + auto subsliceOp = cast(subview); + src = subsliceOp.getSrc(); + shape = to_vector(subsliceOp.getType().getShape()); + for (auto offset : subsliceOp.getOffsets()) + offsets.push_back(builder.create(loc, offset, 32)); + } + auto [alloc, parentAccess] = findBufferAccess(src); + if (!alloc) + return {}; + // Handle subview of a subview. The first `rankOffset` access sizes are + // the same as in the parent access. + AccessRange childAccess; + for (auto i : llvm::seq(parentAccess.rankOffset)) + childAccess.ranges.push_back(parentAccess.ranges[i]); + + // The subview may have a smaller rank, in which case its access size is + // just 1 for the higher dims. + childAccess.rankOffset = src.getType().getRank() - shape.size(); + for (auto [i, offset] : llvm::enumerate(offsets)) { + auto parentRange = parentAccess.ranges[i + parentAccess.rankOffset]; + if (!parentRange) { + childAccess.ranges.push_back({}); + continue; + } + + // If the offset is not known, then the entire dim may be accessed. + APInt value; + if (!matchPattern(offset, m_ConstantInt(&value))) { + childAccess.ranges.push_back({}); + continue; + } + + uint64_t accessStart = parentRange->start() + value.getSExtValue(); + uint64_t accessSize = 1; + if (i >= childAccess.rankOffset) + accessSize = shape[i - childAccess.rankOffset]; + childAccess.ranges.push_back({{accessStart, accessStart + accessSize}}); + } + return {alloc, std::move(childAccess)}; +} + // Simple local alias analysis that looks for a single underlying allocation and // an access subrange. std::pair findBufferAccess(Value a) { @@ -90,41 +147,8 @@ std::pair findBufferAccess(Value a) { } // Subviews can reduce the access sizes. - if (auto subview = dyn_cast(defOp)) { - auto [alloc, parentAccess] = findBufferAccess(subview.getSrc()); - if (!alloc) - return {}; - // Handle subview of a subview. The first `rankOffset` access sizes are - // the same as in the parent access. - AccessRange childAccess; - for (auto i : llvm::seq(parentAccess.rankOffset)) - childAccess.ranges.push_back(parentAccess.ranges[i]); - - // The subview may have a smaller rank, in which case its access size is - // just 1 for the higher dims. - childAccess.rankOffset = - subview.getSrc().getType().getRank() - subview.getType().getRank(); - for (auto [i, offset] : llvm::enumerate(subview.getOffsets())) { - auto parentRange = parentAccess.ranges[i + parentAccess.rankOffset]; - if (!parentRange) { - childAccess.ranges.push_back({}); - continue; - } - - // If the offset is not known, then the entire dim may be accessed. - APInt value; - if (!matchPattern(offset, m_ConstantInt(&value))) { - childAccess.ranges.push_back({}); - continue; - } - - uint64_t accessStart = parentRange->start() + value.getSExtValue(); - uint64_t accessSize = 1; - if (i >= childAccess.rankOffset) - accessSize = subview.getType().getShape()[i - childAccess.rankOffset]; - childAccess.ranges.push_back({{accessStart, accessStart + accessSize}}); - } - return {alloc, std::move(childAccess)}; + if (isa(defOp)) { + return findBufferAccessMemdescSubview(defOp); } // Subslice is a subview only on the N dimension. diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/TensorMemoryAllocation.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/TensorMemoryAllocation.cpp index 48f4ef873fd2..7d18772e1b70 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/TensorMemoryAllocation.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/TensorMemoryAllocation.cpp @@ -120,7 +120,7 @@ static Interval getLiveIntervals(Value value, Liveness &liveness, SmallVector users(value.getUsers()); while (!users.empty()) { Operation *user = users.pop_back_val(); - if (!isa(user)) + if (!isa(user)) continue; auto usersLivness = liveness.resolveLiveness(user->getResult(0)); liveOperations.insert(liveOperations.end(), usersLivness.begin(), @@ -179,8 +179,8 @@ static Operation *getAlloc(Value value) { while (true) { if (auto allocOp = value.getDefiningOp()) return allocOp; - if (auto subviewOp = value.getDefiningOp()) { - value = subviewOp.getSrc(); + if (auto indexOp = value.getDefiningOp()) { + value = indexOp.getSrc(); continue; } if (auto reinterpOp = value.getDefiningOp()) { diff --git a/python/src/gluon_ir.cc b/python/src/gluon_ir.cc index 71c3b1cef251..0ccaafde4508 100644 --- a/python/src/gluon_ir.cc +++ b/python/src/gluon_ir.cc @@ -383,11 +383,16 @@ void init_gluon_ir(py::module &&m) { return self.create(memDesc); }) - .def("create_memdesc_subview", + .def("create_memdesc_index", [](GluonOpBuilder &self, Type resultType, Value src, - std::vector &offsets) -> Value { - return self.create(resultType, src, - offsets); + Value index) -> Value { + return self.create(resultType, src, index); + }) + .def("create_memdesc_subslice", + [](GluonOpBuilder &self, Type resultType, Value src, + std::vector &offsets) -> Value { + return self.create(resultType, src, + offsets); }) .def("create_memdesc_trans", [](GluonOpBuilder &self, Value src, diff --git a/python/test/gluon/test_frontend.py b/python/test/gluon/test_frontend.py index f53d355a3e0d..91f24f7e005f 100644 --- a/python/test/gluon/test_frontend.py +++ b/python/test/gluon/test_frontend.py @@ -201,9 +201,8 @@ def test_tensor_memory(fresh_knobs): %4 = arith.bitcast %c1_i32 : i32 to i32 loc(#loc) %5 = ub.poison : i32 loc(#loc) scf.for %ivar = %2 to %3 step %4 : i32 { - %c0_i32_4 = arith.constant 0 : i32 loc(#loc) - %6 = ttg.memdesc_subview %result_2[%ivar, %c0_i32_4, %c0_i32_4] : !ttg.memdesc<2x128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable, 2x128x128> loc(#loc) - %result_5 = ttng.tmem_load %6 : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable, 2x128x128> -> tensor<128x128xf32, #blocked> loc(#loc) + %6 = ttg.memdesc_index %result_2, %ivar : !ttg.memdesc<2x128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable, 2x128x128> loc(#loc) + %result_4 = ttng.tmem_load %6 : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable, 2x128x128> -> tensor<128x128xf32, #blocked> loc(#loc) } loc(#loc) tt.return loc(#loc) } loc(#loc) @@ -238,13 +237,9 @@ def test_shared_memory_subview(fresh_knobs): module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} { tt.func public @shared_memory_subview_kernel() attributes {noinline = false} { %0 = ttg.local_alloc : () -> !ttg.memdesc<256x256xi32, #shared, #smem, mutable> loc(#loc) - %c0_i32 = arith.constant 0 : i32 loc(#loc) - %c128_i32 = arith.constant 128 : i32 loc(#loc) - %1 = ttg.memdesc_subview %0[%c0_i32, %c128_i32] : !ttg.memdesc<256x256xi32, #shared, #smem, mutable> -> !ttg.memdesc<256x128xi32, #shared, #smem, mutable, 256x256> loc(#loc) + %1 = ttg.memdesc_subslice %0[0, 128] : !ttg.memdesc<256x256xi32, #shared, #smem, mutable> -> !ttg.memdesc<256x128xi32, #shared, #smem, mutable, 256x256> loc(#loc) %2 = ttg.local_load %1 : !ttg.memdesc<256x128xi32, #shared, #smem, mutable, 256x256> -> tensor<256x128xi32, #blocked> loc(#loc) - %c0_i32_0 = arith.constant 0 : i32 loc(#loc) - %c128_i32_1 = arith.constant 128 : i32 loc(#loc) - %3 = ttg.memdesc_subview %0[%c128_i32_1, %c0_i32_0] : !ttg.memdesc<256x256xi32, #shared, #smem, mutable> -> !ttg.memdesc<128x256xi32, #shared, #smem, mutable, 256x256> loc(#loc) + %3 = ttg.memdesc_subslice %0[128, 0] : !ttg.memdesc<256x256xi32, #shared, #smem, mutable> -> !ttg.memdesc<128x256xi32, #shared, #smem, mutable, 256x256> loc(#loc) %4 = tt.trans %2 {order = array} : tensor<256x128xi32, #blocked> -> tensor<128x256xi32, #blocked1> loc(#loc) ttg.local_store %4, %3 : tensor<128x256xi32, #blocked1> -> !ttg.memdesc<128x256xi32, #shared, #smem, mutable, 256x256> loc(#loc) tt.return loc(#loc) @@ -284,8 +279,7 @@ def test_shared_memory_index(fresh_knobs): %3 = arith.bitcast %c1_i32 : i32 to i32 loc(#loc) %4 = ub.poison : i32 loc(#loc) scf.for %ivar = %1 to %2 step %3 : i32 { - %c0_i32_0 = arith.constant 0 : i32 loc(#loc) - %5 = ttg.memdesc_subview %0[%ivar, %c0_i32_0] : !ttg.memdesc<4x256xi32, #shared, #smem, mutable> -> !ttg.memdesc<256xi32, #shared, #smem, mutable, 4x256> loc(#loc) + %5 = ttg.memdesc_index %0, %ivar : !ttg.memdesc<4x256xi32, #shared, #smem, mutable> -> !ttg.memdesc<256xi32, #shared, #smem, mutable, 4x256> loc(#loc) %6 = ttg.local_load %5 : !ttg.memdesc<256xi32, #shared, #smem, mutable, 4x256> -> tensor<256xi32, #blocked> loc(#loc) } loc(#loc) tt.return loc(#loc) @@ -329,8 +323,7 @@ def test_shared_memory_cast(fresh_knobs): tt.func public @shared_memory_cast_kernel() attributes {noinline = false} { %0 = ttg.local_alloc : () -> !ttg.memdesc<2x256x128xi8, #shared, #smem, mutable> %c0_i32 = arith.constant 0 : i32 - %c0_i32_0 = arith.constant 0 : i32 - %1 = ttg.memdesc_subview %0[%c0_i32_0, %c0_i32, %c0_i32] : !ttg.memdesc<2x256x128xi8, #shared, #smem, mutable> -> !ttg.memdesc<256x128xi8, #shared, #smem, mutable, 2x256x128> + %1 = ttg.memdesc_index %0, %c0_i32 : !ttg.memdesc<2x256x128xi8, #shared, #smem, mutable> -> !ttg.memdesc<256x128xi8, #shared, #smem, mutable, 2x256x128> %2 = ttg.memdesc_trans %1 {order = array} : !ttg.memdesc<256x128xi8, #shared, #smem, mutable, 2x256x128> -> !ttg.memdesc<128x256xi8, #shared1, #smem, mutable, 2x128x256> tt.call @"test_frontend.anchor_noinline__MDi8S128_256SLNVMMA_64_8_True_False_NVMMALAS[2, 128, 256]ASMD__"(%2) : (!ttg.memdesc<128x256xi8, #shared1, #smem, mutable, 2x128x256>) -> () %3 = ttg.local_alloc : () -> !ttg.memdesc<32x1x4x64xf16, #shared2, #smem, mutable> @@ -793,8 +786,7 @@ def test_tmem_index_constexpr(): tt.func public @tmem_index_kernel() attributes {noinline = false} { %result = ttng.tmem_alloc : () -> !ttg.memdesc<2x256x256xi32, #tmem, #ttng.tensor_memory, mutable> %c0_i32 = arith.constant 0 : i32 - %c0_i32_0 = arith.constant 0 : i32 - %0 = ttg.memdesc_subview %result[%c0_i32, %c0_i32_0, %c0_i32_0] : !ttg.memdesc<2x256x256xi32, #tmem, #ttng.tensor_memory, mutable> -> !ttg.memdesc<256x256xi32, #tmem, #ttng.tensor_memory, mutable, 2x256x256> + %0 = ttg.memdesc_index %result, %c0_i32 : !ttg.memdesc<2x256x256xi32, #tmem, #ttng.tensor_memory, mutable> -> !ttg.memdesc<256x256xi32, #tmem, #ttng.tensor_memory, mutable, 2x256x256> tt.return } } diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index c9b4a6a93705..cc5f42c4c6a4 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -6514,7 +6514,7 @@ def test_split_subview(M, N, M_tile_size, N_tile_size, device='cuda'): %c0_i32 = arith.constant 0 : i32 %12 = ttg.local_alloc : () -> !ttg.memdesc<1x{M}x{N}xf16, #shared, #smem, mutable> - %13 = ttg.memdesc_subview %12[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x{M}x{N}xf16, #shared, #smem, mutable> -> !ttg.memdesc<{M}x{N}xf16, #shared, #smem, mutable> + %13 = ttg.memdesc_index %12, %c0_i32 : !ttg.memdesc<1x{M}x{N}xf16, #shared, #smem, mutable> -> !ttg.memdesc<{M}x{N}xf16, #shared, #smem, mutable> ttg.local_store %11, %13 : tensor<{M}x{N}xf16, #blocked> -> !ttg.memdesc<{M}x{N}xf16, #shared, #smem, mutable> """ @@ -6525,10 +6525,7 @@ def test_split_subview(M, N, M_tile_size, N_tile_size, device='cuda'): m_offset = m * M_tile_size n_offset = n * N_tile_size ir += f""" - %off0_{m}_{n} = arith.constant {m_offset} : i32 - %off1_{m}_{n} = arith.constant {n_offset} : i32 - - %view{linear_idx} = ttg.memdesc_subview %13[%off0_{m}_{n}, %off1_{m}_{n}] : !ttg.memdesc<{M}x{N}xf16, #shared, #smem, mutable> -> !ttg.memdesc<{M_tile_size}x{N_tile_size}xf16, #shared, #smem, mutable, {M}x{N}> + %view{linear_idx} = ttg.memdesc_subslice %13[{m_offset}, {n_offset}] : !ttg.memdesc<{M}x{N}xf16, #shared, #smem, mutable> -> !ttg.memdesc<{M_tile_size}x{N_tile_size}xf16, #shared, #smem, mutable, {M}x{N}> %data{linear_idx} = ttg.local_load %view{linear_idx} : !ttg.memdesc<{M_tile_size}x{N_tile_size}xf16, #shared, #smem, mutable, {M}x{N}> -> tensor<{M_tile_size}x{N_tile_size}xf16, #blocked> %inc{linear_idx} = arith.constant dense<{linear_idx}.0> : tensor<{M_tile_size}x{N_tile_size}xf16, #blocked> diff --git a/python/triton/experimental/gluon/language/_semantic.py b/python/triton/experimental/gluon/language/_semantic.py index d9eb05f9eaf2..e50d6bf266ff 100644 --- a/python/triton/experimental/gluon/language/_semantic.py +++ b/python/triton/experimental/gluon/language/_semantic.py @@ -202,25 +202,25 @@ def set_auto_layout(self, value, layout): res_ty = ttgl.distributed_type(src_ty.element_ty, src_ty.shape, layout) return self.tensor(handle, res_ty) - def _memdesc_subview(self, mem_desc, offsets, shape): + def memdesc_slice(self, mem_desc, start, length, dim): + offsets = [0] * mem_desc.rank + offsets[dim] = start + shape = list(mem_desc.shape) + shape[dim] = length layout = mem_desc.layout ty = ttgl.shared_memory_descriptor_type(mem_desc.dtype, shape, layout, mem_desc.type.alloc_shape) builder = self.builder - handle = builder.create_memdesc_subview(ty.to_ir(builder), mem_desc.handle, offsets) + handle = builder.create_memdesc_subslice(ty.to_ir(builder), mem_desc.handle, offsets) return ttgl.shared_memory_descriptor(handle, **ty.__dict__) - def memdesc_slice(self, mem_desc, start, length, dim): - offsets = [self.builder.get_int32(0)] * mem_desc.rank - offsets[dim] = self.to_tensor(start).handle - shape = list(mem_desc.shape) - shape[dim] = length - return self._memdesc_subview(mem_desc, offsets, shape) - def memdesc_index(self, mem_desc, index): shape = mem_desc.shape[1:] - offsets = [self.builder.get_int32(0)] * mem_desc.rank - offsets[0] = self.to_tensor(index).handle - return self._memdesc_subview(mem_desc, offsets, shape) + index = self.to_tensor(index).handle + layout = mem_desc.layout + ty = ttgl.shared_memory_descriptor_type(mem_desc.dtype, shape, layout, mem_desc.type.alloc_shape) + builder = self.builder + handle = builder.create_memdesc_index(ty.to_ir(builder), mem_desc.handle, index) + return ttgl.shared_memory_descriptor(handle, **ty.__dict__) def memdesc_trans(self, mem_desc, order): assert len(order) == len( diff --git a/python/triton/experimental/gluon/language/nvidia/blackwell/__init__.py b/python/triton/experimental/gluon/language/nvidia/blackwell/__init__.py index ca795ef354b8..bd119430bbf2 100644 --- a/python/triton/experimental/gluon/language/nvidia/blackwell/__init__.py +++ b/python/triton/experimental/gluon/language/nvidia/blackwell/__init__.py @@ -243,12 +243,10 @@ def index(self, index, _semantic: GluonSemantic = None) -> tensor_memory_descrip """ index = _semantic.to_tensor(index) builder = _semantic.builder - offsets = [builder.get_int32(0)] * self.rank - offsets[0] = index.handle shape = self.shape[1:] layout = self.layout ret = tensor_memory_descriptor(None, self.dtype, shape, layout, self.type.alloc_shape) - ret.handle = builder.create_memdesc_subview(ret.type.to_ir(builder), self.handle, offsets) + ret.handle = builder.create_memdesc_index(ret.type.to_ir(builder), self.handle, index.handle) return ret @builtin diff --git a/test/Analysis/test-alias.mlir b/test/Analysis/test-alias.mlir index 380ad08a3c1b..470ac288e584 100644 --- a/test/Analysis/test-alias.mlir +++ b/test/Analysis/test-alias.mlir @@ -62,7 +62,7 @@ tt.func @subview(%A : !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory>) // expected-remark @below {{%0 -> %0}} %a = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // expected-remark @below {{%1 -> %0}} - %cst1 = ttg.memdesc_subview %a[%index, %index, %index] : !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + %cst1 = ttg.memdesc_index %a, %index : !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> tt.return } @@ -119,7 +119,7 @@ tt.func @for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : %zero = arith.constant 0 : i32 %index = arith.constant 8 : i32 // expected-remark @below {{%4 -> %0,%1}} - %cst0 = ttg.memdesc_subview %a_shared[%index, %zero] : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> -> !ttg.memdesc<32xf16, #A_SHARED, #ttg.shared_memory, mutable> + %cst0 = ttg.memdesc_index %a_shared, %index : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> -> !ttg.memdesc<32xf16, #A_SHARED, #ttg.shared_memory, mutable> scf.yield } scf.yield %b_shared, %a_shared, %a_shared : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> @@ -163,13 +163,12 @@ tt.func @for_for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr, } tt.func @cf_for(%arg0: index, %arg1: index, %arg2: index, %arg3: !tt.ptr, %arg4: !tt.ptr) { - %idx = arith.constant 0 : i32 // expected-remark @below {{%0 -> %0}} %cst = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> // expected-remark @below {{%1 -> %1}} %cst_0 = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> // expected-remark @below {{%2 -> %0}} - %0 = ttg.memdesc_subview %cst[%idx, %idx] : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + %0 = ttg.memdesc_subslice %cst [0, 0] : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> gpu.barrier // expected-remark @below {{%3 -> %3}} %cst_1 = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> @@ -187,7 +186,7 @@ tt.func @cf_for(%arg0: index, %arg1: index, %arg2: index, %arg3: !tt.ptr, % ^bb3: // pred: ^bb1 gpu.barrier // expected-remark @below {{%10 -> %0}} - %9 = ttg.memdesc_subview %0[%idx, %idx] : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + %9 = ttg.memdesc_subslice %0 [0, 0] : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> tt.return } @@ -199,9 +198,8 @@ tt.func @poison_memdesc(%arg0: i1) { %1 = ub.poison : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> cf.br ^bb2(%1 : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>) ^bb2(%2: !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>): - %c0_i32 = arith.constant 0 : i32 // expected-remark @below {{%3 -> %0}} - %3 = ttg.memdesc_subview %2[%c0_i32, %c0_i32] : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + %3 = ttg.memdesc_subslice %2 [0, 0] : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> tt.return } diff --git a/test/Analysis/test-allocation.mlir b/test/Analysis/test-allocation.mlir index a44132738023..cc72531e623c 100644 --- a/test/Analysis/test-allocation.mlir +++ b/test/Analysis/test-allocation.mlir @@ -326,7 +326,7 @@ tt.func @extract_slice(%A : !tt.ptr) { // expected-remark @below {{offset = 0, size = 512}} %cst0 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> %index = arith.constant 0 : i32 - %cst1 = ttg.memdesc_subview %cst0[%index, %index, %index] : !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + %cst1 = ttg.memdesc_index %cst0, %index : !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> tt.return } @@ -445,7 +445,7 @@ tt.func @for_if_slice(%lb : index, %ub : index, %step : index, %A : !tt.ptr scf.if %i1 { %zero = arith.constant 0 : i32 %index = arith.constant 8 : i32 - %cst0 = ttg.memdesc_subview %a_shared[%index, %zero] : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> -> !ttg.memdesc<32xf16, #A_SHARED, #ttg.shared_memory, mutable> + %cst0 = ttg.memdesc_index %a_shared, %index : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> -> !ttg.memdesc<32xf16, #A_SHARED, #ttg.shared_memory, mutable> scf.yield } scf.yield %b_shared, %a_shared, %a_shared : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> @@ -834,7 +834,7 @@ tt.func @aliasing_in_partition() { // expected-remark @below {{offset = 0, size = 16}} %0 = ttg.local_alloc : () -> !ttg.memdesc<2xi64, #A_SHARED, #smem, mutable> %c0_i32 = arith.constant 0 : i32 - %1 = ttg.memdesc_subview %0[%c0_i32] : !ttg.memdesc<2xi64, #A_SHARED, #smem, mutable> -> !ttg.memdesc<1xi64, #A_SHARED, #smem, mutable> + %1 = ttg.memdesc_index %0, %c0_i32 : !ttg.memdesc<2xi64, #A_SHARED, #smem, mutable> -> !ttg.memdesc<1xi64, #A_SHARED, #smem, mutable> // expected-remark @below {{offset = 16, size = 16}} %2 = ttg.local_alloc : () -> !ttg.memdesc<2xi64, #A_SHARED, #smem, mutable> "use"(%1) : (!ttg.memdesc<1xi64, #A_SHARED, #smem, mutable>) -> () diff --git a/test/Analysis/test-membar-ttng.mlir b/test/Analysis/test-membar-ttng.mlir index 20b62ae7ba4d..ddcfdf59c74d 100644 --- a/test/Analysis/test-membar-ttng.mlir +++ b/test/Analysis/test-membar-ttng.mlir @@ -63,12 +63,12 @@ tt.func @tma_special_cases(%arg1: !tt.tensordesc>) - ttng.async_tma_copy_global_to_local %arg1[%c0, %c0] %alloc, %barrier, %true : !tt.tensordesc>, !ttg.memdesc<1xi64, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> ttng.wait_barrier %barrier, %c0 : !ttg.memdesc<1xi64, #shared1, #ttg.shared_memory, mutable> - // CHECK-NEXT: memdesc_subview + // CHECK-NEXT: memdesc_subslice // CHECK-NEXT: ttng.barrier_expect // CHECK-NEXT: ttng.async_tma_gather // CHECK-NEXT: gpu.barrier // CHECK-NEXT: ttng.wait_barrier - %view = ttg.memdesc_subview %alloc[%c0, %c0] : !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<32x64xf16, #shared, #ttg.shared_memory, mutable> + %view = ttg.memdesc_subslice %alloc [0, 0] : !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<32x64xf16, #shared, #ttg.shared_memory, mutable> ttng.barrier_expect %barrier, 49152, %true : !ttg.memdesc<1xi64, #shared1, #ttg.shared_memory, mutable> ttng.async_tma_gather %arg1[%cx, %c0] %view, %barrier, %true : !tt.tensordesc>, tensor<32xi32>, i32, !ttg.memdesc<1xi64, #shared1, #ttg.shared_memory, mutable>, !ttg.memdesc<32x64xf16, #shared, #ttg.shared_memory, mutable>, i1 ttng.wait_barrier %barrier, %c0 : !ttg.memdesc<1xi64, #shared1, #ttg.shared_memory, mutable> diff --git a/test/Analysis/test-membar.mlir b/test/Analysis/test-membar.mlir index deb1cdb63dde..790cb9f33a27 100644 --- a/test/Analysis/test-membar.mlir +++ b/test/Analysis/test-membar.mlir @@ -119,8 +119,7 @@ tt.func @async_wait(%arg: tensor<32x16xf16, #AL>) { tt.func @subview() { %cst0 = arith.constant dense<0.000000e+00> : tensor<32x16xf16, #AL> %a = ttg.local_alloc %cst0 : (tensor<32x16xf16, #AL>) -> !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory> - %index = arith.constant 0 : i32 - %0 = ttg.memdesc_subview %a[%index, %index] : !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory> -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> + %0 = ttg.memdesc_subslice %a [0, 0] : !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory> -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> // CHECK: gpu.barrier // CHECK-NEXT: ttg.local_load %1 = ttg.local_load %0 : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> -> tensor<16x16xf16, #AL> @@ -144,7 +143,7 @@ tt.func @async_copy_global_to_local(%A : !tt.ptr, %i1 : i1) { %mask = tt.splat %i1 : i1 -> tensor<16x16xi1, #AL> %other = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> %alloc = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> - %subview = ttg.memdesc_subview %alloc[%index, %index, %index] : !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + %subview = ttg.memdesc_index %alloc, %index : !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> %1 = ttg.async_copy_global_to_local %a_ptr, %subview : tensor<16x16x!tt.ptr, #AL> -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK: gpu.barrier // CHECK-NEXT: ttg.local_load @@ -879,7 +878,7 @@ tt.func @membar_alias_through_warp_specialize() { // CHECK: partition0 partition0(%arg0: !ttg.memdesc<16x16xf16, #shared, #ttg.shared_memory, mutable>) num_warps(2) { %c0 = arith.constant 0 : i32 - %1 = ttg.memdesc_subview %arg0[%c0, %c0] : !ttg.memdesc<16x16xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<16x16xf16, #shared, #ttg.shared_memory, mutable> + %1 = ttg.memdesc_subslice %arg0 [0, 0] : !ttg.memdesc<16x16xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<16x16xf16, #shared, #ttg.shared_memory, mutable> %c = arith.constant dense<0.0> : tensor<16x16xf16> // CHECK: local_store ttg.local_store %c, %1 : tensor<16x16xf16> -> !ttg.memdesc<16x16xf16, #shared, #ttg.shared_memory, mutable> @@ -891,7 +890,7 @@ tt.func @membar_alias_through_warp_specialize() { // CHECK: partition1 partition1(%arg0: !ttg.memdesc<16x16xf16, #shared, #ttg.shared_memory, mutable>) num_warps(2) { %c0 = arith.constant 0 : i32 - %1 = ttg.memdesc_subview %arg0[%c0, %c0] : !ttg.memdesc<16x16xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<16x16xf16, #shared, #ttg.shared_memory, mutable> + %1 = ttg.memdesc_subslice %arg0 [0, 0] : !ttg.memdesc<16x16xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<16x16xf16, #shared, #ttg.shared_memory, mutable> %c = arith.constant dense<0.0> : tensor<16x16xf16> // CHECK: local_store ttg.local_store %c, %1 : tensor<16x16xf16> -> !ttg.memdesc<16x16xf16, #shared, #ttg.shared_memory, mutable> diff --git a/test/Conversion/amd/amdgpu_membar.mlir b/test/Conversion/amd/amdgpu_membar.mlir index 57a5bea9934d..0828a4061a16 100644 --- a/test/Conversion/amd/amdgpu_membar.mlir +++ b/test/Conversion/amd/amdgpu_membar.mlir @@ -11,8 +11,8 @@ tt.func @pipelined_async_copy_local_to_global(%A: !tt.ptr) { %index_1 = arith.constant 1 : i32 %a_ptr = tt.splat %A : !tt.ptr -> tensor<16x16x!tt.ptr, #AL> %alloc = ttg.local_alloc : () -> !ttg.memdesc<2x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> - %tile_a = ttg.memdesc_subview %alloc[%index_0, %index_0, %index_0] : !ttg.memdesc<2x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> - %tile_b = ttg.memdesc_subview %alloc[%index_1, %index_0, %index_0] : !ttg.memdesc<2x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + %tile_a = ttg.memdesc_index %alloc, %index_0 : !ttg.memdesc<2x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + %tile_b = ttg.memdesc_index %alloc, %index_1 : !ttg.memdesc<2x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // Load TileA %1 = ttg.async_copy_global_to_local %a_ptr, %tile_a: tensor<16x16x!tt.ptr, #AL> -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // Wait for TileA @@ -36,8 +36,8 @@ tt.func @pipelined_async_copy_local_to_global_2(%A: !tt.ptr) { %index_1 = arith.constant 1 : i32 %a_ptr = tt.splat %A : !tt.ptr -> tensor<16x16x!tt.ptr, #AL> %alloc = ttg.local_alloc : () -> !ttg.memdesc<2x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> - %tile_a = ttg.memdesc_subview %alloc[%index_0, %index_0, %index_0] : !ttg.memdesc<2x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> - %tile_b = ttg.memdesc_subview %alloc[%index_1, %index_0, %index_0] : !ttg.memdesc<2x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + %tile_a = ttg.memdesc_index %alloc, %index_0 : !ttg.memdesc<2x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + %tile_b = ttg.memdesc_index %alloc, %index_1 : !ttg.memdesc<2x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // Load Tile %1 = ttg.async_copy_global_to_local %a_ptr, %tile_a: tensor<16x16x!tt.ptr, #AL> -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // Wait for TileA @@ -63,12 +63,12 @@ tt.func @pipelined_async_copy_local_to_global_3(%A: !tt.ptr, %B: !tt.ptr -> tensor<16x16x!tt.ptr, #AL> %alloc_a = ttg.local_alloc : () -> !ttg.memdesc<2x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> - %tile_a_1 = ttg.memdesc_subview %alloc_a[%index_0, %index_0, %index_0] : !ttg.memdesc<2x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> - %tile_a_2 = ttg.memdesc_subview %alloc_a[%index_1, %index_0, %index_0] : !ttg.memdesc<2x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + %tile_a_1 = ttg.memdesc_index %alloc_a, %index_0 : !ttg.memdesc<2x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + %tile_a_2 = ttg.memdesc_index %alloc_a, %index_1 : !ttg.memdesc<2x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> %alloc_b = ttg.local_alloc : () -> !ttg.memdesc<2x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> - %tile_b_1 = ttg.memdesc_subview %alloc_b[%index_0, %index_0, %index_0] : !ttg.memdesc<2x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> - %tile_b_2 = ttg.memdesc_subview %alloc_b[%index_1, %index_0, %index_0] : !ttg.memdesc<2x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + %tile_b_1 = ttg.memdesc_index %alloc_b, %index_0 : !ttg.memdesc<2x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + %tile_b_2 = ttg.memdesc_index %alloc_b, %index_1 : !ttg.memdesc<2x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // Load TileA_1 %1 = ttg.async_copy_global_to_local %a_ptr, %tile_a_1: tensor<16x16x!tt.ptr, #AL> -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> diff --git a/test/Conversion/amd/tritongpu_to_llvm.mlir b/test/Conversion/amd/tritongpu_to_llvm.mlir index 795e9c99964a..b308caf04e7e 100644 --- a/test/Conversion/amd/tritongpu_to_llvm.mlir +++ b/test/Conversion/amd/tritongpu_to_llvm.mlir @@ -373,7 +373,7 @@ module attributes {"ttg.target" = "hip:gfx942", "ttg.num-ctas" = 1 : i32, "ttg.n %c16_i32 = arith.constant 16 : i32 // CHECK-COUNT-16: llvm.store {{.*}} : vector<1xf16>, !llvm.ptr<3> %0 = ttg.local_alloc %arg0 : (tensor<64x64xf16, #blocked>) -> !ttg.memdesc<64x64xf16, #shared, #smem, mutable> - %1 = ttg.memdesc_subview %0[%c0_i32, %c16_i32] : !ttg.memdesc<64x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 64x64> + %1 = ttg.memdesc_subslice %0 [0, 16] : !ttg.memdesc<64x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 64x64> // CHECK-COUNT-4: llvm.load {{.*}} : !llvm.ptr<3> -> vector<1xf16> %2 = ttg.local_load %1 : !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 64x64> -> tensor<64x16xf16, #blocked> // CHECK-COUNT-4: llvm.store {{.*}} : vector<1xf16>, !llvm.ptr<3> @@ -438,7 +438,7 @@ module attributes {"ttg.target" = "hip:gfx942", "ttg.num-ctas" = 1 : i32, "ttg.n // CHECK-NEXT: %[[ADD2:.+]] = llvm.add %[[ADD]], %[[ADD1]] : i32 // CHECK: llvm.getelementptr inbounds %{{.+}}[%[[ADD2]]] - %1 = ttg.memdesc_subview %arg0[%c1_i32, %c0_i32, %c0_i32] : !ttg.memdesc<2x64x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<64x64xf16, #shared, #smem, mutable> + %1 = ttg.memdesc_index %arg0, %c1_i32 : !ttg.memdesc<2x64x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<64x64xf16, #shared, #smem, mutable> %2 = ttg.local_load %1 : !ttg.memdesc<64x64xf16, #shared, #smem, mutable> -> tensor<64x64xf16, #blocked> ttg.local_store %2, %1 : tensor<64x64xf16, #blocked> -> !ttg.memdesc<64x64xf16, #shared, #smem, mutable> tt.return diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index daa9ce4fdb7e..86118319d565 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -559,20 +559,12 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-NEXT: llvm.mlir.constant(1 : i32) : i32 // CHECK-NEXT: llvm.mlir.constant(32 : i32) : i32 // CHECK-NEXT: llvm.mlir.constant(512 : i32) : i32 - // CHECK-NEXT: llvm.mlir.constant(0 : i32) : i32 // CHECK-NEXT: llvm.mul - // CHECK-NEXT: llvm.add - // CHECK-NEXT: llvm.mul - // CHECK-NEXT: llvm.add - // CHECK-NEXT: llvm.mul - // CHECK-NEXT: llvm.add - // CHECK-NEXT: llvm.add - // CHECK-NEXT: llvm.add // CHECK-NEXT: llvm.getelementptr %index = arith.constant 1 : i32 %zero = arith.constant 0 : i32 %0 = ttg.local_alloc : () -> !ttg.memdesc<128x16x32xf32, #shared0, #smem, mutable> - %1 = ttg.memdesc_subview %0[%index, %zero, %zero] : !ttg.memdesc<128x16x32xf32, #shared0, #smem, mutable> -> !ttg.memdesc<16x32xf32, #shared0, #smem, mutable> + %1 = ttg.memdesc_index %0, %index : !ttg.memdesc<128x16x32xf32, #shared0, #smem, mutable> -> !ttg.memdesc<16x32xf32, #shared0, #smem, mutable> tt.return } } @@ -605,7 +597,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} { %59 = tt.addptr %58, %24 : tensor<64x!tt.ptr, #slice1d0>, tensor<64xi32, #slice1d0> %66 = tt.addptr %59, %cst_2 : tensor<64x!tt.ptr, #slice1d0>, tensor<64xi32, #slice1d0> %71 = ttg.local_alloc : () -> !ttg.memdesc<2x64xi64, #shared2D, #smem, mutable> - %subview = ttg.memdesc_subview %71[%c0_i32, %c0_i32] : + %subview = ttg.memdesc_index %71, %c0_i32 : !ttg.memdesc<2x64xi64, #shared2D, #smem, mutable> -> !ttg.memdesc<64xi64, #shared1D, #smem, mutable> // CHECK: llvm.inline_asm has_side_effects asm_dialect = att @@ -2039,7 +2031,7 @@ module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num- tt.func public @test_local_load_bf16() { %c0_i32 = arith.constant 0 : i32 %19 = ttg.local_alloc : () -> !ttg.memdesc<1x1x2048xbf16, #shared, #smem, mutable> - %22 = ttg.memdesc_subview %19[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x1x2048xbf16, #shared, #smem, mutable> -> !ttg.memdesc<1x2048xbf16, #shared, #smem, mutable> + %22 = ttg.memdesc_index %19, %c0_i32 : !ttg.memdesc<1x1x2048xbf16, #shared, #smem, mutable> -> !ttg.memdesc<1x2048xbf16, #shared, #smem, mutable> %39 = ttg.local_load %22 : !ttg.memdesc<1x2048xbf16, #shared, #smem, mutable> -> tensor<1x2048xbf16, #blocked> %40 = arith.extf %39 : tensor<1x2048xbf16, #blocked> to tensor<1x2048xf32, #blocked> tt.return @@ -2090,7 +2082,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ tt.func public @test_local_store_subview(%arg0: tensor<1xf32, #blocked>) { %c0_i32 = arith.constant 0 : i32 %0 = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<1xf32, #shared, #smem, mutable> - %sv = ttg.memdesc_subview %0[%c0_i32] : !ttg.memdesc<1xf32, #shared, #smem, mutable> -> !ttg.memdesc<1xf32, #shared, #smem, mutable> + %sv = ttg.memdesc_index %0, %c0_i32 : !ttg.memdesc<1xf32, #shared, #smem, mutable> -> !ttg.memdesc<1xf32, #shared, #smem, mutable> ttg.local_store %arg0, %sv : tensor<1xf32, #blocked> -> !ttg.memdesc<1xf32, #shared, #smem, mutable> tt.return } diff --git a/test/NVWS/lower_aref.mlir b/test/NVWS/lower_aref.mlir index 7c1ba32e8659..89db9dc0ba17 100644 --- a/test/NVWS/lower_aref.mlir +++ b/test/NVWS/lower_aref.mlir @@ -19,9 +19,9 @@ module attributes {"ttg.target" = "cuda:0", "ttg.num-ctas" = 1 : i32, "ttg.num-w // CHECK: [[EMPTY0:%.*]] = ttg.local_alloc // CHECK-NEXT: [[FULL0:%.*]] = ttg.local_alloc // CHECK-NEXT: scf.for - // CHECK-NEXT: [[EMPTYSLICE:%.*]] = ttg.memdesc_subview [[EMPTY0]] + // CHECK-NEXT: [[EMPTYSLICE:%.*]] = ttg.memdesc_index [[EMPTY0]] // CHECK-NEXT: ttng.init_barrier [[EMPTYSLICE]], 1 - // CHECK-NEXT: [[FULLSLICE:%.*]] = ttg.memdesc_subview [[FULL0]] + // CHECK-NEXT: [[FULLSLICE:%.*]] = ttg.memdesc_index [[FULL0]] // CHECK-NEXT: ttng.init_barrier [[FULLSLICE]], 129 // CHECK-NEXT: } %aref0 = nvws.aref.create %d, %e : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #tmem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]> @@ -29,9 +29,9 @@ module attributes {"ttg.target" = "cuda:0", "ttg.num-ctas" = 1 : i32, "ttg.num-w // CHECK: [[EMPTY1:%.*]] = ttg.local_alloc // CHECK-NEXT: [[FULL1:%.*]] = ttg.local_alloc // CHECK-NEXT: scf.for - // CHECK-NEXT: [[EMPTYSLICE:%.*]] = ttg.memdesc_subview [[EMPTY1]] + // CHECK-NEXT: [[EMPTYSLICE:%.*]] = ttg.memdesc_index [[EMPTY1]] // CHECK-NEXT: ttng.init_barrier [[EMPTYSLICE]], 256 - // CHECK-NEXT: [[FULLSLICE:%.*]] = ttg.memdesc_subview [[FULL1]] + // CHECK-NEXT: [[FULLSLICE:%.*]] = ttg.memdesc_index [[FULL1]] // CHECK-NEXT: ttng.init_barrier [[FULLSLICE]], 128 // CHECK-NEXT: } %aref1 = nvws.aref.create %d, %e : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #tmem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]> @@ -42,7 +42,7 @@ module attributes {"ttg.target" = "cuda:0", "ttg.num-ctas" = 1 : i32, "ttg.num-w scf.for %i = %lb to %ub step %c1_i32 : i32{ // CHECK-NEXT: [[EMPTYIDX:%.*]] = arith.remsi [[IDX0]], [[C3]] - // CHECK-NEXT: [[EMPTYMBAR:%.*]] = ttg.memdesc_subview [[EMPTY0]][[[EMPTYIDX]]] + // CHECK-NEXT: [[EMPTYMBAR:%.*]] = ttg.memdesc_index [[EMPTY0]], [[EMPTYIDX]] // CHECK-NEXT: [[PHASE_DIV:%.*]] = arith.divsi [[IDX0]], [[C3]] // CHECK-NEXT: [[PHASE_AND:%.*]] = arith.andi [[PHASE_DIV]], [[C1]] // CHECK-NEXT: [[PHASE_XOR:%.*]] = arith.xori [[PHASE_AND]], [[C1]] @@ -50,10 +50,10 @@ module attributes {"ttg.target" = "cuda:0", "ttg.num-ctas" = 1 : i32, "ttg.num-w %1:2 = nvws.aref.put.enter %aref0[%c0_i32] {aref_tag = "put0"} : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #tmem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]> -> !ttg.memdesc<64x16xf16, #shared0, #tmem>, !ttg.memdesc<16x32xf16, #shared0, #smem> // CHECK-NEXT: [[STAGE:%.*]] = arith.remsi [[IDX0]], [[C3]] - // CHECK-NEXT: [[BUFA:%.*]] = ttg.memdesc_subview %arg0[[[STAGE]],{{.*}},{{.*}}] - // CHECK-NEXT: [[BUFB:%.*]] = ttg.memdesc_subview %arg1[[[STAGE]],{{.*}},{{.*}}] + // CHECK-NEXT: [[BUFA:%.*]] = ttg.memdesc_index %arg0, [[STAGE]] + // CHECK-NEXT: [[BUFB:%.*]] = ttg.memdesc_index %arg1, [[STAGE]] // CHECK-NEXT: [[FULLIDX:%.*]] = arith.remsi [[IDX2]], [[C3]] - // CHECK-NEXT: [[FULLMBAR:%.*]] = ttg.memdesc_subview [[FULL0]][[[FULLIDX]]] + // CHECK-NEXT: [[FULLMBAR:%.*]] = ttg.memdesc_index [[FULL0]], [[FULLIDX]] // CHECK-NEXT: ttng.barrier_expect [[FULLMBAR]], 0 // CHECK-NEXT: [[IDX0a:%.*]] = arith.addi [[IDX0]], [[C1]] // CHECK-NEXT: "tma_load"([[BUFA]]) @@ -62,7 +62,7 @@ module attributes {"ttg.target" = "cuda:0", "ttg.num-ctas" = 1 : i32, "ttg.num-w "cp_async"(%1#1) : (!ttg.memdesc<16x32xf16, #shared0, #smem>) -> () // CHECK-NEXT: [[FULLIDX:%.*]] = arith.remsi [[IDX2]], [[C3]] - // CHECK-NEXT: [[FULLMBAR:%.*]] = ttg.memdesc_subview [[FULL0]][[[FULLIDX]]] + // CHECK-NEXT: [[FULLMBAR:%.*]] = ttg.memdesc_index [[FULL0]], [[FULLIDX]] // CHECK-NEXT: nvws.async_complete [[FULLMBAR]], async_op = // CHECK-NEXT: nvws.async_complete [[FULLMBAR]], async_op = // CHECK-NEXT: [[IDX2a:%.*]] = arith.addi [[IDX2]], [[C1]] @@ -121,21 +121,21 @@ module attributes {"ttg.target" = "cuda:0", "ttg.num-ctas" = 1 : i32, "ttg.num-w scf.for %i = %lb to %ub step %c1_i32 : i32{ // CHECK-NEXT: [[FULLIDX:%.*]] = arith.remsi [[IDX0]], [[C3]] - // CHECK-NEXT: [[FULLMBAR:%.*]] = ttg.memdesc_subview [[FULL0]][[[FULLIDX]]] + // CHECK-NEXT: [[FULLMBAR:%.*]] = ttg.memdesc_index [[FULL0]], [[FULLIDX]] // CHECK-NEXT: [[PHASE_DIV:%.*]] = arith.divsi [[IDX0]], [[C3]] // CHECK-NEXT: [[PHASE_AND:%.*]] = arith.andi [[PHASE_DIV]], [[C1]] // CHECK-NEXT: ttng.wait_barrier [[FULLMBAR]], [[PHASE_AND]] %2:2 = nvws.aref.get.enter %aref0[%c0_i32] : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #tmem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]> -> !ttg.memdesc<64x16xf16, #shared0, #tmem>, !ttg.memdesc<16x32xf16, #shared0, #smem> // CHECK-NEXT: [[STAGE:%.*]] = arith.remsi [[IDX0]], [[C3]] - // CHECK-NEXT: [[BUFA:%.*]] = ttg.memdesc_subview %arg0[[[STAGE]],{{.*}},{{.*}}] - // CHECK-NEXT: [[BUFB:%.*]] = ttg.memdesc_subview %arg1[[[STAGE]],{{.*}},{{.*}}] + // CHECK-NEXT: [[BUFA:%.*]] = ttg.memdesc_index %arg0, [[STAGE]] + // CHECK-NEXT: [[BUFB:%.*]] = ttg.memdesc_index %arg1, [[STAGE]] // CHECK-NEXT: arith.addi // CHECK-NEXT: "tc5mma"([[BUFA]], [[BUFB]]) "tc5mma"(%2#0, %2#1) : (!ttg.memdesc<64x16xf16, #shared0, #tmem>, !ttg.memdesc<16x32xf16, #shared0, #smem>) -> () // CHECK-NEXT: [[EMPTYIDX:%.*]] = arith.remsi [[IDX2]], [[C3]] - // CHECK-NEXT: [[EMPTYMBAR:%.*]] = ttg.memdesc_subview [[EMPTY0]][[[EMPTYIDX]]] + // CHECK-NEXT: [[EMPTYMBAR:%.*]] = ttg.memdesc_index [[EMPTY0]], [[EMPTYIDX]] // CHECK-NEXT: nvws.async_complete [[EMPTYMBAR]], async_op = // CHECK-NEXT: arith.addi nvws.aref.get.exit %aref0[%c0_i32] [#nvws.async_op] : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #tmem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]> @@ -150,7 +150,7 @@ module attributes {"ttg.target" = "cuda:0", "ttg.num-ctas" = 1 : i32, "ttg.num-w "tmem_load"(%3#0, %3#1) : (!ttg.memdesc<64x16xf16, #shared0, #tmem>, !ttg.memdesc<16x32xf16, #shared0, #smem>) -> () // CHECK: arith.remsi [[IDX3]], [[C3]] - // CHECK-NEXT: ttg.memdesc_subview + // CHECK-NEXT: ttg.memdesc_index // CHECK-NEXT: nvws.async_complete {{.*}}, async_op = nvws.aref.get.exit %aref1[%c0_i32] [#nvws.async_op] : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #tmem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]> } diff --git a/test/TritonGPU/amd/amd-block-pingpong.mlir b/test/TritonGPU/amd/amd-block-pingpong.mlir index 2fd8c9ae066d..c076ac92ea47 100644 --- a/test/TritonGPU/amd/amd-block-pingpong.mlir +++ b/test/TritonGPU/amd/amd-block-pingpong.mlir @@ -51,8 +51,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ %20 = tt.addptr %18, %19 : tensor<64x128x!tt.ptr, #blocked>, tensor<64x128xi32, #blocked> %21 = ttg.local_alloc : () -> !ttg.memdesc<1x128x64xf16, #shared, #smem, mutable> %22 = ttg.local_alloc : () -> !ttg.memdesc<1x64x128xf16, #shared1, #smem, mutable> - %23 = ttg.memdesc_subview %21[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable> - %24 = ttg.memdesc_subview %22[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x128xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x128xf16, #shared1, #smem, mutable> + %23 = ttg.memdesc_index %21, %c0_i32 : !ttg.memdesc<1x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable> + %24 = ttg.memdesc_index %22, %c0_i32 : !ttg.memdesc<1x64x128xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x128xf16, #shared1, #smem, mutable> %25:6 = scf.for %arg5 = %c0_i32 to %c64_i32 step %c1_i32 iter_args(%arg6 = %cst, %arg7 = %13, %arg8 = %20, %arg9 = %c0_i32, %arg10 = %23, %arg11 = %24) -> (tensor<128x128xf32, #mma>, tensor<128x64x!tt.ptr, #blocked1>, tensor<64x128x!tt.ptr, #blocked>, i32, !ttg.memdesc<128x64xf16, #shared, #smem, mutable>, !ttg.memdesc<64x128xf16, #shared1, #smem, mutable>) : i32 { %26 = tt.addptr %arg7, %cst_1 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> %27 = tt.load %26 : tensor<128x64x!tt.ptr, #blocked1> @@ -65,9 +65,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ %34 = arith.addi %arg9, %c1_i32 : i32 %35 = arith.cmpi slt, %34, %c1_i32 : i32 %36 = arith.select %35, %34, %c0_i32 : i32 - %37 = ttg.memdesc_subview %21[%36, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable> + %37 = ttg.memdesc_index %21, %36 : !ttg.memdesc<1x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable> ttg.local_store %27, %37 : tensor<128x64xf16, #blocked1> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable> - %38 = ttg.memdesc_subview %22[%36, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x128xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x128xf16, #shared1, #smem, mutable> + %38 = ttg.memdesc_index %22, %36 : !ttg.memdesc<1x64x128xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x128xf16, #shared1, #smem, mutable> ttg.local_store %29, %38 : tensor<64x128xf16, #blocked> -> !ttg.memdesc<64x128xf16, #shared1, #smem, mutable> scf.yield %33, %26, %28, %36, %37, %38 : tensor<128x128xf32, #mma>, tensor<128x64x!tt.ptr, #blocked1>, tensor<64x128x!tt.ptr, #blocked>, i32, !ttg.memdesc<128x64xf16, #shared, #smem, mutable>, !ttg.memdesc<64x128xf16, #shared1, #smem, mutable> } @@ -165,8 +165,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ %20 = tt.addptr %18, %19 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> %21 = ttg.local_alloc : () -> !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> %22 = ttg.local_alloc : () -> !ttg.memdesc<1x64x256xf16, #shared1, #ttg.shared_memory, mutable> - %23 = ttg.memdesc_subview %21[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> - %24 = ttg.memdesc_subview %22[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x256xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #ttg.shared_memory, mutable> + %23 = ttg.memdesc_index %21, %c0_i32 : !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> + %24 = ttg.memdesc_index %22, %c0_i32 : !ttg.memdesc<1x64x256xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #ttg.shared_memory, mutable> %25:6 = scf.for %arg5 = %c0_i32 to %c64_i32 step %c1_i32 iter_args(%arg6 = %cst, %arg7 = %13, %arg8 = %20, %arg9 = %c0_i32, %arg10 = %23, %arg11 = %24) -> (tensor<256x256xf32, #mma>, tensor<256x64x!tt.ptr, #blocked1>, tensor<64x256x!tt.ptr, #blocked>, i32, !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<64x256xf16, #shared1, #ttg.shared_memory, mutable>) : i32 { %26 = tt.addptr %arg7, %cst_1 : tensor<256x64x!tt.ptr, #blocked1>, tensor<256x64xi32, #blocked1> %27 = tt.load %26 : tensor<256x64x!tt.ptr, #blocked1> @@ -178,9 +178,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ %33 = arith.addi %arg9, %c1_i32 : i32 %34 = arith.cmpi slt, %33, %c1_i32 : i32 %35 = arith.select %34, %33, %c0_i32 : i32 - %36 = ttg.memdesc_subview %21[%35, %c0_i32, %c0_i32] : !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> + %36 = ttg.memdesc_index %21, %35 : !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> ttg.local_store %27, %36 : tensor<256x64xf16, #blocked1> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> - %37 = ttg.memdesc_subview %22[%35, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x256xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #ttg.shared_memory, mutable> + %37 = ttg.memdesc_index %22, %35 : !ttg.memdesc<1x64x256xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #ttg.shared_memory, mutable> ttg.local_store %29, %37 : tensor<64x256xf16, #blocked> -> !ttg.memdesc<64x256xf16, #shared1, #ttg.shared_memory, mutable> scf.yield %32, %26, %28, %35, %36, %37 : tensor<256x256xf32, #mma>, tensor<256x64x!tt.ptr, #blocked1>, tensor<64x256x!tt.ptr, #blocked>, i32, !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<64x256xf16, #shared1, #ttg.shared_memory, mutable> } @@ -264,8 +264,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ %20 = tt.addptr %18, %19 : tensor<64x128x!tt.ptr, #blocked>, tensor<64x128xi32, #blocked> %21 = ttg.local_alloc : () -> !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> %22 = ttg.local_alloc : () -> !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> - %23 = ttg.memdesc_subview %21[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> - %24 = ttg.memdesc_subview %22[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> + %23 = ttg.memdesc_index %21, %c0_i32 : !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> + %24 = ttg.memdesc_index %22, %c0_i32 : !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> %25:6 = scf.for %arg5 = %c0_i32 to %c64_i32 step %c1_i32 iter_args(%arg6 = %cst, %arg7 = %13, %arg8 = %20, %arg9 = %c0_i32, %arg10 = %23, %arg11 = %24) -> (tensor<256x128xf32, #mma>, tensor<256x64x!tt.ptr, #blocked1>, tensor<64x128x!tt.ptr, #blocked>, i32, !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable>) : i32 { %26 = tt.addptr %arg7, %cst_1 : tensor<256x64x!tt.ptr, #blocked1>, tensor<256x64xi32, #blocked1> %27 = tt.load %26 : tensor<256x64x!tt.ptr, #blocked1> @@ -277,9 +277,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ %33 = arith.addi %arg9, %c1_i32 : i32 %34 = arith.cmpi slt, %33, %c1_i32 : i32 %35 = arith.select %34, %33, %c0_i32 : i32 - %36 = ttg.memdesc_subview %21[%35, %c0_i32, %c0_i32] : !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> + %36 = ttg.memdesc_index %21, %35 : !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> ttg.local_store %27, %36 : tensor<256x64xf16, #blocked1> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> - %37 = ttg.memdesc_subview %22[%35, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> + %37 = ttg.memdesc_index %22, %35 : !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> ttg.local_store %29, %37 : tensor<64x128xf16, #blocked> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> scf.yield %32, %26, %28, %35, %36, %37 : tensor<256x128xf32, #mma>, tensor<256x64x!tt.ptr, #blocked1>, tensor<64x128x!tt.ptr, #blocked>, i32, !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> } @@ -332,8 +332,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ %20 = tt.addptr %18, %19 : tensor<64x128x!tt.ptr, #blocked>, tensor<64x128xi32, #blocked> %21 = ttg.local_alloc : () -> !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> %22 = ttg.local_alloc : () -> !ttg.memdesc<1x64x128xi16, #shared1, #ttg.shared_memory, mutable> - %23 = ttg.memdesc_subview %21[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> - %24 = ttg.memdesc_subview %22[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x128xi16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x128xi16, #shared1, #ttg.shared_memory, mutable> + %23 = ttg.memdesc_index %21, %c0_i32 : !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> + %24 = ttg.memdesc_index %22, %c0_i32 : !ttg.memdesc<1x64x128xi16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x128xi16, #shared1, #ttg.shared_memory, mutable> %25:6 = scf.for %arg5 = %c0_i32 to %c64_i32 step %c1_i32 iter_args(%arg6 = %cst, %arg7 = %13, %arg8 = %20, %arg9 = %c0_i32, %arg10 = %23, %arg11 = %24) -> (tensor<256x128xf32, #mma>, tensor<256x64x!tt.ptr, #blocked1>, tensor<64x128x!tt.ptr, #blocked>, i32, !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<64x128xi16, #shared1, #ttg.shared_memory, mutable>) : i32 { %26 = tt.addptr %arg7, %cst_1 : tensor<256x64x!tt.ptr, #blocked1>, tensor<256x64xi32, #blocked1> %27 = tt.load %26 : tensor<256x64x!tt.ptr, #blocked1> @@ -347,9 +347,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ %33 = arith.addi %arg9, %c1_i32 : i32 %34 = arith.cmpi slt, %33, %c1_i32 : i32 %35 = arith.select %34, %33, %c0_i32 : i32 - %36 = ttg.memdesc_subview %21[%35, %c0_i32, %c0_i32] : !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> + %36 = ttg.memdesc_index %21, %35 : !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> ttg.local_store %27, %36 : tensor<256x64xf16, #blocked1> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> - %37 = ttg.memdesc_subview %22[%35, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x128xi16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x128xi16, #shared1, #ttg.shared_memory, mutable> + %37 = ttg.memdesc_index %22, %35 : !ttg.memdesc<1x64x128xi16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x128xi16, #shared1, #ttg.shared_memory, mutable> ttg.local_store %cast2, %37 : tensor<64x128xi16, #blocked> -> !ttg.memdesc<64x128xi16, #shared1, #ttg.shared_memory, mutable> scf.yield %32, %26, %28, %35, %36, %37 : tensor<256x128xf32, #mma>, tensor<256x64x!tt.ptr, #blocked1>, tensor<64x128x!tt.ptr, #blocked>, i32, !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<64x128xi16, #shared1, #ttg.shared_memory, mutable> } @@ -405,8 +405,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ %20 = tt.addptr %18, %19 : tensor<16x256x!tt.ptr, #blocked>, tensor<16x256xi32, #blocked> %21 = ttg.local_alloc : () -> !ttg.memdesc<1x256x16xf16, #shared, #ttg.shared_memory, mutable> %22 = ttg.local_alloc : () -> !ttg.memdesc<1x16x256xf16, #shared1, #ttg.shared_memory, mutable> - %23 = ttg.memdesc_subview %21[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x256x16xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x16xf16, #shared, #ttg.shared_memory, mutable> - %24 = ttg.memdesc_subview %22[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x16x256xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<16x256xf16, #shared1, #ttg.shared_memory, mutable> + %23 = ttg.memdesc_index %21, %c0_i32 : !ttg.memdesc<1x256x16xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x16xf16, #shared, #ttg.shared_memory, mutable> + %24 = ttg.memdesc_index %22, %c0_i32 : !ttg.memdesc<1x16x256xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<16x256xf16, #shared1, #ttg.shared_memory, mutable> %25:6 = scf.for %arg5 = %c0_i32 to %c64_i32 step %c1_i32 iter_args(%arg6 = %cst, %arg7 = %13, %arg8 = %20, %arg9 = %c0_i32, %arg10 = %23, %arg11 = %24) -> (tensor<256x256xf32, #mma>, tensor<256x16x!tt.ptr, #blocked1>, tensor<16x256x!tt.ptr, #blocked>, i32, !ttg.memdesc<256x16xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<16x256xf16, #shared1, #ttg.shared_memory, mutable>) : i32 { %26 = tt.addptr %arg7, %cst_1 : tensor<256x16x!tt.ptr, #blocked1>, tensor<256x16xi32, #blocked1> %27 = tt.load %26 : tensor<256x16x!tt.ptr, #blocked1> @@ -418,9 +418,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ %33 = arith.addi %arg9, %c1_i32 : i32 %34 = arith.cmpi slt, %33, %c1_i32 : i32 %35 = arith.select %34, %33, %c0_i32 : i32 - %36 = ttg.memdesc_subview %21[%35, %c0_i32, %c0_i32] : !ttg.memdesc<1x256x16xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x16xf16, #shared, #ttg.shared_memory, mutable> + %36 = ttg.memdesc_index %21, %35 : !ttg.memdesc<1x256x16xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x16xf16, #shared, #ttg.shared_memory, mutable> ttg.local_store %27, %36 : tensor<256x16xf16, #blocked1> -> !ttg.memdesc<256x16xf16, #shared, #ttg.shared_memory, mutable> - %37 = ttg.memdesc_subview %22[%35, %c0_i32, %c0_i32] : !ttg.memdesc<1x16x256xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<16x256xf16, #shared1, #ttg.shared_memory, mutable> + %37 = ttg.memdesc_index %22, %35 : !ttg.memdesc<1x16x256xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<16x256xf16, #shared1, #ttg.shared_memory, mutable> ttg.local_store %29, %37 : tensor<16x256xf16, #blocked> -> !ttg.memdesc<16x256xf16, #shared1, #ttg.shared_memory, mutable> scf.yield %32, %26, %28, %35, %36, %37 : tensor<256x256xf32, #mma>, tensor<256x16x!tt.ptr, #blocked1>, tensor<16x256x!tt.ptr, #blocked>, i32, !ttg.memdesc<256x16xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<16x256xf16, #shared1, #ttg.shared_memory, mutable> } @@ -472,8 +472,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ %20 = tt.addptr %18, %19 : tensor<64x128x!tt.ptr, #blocked>, tensor<64x128xi32, #blocked> %21 = ttg.local_alloc : () -> !ttg.memdesc<1x128x64xf16, #shared, #ttg.shared_memory, mutable> %22 = ttg.local_alloc : () -> !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> - %23 = ttg.memdesc_subview %21[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory, mutable> - %24 = ttg.memdesc_subview %22[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> + %23 = ttg.memdesc_index %21, %c0_i32 : !ttg.memdesc<1x128x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory, mutable> + %24 = ttg.memdesc_index %22, %c0_i32 : !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> %25:6 = scf.for %arg5 = %c0_i32 to %c64_i32 step %c1_i32 iter_args(%arg6 = %cst, %arg7 = %13, %arg8 = %20, %arg9 = %c0_i32, %arg10 = %23, %arg11 = %24) -> (tensor<128x128xf32, #mma>, tensor<128x64x!tt.ptr, #blocked1>, tensor<64x128x!tt.ptr, #blocked>, i32, !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable>) : i32 { %26 = arith.cmpi eq, %arg5, %c0_i32: i32 %27 = scf.if %26 -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> { @@ -481,7 +481,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ %29 = tt.broadcast %28 : tensor<128x1x!tt.ptr, #blocked1> -> tensor<128x64x!tt.ptr, #blocked1> %30 = tt.load %29 : tensor<128x64x!tt.ptr, #blocked1> %31 = ttg.local_alloc : () -> !ttg.memdesc<1x128x64xf16, #shared, #ttg.shared_memory, mutable> - %32 = ttg.memdesc_subview %31[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory, mutable> + %32 = ttg.memdesc_index %31, %c0_i32 : !ttg.memdesc<1x128x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory, mutable> ttg.local_store %30, %32 : tensor<128x64xf16, #blocked1> -> !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory, mutable> %33 = ttg.local_load %32 : !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory, mutable> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> scf.yield %33 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> @@ -499,9 +499,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ %42 = arith.addi %arg9, %c1_i32 : i32 %43 = arith.cmpi slt, %42, %c1_i32 : i32 %44 = arith.select %43, %42, %c0_i32 : i32 - %45 = ttg.memdesc_subview %21[%44, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory, mutable> + %45 = ttg.memdesc_index %21, %44 : !ttg.memdesc<1x128x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory, mutable> ttg.local_store %35, %45 : tensor<128x64xf16, #blocked1> -> !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory, mutable> - %46 = ttg.memdesc_subview %22[%44, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> + %46 = ttg.memdesc_index %22, %44 : !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> ttg.local_store %37, %46 : tensor<64x128xf16, #blocked> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> scf.yield %41, %34, %36, %44, %45, %46 : tensor<128x128xf32, #mma>, tensor<128x64x!tt.ptr, #blocked1>, tensor<64x128x!tt.ptr, #blocked>, i32, !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> } @@ -588,8 +588,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ %20 = tt.addptr %18, %19 : tensor<64x128x!tt.ptr, #blocked>, tensor<64x128xi32, #blocked> %21 = ttg.local_alloc : () -> !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> %22 = ttg.local_alloc : () -> !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> - %23 = ttg.memdesc_subview %21[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> - %24 = ttg.memdesc_subview %22[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> + %23 = ttg.memdesc_index %21, %c0_i32 : !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> + %24 = ttg.memdesc_index %22, %c0_i32 : !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> %25:6 = scf.for %arg5 = %c0_i32 to %c64_i32 step %c1_i32 iter_args(%arg6 = %cst, %arg7 = %13, %arg8 = %20, %arg9 = %c0_i32, %arg10 = %23, %arg11 = %24) -> (tensor<256x128xf32, #mma>, tensor<256x64x!tt.ptr, #blocked1>, tensor<64x128x!tt.ptr, #blocked>, i32, !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable>) : i32 { %26 = tt.addptr %arg7, %cst_1 : tensor<256x64x!tt.ptr, #blocked1>, tensor<256x64xi32, #blocked1> %27 = tt.load %26 : tensor<256x64x!tt.ptr, #blocked1> @@ -602,9 +602,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ %34 = arith.addi %arg9, %c1_i32 : i32 %35 = arith.cmpi slt, %34, %c1_i32 : i32 %36 = arith.select %35, %34, %c0_i32 : i32 - %37 = ttg.memdesc_subview %21[%36, %c0_i32, %c0_i32] : !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> + %37 = ttg.memdesc_index %21, %36 : !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> ttg.local_store %27, %37 : tensor<256x64xf16, #blocked1> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> - %38 = ttg.memdesc_subview %22[%36, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> + %38 = ttg.memdesc_index %22, %36 : !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> ttg.local_store %29, %38 : tensor<64x128xf16, #blocked> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> scf.yield %33, %26, %28, %36, %37, %38 : tensor<256x128xf32, #mma>, tensor<256x64x!tt.ptr, #blocked1>, tensor<64x128x!tt.ptr, #blocked>, i32, !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> } @@ -705,8 +705,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ %20 = tt.addptr %18, %19 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> %21 = ttg.local_alloc : () -> !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> %22 = ttg.local_alloc : () -> !ttg.memdesc<1x64x256xf16, #shared1, #ttg.shared_memory, mutable> - %23 = ttg.memdesc_subview %21[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> - %24 = ttg.memdesc_subview %22[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x256xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #ttg.shared_memory, mutable> + %23 = ttg.memdesc_index %21, %c0_i32 : !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> + %24 = ttg.memdesc_index %22, %c0_i32 : !ttg.memdesc<1x64x256xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #ttg.shared_memory, mutable> %25:6 = scf.for %arg5 = %c0_i32 to %c64_i32 step %c1_i32 iter_args(%arg6 = %cst, %arg7 = %13, %arg8 = %20, %arg9 = %c0_i32, %arg10 = %23, %arg11 = %24) -> (tensor<256x256xf32, #mma>, tensor<256x64x!tt.ptr, #blocked1>, tensor<64x256x!tt.ptr, #blocked>, i32, !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<64x256xf16, #shared1, #ttg.shared_memory, mutable>) : i32 { %26 = tt.addptr %arg7, %cst_1 : tensor<256x64x!tt.ptr, #blocked1>, tensor<256x64xi32, #blocked1> %27 = tt.load %26 : tensor<256x64x!tt.ptr, #blocked1> @@ -719,9 +719,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ %34 = arith.addi %arg9, %c1_i32 : i32 %35 = arith.cmpi slt, %34, %c1_i32 : i32 %36 = arith.select %35, %34, %c0_i32 : i32 - %37 = ttg.memdesc_subview %21[%36, %c0_i32, %c0_i32] : !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> + %37 = ttg.memdesc_index %21, %36 : !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> ttg.local_store %27, %37 : tensor<256x64xf16, #blocked1> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> - %38 = ttg.memdesc_subview %22[%36, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x256xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #ttg.shared_memory, mutable> + %38 = ttg.memdesc_index %22, %36 : !ttg.memdesc<1x64x256xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #ttg.shared_memory, mutable> ttg.local_store %29, %38 : tensor<64x256xf16, #blocked> -> !ttg.memdesc<64x256xf16, #shared1, #ttg.shared_memory, mutable> scf.yield %33, %26, %28, %36, %37, %38 : tensor<256x256xf32, #mma>, tensor<256x64x!tt.ptr, #blocked1>, tensor<64x256x!tt.ptr, #blocked>, i32, !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<64x256xf16, #shared1, #ttg.shared_memory, mutable> } @@ -780,8 +780,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ %20 = tt.addptr %18, %19 : tensor<64x128x!tt.ptr, #blocked>, tensor<64x128xi32, #blocked> %21 = ttg.local_alloc : () -> !ttg.memdesc<1x128x64xf16, #shared, #ttg.shared_memory, mutable> %22 = ttg.local_alloc : () -> !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> - %23 = ttg.memdesc_subview %21[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory, mutable> - %24 = ttg.memdesc_subview %22[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> + %23 = ttg.memdesc_index %21, %c0_i32 : !ttg.memdesc<1x128x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory, mutable> + %24 = ttg.memdesc_index %22, %c0_i32 : !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> %25:6 = scf.for %arg5 = %c0_i32 to %c64_i32 step %c1_i32 iter_args(%arg6 = %cst, %arg7 = %13, %arg8 = %20, %arg9 = %c0_i32, %arg10 = %23, %arg11 = %24) -> (tensor<128x128xf32, #mma>, tensor<128x64x!tt.ptr, #blocked1>, tensor<64x128x!tt.ptr, #blocked>, i32, !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable>) : i32 { // This swaps the assumption on the ordering of the local load and // global load from the base test to ensure the one ping pong cluster @@ -796,9 +796,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ %33 = arith.addi %arg9, %c1_i32 : i32 %34 = arith.cmpi slt, %33, %c1_i32 : i32 %35 = arith.select %34, %33, %c0_i32 : i32 - %36 = ttg.memdesc_subview %21[%35, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory, mutable> + %36 = ttg.memdesc_index %21, %35 : !ttg.memdesc<1x128x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory, mutable> ttg.local_store %29, %36 : tensor<128x64xf16, #blocked1> -> !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory, mutable> - %37 = ttg.memdesc_subview %22[%35, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> + %37 = ttg.memdesc_index %22, %35 : !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> ttg.local_store %31, %37 : tensor<64x128xf16, #blocked> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> scf.yield %32, %28, %30, %35, %36, %37 : tensor<128x128xf32, #mma>, tensor<128x64x!tt.ptr, #blocked1>, tensor<64x128x!tt.ptr, #blocked>, i32, !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> } @@ -860,8 +860,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ %20 = tt.addptr %18, %19 : tensor<64x128x!tt.ptr, #blocked>, tensor<64x128xi32, #blocked> %21 = ttg.local_alloc : () -> !ttg.memdesc<1x128x64xf16, #shared, #ttg.shared_memory, mutable> %22 = ttg.local_alloc : () -> !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> - %23 = ttg.memdesc_subview %21[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory, mutable> - %24 = ttg.memdesc_subview %22[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> + %23 = ttg.memdesc_index %21, %c0_i32 : !ttg.memdesc<1x128x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory, mutable> + %24 = ttg.memdesc_index %22, %c0_i32 : !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> %25:6 = scf.for %arg5 = %c0_i32 to %c64_i32 step %c1_i32 iter_args(%arg6 = %cst, %arg7 = %13, %arg8 = %20, %arg9 = %c0_i32, %arg10 = %23, %arg11 = %24) -> (tensor<128x128xf32, #mma>, tensor<128x64x!tt.ptr, #blocked1>, tensor<64x128x!tt.ptr, #blocked>, i32, !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable>) : i32 { %26 = tt.addptr %arg7, %cst_1 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> %27 = tt.load %26 : tensor<128x64x!tt.ptr, #blocked1> @@ -874,9 +874,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ %34 = arith.addi %arg9, %c1_i32 : i32 %35 = arith.cmpi slt, %34, %c1_i32 : i32 %36 = arith.select %35, %34, %c0_i32 : i32 - %37 = ttg.memdesc_subview %21[%36, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory, mutable> + %37 = ttg.memdesc_index %21, %36 : !ttg.memdesc<1x128x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory, mutable> ttg.local_store %27, %37 : tensor<128x64xf16, #blocked1> -> !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory, mutable> - %38 = ttg.memdesc_subview %22[%36, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> + %38 = ttg.memdesc_index %22, %36 : !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> ttg.local_store %29, %38 : tensor<64x128xf16, #blocked> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> scf.yield %33, %26, %28, %36, %37, %38 : tensor<128x128xf32, #mma>, tensor<128x64x!tt.ptr, #blocked1>, tensor<64x128x!tt.ptr, #blocked>, i32, !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> } @@ -944,9 +944,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ %11 = ttg.local_alloc : () -> !ttg.memdesc<1x64x128xf16, #shared1, #smem, mutable> %12 = tt.load %2 : tensor<256x64x!tt.ptr, #blocked1> %13 = tt.load %8 : tensor<64x128x!tt.ptr, #blocked> - %14 = ttg.memdesc_subview %10[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x256x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<256x64xf16, #shared, #smem, mutable> + %14 = ttg.memdesc_index %10, %c0_i32 : !ttg.memdesc<1x256x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<256x64xf16, #shared, #smem, mutable> ttg.local_store %12, %14 : tensor<256x64xf16, #blocked1> -> !ttg.memdesc<256x64xf16, #shared, #smem, mutable> - %15 = ttg.memdesc_subview %11[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x128xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x128xf16, #shared1, #smem, mutable> + %15 = ttg.memdesc_index %11, %c0_i32 : !ttg.memdesc<1x64x128xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x128xf16, #shared1, #smem, mutable> ttg.local_store %13, %15 : tensor<64x128xf16, #blocked> -> !ttg.memdesc<64x128xf16, #shared1, #smem, mutable> %16:6 = scf.for %arg3 = %c0_i32 to %c192_i32 step %c64_i32 iter_args(%arg4 = %c0_i64, %arg5 = %c0_i64, %arg6 = %cst, %arg7 = %c0_i32, %arg8 = %14, %arg9 = %15) -> (i64, i64, tensor<256x128xf32, #mma>, i32, !ttg.memdesc<256x64xf16, #shared, #smem, mutable>, !ttg.memdesc<64x128xf16, #shared1, #smem, mutable>) : i32 { %22 = arith.addi %arg4, %c64_i64 : i64 @@ -972,9 +972,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ %42 = arith.addi %arg7, %c1_i32 : i32 %43 = arith.cmpi slt, %42, %c1_i32 : i32 %44 = arith.select %43, %42, %c0_i32 : i32 - %45 = ttg.memdesc_subview %10[%44, %c0_i32, %c0_i32] : !ttg.memdesc<1x256x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<256x64xf16, #shared, #smem, mutable> + %45 = ttg.memdesc_index %10, %44 : !ttg.memdesc<1x256x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<256x64xf16, #shared, #smem, mutable> ttg.local_store %30, %45 : tensor<256x64xf16, #blocked1> -> !ttg.memdesc<256x64xf16, #shared, #smem, mutable> - %46 = ttg.memdesc_subview %11[%44, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x128xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x128xf16, #shared1, #smem, mutable> + %46 = ttg.memdesc_index %11, %44 : !ttg.memdesc<1x64x128xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x128xf16, #shared1, #smem, mutable> ttg.local_store %39, %46 : tensor<64x128xf16, #blocked> -> !ttg.memdesc<64x128xf16, #shared1, #smem, mutable> scf.yield %22, %23, %41, %44, %45, %46 : i64, i64, tensor<256x128xf32, #mma>, i32, !ttg.memdesc<256x64xf16, #shared, #smem, mutable>, !ttg.memdesc<64x128xf16, #shared1, #smem, mutable> } @@ -1067,8 +1067,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ %20 = tt.addptr %18, %19 : tensor<64x128x!tt.ptr, #blocked>, tensor<64x128xi32, #blocked> %21 = ttg.local_alloc : () -> !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> %22 = ttg.local_alloc : () -> !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> - %23 = ttg.memdesc_subview %21[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> - %24 = ttg.memdesc_subview %22[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> + %23 = ttg.memdesc_index %21, %c0_i32 : !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> + %24 = ttg.memdesc_index %22, %c0_i32 : !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> %25:6 = scf.for %arg4 = %c0_i32 to %c64_i32 step %c1_i32 iter_args(%arg5 = %cst, %arg6 = %13, %arg7 = %20, %arg8 = %c0_i32, %arg9 = %23, %arg10 = %24) -> (tensor<256x128xf32, #mma>, tensor<256x64x!tt.ptr, #blocked1>, tensor<64x128x!tt.ptr, #blocked>, i32, !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable>) : i32 { %26 = tt.addptr %arg6, %cst_1 : tensor<256x64x!tt.ptr, #blocked1>, tensor<256x64xi32, #blocked1> %27 = tt.load %26 : tensor<256x64x!tt.ptr, #blocked1> @@ -1080,9 +1080,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ %33 = arith.addi %arg8, %c1_i32 : i32 %34 = arith.cmpi slt, %33, %c1_i32 : i32 %35 = arith.select %34, %33, %c0_i32 : i32 - %36 = ttg.memdesc_subview %21[%35, %c0_i32, %c0_i32] : !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> + %36 = ttg.memdesc_index %21, %35 : !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> ttg.local_store %27, %36 : tensor<256x64xf16, #blocked1> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> - %37 = ttg.memdesc_subview %22[%35, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> + %37 = ttg.memdesc_index %22, %35 : !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> ttg.local_store %29, %37 : tensor<64x128xf16, #blocked> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> %38 = arith.cmpi eq, %arg4, %c63_i32: i32 %39 = scf.if %38 -> tensor<256x128xf32, #mma> { @@ -1190,8 +1190,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ %20 = tt.addptr %18, %19 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> %21 = ttg.local_alloc : () -> !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> %22 = ttg.local_alloc : () -> !ttg.memdesc<1x64x256xf16, #shared1, #ttg.shared_memory, mutable> - %23 = ttg.memdesc_subview %21[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> - %24 = ttg.memdesc_subview %22[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x256xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #ttg.shared_memory, mutable> + %23 = ttg.memdesc_index %21, %c0_i32 : !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> + %24 = ttg.memdesc_index %22, %c0_i32 : !ttg.memdesc<1x64x256xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #ttg.shared_memory, mutable> %25:6 = scf.for %arg4 = %c0_i32 to %c64_i32 step %c1_i32 iter_args(%arg5 = %cst, %arg6 = %13, %arg7 = %20, %arg8 = %c0_i32, %arg9 = %23, %arg10 = %24) -> (tensor<256x256xf32, #mma>, tensor<256x64x!tt.ptr, #blocked1>, tensor<64x256x!tt.ptr, #blocked>, i32, !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<64x256xf16, #shared1, #ttg.shared_memory, mutable>) : i32 { %26 = tt.addptr %arg6, %cst_1 : tensor<256x64x!tt.ptr, #blocked1>, tensor<256x64xi32, #blocked1> %27 = tt.load %26 : tensor<256x64x!tt.ptr, #blocked1> @@ -1203,9 +1203,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ %33 = arith.addi %arg8, %c1_i32 : i32 %34 = arith.cmpi slt, %33, %c1_i32 : i32 %35 = arith.select %34, %33, %c0_i32 : i32 - %36 = ttg.memdesc_subview %21[%35, %c0_i32, %c0_i32] : !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> + %36 = ttg.memdesc_index %21, %35 : !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> ttg.local_store %27, %36 : tensor<256x64xf16, #blocked1> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> - %37 = ttg.memdesc_subview %22[%35, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x256xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #ttg.shared_memory, mutable> + %37 = ttg.memdesc_index %22, %35 : !ttg.memdesc<1x64x256xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #ttg.shared_memory, mutable> ttg.local_store %29, %37 : tensor<64x256xf16, #blocked> -> !ttg.memdesc<64x256xf16, #shared1, #ttg.shared_memory, mutable> %38 = arith.cmpi eq, %arg4, %c63_i32: i32 %39 = scf.if %38 -> tensor<256x256xf32, #mma> { @@ -1266,8 +1266,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ %20 = tt.addptr %18, %19 : tensor<64x128x!tt.ptr, #blocked>, tensor<64x128xi32, #blocked> %21 = ttg.local_alloc : () -> !ttg.memdesc<1x128x64xf16, #shared, #ttg.shared_memory, mutable> %22 = ttg.local_alloc : () -> !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> - %23 = ttg.memdesc_subview %21[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory, mutable> - %24 = ttg.memdesc_subview %22[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> + %23 = ttg.memdesc_index %21, %c0_i32 : !ttg.memdesc<1x128x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory, mutable> + %24 = ttg.memdesc_index %22, %c0_i32 : !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> %25 = tt.splat %arg2 : !tt.ptr -> tensor<128x1x!tt.ptr, #mma> %26 = tt.broadcast %25 : tensor<128x1x!tt.ptr, #mma> -> tensor<128x128x!tt.ptr, #mma> %27 = tt.load %26: tensor<128x128x!tt.ptr, #mma> @@ -1280,16 +1280,16 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ %34 = ttg.local_load %arg11 : !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> -> tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>> %35 = tt.dot %33, %34, %arg6 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> * tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>> -> tensor<128x128xf32, #mma> %36 = ttg.local_alloc : () -> !ttg.memdesc<1x128x128xf32, #shared, #ttg.shared_memory, mutable> - %37 = ttg.memdesc_subview %36[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x128xf32, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x128xf32, #shared, #ttg.shared_memory, mutable> + %37 = ttg.memdesc_index %36, %c0_i32 : !ttg.memdesc<1x128x128xf32, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x128xf32, #shared, #ttg.shared_memory, mutable> ttg.local_store %27, %37 : tensor<128x128xf32, #mma> -> !ttg.memdesc<128x128xf32, #shared, #ttg.shared_memory, mutable> %38 = ttg.local_load %37 : !ttg.memdesc<128x128xf32, #shared, #ttg.shared_memory, mutable> -> tensor<128x128xf32, #mma> %39 = arith.addf %35, %38: tensor<128x128xf32, #mma> %40 = arith.addi %arg9, %c1_i32 : i32 %41 = arith.cmpi slt, %40, %c1_i32 : i32 %42 = arith.select %41, %40, %c0_i32 : i32 - %43 = ttg.memdesc_subview %21[%42, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory, mutable> + %43 = ttg.memdesc_index %21, %42 : !ttg.memdesc<1x128x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory, mutable> ttg.local_store %30, %43 : tensor<128x64xf16, #blocked1> -> !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory, mutable> - %44 = ttg.memdesc_subview %22[%42, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> + %44 = ttg.memdesc_index %22, %42 : !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> ttg.local_store %32, %44 : tensor<64x128xf16, #blocked> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> scf.yield %39, %29, %31, %42, %43, %44: tensor<128x128xf32, #mma>, tensor<128x64x!tt.ptr, #blocked1>, tensor<64x128x!tt.ptr, #blocked>, i32, !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> } @@ -1352,8 +1352,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ %20 = tt.addptr %18, %19 : tensor<64x128x!tt.ptr, #blocked>, tensor<64x128xi32, #blocked> %21 = ttg.local_alloc : () -> !ttg.memdesc<1x128x64xf16, #shared, #ttg.shared_memory, mutable> %22 = ttg.local_alloc : () -> !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> - %23 = ttg.memdesc_subview %21[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory, mutable> - %24 = ttg.memdesc_subview %22[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> + %23 = ttg.memdesc_index %21, %c0_i32 : !ttg.memdesc<1x128x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory, mutable> + %24 = ttg.memdesc_index %22, %c0_i32 : !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> %25 = tt.splat %arg2 : !tt.ptr -> tensor<128x1x!tt.ptr, #mma> %26 = tt.broadcast %25 : tensor<128x1x!tt.ptr, #mma> -> tensor<128x128x!tt.ptr, #mma> %27 = tt.load %26: tensor<128x128x!tt.ptr, #mma> @@ -1374,7 +1374,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ %38 = arith.cmpi eq, %30, %c63_i32: i32 %39 = scf.if %38 -> tensor<128x128xf32, #mma> { %40 = ttg.local_alloc : () -> !ttg.memdesc<1x128x128xf32, #shared, #ttg.shared_memory, mutable> - %41 = ttg.memdesc_subview %40[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x128xf32, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x128xf32, #shared, #ttg.shared_memory, mutable> + %41 = ttg.memdesc_index %40, %c0_i32 : !ttg.memdesc<1x128x128xf32, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x128xf32, #shared, #ttg.shared_memory, mutable> ttg.local_store %27, %41 : tensor<128x128xf32, #mma> -> !ttg.memdesc<128x128xf32, #shared, #ttg.shared_memory, mutable> %42 = ttg.local_load %41 : !ttg.memdesc<128x128xf32, #shared, #ttg.shared_memory, mutable> -> tensor<128x128xf32, #mma> %43 = arith.addf %37, %42: tensor<128x128xf32, #mma> @@ -1385,9 +1385,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ %44 = arith.addi %arg9, %c1_i32 : i32 %45 = arith.cmpi slt, %44, %c1_i32 : i32 %46 = arith.select %45, %44, %c0_i32 : i32 - %47 = ttg.memdesc_subview %21[%46, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory, mutable> + %47 = ttg.memdesc_index %21, %46 : !ttg.memdesc<1x128x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory, mutable> ttg.local_store %32, %47 : tensor<128x64xf16, #blocked1> -> !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory, mutable> - %48 = ttg.memdesc_subview %22[%46, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> + %48 = ttg.memdesc_index %22, %46 : !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> ttg.local_store %34, %48 : tensor<64x128xf16, #blocked> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> scf.yield %39, %31, %33, %46, %47, %48: tensor<128x128xf32, #mma>, tensor<128x64x!tt.ptr, #blocked1>, tensor<64x128x!tt.ptr, #blocked>, i32, !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> } @@ -1473,8 +1473,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ %20 = tt.addptr %18, %19 : tensor<64x128x!tt.ptr, #blocked>, tensor<64x128xi32, #blocked> %21 = ttg.local_alloc : () -> !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> %22 = ttg.local_alloc : () -> !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> - %23 = ttg.memdesc_subview %21[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> - %24 = ttg.memdesc_subview %22[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> + %23 = ttg.memdesc_index %21, %c0_i32 : !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> + %24 = ttg.memdesc_index %22, %c0_i32 : !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> %25 = tt.splat %arg2 : !tt.ptr -> tensor<256x1x!tt.ptr, #mma> %26 = tt.broadcast %25 : tensor<256x1x!tt.ptr, #mma> -> tensor<256x128x!tt.ptr, #mma> %27 = tt.load %26: tensor<256x128x!tt.ptr, #mma> @@ -1495,7 +1495,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ %38 = arith.cmpi eq, %30, %c63_i32: i32 %39 = scf.if %38 -> tensor<256x128xf32, #mma> { %40 = ttg.local_alloc : () -> !ttg.memdesc<1x256x128xf32, #shared, #ttg.shared_memory, mutable> - %41 = ttg.memdesc_subview %40[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x256x128xf32, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x128xf32, #shared, #ttg.shared_memory, mutable> + %41 = ttg.memdesc_index %40, %c0_i32 : !ttg.memdesc<1x256x128xf32, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x128xf32, #shared, #ttg.shared_memory, mutable> ttg.local_store %27, %41 : tensor<256x128xf32, #mma> -> !ttg.memdesc<256x128xf32, #shared, #ttg.shared_memory, mutable> %42 = ttg.local_load %41 : !ttg.memdesc<256x128xf32, #shared, #ttg.shared_memory, mutable> -> tensor<256x128xf32, #mma> %43 = arith.addf %37, %42: tensor<256x128xf32, #mma> @@ -1506,9 +1506,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ %44 = arith.addi %arg9, %c1_i32 : i32 %45 = arith.cmpi slt, %44, %c1_i32 : i32 %46 = arith.select %45, %44, %c0_i32 : i32 - %47 = ttg.memdesc_subview %21[%46, %c0_i32, %c0_i32] : !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> + %47 = ttg.memdesc_index %21, %46 : !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> ttg.local_store %32, %47 : tensor<256x64xf16, #blocked1> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> - %48 = ttg.memdesc_subview %22[%46, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> + %48 = ttg.memdesc_index %22, %46 : !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> ttg.local_store %34, %48 : tensor<64x128xf16, #blocked> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> scf.yield %39, %31, %33, %46, %47, %48: tensor<256x128xf32, #mma>, tensor<256x64x!tt.ptr, #blocked1>, tensor<64x128x!tt.ptr, #blocked>, i32, !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> } @@ -1609,8 +1609,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ %20 = tt.addptr %18, %19 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> %21 = ttg.local_alloc : () -> !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> %22 = ttg.local_alloc : () -> !ttg.memdesc<1x64x256xf16, #shared1, #ttg.shared_memory, mutable> - %23 = ttg.memdesc_subview %21[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> - %24 = ttg.memdesc_subview %22[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x256xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #ttg.shared_memory, mutable> + %23 = ttg.memdesc_index %21, %c0_i32 : !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> + %24 = ttg.memdesc_index %22, %c0_i32 : !ttg.memdesc<1x64x256xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #ttg.shared_memory, mutable> %25 = tt.splat %arg2 : !tt.ptr -> tensor<256x1x!tt.ptr, #mma> %26 = tt.broadcast %25 : tensor<256x1x!tt.ptr, #mma> -> tensor<256x256x!tt.ptr, #mma> %27 = tt.load %26: tensor<256x256x!tt.ptr, #mma> @@ -1631,7 +1631,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ %38 = arith.cmpi eq, %30, %c63_i32: i32 %39 = scf.if %38 -> tensor<256x256xf32, #mma> { %40 = ttg.local_alloc : () -> !ttg.memdesc<1x256x256xf32, #shared, #ttg.shared_memory, mutable> - %41 = ttg.memdesc_subview %40[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x256x256xf32, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x256xf32, #shared, #ttg.shared_memory, mutable> + %41 = ttg.memdesc_index %40, %c0_i32 : !ttg.memdesc<1x256x256xf32, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x256xf32, #shared, #ttg.shared_memory, mutable> ttg.local_store %27, %41 : tensor<256x256xf32, #mma> -> !ttg.memdesc<256x256xf32, #shared, #ttg.shared_memory, mutable> %42 = ttg.local_load %41 : !ttg.memdesc<256x256xf32, #shared, #ttg.shared_memory, mutable> -> tensor<256x256xf32, #mma> %43 = arith.addf %37, %42: tensor<256x256xf32, #mma> @@ -1642,9 +1642,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ %44 = arith.addi %arg9, %c1_i32 : i32 %45 = arith.cmpi slt, %44, %c1_i32 : i32 %46 = arith.select %45, %44, %c0_i32 : i32 - %47 = ttg.memdesc_subview %21[%46, %c0_i32, %c0_i32] : !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> + %47 = ttg.memdesc_index %21, %46 : !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> ttg.local_store %32, %47 : tensor<256x64xf16, #blocked1> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> - %48 = ttg.memdesc_subview %22[%46, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x256xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #ttg.shared_memory, mutable> + %48 = ttg.memdesc_index %22, %46 : !ttg.memdesc<1x64x256xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #ttg.shared_memory, mutable> ttg.local_store %34, %48 : tensor<64x256xf16, #blocked> -> !ttg.memdesc<64x256xf16, #shared1, #ttg.shared_memory, mutable> scf.yield %39, %31, %33, %46, %47, %48: tensor<256x256xf32, #mma>, tensor<256x64x!tt.ptr, #blocked1>, tensor<64x256x!tt.ptr, #blocked>, i32, !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<64x256xf16, #shared1, #ttg.shared_memory, mutable> } @@ -1698,8 +1698,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ %20 = tt.addptr %18, %19 : tensor<64x128x!tt.ptr, #blocked>, tensor<64x128xi32, #blocked> %21 = ttg.local_alloc : () -> !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> %22 = ttg.local_alloc : () -> !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> - %23 = ttg.memdesc_subview %21[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> - %24 = ttg.memdesc_subview %22[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> + %23 = ttg.memdesc_index %21, %c0_i32 : !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> + %24 = ttg.memdesc_index %22, %c0_i32 : !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> %25 = tt.splat %arg2 : !tt.ptr -> tensor<256x1x!tt.ptr, #mma> %26 = tt.broadcast %25 : tensor<256x1x!tt.ptr, #mma> -> tensor<256x128x!tt.ptr, #mma> %27 = tt.load %26: tensor<256x128x!tt.ptr, #mma> @@ -1722,7 +1722,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ scf.yield %37 : tensor<256x128xf32, #mma> } else { %40 = ttg.local_alloc : () -> !ttg.memdesc<1x256x128xf32, #shared, #ttg.shared_memory, mutable> - %41 = ttg.memdesc_subview %40[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x256x128xf32, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x128xf32, #shared, #ttg.shared_memory, mutable> + %41 = ttg.memdesc_index %40, %c0_i32 : !ttg.memdesc<1x256x128xf32, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x128xf32, #shared, #ttg.shared_memory, mutable> ttg.local_store %27, %41 : tensor<256x128xf32, #mma> -> !ttg.memdesc<256x128xf32, #shared, #ttg.shared_memory, mutable> %42 = ttg.local_load %41 : !ttg.memdesc<256x128xf32, #shared, #ttg.shared_memory, mutable> -> tensor<256x128xf32, #mma> %43 = arith.addf %37, %42: tensor<256x128xf32, #mma> @@ -1731,9 +1731,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ %44 = arith.addi %arg9, %c1_i32 : i32 %45 = arith.cmpi slt, %44, %c1_i32 : i32 %46 = arith.select %45, %44, %c0_i32 : i32 - %47 = ttg.memdesc_subview %21[%46, %c0_i32, %c0_i32] : !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> + %47 = ttg.memdesc_index %21, %46 : !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> ttg.local_store %32, %47 : tensor<256x64xf16, #blocked1> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> - %48 = ttg.memdesc_subview %22[%46, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> + %48 = ttg.memdesc_index %22, %46 : !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> ttg.local_store %34, %48 : tensor<64x128xf16, #blocked> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> scf.yield %39, %31, %33, %46, %47, %48: tensor<256x128xf32, #mma>, tensor<256x64x!tt.ptr, #blocked1>, tensor<64x128x!tt.ptr, #blocked>, i32, !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> } @@ -1787,8 +1787,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ %20 = tt.addptr %18, %19 : tensor<64x128x!tt.ptr, #blocked>, tensor<64x128xi32, #blocked> %21 = ttg.local_alloc : () -> !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> %22 = ttg.local_alloc : () -> !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> - %23 = ttg.memdesc_subview %21[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> - %24 = ttg.memdesc_subview %22[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> + %23 = ttg.memdesc_index %21, %c0_i32 : !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> + %24 = ttg.memdesc_index %22, %c0_i32 : !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> %25 = tt.splat %arg2 : !tt.ptr -> tensor<256x1x!tt.ptr, #mma> %26 = tt.broadcast %25 : tensor<256x1x!tt.ptr, #mma> -> tensor<256x128x!tt.ptr, #mma> %27 = tt.load %26: tensor<256x128x!tt.ptr, #mma> @@ -1809,14 +1809,14 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ %38 = arith.cmpi eq, %30, %c63_i32: i32 %39 = scf.if %38 -> tensor<256x128xf32, #mma> { %40 = ttg.local_alloc : () -> !ttg.memdesc<1x256x128xf32, #shared, #ttg.shared_memory, mutable> - %41 = ttg.memdesc_subview %40[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x256x128xf32, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x128xf32, #shared, #ttg.shared_memory, mutable> + %41 = ttg.memdesc_index %40, %c0_i32 : !ttg.memdesc<1x256x128xf32, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x128xf32, #shared, #ttg.shared_memory, mutable> ttg.local_store %27, %41 : tensor<256x128xf32, #mma> -> !ttg.memdesc<256x128xf32, #shared, #ttg.shared_memory, mutable> %42 = ttg.local_load %41 : !ttg.memdesc<256x128xf32, #shared, #ttg.shared_memory, mutable> -> tensor<256x128xf32, #mma> %43 = arith.subf %37, %42: tensor<256x128xf32, #mma> scf.yield %43 : tensor<256x128xf32, #mma> } else { %44 = ttg.local_alloc : () -> !ttg.memdesc<1x256x128xf32, #shared, #ttg.shared_memory, mutable> - %45 = ttg.memdesc_subview %44[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x256x128xf32, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x128xf32, #shared, #ttg.shared_memory, mutable> + %45 = ttg.memdesc_index %44, %c0_i32 : !ttg.memdesc<1x256x128xf32, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x128xf32, #shared, #ttg.shared_memory, mutable> ttg.local_store %27, %45 : tensor<256x128xf32, #mma> -> !ttg.memdesc<256x128xf32, #shared, #ttg.shared_memory, mutable> %46 = ttg.local_load %45 : !ttg.memdesc<256x128xf32, #shared, #ttg.shared_memory, mutable> -> tensor<256x128xf32, #mma> %47 = arith.addf %37, %46: tensor<256x128xf32, #mma> @@ -1825,9 +1825,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ %48 = arith.addi %arg9, %c1_i32 : i32 %49 = arith.cmpi slt, %48, %c1_i32 : i32 %50 = arith.select %49, %48, %c0_i32 : i32 - %51 = ttg.memdesc_subview %21[%50, %c0_i32, %c0_i32] : !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> + %51 = ttg.memdesc_index %21, %50 : !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> ttg.local_store %32, %51 : tensor<256x64xf16, #blocked1> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> - %52 = ttg.memdesc_subview %22[%50, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> + %52 = ttg.memdesc_index %22, %50 : !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> ttg.local_store %34, %52 : tensor<64x128xf16, #blocked> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> scf.yield %39, %31, %33, %50, %51, %52: tensor<256x128xf32, #mma>, tensor<256x64x!tt.ptr, #blocked1>, tensor<64x128x!tt.ptr, #blocked>, i32, !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> } @@ -1866,11 +1866,11 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ %6 = arith.addi %arg30, %c1_i32 : i32 %7 = arith.cmpi slt, %6, %c3_i32 : i32 %8 = arith.select %7, %6, %c0_i32 : i32 - %9 = ttg.memdesc_subview %arg22[%8, %c0_i32, %c0_i32] : !ttg.memdesc<3x256x32xbf16, #shared, #smem, mutable> -> !ttg.memdesc<256x32xbf16, #shared, #smem, mutable> + %9 = ttg.memdesc_index %arg22, %8 : !ttg.memdesc<3x256x32xbf16, #shared, #smem, mutable> -> !ttg.memdesc<256x32xbf16, #shared, #smem, mutable> %10 = ttg.async_copy_global_to_local %4, %9 : tensor<256x32x!tt.ptr, #blocked> -> <256x32xbf16, #shared, #smem, mutable> %11 = ttg.async_commit_group %10 %12 = ttg.local_load %arg31 token %arg33 : !ttg.memdesc<256x32xbf16, #shared, #smem, mutable> -> tensor<256x32xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> - %13 = ttg.memdesc_subview %arg23[%8, %c0_i32, %c0_i32] : !ttg.memdesc<3x32x256xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<32x256xbf16, #shared1, #smem, mutable> + %13 = ttg.memdesc_index %arg23, %8 : !ttg.memdesc<3x32x256xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<32x256xbf16, #shared1, #smem, mutable> %14 = ttg.async_copy_global_to_local %5, %13 : tensor<32x256x!tt.ptr, #blocked1> -> <32x256xbf16, #shared1, #smem, mutable> %15 = ttg.async_commit_group %14 %16 = ttg.local_load %arg34 token %arg36 : !ttg.memdesc<32x256xbf16, #shared1, #smem, mutable> -> tensor<32x256xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>> @@ -1944,19 +1944,19 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ %12 = arith.addi %arg36, %c1_i32 : i32 %13 = arith.cmpi slt, %12, %c2_i32 : i32 %14 = arith.select %13, %12, %c0_i32 : i32 - %15 = ttg.memdesc_subview %2[%14, %c0_i32, %c0_i32] : !ttg.memdesc<2x256x8xi8, #shared, #smem, mutable> -> !ttg.memdesc<256x8xi8, #shared, #smem, mutable> + %15 = ttg.memdesc_index %2, %14 : !ttg.memdesc<2x256x8xi8, #shared, #smem, mutable> -> !ttg.memdesc<256x8xi8, #shared, #smem, mutable> %16 = ttg.async_copy_global_to_local %10, %15 : tensor<256x8x!tt.ptr, #blocked> -> <256x8xi8, #shared, #smem, mutable> %17 = ttg.async_commit_group %16 %18 = ttg.local_load %arg41 token %7 : !ttg.memdesc<256x8xi8, #shared, #smem, mutable> -> tensor<256x8xi8, #linear> - %19 = ttg.memdesc_subview %3[%14, %c0_i32, %c0_i32] : !ttg.memdesc<2x256x8xi8, #shared, #smem, mutable> -> !ttg.memdesc<256x8xi8, #shared, #smem, mutable> + %19 = ttg.memdesc_index %3, %14 : !ttg.memdesc<2x256x8xi8, #shared, #smem, mutable> -> !ttg.memdesc<256x8xi8, #shared, #smem, mutable> %20 = ttg.async_copy_global_to_local %11, %19 : tensor<256x8x!tt.ptr, #blocked> -> <256x8xi8, #shared, #smem, mutable> %21 = ttg.async_commit_group %20 %22 = ttg.local_load %arg42 token %7 : !ttg.memdesc<256x8xi8, #shared, #smem, mutable> -> tensor<256x8xi8, #linear1> - %23 = ttg.memdesc_subview %0[%14, %c0_i32, %c0_i32] : !ttg.memdesc<2x256x128xi8, #shared1, #smem, mutable> -> !ttg.memdesc<256x128xi8, #shared1, #smem, mutable> + %23 = ttg.memdesc_index %0, %14 : !ttg.memdesc<2x256x128xi8, #shared1, #smem, mutable> -> !ttg.memdesc<256x128xi8, #shared1, #smem, mutable> %24 = ttg.async_copy_global_to_local %8, %23 : tensor<256x128x!tt.ptr, #blocked1> -> <256x128xi8, #shared1, #smem, mutable> %25 = ttg.async_commit_group %24 %26 = ttg.local_load %arg43 token %7 : !ttg.memdesc<256x128xi8, #shared1, #smem, mutable> -> tensor<256x128xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> - %27 = ttg.memdesc_subview %1[%14, %c0_i32, %c0_i32] : !ttg.memdesc<2x128x256xi8, #shared2, #smem, mutable> -> !ttg.memdesc<128x256xi8, #shared2, #smem, mutable> + %27 = ttg.memdesc_index %1, %14 : !ttg.memdesc<2x128x256xi8, #shared2, #smem, mutable> -> !ttg.memdesc<128x256xi8, #shared2, #smem, mutable> %28 = ttg.async_copy_global_to_local %9, %27 : tensor<128x256x!tt.ptr, #blocked2> -> <128x256xi8, #shared2, #smem, mutable> %29 = ttg.async_commit_group %28 %30 = ttg.local_load %arg44 token %7 : !ttg.memdesc<128x256xi8, #shared2, #smem, mutable> -> tensor<128x256xi8, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>> diff --git a/test/TritonGPU/amd/amd-fold-true-cmpi.mlir b/test/TritonGPU/amd/amd-fold-true-cmpi.mlir index 5e923fa65946..6f7fe806addc 100644 --- a/test/TritonGPU/amd/amd-fold-true-cmpi.mlir +++ b/test/TritonGPU/amd/amd-fold-true-cmpi.mlir @@ -86,9 +86,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { %14 = tt.load %4, %13 : tensor<128x32x!tt.ptr, #blocked1> %15 = tt.splat %12 : i1 -> tensor<32x128xi1, #blocked> %16 = tt.load %9, %15, %cst_3 : tensor<32x128x!tt.ptr, #blocked> - %17 = ttg.memdesc_subview %10[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x32xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable> + %17 = ttg.memdesc_index %10, %c0_i32 : !ttg.memdesc<1x128x32xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable> ttg.local_store %14, %17 : tensor<128x32xf16, #blocked1> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable> - %18 = ttg.memdesc_subview %11[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x32x128xf16, #shared1, #smem, mutable> -> !ttg.memdesc<32x128xf16, #shared1, #smem, mutable> + %18 = ttg.memdesc_index %11, %c0_i32 : !ttg.memdesc<1x32x128xf16, #shared1, #smem, mutable> -> !ttg.memdesc<32x128xf16, #shared1, #smem, mutable> ttg.local_store %16, %18 : tensor<32x128xf16, #blocked> -> !ttg.memdesc<32x128xf16, #shared1, #smem, mutable> %19 = arith.subi %arg1, %arg2 : index %20:6 = scf.for %arg5 = %arg0 to %19 step %arg2 iter_args(%arg6 = %4, %arg7 = %9, %arg8 = %cst_2, %arg9 = %c0_i32, %arg10 = %17, %arg11 = %18) -> (tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked>, tensor<128x128xf32, #mma>, i32, !ttg.memdesc<128x32xf16, #shared, #smem, mutable>, !ttg.memdesc<32x128xf16, #shared1, #smem, mutable>) { @@ -104,9 +104,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { %41 = arith.addi %arg9, %c1_i32 : i32 %42 = arith.cmpi slt, %41, %c1_i32 : i32 %43 = arith.select %42, %41, %c0_i32 : i32 - %44 = ttg.memdesc_subview %10[%43, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x32xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable> + %44 = ttg.memdesc_index %10, %43 : !ttg.memdesc<1x128x32xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable> ttg.local_store %35, %44 : tensor<128x32xf16, #blocked1> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable> - %45 = ttg.memdesc_subview %11[%43, %c0_i32, %c0_i32] : !ttg.memdesc<1x32x128xf16, #shared1, #smem, mutable> -> !ttg.memdesc<32x128xf16, #shared1, #smem, mutable> + %45 = ttg.memdesc_index %11, %43 : !ttg.memdesc<1x32x128xf16, #shared1, #smem, mutable> -> !ttg.memdesc<32x128xf16, #shared1, #smem, mutable> ttg.local_store %37, %45 : tensor<32x128xf16, #blocked> -> !ttg.memdesc<32x128xf16, #shared1, #smem, mutable> scf.yield %33, %34, %40, %43, %44, %45 : tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked>, tensor<128x128xf32, #mma>, i32, !ttg.memdesc<128x32xf16, #shared, #smem, mutable>, !ttg.memdesc<32x128xf16, #shared1, #smem, mutable> } diff --git a/test/TritonGPU/amd/amd-range-analysis.mlir b/test/TritonGPU/amd/amd-range-analysis.mlir index 27d470fd55fb..9a00df00eadd 100644 --- a/test/TritonGPU/amd/amd-range-analysis.mlir +++ b/test/TritonGPU/amd/amd-range-analysis.mlir @@ -1352,9 +1352,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // expected-remark@+1 {{unsigned : [0, 1] signed : [-1, 0]}} %15 = tt.splat %12 : i1 -> tensor<32x128xi1, #blocked> %16 = tt.load %9, %15, %cst_3 : tensor<32x128x!tt.ptr, #blocked> - %17 = ttg.memdesc_subview %10[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x32xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable> + %17 = ttg.memdesc_index %10, %c0_i32 : !ttg.memdesc<1x128x32xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable> ttg.local_store %14, %17 : tensor<128x32xf16, #blocked1> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable> - %18 = ttg.memdesc_subview %11[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x32x128xf16, #shared1, #smem, mutable> -> !ttg.memdesc<32x128xf16, #shared1, #smem, mutable> + %18 = ttg.memdesc_index %11, %c0_i32 : !ttg.memdesc<1x32x128xf16, #shared1, #smem, mutable> -> !ttg.memdesc<32x128xf16, #shared1, #smem, mutable> ttg.local_store %16, %18 : tensor<32x128xf16, #blocked> -> !ttg.memdesc<32x128xf16, #shared1, #smem, mutable> // expected-remark@+1 {{unsigned : [0, 18446744073709551615] signed : [-9223372036854775808, 9223372036854775807]}} %19 = arith.subi %arg1, %arg2 : index @@ -1374,9 +1374,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // expected-remark@+2 {{unsigned : [0, 0] signed : [0, 0]}} // expected-remark@+1 {{non-neg}} %43 = arith.select %42, %41, %c0_i32 : i32 - %44 = ttg.memdesc_subview %10[%43, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x32xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable> + %44 = ttg.memdesc_index %10, %43 : !ttg.memdesc<1x128x32xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable> ttg.local_store %35, %44 : tensor<128x32xf16, #blocked1> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable> - %45 = ttg.memdesc_subview %11[%43, %c0_i32, %c0_i32] : !ttg.memdesc<1x32x128xf16, #shared1, #smem, mutable> -> !ttg.memdesc<32x128xf16, #shared1, #smem, mutable> + %45 = ttg.memdesc_index %11, %43 : !ttg.memdesc<1x32x128xf16, #shared1, #smem, mutable> -> !ttg.memdesc<32x128xf16, #shared1, #smem, mutable> ttg.local_store %37, %45 : tensor<32x128xf16, #blocked> -> !ttg.memdesc<32x128xf16, #shared1, #smem, mutable> scf.yield %33, %34, %40, %43, %44, %45 : tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked>, tensor<128x128xf32, #mma>, i32, !ttg.memdesc<128x32xf16, #shared, #smem, mutable>, !ttg.memdesc<32x128xf16, #shared1, #smem, mutable> } diff --git a/test/TritonGPU/amd/amd-reorder-instructions.mlir b/test/TritonGPU/amd/amd-reorder-instructions.mlir index a14de25e3790..7abe793e4928 100644 --- a/test/TritonGPU/amd/amd-reorder-instructions.mlir +++ b/test/TritonGPU/amd/amd-reorder-instructions.mlir @@ -556,10 +556,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ %6 = tt.load %0 {amd.pipeliner_part = "prologue"} : tensor<32x128x!tt.ptr, #blocked> %7 = tt.load %1 {amd.pipeliner_part = "prologue"} : tensor<128x32x!tt.ptr, #blocked1> - %8 = ttg.memdesc_subview %2[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<2x32x128xf16, #shared, #smem, mutable> -> !ttg.memdesc<32x128xf16, #shared, #smem, mutable> - %9 = ttg.memdesc_subview %3[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<2x128x32xf8E5M2FNUZ, #shared1, #smem, mutable> -> !ttg.memdesc<128x32xf8E5M2FNUZ, #shared1, #smem, mutable> - %10 = ttg.memdesc_subview %2[%c1_i32, %c0_i32, %c0_i32] : !ttg.memdesc<2x32x128xf16, #shared, #smem, mutable> -> !ttg.memdesc<32x128xf16, #shared, #smem, mutable> - %11 = ttg.memdesc_subview %3[%c1_i32, %c0_i32, %c0_i32] : !ttg.memdesc<2x128x32xf8E5M2FNUZ, #shared1, #smem, mutable> -> !ttg.memdesc<128x32xf8E5M2FNUZ, #shared1, #smem, mutable> + %8 = ttg.memdesc_index %2, %c0_i32 : !ttg.memdesc<2x32x128xf16, #shared, #smem, mutable> -> !ttg.memdesc<32x128xf16, #shared, #smem, mutable> + %9 = ttg.memdesc_index %3, %c0_i32 : !ttg.memdesc<2x128x32xf8E5M2FNUZ, #shared1, #smem, mutable> -> !ttg.memdesc<128x32xf8E5M2FNUZ, #shared1, #smem, mutable> + %10 = ttg.memdesc_index %2, %c1_i32 : !ttg.memdesc<2x32x128xf16, #shared, #smem, mutable> -> !ttg.memdesc<32x128xf16, #shared, #smem, mutable> + %11 = ttg.memdesc_index %3, %c1_i32 : !ttg.memdesc<2x128x32xf8E5M2FNUZ, #shared1, #smem, mutable> -> !ttg.memdesc<128x32xf8E5M2FNUZ, #shared1, #smem, mutable> ttg.local_store %4, %8 : tensor<32x128xf16, #blocked> -> !ttg.memdesc<32x128xf16, #shared, #smem, mutable> ttg.local_store %5, %9 : tensor<128x32xf8E5M2FNUZ, #blocked1> -> !ttg.memdesc<128x32xf8E5M2FNUZ, #shared1, #smem, mutable> diff --git a/test/TritonGPU/amd/in-thread-transpose.mlir b/test/TritonGPU/amd/in-thread-transpose.mlir index 5fde79dcea8e..31cbdbafcc00 100644 --- a/test/TritonGPU/amd/in-thread-transpose.mlir +++ b/test/TritonGPU/amd/in-thread-transpose.mlir @@ -177,9 +177,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.targ %7 = ttg.local_alloc : () -> !ttg.memdesc<1x64x64xf16, #shared1, #smem, mutable> %8 = tt.load %0, %cst_1 : tensor<64x64x!tt.ptr, #blocked> %9 = tt.load %1, %cst_1 : tensor<64x64x!tt.ptr, #blocked> - %10 = ttg.memdesc_subview %6[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<64x64xf16, #shared, #smem, mutable> + %10 = ttg.memdesc_index %6, %c0_i32 : !ttg.memdesc<1x64x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<64x64xf16, #shared, #smem, mutable> ttg.local_store %8, %10 : tensor<64x64xf16, #blocked> -> !ttg.memdesc<64x64xf16, #shared, #smem, mutable> - %11 = ttg.memdesc_subview %7[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x64xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x64xf16, #shared1, #smem, mutable> + %11 = ttg.memdesc_index %7, %c0_i32 : !ttg.memdesc<1x64x64xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x64xf16, #shared1, #smem, mutable> ttg.local_store %9, %11 : tensor<64x64xf16, #blocked> -> !ttg.memdesc<64x64xf16, #shared1, #smem, mutable> %12 = arith.subi %3, %c1_i32 : i32 %13:6 = scf.for %arg9 = %c0_i32 to %12 step %c1_i32 iter_args(%arg10 = %cst_0, %arg11 = %0, %arg12 = %1, %arg13 = %c0_i32, %arg14 = %10, %arg15 = %11) -> (tensor<64x64xf32, #mma>, tensor<64x64x!tt.ptr, #blocked>, tensor<64x64x!tt.ptr, #blocked>, i32, !ttg.memdesc<64x64xf16, #shared, #smem, mutable>, !ttg.memdesc<64x64xf16, #shared1, #smem, mutable>) : i32 { @@ -193,9 +193,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.targ %28 = arith.addi %arg13, %c1_i32 : i32 %29 = arith.cmpi slt, %28, %c1_i32 : i32 %30 = arith.select %29, %28, %c0_i32 : i32 - %31 = ttg.memdesc_subview %6[%30, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<64x64xf16, #shared, #smem, mutable> + %31 = ttg.memdesc_index %6, %30 : !ttg.memdesc<1x64x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<64x64xf16, #shared, #smem, mutable> ttg.local_store %23, %31 : tensor<64x64xf16, #blocked> -> !ttg.memdesc<64x64xf16, #shared, #smem, mutable> - %32 = ttg.memdesc_subview %7[%30, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x64xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x64xf16, #shared1, #smem, mutable> + %32 = ttg.memdesc_index %7, %30 : !ttg.memdesc<1x64x64xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x64xf16, #shared1, #smem, mutable> ttg.local_store %25, %32 : tensor<64x64xf16, #blocked> -> !ttg.memdesc<64x64xf16, #shared1, #smem, mutable> scf.yield %27, %21, %22, %30, %31, %32 : tensor<64x64xf32, #mma>, tensor<64x64x!tt.ptr, #blocked>, tensor<64x64x!tt.ptr, #blocked>, i32, !ttg.memdesc<64x64xf16, #shared, #smem, mutable>, !ttg.memdesc<64x64xf16, #shared1, #smem, mutable> } diff --git a/test/TritonGPU/consan-negative.mlir b/test/TritonGPU/consan-negative.mlir index f386fd233541..950dcca6dc06 100644 --- a/test/TritonGPU/consan-negative.mlir +++ b/test/TritonGPU/consan-negative.mlir @@ -33,7 +33,7 @@ module attributes { "ttg.num-ctas" = 1 : i32, #shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 32}> #smem = #ttg.shared_memory -// Test 2: local_alloc used in a non-trivial memdesc_subview should emit a warning. +// Test 2: local_alloc used in a non-trivial memdesc_index should emit a warning. module attributes { "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shared = 65544 : i32, @@ -48,7 +48,7 @@ module attributes { "ttg.num-ctas" = 1 : i32, // expected-warning@+1 {{Allocation is used in an inconsistent way, cannot instrument}} %alloc = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<2x32x32xf32, #shared, #smem, mutable> - %sub = ttg.memdesc_subview %alloc[%c1, %c0, %c0] + %sub = ttg.memdesc_index %alloc, %c1 : !ttg.memdesc<2x32x32xf32, #shared, #smem, mutable> -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> "memdesc_use" (%alloc) : (!ttg.memdesc<2x32x32xf32, #shared, #smem, mutable>) -> () diff --git a/test/TritonGPU/consan.mlir b/test/TritonGPU/consan.mlir index a65eee295b6a..56920afc8825 100644 --- a/test/TritonGPU/consan.mlir +++ b/test/TritonGPU/consan.mlir @@ -75,7 +75,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar // CHECK: %[[WRT_BARS_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 8 : i32, nbytes = 32 : i32} : !tt.ptr %c0_i32 = arith.constant 0 : i32 %0 = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<3x32x32xf32, #shared, #smem, mutable> - %1 = ttg.memdesc_subview %0[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x32x32xf32, #shared, #smem, mutable> -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> + %1 = ttg.memdesc_index %0, %c0_i32 : !ttg.memdesc<3x32x32xf32, #shared, #smem, mutable> -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> %bar = ttg.local_alloc {allocation.offset = 4096 : i32} : () -> !ttg.memdesc<1xi64, #shared1, #smem, mutable> ttng.init_barrier %bar, 1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> tt.return @@ -109,9 +109,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar %c0 = arith.constant 0 : i32 %0 = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<2x32x32xf32, #shared, #smem, mutable> %bar = ttg.local_alloc {allocation.offset = 8192 : i32} : () -> !ttg.memdesc<4x1xi64, #shared1, #smem, mutable> - %bar_sub = ttg.memdesc_subview %bar[%c0, %c0] : !ttg.memdesc<4x1xi64, #shared1, #smem, mutable> -> !ttg.memdesc<1xi64, #shared1, #smem, mutable> + %bar_sub = ttg.memdesc_index %bar, %c0 : !ttg.memdesc<4x1xi64, #shared1, #smem, mutable> -> !ttg.memdesc<1xi64, #shared1, #smem, mutable> ttng.init_barrier %bar_sub, 1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> - %buf_sub = ttg.memdesc_subview %0[%c0, %c0, %c0] : !ttg.memdesc<2x32x32xf32, #shared, #smem, mutable> -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> + %buf_sub = ttg.memdesc_index %0, %c0 : !ttg.memdesc<2x32x32xf32, #shared, #smem, mutable> -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> tt.return } diff --git a/test/TritonGPU/invalid.mlir b/test/TritonGPU/invalid.mlir index 8de96a73c114..d9177577c972 100644 --- a/test/TritonGPU/invalid.mlir +++ b/test/TritonGPU/invalid.mlir @@ -5,7 +5,7 @@ tt.func public @non_trivial_block(%arg0: !ttg.memdesc<8x16xf32, #shared, #smem>) { %zero = arith.constant 0 : i32 // expected-error @+1 {{non-trivial block}} - %a = ttg.memdesc_subview %arg0[%zero, %zero] : !ttg.memdesc<8x16xf32, #shared, #smem> -> !ttg.memdesc<8x8xf32, #shared, #smem> + %a = ttg.memdesc_subslice %arg0 [0, 0] : !ttg.memdesc<8x16xf32, #shared, #smem> -> !ttg.memdesc<8x8xf32, #shared, #smem> tt.return } @@ -16,7 +16,7 @@ tt.func public @non_trivial_block(%arg0: !ttg.memdesc<8x16xf32, #shared, #smem>) tt.func public @miss_encoding(%arg0: !ttg.memdesc<8x16xf32, #shared, #smem>) { %zero = arith.constant 0 : i32 // expected-error @+1 {{,}} - %a = ttg.memdesc_subview %arg0[%zero, %zero] : !ttg.memdesc<8x16xf32> -> !ttg.memdesc<8x16xf16> + %a = ttg.memdesc_subslice %arg0 [0, 0] : !ttg.memdesc<8x16xf32> -> !ttg.memdesc<8x16xf16> tt.return } @@ -27,7 +27,7 @@ tt.func public @miss_encoding(%arg0: !ttg.memdesc<8x16xf32, #shared, #smem>) { tt.func public @miss_memory_space(%arg0: !ttg.memdesc<8x16xf32, #shared, #smem>) { %zero = arith.constant 0 : i32 // expected-error @+1 {{,}} - %a = ttg.memdesc_subview %arg0[%zero, %zero] : !ttg.memdesc<8x16xf32, #shared> -> !ttg.memdesc<8x16xf16> + %a = ttg.memdesc_subslice %arg0 [0, 0] : !ttg.memdesc<8x16xf32, #shared> -> !ttg.memdesc<8x16xf16> tt.return } @@ -38,7 +38,7 @@ tt.func public @miss_memory_space(%arg0: !ttg.memdesc<8x16xf32, #shared, #smem>) tt.func public @subview_element_ty(%arg0: !ttg.memdesc<8x16xf32, #shared, #smem>) { %zero = arith.constant 0 : i32 // expected-error @+1 {{element type}} - %a = ttg.memdesc_subview %arg0[%zero, %zero] : !ttg.memdesc<8x16xf32, #shared, #smem> -> !ttg.memdesc<8x16xf16, #shared, #smem> + %a = ttg.memdesc_subslice %arg0 [0, 0] : !ttg.memdesc<8x16xf32, #shared, #smem> -> !ttg.memdesc<8x16xf16, #shared, #smem> tt.return } @@ -49,7 +49,7 @@ tt.func public @subview_element_ty(%arg0: !ttg.memdesc<8x16xf32, #shared, #smem> tt.func public @too_many_offsets(%arg0: !ttg.memdesc<8x16xf32, #shared, #smem>) { %zero = arith.constant 0 : i32 // expected-error @+1 {{offsets}} - %a = ttg.memdesc_subview %arg0[%zero, %zero, %zero] : !ttg.memdesc<8x16xf32, #shared, #smem> -> !ttg.memdesc + %a = ttg.memdesc_subslice %arg0 [0, 0, 0] : !ttg.memdesc<8x16xf32, #shared, #smem> -> !ttg.memdesc<8x16xf32, #shared, #smem> tt.return } @@ -58,9 +58,8 @@ tt.func public @too_many_offsets(%arg0: !ttg.memdesc<8x16xf32, #shared, #smem>) #shared = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1]}> #smem = #ttg.shared_memory tt.func public @too_few_offsets(%arg0: !ttg.memdesc<8x16xf32, #shared, #smem>) { - %zero = arith.constant 0 : i32 // expected-error @+1 {{offsets}} - %a = ttg.memdesc_subview %arg0[%zero] : !ttg.memdesc<8x16xf32, #shared, #smem> -> !ttg.memdesc + %a = ttg.memdesc_subslice %arg0 [0] : !ttg.memdesc<8x16xf32, #shared, #smem> -> !ttg.memdesc<8x16xf32, #shared, #smem> tt.return } @@ -68,21 +67,20 @@ tt.func public @too_few_offsets(%arg0: !ttg.memdesc<8x16xf32, #shared, #smem>) { #shared = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1]}> #smem = #ttg.shared_memory -tt.func public @result_rank_too_large(%arg0: !ttg.memdesc<8x16xf32, #shared, #smem>) { +tt.func public @result_rank_too_large(%arg0: !ttg.memdesc<3x8x16xf32, #shared, #smem>) { %zero = arith.constant 0 : i32 // expected-error @+1 {{result rank}} - %a = ttg.memdesc_subview %arg0[%zero, %zero] : !ttg.memdesc<8x16xf32, #shared, #smem> -> !ttg.memdesc<3x8x16xf32, #shared, #smem> + %a = ttg.memdesc_index %arg0, %zero : !ttg.memdesc<3x8x16xf32, #shared, #smem> -> !ttg.memdesc<3x8x16xf32, #shared, #smem> tt.return } - // ----- -#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 16, order = [0, 1]}> +#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0]}> #smem = #ttg.shared_memory -tt.func public @subview_along_swizzling(%arg0: !ttg.memdesc<8x16xf32, #shared, #smem>) { +tt.func public @result_1d_to_1d(%arg0: !ttg.memdesc<8xf32, #shared, #smem>) { %zero = arith.constant 0 : i32 - // expected-error @+1 {{swizzling pattern}} - %a = ttg.memdesc_subview %arg0[%zero, %zero] : !ttg.memdesc<8x16xf32, #shared, #smem> -> !ttg.memdesc<8x4xf32, #shared, #smem> + // expected-error @+1 {{1D -> 1D}} + %a = ttg.memdesc_index %arg0, %zero : !ttg.memdesc<8xf32, #shared, #smem> -> !ttg.memdesc<2xf32, #shared, #smem> tt.return } @@ -90,22 +88,21 @@ tt.func public @subview_along_swizzling(%arg0: !ttg.memdesc<8x16xf32, #shared, # #shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 16, order = [0, 1]}> #smem = #ttg.shared_memory -tt.func public @subview_along_swizzling(%arg0: !ttg.memdesc<8x16xf32, #shared, #smem>, %index: i32) { +tt.func public @subview_along_swizzling(%arg0: !ttg.memdesc<8x16xf32, #shared, #smem>) { %zero = arith.constant 0 : i32 - // expected-error @+1 {{constant}} - %a = ttg.memdesc_subview %arg0[%zero, %index] : !ttg.memdesc<8x16xf32, #shared, #smem> -> !ttg.memdesc<8x4xf32, #shared, #smem> + // expected-error @+1 {{swizzling pattern}} + %a = ttg.memdesc_subslice %arg0 [0, 0] : !ttg.memdesc<8x16xf32, #shared, #smem> -> !ttg.memdesc<8x4xf32, #shared, #smem> tt.return } + // ----- #shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 16, order = [0, 1]}> #smem = #ttg.shared_memory tt.func public @subview_along_swizzling(%arg0: !ttg.memdesc<8x16xf32, #shared, #smem>, %index: i32) { - %zero = arith.constant 0 : i32 - %c_2 = arith.constant 2 : i32 // expected-error @+1 {{tile}} - %a = ttg.memdesc_subview %arg0[%c_2, %zero] : !ttg.memdesc<8x16xf32, #shared, #smem> -> !ttg.memdesc<4x16xf32, #shared, #smem> + %a = ttg.memdesc_subslice %arg0 [2, 0] : !ttg.memdesc<8x16xf32, #shared, #smem> -> !ttg.memdesc<4x16xf32, #shared, #smem> tt.return } @@ -116,7 +113,7 @@ tt.func public @subview_along_swizzling(%arg0: !ttg.memdesc<8x16xf32, #shared, # tt.func public @result_dim_too_large(%arg0: !ttg.memdesc<8x16xf32, #shared, #smem>) { %zero = arith.constant 0 : i32 // expected-error @+1 {{result shape}} - %a = ttg.memdesc_subview %arg0[%zero, %zero] : !ttg.memdesc<8x16xf32, #shared, #smem> -> !ttg.memdesc<32xf32, #shared, #smem> + %a = ttg.memdesc_index %arg0, %zero : !ttg.memdesc<8x16xf32, #shared, #smem> -> !ttg.memdesc<32xf32, #shared, #smem> tt.return } diff --git a/test/TritonGPU/load-mma-specialization.mlir b/test/TritonGPU/load-mma-specialization.mlir index 221418ce1c1b..fd3328b00f48 100644 --- a/test/TritonGPU/load-mma-specialization.mlir +++ b/test/TritonGPU/load-mma-specialization.mlir @@ -58,22 +58,22 @@ tt.func @warp_specialize_tma_matmul( // CHECK-DAG: [[C2:%.*]] = arith.constant 2 : i32 // CHECK: [[ACC_BUFS:%.*]], [[ACC_TOK:.*]] = ttng.tmem_alloc : () -> (!ttg.memdesc<1x128x128xf32, [[ACC_TMEM]], #ttng.tensor_memory, mutable> - // CHECK-NEXT: [[ACC_BUF:%.*]] = ttg.memdesc_subview [[ACC_BUFS]][[[C0]], [[C0]], [[C0]]] + // CHECK-NEXT: [[ACC_BUF:%.*]] = ttg.memdesc_index [[ACC_BUFS]], [[C0]] // CHECK-NEXT: [[INIT_TOK:%.*]] = ttng.tmem_store [[ZERO]], [[ACC_BUF]][[[ACC_TOK]]] // CHECK-NEXT: [[A_BUFS:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<2x128x64xf16, [[SHARED]] // CHECK-NEXT: [[B_BUFS:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<2x128x64xf16, [[SHARED]] // CHECK-NEXT: [[READY_MBARS:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<2xi64 - // CHECK-NEXT: [[READY_MBAR0:%.*]] = ttg.memdesc_subview [[READY_MBARS]][[[C0]]] + // CHECK-NEXT: [[READY_MBAR0:%.*]] = ttg.memdesc_index [[READY_MBARS]], [[C0]] // CHECK-NEXT: ttng.init_barrier [[READY_MBAR0]], 1 - // CHECK-NEXT: [[READY_MBAR1:%.*]] = ttg.memdesc_subview [[READY_MBARS]][[[C1]]] + // CHECK-NEXT: [[READY_MBAR1:%.*]] = ttg.memdesc_index [[READY_MBARS]], [[C1]] // CHECK-NEXT: ttng.init_barrier [[READY_MBAR1]], 1 // CHECK-NEXT: [[OPER_MBARS:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<2xi64 - // CHECK-NEXT: [[OPER_MBAR0:%.*]] = ttg.memdesc_subview [[OPER_MBARS]][[[C0]]] + // CHECK-NEXT: [[OPER_MBAR0:%.*]] = ttg.memdesc_index [[OPER_MBARS]], [[C0]] // CHECK-NEXT: ttng.init_barrier [[OPER_MBAR0]], 1 - // CHECK-NEXT: [[OPER_MBAR1:%.*]] = ttg.memdesc_subview [[OPER_MBARS]][[[C1]]] + // CHECK-NEXT: [[OPER_MBAR1:%.*]] = ttg.memdesc_index [[OPER_MBARS]], [[C1]] // CHECK-NEXT: ttng.init_barrier [[OPER_MBAR1]], 1 // CHECK-NEXT: ttng.arrive_barrier [[READY_MBAR0]], 1 @@ -82,7 +82,7 @@ tt.func @warp_specialize_tma_matmul( // CHECK-NEXT: [[LAST_ITER:%.*]] = arith.subi [[K_TILES]], [[C1]] // CHECK-NEXT: [[DONE_MBAR:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<1xi64 - // CHECK-NEXT: [[DONE_MBAR0:%.*]] = ttg.memdesc_subview [[DONE_MBAR]][[[C0]]] + // CHECK-NEXT: [[DONE_MBAR0:%.*]] = ttg.memdesc_index [[DONE_MBAR]], [[C0]] // CHECK-NEXT: ttng.init_barrier [[DONE_MBAR0]], 1 // CHECK-NEXT: [[LAST:%.*]]:3 = scf.for [[K:%arg[0-9]+]] = [[C0]] to [[K_TILES]] step [[C1]] @@ -95,15 +95,15 @@ tt.func @warp_specialize_tma_matmul( // CHECK-NEXT: [[OFF_K:%.*]] = arith.muli [[K]], [[BLOCK_K]] %off_k = arith.muli %k, %BLOCK_K : i32 - // CHECK-NEXT: [[READY_MBAR:%.*]] = ttg.memdesc_subview [[READY_MBARS]][[[IDX]]] + // CHECK-NEXT: [[READY_MBAR:%.*]] = ttg.memdesc_index [[READY_MBARS]], [[IDX]] // CHECK-NEXT: ttng.wait_barrier [[READY_MBAR]], [[PHASE]] {ttg.partition = 2 : i32} - // CHECK-NEXT: [[OPER_MBAR:%.*]] = ttg.memdesc_subview [[OPER_MBARS]][[[IDX]]] + // CHECK-NEXT: [[OPER_MBAR:%.*]] = ttg.memdesc_index [[OPER_MBARS]], [[IDX]] // CHECK-NEXT: ttng.barrier_expect [[OPER_MBAR]], 32768 {ttg.partition = 2 : i32} - // CHECK-NEXT: [[A_BUF:%.*]] = ttg.memdesc_subview [[A_BUFS]][[[IDX]], [[C0]], [[C0]]] + // CHECK-NEXT: [[A_BUF:%.*]] = ttg.memdesc_index [[A_BUFS]], [[IDX]] // CHECK-NEXT: ttng.async_tma_copy_global_to_local [[A_DESC]][[[OFF_M]], [[OFF_K]]] [[A_BUF]], [[OPER_MBAR]], [[TRUE]] {ttg.partition = 2 : i32} %a_reg = tt.descriptor_load %a_desc[%off_m, %off_k] : !tt.tensordesc> -> tensor<128x64xf16, #oper_layout> - // CHECK-NEXT: [[B_BUF:%.*]] = ttg.memdesc_subview [[B_BUFS]][[[IDX]], [[C0]], [[C0]]] + // CHECK-NEXT: [[B_BUF:%.*]] = ttg.memdesc_index [[B_BUFS]], [[IDX]] // CHECK-NEXT: ttng.async_tma_copy_global_to_local [[B_DESC]][[[OFF_N]], [[OFF_K]]] [[B_BUF]], [[OPER_MBAR]], [[TRUE]] {ttg.partition = 2 : i32} %b_reg = tt.descriptor_load %b_desc[%off_n, %off_k] : !tt.tensordesc> -> tensor<128x64xf16, #oper_layout> @@ -166,11 +166,11 @@ tt.func @unsupported_load() { %k_tiles = arith.constant 32 : i32 // CHECK: [[ACC_ALLOC:%.*]], %{{.*}} = ttng.tmem_alloc : () -> (!ttg.memdesc<1x128x128xf32 - // CHECK-NEXT: [[ACC:%.*]] = ttg.memdesc_subview [[ACC_ALLOC]][%c0_i32 + // CHECK-NEXT: [[ACC:%.*]] = ttg.memdesc_index [[ACC_ALLOC]], %c0_i32 // CHECK-NEXT: tmem_store [[ZERO]], [[ACC]] // CHECK-NEXT: [[DONE_MBAR:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<1xi64 - // CHECK-NEXT: [[DONE_MBAR0:%.*]] = ttg.memdesc_subview [[DONE_MBAR]][%c0_i32] + // CHECK-NEXT: [[DONE_MBAR0:%.*]] = ttg.memdesc_index [[DONE_MBAR]], %c0_i32 // CHECK-NEXT: ttng.init_barrier [[DONE_MBAR0]], 1 // CHECK-NEXT: scf.for @@ -300,21 +300,21 @@ tt.func @matmul_tma_acc_with_unconditional_user( %k_tiles = arith.constant 32 : i32 // CHECK: [[ACC_BUFS:%.*]], [[ACC_TOK:%.*]] = ttng.tmem_alloc : () -> (!ttg.memdesc<2x128x128xf32 - // CHECK-NEXT: [[ACC_BUF0:%.*]] = ttg.memdesc_subview [[ACC_BUFS]][%c0_i32, %c0_i32, %c0_i32] + // CHECK-NEXT: [[ACC_BUF0:%.*]] = ttg.memdesc_index [[ACC_BUFS]], %c0_i32 // CHECK-NEXT: [[INIT_TOK:%.*]] = ttng.tmem_store [[ZERO]], [[ACC_BUF0]][[[ACC_TOK]]] // CHECK-COUNT-2: ttg.local_alloc : () -> !ttg.memdesc<2xi64 // CHECK: [[ACC_READY_BUFS:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<2xi64 - // CHECK-NEXT: [[ACC_READY_BUF0:%.*]] = ttg.memdesc_subview [[ACC_READY_BUFS]][%c0_i32] + // CHECK-NEXT: [[ACC_READY_BUF0:%.*]] = ttg.memdesc_index [[ACC_READY_BUFS]], %c0_i32 // CHECK-NEXT: ttng.init_barrier [[ACC_READY_BUF0]], 1 - // CHECK-NEXT: [[ACC_READY_BUF1:%.*]] = ttg.memdesc_subview [[ACC_READY_BUFS]][%c1_i32] + // CHECK-NEXT: [[ACC_READY_BUF1:%.*]] = ttg.memdesc_index [[ACC_READY_BUFS]], %c1_i32 // CHECK-NEXT: ttng.init_barrier [[ACC_READY_BUF1]], 1 // CHECK-NEXT: [[ACC_EMPTY_BUFS:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<2xi64 - // CHECK-NEXT: [[ACC_EMPTY_BUF0:%.*]] = ttg.memdesc_subview [[ACC_EMPTY_BUFS]][%c0_i32] + // CHECK-NEXT: [[ACC_EMPTY_BUF0:%.*]] = ttg.memdesc_index [[ACC_EMPTY_BUFS]], %c0_i32 // CHECK-NEXT: ttng.init_barrier [[ACC_EMPTY_BUF0]], 1 - // CHECK-NEXT: [[ACC_EMPTY_BUF1:%.*]] = ttg.memdesc_subview [[ACC_EMPTY_BUFS]][%c1_i32] + // CHECK-NEXT: [[ACC_EMPTY_BUF1:%.*]] = ttg.memdesc_index [[ACC_EMPTY_BUFS]], %c1_i32 // CHECK-NEXT: ttng.init_barrier [[ACC_EMPTY_BUF1]], 1 // CHECK-NEXT: ttng.arrive_barrier [[ACC_EMPTY_BUF0]], 1 @@ -340,9 +340,9 @@ tt.func @matmul_tma_acc_with_unconditional_user( %b_shared = ttg.local_alloc %b : (tensor<64x128xf16, #oper_layout>) -> !ttg.memdesc<64x128xf16, #shared, #smem> %c_tmem, %c_tok = ttng.tmem_alloc %acc : (tensor<128x128xf32, #acc_layout>) -> (!ttg.memdesc<128x128xf32, #acc_tmem, #ttng.tensor_memory, mutable>, !ttg.async.token) - // CHECK-NEXT: [[ACC_BUF:%.*]] = ttg.memdesc_subview [[ACC_BUFS]][[[ACC_INDEX]], %c0_i32, %c0_i32] + // CHECK-NEXT: [[ACC_BUF:%.*]] = ttg.memdesc_index [[ACC_BUFS]], [[ACC_INDEX]] - // CHECK-NEXT: [[CUR_ACC_READY_BAR:%.*]] = ttg.memdesc_subview [[ACC_READY_BUFS]][[[ACC_INDEX]]] + // CHECK-NEXT: [[CUR_ACC_READY_BAR:%.*]] = ttg.memdesc_index [[ACC_READY_BUFS]], [[ACC_INDEX]] // CHECK-NEXT: [[MMA_TOK:%.*]] = ttng.tc_gen5_mma %{{[0-9]+}}, %{{[0-9]+}}, [[ACC_BUF]][], %true, %true, {{.*}}, [[CUR_ACC_READY_BAR]][%true] {is_async, ttg.partition = 1 : i32} %mma_tok = ttng.tc_gen5_mma %a_shared, %b_shared, %c_tmem[%c_tok], %true, %true : !ttg.memdesc<128x64xf16, #shared, #smem>, !ttg.memdesc<64x128xf16, #shared, #smem>, !ttg.memdesc<128x128xf32, #acc_tmem, #ttng.tensor_memory, mutable> @@ -350,7 +350,7 @@ tt.func @matmul_tma_acc_with_unconditional_user( // CHECK-NEXT: [[C:%.*]], [[LOAD_TOK:%.*]] = ttng.tmem_load [[ACC_BUF]][] {ttg.partition = 0 : i32} %c, %load_tok = ttng.tmem_load %c_tmem[%mma_tok] : !ttg.memdesc<128x128xf32, #acc_tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #acc_layout> - // CHECK-NEXT: [[CUR_ACC_EMPTY_BAR:%.*]] = ttg.memdesc_subview [[ACC_EMPTY_BUFS]][[[ACC_INDEX]]] + // CHECK-NEXT: [[CUR_ACC_EMPTY_BAR:%.*]] = ttg.memdesc_index [[ACC_EMPTY_BUFS]], [[ACC_INDEX]] // CHECK-NEXT: ttng.arrive_barrier [[CUR_ACC_EMPTY_BAR]], 1 {ttg.partition = 0 : i32} "acc_user"(%c) : (tensor<128x128xf32, #acc_layout>) -> () @@ -362,8 +362,8 @@ tt.func @matmul_tma_acc_with_unconditional_user( // CHECK-NEXT: "acc_user"([[C]]) {ttg.partition = 0 : i32} - // CHECK-NEXT: [[NEXT_ACC_BUF:%.*]] = ttg.memdesc_subview [[ACC_BUFS]][[[ACC_NEXT_INDEX]], %c0_i32, %c0_i32] - // CHECK-NEXT: [[NEXT_ACC_EMPTY_BAR:%.*]] = ttg.memdesc_subview [[ACC_EMPTY_BUFS]][[[ACC_NEXT_INDEX]]] + // CHECK-NEXT: [[NEXT_ACC_BUF:%.*]] = ttg.memdesc_index [[ACC_BUFS]], [[ACC_NEXT_INDEX]] + // CHECK-NEXT: [[NEXT_ACC_EMPTY_BAR:%.*]] = ttg.memdesc_index [[ACC_EMPTY_BUFS]], [[ACC_NEXT_INDEX]] // CHECK-NEXT: ttng.wait_barrier [[NEXT_ACC_EMPTY_BAR]], [[ACC_NEXT_PHASE]], %true {ttg.partition = 1 : i32} // CHECK-NEXT: [[STORE_TOK:%.*]] = ttng.tmem_store [[ACC_RESET]], [[NEXT_ACC_BUF]][], %true {ttg.partition = 1 : i32} @@ -429,8 +429,8 @@ tt.func @matmul_tma_acc_with_conditional_user( %b_shared = ttg.local_alloc %b : (tensor<64x128xf16, #oper_layout>) -> !ttg.memdesc<64x128xf16, #shared, #smem> %c_tmem, %c_tok = ttng.tmem_alloc %acc : (tensor<128x128xf32, #acc_layout>) -> (!ttg.memdesc<128x128xf32, #acc_tmem, #ttng.tensor_memory, mutable>, !ttg.async.token) - // CHECK-NEXT: [[ACC_BUF:%.*]] = ttg.memdesc_subview [[ACC_BUFS]][[[ACC_INDEX]], %c0_i32, %c0_i32] - // CHECK-NEXT: [[CUR_ACC_READY_BAR:%.*]] = ttg.memdesc_subview [[ACC_READY_BUFS]][[[ACC_INDEX]]] + // CHECK-NEXT: [[ACC_BUF:%.*]] = ttg.memdesc_index [[ACC_BUFS]], [[ACC_INDEX]] + // CHECK-NEXT: [[CUR_ACC_READY_BAR:%.*]] = ttg.memdesc_index [[ACC_READY_BUFS]], [[ACC_INDEX]] // CHECK-NEXT: [[DO_EPILOGUE:%.*]] = arith.cmpi // CHECK-NEXT: [[MMA_TOK:%.*]] = ttng.tc_gen5_mma %{{[0-9]+}}, %{{[0-9]+}}, [[ACC_BUF]][], %true, %true, {{.*}}, [[CUR_ACC_READY_BAR]][[[DO_EPILOGUE]]] {is_async, ttg.partition = 1 : i32} %mma_tok = ttng.tc_gen5_mma %a_shared, %b_shared, %c_tmem[%c_tok], %true, %true : !ttg.memdesc<128x64xf16, #shared, #smem>, !ttg.memdesc<64x128xf16, #shared, #smem>, !ttg.memdesc<128x128xf32, #acc_tmem, #ttng.tensor_memory, mutable> @@ -444,7 +444,7 @@ tt.func @matmul_tma_acc_with_conditional_user( // CHECK-NEXT: [[C:%.*]], [[USER_TOK:%.*]] = ttng.tmem_load [[ACC_BUF]][] // CHECK-NEXT: "acc_user"([[C]]) "acc_user"(%c) : (tensor<128x128xf32, #acc_layout>) -> () - // CHECK-NEXT: [[CUR_ACC_EMPTY_BAR:%.*]] = ttg.memdesc_subview [[ACC_EMPTY_BUFS]][[[ACC_INDEX]]] + // CHECK-NEXT: [[CUR_ACC_EMPTY_BAR:%.*]] = ttg.memdesc_index [[ACC_EMPTY_BUFS]], [[ACC_INDEX]] // CHECK-NEXT: ttng.arrive_barrier [[CUR_ACC_EMPTY_BAR]], 1 {ttg.partition = 0 : i32} // CHECK-NEXT: } } @@ -457,8 +457,8 @@ tt.func @matmul_tma_acc_with_conditional_user( // CHECK-NEXT: [[EPILOGUE_ACC_NEXT_INDEX:%.*]] = arith.select [[DO_EPILOGUE]], [[ACC_NEXT_INDEX]], [[ACC_INDEX]] // CHECK-NEXT: [[EPILOGUE_ACC_NEXT_PHASE:%.*]] = arith.select [[DO_EPILOGUE]], [[ACC_NEXT_PHASE]], [[ACC_PHASE]] - // CHECK-NEXT: [[ACC_NEXT_BUF:%.*]] = ttg.memdesc_subview [[ACC_BUFS]][[[EPILOGUE_ACC_NEXT_INDEX]], %c0_i32, %c0_i32] - // CHECK-NEXT: [[NEXT_ACC_EMPTY_BAR:%.*]] = ttg.memdesc_subview [[ACC_EMPTY_BUFS]][[[EPILOGUE_ACC_NEXT_INDEX]]] + // CHECK-NEXT: [[ACC_NEXT_BUF:%.*]] = ttg.memdesc_index [[ACC_BUFS]], [[EPILOGUE_ACC_NEXT_INDEX]] + // CHECK-NEXT: [[NEXT_ACC_EMPTY_BAR:%.*]] = ttg.memdesc_index [[ACC_EMPTY_BUFS]], [[EPILOGUE_ACC_NEXT_INDEX]] // CHECK-NEXT: ttng.wait_barrier [[NEXT_ACC_EMPTY_BAR]], [[EPILOGUE_ACC_NEXT_PHASE]], [[DO_EPILOGUE]] {ttg.partition = 1 : i32} // CHECK-NEXT: ttng.tmem_store [[ACC_RESET]], [[ACC_NEXT_BUF]][], %true {ttg.partition = 1 : i32} @@ -524,9 +524,9 @@ tt.func @matmul_tma_acc_with_conditional_def( %b_shared = ttg.local_alloc %b : (tensor<64x128xf16, #oper_layout>) -> !ttg.memdesc<64x128xf16, #shared, #smem> %c_tmem, %c_tok = ttng.tmem_alloc %acc : (tensor<128x128xf32, #acc_layout>) -> (!ttg.memdesc<128x128xf32, #acc_tmem, #ttng.tensor_memory, mutable>, !ttg.async.token) - // CHECK-NEXT: [[ACC_BUF:%.*]] = ttg.memdesc_subview [[ACC_BUFS]][[[ACC_INDEX]], %c0_i32, %c0_i32] + // CHECK-NEXT: [[ACC_BUF:%.*]] = ttg.memdesc_index [[ACC_BUFS]], [[ACC_INDEX]] - // CHECK-NEXT: [[CUR_ACC_READY_BAR:%.*]] = ttg.memdesc_subview [[ACC_READY_BUFS]][[[ACC_INDEX]]] + // CHECK-NEXT: [[CUR_ACC_READY_BAR:%.*]] = ttg.memdesc_index [[ACC_READY_BUFS]], [[ACC_INDEX]] // CHECK-NEXT: [[MMA_TOK:%.*]] = ttng.tc_gen5_mma %{{[0-9]+}}, %{{[0-9]+}}, [[ACC_BUF]][], %true, %true, {{.*}}, [[CUR_ACC_READY_BAR]][%true] {is_async, ttg.partition = 1 : i32} %mma_tok = ttng.tc_gen5_mma %a_shared, %b_shared, %c_tmem[%c_tok], %true, %true : !ttg.memdesc<128x64xf16, #shared, #smem>, !ttg.memdesc<64x128xf16, #shared, #smem>, !ttg.memdesc<128x128xf32, #acc_tmem, #ttng.tensor_memory, mutable> %c, %load_tok = ttng.tmem_load %c_tmem[%mma_tok] : !ttg.memdesc<128x128xf32, #acc_tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #acc_layout> @@ -537,7 +537,7 @@ tt.func @matmul_tma_acc_with_conditional_def( // CHECK-NEXT: ttng.wait_barrier [[CUR_ACC_READY_BAR]], [[ACC_PHASE]] {ttg.partition = 0 : i32} // CHECK-NEXT: [[C:%.*]], [[LOAD_TOK:%.*]] = ttng.tmem_load [[ACC_BUF]][] {ttg.partition = 0 : i32} - // CHECK-NEXT: [[CUR_ACC_EMPTY_BAR:%.*]] = ttg.memdesc_subview [[ACC_EMPTY_BUFS]][[[ACC_INDEX]]] + // CHECK-NEXT: [[CUR_ACC_EMPTY_BAR:%.*]] = ttg.memdesc_index [[ACC_EMPTY_BUFS]], [[ACC_INDEX]] // CHECK-NEXT: ttng.arrive_barrier [[CUR_ACC_EMPTY_BAR]], 1 {ttg.partition = 0 : i32} // CHECK-NEXT: [[ACC_INDEX_INCR:%.*]] = arith.addi [[ACC_INDEX]], %c1_i32 @@ -549,8 +549,8 @@ tt.func @matmul_tma_acc_with_conditional_def( // CHECK-NEXT: "acc_user"([[C]]) "acc_user"(%c) : (tensor<128x128xf32, #acc_layout>) -> () - // CHECK-NEXT: [[NEXT_ACC_BUF:%.*]] = ttg.memdesc_subview [[ACC_BUFS]][[[ACC_NEXT_INDEX]], %c0_i32, %c0_i32] - // CHECK-NEXT: [[NEXT_ACC_EMPTY_BAR:%.*]] = ttg.memdesc_subview [[ACC_EMPTY_BUFS]][[[ACC_NEXT_INDEX]]] + // CHECK-NEXT: [[NEXT_ACC_BUF:%.*]] = ttg.memdesc_index [[ACC_BUFS]], [[ACC_NEXT_INDEX]] + // CHECK-NEXT: [[NEXT_ACC_EMPTY_BAR:%.*]] = ttg.memdesc_index [[ACC_EMPTY_BUFS]], [[ACC_NEXT_INDEX]] // CHECK-NEXT: ttng.wait_barrier [[NEXT_ACC_EMPTY_BAR]], [[ACC_NEXT_PHASE]], %true {ttg.partition = 1 : i32} // CHECK-NEXT: [[STORE_TOK:%.*]] = ttng.tmem_store [[ZERO]], [[NEXT_ACC_BUF]][], [[DO_EPILOGUE]] {ttg.partition = 1 : i32} @@ -615,9 +615,9 @@ tt.func @matmul_tma_acc_with_conditional_def_and_use( %b_shared = ttg.local_alloc %b : (tensor<64x128xf16, #oper_layout>) -> !ttg.memdesc<64x128xf16, #shared, #smem> %c_tmem, %c_tok = ttng.tmem_alloc %acc : (tensor<128x128xf32, #acc_layout>) -> (!ttg.memdesc<128x128xf32, #acc_tmem, #ttng.tensor_memory, mutable>, !ttg.async.token) - // CHECK-NEXT: [[ACC_BUF:%.*]] = ttg.memdesc_subview [[ACC_BUFS]][[[ACC_INDEX]], %c0_i32, %c0_i32] + // CHECK-NEXT: [[ACC_BUF:%.*]] = ttg.memdesc_index [[ACC_BUFS]], [[ACC_INDEX]] - // CHECK-NEXT: [[CUR_ACC_READY_BAR:%.*]] = ttg.memdesc_subview [[ACC_READY_BUFS]][[[ACC_INDEX]]] + // CHECK-NEXT: [[CUR_ACC_READY_BAR:%.*]] = ttg.memdesc_index [[ACC_READY_BUFS]], [[ACC_INDEX]] // CHECK-NEXT: [[DO_EPILOGUE:%.*]] = arith.cmpi // CHECK-NEXT: [[MMA_TOK:%.*]] = ttng.tc_gen5_mma %{{[0-9]+}}, %{{[0-9]+}}, [[ACC_BUF]][], %true, %true, {{.*}}, [[CUR_ACC_READY_BAR]][[[DO_EPILOGUE]]] {is_async, ttg.partition = 1 : i32} %mma_tok = ttng.tc_gen5_mma %a_shared, %b_shared, %c_tmem[%c_tok], %true, %true : !ttg.memdesc<128x64xf16, #shared, #smem>, !ttg.memdesc<64x128xf16, #shared, #smem>, !ttg.memdesc<128x128xf32, #acc_tmem, #ttng.tensor_memory, mutable> @@ -632,7 +632,7 @@ tt.func @matmul_tma_acc_with_conditional_def_and_use( // CHECK-NEXT: [[C:%.*]], [[USER_TOK:%.*]] = ttng.tmem_load [[ACC_BUF]][] // CHECK-NEXT: "acc_user"([[C]]) "acc_user"(%c) : (tensor<128x128xf32, #acc_layout>) -> () - // CHECK-NEXT: [[CUR_ACC_EMPTY_BAR:%.*]] = ttg.memdesc_subview [[ACC_EMPTY_BUFS]][[[ACC_INDEX]]] + // CHECK-NEXT: [[CUR_ACC_EMPTY_BAR:%.*]] = ttg.memdesc_index [[ACC_EMPTY_BUFS]], [[ACC_INDEX]] // CHECK-NEXT: ttng.arrive_barrier [[CUR_ACC_EMPTY_BAR]], 1 {ttg.partition = 0 : i32} // CHECK-NEXT: } } @@ -645,8 +645,8 @@ tt.func @matmul_tma_acc_with_conditional_def_and_use( // CHECK-NEXT: [[EPILOGUE_ACC_NEXT_INDEX:%.*]] = arith.select [[DO_EPILOGUE]], [[ACC_NEXT_INDEX]], [[ACC_INDEX]] // CHECK-NEXT: [[EPILOGUE_ACC_NEXT_PHASE:%.*]] = arith.select [[DO_EPILOGUE]], [[ACC_NEXT_PHASE]], [[ACC_PHASE]] - // CHECK-NEXT: [[NEXT_ACC_BUF:%.*]] = ttg.memdesc_subview [[ACC_BUFS]][[[EPILOGUE_ACC_NEXT_INDEX]], %c0_i32, %c0_i32] - // CHECK-NEXT: [[NEXT_ACC_EMPTY_BAR:%.*]] = ttg.memdesc_subview [[ACC_EMPTY_BUFS]][[[EPILOGUE_ACC_NEXT_INDEX]]] + // CHECK-NEXT: [[NEXT_ACC_BUF:%.*]] = ttg.memdesc_index [[ACC_BUFS]], [[EPILOGUE_ACC_NEXT_INDEX]] + // CHECK-NEXT: [[NEXT_ACC_EMPTY_BAR:%.*]] = ttg.memdesc_index [[ACC_EMPTY_BUFS]], [[EPILOGUE_ACC_NEXT_INDEX]] // CHECK-NEXT: ttng.wait_barrier [[NEXT_ACC_EMPTY_BAR]], [[EPILOGUE_ACC_NEXT_PHASE]], [[DO_EPILOGUE]] {ttg.partition = 1 : i32} // CHECK-NEXT: [[STORE_TOK:%.*]] = ttng.tmem_store [[ZERO]], [[NEXT_ACC_BUF]][], [[DO_EPILOGUE]] {ttg.partition = 1 : i32} @@ -687,17 +687,17 @@ tt.func @matmul_tma_acc_with_conditional_def_and_use_no_multibuf_flag( %k_tiles = arith.constant 32 : i32 // CHECK: [[ACC_BUFS:%.*]], [[ACC_TOK:%.*]] = ttng.tmem_alloc : () -> (!ttg.memdesc<1x128x128xf32, - // CHECK-NEXT: [[ACC_BUF:%.*]] = ttg.memdesc_subview [[ACC_BUFS]][%c0_i32, %c0_i32, %c0_i32] + // CHECK-NEXT: [[ACC_BUF:%.*]] = ttg.memdesc_index [[ACC_BUFS]], %c0_i32 // CHECK-NEXT: [[INIT_TOK:%.*]] = ttng.tmem_store [[ZERO]], [[ACC_BUF]][[[ACC_TOK]]], %true // CHECK-COUNT-2: ttg.local_alloc : () -> !ttg.memdesc<2xi64 // CHECK: [[ACC_READY_BUFS:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<1xi64 - // CHECK-NEXT: [[ACC_READY_BUF0:%.*]] = ttg.memdesc_subview [[ACC_READY_BUFS]][%c0_i32] + // CHECK-NEXT: [[ACC_READY_BUF0:%.*]] = ttg.memdesc_index [[ACC_READY_BUFS]], %c0_i32 // CHECK-NEXT: ttng.init_barrier [[ACC_READY_BUF0]], 1 // CHECK-NEXT: [[ACC_EMPTY_BUFS:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<1xi64 - // CHECK-NEXT: [[ACC_EMPTY_BUF0:%.*]] = ttg.memdesc_subview [[ACC_EMPTY_BUFS]][%c0_i32] + // CHECK-NEXT: [[ACC_EMPTY_BUF0:%.*]] = ttg.memdesc_index [[ACC_EMPTY_BUFS]], %c0_i32 // CHECK-NEXT: ttng.init_barrier [[ACC_EMPTY_BUF0]], 1 // CHECK-NEXT: ttng.arrive_barrier [[ACC_EMPTY_BUF0]], 1 @@ -925,7 +925,7 @@ tt.func @matmul_tma_acc_with_conditional_def_and_use_flag( %k_tiles = arith.constant 32 : i32 // CHECK: [[ACC_BUFS:%.*]], [[ACC_TOK:%.*]] = ttng.tmem_alloc : () -> (!ttg.memdesc<2x128x128xf32, - // CHECK-NEXT: [[ACC_BUF0:%.*]] = ttg.memdesc_subview [[ACC_BUFS]][%c0_i32, %c0_i32, %c0_i32] + // CHECK-NEXT: [[ACC_BUF0:%.*]] = ttg.memdesc_index [[ACC_BUFS]], %c0_i32 // CHECK-NEXT: ttng.tmem_store [[ZERO]], [[ACC_BUF0]] // CHECK-COUNT-2: ttg.local_alloc : () -> !ttg.memdesc<4x{{.*}}xf16, @@ -933,15 +933,15 @@ tt.func @matmul_tma_acc_with_conditional_def_and_use_flag( // CHECK-COUNT-4: ttng.arrive_barrier // CHECK: [[ACC_READY_BUFS:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<2xi64 - // CHECK-NEXT: [[ACC_READY_BUF0:%.*]] = ttg.memdesc_subview [[ACC_READY_BUFS]][%c0_i32] + // CHECK-NEXT: [[ACC_READY_BUF0:%.*]] = ttg.memdesc_index [[ACC_READY_BUFS]], %c0_i32 // CHECK-NEXT: ttng.init_barrier [[ACC_READY_BUF0]], 1 - // CHECK-NEXT: [[ACC_READY_BUF1:%.*]] = ttg.memdesc_subview [[ACC_READY_BUFS]][%c1_i32] + // CHECK-NEXT: [[ACC_READY_BUF1:%.*]] = ttg.memdesc_index [[ACC_READY_BUFS]], %c1_i32 // CHECK-NEXT: ttng.init_barrier [[ACC_READY_BUF1]], 1 // CHECK-NEXT: [[ACC_EMPTY_BUFS:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<2xi64 - // CHECK-NEXT: [[ACC_EMPTY_BUF0:%.*]] = ttg.memdesc_subview [[ACC_EMPTY_BUFS]][%c0_i32] + // CHECK-NEXT: [[ACC_EMPTY_BUF0:%.*]] = ttg.memdesc_index [[ACC_EMPTY_BUFS]], %c0_i32 // CHECK-NEXT: ttng.init_barrier [[ACC_EMPTY_BUF0]], 1 - // CHECK-NEXT: [[ACC_EMPTY_BUF1:%.*]] = ttg.memdesc_subview [[ACC_EMPTY_BUFS]][%c1_i32] + // CHECK-NEXT: [[ACC_EMPTY_BUF1:%.*]] = ttg.memdesc_index [[ACC_EMPTY_BUFS]], %c1_i32 // CHECK-NEXT: ttng.init_barrier [[ACC_EMPTY_BUF1]], 1 // CHECK-NEXT: ttng.arrive_barrier [[ACC_EMPTY_BUF0]], 1 @@ -968,8 +968,8 @@ tt.func @matmul_tma_acc_with_conditional_def_and_use_flag( %b_shared = ttg.local_alloc %b : (tensor<64x128xf16, #oper_layout>) -> !ttg.memdesc<64x128xf16, #shared, #smem> %c_tmem, %c_tok = ttng.tmem_alloc %acc : (tensor<128x128xf32, #acc_layout>) -> (!ttg.memdesc<128x128xf32, #acc_tmem, #ttng.tensor_memory, mutable>, !ttg.async.token) - // CHECK-NEXT: [[ACC_BUF:%.*]] = ttg.memdesc_subview [[ACC_BUFS]][[[ACC_INDEX]], %c0_i32, %c0_i32] - // CHECK-NEXT: [[CUR_ACC_READY_BUF:%.*]] = ttg.memdesc_subview [[ACC_READY_BUFS]][[[ACC_INDEX]]] + // CHECK-NEXT: [[ACC_BUF:%.*]] = ttg.memdesc_index [[ACC_BUFS]], [[ACC_INDEX]] + // CHECK-NEXT: [[CUR_ACC_READY_BUF:%.*]] = ttg.memdesc_index [[ACC_READY_BUFS]], [[ACC_INDEX]] // CHECK-NEXT: [[DO_EPILOGUE:%.*]] = arith.cmpi eq, [[K:%.*]], %c0_i32 // CHECK-NEXT: ttng.tc_gen5_mma %{{[0-9]+}}, %{{[0-9]+}}, [[ACC_BUF]][], [[FLAG]], %true, {{.*}}, [[CUR_ACC_READY_BUF]][[[DO_EPILOGUE]]] {is_async, ttg.partition = 1 : i32} @@ -989,7 +989,7 @@ tt.func @matmul_tma_acc_with_conditional_def_and_use_flag( // CHECK-NEXT: [[C:%.*]], [[USER_TOK:%.*]] = ttng.tmem_load [[ACC_BUF]][] // CHECK-NEXT: "acc_user"([[C]]) "acc_user"(%c) : (tensor<128x128xf32, #acc_layout>) -> () - // CHECK-NEXT: [[CUR_ACC_EMPTY_BUF:%.*]] = ttg.memdesc_subview [[ACC_EMPTY_BUFS]][[[ACC_INDEX]]] + // CHECK-NEXT: [[CUR_ACC_EMPTY_BUF:%.*]] = ttg.memdesc_index [[ACC_EMPTY_BUFS]], [[ACC_INDEX]] // CHECK-NEXT: ttng.arrive_barrier [[CUR_ACC_EMPTY_BUF]], 1 {ttg.partition = 0 : i32} // CHECK-NEXT: } } @@ -1002,7 +1002,7 @@ tt.func @matmul_tma_acc_with_conditional_def_and_use_flag( // CHECK-NEXT: [[EPILOGUE_ACC_NEXT_INDEX:%.*]] = arith.select [[DO_EPILOGUE]], [[ACC_NEXT_INDEX]], [[ACC_INDEX]] // CHECK-NEXT: [[EPILOGUE_ACC_NEXT_PHASE:%.*]] = arith.select [[DO_EPILOGUE]], [[ACC_NEXT_PHASE]], [[ACC_PHASE]] - // CHECK-NEXT: [[NEXT_ACC_EMPTY_BUF:%.*]] = ttg.memdesc_subview [[ACC_EMPTY_BUFS]][[[EPILOGUE_ACC_NEXT_INDEX]]] + // CHECK-NEXT: [[NEXT_ACC_EMPTY_BUF:%.*]] = ttg.memdesc_index [[ACC_EMPTY_BUFS]], [[EPILOGUE_ACC_NEXT_INDEX]] // CHECK-NEXT: ttng.wait_barrier [[NEXT_ACC_EMPTY_BUF]], [[EPILOGUE_ACC_NEXT_PHASE]], [[DO_EPILOGUE]] {ttg.partition = 1 : i32} // CHECK: arith.addi @@ -1056,11 +1056,11 @@ tt.func @specialize_mma_only(%rhs_desc: !tt.tensordesc !ttg.memdesc<3xi64, // CHECK: [[EMPTY_BARS:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<1xi64, - // CHECK-NEXT: [[EMPTY_BAR0:%.*]] = ttg.memdesc_subview [[EMPTY_BARS]][%c0_i32] + // CHECK-NEXT: [[EMPTY_BAR0:%.*]] = ttg.memdesc_index [[EMPTY_BARS]], %c0_i32 // CHECK-NEXT: ttng.init_barrier [[EMPTY_BAR0]], 1 // CHECK-NEXT: [[READY_BARS:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<1xi64, - // CHECK-NEXT: [[READY_BAR0:%.*]] = ttg.memdesc_subview [[READY_BARS]][%c0_i32] + // CHECK-NEXT: [[READY_BAR0:%.*]] = ttg.memdesc_index [[READY_BARS]], %c0_i32 // CHECK-NEXT: ttng.init_barrier [[READY_BAR0]], 1 // CHECK-NEXT: ttng.arrive_barrier [[READY_BAR0]], 1 @@ -1163,28 +1163,28 @@ tt.func @store_mma_load( %true = arith.constant true // CHECK: [[LHS_EMPTY_BARS:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<2xi64, - // CHECK: [[LHS_EMPTY_BAR0:%.*]] = ttg.memdesc_subview [[LHS_EMPTY_BARS]][%c0_i32] - // CHECK: [[LHS_EMPTY_BAR1:%.*]] = ttg.memdesc_subview [[LHS_EMPTY_BARS]][%c1_i32] + // CHECK: [[LHS_EMPTY_BAR0:%.*]] = ttg.memdesc_index [[LHS_EMPTY_BARS]], %c0_i32 + // CHECK: [[LHS_EMPTY_BAR1:%.*]] = ttg.memdesc_index [[LHS_EMPTY_BARS]], %c1_i32 // CHECK: [[LHS_READY_BARS:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<2xi64, // CHECK: arrive_barrier [[LHS_EMPTY_BAR0]] // CHECK: arrive_barrier [[LHS_EMPTY_BAR1]] // CHECK-NOT: arrive_barrier // CHECK: [[MMA_ENTRY_BARS:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<1xi64, - // CHECK: [[MMA_ENTRY_BAR:%.*]] = ttg.memdesc_subview [[MMA_ENTRY_BARS]][%c0_i32] + // CHECK: [[MMA_ENTRY_BAR:%.*]] = ttg.memdesc_index [[MMA_ENTRY_BARS]], %c0_i32 // CHECK: [[MMA_EXIT_BARS:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<1xi64, - // CHECK: [[MMA_EXIT_BAR:%.*]] = ttg.memdesc_subview [[MMA_EXIT_BARS]][%c0_i32] + // CHECK: [[MMA_EXIT_BAR:%.*]] = ttg.memdesc_index [[MMA_EXIT_BARS]], %c0_i32 // CHECK-NOT: arrive_barrier // CHECK: [[LHS_SHARED:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<128x64xf16, // CHECK: scf.for scf.for %i = %c0 to %ub step %c1 : i32 { - // CHECK-NEXT: [[LOAD_EMPTY_BAR:%.*]] = ttg.memdesc_subview [[LHS_EMPTY_BARS]] + // CHECK-NEXT: [[LOAD_EMPTY_BAR:%.*]] = ttg.memdesc_index [[LHS_EMPTY_BARS]] // CHECK-NEXT: wait_barrier [[LOAD_EMPTY_BAR]]{{.*}}partition = 2 - // CHECK-NEXT: [[LOAD_READY_BAR:%.*]] = ttg.memdesc_subview [[LHS_READY_BARS]] + // CHECK-NEXT: [[LOAD_READY_BAR:%.*]] = ttg.memdesc_index [[LHS_READY_BARS]] // CHECK-NEXT: barrier_expect [[LOAD_READY_BAR]]{{.*}}partition = 2 - // CHECK-NEXT: [[LOAD_BUF:%.*]] = ttg.memdesc_subview + // CHECK-NEXT: [[LOAD_BUF:%.*]] = ttg.memdesc_index // CHECK-NEXT: async_tma_copy_global_to_local{{.*}}partition = 2 %lhs = tt.descriptor_load %lhs_desc[%i, %i] : !tt.tensordesc> -> tensor<128x64xf16, #oper_layout> @@ -1200,7 +1200,7 @@ tt.func @store_mma_load( // CHECK-NEXT: [[ACC:%.*]] = "make_acc"() %acc = "make_acc"() : () -> tensor<128x128xf32, #acc_layout> - // CHECK-NEXT: [[ACC_TMEM:%.*]] = ttg.memdesc_subview + // CHECK-NEXT: [[ACC_TMEM:%.*]] = ttg.memdesc_index // CHECK-NEXT: tmem_store [[ACC]], [[ACC_TMEM]][], %true {{.*}}partition = 0 // CHECK-NEXT: arrive_barrier [[MMA_ENTRY_BAR]], {{.*}}partition = 0 %acc_tmem, %acc_tok = ttng.tmem_alloc %acc : (tensor<128x128xf32, #acc_layout>) -> (!ttg.memdesc<128x128xf32, #acc_tmem, #ttng.tensor_memory, mutable>, !ttg.async.token) @@ -1348,25 +1348,25 @@ tt.func public @attention_forward( // CHECK: [[QK_TMEM:%.*]], [[PV_TOK:%.*]] = ttng.tmem_alloc : () -> (!ttg.memdesc<2x256x64xf32, // CHECK-NEXT: [[PV_TMEM:%.*]], [[QK_TOK:%.*]] = ttng.tmem_alloc : () -> (!ttg.memdesc<1x256x64xf32, - // CHECK-NEXT: [[PV_0:%.*]] = ttg.memdesc_subview [[PV_TMEM]][%c0_i32, %c0_i32, %c0_i32] + // CHECK-NEXT: [[PV_0:%.*]] = ttg.memdesc_index [[PV_TMEM]], %c0_i32 // CHECK-NEXT: ttng.tmem_store [[ZERO]], [[PV_0]] // CHECK-NEXT: [[K_BUFS:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<3x64x64xf16, // CHECK-NEXT: [[K_EMPTY_MBARS:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<3xi64 - // CHECK-NEXT: [[K_EMPTY_BAR0:%.*]] = ttg.memdesc_subview [[K_EMPTY_MBARS]][%c0_i32] + // CHECK-NEXT: [[K_EMPTY_BAR0:%.*]] = ttg.memdesc_index [[K_EMPTY_MBARS]], %c0_i32 // CHECK-NEXT: ttng.init_barrier [[K_EMPTY_BAR0]], 1 - // CHECK-NEXT: [[K_EMPTY_BAR1:%.*]] = ttg.memdesc_subview [[K_EMPTY_MBARS]][%c1_i32] + // CHECK-NEXT: [[K_EMPTY_BAR1:%.*]] = ttg.memdesc_index [[K_EMPTY_MBARS]], %c1_i32 // CHECK-NEXT: ttng.init_barrier [[K_EMPTY_BAR1]], 1 - // CHECK-NEXT: [[K_EMPTY_BAR2:%.*]] = ttg.memdesc_subview [[K_EMPTY_MBARS]][%c2_i32] + // CHECK-NEXT: [[K_EMPTY_BAR2:%.*]] = ttg.memdesc_index [[K_EMPTY_MBARS]], %c2_i32 // CHECK-NEXT: ttng.init_barrier [[K_EMPTY_BAR2]], 1 // CHECK-NEXT: [[K_READY_MBARS:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<3xi64 - // CHECK-NEXT: [[K_READY_BAR0:%.*]] = ttg.memdesc_subview [[K_READY_MBARS]][%c0_i32] + // CHECK-NEXT: [[K_READY_BAR0:%.*]] = ttg.memdesc_index [[K_READY_MBARS]], %c0_i32 // CHECK-NEXT: ttng.init_barrier [[K_READY_BAR0]], 1 - // CHECK-NEXT: [[K_READY_BAR1:%.*]] = ttg.memdesc_subview [[K_READY_MBARS]][%c1_i32] + // CHECK-NEXT: [[K_READY_BAR1:%.*]] = ttg.memdesc_index [[K_READY_MBARS]], %c1_i32 // CHECK-NEXT: ttng.init_barrier [[K_READY_BAR1]], 1 - // CHECK-NEXT: [[K_READY_BAR2:%.*]] = ttg.memdesc_subview [[K_READY_MBARS]][%c2_i32] + // CHECK-NEXT: [[K_READY_BAR2:%.*]] = ttg.memdesc_index [[K_READY_MBARS]], %c2_i32 // CHECK-NEXT: ttng.init_barrier [[K_READY_BAR2]], 1 // CHECK-NEXT: ttng.arrive_barrier [[K_EMPTY_BAR0]], 1 @@ -1376,19 +1376,19 @@ tt.func public @attention_forward( // CHECK-NEXT: [[V_BUFS:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<3x64x64xf16, // CHECK-NEXT: [[V_EMPTY_MBARS:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<3xi64 - // CHECK-NEXT: [[V_EMPTY_BAR0:%.*]] = ttg.memdesc_subview [[V_EMPTY_MBARS]][%c0_i32] + // CHECK-NEXT: [[V_EMPTY_BAR0:%.*]] = ttg.memdesc_index [[V_EMPTY_MBARS]], %c0_i32 // CHECK-NEXT: ttng.init_barrier [[V_EMPTY_BAR0]], 1 - // CHECK-NEXT: [[V_EMPTY_BAR1:%.*]] = ttg.memdesc_subview [[V_EMPTY_MBARS]][%c1_i32] + // CHECK-NEXT: [[V_EMPTY_BAR1:%.*]] = ttg.memdesc_index [[V_EMPTY_MBARS]], %c1_i32 // CHECK-NEXT: ttng.init_barrier [[V_EMPTY_BAR1]], 1 - // CHECK-NEXT: [[V_EMPTY_BAR2:%.*]] = ttg.memdesc_subview [[V_EMPTY_MBARS]][%c2_i32] + // CHECK-NEXT: [[V_EMPTY_BAR2:%.*]] = ttg.memdesc_index [[V_EMPTY_MBARS]], %c2_i32 // CHECK-NEXT: ttng.init_barrier [[V_EMPTY_BAR2]], 1 // CHECK-NEXT: [[V_READY_MBARS:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<3xi64 - // CHECK-NEXT: [[V_READY_BAR0:%.*]] = ttg.memdesc_subview [[V_READY_MBARS]][%c0_i32] + // CHECK-NEXT: [[V_READY_BAR0:%.*]] = ttg.memdesc_index [[V_READY_MBARS]], %c0_i32 // CHECK-NEXT: ttng.init_barrier [[V_READY_BAR0]], 1 - // CHECK-NEXT: [[V_READY_BAR1:%.*]] = ttg.memdesc_subview [[V_READY_MBARS]][%c1_i32] + // CHECK-NEXT: [[V_READY_BAR1:%.*]] = ttg.memdesc_index [[V_READY_MBARS]], %c1_i32 // CHECK-NEXT: ttng.init_barrier [[V_READY_BAR1]], 1 - // CHECK-NEXT: [[V_READY_BAR2:%.*]] = ttg.memdesc_subview [[V_READY_MBARS]][%c2_i32] + // CHECK-NEXT: [[V_READY_BAR2:%.*]] = ttg.memdesc_index [[V_READY_MBARS]], %c2_i32 // CHECK-NEXT: ttng.init_barrier [[V_READY_BAR2]], 1 // CHECK-NEXT: ttng.arrive_barrier [[V_EMPTY_BAR0]], 1 @@ -1396,26 +1396,26 @@ tt.func public @attention_forward( // CHECK-NEXT: ttng.arrive_barrier [[V_EMPTY_BAR2]], 1 // CHECK-NEXT: [[QK_READY_MBARS:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<2xi64 - // CHECK-NEXT: [[QK_READY_BAR0:%.*]] = ttg.memdesc_subview [[QK_READY_MBARS]][%c0_i32] + // CHECK-NEXT: [[QK_READY_BAR0:%.*]] = ttg.memdesc_index [[QK_READY_MBARS]], %c0_i32 // CHECK-NEXT: ttng.init_barrier [[QK_READY_BAR0]], 1 - // CHECK-NEXT: [[QK_READY_BAR1:%.*]] = ttg.memdesc_subview [[QK_READY_MBARS]][%c1_i32] + // CHECK-NEXT: [[QK_READY_BAR1:%.*]] = ttg.memdesc_index [[QK_READY_MBARS]], %c1_i32 // CHECK-NEXT: ttng.init_barrier [[QK_READY_BAR1]], 1 // CHECK-NEXT: [[QK_EMPTY_MBARS:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<2xi64 - // CHECK-NEXT: [[QK_EMPTY_BAR0:%.*]] = ttg.memdesc_subview [[QK_EMPTY_MBARS]][%c0_i32] + // CHECK-NEXT: [[QK_EMPTY_BAR0:%.*]] = ttg.memdesc_index [[QK_EMPTY_MBARS]], %c0_i32 // CHECK-NEXT: ttng.init_barrier [[QK_EMPTY_BAR0]], 1 - // CHECK-NEXT: [[QK_EMPTY_BAR1:%.*]] = ttg.memdesc_subview [[QK_EMPTY_MBARS]][%c1_i32] + // CHECK-NEXT: [[QK_EMPTY_BAR1:%.*]] = ttg.memdesc_index [[QK_EMPTY_MBARS]], %c1_i32 // CHECK-NEXT: ttng.init_barrier [[QK_EMPTY_BAR1]], 1 // CHECK-NEXT: ttng.arrive_barrier [[QK_EMPTY_BAR0]], 1 // CHECK-NEXT: ttng.arrive_barrier [[QK_EMPTY_BAR1]], 1 // CHECK-NEXT: [[PV_EMPTY_MBARS:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<1xi64 - // CHECK-NEXT: [[PV_EMPTY_BAR0:%.*]] = ttg.memdesc_subview [[PV_EMPTY_MBARS]][%c0_i32] + // CHECK-NEXT: [[PV_EMPTY_BAR0:%.*]] = ttg.memdesc_index [[PV_EMPTY_MBARS]], %c0_i32 // CHECK-NEXT: ttng.init_barrier [[PV_EMPTY_BAR0]], 1 // CHECK-NEXT: [[PV_READY_MBARS:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<1xi64 - // CHECK-NEXT: [[PV_READY_BAR0:%.*]] = ttg.memdesc_subview [[PV_READY_MBARS]][%c0_i32] + // CHECK-NEXT: [[PV_READY_BAR0:%.*]] = ttg.memdesc_index [[PV_READY_MBARS]], %c0_i32 // CHECK-NEXT: ttng.init_barrier [[PV_READY_BAR0]], 1 // CHECK-NEXT: ttng.arrive_barrier [[PV_READY_BAR0]], 1 @@ -1424,11 +1424,11 @@ tt.func public @attention_forward( // CHECK-NEXT: [[P_BUF:%.*]] = ttng.tmem_alloc : () -> !ttg.memdesc<256x64xf16, // CHECK-NEXT: [[P_EMPTY_MBARS:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<1xi64 - // CHECK-NEXT: [[P_EMPTY_BAR0:%.*]] = ttg.memdesc_subview [[P_EMPTY_MBARS]][%c0_i32] + // CHECK-NEXT: [[P_EMPTY_BAR0:%.*]] = ttg.memdesc_index [[P_EMPTY_MBARS]], %c0_i32 // CHECK-NEXT: ttng.init_barrier [[P_EMPTY_BAR0]], 1 // CHECK-NEXT: [[P_READY_MBARS:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<1xi64 - // CHECK-NEXT: [[P_READY_BAR0:%.*]] = ttg.memdesc_subview [[P_READY_MBARS]][%c0_i32] + // CHECK-NEXT: [[P_READY_BAR0:%.*]] = ttg.memdesc_index [[P_READY_MBARS]], %c0_i32 // CHECK-NEXT: ttng.init_barrier [[P_READY_BAR0]], 1 // CHECK-NEXT: ttng.arrive_barrier [[P_EMPTY_MBARS]], 1 @@ -1455,11 +1455,11 @@ tt.func public @attention_forward( tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>> ) : i32 { - // CHECK-NEXT: [[K_EMPTY_BAR:%.*]] = ttg.memdesc_subview [[K_EMPTY_MBARS]][[[K_INDEX]]] + // CHECK-NEXT: [[K_EMPTY_BAR:%.*]] = ttg.memdesc_index [[K_EMPTY_MBARS]], [[K_INDEX]] // CHECK-NEXT: wait_barrier [[K_EMPTY_BAR]], [[K_PHASE]] {ttg.partition = 2 : i32} - // CHECK-NEXT: [[K_READY_BAR:%.*]] = ttg.memdesc_subview [[K_READY_MBARS]][[[K_INDEX]]] + // CHECK-NEXT: [[K_READY_BAR:%.*]] = ttg.memdesc_index [[K_READY_MBARS]], [[K_INDEX]] // CHECK-NEXT: barrier_expect [[K_READY_BAR]], 8192 {ttg.partition = 2 : i32} - // CHECK-NEXT: [[K_BUF:%.*]] = ttg.memdesc_subview [[K_BUFS]][[[K_INDEX]], %c0_i32, %c0_i32] + // CHECK-NEXT: [[K_BUF:%.*]] = ttg.memdesc_index [[K_BUFS]], [[K_INDEX]] // CHECK-NEXT: async_tma_copy_global_to_local [[K_DESC]][[[I]], %c0_i32] [[K_BUF]], [[K_READY_BAR]], %true {ttg.partition = 2 : i32} %K = tt.descriptor_load %K_desc[%i, %c0_i32] : !tt.tensordesc> -> tensor<64x64xf16, #load_blocked> %K_shared = ttg.local_alloc %K : (tensor<64x64xf16, #load_blocked>) -> !ttg.memdesc<64x64xf16, #shared, #smem> @@ -1467,10 +1467,10 @@ tt.func public @attention_forward( // CHECK-NEXT: [[K_TRANS:%.*]] = ttg.memdesc_trans [[K_BUF]] {order = array, ttg.partition = 1 : i32} %K_trans = ttg.memdesc_trans %K_shared {order = array} : !ttg.memdesc<64x64xf16, #shared, #smem> -> !ttg.memdesc<64x64xf16, #shared_T, #smem> // CHECK-NEXT: wait_barrier [[K_READY_BAR]], [[K_PHASE]] {ttg.partition = 1 : i32} - // CHECK-NEXT: [[QK_BUF:%.*]] = ttg.memdesc_subview [[QK_TMEM]][[[QK_INDEX]], %c0_i32, %c0_i32] - // CHECK-NEXT: [[QK_EMPTY_BAR:%.*]] = ttg.memdesc_subview [[QK_EMPTY_MBARS]][[[QK_INDEX]]] + // CHECK-NEXT: [[QK_BUF:%.*]] = ttg.memdesc_index [[QK_TMEM]], [[QK_INDEX]] + // CHECK-NEXT: [[QK_EMPTY_BAR:%.*]] = ttg.memdesc_index [[QK_EMPTY_MBARS]], [[QK_INDEX]] // CHECK-NEXT: wait_barrier [[QK_EMPTY_BAR]], [[QK_PHASE]], %true {ttg.partition = 1 : i32} - // CHECK-NEXT: [[QK_READY_BAR:%.*]] = ttg.memdesc_subview [[QK_READY_MBARS]][[[QK_INDEX]]] + // CHECK-NEXT: [[QK_READY_BAR:%.*]] = ttg.memdesc_index [[QK_READY_MBARS]], [[QK_INDEX]] // CHECK-NEXT: tc_gen5_mma [[Q_SHARED]], [[K_TRANS]], [[QK_BUF]][], %false, %true, [[K_EMPTY_BAR]][%true], [[QK_READY_BAR]][%true] {is_async, ttg.partition = 1 : i32} %QK_tmem, %QK_tok = ttng.tmem_alloc : () -> (!ttg.memdesc<256x64xf32, #tmem_acc, #ttng.tensor_memory, mutable>, !ttg.async.token) %QK_mma_tok = ttng.tc_gen5_mma %Q_shared, %K_trans, %QK_tmem[%QK_tok], %false, %true : !ttg.memdesc<256x64xf16, #shared, #smem>, !ttg.memdesc<64x64xf16, #shared_T, #smem>, !ttg.memdesc<256x64xf32, #tmem_acc, #ttng.tensor_memory, mutable> @@ -1524,11 +1524,11 @@ tt.func public @attention_forward( // CHECK-NEXT: [[ACC_CORRECTED:%.*]] = arith.mulf [[PV]], [[ALPHA_1]] {ttg.partition = 3 : i32} %acc_corrected = arith.mulf %acc, %alpha_1 : tensor<256x64xf32, #blocked> - // CHECK-NEXT: [[V_EMPTY_BAR:%.*]] = ttg.memdesc_subview [[V_EMPTY_MBARS]][[[V_INDEX]]] + // CHECK-NEXT: [[V_EMPTY_BAR:%.*]] = ttg.memdesc_index [[V_EMPTY_MBARS]], [[V_INDEX]] // CHECK-NEXT: wait_barrier [[V_EMPTY_BAR]], [[V_PHASE]] {ttg.partition = 2 : i32} - // CHECK-NEXT: [[V_READY_BAR:%.*]] = ttg.memdesc_subview [[V_READY_MBARS]][[[V_INDEX]]] + // CHECK-NEXT: [[V_READY_BAR:%.*]] = ttg.memdesc_index [[V_READY_MBARS]], [[V_INDEX]] // CHECK-NEXT: barrier_expect [[V_READY_BAR]], 8192 {ttg.partition = 2 : i32} - // CHECK-NEXT: [[V_BUF:%.*]] = ttg.memdesc_subview [[V_BUFS]][[[V_INDEX]], %c0_i32, %c0_i32] + // CHECK-NEXT: [[V_BUF:%.*]] = ttg.memdesc_index [[V_BUFS]], [[V_INDEX]] // CHECK-NEXT: async_tma_copy_global_to_local [[V_DESC]][[[I]], %c0_i32] [[V_BUF]], [[V_READY_BAR]], %true {ttg.partition = 2 : i32} %V = tt.descriptor_load %V_desc[%i, %c0_i32] : !tt.tensordesc> -> tensor<64x64xf16, #load_blocked> %V_shared = ttg.local_alloc %V : (tensor<64x64xf16, #load_blocked>) -> !ttg.memdesc<64x64xf16, #shared, #smem> diff --git a/test/TritonGPU/loop-pipeline-async-latencies.mlir b/test/TritonGPU/loop-pipeline-async-latencies.mlir index ebae48fed29c..f60f1491ca33 100644 --- a/test/TritonGPU/loop-pipeline-async-latencies.mlir +++ b/test/TritonGPU/loop-pipeline-async-latencies.mlir @@ -20,39 +20,39 @@ tt.func public @matmul_kernel_tma_persistent(%arg0: !tt.tensordesc !ttg.memdesc<4x256x64xf16, // CHECK: [[LHS_BARS:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<2xi64, - // CHECK-NEXT: [[LHS_BAR0:%.*]] = ttg.memdesc_subview [[LHS_BARS]][%c0_i32] + // CHECK-NEXT: [[LHS_BAR0:%.*]] = ttg.memdesc_index [[LHS_BARS]], %c0_i32 // CHECK-NEXT: ttng.init_barrier [[LHS_BAR0]] - // CHECK-NEXT: [[LHS_BAR1:%.*]] = ttg.memdesc_subview [[LHS_BARS]][%c1_i32] + // CHECK-NEXT: [[LHS_BAR1:%.*]] = ttg.memdesc_index [[LHS_BARS]], %c1_i32 // CHECK-NEXT: ttng.init_barrier [[LHS_BAR1]] // CHECK: [[RHS_BARS:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<4xi64, - // CHECK-NEXT: [[RHS_BAR0:%.*]] = ttg.memdesc_subview [[RHS_BARS]][%c0_i32] + // CHECK-NEXT: [[RHS_BAR0:%.*]] = ttg.memdesc_index [[RHS_BARS]], %c0_i32 // CHECK-NEXT: ttng.init_barrier [[RHS_BAR0]] - // CHECK-NEXT: [[RHS_BAR1:%.*]] = ttg.memdesc_subview [[RHS_BARS]][%c1_i32] + // CHECK-NEXT: [[RHS_BAR1:%.*]] = ttg.memdesc_index [[RHS_BARS]], %c1_i32 // CHECK-NEXT: ttng.init_barrier [[RHS_BAR1]] - // CHECK-NEXT: [[RHS_BAR2:%.*]] = ttg.memdesc_subview [[RHS_BARS]][%c2_i32] + // CHECK-NEXT: [[RHS_BAR2:%.*]] = ttg.memdesc_index [[RHS_BARS]], %c2_i32 // CHECK-NEXT: ttng.init_barrier [[RHS_BAR2]] - // CHECK-NEXT: [[RHS_BAR3:%.*]] = ttg.memdesc_subview [[RHS_BARS]][%c3_i32] + // CHECK-NEXT: [[RHS_BAR3:%.*]] = ttg.memdesc_index [[RHS_BARS]], %c3_i32 // CHECK-NEXT: ttng.init_barrier [[RHS_BAR3]] // CHECK: [[MASK0:%.*]] = arith.cmpi sgt, %arg3, %c0_i32 // CHECK-NEXT: ttng.barrier_expect [[RHS_BAR0]], 32768, [[MASK0]] - // CHECK-NEXT: [[RHS_BUF0:%.*]] = ttg.memdesc_subview [[RHS_BUFFERS]][%c0_i32, %c0_i32, %c0_i32] + // CHECK-NEXT: [[RHS_BUF0:%.*]] = ttg.memdesc_index [[RHS_BUFFERS]], %c0_i32 // CHECK-NEXT: ttng.async_tma_copy_global_to_local %arg1[%c0_i32, %c0_i32] [[RHS_BUF0]], [[RHS_BAR0]], [[MASK0]] // CHECK: [[MASK1:%.*]] = arith.cmpi sgt, %arg3, %c1_i32 // CHECK-NEXT: ttng.barrier_expect [[RHS_BAR1]], 32768, [[MASK1]] - // CHECK-NEXT: [[RHS_BUF1:%.*]] = ttg.memdesc_subview [[RHS_BUFFERS]][%c1_i32, %c0_i32, %c0_i32] + // CHECK-NEXT: [[RHS_BUF1:%.*]] = ttg.memdesc_index [[RHS_BUFFERS]], %c1_i32 // CHECK-NEXT: ttng.async_tma_copy_global_to_local %arg1[%c0_i32, %c1_i32] [[RHS_BUF1]], [[RHS_BAR1]], [[MASK1]] // CHECK: [[MASK2:%.*]] = arith.cmpi sgt, %arg3, %c2_i32 // CHECK-NEXT: ttng.barrier_expect [[LHS_BAR0]], 16384, [[MASK0]] - // CHECK-NEXT: [[LHS_BUF0:%.*]] = ttg.memdesc_subview [[LHS_BUFFERS]][%c0_i32, %c0_i32, %c0_i32] + // CHECK-NEXT: [[LHS_BUF0:%.*]] = ttg.memdesc_index [[LHS_BUFFERS]], %c0_i32 // CHECK-NEXT: ttng.async_tma_copy_global_to_local %arg0[%c0_i32, %c0_i32] [[LHS_BUF0]], [[LHS_BAR0]], [[MASK0]] // CHECK: ttng.barrier_expect [[RHS_BAR2]], 32768, [[MASK2]] - // CHECK-NEXT: [[RHS_BUF2:%.*]] = ttg.memdesc_subview [[RHS_BUFFERS]][%c2_i32, %c0_i32, %c0_i32] + // CHECK-NEXT: [[RHS_BUF2:%.*]] = ttg.memdesc_index [[RHS_BUFFERS]], %c2_i32 // CHECK-NEXT: ttng.async_tma_copy_global_to_local %arg1[%c0_i32, %c2_i32] [[RHS_BUF2]], [[RHS_BAR2]], [[MASK2]] %true = arith.constant true @@ -92,10 +92,10 @@ tt.func public @matmul_kernel_tma_persistent(%arg0: !tt.tensordesc> -> tensor<128x64xf16, #blocked> @@ -108,20 +108,20 @@ tt.func public @matmul_kernel_tma_persistent(%arg0: !tt.tensordesc (tensor<32x32xf32, #mma>) : i32 { // CHECK: ttng.wait_barrier - // CHECK: [[RHS_VIEW:%.*]] = ttg.memdesc_subview [[RHS_BUFS]] + // CHECK: [[RHS_VIEW:%.*]] = ttg.memdesc_index [[RHS_BUFS]] // CHECK: [[RHS:%.*]] = ttg.local_load [[RHS_VIEW]] - // CHECK: [[LHS_VIEW:%.*]] = ttg.memdesc_subview [[LHS_BUFS]] + // CHECK: [[LHS_VIEW:%.*]] = ttg.memdesc_index [[LHS_BUFS]] // CHECK: [[LHS:%.*]] = ttg.local_load [[LHS_VIEW]] // CHECK: tt.dot [[LHS]], [[RHS]] %lhs = tt.descriptor_gather %lhs_desc[%lhs_x_offsets, %y] : (!tt.tensordesc>, tensor<32xi32, #blocked1>, i32) -> tensor<32x128xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> diff --git a/test/TritonGPU/loop-pipeline-expand.mlir b/test/TritonGPU/loop-pipeline-expand.mlir index 2cdb31ed3b60..46622ac93ad9 100644 --- a/test/TritonGPU/loop-pipeline-expand.mlir +++ b/test/TritonGPU/loop-pipeline-expand.mlir @@ -15,9 +15,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // CHECK: ttg.local_alloc : () -> !ttg.memdesc<4x256x32xf32 // CHECK: ttg.local_alloc : () -> !ttg.memdesc<4x32x128xf32 %0:3 = scf.for %arg5 = %c0_i32 to %c128_i32 step %c1_i32 iter_args(%arg6 = %arg0, %arg7 = %arg1, %arg8 = %arg2) -> (tensor<256x128xf32, #mma>, tensor<256x32x!tt.ptr, #blocked>, tensor<32x128x!tt.ptr, #blocked1>) : i32 { - // CHECK: ttg.memdesc_subview {{.*}} : !ttg.memdesc<4x256x32xf32 + // CHECK: ttg.memdesc_index {{.*}} : !ttg.memdesc<4x256x32xf32 // CHECK: ttg.async_wait {{.*}} {num = 4 : i32} - // CHECK: ttg.memdesc_subview {{.*}} : !ttg.memdesc<4x32x128xf32 + // CHECK: ttg.memdesc_index {{.*}} : !ttg.memdesc<4x32x128xf32 // CHECK: ttng.warp_group_dot {{.*}} {inputPrecision = 0 : i32, isAsync = true} // CHECK: ttng.warp_group_dot_wait {{.*}} {pendings = 1 : i32} %1 = tt.load %arg7 {loop.cluster = 4 : i32, loop.stage = 0 : i32} : tensor<256x32x!tt.ptr, #blocked> diff --git a/test/TritonGPU/loop-pipeline-hopper.mlir b/test/TritonGPU/loop-pipeline-hopper.mlir index 0cce727d4888..f373b495b75a 100644 --- a/test/TritonGPU/loop-pipeline-hopper.mlir +++ b/test/TritonGPU/loop-pipeline-hopper.mlir @@ -21,35 +21,35 @@ // CHECK: %[[BBUFFER:.*]] = ttg.local_alloc // CHECK-DAG: %[[LOOP_COND_0:.*]] = arith.cmpi slt, %[[LB:.*]], %[[UB:.*]] // CHECK-DAG: %[[LOOP_COND_0_SPLAT_A:.*]] = tt.splat %[[LOOP_COND_0]] -// CHECK-DAG: %[[ASUB:.*]] = ttg.memdesc_subview %[[ABUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] : !ttg.memdesc<2x128x32xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable, 2x128x32> +// CHECK-DAG: %[[ASUB:.*]] = ttg.memdesc_index %[[ABUFFER]], %[[CONSTANT_0]] : !ttg.memdesc<2x128x32xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable, 2x128x32> // CHECK: %[[T_A0:.*]] = ttg.async_copy_global_to_local %{{.*}}, %[[ASUB]] mask %[[LOOP_COND_0_SPLAT_A]] : tensor<128x32x!tt.ptr, #blocked1> -> <128x32xf16, #shared, #smem, mutable, 2x128x32> // CHECK-DAG: %[[LOOP_COND_0_SPLAT_B:.*]] = tt.splat %[[LOOP_COND_0]] -// CHECK-DAG: %[[BSUB:.*]] = ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK-DAG: %[[BSUB:.*]] = ttg.memdesc_index %[[BBUFFER]], %[[CONSTANT_0]] // CHECK: %[[T_B0:.*]] = ttg.async_copy_global_to_local %{{.*}}, %[[BSUB]] mask %[[LOOP_COND_0_SPLAT_B]] other %{{.*}} : tensor<32x128x!tt.ptr, #blocked> -> <32x128xf16, #shared1, #smem, mutable, 2x32x128> // CHECK-DAG: %[[IV_1:.*]] = arith.addi %[[LB]], %[[STEP:.*]] // CHECK-DAG: %[[LOOP_COND_1:.*]] = arith.cmpi slt, %[[IV_1]], %[[UB]] // CHECK-DAG: %[[LOOP_COND_1_SPLAT_A:.*]] = tt.splat %[[LOOP_COND_1]] -// CHECK-DAG: %[[ASUB1:.*]] = ttg.memdesc_subview %[[ABUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK-DAG: %[[ASUB1:.*]] = ttg.memdesc_index %[[ABUFFER]], %[[CONSTANT_1]] // CHECK: %[[T_A1:.*]] = ttg.async_copy_global_to_local %{{.*}}, %[[ASUB1]] mask %[[LOOP_COND_1_SPLAT_A]] // CHECK-DAG: %[[LOOP_COND_1_SPLAT_B:.*]] = tt.splat %[[LOOP_COND_1]] -// CHECK-DAG: %[[BSUB1:.*]] = ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK-DAG: %[[BSUB1:.*]] = ttg.memdesc_index %[[BBUFFER]], %[[CONSTANT_1]] // CHECK: %[[T_B1:.*]] = ttg.async_copy_global_to_local %{{.*}}, %[[BSUB1]] mask %[[LOOP_COND_1_SPLAT_B]] // CHECK: scf.for {{.*}} iter_args({{.*}}, %[[INS_IDX:.*]] = %[[CONSTANT_1]], %[[EXT_IDX:.*]] = %[[CONSTANT_NEG1]] // CHECK: %[[EXT_IDX_2:.*]] = arith.addi %[[EXT_IDX]], %[[CONSTANT_1]] : i32 // CHECK: %[[CMP_EXT:.*]] = arith.cmpi sge, %[[EXT_IDX_2]], %[[CONSTANT_2]] // CHECK: %[[EXT_IDX_3:.*]] = arith.select %[[CMP_EXT]], %[[CONSTANT_0]], %[[EXT_IDX_2]] // CHECK: ttg.async_wait {{.*}} {num = 2 : i32} -// CHECK: %[[A:.*]] = ttg.memdesc_subview %[[ABUFFER]][%[[EXT_IDX_3]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: %[[A:.*]] = ttg.memdesc_index %[[ABUFFER]], %[[EXT_IDX_3]] // CHECK: %[[arg_a0_dot_op:.*]] = ttg.local_load %[[A]] -// CHECK: %[[B:.*]] = ttg.memdesc_subview %[[BBUFFER]][%[[EXT_IDX_3]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: %[[B:.*]] = ttg.memdesc_index %[[BBUFFER]], %[[EXT_IDX_3]] // CHECK: %[[arg_b0_dot_op_0:.*]] = ttg.local_load %[[B]] // CHECK: tt.dot %[[arg_a0_dot_op]], %[[arg_b0_dot_op_0]], {{.*}} // CHECK-DAG: %[[INS_IDX_2:.*]] = arith.addi %[[INS_IDX]], %[[CONSTANT_1]] : i32 // CHECK-DAG: %[[CMP_INS:.*]] = arith.cmpi sge, %[[INS_IDX_2]], %[[CONSTANT_2]] // CHECK-DAG: %[[INS_IDX_3:.*]] = arith.select %[[CMP_INS]], %[[CONSTANT_0]], %[[INS_IDX_2]] -// CHECK: %[[ASUB3:.*]] = ttg.memdesc_subview %[[ABUFFER]][%[[INS_IDX_3]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: %[[ASUB3:.*]] = ttg.memdesc_index %[[ABUFFER]], %[[INS_IDX_3]] // CHECK: %[[NEXT_A_BUFFER:.*]] = ttg.async_copy_global_to_local {{.*}}, %[[ASUB3]] -// CHECK: %[[BSUB3:.*]] = ttg.memdesc_subview %[[BBUFFER]][%[[INS_IDX_3]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: %[[BSUB3:.*]] = ttg.memdesc_index %[[BBUFFER]], %[[INS_IDX_3]] // CHECK: %[[NEXT_B_BUFFER:.*]] = ttg.async_copy_global_to_local {{.*}}, %[[BSUB3]] // CHECK: scf.yield {{.*}}, %[[INS_IDX_3]], %[[EXT_IDX_3]] module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { diff --git a/test/TritonGPU/loop-pipeline.mlir b/test/TritonGPU/loop-pipeline.mlir index 7a53eea6fb12..e8ecdb3a2062 100644 --- a/test/TritonGPU/loop-pipeline.mlir +++ b/test/TritonGPU/loop-pipeline.mlir @@ -23,36 +23,36 @@ // CHECK: %[[BBUFFER:.*]] = ttg.local_alloc // CHECK-DAG: %[[LOOP_COND_0:.*]] = arith.cmpi slt, %[[LB:.*]], %[[UB:.*]] // CHECK-DAG: %[[LOOP_COND_0_SPLAT_A:.*]] = tt.splat %[[LOOP_COND_0]] -// CHECK-DAG: %[[ASUB:.*]] = ttg.memdesc_subview %[[ABUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK-DAG: %[[ASUB:.*]] = ttg.memdesc_index %[[ABUFFER]], %[[CONSTANT_0]] // CHECK: %[[T_A0:.*]] = ttg.async_copy_global_to_local %{{.*}}, %[[ASUB]] mask %[[LOOP_COND_0_SPLAT_A]] // CHECK-DAG: %[[LOOP_COND_0_SPLAT_B:.*]] = tt.splat %[[LOOP_COND_0]] -// CHECK-DAG: %[[BSUB:.*]] = ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK-DAG: %[[BSUB:.*]] = ttg.memdesc_index %[[BBUFFER]], %[[CONSTANT_0]] // CHECK: %[[T_B0:.*]] = ttg.async_copy_global_to_local %{{.*}}, %[[BSUB]] mask %[[LOOP_COND_0_SPLAT_B]] other %{{.*}} // CHECK-DAG: %[[IV_1:.*]] = arith.addi %[[LB]], %[[STEP:.*]] // CHECK-DAG: %[[LOOP_COND_1:.*]] = arith.cmpi slt, %[[IV_1]], %[[UB]] // CHECK-DAG: %[[LOOP_COND_1_SPLAT_A:.*]] = tt.splat %[[LOOP_COND_1]] -// CHECK-DAG: %[[ASUB1:.*]] = ttg.memdesc_subview %[[ABUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK-DAG: %[[ASUB1:.*]] = ttg.memdesc_index %[[ABUFFER]], %[[CONSTANT_1]] // CHECK: %[[T_A1:.*]] = ttg.async_copy_global_to_local %{{.*}}, %[[ASUB1]] mask %[[LOOP_COND_1_SPLAT_A]] // CHECK-DAG: %[[LOOP_COND_1_SPLAT_B:.*]] = tt.splat %[[LOOP_COND_1]] -// CHECK-DAG: %[[BSUB1:.*]] = ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK-DAG: %[[BSUB1:.*]] = ttg.memdesc_index %[[BBUFFER]], %[[CONSTANT_1]] // CHECK: %[[T_B1:.*]] = ttg.async_copy_global_to_local %{{.*}}, %[[BSUB1]] mask %[[LOOP_COND_1_SPLAT_B]] // CHECK: scf.for {{.*}} iter_args({{.*}}, %[[INS_IDX:.*]] = %[[CONSTANT_1]], %[[EXT_IDX:.*]] = %[[CONSTANT_NEG1]] // CHECK-DAG: %[[EXT_IDX_2:.*]] = arith.addi %[[EXT_IDX]], %[[CONSTANT_1]] : i32 // CHECK-DAG: %[[CMP_EXT:.*]] = arith.cmpi sge, %[[EXT_IDX_2]], %[[CONSTANT_2]] // CHECK-DAG: %[[EXT_IDX_3:.*]] = arith.select %[[CMP_EXT]], %[[CONSTANT_0]], %[[EXT_IDX_2]] // CHECK-DAG: ttg.async_wait {{.*}} {num = 2 : i32} -// CHECK-DAG: %[[A0:.*]] = ttg.memdesc_subview %[[ABUFFER]][%[[EXT_IDX_3]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK-DAG: %[[A0:.*]] = ttg.memdesc_index %[[ABUFFER]], %[[EXT_IDX_3]] // CHECK: %[[arg_a0_dot_op:.*]] = ttg.local_load %[[A0]] -// CHECK-DAG: %[[B0:.*]] = ttg.memdesc_subview %[[BBUFFER]][%[[EXT_IDX_3]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK-DAG: %[[B0:.*]] = ttg.memdesc_index %[[BBUFFER]], %[[EXT_IDX_3]] // CHECK: %[[arg_b0_dot_op_0:.*]] = ttg.local_load %[[B0]] // CHECK: %[[arg_b0_dot_op_1:.*]] = arith.mulf %[[arg_b0_dot_op_0]] // CHECK: tt.dot %[[arg_a0_dot_op]], %[[arg_b0_dot_op_1]], {{.*}} // CHECK-DAG: %[[INS_IDX_2:.*]] = arith.addi %[[INS_IDX]], %[[CONSTANT_1]] : i32 // CHECK-DAG: %[[CMP_INS:.*]] = arith.cmpi sge, %[[INS_IDX_2]], %[[CONSTANT_2]] // CHECK-DAG: %[[INS_IDX_3:.*]] = arith.select %[[CMP_INS]], %[[CONSTANT_0]], %[[INS_IDX_2]] -// CHECK: %[[ASUB3:.*]] = ttg.memdesc_subview %[[ABUFFER]][%[[INS_IDX_3]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: %[[ASUB3:.*]] = ttg.memdesc_index %[[ABUFFER]], %[[INS_IDX_3]] // CHECK: %[[NEXT_A_BUFFER:.*]] = ttg.async_copy_global_to_local {{.*}}, %[[ASUB3]] -// CHECK: %[[BSUB3:.*]] = ttg.memdesc_subview %[[BBUFFER]][%[[INS_IDX_3]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: %[[BSUB3:.*]] = ttg.memdesc_index %[[BBUFFER]], %[[INS_IDX_3]] // CHECK: %[[NEXT_B_BUFFER:.*]] = ttg.async_copy_global_to_local {{.*}}, %[[BSUB3]] // CHECK: scf.yield {{.*}}, %[[INS_IDX_3]], %[[EXT_IDX_3]] @@ -73,9 +73,9 @@ // AMD: %[[ADDI_42:.*]] = arith.addi %[[ARG9]], %{{.*}} // AMD: %[[CMPI_43:.*]] = arith.cmpi slt, %[[ADDI_42]], %{{.*}} // AMD: %[[SELECT_44:.*]] = arith.select %[[CMPI_43]], %[[ADDI_42]], %{{.*}} -// AMD: %[[MEMDESC_SUBVIEW_45:.*]] = ttg.memdesc_subview %{{.*}}[%[[SELECT_44]], %{{.*}}, %{{.*}}] +// AMD: %[[MEMDESC_SUBVIEW_45:.*]] = ttg.memdesc_index %{{.*}}, %[[SELECT_44]] // AMD: ttg.local_store %[[LOAD_36]], %[[MEMDESC_SUBVIEW_45]] -// AMD: %[[MEMDESC_SUBVIEW_46:.*]] = ttg.memdesc_subview %{{.*}}[%[[SELECT_44]], %{{.*}}, %{{.*}}] +// AMD: %[[MEMDESC_SUBVIEW_46:.*]] = ttg.memdesc_index %{{.*}}, %[[SELECT_44]] // AMD: ttg.local_store %[[LOAD_38]], %[[MEMDESC_SUBVIEW_46]] // AMD: scf.yield %[[ADDPTR_34]], %[[ADDPTR_35]], %[[DOT_41]], %[[SELECT_44]], %[[MEMDESC_SUBVIEW_45]], %[[MEMDESC_SUBVIEW_46]] // AMD: } @@ -180,30 +180,30 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, // CHECK: scf.for // CHECK: %[[ABUFFER:.*]] = ttg.local_alloc // CHECK: %[[BBUFFER:.*]] = ttg.local_alloc -// CHECK: ttg.memdesc_subview %[[ABUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: ttg.memdesc_index %[[ABUFFER]], %[[CONSTANT_0]] // CHECK: ttg.async_copy_global_to_local -// CHECK: ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: ttg.memdesc_index %[[BBUFFER]], %[[CONSTANT_0]] // CHECK: ttg.async_copy_global_to_local -// CHECK: ttg.memdesc_subview %[[ABUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: ttg.memdesc_index %[[ABUFFER]], %[[CONSTANT_1]] // CHECK: ttg.async_copy_global_to_local -// CHECK: ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: ttg.memdesc_index %[[BBUFFER]], %[[CONSTANT_1]] // CHECK: ttg.async_copy_global_to_local // CHECK: scf.for {{.*}} iter_args({{.*}}, %[[INS_IDX:.*]] = %[[CONSTANT_1]], %[[EXT_IDX:.*]] = %[[CONSTANT_NEG1]]{{.*}} // CHECK: %[[EXT_IDX_2:.*]] = arith.addi %[[EXT_IDX]], %[[CONSTANT_1]] : i32 // CHECK: %[[CMP_EXT:.*]] = arith.cmpi sge, %[[EXT_IDX_2]], %[[CONSTANT_2]] // CHECK: %[[EXT_IDX_3:.*]] = arith.select %[[CMP_EXT]], %[[CONSTANT_0]], %[[EXT_IDX_2]] // CHECK: ttg.async_wait {{.*}} {num = 2 : i32} -// CHECK: %[[A:.*]] = ttg.memdesc_subview %[[ABUFFER]][%[[EXT_IDX_3]], +// CHECK: %[[A:.*]] = ttg.memdesc_index %[[ABUFFER]], %[[EXT_IDX_3]] // CHECK: %[[arg_a0_dot_op:.*]] = ttg.local_load %[[A]] -// CHECK: %[[B:.*]] = ttg.memdesc_subview %[[BBUFFER]][%[[EXT_IDX_3]], +// CHECK: %[[B:.*]] = ttg.memdesc_index %[[BBUFFER]], %[[EXT_IDX_3]] // CHECK: %[[arg_b0_dot_op_0:.*]] = ttg.local_load %[[B]] // CHECK: tt.dot %[[arg_a0_dot_op]], %[[arg_b0_dot_op_0]], {{.*}} // CHECK-DAG: %[[INS_IDX_2:.*]] = arith.addi %[[INS_IDX]], %[[CONSTANT_1]] : i32 // CHECK-DAG: %[[CMP_INS:.*]] = arith.cmpi sge, %[[INS_IDX_2]], %[[CONSTANT_2]] // CHECK-DAG: %[[INS_IDX_3:.*]] = arith.select %[[CMP_INS]], %[[CONSTANT_0]], %[[INS_IDX_2]] -// CHECK: ttg.memdesc_subview %[[ABUFFER]][%[[INS_IDX_3]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: ttg.memdesc_index %[[ABUFFER]], %[[INS_IDX_3]] // CHECK: ttg.async_copy_global_to_local -// CHECK: ttg.memdesc_subview %[[BBUFFER]][%[[INS_IDX_3]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: ttg.memdesc_index %[[BBUFFER]], %[[INS_IDX_3]] // CHECK: ttg.async_copy_global_to_local // CHECK: scf.yield {{.*}}, %[[INS_IDX_3]], %[[EXT_IDX_3]] // CHECK: ttg.async_wait {num = 0 : i32} @@ -213,9 +213,9 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, // AMD: scf.for // AMD-COUNT-2: ttg.local_alloc // AMD-COUNT-2: tt.load -// AMD: %[[SUBVIEW0:.*]] = ttg.memdesc_subview +// AMD: %[[SUBVIEW0:.*]] = ttg.memdesc_index // AMD: ttg.local_store %{{.+}}, %[[SUBVIEW0]] -// AMD: %[[SUBVIEW1:.*]] = ttg.memdesc_subview +// AMD: %[[SUBVIEW1:.*]] = ttg.memdesc_index // AMD: ttg.local_store %{{.+}}, %[[SUBVIEW1]] // AMD: %[[FOR:.*]]:6 = scf.for // AMD-COUNT-2: tt.addptr @@ -224,9 +224,9 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, // AMD: tt.load // AMD: ttg.local_load // AMD: tt.dot -// AMD: %[[SUBVIEW0:.*]] = ttg.memdesc_subview +// AMD: %[[SUBVIEW0:.*]] = ttg.memdesc_index // AMD: ttg.local_store %{{.+}}, %[[SUBVIEW0]] -// AMD: %[[SUBVIEW1:.*]] = ttg.memdesc_subview +// AMD: %[[SUBVIEW1:.*]] = ttg.memdesc_index // AMD: ttg.local_store %{{.+}}, %[[SUBVIEW1]] // AMD: scf.yield // AMD-COUNT-2: ttg.local_load @@ -290,22 +290,22 @@ tt.func @matmul_loop_nested(%lb : index, %ub : index, %step : index, // CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32 // CHECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32 // CHECK: %[[BBUFFER:.*]] = ttg.local_alloc -// CHECK: ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: ttg.memdesc_index %[[BBUFFER]], %[[CONSTANT_0]] // CHECK: ttg.async_copy_global_to_local -// CHECK: ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: ttg.memdesc_index %[[BBUFFER]], %[[CONSTANT_1]] // CHECK: ttg.async_copy_global_to_local // CHECK: scf.for {{.*}} iter_args({{.*}}, %[[INS_IDX:.*]] = %[[CONSTANT_1]], %[[EXT_IDX:.*]] = %[[CONSTANT_NEG1]] // CHECK: %[[EXT_IDX_2:.*]] = arith.addi %[[EXT_IDX]], %[[CONSTANT_1]] : i32 // CHECK: %[[CMP_EXT:.*]] = arith.cmpi sge, %[[EXT_IDX_2]], %[[CONSTANT_2]] // CHECK: %[[EXT_IDX_3:.*]] = arith.select %[[CMP_EXT]], %[[CONSTANT_0]], %[[EXT_IDX_2]] // CHECK: ttg.async_wait {{.*}} {num = 1 : i32} -// CHECK: %[[B0:.*]] = ttg.memdesc_subview %[[BBUFFER]][%[[EXT_IDX_3]] +// CHECK: %[[B0:.*]] = ttg.memdesc_index %[[BBUFFER]], %[[EXT_IDX_3]] // CHECK: %[[arg_b0_dot_op:.*]] = ttg.local_load %[[B0]] // CHECK: tt.dot {{.*}}, %[[arg_b0_dot_op]], {{.*}} // CHECK-DAG: %[[INS_IDX_2:.*]] = arith.addi %[[INS_IDX]], %[[CONSTANT_1]] : i32 // CHECK-DAG: %[[CMP_INS:.*]] = arith.cmpi sge, %[[INS_IDX_2]], %[[CONSTANT_2]] // CHECK-DAG: %[[INS_IDX_3:.*]] = arith.select %[[CMP_INS]], %[[CONSTANT_0]], %[[INS_IDX_2]] -// CHECK: ttg.memdesc_subview %[[BBUFFER]][%[[INS_IDX_3]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: ttg.memdesc_index %[[BBUFFER]], %[[INS_IDX_3]] // CHECK: ttg.async_copy_global_to_local // CHECK: scf.yield {{.*}}, %[[INS_IDX_3]], %[[EXT_IDX_3]] @@ -316,7 +316,7 @@ tt.func @matmul_loop_nested(%lb : index, %ub : index, %step : index, // AMD: %[[CMPI_13:.*]] = arith.cmpi slt, %{{.*}}, %{{.*}} // AMD: %[[SPLAT_14:.*]] = tt.splat %[[CMPI_13]] // AMD: %[[LOAD_15:.*]] = tt.load %{{.*}}, %[[SPLAT_14]], %{{.*}} -// AMD: %[[MEMDESC_SUBVIEW_16:.*]] = ttg.memdesc_subview %[[LOCAL_ALLOC_12]][%{{.*}}, %{{.*}}, %{{.*}}] +// AMD: %[[MEMDESC_SUBVIEW_16:.*]] = ttg.memdesc_index %[[LOCAL_ALLOC_12]], %{{.*}} // AMD: ttg.local_store %[[LOAD_15]], %[[MEMDESC_SUBVIEW_16]] // AMD: %[[SUBI_17:.*]] = arith.subi %{{.*}}, %{{.*}} // AMD: %{{.*}}:4 = scf.for %[[ARG5:.*]] = %{{.*}} to %[[SUBI_17]] step %{{.*}} iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %[[MEMDESC_SUBVIEW_16]]) @@ -327,7 +327,7 @@ tt.func @matmul_loop_nested(%lb : index, %ub : index, %step : index, // AMD: %[[ADDI_34:.*]] = arith.addi %[[ARG8]], %{{.*}} // AMD: %[[CMPI_35:.*]] = arith.cmpi slt, %[[ADDI_34]], %{{.*}} // AMD: %[[SELECT_36:.*]] = arith.select %[[CMPI_35]], %[[ADDI_34]], %{{.*}} -// AMD: %[[MEMDESC_SUBVIEW_37:.*]] = ttg.memdesc_subview %[[LOCAL_ALLOC_12]][%[[SELECT_36]], %{{.*}}, %{{.*}}] +// AMD: %[[MEMDESC_SUBVIEW_37:.*]] = ttg.memdesc_index %[[LOCAL_ALLOC_12]], %[[SELECT_36]] // AMD: ttg.local_store %[[LOAD_33]], %[[MEMDESC_SUBVIEW_37]] // AMD: scf.yield %[[ADDPTR_32]], %[[DOT_31]], %[[SELECT_36]], %[[MEMDESC_SUBVIEW_37]] // AMD: ttg.local_dealloc %[[LOCAL_ALLOC_12]] @@ -416,9 +416,9 @@ tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, // AMD: %[[ADDPTR_8:.*]] = tt.addptr %{{.*}}, %[[SPLAT_7]] // AMD: %[[SPLAT_9:.*]] = tt.splat %[[CMPI_2]] // AMD: %[[LOAD_10:.*]] = tt.load %[[ADDPTR_8]], %[[SPLAT_9]] {amd.pipeliner_part = "prologue"} -// AMD: %[[MEMDESC_SUBVIEW_11:.*]] = ttg.memdesc_subview %[[LOCAL_ALLOC_0]][%{{.*}}, %{{.*}}, %{{.*}}] +// AMD: %[[MEMDESC_SUBVIEW_11:.*]] = ttg.memdesc_index %[[LOCAL_ALLOC_0]], %{{.*}} // AMD: ttg.local_store %[[LOAD_4]], %[[MEMDESC_SUBVIEW_11]] -// AMD: %[[MEMDESC_SUBVIEW_12:.*]] = ttg.memdesc_subview %[[LOCAL_ALLOC_1]][%{{.*}}, %{{.*}}, %{{.*}}] +// AMD: %[[MEMDESC_SUBVIEW_12:.*]] = ttg.memdesc_index %[[LOCAL_ALLOC_1]], %{{.*}} // AMD: ttg.local_store %[[LOAD_10]], %[[MEMDESC_SUBVIEW_12]] // AMD: %[[SUBI_26:.*]] = arith.subi %{{.*}}, %{{.*}} // AMD: %{{.*}}:7 = scf.for %[[ARG6:.*]] = %{{.*}} to %[[SUBI_26]] step %{{.*}} iter_args(%[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %{{.*}}, %[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %[[MEMDESC_SUBVIEW_11]], %[[ARG12:.*]] = %{{.*}}, %[[ARG13:.*]] = %[[MEMDESC_SUBVIEW_12]]) @@ -436,9 +436,9 @@ tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, // AMD: %[[ADDI_49:.*]] = arith.addi %[[ARG10]], %{{.*}} // AMD: %[[CMPI_50:.*]] = arith.cmpi slt, %[[ADDI_49]], %{{.*}} // AMD: %[[SELECT_51:.*]] = arith.select %[[CMPI_50]], %[[ADDI_49]], %{{.*}} -// AMD: %[[MEMDESC_SUBVIEW_52:.*]] = ttg.memdesc_subview %[[LOCAL_ALLOC_0]][%[[SELECT_51]], %{{.*}}, %{{.*}}] +// AMD: %[[MEMDESC_SUBVIEW_52:.*]] = ttg.memdesc_index %[[LOCAL_ALLOC_0]], %[[SELECT_51]] // AMD: ttg.local_store %[[LOAD_40]], %[[MEMDESC_SUBVIEW_52]] -// AMD: %[[MEMDESC_SUBVIEW_53:.*]] = ttg.memdesc_subview %[[LOCAL_ALLOC_1]][%[[SELECT_51]], %{{.*}}, %{{.*}}] +// AMD: %[[MEMDESC_SUBVIEW_53:.*]] = ttg.memdesc_index %[[LOCAL_ALLOC_1]], %[[SELECT_51]] // AMD: ttg.local_store %[[LOAD_46]], %[[MEMDESC_SUBVIEW_53]] // AMD: scf.yield %[[DOT_48]], %[[ADDPTR_38]], %[[ADDPTR_39]], %[[SELECT_51]], %[[MEMDESC_SUBVIEW_52]], %[[LOAD_42]], %[[MEMDESC_SUBVIEW_53]] // AMD: } {tt.num_stages = 3 @@ -563,7 +563,7 @@ tt.func @indirect_bmm_scalar_dist_one(%77: i64 {tt.divisibility=16: i32}, // CHECK: %[[NEXT_BUFFER_1:.*]] = tt.addptr %{{.*}}, {{.*}} // CHECK: ttg.async_copy_global_to_local %[[NEXT_BUFFER_1]] // CHECK-DAG: %[[IND_BUFFER_WAIT_TOKEN:.*]] = ttg.async_wait {{.*}} {num = 1 : i32} -// CHECK-DAG: %[[IND_BUFFER_0:.*]] = ttg.memdesc_subview +// CHECK-DAG: %[[IND_BUFFER_0:.*]] = ttg.memdesc_index // CHECK: %[[IND_BUFFER_1:.*]] = ttg.local_load %[[IND_BUFFER_0]] token %[[IND_BUFFER_WAIT_TOKEN]] // CHECK: %[[IND_BUFFER_2:.*]] = tt.expand_dims %[[IND_BUFFER_1]] {axis = 1 : i32} // CHECK: %[[IND_BUFFER_3:.*]] = tt.broadcast %[[IND_BUFFER_2]] @@ -590,9 +590,9 @@ tt.func @indirect_bmm_scalar_dist_one(%77: i64 {tt.divisibility=16: i32}, // AMD: %[[ADDPTR_14:.*]] = tt.addptr %{{.*}}, %[[MULI_13]] // AMD: %[[SPLAT_15:.*]] = tt.splat %[[CMPI_2]] // AMD: %[[LOAD_16:.*]] = tt.load %[[ADDPTR_14]], %[[SPLAT_15]] -// AMD: %[[MEMDESC_SUBVIEW_17:.*]] = ttg.memdesc_subview %[[LOCAL_ALLOC_0]][%{{.*}}, %{{.*}}, %{{.*}}] +// AMD: %[[MEMDESC_SUBVIEW_17:.*]] = ttg.memdesc_index %[[LOCAL_ALLOC_0]], %{{.*}} // AMD: ttg.local_store %[[LOAD_8]], %[[MEMDESC_SUBVIEW_17]] -// AMD: %[[MEMDESC_SUBVIEW_18:.*]] = ttg.memdesc_subview %[[LOCAL_ALLOC_1]][%{{.*}}, %{{.*}}, %{{.*}}] +// AMD: %[[MEMDESC_SUBVIEW_18:.*]] = ttg.memdesc_index %[[LOCAL_ALLOC_1]], %{{.*}} // AMD: ttg.local_store %[[LOAD_16]], %[[MEMDESC_SUBVIEW_18]] // AMD: %[[SUBI_19:.*]] = arith.subi %{{.*}}, %{{.*}} // AMD: %{{.*}}:7 = scf.for %[[ARG6:.*]] = %{{.*}} to %[[SUBI_19]] step %{{.*}} iter_args(%[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %[[ADDPTR_6]], %[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %[[MEMDESC_SUBVIEW_17]], %[[ARG12:.*]] = %[[LOAD_10]], %[[ARG13:.*]] = %[[MEMDESC_SUBVIEW_18]]) @@ -611,9 +611,9 @@ tt.func @indirect_bmm_scalar_dist_one(%77: i64 {tt.divisibility=16: i32}, // AMD: %[[ADDI_59:.*]] = arith.addi %[[ARG10]], %{{.*}} // AMD: %[[CMPI_60:.*]] = arith.cmpi slt, %[[ADDI_59]], %{{.*}} // AMD: %[[SELECT_61:.*]] = arith.select %[[CMPI_60]], %[[ADDI_59]], %{{.*}} -// AMD: %[[MEMDESC_SUBVIEW_62:.*]] = ttg.memdesc_subview %[[LOCAL_ALLOC_0]][%[[SELECT_61]], %{{.*}}, %{{.*}}] +// AMD: %[[MEMDESC_SUBVIEW_62:.*]] = ttg.memdesc_index %[[LOCAL_ALLOC_0]], %[[SELECT_61]] // AMD: ttg.local_store %[[LOAD_49]], %[[MEMDESC_SUBVIEW_62]] -// AMD: %[[MEMDESC_SUBVIEW_63:.*]] = ttg.memdesc_subview %[[LOCAL_ALLOC_1]][%[[SELECT_61]], %{{.*}}, %{{.*}}] +// AMD: %[[MEMDESC_SUBVIEW_63:.*]] = ttg.memdesc_index %[[LOCAL_ALLOC_1]], %[[SELECT_61]] // AMD: ttg.local_store %[[LOAD_56]], %[[MEMDESC_SUBVIEW_63]] // AMD: scf.yield %[[DOT_58]], %[[ADDPTR_47]], %[[ADDPTR_48]], %[[SELECT_61]], %[[MEMDESC_SUBVIEW_62]], %[[LOAD_51]], %[[MEMDESC_SUBVIEW_63]] @@ -988,7 +988,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK: ttg.async_wait {{.*}} {num = 1 : i32} // CHECK: %[[NEXT_BUFFER_1:.*]] = tt.addptr %{{.*}}, {{.*}} // CHECK: ttg.async_copy_global_to_local %[[NEXT_BUFFER_1]] -// CHECK: %[[IND_BUFFER_0:.*]] = ttg.memdesc_subview {{.*}} : !ttg.memdesc<1x16xi64, #[[$SHARED_LAYOUT]], #smem, mutable> -> !ttg.memdesc<16xi64, #[[$SHARED_LAYOUT]], #smem, mutable, 1x16> +// CHECK: %[[IND_BUFFER_0:.*]] = ttg.memdesc_index {{.*}} : !ttg.memdesc<1x16xi64, #[[$SHARED_LAYOUT]], #smem, mutable> -> !ttg.memdesc<16xi64, #[[$SHARED_LAYOUT]], #smem, mutable, 1x16> // CHECK: %[[IND_BUFFER_1:.*]] = ttg.local_load %[[IND_BUFFER_0]] // CHECK: %[[IND_BUFFER_2:.*]] = tt.expand_dims %[[IND_BUFFER_1]] {axis = 1 : i32} // CHECK: %[[IND_BUFFER_3:.*]] = tt.broadcast %[[IND_BUFFER_2]] @@ -1016,9 +1016,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // AMD: %[[ADDI_59:.*]] = arith.addi %[[ARG10]], %{{.*}} // AMD: %[[CMPI_60:.*]] = arith.cmpi slt, %[[ADDI_59]], %{{.*}} // AMD: %[[SELECT_61:.*]] = arith.select %[[CMPI_60]], %[[ADDI_59]], %{{.*}} -// AMD: %[[MEMDESC_SUBVIEW_62:.*]] = ttg.memdesc_subview %{{.*}}[%[[SELECT_61]], %{{.*}}, %{{.*}}] +// AMD: %[[MEMDESC_SUBVIEW_62:.*]] = ttg.memdesc_index %{{.*}}, %[[SELECT_61]] // AMD: ttg.local_store %[[LOAD_49]], %[[MEMDESC_SUBVIEW_62]] -// AMD: %[[MEMDESC_SUBVIEW_63:.*]] = ttg.memdesc_subview %{{.*}}[%[[SELECT_61]], %{{.*}}, %{{.*}}] +// AMD: %[[MEMDESC_SUBVIEW_63:.*]] = ttg.memdesc_index %{{.*}}, %[[SELECT_61]] // AMD: ttg.local_store %[[LOAD_56]], %[[MEMDESC_SUBVIEW_63]] // AMD: scf.yield %[[DOT_58]], %[[ADDPTR_47]], %[[ADDPTR_48]], %[[SELECT_61]], %[[MEMDESC_SUBVIEW_62]], %[[LOAD_51]], %[[MEMDESC_SUBVIEW_63]] // AMD: } @@ -1044,9 +1044,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // AMD: %[[ADDI_35:.*]] = arith.addi %{{.*}}#3, %{{.*}} // AMD: %[[CMPI_36:.*]] = arith.cmpi slt, %[[ADDI_35]], %{{.*}} // AMD: %[[SELECT_37:.*]] = arith.select %[[CMPI_36]], %[[ADDI_35]], %{{.*}} -// AMD: %[[MEMDESC_SUBVIEW_38:.*]] = ttg.memdesc_subview %{{.*}}[%[[SELECT_37]], %{{.*}}, %{{.*}}] +// AMD: %[[MEMDESC_SUBVIEW_38:.*]] = ttg.memdesc_index %{{.*}}, %[[SELECT_37]] // AMD: ttg.local_store %[[LOAD_25]], %[[MEMDESC_SUBVIEW_38]] -// AMD: %[[MEMDESC_SUBVIEW_39:.*]] = ttg.memdesc_subview %{{.*}}[%[[SELECT_37]], %{{.*}}, %{{.*}}] +// AMD: %[[MEMDESC_SUBVIEW_39:.*]] = ttg.memdesc_index %{{.*}}, %[[SELECT_37]] // AMD: ttg.local_store %[[LOAD_32]], %[[MEMDESC_SUBVIEW_39]] // AMD: %[[SELECT_40:.*]] = arith.select %[[CMPI_21]], %[[IF_34]], %{{.*}}#0 // AMD: %[[LOCAL_LOAD_41:.*]] = ttg.local_load %[[MEMDESC_SUBVIEW_38]] @@ -1106,17 +1106,17 @@ tt.func @indirect_load_shared_layout(%77: tensor<16x16xi64, #BL> {tt.divisibilit // CHECK-LABEL: @kernel_yield_constant // CHECK: ttg.async_copy_global_to_local // CHECK: scf.for -// CHECK: ttg.memdesc_subview +// CHECK: ttg.memdesc_index // CHECK: ttg.async_copy_global_to_local // CHECK: tt.return // AMD-LABEL: @kernel_yield_constant // AMD: tt.load -// AMD: ttg.memdesc_subview +// AMD: ttg.memdesc_index // AMD: ttg.local_store // AMD: scf.for // AMD: tt.load -// AMD: ttg.memdesc_subview +// AMD: ttg.memdesc_index // AMD: ttg.local_store // AMD: tt.return #blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> @@ -1166,13 +1166,13 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32 // CHECK: %[[ABUFFER:.*]] = ttg.local_alloc // CHECK: %[[BBUFFER:.*]] = ttg.local_alloc -// CHECK: %[[A0BUFFER:.*]] = ttg.memdesc_subview %[[ABUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: %[[A0BUFFER:.*]] = ttg.memdesc_index %[[ABUFFER]], %[[CONSTANT_0]] // CHECK: ttg.async_copy_global_to_local {{.*}}, %[[A0BUFFER]] -// CHECK: %[[B0BUFFER:.*]] = ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: %[[B0BUFFER:.*]] = ttg.memdesc_index %[[BBUFFER]], %[[CONSTANT_0]] // CHECK: ttg.async_copy_global_to_local {{.*}}, %[[B0BUFFER]] -// CHECK: %[[A1BUFFER:.*]] = ttg.memdesc_subview %[[ABUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]]] +// CHECK: %[[A1BUFFER:.*]] = ttg.memdesc_index %[[ABUFFER]], %[[CONSTANT_1]] // CHECK: ttg.async_copy_global_to_local {{.*}}, %[[A1BUFFER]] -// CHECK: %[[B1BUFFER:.*]] = ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]]] +// CHECK: %[[B1BUFFER:.*]] = ttg.memdesc_index %[[BBUFFER]], %[[CONSTANT_1]] // CHECK: ttg.async_copy_global_to_local {{.*}}, %[[B1BUFFER]] // CHECK: scf.for @@ -1231,19 +1231,19 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK: %[[TRANS:.*]] = ttg.memdesc_trans %[[BUFFER_2]] // CHECK: %[[LOCAL_LOAD_1:.*]] = ttg.local_load %[[TRANS]] // CHECK: %[[BUFFER_1:.*]] = ttg.local_alloc : () -// CHECK: %[[SUBVIEW_1:.*]] = ttg.memdesc_subview %[[BUFFER_1]] +// CHECK: %[[SUBVIEW_1:.*]] = ttg.memdesc_index %[[BUFFER_1]] // CHECK: %[[ASYNC_COPY_1:.*]] = ttg.async_copy_global_to_local %[[NEXT_BUFFER_1]], %[[SUBVIEW_1]] // CHECK: ttg.async_commit_group %[[ASYNC_COPY_1]] -// CHECK: %[[SUBVIEW_2:.*]] = ttg.memdesc_subview %[[BUFFER_1]] +// CHECK: %[[SUBVIEW_2:.*]] = ttg.memdesc_index %[[BUFFER_1]] // CHECK: %[[ASYNC_COPY_2:.*]] = ttg.async_copy_global_to_local %[[NEXT_BUFFER_1]], %[[SUBVIEW_2]] // CHECK: ttg.async_commit_group %[[ASYNC_COPY_2]] // CHECK: scf.for // CHECK: ttg.async_wait -// CHECK: ttg.memdesc_subview %[[BUFFER_1]] +// CHECK: ttg.memdesc_index %[[BUFFER_1]] // CHECK: %[[LOCAL_LOAD_2:.*]] = ttg.local_load // CHECK: %[[DOT:.*]] = tt.dot %[[LOCAL_LOAD_2]], %[[LOCAL_LOAD_1]] // CHECK: %[[CONVERT_LAYOUT_3:.*]] = ttg.convert_layout %[[DOT]] -// CHECK: %[[SUBVIEW_4:.*]] = ttg.memdesc_subview %[[BUFFER_1]] +// CHECK: %[[SUBVIEW_4:.*]] = ttg.memdesc_index %[[BUFFER_1]] // CHECK: %[[ASYNC_COPY_3:.*]] = ttg.async_copy_global_to_local %[[NEXT_BUFFER_1]], %[[SUBVIEW_4]] // CHECK: ttg.async_commit_group %[[ASYNC_COPY_3]] // CHECK: ttg.local_dealloc %[[BUFFER_1]] diff --git a/test/TritonGPU/memdesc-subview-split.mlir b/test/TritonGPU/memdesc-subview-split.mlir index 548d34f97fff..99553212e702 100644 --- a/test/TritonGPU/memdesc-subview-split.mlir +++ b/test/TritonGPU/memdesc-subview-split.mlir @@ -5,35 +5,35 @@ #smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { - // CHECK-LABEL: memdesc_subview_spliting - tt.func public @memdesc_subview_spliting() attributes {noinline = false} { + // CHECK-LABEL: memdesc_subslice_spliting + tt.func public @memdesc_subslice_spliting() attributes {noinline = false} { %c0_i32 = arith.constant 0 : i32 %0 = ttg.local_alloc : () -> !ttg.memdesc<1x256x128xf16, #shared, #smem, mutable> - %1 = ttg.memdesc_subview %0[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x256x128xf16, #shared, #smem, mutable> -> !ttg.memdesc<256x128xf16, #shared, #smem, mutable> + %1 = ttg.memdesc_index %0, %c0_i32 : !ttg.memdesc<1x256x128xf16, #shared, #smem, mutable> -> !ttg.memdesc<256x128xf16, #shared, #smem, mutable> %c0_i32_0 = arith.constant 0 : i32 %c0_i32_1 = arith.constant 0 : i32 - %2 = ttg.memdesc_subview %1[%c0_i32_0, %c0_i32_1] : !ttg.memdesc<256x128xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable, 256x128> + %2 = ttg.memdesc_subslice %1 [0, 0] : !ttg.memdesc<256x128xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable, 256x128> %c0_i32_2 = arith.constant 0 : i32 %c32_i32 = arith.constant 32 : i32 - %3 = ttg.memdesc_subview %1[%c0_i32_2, %c32_i32] : !ttg.memdesc<256x128xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable, 256x128> + %3 = ttg.memdesc_subslice %1 [0, 32] : !ttg.memdesc<256x128xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable, 256x128> %c0_i32_3 = arith.constant 0 : i32 %c64_i32 = arith.constant 64 : i32 - %4 = ttg.memdesc_subview %1[%c0_i32_3, %c64_i32] : !ttg.memdesc<256x128xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable, 256x128> + %4 = ttg.memdesc_subslice %1 [0, 64] : !ttg.memdesc<256x128xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable, 256x128> %c0_i32_4 = arith.constant 0 : i32 %c96_i32 = arith.constant 96 : i32 - %5 = ttg.memdesc_subview %1[%c0_i32_4, %c96_i32] : !ttg.memdesc<256x128xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable, 256x128> + %5 = ttg.memdesc_subslice %1 [0, 96] : !ttg.memdesc<256x128xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable, 256x128> %c128_i32 = arith.constant 128 : i32 %c0_i32_5 = arith.constant 0 : i32 - %6 = ttg.memdesc_subview %1[%c128_i32, %c0_i32_5] : !ttg.memdesc<256x128xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable, 256x128> + %6 = ttg.memdesc_subslice %1 [128, 0] : !ttg.memdesc<256x128xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable, 256x128> %c128_i32_6 = arith.constant 128 : i32 %c32_i32_7 = arith.constant 32 : i32 - %7 = ttg.memdesc_subview %1[%c128_i32_6, %c32_i32_7] : !ttg.memdesc<256x128xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable, 256x128> + %7 = ttg.memdesc_subslice %1 [128, 32] : !ttg.memdesc<256x128xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable, 256x128> %c128_i32_8 = arith.constant 128 : i32 %c64_i32_9 = arith.constant 64 : i32 - %8 = ttg.memdesc_subview %1[%c128_i32_8, %c64_i32_9] : !ttg.memdesc<256x128xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable, 256x128> + %8 = ttg.memdesc_subslice %1 [128, 64] : !ttg.memdesc<256x128xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable, 256x128> %c128_i32_10 = arith.constant 128 : i32 %c96_i32_11 = arith.constant 96 : i32 - %9 = ttg.memdesc_subview %1[%c128_i32_10, %c96_i32_11] : !ttg.memdesc<256x128xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable, 256x128> + %9 = ttg.memdesc_subslice %1 [128, 96] : !ttg.memdesc<256x128xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable, 256x128> tt.return } } diff --git a/test/TritonGPU/partition-scheduling.mlir b/test/TritonGPU/partition-scheduling.mlir index 364c53bb2801..3a81db2ebac3 100644 --- a/test/TritonGPU/partition-scheduling.mlir +++ b/test/TritonGPU/partition-scheduling.mlir @@ -118,10 +118,10 @@ tt.func public @mma_operand_view( %K_shared = ttg.local_alloc %K : (tensor<64x64xf16, #load_blocked>) -> !ttg.memdesc<64x64xf16, #shared, #smem> // CHECK-DAG: [[TRANS_MMA:%.*]] = ttg.memdesc_trans [[K_SHARED]] {{.*}}partition = 1 - // CHECK-DAG: [[K_VIEW:%.*]] = ttg.memdesc_subview [[TRANS_MMA]]{{.*}}partition = 1 + // CHECK-DAG: [[K_VIEW:%.*]] = ttg.memdesc_subslice [[TRANS_MMA]]{{.*}}partition = 1 // CHECK-DAG: [[TRANS_USER:%.*]] = ttg.memdesc_trans [[K_SHARED]] {{.*}}partition = 0 %K_trans = ttg.memdesc_trans %K_shared {order = array} : !ttg.memdesc<64x64xf16, #shared, #smem> -> !ttg.memdesc<64x64xf16, #shared_T, #smem> - %K_view = ttg.memdesc_subview %K_trans[%c0_i32, %c0_i32] : !ttg.memdesc<64x64xf16, #shared_T, #smem> -> !ttg.memdesc<64x64xf16, #shared_T, #smem> + %K_view = ttg.memdesc_subslice %K_trans [0, 0] : !ttg.memdesc<64x64xf16, #shared_T, #smem> -> !ttg.memdesc<64x64xf16, #shared_T, #smem> // CHECK: ttng.tc_gen5_mma %arg0, [[K_VIEW]]{{.*}}partition = 1 %QK_mma_tok = ttng.tc_gen5_mma %Q_shared, %K_view, %QK_tmem[%QK_tok], %false, %true : !ttg.memdesc<256x64xf16, #shared, #smem>, !ttg.memdesc<64x64xf16, #shared_T, #smem>, !ttg.memdesc<256x64xf32, #tmem_acc, #ttng.tensor_memory, mutable> diff --git a/test/TritonGPU/pipeline-loop-nest.mlir b/test/TritonGPU/pipeline-loop-nest.mlir index 007fdd566db0..c59ceb38650a 100644 --- a/test/TritonGPU/pipeline-loop-nest.mlir +++ b/test/TritonGPU/pipeline-loop-nest.mlir @@ -28,7 +28,7 @@ tt.func public @matmul_kernel_tma_persistent(%arg0: !tt.ptr, %arg1: !tt.p // BLACKWELL: [[ACC_BUFS:%.*]] = ttng.tmem_alloc : () -> !ttg.memdesc<2x128x128xf32, #tmem, // BLACKWELL: ttg.memdesc_trans - // BLACKWELL: [[ACC_BUF:%.*]] = ttg.memdesc_subview [[ACC_BUFS]] + // BLACKWELL: [[ACC_BUF:%.*]] = ttg.memdesc_index [[ACC_BUFS]] // BLACKWELL: ttng.tc_gen5_mma {{%[0-9]+}}, {{%[0-9]+}}, [[ACC_BUF]], %false // BLACKWELL: scf.for @@ -50,7 +50,7 @@ tt.func public @matmul_kernel_tma_persistent(%arg0: !tt.ptr, %arg1: !tt.p %38 = ttng.reinterpret_tensor_descriptor %arg1 : !tt.ptr to !tt.tensordesc> %39 = tt.descriptor_load %38[%20, %35] : !tt.tensordesc> -> tensor<128x64xf16> // BLACKWELL: ttg.memdesc_trans - // BLACKWELL: [[ACC_BUF:%.*]] = ttg.memdesc_subview [[ACC_BUFS]] + // BLACKWELL: [[ACC_BUF:%.*]] = ttg.memdesc_index [[ACC_BUFS]] // BLACKWELL: ttng.tc_gen5_mma {{%[0-9]+}}, {{%[0-9]+}}, [[ACC_BUF]] // HOPPER: [[RESULT:%.*]] = ttng.warp_group_dot {{.*}} isAsync = true diff --git a/test/TritonGPU/pipeline-lower-loop.mlir b/test/TritonGPU/pipeline-lower-loop.mlir index 05aab70a5aa0..d09a03b7fa8d 100644 --- a/test/TritonGPU/pipeline-lower-loop.mlir +++ b/test/TritonGPU/pipeline-lower-loop.mlir @@ -36,11 +36,11 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { // CHECK: %[[EXT_P1:.*]] = arith.addi %[[EXT]], %[[ONE]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} : i32 // CHECK: %[[EXT_CMP:.*]] = arith.cmpi sge, %[[EXT_P1]], %[[NUM_BUFS]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} : i32 // CHECK: %[[EXT_NEXT:.*]] = arith.select %[[EXT_CMP]], %[[ZERO]], %[[EXT_P1]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} : i32 -// CHECK: %[[A_INS:.*]] = ttg.memdesc_subview %[[A]][%[[INS_NEXT]]{{.*}} {loop.cluster = 2 : i32, loop.stage = 0 : i32} +// CHECK: %[[A_INS:.*]] = ttg.memdesc_index %[[A]], %[[INS_NEXT]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} // CHECK: %[[A_TOK:.*]] = ttg.async_copy_global_to_local %{{.*}}, %[[A_INS]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} // CHECK: %[[A_TOK2:.*]] = ttg.async_commit_group %[[A_TOK]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} // CHECK: %[[A_TOK3:.*]] = ttg.async_wait %[[A_TOK2]] {loop.cluster = 0 : i32, loop.stage = 2 : i32, num = 0 : i32} -// CHECK: %[[A_EXT:.*]] = ttg.memdesc_subview %[[A]][%[[EXT_NEXT]]{{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} +// CHECK: %[[A_EXT:.*]] = ttg.memdesc_index %[[A]], %[[EXT_NEXT]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} // CHECK: %[[A_VAL:.*]] = ttg.local_load %[[A_EXT]] token %[[A_TOK3]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} // CHECK: "use"(%[[A_VAL]]) {loop.cluster = 0 : i32, loop.stage = 2 : i32} // CHECK: scf.yield %[[INS_NEXT]], %[[EXT_NEXT]] @@ -114,7 +114,7 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { // CHECK: scf.for // CHECK: ttg.async_copy_global_to_local %{{.*}} {loop.cluster = 2 : i32, loop.stage = 0 : i32} // CHECK: ttg.async_wait {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32, num = 0 : i32} -// CHECK: ttg.memdesc_subview {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} +// CHECK: ttg.memdesc_index {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} // CHECK: %[[A_VAL:.*]] = ttg.local_load {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} // CHECK: "use1"(%[[A_VAL]]) {loop.cluster = 0 : i32, loop.stage = 2 : i32} // CHECK: "use2"(%[[A_VAL]]) {loop.cluster = 0 : i32, loop.stage = 3 : i32} @@ -215,11 +215,11 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { // CHECK: %[[EXT_P1:.*]] = arith.addi %[[EXT]], %[[ONE]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} : i32 // CHECK: %[[EXT_CMP:.*]] = arith.cmpi sge, %[[EXT_P1]], %[[NUM_BUFS]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} : i32 // CHECK: %[[EXT_NEXT:.*]] = arith.select %[[EXT_CMP]], %[[ZERO]], %[[EXT_P1]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} : i32 -// CHECK: %[[A_INS:.*]] = ttg.memdesc_subview %[[A]][%[[INS_NEXT]]{{.*}} {loop.cluster = 2 : i32, loop.stage = 0 : i32} +// CHECK: %[[A_INS:.*]] = ttg.memdesc_index %[[A]], %[[INS_NEXT]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} // CHECK: %[[A_TOK:.*]] = ttg.async_copy_global_to_local %{{.*}}, %[[A_INS]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} // CHECK: %[[A_TOK2:.*]] = ttg.async_commit_group %[[A_TOK]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} // CHECK: %[[A_TOK3:.*]] = ttg.async_wait %[[A_TOK2]] {loop.cluster = 0 : i32, loop.stage = 2 : i32, num = 0 : i32} -// CHECK: %[[A_EXT:.*]] = ttg.memdesc_subview %[[A]][%[[EXT_NEXT]]{{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} +// CHECK: %[[A_EXT:.*]] = ttg.memdesc_index %[[A]], %[[EXT_NEXT]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} // CHECK: %[[A_VAL:.*]] = ttg.local_load %[[A_EXT]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} : !ttg.memdesc<128x32xf16, #[[SHARED]], # // CHECK: "use"(%[[A_VAL]]) {loop.cluster = 0 : i32, loop.stage = 2 : i32} // CHECK: scf.yield %[[INS_NEXT]], %[[EXT_NEXT]] @@ -330,18 +330,18 @@ tt.func @dependent_loads(%lb : index, %ub : index, %step : index, // CHECK: %[[EXT_P1:.*]] = arith.addi %[[EXT]], %[[ONE]] {loop.cluster = 2 : i32, loop.stage = 2 : i32} // CHECK: %[[EXT_CMP:.*]] = arith.cmpi sge, %[[EXT_P1]], %[[NUM_BUFS]] {loop.cluster = 2 : i32, loop.stage = 2 : i32} // CHECK: %[[EXT_NEXT:.*]] = arith.select %[[EXT_CMP]], %[[ZERO]], %[[EXT_P1]] {loop.cluster = 2 : i32, loop.stage = 2 : i32} - // CHECK: %[[A_INS:.*]] = ttg.memdesc_subview %[[A]][%[[INS_NEXT]]{{.*}} {loop.cluster = 4 : i32, loop.stage = 0 : i32} + // CHECK: %[[A_INS:.*]] = ttg.memdesc_index %[[A]], %[[INS_NEXT]] {loop.cluster = 4 : i32, loop.stage = 0 : i32} // CHECK: %[[A_TOK:.*]] = ttg.async_copy_global_to_local %{{.*}}, %[[A_INS]] {loop.cluster = 4 : i32, loop.stage = 0 : i32} // CHECK: %[[A_TOK2:.*]] = ttg.async_commit_group %[[A_TOK]] {loop.cluster = 4 : i32, loop.stage = 0 : i32} // CHECK: %[[A_TOK3:.*]] = ttg.async_wait %[[A_TOK2]] {loop.cluster = 2 : i32, loop.stage = 2 : i32, num = 0 : i32} - // CHECK: %[[A_EXT:.*]] = ttg.memdesc_subview %[[A]][%[[EXT_NEXT]]{{.*}} {loop.cluster = 2 : i32, loop.stage = 2 : i32} + // CHECK: %[[A_EXT:.*]] = ttg.memdesc_index %[[A]], %[[EXT_NEXT]] {loop.cluster = 2 : i32, loop.stage = 2 : i32} // CHECK: %[[A_VAL:.*]] = ttg.local_load %[[A_EXT]] token %[[A_TOK3]] {loop.cluster = 2 : i32, loop.stage = 2 : i32} // CHECK: %[[B:.*]] = "pointerize"(%[[A_VAL]]) {loop.cluster = 2 : i32, loop.stage = 2 : i32} - // CHECK: %[[C_INS:.*]] = ttg.memdesc_subview %[[C]][%[[INS_NEXT]]{{.*}} {loop.cluster = 2 : i32, loop.stage = 2 : i32} + // CHECK: %[[C_INS:.*]] = ttg.memdesc_index %[[C]], %[[INS_NEXT]] {loop.cluster = 2 : i32, loop.stage = 2 : i32} // CHECK: %[[C_TOK:.*]] = ttg.async_copy_global_to_local %[[B]], %[[C_INS]] {loop.cluster = 2 : i32, loop.stage = 2 : i32} // CHECK: %[[C_TOK2:.*]] = ttg.async_commit_group %[[C_TOK]] {loop.cluster = 2 : i32, loop.stage = 2 : i32} // CHECK: %[[C_TOK3:.*]] = ttg.async_wait %[[C_TOK2]] {loop.cluster = 0 : i32, loop.stage = 4 : i32, num = 0 : i32} - // CHECK: %[[C_EXT:.*]] = ttg.memdesc_subview %[[C]][%[[EXT_NEXT]]{{.*}} {loop.cluster = 0 : i32, loop.stage = 4 : i32} + // CHECK: %[[C_EXT:.*]] = ttg.memdesc_index %[[C]], %[[EXT_NEXT]] {loop.cluster = 0 : i32, loop.stage = 4 : i32} // CHECK: %[[C_VAL:.*]] = ttg.local_load %[[C_EXT]] token %[[C_TOK3]] {loop.cluster = 0 : i32, loop.stage = 4 : i32} // CHECK: "use1"(%[[C_VAL]]) {loop.cluster = 0 : i32, loop.stage = 4 : i32} // CHECK: scf.yield @@ -387,18 +387,18 @@ tt.func @dependent_loads_asymmetric(%lb : index, %ub : index, %step : index, // CHECK-DAG: %[[EXT2_P1:.*]] = arith.addi %[[EXT2]], %[[ONE]] {loop.cluster = 2 : i32, loop.stage = 2 : i32} // CHECK-DAG: %[[EXT2_CMP:.*]] = arith.cmpi sge, %[[EXT2_P1]], %[[NUM_BUFS2]] {loop.cluster = 2 : i32, loop.stage = 2 : i32} // CHECK-DAG: %[[EXT2_NEXT:.*]] = arith.select %[[EXT2_CMP]], %[[ZERO]], %[[EXT2_P1]] {loop.cluster = 2 : i32, loop.stage = 2 : i32} - // CHECK: %[[A_INS:.*]] = ttg.memdesc_subview %[[A]][%[[INS2_NEXT]]{{.*}} {loop.cluster = 4 : i32, loop.stage = 0 : i32} + // CHECK: %[[A_INS:.*]] = ttg.memdesc_index %[[A]], %[[INS2_NEXT]] {loop.cluster = 4 : i32, loop.stage = 0 : i32} // CHECK: %[[A_TOK:.*]] = ttg.async_copy_global_to_local %{{.*}}, %[[A_INS]] {loop.cluster = 4 : i32, loop.stage = 0 : i32} // CHECK: %[[A_TOK2:.*]] = ttg.async_commit_group %[[A_TOK]] {loop.cluster = 4 : i32, loop.stage = 0 : i32} // CHECK: %[[A_TOK3:.*]] = ttg.async_wait %[[A_TOK2]] {loop.cluster = 2 : i32, loop.stage = 2 : i32, num = 0 : i32} - // CHECK: %[[A_EXT:.*]] = ttg.memdesc_subview %[[A]][%[[EXT2_NEXT]]{{.*}} {loop.cluster = 2 : i32, loop.stage = 2 : i32} + // CHECK: %[[A_EXT:.*]] = ttg.memdesc_index %[[A]], %[[EXT2_NEXT]] {loop.cluster = 2 : i32, loop.stage = 2 : i32} // CHECK: %[[A_VAL:.*]] = ttg.local_load %[[A_EXT]] token %[[A_TOK3]] {loop.cluster = 2 : i32, loop.stage = 2 : i32} // CHECK: %[[B:.*]] = "pointerize"(%[[A_VAL]]) {loop.cluster = 2 : i32, loop.stage = 2 : i32} - // CHECK: %[[C_INS:.*]] = ttg.memdesc_subview %[[C]][%[[INS3_NEXT]]{{.*}} {loop.cluster = 2 : i32, loop.stage = 2 : i32} + // CHECK: %[[C_INS:.*]] = ttg.memdesc_index // CHECK: %[[C_TOK:.*]] = ttg.async_copy_global_to_local %[[B]], %[[C_INS]] {loop.cluster = 2 : i32, loop.stage = 2 : i32} // CHECK: %[[C_TOK2:.*]] = ttg.async_commit_group %[[C_TOK]] {loop.cluster = 2 : i32, loop.stage = 2 : i32} // CHECK: %[[C_TOK3:.*]] = ttg.async_wait %[[C_TOK2]] {loop.cluster = 0 : i32, loop.stage = 5 : i32, num = 0 : i32} - // CHECK: %[[C_EXT:.*]] = ttg.memdesc_subview %[[C]][%[[EXT3_NEXT]]{{.*}} {loop.cluster = 0 : i32, loop.stage = 5 : i32} + // CHECK: %[[C_EXT:.*]] = ttg.memdesc_index %[[C]], %[[EXT3_NEXT]] {loop.cluster = 0 : i32, loop.stage = 5 : i32} // CHECK: %[[C_VAL:.*]] = ttg.local_load %[[C_EXT]] token %[[C_TOK3]] {loop.cluster = 0 : i32, loop.stage = 5 : i32} // CHECK: "use1"(%[[C_VAL]]) {loop.cluster = 0 : i32, loop.stage = 5 : i32} // CHECK: scf.yield @@ -455,16 +455,16 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // CHECK: %[[EXT_P1:.*]] = arith.addi %[[EXT]], %[[ONE]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} : i32 // CHECK: %[[EXT_CMP:.*]] = arith.cmpi sge, %[[EXT_P1]], %[[NUM_BUFS]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} : i32 // CHECK: %[[EXT_NEXT:.*]] = arith.select %[[EXT_CMP]], %[[ZERO]], %[[EXT_P1]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} : i32 - // CHECK: %[[A_INS:.*]] = ttg.memdesc_subview %[[A]][%[[INS_NEXT]]{{.*}} {loop.cluster = 2 : i32, loop.stage = 0 : i32} + // CHECK: %[[A_INS:.*]] = ttg.memdesc_index %[[A]], %[[INS_NEXT]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} // CHECK: %[[A_TOK:.*]] = ttg.async_copy_global_to_local %{{.*}}, %[[A_INS]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} // CHECK: %[[A_TOK2:.*]] = ttg.async_commit_group %[[A_TOK]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} // CHECK: %[[A_TOK3:.*]] = ttg.async_wait %[[A_TOK2]] {loop.cluster = 0 : i32, loop.stage = 2 : i32, num = 0 : i32} - // CHECK: %[[A_EXT:.*]] = ttg.memdesc_subview %[[A]][%[[EXT_NEXT]]{{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} - // CHECK: %[[B_INS:.*]] = ttg.memdesc_subview %[[B]][%[[INS_NEXT]]{{.*}} {loop.cluster = 2 : i32, loop.stage = 0 : i32} + // CHECK: %[[A_EXT:.*]] = ttg.memdesc_index %[[A]], %[[EXT_NEXT]]{{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + // CHECK: %[[B_INS:.*]] = ttg.memdesc_index %[[B]], %[[INS_NEXT]]{{.*}} {loop.cluster = 2 : i32, loop.stage = 0 : i32} // CHECK: %[[B_TOK:.*]] = ttg.async_copy_global_to_local %{{.*}}, %[[B_INS]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} // CHECK: %[[B_TOK2:.*]] = ttg.async_commit_group %[[B_TOK]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} // CHECK: %[[B_TOK3:.*]] = ttg.async_wait %[[B_TOK2]] {loop.cluster = 0 : i32, loop.stage = 2 : i32, num = 0 : i32} - // CHECK: %[[B_EXT:.*]] = ttg.memdesc_subview %[[B]][%[[EXT_NEXT]]{{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + // CHECK: %[[B_EXT:.*]] = ttg.memdesc_index %[[B]], %[[EXT_NEXT]]{{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} // CHECK: %[[A_EXT_TRANSP:.*]] = ttg.memdesc_trans %[[A_EXT]] {loop.cluster = 0 : i32, loop.stage = 2 : i32, order = array} // CHECK: ttng.warp_group_dot %[[A_EXT_TRANSP]], %[[B_EXT]], %{{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} // CHECK: scf.yield {{.*}}, %[[INS_NEXT]], %[[EXT_NEXT]] @@ -515,11 +515,11 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // CHECK: %[[EXT_P1:.*]] = arith.addi %[[EXT]], %[[ONE]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} : i32 // CHECK: %[[EXT_CMP:.*]] = arith.cmpi sge, %[[EXT_P1]], %[[NUM_BUFS]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} : i32 // CHECK: %[[EXT_NEXT:.*]] = arith.select %[[EXT_CMP]], %[[ZERO]], %[[EXT_P1]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} : i32 - // CHECK: %[[A_INS:.*]] = ttg.memdesc_subview %[[A]][%[[INS_NEXT]]{{.*}} {loop.cluster = 2 : i32, loop.stage = 0 : i32} + // CHECK: %[[A_INS:.*]] = ttg.memdesc_index %[[A]], %[[INS_NEXT]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} // CHECK: %[[A_TOK:.*]] = ttg.async_copy_global_to_local %{{.*}}, %[[A_INS]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} // CHECK: %[[A_TOK2:.*]] = ttg.async_commit_group %[[A_TOK]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} // CHECK: %[[A_TOK3:.*]] = ttg.async_wait %[[A_TOK2]] {loop.cluster = 0 : i32, loop.stage = 2 : i32, num = 0 : i32} - // CHECK: %[[A_EXT:.*]] = ttg.memdesc_subview %[[A]][%[[EXT_NEXT]]{{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + // CHECK: %[[A_EXT:.*]] = ttg.memdesc_index %[[A]], %[[EXT_NEXT]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} // CHECK: %[[B:.*]] = tt.load {{.*}} {loop.cluster = 2 : i32, loop.stage = 0 : i32} // CHECK: %[[B_SH:.*]] = ttg.local_alloc %[[B]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} // CHECK: ttng.warp_group_dot %[[A_EXT]], %[[B_SH]], %{{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} @@ -569,18 +569,18 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // CHECK: %[[EXT_P1:.*]] = arith.addi %[[EXT]], %[[ONE]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} : i32 // CHECK: %[[EXT_CMP:.*]] = arith.cmpi sge, %[[EXT_P1]], %[[NUM_BUFS]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} : i32 // CHECK: %[[EXT_NEXT:.*]] = arith.select %[[EXT_CMP]], %[[ZERO]], %[[EXT_P1]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} : i32 - // CHECK: %[[A_INS:.*]] = ttg.memdesc_subview %[[A]][%[[INS_NEXT]]{{.*}} {loop.cluster = 2 : i32, loop.stage = 0 : i32} + // CHECK: %[[A_INS:.*]] = ttg.memdesc_index %[[A]], %[[INS_NEXT]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} // CHECK: %[[A_TOK:.*]] = ttg.async_copy_global_to_local %{{.*}}, %[[A_INS]] {{.*}} {loop.cluster = 2 : i32, loop.stage = 0 : i32} // CHECK: %[[A_TOK2:.*]] = ttg.async_commit_group %[[A_TOK]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} // CHECK: %[[A_TOK3:.*]] = ttg.async_wait %[[A_TOK2]] {loop.cluster = 0 : i32, loop.stage = 2 : i32, num = 0 : i32} - // CHECK: %[[A_EXT:.*]] = ttg.memdesc_subview %[[A]][%[[EXT_NEXT]]{{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + // CHECK: %[[A_EXT:.*]] = ttg.memdesc_index %[[A]], %[[EXT_NEXT]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} // CHECK: %[[A_LOAD:.*]] = ttg.local_load %[[A_EXT]] {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} // CHECK: %[[A_MASKED:.*]] = arith.select {{.*}}, %[[A_LOAD]], {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} - // CHECK: %[[B_INS:.*]] = ttg.memdesc_subview %[[B]][%[[INS_NEXT]]{{.*}} {loop.cluster = 2 : i32, loop.stage = 0 : i32} + // CHECK: %[[B_INS:.*]] = ttg.memdesc_index %[[B]], %[[INS_NEXT]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} // CHECK: %[[B_TOK:.*]] = ttg.async_copy_global_to_local %{{.*}}, %[[B_INS]] {{.*}} {loop.cluster = 2 : i32, loop.stage = 0 : i32} // CHECK: %[[B_TOK2:.*]] = ttg.async_commit_group %[[B_TOK]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} // CHECK: %[[B_TOK3:.*]] = ttg.async_wait %[[B_TOK2]] {loop.cluster = 0 : i32, loop.stage = 2 : i32, num = 0 : i32} - // CHECK: %[[B_EXT:.*]] = ttg.memdesc_subview %[[B]][%[[EXT_NEXT]]{{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + // CHECK: %[[B_EXT:.*]] = ttg.memdesc_index %[[B]], %[[EXT_NEXT]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} // CHECK: %[[B_LOAD:.*]] = ttg.local_load %[[B_EXT]] {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} // CHECK: %[[B_MASKED:.*]] = arith.select {{.*}}, %[[B_LOAD]], {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} // CHECK: %[[A_SH:.*]] = ttg.local_alloc %[[A_MASKED]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} @@ -630,9 +630,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // CHECK: %[[ACC_TM:.*]], %[[ACC_TOK:.*]] = ttng.tmem_alloc : () // CHECK: %[[INIT_TOK:.*]] = ttng.tmem_store %[[INIT]], %[[ACC_TM]][%[[ACC_TOK]]] // CHECK: %[[BAR:.*]] = ttg.local_alloc : () -> !ttg.memdesc<2xi64 - // CHECK: %[[BAR_SUB1:.*]] = ttg.memdesc_subview %[[BAR]][%[[ZERO]]] + // CHECK: %[[BAR_SUB1:.*]] = ttg.memdesc_index %[[BAR]], %[[ZERO]] // CHECK: ttng.init_barrier %[[BAR_SUB1]], 1 - // CHECK: %[[BAR_SUB2:.*]] = ttg.memdesc_subview %[[BAR]][%[[ONE]]] + // CHECK: %[[BAR_SUB2:.*]] = ttg.memdesc_index %[[BAR]], %[[ONE]] // CHECK: ttng.init_barrier %[[BAR_SUB2]], 1 // CHECK: %[[A:.*]] = ttg.local_alloc : () -> !ttg.memdesc<3x128x128 // CHECK: %[[B:.*]] = ttg.local_alloc : () -> !ttg.memdesc<3x128x128 @@ -643,17 +643,17 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // CHECK: %[[EXT_P1:.*]] = arith.addi %[[EXT]], %[[ONE]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} : i32 // CHECK: %[[EXT_CMP:.*]] = arith.cmpi sge, %[[EXT_P1]], %[[NUM_BUFS]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} : i32 // CHECK: %[[EXT_NEXT:.*]] = arith.select %[[EXT_CMP]], %[[ZERO]], %[[EXT_P1]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} : i32 - // CHECK: %[[A_INS:.*]] = ttg.memdesc_subview %[[A]][%[[INS_NEXT]]{{.*}} {loop.cluster = 2 : i32, loop.stage = 0 : i32} + // CHECK: %[[A_INS:.*]] = ttg.memdesc_index %[[A]], %[[INS_NEXT]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} // CHECK: %[[A_TOK:.*]] = ttg.async_copy_global_to_local %{{.*}}, %[[A_INS]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} // CHECK: %[[A_TOK2:.*]] = ttg.async_commit_group %[[A_TOK]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} // CHECK: %[[A_TOK3:.*]] = ttg.async_wait %[[A_TOK2]] {loop.cluster = 0 : i32, loop.stage = 2 : i32, num = 0 : i32} - // CHECK: %[[A_EXT:.*]] = ttg.memdesc_subview %[[A]][%[[EXT_NEXT]]{{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} - // CHECK: %[[B_INS:.*]] = ttg.memdesc_subview %[[B]][%[[INS_NEXT]]{{.*}} {loop.cluster = 2 : i32, loop.stage = 0 : i32} + // CHECK: %[[A_EXT:.*]] = ttg.memdesc_index %[[A]], %[[EXT_NEXT]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} + // CHECK: %[[B_INS:.*]] = ttg.memdesc_index %[[B]], %[[INS_NEXT]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} // CHECK: %[[B_TOK:.*]] = ttg.async_copy_global_to_local %{{.*}}, %[[B_INS]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} // CHECK: %[[B_TOK2:.*]] = ttg.async_commit_group %[[B_TOK]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} // CHECK: %[[B_TOK3:.*]] = ttg.async_wait %[[B_TOK2]] {loop.cluster = 0 : i32, loop.stage = 2 : i32, num = 0 : i32} - // CHECK: %[[B_EXT:.*]] = ttg.memdesc_subview %[[B]][%[[EXT_NEXT]]{{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} - // CHECK: %[[BAR_SUB:.*]] = ttg.memdesc_subview %[[BAR]][%[[BAR_IDX]]]{{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + // CHECK: %[[B_EXT:.*]] = ttg.memdesc_index %[[B]], %[[EXT_NEXT]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} + // CHECK: %[[BAR_SUB:.*]] = ttg.memdesc_index %[[BAR]], %[[BAR_IDX]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} // CHECK: %[[MMA_TOK:.*]] = ttng.tc_gen5_mma %[[A_EXT]], %[[B_EXT]], %{{.*}}[%[[TOK]]], {{.*}} {is_async, loop.cluster = 0 : i32, loop.stage = 2 : i32} // CHECK: ttng.wait_barrier %[[BAR_SUB]], %[[PHASE]] deps %[[A_EXT]], %[[B_EXT]] {loop.cluster = 0 : i32, loop.stage = 3 : i32} // CHECK: %[[PHASE_NEG:.*]] = arith.xori %[[PHASE]], %[[ONE]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} @@ -664,9 +664,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // CHECK: scf.yield %[[MMA_TOK]], %[[PHASE_NEXT]], %[[BAR_IDX_NEXT]], %[[INS_NEXT]], %[[EXT_NEXT]] // CHECK-DAG: ttg.local_dealloc %[[A]] // CHECK-DAG: ttg.local_dealloc %[[B]] - // CHECK-DAG: %[[BAR_SUB1:.*]] = ttg.memdesc_subview %[[BAR]][%[[ZERO]]] + // CHECK-DAG: %[[BAR_SUB1:.*]] = ttg.memdesc_index %[[BAR]], %[[ZERO]] // CHECK-DAG: ttng.inval_barrier %[[BAR_SUB1]] - // CHECK-DAG: %[[BAR_SUB2:.*]] = ttg.memdesc_subview %[[BAR]][%[[ONE]]] + // CHECK-DAG: %[[BAR_SUB2:.*]] = ttg.memdesc_index %[[BAR]], %[[ONE]] // CHECK-DAG: ttng.inval_barrier %[[BAR_SUB2]] // CHECK-DAG: ttg.local_dealloc %[[BAR]] // CHECK-DAG: ttg.async_wait {num = 0 : i32} @@ -706,9 +706,9 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { // CHECK-DAG: %[[NUM_BUFS:.*]] = arith.constant {{.*}} 2 : i32 // CHECK-DAG: %[[A:.*]] = ttg.local_alloc : () -> !ttg.memdesc<2x128x32 // CHECK-DAG: %[[BARRIER:.*]] = ttg.local_alloc : () -> !ttg.memdesc<2xi64 -// CHECK: %[[BAR1_VIEW:.*]] = ttg.memdesc_subview %[[BARRIER]][%[[ZERO]]] +// CHECK: %[[BAR1_VIEW:.*]] = ttg.memdesc_index %[[BARRIER]], %[[ZERO]] // CHECK: ttng.init_barrier %[[BAR1_VIEW]], 1 -// CHECK: %[[BAR2_VIEW:.*]] = ttg.memdesc_subview %[[BARRIER]][%[[ONE]]] +// CHECK: %[[BAR2_VIEW:.*]] = ttg.memdesc_index %[[BARRIER]], %[[ONE]] // CHECK: ttng.init_barrier %[[BAR2_VIEW]], 1 // CHECK: scf.for {{.*}} iter_args(%[[INS:.*]] = %[[MINUS_ONE]], %[[EXT:.*]] = %[[MINUS_ONE]], %[[PHASE:.*]] = %[[ZERO]]) // CHECK: %[[INS_P1:.*]] = arith.addi %[[INS]], %[[ONE]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} @@ -719,19 +719,19 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { // CHECK: %[[EXT_NEXT:.*]] = arith.select %[[EXT_CMP]], %[[ZERO]], %[[EXT_P1]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} // CHECK: %[[PHASE_XOR:.*]] = arith.xori %[[PHASE]], %[[ONE]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} // CHECK: %[[PHASE_NEXT:.*]] = arith.select %[[EXT_CMP]], %[[PHASE_XOR]], %[[PHASE]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} -// CHECK: %[[BAR_INS:.*]] = ttg.memdesc_subview %[[BARRIER]][%[[INS_NEXT]]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} +// CHECK: %[[BAR_INS:.*]] = ttg.memdesc_index %[[BARRIER]], %[[INS_NEXT]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} // CHECK: ttng.barrier_expect %[[BAR_INS]], 8192 {loop.cluster = 2 : i32, loop.stage = 0 : i32}, %[[TRUE]] -// CHECK: %[[A_INS:.*]] = ttg.memdesc_subview %[[A]][%[[INS_NEXT]], %[[ZERO]], %[[ZERO]]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} +// CHECK: %[[A_INS:.*]] = ttg.memdesc_index %[[A]], %[[INS_NEXT]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} // CHECK: ttng.async_tma_copy_global_to_local {{.*}}[{{.*}}] %[[A_INS]], %[[BAR_INS]], %[[TRUE]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} -// CHECK: %[[BAR_EXT:.*]] = ttg.memdesc_subview %[[BARRIER]][%[[EXT_NEXT]]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} +// CHECK: %[[BAR_EXT:.*]] = ttg.memdesc_index %[[BARRIER]], %[[EXT_NEXT]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} // CHECK: ttng.wait_barrier %[[BAR_EXT]], %[[PHASE_NEXT]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} -// CHECK: %[[A_EXT:.*]] = ttg.memdesc_subview %[[A]][%[[EXT_NEXT]], %[[ZERO]], %[[ZERO]]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} +// CHECK: %[[A_EXT:.*]] = ttg.memdesc_index %[[A]], %[[EXT_NEXT]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} // CHECK: %[[A_LOAD:.*]] = ttg.local_load %[[A_EXT]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} // CHECK: "use"(%[[A_LOAD]]) {loop.cluster = 0 : i32, loop.stage = 2 : i32} // CHECK: scf.yield %[[INS_NEXT]], %[[EXT_NEXT]], %[[PHASE_NEXT]] : i32, i32, i32 -// CHECK: %[[BAR1_VIEW:.*]] = ttg.memdesc_subview %[[BARRIER]][%[[ZERO]]] +// CHECK: %[[BAR1_VIEW:.*]] = ttg.memdesc_index %[[BARRIER]], %[[ZERO]] // CHECK: ttng.inval_barrier %[[BAR1_VIEW]] -// CHECK: %[[BAR2_VIEW:.*]] = ttg.memdesc_subview %[[BARRIER]][%[[ONE]]] +// CHECK: %[[BAR2_VIEW:.*]] = ttg.memdesc_index %[[BARRIER]], %[[ONE]] // CHECK: ttng.inval_barrier %[[BAR2_VIEW]] // CHECK: ttg.local_dealloc %[[BARRIER]] // CHECK: ttg.local_dealloc %[[A]] @@ -761,9 +761,9 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { // CHECK-DAG: %[[NUM_BUFS:.*]] = arith.constant {{.*}} 2 : i32 // CHECK-DAG: %[[A:.*]] = ttg.local_alloc : () -> !ttg.memdesc<2x32x128 // CHECK-DAG: %[[BARRIER:.*]] = ttg.local_alloc : () -> !ttg.memdesc<2xi64 -// CHECK: %[[BAR1_VIEW:.*]] = ttg.memdesc_subview %[[BARRIER]][%[[ZERO]]] +// CHECK: %[[BAR1_VIEW:.*]] = ttg.memdesc_index %[[BARRIER]], %[[ZERO]] // CHECK: ttng.init_barrier %[[BAR1_VIEW]], 1 -// CHECK: %[[BAR2_VIEW:.*]] = ttg.memdesc_subview %[[BARRIER]][%[[ONE]]] +// CHECK: %[[BAR2_VIEW:.*]] = ttg.memdesc_index %[[BARRIER]], %[[ONE]] // CHECK: ttng.init_barrier %[[BAR2_VIEW]], 1 // CHECK: scf.for {{.*}} iter_args(%[[INS:.*]] = %[[MINUS_ONE]], %[[EXT:.*]] = %[[MINUS_ONE]], %[[PHASE:.*]] = %[[ZERO]]) // CHECK: %[[INS_P1:.*]] = arith.addi %[[INS]], %[[ONE]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} @@ -774,19 +774,19 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { // CHECK: %[[EXT_NEXT:.*]] = arith.select %[[EXT_CMP]], %[[ZERO]], %[[EXT_P1]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} // CHECK: %[[PHASE_XOR:.*]] = arith.xori %[[PHASE]], %[[ONE]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} // CHECK: %[[PHASE_NEXT:.*]] = arith.select %[[EXT_CMP]], %[[PHASE_XOR]], %[[PHASE]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} -// CHECK: %[[BAR_INS:.*]] = ttg.memdesc_subview %[[BARRIER]][%[[INS_NEXT]]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} +// CHECK: %[[BAR_INS:.*]] = ttg.memdesc_index %[[BARRIER]], %[[INS_NEXT]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} // CHECK: ttng.barrier_expect %[[BAR_INS]], 16384 {loop.cluster = 2 : i32, loop.stage = 0 : i32}, %[[TRUE]] -// CHECK: %[[A_INS:.*]] = ttg.memdesc_subview %[[A]][%[[INS_NEXT]], %[[ZERO]], %[[ZERO]]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} +// CHECK: %[[A_INS:.*]] = ttg.memdesc_index %[[A]], %[[INS_NEXT]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} // CHECK: ttng.async_tma_gather {{.*}}[{{.*}}] %[[A_INS]], %[[BAR_INS]], %[[TRUE]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} -// CHECK: %[[BAR_EXT:.*]] = ttg.memdesc_subview %[[BARRIER]][%[[EXT_NEXT]]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} +// CHECK: %[[BAR_EXT:.*]] = ttg.memdesc_index %[[BARRIER]], %[[EXT_NEXT]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} // CHECK: ttng.wait_barrier %[[BAR_EXT]], %[[PHASE_NEXT]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} -// CHECK: %[[A_EXT:.*]] = ttg.memdesc_subview %[[A]][%[[EXT_NEXT]], %[[ZERO]], %[[ZERO]]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} +// CHECK: %[[A_EXT:.*]] = ttg.memdesc_index %[[A]], %[[EXT_NEXT]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} // CHECK: %[[A_LOAD:.*]] = ttg.local_load %[[A_EXT]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} // CHECK: "use"(%[[A_LOAD]]) {loop.cluster = 0 : i32, loop.stage = 2 : i32} // CHECK: scf.yield %[[INS_NEXT]], %[[EXT_NEXT]], %[[PHASE_NEXT]] : i32, i32, i32 -// CHECK: %[[BAR1_VIEW:.*]] = ttg.memdesc_subview %[[BARRIER]][%[[ZERO]]] +// CHECK: %[[BAR1_VIEW:.*]] = ttg.memdesc_index %[[BARRIER]], %[[ZERO]] // CHECK: ttng.inval_barrier %[[BAR1_VIEW]] -// CHECK: %[[BAR2_VIEW:.*]] = ttg.memdesc_subview %[[BARRIER]][%[[ONE]]] +// CHECK: %[[BAR2_VIEW:.*]] = ttg.memdesc_index %[[BARRIER]], %[[ONE]] // CHECK: ttng.inval_barrier %[[BAR2_VIEW]] // CHECK-DAG: ttg.local_dealloc %[[BARRIER]] // CHECK-DAG: ttg.local_dealloc %[[A]] @@ -1011,26 +1011,26 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // CHECK-DAG: %[[C_1:.*]] = arith.constant 1 : i32 // CHECK-DAG: %[[C_2:.*]] = arith.constant 2 : i32 // CHECK: %[[ACC_TM:.*]], %[[ACC_TOK:.*]] = ttng.tmem_alloc : () -> (!ttg.memdesc<2x128x128xf32 - // CHECK: %[[ACC_TM_SLICE:.*]] = ttg.memdesc_subview %[[ACC_TM]][%[[C_0]] + // CHECK: %[[ACC_TM_SLICE:.*]] = ttg.memdesc_index %[[ACC_TM]], %[[C_0]] // CHECK: %[[INIT_TOK:.*]] = ttng.tmem_store %[[INIT_ACC]], %[[ACC_TM_SLICE]][], %[[TRUE]] // CHECK: %[[BAR:.*]] = ttg.local_alloc : () -> !ttg.memdesc<2xi64 - // CHECK: %[[BAR_SLICE:.*]] = ttg.memdesc_subview %[[BAR]][%[[C_0]] + // CHECK: %[[BAR_SLICE:.*]] = ttg.memdesc_index %[[BAR]], %[[C_0]] // CHECK: ttng.init_barrier %[[BAR_SLICE]], 1 - // CHECK: %[[BAR_SLICE_2:.*]] = ttg.memdesc_subview %[[BAR]][%[[C_1]] + // CHECK: %[[BAR_SLICE_2:.*]] = ttg.memdesc_index %[[BAR]], %[[C_1]] // CHECK: ttng.init_barrier %[[BAR_SLICE_2]], 1 // CHECK: %[[FOR_RES:.*]]:5 = scf.for {{.*}} iter_args(%[[PHASE:.*]] = %[[C_0]], %[[BAR_IDX:.*]] = %[[C_0]], %[[BUF_IDX:.*]] = %[[C_N1]], %[[INSERT_IDX:.*]] = %[[C_N1]], %[[EXTRACT_IDX:.*]] = %[[C_N1]] // CHECK: %[[BUF_IDX_P1:.*]] = arith.addi %[[BUF_IDX]], %[[C_1]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} // CHECK: %[[BUF_IDX_CND:.*]] = arith.cmpi sge, %[[BUF_IDX_P1]], %[[C_2]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} // CHECK: %[[BUF_IDX_NEXT:.*]] = arith.select %[[BUF_IDX_CND]], %[[C_0]], %[[BUF_IDX_P1]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} // CHECK: %[[BUF_IDX_NEXT_CND:.*]] = arith.select %[[CND]], %[[BUF_IDX]], %[[BUF_IDX_NEXT]] - // CHECK: %[[TM_SLICE:.*]] = ttg.memdesc_subview %[[ACC_TM]][%[[BUF_IDX_NEXT_CND]], {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + // CHECK: %[[TM_SLICE:.*]] = ttg.memdesc_index %[[ACC_TM]], %[[BUF_IDX_NEXT_CND]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} // CHECK: %[[STORE_TOK:.*]] = ttng.tmem_store %[[OVERRIDE_ACC]], %[[TM_SLICE]][], {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} - // CHECK: %[[BAR_SLICE:.*]] = ttg.memdesc_subview %[[BAR]][%[[BAR_IDX]]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} - // CHECK: %[[ACC_TM_SLICE:.*]] = ttg.memdesc_subview %[[ACC_TM]][%[[BUF_IDX_NEXT_CND]]{{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + // CHECK: %[[BAR_SLICE:.*]] = ttg.memdesc_index %[[BAR]], %[[BAR_IDX]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} + // CHECK: %[[ACC_TM_SLICE:.*]] = ttg.memdesc_index %[[ACC_TM]], %[[BUF_IDX_NEXT_CND]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} // CHECK: %[[MMA_TOK:.*]] = ttng.tc_gen5_mma %{{.*}}, %{{.*}}, %[[ACC_TM_SLICE]][], %[[TRUE]], %[[TRUE]], %[[BAR_SLICE]][%true] {is_async, loop.cluster = 0 : i32, loop.stage = 2 : i32} // CHECK: ttng.wait_barrier %[[BAR_SLICE]], %[[PHASE]] deps %{{.*}}, %{{.*}} {loop.cluster = 0 : i32, loop.stage = 3 : i32} // CHECK: scf.if - // CHECK: %[[ACC_TM_SLICE:.*]] = ttg.memdesc_subview %[[ACC_TM]][%[[BUF_IDX_NEXT_CND]] + // CHECK: %[[ACC_TM_SLICE:.*]] = ttg.memdesc_index %[[ACC_TM]], %[[BUF_IDX_NEXT_CND]] // CHECK: %[[LOAD_ACC:.*]], %[[USER_TOK:.*]] = ttng.tmem_load %[[ACC_TM_SLICE]][] // CHECK: "use"(%[[LOAD_ACC]]) // CHECK: } @@ -1041,7 +1041,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // CHECK: %[[PHASE_NEXT:.*]] = arith.select %[[BAR_IDX_CND]], %[[PHASE_NEG]], %[[PHASE]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} // CHECK: scf.yield %[[PHASE_NEXT]], %[[BAR_IDX_NEXT]], %[[BUF_IDX_NEXT_CND]] // CHECK: } {tt.scheduled_max_stage = 3 : i32} - // CHECK: %[[ACC_TM_SLICE:.*]] = ttg.memdesc_subview %[[ACC_TM]][%[[FOR_RES]]#2, + // CHECK: %[[ACC_TM_SLICE:.*]] = ttg.memdesc_index %[[ACC_TM]], %[[FOR_RES]]#2 // CHECK: %[[LOAD_ACC:.*]], %[[RES_TOK:.*]] = ttng.tmem_load %[[ACC_TM_SLICE]][] tt.func public @simple_persistent_mmav5(%arg0: tensor<128x128x!tt.ptr, #blocked> {tt.contiguity = 16 : i32, tt.divisibility = 16 : i32}, %arg1: tensor<128x128x!tt.ptr, #blocked> {tt.contiguity = 16 : i32, tt.divisibility = 16 : i32}, %arg2: i32) -> tensor<128x128xf16, #blocked1> attributes {noinline = false} { %true = arith.constant true @@ -1093,24 +1093,24 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // CHECK-DAG: %[[C_1:.*]] = arith.constant 1 : i32 // CHECK-DAG: %[[C_2:.*]] = arith.constant 2 : i32 // CHECK: %[[ACC_TM:.*]], %[[ACC_TOK:.*]] = ttng.tmem_alloc : () -> (!ttg.memdesc<2x128x128xf32 - // CHECK: %[[ACC_TM_SLICE:.*]] = ttg.memdesc_subview %[[ACC_TM]][%[[C_0]] + // CHECK: %[[ACC_TM_SLICE:.*]] = ttg.memdesc_index %[[ACC_TM]], %[[C_0]] // CHECK: %[[INIT_TOK:.*]] = ttng.tmem_store %[[INIT_ACC]], %[[ACC_TM_SLICE]][], %[[TRUE]] // CHECK: %[[BAR:.*]] = ttg.local_alloc : () -> !ttg.memdesc<2xi64 - // CHECK: %[[BAR_SLICE:.*]] = ttg.memdesc_subview %[[BAR]][%[[C_0]] + // CHECK: %[[BAR_SLICE:.*]] = ttg.memdesc_index %[[BAR]], %[[C_0]] // CHECK: ttng.init_barrier %[[BAR_SLICE]], 1 - // CHECK: %[[BAR_SLICE_2:.*]] = ttg.memdesc_subview %[[BAR]][%[[C_1]] + // CHECK: %[[BAR_SLICE_2:.*]] = ttg.memdesc_index %[[BAR]], %[[C_1]] // CHECK: ttng.init_barrier %[[BAR_SLICE_2]], 1 // CHECK: %[[FOR_RES:.*]]:5 = scf.for {{.*}} iter_args(%[[PHASE:.*]] = %[[C_0]], %[[BAR_IDX:.*]] = %[[C_0]], %[[BUF_IDX:.*]] = %[[C_N1]], %[[INSERT_IDX:.*]] = %[[C_N1]], %[[EXTRACT_IDX:.*]] = %[[C_N1]] - // CHECK: %[[BAR_SLICE:.*]] = ttg.memdesc_subview %[[BAR]][%[[BAR_IDX]]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} + // CHECK: %[[BAR_SLICE:.*]] = ttg.memdesc_index %[[BAR]], %[[BAR_IDX]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} // CHECK: %[[BUF_IDX_P1:.*]] = arith.addi %[[BUF_IDX]], %[[C_1]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} // CHECK: %[[BUF_IDX_CND:.*]] = arith.cmpi sge, %[[BUF_IDX_P1]], %[[C_2]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} // CHECK: %[[BUF_IDX_NEXT:.*]] = arith.select %[[BUF_IDX_CND]], %[[C_0]], %[[BUF_IDX_P1]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} // CHECK: %[[BUF_IDX_NEXT_CND:.*]] = arith.select %[[CND]], %[[BUF_IDX]], %[[BUF_IDX_NEXT]] - // CHECK: %[[ACC_TM_SLICE:.*]] = ttg.memdesc_subview %[[ACC_TM]][%[[BUF_IDX_NEXT_CND]]{{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + // CHECK: %[[ACC_TM_SLICE:.*]] = ttg.memdesc_index %[[ACC_TM]], %[[BUF_IDX_NEXT_CND]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} // CHECK: %[[MMA_TOK:.*]] = ttng.tc_gen5_mma %{{.*}}, %{{.*}}, %[[ACC_TM_SLICE]][], %[[CND]], %[[TRUE]], %[[BAR_SLICE]][%true] {is_async, loop.cluster = 0 : i32, loop.stage = 2 : i32} // CHECK: ttng.wait_barrier %[[BAR_SLICE]], %[[PHASE]] deps %{{.*}}, %{{.*}} {loop.cluster = 0 : i32, loop.stage = 3 : i32} // CHECK: scf.if - // CHECK: %[[ACC_TM_SLICE:.*]] = ttg.memdesc_subview %[[ACC_TM]][%[[BUF_IDX_NEXT_CND]] + // CHECK: %[[ACC_TM_SLICE:.*]] = ttg.memdesc_index %[[ACC_TM]], %[[BUF_IDX_NEXT_CND]] // CHECK: %[[LOAD_ACC:.*]], %[[USER_TOK:.*]] = ttng.tmem_load %[[ACC_TM_SLICE]][] // CHECK: "use"(%[[LOAD_ACC]]) // CHECK: } @@ -1121,7 +1121,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // CHECK: %[[PHASE_NEXT:.*]] = arith.select %[[BAR_IDX_CND]], %[[PHASE_NEG]], %[[PHASE]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} // CHECK: scf.yield %[[PHASE_NEXT]], %[[BAR_IDX_NEXT]], %[[BUF_IDX_NEXT_CND]] // CHECK: } {tt.scheduled_max_stage = 3 : i32} - // CHECK: %[[ACC_TM_SLICE:.*]] = ttg.memdesc_subview %[[ACC_TM]][%[[FOR_RES]]#2, + // CHECK: %[[ACC_TM_SLICE:.*]] = ttg.memdesc_index %[[ACC_TM]], %[[FOR_RES]]#2 // CHECK: %[[LOAD_ACC:.*]], %[[RES_TOK:.*]] = ttng.tmem_load %[[ACC_TM_SLICE]][] tt.func public @simple_persistent_mmav5_acc_flag(%arg0: tensor<128x128x!tt.ptr, #blocked> {tt.contiguity = 16 : i32, tt.divisibility = 16 : i32}, %arg1: tensor<128x128x!tt.ptr, #blocked> {tt.contiguity = 16 : i32, tt.divisibility = 16 : i32}, %arg2: i32) -> tensor<128x128xf16, #blocked1> attributes {noinline = false} { %true = arith.constant true @@ -1208,18 +1208,18 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // CHECK: %[[TMEM_BUF:.+]], %[[ACC_TOK:.+]] = ttng.tmem_alloc : () -> (!ttg.memdesc<128x128xf32 // CHECK: %[[INIT_TOK:.+]] = ttng.tmem_store %[[C0_F]], %[[TMEM_BUF]][%[[ACC_TOK]]] // CHECK: %[[BAR_BUF:.+]] = ttg.local_alloc : () -> !ttg.memdesc<2xi64 - // CHECK: %[[BAR_SLICE0:.+]] = ttg.memdesc_subview %[[BAR_BUF]][%[[C0]]] + // CHECK: %[[BAR_SLICE0:.+]] = ttg.memdesc_index %[[BAR_BUF]], %[[C0]] // CHECK: ttng.init_barrier %[[BAR_SLICE0]], 1 - // CHECK: %[[BAR_SLICE1:.+]] = ttg.memdesc_subview %[[BAR_BUF]][%[[C1]]] + // CHECK: %[[BAR_SLICE1:.+]] = ttg.memdesc_index %[[BAR_BUF]], %[[C1]] // CHECK: ttng.init_barrier %[[BAR_SLICE1]], 1 // CHECK: %[[LHS_BUFS:.+]] = ttg.local_alloc // CHECK: %[[RHS_BUFS:.+]] = ttg.local_alloc // CHECK: %[[FOR_RES:.+]]:5 = scf.for {{.*}} iter_args(%[[TOK:[^,]+]] = %[[INIT_TOK]], %[[PHASE:[^,]+]] = %[[C0]], %[[BAR_IDX:[^,]+]] = %[[C0]], // CHECK: %[[IDX0:.+]] = arith.select // CHECK: %[[IDX1:.+]] = arith.select - // CHECK: %[[LHS_DEP:.+]] = ttg.memdesc_subview %[[LHS_BUFS]][%[[IDX1]], - // CHECK: %[[RHS_DEP:.+]] = ttg.memdesc_subview %[[RHS_BUFS]][%[[IDX1]], - // CHECK: %[[BAR_SLICE:.+]] = ttg.memdesc_subview %[[BAR_BUF]][%[[BAR_IDX]]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} + // CHECK: %[[LHS_DEP:.+]] = ttg.memdesc_index %[[LHS_BUFS]], %[[IDX1]] + // CHECK: %[[RHS_DEP:.+]] = ttg.memdesc_index %[[RHS_BUFS]], %[[IDX1]] + // CHECK: %[[BAR_SLICE:.+]] = ttg.memdesc_index %[[BAR_BUF]], %[[BAR_IDX]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} // CHECK: %[[MMA_TOK:.+]] = ttng.tc_gen5_mma {{.*}}, {{.*}}, %[[TMEM_BUF]][%[[TOK]]], %[[TRUE]], %[[TRUE]], %[[BAR_SLICE]][%true] {is_async, loop.cluster = 0 : i32, loop.stage = 2 : i32} // CHECK: ttng.wait_barrier %[[BAR_SLICE]], %[[PHASE]] deps %[[LHS_DEP]], %[[RHS_DEP]] {loop.cluster = 0 : i32, loop.stage = 3 : i32} // CHECK: %[[CND_TOK:.+]] = scf.if @@ -1236,9 +1236,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // CHECK: %[[BAR_IDX_NEXT:.+]] = arith.select %[[BAR_WRAP]], %[[C0]], %[[BAR_IDX_P1]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} // CHECK: %[[PHASE_NEXT:.+]] = arith.select %[[BAR_WRAP]], %[[PHASE_XOR]], %[[PHASE]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} // CHECK: yield %[[CND_TOK]] - // CHECK: %[[BAR_SLICE0:.+]] = ttg.memdesc_subview %[[BAR_BUF]][%[[C0]]] + // CHECK: %[[BAR_SLICE0:.+]] = ttg.memdesc_index %[[BAR_BUF]], %[[C0]] // CHECK: ttng.inval_barrier %[[BAR_SLICE0]] - // CHECK: %[[BAR_SLICE1:.+]] = ttg.memdesc_subview %[[BAR_BUF]][%[[C1]]] + // CHECK: %[[BAR_SLICE1:.+]] = ttg.memdesc_index %[[BAR_BUF]], %[[C1]] // CHECK: ttng.inval_barrier %[[BAR_SLICE1]] // CHECK: ttg.local_dealloc %[[BAR_BUF]] // CHECK: ttng.tmem_load %[[TMEM_BUF]][%[[FOR_RES]]#0] @@ -1462,7 +1462,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // CHECK: %[[ACC1:.*]], %[[LOAD_TOK:.+]] = ttng.tmem_load %[[TMEM_BUF]][%{{.*}}] {loop.cluster = 2 : i32, loop.stage = 2 : i32} // CHECK: %[[MUL:.*]] = arith.mulf %[[ACC1]], {{.*}} {loop.cluster = 2 : i32, loop.stage = 2 : i32} // CHECK: %[[STORE_TOK:.+]] = ttng.tmem_store %[[MUL]], %[[TMEM_BUF]][%[[LOAD_TOK]]], {{.*}} {loop.cluster = 2 : i32, loop.stage = 2 : i32} - // CHECK: %[[BAR_SLICE:.*]] = ttg.memdesc_subview %[[BAR_BUF]] + // CHECK: %[[BAR_SLICE:.*]] = ttg.memdesc_index %[[BAR_BUF]] // CHECK: %[[MMA_TOK:.+]] = ttng.tc_gen5_mma %[[A_SLICE:.*]], %[[B_SLICE:.*]], %[[TMEM_BUF]][%[[STORE_TOK]]], {{.*}}, {{.*}}, %[[BAR_SLICE]][%true] {is_async, loop.cluster = 2 : i32, loop.stage = 2 : i32} // CHECK: ttng.wait_barrier %[[BAR_SLICE]], {{.*}} {loop.cluster = 1 : i32, loop.stage = 3 : i32} // CHECK: scf.yield %[[MMA_TOK]] diff --git a/test/TritonGPU/prefetch.mlir b/test/TritonGPU/prefetch.mlir index a0a87a5763ae..eed34ace9f51 100644 --- a/test/TritonGPU/prefetch.mlir +++ b/test/TritonGPU/prefetch.mlir @@ -1,4 +1,4 @@ -// RUN: triton-opt %s -split-input-file -tritongpu-prefetch -canonicalize | FileCheck %s +// RUN: triton-opt %s -split-input-file -tritongpu-prefetch -canonicalize | FileCheck %s --dump-input-context=50 // 4 warps // matmul: 128x32 @ 32x128 -> 128x128 @@ -12,24 +12,22 @@ #smem = #ttg.shared_memory // CHECK: tt.func @matmul_loop_mixed -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : i32 -// CHECK-DAG: %[[C16:.+]] = arith.constant 16 : i32 -// CHECK-DAG: %[[A0_PREFETCH_SMEM:.*]] = ttg.memdesc_subview %[[A0:.*]][%[[C0]], %[[C0]]] +// CHECK-DAG: %[[A0_PREFETCH_SMEM:.*]] = ttg.memdesc_subslice %[[A0:.*]][0, 0] // CHECK-DAG: %[[A0_PREFETCH:.*]] = ttg.local_load %[[A0_PREFETCH_SMEM]] // CHECK-DAG: %[[A0_CVT:.*]] = tt.fp_to_fp %[[A0_PREFETCH]] -// CHECK-DAG: %[[B0_PREFETCH_SMEM:.*]] = ttg.memdesc_subview %[[B0:.*]][%[[C0]], %[[C0]]] +// CHECK-DAG: %[[B0_PREFETCH_SMEM:.*]] = ttg.memdesc_subslice %[[B0:.*]][0, 0] // CHECK-DAG: %[[B0_PREFETCH:.*]] = ttg.local_load %[[B0_PREFETCH_SMEM]] // CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, %[[a0_prefetch:.*]] = %[[A0_CVT]], %[[b0_prefetch:.*]] = %[[B0_PREFETCH]] -// CHECK-DAG: %[[A_REM_SMEM:.*]] = ttg.memdesc_subview %[[arg_a0]][%[[C0]], %[[C16]]] +// CHECK-DAG: %[[A_REM_SMEM:.*]] = ttg.memdesc_subslice %[[arg_a0]][0, 16] // CHECK-DAG: %[[A_REM:.*]] = ttg.local_load %[[A_REM_SMEM]] // CHECK-DAG: %[[A_REM_CVT:.*]] = tt.fp_to_fp %[[A_REM]] -// CHECK-DAG: %[[B_REM_SMEM:.*]] = ttg.memdesc_subview %[[arg_b0]][%[[C16]], %[[C0]]] +// CHECK-DAG: %[[B_REM_SMEM:.*]] = ttg.memdesc_subslice %[[arg_b0]][16, 0] // CHECK-DAG: %[[B_REM:.*]] = ttg.local_load %[[B_REM_SMEM]] // CHECK: %[[D_FIRST:.*]] = tt.dot %[[a0_prefetch]], %[[b0_prefetch:.*]], {{.*}} -// CHECK-DAG: %[[NEXT_A_PREFETCH_SMEM:.*]] = ttg.memdesc_subview {{.*}}[%[[C0]], %[[C0]]] +// CHECK-DAG: %[[NEXT_A_PREFETCH_SMEM:.*]] = ttg.memdesc_subslice {{.*}}[0, 0] // CHECK-DAG: %[[NEXT_A_PREFETCH:.*]] = ttg.local_load %[[NEXT_A_PREFETCH_SMEM]] // CHECK-DAG: %[[NEXT_A_PREFETCH_CVT:.*]] = tt.fp_to_fp %[[NEXT_A_PREFETCH]] -// CHECK-DAG: %[[NEXT_B_PREFETCH_SMEM:.*]] = ttg.memdesc_subview {{.*}}[%[[C0]], %[[C0]]] +// CHECK-DAG: %[[NEXT_B_PREFETCH_SMEM:.*]] = ttg.memdesc_subslice {{.*}}[0, 0] // CHECK-DAG: %[[NEXT_B_PREFETCH:.*]] = ttg.local_load %[[NEXT_B_PREFETCH_SMEM]] // CHECK: tt.dot %[[A_REM_CVT]], %[[B_REM]], %[[D_FIRST:.*]] // CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_PREFETCH_CVT]], %[[NEXT_B_PREFETCH]] @@ -73,23 +71,22 @@ tt.func @matmul_loop_mixed(%lb : index, %ub : index, %step : index, %A : !tt.ptr // 4 warps // matmul: 128x16 @ 16x128 -> 128x128 -// CHECK: tt.func @matmul_loop_mixed -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : i32 -// CHECK-DAG: %[[A0_PREFETCH_SMEM:.*]] = ttg.memdesc_subview %[[A0:.*]][%[[C0]], %[[C0]]] +// CHECK: tt.func @matmul_loop_mixed_4warps +// CHECK-DAG: %[[A0_PREFETCH_SMEM:.*]] = ttg.memdesc_subslice %[[A0:.*]][0, 0] // CHECK-DAG: %[[A0_PREFETCH:.*]] = ttg.local_load %[[A0_PREFETCH_SMEM]] // CHECK-DAG: %[[A0_CVT:.*]] = tt.fp_to_fp %[[A0_PREFETCH]] -// CHECK-DAG: %[[B0_PREFETCH_SMEM:.*]] = ttg.memdesc_subview %[[B0:.*]][%[[C0]], %[[C0]]] +// CHECK-DAG: %[[B0_PREFETCH_SMEM:.*]] = ttg.memdesc_subslice %[[B0:.*]][0, 0] // CHECK-DAG: %[[B0_PREFETCH:.*]] = ttg.local_load %[[B0_PREFETCH_SMEM]] // CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, %[[a0_prefetch:.*]] = %[[A0_CVT]], %[[b0_prefetch:.*]] = %[[B0_PREFETCH]] -// CHECK-DAG: %[[NEXT_A_PREFETCH_SMEM:.*]] = ttg.memdesc_subview {{.*}}[%[[C0]], %[[C0]]] +// CHECK-DAG: %[[NEXT_A_PREFETCH_SMEM:.*]] = ttg.memdesc_subslice {{.*}}[0, 0] // CHECK-DAG: %[[NEXT_A_PREFETCH:.*]] = ttg.local_load %[[NEXT_A_PREFETCH_SMEM]] // CHECK-DAG: %[[NEXT_A_PREFETCH_CVT:.*]] = tt.fp_to_fp %[[NEXT_A_PREFETCH]] -// CHECK-DAG: %[[NEXT_B_PREFETCH_SMEM:.*]] = ttg.memdesc_subview {{.*}}[%[[C0]], %[[C0]]] +// CHECK-DAG: %[[NEXT_B_PREFETCH_SMEM:.*]] = ttg.memdesc_subslice {{.*}}[0, 0] // CHECK-DAG: %[[NEXT_B_PREFETCH:.*]] = ttg.local_load %[[NEXT_B_PREFETCH_SMEM]] // CHECK: tt.dot %[[a0_prefetch]], %[[b0_prefetch]], {{.*}} // CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_PREFETCH_CVT]], %[[NEXT_B_PREFETCH]] module attributes { "ttg.num-warps" = 4 : i32 } { -tt.func @matmul_loop_mixed(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) -> tensor<128x128xf32, #C>{ +tt.func @matmul_loop_mixed_4warps(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) -> tensor<128x128xf32, #C>{ %a_ptr_init = tt.splat %A : !tt.ptr -> tensor<128x16x!tt.ptr, #AL> %b_ptr_init = tt.splat %B : !tt.ptr -> tensor<16x128x!tt.ptr, #BL> @@ -136,17 +133,16 @@ tt.func @matmul_loop_mixed(%lb : index, %ub : index, %step : index, %A : !tt.ptr // matmul: 8x128x16 @ 8x16x128 -> 8x128x128 // CHECK: tt.func @matmul_3D_loop_mixed -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : i32 -// CHECK-DAG: %[[A0_PREFETCH_SMEM:.*]] = ttg.memdesc_subview %[[A0:.*]][%[[C0]], %[[C0]], %[[C0]]] +// CHECK-DAG: %[[A0_PREFETCH_SMEM:.*]] = ttg.memdesc_subslice %[[A0:.*]][0, 0, 0] // CHECK-DAG: %[[A0_PREFETCH:.*]] = ttg.local_load %[[A0_PREFETCH_SMEM]] // CHECK-DAG: %[[A0_CVT:.*]] = tt.fp_to_fp %[[A0_PREFETCH]] -// CHECK-DAG: %[[B0_PREFETCH_SMEM:.*]] = ttg.memdesc_subview %[[B0:.*]][%[[C0]], %[[C0]], %[[C0]]] +// CHECK-DAG: %[[B0_PREFETCH_SMEM:.*]] = ttg.memdesc_subslice %[[B0:.*]][0, 0, 0] // CHECK-DAG: %[[B0_PREFETCH:.*]] = ttg.local_load %[[B0_PREFETCH_SMEM]] // CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, %[[a0_prefetch:.*]] = %[[A0_CVT]], %[[b0_prefetch:.*]] = %[[B0_PREFETCH]] -// CHECK-DAG: %[[NEXT_A_PREFETCH_SMEM:.*]] = ttg.memdesc_subview {{.*}}[%[[C0]], %[[C0]], %[[C0]]] +// CHECK-DAG: %[[NEXT_A_PREFETCH_SMEM:.*]] = ttg.memdesc_subslice {{.*}}[0, 0, 0] // CHECK-DAG: %[[NEXT_A_PREFETCH:.*]] = ttg.local_load %[[NEXT_A_PREFETCH_SMEM]] // CHECK-DAG: %[[NEXT_A_PREFETCH_CVT:.*]] = tt.fp_to_fp %[[NEXT_A_PREFETCH]] -// CHECK-DAG: %[[NEXT_B_PREFETCH_SMEM:.*]] = ttg.memdesc_subview {{.*}}[%[[C0]], %[[C0]], %[[C0]]] +// CHECK-DAG: %[[NEXT_B_PREFETCH_SMEM:.*]] = ttg.memdesc_subslice {{.*}}[0, 0, 0] // CHECK-DAG: %[[NEXT_B_PREFETCH:.*]] = ttg.local_load %[[NEXT_B_PREFETCH_SMEM]] // CHECK: tt.dot %[[a0_prefetch]], %[[b0_prefetch]], {{.*}} // CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_PREFETCH_CVT]], %[[NEXT_B_PREFETCH]] @@ -190,24 +186,22 @@ tt.func @matmul_3D_loop_mixed(%lb : index, %ub : index, %step : index, %A : !tt. // matmul: 8x128x32 @ 8x32x128 -> 8x128x128 // CHECK: tt.func @matmul_3D_loop_mixed2 -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : i32 -// CHECK-DAG: %[[C16:.+]] = arith.constant 16 : i32 -// CHECK-DAG: %[[A0_PREFETCH_SMEM:.*]] = ttg.memdesc_subview %[[A0:.*]][%[[C0]], %[[C0]], %[[C0]]] +// CHECK-DAG: %[[A0_PREFETCH_SMEM:.*]] = ttg.memdesc_subslice %[[A0:.*]][0, 0, 0] // CHECK-DAG: %[[A0_PREFETCH:.*]] = ttg.local_load %[[A0_PREFETCH_SMEM]] // CHECK-DAG: %[[A0_CVT:.*]] = tt.fp_to_fp %[[A0_PREFETCH]] -// CHECK-DAG: %[[B0_PREFETCH_SMEM:.*]] = ttg.memdesc_subview %[[B0:.*]][%[[C0]], %[[C0]], %[[C0]]] +// CHECK-DAG: %[[B0_PREFETCH_SMEM:.*]] = ttg.memdesc_subslice %[[B0:.*]][0, 0, 0] // CHECK-DAG: %[[B0_PREFETCH:.*]] = ttg.local_load %[[B0_PREFETCH_SMEM]] // CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, %[[a0_prefetch:.*]] = %[[A0_CVT]], %[[b0_prefetch:.*]] = %[[B0_PREFETCH]] -// CHECK-DAG: %[[A_REM_SMEM:.*]] = ttg.memdesc_subview %[[arg_a0]][%[[C0]], %[[C0]], %[[C16]]] +// CHECK-DAG: %[[A_REM_SMEM:.*]] = ttg.memdesc_subslice %[[arg_a0]][0, 0, 16] // CHECK-DAG: %[[A_REM:.*]] = ttg.local_load %[[A_REM_SMEM]] // CHECK-DAG: %[[A_REM_CVT:.*]] = tt.fp_to_fp %[[A_REM]] -// CHECK-DAG: %[[B_REM_SMEM:.*]] = ttg.memdesc_subview %[[arg_b0]][%[[C0]], %[[C16]], %[[C0]]] +// CHECK-DAG: %[[B_REM_SMEM:.*]] = ttg.memdesc_subslice %[[arg_b0]][0, 16, 0] // CHECK-DAG: %[[B_REM:.*]] = ttg.local_load %[[B_REM_SMEM]] // CHECK: %[[D_FIRST:.*]] = tt.dot %[[a0_prefetch]], %[[b0_prefetch:.*]], {{.*}} -// CHECK-DAG: %[[NEXT_A_PREFETCH_SMEM:.*]] = ttg.memdesc_subview {{.*}}[%[[C0]], %[[C0]], %[[C0]]] +// CHECK-DAG: %[[NEXT_A_PREFETCH_SMEM:.*]] = ttg.memdesc_subslice {{.*}}[0, 0, 0] // CHECK-DAG: %[[NEXT_A_PREFETCH:.*]] = ttg.local_load %[[NEXT_A_PREFETCH_SMEM]] // CHECK-DAG: %[[NEXT_A_PREFETCH_CVT:.*]] = tt.fp_to_fp %[[NEXT_A_PREFETCH]] -// CHECK-DAG: %[[NEXT_B_PREFETCH_SMEM:.*]] = ttg.memdesc_subview {{.*}}[%[[C0]], %[[C0]], %[[C0]]] +// CHECK-DAG: %[[NEXT_B_PREFETCH_SMEM:.*]] = ttg.memdesc_subslice {{.*}}[0, 0, 0] // CHECK-DAG: %[[NEXT_B_PREFETCH:.*]] = ttg.local_load %[[NEXT_B_PREFETCH_SMEM]] // CHECK: tt.dot %[[A_REM_CVT]], %[[B_REM]], %[[D_FIRST:.*]] // CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_PREFETCH_CVT]], %[[NEXT_B_PREFETCH]] @@ -308,24 +302,22 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ #smem = #ttg.shared_memory // CHECK: tt.func @matmul_loop_mixed_amd -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : i32 -// CHECK-DAG: %[[C16:.+]] = arith.constant 16 : i32 -// CHECK-DAG: %[[A0_PREFETCH_SMEM:.*]] = ttg.memdesc_subview %[[A0:.*]][%[[C0]], %[[C0]]] +// CHECK-DAG: %[[A0_PREFETCH_SMEM:.*]] = ttg.memdesc_subslice %[[A0:.*]][0, 0] // CHECK-DAG: %[[A0_PREFETCH:.*]] = ttg.local_load %[[A0_PREFETCH_SMEM]] // CHECK-DAG: %[[A0_CVT:.*]] = tt.fp_to_fp %[[A0_PREFETCH]] -// CHECK-DAG: %[[B0_PREFETCH_SMEM:.*]] = ttg.memdesc_subview %[[B0:.*]][%[[C0]], %[[C0]]] +// CHECK-DAG: %[[B0_PREFETCH_SMEM:.*]] = ttg.memdesc_subslice %[[B0:.*]][0, 0] // CHECK-DAG: %[[B0_PREFETCH:.*]] = ttg.local_load %[[B0_PREFETCH_SMEM]] // CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, %[[a0_prefetch:.*]] = %[[A0_CVT]], %[[b0_prefetch:.*]] = %[[B0_PREFETCH]] -// CHECK-DAG: %[[A_REM_SMEM:.*]] = ttg.memdesc_subview %[[arg_a0]][%[[C0]], %[[C16]]] +// CHECK-DAG: %[[A_REM_SMEM:.*]] = ttg.memdesc_subslice %[[arg_a0]][0, 16] // CHECK-DAG: %[[A_REM:.*]] = ttg.local_load %[[A_REM_SMEM]] // CHECK-DAG: %[[A_REM_CVT:.*]] = tt.fp_to_fp %[[A_REM]] -// CHECK-DAG: %[[B_REM_SMEM:.*]] = ttg.memdesc_subview %[[arg_b0]][%[[C16]], %[[C0]]] +// CHECK-DAG: %[[B_REM_SMEM:.*]] = ttg.memdesc_subslice %[[arg_b0]][16, 0] // CHECK-DAG: %[[B_REM:.*]] = ttg.local_load %[[B_REM_SMEM]] // CHECK: %[[D_FIRST:.*]] = tt.dot %[[a0_prefetch]], %[[b0_prefetch:.*]], {{.*}} -// CHECK-DAG: %[[NEXT_A_PREFETCH_SMEM:.*]] = ttg.memdesc_subview {{.*}}[%[[C0]], %[[C0]]] +// CHECK-DAG: %[[NEXT_A_PREFETCH_SMEM:.*]] = ttg.memdesc_subslice {{.*}}[0, 0] // CHECK-DAG: %[[NEXT_A_PREFETCH:.*]] = ttg.local_load %[[NEXT_A_PREFETCH_SMEM]] // CHECK-DAG: %[[NEXT_A_PREFETCH_CVT:.*]] = tt.fp_to_fp %[[NEXT_A_PREFETCH]] -// CHECK-DAG: %[[NEXT_B_PREFETCH_SMEM:.*]] = ttg.memdesc_subview {{.*}}[%[[C0]], %[[C0]]] +// CHECK-DAG: %[[NEXT_B_PREFETCH_SMEM:.*]] = ttg.memdesc_subslice {{.*}}[0, 0] // CHECK-DAG: %[[NEXT_B_PREFETCH:.*]] = ttg.local_load %[[NEXT_B_PREFETCH_SMEM]] // CHECK: tt.dot %[[A_REM_CVT]], %[[B_REM]], %[[D_FIRST:.*]] // CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_PREFETCH_CVT]], %[[NEXT_B_PREFETCH]] diff --git a/test/TritonGPU/promote-lhs-to-tmem.mlir b/test/TritonGPU/promote-lhs-to-tmem.mlir index 448e950b4142..4dce112ede18 100644 --- a/test/TritonGPU/promote-lhs-to-tmem.mlir +++ b/test/TritonGPU/promote-lhs-to-tmem.mlir @@ -22,7 +22,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ %res = scf.for %i = %c0_i32 to %arg3 step %c1_i32 iter_args(%acc = %cst) -> (tensor<128x128xf32, #blocked1>) : i32 { %A = tt.load %A_ptr : tensor<128x128x!tt.ptr, #blocked1> %A_sh = ttg.local_alloc %A : (tensor<128x128xf16, #blocked1>) -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> - %B_sh = ttg.memdesc_subview %B_multibuf[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> + %B_sh = ttg.memdesc_index %B_multibuf, %c0_i32 : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> %acc_tm = ttng.tmem_alloc %acc : (tensor<128x128xf32, #blocked1>) -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> ttng.tc_gen5_mma %A_sh, %B_sh, %acc_tm, %true, %true : !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> %acc_res = ttng.tmem_load %acc_tm : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked1> @@ -47,7 +47,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ %res = scf.for %i = %c0_i32 to %arg3 step %c1_i32 iter_args(%acc = %cst) -> (tensor<128x128xf32, #blocked1>) : i32 { %A = tt.load %A_ptr : tensor<128x128x!tt.ptr, #blocked1> %A_sh = ttg.local_alloc %A : (tensor<128x128xf16, #blocked1>) -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> - %B_sh = ttg.memdesc_subview %B_multibuf[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> + %B_sh = ttg.memdesc_index %B_multibuf, %c0_i32 : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> %acc_tm = ttng.tmem_alloc %acc : (tensor<128x128xf32, #blocked1>) -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> %a_scale_tm = ttng.tmem_alloc %a_scale : (tensor<128x1xi8, #blocked2>) -> !ttg.memdesc<128x1xi8, #tmem_scales, #ttng.tensor_memory> %b_scale_tm = ttng.tmem_alloc %b_scale : (tensor<64x1xi8, #blocked2>) -> !ttg.memdesc<64x1xi8, #tmem_scales, #ttng.tensor_memory> @@ -74,7 +74,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ %res = scf.for %i = %c0_i32 to %arg3 step %c1_i32 iter_args(%acc = %cst) -> (tensor<128x128xf32, #blocked1>) : i32 { %B = tt.load %A_ptr : tensor<128x128x!tt.ptr, #blocked1> %B_sh = ttg.local_alloc %B : (tensor<128x128xf16, #blocked1>) -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> - %A_sh = ttg.memdesc_subview %A_multibuf[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> + %A_sh = ttg.memdesc_index %A_multibuf, %c0_i32 : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> %acc_tm = ttng.tmem_alloc %acc : (tensor<128x128xf32, #blocked1>) -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> ttng.tc_gen5_mma %A_sh, %B_sh, %acc_tm, %true, %true : !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> %acc_res = ttng.tmem_load %acc_tm : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked1> @@ -99,7 +99,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ %A = tt.load %A_ptr : tensor<128x128x!tt.ptr, #blocked1> %A_sh = ttg.local_alloc %A : (tensor<128x128xf16, #blocked1>) -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> %res = scf.for %i = %c0_i32 to %arg3 step %c1_i32 iter_args(%acc = %cst) -> (tensor<128x128xf32, #blocked1>) : i32 { - %B_sh = ttg.memdesc_subview %B_multibuf[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> + %B_sh = ttg.memdesc_index %B_multibuf, %c0_i32 : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> %acc_tm = ttng.tmem_alloc %acc : (tensor<128x128xf32, #blocked1>) -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> ttng.tc_gen5_mma %A_sh, %B_sh, %acc_tm, %true, %true : !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> %acc_res = ttng.tmem_load %acc_tm : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked1> @@ -124,7 +124,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ %res = scf.for %i = %c0_i32 to %arg3 step %c1_i32 iter_args(%acc = %cst) -> (tensor<128x128xf32, #blocked1>) : i32 { %A = tt.load %A_ptr : tensor<128x128x!tt.ptr, #blocked2> %A_sh = ttg.local_alloc %A : (tensor<128x128xf16, #blocked2>) -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> - %B_sh = ttg.memdesc_subview %B_multibuf[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> + %B_sh = ttg.memdesc_index %B_multibuf, %c0_i32 : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> %acc_tm = ttng.tmem_alloc %acc : (tensor<128x128xf32, #blocked1>) -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> ttng.tc_gen5_mma %A_sh, %B_sh, %acc_tm, %true, %true : !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> %acc_res = ttng.tmem_load %acc_tm : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked1> diff --git a/test/TritonGPU/rewrite-partition-dependencies.mlir b/test/TritonGPU/rewrite-partition-dependencies.mlir index 7be4bd9f1a58..a33699de7eb3 100644 --- a/test/TritonGPU/rewrite-partition-dependencies.mlir +++ b/test/TritonGPU/rewrite-partition-dependencies.mlir @@ -11,14 +11,14 @@ tt.func @two_consumers(%lb: i32, %ub: i32, %step: i32) { // CHECK-NEXT: [[READY_BARS:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<2xi64, // CHECK-NEXT: [[EMPTY_BARS:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<2xi64, - // CHECK-NEXT: [[READY0:%.*]] = ttg.memdesc_subview [[READY_BARS]][%c0_i32] - // CHECK-NEXT: [[EMPTY0:%.*]] = ttg.memdesc_subview [[EMPTY_BARS]][%c0_i32] + // CHECK-NEXT: [[READY0:%.*]] = ttg.memdesc_index [[READY_BARS]], %c0_i32 + // CHECK-NEXT: [[EMPTY0:%.*]] = ttg.memdesc_index [[EMPTY_BARS]], %c0_i32 // CHECK-NEXT: ttng.init_barrier [[READY0]], 1 // CHECK-NEXT: ttng.init_barrier [[EMPTY0]], 2 // CHECK-NEXT: ttng.arrive_barrier [[EMPTY0]], 2 - // CHECK-NEXT: [[READY1:%.*]] = ttg.memdesc_subview [[READY_BARS]][%c1_i32] - // CHECK-NEXT: [[EMPTY1:%.*]] = ttg.memdesc_subview [[EMPTY_BARS]][%c1_i32] + // CHECK-NEXT: [[READY1:%.*]] = ttg.memdesc_index [[READY_BARS]], %c1_i32 + // CHECK-NEXT: [[EMPTY1:%.*]] = ttg.memdesc_index [[EMPTY_BARS]], %c1_i32 // CHECK-NEXT: ttng.init_barrier [[READY1]], 1 // CHECK-NEXT: ttng.init_barrier [[EMPTY1]], 2 // CHECK-NEXT: ttng.arrive_barrier [[EMPTY1]], 2 @@ -38,9 +38,9 @@ tt.func @two_consumers(%lb: i32, %ub: i32, %step: i32) { // CHECK-NEXT: [[ROLLOVER:%.*]] = arith.cmpi eq, [[NEXT_IDX]], %c2_i32 // CHECK-NEXT: [[PHASE0:%.*]] = arith.select [[ROLLOVER]], [[NEXT_PHASE]], [[PRODUCER_PHASE]] // CHECK-NEXT: [[IDX0:%.*]] = arith.select [[ROLLOVER]], %c0_i32, [[NEXT_IDX]] - // CHECK-NEXT: [[VIEW:%.*]] = ttg.memdesc_subview [[BUFFERS]][[[IDX0]], %c0_i32] - // CHECK-NEXT: [[READY:%.*]] = ttg.memdesc_subview [[READY_BARS]][[[IDX0]]] - // CHECK-NEXT: [[EMPTY:%.*]] = ttg.memdesc_subview [[EMPTY_BARS]][[[IDX0]]] + // CHECK-NEXT: [[VIEW:%.*]] = ttg.memdesc_index [[BUFFERS]], [[IDX0]] + // CHECK-NEXT: [[READY:%.*]] = ttg.memdesc_index [[READY_BARS]], [[IDX0]] + // CHECK-NEXT: [[EMPTY:%.*]] = ttg.memdesc_index [[EMPTY_BARS]], [[IDX0]] // CHECK-NEXT: ttng.wait_barrier [[EMPTY]], [[PHASE0]] {ttg.partition = 0 : i32} // CHECK-NEXT: ttg.local_store [[OUTPUT]], [[VIEW]] {ttg.partition = 0 : i32} // CHECK-NEXT: ttng.arrive_barrier [[READY]], 1 {ttg.partition = 0 : i32} @@ -50,9 +50,9 @@ tt.func @two_consumers(%lb: i32, %ub: i32, %step: i32) { // CHECK-NEXT: [[ROLLOVER:%.*]] = arith.cmpi eq, [[NEXT_IDX]], %c2_i32 // CHECK-NEXT: [[PHASE1:%.*]] = arith.select [[ROLLOVER]], [[NEXT_PHASE]], [[CONSUMER_PHASE0]] // CHECK-NEXT: [[IDX1:%.*]] = arith.select [[ROLLOVER]], %c0_i32, [[NEXT_IDX]] - // CHECK-NEXT: [[VIEW:%.*]] = ttg.memdesc_subview [[BUFFERS]][[[IDX1]], %c0_i32] - // CHECK-NEXT: [[READY:%.*]] = ttg.memdesc_subview [[READY_BARS]][[[IDX1]]] - // CHECK-NEXT: [[EMPTY:%.*]] = ttg.memdesc_subview [[EMPTY_BARS]][[[IDX1]]] + // CHECK-NEXT: [[VIEW:%.*]] = ttg.memdesc_index [[BUFFERS]], [[IDX1]] + // CHECK-NEXT: [[READY:%.*]] = ttg.memdesc_index [[READY_BARS]], [[IDX1]] + // CHECK-NEXT: [[EMPTY:%.*]] = ttg.memdesc_index [[EMPTY_BARS]], [[IDX1]] // CHECK-NEXT: ttng.wait_barrier [[READY]], [[PHASE1]] {ttg.partition = 1 : i32} // CHECK-NEXT: [[VALUE:%.*]] = ttg.local_load [[VIEW]] {ttg.partition = 1 : i32} // CHECK-NEXT: ttng.arrive_barrier [[EMPTY]], 1 {ttg.partition = 1 : i32} @@ -64,9 +64,9 @@ tt.func @two_consumers(%lb: i32, %ub: i32, %step: i32) { // CHECK-NEXT: [[ROLLOVER:%.*]] = arith.cmpi eq, [[NEXT_IDX]], %c2_i32 // CHECK-NEXT: [[PHASE2:%.*]] = arith.select [[ROLLOVER]], [[NEXT_PHASE]], [[CONSUMER_PHASE1]] // CHECK-NEXT: [[IDX2:%.*]] = arith.select [[ROLLOVER]], %c0_i32, [[NEXT_IDX]] - // CHECK-NEXT: [[VIEW:%.*]] = ttg.memdesc_subview [[BUFFERS]][[[IDX2]], %c0_i32] - // CHECK-NEXT: [[READY:%.*]] = ttg.memdesc_subview [[READY_BARS]][[[IDX2]]] - // CHECK-NEXT: [[EMPTY:%.*]] = ttg.memdesc_subview [[EMPTY_BARS]][[[IDX2]]] + // CHECK-NEXT: [[VIEW:%.*]] = ttg.memdesc_index [[BUFFERS]], [[IDX2]] + // CHECK-NEXT: [[READY:%.*]] = ttg.memdesc_index [[READY_BARS]], [[IDX2]] + // CHECK-NEXT: [[EMPTY:%.*]] = ttg.memdesc_index [[EMPTY_BARS]], [[IDX2]] // CHECK-NEXT: ttng.wait_barrier [[READY]], [[PHASE2]] {ttg.partition = 2 : i32} // CHECK-NEXT: [[VALUE:%.*]] = ttg.local_load [[VIEW]] {ttg.partition = 2 : i32} // CHECK-NEXT: ttng.arrive_barrier [[EMPTY]], 1 {ttg.partition = 2 : i32} @@ -95,16 +95,16 @@ tt.func @distance_one(%lb: i32, %ub: i32, %step: i32) { // CHECK-NEXT: [[READY_BARS:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<2xi64, // CHECK-NEXT: [[EMPTY_BARS:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<2xi64, - // CHECK-NEXT: [[INIT:%.*]] = ttg.memdesc_subview [[BUFFERS]][%c0_i32, %c0_i32] + // CHECK-NEXT: [[INIT:%.*]] = ttg.memdesc_index [[BUFFERS]], %c0_i32 // CHECK-NEXT: ttg.local_store %cst, [[INIT]] - // CHECK-NEXT: [[READY0:%.*]] = ttg.memdesc_subview [[READY_BARS]][%c0_i32] - // CHECK-NEXT: [[EMPTY0:%.*]] = ttg.memdesc_subview [[EMPTY_BARS]][%c0_i32] + // CHECK-NEXT: [[READY0:%.*]] = ttg.memdesc_index [[READY_BARS]], %c0_i32 + // CHECK-NEXT: [[EMPTY0:%.*]] = ttg.memdesc_index [[EMPTY_BARS]], %c0_i32 // CHECK-NEXT: ttng.init_barrier [[READY0]], 1 // CHECK-NEXT: ttng.init_barrier [[EMPTY0]], 1 // CHECK-NEXT: ttng.arrive_barrier [[READY0]], 1 - // CHECK-NEXT: [[READY1:%.*]] = ttg.memdesc_subview [[READY_BARS]][%c1_i32] - // CHECK-NEXT: [[EMPTY1:%.*]] = ttg.memdesc_subview [[EMPTY_BARS]][%c1_i32] + // CHECK-NEXT: [[READY1:%.*]] = ttg.memdesc_index [[READY_BARS]], %c1_i32 + // CHECK-NEXT: [[EMPTY1:%.*]] = ttg.memdesc_index [[EMPTY_BARS]], %c1_i32 // CHECK-NEXT: ttng.init_barrier [[READY1]], 1 // CHECK-NEXT: ttng.init_barrier [[EMPTY1]], 1 // CHECK-NEXT: ttng.arrive_barrier [[EMPTY1]], 1 @@ -122,9 +122,9 @@ tt.func @distance_one(%lb: i32, %ub: i32, %step: i32) { // CHECK-NEXT: [[ROLLOVER:%.*]] = arith.cmpi eq, [[NEXT_IDX]], %c2_i32 // CHECK-NEXT: [[PHASE0:%.*]] = arith.select [[ROLLOVER]], [[NEXT_PHASE]], [[PRODUCER_PHASE]] // CHECK-NEXT: [[IDX0:%.*]] = arith.select [[ROLLOVER]], %c1_i32, [[NEXT_IDX]] - // CHECK-NEXT: [[VIEW:%.*]] = ttg.memdesc_subview [[BUFFERS]][[[IDX0]], %c0_i32] - // CHECK-NEXT: [[READY:%.*]] = ttg.memdesc_subview [[READY_BARS]][[[IDX0]]] - // CHECK-NEXT: [[EMPTY:%.*]] = ttg.memdesc_subview [[EMPTY_BARS]][[[IDX0]]] + // CHECK-NEXT: [[VIEW:%.*]] = ttg.memdesc_index [[BUFFERS]], [[IDX0]] + // CHECK-NEXT: [[READY:%.*]] = ttg.memdesc_index [[READY_BARS]], [[IDX0]] + // CHECK-NEXT: [[EMPTY:%.*]] = ttg.memdesc_index [[EMPTY_BARS]], [[IDX0]] // CHECK-NEXT: ttng.wait_barrier [[EMPTY]], [[PHASE0]] {ttg.partition = 0 : i32} // CHECK-NEXT: ttg.local_store [[OUTPUT]], [[VIEW]] {ttg.partition = 0 : i32} // CHECK-NEXT: ttng.arrive_barrier [[READY]], 1 {ttg.partition = 0 : i32} @@ -134,9 +134,9 @@ tt.func @distance_one(%lb: i32, %ub: i32, %step: i32) { // CHECK-NEXT: [[ROLLOVER:%.*]] = arith.cmpi eq, [[NEXT_IDX]], %c2_i32 // CHECK-NEXT: [[PHASE1:%.*]] = arith.select [[ROLLOVER]], [[NEXT_PHASE]], [[CONSUMER_PHASE]] // CHECK-NEXT: [[IDX1:%.*]] = arith.select [[ROLLOVER]], %c1_i32, [[NEXT_IDX]] - // CHECK-NEXT: [[VIEW:%.*]] = ttg.memdesc_subview [[BUFFERS]][[[IDX1]], %c0_i32] - // CHECK-NEXT: [[READY:%.*]] = ttg.memdesc_subview [[READY_BARS]][[[IDX1]]] - // CHECK-NEXT: [[EMPTY:%.*]] = ttg.memdesc_subview [[EMPTY_BARS]][[[IDX1]]] + // CHECK-NEXT: [[VIEW:%.*]] = ttg.memdesc_index [[BUFFERS]], [[IDX1]] + // CHECK-NEXT: [[READY:%.*]] = ttg.memdesc_index [[READY_BARS]], [[IDX1]] + // CHECK-NEXT: [[EMPTY:%.*]] = ttg.memdesc_index [[EMPTY_BARS]], [[IDX1]] // CHECK-NEXT: ttng.wait_barrier [[READY]], [[PHASE1]] {ttg.partition = 1 : i32} // CHECK-NEXT: [[VALUE:%.*]] = ttg.local_load [[VIEW]] {ttg.partition = 1 : i32} // CHECK-NEXT: ttng.arrive_barrier [[EMPTY]], 1 {ttg.partition = 1 : i32} @@ -200,9 +200,9 @@ tt.func @reuse_argument(%lb: i32, %ub: i32, %step: i32) { %cst1 = arith.constant dense<1> : !ty // CHECK: [[BUFFERS:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<3x1xi32 - // CHECK: [[VALUE:%.*]] = ttg.memdesc_subview [[BUFFERS]][%c0_i32, %c0_i32] + // CHECK: [[VALUE:%.*]] = ttg.memdesc_index [[BUFFERS]], %c0_i32 // CHECK: local_store [[CST0]], [[VALUE]] - // CHECK: [[VALUE:%.*]] = ttg.memdesc_subview [[BUFFERS]][%c1_i32, %c0_i32] + // CHECK: [[VALUE:%.*]] = ttg.memdesc_index [[BUFFERS]], %c1_i32 // CHECK: local_store [[CST1]], [[VALUE]] scf.for %i = %lb to %ub step %step iter_args(%k = %cst0, %l = %cst1) -> (!ty, !ty) : i32 { %0 = "op_a"() {ttg.partition = 0} : () -> !ty @@ -224,14 +224,14 @@ tt.func @multiplicity_branch(%lb: i32, %ub: i32, %step: i32) { // CHECK: [[BUFFERS:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<6x1xi32 - // CHECK: [[VALUE:%.*]] = ttg.memdesc_subview [[BUFFERS]][%c0_i32, %c0_i32] + // CHECK: [[VALUE:%.*]] = ttg.memdesc_index [[BUFFERS]], %c0_i32 // CHECK: local_store [[CST0]], [[VALUE]] - // CHECK: [[VALUE:%.*]] = ttg.memdesc_subview [[BUFFERS]][%c1_i32, %c0_i32] + // CHECK: [[VALUE:%.*]] = ttg.memdesc_index [[BUFFERS]], %c1_i32 // CHECK: local_store [[CST2]], [[VALUE]] - // CHECK: [[VALUE:%.*]] = ttg.memdesc_subview [[BUFFERS]][%c2_i32, %c0_i32] + // CHECK: [[VALUE:%.*]] = ttg.memdesc_index [[BUFFERS]], %c2_i32 // CHECK: local_store [[CST0]], [[VALUE]] - // CHECK: [[VALUE:%.*]] = ttg.memdesc_subview [[BUFFERS]][%c3_i32, %c0_i32] + // CHECK: [[VALUE:%.*]] = ttg.memdesc_index [[BUFFERS]], %c3_i32 // CHECK: local_store [[CST1]], [[VALUE]] // CHECK: iter_args @@ -245,21 +245,21 @@ tt.func @multiplicity_branch(%lb: i32, %ub: i32, %step: i32) { // CHECK: [[NEXT_IDX:%.*]] = arith.addi [[PIDX0]], %c1_i32 // CHECK: [[ROLLOVER:%.*]] = arith.cmpi eq, [[NEXT_IDX]], %c6_i32 // CHECK: [[IDX:%.*]] = arith.select [[ROLLOVER]], %c4_i32, [[NEXT_IDX]] - // CHECK: memdesc_subview [[BUFFERS]][[[IDX]], %c0_i32] + // CHECK: memdesc_index [[BUFFERS]], [[IDX]] // CHECK: [[NEXT_IDX:%.*]] = arith.addi [[CIDX0]], %c1_i32 // CHECK: [[LAST:%.*]] = arith.cmpi eq, [[NEXT_IDX]], %c6_i32 // CHECK: [[AT_END:%.*]] = arith.cmpi eq, [[NEXT_IDX]], %c2_i32 // CHECK: [[ROLLOVER:%.*]] = arith.ori [[LAST]], [[AT_END]] // CHECK: [[IDX:%.*]] = arith.select [[ROLLOVER]], %c4_i32, [[NEXT_IDX]] - // CHECK: memdesc_subview [[BUFFERS]][[[IDX]], %c0_i32] + // CHECK: memdesc_index [[BUFFERS]], [[IDX]] // CHECK: op_b "op_b"(%a) {ttg.partition = 1}: (!ty) -> () // CHECK: [[NEXT_IDX:%.*]] = arith.addi [[CIDX1]], %c1_i32 // CHECK: [[ROLLOVER:%.*]] = arith.cmpi eq, [[NEXT_IDX]], %c6_i32 // CHECK: [[IDX:%.*]] = arith.select [[ROLLOVER]], %c4_i32, [[NEXT_IDX]] - // CHECK: memdesc_subview [[BUFFERS]][[[IDX]], %c0_i32] + // CHECK: memdesc_index [[BUFFERS]], [[IDX]] // CHECK: op_c "op_c"(%b) {ttg.partition = 2}: (!ty) -> () @@ -268,7 +268,7 @@ tt.func @multiplicity_branch(%lb: i32, %ub: i32, %step: i32) { // CHECK: [[AT_END:%.*]] = arith.cmpi eq, [[NEXT_IDX]], %c2_i32 // CHECK: [[ROLLOVER:%.*]] = arith.ori [[LAST]], [[AT_END]] // CHECK: [[IDX:%.*]] = arith.select [[ROLLOVER]], %c4_i32, [[NEXT_IDX]] - // CHECK: memdesc_subview [[BUFFERS]][[[IDX]], %c0_i32] + // CHECK: memdesc_index [[BUFFERS]], [[IDX]] // CHECK: op_d "op_d"(%c) {ttg.partition = 3}: (!ty) -> () diff --git a/test/TritonGPU/samples/descriptor-matmul-pipeline.mlir b/test/TritonGPU/samples/descriptor-matmul-pipeline.mlir index 3b6edcfb3ae8..02c685a1d031 100644 --- a/test/TritonGPU/samples/descriptor-matmul-pipeline.mlir +++ b/test/TritonGPU/samples/descriptor-matmul-pipeline.mlir @@ -58,25 +58,25 @@ // CHECK: %[[VAL_43:.*]] = ttg.local_alloc : () -> !ttg.memdesc<3x128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> // CHECK: %[[VAL_44:.*]] = ttg.local_alloc : () -> !ttg.memdesc<3x256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> // CHECK: %[[VAL_45:.*]] = ttg.local_alloc : () -> !ttg.memdesc<3xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -// CHECK: %[[VAL_46:.*]] = ttg.memdesc_subview %[[VAL_45]]{{\[}}%[[VAL_12]]] : !ttg.memdesc<3xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> +// CHECK: %[[VAL_46:.*]] = ttg.memdesc_index %[[VAL_45]], %[[VAL_12]] : !ttg.memdesc<3xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> // CHECK: ttng.init_barrier %[[VAL_46]], 1 : !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> -// CHECK: %[[VAL_47:.*]] = ttg.memdesc_subview %[[VAL_45]]{{\[}}%[[VAL_15]]] : !ttg.memdesc<3xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> +// CHECK: %[[VAL_47:.*]] = ttg.memdesc_index %[[VAL_45]], %[[VAL_15]] : !ttg.memdesc<3xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> // CHECK: ttng.init_barrier %[[VAL_47]], 1 : !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> -// CHECK: %[[VAL_48:.*]] = ttg.memdesc_subview %[[VAL_45]]{{\[}}%[[VAL_7]]] : !ttg.memdesc<3xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> +// CHECK: %[[VAL_48:.*]] = ttg.memdesc_index %[[VAL_45]], %[[VAL_7]] : !ttg.memdesc<3xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> // CHECK: ttng.init_barrier %[[VAL_48]], 1 : !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> // CHECK: %[[VAL_49:.*]] = arith.cmpi sgt, %[[VAL_42]], %[[VAL_12]] : i32 -// CHECK: %[[VAL_50:.*]] = ttg.memdesc_subview %[[VAL_45]]{{\[}}%[[VAL_12]]] : !ttg.memdesc<3xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> +// CHECK: %[[VAL_50:.*]] = ttg.memdesc_index %[[VAL_45]], %[[VAL_12]] : !ttg.memdesc<3xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> // CHECK: ttng.barrier_expect %[[VAL_50]], 49152, %[[VAL_49]] : !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> -// CHECK: %[[VAL_51:.*]] = ttg.memdesc_subview %[[VAL_43]]{{\[}}%[[VAL_12]], %[[VAL_12]], %[[VAL_12]]] : !ttg.memdesc<3x128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x128x64> +// CHECK: %[[VAL_51:.*]] = ttg.memdesc_index %[[VAL_43]], %[[VAL_12]] : !ttg.memdesc<3x128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x128x64> // CHECK: ttng.async_tma_copy_global_to_local %[[VAL_35]]{{\[}}%[[VAL_39]], %[[VAL_12]]] %[[VAL_51]], %[[VAL_50]], %[[VAL_49]] : !tt.tensordesc>, !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> -> !ttg.memdesc<128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x128x64> -// CHECK: %[[VAL_52:.*]] = ttg.memdesc_subview %[[VAL_44]]{{\[}}%[[VAL_12]], %[[VAL_12]], %[[VAL_12]]] : !ttg.memdesc<3x256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x256x64> +// CHECK: %[[VAL_52:.*]] = ttg.memdesc_index %[[VAL_44]], %[[VAL_12]] : !ttg.memdesc<3x256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x256x64> // CHECK: ttng.async_tma_copy_global_to_local %[[VAL_36]]{{\[}}%[[VAL_40]], %[[VAL_12]]] %[[VAL_52]], %[[VAL_50]], %[[VAL_49]] : !tt.tensordesc>, !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> -> !ttg.memdesc<256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x256x64> // CHECK: %[[VAL_53:.*]] = arith.cmpi sgt, %[[VAL_42]], %[[VAL_15]] : i32 -// CHECK: %[[VAL_54:.*]] = ttg.memdesc_subview %[[VAL_45]]{{\[}}%[[VAL_15]]] : !ttg.memdesc<3xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> +// CHECK: %[[VAL_54:.*]] = ttg.memdesc_index %[[VAL_45]], %[[VAL_15]] : !ttg.memdesc<3xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> // CHECK: ttng.barrier_expect %[[VAL_54]], 49152, %[[VAL_53]] : !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> -// CHECK: %[[VAL_55:.*]] = ttg.memdesc_subview %[[VAL_43]]{{\[}}%[[VAL_15]], %[[VAL_12]], %[[VAL_12]]] : !ttg.memdesc<3x128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x128x64> +// CHECK: %[[VAL_55:.*]] = ttg.memdesc_index %[[VAL_43]], %[[VAL_15]] : !ttg.memdesc<3x128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x128x64> // CHECK: ttng.async_tma_copy_global_to_local %[[VAL_35]]{{\[}}%[[VAL_39]], %[[VAL_13]]] %[[VAL_55]], %[[VAL_54]], %[[VAL_53]] : !tt.tensordesc>, !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> -> !ttg.memdesc<128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x128x64> -// CHECK: %[[VAL_56:.*]] = ttg.memdesc_subview %[[VAL_44]]{{\[}}%[[VAL_15]], %[[VAL_12]], %[[VAL_12]]] : !ttg.memdesc<3x256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x256x64> +// CHECK: %[[VAL_56:.*]] = ttg.memdesc_index %[[VAL_44]], %[[VAL_15]] : !ttg.memdesc<3x256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x256x64> // CHECK: ttng.async_tma_copy_global_to_local %[[VAL_36]]{{\[}}%[[VAL_40]], %[[VAL_13]]] %[[VAL_56]], %[[VAL_54]], %[[VAL_53]] : !tt.tensordesc>, !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> -> !ttg.memdesc<256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x256x64> // CHECK: %[[VAL_57:.*]]:5 = scf.for %[[VAL_58:.*]] = %[[VAL_12]] to %[[VAL_42]] step %[[VAL_15]] iter_args(%[[VAL_59:.*]] = %[[VAL_19]], %[[VAL_60:.*]] = %[[VAL_13]], %[[VAL_61:.*]] = %[[VAL_15]], %[[VAL_62:.*]] = %[[VAL_8]], %[[VAL_63:.*]] = %[[VAL_12]]) -> (tensor<128x256xf32, #[[$ATTR_1]]>, i32, i32, i32, i32) : i32 { // CHECK: %[[VAL_64:.*]] = arith.subi %[[VAL_42]], %[[VAL_7]] : i32 @@ -86,10 +86,10 @@ // CHECK: %[[VAL_68:.*]] = arith.select %[[VAL_67]], %[[VAL_12]], %[[VAL_66]] : i32 // CHECK: %[[VAL_69:.*]] = arith.xori %[[VAL_63]], %[[VAL_15]] : i32 // CHECK: %[[VAL_70:.*]] = arith.select %[[VAL_67]], %[[VAL_69]], %[[VAL_63]] : i32 -// CHECK: %[[VAL_71:.*]] = ttg.memdesc_subview %[[VAL_45]]{{\[}}%[[VAL_68]]] : !ttg.memdesc<3xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> +// CHECK: %[[VAL_71:.*]] = ttg.memdesc_index %[[VAL_45]], %[[VAL_68]] : !ttg.memdesc<3xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> // CHECK: ttng.wait_barrier %[[VAL_71]], %[[VAL_70]] : !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> -// CHECK: %[[VAL_72:.*]] = ttg.memdesc_subview %[[VAL_44]]{{\[}}%[[VAL_68]], %[[VAL_12]], %[[VAL_12]]] : !ttg.memdesc<3x256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x256x64> -// CHECK: %[[VAL_73:.*]] = ttg.memdesc_subview %[[VAL_43]]{{\[}}%[[VAL_68]], %[[VAL_12]], %[[VAL_12]]] : !ttg.memdesc<3x128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x128x64> +// CHECK: %[[VAL_72:.*]] = ttg.memdesc_index %[[VAL_44]], %[[VAL_68]] : !ttg.memdesc<3x256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x256x64> +// CHECK: %[[VAL_73:.*]] = ttg.memdesc_index %[[VAL_43]], %[[VAL_68]] : !ttg.memdesc<3x128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x128x64> // CHECK: %[[VAL_74:.*]] = ttg.memdesc_trans %[[VAL_72]] {order = array} : !ttg.memdesc<256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x256x64> -> !ttg.memdesc<64x256xf16, #[[$ATTR_4]], #[[$ATTR_5]], mutable, 3x64x256> // CHECK: %[[VAL_75:.*]] = ttng.warp_group_dot %[[VAL_73]], %[[VAL_74]], %[[VAL_59]] {inputPrecision = 0 : i32, isAsync = true} : !ttg.memdesc<128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x128x64> * !ttg.memdesc<64x256xf16, #[[$ATTR_4]], #[[$ATTR_5]], mutable, 3x64x256> -> tensor<128x256xf32, #[[$ATTR_1]]> // CHECK: %[[VAL_76:.*]]:3 = ttng.warp_group_dot_wait %[[VAL_75]], %[[VAL_73]], %[[VAL_74]] {pendings = 1 : i32} : tensor<128x256xf32, #[[$ATTR_1]]>, !ttg.memdesc<128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x128x64>, !ttg.memdesc<64x256xf16, #[[$ATTR_4]], #[[$ATTR_5]], mutable, 3x64x256> @@ -97,20 +97,20 @@ // CHECK: %[[VAL_78:.*]] = arith.addi %[[VAL_61]], %[[VAL_15]] : i32 // CHECK: %[[VAL_79:.*]] = arith.cmpi sge, %[[VAL_78]], %[[VAL_6]] : i32 // CHECK: %[[VAL_80:.*]] = arith.select %[[VAL_79]], %[[VAL_12]], %[[VAL_78]] : i32 -// CHECK: %[[VAL_81:.*]] = ttg.memdesc_subview %[[VAL_45]]{{\[}}%[[VAL_80]]] : !ttg.memdesc<3xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> +// CHECK: %[[VAL_81:.*]] = ttg.memdesc_index %[[VAL_45]], %[[VAL_80]] : !ttg.memdesc<3xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> // CHECK: ttng.barrier_expect %[[VAL_81]], 49152, %[[VAL_65]] : !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> -// CHECK: %[[VAL_82:.*]] = ttg.memdesc_subview %[[VAL_43]]{{\[}}%[[VAL_80]], %[[VAL_12]], %[[VAL_12]]] : !ttg.memdesc<3x128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x128x64> +// CHECK: %[[VAL_82:.*]] = ttg.memdesc_index %[[VAL_43]], %[[VAL_80]] : !ttg.memdesc<3x128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x128x64> // CHECK: ttng.async_tma_copy_global_to_local %[[VAL_35]]{{\[}}%[[VAL_39]], %[[VAL_77]]] %[[VAL_82]], %[[VAL_81]], %[[VAL_65]] : !tt.tensordesc>, !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> -> !ttg.memdesc<128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x128x64> -// CHECK: %[[VAL_83:.*]] = ttg.memdesc_subview %[[VAL_44]]{{\[}}%[[VAL_80]], %[[VAL_12]], %[[VAL_12]]] : !ttg.memdesc<3x256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x256x64> +// CHECK: %[[VAL_83:.*]] = ttg.memdesc_index %[[VAL_44]], %[[VAL_80]] : !ttg.memdesc<3x256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x256x64> // CHECK: ttng.async_tma_copy_global_to_local %[[VAL_36]]{{\[}}%[[VAL_40]], %[[VAL_77]]] %[[VAL_83]], %[[VAL_81]], %[[VAL_65]] : !tt.tensordesc>, !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> -> !ttg.memdesc<256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x256x64> // CHECK: scf.yield %[[VAL_76]]#0, %[[VAL_77]], %[[VAL_80]], %[[VAL_68]], %[[VAL_70]] : tensor<128x256xf32, #[[$ATTR_1]]>, i32, i32, i32, i32 // CHECK: } // CHECK: %[[VAL_84:.*]] = ttng.warp_group_dot_wait %[[VAL_85:.*]]#0 {pendings = 0 : i32} : tensor<128x256xf32, #[[$ATTR_1]]> -// CHECK: %[[VAL_86:.*]] = ttg.memdesc_subview %[[VAL_45]]{{\[}}%[[VAL_12]]] : !ttg.memdesc<3xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> +// CHECK: %[[VAL_86:.*]] = ttg.memdesc_index %[[VAL_45]], %[[VAL_12]] : !ttg.memdesc<3xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> // CHECK: ttng.inval_barrier %[[VAL_86]] : !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> -// CHECK: %[[VAL_87:.*]] = ttg.memdesc_subview %[[VAL_45]]{{\[}}%[[VAL_15]]] : !ttg.memdesc<3xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> +// CHECK: %[[VAL_87:.*]] = ttg.memdesc_index %[[VAL_45]], %[[VAL_15]] : !ttg.memdesc<3xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> // CHECK: ttng.inval_barrier %[[VAL_87]] : !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> -// CHECK: %[[VAL_88:.*]] = ttg.memdesc_subview %[[VAL_45]]{{\[}}%[[VAL_7]]] : !ttg.memdesc<3xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> +// CHECK: %[[VAL_88:.*]] = ttg.memdesc_index %[[VAL_45]], %[[VAL_7]] : !ttg.memdesc<3xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> // CHECK: ttng.inval_barrier %[[VAL_88]] : !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> // CHECK: ttg.local_dealloc %[[VAL_45]] : !ttg.memdesc<3xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> // CHECK: ttg.local_dealloc %[[VAL_44]] : !ttg.memdesc<3x256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> diff --git a/test/TritonNvidiaGPU/interleave_tmem.mlir b/test/TritonNvidiaGPU/interleave_tmem.mlir index d8461bc5cc70..339c4b5577c5 100644 --- a/test/TritonNvidiaGPU/interleave_tmem.mlir +++ b/test/TritonNvidiaGPU/interleave_tmem.mlir @@ -68,8 +68,8 @@ tt.func @interleave_load_store_ws() { // CHECK: scf.for scf.for %i = %c0 to %c32 step %c1 : i32 { - // CHECK: memdesc_subview - %cur_acc = ttg.memdesc_subview %arg0[%i, %c0, %c0] : !ttg.memdesc<2x128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> + // CHECK: memdesc_index + %cur_acc = ttg.memdesc_index %arg0, %i : !ttg.memdesc<2x128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> // CHECK-NEXT: [[S0:%.+]] = ttng.tmem_subslice %{{.+}} {N = 0 : i32} // CHECK-NEXT: [[S1:%.+]] = ttng.tmem_subslice %{{.+}} {N = 64 : i32} @@ -121,15 +121,15 @@ tt.func @sink_alloc_op(%arg0: tensor<128x128xf32, #blocked>) { %true = arith.constant true %alloc0 = ttng.tmem_alloc : () -> !ttg.memdesc<1x128x128xf32, #tmem, #ttng.tensor_memory, mutable> - %subview0 = ttg.memdesc_subview %alloc0[%c0, %c0, %c0] : !ttg.memdesc<1x128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> + %subview0 = ttg.memdesc_index %alloc0, %c0 : !ttg.memdesc<1x128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> // CHECK: [[ALLOC1:%.+]] = ttng.tmem_alloc %alloc1 = ttng.tmem_alloc : () -> !ttg.memdesc<1x128x128xf32, #tmem, #ttng.tensor_memory, mutable> - // CHECK-NEXT: [[SUBVIEW1:%.+]] = ttg.memdesc_subview [[ALLOC1]] - %subview1 = ttg.memdesc_subview %alloc1[%c0, %c0, %c0] : !ttg.memdesc<1x128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> + // CHECK-NEXT: [[SUBVIEW1:%.+]] = ttg.memdesc_index [[ALLOC1]] + %subview1 = ttg.memdesc_index %alloc1, %c0 : !ttg.memdesc<1x128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> // CHECK-NEXT: tmem_store %arg0, [[SUBVIEW1]] ttng.tmem_store %arg0, %subview1, %true : tensor<128x128xf32, #blocked> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> // CHECK-NEXT: [[ALLOC0:%.+]] = ttng.tmem_alloc - // CHECK-NEXT: [[SUBVIEW0:%.+]] = ttg.memdesc_subview [[ALLOC0]] + // CHECK-NEXT: [[SUBVIEW0:%.+]] = ttg.memdesc_index [[ALLOC0]] // CHECK-NEXT: tmem_store %arg0, [[SUBVIEW0]] ttng.tmem_store %arg0, %subview0, %true : tensor<128x128xf32, #blocked> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> tt.return diff --git a/test/TritonNvidiaGPU/test_tensor_memory_allocation.mlir b/test/TritonNvidiaGPU/test_tensor_memory_allocation.mlir index 2f8575783061..fef27cc09d69 100644 --- a/test/TritonNvidiaGPU/test_tensor_memory_allocation.mlir +++ b/test/TritonNvidiaGPU/test_tensor_memory_allocation.mlir @@ -98,7 +98,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar // CHECK: ttng.tmem_alloc {tensor_memory_col_offset = 128 : i32, tensor_memory_row_offset = 0 : i32} %6 = ttng.tmem_alloc : () -> !ttg.memdesc<2x128x128xf32, #tmem, #ttng.tensor_memory, mutable> - %s = ttg.memdesc_subview %6[%c1, %c0, %c0] : !ttg.memdesc<2x128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> + %s = ttg.memdesc_index %6, %c1 : !ttg.memdesc<2x128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> // CHECK: ttng.tmem_alloc %{{.+}} {tensor_memory_col_offset = 0 : i32, tensor_memory_row_offset = 0 : i32} %7 = ttng.tmem_alloc %cst2 : (tensor<128x64xf32, #blocked>) -> !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable> @@ -313,10 +313,10 @@ tt.func @alloc_warp_specialize_explicit_capture_subview() { %true = arith.constant true %c0_i32 = arith.constant 0 : i32 - %b = ttg.memdesc_subview %arg0[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<2x128x128xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x128xbf16, #shared1, #smem> - %a = ttg.memdesc_subview %arg1[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x128xbf16, #tmem1, #ttng.tensor_memory, mutable> -> !ttg.memdesc<64x128xbf16, #tmem1, #ttng.tensor_memory, mutable, 1x64x128> - %d = ttg.memdesc_subview %arg2[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x128xf32, #tmem, #ttng.tensor_memory, mutable> -> !ttg.memdesc<64x128xf32, #tmem, #ttng.tensor_memory, mutable, 1x64x128> - %barrier = ttg.memdesc_subview %arg3[%c0_i32] : !ttg.memdesc<2xi64, #shared, #smem, mutable> -> !ttg.memdesc<1xi64, #shared, #smem, mutable> + %b = ttg.memdesc_index %arg0, %c0_i32 : !ttg.memdesc<2x128x128xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x128xbf16, #shared1, #smem> + %a = ttg.memdesc_index %arg1, %c0_i32 : !ttg.memdesc<1x64x128xbf16, #tmem1, #ttng.tensor_memory, mutable> -> !ttg.memdesc<64x128xbf16, #tmem1, #ttng.tensor_memory, mutable, 1x64x128> + %d = ttg.memdesc_index %arg2, %c0_i32 : !ttg.memdesc<1x64x128xf32, #tmem, #ttng.tensor_memory, mutable> -> !ttg.memdesc<64x128xf32, #tmem, #ttng.tensor_memory, mutable, 1x64x128> + %barrier = ttg.memdesc_index %arg3, %c0_i32 : !ttg.memdesc<2xi64, #shared, #smem, mutable> -> !ttg.memdesc<1xi64, #shared, #smem, mutable> ttng.tc_gen5_mma %a, %b, %d, %true, %true, %barrier[%true] {is_async} : !ttg.memdesc<64x128xbf16, #tmem1, #ttng.tensor_memory, mutable, 1x64x128>, !ttg.memdesc<128x128xbf16, #shared1, #smem>, !ttg.memdesc<64x128xf32, #tmem, #ttng.tensor_memory, mutable, 1x64x128>, !ttg.memdesc<1xi64, #shared, #smem, mutable> ttg.warp_return @@ -340,8 +340,8 @@ tt.func @alloc_warp_specialize_explicit_capture() { %true = arith.constant true %c0_i32 = arith.constant 0 : i32 - %b = ttg.memdesc_subview %arg0[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<2x128x128xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x128xbf16, #shared1, #smem> - %barrier = ttg.memdesc_subview %arg3[%c0_i32] : !ttg.memdesc<2xi64, #shared, #smem, mutable> -> !ttg.memdesc<1xi64, #shared, #smem, mutable> + %b = ttg.memdesc_index %arg0, %c0_i32 : !ttg.memdesc<2x128x128xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x128xbf16, #shared1, #smem> + %barrier = ttg.memdesc_index %arg3, %c0_i32 : !ttg.memdesc<2xi64, #shared, #smem, mutable> -> !ttg.memdesc<1xi64, #shared, #smem, mutable> ttng.tc_gen5_mma %arg1, %b, %arg2, %true, %true, %barrier[%true] {is_async} : !ttg.memdesc<64x128xbf16, #tmem1, #ttng.tensor_memory, mutable>, !ttg.memdesc<128x128xbf16, #shared1, #smem>, !ttg.memdesc<64x128xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.memdesc<1xi64, #shared, #smem, mutable> ttg.warp_return diff --git a/third_party/amd/include/TritonAMDGPUToLLVM/MembarUtility.h b/third_party/amd/include/TritonAMDGPUToLLVM/MembarUtility.h index 060a90179503..5c28b815a4d5 100644 --- a/third_party/amd/include/TritonAMDGPUToLLVM/MembarUtility.h +++ b/third_party/amd/include/TritonAMDGPUToLLVM/MembarUtility.h @@ -16,7 +16,7 @@ namespace mlir::triton::AMD { // gpu.barrier between the LocalLoad and the prefetches. However the pipeliner // will always use at least 2 buffers so this IR cannot be produced. Example // membar input IR to produce incorrect results: -// %tile_a = ttg.memdesc_subview +// %tile_a = ttg.memdesc_index // %1 = AsyncCopyGlobalToLocal %ptr %tile_a // scf.for // %2 = AsyncWait %1 diff --git a/third_party/amd/include/TritonAMDGPUTransforms/Passes.td b/third_party/amd/include/TritonAMDGPUTransforms/Passes.td index 734831b7ee7b..fd0b105d31b1 100644 --- a/third_party/amd/include/TritonAMDGPUTransforms/Passes.td +++ b/third_party/amd/include/TritonAMDGPUTransforms/Passes.td @@ -209,7 +209,7 @@ def TritonAMDGPUInThreadTranspose: Pass<"tritonamdgpu-in-thread-transpose", "mli local_load is vectorized, because shared memory order matches destination layout register order. This pass introduces two ttg.convert_layouts to properly cover cases when between ttg.load and ttg.local_alloc/ttg.local_store - exist more operations like scf or ttg.memdesc_subview. These convert_layouts ops are optimized out by later passes. + exist more operations like scf or ttg.memdesc_index. These convert_layouts ops are optimized out by later passes. }]; let dependentDialects = ["mlir::triton::amdgpu::TritonAMDGPUDialect", "mlir::triton::gpu::TritonGPUDialect"]; diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp index d9008c7b91df..bf12a692b9d4 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp @@ -48,7 +48,7 @@ class Pingponger { SmallVector> subViewOps; SmallVector> loadSliceOps; SmallVector dotSliceOps; - SmallVector constOffsets; + SmallVector constOffsets; Operation *lastInsertedOp; // rocdl.s.setprio will be mapped to `s_setprio` instruction which set the @@ -353,10 +353,10 @@ void Pingponger::determineDotMemoryOps( // Determine the local stores from the local loads. // With pipelining we expect this to be a single local // store within the loop based on a block argument after routing through - // a ttg.MemDescSubviewOp. - DenseSet subviews; + // a ttg.MemDescIndexOp. + DenseSet subviews; for (auto &&localLoad : dotLocalLoads) - findClosestPredOps(localLoad.getSrc(), subviews); + findClosestPredOps(localLoad.getSrc(), subviews); for (auto &&subview : subviews) for (auto &&user : subview->getUsers()) @@ -409,8 +409,7 @@ void Pingponger::genOffsetConstants(Location loc, OpBuilder &builder, unsigned numSlices, int64_t sliceWidth) { for (int i = 0; i < numSlices; i++) { int64_t offset = sliceWidth * i; - constOffsets.push_back( - builder.create(loc, offset, 32)); + constOffsets.push_back(offset); } } @@ -442,14 +441,14 @@ LogicalResult Pingponger::genLocalSlice(OpBuilder &builder, Value v, shape, elementType, type.getEncoding(), type.getMemorySpace(), type.getMutableMemory(), type.getAllocShape()); for (int i = 0; i < numSlices; i++) { - SmallVector offsetsVal; + SmallVector logicalOffsets; SmallVector offsets = {0, 0}; offsets[kIdx] = i; for (int64_t off : offsets) { - offsetsVal.push_back(constOffsets[off]); + logicalOffsets.push_back(constOffsets[off]); } - Value newSmem = builder.create( - v.getLoc(), subviewDescType, memDesc, offsetsVal); + Value newSmem = builder.create( + v.getLoc(), subviewDescType, memDesc, logicalOffsets); Value prefetchSlice = builder.create( v.getLoc(), RankedTensorType::get(shape, elementType, dotOperandEnc), newSmem); diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/InThreadTranspose.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/InThreadTranspose.cpp index a9235c8d82dc..c01fd094dbe9 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/InThreadTranspose.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/InThreadTranspose.cpp @@ -154,7 +154,7 @@ struct GlobalToSharedMemoryOpChain { SetVector globalLoads; // list of localAllocOp and localStoreOp operations SetVector localAllocStores; - // list of MemDescSubviewOp, control flow results and block operands + // list of MemDescIndexOp, control flow results and block operands SmallVector sharedMemVals; }; @@ -475,7 +475,7 @@ FailureOr> findAllDefiningOps(Value val) { /// /// ttg.local_alloc -----x-------------------------> ttg.local_dealloc /// V -/// tt.load -> ttg.local_store -> ttg.mem_subview -> ttg.local_load +/// tt.load -> ttg.local_store -> ttg.memdesc_index -> ttg.local_load /// /// \returns partially filled GlobalToSharedMemoryOpChain structure of failure. FailureOr @@ -508,7 +508,7 @@ findReachableSMemOps(ttg::LocalLoadOp root) { } else if (isa(candidate)) { foundNetwork.localAllocStores.insert(candidate); smemOperand = candidate->getOperand(1); - } else if (isa(candidate)) { + } else if (isa(candidate)) { smemOutput = candidate->getResult(0); smemOperand = candidate->getOperand(0); } else if (isa(candidate)) { @@ -537,7 +537,7 @@ findReachableSMemOps(ttg::LocalLoadOp root) { foundNetwork.sharedMemVals.push_back(def); if (Operation *op = def.getDefiningOp()) { // additional check, to ignore control flow operations - if (isa(op)) + if (isa(op)) nextTraversalStep.push_back(op); } } @@ -574,10 +574,10 @@ unsigned getMaxSizePerThread(RankedTensorType type, int dimIdx) { // ttg.local_alloc ---x // | // V -// tt.load --> ttg.local_store --> ttg.memdesc_subview --> ttg.local_load +// tt.load --> ttg.local_store --> ttg.memdesc_index --> ttg.local_load // // Actual network could vary, because of different control flow, -// optional ttg.memdesc_subview and ttg.local_store operations. +// optional ttg.memdesc_index and ttg.local_store operations. // // If data flow pattern match, check applicability // of inThreadTrasnpose optimization and return found pattern. @@ -597,7 +597,7 @@ matchInThreadTransposePattern(ttg::LocalLoadOp lLoad) { return failure(); } - // find local_alloc, local_store, local_load and ttg.memdesc_subview + // find local_alloc, local_store, local_load and ttg.memdesc_index // operations auto sharedMemSearch = findReachableSMemOps(lLoad); if (failed(sharedMemSearch)) { diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp index 26006965f02d..78b6870dfbb5 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp @@ -113,7 +113,7 @@ using LoadToInfoMap = llvm::MapVector; struct StreamCopyChainOps { tt::LoadOp loadOp; - ttg::MemDescSubviewOp subviewOp; + ttg::MemDescIndexOp subviewOp; ttg::LocalStoreOp localStoreOp; ttg::LocalLoadOp maybeLocalLoadOp; }; @@ -135,7 +135,7 @@ AsyncCopyChainOps createAsyncCopy(tt::LoadOp loadOp, Value alloc, // Extract local subview from shared allocation auto viewLoad = triton::createSingleBufferView(builder, alloc, extractIdx) - .getDefiningOp(); + .getDefiningOp(); auto copyOp = builder.create( loc, loadOp.getPtr(), viewLoad, loadOp.getMask(), loadOp.getOther(), @@ -172,7 +172,7 @@ StreamCopyChainOps createStreamCopy(tt::LoadOp loadOp, Value alloc, // Extract local subview from shared allocation auto viewLoad = triton::createSingleBufferView(builder, alloc, extractIdx) - .getDefiningOp(); + .getDefiningOp(); tt::LoadOp newLoadOp = cast(builder.clone(*loadOp)); auto storeOp = builder.create(loc, newLoadOp, viewLoad); diff --git a/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/CodePartitionUtility.cpp b/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/CodePartitionUtility.cpp index a2a01717141a..5403561c66b9 100644 --- a/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/CodePartitionUtility.cpp +++ b/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/CodePartitionUtility.cpp @@ -166,9 +166,8 @@ Value getBarrierForPipelineStage(OpBuilderWithAsyncTaskIds &builder, /*mutableMemory=*/true); // Create barrierForTMA from barrierAlloc. - return builder.createWithAsyncTaskIds( - barrierAlloc.getLoc(), barrierTy, barrierAlloc, - ArrayRef({bufferIdx})); + return builder.createWithAsyncTaskIds( + barrierAlloc.getLoc(), barrierTy, barrierAlloc, bufferIdx); } } // namespace mlir diff --git a/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/WSCodePartition.cpp b/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/WSCodePartition.cpp index d02485a0e5e0..a8d0229a4703 100644 --- a/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/WSCodePartition.cpp +++ b/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/WSCodePartition.cpp @@ -564,7 +564,7 @@ static Value createBarrierAlloc(triton::FuncOp funcOp, unsigned distance) { loc, barrierMemDescType, Value()); for (unsigned i = 0; i < distance; i++) { Value idx = builder.create(loc, i, 32); - Value barrierView = builder.create( + Value barrierView = builder.create( loc, singleBarrierMemDescType, barrierAlloc, idx); builder.create(funcOp->getLoc(), barrierView, 1); } diff --git a/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/WSDataPartition.cpp b/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/WSDataPartition.cpp index 9b589970c123..ec64f3b3e99f 100644 --- a/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/WSDataPartition.cpp +++ b/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/WSDataPartition.cpp @@ -670,13 +670,11 @@ static void rewriteRematerializedOps(triton::FuncOp &funcOp, SmallVector shape = getShape(memdescType); int sliceSize = shape[dim] / partitionScheme.numPartitions; shape[dim] = sliceSize; - Value zero = builder.createWithAsyncTaskIds( - allocOp.getLoc(), 0, 32); auto slicedMemdescType = MemDescType::get( shape, memdescType.getElementType(), memdescType.getEncoding(), memdescType.getMemorySpace(), memdescType.getMutableMemory()); - SmallVector offsets(shape.size(), zero); - auto viewOp = builder.createWithAsyncTaskIds( + SmallVector offsets(shape.size(), 0); + auto viewOp = builder.createWithAsyncTaskIds( allocOp.getLoc(), slicedMemdescType, allocOp.getResult(), offsets); newOp = viewOp; } else if (isa(oldOp)) { diff --git a/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/WSLowerMem.cpp b/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/WSLowerMem.cpp index 7589f824ab42..13018d0d68b0 100644 --- a/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/WSLowerMem.cpp +++ b/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/WSLowerMem.cpp @@ -67,14 +67,10 @@ createAsyncCopy(const DenseMap &bufferMap, Channel *c, ttg::MemDescType::get(sliceType.getShape(), sliceType.getElementType(), sliceType.getEncoding(), sharedMemorySpace, /*mutableMemory=*/true); - Value zero = builder.createWithAsyncTaskIds( - loadOp.getLoc(), 0, 32); - SmallVector copyOffsets(sliceType.getRank() + 1, zero); - copyOffsets[0] = bufferIdx; builder.setAsyncTaskIdsFromOp(loadOp); builder.setInsertionPointAfter(loadOp); - auto view = builder.createWithAsyncTaskIds( - loadOp.getLoc(), subviewTy, buffer, copyOffsets); + auto view = builder.createWithAsyncTaskIds( + loadOp.getLoc(), subviewTy, buffer, bufferIdx); // Create cp.async Operation *copy = builder.createWithAsyncTaskIds( @@ -85,10 +81,8 @@ createAsyncCopy(const DenseMap &bufferMap, Channel *c, // Extract part. builder.setAsyncTaskIdsFromValueUsers(loadResult); builder.setInsertionPoint(c->getDstOp()); - SmallVector loadOffsets(sliceType.getRank() + 1, zero); - loadOffsets[0] = bufferIdxExtract; - auto viewLoad = builder.createWithAsyncTaskIds( - loadOp.getLoc(), subviewTy, buffer, loadOffsets); + auto viewLoad = builder.createWithAsyncTaskIds( + loadOp.getLoc(), subviewTy, buffer, bufferIdxExtract); auto sharedLoad = builder.createWithAsyncTaskIds( loadOp.getLoc(), loadOp.getType(), viewLoad /*,wait->getResult(0)*/); // Replace all uses of loadResult @@ -133,12 +127,8 @@ createLocalCopy(const DenseMap &bufferMap, Channel *channel, OpBuilderWithAsyncTaskIds builder(dstOp); builder.setAsyncTaskIdsFromOp(dstOp); builder.setInsertionPoint(dstOp); - Value zero = builder.createWithAsyncTaskIds( - dstOp->getLoc(), 0, 32); - SmallVector loadOffsets(sliceType.getRank() + 1, zero); - loadOffsets[0] = dstBufferIdx; - auto dstView = builder.createWithAsyncTaskIds( - dstOp->getLoc(), subviewTy, buffer, loadOffsets); + auto dstView = builder.createWithAsyncTaskIds( + dstOp->getLoc(), subviewTy, buffer, dstBufferIdx); auto sharedLoad = builder.createWithAsyncTaskIds( dstOp->getLoc(), srcValue.getType(), dstView); srcValue.replaceAllUsesWith(sharedLoad.getResult()); @@ -146,13 +136,9 @@ createLocalCopy(const DenseMap &bufferMap, Channel *channel, // Producer part. Create local_store for new producers. builder.setAsynTaskIdsFromArray(channel->relation.first); builder.setInsertionPoint(srcOp->getParentOp()); - zero = builder.createWithAsyncTaskIds(srcOp->getLoc(), - 0, 32); - SmallVector storeOffsets(sliceType.getRank() + 1, zero); - storeOffsets[0] = srcBufferIdx; builder.setInsertionPointAfter(srcOp); - auto srcView = builder.createWithAsyncTaskIds( - srcOp->getLoc(), subviewTy, buffer, storeOffsets); + auto srcView = builder.createWithAsyncTaskIds( + srcOp->getLoc(), subviewTy, buffer, srcBufferIdx); // Create local_alloc Operation *copy = builder.createWithAsyncTaskIds( srcOp->getLoc(), srcValue, srcView); @@ -175,15 +161,8 @@ static Value createBufferView(OpBuilderWithAsyncTaskIds &builder, Value alloc, shape, allocDescType.getElementType(), allocDescType.getEncoding(), allocDescType.getMemorySpace(), allocDescType.getMutableMemory(), /*allocShape=*/allocDescType.getAllocShape()); - SmallVector idxs = {idx}; - if (allocDescType.getShape().size() > 1) { - Value zero = builder.create(alloc.getLoc(), 0, 32); - for (unsigned i = 1; i < allocDescType.getShape().size(); i++) { - idxs.push_back(zero); - } - } - return builder.create( - alloc.getLoc(), viewDescType, alloc, idxs); + return builder.create(alloc.getLoc(), + viewDescType, alloc, idx); } static int getTMALoadSize(tt::DescriptorLoadOp &tmaLoad) { @@ -216,13 +195,8 @@ Value getBufferForPipelineStage(OpBuilderWithAsyncTaskIds &builder, sliceType.getEncoding(), sharedMemorySpace, /*mutableMemOry=*/mutableMem); - Value zero = builder.createWithAsyncTaskIds( - buffer.getLoc(), 0, 32); - SmallVector copyOffsets(sliceType.getRank() + 1, zero); - copyOffsets[0] = bufferIdx; - - return builder.createWithAsyncTaskIds( - buffer.getLoc(), subviewTy, buffer, copyOffsets); + return builder.createWithAsyncTaskIds( + buffer.getLoc(), subviewTy, buffer, bufferIdx); } Operation *optimizeTMALoads(OpBuilderWithAsyncTaskIds &builder, diff --git a/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/WSLowerToken.cpp b/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/WSLowerToken.cpp index a152914ae8c3..9b99beddbe46 100644 --- a/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/WSLowerToken.cpp +++ b/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/WSLowerToken.cpp @@ -151,14 +151,14 @@ void lowerTokenOperations(Operation *parentOp, int numCTAs, unsigned bufferEmptyCount = THREADS_PER_TASK; for (unsigned i = 0; i < createTokenOp.getNumBuffers(); i++) { Value idx = builder.create(loc, i, 32); - Value barrierFullView = builder.create( + Value barrierFullView = builder.create( loc, singleBarrierMemDescType, bufferFullArray, idx); // EmptyView is used for ConsumerRelease and ProducerAcquire. // FullView is for ConsumerWait and ProducerCommit. builder.create(loc, barrierFullView, bufferFullCount); - Value barrierEmptyView = builder.create( + Value barrierEmptyView = builder.create( loc, singleBarrierMemDescType, bufferEmptyArray, idx); builder.create(loc, barrierEmptyView, bufferEmptyCount); @@ -169,14 +169,14 @@ void lowerTokenOperations(Operation *parentOp, int numCTAs, // Helper function for extracting one index from bufferFullArray. auto extractBufferFull = [&](Location loc, Value idx) -> Value { - return builder.create( - loc, singleBarrierMemDescType, bufferFullArray, idx); + return builder.create(loc, singleBarrierMemDescType, + bufferFullArray, idx); }; // Helper function for extracting one index from bufferEmptyArray. auto extractBufferEmpty = [&](Location loc, Value idx) -> Value { - return builder.create( - loc, singleBarrierMemDescType, bufferEmptyArray, idx); + return builder.create(loc, singleBarrierMemDescType, + bufferEmptyArray, idx); }; auto handleOneUser = [&](Operation *user) -> bool { // Here builder is at the user, make sure usage of values outside of diff --git a/third_party/nvidia/lib/Dialect/NVWS/IR/Ops.cpp b/third_party/nvidia/lib/Dialect/NVWS/IR/Ops.cpp index 5174ad1121c7..c0356ab02847 100644 --- a/third_party/nvidia/lib/Dialect/NVWS/IR/Ops.cpp +++ b/third_party/nvidia/lib/Dialect/NVWS/IR/Ops.cpp @@ -60,7 +60,7 @@ std::optional static arefEnterVerify( auto typeArray = aref.getBaseType(); if (typeArray.size() != resultTypes.size()) return "Aref has different number of arguments than enter"; - // This should probably rely on the memdescSubviewOp verifier? + // This should probably rely on the memdescSubsliceOp verifier? for (auto [orig, arg] : llvm::zip(typeArray, resultTypes)) { if (auto origT = dyn_cast(orig)) { auto argT = dyn_cast(arg); diff --git a/third_party/nvidia/lib/Dialect/NVWS/Transforms/LowerAref.cpp b/third_party/nvidia/lib/Dialect/NVWS/Transforms/LowerAref.cpp index 69b326e3f083..545d60d00797 100644 --- a/third_party/nvidia/lib/Dialect/NVWS/Transforms/LowerAref.cpp +++ b/third_party/nvidia/lib/Dialect/NVWS/Transforms/LowerAref.cpp @@ -193,21 +193,16 @@ SmallVector getSubViews(ArefValue arefVal, Value stage, Location loc, OpBuilder &rewriter) { SmallVector views; for (auto buffer : arefVal.buffers) { - SmallVector offsetsVal{stage}; auto memDescType = cast(buffer.getType()); auto shape = memDescType.getShape(); auto rank = shape.size() - 1; - for (int i = 0; i < rank; ++i) { - offsetsVal.push_back(rewriter.create( - loc, 0, rewriter.getIntegerType(32))); - } SmallVector tensorShape(shape.begin() + 1, shape.end()); auto memDescTypeNew = MemDescType::get( tensorShape, memDescType.getElementType(), memDescType.getEncoding(), memDescType.getMemorySpace(), true); - Value singleBuffer = rewriter.create(loc, memDescTypeNew, - buffer, offsetsVal); + Value singleBuffer = + rewriter.create(loc, memDescTypeNew, buffer, stage); views.push_back(singleBuffer); } diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TensorMemoryToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TensorMemoryToLLVM.cpp index 6560f6b19048..827fbc883000 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TensorMemoryToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TensorMemoryToLLVM.cpp @@ -858,13 +858,13 @@ struct TensorMemoryCopyOpConversion } }; -struct MemDescSubviewOpConversion - : public ConvertOpToLLVMPattern { +struct MemDescIndexOpConversion + : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern< - triton::gpu::MemDescSubviewOp>::ConvertOpToLLVMPattern; + triton::gpu::MemDescIndexOp>::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(triton::gpu::MemDescSubviewOp op, OpAdaptor adaptor, + matchAndRewrite(triton::gpu::MemDescIndexOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); auto b = TritonLLVMOpBuilder(loc, rewriter); @@ -879,22 +879,14 @@ struct MemDescSubviewOpConversion // newBase = base + offset auto tmemBase = adaptor.getSrc(); - SmallVector opOffsetVals = op.getOffsets(); - size_t destRank = op.getResult().getType().getRank(); - SmallVector offsetVals; - int rankReduced = srcTy.getRank() - destRank; - for (int i = rankReduced; i < opOffsetVals.size(); i++) { - offsetVals.push_back(opOffsetVals[i]); - } - + auto idx = op.getIndex(); triton::nvidia_gpu::TMemAllocation tmemAlloc = triton::nvidia_gpu::getTmemAllocSizes(cast(dstTy)); int numColOffset = tmemAlloc.numCols; Value newBase = b.ptrtoint(rewriter.getI32Type(), tmemBase); newBase = rewriter.create( loc, newBase, - rewriter.create(loc, opOffsetVals[0], - b.i32_val(numColOffset))); + rewriter.create(loc, idx, b.i32_val(numColOffset))); auto elemPtrTy = ptr_ty(rewriter.getContext(), 3); rewriter.replaceOp(op, b.inttoptr(elemPtrTy, newBase)); return success(); @@ -977,7 +969,7 @@ void mlir::triton::NVIDIA::populateTensorMemoryOpToLLVMPattern( void mlir::triton::NVIDIA::populateTensorMemorySubviewOpToLLVMPattern( LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, PatternBenefit benefit) { - patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); return; }