Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 78 additions & 6 deletions airflow/providers/amazon/aws/operators/eks.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,18 @@

import warnings
from ast import literal_eval
from datetime import timedelta
from typing import TYPE_CHECKING, Any, List, Sequence, cast

from botocore.exceptions import ClientError, WaiterError

from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.eks import EksHook
from airflow.providers.amazon.aws.triggers.eks import (
EksCreateFargateProfileTrigger,
EksDeleteFargateProfileTrigger,
)

try:
from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator
Expand Down Expand Up @@ -353,6 +358,11 @@ class EksCreateFargateProfileOperator(BaseOperator):
maintained on each worker node).
:param region: Which AWS region the connection should use. (templated)
If this is None or empty then the default boto3 behaviour is used.
:param waiter_delay: Time (in seconds) to wait between two consecutive calls to check profile status
:param waiter_max_attempts: The maximum number of attempts to check the status of the profile.
:param deferrable: If True, the operator will wait asynchronously for the profile to be created.
This implies waiting for completion. This mode requires aiobotocore module to be installed.
(default: False)
"""

template_fields: Sequence[str] = (
Expand All @@ -371,11 +381,14 @@ def __init__(
cluster_name: str,
pod_execution_role_arn: str,
selectors: list,
fargate_profile_name: str | None = DEFAULT_FARGATE_PROFILE_NAME,
fargate_profile_name: str = DEFAULT_FARGATE_PROFILE_NAME,
create_fargate_profile_kwargs: dict | None = None,
wait_for_completion: bool = False,
aws_conn_id: str = DEFAULT_CONN_ID,
region: str | None = None,
waiter_delay: int = 10,
waiter_max_attempts: int = 60,
deferrable: bool = False,
**kwargs,
) -> None:
self.cluster_name = cluster_name
Expand All @@ -386,6 +399,9 @@ def __init__(
self.wait_for_completion = wait_for_completion
self.aws_conn_id = aws_conn_id
self.region = region
self.waiter_delay = waiter_delay
self.waiter_max_attempts = waiter_max_attempts
self.deferrable = deferrable
super().__init__(**kwargs)

def execute(self, context: Context):
Expand All @@ -401,13 +417,35 @@ def execute(self, context: Context):
selectors=self.selectors,
**self.create_fargate_profile_kwargs,
)

if self.wait_for_completion:
if self.deferrable:
self.defer(
trigger=EksCreateFargateProfileTrigger(
cluster_name=self.cluster_name,
fargate_profile_name=self.fargate_profile_name,
aws_conn_id=self.aws_conn_id,
poll_interval=self.waiter_delay,
max_attempts=self.waiter_max_attempts,
),
method_name="execute_complete",
# timeout is set to ensure that if a trigger dies, the timeout does not restart
# 60 seconds is added to allow the trigger to exit gracefully (i.e. yield TriggerEvent)
timeout=timedelta(seconds=(self.waiter_max_attempts * self.waiter_delay + 60)),
)
elif self.wait_for_completion:
self.log.info("Waiting for Fargate profile to provision. This will take some time.")
eks_hook.conn.get_waiter("fargate_profile_active").wait(
clusterName=self.cluster_name, fargateProfileName=self.fargate_profile_name
clusterName=self.cluster_name,
fargateProfileName=self.fargate_profile_name,
WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts": self.waiter_max_attempts},
)

def execute_complete(self, context, event=None):
if event["status"] != "success":
raise AirflowException(f"Error creating Fargate profile: {event}")
else:
self.log.info("Fargate profile created successfully")
return


class EksDeleteClusterOperator(BaseOperator):
"""
Expand Down Expand Up @@ -587,6 +625,11 @@ class EksDeleteFargateProfileOperator(BaseOperator):
maintained on each worker node).
:param region: Which AWS region the connection should use. (templated)
If this is None or empty then the default boto3 behaviour is used.
:param waiter_delay: Time (in seconds) to wait between two consecutive calls to check profile status
:param waiter_max_attempts: The maximum number of attempts to check the status of the profile.
:param deferrable: If True, the operator will wait asynchronously for the profile to be deleted.
This implies waiting for completion. This mode requires aiobotocore module to be installed.
(default: False)
"""

template_fields: Sequence[str] = (
Expand All @@ -604,6 +647,9 @@ def __init__(
wait_for_completion: bool = False,
aws_conn_id: str = DEFAULT_CONN_ID,
region: str | None = None,
waiter_delay: int = 30,
waiter_max_attempts: int = 60,
deferrable: bool = False,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -612,6 +658,9 @@ def __init__(
self.wait_for_completion = wait_for_completion
self.aws_conn_id = aws_conn_id
self.region = region
self.waiter_delay = waiter_delay
self.waiter_max_attempts = waiter_max_attempts
self.deferrable = deferrable

def execute(self, context: Context):
eks_hook = EksHook(
Expand All @@ -622,12 +671,35 @@ def execute(self, context: Context):
eks_hook.delete_fargate_profile(
clusterName=self.cluster_name, fargateProfileName=self.fargate_profile_name
)
if self.wait_for_completion:
if self.deferrable:
self.defer(
trigger=EksDeleteFargateProfileTrigger(
cluster_name=self.cluster_name,
fargate_profile_name=self.fargate_profile_name,
aws_conn_id=self.aws_conn_id,
poll_interval=self.waiter_delay,
max_attempts=self.waiter_max_attempts,
),
method_name="execute_complete",
# timeout is set to ensure that if a trigger dies, the timeout does not restart
# 60 seconds is added to allow the trigger to exit gracefully (i.e. yield TriggerEvent)
timeout=timedelta(seconds=(self.waiter_max_attempts * self.waiter_delay + 60)),
)
elif self.wait_for_completion:
self.log.info("Waiting for Fargate profile to delete. This will take some time.")
eks_hook.conn.get_waiter("fargate_profile_deleted").wait(
clusterName=self.cluster_name, fargateProfileName=self.fargate_profile_name
clusterName=self.cluster_name,
fargateProfileName=self.fargate_profile_name,
WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts": self.waiter_max_attempts},
)

def execute_complete(self, context, event=None):
if event["status"] != "success":
raise AirflowException(f"Error deleting Fargate profile: {event}")
else:
self.log.info("Fargate profile deleted successfully")
return


class EksPodOperator(KubernetesPodOperator):
"""
Expand Down
160 changes: 160 additions & 0 deletions airflow/providers/amazon/aws/triggers/eks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import asyncio
from typing import Any

