Skip to content

Commit c32bdc6

Browse files
YuchenJinyongwww
authored andcommitted
[VM] Add set_input interface; Fix e2e tuning script. (tlc-pack#166)
* Add set_input interface. * Address comments.
1 parent 87508aa commit c32bdc6

File tree

11 files changed

+266
-28
lines changed

11 files changed

+266
-28
lines changed

apps/relax_examples/e2e_auto_tir.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -129,15 +129,18 @@ def apply_opt_before_tuning(
129129
return relax_mod
130130

131131

132-
def f_measurement(rt_mod: runtime.Module, device: runtime.ndarray.Device, *input_data):
132+
def f_measurement(
133+
rt_mod: runtime.Module, device: runtime.ndarray.Device, input_data: Dict[str, runtime.NDArray]
134+
):
133135
vm = relax.vm.VirtualMachine(exec=rt_mod, device=device)
136+
vm.set_input("main", **input_data)
134137
evaluator = vm.module.time_evaluator(
135138
func_name="main",
136139
dev=device,
137140
repeat=5,
138141
min_repeat_ms=500,
139142
)
140-
print(evaluator(*input_data))
143+
print(evaluator())
141144

142145

143146
def get_runner():
@@ -166,10 +169,12 @@ def main():
166169
ARGS.input_shape,
167170
cache_dir=ARGS.cache_dir,
168171
)
169-
print(f"Workload: {ARGS.workload}")
170-
print(f" input_name: {input_name}")
171-
print(f" input_shape: {input_shape}")
172-
print(f" input_dtype: {input_dtype}")
172+
input_info = {input_name: input_shape}
173+
input_data = {}
174+
for input_name, input_shape in input_info.items():
175+
print(f" input_name: {input_name}")
176+
print(f" input_shape: {input_shape}")
177+
print(f" input_dtype: {input_dtype}")
173178

174179
# translate the ResNet model from Relay to Relax
175180
relax_mod = apply_opt_before_tuning(relay_mod, params, target=ARGS.target)
@@ -189,10 +194,13 @@ def main():
189194
num_threads=os.cpu_count(),
190195
)
191196

192-
if input_dtype.startswith("float"):
193-
input_data = [np.random.uniform(size=input_shape).astype(input_dtype)]
194-
else:
195-
input_data = [np.random.randint(low=0, high=10000, size=input_shape, dtype=input_dtype)]
197+
for input_name, input_shape in input_info.items():
198+
if input_dtype.startswith("float"):
199+
input_data[input_name] = np.random.uniform(size=input_shape).astype(input_dtype)
200+
else:
201+
input_data[input_name] = np.random.randint(
202+
low=0, high=10000, size=input_shape, dtype=input_dtype
203+
)
196204

197205
if ARGS.rpc_config:
198206
run_module_via_rpc(
@@ -204,8 +212,7 @@ def main():
204212
)
205213
else:
206214
dev = tvm.device(ARGS.target.kind.name)
207-
input_data = [runtime.ndarray.array(arg, dev) for arg in input_data]
208-
f_measurement(executable.mod, dev, *input_data)
215+
f_measurement(executable.mod, dev, input_data)
209216

210217

211218
if __name__ == "__main__":

include/tvm/relax/exec_builder.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,9 @@ class ExecBuilderNode : public Object {
5252
* \brief To annotate the start of a vm function.
5353
* \param func The function name.
5454
* \param num_inputs The number of inputs.
55+
* \param param_names The function parameter names.
5556
*/
56-
void EmitFunction(std::string func, int64_t num_inputs);
57+
void EmitFunction(std::string func, int64_t num_inputs, Array<String> param_names);
5758
/*!
5859
* \brief Emit a call instruction for a packed function.
5960
* \param func The packed function name.

include/tvm/runtime/relax_vm/executable.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,8 @@ struct VMFunction {
7878
Index num_args;
7979
/*! \brief The register file size of the function. */
8080
Index register_file_size;
81+
/*! \brief The function parameter names.*/
82+
std::vector<std::string> param_names;
8183
};
8284

8385
/*!

include/tvm/runtime/relax_vm/vm.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,29 @@ class VirtualMachine : public runtime::ModuleNode {
167167
*/
168168
inline void RunInstrCall(VMFrame* curr_frame, Instruction inst);
169169

