Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -356,12 +356,12 @@ class SharedMemoryObject {
RewriterBase &rewriter) const;

// Returns a mask representing all the bits of the memdesc offsets that
// may be modified by an affine offset coming from a memdesc_subview.
// may be modified by an affine offset coming from a memdesc_subslice.
// The offsets are considered to be in the type of the memdesc.
// For padded layouts, we return the offsets without padding.
static uint64_t getMaskSpanOffsets(triton::gpu::MemDescType srcTy);

// Returns whether the shared memory access had a memdesc_subview
// Returns whether the shared memory access had a memdesc_subslice
// that is rank-preserving (soon to be called memdesc_slice)
static bool isAffineSharedMemoryAccess(triton::gpu::MemDescType srcTy) {
return getMaskSpanOffsets(srcTy) != 0;
Expand Down
57 changes: 38 additions & 19 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -200,38 +200,57 @@ def TTG_LocalDeallocOp : TTG_Op<"local_dealloc"> {
// Use qualified() otherwise "!ttg.memdesc<X>" is printed as "<X>".
let assemblyFormat = [{$src attr-dict `:` qualified(type($src))}];
}

def TTG_MemDescSubviewOp : TTG_Op<"memdesc_subview", [Pure, MemDescViewTrait]> {
def TTG_MemDescIndexOp : TTG_Op<"memdesc_index", [Pure, MemDescViewTrait]> {
let summary = "take a subview of the descriptor.";

let description = [{
This operation returns a new descriptor representing a subview of the buffer.
This operation returns a new descriptor pointing to the `i`-th element of the
input descriptor along the 0-th dimension.

It doesn't affect the underlying memory.

For example, suppose that
- the input shape is 2x4x16xf16,
- the output shape is 4x16xf16, and
- offsets = [1, 0, 0].
- index = 1.
Then the output descriptor is equivalent to input[1], where input is the logical tensor.

Then in Python syntax, the subview covers input[1].
When the input is of rank 1 (i.e, shape=[k]), the output will have shape=[1].
}];

Just one dimension may be split (at most one non-zero offset).
let arguments = (ins TTG_MemDescType:$src, I32:$index);

When the input shape and the output shape have different rank:
Or the output shape is a tensor of 1D tensor of 1 element:
- The rank of the output must be 1D smaller than the input.
- We assume the input is split along the 0th dimension.
- The offset along the 0th dimension may be a runtime value.
When the input and the output have the same rank:
- The offset must be a compile-time constant
- Larger or equal to the tile of the tensor (or zero)
- That does not split the input along the swizzling pattern (if any)
}];
let arguments = (
ins TTG_MemDescType:$src, Variadic<I32>:$offsets);
let results = (outs TTG_MemDescType:$result);

let assemblyFormat = [{$src `,` $index attr-dict `:` qualified(type($src)) `->` qualified(type($result))}];

let hasVerifier = 1;
}

