Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
check for get_params() instead
Browse files Browse the repository at this point in the history
jameslamb committed Dec 3, 2024

Verified

This commit was signed with the committer’s verified signature.
jameslamb James Lamb
1 parent 27a1bcc commit 18c602f
Showing 1 changed file with 6 additions and 9 deletions.
15 changes: 6 additions & 9 deletions python-package/xgboost/sklearn.py
Original file line number Diff line number Diff line change
@@ -949,16 +949,13 @@ def get_params(self, deep: bool = True) -> Dict[str, Any]:
#
params = super().get_params(deep)
cp = copy.copy(self)
# if the immediate parent is a mixin, skip it (mixins don't define get_params())
if cp.__class__.__bases__[0] in (
XGBClassifierBase,
XGBRankerMixIn,
XGBRegressorBase,
):
cp.__class__ = cp.__class__.__bases__[1]
# otherwise, run get_params() from the immediate parent class
else:
# If the immediate parent defines get_params(), use that.
if callable(getattr(cp.__class__.__bases__[0], "get_params", None)):
cp.__class__ = cp.__class__.__bases__[0]
# Otherwise, skip it and assume the next class will have it.
# This is here primarily for cases where the first class in MRO is a scikit-learn mixin.
else:
cp.__class__ = cp.__class__.__bases__[1]
params.update(cp.__class__.get_params(cp, deep))
# if kwargs is a dict, update params accordingly
if hasattr(self, "kwargs") and isinstance(self.kwargs, dict):

0 comments on commit 18c602f

Please sign in to comment.