diff --git a/airflow/providers/amazon/aws/sensors/emr.py b/airflow/providers/amazon/aws/sensors/emr.py index 140edec404996..ed1b56c24d2ba 100644 --- a/airflow/providers/amazon/aws/sensors/emr.py +++ b/airflow/providers/amazon/aws/sensors/emr.py @@ -17,6 +17,7 @@ # under the License. from __future__ import annotations +from datetime import timedelta from typing import TYPE_CHECKING, Any, Iterable, Sequence from deprecated import deprecated @@ -25,6 +26,7 @@ from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook, EmrServerlessHook from airflow.providers.amazon.aws.hooks.s3 import S3Hook from airflow.providers.amazon.aws.links.emr import EmrLogsLink +from airflow.providers.amazon.aws.triggers.emr import EmrJobFlowSensorTrigger from airflow.sensors.base import BaseSensorOperator if TYPE_CHECKING: @@ -381,6 +383,8 @@ class EmrJobFlowSensor(EmrBaseSensor): job flow reaches any of these states :param failed_states: the failure states, sensor fails when job flow reaches any of these states + :param deferrable: Run sensor in the deferrable mode. + :param max_attempts: Maximum number of tries before failing """ template_fields: Sequence[str] = ("job_flow_id", "target_states", "failed_states") @@ -393,12 +397,16 @@ def __init__( job_flow_id: str, target_states: Iterable[str] | None = None, failed_states: Iterable[str] | None = None, + deferrable: bool = False, + max_attempts: int = 60, **kwargs, ): super().__init__(**kwargs) self.job_flow_id = job_flow_id self.target_states = target_states or ["TERMINATED"] self.failed_states = failed_states or ["TERMINATED_WITH_ERRORS"] + self.deferrable = deferrable + self.max_attempts = max_attempts def get_emr_response(self, context: Context) -> dict[str, Any]: """ @@ -450,6 +458,26 @@ def failure_message_from_response(response: dict[str, Any]) -> str | None: ) return None + def execute(self, context: Context) -> None: + if self.deferrable and not self.poke(context): + self.defer( + timeout=timedelta(seconds=self.timeout), + trigger=EmrJobFlowSensorTrigger( + job_flow_id=self.job_flow_id, + target_states=self.target_states, + aws_conn_id=self.aws_conn_id, + max_attempts=self.max_attempts, + poke_interval=self.poke_interval, + ), + method_name="execute_complete", + ) + + def execute_complete(self, context, event=None): + if event["status"] != "success": + raise AirflowException(f"Error while running job: {event}") + else: + self.log.info("Job completed.") + class EmrStepSensor(EmrBaseSensor): """ diff --git a/airflow/providers/amazon/aws/triggers/emr.py b/airflow/providers/amazon/aws/triggers/emr.py new file mode 100644 index 0000000000000..7693af5fdd13f --- /dev/null +++ b/airflow/providers/amazon/aws/triggers/emr.py @@ -0,0 +1,87 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from typing import Any, AsyncIterator, Iterable + +from airflow.compat.functools import cached_property +from airflow.providers.amazon.aws.hooks.emr import EmrHook +from airflow.triggers.base import BaseTrigger, TriggerEvent + + +class EmrJobFlowSensorTrigger(BaseTrigger): + """ + Poll for the status of EMR cluster until reaches terminal state + + :param job_flow_id: EMR job flow id + :param target_states: the target states, sensor waits until + step reaches any of these states + :param aws_conn_id: Reference to AWS connection id + :param max_attempts: maximum try attempts for polling the status + :param poll_interval: polling period in seconds to check for the status + """ + + def __init__( + self, + job_flow_id: str, + target_states: Iterable[str], + aws_conn_id: str = "aws_default", + max_attempts: int = 60, + poll_interval: int = 30, + **kwargs: Any, + ): + self.job_flow_id = job_flow_id + self.target_states = target_states + self.aws_conn_id = aws_conn_id + self.max_attempts = max_attempts + self.poll_interval = poll_interval + super().__init__(**kwargs) + + @cached_property + def hook(self) -> EmrHook: + return EmrHook(aws_conn_id=self.aws_conn_id) + + def serialize(self) -> tuple[str, dict[str, Any]]: + return ( + "airflow.providers.amazon.aws.triggers.emr.EmrJobFlowSensorTrigger", + { + "job_flow_id": self.job_flow_id, + "target_states": self.target_states, + "aws_conn_id": self.aws_conn_id, + "max_attempts": self.max_attempts, + "poll_interval": self.poll_interval, + }, + ) + + async def run(self) -> AsyncIterator[TriggerEvent]: + async with self.hook.async_conn as client: + waiter = self.hook.get_waiter("job_flow_terminate", deferrable=True, client=client) + await waiter.wait( + ClusterId=self.job_flow_id, + WaiterConfig={ + "Delay": self.poll_interval, + "MaxAttempts": self.max_attempts, + }, + ) + + response = self.hook.conn.describe_cluster(ClusterId=self.job_flow_id) + state = response["Cluster"]["Status"]["State"] + if state in self.target_states: + yield TriggerEvent({"status": "success"}) + else: + yield TriggerEvent({"status": "failed", "response": response}) diff --git a/airflow/providers/amazon/aws/waiters/emr.json b/airflow/providers/amazon/aws/waiters/emr.json index 78afee6b544d1..08441af0788be 100644 --- a/airflow/providers/amazon/aws/waiters/emr.json +++ b/airflow/providers/amazon/aws/waiters/emr.json @@ -75,6 +75,25 @@ "state": "failure" } ] + }, + "job_flow_terminate": { + "operation": "DescribeCluster", + "delay": 30, + "maxAttempts": 60, + "acceptors": [ + { + "matcher": "path", + "argument": "cluster.status", + "expected": "TERMINATED", + "state": "success" + }, + { + "matcher": "path", + "argument": "cluster.status", + "expected": "TERMINATED_WITH_ERRORS", + "state": "failure" + } + ] } } } diff --git a/tests/providers/amazon/aws/sensors/test_emr_job_flow.py b/tests/providers/amazon/aws/sensors/test_emr_job_flow.py index c81d9d4855747..e567be3184d3a 100644 --- a/tests/providers/amazon/aws/sensors/test_emr_job_flow.py +++ b/tests/providers/amazon/aws/sensors/test_emr_job_flow.py @@ -24,8 +24,9 @@ import pytest from dateutil.tz import tzlocal -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, TaskDeferred from airflow.providers.amazon.aws.sensors.emr import EmrJobFlowSensor +from airflow.providers.amazon.aws.triggers.emr import EmrJobFlowSensorTrigger DESCRIBE_CLUSTER_STARTING_RETURN = { "Cluster": { @@ -276,3 +277,20 @@ def test_different_target_states(self): # make sure it was called with the job_flow_id calls = [mock.call(ClusterId="j-8989898989")] self.mock_emr_client.describe_cluster.assert_has_calls(calls) + + @mock.patch("airflow.providers.amazon.aws.sensors.emr.EmrJobFlowSensor.poke") + def test_sensor_defer(self, mock_poke): + sensor = EmrJobFlowSensor( + task_id="test_task", + poke_interval=0, + job_flow_id="j-8989898989", + aws_conn_id="aws_default", + target_states=["RUNNING", "WAITING"], + deferrable=True, + ) + mock_poke.return_value = False + with pytest.raises(TaskDeferred) as exc: + sensor.execute(context=None) + assert isinstance( + exc.value.trigger, EmrJobFlowSensorTrigger + ), "Trigger is not a EmrJobFlowSensorTrigger " diff --git a/tests/providers/amazon/aws/triggers/test_emr.py b/tests/providers/amazon/aws/triggers/test_emr.py new file mode 100644 index 0000000000000..72dca9e62c674 --- /dev/null +++ b/tests/providers/amazon/aws/triggers/test_emr.py @@ -0,0 +1,73 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest.mock import PropertyMock + +import pytest + +from airflow.providers.amazon.aws.triggers.emr import EmrJobFlowSensorTrigger +from tests.providers.amazon.aws.utils.compat import AsyncMock, async_mock + +JOB__FLOW_ID = "job-1234" +TARGET_STATE = ["TERMINATED"] +AWS_CONN_ID = "aws_emr_conn" +POLL_INTERVAL = 60 +MAX_ATTEMPTS = 5 + + +class TestEmrEmrJobFlowSensorTrigger: + def test_emr_job_flow_sensor_trigger_serialize(self): + emr_trigger = EmrJobFlowSensorTrigger( + job_flow_id=JOB__FLOW_ID, + target_states=TARGET_STATE, + aws_conn_id=AWS_CONN_ID, + poll_interval=POLL_INTERVAL, + max_attempts=MAX_ATTEMPTS, + ) + class_path, args = emr_trigger.serialize() + assert class_path == "airflow.providers.amazon.aws.triggers.emr.EmrJobFlowSensorTrigger" + assert args["job_flow_id"] == JOB__FLOW_ID + assert args["target_states"] == TARGET_STATE + assert args["aws_conn_id"] == AWS_CONN_ID + assert args["poll_interval"] == POLL_INTERVAL + assert args["max_attempts"] == MAX_ATTEMPTS + + @pytest.mark.asyncio + @async_mock.patch( + "airflow.providers.amazon.aws.hooks.base_aws.AwsGenericHook.conn", new_callable=PropertyMock + ) + @async_mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrHook.get_waiter") + @async_mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrHook.async_conn") + async def test_emr_flow_sensor_trigger_run(self, mock_async_conn, mock_get_waiter, mock_conn): + mock = async_mock.MagicMock() + mock_async_conn.__aenter__.return_value = mock + + mock_get_waiter().wait = AsyncMock() + + emr_trigger = EmrJobFlowSensorTrigger( + job_flow_id=JOB__FLOW_ID, + target_states=TARGET_STATE, + aws_conn_id=AWS_CONN_ID, + poll_interval=POLL_INTERVAL, + max_attempts=MAX_ATTEMPTS, + ) + + generator = emr_trigger.run() + await generator.asend(None) + + mock_conn.return_value.describe_cluster.assert_called_once_with(ClusterId=JOB__FLOW_ID)