Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 26 additions & 35 deletions airflow/providers/amazon/aws/operators/redshift_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from airflow.providers.amazon.aws.triggers.redshift_cluster import (
RedshiftClusterTrigger,
RedshiftCreateClusterTrigger,
RedshiftPauseClusterTrigger,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -506,7 +507,9 @@ class RedshiftPauseClusterOperator(BaseOperator):

:param cluster_identifier: id of the AWS Redshift Cluster
:param aws_conn_id: aws connection to use
:param deferrable: Run operator in the deferrable mode. This mode requires an additional aiobotocore>=
:param deferrable: Run operator in the deferrable mode
:param poll_interval: Time (in seconds) to wait between two consecutive calls to check cluster state
:param max_attempts: Maximum number of attempts to poll the cluster
"""

template_fields: Sequence[str] = ("cluster_identifier",)
Expand All @@ -520,64 +523,52 @@ def __init__(
aws_conn_id: str = "aws_default",
deferrable: bool = False,
poll_interval: int = 10,
max_attempts: int = 15,
**kwargs,
):
super().__init__(**kwargs)
self.cluster_identifier = cluster_identifier
self.aws_conn_id = aws_conn_id
self.deferrable = deferrable
self.max_attempts = max_attempts
self.poll_interval = poll_interval
# These parameters are added to address an issue with the boto3 API where the API
# These parameters are used to address an issue with the boto3 API where the API
# prematurely reports the cluster as available to receive requests. This causes the cluster
# to reject initial attempts to pause the cluster despite reporting the correct state.
self._attempts = 10
self._attempt_interval = 15

def execute(self, context: Context):
redshift_hook = RedshiftHook(aws_conn_id=self.aws_conn_id)
while self._attempts >= 1:
Comment thread
syedahsn marked this conversation as resolved.
Outdated
try:
redshift_hook.get_conn().pause_cluster(ClusterIdentifier=self.cluster_identifier)
break
except redshift_hook.get_conn().exceptions.InvalidClusterStateFault as error:
self._attempts = self._attempts - 1

if self._attempts > 0:
self.log.error("Unable to pause cluster. %d attempts remaining.", self._attempts)
time.sleep(self._attempt_interval)
else:
raise error
if self.deferrable:
self.defer(
timeout=self.execution_timeout,
trigger=RedshiftClusterTrigger(
task_id=self.task_id,
trigger=RedshiftPauseClusterTrigger(
cluster_identifier=self.cluster_identifier,
poll_interval=self.poll_interval,
max_attempts=self.max_attempts,
Comment thread
syedahsn marked this conversation as resolved.
aws_conn_id=self.aws_conn_id,
cluster_identifier=self.cluster_identifier,
attempts=self._attempts,
operation_type="pause_cluster",
),
method_name="execute_complete",
)
else:
while self._attempts >= 1:
try:
redshift_hook.get_conn().pause_cluster(ClusterIdentifier=self.cluster_identifier)
return
except redshift_hook.get_conn().exceptions.InvalidClusterStateFault as error:
self._attempts = self._attempts - 1

if self._attempts > 0:
self.log.error("Unable to pause cluster. %d attempts remaining.", self._attempts)
time.sleep(self._attempt_interval)
else:
raise error

def execute_complete(self, context: Context, event: Any = None) -> None:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
successful.
"""
if event:
if "status" in event and event["status"] == "error":
msg = f"{event['status']}: {event['message']}"
raise AirflowException(msg)
elif "status" in event and event["status"] == "success":
self.log.info("%s completed successfully.", self.task_id)
self.log.info("Paused cluster successfully")
def execute_complete(self, context, event=None):
if event["status"] != "success":
raise AirflowException(f"Error pausing cluster: {event}")
else:
raise AirflowException("No event received from trigger")
self.log.info("Paused cluster successfully")
return


class RedshiftDeleteClusterOperator(BaseOperator):
Expand Down
70 changes: 70 additions & 0 deletions airflow/providers/amazon/aws/triggers/redshift_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@
# under the License.
from __future__ import annotations

import asyncio
from typing import Any, AsyncIterator

from botocore.exceptions import WaiterError

from airflow.compat.functools import cached_property
from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftAsyncHook, RedshiftHook
from airflow.triggers.base import BaseTrigger, TriggerEvent
Expand Down Expand Up @@ -137,3 +140,70 @@ async def run(self):
},
)
yield TriggerEvent({"status": "success", "message": "Cluster Created"})


class RedshiftPauseClusterTrigger(BaseTrigger):
"""
Trigger for RedshiftPauseClusterOperator.
The trigger will asynchronously poll the boto3 API and wait for the
Redshift cluster to be in the `paused` state.

:param cluster_identifier: A unique identifier for the cluster.
:param poll_interval: The amount of time in seconds to wait between attempts.
:param max_attempts: The maximum number of attempts to be made.
:param aws_conn_id: The Airflow connection used for AWS credentials.
"""

