Skip to content

Commit df7ad33

Browse files
committed
More fixes.
1 parent fb08108 commit df7ad33

File tree

3 files changed

+67
-50
lines changed

3 files changed

+67
-50
lines changed

python/tvm/meta_schedule/cost_model/metric.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import numpy as np
1919

2020

21-
def max_curve(trial_scores):
21+
def max_curve(trial_scores: np.ndarray):
2222
"""f(n) = max([s[i] fo i < n])
2323
2424
Parameters
@@ -37,18 +37,3 @@ def max_curve(trial_scores):
3737
keep = max(keep, score)
3838
ret[i] = keep
3939
return ret
40-
41-
42-
def make_metric_sorter(focused_metric):
43-
""" Make sure the focused metric is the first one. """
44-
45-
def metric_name_for_sort(name):
46-
if focused_metric == name:
47-
return "!" + name
48-
return name
49-
50-
def sort_key(key):
51-
key, _ = key
52-
return metric_name_for_sort(key)
53-
54-
return sort_key

python/tvm/meta_schedule/cost_model/xgb_model.py

Lines changed: 65 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
"""
1818
XGBoost-based cost model
1919
"""
20-
from typing import Union, Optional, Tuple, Callable, List, TYPE_CHECKING
20+
from typing import Optional, Tuple, Callable, List, TYPE_CHECKING
2121

2222

2323
import logging
@@ -29,7 +29,7 @@
2929
from ..feature_extractor import FeatureExtractor
3030
from ..cost_model import PyCostModel
3131
from ..utils import cpu_count
32-
from .metric import max_curve, make_metric_sorter
32+
from .metric import max_curve
3333

3434
if TYPE_CHECKING:
3535
from ..tune_context import TuneContext
@@ -39,6 +39,21 @@
3939
logger = logging.getLogger(__name__)
4040

4141

42+
def make_metric_sorter(focused_metric):
43+
""" Make sure the focused metric is the first one. """
44+
45+
def metric_name_for_sort(name):
46+
if focused_metric == name:
47+
return "!" + name
48+
return name
49+
50+
def sort_key(key):
51+
key, _ = key
52+
return metric_name_for_sort(key)
53+
54+
return sort_key
55+
56+
4257
class PackSum:
4358
"""The pack-sum format
4459
@@ -185,14 +200,9 @@ class XGBModel(PyCostModel):
185200
XGBoost model param, the eta, learning rate.
186201
xgb_seed : int
187202
XGBoost model param, the random seed.
188-
xgb_nthread : int
203+
xgb_nthread : Optional[int],
189204
XGBoost model param, the number of threads to use.
190-
xgb_n_gpus : int
191-
XGBoost model param, the number of gpus.
192-
xgb_verbosity_train : int
193-
XGBoost model param, the verbose level for training.
194-
xgb_disable_default_eval_metric : Union[int, bool]
195-
XGBoost model param, flag to disable default metric. Set to 1 or true to disable.
205+
Default is None, which means to use physical number of cores.
196206
path : Optional[str]
197207
The path to save the model.
198208
num_warmup_samples : int
@@ -208,6 +218,23 @@ class XGBModel(PyCostModel):
208218

