-
Notifications
You must be signed in to change notification settings - Fork 16.3k
Add new SQLCheckOperators #23915
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add new SQLCheckOperators #23915
Changes from all commits
f9d3d3e
a58a3c5
c6950de
1818691
fc8380e
0f56b08
14afffb
eba96f3
2388dd8
cf89bf5
c83c79e
7136397
cbc76c1
fa083ea
7439986
99158b9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -467,6 +467,272 @@ def push(self, meta_data): | |
| self.log.info("Log from %s:\n%s", self.dag_id, info) | ||
|
|
||
|
|
||
| def _get_failed_tests(checks): | ||
| return [ | ||
| f"\tCheck: {check},\n\tCheck Values: {check_values}\n" | ||
| for check, check_values in checks.items() | ||
| if not check_values["success"] | ||
| ] | ||
|
|
||
|
|
||
| class SQLColumnCheckOperator(BaseSQLOperator): | ||
| """ | ||
| Performs one or more of the templated checks in the column_checks dictionary. | ||
| Checks are performed on a per-column basis specified by the column_mapping. | ||
| Each check can take one or more of the following options: | ||
| - equal_to: an exact value to equal, cannot be used with other comparison options | ||
| - greater_than: value that result should be strictly greater than | ||
| - less_than: value that results should be strictly less than | ||
| - geq_than: value that results should be greater than or equal to | ||
| - leq_than: value that results should be less than or equal to | ||
| - tolerance: the percentage that the result may be off from the expected value | ||
|
|
||
| :param table: the table to run checks on | ||
| :param column_mapping: the dictionary of columns and their associated checks, e.g. | ||
|
|
||
| .. code-block:: python | ||
|
|
||
| { | ||
| "col_name": { | ||
| "null_check": { | ||
| "equal_to": 0, | ||
| }, | ||
| "min": { | ||
| "greater_than": 5, | ||
| "leq_than": 10, | ||
| "tolerance": 0.2, | ||
| }, | ||
| "max": {"less_than": 1000, "geq_than": 10, "tolerance": 0.01}, | ||
| } | ||
| } | ||
|
|
||
| :param conn_id: the connection ID used to connect to the database | ||
| :param database: name of database which overwrite the defined one in connection | ||
| """ | ||
|
|
||
| column_checks = { | ||
| "null_check": "SUM(CASE WHEN column IS NULL THEN 1 ELSE 0 END) AS column_null_check", | ||
| "distinct_check": "COUNT(DISTINCT(column)) AS column_distinct_check", | ||
| "unique_check": "COUNT(DISTINCT(column)) = COUNT(column) AS column_unique_check", | ||
| "min": "MIN(column) AS column_min", | ||
| "max": "MAX(column) AS column_max", | ||
| } | ||
|
|
||
| def __init__( | ||
| self, | ||
| *, | ||
| table: str, | ||
| column_mapping: Dict[str, Dict[str, Any]], | ||
| conn_id: Optional[str] = None, | ||
| database: Optional[str] = None, | ||
| **kwargs, | ||
| ): | ||
| super().__init__(conn_id=conn_id, database=database, **kwargs) | ||
| for checks in column_mapping.values(): | ||
| for check, check_values in checks.items(): | ||
| self._column_mapping_validation(check, check_values) | ||
|
|
||
| self.table = table | ||
| self.column_mapping = column_mapping | ||
| # OpenLineage needs a valid SQL query with the input/output table(s) to parse | ||
| self.sql = f"SELECT * FROM {self.table};" | ||
|
|
||
| def execute(self, context=None): | ||
| hook = self.get_db_hook() | ||
| failed_tests = [] | ||
| for column in self.column_mapping: | ||
| checks = [*self.column_mapping[column]] | ||
| checks_sql = ",".join([self.column_checks[check].replace("column", column) for check in checks]) | ||
|
|
||
| self.sql = f"SELECT {checks_sql} FROM {self.table};" | ||
denimalpaca marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| records = hook.get_first(self.sql) | ||
|
|
||
| if not records: | ||
| raise AirflowException(f"The following query returned zero rows: {self.sql}") | ||
|
|
||
| self.log.info(f"Record: {records}") | ||
|
|
||
| for idx, result in enumerate(records): | ||
| tolerance = self.column_mapping[column][checks[idx]].get("tolerance") | ||
|
|
||
| self.column_mapping[column][checks[idx]]["result"] = result | ||
| self.column_mapping[column][checks[idx]]["success"] = self._get_match( | ||
| self.column_mapping[column][checks[idx]], result, tolerance | ||
| ) | ||
|
|
||
| failed_tests.extend(_get_failed_tests(self.column_mapping[column])) | ||
| if failed_tests: | ||
| raise AirflowException( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe it's worth thinking about having a separate exception for Airflow checks/test especially if more data quality functionality is coming?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that's a good idea, but we'd also have to replace exceptions in other SQL operators and I think that's outside the scope of this PR. But I can open an issue and implement that in a different PR. Something like an |
||
| f"Test failed.\nQuery:\n{self.sql}\nResults:\n{records!s}\n" | ||
| "The following tests have failed:" | ||
| f"\n{''.join(failed_tests)}" | ||
| ) | ||
|
|
||
| self.log.info("All tests have passed") | ||
|
|
||
| def _get_match(self, check_values, record, tolerance=None) -> bool: | ||
| if "geq_than" in check_values: | ||
| if tolerance is not None: | ||
| return record >= check_values["geq_than"] * (1 - tolerance) | ||
| return record >= check_values["geq_than"] | ||
| elif "greater_than" in check_values: | ||
| if tolerance is not None: | ||
| return record > check_values["greater_than"] * (1 - tolerance) | ||
| return record > check_values["greater_than"] | ||
| if "leq_than" in check_values: | ||
| if tolerance is not None: | ||
| return record <= check_values["leq_than"] * (1 + tolerance) | ||
| return record <= check_values["leq_than"] | ||
| elif "less_than" in check_values: | ||
| if tolerance is not None: | ||
| return record < check_values["less_than"] * (1 + tolerance) | ||
| return record < check_values["less_than"] | ||
| if "equal_to" in check_values: | ||
| if tolerance is not None: | ||
| return ( | ||
| check_values["equal_to"] * (1 - tolerance) | ||
| <= record | ||
| <= check_values["equal_to"] * (1 + tolerance) | ||
| ) | ||
| return record == check_values["equal_to"] | ||
|
|
||
| def _column_mapping_validation(self, check, check_values): | ||
| if check not in self.column_checks: | ||
| raise AirflowException(f"Invalid column check: {check}.") | ||
| if ( | ||
| "greater_than" not in check_values | ||
| and "geq_than" not in check_values | ||
| and "less_than" not in check_values | ||
| and "leq_than" not in check_values | ||
| and "equal_to" not in check_values | ||
| ): | ||
| raise ValueError( | ||
| "Please provide one or more of: less_than, leq_than, " | ||
| "greater_than, geq_than, or equal_to in the check's dict." | ||
denimalpaca marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| ) | ||
|
|
||
| if "greater_than" in check_values and "less_than" in check_values: | ||
| if check_values["greater_than"] >= check_values["less_than"]: | ||
| raise ValueError( | ||
| "greater_than should be strictly less than " | ||
| "less_than. Use geq_than or leq_than for " | ||
| "overlapping equality." | ||
| ) | ||
|
|
||
| if "greater_than" in check_values and "leq_than" in check_values: | ||
| if check_values["greater_than"] >= check_values["leq_than"]: | ||
| raise ValueError( | ||
| "greater_than must be strictly less than leq_than. " | ||
| "Use geq_than with leq_than for overlapping equality." | ||
| ) | ||
|
|
||
| if "geq_than" in check_values and "less_than" in check_values: | ||
| if check_values["geq_than"] >= check_values["less_than"]: | ||
| raise ValueError( | ||
| "geq_than should be strictly less than less_than. " | ||
| "Use leq_than with geq_than for overlapping equality." | ||
| ) | ||
|
|
||
| if "geq_than" in check_values and "leq_than" in check_values: | ||
| if check_values["geq_than"] > check_values["leq_than"]: | ||
| raise ValueError("geq_than should be less than or equal to leq_than.") | ||
|
|
||
| if "greater_than" in check_values and "geq_than" in check_values: | ||
| raise ValueError("Only supply one of greater_than or geq_than.") | ||
|
|
||
| if "less_than" in check_values and "leq_than" in check_values: | ||
| raise ValueError("Only supply one of less_than or leq_than.") | ||
|
|
||
| if ( | ||
| "greater_than" in check_values | ||
| or "geq_than" in check_values | ||
| or "less_than" in check_values | ||
| or "leq_than" in check_values | ||
| ) and "equal_to" in check_values: | ||
| raise ValueError( | ||
| "equal_to cannot be passed with a greater or less than " | ||
| "function. To specify 'greater than or equal to' or " | ||
| "'less than or equal to', use geq_than or leq_than." | ||
| ) | ||
|
|
||
|
|
||
| class SQLTableCheckOperator(BaseSQLOperator): | ||
| """ | ||
| Performs one or more of the checks provided in the checks dictionary. | ||
| Checks should be written to return a boolean result. | ||
|
|
||
| :param table: the table to run checks on. | ||
| :param checks: the dictionary of checks, e.g.: | ||
|
|
||
| .. code-block:: python | ||
|
|
||
| { | ||
| "row_count_check": {"check_statement": "COUNT(*) = 1000"}, | ||
| "column_sum_check": {"check_statement": "col_a + col_b < col_c"}, | ||
| } | ||
|
|
||
| :param conn_id: the connection ID used to connect to the database. | ||
| :param database: name of database which overwrite the defined one in connection | ||
| """ | ||
|
|
||
| sql_check_template = "CASE WHEN check_statement THEN 1 ELSE 0 END AS check_name" | ||
| sql_min_template = "MIN(check_name)" | ||
|
|
||
| def __init__( | ||
| self, | ||
| *, | ||
| table: str, | ||
| checks: Dict[str, Dict[str, Any]], | ||
| conn_id: Optional[str] = None, | ||
| database: Optional[str] = None, | ||
| **kwargs, | ||
| ): | ||
| super().__init__(conn_id=conn_id, database=database, **kwargs) | ||
|
|
||
| self.table = table | ||
| self.checks = checks | ||
| # OpenLineage needs a valid SQL query with the input/output table(s) to parse | ||
| self.sql = f"SELECT * FROM {self.table};" | ||
denimalpaca marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| def execute(self, context=None): | ||
| hook = self.get_db_hook() | ||
|
|
||
| check_names = [*self.checks] | ||
| check_mins_sql = ",".join( | ||
| self.sql_min_template.replace("check_name", check_name) for check_name in check_names | ||
| ) | ||
| checks_sql = ",".join( | ||
| [ | ||
| self.sql_check_template.replace("check_statement", value["check_statement"]).replace( | ||
| "check_name", check_name | ||
| ) | ||
| for check_name, value in self.checks.items() | ||
| ] | ||
| ) | ||
|
|
||
| self.sql = f"SELECT {check_mins_sql} FROM (SELECT {checks_sql} FROM {self.table});" | ||
| records = hook.get_first(self.sql) | ||
|
|
||
| if not records: | ||
| raise AirflowException(f"The following query returned zero rows: {self.sql}") | ||
|
|
||
| self.log.info(f"Record: {records}") | ||
|
|
||
| for check in self.checks.keys(): | ||
| for result in records: | ||
| self.checks[check]["success"] = bool(result) | ||
|
|
||
| failed_tests = _get_failed_tests(self.checks) | ||
| if failed_tests: | ||
| raise AirflowException( | ||
| f"Test failed.\nQuery:\n{self.sql}\nResults:\n{records!s}\n" | ||
| "The following tests have failed:" | ||
| f"\n{', '.join(failed_tests)}" | ||
| ) | ||
|
|
||
| self.log.info("All tests have passed") | ||
|
|
||
|
|
||
| class BranchSQLOperator(BaseSQLOperator, SkipMixin): | ||
| """ | ||
| Allows a DAG to "branch" or follow a specified path based on the results of a SQL query. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -28,7 +28,9 @@ | |
| from airflow.operators.sql import ( | ||
| BranchSQLOperator, | ||
| SQLCheckOperator, | ||
| SQLColumnCheckOperator, | ||
| SQLIntervalCheckOperator, | ||
| SQLTableCheckOperator, | ||
| SQLThresholdCheckOperator, | ||
| SQLValueCheckOperator, | ||
| ) | ||
|
|
@@ -385,6 +387,82 @@ def test_fail_min_sql_max_value(self, mock_get_db_hook): | |
| operator.execute() | ||
|
|
||
|
|
||
| class TestColumnCheckOperator(unittest.TestCase): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IMO this can be a
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a reason to use one over the other? I was just copying the test pattern from the other ones.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We prefer pytest test type unless there are good reasons not to (for new test) |
||
|
|
||
| valid_column_mapping = { | ||
| "X": { | ||
| "null_check": {"equal_to": 0}, | ||
| "distinct_check": {"equal_to": 10, "tolerance": 0.1}, | ||
| "unique_check": {"geq_than": 10}, | ||
| "min": {"leq_than": 1}, | ||
| "max": {"less_than": 20, "greater_than": 10}, | ||
| } | ||
| } | ||
|
|
||
| invalid_column_mapping = {"Y": {"invalid_check_name": {"expectation": 5}}} | ||
|
|
||
| def _construct_operator(self, column_mapping): | ||
| return SQLColumnCheckOperator(task_id="test_task", table="test_table", column_mapping=column_mapping) | ||
|
|
||
| @mock.patch.object(SQLColumnCheckOperator, "get_db_hook") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since every test case uses the same mocked object, you could think about using a pytest fixture for mocking
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you use pytest fixtures, the |
||
| def test_check_not_in_column_checks(self, mock_get_db_hook): | ||
| with pytest.raises(AirflowException, match="Invalid column check: invalid_check_name."): | ||
| self._construct_operator(self.invalid_column_mapping) | ||
|
|
||
| @mock.patch.object(SQLColumnCheckOperator, "get_db_hook") | ||
| def test_pass_all_checks_exact_check(self, mock_get_db_hook): | ||
| mock_hook = mock.Mock() | ||
| mock_hook.get_first.return_value = (0, 10, 10, 1, 19) | ||
| mock_get_db_hook.return_value = mock_hook | ||
| operator = self._construct_operator(self.valid_column_mapping) | ||
| operator.execute() | ||
|
|
||
| @mock.patch.object(SQLColumnCheckOperator, "get_db_hook") | ||
| def test_pass_all_checks_inexact_check(self, mock_get_db_hook): | ||
| mock_hook = mock.Mock() | ||
| mock_hook.get_first.return_value = (0, 9, 12, 0, 15) | ||
| mock_get_db_hook.return_value = mock_hook | ||
| operator = self._construct_operator(self.valid_column_mapping) | ||
| operator.execute() | ||
denimalpaca marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| @mock.patch.object(SQLColumnCheckOperator, "get_db_hook") | ||
| def test_fail_all_checks_check(self, mock_get_db_hook): | ||
| mock_hook = mock.Mock() | ||
| mock_hook.get_first.return_value = (1, 12, 11, -1, 20) | ||
| mock_get_db_hook.return_value = mock_hook | ||
| operator = self._construct_operator(self.valid_column_mapping) | ||
| with pytest.raises(AirflowException): | ||
| operator.execute() | ||
|
|
||
|
|
||
| class TestTableCheckOperator(unittest.TestCase): | ||
|
|
||
| checks = { | ||
| "row_count_check": {"check_statement": "COUNT(*) == 1000"}, | ||
| "column_sum_check": {"check_statement": "col_a + col_b < col_c"}, | ||
| } | ||
|
|
||
| def _construct_operator(self, checks): | ||
| return SQLTableCheckOperator(task_id="test_task", table="test_table", checks=checks) | ||
|
|
||
| @mock.patch.object(SQLTableCheckOperator, "get_db_hook") | ||
| def test_pass_all_checks_check(self, mock_get_db_hook): | ||
| mock_hook = mock.Mock() | ||
| mock_hook.get_first.return_value = (1000, 1) | ||
| mock_get_db_hook.return_value = mock_hook | ||
| operator = self._construct_operator(self.checks) | ||
| operator.execute() | ||
|
|
||
| @mock.patch.object(SQLTableCheckOperator, "get_db_hook") | ||
| def test_fail_all_checks_check(self, mock_get_db_hook): | ||
| mock_hook = mock.Mock() | ||
| mock_hook.get_first.return_value = (998, 0) | ||
| mock_get_db_hook.return_value = mock_hook | ||
| operator = self._construct_operator(self.checks) | ||
| with pytest.raises(AirflowException): | ||
| operator.execute() | ||
|
|
||
|
|
||
| class TestSqlBranch(TestHiveEnvironment, unittest.TestCase): | ||
| """ | ||
| Test for SQL Branch Operator | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious, why is this needed here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For OpenLineage. The Extractor needs some static query to render, unless we want to have the entire templated query rendered in
__init__(). I can double-check this, though, it might have changed with the move to the Listener.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to reiterate - those "common" SQL classes will not be released unti Airflow 2.4 and it means that the first operator tha will be able to use them as "mandatory" will be 12 months after 2.4 is released (that's our policy for providers), so I just want to make you aware @denimalpaca that those classes won't be to useful for quite some time.
Maybe a better Idea will be to add a separate "base sql" provider which will not be in the "airflow/operators" but it will be a separate package that could be installed additionally to airflow. We have not done so before, but maybe there is a good reason to introduce such a "shared sql" provider that wil encapsulate all SQL-like base functionality that open lineage might base on? I think trying to implement it in Airflow core is a bad idea, if the goal is fast adoption by multiple providers.
Just saying.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is an interesting idea and I support it. We could start with these operators and slowly move existing core SQL operators over too.
@potiuk Do you think this provider would be included with Airflow installs like HTTP, FTP, SQLite, and IMAP, or completely separate?