diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 716226d9..35d6797c 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -42,7 +42,7 @@ Changed - Added generic init tests to base tests for transformers that take two columns as an input. - Refactored EqualityChecker tests in new format. - Bugfix to MeanResponseTransformer to ignore unobserved categorical levels - +- Added test_BaseTwoColumnTransformer base class for columns that require a list of two columns for input Removed ^^^^^^^ diff --git a/tests/base/test_BaseTwoColumnTransformer.py b/tests/base/test_BaseTwoColumnTransformer.py new file mode 100644 index 00000000..2cf75888 --- /dev/null +++ b/tests/base/test_BaseTwoColumnTransformer.py @@ -0,0 +1,40 @@ +from tests.base_tests import ( + GenericFitTests, + GenericTransformTests, + OtherBaseBehaviourTests, + TwoColumnListInitTests, +) + + +class TestInit(TwoColumnListInitTests): + """Generic tests for transformer.init().""" + + @classmethod + def setup_class(cls): + cls.transformer_name = "BaseTwoColumnTransformer" + + +class TestFit(GenericFitTests): + """Generic tests for transformer.fit()""" + + @classmethod + def setup_class(cls): + cls.transformer_name = "BaseTwoColumnTransformer" + + +class TestTransform(GenericTransformTests): + @classmethod + def setup_class(cls): + cls.transformer_name = "BaseTwoColumnTransformer" + + +class TestOtherBaseBehaviour(OtherBaseBehaviourTests): + """ + Class to run tests for BaseTransformerBehaviour outside the three standard methods. + + May need to overwite specific tests in this class if the tested transformer modifies this behaviour. + """ + + @classmethod + def setup_class(cls): + cls.transformer_name = "BaseTwoColumnTransformer" diff --git a/tests/conftest.py b/tests/conftest.py index a3d846a0..45de65f2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -64,6 +64,10 @@ def minimal_attribute_dict(): "BaseTransformer": { "columns": ["a"], }, + "BaseTwoColumnTransformer": { + "columns": ["a", "b"], + "new_col_name": "c", + }, "DataFrameMethodTransformer": { "columns": ["a", "c"], "new_column_names": "f", diff --git a/tubular/base.py b/tubular/base.py index 382be407..b5303516 100644 --- a/tubular/base.py +++ b/tubular/base.py @@ -273,6 +273,49 @@ def check_weights_column(X: pd.DataFrame, weights_column: str) -> None: raise ValueError(msg) +class BaseTwoColumnTransformer(BaseTransformer): + """Transformer that takes a list of two columns as an argument, as well as new_column_name + + Inherits from BaseTransformer, all current transformers that use this argument also output a new column + Inherits fit and transform methods from BaseTransformer (required by sklearn transformers), simple input checking + and functionality to copy X prior to transform. + + Parameters + ---------- + columns : list + Column pair to apply the transformer to, must be list, cannot be None + + new_col_name : str + Name of new column being created, must be str, cannot be None + + **kwargs + Arbitrary keyword arguments passed onto BaseTransformer.__init__(). + + """ + + def __init__( + self, + columns: list[str], + new_col_name: str, + **kwargs: dict[str, bool], + ) -> None: + super().__init__(columns=columns, **kwargs) + + if not (isinstance(columns, list)): + msg = f"{self.classname()}: columns should be list" + raise TypeError(msg) + + if len(columns) != 2: + msg = f"{self.classname()}: This transformer works with two columns only" + raise ValueError(msg) + + if not (isinstance(new_col_name, str)): + msg = f"{self.classname()}: new_col_name should be str" + raise TypeError(msg) + + self.new_col_name = new_col_name + + class DataFrameMethodTransformer(BaseTransformer): """Tranformer that applies a pandas.DataFrame method. diff --git a/tubular/comparison.py b/tubular/comparison.py index 4ea15884..59356a9a 100644 --- a/tubular/comparison.py +++ b/tubular/comparison.py @@ -2,10 +2,10 @@ import pandas as pd # noqa: TCH002 -from tubular.base import BaseTransformer +from tubular.base import BaseTwoColumnTransformer -class EqualityChecker(BaseTransformer): +class EqualityChecker(BaseTwoColumnTransformer): """Transformer to check if two columns are equal. Parameters @@ -31,25 +31,12 @@ def __init__( drop_original: bool = False, **kwargs: dict[str, bool], ) -> None: - super().__init__(columns=columns, **kwargs) - - if not (isinstance(columns, list)): - msg = f"{self.classname()}: columns should be list" - raise TypeError(msg) - - if len(columns) != 2: - msg = f"{self.classname()}: This transformer works with two columns only" - raise ValueError(msg) - - if not (isinstance(new_col_name, str)): - msg = f"{self.classname()}: new_col_name should be str" - raise TypeError(msg) + super().__init__(columns=columns, new_col_name=new_col_name, **kwargs) if not (isinstance(drop_original, bool)): msg = f"{self.classname()}: drop_original should be bool" raise TypeError(msg) - self.new_col_name = new_col_name self.drop_original = drop_original def transform(self, X: pd.DataFrame) -> pd.DataFrame: