Skip to content

Commit 94a44d7

Browse files
authored
[QoL][Relax] Return well-formed IR from relax::Function::CreateEmpty (#16861)
Prior to this commit, the static method `relax::Function::CreateEmpty` returned a function with a nullptr as the body. While only intended for use in bookkeeping for TVMScript, allowing nullptr in this location can cause unexpected segfaults while debugging. For example, adding a print statement This commit updates the `relax::Function::CreateEmpty` function to contain a placeholder body, consistent with the `ret_struct_info` argument provided.
1 parent 4cb4605 commit 94a44d7

File tree

2 files changed

+21
-5
lines changed

2 files changed

+21
-5
lines changed

include/tvm/relax/expr.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1045,6 +1045,8 @@ class ExternFuncNode : public BaseFuncNode {
10451045
class ExternFunc : public BaseFunc {
10461046
public:
10471047
TVM_DLL ExternFunc(String global_symbol, Span span = Span());
1048+
TVM_DLL ExternFunc(String global_symbol, StructInfo struct_info, Span span = Span());
1049+
10481050
TVM_DEFINE_OBJECT_REF_METHODS(ExternFunc, BaseFunc, ExternFuncNode);
10491051
TVM_DEFINE_OBJECT_REF_COW_METHOD(ExternFuncNode);
10501052
};

src/relax/ir/expr.cc

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -559,10 +559,18 @@ Function Function::CreateEmpty(Array<Var> params, StructInfo ret_struct_info, bo
559559

560560
FuncStructInfo finfo(param_sinfo, ret_struct_info, is_pure);
561561

562+
// A dummy body, to ensure that the empty function is still well-formed.
563+
Expr body = [&]() -> Expr {
564+
Var output("output", ret_struct_info);
565+
Call expr(ExternFunc("_dummy_function", FuncStructInfo({}, ret_struct_info)), {});
566+
567+
return SeqExpr({BindingBlock({VarBinding(output, expr)})}, output);
568+
}();
569+
562570
// set the fields
563571
ObjectPtr<FunctionNode> n = make_object<FunctionNode>();
564572
n->params = std::move(params);
565-
n->body = Expr();
573+
n->body = std::move(body);
566574
n->is_pure = is_pure;
567575
n->checked_type_ = GetStaticType(finfo);
568576
n->struct_info_ = std::move(finfo);
@@ -602,13 +610,19 @@ FuncStructInfo GetExternFuncStructInfo() {
602610

603611
TVM_REGISTER_NODE_TYPE(ExternFuncNode);
604612

605-
ExternFunc::ExternFunc(String global_symbol, Span span) {
613+
ExternFunc::ExternFunc(String global_symbol, Span span)
614+
: ExternFunc(global_symbol, GetExternFuncStructInfo(), span) {}
615+
616+
ExternFunc::ExternFunc(String global_symbol, StructInfo struct_info, Span span) {
617+
CHECK(struct_info.as<FuncStructInfoNode>())
618+
<< "ExternFunc must have FuncStructInfo, "
619+
<< "but declaration of '" << global_symbol << "' received " << struct_info;
620+
606621
ObjectPtr<ExternFuncNode> n = make_object<ExternFuncNode>();
607622
n->global_symbol = std::move(global_symbol);
608623
n->span = span;
609-
static auto sinfo = GetExternFuncStructInfo();
610-
n->struct_info_ = sinfo;
611-
n->checked_type_ = GetStaticType(sinfo);
624+
n->struct_info_ = struct_info;
625+
n->checked_type_ = GetStaticType(struct_info);
612626
data_ = std::move(n);
613627
}
614628

0 commit comments

Comments
 (0)