diff --git a/airflow/providers/google/cloud/operators/bigquery.py b/airflow/providers/google/cloud/operators/bigquery.py index 084205f4ed728..33e51833e59a6 100644 --- a/airflow/providers/google/cloud/operators/bigquery.py +++ b/airflow/providers/google/cloud/operators/bigquery.py @@ -34,8 +34,12 @@ from airflow.models.xcom import XCom from airflow.providers.common.sql.operators.sql import ( SQLCheckOperator, + SQLColumnCheckOperator, SQLIntervalCheckOperator, + SQLTableCheckOperator, SQLValueCheckOperator, + _get_failed_checks, + parse_boolean, ) from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook, BigQueryJob from airflow.providers.google.cloud.hooks.gcs import GCSHook, _parse_gcs_url @@ -520,6 +524,241 @@ def execute_complete(self, context: Context, event: dict[str, Any]) -> None: ) +class BigQueryColumnCheckOperator(_BigQueryDbHookMixin, SQLColumnCheckOperator): + """ + BigQueryColumnCheckOperator subclasses the SQLColumnCheckOperator + in order to provide a job id for OpenLineage to parse. See base class + docstring for usage. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:BigQueryColumnCheckOperator` + + :param table: the table name + :param column_mapping: a dictionary relating columns to their checks + :param partition_clause: a string SQL statement added to a WHERE clause + to partition data + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :param use_legacy_sql: Whether to use legacy SQL (true) + or standard SQL (false). + :param location: The geographic location of the job. See details at: + https://cloud.google.com/bigquery/docs/locations#specifying_your_location + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :param labels: a dictionary containing labels for the table, passed to BigQuery + """ + + def __init__( + self, + *, + table: str, + column_mapping: dict, + partition_clause: str | None = None, + gcp_conn_id: str = "google_cloud_default", + use_legacy_sql: bool = True, + location: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, + labels: dict | None = None, + **kwargs, + ) -> None: + super().__init__( + table=table, column_mapping=column_mapping, partition_clause=partition_clause, **kwargs + ) + self.table = table + self.column_mapping = column_mapping + self.partition_clause = partition_clause + self.gcp_conn_id = gcp_conn_id + self.use_legacy_sql = use_legacy_sql + self.location = location + self.impersonation_chain = impersonation_chain + self.labels = labels + # OpenLineage needs a valid SQL query with the input/output table(s) to parse + self.sql = "" + + def _submit_job( + self, + hook: BigQueryHook, + job_id: str, + ) -> BigQueryJob: + """Submit a new job and get the job id for polling the status using Trigger.""" + configuration = {"query": {"query": self.sql}} + + return hook.insert_job( + configuration=configuration, + project_id=hook.project_id, + location=self.location, + job_id=job_id, + nowait=False, + ) + + def execute(self, context=None): + """Perform checks on the given columns.""" + hook = self.get_db_hook() + failed_tests = [] + for column in self.column_mapping: + checks = [*self.column_mapping[column]] + checks_sql = ",".join([self.column_checks[check].replace("column", column) for check in checks]) + partition_clause_statement = f"WHERE {self.partition_clause}" if self.partition_clause else "" + self.sql = f"SELECT {checks_sql} FROM {self.table} {partition_clause_statement};" + + job_id = hook.generate_job_id( + dag_id=self.dag_id, + task_id=self.task_id, + logical_date=context["logical_date"], + configuration=self.configuration, + ) + job = self._submit_job(hook, job_id=job_id) + context["ti"].xcom_push(key="job_id", value=job.job_id) + records = list(job.result().to_dataframe().values.flatten()) + + if not records: + raise AirflowException(f"The following query returned zero rows: {self.sql}") + + self.log.info("Record: %s", records) + + for idx, result in enumerate(records): + tolerance = self.column_mapping[column][checks[idx]].get("tolerance") + + self.column_mapping[column][checks[idx]]["result"] = result + self.column_mapping[column][checks[idx]]["success"] = self._get_match( + self.column_mapping[column][checks[idx]], result, tolerance + ) + + failed_tests.extend(_get_failed_checks(self.column_mapping[column], column)) + if failed_tests: + raise AirflowException( + f"Test failed.\nResults:\n{records!s}\n" + "The following tests have failed:" + f"\n{''.join(failed_tests)}" + ) + + self.log.info("All tests have passed") + + +class BigQueryTableCheckOperator(_BigQueryDbHookMixin, SQLTableCheckOperator): + """ + BigQueryTableCheckOperator subclasses the SQLTableCheckOperator + in order to provide a job id for OpenLineage to parse. See base class + for usage. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:BigQueryTableCheckOperator` + + :param table: the table name + :param checks: a dictionary of check names and boolean SQL statements + :param partition_clause: a string SQL statement added to a WHERE clause + to partition data + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :param use_legacy_sql: Whether to use legacy SQL (true) + or standard SQL (false). + :param location: The geographic location of the job. See details at: + https://cloud.google.com/bigquery/docs/locations#specifying_your_location + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :param labels: a dictionary containing labels for the table, passed to BigQuery + """ + + def __init__( + self, + *, + table: str, + checks: dict, + partition_clause: str | None = None, + gcp_conn_id: str = "google_cloud_default", + use_legacy_sql: bool = True, + location: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, + labels: dict | None = None, + **kwargs, + ) -> None: + super().__init__(table=table, checks=checks, partition_clause=partition_clause, **kwargs) + self.table = table + self.checks = checks + self.partition_clause = partition_clause + self.gcp_conn_id = gcp_conn_id + self.use_legacy_sql = use_legacy_sql + self.location = location + self.impersonation_chain = impersonation_chain + self.labels = labels + # OpenLineage needs a valid SQL query with the input/output table(s) to parse + self.sql = "" + + def _submit_job( + self, + hook: BigQueryHook, + job_id: str, + ) -> BigQueryJob: + """Submit a new job and get the job id for polling the status using Trigger.""" + configuration = {"query": {"query": self.sql}} + + return hook.insert_job( + configuration=configuration, + project_id=hook.project_id, + location=self.location, + job_id=job_id, + nowait=False, + ) + + def execute(self, context=None): + """Execute the given checks on the table.""" + hook = self.get_db_hook() + checks_sql = " UNION ALL ".join( + [ + self.sql_check_template.replace("check_statement", value["check_statement"]) + .replace("_check_name", check_name) + .replace("table", self.table) + for check_name, value in self.checks.items() + ] + ) + partition_clause_statement = f"WHERE {self.partition_clause}" if self.partition_clause else "" + self.sql = f"SELECT check_name, check_result FROM ({checks_sql}) " + f"AS check_table {partition_clause_statement};" + + job_id = hook.generate_job_id( + dag_id=self.dag_id, + task_id=self.task_id, + logical_date=context["logical_date"], + configuration=self.configuration, + ) + job = self._submit_job(hook, job_id=job_id) + context["ti"].xcom_push(key="job_id", value=job.job_id) + records = job.result().to_dataframe() + + if records.empty: + raise AirflowException(f"The following query returned zero rows: {self.sql}") + + records.columns = records.columns.str.lower() + self.log.info("Record:\n%s", records) + + for row in records.iterrows(): + check = row[1].get("check_name") + result = row[1].get("check_result") + self.checks[check]["success"] = parse_boolean(str(result)) + + failed_tests = _get_failed_checks(self.checks) + if failed_tests: + raise AirflowException( + f"Test failed.\nQuery:\n{self.sql}\nResults:\n{records!s}\n" + "The following tests have failed:" + f"\n{', '.join(failed_tests)}" + ) + + self.log.info("All tests have passed") + + class BigQueryGetDataOperator(BaseOperator): """ Fetches the data from a BigQuery table (alternatively fetch data for selected columns) diff --git a/docs/apache-airflow-providers-google/operators/cloud/bigquery.rst b/docs/apache-airflow-providers-google/operators/cloud/bigquery.rst index 548c37ace21f1..919c3dd898bfa 100644 --- a/docs/apache-airflow-providers-google/operators/cloud/bigquery.rst +++ b/docs/apache-airflow-providers-google/operators/cloud/bigquery.rst @@ -438,6 +438,34 @@ Also you can use deferrable mode in this operator :start-after: [START howto_operator_bigquery_interval_check_async] :end-before: [END howto_operator_bigquery_interval_check_async] +.. _howto/operator:BigQueryColumnCheckOperator: + +Check columns with predefined tests +""""""""""""""""""""""""""""""""""" + +To check that columns pass user-configurable tests you can use +:class:`~airflow.providers.google.cloud.operators.bigquery.BigQueryColumnCheckOperator` + +.. exampleinclude:: /../../tests/system/providers/google/cloud/bigquery/example_bigquery_queries.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_bigquery_column_check] + :end-before: [END howto_operator_bigquery_column_check] + +.. _howto/operator:BigQueryTableCheckOperator: + +Check table level data quality +"""""""""""""""""""""""""""""" + +To check that tables pass user-defined tests you can use +:class:`~airflow.providers.google.cloud.operators.bigquery.BigQueryTableCheckOperator` + +.. exampleinclude:: /../../tests/system/providers/google/cloud/bigquery/example_bigquery_queries.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_bigquery_table_check] + :end-before: [END howto_operator_bigquery_table_check] + Sensors ^^^^^^^ diff --git a/tests/system/providers/google/cloud/bigquery/example_bigquery_queries.py b/tests/system/providers/google/cloud/bigquery/example_bigquery_queries.py index eeb54545415c2..52637d37ce22d 100644 --- a/tests/system/providers/google/cloud/bigquery/example_bigquery_queries.py +++ b/tests/system/providers/google/cloud/bigquery/example_bigquery_queries.py @@ -27,12 +27,14 @@ from airflow.operators.bash import BashOperator from airflow.providers.google.cloud.operators.bigquery import ( BigQueryCheckOperator, + BigQueryColumnCheckOperator, BigQueryCreateEmptyDatasetOperator, BigQueryCreateEmptyTableOperator, BigQueryDeleteDatasetOperator, BigQueryGetDataOperator, BigQueryInsertJobOperator, BigQueryIntervalCheckOperator, + BigQueryTableCheckOperator, BigQueryValueCheckOperator, ) from airflow.utils.trigger_rule import TriggerRule @@ -209,6 +211,22 @@ ) # [END howto_operator_bigquery_interval_check] + # [START howto_operator_bigquery_column_check] + column_check = BigQueryColumnCheckOperator( + task_id="column_check", + table=f"{DATASET}.{TABLE_1}", + column_mapping={"value": {"null_check": {"equal_to": 0}}}, + ) + # [END howto_operator_bigquery_column_check] + + # [START howto_operator_bigquery_table_check] + table_check = BigQueryTableCheckOperator( + task_id="table_check", + table=f"{DATASET}.{TABLE_1}", + checks={"row_count_check": {"check_statement": {"COUNT(*) = 4"}}}, + ) + # [END howto_operator_bigquery_table_check] + delete_dataset = BigQueryDeleteDatasetOperator( task_id="delete_dataset", dataset_id=DATASET,