Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 196 additions & 3 deletions airflow/providers/snowflake/operators/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,16 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any, Optional, SupportsAbs
from typing import Any, Dict, Iterable, List, Mapping, Optional, SupportsAbs, Union

from airflow.models import BaseOperator
from airflow.operators.sql import SQLCheckOperator, SQLIntervalCheckOperator, SQLValueCheckOperator
from airflow.operators.sql import (
BranchSQLOperator,
SQLCheckOperator,
SQLIntervalCheckOperator,
SQLThresholdCheckOperator,
SQLValueCheckOperator,
)
from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook


Expand Down Expand Up @@ -332,6 +338,19 @@ class SnowflakeIntervalCheckOperator(SQLIntervalCheckOperator):
:param days_back: number of days between ds and the ds we want to check
against. Defaults to 7 days
:type days_back: int
:param date_filter_column: The column name for the dates to filter on. Defaults to 'ds'
:type date_filter_column: Optional[str]
:param ratio_formula: which formula to use to compute the ratio between
the two metrics. Assuming cur is the metric of today and ref is
the metric to today - days_back.

max_over_min: computes max(cur, ref) / min(cur, ref)
relative_diff: computes abs(cur-ref) / ref

Default: 'max_over_min'
:type ratio_formula: str
:param ignore_zero: whether we should ignore zero metrics
:type ignore_zero: bool
:param metrics_thresholds: a dictionary of ratios indexed by metrics, for
example 'COUNT(*)': 1.5 would require a 50 percent or less difference
between the current day, and the prior days_back.
Expand Down Expand Up @@ -373,9 +392,11 @@ def __init__(
self,
*,
table: str,
metrics_thresholds: dict,
metrics_thresholds: Dict[str, int],
date_filter_column: str = 'ds',
days_back: SupportsAbs[int] = -7,
ratio_formula: Optional[str] = 'max_over_min',
ignore_zero: bool = True,
snowflake_conn_id: str = 'snowflake_default',
parameters: Optional[dict] = None,
autocommit: bool = True,
Expand Down Expand Up @@ -409,3 +430,175 @@ def __init__(

def get_db_hook(self) -> SnowflakeHook:
return get_db_hook(self)


class SnowflakeThresholdCheckOperator(SQLThresholdCheckOperator):
"""
Performs a value check using sql code against a minimum threshold
and a maximum threshold. Thresholds can be in the form of a numeric
value OR a sql statement that results a numeric.

:param sql: the sql to be executed. (templated)
:type sql: str
:param snowflake_conn_id: Reference to
:ref:`Snowflake connection id<howto/connection:snowflake>`
:type snowflake_conn_id: str
:param min_threshold: numerical value or min threshold sql to be executed (templated)
:type min_threshold: numeric or str
:param max_threshold: numerical value or max threshold sql to be executed (templated)
:type max_threshold: numeric or str
:param autocommit: if True, each command is automatically committed.
(default value: True)
:type autocommit: bool
:param warehouse: name of warehouse (will overwrite any warehouse
defined in the connection's extra JSON)
:type warehouse: str
:param database: name of database (will overwrite database defined
in connection)
:type database: str
:param schema: name of schema (will overwrite schema defined in
connection)
:type schema: str
:param role: name of role (will overwrite any role defined in
connection's extra JSON)
:type role: str
:param authenticator: authenticator for Snowflake.
'snowflake' (default) to use the internal Snowflake authenticator
'externalbrowser' to authenticate using your web browser and
Okta, ADFS or any other SAML 2.0-compliant identify provider
(IdP) that has been defined for your account
'https://<your_okta_account_name>.okta.com' to authenticate
through native Okta.
:type authenticator: str
:param session_parameters: You can set session-level parameters at
the time you connect to Snowflake
:type session_parameters: dict
"""

def __init__(
self,
*,
sql: str,
min_threshold: Any,
max_threshold: Any,
snowflake_conn_id: str = 'snowflake_default',
parameters: Optional[dict] = None,
autocommit: bool = True,
do_xcom_push: bool = True,
warehouse: Optional[str] = None,
database: Optional[str] = None,
role: Optional[str] = None,
schema: Optional[str] = None,
authenticator: Optional[str] = None,
session_parameters: Optional[dict] = None,
**kwargs,
) -> None:
super().__init__(
sql=sql,
min_threshold=min_threshold,
max_threshold=max_threshold,
**kwargs,
)

self.snowflake_conn_id = snowflake_conn_id
self.autocommit = autocommit
self.do_xcom_push = do_xcom_push
self.parameters = parameters
self.warehouse = warehouse
self.database = database
self.role = role
self.schema = schema
self.authenticator = authenticator
self.session_parameters = session_parameters
self.query_ids = []

def get_db_hook(self) -> SnowflakeHook:
return get_db_hook(self)


class BranchSnowflakeOperator(BranchSQLOperator):
Copy link
Contributor

@eladkal eladkal Sep 21, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't an approch like https://github.com/apache/airflow/pull/18394/files be better?
Saves the trouble of setting up unique classes specifically for Snowflake

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this solution a lot, would also handle all the other cases (BigQuery specifically, which this and my last PR was based on). Seems like it would make more sense for me to open a separate PR for that solution, though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"""
Executes sql code in a specific database

:param sql: the sql code to be executed. (templated)
:type sql: Can receive a str representing a sql statement or reference to a template file.
Template reference are recognized by str ending in '.sql'.
Expected SQL query to return Boolean (True/False), integer (0 = False, Otherwise = 1)
or string (true/y/yes/1/on/false/n/no/0/off).
:param follow_task_ids_if_true: task id or task ids to follow if query return true
:type follow_task_ids_if_true: str or list
:param follow_task_ids_if_false: task id or task ids to follow if query return true
:type follow_task_ids_if_false: str or list
:param snowflake_conn_id: Reference to
:ref:`Snowflake connection id<howto/connection:snowflake>`
:type snowflake_conn_id: str
:param parameters: (optional) the parameters to render the SQL query with.
:type parameters: mapping or iterable
:param autocommit: if True, each command is automatically committed.
(default value: True)
:type autocommit: bool
:param warehouse: name of warehouse (will overwrite any warehouse
defined in the connection's extra JSON)
:type warehouse: str
:param database: name of database (will overwrite database defined
in connection)
:type database: str
:param schema: name of schema (will overwrite schema defined in
connection)
:type schema: str
:param role: name of role (will overwrite any role defined in
connection's extra JSON)
:type role: str
:param authenticator: authenticator for Snowflake.
'snowflake' (default) to use the internal Snowflake authenticator
'externalbrowser' to authenticate using your web browser and
Okta, ADFS or any other SAML 2.0-compliant identify provider
(IdP) that has been defined for your account
'https://<your_okta_account_name>.okta.com' to authenticate
through native Okta.
:type authenticator: str
:param session_parameters: You can set session-level parameters at
the time you connect to Snowflake
:type session_parameters: dict
"""

def __init__(
self,
*,
sql: str,
follow_task_ids_if_true: List[str],
follow_task_ids_if_false: List[str],
snowflake_conn_id: str = 'snowflake_default',
autocommit: bool = True,
do_xcom_push: bool = True,
warehouse: Optional[str] = None,
database: Optional[str] = None,
role: Optional[str] = None,
schema: Optional[str] = None,
authenticator: Optional[str] = None,
session_parameters: Optional[dict] = None,
parameters: Optional[Union[Mapping, Iterable]] = None,
**kwargs,
) -> None:
super().__init__(
sql=sql,
follow_task_ids_if_true=follow_task_ids_if_true,
follow_task_ids_if_false=follow_task_ids_if_false,
parameters=parameters,
**kwargs,
)

self.snowflake_conn_id = snowflake_conn_id
self.autocommit = autocommit
self.do_xcom_push = do_xcom_push
self.parameters = parameters
self.warehouse = warehouse
self.database = database
self.role = role
self.schema = schema
self.authenticator = authenticator
self.session_parameters = session_parameters
self.query_ids = []

def get_db_hook(self) -> SnowflakeHook:
return get_db_hook(self)
28 changes: 28 additions & 0 deletions tests/providers/snowflake/operators/test_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@

from airflow.models.dag import DAG
from airflow.providers.snowflake.operators.snowflake import (
BranchSnowflakeOperator,
SnowflakeCheckOperator,
SnowflakeIntervalCheckOperator,
SnowflakeOperator,
SnowflakeThresholdCheckOperator,
SnowflakeValueCheckOperator,
)
from airflow.utils import timezone
Expand Down Expand Up @@ -57,12 +59,38 @@ def test_snowflake_operator(self, mock_get_db_hook):
operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)


@pytest.mark.parametrize(
"operator_class, kwargs",
[
(
BranchSnowflakeOperator,
dict(sql="SELECT 1", follow_task_ids_if_true="branch_1", follow_task_ids_if_false="branch_2"),
)
],
)
class TestBranchSnowflakeOperator:
@mock.patch("airflow.providers.snowflake.operators.snowflake.get_db_hook")
def test_get_db_hook(
self,
mock_get_db_hook,
operator_class,
kwargs,
):
operator = operator_class(task_id='branch_snowflake', snowflake_conn_id='snowflake_default', **kwargs)
operator.get_db_hook()
mock_get_db_hook.assert_called_once()


@pytest.mark.parametrize(
"operator_class, kwargs",
[
(SnowflakeCheckOperator, dict(sql='Select * from test_table')),
(SnowflakeValueCheckOperator, dict(sql='Select * from test_table', pass_value=95)),
(SnowflakeIntervalCheckOperator, dict(table='test-table-id', metrics_thresholds={'COUNT(*)': 1.5})),
(
SnowflakeThresholdCheckOperator,
dict(sql='Select * from test_table', min_threshold=0, max_threshold=10),
),
],
)
class TestSnowflakeCheckOperators:
Expand Down