diff --git a/python/tvm/meta_schedule/cost_model/cost_model.py b/python/tvm/meta_schedule/cost_model/cost_model.py index c0f6ea5fb9e1..541154d4cc59 100644 --- a/python/tvm/meta_schedule/cost_model/cost_model.py +++ b/python/tvm/meta_schedule/cost_model/cost_model.py @@ -127,9 +127,12 @@ def create( if kind == "xgb": return XGBModel(*args, **kwargs) # type: ignore - if "num_tuning_cores" in kwargs: - # num_tuning_cores is only relevant for XGBModel. - kwargs.pop("num_tuning_cores") + # params only relevant to XGBModel + _xgb_params = ["num_tuning_cores", "tree_method"] + + for param in _xgb_params: + if param in kwargs: + kwargs.pop(param) if kind == "random": return RandomModel(*args, **kwargs) # type: ignore diff --git a/python/tvm/meta_schedule/cost_model/xgb_model.py b/python/tvm/meta_schedule/cost_model/xgb_model.py index fde2f2f60529..6b6b7a2dc1ed 100644 --- a/python/tvm/meta_schedule/cost_model/xgb_model.py +++ b/python/tvm/meta_schedule/cost_model/xgb_model.py @@ -21,6 +21,8 @@ from itertools import chain as itertools_chain from typing import TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple, Optional, Tuple +from typing_extensions import Literal + import numpy as np # type: ignore from ...contrib.tar import tar, untar @@ -202,6 +204,8 @@ def average_peak_score( class XGBConfig(NamedTuple): """XGBoost model configuration + Reference: https://xgboost.readthedocs.io/en/stable/parameter.html + Parameters ---------- max_depth : int @@ -217,6 +221,8 @@ class XGBConfig(NamedTuple): nthread : Optional[int], The number of threads to use. Default is None, which means to use physical number of cores. + tree_method : Literal["auto", "exact", "approx", "hist", "gpu_hist"] + The tree construction algorithm used in XGBoost. """ max_depth: int = 10 @@ -225,8 +231,11 @@ class XGBConfig(NamedTuple): eta: float = 0.2 seed: int = 43 nthread: Optional[int] = None + tree_method: Literal["auto", "exact", "approx", "hist", "gpu_hist"] = "auto" def to_dict(self): + """Convert to dict""" + return { "max_depth": self.max_depth, "gamma": self.gamma, @@ -234,6 +243,7 @@ def to_dict(self): "eta": self.eta, "seed": self.seed, "nthread": self.nthread, + "tree_method": self.tree_method, } @@ -334,6 +344,7 @@ def __init__( average_peak_n: int = 32, adaptive_training: bool = True, num_tuning_cores: Optional[int] = None, + tree_method: Optional[Literal["auto", "exact", "approx", "hist", "gpu_hist"]] = None, ): super().__init__() if not isinstance(extractor, FeatureExtractor): @@ -348,6 +359,9 @@ def __init__( else: config = config._replace(nthread=num_tuning_cores) + if tree_method is not None: + config._replace(tree_method=tree_method) + self.config = config # behavior of randomness self.num_warmup_samples = num_warmup_samples diff --git a/python/tvm/meta_schedule/tune.py b/python/tvm/meta_schedule/tune.py index 132f446a5252..887941ada0d2 100644 --- a/python/tvm/meta_schedule/tune.py +++ b/python/tvm/meta_schedule/tune.py @@ -108,7 +108,7 @@ def tune_tasks( elif not isinstance(database, Database): database = Database.create(database, module_equality=module_equality) if not isinstance(cost_model, CostModel): - cost_model = CostModel.create(cost_model, num_tuning_cores=num_cores) + cost_model = CostModel.create(cost_model, num_tuning_cores=num_cores, tree_method="auto") if isinstance(measure_callbacks, MeasureCallback): measure_callbacks = [measure_callbacks] elif measure_callbacks == "default":