Skip to content

Commit 4f08249

Browse files
committed
Address review comments
1 parent b9f47dd commit 4f08249

File tree

3 files changed

+43
-37
lines changed

3 files changed

+43
-37
lines changed

python/tvm/meta_schedule/testing/tune_onnx.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -140,19 +140,22 @@ def main():
140140
alloc_repeat=1,
141141
max_workers=ARGS.rpc_workers,
142142
)
143-
lib = ms.tune_relay(
144-
mod=mod,
145-
target=ARGS.target,
146-
config=ms.TuneConfig(
147-
strategy="evolutionary",
148-
num_trials_per_iter=64,
149-
max_trials_per_task=ARGS.num_trials,
150-
max_trials_global=ARGS.num_trials,
151-
),
152-
runner=runner, # type: ignore
153-
work_dir=ARGS.work_dir,
154-
params=params,
155-
)
143+
with ms.Profiler() as profiler:
144+
lib = ms.tune_relay(
145+
mod=mod,
146+
target=ARGS.target,
147+
config=ms.TuneConfig(
148+
strategy="evolutionary",
149+
num_trials_per_iter=64,
150+
max_trials_per_task=ARGS.num_trials,
151+
max_trials_global=ARGS.num_trials,
152+
),
153+
runner=runner, # type: ignore
154+
work_dir=ARGS.work_dir,
155+
params=params,
156+
)
157+
print("Tuning Time:")
158+
print(profiler.table())
156159
graph, rt_mod, params = lib.graph_json, lib.lib, lib.params
157160
input_data = {}
158161
for item in ARGS.input_shape:

python/tvm/meta_schedule/testing/tune_te.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -118,20 +118,23 @@ def main():
118118
alloc_repeat=1,
119119
max_workers=ARGS.rpc_workers,
120120
)
121-
sch: Optional[tir.Schedule] = ms.tune_tir(
122-
mod=create_te_workload(ARGS.workload, 0),
123-
target=ARGS.target,
124-
config=ms.TuneConfig(
125-
strategy="evolutionary",
126-
num_trials_per_iter=64,
127-
max_trials_per_task=ARGS.num_trials,
128-
max_trials_global=ARGS.num_trials,
129-
),
130-
runner=runner, # type: ignore
131-
task_name=ARGS.workload,
132-
work_dir=ARGS.work_dir,
133-
num_threads=cpu_count(),
134-
)
121+
with ms.Profiler() as profiler:
122+
sch: Optional[tir.Schedule] = ms.tune_tir(
123+
mod=create_te_workload(ARGS.workload, 0),
124+
target=ARGS.target,
125+
config=ms.TuneConfig(
126+
strategy="evolutionary",
127+
num_trials_per_iter=64,
128+
max_trials_per_task=ARGS.num_trials,
129+
max_trials_global=ARGS.num_trials,
130+
),
131+
runner=runner, # type: ignore
132+
task_name=ARGS.workload,
133+
work_dir=ARGS.work_dir,
134+
num_threads=cpu_count(),
135+
)
136+
print("Tuning Time:")
137+
print(profiler.table())
135138
if sch is None:
136139
print("No valid schedule found!")
137140
else:

python/tvm/meta_schedule/tune.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -430,15 +430,13 @@ def tune_tir(
430430
mutator_probs=mutator_probs,
431431
num_threads=num_threads,
432432
)
433-
bests: List[TuningRecord] = database.get_top_k(
434-
database.commit_workload(mod),
435-
top_k=1,
436-
)
437-
if not bests:
438-
return None
439-
assert len(bests) == 1
440-
sch = Schedule(mod)
441-
bests[0].trace.apply_to_schedule(sch, remove_postproc=False)
433+
with Profiler.timeit("ApplyHistoryBest"):
434+
bests: List[TuningRecord] = database.get_top_k(database.commit_workload(mod), top_k=1)
435+
if not bests:
436+
return None
437+
assert len(bests) == 1
438+
sch = Schedule(mod)
439+
bests[0].trace.apply_to_schedule(sch, remove_postproc=False)
442440
return sch
443441

444442

@@ -488,8 +486,10 @@ def tune_te(
488486
sch : Optional[Schedule]
489487
The tuned schedule.
490488
"""
489+
with Profiler.timeit("CreatePrimFunc"):
490+
func = create_prim_func(tensors)
491491
return tune_tir(
492-
mod=create_prim_func(tensors),
492+
mod=func,
493493
target=target,
494494
config=config,
495495
work_dir=work_dir,

0 commit comments

Comments
 (0)