diff --git a/airflow/providers/common/sql/operators/sql.py b/airflow/providers/common/sql/operators/sql.py index 66984a802f862..9b3aa868dd8e8 100644 --- a/airflow/providers/common/sql/operators/sql.py +++ b/airflow/providers/common/sql/operators/sql.py @@ -19,10 +19,10 @@ import ast import re -from typing import TYPE_CHECKING, Any, Callable, Iterable, Mapping, Sequence, SupportsAbs +from typing import TYPE_CHECKING, Any, Callable, Iterable, Mapping, NoReturn, Sequence, SupportsAbs from airflow.compat.functools import cached_property -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, AirflowFailException from airflow.hooks.base import BaseHook from airflow.models import BaseOperator, SkipMixin from airflow.providers.common.sql.hooks.sql import DbApiHook, fetch_all_handler @@ -31,7 +31,14 @@ from airflow.utils.context import Context -def parse_boolean(val: str) -> str | bool: +def _convert_to_float_if_possible(s: str) -> float | str: + try: + return float(s) + except (ValueError, TypeError): + return s + + +def _parse_boolean(val: str) -> str | bool: """Try to parse a string into boolean. Raises ValueError if the input is not a valid true- or false-like string value. @@ -44,20 +51,6 @@ def parse_boolean(val: str) -> str | bool: raise ValueError(f"{val!r} is not a boolean-like string value") -def _get_failed_checks(checks, col=None): - if col: - return [ - f"Column: {col}\nCheck: {check},\nCheck Values: {check_values}\n" - for check, check_values in checks.items() - if not check_values["success"] - ] - return [ - f"\tCheck: {check},\n\tCheck Values: {check_values}\n" - for check, check_values in checks.items() - if not check_values["success"] - ] - - _PROVIDERS_MATCHER = re.compile(r"airflow\.providers\.(.*)\.hooks.*") _MIN_SUPPORTED_PROVIDERS_VERSION = { @@ -103,12 +96,14 @@ def __init__( conn_id: str | None = None, database: str | None = None, hook_params: dict | None = None, + retry_on_failure: bool = True, **kwargs, ): super().__init__(**kwargs) self.conn_id = conn_id self.database = database self.hook_params = {} if hook_params is None else hook_params + self.retry_on_failure = retry_on_failure @cached_property def _hook(self): @@ -155,6 +150,11 @@ def get_db_hook(self) -> DbApiHook: """ return self._hook + def _raise_exception(self, exception_string: str) -> NoReturn: + if self.retry_on_failure: + raise AirflowException(exception_string) + raise AirflowFailException(exception_string) + class SQLExecuteQueryOperator(BaseSQLOperator): """ @@ -239,6 +239,7 @@ class SQLColumnCheckOperator(BaseSQLOperator): - geq_to: value that results should be greater than or equal to - leq_to: value that results should be less than or equal to - tolerance: the percentage that the result may be off from the expected value + - partition_clause: an extra clause passed into a WHERE statement to partition data :param table: the table to run checks on :param column_mapping: the dictionary of columns and their associated checks, e.g. @@ -249,6 +250,7 @@ class SQLColumnCheckOperator(BaseSQLOperator): "col_name": { "null_check": { "equal_to": 0, + "partition_clause": "foreign_key IS NOT NULL", }, "min": { "greater_than": 5, @@ -268,6 +270,8 @@ class SQLColumnCheckOperator(BaseSQLOperator): :param conn_id: the connection ID used to connect to the database :param database: name of database which overwrite the defined one in connection + :param accept_none: whether or not to accept None values returned by the query. If true, converts None + to 0. .. seealso:: For more information on how to use this operator, take a look at the guide: @@ -276,12 +280,17 @@ class SQLColumnCheckOperator(BaseSQLOperator): template_fields = ("partition_clause",) + sql_check_template = """ + SELECT '{column}' AS col_name, '{check}' AS check_type, {column}_{check} AS check_result + FROM (SELECT {check_statement} AS {column}_{check} FROM {table} {partition_clause}) AS sq + """ + column_checks = { - "null_check": "SUM(CASE WHEN column IS NULL THEN 1 ELSE 0 END) AS column_null_check", - "distinct_check": "COUNT(DISTINCT(column)) AS column_distinct_check", - "unique_check": "COUNT(column) - COUNT(DISTINCT(column)) AS column_unique_check", - "min": "MIN(column) AS column_min", - "max": "MAX(column) AS column_max", + "null_check": "SUM(CASE WHEN {column} IS NULL THEN 1 ELSE 0 END)", + "distinct_check": "COUNT(DISTINCT({column}))", + "unique_check": "COUNT({column}) - COUNT(DISTINCT({column}))", + "min": "MIN({column})", + "max": "MAX({column})", } def __init__( @@ -292,53 +301,84 @@ def __init__( partition_clause: str | None = None, conn_id: str | None = None, database: str | None = None, + accept_none: bool = True, **kwargs, ): super().__init__(conn_id=conn_id, database=database, **kwargs) - for checks in column_mapping.values(): - for check, check_values in checks.items(): - self._column_mapping_validation(check, check_values) self.table = table self.column_mapping = column_mapping self.partition_clause = partition_clause - # OpenLineage needs a valid SQL query with the input/output table(s) to parse - self.sql = f"SELECT * FROM {self.table};" + self.accept_none = accept_none + + def _build_checks_sql(): + for column, checks in self.column_mapping.items(): + for check, check_values in checks.items(): + self._column_mapping_validation(check, check_values) + yield self._generate_sql_query(column, checks) + + checks_sql = "UNION ALL".join(_build_checks_sql()) + + self.sql = f"SELECT col_name, check_type, check_result FROM ({checks_sql}) AS check_columns" def execute(self, context: Context): hook = self.get_db_hook() - failed_tests = [] - for column in self.column_mapping: - checks = [*self.column_mapping[column]] - checks_sql = ",".join([self.column_checks[check].replace("column", column) for check in checks]) - partition_clause_statement = f"WHERE {self.partition_clause}" if self.partition_clause else "" - self.sql = f"SELECT {checks_sql} FROM {self.table} {partition_clause_statement};" - records = hook.get_first(self.sql) + records = hook.get_records(self.sql) - if not records: - raise AirflowException(f"The following query returned zero rows: {self.sql}") + if not records: + self._raise_exception(f"The following query returned zero rows: {self.sql}") - self.log.info("Record: %s", records) + self.log.info("Record: %s", records) - for idx, result in enumerate(records): - tolerance = self.column_mapping[column][checks[idx]].get("tolerance") + for column, check, result in records: + tolerance = self.column_mapping[column][check].get("tolerance") - self.column_mapping[column][checks[idx]]["result"] = result - self.column_mapping[column][checks[idx]]["success"] = self._get_match( - self.column_mapping[column][checks[idx]], result, tolerance - ) + self.column_mapping[column][check]["result"] = result + self.column_mapping[column][check]["success"] = self._get_match( + self.column_mapping[column][check], result, tolerance + ) - failed_tests.extend(_get_failed_checks(self.column_mapping[column], column)) + failed_tests = [ + f"Column: {col}\n\tCheck: {check},\n\tCheck Values: {check_values}\n" + for col, checks in self.column_mapping.items() + for check, check_values in checks.items() + if not check_values["success"] + ] if failed_tests: - raise AirflowException( + exception_string = ( f"Test failed.\nResults:\n{records!s}\n" - "The following tests have failed:" - f"\n{''.join(failed_tests)}" + f"The following tests have failed:\n{''.join(failed_tests)}" ) + self._raise_exception(exception_string) self.log.info("All tests have passed") + def _generate_sql_query(self, column, checks): + def _generate_partition_clause(check): + if self.partition_clause and "partition_clause" not in checks[check]: + return f"WHERE {self.partition_clause}" + elif not self.partition_clause and "partition_clause" in checks[check]: + return f"WHERE {checks[check]['partition_clause']}" + elif self.partition_clause and "partition_clause" in checks[check]: + return f"WHERE {self.partition_clause} AND {checks[check]['partition_clause']}" + else: + return "" + + checks_sql = "UNION ALL".join( + self.sql_check_template.format( + check_statement=self.column_checks[check].format(column=column), + check=check, + table=self.table, + column=column, + partition_clause=_generate_partition_clause(check), + ) + for check in checks + ) + return checks_sql + def _get_match(self, check_values, record, tolerance=None) -> bool: + if record is None and self.accept_none: + record = 0 match_boolean = True if "geq_to" in check_values: if tolerance is not None: @@ -437,13 +477,15 @@ class SQLTableCheckOperator(BaseSQLOperator): Checks should be written to return a boolean result. :param table: the table to run checks on - :param checks: the dictionary of checks, e.g.: + :param checks: the dictionary of checks, where check names are followed by a dictionary containing at + least a check statement, and optionally a partition clause, e.g.: .. code-block:: python { "row_count_check": {"check_statement": "COUNT(*) = 1000"}, "column_sum_check": {"check_statement": "col_a + col_b < col_c"}, + "third_check": {"check_statement": "MIN(col) = 1", "partition_clause": "col IS NOT NULL"}, } @@ -465,8 +507,9 @@ class SQLTableCheckOperator(BaseSQLOperator): template_fields = ("partition_clause",) sql_check_template = """ - SELECT '_check_name' AS check_name, MIN(_check_name) AS check_result - FROM (SELECT CASE WHEN check_statement THEN 1 ELSE 0 END AS _check_name FROM table) AS sq + SELECT '{check_name}' AS check_name, MIN({check_name}) AS check_result + FROM (SELECT CASE WHEN {check_statement} THEN 1 ELSE 0 END AS {check_name} + FROM {table} {partition_clause}) AS sq """ def __init__( @@ -484,46 +527,56 @@ def __init__( self.table = table self.checks = checks self.partition_clause = partition_clause - # OpenLineage needs a valid SQL query with the input/output table(s) to parse - self.sql = f"SELECT * FROM {self.table};" + self.sql = f"SELECT check_name, check_result FROM ({self._generate_sql_query()}) AS check_table" def execute(self, context: Context): hook = self.get_db_hook() - checks_sql = " UNION ALL ".join( - [ - self.sql_check_template.replace("check_statement", value["check_statement"]) - .replace("_check_name", check_name) - .replace("table", self.table) - for check_name, value in self.checks.items() - ] - ) - partition_clause_statement = f"WHERE {self.partition_clause}" if self.partition_clause else "" - self.sql = f""" - SELECT check_name, check_result FROM ({checks_sql}) - AS check_table {partition_clause_statement} - """ - records = hook.get_records(self.sql) if not records: - raise AirflowException(f"The following query returned zero rows: {self.sql}") + self._raise_exception(f"The following query returned zero rows: {self.sql}") self.log.info("Record:\n%s", records) for row in records: check, result = row - self.checks[check]["success"] = parse_boolean(str(result)) + self.checks[check]["success"] = _parse_boolean(str(result)) - failed_tests = _get_failed_checks(self.checks) + failed_tests = [ + f"\tCheck: {check},\n\tCheck Values: {check_values}\n" + for check, check_values in self.checks.items() + if not check_values["success"] + ] if failed_tests: - raise AirflowException( + exception_string = ( f"Test failed.\nQuery:\n{self.sql}\nResults:\n{records!s}\n" - "The following tests have failed:" - f"\n{', '.join(failed_tests)}" + f"The following tests have failed:\n{', '.join(failed_tests)}" ) + self._raise_exception(exception_string) self.log.info("All tests have passed") + def _generate_sql_query(self): + def _generate_partition_clause(check_name): + if self.partition_clause and "partition_clause" not in self.checks[check_name]: + return f"WHERE {self.partition_clause}" + elif not self.partition_clause and "partition_clause" in self.checks[check_name]: + return f"WHERE {self.checks[check_name]['partition_clause']}" + elif self.partition_clause and "partition_clause" in self.checks[check_name]: + return f"WHERE {self.partition_clause} AND {self.checks[check_name]['partition_clause']}" + else: + return "" + + return "UNION ALL".join( + self.sql_check_template.format( + check_statement=value["check_statement"], + check_name=check_name, + table=self.table, + partition_clause=_generate_partition_clause(check_name), + ) + for check_name, value in self.checks.items() + ) + class SQLCheckOperator(BaseSQLOperator): """ @@ -578,9 +631,9 @@ def execute(self, context: Context): self.log.info("Record: %s", records) if not records: - raise AirflowException("The query returned None") + self._raise_exception(f"The following query returned zero rows: {self.sql}") elif not all(bool(r) for r in records): - raise AirflowException(f"Test failed.\nQuery:\n{self.sql}\nResults:\n{records!s}") + self._raise_exception(f"Test failed.\nQuery:\n{self.sql}\nResults:\n{records!s}") self.log.info("Success.") @@ -628,7 +681,7 @@ def execute(self, context: Context): records = self.get_db_hook().get_first(self.sql) if not records: - raise AirflowException("The query returned None") + self._raise_exception(f"The following query returned zero rows: {self.sql}") pass_value_conv = _convert_to_float_if_possible(self.pass_value) is_numeric_value_check = isinstance(pass_value_conv, float) @@ -657,7 +710,7 @@ def execute(self, context: Context): tests = [] if not all(tests): - raise AirflowException(error_msg) + self._raise_exception(error_msg) def _to_float(self, records): return [float(record) for record in records] @@ -729,7 +782,7 @@ def __init__( if ratio_formula not in self.ratio_formulas: msg_template = "Invalid diff_method: {diff_method}. Supported diff methods are: {diff_methods}" - raise AirflowException( + raise AirflowFailException( msg_template.format(diff_method=ratio_formula, diff_methods=self.ratio_formulas) ) self.ratio_formula = ratio_formula @@ -754,9 +807,9 @@ def execute(self, context: Context): row1 = hook.get_first(self.sql1) if not row2: - raise AirflowException(f"The query {self.sql2} returned None") + self._raise_exception(f"The following query returned zero rows: {self.sql2}") if not row1: - raise AirflowException(f"The query {self.sql1} returned None") + self._raise_exception(f"The following query returned zero rows: {self.sql1}") current = dict(zip(self.metrics_sorted, row1)) reference = dict(zip(self.metrics_sorted, row2)) @@ -809,7 +862,7 @@ def execute(self, context: Context): ratios[k], self.metrics_thresholds[k], ) - raise AirflowException(f"The following tests have failed:\n {', '.join(sorted(failed_tests))}") + self._raise_exception(f"The following tests have failed:\n {', '.join(sorted(failed_tests))}") self.log.info("All tests have passed") @@ -852,6 +905,8 @@ def __init__( def execute(self, context: Context): hook = self.get_db_hook() result = hook.get_first(self.sql)[0] + if not result: + self._raise_exception(f"The following query returned zero rows: {self.sql}") if isinstance(self.min_threshold, float): lower_bound = self.min_threshold @@ -886,7 +941,7 @@ def execute(self, context: Context): f"Result: {result} is not within thresholds " f'{meta_data.get("min_threshold")} and {meta_data.get("max_threshold")}' ) - raise AirflowException(error_msg) + self._raise_exception(error_msg) self.log.info("Test %s Successful.", self.task_id) @@ -969,7 +1024,7 @@ def execute(self, context: Context): follow_branch = self.follow_task_ids_if_true elif isinstance(query_result, str): # return result is not Boolean, try to convert from String to Boolean - if parse_boolean(query_result): + if _parse_boolean(query_result): follow_branch = self.follow_task_ids_if_true elif isinstance(query_result, int): if bool(query_result): @@ -987,17 +1042,3 @@ def execute(self, context: Context): ) self.skip_all_except(context["ti"], follow_branch) - - -def _convert_to_float_if_possible(s): - """ - A small helper function to convert a string to a numeric value - if appropriate - - :param s: the string to be converted - """ - try: - ret = float(s) - except (ValueError, TypeError): - ret = s - return ret diff --git a/airflow/providers/google/cloud/operators/bigquery.py b/airflow/providers/google/cloud/operators/bigquery.py index 8db80d993ac28..cec55844b6244 100644 --- a/airflow/providers/google/cloud/operators/bigquery.py +++ b/airflow/providers/google/cloud/operators/bigquery.py @@ -37,8 +37,7 @@ SQLIntervalCheckOperator, SQLTableCheckOperator, SQLValueCheckOperator, - _get_failed_checks, - parse_boolean, + _parse_boolean, ) from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook, BigQueryJob from airflow.providers.google.cloud.hooks.gcs import GCSHook, _parse_gcs_url @@ -248,7 +247,7 @@ def execute_complete(self, context: Context, event: dict[str, Any]) -> None: if not records: raise AirflowException("The query returned empty results") elif not all(bool(r) for r in records): - raise AirflowException(f"Test failed.\nQuery:\n{self.sql}\nResults:\n{records!s}") + self._raise_exception(f"Test failed.\nQuery:\n{self.sql}\nResults:\n{records!s}") self.log.info("Record: %s", event["records"]) self.log.info("Success.") @@ -544,6 +543,8 @@ def __init__( table: str, column_mapping: dict, partition_clause: str | None = None, + database: str | None = None, + accept_none: bool = True, gcp_conn_id: str = "google_cloud_default", use_legacy_sql: bool = True, location: str | None = None, @@ -552,18 +553,23 @@ def __init__( **kwargs, ) -> None: super().__init__( - table=table, column_mapping=column_mapping, partition_clause=partition_clause, **kwargs + table=table, + column_mapping=column_mapping, + partition_clause=partition_clause, + database=database, + accept_none=accept_none, + **kwargs, ) self.table = table self.column_mapping = column_mapping self.partition_clause = partition_clause + self.database = database + self.accept_none = accept_none self.gcp_conn_id = gcp_conn_id self.use_legacy_sql = use_legacy_sql self.location = location self.impersonation_chain = impersonation_chain self.labels = labels - # OpenLineage needs a valid SQL query with the input/output table(s) to parse - self.sql = "" def _submit_job( self, @@ -585,42 +591,41 @@ def execute(self, context=None): """Perform checks on the given columns.""" hook = self.get_db_hook() failed_tests = [] - for column in self.column_mapping: - checks = [*self.column_mapping[column]] - checks_sql = ",".join([self.column_checks[check].replace("column", column) for check in checks]) - partition_clause_statement = f"WHERE {self.partition_clause}" if self.partition_clause else "" - self.sql = f"SELECT {checks_sql} FROM {self.table} {partition_clause_statement};" - - job_id = hook.generate_job_id( - dag_id=self.dag_id, - task_id=self.task_id, - logical_date=context["logical_date"], - configuration=self.configuration, - ) - job = self._submit_job(hook, job_id=job_id) - context["ti"].xcom_push(key="job_id", value=job.job_id) - records = list(job.result().to_dataframe().values.flatten()) - if not records: - raise AirflowException(f"The following query returned zero rows: {self.sql}") + job = self._submit_job(hook, job_id="") + context["ti"].xcom_push(key="job_id", value=job.job_id) + records = job.result().to_dataframe() + + if records.empty: + raise AirflowException(f"The following query returned zero rows: {self.sql}") - self.log.info("Record: %s", records) + records.columns = records.columns.str.lower() + self.log.info("Record: %s", records) - for idx, result in enumerate(records): - tolerance = self.column_mapping[column][checks[idx]].get("tolerance") + for row in records.iterrows(): + column = row[1].get("col_name") + check = row[1].get("check_type") + result = row[1].get("check_result") + tolerance = self.column_mapping[column][check].get("tolerance") - self.column_mapping[column][checks[idx]]["result"] = result - self.column_mapping[column][checks[idx]]["success"] = self._get_match( - self.column_mapping[column][checks[idx]], result, tolerance - ) + self.column_mapping[column][check]["result"] = result + self.column_mapping[column][check]["success"] = self._get_match( + self.column_mapping[column][check], result, tolerance + ) - failed_tests.extend(_get_failed_checks(self.column_mapping[column], column)) + failed_tests( + f"Column: {col}\n\tCheck: {check},\n\tCheck Values: {check_values}\n" + for col, checks in self.column_mapping.items() + for check, check_values in checks.items() + if not check_values["success"] + ) if failed_tests: - raise AirflowException( + exception_string = ( f"Test failed.\nResults:\n{records!s}\n" - "The following tests have failed:" + f"The following tests have failed:" f"\n{''.join(failed_tests)}" ) + self._raise_exception(exception_string) self.log.info("All tests have passed") @@ -677,8 +682,6 @@ def __init__( self.location = location self.impersonation_chain = impersonation_chain self.labels = labels - # OpenLineage needs a valid SQL query with the input/output table(s) to parse - self.sql = "" def _submit_job( self, @@ -699,25 +702,7 @@ def _submit_job( def execute(self, context=None): """Execute the given checks on the table.""" hook = self.get_db_hook() - checks_sql = " UNION ALL ".join( - [ - self.sql_check_template.replace("check_statement", value["check_statement"]) - .replace("_check_name", check_name) - .replace("table", self.table) - for check_name, value in self.checks.items() - ] - ) - partition_clause_statement = f"WHERE {self.partition_clause}" if self.partition_clause else "" - self.sql = f"SELECT check_name, check_result FROM ({checks_sql}) " - f"AS check_table {partition_clause_statement};" - - job_id = hook.generate_job_id( - dag_id=self.dag_id, - task_id=self.task_id, - logical_date=context["logical_date"], - configuration=self.configuration, - ) - job = self._submit_job(hook, job_id=job_id) + job = self._submit_job(hook, job_id="") context["ti"].xcom_push(key="job_id", value=job.job_id) records = job.result().to_dataframe() @@ -730,15 +715,19 @@ def execute(self, context=None): for row in records.iterrows(): check = row[1].get("check_name") result = row[1].get("check_result") - self.checks[check]["success"] = parse_boolean(str(result)) + self.checks[check]["success"] = _parse_boolean(str(result)) - failed_tests = _get_failed_checks(self.checks) + failed_tests = [ + f"\tCheck: {check},\n\tCheck Values: {check_values}\n" + for check, check_values in self.checks.items() + if not check_values["success"] + ] if failed_tests: - raise AirflowException( + exception_string = ( f"Test failed.\nQuery:\n{self.sql}\nResults:\n{records!s}\n" - "The following tests have failed:" - f"\n{', '.join(failed_tests)}" + f"The following tests have failed:\n{', '.join(failed_tests)}" ) + self._raise_exception(exception_string) self.log.info("All tests have passed") diff --git a/docs/apache-airflow-providers-common-sql/operators.rst b/docs/apache-airflow-providers-common-sql/operators.rst index e10759117e0b1..bc725be418c03 100644 --- a/docs/apache-airflow-providers-common-sql/operators.rst +++ b/docs/apache-airflow-providers-common-sql/operators.rst @@ -51,16 +51,14 @@ Check SQL Table Columns Use the :class:`~airflow.providers.common.sql.operators.sql.SQLColumnCheckOperator` to run data quality checks against columns of a given table. As well as a connection ID and table, a column_mapping -describing the relationship between columns and tests to run must be supplied. An example column -mapping is a set of three nested dictionaries and looks like: +describing the relationship between columns and tests to run must be supplied. An example column mapping +is a set of three nested dictionaries and looks like: .. code-block:: python column_mapping = { "col_name": { - "null_check": { - "equal_to": 0, - }, + "null_check": {"equal_to": 0, "partition_clause": "other_col LIKE 'this'"}, "min": { "greater_than": 5, "leq_to": 10, @@ -79,8 +77,8 @@ The valid checks are: - min: checks the minimum value in the column - max: checks the maximum value in the column -Each entry in the check's dictionary is either a condition for success of the check or the tolerance. The -conditions for success are: +Each entry in the check's dictionary is either a condition for success of the check, the tolerance, +or a partition clause. The conditions for success are: - greater_than - geq_to @@ -92,7 +90,14 @@ When specifying conditions, equal_to is not compatible with other conditions. Bo bound condition may be specified in the same check. The tolerance is a percentage that the result may be out of bounds but still considered successful. +The partition clauses may be given at the operator level as a parameter where it partitions all checks, +at the column level in the column mapping where it partitions all checks for that column, or at the +check level for a column where it partitions just that check. + +A database may also be specified if not using the database from the supplied connection. +The accept_none argument, true by default, will convert None values returned by the query to 0s, allowing +empty tables to return valid integers. The below example demonstrates how to instantiate the SQLColumnCheckOperator task. @@ -119,14 +124,20 @@ checks argument is a set of two nested dictionaries and looks like: "row_count_check": { "check_statement": "COUNT(*) = 1000", }, - "column_sum_check": {"check_statement": "col_a + col_b < col_c"}, + "column_sum_check": { + "check_statement": "col_a + col_b < col_c", + "partition_clause": "col_a IS NOT NULL", + }, }, ) The first set of keys are the check names, which are referenced in the templated query the operator builds. -The dictionary key under the check name must be check_statement, with the value a SQL statement that +A dictionary key under the check name must include check_statement and the value a SQL statement that resolves to a boolean (this can be any string or int that resolves to a boolean in -airflow.operators.sql.parse_boolean). +airflow.operators.sql.parse_boolean). The other possible key to supply is partition_clause, which is a +check level statement that will partition the data in the table using a WHERE clause for that check. +This statement is compatible with the parameter partition_clause, where the latter filters across all +checks. The below example demonstrates how to instantiate the SQLTableCheckOperator task. diff --git a/tests/providers/common/sql/operators/test_sql.py b/tests/providers/common/sql/operators/test_sql.py index 46681d468ef17..2980326602935 100644 --- a/tests/providers/common/sql/operators/test_sql.py +++ b/tests/providers/common/sql/operators/test_sql.py @@ -47,9 +47,6 @@ class MockHook: - def get_first(self): - return - def get_records(self): return @@ -108,17 +105,56 @@ class TestColumnCheckOperator: } } + short_valid_column_mapping = { + "X": { + "null_check": {"equal_to": 0}, + "distinct_check": {"equal_to": 10, "tolerance": 0.1}, + } + } + invalid_column_mapping = {"Y": {"invalid_check_name": {"expectation": 5}}} - def _construct_operator(self, monkeypatch, column_mapping, return_vals): - def get_first_return(*arg): - return return_vals + correct_generate_sql_query_no_partitions = """ + SELECT 'X' AS col_name, 'null_check' AS check_type, X_null_check AS check_result + FROM (SELECT SUM(CASE WHEN X IS NULL THEN 1 ELSE 0 END) AS X_null_check FROM test_table ) AS sq + UNION ALL + SELECT 'X' AS col_name, 'distinct_check' AS check_type, X_distinct_check AS check_result + FROM (SELECT COUNT(DISTINCT(X)) AS X_distinct_check FROM test_table ) AS sq + """ + + correct_generate_sql_query_with_partition = """ + SELECT 'X' AS col_name, 'null_check' AS check_type, X_null_check AS check_result + FROM (SELECT SUM(CASE WHEN X IS NULL THEN 1 ELSE 0 END) AS X_null_check FROM test_table WHERE Y > 1) AS sq + UNION ALL + SELECT 'X' AS col_name, 'distinct_check' AS check_type, X_distinct_check AS check_result + FROM (SELECT COUNT(DISTINCT(X)) AS X_distinct_check FROM test_table WHERE Y > 1) AS sq + """ # noqa 501 + + correct_generate_sql_query_with_partition_and_where = """ + SELECT 'X' AS col_name, 'null_check' AS check_type, X_null_check AS check_result + FROM (SELECT SUM(CASE WHEN X IS NULL THEN 1 ELSE 0 END) AS X_null_check FROM test_table WHERE Y > 1 AND Z < 100) AS sq + UNION ALL + SELECT 'X' AS col_name, 'distinct_check' AS check_type, X_distinct_check AS check_result + FROM (SELECT COUNT(DISTINCT(X)) AS X_distinct_check FROM test_table WHERE Y > 1) AS sq + """ # noqa 501 + + correct_generate_sql_query_with_where = """ + SELECT 'X' AS col_name, 'null_check' AS check_type, X_null_check AS check_result + FROM (SELECT SUM(CASE WHEN X IS NULL THEN 1 ELSE 0 END) AS X_null_check FROM test_table ) AS sq + UNION ALL + SELECT 'X' AS col_name, 'distinct_check' AS check_type, X_distinct_check AS check_result + FROM (SELECT COUNT(DISTINCT(X)) AS X_distinct_check FROM test_table WHERE Z < 100) AS sq + """ # 501 + + def _construct_operator(self, monkeypatch, column_mapping, records): + def get_records(*arg): + return records operator = SQLColumnCheckOperator( task_id="test_task", table="test_table", column_mapping=column_mapping ) monkeypatch.setattr(operator, "get_db_hook", _get_mock_db_hook) - monkeypatch.setattr(MockHook, "get_first", get_first_return) + monkeypatch.setattr(MockHook, "get_records", get_records) return operator def test_check_not_in_column_checks(self, monkeypatch): @@ -126,43 +162,164 @@ def test_check_not_in_column_checks(self, monkeypatch): self._construct_operator(monkeypatch, self.invalid_column_mapping, ()) def test_pass_all_checks_exact_check(self, monkeypatch): - operator = self._construct_operator(monkeypatch, self.valid_column_mapping, (0, 10, 10, 1, 19)) + records = [ + ("X", "null_check", 0), + ("X", "distinct_check", 10), + ("X", "unique_check", 10), + ("X", "min", 1), + ("X", "max", 19), + ] + operator = self._construct_operator(monkeypatch, self.valid_column_mapping, records) operator.execute(context=MagicMock()) + assert [ + operator.column_mapping["X"][check]["success"] is True + for check in [*operator.column_mapping["X"]] + ] def test_max_less_than_fails_check(self, monkeypatch): with pytest.raises(AirflowException): - operator = self._construct_operator(monkeypatch, self.valid_column_mapping, (0, 10, 10, 1, 21)) + records = [ + ("X", "null_check", 1), + ("X", "distinct_check", 10), + ("X", "unique_check", 10), + ("X", "min", 1), + ("X", "max", 21), + ] + operator = self._construct_operator(monkeypatch, self.valid_column_mapping, records) operator.execute(context=MagicMock()) assert operator.column_mapping["X"]["max"]["success"] is False def test_max_greater_than_fails_check(self, monkeypatch): with pytest.raises(AirflowException): - operator = self._construct_operator(monkeypatch, self.valid_column_mapping, (0, 10, 10, 1, 9)) + records = [ + ("X", "null_check", 1), + ("X", "distinct_check", 10), + ("X", "unique_check", 10), + ("X", "min", 1), + ("X", "max", 9), + ] + operator = self._construct_operator(monkeypatch, self.valid_column_mapping, records) operator.execute(context=MagicMock()) assert operator.column_mapping["X"]["max"]["success"] is False def test_pass_all_checks_inexact_check(self, monkeypatch): - operator = self._construct_operator(monkeypatch, self.valid_column_mapping, (0, 9, 12, 0, 15)) + records = [ + ("X", "null_check", 0), + ("X", "distinct_check", 9), + ("X", "unique_check", 12), + ("X", "min", 0), + ("X", "max", 15), + ] + operator = self._construct_operator(monkeypatch, self.valid_column_mapping, records) operator.execute(context=MagicMock()) + assert [ + operator.column_mapping["X"][check]["success"] is True + for check in [*operator.column_mapping["X"]] + ] def test_fail_all_checks_check(self, monkeypatch): - operator = operator = self._construct_operator( - monkeypatch, self.valid_column_mapping, (1, 12, 11, -1, 20) - ) + records = [ + ("X", "null_check", 1), + ("X", "distinct_check", 12), + ("X", "unique_check", 11), + ("X", "min", -1), + ("X", "max", 20), + ] + operator = operator = self._construct_operator(monkeypatch, self.valid_column_mapping, records) with pytest.raises(AirflowException): operator.execute(context=MagicMock()) + def test_generate_sql_query_no_partitions(self, monkeypatch): + checks = self.short_valid_column_mapping["X"] + operator = self._construct_operator(monkeypatch, self.short_valid_column_mapping, ()) + assert ( + operator._generate_sql_query("X", checks).lstrip() + == self.correct_generate_sql_query_no_partitions.lstrip() + ) + + def test_generate_sql_query_with_partitions(self, monkeypatch): + checks = self.short_valid_column_mapping["X"] + operator = self._construct_operator(monkeypatch, self.short_valid_column_mapping, ()) + operator.partition_clause = "Y > 1" + assert ( + operator._generate_sql_query("X", checks).lstrip() + == self.correct_generate_sql_query_with_partition.lstrip() + ) + + def test_generate_sql_query_with_partitions_and_check_partition(self, monkeypatch): + self.short_valid_column_mapping["X"]["null_check"]["partition_clause"] = "Z < 100" + checks = self.short_valid_column_mapping["X"] + operator = self._construct_operator(monkeypatch, self.short_valid_column_mapping, ()) + operator.partition_clause = "Y > 1" + assert ( + operator._generate_sql_query("X", checks).lstrip() + == self.correct_generate_sql_query_with_partition_and_where.lstrip() + ) + del self.short_valid_column_mapping["X"]["null_check"]["partition_clause"] + + def test_generate_sql_query_with_check_partition(self, monkeypatch): + self.short_valid_column_mapping["X"]["distinct_check"]["partition_clause"] = "Z < 100" + checks = self.short_valid_column_mapping["X"] + operator = self._construct_operator(monkeypatch, self.short_valid_column_mapping, ()) + assert ( + operator._generate_sql_query("X", checks).lstrip() + == self.correct_generate_sql_query_with_where.lstrip() + ) + del self.short_valid_column_mapping["X"]["distinct_check"]["partition_clause"] + class TestTableCheckOperator: + count_check = "COUNT(*) == 1000" + sum_check = "col_a + col_b < col_c" checks = { - "row_count_check": {"check_statement": "COUNT(*) == 1000"}, - "column_sum_check": {"check_statement": "col_a + col_b < col_c"}, + "row_count_check": {"check_statement": f"{count_check}"}, + "column_sum_check": {"check_statement": f"{sum_check}"}, } - def _construct_operator(self, monkeypatch, checks, return_df): + correct_generate_sql_query_no_partitions = f""" + SELECT 'row_count_check' AS check_name, MIN(row_count_check) AS check_result + FROM (SELECT CASE WHEN {count_check} THEN 1 ELSE 0 END AS row_count_check + FROM test_table ) AS sq + UNION ALL + SELECT 'column_sum_check' AS check_name, MIN(column_sum_check) AS check_result + FROM (SELECT CASE WHEN {sum_check} THEN 1 ELSE 0 END AS column_sum_check + FROM test_table ) AS sq + """ + + correct_generate_sql_query_with_partition = f""" + SELECT 'row_count_check' AS check_name, MIN(row_count_check) AS check_result + FROM (SELECT CASE WHEN {count_check} THEN 1 ELSE 0 END AS row_count_check + FROM test_table WHERE col_a > 10) AS sq + UNION ALL + SELECT 'column_sum_check' AS check_name, MIN(column_sum_check) AS check_result + FROM (SELECT CASE WHEN {sum_check} THEN 1 ELSE 0 END AS column_sum_check + FROM test_table WHERE col_a > 10) AS sq + """ + + correct_generate_sql_query_with_partition_and_where = f""" + SELECT 'row_count_check' AS check_name, MIN(row_count_check) AS check_result + FROM (SELECT CASE WHEN {count_check} THEN 1 ELSE 0 END AS row_count_check + FROM test_table WHERE col_a > 10 AND id = 100) AS sq + UNION ALL + SELECT 'column_sum_check' AS check_name, MIN(column_sum_check) AS check_result + FROM (SELECT CASE WHEN {sum_check} THEN 1 ELSE 0 END AS column_sum_check + FROM test_table WHERE col_a > 10) AS sq + """ + + correct_generate_sql_query_with_where = f""" + SELECT 'row_count_check' AS check_name, MIN(row_count_check) AS check_result + FROM (SELECT CASE WHEN {count_check} THEN 1 ELSE 0 END AS row_count_check + FROM test_table ) AS sq + UNION ALL + SELECT 'column_sum_check' AS check_name, MIN(column_sum_check) AS check_result + FROM (SELECT CASE WHEN {sum_check} THEN 1 ELSE 0 END AS column_sum_check + FROM test_table WHERE id = 100) AS sq + """ + + def _construct_operator(self, monkeypatch, checks, records): def get_records(*arg): - return return_df + return records operator = SQLTableCheckOperator(task_id="test_task", table="test_table", checks=checks) monkeypatch.setattr(operator, "get_db_hook", _get_mock_db_hook) @@ -210,6 +367,7 @@ def test_pass_all_checks_check(self, monkeypatch): records = [("row_count_check", 1), ("column_sum_check", "y")] operator = self._construct_operator(monkeypatch, self.checks, records) operator.execute(context=MagicMock()) + assert [operator.checks[check]["success"] is True for check in operator.checks.keys()] def test_fail_all_checks_check(self, monkeypatch): records = [("row_count_check", 0), ("column_sum_check", "n")] @@ -217,6 +375,35 @@ def test_fail_all_checks_check(self, monkeypatch): with pytest.raises(AirflowException): operator.execute(context=MagicMock()) + def test_generate_sql_query_no_partitions(self, monkeypatch): + operator = self._construct_operator(monkeypatch, self.checks, ()) + assert ( + operator._generate_sql_query().lstrip() == self.correct_generate_sql_query_no_partitions.lstrip() + ) + + def test_generate_sql_query_with_partitions(self, monkeypatch): + operator = self._construct_operator(monkeypatch, self.checks, ()) + operator.partition_clause = "col_a > 10" + assert ( + operator._generate_sql_query().lstrip() == self.correct_generate_sql_query_with_partition.lstrip() + ) + + def test_generate_sql_query_with_partitions_and_check_partition(self, monkeypatch): + self.checks["row_count_check"]["partition_clause"] = "id = 100" + operator = self._construct_operator(monkeypatch, self.checks, ()) + operator.partition_clause = "col_a > 10" + assert ( + operator._generate_sql_query().lstrip() + == self.correct_generate_sql_query_with_partition_and_where.lstrip() + ) + del self.checks["row_count_check"]["partition_clause"] + + def test_generate_sql_query_with_check_partition(self, monkeypatch): + self.checks["column_sum_check"]["partition_clause"] = "id = 100" + operator = self._construct_operator(monkeypatch, self.checks, ()) + assert operator._generate_sql_query().lstrip() == self.correct_generate_sql_query_with_where.lstrip() + del self.checks["column_sum_check"]["partition_clause"] + DEFAULT_DATE = timezone.datetime(2016, 1, 1) INTERVAL = datetime.timedelta(hours=12) @@ -303,7 +490,7 @@ def setUp(self): def test_execute_no_records(self, mock_get_db_hook): mock_get_db_hook.return_value.get_first.return_value = [] - with pytest.raises(AirflowException, match=r"The query returned None"): + with pytest.raises(AirflowException, match=r"The following query returned zero rows: sql"): self._operator.execute({}) @mock.patch.object(SQLCheckOperator, "get_db_hook") diff --git a/tests/system/providers/google/cloud/bigquery/example_bigquery_queries.py b/tests/system/providers/google/cloud/bigquery/example_bigquery_queries.py index 789b5fb921229..93ef73fb44687 100644 --- a/tests/system/providers/google/cloud/bigquery/example_bigquery_queries.py +++ b/tests/system/providers/google/cloud/bigquery/example_bigquery_queries.py @@ -223,7 +223,7 @@ table_check = BigQueryTableCheckOperator( task_id="table_check", table=f"{DATASET}.{TABLE_1}", - checks={"row_count_check": {"check_statement": {"COUNT(*) = 4"}}}, + checks={"row_count_check": {"check_statement": "COUNT(*) = 4"}}, ) # [END howto_operator_bigquery_table_check]