Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python-package] allow use of early_stopping_round<=0 to turn off early stopping (fixes #6401) #6406

Merged
merged 10 commits into from
Apr 20, 2024
19 changes: 15 additions & 4 deletions python-package/lightgbm/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,18 +280,17 @@ def __init__(
verbose: bool = True,
min_delta: Union[float, List[float]] = 0.0,
) -> None:
if not isinstance(stopping_rounds, int) or stopping_rounds <= 0:
raise ValueError(f"stopping_rounds should be an integer and greater than 0. got: {stopping_rounds}")
self.stopping_rounds = stopping_rounds

self.enabled = _should_enable_early_stopping(stopping_rounds)

self.order = 30
self.before_iteration = False

self.stopping_rounds = stopping_rounds
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
self.first_metric_only = first_metric_only
self.verbose = verbose
self.min_delta = min_delta

self.enabled = True
self._reset_storages()

def _reset_storages(self) -> None:
Expand Down Expand Up @@ -438,6 +437,18 @@ def __call__(self, env: CallbackEnv) -> None:
self._final_iteration_check(env, eval_name_splitted, i)


def _should_enable_early_stopping(stopping_rounds: Any) -> bool:
"""Check if early stopping should be activated.

This function will evaluate to True if the early stopping callback should be
activated (i.e. stopping_rounds > 0). It also provides an informative error if the
type is not int.
"""
if not isinstance(stopping_rounds, int):
raise TypeError(f"early_stopping_round should be an integer. Got {type(stopping_rounds)}")
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
return stopping_rounds > 0


def early_stopping(
stopping_rounds: int,
first_metric_only: bool = False,
Expand Down
4 changes: 2 additions & 2 deletions python-package/lightgbm/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def train(
cb.__dict__.setdefault("order", i - len(callbacks))
callbacks_set = set(callbacks)

if "early_stopping_round" in params:
if callback._should_enable_early_stopping(params.get("early_stopping_round", 0)):
callbacks_set.add(
callback.early_stopping(
stopping_rounds=params["early_stopping_round"], # type: ignore[arg-type]
Expand Down Expand Up @@ -760,7 +760,7 @@ def cv(
cb.__dict__.setdefault("order", i - len(callbacks))
callbacks_set = set(callbacks)

if "early_stopping_round" in params:
if callback._should_enable_early_stopping(params.get("early_stopping_round", 0)):
callbacks_set.add(
callback.early_stopping(
stopping_rounds=params["early_stopping_round"], # type: ignore[arg-type]
Expand Down
8 changes: 1 addition & 7 deletions tests/python_package_test/test_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,7 @@ def test_early_stopping_callback_is_picklable(serializer):


def test_early_stopping_callback_rejects_invalid_stopping_rounds_with_informative_errors():
with pytest.raises(ValueError, match="stopping_rounds should be an integer and greater than 0. got: 0"):
lgb.early_stopping(stopping_rounds=0)

with pytest.raises(ValueError, match="stopping_rounds should be an integer and greater than 0. got: -1"):
lgb.early_stopping(stopping_rounds=-1)

with pytest.raises(ValueError, match="stopping_rounds should be an integer and greater than 0. got: neverrrr"):
with pytest.raises(TypeError, match="early_stopping_round should be an integer. Got " "<class 'str'>"):
lgb.early_stopping(stopping_rounds="neverrrr")
jameslamb marked this conversation as resolved.
Show resolved Hide resolved


Expand Down
46 changes: 46 additions & 0 deletions tests/python_package_test/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -938,6 +938,52 @@ def test_early_stopping_via_global_params(first_metric_only):
assert "error" in gbm.best_score[valid_set_name]


@pytest.mark.parametrize("early_stopping_round", [-10, -1, 0, None, "None"])
def test_early_stopping_non_positive_values(early_stopping_round):
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
X, y = load_breast_cancer(return_X_y=True)
num_trees = 5
params = {
"num_trees": num_trees,
"objective": "binary",
"metric": "None",
"verbose": -1,
"early_stopping_round": early_stopping_round,
"first_metric_only": True,
}
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
lgb_train = lgb.Dataset(X_train, y_train)
lgb_eval = lgb.Dataset(X_test, y_test, reference=lgb_train)
valid_set_name = "valid_set"

if early_stopping_round is None:
gbm = lgb.train(
params,
lgb_train,
feval=[decreasing_metric, constant_metric],
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
valid_sets=lgb_eval,
valid_names=valid_set_name,
)
assert "early_stopping_round" not in gbm.params
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
elif early_stopping_round == "None":
with pytest.raises(TypeError):
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
gbm = lgb.train(
params,
lgb_train,
feval=[decreasing_metric, constant_metric],
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
valid_sets=lgb_eval,
valid_names=valid_set_name,
)
elif early_stopping_round <= 0:
gbm = lgb.train(
params,
lgb_train,
feval=[decreasing_metric, constant_metric],
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
valid_sets=lgb_eval,
valid_names=valid_set_name,
)
assert gbm.params["early_stopping_round"] == early_stopping_round
jameslamb marked this conversation as resolved.
Show resolved Hide resolved


@pytest.mark.parametrize("first_only", [True, False])
@pytest.mark.parametrize("single_metric", [True, False])
@pytest.mark.parametrize("greater_is_better", [True, False])
Expand Down
Loading