From f4ea7d9157d650c1cca67addc983bb83c2656b74 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Vandon?= Date: Thu, 25 May 2023 16:40:56 -0700 Subject: [PATCH 01/20] add deferrable mode for ECS Create Cluster --- airflow/providers/amazon/aws/operators/ecs.py | 22 +++++ airflow/providers/amazon/aws/triggers/ecs.py | 87 +++++++++++++++++++ .../amazon/aws/operators/test_ecs.py | 25 +++++- 3 files changed, 132 insertions(+), 2 deletions(-) create mode 100644 airflow/providers/amazon/aws/triggers/ecs.py diff --git a/airflow/providers/amazon/aws/operators/ecs.py b/airflow/providers/amazon/aws/operators/ecs.py index d4f0bdf045c76..6a9d3d3d05001 100644 --- a/airflow/providers/amazon/aws/operators/ecs.py +++ b/airflow/providers/amazon/aws/operators/ecs.py @@ -36,6 +36,7 @@ EcsTaskLogFetcher, should_retry_eni, ) +from airflow.providers.amazon.aws.triggers.ecs import ClusterActiveTrigger from airflow.utils.helpers import prune_dict from airflow.utils.session import provide_session @@ -84,6 +85,9 @@ class EcsCreateClusterOperator(EcsBaseOperator): if not set then the default waiter value will be used. :param waiter_max_attempts: The maximum number of attempts to be made, if not set then the default waiter value will be used. + :param deferrable: If True, the operator will wait asynchronously for the job to complete. + This implies waiting for completion. This mode requires aiobotocore module to be installed. + (default: False) """ template_fields: Sequence[str] = ("cluster_name", "create_cluster_kwargs", "wait_for_completion") @@ -96,6 +100,7 @@ def __init__( wait_for_completion: bool = True, waiter_delay: int | None = None, waiter_max_attempts: int | None = None, + deferrable: bool = False, **kwargs, ) -> None: super().__init__(**kwargs) @@ -104,6 +109,7 @@ def __init__( self.wait_for_completion = wait_for_completion self.waiter_delay = waiter_delay self.waiter_max_attempts = waiter_max_attempts + self.deferrable = deferrable def execute(self, context: Context): self.log.info( @@ -119,6 +125,17 @@ def execute(self, context: Context): # In some circumstances the ECS Cluster is created immediately, # and there is no reason to wait for completion. self.log.info("Cluster %r in state: %r.", self.cluster_name, cluster_state) + elif self.deferrable: + self.defer( + trigger=ClusterActiveTrigger( + cluster_arn=cluster_details["clusterArn"], + waiter_delay=self.waiter_delay, + waiter_max_attempts=self.waiter_max_attempts, + aws_conn_id=self.aws_conn_id, + region=self.region, + ), + method_name="execute_complete", + ) elif self.wait_for_completion: waiter = self.hook.get_waiter("cluster_active") waiter.wait( @@ -133,6 +150,11 @@ def execute(self, context: Context): return cluster_details + def execute_complete(self, context, event=None): + if event["status"] != "success": + raise AirflowException(f"Error in cluster creation: {event}") + return event.get("value") + class EcsDeleteClusterOperator(EcsBaseOperator): """ diff --git a/airflow/providers/amazon/aws/triggers/ecs.py b/airflow/providers/amazon/aws/triggers/ecs.py new file mode 100644 index 0000000000000..b07594326ff02 --- /dev/null +++ b/airflow/providers/amazon/aws/triggers/ecs.py @@ -0,0 +1,87 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import asyncio +from typing import Any, AsyncIterator + +from botocore.exceptions import WaiterError + +from airflow.providers.amazon.aws.hooks.ecs import EcsHook +from airflow.triggers.base import BaseTrigger, TriggerEvent + + +class ClusterActiveTrigger(BaseTrigger): + """ + Waits for a cluster to be active, triggers when it finishes + + :param cluster_arn: ARN of the cluster to watch. + :param waiter_delay: The amount of time in seconds to wait between attempts. + :param waiter_max_attempts: The number of times to ping for status. + Will fail after that many unsuccessful attempts. + :param aws_conn_id: The Airflow connection used for AWS credentials. + :param region: The AWS region where the cluster is located. + """ + + def __init__( + self, + cluster_arn: str, + waiter_delay: int | None, + waiter_max_attempts: int | None, + aws_conn_id: str | None, + region: str | None, + ): + self.cluster_arn = cluster_arn + self.waiter_delay = waiter_delay or 15 + self.attempts = waiter_max_attempts or 999999999 + self.aws_conn_id = aws_conn_id + self.region = region + + def serialize(self) -> tuple[str, dict[str, Any]]: + return ( + self.__class__.__module__ + "." + self.__class__.__qualname__, + { + "cluster_arn": self.cluster_arn, + "waiter_delay": self.waiter_delay, + "waiter_max_attempts": self.attempts, + "aws_conn_id": self.aws_conn_id, + "region": self.region, + }, + ) + + async def run(self) -> AsyncIterator[TriggerEvent]: + hook = EcsHook(aws_conn_id=self.aws_conn_id, region_name=self.region) + async with hook.async_conn as client: + waiter = hook.get_waiter("cluster_active", deferrable=True, client=client) + while self.attempts >= 1: + self.attempts = self.attempts - 1 + try: + waiter.wait( + clusters=[self.cluster_arn], + WaiterConfig={ + "MaxAttempts": 1, + }, + ) + break # we reach this point only if the waiter met a success criteria + except WaiterError as error: + if "terminal failure" in str(error): + raise + self.log.info("Status of cluster is %s", error.last_response["clusters"][0]["status"]) + await asyncio.sleep(int(self.waiter_delay)) + + yield TriggerEvent({"status": "success", "value": self.cluster_arn}) diff --git a/tests/providers/amazon/aws/operators/test_ecs.py b/tests/providers/amazon/aws/operators/test_ecs.py index cadaa6e329462..f43dd330e6b2b 100644 --- a/tests/providers/amazon/aws/operators/test_ecs.py +++ b/tests/providers/amazon/aws/operators/test_ecs.py @@ -20,13 +20,14 @@ import sys from copy import deepcopy from unittest import mock +from unittest.mock import MagicMock import boto3 import pytest -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, TaskDeferred from airflow.providers.amazon.aws.exceptions import EcsOperatorError, EcsTaskFailToStart -from airflow.providers.amazon.aws.hooks.ecs import EcsHook +from airflow.providers.amazon.aws.hooks.ecs import EcsClusterStates, EcsHook from airflow.providers.amazon.aws.operators.ecs import ( DEFAULT_CONN_ID, EcsBaseOperator, @@ -679,6 +680,26 @@ def test_execute_with_waiter(self, patch_hook_waiters, waiter_delay, waiter_max_ mocked_waiters.wait.assert_called_once_with(clusters=mock.ANY, WaiterConfig=expected_waiter_config) assert result is not None + @mock.patch.object(EcsCreateClusterOperator, "client") + def test_execute_deferrable(self, mock_client: MagicMock): + op = EcsCreateClusterOperator( + task_id="task", + cluster_name=CLUSTER_NAME, + deferrable=True, + waiter_delay=12, + waiter_max_attempts=34, + ) + mock_client.create_cluster.return_value = { + "cluster": {"status": EcsClusterStates.PROVISIONING, "clusterArn": "my arn"} + } + + with pytest.raises(TaskDeferred) as defer: + op.execute(None) + + assert defer.value.trigger.cluster_arn == "my arn" + assert defer.value.trigger.waiter_delay == 12 + assert defer.value.trigger.attempts == 34 + def test_execute_immediate_create(self, patch_hook_waiters): """Test if cluster created during initial request.""" op = EcsCreateClusterOperator(task_id="task", cluster_name=CLUSTER_NAME, wait_for_completion=True) From 4dddc9a50ed95c79a3ad0bb750c6303a01ffe157 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Vandon?= Date: Mon, 5 Jun 2023 16:12:29 -0700 Subject: [PATCH 02/20] execute task - easy part, without the logs --- airflow/providers/amazon/aws/operators/ecs.py | 39 ++++++++++--- airflow/providers/amazon/aws/triggers/ecs.py | 56 +++++++++++++++++++ 2 files changed, 87 insertions(+), 8 deletions(-) diff --git a/airflow/providers/amazon/aws/operators/ecs.py b/airflow/providers/amazon/aws/operators/ecs.py index 6a9d3d3d05001..edff5bd156b82 100644 --- a/airflow/providers/amazon/aws/operators/ecs.py +++ b/airflow/providers/amazon/aws/operators/ecs.py @@ -36,7 +36,7 @@ EcsTaskLogFetcher, should_retry_eni, ) -from airflow.providers.amazon.aws.triggers.ecs import ClusterActiveTrigger +from airflow.providers.amazon.aws.triggers.ecs import ClusterActiveTrigger, TaskDoneTrigger from airflow.utils.helpers import prune_dict from airflow.utils.session import provide_session @@ -512,7 +512,22 @@ def execute(self, context, session=None): if self.reattach: self._try_reattach_task(context) - self._start_wait_check_task(context) + self._start_wait_task(context) + + self._after_execution(session) + + return None + + def execute_complete(self, context, event=None): + if event["status"] != "success": + raise AirflowException(f"Error in task execution: {event}") + self.arn = event["task_arn"] # restore arn to its updated value + self._after_execution() + return None + + @provide_session + def _after_execution(self, session=None): + self._check_success_task() self.log.info("ECS Task has been successfully executed") @@ -524,10 +539,8 @@ def execute(self, context, session=None): if self.do_xcom_push and self.task_log_fetcher: return self.task_log_fetcher.get_last_log_message() - return None - @AwsBaseHook.retry(should_retry_eni) - def _start_wait_check_task(self, context): + def _start_wait_task(self, context): if not self.arn: self._start_task(context) @@ -545,11 +558,20 @@ def _start_wait_check_task(self, context): self.task_log_fetcher.join() else: - if self.wait_for_completion: + if self.deferrable: + self.defer( + trigger=TaskDoneTrigger( + cluster=self.cluster, + task_arn=self.arn, + waiter_delay=self.waiter_delay, + aws_conn_id=self.aws_conn_id, + region=self.region, + ), + method_name="execute_complete", + ) + elif self.wait_for_completion: self._wait_for_task_ended() - self._check_success_task() - def _xcom_del(self, session, task_id): session.query(XCom).filter(XCom.dag_id == self.dag_id, XCom.task_id == task_id).delete() @@ -653,6 +675,7 @@ def _get_task_log_fetcher(self) -> EcsTaskLogFetcher: logger=self.log, ) + @AwsBaseHook.retry(should_retry_eni) def _check_success_task(self) -> None: if not self.client or not self.arn: return diff --git a/airflow/providers/amazon/aws/triggers/ecs.py b/airflow/providers/amazon/aws/triggers/ecs.py index b07594326ff02..dcc32f32d7ecd 100644 --- a/airflow/providers/amazon/aws/triggers/ecs.py +++ b/airflow/providers/amazon/aws/triggers/ecs.py @@ -85,3 +85,59 @@ async def run(self) -> AsyncIterator[TriggerEvent]: await asyncio.sleep(int(self.waiter_delay)) yield TriggerEvent({"status": "success", "value": self.cluster_arn}) + + +class TaskDoneTrigger(BaseTrigger): + """ + Waits for an ECS task to be done + + :param cluster: short name or full ARN of the cluster where the task is running. + :param task_arn: ARN of the task to watch. + :param waiter_delay: The amount of time in seconds to wait between attempts. + :param waiter_max_attempts: The number of times to ping for status. + Will fail after that many unsuccessful attempts. + :param aws_conn_id: The Airflow connection used for AWS credentials. + :param region: The AWS region where the cluster is located. + """ + + def __init__( + self, + cluster: str, + task_arn: str, + waiter_delay: int | None, + aws_conn_id: str | None, + region: str | None, + ): + self.cluster = cluster + self.task_arn = task_arn + self.waiter_delay = waiter_delay or 15 + self.aws_conn_id = aws_conn_id + self.region = region + + def serialize(self) -> tuple[str, dict[str, Any]]: + return ( + self.__class__.__module__ + "." + self.__class__.__qualname__, + { + "cluster": self.cluster, + "task_arn": self.task_arn, + "waiter_delay": self.waiter_delay, + "aws_conn_id": self.aws_conn_id, + "region": self.region, + }, + ) + + async def run(self) -> AsyncIterator[TriggerEvent]: + hook = EcsHook(aws_conn_id=self.aws_conn_id, region_name=self.region) + async with hook.async_conn as client: + waiter = hook.get_waiter("tasks_stopped", deferrable=True, client=client) + while True: + try: + waiter.wait(cluster=self.cluster, tasks=[self.task_arn], WaiterConfig={"MaxAttempts": 1}) + break # we reach this point only if the waiter met a success criteria + except WaiterError as error: + if "terminal failure" in str(error): + raise + self.log.info("Status of the task is %s", error.last_response["tasks"][0]["lastStatus"]) + await asyncio.sleep(int(self.waiter_delay)) + + yield TriggerEvent({"status": "success", "task_arn": self.task_arn}) From a6e5c0e964a6d847cc9cdadce81d4fbfcc7476e1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Vandon?= Date: Tue, 6 Jun 2023 16:39:05 -0700 Subject: [PATCH 03/20] add logs support to run task operator --- airflow/providers/amazon/aws/hooks/ecs.py | 3 +- airflow/providers/amazon/aws/operators/ecs.py | 32 +++++---- airflow/providers/amazon/aws/triggers/ecs.py | 72 ++++++++++++++++--- 3 files changed, 81 insertions(+), 26 deletions(-) diff --git a/airflow/providers/amazon/aws/hooks/ecs.py b/airflow/providers/amazon/aws/hooks/ecs.py index f8e80f7110663..70308df918230 100644 --- a/airflow/providers/amazon/aws/hooks/ecs.py +++ b/airflow/providers/amazon/aws/hooks/ecs.py @@ -191,7 +191,8 @@ def _get_log_events(self, skip: int = 0) -> Generator: self.logger.warning("ConnectionClosedError on retrieving Cloudwatch log events", error) yield from () - def _event_to_str(self, event: dict) -> str: + @staticmethod + def _event_to_str(event: dict) -> str: event_dt = datetime.utcfromtimestamp(event["timestamp"] / 1000.0) formatted_event_dt = event_dt.strftime("%Y-%m-%d %H:%M:%S,%f")[:-3] message = event["message"] diff --git a/airflow/providers/amazon/aws/operators/ecs.py b/airflow/providers/amazon/aws/operators/ecs.py index edff5bd156b82..7ea81c72f4c89 100644 --- a/airflow/providers/amazon/aws/operators/ecs.py +++ b/airflow/providers/amazon/aws/operators/ecs.py @@ -397,6 +397,7 @@ class EcsRunTaskOperator(EcsBaseOperator): finished. :param awslogs_fetch_interval: the interval that the ECS task log fetcher should wait in between each Cloudwatch logs fetches. + If deferrable is set to True, that parameter is ignored and waiter_delay is used instead. :param quota_retry: Config if and how to retry the launch of a new ECS task, to handle transient errors. :param reattach: If set to True, will check if the task previously launched by the task_instance @@ -545,7 +546,20 @@ def _start_wait_task(self, context): if not self.arn: self._start_task(context) - if self._aws_logs_enabled(): + if self.deferrable: + self.defer( + trigger=TaskDoneTrigger( + cluster=self.cluster, + task_arn=self.arn, + waiter_delay=self.waiter_delay, + aws_conn_id=self.aws_conn_id, + region=self.region, + log_group=self.awslogs_group, + log_stream=f"{self.awslogs_stream_prefix}/{self.ecs_task_id}", + ), + method_name="execute_complete", + ) + elif self._aws_logs_enabled(): self.log.info("Starting ECS Task Log Fetcher") self.task_log_fetcher = self._get_task_log_fetcher() self.task_log_fetcher.start() @@ -557,20 +571,8 @@ def _start_wait_task(self, context): self.task_log_fetcher.stop() self.task_log_fetcher.join() - else: - if self.deferrable: - self.defer( - trigger=TaskDoneTrigger( - cluster=self.cluster, - task_arn=self.arn, - waiter_delay=self.waiter_delay, - aws_conn_id=self.aws_conn_id, - region=self.region, - ), - method_name="execute_complete", - ) - elif self.wait_for_completion: - self._wait_for_task_ended() + elif self.wait_for_completion: + self._wait_for_task_ended() def _xcom_del(self, session, task_id): session.query(XCom).filter(XCom.dag_id == self.dag_id, XCom.task_id == task_id).delete() diff --git a/airflow/providers/amazon/aws/triggers/ecs.py b/airflow/providers/amazon/aws/triggers/ecs.py index dcc32f32d7ecd..f4c3b435ab5b9 100644 --- a/airflow/providers/amazon/aws/triggers/ecs.py +++ b/airflow/providers/amazon/aws/triggers/ecs.py @@ -20,9 +20,10 @@ import asyncio from typing import Any, AsyncIterator -from botocore.exceptions import WaiterError +from botocore.exceptions import ClientError, WaiterError -from airflow.providers.amazon.aws.hooks.ecs import EcsHook +from airflow.providers.amazon.aws.hooks.ecs import EcsHook, EcsTaskLogFetcher +from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook from airflow.triggers.base import BaseTrigger, TriggerEvent @@ -65,13 +66,12 @@ def serialize(self) -> tuple[str, dict[str, Any]]: ) async def run(self) -> AsyncIterator[TriggerEvent]: - hook = EcsHook(aws_conn_id=self.aws_conn_id, region_name=self.region) - async with hook.async_conn as client: - waiter = hook.get_waiter("cluster_active", deferrable=True, client=client) + async with EcsHook(aws_conn_id=self.aws_conn_id, region_name=self.region).async_conn as client: + waiter = client.get_waiter("cluster_active") while self.attempts >= 1: self.attempts = self.attempts - 1 try: - waiter.wait( + await waiter.wait( clusters=[self.cluster_arn], WaiterConfig={ "MaxAttempts": 1, @@ -107,13 +107,19 @@ def __init__( waiter_delay: int | None, aws_conn_id: str | None, region: str | None, + log_group: str | None = None, + log_stream: str | None = None, ): self.cluster = cluster self.task_arn = task_arn + self.waiter_delay = waiter_delay or 15 self.aws_conn_id = aws_conn_id self.region = region + self.log_group = log_group + self.log_stream = log_stream + def serialize(self) -> tuple[str, dict[str, Any]]: return ( self.__class__.__module__ + "." + self.__class__.__qualname__, @@ -123,21 +129,67 @@ def serialize(self) -> tuple[str, dict[str, Any]]: "waiter_delay": self.waiter_delay, "aws_conn_id": self.aws_conn_id, "region": self.region, + "log_group": self.log_group, + "log_stream": self.log_stream, }, ) async def run(self) -> AsyncIterator[TriggerEvent]: - hook = EcsHook(aws_conn_id=self.aws_conn_id, region_name=self.region) - async with hook.async_conn as client: - waiter = hook.get_waiter("tasks_stopped", deferrable=True, client=client) + # fmt: off + async with EcsHook(aws_conn_id=self.aws_conn_id, region_name=self.region).async_conn as ecs_client,\ + AwsLogsHook(aws_conn_id=self.aws_conn_id, region_name=self.region).async_conn as logs_client: + # fmt: on + waiter = ecs_client.get_waiter("tasks_stopped") + logs_token = None while True: try: - waiter.wait(cluster=self.cluster, tasks=[self.task_arn], WaiterConfig={"MaxAttempts": 1}) + await waiter.wait( + cluster=self.cluster, tasks=[self.task_arn], WaiterConfig={"MaxAttempts": 1} + ) break # we reach this point only if the waiter met a success criteria except WaiterError as error: if "terminal failure" in str(error): raise self.log.info("Status of the task is %s", error.last_response["tasks"][0]["lastStatus"]) await asyncio.sleep(int(self.waiter_delay)) + finally: + if self.log_group and self.log_stream: + logs_token = await self._forward_logs(logs_client, logs_token) yield TriggerEvent({"status": "success", "task_arn": self.task_arn}) + + async def _forward_logs(self, logs_client, next_token: str | None = None) -> str | None: + """ + Reads logs from the cloudwatch stream and prints them to the task logs. + :return: the token to pass to the next iteration to resume where we started + """ + while True: + if next_token is not None: + token_arg: dict[str, str] = {"nextToken": next_token} + else: + token_arg = {} + try: + response = await logs_client.get_log_events( + logGroupName=self.log_group, + logStreamName=self.log_stream, + startFromHead=True, + **token_arg, + ) + except ClientError as ce: + if ce.response["Error"]["Code"] == "ResourceNotFoundException": + self.log.info( + "Tried to get logs from stream %s in group %s but it didn't exist (yet). " + "Will try again.", + self.log_stream, + self.log_group, + ) + return None + raise + + events = response["events"] + for log_event in events: + self.log.info(EcsTaskLogFetcher._event_to_str(log_event)) + + if len(events) == 0 or next_token == response["nextForwardToken"]: + return response["nextForwardToken"] + next_token = response["nextForwardToken"] From bca71da599156569cb4fa11d383dd3d7062ca8f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Vandon?= Date: Fri, 9 Jun 2023 11:56:28 -0700 Subject: [PATCH 04/20] add tests --- airflow/providers/amazon/aws/operators/ecs.py | 10 +- airflow/providers/amazon/aws/triggers/ecs.py | 9 +- .../amazon/aws/operators/test_ecs.py | 28 ++++- .../providers/amazon/aws/triggers/test_ecs.py | 118 ++++++++++++++++++ 4 files changed, 156 insertions(+), 9 deletions(-) create mode 100644 tests/providers/amazon/aws/triggers/test_ecs.py diff --git a/airflow/providers/amazon/aws/operators/ecs.py b/airflow/providers/amazon/aws/operators/ecs.py index 7ea81c72f4c89..521392a44d7e6 100644 --- a/airflow/providers/amazon/aws/operators/ecs.py +++ b/airflow/providers/amazon/aws/operators/ecs.py @@ -515,19 +515,17 @@ def execute(self, context, session=None): self._start_wait_task(context) - self._after_execution(session) - - return None + return self._after_execution(session) def execute_complete(self, context, event=None): if event["status"] != "success": raise AirflowException(f"Error in task execution: {event}") self.arn = event["task_arn"] # restore arn to its updated value self._after_execution() - return None + # TODO return last log line if necessary because task_log_fetcher will always be None here @provide_session - def _after_execution(self, session=None): + def _after_execution(self, session=None) -> str | None: self._check_success_task() self.log.info("ECS Task has been successfully executed") @@ -539,6 +537,8 @@ def _after_execution(self, session=None): if self.do_xcom_push and self.task_log_fetcher: return self.task_log_fetcher.get_last_log_message() + else: + return None @AwsBaseHook.retry(should_retry_eni) def _start_wait_task(self, context): diff --git a/airflow/providers/amazon/aws/triggers/ecs.py b/airflow/providers/amazon/aws/triggers/ecs.py index f4c3b435ab5b9..10913e96ba631 100644 --- a/airflow/providers/amazon/aws/triggers/ecs.py +++ b/airflow/providers/amazon/aws/triggers/ecs.py @@ -22,6 +22,7 @@ from botocore.exceptions import ClientError, WaiterError +from airflow import AirflowException from airflow.providers.amazon.aws.hooks.ecs import EcsHook, EcsTaskLogFetcher from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook from airflow.triggers.base import BaseTrigger, TriggerEvent @@ -48,7 +49,7 @@ def __init__( region: str | None, ): self.cluster_arn = cluster_arn - self.waiter_delay = waiter_delay or 15 + self.waiter_delay = waiter_delay if waiter_delay is not None else 15 # written like this to allow 0 self.attempts = waiter_max_attempts or 999999999 self.aws_conn_id = aws_conn_id self.region = region @@ -77,14 +78,16 @@ async def run(self) -> AsyncIterator[TriggerEvent]: "MaxAttempts": 1, }, ) - break # we reach this point only if the waiter met a success criteria + # we reach this point only if the waiter met a success criteria + yield TriggerEvent({"status": "success", "value": self.cluster_arn}) + return except WaiterError as error: if "terminal failure" in str(error): raise self.log.info("Status of cluster is %s", error.last_response["clusters"][0]["status"]) await asyncio.sleep(int(self.waiter_delay)) - yield TriggerEvent({"status": "success", "value": self.cluster_arn}) + raise AirflowException("Cluster still not active after the max number of tries has been reached") class TaskDoneTrigger(BaseTrigger): diff --git a/tests/providers/amazon/aws/operators/test_ecs.py b/tests/providers/amazon/aws/operators/test_ecs.py index f43dd330e6b2b..26be48a7e108b 100644 --- a/tests/providers/amazon/aws/operators/test_ecs.py +++ b/tests/providers/amazon/aws/operators/test_ecs.py @@ -20,7 +20,7 @@ import sys from copy import deepcopy from unittest import mock -from unittest.mock import MagicMock +from unittest.mock import MagicMock, PropertyMock import boto3 import pytest @@ -38,6 +38,7 @@ EcsRunTaskOperator, EcsTaskLogFetcher, ) +from airflow.providers.amazon.aws.triggers.ecs import TaskDoneTrigger from airflow.utils.types import NOTSET CLUSTER_NAME = "test_cluster" @@ -654,6 +655,31 @@ def test_execute_xcom_disabled(self, log_fetcher_mock, client_mock): self.ecs.do_xcom_push = False assert self.ecs.execute(None) is None + @mock.patch.object(EcsRunTaskOperator, "client") + def test_with_defer(self, client_mock): + self.ecs.deferrable = True + + client_mock.run_task.return_value = RESPONSE_WITHOUT_FAILURES + + with pytest.raises(TaskDeferred) as deferred: + self.ecs.execute(None) + + assert isinstance(deferred.value.trigger, TaskDoneTrigger) + assert deferred.value.trigger.task_arn == f"arn:aws:ecs:us-east-1:012345678910:task/{TASK_ID}" + + @mock.patch.object(EcsRunTaskOperator, "client", new_callable=PropertyMock) + @mock.patch.object(EcsRunTaskOperator, "_xcom_del") + def test_execute_complete(self, xcom_del_mock: MagicMock, client_mock): + event = {"status": "success", "task_arn": "my_arn"} + self.ecs.reattach = True + + self.ecs.execute_complete(None, event) + + # task gets described to assert its success + client_mock().describe_tasks.assert_called_once_with(cluster="c", tasks=["my_arn"]) + # if reattach mode, xcom value is deleted on success + xcom_del_mock.assert_called_once() + class TestEcsCreateClusterOperator(EcsBaseTestCase): @pytest.mark.parametrize("waiter_delay, waiter_max_attempts", WAITERS_TEST_CASES) diff --git a/tests/providers/amazon/aws/triggers/test_ecs.py b/tests/providers/amazon/aws/triggers/test_ecs.py new file mode 100644 index 0000000000000..cb6037a7c8031 --- /dev/null +++ b/tests/providers/amazon/aws/triggers/test_ecs.py @@ -0,0 +1,118 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock +from unittest.mock import AsyncMock + +import pytest +from botocore.exceptions import WaiterError + +from airflow import AirflowException +from airflow.providers.amazon.aws.hooks.ecs import EcsHook +from airflow.providers.amazon.aws.triggers.ecs import ClusterActiveTrigger, TaskDoneTrigger +from airflow.triggers.base import TriggerEvent + + +class TestClusterActiveTrigger: + @pytest.mark.asyncio + @mock.patch.object(EcsHook, "async_conn") + async def test_run_max_attempts(self, client_mock): + a_mock = mock.MagicMock() + client_mock.__aenter__.return_value = a_mock + wait_mock = AsyncMock() + wait_mock.side_effect = WaiterError("name", "reason", {"clusters": [{"status": "my_status"}]}) + a_mock.get_waiter().wait = wait_mock + + max_attempts = 5 + trigger = ClusterActiveTrigger("cluster_arn", 0, max_attempts, None, None) + + with pytest.raises(AirflowException): + generator = trigger.run() + await generator.asend(None) + + assert wait_mock.call_count == max_attempts + + @pytest.mark.asyncio + @mock.patch.object(EcsHook, "async_conn") + async def test_run_success(self, client_mock): + a_mock = mock.MagicMock() + client_mock.__aenter__.return_value = a_mock + wait_mock = AsyncMock() + a_mock.get_waiter().wait = wait_mock + + trigger = ClusterActiveTrigger("cluster_arn", 0, 5, None, None) + + generator = trigger.run() + response: TriggerEvent = await generator.asend(None) + + assert response.payload["status"] == "success" + assert response.payload["value"] == "cluster_arn" + + @pytest.mark.asyncio + @mock.patch.object(EcsHook, "async_conn") + async def test_run_error(self, client_mock): + a_mock = mock.MagicMock() + client_mock.__aenter__.return_value = a_mock + wait_mock = AsyncMock() + wait_mock.side_effect = WaiterError("terminal failure", "reason", {}) + a_mock.get_waiter().wait = wait_mock + + trigger = ClusterActiveTrigger("cluster_arn", 0, 5, None, None) + + with pytest.raises(WaiterError): + generator = trigger.run() + await generator.asend(None) + + +class TestTaskDoneTrigger: + @pytest.mark.asyncio + @mock.patch.object(EcsHook, "async_conn") + async def test_run_until_error(self, client_mock): + a_mock = mock.MagicMock() + client_mock.__aenter__.return_value = a_mock + wait_mock = AsyncMock() + wait_mock.side_effect = [ + WaiterError("name", "reason", {"tasks": [{"lastStatus": "my_status"}]}), + WaiterError("name", "reason", {"tasks": [{"lastStatus": "my_status"}]}), + WaiterError("terminal failure", "reason", {}), + ] + a_mock.get_waiter().wait = wait_mock + + trigger = TaskDoneTrigger("cluster", "task_arn", 0, None, None) + + with pytest.raises(WaiterError): + generator = trigger.run() + await generator.asend(None) + + assert wait_mock.call_count == 3 + + @pytest.mark.asyncio + @mock.patch.object(EcsHook, "async_conn") + async def test_run_success(self, client_mock): + a_mock = mock.MagicMock() + client_mock.__aenter__.return_value = a_mock + wait_mock = AsyncMock() + a_mock.get_waiter().wait = wait_mock + + trigger = TaskDoneTrigger("cluster", "my_task_arn", 0, None, None) + + generator = trigger.run() + response: TriggerEvent = await generator.asend(None) + + assert response.payload["status"] == "success" + assert response.payload["task_arn"] == "my_task_arn" From 59ff8f405ac13251978afe690959081c0d98e53d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Vandon?= Date: Tue, 13 Jun 2023 11:10:16 -0700 Subject: [PATCH 05/20] add deferrable for delete cluster, by adapting the create cluster trigger --- airflow/providers/amazon/aws/operators/ecs.py | 40 ++++++++++++++----- airflow/providers/amazon/aws/triggers/ecs.py | 18 ++++++--- 2 files changed, 42 insertions(+), 16 deletions(-) diff --git a/airflow/providers/amazon/aws/operators/ecs.py b/airflow/providers/amazon/aws/operators/ecs.py index 521392a44d7e6..bc6221a02ec9b 100644 --- a/airflow/providers/amazon/aws/operators/ecs.py +++ b/airflow/providers/amazon/aws/operators/ecs.py @@ -36,7 +36,10 @@ EcsTaskLogFetcher, should_retry_eni, ) -from airflow.providers.amazon.aws.triggers.ecs import ClusterActiveTrigger, TaskDoneTrigger +from airflow.providers.amazon.aws.triggers.ecs import ( + ClusterWaiterTrigger, + TaskDoneTrigger, +) from airflow.utils.helpers import prune_dict from airflow.utils.session import provide_session @@ -68,6 +71,15 @@ def execute(self, context: Context): """Must overwrite in child classes.""" raise NotImplementedError("Please implement execute() in subclass") + def _complete_exec_with_cluster_desc(self, context, event=None): + """To be used as trigger callback for operators that return the cluster description""" + if event["status"] != "success": + raise AirflowException(f"Error while waiting for operation on cluster to complete: {event}") + cluster_arn = event.get("arn") + # We cannot get the cluster definition from the waiter on success, so we have to query it here. + details = self.hook.conn.describe_clusters(clusters=[cluster_arn])["clusters"][0] + return details + class EcsCreateClusterOperator(EcsBaseOperator): """ @@ -127,14 +139,15 @@ def execute(self, context: Context): self.log.info("Cluster %r in state: %r.", self.cluster_name, cluster_state) elif self.deferrable: self.defer( - trigger=ClusterActiveTrigger( + trigger=ClusterWaiterTrigger( + waiter_name="cluster_active", cluster_arn=cluster_details["clusterArn"], waiter_delay=self.waiter_delay, waiter_max_attempts=self.waiter_max_attempts, aws_conn_id=self.aws_conn_id, region=self.region, ), - method_name="execute_complete", + method_name="_complete_exec_with_cluster_desc", ) elif self.wait_for_completion: waiter = self.hook.get_waiter("cluster_active") @@ -150,11 +163,6 @@ def execute(self, context: Context): return cluster_details - def execute_complete(self, context, event=None): - if event["status"] != "success": - raise AirflowException(f"Error in cluster creation: {event}") - return event.get("value") - class EcsDeleteClusterOperator(EcsBaseOperator): """ @@ -196,9 +204,21 @@ def execute(self, context: Context): cluster_state = cluster_details.get("status") if cluster_state == EcsClusterStates.INACTIVE: - # In some circumstances the ECS Cluster is deleted immediately, - # so there is no reason to wait for completion. + # if the cluster doesn't have capacity providers that are associated with it, + # the deletion is instantaneous, and we don't need to wait for it. self.log.info("Cluster %r in state: %r.", self.cluster_name, cluster_state) + elif self.deferrable: + self.defer( + trigger=ClusterWaiterTrigger( + waiter_name="cluster_inactive", + cluster_arn=cluster_details["clusterArn"], + waiter_delay=self.waiter_delay, + waiter_max_attempts=self.waiter_max_attempts, + aws_conn_id=self.aws_conn_id, + region=self.region, + ), + method_name="_complete_exec_with_cluster_desc", + ) elif self.wait_for_completion: waiter = self.hook.get_waiter("cluster_inactive") waiter.wait( diff --git a/airflow/providers/amazon/aws/triggers/ecs.py b/airflow/providers/amazon/aws/triggers/ecs.py index 10913e96ba631..95ef545f3c0b1 100644 --- a/airflow/providers/amazon/aws/triggers/ecs.py +++ b/airflow/providers/amazon/aws/triggers/ecs.py @@ -28,10 +28,11 @@ from airflow.triggers.base import BaseTrigger, TriggerEvent -class ClusterActiveTrigger(BaseTrigger): +class ClusterWaiterTrigger(BaseTrigger): """ - Waits for a cluster to be active, triggers when it finishes + Polls the status of a cluster using a given waiter. Can be used to poll for an active or inactive cluster. + :param waiter_name: Name of the waiter to use, for instance 'cluster_active' or 'cluster_inactive' :param cluster_arn: ARN of the cluster to watch. :param waiter_delay: The amount of time in seconds to wait between attempts. :param waiter_max_attempts: The number of times to ping for status. @@ -42,6 +43,7 @@ class ClusterActiveTrigger(BaseTrigger): def __init__( self, + waiter_name: str, cluster_arn: str, waiter_delay: int | None, waiter_max_attempts: int | None, @@ -49,6 +51,7 @@ def __init__( region: str | None, ): self.cluster_arn = cluster_arn + self.waiter_name = waiter_name self.waiter_delay = waiter_delay if waiter_delay is not None else 15 # written like this to allow 0 self.attempts = waiter_max_attempts or 999999999 self.aws_conn_id = aws_conn_id @@ -58,6 +61,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]: return ( self.__class__.__module__ + "." + self.__class__.__qualname__, { + "waiter_name": self.waiter_name, "cluster_arn": self.cluster_arn, "waiter_delay": self.waiter_delay, "waiter_max_attempts": self.attempts, @@ -68,7 +72,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]: async def run(self) -> AsyncIterator[TriggerEvent]: async with EcsHook(aws_conn_id=self.aws_conn_id, region_name=self.region).async_conn as client: - waiter = client.get_waiter("cluster_active") + waiter = client.get_waiter(self.waiter_name) while self.attempts >= 1: self.attempts = self.attempts - 1 try: @@ -79,7 +83,7 @@ async def run(self) -> AsyncIterator[TriggerEvent]: }, ) # we reach this point only if the waiter met a success criteria - yield TriggerEvent({"status": "success", "value": self.cluster_arn}) + yield TriggerEvent({"status": "success", "arn": self.cluster_arn}) return except WaiterError as error: if "terminal failure" in str(error): @@ -87,12 +91,14 @@ async def run(self) -> AsyncIterator[TriggerEvent]: self.log.info("Status of cluster is %s", error.last_response["clusters"][0]["status"]) await asyncio.sleep(int(self.waiter_delay)) - raise AirflowException("Cluster still not active after the max number of tries has been reached") + raise AirflowException( + "Cluster still not in expected status after the max number of tries has been reached" + ) class TaskDoneTrigger(BaseTrigger): """ - Waits for an ECS task to be done + Waits for an ECS task to be done, while eventually polling logs. :param cluster: short name or full ARN of the cluster where the task is running. :param task_arn: ARN of the task to watch. From 5dc163b1b740bbe6580b34b5065fdc3c355c2589 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Vandon?= Date: Tue, 13 Jun 2023 11:22:23 -0700 Subject: [PATCH 06/20] add deferrable parameter --- airflow/providers/amazon/aws/operators/ecs.py | 20 +++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/airflow/providers/amazon/aws/operators/ecs.py b/airflow/providers/amazon/aws/operators/ecs.py index bc6221a02ec9b..110c4edcbacb1 100644 --- a/airflow/providers/amazon/aws/operators/ecs.py +++ b/airflow/providers/amazon/aws/operators/ecs.py @@ -102,7 +102,12 @@ class EcsCreateClusterOperator(EcsBaseOperator): (default: False) """ - template_fields: Sequence[str] = ("cluster_name", "create_cluster_kwargs", "wait_for_completion") + template_fields: Sequence[str] = ( + "cluster_name", + "create_cluster_kwargs", + "wait_for_completion", + "deferrable", + ) def __init__( self, @@ -178,9 +183,12 @@ class EcsDeleteClusterOperator(EcsBaseOperator): if not set then the default waiter value will be used. :param waiter_max_attempts: The maximum number of attempts to be made, if not set then the default waiter value will be used. + :param deferrable: If True, the operator will wait asynchronously for the job to complete. + This implies waiting for completion. This mode requires aiobotocore module to be installed. + (default: False) """ - template_fields: Sequence[str] = ("cluster_name", "wait_for_completion") + template_fields: Sequence[str] = ("cluster_name", "wait_for_completion", "deferrable") def __init__( self, @@ -189,6 +197,7 @@ def __init__( wait_for_completion: bool = True, waiter_delay: int | None = None, waiter_max_attempts: int | None = None, + deferrable: bool = False, **kwargs, ) -> None: super().__init__(**kwargs) @@ -196,6 +205,7 @@ def __init__( self.wait_for_completion = wait_for_completion self.waiter_delay = waiter_delay self.waiter_max_attempts = waiter_max_attempts + self.deferrable = deferrable def execute(self, context: Context): self.log.info("Deleting cluster %r.", self.cluster_name) @@ -432,6 +442,9 @@ class EcsRunTaskOperator(EcsBaseOperator): if not set then the default waiter value will be used. :param waiter_max_attempts: The maximum number of attempts to be made, if not set then the default waiter value will be used. + :param deferrable: If True, the operator will wait asynchronously for the job to complete. + This implies waiting for completion. This mode requires aiobotocore module to be installed. + (default: False) """ ui_color = "#f0ede4" @@ -455,6 +468,7 @@ class EcsRunTaskOperator(EcsBaseOperator): "reattach", "number_logs_exception", "wait_for_completion", + "deferrable", ) template_fields_renderers = { "overrides": "json", @@ -489,6 +503,7 @@ def __init__( wait_for_completion: bool = True, waiter_delay: int | None = None, waiter_max_attempts: int | None = None, + deferrable: bool = False, **kwargs, ): super().__init__(**kwargs) @@ -522,6 +537,7 @@ def __init__( self.wait_for_completion = wait_for_completion self.waiter_delay = waiter_delay self.waiter_max_attempts = waiter_max_attempts + self.deferrable = deferrable @provide_session def execute(self, context, session=None): From 71b9648cb1372d303c84079a06bf9f550f9ceb32 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Vandon?= Date: Tue, 13 Jun 2023 12:29:55 -0700 Subject: [PATCH 07/20] rearranging code around a bit --- airflow/providers/amazon/aws/operators/ecs.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/airflow/providers/amazon/aws/operators/ecs.py b/airflow/providers/amazon/aws/operators/ecs.py index 110c4edcbacb1..b5d0f8a92b2a3 100644 --- a/airflow/providers/amazon/aws/operators/ecs.py +++ b/airflow/providers/amazon/aws/operators/ecs.py @@ -551,17 +551,23 @@ def execute(self, context, session=None): self._start_wait_task(context) - return self._after_execution(session) + self._after_execution(session) + + if self.do_xcom_push and self.task_log_fetcher: + return self.task_log_fetcher.get_last_log_message() + else: + return None def execute_complete(self, context, event=None): if event["status"] != "success": raise AirflowException(f"Error in task execution: {event}") self.arn = event["task_arn"] # restore arn to its updated value self._after_execution() - # TODO return last log line if necessary because task_log_fetcher will always be None here + if self._aws_logs_enabled(): + ... # TODO return last log line but task_log_fetcher will always be None here @provide_session - def _after_execution(self, session=None) -> str | None: + def _after_execution(self, session=None): self._check_success_task() self.log.info("ECS Task has been successfully executed") @@ -571,11 +577,6 @@ def _after_execution(self, session=None) -> str | None: # as we can't reattach it anymore self._xcom_del(session, self.REATTACH_XCOM_TASK_ID_TEMPLATE.format(task_id=self.task_id)) - if self.do_xcom_push and self.task_log_fetcher: - return self.task_log_fetcher.get_last_log_message() - else: - return None - @AwsBaseHook.retry(should_retry_eni) def _start_wait_task(self, context): From ab0f65d395f00743784124bbf7e1a8ee30d7bbdd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Vandon?= Date: Tue, 13 Jun 2023 12:36:55 -0700 Subject: [PATCH 08/20] tests --- .../amazon/aws/operators/test_ecs.py | 21 +++++++++++++++++++ .../providers/amazon/aws/triggers/test_ecs.py | 10 ++++----- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/tests/providers/amazon/aws/operators/test_ecs.py b/tests/providers/amazon/aws/operators/test_ecs.py index 26be48a7e108b..dc2c8460e8391 100644 --- a/tests/providers/amazon/aws/operators/test_ecs.py +++ b/tests/providers/amazon/aws/operators/test_ecs.py @@ -188,6 +188,7 @@ def test_template_fields_overrides(self): "reattach", "number_logs_exception", "wait_for_completion", + "deferrable", ) @pytest.mark.parametrize( @@ -771,6 +772,26 @@ def test_execute_with_waiter(self, patch_hook_waiters, waiter_delay, waiter_max_ mocked_waiters.wait.assert_called_once_with(clusters=mock.ANY, WaiterConfig=expected_waiter_config) assert result is not None + @mock.patch.object(EcsDeleteClusterOperator, "client") + def test_execute_deferrable(self, mock_client: MagicMock): + op = EcsDeleteClusterOperator( + task_id="task", + cluster_name=CLUSTER_NAME, + deferrable=True, + waiter_delay=12, + waiter_max_attempts=34, + ) + mock_client.delete_cluster.return_value = { + "cluster": {"status": EcsClusterStates.DEPROVISIONING, "clusterArn": "my arn"} + } + + with pytest.raises(TaskDeferred) as defer: + op.execute(None) + + assert defer.value.trigger.cluster_arn == "my arn" + assert defer.value.trigger.waiter_delay == 12 + assert defer.value.trigger.attempts == 34 + def test_execute_immediate_delete(self, patch_hook_waiters): """Test if cluster deleted during initial request.""" op = EcsDeleteClusterOperator(task_id="task", cluster_name=CLUSTER_NAME, wait_for_completion=True) diff --git a/tests/providers/amazon/aws/triggers/test_ecs.py b/tests/providers/amazon/aws/triggers/test_ecs.py index cb6037a7c8031..2d9cb9f6fa74d 100644 --- a/tests/providers/amazon/aws/triggers/test_ecs.py +++ b/tests/providers/amazon/aws/triggers/test_ecs.py @@ -24,11 +24,11 @@ from airflow import AirflowException from airflow.providers.amazon.aws.hooks.ecs import EcsHook -from airflow.providers.amazon.aws.triggers.ecs import ClusterActiveTrigger, TaskDoneTrigger +from airflow.providers.amazon.aws.triggers.ecs import ClusterWaiterTrigger, TaskDoneTrigger from airflow.triggers.base import TriggerEvent -class TestClusterActiveTrigger: +class TestClusterWaiterTrigger: @pytest.mark.asyncio @mock.patch.object(EcsHook, "async_conn") async def test_run_max_attempts(self, client_mock): @@ -39,7 +39,7 @@ async def test_run_max_attempts(self, client_mock): a_mock.get_waiter().wait = wait_mock max_attempts = 5 - trigger = ClusterActiveTrigger("cluster_arn", 0, max_attempts, None, None) + trigger = ClusterWaiterTrigger("my_waiter", "cluster_arn", 0, max_attempts, None, None) with pytest.raises(AirflowException): generator = trigger.run() @@ -55,7 +55,7 @@ async def test_run_success(self, client_mock): wait_mock = AsyncMock() a_mock.get_waiter().wait = wait_mock - trigger = ClusterActiveTrigger("cluster_arn", 0, 5, None, None) + trigger = ClusterWaiterTrigger("my_waiter", "cluster_arn", 0, 5, None, None) generator = trigger.run() response: TriggerEvent = await generator.asend(None) @@ -72,7 +72,7 @@ async def test_run_error(self, client_mock): wait_mock.side_effect = WaiterError("terminal failure", "reason", {}) a_mock.get_waiter().wait = wait_mock - trigger = ClusterActiveTrigger("cluster_arn", 0, 5, None, None) + trigger = ClusterWaiterTrigger("my_waiter", "cluster_arn", 0, 5, None, None) with pytest.raises(WaiterError): generator = trigger.run() From a4214578ffd150cb78f3e3c2778d9c2014bd8cf2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Vandon?= Date: Tue, 13 Jun 2023 14:16:10 -0700 Subject: [PATCH 09/20] add dots in comments --- airflow/providers/amazon/aws/operators/ecs.py | 2 +- airflow/providers/amazon/aws/triggers/ecs.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/airflow/providers/amazon/aws/operators/ecs.py b/airflow/providers/amazon/aws/operators/ecs.py index b5d0f8a92b2a3..70497be05c24e 100644 --- a/airflow/providers/amazon/aws/operators/ecs.py +++ b/airflow/providers/amazon/aws/operators/ecs.py @@ -72,7 +72,7 @@ def execute(self, context: Context): raise NotImplementedError("Please implement execute() in subclass") def _complete_exec_with_cluster_desc(self, context, event=None): - """To be used as trigger callback for operators that return the cluster description""" + """To be used as trigger callback for operators that return the cluster description.""" if event["status"] != "success": raise AirflowException(f"Error while waiting for operation on cluster to complete: {event}") cluster_arn = event.get("arn") diff --git a/airflow/providers/amazon/aws/triggers/ecs.py b/airflow/providers/amazon/aws/triggers/ecs.py index 95ef545f3c0b1..39ea94379496c 100644 --- a/airflow/providers/amazon/aws/triggers/ecs.py +++ b/airflow/providers/amazon/aws/triggers/ecs.py @@ -170,7 +170,7 @@ async def run(self) -> AsyncIterator[TriggerEvent]: async def _forward_logs(self, logs_client, next_token: str | None = None) -> str | None: """ Reads logs from the cloudwatch stream and prints them to the task logs. - :return: the token to pass to the next iteration to resume where we started + :return: the token to pass to the next iteration to resume where we started. """ while True: if next_token is not None: From b2d1dae05c62ed8ff9f1d154900630893e72acd5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Vandon?= Date: Tue, 13 Jun 2023 14:25:19 -0700 Subject: [PATCH 10/20] fix test --- tests/providers/amazon/aws/triggers/test_ecs.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/providers/amazon/aws/triggers/test_ecs.py b/tests/providers/amazon/aws/triggers/test_ecs.py index 2d9cb9f6fa74d..6177454f90f41 100644 --- a/tests/providers/amazon/aws/triggers/test_ecs.py +++ b/tests/providers/amazon/aws/triggers/test_ecs.py @@ -61,7 +61,7 @@ async def test_run_success(self, client_mock): response: TriggerEvent = await generator.asend(None) assert response.payload["status"] == "success" - assert response.payload["value"] == "cluster_arn" + assert response.payload["arn"] == "cluster_arn" @pytest.mark.asyncio @mock.patch.object(EcsHook, "async_conn") @@ -94,6 +94,7 @@ async def test_run_until_error(self, client_mock): a_mock.get_waiter().wait = wait_mock trigger = TaskDoneTrigger("cluster", "task_arn", 0, None, None) + trigger.waiter_delay = 0 # cannot be set to 0 in __init__ because 0 is treated as None with pytest.raises(WaiterError): generator = trigger.run() From 305ea5c28875f554cafaa67d7bd10c4108b0f652 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Vandon?= Date: Tue, 13 Jun 2023 15:08:25 -0700 Subject: [PATCH 11/20] add trigger to yaml --- airflow/providers/amazon/provider.yaml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/airflow/providers/amazon/provider.yaml b/airflow/providers/amazon/provider.yaml index 05924eebc9742..9466cb6273a39 100644 --- a/airflow/providers/amazon/provider.yaml +++ b/airflow/providers/amazon/provider.yaml @@ -528,6 +528,9 @@ triggers: - integration-name: Amazon EMR python-modules: - airflow.providers.amazon.aws.triggers.emr + - integration-name: Amazon Elastic Containers Service (ECS) + python-modules: + - airflow.providers.amazon.aws.triggers.ecs transfers: - source-integration-name: Amazon DynamoDB From eab95e0ba6d1f17e6ea359a78192f4775697b312 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Vandon?= Date: Tue, 13 Jun 2023 15:21:00 -0700 Subject: [PATCH 12/20] return last line of logs --- airflow/providers/amazon/aws/operators/ecs.py | 20 +++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/airflow/providers/amazon/aws/operators/ecs.py b/airflow/providers/amazon/aws/operators/ecs.py index 1da93126743ad..45c83ead9bfe8 100644 --- a/airflow/providers/amazon/aws/operators/ecs.py +++ b/airflow/providers/amazon/aws/operators/ecs.py @@ -36,6 +36,7 @@ EcsTaskLogFetcher, should_retry_eni, ) +from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook from airflow.providers.amazon.aws.triggers.ecs import ( ClusterWaiterTrigger, TaskDoneTrigger, @@ -561,10 +562,19 @@ def execute(self, context, session=None): def execute_complete(self, context, event=None): if event["status"] != "success": raise AirflowException(f"Error in task execution: {event}") - self.arn = event["task_arn"] # restore arn to its updated value + self.arn = event["task_arn"] # restore arn to its updated value, needed for next steps self._after_execution() if self._aws_logs_enabled(): - ... # TODO return last log line but task_log_fetcher will always be None here + # same behavior as non-deferrable mode, return last line of logs of the task. + logs_client = AwsLogsHook(aws_conn_id=self.aws_conn_id, region_name=self.region).conn + one_log = logs_client.get_log_events( + logGroupName=self.awslogs_group, + logStreamName=self._get_logs_stream_name(), + startFromHead=False, + limit=1, + ) + if len(one_log["events"]) > 0: + return one_log["events"][0]["message"] @provide_session def _after_execution(self, session=None): @@ -700,16 +710,18 @@ def _wait_for_task_ended(self) -> None: def _aws_logs_enabled(self): return self.awslogs_group and self.awslogs_stream_prefix + def _get_logs_stream_name(self) -> str: + return f"{self.awslogs_stream_prefix}/{self.ecs_task_id}" + def _get_task_log_fetcher(self) -> EcsTaskLogFetcher: if not self.awslogs_group: raise ValueError("must specify awslogs_group to fetch task logs") - log_stream_name = f"{self.awslogs_stream_prefix}/{self.ecs_task_id}" return EcsTaskLogFetcher( aws_conn_id=self.aws_conn_id, region_name=self.awslogs_region, log_group=self.awslogs_group, - log_stream_name=log_stream_name, + log_stream_name=self._get_logs_stream_name(), fetch_interval=self.awslogs_fetch_interval, logger=self.log, ) From 02ec64e789b0a1bc01fc3eeaf71ed31c69bdd7ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Vandon?= Date: Tue, 13 Jun 2023 16:09:40 -0700 Subject: [PATCH 13/20] is this the right integration name ? --- airflow/providers/amazon/provider.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/providers/amazon/provider.yaml b/airflow/providers/amazon/provider.yaml index 9466cb6273a39..257623f8ccb55 100644 --- a/airflow/providers/amazon/provider.yaml +++ b/airflow/providers/amazon/provider.yaml @@ -528,7 +528,7 @@ triggers: - integration-name: Amazon EMR python-modules: - airflow.providers.amazon.aws.triggers.emr - - integration-name: Amazon Elastic Containers Service (ECS) + - integration-name: Amazon ECS python-modules: - airflow.providers.amazon.aws.triggers.ecs From 960d1c5ed765a844ae72f0279221bd5cd43a3baa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Vandon?= Date: Wed, 14 Jun 2023 16:08:25 -0700 Subject: [PATCH 14/20] add timeouts --- airflow/providers/amazon/aws/operators/ecs.py | 31 ++++++++++++------- 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/airflow/providers/amazon/aws/operators/ecs.py b/airflow/providers/amazon/aws/operators/ecs.py index 45c83ead9bfe8..08af560f8419f 100644 --- a/airflow/providers/amazon/aws/operators/ecs.py +++ b/airflow/providers/amazon/aws/operators/ecs.py @@ -116,8 +116,8 @@ def __init__( cluster_name: str, create_cluster_kwargs: dict | None = None, wait_for_completion: bool = True, - waiter_delay: int | None = None, - waiter_max_attempts: int | None = None, + waiter_delay: int = 15, + waiter_max_attempts: int = 60, deferrable: bool = False, **kwargs, ) -> None: @@ -154,6 +154,9 @@ def execute(self, context: Context): region=self.region, ), method_name="_complete_exec_with_cluster_desc", + # timeout is set to ensure that if a trigger dies, the timeout does not restart + # 60 seconds is added to allow the trigger to exit gracefully (i.e. yield TriggerEvent) + timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay + 60), ) elif self.wait_for_completion: waiter = self.hook.get_waiter("cluster_active") @@ -196,8 +199,8 @@ def __init__( *, cluster_name: str, wait_for_completion: bool = True, - waiter_delay: int | None = None, - waiter_max_attempts: int | None = None, + waiter_delay: int = 15, + waiter_max_attempts: int = 60, deferrable: bool = False, **kwargs, ) -> None: @@ -229,6 +232,9 @@ def execute(self, context: Context): region=self.region, ), method_name="_complete_exec_with_cluster_desc", + # timeout is set to ensure that if a trigger dies, the timeout does not restart + # 60 seconds is added to allow the trigger to exit gracefully (i.e. yield TriggerEvent) + timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay + 60), ) elif self.wait_for_completion: waiter = self.hook.get_waiter("cluster_inactive") @@ -502,8 +508,8 @@ def __init__( reattach: bool = False, number_logs_exception: int = 10, wait_for_completion: bool = True, - waiter_delay: int | None = None, - waiter_max_attempts: int | None = None, + waiter_delay: int = 6, + waiter_max_attempts: int = 100, deferrable: bool = False, **kwargs, ): @@ -605,6 +611,9 @@ def _start_wait_task(self, context): log_stream=f"{self.awslogs_stream_prefix}/{self.ecs_task_id}", ), method_name="execute_complete", + # timeout is set to ensure that if a trigger dies, the timeout does not restart + # 60 seconds is added to allow the trigger to exit gracefully (i.e. yield TriggerEvent) + timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay + 60), ) elif self._aws_logs_enabled(): self.log.info("Starting ECS Task Log Fetcher") @@ -697,12 +706,10 @@ def _wait_for_task_ended(self) -> None: waiter.wait( cluster=self.cluster, tasks=[self.arn], - WaiterConfig=prune_dict( - { - "Delay": self.waiter_delay, - "MaxAttempts": self.waiter_max_attempts, - } - ), + WaiterConfig={ + "Delay": self.waiter_delay, + "MaxAttempts": self.waiter_max_attempts, + }, ) return From 7141b0ec997c3da981656ea1306ee6c804382087 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Vandon?= Date: Wed, 14 Jun 2023 16:17:34 -0700 Subject: [PATCH 15/20] fix CI + some fix --- airflow/providers/amazon/aws/operators/ecs.py | 1 + airflow/providers/amazon/aws/triggers/ecs.py | 10 +++++++--- tests/providers/amazon/aws/triggers/test_ecs.py | 14 +++++++++----- 3 files changed, 17 insertions(+), 8 deletions(-) diff --git a/airflow/providers/amazon/aws/operators/ecs.py b/airflow/providers/amazon/aws/operators/ecs.py index 08af560f8419f..c0def19d1bfb0 100644 --- a/airflow/providers/amazon/aws/operators/ecs.py +++ b/airflow/providers/amazon/aws/operators/ecs.py @@ -605,6 +605,7 @@ def _start_wait_task(self, context): cluster=self.cluster, task_arn=self.arn, waiter_delay=self.waiter_delay, + waiter_max_attempts=self.waiter_max_attempts, aws_conn_id=self.aws_conn_id, region=self.region, log_group=self.awslogs_group, diff --git a/airflow/providers/amazon/aws/triggers/ecs.py b/airflow/providers/amazon/aws/triggers/ecs.py index 39ea94379496c..345cce2e84bae 100644 --- a/airflow/providers/amazon/aws/triggers/ecs.py +++ b/airflow/providers/amazon/aws/triggers/ecs.py @@ -113,7 +113,8 @@ def __init__( self, cluster: str, task_arn: str, - waiter_delay: int | None, + waiter_delay: int, + waiter_max_attempts: int, aws_conn_id: str | None, region: str | None, log_group: str | None = None, @@ -122,7 +123,8 @@ def __init__( self.cluster = cluster self.task_arn = task_arn - self.waiter_delay = waiter_delay or 15 + self.waiter_delay = waiter_delay + self.waiter_max_attempts = waiter_max_attempts self.aws_conn_id = aws_conn_id self.region = region @@ -136,6 +138,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]: "cluster": self.cluster, "task_arn": self.task_arn, "waiter_delay": self.waiter_delay, + "waiter_max_attempts": self.waiter_max_attempts, "aws_conn_id": self.aws_conn_id, "region": self.region, "log_group": self.log_group, @@ -150,7 +153,8 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # fmt: on waiter = ecs_client.get_waiter("tasks_stopped") logs_token = None - while True: + while self.waiter_max_attempts >= 1: + self.waiter_max_attempts = self.waiter_max_attempts - 1 try: await waiter.wait( cluster=self.cluster, tasks=[self.task_arn], WaiterConfig={"MaxAttempts": 1} diff --git a/tests/providers/amazon/aws/triggers/test_ecs.py b/tests/providers/amazon/aws/triggers/test_ecs.py index 6177454f90f41..8a13585fb13a5 100644 --- a/tests/providers/amazon/aws/triggers/test_ecs.py +++ b/tests/providers/amazon/aws/triggers/test_ecs.py @@ -24,6 +24,7 @@ from airflow import AirflowException from airflow.providers.amazon.aws.hooks.ecs import EcsHook +from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook from airflow.providers.amazon.aws.triggers.ecs import ClusterWaiterTrigger, TaskDoneTrigger from airflow.triggers.base import TriggerEvent @@ -82,7 +83,9 @@ async def test_run_error(self, client_mock): class TestTaskDoneTrigger: @pytest.mark.asyncio @mock.patch.object(EcsHook, "async_conn") - async def test_run_until_error(self, client_mock): + # this mock is only necessary to avoid a "No module named 'aiobotocore'" error in the LatestBoto CI step + @mock.patch.object(AwsLogsHook, "async_conn") + async def test_run_until_error(self, _, client_mock): a_mock = mock.MagicMock() client_mock.__aenter__.return_value = a_mock wait_mock = AsyncMock() @@ -93,8 +96,7 @@ async def test_run_until_error(self, client_mock): ] a_mock.get_waiter().wait = wait_mock - trigger = TaskDoneTrigger("cluster", "task_arn", 0, None, None) - trigger.waiter_delay = 0 # cannot be set to 0 in __init__ because 0 is treated as None + trigger = TaskDoneTrigger("cluster", "task_arn", 0, 10, None, None) with pytest.raises(WaiterError): generator = trigger.run() @@ -104,13 +106,15 @@ async def test_run_until_error(self, client_mock): @pytest.mark.asyncio @mock.patch.object(EcsHook, "async_conn") - async def test_run_success(self, client_mock): + # this mock is only necessary to avoid a "No module named 'aiobotocore'" error in the LatestBoto CI step + @mock.patch.object(AwsLogsHook, "async_conn") + async def test_run_success(self, _, client_mock): a_mock = mock.MagicMock() client_mock.__aenter__.return_value = a_mock wait_mock = AsyncMock() a_mock.get_waiter().wait = wait_mock - trigger = TaskDoneTrigger("cluster", "my_task_arn", 0, None, None) + trigger = TaskDoneTrigger("cluster", "my_task_arn", 0, 10, None, None) generator = trigger.run() response: TriggerEvent = await generator.asend(None) From 2eb2252d66569b3c43c0589d178bd4a8fabeef28 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Vandon?= Date: Thu, 15 Jun 2023 14:46:50 -0700 Subject: [PATCH 16/20] adjust expected value in test --- tests/providers/amazon/aws/operators/test_ecs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/providers/amazon/aws/operators/test_ecs.py b/tests/providers/amazon/aws/operators/test_ecs.py index dc2c8460e8391..9f9a53e67ad66 100644 --- a/tests/providers/amazon/aws/operators/test_ecs.py +++ b/tests/providers/amazon/aws/operators/test_ecs.py @@ -343,7 +343,7 @@ def test_wait_end_tasks(self, client_mock): self.ecs._wait_for_task_ended() client_mock.get_waiter.assert_called_once_with("tasks_stopped") client_mock.get_waiter.return_value.wait.assert_called_once_with( - cluster="c", tasks=["arn"], WaiterConfig={} + cluster="c", tasks=["arn"], WaiterConfig={"Delay": 6, "MaxAttempts": 100} ) assert sys.maxsize == client_mock.get_waiter.return_value.config.max_attempts From 6ecd01979e6d05459c2e9c9ccfd28229baf61d52 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Vandon?= Date: Fri, 16 Jun 2023 15:18:01 -0700 Subject: [PATCH 17/20] fix --- airflow/providers/amazon/aws/operators/ecs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/providers/amazon/aws/operators/ecs.py b/airflow/providers/amazon/aws/operators/ecs.py index 30e570b6ae693..751f824ac96a0 100644 --- a/airflow/providers/amazon/aws/operators/ecs.py +++ b/airflow/providers/amazon/aws/operators/ecs.py @@ -619,7 +619,7 @@ def _start_wait_task(self, context): aws_conn_id=self.aws_conn_id, region=self.region, log_group=self.awslogs_group, - log_stream=f"{self.awslogs_stream_prefix}/{self.ecs_task_id}", + log_stream=self._get_logs_stream_name(), ), method_name="execute_complete", # timeout is set to ensure that if a trigger dies, the timeout does not restart From 439bbd92ee615f7f4aec4b5cbb6e14f9c56806dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Vandon?= Date: Tue, 20 Jun 2023 16:11:07 -0700 Subject: [PATCH 18/20] rename method in test --- tests/providers/amazon/aws/utils/test_task_log_fetcher.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/providers/amazon/aws/utils/test_task_log_fetcher.py b/tests/providers/amazon/aws/utils/test_task_log_fetcher.py index dbda751cfbead..a5598ebf552c0 100644 --- a/tests/providers/amazon/aws/utils/test_task_log_fetcher.py +++ b/tests/providers/amazon/aws/utils/test_task_log_fetcher.py @@ -112,7 +112,7 @@ def test_event_to_str(self): {"timestamp": 1617400367456, "message": "Second"}, {"timestamp": 1617400467789, "message": "Third"}, ] - assert [self.log_fetcher._event_to_str(event) for event in events] == ( + assert [self.log_fetcher.event_to_str(event) for event in events] == ( [ "[2021-04-02 21:51:07,123] First", "[2021-04-02 21:52:47,456] Second", From c5d50ad0b0889b2343fd3137b606570d5d5fab71 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Vandon?= Date: Thu, 22 Jun 2023 13:56:08 -0700 Subject: [PATCH 19/20] use newly available wait method --- airflow/providers/amazon/aws/triggers/ecs.py | 33 +++++++------------- 1 file changed, 11 insertions(+), 22 deletions(-) diff --git a/airflow/providers/amazon/aws/triggers/ecs.py b/airflow/providers/amazon/aws/triggers/ecs.py index fe95115f88cf9..8ba8350588365 100644 --- a/airflow/providers/amazon/aws/triggers/ecs.py +++ b/airflow/providers/amazon/aws/triggers/ecs.py @@ -22,10 +22,10 @@ from botocore.exceptions import ClientError, WaiterError -from airflow import AirflowException from airflow.providers.amazon.aws.hooks.ecs import EcsHook from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook from airflow.providers.amazon.aws.utils.task_log_fetcher import AwsTaskLogFetcher +from airflow.providers.amazon.aws.utils.waiter_with_logging import async_wait from airflow.triggers.base import BaseTrigger, TriggerEvent @@ -74,27 +74,16 @@ def serialize(self) -> tuple[str, dict[str, Any]]: async def run(self) -> AsyncIterator[TriggerEvent]: async with EcsHook(aws_conn_id=self.aws_conn_id, region_name=self.region).async_conn as client: waiter = client.get_waiter(self.waiter_name) - while self.attempts >= 1: - self.attempts = self.attempts - 1 - try: - await waiter.wait( - clusters=[self.cluster_arn], - WaiterConfig={ - "MaxAttempts": 1, - }, - ) - # we reach this point only if the waiter met a success criteria - yield TriggerEvent({"status": "success", "arn": self.cluster_arn}) - return - except WaiterError as error: - if "terminal failure" in str(error): - raise - self.log.info("Status of cluster is %s", error.last_response["clusters"][0]["status"]) - await asyncio.sleep(int(self.waiter_delay)) - - raise AirflowException( - "Cluster still not in expected status after the max number of tries has been reached" - ) + await async_wait( + waiter, + self.waiter_delay, + self.attempts, + {"clusters": [self.cluster_arn]}, + "error when checking cluster status", + "Status of cluster", + ["clusters[].status"], + ) + yield TriggerEvent({"status": "success", "arn": self.cluster_arn}) class TaskDoneTrigger(BaseTrigger): From 4e65d269a62eda66684c598c4f800654a228356f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Vandon?= Date: Thu, 22 Jun 2023 16:12:54 -0700 Subject: [PATCH 20/20] fix test --- tests/providers/amazon/aws/triggers/test_ecs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/providers/amazon/aws/triggers/test_ecs.py b/tests/providers/amazon/aws/triggers/test_ecs.py index 8a13585fb13a5..09b5decbe631a 100644 --- a/tests/providers/amazon/aws/triggers/test_ecs.py +++ b/tests/providers/amazon/aws/triggers/test_ecs.py @@ -75,7 +75,7 @@ async def test_run_error(self, client_mock): trigger = ClusterWaiterTrigger("my_waiter", "cluster_arn", 0, 5, None, None) - with pytest.raises(WaiterError): + with pytest.raises(AirflowException): generator = trigger.run() await generator.asend(None)