From d60913bfd921c4cd7094c5629b524a6cc7e75130 Mon Sep 17 00:00:00 2001 From: Pankaj Date: Tue, 25 Apr 2023 20:58:43 +0530 Subject: [PATCH 1/3] Add deferrable param in BatchOperator Add the deferrable param in BatchOperator. This will allow running BatchOperator in an async way that means we only submit a job from the worker to run a batch job then defer to the trigger for polling and wait for a job the job status and the worker slot won't be occupied for the whole period of task execution. --- .../providers/amazon/aws/operators/batch.py | 27 ++++++ .../providers/amazon/aws/triggers/batch.py | 82 +++++++++++++++++++ .../providers/amazon/aws/waiters/batch.json | 19 +++++ .../amazon/aws/operators/test_batch.py | 19 ++++- .../amazon/aws/triggers/test_batch.py | 69 ++++++++++++++++ 5 files changed, 215 insertions(+), 1 deletion(-) create mode 100644 airflow/providers/amazon/aws/triggers/batch.py create mode 100644 airflow/providers/amazon/aws/waiters/batch.json create mode 100644 tests/providers/amazon/aws/triggers/test_batch.py diff --git a/airflow/providers/amazon/aws/operators/batch.py b/airflow/providers/amazon/aws/operators/batch.py index 272122d1093b9..e5dd4a87c39b5 100644 --- a/airflow/providers/amazon/aws/operators/batch.py +++ b/airflow/providers/amazon/aws/operators/batch.py @@ -38,6 +38,7 @@ BatchJobQueueLink, ) from airflow.providers.amazon.aws.links.logs import CloudWatchEventsLink +from airflow.providers.amazon.aws.triggers.batch import BatchOperatorTrigger from airflow.providers.amazon.aws.utils import trim_none_values if TYPE_CHECKING: @@ -79,6 +80,8 @@ class BatchOperator(BaseOperator): Override the region_name in connection (if provided) :param tags: collection of tags to apply to the AWS Batch job submission if None, no tags are submitted + :param deferrable: Run operator in the deferrable mode. + :param poll_interval: (Deferrable mode only) Time in second to wait between polling. .. note:: Any custom waiters must return a waiter for these calls: @@ -142,6 +145,8 @@ def __init__( region_name: str | None = None, tags: dict | None = None, wait_for_completion: bool = True, + deferrable: bool = False, + poll_interval: int = 30, **kwargs, ): @@ -175,6 +180,8 @@ def __init__( self.waiters = waiters self.tags = tags or {} self.wait_for_completion = wait_for_completion + self.deferrable = deferrable + self.poll_interval = poll_interval # params for hook self.max_retries = max_retries @@ -199,11 +206,31 @@ def execute(self, context: Context): """ self.submit_job(context) + if self.deferrable: + self.defer( + timeout=self.execution_timeout, + trigger=BatchOperatorTrigger( + job_id=self.job_id, + max_retries=self.max_retries or 10, + aws_conn_id=self.aws_conn_id, + region_name=self.region_name, + poll_interval=self.poll_interval, + ), + method_name="execute_complete", + ) + if self.wait_for_completion: self.monitor_job(context) return self.job_id + def execute_complete(self, context, event=None): + if event["status"] != "success": + raise AirflowException(f"Error while running job: {event}") + else: + self.log.info("Job completed.") + return event["job_id"] + def on_kill(self): response = self.hook.client.terminate_job(jobId=self.job_id, reason="Task killed by the user") self.log.info("AWS Batch job (%s) terminated: %s", self.job_id, response) diff --git a/airflow/providers/amazon/aws/triggers/batch.py b/airflow/providers/amazon/aws/triggers/batch.py new file mode 100644 index 0000000000000..a67cf5b91f22e --- /dev/null +++ b/airflow/providers/amazon/aws/triggers/batch.py @@ -0,0 +1,82 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import Any + +from airflow.compat.functools import cached_property +from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook +from airflow.triggers.base import BaseTrigger, TriggerEvent + + +class BatchOperatorTrigger(BaseTrigger): + """ + Trigger for BatchOperator. + The trigger will asynchronously poll the boto3 API and wait for the + Batch job to be in the `SUCCEEDED` state. + + :param job_id: A unique identifier for the cluster. + :param max_retries: The maximum number of attempts to be made. + :param aws_conn_id: The Airflow connection used for AWS credentials. + :param region_name: region name to use in AWS Hook + :param poll_interval: The amount of time in seconds to wait between attempts. + """ + + def __init__( + self, + job_id: str | None = None, + max_retries: int = 10, + aws_conn_id: str | None = "aws_default", + region_name: str | None = None, + poll_interval: int = 30, + ): + super().__init__() + self.job_id = job_id + self.max_retries = max_retries + self.aws_conn_id = aws_conn_id + self.region_name = region_name + self.poll_interval = poll_interval + + def serialize(self) -> tuple[str, dict[str, Any]]: + """Serializes BatchOperatorTrigger arguments and classpath.""" + return ( + "airflow.providers.amazon.aws.triggers.batch.BatchOperatorTrigger", + { + "job_id": self.job_id, + "max_retries": self.max_retries, + "aws_conn_id": self.aws_conn_id, + "region_name": self.region_name, + "poll_interval": self.poll_interval, + }, + ) + + @cached_property + def hook(self) -> BatchClientHook: + return BatchClientHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) + + async def run(self): + + async with self.hook.async_conn as client: + waiter = self.hook.get_waiter("JobComplete", deferrable=True, client=client) + await waiter.wait( + jobs=[self.job_id], + WaiterConfig={ + "Delay": self.poll_interval, + "MaxAttempts": self.max_retries, + }, + ) + yield TriggerEvent({"status": "success", "job_id": self.job_id}) diff --git a/airflow/providers/amazon/aws/waiters/batch.json b/airflow/providers/amazon/aws/waiters/batch.json new file mode 100644 index 0000000000000..ed98c03d5a31d --- /dev/null +++ b/airflow/providers/amazon/aws/waiters/batch.json @@ -0,0 +1,19 @@ +{ + "version": 2, + "waiters": { + "JobComplete": { + "delay": 300, + "operation": "DescribeJobs", + "maxAttempts": 100, + "description": "Wait until job is SUCCEEDED", + "acceptors": [ + { + "argument": "jobs[].status", + "expected": "SUCCEEDED", + "matcher": "pathAll", + "state": "success" + } + ] + } + } +} diff --git a/tests/providers/amazon/aws/operators/test_batch.py b/tests/providers/amazon/aws/operators/test_batch.py index 42f7fae86c9e1..f559424dff8cc 100644 --- a/tests/providers/amazon/aws/operators/test_batch.py +++ b/tests/providers/amazon/aws/operators/test_batch.py @@ -22,11 +22,13 @@ import pytest -from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning +from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, TaskDeferred from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook from airflow.providers.amazon.aws.operators.batch import BatchCreateComputeEnvironmentOperator, BatchOperator # Use dummy AWS credentials +from airflow.providers.amazon.aws.triggers.batch import BatchOperatorTrigger + AWS_REGION = "eu-west-1" AWS_ACCESS_KEY_ID = "airflow_dummy_key" AWS_SECRET_ACCESS_KEY = "airflow_dummy_secret" @@ -256,6 +258,21 @@ def test_cant_set_old_and_new_override_param(self): container_overrides={"a": "b"}, ) + @mock.patch("airflow.providers.amazon.aws.hooks.batch_client.AwsBaseHook.get_client_type") + def test_defer_if_deferrable_param_set(self, mock_client): + batch = BatchOperator( + task_id="task", + job_name=JOB_NAME, + job_queue="queue", + job_definition="hello-world", + do_xcom_push=False, + deferrable=True, + ) + + with pytest.raises(TaskDeferred) as exc: + batch.execute(context=None) + assert isinstance(exc.value.trigger, BatchOperatorTrigger), "Trigger is not a BatchOperatorTrigger" + class TestBatchCreateComputeEnvironmentOperator: @mock.patch.object(BatchClientHook, "client") diff --git a/tests/providers/amazon/aws/triggers/test_batch.py b/tests/providers/amazon/aws/triggers/test_batch.py new file mode 100644 index 0000000000000..54a8765ab4669 --- /dev/null +++ b/tests/providers/amazon/aws/triggers/test_batch.py @@ -0,0 +1,69 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import pytest + +from airflow.providers.amazon.aws.triggers.batch import BatchOperatorTrigger +from airflow.triggers.base import TriggerEvent +from tests.providers.amazon.aws.utils.compat import AsyncMock, async_mock + +BATCH_JOB_ID = "job_id" +POLL_INTERVAL = 5 +MAX_ATTEMPT = 5 +AWS_CONN_ID = "aws_batch_job_conn" +AWS_REGION = "us-east-2" + + +class TestBatchOperatorTrigger: + def test_batch_operator_trigger_serialize(self): + batch_trigger = BatchOperatorTrigger( + job_id=BATCH_JOB_ID, + poll_interval=POLL_INTERVAL, + max_retries=MAX_ATTEMPT, + aws_conn_id=AWS_CONN_ID, + region_name=AWS_REGION, + ) + class_path, args = batch_trigger.serialize() + assert class_path == "airflow.providers.amazon.aws.triggers.batch.BatchOperatorTrigger" + assert args["job_id"] == BATCH_JOB_ID + assert args["poll_interval"] == POLL_INTERVAL + assert args["max_retries"] == MAX_ATTEMPT + assert args["aws_conn_id"] == AWS_CONN_ID + assert args["region_name"] == AWS_REGION + + @pytest.mark.asyncio + @async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientHook.get_waiter") + @async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientHook.async_conn") + async def test_batch_job_trigger_run(self, mock_async_conn, mock_get_waiter): + mock = async_mock.MagicMock() + mock_async_conn.__aenter__.return_value = mock + + mock_get_waiter().wait = AsyncMock() + + batch_trigger = BatchOperatorTrigger( + job_id=BATCH_JOB_ID, + poll_interval=POLL_INTERVAL, + max_retries=MAX_ATTEMPT, + aws_conn_id=AWS_CONN_ID, + region_name=AWS_REGION, + ) + + generator = batch_trigger.run() + response = await generator.asend(None) + + assert response == TriggerEvent({"status": "success", "job_id": BATCH_JOB_ID}) From b6d9a78afc5c0bef94d7a7a453d8cfc702f7f666 Mon Sep 17 00:00:00 2001 From: Pankaj Date: Thu, 25 May 2023 20:58:12 +0530 Subject: [PATCH 2/3] Add logs in trigger --- .../providers/amazon/aws/triggers/batch.py | 43 +++++++++++++++---- .../providers/amazon/aws/waiters/batch.json | 8 +++- 2 files changed, 41 insertions(+), 10 deletions(-) diff --git a/airflow/providers/amazon/aws/triggers/batch.py b/airflow/providers/amazon/aws/triggers/batch.py index a67cf5b91f22e..fb60b7ea916a8 100644 --- a/airflow/providers/amazon/aws/triggers/batch.py +++ b/airflow/providers/amazon/aws/triggers/batch.py @@ -16,8 +16,11 @@ # under the License. from __future__ import annotations +import asyncio from typing import Any +from botocore.exceptions import WaiterError + from airflow.compat.functools import cached_property from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook from airflow.triggers.base import BaseTrigger, TriggerEvent @@ -71,12 +74,34 @@ def hook(self) -> BatchClientHook: async def run(self): async with self.hook.async_conn as client: - waiter = self.hook.get_waiter("JobComplete", deferrable=True, client=client) - await waiter.wait( - jobs=[self.job_id], - WaiterConfig={ - "Delay": self.poll_interval, - "MaxAttempts": self.max_retries, - }, - ) - yield TriggerEvent({"status": "success", "job_id": self.job_id}) + waiter = self.hook.get_waiter("batch_job_complete", deferrable=True, client=client) + attempt = 0 + while attempt < self.max_retries: + attempt = attempt + 1 + try: + await waiter.wait( + jobs=[self.job_id], + WaiterConfig={ + "Delay": self.poll_interval, + "MaxAttempts": 1, + }, + ) + break + except WaiterError as error: + if "terminal failure" in str(error): + yield TriggerEvent( + {"status": "failure", "message": f"Delete Cluster Failed: {error}"} + ) + break + self.log.info( + "Job status is %s. Retrying attempt %s/%s", + error.last_response["jobs"][0]["status"], + attempt, + self.max_retries, + ) + await asyncio.sleep(int(self.poll_interval)) + + if attempt >= self.max_retries: + yield TriggerEvent({"status": "failure", "message": "Job Failed - max attempts reached."}) + else: + yield TriggerEvent({"status": "success", "job_id": self.job_id}) diff --git a/airflow/providers/amazon/aws/waiters/batch.json b/airflow/providers/amazon/aws/waiters/batch.json index ed98c03d5a31d..fa9752ea14c41 100644 --- a/airflow/providers/amazon/aws/waiters/batch.json +++ b/airflow/providers/amazon/aws/waiters/batch.json @@ -1,7 +1,7 @@ { "version": 2, "waiters": { - "JobComplete": { + "batch_job_complete": { "delay": 300, "operation": "DescribeJobs", "maxAttempts": 100, @@ -12,6 +12,12 @@ "expected": "SUCCEEDED", "matcher": "pathAll", "state": "success" + }, + { + "argument": "jobs[].status", + "expected": "FAILED", + "matcher": "pathAll", + "state": "failed" } ] } From cc7da71e177fb53276837eb61daa6180d5fe141e Mon Sep 17 00:00:00 2001 From: Pankaj Singh <98807258+pankajastro@users.noreply.github.com> Date: Thu, 25 May 2023 23:32:37 +0530 Subject: [PATCH 3/3] Apply review suggestion Co-authored-by: Jed Cunningham <66968678+jedcunningham@users.noreply.github.com> --- airflow/providers/amazon/aws/operators/batch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/providers/amazon/aws/operators/batch.py b/airflow/providers/amazon/aws/operators/batch.py index e5dd4a87c39b5..cbeb0cbcba1a9 100644 --- a/airflow/providers/amazon/aws/operators/batch.py +++ b/airflow/providers/amazon/aws/operators/batch.py @@ -81,7 +81,7 @@ class BatchOperator(BaseOperator): :param tags: collection of tags to apply to the AWS Batch job submission if None, no tags are submitted :param deferrable: Run operator in the deferrable mode. - :param poll_interval: (Deferrable mode only) Time in second to wait between polling. + :param poll_interval: (Deferrable mode only) Time in seconds to wait between polling. .. note:: Any custom waiters must return a waiter for these calls: