diff --git a/airflow/providers/amazon/aws/hooks/athena.py b/airflow/providers/amazon/aws/hooks/athena.py index 7341ebcf70f24..27e194d3aa8d0 100644 --- a/airflow/providers/amazon/aws/hooks/athena.py +++ b/airflow/providers/amazon/aws/hooks/athena.py @@ -253,7 +253,7 @@ def poll_query_status( try: wait( waiter=self.get_waiter("query_complete"), - waiter_delay=sleep_time or self.sleep_time, + waiter_delay=self.sleep_time if sleep_time is None else sleep_time, waiter_max_attempts=max_polling_attempts or 120, args={"QueryExecutionId": query_execution_id}, failure_message=f"Error while waiting for query {query_execution_id} to complete", diff --git a/airflow/providers/amazon/aws/operators/batch.py b/airflow/providers/amazon/aws/operators/batch.py index 9dd954d05cd80..e6221ae3e0af4 100644 --- a/airflow/providers/amazon/aws/operators/batch.py +++ b/airflow/providers/amazon/aws/operators/batch.py @@ -41,7 +41,7 @@ from airflow.providers.amazon.aws.links.logs import CloudWatchEventsLink from airflow.providers.amazon.aws.triggers.batch import ( BatchCreateComputeEnvironmentTrigger, - BatchOperatorTrigger, + BatchJobTrigger, ) from airflow.providers.amazon.aws.utils import trim_none_values from airflow.providers.amazon.aws.utils.task_log_fetcher import AwsTaskLogFetcher @@ -221,12 +221,12 @@ def execute(self, context: Context): if self.deferrable: self.defer( timeout=self.execution_timeout, - trigger=BatchOperatorTrigger( + trigger=BatchJobTrigger( job_id=self.job_id, - max_retries=self.max_retries or 10, + waiter_max_attempts=self.max_retries or 10, aws_conn_id=self.aws_conn_id, region_name=self.region_name, - poll_interval=self.poll_interval, + waiter_delay=self.poll_interval, ), method_name="execute_complete", ) diff --git a/airflow/providers/amazon/aws/operators/ecs.py b/airflow/providers/amazon/aws/operators/ecs.py index e5833bf4c3d53..6df72c6264f0e 100644 --- a/airflow/providers/amazon/aws/operators/ecs.py +++ b/airflow/providers/amazon/aws/operators/ecs.py @@ -33,7 +33,11 @@ from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook from airflow.providers.amazon.aws.hooks.ecs import EcsClusterStates, EcsHook, should_retry_eni from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook -from airflow.providers.amazon.aws.triggers.ecs import ClusterWaiterTrigger, TaskDoneTrigger +from airflow.providers.amazon.aws.triggers.ecs import ( + ClusterActiveTrigger, + ClusterInactiveTrigger, + TaskDoneTrigger, +) from airflow.providers.amazon.aws.utils.task_log_fetcher import AwsTaskLogFetcher from airflow.utils.helpers import prune_dict from airflow.utils.session import provide_session @@ -139,13 +143,12 @@ def execute(self, context: Context): self.log.info("Cluster %r in state: %r.", self.cluster_name, cluster_state) elif self.deferrable: self.defer( - trigger=ClusterWaiterTrigger( - waiter_name="cluster_active", + trigger=ClusterActiveTrigger( cluster_arn=cluster_details["clusterArn"], waiter_delay=self.waiter_delay, waiter_max_attempts=self.waiter_max_attempts, aws_conn_id=self.aws_conn_id, - region=self.region, + region_name=self.region, ), method_name="_complete_exec_with_cluster_desc", # timeout is set to ensure that if a trigger dies, the timeout does not restart @@ -217,13 +220,12 @@ def execute(self, context: Context): self.log.info("Cluster %r in state: %r.", self.cluster_name, cluster_state) elif self.deferrable: self.defer( - trigger=ClusterWaiterTrigger( - waiter_name="cluster_inactive", + trigger=ClusterInactiveTrigger( cluster_arn=cluster_details["clusterArn"], waiter_delay=self.waiter_delay, waiter_max_attempts=self.waiter_max_attempts, aws_conn_id=self.aws_conn_id, - region=self.region, + region_name=self.region, ), method_name="_complete_exec_with_cluster_desc", # timeout is set to ensure that if a trigger dies, the timeout does not restart diff --git a/airflow/providers/amazon/aws/operators/eks.py b/airflow/providers/amazon/aws/operators/eks.py index 56e9269f88d79..6858f801121bd 100644 --- a/airflow/providers/amazon/aws/operators/eks.py +++ b/airflow/providers/amazon/aws/operators/eks.py @@ -31,8 +31,9 @@ from airflow.providers.amazon.aws.hooks.eks import EksHook from airflow.providers.amazon.aws.triggers.eks import ( EksCreateFargateProfileTrigger, + EksCreateNodegroupTrigger, EksDeleteFargateProfileTrigger, - EksNodegroupTrigger, + EksDeleteNodegroupTrigger, ) from airflow.providers.amazon.aws.utils.waiter_with_logging import wait from airflow.providers.cncf.kubernetes.utils.pod_manager import OnFinishAction @@ -413,12 +414,11 @@ def execute(self, context: Context): if self.deferrable: self.defer( - trigger=EksNodegroupTrigger( - waiter_name="nodegroup_active", + trigger=EksCreateNodegroupTrigger( cluster_name=self.cluster_name, nodegroup_name=self.nodegroup_name, aws_conn_id=self.aws_conn_id, - region=self.region, + region_name=self.region, waiter_delay=self.waiter_delay, waiter_max_attempts=self.waiter_max_attempts, ), @@ -711,12 +711,11 @@ def execute(self, context: Context): eks_hook.delete_nodegroup(clusterName=self.cluster_name, nodegroupName=self.nodegroup_name) if self.deferrable: self.defer( - trigger=EksNodegroupTrigger( - waiter_name="nodegroup_deleted", + trigger=EksDeleteNodegroupTrigger( cluster_name=self.cluster_name, nodegroup_name=self.nodegroup_name, aws_conn_id=self.aws_conn_id, - region=self.region, + region_name=self.region, waiter_delay=self.waiter_delay, waiter_max_attempts=self.waiter_max_attempts, ), diff --git a/airflow/providers/amazon/aws/operators/emr.py b/airflow/providers/amazon/aws/operators/emr.py index 8330a586e4426..4cb070da733b1 100644 --- a/airflow/providers/amazon/aws/operators/emr.py +++ b/airflow/providers/amazon/aws/operators/emr.py @@ -577,7 +577,7 @@ def execute(self, context: Context) -> str | None: virtual_cluster_id=self.virtual_cluster_id, job_id=self.job_id, aws_conn_id=self.aws_conn_id, - poll_interval=self.poll_interval, + waiter_delay=self.poll_interval, ), method_name="execute_complete", ) @@ -943,8 +943,8 @@ def execute(self, context: Context) -> None: self.defer( trigger=EmrTerminateJobFlowTrigger( job_flow_id=self.job_flow_id, - poll_interval=self.waiter_delay, - max_attempts=self.waiter_max_attempts, + waiter_delay=self.waiter_delay, + waiter_max_attempts=self.waiter_max_attempts, aws_conn_id=self.aws_conn_id, ), method_name="execute_complete", diff --git a/airflow/providers/amazon/aws/operators/glue_crawler.py b/airflow/providers/amazon/aws/operators/glue_crawler.py index 71e2607039c35..c9fb298ee7c0d 100644 --- a/airflow/providers/amazon/aws/operators/glue_crawler.py +++ b/airflow/providers/amazon/aws/operators/glue_crawler.py @@ -96,7 +96,7 @@ def execute(self, context: Context): self.defer( trigger=GlueCrawlerCompleteTrigger( crawler_name=crawler_name, - poll_interval=self.poll_interval, + waiter_delay=self.poll_interval, aws_conn_id=self.aws_conn_id, ), method_name="execute_complete", diff --git a/airflow/providers/amazon/aws/operators/redshift_cluster.py b/airflow/providers/amazon/aws/operators/redshift_cluster.py index cde4a32226e91..ed6aa79e9f2dd 100644 --- a/airflow/providers/amazon/aws/operators/redshift_cluster.py +++ b/airflow/providers/amazon/aws/operators/redshift_cluster.py @@ -267,8 +267,8 @@ def execute(self, context: Context): self.defer( trigger=RedshiftCreateClusterTrigger( cluster_identifier=self.cluster_identifier, - poll_interval=self.poll_interval, - max_attempt=self.max_attempt, + waiter_delay=self.poll_interval, + waiter_max_attempts=self.max_attempt, aws_conn_id=self.aws_conn_id, ), method_name="execute_complete", @@ -361,8 +361,8 @@ def execute(self, context: Context) -> Any: self.defer( trigger=RedshiftCreateClusterSnapshotTrigger( cluster_identifier=self.cluster_identifier, - poll_interval=self.poll_interval, - max_attempts=self.max_attempt, + waiter_delay=self.poll_interval, + waiter_max_attempts=self.max_attempt, aws_conn_id=self.aws_conn_id, ), method_name="execute_complete", @@ -510,8 +510,8 @@ def execute(self, context: Context): self.defer( trigger=RedshiftResumeClusterTrigger( cluster_identifier=self.cluster_identifier, - poll_interval=self.poll_interval, - max_attempts=self.max_attempts, + waiter_delay=self.poll_interval, + waiter_max_attempts=self.max_attempts, aws_conn_id=self.aws_conn_id, ), method_name="execute_complete", @@ -598,8 +598,8 @@ def execute(self, context: Context): self.defer( trigger=RedshiftPauseClusterTrigger( cluster_identifier=self.cluster_identifier, - poll_interval=self.poll_interval, - max_attempts=self.max_attempts, + waiter_delay=self.poll_interval, + waiter_max_attempts=self.max_attempts, aws_conn_id=self.aws_conn_id, ), method_name="execute_complete", @@ -690,8 +690,8 @@ def execute(self, context: Context): timeout=timedelta(seconds=self.max_attempts * self.poll_interval + 60), trigger=RedshiftDeleteClusterTrigger( cluster_identifier=self.cluster_identifier, - poll_interval=self.poll_interval, - max_attempts=self.max_attempts, + waiter_delay=self.poll_interval, + waiter_max_attempts=self.max_attempts, aws_conn_id=self.aws_conn_id, ), method_name="execute_complete", diff --git a/airflow/providers/amazon/aws/sensors/batch.py b/airflow/providers/amazon/aws/sensors/batch.py index 32da5b4cf2524..454d123f9471d 100644 --- a/airflow/providers/amazon/aws/sensors/batch.py +++ b/airflow/providers/amazon/aws/sensors/batch.py @@ -25,7 +25,7 @@ from airflow.configuration import conf from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook -from airflow.providers.amazon.aws.triggers.batch import BatchSensorTrigger +from airflow.providers.amazon.aws.triggers.batch import BatchJobTrigger from airflow.sensors.base import BaseSensorOperator if TYPE_CHECKING: @@ -98,11 +98,12 @@ def execute(self, context: Context) -> None: ) self.defer( timeout=timeout, - trigger=BatchSensorTrigger( + trigger=BatchJobTrigger( job_id=self.job_id, aws_conn_id=self.aws_conn_id, region_name=self.region_name, - poke_interval=self.poke_interval, + waiter_delay=int(self.poke_interval), + waiter_max_attempts=self.max_retries, ), method_name="execute_complete", ) @@ -113,9 +114,10 @@ def execute_complete(self, context: Context, event: dict[str, Any]) -> None: Relies on trigger to throw an exception, otherwise it assumes execution was successful. """ - if "status" in event and event["status"] == "failure": - raise AirflowException(event["message"]) - self.log.info(event["message"]) + if event["status"] != "success": + raise AirflowException(f"Error while running job: {event}") + job_id = event["job_id"] + self.log.info("Batch Job %s complete", job_id) @deprecated(reason="use `hook` property instead.") def get_hook(self) -> BatchClientHook: diff --git a/airflow/providers/amazon/aws/sensors/emr.py b/airflow/providers/amazon/aws/sensors/emr.py index 9953dfa78260c..71935e0b9b6d9 100644 --- a/airflow/providers/amazon/aws/sensors/emr.py +++ b/airflow/providers/amazon/aws/sensors/emr.py @@ -316,7 +316,7 @@ def execute(self, context: Context): virtual_cluster_id=self.virtual_cluster_id, job_id=self.job_id, aws_conn_id=self.aws_conn_id, - poll_interval=self.poll_interval, + waiter_delay=self.poll_interval, ), method_name="execute_complete", ) @@ -501,9 +501,9 @@ def execute(self, context: Context) -> None: timeout=timedelta(seconds=self.poke_interval * self.max_attempts), trigger=EmrTerminateJobFlowTrigger( job_flow_id=self.job_flow_id, - max_attempts=self.max_attempts, + waiter_max_attempts=self.max_attempts, aws_conn_id=self.aws_conn_id, - poll_interval=int(self.poke_interval), + waiter_delay=int(self.poke_interval), ), method_name="execute_complete", ) @@ -628,9 +628,9 @@ def execute(self, context: Context) -> None: trigger=EmrStepSensorTrigger( job_flow_id=self.job_flow_id, step_id=self.step_id, + waiter_delay=int(self.poke_interval), + waiter_max_attempts=self.max_attempts, aws_conn_id=self.aws_conn_id, - max_attempts=self.max_attempts, - poke_interval=int(self.poke_interval), ), method_name="execute_complete", ) diff --git a/airflow/providers/amazon/aws/triggers/athena.py b/airflow/providers/amazon/aws/triggers/athena.py index efae559470a28..636c1350598ea 100644 --- a/airflow/providers/amazon/aws/triggers/athena.py +++ b/airflow/providers/amazon/aws/triggers/athena.py @@ -16,14 +16,12 @@ # under the License. from __future__ import annotations -from typing import Any - from airflow.providers.amazon.aws.hooks.athena import AthenaHook -from airflow.providers.amazon.aws.utils.waiter_with_logging import async_wait -from airflow.triggers.base import BaseTrigger, TriggerEvent +from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook +from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger -class AthenaTrigger(BaseTrigger): +class AthenaTrigger(AwsBaseWaiterTrigger): """ Trigger for RedshiftCreateClusterOperator. @@ -31,46 +29,30 @@ class AthenaTrigger(BaseTrigger): Redshift cluster to be in the `available` state. :param query_execution_id: ID of the Athena query execution to watch - :param poll_interval: The amount of time in seconds to wait between attempts. - :param max_attempt: The maximum number of attempts to be made. + :param waiter_delay: The amount of time in seconds to wait between attempts. + :param waiter_max_attempts: The maximum number of attempts to be made. :param aws_conn_id: The Airflow connection used for AWS credentials. """ def __init__( self, query_execution_id: str, - poll_interval: int, - max_attempt: int, + waiter_delay: int, + waiter_max_attempts: int, aws_conn_id: str, ): - self.query_execution_id = query_execution_id - self.poll_interval = poll_interval - self.max_attempt = max_attempt - self.aws_conn_id = aws_conn_id - - def serialize(self) -> tuple[str, dict[str, Any]]: - return ( - self.__class__.__module__ + "." + self.__class__.__qualname__, - { - "query_execution_id": str(self.query_execution_id), - "poll_interval": str(self.poll_interval), - "max_attempt": str(self.max_attempt), - "aws_conn_id": str(self.aws_conn_id), - }, + super().__init__( + serialized_fields={"query_execution_id": query_execution_id}, + waiter_name="query_complete", + waiter_args={"QueryExecutionId": query_execution_id}, + failure_message=f"Error while waiting for query {query_execution_id} to complete", + status_message=f"Query execution id: {query_execution_id}", + status_queries=["QueryExecution.Status"], + return_value=query_execution_id, + waiter_delay=waiter_delay, + waiter_max_attempts=waiter_max_attempts, + aws_conn_id=aws_conn_id, ) - async def run(self): - hook = AthenaHook(self.aws_conn_id) - async with hook.async_conn as client: - waiter = hook.get_waiter("query_complete", deferrable=True, client=client) - await async_wait( - waiter=waiter, - waiter_delay=self.poll_interval, - waiter_max_attempts=self.max_attempt, - args={"QueryExecutionId": self.query_execution_id}, - failure_message=f"Error while waiting for query {self.query_execution_id} to complete", - status_message=f"Query execution id: {self.query_execution_id}, " - "Query is still in non-terminal state", - status_args=["QueryExecution.Status.State"], - ) - yield TriggerEvent({"status": "success", "value": self.query_execution_id}) + def hook(self) -> AwsGenericHook: + return AthenaHook(self.aws_conn_id) diff --git a/airflow/providers/amazon/aws/triggers/base.py b/airflow/providers/amazon/aws/triggers/base.py new file mode 100644 index 0000000000000..41f7d2dc33d79 --- /dev/null +++ b/airflow/providers/amazon/aws/triggers/base.py @@ -0,0 +1,130 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from abc import abstractmethod +from typing import Any, AsyncIterator + +from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook +from airflow.providers.amazon.aws.utils.waiter_with_logging import async_wait +from airflow.triggers.base import BaseTrigger, TriggerEvent + + +class AwsBaseWaiterTrigger(BaseTrigger): + """ + Base class for all AWS Triggers that follow the "standard" model of just waiting on a waiter. + + Subclasses need to implement the hook() method. + + :param serialized_fields: Fields that are specific to the subclass trigger and need to be serialized + to be passed to the __init__ method on deserialization. + The conn id, region, and waiter delay & attempts are always serialized. + format: {: } + + :param waiter_name: The name of the (possibly custom) boto waiter to use. + + :param waiter_args: The arguments to pass to the waiter. + :param failure_message: The message to log if a failure state is reached. + :param status_message: The message logged when printing the status of the service. + :param status_queries: A list containing the JMESPath queries to retrieve status information from + the waiter response. See https://jmespath.org/tutorial.html + + :param return_key: The key to use for the return_value in the TriggerEvent this emits on success. + Defaults to "value". + :param return_value: A value that'll be returned in the return_key field of the TriggerEvent. + Set to None if there is nothing to return. + + :param waiter_delay: The amount of time in seconds to wait between attempts. + :param waiter_max_attempts: The maximum number of attempts to be made. + :param aws_conn_id: The Airflow connection used for AWS credentials. To be used to build the hook. + :param region_name: The AWS region where the resources to watch are. To be used to build the hook. + """ + + def __init__( + self, + *, + serialized_fields: dict[str, Any], + waiter_name: str, + waiter_args: dict[str, Any], + failure_message: str, + status_message: str, + status_queries: list[str], + return_key: str = "value", + return_value: Any, + waiter_delay: int, + waiter_max_attempts: int, + aws_conn_id: str | None, + region_name: str | None = None, + ): + # parameters that should be hardcoded in the child's implem + self.serialized_fields = serialized_fields + + self.waiter_name = waiter_name + self.waiter_args = waiter_args + self.failure_message = failure_message + self.status_message = status_message + self.status_queries = status_queries + + self.return_key = return_key + self.return_value = return_value + + # parameters that should be passed directly from the child's parameters + self.waiter_delay = waiter_delay + self.attempts = waiter_max_attempts + self.aws_conn_id = aws_conn_id + self.region_name = region_name + + def serialize(self) -> tuple[str, dict[str, Any]]: + # here we put together the "common" params, + # and whatever extras we got from the subclass in serialized_fields + params = dict( + { + "waiter_delay": self.waiter_delay, + "waiter_max_attempts": self.attempts, + "aws_conn_id": self.aws_conn_id, + }, + **self.serialized_fields, + ) + if self.region_name: + # if we serialize the None value from this, it breaks subclasses that don't have it in their ctor. + params["region_name"] = self.region_name + return ( + # remember that self is an instance of the subclass here, not of this class. + self.__class__.__module__ + "." + self.__class__.__qualname__, + params, + ) + + @abstractmethod + def hook(self) -> AwsGenericHook: + """Override in subclasses to return the right hook.""" + ... + + async def run(self) -> AsyncIterator[TriggerEvent]: + hook = self.hook() + async with hook.async_conn as client: + waiter = hook.get_waiter(self.waiter_name, deferrable=True, client=client) + await async_wait( + waiter, + self.waiter_delay, + self.attempts, + self.waiter_args, + self.failure_message, + self.status_message, + self.status_queries, + ) + yield TriggerEvent({"status": "success", self.return_key: self.return_value}) diff --git a/airflow/providers/amazon/aws/triggers/batch.py b/airflow/providers/amazon/aws/triggers/batch.py index f7d335280d6ab..87be13933293f 100644 --- a/airflow/providers/amazon/aws/triggers/batch.py +++ b/airflow/providers/amazon/aws/triggers/batch.py @@ -21,12 +21,15 @@ from typing import Any from botocore.exceptions import WaiterError +from deprecated import deprecated +from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook -from airflow.providers.amazon.aws.utils.waiter_with_logging import async_wait +from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger from airflow.triggers.base import BaseTrigger, TriggerEvent +@deprecated(reason="use BatchJobTrigger instead") class BatchOperatorTrigger(BaseTrigger): """ Asynchronously poll the boto3 API and wait for the Batch job to be in the `SUCCEEDED` state. @@ -106,6 +109,7 @@ async def run(self): yield TriggerEvent({"status": "success", "job_id": self.job_id}) +@deprecated(reason="use BatchJobTrigger instead") class BatchSensorTrigger(BaseTrigger): """ Checks for the status of a submitted job_id to AWS Batch until it reaches a failure or a success state. @@ -189,56 +193,78 @@ async def run(self): ) -class BatchCreateComputeEnvironmentTrigger(BaseTrigger): +class BatchJobTrigger(AwsBaseWaiterTrigger): + """ + Checks for the status of a submitted job_id to AWS Batch until it reaches a failure or a success state. + + :param job_id: the job ID, to poll for job completion or not + :param region_name: AWS region name to use + Override the region_name in connection (if provided) + :param aws_conn_id: connection id of AWS credentials / region name. If None, + credential boto3 strategy will be used + :param waiter_delay: polling period in seconds to check for the status of the job + :param waiter_max_attempts: The maximum number of attempts to be made. + """ + + def __init__( + self, + job_id: str | None, + region_name: str | None, + aws_conn_id: str | None = "aws_default", + waiter_delay: int = 5, + waiter_max_attempts: int = 720, + ): + super().__init__( + serialized_fields={"job_id": job_id}, + waiter_name="batch_job_complete", + waiter_args={"jobs": [job_id]}, + failure_message=f"Failure while running batch job {job_id}", + status_message=f"Batch job {job_id} not ready yet", + status_queries=["jobs[].status", "computeEnvironments[].statusReason"], + return_key="job_id", + return_value=job_id, + waiter_delay=waiter_delay, + waiter_max_attempts=waiter_max_attempts, + aws_conn_id=aws_conn_id, + region_name=region_name, + ) + + def hook(self) -> AwsGenericHook: + return BatchClientHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) + + +class BatchCreateComputeEnvironmentTrigger(AwsBaseWaiterTrigger): """ Asynchronously poll the boto3 API and wait for the compute environment to be ready. - :param job_id: A unique identifier for the cluster. - :param max_retries: The maximum number of attempts to be made. + :param compute_env_arn: The ARN of the compute env. + :param waiter_max_attempts: The maximum number of attempts to be made. :param aws_conn_id: The Airflow connection used for AWS credentials. :param region_name: region name to use in AWS Hook - :param poll_interval: The amount of time in seconds to wait between attempts. + :param waiter_delay: The amount of time in seconds to wait between attempts. """ def __init__( self, - compute_env_arn: str | None = None, - poll_interval: int = 30, - max_retries: int = 10, + compute_env_arn: str, + waiter_delay: int = 30, + waiter_max_attempts: int = 10, aws_conn_id: str | None = "aws_default", region_name: str | None = None, ): - super().__init__() - self.compute_env_arn = compute_env_arn - self.max_retries = max_retries - self.aws_conn_id = aws_conn_id - self.region_name = region_name - self.poll_interval = poll_interval - - def serialize(self) -> tuple[str, dict[str, Any]]: - """Serializes BatchOperatorTrigger arguments and classpath.""" - return ( - self.__class__.__module__ + "." + self.__class__.__qualname__, - { - "compute_env_arn": self.compute_env_arn, - "max_retries": self.max_retries, - "aws_conn_id": self.aws_conn_id, - "region_name": self.region_name, - "poll_interval": self.poll_interval, - }, + super().__init__( + serialized_fields={"compute_env_arn": compute_env_arn}, + waiter_name="compute_env_ready", + waiter_args={"computeEnvironments": [compute_env_arn]}, + failure_message="Failure while creating Compute Environment", + status_message="Compute Environment not ready yet", + status_queries=["computeEnvironments[].status", "computeEnvironments[].statusReason"], + return_value=compute_env_arn, + waiter_delay=waiter_delay, + waiter_max_attempts=waiter_max_attempts, + aws_conn_id=aws_conn_id, + region_name=region_name, ) - async def run(self): - hook = BatchClientHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) - async with hook.async_conn as client: - waiter = hook.get_waiter("compute_env_ready", deferrable=True, client=client) - await async_wait( - waiter=waiter, - waiter_delay=self.poll_interval, - waiter_max_attempts=self.max_retries, - args={"computeEnvironments": [self.compute_env_arn]}, - failure_message="Failure while creating Compute Environment", - status_message="Compute Environment not ready yet", - status_args=["computeEnvironments[].status", "computeEnvironments[].statusReason"], - ) - yield TriggerEvent({"status": "success", "value": self.compute_env_arn}) + def hook(self) -> AwsGenericHook: + return BatchClientHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) diff --git a/airflow/providers/amazon/aws/triggers/ecs.py b/airflow/providers/amazon/aws/triggers/ecs.py index c8977d33e960a..29ad22e13e79c 100644 --- a/airflow/providers/amazon/aws/triggers/ecs.py +++ b/airflow/providers/amazon/aws/triggers/ecs.py @@ -22,68 +22,89 @@ from botocore.exceptions import ClientError, WaiterError +from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook from airflow.providers.amazon.aws.hooks.ecs import EcsHook from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook +from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger from airflow.providers.amazon.aws.utils.task_log_fetcher import AwsTaskLogFetcher -from airflow.providers.amazon.aws.utils.waiter_with_logging import async_wait from airflow.triggers.base import BaseTrigger, TriggerEvent -class ClusterWaiterTrigger(BaseTrigger): +class ClusterActiveTrigger(AwsBaseWaiterTrigger): """ - Polls the status of a cluster using a given waiter. Can be used to poll for an active or inactive cluster. + Polls the status of a cluster until it's active. - :param waiter_name: Name of the waiter to use, for instance 'cluster_active' or 'cluster_inactive' :param cluster_arn: ARN of the cluster to watch. :param waiter_delay: The amount of time in seconds to wait between attempts. :param waiter_max_attempts: The number of times to ping for status. Will fail after that many unsuccessful attempts. :param aws_conn_id: The Airflow connection used for AWS credentials. - :param region: The AWS region where the cluster is located. + :param region_name: The AWS region where the cluster is located. """ def __init__( self, - waiter_name: str, cluster_arn: str, - waiter_delay: int | None, - waiter_max_attempts: int | None, + waiter_delay: int, + waiter_max_attempts: int, aws_conn_id: str | None, - region: str | None, + region_name: str | None, ): - self.cluster_arn = cluster_arn - self.waiter_name = waiter_name - self.waiter_delay = waiter_delay if waiter_delay is not None else 15 # written like this to allow 0 - self.attempts = waiter_max_attempts or 999999999 - self.aws_conn_id = aws_conn_id - self.region = region + super().__init__( + serialized_fields={"cluster_arn": cluster_arn}, + waiter_name="cluster_active", + waiter_args={"clusters": [cluster_arn]}, + failure_message="Failure while waiting for cluster to be available", + status_message="Cluster is not ready yet", + status_queries=["clusters[].status", "failures"], + return_key="arn", + return_value=cluster_arn, + waiter_delay=waiter_delay, + waiter_max_attempts=waiter_max_attempts, + aws_conn_id=aws_conn_id, + region_name=region_name, + ) - def serialize(self) -> tuple[str, dict[str, Any]]: - return ( - self.__class__.__module__ + "." + self.__class__.__qualname__, - { - "waiter_name": self.waiter_name, - "cluster_arn": self.cluster_arn, - "waiter_delay": self.waiter_delay, - "waiter_max_attempts": self.attempts, - "aws_conn_id": self.aws_conn_id, - "region": self.region, - }, + def hook(self) -> AwsGenericHook: + return EcsHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) + + +class ClusterInactiveTrigger(AwsBaseWaiterTrigger): + """ + Polls the status of a cluster until it's inactive. + + :param cluster_arn: ARN of the cluster to watch. + :param waiter_delay: The amount of time in seconds to wait between attempts. + :param waiter_max_attempts: The number of times to ping for status. + Will fail after that many unsuccessful attempts. + :param aws_conn_id: The Airflow connection used for AWS credentials. + :param region_name: The AWS region where the cluster is located. + """ + + def __init__( + self, + cluster_arn: str, + waiter_delay: int, + waiter_max_attempts: int, + aws_conn_id: str | None, + region_name: str | None, + ): + super().__init__( + serialized_fields={"cluster_arn": cluster_arn}, + waiter_name="cluster_inactive", + waiter_args={"clusters": [cluster_arn]}, + failure_message="Failure while waiting for cluster to be deactivated", + status_message="Cluster deactivation is not done yet", + status_queries=["clusters[].status", "failures"], + return_value=cluster_arn, + waiter_delay=waiter_delay, + waiter_max_attempts=waiter_max_attempts, + aws_conn_id=aws_conn_id, + region_name=region_name, ) - async def run(self) -> AsyncIterator[TriggerEvent]: - async with EcsHook(aws_conn_id=self.aws_conn_id, region_name=self.region).async_conn as client: - waiter = client.get_waiter(self.waiter_name) - await async_wait( - waiter, - self.waiter_delay, - self.attempts, - {"clusters": [self.cluster_arn]}, - "error when checking cluster status", - "Status of cluster", - ["clusters[].status"], - ) - yield TriggerEvent({"status": "success", "arn": self.cluster_arn}) + def hook(self) -> AwsGenericHook: + return EcsHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) class TaskDoneTrigger(BaseTrigger): diff --git a/airflow/providers/amazon/aws/triggers/eks.py b/airflow/providers/amazon/aws/triggers/eks.py index d01c88dc88c80..a6fb75eb80fa2 100644 --- a/airflow/providers/amazon/aws/triggers/eks.py +++ b/airflow/providers/amazon/aws/triggers/eks.py @@ -16,18 +16,15 @@ # under the License. from __future__ import annotations -import asyncio -from typing import Any +import warnings -from botocore.exceptions import WaiterError - -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowProviderDeprecationWarning +from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook from airflow.providers.amazon.aws.hooks.eks import EksHook -from airflow.providers.amazon.aws.utils.waiter_with_logging import async_wait -from airflow.triggers.base import BaseTrigger, TriggerEvent +from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger -class EksCreateFargateProfileTrigger(BaseTrigger): +class EksCreateFargateProfileTrigger(AwsBaseWaiterTrigger): """ Asynchronously wait for the fargate profile to be created. @@ -46,57 +43,35 @@ def __init__( waiter_max_attempts: int, aws_conn_id: str, region: str | None = None, + region_name: str | None = None, ): - self.cluster_name = cluster_name - self.fargate_profile_name = fargate_profile_name - self.waiter_delay = waiter_delay - self.waiter_max_attempts = waiter_max_attempts - self.aws_conn_id = aws_conn_id - self.region = region - - def serialize(self) -> tuple[str, dict[str, Any]]: - return ( - self.__class__.__module__ + "." + self.__class__.__qualname__, - { - "cluster_name": self.cluster_name, - "fargate_profile_name": self.fargate_profile_name, - "waiter_delay": str(self.waiter_delay), - "waiter_max_attempts": str(self.waiter_max_attempts), - "aws_conn_id": self.aws_conn_id, - "region": self.region, - }, + if region is not None: + warnings.warn( + "please use region_name param instead of region", + AirflowProviderDeprecationWarning, + stacklevel=2, + ) + region_name = region + + super().__init__( + serialized_fields={"cluster_name": cluster_name, "fargate_profile_name": fargate_profile_name}, + waiter_name="fargate_profile_active", + waiter_args={"clusterName": cluster_name, "fargateProfileName": fargate_profile_name}, + failure_message="Failure while creating Fargate profile", + status_message="Fargate profile not created yet", + status_queries=["fargateProfile.status"], + return_value=None, + waiter_delay=waiter_delay, + waiter_max_attempts=waiter_max_attempts, + aws_conn_id=aws_conn_id, + region_name=region_name, ) - async def run(self): - self.hook = EksHook(aws_conn_id=self.aws_conn_id, region_name=self.region) - async with self.hook.async_conn as client: - attempt = 0 - waiter = client.get_waiter("fargate_profile_active") - while attempt < int(self.waiter_max_attempts): - attempt += 1 - try: - await waiter.wait( - clusterName=self.cluster_name, - fargateProfileName=self.fargate_profile_name, - WaiterConfig={"Delay": int(self.waiter_delay), "MaxAttempts": 1}, - ) - break - except WaiterError as error: - if "terminal failure" in str(error): - raise AirflowException(f"Create Fargate Profile failed: {error}") - self.log.info( - "Status of fargate profile is %s", error.last_response["fargateProfile"]["status"] - ) - await asyncio.sleep(int(self.waiter_delay)) - if attempt >= int(self.waiter_max_attempts): - raise AirflowException( - f"Create Fargate Profile failed - max attempts reached: {self.waiter_max_attempts}" - ) - else: - yield TriggerEvent({"status": "success", "message": "Fargate Profile Created"}) + def hook(self) -> AwsGenericHook: + return EksHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) -class EksDeleteFargateProfileTrigger(BaseTrigger): +class EksDeleteFargateProfileTrigger(AwsBaseWaiterTrigger): """ Asynchronously wait for the fargate profile to be deleted. @@ -115,59 +90,37 @@ def __init__( waiter_max_attempts: int, aws_conn_id: str, region: str | None = None, + region_name: str | None = None, ): - self.cluster_name = cluster_name - self.fargate_profile_name = fargate_profile_name - self.waiter_delay = waiter_delay - self.waiter_max_attempts = waiter_max_attempts - self.aws_conn_id = aws_conn_id - self.region = region - - def serialize(self) -> tuple[str, dict[str, Any]]: - return ( - self.__class__.__module__ + "." + self.__class__.__qualname__, - { - "cluster_name": self.cluster_name, - "fargate_profile_name": self.fargate_profile_name, - "waiter_delay": str(self.waiter_delay), - "waiter_max_attempts": str(self.waiter_max_attempts), - "aws_conn_id": self.aws_conn_id, - "region": self.region, - }, + if region is not None: + warnings.warn( + "please use region_name param instead of region", + AirflowProviderDeprecationWarning, + stacklevel=2, + ) + region_name = region + + super().__init__( + serialized_fields={"cluster_name": cluster_name, "fargate_profile_name": fargate_profile_name}, + waiter_name="fargate_profile_deleted", + waiter_args={"clusterName": cluster_name, "fargateProfileName": fargate_profile_name}, + failure_message="Failure while deleting Fargate profile", + status_message="Fargate profile not deleted yet", + status_queries=["fargateProfile.status"], + return_value=None, + waiter_delay=waiter_delay, + waiter_max_attempts=waiter_max_attempts, + aws_conn_id=aws_conn_id, + region_name=region_name, ) - async def run(self): - self.hook = EksHook(aws_conn_id=self.aws_conn_id, region_name=self.region) - async with self.hook.async_conn as client: - attempt = 0 - waiter = client.get_waiter("fargate_profile_deleted") - while attempt < int(self.waiter_max_attempts): - attempt += 1 - try: - await waiter.wait( - clusterName=self.cluster_name, - fargateProfileName=self.fargate_profile_name, - WaiterConfig={"Delay": int(self.waiter_delay), "MaxAttempts": 1}, - ) - break - except WaiterError as error: - if "terminal failure" in str(error): - raise AirflowException(f"Delete Fargate Profile failed: {error}") - self.log.info( - "Status of fargate profile is %s", error.last_response["fargateProfile"]["status"] - ) - await asyncio.sleep(int(self.waiter_delay)) - if attempt >= int(self.waiter_max_attempts): - raise AirflowException( - f"Delete Fargate Profile failed - max attempts reached: {self.waiter_max_attempts}" - ) - else: - yield TriggerEvent({"status": "success", "message": "Fargate Profile Deleted"}) + def hook(self) -> AwsGenericHook: + return EksHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) -class EksNodegroupTrigger(BaseTrigger): +class EksCreateNodegroupTrigger(AwsBaseWaiterTrigger): """ - Trigger for EksCreateNodegroupOperator and EksDeleteNodegroupOperator. + Trigger for EksCreateNodegroupOperator. The trigger will asynchronously poll the boto3 API and wait for the nodegroup to be in the state specified by the waiter. @@ -184,54 +137,70 @@ class EksNodegroupTrigger(BaseTrigger): def __init__( self, - waiter_name: str, cluster_name: str, nodegroup_name: str, waiter_delay: int, waiter_max_attempts: int, aws_conn_id: str, - region: str | None, + region_name: str | None, ): - self.waiter_name = waiter_name - self.cluster_name = cluster_name - self.nodegroup_name = nodegroup_name - self.aws_conn_id = aws_conn_id - self.waiter_delay = waiter_delay - self.waiter_max_attempts = waiter_max_attempts - self.region = region - - def serialize(self) -> tuple[str, dict[str, Any]]: - return ( - self.__class__.__module__ + "." + self.__class__.__qualname__, - { - "waiter_name": self.waiter_name, - "cluster_name": self.cluster_name, - "nodegroup_name": self.nodegroup_name, - "waiter_delay": str(self.waiter_delay), - "waiter_max_attempts": str(self.waiter_max_attempts), - "aws_conn_id": self.aws_conn_id, - "region": self.region, - }, + super().__init__( + serialized_fields={"cluster_name": cluster_name, "nodegroup_name": nodegroup_name}, + waiter_name="nodegroup_active", + waiter_args={"clusterName": cluster_name, "nodegroupName": nodegroup_name}, + failure_message="Error creating nodegroup", + status_message="Nodegroup status is", + status_queries=["nodegroup.status", "nodegroup.health.issues"], + return_value=None, + waiter_delay=waiter_delay, + waiter_max_attempts=waiter_max_attempts, + aws_conn_id=aws_conn_id, + region_name=region_name, ) - async def run(self): - self.hook = EksHook(aws_conn_id=self.aws_conn_id, region_name=self.region) - async with self.hook.async_conn as client: - waiter = client.get_waiter(self.waiter_name) - await async_wait( - waiter=waiter, - waiter_max_attempts=int(self.waiter_max_attempts), - waiter_delay=int(self.waiter_delay), - args={"clusterName": self.cluster_name, "nodegroupName": self.nodegroup_name}, - failure_message="Error checking nodegroup", - status_message="Nodegroup status is", - status_args=["nodegroup.status"], - ) + def hook(self) -> AwsGenericHook: + return EksHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) + + +class EksDeleteNodegroupTrigger(AwsBaseWaiterTrigger): + """ + Trigger for EksDeleteNodegroupOperator. + + The trigger will asynchronously poll the boto3 API and wait for the + nodegroup to be in the state specified by the waiter. - yield TriggerEvent( - { - "status": "success", - "cluster_name": self.cluster_name, - "nodegroup_name": self.nodegroup_name, - } + :param waiter_name: Name of the waiter to use, for instance 'nodegroup_active' or 'nodegroup_deleted' + :param cluster_name: The name of the EKS cluster associated with the node group. + :param nodegroup_name: The name of the nodegroup to check. + :param waiter_delay: The amount of time in seconds to wait between attempts. + :param waiter_max_attempts: The maximum number of attempts to be made. + :param aws_conn_id: The Airflow connection used for AWS credentials. + :param region: Which AWS region the connection should use. (templated) + If this is None or empty then the default boto3 behaviour is used. + """ + + def __init__( + self, + cluster_name: str, + nodegroup_name: str, + waiter_delay: int, + waiter_max_attempts: int, + aws_conn_id: str, + region_name: str | None, + ): + super().__init__( + serialized_fields={"cluster_name": cluster_name, "nodegroup_name": nodegroup_name}, + waiter_name="nodegroup_deleted", + waiter_args={"clusterName": cluster_name, "nodegroupName": nodegroup_name}, + failure_message="Error deleting nodegroup", + status_message="Nodegroup status is", + status_queries=["nodegroup.status", "nodegroup.health.issues"], + return_value=None, + waiter_delay=waiter_delay, + waiter_max_attempts=waiter_max_attempts, + aws_conn_id=aws_conn_id, + region_name=region_name, ) + + def hook(self) -> AwsGenericHook: + return EksHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) diff --git a/airflow/providers/amazon/aws/triggers/emr.py b/airflow/providers/amazon/aws/triggers/emr.py index dbf620e9cbec3..7deadc1f37467 100644 --- a/airflow/providers/amazon/aws/triggers/emr.py +++ b/airflow/providers/amazon/aws/triggers/emr.py @@ -17,16 +17,16 @@ from __future__ import annotations import asyncio -from functools import cached_property -from typing import Any, AsyncIterator +import warnings +from typing import Any from botocore.exceptions import WaiterError -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowProviderDeprecationWarning +from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook -from airflow.providers.amazon.aws.utils.waiter_with_logging import async_wait +from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger from airflow.triggers.base import BaseTrigger, TriggerEvent -from airflow.utils.helpers import prune_dict class EmrAddStepsTrigger(BaseTrigger): @@ -102,157 +102,113 @@ async def run(self): yield TriggerEvent({"status": "success", "message": "Steps completed", "step_ids": self.step_ids}) -class EmrCreateJobFlowTrigger(BaseTrigger): +class EmrCreateJobFlowTrigger(AwsBaseWaiterTrigger): """ Asynchronously poll the boto3 API and wait for the JobFlow to finish executing. :param job_flow_id: The id of the job flow to wait for. - :param poll_interval: The amount of time in seconds to wait between attempts. - :param max_attempts: The maximum number of attempts to be made. + :param waiter_delay: The amount of time in seconds to wait between attempts. + :param waiter_max_attempts: The maximum number of attempts to be made. :param aws_conn_id: The Airflow connection used for AWS credentials. """ def __init__( self, job_flow_id: str, - poll_interval: int, - max_attempts: int, - aws_conn_id: str, + poll_interval: int | None = None, # deprecated + max_attempts: int | None = None, # deprecated + aws_conn_id: str | None = None, + waiter_delay: int = 30, + waiter_max_attempts: int = 60, ): - self.job_flow_id = job_flow_id - self.poll_interval = poll_interval - self.max_attempts = max_attempts - self.aws_conn_id = aws_conn_id - - def serialize(self) -> tuple[str, dict[str, Any]]: - return ( - self.__class__.__module__ + "." + self.__class__.__qualname__, - { - "job_flow_id": self.job_flow_id, - "poll_interval": str(self.poll_interval), - "max_attempts": str(self.max_attempts), - "aws_conn_id": self.aws_conn_id, - }, + if poll_interval is not None or max_attempts is not None: + warnings.warn( + "please use waiter_delay instead of poll_interval " + "and waiter_max_attempts instead of max_attempts", + AirflowProviderDeprecationWarning, + stacklevel=2, + ) + waiter_delay = poll_interval or waiter_delay + waiter_max_attempts = max_attempts or waiter_max_attempts + super().__init__( + serialized_fields={"job_flow_id": job_flow_id}, + waiter_name="job_flow_waiting", + waiter_args={"ClusterId": job_flow_id}, + failure_message="JobFlow creation failed", + status_message="JobFlow creation in progress", + status_queries=[ + "Cluster.Status.State", + "Cluster.Status.StateChangeReason", + "Cluster.Status.ErrorDetails", + ], + return_key="job_flow_id", + return_value=job_flow_id, + waiter_delay=waiter_delay, + waiter_max_attempts=waiter_max_attempts, + aws_conn_id=aws_conn_id, ) - async def run(self): - self.hook = EmrHook(aws_conn_id=self.aws_conn_id) - async with self.hook.async_conn as client: - attempt = 0 - waiter = self.hook.get_waiter("job_flow_waiting", deferrable=True, client=client) - while attempt < int(self.max_attempts): - attempt = attempt + 1 - try: - await waiter.wait( - ClusterId=self.job_flow_id, - WaiterConfig=prune_dict( - { - "Delay": self.poll_interval, - "MaxAttempts": 1, - } - ), - ) - break - except WaiterError as error: - if "terminal failure" in str(error): - raise AirflowException(f"JobFlow creation failed: {error}") - self.log.info( - "Status of jobflow is %s - %s", - error.last_response["Cluster"]["Status"]["State"], - error.last_response["Cluster"]["Status"]["StateChangeReason"], - ) - await asyncio.sleep(int(self.poll_interval)) - if attempt >= int(self.max_attempts): - raise AirflowException(f"JobFlow creation failed - max attempts reached: {self.max_attempts}") - else: - yield TriggerEvent( - { - "status": "success", - "message": "JobFlow completed successfully", - "job_flow_id": self.job_flow_id, - } - ) + def hook(self) -> AwsGenericHook: + return EmrHook(aws_conn_id=self.aws_conn_id) -class EmrTerminateJobFlowTrigger(BaseTrigger): +class EmrTerminateJobFlowTrigger(AwsBaseWaiterTrigger): """ Asynchronously poll the boto3 API and wait for the JobFlow to finish terminating. :param job_flow_id: ID of the EMR Job Flow to terminate - :param poll_interval: The amount of time in seconds to wait between attempts. - :param max_attempts: The maximum number of attempts to be made. + :param waiter_delay: The amount of time in seconds to wait between attempts. + :param waiter_max_attempts: The maximum number of attempts to be made. :param aws_conn_id: The Airflow connection used for AWS credentials. """ def __init__( self, job_flow_id: str, - poll_interval: int, - max_attempts: int, - aws_conn_id: str, + poll_interval: int | None = None, # deprecated + max_attempts: int | None = None, # deprecated + aws_conn_id: str | None = None, + waiter_delay: int = 30, + waiter_max_attempts: int = 60, ): - self.job_flow_id = job_flow_id - self.poll_interval = poll_interval - self.max_attempts = max_attempts - self.aws_conn_id = aws_conn_id - - def serialize(self) -> tuple[str, dict[str, Any]]: - return ( - self.__class__.__module__ + "." + self.__class__.__qualname__, - { - "job_flow_id": self.job_flow_id, - "poll_interval": str(self.poll_interval), - "max_attempts": str(self.max_attempts), - "aws_conn_id": self.aws_conn_id, - }, + if poll_interval is not None or max_attempts is not None: + warnings.warn( + "please use waiter_delay instead of poll_interval " + "and waiter_max_attempts instead of max_attempts", + AirflowProviderDeprecationWarning, + stacklevel=2, + ) + waiter_delay = poll_interval or waiter_delay + waiter_max_attempts = max_attempts or waiter_max_attempts + super().__init__( + serialized_fields={"job_flow_id": job_flow_id}, + waiter_name="job_flow_terminated", + waiter_args={"ClusterId": job_flow_id}, + failure_message="JobFlow termination failed", + status_message="JobFlow termination in progress", + status_queries=[ + "Cluster.Status.State", + "Cluster.Status.StateChangeReason", + "Cluster.Status.ErrorDetails", + ], + return_value=None, + waiter_delay=waiter_delay, + waiter_max_attempts=waiter_max_attempts, + aws_conn_id=aws_conn_id, ) - async def run(self): - self.hook = EmrHook(aws_conn_id=self.aws_conn_id) - async with self.hook.async_conn as client: - attempt = 0 - waiter = self.hook.get_waiter("job_flow_terminated", deferrable=True, client=client) - while attempt < int(self.max_attempts): - attempt = attempt + 1 - try: - await waiter.wait( - ClusterId=self.job_flow_id, - WaiterConfig=prune_dict( - { - "Delay": self.poll_interval, - "MaxAttempts": 1, - } - ), - ) - break - except WaiterError as error: - if "terminal failure" in str(error): - raise AirflowException(f"JobFlow termination failed: {error}") - self.log.info( - "Status of jobflow is %s - %s", - error.last_response["Cluster"]["Status"]["State"], - error.last_response["Cluster"]["Status"]["StateChangeReason"], - ) - await asyncio.sleep(int(self.poll_interval)) - if attempt >= int(self.max_attempts): - raise AirflowException(f"JobFlow termination failed - max attempts reached: {self.max_attempts}") - else: - yield TriggerEvent( - { - "status": "success", - "message": "JobFlow terminated successfully", - } - ) + def hook(self) -> AwsGenericHook: + return EmrHook(aws_conn_id=self.aws_conn_id) -class EmrContainerTrigger(BaseTrigger): +class EmrContainerTrigger(AwsBaseWaiterTrigger): """ Poll for the status of EMR container until reaches terminal state. :param virtual_cluster_id: Reference Emr cluster id :param job_id: job_id to check the state :param aws_conn_id: Reference to AWS connection id - :param poll_interval: polling period in seconds to check for the status + :param waiter_delay: polling period in seconds to check for the status """ def __init__( @@ -260,116 +216,70 @@ def __init__( virtual_cluster_id: str, job_id: str, aws_conn_id: str = "aws_default", - poll_interval: int = 30, - **kwargs: Any, + poll_interval: int | None = None, # deprecated + waiter_delay: int = 30, + waiter_max_attempts: int = 600, ): - self.virtual_cluster_id = virtual_cluster_id - self.job_id = job_id - self.aws_conn_id = aws_conn_id - self.poll_interval = poll_interval - super().__init__(**kwargs) - - @cached_property - def hook(self) -> EmrContainerHook: - return EmrContainerHook(self.aws_conn_id, virtual_cluster_id=self.virtual_cluster_id) - - def serialize(self) -> tuple[str, dict[str, Any]]: - """Serializes EmrContainerTrigger arguments and classpath.""" - return ( - "airflow.providers.amazon.aws.triggers.emr.EmrContainerTrigger", - { - "virtual_cluster_id": self.virtual_cluster_id, - "job_id": self.job_id, - "aws_conn_id": self.aws_conn_id, - "poll_interval": self.poll_interval, - }, + if poll_interval is not None: + warnings.warn( + "please use waiter_delay instead of poll_interval.", + AirflowProviderDeprecationWarning, + stacklevel=2, + ) + waiter_delay = poll_interval or waiter_delay + super().__init__( + serialized_fields={"virtual_cluster_id": virtual_cluster_id, "job_id": job_id}, + waiter_name="container_job_complete", + waiter_args={"id": job_id, "virtualClusterId": virtual_cluster_id}, + failure_message="Job failed", + status_message="Job in progress", + status_queries=["jobRun.state", "jobRun.failureReason"], + return_key="job_id", + return_value=job_id, + waiter_delay=waiter_delay, + waiter_max_attempts=waiter_max_attempts, + aws_conn_id=aws_conn_id, ) - async def run(self) -> AsyncIterator[TriggerEvent]: - async with self.hook.async_conn as client: - waiter = self.hook.get_waiter("container_job_complete", deferrable=True, client=client) - attempt = 0 - while True: - attempt = attempt + 1 - try: - await waiter.wait( - id=self.job_id, - virtualClusterId=self.virtual_cluster_id, - WaiterConfig={ - "Delay": self.poll_interval, - "MaxAttempts": 1, - }, - ) - break - except WaiterError as error: - if "terminal failure" in str(error): - yield TriggerEvent({"status": "failure", "message": f"Job Failed: {error}"}) - break - self.log.info( - "Job status is %s. Retrying attempt %s", - error.last_response["jobRun"]["state"], - attempt, - ) - await asyncio.sleep(int(self.poll_interval)) - - yield TriggerEvent({"status": "success", "job_id": self.job_id}) + def hook(self) -> AwsGenericHook: + return EmrContainerHook(self.aws_conn_id) -class EmrStepSensorTrigger(BaseTrigger): +class EmrStepSensorTrigger(AwsBaseWaiterTrigger): """ Poll for the status of EMR container until reaches terminal state. :param job_flow_id: job_flow_id which contains the step check the state of :param step_id: step to check the state of + :param waiter_delay: polling period in seconds to check for the status + :param waiter_max_attempts: The maximum number of attempts to be made :param aws_conn_id: Reference to AWS connection id - :param max_attempts: The maximum number of attempts to be made - :param poke_interval: polling period in seconds to check for the status """ def __init__( self, job_flow_id: str, step_id: str, + waiter_delay: int = 30, + waiter_max_attempts: int = 60, aws_conn_id: str = "aws_default", - max_attempts: int = 60, - poke_interval: int = 30, - **kwargs: Any, ): - self.job_flow_id = job_flow_id - self.step_id = step_id - self.aws_conn_id = aws_conn_id - self.max_attempts = max_attempts - self.poke_interval = poke_interval - super().__init__(**kwargs) - - @cached_property - def hook(self) -> EmrHook: - return EmrHook(self.aws_conn_id) - - def serialize(self) -> tuple[str, dict[str, Any]]: - return ( - "airflow.providers.amazon.aws.triggers.emr.EmrStepSensorTrigger", - { - "job_flow_id": self.job_flow_id, - "step_id": self.step_id, - "aws_conn_id": self.aws_conn_id, - "max_attempts": self.max_attempts, - "poke_interval": self.poke_interval, - }, + super().__init__( + serialized_fields={"job_flow_id": job_flow_id, "step_id": step_id}, + waiter_name="step_wait_for_terminal", + waiter_args={"ClusterId": job_flow_id, "StepId": step_id}, + failure_message=f"Error while waiting for step {step_id} to complete", + status_message=f"Step id: {step_id}, Step is still in non-terminal state", + status_queries=[ + "Step.Status.State", + "Step.Status.FailureDetails", + "Step.Status.StateChangeReason", + ], + return_value=None, + waiter_delay=waiter_delay, + waiter_max_attempts=waiter_max_attempts, + aws_conn_id=aws_conn_id, ) - async def run(self) -> AsyncIterator[TriggerEvent]: - - async with self.hook.async_conn as client: - waiter = client.get_waiter("step_wait_for_terminal", deferrable=True, client=client) - await async_wait( - waiter=waiter, - waiter_delay=self.poke_interval, - waiter_max_attempts=self.max_attempts, - args={"ClusterId": self.job_flow_id, "StepId": self.step_id}, - failure_message=f"Error while waiting for step {self.step_id} to complete", - status_message=f"Step id: {self.step_id}, Step is still in non-terminal state", - status_args=["Step.Status.State"], - ) - - yield TriggerEvent({"status": "success"}) + def hook(self) -> AwsGenericHook: + return EmrHook(self.aws_conn_id) diff --git a/airflow/providers/amazon/aws/triggers/glue_crawler.py b/airflow/providers/amazon/aws/triggers/glue_crawler.py index e891e87627f78..092b05242983a 100644 --- a/airflow/providers/amazon/aws/triggers/glue_crawler.py +++ b/airflow/providers/amazon/aws/triggers/glue_crawler.py @@ -16,17 +16,15 @@ # under the License. from __future__ import annotations -import asyncio -from functools import cached_property -from typing import AsyncIterator - -from botocore.exceptions import WaiterError +import warnings +from airflow.exceptions import AirflowProviderDeprecationWarning +from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook from airflow.providers.amazon.aws.hooks.glue_crawler import GlueCrawlerHook -from airflow.triggers.base import BaseTrigger, TriggerEvent +from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger -class GlueCrawlerCompleteTrigger(BaseTrigger): +class GlueCrawlerCompleteTrigger(AwsBaseWaiterTrigger): """ Watches for a glue crawl, triggers when it finishes. @@ -35,41 +33,33 @@ class GlueCrawlerCompleteTrigger(BaseTrigger): :param aws_conn_id: The Airflow connection used for AWS credentials. """ - def __init__(self, crawler_name: str, poll_interval: int, aws_conn_id: str): - super().__init__() - self.crawler_name = crawler_name - self.poll_interval = poll_interval - self.aws_conn_id = aws_conn_id - - def serialize(self) -> tuple[str, dict]: - return ( - # dynamically generate the fully qualified name of the class - self.__class__.__module__ + "." + self.__class__.__qualname__, - { - "crawler_name": self.crawler_name, - "poll_interval": self.poll_interval, - "aws_conn_id": self.aws_conn_id, - }, + def __init__( + self, + crawler_name: str, + poll_interval: int | None = None, + aws_conn_id: str = "aws_default", + waiter_delay: int = 5, + waiter_max_attempts: int = 1500, + ): + if poll_interval is not None: + warnings.warn( + "please use waiter_delay instead of poll_interval.", + AirflowProviderDeprecationWarning, + stacklevel=2, + ) + waiter_delay = poll_interval or waiter_delay + super().__init__( + serialized_fields={"crawler_name": crawler_name}, + waiter_name="crawler_ready", + waiter_args={"Name": crawler_name}, + failure_message="Error while waiting for glue crawl to complete", + status_message="Status of glue crawl is", + status_queries=["Crawler.State", "Crawler.LastCrawl"], + return_value=None, + waiter_delay=waiter_delay, + waiter_max_attempts=waiter_max_attempts, + aws_conn_id=aws_conn_id, ) - @cached_property - def hook(self) -> GlueCrawlerHook: + def hook(self) -> AwsGenericHook: return GlueCrawlerHook(aws_conn_id=self.aws_conn_id) - - async def run(self) -> AsyncIterator[TriggerEvent]: - async with self.hook.async_conn as client: - waiter = self.hook.get_waiter("crawler_ready", deferrable=True, client=client) - while True: - try: - await waiter.wait( - Name=self.crawler_name, - WaiterConfig={"Delay": self.poll_interval, "MaxAttempts": 1}, - ) - break # we reach this point only if the waiter met a success criteria - except WaiterError as error: - if "terminal failure" in str(error): - raise - self.log.info("Status of glue crawl is %s", error.last_response["Crawler"]["State"]) - await asyncio.sleep(int(self.poll_interval)) - - yield TriggerEvent({"status": "success", "message": "Crawl Complete"}) diff --git a/airflow/providers/amazon/aws/triggers/redshift_cluster.py b/airflow/providers/amazon/aws/triggers/redshift_cluster.py index b3770bd395c6c..678f435364c08 100644 --- a/airflow/providers/amazon/aws/triggers/redshift_cluster.py +++ b/airflow/providers/amazon/aws/triggers/redshift_cluster.py @@ -16,17 +16,15 @@ # under the License. from __future__ import annotations -import asyncio -from functools import cached_property -from typing import Any, AsyncIterator - -from botocore.exceptions import WaiterError +import warnings +from airflow.exceptions import AirflowProviderDeprecationWarning +from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftHook -from airflow.triggers.base import BaseTrigger, TriggerEvent +from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger -class RedshiftCreateClusterTrigger(BaseTrigger): +class RedshiftCreateClusterTrigger(AwsBaseWaiterTrigger): """ Trigger for RedshiftCreateClusterOperator. @@ -34,51 +32,47 @@ class RedshiftCreateClusterTrigger(BaseTrigger): Redshift cluster to be in the `available` state. :param cluster_identifier: A unique identifier for the cluster. - :param poll_interval: The amount of time in seconds to wait between attempts. - :param max_attempt: The maximum number of attempts to be made. + :param waiter_delay: The amount of time in seconds to wait between attempts. + :param waiter_max_attempts: The maximum number of attempts to be made. :param aws_conn_id: The Airflow connection used for AWS credentials. """ def __init__( self, cluster_identifier: str, - poll_interval: int, - max_attempt: int, - aws_conn_id: str, + poll_interval: int | None = None, + max_attempt: int | None = None, + aws_conn_id: str = "aws_default", + waiter_delay: int = 15, + waiter_max_attempts: int = 999999, ): - self.cluster_identifier = cluster_identifier - self.poll_interval = poll_interval - self.max_attempt = max_attempt - self.aws_conn_id = aws_conn_id - - def serialize(self) -> tuple[str, dict[str, Any]]: - return ( - "airflow.providers.amazon.aws.triggers.redshift_cluster.RedshiftCreateClusterTrigger", - { - "cluster_identifier": str(self.cluster_identifier), - "poll_interval": str(self.poll_interval), - "max_attempt": str(self.max_attempt), - "aws_conn_id": str(self.aws_conn_id), - }, + if poll_interval is not None or max_attempt is not None: + warnings.warn( + "please use waiter_delay instead of poll_interval " + "and waiter_max_attempts instead of max_attempt.", + AirflowProviderDeprecationWarning, + stacklevel=2, + ) + waiter_delay = poll_interval or waiter_delay + waiter_max_attempts = max_attempt or waiter_max_attempts + super().__init__( + serialized_fields={"cluster_identifier": cluster_identifier}, + waiter_name="cluster_available", + waiter_args={"ClusterIdentifier": cluster_identifier}, + failure_message="Error while creating the redshift cluster", + status_message="Redshift cluster creation in progress", + status_queries=["Clusters[].ClusterStatus"], + return_value=None, + waiter_delay=waiter_delay, + waiter_max_attempts=waiter_max_attempts, + aws_conn_id=aws_conn_id, ) - @cached_property - def hook(self) -> RedshiftHook: + def hook(self) -> AwsGenericHook: return RedshiftHook(aws_conn_id=self.aws_conn_id) - async def run(self): - async with self.hook.async_conn as client: - await client.get_waiter("cluster_available").wait( - ClusterIdentifier=self.cluster_identifier, - WaiterConfig={ - "Delay": int(self.poll_interval), - "MaxAttempts": int(self.max_attempt), - }, - ) - yield TriggerEvent({"status": "success", "message": "Cluster Created"}) - -class RedshiftPauseClusterTrigger(BaseTrigger): +class RedshiftPauseClusterTrigger(AwsBaseWaiterTrigger): """ Trigger for RedshiftPauseClusterOperator. @@ -86,70 +80,47 @@ class RedshiftPauseClusterTrigger(BaseTrigger): Redshift cluster to be in the `paused` state. :param cluster_identifier: A unique identifier for the cluster. - :param poll_interval: The amount of time in seconds to wait between attempts. - :param max_attempts: The maximum number of attempts to be made. + :param waiter_delay: The amount of time in seconds to wait between attempts. + :param waiter_max_attempts: The maximum number of attempts to be made. :param aws_conn_id: The Airflow connection used for AWS credentials. """ def __init__( self, cluster_identifier: str, - poll_interval: int, - max_attempts: int, - aws_conn_id: str, + poll_interval: int | None = None, + max_attempts: int | None = None, + aws_conn_id: str = "aws_default", + waiter_delay: int = 15, + waiter_max_attempts: int = 999999, ): - self.cluster_identifier = cluster_identifier - self.poll_interval = poll_interval - self.max_attempts = max_attempts - self.aws_conn_id = aws_conn_id - - def serialize(self) -> tuple[str, dict[str, Any]]: - return ( - "airflow.providers.amazon.aws.triggers.redshift_cluster.RedshiftPauseClusterTrigger", - { - "cluster_identifier": self.cluster_identifier, - "poll_interval": str(self.poll_interval), - "max_attempts": str(self.max_attempts), - "aws_conn_id": self.aws_conn_id, - }, + if poll_interval is not None or max_attempts is not None: + warnings.warn( + "please use waiter_delay instead of poll_interval " + "and waiter_max_attempts instead of max_attempt.", + AirflowProviderDeprecationWarning, + stacklevel=2, + ) + waiter_delay = poll_interval or waiter_delay + waiter_max_attempts = max_attempts or waiter_max_attempts + super().__init__( + serialized_fields={"cluster_identifier": cluster_identifier}, + waiter_name="cluster_paused", + waiter_args={"ClusterIdentifier": cluster_identifier}, + failure_message="Error while pausing the redshift cluster", + status_message="Redshift cluster pausing in progress", + status_queries=["Clusters[].ClusterStatus"], + return_value=None, + waiter_delay=waiter_delay, + waiter_max_attempts=waiter_max_attempts, + aws_conn_id=aws_conn_id, ) - @cached_property - def hook(self) -> RedshiftHook: + def hook(self) -> AwsGenericHook: return RedshiftHook(aws_conn_id=self.aws_conn_id) - async def run(self): - async with self.hook.async_conn as client: - attempt = 0 - waiter = self.hook.get_waiter("cluster_paused", deferrable=True, client=client) - while attempt < int(self.max_attempts): - attempt = attempt + 1 - try: - await waiter.wait( - ClusterIdentifier=self.cluster_identifier, - WaiterConfig={ - "Delay": int(self.poll_interval), - "MaxAttempts": 1, - }, - ) - break - except WaiterError as error: - if "terminal failure" in str(error): - yield TriggerEvent({"status": "failure", "message": f"Pause Cluster Failed: {error}"}) - break - self.log.info( - "Status of cluster is %s", error.last_response["Clusters"][0]["ClusterStatus"] - ) - await asyncio.sleep(int(self.poll_interval)) - if attempt >= int(self.max_attempts): - yield TriggerEvent( - {"status": "failure", "message": "Pause Cluster Failed - max attempts reached."} - ) - else: - yield TriggerEvent({"status": "success", "message": "Cluster paused"}) - -class RedshiftCreateClusterSnapshotTrigger(BaseTrigger): +class RedshiftCreateClusterSnapshotTrigger(AwsBaseWaiterTrigger): """ Trigger for RedshiftCreateClusterSnapshotOperator. @@ -157,75 +128,47 @@ class RedshiftCreateClusterSnapshotTrigger(BaseTrigger): Redshift cluster snapshot to be in the `available` state. :param cluster_identifier: A unique identifier for the cluster. - :param poll_interval: The amount of time in seconds to wait between attempts. - :param max_attempts: The maximum number of attempts to be made. + :param waiter_delay: The amount of time in seconds to wait between attempts. + :param waiter_max_attempts: The maximum number of attempts to be made. :param aws_conn_id: The Airflow connection used for AWS credentials. """ def __init__( self, cluster_identifier: str, - poll_interval: int, - max_attempts: int, - aws_conn_id: str, + poll_interval: int | None = None, + max_attempts: int | None = None, + aws_conn_id: str = "aws_default", + waiter_delay: int = 15, + waiter_max_attempts: int = 999999, ): - self.cluster_identifier = cluster_identifier - self.poll_interval = poll_interval - self.max_attempts = max_attempts - self.aws_conn_id = aws_conn_id - - def serialize(self) -> tuple[str, dict[str, Any]]: - return ( - "airflow.providers.amazon.aws.triggers.redshift_cluster.RedshiftCreateClusterSnapshotTrigger", - { - "cluster_identifier": self.cluster_identifier, - "poll_interval": str(self.poll_interval), - "max_attempts": str(self.max_attempts), - "aws_conn_id": self.aws_conn_id, - }, + if poll_interval is not None or max_attempts is not None: + warnings.warn( + "please use waiter_delay instead of poll_interval " + "and waiter_max_attempts instead of max_attempt.", + AirflowProviderDeprecationWarning, + stacklevel=2, + ) + waiter_delay = poll_interval or waiter_delay + waiter_max_attempts = max_attempts or waiter_max_attempts + super().__init__( + serialized_fields={"cluster_identifier": cluster_identifier}, + waiter_name="snapshot_available", + waiter_args={"ClusterIdentifier": cluster_identifier}, + failure_message="Create Cluster Snapshot Failed", + status_message="Redshift Cluster Snapshot in progress", + status_queries=["Clusters[].ClusterStatus"], + return_value=None, + waiter_delay=waiter_delay, + waiter_max_attempts=waiter_max_attempts, + aws_conn_id=aws_conn_id, ) - @cached_property - def hook(self) -> RedshiftHook: + def hook(self) -> AwsGenericHook: return RedshiftHook(aws_conn_id=self.aws_conn_id) - async def run(self): - async with self.hook.async_conn as client: - attempt = 0 - waiter = client.get_waiter("snapshot_available") - while attempt < int(self.max_attempts): - attempt = attempt + 1 - try: - await waiter.wait( - ClusterIdentifier=self.cluster_identifier, - WaiterConfig={ - "Delay": int(self.poll_interval), - "MaxAttempts": 1, - }, - ) - break - except WaiterError as error: - if "terminal failure" in str(error): - yield TriggerEvent( - {"status": "failure", "message": f"Create Cluster Snapshot Failed: {error}"} - ) - break - self.log.info( - "Status of cluster snapshot is %s", error.last_response["Snapshots"][0]["Status"] - ) - await asyncio.sleep(int(self.poll_interval)) - if attempt >= int(self.max_attempts): - yield TriggerEvent( - { - "status": "failure", - "message": "Create Cluster Snapshot Cluster Failed - max attempts reached.", - } - ) - else: - yield TriggerEvent({"status": "success", "message": "Cluster Snapshot Created"}) - -class RedshiftResumeClusterTrigger(BaseTrigger): +class RedshiftResumeClusterTrigger(AwsBaseWaiterTrigger): """ Trigger for RedshiftResumeClusterOperator. @@ -233,141 +176,86 @@ class RedshiftResumeClusterTrigger(BaseTrigger): Redshift cluster to be in the `available` state. :param cluster_identifier: A unique identifier for the cluster. - :param poll_interval: The amount of time in seconds to wait between attempts. - :param max_attempts: The maximum number of attempts to be made. + :param waiter_delay: The amount of time in seconds to wait between attempts. + :param waiter_max_attempts: The maximum number of attempts to be made. :param aws_conn_id: The Airflow connection used for AWS credentials. """ def __init__( self, cluster_identifier: str, - poll_interval: int, - max_attempts: int, - aws_conn_id: str, + poll_interval: int | None = None, + max_attempts: int | None = None, + aws_conn_id: str = "aws_default", + waiter_delay: int = 15, + waiter_max_attempts: int = 999999, ): - self.cluster_identifier = cluster_identifier - self.poll_interval = poll_interval - self.max_attempts = max_attempts - self.aws_conn_id = aws_conn_id - - def serialize(self) -> tuple[str, dict[str, Any]]: - return ( - "airflow.providers.amazon.aws.triggers.redshift_cluster.RedshiftResumeClusterTrigger", - { - "cluster_identifier": self.cluster_identifier, - "poll_interval": str(self.poll_interval), - "max_attempts": str(self.max_attempts), - "aws_conn_id": self.aws_conn_id, - }, + if poll_interval is not None or max_attempts is not None: + warnings.warn( + "please use waiter_delay instead of poll_interval " + "and waiter_max_attempts instead of max_attempt.", + AirflowProviderDeprecationWarning, + stacklevel=2, + ) + waiter_delay = poll_interval or waiter_delay + waiter_max_attempts = max_attempts or waiter_max_attempts + super().__init__( + serialized_fields={"cluster_identifier": cluster_identifier}, + waiter_name="cluster_resumed", + waiter_args={"ClusterIdentifier": cluster_identifier}, + failure_message="Resume Cluster Snapshot Failed", + status_message="Redshift Cluster resuming in progress", + status_queries=["Clusters[].ClusterStatus"], + return_value=None, + waiter_delay=waiter_delay, + waiter_max_attempts=waiter_max_attempts, + aws_conn_id=aws_conn_id, ) - @cached_property - def hook(self) -> RedshiftHook: + def hook(self) -> AwsGenericHook: return RedshiftHook(aws_conn_id=self.aws_conn_id) - async def run(self): - async with self.hook.async_conn as client: - attempt = 0 - waiter = self.hook.get_waiter("cluster_resumed", deferrable=True, client=client) - while attempt < int(self.max_attempts): - attempt = attempt + 1 - try: - await waiter.wait( - ClusterIdentifier=self.cluster_identifier, - WaiterConfig={ - "Delay": int(self.poll_interval), - "MaxAttempts": 1, - }, - ) - break - except WaiterError as error: - if "terminal failure" in str(error): - yield TriggerEvent( - {"status": "failure", "message": f"Resume Cluster Failed: {error}"} - ) - break - self.log.info( - "Status of cluster is %s", error.last_response["Clusters"][0]["ClusterStatus"] - ) - await asyncio.sleep(int(self.poll_interval)) - if attempt >= int(self.max_attempts): - yield TriggerEvent( - {"status": "failure", "message": "Resume Cluster Failed - max attempts reached."} - ) - else: - yield TriggerEvent({"status": "success", "message": "Cluster resumed"}) - -class RedshiftDeleteClusterTrigger(BaseTrigger): +class RedshiftDeleteClusterTrigger(AwsBaseWaiterTrigger): """ Trigger for RedshiftDeleteClusterOperator. :param cluster_identifier: A unique identifier for the cluster. - :param max_attempts: The maximum number of attempts to be made. + :param waiter_max_attempts: The maximum number of attempts to be made. :param aws_conn_id: The Airflow connection used for AWS credentials. - :param poll_interval: The amount of time in seconds to wait between attempts. + :param waiter_delay: The amount of time in seconds to wait between attempts. """ def __init__( self, cluster_identifier: str, - max_attempts: int = 30, + poll_interval: int | None = None, + max_attempts: int | None = None, aws_conn_id: str = "aws_default", - poll_interval: int = 30, + waiter_delay: int = 30, + waiter_max_attempts: int = 30, ): - super().__init__() - self.cluster_identifier = cluster_identifier - self.max_attempts = max_attempts - self.aws_conn_id = aws_conn_id - self.poll_interval = poll_interval - - def serialize(self) -> tuple[str, dict[str, Any]]: - return ( - "airflow.providers.amazon.aws.triggers.redshift_cluster.RedshiftDeleteClusterTrigger", - { - "cluster_identifier": self.cluster_identifier, - "max_attempts": self.max_attempts, - "aws_conn_id": self.aws_conn_id, - "poll_interval": self.poll_interval, - }, + if poll_interval is not None or max_attempts is not None: + warnings.warn( + "please use waiter_delay instead of poll_interval " + "and waiter_max_attempts instead of max_attempt.", + AirflowProviderDeprecationWarning, + stacklevel=2, + ) + waiter_delay = poll_interval or waiter_delay + waiter_max_attempts = max_attempts or waiter_max_attempts + super().__init__( + serialized_fields={"cluster_identifier": cluster_identifier}, + waiter_name="cluster_deleted", + waiter_args={"ClusterIdentifier": cluster_identifier}, + failure_message="Delete Cluster Failed", + status_message="Redshift Cluster deletion in progress", + status_queries=["Clusters[].ClusterStatus"], + return_value=None, + waiter_delay=waiter_delay, + waiter_max_attempts=waiter_max_attempts, + aws_conn_id=aws_conn_id, ) - @cached_property - def hook(self): + def hook(self) -> AwsGenericHook: return RedshiftHook(aws_conn_id=self.aws_conn_id) - - async def run(self) -> AsyncIterator[TriggerEvent]: - async with self.hook.async_conn as client: - attempt = 0 - waiter = client.get_waiter("cluster_deleted") - while attempt < self.max_attempts: - attempt = attempt + 1 - try: - await waiter.wait( - ClusterIdentifier=self.cluster_identifier, - WaiterConfig={ - "Delay": self.poll_interval, - "MaxAttempts": 1, - }, - ) - break - except WaiterError as error: - if "terminal failure" in str(error): - yield TriggerEvent( - {"status": "failure", "message": f"Delete Cluster Failed: {error}"} - ) - break - self.log.info( - "Cluster status is %s. Retrying attempt %s/%s", - error.last_response["Clusters"][0]["ClusterStatus"], - attempt, - self.max_attempts, - ) - await asyncio.sleep(int(self.poll_interval)) - - if attempt >= self.max_attempts: - yield TriggerEvent( - {"status": "failure", "message": "Delete Cluster Failed - max attempts reached."} - ) - else: - yield TriggerEvent({"status": "success", "message": "Cluster deleted."}) diff --git a/airflow/providers/amazon/provider.yaml b/airflow/providers/amazon/provider.yaml index 390a5866368ac..3622896504da4 100644 --- a/airflow/providers/amazon/provider.yaml +++ b/airflow/providers/amazon/provider.yaml @@ -516,6 +516,9 @@ hooks: - airflow.providers.amazon.aws.hooks.appflow triggers: + - integration-name: Amazon Web Services + python-modules: + - airflow.providers.amazon.aws.triggers.base - integration-name: Amazon Athena python-modules: - airflow.providers.amazon.aws.triggers.athena diff --git a/tests/providers/amazon/aws/operators/test_athena.py b/tests/providers/amazon/aws/operators/test_athena.py index 9e528525204c6..3eb6349850799 100644 --- a/tests/providers/amazon/aws/operators/test_athena.py +++ b/tests/providers/amazon/aws/operators/test_athena.py @@ -169,4 +169,3 @@ def test_is_deferred(self, mock_run_query): self.athena.execute(None) assert isinstance(deferred.value.trigger, AthenaTrigger) - assert deferred.value.trigger.query_execution_id == ATHENA_QUERY_ID diff --git a/tests/providers/amazon/aws/operators/test_batch.py b/tests/providers/amazon/aws/operators/test_batch.py index 3aace0bb3e5bb..7255035b9d55b 100644 --- a/tests/providers/amazon/aws/operators/test_batch.py +++ b/tests/providers/amazon/aws/operators/test_batch.py @@ -29,7 +29,7 @@ # Use dummy AWS credentials from airflow.providers.amazon.aws.triggers.batch import ( BatchCreateComputeEnvironmentTrigger, - BatchOperatorTrigger, + BatchJobTrigger, ) AWS_REGION = "eu-west-1" @@ -276,7 +276,7 @@ def test_defer_if_deferrable_param_set(self, mock_client): with pytest.raises(TaskDeferred) as exc: batch.execute(context=None) - assert isinstance(exc.value.trigger, BatchOperatorTrigger), "Trigger is not a BatchOperatorTrigger" + assert isinstance(exc.value.trigger, BatchJobTrigger) @mock.patch.object(BatchClientHook, "get_job_description") @mock.patch.object(BatchClientHook, "wait_for_job") @@ -349,6 +349,5 @@ def test_defer(self, client_mock): operator.execute(None) assert isinstance(deferred.value.trigger, BatchCreateComputeEnvironmentTrigger) - assert deferred.value.trigger.compute_env_arn == "my_arn" - assert deferred.value.trigger.poll_interval == 456789 - assert deferred.value.trigger.max_retries == 123456 + assert deferred.value.trigger.waiter_delay == 456789 + assert deferred.value.trigger.attempts == 123456 diff --git a/tests/providers/amazon/aws/operators/test_ecs.py b/tests/providers/amazon/aws/operators/test_ecs.py index b89ea59c561be..f324918171927 100644 --- a/tests/providers/amazon/aws/operators/test_ecs.py +++ b/tests/providers/amazon/aws/operators/test_ecs.py @@ -724,7 +724,6 @@ def test_execute_deferrable(self, mock_client: MagicMock): with pytest.raises(TaskDeferred) as defer: op.execute(None) - assert defer.value.trigger.cluster_arn == "my arn" assert defer.value.trigger.waiter_delay == 12 assert defer.value.trigger.attempts == 34 @@ -789,7 +788,6 @@ def test_execute_deferrable(self, mock_client: MagicMock): with pytest.raises(TaskDeferred) as defer: op.execute(None) - assert defer.value.trigger.cluster_arn == "my arn" assert defer.value.trigger.waiter_delay == 12 assert defer.value.trigger.attempts == 34 diff --git a/tests/providers/amazon/aws/operators/test_eks.py b/tests/providers/amazon/aws/operators/test_eks.py index 9ea8ab72d72fa..8381a26ef151a 100644 --- a/tests/providers/amazon/aws/operators/test_eks.py +++ b/tests/providers/amazon/aws/operators/test_eks.py @@ -36,8 +36,8 @@ ) from airflow.providers.amazon.aws.triggers.eks import ( EksCreateFargateProfileTrigger, + EksCreateNodegroupTrigger, EksDeleteFargateProfileTrigger, - EksNodegroupTrigger, ) from airflow.providers.cncf.kubernetes.utils.pod_manager import OnFinishAction from airflow.typing_compat import TypedDict @@ -489,7 +489,7 @@ def test_create_nodegroup_deferrable(self, mock_create_nodegroup): ) with pytest.raises(TaskDeferred) as exc: operator.execute({}) - assert isinstance(exc.value.trigger, EksNodegroupTrigger), "Trigger is not a EksNodegroupTrigger" + assert isinstance(exc.value.trigger, EksCreateNodegroupTrigger) def test_create_nodegroup_deferrable_versus_wait_for_completion(self): op_kwargs = {**self.create_nodegroup_params} diff --git a/tests/providers/amazon/aws/sensors/test_batch.py b/tests/providers/amazon/aws/sensors/test_batch.py index 42e9bffb5b688..74b348381e8ea 100644 --- a/tests/providers/amazon/aws/sensors/test_batch.py +++ b/tests/providers/amazon/aws/sensors/test_batch.py @@ -27,7 +27,7 @@ BatchJobQueueSensor, BatchSensor, ) -from airflow.providers.amazon.aws.triggers.batch import BatchSensorTrigger +from airflow.providers.amazon.aws.triggers.batch import BatchJobTrigger TASK_ID = "batch_job_sensor" JOB_ID = "8222a1c2-b246-4e19-b1b8-0039bb4407c0" @@ -210,26 +210,10 @@ def test_batch_sensor_async(self): with pytest.raises(TaskDeferred) as exc: self.TASK.execute({}) - assert isinstance(exc.value.trigger, BatchSensorTrigger), "Trigger is not a BatchSensorTrigger" + assert isinstance(exc.value.trigger, BatchJobTrigger), "Trigger is not a BatchJobTrigger" def test_batch_sensor_async_execute_failure(self): """Tests that an AirflowException is raised in case of error event""" - with pytest.raises(AirflowException) as exc_info: - self.TASK.execute_complete( - context={}, event={"status": "failure", "message": "test failure message"} - ) - - assert str(exc_info.value) == "test failure message" - - @pytest.mark.parametrize( - "event", - [{"status": "success", "message": f"AWS Batch job ({JOB_ID}) succeeded"}], - ) - def test_batch_sensor_async_execute_complete(self, caplog, event): - """Tests that execute_complete method returns None and that it prints expected log""" - - with mock.patch.object(self.TASK.log, "info") as mock_log_info: - assert self.TASK.execute_complete(context={}, event=event) is None - - mock_log_info.assert_called_with(event["message"]) + with pytest.raises(AirflowException): + self.TASK.execute_complete(context={}, event={"status": "failure"}) diff --git a/tests/providers/amazon/aws/triggers/test_athena.py b/tests/providers/amazon/aws/triggers/test_athena.py index 04e601f4392c4..d18bdc1553f78 100644 --- a/tests/providers/amazon/aws/triggers/test_athena.py +++ b/tests/providers/amazon/aws/triggers/test_athena.py @@ -16,38 +16,20 @@ # under the License. from __future__ import annotations -from unittest import mock -from unittest.mock import AsyncMock - -import pytest -from botocore.exceptions import WaiterError - -from airflow.providers.amazon.aws.hooks.athena import AthenaHook from airflow.providers.amazon.aws.triggers.athena import AthenaTrigger class TestAthenaTrigger: - @pytest.mark.asyncio - @mock.patch.object(AthenaHook, "get_waiter") - @mock.patch.object(AthenaHook, "async_conn") # LatestBoto step of CI fails without this - async def test_run_with_error(self, conn_mock, waiter_mock): - waiter_mock.side_effect = WaiterError("name", "reason", {}) - - trigger = AthenaTrigger("query_id", 0, 5, None) + def test_serialize_recreate(self): + trigger = AthenaTrigger("query_id", 1, 5, "aws connection") - with pytest.raises(WaiterError): - generator = trigger.run() - await generator.asend(None) + class_path, args = trigger.serialize() - @pytest.mark.asyncio - @mock.patch.object(AthenaHook, "get_waiter") - @mock.patch.object(AthenaHook, "async_conn") # LatestBoto step of CI fails without this - async def test_run_success(self, conn_mock, waiter_mock): - waiter_mock().wait = AsyncMock() - trigger = AthenaTrigger("my_query_id", 0, 5, None) + class_name = class_path.split(".")[-1] + clazz = globals()[class_name] + instance = clazz(**args) - generator = trigger.run() - event = await generator.asend(None) + class_path2, args2 = instance.serialize() - assert event.payload["status"] == "success" - assert event.payload["value"] == "my_query_id" + assert class_path == class_path2 + assert args == args2 diff --git a/tests/providers/amazon/aws/triggers/test_base.py b/tests/providers/amazon/aws/triggers/test_base.py new file mode 100644 index 0000000000000..aad9ce45ec4a7 --- /dev/null +++ b/tests/providers/amazon/aws/triggers/test_base.py @@ -0,0 +1,89 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock +from unittest.mock import MagicMock + +import pytest + +from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook +from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger +from airflow.triggers.base import TriggerEvent + + +class TestImplem(AwsBaseWaiterTrigger): + """An empty implementation that allows instantiation for tests.""" + + __test__ = False + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def hook(self) -> AwsGenericHook: + return MagicMock() + + +class TestAwsBaseWaiterTrigger: + def setup_method(self): + self.trigger = TestImplem( + serialized_fields={}, + waiter_name="", + waiter_args={}, + failure_message="", + status_message="", + status_queries=[], + return_value=None, + waiter_delay=0, + waiter_max_attempts=0, + aws_conn_id="", + ) + + def test_region_serialized(self): + self.trigger.region_name = "my_region" + _, args = self.trigger.serialize() + + assert "region_name" in args + assert args["region_name"] == "my_region" + + def test_region_not_serialized_if_omitted(self): + _, args = self.trigger.serialize() + + assert "region_name" not in args + + def test_serialize_extra_fields(self): + self.trigger.serialized_fields = {"foo": "bar", "foz": "baz"} + + _, args = self.trigger.serialize() + + assert "foo" in args + assert args["foo"] == "bar" + assert "foz" in args + assert args["foz"] == "baz" + + @pytest.mark.asyncio + @mock.patch("airflow.providers.amazon.aws.triggers.base.async_wait") + async def test_run(self, wait_mock: MagicMock): + self.trigger.return_key = "hello" + self.trigger.return_value = "world" + + generator = self.trigger.run() + res: TriggerEvent = await generator.asend(None) + + wait_mock.assert_called_once() + assert res.payload["status"] == "success" + assert res.payload["hello"] == "world" diff --git a/tests/providers/amazon/aws/triggers/test_batch.py b/tests/providers/amazon/aws/triggers/test_batch.py index e33736076237f..6ceee613321d3 100644 --- a/tests/providers/amazon/aws/triggers/test_batch.py +++ b/tests/providers/amazon/aws/triggers/test_batch.py @@ -16,20 +16,12 @@ # under the License. from __future__ import annotations -from unittest import mock -from unittest.mock import AsyncMock - import pytest -from botocore.exceptions import WaiterError -from airflow import AirflowException -from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook from airflow.providers.amazon.aws.triggers.batch import ( BatchCreateComputeEnvironmentTrigger, - BatchOperatorTrigger, - BatchSensorTrigger, + BatchJobTrigger, ) -from airflow.triggers.base import TriggerEvent BATCH_JOB_ID = "job_id" POLL_INTERVAL = 5 @@ -39,186 +31,34 @@ pytest.importorskip("aiobotocore") -class TestBatchOperatorTrigger: - def test_batch_operator_trigger_serialize(self): - batch_trigger = BatchOperatorTrigger( - job_id=BATCH_JOB_ID, - poll_interval=POLL_INTERVAL, - max_retries=MAX_ATTEMPT, - aws_conn_id=AWS_CONN_ID, - region_name=AWS_REGION, - ) - class_path, args = batch_trigger.serialize() - assert class_path == "airflow.providers.amazon.aws.triggers.batch.BatchOperatorTrigger" - assert args["job_id"] == BATCH_JOB_ID - assert args["poll_interval"] == POLL_INTERVAL - assert args["max_retries"] == MAX_ATTEMPT - assert args["aws_conn_id"] == AWS_CONN_ID - assert args["region_name"] == AWS_REGION - - @pytest.mark.asyncio - @mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientHook.get_waiter") - @mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientHook.async_conn") - async def test_batch_job_trigger_run(self, mock_async_conn, mock_get_waiter): - the_mock = mock.MagicMock() - mock_async_conn.__aenter__.return_value = the_mock - - mock_get_waiter().wait = AsyncMock() - - batch_trigger = BatchOperatorTrigger( - job_id=BATCH_JOB_ID, - poll_interval=POLL_INTERVAL, - max_retries=MAX_ATTEMPT, - aws_conn_id=AWS_CONN_ID, - region_name=AWS_REGION, - ) - - generator = batch_trigger.run() - response = await generator.asend(None) - - assert response == TriggerEvent({"status": "success", "job_id": BATCH_JOB_ID}) - - -class TestBatchSensorTrigger: - TRIGGER = BatchSensorTrigger( - job_id=BATCH_JOB_ID, - region_name=AWS_REGION, - aws_conn_id=AWS_CONN_ID, - poke_interval=POLL_INTERVAL, +class TestBatchTrigger: + @pytest.mark.parametrize( + "trigger", + [ + BatchJobTrigger( + job_id=BATCH_JOB_ID, + waiter_delay=POLL_INTERVAL, + waiter_max_attempts=MAX_ATTEMPT, + aws_conn_id=AWS_CONN_ID, + region_name=AWS_REGION, + ), + BatchCreateComputeEnvironmentTrigger( + compute_env_arn="my_arn", + waiter_delay=POLL_INTERVAL, + waiter_max_attempts=MAX_ATTEMPT, + aws_conn_id=AWS_CONN_ID, + region_name=AWS_REGION, + ), + ], ) + def test_serialize_recreate(self, trigger): + class_path, args = trigger.serialize() - def test_batch_sensor_trigger_serialization(self): - """ - Asserts that the BatchSensorTrigger correctly serializes its arguments - and classpath. - """ - - classpath, kwargs = self.TRIGGER.serialize() - assert classpath == "airflow.providers.amazon.aws.triggers.batch.BatchSensorTrigger" - assert kwargs == { - "job_id": BATCH_JOB_ID, - "region_name": AWS_REGION, - "aws_conn_id": AWS_CONN_ID, - "poke_interval": POLL_INTERVAL, - } - - @pytest.mark.asyncio - @mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientHook.get_waiter") - @mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientHook.async_conn") - async def test_batch_job_trigger_run(self, mock_async_conn, mock_get_waiter): - the_mock = mock.MagicMock() - mock_async_conn.__aenter__.return_value = the_mock - - mock_get_waiter().wait = AsyncMock() - - batch_trigger = BatchOperatorTrigger( - job_id=BATCH_JOB_ID, - poll_interval=POLL_INTERVAL, - max_retries=MAX_ATTEMPT, - aws_conn_id=AWS_CONN_ID, - region_name=AWS_REGION, - ) - - generator = batch_trigger.run() - response = await generator.asend(None) - - assert response == TriggerEvent({"status": "success", "job_id": BATCH_JOB_ID}) - - @pytest.mark.asyncio - @mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientHook.get_waiter") - @mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientHook.async_conn") - @mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientHook.get_job_description") - async def test_batch_sensor_trigger_completed(self, mock_response, mock_async_conn, mock_get_waiter): - """Test if the success event is returned from trigger.""" - mock_response.return_value = {"status": "SUCCEEDED"} - - the_mock = mock.MagicMock() - mock_async_conn.__aenter__.return_value = the_mock - - mock_get_waiter().wait = AsyncMock() - - trigger = BatchSensorTrigger( - job_id=BATCH_JOB_ID, - region_name=AWS_REGION, - aws_conn_id=AWS_CONN_ID, - ) - generator = trigger.run() - actual_response = await generator.asend(None) - assert ( - TriggerEvent( - {"status": "success", "job_id": BATCH_JOB_ID, "message": f"Job {BATCH_JOB_ID} Succeeded"} - ) - == actual_response - ) - - @pytest.mark.asyncio - @mock.patch("asyncio.sleep") - @mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientHook.get_waiter") - @mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientHook.get_job_description") - @mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientHook.async_conn") - async def test_batch_sensor_trigger_failure( - self, mock_async_conn, mock_response, mock_get_waiter, mock_sleep - ): - """Test if the failure event is returned from trigger.""" - a_mock = mock.MagicMock() - mock_async_conn.__aenter__.return_value = a_mock - - mock_response.return_value = {"status": "failed"} - - name = "batch_job_complete" - reason = ( - "An error occurred (UnrecognizedClientException): The security token included in the " - "request is invalid. " - ) - last_response = ({"Error": {"Message": "The security token included in the request is invalid."}},) - - error_failed = WaiterError( - name=name, - reason=reason, - last_response=last_response, - ) - - mock_get_waiter().wait.side_effect = AsyncMock(side_effect=[error_failed]) - mock_sleep.return_value = True - - trigger = BatchSensorTrigger(job_id=BATCH_JOB_ID, region_name=AWS_REGION, aws_conn_id=AWS_CONN_ID) - generator = trigger.run() - actual_response = await generator.asend(None) - assert actual_response == TriggerEvent( - {"status": "failure", "message": f"Job Failed: Waiter {name} failed: {reason}"} - ) - - -class TestBatchCreateComputeEnvironmentTrigger: - @pytest.mark.asyncio - @mock.patch.object(BatchClientHook, "async_conn") - @mock.patch.object(BatchClientHook, "get_waiter") - async def test_success(self, get_waiter_mock, conn_mock): - get_waiter_mock().wait = AsyncMock( - side_effect=[ - WaiterError( - "situation normal", "first try", {"computeEnvironments": [{"status": "my_status"}]} - ), - {}, - ] - ) - trigger = BatchCreateComputeEnvironmentTrigger("my_arn", poll_interval=0, max_retries=3) - - generator = trigger.run() - response: TriggerEvent = await generator.asend(None) - - assert response.payload["status"] == "success" - assert response.payload["value"] == "my_arn" + class_name = class_path.split(".")[-1] + clazz = globals()[class_name] + instance = clazz(**args) - @pytest.mark.asyncio - @mock.patch.object(BatchClientHook, "async_conn") - @mock.patch.object(BatchClientHook, "get_waiter") - async def test_failure(self, get_waiter_mock, conn_mock): - get_waiter_mock().wait = AsyncMock( - side_effect=[WaiterError("terminal failure", "terminal failure reason", {})] - ) - trigger = BatchCreateComputeEnvironmentTrigger("my_arn", poll_interval=0, max_retries=3) + class_path2, args2 = instance.serialize() - with pytest.raises(AirflowException): - generator = trigger.run() - await generator.asend(None) + assert class_path == class_path2 + assert args == args2 diff --git a/tests/providers/amazon/aws/triggers/test_ecs.py b/tests/providers/amazon/aws/triggers/test_ecs.py index 09b5decbe631a..551ab39a44867 100644 --- a/tests/providers/amazon/aws/triggers/test_ecs.py +++ b/tests/providers/amazon/aws/triggers/test_ecs.py @@ -22,64 +22,16 @@ import pytest from botocore.exceptions import WaiterError -from airflow import AirflowException from airflow.providers.amazon.aws.hooks.ecs import EcsHook from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook -from airflow.providers.amazon.aws.triggers.ecs import ClusterWaiterTrigger, TaskDoneTrigger +from airflow.providers.amazon.aws.triggers.ecs import ( + ClusterActiveTrigger, + ClusterInactiveTrigger, + TaskDoneTrigger, +) from airflow.triggers.base import TriggerEvent -class TestClusterWaiterTrigger: - @pytest.mark.asyncio - @mock.patch.object(EcsHook, "async_conn") - async def test_run_max_attempts(self, client_mock): - a_mock = mock.MagicMock() - client_mock.__aenter__.return_value = a_mock - wait_mock = AsyncMock() - wait_mock.side_effect = WaiterError("name", "reason", {"clusters": [{"status": "my_status"}]}) - a_mock.get_waiter().wait = wait_mock - - max_attempts = 5 - trigger = ClusterWaiterTrigger("my_waiter", "cluster_arn", 0, max_attempts, None, None) - - with pytest.raises(AirflowException): - generator = trigger.run() - await generator.asend(None) - - assert wait_mock.call_count == max_attempts - - @pytest.mark.asyncio - @mock.patch.object(EcsHook, "async_conn") - async def test_run_success(self, client_mock): - a_mock = mock.MagicMock() - client_mock.__aenter__.return_value = a_mock - wait_mock = AsyncMock() - a_mock.get_waiter().wait = wait_mock - - trigger = ClusterWaiterTrigger("my_waiter", "cluster_arn", 0, 5, None, None) - - generator = trigger.run() - response: TriggerEvent = await generator.asend(None) - - assert response.payload["status"] == "success" - assert response.payload["arn"] == "cluster_arn" - - @pytest.mark.asyncio - @mock.patch.object(EcsHook, "async_conn") - async def test_run_error(self, client_mock): - a_mock = mock.MagicMock() - client_mock.__aenter__.return_value = a_mock - wait_mock = AsyncMock() - wait_mock.side_effect = WaiterError("terminal failure", "reason", {}) - a_mock.get_waiter().wait = wait_mock - - trigger = ClusterWaiterTrigger("my_waiter", "cluster_arn", 0, 5, None, None) - - with pytest.raises(AirflowException): - generator = trigger.run() - await generator.asend(None) - - class TestTaskDoneTrigger: @pytest.mark.asyncio @mock.patch.object(EcsHook, "async_conn") @@ -121,3 +73,36 @@ async def test_run_success(self, _, client_mock): assert response.payload["status"] == "success" assert response.payload["task_arn"] == "my_task_arn" + + +class TestClusterTriggers: + @pytest.mark.parametrize( + "trigger", + [ + ClusterActiveTrigger( + cluster_arn="my_arn", + aws_conn_id="my_conn", + waiter_delay=1, + waiter_max_attempts=2, + region_name="my_region", + ), + ClusterInactiveTrigger( + cluster_arn="my_arn", + aws_conn_id="my_conn", + waiter_delay=1, + waiter_max_attempts=2, + region_name="my_region", + ), + ], + ) + def test_serialize_recreate(self, trigger): + class_path, args = trigger.serialize() + + class_name = class_path.split(".")[-1] + clazz = globals()[class_name] + instance = clazz(**args) + + class_path2, args2 = instance.serialize() + + assert class_path == class_path2 + assert args == args2 diff --git a/tests/providers/amazon/aws/triggers/test_eks.py b/tests/providers/amazon/aws/triggers/test_eks.py index 0b94e957db6cb..045519aea57e8 100644 --- a/tests/providers/amazon/aws/triggers/test_eks.py +++ b/tests/providers/amazon/aws/triggers/test_eks.py @@ -16,20 +16,14 @@ # under the License. from __future__ import annotations -from unittest import mock -from unittest.mock import AsyncMock - import pytest -from botocore.exceptions import WaiterError -from airflow.exceptions import AirflowException -from airflow.providers.amazon.aws.hooks.eks import EksHook from airflow.providers.amazon.aws.triggers.eks import ( EksCreateFargateProfileTrigger, + EksCreateNodegroupTrigger, EksDeleteFargateProfileTrigger, - EksNodegroupTrigger, + EksDeleteNodegroupTrigger, ) -from airflow.triggers.base import TriggerEvent TEST_CLUSTER_IDENTIFIER = "test-cluster" TEST_FARGATE_PROFILE_NAME = "test-fargate-profile" @@ -40,416 +34,50 @@ TEST_REGION = "test-region" -class TestEksCreateFargateProfileTrigger: - def test_eks_create_fargate_profile_serialize(self): - eks_create_fargate_profile_trigger = EksCreateFargateProfileTrigger( - cluster_name=TEST_CLUSTER_IDENTIFIER, - fargate_profile_name=TEST_FARGATE_PROFILE_NAME, - aws_conn_id=TEST_AWS_CONN_ID, - waiter_delay=TEST_WAITER_DELAY, - waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS, - ) - - class_path, args = eks_create_fargate_profile_trigger.serialize() - assert class_path == "airflow.providers.amazon.aws.triggers.eks.EksCreateFargateProfileTrigger" - assert args["cluster_name"] == TEST_CLUSTER_IDENTIFIER - assert args["fargate_profile_name"] == TEST_FARGATE_PROFILE_NAME - assert args["aws_conn_id"] == TEST_AWS_CONN_ID - assert args["waiter_delay"] == str(TEST_WAITER_DELAY) - assert args["waiter_max_attempts"] == str(TEST_WAITER_MAX_ATTEMPTS) - - @pytest.mark.asyncio - @mock.patch.object(EksHook, "async_conn") - async def test_eks_create_fargate_profile_trigger_run(self, mock_async_conn): - a_mock = mock.MagicMock() - mock_async_conn.__aenter__.return_value = a_mock - - a_mock.get_waiter().wait = AsyncMock() - - eks_create_fargate_profile_trigger = EksCreateFargateProfileTrigger( - cluster_name=TEST_CLUSTER_IDENTIFIER, - fargate_profile_name=TEST_FARGATE_PROFILE_NAME, - aws_conn_id=TEST_AWS_CONN_ID, - waiter_delay=TEST_WAITER_DELAY, - waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS, - ) - - generator = eks_create_fargate_profile_trigger.run() - response = await generator.asend(None) - - assert response == TriggerEvent({"status": "success", "message": "Fargate Profile Created"}) - - @pytest.mark.asyncio - @mock.patch("asyncio.sleep") - @mock.patch.object(EksHook, "async_conn") - async def test_eks_create_fargate_profile_trigger_run_multiple_attempts( - self, mock_async_conn, mock_sleep - ): - a_mock = mock.MagicMock() - mock_async_conn.__aenter__.return_value = a_mock - error = WaiterError( - name="test_name", - reason="test_reason", - last_response={"fargateProfile": {"status": "CREATING"}}, - ) - a_mock.get_waiter().wait = AsyncMock(side_effect=[error, error, True]) - mock_sleep.return_value = True - - eks_create_fargate_profile_trigger = EksCreateFargateProfileTrigger( - cluster_name=TEST_CLUSTER_IDENTIFIER, - fargate_profile_name=TEST_FARGATE_PROFILE_NAME, - aws_conn_id=TEST_AWS_CONN_ID, - waiter_delay=TEST_WAITER_DELAY, - waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS, - ) - - generator = eks_create_fargate_profile_trigger.run() - response = await generator.asend(None) - - assert a_mock.get_waiter().wait.call_count == 3 - assert response == TriggerEvent({"status": "success", "message": "Fargate Profile Created"}) - - @pytest.mark.asyncio - @mock.patch("asyncio.sleep") - @mock.patch.object(EksHook, "async_conn") - async def test_eks_create_fargate_profile_trigger_run_attempts_exceeded( - self, mock_async_conn, mock_sleep - ): - a_mock = mock.MagicMock() - mock_async_conn.__aenter__.return_value = a_mock - error = WaiterError( - name="test_name", - reason="test_reason", - last_response={"fargateProfile": {"status": "CREATING"}}, - ) - a_mock.get_waiter().wait = AsyncMock(side_effect=[error, error, True]) - mock_sleep.return_value = True - - eks_create_fargate_profile_trigger = EksCreateFargateProfileTrigger( - cluster_name=TEST_CLUSTER_IDENTIFIER, - fargate_profile_name=TEST_FARGATE_PROFILE_NAME, - aws_conn_id=TEST_AWS_CONN_ID, - waiter_delay=TEST_WAITER_DELAY, - waiter_max_attempts=2, - ) - with pytest.raises(AirflowException) as exc: - generator = eks_create_fargate_profile_trigger.run() - await generator.asend(None) - assert "Create Fargate Profile failed - max attempts reached:" in str(exc.value) - assert a_mock.get_waiter().wait.call_count == 2 - - @pytest.mark.asyncio - @mock.patch("asyncio.sleep") - @mock.patch.object(EksHook, "async_conn") - async def test_eks_create_fargate_profile_trigger_run_attempts_failed(self, mock_async_conn, mock_sleep): - a_mock = mock.MagicMock() - mock_async_conn.__aenter__.return_value = a_mock - error_creating = WaiterError( - name="test_name", - reason="test_reason", - last_response={"fargateProfile": {"status": "CREATING"}}, - ) - error_failed = WaiterError( - name="test_name", - reason="Waiter encountered a terminal failure state:", - last_response={"fargateProfile": {"status": "CREATE_FAILED"}}, - ) - a_mock.get_waiter().wait = AsyncMock(side_effect=[error_creating, error_creating, error_failed]) - mock_sleep.return_value = True - - eks_create_fargate_profile_trigger = EksCreateFargateProfileTrigger( - cluster_name=TEST_CLUSTER_IDENTIFIER, - fargate_profile_name=TEST_FARGATE_PROFILE_NAME, - aws_conn_id=TEST_AWS_CONN_ID, - waiter_delay=TEST_WAITER_DELAY, - waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS, - ) - - with pytest.raises(AirflowException) as exc: - generator = eks_create_fargate_profile_trigger.run() - await generator.asend(None) - assert f"Create Fargate Profile failed: {error_failed}" in str(exc.value) - assert a_mock.get_waiter().wait.call_count == 3 - - -class TestEksDeleteFargateProfileTrigger: - def test_eks_delete_fargate_profile_serialize(self): - eks_delete_fargate_profile_trigger = EksDeleteFargateProfileTrigger( - cluster_name=TEST_CLUSTER_IDENTIFIER, - fargate_profile_name=TEST_FARGATE_PROFILE_NAME, - aws_conn_id=TEST_AWS_CONN_ID, - waiter_delay=TEST_WAITER_DELAY, - waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS, - ) - - class_path, args = eks_delete_fargate_profile_trigger.serialize() - assert class_path == "airflow.providers.amazon.aws.triggers.eks.EksDeleteFargateProfileTrigger" - assert args["cluster_name"] == TEST_CLUSTER_IDENTIFIER - assert args["fargate_profile_name"] == TEST_FARGATE_PROFILE_NAME - assert args["aws_conn_id"] == TEST_AWS_CONN_ID - assert args["waiter_delay"] == str(TEST_WAITER_DELAY) - assert args["waiter_max_attempts"] == str(TEST_WAITER_MAX_ATTEMPTS) - - @pytest.mark.asyncio - @mock.patch.object(EksHook, "async_conn") - async def test_eks_delete_fargate_profile_trigger_run(self, mock_async_conn): - a_mock = mock.MagicMock() - mock_async_conn.__aenter__.return_value = a_mock - - a_mock.get_waiter().wait = AsyncMock() - - eks_delete_fargate_profile_trigger = EksDeleteFargateProfileTrigger( - cluster_name=TEST_CLUSTER_IDENTIFIER, - fargate_profile_name=TEST_FARGATE_PROFILE_NAME, - aws_conn_id=TEST_AWS_CONN_ID, - waiter_delay=TEST_WAITER_DELAY, - waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS, - ) - - generator = eks_delete_fargate_profile_trigger.run() - response = await generator.asend(None) - - assert response == TriggerEvent({"status": "success", "message": "Fargate Profile Deleted"}) - - @pytest.mark.asyncio - @mock.patch("asyncio.sleep") - @mock.patch.object(EksHook, "async_conn") - async def test_eks_delete_fargate_profile_trigger_run_multiple_attempts( - self, mock_async_conn, mock_sleep - ): - a_mock = mock.MagicMock() - mock_async_conn.__aenter__.return_value = a_mock - error = WaiterError( - name="test_name", - reason="test_reason", - last_response={"fargateProfile": {"status": "DELETING"}}, - ) - a_mock.get_waiter().wait = AsyncMock(side_effect=[error, error, True]) - mock_sleep.return_value = True - - eks_delete_fargate_profile_trigger = EksDeleteFargateProfileTrigger( - cluster_name=TEST_CLUSTER_IDENTIFIER, - fargate_profile_name=TEST_FARGATE_PROFILE_NAME, - aws_conn_id=TEST_AWS_CONN_ID, - waiter_delay=TEST_WAITER_DELAY, - waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS, - ) - - generator = eks_delete_fargate_profile_trigger.run() - response = await generator.asend(None) - assert a_mock.get_waiter().wait.call_count == 3 - assert response == TriggerEvent({"status": "success", "message": "Fargate Profile Deleted"}) - - @pytest.mark.asyncio - @mock.patch("asyncio.sleep") - @mock.patch.object(EksHook, "async_conn") - async def test_eks_delete_fargate_profile_trigger_run_attempts_exceeded( - self, mock_async_conn, mock_sleep - ): - a_mock = mock.MagicMock() - mock_async_conn.__aenter__.return_value = a_mock - error = WaiterError( - name="test_name", - reason="test_reason", - last_response={"fargateProfile": {"status": "DELETING"}}, - ) - a_mock.get_waiter().wait = AsyncMock(side_effect=[error, error, error, True]) - mock_sleep.return_value = True - - eks_delete_fargate_profile_trigger = EksDeleteFargateProfileTrigger( - cluster_name=TEST_CLUSTER_IDENTIFIER, - fargate_profile_name=TEST_FARGATE_PROFILE_NAME, - aws_conn_id=TEST_AWS_CONN_ID, - waiter_delay=TEST_WAITER_DELAY, - waiter_max_attempts=2, - ) - with pytest.raises(AirflowException) as exc: - generator = eks_delete_fargate_profile_trigger.run() - await generator.asend(None) - assert "Delete Fargate Profile failed - max attempts reached: 2" in str(exc.value) - assert a_mock.get_waiter().wait.call_count == 2 - - @pytest.mark.asyncio - @mock.patch("asyncio.sleep") - @mock.patch.object(EksHook, "async_conn") - async def test_eks_delete_fargate_profile_trigger_run_attempts_failed(self, mock_async_conn, mock_sleep): - a_mock = mock.MagicMock() - mock_async_conn.__aenter__.return_value = a_mock - error_creating = WaiterError( - name="test_name", - reason="test_reason", - last_response={"fargateProfile": {"status": "DELETING"}}, - ) - error_failed = WaiterError( - name="test_name", - reason="Waiter encountered a terminal failure state:", - last_response={"fargateProfile": {"status": "DELETE_FAILED"}}, - ) - a_mock.get_waiter().wait = AsyncMock(side_effect=[error_creating, error_creating, error_failed]) - mock_sleep.return_value = True - - eks_delete_fargate_profile_trigger = EksDeleteFargateProfileTrigger( - cluster_name=TEST_CLUSTER_IDENTIFIER, - fargate_profile_name=TEST_FARGATE_PROFILE_NAME, - aws_conn_id=TEST_AWS_CONN_ID, - waiter_delay=TEST_WAITER_DELAY, - waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS, - ) - with pytest.raises(AirflowException) as exc: - generator = eks_delete_fargate_profile_trigger.run() - await generator.asend(None) - assert f"Delete Fargate Profile failed: {error_failed}" in str(exc.value) - assert a_mock.get_waiter().wait.call_count == 3 - - -class TestEksNodegroupTrigger: - def test_eks_nodegroup_trigger_serialize(self): - eks_nodegroup_trigger = EksNodegroupTrigger( - waiter_name="test_waiter_name", - cluster_name=TEST_CLUSTER_IDENTIFIER, - nodegroup_name=TEST_NODEGROUP_NAME, - aws_conn_id=TEST_AWS_CONN_ID, - waiter_delay=TEST_WAITER_DELAY, - waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS, - region=TEST_REGION, - ) - - class_path, args = eks_nodegroup_trigger.serialize() - assert class_path == "airflow.providers.amazon.aws.triggers.eks.EksNodegroupTrigger" - assert args["waiter_name"] == "test_waiter_name" - assert args["cluster_name"] == TEST_CLUSTER_IDENTIFIER - assert args["nodegroup_name"] == TEST_NODEGROUP_NAME - assert args["aws_conn_id"] == TEST_AWS_CONN_ID - assert args["waiter_delay"] == str(TEST_WAITER_DELAY) - assert args["waiter_max_attempts"] == str(TEST_WAITER_MAX_ATTEMPTS) - assert args["region"] == TEST_REGION - - @pytest.mark.asyncio - @mock.patch.object(EksHook, "async_conn") - async def test_eks_nodegroup_trigger_run(self, mock_async_conn): - a_mock = mock.MagicMock() - mock_async_conn.__aenter__.return_value = a_mock - - a_mock.get_waiter().wait = AsyncMock() - - eks_nodegroup_trigger = EksNodegroupTrigger( - waiter_name="test_waiter_name", - cluster_name=TEST_CLUSTER_IDENTIFIER, - nodegroup_name=TEST_NODEGROUP_NAME, - aws_conn_id=TEST_AWS_CONN_ID, - waiter_delay=TEST_WAITER_DELAY, - waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS, - region=TEST_REGION, - ) - - generator = eks_nodegroup_trigger.run() - response = await generator.asend(None) - - assert response == TriggerEvent( - { - "status": "success", - "cluster_name": TEST_CLUSTER_IDENTIFIER, - "nodegroup_name": TEST_NODEGROUP_NAME, - } - ) - - @pytest.mark.asyncio - @mock.patch("asyncio.sleep") - @mock.patch.object(EksHook, "async_conn") - async def test_eks_nodegroup_trigger_run_multiple_attempts(self, mock_async_conn, mock_sleep): - mock_sleep.return_value = True - a_mock = mock.MagicMock() - mock_async_conn.__aenter__.return_value = a_mock - error = WaiterError( - name="test_name", - reason="test_reason", - last_response={"nodegroup": {"status": "CREATING"}}, - ) - a_mock.get_waiter().wait = AsyncMock(side_effect=[error, error, error, True]) - - eks_nodegroup_trigger = EksNodegroupTrigger( - waiter_name="test_waiter_name", - cluster_name=TEST_CLUSTER_IDENTIFIER, - nodegroup_name=TEST_NODEGROUP_NAME, - aws_conn_id=TEST_AWS_CONN_ID, - waiter_delay=TEST_WAITER_DELAY, - waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS, - region=TEST_REGION, - ) - - generator = eks_nodegroup_trigger.run() - response = await generator.asend(None) - assert a_mock.get_waiter().wait.call_count == 4 - assert response == TriggerEvent( - { - "status": "success", - "cluster_name": TEST_CLUSTER_IDENTIFIER, - "nodegroup_name": TEST_NODEGROUP_NAME, - } - ) - - @pytest.mark.asyncio - @mock.patch("asyncio.sleep") - @mock.patch.object(EksHook, "async_conn") - async def test_eks_nodegroup_trigger_run_attempts_exceeded(self, mock_async_conn, mock_sleep): - mock_sleep.return_value = True - a_mock = mock.MagicMock() - mock_async_conn.__aenter__.return_value = a_mock - error = WaiterError( - name="test_name", - reason="test_reason", - last_response={"nodegroup": {"status": "CREATING"}}, - ) - a_mock.get_waiter().wait = AsyncMock(side_effect=[error, error, error, True]) - - eks_nodegroup_trigger = EksNodegroupTrigger( - waiter_name="test_waiter_name", - cluster_name=TEST_CLUSTER_IDENTIFIER, - nodegroup_name=TEST_NODEGROUP_NAME, - aws_conn_id=TEST_AWS_CONN_ID, - waiter_delay=TEST_WAITER_DELAY, - waiter_max_attempts=2, - region=TEST_REGION, - ) - - with pytest.raises(AirflowException) as exc: - generator = eks_nodegroup_trigger.run() - await generator.asend(None) - assert "Waiter error: max attempts reached" in str(exc.value) - assert a_mock.get_waiter().wait.call_count == 2 - - @pytest.mark.asyncio - @mock.patch("asyncio.sleep") - @mock.patch.object(EksHook, "async_conn") - async def test_eks_nodegroup_trigger_run_attempts_failed(self, mock_async_conn, mock_sleep): - mock_sleep.return_value = True - a_mock = mock.MagicMock() - mock_async_conn.__aenter__.return_value = a_mock - error_creating = WaiterError( - name="test_name", - reason="test_reason", - last_response={"nodegroup": {"status": "CREATING"}}, - ) - error_failed = WaiterError( - name="test_name", - reason="Waiter encountered a terminal failure state:", - last_response={"nodegroup": {"status": "DELETE_FAILED"}}, - ) - a_mock.get_waiter().wait = AsyncMock(side_effect=[error_creating, error_creating, error_failed]) - mock_sleep.return_value = True - - eks_nodegroup_trigger = EksNodegroupTrigger( - waiter_name="test_waiter_name", - cluster_name=TEST_CLUSTER_IDENTIFIER, - nodegroup_name=TEST_NODEGROUP_NAME, - aws_conn_id=TEST_AWS_CONN_ID, - waiter_delay=TEST_WAITER_DELAY, - waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS, - region=TEST_REGION, - ) - with pytest.raises(AirflowException) as exc: - generator = eks_nodegroup_trigger.run() - await generator.asend(None) - - assert "Error checking nodegroup" in str(exc.value) - assert a_mock.get_waiter().wait.call_count == 3 +class TestEksTriggers: + @pytest.mark.parametrize( + "trigger", + [ + EksCreateFargateProfileTrigger( + cluster_name=TEST_CLUSTER_IDENTIFIER, + fargate_profile_name=TEST_FARGATE_PROFILE_NAME, + aws_conn_id=TEST_AWS_CONN_ID, + waiter_delay=TEST_WAITER_DELAY, + waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS, + ), + EksDeleteFargateProfileTrigger( + cluster_name=TEST_CLUSTER_IDENTIFIER, + fargate_profile_name=TEST_FARGATE_PROFILE_NAME, + aws_conn_id=TEST_AWS_CONN_ID, + waiter_delay=TEST_WAITER_DELAY, + waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS, + ), + EksCreateNodegroupTrigger( + cluster_name=TEST_CLUSTER_IDENTIFIER, + nodegroup_name=TEST_NODEGROUP_NAME, + aws_conn_id=TEST_AWS_CONN_ID, + waiter_delay=TEST_WAITER_DELAY, + waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS, + region_name=TEST_REGION, + ), + EksDeleteNodegroupTrigger( + cluster_name=TEST_CLUSTER_IDENTIFIER, + nodegroup_name=TEST_NODEGROUP_NAME, + aws_conn_id=TEST_AWS_CONN_ID, + waiter_delay=TEST_WAITER_DELAY, + waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS, + region_name=TEST_REGION, + ), + ], + ) + def test_serialize_recreate(self, trigger): + class_path, args = trigger.serialize() + + class_name = class_path.split(".")[-1] + clazz = globals()[class_name] + instance = clazz(**args) + + class_path2, args2 = instance.serialize() + + assert class_path == class_path2 + assert args == args2 diff --git a/tests/providers/amazon/aws/triggers/test_emr.py b/tests/providers/amazon/aws/triggers/test_emr.py index ff57eab4a5d15..7066e817e4811 100644 --- a/tests/providers/amazon/aws/triggers/test_emr.py +++ b/tests/providers/amazon/aws/triggers/test_emr.py @@ -16,21 +16,14 @@ # under the License. from __future__ import annotations -from unittest import mock -from unittest.mock import AsyncMock - import pytest -from botocore.exceptions import WaiterError -from airflow.exceptions import AirflowException -from airflow.providers.amazon.aws.hooks.emr import EmrHook from airflow.providers.amazon.aws.triggers.emr import ( EmrContainerTrigger, EmrCreateJobFlowTrigger, EmrStepSensorTrigger, EmrTerminateJobFlowTrigger, ) -from airflow.triggers.base import TriggerEvent TEST_JOB_FLOW_ID = "test-job-flow-id" TEST_POLL_INTERVAL = 10 @@ -44,527 +37,44 @@ STEP_ID = "s-1234" -class TestEmrCreateJobFlowTrigger: - def test_emr_create_job_flow_trigger_serialize(self): - """Test serialize method to make sure all parameters are being serialized correctly.""" - emr_create_job_flow_trigger = EmrCreateJobFlowTrigger( - job_flow_id=TEST_JOB_FLOW_ID, - aws_conn_id=TEST_AWS_CONN_ID, - poll_interval=TEST_POLL_INTERVAL, - max_attempts=TEST_MAX_ATTEMPTS, - ) - class_path, args = emr_create_job_flow_trigger.serialize() - assert class_path == "airflow.providers.amazon.aws.triggers.emr.EmrCreateJobFlowTrigger" - assert args["job_flow_id"] == TEST_JOB_FLOW_ID - assert args["aws_conn_id"] == TEST_AWS_CONN_ID - assert args["poll_interval"] == str(TEST_POLL_INTERVAL) - assert args["max_attempts"] == str(TEST_MAX_ATTEMPTS) - - @pytest.mark.asyncio - @mock.patch.object(EmrHook, "get_waiter") - @mock.patch.object(EmrHook, "async_conn") - async def test_emr_create_job_flow_trigger_run(self, mock_async_conn, mock_get_waiter): - """ - Test run method, with basic success case to assert TriggerEvent contains the - correct payload. - """ - a_mock = mock.MagicMock() - mock_async_conn.__aenter__.return_value = a_mock - - mock_get_waiter().wait = AsyncMock() - - emr_create_job_flow_trigger = EmrCreateJobFlowTrigger( - job_flow_id=TEST_JOB_FLOW_ID, - aws_conn_id=TEST_AWS_CONN_ID, - poll_interval=TEST_POLL_INTERVAL, - max_attempts=TEST_MAX_ATTEMPTS, - ) - - generator = emr_create_job_flow_trigger.run() - response = await generator.asend(None) - - assert response == TriggerEvent( - { - "status": "success", - "message": "JobFlow completed successfully", - "job_flow_id": TEST_JOB_FLOW_ID, - } - ) - - @pytest.mark.asyncio - @mock.patch("asyncio.sleep") - @mock.patch.object(EmrHook, "get_waiter") - @mock.patch.object(EmrHook, "async_conn") - async def test_emr_create_job_flow_trigger_run_multiple_attempts( - self, mock_async_conn, mock_get_waiter, mock_sleep - ): - """ - Test run method with multiple attempts to make sure the waiter retries - are working as expected. - """ - a_mock = mock.MagicMock() - mock_async_conn.__aenter__.return_value = a_mock - error = WaiterError( - name="test_name", - reason="test_reason", - last_response={"Cluster": {"Status": {"State": "STARTING", "StateChangeReason": "test-reason"}}}, - ) - mock_get_waiter().wait.side_effect = AsyncMock(side_effect=[error, error, True]) - mock_sleep.return_value = True - - emr_create_job_flow_trigger = EmrCreateJobFlowTrigger( - job_flow_id=TEST_JOB_FLOW_ID, - aws_conn_id=TEST_AWS_CONN_ID, - poll_interval=TEST_POLL_INTERVAL, - max_attempts=TEST_MAX_ATTEMPTS, - ) - - generator = emr_create_job_flow_trigger.run() - response = await generator.asend(None) - - assert mock_get_waiter().wait.call_count == 3 - assert response == TriggerEvent( - { - "status": "success", - "message": "JobFlow completed successfully", - "job_flow_id": TEST_JOB_FLOW_ID, - } - ) - - @pytest.mark.asyncio - @mock.patch("asyncio.sleep") - @mock.patch.object(EmrHook, "get_waiter") - @mock.patch.object(EmrHook, "async_conn") - async def test_emr_create_job_flow_trigger_run_attempts_exceeded( - self, mock_async_conn, mock_get_waiter, mock_sleep - ): - """ - Test run method with max_attempts set to 2 to test the Trigger yields - the correct TriggerEvent in the case of max_attempts being exceeded. - """ - a_mock = mock.MagicMock() - mock_async_conn.__aenter__.return_value = a_mock - error = WaiterError( - name="test_name", - reason="test_reason", - last_response={"Cluster": {"Status": {"State": "STARTING", "StateChangeReason": "test-reason"}}}, - ) - mock_get_waiter().wait.side_effect = AsyncMock(side_effect=[error, error, True]) - mock_sleep.return_value = True - - emr_create_job_flow_trigger = EmrCreateJobFlowTrigger( - job_flow_id=TEST_JOB_FLOW_ID, - aws_conn_id=TEST_AWS_CONN_ID, - poll_interval=TEST_POLL_INTERVAL, - max_attempts=2, - ) - - with pytest.raises(AirflowException) as exc: - generator = emr_create_job_flow_trigger.run() - await generator.asend(None) - - assert str(exc.value) == "JobFlow creation failed - max attempts reached: 2" - assert mock_get_waiter().wait.call_count == 2 - - @pytest.mark.asyncio - @mock.patch("asyncio.sleep") - @mock.patch.object(EmrHook, "get_waiter") - @mock.patch.object(EmrHook, "async_conn") - async def test_emr_create_job_flow_trigger_run_attempts_failed( - self, mock_async_conn, mock_get_waiter, mock_sleep - ): - """ - Test run method with a failure case to test Trigger yields the correct - failure TriggerEvent. - """ - a_mock = mock.MagicMock() - mock_async_conn.__aenter__.return_value = a_mock - error_starting = WaiterError( - name="test_name", - reason="test_reason", - last_response={"Cluster": {"Status": {"State": "STARTING", "StateChangeReason": "test-reason"}}}, - ) - error_failed = WaiterError( - name="test_name", - reason="Waiter encountered a terminal failure state:", - last_response={ - "Cluster": {"Status": {"State": "TERMINATED_WITH_ERRORS", "StateChangeReason": "test-reason"}} - }, - ) - mock_get_waiter().wait.side_effect = AsyncMock( - side_effect=[error_starting, error_starting, error_failed] - ) - mock_sleep.return_value = True - - emr_create_job_flow_trigger = EmrCreateJobFlowTrigger( - job_flow_id=TEST_JOB_FLOW_ID, - aws_conn_id=TEST_AWS_CONN_ID, - poll_interval=TEST_POLL_INTERVAL, - max_attempts=TEST_MAX_ATTEMPTS, - ) - - with pytest.raises(AirflowException) as exc: - generator = emr_create_job_flow_trigger.run() - await generator.asend(None) - - assert str(exc.value) == f"JobFlow creation failed: {error_failed}" - assert mock_get_waiter().wait.call_count == 3 - - -class TestEmrTerminateJobFlowTrigger: - def test_emr_terminate_job_flow_trigger_serialize(self): - emr_terminate_job_flow_trigger = EmrTerminateJobFlowTrigger( - job_flow_id=TEST_JOB_FLOW_ID, - aws_conn_id=TEST_AWS_CONN_ID, - poll_interval=TEST_POLL_INTERVAL, - max_attempts=TEST_MAX_ATTEMPTS, - ) - class_path, args = emr_terminate_job_flow_trigger.serialize() - assert class_path == "airflow.providers.amazon.aws.triggers.emr.EmrTerminateJobFlowTrigger" - assert args["job_flow_id"] == TEST_JOB_FLOW_ID - assert args["aws_conn_id"] == TEST_AWS_CONN_ID - assert args["poll_interval"] == str(TEST_POLL_INTERVAL) - assert args["max_attempts"] == str(TEST_MAX_ATTEMPTS) - - @pytest.mark.asyncio - @mock.patch.object(EmrHook, "get_waiter") - @mock.patch.object(EmrHook, "async_conn") - async def test_emr_terminate_job_flow_trigger_run(self, mock_async_conn, mock_get_waiter): - a_mock = mock.MagicMock() - mock_async_conn.__aenter__.return_value = a_mock - - mock_get_waiter().wait = AsyncMock() - - emr_terminate_job_flow_trigger = EmrTerminateJobFlowTrigger( - job_flow_id=TEST_JOB_FLOW_ID, - aws_conn_id=TEST_AWS_CONN_ID, - poll_interval=TEST_POLL_INTERVAL, - max_attempts=TEST_MAX_ATTEMPTS, - ) - - generator = emr_terminate_job_flow_trigger.run() - response = await generator.asend(None) - - assert response == TriggerEvent( - { - "status": "success", - "message": "JobFlow terminated successfully", - } - ) - - @pytest.mark.asyncio - @mock.patch("asyncio.sleep") - @mock.patch.object(EmrHook, "get_waiter") - @mock.patch.object(EmrHook, "async_conn") - async def test_emr_terminate_job_flow_trigger_run_multiple_attempts( - self, mock_async_conn, mock_get_waiter, mock_sleep - ): - a_mock = mock.MagicMock() - mock_async_conn.__aenter__.return_value = a_mock - error = WaiterError( - name="test_name", - reason="test_reason", - last_response={ - "Cluster": {"Status": {"State": "TERMINATING", "StateChangeReason": "test-reason"}} - }, - ) - mock_get_waiter().wait.side_effect = AsyncMock(side_effect=[error, error, True]) - mock_sleep.return_value = True - - emr_terminate_job_flow_trigger = EmrTerminateJobFlowTrigger( - job_flow_id=TEST_JOB_FLOW_ID, - aws_conn_id=TEST_AWS_CONN_ID, - poll_interval=TEST_POLL_INTERVAL, - max_attempts=TEST_MAX_ATTEMPTS, - ) - - generator = emr_terminate_job_flow_trigger.run() - response = await generator.asend(None) - - assert mock_get_waiter().wait.call_count == 3 - assert response == TriggerEvent( - { - "status": "success", - "message": "JobFlow terminated successfully", - } - ) - - @pytest.mark.asyncio - @mock.patch("asyncio.sleep") - @mock.patch.object(EmrHook, "get_waiter") - @mock.patch.object(EmrHook, "async_conn") - async def test_emr_terminate_job_flow_trigger_run_attempts_exceeded( - self, mock_async_conn, mock_get_waiter, mock_sleep - ): - a_mock = mock.MagicMock() - mock_async_conn.__aenter__.return_value = a_mock - error = WaiterError( - name="test_name", - reason="test_reason", - last_response={ - "Cluster": {"Status": {"State": "TERMINATING", "StateChangeReason": "test-reason"}} - }, - ) - mock_get_waiter().wait.side_effect = AsyncMock(side_effect=[error, error, True]) - mock_sleep.return_value = True - - emr_terminate_job_flow_trigger = EmrTerminateJobFlowTrigger( - job_flow_id=TEST_JOB_FLOW_ID, - aws_conn_id=TEST_AWS_CONN_ID, - poll_interval=TEST_POLL_INTERVAL, - max_attempts=2, - ) - with pytest.raises(AirflowException) as exc: - generator = emr_terminate_job_flow_trigger.run() - await generator.asend(None) - - assert str(exc.value) == "JobFlow termination failed - max attempts reached: 2" - assert mock_get_waiter().wait.call_count == 2 - - @pytest.mark.asyncio - @mock.patch("asyncio.sleep") - @mock.patch.object(EmrHook, "get_waiter") - @mock.patch.object(EmrHook, "async_conn") - async def test_emr_terminate_job_flow_trigger_run_attempts_failed( - self, mock_async_conn, mock_get_waiter, mock_sleep - ): - a_mock = mock.MagicMock() - mock_async_conn.__aenter__.return_value = a_mock - error_starting = WaiterError( - name="test_name", - reason="test_reason", - last_response={ - "Cluster": {"Status": {"State": "TERMINATING", "StateChangeReason": "test-reason"}} - }, - ) - error_failed = WaiterError( - name="test_name", - reason="Waiter encountered a terminal failure state:", - last_response={ - "Cluster": {"Status": {"State": "TERMINATED_WITH_ERRORS", "StateChangeReason": "test-reason"}} - }, - ) - mock_get_waiter().wait.side_effect = AsyncMock( - side_effect=[error_starting, error_starting, error_failed] - ) - mock_sleep.return_value = True - - emr_terminate_job_flow_trigger = EmrTerminateJobFlowTrigger( - job_flow_id=TEST_JOB_FLOW_ID, - aws_conn_id=TEST_AWS_CONN_ID, - poll_interval=TEST_POLL_INTERVAL, - max_attempts=TEST_MAX_ATTEMPTS, - ) - with pytest.raises(AirflowException) as exc: - generator = emr_terminate_job_flow_trigger.run() - await generator.asend(None) - - assert str(exc.value) == f"JobFlow termination failed: {error_failed}" - assert mock_get_waiter().wait.call_count == 3 - - -class TestEmrContainerTrigger: - def test_emr_container_trigger_serialize(self): - emr_trigger = EmrContainerTrigger( - virtual_cluster_id=VIRTUAL_CLUSTER_ID, - job_id=JOB_ID, - aws_conn_id=AWS_CONN_ID, - poll_interval=POLL_INTERVAL, - ) - class_path, args = emr_trigger.serialize() - assert class_path == "airflow.providers.amazon.aws.triggers.emr.EmrContainerTrigger" - assert args["virtual_cluster_id"] == VIRTUAL_CLUSTER_ID - assert args["job_id"] == JOB_ID - assert args["aws_conn_id"] == AWS_CONN_ID - assert args["poll_interval"] == POLL_INTERVAL - - @pytest.mark.asyncio - @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrContainerHook.get_waiter") - @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrContainerHook.async_conn") - async def test_emr_container_trigger_run(self, mock_async_conn, mock_get_waiter): - a_mock = mock.MagicMock() - mock_async_conn.__aenter__.return_value = a_mock - - mock_get_waiter().wait = AsyncMock() - - emr_trigger = EmrContainerTrigger( - virtual_cluster_id=VIRTUAL_CLUSTER_ID, - job_id=JOB_ID, - aws_conn_id=AWS_CONN_ID, - poll_interval=POLL_INTERVAL, - ) - - generator = emr_trigger.run() - response = await generator.asend(None) - - assert response == TriggerEvent({"status": "success", "job_id": JOB_ID}) - - @pytest.mark.asyncio - @mock.patch("asyncio.sleep") - @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrContainerHook.get_waiter") - @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrContainerHook.async_conn") - async def test_emr_trigger_run_multiple_attempts(self, mock_async_conn, mock_get_waiter, mock_sleep): - a_mock = mock.MagicMock() - mock_async_conn.__aenter__.return_value = a_mock - - error = WaiterError( - name="test_name", - reason="test_reason", - last_response={"jobRun": {"state": "RUNNING"}}, - ) - mock_get_waiter().wait.side_effect = AsyncMock(side_effect=[error, error, True]) - mock_sleep.return_value = True - - emr_trigger = EmrContainerTrigger( - virtual_cluster_id=VIRTUAL_CLUSTER_ID, - job_id=JOB_ID, - aws_conn_id=AWS_CONN_ID, - poll_interval=POLL_INTERVAL, - ) - - generator = emr_trigger.run() - response = await generator.asend(None) - - assert mock_get_waiter().wait.call_count == 3 - assert response == TriggerEvent({"status": "success", "job_id": JOB_ID}) - - @pytest.mark.asyncio - @mock.patch("asyncio.sleep") - @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrContainerHook.get_waiter") - @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrContainerHook.async_conn") - async def test_emr_trigger_run_attempts_failed(self, mock_async_conn, mock_get_waiter, mock_sleep): - a_mock = mock.MagicMock() - mock_async_conn.__aenter__.return_value = a_mock - - error_available = WaiterError( - name="test_name", - reason="Max attempts exceeded", - last_response={"jobRun": {"state": "FAILED"}}, - ) - error_failed = WaiterError( - name="test_name", - reason="Waiter encountered a terminal failure state", - last_response={"jobRun": {"state": "FAILED"}}, - ) - mock_get_waiter().wait.side_effect = AsyncMock( - side_effect=[error_available, error_available, error_failed] - ) - mock_sleep.return_value = True - - emr_trigger = EmrContainerTrigger( - virtual_cluster_id=VIRTUAL_CLUSTER_ID, - job_id=JOB_ID, - aws_conn_id=AWS_CONN_ID, - poll_interval=POLL_INTERVAL, - ) - - generator = emr_trigger.run() - response = await generator.asend(None) - - assert mock_get_waiter().wait.call_count == 3 - assert response == TriggerEvent({"status": "failure", "message": f"Job Failed: {error_failed}"}) - - -class TestEmrStepSensorTrigger: - def test_emr_step_trigger_serialize(self): - """Test trigger serialize object and path as expected""" - emr_trigger = EmrStepSensorTrigger( - job_flow_id=TEST_JOB_FLOW_ID, - step_id=STEP_ID, - aws_conn_id=AWS_CONN_ID, - poke_interval=POLL_INTERVAL, - ) - class_path, args = emr_trigger.serialize() - assert class_path == "airflow.providers.amazon.aws.triggers.emr.EmrStepSensorTrigger" - assert args["job_flow_id"] == TEST_JOB_FLOW_ID - assert args["step_id"] == STEP_ID - assert args["aws_conn_id"] == AWS_CONN_ID - assert args["max_attempts"] == 60 - assert args["poke_interval"] == POLL_INTERVAL - - @pytest.mark.asyncio - @mock.patch.object(EmrHook, "async_conn") - async def test_emr_step_trigger_run(self, mock_async_conn): - """Test trigger emit success if condition met""" - a_mock = mock.MagicMock() - mock_async_conn.__aenter__.return_value = a_mock - - a_mock.get_waiter().wait = AsyncMock() - - emr_trigger = EmrStepSensorTrigger( - job_flow_id=TEST_JOB_FLOW_ID, - step_id=STEP_ID, - aws_conn_id=AWS_CONN_ID, - poll_interval=POLL_INTERVAL, - ) - - generator = emr_trigger.run() - response = await generator.asend(None) - - assert response == TriggerEvent({"status": "success"}) - - @pytest.mark.asyncio - @mock.patch("asyncio.sleep") - @mock.patch.object(EmrHook, "async_conn") - async def test_emr_trigger_run_multiple_attempts(self, mock_async_conn, mock_sleep): - """Test trigger try max attempt if attempt not exceeded and job still running""" - mock_sleep.return_value = True - a_mock = mock.MagicMock() - mock_async_conn.__aenter__.return_value = a_mock - - error = WaiterError( - name="test_name", - reason="test_reason", - last_response={"Step": {"Status": {"State": "RUNNING"}}}, - ) - a_mock.get_waiter().wait = AsyncMock(side_effect=[error, error, error, True]) - - emr_trigger = EmrStepSensorTrigger( - job_flow_id=TEST_JOB_FLOW_ID, - step_id=STEP_ID, - aws_conn_id=AWS_CONN_ID, - poll_interval=POLL_INTERVAL, - ) - - generator = emr_trigger.run() - response = await generator.asend(None) - - assert a_mock.get_waiter().wait.call_count == 4 - assert response == TriggerEvent({"status": "success"}) - - @pytest.mark.asyncio - @mock.patch("asyncio.sleep") - @mock.patch.object(EmrHook, "async_conn") - async def test_emr_trigger_run_attempts_failed(self, mock_async_conn, mock_sleep): - """Test trigger does fail if max attempt exceeded and job still not succeeded""" - mock_sleep.return_value = True - a_mock = mock.MagicMock() - mock_async_conn.__aenter__.return_value = a_mock - - error_running = WaiterError( - name="test_name", - reason="test reason", - last_response={"Step": {"Status": {"State": "RUNNING"}}}, - ) - error_failed = WaiterError( - name="test_name", - reason="Waiter encountered a terminal failure state", - last_response={"Step": {"Status": {"State": "CANCELLED"}}}, - ) - - a_mock.get_waiter().wait = AsyncMock(side_effect=[error_running, error_failed]) - mock_sleep.return_value = True - - emr_trigger = EmrStepSensorTrigger( - job_flow_id=TEST_JOB_FLOW_ID, - step_id=STEP_ID, - aws_conn_id=AWS_CONN_ID, - poll_interval=POLL_INTERVAL, - ) - - with pytest.raises(AirflowException) as exc: - generator = emr_trigger.run() - await generator.asend(None) - - assert a_mock.get_waiter().wait.call_count == 2 - assert "Error while waiting for step s-1234 to complete" in str(exc.value) +class TestEmrTriggers: + @pytest.mark.parametrize( + "trigger", + [ + EmrCreateJobFlowTrigger( + job_flow_id=TEST_JOB_FLOW_ID, + aws_conn_id=TEST_AWS_CONN_ID, + poll_interval=TEST_POLL_INTERVAL, + max_attempts=TEST_MAX_ATTEMPTS, + ), + EmrTerminateJobFlowTrigger( + job_flow_id=TEST_JOB_FLOW_ID, + aws_conn_id=TEST_AWS_CONN_ID, + poll_interval=TEST_POLL_INTERVAL, + max_attempts=TEST_MAX_ATTEMPTS, + ), + EmrContainerTrigger( + virtual_cluster_id=VIRTUAL_CLUSTER_ID, + job_id=JOB_ID, + aws_conn_id=AWS_CONN_ID, + poll_interval=POLL_INTERVAL, + ), + EmrStepSensorTrigger( + job_flow_id=TEST_JOB_FLOW_ID, + step_id=STEP_ID, + aws_conn_id=AWS_CONN_ID, + waiter_delay=POLL_INTERVAL, + ), + ], + ) + def test_serialize_recreate(self, trigger): + class_path, args = trigger.serialize() + + class_name = class_path.split(".")[-1] + clazz = globals()[class_name] + instance = clazz(**args) + + class_path2, args2 = instance.serialize() + + assert class_path == class_path2 + assert args == args2 diff --git a/tests/providers/amazon/aws/triggers/test_emr_trigger.py b/tests/providers/amazon/aws/triggers/test_emr_trigger.py index 0ec3b5af6eb8c..fe28f86edb67d 100644 --- a/tests/providers/amazon/aws/triggers/test_emr_trigger.py +++ b/tests/providers/amazon/aws/triggers/test_emr_trigger.py @@ -23,7 +23,13 @@ from botocore.exceptions import WaiterError from airflow.providers.amazon.aws.hooks.emr import EmrHook -from airflow.providers.amazon.aws.triggers.emr import EmrAddStepsTrigger +from airflow.providers.amazon.aws.triggers.emr import ( + EmrAddStepsTrigger, + EmrContainerTrigger, + EmrCreateJobFlowTrigger, + EmrStepSensorTrigger, + EmrTerminateJobFlowTrigger, +) from airflow.triggers.base import TriggerEvent TEST_JOB_FLOW_ID = "test_job_flow_id" @@ -168,3 +174,48 @@ async def test_emr_add_steps_trigger_run_attempts_failed(self, mock_async_conn, assert response == TriggerEvent( {"status": "failure", "message": f"Step {TEST_STEP_IDS[0]} failed: {error_failed}"} ) + + +class TestEmrTriggers: + @pytest.mark.parametrize( + "trigger", + [ + EmrCreateJobFlowTrigger( + job_flow_id=TEST_JOB_FLOW_ID, + aws_conn_id=TEST_AWS_CONN_ID, + waiter_delay=TEST_POLL_INTERVAL, + waiter_max_attempts=TEST_MAX_ATTEMPTS, + ), + EmrTerminateJobFlowTrigger( + job_flow_id=TEST_JOB_FLOW_ID, + aws_conn_id=TEST_AWS_CONN_ID, + waiter_delay=TEST_POLL_INTERVAL, + waiter_max_attempts=TEST_MAX_ATTEMPTS, + ), + EmrContainerTrigger( + virtual_cluster_id="my_cluster_id", + job_id=TEST_JOB_FLOW_ID, + aws_conn_id=TEST_AWS_CONN_ID, + waiter_delay=TEST_POLL_INTERVAL, + waiter_max_attempts=TEST_MAX_ATTEMPTS, + ), + EmrStepSensorTrigger( + job_flow_id=TEST_JOB_FLOW_ID, + step_id="my_step", + aws_conn_id=TEST_AWS_CONN_ID, + waiter_delay=TEST_POLL_INTERVAL, + waiter_max_attempts=TEST_MAX_ATTEMPTS, + ), + ], + ) + def test_serialize_recreate(self, trigger): + class_path, args = trigger.serialize() + + class_name = class_path.split(".")[-1] + clazz = globals()[class_name] + instance = clazz(**args) + + class_path2, args2 = instance.serialize() + + assert class_path == class_path2 + assert args == args2 diff --git a/tests/providers/amazon/aws/triggers/test_glue.py b/tests/providers/amazon/aws/triggers/test_glue.py index cc98ecc74831b..9c7cd61a712ee 100644 --- a/tests/providers/amazon/aws/triggers/test_glue.py +++ b/tests/providers/amazon/aws/triggers/test_glue.py @@ -24,6 +24,7 @@ from airflow import AirflowException from airflow.providers.amazon.aws.hooks.glue import GlueJobHook from airflow.providers.amazon.aws.triggers.glue import GlueJobCompleteTrigger +from airflow.providers.amazon.aws.triggers.glue_crawler import GlueCrawlerCompleteTrigger class TestGlueJobTrigger: @@ -69,3 +70,21 @@ async def test_wait_job_failed(self, get_state_mock: mock.MagicMock): await trigger.run().asend(None) assert get_state_mock.call_count == 3 + + +class TestGlueCrawlerTrigger: + def test_serialize_recreate(self): + trigger = GlueCrawlerCompleteTrigger( + crawler_name="my_crawler", waiter_delay=2, aws_conn_id="my_conn_id" + ) + + class_path, args = trigger.serialize() + + class_name = class_path.split(".")[-1] + clazz = globals()[class_name] + instance = clazz(**args) + + class_path2, args2 = instance.serialize() + + assert class_path == class_path2 + assert args == args2 diff --git a/tests/providers/amazon/aws/triggers/test_redshift_cluster.py b/tests/providers/amazon/aws/triggers/test_redshift_cluster.py index 43a2ee0ae1847..af5e4a9da1ad4 100644 --- a/tests/providers/amazon/aws/triggers/test_redshift_cluster.py +++ b/tests/providers/amazon/aws/triggers/test_redshift_cluster.py @@ -16,13 +16,8 @@ # under the License. from __future__ import annotations -from unittest import mock -from unittest.mock import AsyncMock - import pytest -from botocore.exceptions import WaiterError -from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftHook from airflow.providers.amazon.aws.triggers.redshift_cluster import ( RedshiftCreateClusterSnapshotTrigger, RedshiftCreateClusterTrigger, @@ -30,7 +25,6 @@ RedshiftPauseClusterTrigger, RedshiftResumeClusterTrigger, ) -from airflow.triggers.base import TriggerEvent TEST_CLUSTER_IDENTIFIER = "test-cluster" TEST_POLL_INTERVAL = 10 @@ -38,598 +32,50 @@ TEST_AWS_CONN_ID = "test-aws-id" -class TestRedshiftCreateClusterTrigger: - def test_redshift_create_cluster_trigger_serialize(self): - redshift_create_cluster_trigger = RedshiftCreateClusterTrigger( - cluster_identifier=TEST_CLUSTER_IDENTIFIER, - poll_interval=TEST_POLL_INTERVAL, - max_attempt=TEST_MAX_ATTEMPT, - aws_conn_id=TEST_AWS_CONN_ID, - ) - class_path, args = redshift_create_cluster_trigger.serialize() - assert ( - class_path - == "airflow.providers.amazon.aws.triggers.redshift_cluster.RedshiftCreateClusterTrigger" - ) - assert args["cluster_identifier"] == TEST_CLUSTER_IDENTIFIER - assert args["poll_interval"] == str(TEST_POLL_INTERVAL) - assert args["max_attempt"] == str(TEST_MAX_ATTEMPT) - assert args["aws_conn_id"] == TEST_AWS_CONN_ID - - @pytest.mark.asyncio - @mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.async_conn") - async def test_redshift_create_cluster_trigger_run(self, mock_async_conn): - a_mock = mock.MagicMock() - mock_async_conn.__aenter__.return_value = a_mock - a_mock.get_waiter().wait = AsyncMock() - - redshift_create_cluster_trigger = RedshiftCreateClusterTrigger( - cluster_identifier=TEST_CLUSTER_IDENTIFIER, - poll_interval=TEST_POLL_INTERVAL, - max_attempt=TEST_MAX_ATTEMPT, - aws_conn_id=TEST_AWS_CONN_ID, - ) - - generator = redshift_create_cluster_trigger.run() - response = await generator.asend(None) - - assert response == TriggerEvent({"status": "success", "message": "Cluster Created"}) - - -class TestRedshiftPauseClusterTrigger: - def test_redshift_pause_cluster_trigger_serialize(self): - redshift_pause_cluster_trigger = RedshiftPauseClusterTrigger( - cluster_identifier=TEST_CLUSTER_IDENTIFIER, - poll_interval=TEST_POLL_INTERVAL, - max_attempts=TEST_MAX_ATTEMPT, - aws_conn_id=TEST_AWS_CONN_ID, - ) - class_path, args = redshift_pause_cluster_trigger.serialize() - assert ( - class_path == "airflow.providers.amazon.aws.triggers.redshift_cluster.RedshiftPauseClusterTrigger" - ) - assert args["cluster_identifier"] == TEST_CLUSTER_IDENTIFIER - assert args["poll_interval"] == str(TEST_POLL_INTERVAL) - assert args["max_attempts"] == str(TEST_MAX_ATTEMPT) - assert args["aws_conn_id"] == TEST_AWS_CONN_ID - - @pytest.mark.asyncio - @mock.patch.object(RedshiftHook, "get_waiter") - @mock.patch.object(RedshiftHook, "async_conn") - async def test_redshift_pause_cluster_trigger_run(self, mock_async_conn, mock_get_waiter): - a_mock = mock.MagicMock() - mock_async_conn.__aenter__.return_value = a_mock - - mock_get_waiter().wait = AsyncMock() - - redshift_pause_cluster_trigger = RedshiftPauseClusterTrigger( - cluster_identifier=TEST_CLUSTER_IDENTIFIER, - poll_interval=TEST_POLL_INTERVAL, - max_attempts=TEST_MAX_ATTEMPT, - aws_conn_id=TEST_AWS_CONN_ID, - ) - - generator = redshift_pause_cluster_trigger.run() - response = await generator.asend(None) - - assert response == TriggerEvent({"status": "success", "message": "Cluster paused"}) - - @pytest.mark.asyncio - @mock.patch("asyncio.sleep") - @mock.patch.object(RedshiftHook, "get_waiter") - @mock.patch.object(RedshiftHook, "async_conn") - async def test_redshift_pause_cluster_trigger_run_multiple_attempts( - self, mock_async_conn, mock_get_waiter, mock_sleep - ): - a_mock = mock.MagicMock() - mock_async_conn.__aenter__.return_value = a_mock - error = WaiterError( - name="test_name", - reason="test_reason", - last_response={"Clusters": [{"ClusterStatus": "available"}]}, - ) - mock_get_waiter().wait.side_effect = AsyncMock(side_effect=[error, error, True]) - mock_sleep.return_value = True - - redshift_pause_cluster_trigger = RedshiftPauseClusterTrigger( - cluster_identifier=TEST_CLUSTER_IDENTIFIER, - poll_interval=TEST_POLL_INTERVAL, - max_attempts=TEST_MAX_ATTEMPT, - aws_conn_id=TEST_AWS_CONN_ID, - ) - - generator = redshift_pause_cluster_trigger.run() - response = await generator.asend(None) - - assert mock_get_waiter().wait.call_count == 3 - assert response == TriggerEvent({"status": "success", "message": "Cluster paused"}) - - @pytest.mark.asyncio - @mock.patch("asyncio.sleep") - @mock.patch.object(RedshiftHook, "get_waiter") - @mock.patch.object(RedshiftHook, "async_conn") - async def test_redshift_pause_cluster_trigger_run_attempts_exceeded( - self, mock_async_conn, mock_get_waiter, mock_sleep - ): - a_mock = mock.MagicMock() - mock_async_conn.__aenter__.return_value = a_mock - error = WaiterError( - name="test_name", - reason="test_reason", - last_response={"Clusters": [{"ClusterStatus": "available"}]}, - ) - mock_get_waiter().wait.side_effect = AsyncMock(side_effect=[error, error, True]) - mock_sleep.return_value = True - - redshift_pause_cluster_trigger = RedshiftPauseClusterTrigger( - cluster_identifier=TEST_CLUSTER_IDENTIFIER, - poll_interval=TEST_POLL_INTERVAL, - max_attempts=2, - aws_conn_id=TEST_AWS_CONN_ID, - ) - - generator = redshift_pause_cluster_trigger.run() - response = await generator.asend(None) - - assert mock_get_waiter().wait.call_count == 2 - assert response == TriggerEvent( - {"status": "failure", "message": "Pause Cluster Failed - max attempts reached."} - ) - - @pytest.mark.asyncio - @mock.patch("asyncio.sleep") - @mock.patch.object(RedshiftHook, "get_waiter") - @mock.patch.object(RedshiftHook, "async_conn") - async def test_redshift_pause_cluster_trigger_run_attempts_failed( - self, mock_async_conn, mock_get_waiter, mock_sleep - ): - a_mock = mock.MagicMock() - mock_async_conn.__aenter__.return_value = a_mock - error_available = WaiterError( - name="test_name", - reason="Max attempts exceeded", - last_response={"Clusters": [{"ClusterStatus": "available"}]}, - ) - error_failed = WaiterError( - name="test_name", - reason="Waiter encountered a terminal failure state:", - last_response={"Clusters": [{"ClusterStatus": "available"}]}, - ) - mock_get_waiter().wait.side_effect = AsyncMock( - side_effect=[error_available, error_available, error_failed] - ) - mock_sleep.return_value = True - - redshift_pause_cluster_trigger = RedshiftPauseClusterTrigger( - cluster_identifier=TEST_CLUSTER_IDENTIFIER, - poll_interval=TEST_POLL_INTERVAL, - max_attempts=TEST_MAX_ATTEMPT, - aws_conn_id=TEST_AWS_CONN_ID, - ) - - generator = redshift_pause_cluster_trigger.run() - response = await generator.asend(None) - - assert mock_get_waiter().wait.call_count == 3 - assert response == TriggerEvent( - {"status": "failure", "message": f"Pause Cluster Failed: {error_failed}"} - ) - - -class TestRedshiftCreateClusterSnapshotTrigger: - def test_redshift_create_cluster_snapshot_trigger_serialize(self): - redshift_create_cluster_trigger = RedshiftCreateClusterSnapshotTrigger( - cluster_identifier=TEST_CLUSTER_IDENTIFIER, - poll_interval=TEST_POLL_INTERVAL, - max_attempts=TEST_MAX_ATTEMPT, - aws_conn_id=TEST_AWS_CONN_ID, - ) - class_path, args = redshift_create_cluster_trigger.serialize() - assert ( - class_path - == "airflow.providers.amazon.aws.triggers.redshift_cluster.RedshiftCreateClusterSnapshotTrigger" - ) - assert args["cluster_identifier"] == TEST_CLUSTER_IDENTIFIER - assert args["poll_interval"] == str(TEST_POLL_INTERVAL) - assert args["max_attempts"] == str(TEST_MAX_ATTEMPT) - assert args["aws_conn_id"] == TEST_AWS_CONN_ID - - @pytest.mark.asyncio - @mock.patch.object(RedshiftHook, "async_conn") - async def test_redshift_create_cluster_snapshot_trigger_run(self, mock_async_conn): - a_mock = mock.MagicMock() - mock_async_conn.__aenter__.return_value = a_mock - a_mock.get_waiter().wait = AsyncMock() - - redshift_create_cluster_trigger = RedshiftCreateClusterSnapshotTrigger( - cluster_identifier=TEST_CLUSTER_IDENTIFIER, - poll_interval=TEST_POLL_INTERVAL, - max_attempts=TEST_MAX_ATTEMPT, - aws_conn_id=TEST_AWS_CONN_ID, - ) - - generator = redshift_create_cluster_trigger.run() - response = await generator.asend(None) - - assert response == TriggerEvent({"status": "success", "message": "Cluster Snapshot Created"}) - - @pytest.mark.asyncio - @mock.patch("asyncio.sleep") - @mock.patch.object(RedshiftHook, "async_conn") - async def test_redshift_create_cluster_snapshot_trigger_run_multiple_attempts( - self, mock_async_conn, mock_sleep - ): - a_mock = mock.MagicMock() - mock_async_conn.__aenter__.return_value = a_mock - - a_mock = mock.MagicMock() - mock_async_conn.__aenter__.return_value = a_mock - error = WaiterError( - name="test_name", - reason="test_reason", - last_response={"Snapshots": [{"Status": "available"}]}, - ) - a_mock.get_waiter().wait.side_effect = AsyncMock(side_effect=[error, error, True]) - mock_sleep.return_value = True - - redshift_create_cluster_snapshot_trigger = RedshiftCreateClusterSnapshotTrigger( - cluster_identifier=TEST_CLUSTER_IDENTIFIER, - poll_interval=TEST_POLL_INTERVAL, - max_attempts=TEST_MAX_ATTEMPT, - aws_conn_id=TEST_AWS_CONN_ID, - ) - - generator = redshift_create_cluster_snapshot_trigger.run() - response = await generator.asend(None) - - assert a_mock.get_waiter().wait.call_count == 3 - assert response == TriggerEvent({"status": "success", "message": "Cluster Snapshot Created"}) - - @pytest.mark.asyncio - @mock.patch("asyncio.sleep") - @mock.patch.object(RedshiftHook, "async_conn") - async def test_redshift_create_cluster_snapshot_trigger_run_attempts_exceeded( - self, mock_async_conn, mock_sleep - ): - a_mock = mock.MagicMock() - mock_async_conn.__aenter__.return_value = a_mock - - a_mock = mock.MagicMock() - mock_async_conn.__aenter__.return_value = a_mock - error = WaiterError( - name="test_name", - reason="test_reason", - last_response={"Snapshots": [{"Status": "available"}]}, - ) - a_mock.get_waiter().wait.side_effect = AsyncMock(side_effect=[error, error, True]) - mock_sleep.return_value = True - - redshift_create_cluster_snapshot_trigger = RedshiftCreateClusterSnapshotTrigger( - cluster_identifier=TEST_CLUSTER_IDENTIFIER, - poll_interval=TEST_POLL_INTERVAL, - max_attempts=2, - aws_conn_id=TEST_AWS_CONN_ID, - ) - - generator = redshift_create_cluster_snapshot_trigger.run() - response = await generator.asend(None) - - assert a_mock.get_waiter().wait.call_count == 2 - assert response == TriggerEvent( - {"status": "failure", "message": "Create Cluster Snapshot Cluster Failed - max attempts reached."} - ) - - @pytest.mark.asyncio - @mock.patch("asyncio.sleep") - @mock.patch.object(RedshiftHook, "async_conn") - async def test_redshift_create_cluster_snapshot_trigger_run_attempts_failed( - self, mock_async_conn, mock_sleep - ): - a_mock = mock.MagicMock() - mock_async_conn.__aenter__.return_value = a_mock - - a_mock = mock.MagicMock() - mock_async_conn.__aenter__.return_value = a_mock - error_available = WaiterError( - name="test_name", - reason="test_reason", - last_response={"Snapshots": [{"Status": "available"}]}, - ) - - error_failed = WaiterError( - name="test_name", - reason="Waiter encountered a terminal failure state:", - last_response={"Snapshots": [{"Status": "available"}]}, - ) - a_mock.get_waiter().wait.side_effect = AsyncMock( - side_effect=[error_available, error_available, error_failed] - ) - mock_sleep.return_value = True - - redshift_create_cluster_snapshot_trigger = RedshiftCreateClusterSnapshotTrigger( - cluster_identifier=TEST_CLUSTER_IDENTIFIER, - poll_interval=TEST_POLL_INTERVAL, - max_attempts=TEST_MAX_ATTEMPT, - aws_conn_id=TEST_AWS_CONN_ID, - ) - - generator = redshift_create_cluster_snapshot_trigger.run() - response = await generator.asend(None) - - assert a_mock.get_waiter().wait.call_count == 3 - assert response == TriggerEvent( - {"status": "failure", "message": f"Create Cluster Snapshot Failed: {error_failed}"} - ) - - -class TestRedshiftResumeClusterTrigger: - def test_redshift_resume_cluster_trigger_serialize(self): - redshift_resume_cluster_trigger = RedshiftResumeClusterTrigger( - cluster_identifier=TEST_CLUSTER_IDENTIFIER, - poll_interval=TEST_POLL_INTERVAL, - max_attempts=TEST_MAX_ATTEMPT, - aws_conn_id=TEST_AWS_CONN_ID, - ) - class_path, args = redshift_resume_cluster_trigger.serialize() - assert ( - class_path - == "airflow.providers.amazon.aws.triggers.redshift_cluster.RedshiftResumeClusterTrigger" - ) - assert args["cluster_identifier"] == TEST_CLUSTER_IDENTIFIER - assert args["poll_interval"] == str(TEST_POLL_INTERVAL) - assert args["max_attempts"] == str(TEST_MAX_ATTEMPT) - assert args["aws_conn_id"] == TEST_AWS_CONN_ID - - @pytest.mark.asyncio - @mock.patch.object(RedshiftHook, "get_waiter") - @mock.patch.object(RedshiftHook, "async_conn") - async def test_redshift_resume_cluster_trigger_run(self, mock_async_conn, mock_get_waiter): - a_mock = mock.MagicMock() - mock_async_conn.__aenter__.return_value = a_mock - - mock_get_waiter().wait = AsyncMock() - - redshift_resume_cluster_trigger = RedshiftResumeClusterTrigger( - cluster_identifier=TEST_CLUSTER_IDENTIFIER, - poll_interval=TEST_POLL_INTERVAL, - max_attempts=TEST_MAX_ATTEMPT, - aws_conn_id=TEST_AWS_CONN_ID, - ) - - generator = redshift_resume_cluster_trigger.run() - response = await generator.asend(None) - - assert response == TriggerEvent({"status": "success", "message": "Cluster resumed"}) - - @pytest.mark.asyncio - @mock.patch("asyncio.sleep") - @mock.patch.object(RedshiftHook, "get_waiter") - @mock.patch.object(RedshiftHook, "async_conn") - async def test_redshift_resume_cluster_trigger_run_multiple_attempts( - self, mock_async_conn, mock_get_waiter, mock_sleep - ): - a_mock = mock.MagicMock() - mock_async_conn.__aenter__.return_value = a_mock - error = WaiterError( - name="test_name", - reason="test_reason", - last_response={"Clusters": [{"ClusterStatus": "available"}]}, - ) - mock_get_waiter().wait.side_effect = AsyncMock(side_effect=[error, error, True]) - mock_sleep.return_value = True - - redshift_resume_cluster_trigger = RedshiftResumeClusterTrigger( - cluster_identifier=TEST_CLUSTER_IDENTIFIER, - poll_interval=TEST_POLL_INTERVAL, - max_attempts=TEST_MAX_ATTEMPT, - aws_conn_id=TEST_AWS_CONN_ID, - ) - - generator = redshift_resume_cluster_trigger.run() - response = await generator.asend(None) - - assert mock_get_waiter().wait.call_count == 3 - assert response == TriggerEvent({"status": "success", "message": "Cluster resumed"}) - - @pytest.mark.asyncio - @mock.patch("asyncio.sleep") - @mock.patch.object(RedshiftHook, "get_waiter") - @mock.patch.object(RedshiftHook, "async_conn") - async def test_redshift_resume_cluster_trigger_run_attempts_exceeded( - self, mock_async_conn, mock_get_waiter, mock_sleep - ): - a_mock = mock.MagicMock() - mock_async_conn.__aenter__.return_value = a_mock - error = WaiterError( - name="test_name", - reason="test_reason", - last_response={"Clusters": [{"ClusterStatus": "available"}]}, - ) - mock_get_waiter().wait.side_effect = AsyncMock(side_effect=[error, error, True]) - mock_sleep.return_value = True - - redshift_resume_cluster_trigger = RedshiftResumeClusterTrigger( - cluster_identifier=TEST_CLUSTER_IDENTIFIER, - poll_interval=TEST_POLL_INTERVAL, - max_attempts=2, - aws_conn_id=TEST_AWS_CONN_ID, - ) - - generator = redshift_resume_cluster_trigger.run() - response = await generator.asend(None) - - assert mock_get_waiter().wait.call_count == 2 - assert response == TriggerEvent( - {"status": "failure", "message": "Resume Cluster Failed - max attempts reached."} - ) - - @pytest.mark.asyncio - @mock.patch("asyncio.sleep") - @mock.patch.object(RedshiftHook, "get_waiter") - @mock.patch.object(RedshiftHook, "async_conn") - async def test_redshift_resume_cluster_trigger_run_attempts_failed( - self, mock_async_conn, mock_get_waiter, mock_sleep - ): - a_mock = mock.MagicMock() - mock_async_conn.__aenter__.return_value = a_mock - error_available = WaiterError( - name="test_name", - reason="Max attempts exceeded", - last_response={"Clusters": [{"ClusterStatus": "available"}]}, - ) - error_failed = WaiterError( - name="test_name", - reason="Waiter encountered a terminal failure state:", - last_response={"Clusters": [{"ClusterStatus": "available"}]}, - ) - mock_get_waiter().wait.side_effect = AsyncMock( - side_effect=[error_available, error_available, error_failed] - ) - mock_sleep.return_value = True - - redshift_resume_cluster_trigger = RedshiftResumeClusterTrigger( - cluster_identifier=TEST_CLUSTER_IDENTIFIER, - poll_interval=TEST_POLL_INTERVAL, - max_attempts=TEST_MAX_ATTEMPT, - aws_conn_id=TEST_AWS_CONN_ID, - ) - - generator = redshift_resume_cluster_trigger.run() - response = await generator.asend(None) - - assert mock_get_waiter().wait.call_count == 3 - assert response == TriggerEvent( - {"status": "failure", "message": f"Resume Cluster Failed: {error_failed}"} - ) - - -class TestRedshiftDeleteClusterTrigger: - def test_redshift_delete_cluster_trigger_serialize(self): - redshift_delete_cluster_trigger = RedshiftDeleteClusterTrigger( - cluster_identifier=TEST_CLUSTER_IDENTIFIER, - poll_interval=TEST_POLL_INTERVAL, - max_attempts=TEST_MAX_ATTEMPT, - aws_conn_id=TEST_AWS_CONN_ID, - ) - class_path, args = redshift_delete_cluster_trigger.serialize() - assert ( - class_path - == "airflow.providers.amazon.aws.triggers.redshift_cluster.RedshiftDeleteClusterTrigger" - ) - assert args["cluster_identifier"] == TEST_CLUSTER_IDENTIFIER - assert args["poll_interval"] == TEST_POLL_INTERVAL - assert args["max_attempts"] == TEST_MAX_ATTEMPT - assert args["aws_conn_id"] == TEST_AWS_CONN_ID - - @pytest.mark.asyncio - @mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.async_conn") - async def test_redshift_delete_cluster_trigger_run(self, mock_async_conn): - a_mock = mock.MagicMock() - mock_async_conn.__aenter__.return_value = a_mock - a_mock.get_waiter().wait = AsyncMock() - - redshift_delete_cluster_trigger = RedshiftDeleteClusterTrigger( - cluster_identifier=TEST_CLUSTER_IDENTIFIER, - poll_interval=TEST_POLL_INTERVAL, - max_attempts=TEST_MAX_ATTEMPT, - aws_conn_id=TEST_AWS_CONN_ID, - ) - - generator = redshift_delete_cluster_trigger.run() - response = await generator.asend(None) - - assert response == TriggerEvent({"status": "success", "message": "Cluster deleted."}) - - @pytest.mark.asyncio - @mock.patch("asyncio.sleep") - @mock.patch.object(RedshiftHook, "async_conn") - async def test_redshift_delete_cluster_trigger_run_multiple_attempts(self, mock_async_conn, mock_sleep): - a_mock = mock.MagicMock() - mock_async_conn.__aenter__.return_value = a_mock - error = WaiterError( - name="test_name", - reason="test_reason", - last_response={"Clusters": [{"ClusterStatus": "available"}]}, - ) - a_mock.get_waiter().wait.side_effect = AsyncMock(side_effect=[error, error, True]) - mock_sleep.return_value = True - - redshift_delete_cluster_trigger = RedshiftDeleteClusterTrigger( - cluster_identifier=TEST_CLUSTER_IDENTIFIER, - poll_interval=TEST_POLL_INTERVAL, - max_attempts=TEST_MAX_ATTEMPT, - aws_conn_id=TEST_AWS_CONN_ID, - ) - - generator = redshift_delete_cluster_trigger.run() - response = await generator.asend(None) - - assert a_mock.get_waiter().wait.call_count == 3 - assert response == TriggerEvent({"status": "success", "message": "Cluster deleted."}) - - @pytest.mark.asyncio - @mock.patch("asyncio.sleep") - @mock.patch.object(RedshiftHook, "async_conn") - async def test_redshift_delete_cluster_trigger_run_attempts_exceeded(self, mock_async_conn, mock_sleep): - a_mock = mock.MagicMock() - mock_async_conn.__aenter__.return_value = a_mock - - error = WaiterError( - name="test_name", - reason="test_reason", - last_response={"Clusters": [{"ClusterStatus": "deleting"}]}, - ) - a_mock.get_waiter().wait.side_effect = AsyncMock(side_effect=[error, error, True]) - mock_sleep.return_value = True - - redshift_delete_cluster_trigger = RedshiftDeleteClusterTrigger( - cluster_identifier=TEST_CLUSTER_IDENTIFIER, - poll_interval=TEST_POLL_INTERVAL, - max_attempts=2, - aws_conn_id=TEST_AWS_CONN_ID, - ) - - generator = redshift_delete_cluster_trigger.run() - response = await generator.asend(None) - - assert a_mock.get_waiter().wait.call_count == 2 - assert response == TriggerEvent( - {"status": "failure", "message": "Delete Cluster Failed - max attempts reached."} - ) - - @pytest.mark.asyncio - @mock.patch("asyncio.sleep") - @mock.patch.object(RedshiftHook, "async_conn") - async def test_redshift_delete_cluster_trigger_run_attempts_failed(self, mock_async_conn, mock_sleep): - a_mock = mock.MagicMock() - mock_async_conn.__aenter__.return_value = a_mock - - error_available = WaiterError( - name="test_name", - reason="Max attempts exceeded", - last_response={"Clusters": [{"ClusterStatus": "deleting"}]}, - ) - error_failed = WaiterError( - name="test_name", - reason="Waiter encountered a terminal failure state:", - last_response={"Clusters": [{"ClusterStatus": "available"}]}, - ) - a_mock.get_waiter().wait.side_effect = AsyncMock( - side_effect=[error_available, error_available, error_failed] - ) - mock_sleep.return_value = True - - redshift_delete_cluster_trigger = RedshiftDeleteClusterTrigger( - cluster_identifier=TEST_CLUSTER_IDENTIFIER, - poll_interval=TEST_POLL_INTERVAL, - max_attempts=TEST_MAX_ATTEMPT, - aws_conn_id=TEST_AWS_CONN_ID, - ) - - generator = redshift_delete_cluster_trigger.run() - response = await generator.asend(None) - - assert a_mock.get_waiter().wait.call_count == 3 - assert response == TriggerEvent( - {"status": "failure", "message": f"Delete Cluster Failed: {error_failed}"} - ) +class TestRedshiftClusterTriggers: + @pytest.mark.parametrize( + "trigger", + [ + RedshiftCreateClusterTrigger( + cluster_identifier=TEST_CLUSTER_IDENTIFIER, + poll_interval=TEST_POLL_INTERVAL, + max_attempt=TEST_MAX_ATTEMPT, + aws_conn_id=TEST_AWS_CONN_ID, + ), + RedshiftPauseClusterTrigger( + cluster_identifier=TEST_CLUSTER_IDENTIFIER, + poll_interval=TEST_POLL_INTERVAL, + max_attempts=TEST_MAX_ATTEMPT, + aws_conn_id=TEST_AWS_CONN_ID, + ), + RedshiftCreateClusterSnapshotTrigger( + cluster_identifier=TEST_CLUSTER_IDENTIFIER, + poll_interval=TEST_POLL_INTERVAL, + max_attempts=TEST_MAX_ATTEMPT, + aws_conn_id=TEST_AWS_CONN_ID, + ), + RedshiftResumeClusterTrigger( + cluster_identifier=TEST_CLUSTER_IDENTIFIER, + poll_interval=TEST_POLL_INTERVAL, + max_attempts=TEST_MAX_ATTEMPT, + aws_conn_id=TEST_AWS_CONN_ID, + ), + RedshiftDeleteClusterTrigger( + cluster_identifier=TEST_CLUSTER_IDENTIFIER, + poll_interval=TEST_POLL_INTERVAL, + max_attempts=TEST_MAX_ATTEMPT, + aws_conn_id=TEST_AWS_CONN_ID, + ), + ], + ) + def test_serialize_recreate(self, trigger): + class_path, args = trigger.serialize() + + class_name = class_path.split(".")[-1] + clazz = globals()[class_name] + instance = clazz(**args) + + class_path2, args2 = instance.serialize() + + assert class_path == class_path2 + assert args == args2