Skip to content

Commit 0c1aad7

Browse files
authored
[Testing] Add tvm.testing.local_run (#15268)
This PR introduces `tvm.testing.local_run`, which serves as a convenient numpy-in numpy-out interface to quickly run a `runtime.Module` in TVM and obtain its running time and outputs. Example: ```python @I.ir_module class Module: ... n = 128 np_a = np.random.uniform(-1, 1, [1, 32, 1, 128]).astype(np.float16) np_b = np.random.uniform(-1, 1, [1, 32, n, 128]).astype(np.float16) np_c = np.random.uniform(-1, 1, [1, 1, 1, n]).astype(np.float16) np_d = np.random.uniform(-1, 1, [1, 32, 1, n]).astype(np.float32) _, _, _, np_d = local_run( tvm.build(Module, target="llvm"), device_type="cpu", args=[np_a, np_b, np_c, np_d], ) ```
1 parent 24ae0d5 commit 0c1aad7

File tree

2 files changed

+78
-9
lines changed

2 files changed

+78
-9
lines changed

python/tvm/testing/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,5 +43,5 @@
4343
slow_summation,
4444
timeout_job,
4545
)
46-
from .rpc_run import rpc_run
46+
from .runner import local_run, rpc_run
4747
from .utils import *

python/tvm/testing/rpc_run.py renamed to python/tvm/testing/runner.py

Lines changed: 77 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,14 @@
2222

2323
if TYPE_CHECKING:
2424
import numpy as np
25-
2625
from tvm.meta_schedule.runner import EvaluatorConfig, RPCConfig
2726
from tvm.runtime import Device, Module, NDArray
2827

2928
# pylint: disable=import-outside-toplevel,protected-access
3029

3130

32-
def _args_to_remote(args, device):
31+
def _args_to_device(args, device):
3332
import numpy as np
34-
3533
from tvm.runtime.ndarray import NDArray, empty
3634

3735
uploaded_args = []
@@ -45,7 +43,7 @@ def _args_to_remote(args, device):
4543
return uploaded_args
4644

4745

48-
def _args_to_local(args):
46+
def _args_to_numpy(args):
4947
from tvm.runtime.ndarray import NDArray
5048

5149
downloaded_args = []
@@ -77,6 +75,77 @@ def export_with(func):
7775
return export_func, output_format
7876

7977

78+
def local_run( # pylint: disable=too-many-arguments,too-many-locals
79+
mod: "Module",
80+
device_type: str,
81+
args: List[Union["np.ndarray", "NDArray", int, float]],
82+
evaluator_config: Optional["EvaluatorConfig"] = None,
83+
export_func: Union[Callable[["Module", str], None], Literal["tar", "ndk"]] = "tar",
84+
output_format: Optional[str] = None,
85+
):
86+
"""Run a TVM module on a local device.
87+
88+
Parameters
89+
----------
90+
mod : Module
91+
The TVM module to run.
92+
device_type : str
93+
The device type to run the module on.
94+
args : List[Union[np.ndarray, NDArray, int, float]]
95+
The arguments to be fed to the module.
96+
evaluator_config : Optional[EvaluatorConfig]
97+
The evaluator configuration to use.
98+
export_func : Union[Callable[Module, str], Literal["tar", "ndk"]]
99+
The function to export the module to a file.
100+
If callable, it must be a function that takes two arguments: the module to export and the
101+
path to export to.
102+
If "tar", the module will be exported to a tar file.
103+
If "ndk", the module will be exported to a shared library.
104+
output_format : Optional[str]
105+
The format of the exported module.
106+
If not specified, it will be inferred from the `export_func` argument.
107+
108+
Returns
109+
-------
110+
args : List[Union[np.ndarray, NDArray, int, float]]
111+
The results of running the module.
112+
"""
113+
import os.path as osp
114+
import tempfile
115+
116+
from tvm.meta_schedule.runner import EvaluatorConfig
117+
from tvm.runtime import device, load_module
118+
119+
evaluator_config = EvaluatorConfig._normalized(evaluator_config)
120+
export_func, output_format = _normalize_export_func(export_func, output_format)
121+
122+
with tempfile.TemporaryDirectory() as tmp_dir:
123+
artifact_path = osp.join(tmp_dir, "tvm_tmp_mod." + output_format)
124+
export_func(mod, artifact_path)
125+
device: Device = device(device_type, 0)
126+
127+
try:
128+
args = _args_to_device(args, device)
129+
remote_mod = load_module(artifact_path)
130+
profile_result = remote_mod.time_evaluator(
131+
func_name=remote_mod.entry_name,
132+
dev=device,
133+
number=evaluator_config.number,
134+
repeat=evaluator_config.repeat,
135+
min_repeat_ms=evaluator_config.min_repeat_ms,
136+
f_preproc="cache_flush_cpu_non_first_arg"
137+
if evaluator_config.enable_cpu_cache_flush
138+
else "",
139+
)(*args)
140+
print(profile_result)
141+
remote_mod(*args)
142+
args = _args_to_numpy(args)
143+
finally:
144+
pass
145+
146+
return args
147+
148+
80149
def rpc_run( # pylint: disable=too-many-arguments,too-many-locals
81150
mod: "Module",
82151
device_type: str,
@@ -103,7 +172,7 @@ def rpc_run( # pylint: disable=too-many-arguments,too-many-locals
103172
If not specified, the default RPC configuration will be used, which reads the following
104173
environment variables:
105174
- TVM_TRACKER_HOST
106-
- TVM_TRACKER_PORmod
175+
- TVM_TRACKER_PORT
107176
- TVM_TRACKER_KEY
108177
export_func : Union[Callable[Module, str], Literal["tar", "ndk"]]
109178
The function to export the module to a file.
@@ -134,12 +203,12 @@ def rpc_run( # pylint: disable=too-many-arguments,too-many-locals
134203
artifact_path = osp.join(tmp_dir, "tvm_tmp_mod." + output_format)
135204
_, remote_path = osp.split(artifact_path)
136205
session = rpc_config.connect_server()
137-
device: Device = session.device(dev_type=device_type, dev_id=0)
206+
device: Device = session.device(device_type, 0)
138207

139208
export_func(mod, artifact_path)
140209
try:
141210
session.upload(artifact_path, remote_path)
142-
args = _args_to_remote(args, device)
211+
args = _args_to_device(args, device)
143212
remote_mod = session.load_module(remote_path)
144213
profile_result = remote_mod.time_evaluator(
145214
func_name=remote_mod.entry_name,
@@ -153,7 +222,7 @@ def rpc_run( # pylint: disable=too-many-arguments,too-many-locals
153222
)(*args)
154223
print(profile_result)
155224
remote_mod(*args)
156-
args = _args_to_local(args)
225+
args = _args_to_numpy(args)
157226
finally:
158227
session.remove(remote_path)
159228
session.remove(remote_path + "." + output_format)

0 commit comments

Comments
 (0)