Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 3 additions & 11 deletions lib/Analysis/Allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
return 0;
}
};
// blocked -> blocked
if (srcLayout.isa<BlockedEncodingAttr>() &&
dstLayout.isa<BlockedEncodingAttr>()) {
auto srcBlockedLayout = srcLayout.cast<BlockedEncodingAttr>();
Expand All @@ -66,14 +65,6 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
}
paddedRepShape[outOrd[0]] += pad;
}
// blocked -> shared
if (srcLayout.isa<BlockedEncodingAttr>() &&
dstLayout.isa<SharedEncodingAttr>()) {
auto sharedLayout = dstLayout.cast<SharedEncodingAttr>();
for (int v : dstTy.getShape())
paddedRepShape.push_back(v);
}

return paddedRepShape;
}

Expand Down Expand Up @@ -140,8 +131,9 @@ class AllocationAnalysis {
auto dstTy = cvtLayout.result().getType().cast<RankedTensorType>();
auto srcEncoding = srcTy.getEncoding();
auto dstEncoding = dstTy.getEncoding();
if (srcEncoding.isa<SharedEncodingAttr>()) {
// only block->block and block->shared is supported now
if (srcEncoding.isa<SharedEncodingAttr>() ||
dstEncoding.isa<SharedEncodingAttr>()) {
// Only blocked -> blocked conversion requires for scratch allocation
return;
}
// ConvertLayoutOp with both input/output non-shared_layout
Expand Down
125 changes: 88 additions & 37 deletions lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,13 @@ class ConvertTritonGPUOpToLLVMPattern
PatternBenefit benefit = 1)
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit) {}

explicit ConvertTritonGPUOpToLLVMPattern(LLVMTypeConverter &typeConverter,
const Allocation *allocation,
Value smem,
PatternBenefit benefit = 1)
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit),
allocation(allocation), smem(smem) {}

Value getThreadId(ConversionPatternRewriter &rewriter, Location loc) const {
auto llvmIndexTy = this->getTypeConverter()->getIndexType();
auto cast = rewriter.create<UnrealizedConversionCastOp>(
Expand Down Expand Up @@ -585,19 +592,23 @@ class ConvertTritonGPUOpToLLVMPattern
return multiDimIdx;
}

template <typename T>
Value getSharedMemoryBase(Location loc, ConversionPatternRewriter &rewriter,
Value smem, const Allocation *allocation,
Operation *op) const {
T value) const {
auto ptrTy = LLVM::LLVMPointerType::get(
this->getTypeConverter()->convertType(rewriter.getIntegerType(8)), 3);
auto bufferId = allocation->getBufferId(op);
this->getTypeConverter()->convertType(rewriter.getI8Type()), 3);
auto bufferId = allocation->getBufferId(value);
assert(bufferId != Allocation::InvalidBufferId && "BufferId not found");
size_t offset = allocation->getOffset(bufferId);
auto llvmIndexTy = this->getTypeConverter()->getIndexType();
Value offVal = createIndexAttrConstant(rewriter, loc, llvmIndexTy, offset);
Value base = gep(ptrTy, smem, offVal);
return base;
}

protected:
const Allocation *allocation;
Value smem;
};

// Convert SplatOp or arith::ConstantOp with SplatElementsAttr to a
Expand Down Expand Up @@ -1332,6 +1343,65 @@ struct AddPtrOpConversion
}
};

struct AllocTensorOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::gpu::AllocTensorOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::gpu::AllocTensorOp>::ConvertTritonGPUOpToLLVMPattern;

LogicalResult
matchAndRewrite(triton::gpu::AllocTensorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
Value smemBase = getSharedMemoryBase(loc, rewriter, op.getResult());
auto resultTy = op.getType().dyn_cast<RankedTensorType>();
auto llvmElemTy =
getTypeConverter()->convertType(resultTy.getElementType());
auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3);
Value resultVal =
rewriter.create<LLVM::BitcastOp>(loc, elemPtrTy, smemBase);
rewriter.replaceOp(op, resultVal);
return success();
}
};

