Skip to content

Commit

Permalink
[python-package] Correctly recognize LGBMClassifier(num_class=2, obje…
Browse files Browse the repository at this point in the history
…ctive="multiclass") as multiclass classification (#6524)
  • Loading branch information
RektPunk authored Jul 16, 2024
1 parent 3d02662 commit f8ec57b
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 4 deletions.
2 changes: 2 additions & 0 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@

ZERO_THRESHOLD = 1e-35

_MULTICLASS_OBJECTIVES = {"multiclass", "multiclassova", "multiclass_ova", "ova", "ovr", "softmax"}


def _is_zero(x: float) -> bool:
return -ZERO_THRESHOLD <= x <= ZERO_THRESHOLD
Expand Down
14 changes: 10 additions & 4 deletions python-package/lightgbm/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import scipy.sparse

from .basic import (
_MULTICLASS_OBJECTIVES,
Booster,
Dataset,
LightGBMError,
Expand Down Expand Up @@ -467,7 +468,7 @@ def _extract_evaluation_meta_data(
# It's possible, for example, to pass 3 eval sets through `eval_set`,
# but only 1 init_score through `eval_init_score`.
#
# This if-else accounts for that possiblity.
# This if-else accounts for that possibility.
if len(collection) > i:
return collection[i]
else:
Expand Down Expand Up @@ -1011,7 +1012,7 @@ def predict(
f"match the input. Model n_features_ is {self._n_features} and "
f"input n_features is {n_features}"
)
# retrive original params that possibly can be used in both training and prediction
# retrieve original params that possibly can be used in both training and prediction
# and then overwrite them (considering aliases) with params that were passed directly in prediction
predict_params = self._process_params(stage="predict")
for alias in _ConfigAliases.get_by_alias(
Expand Down Expand Up @@ -1251,7 +1252,7 @@ def fit( # type: ignore[override]
eval_metric_list = [eval_metric]
else:
eval_metric_list = []
if self._n_classes > 2:
if self.__is_multiclass:
for index, metric in enumerate(eval_metric_list):
if metric in {"logloss", "binary_logloss"}:
eval_metric_list[index] = "multi_logloss"
Expand Down Expand Up @@ -1361,7 +1362,7 @@ def predict_proba(
"Returning raw scores instead."
)
return result
elif self._n_classes > 2 or raw_score or pred_leaf or pred_contrib: # type: ignore [operator]
elif self.__is_multiclass or raw_score or pred_leaf or pred_contrib: # type: ignore [operator]
return result
else:
return np.vstack((1.0 - result, result)).transpose()
Expand Down Expand Up @@ -1389,6 +1390,11 @@ def n_classes_(self) -> int:
raise LGBMNotFittedError("No classes found. Need to call fit beforehand.")
return self._n_classes

@property
def __is_multiclass(self) -> bool:
""":obj:`bool`: Indicator of whether the classifier is used for multiclass."""
return self._n_classes > 2 or (isinstance(self._objective, str) and self._objective in _MULTICLASS_OBJECTIVES)


class LGBMRanker(LGBMModel):
"""LightGBM ranker.
Expand Down
33 changes: 33 additions & 0 deletions tests/python_package_test/test_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,6 +719,25 @@ def test_predict():
with pytest.raises(AssertionError):
np.testing.assert_allclose(res_engine, res_sklearn_params)

# Test multiclass binary classification
num_samples = 100
num_classes = 2
X_train = np.linspace(start=0, stop=10, num=num_samples * 3).reshape(num_samples, 3)
y_train = np.concatenate([np.zeros(int(num_samples / 2 - 10)), np.ones(int(num_samples / 2 + 10))])

gbm = lgb.train({"objective": "multiclass", "num_class": num_classes, "verbose": -1}, lgb.Dataset(X_train, y_train))
clf = lgb.LGBMClassifier(objective="multiclass", num_classes=num_classes).fit(X_train, y_train)

res_engine = gbm.predict(X_train)
res_sklearn = clf.predict_proba(X_train)

assert res_engine.shape == (num_samples, num_classes)
assert res_sklearn.shape == (num_samples, num_classes)
np.testing.assert_allclose(res_engine, res_sklearn)

res_class_sklearn = clf.predict(X_train)
np.testing.assert_allclose(res_class_sklearn, y_train)


def test_predict_with_params_from_init():
X, y = load_iris(return_X_y=True)
Expand Down Expand Up @@ -1035,6 +1054,20 @@ def test_metrics():
assert len(gbm.evals_result_["training"]) == 1
assert "binary_logloss" in gbm.evals_result_["training"]

# the evaluation metric changes to multiclass metric even num classes is 2 for multiclass objective
gbm = lgb.LGBMClassifier(objective="multiclass", num_classes=2, **params).fit(
eval_metric="binary_logloss", **params_fit
)
assert len(gbm._evals_result["training"]) == 1
assert "multi_logloss" in gbm.evals_result_["training"]

# the evaluation metric changes to multiclass metric even num classes is 2 for ovr objective
gbm = lgb.LGBMClassifier(objective="ovr", num_classes=2, **params).fit(eval_metric="binary_error", **params_fit)
assert gbm.objective_ == "ovr"
assert len(gbm.evals_result_["training"]) == 2
assert "multi_logloss" in gbm.evals_result_["training"]
assert "multi_error" in gbm.evals_result_["training"]


def test_multiple_eval_metrics():
X, y = load_breast_cancer(return_X_y=True)
Expand Down

0 comments on commit f8ec57b

Please sign in to comment.