Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/equality transformer #245

Merged
merged 4 commits into from
May 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ Changed
- Added generic init tests to base tests for transformers that take two columns as an input.
- Refactored EqualityChecker tests in new format.
- Bugfix to MeanResponseTransformer to ignore unobserved categorical levels

- Added test_BaseTwoColumnTransformer base class for columns that require a list of two columns for input

Removed
^^^^^^^
Expand Down
40 changes: 40 additions & 0 deletions tests/base/test_BaseTwoColumnTransformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from tests.base_tests import (
GenericFitTests,
GenericTransformTests,
OtherBaseBehaviourTests,
TwoColumnListInitTests,
)


class TestInit(TwoColumnListInitTests):
"""Generic tests for transformer.init()."""

@classmethod
def setup_class(cls):
cls.transformer_name = "BaseTwoColumnTransformer"


class TestFit(GenericFitTests):
"""Generic tests for transformer.fit()"""

@classmethod
def setup_class(cls):
cls.transformer_name = "BaseTwoColumnTransformer"


class TestTransform(GenericTransformTests):
@classmethod
def setup_class(cls):
cls.transformer_name = "BaseTwoColumnTransformer"


class TestOtherBaseBehaviour(OtherBaseBehaviourTests):
"""
Class to run tests for BaseTransformerBehaviour outside the three standard methods.

May need to overwite specific tests in this class if the tested transformer modifies this behaviour.
"""

@classmethod
def setup_class(cls):
cls.transformer_name = "BaseTwoColumnTransformer"
4 changes: 4 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ def minimal_attribute_dict():
"BaseTransformer": {
"columns": ["a"],
},
"BaseTwoColumnTransformer": {
"columns": ["a", "b"],
"new_col_name": "c",
},
"DataFrameMethodTransformer": {
"columns": ["a", "c"],
"new_column_names": "f",
Expand Down
43 changes: 43 additions & 0 deletions tubular/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,49 @@ def check_weights_column(X: pd.DataFrame, weights_column: str) -> None:
raise ValueError(msg)


class BaseTwoColumnTransformer(BaseTransformer):
"""Transformer that takes a list of two columns as an argument, as well as new_column_name

Inherits from BaseTransformer, all current transformers that use this argument also output a new column
Inherits fit and transform methods from BaseTransformer (required by sklearn transformers), simple input checking
and functionality to copy X prior to transform.

Parameters
----------
columns : list
Column pair to apply the transformer to, must be list, cannot be None

new_col_name : str
Name of new column being created, must be str, cannot be None

**kwargs
Arbitrary keyword arguments passed onto BaseTransformer.__init__().

"""

def __init__(
self,
columns: list[str],
new_col_name: str,
**kwargs: dict[str, bool],
) -> None:
super().__init__(columns=columns, **kwargs)

if not (isinstance(columns, list)):
msg = f"{self.classname()}: columns should be list"
raise TypeError(msg)

if len(columns) != 2:
msg = f"{self.classname()}: This transformer works with two columns only"
raise ValueError(msg)

if not (isinstance(new_col_name, str)):
msg = f"{self.classname()}: new_col_name should be str"
raise TypeError(msg)

self.new_col_name = new_col_name


class DataFrameMethodTransformer(BaseTransformer):
"""Tranformer that applies a pandas.DataFrame method.

Expand Down
19 changes: 3 additions & 16 deletions tubular/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

import pandas as pd # noqa: TCH002

from tubular.base import BaseTransformer
from tubular.base import BaseTwoColumnTransformer


class EqualityChecker(BaseTransformer):
class EqualityChecker(BaseTwoColumnTransformer):
"""Transformer to check if two columns are equal.

Parameters
Expand All @@ -31,25 +31,12 @@ def __init__(
drop_original: bool = False,
**kwargs: dict[str, bool],
) -> None:
super().__init__(columns=columns, **kwargs)

if not (isinstance(columns, list)):
msg = f"{self.classname()}: columns should be list"
raise TypeError(msg)

if len(columns) != 2:
msg = f"{self.classname()}: This transformer works with two columns only"
raise ValueError(msg)

if not (isinstance(new_col_name, str)):
msg = f"{self.classname()}: new_col_name should be str"
raise TypeError(msg)
super().__init__(columns=columns, new_col_name=new_col_name, **kwargs)

if not (isinstance(drop_original, bool)):
msg = f"{self.classname()}: drop_original should be bool"
raise TypeError(msg)

self.new_col_name = new_col_name
self.drop_original = drop_original

def transform(self, X: pd.DataFrame) -> pd.DataFrame:
Expand Down
Loading