struct ExtractSliceOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::gpu::ExtractSliceOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::gpu::ExtractSliceOp>::ConvertTritonGPUOpToLLVMPattern;

LogicalResult
matchAndRewrite(triton::gpu::ExtractSliceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
auto srcTy = op.src().getType().dyn_cast<RankedTensorType>();
auto srcLayout = srcTy.getEncoding().dyn_cast<SharedEncodingAttr>();
assert(srcLayout && "Unexpected resultLayout in ExtractSliceOpConversion");

// axis > 0 will result in non-contiguous memory access if the result tensor
// is an alias of the source tensor.
auto axis =
op->getAttrOfType<IntegerAttr>("axis").cast<IntegerAttr>().getInt();
assert(axis == 0 && "Only axis=0 is supported for now");

// Example:
// %dst = extract_slice %src, %index {axis = 0}
// src.shape = [11, 2, 3, 4, 1]
// offset = %index * 2 * 3 * 4 * 1
auto dstTy = op.getType().dyn_cast<RankedTensorType>();
auto base = product<int64_t>(dstTy.getShape());
auto baseVal = createIndexAttrConstant(
rewriter, loc, getTypeConverter()->getIndexType(), base);
Value offset = rewriter.create<LLVM::MulOp>(loc, adaptor.index(), baseVal);

auto llvmElemTy = getTypeConverter()->convertType(dstTy.getElementType());
auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3);
Value resultVal =
rewriter.create<LLVM::GEPOp>(loc, elemPtrTy, adaptor.src(), offset);
rewriter.replaceOp(op, resultVal);
return success();
}
};

template <typename SourceOp, typename DestOp>
class BinaryOpConversion : public ConvertTritonGPUOpToLLVMPattern<SourceOp> {
public:
Expand Down Expand Up @@ -1379,13 +1449,6 @@ struct ConvertLayoutOpConversion
using ConvertTritonGPUOpToLLVMPattern<
triton::gpu::ConvertLayoutOp>::ConvertTritonGPUOpToLLVMPattern;

ConvertLayoutOpConversion(LLVMTypeConverter &converter,
const Allocation *allocation, Value smem,
PatternBenefit benefit)
: ConvertTritonGPUOpToLLVMPattern<triton::gpu::ConvertLayoutOp>(converter,
benefit),
allocation(allocation), smem(smem) {}

LogicalResult
matchAndRewrite(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Expand All @@ -1399,13 +1462,10 @@ struct ConvertLayoutOpConversion
if ((!srcLayout.isa<BlockedEncodingAttr>()) ||
(!dstLayout.isa<BlockedEncodingAttr>())) {
// TODO: not implemented
llvm::errs()
<< "convert_layout except for blocked -> blocked is not implemented";
return failure();
}
auto llvmElemTy = getTypeConverter()->convertType(dstTy.getElementType());
Value smemBase =
getSharedMemoryBase(loc, rewriter, smem, allocation, op.getOperation());
Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation());
auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3);
smemBase = bit_cast(elemPtrTy, smemBase);

Expand Down Expand Up @@ -1587,9 +1647,6 @@ struct ConvertLayoutOpConversion
}
}
}

const Allocation *allocation;
Value smem;
};

/// ====================== dot codegen begin ==========================
Expand Down Expand Up @@ -1926,11 +1983,8 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
NOT_APPLICABLE,
};

explicit DotOpConversion(LLVMTypeConverter &typeConverter,
const Allocation *allocation, Value smem,
PatternBenefit benefit = 1)
: ConvertTritonGPUOpToLLVMPattern(typeConverter, benefit),
allocation(allocation), smem(smem) {}
using ConvertTritonGPUOpToLLVMPattern<
triton::DotOp>::ConvertTritonGPUOpToLLVMPattern;

