diff --git a/airflow/providers/amazon/aws/hooks/base_aws.py b/airflow/providers/amazon/aws/hooks/base_aws.py index 83bebfe0a20ca..541ef37d312a2 100644 --- a/airflow/providers/amazon/aws/hooks/base_aws.py +++ b/airflow/providers/amazon/aws/hooks/base_aws.py @@ -57,7 +57,6 @@ ) from airflow.hooks.base import BaseHook from airflow.providers.amazon.aws.utils.connection_wrapper import AwsConnectionWrapper -from airflow.providers.amazon.aws.waiters.base_waiter import BaseBotoWaiter from airflow.providers_manager import ProvidersManager from airflow.utils.helpers import exactly_one from airflow.utils.log.logging_mixin import LoggingMixin @@ -71,12 +70,15 @@ class BaseSessionFactory(LoggingMixin): """ - Base AWS Session Factory class to handle boto3 session creation. + Base AWS Session Factory class to handle synchronous and async boto session creation. It can handle most of the AWS supported authentication methods. User can also derive from this class to have full control of boto3 session creation or to support custom federation. + Note: Not all features implemented for synchronous sessions are available for async + sessions. + .. seealso:: - :ref:`howto/connection:aws:session-factory` """ @@ -126,17 +128,50 @@ def role_arn(self) -> str | None: """Assume Role ARN from AWS Connection""" return self.conn.role_arn - def create_session(self) -> boto3.session.Session: - """Create boto3 Session from connection config.""" + def _apply_session_kwargs(self, session): + if self.conn.session_kwargs.get("profile_name", None) is not None: + session.set_config_variable("profile", self.conn.session_kwargs["profile_name"]) + + if ( + self.conn.session_kwargs.get("aws_access_key_id", None) + or self.conn.session_kwargs.get("aws_secret_access_key", None) + or self.conn.session_kwargs.get("aws_session_token", None) + ): + session.set_credentials( + self.conn.session_kwargs["aws_access_key_id"], + self.conn.session_kwargs["aws_secret_access_key"], + self.conn.session_kwargs["aws_session_token"], + ) + + if self.conn.session_kwargs.get("region_name", None) is not None: + session.set_config_variable("region", self.conn.session_kwargs["region_name"]) + + def get_async_session(self): + from aiobotocore.session import get_session as async_get_session + + return async_get_session() + + def create_session(self, deferrable: bool = False) -> boto3.session.Session: + """Create boto3 or aiobotocore Session from connection config.""" if not self.conn: self.log.info( "No connection ID provided. Fallback on boto3 credential strategy (region_name=%r). " "See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html", self.region_name, ) - return boto3.session.Session(region_name=self.region_name) + if deferrable: + session = self.get_async_session() + self._apply_session_kwargs(session) + return session + else: + return boto3.session.Session(region_name=self.region_name) elif not self.role_arn: - return self.basic_session + if deferrable: + session = self.get_async_session() + self._apply_session_kwargs(session) + return session + else: + return self.basic_session # Values stored in ``AwsConnectionWrapper.session_kwargs`` are intended to be used only # to create the initial boto3 session. @@ -149,12 +184,18 @@ def create_session(self) -> boto3.session.Session: assume_session_kwargs = {} if self.conn.region_name: assume_session_kwargs["region_name"] = self.conn.region_name - return self._create_session_with_assume_role(session_kwargs=assume_session_kwargs) + return self._create_session_with_assume_role( + session_kwargs=assume_session_kwargs, deferrable=deferrable + ) def _create_basic_session(self, session_kwargs: dict[str, Any]) -> boto3.session.Session: return boto3.session.Session(**session_kwargs) - def _create_session_with_assume_role(self, session_kwargs: dict[str, Any]) -> boto3.session.Session: + def _create_session_with_assume_role( + self, session_kwargs: dict[str, Any], deferrable: bool = False + ) -> boto3.session.Session: + from aiobotocore.session import get_session as async_get_session + if self.conn.assume_role_method == "assume_role_with_web_identity": # Deferred credentials have no initial credentials credential_fetcher = self._get_web_identity_credential_fetcher() @@ -171,10 +212,10 @@ def _create_session_with_assume_role(self, session_kwargs: dict[str, Any]) -> bo method="sts-assume-role", ) - session = botocore.session.get_session() + session = async_get_session() if deferrable else botocore.session.get_session() + session._credentials = credentials - region_name = self.basic_session.region_name - session.set_config_variable("region", region_name) + session.set_config_variable("region", self.basic_session.region_name) return boto3.session.Session(botocore_session=session, **session_kwargs) @@ -530,11 +571,11 @@ def verify(self) -> bool | str | None: """Verify or not SSL certificates boto3 client/resource read-only property.""" return self.conn_config.verify - def get_session(self, region_name: str | None = None) -> boto3.session.Session: + def get_session(self, region_name: str | None = None, deferrable: bool = False) -> boto3.session.Session: """Get the underlying boto3.session.Session(region_name=region_name).""" return SessionFactory( conn=self.conn_config, region_name=region_name, config=self.config - ).create_session() + ).create_session(deferrable=deferrable) def _get_config(self, config: Config | None = None) -> Config: """ @@ -557,10 +598,19 @@ def get_client_type( self, region_name: str | None = None, config: Config | None = None, + deferrable: bool = False, ) -> boto3.client: """Get the underlying boto3 client using boto3 session""" client_type = self.client_type - session = self.get_session(region_name=region_name) + session = self.get_session(region_name=region_name, deferrable=deferrable) + if not isinstance(session, boto3.session.Session): + return session.create_client( + client_type, + endpoint_url=self.conn_config.endpoint_url, + config=self._get_config(config), + verify=self.verify, + ) + return session.client( client_type, endpoint_url=self.conn_config.endpoint_url, @@ -600,6 +650,14 @@ def conn(self) -> BaseAwsConnection: else: return self.get_resource_type(region_name=self.region_name) + @property + def async_conn(self): + """Get an aiobotocore client to use for async operations.""" + if not self.client_type: + raise ValueError("client_type must be specified.") + + return self.get_client_type(region_name=self.region_name, deferrable=True) + @cached_property def conn_client_meta(self) -> ClientMeta: """Get botocore client metadata from Hook connection (cached).""" @@ -753,18 +811,35 @@ def waiter_path(self) -> PathLike[str] | None: path = Path(__file__).parents[1].joinpath(f"waiters/{filename}.json").resolve() return path if path.exists() else None - def get_waiter(self, waiter_name: str, parameters: dict[str, str] | None = None) -> Waiter: + def get_waiter( + self, + waiter_name: str, + parameters: dict[str, str] | None = None, + deferrable: bool = False, + client=None, + ) -> Waiter: """ First checks if there is a custom waiter with the provided waiter_name and uses that if it exists, otherwise it will check the service client for a waiter that matches the name and pass that through. + If `deferrable` is True, the waiter will be an AIOWaiter, generated from the + client that is passed as a parameter. If `deferrable` is True, `client` must be + provided. + :param waiter_name: The name of the waiter. The name should exactly match the name of the key in the waiter model file (typically this is CamelCase). :param parameters: will scan the waiter config for the keys of that dict, and replace them with the corresponding value. If a custom waiter has such keys to be expanded, they need to be provided here. + :param deferrable: If True, the waiter is going to be an async custom waiter. + """ + from airflow.providers.amazon.aws.waiters.base_waiter import BaseBotoWaiter + + if deferrable and not client: + raise ValueError("client must be provided for a deferrable waiter.") + client = client or self.conn if self.waiter_path and (waiter_name in self._list_custom_waiters()): # Technically if waiter_name is in custom_waiters then self.waiter_path must # exist but MyPy doesn't like the fact that self.waiter_path could be None. @@ -772,7 +847,9 @@ def get_waiter(self, waiter_name: str, parameters: dict[str, str] | None = None) config = json.loads(config_file.read()) config = self._apply_parameters_value(config, waiter_name, parameters) - return BaseBotoWaiter(client=self.conn, model_config=config).waiter(waiter_name) + return BaseBotoWaiter(client=client, model_config=config, deferrable=deferrable).waiter( + waiter_name + ) # If there is no custom waiter found for the provided name, # then try checking the service's official waiters. return self.conn.get_waiter(waiter_name) @@ -941,7 +1018,7 @@ def _basic_session(self) -> AioSession: aio_session.set_config_variable("region", region_name) return aio_session - def create_session(self) -> AioSession: + def create_session(self, deferrable: bool = False) -> AioSession: """Create aiobotocore Session from connection and config.""" if not self._conn: self.log.info("No connection ID provided. Fallback on boto3 credential strategy") diff --git a/airflow/providers/amazon/aws/hooks/batch_waiters.py b/airflow/providers/amazon/aws/hooks/batch_waiters.py index dcf111591c9bf..cb852acf9d8b8 100644 --- a/airflow/providers/amazon/aws/hooks/batch_waiters.py +++ b/airflow/providers/amazon/aws/hooks/batch_waiters.py @@ -138,7 +138,9 @@ def waiter_model(self) -> botocore.waiter.WaiterModel: """ return self._waiter_model - def get_waiter(self, waiter_name: str, _: dict[str, str] | None = None) -> botocore.waiter.Waiter: + def get_waiter( + self, waiter_name: str, _: dict[str, str] | None = None, deferrable: bool = False, client=None + ) -> botocore.waiter.Waiter: """ Get an AWS Batch service waiter, using the configured ``.waiter_model``. diff --git a/airflow/providers/amazon/aws/operators/redshift_cluster.py b/airflow/providers/amazon/aws/operators/redshift_cluster.py index 5183dc2e2c120..77ac521c9baf6 100644 --- a/airflow/providers/amazon/aws/operators/redshift_cluster.py +++ b/airflow/providers/amazon/aws/operators/redshift_cluster.py @@ -22,7 +22,10 @@ from airflow.exceptions import AirflowException from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftHook -from airflow.providers.amazon.aws.triggers.redshift_cluster import RedshiftClusterTrigger +from airflow.providers.amazon.aws.triggers.redshift_cluster import ( + RedshiftClusterTrigger, + RedshiftCreateClusterTrigger, +) if TYPE_CHECKING: from airflow.utils.context import Context @@ -88,6 +91,7 @@ class RedshiftCreateClusterOperator(BaseOperator): :param wait_for_completion: Whether wait for the cluster to be in ``available`` state :param max_attempt: The maximum number of attempts to be made. Default: 5 :param poll_interval: The amount of time in seconds to wait between attempts. Default: 60 + :param deferrable: If True, the operator will run in deferrable mode """ template_fields: Sequence[str] = ( @@ -140,6 +144,7 @@ def __init__( wait_for_completion: bool = False, max_attempt: int = 5, poll_interval: int = 60, + deferrable: bool = False, **kwargs, ): super().__init__(**kwargs) @@ -180,6 +185,7 @@ def __init__( self.wait_for_completion = wait_for_completion self.max_attempt = max_attempt self.poll_interval = poll_interval + self.deferrable = deferrable self.kwargs = kwargs def execute(self, context: Context): @@ -252,6 +258,16 @@ def execute(self, context: Context): self.master_user_password, params, ) + if self.deferrable: + self.defer( + trigger=RedshiftCreateClusterTrigger( + cluster_identifier=self.cluster_identifier, + poll_interval=self.poll_interval, + max_attempt=self.max_attempt, + aws_conn_id=self.aws_conn_id, + ), + method_name="execute_complete", + ) if self.wait_for_completion: redshift_hook.get_conn().get_waiter("cluster_available").wait( ClusterIdentifier=self.cluster_identifier, @@ -264,6 +280,11 @@ def execute(self, context: Context): self.log.info("Created Redshift cluster %s", self.cluster_identifier) self.log.info(cluster) + def execute_complete(self, context, event=None): + if event["status"] != "success": + raise AirflowException(f"Error creating cluster: {event}") + return + class RedshiftCreateClusterSnapshotOperator(BaseOperator): """ diff --git a/airflow/providers/amazon/aws/triggers/README.md b/airflow/providers/amazon/aws/triggers/README.md new file mode 100644 index 0000000000000..cd0c0baae5d53 --- /dev/null +++ b/airflow/providers/amazon/aws/triggers/README.md @@ -0,0 +1,153 @@ + + +# Writing Deferrable Operators for Amazon Provider Package + + +Before writing deferrable operators, it is strongly recommended to read and familiarize yourself with the official [documentation](https://airflow.apache.org/docs/apache-airflow/stable/authoring-and-scheduling/deferring.html) of Deferrable Operators. +The purpose of this guide is to provide a standardized way to convert existing Amazon Provider Package (AMPP) operators to deferrable operators. Due to the varied complexities of available operators, it is impossible to define one method that will work for every operator. +The method described in this guide should work for many of the AMPP operators, but it is important to study each operator before determining whether the steps outlined below are applicable. + +Although it varies from operator to operator, a typical AMPP operator has 3 stages: + +1. A pre-processing stage, where information is looked up via boto3 API calls, parameters are formatted etc. The complexity of this stage depends on the complexity of the task the operator is attempting to do. Some operators (e.g. Sagemaker) have a lot of pre-processing, whereas others require little to no pre-processing. +2. The "main" call to the boto3 API to start an operation. This is the task that the operator is attempting to complete. This could be a request to provision a resource, request to change the state of a resource, start a job on a resource etc. Regardless of the operation, the boto3 API returns a response instantly (ignoring network delays) with a response detailing the results of the query. For example, in the case of a resource provisioning request, although the resource can take significant time to be allocated, the boto3 API returns a response to the caller without waiting for the operation to be completed. +3. The last, often optional, stage is to wait for the operation initiated in stage 2 to be completed. This usually involves polling the boto3 API at set intervals, and waiting for a certain criteria to be met. + +In general, it is the last polling stage where we can defer the operator to a trigger which can handle the polling operation. The botocore library defines waiters for certain services, which are built-in functions that poll a service and wait for a given criteria to be met. +As part of our work for writing deferrable operators, we have extended the built-in waiters to allow custom waiters, which follow the same logic, but for services not included in the botocore library. +We can use these custom waiters, along with the built-in waiters to implement the polling logic of the deferrable operators. + +The first step to making an existing operator deferrable is to add `deferrable` as a parameter to the operator, and initialize it in the constructor of the operator. +The next step is to determine where the operator should be deferred. This will be dependent on what the operator does, and how it is written. Although every operator is different, there are a few guidelines to determine the best place to defer an operator. + +1. If the operator has a `wait_for_completion` parameter, the `self.defer` method should be called right before the check for wait_for_completion . +2. If there is no `wait_for_completion` , look for the "main" task that the operator does. Often, operators will make various describe calls to to the boto3 API to verify certain conditions, or look up some information before performing its "main" task. Often, right after the "main" call to the boto3 API is made is a good place to call `self.defer`. + + +Once the location to defer is decided in the operator, call the `self.defer` method if the `deferrable` flag is `True`. The `self.defer` method takes in several parameters, listed below: + +1. `trigger`: This is the trigger which you want to pass the execution to. We will write this trigger in just a moment. +2. `method_name`: This specifies the name of the method you want to execute once the trigger completes its execution. The trigger cannot pass the execution back to the execute method of the operator. By convention, the name for this method is `execute_complete`. +3. `timeout`: An optional parameter that controls the length of time the Trigger can execute for before timing out. This defaults to `None`, meaning no timeout. +4. `kwargs`: Additional keyword arguments to pass to `method_name`. Default is `{}`. + +The Trigger is the main component of deferrable operators. They must be placed in the `airflow/providers/amazon/aws/triggers/` folder. All Triggers must implement the following 3 methods: + +1. `__init__`: the constructor which receives parameters from the operator. These must be JSON serializable. +2. `serialize`: a function that returns the classpath, as well as keyword arguments to the `__init__` method as a tuple +3. `run` : the asynchronous function that is responsible for awaiting the asynchronous operations. + +Ideally, when the operator has deferred itself, it has already initiated the "main" task of the operator, and is now waiting for a certain resource to reach a certain state. +As mentioned earlier, the botocore library defines a `Waiter` class for many services, which implements a `wait` method that can be configured via a config file to poll the boto3 API at set intervals, and return if the success criteria is met. +The aiobotocore library, which is used to make asynchronous botocore calls, defines an `AIOWaiter` class, which also implements a wait method that behaves identical to the botocore method, except that it works asynchronously. +Therefore, any botocore waiter is available as an aiobotocore waiter, and can be used to asynchronously poll a service until the desired criteria is met. + +To call the asynchronous `wait` function, first create a hook for the particular service. For example, for a Redshift hook, it would look like this: + +```python +self.redshift_hook = RedshiftHook(aws_conn_id=self.aws_conn_id) +``` + +With this hook, we can use the async_conn property to get access to the aiobotocore client: + +```python +async with self.redshift_hook.async_conn as client: + await client.get_waiter("cluster_available").wait( + ClusterIdentifier=self.cluster_identifier, + WaiterConfig={ + "Delay": int(self.poll_interval), + "MaxAttempts": int(self.max_attempt), + }, + ) +``` + +In this case, we are using the built-in cluster_available waiter. If we wanted to use a custom waiter, we would change the code slightly to use the `get_waiter` function from the hook, rather than the aiobotocore client: + +```python +async with self.redshift_hook.async_conn as client: + waiter = self.redshift_hook.get_waiter("cluster_paused", deferrable=True, client=client) + await waiter.wait( + ClusterIdentifier=self.cluster_identifier, + WaiterConfig={ + "Delay": int(self.poll_interval), + "MaxAttempts": int(self.max_attempt), + }, + ) +``` + +Here, we are calling the `get_waiter` function defined in `base_aws.py` which takes an optional argument of `deferrable` (set to `True`), and the `aiobotocore` client. `cluster_paused` is a custom boto waiter defined in `redshift.json` in the `airflow/providers/amazon/aws/waiters` folder. In general, the config file for a custom waiter should be named as `.json`. The config for `cluster_paused` is shown below: + +```json +{ + "version": 2, + "waiters": { + "cluster_paused": { + "operation": "DescribeClusters", + "delay": 30, + "maxAttempts": 60, + "acceptors": [ + { + "matcher": "pathAll", + "argument": "Clusters[].ClusterStatus", + "expected": "paused", + "state": "success" + }, + { + "expected": "ClusterNotFound", + "matcher": "error", + "state": "retry" + }, + { + "expected": "deleting", + "matcher": "pathAny", + "state": "failure", + "argument": "Clusters[].ClusterStatus" + } + ] + }, + } +} +``` + +For more information about writing custom waiter, see the [README.md](https://github.com/apache/airflow/blob/main/airflow/providers/amazon/aws/waiters/README.md) for custom waiters. + +In some cases, a built-in or custom waiter may not be able to solve the problem. In such cases, the asynchronous method used to poll the boto3 API would need to be defined in the hook of the service being used. This method is essentially the same as the synchronous version of the method, except that it will use the aiobotocore client, and will be awaited. For the Redshift example, the async `describe_clusters` method would look as follows: + +```python +async with self.async_conn as client: + response = client.describe_clusters(ClusterIdentifier=self.cluster_identifier) +``` + +This async method can be used in the Trigger to poll the boto3 API. The polling logic will need to be implemented manually, taking care to use `asyncio.sleep()` rather than `time.sleep()`. + +The last step in the Trigger is to yield a `TriggerEvent` that will be used to alert the `Triggerer` that the Trigger has finished execution. The `TriggerEvent` can pass information from the trigger to the `method_name` method named in the `self.defer` call in the operator. In the Redshift example, the `TriggerEvent` would look as follows: + +``` +yield TriggerEvent({"status": "success", "message": "Cluster Created"}) +``` + +The object passed through the `TriggerEvent` can be captured in the `method_name` method through an `event` parameter. This can be used to determine what needs to be done based on the outcome of the Trigger execution. In the Redshift case, we can simply check the status of the event, and raise an Exception if something went wrong. + +```python +def execute_complete(self, context, event=None): + if event["status"] != "success": + raise AirflowException(f"Error creating cluster: {event}") + return +``` diff --git a/airflow/providers/amazon/aws/triggers/redshift_cluster.py b/airflow/providers/amazon/aws/triggers/redshift_cluster.py index a32a6efa19924..2f831fa14c2f1 100644 --- a/airflow/providers/amazon/aws/triggers/redshift_cluster.py +++ b/airflow/providers/amazon/aws/triggers/redshift_cluster.py @@ -18,7 +18,8 @@ from typing import Any, AsyncIterator -from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftAsyncHook +from airflow.compat.functools import cached_property +from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftAsyncHook, RedshiftHook from airflow.triggers.base import BaseTrigger, TriggerEvent @@ -85,3 +86,54 @@ async def run(self) -> AsyncIterator["TriggerEvent"]: except Exception as e: if self.attempts < 1: yield TriggerEvent({"status": "error", "message": str(e)}) + + +class RedshiftCreateClusterTrigger(BaseTrigger): + """ + Trigger for RedshiftCreateClusterOperator. + The trigger will asynchronously poll the boto3 API and wait for the + Redshift cluster to be in the `available` state. + + :param cluster_identifier: A unique identifier for the cluster. + :param poll_interval: The amount of time in seconds to wait between attempts. + :param max_attempt: The maximum number of attempts to be made. + :param aws_conn_id: The Airflow connection used for AWS credentials. + """ + + def __init__( + self, + cluster_identifier: str, + poll_interval: int, + max_attempt: int, + aws_conn_id: str, + ): + self.cluster_identifier = cluster_identifier + self.poll_interval = poll_interval + self.max_attempt = max_attempt + self.aws_conn_id = aws_conn_id + + def serialize(self) -> tuple[str, dict[str, Any]]: + return ( + "airflow.providers.amazon.aws.triggers.redshift_cluster.RedshiftCreateClusterTrigger", + { + "cluster_identifier": str(self.cluster_identifier), + "poll_interval": str(self.poll_interval), + "max_attempt": str(self.max_attempt), + "aws_conn_id": str(self.aws_conn_id), + }, + ) + + @cached_property + def hook(self) -> RedshiftHook: + return RedshiftHook(aws_conn_id=self.aws_conn_id) + + async def run(self): + async with self.hook.async_conn as client: + await client.get_waiter("cluster_available").wait( + ClusterIdentifier=self.cluster_identifier, + WaiterConfig={ + "Delay": int(self.poll_interval), + "MaxAttempts": int(self.max_attempt), + }, + ) + yield TriggerEvent({"status": "success", "message": "Cluster Created"}) diff --git a/airflow/providers/amazon/aws/waiters/base_waiter.py b/airflow/providers/amazon/aws/waiters/base_waiter.py index 0d9f8a1d4e407..488767a084a21 100644 --- a/airflow/providers/amazon/aws/waiters/base_waiter.py +++ b/airflow/providers/amazon/aws/waiters/base_waiter.py @@ -28,9 +28,20 @@ class BaseBotoWaiter: For more details, see airflow/providers/amazon/aws/waiters/README.md """ - def __init__(self, client: boto3.client, model_config: dict) -> None: + def __init__(self, client: boto3.client, model_config: dict, deferrable: bool = False) -> None: self.model = WaiterModel(model_config) self.client = client + self.deferrable = deferrable + + def _get_async_waiter_with_client(self, waiter_name: str): + from aiobotocore.waiter import create_waiter_with_client as create_async_waiter_with_client + + return create_async_waiter_with_client( + waiter_name=waiter_name, waiter_model=self.model, client=self.client + ) def waiter(self, waiter_name: str) -> Waiter: + if self.deferrable: + return self._get_async_waiter_with_client(waiter_name=waiter_name) + return create_waiter_with_client(waiter_name=waiter_name, waiter_model=self.model, client=self.client) diff --git a/airflow/providers/amazon/provider.yaml b/airflow/providers/amazon/provider.yaml index e2f3c142324ed..79640d68b8578 100644 --- a/airflow/providers/amazon/provider.yaml +++ b/airflow/providers/amazon/provider.yaml @@ -72,6 +72,7 @@ dependencies: - mypy-boto3-rds>=1.24.0 - mypy-boto3-redshift-data>=1.24.0 - mypy-boto3-appflow>=1.24.0 + - aiobotocore[boto3]>=2.2.0 integrations: - integration-name: Amazon Athena diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index 79e919591815e..629a5b0613c3a 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -17,6 +17,7 @@ }, "amazon": { "deps": [ + "aiobotocore[boto3]>=2.2.0", "apache-airflow-providers-common-sql>=1.3.1", "apache-airflow>=2.3.0", "asgiref", diff --git a/tests/providers/amazon/aws/hooks/test_base_aws.py b/tests/providers/amazon/aws/hooks/test_base_aws.py index a06a0d9fb3feb..3faafd6bbee17 100644 --- a/tests/providers/amazon/aws/hooks/test_base_aws.py +++ b/tests/providers/amazon/aws/hooks/test_base_aws.py @@ -47,6 +47,8 @@ from airflow.providers.amazon.aws.utils.connection_wrapper import AwsConnectionWrapper from tests.test_utils.config import conf_vars +pytest.importorskip("aiobotocore") + MOCK_AWS_CONN_ID = "mock-conn-id" MOCK_CONN_TYPE = "aws" MOCK_BOTO3_SESSION = mock.MagicMock(return_value="Mock boto3.session.Session") @@ -227,20 +229,39 @@ def test_create_session_from_credentials(self, mock_boto3_session, region_name, mock_boto3_session.assert_called_once_with(**expected_arguments) assert session == MOCK_BOTO3_SESSION + @pytest.mark.parametrize("region_name", ["eu-central-1", None]) + @pytest.mark.parametrize("profile_name", ["default", None]) + def test_async_create_session_from_credentials(self, region_name, profile_name): + mock_conn = Connection( + conn_type=MOCK_CONN_TYPE, conn_id=MOCK_AWS_CONN_ID, extra={"profile_name": profile_name} + ) + mock_conn_config = AwsConnectionWrapper(conn=mock_conn) + sf = BaseSessionFactory(conn=mock_conn_config, region_name=region_name, config=None) + async_session = sf.create_session(deferrable=True) + if region_name: + session_region = async_session.get_config_variable("region") + assert session_region == region_name + + session_profile = async_session.get_config_variable("profile") + + assert session_profile == profile_name + + config_for_credentials_test = [ + ( + "assume-with-initial-creds", + { + "aws_access_key_id": "mock_aws_access_key_id", + "aws_secret_access_key": "mock_aws_access_key_id", + "aws_session_token": "mock_aws_session_token", + }, + ), + ("assume-without-initial-creds", {}), + ] + @mock_sts @pytest.mark.parametrize( "conn_id, conn_extra", - [ - ( - "assume-with-initial-creds", - { - "aws_access_key_id": "mock_aws_access_key_id", - "aws_secret_access_key": "mock_aws_access_key_id", - "aws_session_token": "mock_aws_session_token", - }, - ), - ("assume-without-initial-creds", {}), - ], + config_for_credentials_test, ) @pytest.mark.parametrize("region_name", ["ap-southeast-2", "sa-east-1"]) def test_get_credentials_from_role_arn(self, conn_id, conn_extra, region_name): @@ -258,6 +279,41 @@ def test_get_credentials_from_role_arn(self, conn_id, conn_extra, region_name): # It shouldn't be 'explicit' which refers in this case to initial credentials. assert session.get_credentials().method == "sts-assume-role" + @pytest.mark.asyncio + @pytest.mark.parametrize( + "conn_id, conn_extra", + config_for_credentials_test, + ) + @pytest.mark.parametrize("region_name", ["ap-southeast-2", "sa-east-1"]) + async def test_async_get_credentials_from_role_arn(self, conn_id, conn_extra, region_name): + """Test RefreshableCredentials with assume_role for async_conn.""" + with mock.patch( + "airflow.providers.amazon.aws.hooks.base_aws.BaseSessionFactory._refresh_credentials" + ) as mock_refresh: + + def side_effect(): + return { + "access_key": "mock-AccessKeyId", + "secret_key": "mock-SecretAccessKey", + "token": "mock-SessionToken", + "expiry_time": datetime.now(timezone.utc).isoformat(), + } + + mock_refresh.side_effect = side_effect + extra = { + **conn_extra, + "role_arn": "arn:aws:iam::123456:role/role_arn", + "region_name": region_name, + } + conn = AwsConnectionWrapper.from_connection_metadata(conn_id=conn_id, extra=extra) + sf = BaseSessionFactory(conn=conn) + session = sf.create_session(deferrable=True) + assert session.region_name == region_name + # Validate method of botocore credentials provider. + # It shouldn't be 'explicit' which refers in this case to initial credentials. + credentials = await session.get_credentials() + assert credentials.method == "sts-assume-role" + class TestAwsBaseHook: @mock_emr @@ -394,7 +450,12 @@ def mock_assume_role(**kwargs): with mock.patch( "airflow.providers.amazon.aws.hooks.base_aws.requests.Session.get" - ) as mock_get, mock.patch("airflow.providers.amazon.aws.hooks.base_aws.boto3") as mock_boto3: + ) as mock_get, mock.patch( + "airflow.providers.amazon.aws.hooks.base_aws.boto3" + ) as mock_boto3, mock.patch( + "airflow.providers.amazon.aws.hooks.base_aws.isinstance" + ) as mock_isinstance: + mock_isinstance.return_value = True mock_get.return_value.ok = True mock_client = mock_boto3.session.Session.return_value.client @@ -589,7 +650,12 @@ def mock_assume_role_with_saml(**kwargs): with mock.patch("builtins.__import__", side_effect=import_mock), mock.patch( "airflow.providers.amazon.aws.hooks.base_aws.requests.Session.get" - ) as mock_get, mock.patch("airflow.providers.amazon.aws.hooks.base_aws.boto3") as mock_boto3: + ) as mock_get, mock.patch( + "airflow.providers.amazon.aws.hooks.base_aws.boto3" + ) as mock_boto3, mock.patch( + "airflow.providers.amazon.aws.hooks.base_aws.isinstance" + ) as mock_isinstance: + mock_isinstance.return_value = True mock_get.return_value.ok = True mock_client = mock_boto3.session.Session.return_value.client diff --git a/tests/providers/amazon/aws/hooks/test_emr_containers.py b/tests/providers/amazon/aws/hooks/test_emr_containers.py index 8a5f1303a6921..9be48a08ba04a 100644 --- a/tests/providers/amazon/aws/hooks/test_emr_containers.py +++ b/tests/providers/amazon/aws/hooks/test_emr_containers.py @@ -54,8 +54,9 @@ def test_init(self): assert self.emr_containers.aws_conn_id == "aws_default" assert self.emr_containers.virtual_cluster_id == "vc1234" + @mock.patch("airflow.providers.amazon.aws.hooks.base_aws.isinstance", return_value=True) @mock.patch("boto3.session.Session") - def test_create_emr_on_eks_cluster(self, mock_session): + def test_create_emr_on_eks_cluster(self, mock_session, mock_isinstance): emr_client_mock = mock.MagicMock() emr_client_mock.create_virtual_cluster.return_value = CREATE_EMR_ON_EKS_CLUSTER_RETURN emr_session_mock = mock.MagicMock() @@ -69,8 +70,9 @@ def test_create_emr_on_eks_cluster(self, mock_session): ) assert emr_on_eks_create_cluster_response == "vc1234" + @mock.patch("airflow.providers.amazon.aws.hooks.base_aws.isinstance", return_value=True) @mock.patch("boto3.session.Session") - def test_submit_job(self, mock_session): + def test_submit_job(self, mock_session, mock_isinstance): # Mock out the emr_client creator emr_client_mock = mock.MagicMock() emr_client_mock.start_job_run.return_value = SUBMIT_JOB_SUCCESS_RETURN @@ -88,8 +90,9 @@ def test_submit_job(self, mock_session): ) assert emr_containers_job == "job123456" + @mock.patch("airflow.providers.amazon.aws.hooks.base_aws.isinstance", return_value=True) @mock.patch("boto3.session.Session") - def test_query_status_polling_when_terminal(self, mock_session): + def test_query_status_polling_when_terminal(self, mock_session, mock_isinstance): emr_client_mock = mock.MagicMock() emr_session_mock = mock.MagicMock() emr_session_mock.client.return_value = emr_client_mock @@ -101,8 +104,9 @@ def test_query_status_polling_when_terminal(self, mock_session): emr_client_mock.describe_job_run.assert_called_once() assert query_status == "COMPLETED" + @mock.patch("airflow.providers.amazon.aws.hooks.base_aws.isinstance", return_value=True) @mock.patch("boto3.session.Session") - def test_query_status_polling_with_timeout(self, mock_session): + def test_query_status_polling_with_timeout(self, mock_session, mock_isinstance): emr_client_mock = mock.MagicMock() emr_session_mock = mock.MagicMock() emr_session_mock.client.return_value = emr_client_mock diff --git a/tests/providers/amazon/aws/operators/test_cloud_formation.py b/tests/providers/amazon/aws/operators/test_cloud_formation.py index 2600096df2852..1a1088adab057 100644 --- a/tests/providers/amazon/aws/operators/test_cloud_formation.py +++ b/tests/providers/amazon/aws/operators/test_cloud_formation.py @@ -55,7 +55,10 @@ def test_create_stack(self): dag=DAG("test_dag_id", default_args=DEFAULT_ARGS), ) - with mock.patch("boto3.session.Session", self.boto3_session_mock): + with mock.patch("boto3.session.Session", self.boto3_session_mock), mock.patch( + "airflow.providers.amazon.aws.hooks.base_aws.isinstance" + ) as mock_isinstance: + mock_isinstance.return_value = True operator.execute(self.mock_context) self.cloudformation_client_mock.create_stack.assert_any_call( @@ -84,7 +87,10 @@ def test_delete_stack(self): dag=DAG("test_dag_id", default_args=DEFAULT_ARGS), ) - with mock.patch("boto3.session.Session", self.boto3_session_mock): + with mock.patch("boto3.session.Session", self.boto3_session_mock), mock.patch( + "airflow.providers.amazon.aws.hooks.base_aws.isinstance" + ) as mock_isinstance: + mock_isinstance.return_value = True operator.execute(self.mock_context) self.cloudformation_client_mock.delete_stack.assert_any_call(StackName=stack_name) diff --git a/tests/providers/amazon/aws/operators/test_emr_add_steps.py b/tests/providers/amazon/aws/operators/test_emr_add_steps.py index 6f9c1c1b45922..67a4090563067 100644 --- a/tests/providers/amazon/aws/operators/test_emr_add_steps.py +++ b/tests/providers/amazon/aws/operators/test_emr_add_steps.py @@ -151,7 +151,10 @@ def test_render_template_from_file(self): assert json.loads(test_task.steps) == file_steps # String in job_flow_overrides (i.e. from loaded as a file) is not "parsed" until inside execute() - with patch("boto3.session.Session", self.boto3_session_mock): + with patch("boto3.session.Session", self.boto3_session_mock), patch( + "airflow.providers.amazon.aws.hooks.base_aws.isinstance" + ) as mock_isinstance: + mock_isinstance.return_value = True test_task.execute(None) self.emr_client_mock.add_job_flow_steps.assert_called_once_with( @@ -161,7 +164,10 @@ def test_render_template_from_file(self): def test_execute_returns_step_id(self): self.emr_client_mock.add_job_flow_steps.return_value = ADD_STEPS_SUCCESS_RETURN - with patch("boto3.session.Session", self.boto3_session_mock): + with patch("boto3.session.Session", self.boto3_session_mock), patch( + "airflow.providers.amazon.aws.hooks.base_aws.isinstance" + ) as mock_isinstance: + mock_isinstance.return_value = True assert self.operator.execute(self.mock_context) == ["s-2LH3R5GW3A53T"] def test_init_with_cluster_name(self): @@ -169,7 +175,10 @@ def test_init_with_cluster_name(self): self.emr_client_mock.add_job_flow_steps.return_value = ADD_STEPS_SUCCESS_RETURN - with patch("boto3.session.Session", self.boto3_session_mock): + with patch("boto3.session.Session", self.boto3_session_mock), patch( + "airflow.providers.amazon.aws.hooks.base_aws.isinstance" + ) as mock_isinstance: + mock_isinstance.return_value = True with patch( "airflow.providers.amazon.aws.hooks.emr.EmrHook.get_cluster_id_by_name" ) as mock_get_cluster_id_by_name: diff --git a/tests/providers/amazon/aws/operators/test_emr_containers.py b/tests/providers/amazon/aws/operators/test_emr_containers.py index 7f0e2942f6c98..ddc11b15c56ce 100644 --- a/tests/providers/amazon/aws/operators/test_emr_containers.py +++ b/tests/providers/amazon/aws/operators/test_emr_containers.py @@ -85,7 +85,10 @@ def test_execute_with_polling(self, mock_check_query_status): emr_session_mock.client.return_value = emr_client_mock boto3_session_mock = MagicMock(return_value=emr_session_mock) - with patch("boto3.session.Session", boto3_session_mock): + with patch("boto3.session.Session", boto3_session_mock), patch( + "airflow.providers.amazon.aws.hooks.base_aws.isinstance" + ) as mock_isinstance: + mock_isinstance.return_value = True assert self.emr_container.execute(None) == "job123456" assert mock_check_query_status.call_count == 5 @@ -130,7 +133,10 @@ def test_execute_with_polling_timeout(self, mock_check_query_status): max_polling_attempts=3, ) - with patch("boto3.session.Session", boto3_session_mock): + with patch("boto3.session.Session", boto3_session_mock), patch( + "airflow.providers.amazon.aws.hooks.base_aws.isinstance" + ) as mock_isinstance: + mock_isinstance.return_value = True with pytest.raises(AirflowException) as ctx: timeout_container.execute(None) diff --git a/tests/providers/amazon/aws/operators/test_emr_create_job_flow.py b/tests/providers/amazon/aws/operators/test_emr_create_job_flow.py index 1920f915b88db..cac14ddf59cd1 100644 --- a/tests/providers/amazon/aws/operators/test_emr_create_job_flow.py +++ b/tests/providers/amazon/aws/operators/test_emr_create_job_flow.py @@ -129,7 +129,10 @@ def test_render_template_from_file(self): boto3_session_mock = MagicMock(return_value=emr_session_mock) # String in job_flow_overrides (i.e. from loaded as a file) is not "parsed" until inside execute() - with patch("boto3.session.Session", boto3_session_mock): + with patch("boto3.session.Session", boto3_session_mock), patch( + "airflow.providers.amazon.aws.hooks.base_aws.isinstance" + ) as mock_isinstance: + mock_isinstance.return_value = True self.operator.execute(self.mock_context) expected_args = { @@ -161,7 +164,10 @@ def test_execute_returns_job_id(self): emr_session_mock.client.return_value = self.emr_client_mock boto3_session_mock = MagicMock(return_value=emr_session_mock) - with patch("boto3.session.Session", boto3_session_mock): + with patch("boto3.session.Session", boto3_session_mock), patch( + "airflow.providers.amazon.aws.hooks.base_aws.isinstance" + ) as mock_isinstance: + mock_isinstance.return_value = True assert self.operator.execute(self.mock_context) == JOB_FLOW_ID @mock.patch("botocore.waiter.get_service_module_name", return_value="emr") @@ -175,7 +181,10 @@ def test_execute_with_wait(self, mock_waiter, _): boto3_session_mock = MagicMock(return_value=emr_session_mock) self.operator.wait_for_completion = True - with patch("boto3.session.Session", boto3_session_mock): + with patch("boto3.session.Session", boto3_session_mock), patch( + "airflow.providers.amazon.aws.hooks.base_aws.isinstance" + ) as mock_isinstance: + mock_isinstance.return_value = True assert self.operator.execute(self.mock_context) == JOB_FLOW_ID mock_waiter.assert_called_once_with(mock.ANY, ClusterId=JOB_FLOW_ID, WaiterConfig=mock.ANY) assert_expected_waiter_type(mock_waiter, "job_flow_waiting") diff --git a/tests/providers/amazon/aws/operators/test_emr_modify_cluster.py b/tests/providers/amazon/aws/operators/test_emr_modify_cluster.py index 12efbcf198326..6fc20ed430b1d 100644 --- a/tests/providers/amazon/aws/operators/test_emr_modify_cluster.py +++ b/tests/providers/amazon/aws/operators/test_emr_modify_cluster.py @@ -63,12 +63,18 @@ def test_init(self): def test_execute_returns_step_concurrency(self): self.emr_client_mock.modify_cluster.return_value = MODIFY_CLUSTER_SUCCESS_RETURN - with patch("boto3.session.Session", self.boto3_session_mock): + with patch("boto3.session.Session", self.boto3_session_mock), patch( + "airflow.providers.amazon.aws.hooks.base_aws.isinstance" + ) as mock_isinstance: + mock_isinstance.return_value = True assert self.operator.execute(self.mock_context) == 1 def test_execute_returns_error(self): self.emr_client_mock.modify_cluster.return_value = MODIFY_CLUSTER_ERROR_RETURN - with patch("boto3.session.Session", self.boto3_session_mock): + with patch("boto3.session.Session", self.boto3_session_mock), patch( + "airflow.providers.amazon.aws.hooks.base_aws.isinstance" + ) as mock_isinstance: + mock_isinstance.return_value = True with pytest.raises(AirflowException): self.operator.execute(self.mock_context) diff --git a/tests/providers/amazon/aws/operators/test_emr_terminate_job_flow.py b/tests/providers/amazon/aws/operators/test_emr_terminate_job_flow.py index 594c6d775b9f9..4a639cc154719 100644 --- a/tests/providers/amazon/aws/operators/test_emr_terminate_job_flow.py +++ b/tests/providers/amazon/aws/operators/test_emr_terminate_job_flow.py @@ -37,7 +37,10 @@ def setup_method(self): self.boto3_session_mock = MagicMock(return_value=mock_emr_session) def test_execute_terminates_the_job_flow_and_does_not_error(self): - with patch("boto3.session.Session", self.boto3_session_mock): + with patch("boto3.session.Session", self.boto3_session_mock), patch( + "airflow.providers.amazon.aws.hooks.base_aws.isinstance" + ) as mock_isinstance: + mock_isinstance.return_value = True operator = EmrTerminateJobFlowOperator( task_id="test_task", job_flow_id="j-8989898989", aws_conn_id="aws_default" ) diff --git a/tests/providers/amazon/aws/operators/test_redshift_cluster.py b/tests/providers/amazon/aws/operators/test_redshift_cluster.py index 8ec6f67eef152..64a276f14d02c 100644 --- a/tests/providers/amazon/aws/operators/test_redshift_cluster.py +++ b/tests/providers/amazon/aws/operators/test_redshift_cluster.py @@ -116,6 +116,21 @@ def test_create_multi_node_cluster(self, mock_get_conn): # wait_for_completion is False so check waiter is not called mock_get_conn.return_value.get_waiter.assert_not_called() + @mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.get_conn") + def test_create_cluster_deferrable(self, mock_get_conn): + redshift_operator = RedshiftCreateClusterOperator( + task_id="task_test", + cluster_identifier="test-cluster", + node_type="dc2.large", + master_username="adminuser", + master_user_password="Test123$", + cluster_type="single-node", + deferrable=True, + ) + + with pytest.raises(TaskDeferred): + redshift_operator.execute(None) + class TestRedshiftCreateClusterSnapshotOperator: @mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.cluster_status") diff --git a/tests/providers/amazon/aws/sensors/test_cloud_formation.py b/tests/providers/amazon/aws/sensors/test_cloud_formation.py index 14610df2671b5..9aef7fae6dd87 100644 --- a/tests/providers/amazon/aws/sensors/test_cloud_formation.py +++ b/tests/providers/amazon/aws/sensors/test_cloud_formation.py @@ -51,7 +51,10 @@ def test_poke(self): assert op.poke({}) def test_poke_false(self): - with patch("boto3.session.Session", self.boto3_session_mock): + with patch("boto3.session.Session", self.boto3_session_mock), patch( + "airflow.providers.amazon.aws.hooks.base_aws.isinstance" + ) as mock_isinstance: + mock_isinstance.return_value = True self.cloudformation_client_mock.describe_stacks.return_value = { "Stacks": [{"StackStatus": "CREATE_IN_PROGRESS"}] } @@ -59,7 +62,10 @@ def test_poke_false(self): assert not op.poke({}) def test_poke_stack_in_unsuccessful_state(self): - with patch("boto3.session.Session", self.boto3_session_mock): + with patch("boto3.session.Session", self.boto3_session_mock), patch( + "airflow.providers.amazon.aws.hooks.base_aws.isinstance" + ) as mock_isinstance: + mock_isinstance.return_value = True self.cloudformation_client_mock.describe_stacks.return_value = { "Stacks": [{"StackStatus": "bar"}] } @@ -91,7 +97,10 @@ def test_poke(self): assert op.poke({}) def test_poke_false(self): - with patch("boto3.session.Session", self.boto3_session_mock): + with patch("boto3.session.Session", self.boto3_session_mock), patch( + "airflow.providers.amazon.aws.hooks.base_aws.isinstance" + ) as mock_isinstance: + mock_isinstance.return_value = True self.cloudformation_client_mock.describe_stacks.return_value = { "Stacks": [{"StackStatus": "DELETE_IN_PROGRESS"}] } @@ -99,7 +108,10 @@ def test_poke_false(self): assert not op.poke({}) def test_poke_stack_in_unsuccessful_state(self): - with patch("boto3.session.Session", self.boto3_session_mock): + with patch("boto3.session.Session", self.boto3_session_mock), patch( + "airflow.providers.amazon.aws.hooks.base_aws.isinstance" + ) as mock_isinstance: + mock_isinstance.return_value = True self.cloudformation_client_mock.describe_stacks.return_value = { "Stacks": [{"StackStatus": "bar"}] } diff --git a/tests/providers/amazon/aws/sensors/test_emr_job_flow.py b/tests/providers/amazon/aws/sensors/test_emr_job_flow.py index a10e68abb8483..c81d9d4855747 100644 --- a/tests/providers/amazon/aws/sensors/test_emr_job_flow.py +++ b/tests/providers/amazon/aws/sensors/test_emr_job_flow.py @@ -208,7 +208,10 @@ def test_execute_calls_with_the_job_flow_id_until_it_reaches_a_target_state(self DESCRIBE_CLUSTER_RUNNING_RETURN, DESCRIBE_CLUSTER_TERMINATED_RETURN, ] - with patch("boto3.session.Session", self.boto3_session_mock): + with patch("boto3.session.Session", self.boto3_session_mock), patch( + "airflow.providers.amazon.aws.hooks.base_aws.isinstance" + ) as mock_isinstance: + mock_isinstance.return_value = True operator = EmrJobFlowSensor( task_id="test_task", poke_interval=0, job_flow_id="j-8989898989", aws_conn_id="aws_default" ) @@ -227,7 +230,10 @@ def test_execute_calls_with_the_job_flow_id_until_it_reaches_failed_state_with_e DESCRIBE_CLUSTER_RUNNING_RETURN, DESCRIBE_CLUSTER_TERMINATED_WITH_ERRORS_RETURN, ] - with patch("boto3.session.Session", self.boto3_session_mock): + with patch("boto3.session.Session", self.boto3_session_mock), patch( + "airflow.providers.amazon.aws.hooks.base_aws.isinstance" + ) as mock_isinstance: + mock_isinstance.return_value = True operator = EmrJobFlowSensor( task_id="test_task", poke_interval=0, job_flow_id="j-8989898989", aws_conn_id="aws_default" ) @@ -250,7 +256,10 @@ def test_different_target_states(self): DESCRIBE_CLUSTER_TERMINATED_RETURN, # will not be used DESCRIBE_CLUSTER_TERMINATED_WITH_ERRORS_RETURN, # will not be used ] - with patch("boto3.session.Session", self.boto3_session_mock): + with patch("boto3.session.Session", self.boto3_session_mock), patch( + "airflow.providers.amazon.aws.hooks.base_aws.isinstance" + ) as mock_isinstance: + mock_isinstance.return_value = True operator = EmrJobFlowSensor( task_id="test_task", poke_interval=0, diff --git a/tests/providers/amazon/aws/sensors/test_emr_step.py b/tests/providers/amazon/aws/sensors/test_emr_step.py index d053bda97c00f..cca5f417f83d0 100644 --- a/tests/providers/amazon/aws/sensors/test_emr_step.py +++ b/tests/providers/amazon/aws/sensors/test_emr_step.py @@ -165,7 +165,10 @@ def test_step_completed(self): DESCRIBE_JOB_STEP_COMPLETED_RETURN, ] - with patch("boto3.session.Session", self.boto3_session_mock): + with patch("boto3.session.Session", self.boto3_session_mock), patch( + "airflow.providers.amazon.aws.hooks.base_aws.isinstance" + ) as mock_isinstance: + mock_isinstance.return_value = True self.sensor.execute(None) assert self.emr_client_mock.describe_step.call_count == 2 @@ -181,7 +184,10 @@ def test_step_cancelled(self): DESCRIBE_JOB_STEP_CANCELLED_RETURN, ] - with patch("boto3.session.Session", self.boto3_session_mock): + with patch("boto3.session.Session", self.boto3_session_mock), patch( + "airflow.providers.amazon.aws.hooks.base_aws.isinstance" + ) as mock_isinstance: + mock_isinstance.return_value = True with pytest.raises(AirflowException): self.sensor.execute(None) @@ -191,7 +197,10 @@ def test_step_failed(self): DESCRIBE_JOB_STEP_FAILED_RETURN, ] - with patch("boto3.session.Session", self.boto3_session_mock): + with patch("boto3.session.Session", self.boto3_session_mock), patch( + "airflow.providers.amazon.aws.hooks.base_aws.isinstance" + ) as mock_isinstance: + mock_isinstance.return_value = True with pytest.raises(AirflowException): self.sensor.execute(None) @@ -201,6 +210,9 @@ def test_step_interrupted(self): DESCRIBE_JOB_STEP_INTERRUPTED_RETURN, ] - with patch("boto3.session.Session", self.boto3_session_mock): + with patch("boto3.session.Session", self.boto3_session_mock), patch( + "airflow.providers.amazon.aws.hooks.base_aws.isinstance" + ) as mock_isinstance: + mock_isinstance.return_value = True with pytest.raises(AirflowException): self.sensor.execute(None) diff --git a/tests/providers/amazon/aws/triggers/__init__.py b/tests/providers/amazon/aws/triggers/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/tests/providers/amazon/aws/triggers/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/providers/amazon/aws/triggers/test_redshift_cluster.py b/tests/providers/amazon/aws/triggers/test_redshift_cluster.py new file mode 100644 index 0000000000000..941258659e9ae --- /dev/null +++ b/tests/providers/amazon/aws/triggers/test_redshift_cluster.py @@ -0,0 +1,74 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import sys + +import pytest + +from airflow.providers.amazon.aws.triggers.redshift_cluster import RedshiftCreateClusterTrigger +from airflow.triggers.base import TriggerEvent + +if sys.version_info < (3, 8): + from asynctest import CoroutineMock as AsyncMock, mock as async_mock +else: + from unittest import mock as async_mock + from unittest.mock import AsyncMock + + +TEST_CLUSTER_IDENTIFIER = "test-cluster" +TEST_POLL_INTERVAL = 10 +TEST_MAX_ATTEMPT = 10 +TEST_AWS_CONN_ID = "test-aws-id" + + +class TestRedshiftCreateClusterTrigger: + def test_redshift_create_cluster_trigger_serialize(self): + redshift_create_cluster_trigger = RedshiftCreateClusterTrigger( + cluster_identifier=TEST_CLUSTER_IDENTIFIER, + poll_interval=TEST_POLL_INTERVAL, + max_attempt=TEST_MAX_ATTEMPT, + aws_conn_id=TEST_AWS_CONN_ID, + ) + class_path, args = redshift_create_cluster_trigger.serialize() + assert ( + class_path + == "airflow.providers.amazon.aws.triggers.redshift_cluster.RedshiftCreateClusterTrigger" + ) + assert args["cluster_identifier"] == TEST_CLUSTER_IDENTIFIER + assert args["poll_interval"] == str(TEST_POLL_INTERVAL) + assert args["max_attempt"] == str(TEST_MAX_ATTEMPT) + assert args["aws_conn_id"] == TEST_AWS_CONN_ID + + @pytest.mark.asyncio + @async_mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.async_conn") + async def test_redshift_create_cluster_trigger_run(self, mock_async_conn): + mock = async_mock.MagicMock() + mock_async_conn.__aenter__.return_value = mock + mock.get_waiter().wait = AsyncMock() + + redshift_create_cluster_trigger = RedshiftCreateClusterTrigger( + cluster_identifier=TEST_CLUSTER_IDENTIFIER, + poll_interval=TEST_POLL_INTERVAL, + max_attempt=TEST_MAX_ATTEMPT, + aws_conn_id=TEST_AWS_CONN_ID, + ) + + generator = redshift_create_cluster_trigger.run() + response = await generator.asend(None) + + assert response == TriggerEvent({"status": "success", "message": "Cluster Created"})