diff --git a/python/tvm/meta_schedule/__init__.py b/python/tvm/meta_schedule/__init__.py index eb40b32e7c29..f60d0a5490f5 100644 --- a/python/tvm/meta_schedule/__init__.py +++ b/python/tvm/meta_schedule/__init__.py @@ -33,7 +33,11 @@ from .apply_history_best import ApplyHistoryBest from .extracted_task import ExtractedTask from .profiler import Profiler -from .relay_integration import extract_task_from_relay, is_meta_schedule_enabled +from .relay_integration import ( + extract_task_from_relay, + is_meta_schedule_dispatch_enabled, + is_meta_schedule_enabled, +) from .search_strategy import MeasureCandidate from .tune import TuneConfig, tune_extracted_tasks, tune_relay, tune_te, tune_tir from .tune_context import TuneContext diff --git a/python/tvm/meta_schedule/relay_integration.py b/python/tvm/meta_schedule/relay_integration.py index 707b469aa456..bd12ac350a61 100644 --- a/python/tvm/meta_schedule/relay_integration.py +++ b/python/tvm/meta_schedule/relay_integration.py @@ -70,6 +70,7 @@ def extract_task_from_relay( The tasks extracted from this network """ # pylint: disable=import-outside-toplevel + from tvm import autotvm from tvm.relay import Function as RelayFunc # pylint: enable=import-outside-toplevel @@ -102,7 +103,14 @@ def extract_task_from_relay( config=pass_config, disabled_pass=disabled_pass, ): - return list(extract_task_func(mod, target, relay_params, te_filter_func)) + if target.kind.name != "cuda" and isinstance( + autotvm.DispatchContext.current, autotvm.FallbackContext + ): + tophub_context = autotvm.tophub.context(target) + else: + tophub_context = autotvm.utils.EmptyContext() + with tophub_context: + return list(extract_task_func(mod, target, relay_params, te_filter_func)) def is_meta_schedule_enabled() -> bool: @@ -117,3 +125,17 @@ def is_meta_schedule_enabled() -> bool: "relay.backend.use_meta_schedule", False, ) + + +def is_meta_schedule_dispatch_enabled() -> bool: + """Return whether the meta-schedule dispatch is enabled. + + Returns + ------- + enabled: bool + Whether the meta schedule is enabled + """ + return transform.PassContext.current().config.get( + "relay.backend.use_meta_schedule_dispatch", + False, + ) diff --git a/python/tvm/meta_schedule/tune.py b/python/tvm/meta_schedule/tune.py index fabf14ab23c7..fe4395d332ce 100644 --- a/python/tvm/meta_schedule/tune.py +++ b/python/tvm/meta_schedule/tune.py @@ -588,6 +588,9 @@ def tune_relay( with target, autotvm_silencer(), ApplyHistoryBest(database): with PassContext( opt_level=3, - config={"relay.backend.use_meta_schedule": True}, + config={ + "relay.backend.use_meta_schedule": True, + "relay.backend.use_meta_schedule_dispatch": target.kind.name != "cuda", + }, ): return relay_build(mod, target=target, params=params) diff --git a/python/tvm/relay/backend/te_compiler.py b/python/tvm/relay/backend/te_compiler.py index 3c87f45b8f7d..a2fbf555e12b 100644 --- a/python/tvm/relay/backend/te_compiler.py +++ b/python/tvm/relay/backend/te_compiler.py @@ -23,7 +23,8 @@ import numpy as np import tvm from tvm import autotvm, te -from tvm.ir.transform import PassContext +from tvm.auto_scheduler import is_auto_scheduler_enabled +from tvm.meta_schedule import is_meta_schedule_dispatch_enabled from tvm.runtime import Object from tvm.support import libinfo from tvm.target import Target @@ -180,7 +181,7 @@ def select_implementation(op, attrs, inputs, out_type, target, use_autotvm=True) # Disable autotvm if auto_scheduler is enabled. # (i.e., always return the implementation with the highest priority for auto-scheduler). - if PassContext.current().config.get("relay.backend.use_auto_scheduler", False): + if is_auto_scheduler_enabled() or is_meta_schedule_dispatch_enabled(): use_autotvm = False # If not use autotvm, always return the implementation with the highest priority diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index 072b958da213..9c4a896d572d 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -252,9 +252,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target): ) # register auto-scheduler implementations - if ( - is_auto_scheduler_enabled() or is_meta_schedule_enabled() - ) and judge_winograd_auto_scheduler: + if is_auto_scheduler_enabled() and judge_winograd_auto_scheduler: strategy.add_implementation( wrap_compute_conv2d(topi.nn.conv2d_winograd_nhwc), naive_schedule, # this implementation should never be picked by autotvm @@ -545,7 +543,7 @@ def conv2d_winograd_without_weight_transfrom_strategy_cuda(attrs, inputs, out_ty name="conv2d_nhwc_winograd_direct_without_weight_transform.cuda", ) - if is_auto_scheduler_enabled() or is_meta_schedule_enabled(): + if is_auto_scheduler_enabled(): strategy.add_implementation( wrap_compute_conv2d(topi.nn.conv2d_winograd_nhwc_without_weight_transform), naive_schedule, # this implementation should never be picked by autotvm @@ -823,7 +821,7 @@ def matmul_strategy_cuda(attrs, inputs, out_type, target): """Matmul cuda strategy.""" strategy = _op.OpStrategy() - if is_auto_scheduler_enabled() or is_meta_schedule_enabled(): + if is_auto_scheduler_enabled(): strategy.add_implementation( wrap_compute_matmul(topi.nn.matmul), naive_schedule, diff --git a/src/meta_schedule/database/json_database.cc b/src/meta_schedule/database/json_database.cc index 23ecb121f499..5e7c9119c95a 100644 --- a/src/meta_schedule/database/json_database.cc +++ b/src/meta_schedule/database/json_database.cc @@ -204,7 +204,7 @@ Database Database::JSONDatabase(String path_workload, String path_tuning_record, LOG(FATAL) << "ValueError: Unable to parse TuningRecord, on line " << (task_id + 1) << " of file " << path_tuning_record << ". The workload is:\n" << (workload.defined() ? tir::AsTVMScript(workload) : "(null)") - << "\nThe JSONObject of TuningRecrod is:\n" + << "\nThe JSONObject of TuningRecord is:\n" << json_obj << "\nThe error message is:\n" << e.what(); } diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index 08fa18b61e16..097b9153929b 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -552,6 +552,7 @@ TECompiler& TECompiler::Global() { } TVM_REGISTER_PASS_CONFIG_OPTION("relay.backend.use_auto_scheduler", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("relay.backend.use_meta_schedule", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("relay.backend.use_meta_schedule_dispatch", Bool); TVM_REGISTER_GLOBAL("relay.backend._TECompilerGlobal").set_body_typed([]() { return TECompiler::Global();