From f9d3d3ec9c18ebc571b4cfb1a0f2e81673524b1b Mon Sep 17 00:00:00 2001 From: Benji Lampel Date: Wed, 25 May 2022 10:58:56 -0400 Subject: [PATCH 01/15] Add new SQLCheckOperators This commit adds two new SQL Check Operators to the core sql.py file, one called SQLColumnCheckOperator and one called SQLTableCheckOperator. Corresponding tests are added as well. The current functionality of these operators is minimal, but expected to grow over time. These operators, unlike current ones, save information about the results of the check queries that is then parseable by third party libraries. --- airflow/operators/sql.py | 216 ++++++++++++++++++++++++++++++++++++ tests/operators/test_sql.py | 88 +++++++++++++++ 2 files changed, 304 insertions(+) diff --git a/airflow/operators/sql.py b/airflow/operators/sql.py index efa5d0d81a8f0..9c09ec20c5c64 100644 --- a/airflow/operators/sql.py +++ b/airflow/operators/sql.py @@ -467,6 +467,218 @@ def push(self, meta_data): self.log.info("Log from %s:\n%s", self.dag_id, info) +def _get_failed_tests(checks): + failed_tests = [] + for check, check_values in checks.items(): + if not check_values["success"]: + failed_tests.append( + f"\tCheck: {check}, " + f"Pass Value: {check_values['pass_value']}, " + f"Result: {check_values['result']}\n" + ) + return failed_tests + + +class SQLColumnCheckOperator(BaseSQLOperator): + """ + Performs one or more of the templated checks in the column_checks dictionary. + Checks are performed on a per-column basis specified by the column_mapping. + + :param table: the table to run checks on. + :param column_mapping: the dictionary of columns and their associated checks, e.g.: + { + "col_name": { + "null_check": { + "pass_value": 0, + }, + "min": { + "pass_value": 5, + "tolerance": 0.2, + } + } + } + :param conn_id: the connection ID used to connect to the database. + :param database: name of database which overwrite the defined one in connection + """ + + column_checks = { + # pass value should be number of acceptable nulls + "null_check": "SUM(CASE WHEN 'column' IS NULL THEN 1 ELSE 0 END) AS column_null_check", + # pass value should be number of acceptable distinct values + "distinct_check": "COUNT(DISTINCT(column)) AS column_distinct_check", + # pass value is implicit in the query, it does not need to be passed + "unique_check": "COUNT(DISTINCT(column)) = COUNT(column)", + # pass value should be the minimum acceptable numeric value + "min": "MIN(column) AS column_min", + # pass value should be the maximum acceptable numeric value + "max": "MAX(column) AS column_max", + } + + def __init__( + self, + *, + table: str, + column_mapping: Dict[str, Dict[str, Any]], + conn_id: Optional[str] = None, + database: Optional[str] = None, + **kwargs, + ): + super().__init__(conn_id=conn_id, database=database, **kwargs) + for checks in column_mapping.values(): + for check in checks: + if check not in self.column_checks: + raise AirflowException(f"Invalid column check: {check}.") + + self.table = table + self.column_mapping = column_mapping + self.sql = f"SELECT * FROM {self.table};" + + def execute(self, context=None): + hook = self.get_db_hook() + failed_tests = [] + for column in self.column_mapping: + checks = [*self.column_mapping[column]] + checks_sql = ",".join([self.column_checks[check].replace("column", column) for check in checks]) + + self.sql = f"SELECT {checks_sql} FROM {self.table};" + records = hook.get_first(self.sql) + self.log.info(f"Record: {records}") + + if not records: + raise AirflowException("The query returned None") + + for idx, result in enumerate(records): + pass_value_conv = _convert_to_float_if_possible( + self.column_mapping[column][checks[idx]]["pass_value"] + ) + is_numeric_value_check = isinstance(pass_value_conv, float) + tolerance = ( + self.column_mapping[column][checks[idx]]["tolerance"] + if "tolerance" in self.column_mapping[column][checks[idx]] + else None + ) + + self.column_mapping[column][checks[idx]]["result"] = result + self.column_mapping[column][checks[idx]]["success"] = ( + self._get_numeric_match( + checks[idx], result, self.column_mapping[column][checks[idx]]["pass_value"], tolerance + ) + if is_numeric_value_check + else (result == self.column_mapping[column][checks[idx]]["pass_value"]) + ) + + failed_tests.extend(_get_failed_tests(self.column_mapping[column])) + if failed_tests: + raise AirflowException( + f"Test failed.\nQuery:\n{self.sql}\nResults:\n{records!s}\n" + "The following tests have failed:" + f"\n{''.join(failed_tests)}" + ) + + self.log.info("All tests have passed") + + def _get_numeric_match(self, check, numeric_record, numeric_pass_value, tolerance=None) -> bool: + if check in "min": + if tolerance is not None: + return numeric_record >= numeric_pass_value * (1 - tolerance) + return numeric_record >= numeric_pass_value + if check in "max": + if tolerance is not None: + return numeric_record <= numeric_pass_value * (1 + tolerance) + return numeric_record <= numeric_pass_value + if check in ["null_check", "distinct_check", "unique_check"]: + if tolerance is not None: + return ( + numeric_pass_value * (1 - tolerance) + <= numeric_record + <= numeric_pass_value * (1 + tolerance) + ) + return numeric_record == numeric_pass_value + + +class SQLTableCheckOperator(BaseSQLOperator): + """ + Performs one or more of the templated checks in the table_checks dictionary. + Checks are performed on the table as aggregates. + + :param table: the table to run checks on. + :param checks: the dictionary of checks, e.g.: + { + "row_count_check": { + "pass_value": 100, + "tolerance": .05 + } + } + :param conn_id: the connection ID used to connect to the database. + :param database: name of database which overwrite the defined one in connection + """ + + table_checks = { + "row_count_check": "COUNT(*) AS row_count_check", + } + + def __init__( + self, + *, + table: str, + checks: Dict[str, Dict[str, Any]], + conn_id: Optional[str] = None, + database: Optional[str] = None, + **kwargs, + ): + super().__init__(conn_id=conn_id, database=database, **kwargs) + for check in checks.keys(): + if check not in self.table_checks: + raise AirflowException(f"Invalid table check: {check}.") + + self.table = table + self.checks = checks + self.sql = f"SELECT * FROM {self.table};" + + def execute(self, context=None): + hook = self.get_db_hook() + + checks_sql = ",".join([self.table_checks[check] for check in self.checks.keys()]) + + self.sql = f"SELECT {checks_sql} FROM {self.table};" + records = hook.get_first(self.sql) + + self.log.info(f"Record: {records}") + + if not records: + raise AirflowException("The query returned None") + + for check in self.checks.keys(): + for result in records: + pass_value_conv = _convert_to_float_if_possible(self.checks[check]["pass_value"]) + is_numeric_value_check = isinstance(pass_value_conv, float) + tolerance = self.checks[check]["tolerance"] if "tolerance" in self.checks[check] else None + + self.checks[check]["result"] = result + self.checks[check]["success"] = ( + self._get_numeric_match(result, self.checks[check]["pass_value"], tolerance) + if is_numeric_value_check + else (result == self.checks[check]["pass_value"]) + ) + + failed_tests = _get_failed_tests(self.checks) + if failed_tests: + raise AirflowException( + f"Test failed.\nQuery:\n{self.sql}\nResults:\n{records!s}\n" + "The following tests have failed:" + f"\n{', '.join(failed_tests)}" + ) + + self.log.info("All tests have passed") + + def _get_numeric_match(self, numeric_record, numeric_pass_value, tolerance=None): + if tolerance is not None: + return ( + numeric_pass_value * (1 - tolerance) <= numeric_record <= numeric_pass_value * (1 + tolerance) + ) + return numeric_record == numeric_pass_value + + class BranchSQLOperator(BaseSQLOperator, SkipMixin): """ Allows a DAG to "branch" or follow a specified path based on the results of a SQL query. @@ -555,3 +767,7 @@ def execute(self, context: Context): ) self.skip_all_except(context["ti"], follow_branch) + + self.skip_all_except(context["ti"], follow_branch) + self.skip_all_except(context["ti"], follow_branch) + self.skip_all_except(context["ti"], follow_branch) diff --git a/tests/operators/test_sql.py b/tests/operators/test_sql.py index 2e73c3ac33487..780206aa0b7c7 100644 --- a/tests/operators/test_sql.py +++ b/tests/operators/test_sql.py @@ -28,7 +28,9 @@ from airflow.operators.sql import ( BranchSQLOperator, SQLCheckOperator, + SQLColumnCheckOperator, SQLIntervalCheckOperator, + SQLTableCheckOperator, SQLThresholdCheckOperator, SQLValueCheckOperator, ) @@ -385,6 +387,92 @@ def test_fail_min_sql_max_value(self, mock_get_db_hook): operator.execute() +class TestColumnCheckOperator(unittest.TestCase): + + valid_column_mapping = { + "X": { + "null_check": {"pass_value": 0}, + "distinct_check": {"pass_value": 10, "tolerance": 0.1}, + "unique_check": {"pass_value": 10}, + "min": {"pass_value": 1}, + "max": {"pass_value": 20}, + } + } + + invalid_column_mapping = {"Y": {"invalid_check_name": {"expectation": 5}}} + + def _construct_operator(self, column_mapping): + return SQLColumnCheckOperator(task_id="test_task", table="test_table", column_mapping=column_mapping) + + @mock.patch.object(SQLColumnCheckOperator, "get_db_hook") + def test_check_not_in_column_checks(self, mock_get_db_hook): + with pytest.raises(AirflowException, match="Invalid column check: invalid_check_name."): + self._construct_operator(self.invalid_column_mapping) + + @mock.patch.object(SQLColumnCheckOperator, "get_db_hook") + def test_pass_all_checks_check(self, mock_get_db_hook): + mock_hook = mock.Mock() + mock_hook.get_first.return_value = (0, 10, 10, 1, 20) + mock_get_db_hook.return_value = mock_hook + operator = self._construct_operator(self.valid_column_mapping) + operator.execute() + + @mock.patch.object(SQLColumnCheckOperator, "get_db_hook") + def test_fail_all_checks_check(self, mock_get_db_hook): + mock_hook = mock.Mock() + mock_hook.get_first.return_value = (1, 12, 11, -1, 50) + mock_get_db_hook.return_value = mock_hook + operator = self._construct_operator(self.valid_column_mapping) + with pytest.raises(AirflowException): + operator.execute() + + +class TestTableCheckOperator(unittest.TestCase): + + valid_checks = { + "row_count_check": {"pass_value": 10}, + } + + valid_checks_tolerance = { + "row_count_check": {"pass_value": 10, "tolerance": 0.2}, + } + + invalid_checks = {"invalid_check_name": {"pass_value": 5}} + + def _construct_operator(self, checks): + return SQLTableCheckOperator(task_id="test_task", table="test_table", checks=checks) + + @mock.patch.object(SQLTableCheckOperator, "get_db_hook") + def test_check_not_in_checks(self, mock_get_db_hook): + with pytest.raises(AirflowException, match="Invalid table check: invalid_check_name."): + self._construct_operator(self.invalid_checks) + + @mock.patch.object(SQLTableCheckOperator, "get_db_hook") + def test_pass_all_checks_check(self, mock_get_db_hook): + mock_hook = mock.Mock() + mock_hook.get_first.return_value = (10,) + mock_get_db_hook.return_value = mock_hook + operator = self._construct_operator(self.valid_checks) + operator.execute() + + @mock.patch.object(SQLTableCheckOperator, "get_db_hook") + def test_pass_all_checks_with_tolerance_check(self, mock_get_db_hook): + mock_hook = mock.Mock() + mock_hook.get_first.return_value = (11,) + mock_get_db_hook.return_value = mock_hook + operator = self._construct_operator(self.valid_checks_tolerance) + operator.execute() + + @mock.patch.object(SQLTableCheckOperator, "get_db_hook") + def test_fail_all_checks_check(self, mock_get_db_hook): + mock_hook = mock.Mock() + mock_hook.get_first.return_value = (1,) + mock_get_db_hook.return_value = mock_hook + operator = self._construct_operator(self.valid_checks) + with pytest.raises(AirflowException): + operator.execute() + + class TestSqlBranch(TestHiveEnvironment, unittest.TestCase): """ Test for SQL Branch Operator From a58a3c5422daae087a8253793f0c7e99cf36e695 Mon Sep 17 00:00:00 2001 From: Benjamin Date: Wed, 25 May 2022 12:10:27 -0400 Subject: [PATCH 02/15] Use list comprehension in _get_failed_tests() Co-authored-by: Tzu-ping Chung --- airflow/operators/sql.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/airflow/operators/sql.py b/airflow/operators/sql.py index 9c09ec20c5c64..59baae09fa89c 100644 --- a/airflow/operators/sql.py +++ b/airflow/operators/sql.py @@ -468,15 +468,13 @@ def push(self, meta_data): def _get_failed_tests(checks): - failed_tests = [] - for check, check_values in checks.items(): - if not check_values["success"]: - failed_tests.append( - f"\tCheck: {check}, " - f"Pass Value: {check_values['pass_value']}, " - f"Result: {check_values['result']}\n" - ) - return failed_tests + return [ + f"\tCheck: {check}, " + f"Pass Value: {check_values['pass_value']}, " + f"Result: {check_values['result']}\n" + for check, check_values in checks.items() + if not check_values["success"] + ] class SQLColumnCheckOperator(BaseSQLOperator): From c6950dece67edeea84602ed9be423831073cc15e Mon Sep 17 00:00:00 2001 From: Benji Lampel Date: Wed, 25 May 2022 12:11:13 -0400 Subject: [PATCH 03/15] Remove repeated lines of code Accidentally copied the last line of the file several times, this commit remedies that change. --- airflow/operators/sql.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/airflow/operators/sql.py b/airflow/operators/sql.py index 59baae09fa89c..16ffce50c71d7 100644 --- a/airflow/operators/sql.py +++ b/airflow/operators/sql.py @@ -765,7 +765,3 @@ def execute(self, context: Context): ) self.skip_all_except(context["ti"], follow_branch) - - self.skip_all_except(context["ti"], follow_branch) - self.skip_all_except(context["ti"], follow_branch) - self.skip_all_except(context["ti"], follow_branch) From 181869181e8addcc819df151852021efd4beb9d4 Mon Sep 17 00:00:00 2001 From: Benji Lampel Date: Wed, 25 May 2022 16:59:32 -0400 Subject: [PATCH 04/15] Fix doc strings. --- airflow/operators/sql.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/airflow/operators/sql.py b/airflow/operators/sql.py index 16ffce50c71d7..60cfae1848e7c 100644 --- a/airflow/operators/sql.py +++ b/airflow/operators/sql.py @@ -484,17 +484,17 @@ class SQLColumnCheckOperator(BaseSQLOperator): :param table: the table to run checks on. :param column_mapping: the dictionary of columns and their associated checks, e.g.: - { - "col_name": { - "null_check": { - "pass_value": 0, - }, - "min": { - "pass_value": 5, - "tolerance": 0.2, - } + { + 'col_name': { + 'null_check': { + 'pass_value': 0, + }, + 'min': { + 'pass_value': 5, + 'tolerance': 0.2, } } + } :param conn_id: the connection ID used to connect to the database. :param database: name of database which overwrite the defined one in connection """ @@ -601,12 +601,12 @@ class SQLTableCheckOperator(BaseSQLOperator): :param table: the table to run checks on. :param checks: the dictionary of checks, e.g.: - { - "row_count_check": { - "pass_value": 100, - "tolerance": .05 - } + { + 'row_count_check': { + 'pass_value': 100, + 'tolerance': .05 } + } :param conn_id: the connection ID used to connect to the database. :param database: name of database which overwrite the defined one in connection """ From fc8380ebb753ba849452352b5ad52c889911486e Mon Sep 17 00:00:00 2001 From: Benji Lampel Date: Fri, 3 Jun 2022 12:23:57 -0400 Subject: [PATCH 05/15] Update SQLColumnCheckOperator Improved the column check operator to include support for greater than, geq, less than, and leq functions. Updated logic and tests to reflect this change. --- airflow/operators/sql.py | 125 +++++++++++++++++++++++++++--------- tests/operators/test_sql.py | 24 ++++--- 2 files changed, 111 insertions(+), 38 deletions(-) diff --git a/airflow/operators/sql.py b/airflow/operators/sql.py index 60cfae1848e7c..f4fe32335420c 100644 --- a/airflow/operators/sql.py +++ b/airflow/operators/sql.py @@ -469,9 +469,7 @@ def push(self, meta_data): def _get_failed_tests(checks): return [ - f"\tCheck: {check}, " - f"Pass Value: {check_values['pass_value']}, " - f"Result: {check_values['result']}\n" + f"\tCheck: {check}, " f"Check Values: {check_values}\n" for check, check_values in checks.items() if not check_values["success"] ] @@ -487,11 +485,17 @@ class SQLColumnCheckOperator(BaseSQLOperator): { 'col_name': { 'null_check': { - 'pass_value': 0, + 'equal_to': 0, }, 'min': { - 'pass_value': 5, + 'greater_than': 5, + 'leq_than': 10, 'tolerance': 0.2, + }, + 'max': { + 'less_than': 1000, + 'geq_than': 10, + 'tolerance': 0.01 } } } @@ -523,9 +527,8 @@ def __init__( ): super().__init__(conn_id=conn_id, database=database, **kwargs) for checks in column_mapping.values(): - for check in checks: - if check not in self.column_checks: - raise AirflowException(f"Invalid column check: {check}.") + for check, check_values in checks.items(): + self._column_mapping_validation(check, check_values) self.table = table self.column_mapping = column_mapping @@ -546,10 +549,6 @@ def execute(self, context=None): raise AirflowException("The query returned None") for idx, result in enumerate(records): - pass_value_conv = _convert_to_float_if_possible( - self.column_mapping[column][checks[idx]]["pass_value"] - ) - is_numeric_value_check = isinstance(pass_value_conv, float) tolerance = ( self.column_mapping[column][checks[idx]]["tolerance"] if "tolerance" in self.column_mapping[column][checks[idx]] @@ -557,12 +556,8 @@ def execute(self, context=None): ) self.column_mapping[column][checks[idx]]["result"] = result - self.column_mapping[column][checks[idx]]["success"] = ( - self._get_numeric_match( - checks[idx], result, self.column_mapping[column][checks[idx]]["pass_value"], tolerance - ) - if is_numeric_value_check - else (result == self.column_mapping[column][checks[idx]]["pass_value"]) + self.column_mapping[column][checks[idx]]["success"] = self._get_match( + self.column_mapping[column][checks[idx]], result, tolerance ) failed_tests.extend(_get_failed_tests(self.column_mapping[column])) @@ -575,23 +570,93 @@ def execute(self, context=None): self.log.info("All tests have passed") - def _get_numeric_match(self, check, numeric_record, numeric_pass_value, tolerance=None) -> bool: - if check in "min": + def _get_match(self, check_values, record, tolerance=None) -> bool: + # check if record is str or numeric + # if record is str, do pattern matching + # numeric record checks + if "geq_than" in check_values: if tolerance is not None: - return numeric_record >= numeric_pass_value * (1 - tolerance) - return numeric_record >= numeric_pass_value - if check in "max": + return record >= check_values["geq_than"] * (1 - tolerance) + return record >= check_values["geq_than"] + elif "greater_than" in check_values: if tolerance is not None: - return numeric_record <= numeric_pass_value * (1 + tolerance) - return numeric_record <= numeric_pass_value - if check in ["null_check", "distinct_check", "unique_check"]: + return record > check_values["greater_than"] * (1 - tolerance) + return record > check_values["greater_than"] + if "leq_than" in check_values: + if tolerance is not None: + return record <= check_values["leq_than"] * (1 + tolerance) + return record <= check_values["leq_than"] + elif "less_than" in check_values: + if tolerance is not None: + return record < check_values["less_than"] * (1 + tolerance) + return record < check_values["less_than"] + if "equal_to" in check_values: if tolerance is not None: return ( - numeric_pass_value * (1 - tolerance) - <= numeric_record - <= numeric_pass_value * (1 + tolerance) + check_values["equal_to"] * (1 - tolerance) + <= record + <= check_values["equal_to"] * (1 + tolerance) ) - return numeric_record == numeric_pass_value + return record == check_values["equal_to"] + + def _column_mapping_validation(self, check, check_values): + if check not in self.column_checks: + raise AirflowException(f"Invalid column check: {check}.") + if ( + "greater_than" not in check_values + and "geq_than" not in check_values + and "less_than" not in check_values + and "leq_than" not in check_values + and "equal_to" not in check_values + ): + raise ValueError( + "Please provide one or more of: less_than, leq_than, " + "greater_than, geq_than, or equal_to in the check's dict." + ) + + if "greater_than" in check_values and "less_than" in check_values: + if check_values["greater_than"] >= check_values["less_than"]: + raise ValueError( + "greater_than should be strictly less than to " + "less_than. Use geq_than or leq_than for " + "overlapping equality." + ) + + if "greater_than" in check_values and "leq_than" in check_values: + if check_values["greater_than"] >= check_values["leq_than"]: + raise ValueError( + "greater_than must be strictly less than leq_than. " + "Use geq_than with leq_than for overlapping equality." + ) + + if "geq_than" in check_values and "less_than" in check_values: + if check_values["geq_than"] >= check_values["less_than"]: + raise ValueError( + "geq_than should be strictly less than less_than. " + "Use leq_than with geq_than for overlapping equality." + ) + + if "geq_than" in check_values and "leq_than" in check_values: + if check_values["geq_than"] > check_values["leq_than"]: + raise ValueError("geq_than should be less than or equal to leq_than.") + + if "greater_than" in check_values and "geq_than" in check_values: + raise ValueError("Only supply one of greater_than or geq_than.") + + if "less_than" in check_values and "leq_than" in check_values: + raise ValueError("Only supply one of less_than or leq_than.") + + if ( + "greater_than" in check_values + or "geq_than" in check_values + or "less_than" in check_values + or "leq_than" in check_values + ) and "equal_to" in check_values: + raise ValueError( + "equal_to cannot be passed with a greater or less than " + "function. To specify 'greater than or equal to' or " + "'less than or equal to', use geq_than or leq_than." + ) class SQLTableCheckOperator(BaseSQLOperator): diff --git a/tests/operators/test_sql.py b/tests/operators/test_sql.py index 780206aa0b7c7..2c6e38df8d948 100644 --- a/tests/operators/test_sql.py +++ b/tests/operators/test_sql.py @@ -391,11 +391,11 @@ class TestColumnCheckOperator(unittest.TestCase): valid_column_mapping = { "X": { - "null_check": {"pass_value": 0}, - "distinct_check": {"pass_value": 10, "tolerance": 0.1}, - "unique_check": {"pass_value": 10}, - "min": {"pass_value": 1}, - "max": {"pass_value": 20}, + "null_check": {"equal_to": 0}, + "distinct_check": {"equal_to": 10, "tolerance": 0.1}, + "unique_check": {"geq_than": 10}, + "min": {"leq_than": 1}, + "max": {"less_than": 20, "greater_than": 10}, } } @@ -410,9 +410,17 @@ def test_check_not_in_column_checks(self, mock_get_db_hook): self._construct_operator(self.invalid_column_mapping) @mock.patch.object(SQLColumnCheckOperator, "get_db_hook") - def test_pass_all_checks_check(self, mock_get_db_hook): + def test_pass_all_checks_exact_check(self, mock_get_db_hook): + mock_hook = mock.Mock() + mock_hook.get_first.return_value = (0, 10, 10, 1, 19) + mock_get_db_hook.return_value = mock_hook + operator = self._construct_operator(self.valid_column_mapping) + operator.execute() + + @mock.patch.object(SQLColumnCheckOperator, "get_db_hook") + def test_pass_all_checks_inexact_check(self, mock_get_db_hook): mock_hook = mock.Mock() - mock_hook.get_first.return_value = (0, 10, 10, 1, 20) + mock_hook.get_first.return_value = (0, 9, 12, 0, 15) mock_get_db_hook.return_value = mock_hook operator = self._construct_operator(self.valid_column_mapping) operator.execute() @@ -420,7 +428,7 @@ def test_pass_all_checks_check(self, mock_get_db_hook): @mock.patch.object(SQLColumnCheckOperator, "get_db_hook") def test_fail_all_checks_check(self, mock_get_db_hook): mock_hook = mock.Mock() - mock_hook.get_first.return_value = (1, 12, 11, -1, 50) + mock_hook.get_first.return_value = (1, 12, 11, -1, 20) mock_get_db_hook.return_value = mock_hook operator = self._construct_operator(self.valid_column_mapping) with pytest.raises(AirflowException): From 0f56b089031429e0380c631ba8161aa03c5d63e8 Mon Sep 17 00:00:00 2001 From: Benji Lampel Date: Fri, 3 Jun 2022 15:12:41 -0400 Subject: [PATCH 06/15] Update SQLTableCheckOperator Improved the table check operator by refactoring how checks are performed; this makes it much more customizeable at the expense of knowing the specific result of checks, as the results are now all just 0 or 1. --- airflow/operators/sql.py | 44 ++++++++++++++----------------------- tests/operators/test_sql.py | 32 ++++++--------------------- 2 files changed, 23 insertions(+), 53 deletions(-) diff --git a/airflow/operators/sql.py b/airflow/operators/sql.py index f4fe32335420c..0880861c849da 100644 --- a/airflow/operators/sql.py +++ b/airflow/operators/sql.py @@ -661,24 +661,24 @@ def _column_mapping_validation(self, check, check_values): class SQLTableCheckOperator(BaseSQLOperator): """ - Performs one or more of the templated checks in the table_checks dictionary. - Checks are performed on the table as aggregates. + Performs one or more of the checks provided in the checks dictionary. + Checks should be written to return a boolean result. :param table: the table to run checks on. :param checks: the dictionary of checks, e.g.: { 'row_count_check': { - 'pass_value': 100, - 'tolerance': .05 + 'check_statement': 'COUNT(*) == 1000' + }, + 'column_sum_check': { + 'check_statement': 'col_a + col_b < col_c' } } :param conn_id: the connection ID used to connect to the database. :param database: name of database which overwrite the defined one in connection """ - table_checks = { - "row_count_check": "COUNT(*) AS row_count_check", - } + sql_check_template = "MIN(CASE WHEN check_statement THEN 1 ELSE 0 END AS check_name)" def __init__( self, @@ -690,9 +690,6 @@ def __init__( **kwargs, ): super().__init__(conn_id=conn_id, database=database, **kwargs) - for check in checks.keys(): - if check not in self.table_checks: - raise AirflowException(f"Invalid table check: {check}.") self.table = table self.checks = checks @@ -701,7 +698,14 @@ def __init__( def execute(self, context=None): hook = self.get_db_hook() - checks_sql = ",".join([self.table_checks[check] for check in self.checks.keys()]) + checks_sql = ",".join( + [ + self.sql_check_template.replace("check_statment", value["check_statement"]).replace( + "check_name", check_name + ) + for check_name, value in self.checks.items() + ] + ) self.sql = f"SELECT {checks_sql} FROM {self.table};" records = hook.get_first(self.sql) @@ -713,16 +717,7 @@ def execute(self, context=None): for check in self.checks.keys(): for result in records: - pass_value_conv = _convert_to_float_if_possible(self.checks[check]["pass_value"]) - is_numeric_value_check = isinstance(pass_value_conv, float) - tolerance = self.checks[check]["tolerance"] if "tolerance" in self.checks[check] else None - - self.checks[check]["result"] = result - self.checks[check]["success"] = ( - self._get_numeric_match(result, self.checks[check]["pass_value"], tolerance) - if is_numeric_value_check - else (result == self.checks[check]["pass_value"]) - ) + self.checks[check]["success"] = bool(result) failed_tests = _get_failed_tests(self.checks) if failed_tests: @@ -734,13 +729,6 @@ def execute(self, context=None): self.log.info("All tests have passed") - def _get_numeric_match(self, numeric_record, numeric_pass_value, tolerance=None): - if tolerance is not None: - return ( - numeric_pass_value * (1 - tolerance) <= numeric_record <= numeric_pass_value * (1 + tolerance) - ) - return numeric_record == numeric_pass_value - class BranchSQLOperator(BaseSQLOperator, SkipMixin): """ diff --git a/tests/operators/test_sql.py b/tests/operators/test_sql.py index 2c6e38df8d948..329d06f57fbe4 100644 --- a/tests/operators/test_sql.py +++ b/tests/operators/test_sql.py @@ -437,46 +437,28 @@ def test_fail_all_checks_check(self, mock_get_db_hook): class TestTableCheckOperator(unittest.TestCase): - valid_checks = { - "row_count_check": {"pass_value": 10}, + checks = { + "row_count_check": {"check_statement": "COUNT(*) == 1000"}, + "column_sum_check": {"check_statement": "col_a + col_b < col_c"}, } - valid_checks_tolerance = { - "row_count_check": {"pass_value": 10, "tolerance": 0.2}, - } - - invalid_checks = {"invalid_check_name": {"pass_value": 5}} - def _construct_operator(self, checks): return SQLTableCheckOperator(task_id="test_task", table="test_table", checks=checks) - @mock.patch.object(SQLTableCheckOperator, "get_db_hook") - def test_check_not_in_checks(self, mock_get_db_hook): - with pytest.raises(AirflowException, match="Invalid table check: invalid_check_name."): - self._construct_operator(self.invalid_checks) - @mock.patch.object(SQLTableCheckOperator, "get_db_hook") def test_pass_all_checks_check(self, mock_get_db_hook): mock_hook = mock.Mock() - mock_hook.get_first.return_value = (10,) - mock_get_db_hook.return_value = mock_hook - operator = self._construct_operator(self.valid_checks) - operator.execute() - - @mock.patch.object(SQLTableCheckOperator, "get_db_hook") - def test_pass_all_checks_with_tolerance_check(self, mock_get_db_hook): - mock_hook = mock.Mock() - mock_hook.get_first.return_value = (11,) + mock_hook.get_first.return_value = (1000, 1) mock_get_db_hook.return_value = mock_hook - operator = self._construct_operator(self.valid_checks_tolerance) + operator = self._construct_operator(self.checks) operator.execute() @mock.patch.object(SQLTableCheckOperator, "get_db_hook") def test_fail_all_checks_check(self, mock_get_db_hook): mock_hook = mock.Mock() - mock_hook.get_first.return_value = (1,) + mock_hook.get_first.return_value = (998, 0) mock_get_db_hook.return_value = mock_hook - operator = self._construct_operator(self.valid_checks) + operator = self._construct_operator(self.checks) with pytest.raises(AirflowException): operator.execute() From 14afffbf83b953e2768bbafe2d54ae14ece19ff5 Mon Sep 17 00:00:00 2001 From: Benji Lampel Date: Mon, 6 Jun 2022 12:26:49 -0400 Subject: [PATCH 07/15] Fix doc errors in docstrings New docstrings were using example dictionaries in the param list, and the way they were written before caused doc build errors. The fix was to update the formatting and add missing words to the word list. --- airflow/operators/sql.py | 52 +++++++++++++++++++------------------- docs/spelling_wordlist.txt | 2 ++ 2 files changed, 28 insertions(+), 26 deletions(-) diff --git a/airflow/operators/sql.py b/airflow/operators/sql.py index 0880861c849da..728a5737084f3 100644 --- a/airflow/operators/sql.py +++ b/airflow/operators/sql.py @@ -469,7 +469,7 @@ def push(self, meta_data): def _get_failed_tests(checks): return [ - f"\tCheck: {check}, " f"Check Values: {check_values}\n" + f"\tCheck: {check},\nCheck Values: {check_values}\n" for check, check_values in checks.items() if not check_values["success"] ] @@ -480,26 +480,26 @@ class SQLColumnCheckOperator(BaseSQLOperator): Performs one or more of the templated checks in the column_checks dictionary. Checks are performed on a per-column basis specified by the column_mapping. - :param table: the table to run checks on. - :param column_mapping: the dictionary of columns and their associated checks, e.g.: - { - 'col_name': { - 'null_check': { - 'equal_to': 0, - }, - 'min': { - 'greater_than': 5, - 'leq_than': 10, - 'tolerance': 0.2, - }, - 'max': { - 'less_than': 1000, - 'geq_than': 10, - 'tolerance': 0.01 + :param table: the table to run checks on + :param column_mapping: the dictionary of columns and their associated checks, e.g. + + .. code-block:: python + + { + "col_name": { + "null_check": { + "equal_to": 0, + }, + "min": { + "greater_than": 5, + "leq_than": 10, + "tolerance": 0.2, + }, + "max": {"less_than": 1000, "geq_than": 10, "tolerance": 0.01}, } } - } - :param conn_id: the connection ID used to connect to the database. + + :param conn_id: the connection ID used to connect to the database :param database: name of database which overwrite the defined one in connection """ @@ -666,14 +666,14 @@ class SQLTableCheckOperator(BaseSQLOperator): :param table: the table to run checks on. :param checks: the dictionary of checks, e.g.: - { - 'row_count_check': { - 'check_statement': 'COUNT(*) == 1000' - }, - 'column_sum_check': { - 'check_statement': 'col_a + col_b < col_c' + + .. code-block:: python + + { + "row_count_check": {"check_statement": "COUNT(*) == 1000"}, + "column_sum_check": {"check_statement": "col_a + col_b < col_c"}, } - } + :param conn_id: the connection ID used to connect to the database. :param database: name of database which overwrite the defined one in connection """ diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index eb8f6d7795c27..d81545990e3a8 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -851,6 +851,7 @@ gcpcloudsql gcs gdbm generateUploadUrl +geq getattr getfqdn getframe @@ -1007,6 +1008,7 @@ latencies latin ldap ldaps +leq leveldb libs libz From eba96f3b097e5e85af28826f6fc1bfa16778883a Mon Sep 17 00:00:00 2001 From: Benji Lampel Date: Mon, 6 Jun 2022 12:38:20 -0400 Subject: [PATCH 08/15] Reorder logging Moves record logging below the AirflowException for no records, and makes the Exception message more clear. --- airflow/operators/sql.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/airflow/operators/sql.py b/airflow/operators/sql.py index 728a5737084f3..375fd77bf16a9 100644 --- a/airflow/operators/sql.py +++ b/airflow/operators/sql.py @@ -543,10 +543,11 @@ def execute(self, context=None): self.sql = f"SELECT {checks_sql} FROM {self.table};" records = hook.get_first(self.sql) - self.log.info(f"Record: {records}") if not records: - raise AirflowException("The query returned None") + raise AirflowException(f"The following query returned zero rows: {self.sql}") + + self.log.info(f"Record: {records}") for idx, result in enumerate(records): tolerance = ( @@ -710,10 +711,10 @@ def execute(self, context=None): self.sql = f"SELECT {checks_sql} FROM {self.table};" records = hook.get_first(self.sql) - self.log.info(f"Record: {records}") - if not records: - raise AirflowException("The query returned None") + raise AirflowException(f"The following query returned zero rows: {self.sql}") + + self.log.info(f"Record: {records}") for check in self.checks.keys(): for result in records: From 2388dd8eb6edcb6417e2faea5a62dc2cd19b354b Mon Sep 17 00:00:00 2001 From: Benjamin Date: Mon, 6 Jun 2022 12:46:39 -0400 Subject: [PATCH 09/15] Apply suggestions from code review Co-authored-by: Josh Fell <48934154+josh-fell@users.noreply.github.com> --- airflow/operators/sql.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/airflow/operators/sql.py b/airflow/operators/sql.py index 375fd77bf16a9..b8171df5d4f90 100644 --- a/airflow/operators/sql.py +++ b/airflow/operators/sql.py @@ -550,11 +550,7 @@ def execute(self, context=None): self.log.info(f"Record: {records}") for idx, result in enumerate(records): - tolerance = ( - self.column_mapping[column][checks[idx]]["tolerance"] - if "tolerance" in self.column_mapping[column][checks[idx]] - else None - ) + tolerance = self.column_mapping[column][checks[idx]].get("tolerance") self.column_mapping[column][checks[idx]]["result"] = result self.column_mapping[column][checks[idx]]["success"] = self._get_match( @@ -618,7 +614,7 @@ def _column_mapping_validation(self, check, check_values): if "greater_than" in check_values and "less_than" in check_values: if check_values["greater_than"] >= check_values["less_than"]: raise ValueError( - "greater_than should be strictly less than to " + "greater_than should be strictly less than " "less_than. Use geq_than or leq_than for " "overlapping equality." ) From cf89bf577e18576db1dcd2398b208a00d8bc4cde Mon Sep 17 00:00:00 2001 From: Benji Lampel Date: Mon, 6 Jun 2022 14:42:07 -0400 Subject: [PATCH 10/15] Add more information to docstring --- airflow/operators/sql.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/airflow/operators/sql.py b/airflow/operators/sql.py index b8171df5d4f90..d8e915a74d73e 100644 --- a/airflow/operators/sql.py +++ b/airflow/operators/sql.py @@ -479,6 +479,13 @@ class SQLColumnCheckOperator(BaseSQLOperator): """ Performs one or more of the templated checks in the column_checks dictionary. Checks are performed on a per-column basis specified by the column_mapping. + Each check can take one or more of the following options: + - equal_to: an exact value to equal, cannot be used with other comparison options + - greater_than: value that result should be strictly greater than + - less_than: value that results should be strictly less than + - geq_than: value that results should be greater than or equal to + - leq_than: value that results should be less than or equal to + - tolerance: the percentage that the result may be off from the expected value :param table: the table to run checks on :param column_mapping: the dictionary of columns and their associated checks, e.g. From c83c79e6bac2c9ec4e4ea03f0ada2c20ff586c6b Mon Sep 17 00:00:00 2001 From: Benji Lampel Date: Mon, 6 Jun 2022 16:05:19 -0400 Subject: [PATCH 11/15] Fix docstring indent error --- airflow/operators/sql.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/airflow/operators/sql.py b/airflow/operators/sql.py index d8e915a74d73e..9e25612c60aee 100644 --- a/airflow/operators/sql.py +++ b/airflow/operators/sql.py @@ -480,12 +480,12 @@ class SQLColumnCheckOperator(BaseSQLOperator): Performs one or more of the templated checks in the column_checks dictionary. Checks are performed on a per-column basis specified by the column_mapping. Each check can take one or more of the following options: - - equal_to: an exact value to equal, cannot be used with other comparison options - - greater_than: value that result should be strictly greater than - - less_than: value that results should be strictly less than - - geq_than: value that results should be greater than or equal to - - leq_than: value that results should be less than or equal to - - tolerance: the percentage that the result may be off from the expected value + - equal_to: an exact value to equal, cannot be used with other comparison options + - greater_than: value that result should be strictly greater than + - less_than: value that results should be strictly less than + - geq_than: value that results should be greater than or equal to + - leq_than: value that results should be less than or equal to + - tolerance: the percentage that the result may be off from the expected value :param table: the table to run checks on :param column_mapping: the dictionary of columns and their associated checks, e.g. From 7136397567fd26c99a063804acb5c023bbd26bf5 Mon Sep 17 00:00:00 2001 From: Benji Lampel Date: Wed, 8 Jun 2022 16:58:10 -0400 Subject: [PATCH 12/15] Remove reminder comments; Add column alias for unique_check query. --- airflow/operators/sql.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/airflow/operators/sql.py b/airflow/operators/sql.py index 9e25612c60aee..3c15abacc4fb6 100644 --- a/airflow/operators/sql.py +++ b/airflow/operators/sql.py @@ -516,7 +516,7 @@ class SQLColumnCheckOperator(BaseSQLOperator): # pass value should be number of acceptable distinct values "distinct_check": "COUNT(DISTINCT(column)) AS column_distinct_check", # pass value is implicit in the query, it does not need to be passed - "unique_check": "COUNT(DISTINCT(column)) = COUNT(column)", + "unique_check": "COUNT(DISTINCT(column)) = COUNT(column) AS column_unique_check", # pass value should be the minimum acceptable numeric value "min": "MIN(column) AS column_min", # pass value should be the maximum acceptable numeric value @@ -575,9 +575,6 @@ def execute(self, context=None): self.log.info("All tests have passed") def _get_match(self, check_values, record, tolerance=None) -> bool: - # check if record is str or numeric - # if record is str, do pattern matching - # numeric record checks if "geq_than" in check_values: if tolerance is not None: return record >= check_values["geq_than"] * (1 - tolerance) From cbc76c1aaeb1f416298928f2f3dd103e68471b25 Mon Sep 17 00:00:00 2001 From: Benji Lampel Date: Thu, 9 Jun 2022 11:43:41 -0400 Subject: [PATCH 13/15] Fix SQL rendering bug in SQLTableCheckOperator SQLTableCheckOperator was rendering an invalid SQL query, this fix changes the way the query template is built from the passed in strings. --- airflow/operators/sql.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/airflow/operators/sql.py b/airflow/operators/sql.py index 3c15abacc4fb6..0c676b27005e9 100644 --- a/airflow/operators/sql.py +++ b/airflow/operators/sql.py @@ -539,7 +539,7 @@ def __init__( self.table = table self.column_mapping = column_mapping - self.sql = f"SELECT * FROM {self.table};" + self.sql = None def execute(self, context=None): hook = self.get_db_hook() @@ -671,7 +671,7 @@ class SQLTableCheckOperator(BaseSQLOperator): .. code-block:: python { - "row_count_check": {"check_statement": "COUNT(*) == 1000"}, + "row_count_check": {"check_statement": "COUNT(*) = 1000"}, "column_sum_check": {"check_statement": "col_a + col_b < col_c"}, } @@ -679,7 +679,8 @@ class SQLTableCheckOperator(BaseSQLOperator): :param database: name of database which overwrite the defined one in connection """ - sql_check_template = "MIN(CASE WHEN check_statement THEN 1 ELSE 0 END AS check_name)" + sql_check_template = "CASE WHEN check_statement THEN 1 ELSE 0 END AS check_name" + sql_min_template = "MIN(check_name)" def __init__( self, @@ -699,16 +700,20 @@ def __init__( def execute(self, context=None): hook = self.get_db_hook() + check_names = [*self.checks] + check_mins_sql = ",".join( + self.sql_min_template.replace("check_name", check_name) for check_name in check_names + ) checks_sql = ",".join( [ - self.sql_check_template.replace("check_statment", value["check_statement"]).replace( + self.sql_check_template.replace("check_statement", value["check_statement"]).replace( "check_name", check_name ) for check_name, value in self.checks.items() ] ) - self.sql = f"SELECT {checks_sql} FROM {self.table};" + self.sql = f"SELECT {check_mins_sql} FROM (SELECT {checks_sql} FROM {self.table});" records = hook.get_first(self.sql) if not records: From fa083eaf05856948b2f091bb28208b64d49fe99e Mon Sep 17 00:00:00 2001 From: Benji Lampel Date: Thu, 9 Jun 2022 11:55:30 -0400 Subject: [PATCH 14/15] Add note about OpenLineage compatability; add tab in _get_failed_test result string --- airflow/operators/sql.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/airflow/operators/sql.py b/airflow/operators/sql.py index 0c676b27005e9..b23b188fbb295 100644 --- a/airflow/operators/sql.py +++ b/airflow/operators/sql.py @@ -469,7 +469,7 @@ def push(self, meta_data): def _get_failed_tests(checks): return [ - f"\tCheck: {check},\nCheck Values: {check_values}\n" + f"\tCheck: {check},\n\tCheck Values: {check_values}\n" for check, check_values in checks.items() if not check_values["success"] ] @@ -539,7 +539,8 @@ def __init__( self.table = table self.column_mapping = column_mapping - self.sql = None + # OpenLineage needs a valid SQL query with the input/output table(s) to parse + self.sql = f"SELECT * FROM {self.table};" def execute(self, context=None): hook = self.get_db_hook() @@ -695,6 +696,7 @@ def __init__( self.table = table self.checks = checks + # OpenLineage needs a valid SQL query with the input/output table(s) to parse self.sql = f"SELECT * FROM {self.table};" def execute(self, context=None): From 7439986826fef2d8651d2abdd0d9f38481265210 Mon Sep 17 00:00:00 2001 From: Benji Lampel Date: Thu, 9 Jun 2022 15:07:45 -0400 Subject: [PATCH 15/15] Remove quotes around column in null_check After some research, it seems the null_check query does *not* need quotes around the column name. --- airflow/operators/sql.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/airflow/operators/sql.py b/airflow/operators/sql.py index b23b188fbb295..65cf92d0ced28 100644 --- a/airflow/operators/sql.py +++ b/airflow/operators/sql.py @@ -511,15 +511,10 @@ class SQLColumnCheckOperator(BaseSQLOperator): """ column_checks = { - # pass value should be number of acceptable nulls - "null_check": "SUM(CASE WHEN 'column' IS NULL THEN 1 ELSE 0 END) AS column_null_check", - # pass value should be number of acceptable distinct values + "null_check": "SUM(CASE WHEN column IS NULL THEN 1 ELSE 0 END) AS column_null_check", "distinct_check": "COUNT(DISTINCT(column)) AS column_distinct_check", - # pass value is implicit in the query, it does not need to be passed "unique_check": "COUNT(DISTINCT(column)) = COUNT(column) AS column_unique_check", - # pass value should be the minimum acceptable numeric value "min": "MIN(column) AS column_min", - # pass value should be the maximum acceptable numeric value "max": "MAX(column) AS column_max", }