209219
# model-related params
210220
_xgb_params: dict
221+
"""The parameters for xgboost model
222+
223+
Parameters
224+
----------
225+
max_depth : int
226+
XGBoost model param, the maximum depth.
227+
gamma : float
228+
XGBoost model param, the gamma.
229+
min_child_weight : float
230+
XGBoost model param, the minimum child weight.
231+
eta : float
232+
XGBoost model param, the eta, learning rate.
233+
seed : int
234+
XGBoost model param, the random seed.
235+
nthread : int
236+
XGBoost model param, the number of threads to use.
237+
"""
211238
# serialization-related
212239
path: Optional[str]
213240
# feature extractor
@@ -235,10 +262,7 @@ def __init__(
235262
xgb_min_child_weight: float = 0,
236263
xgb_eta: float = 0.2,
237264
xgb_seed: int = 43,
238-
xgb_nthread: int = cpu_count(),
239-
xgb_n_gpus: int = 0,
240-
xgb_verbosity_train: int = 0,
241-
xgb_disable_default_eval_metric: Union[int, bool] = 1,
265+
xgb_nthread: Optional[int] = None,
242266
# load from disk
243267
path: Optional[str] = None,
244268
# behavior of randomness
@@ -252,16 +276,16 @@ def __init__(
252276
# feature extractor
253277
self.extractor = extractor
254278
# model-related
279+
if xgb_nthread is None:
280+
# use physical core number
281+
xgb_nthread = cpu_count(False)
255282
self._xgb_params = {
256283
"max_depth": xgb_max_depth,
257284
"gamma": xgb_gamma,
258285
"min_child_weight": xgb_min_child_weight,
259286
"eta": xgb_eta,
260287
"seed": xgb_seed,
261288
"nthread": xgb_nthread,
262-
"n_gpus": xgb_n_gpus,
263-
"verbosity": xgb_verbosity_train,
264-
"disable_default_eval_metric": xgb_disable_default_eval_metric,
265289
}
266290
# serialization-related
267291
self.path = path
@@ -277,7 +301,7 @@ def __init__(
277301
self.cached_normalizer = None
278302
self.booster = None
279303

280-
def load(self, path: str) -> None:
304+
def load(self, path: str = None) -> None:
281305
"""Load the cost model from given file location.
282306
283307
Parameters
@@ -290,14 +314,18 @@ def load(self, path: str) -> None:
290314
Since XGBoost model trains from scratch, each time we can only load the model without the
291315
previous cached features / results so any call of update won't use previous training data.
292316
"""
317+
if path is None:
318+
path = self.path
319+
293320
import xgboost as xgb # pylint: disable=import-outside-toplevel
294321

295322
if self.booster is None:
296323
# save all the paramaters
297324
self.booster = xgb.Booster(self._xgb_params)
325+
# throw error when path is none
298326
self.booster.load_model(path)
299327

300-
def save(self, path: str) -> None:
328+
def save(self, path: str = None) -> None:
301329
"""Save the cost model to given file location.
302330
303331
Parameters
@@ -310,6 +338,9 @@ def save(self, path: str) -> None:
310338
Since XGBoost model trains from scratch, each time we can only save the model without the
311339
previous cached features / results so any call of update won't use previous training data.
312340
"""
341+
if path is None:
342+
path = self.path
343+
# throw error when path is none
313344
self.booster.save_model(path)
314345

315346
def update(
@@ -322,7 +353,7 @@ def update(
322353
323354
Parameters
324355
----------
325-
tune_context : TuneContext,
356+
tune_context : TuneContext
326357
The tuning context.
327358
candidates : List[MeasureCandidate]
328359
The measure candidates.
@@ -339,7 +370,7 @@ def update(
339370
logger.debug(
340371
"XGB validation: %s",
341372
"\t".join(
342-
"%s: %.6f" % (key, score) # pylint: disable=consider-using-f-string
373+
f"{key}: {score:.6f}"
343374
for key, score in self._validate(
344375
xs=new_features,
345376
ys=new_mean_costs,
@@ -356,14 +387,11 @@ def update(
356387
xs=self.cached_features,
357388
ys=self.cached_mean_costs,
358389
)
359-
# Update the model file if it has been set
360-
if self.path:
361-
self.save(self.path)
362390

363391
def predict(
364392
self, tune_context: "TuneContext", candidates: List[MeasureCandidate]
365393
) -> np.ndarray:
366-
"""Update the cost model given running results.
394+
"""Predict the normalized score using the cost model.
367395
368396
Parameters
369397
----------
@@ -375,7 +403,7 @@ def predict(
375403
Return
376404
------
377405
result : np.ndarray
378-
The predicted running results.
406+
The predicted normalized score.
379407
"""
380408
n_measured = len(self.cached_features)
381409
if self.booster is not None and n_measured >= self.num_warmup_samples:
@@ -407,8 +435,8 @@ def rmse(ys_pred: np.ndarray, d_train: "xgb.DMatrix"): # pylint: disable = unus
407435
return self.d_train.rmse(ys_pred)
408436

409437
def average_peak_score(
410-
ys_pred: np.ndarray, d_train: "xgb.DMatrix"
411-
): # pylint: disable = unused-argument
438+
ys_pred: np.ndarray, d_train: "xgb.DMatrix" # pylint: disable = unused-argument
439+
):
412440
return self.d_train.average_peak_score(ys_pred, self.average_peak_n)
413441

414442
self.booster = xgb.train(
@@ -429,6 +457,12 @@ def average_peak_score(
429457
],
430458
)
431459

460+
del self.d_train
461+
462+
# Update the model file if it has been set
463+
if self.path:
464+
self.save(self.path)
465+
432466
def _predict( # pylint: disable=invalid-name
433467
self,
434468
xs: List[np.ndarray],
@@ -565,9 +599,8 @@ def callback(env: "xgb.core.CallbackEnv"):
565599
if verbose_eval and iteration % verbose_eval == 0:
566600
info = []
567601
for key, score in eval_result:
568-
if "null" in key:
569-
continue
570-
info.append("%s: %.6f" % (key, score)) # pylint: disable=consider-using-f-string
602+
if "null" not in key:
603+
info.append(f"{key}: {score:.6f}")
571604
logger.debug("XGB iter %3d: %s", iteration, "\t".join(info))
572605

573606
##### Choose score and do early stopping #####
@@ -581,10 +614,8 @@ def callback(env: "xgb.core.CallbackEnv"):
581614
best_score = state["best_score"]
582615
best_iteration = state["best_iteration"]
583616
if score < best_score:
584-
msg = "[%d] %s" % ( # pylint: disable=consider-using-f-string
585-
env.iteration,
586-
"\t".join([_fmt_metric(x) for x in eval_result]),
587-
)
617+
tab = "\t" # to work with f-string
618+
msg = f"[{env.iteration}] {tab.join([_fmt_metric(x) for x in eval_result])}"
588619
state["best_msg"] = msg
589620
state["best_score"] = score
590621
state["best_iteration"] = env.iteration

python/tvm/meta_schedule/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
"""Utilities for meta schedule"""
18+
from logging import log
1819
from typing import Any, Callable, List, Optional, Union
1920

2021
import ctypes

0 commit comments

Comments
 (0)