def __init__(
self,
cluster_identifier: str,
poll_interval: int,
max_attempts: int,
aws_conn_id: str,
):
self.cluster_identifier = cluster_identifier
self.poll_interval = poll_interval
self.max_attempts = max_attempts
self.aws_conn_id = aws_conn_id

def serialize(self) -> tuple[str, dict[str, Any]]:
return (
"airflow.providers.amazon.aws.triggers.redshift_cluster.RedshiftPauseClusterTrigger",
{
"cluster_identifier": str(self.cluster_identifier),
Comment thread
syedahsn marked this conversation as resolved.
Outdated
"poll_interval": str(self.poll_interval),
"max_attempts": str(self.max_attempts),
"aws_conn_id": str(self.aws_conn_id),
},
)

@cached_property
def hook(self) -> RedshiftHook:
return RedshiftHook(aws_conn_id=self.aws_conn_id)

async def run(self):
async with self.hook.async_conn as client:
attempt = 0
while attempt < int(self.max_attempts):
attempt = attempt + 1
try:
waiter = self.hook.get_waiter("cluster_paused", deferrable=True, client=client)
Comment thread
syedahsn marked this conversation as resolved.
Outdated
await waiter.wait(
ClusterIdentifier=self.cluster_identifier,
WaiterConfig={
"Delay": int(self.poll_interval),
"MaxAttempts": 1,
Comment thread
syedahsn marked this conversation as resolved.
},
)
break
except WaiterError as error:
self.log.info(
"Status of cluster is %s", error.last_response["Clusters"][0]["ClusterStatus"]
)
await asyncio.sleep(int(self.poll_interval))
Comment thread
syedahsn marked this conversation as resolved.
if attempt >= int(self.max_attempts):
Comment thread
syedahsn marked this conversation as resolved.
yield TriggerEvent(
{"status": "failure", "message": "Resume Cluster Failed - max attempts reached."}
Comment thread
syedahsn marked this conversation as resolved.
Outdated
)
else:
Comment thread
syedahsn marked this conversation as resolved.
yield TriggerEvent({"status": "success", "message": "Cluster paused"})
30 changes: 30 additions & 0 deletions airflow/providers/amazon/aws/waiters/redshift.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
{
"version": 2,
"waiters": {
"cluster_paused": {
"operation": "DescribeClusters",
"delay": 30,
"maxAttempts": 60,
"acceptors": [
{
"matcher": "pathAll",
"argument": "Clusters[].ClusterStatus",
"expected": "paused",
"state": "success"
},
{
"expected": "ClusterNotFound",
"argument": "Clusters[].ClusterStatus",
"matcher": "error",
"state": "retry"
},
{
"expected": "deleting",
"matcher": "pathAny",
"state": "failure",
"argument": "Clusters[].ClusterStatus"
Comment thread
syedahsn marked this conversation as resolved.
Outdated
}
]
}
}
}
14 changes: 9 additions & 5 deletions tests/providers/amazon/aws/operators/test_redshift_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@
RedshiftPauseClusterOperator,
RedshiftResumeClusterOperator,
)
from airflow.providers.amazon.aws.triggers.redshift_cluster import RedshiftClusterTrigger
from airflow.providers.amazon.aws.triggers.redshift_cluster import (
RedshiftClusterTrigger,
RedshiftPauseClusterTrigger,
)


class TestRedshiftCreateClusterOperator:
Expand Down Expand Up @@ -377,9 +380,10 @@ def test_pause_cluster_multiple_attempts_fail(self, mock_sleep, mock_conn):
redshift_operator.execute(None)
assert mock_conn.pause_cluster.call_count == 10

def test_pause_cluster_deferrable_mode(self):
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.get_conn")
Comment thread
syedahsn marked this conversation as resolved.
Outdated
def test_pause_cluster_deferrable_mode(self, mock_get_conn):
"""Test Pause cluster operator with defer when deferrable param is true"""

mock_get_conn().pause_cluster.return_value = True
redshift_operator = RedshiftPauseClusterOperator(
task_id="task_test", cluster_identifier="test_cluster", deferrable=True
)
Expand All @@ -388,8 +392,8 @@ def test_pause_cluster_deferrable_mode(self):
redshift_operator.execute(context=None)

assert isinstance(
exc.value.trigger, RedshiftClusterTrigger
), "Trigger is not a RedshiftClusterTrigger"
exc.value.trigger, RedshiftPauseClusterTrigger
), "Trigger is not a RedshiftPauseClusterTrigger"

def test_pause_cluster_execute_complete_success(self):
"""Asserts that logging occurs as expected"""
Expand Down
107 changes: 106 additions & 1 deletion tests/providers/amazon/aws/triggers/test_redshift_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,12 @@
import sys

import pytest
from botocore.exceptions import WaiterError

