Skip to content

Commit

Permalink
FIX Removes validation in __init__ for Pipeline (#21888)
Browse files Browse the repository at this point in the history
Co-authored-by: arisayosh <[email protected]>
Co-authored-by: iofall <[email protected]>
Co-authored-by: Olivier Grisel <[email protected]>
  • Loading branch information
3 people authored Dec 10, 2021
1 parent d72bd02 commit 0110921
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 10 deletions.
4 changes: 4 additions & 0 deletions doc/whats_new/v1.1.rst
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,10 @@ Changelog
Setting a transformer to "passthrough" will pass the features unchanged.
:pr:`20860` by :user:`Shubhraneel Pal <shubhraneel>`.

- |Fix| :class: `pipeline.Pipeline` now does not validate hyper-parameters in
`__init__` but in `.fit()`.
:pr:`21888` by :user:`iofall <iofall>` and :user: `Arisa Y. <arisayosh>`.

:mod:`sklearn.preprocessing`
............................

Expand Down
1 change: 0 additions & 1 deletion sklearn/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,6 @@ def __init__(self, steps, *, memory=None, verbose=False):
self.steps = steps
self.memory = memory
self.verbose = verbose
self._validate_steps()

def get_params(self, deep=True):
"""Get parameters for this estimator.
Expand Down
1 change: 0 additions & 1 deletion sklearn/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,6 @@ def test_transformers_get_feature_names_out(transformer):
"FeatureUnion",
"GridSearchCV",
"HalvingGridSearchCV",
"Pipeline",
"SGDOneClassSVM",
"TheilSenRegressor",
"TweedieRegressor",
Expand Down
20 changes: 12 additions & 8 deletions sklearn/tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,20 +165,23 @@ def predict_log_proba(self, X, got_attribute=False):
return self


def test_pipeline_init():
# Test the various init parameters of the pipeline.
def test_pipeline_invalid_parameters():
# Test the various init parameters of the pipeline in fit
# method
pipeline = Pipeline([(1, 1)])
with pytest.raises(TypeError):
Pipeline()
pipeline.fit([[1]], [1])

# Check that we can't instantiate pipelines with objects without fit
# Check that we can't fit pipelines with objects without fit
# method
msg = (
"Last step of Pipeline should implement fit "
"or be the string 'passthrough'"
".*NoFit.*"
)
pipeline = Pipeline([("clf", NoFit())])
with pytest.raises(TypeError, match=msg):
Pipeline([("clf", NoFit())])
pipeline.fit([[1]], [1])

# Smoke test with only an estimator
clf = NoTrans()
Expand All @@ -203,11 +206,12 @@ def test_pipeline_init():
assert pipe.named_steps["anova"] is filter1
assert pipe.named_steps["svc"] is clf

# Check that we can't instantiate with non-transformers on the way
# Check that we can't fit with non-transformers on the way
# Note that NoTrans implements fit, but not transform
msg = "All intermediate steps should be transformers.*\\bNoTrans\\b.*"
pipeline = Pipeline([("t", NoTrans()), ("svc", clf)])
with pytest.raises(TypeError, match=msg):
Pipeline([("t", NoTrans()), ("svc", clf)])
pipeline.fit([[1]], [1])

# Check that params are set
pipe.set_params(svc__C=0.1)
Expand Down Expand Up @@ -1086,7 +1090,7 @@ def test_step_name_validation():
# three ways to make invalid:
# - construction
with pytest.raises(ValueError, match=message):
cls(**{param: bad_steps})
cls(**{param: bad_steps}).fit([[1]], [1])

# - setattr
est = cls(**{param: [("a", Mult(1))]})
Expand Down

0 comments on commit 0110921

Please sign in to comment.