Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion python/tvm/tir/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,14 @@ def call_tir(global_var: tvm.ir.GlobalVar, *args):
The call expression.
"""
assert isinstance(global_var, tvm.ir.GlobalVar)
return Call(dtype="void", op=global_var, args=args)

dtype = "void"
if global_var.checked_type is not None:
ret_type = global_var.checked_type.ret_type
if hasattr(ret_type, "dtype"):
dtype = ret_type.dtype

return Call(dtype=dtype, op=global_var, args=args)


def start_profile_intrinsic(id):
Expand Down
14 changes: 13 additions & 1 deletion src/script/ir_builder/ir/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
#include <tvm/ir/module.h>
#include <tvm/runtime/registry.h>
#include <tvm/script/ir_builder/ir/ir.h>
#include <tvm/tir/function.h>
#include <tvm/tir/op.h>

#include "./utils.h"

Expand All @@ -38,7 +40,17 @@ GlobalVar DeclFunction(const String& func_name, const BaseFunc& func_signature)
IRModuleFrame frame = FindModuleFrame("I.DeclFunction");
CHECK(!frame->global_var_map.count(func_name))
<< "ValueError: function " << func_name << " already exists";
GlobalVar gv = GlobalVar(func_name);

auto gvar_type = [&]() -> Type {
if (auto prim_func = func_signature.as<tir::PrimFuncNode>()) {
Array<Type> arg_types = prim_func->params.Map([](const auto& var) { return GetType(var); });
return FuncType(arg_types, prim_func->ret_type, {}, {});
}

return {};
}();

GlobalVar gv = GlobalVar(func_name, gvar_type);
CHECK(frame->functions.find(gv) == frame->functions.end())
<< "ValueError: function " << func_name << " has already been defined.";
frame->global_var_map.Set(func_name, gv);
Expand Down
41 changes: 40 additions & 1 deletion src/tir/transforms/lower_device_kernel_launch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,36 @@ class DeviceInfoCollector : public StmtVisitor {
// The amount of dynamic shared memory used
Optional<PrimExpr> dyn_shmem_size{NullOpt};
};

class ReturnRemover : public StmtExprMutator {
public:
static Stmt Apply(const Stmt& stmt) {
ReturnRemover mutator;
return mutator(stmt);
}

private:
using Parent = StmtExprMutator;
Stmt VisitStmt_(const EvaluateNode* op) override {
if (auto* call = op->value.as<CallNode>()) {
if (call->op.same_as(builtin::ret())) {
ICHECK_EQ(call->args.size(), 1);
auto as_int = call->args[0].as<IntImmNode>();
ICHECK(as_int && as_int->value == 0)
<< "Device kernel may only contain successful return, T.ret(0)";
return Evaluate(0);
}
}
return Parent::VisitStmt_(op);
}

PrimExpr VisitExpr_(const CallNode* op) override {
if (op->op.same_as(builtin::ret())) {
LOG(FATAL) << "Call to builtin::ret() should only appear within an Evaluate node";
}
return Parent::VisitExpr_(op);
}
};
} // namespace

class DeviceKernelMutator : public StmtExprMutator {
Expand Down Expand Up @@ -185,10 +215,19 @@ class DeviceKernelMutator : public StmtExprMutator {
if (is_kernel_launch) {
const auto& info = device_info_map_.at(gvar.get());

// Kernel launches provide an int32 error code to the caller,
// but do not accept any return type from the callee.
{
auto write_ptr = func.CopyOnWrite();
write_ptr->ret_type = VoidType();
write_ptr->body = ReturnRemover::Apply(write_ptr->body);
}

func = WithAttrs(std::move(func),
{{tvm::attr::kCallingConv, Integer(tvm::CallingConv::kDeviceKernelLaunch)},
{tvm::tir::attr::kKernelLaunchParams, info.launch_params},
{tvm::attr::kGlobalSymbol, info.global_symbol}});

} else if (is_call_extern && !func->GetAttr<String>(tvm::attr::kGlobalSymbol)) {
func = WithAttr(func, tvm::attr::kGlobalSymbol, gvar->name_hint);
}
Expand All @@ -197,7 +236,7 @@ class DeviceKernelMutator : public StmtExprMutator {
}

private:
PrimExpr VisitExpr_(const CallNode* op) {
PrimExpr VisitExpr_(const CallNode* op) override {
auto node = Downcast<Call>(Parent::VisitExpr_(op));

auto* gvar = op->op.as<GlobalVarNode>();
Expand Down
11 changes: 7 additions & 4 deletions src/tir/transforms/make_unpacked_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,18 +64,21 @@ class SubroutineCallRewriter : public StmtExprMutator {

if (auto gvar = node->op.as<GlobalVarNode>()) {
if (external_methods_.count(gvar)) {
Array<PrimExpr> args = node->args.Map([this](const PrimExpr& arg) -> PrimExpr {
Array<PrimExpr> args = node->args.Map([](const PrimExpr& arg) -> PrimExpr {
if (auto* as_call = arg.as<CallNode>()) {
if (as_call->op.same_as(builtin::tvm_stack_make_array())) {
PrimExpr data_ptr = as_call->args[0];
made_change_ = true;
return data_ptr;
}
}
return arg;
});
if (!args.same_as(node->args)) {
node.CopyOnWrite()->args = args;

if (!args.same_as(node->args) || node->dtype != DataType::Int(32)) {
auto write_ptr = node.CopyOnWrite();
write_ptr->dtype = DataType::Int(32);
write_ptr->args = args;
made_change_ = true;
}
}
}
Expand Down
33 changes: 30 additions & 3 deletions src/tir/transforms/split_host_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class HostDeviceSplitter : public StmtMutator {
VarUseDefAnalyzer use_def(/*defined_vars=*/{}, /*visit_thread_extent=*/false);
use_def(body);

// Sort first by variable typ, then by variable name
// Sort first by variable type, then by variable name
std::vector<Var> params{use_def.undefined_.begin(), use_def.undefined_.end()};
std::sort(params.begin(), params.end(), [](const Var& a, const Var& b) {
auto sort_key = [](const Var& var) {
Expand All @@ -74,16 +74,43 @@ class HostDeviceSplitter : public StmtMutator {
return params;
}();

// CodeGenCPU is used for some device-side targets, such as
// "ext_dev", and expects to be able to return a int32_t status
// code.

bool can_propagate_errors = [&]() {
auto kind = device_target->GetTargetDeviceType();
return kind == kDLCPU || kind == kDLExtDev || kind == kDLHexagon;
}();
IntImm success(DataType::Int(32), 0);
Type kernel_ret_type;
if (can_propagate_errors) {
kernel_ret_type = PrimType(DataType::Int(32));
body = SeqStmt::Flatten(body, Evaluate(ret(success)));
} else {
kernel_ret_type = VoidType();
}

GlobalVar kernel_symbol_global = var_supply_();
PrimFunc device_func(params, body);
PrimFunc device_func(params, body, kernel_ret_type);
device_func = WithAttrs(std::move(device_func), {{tvm::attr::kTarget, device_target},
{tir::attr::kNoAlias, Bool(true)},
{tir::attr::kIsGlobalFunc, Bool(true)}});

(*device_mod_)->Add(kernel_symbol_global, device_func);
Array<PrimExpr> args = params.Map([](const Var& var) -> PrimExpr { return var; });

return Evaluate(Call(DataType::Void(), kernel_symbol_global, args));
if (can_propagate_errors) {
Var kernel_error_code("kernel_error_code", success->dtype);
Call kernel_call(success->dtype, kernel_symbol_global, args);
AssertStmt assert_success(kernel_error_code == success,
StringImm("Error executing compute kernel"), Evaluate(0));
LetStmt let_check(kernel_error_code, kernel_call, assert_success);

return std::move(let_check);
} else {
return Evaluate(Call(DataType::Void(), kernel_symbol_global, args));
}
}

// target ir module
Expand Down
38 changes: 38 additions & 0 deletions tests/python/unittest/test_tir_transform_split_host_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,44 @@ def main_kernel(n: T.int32):
return mod


class TestSplitHostDeviceOnCPU(BaseCompare):
"""A kernel running on the CPU may return an error code"""

def before(self):
@I.ir_module
class mod:
@T.prim_func
def main(n: T.int32):
T.func_attr({"target": T.target("cuda", host="llvm -opt-level=0")})
T.attr(T.target("llvm"), "target", 0)
T.evaluate(n)

return mod

def expected(self):
@I.ir_module
class mod:
@T.prim_func
def main(n: T.int32):
T.func_attr({"target": T.target("cuda", host="llvm -opt-level=0")})
err = mod.main_kernel(n)
assert err == 0, "Error executing compute kernel"

@T.prim_func
def main_kernel(n: T.int32) -> T.int32:
T.func_attr(
{
"target": T.target("llvm"),
"tir.noalias": T.bool(True),
"tir.is_global_func": True,
}
)
T.evaluate(n)
T.ret(0)

return mod


class TestSplitHostDeviceWithoutFuncHostAttribute(BaseCompare):
"""Like TestSplitHostDevice, but no host specified in the host's target

Expand Down
17 changes: 17 additions & 0 deletions tests/python/unittest/test_tvmscript_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -3817,6 +3817,22 @@ def subroutine(A_data: T.handle("float32"), n: T.int32):
return mod


def subroutine_call_returning_int():
"""An internal function call may return non-void"""

@I.ir_module
class mod:
@T.prim_func
def main(A: T.Buffer(2, "float32")):
mod.subroutine(A[0]) + mod.subroutine(A[1])

@T.prim_func
def subroutine(x: T.float32) -> T.float32:
T.ret(x * x)

return mod


def undefined_data_ptr_in_decl_buffer():
"""The T.decl_buffer syntax should not introduce an Allocate

Expand Down Expand Up @@ -4009,6 +4025,7 @@ def func():
ir_module_with_attrs,
nested_seqstmt,
subroutine_call,
subroutine_call_returning_int,
undefined_data_ptr_in_decl_buffer,
undefined_shape_in_decl_buffer,
undefined_stride_in_decl_buffer,
Expand Down