diff --git a/airflow/providers/amazon/aws/hooks/rds.py b/airflow/providers/amazon/aws/hooks/rds.py index 16c72c1b26cb2..fdaee7f742ab4 100644 --- a/airflow/providers/amazon/aws/hooks/rds.py +++ b/airflow/providers/amazon/aws/hooks/rds.py @@ -23,6 +23,7 @@ from airflow.exceptions import AirflowException, AirflowNotFoundException from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook +from airflow.providers.amazon.aws.utils.waiter_with_logging import wait if TYPE_CHECKING: from mypy_boto3_rds import RDSClient # noqa @@ -269,9 +270,14 @@ def poke(): target_state = target_state.lower() if target_state in ("available", "deleted"): waiter = self.conn.get_waiter(f"db_instance_{target_state}") # type: ignore - waiter.wait( - DBInstanceIdentifier=db_instance_id, - WaiterConfig={"Delay": check_interval, "MaxAttempts": max_attempts}, + wait( + waiter=waiter, + waiter_delay=check_interval, + waiter_max_attempts=max_attempts, + args={"DBInstanceIdentifier": db_instance_id}, + failure_message=f"Rdb DB instance failed to reach state {target_state}", + status_message="Rds DB instance state is", + status_args=["DBInstances[0].DBInstanceStatus"], ) else: self._wait_for_state(poke, target_state, check_interval, max_attempts) diff --git a/airflow/providers/amazon/aws/operators/rds.py b/airflow/providers/amazon/aws/operators/rds.py index 4d9b0889789ae..1e6a740731750 100644 --- a/airflow/providers/amazon/aws/operators/rds.py +++ b/airflow/providers/amazon/aws/operators/rds.py @@ -18,14 +18,18 @@ from __future__ import annotations import json +from datetime import timedelta from typing import TYPE_CHECKING, Sequence from mypy_boto3_rds.type_defs import TagTypeDef +from airflow.exceptions import AirflowException from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.rds import RdsHook +from airflow.providers.amazon.aws.triggers.rds import RdsDbInstanceTrigger from airflow.providers.amazon.aws.utils.rds import RdsDbType from airflow.providers.amazon.aws.utils.tags import format_tags +from airflow.providers.amazon.aws.utils.waiter_with_logging import wait if TYPE_CHECKING: from airflow.utils.context import Context @@ -38,8 +42,8 @@ class RdsBaseOperator(BaseOperator): ui_fgcolor = "#ffffff" def __init__(self, *args, aws_conn_id: str = "aws_conn_id", hook_params: dict | None = None, **kwargs): - hook_params = hook_params or {} - self.hook = RdsHook(aws_conn_id=aws_conn_id, **hook_params) + self.hook_params = hook_params or {} + self.hook = RdsHook(aws_conn_id=aws_conn_id, **self.hook_params) super().__init__(*args, **kwargs) self._await_interval = 60 # seconds @@ -521,6 +525,11 @@ class RdsCreateDbInstanceOperator(RdsBaseOperator): https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/rds.html#RDS.Client.create_db_instance :param aws_conn_id: The Airflow connection used for AWS credentials. :param wait_for_completion: If True, waits for creation of the DB instance to complete. (default: True) + :param waiter_delay: Time (in seconds) to wait between two consecutive calls to check DB instance state + :param waiter_max_attempts: The maximum number of attempts to check DB instance state + :param deferrable: If True, the operator will wait asynchronously for the DB instance to be created. + This implies waiting for completion. This mode requires aiobotocore module to be installed. + (default: False) """ template_fields = ("db_instance_identifier", "db_instance_class", "engine", "rds_kwargs") @@ -534,6 +543,9 @@ def __init__( rds_kwargs: dict | None = None, aws_conn_id: str = "aws_default", wait_for_completion: bool = True, + deferrable: bool = False, + waiter_delay: int = 30, + waiter_max_attempts: int = 60, **kwargs, ): super().__init__(aws_conn_id=aws_conn_id, **kwargs) @@ -542,7 +554,11 @@ def __init__( self.db_instance_class = db_instance_class self.engine = engine self.rds_kwargs = rds_kwargs or {} - self.wait_for_completion = wait_for_completion + self.wait_for_completion = False if deferrable else wait_for_completion + self.deferrable = deferrable + self.waiter_delay = waiter_delay + self.waiter_max_attempts = waiter_max_attempts + self.aws_conn_id = aws_conn_id def execute(self, context: Context) -> str: self.log.info("Creating new DB instance %s", self.db_instance_identifier) @@ -553,11 +569,41 @@ def execute(self, context: Context) -> str: Engine=self.engine, **self.rds_kwargs, ) + if self.deferrable: + self.defer( + trigger=RdsDbInstanceTrigger( + db_instance_identifier=self.db_instance_identifier, + waiter_delay=self.waiter_delay, + waiter_max_attempts=self.waiter_max_attempts, + aws_conn_id=self.aws_conn_id, + hook_params=self.hook_params, + waiter_name="db_instance_available", + # ignoring type because create_db_instance is a dict + response=create_db_instance, # type: ignore[arg-type] + ), + method_name="execute_complete", + timeout=timedelta(seconds=self.waiter_delay * self.waiter_max_attempts), + ) if self.wait_for_completion: - self.hook.wait_for_db_instance_state(self.db_instance_identifier, target_state="available") + waiter = self.hook.conn.get_waiter("db_instance_available") + wait( + waiter=waiter, + waiter_delay=self.waiter_delay, + waiter_max_attempts=self.waiter_max_attempts, + args={"DBInstanceIdentifier": self.db_instance_identifier}, + failure_message="DB instance creation failed", + status_message="DB Instance status is", + status_args=["DBInstances[0].DBInstanceStatus"], + ) return json.dumps(create_db_instance, default=str) + def execute_complete(self, context, event=None) -> str: + if event["status"] != "success": + raise AirflowException(f"DB instance creation failed: {event}") + else: + return json.dumps(event["response"], default=str) + class RdsDeleteDbInstanceOperator(RdsBaseOperator): """ @@ -572,6 +618,11 @@ class RdsDeleteDbInstanceOperator(RdsBaseOperator): https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/rds.html#RDS.Client.delete_db_instance :param aws_conn_id: The Airflow connection used for AWS credentials. :param wait_for_completion: If True, waits for deletion of the DB instance to complete. (default: True) + :param waiter_delay: Time (in seconds) to wait between two consecutive calls to check DB instance state + :param waiter_max_attempts: The maximum number of attempts to check DB instance state + :param deferrable: If True, the operator will wait asynchronously for the DB instance to be created. + This implies waiting for completion. This mode requires aiobotocore module to be installed. + (default: False) """ template_fields = ("db_instance_identifier", "rds_kwargs") @@ -583,12 +634,19 @@ def __init__( rds_kwargs: dict | None = None, aws_conn_id: str = "aws_default", wait_for_completion: bool = True, + deferrable: bool = False, + waiter_delay: int = 30, + waiter_max_attempts: int = 60, **kwargs, ): super().__init__(aws_conn_id=aws_conn_id, **kwargs) self.db_instance_identifier = db_instance_identifier self.rds_kwargs = rds_kwargs or {} - self.wait_for_completion = wait_for_completion + self.wait_for_completion = False if deferrable else wait_for_completion + self.deferrable = deferrable + self.waiter_delay = waiter_delay + self.waiter_max_attempts = waiter_max_attempts + self.aws_conn_id = aws_conn_id def execute(self, context: Context) -> str: self.log.info("Deleting DB instance %s", self.db_instance_identifier) @@ -597,11 +655,41 @@ def execute(self, context: Context) -> str: DBInstanceIdentifier=self.db_instance_identifier, **self.rds_kwargs, ) + if self.deferrable: + self.defer( + trigger=RdsDbInstanceTrigger( + db_instance_identifier=self.db_instance_identifier, + waiter_delay=self.waiter_delay, + waiter_max_attempts=self.waiter_max_attempts, + aws_conn_id=self.aws_conn_id, + hook_params=self.hook_params, + waiter_name="db_instance_deleted", + # ignoring type because delete_db_instance is a dict + response=delete_db_instance, # type: ignore[arg-type] + ), + method_name="execute_complete", + timeout=timedelta(seconds=self.waiter_delay * self.waiter_max_attempts), + ) if self.wait_for_completion: - self.hook.wait_for_db_instance_state(self.db_instance_identifier, target_state="deleted") + waiter = self.hook.conn.get_waiter("db_instance_deleted") + wait( + waiter=waiter, + waiter_delay=self.waiter_delay, + waiter_max_attempts=self.waiter_max_attempts, + args={"DBInstanceIdentifier": self.db_instance_identifier}, + failure_message="DB instance deletion failed", + status_message="DB Instance status is", + status_args=["DBInstances[0].DBInstanceStatus"], + ) return json.dumps(delete_db_instance, default=str) + def execute_complete(self, context, event=None) -> str: + if event["status"] != "success": + raise AirflowException(f"DB instance deletion failed: {event}") + else: + return json.dumps(event["response"], default=str) + class RdsStartDbOperator(RdsBaseOperator): """ diff --git a/airflow/providers/amazon/aws/triggers/rds.py b/airflow/providers/amazon/aws/triggers/rds.py new file mode 100644 index 0000000000000..0897f764bece9 --- /dev/null +++ b/airflow/providers/amazon/aws/triggers/rds.py @@ -0,0 +1,89 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import Any + +from airflow.providers.amazon.aws.hooks.rds import RdsHook +from airflow.providers.amazon.aws.utils.waiter_with_logging import async_wait +from airflow.triggers.base import BaseTrigger, TriggerEvent + + +class RdsDbInstanceTrigger(BaseTrigger): + """ + Trigger for RdsCreateDbInstanceOperator and RdsDeleteDbInstanceOperator. + + The trigger will asynchronously poll the boto3 API and wait for the + DB instance to be in the state specified by the waiter. + + :param waiter_name: Name of the waiter to use, for instance 'db_instance_available' + or 'db_instance_deleted'. + :param db_instance_identifier: The DB instance identifier for the DB instance to be polled. + :param waiter_delay: The amount of time in seconds to wait between attempts. + :param waiter_max_attempts: The maximum number of attempts to be made. + :param aws_conn_id: The Airflow connection used for AWS credentials. + :param hook_params: The parameters to pass to the RdsHook. + :param response: The response from the RdsHook, to be passed back to the operator. + """ + + def __init__( + self, + waiter_name: str, + db_instance_identifier: str, + waiter_delay: int, + waiter_max_attempts: int, + aws_conn_id: str, + hook_params: dict[str, Any], + response: dict[str, Any], + ): + self.db_instance_identifier = db_instance_identifier + self.waiter_delay = waiter_delay + self.waiter_max_attempts = waiter_max_attempts + self.aws_conn_id = aws_conn_id + self.hook_params = hook_params + self.waiter_name = waiter_name + self.response = response + + def serialize(self) -> tuple[str, dict[str, Any]]: + return ( + # dynamically generate the fully qualified name of the class + self.__class__.__module__ + "." + self.__class__.__qualname__, + { + "db_instance_identifier": self.db_instance_identifier, + "waiter_delay": str(self.waiter_delay), + "waiter_max_attempts": str(self.waiter_max_attempts), + "aws_conn_id": self.aws_conn_id, + "hook_params": self.hook_params, + "waiter_name": self.waiter_name, + "response": self.response, + }, + ) + + async def run(self): + self.hook = RdsHook(aws_conn_id=self.aws_conn_id, **self.hook_params) + async with self.hook.async_conn as client: + waiter = client.get_waiter(self.waiter_name) + await async_wait( + waiter=waiter, + waiter_delay=int(self.waiter_delay), + waiter_max_attempts=int(self.waiter_max_attempts), + args={"DBInstanceIdentifier": self.db_instance_identifier}, + failure_message="Error checking DB Instance status", + status_message="DB instance status is", + status_args=["DBInstances[0].DBInstanceStatus"], + ) + yield TriggerEvent({"status": "success", "response": self.response}) diff --git a/airflow/providers/amazon/provider.yaml b/airflow/providers/amazon/provider.yaml index 5439f9c8cbb76..fb84ee4d36177 100644 --- a/airflow/providers/amazon/provider.yaml +++ b/airflow/providers/amazon/provider.yaml @@ -543,6 +543,9 @@ triggers: - integration-name: Amazon ECS python-modules: - airflow.providers.amazon.aws.triggers.ecs + - integration-name: Amazon RDS + python-modules: + - airflow.providers.amazon.aws.triggers.rds transfers: - source-integration-name: Amazon DynamoDB diff --git a/docs/apache-airflow-providers-amazon/operators/rds.rst b/docs/apache-airflow-providers-amazon/operators/rds.rst index bca9c64af1e0e..e27bbc2d2f699 100644 --- a/docs/apache-airflow-providers-amazon/operators/rds.rst +++ b/docs/apache-airflow-providers-amazon/operators/rds.rst @@ -145,6 +145,7 @@ Create a database instance To create a AWS DB instance you can use :class:`~airflow.providers.amazon.aws.operators.rds.RdsCreateDbInstanceOperator`. +You can also run this operator in deferrable mode by setting ``deferrable`` param to ``True``. .. exampleinclude:: /../../tests/system/providers/amazon/aws/example_rds_instance.py :language: python @@ -159,6 +160,7 @@ Delete a database instance To delete a AWS DB instance you can use :class:`~airflow.providers.amazon.aws.operators.rds.RDSDeleteDbInstanceOperator`. +You can also run this operator in deferrable mode by setting ``deferrable`` param to ``True``. .. exampleinclude:: /../../tests/system/providers/amazon/aws/example_rds_instance.py :language: python diff --git a/tests/providers/amazon/aws/hooks/test_rds.py b/tests/providers/amazon/aws/hooks/test_rds.py index f98320a2f1eb2..a34e43c9443b6 100644 --- a/tests/providers/amazon/aws/hooks/test_rds.py +++ b/tests/providers/amazon/aws/hooks/test_rds.py @@ -153,7 +153,6 @@ def test_wait_for_db_instance_state_boto_waiters(self, rds_hook: RdsHook, db_ins mock.return_value.wait.assert_called_once_with( DBInstanceIdentifier=db_instance_id, WaiterConfig={ - "Delay": self.waiter_args["check_interval"], "MaxAttempts": self.waiter_args["max_attempts"], }, ) diff --git a/tests/providers/amazon/aws/triggers/test_rds.py b/tests/providers/amazon/aws/triggers/test_rds.py new file mode 100644 index 0000000000000..5ae64b83c3bde --- /dev/null +++ b/tests/providers/amazon/aws/triggers/test_rds.py @@ -0,0 +1,184 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock +from unittest.mock import AsyncMock + +import pytest +from botocore.exceptions import WaiterError + +from airflow.exceptions import AirflowException +from airflow.providers.amazon.aws.hooks.rds import RdsHook +from airflow.providers.amazon.aws.triggers.rds import RdsDbInstanceTrigger +from airflow.triggers.base import TriggerEvent + +TEST_DB_INSTANCE_IDENTIFIER = "test-db-instance-identifier" +TEST_WAITER_DELAY = 10 +TEST_WAITER_MAX_ATTEMPTS = 10 +TEST_AWS_CONN_ID = "test-aws-id" +TEST_RESPONSE = { + "DBInstance": { + "DBInstanceIdentifier": "test-db-instance-identifier", + "DBInstanceStatus": "test-db-instance-status", + } +} + + +class TestRdsDbInstanceTrigger: + def test_rds_db_instance_trigger_serialize(self): + rds_db_instance_trigger = RdsDbInstanceTrigger( + waiter_name="test-waiter", + db_instance_identifier=TEST_DB_INSTANCE_IDENTIFIER, + waiter_delay=TEST_WAITER_DELAY, + waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS, + aws_conn_id=TEST_AWS_CONN_ID, + hook_params={}, + response=TEST_RESPONSE, + ) + class_path, args = rds_db_instance_trigger.serialize() + + assert class_path == "airflow.providers.amazon.aws.triggers.rds.RdsDbInstanceTrigger" + assert args["waiter_name"] == "test-waiter" + assert args["db_instance_identifier"] == TEST_DB_INSTANCE_IDENTIFIER + assert args["waiter_delay"] == str(TEST_WAITER_DELAY) + assert args["waiter_max_attempts"] == str(TEST_WAITER_MAX_ATTEMPTS) + assert args["aws_conn_id"] == TEST_AWS_CONN_ID + assert args["hook_params"] == {} + assert args["response"] == TEST_RESPONSE + + @pytest.mark.asyncio + @mock.patch.object(RdsHook, "async_conn") + async def test_rds_db_instance_trigger_run(self, mock_async_conn): + a_mock = mock.MagicMock() + mock_async_conn.__aenter__.return_value = a_mock + + a_mock.get_waiter().wait = AsyncMock() + + rds_db_instance_trigger = RdsDbInstanceTrigger( + waiter_name="test-waiter", + db_instance_identifier=TEST_DB_INSTANCE_IDENTIFIER, + waiter_delay=TEST_WAITER_DELAY, + waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS, + aws_conn_id=TEST_AWS_CONN_ID, + hook_params={}, + response=TEST_RESPONSE, + ) + + generator = rds_db_instance_trigger.run() + response = await generator.asend(None) + + assert response == TriggerEvent({"status": "success", "response": TEST_RESPONSE}) + + @pytest.mark.asyncio + @mock.patch("asyncio.sleep") + @mock.patch.object(RdsHook, "async_conn") + async def test_rds_db_instance_trigger_run_multiple_attempts(self, mock_async_conn, mock_sleep): + mock_sleep.return_value = True + a_mock = mock.MagicMock() + mock_async_conn.__aenter__.return_value = a_mock + error = WaiterError( + name="test_name", + reason="test_reason", + last_response={"DBInstances": [{"DBInstanceStatus": "CREATING"}]}, + ) + a_mock.get_waiter().wait = AsyncMock(side_effect=[error, error, error, True]) + + rds_db_instance_trigger = RdsDbInstanceTrigger( + waiter_name="test-waiter", + db_instance_identifier=TEST_DB_INSTANCE_IDENTIFIER, + waiter_delay=TEST_WAITER_DELAY, + waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS, + aws_conn_id=TEST_AWS_CONN_ID, + hook_params={}, + response=TEST_RESPONSE, + ) + + generator = rds_db_instance_trigger.run() + response = await generator.asend(None) + assert a_mock.get_waiter().wait.call_count == 4 + + assert response == TriggerEvent({"status": "success", "response": TEST_RESPONSE}) + + @pytest.mark.asyncio + @mock.patch("asyncio.sleep") + @mock.patch.object(RdsHook, "async_conn") + async def test_rds_db_instance_trigger_run_attempts_exceeded(self, mock_async_conn, mock_sleep): + mock_sleep.return_value = True + + a_mock = mock.MagicMock() + mock_async_conn.__aenter__.return_value = a_mock + error = WaiterError( + name="test_name", + reason="test_reason", + last_response={"DBInstances": [{"DBInstanceStatus": "CREATING"}]}, + ) + a_mock.get_waiter().wait = AsyncMock(side_effect=[error, error, error, True]) + + rds_db_instance_trigger = RdsDbInstanceTrigger( + waiter_name="test-waiter", + db_instance_identifier=TEST_DB_INSTANCE_IDENTIFIER, + waiter_delay=TEST_WAITER_DELAY, + waiter_max_attempts=2, + aws_conn_id=TEST_AWS_CONN_ID, + hook_params={}, + response=TEST_RESPONSE, + ) + + with pytest.raises(AirflowException) as exc: + generator = rds_db_instance_trigger.run() + await generator.asend(None) + + assert "Waiter error: max attempts reached" in str(exc.value) + assert a_mock.get_waiter().wait.call_count == 2 + + @pytest.mark.asyncio + @mock.patch("asyncio.sleep") + @mock.patch.object(RdsHook, "async_conn") + async def test_rds_db_instance_trigger_run_attempts_failed(self, mock_async_conn, mock_sleep): + a_mock = mock.MagicMock() + mock_async_conn.__aenter__.return_value = a_mock + + error_creating = WaiterError( + name="test_name", + reason="test_reason", + last_response={"DBInstances": [{"DBInstanceStatus": "CREATING"}]}, + ) + + error_failed = WaiterError( + name="test_name", + reason="Waiter encountered a terminal failure state:", + last_response={"DBInstances": [{"DBInstanceStatus": "FAILED"}]}, + ) + a_mock.get_waiter().wait = AsyncMock(side_effect=[error_creating, error_creating, error_failed]) + mock_sleep.return_value = True + + rds_db_instance_trigger = RdsDbInstanceTrigger( + waiter_name="test-waiter", + db_instance_identifier=TEST_DB_INSTANCE_IDENTIFIER, + waiter_delay=TEST_WAITER_DELAY, + waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS, + aws_conn_id=TEST_AWS_CONN_ID, + hook_params={}, + response=TEST_RESPONSE, + ) + + with pytest.raises(AirflowException) as exc: + generator = rds_db_instance_trigger.run() + await generator.asend(None) + assert "Error checking DB Instance status" in str(exc.value) + assert a_mock.get_waiter().wait.call_count == 3