|
22 | 22 |
|
23 | 23 | # isort: on |
24 | 24 |
|
25 | | -from tvm._ffi import get_global_func |
| 25 | +from tvm._ffi import get_global_func, register_func |
26 | 26 | from tvm.ir import IRModule |
27 | 27 | from tvm.ir.transform import PassContext |
28 | 28 | from tvm.runtime import NDArray |
29 | 29 | from tvm.target import Target |
| 30 | +from tvm.tir.expr import IntImm |
30 | 31 |
|
31 | 32 | from .builder import Builder |
32 | 33 | from .cost_model import CostModel |
@@ -223,6 +224,94 @@ def tune_relax( |
223 | 224 | ) |
224 | 225 |
|
225 | 226 |
|
| 227 | +@register_func("tvm.meta_schedule.tune_relax") |
| 228 | +def _tune_relax( |
| 229 | + mod: Union[IRModule, "relax.Function"], |
| 230 | + params: Dict[str, NDArray], |
| 231 | + target: Union[str, Target], |
| 232 | + work_dir: str, |
| 233 | + max_trials_global: int, |
| 234 | + *, |
| 235 | + max_trials_per_task: Optional[int] = None, |
| 236 | + num_trials_per_iter: int = 64, |
| 237 | + builder: Builder.BuilderType = "local", |
| 238 | + runner: Runner.RunnerType = "local", |
| 239 | + database: Database.DatabaseType = "json", |
| 240 | + cost_model: CostModel.CostModelType = "xgb", |
| 241 | + measure_callbacks: MeasureCallback.CallbackListType = "default", |
| 242 | + task_scheduler: TaskScheduler.TaskSchedulerType = "gradient", |
| 243 | + space: SpaceGenerator.SpaceGeneratorType = "post-order-apply", |
| 244 | + strategy: SearchStrategy.SearchStrategyType = "evolutionary", |
| 245 | + seed: Optional[int] = None, |
| 246 | +) -> Database: |
| 247 | + """Interface with tuning api to tune a Relax program. |
| 248 | +
|
| 249 | + Parameters |
| 250 | + ---------- |
| 251 | + mod : Union[IRModule, relax.Function] |
| 252 | + The module or function to tune |
| 253 | + params : Optional[Dict[str, tvm.runtime.NDArray]] |
| 254 | + The associated parameters of the program |
| 255 | + target : Union[Target, str] |
| 256 | + The compilation target |
| 257 | + work_dir : str |
| 258 | + The working directory to store the tuning records |
| 259 | + max_trials_global : int |
| 260 | + The maximum number of trials to run |
| 261 | + max_trials_per_task : Optional[int] |
| 262 | + The maximum number of trials to run for each task |
| 263 | + num_trials_per_iter : int |
| 264 | + The number of trials to run per iteration |
| 265 | + builder : BuilderType |
| 266 | + The builder to use |
| 267 | + runner : RunnerType |
| 268 | + The runner to use |
| 269 | + database : DatabaseType |
| 270 | + The database to use |
| 271 | + cost_model : CostModelType |
| 272 | + The cost model to use |
| 273 | + measure_callbacks : CallbackListType |
| 274 | + The measure callbacks to use |
| 275 | + task_scheduler : TaskSchedulerType |
| 276 | + The task scheduler to use |
| 277 | + space : SpaceGeneratorType |
| 278 | + The space generator to use |
| 279 | + strategy : SearchStrategyType |
| 280 | + The search strategy to use |
| 281 | + seed : Optional[int] |
| 282 | + The random seed |
| 283 | +
|
| 284 | + Returns |
| 285 | + ------- |
| 286 | + ret_mod : IRModule |
| 287 | + IRModule |
| 288 | + """ |
| 289 | + if isinstance(max_trials_global, IntImm): |
| 290 | + max_trials_global = int(max_trials_global) |
| 291 | + |
| 292 | + tune_relax( |
| 293 | + mod, |
| 294 | + params, |
| 295 | + target, |
| 296 | + work_dir, |
| 297 | + max_trials_global, |
| 298 | + max_trials_per_task=max_trials_per_task, |
| 299 | + num_trials_per_iter=num_trials_per_iter, |
| 300 | + builder=builder, |
| 301 | + runner=runner, |
| 302 | + database=database, |
| 303 | + cost_model=cost_model, |
| 304 | + measure_callbacks=measure_callbacks, |
| 305 | + task_scheduler=task_scheduler, |
| 306 | + space=space, |
| 307 | + strategy=strategy, |
| 308 | + seed=seed, |
| 309 | + ) |
| 310 | + # Return original IRModule |
| 311 | + # This pass only makes optimization decision |
| 312 | + return mod |
| 313 | + |
| 314 | + |
226 | 315 | def compile_relax( |
227 | 316 | database: Database, |
228 | 317 | mod: IRModule, |
|
0 commit comments