Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions sdmetrics/single_table/detection/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,23 @@ def compute(cls, real_data, synthetic_data, metadata=None):
transformed_real_data = real_data
transformed_synthetic_data = synthetic_data

if metadata is not None and 'columns' in metadata:
drop_columns = []
for column in metadata['columns']:
if 'primary_key' in metadata and column == metadata['primary_key']:
continue
for field in metadata['columns'][column]:
if field == 'sdtype':
sdtype = metadata['columns'][column][field]
if sdtype == 'id' or sdtype == 'text':
drop_columns.append(column)
if field == 'pii':
if metadata['columns'][column][field]:
drop_columns.append(column)
if len(drop_columns) > 0:
transformed_real_data = transformed_real_data.drop(drop_columns, axis=1)
transformed_synthetic_data = transformed_synthetic_data.drop(drop_columns, axis=1)

ht = HyperTransformer()
transformed_real_data = ht.fit_transform(transformed_real_data).to_numpy()
transformed_synthetic_data = ht.transform(transformed_synthetic_data).to_numpy()
Expand Down
73 changes: 73 additions & 0 deletions tests/unit/single_table/detection/test_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,76 @@ def test_primary_key_detection_metrics(self, fit_transform_mock, transform_mock)
transform_mock.assert_called_with(expected_return_synthetic)
assert expected_return_real == call_1
assert expected_return_synthetic == call_2

@patch('sdmetrics.utils.HyperTransformer.transform')
@patch('sdmetrics.utils.HyperTransformer.fit_transform')
def test_ignore_keys_detection_metrics(self, fit_transform_mock, transform_mock):
"""This test checks that ``primary_key`` columns of dataset are ignored.

Ensure that ``primary_keys`` are ignored for Detection metrics expect that the match
is made correctly.
"""

# Setup
real_data = pd.DataFrame({
'ID_1': [1, 2, 1, 3, 4],
'col1': [43.0, 47.5, 34.2, 30.3, 39.1],
'col2': [1.0, 2.0, 3.0, 4.0, 5.0],
'ID_2': ['aa', 'bb', 'cc', 'dd', 'bb'],
'col3': [5, 6, 7, 8, 9],
'ID_3': ['a', 'b', 'c', 'd', 'e'],
'blob': ['Hello world!', 'Hello world!', 'This is SDV', 'This is SDV', 'Hello world!'],
'col4': [1, 3, 9, 2, 1]
})
synthetic_data = pd.DataFrame({
'ID_1': [1, 3, 4, 2, 2],
'col1': [23.0, 47.1, 44.9, 31.3, 9.7],
'col2': [11.0, 22.0, 33.0, 44.0, 55.0],
'ID_2': ['aa', 'bb', 'cc', 'dd', 'ee'],
'col3': [55, 66, 77, 88, 99],
'ID_3': ['a', 'b', 'e', 'd', 'c'],
'blob': ['Hello world!', 'Hello world!', 'This is SDV', 'This is SDV', 'Hello world!'],
'col4': [4, 1, 3, 1, 9]
})
metadata = {
'columns': {
'ID_1': {'sdtype': 'numerical'},
'col1': {'sdtype': 'numerical', 'pii': True},
'col2': {'sdtype': 'numerical'},
'ID_2': {'sdtype': 'categorical'},
'col3': {'sdtype': 'numerical'},
'ID_3': {'sdtype': 'id'},
'blob': {'sdtype': 'text'},
'col4': {'sdtype': 'numeric', 'pii': False}
},
'primary_key': {'ID_1', 'ID_2'}
}

expected_real_dataframe = pd.DataFrame({
'col2': [1.0, 2.0, 3.0, 4.0, 5.0],
'col3': [5, 6, 7, 8, 9],
'col4': [1, 3, 9, 2, 1]
})
expected_synthetic_dataframe = pd.DataFrame({
'col2': [11.0, 22.0, 33.0, 44.0, 55.0],
'col3': [55, 66, 77, 88, 99],
'col4': [4, 1, 3, 1, 9]
})

expected_return_real = DataFrameMatcher(expected_real_dataframe)
expected_return_synthetic = DataFrameMatcher(expected_synthetic_dataframe)
fit_transform_mock.return_value = expected_real_dataframe
transform_mock.return_value = expected_synthetic_dataframe

# Run
LogisticDetection().compute(real_data, synthetic_data, metadata)

# Assert

# check that ``fit_transform`` and ``transform`` received the good argument.
call_1 = pd.DataFrame(fit_transform_mock.call_args_list[0][0][0])
call_2 = pd.DataFrame(transform_mock.call_args_list[0][0][0])

transform_mock.assert_called_with(expected_return_synthetic)
assert expected_return_real == call_1
assert expected_return_synthetic == call_2