Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
31b1767
Change base_aws.py to support async_conn
syedahsn Mar 6, 2023
197d7d1
Fix failing test
syedahsn Mar 22, 2023
647b414
Fix static check failures
syedahsn Mar 23, 2023
50035cf
Skip system tests if aiobotocore is not added.
syedahsn Mar 27, 2023
ec21cf1
Remove debug statement, fix static checks
syedahsn Mar 27, 2023
05d2261
Put import of aiobotocore into function so it only imports when it is…
syedahsn Mar 28, 2023
c73f7b6
Remove importorskip from system tests and base_aws
syedahsn Mar 30, 2023
03d6cbc
mock isinstance call to allow tests to pass
syedahsn Apr 5, 2023
579e591
Add mocking to isinstance where it is needed
syedahsn Apr 5, 2023
53d99fd
Mock with context to work with python 3.7
syedahsn Apr 12, 2023
441a6c7
add importorskip to base_aws.py
syedahsn Apr 13, 2023
e645dc6
Merge branch 'main' into syedahsn/deferrable-redshift-create-cluster
syedahsn Apr 13, 2023
49f376a
Add note about async connections in base_aws. Remove typo in system test
syedahsn Apr 14, 2023
c292de3
Remove cached_property on async_conn to prevent awaiting a coro twice
syedahsn Apr 17, 2023
d4caf91
Merge branch 'main' into syedahsn/deferrable-redshift-create-cluster
syedahsn Apr 17, 2023
d65372d
Merge branch 'main' into syedahsn/deferrable-redshift-create-cluster
syedahsn Apr 17, 2023
8e83f4c
Update doc string
syedahsn Apr 17, 2023
93dbf17
Make async_conn a property
syedahsn Apr 19, 2023
02c82b9
Merge branch 'main' into syedahsn/deferrable-redshift-create-cluster
syedahsn Apr 19, 2023
a96d2ee
Merge branch 'main' into syedahsn/deferrable-redshift-create-cluster
syedahsn Apr 19, 2023
6ecef57
Merge branch 'main' into syedahsn/deferrable-redshift-create-cluster
vandonr-amz Apr 19, 2023
bc2f87e
Fixed formatting in trigger README.md.
syedahsn Apr 21, 2023
f8bb4a4
Fix static checks
syedahsn Apr 21, 2023
3c98304
Merge branch 'main' into syedahsn/deferrable-redshift-create-cluster
syedahsn Apr 21, 2023
98cdfc3
Merge branch 'main' into syedahsn/deferrable-redshift-create-cluster
syedahsn Apr 22, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 94 additions & 17 deletions airflow/providers/amazon/aws/hooks/base_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@
)
from airflow.hooks.base import BaseHook
from airflow.providers.amazon.aws.utils.connection_wrapper import AwsConnectionWrapper
from airflow.providers.amazon.aws.waiters.base_waiter import BaseBotoWaiter
from airflow.providers_manager import ProvidersManager
from airflow.utils.helpers import exactly_one
from airflow.utils.log.logging_mixin import LoggingMixin
Expand All @@ -71,12 +70,15 @@

class BaseSessionFactory(LoggingMixin):
"""
Base AWS Session Factory class to handle boto3 session creation.
Base AWS Session Factory class to handle synchronous and async boto session creation.
It can handle most of the AWS supported authentication methods.

User can also derive from this class to have full control of boto3 session
creation or to support custom federation.

Note: Not all features implemented for synchronous sessions are available for async
sessions.

.. seealso::
- :ref:`howto/connection:aws:session-factory`
"""
Expand Down Expand Up @@ -126,17 +128,50 @@ def role_arn(self) -> str | None:
"""Assume Role ARN from AWS Connection"""
return self.conn.role_arn

def create_session(self) -> boto3.session.Session:
"""Create boto3 Session from connection config."""
def _apply_session_kwargs(self, session):
if self.conn.session_kwargs.get("profile_name", None) is not None:
session.set_config_variable("profile", self.conn.session_kwargs["profile_name"])

if (
self.conn.session_kwargs.get("aws_access_key_id", None)
or self.conn.session_kwargs.get("aws_secret_access_key", None)
or self.conn.session_kwargs.get("aws_session_token", None)
):
session.set_credentials(
self.conn.session_kwargs["aws_access_key_id"],
self.conn.session_kwargs["aws_secret_access_key"],
self.conn.session_kwargs["aws_session_token"],
)

if self.conn.session_kwargs.get("region_name", None) is not None:
session.set_config_variable("region", self.conn.session_kwargs["region_name"])

def get_async_session(self):
from aiobotocore.session import get_session as async_get_session

return async_get_session()

