Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python-package] Correctly recognize LGBMClassifier(num_class=2, objective="multiclass") as multiclass classification #6524

Merged
merged 15 commits into from
Jul 16, 2024
6 changes: 3 additions & 3 deletions python-package/lightgbm/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ def _extract_evaluation_meta_data(
# It's possible, for example, to pass 3 eval sets through `eval_set`,
# but only 1 init_score through `eval_init_score`.
#
# This if-else accounts for that possiblity.
# This if-else accounts for that possibility.
if len(collection) > i:
return collection[i]
else:
Expand Down Expand Up @@ -1011,7 +1011,7 @@ def predict(
f"match the input. Model n_features_ is {self._n_features} and "
f"input n_features is {n_features}"
)
# retrive original params that possibly can be used in both training and prediction
# retrieve original params that possibly can be used in both training and prediction
# and then overwrite them (considering aliases) with params that were passed directly in prediction
predict_params = self._process_params(stage="predict")
for alias in _ConfigAliases.get_by_alias(
Expand Down Expand Up @@ -1361,7 +1361,7 @@ def predict_proba(
"Returning raw scores instead."
)
return result
elif self._n_classes > 2 or raw_score or pred_leaf or pred_contrib: # type: ignore [operator]
elif self._n_classes > 2 or len(result.shape) > 1 or raw_score or pred_leaf or pred_contrib: # type: ignore [operator]
return result
else:
return np.vstack((1.0 - result, result)).transpose()
Expand Down
15 changes: 15 additions & 0 deletions tests/python_package_test/test_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,6 +719,21 @@ def test_predict():
with pytest.raises(AssertionError):
np.testing.assert_allclose(res_engine, res_sklearn_params)

# Test multiclass binary classification
num_samples, num_classes = 20, 2
rng = np.random.Generator(np.random.PCG64())
X_train = rng.uniform(low=0, high=1, size=[num_samples, 3])
y_train = rng.choice([0, 1], size=num_samples)
jameslamb marked this conversation as resolved.
Show resolved Hide resolved

gbm = lgb.train({"objective": "multiclass", "num_class": num_classes, "verbose": -1}, lgb.Dataset(X_train, y_train))
clf = lgb.LGBMClassifier(objective="multiclass", num_classes=num_classes).fit(X_train, y_train)

res_engine = gbm.predict(X_train)
res_sklearn = clf.predict_proba(X_train)
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
assert res_engine.shape == (num_samples, num_classes)
assert res_sklearn.shape == (num_samples, num_classes)
np.testing.assert_allclose(res_engine, res_sklearn)


def test_predict_with_params_from_init():
X, y = load_iris(return_X_y=True)
Expand Down