Skip to content

Commit 6b11cd2

Browse files
author
Giuseppe Rossini
committed
AOT] Remove lookup parameter function in AOT
This PR aims at removing the function call to extract the parameters within the AOT main function by introducing a tir::lookup_param builtin. This has different benefits: - In AOT we now only use the v_handle field - We save cycles by not calling an intermediate function to extract local parameters - We reduce code size, since we don't need to pack a call to extract parameters and we don't need to produce the lookup_param function anymore within the compilation unit Change-Id: I36c2f0724a79606424a4374f4f5cd669bb2a8a55
1 parent dbd076a commit 6b11cd2

File tree

6 files changed

+44
-36
lines changed

6 files changed

+44
-36
lines changed

include/tvm/tir/builtin.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,14 @@ TVM_DLL const Op& tvm_struct_get();
278278
*/
279279
TVM_DLL const Op& tvm_struct_set();
280280

281+
/*!
282+
* \brief See pseudo code
283+
* Type lookup_param(String param_name) {
284+
* return __tvm_param__param_name;
285+
* }
286+
*/
287+
TVM_DLL const Op& lookup_param();
288+
281289
/*!
282290
* \brief See pesudo code
283291
*

src/relay/backend/aot_executor_codegen.cc

Lines changed: 4 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -152,40 +152,23 @@ class AOTExecutorCodegen : public ExprVisitor {
152152
* \return Variable that represents the DLTensor associated with the parameters
153153
*/
154154
tir::Var PackParam(Expr expr) {
155-
// TODO(giuseros): Using call_extern to call into lookup_linked_param. This is because the
156-
// builtin::ret is not supported yet in the c target. Once return is supported we can use
157-
// tvm_call_packed_lowered().
158155
int param_sid = param_storage_ids_[params_by_expr_[expr]];
159-
auto lookup_linked_param_fn = tir::StringImm(::tvm::runtime::symbol::tvm_lookup_linked_param);
160156
auto param_array = te::Var(MakeString("param_", param_sid, "_array"), DataType::Handle());
161157

162158
// Compose the lookup_call using a local stack
163159
Array<tir::Stmt> lookup_call;
164-
auto param_var = te::Var(MakeString("param_", param_sid, "_value"), DataType::Handle());
165-
auto ret_var = te::Var("ret_value", DataType::Handle());
166-
auto ret_code = te::Var("ret_value", DataType::Handle());
167-
168-
lookup_call.push_back(tir::Evaluate(
169-
tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(),
170-
{param_var, 0, tir::builtin::kTVMValueContent, ConstInt32(param_sid)})));
171-
lookup_call.push_back(tir::Evaluate(
172-
tvm::tir::Call(DataType::Handle(), tir::builtin::call_extern(),
173-
{lookup_linked_param_fn, param_var, 0, 0, ret_var, ret_code, 0})));
174-
auto ret_var_handle = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_get(),
175-
{ret_var, 0, tir::builtin::kTVMValueContent});
176-
177160
// Set the param to the value returned by lookup_call
161+
auto param_handle = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::lookup_param(),
162+
{tir::StringImm(params_by_expr_[expr])});
163+
178164
tvm::PrimExpr set_param_array =
179165
tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(),
180-
{param_array, 0, tir::builtin::kArrData, ret_var_handle});
166+
{param_array, 0, tir::builtin::kArrData, param_handle});
181167
lookup_call.push_back(tir::Evaluate(set_param_array));
182168

183169
tir::Stmt lookup_body = tir::SeqStmt(lookup_call);
184170

185171
// Allocate the DLTensors on the stack
186-
lookup_body = tir::LetStmt(param_var, StackAlloca("arg_value", 1), lookup_body);
187-
lookup_body = tir::LetStmt(ret_var, StackAlloca("arg_value", 1), lookup_body);
188-
lookup_body = tir::LetStmt(ret_code, StackAlloca("arg_value", 1), lookup_body);
189172
lookup_body = tir::LetStmt(param_array, StackAlloca("arg_value", 1), lookup_body);
190173
stmts_.push_back(lookup_body);
191174
return param_array;

src/target/source/codegen_c.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -662,6 +662,11 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*)
662662
os << " != ";
663663
this->PrintExpr(op->args[0], os);
664664
os << ")";
665+
} else if (op->op.same_as(builtin::lookup_param())) {
666+
ICHECK_EQ(op->args.size(), 1);
667+
const StringImmNode* str = op->args[0].as<StringImmNode>();
668+
ICHECK(str != nullptr);
669+
os << "__tvm_param__" << str->value;
665670
} else {
666671
LOG(FATAL) << "Unresolved call " << op->op;
667672
}

