Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/triton/Dialect/TritonGPU/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ SmallVector<unsigned> getSizePerThread(Attribute layout);

SmallVector<unsigned> getContigPerThread(Attribute layout);

SmallVector<unsigned> getUniqueContigPerThread(Type type);

SmallVector<unsigned> getThreadsPerCTA(Attribute layout);

SmallVector<unsigned>
Expand Down
10 changes: 6 additions & 4 deletions lib/Analysis/AxisInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -919,11 +919,13 @@ unsigned AxisInfoAnalysis::getPtrContiguity(Value ptr) {
auto order = triton::gpu::getOrder(layout);
unsigned align = getPtrAlignment(ptr);

unsigned contigPerThread = triton::gpu::getSizePerThread(layout)[order[0]];
contigPerThread = std::min(align, contigPerThread);
contigPerThread = std::min<unsigned>(shape[order[0]], contigPerThread);
auto uniqueContigPerThread = triton::gpu::getUniqueContigPerThread(tensorTy);
assert(order[0] < uniqueContigPerThread.size() &&
"Unxpected uniqueContigPerThread size");
unsigned contiguity = uniqueContigPerThread[order[0]];
contiguity = std::min(align, contiguity);

return contigPerThread;
return contiguity;
}

unsigned AxisInfoAnalysis::getPtrAlignment(Value ptr) {
Expand Down
12 changes: 8 additions & 4 deletions lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,13 +106,17 @@ struct ConvertLayoutOpConversion
if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
unsigned dim = sliceLayout.getDim();
auto parentEncoding = sliceLayout.getParent();
auto parentSizePerThread = getSizePerThread(parentEncoding);
unsigned stride = 1;
if (getOrder(parentEncoding)[0] == dim)
stride = parentSizePerThread[dim];
auto parentShape = sliceLayout.paddedShape(shape);
auto parentTy = RankedTensorType::get(parentShape, type.getElementType(),
parentEncoding);
auto multiDimOffsetParent =
getMultiDimOffset(parentEncoding, loc, rewriter, elemId, parentTy,
sliceLayout.paddedShape(multiDimCTAInRepId),
sliceLayout.paddedShape(shapePerCTA));
auto multiDimOffsetParent = getMultiDimOffset(
parentEncoding, loc, rewriter, elemId * stride, parentTy,
sliceLayout.paddedShape(multiDimCTAInRepId),
sliceLayout.paddedShape(shapePerCTA));
SmallVector<Value> multiDimOffset(rank);
for (unsigned d = 0; d < rank + 1; ++d) {
if (d == dim)
Expand Down
42 changes: 28 additions & 14 deletions lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
#include "Utility.h"
#include "mlir/IR/TypeUtilities.h"
#include "triton/Analysis/AxisInfo.h"

#include <set>
using namespace mlir;
using namespace mlir::triton;

Expand Down Expand Up @@ -521,6 +521,13 @@ class ConvertTritonGPUOpToLLVMPatternBase {
result = emitBaseIndexForMmaLayoutV1(loc, rewriter, mmaLayout, type);
if (mmaLayout.isAmpere())
result = emitBaseIndexForMmaLayoutV2(loc, rewriter, mmaLayout, type);
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
auto parentLayout = sliceLayout.getParent();
auto parentShape = sliceLayout.paddedShape(type.getShape());
RankedTensorType parentTy = RankedTensorType::get(
parentShape, type.getElementType(), parentLayout);
result = emitBaseIndexForLayout(loc, rewriter, parentLayout, parentTy);
result.erase(result.begin() + sliceLayout.getDim());
} else {
llvm_unreachable("unsupported emitBaseIndexForLayout");
}
Expand All @@ -540,6 +547,8 @@ class ConvertTritonGPUOpToLLVMPatternBase {
if (mmaLayout.isAmpere())
return emitOffsetForMmaLayoutV2(mmaLayout, type);
}
if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>())
return emitOffsetForSliceLayout(sliceLayout, type);
llvm_unreachable("unsupported emitOffsetForLayout");
}

