Skip to content

Commit

Permalink
rely on mixins to set estimator_type, removed unused imports
Browse files Browse the repository at this point in the history
  • Loading branch information
jameslamb committed Dec 1, 2024
1 parent 8364e92 commit a511848
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 24 deletions.
13 changes: 0 additions & 13 deletions python-package/xgboost/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,24 +47,17 @@ def lazy_isinstance(instance: Any, module: str, name: str) -> bool:
from sklearn.base import BaseEstimator as XGBModelBase
from sklearn.base import ClassifierMixin as XGBClassifierBase
from sklearn.base import RegressorMixin as XGBRegressorBase
from sklearn.preprocessing import LabelEncoder

try:
from sklearn.model_selection import KFold as XGBKFold
from sklearn.model_selection import StratifiedKFold as XGBStratifiedKFold
except ImportError:
from sklearn.cross_validation import KFold as XGBKFold
from sklearn.cross_validation import StratifiedKFold as XGBStratifiedKFold

# sklearn.utils Tags types can be imported unconditionally once
# xgboost's minimum scikit-learn version is 1.6 or higher
try:
from sklearn.utils import ClassifierTags as _sklearn_ClassifierTags
from sklearn.utils import RegressorTags as _sklearn_RegressorTags
from sklearn.utils import Tags as _sklearn_Tags
except ImportError:
_sklearn_ClassifierTags = object
_sklearn_RegressorTags = object
_sklearn_Tags = object

SKLEARN_INSTALLED = True
Expand All @@ -82,14 +75,8 @@ class XGBClassifierBase: # type: ignore[no-redef]
class XGBRegressorBase: # type: ignore[no-redef]
"""Dummy class for sklearn.base.RegressorMixin."""

class LabelEncoder: # type: ignore[no-redef]
"""Dummy class for sklearn.preprocessing.LabelEncoder."""

XGBKFold = None
XGBStratifiedKFold = None

_sklearn_ClassifierTags = object
_sklearn_RegressorTags = object
_sklearn_Tags = object
_sklearn_version = object

Expand Down
15 changes: 4 additions & 11 deletions python-package/xgboost/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@
XGBClassifierBase,
XGBModelBase,
XGBRegressorBase,
_sklearn_ClassifierTags,
_sklearn_RegressorTags,
_sklearn_Tags,
_sklearn_version,
import_cupy,
Expand Down Expand Up @@ -840,7 +838,7 @@ def __sklearn_tags__(self) -> _sklearn_Tags:
# take whatever tags are provided by BaseEstimator, then modify
# them with XGBoost-specific values
return self._update_sklearn_tags_from_dict(
tags=XGBModelBase.__sklearn_tags__(self), # pylint: disable=no-member
tags=super().__sklearn_tags__(), # pylint: disable=no-member
tags_dict=self._more_tags(),
)

Expand Down Expand Up @@ -1554,12 +1552,9 @@ def _more_tags(self) -> Dict[str, bool]:
return tags

def __sklearn_tags__(self) -> _sklearn_Tags:
tags = XGBModel.__sklearn_tags__(self)
tags.estimator_type = "classifier"
tags = super().__sklearn_tags__()
tags_dict = self._more_tags()
tags.classifier_tags = _sklearn_ClassifierTags(
multi_label=tags_dict["multilabel"]
)
tags.classifier_tags.multi_label = tags_dict["multilabel"]
return tags

@_deprecate_positional_args
Expand Down Expand Up @@ -1849,10 +1844,8 @@ def _more_tags(self) -> Dict[str, bool]:
return tags

def __sklearn_tags__(self) -> _sklearn_Tags:
tags = XGBModel.__sklearn_tags__(self)
tags.estimator_type = "regressor"
tags = super().__sklearn_tags__()
tags_dict = self._more_tags()
tags.regressor_tags = _sklearn_RegressorTags()
tags.target_tags.multi_output = tags_dict["multioutput"]
tags.target_tags.single_output = not tags_dict["multioutput_only"]
return tags
Expand Down
26 changes: 26 additions & 0 deletions tests/python/test_with_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -826,6 +826,32 @@ def get_tm(clf: xgb.XGBClassifier) -> str:
assert clf.get_params()["tree_method"] is None


def test_get_params_works_as_expected():
# XGBModel -> BaseEstimator
params = xgb.XGBModel(max_depth=2).get_params()
assert params["max_depth"] == 2
# 'objective' defaults to None in the signature of XGBModel
assert params["objective"] is None

# XGBRegressor -> XGBModel -> BaseEstimator
params = xgb.XGBRegressor(max_depth=3).get_params()
assert params["max_depth"] == 3
# 'objective' defaults to 'reg:squarederror' in the signature of XGBRegressor
assert params["objective"] == "reg:squarederror"
# 'colsample_bynode' defaults to 'None' for XGBModel (which XGBRegressor inherits from), so it
# should be in get_params() output
assert params["colsample_bynode"] is None

# XGBRFRegressor -> XGBRegressor -> XGBModel -> BaseEstimator
params = xgb.XGBRFRegressor(max_depth=4, objective="reg:tweedie").get_params()
assert params["max_depth"] == 4
# 'objective' is a keyword argument for XGBRegressor, so it should be in get_params() output
# ... but values passed through kwargs should override the default from the signature of XGBRegressor
assert params["objective"] == "reg:tweedie"
# 'colsample_bynode' defaults to 0.8 for XGBRFRegressor...that should be preferred to the None from XGBRegressor
assert params["colsample_bynode"] == 0.8


def test_kwargs_error():
params = {'updater': 'grow_gpu_hist', 'subsample': .5, 'n_jobs': -1}
with pytest.raises(TypeError):
Expand Down

0 comments on commit a511848

Please sign in to comment.