Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 68 additions & 10 deletions airflow/providers/amazon/aws/operators/eks.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from airflow.providers.amazon.aws.triggers.eks import (
EksCreateFargateProfileTrigger,
EksDeleteFargateProfileTrigger,
EksNodegroupTrigger,
)
from airflow.providers.amazon.aws.utils.waiter_with_logging import wait

Expand Down Expand Up @@ -183,8 +184,8 @@ class EksCreateClusterOperator(BaseOperator):
:param fargate_selectors: The selectors to match for pods to use this AWS Fargate profile. (templated)
:param create_fargate_profile_kwargs: Optional parameters to pass to the CreateFargateProfile API
(templated)
:param waiter_delay: Time (in seconds) to wait between two consecutive calls to check cluster status
:param waiter_max_attempts: The maximum number of attempts to check the status of the cluster.
:param waiter_delay: Time (in seconds) to wait between two consecutive calls to check cluster state
:param waiter_max_attempts: The maximum number of attempts to check cluster state

"""

Expand Down Expand Up @@ -333,8 +334,11 @@ class EksCreateNodegroupOperator(BaseOperator):
maintained on each worker node).
:param region: Which AWS region the connection should use. (templated)
If this is None or empty then the default boto3 behaviour is used.
:param waiter_delay: Time (in seconds) to wait between two consecutive calls to check nodegroup status
:param waiter_max_attempts: The maximum number of attempts to check the status of the nodegroup.
:param waiter_delay: Time (in seconds) to wait between two consecutive calls to check nodegroup state
:param waiter_max_attempts: The maximum number of attempts to check nodegroup state
:param deferrable: If True, the operator will wait asynchronously for the nodegroup to be created.
This implies waiting for completion. This mode requires aiobotocore module to be installed.
(default: False)

"""

Expand All @@ -361,6 +365,7 @@ def __init__(
region: str | None = None,
waiter_delay: int = 30,
waiter_max_attempts: int = 80,
deferrable: bool = False,
**kwargs,
) -> None:
self.nodegroup_subnets = nodegroup_subnets
Expand All @@ -369,15 +374,13 @@ def __init__(
self.nodegroup_role_arn = nodegroup_role_arn
self.nodegroup_name = nodegroup_name
self.create_nodegroup_kwargs = create_nodegroup_kwargs or {}
self.wait_for_completion = wait_for_completion
self.wait_for_completion = False if deferrable else wait_for_completion
self.aws_conn_id = aws_conn_id
self.region = region
self.waiter_delay = waiter_delay
self.waiter_max_attempts = waiter_max_attempts

super().__init__(
**kwargs,
)
self.deferrable = deferrable
super().__init__(**kwargs)

def execute(self, context: Context):
self.log.info(self.task_id)
Expand All @@ -393,6 +396,7 @@ def execute(self, context: Context):
self.nodegroup_subnets,
)
self.nodegroup_subnets = nodegroup_subnets_list

_create_compute(
compute=self.compute,
cluster_name=self.cluster_name,
Expand All @@ -407,6 +411,28 @@ def execute(self, context: Context):
subnets=self.nodegroup_subnets,
)

if self.deferrable:
self.defer(
trigger=EksNodegroupTrigger(
waiter_name="nodegroup_active",
cluster_name=self.cluster_name,
nodegroup_name=self.nodegroup_name,
aws_conn_id=self.aws_conn_id,
region=self.region,
waiter_delay=self.waiter_delay,
waiter_max_attempts=self.waiter_max_attempts,
),
method_name="execute_complete",
# timeout is set to ensure that if a trigger dies, the timeout does not restart
# 60 seconds is added to allow the trigger to exit gracefully (i.e. yield TriggerEvent)
timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay + 60),
)

def execute_complete(self, context, event=None):
if event["status"] != "success":
raise AirflowException(f"Error creating nodegroup: {event}")
return


