From 7eae66a7bd0cf5e4bb081f23b61ad11b4f66d706 Mon Sep 17 00:00:00 2001 From: vnherdeiro Date: Thu, 10 Oct 2024 00:46:26 +0100 Subject: [PATCH] [python-package] require `scikit-learn>=0.24.2`, make scikit-learn estimators compatible with `scikit-learn>=1.6.0dev` (#6651) Co-authored-by: James Lamb Co-authored-by: Nikita Titov --- .ci/test-python-latest.sh | 2 +- .ci/test-python-oldest.sh | 2 +- .ci/test.sh | 1 + python-package/lightgbm/compat.py | 88 +++++++++++- python-package/lightgbm/sklearn.py | 163 +++++++++++++++++++--- python-package/pyproject.toml | 2 +- tests/python_package_test/test_sklearn.py | 80 ++++++++++- 7 files changed, 309 insertions(+), 29 deletions(-) diff --git a/.ci/test-python-latest.sh b/.ci/test-python-latest.sh index f98f29f2641a..08fc8558ef3e 100755 --- a/.ci/test-python-latest.sh +++ b/.ci/test-python-latest.sh @@ -22,7 +22,7 @@ python -m pip install \ 'numpy>=2.0.0.dev0' \ 'matplotlib>=3.10.0.dev0' \ 'pandas>=3.0.0.dev0' \ - 'scikit-learn==1.5.*' \ + 'scikit-learn>=1.6.dev0' \ 'scipy>=1.15.0.dev0' python -m pip install \ diff --git a/.ci/test-python-oldest.sh b/.ci/test-python-oldest.sh index 002a1c95833c..b33690324c79 100644 --- a/.ci/test-python-oldest.sh +++ b/.ci/test-python-oldest.sh @@ -15,7 +15,7 @@ pip install \ 'numpy==1.19.0' \ 'pandas==1.1.3' \ 'pyarrow==6.0.1' \ - 'scikit-learn==0.24.0' \ + 'scikit-learn==0.24.2' \ 'scipy==1.6.0' \ || exit 1 echo "done installing lightgbm's dependencies" diff --git a/.ci/test.sh b/.ci/test.sh index 9671e4e078e8..78ba3020c1d1 100755 --- a/.ci/test.sh +++ b/.ci/test.sh @@ -103,6 +103,7 @@ if [[ $TASK == "lint" ]]; then 'mypy>=1.11.1' \ 'pre-commit>=3.8.0' \ 'pyarrow-core>=17.0' \ + 'scikit-learn>=1.5.2' \ 'r-lintr>=3.1.2' source activate $CONDA_ENV echo "Linting Python code" diff --git a/python-package/lightgbm/compat.py b/python-package/lightgbm/compat.py index f916ca6be723..96dee6522572 100644 --- a/python-package/lightgbm/compat.py +++ b/python-package/lightgbm/compat.py @@ -1,12 +1,13 @@ # coding: utf-8 """Compatibility library.""" -from typing import Any, List +from typing import TYPE_CHECKING, Any, List # scikit-learn is intentionally imported first here, # see https://github.com/microsoft/LightGBM/issues/6509 """sklearn""" try: + from sklearn import __version__ as _sklearn_version from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin from sklearn.preprocessing import LabelEncoder from sklearn.utils.class_weight import compute_sample_weight @@ -29,6 +30,74 @@ def _check_sample_weight(sample_weight: Any, X: Any, dtype: Any = None) -> Any: check_consistent_length(sample_weight, X) return sample_weight + try: + from sklearn.utils.validation import validate_data + except ImportError: + # validate_data() was added in scikit-learn 1.6, this function roughly imitates it for older versions. + # It can be removed when lightgbm's minimum scikit-learn version is at least 1.6. + def validate_data( + _estimator, + X, + y="no_validation", + accept_sparse: bool = True, + # 'force_all_finite' was renamed to 'ensure_all_finite' in scikit-learn 1.6 + ensure_all_finite: bool = False, + ensure_min_samples: int = 1, + # trap other keyword arguments that only work on scikit-learn >=1.6, like 'reset' + **ignored_kwargs, + ): + # it's safe to import _num_features unconditionally because: + # + # * it was first added in scikit-learn 0.24.2 + # * lightgbm cannot be used with scikit-learn versions older than that + # * this validate_data() re-implementation will not be called in scikit-learn>=1.6 + # + from sklearn.utils.validation import _num_features + + # _num_features() raises a TypeError on 1-dimensional input. That's a problem + # because scikit-learn's 'check_fit1d' estimator check sets that expectation that + # estimators must raise a ValueError when a 1-dimensional input is passed to fit(). + # + # So here, lightgbm avoids calling _num_features() on 1-dimensional inputs. + if hasattr(X, "shape") and len(X.shape) == 1: + n_features_in_ = 1 + else: + n_features_in_ = _num_features(X) + + no_val_y = isinstance(y, str) and y == "no_validation" + + # NOTE: check_X_y() calls check_array() internally, so only need to call one or the other of them here + if no_val_y: + X = check_array( + X, + accept_sparse=accept_sparse, + force_all_finite=ensure_all_finite, + ensure_min_samples=ensure_min_samples, + ) + else: + X, y = check_X_y( + X, + y, + accept_sparse=accept_sparse, + force_all_finite=ensure_all_finite, + ensure_min_samples=ensure_min_samples, + ) + + # this only needs to be updated at fit() time + _estimator.n_features_in_ = n_features_in_ + + # raise the same error that scikit-learn's `validate_data()` does on scikit-learn>=1.6 + if _estimator.__sklearn_is_fitted__() and _estimator._n_features != n_features_in_: + raise ValueError( + f"X has {n_features_in_} features, but {_estimator.__class__.__name__} " + f"is expecting {_estimator._n_features} features as input." + ) + + if no_val_y: + return X + else: + return X, y + SKLEARN_INSTALLED = True _LGBMBaseCrossValidator = BaseCrossValidator _LGBMModelBase = BaseEstimator @@ -38,12 +107,11 @@ def _check_sample_weight(sample_weight: Any, X: Any, dtype: Any = None) -> Any: LGBMNotFittedError = NotFittedError _LGBMStratifiedKFold = StratifiedKFold _LGBMGroupKFold = GroupKFold - _LGBMCheckXY = check_X_y - _LGBMCheckArray = check_array _LGBMCheckSampleWeight = _check_sample_weight _LGBMAssertAllFinite = assert_all_finite _LGBMCheckClassificationTargets = check_classification_targets _LGBMComputeSampleWeight = compute_sample_weight + _LGBMValidateData = validate_data except ImportError: SKLEARN_INSTALLED = False @@ -67,12 +135,22 @@ class _LGBMRegressorBase: # type: ignore LGBMNotFittedError = ValueError _LGBMStratifiedKFold = None _LGBMGroupKFold = None - _LGBMCheckXY = None - _LGBMCheckArray = None _LGBMCheckSampleWeight = None _LGBMAssertAllFinite = None _LGBMCheckClassificationTargets = None _LGBMComputeSampleWeight = None + _LGBMValidateData = None + _sklearn_version = None + +# additional scikit-learn imports only for type hints +if TYPE_CHECKING: + # sklearn.utils.Tags can be imported unconditionally once + # lightgbm's minimum scikit-learn version is 1.6 or higher + try: + from sklearn.utils import Tags as _sklearn_Tags + except ImportError: + _sklearn_Tags = None + """pandas""" try: diff --git a/python-package/lightgbm/sklearn.py b/python-package/lightgbm/sklearn.py index ad805eef7332..e8e46eb42b86 100644 --- a/python-package/lightgbm/sklearn.py +++ b/python-package/lightgbm/sklearn.py @@ -4,7 +4,7 @@ import copy from inspect import signature from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import scipy.sparse @@ -31,21 +31,25 @@ SKLEARN_INSTALLED, LGBMNotFittedError, _LGBMAssertAllFinite, - _LGBMCheckArray, _LGBMCheckClassificationTargets, _LGBMCheckSampleWeight, - _LGBMCheckXY, _LGBMClassifierBase, _LGBMComputeSampleWeight, _LGBMCpuCount, _LGBMLabelEncoder, _LGBMModelBase, _LGBMRegressorBase, + _LGBMValidateData, + _sklearn_version, dt_DataTable, pd_DataFrame, ) from .engine import train +if TYPE_CHECKING: + from .compat import _sklearn_Tags + + __all__ = [ "LGBMClassifier", "LGBMModel", @@ -662,6 +666,10 @@ def __init__( self._n_classes: int = -1 self.set_params(**kwargs) + # scikit-learn 1.6 introduced an __sklearn__tags() method intended to replace _more_tags(). + # _more_tags() can be removed whenever lightgbm's minimum supported scikit-learn version + # is >=1.6. + # ref: https://github.com/microsoft/LightGBM/pull/6651 def _more_tags(self) -> Dict[str, Any]: return { "allow_nan": True, @@ -669,10 +677,46 @@ def _more_tags(self) -> Dict[str, Any]: "_xfail_checks": { "check_no_attributes_set_in_init": "scikit-learn incorrectly asserts that private attributes " "cannot be set in __init__: " - "(see https://github.com/microsoft/LightGBM/issues/2628)" + "(see https://github.com/microsoft/LightGBM/issues/2628)", }, } + @staticmethod + def _update_sklearn_tags_from_dict( + *, + tags: "_sklearn_Tags", + tags_dict: Dict[str, Any], + ) -> "_sklearn_Tags": + """Update ``sklearn.utils.Tags`` inherited from ``scikit-learn`` base classes. + + ``scikit-learn`` 1.6 introduced a dataclass-based interface for estimator tags. + ref: https://github.com/scikit-learn/scikit-learn/pull/29677 + + This method handles updating that instance based on the value in ``self._more_tags()``. + """ + tags.input_tags.allow_nan = tags_dict["allow_nan"] + tags.input_tags.sparse = "sparse" in tags_dict["X_types"] + tags.target_tags.one_d_labels = "1dlabels" in tags_dict["X_types"] + tags._xfail_checks = tags_dict["_xfail_checks"] + return tags + + def __sklearn_tags__(self) -> Optional["_sklearn_Tags"]: + # _LGBMModelBase.__sklearn_tags__() cannot be called unconditionally, + # because that method isn't defined for scikit-learn<1.6 + if not hasattr(_LGBMModelBase, "__sklearn_tags__"): + err_msg = ( + "__sklearn_tags__() should not be called when using scikit-learn<1.6. " + f"Detected version: {_sklearn_version}" + ) + raise AttributeError(err_msg) + + # take whatever tags are provided by BaseEstimator, then modify + # them with LightGBM-specific values + return self._update_sklearn_tags_from_dict( + tags=_LGBMModelBase.__sklearn_tags__(self), + tags_dict=self._more_tags(), + ) + def __sklearn_is_fitted__(self) -> bool: return getattr(self, "fitted_", False) @@ -862,12 +906,26 @@ def fit( params["metric"] = [metric for metric in params["metric"] if metric is not None] if not isinstance(X, (pd_DataFrame, dt_DataTable)): - _X, _y = _LGBMCheckXY(X, y, accept_sparse=True, force_all_finite=False, ensure_min_samples=2) + _X, _y = _LGBMValidateData( + self, + X, + y, + reset=True, + # allow any input type (this validation is done further down, in lgb.Dataset()) + accept_sparse=True, + # do not raise an error if Inf of NaN values are found (LightGBM handles these internally) + ensure_all_finite=False, + # raise an error on 0-row and 1-row inputs + ensure_min_samples=2, + ) if sample_weight is not None: sample_weight = _LGBMCheckSampleWeight(sample_weight, _X) else: _X, _y = X, y + # for other data types, setting n_features_in_ is handled by _LGBMValidateData() in the branch above + self.n_features_in_ = _X.shape[1] + if self._class_weight is None: self._class_weight = self.class_weight if self._class_weight is not None: @@ -877,10 +935,6 @@ def fit( else: sample_weight = np.multiply(sample_weight, class_sample_weight) - self._n_features = _X.shape[1] - # copy for consistency - self._n_features_in = self._n_features - train_set = Dataset( data=_X, label=_y, @@ -963,6 +1017,13 @@ def fit( callbacks=callbacks, ) + # This populates the property self.n_features_, the number of features in the fitted model, + # and so should only be set after fitting. + # + # The related property self._n_features_in, which populates self.n_features_in_, + # is set BEFORE fitting. + self._n_features = self._Booster.num_feature() + self._evals_result = evals_result self._best_iteration = self._Booster.best_iteration self._best_score = self._Booster.best_score @@ -1004,13 +1065,20 @@ def predict( if not self.__sklearn_is_fitted__(): raise LGBMNotFittedError("Estimator not fitted, call fit before exploiting the model.") if not isinstance(X, (pd_DataFrame, dt_DataTable)): - X = _LGBMCheckArray(X, accept_sparse=True, force_all_finite=False) - n_features = X.shape[1] - if self._n_features != n_features: - raise ValueError( - "Number of features of the model must " - f"match the input. Model n_features_ is {self._n_features} and " - f"input n_features is {n_features}" + X = _LGBMValidateData( + self, + X, + # 'y' being omitted = run scikit-learn's check_array() instead of check_X_y() + # + # Prevent scikit-learn from deleting or modifying attributes like 'feature_names_in_' and 'n_features_in_'. + # These shouldn't be changed at predict() time. + reset=False, + # allow any input type (this validation is done further down, in lgb.Dataset()) + accept_sparse=True, + # do not raise an error if Inf of NaN values are found (LightGBM handles these internally) + ensure_all_finite=False, + # raise an error on 0-row inputs + ensure_min_samples=1, ) # retrieve original params that possibly can be used in both training and prediction # and then overwrite them (considering aliases) with params that were passed directly in prediction @@ -1067,6 +1135,21 @@ def n_features_in_(self) -> int: raise LGBMNotFittedError("No n_features_in found. Need to call fit beforehand.") return self._n_features_in + @n_features_in_.setter + def n_features_in_(self, value: int) -> None: + """Set number of features found in passed-in dataset. + + Starting with ``scikit-learn`` 1.6, ``scikit-learn`` expects to be able to directly + set this property in functions like ``validate_data()``. + + .. note:: + + Do not call ``estimator.n_features_in_ = some_int`` or anything else that invokes + this method. It is only here for compatibility with ``scikit-learn`` validation + functions used internally in ``lightgbm``. + """ + self._n_features_in = value + @property def best_score_(self) -> _LGBM_BoosterBestScoreType: """:obj:`dict`: The best score of fitted model.""" @@ -1165,10 +1248,45 @@ def feature_names_in_(self) -> np.ndarray: raise LGBMNotFittedError("No feature_names_in_ found. Need to call fit beforehand.") return np.array(self.feature_name_) + @feature_names_in_.deleter + def feature_names_in_(self) -> None: + """Intercept calls to delete ``feature_names_in_``. + + Some code paths in ``scikit-learn`` try to delete the ``feature_names_in_`` attribute + on estimators when a new training dataset that doesn't have features is passed. + LightGBM automatically assigns feature names to such datasets + (like ``Column_0``, ``Column_1``, etc.) and so does not want that behavior. + + However, that behavior is coupled to ``scikit-learn`` automatically updating + ``n_features_in_`` in those same code paths, which is necessary for compliance + with its API (via argument ``reset`` to functions like ``validate_data()`` and + ``check_array()``). + + .. note:: + + Do not call ``del estimator.feature_names_in_`` or anything else that invokes + this method. It is only here for compatibility with ``scikit-learn`` validation + functions used internally in ``lightgbm``. + """ + pass + class LGBMRegressor(_LGBMRegressorBase, LGBMModel): """LightGBM regressor.""" + def _more_tags(self) -> Dict[str, Any]: + # handle the case where RegressorMixin possibly provides _more_tags() + if callable(getattr(_LGBMRegressorBase, "_more_tags", None)): + tags = _LGBMRegressorBase._more_tags(self) + else: + tags = {} + # override those with LightGBM-specific preferences + tags.update(LGBMModel._more_tags(self)) + return tags + + def __sklearn_tags__(self) -> "_sklearn_Tags": + return LGBMModel.__sklearn_tags__(self) + def fit( # type: ignore[override] self, X: _LGBM_ScikitMatrixLike, @@ -1215,6 +1333,19 @@ def fit( # type: ignore[override] class LGBMClassifier(_LGBMClassifierBase, LGBMModel): """LightGBM classifier.""" + def _more_tags(self) -> Dict[str, Any]: + # handle the case where ClassifierMixin possibly provides _more_tags() + if callable(getattr(_LGBMClassifierBase, "_more_tags", None)): + tags = _LGBMClassifierBase._more_tags(self) + else: + tags = {} + # override those with LightGBM-specific preferences + tags.update(LGBMModel._more_tags(self)) + return tags + + def __sklearn_tags__(self) -> "_sklearn_Tags": + return LGBMModel.__sklearn_tags__(self) + def fit( # type: ignore[override] self, X: _LGBM_ScikitMatrixLike, diff --git a/python-package/pyproject.toml b/python-package/pyproject.toml index 2212347637e6..19866e01202b 100644 --- a/python-package/pyproject.toml +++ b/python-package/pyproject.toml @@ -45,7 +45,7 @@ pandas = [ "pandas>=0.24.0" ] scikit-learn = [ - "scikit-learn!=0.22.0" + "scikit-learn>=0.24.2" ] [project.urls] diff --git a/tests/python_package_test/test_sklearn.py b/tests/python_package_test/test_sklearn.py index 01ab057cf3e2..6eca66ff20d3 100644 --- a/tests/python_package_test/test_sklearn.py +++ b/tests/python_package_test/test_sklearn.py @@ -36,12 +36,14 @@ ) decreasing_generator = itertools.count(0, -1) +estimator_classes = (lgb.LGBMModel, lgb.LGBMClassifier, lgb.LGBMRegressor, lgb.LGBMRanker) task_to_model_factory = { "ranking": lgb.LGBMRanker, "binary-classification": lgb.LGBMClassifier, "multiclass-classification": lgb.LGBMClassifier, "regression": lgb.LGBMRegressor, } +all_tasks = tuple(task_to_model_factory.keys()) def _create_data(task, n_samples=100, n_features=4): @@ -1311,7 +1313,7 @@ def test_check_is_fitted(): check_is_fitted(model) -@pytest.mark.parametrize("estimator_class", [lgb.LGBMModel, lgb.LGBMClassifier, lgb.LGBMRegressor, lgb.LGBMRanker]) +@pytest.mark.parametrize("estimator_class", estimator_classes) @pytest.mark.parametrize("max_depth", [3, 4, 5, 8]) def test_max_depth_warning_is_never_raised(capsys, estimator_class, max_depth): X, y = make_blobs(n_samples=1_000, n_features=1, centers=2) @@ -1390,7 +1392,7 @@ def test_fit_only_raises_num_rounds_warning_when_expected(capsys): assert_silent(capsys) -@pytest.mark.parametrize("estimator_class", [lgb.LGBMModel, lgb.LGBMClassifier, lgb.LGBMRegressor, lgb.LGBMRanker]) +@pytest.mark.parametrize("estimator_class", estimator_classes) def test_getting_feature_names_in_np_input(estimator_class): # input is a numpy array, which doesn't have feature names. LightGBM adds # feature names to the fitted model, which is inconsistent with sklearn's behavior @@ -1409,7 +1411,7 @@ def test_getting_feature_names_in_np_input(estimator_class): np.testing.assert_array_equal(model.feature_names_in_, np.array([f"Column_{i}" for i in range(X.shape[1])])) -@pytest.mark.parametrize("estimator_class", [lgb.LGBMModel, lgb.LGBMClassifier, lgb.LGBMRegressor, lgb.LGBMRanker]) +@pytest.mark.parametrize("estimator_class", estimator_classes) def test_getting_feature_names_in_pd_input(estimator_class): X, y = load_digits(n_class=2, return_X_y=True, as_frame=True) col_names = X.columns.to_list() @@ -1436,7 +1438,29 @@ def test_sklearn_integration(estimator, check): check(estimator) -@pytest.mark.parametrize("task", ["binary-classification", "multiclass-classification", "ranking", "regression"]) +@pytest.mark.parametrize("estimator_class", estimator_classes) +def test_sklearn_tags_should_correctly_reflect_lightgbm_specific_values(estimator_class): + est = estimator_class() + more_tags = est._more_tags() + err_msg = "List of supported X_types has changed. Update LGBMModel.__sklearn_tags__() to match." + assert more_tags["X_types"] == ["2darray", "sparse", "1dlabels"], err_msg + # the try-except part of this should be removed once lightgbm's + # minimum supported scikit-learn version is at least 1.6 + try: + sklearn_tags = est.__sklearn_tags__() + except AttributeError as err: + # only the exact error we expected to be raised should be raised + assert bool(re.search(r"__sklearn_tags__.* should not be called", str(err))) + else: + # if no AttributeError was thrown, we must be using scikit-learn>=1.6, + # and so the actual effects of __sklearn_tags__() should be tested + assert sklearn_tags.input_tags.allow_nan is True + assert sklearn_tags.input_tags.sparse is True + assert sklearn_tags.target_tags.one_d_labels is True + assert sklearn_tags._xfail_checks == more_tags["_xfail_checks"] + + +@pytest.mark.parametrize("task", all_tasks) def test_training_succeeds_when_data_is_dataframe_and_label_is_column_array(task): pd = pytest.importorskip("pandas") X, y, g = _create_data(task) @@ -1540,7 +1564,7 @@ def test_default_n_jobs(tmp_path): @pytest.mark.skipif(not PANDAS_INSTALLED, reason="pandas is not installed") -@pytest.mark.parametrize("task", ["binary-classification", "multiclass-classification", "ranking", "regression"]) +@pytest.mark.parametrize("task", all_tasks) def test_validate_features(task): X, y, g = _create_data(task, n_features=4) features = ["x1", "x2", "x3", "x4"] @@ -1561,6 +1585,52 @@ def test_validate_features(task): model.predict(df2, validate_features=False) +# LightGBM's 'predict_disable_shape_check' mechanism is intentionally not respected by +# its scikit-learn estimators, for consistency with scikit-learn's own behavior. +@pytest.mark.parametrize("task", all_tasks) +@pytest.mark.parametrize("predict_disable_shape_check", [True, False]) +def test_predict_rejects_inputs_with_incorrect_number_of_features(predict_disable_shape_check, task): + X, y, g = _create_data(task, n_features=4) + model_factory = task_to_model_factory[task] + fit_kwargs = {"X": X[:, :-1], "y": y} + if task == "ranking": + estimator_name = "LGBMRanker" + fit_kwargs.update({"group": g}) + elif task == "regression": + estimator_name = "LGBMRegressor" + else: + estimator_name = "LGBMClassifier" + + # train on the first 3 features + model = model_factory(n_estimators=5, num_leaves=7, verbose=-1).fit(**fit_kwargs) + + # more cols in X than features: error + err_msg = f"X has 4 features, but {estimator_name} is expecting 3 features as input" + with pytest.raises(ValueError, match=err_msg): + model.predict(X, predict_disable_shape_check=predict_disable_shape_check) + + if estimator_name == "LGBMClassifier": + with pytest.raises(ValueError, match=err_msg): + model.predict_proba(X, predict_disable_shape_check=predict_disable_shape_check) + + # fewer cols in X than features: error + err_msg = f"X has 2 features, but {estimator_name} is expecting 3 features as input" + with pytest.raises(ValueError, match=err_msg): + model.predict(X[:, :-2], predict_disable_shape_check=predict_disable_shape_check) + + if estimator_name == "LGBMClassifier": + with pytest.raises(ValueError, match=err_msg): + model.predict_proba(X[:, :-2], predict_disable_shape_check=predict_disable_shape_check) + + # same number of columns in both: no error + preds = model.predict(X[:, :-1], predict_disable_shape_check=predict_disable_shape_check) + assert preds.shape == y.shape + + if estimator_name == "LGBMClassifier": + preds = model.predict_proba(X[:, :-1], predict_disable_shape_check=predict_disable_shape_check) + assert preds.shape[0] == y.shape[0] + + @pytest.mark.parametrize("X_type", ["dt_DataTable", "list2d", "numpy", "scipy_csc", "scipy_csr", "pd_DataFrame"]) @pytest.mark.parametrize("y_type", ["list1d", "numpy", "pd_Series", "pd_DataFrame"]) @pytest.mark.parametrize("task", ["binary-classification", "multiclass-classification", "regression"])