diff --git a/airflow/providers/amazon/aws/hooks/base_aws.py b/airflow/providers/amazon/aws/hooks/base_aws.py index 2a43a2f502931..de52edd72fdca 100644 --- a/airflow/providers/amazon/aws/hooks/base_aws.py +++ b/airflow/providers/amazon/aws/hooks/base_aws.py @@ -30,6 +30,7 @@ import logging import os import uuid +import warnings from copy import deepcopy from functools import cached_property, wraps from os import PathLike @@ -53,6 +54,7 @@ from airflow.exceptions import ( AirflowException, AirflowNotFoundException, + AirflowProviderDeprecationWarning, ) from airflow.hooks.base import BaseHook from airflow.providers.amazon.aws.utils.connection_wrapper import AwsConnectionWrapper @@ -966,6 +968,15 @@ class BaseAsyncSessionFactory(BaseSessionFactory): provided in Airflow connection """ + def __init__(self, *args, **kwargs): + warnings.warn( + "airflow.providers.amazon.aws.hook.base_aws.BaseAsyncSessionFactory has been deprecated and " + "will be removed in future", + AirflowProviderDeprecationWarning, + stacklevel=2, + ) + super().__init__(*args, **kwargs) + async def get_role_credentials(self) -> dict: """Get the role_arn, method credentials from connection details and get the role credentials detail""" async with self._basic_session.create_client("sts", region_name=self.region_name) as client: @@ -1059,6 +1070,15 @@ class AwsBaseAsyncHook(AwsBaseHook): :param config: Configuration for botocore client. """ + def __init__(self, *args, **kwargs): + warnings.warn( + "airflow.providers.amazon.aws.hook.base_aws.AwsBaseAsyncHook has been deprecated and " + "will be removed in future", + AirflowProviderDeprecationWarning, + stacklevel=2, + ) + super().__init__(*args, **kwargs) + def get_async_session(self) -> AioSession: """Get the underlying aiobotocore.session.AioSession(...).""" return BaseAsyncSessionFactory( diff --git a/airflow/providers/amazon/aws/hooks/redshift_cluster.py b/airflow/providers/amazon/aws/hooks/redshift_cluster.py index 27fc25a1de546..56954d6a984da 100644 --- a/airflow/providers/amazon/aws/hooks/redshift_cluster.py +++ b/airflow/providers/amazon/aws/hooks/redshift_cluster.py @@ -17,11 +17,13 @@ from __future__ import annotations import asyncio +import warnings from typing import Any, Sequence import botocore.exceptions from botocore.exceptions import ClientError +from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseAsyncHook, AwsBaseHook @@ -204,7 +206,13 @@ def get_cluster_snapshot_status(self, snapshot_identifier: str): class RedshiftAsyncHook(AwsBaseAsyncHook): """Interact with AWS Redshift using aiobotocore library""" - def __init__(self, *args: Any, **kwargs: Any) -> None: + def __init__(self, *args, **kwargs): + warnings.warn( + "airflow.providers.amazon.aws.hook.base_aws.RedshiftAsyncHook has been deprecated and " + "will be removed in future", + AirflowProviderDeprecationWarning, + stacklevel=2, + ) kwargs["client_type"] = "redshift" super().__init__(*args, **kwargs) diff --git a/airflow/providers/amazon/aws/triggers/redshift_cluster.py b/airflow/providers/amazon/aws/triggers/redshift_cluster.py index 6b1a16bc827e1..d1b6f4d29e4be 100644 --- a/airflow/providers/amazon/aws/triggers/redshift_cluster.py +++ b/airflow/providers/amazon/aws/triggers/redshift_cluster.py @@ -18,79 +18,14 @@ import asyncio from functools import cached_property -from typing import Any, AsyncIterator +from typing import Any from botocore.exceptions import WaiterError -from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftAsyncHook, RedshiftHook +from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftHook from airflow.triggers.base import BaseTrigger, TriggerEvent -class RedshiftClusterTrigger(BaseTrigger): - """AWS Redshift trigger""" - - def __init__( - self, - task_id: str, - aws_conn_id: str, - cluster_identifier: str, - operation_type: str, - attempts: int, - poll_interval: float = 5.0, - ): - super().__init__() - self.task_id = task_id - self.poll_interval = poll_interval - self.aws_conn_id = aws_conn_id - self.cluster_identifier = cluster_identifier - self.operation_type = operation_type - self.attempts = attempts - - def serialize(self) -> tuple[str, dict[str, Any]]: - return ( - "airflow.providers.amazon.aws.triggers.redshift_cluster.RedshiftClusterTrigger", - { - "task_id": self.task_id, - "poll_interval": self.poll_interval, - "aws_conn_id": self.aws_conn_id, - "cluster_identifier": self.cluster_identifier, - "attempts": self.attempts, - "operation_type": self.operation_type, - }, - ) - - async def run(self) -> AsyncIterator[TriggerEvent]: - hook = RedshiftAsyncHook(aws_conn_id=self.aws_conn_id) - while self.attempts >= 1: - self.attempts = self.attempts - 1 - try: - if self.operation_type == "pause_cluster": - response = await hook.pause_cluster( - cluster_identifier=self.cluster_identifier, - poll_interval=self.poll_interval, - ) - if response.get("status") == "success": - yield TriggerEvent(response) - else: - if self.attempts < 1: - yield TriggerEvent({"status": "error", "message": f"{self.task_id} failed"}) - elif self.operation_type == "resume_cluster": - response = await hook.resume_cluster( - cluster_identifier=self.cluster_identifier, - polling_period_seconds=self.poll_interval, - ) - if response: - yield TriggerEvent(response) - else: - error_message = f"{self.task_id} failed" - yield TriggerEvent({"status": "error", "message": error_message}) - else: - yield TriggerEvent(f"{self.operation_type} is not supported") - except Exception as e: - if self.attempts < 1: - yield TriggerEvent({"status": "error", "message": str(e)}) - - class RedshiftCreateClusterTrigger(BaseTrigger): """ Trigger for RedshiftCreateClusterOperator. diff --git a/tests/providers/amazon/aws/deferrable/triggers/__init__.py b/tests/providers/amazon/aws/deferrable/triggers/__init__.py deleted file mode 100644 index 217e5db960782..0000000000000 --- a/tests/providers/amazon/aws/deferrable/triggers/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. diff --git a/tests/providers/amazon/aws/deferrable/triggers/test_redshift_cluster.py b/tests/providers/amazon/aws/deferrable/triggers/test_redshift_cluster.py deleted file mode 100644 index f2430f6d032f1..0000000000000 --- a/tests/providers/amazon/aws/deferrable/triggers/test_redshift_cluster.py +++ /dev/null @@ -1,168 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -from unittest import mock - -import pytest - -from airflow.providers.amazon.aws.triggers.redshift_cluster import ( - RedshiftClusterTrigger, -) -from airflow.triggers.base import TriggerEvent - -pytest.importorskip("aiobotocore") - -TASK_ID = "redshift_trigger_check" -POLLING_PERIOD_SECONDS = 1.0 - - -class TestRedshiftClusterTrigger: - def test_pause_serialization(self): - """ - Asserts that the RedshiftClusterTrigger correctly serializes its arguments - and classpath. - """ - trigger = RedshiftClusterTrigger( - task_id=TASK_ID, - poll_interval=POLLING_PERIOD_SECONDS, - aws_conn_id="test_redshift_conn_id", - cluster_identifier="mock_cluster_identifier", - attempts=10, - operation_type="pause_cluster", - ) - classpath, kwargs = trigger.serialize() - assert classpath == "airflow.providers.amazon.aws.triggers.redshift_cluster.RedshiftClusterTrigger" - assert kwargs == { - "task_id": TASK_ID, - "poll_interval": POLLING_PERIOD_SECONDS, - "aws_conn_id": "test_redshift_conn_id", - "cluster_identifier": "mock_cluster_identifier", - "attempts": 10, - "operation_type": "pause_cluster", - } - - @pytest.mark.asyncio - @mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftAsyncHook.pause_cluster") - async def test_pause_trigger_run(self, mock_pause_cluster): - """ - Test trigger event for the pause_cluster response - """ - trigger = RedshiftClusterTrigger( - task_id=TASK_ID, - poll_interval=POLLING_PERIOD_SECONDS, - aws_conn_id="test_redshift_conn_id", - cluster_identifier="mock_cluster_identifier", - attempts=1, - operation_type="pause_cluster", - ) - generator = trigger.run() - await generator.asend(None) - mock_pause_cluster.assert_called_once_with( - cluster_identifier="mock_cluster_identifier", poll_interval=1.0 - ) - - @pytest.mark.asyncio - @mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftAsyncHook.pause_cluster") - async def test_pause_trigger_failure(self, mock_pause_cluster): - """Test trigger event when pause cluster raise exception""" - mock_pause_cluster.side_effect = Exception("Test exception") - trigger = RedshiftClusterTrigger( - task_id=TASK_ID, - poll_interval=POLLING_PERIOD_SECONDS, - aws_conn_id="test_redshift_conn_id", - cluster_identifier="mock_cluster_identifier", - attempts=1, - operation_type="pause_cluster", - ) - task = [i async for i in trigger.run()] - assert TriggerEvent({"status": "error", "message": "Test exception"}) in task - - @pytest.mark.asyncio - @pytest.mark.parametrize( - "operation_type,return_value,response", - [ - ( - "resume_cluster", - {"status": "error", "message": "test error"}, - TriggerEvent({"status": "error", "message": "test error"}), - ), - ], - ) - @mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftAsyncHook.resume_cluster") - async def test_resume_trigger_run_error( - self, mock_resume_cluster, operation_type, return_value, response - ): - """Test RedshiftClusterTrigger resume cluster with success""" - mock_resume_cluster.return_value = return_value - trigger = RedshiftClusterTrigger( - task_id=TASK_ID, - poll_interval=POLLING_PERIOD_SECONDS, - aws_conn_id="test_redshift_conn_id", - cluster_identifier="mock_cluster_identifier", - operation_type=operation_type, - attempts=1, - ) - generator = trigger.run() - actual = await generator.asend(None) - assert response == actual - - @pytest.mark.asyncio - @pytest.mark.parametrize( - "operation_type,return_value,response", - [ - ( - "resume_cluster", - {"status": "success", "cluster_state": "available"}, - TriggerEvent({"status": "success", "cluster_state": "available"}), - ), - ], - ) - @mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftAsyncHook.resume_cluster") - async def test_resume_trigger_run_success( - self, mock_resume_cluster, operation_type, return_value, response - ): - """Test RedshiftClusterTrigger resume cluster with success""" - mock_resume_cluster.return_value = return_value - trigger = RedshiftClusterTrigger( - task_id=TASK_ID, - poll_interval=POLLING_PERIOD_SECONDS, - aws_conn_id="test_redshift_conn_id", - cluster_identifier="mock_cluster_identifier", - operation_type=operation_type, - attempts=1, - ) - generator = trigger.run() - actual = await generator.asend(None) - assert response == actual - - @pytest.mark.asyncio - @mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftAsyncHook.resume_cluster") - async def test_resume_trigger_failure(self, mock_resume_cluster): - """Test RedshiftClusterTrigger resume cluster with failure status""" - mock_resume_cluster.side_effect = Exception("Test exception") - trigger = RedshiftClusterTrigger( - task_id=TASK_ID, - poll_interval=POLLING_PERIOD_SECONDS, - aws_conn_id="test_redshift_conn_id", - cluster_identifier="mock_cluster_identifier", - operation_type="resume_cluster", - attempts=1, - ) - task = [i async for i in trigger.run()] - assert len(task) == 1 - assert TriggerEvent({"status": "error", "message": "Test exception"}) in task