Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 140 additions & 7 deletions airflow/providers/amazon/aws/operators/eks.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import warnings
from ast import literal_eval
from datetime import timedelta
from typing import TYPE_CHECKING, List, Sequence, cast
from typing import TYPE_CHECKING, Any, List, Sequence, cast

from botocore.exceptions import ClientError, WaiterError

Expand All @@ -30,8 +30,10 @@
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.eks import EksHook
from airflow.providers.amazon.aws.triggers.eks import (
EksCreateClusterTrigger,
EksCreateFargateProfileTrigger,
EksCreateNodegroupTrigger,
EksDeleteClusterTrigger,
EksDeleteFargateProfileTrigger,
EksDeleteNodegroupTrigger,
)
Expand Down Expand Up @@ -187,6 +189,9 @@ class EksCreateClusterOperator(BaseOperator):
(templated)
:param waiter_delay: Time (in seconds) to wait between two consecutive calls to check cluster state
:param waiter_max_attempts: The maximum number of attempts to check cluster state
:param deferrable: If True, the operator will wait asynchronously for the job to complete.
This implies waiting for completion. This mode requires aiobotocore module to be installed.
(default: False)

"""

Expand Down Expand Up @@ -225,6 +230,7 @@ def __init__(
wait_for_completion: bool = False,
aws_conn_id: str = DEFAULT_CONN_ID,
region: str | None = None,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
waiter_delay: int = 30,
waiter_max_attempts: int = 40,
**kwargs,
Expand All @@ -237,7 +243,7 @@ def __init__(
self.nodegroup_role_arn = nodegroup_role_arn
self.fargate_pod_execution_role_arn = fargate_pod_execution_role_arn
self.create_fargate_profile_kwargs = create_fargate_profile_kwargs or {}
self.wait_for_completion = wait_for_completion
self.wait_for_completion = False if deferrable else wait_for_completion
self.waiter_delay = waiter_delay
self.waiter_max_attempts = waiter_max_attempts
self.aws_conn_id = aws_conn_id
Expand All @@ -246,6 +252,7 @@ def __init__(
self.create_nodegroup_kwargs = create_nodegroup_kwargs or {}
self.fargate_selectors = fargate_selectors or [{"namespace": DEFAULT_NAMESPACE_NAME}]
self.fargate_profile_name = fargate_profile_name
self.deferrable = deferrable
super().__init__(
**kwargs,
)
Expand Down Expand Up @@ -274,12 +281,25 @@ def execute(self, context: Context):

# Short circuit early if we don't need to wait to attach compute
# and the caller hasn't requested to wait for the cluster either.
if not self.compute and not self.wait_for_completion:
if not any([self.compute, self.wait_for_completion, self.deferrable]):
return None

self.log.info("Waiting for EKS Cluster to provision. This will take some time.")
self.log.info("Waiting for EKS Cluster to provision. This will take some time.")
client = self.eks_hook.conn

if self.deferrable:
self.defer(
trigger=EksCreateClusterTrigger(
cluster_name=self.cluster_name,
aws_conn_id=self.aws_conn_id,
region_name=self.region,
waiter_delay=self.waiter_delay,
waiter_max_attempts=self.waiter_max_attempts,
),
method_name="deferrable_create_cluster_next",
timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay),
)

try:
client.get_waiter("cluster_active").wait(
name=self.cluster_name,
Expand Down Expand Up @@ -311,6 +331,89 @@ def execute(self, context: Context):
subnets=cast(List[str], self.resources_vpc_config.get("subnetIds")),
)

def deferrable_create_cluster_next(self, context: Context, event: dict[str, Any] | None = None) -> None:
if event is None:
self.log.error("Trigger error: event is None")
raise AirflowException("Trigger error: event is None")
elif event["status"] == "failed":
self.log.error("Cluster failed to start and will be torn down.")
self.eks_hook.delete_cluster(name=self.cluster_name)
self.defer(
trigger=EksDeleteClusterTrigger(
cluster_name=self.cluster_name,
waiter_delay=self.waiter_delay,
waiter_max_attempts=self.waiter_max_attempts,
aws_conn_id=self.aws_conn_id,
region_name=self.region,
force_delete_compute=False,
),
method_name="execute_failed",
timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay),
)
elif event["status"] == "success":
self.log.info("Cluster is ready to provision compute.")
_create_compute(
compute=self.compute,
cluster_name=self.cluster_name,
aws_conn_id=self.aws_conn_id,
region=self.region,
wait_for_completion=self.wait_for_completion,
waiter_delay=self.waiter_delay,
waiter_max_attempts=self.waiter_max_attempts,
nodegroup_name=self.nodegroup_name,
nodegroup_role_arn=self.nodegroup_role_arn,
create_nodegroup_kwargs=self.create_nodegroup_kwargs,
fargate_profile_name=self.fargate_profile_name,
fargate_pod_execution_role_arn=self.fargate_pod_execution_role_arn,
fargate_selectors=self.fargate_selectors,
create_fargate_profile_kwargs=self.create_fargate_profile_kwargs,
subnets=cast(List[str], self.resources_vpc_config.get("subnetIds")),
)
if self.compute == "fargate":
self.defer(
trigger=EksCreateFargateProfileTrigger(
cluster_name=self.cluster_name,
fargate_profile_name=self.fargate_profile_name,
waiter_delay=self.waiter_delay,
waiter_max_attempts=self.waiter_max_attempts,
aws_conn_id=self.aws_conn_id,
region=self.region,
),
method_name="execute_complete",
timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay),
)
else:
self.defer(
trigger=EksCreateNodegroupTrigger(
nodegroup_name=self.nodegroup_name,
cluster_name=self.cluster_name,
aws_conn_id=self.aws_conn_id,
region_name=self.region,
waiter_delay=self.waiter_delay,
waiter_max_attempts=self.waiter_max_attempts,
),
method_name="execute_complete",
timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay),
)

def execute_failed(self, context: Context, event: dict[str, Any] | None = None) -> None:
if event is None:
self.log.info("Trigger error: event is None")
raise AirflowException("Trigger error: event is None")
elif event["status"] == "delteted":
self.log.info("Cluster deleted")
raise event["exception"]

def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None:
resource = "fargate profile" if self.compute == "fargate" else self.compute
if event is None:
self.log.info("Trigger error: event is None")
raise AirflowException("Trigger error: event is None")
elif event["status"] != "success":
raise AirflowException(f"Error creating {resource}: {event}")

self.log.info("%s created successfully", resource)


class EksCreateNodegroupOperator(BaseOperator):
"""
Expand Down Expand Up @@ -564,6 +667,11 @@ class EksDeleteClusterOperator(BaseOperator):
maintained on each worker node).
:param region: Which AWS region the connection should use. (templated)
If this is None or empty then the default boto3 behaviour is used.
:param waiter_delay: Time (in seconds) to wait between two consecutive calls to check cluster state
:param waiter_max_attempts: The maximum number of attempts to check cluster state
:param deferrable: If True, the operator will wait asynchronously for the cluster to be deleted.
This implies waiting for completion. This mode requires aiobotocore module to be installed.
(default: False)