LogicalResult
matchAndRewrite(triton::DotOp op, OpAdaptor adaptor,
Expand Down Expand Up @@ -1995,15 +2049,6 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
assert(false && "Not implemented yet.");
return failure();
}

Value getSmemAddr(Value value, Location loc,
ConversionPatternRewriter &rewriter) const {
return getSharedMemoryBase(loc, rewriter, smem, allocation,
value.getDefiningOp());
}

const Allocation *allocation;
Value smem;
};

struct DotOpConversionHelper {
Expand Down Expand Up @@ -2340,7 +2385,7 @@ DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adapter,
SmallVector<Value> ptrs(numPtrs);

Type smemPtrTy = helper.getShemPtrTy();
auto smemBase = getSmemAddr(tensor, loc, rewriter);
auto smemBase = getSharedMemoryBase(loc, rewriter, tensor);
for (int i = 0; i < numPtrs; i++) {
ptrs[i] = bit_cast(
smemPtrTy, gep(smemBase.getType(), smemBase, ValueRange({offs[i]})));
Expand Down Expand Up @@ -2479,10 +2524,12 @@ class TritonGPUToLLVMTypeConverter : public LLVMTypeConverter {
SmallVector<Type, 4> types(numElementsPerThread,
convertType(type.getElementType()));
return LLVM::LLVMStructType::getLiteral(&getContext(), types);
} else if (auto mma_layout = layout.dyn_cast<MmaEncodingAttr>()) {
return type;
} else if (auto shared_layout = layout.dyn_cast<SharedEncodingAttr>()) {
} else if (auto mma_layout = layout.dyn_cast_or_null<MmaEncodingAttr>()) {
// TODO: Not implemented
return type;
} else if (auto shared_layout =
layout.dyn_cast_or_null<SharedEncodingAttr>()) {
return LLVM::LLVMPointerType::get(convertType(type.getElementType()), 3);
}
return llvm::None;
}
Expand All @@ -2493,6 +2540,9 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation, Value smem,
PatternBenefit benefit = 1) {
patterns.add<AddPtrOpConversion>(typeConverter, benefit);
patterns.add<AllocTensorOpConversion>(typeConverter, allocation, smem,
benefit);
patterns.add<ArithConstantSplatOpConversion>(typeConverter, benefit);
patterns.add<BinaryOpConversion<arith::AddIOp, LLVM::AddOp>>(typeConverter,
benefit);
Expand All @@ -2503,9 +2553,10 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
patterns.add<BinaryOpConversion<arith::MulFOp, LLVM::FMulOp>>(typeConverter,
benefit);
patterns.add<BroadcastOpConversion>(typeConverter, benefit);
patterns.add<AddPtrOpConversion>(typeConverter, benefit);
patterns.add<ConvertLayoutOpConversion>(typeConverter, allocation, smem,
benefit);
patterns.add<ExtractSliceOpConversion>(typeConverter, allocation, smem,
benefit);
patterns.add<GetProgramIdOpConversion>(typeConverter, benefit);
patterns.add<LoadOpConversion>(typeConverter, axisInfoAnalysis, benefit);
patterns.add<MakeRangeOpConversion>(typeConverter, benefit);
Expand Down
7 changes: 4 additions & 3 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -431,9 +431,10 @@ mlir::LogicalResult ExtractSliceOp::inferReturnTypes(
auto axis = attributes.get("axis").cast<IntegerAttr>().getInt();
if (axis < 0 || axis > srcShape.size())
return failure();
// Since we only extract a slice from a certain index on the axis,
// the dims before the axis can be dropped.
auto dstShape = srcShape.drop_front(axis + 1);
SmallVector<int64_t, 4> dstShape;
for (int i = 0; i < srcShape.size(); i++)
if (i != axis)
dstShape.push_back(srcShape[i]);
auto returnType =
RankedTensorType::get(dstShape, srcType.getElementType(), encoding);
inferredReturnTypes.assign({returnType});
Expand Down
8 changes: 1 addition & 7 deletions test/Analysis/test-allocation.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,9 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B

scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>) {
%a_ = tt.load %a_ptr, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL>
// CHECK: scratch offset = 8192, size = 0
// CHECK-NEXT: offset = 0, size = 8192
// CHECK: offset = 0, size = 8192
%a = triton_gpu.convert_layout %a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A>
%b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL>
// CHECK-NEXT: scratch offset = 16384, size = 0
// CHECK-NEXT: offset = 8192, size = 8192
%b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B>

Expand All @@ -52,20 +50,16 @@ func @reusable(%A : !tt.ptr<f16>) {
%a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
%b_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<32x128x!tt.ptr<f16>, #AL>
%a1_ = tt.load %a_ptr, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL>
// CHECK: scratch offset = 8192, size = 0
// CHECK-NEXT: offset = 0, size = 8192
%a1 = triton_gpu.convert_layout %a1_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A>
%a2_ = tt.load %b_ptr, %cst3, %cst4 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #AL>
// CHECK-NEXT: scratch offset = 16384, size = 0
// CHECK-NEXT: offset = 8192, size = 8192
%a2 = triton_gpu.convert_layout %a2_ : (tensor<32x128xf16, #AL>) -> tensor<32x128xf16, #A>
%a3_ = tt.load %a_ptr, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL>
// CHECK-NEXT: scratch offset = 24576, size = 0
// CHECK-NEXT: offset = 16384, size = 8192
%a3 = triton_gpu.convert_layout %a3_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A>
%c = tt.dot %a1, %a2, %c_init {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
%a4_ = tt.load %b_ptr, %cst3, %cst4 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #AL>
// CHECK-NEXT: scratch offset = 8192, size = 0
// CHECK-NEXT: offset = 0, size = 8192
%a4 = triton_gpu.convert_layout %a4_ : (tensor<32x128xf16, #AL>) -> tensor<32x128xf16, #A>
%c1 = tt.dot %a3, %a4, %c {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
Expand Down
38 changes: 38 additions & 0 deletions test/Conversion/tritongpu_to_llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,44 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {

// -----

#shared0 = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK: llvm.mlir.global internal @global_smem
// CHECK-LABEL: basic_alloc_tensor
func @basic_alloc_tensor() {
// CHECK: llvm.mlir.addressof @global_smem
// CHECK-NEXT: llvm.mlir.constant
// CHECK-NEXT: llvm.getelementptr
// CHECK-NEXT: llvm.bitcast
%0 = triton_gpu.alloc_tensor : tensor<16x16xf16, #shared0>
return
}
}

// -----

#shared0 = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK: llvm.mlir.global internal @global_smem
// CHECK-LABEL: basic_extract_slice
func @basic_extract_slice() {
// CHECK: %[[BASE0:.*]] = llvm.mlir.addressof @global_smem
// CHECK-NEXT: %[[OFFSET0:.*]] = llvm.mlir.constant
// CHECK-NEXT: %[[OFFSET1:.*]] = llvm.mlir.constant
// CHECK-NEXT: llvm.getelementptr %[[BASE0]][%[[OFFSET1]]]
// CHECK-NEXT: %[[BASE1:.*]] = llvm.bitcast
// CHECK-NEXT: %[[OFFSET2:.*]] = llvm.mlir.constant
// CHECK-NEXT: %[[OFFSET3:.*]] = llvm.mul %[[OFFSET0]], %[[OFFSET2]]
// CHECK-NEXT: llvm.getelementptr %[[BASE1]][%[[OFFSET3]]]
%index = arith.constant 1 : i32
%0 = triton_gpu.alloc_tensor : tensor<128x16x32xf32, #shared0>
%1 = triton_gpu.extract_slice %0, %index {axis = 0: i32} : tensor<128x16x32xf32, #shared0> -> tensor<16x32xf32, #shared0>
return
}
}

// -----

#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK: basic_splat
Expand Down