Skip to content

Commit

Permalink
Add tolerance to early stopping.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed May 6, 2021
1 parent ec6ce08 commit 90f5868
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 8 deletions.
21 changes: 13 additions & 8 deletions python-package/xgboost/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,19 +493,17 @@ def __init__(self,
metric_name: Optional[str] = None,
data_name: Optional[str] = None,
maximize: Optional[bool] = None,
save_best: Optional[bool] = False) -> None:
save_best: Optional[bool] = False,
tolerance: float = 0) -> None:
self.data = data_name
self.metric_name = metric_name
self.rounds = rounds
self.save_best = save_best
self.maximize = maximize
self.stopping_history: CallbackContainer.EvalsLog = {}
self._tol = tolerance

if self.maximize is not None:
if self.maximize:
self.improve_op = lambda x, y: x > y
else:
self.improve_op = lambda x, y: x < y
self.improve_op = None

self.current_rounds: int = 0
self.best_scores: dict = {}
Expand All @@ -523,11 +521,18 @@ def _update_rounds(self, score, name, metric, model, epoch) -> bool:
maximize_metrics = ('auc', 'aucpr', 'map', 'ndcg', 'auc@',
'aucpr@', 'map@', 'ndcg@')
if any(metric.startswith(x) for x in maximize_metrics):
self.improve_op = lambda x, y: x > y
self.improve_op = lambda x, y: x - y > -self._tol
self.maximize = True
else:
self.improve_op = lambda x, y: x < y
self.improve_op = lambda x, y: y - x > -self._tol
self.maximize = False
else:
if self.maximize:
self.improve_op = lambda x, y: x - y > -self._tol
else:
self.improve_op = lambda x, y: y - x > -self._tol

assert self.improve_op

if not self.stopping_history: # First round
self.current_rounds = 0
Expand Down
19 changes: 19 additions & 0 deletions tests/python/test_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,25 @@ def test_early_stopping_customize(self):
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
assert len(early_stop.stopping_history['Train']['CustomErr']) == len(dump)

tol = 10
early_stop = xgb.callback.EarlyStopping(
rounds=early_stopping_rounds,
metric_name='CustomErr',
data_name='Train',
tolerance=tol
)
booster = xgb.train(
{'objective': 'binary:logistic',
'eval_metric': ['error', 'rmse'],
'tree_method': 'hist'}, D_train,
evals=[(D_train, 'Train'), (D_valid, 'Valid')],
feval=tm.eval_error_metric,
num_boost_round=100,
callbacks=[early_stop],
verbose_eval=False)

assert booster.best_iteration == 100 - tol

def test_early_stopping_skl(self):
from sklearn.datasets import load_breast_cancer
X, y = load_breast_cancer(return_X_y=True)
Expand Down

0 comments on commit 90f5868

Please sign in to comment.