class EksCreateFargateProfileOperator(BaseOperator):
"""
Expand Down Expand Up @@ -638,6 +664,11 @@ class EksDeleteNodegroupOperator(BaseOperator):
maintained on each worker node).
:param region: Which AWS region the connection should use. (templated)
If this is None or empty then the default boto3 behaviour is used.
:param waiter_delay: Time (in seconds) to wait between two consecutive calls to check nodegroup state
:param waiter_max_attempts: The maximum number of attempts to check nodegroup state
:param deferrable: If True, the operator will wait asynchronously for the nodegroup to be deleted.
This implies waiting for completion. This mode requires aiobotocore module to be installed.
(default: False)

"""

Expand All @@ -656,13 +687,19 @@ def __init__(
wait_for_completion: bool = False,
aws_conn_id: str = DEFAULT_CONN_ID,
region: str | None = None,
waiter_delay: int = 30,
waiter_max_attempts: int = 40,
deferrable: bool = False,
**kwargs,
) -> None:
self.cluster_name = cluster_name
self.nodegroup_name = nodegroup_name
self.wait_for_completion = wait_for_completion
self.aws_conn_id = aws_conn_id
self.region = region
self.waiter_delay = waiter_delay
self.waiter_max_attempts = waiter_max_attempts
self.deferrable = deferrable
super().__init__(**kwargs)

def execute(self, context: Context):
Expand All @@ -672,12 +709,33 @@ def execute(self, context: Context):
)

eks_hook.delete_nodegroup(clusterName=self.cluster_name, nodegroupName=self.nodegroup_name)
if self.wait_for_completion:
if self.deferrable:
self.defer(
trigger=EksNodegroupTrigger(
waiter_name="nodegroup_deleted",
cluster_name=self.cluster_name,
nodegroup_name=self.nodegroup_name,
aws_conn_id=self.aws_conn_id,
region=self.region,
waiter_delay=self.waiter_delay,
waiter_max_attempts=self.waiter_max_attempts,
),
method_name="execute_complete",
# timeout is set to ensure that if a trigger dies, the timeout does not restart
# 60 seconds is added to allow the trigger to exit gracefully (i.e. yield TriggerEvent)
timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay + 60),
)
elif self.wait_for_completion:
self.log.info("Waiting for nodegroup to delete. This will take some time.")
eks_hook.conn.get_waiter("nodegroup_deleted").wait(
clusterName=self.cluster_name, nodegroupName=self.nodegroup_name
)

def execute_complete(self, context, event=None):
if event["status"] != "success":
raise AirflowException(f"Error deleting nodegroup: {event}")
return


class EksDeleteFargateProfileOperator(BaseOperator):
"""
Expand Down
73 changes: 73 additions & 0 deletions airflow/providers/amazon/aws/triggers/eks.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.eks import EksHook
from airflow.providers.amazon.aws.utils.waiter_with_logging import async_wait
from airflow.triggers.base import BaseTrigger, TriggerEvent


Expand Down Expand Up @@ -164,3 +165,75 @@ async def run(self):
)
else:
yield TriggerEvent({"status": "success", "message": "Fargate Profile Deleted"})


class EksNodegroupTrigger(BaseTrigger):
"""
Trigger for EksCreateNodegroupOperator and EksDeleteNodegroupOperator.

The trigger will asynchronously poll the boto3 API and wait for the
nodegroup to be in the state specified by the waiter.

