Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python-package] support sub-classing scikit-learn estimators #6783

Merged
merged 24 commits into from
Feb 12, 2025
Merged
Changes from 9 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
3b5f648
[python-package] make sub-classing scikit-learn estimators easier
jameslamb Jan 4, 2025
02c48c3
tests passing
jameslamb Jan 4, 2025
7b720cb
add docs
jameslamb Jan 7, 2025
51b5e64
Update tests/python_package_test/test_sklearn.py
jameslamb Jan 10, 2025
81178fd
remove docs links
jameslamb Jan 10, 2025
110b0e1
Merge branch 'python/sklearn-subclassing' of github.com:microsoft/Lig…
jameslamb Jan 10, 2025
104471a
Merge branch 'master' into python/sklearn-subclassing
jameslamb Jan 11, 2025
d80b0df
fix Dask tests
jameslamb Jan 13, 2025
b7e041a
Merge branch 'python/sklearn-subclassing' of github.com:microsoft/Lig…
jameslamb Jan 13, 2025
68177a7
Merge branch 'master' into python/sklearn-subclassing
StrikerRUS Jan 20, 2025
70f29a7
Merge branch 'master' into python/sklearn-subclassing
jameslamb Jan 26, 2025
6796ba9
Merge branch 'master' into python/sklearn-subclassing
jameslamb Jan 26, 2025
409733a
Update tests/python_package_test/test_dask.py
jameslamb Jan 30, 2025
0a40e9b
Update python-package/lightgbm/sklearn.py
jameslamb Jan 30, 2025
64850c6
Merge branch 'master' into python/sklearn-subclassing
jameslamb Jan 30, 2025
cd54639
Update docs/FAQ.rst
jameslamb Jan 30, 2025
e39d19f
Merge branch 'master' into python/sklearn-subclassing
jameslamb Jan 30, 2025
7077c24
Merge branch 'master' into python/sklearn-subclassing
jameslamb Feb 5, 2025
7c59cd9
Merge branch 'python/sklearn-subclassing' of github.com:microsoft/Lig…
jameslamb Feb 7, 2025
98eb476
restore Dask signatures
jameslamb Feb 7, 2025
734961c
Merge branch 'master' into python/sklearn-subclassing
jameslamb Feb 10, 2025
3d351a4
repeat all params
jameslamb Feb 11, 2025
9be0ec1
Merge branch 'python/sklearn-subclassing' of github.com:microsoft/Lig…
jameslamb Feb 11, 2025
51c18ad
update Dask tests
jameslamb Feb 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions docs/FAQ.rst
Original file line number Diff line number Diff line change
@@ -377,3 +377,42 @@ We strongly recommend installation from the ``conda-forge`` channel and not from
For some specific examples, see `this comment <https://github.com/microsoft/LightGBM/issues/4948#issuecomment-1013766397>`__.

In addition, as of ``lightgbm==4.4.0``, the ``conda-forge`` package automatically supports CUDA-based GPU acceleration.

5. How do I subclass ``scikit-learn`` estimators?
-------------------------------------------------

For ``lightgbm <= 4.5.0``, copy all of the constructor arguments from the corresponding
``lightgbm`` class into the constructor of your custom estimator.

For later versions, just ensure that the constructor of your custom estimator calls ``super().__init__()``.

Consider the example below, which implements a regressor that allows creation of truncated predictions.
This pattern will work with ``lightgbm > 4.5.0``.

.. code-block:: python

import numpy as np
from lightgbm import LGBMRegressor
from sklearn.datasets import make_regression

class TruncatedRegressor(LGBMRegressor):

def __init__(self, **kwargs):
super().__init__(**kwargs)

def predict(self, X, max_score: float = np.inf):
preds = super().predict(X)
preds[np.where(preds > max_score)] = max_score
return preds

X, y = make_regression(n_samples=1_000, n_features=4)

reg_trunc = TruncatedRegressor().fit(X, y)

preds = reg_trunc.predict(X)
print(f"mean: {preds.mean():.2f}, max: {preds.max():.2f}")
# mean: -6.81, max: 345.10

preds_trunc = reg_trunc.predict(X, max_score = preds.mean())
print(f"mean: {preds_trunc.mean():.2f}, max: {preds_trunc.max():.2f}")
# mean: -56.50, max: -6.81
130 changes: 6 additions & 124 deletions python-package/lightgbm/dask.py
Original file line number Diff line number Diff line change
@@ -44,7 +44,6 @@
LGBMModel,
LGBMRanker,
LGBMRegressor,
_LGBM_ScikitCustomObjectiveFunction,
_LGBM_ScikitEvalMetricType,
_lgbmmodel_doc_custom_eval_note,
_lgbmmodel_doc_fit,
@@ -1115,52 +1114,13 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):

