2222
2323if TYPE_CHECKING :
2424 import numpy as np
25-
2625 from tvm .meta_schedule .runner import EvaluatorConfig , RPCConfig
2726 from tvm .runtime import Device , Module , NDArray
2827
2928# pylint: disable=import-outside-toplevel,protected-access
3029
3130
32- def _args_to_remote (args , device ):
31+ def _args_to_device (args , device ):
3332 import numpy as np
34-
3533 from tvm .runtime .ndarray import NDArray , empty
3634
3735 uploaded_args = []
@@ -45,7 +43,7 @@ def _args_to_remote(args, device):
4543 return uploaded_args
4644
4745
48- def _args_to_local (args ):
46+ def _args_to_numpy (args ):
4947 from tvm .runtime .ndarray import NDArray
5048
5149 downloaded_args = []
@@ -77,6 +75,77 @@ def export_with(func):
7775 return export_func , output_format
7876
7977
78+ def local_run ( # pylint: disable=too-many-arguments,too-many-locals
79+ mod : "Module" ,
80+ device_type : str ,
81+ args : List [Union ["np.ndarray" , "NDArray" , int , float ]],
82+ evaluator_config : Optional ["EvaluatorConfig" ] = None ,
83+ export_func : Union [Callable [["Module" , str ], None ], Literal ["tar" , "ndk" ]] = "tar" ,
84+ output_format : Optional [str ] = None ,
85+ ):
86+ """Run a TVM module on a local device.
87+
88+ Parameters
89+ ----------
90+ mod : Module
91+ The TVM module to run.
92+ device_type : str
93+ The device type to run the module on.
94+ args : List[Union[np.ndarray, NDArray, int, float]]
95+ The arguments to be fed to the module.
96+ evaluator_config : Optional[EvaluatorConfig]
97+ The evaluator configuration to use.
98+ export_func : Union[Callable[Module, str], Literal["tar", "ndk"]]
99+ The function to export the module to a file.
100+ If callable, it must be a function that takes two arguments: the module to export and the
101+ path to export to.
102+ If "tar", the module will be exported to a tar file.
103+ If "ndk", the module will be exported to a shared library.
104+ output_format : Optional[str]
105+ The format of the exported module.
106+ If not specified, it will be inferred from the `export_func` argument.
107+
108+ Returns
109+ -------
110+ args : List[Union[np.ndarray, NDArray, int, float]]
111+ The results of running the module.
112+ """
113+ import os .path as osp
114+ import tempfile
115+
116+ from tvm .meta_schedule .runner import EvaluatorConfig
117+ from tvm .runtime import device , load_module
118+
119+ evaluator_config = EvaluatorConfig ._normalized (evaluator_config )
120+ export_func , output_format = _normalize_export_func (export_func , output_format )
121+
122+ with tempfile .TemporaryDirectory () as tmp_dir :
123+ artifact_path = osp .join (tmp_dir , "tvm_tmp_mod." + output_format )
124+ export_func (mod , artifact_path )
125+ device : Device = device (device_type , 0 )
126+
127+ try :
128+ args = _args_to_device (args , device )
129+ remote_mod = load_module (artifact_path )
130+ profile_result = remote_mod .time_evaluator (
131+ func_name = remote_mod .entry_name ,
132+ dev = device ,
133+ number = evaluator_config .number ,
134+ repeat = evaluator_config .repeat ,
135+ min_repeat_ms = evaluator_config .min_repeat_ms ,
136+ f_preproc = "cache_flush_cpu_non_first_arg"
137+ if evaluator_config .enable_cpu_cache_flush
138+ else "" ,
139+ )(* args )
140+ print (profile_result )
141+ remote_mod (* args )
142+ args = _args_to_numpy (args )
143+ finally :
144+ pass
145+
146+ return args
147+
148+
80149def rpc_run ( # pylint: disable=too-many-arguments,too-many-locals
81150 mod : "Module" ,
82151 device_type : str ,
@@ -103,7 +172,7 @@ def rpc_run( # pylint: disable=too-many-arguments,too-many-locals
103172 If not specified, the default RPC configuration will be used, which reads the following
104173 environment variables:
105174 - TVM_TRACKER_HOST
106- - TVM_TRACKER_PORmod
175+ - TVM_TRACKER_PORT
107176 - TVM_TRACKER_KEY
108177 export_func : Union[Callable[Module, str], Literal["tar", "ndk"]]
109178 The function to export the module to a file.
@@ -134,12 +203,12 @@ def rpc_run( # pylint: disable=too-many-arguments,too-many-locals
134203 artifact_path = osp .join (tmp_dir , "tvm_tmp_mod." + output_format )
135204 _ , remote_path = osp .split (artifact_path )
136205 session = rpc_config .connect_server ()
137- device : Device = session .device (dev_type = device_type , dev_id = 0 )
206+ device : Device = session .device (device_type , 0 )
138207
139208 export_func (mod , artifact_path )
140209 try :
141210 session .upload (artifact_path , remote_path )
142- args = _args_to_remote (args , device )
211+ args = _args_to_device (args , device )
143212 remote_mod = session .load_module (remote_path )
144213 profile_result = remote_mod .time_evaluator (
145214 func_name = remote_mod .entry_name ,
@@ -153,7 +222,7 @@ def rpc_run( # pylint: disable=too-many-arguments,too-many-locals
153222 )(* args )
154223 print (profile_result )
155224 remote_mod (* args )
156- args = _args_to_local (args )
225+ args = _args_to_numpy (args )
157226 finally :
158227 session .remove (remote_path )
159228 session .remove (remote_path + "." + output_format )
0 commit comments