Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/tvm/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,5 +43,5 @@
slow_summation,
timeout_job,
)
from .rpc_run import rpc_run
from .runner import local_run, rpc_run
from .utils import *
85 changes: 77 additions & 8 deletions python/tvm/testing/rpc_run.py → python/tvm/testing/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,14 @@

if TYPE_CHECKING:
import numpy as np

from tvm.meta_schedule.runner import EvaluatorConfig, RPCConfig
from tvm.runtime import Device, Module, NDArray

# pylint: disable=import-outside-toplevel,protected-access


def _args_to_remote(args, device):
def _args_to_device(args, device):
import numpy as np

from tvm.runtime.ndarray import NDArray, empty

uploaded_args = []
Expand All @@ -45,7 +43,7 @@ def _args_to_remote(args, device):
return uploaded_args


def _args_to_local(args):
def _args_to_numpy(args):
from tvm.runtime.ndarray import NDArray

downloaded_args = []
Expand Down Expand Up @@ -77,6 +75,77 @@ def export_with(func):
return export_func, output_format


def local_run( # pylint: disable=too-many-arguments,too-many-locals
mod: "Module",
device_type: str,
args: List[Union["np.ndarray", "NDArray", int, float]],
evaluator_config: Optional["EvaluatorConfig"] = None,
export_func: Union[Callable[["Module", str], None], Literal["tar", "ndk"]] = "tar",
output_format: Optional[str] = None,
):
"""Run a TVM module on a local device.

Parameters
----------
mod : Module
The TVM module to run.
device_type : str
The device type to run the module on.
args : List[Union[np.ndarray, NDArray, int, float]]
The arguments to be fed to the module.
evaluator_config : Optional[EvaluatorConfig]
The evaluator configuration to use.
export_func : Union[Callable[Module, str], Literal["tar", "ndk"]]
The function to export the module to a file.
If callable, it must be a function that takes two arguments: the module to export and the
path to export to.
If "tar", the module will be exported to a tar file.
If "ndk", the module will be exported to a shared library.
output_format : Optional[str]
The format of the exported module.
If not specified, it will be inferred from the `export_func` argument.

Returns
-------
args : List[Union[np.ndarray, NDArray, int, float]]
The results of running the module.
"""
import os.path as osp
import tempfile

from tvm.meta_schedule.runner import EvaluatorConfig
from tvm.runtime import device, load_module

evaluator_config = EvaluatorConfig._normalized(evaluator_config)
export_func, output_format = _normalize_export_func(export_func, output_format)

with tempfile.TemporaryDirectory() as tmp_dir:
artifact_path = osp.join(tmp_dir, "tvm_tmp_mod." + output_format)
export_func(mod, artifact_path)
device: Device = device(device_type, 0)

try:
args = _args_to_device(args, device)
remote_mod = load_module(artifact_path)
profile_result = remote_mod.time_evaluator(
func_name=remote_mod.entry_name,
dev=device,
number=evaluator_config.number,
repeat=evaluator_config.repeat,
min_repeat_ms=evaluator_config.min_repeat_ms,
f_preproc="cache_flush_cpu_non_first_arg"
if evaluator_config.enable_cpu_cache_flush
else "",
)(*args)
print(profile_result)
remote_mod(*args)
args = _args_to_numpy(args)
finally:
pass

return args


def rpc_run( # pylint: disable=too-many-arguments,too-many-locals
mod: "Module",
device_type: str,
Expand All @@ -103,7 +172,7 @@ def rpc_run( # pylint: disable=too-many-arguments,too-many-locals
If not specified, the default RPC configuration will be used, which reads the following
environment variables:
- TVM_TRACKER_HOST
- TVM_TRACKER_PORmod
- TVM_TRACKER_PORT
- TVM_TRACKER_KEY
export_func : Union[Callable[Module, str], Literal["tar", "ndk"]]
The function to export the module to a file.
Expand Down Expand Up @@ -134,12 +203,12 @@ def rpc_run( # pylint: disable=too-many-arguments,too-many-locals
artifact_path = osp.join(tmp_dir, "tvm_tmp_mod." + output_format)
_, remote_path = osp.split(artifact_path)
session = rpc_config.connect_server()
device: Device = session.device(dev_type=device_type, dev_id=0)
device: Device = session.device(device_type, 0)

export_func(mod, artifact_path)
try:
session.upload(artifact_path, remote_path)
args = _args_to_remote(args, device)
args = _args_to_device(args, device)
remote_mod = session.load_module(remote_path)
profile_result = remote_mod.time_evaluator(
func_name=remote_mod.entry_name,
Expand All @@ -153,7 +222,7 @@ def rpc_run( # pylint: disable=too-many-arguments,too-many-locals
)(*args)
print(profile_result)
remote_mod(*args)
args = _args_to_local(args)
args = _args_to_numpy(args)
finally:
session.remove(remote_path)
session.remove(remote_path + "." + output_format)
Expand Down