diff --git a/airflow/providers/amazon/aws/operators/athena.py b/airflow/providers/amazon/aws/operators/athena.py index 612e563ce678b..0467fe6d11aed 100644 --- a/airflow/providers/amazon/aws/operators/athena.py +++ b/airflow/providers/amazon/aws/operators/athena.py @@ -20,8 +20,10 @@ from functools import cached_property from typing import TYPE_CHECKING, Any, Sequence +from airflow import AirflowException from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.athena import AthenaHook +from airflow.providers.amazon.aws.triggers.athena import AthenaTrigger if TYPE_CHECKING: from airflow.utils.context import Context @@ -31,6 +33,9 @@ class AthenaOperator(BaseOperator): """ An operator that submits a presto query to athena. + .. note:: if the task is killed while it runs, it'll cancel the athena query that was launched, + EXCEPT if running in deferrable mode. + .. seealso:: For more information on how to use this operator, take a look at the guide: :ref:`howto/operator:AthenaOperator` @@ -69,6 +74,7 @@ def __init__( sleep_time: int = 30, max_polling_attempts: int | None = None, log_query: bool = True, + deferrable: bool = False, **kwargs: Any, ) -> None: super().__init__(**kwargs) @@ -81,9 +87,10 @@ def __init__( self.query_execution_context = query_execution_context or {} self.result_configuration = result_configuration or {} self.sleep_time = sleep_time - self.max_polling_attempts = max_polling_attempts + self.max_polling_attempts = max_polling_attempts or 999999 self.query_execution_id: str | None = None self.log_query: bool = log_query + self.deferrable = deferrable @cached_property def hook(self) -> AthenaHook: @@ -101,6 +108,15 @@ def execute(self, context: Context) -> str | None: self.client_request_token, self.workgroup, ) + + if self.deferrable: + self.defer( + trigger=AthenaTrigger( + self.query_execution_id, self.sleep_time, self.max_polling_attempts, self.aws_conn_id + ), + method_name="execute_complete", + ) + # implicit else: query_status = self.hook.poll_query_status( self.query_execution_id, max_polling_attempts=self.max_polling_attempts, @@ -121,6 +137,11 @@ def execute(self, context: Context) -> str | None: return self.query_execution_id + def execute_complete(self, context, event=None): + if event["status"] != "success": + raise AirflowException(f"Error while waiting for operation on cluster to complete: {event}") + return event["value"] + def on_kill(self) -> None: """Cancel the submitted athena query.""" if self.query_execution_id: diff --git a/airflow/providers/amazon/aws/triggers/athena.py b/airflow/providers/amazon/aws/triggers/athena.py new file mode 100644 index 0000000000000..efae559470a28 --- /dev/null +++ b/airflow/providers/amazon/aws/triggers/athena.py @@ -0,0 +1,76 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import Any + +from airflow.providers.amazon.aws.hooks.athena import AthenaHook +from airflow.providers.amazon.aws.utils.waiter_with_logging import async_wait +from airflow.triggers.base import BaseTrigger, TriggerEvent + + +class AthenaTrigger(BaseTrigger): + """ + Trigger for RedshiftCreateClusterOperator. + + The trigger will asynchronously poll the boto3 API and wait for the + Redshift cluster to be in the `available` state. + + :param query_execution_id: ID of the Athena query execution to watch + :param poll_interval: The amount of time in seconds to wait between attempts. + :param max_attempt: The maximum number of attempts to be made. + :param aws_conn_id: The Airflow connection used for AWS credentials. + """ + + def __init__( + self, + query_execution_id: str, + poll_interval: int, + max_attempt: int, + aws_conn_id: str, + ): + self.query_execution_id = query_execution_id + self.poll_interval = poll_interval + self.max_attempt = max_attempt + self.aws_conn_id = aws_conn_id + + def serialize(self) -> tuple[str, dict[str, Any]]: + return ( + self.__class__.__module__ + "." + self.__class__.__qualname__, + { + "query_execution_id": str(self.query_execution_id), + "poll_interval": str(self.poll_interval), + "max_attempt": str(self.max_attempt), + "aws_conn_id": str(self.aws_conn_id), + }, + ) + + async def run(self): + hook = AthenaHook(self.aws_conn_id) + async with hook.async_conn as client: + waiter = hook.get_waiter("query_complete", deferrable=True, client=client) + await async_wait( + waiter=waiter, + waiter_delay=self.poll_interval, + waiter_max_attempts=self.max_attempt, + args={"QueryExecutionId": self.query_execution_id}, + failure_message=f"Error while waiting for query {self.query_execution_id} to complete", + status_message=f"Query execution id: {self.query_execution_id}, " + "Query is still in non-terminal state", + status_args=["QueryExecution.Status.State"], + ) + yield TriggerEvent({"status": "success", "value": self.query_execution_id}) diff --git a/airflow/providers/amazon/provider.yaml b/airflow/providers/amazon/provider.yaml index 5439f9c8cbb76..e4f16ce398460 100644 --- a/airflow/providers/amazon/provider.yaml +++ b/airflow/providers/amazon/provider.yaml @@ -515,6 +515,9 @@ hooks: - airflow.providers.amazon.aws.hooks.appflow triggers: + - integration-name: Amazon Athena + python-modules: + - airflow.providers.amazon.aws.triggers.athena - integration-name: AWS Batch python-modules: - airflow.providers.amazon.aws.triggers.batch diff --git a/tests/providers/amazon/aws/operators/test_athena.py b/tests/providers/amazon/aws/operators/test_athena.py index cfc78697683d1..9e528525204c6 100644 --- a/tests/providers/amazon/aws/operators/test_athena.py +++ b/tests/providers/amazon/aws/operators/test_athena.py @@ -20,9 +20,11 @@ import pytest +from airflow.exceptions import TaskDeferred from airflow.models import DAG, DagRun, TaskInstance from airflow.providers.amazon.aws.hooks.athena import AthenaHook from airflow.providers.amazon.aws.operators.athena import AthenaOperator +from airflow.providers.amazon.aws.triggers.athena import AthenaTrigger from airflow.utils import timezone from airflow.utils.timezone import datetime @@ -158,3 +160,13 @@ def test_return_value(self, mock_conn, mock_run_query, mock_check_query_status): ti.dag_run = dag_run assert self.athena.execute(ti.get_template_context()) == ATHENA_QUERY_ID + + @mock.patch.object(AthenaHook, "run_query", return_value=ATHENA_QUERY_ID) + def test_is_deferred(self, mock_run_query): + self.athena.deferrable = True + + with pytest.raises(TaskDeferred) as deferred: + self.athena.execute(None) + + assert isinstance(deferred.value.trigger, AthenaTrigger) + assert deferred.value.trigger.query_execution_id == ATHENA_QUERY_ID diff --git a/tests/providers/amazon/aws/triggers/test_athena.py b/tests/providers/amazon/aws/triggers/test_athena.py new file mode 100644 index 0000000000000..04e601f4392c4 --- /dev/null +++ b/tests/providers/amazon/aws/triggers/test_athena.py @@ -0,0 +1,53 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock +from unittest.mock import AsyncMock + +import pytest +from botocore.exceptions import WaiterError + +from airflow.providers.amazon.aws.hooks.athena import AthenaHook +from airflow.providers.amazon.aws.triggers.athena import AthenaTrigger + + +class TestAthenaTrigger: + @pytest.mark.asyncio + @mock.patch.object(AthenaHook, "get_waiter") + @mock.patch.object(AthenaHook, "async_conn") # LatestBoto step of CI fails without this + async def test_run_with_error(self, conn_mock, waiter_mock): + waiter_mock.side_effect = WaiterError("name", "reason", {}) + + trigger = AthenaTrigger("query_id", 0, 5, None) + + with pytest.raises(WaiterError): + generator = trigger.run() + await generator.asend(None) + + @pytest.mark.asyncio + @mock.patch.object(AthenaHook, "get_waiter") + @mock.patch.object(AthenaHook, "async_conn") # LatestBoto step of CI fails without this + async def test_run_success(self, conn_mock, waiter_mock): + waiter_mock().wait = AsyncMock() + trigger = AthenaTrigger("my_query_id", 0, 5, None) + + generator = trigger.run() + event = await generator.asend(None) + + assert event.payload["status"] == "success" + assert event.payload["value"] == "my_query_id"