Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
Release history
===============

.. include:: whats_new/v0.0.4.rst
.. include:: whats_new/v0.4.rst

.. include:: whats_new/v0.0.3.rst
.. include:: whats_new/v0.3.rst

.. include:: whats_new/v0.0.2.rst
.. include:: whats_new/v0.2.rst

.. include:: whats_new/v0.0.1.rst
.. include:: whats_new/v0.1.rst
File renamed without changes.
File renamed without changes.
File renamed without changes.
4 changes: 4 additions & 0 deletions doc/whats_new/v0.0.4.rst → doc/whats_new/v0.4.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ Bug fixes
deviation.
By :user:`Guillaume Lemaitre <glemaitre>` in :issue:`491`.

- Raise an error when passing target which is not supported, i.e. regression
target or multilabel targets. Imbalanced-learn does not support this case.
By :user:`Guillaume Lemaitre <glemaitre>` in :issue:`490`.

Version 0.4
===========

Expand Down
2 changes: 2 additions & 0 deletions imblearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from sklearn.externals import six
from sklearn.preprocessing import label_binarize
from sklearn.utils import check_X_y
from sklearn.utils.multiclass import check_classification_targets

from .utils import check_sampling_strategy, check_target_type
from .utils.deprecation import deprecate_parameter
Expand Down Expand Up @@ -77,6 +78,7 @@ def fit_resample(self, X, y):
"""
self._deprecate_ratio()

check_classification_targets(y)
X, y, binarize_y = self._check_X_y(X, y)

self.sampling_strategy_ = check_sampling_strategy(
Expand Down
7 changes: 5 additions & 2 deletions imblearn/over_sampling/tests/test_smote_nc.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,12 @@ def test_smotenc_check_target_type():
y = np.linspace(0, 1, 30)
smote = SMOTENC(categorical_features=categorical_features,
random_state=0)
with pytest.warns(UserWarning, match='should be of types'):
with pytest.raises(ValueError, match="Unknown label type: 'continuous'"):
smote.fit_resample(X, y)
rng = np.random.RandomState(42)
y = rng.randint(2, size=(20, 3))
with pytest.raises(ValueError, match="'y' should encode the multiclass"):
smote.fit_resample(X, y)


def test_smotenc_samplers_one_label():
X, _, categorical_features = data_heterogneous_unordered()
Expand Down
19 changes: 8 additions & 11 deletions imblearn/utils/_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,17 +86,14 @@ def check_target_type(y, indicate_one_vs_all=False):

"""
type_y = type_of_target(y)
if type_y not in TARGET_KIND:
# FIXME: perfectly we should raise an error but the sklearn API does
# not allow for it
warnings.warn("'y' should be of types {} only. Got {} instead.".format(
TARGET_KIND, type_of_target(y)))

if indicate_one_vs_all:
return (y.argmax(axis=1) if type_y == 'multilabel-indicator' else y,
type_y == 'multilabel-indicator')
else:
return y.argmax(axis=1) if type_y == 'multilabel-indicator' else y
if type_y == 'multilabel-indicator':
if np.any(y.sum(axis=1) > 1):
raise ValueError(
"When 'y' corresponds to '{}', 'y' should encode the "
"multiclass (a single 1 by row).".format(type_y))
y = y.argmax(axis=1)

return (y, type_y == 'multilabel-indicator') if indicate_one_vs_all else y


def _sampling_strategy_all(y, sampling_type):
Expand Down
19 changes: 11 additions & 8 deletions imblearn/utils/estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@
from imblearn.over_sampling import SMOTE
from imblearn.under_sampling import NearMiss, ClusterCentroids

from imblearn.utils.testing import warns

DONT_SUPPORT_RATIO = ['SVMSMOTE', 'BorderlineSMOTE']
SUPPORT_STRING = ['RandomUnderSampler', 'RandomOverSampler']
HAVE_SAMPLE_INDICES = [
Expand All @@ -54,7 +52,6 @@ def monkey_patch_check_dtype_object(name, estimator_orig):
X = rng.rand(40, 10).astype(object)
y = np.array([0] * 10 + [1] * 30, dtype=np.int)
estimator = clone(estimator_orig)

estimator.fit(X, y)

try:
Expand Down Expand Up @@ -123,14 +120,20 @@ def check_estimator(Estimator, run_sampler_tests=True):


def check_target_type(name, Estimator):
# should raise warning if the target is continuous (we cannot raise error)
X = np.random.random((20, 2))
y = np.linspace(0, 1, 20)
estimator = Estimator()
# FIXME: in 0.6 set the random_state for all
if name not in DONT_HAVE_RANDOM_STATE:
set_random_state(estimator)
with warns(UserWarning, match='should be of types'):
estimator.fit(X, y)
with pytest.raises(ValueError, match="Unknown label type: 'continuous'"):
estimator.fit_resample(X, y)
# if the target is multilabel then we should raise an error
rng = np.random.RandomState(42)
y = rng.randint(2, size=(20, 3))
with pytest.raises(ValueError, match="'y' should encode the multiclass"):
estimator.fit_resample(X, y)


def check_samplers_one_label(name, Sampler):
Expand All @@ -139,7 +142,7 @@ def check_samplers_one_label(name, Sampler):
X = np.random.random((20, 2))
y = np.zeros(20)
try:
sampler.fit(X, y)
sampler.fit_resample(X, y)
except ValueError as e:
if 'class' not in repr(e):
print(error_string_fit, Sampler, e)
Expand All @@ -157,15 +160,15 @@ def check_samplers_fit(name, Sampler):
sampler = Sampler()
X = np.random.random((30, 2))
y = np.array([1] * 20 + [0] * 10)
sampler.fit(X, y)
sampler.fit_resample(X, y)
assert hasattr(sampler, 'sampling_strategy_'), \
"No fitted attribute sampling_strategy_"


def check_samplers_fit_resample(name, Sampler):
sampler = Sampler()
X, y = make_classification(n_samples=1000, n_classes=3, n_informative=4,
weights=[0.2, 0.3, 0.5], random_state=0)
weights=[0.2, 0.3, 0.5], random_state=0)
target_stats = Counter(y)
X_res, y_res = sampler.fit_resample(X, y)
if isinstance(sampler, BaseOverSampler):
Expand Down
27 changes: 4 additions & 23 deletions imblearn/utils/tests/test_estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from sklearn.base import BaseEstimator
from sklearn.utils import check_X_y
from sklearn.utils.multiclass import check_classification_targets

from imblearn.base import BaseSampler
from imblearn.utils.estimator_checks import check_estimator
Expand All @@ -17,6 +18,8 @@ def fit(self, X, y):
return self

def fit_resample(self, X, y):
check_classification_targets(y)
self.fit(X, y)
return X, y


Expand All @@ -27,36 +30,15 @@ def fit(self, X, y):
X, y = check_X_y(X, y, accept_sparse=True)
return self

def fit_resample(self, X, y):
self.fit(X, y)
return X, y


class NoAcceptingSparseSampler(BaseBadSampler):
"""Sampler which does not accept sparse matrix."""
def fit(self, X, y):
X, y = check_X_y(X, y, accept_sparse=False)
y, _ = check_target_type(y, indicate_one_vs_all=True)
self.sampling_strategy_ = 'sampling_strategy_'
return self

def fit_resample(self, X, y):
self.fit(X, y)
return X, y


class NotTransformingTargetOvR(BaseBadSampler):
"""Sampler which does not transform OvR enconding."""
def fit(self, X, y):
X, y = check_X_y(X, y, accept_sparse=True)
y, _ = check_target_type(y, indicate_one_vs_all=True)
X, y = check_X_y(X, y, accept_sparse=False)
self.sampling_strategy_ = 'sampling_strategy_'
return self

def fit_resample(self, X, y):
self.fit(X, y)
return X, y


class NotPreservingDtypeSampler(BaseSampler):
_sampling_type = 'bypass'
Expand All @@ -72,7 +54,6 @@ def _fit_resample(self, X, y):
[(BaseBadSampler, AssertionError, "TypeError not raised by fit"),
(NotFittedSampler, AssertionError, "No fitted attribute"),
(NoAcceptingSparseSampler, TypeError, "A sparse matrix was passed"),
(NotTransformingTargetOvR, ValueError, "bad input shape"),
(NotPreservingDtypeSampler, AssertionError, "X dytype is not preserved")]
)
def test_check_estimator(Estimator, err_type, err_msg):
Expand Down
6 changes: 0 additions & 6 deletions imblearn/utils/tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,6 @@ def test_check_target_type_ova(target, output_target, is_ova):
assert binarize_target == is_ova


def test_check_target_warning():
target = np.arange(4).reshape((2, 2))
with pytest.warns(UserWarning, match='should be of types'):
check_target_type(target)


def test_check_sampling_strategy_warning():
msg = 'dict for cleaning methods is deprecated'
with pytest.warns(DeprecationWarning, match=msg):
Expand Down