Skip to content

Commit a264e81

Browse files
committed
[Relax] Support callback as argument
Prior to this commit, calls from Relax to external PackedFuncs could only be done through the TVM global registry. While Relax functions accepting a callback could be written as `callback_arg: R.Callable(arg_struct_info, ret_struct_info)`, attempting to compile these functions would raise an error during the `CodeGenVM` step of `relax.build`. In addition, the global registry is only queried when initializing the `relax.VirtualMachine`, and so later changes requires restarting the VM. This commit updates both the `CodeGenVM` lowering pass and the relax VM to support callbacks. The is primarily intended for use with the `LazyTransformParams` pass, to improve flexibility by avoiding use of the global registry.
1 parent 268d15c commit a264e81

File tree

7 files changed

+219
-23
lines changed

7 files changed

+219
-23
lines changed

include/tvm/runtime/relax_vm/bytecode.h

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ enum class Opcode {
5858
Ret = 2U,
5959
Goto = 3U,
6060
If = 4U,
61+
CallFromRegister = 5U,
6162
};
6263

6364
/*! \brief A single virtual machine instruction.
@@ -183,10 +184,15 @@ struct Instruction {
183184
/*! \brief The instruction opcode. */
184185
Opcode op;
185186
union {
186-
struct /* Call */ {
187+
struct /* Call, CallFromRegister */ {
187188
/*! \brief The destination register. */
188189
RegName dst;
189-
/*! \brief The index into the packed function table. */
190+
/*! \brief The index of the function.
191+
*
192+
* For `OpCode::Call`, this is an index into the table of static
193+
* functions. For `OpCode::CallFromRegister`, this is an index
194+
* of a register.
195+
*/
190196
Index func_idx;
191197
/*! \brief The number of arguments to the packed function. */
192198
Index num_args;
@@ -208,27 +214,43 @@ struct Instruction {
208214
Index false_offset;
209215
};
210216
};
217+
211218
/*!
212219
* \brief Construct a Call instruction.
213-
* \param func_idx The index of the function to call.
220+
* \param func_idx The index of the function to call within the
221+
* static function table
214222
* \param num_args The number of arguments.
215223
* \param args The input arguments.
216224
* \param dst The destination register.
217225
* \return The call instruction.
218226
*/
219227
static Instruction Call(Index func_idx, Index num_args, Arg* args, RegName dst);
228+
229+
/*!
230+
* \brief Construct a Call instruction.
231+
* \param func_idx The index of the function to call within the
232+
* current stack frame's registers.
233+
* \param num_args The number of arguments.
234+
* \param args The input arguments.
235+
* \param dst The destination register.
236+
* \return The call instruction.
237+
*/
238+
static Instruction CallFromRegister(Index func_idx, Index num_args, Arg* args, RegName dst);
239+
220240
/*!
221241
* \brief Construct a return instruction.
222242
* \param result The register containing the return value.
223243
* \return The return instruction.
224244
*/
225245
static Instruction Ret(RegName result);
246+
226247
/*!
227248
* \brief Construct a goto instruction.
228249
* \param pc_offset The register containing the jump offset.
229250
* \return The goto instruction.
230251
*/
231252
static Instruction Goto(RegName pc_offset);
253+
232254
/*!
233255
* \brief Construct an If instruction.
234256
* \param cond The register containing the cond value.

src/relax/backend/vm/exec_builder.cc

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -138,10 +138,20 @@ void ExecBuilderNode::EndFunction(const std::string& func_name) {
138138

139139
void ExecBuilderNode::EmitCall(vm::Instruction::Arg func, std::vector<vm::Instruction::Arg> args,
140140
vm::RegName dst) {
141-
ICHECK(func.kind() == vm::Instruction::ArgKind::kFuncIdx);
141+
Opcode op_code;
142+
if (func.kind() == vm::Instruction::ArgKind::kFuncIdx) {
143+
op_code = Opcode::Call;
144+
} else if (func.kind() == vm::Instruction::ArgKind::kRegister) {
145+
op_code = Opcode::CallFromRegister;
146+
} else {
147+
LOG(FATAL) << "VM instruction for a function must be either "
148+
<< "kFuncIdx (static function ) "
149+
<< "or kRegister (function passed as parameter), "
150+
<< "but instead found " << func.kind();
151+
}
142152
// store instruction
143153
exec_->instr_offset.push_back(exec_->instr_data.size());
144-
exec_->instr_data.push_back(static_cast<ExecWord>(Opcode::Call));
154+
exec_->instr_data.push_back(static_cast<ExecWord>(op_code));
145155
exec_->instr_data.push_back(dst);
146156
exec_->instr_data.push_back(func.value());
147157
exec_->instr_data.push_back(args.size());
@@ -228,7 +238,8 @@ void ExecBuilderNode::CheckExecutable() {
228238
for (size_t idx = start_instr; idx < end_instr; ++idx) {
229239
Instruction instr = exec_->GetInstruction(idx);
230240
switch (instr.op) {
231-
case Opcode::Call: {
241+
case Opcode::Call:
242+
case Opcode::CallFromRegister: {
232243
check_func_defined(Instruction::Arg::FuncIdx(instr.func_idx));
233244
for (int i = 0; i < instr.num_args; ++i) {
234245
check_reg_defined(instr.args[i]);
@@ -280,7 +291,8 @@ void ExecBuilderNode::Formalize() {
280291
for (size_t idx = start_instr; idx < end_instr; ++idx) {
281292
Instruction instr = this->exec_->GetInstruction(idx);
282293
switch (instr.op) {
283-
case Opcode::Call: {
294+
case Opcode::Call:
295+
case Opcode::CallFromRegister: {
284296
// rewrite args
285297
for (int i = 0; i < instr.num_args; ++i) {
286298
if (instr.args[i].kind() == Instruction::ArgKind::kRegister &&

src/runtime/library_module.cc

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,15 @@ PackedFunc WrapPackedFunc(TVMBackendPackedCFunc faddr, const ObjectPtr<Object>&
7171
return PackedFunc([faddr, sptr_to_self](TVMArgs args, TVMRetValue* rv) {
7272
TVMValue ret_value;
7373
int ret_type_code = kTVMNullptr;
74-
int ret = (*faddr)(const_cast<TVMValue*>(args.values), const_cast<int*>(args.type_codes),
75-
args.num_args, &ret_value, &ret_type_code, nullptr);
76-
// NOTE: important to keep the original error message.
74+
auto arg_values = const_cast<TVMValue*>(args.values);
75+
auto arg_type_codes = const_cast<int*>(args.type_codes);
76+
int ret =
77+
(*faddr)(arg_values, arg_type_codes, args.num_args, &ret_value, &ret_type_code, nullptr);
78+
// NOTE: It is important to keep the original error message.
79+
// Using the `TVMThrowLastError()` function will also preserve the
80+
// full stack trace for debugging in pdb.
7781
if (ret != 0) {
78-
LOG(FATAL) << TVMGetLastError();
82+
TVMThrowLastError();
7983
}
8084
if (ret_type_code != kTVMNullptr) {
8185
*rv = TVMRetValue::MoveFromCHost(ret_value, ret_type_code);

src/runtime/relax_vm/bytecode.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,17 @@ Instruction Instruction::Call(Index func_idx, Index num_args, Instruction::Arg*
4242
return instr;
4343
}
4444

45+
Instruction Instruction::CallFromRegister(Index func_idx, Index num_args, Instruction::Arg* args,
46+
RegName dst) {
47+
Instruction instr;
48+
instr.op = Opcode::CallFromRegister;
49+
instr.dst = dst;
50+
instr.func_idx = func_idx;
51+
instr.num_args = num_args;
52+
instr.args = args;
53+
return instr;
54+
}
55+
4556
Instruction Instruction::Ret(RegName result) {
4657
Instruction instr;
4758
instr.op = Opcode::Ret;

src/runtime/relax_vm/executable.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,14 @@ Instruction Executable::GetInstruction(Index i) const {
134134
ExecWord* args = const_cast<ExecWord*>(&instr_data[offset + 4]);
135135
return Instruction::Call(func_idx, num_args, reinterpret_cast<Instruction::Arg*>(args), dst);
136136
}
137+
case Opcode::CallFromRegister: {
138+
RegName dst = instr_data[offset + 1];
139+
Index func_idx = instr_data[offset + 2];
140+
Index num_args = instr_data[offset + 3];
141+
ExecWord* args = const_cast<ExecWord*>(&instr_data[offset + 4]);
142+
return Instruction::CallFromRegister(func_idx, num_args,
143+
reinterpret_cast<Instruction::Arg*>(args), dst);
144+
}
137145
case Opcode::Ret: {
138146
RegName result = instr_data[offset + 1];
139147
return Instruction::Ret(result);

src/runtime/relax_vm/vm.cc

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -372,9 +372,10 @@ class VirtualMachineImpl : public VirtualMachine {
372372
/*!
373373
* \brief Run call instruction.
374374
* \param curr_frame The current frame.
375+
* \param callable The callable object, either PackedFunc or closure
375376
* \param inst The call instruction.
376377
*/
377-
virtual void RunInstrCall(VMFrame* curr_frame, Instruction inst);
378+
virtual void RunInstrCall(VMFrame* curr_frame, const ObjectRef& callable, Instruction inst);
378379

379380
/*! \brief Run VM dispatch loop. */
380381
void RunLoop();
@@ -506,14 +507,18 @@ void VirtualMachineImpl::SetInput(std::string func_name, bool with_param_module,
506507
//------------------------------------------
507508
void VirtualMachineImpl::InvokeClosurePacked(const ObjectRef& closure_or_packedfunc, TVMArgs args,
508509
TVMRetValue* rv) {
510+
ICHECK(closure_or_packedfunc.defined())
511+
<< "InvokeClosurePacked requires the callable object to be defined";
512+
509513
// run packed call if it is a packed func.
510514
if (auto* packed = closure_or_packedfunc.as<PackedFunc::ContainerType>()) {
511515
packed->CallPacked(args, rv);
512516
return;
513517
}
514518
// run closure call.
515519
auto* clo = closure_or_packedfunc.as<VMClosureObj>();
516-
ICHECK(clo != nullptr) << "Function expects a closure or PackedFunc ";
520+
ICHECK(clo != nullptr) << "Function expects a closure or PackedFunc, "
521+
<< "but received " << closure_or_packedfunc->GetTypeKey();
517522

518523
std::vector<TVMValue> values(args.size() + 1);
519524
std::vector<int> tcodes(args.size() + 1);
@@ -595,6 +600,8 @@ Optional<VMClosure> VirtualMachineImpl::GetClosureInternal(const String& func_na
595600
auto impl = PackedFunc([gf_idx](TVMArgs args, TVMRetValue* rv) {
596601
// Per convention, ctx ptr is a VirtualMachine*
597602
VirtualMachine* ctx_ptr = static_cast<VirtualMachine*>(args[0].operator void*());
603+
ICHECK(ctx_ptr) << "Context pointer for relax VM closure should be a VirtualMachine*, "
604+
<< "but was NULL";
598605

599606
std::vector<RegType> inputs(args.size() - 1);
600607
for (size_t i = 0; i < inputs.size(); ++i) {
@@ -644,7 +651,7 @@ RegType VirtualMachineImpl::InvokeBytecode(Index gf_idx, const std::vector<RegTy
644651
auto guard = PushFrame(this->pc_, gfunc);
645652
// Get new frame and set the caller info.
646653
VMFrame* curr_frame = frames_.back().get();
647-
if (curr_instr.op == Opcode::Call) {
654+
if (curr_instr.op == Opcode::Call || curr_instr.op == Opcode::CallFromRegister) {
648655
curr_frame->caller_return_register = curr_instr.dst;
649656
}
650657

@@ -688,8 +695,12 @@ void VirtualMachineImpl::InitFuncPool() {
688695
}
689696
}
690697

691-
void VirtualMachineImpl::RunInstrCall(VMFrame* curr_frame, Instruction instr) {
692-
DLOG(INFO) << "\n pc = " << pc_ << ", execute: " << GetFuncName(instr.func_idx);
698+
void VirtualMachineImpl::RunInstrCall(VMFrame* curr_frame, const ObjectRef& callable,
699+
Instruction instr) {
700+
ICHECK(callable.defined()) << "RunInstrCall requires the callable object to be defined";
701+
auto func_name = instr.op == Opcode::Call ? GetFuncName(instr.func_idx) : "<dynamic>";
702+
703+
DLOG(INFO) << "\n pc = " << pc_ << ", execute: " << func_name;
693704
int args_begin_offset = instrument_ != nullptr ? 4 : 0;
694705
// Use the call arg stack from the current frame to increase reuse
695706
// and avoid re-allocation
@@ -735,11 +746,11 @@ void VirtualMachineImpl::RunInstrCall(VMFrame* curr_frame, Instruction instr) {
735746
ICHECK_LT(static_cast<size_t>(instr.func_idx), this->func_pool_.size());
736747

737748
if (instrument_ == nullptr) {
738-
this->InvokeClosurePacked(func_pool_[instr.func_idx], args, &ret);
749+
this->InvokeClosurePacked(callable, args, &ret);
739750
} else {
740751
// insert light-weight instrument callback
741-
setter(0, func_pool_[instr.func_idx]);
742-
setter(1, GetFuncName(instr.func_idx));
752+
setter(0, callable);
753+
setter(1, func_name);
743754
setter(2, true);
744755
setter(3, nullptr);
745756
TVMRetValue rv;
@@ -758,7 +769,7 @@ void VirtualMachineImpl::RunInstrCall(VMFrame* curr_frame, Instruction instr) {
758769
ret_kind = rv;
759770
}
760771
if (ret_kind != static_cast<int>(VMInstrumentReturnKind::kSkipRun)) {
761-
this->InvokeClosurePacked(func_pool_[instr.func_idx], args, &ret);
772+
this->InvokeClosurePacked(callable, args, &ret);
762773
setter(2, false);
763774
setter(3, ret);
764775
instrument_.CallPacked(TVMArgs(values.data(), tcodes.data(), values.size()), &rv);
@@ -782,7 +793,11 @@ void VirtualMachineImpl::RunLoop() {
782793
Instruction instr = exec_->GetInstruction(pc_);
783794
switch (instr.op) {
784795
case Opcode::Call: {
785-
this->RunInstrCall(curr_frame, instr);
796+
this->RunInstrCall(curr_frame, func_pool_[instr.func_idx], instr);
797+
break;
798+
}
799+
case Opcode::CallFromRegister: {
800+
this->RunInstrCall(curr_frame, ReadRegister(curr_frame, instr.func_idx), instr);
786801
break;
787802
}
788803
case Opcode::Ret: {
@@ -1000,7 +1015,7 @@ class VirtualMachineProfiler : public VirtualMachineImpl {
10001015
}
10011016

10021017
protected:
1003-
void RunInstrCall(VMFrame* curr_frame, Instruction inst) override {
1018+
void RunInstrCall(VMFrame* curr_frame, const ObjectRef& callable, Instruction inst) override {
10041019
bool profiling = false;
10051020
if (prof_ && prof_->IsRunning()) {
10061021
auto f_name = GetFuncName(inst.func_idx);
@@ -1036,7 +1051,7 @@ class VirtualMachineProfiler : public VirtualMachineImpl {
10361051
}
10371052
}
10381053

1039-
VirtualMachineImpl::RunInstrCall(curr_frame, inst);
1054+
VirtualMachineImpl::RunInstrCall(curr_frame, callable, inst);
10401055

10411056
if (profiling) {
10421057
prof_->StopCall();

0 commit comments

Comments
 (0)