def create_session(self, deferrable: bool = False) -> boto3.session.Session:
"""Create boto3 or aiobotocore Session from connection config."""
if not self.conn:
self.log.info(
"No connection ID provided. Fallback on boto3 credential strategy (region_name=%r). "
"See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html",
self.region_name,
)
return boto3.session.Session(region_name=self.region_name)
if deferrable:
session = self.get_async_session()
self._apply_session_kwargs(session)
return session
else:
return boto3.session.Session(region_name=self.region_name)
elif not self.role_arn:
return self.basic_session
if deferrable:
session = self.get_async_session()
self._apply_session_kwargs(session)
return session
else:
return self.basic_session

# Values stored in ``AwsConnectionWrapper.session_kwargs`` are intended to be used only
# to create the initial boto3 session.
Expand All @@ -149,12 +184,18 @@ def create_session(self) -> boto3.session.Session:
assume_session_kwargs = {}
if self.conn.region_name:
assume_session_kwargs["region_name"] = self.conn.region_name
return self._create_session_with_assume_role(session_kwargs=assume_session_kwargs)
return self._create_session_with_assume_role(
session_kwargs=assume_session_kwargs, deferrable=deferrable
)

def _create_basic_session(self, session_kwargs: dict[str, Any]) -> boto3.session.Session:
return boto3.session.Session(**session_kwargs)

def _create_session_with_assume_role(self, session_kwargs: dict[str, Any]) -> boto3.session.Session:
def _create_session_with_assume_role(
self, session_kwargs: dict[str, Any], deferrable: bool = False
) -> boto3.session.Session:
from aiobotocore.session import get_session as async_get_session

if self.conn.assume_role_method == "assume_role_with_web_identity":
# Deferred credentials have no initial credentials
credential_fetcher = self._get_web_identity_credential_fetcher()
Expand All @@ -171,10 +212,10 @@ def _create_session_with_assume_role(self, session_kwargs: dict[str, Any]) -> bo
method="sts-assume-role",
)

session = botocore.session.get_session()
session = async_get_session() if deferrable else botocore.session.get_session()

session._credentials = credentials
region_name = self.basic_session.region_name
session.set_config_variable("region", region_name)
session.set_config_variable("region", self.basic_session.region_name)

return boto3.session.Session(botocore_session=session, **session_kwargs)

Expand Down Expand Up @@ -530,11 +571,11 @@ def verify(self) -> bool | str | None:
"""Verify or not SSL certificates boto3 client/resource read-only property."""
return self.conn_config.verify

def get_session(self, region_name: str | None = None) -> boto3.session.Session:
def get_session(self, region_name: str | None = None, deferrable: bool = False) -> boto3.session.Session:
"""Get the underlying boto3.session.Session(region_name=region_name)."""
return SessionFactory(
conn=self.conn_config, region_name=region_name, config=self.config
).create_session()
).create_session(deferrable=deferrable)

def _get_config(self, config: Config | None = None) -> Config:
"""
Expand All @@ -557,10 +598,19 @@ def get_client_type(
self,
region_name: str | None = None,
config: Config | None = None,
deferrable: bool = False,
) -> boto3.client:
"""Get the underlying boto3 client using boto3 session"""
client_type = self.client_type
session = self.get_session(region_name=region_name)
session = self.get_session(region_name=region_name, deferrable=deferrable)
if not isinstance(session, boto3.session.Session):
return session.create_client(
client_type,
endpoint_url=self.conn_config.endpoint_url,
config=self._get_config(config),
verify=self.verify,
)

return session.client(
client_type,
endpoint_url=self.conn_config.endpoint_url,
Expand Down Expand Up @@ -600,6 +650,14 @@ def conn(self) -> BaseAwsConnection:
else:
return self.get_resource_type(region_name=self.region_name)

@property
def async_conn(self):
"""Get an aiobotocore client to use for async operations."""
if not self.client_type:
raise ValueError("client_type must be specified.")

return self.get_client_type(region_name=self.region_name, deferrable=True)

@cached_property
def conn_client_meta(self) -> ClientMeta:
"""Get botocore client metadata from Hook connection (cached)."""
Expand Down Expand Up @@ -753,26 +811,45 @@ def waiter_path(self) -> PathLike[str] | None:
path = Path(__file__).parents[1].joinpath(f"waiters/{filename}.json").resolve()
return path if path.exists() else None

def get_waiter(self, waiter_name: str, parameters: dict[str, str] | None = None) -> Waiter:
def get_waiter(
self,
waiter_name: str,
parameters: dict[str, str] | None = None,
deferrable: bool = False,
client=None,
) -> Waiter:
"""
First checks if there is a custom waiter with the provided waiter_name and
uses that if it exists, otherwise it will check the service client for a
waiter that matches the name and pass that through.

If `deferrable` is True, the waiter will be an AIOWaiter, generated from the
client that is passed as a parameter. If `deferrable` is True, `client` must be
provided.

