From dd8a3ad873e1c80e9f2692671fd96fab0c64db7f Mon Sep 17 00:00:00 2001 From: Adrian Galvan Date: Tue, 10 Dec 2024 17:05:19 -0800 Subject: [PATCH] Setting requires_primary_keys for select connectors + updating tests --- .../api/service/connectors/base_connector.py | 8 +- .../service/connectors/bigquery_connector.py | 1 + .../service/connectors/postgres_connector.py | 5 + .../connectors/query_configs/query_config.py | 4 +- .../api/service/connectors/saas_connector.py | 2 + .../service/connectors/scylla_connector.py | 5 + .../service/connectors/scylla_query_config.py | 11 +- .../postgres_example_test_dataset.yml | 18 ++ .../service/connectors/test_query_config.py | 205 +++++++++++++++++- .../connectors/test_snowflake_query_config.py | 4 +- tests/ops/task/test_create_request_tasks.py | 2 +- tests/ops/test_helpers/dataset_utils.py | 30 ++- 12 files changed, 281 insertions(+), 14 deletions(-) diff --git a/src/fides/api/service/connectors/base_connector.py b/src/fides/api/service/connectors/base_connector.py index 4bf46e5eca..e1f735df1c 100644 --- a/src/fides/api/service/connectors/base_connector.py +++ b/src/fides/api/service/connectors/base_connector.py @@ -135,5 +135,11 @@ def execute_standalone_retrieval_query( @property def requires_primary_keys(self) -> bool: - """Indicates if datasets linked to this connector require primary keys for erasures. Defaults to True.""" + """ + Indicates if datasets linked to this connector require primary keys for erasures. + Defaults to True. + """ + + # Defaulting to true for now so we can keep the default behavior and + # incrementally determine the need for primary keys across all connectors return True diff --git a/src/fides/api/service/connectors/bigquery_connector.py b/src/fides/api/service/connectors/bigquery_connector.py index 4c52b3b3f6..ae6fe4b909 100644 --- a/src/fides/api/service/connectors/bigquery_connector.py +++ b/src/fides/api/service/connectors/bigquery_connector.py @@ -35,6 +35,7 @@ class BigQueryConnector(SQLConnector): @property def requires_primary_keys(self) -> bool: + """BigQuery does not have the concept of primary keys so they're not required for erasures.""" return False # Overrides BaseConnector.build_uri diff --git a/src/fides/api/service/connectors/postgres_connector.py b/src/fides/api/service/connectors/postgres_connector.py index 5354d4ec13..2abafc01c8 100644 --- a/src/fides/api/service/connectors/postgres_connector.py +++ b/src/fides/api/service/connectors/postgres_connector.py @@ -19,6 +19,11 @@ class PostgreSQLConnector(SQLConnector): secrets_schema = PostgreSQLSchema + @property + def requires_primary_keys(self) -> bool: + """Postgres allows arbitrary columns in the WHERE clause for updates so primary keys are not required.""" + return False + def build_uri(self) -> str: """Build URI of format postgresql://[user[:password]@][netloc][:port][/dbname]""" config = self.secrets_schema(**self.configuration.secrets or {}) diff --git a/src/fides/api/service/connectors/query_configs/query_config.py b/src/fides/api/service/connectors/query_configs/query_config.py index c54eecff85..9f5ddb0251 100644 --- a/src/fides/api/service/connectors/query_configs/query_config.py +++ b/src/fides/api/service/connectors/query_configs/query_config.py @@ -430,7 +430,7 @@ def get_update_stmt( def get_update_clauses( self, update_value_map: Dict[str, Any], - non_empty_reference_fields: Dict[str, Field], + where_clause_fields: Dict[str, Field], ) -> List[str]: """Returns a list of update clauses for the update statement.""" @@ -567,7 +567,7 @@ def format_key_map_for_update_stmt(self, param_map: Dict[str, Any]) -> List[str] def get_update_clauses( self, update_value_map: Dict[str, Any], - non_empty_reference_fields: Dict[str, Field], + where_clause_fields: Dict[str, Field], ) -> List[str]: """Returns a list of update clauses for the update statement.""" return self.format_key_map_for_update_stmt(update_value_map) diff --git a/src/fides/api/service/connectors/saas_connector.py b/src/fides/api/service/connectors/saas_connector.py index 40a4d8a7eb..b1101467bf 100644 --- a/src/fides/api/service/connectors/saas_connector.py +++ b/src/fides/api/service/connectors/saas_connector.py @@ -72,7 +72,9 @@ class SaaSConnector(BaseConnector[AuthenticatedClient], Contextualizable): """A connector type to integrate with third-party SaaS APIs""" + @property def requires_primary_keys(self) -> bool: + """SaaS connectors work with HTTP requests, so the database concept of primary keys does not apply.""" return False def get_log_context(self) -> Dict[LoggerContextKeys, Any]: diff --git a/src/fides/api/service/connectors/scylla_connector.py b/src/fides/api/service/connectors/scylla_connector.py index 43a821930c..ff17674b88 100644 --- a/src/fides/api/service/connectors/scylla_connector.py +++ b/src/fides/api/service/connectors/scylla_connector.py @@ -28,6 +28,11 @@ class ScyllaConnectorMissingKeyspace(Exception): class ScyllaConnector(BaseConnector[Cluster]): """Scylla Connector""" + @property + def requires_primary_keys(self) -> bool: + """ScyllaDB requires primary keys for erasures.""" + return True + def build_uri(self) -> str: """ Builds URI - Not yet implemented diff --git a/src/fides/api/service/connectors/scylla_query_config.py b/src/fides/api/service/connectors/scylla_query_config.py index 5e93668459..1fa52d573d 100644 --- a/src/fides/api/service/connectors/scylla_query_config.py +++ b/src/fides/api/service/connectors/scylla_query_config.py @@ -77,14 +77,19 @@ def format_key_map_for_update_stmt(self, param_map: Dict[str, Any]) -> List[str] def get_update_clauses( self, update_value_map: Dict[str, Any], - non_empty_reference_fields: Dict[str, Field], + where_clause_fields: Dict[str, Field], ) -> List[str]: - """Returns a list of update clauses for the update statement.""" + """Returns a list of update clauses for the update statement. + + Omits primary key fields from updates since ScyllaDB prohibits + updating primary key fields. + """ + return self.format_key_map_for_update_stmt( { key: value for key, value in update_value_map.items() - if key not in non_empty_reference_fields + if key not in where_clause_fields } ) diff --git a/src/fides/data/sample_project/sample_resources/postgres_example_test_dataset.yml b/src/fides/data/sample_project/sample_resources/postgres_example_test_dataset.yml index 768c972d99..e519a75008 100644 --- a/src/fides/data/sample_project/sample_resources/postgres_example_test_dataset.yml +++ b/src/fides/data/sample_project/sample_resources/postgres_example_test_dataset.yml @@ -11,6 +11,8 @@ dataset: data_categories: [user.contact.address.street] - name: id data_categories: [system.operations] + fides_meta: + primary_key: True - name: state data_categories: [user.contact.address.state] - name: street @@ -36,6 +38,8 @@ dataset: data_type: string - name: id data_categories: [user.unique_id] + fides_meta: + primary_key: True - name: name data_categories: [user.name] fides_meta: @@ -58,6 +62,8 @@ dataset: data_type: string - name: id data_categories: [user.unique_id] + fides_meta: + primary_key: True - name: name data_categories: [user.name] fides_meta: @@ -74,6 +80,8 @@ dataset: direction: from - name: id data_categories: [system.operations] + fides_meta: + primary_key: True - name: time data_categories: [user.sensor] @@ -88,6 +96,8 @@ dataset: direction: from - name: id data_categories: [system.operations] + fides_meta: + primary_key: True - name: shipping_address_id data_categories: [system.operations] fides_meta: @@ -138,6 +148,8 @@ dataset: direction: from - name: id data_categories: [system.operations] + fides_meta: + primary_key: True - name: name data_categories: [user.financial] - name: preferred @@ -147,6 +159,8 @@ dataset: fields: - name: id data_categories: [system.operations] + fides_meta: + primary_key: True - name: name data_categories: [system.operations] - name: price @@ -161,6 +175,8 @@ dataset: data_type: string - name: id data_categories: [system.operations] + fides_meta: + primary_key: True - name: month data_categories: [system.operations] - name: name @@ -193,6 +209,8 @@ dataset: direction: from - name: id data_categories: [system.operations] + fides_meta: + primary_key: True - name: opened data_categories: [system.operations] diff --git a/tests/ops/service/connectors/test_query_config.py b/tests/ops/service/connectors/test_query_config.py index 2aa0871255..eac650d587 100644 --- a/tests/ops/service/connectors/test_query_config.py +++ b/tests/ops/service/connectors/test_query_config.py @@ -21,6 +21,7 @@ from fides.api.service.masking.strategy.masking_strategy_hash import HashMaskingStrategy from fides.api.util.data_category import DataCategory from tests.fixtures.application_fixtures import load_dataset +from tests.ops.test_helpers.dataset_utils import remove_primary_keys from ...task.traversal_data import integration_db_graph from ...test_helpers.cache_secrets_helper import cache_secret, clear_cache_secrets @@ -273,7 +274,7 @@ def test_generate_update_stmt_one_field( text_clause = config.generate_update_stmt(row, erasure_policy, privacy_request) assert ( text_clause.text - == """UPDATE customer SET name = :masked_name WHERE email = :email""" + == """UPDATE customer SET name = :masked_name WHERE id = :id""" ) assert text_clause._bindparams["masked_name"].key == "masked_name" assert ( @@ -341,7 +342,7 @@ def test_generate_update_stmt_length_truncation( ) assert ( text_clause.text - == """UPDATE customer SET name = :masked_name WHERE email = :email""" + == """UPDATE customer SET name = :masked_name WHERE id = :id""" ) assert text_clause._bindparams["masked_name"].key == "masked_name" # length truncation on name field @@ -391,7 +392,7 @@ def test_generate_update_stmt_multiple_fields_same_rule( text_clause = config.generate_update_stmt(row, erasure_policy, privacy_request) assert ( text_clause.text - == "UPDATE customer SET email = :masked_email, name = :masked_name WHERE email = :email" + == "UPDATE customer SET email = :masked_email, name = :masked_name WHERE id = :id" ) assert text_clause._bindparams["masked_name"].key == "masked_name" # since length is set to 40 in dataset.yml, we expect only first 40 chars of masked val @@ -407,7 +408,7 @@ def test_generate_update_stmt_multiple_fields_same_rule( ["customer-1@example.com"], request_id=privacy_request.id )[0] ) - assert text_clause._bindparams["email"].value == "customer-1@example.com" + assert text_clause._bindparams["id"].value == 1 clear_cache_secrets(privacy_request.id) def test_generate_update_stmts_from_multiple_rules( @@ -434,6 +435,201 @@ def test_generate_update_stmts_from_multiple_rules( row, erasure_policy_two_rules, privacy_request ) + assert ( + text_clause.text + == "UPDATE customer SET email = :masked_email, name = :masked_name WHERE id = :id" + ) + # Two different masking strategies used for name and email + assert ( + text_clause._bindparams["masked_name"].value is None + ) # Null masking strategy + assert ( + text_clause._bindparams["masked_email"].value == "*****" + ) # String rewrite masking strategy + + def test_generate_update_stmt_one_field_without_primary_keys( + self, erasure_policy, example_datasets, connection_config + ): + dataset = remove_primary_keys(Dataset(**example_datasets[0])) + graph = convert_dataset_to_graph(dataset, connection_config.key) + dataset_graph = DatasetGraph(*[graph]) + traversal = Traversal(dataset_graph, {"email": "customer-1@example.com"}) + + customer_node = traversal.traversal_node_dict[ + CollectionAddress("postgres_example_test_dataset", "customer") + ].to_mock_execution_node() + + config = SQLQueryConfig(customer_node) + row = { + "email": "customer-1@example.com", + "name": "John Customer", + "address_id": 1, + "id": 1, + } + text_clause = config.generate_update_stmt(row, erasure_policy, privacy_request) + assert ( + text_clause.text + == """UPDATE customer SET name = :masked_name WHERE email = :email""" + ) + assert text_clause._bindparams["masked_name"].key == "masked_name" + assert ( + text_clause._bindparams["masked_name"].value is None + ) # Null masking strategy + + def test_generate_update_stmt_one_field_inbound_reference_without_primary_keys( + self, erasure_policy_address_city, example_datasets, connection_config + ): + dataset = remove_primary_keys(Dataset(**example_datasets[0])) + graph = convert_dataset_to_graph(dataset, connection_config.key) + dataset_graph = DatasetGraph(*[graph]) + traversal = Traversal(dataset_graph, {"email": "customer-1@example.com"}) + + address_node = traversal.traversal_node_dict[ + CollectionAddress("postgres_example_test_dataset", "address") + ].to_mock_execution_node() + + config = SQLQueryConfig(address_node) + row = { + "id": 1, + "house": "123", + "street": "Main St", + "city": "San Francisco", + "state": "CA", + "zip": "94105", + } + text_clause = config.generate_update_stmt( + row, erasure_policy_address_city, privacy_request + ) + assert ( + text_clause.text + == """UPDATE address SET city = :masked_city WHERE id = :id""" + ) + assert text_clause._bindparams["masked_city"].key == "masked_city" + assert ( + text_clause._bindparams["masked_city"].value is None + ) # Null masking strategy + + def test_generate_update_stmt_length_truncation_without_primary_keys( + self, + erasure_policy_string_rewrite_long, + example_datasets, + connection_config, + ): + dataset = remove_primary_keys(Dataset(**example_datasets[0])) + graph = convert_dataset_to_graph(dataset, connection_config.key) + dataset_graph = DatasetGraph(*[graph]) + traversal = Traversal(dataset_graph, {"email": "customer-1@example.com"}) + + customer_node = traversal.traversal_node_dict[ + CollectionAddress("postgres_example_test_dataset", "customer") + ].to_mock_execution_node() + + config = SQLQueryConfig(customer_node) + row = { + "email": "customer-1@example.com", + "name": "John Customer", + "address_id": 1, + "id": 1, + } + + text_clause = config.generate_update_stmt( + row, erasure_policy_string_rewrite_long, privacy_request + ) + assert ( + text_clause.text + == """UPDATE customer SET name = :masked_name WHERE email = :email""" + ) + assert text_clause._bindparams["masked_name"].key == "masked_name" + # length truncation on name field + assert ( + text_clause._bindparams["masked_name"].value + == "some rewrite value that is very long and" + ) + + def test_generate_update_stmt_multiple_fields_same_rule_without_primary_keys( + self, erasure_policy, example_datasets, connection_config + ): + dataset = remove_primary_keys(Dataset(**example_datasets[0])) + graph = convert_dataset_to_graph(dataset, connection_config.key) + dataset_graph = DatasetGraph(*[graph]) + traversal = Traversal(dataset_graph, {"email": "customer-1@example.com"}) + + customer_node = traversal.traversal_node_dict[ + CollectionAddress("postgres_example_test_dataset", "customer") + ].to_mock_execution_node() + + config = SQLQueryConfig(customer_node) + row = { + "email": "customer-1@example.com", + "name": "John Customer", + "address_id": 1, + "id": 1, + } + + # Make target more broad + rule = erasure_policy.rules[0] + target = rule.targets[0] + target.data_category = DataCategory("user").value + + # Update rule masking strategy + rule.masking_strategy = { + "strategy": "hash", + "configuration": {"algorithm": "SHA-512"}, + } + # cache secrets for hash strategy + secret = MaskingSecretCache[str]( + secret="adobo", + masking_strategy=HashMaskingStrategy.name, + secret_type=SecretType.salt, + ) + cache_secret(secret, privacy_request.id) + + text_clause = config.generate_update_stmt(row, erasure_policy, privacy_request) + assert ( + text_clause.text + == "UPDATE customer SET email = :masked_email, name = :masked_name WHERE email = :email" + ) + assert text_clause._bindparams["masked_name"].key == "masked_name" + # since length is set to 40 in dataset.yml, we expect only first 40 chars of masked val + assert ( + text_clause._bindparams["masked_name"].value + == HashMaskingStrategy(HashMaskingConfiguration(algorithm="SHA-512")).mask( + ["John Customer"], request_id=privacy_request.id + )[0][0:40] + ) + assert ( + text_clause._bindparams["masked_email"].value + == HashMaskingStrategy(HashMaskingConfiguration(algorithm="SHA-512")).mask( + ["customer-1@example.com"], request_id=privacy_request.id + )[0] + ) + assert text_clause._bindparams["email"].value == "customer-1@example.com" + clear_cache_secrets(privacy_request.id) + + def test_generate_update_stmts_from_multiple_rules_without_primary_keys( + self, erasure_policy_two_rules, example_datasets, connection_config + ): + dataset = remove_primary_keys(Dataset(**example_datasets[0])) + graph = convert_dataset_to_graph(dataset, connection_config.key) + dataset_graph = DatasetGraph(*[graph]) + traversal = Traversal(dataset_graph, {"email": "customer-1@example.com"}) + row = { + "email": "customer-1@example.com", + "name": "John Customer", + "address_id": 1, + "id": 1, + } + + customer_node = traversal.traversal_node_dict[ + CollectionAddress("postgres_example_test_dataset", "customer") + ].to_mock_execution_node() + + config = SQLQueryConfig(customer_node) + + text_clause = config.generate_update_stmt( + row, erasure_policy_two_rules, privacy_request + ) + assert ( text_clause.text == "UPDATE customer SET email = :masked_email, name = :masked_name WHERE email = :email" @@ -446,6 +642,7 @@ def test_generate_update_stmts_from_multiple_rules( text_clause._bindparams["masked_email"].value == "*****" ) # String rewrite masking strategy + class TestSQLLikeQueryConfig: def test_missing_namespace_meta_schema(self): diff --git a/tests/ops/service/connectors/test_snowflake_query_config.py b/tests/ops/service/connectors/test_snowflake_query_config.py index 5521a1a88a..4f4b23b8c4 100644 --- a/tests/ops/service/connectors/test_snowflake_query_config.py +++ b/tests/ops/service/connectors/test_snowflake_query_config.py @@ -150,7 +150,7 @@ def test_generate_update_stmt( ) assert ( str(update_stmt) - == 'UPDATE "address" SET "city" = :city, "house" = :house, "state" = :state, "street" = :street, "zip" = :zip WHERE "id" = :id' + == 'UPDATE "address" SET "city" = :masked_city, "house" = :masked_house, "state" = :masked_state, "street" = :masked_street, "zip" = :masked_zip WHERE "id" = :id' ) def test_generate_namespaced_update_stmt( @@ -191,5 +191,5 @@ def test_generate_namespaced_update_stmt( ) assert ( str(update_stmt) - == 'UPDATE "FIDESOPS_TEST"."TEST"."address" SET "city" = :city, "house" = :house, "state" = :state, "street" = :street, "zip" = :zip WHERE "id" = :id' + == 'UPDATE "FIDESOPS_TEST"."TEST"."address" SET "city" = :masked_city, "house" = :masked_house, "state" = :masked_state, "street" = :masked_street, "zip" = :masked_zip WHERE "id" = :id' ) diff --git a/tests/ops/task/test_create_request_tasks.py b/tests/ops/task/test_create_request_tasks.py index 290c2dc1be..ad118ee46c 100644 --- a/tests/ops/task/test_create_request_tasks.py +++ b/tests/ops/task/test_create_request_tasks.py @@ -927,7 +927,7 @@ def test_erase_after_saas_upstream_and_downstream_tasks( "is_array": False, "read_only": None, "references": [], - "primary_key": True, + "primary_key": False, "data_categories": ["system.operations"], "data_type_converter": "integer", "return_all_elements": None, diff --git a/tests/ops/test_helpers/dataset_utils.py b/tests/ops/test_helpers/dataset_utils.py index e60efb9892..d51e1f47ff 100644 --- a/tests/ops/test_helpers/dataset_utils.py +++ b/tests/ops/test_helpers/dataset_utils.py @@ -13,7 +13,11 @@ ) from fides.api.graph.data_type import DataType, get_data_type, to_data_type_string from fides.api.models.connectionconfig import ConnectionConfig -from fides.api.models.datasetconfig import DatasetConfig, convert_dataset_to_graph +from fides.api.models.datasetconfig import ( + DatasetConfig, + DatasetField, + convert_dataset_to_graph, +) from fides.api.util.collection_util import Row SAAS_DATASET_DIRECTORY = "data/saas/dataset/" @@ -231,3 +235,27 @@ def get_simple_fields(fields: Iterable[Field]) -> List[Dict[str, Any]]: object["fields"] = get_simple_fields(field.fields.values()) object_list.append(object) return object_list + + +def remove_primary_keys(dataset: Dataset) -> Dataset: + """Returns a copy of the dataset with primary key fields removed from fides_meta.""" + dataset_copy = dataset.model_copy(deep=True) + + for collection in dataset_copy.collections: + for field in collection.fields: + if field.fides_meta: + if field.fides_meta.primary_key: + field.fides_meta.primary_key = None + if field.fields: + _remove_nested_primary_keys(field.fields) + + return dataset_copy + + +def _remove_nested_primary_keys(fields: List[DatasetField]) -> None: + """Helper function to recursively remove primary keys from nested fields.""" + for field in fields: + if field.fides_meta and field.fides_meta.primary_key: + field.fides_meta.primary_key = None + if field.fields: + _remove_nested_primary_keys(field.fields)