diff --git a/airflow/providers/amazon/aws/operators/emr.py b/airflow/providers/amazon/aws/operators/emr.py index 220a4ddea0032..fc0af516d1930 100644 --- a/airflow/providers/amazon/aws/operators/emr.py +++ b/airflow/providers/amazon/aws/operators/emr.py @@ -19,6 +19,7 @@ import ast import warnings +from datetime import timedelta from typing import TYPE_CHECKING, Any, Sequence from uuid import uuid4 @@ -26,6 +27,7 @@ from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook, EmrServerlessHook from airflow.providers.amazon.aws.links.emr import EmrClusterLink +from airflow.providers.amazon.aws.triggers.emr import EmrContainerOperatorTrigger from airflow.providers.amazon.aws.utils.waiter import waiter from airflow.utils.helpers import exactly_one, prune_dict from airflow.utils.types import NOTSET, ArgNotSet @@ -437,6 +439,7 @@ class EmrContainerOperator(BaseOperator): Defaults to None, which will poll until the job is *not* in a pending, submitted, or running state. :param tags: The tags assigned to job runs. Defaults to None + :param deferrable: Run operator in the deferrable mode. """ template_fields: Sequence[str] = ( @@ -465,6 +468,7 @@ def __init__( max_tries: int | None = None, tags: dict | None = None, max_polling_attempts: int | None = None, + deferrable: bool = False, **kwargs: Any, ) -> None: super().__init__(**kwargs) @@ -481,6 +485,7 @@ def __init__( self.max_polling_attempts = max_polling_attempts self.tags = tags self.job_id: str | None = None + self.deferrable = deferrable if max_tries: warnings.warn( @@ -513,6 +518,22 @@ def execute(self, context: Context) -> str | None: self.client_request_token, self.tags, ) + if self.deferrable: + timeout = ( + timedelta(seconds=self.max_polling_attempts * self.poll_interval + 60) + if self.max_polling_attempts + else self.execution_timeout + ) + self.defer( + timeout=timeout, + trigger=EmrContainerOperatorTrigger( + virtual_cluster_id=self.virtual_cluster_id, + job_id=self.job_id, + aws_conn_id=self.aws_conn_id, + poll_interval=self.poll_interval, + ), + method_name="execute_complete", + ) if self.wait_for_completion: query_status = self.hook.poll_query_status( self.job_id, @@ -534,6 +555,12 @@ def execute(self, context: Context) -> str | None: return self.job_id + def execute_complete(self, context, event=None): + if event["status"] != "success": + raise AirflowException(f"Error while running job: {event}") + else: + self.log.info(event["message"]) + def on_kill(self) -> None: """Cancel the submitted job run""" if self.job_id: diff --git a/airflow/providers/amazon/aws/triggers/emr.py b/airflow/providers/amazon/aws/triggers/emr.py new file mode 100644 index 0000000000000..15b89172de7e9 --- /dev/null +++ b/airflow/providers/amazon/aws/triggers/emr.py @@ -0,0 +1,97 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import asyncio +from typing import Any, AsyncIterator + +from botocore.exceptions import WaiterError + +from airflow.compat.functools import cached_property +from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook +from airflow.triggers.base import BaseTrigger, TriggerEvent + + +class EmrContainerOperatorTrigger(BaseTrigger): + """ + Poll for the status of EMR container until reaches terminal state + + :param virtual_cluster_id: Reference Emr cluster id + :param job_id: job_id to check the state + :param aws_conn_id: Reference to AWS connection id + :param poll_interval: polling period in seconds to check for the status + """ + + def __init__( + self, + virtual_cluster_id: str, + job_id: str, + aws_conn_id: str = "aws_default", + poll_interval: int = 10, + **kwargs: Any, + ): + self.virtual_cluster_id = virtual_cluster_id + self.job_id = job_id + self.aws_conn_id = aws_conn_id + self.poll_interval = poll_interval + super().__init__(**kwargs) + + @cached_property + def hook(self) -> EmrContainerHook: + return EmrContainerHook(self.aws_conn_id, virtual_cluster_id=self.virtual_cluster_id) + + def serialize(self) -> tuple[str, dict[str, Any]]: + """Serializes EmrContainerOperatorTrigger arguments and classpath.""" + return ( + "airflow.providers.amazon.aws.triggers.emr.EmrContainerOperatorTrigger", + { + "virtual_cluster_id": self.virtual_cluster_id, + "job_id": self.job_id, + "aws_conn_id": self.aws_conn_id, + "poll_interval": self.poll_interval, + }, + ) + + async def run(self) -> AsyncIterator[TriggerEvent]: + async with self.hook.async_conn as client: + waiter = self.hook.get_waiter("container_job_complete", deferrable=True, client=client) + attempt = 0 + while True: + attempt = attempt + 1 + try: + await waiter.wait( + id=self.job_id, + virtualClusterId=self.virtual_cluster_id, + WaiterConfig={ + "Delay": self.poll_interval, + "MaxAttempts": 1, + }, + ) + break + except WaiterError as error: + if "terminal failure" in str(error): + yield TriggerEvent({"status": "failure", "message": f"Job Failed: {error}"}) + break + self.log.info( + "Job status is %s. Retrying attempt %s", + error.last_response["jobRun"]["state"], + attempt, + ) + await asyncio.sleep(int(self.poll_interval)) + + yield TriggerEvent({"status": "success", "job_id": self.job_id}) diff --git a/airflow/providers/amazon/aws/waiters/emr-containers.json b/airflow/providers/amazon/aws/waiters/emr-containers.json new file mode 100644 index 0000000000000..a4174b0536e50 --- /dev/null +++ b/airflow/providers/amazon/aws/waiters/emr-containers.json @@ -0,0 +1,30 @@ +{ + "version": 2, + "waiters": { + "container_job_complete": { + "operation": "DescribeJobRun", + "delay": 30, + "maxAttempts": 60, + "acceptors": [ + { + "matcher": "path", + "argument": "jobRun.state", + "expected": "COMPLETED", + "state": "success" + }, + { + "matcher": "path", + "argument": "jobRun.state", + "expected": "FAILED", + "state": "failure" + }, + { + "matcher": "path", + "argument": "jobRun.state", + "expected": "CANCELLED", + "state": "failure" + } + ] + } + } +} diff --git a/tests/providers/amazon/aws/operators/test_emr_containers.py b/tests/providers/amazon/aws/operators/test_emr_containers.py index ddc11b15c56ce..8646e000fe997 100644 --- a/tests/providers/amazon/aws/operators/test_emr_containers.py +++ b/tests/providers/amazon/aws/operators/test_emr_containers.py @@ -22,9 +22,10 @@ import pytest from airflow.configuration import conf -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, TaskDeferred from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook from airflow.providers.amazon.aws.operators.emr import EmrContainerOperator, EmrEksCreateClusterOperator +from airflow.providers.amazon.aws.triggers.emr import EmrContainerOperatorTrigger SUBMIT_JOB_SUCCESS_RETURN = { "ResponseMetadata": {"HTTPStatusCode": 200}, @@ -144,6 +145,16 @@ def test_execute_with_polling_timeout(self, mock_check_query_status): assert "Final state of EMR Containers job is SUBMITTED" in str(ctx.value) assert "Max tries of poll status exceeded" in str(ctx.value) + @mock.patch.object(EmrContainerHook, "submit_job") + def test_operator_defer(self, mock_submit_job): + self.emr_container.deferrable = True + self.emr_container.wait_for_completion = False + with pytest.raises(TaskDeferred) as exc: + self.emr_container.execute(context=None) + assert isinstance( + exc.value.trigger, EmrContainerOperatorTrigger + ), "Trigger is not a EmrContainerOperatorTrigger" + class TestEmrEksCreateClusterOperator: def setup_method(self): diff --git a/tests/providers/amazon/aws/triggers/test_emr_containers.py b/tests/providers/amazon/aws/triggers/test_emr_containers.py new file mode 100644 index 0000000000000..9965c2ddec188 --- /dev/null +++ b/tests/providers/amazon/aws/triggers/test_emr_containers.py @@ -0,0 +1,69 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import pytest + +from airflow.providers.amazon.aws.triggers.emr import EmrContainerOperatorTrigger +from airflow.triggers.base import TriggerEvent +from tests.providers.amazon.aws.utils.compat import AsyncMock, async_mock + +VIRTUAL_CLUSTER_ID = "vzwemreks" +JOB_ID = "job-1234" +AWS_CONN_ID = "aws_emr_conn" +POLL_INTERVAL = 60 +MAX_ATTEMPTS = 5 + + +class TestEmrContainerSensorTrigger: + def test_emr_container_operator_trigger_serialize(self): + emr_trigger = EmrContainerOperatorTrigger( + virtual_cluster_id=VIRTUAL_CLUSTER_ID, + job_id=JOB_ID, + aws_conn_id=AWS_CONN_ID, + poll_interval=POLL_INTERVAL, + max_attempts=MAX_ATTEMPTS, + ) + class_path, args = emr_trigger.serialize() + assert class_path == "airflow.providers.amazon.aws.triggers.emr.EmrContainerOperatorTrigger" + assert args["virtual_cluster_id"] == VIRTUAL_CLUSTER_ID + assert args["job_id"] == JOB_ID + assert args["aws_conn_id"] == AWS_CONN_ID + assert args["poll_interval"] == POLL_INTERVAL + assert args["max_attempts"] == MAX_ATTEMPTS + + @pytest.mark.asyncio + @async_mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrContainerHook.get_waiter") + @async_mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrContainerHook.async_conn") + async def test_emr_container_trigger_run(self, mock_async_conn, mock_get_waiter): + mock = async_mock.MagicMock() + mock_async_conn.__aenter__.return_value = mock + + mock_get_waiter().wait = AsyncMock() + + emr_trigger = EmrContainerOperatorTrigger( + virtual_cluster_id=VIRTUAL_CLUSTER_ID, + job_id=JOB_ID, + aws_conn_id=AWS_CONN_ID, + poll_interval=POLL_INTERVAL, + max_attempts=MAX_ATTEMPTS, + ) + + generator = emr_trigger.run() + response = await generator.asend(None) + + assert response == TriggerEvent({"status": "success", "message": "Job completed."})