Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bring GroupRareLevelsTransformer in line with new testing setup #259

Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@ Subsections for each version can be one of the following;

Each individual change should have a link to the pull request after the description of the change.

1.3.1 (unreleased)
------------------

Changed
^^^^^^^
- Refactored GroupRareLevelsTransformer tests in new format `#259 <https://github.com/lvgig/tubular/pull/259>`_

1.3.0 (2024-06-13)
------------------
Expand Down
2 changes: 1 addition & 1 deletion tests/nominal/test_BaseNominalTransformer.py
limlam96 marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def test_not_fitted_error_raised(self, initialized_transformers):
with pytest.raises(NotFittedError):
limlam96 marked this conversation as resolved.
Show resolved Hide resolved
initialized_transformers[self.transformer_name].transform(df)

def test_exception_raised(self, initialized_transformers):
def test_non_mappable_rows_exception_raised(self, initialized_transformers):
"""Test an exception is raised if non-mappable rows are present in X."""
df = d.create_df_1()

Expand Down
158 changes: 36 additions & 122 deletions tests/nominal/test_GroupRareLevelsTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,26 @@
import pandas as pd
import pytest
import test_aide as ta
from test_BaseNominalTransformer import GenericBaseNominalTransformerTests

import tests.test_data as d
import tubular
from tests.base_tests import (
ColumnStrListInitTests,
GenericFitTests,
GenericTransformTests,
OtherBaseBehaviourTests,
WeightColumnFitMixinTests,
WeightColumnInitMixinTests,
)
from tubular.nominal import GroupRareLevelsTransformer


class TestInit:
class TestInit(ColumnStrListInitTests, WeightColumnInitMixinTests):
limlam96 marked this conversation as resolved.
Show resolved Hide resolved
"""Tests for GroupRareLevelsTransformer.init()."""

def test_super_init_called(self, mocker):
"""Test that init calls BaseTransformer.init."""
expected_call_args = {
0: {"args": (), "kwargs": {"columns": None, "verbose": True}},
}

with ta.functions.assert_function_call(
mocker,
tubular.base.BaseTransformer,
"__init__",
expected_call_args,
):
GroupRareLevelsTransformer(columns=None, verbose=True)
@classmethod
def setup_class(cls):
cls.transformer_name = "GroupRareLevelsTransformer"

def test_cut_off_percent_not_float_error(self):
"""Test that an exception is raised if cut_off_percent is not an float."""
Expand All @@ -49,14 +47,6 @@ def test_cut_off_percent_gt_one_error(self):
):
GroupRareLevelsTransformer(columns="a", cut_off_percent=2.0)

def test_weight_not_str_error(self):
"""Test that an exception is raised if weight is not a str, if supplied."""
with pytest.raises(
TypeError,
match="weights_column should be str or None",
):
GroupRareLevelsTransformer(columns="a", weights_column=2)

def test_record_rare_levels_not_bool_error(self):
"""Test that an exception is raised if record_rare_levels is not a bool."""
with pytest.raises(
Expand All @@ -74,62 +64,12 @@ def test_unseen_levels_to_rare_not_bool_error(self):
GroupRareLevelsTransformer(columns="a", unseen_levels_to_rare=2)


class TestFit:
class TestFit(GenericFitTests, WeightColumnFitMixinTests):
"""Tests for GroupRareLevelsTransformer.fit()."""

def test_super_fit_called(self, mocker):
"""Test that fit calls BaseTransformer.fit."""
df = d.create_df_5()

x = GroupRareLevelsTransformer(columns=["b", "c"])

expected_call_args = {0: {"args": (d.create_df_5(), None), "kwargs": {}}}

with ta.functions.assert_function_call(
mocker,
tubular.base.BaseTransformer,
"fit",
expected_call_args,
):
x.fit(df)

def test_weight_column_not_in_X_error(self):
"""Test that an exception is raised if weight is not in X."""
df = d.create_df_5()

x = GroupRareLevelsTransformer(columns=["b", "c"], weights_column="aaaa")

with pytest.raises(
ValueError,
match=r"weight col \(aaaa\) is not present in columns of data",
):
x.fit(df)

def test_fit_returns_self(self):
"""Test fit returns self?."""
df = d.create_df_5()

x = GroupRareLevelsTransformer(columns=["b", "c"])

x_fitted = x.fit(df)

assert (
x_fitted is x
), "Returned value from GroupRareLevelsTransformer.fit not as expected."

def test_fit_not_changing_data(self):
"""Test fit does not change X."""
df = d.create_df_5()

x = GroupRareLevelsTransformer(columns=["b", "c"])

x.fit(df)

ta.equality.assert_equal_dispatch(
expected=d.create_df_5(),
actual=df,
msg="Check X not changing during fit",
)
@classmethod
def setup_class(cls):
cls.transformer_name = "GroupRareLevelsTransformer"

def test_learnt_values_no_weight(self):
"""Test that the impute values learnt during fit, without using a weight, are expected."""
Expand Down Expand Up @@ -221,9 +161,13 @@ def test_training_data_levels_stored(self):
)


class TestTransform:
class TestTransform(GenericBaseNominalTransformerTests, GenericTransformTests):
"""Tests for GroupRareLevelsTransformer.transform()."""

@classmethod
def setup_class(cls):
cls.transformer_name = "GroupRareLevelsTransformer"

def expected_df_1():
"""Expected output for test_expected_output_no_weight."""
df = pd.DataFrame({"a": [1, 2, 3, 4, 5, 6, 7, 8, 9, np.nan]})
Expand Down Expand Up @@ -260,50 +204,8 @@ def expected_df_2():

return df

def test_check_is_fitted_called(self, mocker):
"""Test that BaseTransformer check_is_fitted called."""
df = d.create_df_5()

x = GroupRareLevelsTransformer(columns=["b", "c"])

x.fit(df)

expected_call_args = {0: {"args": (["non_rare_levels"],), "kwargs": {}}}

with ta.functions.assert_function_call(
mocker,
tubular.base.BaseTransformer,
"check_is_fitted",
expected_call_args,
):
x.transform(df)

def test_super_transform_called(self, mocker):
"""Test that BaseTransformer.transform called."""
df = d.create_df_5()

x = GroupRareLevelsTransformer(columns=["b", "c"])

x.fit(df)

expected_call_args = {
0: {
"args": (
x,
d.create_df_5(),
),
"kwargs": {},
},
}

with ta.functions.assert_function_call(
mocker,
tubular.base.BaseTransformer,
"transform",
expected_call_args,
return_value=d.create_df_5(),
):
x.transform(df)
def test_non_mappable_rows_exception_raised(self):
"""override test in GenericBaseNominalTransformerTests as not relevant to this transformer."""

def test_learnt_values_not_modified(self):
"""Test that the non_rare_levels from fit are not changed in transform."""
Expand Down Expand Up @@ -471,3 +373,15 @@ def test_rare_categories_forgotten(self):
assert (
cat not in output_categories
), f"{x.classname} output columns should forget rare encoded categories, expected {cat} to be forgotten from column {column}"


class TestOtherBaseBehaviour(OtherBaseBehaviourTests):
"""
Class to run tests for BaseTransformerBehaviour outside the three standard methods.

May need to overwite specific tests in this class if the tested transformer modifies this behaviour.
"""

@classmethod
def setup_class(cls):
cls.transformer_name = "BaseNominalTransformer"
Loading