170+
/*!
171+
* \brief Set inputs to a function.
172+
* \param func_name The function name.
173+
* \param args args[offset:] are arguments to the function. If the arguments are not of the
174+
* correct device for the function, they will be copied to the device.
175+
* \param offset Starting offset of the arguments in \p args.
176+
* \note This interface works when using VM over RPC by internally converting NDArray in
177+
* the arguments to DLTensor, which is supported in RPC where remote could only have a minimal C
178+
* runtime.
179+
*/
180+
void SetInput(std::string func_name, TVMArgs args, int offset);
181+
182+
/*!
183+
* \brief Set a function argument with a given index to an input tensor.
184+
* \param func_args the function arguments.
185+
* \param inp_tensor some input tensor (not necessarily DLTensor). When it's an NDArray or a list
186+
* of NDArray, they will be converted.
187+
* \param index The input tensor index in the function arguments.
188+
* \param dev device to copy to if needed.
189+
*/
190+
void SetInputTensorWithIndex(std::vector<RegType>& func_args, const TVMArgValue& inp_tensor,
191+
int index, Device dev);
192+
170193
private:
171194
/*! \brief The loaded executable. */
172195
ObjectPtr<Executable> exec_;
@@ -189,6 +212,8 @@ class VirtualMachine : public runtime::ModuleNode {
189212
RegType return_value_;
190213
/*! \brief The global constant pool */
191214
std::vector<TVMRetValue> constants;
215+
/*! \brief The function name to input register mapping. */
216+
std::unordered_map<std::string, std::vector<RegType>> inputs_;
192217
};
193218

194219
} // namespace relax_vm

python/tvm/relax/exec_builder.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,11 @@ def void_arg(self) -> int:
7171
def vm_state(self) -> int:
7272
return self.r(SpecialReg.VM_STATE)
7373

74-
def function(self, func_name: str, num_inputs: Optional[int] = 0) -> VMFuncScope:
74+
def function(
75+
self, func_name: str, num_inputs: Optional[int] = 0, param_names: List[str] = None
76+
) -> VMFuncScope:
7577
"""annotate a VM function."""
76-
_ffi_api.ExecBuilderFunction(self, func_name, num_inputs)
78+
_ffi_api.ExecBuilderFunction(self, func_name, num_inputs, param_names)
7779
return VMFuncScope()
7880

7981
def _check_scope(self) -> None:

python/tvm/relax/vm.py

Lines changed: 79 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,17 @@
1414
# KIND, either express or implied. See the License for the
1515
# specific language governing permissions and limitations
1616
# under the License.
17-
# pylint: disable=invalid-name, redefined-builtin
17+
# pylint: disable=invalid-name, redefined-builtin, no-else-return
1818
"""The Relax virtual machine"""
1919
from typing import List, Optional, Union, Dict, Tuple
20+
from tvm._ffi import base as _base
21+
import numpy as np
2022

2123
import tvm
2224
from tvm import relax
2325
from tvm.ir.module import IRModule
2426
from tvm.relay import Any
25-
from tvm.runtime import Device, Module, PackedFunc
27+
from tvm.runtime import Device, Module, PackedFunc, container
2628
from tvm.runtime.object import Object
2729
from tvm.tir.function import PrimFunc
2830
from . import _ffi_api
@@ -97,6 +99,8 @@ def __init__(
9799
else exec["vm_load_executable"]()
98100
)
99101
self._invoke_closure = self.module["invoke_closure"]
102+
self._set_input = self.module["set_input"]
103+
self._get_func_param_names = self.module["get_func_param_names"]
100104
self._setup_device(device, memory_cfg)
101105

102106
def _setup_device(self, dev: Device, memory_cfg: Union[str, Dict[Device, str]]) -> None:
@@ -161,6 +165,79 @@ def invoke_closure(self, closure: Object, *args: Any) -> Object:
161165
"""
162166
return self._invoke_closure(closure, *args)
163167

168+
def _convert(self, arg: Any, cargs: List) -> None:
169+
"""helper function to convert arguments to vm function."""
170+
171+
def _gettype(arg):
172+
if isinstance(arg, np.float16):
173+
return "float16"
174+
elif isinstance(arg, (_base.integer_types, bool)):
175+
return "int32"
176+
else:
177+
return "float32"
178+
179+
if isinstance(arg, Object):
180+
cargs.append(arg)
181+
elif isinstance(arg, np.ndarray):
182+
nd_arr = tvm.nd.array(arg, device=tvm.cpu(0))
183+
cargs.append(nd_arr)
184+
elif isinstance(arg, tvm.runtime.NDArray):
185+
cargs.append(arg)
186+
elif isinstance(arg, (tuple, list)):
187+
field_args = []
188+
for field in arg:
189+
self._convert(field, field_args)
190+
cargs.append(container.tuple_object(field_args))
191+
elif isinstance(arg, (_base.numeric_types, bool)):
192+
dtype = _gettype(arg)
193+
value = tvm.nd.array(np.array(arg, dtype=dtype), device=tvm.cpu(0))
194+
cargs.append(value)
195+
elif isinstance(arg, str):
196+
cargs.append(arg)
197+
else:
198+
raise TypeError("Unsupported type: %s" % (type(arg)))
199+
200+
def set_input(self, func_name: str, *args: Any, **kwargs: Any) -> None:
201+
"""Set the inputs to a function.
202+
This interface works when using VM over RPC by internally converting NDArray in
203+
the arguments to DLTensor, which is supported in RPC where remote could only
204+
have a minimal C runtime.
205+
206+
Parameters
207+
----------
208+
func_name : str
209+
The name of the function.
210+
args: List[tvm.runtime.NDArray] or List[np.ndarray]
211+
The arguments to the function.
212+
kwargs: dict of str to tvm.runtime.NDArray or np.ndarray
213+
Named arguments to the function.
214+
"""
215+
cargs = []
216+
217+
if kwargs:
218+
# kwargs can be a super set of the required function parameters.
219+
# We only find the ones that are needed.
220+
func_params = list(self._get_func_param_names(func_name))
221+
new_args = [None] * len(func_params)
222+
cnt = 0
223+
for k in kwargs:
224+
if k in func_params:
225+
idx = func_params.index(k)
226+
new_args[idx] = kwargs[k]
227+
cnt += 1
228+
assert len(args) + cnt == len(func_params)
229+
idx = 0
230+
for i, arg in enumerate(new_args):
231+
if arg is None:
232+
new_args[i] = args[idx]
233+
idx += 1
234+
args = new_args
235+
236+
for arg in args:
237+
self._convert(arg, cargs)
238+
239+
self._set_input(func_name, *cargs)
240+
164241

165242
def build(mod: tvm.IRModule, target: tvm.target.Target) -> Executable:
166243
"""

