diff --git a/airflow/providers/amazon/aws/operators/redshift_cluster.py b/airflow/providers/amazon/aws/operators/redshift_cluster.py index 94ce2e58c28a7..abf0e68f39435 100644 --- a/airflow/providers/amazon/aws/operators/redshift_cluster.py +++ b/airflow/providers/amazon/aws/operators/redshift_cluster.py @@ -24,10 +24,10 @@ from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftHook from airflow.providers.amazon.aws.triggers.redshift_cluster import ( - RedshiftClusterTrigger, RedshiftCreateClusterSnapshotTrigger, RedshiftCreateClusterTrigger, RedshiftPauseClusterTrigger, + RedshiftResumeClusterTrigger, ) if TYPE_CHECKING: @@ -452,8 +452,11 @@ class RedshiftResumeClusterOperator(BaseOperator): :param cluster_identifier: Unique identifier of the AWS Redshift cluster :param aws_conn_id: The Airflow connection used for AWS credentials. The default connection id is ``aws_default`` - :param deferrable: Run operator in deferrable mode :param poll_interval: Time (in seconds) to wait between two consecutive calls to check cluster state + :param max_attempts: The maximum number of attempts to check the state of the cluster. + :param wait_for_completion: If True, the operator will wait for the cluster to be in the + `resumed` state. Default is False. + :param deferrable: If True, the operator will run as a deferrable operator. """ template_fields: Sequence[str] = ("cluster_identifier",) @@ -465,66 +468,71 @@ def __init__( *, cluster_identifier: str, aws_conn_id: str = "aws_default", + wait_for_completion: bool = False, deferrable: bool = False, poll_interval: int = 10, + max_attempts: int = 10, **kwargs, ): super().__init__(**kwargs) self.cluster_identifier = cluster_identifier self.aws_conn_id = aws_conn_id + self.wait_for_completion = wait_for_completion self.deferrable = deferrable + self.max_attempts = max_attempts self.poll_interval = poll_interval - # These parameters are added to address an issue with the boto3 API where the API + # These parameters are used to address an issue with the boto3 API where the API # prematurely reports the cluster as available to receive requests. This causes the cluster # to reject initial attempts to resume the cluster despite reporting the correct state. - self._attempts = 10 + self._remaining_attempts = 10 self._attempt_interval = 15 def execute(self, context: Context): redshift_hook = RedshiftHook(aws_conn_id=self.aws_conn_id) + self.log.info("Starting resume cluster") + while self._remaining_attempts >= 1: + try: + redshift_hook.get_conn().resume_cluster(ClusterIdentifier=self.cluster_identifier) + break + except redshift_hook.get_conn().exceptions.InvalidClusterStateFault as error: + self._remaining_attempts = self._remaining_attempts - 1 + if self._remaining_attempts > 0: + self.log.error( + "Unable to resume cluster. %d attempts remaining.", self._remaining_attempts + ) + time.sleep(self._attempt_interval) + else: + raise error if self.deferrable: self.defer( - timeout=self.execution_timeout, - trigger=RedshiftClusterTrigger( - task_id=self.task_id, + trigger=RedshiftResumeClusterTrigger( + cluster_identifier=self.cluster_identifier, poll_interval=self.poll_interval, + max_attempts=self.max_attempts, aws_conn_id=self.aws_conn_id, - cluster_identifier=self.cluster_identifier, - attempts=self._attempts, - operation_type="resume_cluster", ), method_name="execute_complete", + # timeout is set to ensure that if a trigger dies, the timeout does not restart + # 60 seconds is added to allow the trigger to exit gracefully (i.e. yield TriggerEvent) + timeout=timedelta(seconds=self.max_attempts * self.poll_interval + 60), ) + if self.wait_for_completion: + waiter = redshift_hook.get_waiter("cluster_resumed") + waiter.wait( + ClusterIdentifier=self.cluster_identifier, + WaiterConfig={ + "Delay": self.poll_interval, + "MaxAttempts": self.max_attempts, + }, + ) + + def execute_complete(self, context, event=None): + if event["status"] != "success": + raise AirflowException(f"Error resuming cluster: {event}") else: - while self._attempts >= 1: - try: - redshift_hook.get_conn().resume_cluster(ClusterIdentifier=self.cluster_identifier) - return - except redshift_hook.get_conn().exceptions.InvalidClusterStateFault as error: - self._attempts = self._attempts - 1 - - if self._attempts > 0: - self.log.error("Unable to resume cluster. %d attempts remaining.", self._attempts) - time.sleep(self._attempt_interval) - else: - raise error - - def execute_complete(self, context: Context, event: Any = None) -> None: - """ - Callback for when the trigger fires - returns immediately. - Relies on trigger to throw an exception, otherwise it assumes execution was - successful. - """ - if event: - if "status" in event and event["status"] == "error": - msg = f"{event['status']}: {event['message']}" - raise AirflowException(msg) - elif "status" in event and event["status"] == "success": - self.log.info("%s completed successfully.", self.task_id) - self.log.info("Resumed cluster successfully") - else: - raise AirflowException("No event received from trigger") + self.log.info("Resumed cluster successfully") + return class RedshiftPauseClusterOperator(BaseOperator): diff --git a/airflow/providers/amazon/aws/triggers/redshift_cluster.py b/airflow/providers/amazon/aws/triggers/redshift_cluster.py index 06f008d695d19..3224350e5403e 100644 --- a/airflow/providers/amazon/aws/triggers/redshift_cluster.py +++ b/airflow/providers/amazon/aws/triggers/redshift_cluster.py @@ -285,3 +285,75 @@ async def run(self): ) else: yield TriggerEvent({"status": "success", "message": "Cluster Snapshot Created"}) + + +class RedshiftResumeClusterTrigger(BaseTrigger): + """ + Trigger for RedshiftResumeClusterOperator. + The trigger will asynchronously poll the boto3 API and wait for the + Redshift cluster to be in the `available` state. + + :param cluster_identifier: A unique identifier for the cluster. + :param poll_interval: The amount of time in seconds to wait between attempts. + :param max_attempts: The maximum number of attempts to be made. + :param aws_conn_id: The Airflow connection used for AWS credentials. + """ + + def __init__( + self, + cluster_identifier: str, + poll_interval: int, + max_attempts: int, + aws_conn_id: str, + ): + self.cluster_identifier = cluster_identifier + self.poll_interval = poll_interval + self.max_attempts = max_attempts + self.aws_conn_id = aws_conn_id + + def serialize(self) -> tuple[str, dict[str, Any]]: + return ( + "airflow.providers.amazon.aws.triggers.redshift_cluster.RedshiftResumeClusterTrigger", + { + "cluster_identifier": self.cluster_identifier, + "poll_interval": str(self.poll_interval), + "max_attempts": str(self.max_attempts), + "aws_conn_id": self.aws_conn_id, + }, + ) + + @cached_property + def hook(self) -> RedshiftHook: + return RedshiftHook(aws_conn_id=self.aws_conn_id) + + async def run(self): + async with self.hook.async_conn as client: + attempt = 0 + waiter = self.hook.get_waiter("cluster_resumed", deferrable=True, client=client) + while attempt < int(self.max_attempts): + attempt = attempt + 1 + try: + await waiter.wait( + ClusterIdentifier=self.cluster_identifier, + WaiterConfig={ + "Delay": int(self.poll_interval), + "MaxAttempts": 1, + }, + ) + break + except WaiterError as error: + if "terminal failure" in str(error): + yield TriggerEvent( + {"status": "failure", "message": f"Resume Cluster Failed: {error}"} + ) + break + self.log.info( + "Status of cluster is %s", error.last_response["Clusters"][0]["ClusterStatus"] + ) + await asyncio.sleep(int(self.poll_interval)) + if attempt >= int(self.max_attempts): + yield TriggerEvent( + {"status": "failure", "message": "Resume Cluster Failed - max attempts reached."} + ) + else: + yield TriggerEvent({"status": "success", "message": "Cluster resumed"}) diff --git a/airflow/providers/amazon/aws/waiters/redshift.json b/airflow/providers/amazon/aws/waiters/redshift.json index 587f8ce989702..8165eb3fc439a 100644 --- a/airflow/providers/amazon/aws/waiters/redshift.json +++ b/airflow/providers/amazon/aws/waiters/redshift.json @@ -25,6 +25,31 @@ "state": "failure" } ] + }, + "cluster_resumed": { + "operation": "DescribeClusters", + "delay": 30, + "maxAttempts": 60, + "acceptors": [ + { + "matcher": "pathAll", + "argument": "Clusters[].ClusterStatus", + "expected": "available", + "state": "success" + }, + { + "matcher": "error", + "argument": "Clusters[].ClusterStatus", + "expected": "ClusterNotFound", + "state": "retry" + }, + { + "matcher": "pathAny", + "argument": "Clusters[].ClusterStatus", + "expected": "deleting", + "state": "failure" + } + ] } } } diff --git a/tests/providers/amazon/aws/operators/test_redshift_cluster.py b/tests/providers/amazon/aws/operators/test_redshift_cluster.py index 82f9b19f2c0c4..28b8d28642eb0 100644 --- a/tests/providers/amazon/aws/operators/test_redshift_cluster.py +++ b/tests/providers/amazon/aws/operators/test_redshift_cluster.py @@ -32,9 +32,9 @@ RedshiftResumeClusterOperator, ) from airflow.providers.amazon.aws.triggers.redshift_cluster import ( - RedshiftClusterTrigger, RedshiftCreateClusterSnapshotTrigger, RedshiftPauseClusterTrigger, + RedshiftResumeClusterTrigger, ) @@ -264,7 +264,7 @@ def test_init(self): assert redshift_operator.cluster_identifier == "test_cluster" assert redshift_operator.aws_conn_id == "aws_conn_test" - @mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.get_conn") + @mock.patch.object(RedshiftHook, "get_conn") def test_resume_cluster_is_called_when_cluster_is_paused(self, mock_get_conn): redshift_operator = RedshiftResumeClusterOperator( task_id="task_test", cluster_identifier="test_cluster", aws_conn_id="aws_conn_test" @@ -272,7 +272,7 @@ def test_resume_cluster_is_called_when_cluster_is_paused(self, mock_get_conn): redshift_operator.execute(None) mock_get_conn.return_value.resume_cluster.assert_called_once_with(ClusterIdentifier="test_cluster") - @mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.conn") + @mock.patch.object(RedshiftHook, "conn") @mock.patch("time.sleep", return_value=None) def test_resume_cluster_multiple_attempts(self, mock_sleep, mock_conn): exception = boto3.client("redshift").exceptions.InvalidClusterStateFault({}, "test") @@ -288,7 +288,7 @@ def test_resume_cluster_multiple_attempts(self, mock_sleep, mock_conn): redshift_operator.execute(None) assert mock_conn.resume_cluster.call_count == 3 - @mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.conn") + @mock.patch.object(RedshiftHook, "conn") @mock.patch("time.sleep", return_value=None) def test_resume_cluster_multiple_attempts_fail(self, mock_sleep, mock_conn): exception = boto3.client("redshift").exceptions.InvalidClusterStateFault({}, "test") @@ -306,16 +306,10 @@ def test_resume_cluster_multiple_attempts_fail(self, mock_sleep, mock_conn): redshift_operator.execute(None) assert mock_conn.resume_cluster.call_count == 10 - @mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.cluster_status") - @mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftAsyncHook.resume_cluster") - @mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftAsyncHook.get_client_async") - def test_resume_cluster(self, mock_async_client, mock_async_resume_cluster, mock_sync_cluster_status): - """Test Resume cluster operator run""" - mock_sync_cluster_status.return_value = "paused" - mock_async_client.return_value.resume_cluster.return_value = { - "Cluster": {"ClusterIdentifier": "test_cluster", "ClusterStatus": "resuming"} - } - mock_async_resume_cluster.return_value = {"status": "success", "cluster_state": "available"} + @mock.patch.object(RedshiftHook, "conn") + def test_resume_cluster_deferrable(self, mock_conn): + """Test Resume cluster operator deferrable""" + mock_conn.resume_cluster.return_value = True redshift_operator = RedshiftResumeClusterOperator( task_id="task_test", @@ -328,22 +322,33 @@ def test_resume_cluster(self, mock_async_client, mock_async_resume_cluster, mock redshift_operator.execute({}) assert isinstance( - exc.value.trigger, RedshiftClusterTrigger - ), "Trigger is not a RedshiftClusterTrigger" + exc.value.trigger, RedshiftResumeClusterTrigger + ), "Trigger is not a RedshiftResumeClusterTrigger" - @mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.cluster_status") - @mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftAsyncHook.resume_cluster") - @mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftAsyncHook.get_client_async") - def test_resume_cluster_failure( - self, mock_async_client, mock_async_resume_cluster, mock_sync_cluster_statue - ): - """Test Resume cluster operator Failure""" - mock_sync_cluster_statue.return_value = "paused" - mock_async_client.return_value.resume_cluster.return_value = { - "Cluster": {"ClusterIdentifier": "test_cluster", "ClusterStatus": "resuming"} - } - mock_async_resume_cluster.return_value = {"status": "success", "cluster_state": "available"} + @mock.patch.object(RedshiftHook, "get_waiter") + @mock.patch.object(RedshiftHook, "conn") + def test_resume_cluster_wait_for_completion(self, mock_conn, mock_get_waiter): + """Test Resume cluster operator wait for complettion""" + mock_conn.resume_cluster.return_value = True + mock_get_waiter().wait.return_value = None + + redshift_operator = RedshiftResumeClusterOperator( + task_id="task_test", + cluster_identifier="test_cluster", + aws_conn_id="aws_conn_test", + wait_for_completion=True, + ) + redshift_operator.execute(None) + mock_conn.resume_cluster.assert_called_once_with(ClusterIdentifier="test_cluster") + mock_get_waiter.assert_called_with("cluster_resumed") + assert mock_get_waiter.call_count == 2 + mock_get_waiter().wait.assert_called_once_with( + ClusterIdentifier="test_cluster", WaiterConfig={"Delay": 10, "MaxAttempts": 10} + ) + + def test_resume_cluster_failure(self): + """Test Resume cluster operator Failure""" redshift_operator = RedshiftResumeClusterOperator( task_id="task_test", cluster_identifier="test_cluster", diff --git a/tests/providers/amazon/aws/triggers/test_redshift_cluster.py b/tests/providers/amazon/aws/triggers/test_redshift_cluster.py index b79286a093ce8..92379b3ac3f24 100644 --- a/tests/providers/amazon/aws/triggers/test_redshift_cluster.py +++ b/tests/providers/amazon/aws/triggers/test_redshift_cluster.py @@ -26,6 +26,7 @@ RedshiftCreateClusterSnapshotTrigger, RedshiftCreateClusterTrigger, RedshiftPauseClusterTrigger, + RedshiftResumeClusterTrigger, ) from airflow.triggers.base import TriggerEvent @@ -364,3 +365,144 @@ async def test_redshift_create_cluster_snapshot_trigger_run_attempts_failed( assert response == TriggerEvent( {"status": "failure", "message": f"Create Cluster Snapshot Failed: {error_failed}"} ) + + +class TestRedshiftResumeClusterTrigger: + def test_redshift_resume_cluster_trigger_serialize(self): + redshift_resume_cluster_trigger = RedshiftResumeClusterTrigger( + cluster_identifier=TEST_CLUSTER_IDENTIFIER, + poll_interval=TEST_POLL_INTERVAL, + max_attempts=TEST_MAX_ATTEMPT, + aws_conn_id=TEST_AWS_CONN_ID, + ) + class_path, args = redshift_resume_cluster_trigger.serialize() + assert ( + class_path + == "airflow.providers.amazon.aws.triggers.redshift_cluster.RedshiftResumeClusterTrigger" + ) + assert args["cluster_identifier"] == TEST_CLUSTER_IDENTIFIER + assert args["poll_interval"] == str(TEST_POLL_INTERVAL) + assert args["max_attempts"] == str(TEST_MAX_ATTEMPT) + assert args["aws_conn_id"] == TEST_AWS_CONN_ID + + @pytest.mark.asyncio + @async_mock.patch.object(RedshiftHook, "get_waiter") + @async_mock.patch.object(RedshiftHook, "async_conn") + async def test_redshift_resume_cluster_trigger_run(self, mock_async_conn, mock_get_waiter): + mock = async_mock.MagicMock() + mock_async_conn.__aenter__.return_value = mock + + mock_get_waiter().wait = AsyncMock() + + redshift_resume_cluster_trigger = RedshiftResumeClusterTrigger( + cluster_identifier=TEST_CLUSTER_IDENTIFIER, + poll_interval=TEST_POLL_INTERVAL, + max_attempts=TEST_MAX_ATTEMPT, + aws_conn_id=TEST_AWS_CONN_ID, + ) + + generator = redshift_resume_cluster_trigger.run() + response = await generator.asend(None) + + assert response == TriggerEvent({"status": "success", "message": "Cluster resumed"}) + + @pytest.mark.asyncio + @async_mock.patch("asyncio.sleep") + @async_mock.patch.object(RedshiftHook, "get_waiter") + @async_mock.patch.object(RedshiftHook, "async_conn") + async def test_redshift_resume_cluster_trigger_run_multiple_attempts( + self, mock_async_conn, mock_get_waiter, mock_sleep + ): + mock = async_mock.MagicMock() + mock_async_conn.__aenter__.return_value = mock + error = WaiterError( + name="test_name", + reason="test_reason", + last_response={"Clusters": [{"ClusterStatus": "available"}]}, + ) + mock_get_waiter().wait.side_effect = AsyncMock(side_effect=[error, error, True]) + mock_sleep.return_value = True + + redshift_resume_cluster_trigger = RedshiftResumeClusterTrigger( + cluster_identifier=TEST_CLUSTER_IDENTIFIER, + poll_interval=TEST_POLL_INTERVAL, + max_attempts=TEST_MAX_ATTEMPT, + aws_conn_id=TEST_AWS_CONN_ID, + ) + + generator = redshift_resume_cluster_trigger.run() + response = await generator.asend(None) + + assert mock_get_waiter().wait.call_count == 3 + assert response == TriggerEvent({"status": "success", "message": "Cluster resumed"}) + + @pytest.mark.asyncio + @async_mock.patch("asyncio.sleep") + @async_mock.patch.object(RedshiftHook, "get_waiter") + @async_mock.patch.object(RedshiftHook, "async_conn") + async def test_redshift_resume_cluster_trigger_run_attempts_exceeded( + self, mock_async_conn, mock_get_waiter, mock_sleep + ): + mock = async_mock.MagicMock() + mock_async_conn.__aenter__.return_value = mock + error = WaiterError( + name="test_name", + reason="test_reason", + last_response={"Clusters": [{"ClusterStatus": "available"}]}, + ) + mock_get_waiter().wait.side_effect = AsyncMock(side_effect=[error, error, True]) + mock_sleep.return_value = True + + redshift_resume_cluster_trigger = RedshiftResumeClusterTrigger( + cluster_identifier=TEST_CLUSTER_IDENTIFIER, + poll_interval=TEST_POLL_INTERVAL, + max_attempts=2, + aws_conn_id=TEST_AWS_CONN_ID, + ) + + generator = redshift_resume_cluster_trigger.run() + response = await generator.asend(None) + + assert mock_get_waiter().wait.call_count == 2 + assert response == TriggerEvent( + {"status": "failure", "message": "Resume Cluster Failed - max attempts reached."} + ) + + @pytest.mark.asyncio + @async_mock.patch("asyncio.sleep") + @async_mock.patch.object(RedshiftHook, "get_waiter") + @async_mock.patch.object(RedshiftHook, "async_conn") + async def test_redshift_resume_cluster_trigger_run_attempts_failed( + self, mock_async_conn, mock_get_waiter, mock_sleep + ): + mock = async_mock.MagicMock() + mock_async_conn.__aenter__.return_value = mock + error_available = WaiterError( + name="test_name", + reason="Max attempts exceeded", + last_response={"Clusters": [{"ClusterStatus": "available"}]}, + ) + error_failed = WaiterError( + name="test_name", + reason="Waiter encountered a terminal failure state:", + last_response={"Clusters": [{"ClusterStatus": "available"}]}, + ) + mock_get_waiter().wait.side_effect = AsyncMock( + side_effect=[error_available, error_available, error_failed] + ) + mock_sleep.return_value = True + + redshift_resume_cluster_trigger = RedshiftResumeClusterTrigger( + cluster_identifier=TEST_CLUSTER_IDENTIFIER, + poll_interval=TEST_POLL_INTERVAL, + max_attempts=TEST_MAX_ATTEMPT, + aws_conn_id=TEST_AWS_CONN_ID, + ) + + generator = redshift_resume_cluster_trigger.run() + response = await generator.asend(None) + + assert mock_get_waiter().wait.call_count == 3 + assert response == TriggerEvent( + {"status": "failure", "message": f"Resume Cluster Failed: {error_failed}"} + )