"""

Expand All @@ -582,22 +690,40 @@ def __init__(
wait_for_completion: bool = False,
aws_conn_id: str = DEFAULT_CONN_ID,
region: str | None = None,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
waiter_delay: int = 30,
waiter_max_attempts: int = 40,
**kwargs,
) -> None:
self.cluster_name = cluster_name
self.force_delete_compute = force_delete_compute
self.wait_for_completion = wait_for_completion
self.wait_for_completion = False if deferrable else wait_for_completion
self.aws_conn_id = aws_conn_id
self.region = region
self.deferrable = deferrable
self.waiter_delay = waiter_delay
self.waiter_max_attempts = waiter_max_attempts
super().__init__(**kwargs)

def execute(self, context: Context):
eks_hook = EksHook(
aws_conn_id=self.aws_conn_id,
region_name=self.region,
)

if self.force_delete_compute:
if self.deferrable:
self.defer(
trigger=EksDeleteClusterTrigger(
cluster_name=self.cluster_name,
waiter_delay=self.waiter_delay,
waiter_max_attempts=self.waiter_max_attempts,
aws_conn_id=self.aws_conn_id,
region_name=self.region,
force_delete_compute=self.force_delete_compute,
),
method_name="execute_complete",
timeout=timedelta(seconds=self.waiter_delay * self.waiter_max_attempts),
)
elif self.force_delete_compute:
self.delete_any_nodegroups(eks_hook)
self.delete_any_fargate_profiles(eks_hook)

Expand Down Expand Up @@ -645,6 +771,13 @@ def delete_any_fargate_profiles(self, eks_hook) -> None:
)
self.log.info(SUCCESS_MSG.format(compute=FARGATE_FULL_NAME))

def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None:
if event is None:
self.log.error("Trigger error. Event is None")
raise AirflowException("Trigger error. Event is None")
elif event["status"] == "success":
self.log.info("Cluster deleted successfully.")


class EksDeleteNodegroupOperator(BaseOperator):
"""
Expand Down
Loading