Skip to content

Commit af722ad

Browse files
committed
Restrict the int32 return type to targets that need to propagate errors
1 parent 1851d60 commit af722ad

File tree

1 file changed

+25
-12
lines changed

1 file changed

+25
-12
lines changed

src/tir/transforms/split_host_device.cc

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ class HostDeviceSplitter : public StmtMutator {
6060
VarUseDefAnalyzer use_def(/*defined_vars=*/{}, /*visit_thread_extent=*/false);
6161
use_def(body);
6262

63-
// Sort first by variable typ, then by variable name
63+
// Sort first by variable type, then by variable name
6464
std::vector<Var> params{use_def.undefined_.begin(), use_def.undefined_.end()};
6565
std::sort(params.begin(), params.end(), [](const Var& a, const Var& b) {
6666
auto sort_key = [](const Var& var) {
@@ -77,27 +77,40 @@ class HostDeviceSplitter : public StmtMutator {
7777
// CodeGenCPU is used for some device-side targets, such as
7878
// "ext_dev", and expects to be able to return a int32_t status
7979
// code.
80-
auto error_code_dtype = DataType::Int(32);
81-
IntImm success(error_code_dtype, 0);
82-
body = SeqStmt::Flatten(body, Evaluate(ret(success)));
80+
81+
bool can_propagate_errors = [&]() {
82+
auto kind = device_target->GetTargetDeviceType();
83+
return kind == kDLCPU || kind == kDLExtDev || kind == kDLHexagon;
84+
}();
85+
IntImm success(DataType::Int(32), 0);
86+
Type kernel_ret_type;
87+
if (can_propagate_errors) {
88+
kernel_ret_type = PrimType(DataType::Int(32));
89+
body = SeqStmt::Flatten(body, Evaluate(ret(success)));
90+
} else {
91+
kernel_ret_type = VoidType();
92+
}
8393

8494
GlobalVar kernel_symbol_global = var_supply_();
85-
PrimFunc device_func(params, body, PrimType(error_code_dtype));
95+
PrimFunc device_func(params, body, kernel_ret_type);
8696
device_func = WithAttrs(std::move(device_func), {{tvm::attr::kTarget, device_target},
8797
{tir::attr::kNoAlias, Bool(true)},
8898
{tir::attr::kIsGlobalFunc, Bool(true)}});
8999

90100
(*device_mod_)->Add(kernel_symbol_global, device_func);
91101
Array<PrimExpr> args = params.Map([](const Var& var) -> PrimExpr { return var; });
92102

93-
Var kernel_error_code("kernel_error_code", error_code_dtype);
103+
if (can_propagate_errors) {
104+
Var kernel_error_code("kernel_error_code", success->dtype);
105+
Call kernel_call(success->dtype, kernel_symbol_global, args);
106+
AssertStmt assert_success(kernel_error_code == success,
107+
StringImm("Error executing compute kernel"), Evaluate(0));
108+
LetStmt let_check(kernel_error_code, kernel_call, assert_success);
94109

95-
AssertStmt assert_success(kernel_error_code == success,
96-
StringImm("Error executing compute kernel"), Evaluate(0));
97-
LetStmt kernel_call(kernel_error_code, Call(error_code_dtype, kernel_symbol_global, args),
98-
assert_success);
99-
100-
return std::move(kernel_call);
110+
return std::move(let_check);
111+
} else {
112+
return Evaluate(Call(DataType::Void(), kernel_symbol_global, args));
113+
}
101114
}
102115

103116
// target ir module

0 commit comments

Comments
 (0)