def TTG_MemDescSubsliceOp : TTG_Op<"memdesc_subslice", [Pure, MemDescViewTrait]> {
let summary = "take a subview of the descriptor.";

let description = [{
This operation returns a new descriptor representing a subview of the logical tensor.
It doesn't affect the underlying memory.

For example, suppose that
- the input shape is 32x16xf16,
- the output shape is 8x16xf16, and
- offsets = [2, 1].
Then in Python syntax, the subview covers input[2:8+2, 1:16+1] where input is
the logical tensor.

The offsets must be larger or equal to the tile of the tensor (or zero).
}];
let arguments = (ins TTG_MemDescType:$src, DenseI32ArrayAttr:$offsets);
// Use qualified() otherwise "!ttg.memdesc<X>" is printed as "<X>".
let assemblyFormat = [{$src `[` $offsets `]` attr-dict `:` qualified(type($src)) `->` qualified(type($result))}];
// Render offsets inline as %src[0, 0] via a custom directive, but keep
// the overall parse/print generated from this assemblyFormat.
let assemblyFormat = [{
$src `[` custom<Offsets>($offsets) `]` attr-dict `:` qualified(type($src))
`->` qualified(type($result))
}];

let results = (outs TTG_MemDescType:$result);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,11 @@ gpu::SharedEncodingTrait getSharedEncoding(Operation *loadOp);
// specified.
int getNumStagesOrDefault(scf::ForOp forOp, int defaultNumStages);

// Given a result of MemDescSubview, or Alloca, create a MemDescSubview with a
// Given a result of MemDescIndex, or Alloca, create a MemDescIndex with a
// single buffer slice (leading dimension equal to 1), at the given index.
TypedValue<triton::gpu::MemDescType>
createSingleBufferView(OpBuilder &builder, Value alloc, Value idx);
// Given a result of MemDescSubview, or Alloca, create a MemDescSubview with a
// Given a result of MemDescIndex, or Alloca, create a MemDescIndex with a
// single buffer slice (leading dimension equal to 1), at the given index.
TypedValue<triton::gpu::MemDescType>
createSingleBufferView(OpBuilder &builder, Value alloc, int idx);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -674,8 +674,8 @@ def TTNG_TMEMSubSliceOp : TTNG_Op<"tmem_subslice", [Pure]> {
let description = [{
This operation takes a subslice of a tensor memory allocation and returns a new descriptor
containing the address and a view of the subslice.
This is similar to ttg.memdesc_subview except the offset needs to be static and we can only
slice alog the inner dimension of a 2D memdesc as this is the only one we can do for TMem.
This is similar to ttg.memdesc_subslice except we can only slice along the inner dimension
of a 2D memdesc as this is the only one we can do for TMem.
}];
let arguments = (ins TTG_MemDescType:$src, I32Attr:$N);

Expand Down
4 changes: 2 additions & 2 deletions lib/Conversion/TritonGPUToLLVM/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -705,7 +705,7 @@ bool emitTransferBetweenRegistersAndShared(
{kBlock, blockId}},
regIds);

// Compute affine offset given by memdesc_subview
// Compute affine offset given by memdesc_subslice
auto offset = smemObj.getShmemOffset(loc, rewriter, sharedTy);
SmallVector<Value> vecAddrVec;
for (auto &indices : indicesVec) {
Expand Down Expand Up @@ -1153,7 +1153,7 @@ Value SharedMemoryObject::getShmemOffset(Location loc, RewriterBase &rewriter,
auto ctx = srcTy.getContext();
auto b = TritonLLVMOpBuilder(loc, rewriter);

// If it did not have a memdesc_subview, we don't need to compute the offset
// If it did not have a memdesc_subslice we don't need to compute the offset
// as it is zero
if (!isAffineSharedMemoryAccess(srcTy)) {
return b.i32_val(0);
Expand Down
81 changes: 46 additions & 35 deletions lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -465,13 +465,46 @@ struct BroadcastOpConversion
}
};

struct MemDescSubviewOpConversion
: public ConvertOpToLLVMPattern<triton::gpu::MemDescSubviewOp> {
struct MemDescIndexOpConversion
: public ConvertOpToLLVMPattern<triton::gpu::MemDescIndexOp> {
using ConvertOpToLLVMPattern<
triton::gpu::MemDescSubviewOp>::ConvertOpToLLVMPattern;
triton::gpu::MemDescIndexOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(triton::gpu::MemDescSubviewOp op, OpAdaptor adaptor,
matchAndRewrite(triton::gpu::MemDescIndexOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
auto *ctx = op->getContext();
auto b = TritonLLVMOpBuilder(loc, rewriter);
auto srcTy = op.getSrc().getType();
auto destTy = op.getResult().getType();
auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType());

auto smemObj = getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(),
llvmElemTy, rewriter);
auto base = smemObj.getBase();
auto elemPtrTy = base.getType();
Value stride = smemObj.getStrides(srcTy, loc, rewriter).front();
Value offset = b.mul(op.getIndex(), stride);
auto prevOffsets = smemObj.getOffsets();
SmallVector<Value> offsetVals(prevOffsets.end() - destTy.getRank(),
prevOffsets.end());
// Advance the pointer and keep the opOffsets as the new shape
smemObj = SharedMemoryObject(b.gep(elemPtrTy, llvmElemTy, base, offset),
llvmElemTy, offsetVals);
auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter);
rewriter.replaceOp(op, retVal);
return success();
}
};

struct MemDescSubsliceOpConversion
: public ConvertOpToLLVMPattern<triton::gpu::MemDescSubsliceOp> {
using ConvertOpToLLVMPattern<
triton::gpu::MemDescSubsliceOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(triton::gpu::MemDescSubsliceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
auto *ctx = op->getContext();
Expand All @@ -484,40 +517,17 @@ struct MemDescSubviewOpConversion

auto smemObj = getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(),
llvmElemTy, rewriter);
SmallVector<Value> opOffsetVals = op.getOffsets();
// We assume we always create a subview of the last dimensions
// Compute total offset
auto rankReduced = srcTy.getRank() - destTy.getRank();
auto opOffsetVals = op.getOffsets();

auto base = smemObj.getBase();
auto elemPtrTy = base.getType();
auto is1d = srcTy.getRank() == 1 && destTy.getRank() == 1 &&
destTy.getDimSize(0) == 1;
if (rankReduced || is1d) {
auto smemStrides = smemObj.getStrides(srcTy, loc, rewriter);
SmallVector<Value> opSmemStrides(smemStrides.end() - opOffsetVals.size(),
smemStrides.end());
// We are splitting the pipelining dimension which may not be a power of 2
// so we can't use LinearLayouts
auto offset = dot(rewriter, loc, opOffsetVals, opSmemStrides);
// Remove the first offsets
SmallVector<Value> offsetVals;
for (int i = rankReduced; i < opOffsetVals.size(); i++) {
offsetVals.push_back(b.add(opOffsetVals[i], smemObj.getOffsets()[i]));
}
// Advance the pointer and keep the opOffsets as the new shape
smemObj = SharedMemoryObject(b.gep(elemPtrTy, llvmElemTy, base, offset),
llvmElemTy, offsetVals);
} else {
// Accumulate the logical offsets
SmallVector<Value> offsetVals;
for (auto [oldOff, newOff] :
llvm::zip(smemObj.getOffsets(), opOffsetVals)) {
offsetVals.push_back(b.add(oldOff, newOff));
}
smemObj = SharedMemoryObject(base, llvmElemTy, offsetVals);
// Accumulate the logical offsets
SmallVector<Value> offsetVals;
for (auto [oldOffVal, opOff] :
llvm::zip(smemObj.getOffsets(), opOffsetVals)) {
offsetVals.push_back(b.add(oldOffVal, b.i32_val(opOff)));
}

smemObj = SharedMemoryObject(base, llvmElemTy, offsetVals);
auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter);
rewriter.replaceOp(op, retVal);
return success();
Expand Down Expand Up @@ -563,6 +573,7 @@ void mlir::triton::populateViewOpToLLVMPatterns(
typeConverter, benefit);
patterns.add<TransOpConversion>(typeConverter, benefit);
patterns.add<BroadcastOpConversion>(typeConverter, benefit);
patterns.add<MemDescSubviewOpConversion>(typeConverter, benefit);
patterns.add<MemDescSubsliceOpConversion, MemDescIndexOpConversion>(
typeConverter, benefit);
patterns.add<MemDescReinterpretOpConversion>(typeConverter, benefit);
}
Loading
Loading