diff --git a/airflow/operators/sql.py b/airflow/operators/sql.py index efa5d0d81a8f0..65cf92d0ced28 100644 --- a/airflow/operators/sql.py +++ b/airflow/operators/sql.py @@ -467,6 +467,272 @@ def push(self, meta_data): self.log.info("Log from %s:\n%s", self.dag_id, info) +def _get_failed_tests(checks): + return [ + f"\tCheck: {check},\n\tCheck Values: {check_values}\n" + for check, check_values in checks.items() + if not check_values["success"] + ] + + +class SQLColumnCheckOperator(BaseSQLOperator): + """ + Performs one or more of the templated checks in the column_checks dictionary. + Checks are performed on a per-column basis specified by the column_mapping. + Each check can take one or more of the following options: + - equal_to: an exact value to equal, cannot be used with other comparison options + - greater_than: value that result should be strictly greater than + - less_than: value that results should be strictly less than + - geq_than: value that results should be greater than or equal to + - leq_than: value that results should be less than or equal to + - tolerance: the percentage that the result may be off from the expected value + + :param table: the table to run checks on + :param column_mapping: the dictionary of columns and their associated checks, e.g. + + .. code-block:: python + + { + "col_name": { + "null_check": { + "equal_to": 0, + }, + "min": { + "greater_than": 5, + "leq_than": 10, + "tolerance": 0.2, + }, + "max": {"less_than": 1000, "geq_than": 10, "tolerance": 0.01}, + } + } + + :param conn_id: the connection ID used to connect to the database + :param database: name of database which overwrite the defined one in connection + """ + + column_checks = { + "null_check": "SUM(CASE WHEN column IS NULL THEN 1 ELSE 0 END) AS column_null_check", + "distinct_check": "COUNT(DISTINCT(column)) AS column_distinct_check", + "unique_check": "COUNT(DISTINCT(column)) = COUNT(column) AS column_unique_check", + "min": "MIN(column) AS column_min", + "max": "MAX(column) AS column_max", + } + + def __init__( + self, + *, + table: str, + column_mapping: Dict[str, Dict[str, Any]], + conn_id: Optional[str] = None, + database: Optional[str] = None, + **kwargs, + ): + super().__init__(conn_id=conn_id, database=database, **kwargs) + for checks in column_mapping.values(): + for check, check_values in checks.items(): + self._column_mapping_validation(check, check_values) + + self.table = table + self.column_mapping = column_mapping + # OpenLineage needs a valid SQL query with the input/output table(s) to parse + self.sql = f"SELECT * FROM {self.table};" + + def execute(self, context=None): + hook = self.get_db_hook() + failed_tests = [] + for column in self.column_mapping: + checks = [*self.column_mapping[column]] + checks_sql = ",".join([self.column_checks[check].replace("column", column) for check in checks]) + + self.sql = f"SELECT {checks_sql} FROM {self.table};" + records = hook.get_first(self.sql) + + if not records: + raise AirflowException(f"The following query returned zero rows: {self.sql}") + + self.log.info(f"Record: {records}") + + for idx, result in enumerate(records): + tolerance = self.column_mapping[column][checks[idx]].get("tolerance") + + self.column_mapping[column][checks[idx]]["result"] = result + self.column_mapping[column][checks[idx]]["success"] = self._get_match( + self.column_mapping[column][checks[idx]], result, tolerance + ) + + failed_tests.extend(_get_failed_tests(self.column_mapping[column])) + if failed_tests: + raise AirflowException( + f"Test failed.\nQuery:\n{self.sql}\nResults:\n{records!s}\n" + "The following tests have failed:" + f"\n{''.join(failed_tests)}" + ) + + self.log.info("All tests have passed") + + def _get_match(self, check_values, record, tolerance=None) -> bool: + if "geq_than" in check_values: + if tolerance is not None: + return record >= check_values["geq_than"] * (1 - tolerance) + return record >= check_values["geq_than"] + elif "greater_than" in check_values: + if tolerance is not None: + return record > check_values["greater_than"] * (1 - tolerance) + return record > check_values["greater_than"] + if "leq_than" in check_values: + if tolerance is not None: + return record <= check_values["leq_than"] * (1 + tolerance) + return record <= check_values["leq_than"] + elif "less_than" in check_values: + if tolerance is not None: + return record < check_values["less_than"] * (1 + tolerance) + return record < check_values["less_than"] + if "equal_to" in check_values: + if tolerance is not None: + return ( + check_values["equal_to"] * (1 - tolerance) + <= record + <= check_values["equal_to"] * (1 + tolerance) + ) + return record == check_values["equal_to"] + + def _column_mapping_validation(self, check, check_values): + if check not in self.column_checks: + raise AirflowException(f"Invalid column check: {check}.") + if ( + "greater_than" not in check_values + and "geq_than" not in check_values + and "less_than" not in check_values + and "leq_than" not in check_values + and "equal_to" not in check_values + ): + raise ValueError( + "Please provide one or more of: less_than, leq_than, " + "greater_than, geq_than, or equal_to in the check's dict." + ) + + if "greater_than" in check_values and "less_than" in check_values: + if check_values["greater_than"] >= check_values["less_than"]: + raise ValueError( + "greater_than should be strictly less than " + "less_than. Use geq_than or leq_than for " + "overlapping equality." + ) + + if "greater_than" in check_values and "leq_than" in check_values: + if check_values["greater_than"] >= check_values["leq_than"]: + raise ValueError( + "greater_than must be strictly less than leq_than. " + "Use geq_than with leq_than for overlapping equality." + ) + + if "geq_than" in check_values and "less_than" in check_values: + if check_values["geq_than"] >= check_values["less_than"]: + raise ValueError( + "geq_than should be strictly less than less_than. " + "Use leq_than with geq_than for overlapping equality." + ) + + if "geq_than" in check_values and "leq_than" in check_values: + if check_values["geq_than"] > check_values["leq_than"]: + raise ValueError("geq_than should be less than or equal to leq_than.") + + if "greater_than" in check_values and "geq_than" in check_values: + raise ValueError("Only supply one of greater_than or geq_than.") + + if "less_than" in check_values and "leq_than" in check_values: + raise ValueError("Only supply one of less_than or leq_than.") + + if ( + "greater_than" in check_values + or "geq_than" in check_values + or "less_than" in check_values + or "leq_than" in check_values + ) and "equal_to" in check_values: + raise ValueError( + "equal_to cannot be passed with a greater or less than " + "function. To specify 'greater than or equal to' or " + "'less than or equal to', use geq_than or leq_than." + ) + + +class SQLTableCheckOperator(BaseSQLOperator): + """ + Performs one or more of the checks provided in the checks dictionary. + Checks should be written to return a boolean result. + + :param table: the table to run checks on. + :param checks: the dictionary of checks, e.g.: + + .. code-block:: python + + { + "row_count_check": {"check_statement": "COUNT(*) = 1000"}, + "column_sum_check": {"check_statement": "col_a + col_b < col_c"}, + } + + :param conn_id: the connection ID used to connect to the database. + :param database: name of database which overwrite the defined one in connection + """ + + sql_check_template = "CASE WHEN check_statement THEN 1 ELSE 0 END AS check_name" + sql_min_template = "MIN(check_name)" + + def __init__( + self, + *, + table: str, + checks: Dict[str, Dict[str, Any]], + conn_id: Optional[str] = None, + database: Optional[str] = None, + **kwargs, + ): + super().__init__(conn_id=conn_id, database=database, **kwargs) + + self.table = table + self.checks = checks + # OpenLineage needs a valid SQL query with the input/output table(s) to parse + self.sql = f"SELECT * FROM {self.table};" + + def execute(self, context=None): + hook = self.get_db_hook() + + check_names = [*self.checks] + check_mins_sql = ",".join( + self.sql_min_template.replace("check_name", check_name) for check_name in check_names + ) + checks_sql = ",".join( + [ + self.sql_check_template.replace("check_statement", value["check_statement"]).replace( + "check_name", check_name + ) + for check_name, value in self.checks.items() + ] + ) + + self.sql = f"SELECT {check_mins_sql} FROM (SELECT {checks_sql} FROM {self.table});" + records = hook.get_first(self.sql) + + if not records: + raise AirflowException(f"The following query returned zero rows: {self.sql}") + + self.log.info(f"Record: {records}") + + for check in self.checks.keys(): + for result in records: + self.checks[check]["success"] = bool(result) + + failed_tests = _get_failed_tests(self.checks) + if failed_tests: + raise AirflowException( + f"Test failed.\nQuery:\n{self.sql}\nResults:\n{records!s}\n" + "The following tests have failed:" + f"\n{', '.join(failed_tests)}" + ) + + self.log.info("All tests have passed") + + class BranchSQLOperator(BaseSQLOperator, SkipMixin): """ Allows a DAG to "branch" or follow a specified path based on the results of a SQL query. diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 1e3142b654306..90a3ac0d4c901 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -856,6 +856,7 @@ gcpcloudsql gcs gdbm generateUploadUrl +geq getattr getfqdn getframe @@ -1013,6 +1014,7 @@ latencies latin ldap ldaps +leq leveldb libs libz diff --git a/tests/operators/test_sql.py b/tests/operators/test_sql.py index 2e73c3ac33487..329d06f57fbe4 100644 --- a/tests/operators/test_sql.py +++ b/tests/operators/test_sql.py @@ -28,7 +28,9 @@ from airflow.operators.sql import ( BranchSQLOperator, SQLCheckOperator, + SQLColumnCheckOperator, SQLIntervalCheckOperator, + SQLTableCheckOperator, SQLThresholdCheckOperator, SQLValueCheckOperator, ) @@ -385,6 +387,82 @@ def test_fail_min_sql_max_value(self, mock_get_db_hook): operator.execute() +class TestColumnCheckOperator(unittest.TestCase): + + valid_column_mapping = { + "X": { + "null_check": {"equal_to": 0}, + "distinct_check": {"equal_to": 10, "tolerance": 0.1}, + "unique_check": {"geq_than": 10}, + "min": {"leq_than": 1}, + "max": {"less_than": 20, "greater_than": 10}, + } + } + + invalid_column_mapping = {"Y": {"invalid_check_name": {"expectation": 5}}} + + def _construct_operator(self, column_mapping): + return SQLColumnCheckOperator(task_id="test_task", table="test_table", column_mapping=column_mapping) + + @mock.patch.object(SQLColumnCheckOperator, "get_db_hook") + def test_check_not_in_column_checks(self, mock_get_db_hook): + with pytest.raises(AirflowException, match="Invalid column check: invalid_check_name."): + self._construct_operator(self.invalid_column_mapping) + + @mock.patch.object(SQLColumnCheckOperator, "get_db_hook") + def test_pass_all_checks_exact_check(self, mock_get_db_hook): + mock_hook = mock.Mock() + mock_hook.get_first.return_value = (0, 10, 10, 1, 19) + mock_get_db_hook.return_value = mock_hook + operator = self._construct_operator(self.valid_column_mapping) + operator.execute() + + @mock.patch.object(SQLColumnCheckOperator, "get_db_hook") + def test_pass_all_checks_inexact_check(self, mock_get_db_hook): + mock_hook = mock.Mock() + mock_hook.get_first.return_value = (0, 9, 12, 0, 15) + mock_get_db_hook.return_value = mock_hook + operator = self._construct_operator(self.valid_column_mapping) + operator.execute() + + @mock.patch.object(SQLColumnCheckOperator, "get_db_hook") + def test_fail_all_checks_check(self, mock_get_db_hook): + mock_hook = mock.Mock() + mock_hook.get_first.return_value = (1, 12, 11, -1, 20) + mock_get_db_hook.return_value = mock_hook + operator = self._construct_operator(self.valid_column_mapping) + with pytest.raises(AirflowException): + operator.execute() + + +class TestTableCheckOperator(unittest.TestCase): + + checks = { + "row_count_check": {"check_statement": "COUNT(*) == 1000"}, + "column_sum_check": {"check_statement": "col_a + col_b < col_c"}, + } + + def _construct_operator(self, checks): + return SQLTableCheckOperator(task_id="test_task", table="test_table", checks=checks) + + @mock.patch.object(SQLTableCheckOperator, "get_db_hook") + def test_pass_all_checks_check(self, mock_get_db_hook): + mock_hook = mock.Mock() + mock_hook.get_first.return_value = (1000, 1) + mock_get_db_hook.return_value = mock_hook + operator = self._construct_operator(self.checks) + operator.execute() + + @mock.patch.object(SQLTableCheckOperator, "get_db_hook") + def test_fail_all_checks_check(self, mock_get_db_hook): + mock_hook = mock.Mock() + mock_hook.get_first.return_value = (998, 0) + mock_get_db_hook.return_value = mock_hook + operator = self._construct_operator(self.checks) + with pytest.raises(AirflowException): + operator.execute() + + class TestSqlBranch(TestHiveEnvironment, unittest.TestCase): """ Test for SQL Branch Operator