Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions python/tvm/meta_schedule/cost_model/cost_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,12 @@ def create(
if kind == "xgb":
return XGBModel(*args, **kwargs) # type: ignore

if "num_tuning_cores" in kwargs:
# num_tuning_cores is only relevant for XGBModel.
kwargs.pop("num_tuning_cores")
# params only relevant to XGBModel
_xgb_params = ["num_tuning_cores", "tree_method"]

for param in _xgb_params:
if param in kwargs:
kwargs.pop(param)

if kind == "random":
return RandomModel(*args, **kwargs) # type: ignore
Expand Down
14 changes: 14 additions & 0 deletions python/tvm/meta_schedule/cost_model/xgb_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from itertools import chain as itertools_chain
from typing import TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple, Optional, Tuple

from typing_extensions import Literal

import numpy as np # type: ignore

from ...contrib.tar import tar, untar
Expand Down Expand Up @@ -202,6 +204,8 @@ def average_peak_score(
class XGBConfig(NamedTuple):
"""XGBoost model configuration

Reference: https://xgboost.readthedocs.io/en/stable/parameter.html

Parameters
----------
max_depth : int
Expand All @@ -217,6 +221,8 @@ class XGBConfig(NamedTuple):
nthread : Optional[int],
The number of threads to use.
Default is None, which means to use physical number of cores.
tree_method : Literal["auto", "exact", "approx", "hist", "gpu_hist"]
The tree construction algorithm used in XGBoost.
"""

max_depth: int = 10
Expand All @@ -225,15 +231,19 @@ class XGBConfig(NamedTuple):
eta: float = 0.2
seed: int = 43
nthread: Optional[int] = None
tree_method: Literal["auto", "exact", "approx", "hist", "gpu_hist"] = "auto"

def to_dict(self):
"""Convert to dict"""

return {
"max_depth": self.max_depth,
"gamma": self.gamma,
"min_child_weight": self.min_child_weight,
"eta": self.eta,
"seed": self.seed,
"nthread": self.nthread,
"tree_method": self.tree_method,
}


Expand Down Expand Up @@ -334,6 +344,7 @@ def __init__(
average_peak_n: int = 32,
adaptive_training: bool = True,
num_tuning_cores: Optional[int] = None,
tree_method: Optional[Literal["auto", "exact", "approx", "hist", "gpu_hist"]] = None,
):
super().__init__()
if not isinstance(extractor, FeatureExtractor):
Expand All @@ -348,6 +359,9 @@ def __init__(
else:
config = config._replace(nthread=num_tuning_cores)

if tree_method is not None:
config._replace(tree_method=tree_method)

self.config = config
# behavior of randomness
self.num_warmup_samples = num_warmup_samples
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/meta_schedule/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def tune_tasks(
elif not isinstance(database, Database):
database = Database.create(database, module_equality=module_equality)
if not isinstance(cost_model, CostModel):
cost_model = CostModel.create(cost_model, num_tuning_cores=num_cores)
cost_model = CostModel.create(cost_model, num_tuning_cores=num_cores, tree_method="auto")
if isinstance(measure_callbacks, MeasureCallback):
measure_callbacks = [measure_callbacks]
elif measure_callbacks == "default":
Expand Down