From f673f78fd958c40d995177415bfa07a619a327de Mon Sep 17 00:00:00 2001 From: lajohn4747 Date: Fri, 17 Nov 2023 11:19:47 -0600 Subject: [PATCH 1/6] Filter out keys that cannot be statistically modeled --- sdmetrics/single_table/detection/base.py | 17 +++++ .../single_table/detection/test_detection.py | 73 +++++++++++++++++++ 2 files changed, 90 insertions(+) diff --git a/sdmetrics/single_table/detection/base.py b/sdmetrics/single_table/detection/base.py index 2da0a735..cb155a94 100644 --- a/sdmetrics/single_table/detection/base.py +++ b/sdmetrics/single_table/detection/base.py @@ -76,6 +76,23 @@ def compute(cls, real_data, synthetic_data, metadata=None): transformed_real_data = real_data transformed_synthetic_data = synthetic_data + if metadata is not None and 'columns' in metadata: + drop_columns = [] + for column in metadata['columns']: + if 'primary_key' in metadata and column == metadata['primary_key']: + continue + for field in metadata['columns'][column]: + if field == 'sdtype': + sdtype = metadata['columns'][column][field] + if sdtype == 'id' or sdtype == 'text': + drop_columns.append(column) + if field == 'pii': + if metadata['columns'][column][field]: + drop_columns.append(column) + if len(drop_columns) > 0: + transformed_real_data = transformed_real_data.drop(drop_columns, axis=1) + transformed_synthetic_data = transformed_synthetic_data.drop(drop_columns, axis=1) + ht = HyperTransformer() transformed_real_data = ht.fit_transform(transformed_real_data).to_numpy() transformed_synthetic_data = ht.transform(transformed_synthetic_data).to_numpy() diff --git a/tests/unit/single_table/detection/test_detection.py b/tests/unit/single_table/detection/test_detection.py index dd230ddf..4dcaf3c9 100644 --- a/tests/unit/single_table/detection/test_detection.py +++ b/tests/unit/single_table/detection/test_detection.py @@ -65,3 +65,76 @@ def test_primary_key_detection_metrics(self, fit_transform_mock, transform_mock) transform_mock.assert_called_with(expected_return_synthetic) assert expected_return_real == call_1 assert expected_return_synthetic == call_2 + + @patch('sdmetrics.utils.HyperTransformer.transform') + @patch('sdmetrics.utils.HyperTransformer.fit_transform') + def test_ignore_keys_detection_metrics(self, fit_transform_mock, transform_mock): + """This test checks that ``primary_key`` columns of dataset are ignored. + + Ensure that ``primary_keys`` are ignored for Detection metrics expect that the match + is made correctly. + """ + + # Setup + real_data = pd.DataFrame({ + 'ID_1': [1, 2, 1, 3, 4], + 'col1': [43.0, 47.5, 34.2, 30.3, 39.1], + 'col2': [1.0, 2.0, 3.0, 4.0, 5.0], + 'ID_2': ['aa', 'bb', 'cc', 'dd', 'bb'], + 'col3': [5, 6, 7, 8, 9], + 'ID_3': ['a', 'b', 'c', 'd', 'e'], + 'blob': ['Hello world!', 'Hello world!', 'This is SDV', 'This is SDV', 'Hello world!'], + 'col4': [1, 3, 9, 2, 1] + }) + synthetic_data = pd.DataFrame({ + 'ID_1': [1, 3, 4, 2, 2], + 'col1': [23.0, 47.1, 44.9, 31.3, 9.7], + 'col2': [11.0, 22.0, 33.0, 44.0, 55.0], + 'ID_2': ['aa', 'bb', 'cc', 'dd', 'ee'], + 'col3': [55, 66, 77, 88, 99], + 'ID_3': ['a', 'b', 'e', 'd', 'c'], + 'blob': ['Hello world!', 'Hello world!', 'This is SDV', 'This is SDV', 'Hello world!'], + 'col4': [4, 1, 3, 1, 9] + }) + metadata = { + 'columns': { + 'ID_1': {'sdtype': 'numerical'}, + 'col1': {'sdtype': 'numerical', 'pii': True}, + 'col2': {'sdtype': 'numerical'}, + 'ID_2': {'sdtype': 'categorical'}, + 'col3': {'sdtype': 'numerical'}, + 'ID_3': {'sdtype': 'id'}, + 'blob': {'sdtype': 'text'}, + 'col4': {'sdtype': 'numeric', 'pii': False} + }, + 'primary_key': {'ID_1', 'ID_2'} + } + + expected_real_dataframe = pd.DataFrame({ + 'col2': [1.0, 2.0, 3.0, 4.0, 5.0], + 'col3': [5, 6, 7, 8, 9], + 'col4': [1, 3, 9, 2, 1] + }) + expected_synthetic_dataframe = pd.DataFrame({ + 'col2': [11.0, 22.0, 33.0, 44.0, 55.0], + 'col3': [55, 66, 77, 88, 99], + 'col4': [4, 1, 3, 1, 9] + }) + + expected_return_real = DataFrameMatcher(expected_real_dataframe) + expected_return_synthetic = DataFrameMatcher(expected_synthetic_dataframe) + fit_transform_mock.return_value = expected_real_dataframe + transform_mock.return_value = expected_synthetic_dataframe + + # Run + LogisticDetection().compute(real_data, synthetic_data, metadata) + + # Assert + + # check that ``fit_transform`` and ``transform`` received the good argument. + call_1 = pd.DataFrame(fit_transform_mock.call_args_list[0][0][0]) + call_2 = pd.DataFrame(transform_mock.call_args_list[0][0][0]) + + transform_mock.assert_called_with(expected_return_synthetic) + assert expected_return_real == call_1 + assert expected_return_synthetic == call_2 From 7c46515b463b49f5b3b9980c966f4360062c6ac3 Mon Sep 17 00:00:00 2001 From: lajohn4747 Date: Mon, 20 Nov 2023 13:31:52 -0600 Subject: [PATCH 2/6] Clean up logic --- sdmetrics/single_table/detection/base.py | 41 ++++++++++++------------ 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/sdmetrics/single_table/detection/base.py b/sdmetrics/single_table/detection/base.py index cb155a94..6a37694f 100644 --- a/sdmetrics/single_table/detection/base.py +++ b/sdmetrics/single_table/detection/base.py @@ -68,28 +68,29 @@ def compute(cls, real_data, synthetic_data, metadata=None): real_data, synthetic_data, metadata = cls._validate_inputs( real_data, synthetic_data, metadata) - if metadata is not None and 'primary_key' in metadata: - transformed_real_data = real_data.drop(metadata['primary_key'], axis=1) - transformed_synthetic_data = synthetic_data.drop(metadata['primary_key'], axis=1) + transformed_real_data = real_data + transformed_synthetic_data = synthetic_data - else: - transformed_real_data = real_data - transformed_synthetic_data = synthetic_data - - if metadata is not None and 'columns' in metadata: + if metadata is not None: drop_columns = [] - for column in metadata['columns']: - if 'primary_key' in metadata and column == metadata['primary_key']: - continue - for field in metadata['columns'][column]: - if field == 'sdtype': - sdtype = metadata['columns'][column][field] - if sdtype == 'id' or sdtype == 'text': - drop_columns.append(column) - if field == 'pii': - if metadata['columns'][column][field]: - drop_columns.append(column) - if len(drop_columns) > 0: + if 'columns' in metadata: + for column in metadata['columns']: + if ('primary_key' in metadata and + (column == metadata['primary_key'] or + column in metadata['primary_key'])): + drop_columns.append(column) + + for field in metadata['columns'][column]: + if field == 'sdtype': + sdtype = metadata['columns'][column][field] + if sdtype == 'id' or sdtype == 'text': + drop_columns.append(column) + + if field == 'pii': + if metadata['columns'][column][field]: + drop_columns.append(column) + + if drop_columns: transformed_real_data = transformed_real_data.drop(drop_columns, axis=1) transformed_synthetic_data = transformed_synthetic_data.drop(drop_columns, axis=1) From 150f3a51cf458ba35a5d38fef45eb00bb285f912 Mon Sep 17 00:00:00 2001 From: lajohn4747 Date: Tue, 21 Nov 2023 12:11:12 -0600 Subject: [PATCH 3/6] Update logic --- sdmetrics/single_table/detection/base.py | 2 +- tests/unit/single_table/detection/test_detection.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sdmetrics/single_table/detection/base.py b/sdmetrics/single_table/detection/base.py index 6a37694f..0b5497b7 100644 --- a/sdmetrics/single_table/detection/base.py +++ b/sdmetrics/single_table/detection/base.py @@ -83,7 +83,7 @@ def compute(cls, real_data, synthetic_data, metadata=None): for field in metadata['columns'][column]: if field == 'sdtype': sdtype = metadata['columns'][column][field] - if sdtype == 'id' or sdtype == 'text': + if sdtype not in ['numerical', 'datetime', 'categorical']: drop_columns.append(column) if field == 'pii': diff --git a/tests/unit/single_table/detection/test_detection.py b/tests/unit/single_table/detection/test_detection.py index 4dcaf3c9..32f03794 100644 --- a/tests/unit/single_table/detection/test_detection.py +++ b/tests/unit/single_table/detection/test_detection.py @@ -105,7 +105,7 @@ def test_ignore_keys_detection_metrics(self, fit_transform_mock, transform_mock) 'col3': {'sdtype': 'numerical'}, 'ID_3': {'sdtype': 'id'}, 'blob': {'sdtype': 'text'}, - 'col4': {'sdtype': 'numeric', 'pii': False} + 'col4': {'sdtype': 'numerical', 'pii': False} }, 'primary_key': {'ID_1', 'ID_2'} } From 33ee1c3b7d92b14bbdc531728d4b7b32529bc990 Mon Sep 17 00:00:00 2001 From: lajohn4747 Date: Tue, 21 Nov 2023 18:01:39 -0600 Subject: [PATCH 4/6] Add alternate keys to dropped columns --- sdmetrics/single_table/detection/base.py | 3 ++- tests/unit/single_table/detection/test_detection.py | 12 ++++++++---- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/sdmetrics/single_table/detection/base.py b/sdmetrics/single_table/detection/base.py index 0b5497b7..cb9b5541 100644 --- a/sdmetrics/single_table/detection/base.py +++ b/sdmetrics/single_table/detection/base.py @@ -9,7 +9,7 @@ from sdmetrics.errors import IncomputableMetricError from sdmetrics.goal import Goal from sdmetrics.single_table.base import SingleTableMetric -from sdmetrics.utils import HyperTransformer +from sdmetrics.utils import HyperTransformer, get_alternate_keys LOGGER = logging.getLogger(__name__) @@ -73,6 +73,7 @@ def compute(cls, real_data, synthetic_data, metadata=None): if metadata is not None: drop_columns = [] + drop_columns.extend(get_alternate_keys(metadata)) if 'columns' in metadata: for column in metadata['columns']: if ('primary_key' in metadata and diff --git a/tests/unit/single_table/detection/test_detection.py b/tests/unit/single_table/detection/test_detection.py index 32f03794..2e04b700 100644 --- a/tests/unit/single_table/detection/test_detection.py +++ b/tests/unit/single_table/detection/test_detection.py @@ -84,7 +84,8 @@ def test_ignore_keys_detection_metrics(self, fit_transform_mock, transform_mock) 'col3': [5, 6, 7, 8, 9], 'ID_3': ['a', 'b', 'c', 'd', 'e'], 'blob': ['Hello world!', 'Hello world!', 'This is SDV', 'This is SDV', 'Hello world!'], - 'col4': [1, 3, 9, 2, 1] + 'col4': [1, 3, 9, 2, 1], + 'col5': [10, 20, 30, 40, 50] }) synthetic_data = pd.DataFrame({ 'ID_1': [1, 3, 4, 2, 2], @@ -94,7 +95,8 @@ def test_ignore_keys_detection_metrics(self, fit_transform_mock, transform_mock) 'col3': [55, 66, 77, 88, 99], 'ID_3': ['a', 'b', 'e', 'd', 'c'], 'blob': ['Hello world!', 'Hello world!', 'This is SDV', 'This is SDV', 'Hello world!'], - 'col4': [4, 1, 3, 1, 9] + 'col4': [4, 1, 3, 1, 9], + 'col5': [10, 20, 30, 40, 50] }) metadata = { 'columns': { @@ -105,9 +107,11 @@ def test_ignore_keys_detection_metrics(self, fit_transform_mock, transform_mock) 'col3': {'sdtype': 'numerical'}, 'ID_3': {'sdtype': 'id'}, 'blob': {'sdtype': 'text'}, - 'col4': {'sdtype': 'numerical', 'pii': False} + 'col4': {'sdtype': 'numerical', 'pii': False}, + 'col5': {'sdtype': 'numerical'} }, - 'primary_key': {'ID_1', 'ID_2'} + 'primary_key': {'ID_1', 'ID_2'}, + 'alternate_keys': ['col5'] } expected_real_dataframe = pd.DataFrame({ From 3dd9750897ab53d110f4505a0ee4cdaeeccdd3ed Mon Sep 17 00:00:00 2001 From: lajohn4747 Date: Wed, 22 Nov 2023 10:07:34 -0600 Subject: [PATCH 5/6] Move to private function --- sdmetrics/single_table/detection/base.py | 59 +++++++++++++----------- 1 file changed, 33 insertions(+), 26 deletions(-) diff --git a/sdmetrics/single_table/detection/base.py b/sdmetrics/single_table/detection/base.py index cb9b5541..da8e0eb8 100644 --- a/sdmetrics/single_table/detection/base.py +++ b/sdmetrics/single_table/detection/base.py @@ -43,6 +43,37 @@ def _fit_predict(X_train, y_train, X_test): """Fit a classifier and then use it to predict.""" raise NotImplementedError() + @staticmethod + def _drop_non_compute_columns(real_data, synthetic_data, metadata): + """Drop all columns that cannot be statistically modeled.""" + transformed_real_data = real_data + transformed_synthetic_data = synthetic_data + + if metadata is not None: + drop_columns = [] + drop_columns.extend(get_alternate_keys(metadata)) + if 'columns' in metadata: + for column in metadata['columns']: + if ('primary_key' in metadata and + (column == metadata['primary_key'] or + column in metadata['primary_key'])): + drop_columns.append(column) + + for field in metadata['columns'][column]: + if field == 'sdtype': + sdtype = metadata['columns'][column][field] + if sdtype not in ['numerical', 'datetime', 'categorical']: + drop_columns.append(column) + + if field == 'pii': + if metadata['columns'][column][field]: + drop_columns.append(column) + + if drop_columns: + transformed_real_data = real_data.drop(drop_columns, axis=1) + transformed_synthetic_data = synthetic_data.drop(drop_columns, axis=1) + return transformed_real_data, transformed_synthetic_data + @classmethod def compute(cls, real_data, synthetic_data, metadata=None): """Compute this metric. @@ -68,32 +99,8 @@ def compute(cls, real_data, synthetic_data, metadata=None): real_data, synthetic_data, metadata = cls._validate_inputs( real_data, synthetic_data, metadata) - transformed_real_data = real_data - transformed_synthetic_data = synthetic_data - - if metadata is not None: - drop_columns = [] - drop_columns.extend(get_alternate_keys(metadata)) - if 'columns' in metadata: - for column in metadata['columns']: - if ('primary_key' in metadata and - (column == metadata['primary_key'] or - column in metadata['primary_key'])): - drop_columns.append(column) - - for field in metadata['columns'][column]: - if field == 'sdtype': - sdtype = metadata['columns'][column][field] - if sdtype not in ['numerical', 'datetime', 'categorical']: - drop_columns.append(column) - - if field == 'pii': - if metadata['columns'][column][field]: - drop_columns.append(column) - - if drop_columns: - transformed_real_data = transformed_real_data.drop(drop_columns, axis=1) - transformed_synthetic_data = transformed_synthetic_data.drop(drop_columns, axis=1) + transformed_real_data, transformed_synthetic_data = cls._drop_non_compute_columns( + real_data, synthetic_data, metadata) ht = HyperTransformer() transformed_real_data = ht.fit_transform(transformed_real_data).to_numpy() From 0b4b42c6cb4a43d86d82a3e2912f3c64ff2fbe7e Mon Sep 17 00:00:00 2001 From: lajohn4747 Date: Wed, 22 Nov 2023 10:46:49 -0600 Subject: [PATCH 6/6] Use get to avoid nesting ifs --- sdmetrics/single_table/detection/base.py | 27 ++++++++++-------------- 1 file changed, 11 insertions(+), 16 deletions(-) diff --git a/sdmetrics/single_table/detection/base.py b/sdmetrics/single_table/detection/base.py index da8e0eb8..ac4d9506 100644 --- a/sdmetrics/single_table/detection/base.py +++ b/sdmetrics/single_table/detection/base.py @@ -52,22 +52,17 @@ def _drop_non_compute_columns(real_data, synthetic_data, metadata): if metadata is not None: drop_columns = [] drop_columns.extend(get_alternate_keys(metadata)) - if 'columns' in metadata: - for column in metadata['columns']: - if ('primary_key' in metadata and - (column == metadata['primary_key'] or - column in metadata['primary_key'])): - drop_columns.append(column) - - for field in metadata['columns'][column]: - if field == 'sdtype': - sdtype = metadata['columns'][column][field] - if sdtype not in ['numerical', 'datetime', 'categorical']: - drop_columns.append(column) - - if field == 'pii': - if metadata['columns'][column][field]: - drop_columns.append(column) + for column in metadata.get('columns', []): + if ('primary_key' in metadata and + (column == metadata['primary_key'] or + column in metadata['primary_key'])): + drop_columns.append(column) + + column_info = metadata['columns'].get(column, {}) + sdtype = column_info.get('sdtype') + pii = column_info.get('pii') + if sdtype not in ['numerical', 'datetime', 'categorical'] or pii: + drop_columns.append(column) if drop_columns: transformed_real_data = real_data.drop(drop_columns, axis=1)