Skip to content

Commit

Permalink
Merge pull request #259 from lvgig/227-bring-grouprarelevelstransform…
Browse files Browse the repository at this point in the history
…er-in-line-with-new-testing-setup

Bring GroupRareLevelsTransformer in line with new testing setup
  • Loading branch information
davidhopkinson26 authored Jul 8, 2024
2 parents 63b9677 + d4ad521 commit 0725fb1
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 145 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@ Subsections for each version can be one of the following;
- ``Security`` in case of vulnerabilities.

Each individual change should have a link to the pull request after the description of the change.

1.3.1 (unreleased)
------------------

Changed
^^^^^^^
- Refactored GroupRareLevelsTransformer tests in new format `#259 <https://github.com/lvgig/tubular/pull/259>`_
- DatetimeInfoExtractor.mappings_provided changed from a dict.keys() object to list so transformer is serialisable. `#258 <https://github.com/lvgig/tubular/pull/258>`_
- Created BaseNumericTransformer class to support test refactor of numeric file

Expand Down
26 changes: 3 additions & 23 deletions tests/nominal/test_BaseNominalTransformer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import re

import pandas as pd
import pytest
from sklearn.exceptions import NotFittedError
Expand All @@ -16,7 +14,7 @@
# The first part of this file builds out the tests for BaseNominalTransformer so that they can be
# imported into other test files (by not starting the class name with Test)
# The second part actually calls these tests (along with all other require tests) for the BaseNominalTransformer
class GenericBaseNominalTransformerTests:
class GenericNominalTransformTests(GenericTransformTests):
"""
Tests for BaseNominalTransformer.transform().
Note this deliberately avoids starting with "Tests" so that the tests are not run on import.
Expand All @@ -29,7 +27,7 @@ def test_not_fitted_error_raised(self, initialized_transformers):
with pytest.raises(NotFittedError):
initialized_transformers[self.transformer_name].transform(df)

def test_exception_raised(self, initialized_transformers):
def test_non_mappable_rows_exception_raised(self, initialized_transformers):
"""Test an exception is raised if non-mappable rows are present in X."""
df = d.create_df_1()

Expand Down Expand Up @@ -63,24 +61,6 @@ def test_original_df_not_updated(self, initialized_transformers):

pd.testing.assert_frame_equal(df, d.create_df_1())

def test_no_rows_error(self, initialized_transformers):
"""Test an error is raised if X has no rows."""
df = d.create_df_1()

x = initialized_transformers[self.transformer_name]

x = x.fit(df)

x.mappings = {"b": {"a": 1, "b": 2, "c": 3, "d": 4, "e": 5, "f": 6}}

df = pd.DataFrame(columns=["a", "b", "c"])

with pytest.raises(
ValueError,
match=re.escape(f"{self.transformer_name}: X has no rows; (0, 3)"),
):
x.transform(df)


class TestInit(ColumnStrListInitTests):
"""Generic tests for transformer.init()."""
Expand All @@ -98,7 +78,7 @@ def setup_class(cls):
cls.transformer_name = "BaseNominalTransformer"


class TestTransform(GenericBaseNominalTransformerTests, GenericTransformTests):
class TestTransform(GenericNominalTransformTests):
"""Tests for BaseImputer.transform."""

@classmethod
Expand Down
157 changes: 35 additions & 122 deletions tests/nominal/test_GroupRareLevelsTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,25 @@
import pandas as pd
import pytest
import test_aide as ta
from test_BaseNominalTransformer import GenericNominalTransformTests

import tests.test_data as d
import tubular
from tests.base_tests import (
ColumnStrListInitTests,
GenericFitTests,
OtherBaseBehaviourTests,
WeightColumnFitMixinTests,
WeightColumnInitMixinTests,
)
from tubular.nominal import GroupRareLevelsTransformer


class TestInit:
class TestInit(ColumnStrListInitTests, WeightColumnInitMixinTests):
"""Tests for GroupRareLevelsTransformer.init()."""

def test_super_init_called(self, mocker):
"""Test that init calls BaseTransformer.init."""
expected_call_args = {
0: {"args": (), "kwargs": {"columns": None, "verbose": True}},
}

with ta.functions.assert_function_call(
mocker,
tubular.base.BaseTransformer,
"__init__",
expected_call_args,
):
GroupRareLevelsTransformer(columns=None, verbose=True)
@classmethod
def setup_class(cls):
cls.transformer_name = "GroupRareLevelsTransformer"

def test_cut_off_percent_not_float_error(self):
"""Test that an exception is raised if cut_off_percent is not an float."""
Expand All @@ -49,14 +46,6 @@ def test_cut_off_percent_gt_one_error(self):
):
GroupRareLevelsTransformer(columns="a", cut_off_percent=2.0)

def test_weight_not_str_error(self):
"""Test that an exception is raised if weight is not a str, if supplied."""
with pytest.raises(
TypeError,
match="weights_column should be str or None",
):
GroupRareLevelsTransformer(columns="a", weights_column=2)

def test_record_rare_levels_not_bool_error(self):
"""Test that an exception is raised if record_rare_levels is not a bool."""
with pytest.raises(
Expand All @@ -74,62 +63,12 @@ def test_unseen_levels_to_rare_not_bool_error(self):
GroupRareLevelsTransformer(columns="a", unseen_levels_to_rare=2)


class TestFit:
class TestFit(GenericFitTests, WeightColumnFitMixinTests):
"""Tests for GroupRareLevelsTransformer.fit()."""

def test_super_fit_called(self, mocker):
"""Test that fit calls BaseTransformer.fit."""
df = d.create_df_5()

x = GroupRareLevelsTransformer(columns=["b", "c"])

expected_call_args = {0: {"args": (d.create_df_5(), None), "kwargs": {}}}

with ta.functions.assert_function_call(
mocker,
tubular.base.BaseTransformer,
"fit",
expected_call_args,
):
x.fit(df)

def test_weight_column_not_in_X_error(self):
"""Test that an exception is raised if weight is not in X."""
df = d.create_df_5()

x = GroupRareLevelsTransformer(columns=["b", "c"], weights_column="aaaa")

with pytest.raises(
ValueError,
match=r"weight col \(aaaa\) is not present in columns of data",
):
x.fit(df)

def test_fit_returns_self(self):
"""Test fit returns self?."""
df = d.create_df_5()

x = GroupRareLevelsTransformer(columns=["b", "c"])

x_fitted = x.fit(df)

assert (
x_fitted is x
), "Returned value from GroupRareLevelsTransformer.fit not as expected."

def test_fit_not_changing_data(self):
"""Test fit does not change X."""
df = d.create_df_5()

x = GroupRareLevelsTransformer(columns=["b", "c"])

x.fit(df)

ta.equality.assert_equal_dispatch(
expected=d.create_df_5(),
actual=df,
msg="Check X not changing during fit",
)
@classmethod
def setup_class(cls):
cls.transformer_name = "GroupRareLevelsTransformer"

def test_learnt_values_no_weight(self):
"""Test that the impute values learnt during fit, without using a weight, are expected."""
Expand Down Expand Up @@ -221,9 +160,13 @@ def test_training_data_levels_stored(self):
)


class TestTransform:
class TestTransform(GenericNominalTransformTests):
"""Tests for GroupRareLevelsTransformer.transform()."""

@classmethod
def setup_class(cls):
cls.transformer_name = "GroupRareLevelsTransformer"

def expected_df_1():
"""Expected output for test_expected_output_no_weight."""
df = pd.DataFrame({"a": [1, 2, 3, 4, 5, 6, 7, 8, 9, np.nan]})
Expand Down Expand Up @@ -260,50 +203,8 @@ def expected_df_2():

return df

def test_check_is_fitted_called(self, mocker):
"""Test that BaseTransformer check_is_fitted called."""
df = d.create_df_5()

x = GroupRareLevelsTransformer(columns=["b", "c"])

x.fit(df)

expected_call_args = {0: {"args": (["non_rare_levels"],), "kwargs": {}}}

with ta.functions.assert_function_call(
mocker,
tubular.base.BaseTransformer,
"check_is_fitted",
expected_call_args,
):
x.transform(df)

def test_super_transform_called(self, mocker):
"""Test that BaseTransformer.transform called."""
df = d.create_df_5()

x = GroupRareLevelsTransformer(columns=["b", "c"])

x.fit(df)

expected_call_args = {
0: {
"args": (
x,
d.create_df_5(),
),
"kwargs": {},
},
}

with ta.functions.assert_function_call(
mocker,
tubular.base.BaseTransformer,
"transform",
expected_call_args,
return_value=d.create_df_5(),
):
x.transform(df)
def test_non_mappable_rows_exception_raised(self):
"""override test in GenericNominalTransformTests as not relevant to this transformer."""

def test_learnt_values_not_modified(self):
"""Test that the non_rare_levels from fit are not changed in transform."""
Expand Down Expand Up @@ -471,3 +372,15 @@ def test_rare_categories_forgotten(self):
assert (
cat not in output_categories
), f"{x.classname} output columns should forget rare encoded categories, expected {cat} to be forgotten from column {column}"


class TestOtherBaseBehaviour(OtherBaseBehaviourTests):
"""
Class to run tests for BaseTransformerBehaviour outside the three standard methods.
May need to overwite specific tests in this class if the tested transformer modifies this behaviour.
"""

@classmethod
def setup_class(cls):
cls.transformer_name = "BaseNominalTransformer"

0 comments on commit 0725fb1

Please sign in to comment.