From f8ec57b8eba69b7371324bba1d158f850802c782 Mon Sep 17 00:00:00 2001 From: RektPunk <110188257+RektPunk@users.noreply.github.com> Date: Wed, 17 Jul 2024 06:05:00 +0900 Subject: [PATCH] [python-package] Correctly recognize LGBMClassifier(num_class=2, objective="multiclass") as multiclass classification (#6524) --- python-package/lightgbm/basic.py | 2 ++ python-package/lightgbm/sklearn.py | 14 +++++++--- tests/python_package_test/test_sklearn.py | 33 +++++++++++++++++++++++ 3 files changed, 45 insertions(+), 4 deletions(-) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 373c9911303a..194d9ca6c5b0 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -157,6 +157,8 @@ ZERO_THRESHOLD = 1e-35 +_MULTICLASS_OBJECTIVES = {"multiclass", "multiclassova", "multiclass_ova", "ova", "ovr", "softmax"} + def _is_zero(x: float) -> bool: return -ZERO_THRESHOLD <= x <= ZERO_THRESHOLD diff --git a/python-package/lightgbm/sklearn.py b/python-package/lightgbm/sklearn.py index 7f3e91a064c4..3c8d970e7428 100644 --- a/python-package/lightgbm/sklearn.py +++ b/python-package/lightgbm/sklearn.py @@ -10,6 +10,7 @@ import scipy.sparse from .basic import ( + _MULTICLASS_OBJECTIVES, Booster, Dataset, LightGBMError, @@ -467,7 +468,7 @@ def _extract_evaluation_meta_data( # It's possible, for example, to pass 3 eval sets through `eval_set`, # but only 1 init_score through `eval_init_score`. # - # This if-else accounts for that possiblity. + # This if-else accounts for that possibility. if len(collection) > i: return collection[i] else: @@ -1011,7 +1012,7 @@ def predict( f"match the input. Model n_features_ is {self._n_features} and " f"input n_features is {n_features}" ) - # retrive original params that possibly can be used in both training and prediction + # retrieve original params that possibly can be used in both training and prediction # and then overwrite them (considering aliases) with params that were passed directly in prediction predict_params = self._process_params(stage="predict") for alias in _ConfigAliases.get_by_alias( @@ -1251,7 +1252,7 @@ def fit( # type: ignore[override] eval_metric_list = [eval_metric] else: eval_metric_list = [] - if self._n_classes > 2: + if self.__is_multiclass: for index, metric in enumerate(eval_metric_list): if metric in {"logloss", "binary_logloss"}: eval_metric_list[index] = "multi_logloss" @@ -1361,7 +1362,7 @@ def predict_proba( "Returning raw scores instead." ) return result - elif self._n_classes > 2 or raw_score or pred_leaf or pred_contrib: # type: ignore [operator] + elif self.__is_multiclass or raw_score or pred_leaf or pred_contrib: # type: ignore [operator] return result else: return np.vstack((1.0 - result, result)).transpose() @@ -1389,6 +1390,11 @@ def n_classes_(self) -> int: raise LGBMNotFittedError("No classes found. Need to call fit beforehand.") return self._n_classes + @property + def __is_multiclass(self) -> bool: + """:obj:`bool`: Indicator of whether the classifier is used for multiclass.""" + return self._n_classes > 2 or (isinstance(self._objective, str) and self._objective in _MULTICLASS_OBJECTIVES) + class LGBMRanker(LGBMModel): """LightGBM ranker. diff --git a/tests/python_package_test/test_sklearn.py b/tests/python_package_test/test_sklearn.py index 478b66035837..6f0f7cb2ff3a 100644 --- a/tests/python_package_test/test_sklearn.py +++ b/tests/python_package_test/test_sklearn.py @@ -719,6 +719,25 @@ def test_predict(): with pytest.raises(AssertionError): np.testing.assert_allclose(res_engine, res_sklearn_params) + # Test multiclass binary classification + num_samples = 100 + num_classes = 2 + X_train = np.linspace(start=0, stop=10, num=num_samples * 3).reshape(num_samples, 3) + y_train = np.concatenate([np.zeros(int(num_samples / 2 - 10)), np.ones(int(num_samples / 2 + 10))]) + + gbm = lgb.train({"objective": "multiclass", "num_class": num_classes, "verbose": -1}, lgb.Dataset(X_train, y_train)) + clf = lgb.LGBMClassifier(objective="multiclass", num_classes=num_classes).fit(X_train, y_train) + + res_engine = gbm.predict(X_train) + res_sklearn = clf.predict_proba(X_train) + + assert res_engine.shape == (num_samples, num_classes) + assert res_sklearn.shape == (num_samples, num_classes) + np.testing.assert_allclose(res_engine, res_sklearn) + + res_class_sklearn = clf.predict(X_train) + np.testing.assert_allclose(res_class_sklearn, y_train) + def test_predict_with_params_from_init(): X, y = load_iris(return_X_y=True) @@ -1035,6 +1054,20 @@ def test_metrics(): assert len(gbm.evals_result_["training"]) == 1 assert "binary_logloss" in gbm.evals_result_["training"] + # the evaluation metric changes to multiclass metric even num classes is 2 for multiclass objective + gbm = lgb.LGBMClassifier(objective="multiclass", num_classes=2, **params).fit( + eval_metric="binary_logloss", **params_fit + ) + assert len(gbm._evals_result["training"]) == 1 + assert "multi_logloss" in gbm.evals_result_["training"] + + # the evaluation metric changes to multiclass metric even num classes is 2 for ovr objective + gbm = lgb.LGBMClassifier(objective="ovr", num_classes=2, **params).fit(eval_metric="binary_error", **params_fit) + assert gbm.objective_ == "ovr" + assert len(gbm.evals_result_["training"]) == 2 + assert "multi_logloss" in gbm.evals_result_["training"] + assert "multi_error" in gbm.evals_result_["training"] + def test_multiple_eval_metrics(): X, y = load_breast_cancer(return_X_y=True)