diff --git a/airflow/providers/amazon/aws/operators/redshift_cluster.py b/airflow/providers/amazon/aws/operators/redshift_cluster.py index abf0e68f39435..5b90ed3cdd756 100644 --- a/airflow/providers/amazon/aws/operators/redshift_cluster.py +++ b/airflow/providers/amazon/aws/operators/redshift_cluster.py @@ -26,6 +26,7 @@ from airflow.providers.amazon.aws.triggers.redshift_cluster import ( RedshiftCreateClusterSnapshotTrigger, RedshiftCreateClusterTrigger, + RedshiftDeleteClusterTrigger, RedshiftPauseClusterTrigger, RedshiftResumeClusterTrigger, ) @@ -629,6 +630,8 @@ class RedshiftDeleteClusterOperator(BaseOperator): The default value is ``True`` :param aws_conn_id: aws connection to use :param poll_interval: Time (in seconds) to wait between two consecutive calls to check cluster state + :param deferrable: Run operator in the deferrable mode. + :param max_attempts: (Deferrable mode only) The maximum number of attempts to be made """ template_fields: Sequence[str] = ("cluster_identifier",) @@ -643,7 +646,9 @@ def __init__( final_cluster_snapshot_identifier: str | None = None, wait_for_completion: bool = True, aws_conn_id: str = "aws_default", - poll_interval: float = 30.0, + poll_interval: int = 30, + deferrable: bool = False, + max_attempts: int = 30, **kwargs, ): super().__init__(**kwargs) @@ -658,8 +663,12 @@ def __init__( self._attempts = 10 self._attempt_interval = 15 self.redshift_hook = RedshiftHook(aws_conn_id=aws_conn_id) + self.aws_conn_id = aws_conn_id + self.deferrable = deferrable + self.max_attempts = max_attempts def execute(self, context: Context): + while self._attempts >= 1: try: self.redshift_hook.delete_cluster( @@ -676,10 +685,26 @@ def execute(self, context: Context): time.sleep(self._attempt_interval) else: raise - - if self.wait_for_completion: + if self.deferrable: + self.defer( + timeout=timedelta(seconds=self.max_attempts * self.poll_interval + 60), + trigger=RedshiftDeleteClusterTrigger( + cluster_identifier=self.cluster_identifier, + poll_interval=self.poll_interval, + max_attempts=self.max_attempts, + aws_conn_id=self.aws_conn_id, + ), + method_name="execute_complete", + ) + elif self.wait_for_completion: waiter = self.redshift_hook.get_conn().get_waiter("cluster_deleted") waiter.wait( ClusterIdentifier=self.cluster_identifier, - WaiterConfig={"Delay": self.poll_interval, "MaxAttempts": 30}, + WaiterConfig={"Delay": self.poll_interval, "MaxAttempts": self.max_attempts}, ) + + def execute_complete(self, context, event=None): + if event["status"] != "success": + raise AirflowException(f"Error deleting cluster: {event}") + else: + self.log.info("Cluster deleted successfully") diff --git a/airflow/providers/amazon/aws/triggers/redshift_cluster.py b/airflow/providers/amazon/aws/triggers/redshift_cluster.py index 6b1a16bc827e1..0b4813e4d71c2 100644 --- a/airflow/providers/amazon/aws/triggers/redshift_cluster.py +++ b/airflow/providers/amazon/aws/triggers/redshift_cluster.py @@ -357,3 +357,78 @@ async def run(self): ) else: yield TriggerEvent({"status": "success", "message": "Cluster resumed"}) + + +class RedshiftDeleteClusterTrigger(BaseTrigger): + """ + Trigger for RedshiftDeleteClusterOperator + + :param cluster_identifier: A unique identifier for the cluster. + :param max_attempts: The maximum number of attempts to be made. + :param aws_conn_id: The Airflow connection used for AWS credentials. + :param poll_interval: The amount of time in seconds to wait between attempts. + """ + + def __init__( + self, + cluster_identifier: str, + max_attempts: int = 30, + aws_conn_id: str = "aws_default", + poll_interval: int = 30, + ): + super().__init__() + self.cluster_identifier = cluster_identifier + self.max_attempts = max_attempts + self.aws_conn_id = aws_conn_id + self.poll_interval = poll_interval + + def serialize(self) -> tuple[str, dict[str, Any]]: + return ( + "airflow.providers.amazon.aws.triggers.redshift_cluster.RedshiftDeleteClusterTrigger", + { + "cluster_identifier": self.cluster_identifier, + "max_attempts": self.max_attempts, + "aws_conn_id": self.aws_conn_id, + "poll_interval": self.poll_interval, + }, + ) + + @cached_property + def hook(self): + return RedshiftHook(aws_conn_id=self.aws_conn_id) + + async def run(self) -> AsyncIterator[TriggerEvent]: + async with self.hook.async_conn as client: + attempt = 0 + waiter = client.get_waiter("cluster_deleted") + while attempt < self.max_attempts: + attempt = attempt + 1 + try: + await waiter.wait( + ClusterIdentifier=self.cluster_identifier, + WaiterConfig={ + "Delay": self.poll_interval, + "MaxAttempts": 1, + }, + ) + break + except WaiterError as error: + if "terminal failure" in str(error): + yield TriggerEvent( + {"status": "failure", "message": f"Delete Cluster Failed: {error}"} + ) + break + self.log.info( + "Cluster status is %s. Retrying attempt %s/%s", + error.last_response["Clusters"][0]["ClusterStatus"], + attempt, + self.max_attempts, + ) + await asyncio.sleep(int(self.poll_interval)) + + if attempt >= self.max_attempts: + yield TriggerEvent( + {"status": "failure", "message": "Delete Cluster Failed - max attempts reached."} + ) + else: + yield TriggerEvent({"status": "success", "message": "Cluster deleted."}) diff --git a/docs/apache-airflow-providers-amazon/operators/redshift/redshift_cluster.rst b/docs/apache-airflow-providers-amazon/operators/redshift/redshift_cluster.rst index 7c5d22486b39f..a50e56a760bbd 100644 --- a/docs/apache-airflow-providers-amazon/operators/redshift/redshift_cluster.rst +++ b/docs/apache-airflow-providers-amazon/operators/redshift/redshift_cluster.rst @@ -53,7 +53,8 @@ Resume an Amazon Redshift cluster To resume a 'paused' Amazon Redshift cluster you can use :class:`RedshiftResumeClusterOperator ` -You can also run this operator in deferrable mode by setting ``deferrable`` param to ``True`` +You can also run this operator in deferrable mode by setting ``deferrable`` param to ``True``. +This will ensure that the task is deferred from the Airflow worker slot and polling for the task status happens on the trigger. .. exampleinclude:: /../../tests/system/providers/amazon/aws/example_redshift.py :language: python @@ -110,7 +111,8 @@ Delete an Amazon Redshift cluster ================================= To delete an Amazon Redshift cluster you can use -:class:`RedshiftDeleteClusterOperator ` +:class:`RedshiftDeleteClusterOperator `. +You can also run this operator in deferrable mode by setting ``deferrable`` param to ``True`` .. exampleinclude:: /../../tests/system/providers/amazon/aws/example_redshift.py :language: python diff --git a/tests/providers/amazon/aws/operators/test_redshift_cluster.py b/tests/providers/amazon/aws/operators/test_redshift_cluster.py index 28b8d28642eb0..34b1011423c5b 100644 --- a/tests/providers/amazon/aws/operators/test_redshift_cluster.py +++ b/tests/providers/amazon/aws/operators/test_redshift_cluster.py @@ -33,6 +33,7 @@ ) from airflow.providers.amazon.aws.triggers.redshift_cluster import ( RedshiftCreateClusterSnapshotTrigger, + RedshiftDeleteClusterTrigger, RedshiftPauseClusterTrigger, RedshiftResumeClusterTrigger, ) @@ -520,3 +521,46 @@ def test_delete_cluster_multiple_attempts_fail(self, _, mock_conn, mock_delete_c redshift_operator.execute(None) assert mock_delete_cluster.call_count == 10 + + @mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.delete_cluster") + def test_delete_cluster_deferrable_mode(self, mock_delete_cluster): + """Test delete cluster operator with defer when deferrable param is true""" + mock_delete_cluster.return_value = True + delete_cluster = RedshiftDeleteClusterOperator( + task_id="task_test", + cluster_identifier="test_cluster", + deferrable=True, + wait_for_completion=False, + ) + + with pytest.raises(TaskDeferred) as exc: + delete_cluster.execute(context=None) + + assert isinstance( + exc.value.trigger, RedshiftDeleteClusterTrigger + ), "Trigger is not a RedshiftDeleteClusterTrigger" + + def test_delete_cluster_execute_complete_success(self): + """Asserts that logging occurs as expected""" + task = RedshiftDeleteClusterOperator( + task_id="task_test", + cluster_identifier="test_cluster", + deferrable=True, + wait_for_completion=False, + ) + with mock.patch.object(task.log, "info") as mock_log_info: + task.execute_complete(context=None, event={"status": "success", "message": "Cluster deleted"}) + mock_log_info.assert_called_with("Cluster deleted successfully") + + def test_delete_cluster_execute_complete_fail(self): + redshift_operator = RedshiftDeleteClusterOperator( + task_id="task_test", + cluster_identifier="test_cluster", + deferrable=True, + wait_for_completion=False, + ) + + with pytest.raises(AirflowException): + redshift_operator.execute_complete( + context=None, event={"status": "error", "message": "test failure message"} + ) diff --git a/tests/providers/amazon/aws/triggers/test_redshift_cluster.py b/tests/providers/amazon/aws/triggers/test_redshift_cluster.py index 3835aae41af37..43a2ee0ae1847 100644 --- a/tests/providers/amazon/aws/triggers/test_redshift_cluster.py +++ b/tests/providers/amazon/aws/triggers/test_redshift_cluster.py @@ -26,6 +26,7 @@ from airflow.providers.amazon.aws.triggers.redshift_cluster import ( RedshiftCreateClusterSnapshotTrigger, RedshiftCreateClusterTrigger, + RedshiftDeleteClusterTrigger, RedshiftPauseClusterTrigger, RedshiftResumeClusterTrigger, ) @@ -500,3 +501,135 @@ async def test_redshift_resume_cluster_trigger_run_attempts_failed( assert response == TriggerEvent( {"status": "failure", "message": f"Resume Cluster Failed: {error_failed}"} ) + + +class TestRedshiftDeleteClusterTrigger: + def test_redshift_delete_cluster_trigger_serialize(self): + redshift_delete_cluster_trigger = RedshiftDeleteClusterTrigger( + cluster_identifier=TEST_CLUSTER_IDENTIFIER, + poll_interval=TEST_POLL_INTERVAL, + max_attempts=TEST_MAX_ATTEMPT, + aws_conn_id=TEST_AWS_CONN_ID, + ) + class_path, args = redshift_delete_cluster_trigger.serialize() + assert ( + class_path + == "airflow.providers.amazon.aws.triggers.redshift_cluster.RedshiftDeleteClusterTrigger" + ) + assert args["cluster_identifier"] == TEST_CLUSTER_IDENTIFIER + assert args["poll_interval"] == TEST_POLL_INTERVAL + assert args["max_attempts"] == TEST_MAX_ATTEMPT + assert args["aws_conn_id"] == TEST_AWS_CONN_ID + + @pytest.mark.asyncio + @mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.async_conn") + async def test_redshift_delete_cluster_trigger_run(self, mock_async_conn): + a_mock = mock.MagicMock() + mock_async_conn.__aenter__.return_value = a_mock + a_mock.get_waiter().wait = AsyncMock() + + redshift_delete_cluster_trigger = RedshiftDeleteClusterTrigger( + cluster_identifier=TEST_CLUSTER_IDENTIFIER, + poll_interval=TEST_POLL_INTERVAL, + max_attempts=TEST_MAX_ATTEMPT, + aws_conn_id=TEST_AWS_CONN_ID, + ) + + generator = redshift_delete_cluster_trigger.run() + response = await generator.asend(None) + + assert response == TriggerEvent({"status": "success", "message": "Cluster deleted."}) + + @pytest.mark.asyncio + @mock.patch("asyncio.sleep") + @mock.patch.object(RedshiftHook, "async_conn") + async def test_redshift_delete_cluster_trigger_run_multiple_attempts(self, mock_async_conn, mock_sleep): + a_mock = mock.MagicMock() + mock_async_conn.__aenter__.return_value = a_mock + error = WaiterError( + name="test_name", + reason="test_reason", + last_response={"Clusters": [{"ClusterStatus": "available"}]}, + ) + a_mock.get_waiter().wait.side_effect = AsyncMock(side_effect=[error, error, True]) + mock_sleep.return_value = True + + redshift_delete_cluster_trigger = RedshiftDeleteClusterTrigger( + cluster_identifier=TEST_CLUSTER_IDENTIFIER, + poll_interval=TEST_POLL_INTERVAL, + max_attempts=TEST_MAX_ATTEMPT, + aws_conn_id=TEST_AWS_CONN_ID, + ) + + generator = redshift_delete_cluster_trigger.run() + response = await generator.asend(None) + + assert a_mock.get_waiter().wait.call_count == 3 + assert response == TriggerEvent({"status": "success", "message": "Cluster deleted."}) + + @pytest.mark.asyncio + @mock.patch("asyncio.sleep") + @mock.patch.object(RedshiftHook, "async_conn") + async def test_redshift_delete_cluster_trigger_run_attempts_exceeded(self, mock_async_conn, mock_sleep): + a_mock = mock.MagicMock() + mock_async_conn.__aenter__.return_value = a_mock + + error = WaiterError( + name="test_name", + reason="test_reason", + last_response={"Clusters": [{"ClusterStatus": "deleting"}]}, + ) + a_mock.get_waiter().wait.side_effect = AsyncMock(side_effect=[error, error, True]) + mock_sleep.return_value = True + + redshift_delete_cluster_trigger = RedshiftDeleteClusterTrigger( + cluster_identifier=TEST_CLUSTER_IDENTIFIER, + poll_interval=TEST_POLL_INTERVAL, + max_attempts=2, + aws_conn_id=TEST_AWS_CONN_ID, + ) + + generator = redshift_delete_cluster_trigger.run() + response = await generator.asend(None) + + assert a_mock.get_waiter().wait.call_count == 2 + assert response == TriggerEvent( + {"status": "failure", "message": "Delete Cluster Failed - max attempts reached."} + ) + + @pytest.mark.asyncio + @mock.patch("asyncio.sleep") + @mock.patch.object(RedshiftHook, "async_conn") + async def test_redshift_delete_cluster_trigger_run_attempts_failed(self, mock_async_conn, mock_sleep): + a_mock = mock.MagicMock() + mock_async_conn.__aenter__.return_value = a_mock + + error_available = WaiterError( + name="test_name", + reason="Max attempts exceeded", + last_response={"Clusters": [{"ClusterStatus": "deleting"}]}, + ) + error_failed = WaiterError( + name="test_name", + reason="Waiter encountered a terminal failure state:", + last_response={"Clusters": [{"ClusterStatus": "available"}]}, + ) + a_mock.get_waiter().wait.side_effect = AsyncMock( + side_effect=[error_available, error_available, error_failed] + ) + mock_sleep.return_value = True + + redshift_delete_cluster_trigger = RedshiftDeleteClusterTrigger( + cluster_identifier=TEST_CLUSTER_IDENTIFIER, + poll_interval=TEST_POLL_INTERVAL, + max_attempts=TEST_MAX_ATTEMPT, + aws_conn_id=TEST_AWS_CONN_ID, + ) + + generator = redshift_delete_cluster_trigger.run() + response = await generator.asend(None) + + assert a_mock.get_waiter().wait.call_count == 3 + assert response == TriggerEvent( + {"status": "failure", "message": f"Delete Cluster Failed: {error_failed}"} + )