|
14 | 14 | # KIND, either express or implied. See the License for the |
15 | 15 | # specific language governing permissions and limitations |
16 | 16 | # under the License. |
17 | | -# pylint: disable=invalid-name, redefined-builtin |
| 17 | +# pylint: disable=invalid-name, redefined-builtin, no-else-return |
18 | 18 | """The Relax virtual machine""" |
19 | 19 | from typing import List, Optional, Union, Dict, Tuple |
| 20 | +from tvm._ffi import base as _base |
| 21 | +import numpy as np |
20 | 22 |
|
21 | 23 | import tvm |
22 | 24 | from tvm import relax |
23 | 25 | from tvm.ir.module import IRModule |
24 | 26 | from tvm.relay import Any |
25 | | -from tvm.runtime import Device, Module, PackedFunc |
| 27 | +from tvm.runtime import Device, Module, PackedFunc, container |
26 | 28 | from tvm.runtime.object import Object |
27 | 29 | from tvm.tir.function import PrimFunc |
28 | 30 | from . import _ffi_api |
@@ -97,6 +99,8 @@ def __init__( |
97 | 99 | else exec["vm_load_executable"]() |
98 | 100 | ) |
99 | 101 | self._invoke_closure = self.module["invoke_closure"] |
| 102 | + self._set_input = self.module["set_input"] |
| 103 | + self._get_func_param_names = self.module["get_func_param_names"] |
100 | 104 | self._setup_device(device, memory_cfg) |
101 | 105 |
|
102 | 106 | def _setup_device(self, dev: Device, memory_cfg: Union[str, Dict[Device, str]]) -> None: |
@@ -161,6 +165,79 @@ def invoke_closure(self, closure: Object, *args: Any) -> Object: |
161 | 165 | """ |
162 | 166 | return self._invoke_closure(closure, *args) |
163 | 167 |
|
| 168 | + def _convert(self, arg: Any, cargs: List) -> None: |
| 169 | + """helper function to convert arguments to vm function.""" |
| 170 | + |
| 171 | + def _gettype(arg): |
| 172 | + if isinstance(arg, np.float16): |
| 173 | + return "float16" |
| 174 | + elif isinstance(arg, (_base.integer_types, bool)): |
| 175 | + return "int32" |
| 176 | + else: |
| 177 | + return "float32" |
| 178 | + |
| 179 | + if isinstance(arg, Object): |
| 180 | + cargs.append(arg) |
| 181 | + elif isinstance(arg, np.ndarray): |
| 182 | + nd_arr = tvm.nd.array(arg, device=tvm.cpu(0)) |
| 183 | + cargs.append(nd_arr) |
| 184 | + elif isinstance(arg, tvm.runtime.NDArray): |
| 185 | + cargs.append(arg) |
| 186 | + elif isinstance(arg, (tuple, list)): |
| 187 | + field_args = [] |
| 188 | + for field in arg: |
| 189 | + self._convert(field, field_args) |
| 190 | + cargs.append(container.tuple_object(field_args)) |
| 191 | + elif isinstance(arg, (_base.numeric_types, bool)): |
| 192 | + dtype = _gettype(arg) |
| 193 | + value = tvm.nd.array(np.array(arg, dtype=dtype), device=tvm.cpu(0)) |
| 194 | + cargs.append(value) |
| 195 | + elif isinstance(arg, str): |
| 196 | + cargs.append(arg) |
| 197 | + else: |
| 198 | + raise TypeError("Unsupported type: %s" % (type(arg))) |
| 199 | + |
| 200 | + def set_input(self, func_name: str, *args: Any, **kwargs: Any) -> None: |
| 201 | + """Set the inputs to a function. |
| 202 | + This interface works when using VM over RPC by internally converting NDArray in |
| 203 | + the arguments to DLTensor, which is supported in RPC where remote could only |
| 204 | + have a minimal C runtime. |
| 205 | +
|
| 206 | + Parameters |
| 207 | + ---------- |
| 208 | + func_name : str |
| 209 | + The name of the function. |
| 210 | + args: List[tvm.runtime.NDArray] or List[np.ndarray] |
| 211 | + The arguments to the function. |
| 212 | + kwargs: dict of str to tvm.runtime.NDArray or np.ndarray |
| 213 | + Named arguments to the function. |
| 214 | + """ |
| 215 | + cargs = [] |
| 216 | + |
| 217 | + if kwargs: |
| 218 | + # kwargs can be a super set of the required function parameters. |
| 219 | + # We only find the ones that are needed. |
| 220 | + func_params = list(self._get_func_param_names(func_name)) |
| 221 | + new_args = [None] * len(func_params) |
| 222 | + cnt = 0 |
| 223 | + for k in kwargs: |
| 224 | + if k in func_params: |
| 225 | + idx = func_params.index(k) |
| 226 | + new_args[idx] = kwargs[k] |
| 227 | + cnt += 1 |
| 228 | + assert len(args) + cnt == len(func_params) |
| 229 | + idx = 0 |
| 230 | + for i, arg in enumerate(new_args): |
| 231 | + if arg is None: |
| 232 | + new_args[i] = args[idx] |
| 233 | + idx += 1 |
| 234 | + args = new_args |
| 235 | + |
| 236 | + for arg in args: |
| 237 | + self._convert(arg, cargs) |
| 238 | + |
| 239 | + self._set_input(func_name, *cargs) |
| 240 | + |
164 | 241 |
|
165 | 242 | def build(mod: tvm.IRModule, target: tvm.target.Target) -> Executable: |
166 | 243 | """ |
|
0 commit comments