Skip to content

Commit 1783c42

Browse files
sungggjunrushao
authored andcommitted
Recover: [Pass] Separate ApplyHistoryBest from tuning passes (apache#226)
1 parent 9e6acd0 commit 1783c42

File tree

7 files changed

+422
-71
lines changed

7 files changed

+422
-71
lines changed

python/tvm/meta_schedule/relax_integration.py

Lines changed: 90 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,12 @@
2222

2323
# isort: on
2424

25-
from tvm._ffi import get_global_func
25+
from tvm._ffi import get_global_func, register_func
2626
from tvm.ir import IRModule
2727
from tvm.ir.transform import PassContext
2828
from tvm.runtime import NDArray
2929
from tvm.target import Target
30+
from tvm.tir.expr import IntImm
3031

3132
from .builder import Builder
3233
from .cost_model import CostModel
@@ -223,6 +224,94 @@ def tune_relax(
223224
)
224225

225226

227+
@register_func("tvm.meta_schedule.tune_relax")
228+
def _tune_relax(
229+
mod: Union[IRModule, "relax.Function"],
230+
params: Dict[str, NDArray],
231+
target: Union[str, Target],
232+
work_dir: str,
233+
max_trials_global: int,
234+
*,
235+
max_trials_per_task: Optional[int] = None,
236+
num_trials_per_iter: int = 64,
237+
builder: Builder.BuilderType = "local",
238+
runner: Runner.RunnerType = "local",
239+
database: Database.DatabaseType = "json",
240+
cost_model: CostModel.CostModelType = "xgb",
241+
measure_callbacks: MeasureCallback.CallbackListType = "default",
242+
task_scheduler: TaskScheduler.TaskSchedulerType = "gradient",
243+
space: SpaceGenerator.SpaceGeneratorType = "post-order-apply",
244+
strategy: SearchStrategy.SearchStrategyType = "evolutionary",
245+
seed: Optional[int] = None,
246+
) -> Database:
247+
"""Interface with tuning api to tune a Relax program.
248+
249+
Parameters
250+
----------
251+
mod : Union[IRModule, relax.Function]
252+
The module or function to tune
253+
params : Optional[Dict[str, tvm.runtime.NDArray]]
254+
The associated parameters of the program
255+
target : Union[Target, str]
256+
The compilation target
257+
work_dir : str
258+
The working directory to store the tuning records
259+
max_trials_global : int
260+
The maximum number of trials to run
261+
max_trials_per_task : Optional[int]
262+
The maximum number of trials to run for each task
263+
num_trials_per_iter : int
264+
The number of trials to run per iteration
265+
builder : BuilderType
266+
The builder to use
267+
runner : RunnerType
268+
The runner to use
269+
database : DatabaseType
270+
The database to use
271+
cost_model : CostModelType
272+
The cost model to use
273+
measure_callbacks : CallbackListType
274+
The measure callbacks to use
275+
task_scheduler : TaskSchedulerType
276+
The task scheduler to use
277+
space : SpaceGeneratorType
278+
The space generator to use
279+
strategy : SearchStrategyType
280+
The search strategy to use
281+
seed : Optional[int]
282+
The random seed
283+
284+
Returns
285+
-------
286+
ret_mod : IRModule
287+
IRModule
288+
"""
289+
if isinstance(max_trials_global, IntImm):
290+
max_trials_global = int(max_trials_global)
291+
292+
tune_relax(
293+
mod,
294+
params,
295+
target,
296+
work_dir,
297+
max_trials_global,
298+
max_trials_per_task=max_trials_per_task,
299+
num_trials_per_iter=num_trials_per_iter,
300+
builder=builder,
301+
runner=runner,
302+
database=database,
303+
cost_model=cost_model,
304+
measure_callbacks=measure_callbacks,
305+
task_scheduler=task_scheduler,
306+
space=space,
307+
strategy=strategy,
308+
seed=seed,
309+
)
310+
# Return original IRModule
311+
# This pass only makes optimization decision
312+
return mod
313+
314+
226315
def compile_relax(
227316
database: Database,
228317
mod: IRModule,

python/tvm/meta_schedule/tir_integration.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@
2222

2323
# isort: on
2424
from tvm import ir, tir
25+
from tvm._ffi import register_func
2526
from tvm.target import Target
27+
from tvm.tir.expr import IntImm
2628

2729
from .builder import Builder
2830
from .cost_model import CostModel
@@ -128,6 +130,93 @@ def tune_tir(
128130
)
129131

130132

133+
@register_func("tvm.meta_schedule.tune_tir")
134+
def _tune_tir(
135+
mod: Union[ir.IRModule, tir.PrimFunc],
136+
target: Union[str, Target],
137+
work_dir: str,
138+
max_trials_global: int,
139+
*,
140+
num_trials_per_iter: int = 64,
141+
builder: Builder.BuilderType = "local",
142+
runner: Runner.RunnerType = "local",
143+
database: Database.DatabaseType = "json",
144+
cost_model: CostModel.CostModelType = "xgb",
145+
measure_callbacks: MeasureCallback.CallbackListType = "default",
146+
task_scheduler: TaskScheduler.TaskSchedulerType = "round-robin",
147+
space: SpaceGenerator.SpaceGeneratorType = "post-order-apply",
148+
strategy: SearchStrategy.SearchStrategyType = "evolutionary",
149+
task_name: str = "main",
150+
num_threads: Union[Literal["physical", "logical"], int] = "physical",
151+
seed: Optional[int] = None,
152+
) -> Database:
153+
"""Interface with tuning api to tune a TIR program.
154+
155+
Parameters
156+
----------
157+
mod : Union[ir.IRModule, tir.PrimFunc]
158+
The TIR function to tune.
159+
target : Union[str, Target]
160+
The target to tune for.
161+
work_dir : str
162+
The working directory.
163+
max_trials_global : int
164+
The maximum number of trials to run globally.
165+
num_trials_per_iter : int
166+
The number of trials to run per iteration
167+
builder : Builder.BuilderType
168+
The builder.
169+
runner : Runner.RunnerType
170+
The runner.
171+
database : Database.DatabaseType
172+
The database.
173+
cost_model : CostModel.CostModelType
174+
The cost model.
175+
measure_callbacks : MeasureCallback.CallbackListType
176+
The measure callbacks.
177+
task_scheduler : TaskScheduler.TaskSchedulerType
178+
The task scheduler.
179+
space : SpaceGenerator.SpaceGeneratorType
180+
The space generator.
181+
strategy : SearchStrategy.SearchStrategyType
182+
The search strategy.
183+
task_name : str
184+
The name of the task.
185+
num_threads : Union[Literal["physical", "logical"], int]
186+
The number of threads to use.
187+
seed : Optional[int]
188+
The seed for the random number generator.
189+
190+
Returns
191+
-------
192+
ret_mod : IRModule
193+
IRModule
194+
"""
195+
if isinstance(max_trials_global, IntImm):
196+
max_trials_global = int(max_trials_global)
197+
tune_tir(
198+
mod,
199+
target,
200+
work_dir,
201+
max_trials_global,
202+
num_trials_per_iter=num_trials_per_iter,
203+
builder=builder,
204+
runner=runner,
205+
database=database,
206+
cost_model=cost_model,
207+
measure_callbacks=measure_callbacks,
208+
task_scheduler=task_scheduler,
209+
space=space,
210+
strategy=strategy,
211+
task_name=task_name,
212+
num_threads=num_threads,
213+
seed=seed,
214+
)
215+
# Return original IRModule
216+
# This pass only makes optimization decision
217+
return mod
218+
219+
131220
def compile_tir(
132221
database: Database,
133222
mod: Union[ir.IRModule, tir.PrimFunc],

python/tvm/meta_schedule/tune_context.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
# isort: on
2525

2626
from tvm import IRModule
27-
from tvm._ffi import register_object
27+
from tvm._ffi import register_object, register_func
2828
from tvm.runtime import Object
2929
from tvm.target import Target
3030
from tvm.tir import PrimFunc, Schedule
@@ -41,6 +41,7 @@
4141
from .space_generator import SpaceGenerator
4242

4343

44+
@register_func("tvm.meta_schedule.normalize_mod")
4445
def _normalize_mod(mod: Union[PrimFunc, IRModule]) -> IRModule:
4546
"""Normalize the input to an IRModule"""
4647
if isinstance(mod, PrimFunc):

python/tvm/relax/transform/transform.py

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323

2424
import numpy as np
2525
import tvm.ir
26+
from tvm.runtime import NDArray
27+
from tvm.target import Target
28+
from tvm.meta_schedule.database import PyDatabase
2629

2730
from . import _ffi_api
2831

@@ -283,15 +286,58 @@ def FuseTIR() -> tvm.ir.transform.Pass:
283286
return _ffi_api.FuseTIR()
284287

285288

286-
def MetaScheduleApplyDatabase() -> tvm.ir.transform.Pass:
289+
def MetaScheduleApplyDatabase(
290+
work_dir: Optional[str] = None,
291+
) -> tvm.ir.transform.Pass:
287292
"""Apply the best schedule from tuning database.
288-
293+
work_dir : Optional[str]
294+
work directory to deduce default database if database is not provided
295+
(it will be ignored when an user passes database)
289296
Returns
290297
-------
291298
ret : tvm.transform.Pass
292-
The registered pass for tir fusion.
299+
The registered pass
300+
"""
301+
return _ffi_api.MetaScheduleApplyDatabase(work_dir)
302+
303+
304+
def MetaScheduleTuneTIR(
305+
work_dir: str,
306+
max_trials_global: int,
307+
) -> tvm.ir.transform.Pass:
308+
"""Tune TIR with MetaSchedule.
309+
Parameters
310+
----------
311+
work_dir: str
312+
work directory
313+
max_trials_gloabl: int
314+
maximum number of total trials allowed for tuning
315+
Returns
316+
-------
317+
ret: tvm.ir.transform.Pass
318+
"""
319+
return _ffi_api.MetaScheduleTuneTIR(work_dir, max_trials_global)
320+
321+
322+
def MetaScheduleTuneIRMod(
323+
params: Dict[str, NDArray],
324+
work_dir: str,
325+
max_trials_global: int,
326+
) -> tvm.ir.transform.Pass:
327+
"""Tune Relax IRModule with MetaSchedule.
328+
Parameters
329+
----------
330+
params: Dict[str, NDArray]
331+
model params
332+
work_dir: str
333+
work directory
334+
max_trials_gloabl: int
335+
maximum number of total trials allowed for tuning
336+
Returns
337+
-------
338+
ret: tvm.ir.transform.Pass
293339
"""
294-
return _ffi_api.MetaScheduleApplyDatabase()
340+
return _ffi_api.MetaScheduleTuneIRMod(params, work_dir, max_trials_global)
295341

296342

297343
def _wrap_class_function_pass(pass_cls, pass_info):

src/relax/backend/task_extraction.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,10 @@ class TaskExtractor : public ExprVisitor {
5757

5858
private:
5959
explicit TaskExtractor(IRModule mod, Target target)
60-
: mod_(std::move(mod)), target_(std::move(target)) {}
60+
: mod_(std::move(mod)), target_(std::move(target)) {
61+
normalize_mod_func_ = runtime::Registry::Get("tvm.meta_schedule.normalize_mod");
62+
ICHECK(normalize_mod_func_) << "Normalization function is not found.";
63+
}
6164

6265
void VisitExpr_(const CallNode* call) final {
6366
static const Op& call_tir_op = Op::Get("relax.call_tir");
@@ -76,7 +79,7 @@ class TaskExtractor : public ExprVisitor {
7679
return;
7780
}
7881

79-
IRModule tir_mod({{global_var, func}});
82+
IRModule tir_mod = (*normalize_mod_func_)(func);
8083
ExtractedTask task(/*task_name=*/global_var->name_hint, //
8184
/*mod=*/tir_mod, //
8285
/*target=*/target_, //
@@ -90,6 +93,7 @@ class TaskExtractor : public ExprVisitor {
9093
Target target_;
9194
Array<ExtractedTask> tasks_;
9295
std::unordered_map<tir::PrimFunc, ExtractedTask, StructuralHash, StructuralEqual> func2task_;
96+
const runtime::PackedFunc* normalize_mod_func_;
9397
};
9498

9599
TVM_REGISTER_GLOBAL("relax.backend.MetaScheduleExtractTask")

0 commit comments

Comments
 (0)