Skip to content

Commit

Permalink
Merge pull request #483 from tdspora/EPMCTDM-7179_target_validation_f…
Browse files Browse the repository at this point in the history
…or_utility_metric

Epmctdm 7179 target validation for utility metric
  • Loading branch information
Ijka authored Dec 18, 2024
2 parents 0436a15 + 2a490e8 commit 47ced85
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 8 deletions.
2 changes: 1 addition & 1 deletion src/syngen/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.10.0
0.10.1
44 changes: 38 additions & 6 deletions src/syngen/ml/metrics/metrics_classes/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1356,12 +1356,8 @@ def __model_process(self, model_object, targets, task_type):
original, model_y, self.sample_size
)

if len(set(model_y)) < 2:
logger.info(
f"Column {col} has less than 2 classes as target. "
f"It will not be used in metric "
f"that measures regression results."
)
# validate number of unique values and instances in each class
if not self._valid_target(col, model_y, task_type):
continue

(
Expand Down Expand Up @@ -1463,3 +1459,39 @@ def __create_regression_models(self, cont_targets):
cont_targets, "regression"
)
return best_target, score, synthetic_score

@staticmethod
def _valid_target(col, model_y, task_type):
"""
Validate the target column in utility metric calculation.
Args:
col (str): The name of the column being checked.
model_y (array-like): The target values.
task_type (str): type task
Returns:
bool: True if the column is valid
"""
unique_values, unique_counts = np.unique(model_y, return_counts=True)

# check if there is only one unique value in the column
if len(unique_values) < 2:
logger.info(
f"Column '{col}' has only one unique value. "
"It will not be used as target column "
f"in utility metric calculation for {task_type}."
)
return False

# check if there is more than 1 sample in each class
if task_type in ["binary classification", "multiclass classification"]:
if np.min(unique_counts) < 2:
logger.info(
f"Column '{col}' has a class with less than 2 samples. "
"It will not be used as target column "
f"in utility metric calculation for {task_type}."
)
return False

return True
46 changes: 45 additions & 1 deletion src/tests/unit/metrics/test_metrics.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import pandas as pd
import numpy as np
import pytest

from unittest.mock import patch

from syngen.ml.metrics.metrics_classes.metrics import Clustering, Utility

from syngen.ml.metrics.metrics_classes.metrics import Clustering
from tests.conftest import SUCCESSFUL_MESSAGE, DIR_NAME


Expand Down Expand Up @@ -30,3 +35,42 @@ def test_clustering_calculate_all(rp_logger):
assert mean_score >= threshold, f"Mean score shouldn't be less than {threshold}"

rp_logger.info(SUCCESSFUL_MESSAGE)


@pytest.mark.parametrize(
"col, model_y, task_type, expected_result, expected_log_message",
[
("column_1", np.array([1, 1, 1, 1]), "binary classification", False, True),
("column_2", np.array([0, 0, 0, 1]), "binary classification", False, True),
("column_3", np.array([0, 1, 0, 1]), "binary classification", True, False),
("column_4", np.array([0, 0, 1, 1, 2, 2]), "multiclass classification", True, False),
("column_5", np.array([0, 0, 0, 1, 2, 2]), "multiclass classification", False, True),
("column_6", np.array([1, 1, 1]), "multiclass classification", False, True),
("column_7", np.array([1, 1, 1, 1]), "regression", False, True),
("column_8", np.array([1, 2, 3, 4, 4]), "regression", True, False),
]
)
def test_utility_valid_target(
rp_logger, col, model_y, task_type,
expected_result, expected_log_message):
"""
Testing the _valid_target function in the Utility class
"""
rp_logger.info(
"Testing the _valid_target function in the Utility class"
)

with patch(
'syngen.ml.metrics.metrics_classes.metrics.logger'
) as mock_logger:
result = Utility._valid_target(col, model_y, task_type)

assert result == expected_result, \
f"Expected result is {expected_result}, got {result}"

if expected_log_message:
mock_logger.info.assert_called_once()
else:
mock_logger.info.assert_not_called()

rp_logger.info(SUCCESSFUL_MESSAGE)

0 comments on commit 47ced85

Please sign in to comment.