Skip to content

Commit

Permalink
Fix CV.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed May 11, 2021
1 parent 277674a commit 8f09468
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 11 deletions.
26 changes: 18 additions & 8 deletions python-package/xgboost/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,8 @@ class EarlyStopping(TrainingCallback):
Whether to maximize evaluation metric. None means auto (discouraged).
save_best
Whether training should return the best model or the last model.
tolerance
Tolerance for early stopping condition.
"""
def __init__(self,
rounds: int,
Expand Down Expand Up @@ -515,22 +517,30 @@ def before_training(self, model):
return model

def _update_rounds(self, score, name, metric, model, epoch) -> bool:
# Just to be compatibility with old behavior before 1.3. We should let
# user to decide.
def get_s(x):
"""get score if it's cross validation history."""
return x[0] if isinstance(x, tuple) else x

def maximize(new, last):
return get_s(new) - get_s(last) > -self._tol

def minimize(new, last):
return get_s(last) - get_s(new) > -self._tol

if self.maximize is None:
# Just to be compatibility with old behavior before 1.3. We should let
# user to decide.
maximize_metrics = ('auc', 'aucpr', 'map', 'ndcg', 'auc@',
'aucpr@', 'map@', 'ndcg@')
if any(metric.startswith(x) for x in maximize_metrics):
self.improve_op = lambda x, y: x - y > -self._tol
self.maximize = True
else:
self.improve_op = lambda x, y: y - x > -self._tol
self.maximize = False

if self.maximize:
self.improve_op = maximize
else:
if self.maximize:
self.improve_op = lambda x, y: x - y > -self._tol
else:
self.improve_op = lambda x, y: y - x > -self._tol
self.improve_op = minimize

assert self.improve_op

Expand Down
8 changes: 5 additions & 3 deletions tests/python/test_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,9 @@ def test_early_stopping_customize(self):
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
assert len(early_stop.stopping_history['Train']['CustomErr']) == len(dump)

# test tolerance, early stop won't occur with high tolerance.
tol = 10
rounds = 100
early_stop = xgb.callback.EarlyStopping(
rounds=early_stopping_rounds,
metric_name='CustomErr',
Expand All @@ -139,11 +141,11 @@ def test_early_stopping_customize(self):
'tree_method': 'hist'}, D_train,
evals=[(D_train, 'Train'), (D_valid, 'Valid')],
feval=tm.eval_error_metric,
num_boost_round=100,
num_boost_round=rounds,
callbacks=[early_stop],
verbose_eval=False)

assert booster.best_iteration == 100 - tol
# 0 based index
assert booster.best_iteration == rounds - 1

def test_early_stopping_skl(self):
from sklearn.datasets import load_breast_cancer
Expand Down

0 comments on commit 8f09468

Please sign in to comment.