from airflow.providers.amazon.aws.triggers.redshift_cluster import RedshiftCreateClusterTrigger
from airflow.providers.amazon.aws.triggers.redshift_cluster import (
RedshiftCreateClusterTrigger,
RedshiftPauseClusterTrigger,
)
from airflow.triggers.base import TriggerEvent

if sys.version_info < (3, 8):
Expand Down Expand Up @@ -72,3 +76,104 @@ async def test_redshift_create_cluster_trigger_run(self, mock_async_conn):
response = await generator.asend(None)

assert response == TriggerEvent({"status": "success", "message": "Cluster Created"})


class TestRedshiftPauseClusterTrigger:
def test_redshift_resume_cluster_trigger_serialize(self):
redshift_resume_cluster_trigger = RedshiftPauseClusterTrigger(
cluster_identifier=TEST_CLUSTER_IDENTIFIER,
poll_interval=TEST_POLL_INTERVAL,
max_attempts=TEST_MAX_ATTEMPT,
aws_conn_id=TEST_AWS_CONN_ID,
)
class_path, args = redshift_resume_cluster_trigger.serialize()
assert (
class_path == "airflow.providers.amazon.aws.triggers.redshift_cluster.RedshiftPauseClusterTrigger"
)
assert args["cluster_identifier"] == TEST_CLUSTER_IDENTIFIER
assert args["poll_interval"] == str(TEST_POLL_INTERVAL)
assert args["max_attempts"] == str(TEST_MAX_ATTEMPT)
assert args["aws_conn_id"] == TEST_AWS_CONN_ID

@pytest.mark.asyncio
@async_mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.get_waiter")
@async_mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.async_conn")
Comment thread
syedahsn marked this conversation as resolved.
Outdated
async def test_redshift_resume_cluster_trigger_run(self, mock_async_conn, mock_get_waiter):
mock = async_mock.MagicMock()
mock_async_conn.__aenter__.return_value = mock

mock_get_waiter().wait = AsyncMock()

redshift_resume_cluster_trigger = RedshiftPauseClusterTrigger(
cluster_identifier=TEST_CLUSTER_IDENTIFIER,
poll_interval=TEST_POLL_INTERVAL,
max_attempts=TEST_MAX_ATTEMPT,
aws_conn_id=TEST_AWS_CONN_ID,
)

generator = redshift_resume_cluster_trigger.run()
response = await generator.asend(None)

assert response == TriggerEvent({"status": "success", "message": "Cluster paused"})

@pytest.mark.asyncio
@async_mock.patch("asyncio.sleep")
@async_mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.get_waiter")
@async_mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.async_conn")
async def test_redshift_resume_cluster_trigger_run_multiple_attempts(
self, mock_async_conn, mock_get_waiter, mock_sleep
):
mock = async_mock.MagicMock()
mock_async_conn.__aenter__.return_value = mock
error = WaiterError(
name="test_name",
reason="test_reason",
last_response={"Clusters": [{"ClusterStatus": "available"}]},
)
mock_get_waiter().wait.side_effect = AsyncMock(side_effect=[error, error, True])
mock_sleep.return_value = True

redshift_resume_cluster_trigger = RedshiftPauseClusterTrigger(
cluster_identifier=TEST_CLUSTER_IDENTIFIER,
poll_interval=TEST_POLL_INTERVAL,
max_attempts=TEST_MAX_ATTEMPT,
aws_conn_id=TEST_AWS_CONN_ID,
)

generator = redshift_resume_cluster_trigger.run()
response = await generator.asend(None)

assert mock_get_waiter().wait.call_count == 3
assert response == TriggerEvent({"status": "success", "message": "Cluster paused"})

@pytest.mark.asyncio
@async_mock.patch("asyncio.sleep")
@async_mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.get_waiter")
@async_mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.async_conn")
async def test_redshift_resume_cluster_trigger_run_attempts_exceeded(
self, mock_async_conn, mock_get_waiter, mock_sleep
):
mock = async_mock.MagicMock()
mock_async_conn.__aenter__.return_value = mock
error = WaiterError(
name="test_name",
reason="test_reason",
last_response={"Clusters": [{"ClusterStatus": "available"}]},
)
mock_get_waiter().wait.side_effect = AsyncMock(side_effect=[error, error, True])
mock_sleep.return_value = True

redshift_resume_cluster_trigger = RedshiftPauseClusterTrigger(
cluster_identifier=TEST_CLUSTER_IDENTIFIER,
poll_interval=TEST_POLL_INTERVAL,
max_attempts=2,
aws_conn_id=TEST_AWS_CONN_ID,
)

generator = redshift_resume_cluster_trigger.run()
response = await generator.asend(None)

assert mock_get_waiter().wait.call_count == 2
assert response == TriggerEvent(
{"status": "failure", "message": "Resume Cluster Failed - max attempts reached."}
)