Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions airflow/providers/amazon/aws/hooks/base_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import logging
import os
import uuid
import warnings
from copy import deepcopy
from functools import cached_property, wraps
from os import PathLike
Expand All @@ -53,6 +54,7 @@
from airflow.exceptions import (
AirflowException,
AirflowNotFoundException,
AirflowProviderDeprecationWarning,
)
from airflow.hooks.base import BaseHook
from airflow.providers.amazon.aws.utils.connection_wrapper import AwsConnectionWrapper
Expand Down Expand Up @@ -966,6 +968,15 @@ class BaseAsyncSessionFactory(BaseSessionFactory):
provided in Airflow connection
"""

def __init__(self, *args, **kwargs):
warnings.warn(
"airflow.providers.amazon.aws.hook.base_aws.BaseAsyncSessionFactory has been deprecated and "
"will be removed in future",
AirflowProviderDeprecationWarning,
stacklevel=2,
)
super().__init__(*args, **kwargs)

async def get_role_credentials(self) -> dict:
"""Get the role_arn, method credentials from connection details and get the role credentials detail"""
async with self._basic_session.create_client("sts", region_name=self.region_name) as client:
Expand Down Expand Up @@ -1059,6 +1070,15 @@ class AwsBaseAsyncHook(AwsBaseHook):
:param config: Configuration for botocore client.
"""

def __init__(self, *args, **kwargs):
warnings.warn(
"airflow.providers.amazon.aws.hook.base_aws.AwsBaseAsyncHook has been deprecated and "
"will be removed in future",
AirflowProviderDeprecationWarning,
stacklevel=2,
)
super().__init__(*args, **kwargs)

def get_async_session(self) -> AioSession:
"""Get the underlying aiobotocore.session.AioSession(...)."""
return BaseAsyncSessionFactory(
Expand Down
10 changes: 9 additions & 1 deletion airflow/providers/amazon/aws/hooks/redshift_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@
from __future__ import annotations

import asyncio
import warnings
from typing import Any, Sequence

import botocore.exceptions
from botocore.exceptions import ClientError

from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseAsyncHook, AwsBaseHook


Expand Down Expand Up @@ -204,7 +206,13 @@ def get_cluster_snapshot_status(self, snapshot_identifier: str):
class RedshiftAsyncHook(AwsBaseAsyncHook):
"""Interact with AWS Redshift using aiobotocore library"""

def __init__(self, *args: Any, **kwargs: Any) -> None:
def __init__(self, *args, **kwargs):
warnings.warn(
"airflow.providers.amazon.aws.hook.base_aws.RedshiftAsyncHook has been deprecated and "
"will be removed in future",
AirflowProviderDeprecationWarning,
stacklevel=2,
)
kwargs["client_type"] = "redshift"
super().__init__(*args, **kwargs)

Expand Down
69 changes: 2 additions & 67 deletions airflow/providers/amazon/aws/triggers/redshift_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,79 +18,14 @@

import asyncio
from functools import cached_property
from typing import Any, AsyncIterator
from typing import Any

from botocore.exceptions import WaiterError

from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftAsyncHook, RedshiftHook
from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftHook
from airflow.triggers.base import BaseTrigger, TriggerEvent


class RedshiftClusterTrigger(BaseTrigger):
"""AWS Redshift trigger"""

def __init__(
self,
task_id: str,
aws_conn_id: str,
cluster_identifier: str,
operation_type: str,
attempts: int,
poll_interval: float = 5.0,
):
super().__init__()
self.task_id = task_id
self.poll_interval = poll_interval
self.aws_conn_id = aws_conn_id
self.cluster_identifier = cluster_identifier
self.operation_type = operation_type
self.attempts = attempts

def serialize(self) -> tuple[str, dict[str, Any]]:
return (
"airflow.providers.amazon.aws.triggers.redshift_cluster.RedshiftClusterTrigger",
{
"task_id": self.task_id,
"poll_interval": self.poll_interval,
"aws_conn_id": self.aws_conn_id,
"cluster_identifier": self.cluster_identifier,
"attempts": self.attempts,
"operation_type": self.operation_type,
},
)

async def run(self) -> AsyncIterator[TriggerEvent]:
hook = RedshiftAsyncHook(aws_conn_id=self.aws_conn_id)
while self.attempts >= 1:
self.attempts = self.attempts - 1
try:
if self.operation_type == "pause_cluster":
response = await hook.pause_cluster(
cluster_identifier=self.cluster_identifier,
poll_interval=self.poll_interval,
)
if response.get("status") == "success":
yield TriggerEvent(response)
else:
if self.attempts < 1:
yield TriggerEvent({"status": "error", "message": f"{self.task_id} failed"})
elif self.operation_type == "resume_cluster":
response = await hook.resume_cluster(
cluster_identifier=self.cluster_identifier,
polling_period_seconds=self.poll_interval,
)
if response:
yield TriggerEvent(response)
else:
error_message = f"{self.task_id} failed"
yield TriggerEvent({"status": "error", "message": error_message})
else:
yield TriggerEvent(f"{self.operation_type} is not supported")
except Exception as e:
if self.attempts < 1:
yield TriggerEvent({"status": "error", "message": str(e)})


class RedshiftCreateClusterTrigger(BaseTrigger):
"""
Trigger for RedshiftCreateClusterOperator.
Expand Down
17 changes: 0 additions & 17 deletions tests/providers/amazon/aws/deferrable/triggers/__init__.py

This file was deleted.

This file was deleted.