Skip to content

Commit beea0d2

Browse files
authored
[MetaSchedule] Fix Task Extraction (#11954)
1 parent 26ad703 commit beea0d2

File tree

7 files changed

+40
-11
lines changed

7 files changed

+40
-11
lines changed

python/tvm/meta_schedule/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,11 @@
3333
from .apply_history_best import ApplyHistoryBest
3434
from .extracted_task import ExtractedTask
3535
from .profiler import Profiler
36-
from .relay_integration import extract_task_from_relay, is_meta_schedule_enabled
36+
from .relay_integration import (
37+
extract_task_from_relay,
38+
is_meta_schedule_dispatch_enabled,
39+
is_meta_schedule_enabled,
40+
)
3741
from .search_strategy import MeasureCandidate
3842
from .tune import TuneConfig, tune_extracted_tasks, tune_relay, tune_te, tune_tir
3943
from .tune_context import TuneContext

python/tvm/meta_schedule/relay_integration.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def extract_task_from_relay(
7070
The tasks extracted from this network
7171
"""
7272
# pylint: disable=import-outside-toplevel
73+
from tvm import autotvm
7374
from tvm.relay import Function as RelayFunc
7475

7576
# pylint: enable=import-outside-toplevel
@@ -102,7 +103,14 @@ def extract_task_from_relay(
102103
config=pass_config,
103104
disabled_pass=disabled_pass,
104105
):
105-
return list(extract_task_func(mod, target, relay_params, te_filter_func))
106+
if target.kind.name != "cuda" and isinstance(
107+
autotvm.DispatchContext.current, autotvm.FallbackContext
108+
):
109+
tophub_context = autotvm.tophub.context(target)
110+
else:
111+
tophub_context = autotvm.utils.EmptyContext()
112+
with tophub_context:
113+
return list(extract_task_func(mod, target, relay_params, te_filter_func))
106114

107115

108116
def is_meta_schedule_enabled() -> bool:
@@ -117,3 +125,17 @@ def is_meta_schedule_enabled() -> bool:
117125
"relay.backend.use_meta_schedule",
118126
False,
119127
)
128+
129+
130+
def is_meta_schedule_dispatch_enabled() -> bool:
131+
"""Return whether the meta-schedule dispatch is enabled.
132+
133+
Returns
134+
-------
135+
enabled: bool
136+
Whether the meta schedule is enabled
137+
"""
138+
return transform.PassContext.current().config.get(
139+
"relay.backend.use_meta_schedule_dispatch",
140+
False,
141+
)

python/tvm/meta_schedule/tune.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -592,6 +592,9 @@ def tune_relay(
592592
with target, autotvm_silencer(), ApplyHistoryBest(database):
593593
with PassContext(
594594
opt_level=3,
595-
config={"relay.backend.use_meta_schedule": True},
595+
config={
596+
"relay.backend.use_meta_schedule": True,
597+
"relay.backend.use_meta_schedule_dispatch": target.kind.name != "cuda",
598+
},
596599
):
597600
return relay_build(mod, target=target, params=params)

python/tvm/relay/backend/te_compiler.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@
2323
import numpy as np
2424
import tvm
2525
from tvm import autotvm, te
26-
from tvm.ir.transform import PassContext
26+
from tvm.auto_scheduler import is_auto_scheduler_enabled
27+
from tvm.meta_schedule import is_meta_schedule_dispatch_enabled
2728
from tvm.runtime import Object
2829
from tvm.support import libinfo
2930
from tvm.target import Target
@@ -180,7 +181,7 @@ def select_implementation(op, attrs, inputs, out_type, target, use_autotvm=True)
180181

181182
# Disable autotvm if auto_scheduler is enabled.
182183
# (i.e., always return the implementation with the highest priority for auto-scheduler).
183-
if PassContext.current().config.get("relay.backend.use_auto_scheduler", False):
184+
if is_auto_scheduler_enabled() or is_meta_schedule_dispatch_enabled():
184185
use_autotvm = False
185186

186187
# If not use autotvm, always return the implementation with the highest priority

python/tvm/relay/op/strategy/cuda.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -252,9 +252,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
252252
)
253253

254254
# register auto-scheduler implementations
255-
if (
256-
is_auto_scheduler_enabled() or is_meta_schedule_enabled()
257-
) and judge_winograd_auto_scheduler:
255+
if is_auto_scheduler_enabled() and judge_winograd_auto_scheduler:
258256
strategy.add_implementation(
259257
wrap_compute_conv2d(topi.nn.conv2d_winograd_nhwc),
260258
naive_schedule, # this implementation should never be picked by autotvm
@@ -545,7 +543,7 @@ def conv2d_winograd_without_weight_transfrom_strategy_cuda(attrs, inputs, out_ty
545543
name="conv2d_nhwc_winograd_direct_without_weight_transform.cuda",
546544
)
547545

548-
if is_auto_scheduler_enabled() or is_meta_schedule_enabled():
546+
if is_auto_scheduler_enabled():
549547
strategy.add_implementation(
550548
wrap_compute_conv2d(topi.nn.conv2d_winograd_nhwc_without_weight_transform),
551549
naive_schedule, # this implementation should never be picked by autotvm
@@ -823,7 +821,7 @@ def matmul_strategy_cuda(attrs, inputs, out_type, target):
823821
"""Matmul cuda strategy."""
824822
strategy = _op.OpStrategy()
825823

826-
if is_auto_scheduler_enabled() or is_meta_schedule_enabled():
824+
if is_auto_scheduler_enabled():
827825
strategy.add_implementation(
828826
wrap_compute_matmul(topi.nn.matmul),
829827
naive_schedule,

src/meta_schedule/database/json_database.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ Database Database::JSONDatabase(String path_workload, String path_tuning_record,
204204
LOG(FATAL) << "ValueError: Unable to parse TuningRecord, on line " << (task_id + 1)
205205
<< " of file " << path_tuning_record << ". The workload is:\n"
206206
<< (workload.defined() ? tir::AsTVMScript(workload) : "(null)")
207-
<< "\nThe JSONObject of TuningRecrod is:\n"
207+
<< "\nThe JSONObject of TuningRecord is:\n"
208208
<< json_obj << "\nThe error message is:\n"
209209
<< e.what();
210210
}

src/relay/backend/te_compiler.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,7 @@ TECompiler& TECompiler::Global() {
552552
}
553553
TVM_REGISTER_PASS_CONFIG_OPTION("relay.backend.use_auto_scheduler", Bool);
554554
TVM_REGISTER_PASS_CONFIG_OPTION("relay.backend.use_meta_schedule", Bool);
555+
TVM_REGISTER_PASS_CONFIG_OPTION("relay.backend.use_meta_schedule_dispatch", Bool);
555556

556557
TVM_REGISTER_GLOBAL("relay.backend._TECompilerGlobal").set_body_typed([]() {
557558
return TECompiler::Global();

0 commit comments

Comments
 (0)