def __init__(
self,
boosting_type: str = "gbdt",
num_leaves: int = 31,
max_depth: int = -1,
learning_rate: float = 0.1,
n_estimators: int = 100,
subsample_for_bin: int = 200000,
objective: Optional[Union[str, _LGBM_ScikitCustomObjectiveFunction]] = None,
class_weight: Optional[Union[dict, str]] = None,
min_split_gain: float = 0.0,
min_child_weight: float = 1e-3,
min_child_samples: int = 20,
subsample: float = 1.0,
subsample_freq: int = 0,
colsample_bytree: float = 1.0,
reg_alpha: float = 0.0,
reg_lambda: float = 0.0,
random_state: Optional[Union[int, np.random.RandomState, "np.random.Generator"]] = None,
n_jobs: Optional[int] = None,
importance_type: str = "split",
*,
client: Optional[Client] = None,
**kwargs: Any,
):
"""Docstring is inherited from the lightgbm.LGBMClassifier.__init__."""
self.client = client
super().__init__(
boosting_type=boosting_type,
num_leaves=num_leaves,
max_depth=max_depth,
learning_rate=learning_rate,
n_estimators=n_estimators,
subsample_for_bin=subsample_for_bin,
objective=objective,
class_weight=class_weight,
min_split_gain=min_split_gain,
min_child_weight=min_child_weight,
min_child_samples=min_child_samples,
subsample=subsample,
subsample_freq=subsample_freq,
colsample_bytree=colsample_bytree,
reg_alpha=reg_alpha,
reg_lambda=reg_lambda,
random_state=random_state,
n_jobs=n_jobs,
importance_type=importance_type,
**kwargs,
)
super().__init__(**kwargs)

_base_doc = LGBMClassifier.__init__.__doc__
Copy link
Collaborator

@StrikerRUS StrikerRUS Jan 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think it's OK to have just one client argument in the signature, but describe all parent args in the docstring?..

image

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's a little better for users to see all the parameters right here, instead of having to click over to another page.

This is what XGBoost is doing too: https://xgboost.readthedocs.io/en/stable/python/python_api.html#xgboost.XGBRFRegressor

But I do also appreciate that it could look confusing.

