Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 39 additions & 6 deletions src/tir/transforms/lower_device_kernel_launch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -170,14 +170,27 @@ class DeviceKernelMutator : public StmtExprMutator {
}

PrimFunc UpdateKernelAttributes(const GlobalVar& gvar, PrimFunc func) const {
if (device_kernel_launch_.count(gvar.get())) {
bool is_kernel_launch = device_kernel_launch_.count(gvar.get());
bool is_call_extern = extern_function_call_.count(gvar.get());
CHECK(!is_kernel_launch || !is_call_extern)
<< "Function " << gvar << " has multiple callees, "
<< "and would need to be lowered into a call_extern at some call sites, "
<< "and a device kernel launch at others. "
<< "This case is not yet supported.";

if (is_kernel_launch || is_call_extern) {
func = WithAttr(std::move(func), tvm::tir::attr::kIsGlobalFunc, Bool(true));
}

if (is_kernel_launch) {
const auto& info = device_info_map_.at(gvar.get());

func = WithAttrs(std::move(func),
{{tvm::attr::kCallingConv, Integer(tvm::CallingConv::kDeviceKernelLaunch)},
{tvm::tir::attr::kKernelLaunchParams, info.launch_params},
{tvm::attr::kGlobalSymbol, info.global_symbol},
{tvm::tir::attr::kIsGlobalFunc, Bool(true)}});
{tvm::attr::kGlobalSymbol, info.global_symbol}});
} else if (is_call_extern && !func->GetAttr<String>(tvm::attr::kGlobalSymbol)) {
func = WithAttr(func, tvm::attr::kGlobalSymbol, gvar->name_hint);
}

return func;
Expand All @@ -196,12 +209,31 @@ class DeviceKernelMutator : public StmtExprMutator {
<< gvar->name_hint << " did not appear within the IRModule";
const KernelInfo& dev_info = it->second;

auto caller_device_type = current_target_.value()->GetTargetDeviceType();
auto callee_device_type = dev_info.target->GetTargetDeviceType();
if (caller_device_type == callee_device_type) {
auto caller_target = current_target_.value();
auto callee_target = dev_info.target;

bool same_target = caller_target->str() == callee_target->str();
if (same_target) {
// Calls within the same target may be handled at codegen time
// as internal subroutine calls.
return std::move(node);
}

bool same_device_type =
caller_target->GetTargetDeviceType() == callee_target->GetTargetDeviceType();
if (same_device_type) {
// Calls to another target using the same device (e.g. LLVM
// calling a custom TIRToRuntime target) do not require a kernel
// launch, but need to be replaced with call_extern.
extern_function_call_.insert(gvar);
Array<PrimExpr> args;
args.push_back(StringImm(gvar->name_hint));
for (const auto& arg : node->args) {
args.push_back(arg);
}
return Call(node->dtype, builtin::call_extern(), args);
}

ICHECK(dev_info.launch_params.defined())
<< "CallNode attempted kernel launch to " << gvar->name_hint << " on target "
<< dev_info.target << ", but subroutine " << gvar->name_hint
Expand Down Expand Up @@ -243,6 +275,7 @@ class DeviceKernelMutator : public StmtExprMutator {
Optional<Target> current_target_;
std::unordered_map<const GlobalVarNode*, KernelInfo> device_info_map_;
std::unordered_set<const GlobalVarNode*> device_kernel_launch_;
std::unordered_set<const GlobalVarNode*> extern_function_call_;
};

namespace transform {
Expand Down
49 changes: 49 additions & 0 deletions tests/python/unittest/test_tir_transform_device_kernel_launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,5 +189,54 @@ def kernel(A_data: T.handle("float32")):
return mod


class TestSameDeviceDifferentTarget(BaseCompare):
"""Handle subroutine calls to same device, different codegen

The device kernel launch is only required when the caller and
callee are on different devices. However, if the caller and
callee use different codegen, then the call cannot be handled as
an internal call by a single codegen. Instead, it should be
lowered to a `T.call_extern`.
"""

def before(self):
@I.ir_module
class mod:
@T.prim_func
def main(A: T.Buffer(1, "float32")):
T.func_attr({"target": T.target("llvm")})
mod.kernel(A.data)

@T.prim_func
def kernel(A_data: T.handle("float32")):
T.func_attr({"target": T.target("c")})
A = T.decl_buffer(16, dtype="float32", data=A_data)
A[0] = 0.0

return mod

def expected(self):
@I.ir_module
class mod:
@T.prim_func
def main(A: T.Buffer(1, "float32")):
T.func_attr({"target": T.target("llvm")})
T.call_extern("kernel", A.data, dtype="void")

@T.prim_func
def kernel(A_data: T.handle("float32")):
T.func_attr(
{
"target": T.target("c"),
"global_symbol": "kernel",
"tir.is_global_func": True,
}
)
A = T.decl_buffer(16, dtype="float32", data=A_data)
A[0] = 0.0

return mod


if __name__ == "__main__":
tvm.testing.main()