diff --git a/docs/FAQ.rst b/docs/FAQ.rst index 14c7f7dd7265..aaaf4094ce3c 100644 --- a/docs/FAQ.rst +++ b/docs/FAQ.rst @@ -377,3 +377,42 @@ We strongly recommend installation from the ``conda-forge`` channel and not from For some specific examples, see `this comment <https://github.com/microsoft/LightGBM/issues/4948#issuecomment-1013766397>`__. In addition, as of ``lightgbm==4.4.0``, the ``conda-forge`` package automatically supports CUDA-based GPU acceleration. + +5. How do I subclass ``scikit-learn`` estimators? +------------------------------------------------- + +For ``lightgbm <= 4.5.0``, copy all of the constructor arguments from the corresponding +``lightgbm`` class into the constructor of your custom estimator. + +For later versions, just ensure that the constructor of your custom estimator calls ``super().__init__()``. + +Consider the example below, which implements a regressor that allows creation of truncated predictions. +This pattern will work with ``lightgbm > 4.5.0``. + +.. code-block:: python + + import numpy as np + from lightgbm import LGBMRegressor + from sklearn.datasets import make_regression + + class TruncatedRegressor(LGBMRegressor): + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def predict(self, X, max_score: float = np.inf): + preds = super().predict(X) + np.clip(preds, a_min=None, a_max=max_score, out=preds) + return preds + + X, y = make_regression(n_samples=1_000, n_features=4) + + reg_trunc = TruncatedRegressor().fit(X, y) + + preds = reg_trunc.predict(X) + print(f"mean: {preds.mean():.2f}, max: {preds.max():.2f}") + # mean: -6.81, max: 345.10 + + preds_trunc = reg_trunc.predict(X, max_score=preds.mean()) + print(f"mean: {preds_trunc.mean():.2f}, max: {preds_trunc.max():.2f}") + # mean: -56.50, max: -6.81 diff --git a/python-package/lightgbm/dask.py b/python-package/lightgbm/dask.py index 8b939a8cb49d..12e778f37075 100644 --- a/python-package/lightgbm/dask.py +++ b/python-package/lightgbm/dask.py @@ -1115,6 +1115,7 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel): def __init__( self, + *, boosting_type: str = "gbdt", num_leaves: int = 31, max_depth: int = -1, @@ -1318,6 +1319,7 @@ class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel): def __init__( self, + *, boosting_type: str = "gbdt", num_leaves: int = 31, max_depth: int = -1, @@ -1485,6 +1487,7 @@ class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel): def __init__( self, + *, boosting_type: str = "gbdt", num_leaves: int = 31, max_depth: int = -1, diff --git a/python-package/lightgbm/sklearn.py b/python-package/lightgbm/sklearn.py index 108ef1e14498..ab0686e216fa 100644 --- a/python-package/lightgbm/sklearn.py +++ b/python-package/lightgbm/sklearn.py @@ -488,6 +488,7 @@ class LGBMModel(_LGBMModelBase): def __init__( self, + *, boosting_type: str = "gbdt", num_leaves: int = 31, max_depth: int = -1, @@ -745,7 +746,35 @@ def get_params(self, deep: bool = True) -> Dict[str, Any]: params : dict Parameter names mapped to their values. """ + # Based on: https://github.com/dmlc/xgboost/blob/bd92b1c9c0db3e75ec3dfa513e1435d518bb535d/python-package/xgboost/sklearn.py#L941 + # which was based on: https://stackoverflow.com/questions/59248211 + # + # `get_params()` flows like this: + # + # 0. Get parameters in subclass (self.__class__) first, by using inspect. + # 1. Get parameters in all parent classes (especially `LGBMModel`). + # 2. Get whatever was passed via `**kwargs`. + # 3. Merge them. + # + # This needs to accommodate being called recursively in the following + # inheritance graphs (and similar for classification and ranking): + # + # DaskLGBMRegressor -> LGBMRegressor -> LGBMModel -> BaseEstimator + # (custom subclass) -> LGBMRegressor -> LGBMModel -> BaseEstimator + # LGBMRegressor -> LGBMModel -> BaseEstimator + # (custom subclass) -> LGBMModel -> BaseEstimator + # LGBMModel -> BaseEstimator + # params = super().get_params(deep=deep) + cp = copy.copy(self) + # If the immediate parent defines get_params(), use that. + if callable(getattr(cp.__class__.__bases__[0], "get_params", None)): + cp.__class__ = cp.__class__.__bases__[0] + # Otherwise, skip it and assume the next class will have it. + # This is here primarily for cases where the first class in MRO is a scikit-learn mixin. + else: + cp.__class__ = cp.__class__.__bases__[1] + params.update(cp.__class__.get_params(cp, deep)) params.update(self._other_params) return params @@ -1285,6 +1314,57 @@ def feature_names_in_(self) -> None: class LGBMRegressor(_LGBMRegressorBase, LGBMModel): """LightGBM regressor.""" + # NOTE: all args from LGBMModel.__init__() are intentionally repeated here for + # docs, help(), and tab completion. + def __init__( + self, + *, + boosting_type: str = "gbdt", + num_leaves: int = 31, + max_depth: int = -1, + learning_rate: float = 0.1, + n_estimators: int = 100, + subsample_for_bin: int = 200000, + objective: Optional[Union[str, _LGBM_ScikitCustomObjectiveFunction]] = None, + class_weight: Optional[Union[Dict, str]] = None, + min_split_gain: float = 0.0, + min_child_weight: float = 1e-3, + min_child_samples: int = 20, + subsample: float = 1.0, + subsample_freq: int = 0, + colsample_bytree: float = 1.0, + reg_alpha: float = 0.0, + reg_lambda: float = 0.0, + random_state: Optional[Union[int, np.random.RandomState, np.random.Generator]] = None, + n_jobs: Optional[int] = None, + importance_type: str = "split", + **kwargs: Any, + ) -> None: + super().__init__( + boosting_type=boosting_type, + num_leaves=num_leaves, + max_depth=max_depth, + learning_rate=learning_rate, + n_estimators=n_estimators, + subsample_for_bin=subsample_for_bin, + objective=objective, + class_weight=class_weight, + min_split_gain=min_split_gain, + min_child_weight=min_child_weight, + min_child_samples=min_child_samples, + subsample=subsample, + subsample_freq=subsample_freq, + colsample_bytree=colsample_bytree, + reg_alpha=reg_alpha, + reg_lambda=reg_lambda, + random_state=random_state, + n_jobs=n_jobs, + importance_type=importance_type, + **kwargs, + ) + + __init__.__doc__ = LGBMModel.__init__.__doc__ + def _more_tags(self) -> Dict[str, Any]: # handle the case where RegressorMixin possibly provides _more_tags() if callable(getattr(_LGBMRegressorBase, "_more_tags", None)): @@ -1344,6 +1424,57 @@ def fit( # type: ignore[override] class LGBMClassifier(_LGBMClassifierBase, LGBMModel): """LightGBM classifier.""" + # NOTE: all args from LGBMModel.__init__() are intentionally repeated here for + # docs, help(), and tab completion. + def __init__( + self, + *, + boosting_type: str = "gbdt", + num_leaves: int = 31, + max_depth: int = -1, + learning_rate: float = 0.1, + n_estimators: int = 100, + subsample_for_bin: int = 200000, + objective: Optional[Union[str, _LGBM_ScikitCustomObjectiveFunction]] = None, + class_weight: Optional[Union[Dict, str]] = None, + min_split_gain: float = 0.0, + min_child_weight: float = 1e-3, + min_child_samples: int = 20, + subsample: float = 1.0, + subsample_freq: int = 0, + colsample_bytree: float = 1.0, + reg_alpha: float = 0.0, + reg_lambda: float = 0.0, + random_state: Optional[Union[int, np.random.RandomState, np.random.Generator]] = None, + n_jobs: Optional[int] = None, + importance_type: str = "split", + **kwargs: Any, + ) -> None: + super().__init__( + boosting_type=boosting_type, + num_leaves=num_leaves, + max_depth=max_depth, + learning_rate=learning_rate, + n_estimators=n_estimators, + subsample_for_bin=subsample_for_bin, + objective=objective, + class_weight=class_weight, + min_split_gain=min_split_gain, + min_child_weight=min_child_weight, + min_child_samples=min_child_samples, + subsample=subsample, + subsample_freq=subsample_freq, + colsample_bytree=colsample_bytree, + reg_alpha=reg_alpha, + reg_lambda=reg_lambda, + random_state=random_state, + n_jobs=n_jobs, + importance_type=importance_type, + **kwargs, + ) + + __init__.__doc__ = LGBMModel.__init__.__doc__ + def _more_tags(self) -> Dict[str, Any]: # handle the case where ClassifierMixin possibly provides _more_tags() if callable(getattr(_LGBMClassifierBase, "_more_tags", None)): @@ -1554,6 +1685,57 @@ class LGBMRanker(LGBMModel): Please use this class mainly for training and applying ranking models in common sklearnish way. """ + # NOTE: all args from LGBMModel.__init__() are intentionally repeated here for + # docs, help(), and tab completion. + def __init__( + self, + *, + boosting_type: str = "gbdt", + num_leaves: int = 31, + max_depth: int = -1, + learning_rate: float = 0.1, + n_estimators: int = 100, + subsample_for_bin: int = 200000, + objective: Optional[Union[str, _LGBM_ScikitCustomObjectiveFunction]] = None, + class_weight: Optional[Union[Dict, str]] = None, + min_split_gain: float = 0.0, + min_child_weight: float = 1e-3, + min_child_samples: int = 20, + subsample: float = 1.0, + subsample_freq: int = 0, + colsample_bytree: float = 1.0, + reg_alpha: float = 0.0, + reg_lambda: float = 0.0, + random_state: Optional[Union[int, np.random.RandomState, np.random.Generator]] = None, + n_jobs: Optional[int] = None, + importance_type: str = "split", + **kwargs: Any, + ) -> None: + super().__init__( + boosting_type=boosting_type, + num_leaves=num_leaves, + max_depth=max_depth, + learning_rate=learning_rate, + n_estimators=n_estimators, + subsample_for_bin=subsample_for_bin, + objective=objective, + class_weight=class_weight, + min_split_gain=min_split_gain, + min_child_weight=min_child_weight, + min_child_samples=min_child_samples, + subsample=subsample, + subsample_freq=subsample_freq, + colsample_bytree=colsample_bytree, + reg_alpha=reg_alpha, + reg_lambda=reg_lambda, + random_state=random_state, + n_jobs=n_jobs, + importance_type=importance_type, + **kwargs, + ) + + __init__.__doc__ = LGBMModel.__init__.__doc__ + def fit( # type: ignore[override] self, X: _LGBM_ScikitMatrixLike, diff --git a/tests/python_package_test/test_dask.py b/tests/python_package_test/test_dask.py index b5e17991f63d..ad13734187d8 100644 --- a/tests/python_package_test/test_dask.py +++ b/tests/python_package_test/test_dask.py @@ -1373,26 +1373,42 @@ def test_machines_should_be_used_if_provided(task, cluster): @pytest.mark.parametrize( - "classes", + "dask_est,sklearn_est", [ (lgb.DaskLGBMClassifier, lgb.LGBMClassifier), (lgb.DaskLGBMRegressor, lgb.LGBMRegressor), (lgb.DaskLGBMRanker, lgb.LGBMRanker), ], ) -def test_dask_classes_and_sklearn_equivalents_have_identical_constructors_except_client_arg(classes): - dask_spec = inspect.getfullargspec(classes[0]) - sklearn_spec = inspect.getfullargspec(classes[1]) +def test_dask_classes_and_sklearn_equivalents_have_identical_constructors_except_client_arg(dask_est, sklearn_est): + dask_spec = inspect.getfullargspec(dask_est) + sklearn_spec = inspect.getfullargspec(sklearn_est) + + # should not allow for any varargs assert dask_spec.varargs == sklearn_spec.varargs + assert dask_spec.varargs is None + + # the only varkw should be **kwargs, + # for pass-through to parent classes' __init__() assert dask_spec.varkw == sklearn_spec.varkw - assert dask_spec.kwonlyargs == sklearn_spec.kwonlyargs - assert dask_spec.kwonlydefaults == sklearn_spec.kwonlydefaults + assert dask_spec.varkw == "kwargs" # "client" should be the only different, and the final argument - assert dask_spec.args[:-1] == sklearn_spec.args - assert dask_spec.defaults[:-1] == sklearn_spec.defaults - assert dask_spec.args[-1] == "client" - assert dask_spec.defaults[-1] is None + assert dask_spec.kwonlyargs == [*sklearn_spec.kwonlyargs, "client"] + + # default values for all constructor arguments should be identical + # + # NOTE: if LGBMClassifier / LGBMRanker / LGBMRegressor ever override + # any of LGBMModel's constructor arguments, this will need to be updated + assert dask_spec.kwonlydefaults == {**sklearn_spec.kwonlydefaults, "client": None} + + # only positional argument should be 'self' + assert dask_spec.args == sklearn_spec.args + assert dask_spec.args == ["self"] + assert dask_spec.defaults is None + + # get_params() should be identical, except for "client" + assert dask_est().get_params() == {**sklearn_est().get_params(), "client": None} @pytest.mark.parametrize( diff --git a/tests/python_package_test/test_sklearn.py b/tests/python_package_test/test_sklearn.py index da6c94d41183..e26e14c24ec6 100644 --- a/tests/python_package_test/test_sklearn.py +++ b/tests/python_package_test/test_sklearn.py @@ -1,4 +1,5 @@ # coding: utf-8 +import inspect import itertools import math import re @@ -22,6 +23,7 @@ import lightgbm as lgb from lightgbm.compat import ( + DASK_INSTALLED, DATATABLE_INSTALLED, PANDAS_INSTALLED, _sklearn_version, @@ -83,6 +85,30 @@ def __call__(self, env): env.model.attr_set_inside_callback = env.iteration * 10 +class ExtendedLGBMClassifier(lgb.LGBMClassifier): + """Class for testing that inheriting from LGBMClassifier works""" + + def __init__(self, *, some_other_param: str = "lgbm-classifier", **kwargs): + self.some_other_param = some_other_param + super().__init__(**kwargs) + + +class ExtendedLGBMRanker(lgb.LGBMRanker): + """Class for testing that inheriting from LGBMRanker works""" + + def __init__(self, *, some_other_param: str = "lgbm-ranker", **kwargs): + self.some_other_param = some_other_param + super().__init__(**kwargs) + + +class ExtendedLGBMRegressor(lgb.LGBMRegressor): + """Class for testing that inheriting from LGBMRegressor works""" + + def __init__(self, *, some_other_param: str = "lgbm-regressor", **kwargs): + self.some_other_param = some_other_param + super().__init__(**kwargs) + + def custom_asymmetric_obj(y_true, y_pred): residual = (y_true - y_pred).astype(np.float64) grad = np.where(residual < 0, -2 * 10.0 * residual, -2 * residual) @@ -475,6 +501,193 @@ def test_clone_and_property(): assert isinstance(clf.feature_importances_, np.ndarray) +@pytest.mark.parametrize("estimator", (lgb.LGBMClassifier, lgb.LGBMRegressor, lgb.LGBMRanker)) +def test_estimators_all_have_the_same_kwargs_and_defaults(estimator): + base_spec = inspect.getfullargspec(lgb.LGBMModel) + subclass_spec = inspect.getfullargspec(estimator) + + # should not allow for any varargs + assert subclass_spec.varargs == base_spec.varargs + assert subclass_spec.varargs is None + + # the only varkw should be **kwargs, + assert subclass_spec.varkw == base_spec.varkw + assert subclass_spec.varkw == "kwargs" + + # default values for all constructor arguments should be identical + # + # NOTE: if LGBMClassifier / LGBMRanker / LGBMRegressor ever override + # any of LGBMModel's constructor arguments, this will need to be updated + assert subclass_spec.kwonlydefaults == base_spec.kwonlydefaults + + # only positional argument should be 'self' + assert subclass_spec.args == base_spec.args + assert subclass_spec.args == ["self"] + assert subclass_spec.defaults is None + + # get_params() should be identical + assert estimator().get_params() == lgb.LGBMModel().get_params() + + +def test_subclassing_get_params_works(): + expected_params = { + "boosting_type": "gbdt", + "class_weight": None, + "colsample_bytree": 1.0, + "importance_type": "split", + "learning_rate": 0.1, + "max_depth": -1, + "min_child_samples": 20, + "min_child_weight": 0.001, + "min_split_gain": 0.0, + "n_estimators": 100, + "n_jobs": None, + "num_leaves": 31, + "objective": None, + "random_state": None, + "reg_alpha": 0.0, + "reg_lambda": 0.0, + "subsample": 1.0, + "subsample_for_bin": 200000, + "subsample_freq": 0, + } + + # Overrides, used to test that passing through **kwargs works as expected. + # + # why these? + # + # - 'n_estimators' directly matches a keyword arg for the scikit-learn estimators + # - 'eta' is a parameter alias for 'learning_rate' + overrides = {"n_estimators": 13, "eta": 0.07} + + # lightgbm-official classes + for est in [lgb.LGBMModel, lgb.LGBMClassifier, lgb.LGBMRanker, lgb.LGBMRegressor]: + assert est().get_params() == expected_params + assert est(**overrides).get_params() == { + **expected_params, + "eta": 0.07, + "n_estimators": 13, + "learning_rate": 0.1, + } + + if DASK_INSTALLED: + for est in [lgb.DaskLGBMClassifier, lgb.DaskLGBMRanker, lgb.DaskLGBMRegressor]: + assert est().get_params() == { + **expected_params, + "client": None, + } + assert est(**overrides).get_params() == { + **expected_params, + "eta": 0.07, + "n_estimators": 13, + "learning_rate": 0.1, + "client": None, + } + + # custom sub-classes + assert ExtendedLGBMClassifier().get_params() == {**expected_params, "some_other_param": "lgbm-classifier"} + assert ExtendedLGBMClassifier(**overrides).get_params() == { + **expected_params, + "eta": 0.07, + "n_estimators": 13, + "learning_rate": 0.1, + "some_other_param": "lgbm-classifier", + } + assert ExtendedLGBMRanker().get_params() == { + **expected_params, + "some_other_param": "lgbm-ranker", + } + assert ExtendedLGBMRanker(**overrides).get_params() == { + **expected_params, + "eta": 0.07, + "n_estimators": 13, + "learning_rate": 0.1, + "some_other_param": "lgbm-ranker", + } + assert ExtendedLGBMRegressor().get_params() == { + **expected_params, + "some_other_param": "lgbm-regressor", + } + assert ExtendedLGBMRegressor(**overrides).get_params() == { + **expected_params, + "eta": 0.07, + "n_estimators": 13, + "learning_rate": 0.1, + "some_other_param": "lgbm-regressor", + } + + +@pytest.mark.parametrize("task", all_tasks) +def test_subclassing_works(task): + # param values to make training deterministic and + # just train a small, cheap model + params = { + "deterministic": True, + "force_row_wise": True, + "n_jobs": 1, + "n_estimators": 5, + "num_leaves": 11, + "random_state": 708, + } + + X, y, g = _create_data(task=task) + if task == "ranking": + est = lgb.LGBMRanker(**params).fit(X, y, group=g) + est_sub = ExtendedLGBMRanker(**params).fit(X, y, group=g) + elif task.endswith("classification"): + est = lgb.LGBMClassifier(**params).fit(X, y) + est_sub = ExtendedLGBMClassifier(**params).fit(X, y) + else: + est = lgb.LGBMRegressor(**params).fit(X, y) + est_sub = ExtendedLGBMRegressor(**params).fit(X, y) + + np.testing.assert_allclose(est.predict(X), est_sub.predict(X)) + + +@pytest.mark.parametrize( + "estimator_to_task", + [ + (lgb.LGBMClassifier, "binary-classification"), + (ExtendedLGBMClassifier, "binary-classification"), + (lgb.LGBMRanker, "ranking"), + (ExtendedLGBMRanker, "ranking"), + (lgb.LGBMRegressor, "regression"), + (ExtendedLGBMRegressor, "regression"), + ], +) +def test_parameter_aliases_are_handled_correctly(estimator_to_task): + estimator, task = estimator_to_task + # scikit-learn estimators should remember every parameter passed + # via keyword arguments in the estimator constructor, but then + # only pass the correct value down to LightGBM's C++ side + params = { + "eta": 0.08, + "num_iterations": 3, + "num_leaves": 5, + } + X, y, g = _create_data(task=task) + mod = estimator(**params) + if task == "ranking": + mod.fit(X, y, group=g) + else: + mod.fit(X, y) + + # scikit-learn get_params() + p = mod.get_params() + assert p["eta"] == 0.08 + assert p["learning_rate"] == 0.1 + + # lgb.Booster's 'params' attribute + p = mod.booster_.params + assert p["eta"] == 0.08 + assert p["learning_rate"] == 0.1 + + # Config in the 'LightGBM::Booster' on the C++ side + p = mod.booster_._get_loaded_param() + assert p["learning_rate"] == 0.1 + assert "eta" not in p + + def test_joblib(tmp_path): X, y = make_synthetic_regression() X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42) @@ -1463,7 +1676,10 @@ def _get_expected_failed_tests(estimator): return estimator._more_tags()["_xfail_checks"] -@parametrize_with_checks([lgb.LGBMClassifier(), lgb.LGBMRegressor()], expected_failed_checks=_get_expected_failed_tests) +@parametrize_with_checks( + [ExtendedLGBMClassifier(), ExtendedLGBMRegressor(), lgb.LGBMClassifier(), lgb.LGBMRegressor()], + expected_failed_checks=_get_expected_failed_tests, +) def test_sklearn_integration(estimator, check): estimator.set_params(min_child_samples=1, min_data_in_bin=1) check(estimator)