diff --git a/src/tir/transforms/lower_device_kernel_launch.cc b/src/tir/transforms/lower_device_kernel_launch.cc index 5ffbf0d7a7fd..52f06ea45c7c 100644 --- a/src/tir/transforms/lower_device_kernel_launch.cc +++ b/src/tir/transforms/lower_device_kernel_launch.cc @@ -170,14 +170,27 @@ class DeviceKernelMutator : public StmtExprMutator { } PrimFunc UpdateKernelAttributes(const GlobalVar& gvar, PrimFunc func) const { - if (device_kernel_launch_.count(gvar.get())) { + bool is_kernel_launch = device_kernel_launch_.count(gvar.get()); + bool is_call_extern = extern_function_call_.count(gvar.get()); + CHECK(!is_kernel_launch || !is_call_extern) + << "Function " << gvar << " has multiple callees, " + << "and would need to be lowered into a call_extern at some call sites, " + << "and a device kernel launch at others. " + << "This case is not yet supported."; + + if (is_kernel_launch || is_call_extern) { + func = WithAttr(std::move(func), tvm::tir::attr::kIsGlobalFunc, Bool(true)); + } + + if (is_kernel_launch) { const auto& info = device_info_map_.at(gvar.get()); func = WithAttrs(std::move(func), {{tvm::attr::kCallingConv, Integer(tvm::CallingConv::kDeviceKernelLaunch)}, {tvm::tir::attr::kKernelLaunchParams, info.launch_params}, - {tvm::attr::kGlobalSymbol, info.global_symbol}, - {tvm::tir::attr::kIsGlobalFunc, Bool(true)}}); + {tvm::attr::kGlobalSymbol, info.global_symbol}}); + } else if (is_call_extern && !func->GetAttr(tvm::attr::kGlobalSymbol)) { + func = WithAttr(func, tvm::attr::kGlobalSymbol, gvar->name_hint); } return func; @@ -196,12 +209,31 @@ class DeviceKernelMutator : public StmtExprMutator { << gvar->name_hint << " did not appear within the IRModule"; const KernelInfo& dev_info = it->second; - auto caller_device_type = current_target_.value()->GetTargetDeviceType(); - auto callee_device_type = dev_info.target->GetTargetDeviceType(); - if (caller_device_type == callee_device_type) { + auto caller_target = current_target_.value(); + auto callee_target = dev_info.target; + + bool same_target = caller_target->str() == callee_target->str(); + if (same_target) { + // Calls within the same target may be handled at codegen time + // as internal subroutine calls. return std::move(node); } + bool same_device_type = + caller_target->GetTargetDeviceType() == callee_target->GetTargetDeviceType(); + if (same_device_type) { + // Calls to another target using the same device (e.g. LLVM + // calling a custom TIRToRuntime target) do not require a kernel + // launch, but need to be replaced with call_extern. + extern_function_call_.insert(gvar); + Array args; + args.push_back(StringImm(gvar->name_hint)); + for (const auto& arg : node->args) { + args.push_back(arg); + } + return Call(node->dtype, builtin::call_extern(), args); + } + ICHECK(dev_info.launch_params.defined()) << "CallNode attempted kernel launch to " << gvar->name_hint << " on target " << dev_info.target << ", but subroutine " << gvar->name_hint @@ -243,6 +275,7 @@ class DeviceKernelMutator : public StmtExprMutator { Optional current_target_; std::unordered_map device_info_map_; std::unordered_set device_kernel_launch_; + std::unordered_set extern_function_call_; }; namespace transform { diff --git a/tests/python/unittest/test_tir_transform_device_kernel_launch.py b/tests/python/unittest/test_tir_transform_device_kernel_launch.py index a0f77da3766b..34cde4e4b6ce 100644 --- a/tests/python/unittest/test_tir_transform_device_kernel_launch.py +++ b/tests/python/unittest/test_tir_transform_device_kernel_launch.py @@ -189,5 +189,54 @@ def kernel(A_data: T.handle("float32")): return mod +class TestSameDeviceDifferentTarget(BaseCompare): + """Handle subroutine calls to same device, different codegen + + The device kernel launch is only required when the caller and + callee are on different devices. However, if the caller and + callee use different codegen, then the call cannot be handled as + an internal call by a single codegen. Instead, it should be + lowered to a `T.call_extern`. + """ + + def before(self): + @I.ir_module + class mod: + @T.prim_func + def main(A: T.Buffer(1, "float32")): + T.func_attr({"target": T.target("llvm")}) + mod.kernel(A.data) + + @T.prim_func + def kernel(A_data: T.handle("float32")): + T.func_attr({"target": T.target("c")}) + A = T.decl_buffer(16, dtype="float32", data=A_data) + A[0] = 0.0 + + return mod + + def expected(self): + @I.ir_module + class mod: + @T.prim_func + def main(A: T.Buffer(1, "float32")): + T.func_attr({"target": T.target("llvm")}) + T.call_extern("kernel", A.data, dtype="void") + + @T.prim_func + def kernel(A_data: T.handle("float32")): + T.func_attr( + { + "target": T.target("c"), + "global_symbol": "kernel", + "tir.is_global_func": True, + } + ) + A = T.decl_buffer(16, dtype="float32", data=A_data) + A[0] = 0.0 + + return mod + + if __name__ == "__main__": tvm.testing.main()