diff --git a/airflow/providers/amazon/aws/operators/eks.py b/airflow/providers/amazon/aws/operators/eks.py index 6858f801121bd..e7931cc07a075 100644 --- a/airflow/providers/amazon/aws/operators/eks.py +++ b/airflow/providers/amazon/aws/operators/eks.py @@ -21,7 +21,7 @@ import warnings from ast import literal_eval from datetime import timedelta -from typing import TYPE_CHECKING, List, Sequence, cast +from typing import TYPE_CHECKING, Any, List, Sequence, cast from botocore.exceptions import ClientError, WaiterError @@ -30,8 +30,10 @@ from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.eks import EksHook from airflow.providers.amazon.aws.triggers.eks import ( + EksCreateClusterTrigger, EksCreateFargateProfileTrigger, EksCreateNodegroupTrigger, + EksDeleteClusterTrigger, EksDeleteFargateProfileTrigger, EksDeleteNodegroupTrigger, ) @@ -187,6 +189,9 @@ class EksCreateClusterOperator(BaseOperator): (templated) :param waiter_delay: Time (in seconds) to wait between two consecutive calls to check cluster state :param waiter_max_attempts: The maximum number of attempts to check cluster state + :param deferrable: If True, the operator will wait asynchronously for the job to complete. + This implies waiting for completion. This mode requires aiobotocore module to be installed. + (default: False) """ @@ -225,6 +230,7 @@ def __init__( wait_for_completion: bool = False, aws_conn_id: str = DEFAULT_CONN_ID, region: str | None = None, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), waiter_delay: int = 30, waiter_max_attempts: int = 40, **kwargs, @@ -237,7 +243,7 @@ def __init__( self.nodegroup_role_arn = nodegroup_role_arn self.fargate_pod_execution_role_arn = fargate_pod_execution_role_arn self.create_fargate_profile_kwargs = create_fargate_profile_kwargs or {} - self.wait_for_completion = wait_for_completion + self.wait_for_completion = False if deferrable else wait_for_completion self.waiter_delay = waiter_delay self.waiter_max_attempts = waiter_max_attempts self.aws_conn_id = aws_conn_id @@ -246,6 +252,7 @@ def __init__( self.create_nodegroup_kwargs = create_nodegroup_kwargs or {} self.fargate_selectors = fargate_selectors or [{"namespace": DEFAULT_NAMESPACE_NAME}] self.fargate_profile_name = fargate_profile_name + self.deferrable = deferrable super().__init__( **kwargs, ) @@ -274,12 +281,25 @@ def execute(self, context: Context): # Short circuit early if we don't need to wait to attach compute # and the caller hasn't requested to wait for the cluster either. - if not self.compute and not self.wait_for_completion: + if not any([self.compute, self.wait_for_completion, self.deferrable]): return None - self.log.info("Waiting for EKS Cluster to provision. This will take some time.") + self.log.info("Waiting for EKS Cluster to provision. This will take some time.") client = self.eks_hook.conn + if self.deferrable: + self.defer( + trigger=EksCreateClusterTrigger( + cluster_name=self.cluster_name, + aws_conn_id=self.aws_conn_id, + region_name=self.region, + waiter_delay=self.waiter_delay, + waiter_max_attempts=self.waiter_max_attempts, + ), + method_name="deferrable_create_cluster_next", + timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay), + ) + try: client.get_waiter("cluster_active").wait( name=self.cluster_name, @@ -311,6 +331,89 @@ def execute(self, context: Context): subnets=cast(List[str], self.resources_vpc_config.get("subnetIds")), ) + def deferrable_create_cluster_next(self, context: Context, event: dict[str, Any] | None = None) -> None: + if event is None: + self.log.error("Trigger error: event is None") + raise AirflowException("Trigger error: event is None") + elif event["status"] == "failed": + self.log.error("Cluster failed to start and will be torn down.") + self.eks_hook.delete_cluster(name=self.cluster_name) + self.defer( + trigger=EksDeleteClusterTrigger( + cluster_name=self.cluster_name, + waiter_delay=self.waiter_delay, + waiter_max_attempts=self.waiter_max_attempts, + aws_conn_id=self.aws_conn_id, + region_name=self.region, + force_delete_compute=False, + ), + method_name="execute_failed", + timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay), + ) + elif event["status"] == "success": + self.log.info("Cluster is ready to provision compute.") + _create_compute( + compute=self.compute, + cluster_name=self.cluster_name, + aws_conn_id=self.aws_conn_id, + region=self.region, + wait_for_completion=self.wait_for_completion, + waiter_delay=self.waiter_delay, + waiter_max_attempts=self.waiter_max_attempts, + nodegroup_name=self.nodegroup_name, + nodegroup_role_arn=self.nodegroup_role_arn, + create_nodegroup_kwargs=self.create_nodegroup_kwargs, + fargate_profile_name=self.fargate_profile_name, + fargate_pod_execution_role_arn=self.fargate_pod_execution_role_arn, + fargate_selectors=self.fargate_selectors, + create_fargate_profile_kwargs=self.create_fargate_profile_kwargs, + subnets=cast(List[str], self.resources_vpc_config.get("subnetIds")), + ) + if self.compute == "fargate": + self.defer( + trigger=EksCreateFargateProfileTrigger( + cluster_name=self.cluster_name, + fargate_profile_name=self.fargate_profile_name, + waiter_delay=self.waiter_delay, + waiter_max_attempts=self.waiter_max_attempts, + aws_conn_id=self.aws_conn_id, + region=self.region, + ), + method_name="execute_complete", + timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay), + ) + else: + self.defer( + trigger=EksCreateNodegroupTrigger( + nodegroup_name=self.nodegroup_name, + cluster_name=self.cluster_name, + aws_conn_id=self.aws_conn_id, + region_name=self.region, + waiter_delay=self.waiter_delay, + waiter_max_attempts=self.waiter_max_attempts, + ), + method_name="execute_complete", + timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay), + ) + + def execute_failed(self, context: Context, event: dict[str, Any] | None = None) -> None: + if event is None: + self.log.info("Trigger error: event is None") + raise AirflowException("Trigger error: event is None") + elif event["status"] == "delteted": + self.log.info("Cluster deleted") + raise event["exception"] + + def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None: + resource = "fargate profile" if self.compute == "fargate" else self.compute + if event is None: + self.log.info("Trigger error: event is None") + raise AirflowException("Trigger error: event is None") + elif event["status"] != "success": + raise AirflowException(f"Error creating {resource}: {event}") + + self.log.info("%s created successfully", resource) + class EksCreateNodegroupOperator(BaseOperator): """ @@ -564,6 +667,11 @@ class EksDeleteClusterOperator(BaseOperator): maintained on each worker node). :param region: Which AWS region the connection should use. (templated) If this is None or empty then the default boto3 behaviour is used. + :param waiter_delay: Time (in seconds) to wait between two consecutive calls to check cluster state + :param waiter_max_attempts: The maximum number of attempts to check cluster state + :param deferrable: If True, the operator will wait asynchronously for the cluster to be deleted. + This implies waiting for completion. This mode requires aiobotocore module to be installed. + (default: False) """ @@ -582,13 +690,19 @@ def __init__( wait_for_completion: bool = False, aws_conn_id: str = DEFAULT_CONN_ID, region: str | None = None, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + waiter_delay: int = 30, + waiter_max_attempts: int = 40, **kwargs, ) -> None: self.cluster_name = cluster_name self.force_delete_compute = force_delete_compute - self.wait_for_completion = wait_for_completion + self.wait_for_completion = False if deferrable else wait_for_completion self.aws_conn_id = aws_conn_id self.region = region + self.deferrable = deferrable + self.waiter_delay = waiter_delay + self.waiter_max_attempts = waiter_max_attempts super().__init__(**kwargs) def execute(self, context: Context): @@ -596,8 +710,20 @@ def execute(self, context: Context): aws_conn_id=self.aws_conn_id, region_name=self.region, ) - - if self.force_delete_compute: + if self.deferrable: + self.defer( + trigger=EksDeleteClusterTrigger( + cluster_name=self.cluster_name, + waiter_delay=self.waiter_delay, + waiter_max_attempts=self.waiter_max_attempts, + aws_conn_id=self.aws_conn_id, + region_name=self.region, + force_delete_compute=self.force_delete_compute, + ), + method_name="execute_complete", + timeout=timedelta(seconds=self.waiter_delay * self.waiter_max_attempts), + ) + elif self.force_delete_compute: self.delete_any_nodegroups(eks_hook) self.delete_any_fargate_profiles(eks_hook) @@ -645,6 +771,13 @@ def delete_any_fargate_profiles(self, eks_hook) -> None: ) self.log.info(SUCCESS_MSG.format(compute=FARGATE_FULL_NAME)) + def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None: + if event is None: + self.log.error("Trigger error. Event is None") + raise AirflowException("Trigger error. Event is None") + elif event["status"] == "success": + self.log.info("Cluster deleted successfully.") + class EksDeleteNodegroupOperator(BaseOperator): """ diff --git a/airflow/providers/amazon/aws/triggers/eks.py b/airflow/providers/amazon/aws/triggers/eks.py index a6fb75eb80fa2..ff99b512001ea 100644 --- a/airflow/providers/amazon/aws/triggers/eks.py +++ b/airflow/providers/amazon/aws/triggers/eks.py @@ -17,11 +17,178 @@ from __future__ import annotations import warnings +from typing import Any from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook from airflow.providers.amazon.aws.hooks.eks import EksHook from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger +from airflow.providers.amazon.aws.utils.waiter_with_logging import async_wait +from airflow.triggers.base import TriggerEvent + + +class EksCreateClusterTrigger(AwsBaseWaiterTrigger): + """ + Trigger for EksCreateClusterOperator. + + The trigger will asynchronously wait for the cluster to be created. + + :param cluster_name: The name of the EKS cluster + :param waiter_delay: The amount of time in seconds to wait between attempts. + :param waiter_max_attempts: The maximum number of attempts to be made. + :param aws_conn_id: The Airflow connection used for AWS credentials. + :param region_name: Which AWS region the connection should use. + If this is None or empty then the default boto3 behaviour is used. + """ + + def __init__( + self, + cluster_name: str, + waiter_delay: int, + waiter_max_attempts: int, + aws_conn_id: str, + region_name: str | None, + ): + super().__init__( + serialized_fields={"cluster_name": cluster_name, "region_name": region_name}, + waiter_name="cluster_active", + waiter_args={"name": cluster_name}, + failure_message="Error checking Eks cluster", + status_message="Eks cluster status is", + status_queries=["cluster.status"], + return_value=None, + waiter_delay=waiter_delay, + waiter_max_attempts=waiter_max_attempts, + aws_conn_id=aws_conn_id, + region_name=region_name, + ) + + def hook(self) -> AwsGenericHook: + return EksHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) + + +class EksDeleteClusterTrigger(AwsBaseWaiterTrigger): + """ + Trigger for EksDeleteClusterOperator. + + The trigger will asynchronously wait for the cluster to be deleted. If there are + any nodegroups or fargate profiles associated with the cluster, they will be deleted + before the cluster is deleted. + + :param cluster_name: The name of the EKS cluster + :param waiter_delay: The amount of time in seconds to wait between attempts. + :param waiter_max_attempts: The maximum number of attempts to be made. + :param aws_conn_id: The Airflow connection used for AWS credentials. + :param region_name: Which AWS region the connection should use. + If this is None or empty then the default boto3 behaviour is used. + :param force_delete_compute: If True, any nodegroups or fargate profiles associated + with the cluster will be deleted before the cluster is deleted. + """ + + def __init__( + self, + cluster_name, + waiter_delay: int, + waiter_max_attempts: int, + aws_conn_id: str, + region_name: str | None, + force_delete_compute: bool, + ): + self.cluster_name = cluster_name + self.waiter_delay = waiter_delay + self.waiter_max_attempts = waiter_max_attempts + self.aws_conn_id = aws_conn_id + self.region_name = region_name + self.force_delete_compute = force_delete_compute + + def serialize(self) -> tuple[str, dict[str, Any]]: + return ( + self.__class__.__module__ + "." + self.__class__.__qualname__, + { + "cluster_name": self.cluster_name, + "waiter_delay": str(self.waiter_delay), + "waiter_max_attempts": str(self.waiter_max_attempts), + "aws_conn_id": self.aws_conn_id, + "region_name": self.region_name, + "force_delete_compute": self.force_delete_compute, + }, + ) + + def hook(self) -> AwsGenericHook: + return EksHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) + + async def run(self): + async with self.hook.async_conn as client: + waiter = client.get_waiter("cluster_deleted") + if self.force_delete_compute: + await self.delete_any_nodegroups(client=client) + await self.delete_any_fargate_profiles(client=client) + await client.delete_cluster(name=self.cluster_name) + await async_wait( + waiter=waiter, + waiter_delay=int(self.waiter_delay), + waiter_max_attempts=int(self.waiter_max_attempts), + args={"name": self.cluster_name}, + failure_message="Error deleting cluster", + status_message="Status of cluster is", + status_args=["cluster.status"], + ) + + yield TriggerEvent({"status": "deleted"}) + + async def delete_any_nodegroups(self, client) -> None: + """ + Deletes all EKS Nodegroups for a provided Amazon EKS Cluster. + + All the EKS Nodegroups are deleted simultaneously. We wait for + all Nodegroups to be deleted before returning. + """ + nodegroups = await client.list_nodegroups(clusterName=self.cluster_name) + if nodegroups.get("nodegroups", None): + self.log.info("Deleting nodegroups") + # ignoring attr-defined here because aws_base hook defines get_waiter for all hooks + waiter = self.hook.get_waiter( # type: ignore[attr-defined] + "all_nodegroups_deleted", deferrable=True, client=client + ) + for group in nodegroups["nodegroups"]: + await client.delete_nodegroup(clusterName=self.cluster_name, nodegroupName=group) + await async_wait( + waiter=waiter, + waiter_delay=int(self.waiter_delay), + waiter_max_attempts=int(self.waiter_max_attempts), + args={"clusterName": self.cluster_name}, + failure_message=f"Error deleting nodegroup for cluster {self.cluster_name}", + status_message="Deleting nodegroups associated with the cluster", + status_args=["nodegroups"], + ) + self.log.info("All nodegroups deleted") + else: + self.log.info("No nodegroups associated with cluster %s", self.cluster_name) + + async def delete_any_fargate_profiles(self, client) -> None: + """ + Deletes all EKS Fargate profiles for a provided Amazon EKS Cluster. + + EKS Fargate profiles must be deleted one at a time, so we must wait + for one to be deleted before sending the next delete command. + """ + fargate_profiles = await client.list_fargate_profiles(clusterName=self.cluster_name) + if fargate_profiles.get("fargateProfileNames"): + self.log.info("Waiting for Fargate profiles to delete. This will take some time.") + for profile in fargate_profiles["fargateProfileNames"]: + await client.delete_fargate_profile(clusterName=self.cluster_name, fargateProfileName=profile) + await async_wait( + waiter=client.get_waiter("fargate_profile_deleted"), + waiter_delay=int(self.waiter_delay), + waiter_max_attempts=int(self.waiter_max_attempts), + args={"clusterName": self.cluster_name, "fargateProfileName": profile}, + failure_message=f"Error deleting fargate profile for cluster {self.cluster_name}", + status_message="Status of fargate profile is", + status_args=["fargateProfile.status"], + ) + self.log.info("All Fargate profiles deleted") + else: + self.log.info(f"No Fargate profiles associated with cluster {self.cluster_name}") class EksCreateFargateProfileTrigger(AwsBaseWaiterTrigger): @@ -145,7 +312,11 @@ def __init__( region_name: str | None, ): super().__init__( - serialized_fields={"cluster_name": cluster_name, "nodegroup_name": nodegroup_name}, + serialized_fields={ + "cluster_name": cluster_name, + "nodegroup_name": nodegroup_name, + "region_name": region_name, + }, waiter_name="nodegroup_active", waiter_args={"clusterName": cluster_name, "nodegroupName": nodegroup_name}, failure_message="Error creating nodegroup", diff --git a/docs/apache-airflow-providers-amazon/operators/eks.rst b/docs/apache-airflow-providers-amazon/operators/eks.rst index cff682db2aa8d..9f1cc9df61eec 100644 --- a/docs/apache-airflow-providers-amazon/operators/eks.rst +++ b/docs/apache-airflow-providers-amazon/operators/eks.rst @@ -76,6 +76,7 @@ Create an Amazon EKS cluster and AWS Fargate profile in one step To create an Amazon EKS cluster and an AWS Fargate profile in one command, you can use :class:`~airflow.providers.amazon.aws.operators.eks.EksCreateClusterOperator`. +You can also run this operator in deferrable mode by setting ``deferrable`` param to ``True``. Note: An AWS IAM role with the following permissions is required: ``ec2.amazon.aws.com`` must be in the Trusted Relationships @@ -97,6 +98,7 @@ Delete an Amazon EKS Cluster To delete an existing Amazon EKS Cluster you can use :class:`~airflow.providers.amazon.aws.operators.eks.EksDeleteClusterOperator`. +You can also run this operator in deferrable mode by setting ``deferrable`` param to ``True``. .. exampleinclude:: /../../tests/system/providers/amazon/aws/example_eks_with_nodegroups.py :language: python diff --git a/tests/providers/amazon/aws/operators/test_eks.py b/tests/providers/amazon/aws/operators/test_eks.py index 8381a26ef151a..e537ab20a3336 100644 --- a/tests/providers/amazon/aws/operators/test_eks.py +++ b/tests/providers/amazon/aws/operators/test_eks.py @@ -337,6 +337,34 @@ def test_fargate_compute_missing_fargate_pod_execution_role_arn(self): ): missing_fargate_pod_execution_role_arn.execute({}) + @mock.patch.object(EksHook, "create_cluster") + def test_eks_create_cluster_short_circuit_early(self, mock_create_cluster, caplog): + mock_create_cluster.return_value = None + eks_create_cluster_operator = EksCreateClusterOperator( + task_id=TASK_ID, + **self.create_cluster_params, + compute=None, + wait_for_completion=False, + deferrable=False, + ) + eks_create_cluster_operator.execute({}) + assert len(caplog.records) == 0 + + @mock.patch.object(EksHook, "create_cluster") + def test_eks_create_cluster_with_deferrable(self, mock_create_cluster, caplog): + mock_create_cluster.return_value = None + + eks_create_cluster_operator = EksCreateClusterOperator( + task_id=TASK_ID, + **self.create_cluster_params, + compute=None, + wait_for_completion=False, + deferrable=True, + ) + with pytest.raises(TaskDeferred): + eks_create_cluster_operator.execute({}) + assert "Waiting for EKS Cluster to provision. This will take some time." in caplog.messages + class TestEksCreateFargateProfileOperator: def setup_method(self) -> None: @@ -542,6 +570,11 @@ def test_existing_cluster_not_in_use_with_wait( mock_waiter.assert_called_with(mock.ANY, name=CLUSTER_NAME) assert_expected_waiter_type(mock_waiter, "ClusterDeleted") + def test_eks_delete_cluster_operator_with_deferrable(self): + self.delete_cluster_operator.deferrable = True + with pytest.raises(TaskDeferred): + self.delete_cluster_operator.execute({}) + class TestEksDeleteNodegroupOperator: def setup_method(self) -> None: diff --git a/tests/providers/amazon/aws/triggers/test_eks.py b/tests/providers/amazon/aws/triggers/test_eks.py index 045519aea57e8..023f8d2a97d46 100644 --- a/tests/providers/amazon/aws/triggers/test_eks.py +++ b/tests/providers/amazon/aws/triggers/test_eks.py @@ -19,8 +19,10 @@ import pytest from airflow.providers.amazon.aws.triggers.eks import ( + EksCreateClusterTrigger, EksCreateFargateProfileTrigger, EksCreateNodegroupTrigger, + EksDeleteClusterTrigger, EksDeleteFargateProfileTrigger, EksDeleteNodegroupTrigger, ) @@ -68,6 +70,21 @@ class TestEksTriggers: waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS, region_name=TEST_REGION, ), + EksCreateClusterTrigger( + cluster_name=TEST_CLUSTER_IDENTIFIER, + waiter_delay=TEST_WAITER_DELAY, + waiter_max_attempts=TEST_WAITER_DELAY, + aws_conn_id=TEST_AWS_CONN_ID, + region_name=TEST_REGION, + ), + EksDeleteClusterTrigger( + cluster_name=TEST_CLUSTER_IDENTIFIER, + waiter_delay=TEST_WAITER_DELAY, + waiter_max_attempts=TEST_WAITER_DELAY, + aws_conn_id=TEST_AWS_CONN_ID, + region_name=TEST_REGION, + force_delete_compute=True, + ), ], ) def test_serialize_recreate(self, trigger): diff --git a/tests/system/providers/amazon/aws/example_eks_with_fargate_in_one_step.py b/tests/system/providers/amazon/aws/example_eks_with_fargate_in_one_step.py index ae67a26588bdc..6b9907d836108 100644 --- a/tests/system/providers/amazon/aws/example_eks_with_fargate_in_one_step.py +++ b/tests/system/providers/amazon/aws/example_eks_with_fargate_in_one_step.py @@ -81,6 +81,9 @@ # Opting to use the same ARN for the cluster and the pod here, # but a different ARN could be configured and passed if desired. fargate_pod_execution_role_arn=fargate_pod_role_arn, + deferrable=True, + waiter_delay=30, + wait_for_completion=399, ) # [END howto_operator_eks_create_cluster_with_fargate_profile]