@@ -60,7 +60,7 @@ class HostDeviceSplitter : public StmtMutator {
6060 VarUseDefAnalyzer use_def (/* defined_vars=*/ {}, /* visit_thread_extent=*/ false );
6161 use_def (body);
6262
63- // Sort first by variable typ , then by variable name
63+ // Sort first by variable type , then by variable name
6464 std::vector<Var> params{use_def.undefined_ .begin (), use_def.undefined_ .end ()};
6565 std::sort (params.begin (), params.end (), [](const Var& a, const Var& b) {
6666 auto sort_key = [](const Var& var) {
@@ -77,27 +77,40 @@ class HostDeviceSplitter : public StmtMutator {
7777 // CodeGenCPU is used for some device-side targets, such as
7878 // "ext_dev", and expects to be able to return a int32_t status
7979 // code.
80- auto error_code_dtype = DataType::Int (32 );
81- IntImm success (error_code_dtype, 0 );
82- body = SeqStmt::Flatten (body, Evaluate (ret (success)));
80+
81+ bool can_propagate_errors = [&]() {
82+ auto kind = device_target->GetTargetDeviceType ();
83+ return kind == kDLCPU || kind == kDLExtDev || kind == kDLHexagon ;
84+ }();
85+ IntImm success (DataType::Int (32 ), 0 );
86+ Type kernel_ret_type;
87+ if (can_propagate_errors) {
88+ kernel_ret_type = PrimType (DataType::Int (32 ));
89+ body = SeqStmt::Flatten (body, Evaluate (ret (success)));
90+ } else {
91+ kernel_ret_type = VoidType ();
92+ }
8393
8494 GlobalVar kernel_symbol_global = var_supply_ ();
85- PrimFunc device_func (params, body, PrimType (error_code_dtype) );
95+ PrimFunc device_func (params, body, kernel_ret_type );
8696 device_func = WithAttrs (std::move (device_func), {{tvm::attr::kTarget , device_target},
8797 {tir::attr::kNoAlias , Bool (true )},
8898 {tir::attr::kIsGlobalFunc , Bool (true )}});
8999
90100 (*device_mod_)->Add (kernel_symbol_global, device_func);
91101 Array<PrimExpr> args = params.Map ([](const Var& var) -> PrimExpr { return var; });
92102
93- Var kernel_error_code (" kernel_error_code" , error_code_dtype);
103+ if (can_propagate_errors) {
104+ Var kernel_error_code (" kernel_error_code" , success->dtype );
105+ Call kernel_call (success->dtype , kernel_symbol_global, args);
106+ AssertStmt assert_success (kernel_error_code == success,
107+ StringImm (" Error executing compute kernel" ), Evaluate (0 ));
108+ LetStmt let_check (kernel_error_code, kernel_call, assert_success);
94109
95- AssertStmt assert_success (kernel_error_code == success,
96- StringImm (" Error executing compute kernel" ), Evaluate (0 ));
97- LetStmt kernel_call (kernel_error_code, Call (error_code_dtype, kernel_symbol_global, args),
98- assert_success);
99-
100- return std::move (kernel_call);
110+ return std::move (let_check);
111+ } else {
112+ return Evaluate (Call (DataType::Void (), kernel_symbol_global, args));
113+ }
101114 }
102115
103116 // target ir module
0 commit comments