diff --git a/.github/workflows/backend_checks.yml b/.github/workflows/backend_checks.yml index a9236ea23b..e1cc215070 100644 --- a/.github/workflows/backend_checks.yml +++ b/.github/workflows/backend_checks.yml @@ -244,6 +244,7 @@ jobs: Safe-Tests: needs: Check-Container-Startup strategy: + fail-fast: false matrix: python_version: ["3.9.18", "3.10.13"] test_selection: @@ -255,7 +256,6 @@ jobs: runs-on: ubuntu-latest timeout-minutes: 45 - continue-on-error: true steps: - name: Download container uses: actions/download-artifact@v4 @@ -397,8 +397,8 @@ jobs: # Secrets to pull from 1Password BIGQUERY_DATASET: op://github-actions/bigquery/BIGQUERY_DATASET BIGQUERY_KEYFILE_CREDS: op://github-actions/bigquery/BIGQUERY_KEYFILE_CREDS - BIGQUERY_ENTERPRISE_DATASET: op://github-actions/bigquery-enterprise/BIGQUERY_DATASET - BIGQUERY_ENTERPRISE_KEYFILE_CREDS: op://github-actions/bigquery-enterprise/BIGQUERY_KEYFILE_CREDS + BIGQUERY_ENTERPRISE_DATASET: op://github-actions/bigquery-enterprise/BIGQUERY_ENTERPRISE_DATASET + BIGQUERY_ENTERPRISE_KEYFILE_CREDS: op://github-actions/bigquery-enterprise/BIGQUERY_ENTERPRISE_KEYFILE_CREDS DYNAMODB_ACCESS_KEY_ID: op://github-actions/dynamodb/DYNAMODB_ACCESS_KEY_ID DYNAMODB_ACCESS_KEY: op://github-actions/dynamodb/DYNAMODB_ACCESS_KEY DYNAMODB_ASSUME_ROLE_ARN: op://github-actions/dynamodb/DYNAMODB_ASSUME_ROLE_ARN diff --git a/clients/fides-js/src/components/Overlay.tsx b/clients/fides-js/src/components/Overlay.tsx index 8d403f5e83..24b3d57b1a 100644 --- a/clients/fides-js/src/components/Overlay.tsx +++ b/clients/fides-js/src/components/Overlay.tsx @@ -153,7 +153,7 @@ const Overlay: FunctionComponent = ({ }, [showBanner, setBannerIsOpen]); useEffect(() => { - if (!experience || options.modalLinkId === "") { + if (options.fidesEmbed || !experience || options.modalLinkId === "") { // If empty string is explicitly set, do not attempt to bind the modal link to the click handler. // developers using `Fides.showModal();` can use this to prevent polling for the modal link. return () => {}; @@ -211,7 +211,13 @@ const Overlay: FunctionComponent = ({ } window.Fides.showModal = defaultShowModal; }; - }, [options.modalLinkId, options.debug, handleOpenModal, experience]); + }, [ + options.fidesEmbed, + options.modalLinkId, + options.debug, + handleOpenModal, + experience, + ]); const handleManagePreferencesClick = (): void => { handleOpenModal(); diff --git a/clients/fides-js/src/components/fides.css b/clients/fides-js/src/components/fides.css index 995af6e740..3eb261516b 100644 --- a/clients/fides-js/src/components/fides.css +++ b/clients/fides-js/src/components/fides.css @@ -485,7 +485,7 @@ div#fides-consent-content .fides-modal-description { justify-content: center; } -.fides-modal-container .fides-button-group-brand { +.fides-modal-footer .fides-button-group-brand { min-height: var(--fides-overlay-modal-secondary-button-group-height); } @@ -1023,13 +1023,13 @@ div#fides-overlay-wrapper .fides-toggle .fides-toggle-display { position: relative; } -.fides-modal-container .fides-i18n-menu { +.fides-modal-footer .fides-i18n-menu { position: absolute; left: var(--fides-overlay-padding); bottom: var(--fides-overlay-padding); } -.fides-modal-container .fides-button-group-i18n { +.fides-modal-footer .fides-button-group-i18n { min-height: var(--fides-overlay-modal-secondary-button-group-height); } diff --git a/clients/sample-app/README.md b/clients/sample-app/README.md index 9142c14617..73d161ae94 100644 --- a/clients/sample-app/README.md +++ b/clients/sample-app/README.md @@ -31,6 +31,13 @@ This will automatically bring up a Docker Compose project to create a sample app Once running successfully, open http://localhost:3000 to see the Cookie House! +Note: If you are already running a database on port 5432 locally, you can override the default port by setting the `FIDES_SAMPLE_APP__DATABASE_PORT` environment variable and ALSO changing the **host** port number in the `docker-compose.yml` file. For example: + +```yaml +ports: + - "5433:5432" +``` + ## Pre-commit Before committing any changes, run the following: diff --git a/clients/sample-app/src/pages/embedded-consent.tsx b/clients/sample-app/src/pages/embedded-consent.tsx new file mode 100644 index 0000000000..3b4334b031 --- /dev/null +++ b/clients/sample-app/src/pages/embedded-consent.tsx @@ -0,0 +1,109 @@ +import { GetServerSideProps } from "next"; +import Head from "next/head"; +import { useRouter } from "next/router"; +import Script from "next/script"; + +interface Props { + gtmContainerId: string | null; + privacyCenterUrl: string; +} + +// Regex to ensure the provided GTM container ID appears valid (e.g. "GTM-ABCD123") +// NOTE: this also protects against XSS since this ID is added to a script template +const VALID_GTM_REGEX = /^[0-9a-zA-Z-]+$/; + +/** + * Pass the following server-side ENV variables to the page: + * - FIDES_SAMPLE_APP__GOOGLE_TAG_MANAGER_CONTAINER_ID: configure a GTM container, e.g. "GTM-ABCD123" + * - FIDES_SAMPLE_APP__PRIVACY_CENTER_URL: configure Privacy Center URL, e.g. "http://localhost:3001" + */ +export const getServerSideProps: GetServerSideProps = async () => { + // Check for a valid FIDES_SAMPLE_APP__GOOGLE_TAG_MANAGER_CONTAINER_ID + let gtmContainerId = null; + if ( + process.env.FIDES_SAMPLE_APP__GOOGLE_TAG_MANAGER_CONTAINER_ID?.match( + VALID_GTM_REGEX, + ) + ) { + gtmContainerId = + process.env.FIDES_SAMPLE_APP__GOOGLE_TAG_MANAGER_CONTAINER_ID; + } + + // Check for a valid FIDES_SAMPLE_APP__PRIVACY_CENTER_URL + const privacyCenterUrl = + process.env.FIDES_SAMPLE_APP__PRIVACY_CENTER_URL || "http://localhost:3001"; + + // Pass the server-side props to the page + return { props: { gtmContainerId, privacyCenterUrl } }; +}; + +const IndexPage = ({ gtmContainerId, privacyCenterUrl }: Props) => { + // Load the fides.js script from the Fides Privacy Center, assumed to be + // running at http://localhost:3001 + const fidesScriptTagUrl = new URL(`${privacyCenterUrl}/fides.js`); + const router = useRouter(); + // eslint-disable-next-line @typescript-eslint/naming-convention + const { geolocation, property_id } = router.query; + + // If `geolocation=` or `property_id` query params exists, pass those along to the fides.js fetch + if (geolocation && typeof geolocation === "string") { + fidesScriptTagUrl.searchParams.append("geolocation", geolocation); + } + if (typeof property_id === "string") { + fidesScriptTagUrl.searchParams.append("property_id", property_id); + } + + return ( + <> + + Cookie House + {/* Require FidesJS to "embed" it's UI onto the page, instead of as an overlay over the itself. (see https://ethyca.com/docs/dev-docs/js/reference/interfaces/FidesOptions#fides_embed) */} + + {/* Allow the embedded consent modal to fill the screen */} + + + {/** + Insert the fides.js script and run the GTM integration once ready + DEFER: using "beforeInteractive" here triggers a lint warning from NextJS + as it should only be used in the _document.tsx file. This still works and + ensures that fides.js fires earlier than other scripts, but isn't a best + practice. + */} + + ) : null} +
+ + ); +}; + +export default IndexPage; diff --git a/data/dataset/bigquery_enterprise_test_dataset.yml b/data/dataset/bigquery_enterprise_test_dataset.yml index 10504d63a5..59d27e68a2 100644 --- a/data/dataset/bigquery_enterprise_test_dataset.yml +++ b/data/dataset/bigquery_enterprise_test_dataset.yml @@ -31,7 +31,7 @@ dataset: references: null identity: null primary_key: true - data_type: null + data_type: integer length: null return_all_elements: null read_only: null @@ -103,7 +103,7 @@ dataset: references: null identity: null primary_key: true - data_type: null + data_type: integer length: null return_all_elements: null read_only: null @@ -119,18 +119,7 @@ dataset: description: null data_categories: - system.operations - fides_meta: - references: - - dataset: enterprise_dsr_testing - field: stackoverflow_posts.id - direction: from - identity: null - primary_key: null - data_type: null - length: null - return_all_elements: null - read_only: null - custom_request_field: null + fides_meta: null fields: null - name: revision_guid description: null @@ -147,7 +136,7 @@ dataset: - name: user_id description: null data_categories: - - user.contact + - system.operations fides_meta: references: - dataset: enterprise_dsr_testing @@ -216,7 +205,7 @@ dataset: references: null identity: null primary_key: true - data_type: null + data_type: integer length: null return_all_elements: null read_only: null @@ -260,7 +249,7 @@ dataset: - name: owner_display_name description: null data_categories: - - system.operations + - user.contact fides_meta: null fields: null - name: owner_user_id @@ -274,7 +263,7 @@ dataset: direction: from identity: null primary_key: null - data_type: null + data_type: integer length: null return_all_elements: null read_only: null diff --git a/setup.py b/setup.py index d6ee2d1237..624c3161d1 100644 --- a/setup.py +++ b/setup.py @@ -54,7 +54,7 @@ def optional_requirements( ## Package Setup ## ################### setup( - name="ethyca-fides", + name="ethyca_fides", version=versioneer.get_version(), cmdclass=versioneer.get_cmdclass(), description="Open-source ecosystem for data privacy as code.", diff --git a/src/fides/api/service/connectors/__init__.py b/src/fides/api/service/connectors/__init__.py index d81498c005..3a96545587 100644 --- a/src/fides/api/service/connectors/__init__.py +++ b/src/fides/api/service/connectors/__init__.py @@ -9,6 +9,9 @@ from fides.api.models.connectionconfig import ConnectionConfig as ConnectionConfig from fides.api.models.connectionconfig import ConnectionType as ConnectionType from fides.api.service.connectors.base_connector import BaseConnector as BaseConnector +from fides.api.service.connectors.bigquery_connector import ( + BigQueryConnector as BigQueryConnector, +) from fides.api.service.connectors.consent_email_connector import ( GenericConsentEmailConnector, ) @@ -29,47 +32,46 @@ from fides.api.service.connectors.fides_connector import ( FidesConnector as FidesConnector, ) +from fides.api.service.connectors.google_cloud_mysql_connector import ( + GoogleCloudSQLMySQLConnector as GoogleCloudSQLMySQLConnector, +) +from fides.api.service.connectors.google_cloud_postgres_connector import ( + GoogleCloudSQLPostgresConnector as GoogleCloudSQLPostgresConnector, +) from fides.api.service.connectors.http_connector import HTTPSConnector as HTTPSConnector from fides.api.service.connectors.manual_webhook_connector import ( ManualWebhookConnector as ManualWebhookConnector, ) +from fides.api.service.connectors.mariadb_connector import ( + MariaDBConnector as MariaDBConnector, +) +from fides.api.service.connectors.microsoft_sql_server_connector import ( + MicrosoftSQLServerConnector as MicrosoftSQLServerConnector, +) from fides.api.service.connectors.mongodb_connector import ( MongoDBConnector as MongoDBConnector, ) +from fides.api.service.connectors.mysql_connector import ( + MySQLConnector as MySQLConnector, +) +from fides.api.service.connectors.postgres_connector import ( + PostgreSQLConnector as PostgreSQLConnector, +) from fides.api.service.connectors.rds_mysql_connector import ( RDSMySQLConnector as RDSMySQLConnector, ) from fides.api.service.connectors.rds_postgres_connector import ( RDSPostgresConnector as RDSPostgresConnector, ) +from fides.api.service.connectors.redshift_connector import ( + RedshiftConnector as RedshiftConnector, +) from fides.api.service.connectors.s3_connector import S3Connector from fides.api.service.connectors.saas_connector import SaaSConnector as SaaSConnector from fides.api.service.connectors.scylla_connector import ( ScyllaConnector as ScyllaConnector, ) -from fides.api.service.connectors.sql_connector import ( - BigQueryConnector as BigQueryConnector, -) -from fides.api.service.connectors.sql_connector import ( - GoogleCloudSQLMySQLConnector as GoogleCloudSQLMySQLConnector, -) -from fides.api.service.connectors.sql_connector import ( - GoogleCloudSQLPostgresConnector as GoogleCloudSQLPostgresConnector, -) -from fides.api.service.connectors.sql_connector import ( - MariaDBConnector as MariaDBConnector, -) -from fides.api.service.connectors.sql_connector import ( - MicrosoftSQLServerConnector as MicrosoftSQLServerConnector, -) -from fides.api.service.connectors.sql_connector import MySQLConnector as MySQLConnector -from fides.api.service.connectors.sql_connector import ( - PostgreSQLConnector as PostgreSQLConnector, -) -from fides.api.service.connectors.sql_connector import ( - RedshiftConnector as RedshiftConnector, -) -from fides.api.service.connectors.sql_connector import ( +from fides.api.service.connectors.snowflake_connector import ( SnowflakeConnector as SnowflakeConnector, ) from fides.api.service.connectors.timescale_connector import ( diff --git a/src/fides/api/service/connectors/base_connector.py b/src/fides/api/service/connectors/base_connector.py index 2cd365d00c..ca3439f523 100644 --- a/src/fides/api/service/connectors/base_connector.py +++ b/src/fides/api/service/connectors/base_connector.py @@ -8,7 +8,7 @@ from fides.api.models.connectionconfig import ConnectionConfig, ConnectionTestStatus from fides.api.models.policy import Policy from fides.api.models.privacy_request import PrivacyRequest, RequestTask -from fides.api.service.connectors.query_config import QueryConfig +from fides.api.service.connectors.query_configs.query_config import QueryConfig from fides.api.util.collection_util import Row from fides.config import CONFIG diff --git a/src/fides/api/service/connectors/bigquery_connector.py b/src/fides/api/service/connectors/bigquery_connector.py new file mode 100644 index 0000000000..8b51f90842 --- /dev/null +++ b/src/fides/api/service/connectors/bigquery_connector.py @@ -0,0 +1,158 @@ +from typing import List, Optional + +from loguru import logger +from sqlalchemy import text +from sqlalchemy.engine import ( # type: ignore + Connection, + Engine, + LegacyCursorResult, + create_engine, +) +from sqlalchemy.orm import Session +from sqlalchemy.sql import Executable # type: ignore +from sqlalchemy.sql.elements import TextClause + +from fides.api.common_exceptions import ConnectionException +from fides.api.graph.execution import ExecutionNode +from fides.api.models.connectionconfig import ConnectionTestStatus +from fides.api.models.policy import Policy +from fides.api.models.privacy_request import PrivacyRequest, RequestTask +from fides.api.schemas.connection_configuration.connection_secrets_bigquery import ( + BigQuerySchema, +) +from fides.api.service.connectors.query_configs.bigquery_query_config import ( + BigQueryQueryConfig, +) +from fides.api.service.connectors.query_configs.query_config import SQLQueryConfig +from fides.api.service.connectors.sql_connector import SQLConnector +from fides.api.util.collection_util import Row + + +class BigQueryConnector(SQLConnector): + """Connector specific to Google BigQuery""" + + secrets_schema = BigQuerySchema + + # Overrides BaseConnector.build_uri + def build_uri(self) -> str: + """Build URI of format""" + config = self.secrets_schema(**self.configuration.secrets or {}) + dataset = f"/{config.dataset}" if config.dataset else "" + return f"bigquery://{config.keyfile_creds.project_id}{dataset}" # pylint: disable=no-member + + # Overrides SQLConnector.create_client + def create_client(self) -> Engine: + """ + Returns a SQLAlchemy Engine that can be used to interact with Google BigQuery. + + Overrides to pass in credentials_info + """ + secrets = self.configuration.secrets or {} + uri = secrets.get("url") or self.build_uri() + + keyfile_creds = secrets.get("keyfile_creds", {}) + credentials_info = dict(keyfile_creds) if keyfile_creds else {} + + return create_engine( + uri, + credentials_info=credentials_info, + hide_parameters=self.hide_parameters, + echo=not self.hide_parameters, + ) + + # Overrides SQLConnector.query_config + def query_config(self, node: ExecutionNode) -> BigQueryQueryConfig: + """Query wrapper corresponding to the input execution_node.""" + + db: Session = Session.object_session(self.configuration) + return BigQueryQueryConfig( + node, SQLConnector.get_namespace_meta(db, node.address.dataset) + ) + + def partitioned_retrieval( + self, + query_config: SQLQueryConfig, + connection: Connection, + stmt: TextClause, + ) -> List[Row]: + """ + Retrieve data against a partitioned table using the partitioning spec configured for this node to execute + multiple queries against the partitioned table. + + This is only supported by the BigQueryConnector currently. + + NOTE: when we deprecate `where_clause` partitioning in favor of a more proper partitioning DSL, + we should be sure to still support the existing `where_clause` partition definition on + any in-progress DSRs so that they can run through to completion. + """ + if not isinstance(query_config, BigQueryQueryConfig): + raise TypeError( + f"Unexpected query config of type '{type(query_config)}' passed to BigQueryConnector's `partitioned_retrieval`" + ) + + partition_clauses = query_config.get_partition_clauses() + logger.info( + f"Executing {len(partition_clauses)} partition queries for node '{query_config.node.address}' in DSR execution" + ) + rows = [] + for partition_clause in partition_clauses: + logger.debug( + f"Executing partition query with partition clause '{partition_clause}'" + ) + existing_bind_params = stmt.compile().params + partitioned_stmt = text(f"{stmt} AND ({text(partition_clause)})").params( + existing_bind_params + ) + results = connection.execute(partitioned_stmt) + rows.extend(self.cursor_result_to_rows(results)) + return rows + + # Overrides SQLConnector.test_connection + def test_connection(self) -> Optional[ConnectionTestStatus]: + """ + Overrides SQLConnector.test_connection with a BigQuery-specific connection test. + + The connection is tested using the native python client for BigQuery, since that is what's used + by the detection and discovery workflows/codepaths. + TODO: migrate the rest of this class, used for DSR execution, to also make use of the native bigquery client. + """ + try: + bq_schema = BigQuerySchema(**self.configuration.secrets or {}) + client = bq_schema.get_client() + all_projects = [project for project in client.list_projects()] + if all_projects: + return ConnectionTestStatus.succeeded + logger.error("No Bigquery Projects found with the provided credentials.") + raise ConnectionException( + "No Bigquery Projects found with the provided credentials." + ) + except Exception as e: + logger.exception(f"Error testing connection to remote BigQuery {str(e)}") + raise ConnectionException(f"Connection error: {e}") + + def mask_data( + self, + node: ExecutionNode, + policy: Policy, + privacy_request: PrivacyRequest, + request_task: RequestTask, + rows: List[Row], + ) -> int: + """Execute a masking request. Returns the number of records updated or deleted""" + query_config = self.query_config(node) + update_or_delete_ct = 0 + client = self.client() + for row in rows: + update_or_delete_stmts: List[Executable] = ( + query_config.generate_masking_stmt( + node, row, policy, privacy_request, client + ) + ) + if update_or_delete_stmts: + with client.connect() as connection: + for update_or_delete_stmt in update_or_delete_stmts: + results: LegacyCursorResult = connection.execute( + update_or_delete_stmt + ) + update_or_delete_ct = update_or_delete_ct + results.rowcount + return update_or_delete_ct diff --git a/src/fides/api/service/connectors/dynamodb_connector.py b/src/fides/api/service/connectors/dynamodb_connector.py index 1bb99a1028..1c1c8b6e6f 100644 --- a/src/fides/api/service/connectors/dynamodb_connector.py +++ b/src/fides/api/service/connectors/dynamodb_connector.py @@ -15,7 +15,10 @@ DynamoDBSchema, ) from fides.api.service.connectors.base_connector import BaseConnector -from fides.api.service.connectors.query_config import DynamoDBQueryConfig, QueryConfig +from fides.api.service.connectors.query_configs.dynamodb_query_config import ( + DynamoDBQueryConfig, +) +from fides.api.service.connectors.query_configs.query_config import QueryConfig from fides.api.util.aws_util import get_aws_session from fides.api.util.collection_util import Row from fides.api.util.logger import Pii diff --git a/src/fides/api/service/connectors/fides_connector.py b/src/fides/api/service/connectors/fides_connector.py index 86812cb53e..32f82023f6 100644 --- a/src/fides/api/service/connectors/fides_connector.py +++ b/src/fides/api/service/connectors/fides_connector.py @@ -17,7 +17,7 @@ from fides.api.schemas.redis_cache import Identity from fides.api.service.connectors.base_connector import BaseConnector from fides.api.service.connectors.fides.fides_client import FidesClient -from fides.api.service.connectors.query_config import QueryConfig +from fides.api.service.connectors.query_configs.query_config import QueryConfig from fides.api.util.collection_util import Row from fides.api.util.errors import FidesError diff --git a/src/fides/api/service/connectors/google_cloud_mysql_connector.py b/src/fides/api/service/connectors/google_cloud_mysql_connector.py new file mode 100644 index 0000000000..ffe354bbd5 --- /dev/null +++ b/src/fides/api/service/connectors/google_cloud_mysql_connector.py @@ -0,0 +1,56 @@ +from typing import List + +import pymysql +from google.cloud.sql.connector import Connector +from google.oauth2 import service_account +from sqlalchemy.engine import Engine, LegacyCursorResult, create_engine # type: ignore + +from fides.api.schemas.connection_configuration.connection_secrets_google_cloud_sql_mysql import ( + GoogleCloudSQLMySQLSchema, +) +from fides.api.service.connectors.sql_connector import SQLConnector +from fides.api.util.collection_util import Row +from fides.config import get_config + +CONFIG = get_config() + + +class GoogleCloudSQLMySQLConnector(SQLConnector): + """Connector specific to Google Cloud SQL for MySQL""" + + secrets_schema = GoogleCloudSQLMySQLSchema + + # Overrides SQLConnector.create_client + def create_client(self) -> Engine: + """Returns a SQLAlchemy Engine that can be used to interact with a database""" + + config = self.secrets_schema(**self.configuration.secrets or {}) + + credentials = service_account.Credentials.from_service_account_info( + dict(config.keyfile_creds) + ) + + # initialize connector with the loaded credentials + connector = Connector(credentials=credentials) + + def getconn() -> pymysql.connections.Connection: + conn: pymysql.connections.Connection = connector.connect( + config.instance_connection_name, + "pymysql", + user=config.db_iam_user, + db=config.dbname, + enable_iam_auth=True, + ) + return conn + + return create_engine("mysql+pymysql://", creator=getconn) + + @staticmethod + def cursor_result_to_rows(results: LegacyCursorResult) -> List[Row]: + """results to a list of dictionaries""" + return SQLConnector.default_cursor_result_to_rows(results) + + def build_uri(self) -> None: + """ + We need to override this method so it is not abstract anymore, and GoogleCloudSQLMySQLConnector is instantiable. + """ diff --git a/src/fides/api/service/connectors/google_cloud_postgres_connector.py b/src/fides/api/service/connectors/google_cloud_postgres_connector.py new file mode 100644 index 0000000000..cb8037ab1f --- /dev/null +++ b/src/fides/api/service/connectors/google_cloud_postgres_connector.py @@ -0,0 +1,86 @@ +from typing import List + +import pg8000 +from google.cloud.sql.connector import Connector +from google.oauth2 import service_account +from loguru import logger +from sqlalchemy import text +from sqlalchemy.engine import ( # type: ignore + Connection, + Engine, + LegacyCursorResult, + create_engine, +) + +from fides.api.graph.execution import ExecutionNode +from fides.api.schemas.connection_configuration.connection_secrets_google_cloud_sql_postgres import ( + GoogleCloudSQLPostgresSchema, +) +from fides.api.service.connectors.query_configs.google_cloud_postgres_query_config import ( + GoogleCloudSQLPostgresQueryConfig, +) +from fides.api.service.connectors.sql_connector import SQLConnector +from fides.api.util.collection_util import Row +from fides.config import get_config + +CONFIG = get_config() + + +class GoogleCloudSQLPostgresConnector(SQLConnector): + """Connector specific to Google Cloud SQL for Postgres""" + + secrets_schema = GoogleCloudSQLPostgresSchema + + @property + def default_db_name(self) -> str: + """Default database name for Google Cloud SQL Postgres""" + return "postgres" + + # Overrides SQLConnector.create_client + def create_client(self) -> Engine: + """Returns a SQLAlchemy Engine that can be used to interact with a database""" + + config = self.secrets_schema(**self.configuration.secrets or {}) + + credentials = service_account.Credentials.from_service_account_info( + dict(config.keyfile_creds) + ) + + # initialize connector with the loaded credentials + connector = Connector(credentials=credentials) + + def getconn() -> pg8000.dbapi.Connection: + conn: pg8000.dbapi.Connection = connector.connect( + config.instance_connection_name, + "pg8000", + user=config.db_iam_user, + db=config.dbname or self.default_db_name, + enable_iam_auth=True, + ) + return conn + + return create_engine("postgresql+pg8000://", creator=getconn) + + @staticmethod + def cursor_result_to_rows(results: LegacyCursorResult) -> List[Row]: + """results to a list of dictionaries""" + return SQLConnector.default_cursor_result_to_rows(results) + + def build_uri(self) -> None: + """ + We need to override this method so it is not abstract anymore, and GoogleCloudSQLPostgresConnector is instantiable. + """ + + def set_schema(self, connection: Connection) -> None: + """Sets the schema for a postgres database if applicable""" + config = self.secrets_schema(**self.configuration.secrets or {}) + if config.db_schema: + logger.info("Setting PostgreSQL search_path before retrieving data") + stmt = text("SELECT set_config('search_path', :search_path, false)") + stmt = stmt.bindparams(search_path=config.db_schema) + connection.execute(stmt) + + # Overrides SQLConnector.query_config + def query_config(self, node: ExecutionNode) -> GoogleCloudSQLPostgresQueryConfig: + """Query wrapper corresponding to the input execution_node.""" + return GoogleCloudSQLPostgresQueryConfig(node) diff --git a/src/fides/api/service/connectors/http_connector.py b/src/fides/api/service/connectors/http_connector.py index 50d4716928..1b367def7d 100644 --- a/src/fides/api/service/connectors/http_connector.py +++ b/src/fides/api/service/connectors/http_connector.py @@ -12,7 +12,7 @@ from fides.api.models.privacy_request import PrivacyRequest, RequestTask from fides.api.schemas.connection_configuration import HttpsSchema from fides.api.service.connectors.base_connector import BaseConnector -from fides.api.service.connectors.query_config import QueryConfig +from fides.api.service.connectors.query_configs.query_config import QueryConfig from fides.api.util.collection_util import Row diff --git a/src/fides/api/service/connectors/mariadb_connector.py b/src/fides/api/service/connectors/mariadb_connector.py new file mode 100644 index 0000000000..f52e457cd2 --- /dev/null +++ b/src/fides/api/service/connectors/mariadb_connector.py @@ -0,0 +1,41 @@ +from typing import List + +from sqlalchemy.engine import LegacyCursorResult # type: ignore + +from fides.api.schemas.connection_configuration.connection_secrets_mariadb import ( + MariaDBSchema, +) +from fides.api.service.connectors.sql_connector import SQLConnector +from fides.api.util.collection_util import Row +from fides.config import get_config + +CONFIG = get_config() + + +class MariaDBConnector(SQLConnector): + """Connector specific to MariaDB""" + + secrets_schema = MariaDBSchema + + def build_uri(self) -> str: + """Build URI of format mariadb+pymysql://[user[:password]@][netloc][:port][/dbname]""" + config = self.secrets_schema(**self.configuration.secrets or {}) + + user_password = "" + if config.username: + user = config.username + password = f":{config.password}" if config.password else "" + user_password = f"{user}{password}@" + + netloc = config.host + port = f":{config.port}" if config.port else "" + dbname = f"/{config.dbname}" if config.dbname else "" + url = f"mariadb+pymysql://{user_password}{netloc}{port}{dbname}" + return url + + @staticmethod + def cursor_result_to_rows(results: LegacyCursorResult) -> List[Row]: + """ + Convert SQLAlchemy results to a list of dictionaries + """ + return SQLConnector.default_cursor_result_to_rows(results) diff --git a/src/fides/api/service/connectors/microsoft_sql_server_connector.py b/src/fides/api/service/connectors/microsoft_sql_server_connector.py new file mode 100644 index 0000000000..710041f666 --- /dev/null +++ b/src/fides/api/service/connectors/microsoft_sql_server_connector.py @@ -0,0 +1,54 @@ +from typing import List + +from sqlalchemy.engine import URL, LegacyCursorResult # type: ignore + +from fides.api.graph.execution import ExecutionNode +from fides.api.schemas.connection_configuration import MicrosoftSQLServerSchema +from fides.api.service.connectors.query_configs.microsoft_sql_server_query_config import ( + MicrosoftSQLServerQueryConfig, +) +from fides.api.service.connectors.query_configs.query_config import SQLQueryConfig +from fides.api.service.connectors.sql_connector import SQLConnector +from fides.api.util.collection_util import Row +from fides.config import get_config + +CONFIG = get_config() + + +class MicrosoftSQLServerConnector(SQLConnector): + """ + Connector specific to Microsoft SQL Server + """ + + secrets_schema = MicrosoftSQLServerSchema + + def build_uri(self) -> URL: + """ + Build URI of format + mssql+pymssql://[username]:[password]@[host]:[port]/[dbname] + Returns URL obj, since SQLAlchemy's create_engine method accepts either a URL obj or a string + """ + + config = self.secrets_schema(**self.configuration.secrets or {}) + + url = URL.create( + "mssql+pymssql", + username=config.username, + password=config.password, + host=config.host, + port=config.port, + database=config.dbname, + ) + + return url + + def query_config(self, node: ExecutionNode) -> SQLQueryConfig: + """Query wrapper corresponding to the input execution_node.""" + return MicrosoftSQLServerQueryConfig(node) + + @staticmethod + def cursor_result_to_rows(results: LegacyCursorResult) -> List[Row]: + """ + Convert SQLAlchemy results to a list of dictionaries + """ + return SQLConnector.default_cursor_result_to_rows(results) diff --git a/src/fides/api/service/connectors/mongodb_connector.py b/src/fides/api/service/connectors/mongodb_connector.py index 1000724fda..3389d2f01d 100644 --- a/src/fides/api/service/connectors/mongodb_connector.py +++ b/src/fides/api/service/connectors/mongodb_connector.py @@ -13,7 +13,10 @@ MongoDBSchema, ) from fides.api.service.connectors.base_connector import BaseConnector -from fides.api.service.connectors.query_config import MongoQueryConfig, QueryConfig +from fides.api.service.connectors.query_configs.mongodb_query_config import ( + MongoQueryConfig, +) +from fides.api.service.connectors.query_configs.query_config import QueryConfig from fides.api.util.collection_util import Row from fides.api.util.logger import Pii diff --git a/src/fides/api/service/connectors/mysql_connector.py b/src/fides/api/service/connectors/mysql_connector.py new file mode 100644 index 0000000000..b912f8adbb --- /dev/null +++ b/src/fides/api/service/connectors/mysql_connector.py @@ -0,0 +1,87 @@ +from typing import List + +from sqlalchemy.engine import Engine, LegacyCursorResult, create_engine # type: ignore + +from fides.api.graph.execution import ExecutionNode +from fides.api.schemas.connection_configuration.connection_secrets_mysql import ( + MySQLSchema, +) +from fides.api.service.connectors.query_configs.mysql_query_config import ( + MySQLQueryConfig, +) +from fides.api.service.connectors.query_configs.query_config import SQLQueryConfig +from fides.api.service.connectors.sql_connector import SQLConnector +from fides.api.util.collection_util import Row +from fides.config import get_config + +CONFIG = get_config() + + +class MySQLConnector(SQLConnector): + """Connector specific to MySQL""" + + secrets_schema = MySQLSchema + + def build_uri(self) -> str: + """Build URI of format mysql+pymysql://[user[:password]@][netloc][:port][/dbname]""" + config = self.secrets_schema(**self.configuration.secrets or {}) + + user_password = "" + if config.username: + user = config.username + password = f":{config.password}" if config.password else "" + user_password = f"{user}{password}@" + + netloc = config.host + port = f":{config.port}" if config.port else "" + dbname = f"/{config.dbname}" if config.dbname else "" + url = f"mysql+pymysql://{user_password}{netloc}{port}{dbname}" + return url + + def build_ssh_uri(self, local_address: tuple) -> str: + """Build URI of format mysql+pymysql://[user[:password]@][ssh_host][:ssh_port][/dbname]""" + config = self.secrets_schema(**self.configuration.secrets or {}) + + user_password = "" + if config.username: + user = config.username + password = f":{config.password}" if config.password else "" + user_password = f"{user}{password}@" + + local_host, local_port = local_address + netloc = local_host + port = f":{local_port}" if local_port else "" + dbname = f"/{config.dbname}" if config.dbname else "" + url = f"mysql+pymysql://{user_password}{netloc}{port}{dbname}" + return url + + # Overrides SQLConnector.create_client + def create_client(self) -> Engine: + """Returns a SQLAlchemy Engine that can be used to interact with a database""" + if ( + self.configuration.secrets + and self.configuration.secrets.get("ssh_required", False) + and CONFIG.security.bastion_server_ssh_private_key + ): + config = self.secrets_schema(**self.configuration.secrets or {}) + self.create_ssh_tunnel(host=config.host, port=config.port) + self.ssh_server.start() + uri = self.build_ssh_uri(local_address=self.ssh_server.local_bind_address) + else: + uri = (self.configuration.secrets or {}).get("url") or self.build_uri() + return create_engine( + uri, + hide_parameters=self.hide_parameters, + echo=not self.hide_parameters, + ) + + def query_config(self, node: ExecutionNode) -> SQLQueryConfig: + """Query wrapper corresponding to the input execution_node.""" + return MySQLQueryConfig(node) + + @staticmethod + def cursor_result_to_rows(results: LegacyCursorResult) -> List[Row]: + """ + Convert SQLAlchemy results to a list of dictionaries + """ + return SQLConnector.default_cursor_result_to_rows(results) diff --git a/src/fides/api/service/connectors/postgres_connector.py b/src/fides/api/service/connectors/postgres_connector.py new file mode 100644 index 0000000000..5354d4ec13 --- /dev/null +++ b/src/fides/api/service/connectors/postgres_connector.py @@ -0,0 +1,84 @@ +from loguru import logger +from sqlalchemy import text +from sqlalchemy.engine import Connection, Engine, create_engine # type: ignore + +from fides.api.graph.execution import ExecutionNode +from fides.api.schemas.connection_configuration import PostgreSQLSchema +from fides.api.service.connectors.query_configs.postgres_query_config import ( + PostgresQueryConfig, +) +from fides.api.service.connectors.query_configs.query_config import SQLQueryConfig +from fides.api.service.connectors.sql_connector import SQLConnector +from fides.config import get_config + +CONFIG = get_config() + + +class PostgreSQLConnector(SQLConnector): + """Connector specific to postgresql""" + + secrets_schema = PostgreSQLSchema + + def build_uri(self) -> str: + """Build URI of format postgresql://[user[:password]@][netloc][:port][/dbname]""" + config = self.secrets_schema(**self.configuration.secrets or {}) + + user_password = "" + if config.username: + user = config.username + password = f":{config.password}" if config.password else "" + user_password = f"{user}{password}@" + + netloc = config.host + port = f":{config.port}" if config.port else "" + dbname = f"/{config.dbname}" if config.dbname else "" + return f"postgresql://{user_password}{netloc}{port}{dbname}" + + def build_ssh_uri(self, local_address: tuple) -> str: + """Build URI of format postgresql://[user[:password]@][ssh_host][:ssh_port][/dbname]""" + config = self.secrets_schema(**self.configuration.secrets or {}) + + user_password = "" + if config.username: + user = config.username + password = f":{config.password}" if config.password else "" + user_password = f"{user}{password}@" + + local_host, local_port = local_address + netloc = local_host + port = f":{local_port}" if local_port else "" + dbname = f"/{config.dbname}" if config.dbname else "" + return f"postgresql://{user_password}{netloc}{port}{dbname}" + + # Overrides SQLConnector.create_client + def create_client(self) -> Engine: + """Returns a SQLAlchemy Engine that can be used to interact with a database""" + if ( + self.configuration.secrets + and self.configuration.secrets.get("ssh_required", False) + and CONFIG.security.bastion_server_ssh_private_key + ): + config = self.secrets_schema(**self.configuration.secrets or {}) + self.create_ssh_tunnel(host=config.host, port=config.port) + self.ssh_server.start() + uri = self.build_ssh_uri(local_address=self.ssh_server.local_bind_address) + else: + uri = (self.configuration.secrets or {}).get("url") or self.build_uri() + return create_engine( + uri, + hide_parameters=self.hide_parameters, + echo=not self.hide_parameters, + ) + + def set_schema(self, connection: Connection) -> None: + """Sets the schema for a postgres database if applicable""" + config = self.secrets_schema(**self.configuration.secrets or {}) + if config.db_schema: + logger.info("Setting PostgreSQL search_path before retrieving data") + stmt = text("SET search_path to :search_path") + stmt = stmt.bindparams(search_path=config.db_schema) + connection.execute(stmt) + + def query_config(self, node: ExecutionNode) -> SQLQueryConfig: + """Query wrapper corresponding to the input execution_node.""" + return PostgresQueryConfig(node) diff --git a/src/fides/api/service/connectors/query_configs/__init__.py b/src/fides/api/service/connectors/query_configs/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/fides/api/service/connectors/query_configs/bigquery_query_config.py b/src/fides/api/service/connectors/query_configs/bigquery_query_config.py new file mode 100644 index 0000000000..681e2b9c60 --- /dev/null +++ b/src/fides/api/service/connectors/query_configs/bigquery_query_config.py @@ -0,0 +1,210 @@ +from typing import Any, Dict, List, Optional, Union, cast + +from fideslang.models import MaskingStrategies +from loguru import logger +from sqlalchemy import MetaData, Table, text +from sqlalchemy.engine import Engine +from sqlalchemy.sql import Delete, Update # type: ignore +from sqlalchemy.sql.elements import ColumnElement + +from fides.api.graph.config import Field +from fides.api.graph.execution import ExecutionNode +from fides.api.models.policy import Policy +from fides.api.models.privacy_request import PrivacyRequest +from fides.api.schemas.namespace_meta.bigquery_namespace_meta import ( + BigQueryNamespaceMeta, +) +from fides.api.service.connectors.query_configs.query_config import ( + QueryStringWithoutTuplesOverrideQueryConfig, +) +from fides.api.util.collection_util import Row, filter_nonempty_values + + +class BigQueryQueryConfig(QueryStringWithoutTuplesOverrideQueryConfig): + """ + Generates SQL valid for BigQuery + """ + + namespace_meta_schema = BigQueryNamespaceMeta + + @property + def partitioning(self) -> Optional[Dict]: + # Overriden from base implementation to allow for _only_ BQ partitioning, for now + return self.node.collection.partitioning + + def get_partition_clauses( + self, + ) -> List[str]: + """ + Returns the WHERE clauses specified in the partitioning spec + + Currently, only where-clause based partitioning is supported. + + TODO: derive partitions from a start/end/interval specification + + + NOTE: when we deprecate `where_clause` partitioning in favor of a more proper partitioning DSL, + we should be sure to still support the existing `where_clause` partition definition on + any in-progress DSRs so that they can run through to completion. + """ + partition_spec = self.partitioning + if not partition_spec: + logger.error( + "Partitioning clauses cannot be retrieved, no partitioning specification found" + ) + return [] + + if where_clauses := partition_spec.get("where_clauses"): + return where_clauses + + # TODO: implement more advanced partitioning support! + + raise ValueError( + "`where_clauses` must be specified in partitioning specification!" + ) + + def _generate_table_name(self) -> str: + """ + Prepends the dataset ID and project ID to the base table name + if the BigQuery namespace meta is provided. + """ + + table_name = self.node.collection.name + if self.namespace_meta: + bigquery_namespace_meta = cast(BigQueryNamespaceMeta, self.namespace_meta) + table_name = f"{bigquery_namespace_meta.dataset_id}.{table_name}" + if project_id := bigquery_namespace_meta.project_id: + table_name = f"{project_id}.{table_name}" + return table_name + + def get_formatted_query_string( + self, + field_list: str, + clauses: List[str], + ) -> str: + """ + Returns a query string with backtick formatting for tables that have the same names as + BigQuery reserved words. + """ + return f'SELECT {field_list} FROM `{self._generate_table_name()}` WHERE ({" OR ".join(clauses)})' + + def generate_masking_stmt( + self, + node: ExecutionNode, + row: Row, + policy: Policy, + request: PrivacyRequest, + client: Engine, + ) -> Union[List[Update], List[Delete]]: + """ + Generate a masking statement for BigQuery. + + If a masking override is present, it will take precedence over the policy masking strategy. + """ + + masking_override = node.collection.masking_strategy_override + if masking_override and masking_override.strategy == MaskingStrategies.DELETE: + logger.info( + f"Masking override detected for collection {node.address.value}: {masking_override.strategy.value}" + ) + return self.generate_delete(row, client) + return self.generate_update(row, policy, request, client) + + def generate_update( + self, row: Row, policy: Policy, request: PrivacyRequest, client: Engine + ) -> List[Update]: + """ + Using TextClause to insert 'None' values into BigQuery throws an exception, so we use update clause instead. + Returns a List of SQLAlchemy Update object. Does not actually execute the update object. + + A List of multiple Update objects are returned for partitioned tables; for a non-partitioned table, + a single Update object is returned in a List for consistent typing. + + TODO: DRY up this method and `generate_delete` a bit + """ + update_value_map: Dict[str, Any] = self.update_value_map(row, policy, request) + non_empty_primary_keys: Dict[str, Field] = filter_nonempty_values( + { + fpath.string_path: fld.cast(row[fpath.string_path]) + for fpath, fld in self.primary_key_field_paths.items() + if fpath.string_path in row + } + ) + + valid = len(non_empty_primary_keys) > 0 and update_value_map + if not valid: + logger.warning( + "There is not enough data to generate a valid update statement for {}", + self.node.address, + ) + return [] + + table = Table(self._generate_table_name(), MetaData(bind=client), autoload=True) + pk_clauses: List[ColumnElement] = [ + getattr(table.c, k) == v for k, v in non_empty_primary_keys.items() + ] + + if self.partitioning: + partition_clauses = self.get_partition_clauses() + partitioned_queries = [] + logger.info( + f"Generating {len(partition_clauses)} partition queries for node '{self.node.address}' in DSR execution" + ) + for partition_clause in partition_clauses: + partitioned_queries.append( + table.update() + .where(*(pk_clauses + [text(partition_clause)])) + .values(**update_value_map) + ) + + return partitioned_queries + + return [table.update().where(*pk_clauses).values(**update_value_map)] + + def generate_delete(self, row: Row, client: Engine) -> List[Delete]: + """Returns a List of SQLAlchemy DELETE statements for BigQuery. Does not actually execute the delete statement. + + Used when a collection-level masking override is present and the masking strategy is DELETE. + + A List of multiple DELETE statements are returned for partitioned tables; for a non-partitioned table, + a single DELETE statement is returned in a List for consistent typing. + + TODO: DRY up this method and `generate_update` a bit + """ + + non_empty_primary_keys: Dict[str, Field] = filter_nonempty_values( + { + fpath.string_path: fld.cast(row[fpath.string_path]) + for fpath, fld in self.primary_key_field_paths.items() + if fpath.string_path in row + } + ) + + valid = len(non_empty_primary_keys) > 0 + if not valid: + logger.warning( + "There is not enough data to generate a valid DELETE statement for {}", + self.node.address, + ) + return [] + + table = Table(self._generate_table_name(), MetaData(bind=client), autoload=True) + pk_clauses: List[ColumnElement] = [ + getattr(table.c, k) == v for k, v in non_empty_primary_keys.items() + ] + + if self.partitioning: + partition_clauses = self.get_partition_clauses() + partitioned_queries = [] + logger.info( + f"Generating {len(partition_clauses)} partition queries for node '{self.node.address}' in DSR execution" + ) + + for partition_clause in partition_clauses: + partitioned_queries.append( + table.delete().where(*(pk_clauses + [text(partition_clause)])) + ) + + return partitioned_queries + + return [table.delete().where(*pk_clauses)] diff --git a/src/fides/api/service/connectors/query_configs/dynamodb_query_config.py b/src/fides/api/service/connectors/query_configs/dynamodb_query_config.py new file mode 100644 index 0000000000..4a7e1c3764 --- /dev/null +++ b/src/fides/api/service/connectors/query_configs/dynamodb_query_config.py @@ -0,0 +1,86 @@ +from typing import Any, Dict, List, Optional, TypeVar + +from boto3.dynamodb.types import TypeSerializer + +from fides.api.graph.execution import ExecutionNode +from fides.api.models.policy import Policy +from fides.api.models.privacy_request import PrivacyRequest +from fides.api.service.connectors.query_configs.query_config import QueryConfig +from fides.api.util.collection_util import Row + +T = TypeVar("T") + +DynamoDBStatement = Dict[str, Any] +"""A DynamoDB query is formed using the boto3 library. The required parameters are: + * a table/collection name (string) + * the key name to pass when accessing the table, along with type and value (dict) + * optionally, the sort key or secondary index (i.e. timestamp) + * optionally, the specified attributes can be provided. If None, all attributes + returned for item. + + # TODO finish these docs + + We can either represent these items as a model and then handle each of the values + accordingly in the connector or use this query config to return a dictionary that + can be appropriately unpacked when executing using the client. + + The get_item query has been left out of the query_config for now. + + Add an example for put_item + """ + + +class DynamoDBQueryConfig(QueryConfig[DynamoDBStatement]): + def __init__( + self, node: ExecutionNode, attribute_definitions: List[Dict[str, Any]] + ): + super().__init__(node) + self.attribute_definitions = attribute_definitions + + def generate_query( + self, + input_data: Dict[str, List[Any]], + policy: Optional[Policy], + ) -> Optional[DynamoDBStatement]: + """Generates a dictionary for the `query` method used for DynamoDB""" + query_param = {} + serializer = TypeSerializer() + for attribute_definition in self.attribute_definitions: + attribute_name = attribute_definition["AttributeName"] + attribute_value = input_data[attribute_name][0] + query_param["ExpressionAttributeValues"] = { + ":value": serializer.serialize(attribute_value) + } + key_condition_expression: str = f"{attribute_name} = :value" + query_param["KeyConditionExpression"] = key_condition_expression # type: ignore + return query_param + + def generate_update_stmt( + self, row: Row, policy: Policy, request: PrivacyRequest + ) -> Optional[DynamoDBStatement]: + """ + Generate a Dictionary that contains necessary items to + run a PUT operation against DynamoDB + """ + update_clauses = self.update_value_map(row, policy, request) + + if update_clauses: + serializer = TypeSerializer() + update_items = row + for key, value in update_items.items(): + if key in update_clauses: + update_items[key] = serializer.serialize(update_clauses[key]) + else: + update_items[key] = serializer.serialize(value) + else: + update_items = None + + return update_items + + def query_to_str(self, t: T, input_data: Dict[str, List[Any]]) -> None: + """Not used for this connector""" + return None + + def dry_run_query(self) -> None: + """Not used for this connector""" + return None diff --git a/src/fides/api/service/connectors/query_configs/google_cloud_postgres_query_config.py b/src/fides/api/service/connectors/query_configs/google_cloud_postgres_query_config.py new file mode 100644 index 0000000000..bd987fcf08 --- /dev/null +++ b/src/fides/api/service/connectors/query_configs/google_cloud_postgres_query_config.py @@ -0,0 +1,7 @@ +from fides.api.service.connectors.query_configs.query_config import ( + QueryStringWithoutTuplesOverrideQueryConfig, +) + + +class GoogleCloudSQLPostgresQueryConfig(QueryStringWithoutTuplesOverrideQueryConfig): + """Generates SQL in Google Cloud SQL for Postgres' custom dialect.""" diff --git a/src/fides/api/service/connectors/query_configs/manual_query_config.py b/src/fides/api/service/connectors/query_configs/manual_query_config.py new file mode 100644 index 0000000000..a63b1bd1fa --- /dev/null +++ b/src/fides/api/service/connectors/query_configs/manual_query_config.py @@ -0,0 +1,91 @@ +# pylint: disable=too-many-lines +from typing import Any, Dict, List, Optional, TypeVar + +from sqlalchemy.sql import Executable # type: ignore + +from fides.api.models.policy import Policy +from fides.api.models.privacy_request import ManualAction, PrivacyRequest +from fides.api.service.connectors.query_configs.query_config import QueryConfig +from fides.api.util.collection_util import Row, filter_nonempty_values + +T = TypeVar("T") + + +class ManualQueryConfig(QueryConfig[Executable]): + def generate_query( + self, input_data: Dict[str, List[Any]], policy: Optional[Policy] + ) -> Optional[ManualAction]: + """Describe the details needed to manually retrieve data from the + current collection. + + Example: + { + "step": "access", + "collection": "manual_dataset:manual_collection", + "action_needed": [ + { + "locators": {'email': "customer-1@example.com"}, + "get": ["id", "box_id"] + "update": {} + } + ] + } + + """ + + locators: Dict[str, Any] = self.node.typed_filtered_values(input_data) + get: List[str] = [ + field_path.string_path + for field_path in self.node.collection.top_level_field_dict + ] + + if get and locators: + return ManualAction(locators=locators, get=get, update=None) + return None + + def query_to_str(self, t: T, input_data: Dict[str, List[Any]]) -> None: + """Not used for ManualQueryConfig, we output the dry run query as a dictionary instead of a string""" + + def dry_run_query(self) -> Optional[ManualAction]: # type: ignore + """Displays the ManualAction needed with question marks instead of action data for the locators + as a dry run query""" + fake_data: Dict[str, Any] = self.display_query_data() + manual_query: Optional[ManualAction] = self.generate_query(fake_data, None) + if not manual_query: + return None + + for where_params in manual_query.locators.values(): + for i, _ in enumerate(where_params): + where_params[i] = "?" + return manual_query + + def generate_update_stmt( + self, row: Row, policy: Policy, request: PrivacyRequest + ) -> Optional[ManualAction]: + """Describe the details needed to manually mask data in the + current collection. + + Example: + { + "step": "erasure", + "collection": "manual_dataset:manual_collection", + "action_needed": [ + { + "locators": {'id': 1}, + "get": [] + "update": {'authorized_user': None} + } + ] + } + """ + locators: Dict[str, Any] = filter_nonempty_values( + { + field_path.string_path: field.cast(row[field_path.string_path]) + for field_path, field in self.primary_key_field_paths.items() + } + ) + update_stmt: Dict[str, Any] = self.update_value_map(row, policy, request) + + if update_stmt and locators: + return ManualAction(locators=locators, get=None, update=update_stmt) + return None diff --git a/src/fides/api/service/connectors/query_configs/microsoft_sql_server_query_config.py b/src/fides/api/service/connectors/query_configs/microsoft_sql_server_query_config.py new file mode 100644 index 0000000000..427f2b2e22 --- /dev/null +++ b/src/fides/api/service/connectors/query_configs/microsoft_sql_server_query_config.py @@ -0,0 +1,9 @@ +from fides.api.service.connectors.query_configs.query_config import ( + QueryStringWithoutTuplesOverrideQueryConfig, +) + + +class MicrosoftSQLServerQueryConfig(QueryStringWithoutTuplesOverrideQueryConfig): + """ + Generates SQL valid for SQLServer. + """ diff --git a/src/fides/api/service/connectors/query_configs/mongodb_query_config.py b/src/fides/api/service/connectors/query_configs/mongodb_query_config.py new file mode 100644 index 0000000000..bd650723f4 --- /dev/null +++ b/src/fides/api/service/connectors/query_configs/mongodb_query_config.py @@ -0,0 +1,100 @@ +# pylint: disable=too-many-lines +from typing import Any, Dict, List, Optional, Tuple + +from loguru import logger + +from fides.api.models.policy import Policy +from fides.api.models.privacy_request import PrivacyRequest +from fides.api.service.connectors.query_configs.query_config import QueryConfig +from fides.api.util.collection_util import Row, filter_nonempty_values + +MongoStatement = Tuple[Dict[str, Any], Dict[str, Any]] +"""A mongo query is expressed in the form of 2 dicts, the first of which represents + the query object(s) and the second of which represents fields to return. + e.g. 'collection.find({k1:v1, k2:v2},{f1:1, f2:1 ... })'. This is returned as + a tuple ({k1:v1, k2:v2},{f1:1, f2:1 ... }). + + An update statement takes the form + collection.update_one({k1:v1},{k2:v2}...}, {$set: {f1:fv1, f2:fv2 ... }}, upsert=False). + This is returned as a tuple + ({k1:v1},{k2:v2}...}, {f1:fv1, f2: fv2 ... } + """ + + +class MongoQueryConfig(QueryConfig[MongoStatement]): + """Query config that translates parameters into mongo statements""" + + def generate_query( + self, input_data: Dict[str, List[Any]], policy: Optional[Policy] = None + ) -> Optional[MongoStatement]: + def transform_query_pairs(pairs: Dict[str, Any]) -> Dict[str, Any]: + """Since we want to do an 'OR' match in mongo, transform queries of the form + {A:1, B:2} => "{$or:[{A:1},{B:2}]}". + Don't bother to do this if the pairs size is 1 + """ + if len(pairs) < 2: + return pairs + return {"$or": [dict([(k, v)]) for k, v in pairs.items()]} + + if input_data: + filtered_data: Dict[str, Any] = self.node.typed_filtered_values(input_data) + if filtered_data: + query_pairs = {} + for string_field_path, data in filtered_data.items(): + if len(data) == 1: + query_pairs[string_field_path] = data[0] + + elif len(data) > 1: + query_pairs[string_field_path] = {"$in": data} + + field_list = { # Get top-level fields to avoid path collisions + field_path.string_path: 1 + for field_path, field in self.top_level_field_map().items() + } + query_fields, return_fields = ( + transform_query_pairs(query_pairs), + field_list, + ) + return query_fields, return_fields + + logger.warning( + "There is not enough data to generate a valid query for {}", + self.node.address, + ) + return None + + def generate_update_stmt( + self, row: Row, policy: Policy, request: PrivacyRequest + ) -> Optional[MongoStatement]: + """Generate a SQL update statement in the form of Mongo update statement components""" + update_clauses = self.update_value_map(row, policy, request) + + pk_clauses: Dict[str, Any] = filter_nonempty_values( + { + field_path.string_path: field.cast(row[field_path.string_path]) + for field_path, field in self.primary_key_field_paths.items() + } + ) + + valid = len(pk_clauses) > 0 and len(update_clauses) > 0 + if not valid: + logger.warning( + "There is not enough data to generate a valid update for {}", + self.node.address, + ) + return None + return pk_clauses, {"$set": update_clauses} + + def query_to_str(self, t: MongoStatement, input_data: Dict[str, List[Any]]) -> str: + """string representation of a query for logging/dry-run""" + query_data, field_list = t + db_name = self.node.address.dataset + collection_name = self.node.address.collection + return f"db.{db_name}.{collection_name}.find({query_data}, {field_list})" + + def dry_run_query(self) -> Optional[str]: + data = self.display_query_data() + mongo_query = self.generate_query(self.display_query_data(), None) + if mongo_query is not None: + return self.query_to_str(mongo_query, data) + return None diff --git a/src/fides/api/service/connectors/query_configs/mysql_query_config.py b/src/fides/api/service/connectors/query_configs/mysql_query_config.py new file mode 100644 index 0000000000..0149bf9617 --- /dev/null +++ b/src/fides/api/service/connectors/query_configs/mysql_query_config.py @@ -0,0 +1,18 @@ +from typing import List + +from fides.api.service.connectors.query_configs.query_config import SQLQueryConfig + + +class MySQLQueryConfig(SQLQueryConfig): + """ + Generates SQL valid for MySQL + """ + + def get_formatted_query_string( + self, + field_list: str, + clauses: List[str], + ) -> str: + """Returns a query string with backtick formatting for tables that have the same names as + MySQL reserved words.""" + return f'SELECT {field_list} FROM `{self.node.collection.name}` WHERE ({" OR ".join(clauses)})' diff --git a/src/fides/api/service/connectors/query_configs/postgres_query_config.py b/src/fides/api/service/connectors/query_configs/postgres_query_config.py new file mode 100644 index 0000000000..96f0f2060d --- /dev/null +++ b/src/fides/api/service/connectors/query_configs/postgres_query_config.py @@ -0,0 +1,18 @@ +from typing import List + +from fides.api.service.connectors.query_configs.query_config import SQLQueryConfig + + +class PostgresQueryConfig(SQLQueryConfig): + """ + Generates SQL valid for Postgres + """ + + def get_formatted_query_string( + self, + field_list: str, + clauses: List[str], + ) -> str: + """Returns a query string with double quotation mark formatting for tables that have the same names as + Postgres reserved words.""" + return f'SELECT {field_list} FROM "{self.node.collection.name}" WHERE ({" OR ".join(clauses)})' diff --git a/src/fides/api/service/connectors/query_config.py b/src/fides/api/service/connectors/query_configs/query_config.py similarity index 55% rename from src/fides/api/service/connectors/query_config.py rename to src/fides/api/service/connectors/query_configs/query_config.py index be95ebcaed..6e868964af 100644 --- a/src/fides/api/service/connectors/query_config.py +++ b/src/fides/api/service/connectors/query_configs/query_config.py @@ -1,17 +1,14 @@ # pylint: disable=too-many-lines import re from abc import ABC, abstractmethod -from typing import Any, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union, cast +from typing import Any, Dict, Generic, List, Optional, Tuple, Type, TypeVar import pydash -from boto3.dynamodb.types import TypeSerializer -from fideslang.models import MaskingStrategies from loguru import logger from pydantic import ValidationError -from sqlalchemy import MetaData, Table, text -from sqlalchemy.engine import Engine -from sqlalchemy.sql import Delete, Executable, Update # type: ignore -from sqlalchemy.sql.elements import ColumnElement, TextClause +from sqlalchemy import text +from sqlalchemy.sql import Executable # type: ignore +from sqlalchemy.sql.elements import TextClause from fides.api.common_exceptions import MissingNamespaceSchemaException from fides.api.graph.config import ( @@ -23,14 +20,8 @@ ) from fides.api.graph.execution import ExecutionNode from fides.api.models.policy import Policy, Rule -from fides.api.models.privacy_request import ManualAction, PrivacyRequest -from fides.api.schemas.namespace_meta.bigquery_namespace_meta import ( - BigQueryNamespaceMeta, -) +from fides.api.models.privacy_request import PrivacyRequest from fides.api.schemas.namespace_meta.namespace_meta import NamespaceMeta -from fides.api.schemas.namespace_meta.snowflake_namespace_meta import ( - SnowflakeNamespaceMeta, -) from fides.api.schemas.policy import ActionType from fides.api.service.masking.strategy.masking_strategy import MaskingStrategy from fides.api.service.masking.strategy.masking_strategy_nullify import ( @@ -286,86 +277,6 @@ def generate_update_stmt( returns None""" -class ManualQueryConfig(QueryConfig[Executable]): - def generate_query( - self, input_data: Dict[str, List[Any]], policy: Optional[Policy] - ) -> Optional[ManualAction]: - """Describe the details needed to manually retrieve data from the - current collection. - - Example: - { - "step": "access", - "collection": "manual_dataset:manual_collection", - "action_needed": [ - { - "locators": {'email': "customer-1@example.com"}, - "get": ["id", "box_id"] - "update": {} - } - ] - } - - """ - - locators: Dict[str, Any] = self.node.typed_filtered_values(input_data) - get: List[str] = [ - field_path.string_path - for field_path in self.node.collection.top_level_field_dict - ] - - if get and locators: - return ManualAction(locators=locators, get=get, update=None) - return None - - def query_to_str(self, t: T, input_data: Dict[str, List[Any]]) -> None: - """Not used for ManualQueryConfig, we output the dry run query as a dictionary instead of a string""" - - def dry_run_query(self) -> Optional[ManualAction]: # type: ignore - """Displays the ManualAction needed with question marks instead of action data for the locators - as a dry run query""" - fake_data: Dict[str, Any] = self.display_query_data() - manual_query: Optional[ManualAction] = self.generate_query(fake_data, None) - if not manual_query: - return None - - for where_params in manual_query.locators.values(): - for i, _ in enumerate(where_params): - where_params[i] = "?" - return manual_query - - def generate_update_stmt( - self, row: Row, policy: Policy, request: PrivacyRequest - ) -> Optional[ManualAction]: - """Describe the details needed to manually mask data in the - current collection. - - Example: - { - "step": "erasure", - "collection": "manual_dataset:manual_collection", - "action_needed": [ - { - "locators": {'id': 1}, - "get": [] - "update": {'authorized_user': None} - } - ] - } - """ - locators: Dict[str, Any] = filter_nonempty_values( - { - field_path.string_path: field.cast(row[field_path.string_path]) - for field_path, field in self.primary_key_field_paths.items() - } - ) - update_stmt: Dict[str, Any] = self.update_value_map(row, policy, request) - - if update_stmt and locators: - return ManualAction(locators=locators, get=None, update=update_stmt) - return None - - class SQLLikeQueryConfig(QueryConfig[T], ABC): """ Abstract query config for SQL-like languages (that may not be strictly SQL). @@ -672,36 +583,6 @@ def format_query_data_name(self, query_data_name: str) -> str: return f":{query_data_name}" -class PostgresQueryConfig(SQLQueryConfig): - """ - Generates SQL valid for Postgres - """ - - def get_formatted_query_string( - self, - field_list: str, - clauses: List[str], - ) -> str: - """Returns a query string with double quotation mark formatting for tables that have the same names as - Postgres reserved words.""" - return f'SELECT {field_list} FROM "{self.node.collection.name}" WHERE ({" OR ".join(clauses)})' - - -class MySQLQueryConfig(SQLQueryConfig): - """ - Generates SQL valid for MySQL - """ - - def get_formatted_query_string( - self, - field_list: str, - clauses: List[str], - ) -> str: - """Returns a query string with backtick formatting for tables that have the same names as - MySQL reserved words.""" - return f'SELECT {field_list} FROM `{self.node.collection.name}` WHERE ({" OR ".join(clauses)})' - - class QueryStringWithoutTuplesOverrideQueryConfig(SQLQueryConfig): """ Generates SQL valid for connectors that require the query string to be built without tuples. @@ -775,448 +656,3 @@ def generate_query( # pylint: disable=R0914 SELECT order_id,product_id,quantity FROM order_item WHERE order_id IN (:_in_stmt_generated_0, :_in_stmt_generated_1, :_in_stmt_generated_2) """ return self.generate_query_without_tuples(input_data, policy) - - -class MicrosoftSQLServerQueryConfig(QueryStringWithoutTuplesOverrideQueryConfig): - """ - Generates SQL valid for SQLServer. - """ - - -class SnowflakeQueryConfig(SQLQueryConfig): - """Generates SQL in Snowflake's custom dialect.""" - - namespace_meta_schema = SnowflakeNamespaceMeta - - def generate_raw_query( - self, field_list: List[str], filters: Dict[str, List[Any]] - ) -> Optional[TextClause]: - formatted_field_list = [f'"{field}"' for field in field_list] - raw_query = super().generate_raw_query(formatted_field_list, filters) - return raw_query # type: ignore - - def format_clause_for_query( - self, - string_path: str, - operator: str, - operand: str, - ) -> str: - """Returns field names in clauses surrounded by quotation marks as required by Snowflake syntax.""" - return f'"{string_path}" {operator} (:{operand})' - - def _generate_table_name(self) -> str: - """ - Prepends the dataset name and schema to the base table name - if the Snowflake namespace meta is provided. - """ - - table_name = ( - f'"{self.node.collection.name}"' # Always quote the base table name - ) - - if not self.namespace_meta: - return table_name - - snowflake_meta = cast(SnowflakeNamespaceMeta, self.namespace_meta) - qualified_name = f'"{snowflake_meta.schema}".{table_name}' - - if database_name := snowflake_meta.database_name: - return f'"{database_name}".{qualified_name}' - - return qualified_name - - def get_formatted_query_string( - self, - field_list: str, - clauses: List[str], - ) -> str: - """Returns a query string with double quotation mark formatting as required by Snowflake syntax.""" - return f'SELECT {field_list} FROM {self._generate_table_name()} WHERE ({" OR ".join(clauses)})' - - def format_key_map_for_update_stmt(self, fields: List[str]) -> List[str]: - """Adds the appropriate formatting for update statements in this datastore.""" - fields.sort() - return [f'"{k}" = :{k}' for k in fields] - - def get_update_stmt( - self, - update_clauses: List[str], - pk_clauses: List[str], - ) -> str: - """Returns a parameterized update statement in Snowflake dialect.""" - return f'UPDATE {self._generate_table_name()} SET {", ".join(update_clauses)} WHERE {" AND ".join(pk_clauses)}' - - -class RedshiftQueryConfig(SQLQueryConfig): - """Generates SQL in Redshift's custom dialect.""" - - def get_formatted_query_string( - self, - field_list: str, - clauses: List[str], - ) -> str: - """Returns a query string with double quotation mark formatting for tables that have the same names as - Redshift reserved words.""" - return f'SELECT {field_list} FROM "{self.node.collection.name}" WHERE ({" OR ".join(clauses)})' - - -class GoogleCloudSQLPostgresQueryConfig(QueryStringWithoutTuplesOverrideQueryConfig): - """Generates SQL in Google Cloud SQL for Postgres' custom dialect.""" - - -class BigQueryQueryConfig(QueryStringWithoutTuplesOverrideQueryConfig): - """ - Generates SQL valid for BigQuery - """ - - namespace_meta_schema = BigQueryNamespaceMeta - - @property - def partitioning(self) -> Optional[Dict]: - # Overriden from base implementation to allow for _only_ BQ partitioning, for now - return self.node.collection.partitioning - - def get_partition_clauses( - self, - ) -> List[str]: - """ - Returns the WHERE clauses specified in the partitioning spec - - Currently, only where-clause based partitioning is supported. - - TODO: derive partitions from a start/end/interval specification - - - NOTE: when we deprecate `where_clause` partitioning in favor of a more proper partitioning DSL, - we should be sure to still support the existing `where_clause` partition definition on - any in-progress DSRs so that they can run through to completion. - """ - partition_spec = self.partitioning - if not partition_spec: - logger.error( - "Partitioning clauses cannot be retrieved, no partitioning specification found" - ) - return [] - - if where_clauses := partition_spec.get("where_clauses"): - return where_clauses - - # TODO: implement more advanced partitioning support! - - raise ValueError( - "`where_clauses` must be specified in partitioning specification!" - ) - - def _generate_table_name(self) -> str: - """ - Prepends the dataset ID and project ID to the base table name - if the BigQuery namespace meta is provided. - """ - - table_name = self.node.collection.name - if self.namespace_meta: - bigquery_namespace_meta = cast(BigQueryNamespaceMeta, self.namespace_meta) - table_name = f"{bigquery_namespace_meta.dataset_id}.{table_name}" - if project_id := bigquery_namespace_meta.project_id: - table_name = f"{project_id}.{table_name}" - return table_name - - def get_formatted_query_string( - self, - field_list: str, - clauses: List[str], - ) -> str: - """ - Returns a query string with backtick formatting for tables that have the same names as - BigQuery reserved words. - """ - return f'SELECT {field_list} FROM `{self._generate_table_name()}` WHERE ({" OR ".join(clauses)})' - - def generate_masking_stmt( - self, - node: ExecutionNode, - row: Row, - policy: Policy, - request: PrivacyRequest, - client: Engine, - ) -> Union[List[Update], List[Delete]]: - """ - Generate a masking statement for BigQuery. - - If a masking override is present, it will take precedence over the policy masking strategy. - """ - - masking_override = node.collection.masking_strategy_override - if masking_override and masking_override.strategy == MaskingStrategies.DELETE: - logger.info( - f"Masking override detected for collection {node.address.value}: {masking_override.strategy.value}" - ) - return self.generate_delete(row, client) - return self.generate_update(row, policy, request, client) - - def generate_update( - self, row: Row, policy: Policy, request: PrivacyRequest, client: Engine - ) -> List[Update]: - """ - Using TextClause to insert 'None' values into BigQuery throws an exception, so we use update clause instead. - Returns a List of SQLAlchemy Update object. Does not actually execute the update object. - - A List of multiple Update objects are returned for partitioned tables; for a non-partitioned table, - a single Update object is returned in a List for consistent typing. - - TODO: DRY up this method and `generate_delete` a bit - """ - update_value_map: Dict[str, Any] = self.update_value_map(row, policy, request) - non_empty_primary_keys: Dict[str, Field] = filter_nonempty_values( - { - fpath.string_path: fld.cast(row[fpath.string_path]) - for fpath, fld in self.primary_key_field_paths.items() - if fpath.string_path in row - } - ) - - valid = len(non_empty_primary_keys) > 0 and update_value_map - if not valid: - logger.warning( - "There is not enough data to generate a valid update statement for {}", - self.node.address, - ) - return [] - - table = Table(self._generate_table_name(), MetaData(bind=client), autoload=True) - pk_clauses: List[ColumnElement] = [ - getattr(table.c, k) == v for k, v in non_empty_primary_keys.items() - ] - - if self.partitioning: - partition_clauses = self.get_partition_clauses() - partitioned_queries = [] - logger.info( - f"Generating {len(partition_clauses)} partition queries for node '{self.node.address}' in DSR execution" - ) - for partition_clause in partition_clauses: - partitioned_queries.append( - table.update() - .where(*(pk_clauses + [text(partition_clause)])) - .values(**update_value_map) - ) - - return partitioned_queries - - return [table.update().where(*pk_clauses).values(**update_value_map)] - - def generate_delete(self, row: Row, client: Engine) -> List[Delete]: - """Returns a List of SQLAlchemy DELETE statements for BigQuery. Does not actually execute the delete statement. - - Used when a collection-level masking override is present and the masking strategy is DELETE. - - A List of multiple DELETE statements are returned for partitioned tables; for a non-partitioned table, - a single DELETE statement is returned in a List for consistent typing. - - TODO: DRY up this method and `generate_update` a bit - """ - - non_empty_primary_keys: Dict[str, Field] = filter_nonempty_values( - { - fpath.string_path: fld.cast(row[fpath.string_path]) - for fpath, fld in self.primary_key_field_paths.items() - if fpath.string_path in row - } - ) - - valid = len(non_empty_primary_keys) > 0 - if not valid: - logger.warning( - "There is not enough data to generate a valid DELETE statement for {}", - self.node.address, - ) - return [] - - table = Table(self._generate_table_name(), MetaData(bind=client), autoload=True) - pk_clauses: List[ColumnElement] = [ - getattr(table.c, k) == v for k, v in non_empty_primary_keys.items() - ] - - if self.partitioning: - partition_clauses = self.get_partition_clauses() - partitioned_queries = [] - logger.info( - f"Generating {len(partition_clauses)} partition queries for node '{self.node.address}' in DSR execution" - ) - - for partition_clause in partition_clauses: - partitioned_queries.append( - table.delete().where(*(pk_clauses + [text(partition_clause)])) - ) - - return partitioned_queries - - return [table.delete().where(*pk_clauses)] - - -MongoStatement = Tuple[Dict[str, Any], Dict[str, Any]] -"""A mongo query is expressed in the form of 2 dicts, the first of which represents - the query object(s) and the second of which represents fields to return. - e.g. 'collection.find({k1:v1, k2:v2},{f1:1, f2:1 ... })'. This is returned as - a tuple ({k1:v1, k2:v2},{f1:1, f2:1 ... }). - - An update statement takes the form - collection.update_one({k1:v1},{k2:v2}...}, {$set: {f1:fv1, f2:fv2 ... }}, upsert=False). - This is returned as a tuple - ({k1:v1},{k2:v2}...}, {f1:fv1, f2: fv2 ... } - """ - - -class MongoQueryConfig(QueryConfig[MongoStatement]): - """Query config that translates parameters into mongo statements""" - - def generate_query( - self, input_data: Dict[str, List[Any]], policy: Optional[Policy] = None - ) -> Optional[MongoStatement]: - def transform_query_pairs(pairs: Dict[str, Any]) -> Dict[str, Any]: - """Since we want to do an 'OR' match in mongo, transform queries of the form - {A:1, B:2} => "{$or:[{A:1},{B:2}]}". - Don't bother to do this if the pairs size is 1 - """ - if len(pairs) < 2: - return pairs - return {"$or": [dict([(k, v)]) for k, v in pairs.items()]} - - if input_data: - filtered_data: Dict[str, Any] = self.node.typed_filtered_values(input_data) - if filtered_data: - query_pairs = {} - for string_field_path, data in filtered_data.items(): - if len(data) == 1: - query_pairs[string_field_path] = data[0] - - elif len(data) > 1: - query_pairs[string_field_path] = {"$in": data} - - field_list = { # Get top-level fields to avoid path collisions - field_path.string_path: 1 - for field_path, field in self.top_level_field_map().items() - } - query_fields, return_fields = ( - transform_query_pairs(query_pairs), - field_list, - ) - return query_fields, return_fields - - logger.warning( - "There is not enough data to generate a valid query for {}", - self.node.address, - ) - return None - - def generate_update_stmt( - self, row: Row, policy: Policy, request: PrivacyRequest - ) -> Optional[MongoStatement]: - """Generate a SQL update statement in the form of Mongo update statement components""" - update_clauses = self.update_value_map(row, policy, request) - - pk_clauses: Dict[str, Any] = filter_nonempty_values( - { - field_path.string_path: field.cast(row[field_path.string_path]) - for field_path, field in self.primary_key_field_paths.items() - } - ) - - valid = len(pk_clauses) > 0 and len(update_clauses) > 0 - if not valid: - logger.warning( - "There is not enough data to generate a valid update for {}", - self.node.address, - ) - return None - return pk_clauses, {"$set": update_clauses} - - def query_to_str(self, t: MongoStatement, input_data: Dict[str, List[Any]]) -> str: - """string representation of a query for logging/dry-run""" - query_data, field_list = t - db_name = self.node.address.dataset - collection_name = self.node.address.collection - return f"db.{db_name}.{collection_name}.find({query_data}, {field_list})" - - def dry_run_query(self) -> Optional[str]: - data = self.display_query_data() - mongo_query = self.generate_query(self.display_query_data(), None) - if mongo_query is not None: - return self.query_to_str(mongo_query, data) - return None - - -DynamoDBStatement = Dict[str, Any] -"""A DynamoDB query is formed using the boto3 library. The required parameters are: - * a table/collection name (string) - * the key name to pass when accessing the table, along with type and value (dict) - * optionally, the sort key or secondary index (i.e. timestamp) - * optionally, the specified attributes can be provided. If None, all attributes - returned for item. - - # TODO finish these docs - - We can either represent these items as a model and then handle each of the values - accordingly in the connector or use this query config to return a dictionary that - can be appropriately unpacked when executing using the client. - - The get_item query has been left out of the query_config for now. - - Add an example for put_item - """ - - -class DynamoDBQueryConfig(QueryConfig[DynamoDBStatement]): - def __init__( - self, node: ExecutionNode, attribute_definitions: List[Dict[str, Any]] - ): - super().__init__(node) - self.attribute_definitions = attribute_definitions - - def generate_query( - self, - input_data: Dict[str, List[Any]], - policy: Optional[Policy], - ) -> Optional[DynamoDBStatement]: - """Generates a dictionary for the `query` method used for DynamoDB""" - query_param = {} - serializer = TypeSerializer() - for attribute_definition in self.attribute_definitions: - attribute_name = attribute_definition["AttributeName"] - attribute_value = input_data[attribute_name][0] - query_param["ExpressionAttributeValues"] = { - ":value": serializer.serialize(attribute_value) - } - key_condition_expression: str = f"{attribute_name} = :value" - query_param["KeyConditionExpression"] = key_condition_expression # type: ignore - return query_param - - def generate_update_stmt( - self, row: Row, policy: Policy, request: PrivacyRequest - ) -> Optional[DynamoDBStatement]: - """ - Generate a Dictionary that contains necessary items to - run a PUT operation against DynamoDB - """ - update_clauses = self.update_value_map(row, policy, request) - - if update_clauses: - serializer = TypeSerializer() - update_items = row - for key, value in update_items.items(): - if key in update_clauses: - update_items[key] = serializer.serialize(update_clauses[key]) - else: - update_items[key] = serializer.serialize(value) - else: - update_items = None - - return update_items - - def query_to_str(self, t: T, input_data: Dict[str, List[Any]]) -> None: - """Not used for this connector""" - return None - - def dry_run_query(self) -> None: - """Not used for this connector""" - return None diff --git a/src/fides/api/service/connectors/query_configs/redshift_query_config.py b/src/fides/api/service/connectors/query_configs/redshift_query_config.py new file mode 100644 index 0000000000..df645abfd4 --- /dev/null +++ b/src/fides/api/service/connectors/query_configs/redshift_query_config.py @@ -0,0 +1,16 @@ +from typing import List + +from fides.api.service.connectors.query_configs.query_config import SQLQueryConfig + + +class RedshiftQueryConfig(SQLQueryConfig): + """Generates SQL in Redshift's custom dialect.""" + + def get_formatted_query_string( + self, + field_list: str, + clauses: List[str], + ) -> str: + """Returns a query string with double quotation mark formatting for tables that have the same names as + Redshift reserved words.""" + return f'SELECT {field_list} FROM "{self.node.collection.name}" WHERE ({" OR ".join(clauses)})' diff --git a/src/fides/api/service/connectors/saas_query_config.py b/src/fides/api/service/connectors/query_configs/saas_query_config.py similarity index 99% rename from src/fides/api/service/connectors/saas_query_config.py rename to src/fides/api/service/connectors/query_configs/saas_query_config.py index e72313a756..a85f0cab61 100644 --- a/src/fides/api/service/connectors/saas_query_config.py +++ b/src/fides/api/service/connectors/query_configs/saas_query_config.py @@ -27,7 +27,7 @@ SaaSRequest, ) from fides.api.schemas.saas.shared_schemas import SaaSRequestParams -from fides.api.service.connectors.query_config import QueryConfig +from fides.api.service.connectors.query_configs.query_config import QueryConfig from fides.api.util import saas_util from fides.api.util.collection_util import Row, merge_dicts from fides.api.util.saas_util import ( diff --git a/src/fides/api/service/connectors/query_configs/snowflake_query_config.py b/src/fides/api/service/connectors/query_configs/snowflake_query_config.py new file mode 100644 index 0000000000..574e1ea1b1 --- /dev/null +++ b/src/fides/api/service/connectors/query_configs/snowflake_query_config.py @@ -0,0 +1,73 @@ +# pylint: disable=too-many-lines +from typing import Any, Dict, List, Optional, cast + +from sqlalchemy.sql.elements import TextClause + +from fides.api.schemas.namespace_meta.snowflake_namespace_meta import ( + SnowflakeNamespaceMeta, +) +from fides.api.service.connectors.query_configs.query_config import SQLQueryConfig + + +class SnowflakeQueryConfig(SQLQueryConfig): + """Generates SQL in Snowflake's custom dialect.""" + + namespace_meta_schema = SnowflakeNamespaceMeta + + def generate_raw_query( + self, field_list: List[str], filters: Dict[str, List[Any]] + ) -> Optional[TextClause]: + formatted_field_list = [f'"{field}"' for field in field_list] + raw_query = super().generate_raw_query(formatted_field_list, filters) + return raw_query # type: ignore + + def format_clause_for_query( + self, + string_path: str, + operator: str, + operand: str, + ) -> str: + """Returns field names in clauses surrounded by quotation marks as required by Snowflake syntax.""" + return f'"{string_path}" {operator} (:{operand})' + + def _generate_table_name(self) -> str: + """ + Prepends the dataset name and schema to the base table name + if the Snowflake namespace meta is provided. + """ + + table_name = ( + f'"{self.node.collection.name}"' # Always quote the base table name + ) + + if not self.namespace_meta: + return table_name + + snowflake_meta = cast(SnowflakeNamespaceMeta, self.namespace_meta) + qualified_name = f'"{snowflake_meta.schema}".{table_name}' + + if database_name := snowflake_meta.database_name: + return f'"{database_name}".{qualified_name}' + + return qualified_name + + def get_formatted_query_string( + self, + field_list: str, + clauses: List[str], + ) -> str: + """Returns a query string with double quotation mark formatting as required by Snowflake syntax.""" + return f'SELECT {field_list} FROM {self._generate_table_name()} WHERE ({" OR ".join(clauses)})' + + def format_key_map_for_update_stmt(self, fields: List[str]) -> List[str]: + """Adds the appropriate formatting for update statements in this datastore.""" + fields.sort() + return [f'"{k}" = :{k}' for k in fields] + + def get_update_stmt( + self, + update_clauses: List[str], + pk_clauses: List[str], + ) -> str: + """Returns a parameterized update statement in Snowflake dialect.""" + return f'UPDATE {self._generate_table_name()} SET {", ".join(update_clauses)} WHERE {" AND ".join(pk_clauses)}' diff --git a/src/fides/api/service/connectors/rds_mysql_connector.py b/src/fides/api/service/connectors/rds_mysql_connector.py index 01f8a10bf1..a9ee7890a0 100644 --- a/src/fides/api/service/connectors/rds_mysql_connector.py +++ b/src/fides/api/service/connectors/rds_mysql_connector.py @@ -14,7 +14,10 @@ from fides.api.schemas.connection_configuration.connection_secrets_rds_mysql import ( RDSMySQLSchema, ) -from fides.api.service.connectors.query_config import MySQLQueryConfig, SQLQueryConfig +from fides.api.service.connectors.query_configs.mysql_query_config import ( + MySQLQueryConfig, +) +from fides.api.service.connectors.query_configs.query_config import SQLQueryConfig from fides.api.service.connectors.rds_connector_mixin import RDSConnectorMixin from fides.api.service.connectors.sql_connector import SQLConnector from fides.api.util.collection_util import Row diff --git a/src/fides/api/service/connectors/rds_postgres_connector.py b/src/fides/api/service/connectors/rds_postgres_connector.py index d04b90d841..2c1d1f51c5 100644 --- a/src/fides/api/service/connectors/rds_postgres_connector.py +++ b/src/fides/api/service/connectors/rds_postgres_connector.py @@ -14,10 +14,10 @@ from fides.api.schemas.connection_configuration.connection_secrets_rds_postgres import ( RDSPostgresSchema, ) -from fides.api.service.connectors.query_config import ( +from fides.api.service.connectors.query_configs.postgres_query_config import ( PostgresQueryConfig, - SQLQueryConfig, ) +from fides.api.service.connectors.query_configs.query_config import SQLQueryConfig from fides.api.service.connectors.rds_connector_mixin import RDSConnectorMixin from fides.api.service.connectors.sql_connector import SQLConnector from fides.api.util.collection_util import Row diff --git a/src/fides/api/service/connectors/redshift_connector.py b/src/fides/api/service/connectors/redshift_connector.py new file mode 100644 index 0000000000..14e149770a --- /dev/null +++ b/src/fides/api/service/connectors/redshift_connector.py @@ -0,0 +1,88 @@ +from typing import Dict, Union +from urllib.parse import quote_plus + +from loguru import logger +from sqlalchemy import text +from sqlalchemy.engine import Connection, Engine, create_engine # type: ignore + +from fides.api.graph.execution import ExecutionNode +from fides.api.schemas.connection_configuration import RedshiftSchema +from fides.api.service.connectors.query_configs.redshift_query_config import ( + RedshiftQueryConfig, +) +from fides.api.service.connectors.sql_connector import SQLConnector +from fides.config import get_config + +CONFIG = get_config() + + +class RedshiftConnector(SQLConnector): + """Connector specific to Amazon Redshift""" + + secrets_schema = RedshiftSchema + + def build_ssh_uri(self, local_address: tuple) -> str: + """Build SSH URI of format redshift+psycopg2://[user[:password]@][ssh_host][:ssh_port][/dbname]""" + local_host, local_port = local_address + + config = self.secrets_schema(**self.configuration.secrets or {}) + + port = f":{local_port}" if local_port else "" + database = f"/{config.database}" if config.database else "" + url = f"redshift+psycopg2://{config.user}:{config.password}@{local_host}{port}{database}" + return url + + # Overrides BaseConnector.build_uri + def build_uri(self) -> str: + """Build URI of format redshift+psycopg2://user:password@[host][:port][/database]""" + config = self.secrets_schema(**self.configuration.secrets or {}) + + url_encoded_password = quote_plus(config.password) + port = f":{config.port}" if config.port else "" + database = f"/{config.database}" if config.database else "" + url = f"redshift+psycopg2://{config.user}:{url_encoded_password}@{config.host}{port}{database}" + return url + + # Overrides SQLConnector.create_client + def create_client(self) -> Engine: + """Returns a SQLAlchemy Engine that can be used to interact with a database""" + connect_args: Dict[str, Union[int, str]] = {} + connect_args["sslmode"] = "prefer" + + # keep alive settings to prevent long-running queries from causing a connection close + connect_args["keepalives"] = 1 + connect_args["keepalives_idle"] = 30 + connect_args["keepalives_interval"] = 5 + connect_args["keepalives_count"] = 5 + + if ( + self.configuration.secrets + and self.configuration.secrets.get("ssh_required", False) + and CONFIG.security.bastion_server_ssh_private_key + ): + config = self.secrets_schema(**self.configuration.secrets or {}) + self.create_ssh_tunnel(host=config.host, port=config.port) + self.ssh_server.start() + uri = self.build_ssh_uri(local_address=self.ssh_server.local_bind_address) + else: + uri = (self.configuration.secrets or {}).get("url") or self.build_uri() + return create_engine( + uri, + hide_parameters=self.hide_parameters, + echo=not self.hide_parameters, + connect_args=connect_args, + ) + + def set_schema(self, connection: Connection) -> None: + """Sets the search_path for the duration of the session""" + config = self.secrets_schema(**self.configuration.secrets or {}) + if config.db_schema: + logger.info("Setting Redshift search_path before retrieving data") + stmt = text("SET search_path to :search_path") + stmt = stmt.bindparams(search_path=config.db_schema) + connection.execute(stmt) + + # Overrides SQLConnector.query_config + def query_config(self, node: ExecutionNode) -> RedshiftQueryConfig: + """Query wrapper corresponding to the input execution node.""" + return RedshiftQueryConfig(node) diff --git a/src/fides/api/service/connectors/s3_connector.py b/src/fides/api/service/connectors/s3_connector.py index 35ed1b5157..c0e9ca3ee6 100644 --- a/src/fides/api/service/connectors/s3_connector.py +++ b/src/fides/api/service/connectors/s3_connector.py @@ -9,7 +9,7 @@ from fides.api.models.privacy_request import PrivacyRequest, RequestTask from fides.api.schemas.connection_configuration.connection_secrets_s3 import S3Schema from fides.api.service.connectors.base_connector import BaseConnector -from fides.api.service.connectors.query_config import QueryConfig +from fides.api.service.connectors.query_configs.query_config import QueryConfig from fides.api.util.aws_util import get_aws_session from fides.api.util.collection_util import Row diff --git a/src/fides/api/service/connectors/saas_connector.py b/src/fides/api/service/connectors/saas_connector.py index 588ac9265a..b917b6cfda 100644 --- a/src/fides/api/service/connectors/saas_connector.py +++ b/src/fides/api/service/connectors/saas_connector.py @@ -39,8 +39,8 @@ SaaSRequestParams, ) from fides.api.service.connectors.base_connector import BaseConnector +from fides.api.service.connectors.query_configs.saas_query_config import SaaSQueryConfig from fides.api.service.connectors.saas.authenticated_client import AuthenticatedClient -from fides.api.service.connectors.saas_query_config import SaaSQueryConfig from fides.api.service.pagination.pagination_strategy import PaginationStrategy from fides.api.service.processors.post_processor_strategy.post_processor_strategy import ( PostProcessorStrategy, diff --git a/src/fides/api/service/connectors/scylla_query_config.py b/src/fides/api/service/connectors/scylla_query_config.py index 7d304664e7..2a72270a40 100644 --- a/src/fides/api/service/connectors/scylla_query_config.py +++ b/src/fides/api/service/connectors/scylla_query_config.py @@ -3,7 +3,7 @@ from fides.api.graph.config import Field from fides.api.models.policy import Policy -from fides.api.service.connectors.query_config import SQLLikeQueryConfig +from fides.api.service.connectors.query_configs.query_config import SQLLikeQueryConfig ScyllaDBStatement = Tuple[str, Dict[str, Any]] """ diff --git a/src/fides/api/service/connectors/snowflake_connector.py b/src/fides/api/service/connectors/snowflake_connector.py new file mode 100644 index 0000000000..3a3fcbecdb --- /dev/null +++ b/src/fides/api/service/connectors/snowflake_connector.py @@ -0,0 +1,78 @@ +from typing import Any, Dict, Union + +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import serialization +from snowflake.sqlalchemy import URL as Snowflake_URL +from sqlalchemy.orm import Session + +from fides.api.graph.execution import ExecutionNode +from fides.api.schemas.connection_configuration import SnowflakeSchema +from fides.api.service.connectors.query_configs.query_config import SQLQueryConfig +from fides.api.service.connectors.query_configs.snowflake_query_config import ( + SnowflakeQueryConfig, +) +from fides.api.service.connectors.sql_connector import SQLConnector +from fides.config import get_config + +CONFIG = get_config() + + +class SnowflakeConnector(SQLConnector): + """Connector specific to Snowflake""" + + secrets_schema = SnowflakeSchema + + def build_uri(self) -> str: + """Build URI of format 'snowflake://:@// + ?warehouse=&role=' + """ + config = self.secrets_schema(**self.configuration.secrets or {}) + + kwargs = {} + + if config.account_identifier: + kwargs["account"] = config.account_identifier + if config.user_login_name: + kwargs["user"] = config.user_login_name + if config.password: + kwargs["password"] = config.password + if config.database_name: + kwargs["database"] = config.database_name + if config.schema_name: + kwargs["schema"] = config.schema_name + if config.warehouse_name: + kwargs["warehouse"] = config.warehouse_name + if config.role_name: + kwargs["role"] = config.role_name + + url: str = Snowflake_URL(**kwargs) + return url + + def get_connect_args(self) -> Dict[str, Any]: + """Get connection arguments for the engine""" + config = self.secrets_schema(**self.configuration.secrets or {}) + connect_args: Dict[str, Union[str, bytes]] = {} + if config.private_key: + config.private_key = config.private_key.replace("\\n", "\n") + connect_args["private_key"] = config.private_key + if config.private_key_passphrase: + private_key_encoded = serialization.load_pem_private_key( + config.private_key.encode(), + password=config.private_key_passphrase.encode(), # pylint: disable=no-member + backend=default_backend(), + ) + private_key = private_key_encoded.private_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + connect_args["private_key"] = private_key + return connect_args + + def query_config(self, node: ExecutionNode) -> SQLQueryConfig: + """Query wrapper corresponding to the input execution_node.""" + + db: Session = Session.object_session(self.configuration) + return SnowflakeQueryConfig( + node, SQLConnector.get_namespace_meta(db, node.address.dataset) + ) diff --git a/src/fides/api/service/connectors/sql_connector.py b/src/fides/api/service/connectors/sql_connector.py index 8a22cb0106..ff6c627631 100644 --- a/src/fides/api/service/connectors/sql_connector.py +++ b/src/fides/api/service/connectors/sql_connector.py @@ -1,23 +1,14 @@ import io from abc import abstractmethod -from typing import Any, Dict, List, Optional, Type, Union -from urllib.parse import quote_plus +from typing import Any, Dict, List, Optional, Type import paramiko -import pg8000 -import pymysql import sshtunnel # type: ignore from aiohttp.client_exceptions import ClientResponseError -from cryptography.hazmat.backends import default_backend -from cryptography.hazmat.primitives import serialization -from google.cloud.sql.connector import Connector -from google.oauth2 import service_account from loguru import logger -from snowflake.sqlalchemy import URL as Snowflake_URL -from sqlalchemy import Column, select, text +from sqlalchemy import Column, select from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.engine import ( # type: ignore - URL, Connection, CursorResult, Engine, @@ -26,7 +17,6 @@ ) from sqlalchemy.exc import InternalError, OperationalError from sqlalchemy.orm import Session -from sqlalchemy.sql import Executable # type: ignore from sqlalchemy.sql.elements import TextClause from fides.api.common_exceptions import ( @@ -37,39 +27,9 @@ from fides.api.models.connectionconfig import ConnectionConfig, ConnectionTestStatus from fides.api.models.policy import Policy from fides.api.models.privacy_request import PrivacyRequest, RequestTask -from fides.api.schemas.connection_configuration import ( - ConnectionConfigSecretsSchema, - MicrosoftSQLServerSchema, - PostgreSQLSchema, - RedshiftSchema, - SnowflakeSchema, -) -from fides.api.schemas.connection_configuration.connection_secrets_bigquery import ( - BigQuerySchema, -) -from fides.api.schemas.connection_configuration.connection_secrets_google_cloud_sql_mysql import ( - GoogleCloudSQLMySQLSchema, -) -from fides.api.schemas.connection_configuration.connection_secrets_google_cloud_sql_postgres import ( - GoogleCloudSQLPostgresSchema, -) -from fides.api.schemas.connection_configuration.connection_secrets_mariadb import ( - MariaDBSchema, -) -from fides.api.schemas.connection_configuration.connection_secrets_mysql import ( - MySQLSchema, -) +from fides.api.schemas.connection_configuration import ConnectionConfigSecretsSchema from fides.api.service.connectors.base_connector import BaseConnector -from fides.api.service.connectors.query_config import ( - BigQueryQueryConfig, - GoogleCloudSQLPostgresQueryConfig, - MicrosoftSQLServerQueryConfig, - MySQLQueryConfig, - PostgresQueryConfig, - RedshiftQueryConfig, - SnowflakeQueryConfig, - SQLQueryConfig, -) +from fides.api.service.connectors.query_configs.query_config import SQLQueryConfig from fides.api.util.collection_util import Row from fides.config import get_config @@ -298,575 +258,3 @@ def partitioned_retrieval( raise NotImplementedError( "Partitioned retrieval is only supported for BigQuery currently!" ) - - -class PostgreSQLConnector(SQLConnector): - """Connector specific to postgresql""" - - secrets_schema = PostgreSQLSchema - - def build_uri(self) -> str: - """Build URI of format postgresql://[user[:password]@][netloc][:port][/dbname]""" - config = self.secrets_schema(**self.configuration.secrets or {}) - - user_password = "" - if config.username: - user = config.username - password = f":{config.password}" if config.password else "" - user_password = f"{user}{password}@" - - netloc = config.host - port = f":{config.port}" if config.port else "" - dbname = f"/{config.dbname}" if config.dbname else "" - return f"postgresql://{user_password}{netloc}{port}{dbname}" - - def build_ssh_uri(self, local_address: tuple) -> str: - """Build URI of format postgresql://[user[:password]@][ssh_host][:ssh_port][/dbname]""" - config = self.secrets_schema(**self.configuration.secrets or {}) - - user_password = "" - if config.username: - user = config.username - password = f":{config.password}" if config.password else "" - user_password = f"{user}{password}@" - - local_host, local_port = local_address - netloc = local_host - port = f":{local_port}" if local_port else "" - dbname = f"/{config.dbname}" if config.dbname else "" - return f"postgresql://{user_password}{netloc}{port}{dbname}" - - # Overrides SQLConnector.create_client - def create_client(self) -> Engine: - """Returns a SQLAlchemy Engine that can be used to interact with a database""" - if ( - self.configuration.secrets - and self.configuration.secrets.get("ssh_required", False) - and CONFIG.security.bastion_server_ssh_private_key - ): - config = self.secrets_schema(**self.configuration.secrets or {}) - self.create_ssh_tunnel(host=config.host, port=config.port) - self.ssh_server.start() - uri = self.build_ssh_uri(local_address=self.ssh_server.local_bind_address) - else: - uri = (self.configuration.secrets or {}).get("url") or self.build_uri() - return create_engine( - uri, - hide_parameters=self.hide_parameters, - echo=not self.hide_parameters, - ) - - def set_schema(self, connection: Connection) -> None: - """Sets the schema for a postgres database if applicable""" - config = self.secrets_schema(**self.configuration.secrets or {}) - if config.db_schema: - logger.info("Setting PostgreSQL search_path before retrieving data") - stmt = text("SET search_path to :search_path") - stmt = stmt.bindparams(search_path=config.db_schema) - connection.execute(stmt) - - def query_config(self, node: ExecutionNode) -> SQLQueryConfig: - """Query wrapper corresponding to the input execution_node.""" - return PostgresQueryConfig(node) - - -class MySQLConnector(SQLConnector): - """Connector specific to MySQL""" - - secrets_schema = MySQLSchema - - def build_uri(self) -> str: - """Build URI of format mysql+pymysql://[user[:password]@][netloc][:port][/dbname]""" - config = self.secrets_schema(**self.configuration.secrets or {}) - - user_password = "" - if config.username: - user = config.username - password = f":{config.password}" if config.password else "" - user_password = f"{user}{password}@" - - netloc = config.host - port = f":{config.port}" if config.port else "" - dbname = f"/{config.dbname}" if config.dbname else "" - url = f"mysql+pymysql://{user_password}{netloc}{port}{dbname}" - return url - - def build_ssh_uri(self, local_address: tuple) -> str: - """Build URI of format mysql+pymysql://[user[:password]@][ssh_host][:ssh_port][/dbname]""" - config = self.secrets_schema(**self.configuration.secrets or {}) - - user_password = "" - if config.username: - user = config.username - password = f":{config.password}" if config.password else "" - user_password = f"{user}{password}@" - - local_host, local_port = local_address - netloc = local_host - port = f":{local_port}" if local_port else "" - dbname = f"/{config.dbname}" if config.dbname else "" - url = f"mysql+pymysql://{user_password}{netloc}{port}{dbname}" - return url - - # Overrides SQLConnector.create_client - def create_client(self) -> Engine: - """Returns a SQLAlchemy Engine that can be used to interact with a database""" - if ( - self.configuration.secrets - and self.configuration.secrets.get("ssh_required", False) - and CONFIG.security.bastion_server_ssh_private_key - ): - config = self.secrets_schema(**self.configuration.secrets or {}) - self.create_ssh_tunnel(host=config.host, port=config.port) - self.ssh_server.start() - uri = self.build_ssh_uri(local_address=self.ssh_server.local_bind_address) - else: - uri = (self.configuration.secrets or {}).get("url") or self.build_uri() - return create_engine( - uri, - hide_parameters=self.hide_parameters, - echo=not self.hide_parameters, - ) - - def query_config(self, node: ExecutionNode) -> SQLQueryConfig: - """Query wrapper corresponding to the input execution_node.""" - return MySQLQueryConfig(node) - - @staticmethod - def cursor_result_to_rows(results: LegacyCursorResult) -> List[Row]: - """ - Convert SQLAlchemy results to a list of dictionaries - """ - return SQLConnector.default_cursor_result_to_rows(results) - - -class MariaDBConnector(SQLConnector): - """Connector specific to MariaDB""" - - secrets_schema = MariaDBSchema - - def build_uri(self) -> str: - """Build URI of format mariadb+pymysql://[user[:password]@][netloc][:port][/dbname]""" - config = self.secrets_schema(**self.configuration.secrets or {}) - - user_password = "" - if config.username: - user = config.username - password = f":{config.password}" if config.password else "" - user_password = f"{user}{password}@" - - netloc = config.host - port = f":{config.port}" if config.port else "" - dbname = f"/{config.dbname}" if config.dbname else "" - url = f"mariadb+pymysql://{user_password}{netloc}{port}{dbname}" - return url - - @staticmethod - def cursor_result_to_rows(results: LegacyCursorResult) -> List[Row]: - """ - Convert SQLAlchemy results to a list of dictionaries - """ - return SQLConnector.default_cursor_result_to_rows(results) - - -class RedshiftConnector(SQLConnector): - """Connector specific to Amazon Redshift""" - - secrets_schema = RedshiftSchema - - def build_ssh_uri(self, local_address: tuple) -> str: - """Build SSH URI of format redshift+psycopg2://[user[:password]@][ssh_host][:ssh_port][/dbname]""" - local_host, local_port = local_address - - config = self.secrets_schema(**self.configuration.secrets or {}) - - port = f":{local_port}" if local_port else "" - database = f"/{config.database}" if config.database else "" - url = f"redshift+psycopg2://{config.user}:{config.password}@{local_host}{port}{database}" - return url - - # Overrides BaseConnector.build_uri - def build_uri(self) -> str: - """Build URI of format redshift+psycopg2://user:password@[host][:port][/database]""" - config = self.secrets_schema(**self.configuration.secrets or {}) - - url_encoded_password = quote_plus(config.password) - port = f":{config.port}" if config.port else "" - database = f"/{config.database}" if config.database else "" - url = f"redshift+psycopg2://{config.user}:{url_encoded_password}@{config.host}{port}{database}" - return url - - # Overrides SQLConnector.create_client - def create_client(self) -> Engine: - """Returns a SQLAlchemy Engine that can be used to interact with a database""" - connect_args: Dict[str, Union[int, str]] = {} - connect_args["sslmode"] = "prefer" - - # keep alive settings to prevent long-running queries from causing a connection close - connect_args["keepalives"] = 1 - connect_args["keepalives_idle"] = 30 - connect_args["keepalives_interval"] = 5 - connect_args["keepalives_count"] = 5 - - if ( - self.configuration.secrets - and self.configuration.secrets.get("ssh_required", False) - and CONFIG.security.bastion_server_ssh_private_key - ): - config = self.secrets_schema(**self.configuration.secrets or {}) - self.create_ssh_tunnel(host=config.host, port=config.port) - self.ssh_server.start() - uri = self.build_ssh_uri(local_address=self.ssh_server.local_bind_address) - else: - uri = (self.configuration.secrets or {}).get("url") or self.build_uri() - return create_engine( - uri, - hide_parameters=self.hide_parameters, - echo=not self.hide_parameters, - connect_args=connect_args, - ) - - def set_schema(self, connection: Connection) -> None: - """Sets the search_path for the duration of the session""" - config = self.secrets_schema(**self.configuration.secrets or {}) - if config.db_schema: - logger.info("Setting Redshift search_path before retrieving data") - stmt = text("SET search_path to :search_path") - stmt = stmt.bindparams(search_path=config.db_schema) - connection.execute(stmt) - - # Overrides SQLConnector.query_config - def query_config(self, node: ExecutionNode) -> RedshiftQueryConfig: - """Query wrapper corresponding to the input execution node.""" - return RedshiftQueryConfig(node) - - -class BigQueryConnector(SQLConnector): - """Connector specific to Google BigQuery""" - - secrets_schema = BigQuerySchema - - # Overrides BaseConnector.build_uri - def build_uri(self) -> str: - """Build URI of format""" - config = self.secrets_schema(**self.configuration.secrets or {}) - dataset = f"/{config.dataset}" if config.dataset else "" - return f"bigquery://{config.keyfile_creds.project_id}{dataset}" # pylint: disable=no-member - - # Overrides SQLConnector.create_client - def create_client(self) -> Engine: - """ - Returns a SQLAlchemy Engine that can be used to interact with Google BigQuery. - - Overrides to pass in credentials_info - """ - secrets = self.configuration.secrets or {} - uri = secrets.get("url") or self.build_uri() - - keyfile_creds = secrets.get("keyfile_creds", {}) - credentials_info = dict(keyfile_creds) if keyfile_creds else {} - - return create_engine( - uri, - credentials_info=credentials_info, - hide_parameters=self.hide_parameters, - echo=not self.hide_parameters, - ) - - # Overrides SQLConnector.query_config - def query_config(self, node: ExecutionNode) -> BigQueryQueryConfig: - """Query wrapper corresponding to the input execution_node.""" - - db: Session = Session.object_session(self.configuration) - return BigQueryQueryConfig( - node, SQLConnector.get_namespace_meta(db, node.address.dataset) - ) - - def partitioned_retrieval( - self, - query_config: SQLQueryConfig, - connection: Connection, - stmt: TextClause, - ) -> List[Row]: - """ - Retrieve data against a partitioned table using the partitioning spec configured for this node to execute - multiple queries against the partitioned table. - - This is only supported by the BigQueryConnector currently. - - NOTE: when we deprecate `where_clause` partitioning in favor of a more proper partitioning DSL, - we should be sure to still support the existing `where_clause` partition definition on - any in-progress DSRs so that they can run through to completion. - """ - if not isinstance(query_config, BigQueryQueryConfig): - raise TypeError( - f"Unexpected query config of type '{type(query_config)}' passed to BigQueryConnector's `partitioned_retrieval`" - ) - - partition_clauses = query_config.get_partition_clauses() - logger.info( - f"Executing {len(partition_clauses)} partition queries for node '{query_config.node.address}' in DSR execution" - ) - rows = [] - for partition_clause in partition_clauses: - logger.debug( - f"Executing partition query with partition clause '{partition_clause}'" - ) - existing_bind_params = stmt.compile().params - partitioned_stmt = text(f"{stmt} AND ({text(partition_clause)})").params( - existing_bind_params - ) - results = connection.execute(partitioned_stmt) - rows.extend(self.cursor_result_to_rows(results)) - return rows - - # Overrides SQLConnector.test_connection - def test_connection(self) -> Optional[ConnectionTestStatus]: - """ - Overrides SQLConnector.test_connection with a BigQuery-specific connection test. - - The connection is tested using the native python client for BigQuery, since that is what's used - by the detection and discovery workflows/codepaths. - TODO: migrate the rest of this class, used for DSR execution, to also make use of the native bigquery client. - """ - try: - bq_schema = BigQuerySchema(**self.configuration.secrets or {}) - client = bq_schema.get_client() - all_projects = [project for project in client.list_projects()] - if all_projects: - return ConnectionTestStatus.succeeded - logger.error("No Bigquery Projects found with the provided credentials.") - raise ConnectionException( - "No Bigquery Projects found with the provided credentials." - ) - except Exception as e: - logger.exception(f"Error testing connection to remote BigQuery {str(e)}") - raise ConnectionException(f"Connection error: {e}") - - def mask_data( - self, - node: ExecutionNode, - policy: Policy, - privacy_request: PrivacyRequest, - request_task: RequestTask, - rows: List[Row], - ) -> int: - """Execute a masking request. Returns the number of records updated or deleted""" - query_config = self.query_config(node) - update_or_delete_ct = 0 - client = self.client() - for row in rows: - update_or_delete_stmts: List[Executable] = ( - query_config.generate_masking_stmt( - node, row, policy, privacy_request, client - ) - ) - if update_or_delete_stmts: - with client.connect() as connection: - for update_or_delete_stmt in update_or_delete_stmts: - results: LegacyCursorResult = connection.execute( - update_or_delete_stmt - ) - update_or_delete_ct = update_or_delete_ct + results.rowcount - return update_or_delete_ct - - -class SnowflakeConnector(SQLConnector): - """Connector specific to Snowflake""" - - secrets_schema = SnowflakeSchema - - def build_uri(self) -> str: - """Build URI of format 'snowflake://:@// - ?warehouse=&role=' - """ - config = self.secrets_schema(**self.configuration.secrets or {}) - - kwargs = {} - - if config.account_identifier: - kwargs["account"] = config.account_identifier - if config.user_login_name: - kwargs["user"] = config.user_login_name - if config.password: - kwargs["password"] = config.password - if config.database_name: - kwargs["database"] = config.database_name - if config.schema_name: - kwargs["schema"] = config.schema_name - if config.warehouse_name: - kwargs["warehouse"] = config.warehouse_name - if config.role_name: - kwargs["role"] = config.role_name - - url: str = Snowflake_URL(**kwargs) - return url - - def get_connect_args(self) -> Dict[str, Any]: - """Get connection arguments for the engine""" - config = self.secrets_schema(**self.configuration.secrets or {}) - connect_args: Dict[str, Union[str, bytes]] = {} - if config.private_key: - config.private_key = config.private_key.replace("\\n", "\n") - connect_args["private_key"] = config.private_key - if config.private_key_passphrase: - private_key_encoded = serialization.load_pem_private_key( - config.private_key.encode(), - password=config.private_key_passphrase.encode(), # pylint: disable=no-member - backend=default_backend(), - ) - private_key = private_key_encoded.private_bytes( - encoding=serialization.Encoding.DER, - format=serialization.PrivateFormat.PKCS8, - encryption_algorithm=serialization.NoEncryption(), - ) - connect_args["private_key"] = private_key - return connect_args - - def query_config(self, node: ExecutionNode) -> SQLQueryConfig: - """Query wrapper corresponding to the input execution_node.""" - - db: Session = Session.object_session(self.configuration) - return SnowflakeQueryConfig( - node, SQLConnector.get_namespace_meta(db, node.address.dataset) - ) - - -class MicrosoftSQLServerConnector(SQLConnector): - """ - Connector specific to Microsoft SQL Server - """ - - secrets_schema = MicrosoftSQLServerSchema - - def build_uri(self) -> URL: - """ - Build URI of format - mssql+pymssql://[username]:[password]@[host]:[port]/[dbname] - Returns URL obj, since SQLAlchemy's create_engine method accepts either a URL obj or a string - """ - - config = self.secrets_schema(**self.configuration.secrets or {}) - - url = URL.create( - "mssql+pymssql", - username=config.username, - password=config.password, - host=config.host, - port=config.port, - database=config.dbname, - ) - - return url - - def query_config(self, node: ExecutionNode) -> SQLQueryConfig: - """Query wrapper corresponding to the input execution_node.""" - return MicrosoftSQLServerQueryConfig(node) - - @staticmethod - def cursor_result_to_rows(results: LegacyCursorResult) -> List[Row]: - """ - Convert SQLAlchemy results to a list of dictionaries - """ - return SQLConnector.default_cursor_result_to_rows(results) - - -class GoogleCloudSQLMySQLConnector(SQLConnector): - """Connector specific to Google Cloud SQL for MySQL""" - - secrets_schema = GoogleCloudSQLMySQLSchema - - # Overrides SQLConnector.create_client - def create_client(self) -> Engine: - """Returns a SQLAlchemy Engine that can be used to interact with a database""" - - config = self.secrets_schema(**self.configuration.secrets or {}) - - credentials = service_account.Credentials.from_service_account_info( - dict(config.keyfile_creds) - ) - - # initialize connector with the loaded credentials - connector = Connector(credentials=credentials) - - def getconn() -> pymysql.connections.Connection: - conn: pymysql.connections.Connection = connector.connect( - config.instance_connection_name, - "pymysql", - user=config.db_iam_user, - db=config.dbname, - enable_iam_auth=True, - ) - return conn - - return create_engine("mysql+pymysql://", creator=getconn) - - @staticmethod - def cursor_result_to_rows(results: LegacyCursorResult) -> List[Row]: - """results to a list of dictionaries""" - return SQLConnector.default_cursor_result_to_rows(results) - - def build_uri(self) -> None: - """ - We need to override this method so it is not abstract anymore, and GoogleCloudSQLMySQLConnector is instantiable. - """ - - -class GoogleCloudSQLPostgresConnector(SQLConnector): - """Connector specific to Google Cloud SQL for Postgres""" - - secrets_schema = GoogleCloudSQLPostgresSchema - - @property - def default_db_name(self) -> str: - """Default database name for Google Cloud SQL Postgres""" - return "postgres" - - # Overrides SQLConnector.create_client - def create_client(self) -> Engine: - """Returns a SQLAlchemy Engine that can be used to interact with a database""" - - config = self.secrets_schema(**self.configuration.secrets or {}) - - credentials = service_account.Credentials.from_service_account_info( - dict(config.keyfile_creds) - ) - - # initialize connector with the loaded credentials - connector = Connector(credentials=credentials) - - def getconn() -> pg8000.dbapi.Connection: - conn: pg8000.dbapi.Connection = connector.connect( - config.instance_connection_name, - "pg8000", - user=config.db_iam_user, - db=config.dbname or self.default_db_name, - enable_iam_auth=True, - ) - return conn - - return create_engine("postgresql+pg8000://", creator=getconn) - - @staticmethod - def cursor_result_to_rows(results: LegacyCursorResult) -> List[Row]: - """results to a list of dictionaries""" - return SQLConnector.default_cursor_result_to_rows(results) - - def build_uri(self) -> None: - """ - We need to override this method so it is not abstract anymore, and GoogleCloudSQLPostgresConnector is instantiable. - """ - - def set_schema(self, connection: Connection) -> None: - """Sets the schema for a postgres database if applicable""" - config = self.secrets_schema(**self.configuration.secrets or {}) - if config.db_schema: - logger.info("Setting PostgreSQL search_path before retrieving data") - stmt = text("SELECT set_config('search_path', :search_path, false)") - stmt = stmt.bindparams(search_path=config.db_schema) - connection.execute(stmt) - - # Overrides SQLConnector.query_config - def query_config(self, node: ExecutionNode) -> GoogleCloudSQLPostgresQueryConfig: - """Query wrapper corresponding to the input execution_node.""" - return GoogleCloudSQLPostgresQueryConfig(node) diff --git a/src/fides/api/service/connectors/timescale_connector.py b/src/fides/api/service/connectors/timescale_connector.py index 633e5347d7..2f16bd4eb1 100644 --- a/src/fides/api/service/connectors/timescale_connector.py +++ b/src/fides/api/service/connectors/timescale_connector.py @@ -1,4 +1,4 @@ -from fides.api.service.connectors.sql_connector import PostgreSQLConnector +from fides.api.service.connectors.postgres_connector import PostgreSQLConnector class TimescaleConnector(PostgreSQLConnector): diff --git a/tests/fixtures/application_fixtures.py b/tests/fixtures/application_fixtures.py index 54fe6f3784..8356c42111 100644 --- a/tests/fixtures/application_fixtures.py +++ b/tests/fixtures/application_fixtures.py @@ -939,6 +939,57 @@ def biquery_erasure_policy( pass +@pytest.fixture(scope="function") +def bigquery_enterprise_erasure_policy( + db: Session, + oauth_client: ClientDetail, +) -> Generator: + erasure_policy = Policy.create( + db=db, + data={ + "name": "example enterprise erasure policy", + "key": "example_enterprise_erasure_policy", + "client_id": oauth_client.id, + }, + ) + + erasure_rule = Rule.create( + db=db, + data={ + "action_type": ActionType.erasure.value, + "client_id": oauth_client.id, + "name": "Erasure Rule Enterprise", + "policy_id": erasure_policy.id, + "masking_strategy": { + "strategy": "null_rewrite", + "configuration": {}, + }, + }, + ) + + user_target = RuleTarget.create( + db=db, + data={ + "client_id": oauth_client.id, + "data_category": DataCategory("user.contact").value, + "rule_id": erasure_rule.id, + }, + ) + yield erasure_policy + try: + user_target.delete(db) + except ObjectDeletedError: + pass + try: + erasure_rule.delete(db) + except ObjectDeletedError: + pass + try: + erasure_policy.delete(db) + except ObjectDeletedError: + pass + + @pytest.fixture(scope="function") def erasure_policy_aes( db: Session, diff --git a/tests/fixtures/bigquery_fixtures.py b/tests/fixtures/bigquery_fixtures.py index 9c7ef2f2bc..105910e466 100644 --- a/tests/fixtures/bigquery_fixtures.py +++ b/tests/fixtures/bigquery_fixtures.py @@ -1,5 +1,7 @@ import ast import os +import random +from datetime import datetime from typing import Dict, Generator, List from uuid import uuid4 @@ -449,6 +451,101 @@ def bigquery_resources_with_namespace_meta( connection.execute(stmt) +@pytest.fixture(scope="function") +def bigquery_enterprise_resources( + bigquery_enterprise_test_dataset_config, +): + bigquery_connection_config = ( + bigquery_enterprise_test_dataset_config.connection_config + ) + connector = BigQueryConnector(bigquery_connection_config) + bigquery_client = connector.client() + with bigquery_client.connect() as connection: + + # Real max id in the Stackoverflow dataset is 20081052, so we purposefully generate and id above this max + stmt = "select max(id) from enterprise_dsr_testing.users;" + res = connection.execute(stmt) + # Increment the id by a random number to avoid conflicts on concurrent test runs + random_increment = random.randint(0, 99999) + user_id = res.all()[0][0] + random_increment + display_name = ( + f"fides_testing_{user_id}" # prefix to do manual cleanup if needed + ) + last_access_date = datetime.now() + creation_date = datetime.now() + location = "Dream World" + + # Create test user data + stmt = f""" + insert into enterprise_dsr_testing.users (id, display_name, last_access_date, creation_date, location) + values ({user_id}, '{display_name}', '{last_access_date}', '{creation_date}', '{location}'); + """ + connection.execute(stmt) + + # Create test stackoverflow_posts data. Posts are responses to questions on Stackoverflow, and does not include original question. + post_body = "For me, the solution was to adopt 3 cats and dance with them under the full moon at midnight." + stmt = "select max(id) from enterprise_dsr_testing.stackoverflow_posts;" + res = connection.execute(stmt) + random_increment = random.randint(0, 99999) + post_id = res.all()[0][0] + random_increment + stmt = f""" + insert into enterprise_dsr_testing.stackoverflow_posts (body, creation_date, id, owner_user_id, owner_display_name) + values ('{post_body}', '{creation_date}', {post_id}, {user_id}, '{display_name}'); + """ + connection.execute(stmt) + + # Create test comments data. Comments are responses to posts or questions on Stackoverflow, and does not include original question or post itself. + stmt = "select max(id) from enterprise_dsr_testing.comments;" + res = connection.execute(stmt) + random_increment = random.randint(0, 99999) + comment_id = res.all()[0][0] + random_increment + comment_text = "FYI this only works if you have pytest installed locally." + stmt = f""" + insert into enterprise_dsr_testing.comments (id, text, creation_date, post_id, user_id, user_display_name) + values ({comment_id}, '{comment_text}', '{creation_date}', {post_id}, {user_id}, '{display_name}'); + """ + connection.execute(stmt) + + # Create test post_history data + stmt = "select max(id) from enterprise_dsr_testing.comments;" + res = connection.execute(stmt) + random_increment = random.randint(0, 99999) + post_history_id = res.all()[0][0] + random_increment + revision_text = "this works if you have pytest" + uuid = str(uuid4()) + stmt = f""" + insert into enterprise_dsr_testing.post_history (id, text, creation_date, post_id, user_id, post_history_type_id, revision_guid) + values ({post_history_id}, '{revision_text}', '{creation_date}', {post_id}, {user_id}, 1, '{uuid}'); + """ + connection.execute(stmt) + + yield { + "name": display_name, + "user_id": user_id, + "comment_id": comment_id, + "post_history_id": post_history_id, + "post_id": post_id, + "client": bigquery_client, + "connector": connector, + "first_comment_text": comment_text, + "first_post_body": post_body, + "revision_text": revision_text, + "display_name": display_name, + } + # Remove test data and close BigQuery connection in teardown + stmt = f"delete from enterprise_dsr_testing.post_history where id = {post_history_id};" + connection.execute(stmt) + + stmt = f"delete from enterprise_dsr_testing.comments where id = {comment_id};" + connection.execute(stmt) + + stmt = f"delete from enterprise_dsr_testing.stackoverflow_posts where id = {post_id};" + connection.execute(stmt) + + stmt = f"delete from enterprise_dsr_testing.users where id = {user_id};" + connection.execute(stmt) + + @pytest.fixture(scope="session") def bigquery_test_engine(bigquery_keyfile_creds) -> Generator: """Return a connection to a Google BigQuery Warehouse""" diff --git a/tests/ops/api/v1/endpoints/test_dataset_test_endpoints.py b/tests/ops/api/v1/endpoints/test_dataset_test_endpoints.py index 347ab9a4f1..cc5a2cd20c 100644 --- a/tests/ops/api/v1/endpoints/test_dataset_test_endpoints.py +++ b/tests/ops/api/v1/endpoints/test_dataset_test_endpoints.py @@ -194,7 +194,7 @@ def test_dataset_reachability( assert set(response.json().keys()) == {"reachable", "details"} -@pytest.mark.integration_external +@pytest.mark.integration @pytest.mark.integration_postgres class TestDatasetTest: @pytest.fixture(scope="function") diff --git a/tests/ops/integration_tests/setup_scripts/postgres_setup.py b/tests/ops/integration_tests/setup_scripts/postgres_setup.py index 233726706a..e0cf34ab35 100644 --- a/tests/ops/integration_tests/setup_scripts/postgres_setup.py +++ b/tests/ops/integration_tests/setup_scripts/postgres_setup.py @@ -15,6 +15,7 @@ # Need to manually import this model because it's used in src/fides/api/models/property.py # but that file only imports it conditionally if TYPE_CHECKING is true +from fides.api.models.detection_discovery import MonitorConfig from fides.api.models.experience_notices import ExperienceNotices from fides.api.models.privacy_experience import PrivacyExperienceConfig from fides.api.service.connectors.sql_connector import PostgreSQLConnector diff --git a/tests/ops/integration_tests/test_connection_configuration_integration.py b/tests/ops/integration_tests/test_connection_configuration_integration.py index 87e641a6bc..ce4677464b 100644 --- a/tests/ops/integration_tests/test_connection_configuration_integration.py +++ b/tests/ops/integration_tests/test_connection_configuration_integration.py @@ -20,11 +20,11 @@ ScyllaConnector, get_connector, ) -from fides.api.service.connectors.sql_connector import ( - MariaDBConnector, +from fides.api.service.connectors.mariadb_connector import MariaDBConnector +from fides.api.service.connectors.microsoft_sql_server_connector import ( MicrosoftSQLServerConnector, - MySQLConnector, ) +from fides.api.service.connectors.mysql_connector import MySQLConnector from fides.common.api.scope_registry import ( CONNECTION_CREATE_OR_UPDATE, CONNECTION_READ, diff --git a/tests/ops/service/connection_config/test_mariadb_connector.py b/tests/ops/service/connection_config/test_mariadb_connector.py index 8a594d62f4..adac0b416b 100644 --- a/tests/ops/service/connection_config/test_mariadb_connector.py +++ b/tests/ops/service/connection_config/test_mariadb_connector.py @@ -1,6 +1,6 @@ from sqlalchemy.orm import Session -from fides.api.service.connectors.sql_connector import MariaDBConnector +from fides.api.service.connectors.mariadb_connector import MariaDBConnector def test_mariadb_connector_build_uri(connection_config_mariadb, db: Session): diff --git a/tests/ops/service/connection_config/test_mysql_connector.py b/tests/ops/service/connection_config/test_mysql_connector.py index f5c1583f2c..7ee4dfc741 100644 --- a/tests/ops/service/connection_config/test_mysql_connector.py +++ b/tests/ops/service/connection_config/test_mysql_connector.py @@ -1,6 +1,6 @@ from sqlalchemy.orm import Session -from fides.api.service.connectors.sql_connector import MySQLConnector +from fides.api.service.connectors.mysql_connector import MySQLConnector def test_mysql_connector_build_uri(connection_config_mysql, db: Session): diff --git a/tests/ops/service/connectors/test_bigquery_connector.py b/tests/ops/service/connectors/test_bigquery_connector.py index 98ad524204..a9524777fe 100644 --- a/tests/ops/service/connectors/test_bigquery_connector.py +++ b/tests/ops/service/connectors/test_bigquery_connector.py @@ -12,7 +12,7 @@ from fides.api.schemas.namespace_meta.bigquery_namespace_meta import ( BigQueryNamespaceMeta, ) -from fides.api.service.connectors.sql_connector import BigQueryConnector +from fides.api.service.connectors.bigquery_connector import BigQueryConnector @pytest.mark.integration_external diff --git a/tests/ops/service/connectors/test_bigquery_queryconfig.py b/tests/ops/service/connectors/test_bigquery_queryconfig.py index 771689e983..06c51c5105 100644 --- a/tests/ops/service/connectors/test_bigquery_queryconfig.py +++ b/tests/ops/service/connectors/test_bigquery_queryconfig.py @@ -13,7 +13,9 @@ BigQueryNamespaceMeta, ) from fides.api.service.connectors import BigQueryConnector -from fides.api.service.connectors.query_config import BigQueryQueryConfig +from fides.api.service.connectors.query_configs.bigquery_query_config import ( + BigQueryQueryConfig, +) @pytest.mark.integration_external diff --git a/tests/ops/service/connectors/test_queryconfig.py b/tests/ops/service/connectors/test_query_config.py similarity index 99% rename from tests/ops/service/connectors/test_queryconfig.py rename to tests/ops/service/connectors/test_query_config.py index d6e591d2d7..01d7b9dbd2 100644 --- a/tests/ops/service/connectors/test_queryconfig.py +++ b/tests/ops/service/connectors/test_query_config.py @@ -1,5 +1,5 @@ from datetime import datetime, timezone -from typing import Any, Dict, Generator, Set +from typing import Any, Dict, Set from unittest import mock import pytest @@ -22,9 +22,13 @@ from fides.api.schemas.masking.masking_configuration import HashMaskingConfiguration from fides.api.schemas.masking.masking_secrets import MaskingSecretCache, SecretType from fides.api.schemas.namespace_meta.namespace_meta import NamespaceMeta -from fides.api.service.connectors.query_config import ( +from fides.api.service.connectors.query_configs.dynamodb_query_config import ( DynamoDBQueryConfig, +) +from fides.api.service.connectors.query_configs.mongodb_query_config import ( MongoQueryConfig, +) +from fides.api.service.connectors.query_configs.query_config import ( QueryConfig, SQLQueryConfig, ) diff --git a/tests/ops/service/connectors/test_saas_queryconfig.py b/tests/ops/service/connectors/test_saas_query_config.py similarity index 99% rename from tests/ops/service/connectors/test_saas_queryconfig.py rename to tests/ops/service/connectors/test_saas_query_config.py index 4a2e2b44e7..1e9b6d326a 100644 --- a/tests/ops/service/connectors/test_saas_queryconfig.py +++ b/tests/ops/service/connectors/test_saas_query_config.py @@ -14,8 +14,8 @@ from fides.api.models.privacy_request import PrivacyRequest from fides.api.schemas.saas.saas_config import ParamValue, SaaSConfig, SaaSRequest from fides.api.schemas.saas.shared_schemas import HTTPMethod, SaaSRequestParams +from fides.api.service.connectors.query_configs.saas_query_config import SaaSQueryConfig from fides.api.service.connectors.saas_connector import SaaSConnector -from fides.api.service.connectors.saas_query_config import SaaSQueryConfig from fides.api.util.saas_util import ( CUSTOM_PRIVACY_REQUEST_FIELDS, FIDESOPS_GROUPED_INPUTS, diff --git a/tests/ops/service/connectors/test_snowflake_connector.py b/tests/ops/service/connectors/test_snowflake_connector.py index 6ce44df2bf..af88314c7e 100644 --- a/tests/ops/service/connectors/test_snowflake_connector.py +++ b/tests/ops/service/connectors/test_snowflake_connector.py @@ -10,7 +10,7 @@ from fides.api.schemas.namespace_meta.snowflake_namespace_meta import ( SnowflakeNamespaceMeta, ) -from fides.api.service.connectors.sql_connector import SnowflakeConnector +from fides.api.service.connectors.snowflake_connector import SnowflakeConnector @pytest.mark.integration_external diff --git a/tests/ops/service/connectors/test_snowflake_query_config.py b/tests/ops/service/connectors/test_snowflake_query_config.py index 656ad75c3c..5521a1a88a 100644 --- a/tests/ops/service/connectors/test_snowflake_query_config.py +++ b/tests/ops/service/connectors/test_snowflake_query_config.py @@ -13,7 +13,9 @@ SnowflakeNamespaceMeta, ) from fides.api.service.connectors import SnowflakeConnector -from fides.api.service.connectors.query_config import SnowflakeQueryConfig +from fides.api.service.connectors.query_configs.snowflake_query_config import ( + SnowflakeQueryConfig, +) @pytest.mark.integration_external diff --git a/tests/ops/service/dataset/test_dataset_service.py b/tests/ops/service/dataset/test_dataset_service.py index c1e56f8d01..ca433671cf 100644 --- a/tests/ops/service/dataset/test_dataset_service.py +++ b/tests/ops/service/dataset/test_dataset_service.py @@ -111,9 +111,13 @@ def test_get_identities_and_references( assert required_identities == expected_required_identities -@pytest.mark.integration_external +@pytest.mark.integration @pytest.mark.integration_postgres class TestRunTestAccessRequest: + """ + Run test requests against the postgres_example database + """ + @pytest.mark.usefixtures("postgres_integration_db") def test_run_test_access_request( self, diff --git a/tests/ops/service/privacy_request/test_bigquery_enterprise_privacy_request.py b/tests/ops/service/privacy_request/test_bigquery_enterprise_privacy_request.py index 05fc8742a3..8fb7e29729 100644 --- a/tests/ops/service/privacy_request/test_bigquery_enterprise_privacy_request.py +++ b/tests/ops/service/privacy_request/test_bigquery_enterprise_privacy_request.py @@ -28,7 +28,7 @@ PRIVACY_REQUEST_TASK_TIMEOUT = 5 # External services take much longer to return -PRIVACY_REQUEST_TASK_TIMEOUT_EXTERNAL = 60 +PRIVACY_REQUEST_TASK_TIMEOUT_EXTERNAL = 100 @pytest.mark.integration_bigquery @@ -101,7 +101,7 @@ def test_create_and_process_access_request_bigquery_enterprise( len( [post["user_id"] for post in results["enterprise_dsr_testing:post_history"]] ) - == 60 + == 39 ) assert ( len( @@ -139,3 +139,144 @@ def test_create_and_process_access_request_bigquery_enterprise( pr.delete(db=db) assert not pr in db # Check that `pr` has been expunged from the session assert ExecutionLog.get(db, object_id=log_id).privacy_request_id == pr_id + + +@pytest.mark.integration_external +@pytest.mark.integration_bigquery +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) +@pytest.mark.parametrize( + "bigquery_fixtures", + [ + "bigquery_enterprise_resources" + ], # todo- add other resources to test, e.g. partitioned data +) +def test_create_and_process_erasure_request_bigquery( + db, + request, + policy, + cache, + dsr_version, + bigquery_fixtures, + bigquery_enterprise_test_dataset_config, + bigquery_enterprise_erasure_policy, + run_privacy_request_task, +): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + bigquery_enterprise_resources = request.getfixturevalue(bigquery_fixtures) + bigquery_client = bigquery_enterprise_resources["client"] + + # first test access request against manually added data + user_id = bigquery_enterprise_resources["user_id"] + customer_email = "customer-1@example.com" + data = { + "requested_at": "2024-08-30T16:09:37.359Z", + "policy_key": policy.key, + "identity": { + "email": customer_email, + "stackoverflow_user_id": { + "label": "Stackoverflow User Id", + "value": user_id, + }, + }, + } + + pr = get_privacy_request_results( + db, + policy, + run_privacy_request_task, + data, + PRIVACY_REQUEST_TASK_TIMEOUT_EXTERNAL, + ) + + results = pr.get_raw_access_results() + assert len(results.keys()) == 4 + + for key in results.keys(): + assert results[key] is not None + assert results[key] != {} + + users = results["enterprise_dsr_testing:users"] + assert len(users) == 1 + user_details = users[0] + assert user_details["id"] == user_id + + assert ( + len( + [ + comment["user_id"] + for comment in results["enterprise_dsr_testing:comments"] + ] + ) + == 1 + ) + assert ( + len( + [post["user_id"] for post in results["enterprise_dsr_testing:post_history"]] + ) + == 1 + ) + assert ( + len( + [ + post["title"] + for post in results["enterprise_dsr_testing:stackoverflow_posts"] + ] + ) + == 1 + ) + + data = { + "requested_at": "2024-08-30T16:09:37.359Z", + "policy_key": bigquery_enterprise_erasure_policy.key, + "identity": { + "email": customer_email, + "stackoverflow_user_id": { + "label": "Stackoverflow User Id", + "value": bigquery_enterprise_resources["user_id"], + }, + }, + } + + # Should erase all user data + pr = get_privacy_request_results( + db, + bigquery_enterprise_erasure_policy, + run_privacy_request_task, + data, + task_timeout=PRIVACY_REQUEST_TASK_TIMEOUT_EXTERNAL, + ) + pr.delete(db=db) + + bigquery_client = bigquery_enterprise_resources["client"] + post_history_id = bigquery_enterprise_resources["post_history_id"] + comment_id = bigquery_enterprise_resources["comment_id"] + post_id = bigquery_enterprise_resources["post_id"] + with bigquery_client.connect() as connection: + stmt = f"select text from enterprise_dsr_testing.post_history where id = {post_history_id};" + res = connection.execute(stmt).all() + for row in res: + assert row.text is None + + stmt = f"select user_display_name, text from enterprise_dsr_testing.comments where id = {comment_id};" + res = connection.execute(stmt).all() + for row in res: + assert row.user_display_name is None + assert row.text is None + + stmt = f"select owner_user_id, owner_display_name, body from enterprise_dsr_testing.stackoverflow_posts where id = {post_id};" + res = connection.execute(stmt).all() + for row in res: + assert ( + row.owner_user_id == bigquery_enterprise_resources["user_id"] + ) # not targeted by policy + assert row.owner_display_name is None + assert row.body is None + + stmt = f"select display_name, location from enterprise_dsr_testing.users where id = {user_id};" + res = connection.execute(stmt).all() + for row in res: + assert row.display_name is None + assert row.location is None diff --git a/tests/ops/service/privacy_request/test_bigquery_privacy_requests.py b/tests/ops/service/privacy_request/test_bigquery_privacy_requests.py new file mode 100644 index 0000000000..2d73e72034 --- /dev/null +++ b/tests/ops/service/privacy_request/test_bigquery_privacy_requests.py @@ -0,0 +1,189 @@ +import pytest + +from tests.ops.service.privacy_request.test_request_runner_service import ( + PRIVACY_REQUEST_TASK_TIMEOUT_EXTERNAL, + get_privacy_request_results, +) + + +@pytest.mark.integration_external +@pytest.mark.integration_bigquery +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_2_0", "use_dsr_3_0"], +) +@pytest.mark.parametrize( + "bigquery_fixtures", + ["bigquery_resources", "bigquery_resources_with_namespace_meta"], +) +def test_create_and_process_access_request_bigquery( + db, + policy, + dsr_version, + request, + bigquery_fixtures, + run_privacy_request_task, +): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + bigquery_resources = request.getfixturevalue(bigquery_fixtures) + + customer_email = bigquery_resources["email"] + customer_name = bigquery_resources["name"] + data = { + "requested_at": "2021-08-30T16:09:37.359Z", + "policy_key": policy.key, + "identity": {"email": customer_email}, + } + bigquery_client = bigquery_resources["client"] + with bigquery_client.connect() as connection: + stmt = f"select * from fidesopstest.employee where address_id = {bigquery_resources['address_id']};" + res = connection.execute(stmt).all() + for row in res: + assert row.address_id == bigquery_resources["address_id"] + assert row.id == bigquery_resources["employee_id"] + assert row.email == bigquery_resources["employee_email"] + + pr = get_privacy_request_results( + db, + policy, + run_privacy_request_task, + data, + task_timeout=PRIVACY_REQUEST_TASK_TIMEOUT_EXTERNAL, + ) + results = pr.get_raw_access_results() + customer_table_key = "bigquery_example_test_dataset:customer" + assert len(results[customer_table_key]) == 1 + assert results[customer_table_key][0]["email"] == customer_email + assert results[customer_table_key][0]["name"] == customer_name + + address_table_key = "bigquery_example_test_dataset:address" + + city = bigquery_resources["city"] + state = bigquery_resources["state"] + assert len(results[address_table_key]) == 1 + assert results[address_table_key][0]["city"] == city + assert results[address_table_key][0]["state"] == state + + employee_table_key = "bigquery_example_test_dataset:employee" + assert len(results[employee_table_key]) == 1 + assert results["bigquery_example_test_dataset:employee"] != [] + assert ( + results[employee_table_key][0]["address_id"] == bigquery_resources["address_id"] + ) + assert ( + results[employee_table_key][0]["email"] == bigquery_resources["employee_email"] + ) + assert results[employee_table_key][0]["id"] == bigquery_resources["employee_id"] + + # this covers access requests against a partitioned table + visit_partitioned_table_key = "bigquery_example_test_dataset:visit_partitioned" + assert len(results[visit_partitioned_table_key]) == 1 + assert ( + results[visit_partitioned_table_key][0]["email"] == bigquery_resources["email"] + ) + + pr.delete(db=db) + + +@pytest.mark.integration_external +@pytest.mark.integration_bigquery +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_2_0", "use_dsr_3_0"], +) +@pytest.mark.parametrize( + "bigquery_fixtures", + ["bigquery_resources", "bigquery_resources_with_namespace_meta"], +) +def test_create_and_process_erasure_request_bigquery( + db, + dsr_version, + request, + bigquery_fixtures, + erasure_policy, + run_privacy_request_task, +): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + bigquery_resources = request.getfixturevalue(bigquery_fixtures) + + bigquery_client = bigquery_resources["client"] + # Verifying that employee info exists in db + with bigquery_client.connect() as connection: + stmt = f"select * from fidesopstest.employee where address_id = {bigquery_resources['address_id']};" + res = connection.execute(stmt).all() + for row in res: + assert row.address_id == bigquery_resources["address_id"] + assert row.id == bigquery_resources["employee_id"] + assert row.email == bigquery_resources["employee_email"] + + customer_email = bigquery_resources["email"] + data = { + "requested_at": "2021-08-30T16:09:37.359Z", + "policy_key": erasure_policy.key, + "identity": {"email": customer_email}, + } + + # Should erase customer name + pr = get_privacy_request_results( + db, + erasure_policy, + run_privacy_request_task, + data, + task_timeout=PRIVACY_REQUEST_TASK_TIMEOUT_EXTERNAL, + ) + pr.delete(db=db) + + bigquery_client = bigquery_resources["client"] + with bigquery_client.connect() as connection: + stmt = ( + f"select name from fidesopstest.customer where email = '{customer_email}';" + ) + res = connection.execute(stmt).all() + for row in res: + assert row.name is None + + address_id = bigquery_resources["address_id"] + stmt = f"select 'id', city, state from fidesopstest.address where id = {address_id};" + res = connection.execute(stmt).all() + for row in res: + # Not yet masked because these fields aren't targeted by erasure policy + assert row.city == bigquery_resources["city"] + assert row.state == bigquery_resources["state"] + + target = erasure_policy.rules[0].targets[0] + target.data_category = "user.contact.address.state" + target.save(db=db) + + # Should erase state fields on address table + pr = get_privacy_request_results( + db, + erasure_policy, + run_privacy_request_task, + data, + task_timeout=PRIVACY_REQUEST_TASK_TIMEOUT_EXTERNAL, + ) + + bigquery_client = bigquery_resources["client"] + with bigquery_client.connect() as connection: + address_id = bigquery_resources["address_id"] + stmt = f"select 'id', city, state from fidesopstest.address where id = {address_id};" + res = connection.execute(stmt).all() + for row in res: + # State field was targeted by erasure policy but city was not + assert row.city is not None + assert row.state is None + + stmt = f"select 'id', city, state from fidesopstest.address where id = {address_id};" + res = connection.execute(stmt).all() + for row in res: + # State field was targeted by erasure policy but city was not + assert row.city is not None + assert row.state is None + + stmt = f"select * from fidesopstest.employee where address_id = {bigquery_resources['address_id']};" + res = connection.execute(stmt).all() + + # Employee records deleted entirely due to collection-level masking strategy override + assert res == [] + + pr.delete(db=db) diff --git a/tests/ops/service/privacy_request/test_dynamodb_privacy_requests.py b/tests/ops/service/privacy_request/test_dynamodb_privacy_requests.py new file mode 100644 index 0000000000..08e6dfa39f --- /dev/null +++ b/tests/ops/service/privacy_request/test_dynamodb_privacy_requests.py @@ -0,0 +1,249 @@ +from datetime import datetime, timezone +from typing import Dict +from uuid import uuid4 + +import pytest +from boto3.dynamodb.types import TypeDeserializer + +from fides.api.service.connectors.dynamodb_connector import DynamoDBConnector +from tests.ops.service.privacy_request.test_request_runner_service import ( + PRIVACY_REQUEST_TASK_TIMEOUT_EXTERNAL, + get_privacy_request_results, +) + + +@pytest.fixture(scope="function") +def dynamodb_resources( + dynamodb_example_test_dataset_config, +): + dynamodb_connection_config = dynamodb_example_test_dataset_config.connection_config + dynamodb_client = DynamoDBConnector(dynamodb_connection_config).client() + uuid = str(uuid4()) + customer_email = f"customer-{uuid}@example.com" + customer_name = f"{uuid}" + + ## document and remove remaining comments if we can't get the bigger test running + items = { + "customer_identifier": [ + { + "customer_id": {"S": customer_name}, + "email": {"S": customer_email}, + "name": {"S": customer_name}, + "created": {"S": datetime.now(timezone.utc).isoformat()}, + } + ], + "customer": [ + { + "id": {"S": customer_name}, + "name": {"S": customer_name}, + "email": {"S": customer_email}, + "address_id": {"L": [{"S": customer_name}, {"S": customer_name}]}, + "personal_info": {"M": {"gender": {"S": "male"}, "age": {"S": "99"}}}, + "created": {"S": datetime.now(timezone.utc).isoformat()}, + } + ], + "address": [ + { + "id": {"S": customer_name}, + "city": {"S": "city"}, + "house": {"S": "house"}, + "state": {"S": "state"}, + "street": {"S": "street"}, + "zip": {"S": "zip"}, + } + ], + "login": [ + { + "customer_id": {"S": customer_name}, + "login_date": {"S": "2023-01-01"}, + "name": {"S": customer_name}, + "email": {"S": customer_email}, + }, + { + "customer_id": {"S": customer_name}, + "login_date": {"S": "2023-01-02"}, + "name": {"S": customer_name}, + "email": {"S": customer_email}, + }, + ], + } + + for table_name, rows in items.items(): + for item in rows: + res = dynamodb_client.put_item( + TableName=table_name, + Item=item, + ) + assert res["ResponseMetadata"]["HTTPStatusCode"] == 200 + + yield { + "email": customer_email, + "formatted_email": customer_email, + "name": customer_name, + "customer_id": uuid, + "client": dynamodb_client, + } + # Remove test data and close Dynamodb connection in teardown + delete_items = { + "customer_identifier": [{"email": {"S": customer_email}}], + "customer": [{"id": {"S": customer_name}}], + "address": [{"id": {"S": customer_name}}], + "login": [ + { + "customer_id": {"S": customer_name}, + "login_date": {"S": "2023-01-01"}, + }, + { + "customer_id": {"S": customer_name}, + "login_date": {"S": "2023-01-02"}, + }, + ], + } + for table_name, rows in delete_items.items(): + for item in rows: + res = dynamodb_client.delete_item( + TableName=table_name, + Key=item, + ) + assert res["ResponseMetadata"]["HTTPStatusCode"] == 200 + + +@pytest.mark.integration_external +@pytest.mark.integration_dynamodb +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) +def test_create_and_process_empty_access_request_dynamodb( + db, + cache, + policy, + dsr_version, + request, + run_privacy_request_task, +): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + + data = { + "requested_at": "2021-08-30T16:09:37.359Z", + "policy_key": policy.key, + "identity": {"email": "thiscustomerdoesnot@exist.com"}, + } + + pr = get_privacy_request_results( + db, + policy, + run_privacy_request_task, + data, + task_timeout=PRIVACY_REQUEST_TASK_TIMEOUT_EXTERNAL, + ) + # Here the results should be empty as no data will be located for that identity + results = pr.get_raw_access_results() + pr.delete(db=db) + assert results == {} + + +@pytest.mark.integration_external +@pytest.mark.integration_dynamodb +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) +def test_create_and_process_access_request_dynamodb( + dynamodb_resources, + db, + cache, + policy, + run_privacy_request_task, + dsr_version, + request, +): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + + customer_email = dynamodb_resources["email"] + customer_name = dynamodb_resources["name"] + customer_id = dynamodb_resources["customer_id"] + data = { + "requested_at": "2021-08-30T16:09:37.359Z", + "policy_key": policy.key, + "identity": {"email": customer_email}, + } + + pr = get_privacy_request_results( + db, + policy, + run_privacy_request_task, + data, + task_timeout=PRIVACY_REQUEST_TASK_TIMEOUT_EXTERNAL, + ) + results = pr.get_raw_access_results() + customer_table_key = f"dynamodb_example_test_dataset:customer" + address_table_key = f"dynamodb_example_test_dataset:address" + login_table_key = f"dynamodb_example_test_dataset:login" + assert len(results[customer_table_key]) == 1 + assert len(results[address_table_key]) == 1 + assert len(results[login_table_key]) == 2 + assert results[customer_table_key][0]["email"] == customer_email + assert results[customer_table_key][0]["name"] == customer_name + assert results[customer_table_key][0]["id"] == customer_id + assert results[address_table_key][0]["id"] == customer_id + assert results[login_table_key][0]["name"] == customer_name + + pr.delete(db=db) + + +@pytest.mark.integration_external +@pytest.mark.integration_dynamodb +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) +def test_create_and_process_erasure_request_dynamodb( + dynamodb_example_test_dataset_config, + dynamodb_resources, + integration_config: Dict[str, str], + db, + cache, + erasure_policy, + dsr_version, + request, + run_privacy_request_task, +): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + + customer_email = dynamodb_resources["email"] + dynamodb_client = dynamodb_resources["client"] + customer_id = dynamodb_resources["customer_id"] + customer_name = dynamodb_resources["name"] + data = { + "requested_at": "2021-08-30T16:09:37.359Z", + "policy_key": erasure_policy.key, + "identity": {"email": customer_email}, + } + pr = get_privacy_request_results( + db, + erasure_policy, + run_privacy_request_task, + data, + task_timeout=PRIVACY_REQUEST_TASK_TIMEOUT_EXTERNAL, + ) + pr.delete(db=db) + deserializer = TypeDeserializer() + customer = dynamodb_client.get_item( + TableName="customer", + Key={"id": {"S": customer_id}}, + ) + customer_identifier = dynamodb_client.get_item( + TableName="customer_identifier", + Key={"email": {"S": customer_email}}, + ) + login = dynamodb_client.get_item( + TableName="login", + Key={ + "customer_id": {"S": customer_name}, + "login_date": {"S": "2023-01-01"}, + }, + ) + assert deserializer.deserialize(customer["Item"]["name"]) == None + assert deserializer.deserialize(customer_identifier["Item"]["name"]) == None + assert deserializer.deserialize(login["Item"]["name"]) == None diff --git a/tests/ops/service/privacy_request/test_google_cloud_mysql_privacy_requests.py b/tests/ops/service/privacy_request/test_google_cloud_mysql_privacy_requests.py new file mode 100644 index 0000000000..7a449defd7 --- /dev/null +++ b/tests/ops/service/privacy_request/test_google_cloud_mysql_privacy_requests.py @@ -0,0 +1,116 @@ +from unittest import mock + +import pytest +from sqlalchemy import column, select, table +from sqlalchemy.orm import Session + +from tests.ops.service.privacy_request.test_request_runner_service import ( + PRIVACY_REQUEST_TASK_TIMEOUT_EXTERNAL, + get_privacy_request_results, +) + + +@pytest.mark.integration_external +@pytest.mark.integration_google_cloud_sql_mysql +@mock.patch("fides.api.models.privacy_request.PrivacyRequest.trigger_policy_webhook") +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) +def test_create_and_process_access_request_google_cloud_sql_mysql( + trigger_webhook_mock, + google_cloud_sql_mysql_example_test_dataset_config, + google_cloud_sql_mysql_integration_db, + db: Session, + cache, + policy, + dsr_version, + request, + policy_pre_execution_webhooks, + policy_post_execution_webhooks, + run_privacy_request_task, +): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + + customer_email = "customer-1@example.com" + data = { + "requested_at": "2021-08-30T16:09:37.359Z", + "policy_key": policy.key, + "identity": {"email": customer_email}, + } + + pr = get_privacy_request_results( + db, + policy, + run_privacy_request_task, + data, + task_timeout=PRIVACY_REQUEST_TASK_TIMEOUT_EXTERNAL, + ) + + results = pr.get_raw_access_results() + assert len(results.keys()) == 11 + + for key in results.keys(): + assert results[key] is not None + assert results[key] != {} + + result_key_prefix = "google_cloud_sql_mysql_example_test_dataset:" + customer_key = result_key_prefix + "customer" + assert results[customer_key][0]["email"] == customer_email + + visit_key = result_key_prefix + "visit" + assert results[visit_key][0]["email"] == customer_email + # Both pre-execution webhooks and both post-execution webhooks were called + assert trigger_webhook_mock.call_count == 4 + pr.delete(db=db) + + +@pytest.mark.integration_external +@pytest.mark.integration_google_cloud_sql_mysql +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) +def test_create_and_process_erasure_request_google_cloud_sql_mysql( + google_cloud_sql_mysql_integration_db, + google_cloud_sql_mysql_example_test_dataset_config, + cache, + db, + dsr_version, + request, + generate_auth_header, + erasure_policy, + run_privacy_request_task, +): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + + customer_email = "customer-1@example.com" + customer_id = 1 + data = { + "requested_at": "2021-08-30T16:09:37.359Z", + "policy_key": erasure_policy.key, + "identity": {"email": customer_email}, + } + + pr = get_privacy_request_results( + db, + erasure_policy, + run_privacy_request_task, + data, + task_timeout=PRIVACY_REQUEST_TASK_TIMEOUT_EXTERNAL, + ) + pr.delete(db=db) + + stmt = select( + column("id"), + column("name"), + ).select_from(table("customer")) + res = google_cloud_sql_mysql_integration_db.execute(stmt).all() + + customer_found = False + for row in res: + if customer_id == row.id: + customer_found = True + # Check that the `name` field is `None` + assert row.name is None + assert customer_found diff --git a/tests/ops/service/privacy_request/test_google_cloud_postgres_privacy_requests.py b/tests/ops/service/privacy_request/test_google_cloud_postgres_privacy_requests.py new file mode 100644 index 0000000000..87b0ca13d8 --- /dev/null +++ b/tests/ops/service/privacy_request/test_google_cloud_postgres_privacy_requests.py @@ -0,0 +1,117 @@ +from unittest import mock + +import pytest +from sqlalchemy import column, select, table +from sqlalchemy.orm import Session + +from tests.ops.service.privacy_request.test_request_runner_service import ( + PRIVACY_REQUEST_TASK_TIMEOUT_EXTERNAL, + get_privacy_request_results, +) + + +@pytest.mark.integration_external +@pytest.mark.integration_google_cloud_sql_postgres +@mock.patch("fides.api.models.privacy_request.PrivacyRequest.trigger_policy_webhook") +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) +def test_create_and_process_access_request_google_cloud_sql_postgres( + trigger_webhook_mock, + google_cloud_sql_postgres_example_test_dataset_config, + google_cloud_sql_postgres_integration_db, + db: Session, + cache, + policy, + dsr_version, + request, + policy_pre_execution_webhooks, + policy_post_execution_webhooks, + run_privacy_request_task, +): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + + customer_email = "customer-1@example.com" + data = { + "requested_at": "2021-08-30T16:09:37.359Z", + "policy_key": policy.key, + "identity": {"email": customer_email}, + } + + pr = get_privacy_request_results( + db, + policy, + run_privacy_request_task, + data, + task_timeout=PRIVACY_REQUEST_TASK_TIMEOUT_EXTERNAL, + ) + + results = pr.get_raw_access_results() + assert len(results.keys()) == 11 + + for key in results.keys(): + assert results[key] is not None + assert results[key] != {} + + result_key_prefix = "google_cloud_sql_postgres_example_test_dataset:" + customer_key = result_key_prefix + "customer" + assert results[customer_key][0]["email"] == customer_email + + visit_key = result_key_prefix + "visit" + assert results[visit_key][0]["email"] == customer_email + # Both pre-execution webhooks and both post-execution webhooks were called + assert trigger_webhook_mock.call_count == 4 + pr.delete(db=db) + + +@pytest.mark.integration_external +@pytest.mark.integration_google_cloud_sql_postgres +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) +def test_create_and_process_erasure_request_google_cloud_sql_postgres( + google_cloud_sql_postgres_integration_db, + google_cloud_sql_postgres_example_test_dataset_config, + cache, + db, + dsr_version, + request, + generate_auth_header, + erasure_policy, + run_privacy_request_task, +): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + + customer_email = "customer-1@example.com" + customer_id = 1 + data = { + "requested_at": "2021-08-30T16:09:37.359Z", + "policy_key": erasure_policy.key, + "identity": {"email": customer_email}, + } + + pr = get_privacy_request_results( + db, + erasure_policy, + run_privacy_request_task, + data, + task_timeout=PRIVACY_REQUEST_TASK_TIMEOUT_EXTERNAL, + ) + pr.delete(db=db) + + stmt = select( + column("id"), + column("name"), + ).select_from(table("customer")) + + res = google_cloud_sql_postgres_integration_db.execute(stmt).all() + + customer_found = False + for row in res: + if customer_id == row.id: + customer_found = True + # Check that the `name` field is `None` + assert row.name is None + assert customer_found diff --git a/tests/ops/service/privacy_request/test_mariadb_privacy_requests.py b/tests/ops/service/privacy_request/test_mariadb_privacy_requests.py new file mode 100644 index 0000000000..569d6d6cb2 --- /dev/null +++ b/tests/ops/service/privacy_request/test_mariadb_privacy_requests.py @@ -0,0 +1,112 @@ +from unittest import mock + +import pytest +from sqlalchemy import column, select, table + +from tests.ops.service.privacy_request.test_request_runner_service import ( + get_privacy_request_results, +) + + +@pytest.mark.integration_mariadb +@pytest.mark.integration +@mock.patch("fides.api.models.privacy_request.PrivacyRequest.trigger_policy_webhook") +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) +def test_create_and_process_access_request_mariadb( + trigger_webhook_mock, + mariadb_example_test_dataset_config, + mariadb_integration_db, + db, + cache, + policy, + dsr_version, + request, + policy_pre_execution_webhooks, + policy_post_execution_webhooks, + run_privacy_request_task, +): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + + customer_email = "customer-1@example.com" + data = { + "requested_at": "2021-08-30T16:09:37.359Z", + "policy_key": policy.key, + "identity": {"email": customer_email}, + } + + pr = get_privacy_request_results( + db, + policy, + run_privacy_request_task, + data, + ) + + results = pr.get_raw_access_results() + assert len(results.keys()) == 11 + + for key in results.keys(): + assert results[key] is not None + assert results[key] != {} + + result_key_prefix = "mariadb_example_test_dataset:" + customer_key = result_key_prefix + "customer" + assert results[customer_key][0]["email"] == customer_email + + visit_key = result_key_prefix + "visit" + assert results[visit_key][0]["email"] == customer_email + # Both pre-execution webhooks and both post-execution webhooks were called + assert trigger_webhook_mock.call_count == 4 + pr.delete(db=db) + + +@pytest.mark.integration_mariadb +@pytest.mark.integration +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) +def test_create_and_process_erasure_request_specific_category_mariadb( + mariadb_example_test_dataset_config, + mariadb_integration_db, + cache, + db, + dsr_version, + request, + generate_auth_header, + erasure_policy, + run_privacy_request_task, +): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + + customer_email = "customer-1@example.com" + customer_id = 1 + data = { + "requested_at": "2021-08-30T16:09:37.359Z", + "policy_key": erasure_policy.key, + "identity": {"email": customer_email}, + } + + pr = get_privacy_request_results( + db, + erasure_policy, + run_privacy_request_task, + data, + ) + pr.delete(db=db) + + stmt = select( + column("id"), + column("name"), + ).select_from(table("customer")) + res = mariadb_integration_db.execute(stmt).all() + + customer_found = False + for row in res: + if customer_id == row.id: + customer_found = True + # Check that the `name` field is `None` + assert row.name is None + assert customer_found diff --git a/tests/ops/service/privacy_request/test_mssql_privacy_requests.py b/tests/ops/service/privacy_request/test_mssql_privacy_requests.py new file mode 100644 index 0000000000..a36f561ccd --- /dev/null +++ b/tests/ops/service/privacy_request/test_mssql_privacy_requests.py @@ -0,0 +1,112 @@ +from unittest import mock + +import pytest +from sqlalchemy import column, select, table + +from tests.ops.service.privacy_request.test_request_runner_service import ( + get_privacy_request_results, +) + + +@pytest.mark.integration_mssql +@pytest.mark.integration +@mock.patch("fides.api.models.privacy_request.PrivacyRequest.trigger_policy_webhook") +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) +def test_create_and_process_access_request_mssql( + trigger_webhook_mock, + mssql_example_test_dataset_config, + mssql_integration_db, + db, + cache, + policy, + dsr_version, + request, + policy_pre_execution_webhooks, + policy_post_execution_webhooks, + run_privacy_request_task, +): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + + customer_email = "customer-1@example.com" + data = { + "requested_at": "2021-08-30T16:09:37.359Z", + "policy_key": policy.key, + "identity": {"email": customer_email}, + } + + pr = get_privacy_request_results( + db, + policy, + run_privacy_request_task, + data, + ) + + results = pr.get_raw_access_results() + assert len(results.keys()) == 11 + + for key in results.keys(): + assert results[key] is not None + assert results[key] != {} + + result_key_prefix = "mssql_example_test_dataset:" + customer_key = result_key_prefix + "customer" + assert results[customer_key][0]["email"] == customer_email + + visit_key = result_key_prefix + "visit" + assert results[visit_key][0]["email"] == customer_email + # Both pre-execution webhooks and both post-execution webhooks were called + assert trigger_webhook_mock.call_count == 4 + pr.delete(db=db) + + +@pytest.mark.integration_mssql +@pytest.mark.integration +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) +def test_create_and_process_erasure_request_specific_category_mssql( + mssql_integration_db, + mssql_example_test_dataset_config, + cache, + db, + dsr_version, + request, + generate_auth_header, + erasure_policy, + run_privacy_request_task, +): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + + customer_email = "customer-1@example.com" + customer_id = 1 + data = { + "requested_at": "2021-08-30T16:09:37.359Z", + "policy_key": erasure_policy.key, + "identity": {"email": customer_email}, + } + + pr = get_privacy_request_results( + db, + erasure_policy, + run_privacy_request_task, + data, + ) + pr.delete(db=db) + + stmt = select( + column("id"), + column("name"), + ).select_from(table("customer")) + res = mssql_integration_db.execute(stmt).all() + + customer_found = False + for row in res: + if customer_id == row.id: + customer_found = True + # Check that the `name` field is `None` + assert row.name is None + assert customer_found diff --git a/tests/ops/service/privacy_request/test_mysql_privacy_requests.py b/tests/ops/service/privacy_request/test_mysql_privacy_requests.py new file mode 100644 index 0000000000..3f1db392b2 --- /dev/null +++ b/tests/ops/service/privacy_request/test_mysql_privacy_requests.py @@ -0,0 +1,112 @@ +from unittest import mock + +import pytest +from sqlalchemy import column, select, table + +from tests.ops.service.privacy_request.test_request_runner_service import ( + get_privacy_request_results, +) + + +@pytest.mark.integration +@pytest.mark.integration_mysql +@mock.patch("fides.api.models.privacy_request.PrivacyRequest.trigger_policy_webhook") +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) +def test_create_and_process_access_request_mysql( + trigger_webhook_mock, + mysql_example_test_dataset_config, + mysql_integration_db, + db, + cache, + policy, + dsr_version, + request, + policy_pre_execution_webhooks, + policy_post_execution_webhooks, + run_privacy_request_task, +): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + + customer_email = "customer-1@example.com" + data = { + "requested_at": "2021-08-30T16:09:37.359Z", + "policy_key": policy.key, + "identity": {"email": customer_email}, + } + + pr = get_privacy_request_results( + db, + policy, + run_privacy_request_task, + data, + ) + + results = pr.get_raw_access_results() + assert len(results.keys()) == 12 + + for key in results.keys(): + assert results[key] is not None + assert results[key] != {} + + result_key_prefix = f"mysql_example_test_dataset:" + customer_key = result_key_prefix + "customer" + assert results[customer_key][0]["email"] == customer_email + + visit_key = result_key_prefix + "visit" + assert results[visit_key][0]["email"] == customer_email + # Both pre-execution webhooks and both post-execution webhooks were called + assert trigger_webhook_mock.call_count == 4 + pr.delete(db=db) + + +@pytest.mark.integration_mysql +@pytest.mark.integration +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) +def test_create_and_process_erasure_request_specific_category_mysql( + mysql_integration_db, + mysql_example_test_dataset_config, + cache, + db, + dsr_version, + request, + generate_auth_header, + erasure_policy, + run_privacy_request_task, +): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + + customer_email = "customer-1@example.com" + customer_id = 1 + data = { + "requested_at": "2021-08-30T16:09:37.359Z", + "policy_key": erasure_policy.key, + "identity": {"email": customer_email}, + } + + pr = get_privacy_request_results( + db, + erasure_policy, + run_privacy_request_task, + data, + ) + pr.delete(db=db) + + stmt = select( + column("id"), + column("name"), + ).select_from(table("customer")) + res = mysql_integration_db.execute(stmt).all() + + customer_found = False + for row in res: + if customer_id == row.id: + customer_found = True + # Check that the `name` field is `None` + assert row.name is None + assert customer_found diff --git a/tests/ops/service/privacy_request/test_postgres_privacy_requests.py b/tests/ops/service/privacy_request/test_postgres_privacy_requests.py new file mode 100644 index 0000000000..2959efd463 --- /dev/null +++ b/tests/ops/service/privacy_request/test_postgres_privacy_requests.py @@ -0,0 +1,763 @@ +from unittest import mock +from unittest.mock import Mock + +import pytest +from sqlalchemy import column, select, table + +from fides.api.graph.config import CollectionAddress, FieldPath +from fides.api.models.audit_log import AuditLog, AuditLogAction +from fides.api.models.privacy_request import ( + ExecutionLog, + ExecutionLogStatus, + PrivacyRequestStatus, +) +from fides.api.util.data_category import DataCategory +from tests.ops.service.privacy_request.test_request_runner_service import ( + PRIVACY_REQUEST_TASK_TIMEOUT_EXTERNAL, + get_privacy_request_results, +) + + +@pytest.mark.integration_postgres +@pytest.mark.integration +@mock.patch("fides.api.service.privacy_request.request_runner_service.upload") +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) +def test_upload_access_results_has_data_category_field_mapping( + upload_mock: Mock, + postgres_example_test_dataset_config_read_access, + postgres_integration_db, + db, + policy, + dsr_version, + request, + run_privacy_request_task, +): + """ + Ensure we are passing along a correctly populated data_category_field_mapping to the 'upload' function + that publishes the access request output. + """ + upload_mock.return_value = "http://www.data-download-url" + + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + + customer_email = "customer-1@example.com" + data = { + "requested_at": "2021-08-30T16:09:37.359Z", + "policy_key": policy.key, + "identity": {"email": customer_email}, + } + + pr = get_privacy_request_results( + db, + policy, + run_privacy_request_task, + data, + ) + + # sanity check that acccess results returned as expected + results = pr.get_raw_access_results() + assert len(results.keys()) == 11 + + # what we're really testing - ensure data_category_field_mapping arg is well-populated + args, kwargs = upload_mock.call_args + data_category_field_mapping = kwargs["data_category_field_mapping"] + + # make sure the category field mapping generally looks as we expect + address_mapping = data_category_field_mapping[ + CollectionAddress.from_string("postgres_example_test_dataset:address") + ] + assert len(address_mapping) >= 5 + assert address_mapping["user.contact.address.street"] == [ + FieldPath("house"), + FieldPath("street"), + ] + product_mapping = data_category_field_mapping[ + CollectionAddress.from_string("postgres_example_test_dataset:product") + ] + assert len(product_mapping) >= 1 + assert product_mapping["system.operations"] == [ + FieldPath( + "id", + ), + FieldPath( + "name", + ), + FieldPath( + "price", + ), + ] + + +@pytest.mark.integration_postgres +@pytest.mark.integration +@mock.patch("fides.api.service.privacy_request.request_runner_service.upload") +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) +def test_upload_access_results_has_data_use_map( + upload_mock: Mock, + postgres_example_test_dataset_config_read_access, + postgres_integration_db, + db, + policy, + dsr_version, + request, + run_privacy_request_task, +): + """ + Ensure we are passing along a correctly populated data_use_map to the 'upload' function + that publishes the access request output. + """ + upload_mock.return_value = "http://www.data-download-url" + + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + + customer_email = "customer-1@example.com" + data = { + "requested_at": "2021-08-30T16:09:37.359Z", + "policy_key": policy.key, + "identity": {"email": customer_email}, + } + + pr = get_privacy_request_results( + db, + policy, + run_privacy_request_task, + data, + ) + + # sanity check that access results returned as expected + results = pr.get_raw_access_results() + assert len(results.keys()) == 11 + + # what we're really testing - ensure data_use_map arg is well-populated + args, kwargs = upload_mock.call_args + data_use_map = kwargs["data_use_map"] + + assert data_use_map == { + "postgres_example_test_dataset:report": "{'marketing.advertising'}", + "postgres_example_test_dataset:employee": "{'marketing.advertising'}", + "postgres_example_test_dataset:customer": "{'marketing.advertising'}", + "postgres_example_test_dataset:service_request": "{'marketing.advertising'}", + "postgres_example_test_dataset:visit": "{'marketing.advertising'}", + "postgres_example_test_dataset:address": "{'marketing.advertising'}", + "postgres_example_test_dataset:login": "{'marketing.advertising'}", + "postgres_example_test_dataset:orders": "{'marketing.advertising'}", + "postgres_example_test_dataset:payment_card": "{'marketing.advertising'}", + "postgres_example_test_dataset:order_item": "{'marketing.advertising'}", + "postgres_example_test_dataset:product": "{'marketing.advertising'}", + } + + +@pytest.mark.integration_postgres +@pytest.mark.integration +@mock.patch("fides.api.models.privacy_request.PrivacyRequest.trigger_policy_webhook") +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) +def test_create_and_process_access_request_postgres( + trigger_webhook_mock, + postgres_example_test_dataset_config_read_access, + postgres_integration_db, + db, + cache, + dsr_version, + request, + policy, + policy_pre_execution_webhooks, + policy_post_execution_webhooks, + run_privacy_request_task, +): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + + customer_email = "customer-1@example.com" + data = { + "requested_at": "2021-08-30T16:09:37.359Z", + "policy_key": policy.key, + "identity": {"email": customer_email}, + } + + pr = get_privacy_request_results( + db, + policy, + run_privacy_request_task, + data, + ) + + results = pr.get_raw_access_results() + assert len(results.keys()) == 11 + + for key in results.keys(): + assert results[key] is not None + assert results[key] != {} + + result_key_prefix = f"postgres_example_test_dataset:" + customer_key = result_key_prefix + "customer" + assert results[customer_key][0]["email"] == customer_email + + visit_key = result_key_prefix + "visit" + assert results[visit_key][0]["email"] == customer_email + log_id = pr.execution_logs[0].id + pr_id = pr.id + + finished_audit_log: AuditLog = AuditLog.filter( + db=db, + conditions=( + (AuditLog.privacy_request_id == pr_id) + & (AuditLog.action == AuditLogAction.finished) + ), + ).first() + + assert finished_audit_log is not None + + # Both pre-execution webhooks and both post-execution webhooks were called + assert trigger_webhook_mock.call_count == 4 + + for webhook in policy_pre_execution_webhooks: + webhook.delete(db=db) + + for webhook in policy_post_execution_webhooks: + webhook.delete(db=db) + + policy.delete(db=db) + pr.delete(db=db) + assert not pr in db # Check that `pr` has been expunged from the session + assert ExecutionLog.get(db, object_id=log_id).privacy_request_id == pr_id + + +@pytest.mark.integration_postgres +@pytest.mark.integration +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) +@mock.patch("fides.api.models.privacy_request.PrivacyRequest.trigger_policy_webhook") +def test_create_and_process_access_request_with_custom_identities_postgres( + trigger_webhook_mock, + postgres_example_test_dataset_config_read_access, + postgres_example_test_extended_dataset_config, + postgres_integration_db, + db, + cache, + policy, + dsr_version, + request, + policy_pre_execution_webhooks, + policy_post_execution_webhooks, + run_privacy_request_task, +): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + + customer_email = "customer-1@example.com" + loyalty_id = "CH-1" + data = { + "requested_at": "2021-08-30T16:09:37.359Z", + "policy_key": policy.key, + "identity": { + "email": customer_email, + "loyalty_id": {"label": "Loyalty ID", "value": loyalty_id}, + }, + } + + pr = get_privacy_request_results( + db, + policy, + run_privacy_request_task, + data, + ) + + results = pr.get_raw_access_results() + assert len(results.keys()) == 12 + + for key in results.keys(): + assert results[key] is not None + assert results[key] != {} + + result_key_prefix = f"postgres_example_test_dataset:" + customer_key = result_key_prefix + "customer" + assert results[customer_key][0]["email"] == customer_email + + visit_key = result_key_prefix + "visit" + assert results[visit_key][0]["email"] == customer_email + + loyalty_key = f"postgres_example_test_extended_dataset:loyalty" + assert results[loyalty_key][0]["id"] == loyalty_id + + log_id = pr.execution_logs[0].id + pr_id = pr.id + + finished_audit_log: AuditLog = AuditLog.filter( + db=db, + conditions=( + (AuditLog.privacy_request_id == pr_id) + & (AuditLog.action == AuditLogAction.finished) + ), + ).first() + + assert finished_audit_log is not None + + # Both pre-execution webhooks and both post-execution webhooks were called + assert trigger_webhook_mock.call_count == 4 + + for webhook in policy_pre_execution_webhooks: + webhook.delete(db=db) + + for webhook in policy_post_execution_webhooks: + webhook.delete(db=db) + + policy.delete(db=db) + pr.delete(db=db) + assert not pr in db # Check that `pr` has been expunged from the session + assert ExecutionLog.get(db, object_id=log_id).privacy_request_id == pr_id + + +@pytest.mark.integration_postgres +@pytest.mark.integration +@pytest.mark.usefixtures( + "postgres_example_test_dataset_config_skipped_login_collection", + "postgres_integration_db", + "cache", +) +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) +def test_create_and_process_access_request_with_valid_skipped_collection( + db, + policy, + run_privacy_request_task, + dsr_version, + request, +): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + + customer_email = "customer-1@example.com" + data = { + "requested_at": "2021-08-30T16:09:37.359Z", + "policy_key": policy.key, + "identity": {"email": customer_email}, + } + + pr = get_privacy_request_results( + db, + policy, + run_privacy_request_task, + data, + ) + + results = pr.get_raw_access_results() + assert len(results.keys()) == 10 + + assert "login" not in results.keys() + + result_key_prefix = f"postgres_example_test_dataset:" + customer_key = result_key_prefix + "customer" + assert results[customer_key][0]["email"] == customer_email + + assert AuditLog.filter( + db=db, + conditions=( + (AuditLog.privacy_request_id == pr.id) + & (AuditLog.action == AuditLogAction.finished) + ), + ).first() + + +@pytest.mark.integration_postgres +@pytest.mark.integration +@pytest.mark.usefixtures( + "postgres_example_test_dataset_config_skipped_address_collection", + "postgres_integration_db", + "cache", +) +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) +def test_create_and_process_access_request_with_invalid_skipped_collection( + db, + policy, + dsr_version, + request, + run_privacy_request_task, +): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + + customer_email = "customer-1@example.com" + data = { + "requested_at": "2021-08-30T16:09:37.359Z", + "policy_key": policy.key, + "identity": {"email": customer_email}, + } + + pr = get_privacy_request_results( + db, + policy, + run_privacy_request_task, + data, + ) + + results = pr.get_raw_access_results() + assert len(results.keys()) == 0 + + db.refresh(pr) + + assert pr.status == PrivacyRequestStatus.error + + +@pytest.mark.integration_postgres +@pytest.mark.integration +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_2_0", "use_dsr_3_0"], +) +def test_create_and_process_access_request_postgres_with_disabled_integration( + postgres_integration_db, + postgres_example_test_dataset_config, + connection_config, + db, + dsr_version, + request, + policy, + run_privacy_request_task, +): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + + data = { + "requested_at": "2021-08-30T16:09:37.359Z", + "policy_key": policy.key, + "identity": {"external_id": "ext-123"}, + } + + pr = get_privacy_request_results( + db, + policy, + run_privacy_request_task, + data, + task_timeout=PRIVACY_REQUEST_TASK_TIMEOUT_EXTERNAL, + ) + + for execution_log in pr.execution_logs: + assert execution_log.dataset_name == "Dataset traversal" + assert execution_log.status == ExecutionLogStatus.error + + connection_config.disabled = True + connection_config.save(db=db) + + pr = get_privacy_request_results( + db, + policy, + run_privacy_request_task, + data, + task_timeout=PRIVACY_REQUEST_TASK_TIMEOUT_EXTERNAL, + ) + + assert pr.execution_logs.count() == 1 + + execution_log = pr.execution_logs[0] + assert execution_log.dataset_name == "Dataset traversal" + assert execution_log.status == ExecutionLogStatus.complete + + +@pytest.mark.integration_postgres +@pytest.mark.integration +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) +def test_create_and_process_erasure_request_specific_category_postgres( + postgres_integration_db, + postgres_example_test_dataset_config, + cache, + db, + generate_auth_header, + erasure_policy, + dsr_version, + request, + read_connection_config, + run_privacy_request_task, +): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + + customer_email = "customer-1@example.com" + customer_id = 1 + data = { + "requested_at": "2021-08-30T16:09:37.359Z", + "policy_key": erasure_policy.key, + "identity": {"email": customer_email}, + } + + stmt = select("*").select_from(table("customer")) + res = postgres_integration_db.execute(stmt).all() + + pr = get_privacy_request_results( + db, + erasure_policy, + run_privacy_request_task, + data, + ) + pr.delete(db=db) + + stmt = select( + column("id"), + column("name"), + ).select_from(table("customer")) + res = postgres_integration_db.execute(stmt).all() + + customer_found = False + for row in res: + if customer_id == row.id: + customer_found = True + # Check that the `name` field is `None` + assert row.name is None + assert customer_found + + +@pytest.mark.integration_postgres +@pytest.mark.integration +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) +def test_create_and_process_erasure_request_generic_category( + postgres_integration_db, + postgres_example_test_dataset_config, + cache, + db, + dsr_version, + request, + generate_auth_header, + erasure_policy, + run_privacy_request_task, +): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + + # It's safe to change this here since the `erasure_policy` fixture is scoped + # at function level + target = erasure_policy.rules[0].targets[0] + target.data_category = DataCategory("user.contact").value + target.save(db=db) + + email = "customer-2@example.com" + customer_id = 2 + data = { + "requested_at": "2021-08-30T16:09:37.359Z", + "policy_key": erasure_policy.key, + "identity": {"email": email}, + } + + pr = get_privacy_request_results( + db, + erasure_policy, + run_privacy_request_task, + data, + ) + pr.delete(db=db) + + stmt = select( + column("id"), + column("email"), + column("name"), + ).select_from(table("customer")) + res = postgres_integration_db.execute(stmt).all() + + customer_found = False + for row in res: + if customer_id == row.id: + customer_found = True + # Check that the `email` field is `None` and that its data category + # ("user.contact.email") has been erased by the parent + # category ("user.contact") + assert row.email is None + assert row.name is not None + else: + # There are two rows other rows, and they should not have been erased + assert row.email in ["customer-1@example.com", "jane@example.com"] + assert customer_found + + +@pytest.mark.integration_postgres +@pytest.mark.integration +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) +def test_create_and_process_erasure_request_aes_generic_category( + postgres_integration_db, + postgres_example_test_dataset_config, + cache, + db, + dsr_version, + request, + generate_auth_header, + erasure_policy_aes, + run_privacy_request_task, +): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + + # It's safe to change this here since the `erasure_policy` fixture is scoped + # at function level + target = erasure_policy_aes.rules[0].targets[0] + target.data_category = DataCategory("user.contact").value + target.save(db=db) + + email = "customer-2@example.com" + customer_id = 2 + data = { + "requested_at": "2021-08-30T16:09:37.359Z", + "policy_key": erasure_policy_aes.key, + "identity": {"email": email}, + } + + pr = get_privacy_request_results( + db, + erasure_policy_aes, + run_privacy_request_task, + data, + ) + pr.delete(db=db) + + stmt = select( + column("id"), + column("email"), + column("name"), + ).select_from(table("customer")) + res = postgres_integration_db.execute(stmt).all() + + customer_found = False + for row in res: + if customer_id == row.id: + customer_found = True + # Check that the `email` field is not original val and that its data category + # ("user.contact.email") has been erased by the parent + # category ("user.contact"). + # masked val for `email` field will change per new privacy request, so the best + # we can do here is test that the original val has been changed + assert row[1] != "customer-2@example.com" + assert row[2] is not None + else: + # There are two rows other rows, and they should not have been erased + assert row[1] in ["customer-1@example.com", "jane@example.com"] + assert customer_found + + +@pytest.mark.integration_postgres +@pytest.mark.integration +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) +def test_create_and_process_erasure_request_with_table_joins( + postgres_integration_db, + postgres_example_test_dataset_config, + db, + cache, + dsr_version, + request, + erasure_policy, + run_privacy_request_task, +): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + + # It's safe to change this here since the `erasure_policy` fixture is scoped + # at function level + target = erasure_policy.rules[0].targets[0] + target.data_category = DataCategory("user.financial").value + target.save(db=db) + + customer_email = "customer-1@example.com" + customer_id = 1 + data = { + "requested_at": "2021-08-30T16:09:37.359Z", + "policy_key": erasure_policy.key, + "identity": {"email": customer_email}, + } + + pr = get_privacy_request_results( + db, + erasure_policy, + run_privacy_request_task, + data, + ) + pr.delete(db=db) + + stmt = select( + column("customer_id"), + column("id"), + column("ccn"), + column("code"), + column("name"), + ).select_from(table("payment_card")) + res = postgres_integration_db.execute(stmt).all() + + card_found = False + for row in res: + if row.customer_id == customer_id: + card_found = True + assert row.ccn is None + assert row.code is None + assert row.name is None + + assert card_found is True + + +@pytest.mark.integration_postgres +@pytest.mark.integration +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) +def test_create_and_process_erasure_request_read_access( + postgres_integration_db, + postgres_example_test_dataset_config_read_access, + db, + cache, + erasure_policy, + dsr_version, + request, + run_privacy_request_task, +): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + + customer_email = "customer-2@example.com" + customer_id = 2 + data = { + "requested_at": "2021-08-30T16:09:37.359Z", + "policy_key": erasure_policy.key, + "identity": {"email": customer_email}, + } + + pr = get_privacy_request_results( + db, + erasure_policy, + run_privacy_request_task, + data, + ) + errored_execution_logs = pr.execution_logs.filter_by(status="error") + assert errored_execution_logs.count() == 9 + assert ( + errored_execution_logs[0].message + == "No values were erased since this connection " + "my_postgres_db_1_read_config has not been given write access" + ) + pr.delete(db=db) + + stmt = select( + column("id"), + column("name"), + ).select_from(table("customer")) + res = postgres_integration_db.execute(stmt).all() + + customer_found = False + for row in res: + if customer_id == row.id: + customer_found = True + # Check that the `name` field is NOT `None`. We couldn't erase, because the ConnectionConfig only had + # "read" access + assert row.name is not None + assert customer_found diff --git a/tests/ops/service/privacy_request/test_redshift_privacy_requests.py b/tests/ops/service/privacy_request/test_redshift_privacy_requests.py new file mode 100644 index 0000000000..71043f57cf --- /dev/null +++ b/tests/ops/service/privacy_request/test_redshift_privacy_requests.py @@ -0,0 +1,192 @@ +from typing import Dict +from uuid import uuid4 + +import pytest + +from fides.api.service.connectors.redshift_connector import RedshiftConnector +from tests.ops.service.privacy_request.test_request_runner_service import ( + PRIVACY_REQUEST_TASK_TIMEOUT_EXTERNAL, + get_privacy_request_results, +) + + +@pytest.fixture(scope="function") +def redshift_resources( + redshift_example_test_dataset_config, +): + redshift_connection_config = redshift_example_test_dataset_config.connection_config + connector = RedshiftConnector(redshift_connection_config) + redshift_client = connector.client() + with redshift_client.connect() as connection: + connector.set_schema(connection) + uuid = str(uuid4()) + customer_email = f"customer-{uuid}@example.com" + customer_name = f"{uuid}" + + stmt = "select max(id) from customer;" + res = connection.execute(stmt) + customer_id = res.all()[0][0] + 1 + + stmt = "select max(id) from address;" + res = connection.execute(stmt) + address_id = res.all()[0][0] + 1 + + city = "Test City" + state = "TX" + stmt = f""" + insert into address (id, house, street, city, state, zip) + values ({address_id}, '{111}', 'Test Street', '{city}', '{state}', '55555'); + """ + connection.execute(stmt) + + stmt = f""" + insert into customer (id, email, name, address_id) + values ({customer_id}, '{customer_email}', '{customer_name}', '{address_id}'); + """ + connection.execute(stmt) + + yield { + "email": customer_email, + "name": customer_name, + "id": customer_id, + "client": redshift_client, + "address_id": address_id, + "city": city, + "state": state, + "connector": connector, + } + # Remove test data and close Redshift connection in teardown + stmt = f"delete from customer where email = '{customer_email}';" + connection.execute(stmt) + + stmt = f'delete from address where "id" = {address_id};' + connection.execute(stmt) + + +@pytest.mark.integration_external +@pytest.mark.integration_redshift +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) +def test_create_and_process_access_request_redshift( + redshift_resources, + db, + cache, + policy, + run_privacy_request_task, + dsr_version, + request, +): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + + customer_email = redshift_resources["email"] + customer_name = redshift_resources["name"] + data = { + "requested_at": "2021-08-30T16:09:37.359Z", + "policy_key": policy.key, + "identity": {"email": customer_email}, + } + pr = get_privacy_request_results( + db, + policy, + run_privacy_request_task, + data, + task_timeout=PRIVACY_REQUEST_TASK_TIMEOUT_EXTERNAL, + ) + results = pr.get_raw_access_results() + customer_table_key = "redshift_example_test_dataset:customer" + assert len(results[customer_table_key]) == 1 + assert results[customer_table_key][0]["email"] == customer_email + assert results[customer_table_key][0]["name"] == customer_name + + address_table_key = "redshift_example_test_dataset:address" + + city = redshift_resources["city"] + state = redshift_resources["state"] + assert len(results[address_table_key]) == 1 + assert results[address_table_key][0]["city"] == city + assert results[address_table_key][0]["state"] == state + + pr.delete(db=db) + + +@pytest.mark.integration_external +@pytest.mark.integration_redshift +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) +def test_create_and_process_erasure_request_redshift( + redshift_example_test_dataset_config, + redshift_resources, + integration_config: Dict[str, str], + db, + cache, + erasure_policy, + dsr_version, + request, + run_privacy_request_task, +): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + + customer_email = redshift_resources["email"] + data = { + "requested_at": "2021-08-30T16:09:37.359Z", + "policy_key": erasure_policy.key, + "identity": {"email": customer_email}, + } + + # Should erase customer name + pr = get_privacy_request_results( + db, + erasure_policy, + run_privacy_request_task, + data, + task_timeout=PRIVACY_REQUEST_TASK_TIMEOUT_EXTERNAL, + ) + pr.delete(db=db) + + connector = redshift_resources["connector"] + redshift_client = redshift_resources["client"] + with redshift_client.connect() as connection: + connector.set_schema(connection) + stmt = f"select name from customer where email = '{customer_email}';" + res = connection.execute(stmt).all() + for row in res: + assert row.name is None + + address_id = redshift_resources["address_id"] + stmt = f"select 'id', city, state from address where id = {address_id};" + res = connection.execute(stmt).all() + for row in res: + # Not yet masked because these fields aren't targeted by erasure policy + assert row.city == redshift_resources["city"] + assert row.state == redshift_resources["state"] + + target = erasure_policy.rules[0].targets[0] + target.data_category = "user.contact.address.state" + target.save(db=db) + + # Should erase state fields on address table + pr = get_privacy_request_results( + db, + erasure_policy, + run_privacy_request_task, + data, + task_timeout=PRIVACY_REQUEST_TASK_TIMEOUT_EXTERNAL, + ) + pr.delete(db=db) + + connector = redshift_resources["connector"] + redshift_client = redshift_resources["client"] + with redshift_client.connect() as connection: + connector.set_schema(connection) + + address_id = redshift_resources["address_id"] + stmt = f"select 'id', city, state from address where id = {address_id};" + res = connection.execute(stmt).all() + for row in res: + # State field was targeted by erasure policy but city was not + assert row.city is not None + assert row.state is None diff --git a/tests/ops/service/privacy_request/test_request_runner_service.py b/tests/ops/service/privacy_request/test_request_runner_service.py index e644f06814..cda747792a 100644 --- a/tests/ops/service/privacy_request/test_request_runner_service.py +++ b/tests/ops/service/privacy_request/test_request_runner_service.py @@ -1,26 +1,20 @@ # pylint: disable=missing-docstring, redefined-outer-name import time -from datetime import datetime, timezone from typing import Any, Dict, List, Set from unittest import mock from unittest.mock import ANY, Mock, call -from uuid import uuid4 import pydash import pytest -from boto3.dynamodb.types import TypeDeserializer from pydantic import ValidationError -from sqlalchemy import column, select, table from sqlalchemy.orm import Session from fides.api.common_exceptions import ( ClientUnsuccessfulException, PrivacyRequestPaused, ) -from fides.api.graph.config import CollectionAddress, FieldPath from fides.api.graph.graph import DatasetGraph from fides.api.models.application_config import ApplicationConfig -from fides.api.models.audit_log import AuditLog, AuditLogAction from fides.api.models.policy import CurrentStep, PolicyPostWebhook from fides.api.models.privacy_request import ( ActionType, @@ -31,10 +25,7 @@ PrivacyRequest, PrivacyRequestStatus, ) -from fides.api.schemas.masking.masking_configuration import ( - HmacMaskingConfiguration, - MaskingConfiguration, -) +from fides.api.schemas.masking.masking_configuration import MaskingConfiguration from fides.api.schemas.masking.masking_secrets import MaskingSecretCache from fides.api.schemas.messaging.messaging import ( AccessRequestCompleteBodyParams, @@ -44,25 +35,18 @@ from fides.api.schemas.policy import Rule from fides.api.schemas.privacy_request import Consent from fides.api.schemas.redis_cache import Identity -from fides.api.schemas.saas.saas_config import SaaSRequest -from fides.api.schemas.saas.shared_schemas import HTTPMethod, SaaSRequestParams -from fides.api.service.connectors.dynamodb_connector import DynamoDBConnector -from fides.api.service.connectors.saas_connector import SaaSConnector -from fides.api.service.connectors.sql_connector import RedshiftConnector from fides.api.service.masking.strategy.masking_strategy import MaskingStrategy -from fides.api.service.masking.strategy.masking_strategy_hmac import HmacMaskingStrategy from fides.api.service.privacy_request.request_runner_service import ( build_consent_dataset_graph, needs_batch_email_send, run_webhooks_and_report_status, ) -from fides.api.util.data_category import DataCategory from fides.common.api.v1.urn_registry import REQUEST_TASK_CALLBACK, V1_URL_PREFIX from fides.config import CONFIG PRIVACY_REQUEST_TASK_TIMEOUT = 5 # External services take much longer to return -PRIVACY_REQUEST_TASK_TIMEOUT_EXTERNAL = 60 +PRIVACY_REQUEST_TASK_TIMEOUT_EXTERNAL = 100 @pytest.fixture(scope="function") @@ -298,1957 +282,78 @@ def test_resume_privacy_request_from_erasure( privacy_request.save(db) updated_at = privacy_request.updated_at - run_privacy_request_task.delay( - privacy_request_id=privacy_request.id, - from_step=CurrentStep.erasure.value, - ).get(timeout=PRIVACY_REQUEST_TASK_TIMEOUT) - - db.refresh(privacy_request) - assert privacy_request.started_processing_at is not None - assert privacy_request.updated_at > updated_at - - # Starting privacy request in the middle of the graph means we don't run pre-webhooks again - assert run_webhooks.call_count == 1 - assert run_webhooks.call_args[1]["webhook_cls"] == PolicyPostWebhook - - assert run_access.call_count == 0 # Access request skipped - assert run_erasure.call_count == 1 # Erasure request runs - - assert mock_email_dispatch.call_count == 1 - - -def get_privacy_request_results( - db, - policy, - run_privacy_request_task, - privacy_request_data: Dict[str, Any], - task_timeout=PRIVACY_REQUEST_TASK_TIMEOUT, -) -> PrivacyRequest: - """Utility method to run a privacy request and return results after waiting for - the returned future.""" - kwargs = { - "requested_at": pydash.get(privacy_request_data, "requested_at"), - "policy_id": policy.id, - "status": "pending", - } - optional_fields = ["started_processing_at", "finished_processing_at"] - for field in optional_fields: - try: - attr = getattr(privacy_request_data, field) - if attr is not None: - kwargs[field] = attr - except AttributeError: - pass - privacy_request = PrivacyRequest.create(db=db, data=kwargs) - privacy_request.cache_identity(privacy_request_data["identity"]) - privacy_request.cache_custom_privacy_request_fields( - privacy_request_data.get("custom_privacy_request_fields", None) - ) - if "encryption_key" in privacy_request_data: - privacy_request.cache_encryption(privacy_request_data["encryption_key"]) - - erasure_rules: List[Rule] = policy.get_rules_for_action( - action_type=ActionType.erasure - ) - unique_masking_strategies_by_name: Set[str] = set() - for rule in erasure_rules: - strategy_name: str = rule.masking_strategy["strategy"] - configuration: MaskingConfiguration = rule.masking_strategy["configuration"] - if strategy_name in unique_masking_strategies_by_name: - continue - unique_masking_strategies_by_name.add(strategy_name) - masking_strategy = MaskingStrategy.get_strategy(strategy_name, configuration) - if masking_strategy.secrets_required(): - masking_secrets: List[MaskingSecretCache] = ( - masking_strategy.generate_secrets_for_cache() - ) - for masking_secret in masking_secrets: - privacy_request.cache_masking_secret(masking_secret) - - run_privacy_request_task.delay(privacy_request.id).get( - timeout=task_timeout, - ) - - return PrivacyRequest.get(db=db, object_id=privacy_request.id) - - -@pytest.mark.integration_postgres -@pytest.mark.integration -@mock.patch("fides.api.service.privacy_request.request_runner_service.upload") -@pytest.mark.parametrize( - "dsr_version", - ["use_dsr_3_0", "use_dsr_2_0"], -) -def test_upload_access_results_has_data_category_field_mapping( - upload_mock: Mock, - postgres_example_test_dataset_config_read_access, - postgres_integration_db, - db, - policy, - dsr_version, - request, - run_privacy_request_task, -): - """ - Ensure we are passing along a correctly populated data_category_field_mapping to the 'upload' function - that publishes the access request output. - """ - upload_mock.return_value = "http://www.data-download-url" - - request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - - customer_email = "customer-1@example.com" - data = { - "requested_at": "2021-08-30T16:09:37.359Z", - "policy_key": policy.key, - "identity": {"email": customer_email}, - } - - pr = get_privacy_request_results( - db, - policy, - run_privacy_request_task, - data, - ) - - # sanity check that acccess results returned as expected - results = pr.get_raw_access_results() - assert len(results.keys()) == 11 - - # what we're really testing - ensure data_category_field_mapping arg is well-populated - args, kwargs = upload_mock.call_args - data_category_field_mapping = kwargs["data_category_field_mapping"] - - # make sure the category field mapping generally looks as we expect - address_mapping = data_category_field_mapping[ - CollectionAddress.from_string("postgres_example_test_dataset:address") - ] - assert len(address_mapping) >= 5 - assert address_mapping["user.contact.address.street"] == [ - FieldPath("house"), - FieldPath("street"), - ] - product_mapping = data_category_field_mapping[ - CollectionAddress.from_string("postgres_example_test_dataset:product") - ] - assert len(product_mapping) >= 1 - assert product_mapping["system.operations"] == [ - FieldPath( - "id", - ), - FieldPath( - "name", - ), - FieldPath( - "price", - ), - ] - - -@pytest.mark.integration_postgres -@pytest.mark.integration -@mock.patch("fides.api.service.privacy_request.request_runner_service.upload") -@pytest.mark.parametrize( - "dsr_version", - ["use_dsr_3_0", "use_dsr_2_0"], -) -def test_upload_access_results_has_data_use_map( - upload_mock: Mock, - postgres_example_test_dataset_config_read_access, - postgres_integration_db, - db, - policy, - dsr_version, - request, - run_privacy_request_task, -): - """ - Ensure we are passing along a correctly populated data_use_map to the 'upload' function - that publishes the access request output. - """ - upload_mock.return_value = "http://www.data-download-url" - - request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - - customer_email = "customer-1@example.com" - data = { - "requested_at": "2021-08-30T16:09:37.359Z", - "policy_key": policy.key, - "identity": {"email": customer_email}, - } - - pr = get_privacy_request_results( - db, - policy, - run_privacy_request_task, - data, - ) - - # sanity check that access results returned as expected - results = pr.get_raw_access_results() - assert len(results.keys()) == 11 - - # what we're really testing - ensure data_use_map arg is well-populated - args, kwargs = upload_mock.call_args - data_use_map = kwargs["data_use_map"] - - assert data_use_map == { - "postgres_example_test_dataset:report": "{'marketing.advertising'}", - "postgres_example_test_dataset:employee": "{'marketing.advertising'}", - "postgres_example_test_dataset:customer": "{'marketing.advertising'}", - "postgres_example_test_dataset:service_request": "{'marketing.advertising'}", - "postgres_example_test_dataset:visit": "{'marketing.advertising'}", - "postgres_example_test_dataset:address": "{'marketing.advertising'}", - "postgres_example_test_dataset:login": "{'marketing.advertising'}", - "postgres_example_test_dataset:orders": "{'marketing.advertising'}", - "postgres_example_test_dataset:payment_card": "{'marketing.advertising'}", - "postgres_example_test_dataset:order_item": "{'marketing.advertising'}", - "postgres_example_test_dataset:product": "{'marketing.advertising'}", - } - - -@pytest.mark.integration_postgres -@pytest.mark.integration -@mock.patch("fides.api.models.privacy_request.PrivacyRequest.trigger_policy_webhook") -@pytest.mark.parametrize( - "dsr_version", - ["use_dsr_3_0", "use_dsr_2_0"], -) -def test_create_and_process_access_request_postgres( - trigger_webhook_mock, - postgres_example_test_dataset_config_read_access, - postgres_integration_db, - db, - cache, - dsr_version, - request, - policy, - policy_pre_execution_webhooks, - policy_post_execution_webhooks, - run_privacy_request_task, -): - request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - - customer_email = "customer-1@example.com" - data = { - "requested_at": "2021-08-30T16:09:37.359Z", - "policy_key": policy.key, - "identity": {"email": customer_email}, - } - - pr = get_privacy_request_results( - db, - policy, - run_privacy_request_task, - data, - ) - - results = pr.get_raw_access_results() - assert len(results.keys()) == 11 - - for key in results.keys(): - assert results[key] is not None - assert results[key] != {} - - result_key_prefix = f"postgres_example_test_dataset:" - customer_key = result_key_prefix + "customer" - assert results[customer_key][0]["email"] == customer_email - - visit_key = result_key_prefix + "visit" - assert results[visit_key][0]["email"] == customer_email - log_id = pr.execution_logs[0].id - pr_id = pr.id - - finished_audit_log: AuditLog = AuditLog.filter( - db=db, - conditions=( - (AuditLog.privacy_request_id == pr_id) - & (AuditLog.action == AuditLogAction.finished) - ), - ).first() - - assert finished_audit_log is not None - - # Both pre-execution webhooks and both post-execution webhooks were called - assert trigger_webhook_mock.call_count == 4 - - for webhook in policy_pre_execution_webhooks: - webhook.delete(db=db) - - for webhook in policy_post_execution_webhooks: - webhook.delete(db=db) - - policy.delete(db=db) - pr.delete(db=db) - assert not pr in db # Check that `pr` has been expunged from the session - assert ExecutionLog.get(db, object_id=log_id).privacy_request_id == pr_id - - -@pytest.mark.integration_postgres -@pytest.mark.integration -@pytest.mark.parametrize( - "dsr_version", - ["use_dsr_3_0", "use_dsr_2_0"], -) -@mock.patch("fides.api.models.privacy_request.PrivacyRequest.trigger_policy_webhook") -def test_create_and_process_access_request_with_custom_identities_postgres( - trigger_webhook_mock, - postgres_example_test_dataset_config_read_access, - postgres_example_test_extended_dataset_config, - postgres_integration_db, - db, - cache, - policy, - dsr_version, - request, - policy_pre_execution_webhooks, - policy_post_execution_webhooks, - run_privacy_request_task, -): - request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - - customer_email = "customer-1@example.com" - loyalty_id = "CH-1" - data = { - "requested_at": "2021-08-30T16:09:37.359Z", - "policy_key": policy.key, - "identity": { - "email": customer_email, - "loyalty_id": {"label": "Loyalty ID", "value": loyalty_id}, - }, - } - - pr = get_privacy_request_results( - db, - policy, - run_privacy_request_task, - data, - ) - - results = pr.get_raw_access_results() - assert len(results.keys()) == 12 - - for key in results.keys(): - assert results[key] is not None - assert results[key] != {} - - result_key_prefix = f"postgres_example_test_dataset:" - customer_key = result_key_prefix + "customer" - assert results[customer_key][0]["email"] == customer_email - - visit_key = result_key_prefix + "visit" - assert results[visit_key][0]["email"] == customer_email - - loyalty_key = f"postgres_example_test_extended_dataset:loyalty" - assert results[loyalty_key][0]["id"] == loyalty_id - - log_id = pr.execution_logs[0].id - pr_id = pr.id - - finished_audit_log: AuditLog = AuditLog.filter( - db=db, - conditions=( - (AuditLog.privacy_request_id == pr_id) - & (AuditLog.action == AuditLogAction.finished) - ), - ).first() - - assert finished_audit_log is not None - - # Both pre-execution webhooks and both post-execution webhooks were called - assert trigger_webhook_mock.call_count == 4 - - for webhook in policy_pre_execution_webhooks: - webhook.delete(db=db) - - for webhook in policy_post_execution_webhooks: - webhook.delete(db=db) - - policy.delete(db=db) - pr.delete(db=db) - assert not pr in db # Check that `pr` has been expunged from the session - assert ExecutionLog.get(db, object_id=log_id).privacy_request_id == pr_id - - -@pytest.mark.integration_postgres -@pytest.mark.integration -@pytest.mark.usefixtures( - "postgres_example_test_dataset_config_skipped_login_collection", - "postgres_integration_db", - "cache", -) -@pytest.mark.parametrize( - "dsr_version", - ["use_dsr_3_0", "use_dsr_2_0"], -) -def test_create_and_process_access_request_with_valid_skipped_collection( - db, - policy, - run_privacy_request_task, - dsr_version, - request, -): - request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - - customer_email = "customer-1@example.com" - data = { - "requested_at": "2021-08-30T16:09:37.359Z", - "policy_key": policy.key, - "identity": {"email": customer_email}, - } - - pr = get_privacy_request_results( - db, - policy, - run_privacy_request_task, - data, - ) - - results = pr.get_raw_access_results() - assert len(results.keys()) == 10 - - assert "login" not in results.keys() - - result_key_prefix = f"postgres_example_test_dataset:" - customer_key = result_key_prefix + "customer" - assert results[customer_key][0]["email"] == customer_email - - assert AuditLog.filter( - db=db, - conditions=( - (AuditLog.privacy_request_id == pr.id) - & (AuditLog.action == AuditLogAction.finished) - ), - ).first() - - -@pytest.mark.integration_postgres -@pytest.mark.integration -@pytest.mark.usefixtures( - "postgres_example_test_dataset_config_skipped_address_collection", - "postgres_integration_db", - "cache", -) -@pytest.mark.parametrize( - "dsr_version", - ["use_dsr_3_0", "use_dsr_2_0"], -) -def test_create_and_process_access_request_with_invalid_skipped_collection( - db, - policy, - dsr_version, - request, - run_privacy_request_task, -): - request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - - customer_email = "customer-1@example.com" - data = { - "requested_at": "2021-08-30T16:09:37.359Z", - "policy_key": policy.key, - "identity": {"email": customer_email}, - } - - pr = get_privacy_request_results( - db, - policy, - run_privacy_request_task, - data, - ) - - results = pr.get_raw_access_results() - assert len(results.keys()) == 0 - - db.refresh(pr) - - assert pr.status == PrivacyRequestStatus.error - - -@pytest.mark.integration_postgres -@pytest.mark.integration -@pytest.mark.parametrize( - "dsr_version", - ["use_dsr_2_0", "use_dsr_3_0"], -) -def test_create_and_process_access_request_postgres_with_disabled_integration( - postgres_integration_db, - postgres_example_test_dataset_config, - connection_config, - db, - dsr_version, - request, - policy, - run_privacy_request_task, -): - request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - - data = { - "requested_at": "2021-08-30T16:09:37.359Z", - "policy_key": policy.key, - "identity": {"external_id": "ext-123"}, - } - - pr = get_privacy_request_results( - db, - policy, - run_privacy_request_task, - data, - task_timeout=PRIVACY_REQUEST_TASK_TIMEOUT_EXTERNAL, - ) - - for execution_log in pr.execution_logs: - assert execution_log.dataset_name == "Dataset traversal" - assert execution_log.status == ExecutionLogStatus.error - - connection_config.disabled = True - connection_config.save(db=db) - - pr = get_privacy_request_results( - db, - policy, - run_privacy_request_task, - data, - task_timeout=PRIVACY_REQUEST_TASK_TIMEOUT_EXTERNAL, - ) - - assert pr.execution_logs.count() == 1 - - execution_log = pr.execution_logs[0] - assert execution_log.dataset_name == "Dataset traversal" - assert execution_log.status == ExecutionLogStatus.complete - - -@pytest.mark.integration -@mock.patch("fides.api.models.privacy_request.PrivacyRequest.trigger_policy_webhook") -@pytest.mark.parametrize( - "dsr_version", - ["use_dsr_3_0", "use_dsr_2_0"], -) -def test_create_and_process_access_request_mssql( - trigger_webhook_mock, - mssql_example_test_dataset_config, - mssql_integration_db, - db, - cache, - policy, - dsr_version, - request, - policy_pre_execution_webhooks, - policy_post_execution_webhooks, - run_privacy_request_task, -): - request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - - customer_email = "customer-1@example.com" - data = { - "requested_at": "2021-08-30T16:09:37.359Z", - "policy_key": policy.key, - "identity": {"email": customer_email}, - } - - pr = get_privacy_request_results( - db, - policy, - run_privacy_request_task, - data, - ) - - results = pr.get_raw_access_results() - assert len(results.keys()) == 11 - - for key in results.keys(): - assert results[key] is not None - assert results[key] != {} - - result_key_prefix = f"mssql_example_test_dataset:" - customer_key = result_key_prefix + "customer" - assert results[customer_key][0]["email"] == customer_email - - visit_key = result_key_prefix + "visit" - assert results[visit_key][0]["email"] == customer_email - # Both pre-execution webhooks and both post-execution webhooks were called - assert trigger_webhook_mock.call_count == 4 - pr.delete(db=db) - - -@pytest.mark.integration -@pytest.mark.integration_mysql -@mock.patch("fides.api.models.privacy_request.PrivacyRequest.trigger_policy_webhook") -@pytest.mark.parametrize( - "dsr_version", - ["use_dsr_3_0", "use_dsr_2_0"], -) -def test_create_and_process_access_request_mysql( - trigger_webhook_mock, - mysql_example_test_dataset_config, - mysql_integration_db, - db, - cache, - policy, - dsr_version, - request, - policy_pre_execution_webhooks, - policy_post_execution_webhooks, - run_privacy_request_task, -): - request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - - customer_email = "customer-1@example.com" - data = { - "requested_at": "2021-08-30T16:09:37.359Z", - "policy_key": policy.key, - "identity": {"email": customer_email}, - } - - pr = get_privacy_request_results( - db, - policy, - run_privacy_request_task, - data, - ) - - results = pr.get_raw_access_results() - assert len(results.keys()) == 12 - - for key in results.keys(): - assert results[key] is not None - assert results[key] != {} - - result_key_prefix = f"mysql_example_test_dataset:" - customer_key = result_key_prefix + "customer" - assert results[customer_key][0]["email"] == customer_email - - visit_key = result_key_prefix + "visit" - assert results[visit_key][0]["email"] == customer_email - # Both pre-execution webhooks and both post-execution webhooks were called - assert trigger_webhook_mock.call_count == 4 - pr.delete(db=db) - - -@pytest.mark.integration -@pytest.mark.integration_scylladb -@mock.patch("fides.api.models.privacy_request.PrivacyRequest.trigger_policy_webhook") -@pytest.mark.parametrize( - "dsr_version", - ["use_dsr_3_0", "use_dsr_2_0"], -) -def test_create_and_process_access_request_scylladb( - trigger_webhook_mock, - scylladb_test_dataset_config, - scylla_reset_db, - db, - cache, - policy, - dsr_version, - request, - policy_pre_execution_webhooks, - policy_post_execution_webhooks, - run_privacy_request_task, -): - request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - - customer_email = "customer-1@example.com" - data = { - "requested_at": "2021-08-30T16:09:37.359Z", - "policy_key": policy.key, - "identity": {"email": customer_email}, - } - - pr = get_privacy_request_results( - db, - policy, - run_privacy_request_task, - data, - ) - - results = pr.get_raw_access_results() - assert len(results.keys()) == 4 - - assert "scylladb_example_test_dataset:users" in results - assert len(results["scylladb_example_test_dataset:users"]) == 1 - assert results["scylladb_example_test_dataset:users"][0]["email"] == customer_email - assert results["scylladb_example_test_dataset:users"][0]["age"] == 41 - assert results["scylladb_example_test_dataset:users"][0][ - "alternative_contacts" - ] == {"phone": "+1 (531) 988-5905", "work_email": "customer-1@example.com"} - - assert "scylladb_example_test_dataset:user_activity" in results - assert len(results["scylladb_example_test_dataset:user_activity"]) == 3 - - for activity in results["scylladb_example_test_dataset:user_activity"]: - assert activity["user_id"] - assert activity["timestamp"] - assert activity["activity_type"] - assert activity["user_agent"] - - assert "scylladb_example_test_dataset:payment_methods" in results - assert len(results["scylladb_example_test_dataset:payment_methods"]) == 2 - for payment_method in results["scylladb_example_test_dataset:payment_methods"]: - assert payment_method["payment_method_id"] - assert payment_method["card_number"] - assert payment_method["expiration_date"] - - assert "scylladb_example_test_dataset:orders" in results - assert len(results["scylladb_example_test_dataset:orders"]) == 2 - for payment_method in results["scylladb_example_test_dataset:orders"]: - assert payment_method["order_amount"] - assert payment_method["order_date"] - assert payment_method["order_description"] - - # Both pre-execution webhooks and both post-execution webhooks were called - assert trigger_webhook_mock.call_count == 4 - pr.delete(db=db) - - -@pytest.mark.integration -@pytest.mark.integration_scylladb -@mock.patch("fides.api.models.privacy_request.PrivacyRequest.trigger_policy_webhook") -@pytest.mark.parametrize( - "dsr_version", - ["use_dsr_3_0"], -) -def test_create_and_process_access_request_scylladb_no_keyspace( - trigger_webhook_mock, - scylladb_test_dataset_config_no_keyspace, - scylla_reset_db, - db, - cache, - policy, - dsr_version, - request, - policy_pre_execution_webhooks, - policy_post_execution_webhooks, - run_privacy_request_task, -): - request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - - customer_email = "customer-1@example.com" - data = { - "requested_at": "2021-08-30T16:09:37.359Z", - "policy_key": policy.key, - "identity": {"email": customer_email}, - } - - pr = get_privacy_request_results( - db, - policy, - run_privacy_request_task, - data, - ) - - assert ( - pr.access_tasks.count() == 6 - ) # There's 4 tables plus the root and terminal "dummy" tasks - - # Root task should be completed - assert pr.access_tasks.first().collection_name == "__ROOT__" - assert pr.access_tasks.first().status == ExecutionLogStatus.complete - - # All other tasks should be error - for access_task in pr.access_tasks.offset(1): - assert access_task.status == ExecutionLogStatus.error - - results = pr.get_raw_access_results() - assert results == {} - - -@pytest.mark.integration_external -@pytest.mark.integration_google_cloud_sql_mysql -@mock.patch("fides.api.models.privacy_request.PrivacyRequest.trigger_policy_webhook") -@pytest.mark.parametrize( - "dsr_version", - ["use_dsr_3_0", "use_dsr_2_0"], -) -def test_create_and_process_access_request_google_cloud_sql_mysql( - trigger_webhook_mock, - google_cloud_sql_mysql_example_test_dataset_config, - google_cloud_sql_mysql_integration_db, - db: Session, - cache, - policy, - dsr_version, - request, - policy_pre_execution_webhooks, - policy_post_execution_webhooks, - run_privacy_request_task, -): - request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - - customer_email = "customer-1@example.com" - data = { - "requested_at": "2021-08-30T16:09:37.359Z", - "policy_key": policy.key, - "identity": {"email": customer_email}, - } - - pr = get_privacy_request_results( - db, - policy, - run_privacy_request_task, - data, - task_timeout=PRIVACY_REQUEST_TASK_TIMEOUT_EXTERNAL, - ) - - results = pr.get_raw_access_results() - assert len(results.keys()) == 11 - - for key in results.keys(): - assert results[key] is not None - assert results[key] != {} - - result_key_prefix = "google_cloud_sql_mysql_example_test_dataset:" - customer_key = result_key_prefix + "customer" - assert results[customer_key][0]["email"] == customer_email - - visit_key = result_key_prefix + "visit" - assert results[visit_key][0]["email"] == customer_email - # Both pre-execution webhooks and both post-execution webhooks were called - assert trigger_webhook_mock.call_count == 4 - pr.delete(db=db) - - -@pytest.mark.integration_mariadb -@pytest.mark.integration -@mock.patch("fides.api.models.privacy_request.PrivacyRequest.trigger_policy_webhook") -@pytest.mark.parametrize( - "dsr_version", - ["use_dsr_3_0", "use_dsr_2_0"], -) -def test_create_and_process_access_request_mariadb( - trigger_webhook_mock, - mariadb_example_test_dataset_config, - mariadb_integration_db, - db, - cache, - policy, - dsr_version, - request, - policy_pre_execution_webhooks, - policy_post_execution_webhooks, - run_privacy_request_task, -): - request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - - customer_email = "customer-1@example.com" - data = { - "requested_at": "2021-08-30T16:09:37.359Z", - "policy_key": policy.key, - "identity": {"email": customer_email}, - } - - pr = get_privacy_request_results( - db, - policy, - run_privacy_request_task, - data, - ) - - results = pr.get_raw_access_results() - assert len(results.keys()) == 11 - - for key in results.keys(): - assert results[key] is not None - assert results[key] != {} - - result_key_prefix = "mariadb_example_test_dataset:" - customer_key = result_key_prefix + "customer" - assert results[customer_key][0]["email"] == customer_email - - visit_key = result_key_prefix + "visit" - assert results[visit_key][0]["email"] == customer_email - # Both pre-execution webhooks and both post-execution webhooks were called - assert trigger_webhook_mock.call_count == 4 - pr.delete(db=db) - - -@pytest.mark.integration_saas -@mock.patch("fides.api.models.privacy_request.PrivacyRequest.trigger_policy_webhook") -@pytest.mark.parametrize( - "dsr_version", - ["use_dsr_3_0", "use_dsr_2_0"], -) -def test_create_and_process_access_request_saas_mailchimp( - trigger_webhook_mock, - mailchimp_connection_config, - mailchimp_dataset_config, - db, - cache, - policy, - policy_pre_execution_webhooks, - policy_post_execution_webhooks, - dsr_version, - request, - mailchimp_identity_email, - run_privacy_request_task, -): - request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - - customer_email = mailchimp_identity_email - data = { - "requested_at": "2021-08-30T16:09:37.359Z", - "policy_key": policy.key, - "identity": {"email": customer_email}, - } - - pr = get_privacy_request_results( - db, - policy, - run_privacy_request_task, - data, - task_timeout=PRIVACY_REQUEST_TASK_TIMEOUT_EXTERNAL, - ) - results = pr.get_raw_access_results() - assert len(results.keys()) == 3 - - for key in results.keys(): - assert results[key] is not None - assert results[key] != {} - - result_key_prefix = f"mailchimp_instance:" - member_key = result_key_prefix + "member" - assert results[member_key][0]["email_address"] == customer_email - - # Both pre-execution webhooks and both post-execution webhooks were called - assert trigger_webhook_mock.call_count == 4 - - pr.delete(db=db) - - -@pytest.mark.integration_saas -@mock.patch("fides.api.models.privacy_request.PrivacyRequest.trigger_policy_webhook") -@pytest.mark.parametrize( - "dsr_version", - ["use_dsr_3_0", "use_dsr_2_0"], -) -def test_create_and_process_erasure_request_saas( - _, - mailchimp_connection_config, - mailchimp_dataset_config, - db, - cache, - erasure_policy_hmac, - generate_auth_header, - dsr_version, - request, - mailchimp_identity_email, - reset_mailchimp_data, - run_privacy_request_task, -): - request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - - customer_email = mailchimp_identity_email - data = { - "requested_at": "2021-08-30T16:09:37.359Z", - "policy_key": erasure_policy_hmac.key, - "identity": {"email": customer_email}, - } - - pr = get_privacy_request_results( - db, - erasure_policy_hmac, - run_privacy_request_task, - data, - task_timeout=PRIVACY_REQUEST_TASK_TIMEOUT_EXTERNAL, - ) - - connector = SaaSConnector(mailchimp_connection_config) - connector.set_saas_request_state( - SaaSRequest(path="test_path", method=HTTPMethod.GET) - ) # dummy request as connector requires it - request: SaaSRequestParams = SaaSRequestParams( - method=HTTPMethod.GET, - path="/3.0/search-members", - query_params={"query": mailchimp_identity_email}, - ) - resp = connector.create_client().send(request) - body = resp.json() - merge_fields = body["exact_matches"]["members"][0]["merge_fields"] - - masking_configuration = HmacMaskingConfiguration() - masking_strategy = HmacMaskingStrategy(masking_configuration) - - assert ( - merge_fields["FNAME"] - == masking_strategy.mask( - [reset_mailchimp_data["merge_fields"]["FNAME"]], pr.id - )[0] - ) - assert ( - merge_fields["LNAME"] - == masking_strategy.mask( - [reset_mailchimp_data["merge_fields"]["LNAME"]], pr.id - )[0] - ) - - pr.delete(db=db) - - -@pytest.mark.integration_saas -@mock.patch("fides.api.models.privacy_request.PrivacyRequest.trigger_policy_webhook") -@pytest.mark.parametrize( - "dsr_version", - ["use_dsr_3_0", "use_dsr_2_0"], -) -def test_create_and_process_access_request_saas_hubspot( - trigger_webhook_mock, - connection_config_hubspot, - dataset_config_hubspot, - db, - cache, - policy, - policy_pre_execution_webhooks, - policy_post_execution_webhooks, - dsr_version, - request, - hubspot_identity_email, - run_privacy_request_task, -): - request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - - customer_email = hubspot_identity_email - data = { - "requested_at": "2021-08-30T16:09:37.359Z", - "policy_key": policy.key, - "identity": {"email": customer_email}, - } - - pr = get_privacy_request_results( - db, - policy, - run_privacy_request_task, - data, - task_timeout=PRIVACY_REQUEST_TASK_TIMEOUT_EXTERNAL, - ) - results = pr.get_raw_access_results() - assert len(results.keys()) == 4 - - for key in results.keys(): - assert results[key] is not None - assert results[key] != {} - - result_key_prefix = f"hubspot_instance:" - contacts_key = result_key_prefix + "contacts" - assert results[contacts_key][0]["properties"]["email"] == customer_email - - # Both pre-execution webhooks and both post-execution webhooks were called - assert trigger_webhook_mock.call_count == 4 - - pr.delete(db=db) - - -@pytest.mark.integration_postgres -@pytest.mark.integration -@pytest.mark.parametrize( - "dsr_version", - ["use_dsr_3_0", "use_dsr_2_0"], -) -def test_create_and_process_erasure_request_specific_category_postgres( - postgres_integration_db, - postgres_example_test_dataset_config, - cache, - db, - generate_auth_header, - erasure_policy, - dsr_version, - request, - read_connection_config, - run_privacy_request_task, -): - request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - - customer_email = "customer-1@example.com" - customer_id = 1 - data = { - "requested_at": "2021-08-30T16:09:37.359Z", - "policy_key": erasure_policy.key, - "identity": {"email": customer_email}, - } - - stmt = select("*").select_from(table("customer")) - res = postgres_integration_db.execute(stmt).all() - - pr = get_privacy_request_results( - db, - erasure_policy, - run_privacy_request_task, - data, - ) - pr.delete(db=db) - - stmt = select( - column("id"), - column("name"), - ).select_from(table("customer")) - res = postgres_integration_db.execute(stmt).all() - - customer_found = False - for row in res: - if customer_id == row.id: - customer_found = True - # Check that the `name` field is `None` - assert row.name is None - assert customer_found - - -@pytest.mark.integration_postgres -@pytest.mark.integration -@pytest.mark.parametrize( - "dsr_version", - ["use_dsr_3_0", "use_dsr_2_0"], -) -def test_create_and_process_erasure_request_with_masking_strategy_override( - postgres_integration_db, - postgres_example_test_dataset_config, - cache, - db, - generate_auth_header, - erasure_policy, - dsr_version, - request, - read_connection_config, - run_privacy_request_task, -): - request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - - employee_email = "employee-1@example.com" - employee_id = 1 - data = { - "requested_at": "2021-08-30T16:09:37.359Z", - "policy_key": erasure_policy.key, - "identity": {"email": employee_email}, - } - - stmt = select("*").select_from(table("employee")) - res = postgres_integration_db.execute(stmt).all() - - pr = get_privacy_request_results( - db, - erasure_policy, - run_privacy_request_task, - data, - ) - pr.delete(db=db) - - stmt = select( - column("id"), - column("name"), - ).select_from(table("employee")) - res = postgres_integration_db.execute(stmt).all() - - customer_found = False - for row in res: - if employee_id == row.id: - customer_found = True - # Check that the `name` field was masked with the override provided in the dataset - assert row.name == "testing-test" - assert customer_found - - -@pytest.mark.integration_mssql -@pytest.mark.integration -@pytest.mark.parametrize( - "dsr_version", - ["use_dsr_3_0", "use_dsr_2_0"], -) -def test_create_and_process_erasure_request_specific_category_mssql( - mssql_integration_db, - mssql_example_test_dataset_config, - cache, - db, - dsr_version, - request, - generate_auth_header, - erasure_policy, - run_privacy_request_task, -): - request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - - customer_email = "customer-1@example.com" - customer_id = 1 - data = { - "requested_at": "2021-08-30T16:09:37.359Z", - "policy_key": erasure_policy.key, - "identity": {"email": customer_email}, - } - - pr = get_privacy_request_results( - db, - erasure_policy, - run_privacy_request_task, - data, - ) - pr.delete(db=db) - - stmt = select( - column("id"), - column("name"), - ).select_from(table("customer")) - res = mssql_integration_db.execute(stmt).all() - - customer_found = False - for row in res: - if customer_id == row.id: - customer_found = True - # Check that the `name` field is `None` - assert row.name is None - assert customer_found - - -@pytest.mark.integration_mysql -@pytest.mark.integration -@pytest.mark.parametrize( - "dsr_version", - ["use_dsr_3_0", "use_dsr_2_0"], -) -def test_create_and_process_erasure_request_specific_category_mysql( - mysql_integration_db, - mysql_example_test_dataset_config, - cache, - db, - dsr_version, - request, - generate_auth_header, - erasure_policy, - run_privacy_request_task, -): - request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - - customer_email = "customer-1@example.com" - customer_id = 1 - data = { - "requested_at": "2021-08-30T16:09:37.359Z", - "policy_key": erasure_policy.key, - "identity": {"email": customer_email}, - } - - pr = get_privacy_request_results( - db, - erasure_policy, - run_privacy_request_task, - data, - ) - pr.delete(db=db) - - stmt = select( - column("id"), - column("name"), - ).select_from(table("customer")) - res = mysql_integration_db.execute(stmt).all() - - customer_found = False - for row in res: - if customer_id == row.id: - customer_found = True - # Check that the `name` field is `None` - assert row.name is None - assert customer_found - - -@pytest.mark.integration_external -@pytest.mark.integration_google_cloud_sql_mysql -@pytest.mark.parametrize( - "dsr_version", - ["use_dsr_3_0", "use_dsr_2_0"], -) -def test_create_and_process_erasure_request_google_cloud_sql_mysql( - google_cloud_sql_mysql_integration_db, - google_cloud_sql_mysql_example_test_dataset_config, - cache, - db, - dsr_version, - request, - generate_auth_header, - erasure_policy, - run_privacy_request_task, -): - request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - - customer_email = "customer-1@example.com" - customer_id = 1 - data = { - "requested_at": "2021-08-30T16:09:37.359Z", - "policy_key": erasure_policy.key, - "identity": {"email": customer_email}, - } - - pr = get_privacy_request_results( - db, - erasure_policy, - run_privacy_request_task, - data, - task_timeout=PRIVACY_REQUEST_TASK_TIMEOUT_EXTERNAL, - ) - pr.delete(db=db) - - stmt = select( - column("id"), - column("name"), - ).select_from(table("customer")) - res = google_cloud_sql_mysql_integration_db.execute(stmt).all() - - customer_found = False - for row in res: - if customer_id == row.id: - customer_found = True - # Check that the `name` field is `None` - assert row.name is None - assert customer_found - - -@pytest.mark.integration_mariadb -@pytest.mark.integration -@pytest.mark.parametrize( - "dsr_version", - ["use_dsr_3_0", "use_dsr_2_0"], -) -def test_create_and_process_erasure_request_specific_category_mariadb( - mariadb_example_test_dataset_config, - mariadb_integration_db, - cache, - db, - dsr_version, - request, - generate_auth_header, - erasure_policy, - run_privacy_request_task, -): - request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - - customer_email = "customer-1@example.com" - customer_id = 1 - data = { - "requested_at": "2021-08-30T16:09:37.359Z", - "policy_key": erasure_policy.key, - "identity": {"email": customer_email}, - } - - pr = get_privacy_request_results( - db, - erasure_policy, - run_privacy_request_task, - data, - ) - pr.delete(db=db) - - stmt = select( - column("id"), - column("name"), - ).select_from(table("customer")) - res = mariadb_integration_db.execute(stmt).all() - - customer_found = False - for row in res: - if customer_id == row.id: - customer_found = True - # Check that the `name` field is `None` - assert row.name is None - assert customer_found - - -@pytest.mark.integration_postgres -@pytest.mark.integration -@pytest.mark.parametrize( - "dsr_version", - ["use_dsr_3_0", "use_dsr_2_0"], -) -def test_create_and_process_erasure_request_generic_category( - postgres_integration_db, - postgres_example_test_dataset_config, - cache, - db, - dsr_version, - request, - generate_auth_header, - erasure_policy, - run_privacy_request_task, -): - request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - - # It's safe to change this here since the `erasure_policy` fixture is scoped - # at function level - target = erasure_policy.rules[0].targets[0] - target.data_category = DataCategory("user.contact").value - target.save(db=db) - - email = "customer-2@example.com" - customer_id = 2 - data = { - "requested_at": "2021-08-30T16:09:37.359Z", - "policy_key": erasure_policy.key, - "identity": {"email": email}, - } - - pr = get_privacy_request_results( - db, - erasure_policy, - run_privacy_request_task, - data, - ) - pr.delete(db=db) - - stmt = select( - column("id"), - column("email"), - column("name"), - ).select_from(table("customer")) - res = postgres_integration_db.execute(stmt).all() - - customer_found = False - for row in res: - if customer_id == row.id: - customer_found = True - # Check that the `email` field is `None` and that its data category - # ("user.contact.email") has been erased by the parent - # category ("user.contact") - assert row.email is None - assert row.name is not None - else: - # There are two rows other rows, and they should not have been erased - assert row.email in ["customer-1@example.com", "jane@example.com"] - assert customer_found - - -@pytest.mark.integration_postgres -@pytest.mark.integration -@pytest.mark.parametrize( - "dsr_version", - ["use_dsr_3_0", "use_dsr_2_0"], -) -def test_create_and_process_erasure_request_aes_generic_category( - postgres_integration_db, - postgres_example_test_dataset_config, - cache, - db, - dsr_version, - request, - generate_auth_header, - erasure_policy_aes, - run_privacy_request_task, -): - request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - - # It's safe to change this here since the `erasure_policy` fixture is scoped - # at function level - target = erasure_policy_aes.rules[0].targets[0] - target.data_category = DataCategory("user.contact").value - target.save(db=db) - - email = "customer-2@example.com" - customer_id = 2 - data = { - "requested_at": "2021-08-30T16:09:37.359Z", - "policy_key": erasure_policy_aes.key, - "identity": {"email": email}, - } - - pr = get_privacy_request_results( - db, - erasure_policy_aes, - run_privacy_request_task, - data, - ) - pr.delete(db=db) - - stmt = select( - column("id"), - column("email"), - column("name"), - ).select_from(table("customer")) - res = postgres_integration_db.execute(stmt).all() - - customer_found = False - for row in res: - if customer_id == row.id: - customer_found = True - # Check that the `email` field is not original val and that its data category - # ("user.contact.email") has been erased by the parent - # category ("user.contact"). - # masked val for `email` field will change per new privacy request, so the best - # we can do here is test that the original val has been changed - assert row[1] != "customer-2@example.com" - assert row[2] is not None - else: - # There are two rows other rows, and they should not have been erased - assert row[1] in ["customer-1@example.com", "jane@example.com"] - assert customer_found - - -@pytest.mark.integration_postgres -@pytest.mark.integration -@pytest.mark.parametrize( - "dsr_version", - ["use_dsr_3_0", "use_dsr_2_0"], -) -def test_create_and_process_erasure_request_with_table_joins( - postgres_integration_db, - postgres_example_test_dataset_config, - db, - cache, - dsr_version, - request, - erasure_policy, - run_privacy_request_task, -): - request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - - # It's safe to change this here since the `erasure_policy` fixture is scoped - # at function level - target = erasure_policy.rules[0].targets[0] - target.data_category = DataCategory("user.financial").value - target.save(db=db) - - customer_email = "customer-1@example.com" - customer_id = 1 - data = { - "requested_at": "2021-08-30T16:09:37.359Z", - "policy_key": erasure_policy.key, - "identity": {"email": customer_email}, - } - - pr = get_privacy_request_results( - db, - erasure_policy, - run_privacy_request_task, - data, - ) - pr.delete(db=db) - - stmt = select( - column("customer_id"), - column("id"), - column("ccn"), - column("code"), - column("name"), - ).select_from(table("payment_card")) - res = postgres_integration_db.execute(stmt).all() - - card_found = False - for row in res: - if row.customer_id == customer_id: - card_found = True - assert row.ccn is None - assert row.code is None - assert row.name is None - - assert card_found is True - - -@pytest.mark.integration_postgres -@pytest.mark.integration -@pytest.mark.parametrize( - "dsr_version", - ["use_dsr_3_0", "use_dsr_2_0"], -) -def test_create_and_process_erasure_request_read_access( - postgres_integration_db, - postgres_example_test_dataset_config_read_access, - db, - cache, - erasure_policy, - dsr_version, - request, - run_privacy_request_task, -): - request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - - customer_email = "customer-2@example.com" - customer_id = 2 - data = { - "requested_at": "2021-08-30T16:09:37.359Z", - "policy_key": erasure_policy.key, - "identity": {"email": customer_email}, - } - - pr = get_privacy_request_results( - db, - erasure_policy, - run_privacy_request_task, - data, - ) - errored_execution_logs = pr.execution_logs.filter_by(status="error") - assert errored_execution_logs.count() == 9 - assert ( - errored_execution_logs[0].message - == "No values were erased since this connection " - "my_postgres_db_1_read_config has not been given write access" - ) - pr.delete(db=db) - - stmt = select( - column("id"), - column("name"), - ).select_from(table("customer")) - res = postgres_integration_db.execute(stmt).all() - - customer_found = False - for row in res: - if customer_id == row.id: - customer_found = True - # Check that the `name` field is NOT `None`. We couldn't erase, because the ConnectionConfig only had - # "read" access - assert row.name is not None - assert customer_found - - -@pytest.fixture(scope="function") -def redshift_resources( - redshift_example_test_dataset_config, -): - redshift_connection_config = redshift_example_test_dataset_config.connection_config - connector = RedshiftConnector(redshift_connection_config) - redshift_client = connector.client() - with redshift_client.connect() as connection: - connector.set_schema(connection) - uuid = str(uuid4()) - customer_email = f"customer-{uuid}@example.com" - customer_name = f"{uuid}" - - stmt = "select max(id) from customer;" - res = connection.execute(stmt) - customer_id = res.all()[0][0] + 1 - - stmt = "select max(id) from address;" - res = connection.execute(stmt) - address_id = res.all()[0][0] + 1 - - city = "Test City" - state = "TX" - stmt = f""" - insert into address (id, house, street, city, state, zip) - values ({address_id}, '{111}', 'Test Street', '{city}', '{state}', '55555'); - """ - connection.execute(stmt) - - stmt = f""" - insert into customer (id, email, name, address_id) - values ({customer_id}, '{customer_email}', '{customer_name}', '{address_id}'); - """ - connection.execute(stmt) - - yield { - "email": customer_email, - "name": customer_name, - "id": customer_id, - "client": redshift_client, - "address_id": address_id, - "city": city, - "state": state, - "connector": connector, - } - # Remove test data and close Redshift connection in teardown - stmt = f"delete from customer where email = '{customer_email}';" - connection.execute(stmt) - - stmt = f'delete from address where "id" = {address_id};' - connection.execute(stmt) - - -@pytest.mark.integration_external -@pytest.mark.integration_redshift -@pytest.mark.parametrize( - "dsr_version", - ["use_dsr_3_0", "use_dsr_2_0"], -) -def test_create_and_process_access_request_redshift( - redshift_resources, - db, - cache, - policy, - run_privacy_request_task, - dsr_version, - request, -): - request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - - customer_email = redshift_resources["email"] - customer_name = redshift_resources["name"] - data = { - "requested_at": "2021-08-30T16:09:37.359Z", - "policy_key": policy.key, - "identity": {"email": customer_email}, - } - pr = get_privacy_request_results( - db, - policy, - run_privacy_request_task, - data, - task_timeout=PRIVACY_REQUEST_TASK_TIMEOUT_EXTERNAL, - ) - results = pr.get_raw_access_results() - customer_table_key = f"redshift_example_test_dataset:customer" - assert len(results[customer_table_key]) == 1 - assert results[customer_table_key][0]["email"] == customer_email - assert results[customer_table_key][0]["name"] == customer_name - - address_table_key = f"redshift_example_test_dataset:address" - - city = redshift_resources["city"] - state = redshift_resources["state"] - assert len(results[address_table_key]) == 1 - assert results[address_table_key][0]["city"] == city - assert results[address_table_key][0]["state"] == state - - pr.delete(db=db) - - -@pytest.mark.integration_external -@pytest.mark.integration_redshift -@pytest.mark.parametrize( - "dsr_version", - ["use_dsr_3_0", "use_dsr_2_0"], -) -def test_create_and_process_erasure_request_redshift( - redshift_example_test_dataset_config, - redshift_resources, - integration_config: Dict[str, str], - db, - cache, - erasure_policy, - dsr_version, - request, - run_privacy_request_task, -): - request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - - customer_email = redshift_resources["email"] - data = { - "requested_at": "2021-08-30T16:09:37.359Z", - "policy_key": erasure_policy.key, - "identity": {"email": customer_email}, - } - - # Should erase customer name - pr = get_privacy_request_results( - db, - erasure_policy, - run_privacy_request_task, - data, - task_timeout=PRIVACY_REQUEST_TASK_TIMEOUT_EXTERNAL, - ) - pr.delete(db=db) - - connector = redshift_resources["connector"] - redshift_client = redshift_resources["client"] - with redshift_client.connect() as connection: - connector.set_schema(connection) - stmt = f"select name from customer where email = '{customer_email}';" - res = connection.execute(stmt).all() - for row in res: - assert row.name is None - - address_id = redshift_resources["address_id"] - stmt = f"select 'id', city, state from address where id = {address_id};" - res = connection.execute(stmt).all() - for row in res: - # Not yet masked because these fields aren't targeted by erasure policy - assert row.city == redshift_resources["city"] - assert row.state == redshift_resources["state"] - - target = erasure_policy.rules[0].targets[0] - target.data_category = "user.contact.address.state" - target.save(db=db) - - # Should erase state fields on address table - pr = get_privacy_request_results( - db, - erasure_policy, - run_privacy_request_task, - data, - task_timeout=PRIVACY_REQUEST_TASK_TIMEOUT_EXTERNAL, - ) - pr.delete(db=db) - - connector = redshift_resources["connector"] - redshift_client = redshift_resources["client"] - with redshift_client.connect() as connection: - connector.set_schema(connection) - - address_id = redshift_resources["address_id"] - stmt = f"select 'id', city, state from address where id = {address_id};" - res = connection.execute(stmt).all() - for row in res: - # State field was targeted by erasure policy but city was not - assert row.city is not None - assert row.state is None - - -@pytest.mark.integration_external -@pytest.mark.integration_bigquery -@pytest.mark.parametrize( - "dsr_version", - ["use_dsr_2_0", "use_dsr_3_0"], -) -@pytest.mark.parametrize( - "bigquery_fixtures", - ["bigquery_resources", "bigquery_resources_with_namespace_meta"], -) -def test_create_and_process_access_request_bigquery( - db, - policy, - dsr_version, - request, - bigquery_fixtures, - run_privacy_request_task, -): - request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - bigquery_resources = request.getfixturevalue(bigquery_fixtures) - - customer_email = bigquery_resources["email"] - customer_name = bigquery_resources["name"] - data = { - "requested_at": "2021-08-30T16:09:37.359Z", - "policy_key": policy.key, - "identity": {"email": customer_email}, - } - bigquery_client = bigquery_resources["client"] - with bigquery_client.connect() as connection: - stmt = f"select * from fidesopstest.employee where address_id = {bigquery_resources['address_id']};" - res = connection.execute(stmt).all() - for row in res: - assert row.address_id == bigquery_resources["address_id"] - assert row.id == bigquery_resources["employee_id"] - assert row.email == bigquery_resources["employee_email"] - - pr = get_privacy_request_results( - db, - policy, - run_privacy_request_task, - data, - task_timeout=PRIVACY_REQUEST_TASK_TIMEOUT_EXTERNAL, - ) - results = pr.get_raw_access_results() - customer_table_key = "bigquery_example_test_dataset:customer" - assert len(results[customer_table_key]) == 1 - assert results[customer_table_key][0]["email"] == customer_email - assert results[customer_table_key][0]["name"] == customer_name - - address_table_key = "bigquery_example_test_dataset:address" - - city = bigquery_resources["city"] - state = bigquery_resources["state"] - assert len(results[address_table_key]) == 1 - assert results[address_table_key][0]["city"] == city - assert results[address_table_key][0]["state"] == state - - employee_table_key = "bigquery_example_test_dataset:employee" - assert len(results[employee_table_key]) == 1 - assert results["bigquery_example_test_dataset:employee"] != [] - assert ( - results[employee_table_key][0]["address_id"] == bigquery_resources["address_id"] - ) - assert ( - results[employee_table_key][0]["email"] == bigquery_resources["employee_email"] - ) - assert results[employee_table_key][0]["id"] == bigquery_resources["employee_id"] + run_privacy_request_task.delay( + privacy_request_id=privacy_request.id, + from_step=CurrentStep.erasure.value, + ).get(timeout=PRIVACY_REQUEST_TASK_TIMEOUT) - # this covers access requests against a partitioned table - visit_partitioned_table_key = "bigquery_example_test_dataset:visit_partitioned" - assert len(results[visit_partitioned_table_key]) == 1 - assert ( - results[visit_partitioned_table_key][0]["email"] == bigquery_resources["email"] - ) + db.refresh(privacy_request) + assert privacy_request.started_processing_at is not None + assert privacy_request.updated_at > updated_at - pr.delete(db=db) + # Starting privacy request in the middle of the graph means we don't run pre-webhooks again + assert run_webhooks.call_count == 1 + assert run_webhooks.call_args[1]["webhook_cls"] == PolicyPostWebhook + assert run_access.call_count == 0 # Access request skipped + assert run_erasure.call_count == 1 # Erasure request runs -@pytest.mark.integration_external -@pytest.mark.integration_bigquery -@pytest.mark.parametrize( - "dsr_version", - ["use_dsr_2_0", "use_dsr_3_0"], -) -@pytest.mark.parametrize( - "bigquery_fixtures", - ["bigquery_resources", "bigquery_resources_with_namespace_meta"], -) -def test_create_and_process_erasure_request_bigquery( + assert mock_email_dispatch.call_count == 1 + + +def get_privacy_request_results( db, - dsr_version, - request, - bigquery_fixtures, - biquery_erasure_policy, + policy, run_privacy_request_task, -): - request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - bigquery_resources = request.getfixturevalue(bigquery_fixtures) - - bigquery_client = bigquery_resources["client"] - # Verifying that employee info exists in db - with bigquery_client.connect() as connection: - stmt = f"select * from fidesopstest.employee where address_id = {bigquery_resources['address_id']};" - res = connection.execute(stmt).all() - for row in res: - assert row.address_id == bigquery_resources["address_id"] - assert row.id == bigquery_resources["employee_id"] - assert row.email == bigquery_resources["employee_email"] - - customer_email = bigquery_resources["email"] - data = { - "requested_at": "2021-08-30T16:09:37.359Z", - "policy_key": biquery_erasure_policy.key, - "identity": {"email": customer_email}, + privacy_request_data: Dict[str, Any], + task_timeout=PRIVACY_REQUEST_TASK_TIMEOUT, +) -> PrivacyRequest: + """Utility method to run a privacy request and return results after waiting for + the returned future.""" + kwargs = { + "requested_at": pydash.get(privacy_request_data, "requested_at"), + "policy_id": policy.id, + "status": "pending", } - - # Should erase customer name - pr = get_privacy_request_results( - db, - biquery_erasure_policy, - run_privacy_request_task, - data, - task_timeout=PRIVACY_REQUEST_TASK_TIMEOUT_EXTERNAL, + optional_fields = ["started_processing_at", "finished_processing_at"] + for field in optional_fields: + try: + attr = getattr(privacy_request_data, field) + if attr is not None: + kwargs[field] = attr + except AttributeError: + pass + privacy_request = PrivacyRequest.create(db=db, data=kwargs) + privacy_request.cache_identity(privacy_request_data["identity"]) + privacy_request.cache_custom_privacy_request_fields( + privacy_request_data.get("custom_privacy_request_fields", None) ) - pr.delete(db=db) + if "encryption_key" in privacy_request_data: + privacy_request.cache_encryption(privacy_request_data["encryption_key"]) - bigquery_client = bigquery_resources["client"] - with bigquery_client.connect() as connection: - stmt = ( - f"select name from fidesopstest.customer where email = '{customer_email}';" - ) - res = connection.execute(stmt).all() - for row in res: - assert row.name is None - - address_id = bigquery_resources["address_id"] - stmt = f"select 'id', city, state from fidesopstest.address where id = {address_id};" - res = connection.execute(stmt).all() - for row in res: - # Not yet masked because these fields aren't targeted by erasure policy - assert row.city == bigquery_resources["city"] - assert row.state == bigquery_resources["state"] - - for target in biquery_erasure_policy.rules[0].targets: - if target.data_category == "user.name": - target.data_category = "user.contact.address.state" - target.save(db=db) - - # Should erase state fields on address table - pr = get_privacy_request_results( - db, - biquery_erasure_policy, - run_privacy_request_task, - data, - task_timeout=PRIVACY_REQUEST_TASK_TIMEOUT_EXTERNAL, + erasure_rules: List[Rule] = policy.get_rules_for_action( + action_type=ActionType.erasure ) + unique_masking_strategies_by_name: Set[str] = set() + for rule in erasure_rules: + strategy_name: str = rule.masking_strategy["strategy"] + configuration: MaskingConfiguration = rule.masking_strategy["configuration"] + if strategy_name in unique_masking_strategies_by_name: + continue + unique_masking_strategies_by_name.add(strategy_name) + masking_strategy = MaskingStrategy.get_strategy(strategy_name, configuration) + if masking_strategy.secrets_required(): + masking_secrets: List[MaskingSecretCache] = ( + masking_strategy.generate_secrets_for_cache() + ) + for masking_secret in masking_secrets: + privacy_request.cache_masking_secret(masking_secret) - bigquery_client = bigquery_resources["client"] - with bigquery_client.connect() as connection: - address_id = bigquery_resources["address_id"] - stmt = f"select 'id', city, state, street from fidesopstest.address where id = {address_id};" - res = connection.execute(stmt).all() - for row in res: - # State field was targeted by erasure policy but city was not - assert row.city is not None - assert row.state is None - # Street field was targeted by erasure policy but overridden by field-level masking_strategy_override - assert row.street == "REDACTED" - - stmt = f"select * from fidesopstest.employee where address_id = {bigquery_resources['address_id']};" - res = connection.execute(stmt).all() - - # Employee records deleted entirely due to collection-level masking strategy override - assert res == [] + run_privacy_request_task.delay(privacy_request.id).get( + timeout=task_timeout, + ) - pr.delete(db=db) + return PrivacyRequest.get(db=db, object_id=privacy_request.id) class TestRunPrivacyRequestRunsWebhooks: @@ -3129,483 +1234,141 @@ def test_needs_batch_email_send_new_workflow( ) -@pytest.fixture(scope="function") -def dynamodb_resources( - dynamodb_example_test_dataset_config, -): - dynamodb_connection_config = dynamodb_example_test_dataset_config.connection_config - dynamodb_client = DynamoDBConnector(dynamodb_connection_config).client() - uuid = str(uuid4()) - customer_email = f"customer-{uuid}@example.com" - customer_name = f"{uuid}" - - ## document and remove remaining comments if we can't get the bigger test running - items = { - "customer_identifier": [ - { - "customer_id": {"S": customer_name}, - "email": {"S": customer_email}, - "name": {"S": customer_name}, - "created": {"S": datetime.now(timezone.utc).isoformat()}, - } - ], - "customer": [ - { - "id": {"S": customer_name}, - "name": {"S": customer_name}, - "email": {"S": customer_email}, - "address_id": {"L": [{"S": customer_name}, {"S": customer_name}]}, - "personal_info": {"M": {"gender": {"S": "male"}, "age": {"S": "99"}}}, - "created": {"S": datetime.now(timezone.utc).isoformat()}, - } - ], - "address": [ - { - "id": {"S": customer_name}, - "city": {"S": "city"}, - "house": {"S": "house"}, - "state": {"S": "state"}, - "street": {"S": "street"}, - "zip": {"S": "zip"}, - } - ], - "login": [ - { - "customer_id": {"S": customer_name}, - "login_date": {"S": "2023-01-01"}, - "name": {"S": customer_name}, - "email": {"S": customer_email}, - }, - { - "customer_id": {"S": customer_name}, - "login_date": {"S": "2023-01-02"}, - "name": {"S": customer_name}, - "email": {"S": customer_email}, - }, - ], - } - - for table_name, rows in items.items(): - for item in rows: - res = dynamodb_client.put_item( - TableName=table_name, - Item=item, - ) - assert res["ResponseMetadata"]["HTTPStatusCode"] == 200 - - yield { - "email": customer_email, - "formatted_email": customer_email, - "name": customer_name, - "customer_id": uuid, - "client": dynamodb_client, - } - # Remove test data and close Dynamodb connection in teardown - delete_items = { - "customer_identifier": [{"email": {"S": customer_email}}], - "customer": [{"id": {"S": customer_name}}], - "address": [{"id": {"S": customer_name}}], - "login": [ - { - "customer_id": {"S": customer_name}, - "login_date": {"S": "2023-01-01"}, - }, - { - "customer_id": {"S": customer_name}, - "login_date": {"S": "2023-01-02"}, - }, - ], - } - for table_name, rows in delete_items.items(): - for item in rows: - res = dynamodb_client.delete_item( - TableName=table_name, - Key=item, - ) - assert res["ResponseMetadata"]["HTTPStatusCode"] == 200 - - -@pytest.mark.integration_external -@pytest.mark.integration_dynamodb -@pytest.mark.parametrize( - "dsr_version", - ["use_dsr_3_0", "use_dsr_2_0"], -) -def test_create_and_process_empty_access_request_dynamodb( - db, - cache, - policy, - dsr_version, - request, - run_privacy_request_task, -): - request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - - data = { - "requested_at": "2021-08-30T16:09:37.359Z", - "policy_key": policy.key, - "identity": {"email": "thiscustomerdoesnot@exist.com"}, - } - - pr = get_privacy_request_results( - db, - policy, - run_privacy_request_task, - data, - task_timeout=PRIVACY_REQUEST_TASK_TIMEOUT_EXTERNAL, - ) - # Here the results should be empty as no data will be located for that identity - results = pr.get_raw_access_results() - pr.delete(db=db) - assert results == {} - - -@pytest.mark.integration_external -@pytest.mark.integration_dynamodb -@pytest.mark.parametrize( - "dsr_version", - ["use_dsr_3_0", "use_dsr_2_0"], -) -def test_create_and_process_access_request_dynamodb( - dynamodb_resources, - db, - cache, - policy, - run_privacy_request_task, - dsr_version, - request, -): - request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - - customer_email = dynamodb_resources["email"] - customer_name = dynamodb_resources["name"] - customer_id = dynamodb_resources["customer_id"] - data = { - "requested_at": "2021-08-30T16:09:37.359Z", - "policy_key": policy.key, - "identity": {"email": customer_email}, - } - - pr = get_privacy_request_results( - db, - policy, - run_privacy_request_task, - data, - task_timeout=PRIVACY_REQUEST_TASK_TIMEOUT_EXTERNAL, - ) - results = pr.get_raw_access_results() - customer_table_key = f"dynamodb_example_test_dataset:customer" - address_table_key = f"dynamodb_example_test_dataset:address" - login_table_key = f"dynamodb_example_test_dataset:login" - assert len(results[customer_table_key]) == 1 - assert len(results[address_table_key]) == 1 - assert len(results[login_table_key]) == 2 - assert results[customer_table_key][0]["email"] == customer_email - assert results[customer_table_key][0]["name"] == customer_name - assert results[customer_table_key][0]["id"] == customer_id - assert results[address_table_key][0]["id"] == customer_id - assert results[login_table_key][0]["name"] == customer_name - - pr.delete(db=db) - - -@pytest.mark.integration_external -@pytest.mark.integration_dynamodb -@pytest.mark.parametrize( - "dsr_version", - ["use_dsr_3_0", "use_dsr_2_0"], -) -def test_create_and_process_erasure_request_dynamodb( - dynamodb_example_test_dataset_config, - dynamodb_resources, - integration_config: Dict[str, str], - db, - cache, - erasure_policy, - dsr_version, - request, - run_privacy_request_task, -): - request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - - customer_email = dynamodb_resources["email"] - dynamodb_client = dynamodb_resources["client"] - customer_id = dynamodb_resources["customer_id"] - customer_name = dynamodb_resources["name"] - data = { - "requested_at": "2021-08-30T16:09:37.359Z", - "policy_key": erasure_policy.key, - "identity": {"email": customer_email}, - } - pr = get_privacy_request_results( - db, - erasure_policy, - run_privacy_request_task, - data, - task_timeout=PRIVACY_REQUEST_TASK_TIMEOUT_EXTERNAL, - ) - pr.delete(db=db) - deserializer = TypeDeserializer() - customer = dynamodb_client.get_item( - TableName="customer", - Key={"id": {"S": customer_id}}, - ) - customer_identifier = dynamodb_client.get_item( - TableName="customer_identifier", - Key={"email": {"S": customer_email}}, - ) - login = dynamodb_client.get_item( - TableName="login", - Key={ - "customer_id": {"S": customer_name}, - "login_date": {"S": "2023-01-01"}, - }, +class TestAsyncCallbacks: + @mock.patch("fides.api.service.connectors.saas_connector.AuthenticatedClient.send") + @pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], ) - assert deserializer.deserialize(customer["Item"]["name"]) == None - assert deserializer.deserialize(customer_identifier["Item"]["name"]) == None - assert deserializer.deserialize(login["Item"]["name"]) == None - - -@mock.patch("fides.api.service.connectors.saas_connector.AuthenticatedClient.send") -@pytest.mark.parametrize( - "dsr_version", - ["use_dsr_3_0", "use_dsr_2_0"], -) -def test_async_callback_access_request( - mock_send, - api_client, - saas_example_async_dataset_config, - saas_async_example_connection_config: Dict[str, str], - db, - policy, - dsr_version, - request, - run_privacy_request_task, -): - """Demonstrate end-to-end support for tasks expecting async callbacks for DSR 3.0""" - request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - mock_send().json.return_value = {"id": "123"} - - pr = get_privacy_request_results( + def test_async_callback_access_request( + self, + mock_send, + api_client, + saas_example_async_dataset_config, + saas_async_example_connection_config: Dict[str, str], db, policy, + dsr_version, + request, run_privacy_request_task, - {"identity": {"email": "customer-1@example.com"}}, - task_timeout=120, - ) - db.refresh(pr) - - if dsr_version == "use_dsr_3_0": - assert pr.status == PrivacyRequestStatus.in_processing - - request_tasks = pr.access_tasks - assert request_tasks[0].status == ExecutionLogStatus.complete - - # SaaS Request was marked as needing async results, so the Request - # Task was put in a paused state - assert request_tasks[1].status == ExecutionLogStatus.awaiting_processing - assert request_tasks[1].collection_address == "saas_async_config:user" - - # Terminator task is downstream so it is still in a pending state - assert request_tasks[2].status == ExecutionLogStatus.pending - - jwe_token = mock_send.call_args[0][0].headers["reply-to-token"] - auth_header = {"Authorization": "Bearer " + jwe_token} - # Post to callback URL to supply access results async - # This requeues task and proceeds downstream - api_client.post( - V1_URL_PREFIX + REQUEST_TASK_CALLBACK, - headers=auth_header, - json={"access_results": [{"id": 1, "user_id": "abcde", "state": "VA"}]}, - ) - db.refresh(pr) - assert pr.status == PrivacyRequestStatus.complete - assert pr.get_raw_access_results() == { - "saas_async_config:user": [{"id": 1, "user_id": "abcde", "state": "VA"}] - } - # User data supplied async was filtered before being returned to the end user - assert pr.get_filtered_final_upload() == { - "access_request_rule": { - "saas_async_config:user": [{"state": "VA", "id": 1}] - } - } - else: - # Async Access Requests not supported for DSR 2.0 - the given - # node cannot be paused - assert pr.status == PrivacyRequestStatus.complete - - -@mock.patch("fides.api.service.connectors.saas_connector.AuthenticatedClient.send") -@pytest.mark.parametrize( - "dsr_version", - ["use_dsr_3_0", "use_dsr_2_0"], -) -def test_async_callback_erasure_request( - mock_send, - saas_example_async_dataset_config, - saas_async_example_connection_config: Dict[str, str], - db, - api_client, - erasure_policy, - dsr_version, - request, - run_privacy_request_task, -): - """Demonstrate end-to-end support for erasure tasks expecting async callbacks for DSR 3.0""" - mock_send().json.return_value = {"id": "123"} - - request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - - pr = get_privacy_request_results( - db, - erasure_policy, - run_privacy_request_task, - {"identity": {"email": "customer-1@example.com"}}, - task_timeout=120, - ) + ): + """Demonstrate end-to-end support for tasks expecting async callbacks for DSR 3.0""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + mock_send().json.return_value = {"id": "123"} - if dsr_version == "use_dsr_3_0": - # Access async task fired first - assert pr.access_tasks[1].status == ExecutionLogStatus.awaiting_processing - jwe_token = mock_send.call_args[0][0].headers["reply-to-token"] - auth_header = {"Authorization": "Bearer " + jwe_token} - # Post to callback URL to supply access results async - # This requeues task and proceeds downstream - response = api_client.post( - V1_URL_PREFIX + REQUEST_TASK_CALLBACK, - headers=auth_header, - json={"access_results": [{"id": 1, "user_id": "abcde", "state": "VA"}]}, - ) - assert response.status_code == 200 - - # Erasure task is also expected async results and is now paused - assert pr.erasure_tasks[1].status == ExecutionLogStatus.awaiting_processing - jwe_token = mock_send.call_args[0][0].headers["reply-to-token"] - auth_header = {"Authorization": "Bearer " + jwe_token} - # Post to callback URL to supply erasure results async - # This requeues task and proceeds downstream to complete privacy request - response = api_client.post( - V1_URL_PREFIX + REQUEST_TASK_CALLBACK, - headers=auth_header, - json={"rows_masked": 2}, + pr = get_privacy_request_results( + db, + policy, + run_privacy_request_task, + {"identity": {"email": "customer-1@example.com"}}, + task_timeout=120, ) - assert response.status_code == 200 - db.refresh(pr) - assert pr.status == PrivacyRequestStatus.complete - assert pr.erasure_tasks[1].rows_masked == 2 - assert pr.erasure_tasks[1].status == ExecutionLogStatus.complete + if dsr_version == "use_dsr_3_0": + assert pr.status == PrivacyRequestStatus.in_processing - else: - # Async Erasure Requests not supported for DSR 2.0 - the given - # node cannot be paused - db.refresh(pr) - assert pr.status == PrivacyRequestStatus.complete + request_tasks = pr.access_tasks + assert request_tasks[0].status == ExecutionLogStatus.complete + # SaaS Request was marked as needing async results, so the Request + # Task was put in a paused state + assert request_tasks[1].status == ExecutionLogStatus.awaiting_processing + assert request_tasks[1].collection_address == "saas_async_config:user" -@pytest.mark.integration_external -@pytest.mark.integration_google_cloud_sql_postgres -@mock.patch("fides.api.models.privacy_request.PrivacyRequest.trigger_policy_webhook") -@pytest.mark.parametrize( - "dsr_version", - ["use_dsr_3_0", "use_dsr_2_0"], -) -def test_create_and_process_access_request_google_cloud_sql_postgres( - trigger_webhook_mock, - google_cloud_sql_postgres_example_test_dataset_config, - google_cloud_sql_postgres_integration_db, - db: Session, - cache, - policy, - dsr_version, - request, - policy_pre_execution_webhooks, - policy_post_execution_webhooks, - run_privacy_request_task, -): - request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + # Terminator task is downstream so it is still in a pending state + assert request_tasks[2].status == ExecutionLogStatus.pending - customer_email = "customer-1@example.com" - data = { - "requested_at": "2021-08-30T16:09:37.359Z", - "policy_key": policy.key, - "identity": {"email": customer_email}, - } + jwe_token = mock_send.call_args[0][0].headers["reply-to-token"] + auth_header = {"Authorization": "Bearer " + jwe_token} + # Post to callback URL to supply access results async + # This requeues task and proceeds downstream + api_client.post( + V1_URL_PREFIX + REQUEST_TASK_CALLBACK, + headers=auth_header, + json={"access_results": [{"id": 1, "user_id": "abcde", "state": "VA"}]}, + ) + db.refresh(pr) + assert pr.status == PrivacyRequestStatus.complete + assert pr.get_raw_access_results() == { + "saas_async_config:user": [{"id": 1, "user_id": "abcde", "state": "VA"}] + } + # User data supplied async was filtered before being returned to the end user + assert pr.get_filtered_final_upload() == { + "access_request_rule": { + "saas_async_config:user": [{"state": "VA", "id": 1}] + } + } + else: + # Async Access Requests not supported for DSR 2.0 - the given + # node cannot be paused + assert pr.status == PrivacyRequestStatus.complete - pr = get_privacy_request_results( + @mock.patch("fides.api.service.connectors.saas_connector.AuthenticatedClient.send") + @pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], + ) + def test_async_callback_erasure_request( + self, + mock_send, + saas_example_async_dataset_config, + saas_async_example_connection_config: Dict[str, str], db, - policy, + api_client, + erasure_policy, + dsr_version, + request, run_privacy_request_task, - data, - task_timeout=PRIVACY_REQUEST_TASK_TIMEOUT_EXTERNAL, - ) - - results = pr.get_raw_access_results() - assert len(results.keys()) == 11 - - for key in results.keys(): - assert results[key] is not None - assert results[key] != {} + ): + """Demonstrate end-to-end support for erasure tasks expecting async callbacks for DSR 3.0""" + mock_send().json.return_value = {"id": "123"} - result_key_prefix = "google_cloud_sql_postgres_example_test_dataset:" - customer_key = result_key_prefix + "customer" - assert results[customer_key][0]["email"] == customer_email + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - visit_key = result_key_prefix + "visit" - assert results[visit_key][0]["email"] == customer_email - # Both pre-execution webhooks and both post-execution webhooks were called - assert trigger_webhook_mock.call_count == 4 - pr.delete(db=db) + pr = get_privacy_request_results( + db, + erasure_policy, + run_privacy_request_task, + {"identity": {"email": "customer-1@example.com"}}, + task_timeout=120, + ) + if dsr_version == "use_dsr_3_0": + # Access async task fired first + assert pr.access_tasks[1].status == ExecutionLogStatus.awaiting_processing + jwe_token = mock_send.call_args[0][0].headers["reply-to-token"] + auth_header = {"Authorization": "Bearer " + jwe_token} + # Post to callback URL to supply access results async + # This requeues task and proceeds downstream + response = api_client.post( + V1_URL_PREFIX + REQUEST_TASK_CALLBACK, + headers=auth_header, + json={"access_results": [{"id": 1, "user_id": "abcde", "state": "VA"}]}, + ) + assert response.status_code == 200 + + # Erasure task is also expected async results and is now paused + assert pr.erasure_tasks[1].status == ExecutionLogStatus.awaiting_processing + jwe_token = mock_send.call_args[0][0].headers["reply-to-token"] + auth_header = {"Authorization": "Bearer " + jwe_token} + # Post to callback URL to supply erasure results async + # This requeues task and proceeds downstream to complete privacy request + response = api_client.post( + V1_URL_PREFIX + REQUEST_TASK_CALLBACK, + headers=auth_header, + json={"rows_masked": 2}, + ) + assert response.status_code == 200 -@pytest.mark.integration_external -@pytest.mark.integration_google_cloud_sql_postgres -@pytest.mark.parametrize( - "dsr_version", - ["use_dsr_3_0", "use_dsr_2_0"], -) -def test_create_and_process_erasure_request_google_cloud_sql_postgres( - google_cloud_sql_postgres_integration_db, - google_cloud_sql_postgres_example_test_dataset_config, - cache, - db, - dsr_version, - request, - generate_auth_header, - erasure_policy, - run_privacy_request_task, -): - request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + db.refresh(pr) + assert pr.status == PrivacyRequestStatus.complete - customer_email = "customer-1@example.com" - customer_id = 1 - data = { - "requested_at": "2021-08-30T16:09:37.359Z", - "policy_key": erasure_policy.key, - "identity": {"email": customer_email}, - } + assert pr.erasure_tasks[1].rows_masked == 2 + assert pr.erasure_tasks[1].status == ExecutionLogStatus.complete - pr = get_privacy_request_results( - db, - erasure_policy, - run_privacy_request_task, - data, - task_timeout=PRIVACY_REQUEST_TASK_TIMEOUT_EXTERNAL, - ) - pr.delete(db=db) - - stmt = select( - column("id"), - column("name"), - ).select_from(table("customer")) - - res = google_cloud_sql_postgres_integration_db.execute(stmt).all() - - customer_found = False - for row in res: - if customer_id == row.id: - customer_found = True - # Check that the `name` field is `None` - assert row.name is None - assert customer_found + else: + # Async Erasure Requests not supported for DSR 2.0 - the given + # node cannot be paused + db.refresh(pr) + assert pr.status == PrivacyRequestStatus.complete diff --git a/tests/ops/service/privacy_request/test_saas_privacy_requests.py b/tests/ops/service/privacy_request/test_saas_privacy_requests.py new file mode 100644 index 0000000000..8816037e2e --- /dev/null +++ b/tests/ops/service/privacy_request/test_saas_privacy_requests.py @@ -0,0 +1,188 @@ +from unittest import mock + +import pytest + +from fides.api.schemas.masking.masking_configuration import HmacMaskingConfiguration +from fides.api.schemas.saas.saas_config import SaaSRequest +from fides.api.schemas.saas.shared_schemas import HTTPMethod, SaaSRequestParams +from fides.api.service.connectors.saas_connector import SaaSConnector +from fides.api.service.masking.strategy.masking_strategy_hmac import HmacMaskingStrategy +from tests.ops.service.privacy_request.test_request_runner_service import ( + PRIVACY_REQUEST_TASK_TIMEOUT_EXTERNAL, + get_privacy_request_results, +) + + +@pytest.mark.integration_saas +@mock.patch("fides.api.models.privacy_request.PrivacyRequest.trigger_policy_webhook") +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) +def test_create_and_process_access_request_saas_mailchimp( + trigger_webhook_mock, + mailchimp_connection_config, + mailchimp_dataset_config, + db, + cache, + policy, + policy_pre_execution_webhooks, + policy_post_execution_webhooks, + dsr_version, + request, + mailchimp_identity_email, + run_privacy_request_task, +): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + + customer_email = mailchimp_identity_email + data = { + "requested_at": "2021-08-30T16:09:37.359Z", + "policy_key": policy.key, + "identity": {"email": customer_email}, + } + + pr = get_privacy_request_results( + db, + policy, + run_privacy_request_task, + data, + task_timeout=PRIVACY_REQUEST_TASK_TIMEOUT_EXTERNAL, + ) + results = pr.get_raw_access_results() + assert len(results.keys()) == 3 + + for key in results.keys(): + assert results[key] is not None + assert results[key] != {} + + result_key_prefix = f"mailchimp_instance:" + member_key = result_key_prefix + "member" + assert results[member_key][0]["email_address"] == customer_email + + # Both pre-execution webhooks and both post-execution webhooks were called + assert trigger_webhook_mock.call_count == 4 + + pr.delete(db=db) + + +@pytest.mark.integration_saas +@mock.patch("fides.api.models.privacy_request.PrivacyRequest.trigger_policy_webhook") +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) +def test_create_and_process_erasure_request_saas( + _, + mailchimp_connection_config, + mailchimp_dataset_config, + db, + cache, + erasure_policy_hmac, + generate_auth_header, + dsr_version, + request, + mailchimp_identity_email, + reset_mailchimp_data, + run_privacy_request_task, +): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + + customer_email = mailchimp_identity_email + data = { + "requested_at": "2021-08-30T16:09:37.359Z", + "policy_key": erasure_policy_hmac.key, + "identity": {"email": customer_email}, + } + + pr = get_privacy_request_results( + db, + erasure_policy_hmac, + run_privacy_request_task, + data, + task_timeout=PRIVACY_REQUEST_TASK_TIMEOUT_EXTERNAL, + ) + + connector = SaaSConnector(mailchimp_connection_config) + connector.set_saas_request_state( + SaaSRequest(path="test_path", method=HTTPMethod.GET) + ) # dummy request as connector requires it + request: SaaSRequestParams = SaaSRequestParams( + method=HTTPMethod.GET, + path="/3.0/search-members", + query_params={"query": mailchimp_identity_email}, + ) + resp = connector.create_client().send(request) + body = resp.json() + merge_fields = body["exact_matches"]["members"][0]["merge_fields"] + + masking_configuration = HmacMaskingConfiguration() + masking_strategy = HmacMaskingStrategy(masking_configuration) + + assert ( + merge_fields["FNAME"] + == masking_strategy.mask( + [reset_mailchimp_data["merge_fields"]["FNAME"]], pr.id + )[0] + ) + assert ( + merge_fields["LNAME"] + == masking_strategy.mask( + [reset_mailchimp_data["merge_fields"]["LNAME"]], pr.id + )[0] + ) + + pr.delete(db=db) + + +@pytest.mark.integration_saas +@mock.patch("fides.api.models.privacy_request.PrivacyRequest.trigger_policy_webhook") +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) +def test_create_and_process_access_request_saas_hubspot( + trigger_webhook_mock, + connection_config_hubspot, + dataset_config_hubspot, + db, + cache, + policy, + policy_pre_execution_webhooks, + policy_post_execution_webhooks, + dsr_version, + request, + hubspot_identity_email, + run_privacy_request_task, +): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + + customer_email = hubspot_identity_email + data = { + "requested_at": "2021-08-30T16:09:37.359Z", + "policy_key": policy.key, + "identity": {"email": customer_email}, + } + + pr = get_privacy_request_results( + db, + policy, + run_privacy_request_task, + data, + task_timeout=PRIVACY_REQUEST_TASK_TIMEOUT_EXTERNAL, + ) + results = pr.get_raw_access_results() + assert len(results.keys()) == 4 + + for key in results.keys(): + assert results[key] is not None + assert results[key] != {} + + result_key_prefix = f"hubspot_instance:" + contacts_key = result_key_prefix + "contacts" + assert results[contacts_key][0]["properties"]["email"] == customer_email + + # Both pre-execution webhooks and both post-execution webhooks were called + assert trigger_webhook_mock.call_count == 4 + + pr.delete(db=db) diff --git a/tests/ops/service/privacy_request/test_scylladb_privacy_requests.py b/tests/ops/service/privacy_request/test_scylladb_privacy_requests.py new file mode 100644 index 0000000000..cb47b5893a --- /dev/null +++ b/tests/ops/service/privacy_request/test_scylladb_privacy_requests.py @@ -0,0 +1,135 @@ +from unittest import mock + +import pytest + +from fides.api.models.privacy_request import ExecutionLogStatus +from tests.ops.service.privacy_request.test_request_runner_service import ( + get_privacy_request_results, +) + + +@pytest.mark.integration +@pytest.mark.integration_scylladb +@mock.patch("fides.api.models.privacy_request.PrivacyRequest.trigger_policy_webhook") +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) +def test_create_and_process_access_request_scylladb( + trigger_webhook_mock, + scylladb_test_dataset_config, + scylla_reset_db, + db, + cache, + policy, + dsr_version, + request, + policy_pre_execution_webhooks, + policy_post_execution_webhooks, + run_privacy_request_task, +): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + + customer_email = "customer-1@example.com" + data = { + "requested_at": "2021-08-30T16:09:37.359Z", + "policy_key": policy.key, + "identity": {"email": customer_email}, + } + + pr = get_privacy_request_results( + db, + policy, + run_privacy_request_task, + data, + ) + + results = pr.get_raw_access_results() + assert len(results.keys()) == 4 + + assert "scylladb_example_test_dataset:users" in results + assert len(results["scylladb_example_test_dataset:users"]) == 1 + assert results["scylladb_example_test_dataset:users"][0]["email"] == customer_email + assert results["scylladb_example_test_dataset:users"][0]["age"] == 41 + assert results["scylladb_example_test_dataset:users"][0][ + "alternative_contacts" + ] == {"phone": "+1 (531) 988-5905", "work_email": "customer-1@example.com"} + + assert "scylladb_example_test_dataset:user_activity" in results + assert len(results["scylladb_example_test_dataset:user_activity"]) == 3 + + for activity in results["scylladb_example_test_dataset:user_activity"]: + assert activity["user_id"] + assert activity["timestamp"] + assert activity["activity_type"] + assert activity["user_agent"] + + assert "scylladb_example_test_dataset:payment_methods" in results + assert len(results["scylladb_example_test_dataset:payment_methods"]) == 2 + for payment_method in results["scylladb_example_test_dataset:payment_methods"]: + assert payment_method["payment_method_id"] + assert payment_method["card_number"] + assert payment_method["expiration_date"] + + assert "scylladb_example_test_dataset:orders" in results + assert len(results["scylladb_example_test_dataset:orders"]) == 2 + for payment_method in results["scylladb_example_test_dataset:orders"]: + assert payment_method["order_amount"] + assert payment_method["order_date"] + assert payment_method["order_description"] + + # Both pre-execution webhooks and both post-execution webhooks were called + assert trigger_webhook_mock.call_count == 4 + pr.delete(db=db) + + +@pytest.mark.integration +@pytest.mark.integration_scylladb +@mock.patch("fides.api.models.privacy_request.PrivacyRequest.trigger_policy_webhook") +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0"], +) +def test_create_and_process_access_request_scylladb_no_keyspace( + trigger_webhook_mock, + scylladb_test_dataset_config_no_keyspace, + scylla_reset_db, + db, + cache, + policy, + dsr_version, + request, + policy_pre_execution_webhooks, + policy_post_execution_webhooks, + run_privacy_request_task, +): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + + customer_email = "customer-1@example.com" + data = { + "requested_at": "2021-08-30T16:09:37.359Z", + "policy_key": policy.key, + "identity": {"email": customer_email}, + } + + pr = get_privacy_request_results( + db, + policy, + run_privacy_request_task, + data, + ) + + assert ( + pr.access_tasks.count() == 6 + ) # There's 4 tables plus the root and terminal "dummy" tasks + + # Root task should be completed + assert pr.access_tasks.first().collection_name == "__ROOT__" + assert pr.access_tasks.first().status == ExecutionLogStatus.complete + + # All other tasks should be error + for access_task in pr.access_tasks.offset(1): + assert access_task.status == ExecutionLogStatus.error + + results = pr.get_raw_access_results() + assert results == {} diff --git a/tests/ops/service/privacy_request/test_snowflake_privacy_requests.py b/tests/ops/service/privacy_request/test_snowflake_privacy_requests.py index ba8b514eb3..530eb3d5c8 100644 --- a/tests/ops/service/privacy_request/test_snowflake_privacy_requests.py +++ b/tests/ops/service/privacy_request/test_snowflake_privacy_requests.py @@ -2,7 +2,7 @@ import pytest -from fides.api.service.connectors.sql_connector import SnowflakeConnector +from fides.api.service.connectors.snowflake_connector import SnowflakeConnector from tests.ops.service.privacy_request.test_request_runner_service import ( PRIVACY_REQUEST_TASK_TIMEOUT_EXTERNAL, get_privacy_request_results,