44# Christos Aridas
55# License: MIT
66
7+ import inspect
78import numbers
9+ import warnings
810
911import numpy as np
1012
@@ -41,10 +43,12 @@ class BalancedBaggingClassifier(BaggingClassifier):
4143
4244 Parameters
4345 ----------
44- base_estimator : estimator object, default=None
46+ estimator : estimator object, default=None
4547 The base estimator to fit on random subsets of the dataset.
4648 If None, then the base estimator is a decision tree.
4749
50+ .. versionadded:: 0.10
51+
4852 n_estimators : int, default=10
4953 The number of base estimators in the ensemble.
5054
@@ -100,18 +104,37 @@ class BalancedBaggingClassifier(BaggingClassifier):
100104
101105 .. versionadded:: 0.8
102106
107+ base_estimator : estimator object, default=None
108+ The base estimator to fit on random subsets of the dataset.
109+ If None, then the base estimator is a decision tree.
110+
111+ .. deprecated:: 0.10
112+ `base_estimator` was renamed to `estimator` in version 0.10 and
113+ will be removed in 0.12.
114+
103115 Attributes
104116 ----------
117+ estimator_ : estimator
118+ The base estimator from which the ensemble is grown.
119+
120+ .. versionadded:: 0.10
121+
105122 base_estimator_ : estimator
106123 The base estimator from which the ensemble is grown.
107124
125+ .. deprecated:: 1.2
126+ `base_estimator_` is deprecated in `scikit-learn` 1.2 and will be
127+ removed in 1.4. Use `estimator_` instead. When the minimum version
128+ of `scikit-learn` supported by `imbalanced-learn` will reach 1.4,
129+ this attribute will be removed.
130+
108131 n_features_ : int
109132 The number of features when `fit` is performed.
110133
111134 .. deprecated:: 1.0
112135 `n_features_` is deprecated in `scikit-learn` 1.0 and will be removed
113- in version 1.2. Depending of the version of `scikit-learn` installed,
114- you will get be warned or not .
136+ in version 1.2. When the minimum version of `scikit-learn` supported
137+ by `imbalanced-learn` will reach 1.2, this attribute will be removed .
115138
116139 estimators_ : list of estimators
117140 The collection of fitted base estimators.
@@ -209,7 +232,7 @@ class BalancedBaggingClassifier(BaggingClassifier):
209232 >>> from sklearn.model_selection import train_test_split
210233 >>> from sklearn.metrics import confusion_matrix
211234 >>> from imblearn.ensemble import \
212- BalancedBaggingClassifier # doctest: +NORMALIZE_WHITESPACE
235+ BalancedBaggingClassifier # doctest:
213236 >>> X, y = make_classification(n_classes=2, class_sep=2,
214237 ... weights=[0.1, 0.9], n_informative=3, n_redundant=1, flip_y=0,
215238 ... n_features=20, n_clusters_per_class=1, n_samples=1000, random_state=10)
@@ -218,7 +241,7 @@ class BalancedBaggingClassifier(BaggingClassifier):
218241 >>> X_train, X_test, y_train, y_test = train_test_split(X, y,
219242 ... random_state=0)
220243 >>> bbc = BalancedBaggingClassifier(random_state=42)
221- >>> bbc.fit(X_train, y_train) # doctest: +ELLIPSIS
244+ >>> bbc.fit(X_train, y_train) # doctest:
222245 BalancedBaggingClassifier(...)
223246 >>> y_pred = bbc.predict(X_test)
224247 >>> print(confusion_matrix(y_test, y_pred))
@@ -229,7 +252,7 @@ class BalancedBaggingClassifier(BaggingClassifier):
229252 @_deprecate_positional_args
230253 def __init__ (
231254 self ,
232- base_estimator = None ,
255+ estimator = None ,
233256 n_estimators = 10 ,
234257 * ,
235258 max_samples = 1.0 ,
@@ -244,10 +267,18 @@ def __init__(
244267 random_state = None ,
245268 verbose = 0 ,
246269 sampler = None ,
270+ base_estimator = "deprecated" ,
247271 ):
272+ # TODO: remove when supporting scikit-learn>=1.2
273+ bagging_classifier_signature = inspect .signature (super ().__init__ )
274+ estimator_params = {"base_estimator" : base_estimator }
275+ if "estimator" in bagging_classifier_signature .parameters :
276+ estimator_params ["estimator" ] = estimator
277+ else :
278+ self .estimator = estimator
248279
249280 super ().__init__ (
250- base_estimator ,
281+ ** estimator_params ,
251282 n_estimators = n_estimators ,
252283 max_samples = max_samples ,
253284 max_features = max_features ,
@@ -294,20 +325,54 @@ def _validate_estimator(self, default=DecisionTreeClassifier()):
294325 f"n_estimators must be greater than zero, " f"got { self .n_estimators } ."
295326 )
296327
297- if self .base_estimator is not None :
328+ if self .estimator is not None and (
329+ self .base_estimator not in [None , "deprecated" ]
330+ ):
331+ raise ValueError (
332+ "Both `estimator` and `base_estimator` were set. Only set `estimator`."
333+ )
334+
335+ if self .estimator is not None :
336+ base_estimator = clone (self .estimator )
337+ elif self .base_estimator not in [None , "deprecated" ]:
338+ warnings .warn (
339+ "`base_estimator` was renamed to `estimator` in version 0.10 and "
340+ "will be removed in 0.12." ,
341+ FutureWarning ,
342+ )
298343 base_estimator = clone (self .base_estimator )
299344 else :
300345 base_estimator = clone (default )
301346
302347 if self .sampler_ ._sampling_type != "bypass" :
303348 self .sampler_ .set_params (sampling_strategy = self ._sampling_strategy )
304349
305- self .base_estimator_ = Pipeline (
306- [
307- ("sampler" , self .sampler_ ),
308- ("classifier" , base_estimator ),
309- ]
350+ self ._estimator = Pipeline (
351+ [("sampler" , self .sampler_ ), ("classifier" , base_estimator )]
352+ )
353+ try :
354+ # scikit-learn < 1.2
355+ self .base_estimator_ = self ._estimator
356+ except AttributeError :
357+ pass
358+
359+ # TODO: remove when supporting scikit-learn>=1.4
360+ @property
361+ def estimator_ (self ):
362+ """Estimator used to grow the ensemble."""
363+ return self ._estimator
364+
365+ # TODO: remove when supporting scikit-learn>=1.2
366+ @property
367+ def n_features_ (self ):
368+ """Number of features when ``fit`` is performed."""
369+ warnings .warn (
370+ "`n_features_` was deprecated in scikit-learn 1.0. This attribute will "
371+ "not be accessible when the minimum supported version of scikit-learn "
372+ "is 1.2." ,
373+ FutureWarning ,
310374 )
375+ return self .n_features_in_
311376
312377 def fit (self , X , y ):
313378 """Build a Bagging ensemble of estimators from the training set (X, y).
0 commit comments