diff --git a/include/tvm/meta_schedule/apply_history_best.h b/include/tvm/meta_schedule/apply_history_best.h new file mode 100644 index 000000000000..9d6f46dd6c43 --- /dev/null +++ b/include/tvm/meta_schedule/apply_history_best.h @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_META_SCHEDULE_APPLY_HISTORY_BEST_H_ +#define TVM_META_SCHEDULE_APPLY_HISTORY_BEST_H_ + +#include +#include + +namespace tvm { +namespace meta_schedule { + +/*! + * \brief An integration context that allows application of historically best records from a + * database + */ +class ApplyHistoryBestNode : public runtime::Object { + public: + /*! \brief The database to be queried from */ + Database database{nullptr}; + + void VisitAttrs(AttrVisitor* v) { v->Visit("database", &database); } + /*! + * \brief Query the best entry from the database + * \param task_name The name of the task to be queried + * \param mod The module to be queried + * \param target The target to be queried + * \param dispatched The IRs after dispatch + */ + Optional Query(runtime::String task_name, IRModule mod, Target target, + Optional> dispatched); + + static constexpr const char* _type_key = "meta_schedule.ApplyHistoryBest"; + TVM_DECLARE_FINAL_OBJECT_INFO(ApplyHistoryBestNode, runtime::Object); +}; + +/*! + * \brief Managed reference to ApplyHistoryBestNode + * \sa ApplyHistoryBestNode + */ +class ApplyHistoryBest : public runtime::ObjectRef { + public: + /*! + * \brief Constructor + * \param database The database to be queried from + */ + explicit ApplyHistoryBest(Database database); + /*! + * \brief The current ApplyHistoryBest in the context + * \return The ApplyHistoryBest in the current scope. + */ + static Optional Current(); + + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ApplyHistoryBest, runtime::ObjectRef, + ApplyHistoryBestNode); + + protected: + friend class ApplyHistoryBestInternal; + /*! \brief Entering the scope of the context manager */ + void EnterWithScope(); + /*! \brief Exiting the scope of the context manager */ + void ExitWithScope(); +}; + +} // namespace meta_schedule +} // namespace tvm + +#endif // TVM_META_SCHEDULE_APPLY_HISTORY_BEST_H_ diff --git a/include/tvm/meta_schedule/extracted_task.h b/include/tvm/meta_schedule/extracted_task.h new file mode 100644 index 000000000000..c6613427fd5b --- /dev/null +++ b/include/tvm/meta_schedule/extracted_task.h @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_META_SCHEDULE_EXTRACTED_TASK_H_ +#define TVM_META_SCHEDULE_EXTRACTED_TASK_H_ + +#include + +namespace tvm { +namespace meta_schedule { + +/*! \brief A tuning task extracted from the high-level IR */ +class ExtractedTaskNode : public runtime::Object { + public: + /*! \brief The name of the task extracted */ + String task_name; + /*! \brief The high-level IR */ + IRModule mod; + /*! \brief Target */ + Target target; + /*! \brief A list of low-level IRs that the high-level IR could potentially dispatch to */ + Array dispatched; + /*! \brief Weight of the task */ + int weight; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("task_name", &task_name); + v->Visit("mod", &mod); + v->Visit("target", &target); + v->Visit("dispatched", &dispatched); + v->Visit("weight", &weight); + } + + static constexpr const char* _type_key = "meta_schedule.ExtractedTask"; + TVM_DECLARE_FINAL_OBJECT_INFO(ExtractedTaskNode, runtime::Object); +}; + +/*! + * \brief Managed reference to ExtractedTaskNode + * \sa ExtractedTaskNode + */ +class ExtractedTask : public runtime::ObjectRef { + public: + explicit ExtractedTask(String task_name, IRModule mod, Target target, Array dispatched, + int weight); + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ExtractedTask, runtime::ObjectRef, + ExtractedTaskNode); +}; + +} // namespace meta_schedule +} // namespace tvm + +#endif // TVM_META_SCHEDULE_EXTRACTED_TASK_H_ diff --git a/include/tvm/meta_schedule/integration.h b/include/tvm/meta_schedule/integration.h deleted file mode 100644 index b231913f2f9b..000000000000 --- a/include/tvm/meta_schedule/integration.h +++ /dev/null @@ -1,190 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#ifndef TVM_META_SCHEDULE_INTEGRATION_H_ -#define TVM_META_SCHEDULE_INTEGRATION_H_ - -#include -#include -#include - -#include - -namespace tvm { -namespace meta_schedule { - -/**************** ExtractedTask ****************/ - -/*! - * \brief A tuning task extracted from the high-level IR - */ -class ExtractedTaskNode : public runtime::Object { - public: - /*! \brief The name of the task extracted */ - String task_name; - /*! \brief The high-level IR */ - IRModule mod; - /*! \brief Target */ - Target target; - /*! \brief A list of low-level IRs that the high-level IR could potentially dispatch to */ - Array dispatched; - /*! \brief Weight of the task */ - int weight; - - void VisitAttrs(AttrVisitor* v) { - v->Visit("task_name", &task_name); - v->Visit("mod", &mod); - v->Visit("target", &target); - v->Visit("dispatched", &dispatched); - v->Visit("weight", &weight); - } - - static constexpr const char* _type_key = "meta_schedule.ExtractedTask"; - TVM_DECLARE_FINAL_OBJECT_INFO(ExtractedTaskNode, runtime::Object); -}; - -/*! - * \brief Managed reference to ExtractedTaskNode - * \sa ExtractedTaskNode - */ -class ExtractedTask : public runtime::ObjectRef { - public: - /*! - * \brief Constructor. The name of the task extracted - * \brief The high-level IR - * \brief A list of low-level IRs that the high-level IR could potentially dispatch to - */ - explicit ExtractedTask(String task_name, IRModule mod, Target target, Array dispatched, - int weight); - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ExtractedTask, runtime::ObjectRef, - ExtractedTaskNode); -}; - -/**************** MetaScheduleContext ****************/ - -/*! - * \brief A context manager interface for the integration - */ -class MetaScheduleContextNode : public runtime::Object { - public: - /*! \brief Default destructor */ - virtual ~MetaScheduleContextNode() = default; - /*! - * \brief The entry point of the integration - * \param task_name The name of the task - * \param mod The high-level IR - * \param target Target info - * \param dispatched A list of low-level IRs that the high-level IR could potentially dispatch to. - * NullOpt means the dispatch needs to be done in the context. - * \return IRModule or NullOpt Currently we only have to return tir::PrimFunc, but we wrap it - * under IRModule for more general future use. NullOpt is returned - * if there is no feedback hint. - */ - virtual Optional Query(runtime::String task_name, IRModule mod, Target target, - Optional> dispatched) = 0; - - static constexpr const char* _type_key = "meta_schedule.MetaScheduleContext"; - TVM_DECLARE_BASE_OBJECT_INFO(MetaScheduleContextNode, runtime::Object); -}; - -/*! - * \brief Managed reference to MetaScheduleContextNode - * \sa MetaScheduleContextNode - */ -class MetaScheduleContext : public runtime::ObjectRef { - friend class MetaScheduleContextInternal; - friend class With; - - public: - /*! \brief Default destructor */ - virtual ~MetaScheduleContext() = default; - /*! - * \brief The context manager in the current scope - * \return The MetaScheduleContext in the current scope. NullOpt if it's currently not under any - * MetaScheduleContext. - */ - static Optional Current(); - /*! - * \brief The entry point of the integration workflow. The compilation process of the high-level - * IR should call this method for task extraction and for feedback hints - * \param task_name The name of the task - * \param mod The high-level IR - * \param target Target info - * \param dispatched A list of low-level IRs that the high-level IR could potentially dispatch to - * \return IRModule or NullOpt Currently we only have to return tir::PrimFunc, but we wrap it - * under IRModule for more general future use. NullOpt is returned - * if there is no feedback hint - */ - static Optional QueryInsideWithScope(runtime::String task_name, IRModule mod, - Target target, - Optional> dispatched); - - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(MetaScheduleContext, runtime::ObjectRef, - MetaScheduleContextNode); - - protected: - /*! \brief Default constructor */ - MetaScheduleContext() = default; - /*! \brief Entering the scope of the context manager */ - void EnterWithScope(); - /*! \brief Exiting the scope of the context manager */ - void ExitWithScope(); -}; - -/**************** ApplyHistoryBest ****************/ - -/*! - * \brief An integration context that allows application of historically best records from a - * database - */ -class ApplyHistoryBestNode : public MetaScheduleContextNode { - public: - /*! \brief The database to be queried from */ - Database database{nullptr}; - - void VisitAttrs(AttrVisitor* v) { - v->Visit("database", &database); // - } - - // Inherited from base class - Optional Query(runtime::String task_name, IRModule mod, Target target, - Optional> dispatched) final; - - static constexpr const char* _type_key = "meta_schedule.ApplyHistoryBest"; - TVM_DECLARE_FINAL_OBJECT_INFO(ApplyHistoryBestNode, MetaScheduleContextNode); -}; - -/*! - * \brief Managed reference to ApplyHistoryBestNode - * \sa ApplyHistoryBestNode - */ -class ApplyHistoryBest : public MetaScheduleContext { - public: - /*! - * \brief Constructor - * \param database The database to be queried from - */ - explicit ApplyHistoryBest(Database database); - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ApplyHistoryBest, MetaScheduleContext, - ApplyHistoryBestNode); -}; - -} // namespace meta_schedule -} // namespace tvm - -#endif // TVM_META_SCHEDULE_INTEGRATION_H_ diff --git a/python/tvm/meta_schedule/__init__.py b/python/tvm/meta_schedule/__init__.py index 3612bb81a6bc..466c5e3e6699 100644 --- a/python/tvm/meta_schedule/__init__.py +++ b/python/tvm/meta_schedule/__init__.py @@ -21,7 +21,6 @@ cost_model, database, feature_extractor, - integration, mutator, postproc, runner, @@ -29,6 +28,9 @@ search_strategy, space_generator, ) +from .apply_history_best import ApplyHistoryBest +from .extracted_task import ExtractedTask +from .relay_integration import extract_task_from_relay from .search_strategy import MeasureCandidate from .tune import ( EvolutionarySearchConfig, diff --git a/python/tvm/meta_schedule/apply_history_best.py b/python/tvm/meta_schedule/apply_history_best.py new file mode 100644 index 000000000000..5e1e40bd154b --- /dev/null +++ b/python/tvm/meta_schedule/apply_history_best.py @@ -0,0 +1,100 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""A context manager that injects the best tuning record in the database into compilation""" +from typing import List, Optional, Union + +from tvm._ffi import register_object +from tvm.ir import IRModule +from tvm.runtime import Object +from tvm.target import Target + +from . import _ffi_api +from .database import Database + + +@register_object("meta_schedule.ApplyHistoryBest") +class ApplyHistoryBest(Object): + """An integration context that allows application of historically best records from a database + + Parameters + ---------- + database : Database + The database to be queried from + """ + + database: Database + + def __init__( + self, + database: Database, + ) -> None: + self.__init_handle_by_constructor__(_ffi_api.ApplyHistoryBest, database) # type: ignore # pylint: disable=no-member + + def query( + self, + task_name: str, + mod: IRModule, + target: Target, + dispatched: Optional[List[IRModule]], + ) -> Union[IRModule, None]: + """The entry point of the integration + + Parameters + ---------- + task_name : str + The name of the task extracted + mod : IRModule + The high-level IR + target: Target + Target Info + dispatched : Optional[List[IRModule]] + A list of low-level IRs that the high-level IR could potentially dispatch to + + Returns + ------- + result : IRModule or None + Currently we only have to return tir::PrimFunc, but we wrap it under IRModule for + more general future use. None is returned if there is no feedback hint. + """ + return _ffi_api.ApplyHistoryBestQuery( # type: ignore # pylint: disable=no-member + self, + task_name, + mod, + target, + dispatched, + ) + + @staticmethod + def current() -> Optional["ApplyHistoryBest"]: + """The context manager in the current scope + + Returns + ------- + ctx : Optional[ApplyHistoryBest] + The ApplyHistoryBest context manager in the current scope. + None if it's currently not under any ApplyHistoryBest context. + """ + return _ffi_api.ApplyHistoryBestCurrent() # type: ignore # pylint: disable=no-member + + def __enter__(self) -> "ApplyHistoryBest": + """Entering the scope of the context manager""" + _ffi_api.ApplyHistoryBestEnterScope(self) # type: ignore # pylint: disable=no-member + return self + + def __exit__(self, ptype, value, trace) -> None: + """Exiting the scope of the context manager""" + _ffi_api.ApplyHistoryBestExitScope(self) # type: ignore # pylint: disable=no-member diff --git a/python/tvm/meta_schedule/extracted_task.py b/python/tvm/meta_schedule/extracted_task.py new file mode 100644 index 000000000000..b69a38ef6dc0 --- /dev/null +++ b/python/tvm/meta_schedule/extracted_task.py @@ -0,0 +1,66 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Extracted tasks from high-level IR.""" +from typing import List + +from tvm._ffi import register_object +from tvm.ir import IRModule +from tvm.runtime import Object +from tvm.target import Target + +from . import _ffi_api + + +@register_object("meta_schedule.ExtractedTask") +class ExtractedTask(Object): + """A tuning task extracted from the high-level IR + + Parameters + ---------- + task_name : str + The name of the task extracted + mod : IRModule + The high-level IR + target: Target + Target information + dispatched : List[IRModule] + A list of low-level IRs that the high-level IR could potentially dispatch to + weight : int + The weight of the task + """ + + task_name: str + mod: IRModule + dispatched: List[IRModule] + weight: int + + def __init__( + self, + task_name: str, + mod: IRModule, + target: Target, + dispatched: List[IRModule], + weight: int, + ) -> None: + self.__init_handle_by_constructor__( + _ffi_api.ExtractedTask, # type: ignore # pylint: disable=no-member + task_name, + mod, + target, + dispatched, + weight, + ) diff --git a/python/tvm/meta_schedule/integration.py b/python/tvm/meta_schedule/integration.py deleted file mode 100644 index db6771fecafc..000000000000 --- a/python/tvm/meta_schedule/integration.py +++ /dev/null @@ -1,247 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Meta schedule integration with high-level IR""" -from typing import Dict, List, Optional, Union - -import numpy as np # type: ignore -import tvm.runtime.ndarray as nd -from tvm._ffi import get_global_func, register_object -from tvm.ir import IRModule, transform -from tvm.relay import Any -from tvm.relay import Function as RelayFunc -from tvm.runtime import NDArray, Object -from tvm.target import Target - -from . import _ffi_api -from .database import Database -from .utils import autotvm_silencer - - -@register_object("meta_schedule.ExtractedTask") -class ExtractedTask(Object): - """A tuning task extracted from the high-level IR - - Parameters - ---------- - task_name : str - The name of the task extracted - mod : IRModule - The high-level IR - target: Target - Target information - dispatched : List[IRModule] - A list of low-level IRs that the high-level IR could potentially dispatch to - weight : int - The weight of the task - """ - - task_name: str - mod: IRModule - dispatched: List[IRModule] - weight: int - - def __init__( - self, - task_name: str, - mod: IRModule, - target: Target, - dispatched: List[IRModule], - weight: int, - ) -> None: - self.__init_handle_by_constructor__( - _ffi_api.ExtractedTask, # type: ignore # pylint: disable=no-member - task_name, - mod, - target, - dispatched, - weight, - ) - - -@register_object("meta_schedule.MetaScheduleContext") -class MetaScheduleContext(Object): - """A context manager interface for the integration""" - - def query( - self, - task_name: str, - mod: IRModule, - target: Target, - dispatched: Optional[List[IRModule]], - ) -> Union[IRModule, None]: - """The entry point of the integration - - Parameters - ---------- - task_name : str - The name of the task extracted - mod : IRModule - The high-level IR - target: Target - Target Info - dispatched : Optional[List[IRModule]] - A list of low-level IRs that the high-level IR could potentially dispatch to - - Returns - ------- - result : IRModule or None - Currently we only have to return tir::PrimFunc, but we wrap it under IRModule for - more general future use. None is returned if there is no feedback hint. - """ - return _ffi_api.MetaScheduleContextQuery( # type: ignore # pylint: disable=no-member - self, - task_name, - mod, - target, - dispatched, - ) - - @staticmethod - def current() -> Optional["MetaScheduleContext"]: - """The context manager in the current scope - - Returns - ------- - ctx : Optional[MetaScheduleContext] - The MetaScheduleContext in the current scope. - NullOpt if it's currently not under any MetaScheduleContext. - """ - return _ffi_api.MetaScheduleContextCurrent() # type: ignore # pylint: disable=no-member - - @staticmethod - def query_inside_with_scope( - task_name: str, - mod: IRModule, - target: Target, - dispatched: Optional[List[IRModule]], - ) -> Union[IRModule, None]: - """The entry point of the integration workflow. The compilation process of the high-level - IR should call this method for task extraction and for feedback hints - - Basically, this method is equivalent to: - - .. code-block:: python - - def query_inside_with_scope(task_name, mod, dispatched): - ctx = MetaScheduleContext.current() - assert ctx is not None - mod = ctx.query(task_name, mod, target, dispatched) - - Parameters - ---------- - task_name : str - The name of the task - mod : IRModule - The high-level IR - target: Target - Target - dispatched : Optional[List[IRModule]] - A list of low-level IRs that the high-level IR could potentially dispatch to - - Returns - ------- - result : IRModule or None - Currently we only have to return tir::PrimFunc, but we wrap it under IRModule for - more general future use. None is returned if there is no feedback hint. - """ - return _ffi_api.MetaScheduleContextQueryInsideWithScope( # type: ignore # pylint: disable=no-member - task_name, - mod, - target, - dispatched, - ) - - def __enter__(self) -> "MetaScheduleContext": - """Entering the scope of the context manager""" - _ffi_api.MetaScheduleContextEnterScope(self) # type: ignore # pylint: disable=no-member - return self - - def __exit__(self, ptype, value, trace) -> None: - """Exiting the scope of the context manager""" - _ffi_api.MetaScheduleContextExitScope(self) # type: ignore # pylint: disable=no-member - - -@register_object("meta_schedule.ApplyHistoryBest") -class ApplyHistoryBest(MetaScheduleContext): - """An integration context that allows application of historically best record from database""" - - database: Database - """ The database to be queried from""" - - def __init__(self, database) -> None: - self.__init_handle_by_constructor__(_ffi_api.ApplyHistoryBest, database) # type: ignore # pylint: disable=no-member - - -def extract_task_from_relay( - mod: Union[IRModule, RelayFunc], - target: Target, - params: Optional[Dict[str, NDArray]] = None, - *, - opt_level: int = 3, - pass_config: Optional[Dict[str, Any]] = None, - disabled_pass: Optional[List[str]] = None, -) -> List[ExtractedTask]: - """Extract tuning tasks from a relay program. - - Parameters - ---------- - mod : Union[tvm.IRModule, tvm.relay.Function] - The module or function to tune - target : tvm.target.Target - The compilation target - params : Optional[Dict[str, tvm.runtime.NDArray]] - The associated parameters of the program - opt_level : int - The optimization level of the compiler - pass_config : Optional[Dict[str, Any]] - The pass config of the compiler - disabled_pass : Optional[List[str]] - The list of disabled passes of the compiler - - Returns - ------- - tasks: List[ExtractedTask] - The tasks extracted from this network - """ - - extract_task_func = get_global_func("relay.backend.MetaScheduleExtractTask") - assert extract_task_func - - target = Target(target) if isinstance(target, str) else target - - relay_params = {} - for name, param in params.items(): - if isinstance(param, np.ndarray): - param = nd.array(param) - relay_params[name] = param - - if disabled_pass is None: - disabled_pass = [] - if pass_config is None: - pass_config = {"relay.backend.use_meta_schedule": True} - - if isinstance(mod, RelayFunc): - mod = IRModule.from_expr(mod) - if not isinstance(target, Target): - target = Target(target) - - with autotvm_silencer(), target, transform.PassContext( - opt_level=opt_level, - config=pass_config, - disabled_pass=disabled_pass, - ): - return list(extract_task_func(mod, target, relay_params)) diff --git a/python/tvm/meta_schedule/relay_integration.py b/python/tvm/meta_schedule/relay_integration.py new file mode 100644 index 000000000000..4478ffc76b47 --- /dev/null +++ b/python/tvm/meta_schedule/relay_integration.py @@ -0,0 +1,91 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""MetaSchedule-Relay integration""" +from typing import Any, Dict, List, Optional + +import numpy as np # type: ignore +from tvm import nd +from tvm._ffi import get_global_func +from tvm.ir import IRModule, transform +from tvm.runtime import NDArray +from tvm.target import Target + +from .extracted_task import ExtractedTask +from .utils import autotvm_silencer + + +def extract_task_from_relay( + mod: IRModule, + target: Target, + params: Optional[Dict[str, NDArray]] = None, + *, + opt_level: int = 3, + pass_config: Optional[Dict[str, Any]] = None, + disabled_pass: Optional[List[str]] = None, +) -> List[ExtractedTask]: + """Extract tuning tasks from a relay program. + + Parameters + ---------- + mod : IRModule + The module or function to tune + target : tvm.target.Target + The compilation target + params : Optional[Dict[str, tvm.runtime.NDArray]] + The associated parameters of the program + opt_level : int + The optimization level of the compiler + pass_config : Optional[Dict[str, Any]] + The pass config of the compiler + disabled_pass : Optional[List[str]] + The list of disabled passes of the compiler + + Returns + ------- + tasks: List[ExtractedTask] + The tasks extracted from this network + """ + # pylint: disable=import-outside-toplevel + from tvm.relay import Function as RelayFunc + + # pylint: enable=import-outside-toplevel + + extract_task_func = get_global_func( + "relay.backend.MetaScheduleExtractTask", + allow_missing=False, + ) + + if isinstance(mod, RelayFunc): + mod = IRModule.from_expr(mod) + if not isinstance(target, Target): + target = Target(target) + if disabled_pass is None: + disabled_pass = [] + if pass_config is None: + pass_config = {"relay.backend.use_meta_schedule": True} + relay_params = {} + for name, param in params.items(): + if isinstance(param, np.ndarray): + param = nd.array(param) + relay_params[name] = param + + with autotvm_silencer(), target, transform.PassContext( + opt_level=opt_level, + config=pass_config, + disabled_pass=disabled_pass, + ): + return list(extract_task_func(mod, target, relay_params)) diff --git a/python/tvm/meta_schedule/testing/relay_workload.py b/python/tvm/meta_schedule/testing/relay_workload.py index 83a70abb7fc9..2dbd290a28eb 100644 --- a/python/tvm/meta_schedule/testing/relay_workload.py +++ b/python/tvm/meta_schedule/testing/relay_workload.py @@ -26,7 +26,7 @@ import tvm.relay.testing from tvm import relay from tvm.ir import IRModule -from tvm.meta_schedule.integration import ExtractedTask, extract_task_from_relay +from tvm.meta_schedule import ExtractedTask, extract_task_from_relay from tvm.runtime import NDArray, load_param_dict, save_param_dict from tvm.target import Target diff --git a/python/tvm/meta_schedule/testing/tune_relay_meta_schedule.py b/python/tvm/meta_schedule/testing/tune_relay_meta_schedule.py index 5859412ebbf0..0973c9b91bff 100644 --- a/python/tvm/meta_schedule/testing/tune_relay_meta_schedule.py +++ b/python/tvm/meta_schedule/testing/tune_relay_meta_schedule.py @@ -24,7 +24,6 @@ import tvm from tvm import meta_schedule as ms from tvm.ir.transform import PassContext -from tvm.meta_schedule.integration import extract_task_from_relay from tvm.meta_schedule.testing.custom_builder_runner import run_module_via_rpc from tvm.meta_schedule.testing.relay_workload import get_network from tvm.relay import build as relay_build @@ -107,7 +106,7 @@ def tune_each_task( work_dir, params, ): - extracted_tasks = extract_task_from_relay(mod, target, params) + extracted_tasks = ms.extract_task_from_relay(mod, target, params) database = ms.database.JSONDatabase( path_workload=os.path.join(work_dir, "default_database_workload.json"), path_tuning_record=os.path.join(work_dir, "default_database_tuning_record.json"), @@ -139,7 +138,7 @@ def tune_each_task( ) # pylint: enable=protected-access task_scheduler.tune() - with target, ms.integration.ApplyHistoryBest(database): + with target, ms.ApplyHistoryBest(database): with PassContext( opt_level=3, config={"relay.backend.use_meta_schedule": True}, diff --git a/python/tvm/meta_schedule/testing/utils.py b/python/tvm/meta_schedule/testing/utils.py index e22677a3b918..a832dfc6bcc4 100644 --- a/python/tvm/meta_schedule/testing/utils.py +++ b/python/tvm/meta_schedule/testing/utils.py @@ -14,31 +14,31 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Testing utilitiy functions in meta schedule""" +"""Testing utility functions in meta schedule""" import random -from typing import List, Optional, Callable, Dict, Union +from typing import Callable, Dict, List, Optional, Union import tvm -from tvm.relay import Function as RelayFunc -from tvm.tir import Schedule -from tvm.target import Target -from tvm.runtime import NDArray +from tvm.ir import IRModule from tvm.meta_schedule import TuneContext # pylint: disable=unused-import -from tvm.meta_schedule.utils import derived_object +from tvm.meta_schedule.builder import BuilderInput, BuilderResult, PyBuilder +from tvm.meta_schedule.database import PyDatabase, TuningRecord, Workload +from tvm.meta_schedule.extracted_task import ExtractedTask from tvm.meta_schedule.mutator.mutator import PyMutator -from tvm.meta_schedule.database import PyDatabase, Workload, TuningRecord -from tvm.meta_schedule.builder import PyBuilder, BuilderInput, BuilderResult +from tvm.meta_schedule.relay_integration import extract_task_from_relay from tvm.meta_schedule.runner import ( + PyRunner, + PyRunnerFuture, + RunnerFuture, RunnerInput, RunnerResult, - RunnerFuture, - PyRunnerFuture, - PyRunner, ) -from tvm.meta_schedule.tune import Parse, extract_task_from_relay -from tvm.meta_schedule.integration import ExtractedTask - -from tvm.ir import IRModule +from tvm.meta_schedule.tune import Parse +from tvm.meta_schedule.utils import derived_object +from tvm.relay import Function as RelayFunc +from tvm.runtime import NDArray +from tvm.target import Target +from tvm.tir import Schedule from tvm.tir.schedule import Trace diff --git a/python/tvm/meta_schedule/tune.py b/python/tvm/meta_schedule/tune.py index 86157e0fb32e..31130f67af34 100644 --- a/python/tvm/meta_schedule/tune.py +++ b/python/tvm/meta_schedule/tune.py @@ -23,18 +23,17 @@ from tvm._ffi.registry import register_func from tvm.ir import IRModule, structural_hash from tvm.ir.transform import PassContext -from tvm.relay import Function as RelayFunc -from tvm.relay import build as relay_build from tvm.runtime import Module, NDArray from tvm.target import Target from tvm.te import Tensor, create_prim_func from tvm.tir import PrimFunc, Schedule +from .apply_history_best import ApplyHistoryBest from .builder import Builder, LocalBuilder from .cost_model import CostModel, XGBModel from .database import Database, JSONDatabase, TuningRecord +from .extracted_task import ExtractedTask from .feature_extractor import PerStoreFeature -from .integration import ApplyHistoryBest, ExtractedTask, extract_task_from_relay from .measure_callback import MeasureCallback from .mutator import Mutator from .postproc import Postproc @@ -822,7 +821,7 @@ def tune_extracted_tasks( def tune_relay( - mod: Union[RelayFunc, IRModule], + mod: IRModule, target: Union[str, Target], config: SearchStrategyConfig, work_dir: str, @@ -844,7 +843,7 @@ def tune_relay( Parameters ---------- - mod : Union[RelayFunc, IRModule] + mod : IRModule The module to tune. target : Union[str, Target] The target to tune for. @@ -874,6 +873,12 @@ def tune_relay( lib : Module The built runtime module for the given relay workload. """ + # pylint: disable=import-outside-toplevel + from tvm.relay import build as relay_build + + from .relay_integration import extract_task_from_relay + + # pylint: enable=import-outside-toplevel logger.info("Working directory: %s", work_dir) # pylint: disable=protected-access diff --git a/src/meta_schedule/integration.cc b/src/meta_schedule/apply_history_best.cc similarity index 58% rename from src/meta_schedule/integration.cc rename to src/meta_schedule/apply_history_best.cc index 35c3baf237a4..41714cf7b0ce 100644 --- a/src/meta_schedule/integration.cc +++ b/src/meta_schedule/apply_history_best.cc @@ -16,17 +16,13 @@ * specific language governing permissions and limitations * under the License. */ -#include -#include -#include - #include "./utils.h" -#include "tvm/runtime/container/optional.h" namespace tvm { namespace meta_schedule { /**************** Utility functions ****************/ + template Optional GetOnlyOneFunctionCommon(const IRModule& mod, Callback on_found) { if (mod->functions.size() != 1) { @@ -59,54 +55,36 @@ bool HasOnlyOneFunction(const IRModule& mod) { return GetOnlyOneFunction(mod).defined(); } -/**************** ExtractedTask ****************/ - -ExtractedTask::ExtractedTask(String task_name, IRModule mod, Target target, - Array dispatched, int weight) { - ObjectPtr n = make_object(); - n->task_name = task_name; - n->mod = mod; - n->target = target; - n->dispatched = dispatched; - n->weight = weight; - data_ = n; -} +/**************** Context Manager ****************/ -/**************** MetaScheduleContext ****************/ +class ApplyHistoryBestInternal { + public: + static void EnterScope(ApplyHistoryBest ctx) { ctx.EnterWithScope(); } + static void ExitScope(ApplyHistoryBest ctx) { ctx.ExitWithScope(); } +}; -struct MetaScheduleContextThreadLocalEntry { - Optional ctx; +struct ApplyHistoryBestThreadLocalEntry { + Optional ctx; }; -using MetaScheduleContextThreadLocalStore = - dmlc::ThreadLocalStore; +using ApplyHistoryBestThreadLocalStore = dmlc::ThreadLocalStore; -Optional MetaScheduleContext::Current() { - return MetaScheduleContextThreadLocalStore::Get()->ctx; +Optional ApplyHistoryBest::Current() { + return ApplyHistoryBestThreadLocalStore::Get()->ctx; } -void MetaScheduleContext::EnterWithScope() { - Optional& ctx = MetaScheduleContextThreadLocalStore::Get()->ctx; - CHECK(!ctx.defined()) - << "ValueError: Nested MetaScheduleContext context managers are not allowed"; +void ApplyHistoryBest::EnterWithScope() { + Optional& ctx = ApplyHistoryBestThreadLocalStore::Get()->ctx; + CHECK(!ctx.defined()) << "ValueError: Nested ApplyHistoryBest context managers are not allowed"; ctx = *this; } -void MetaScheduleContext::ExitWithScope() { - Optional& ctx = MetaScheduleContextThreadLocalStore::Get()->ctx; +void ApplyHistoryBest::ExitWithScope() { + Optional& ctx = ApplyHistoryBestThreadLocalStore::Get()->ctx; ICHECK(ctx.defined()); ctx = NullOpt; } -Optional MetaScheduleContext::QueryInsideWithScope(runtime::String task_name, - IRModule mod, Target target, - Optional> dispatched) { - if (Optional ctx = MetaScheduleContext::Current()) { - return ctx.value()->Query(task_name, mod, target, dispatched); - } - return NullOpt; -} - /**************** ApplyHistoryBest ****************/ ApplyHistoryBest::ApplyHistoryBest(Database database) { @@ -149,37 +127,19 @@ Optional ApplyHistoryBestNode::Query(runtime::String task_name, IRModu return NullOpt; } -/**************** FFI ****************/ - -class MetaScheduleContextInternal { - public: - static void EnterScope(MetaScheduleContext ctx) { ctx.EnterWithScope(); } - static void ExitScope(MetaScheduleContext ctx) { ctx.ExitWithScope(); } -}; - -TVM_REGISTER_NODE_TYPE(ExtractedTaskNode); -TVM_REGISTER_OBJECT_TYPE(MetaScheduleContextNode); TVM_REGISTER_NODE_TYPE(ApplyHistoryBestNode); - -TVM_REGISTER_GLOBAL("meta_schedule.ExtractedTask") - .set_body_typed([](String task_name, IRModule mod, Target target, Array dispatched, - int weight) -> ExtractedTask { - return ExtractedTask(task_name, mod, target, dispatched, weight); - }); -TVM_REGISTER_GLOBAL("meta_schedule.MetaScheduleContextEnterScope") - .set_body_typed(MetaScheduleContextInternal::EnterScope); -TVM_REGISTER_GLOBAL("meta_schedule.MetaScheduleContextExitScope") - .set_body_typed(MetaScheduleContextInternal::ExitScope); -TVM_REGISTER_GLOBAL("meta_schedule.MetaScheduleContextCurrent") - .set_body_typed(MetaScheduleContext::Current); -TVM_REGISTER_GLOBAL("meta_schedule.MetaScheduleContextQueryInsideWithScope") - .set_body_typed(MetaScheduleContext::QueryInsideWithScope); -TVM_REGISTER_GLOBAL("meta_schedule.MetaScheduleContextQuery") - .set_body_method(&MetaScheduleContextNode::Query); TVM_REGISTER_GLOBAL("meta_schedule.ApplyHistoryBest") .set_body_typed([](Database database) -> ApplyHistoryBest { return ApplyHistoryBest(database); }); +TVM_REGISTER_GLOBAL("meta_schedule.ApplyHistoryBestEnterScope") + .set_body_typed(ApplyHistoryBestInternal::EnterScope); +TVM_REGISTER_GLOBAL("meta_schedule.ApplyHistoryBestExitScope") + .set_body_typed(ApplyHistoryBestInternal::ExitScope); +TVM_REGISTER_GLOBAL("meta_schedule.ApplyHistoryBestCurrent") + .set_body_typed(ApplyHistoryBest::Current); +TVM_REGISTER_GLOBAL("meta_schedule.ApplyHistoryBestQuery") + .set_body_method(&ApplyHistoryBestNode::Query); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/extracted_task.cc b/src/meta_schedule/extracted_task.cc new file mode 100644 index 000000000000..b1044fc87d0f --- /dev/null +++ b/src/meta_schedule/extracted_task.cc @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include + +namespace tvm { +namespace meta_schedule { + +ExtractedTask::ExtractedTask(String task_name, IRModule mod, Target target, + Array dispatched, int weight) { + ObjectPtr n = make_object(); + n->task_name = task_name; + n->mod = mod; + n->target = target; + n->dispatched = dispatched; + n->weight = weight; + data_ = n; +} + +TVM_REGISTER_NODE_TYPE(ExtractedTaskNode); +TVM_REGISTER_GLOBAL("meta_schedule.ExtractedTask") + .set_body_typed([](String task_name, IRModule mod, Target target, Array dispatched, + int weight) -> ExtractedTask { + return ExtractedTask(task_name, mod, target, dispatched, weight); + }); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index 2ee18a8668be..45a04958ade1 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -21,6 +21,7 @@ #include #include +#include #include #include #include diff --git a/src/relay/backend/task_extraction.cc b/src/relay/backend/task_extraction.cc index a787f1915099..0895fd42a307 100644 --- a/src/relay/backend/task_extraction.cc +++ b/src/relay/backend/task_extraction.cc @@ -17,16 +17,15 @@ * under the License. */ -#include +#include #include #include #include #include #include "../../te/operation/create_primfunc.h" -#include "te_compiler_cache.h" -#include "tvm/runtime/ndarray.h" -#include "utils.h" +#include "./te_compiler_cache.h" +#include "./utils.h" namespace tvm { namespace relay { diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index e0e7277676bc..a8edeff8626e 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -21,7 +21,7 @@ #include #include -#include +#include #include #include #include @@ -302,7 +302,13 @@ class ScheduleBuilder : public ExprVisitor { explicit ScheduleBuilder(Target target) : target_(target) { // Whether to use auto_scheduler schedule. use_auto_scheduler_ = backend::IsAutoSchedulerEnabled(); - use_meta_scheduler_ = backend::IsMetaScheduleEnabled(); + if (backend::IsMetaScheduleEnabled()) { + meta_schedule_ctx_ = meta_schedule::ApplyHistoryBest::Current(); + CHECK(meta_schedule_ctx_.defined()) << "ValueError: `use_meta_schedule` is enabled in Relay " + "build, but no ApplyHistoryBest context is provided. "; + } else { + meta_schedule_ctx_ = NullOpt; + } } CachedFunc Create(const Function& relay_func, std::function renamer) { @@ -340,12 +346,11 @@ class ScheduleBuilder : public ExprVisitor { schedule = Downcast(obj); } } - if (use_meta_scheduler_) { + if (meta_schedule_ctx_) { IRModule relay_mod({{prim_fn_var, relay_func}}); IRModule tir_mod({{prim_fn_var, tir::CreatePrimFunc(Concat(fn_inputs, tensor_outs))}}); - Optional scheduled_mod = meta_schedule::MetaScheduleContext::QueryInsideWithScope( - prim_fn_var->name_hint, relay_mod, target_, Array{tir_mod}); - if (scheduled_mod) { + if (Optional scheduled_mod = meta_schedule_ctx_.value()->Query( + prim_fn_var->name_hint, relay_mod, target_, Array{tir_mod})) { ICHECK_EQ(scheduled_mod.value()->functions.count(prim_fn_var), 1); prim_func = Downcast(scheduled_mod.value()->functions[prim_fn_var]); } @@ -381,7 +386,7 @@ class ScheduleBuilder : public ExprVisitor { } int op_pattern = fpattern[op]; - if (!use_auto_scheduler_ && !use_meta_scheduler_ && op_pattern >= kCommReduce) { + if (!use_auto_scheduler_ && !meta_schedule_ctx_.defined() && op_pattern >= kCommReduce) { ICHECK(!anchor_op_.defined() || anchor_op_pattern_ < kCommReduce) << "Cannot apply TOPI schedule to a primitive function with two complicated ops" << " anchor=" << anchor_op_ << " current=" << op; @@ -399,7 +404,7 @@ class ScheduleBuilder : public ExprVisitor { Attrs anchor_attrs_; int anchor_op_pattern_{0}; bool use_auto_scheduler_; - bool use_meta_scheduler_; + Optional meta_schedule_ctx_; }; /*! diff --git a/tests/python/unittest/test_meta_schedule_integration.py b/tests/python/unittest/test_meta_schedule_integration.py index 1bbaf35ad280..b17d6ffc6054 100644 --- a/tests/python/unittest/test_meta_schedule_integration.py +++ b/tests/python/unittest/test_meta_schedule_integration.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. import sys -from typing import List import numpy as np import pytest @@ -23,18 +22,13 @@ import tvm.testing from tvm import meta_schedule as ms from tvm import relay -from tvm.ir.module import IRModule -from tvm.meta_schedule.database import PyDatabase, TuningRecord, Workload -from tvm.meta_schedule.integration import ( - ApplyHistoryBest, - ExtractedTask, - MetaScheduleContext, -) +from tvm.meta_schedule import ApplyHistoryBest +from tvm.meta_schedule.database import TuningRecord +from tvm.meta_schedule.relay_integration import extract_task_from_relay from tvm.meta_schedule.testing import DummyDatabase from tvm.meta_schedule.testing.relay_workload import get_network from tvm.meta_schedule.testing.tlcbench import load_quantized_bert_base -from tvm.meta_schedule.tune import Parse, extract_task_from_relay -from tvm.meta_schedule.utils import derived_object +from tvm.meta_schedule.tune import Parse from tvm.script import tir as T from tvm.target import Target from tvm.tir import Schedule @@ -68,14 +62,14 @@ def _has_torch(): requires_torch = pytest.mark.skipif(not _has_torch(), reason="torch is not installed") -def test_meta_schedule_integration_no_current(): - assert MetaScheduleContext.current() is None +def test_meta_schedule_apply_history_best_no_current(): + assert ApplyHistoryBest.current() is None @requires_torch def test_meta_schedule_integration_extract_from_resnet(): mod, params, _ = get_network(name="resnet_18", input_shape=[1, 3, 224, 224]) - extracted_tasks = ms.integration.extract_task_from_relay(mod, target="llvm", params=params) + extracted_tasks = ms.extract_task_from_relay(mod, target="llvm", params=params) expected_task_names = [ "fused_" + s for s in [ @@ -189,7 +183,7 @@ def test_meta_schedule_integration_extract_from_bert_base(): ), } mod, params, _ = get_network(name="bert_base", input_shape=[1, 64]) - extracted_tasks = ms.integration.extract_task_from_relay(mod, target="llvm", params=params) + extracted_tasks = ms.extract_task_from_relay(mod, target="llvm", params=params) assert len(extracted_tasks) == len(expected) for t in extracted_tasks: prim_func = None diff --git a/tests/python/unittest/test_meta_schedule_multi_anchor.py b/tests/python/unittest/test_meta_schedule_multi_anchor.py index 78d0ddeda32f..0b8af9c14550 100644 --- a/tests/python/unittest/test_meta_schedule_multi_anchor.py +++ b/tests/python/unittest/test_meta_schedule_multi_anchor.py @@ -15,12 +15,11 @@ # specific language governing permissions and limitations # under the License. import numpy as np - import tvm import tvm.testing from tvm import relay +from tvm.meta_schedule import ApplyHistoryBest from tvm.meta_schedule.testing import apply_fixed_schedules -from tvm.meta_schedule.integration import ApplyHistoryBest def get_dense_dense(data_shape, weight_shape): diff --git a/tests/python/unittest/test_meta_schedule_tune_relay.py b/tests/python/unittest/test_meta_schedule_tune_relay.py index 76cd82920c35..af25d2a6f39e 100644 --- a/tests/python/unittest/test_meta_schedule_tune_relay.py +++ b/tests/python/unittest/test_meta_schedule_tune_relay.py @@ -27,16 +27,12 @@ from tvm._ffi import register_func from tvm.contrib import graph_executor from tvm.ir import IRModule -from tvm.meta_schedule import ReplayTraceConfig +from tvm.meta_schedule import ApplyHistoryBest, ReplayTraceConfig from tvm.meta_schedule.database import JSONDatabase, PyDatabase, TuningRecord, Workload -from tvm.meta_schedule.integration import ApplyHistoryBest -from tvm.meta_schedule.testing.relay_workload import get_network +from tvm.meta_schedule.relay_integration import extract_task_from_relay from tvm.meta_schedule.testing import apply_fixed_schedules -from tvm.meta_schedule.tune import ( - extract_task_from_relay, - tune_extracted_tasks, - tune_relay, -) +from tvm.meta_schedule.testing.relay_workload import get_network +from tvm.meta_schedule.tune import tune_extracted_tasks, tune_relay from tvm.meta_schedule.utils import derived_object from tvm.script import tir as T from tvm.target.target import Target @@ -528,13 +524,13 @@ def schedule_fn(task, sch): ): """ The log should say - meta_schedule/integration.cc:146: Warning: Cannot find workload: tvmgen_default_fused_expand_dims - meta_schedule/integration.cc:146: Warning: Cannot find workload: tvmgen_default_fused_cast - meta_schedule/integration.cc:146: Warning: Cannot find workload: tvmgen_default_fused_cast_1 - meta_schedule/integration.cc:146: Warning: Cannot find workload: tvmgen_default_fused_nn_batch_matmul + Warning: Cannot find workload: tvmgen_default_fused_expand_dims + Warning: Cannot find workload: tvmgen_default_fused_cast + Warning: Cannot find workload: tvmgen_default_fused_cast_1 + Warning: Cannot find workload: tvmgen_default_fused_nn_batch_matmul - This means batch matmul and others are scheduled by TE, and dense (the one not warned) is found in the - meta schedule tuning database during ApplyHistoryBest + This means batch matmul and others are scheduled by TE, and dense (the one not warned) + is found in the meta schedule tuning database during ApplyHistoryBest """ lib = relay.build(relay_mod, target=target, params=params)