Expand All @@ -565,7 +574,7 @@ class ConvertTritonGPUOpToLLVMPatternBase {
} else if (auto mma = layout.dyn_cast<MmaEncodingAttr>()) {
result = emitIndicesForDistributedLayout(loc, b, mma, type);
} else if (auto slice = layout.dyn_cast<SliceEncodingAttr>()) {
result = emitIndicesForSliceLayout(loc, b, slice, type);
result = emitIndicesForDistributedLayout(loc, b, slice, type);
} else {
llvm_unreachable(
"emitIndices for layouts other than blocked & slice not "
Expand Down Expand Up @@ -879,24 +888,29 @@ class ConvertTritonGPUOpToLLVMPatternBase {
return multiDimIdx;
}

SmallVector<SmallVector<Value>>
emitIndicesForSliceLayout(Location loc, ConversionPatternRewriter &rewriter,
const SliceEncodingAttr &sliceLayout,
RankedTensorType type) const {
SmallVector<SmallVector<unsigned>>
emitOffsetForSliceLayout(const SliceEncodingAttr &sliceLayout,
RankedTensorType type) const {
auto parentEncoding = sliceLayout.getParent();
unsigned dim = sliceLayout.getDim();
auto parentShape = sliceLayout.paddedShape(type.getShape());
RankedTensorType parentTy = RankedTensorType::get(
parentShape, type.getElementType(), parentEncoding);
auto parentIndices = emitIndices(loc, rewriter, parentEncoding, parentTy);
unsigned numIndices = parentIndices.size();
SmallVector<SmallVector<Value>> resultIndices;
for (unsigned i = 0; i < numIndices; ++i) {
SmallVector<Value> indices = parentIndices[i];
indices.erase(indices.begin() + dim);
resultIndices.push_back(indices);
auto parentOffsets = emitOffsetForLayout(parentEncoding, parentTy);

unsigned numOffsets = parentOffsets.size();
SmallVector<SmallVector<unsigned>> resultOffsets;
std::set<SmallVector<unsigned>> uniqueOffsets;

for (unsigned i = 0; i < numOffsets; ++i) {
SmallVector<unsigned> offsets = parentOffsets[i];
offsets.erase(offsets.begin() + dim);
if (uniqueOffsets.find(offsets) == uniqueOffsets.end()) {
resultOffsets.push_back(offsets);
uniqueOffsets.insert(offsets);
}
}
return resultIndices;
return resultOffsets;
}

protected:
Expand Down
64 changes: 52 additions & 12 deletions lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,23 +116,64 @@ struct CatOpConversion : public ConvertTritonGPUOpToLLVMPattern<CatOp> {
}
};

template <typename SourceOp>
struct ViewLikeOpConversion : public ConvertTritonGPUOpToLLVMPattern<SourceOp> {
using OpAdaptor = typename SourceOp::Adaptor;
explicit ViewLikeOpConversion(TritonGPUToLLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
: ConvertTritonGPUOpToLLVMPattern<SourceOp>(typeConverter, benefit) {}
struct ViewOpConversion : public ConvertTritonGPUOpToLLVMPattern<ViewOp> {
using OpAdaptor = typename ViewOp::Adaptor;
explicit ViewOpConversion(TritonGPUToLLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
: ConvertTritonGPUOpToLLVMPattern<ViewOp>(typeConverter, benefit) {}

LogicalResult
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
matchAndRewrite(ViewOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
auto resultTy = op.getType().template cast<RankedTensorType>();
auto vals = this->getTypeConverter()->unpackLLElements(
loc, adaptor.getSrc(), rewriter, op.getOperand().getType());
Value view =
Value ret =
this->getTypeConverter()->packLLElements(loc, vals, rewriter, resultTy);
rewriter.replaceOp(op, view);
rewriter.replaceOp(op, ret);
return success();
}
};

struct ExpandDimsOpConversion
: public ConvertTritonGPUOpToLLVMPattern<ExpandDimsOp> {
using OpAdaptor = typename ExpandDimsOp::Adaptor;
explicit ExpandDimsOpConversion(TritonGPUToLLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
: ConvertTritonGPUOpToLLVMPattern<ExpandDimsOp>(typeConverter, benefit) {}

LogicalResult
matchAndRewrite(ExpandDimsOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
auto srcVals = this->getTypeConverter()->unpackLLElements(
loc, adaptor.getSrc(), rewriter, op.getOperand().getType());

auto srcTy = op.getSrc().getType().cast<RankedTensorType>();
auto resultTy = op.getType().template cast<RankedTensorType>();

assert(srcTy.getEncoding().isa<SliceEncodingAttr>() &&
"ExpandDimsOp only support SliceEncodingAttr");
auto srcLayout = srcTy.getEncoding().dyn_cast<SliceEncodingAttr>();
auto resultLayout = resultTy.getEncoding();

auto srcOffsets = emitOffsetForLayout(srcLayout, srcTy);
auto resultOffsets = emitOffsetForLayout(resultLayout, resultTy);
DenseMap<SmallVector<unsigned>, Value, SmallVectorKeyInfo> srcValues;
for (size_t i = 0; i < srcOffsets.size(); i++) {
srcValues[srcOffsets[i]] = srcVals[i];
}

SmallVector<Value> resultVals;
for (size_t i = 0; i < resultOffsets.size(); i++) {
auto offset = resultOffsets[i];
offset.erase(offset.begin() + srcLayout.getDim());
resultVals.push_back(srcValues.lookup(offset));
}
Value ret = this->getTypeConverter()->packLLElements(loc, resultVals,
rewriter, resultTy);
rewriter.replaceOp(op, ret);
return success();
}
};
Expand Down Expand Up @@ -165,9 +206,8 @@ void populateViewOpToLLVMPatterns(TritonGPUToLLVMTypeConverter &typeConverter,
AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation, Value smem,
PatternBenefit benefit) {
patterns.add<ViewLikeOpConversion<triton::ViewOp>>(typeConverter, benefit);
patterns.add<ViewLikeOpConversion<triton::ExpandDimsOp>>(typeConverter,
benefit);
patterns.add<ViewOpConversion>(typeConverter, benefit);
patterns.add<ExpandDimsOpConversion>(typeConverter, benefit);
patterns.add<SplatOpConversion>(typeConverter, benefit);
patterns.add<ArithConstantSplatOpConversion>(typeConverter, benefit);
patterns.add<CatOpConversion>(typeConverter, benefit);
Expand Down
40 changes: 36 additions & 4 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,9 @@ SmallVector<unsigned> getSizePerThread(Attribute layout) {
return SmallVector<unsigned>(blockedLayout.getSizePerThread().begin(),
blockedLayout.getSizePerThread().end());
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
auto ret = getSizePerThread(sliceLayout.getParent());
return ret;
// ret.erase(ret.begin() + sliceLayout.getDim());
return ret;
auto sizePerThread = getSizePerThread(sliceLayout.getParent());
sizePerThread.erase(sizePerThread.begin() + sliceLayout.getDim());
return sizePerThread;
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
if (mmaLayout.isAmpere()) {
return {2, 2};
Expand Down Expand Up @@ -146,11 +145,43 @@ SmallVector<unsigned> getContigPerThread(Attribute layout) {
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
assert(mmaLayout.isVolta() || mmaLayout.isAmpere());
return {1, 2};
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
auto parentLayout = sliceLayout.getParent();
return getContigPerThread(parentLayout);
} else {
return getSizePerThread(layout);
}
}

SmallVector<unsigned> getUniqueContigPerThread(Type type) {
if (type.isIntOrIndexOrFloat() || type.isa<triton::PointerType>())
return SmallVector<unsigned>(1, 1);
auto tensorType = type.cast<RankedTensorType>();
auto shape = tensorType.getShape();
// If slice layout, call recursively on parent layout, and drop
// sliced dim
if (auto sliceLayout =
tensorType.getEncoding().dyn_cast<SliceEncodingAttr>()) {
auto parentLayout = sliceLayout.getParent();
auto parentShape = sliceLayout.paddedShape(shape);
auto parentTy = RankedTensorType::get(
parentShape, tensorType.getElementType(), parentLayout);
auto parentUniqueContigPerThread = getUniqueContigPerThread(parentTy);
parentUniqueContigPerThread.erase(parentUniqueContigPerThread.begin() +
sliceLayout.getDim());
return parentUniqueContigPerThread;
}
// Base case
auto rank = shape.size();
SmallVector<unsigned> ret(rank);
auto contigPerThread = getContigPerThread(tensorType.getEncoding());
assert(contigPerThread.size() == rank && "Unexpected contigPerThread size");
for (int d = 0; d < rank; ++d) {
ret[d] = std::min<unsigned>(shape[d], contigPerThread[d]);
}
return ret;
}

SmallVector<unsigned> getThreadsPerCTA(Attribute layout) {
SmallVector<unsigned> threads;
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
Expand Down Expand Up @@ -375,6 +406,7 @@ SliceEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape,
auto parent = getParent();
auto parentElemsPerThread =
::getElemsPerThread(parent, paddedShape(shape), eltTy);
parentElemsPerThread.erase(parentElemsPerThread.begin() + getDim());
return parentElemsPerThread;
}
unsigned SliceEncodingAttr::getTotalElemsPerThread(ArrayRef<int64_t> shape,
Expand Down
2 changes: 1 addition & 1 deletion test/Conversion/tritongpu_to_llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -869,7 +869,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: convert_blocked1d_to_slice1
tt.func @convert_blocked1d_to_slice1(%src:tensor<32xi32, #blocked0>) {
// CHECK-COUNT-32: llvm.load {{.*}} : !llvm.ptr<vector<1xi32>, 3>
// CHECK-COUNT-8: llvm.load {{.*}} : !llvm.ptr<vector<1xi32>, 3>
%cvt = triton_gpu.convert_layout %src : (tensor<32xi32, #blocked0>) -> tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
tt.return
}
Expand Down