:param waiter_name: Name of the waiter to use, for instance 'nodegroup_active' or 'nodegroup_deleted'
:param cluster_name: The name of the EKS cluster associated with the node group.
:param nodegroup_name: The name of the nodegroup to check.
:param waiter_delay: The amount of time in seconds to wait between attempts.
:param waiter_max_attempts: The maximum number of attempts to be made.
:param aws_conn_id: The Airflow connection used for AWS credentials.
:param region: Which AWS region the connection should use. (templated)
If this is None or empty then the default boto3 behaviour is used.
"""

def __init__(
self,
waiter_name: str,
cluster_name: str,
nodegroup_name: str,
waiter_delay: int,
waiter_max_attempts: int,
aws_conn_id: str,
region: str | None,
):
self.waiter_name = waiter_name
self.cluster_name = cluster_name
self.nodegroup_name = nodegroup_name
self.aws_conn_id = aws_conn_id
self.waiter_delay = waiter_delay
self.waiter_max_attempts = waiter_max_attempts
self.region = region

def serialize(self) -> tuple[str, dict[str, Any]]:
return (
self.__class__.__module__ + "." + self.__class__.__qualname__,
{
"waiter_name": self.waiter_name,
"cluster_name": self.cluster_name,
"nodegroup_name": self.nodegroup_name,
"waiter_delay": str(self.waiter_delay),
"waiter_max_attempts": str(self.waiter_max_attempts),
"aws_conn_id": self.aws_conn_id,
"region": self.region,
},
)

async def run(self):
self.hook = EksHook(aws_conn_id=self.aws_conn_id, region_name=self.region)
async with self.hook.async_conn as client:
waiter = client.get_waiter(self.waiter_name)
await async_wait(
waiter=waiter,
waiter_max_attempts=int(self.waiter_max_attempts),
waiter_delay=int(self.waiter_delay),
args={"clusterName": self.cluster_name, "nodegroupName": self.nodegroup_name},
failure_message="Error checking nodegroup",
status_message="Nodegroup status is",
status_args=["nodegroup.status"],
)

yield TriggerEvent(
{
"status": "success",
"cluster_name": self.cluster_name,
"nodegroup_name": self.nodegroup_name,
}
)
2 changes: 2 additions & 0 deletions docs/apache-airflow-providers-amazon/operators/eks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ Create an Amazon EKS managed node group

To create an Amazon EKS managed node group you can use
:class:`~airflow.providers.amazon.aws.operators.eks.EksCreateNodegroupOperator`.
You can also run this operator in deferrable mode by setting ``deferrable`` param to ``True``.

Note: An AWS IAM role with the following permissions is required:
``ec2.amazon.aws.com`` must be in the Trusted Relationships
Expand All @@ -140,6 +141,7 @@ Delete an Amazon EKS managed node group

To delete an existing Amazon EKS managed node group you can use
:class:`~airflow.providers.amazon.aws.operators.eks.EksDeleteNodegroupOperator`.
You can also run this operator in deferrable mode by setting ``deferrable`` param to ``True``.

.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_eks_with_nodegroups.py
:language: python
Expand Down
31 changes: 31 additions & 0 deletions tests/providers/amazon/aws/operators/test_eks.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from airflow.providers.amazon.aws.triggers.eks import (
EksCreateFargateProfileTrigger,
EksDeleteFargateProfileTrigger,
EksNodegroupTrigger,
)
from airflow.typing_compat import TypedDict
from tests.providers.amazon.aws.utils.eks_test_constants import (
Expand Down Expand Up @@ -476,6 +477,36 @@ def test_execute_with_wait_when_nodegroup_does_not_already_exist(
mock_waiter.assert_called_with(mock.ANY, clusterName=CLUSTER_NAME, nodegroupName=NODEGROUP_NAME)
assert_expected_waiter_type(mock_waiter, "NodegroupActive")

@mock.patch.object(EksHook, "create_nodegroup")
def test_create_nodegroup_deferrable(self, mock_create_nodegroup):
mock_create_nodegroup.return_value = True
op_kwargs = {**self.create_nodegroup_params}
operator = EksCreateNodegroupOperator(
task_id=TASK_ID,
**op_kwargs,
deferrable=True,
)
with pytest.raises(TaskDeferred) as exc:
operator.execute({})
assert isinstance(exc.value.trigger, EksNodegroupTrigger), "Trigger is not a EksNodegroupTrigger"

def test_create_nodegroup_deferrable_versus_wait_for_completion(self):
op_kwargs = {**self.create_nodegroup_params}
operator = EksCreateNodegroupOperator(
task_id=TASK_ID,
**op_kwargs,
deferrable=True,
wait_for_completion=True,
)
assert operator.wait_for_completion is False
operator = EksCreateNodegroupOperator(
task_id=TASK_ID,
**op_kwargs,
deferrable=False,
wait_for_completion=True,
)
assert operator.wait_for_completion is True


class TestEksDeleteClusterOperator:
def setup_method(self) -> None:
Expand Down
Loading