diff --git a/src/syngen/VERSION b/src/syngen/VERSION index 78bc1abd..57121573 100644 --- a/src/syngen/VERSION +++ b/src/syngen/VERSION @@ -1 +1 @@ -0.10.0 +0.10.1 diff --git a/src/syngen/ml/metrics/metrics_classes/metrics.py b/src/syngen/ml/metrics/metrics_classes/metrics.py index 8e57315d..2014180f 100644 --- a/src/syngen/ml/metrics/metrics_classes/metrics.py +++ b/src/syngen/ml/metrics/metrics_classes/metrics.py @@ -1356,12 +1356,8 @@ def __model_process(self, model_object, targets, task_type): original, model_y, self.sample_size ) - if len(set(model_y)) < 2: - logger.info( - f"Column {col} has less than 2 classes as target. " - f"It will not be used in metric " - f"that measures regression results." - ) + # validate number of unique values and instances in each class + if not self._valid_target(col, model_y, task_type): continue ( @@ -1463,3 +1459,39 @@ def __create_regression_models(self, cont_targets): cont_targets, "regression" ) return best_target, score, synthetic_score + + @staticmethod + def _valid_target(col, model_y, task_type): + """ + Validate the target column in utility metric calculation. + + Args: + col (str): The name of the column being checked. + model_y (array-like): The target values. + task_type (str): type task + + Returns: + bool: True if the column is valid + """ + unique_values, unique_counts = np.unique(model_y, return_counts=True) + + # check if there is only one unique value in the column + if len(unique_values) < 2: + logger.info( + f"Column '{col}' has only one unique value. " + "It will not be used as target column " + f"in utility metric calculation for {task_type}." + ) + return False + + # check if there is more than 1 sample in each class + if task_type in ["binary classification", "multiclass classification"]: + if np.min(unique_counts) < 2: + logger.info( + f"Column '{col}' has a class with less than 2 samples. " + "It will not be used as target column " + f"in utility metric calculation for {task_type}." + ) + return False + + return True diff --git a/src/tests/unit/metrics/test_metrics.py b/src/tests/unit/metrics/test_metrics.py index fd419234..5419deeb 100644 --- a/src/tests/unit/metrics/test_metrics.py +++ b/src/tests/unit/metrics/test_metrics.py @@ -1,6 +1,11 @@ import pandas as pd +import numpy as np +import pytest + +from unittest.mock import patch + +from syngen.ml.metrics.metrics_classes.metrics import Clustering, Utility -from syngen.ml.metrics.metrics_classes.metrics import Clustering from tests.conftest import SUCCESSFUL_MESSAGE, DIR_NAME @@ -30,3 +35,42 @@ def test_clustering_calculate_all(rp_logger): assert mean_score >= threshold, f"Mean score shouldn't be less than {threshold}" rp_logger.info(SUCCESSFUL_MESSAGE) + + +@pytest.mark.parametrize( + "col, model_y, task_type, expected_result, expected_log_message", + [ + ("column_1", np.array([1, 1, 1, 1]), "binary classification", False, True), + ("column_2", np.array([0, 0, 0, 1]), "binary classification", False, True), + ("column_3", np.array([0, 1, 0, 1]), "binary classification", True, False), + ("column_4", np.array([0, 0, 1, 1, 2, 2]), "multiclass classification", True, False), + ("column_5", np.array([0, 0, 0, 1, 2, 2]), "multiclass classification", False, True), + ("column_6", np.array([1, 1, 1]), "multiclass classification", False, True), + ("column_7", np.array([1, 1, 1, 1]), "regression", False, True), + ("column_8", np.array([1, 2, 3, 4, 4]), "regression", True, False), + ] +) +def test_utility_valid_target( + rp_logger, col, model_y, task_type, + expected_result, expected_log_message): + """ + Testing the _valid_target function in the Utility class + """ + rp_logger.info( + "Testing the _valid_target function in the Utility class" + ) + + with patch( + 'syngen.ml.metrics.metrics_classes.metrics.logger' + ) as mock_logger: + result = Utility._valid_target(col, model_y, task_type) + + assert result == expected_result, \ + f"Expected result is {expected_result}, got {result}" + + if expected_log_message: + mock_logger.info.assert_called_once() + else: + mock_logger.info.assert_not_called() + + rp_logger.info(SUCCESSFUL_MESSAGE)