Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 29 additions & 8 deletions sdmetrics/single_table/detection/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from sdmetrics.errors import IncomputableMetricError
from sdmetrics.goal import Goal
from sdmetrics.single_table.base import SingleTableMetric
from sdmetrics.utils import HyperTransformer
from sdmetrics.utils import HyperTransformer, get_alternate_keys

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -43,6 +43,32 @@ def _fit_predict(X_train, y_train, X_test):
"""Fit a classifier and then use it to predict."""
raise NotImplementedError()

@staticmethod
def _drop_non_compute_columns(real_data, synthetic_data, metadata):
"""Drop all columns that cannot be statistically modeled."""
transformed_real_data = real_data
transformed_synthetic_data = synthetic_data

if metadata is not None:
drop_columns = []
drop_columns.extend(get_alternate_keys(metadata))
for column in metadata.get('columns', []):
if ('primary_key' in metadata and
(column == metadata['primary_key'] or
column in metadata['primary_key'])):
drop_columns.append(column)

column_info = metadata['columns'].get(column, {})
sdtype = column_info.get('sdtype')
pii = column_info.get('pii')
if sdtype not in ['numerical', 'datetime', 'categorical'] or pii:
drop_columns.append(column)

if drop_columns:
transformed_real_data = real_data.drop(drop_columns, axis=1)
transformed_synthetic_data = synthetic_data.drop(drop_columns, axis=1)
return transformed_real_data, transformed_synthetic_data

@classmethod
def compute(cls, real_data, synthetic_data, metadata=None):
"""Compute this metric.
Expand All @@ -68,13 +94,8 @@ def compute(cls, real_data, synthetic_data, metadata=None):
real_data, synthetic_data, metadata = cls._validate_inputs(
real_data, synthetic_data, metadata)

if metadata is not None and 'primary_key' in metadata:
transformed_real_data = real_data.drop(metadata['primary_key'], axis=1)
transformed_synthetic_data = synthetic_data.drop(metadata['primary_key'], axis=1)

else:
transformed_real_data = real_data
transformed_synthetic_data = synthetic_data
transformed_real_data, transformed_synthetic_data = cls._drop_non_compute_columns(
real_data, synthetic_data, metadata)

ht = HyperTransformer()
transformed_real_data = ht.fit_transform(transformed_real_data).to_numpy()
Expand Down
77 changes: 77 additions & 0 deletions tests/unit/single_table/detection/test_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,80 @@ def test_primary_key_detection_metrics(self, fit_transform_mock, transform_mock)
transform_mock.assert_called_with(expected_return_synthetic)
assert expected_return_real == call_1
assert expected_return_synthetic == call_2

@patch('sdmetrics.utils.HyperTransformer.transform')
@patch('sdmetrics.utils.HyperTransformer.fit_transform')
def test_ignore_keys_detection_metrics(self, fit_transform_mock, transform_mock):
"""This test checks that ``primary_key`` columns of dataset are ignored.

Ensure that ``primary_keys`` are ignored for Detection metrics expect that the match
is made correctly.
"""

# Setup
real_data = pd.DataFrame({
'ID_1': [1, 2, 1, 3, 4],
'col1': [43.0, 47.5, 34.2, 30.3, 39.1],
'col2': [1.0, 2.0, 3.0, 4.0, 5.0],
'ID_2': ['aa', 'bb', 'cc', 'dd', 'bb'],
'col3': [5, 6, 7, 8, 9],
'ID_3': ['a', 'b', 'c', 'd', 'e'],
'blob': ['Hello world!', 'Hello world!', 'This is SDV', 'This is SDV', 'Hello world!'],
'col4': [1, 3, 9, 2, 1],
'col5': [10, 20, 30, 40, 50]
})
synthetic_data = pd.DataFrame({
'ID_1': [1, 3, 4, 2, 2],
'col1': [23.0, 47.1, 44.9, 31.3, 9.7],
'col2': [11.0, 22.0, 33.0, 44.0, 55.0],
'ID_2': ['aa', 'bb', 'cc', 'dd', 'ee'],
'col3': [55, 66, 77, 88, 99],
'ID_3': ['a', 'b', 'e', 'd', 'c'],
'blob': ['Hello world!', 'Hello world!', 'This is SDV', 'This is SDV', 'Hello world!'],
'col4': [4, 1, 3, 1, 9],
'col5': [10, 20, 30, 40, 50]
})
metadata = {
'columns': {
'ID_1': {'sdtype': 'numerical'},
'col1': {'sdtype': 'numerical', 'pii': True},
'col2': {'sdtype': 'numerical'},
'ID_2': {'sdtype': 'categorical'},
'col3': {'sdtype': 'numerical'},
'ID_3': {'sdtype': 'id'},
'blob': {'sdtype': 'text'},
'col4': {'sdtype': 'numerical', 'pii': False},
'col5': {'sdtype': 'numerical'}
},
'primary_key': {'ID_1', 'ID_2'},
'alternate_keys': ['col5']
}

expected_real_dataframe = pd.DataFrame({
'col2': [1.0, 2.0, 3.0, 4.0, 5.0],
'col3': [5, 6, 7, 8, 9],
'col4': [1, 3, 9, 2, 1]
})
expected_synthetic_dataframe = pd.DataFrame({
'col2': [11.0, 22.0, 33.0, 44.0, 55.0],
'col3': [55, 66, 77, 88, 99],
'col4': [4, 1, 3, 1, 9]
})

expected_return_real = DataFrameMatcher(expected_real_dataframe)
expected_return_synthetic = DataFrameMatcher(expected_synthetic_dataframe)
fit_transform_mock.return_value = expected_real_dataframe
transform_mock.return_value = expected_synthetic_dataframe

# Run
LogisticDetection().compute(real_data, synthetic_data, metadata)

# Assert

# check that ``fit_transform`` and ``transform`` received the good argument.
call_1 = pd.DataFrame(fit_transform_mock.call_args_list[0][0][0])
call_2 = pd.DataFrame(transform_mock.call_args_list[0][0][0])

transform_mock.assert_called_with(expected_return_synthetic)
assert expected_return_real == call_1
assert expected_return_synthetic == call_2