Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
266 changes: 266 additions & 0 deletions airflow/operators/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,272 @@ def push(self, meta_data):
self.log.info("Log from %s:\n%s", self.dag_id, info)


def _get_failed_tests(checks):
return [
f"\tCheck: {check},\n\tCheck Values: {check_values}\n"
for check, check_values in checks.items()
if not check_values["success"]
]


class SQLColumnCheckOperator(BaseSQLOperator):
"""
Performs one or more of the templated checks in the column_checks dictionary.
Checks are performed on a per-column basis specified by the column_mapping.
Each check can take one or more of the following options:
- equal_to: an exact value to equal, cannot be used with other comparison options
- greater_than: value that result should be strictly greater than
- less_than: value that results should be strictly less than
- geq_than: value that results should be greater than or equal to
- leq_than: value that results should be less than or equal to
- tolerance: the percentage that the result may be off from the expected value

:param table: the table to run checks on
:param column_mapping: the dictionary of columns and their associated checks, e.g.

.. code-block:: python

{
"col_name": {
"null_check": {
"equal_to": 0,
},
"min": {
"greater_than": 5,
"leq_than": 10,
"tolerance": 0.2,
},
"max": {"less_than": 1000, "geq_than": 10, "tolerance": 0.01},
}
}

:param conn_id: the connection ID used to connect to the database
:param database: name of database which overwrite the defined one in connection
"""

column_checks = {
"null_check": "SUM(CASE WHEN column IS NULL THEN 1 ELSE 0 END) AS column_null_check",
"distinct_check": "COUNT(DISTINCT(column)) AS column_distinct_check",
"unique_check": "COUNT(DISTINCT(column)) = COUNT(column) AS column_unique_check",
"min": "MIN(column) AS column_min",
"max": "MAX(column) AS column_max",
}

def __init__(
self,
*,
table: str,
column_mapping: Dict[str, Dict[str, Any]],
conn_id: Optional[str] = None,
database: Optional[str] = None,
**kwargs,
):
super().__init__(conn_id=conn_id, database=database, **kwargs)
for checks in column_mapping.values():
for check, check_values in checks.items():
self._column_mapping_validation(check, check_values)

self.table = table
self.column_mapping = column_mapping
# OpenLineage needs a valid SQL query with the input/output table(s) to parse
self.sql = f"SELECT * FROM {self.table};"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious, why is this needed here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For OpenLineage. The Extractor needs some static query to render, unless we want to have the entire templated query rendered in __init__(). I can double-check this, though, it might have changed with the move to the Listener.

Copy link
Member

@potiuk potiuk Jun 12, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to reiterate - those "common" SQL classes will not be released unti Airflow 2.4 and it means that the first operator tha will be able to use them as "mandatory" will be 12 months after 2.4 is released (that's our policy for providers), so I just want to make you aware @denimalpaca that those classes won't be to useful for quite some time.

Maybe a better Idea will be to add a separate "base sql" provider which will not be in the "airflow/operators" but it will be a separate package that could be installed additionally to airflow. We have not done so before, but maybe there is a good reason to introduce such a "shared sql" provider that wil encapsulate all SQL-like base functionality that open lineage might base on? I think trying to implement it in Airflow core is a bad idea, if the goal is fast adoption by multiple providers.

Just saying.

Copy link
Contributor

@josh-fell josh-fell Jun 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe a better Idea will be to add a separate "base sql" provider

I think this is an interesting idea and I support it. We could start with these operators and slowly move existing core SQL operators over too.

@potiuk Do you think this provider would be included with Airflow installs like HTTP, FTP, SQLite, and IMAP, or completely separate?


def execute(self, context=None):
hook = self.get_db_hook()
failed_tests = []
for column in self.column_mapping:
checks = [*self.column_mapping[column]]
checks_sql = ",".join([self.column_checks[check].replace("column", column) for check in checks])

self.sql = f"SELECT {checks_sql} FROM {self.table};"
records = hook.get_first(self.sql)

if not records:
raise AirflowException(f"The following query returned zero rows: {self.sql}")

self.log.info(f"Record: {records}")

for idx, result in enumerate(records):
tolerance = self.column_mapping[column][checks[idx]].get("tolerance")

self.column_mapping[column][checks[idx]]["result"] = result
self.column_mapping[column][checks[idx]]["success"] = self._get_match(
self.column_mapping[column][checks[idx]], result, tolerance
)

failed_tests.extend(_get_failed_tests(self.column_mapping[column]))
if failed_tests:
raise AirflowException(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it's worth thinking about having a separate exception for Airflow checks/test especially if more data quality functionality is coming?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that's a good idea, but we'd also have to replace exceptions in other SQL operators and I think that's outside the scope of this PR. But I can open an issue and implement that in a different PR. Something like an AirflowDataQualityException?

f"Test failed.\nQuery:\n{self.sql}\nResults:\n{records!s}\n"
"The following tests have failed:"
f"\n{''.join(failed_tests)}"
)

self.log.info("All tests have passed")

def _get_match(self, check_values, record, tolerance=None) -> bool:
if "geq_than" in check_values:
if tolerance is not None:
return record >= check_values["geq_than"] * (1 - tolerance)
return record >= check_values["geq_than"]
elif "greater_than" in check_values:
if tolerance is not None:
return record > check_values["greater_than"] * (1 - tolerance)
return record > check_values["greater_than"]
if "leq_than" in check_values:
if tolerance is not None:
return record <= check_values["leq_than"] * (1 + tolerance)
return record <= check_values["leq_than"]
elif "less_than" in check_values:
if tolerance is not None:
return record < check_values["less_than"] * (1 + tolerance)
return record < check_values["less_than"]
if "equal_to" in check_values:
if tolerance is not None:
return (
check_values["equal_to"] * (1 - tolerance)
<= record
<= check_values["equal_to"] * (1 + tolerance)
)
return record == check_values["equal_to"]

def _column_mapping_validation(self, check, check_values):
if check not in self.column_checks:
raise AirflowException(f"Invalid column check: {check}.")
if (
"greater_than" not in check_values
and "geq_than" not in check_values
and "less_than" not in check_values
and "leq_than" not in check_values
and "equal_to" not in check_values
):
raise ValueError(
"Please provide one or more of: less_than, leq_than, "
"greater_than, geq_than, or equal_to in the check's dict."
)

if "greater_than" in check_values and "less_than" in check_values:
if check_values["greater_than"] >= check_values["less_than"]:
raise ValueError(
"greater_than should be strictly less than "
"less_than. Use geq_than or leq_than for "
"overlapping equality."
)

if "greater_than" in check_values and "leq_than" in check_values:
if check_values["greater_than"] >= check_values["leq_than"]:
raise ValueError(
"greater_than must be strictly less than leq_than. "
"Use geq_than with leq_than for overlapping equality."
)

if "geq_than" in check_values and "less_than" in check_values:
if check_values["geq_than"] >= check_values["less_than"]:
raise ValueError(
"geq_than should be strictly less than less_than. "
"Use leq_than with geq_than for overlapping equality."
)

if "geq_than" in check_values and "leq_than" in check_values:
if check_values["geq_than"] > check_values["leq_than"]:
raise ValueError("geq_than should be less than or equal to leq_than.")

if "greater_than" in check_values and "geq_than" in check_values:
raise ValueError("Only supply one of greater_than or geq_than.")

if "less_than" in check_values and "leq_than" in check_values:
raise ValueError("Only supply one of less_than or leq_than.")

if (
"greater_than" in check_values
or "geq_than" in check_values
or "less_than" in check_values
or "leq_than" in check_values
) and "equal_to" in check_values:
raise ValueError(
"equal_to cannot be passed with a greater or less than "
"function. To specify 'greater than or equal to' or "
"'less than or equal to', use geq_than or leq_than."
)


class SQLTableCheckOperator(BaseSQLOperator):
"""
Performs one or more of the checks provided in the checks dictionary.
Checks should be written to return a boolean result.

:param table: the table to run checks on.
:param checks: the dictionary of checks, e.g.:

.. code-block:: python

{
"row_count_check": {"check_statement": "COUNT(*) = 1000"},
"column_sum_check": {"check_statement": "col_a + col_b < col_c"},
}

:param conn_id: the connection ID used to connect to the database.
:param database: name of database which overwrite the defined one in connection
"""

sql_check_template = "CASE WHEN check_statement THEN 1 ELSE 0 END AS check_name"
sql_min_template = "MIN(check_name)"

def __init__(
self,
*,
table: str,
checks: Dict[str, Dict[str, Any]],
conn_id: Optional[str] = None,
database: Optional[str] = None,
**kwargs,
):
super().__init__(conn_id=conn_id, database=database, **kwargs)

self.table = table
self.checks = checks
# OpenLineage needs a valid SQL query with the input/output table(s) to parse
self.sql = f"SELECT * FROM {self.table};"

def execute(self, context=None):
hook = self.get_db_hook()

check_names = [*self.checks]
check_mins_sql = ",".join(
self.sql_min_template.replace("check_name", check_name) for check_name in check_names
)
checks_sql = ",".join(
[
self.sql_check_template.replace("check_statement", value["check_statement"]).replace(
"check_name", check_name
)
for check_name, value in self.checks.items()
]
)

self.sql = f"SELECT {check_mins_sql} FROM (SELECT {checks_sql} FROM {self.table});"
records = hook.get_first(self.sql)

if not records:
raise AirflowException(f"The following query returned zero rows: {self.sql}")

self.log.info(f"Record: {records}")

for check in self.checks.keys():
for result in records:
self.checks[check]["success"] = bool(result)

failed_tests = _get_failed_tests(self.checks)
if failed_tests:
raise AirflowException(
f"Test failed.\nQuery:\n{self.sql}\nResults:\n{records!s}\n"
"The following tests have failed:"
f"\n{', '.join(failed_tests)}"
)

self.log.info("All tests have passed")


class BranchSQLOperator(BaseSQLOperator, SkipMixin):
"""
Allows a DAG to "branch" or follow a specified path based on the results of a SQL query.
Expand Down
2 changes: 2 additions & 0 deletions docs/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -856,6 +856,7 @@ gcpcloudsql
gcs
gdbm
generateUploadUrl
geq
getattr
getfqdn
getframe
Expand Down Expand Up @@ -1013,6 +1014,7 @@ latencies
latin
ldap
ldaps
leq
leveldb
libs
libz
Expand Down
78 changes: 78 additions & 0 deletions tests/operators/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@
from airflow.operators.sql import (
BranchSQLOperator,
SQLCheckOperator,
SQLColumnCheckOperator,
SQLIntervalCheckOperator,
SQLTableCheckOperator,
SQLThresholdCheckOperator,
SQLValueCheckOperator,
)
Expand Down Expand Up @@ -385,6 +387,82 @@ def test_fail_min_sql_max_value(self, mock_get_db_hook):
operator.execute()


class TestColumnCheckOperator(unittest.TestCase):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO this can be a pytest test rather than continuing to use unittest.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason to use one over the other? I was just copying the test pattern from the other ones.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We prefer pytest test type unless there are good reasons not to (for new test)


valid_column_mapping = {
"X": {
"null_check": {"equal_to": 0},
"distinct_check": {"equal_to": 10, "tolerance": 0.1},
"unique_check": {"geq_than": 10},
"min": {"leq_than": 1},
"max": {"less_than": 20, "greater_than": 10},
}
}

invalid_column_mapping = {"Y": {"invalid_check_name": {"expectation": 5}}}

def _construct_operator(self, column_mapping):
return SQLColumnCheckOperator(task_id="test_task", table="test_table", column_mapping=column_mapping)

@mock.patch.object(SQLColumnCheckOperator, "get_db_hook")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since every test case uses the same mocked object, you could think about using a pytest fixture for mocking get_db_hook. I don't have a strong opinion though.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you use pytest fixtures, the monkeypatch fixture makes this easy to do in setup_method.

def test_check_not_in_column_checks(self, mock_get_db_hook):
with pytest.raises(AirflowException, match="Invalid column check: invalid_check_name."):
self._construct_operator(self.invalid_column_mapping)

@mock.patch.object(SQLColumnCheckOperator, "get_db_hook")
def test_pass_all_checks_exact_check(self, mock_get_db_hook):
mock_hook = mock.Mock()
mock_hook.get_first.return_value = (0, 10, 10, 1, 19)
mock_get_db_hook.return_value = mock_hook
operator = self._construct_operator(self.valid_column_mapping)
operator.execute()

@mock.patch.object(SQLColumnCheckOperator, "get_db_hook")
def test_pass_all_checks_inexact_check(self, mock_get_db_hook):
mock_hook = mock.Mock()
mock_hook.get_first.return_value = (0, 9, 12, 0, 15)
mock_get_db_hook.return_value = mock_hook
operator = self._construct_operator(self.valid_column_mapping)
operator.execute()

@mock.patch.object(SQLColumnCheckOperator, "get_db_hook")
def test_fail_all_checks_check(self, mock_get_db_hook):
mock_hook = mock.Mock()
mock_hook.get_first.return_value = (1, 12, 11, -1, 20)
mock_get_db_hook.return_value = mock_hook
operator = self._construct_operator(self.valid_column_mapping)
with pytest.raises(AirflowException):
operator.execute()


class TestTableCheckOperator(unittest.TestCase):

checks = {
"row_count_check": {"check_statement": "COUNT(*) == 1000"},
"column_sum_check": {"check_statement": "col_a + col_b < col_c"},
}

def _construct_operator(self, checks):
return SQLTableCheckOperator(task_id="test_task", table="test_table", checks=checks)

@mock.patch.object(SQLTableCheckOperator, "get_db_hook")
def test_pass_all_checks_check(self, mock_get_db_hook):
mock_hook = mock.Mock()
mock_hook.get_first.return_value = (1000, 1)
mock_get_db_hook.return_value = mock_hook
operator = self._construct_operator(self.checks)
operator.execute()

@mock.patch.object(SQLTableCheckOperator, "get_db_hook")
def test_fail_all_checks_check(self, mock_get_db_hook):
mock_hook = mock.Mock()
mock_hook.get_first.return_value = (998, 0)
mock_get_db_hook.return_value = mock_hook
operator = self._construct_operator(self.checks)
with pytest.raises(AirflowException):
operator.execute()


class TestSqlBranch(TestHiveEnvironment, unittest.TestCase):
"""
Test for SQL Branch Operator
Expand Down