from botocore.exceptions import WaiterError

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.eks import EksHook
from airflow.triggers.base import BaseTrigger, TriggerEvent


class EksCreateFargateProfileTrigger(BaseTrigger):
"""
Trigger for EksCreateFargateProfileOperator.
The trigger will asynchronously wait for the fargate profile to be created.

:param cluster_name: The name of the EKS cluster
:param fargate_profile_name: The name of the fargate profile
:param poll_interval: The amount of time in seconds to wait between attempts.
:param max_attempts: The maximum number of attempts to be made.
:param aws_conn_id: The Airflow connection used for AWS credentials.
"""

def __init__(
self,
cluster_name: str,
fargate_profile_name: str,
poll_interval: int,
max_attempts: int,
aws_conn_id: str,
):
self.cluster_name = cluster_name
self.fargate_profile_name = fargate_profile_name
self.poll_interval = poll_interval
self.max_attempts = max_attempts
self.aws_conn_id = aws_conn_id

def serialize(self) -> tuple[str, dict[str, Any]]:
return (
self.__class__.__module__ + "." + self.__class__.__qualname__,
{
"cluster_name": self.cluster_name,
"fargate_profile_name": self.fargate_profile_name,
"poll_interval": str(self.poll_interval),
"max_attempts": str(self.max_attempts),
"aws_conn_id": self.aws_conn_id,
},
)

async def run(self):
self.hook = EksHook(aws_conn_id=self.aws_conn_id)
async with self.hook.async_conn as client:
attempt = 0
waiter = client.get_waiter("fargate_profile_active")
while attempt < int(self.max_attempts):
attempt += 1
try:
await waiter.wait(
clusterName=self.cluster_name,
fargateProfileName=self.fargate_profile_name,
WaiterConfig={"Delay": int(self.poll_interval), "MaxAttempts": 1},
)
break
except WaiterError as error:
if "terminal failure" in str(error):
raise AirflowException(f"Create Fargate Profile failed: {error}")
self.log.info(
"Status of fargate profile is %s", error.last_response["fargateProfile"]["status"]
)
await asyncio.sleep(int(self.poll_interval))
if attempt >= int(self.max_attempts):
raise AirflowException(
f"Create Fargate Profile failed - max attempts reached: {self.max_attempts}"
)
else:
yield TriggerEvent({"status": "success", "message": "Fargate Profile Created"})


class EksDeleteFargateProfileTrigger(BaseTrigger):
"""
Trigger for EksDeleteFargateProfileOperator.
The trigger will asynchronously wait for the fargate profile to be deleted.

:param cluster_name: The name of the EKS cluster
:param fargate_profile_name: The name of the fargate profile
:param poll_interval: The amount of time in seconds to wait between attempts.
:param max_attempts: The maximum number of attempts to be made.
:param aws_conn_id: The Airflow connection used for AWS credentials.
"""

def __init__(
self,
cluster_name: str,
fargate_profile_name: str,
poll_interval: int,
max_attempts: int,
aws_conn_id: str,
):
self.cluster_name = cluster_name
self.fargate_profile_name = fargate_profile_name
self.poll_interval = poll_interval
self.max_attempts = max_attempts
self.aws_conn_id = aws_conn_id

def serialize(self) -> tuple[str, dict[str, Any]]:
return (
self.__class__.__module__ + "." + self.__class__.__qualname__,
{
"cluster_name": self.cluster_name,
"fargate_profile_name": self.fargate_profile_name,
"poll_interval": str(self.poll_interval),
"max_attempts": str(self.max_attempts),
"aws_conn_id": self.aws_conn_id,
},
)

async def run(self):
self.hook = EksHook(aws_conn_id=self.aws_conn_id)
async with self.hook.async_conn as client:
attempt = 0
waiter = client.get_waiter("fargate_profile_deleted")
while attempt < int(self.max_attempts):
attempt += 1
try:
await waiter.wait(
clusterName=self.cluster_name,
fargateProfileName=self.fargate_profile_name,
WaiterConfig={"Delay": int(self.poll_interval), "MaxAttempts": 1},
)
break
except WaiterError as error:
if "terminal failure" in str(error):
raise AirflowException(f"Delete Fargate Profile failed: {error}")
self.log.info(
"Status of fargate profile is %s", error.last_response["fargateProfile"]["status"]
)
await asyncio.sleep(int(self.poll_interval))
if attempt >= int(self.max_attempts):
raise AirflowException(
f"Delete Fargate Profile failed - max attempts reached: {self.max_attempts}"
)
else:
yield TriggerEvent({"status": "success", "message": "Fargate Profile Deleted"})
3 changes: 3 additions & 0 deletions airflow/providers/amazon/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,9 @@ triggers:
- integration-name: Amazon EMR
python-modules:
- airflow.providers.amazon.aws.triggers.emr
- integration-name: Amazon Elastic Kubernetes Service (EKS)
python-modules:
- airflow.providers.amazon.aws.triggers.eks

transfers:
- source-integration-name: Amazon DynamoDB
Expand Down
Loading