src/target/source/codegen_c_host.cc

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -61,20 +61,7 @@ void CodeGenCHost::AddFunction(const PrimFunc& f) {
6161
CodeGenC::AddFunction(f);
6262
}
6363

64-
void CodeGenCHost::LinkParameters(Map<String, LinkedParam> params) {
65-
PrintFuncPrefix();
66-
stream << " " << tvm::runtime::symbol::tvm_lookup_linked_param
67-
<< "(void* args, int* arg_type_ids, int num_args, void* out_ret_value, "
68-
<< "int* out_ret_tcode, void* resource_handle) {\n";
69-
ICHECK_EQ(GetUniqueName(tvm::runtime::symbol::tvm_lookup_linked_param),
70-
tvm::runtime::symbol::tvm_lookup_linked_param)
71-
<< "builtin PackedFunc name already taken: " << tvm::runtime::symbol::tvm_lookup_linked_param;
72-
stream << " switch (((int64_t*) args)[0]) {\n"
73-
<< " default:\n"
74-
<< " out_ret_tcode[0] = " << kTVMNullptr << ";\n"
75-
<< " return 0;\n";
76-
77-
function_names_.push_back(tvm::runtime::symbol::tvm_lookup_linked_param);
64+
void CodeGenCHost::DeclareParameters(Map<String, LinkedParam> params) {
7865
for (auto kv : params) {
7966
decl_stream << "\n"
8067
<< "#ifdef __cplusplus\n"
@@ -93,6 +80,24 @@ void CodeGenCHost::LinkParameters(Map<String, LinkedParam> params) {
9380
<< "#ifdef __cplusplus\n"
9481
<< "} // extern \"C\"\n"
9582
<< "#endif\n";
83+
}
84+
}
85+
86+
void CodeGenCHost::LinkParameters(Map<String, LinkedParam> params) {
87+
PrintFuncPrefix();
88+
stream << " " << tvm::runtime::symbol::tvm_lookup_linked_param
89+
<< "(void* args, int* arg_type_ids, int num_args, void* out_ret_value, "
90+
<< "int* out_ret_tcode, void* resource_handle) {\n";
91+
ICHECK_EQ(GetUniqueName(tvm::runtime::symbol::tvm_lookup_linked_param),
92+
tvm::runtime::symbol::tvm_lookup_linked_param)
93+
<< "builtin PackedFunc name already taken: " << tvm::runtime::symbol::tvm_lookup_linked_param;
94+
stream << " switch (((int64_t*) args)[0]) {\n"
95+
<< " default:\n"
96+
<< " out_ret_tcode[0] = " << kTVMNullptr << ";\n"
97+
<< " return 0;\n";
98+
99+
function_names_.push_back(tvm::runtime::symbol::tvm_lookup_linked_param);
100+
for (auto kv : params) {
96101
stream << " case " << kv.second->id << ":\n"
97102
<< " ((uint64_t*)out_ret_value)[0] = (uint64_t) (uintptr_t) "
98103
<< ::tvm::runtime::symbol::tvm_param_prefix << kv.first << ";\n"
@@ -398,12 +403,14 @@ runtime::Module BuildCHost(IRModule mod, Target target) {
398403
cg.AddFunction(f);
399404
}
400405

401-
if (could_have_linked_params) {
406+
if (could_have_linked_params && !aot_executor_fn.defined()) {
402407
ICHECK(found_linked_params) << "-link-params given but none found";
408+
cg.DeclareParameters(linked_params);
403409
cg.LinkParameters(linked_params);
404410
}
405411

406412
if (aot_executor_fn.defined()) {
413+
cg.DeclareParameters(linked_params);
407414
cg.AddFunction(aot_executor_fn);
408415
}
409416

src/target/source/codegen_c_host.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ class CodeGenCHost final : public CodeGenC {
4343
void AddFunction(const PrimFunc& f);
4444

4545
/*! \brief Add linked parameters, if they are present. */
46+
void DeclareParameters(Map<String, LinkedParam> params);
4647
void LinkParameters(Map<String, LinkedParam> params);
4748

4849
void PrintType(DataType t, std::ostream& os) final; // NOLINT(*)

src/tir/op/builtin.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,10 @@ TIR_DEFINE_BUILTIN_FUNC(tvm_struct_set)
155155
.set_num_inputs(4)
156156
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kUpdateState));
157157

158+
TIR_DEFINE_BUILTIN_FUNC(lookup_param)
159+
.set_num_inputs(4)
160+
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kUpdateState));
161+
158162
TIR_DEFINE_BUILTIN_FUNC(tvm_throw_last_error)
159163
.set_num_inputs(0)
160164
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));

0 commit comments

Comments
 (0)