src/relax/backend/vm/codegen_vm.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,15 +66,19 @@ class CodeGenVM : public ExprFunctor<Instruction::Arg(const Expr&)> {
6666
size_t NewRegister() { return registers_num_++; }
6767
Instruction::Arg VisitExpr_(const FunctionNode* func_node) {
6868
Optional<String> gsymbol = func_node->GetAttr<String>(tvm::attr::kGlobalSymbol);
69+
Array<String> param_names;
70+
for (Var param : func_node->params) {
71+
param_names.push_back(param->name_hint());
72+
}
6973
if (gsymbol.defined()) {
70-
builder_->EmitFunction(gsymbol.value(), func_node->params.size());
74+
builder_->EmitFunction(gsymbol.value(), func_node->params.size(), param_names);
7175
} else {
7276
// TODO(@yuchen): handle local functions that capture local vars outside the func
7377
// TODO(@yuchen): a renaming pass to resolve name conflicts, e.g. the input module has a
7478
// function named "local_funcN"
7579
// lift the local func to a global func and process it normally
7680
builder_->EmitFunction("local_func" + std::to_string(local_func_counter_++),
77-
func_node->params.size());
81+
func_node->params.size(), param_names);
7882
}
7983

8084
for (Var param : func_node->params) {

src/relax/backend/vm/exec_builder.cc

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,19 @@ vm::Index ExecBuilderNode::EmitConstant(TVMRetValue obj) {
4343
return vm::Instruction::Arg(vm::Instruction::kConstIdx, idx).data;
4444
}
4545

46-
void ExecBuilderNode::EmitFunction(std::string func_name, int64_t num_inputs) {
46+
void ExecBuilderNode::EmitFunction(std::string func_name, int64_t num_inputs,
47+
Array<String> param_names) {
4748
const auto& m = exec->global_map;
4849
ICHECK(m.find(func_name) == m.end());
4950
VMFunction vmfunc;
5051
vmfunc.name = func_name;
5152
vmfunc.start_instr = exec->instr_offset.size();
5253
vmfunc.num_args = num_inputs;
54+
std::vector<std::string> names;
55+
for (size_t i = 0; i < param_names.size(); ++i) {
56+
names.push_back(param_names[i]);
57+
}
58+
vmfunc.param_names = names;
5359
exec->global_map[func_name] = exec->global_funcs.size();
5460
exec->global_funcs.push_back(vmfunc);
5561
}

src/runtime/relax_vm/executable.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,7 @@ void SerializeVMFunc(const VMFunction& func, dmlc::Stream* strm) {
290290
strm->Write(func.start_instr);
291291
strm->Write(func.num_args);
292292
strm->Write(func.register_file_size);
293+
strm->Write(func.param_names);
293294
}
294295

295296
VMFunction DeserializeVMFunc(dmlc::Stream* strm) {
@@ -298,6 +299,7 @@ VMFunction DeserializeVMFunc(dmlc::Stream* strm) {
298299
STREAM_CHECK(strm->Read(&func.start_instr), "vmfunc start_instr");
299300
STREAM_CHECK(strm->Read(&func.num_args), "vmfunc num_args");
300301
STREAM_CHECK(strm->Read(&func.register_file_size), "vmfunc register_file_size");
302+
STREAM_CHECK(strm->Read(&func.param_names), "vmfunc params");
301303
return func;
302304
}
303305

0 commit comments

Comments
 (0)