Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,13 +89,13 @@ TritonGPUConversionTarget::TritonGPUConversionTarget(
return false;
});
addDynamicallyLegalOp<triton::FuncOp>([](triton::FuncOp funcOp) -> bool {
for (auto arg : funcOp.getArguments()) {
if (auto tensor = dyn_cast<RankedTensorType>(arg.getType())) {
if (!tensor.getEncoding())
return false;
}
}
return true;
auto check = [](auto types) {
return llvm::all_of(types, [](auto type) {
auto tensor = dyn_cast<RankedTensorType>(type);
return !tensor || tensor.getEncoding();
});
};
return check(funcOp.getArgumentTypes()) && check(funcOp.getResultTypes());
});
}

Expand Down
57 changes: 4 additions & 53 deletions lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "triton/Conversion/TritonToTritonGPU/Passes.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Utility.h"
#include "triton/Dialect/Triton/Transforms/FunctionTypeConversion.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
Expand Down Expand Up @@ -488,55 +489,6 @@ struct TritonMapElementwisePattern
}
};

class TritonFuncOpPattern : public OpConversionPattern<triton::FuncOp> {
public:
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(triton::FuncOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto converter = getTypeConverter();
TypeConverter::SignatureConversion result(op.getNumArguments());
auto newOp = rewriter.replaceOpWithNewOp<triton::FuncOp>(
op, op.getName(), op.getFunctionType());
addNamedAttrs(newOp, adaptor.getAttributes());
rewriter.inlineRegionBefore(op.getBody(), newOp.getBody(),
newOp.getBody().end());
// Convert just the entry block. The remaining unstructured control flow is
// converted by br patterns.
if (!newOp.getBody().empty())
rewriter.applySignatureConversion(&newOp.getBody().front(), result,
converter);
return success();
}
};

class TritonCallOpPattern : public OpConversionPattern<triton::CallOp> {
public:
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(triton::CallOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto newOp = rewriter.replaceOpWithNewOp<triton::CallOp>(
op, op.getCallee(), op.getResultTypes(), adaptor.getOperands());
addNamedAttrs(newOp, adaptor.getAttributes());
return success();
}
};

class TritonReturnOpPattern : public OpConversionPattern<ReturnOp> {
public:
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(ReturnOp op, ReturnOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<ReturnOp>(op, adaptor.getOperands());
return success();
}
};

void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
RewritePatternSet &patterns, unsigned numCTAs) {
MLIRContext *context = patterns.getContext();
Expand Down Expand Up @@ -584,10 +536,7 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
GenericOpPattern<triton::DescriptorStoreOp>,
GenericOpPattern<triton::DescriptorReduceOp>,
// this assumes the right layout will be set later for dot scaled.
GenericOpPattern<triton::DotScaledOp>,
GenericOpPattern<triton::CallOp>,
GenericOpPattern<ReturnOp>,
TritonFuncOpPattern
GenericOpPattern<triton::DotScaledOp>
// clang-format on
>(typeConverter, context);
}
Expand Down Expand Up @@ -803,6 +752,8 @@ class ConvertTritonToTritonGPU
// add rules
populateArithPatternsAndLegality(typeConverter, patterns, target);
populateMathPatternsAndLegality(typeConverter, patterns, target);
FuncArgRenamer renamer;
populateFunctionTypeConversions(typeConverter, renamer, patterns);
populateTritonPatterns(typeConverter, patterns, numCTAs);
// TODO: can we use
// mlir::scf::populateSCFStructurealTypeConversionsAndLegality(...) here?
Expand Down
24 changes: 24 additions & 0 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1312,6 +1312,30 @@ def kernel(X, Y, Z):
assert torch.equal(z, ref + x + y)


@triton.jit(noinline=True)
def noinline_load_block_fn(ptr, BLOCK_SIZE: tl.constexpr):
offsets = tl.arange(0, BLOCK_SIZE)
return tl.load(ptr + offsets)


def test_noinline_returns_tensor(device):

@triton.jit
def kernel(X, Y, Z, BLOCK_SIZE: tl.constexpr):
x = noinline_load_block_fn(X, BLOCK_SIZE)
y = noinline_load_block_fn(Y, BLOCK_SIZE)
offsets = tl.arange(0, BLOCK_SIZE)
tl.store(Z + offsets, x + y)

BLOCK_SIZE = 128
torch.manual_seed(0)
x = torch.randn(BLOCK_SIZE, device=device, dtype=torch.float32)
y = torch.randn(BLOCK_SIZE, device=device, dtype=torch.float32)
z = torch.empty_like(x)
kernel[(1, )](x, y, z, BLOCK_SIZE=BLOCK_SIZE, num_warps=1)
assert torch.equal(z, x + y)


# ---------------
# test atomics
# ---------------
Expand Down
43 changes: 43 additions & 0 deletions test/Conversion/triton_to_tritongpu.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -182,3 +182,46 @@ tt.func @split_op(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>) {
tt.store %2, %res1 : tensor<64x!tt.ptr<f32>>
tt.return
}

// -----

// CHECK-LABEL: tt.func private @callee
// CHECK-SAME: (%{{.*}}: !tt.ptr<i32>) -> tensor<128xi32, #{{.*}}>
tt.func private @callee(%arg0: !tt.ptr<i32>) -> tensor<128xi32> {
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
// CHECK: tt.return %{{.*}} : tensor<128xi32, #{{.*}}>
tt.return %0 : tensor<128xi32>
}

// CHECK-LABEL: tt.func @caller
tt.func @caller(%ptr: !tt.ptr<i32>) {
// CHECK: %{{.*}} = tt.call @callee(%{{.*}}) : (!tt.ptr<i32>) -> tensor<128xi32, #{{.*}}>
%v = tt.call @callee(%ptr) : (!tt.ptr<i32>) -> tensor<128xi32>
%ptrs = tt.splat %ptr : !tt.ptr<i32> -> tensor<128x!tt.ptr<i32>>
tt.store %ptrs, %v : tensor<128x!tt.ptr<i32>>
tt.return
}

// -----

// When a callee returns a tensor whose default encoding doesn't match what
// the caller's consumer wants, a ttg.convert_layout should be auto-inserted
// at the call boundary.

// CHECK-LABEL: tt.func private @make_a
// CHECK: tt.return %{{.*}} : tensor<128x32xf16, #[[$BLOCKED:[^,>]+]]>
tt.func private @make_a() -> tensor<128x32xf16> {
%a = arith.constant dense<1.0> : tensor<128x32xf16>
tt.return %a : tensor<128x32xf16>
}

// CHECK-LABEL: tt.func @call_into_dot
// CHECK: %[[V:.*]] = tt.call @make_a() : () -> tensor<128x32xf16, #[[$BLOCKED]]>
// CHECK: ttg.convert_layout %[[V]] : tensor<128x32xf16, #[[$BLOCKED]]> -> tensor<128x32xf16, #ttg.dot_op<{{.*}}>>
// CHECK: tt.dot
tt.func @call_into_dot(%b: tensor<32x128xf16>) {
%a = tt.call @make_a() : () -> tensor<128x32xf16>
%c = arith.constant dense<0.0> : tensor<128x128xf32>
%0 = tt.dot %a, %b, %c : tensor<128x32xf16> * tensor<32x128xf16> -> tensor<128x128xf32>
tt.return
}
Loading