If we don't do it this way, then I'd recommend we add a link in the docs for `**kwargs`` in these estimators, like this:

**kwargs Other parameters for the model. These can be any of the keyword arguments for LGBMModel or any other LightGBM parameters documented at https://lightgbm.readthedocs.io/en/latest/Parameters.html.

I have a weak preference for keeping it as-is (the signature in docs not having all parameters, but docstring having all parameters), but happy to change it if you think that's confusing.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for clarifying your opinion!
I love your suggestion for **kwargs description. But my preference is also weak 🙂
I think we need a third judge opinion for this question.

Either way, I'm approving this PR!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jmoralez or @borchero could one of you comment on this thread and help us break the tie?

To make progress on the release, if we don't hear back in the next 2 days I'll merge this PR as-is and we can come back and change the docs later.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I only saw this now! My personal preference would actually be to keep all of the parameters (similar to the previous state) and simply make them keyword arguments. While this results in more code and some duplication of defaults, I think that this is the clearest interface for users. If you think this is undesirable @jameslamb, I'd at least opt for documenting all of the "transitive" parameters, just like in the XGBoost docs.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, I still think that

**kwargs Other parameters for the model. These can be any of the keyword arguments for LGBMModel or any other LightGBM parameters documented at https://lightgbm.readthedocs.io/en/latest/Parameters.html.

would be better... But OK.

What I'm definitely sure in is that sklearn classes and Dask ones should follow the same pattern.

image

image

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I was so focused on the Dask estimators in the most recent round of changes that I forgot about the affect this would have on LGBM{Classifier,Ranker,Regressor}. I agree, I need to fix this inconsistency

I do think that it'd be better to have all the arguments listed out in the signature explicitly. That's helpful for code completion in editors and help() in a REPL. And I strongly suspect that users use LGBM{Classifier,Ranker,Regressor} directly much more often than they use LGBMModel. It introduces duplication in the code, but I personally am OK with that in exchange for those benefits for users, for the reasons I mentioned in #6783 (comment)

Given that set of possible benefits, @StrikerRUS would you be ok with me duplicating all the defaults into the __init__() signature of LGBM{Classifier,Ranker,Regressor} too (as currently happens for the Dask estimators) and expanding the tests to confirm that the arguments are all consistent between LGBMModel, LGBM{Classifier,Ranker,Regressor}, and DaskLGBM{Classifier,Ranker,Regressor}?

Or would you still prefer having **kwargs and a docstring like this?

**kwargs Other parameters for the model. These can be any of the keyword arguments for LGBMModel or any other LightGBM parameters documented at https://lightgbm.readthedocs.io/en/latest/Parameters.html.

It seems from comments above that @borchero was also OK with either form... I think we are all struggling to choose a preferred form here. I don't have any other thoughts on this, so I'll happily defer to your decision.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I'll approve consistent version with explicitly listed args.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! I'm sorry for how much effort reviewing this PR has turned into.

I do think LightGBM's users will appreciate sub-classing being easier, and still having tab completion for constructor arguments for LGBM{Classifier,Ranker,Regressor}.

I just pushed 3d351a4 repeating all the arguments in the constructors.

Also added a test in test_sklearn.py similar to the Dask one, to ensure that all the default values and the set of parameters stay the same.

Updated docs:

Now they look the same:

Screenshot 2025-02-10 at 11 28 56 PM Screenshot 2025-02-10 at 11 30 33 PM

I also re-ran the sub-classing example being added to FAQ.rst here to be sure it works.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to apologize! Thank you for working on this very important change!

_before_kwargs, _kwargs, _after_kwargs = _base_doc.partition("**kwargs") # type: ignore
@@ -1318,52 +1278,13 @@ class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel):

def __init__(
self,
boosting_type: str = "gbdt",
num_leaves: int = 31,
max_depth: int = -1,
learning_rate: float = 0.1,
n_estimators: int = 100,
subsample_for_bin: int = 200000,
objective: Optional[Union[str, _LGBM_ScikitCustomObjectiveFunction]] = None,
class_weight: Optional[Union[dict, str]] = None,
min_split_gain: float = 0.0,
min_child_weight: float = 1e-3,
min_child_samples: int = 20,
subsample: float = 1.0,
subsample_freq: int = 0,
colsample_bytree: float = 1.0,
reg_alpha: float = 0.0,
reg_lambda: float = 0.0,
random_state: Optional[Union[int, np.random.RandomState, "np.random.Generator"]] = None,
n_jobs: Optional[int] = None,
importance_type: str = "split",
*,
client: Optional[Client] = None,
**kwargs: Any,
):
"""Docstring is inherited from the lightgbm.LGBMRegressor.__init__."""
self.client = client
super().__init__(
boosting_type=boosting_type,
num_leaves=num_leaves,
max_depth=max_depth,
learning_rate=learning_rate,
n_estimators=n_estimators,
subsample_for_bin=subsample_for_bin,
objective=objective,
class_weight=class_weight,
min_split_gain=min_split_gain,
min_child_weight=min_child_weight,
min_child_samples=min_child_samples,
subsample=subsample,
subsample_freq=subsample_freq,
colsample_bytree=colsample_bytree,
reg_alpha=reg_alpha,
reg_lambda=reg_lambda,
random_state=random_state,
n_jobs=n_jobs,
importance_type=importance_type,
**kwargs,
)
super().__init__(**kwargs)

_base_doc = LGBMRegressor.__init__.__doc__
_before_kwargs, _kwargs, _after_kwargs = _base_doc.partition("**kwargs") # type: ignore
@@ -1485,52 +1406,13 @@ class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel):

def __init__(
self,
boosting_type: str = "gbdt",
num_leaves: int = 31,
max_depth: int = -1,
learning_rate: float = 0.1,
n_estimators: int = 100,
subsample_for_bin: int = 200000,
objective: Optional[Union[str, _LGBM_ScikitCustomObjectiveFunction]] = None,
class_weight: Optional[Union[dict, str]] = None,
min_split_gain: float = 0.0,
min_child_weight: float = 1e-3,
min_child_samples: int = 20,
subsample: float = 1.0,
subsample_freq: int = 0,
colsample_bytree: float = 1.0,
reg_alpha: float = 0.0,
reg_lambda: float = 0.0,
random_state: Optional[Union[int, np.random.RandomState, "np.random.Generator"]] = None,
n_jobs: Optional[int] = None,
importance_type: str = "split",
*,
client: Optional[Client] = None,
**kwargs: Any,
):
"""Docstring is inherited from the lightgbm.LGBMRanker.__init__."""
self.client = client
super().__init__(
boosting_type=boosting_type,
num_leaves=num_leaves,
max_depth=max_depth,
learning_rate=learning_rate,
n_estimators=n_estimators,
subsample_for_bin=subsample_for_bin,
objective=objective,
class_weight=class_weight,
min_split_gain=min_split_gain,
min_child_weight=min_child_weight,
min_child_samples=min_child_samples,
subsample=subsample,
subsample_freq=subsample_freq,
colsample_bytree=colsample_bytree,
reg_alpha=reg_alpha,
reg_lambda=reg_lambda,
random_state=random_state,
n_jobs=n_jobs,
importance_type=importance_type,
**kwargs,
)
super().__init__(**kwargs)

_base_doc = LGBMRanker.__init__.__doc__
_before_kwargs, _kwargs, _after_kwargs = _base_doc.partition("**kwargs") # type: ignore
44 changes: 44 additions & 0 deletions python-package/lightgbm/sklearn.py
Original file line number Diff line number Diff line change
@@ -488,6 +488,7 @@ class LGBMModel(_LGBMModelBase):

def __init__(
self,
*,
boosting_type: str = "gbdt",
num_leaves: int = 31,
max_depth: int = -1,
@@ -745,7 +746,35 @@ def get_params(self, deep: bool = True) -> Dict[str, Any]:
params : dict
Parameter names mapped to their values.
"""
# Based on: https://github.com/dmlc/xgboost/blob/bd92b1c9c0db3e75ec3dfa513e1435d518bb535d/python-package/xgboost/sklearn.py#L941
# which was based on: https://stackoverflow.com/questions/59248211
#
# `get_params()` flows like this:
#
# 0. Get parameters in subclass (self.__class__) first, by using inspect.
# 1. Get parameters in all parent classes (especially `LGBMModel`).
# 2. Get whatever was passed via `**kwargs`.
# 3. Merge them.
#
# This needs to accommodate being called recursively in the following
# inheritance graphs (and similar for classification and ranking):
#
# DaskLGBMRegressor -> LGBMRegressor -> LGBMModel -> BaseEstimator
# (custom subclass) -> LGBMRegressor -> LGBMModel -> BaseEstimator
# LGBMRegressor -> LGBMModel -> BaseEstimator
# (custom subclass) -> LGBMModel -> BaseEstimator
# LGBMModel -> BaseEstimator
#
params = super().get_params(deep=deep)
cp = copy.copy(self)
# If the immediate parent defines get_params(), use that.
if callable(getattr(cp.__class__.__bases__[0], "get_params", None)):
cp.__class__ = cp.__class__.__bases__[0]
# Otherwise, skip it and assume the next class will have it.
# This is here primarily for cases where the first class in MRO is a scikit-learn mixin.
else:
cp.__class__ = cp.__class__.__bases__[1]
params.update(cp.__class__.get_params(cp, deep))
params.update(self._other_params)
return params

@@ -1285,6 +1314,11 @@ def feature_names_in_(self) -> None:
class LGBMRegressor(_LGBMRegressorBase, LGBMModel):
"""LightGBM regressor."""

def __init__(self, **kwargs: Any):
super().__init__(**kwargs)

__init__.__doc__ = LGBMModel.__init__.__doc__

def _more_tags(self) -> Dict[str, Any]:
# handle the case where RegressorMixin possibly provides _more_tags()
if callable(getattr(_LGBMRegressorBase, "_more_tags", None)):
@@ -1344,6 +1378,11 @@ def fit( # type: ignore[override]
class LGBMClassifier(_LGBMClassifierBase, LGBMModel):
"""LightGBM classifier."""

def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)

__init__.__doc__ = LGBMModel.__init__.__doc__

def _more_tags(self) -> Dict[str, Any]:
# handle the case where ClassifierMixin possibly provides _more_tags()
if callable(getattr(_LGBMClassifierBase, "_more_tags", None)):
@@ -1554,6 +1593,11 @@ class LGBMRanker(LGBMModel):
Please use this class mainly for training and applying ranking models in common sklearnish way.
"""

def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)

__init__.__doc__ = LGBMModel.__init__.__doc__

def fit( # type: ignore[override]
self,
X: _LGBM_ScikitMatrixLike,
13 changes: 7 additions & 6 deletions tests/python_package_test/test_dask.py
Original file line number Diff line number Diff line change
@@ -1385,13 +1385,14 @@ def test_dask_classes_and_sklearn_equivalents_have_identical_constructors_except
sklearn_spec = inspect.getfullargspec(classes[1])
assert dask_spec.varargs == sklearn_spec.varargs
assert dask_spec.varkw == sklearn_spec.varkw
assert dask_spec.kwonlyargs == sklearn_spec.kwonlyargs
assert dask_spec.kwonlydefaults == sklearn_spec.kwonlydefaults

# "client" should be the only different, and the final argument
assert dask_spec.args[:-1] == sklearn_spec.args
assert dask_spec.defaults[:-1] == sklearn_spec.defaults
assert dask_spec.args[-1] == "client"
assert dask_spec.kwonlyargs == [*sklearn_spec.kwonlyargs, "client"]
assert dask_spec.kwonlydefaults == {"client": None}
assert sklearn_spec.kwonlydefaults is None

# only positional argument should be 'self'
assert dask_spec.args == sklearn_spec.args
assert dask_spec.args == ["self"]
assert dask_spec.defaults[-1] is None


Loading