Skip to content

Commit 3b5d135

Browse files
committed
[AOT] Avoid call_extern() with incorrect argument count
Prior to this commit, if device initialization is required, the AOT main function produced a `call_extern()` that included the device context as input. This commit updates the AOT main function to provide the device context only if the function being called accepts a device context as input. If an extra device context argument is included at the call site, the C codegen would produce a function signature that includes the device context for the caller's compilation unit, but a signature without the device context for the callee's compilation unit. While this can compile and run in some cases, it is undefined behavior for the signature to vary between compilation units, and should be avoided. This was initially discovered while debugging apache#14985, in which changes to the lowering flow resulted in the caller and callee being within the same compilation unit.
1 parent 3c23865 commit 3b5d135

File tree

1 file changed

+37
-1
lines changed

1 file changed

+37
-1
lines changed

src/relay/backend/aot_executor_codegen.cc

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -454,7 +454,32 @@ class AOTExecutorCodegen : public MixedModeVisitor {
454454
// call_extern calling convention with optional context
455455
if (has_c_device_api_context) {
456456
device_context = device_contexts_.Get(global_var).value();
457-
args.push_back(device_context);
457+
458+
// call_extern has no further legalization steps, and
459+
// requires the number of arguments to match exactly. For
460+
// internal calls, conditionally append the device context.
461+
bool requires_device_context = [&]() -> bool {
462+
Optional<Integer> opt = num_arguments_.Get(global_var);
463+
if (!opt.defined()) {
464+
// For external calls, we must trust that the user has
465+
// supplied a kernel that accepts a device_context
466+
// argument.
467+
return true;
468+
}
469+
int num_callee_params = opt.value()->value;
470+
int num_args = call_lowered_props.arguments.size();
471+
if (num_callee_params == num_args) {
472+
return false;
473+
} else if (num_callee_params == num_args + 1) {
474+
return true;
475+
} else {
476+
LOG(FATAL) << "Callee " << global_var << " requires " << num_callee_params
477+
<< ", but is called with " << num_args << " arguments.";
478+
}
479+
}();
480+
if (requires_device_context) {
481+
args.push_back(device_context);
482+
}
458483
}
459484
func_call = tir::Evaluate(AddCheckReturn(
460485
tvm::tir::Call(DataType::Int(32), tvm::tir::builtin::call_extern(), args)));
@@ -1007,6 +1032,8 @@ class AOTExecutorCodegen : public MixedModeVisitor {
10071032
Map<String, tir::Var> devices_;
10081033
/*! \brief map of GlobalVars to C Device API contexts */
10091034
Map<GlobalVar, tir::Var> device_contexts_;
1035+
/*! \brief map of GlobalVars to the number of arguments they require */
1036+
Map<GlobalVar, Integer> num_arguments_;
10101037
/*! \brief input and output variables belonging to the main function signature */
10111038
Array<tir::Var> main_signature_;
10121039
/*! \brief input and output variables belonging to the main function signature */
@@ -1183,6 +1210,15 @@ class AOTExecutorCodegen : public MixedModeVisitor {
11831210
}
11841211

11851212
CollectDeviceVariables(lowered_mod->GetAttr<Map<GlobalVar, String>>("device_contexts").value());
1213+
num_arguments_ = [&]() -> Map<GlobalVar, Integer> {
1214+
Map<GlobalVar, Integer> arg_count;
1215+
for (const auto& [gvar, func] : lowered_mod->functions) {
1216+
if (const auto* prim_func = func.as<tir::PrimFuncNode>()) {
1217+
arg_count.Set(gvar, prim_func->params.size());
1218+
}
1219+
}
1220+
return arg_count;
1221+
}();
11861222
VisitExpr(lowered_main_func->body);
11871223

11881224
// Create the runner function. Please note that the function is not legal yet

0 commit comments

Comments
 (0)