Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python-package] Correctly recognize LGBMClassifier(num_class=2, objective="multiclass") as multiclass classification #6524

Merged
merged 15 commits into from
Jul 16, 2024
14 changes: 10 additions & 4 deletions python-package/lightgbm/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ def _extract_evaluation_meta_data(
# It's possible, for example, to pass 3 eval sets through `eval_set`,
# but only 1 init_score through `eval_init_score`.
#
# This if-else accounts for that possiblity.
# This if-else accounts for that possibility.
if len(collection) > i:
return collection[i]
else:
Expand Down Expand Up @@ -1011,7 +1011,7 @@ def predict(
f"match the input. Model n_features_ is {self._n_features} and "
f"input n_features is {n_features}"
)
# retrive original params that possibly can be used in both training and prediction
# retrieve original params that possibly can be used in both training and prediction
# and then overwrite them (considering aliases) with params that were passed directly in prediction
predict_params = self._process_params(stage="predict")
for alias in _ConfigAliases.get_by_alias(
Expand Down Expand Up @@ -1251,7 +1251,7 @@ def fit( # type: ignore[override]
eval_metric_list = [eval_metric]
else:
eval_metric_list = []
if self._n_classes > 2:
if self.__is_multiclass:
for index, metric in enumerate(eval_metric_list):
if metric in {"logloss", "binary_logloss"}:
eval_metric_list[index] = "multi_logloss"
Expand Down Expand Up @@ -1361,7 +1361,7 @@ def predict_proba(
"Returning raw scores instead."
)
return result
elif self._n_classes > 2 or raw_score or pred_leaf or pred_contrib: # type: ignore [operator]
elif self.__is_multiclass or raw_score or pred_leaf or pred_contrib: # type: ignore [operator]
return result
else:
return np.vstack((1.0 - result, result)).transpose()
Expand Down Expand Up @@ -1389,6 +1389,12 @@ def n_classes_(self) -> int:
raise LGBMNotFittedError("No classes found. Need to call fit beforehand.")
return self._n_classes

@property
def __is_multiclass(self) -> bool:
""":obj:`bool`: Indicator of whether the classifier is used for multiclass."""
multiclass_objectives = ("multiclass", "softmax", "multiclassova", "multiclass_ova", "ova", "ovr")
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
return self._n_classes > 2 or (isinstance(self._objective, str) and self._objective in multiclass_objectives)


class LGBMRanker(LGBMModel):
"""LightGBM ranker.
Expand Down
32 changes: 32 additions & 0 deletions tests/python_package_test/test_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,6 +719,24 @@ def test_predict():
with pytest.raises(AssertionError):
np.testing.assert_allclose(res_engine, res_sklearn_params)

# Test multiclass binary classification
num_samples, num_classes = 5, 2
X_train = np.ones((num_samples, 3))
y_train = np.zeros((num_samples,))
jameslamb marked this conversation as resolved.
Show resolved Hide resolved

gbm = lgb.train({"objective": "multiclass", "num_class": num_classes, "verbose": -1}, lgb.Dataset(X_train, y_train))
clf = lgb.LGBMClassifier(objective="multiclass", num_classes=num_classes).fit(X_train, y_train)

res_engine = gbm.predict(X_train)
res_sklearn = clf.predict_proba(X_train)
jameslamb marked this conversation as resolved.
Show resolved Hide resolved

assert res_engine.shape == (num_samples, num_classes)
assert res_sklearn.shape == (num_samples, num_classes)
np.testing.assert_allclose(res_engine, res_sklearn)

res_class_sklearn = clf.predict(X_train)
np.testing.assert_allclose(res_class_sklearn, np.zeros((num_samples,)))


def test_predict_with_params_from_init():
X, y = load_iris(return_X_y=True)
Expand Down Expand Up @@ -1035,6 +1053,20 @@ def test_metrics():
assert len(gbm.evals_result_["training"]) == 1
assert "binary_logloss" in gbm.evals_result_["training"]

# the evaluation metric changes to multiclass metric even num classes is 2 for multiclass objective
gbm = lgb.LGBMClassifier(objective="multiclass", num_classes=2, **params).fit(
eval_metric="binary_logloss", **params_fit
)
assert len(gbm._evals_result["training"]) == 1
assert "multi_logloss" in gbm.evals_result_["training"]

# the evaluation metric changes to multiclass metric even num classes is 2 for ovr objective
gbm = lgb.LGBMClassifier(objective="ovr", num_classes=2, **params).fit(eval_metric="binary_error", **params_fit)
assert gbm.objective_ == "ovr"
assert len(gbm.evals_result_["training"]) == 2
assert "multi_logloss" in gbm.evals_result_["training"]
assert "multi_error" in gbm.evals_result_["training"]


def test_multiple_eval_metrics():
X, y = load_breast_cancer(return_X_y=True)
Expand Down