Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions airflow/providers/amazon/aws/hooks/rds.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from airflow.exceptions import AirflowException, AirflowNotFoundException
from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
from airflow.providers.amazon.aws.utils.waiter_with_logging import wait

if TYPE_CHECKING:
from mypy_boto3_rds import RDSClient # noqa
Expand Down Expand Up @@ -269,9 +270,14 @@ def poke():
target_state = target_state.lower()
if target_state in ("available", "deleted"):
waiter = self.conn.get_waiter(f"db_instance_{target_state}") # type: ignore
waiter.wait(
DBInstanceIdentifier=db_instance_id,
WaiterConfig={"Delay": check_interval, "MaxAttempts": max_attempts},
wait(
waiter=waiter,
waiter_delay=check_interval,
waiter_max_attempts=max_attempts,
args={"DBInstanceIdentifier": db_instance_id},
failure_message=f"Rdb DB instance failed to reach state {target_state}",
status_message="Rds DB instance state is",
status_args=["DBInstances[0].DBInstanceStatus"],
)
else:
self._wait_for_state(poke, target_state, check_interval, max_attempts)
Expand Down
100 changes: 94 additions & 6 deletions airflow/providers/amazon/aws/operators/rds.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,18 @@
from __future__ import annotations

import json
from datetime import timedelta
from typing import TYPE_CHECKING, Sequence

from mypy_boto3_rds.type_defs import TagTypeDef

from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.rds import RdsHook
from airflow.providers.amazon.aws.triggers.rds import RdsDbInstanceTrigger
from airflow.providers.amazon.aws.utils.rds import RdsDbType
from airflow.providers.amazon.aws.utils.tags import format_tags
from airflow.providers.amazon.aws.utils.waiter_with_logging import wait

if TYPE_CHECKING:
from airflow.utils.context import Context
Expand All @@ -38,8 +42,8 @@ class RdsBaseOperator(BaseOperator):
ui_fgcolor = "#ffffff"

def __init__(self, *args, aws_conn_id: str = "aws_conn_id", hook_params: dict | None = None, **kwargs):
hook_params = hook_params or {}
self.hook = RdsHook(aws_conn_id=aws_conn_id, **hook_params)
self.hook_params = hook_params or {}
self.hook = RdsHook(aws_conn_id=aws_conn_id, **self.hook_params)
super().__init__(*args, **kwargs)

self._await_interval = 60 # seconds
Expand Down Expand Up @@ -521,6 +525,11 @@ class RdsCreateDbInstanceOperator(RdsBaseOperator):
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/rds.html#RDS.Client.create_db_instance
:param aws_conn_id: The Airflow connection used for AWS credentials.
:param wait_for_completion: If True, waits for creation of the DB instance to complete. (default: True)
:param waiter_delay: Time (in seconds) to wait between two consecutive calls to check DB instance state
:param waiter_max_attempts: The maximum number of attempts to check DB instance state
:param deferrable: If True, the operator will wait asynchronously for the DB instance to be created.
This implies waiting for completion. This mode requires aiobotocore module to be installed.
(default: False)
"""

template_fields = ("db_instance_identifier", "db_instance_class", "engine", "rds_kwargs")
Expand All @@ -534,6 +543,9 @@ def __init__(
rds_kwargs: dict | None = None,
aws_conn_id: str = "aws_default",
wait_for_completion: bool = True,
deferrable: bool = False,
waiter_delay: int = 30,
waiter_max_attempts: int = 60,
**kwargs,
):
super().__init__(aws_conn_id=aws_conn_id, **kwargs)
Expand All @@ -542,7 +554,11 @@ def __init__(
self.db_instance_class = db_instance_class
self.engine = engine
self.rds_kwargs = rds_kwargs or {}
self.wait_for_completion = wait_for_completion
self.wait_for_completion = False if deferrable else wait_for_completion
self.deferrable = deferrable
self.waiter_delay = waiter_delay
self.waiter_max_attempts = waiter_max_attempts
self.aws_conn_id = aws_conn_id

def execute(self, context: Context) -> str:
self.log.info("Creating new DB instance %s", self.db_instance_identifier)
Expand All @@ -553,11 +569,41 @@ def execute(self, context: Context) -> str:
Engine=self.engine,
**self.rds_kwargs,
)
if self.deferrable:
self.defer(
trigger=RdsDbInstanceTrigger(
db_instance_identifier=self.db_instance_identifier,
waiter_delay=self.waiter_delay,
waiter_max_attempts=self.waiter_max_attempts,
aws_conn_id=self.aws_conn_id,
hook_params=self.hook_params,
waiter_name="db_instance_available",
# ignoring type because create_db_instance is a dict
response=create_db_instance, # type: ignore[arg-type]
),
method_name="execute_complete",
timeout=timedelta(seconds=self.waiter_delay * self.waiter_max_attempts),
)

if self.wait_for_completion:
self.hook.wait_for_db_instance_state(self.db_instance_identifier, target_state="available")
waiter = self.hook.conn.get_waiter("db_instance_available")
wait(
waiter=waiter,
waiter_delay=self.waiter_delay,
waiter_max_attempts=self.waiter_max_attempts,
args={"DBInstanceIdentifier": self.db_instance_identifier},
failure_message="DB instance creation failed",
status_message="DB Instance status is",
status_args=["DBInstances[0].DBInstanceStatus"],
)
return json.dumps(create_db_instance, default=str)

def execute_complete(self, context, event=None) -> str:
if event["status"] != "success":
raise AirflowException(f"DB instance creation failed: {event}")
else:
return json.dumps(event["response"], default=str)


class RdsDeleteDbInstanceOperator(RdsBaseOperator):
"""
Expand All @@ -572,6 +618,11 @@ class RdsDeleteDbInstanceOperator(RdsBaseOperator):
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/rds.html#RDS.Client.delete_db_instance
:param aws_conn_id: The Airflow connection used for AWS credentials.
:param wait_for_completion: If True, waits for deletion of the DB instance to complete. (default: True)
:param waiter_delay: Time (in seconds) to wait between two consecutive calls to check DB instance state
:param waiter_max_attempts: The maximum number of attempts to check DB instance state
:param deferrable: If True, the operator will wait asynchronously for the DB instance to be created.
This implies waiting for completion. This mode requires aiobotocore module to be installed.
(default: False)
"""

template_fields = ("db_instance_identifier", "rds_kwargs")
Expand All @@ -583,12 +634,19 @@ def __init__(
rds_kwargs: dict | None = None,
aws_conn_id: str = "aws_default",
wait_for_completion: bool = True,
deferrable: bool = False,
waiter_delay: int = 30,
waiter_max_attempts: int = 60,
**kwargs,
):
super().__init__(aws_conn_id=aws_conn_id, **kwargs)
self.db_instance_identifier = db_instance_identifier
self.rds_kwargs = rds_kwargs or {}
self.wait_for_completion = wait_for_completion
self.wait_for_completion = False if deferrable else wait_for_completion
self.deferrable = deferrable
self.waiter_delay = waiter_delay
self.waiter_max_attempts = waiter_max_attempts
self.aws_conn_id = aws_conn_id

def execute(self, context: Context) -> str:
self.log.info("Deleting DB instance %s", self.db_instance_identifier)
Expand All @@ -597,11 +655,41 @@ def execute(self, context: Context) -> str:
DBInstanceIdentifier=self.db_instance_identifier,
**self.rds_kwargs,
)
if self.deferrable:
self.defer(
trigger=RdsDbInstanceTrigger(
db_instance_identifier=self.db_instance_identifier,
waiter_delay=self.waiter_delay,
waiter_max_attempts=self.waiter_max_attempts,
aws_conn_id=self.aws_conn_id,
hook_params=self.hook_params,
waiter_name="db_instance_deleted",
# ignoring type because delete_db_instance is a dict
response=delete_db_instance, # type: ignore[arg-type]
),
method_name="execute_complete",
timeout=timedelta(seconds=self.waiter_delay * self.waiter_max_attempts),
)

if self.wait_for_completion:
self.hook.wait_for_db_instance_state(self.db_instance_identifier, target_state="deleted")
waiter = self.hook.conn.get_waiter("db_instance_deleted")
wait(
waiter=waiter,
waiter_delay=self.waiter_delay,
waiter_max_attempts=self.waiter_max_attempts,
args={"DBInstanceIdentifier": self.db_instance_identifier},
failure_message="DB instance deletion failed",
status_message="DB Instance status is",
status_args=["DBInstances[0].DBInstanceStatus"],
)
return json.dumps(delete_db_instance, default=str)

def execute_complete(self, context, event=None) -> str:
if event["status"] != "success":
raise AirflowException(f"DB instance deletion failed: {event}")
else:
return json.dumps(event["response"], default=str)


class RdsStartDbOperator(RdsBaseOperator):
"""
Expand Down
89 changes: 89 additions & 0 deletions airflow/providers/amazon/aws/triggers/rds.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from typing import Any

from airflow.providers.amazon.aws.hooks.rds import RdsHook
from airflow.providers.amazon.aws.utils.waiter_with_logging import async_wait
from airflow.triggers.base import BaseTrigger, TriggerEvent


class RdsDbInstanceTrigger(BaseTrigger):
"""
Trigger for RdsCreateDbInstanceOperator and RdsDeleteDbInstanceOperator.

The trigger will asynchronously poll the boto3 API and wait for the
DB instance to be in the state specified by the waiter.

:param waiter_name: Name of the waiter to use, for instance 'db_instance_available'
or 'db_instance_deleted'.
:param db_instance_identifier: The DB instance identifier for the DB instance to be polled.
:param waiter_delay: The amount of time in seconds to wait between attempts.
:param waiter_max_attempts: The maximum number of attempts to be made.
:param aws_conn_id: The Airflow connection used for AWS credentials.
:param hook_params: The parameters to pass to the RdsHook.
:param response: The response from the RdsHook, to be passed back to the operator.
"""

def __init__(
self,
waiter_name: str,
db_instance_identifier: str,
waiter_delay: int,
waiter_max_attempts: int,
aws_conn_id: str,
hook_params: dict[str, Any],
response: dict[str, Any],
):
self.db_instance_identifier = db_instance_identifier
self.waiter_delay = waiter_delay
self.waiter_max_attempts = waiter_max_attempts
self.aws_conn_id = aws_conn_id
self.hook_params = hook_params
self.waiter_name = waiter_name
self.response = response

def serialize(self) -> tuple[str, dict[str, Any]]:
return (
# dynamically generate the fully qualified name of the class
self.__class__.__module__ + "." + self.__class__.__qualname__,
{
"db_instance_identifier": self.db_instance_identifier,
"waiter_delay": str(self.waiter_delay),
"waiter_max_attempts": str(self.waiter_max_attempts),
"aws_conn_id": self.aws_conn_id,
"hook_params": self.hook_params,
"waiter_name": self.waiter_name,
"response": self.response,
},
)

async def run(self):
self.hook = RdsHook(aws_conn_id=self.aws_conn_id, **self.hook_params)
async with self.hook.async_conn as client:
waiter = client.get_waiter(self.waiter_name)
await async_wait(
waiter=waiter,
waiter_delay=int(self.waiter_delay),
waiter_max_attempts=int(self.waiter_max_attempts),
args={"DBInstanceIdentifier": self.db_instance_identifier},
failure_message="Error checking DB Instance status",
status_message="DB instance status is",
status_args=["DBInstances[0].DBInstanceStatus"],
)
yield TriggerEvent({"status": "success", "response": self.response})
3 changes: 3 additions & 0 deletions airflow/providers/amazon/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,9 @@ triggers:
- integration-name: Amazon ECS
python-modules:
- airflow.providers.amazon.aws.triggers.ecs
- integration-name: Amazon RDS
python-modules:
- airflow.providers.amazon.aws.triggers.rds

transfers:
- source-integration-name: Amazon DynamoDB
Expand Down
2 changes: 2 additions & 0 deletions docs/apache-airflow-providers-amazon/operators/rds.rst
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ Create a database instance

To create a AWS DB instance you can use
:class:`~airflow.providers.amazon.aws.operators.rds.RdsCreateDbInstanceOperator`.
You can also run this operator in deferrable mode by setting ``deferrable`` param to ``True``.

.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_rds_instance.py
:language: python
Expand All @@ -159,6 +160,7 @@ Delete a database instance

To delete a AWS DB instance you can use
:class:`~airflow.providers.amazon.aws.operators.rds.RDSDeleteDbInstanceOperator`.
You can also run this operator in deferrable mode by setting ``deferrable`` param to ``True``.

.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_rds_instance.py
:language: python
Expand Down
1 change: 0 additions & 1 deletion tests/providers/amazon/aws/hooks/test_rds.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,6 @@ def test_wait_for_db_instance_state_boto_waiters(self, rds_hook: RdsHook, db_ins
mock.return_value.wait.assert_called_once_with(
DBInstanceIdentifier=db_instance_id,
WaiterConfig={
"Delay": self.waiter_args["check_interval"],
"MaxAttempts": self.waiter_args["max_attempts"],
},
)
Expand Down
Loading