diff --git a/airflow/providers/amazon/aws/operators/athena.py b/airflow/providers/amazon/aws/operators/athena.py index 990f2ec414f13..612e563ce678b 100644 --- a/airflow/providers/amazon/aws/operators/athena.py +++ b/airflow/providers/amazon/aws/operators/athena.py @@ -20,10 +20,8 @@ from functools import cached_property from typing import TYPE_CHECKING, Any, Sequence -from airflow import AirflowException from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.athena import AthenaHook -from airflow.providers.amazon.aws.triggers.athena import AthenaTrigger if TYPE_CHECKING: from airflow.utils.context import Context @@ -71,7 +69,6 @@ def __init__( sleep_time: int = 30, max_polling_attempts: int | None = None, log_query: bool = True, - deferrable: bool = False, **kwargs: Any, ) -> None: super().__init__(**kwargs) @@ -84,10 +81,9 @@ def __init__( self.query_execution_context = query_execution_context or {} self.result_configuration = result_configuration or {} self.sleep_time = sleep_time - self.max_polling_attempts = max_polling_attempts or 999999 + self.max_polling_attempts = max_polling_attempts self.query_execution_id: str | None = None self.log_query: bool = log_query - self.deferrable = deferrable @cached_property def hook(self) -> AthenaHook: @@ -105,15 +101,6 @@ def execute(self, context: Context) -> str | None: self.client_request_token, self.workgroup, ) - - if self.deferrable: - self.defer( - trigger=AthenaTrigger( - self.query_execution_id, self.sleep_time, self.max_polling_attempts, self.aws_conn_id - ), - method_name="execute_complete", - ) - # implicit else: query_status = self.hook.poll_query_status( self.query_execution_id, max_polling_attempts=self.max_polling_attempts, @@ -134,11 +121,6 @@ def execute(self, context: Context) -> str | None: return self.query_execution_id - def execute_complete(self, context, event=None): - if event["status"] != "success": - raise AirflowException(f"Error while waiting for operation on cluster to complete: {event}") - return event["value"] - def on_kill(self) -> None: """Cancel the submitted athena query.""" if self.query_execution_id: diff --git a/airflow/providers/amazon/aws/triggers/athena.py b/airflow/providers/amazon/aws/triggers/athena.py deleted file mode 100644 index 780d9e9b98df2..0000000000000 --- a/airflow/providers/amazon/aws/triggers/athena.py +++ /dev/null @@ -1,76 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -from typing import Any - -from airflow.providers.amazon.aws.hooks.athena import AthenaHook -from airflow.providers.amazon.aws.utils.waiter_with_logging import async_wait -from airflow.triggers.base import BaseTrigger, TriggerEvent - - -class AthenaTrigger(BaseTrigger): - """ - Trigger for RedshiftCreateClusterOperator. - - The trigger will asynchronously poll the boto3 API and wait for the - Redshift cluster to be in the `available` state. - - :param query_execution_id: ID of the Athena query execution to watch - :param poll_interval: The amount of time in seconds to wait between attempts. - :param max_attempt: The maximum number of attempts to be made. - :param aws_conn_id: The Airflow connection used for AWS credentials. - """ - - def __init__( - self, - query_execution_id: str, - poll_interval: int, - max_attempt: int, - aws_conn_id: str, - ): - self.query_execution_id = query_execution_id - self.poll_interval = poll_interval - self.max_attempt = max_attempt - self.aws_conn_id = aws_conn_id - - def serialize(self) -> tuple[str, dict[str, Any]]: - return ( - self.__class__.__module__ + "." + self.__class__.__qualname__, - { - "query_execution_id": str(self.query_execution_id), - "poll_interval": str(self.poll_interval), - "max_attempt": str(self.max_attempt), - "aws_conn_id": str(self.aws_conn_id), - }, - ) - - async def run(self): - hook = AthenaHook(self.aws_conn_id) - async with hook.async_conn as client: - waiter = hook.get_waiter("query_complete", deferrable=True, client=client) - await async_wait( - waiter=waiter, - waiter_delay=self.poll_interval, - max_attempts=self.max_attempt, - args={"QueryExecutionId": self.query_execution_id}, - failure_message=f"Error while waiting for query {self.query_execution_id} to complete", - status_message=f"Query execution id: {self.query_execution_id}, " - "Query is still in non-terminal state", - status_args=["QueryExecution.Status.State"], - ) - yield TriggerEvent({"status": "success", "value": self.query_execution_id}) diff --git a/airflow/providers/amazon/provider.yaml b/airflow/providers/amazon/provider.yaml index e4f16ce398460..5439f9c8cbb76 100644 --- a/airflow/providers/amazon/provider.yaml +++ b/airflow/providers/amazon/provider.yaml @@ -515,9 +515,6 @@ hooks: - airflow.providers.amazon.aws.hooks.appflow triggers: - - integration-name: Amazon Athena - python-modules: - - airflow.providers.amazon.aws.triggers.athena - integration-name: AWS Batch python-modules: - airflow.providers.amazon.aws.triggers.batch diff --git a/tests/providers/amazon/aws/operators/test_athena.py b/tests/providers/amazon/aws/operators/test_athena.py index 9e528525204c6..cfc78697683d1 100644 --- a/tests/providers/amazon/aws/operators/test_athena.py +++ b/tests/providers/amazon/aws/operators/test_athena.py @@ -20,11 +20,9 @@ import pytest -from airflow.exceptions import TaskDeferred from airflow.models import DAG, DagRun, TaskInstance from airflow.providers.amazon.aws.hooks.athena import AthenaHook from airflow.providers.amazon.aws.operators.athena import AthenaOperator -from airflow.providers.amazon.aws.triggers.athena import AthenaTrigger from airflow.utils import timezone from airflow.utils.timezone import datetime @@ -160,13 +158,3 @@ def test_return_value(self, mock_conn, mock_run_query, mock_check_query_status): ti.dag_run = dag_run assert self.athena.execute(ti.get_template_context()) == ATHENA_QUERY_ID - - @mock.patch.object(AthenaHook, "run_query", return_value=ATHENA_QUERY_ID) - def test_is_deferred(self, mock_run_query): - self.athena.deferrable = True - - with pytest.raises(TaskDeferred) as deferred: - self.athena.execute(None) - - assert isinstance(deferred.value.trigger, AthenaTrigger) - assert deferred.value.trigger.query_execution_id == ATHENA_QUERY_ID diff --git a/tests/providers/amazon/aws/triggers/test_athena.py b/tests/providers/amazon/aws/triggers/test_athena.py deleted file mode 100644 index 04e601f4392c4..0000000000000 --- a/tests/providers/amazon/aws/triggers/test_athena.py +++ /dev/null @@ -1,53 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -from unittest import mock -from unittest.mock import AsyncMock - -import pytest -from botocore.exceptions import WaiterError - -from airflow.providers.amazon.aws.hooks.athena import AthenaHook -from airflow.providers.amazon.aws.triggers.athena import AthenaTrigger - - -class TestAthenaTrigger: - @pytest.mark.asyncio - @mock.patch.object(AthenaHook, "get_waiter") - @mock.patch.object(AthenaHook, "async_conn") # LatestBoto step of CI fails without this - async def test_run_with_error(self, conn_mock, waiter_mock): - waiter_mock.side_effect = WaiterError("name", "reason", {}) - - trigger = AthenaTrigger("query_id", 0, 5, None) - - with pytest.raises(WaiterError): - generator = trigger.run() - await generator.asend(None) - - @pytest.mark.asyncio - @mock.patch.object(AthenaHook, "get_waiter") - @mock.patch.object(AthenaHook, "async_conn") # LatestBoto step of CI fails without this - async def test_run_success(self, conn_mock, waiter_mock): - waiter_mock().wait = AsyncMock() - trigger = AthenaTrigger("my_query_id", 0, 5, None) - - generator = trigger.run() - event = await generator.asend(None) - - assert event.payload["status"] == "success" - assert event.payload["value"] == "my_query_id"