diff --git a/airflow/providers/snowflake/operators/snowflake.py b/airflow/providers/snowflake/operators/snowflake.py index 290c323119ab7..0218b50f4c2c4 100644 --- a/airflow/providers/snowflake/operators/snowflake.py +++ b/airflow/providers/snowflake/operators/snowflake.py @@ -15,10 +15,16 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Optional, SupportsAbs +from typing import Any, Dict, Iterable, List, Mapping, Optional, SupportsAbs, Union from airflow.models import BaseOperator -from airflow.operators.sql import SQLCheckOperator, SQLIntervalCheckOperator, SQLValueCheckOperator +from airflow.operators.sql import ( + BranchSQLOperator, + SQLCheckOperator, + SQLIntervalCheckOperator, + SQLThresholdCheckOperator, + SQLValueCheckOperator, +) from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook @@ -332,6 +338,19 @@ class SnowflakeIntervalCheckOperator(SQLIntervalCheckOperator): :param days_back: number of days between ds and the ds we want to check against. Defaults to 7 days :type days_back: int + :param date_filter_column: The column name for the dates to filter on. Defaults to 'ds' + :type date_filter_column: Optional[str] + :param ratio_formula: which formula to use to compute the ratio between + the two metrics. Assuming cur is the metric of today and ref is + the metric to today - days_back. + + max_over_min: computes max(cur, ref) / min(cur, ref) + relative_diff: computes abs(cur-ref) / ref + + Default: 'max_over_min' + :type ratio_formula: str + :param ignore_zero: whether we should ignore zero metrics + :type ignore_zero: bool :param metrics_thresholds: a dictionary of ratios indexed by metrics, for example 'COUNT(*)': 1.5 would require a 50 percent or less difference between the current day, and the prior days_back. @@ -373,9 +392,11 @@ def __init__( self, *, table: str, - metrics_thresholds: dict, + metrics_thresholds: Dict[str, int], date_filter_column: str = 'ds', days_back: SupportsAbs[int] = -7, + ratio_formula: Optional[str] = 'max_over_min', + ignore_zero: bool = True, snowflake_conn_id: str = 'snowflake_default', parameters: Optional[dict] = None, autocommit: bool = True, @@ -409,3 +430,175 @@ def __init__( def get_db_hook(self) -> SnowflakeHook: return get_db_hook(self) + + +class SnowflakeThresholdCheckOperator(SQLThresholdCheckOperator): + """ + Performs a value check using sql code against a minimum threshold + and a maximum threshold. Thresholds can be in the form of a numeric + value OR a sql statement that results a numeric. + + :param sql: the sql to be executed. (templated) + :type sql: str + :param snowflake_conn_id: Reference to + :ref:`Snowflake connection id` + :type snowflake_conn_id: str + :param min_threshold: numerical value or min threshold sql to be executed (templated) + :type min_threshold: numeric or str + :param max_threshold: numerical value or max threshold sql to be executed (templated) + :type max_threshold: numeric or str + :param autocommit: if True, each command is automatically committed. + (default value: True) + :type autocommit: bool + :param warehouse: name of warehouse (will overwrite any warehouse + defined in the connection's extra JSON) + :type warehouse: str + :param database: name of database (will overwrite database defined + in connection) + :type database: str + :param schema: name of schema (will overwrite schema defined in + connection) + :type schema: str + :param role: name of role (will overwrite any role defined in + connection's extra JSON) + :type role: str + :param authenticator: authenticator for Snowflake. + 'snowflake' (default) to use the internal Snowflake authenticator + 'externalbrowser' to authenticate using your web browser and + Okta, ADFS or any other SAML 2.0-compliant identify provider + (IdP) that has been defined for your account + 'https://.okta.com' to authenticate + through native Okta. + :type authenticator: str + :param session_parameters: You can set session-level parameters at + the time you connect to Snowflake + :type session_parameters: dict + """ + + def __init__( + self, + *, + sql: str, + min_threshold: Any, + max_threshold: Any, + snowflake_conn_id: str = 'snowflake_default', + parameters: Optional[dict] = None, + autocommit: bool = True, + do_xcom_push: bool = True, + warehouse: Optional[str] = None, + database: Optional[str] = None, + role: Optional[str] = None, + schema: Optional[str] = None, + authenticator: Optional[str] = None, + session_parameters: Optional[dict] = None, + **kwargs, + ) -> None: + super().__init__( + sql=sql, + min_threshold=min_threshold, + max_threshold=max_threshold, + **kwargs, + ) + + self.snowflake_conn_id = snowflake_conn_id + self.autocommit = autocommit + self.do_xcom_push = do_xcom_push + self.parameters = parameters + self.warehouse = warehouse + self.database = database + self.role = role + self.schema = schema + self.authenticator = authenticator + self.session_parameters = session_parameters + self.query_ids = [] + + def get_db_hook(self) -> SnowflakeHook: + return get_db_hook(self) + + +class BranchSnowflakeOperator(BranchSQLOperator): + """ + Executes sql code in a specific database + + :param sql: the sql code to be executed. (templated) + :type sql: Can receive a str representing a sql statement or reference to a template file. + Template reference are recognized by str ending in '.sql'. + Expected SQL query to return Boolean (True/False), integer (0 = False, Otherwise = 1) + or string (true/y/yes/1/on/false/n/no/0/off). + :param follow_task_ids_if_true: task id or task ids to follow if query return true + :type follow_task_ids_if_true: str or list + :param follow_task_ids_if_false: task id or task ids to follow if query return true + :type follow_task_ids_if_false: str or list + :param snowflake_conn_id: Reference to + :ref:`Snowflake connection id` + :type snowflake_conn_id: str + :param parameters: (optional) the parameters to render the SQL query with. + :type parameters: mapping or iterable + :param autocommit: if True, each command is automatically committed. + (default value: True) + :type autocommit: bool + :param warehouse: name of warehouse (will overwrite any warehouse + defined in the connection's extra JSON) + :type warehouse: str + :param database: name of database (will overwrite database defined + in connection) + :type database: str + :param schema: name of schema (will overwrite schema defined in + connection) + :type schema: str + :param role: name of role (will overwrite any role defined in + connection's extra JSON) + :type role: str + :param authenticator: authenticator for Snowflake. + 'snowflake' (default) to use the internal Snowflake authenticator + 'externalbrowser' to authenticate using your web browser and + Okta, ADFS or any other SAML 2.0-compliant identify provider + (IdP) that has been defined for your account + 'https://.okta.com' to authenticate + through native Okta. + :type authenticator: str + :param session_parameters: You can set session-level parameters at + the time you connect to Snowflake + :type session_parameters: dict + """ + + def __init__( + self, + *, + sql: str, + follow_task_ids_if_true: List[str], + follow_task_ids_if_false: List[str], + snowflake_conn_id: str = 'snowflake_default', + autocommit: bool = True, + do_xcom_push: bool = True, + warehouse: Optional[str] = None, + database: Optional[str] = None, + role: Optional[str] = None, + schema: Optional[str] = None, + authenticator: Optional[str] = None, + session_parameters: Optional[dict] = None, + parameters: Optional[Union[Mapping, Iterable]] = None, + **kwargs, + ) -> None: + super().__init__( + sql=sql, + follow_task_ids_if_true=follow_task_ids_if_true, + follow_task_ids_if_false=follow_task_ids_if_false, + parameters=parameters, + **kwargs, + ) + + self.snowflake_conn_id = snowflake_conn_id + self.autocommit = autocommit + self.do_xcom_push = do_xcom_push + self.parameters = parameters + self.warehouse = warehouse + self.database = database + self.role = role + self.schema = schema + self.authenticator = authenticator + self.session_parameters = session_parameters + self.query_ids = [] + + def get_db_hook(self) -> SnowflakeHook: + return get_db_hook(self) diff --git a/tests/providers/snowflake/operators/test_snowflake.py b/tests/providers/snowflake/operators/test_snowflake.py index 8cdcc0936d898..8ee3f093f81be 100644 --- a/tests/providers/snowflake/operators/test_snowflake.py +++ b/tests/providers/snowflake/operators/test_snowflake.py @@ -23,9 +23,11 @@ from airflow.models.dag import DAG from airflow.providers.snowflake.operators.snowflake import ( + BranchSnowflakeOperator, SnowflakeCheckOperator, SnowflakeIntervalCheckOperator, SnowflakeOperator, + SnowflakeThresholdCheckOperator, SnowflakeValueCheckOperator, ) from airflow.utils import timezone @@ -57,12 +59,38 @@ def test_snowflake_operator(self, mock_get_db_hook): operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) +@pytest.mark.parametrize( + "operator_class, kwargs", + [ + ( + BranchSnowflakeOperator, + dict(sql="SELECT 1", follow_task_ids_if_true="branch_1", follow_task_ids_if_false="branch_2"), + ) + ], +) +class TestBranchSnowflakeOperator: + @mock.patch("airflow.providers.snowflake.operators.snowflake.get_db_hook") + def test_get_db_hook( + self, + mock_get_db_hook, + operator_class, + kwargs, + ): + operator = operator_class(task_id='branch_snowflake', snowflake_conn_id='snowflake_default', **kwargs) + operator.get_db_hook() + mock_get_db_hook.assert_called_once() + + @pytest.mark.parametrize( "operator_class, kwargs", [ (SnowflakeCheckOperator, dict(sql='Select * from test_table')), (SnowflakeValueCheckOperator, dict(sql='Select * from test_table', pass_value=95)), (SnowflakeIntervalCheckOperator, dict(table='test-table-id', metrics_thresholds={'COUNT(*)': 1.5})), + ( + SnowflakeThresholdCheckOperator, + dict(sql='Select * from test_table', min_threshold=0, max_threshold=10), + ), ], ) class TestSnowflakeCheckOperators: