Skip to content
Merged
84 changes: 46 additions & 38 deletions airflow/providers/amazon/aws/operators/redshift_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftHook
from airflow.providers.amazon.aws.triggers.redshift_cluster import (
RedshiftClusterTrigger,
RedshiftCreateClusterSnapshotTrigger,
RedshiftCreateClusterTrigger,
RedshiftPauseClusterTrigger,
RedshiftResumeClusterTrigger,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -452,8 +452,11 @@ class RedshiftResumeClusterOperator(BaseOperator):
:param cluster_identifier: Unique identifier of the AWS Redshift cluster
:param aws_conn_id: The Airflow connection used for AWS credentials.
The default connection id is ``aws_default``
:param deferrable: Run operator in deferrable mode
:param poll_interval: Time (in seconds) to wait between two consecutive calls to check cluster state
:param max_attempts: The maximum number of attempts to check the state of the cluster.
:param wait_for_completion: If True, the operator will wait for the cluster to be in the
`resumed` state. Default is False.
:param deferrable: If True, the operator will run as a deferrable operator.
"""

template_fields: Sequence[str] = ("cluster_identifier",)
Expand All @@ -465,66 +468,71 @@ def __init__(
*,
cluster_identifier: str,
aws_conn_id: str = "aws_default",
wait_for_completion: bool = False,
deferrable: bool = False,
poll_interval: int = 10,
max_attempts: int = 10,
**kwargs,
):
super().__init__(**kwargs)
self.cluster_identifier = cluster_identifier
self.aws_conn_id = aws_conn_id
self.wait_for_completion = wait_for_completion
self.deferrable = deferrable
self.max_attempts = max_attempts
self.poll_interval = poll_interval
# These parameters are added to address an issue with the boto3 API where the API
# These parameters are used to address an issue with the boto3 API where the API
# prematurely reports the cluster as available to receive requests. This causes the cluster
# to reject initial attempts to resume the cluster despite reporting the correct state.
self._attempts = 10
self._remaining_attempts = 10
self._attempt_interval = 15

def execute(self, context: Context):
redshift_hook = RedshiftHook(aws_conn_id=self.aws_conn_id)
self.log.info("Starting resume cluster")
while self._remaining_attempts >= 1:
try:
redshift_hook.get_conn().resume_cluster(ClusterIdentifier=self.cluster_identifier)
break
except redshift_hook.get_conn().exceptions.InvalidClusterStateFault as error:
self._remaining_attempts = self._remaining_attempts - 1

if self._remaining_attempts > 0:
self.log.error(
"Unable to resume cluster. %d attempts remaining.", self._remaining_attempts
)
time.sleep(self._attempt_interval)
else:
raise error
if self.deferrable:
self.defer(
timeout=self.execution_timeout,
trigger=RedshiftClusterTrigger(
task_id=self.task_id,
trigger=RedshiftResumeClusterTrigger(
cluster_identifier=self.cluster_identifier,
poll_interval=self.poll_interval,
max_attempts=self.max_attempts,
aws_conn_id=self.aws_conn_id,
cluster_identifier=self.cluster_identifier,
attempts=self._attempts,
operation_type="resume_cluster",
),
method_name="execute_complete",
# timeout is set to ensure that if a trigger dies, the timeout does not restart
# 60 seconds is added to allow the trigger to exit gracefully (i.e. yield TriggerEvent)
timeout=timedelta(seconds=self.max_attempts * self.poll_interval + 60),
)
if self.wait_for_completion:
waiter = redshift_hook.get_waiter("cluster_resumed")
waiter.wait(
ClusterIdentifier=self.cluster_identifier,
WaiterConfig={
"Delay": self.poll_interval,
"MaxAttempts": self.max_attempts,
},
)

def execute_complete(self, context, event=None):
if event["status"] != "success":
raise AirflowException(f"Error resuming cluster: {event}")
else:
while self._attempts >= 1:
try:
redshift_hook.get_conn().resume_cluster(ClusterIdentifier=self.cluster_identifier)
return
except redshift_hook.get_conn().exceptions.InvalidClusterStateFault as error:
self._attempts = self._attempts - 1

if self._attempts > 0:
self.log.error("Unable to resume cluster. %d attempts remaining.", self._attempts)
time.sleep(self._attempt_interval)
else:
raise error

def execute_complete(self, context: Context, event: Any = None) -> None:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
successful.
"""
if event:
if "status" in event and event["status"] == "error":
msg = f"{event['status']}: {event['message']}"
raise AirflowException(msg)
elif "status" in event and event["status"] == "success":
self.log.info("%s completed successfully.", self.task_id)
self.log.info("Resumed cluster successfully")
else:
raise AirflowException("No event received from trigger")
self.log.info("Resumed cluster successfully")
return


class RedshiftPauseClusterOperator(BaseOperator):
Expand Down
72 changes: 72 additions & 0 deletions airflow/providers/amazon/aws/triggers/redshift_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,3 +285,75 @@ async def run(self):
)
else:
yield TriggerEvent({"status": "success", "message": "Cluster Snapshot Created"})


class RedshiftResumeClusterTrigger(BaseTrigger):
"""
Trigger for RedshiftResumeClusterOperator.
The trigger will asynchronously poll the boto3 API and wait for the
Redshift cluster to be in the `available` state.

:param cluster_identifier: A unique identifier for the cluster.
:param poll_interval: The amount of time in seconds to wait between attempts.
:param max_attempts: The maximum number of attempts to be made.
:param aws_conn_id: The Airflow connection used for AWS credentials.
"""

def __init__(
self,
cluster_identifier: str,
poll_interval: int,
max_attempts: int,
aws_conn_id: str,
):
self.cluster_identifier = cluster_identifier
self.poll_interval = poll_interval
self.max_attempts = max_attempts
self.aws_conn_id = aws_conn_id

def serialize(self) -> tuple[str, dict[str, Any]]:
return (
"airflow.providers.amazon.aws.triggers.redshift_cluster.RedshiftResumeClusterTrigger",
{
"cluster_identifier": self.cluster_identifier,
"poll_interval": str(self.poll_interval),
"max_attempts": str(self.max_attempts),
"aws_conn_id": self.aws_conn_id,
},
)

@cached_property
def hook(self) -> RedshiftHook:
return RedshiftHook(aws_conn_id=self.aws_conn_id)

async def run(self):
async with self.hook.async_conn as client:
attempt = 0
waiter = self.hook.get_waiter("cluster_resumed", deferrable=True, client=client)
while attempt < int(self.max_attempts):
attempt = attempt + 1
try:
await waiter.wait(
ClusterIdentifier=self.cluster_identifier,
WaiterConfig={
"Delay": int(self.poll_interval),
"MaxAttempts": 1,
},
)
break
except WaiterError as error:
if "terminal failure" in str(error):
yield TriggerEvent(
{"status": "failure", "message": f"Resume Cluster Failed: {error}"}
)
break
self.log.info(
"Status of cluster is %s", error.last_response["Clusters"][0]["ClusterStatus"]
)
await asyncio.sleep(int(self.poll_interval))
if attempt >= int(self.max_attempts):
yield TriggerEvent(
{"status": "failure", "message": "Resume Cluster Failed - max attempts reached."}
)
else:
yield TriggerEvent({"status": "success", "message": "Cluster resumed"})
25 changes: 25 additions & 0 deletions airflow/providers/amazon/aws/waiters/redshift.json
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,31 @@
"state": "failure"
}
]
},
"cluster_resumed": {
"operation": "DescribeClusters",
"delay": 30,
"maxAttempts": 60,
"acceptors": [
{
"matcher": "pathAll",
"argument": "Clusters[].ClusterStatus",
"expected": "available",
"state": "success"
},
{
"matcher": "error",
"argument": "Clusters[].ClusterStatus",
"expected": "ClusterNotFound",
"state": "retry"
},
{
"matcher": "pathAny",
"argument": "Clusters[].ClusterStatus",
"expected": "deleting",
"state": "failure"
}
]
}
}
}
61 changes: 33 additions & 28 deletions tests/providers/amazon/aws/operators/test_redshift_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@
RedshiftResumeClusterOperator,
)
from airflow.providers.amazon.aws.triggers.redshift_cluster import (
RedshiftClusterTrigger,
RedshiftCreateClusterSnapshotTrigger,
RedshiftPauseClusterTrigger,
RedshiftResumeClusterTrigger,
)


Expand Down Expand Up @@ -264,15 +264,15 @@ def test_init(self):
assert redshift_operator.cluster_identifier == "test_cluster"
assert redshift_operator.aws_conn_id == "aws_conn_test"

@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.get_conn")
@mock.patch.object(RedshiftHook, "get_conn")
def test_resume_cluster_is_called_when_cluster_is_paused(self, mock_get_conn):
redshift_operator = RedshiftResumeClusterOperator(
task_id="task_test", cluster_identifier="test_cluster", aws_conn_id="aws_conn_test"
)
redshift_operator.execute(None)
mock_get_conn.return_value.resume_cluster.assert_called_once_with(ClusterIdentifier="test_cluster")

@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.conn")
@mock.patch.object(RedshiftHook, "conn")
@mock.patch("time.sleep", return_value=None)
def test_resume_cluster_multiple_attempts(self, mock_sleep, mock_conn):
exception = boto3.client("redshift").exceptions.InvalidClusterStateFault({}, "test")
Expand All @@ -288,7 +288,7 @@ def test_resume_cluster_multiple_attempts(self, mock_sleep, mock_conn):
redshift_operator.execute(None)
assert mock_conn.resume_cluster.call_count == 3

@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.conn")
@mock.patch.object(RedshiftHook, "conn")
@mock.patch("time.sleep", return_value=None)
def test_resume_cluster_multiple_attempts_fail(self, mock_sleep, mock_conn):
exception = boto3.client("redshift").exceptions.InvalidClusterStateFault({}, "test")
Expand All @@ -306,16 +306,10 @@ def test_resume_cluster_multiple_attempts_fail(self, mock_sleep, mock_conn):
redshift_operator.execute(None)
assert mock_conn.resume_cluster.call_count == 10

@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.cluster_status")
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftAsyncHook.resume_cluster")
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftAsyncHook.get_client_async")
def test_resume_cluster(self, mock_async_client, mock_async_resume_cluster, mock_sync_cluster_status):
"""Test Resume cluster operator run"""
mock_sync_cluster_status.return_value = "paused"
mock_async_client.return_value.resume_cluster.return_value = {
"Cluster": {"ClusterIdentifier": "test_cluster", "ClusterStatus": "resuming"}
}
mock_async_resume_cluster.return_value = {"status": "success", "cluster_state": "available"}
@mock.patch.object(RedshiftHook, "conn")
def test_resume_cluster_deferrable(self, mock_conn):
"""Test Resume cluster operator deferrable"""
mock_conn.resume_cluster.return_value = True

redshift_operator = RedshiftResumeClusterOperator(
task_id="task_test",
Expand All @@ -328,22 +322,33 @@ def test_resume_cluster(self, mock_async_client, mock_async_resume_cluster, mock
redshift_operator.execute({})

assert isinstance(
exc.value.trigger, RedshiftClusterTrigger
), "Trigger is not a RedshiftClusterTrigger"
exc.value.trigger, RedshiftResumeClusterTrigger
), "Trigger is not a RedshiftResumeClusterTrigger"

@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.cluster_status")
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftAsyncHook.resume_cluster")
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftAsyncHook.get_client_async")
def test_resume_cluster_failure(
self, mock_async_client, mock_async_resume_cluster, mock_sync_cluster_statue
):
"""Test Resume cluster operator Failure"""
mock_sync_cluster_statue.return_value = "paused"
mock_async_client.return_value.resume_cluster.return_value = {
"Cluster": {"ClusterIdentifier": "test_cluster", "ClusterStatus": "resuming"}
}
mock_async_resume_cluster.return_value = {"status": "success", "cluster_state": "available"}
@mock.patch.object(RedshiftHook, "get_waiter")
@mock.patch.object(RedshiftHook, "conn")
def test_resume_cluster_wait_for_completion(self, mock_conn, mock_get_waiter):
"""Test Resume cluster operator wait for complettion"""
mock_conn.resume_cluster.return_value = True
mock_get_waiter().wait.return_value = None

redshift_operator = RedshiftResumeClusterOperator(
task_id="task_test",
cluster_identifier="test_cluster",
aws_conn_id="aws_conn_test",
wait_for_completion=True,
)
redshift_operator.execute(None)
mock_conn.resume_cluster.assert_called_once_with(ClusterIdentifier="test_cluster")

mock_get_waiter.assert_called_with("cluster_resumed")
assert mock_get_waiter.call_count == 2
mock_get_waiter().wait.assert_called_once_with(
ClusterIdentifier="test_cluster", WaiterConfig={"Delay": 10, "MaxAttempts": 10}
)

def test_resume_cluster_failure(self):
"""Test Resume cluster operator Failure"""
redshift_operator = RedshiftResumeClusterOperator(
task_id="task_test",
cluster_identifier="test_cluster",
Expand Down
Loading