|
22 | 22 |
|
23 | 23 | if TYPE_CHECKING: |
24 | 24 | import numpy as np |
| 25 | + |
25 | 26 | from tvm.meta_schedule.runner import EvaluatorConfig, RPCConfig |
26 | 27 | from tvm.runtime import Device, Module, NDArray |
27 | 28 |
|
|
30 | 31 |
|
31 | 32 | def _args_to_device(args, device): |
32 | 33 | import numpy as np |
| 34 | + |
33 | 35 | from tvm.runtime.ndarray import NDArray, empty |
34 | 36 |
|
35 | 37 | uploaded_args = [] |
@@ -109,6 +111,8 @@ def local_run( # pylint: disable=too-many-arguments,too-many-locals |
109 | 111 | ------- |
110 | 112 | args : List[Union[np.ndarray, NDArray, int, float]] |
111 | 113 | The results of running the module. |
| 114 | + profile_result : tvm.runtime.BenchmarkResult |
| 115 | + The profiling result of running the module. |
112 | 116 | """ |
113 | 117 | import os.path as osp |
114 | 118 | import tempfile |
@@ -137,13 +141,12 @@ def local_run( # pylint: disable=too-many-arguments,too-many-locals |
137 | 141 | if evaluator_config.enable_cpu_cache_flush |
138 | 142 | else "", |
139 | 143 | )(*args) |
140 | | - print(profile_result) |
141 | 144 | remote_mod(*args) |
142 | 145 | args = _args_to_numpy(args) |
143 | 146 | finally: |
144 | 147 | pass |
145 | 148 |
|
146 | | - return args |
| 149 | + return args, profile_result |
147 | 150 |
|
148 | 151 |
|
149 | 152 | def rpc_run( # pylint: disable=too-many-arguments,too-many-locals |
@@ -188,6 +191,8 @@ def rpc_run( # pylint: disable=too-many-arguments,too-many-locals |
188 | 191 | ------- |
189 | 192 | args : List[Union[np.ndarray, NDArray, int, float]] |
190 | 193 | The results of running the module. |
| 194 | + profile_result : tvm.runtime.BenchmarkResult |
| 195 | + The profiling result of running the module. |
191 | 196 | """ |
192 | 197 |
|
193 | 198 | import os.path as osp |
@@ -220,12 +225,11 @@ def rpc_run( # pylint: disable=too-many-arguments,too-many-locals |
220 | 225 | if evaluator_config.enable_cpu_cache_flush |
221 | 226 | else "", |
222 | 227 | )(*args) |
223 | | - print(profile_result) |
224 | 228 | remote_mod(*args) |
225 | 229 | args = _args_to_numpy(args) |
226 | 230 | finally: |
227 | 231 | session.remove(remote_path) |
228 | 232 | session.remove(remote_path + "." + output_format) |
229 | 233 | session.remove("") |
230 | 234 |
|
231 | | - return args |
| 235 | + return args, profile_result |
0 commit comments