:param waiter_name: The name of the waiter. The name should exactly match the
name of the key in the waiter model file (typically this is CamelCase).
:param parameters: will scan the waiter config for the keys of that dict, and replace them with the
corresponding value. If a custom waiter has such keys to be expanded, they need to be provided
here.
:param deferrable: If True, the waiter is going to be an async custom waiter.

"""
from airflow.providers.amazon.aws.waiters.base_waiter import BaseBotoWaiter

if deferrable and not client:
raise ValueError("client must be provided for a deferrable waiter.")
client = client or self.conn
if self.waiter_path and (waiter_name in self._list_custom_waiters()):
# Technically if waiter_name is in custom_waiters then self.waiter_path must
# exist but MyPy doesn't like the fact that self.waiter_path could be None.
with open(self.waiter_path) as config_file:
config = json.loads(config_file.read())

config = self._apply_parameters_value(config, waiter_name, parameters)
return BaseBotoWaiter(client=self.conn, model_config=config).waiter(waiter_name)
return BaseBotoWaiter(client=client, model_config=config, deferrable=deferrable).waiter(
waiter_name
)
# If there is no custom waiter found for the provided name,
# then try checking the service's official waiters.
return self.conn.get_waiter(waiter_name)
Expand Down Expand Up @@ -941,7 +1018,7 @@ def _basic_session(self) -> AioSession:
aio_session.set_config_variable("region", region_name)
return aio_session

def create_session(self) -> AioSession:
def create_session(self, deferrable: bool = False) -> AioSession:
"""Create aiobotocore Session from connection and config."""
if not self._conn:
self.log.info("No connection ID provided. Fallback on boto3 credential strategy")
Expand Down
4 changes: 3 additions & 1 deletion airflow/providers/amazon/aws/hooks/batch_waiters.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,9 @@ def waiter_model(self) -> botocore.waiter.WaiterModel:
"""
return self._waiter_model

def get_waiter(self, waiter_name: str, _: dict[str, str] | None = None) -> botocore.waiter.Waiter:
def get_waiter(
self, waiter_name: str, _: dict[str, str] | None = None, deferrable: bool = False, client=None
) -> botocore.waiter.Waiter:
"""
Get an AWS Batch service waiter, using the configured ``.waiter_model``.

Expand Down
23 changes: 22 additions & 1 deletion airflow/providers/amazon/aws/operators/redshift_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftHook
from airflow.providers.amazon.aws.triggers.redshift_cluster import RedshiftClusterTrigger
from airflow.providers.amazon.aws.triggers.redshift_cluster import (
RedshiftClusterTrigger,
RedshiftCreateClusterTrigger,
)

if TYPE_CHECKING:
from airflow.utils.context import Context
Expand Down Expand Up @@ -88,6 +91,7 @@ class RedshiftCreateClusterOperator(BaseOperator):
:param wait_for_completion: Whether wait for the cluster to be in ``available`` state
:param max_attempt: The maximum number of attempts to be made. Default: 5
:param poll_interval: The amount of time in seconds to wait between attempts. Default: 60
:param deferrable: If True, the operator will run in deferrable mode
"""

template_fields: Sequence[str] = (
Expand Down Expand Up @@ -140,6 +144,7 @@ def __init__(
wait_for_completion: bool = False,
max_attempt: int = 5,
poll_interval: int = 60,
deferrable: bool = False,
**kwargs,
):
super().__init__(**kwargs)
Expand Down Expand Up @@ -180,6 +185,7 @@ def __init__(
self.wait_for_completion = wait_for_completion
self.max_attempt = max_attempt
self.poll_interval = poll_interval
self.deferrable = deferrable
self.kwargs = kwargs

def execute(self, context: Context):
Expand Down Expand Up @@ -252,6 +258,16 @@ def execute(self, context: Context):
self.master_user_password,
params,
)
if self.deferrable:
self.defer(
trigger=RedshiftCreateClusterTrigger(
cluster_identifier=self.cluster_identifier,
poll_interval=self.poll_interval,
max_attempt=self.max_attempt,
aws_conn_id=self.aws_conn_id,
),
method_name="execute_complete",
)
if self.wait_for_completion:
redshift_hook.get_conn().get_waiter("cluster_available").wait(
ClusterIdentifier=self.cluster_identifier,
Expand All @@ -264,6 +280,11 @@ def execute(self, context: Context):
self.log.info("Created Redshift cluster %s", self.cluster_identifier)
self.log.info(cluster)

def execute_complete(self, context, event=None):
if event["status"] != "success":
raise AirflowException(f"Error creating cluster: {event}")
return


class RedshiftCreateClusterSnapshotOperator(BaseOperator):
"""
Expand Down
Loading