diff --git a/airflow/providers/amazon/aws/operators/emr.py b/airflow/providers/amazon/aws/operators/emr.py index 15ef379aad7d3..b8ca53226e807 100644 --- a/airflow/providers/amazon/aws/operators/emr.py +++ b/airflow/providers/amazon/aws/operators/emr.py @@ -28,7 +28,11 @@ from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook, EmrServerlessHook from airflow.providers.amazon.aws.links.emr import EmrClusterLink, EmrLogsLink, get_log_uri -from airflow.providers.amazon.aws.triggers.emr import EmrAddStepsTrigger, EmrCreateJobFlowTrigger +from airflow.providers.amazon.aws.triggers.emr import ( + EmrAddStepsTrigger, + EmrCreateJobFlowTrigger, + EmrTerminateJobFlowTrigger, +) from airflow.providers.amazon.aws.utils.waiter import waiter from airflow.utils.helpers import exactly_one, prune_dict from airflow.utils.types import NOTSET, ArgNotSet @@ -842,6 +846,11 @@ class EmrTerminateJobFlowOperator(BaseOperator): :param job_flow_id: id of the JobFlow to terminate. (templated) :param aws_conn_id: aws connection to uses + :param waiter_delay: Time (in seconds) to wait between two consecutive calls to check JobFlow status + :param waiter_max_attempts: The maximum number of times to poll for JobFlow status. + :param deferrable: If True, the operator will wait asynchronously for the crawl to complete. + This implies waiting for completion. This mode requires aiobotocore module to be installed. + (default: False) """ template_fields: Sequence[str] = ("job_flow_id",) @@ -852,10 +861,22 @@ class EmrTerminateJobFlowOperator(BaseOperator): EmrLogsLink(), ) - def __init__(self, *, job_flow_id: str, aws_conn_id: str = "aws_default", **kwargs): + def __init__( + self, + *, + job_flow_id: str, + aws_conn_id: str = "aws_default", + waiter_delay: int = 60, + waiter_max_attempts: int = 20, + deferrable: bool = False, + **kwargs, + ): super().__init__(**kwargs) self.job_flow_id = job_flow_id self.aws_conn_id = aws_conn_id + self.waiter_delay = waiter_delay + self.waiter_max_attempts = waiter_max_attempts + self.deferrable = deferrable def execute(self, context: Context) -> None: emr_hook = EmrHook(aws_conn_id=self.aws_conn_id) @@ -883,7 +904,28 @@ def execute(self, context: Context) -> None: if not response["ResponseMetadata"]["HTTPStatusCode"] == 200: raise AirflowException(f"JobFlow termination failed: {response}") else: - self.log.info("JobFlow with id %s terminated", self.job_flow_id) + self.log.info("Terminating JobFlow with id %s", self.job_flow_id) + + if self.deferrable: + self.defer( + trigger=EmrTerminateJobFlowTrigger( + job_flow_id=self.job_flow_id, + poll_interval=self.waiter_delay, + max_attempts=self.waiter_max_attempts, + aws_conn_id=self.aws_conn_id, + ), + method_name="execute_complete", + # timeout is set to ensure that if a trigger dies, the timeout does not restart + # 60 seconds is added to allow the trigger to exit gracefully (i.e. yield TriggerEvent) + timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay + 60), + ) + + def execute_complete(self, context, event=None): + if event["status"] != "success": + raise AirflowException(f"Error terminating JobFlow: {event}") + else: + self.log.info("Jobflow terminated successfully.") + return class EmrServerlessCreateApplicationOperator(BaseOperator): diff --git a/airflow/providers/amazon/aws/triggers/emr.py b/airflow/providers/amazon/aws/triggers/emr.py index 76ee47bc8baa6..1c3c8bb8338b2 100644 --- a/airflow/providers/amazon/aws/triggers/emr.py +++ b/airflow/providers/amazon/aws/triggers/emr.py @@ -173,3 +173,76 @@ async def run(self): "job_flow_id": self.job_flow_id, } ) + + +class EmrTerminateJobFlowTrigger(BaseTrigger): + """ + Trigger that terminates a running EMR Job Flow. + The trigger will asynchronously poll the boto3 API and wait for the + JobFlow to finish terminating. + + :param job_flow_id: ID of the EMR Job Flow to terminate + :param poll_interval: The amount of time in seconds to wait between attempts. + :param max_attempts: The maximum number of attempts to be made. + :param aws_conn_id: The Airflow connection used for AWS credentials. + """ + + def __init__( + self, + job_flow_id: str, + poll_interval: int, + max_attempts: int, + aws_conn_id: str, + ): + self.job_flow_id = job_flow_id + self.poll_interval = poll_interval + self.max_attempts = max_attempts + self.aws_conn_id = aws_conn_id + + def serialize(self) -> tuple[str, dict[str, Any]]: + return ( + self.__class__.__module__ + "." + self.__class__.__qualname__, + { + "job_flow_id": self.job_flow_id, + "poll_interval": str(self.poll_interval), + "max_attempts": str(self.max_attempts), + "aws_conn_id": self.aws_conn_id, + }, + ) + + async def run(self): + self.hook = EmrHook(aws_conn_id=self.aws_conn_id) + async with self.hook.async_conn as client: + attempt = 0 + waiter = self.hook.get_waiter("job_flow_terminated", deferrable=True, client=client) + while attempt < int(self.max_attempts): + attempt = attempt + 1 + try: + await waiter.wait( + ClusterId=self.job_flow_id, + WaiterConfig=prune_dict( + { + "Delay": self.poll_interval, + "MaxAttempts": 1, + } + ), + ) + break + except WaiterError as error: + if "terminal failure" in str(error): + raise AirflowException(f"JobFlow termination failed: {error}") + self.log.info( + "Status of jobflow is %s - %s", + error.last_response["Cluster"]["Status"]["State"], + error.last_response["Cluster"]["Status"]["StateChangeReason"], + ) + await asyncio.sleep(int(self.poll_interval)) + if attempt >= int(self.max_attempts): + raise AirflowException(f"JobFlow termination failed - max attempts reached: {self.max_attempts}") + else: + yield TriggerEvent( + { + "status": "success", + "message": "JobFlow terminated successfully", + } + ) diff --git a/airflow/providers/amazon/aws/waiters/emr.json b/airflow/providers/amazon/aws/waiters/emr.json index 13bc5857e30cc..d27cd08e00145 100644 --- a/airflow/providers/amazon/aws/waiters/emr.json +++ b/airflow/providers/amazon/aws/waiters/emr.json @@ -75,6 +75,25 @@ "state": "failure" } ] + }, + "job_flow_terminated": { + "operation": "DescribeCluster", + "delay": 30, + "maxAttempts": 60, + "acceptors": [ + { + "matcher": "path", + "argument": "Cluster.Status.State", + "expected": "TERMINATED", + "state": "success" + }, + { + "matcher": "path", + "argument": "Cluster.Status.State", + "expected": "TERMINATED_WITH_ERRORS", + "state": "failure" + } + ] } } } diff --git a/docs/apache-airflow-providers-amazon/operators/emr/emr.rst b/docs/apache-airflow-providers-amazon/operators/emr/emr.rst index d26a427a64e19..8a2255ddbf4cc 100644 --- a/docs/apache-airflow-providers-amazon/operators/emr/emr.rst +++ b/docs/apache-airflow-providers-amazon/operators/emr/emr.rst @@ -111,6 +111,10 @@ Terminate an EMR job flow To terminate an EMR Job Flow you can use :class:`~airflow.providers.amazon.aws.operators.emr.EmrTerminateJobFlowOperator`. +This operator can be run in deferrable mode by passing ``deferrable=True`` as a parameter. +Using ``deferrable`` mode will release worker slots and leads to efficient utilization of +resources within Airflow cluster.However this mode will need the Airflow triggerer to be +available in your deployment. .. exampleinclude:: /../../tests/system/providers/amazon/aws/example_emr.py :language: python diff --git a/tests/providers/amazon/aws/hooks/test_emr.py b/tests/providers/amazon/aws/hooks/test_emr.py index 918710e8ce4a1..8ff9bc5a36edf 100644 --- a/tests/providers/amazon/aws/hooks/test_emr.py +++ b/tests/providers/amazon/aws/hooks/test_emr.py @@ -32,9 +32,9 @@ class TestEmrHook: def test_service_waiters(self): hook = EmrHook(aws_conn_id=None) official_waiters = hook.conn.waiter_names - custom_waiters = ["job_flow_waiting", "notebook_running", "notebook_stopped"] + custom_waiters = ["job_flow_waiting", "job_flow_terminated", "notebook_running", "notebook_stopped"] - assert hook.list_waiters() == [*official_waiters, *custom_waiters] + assert sorted(hook.list_waiters()) == sorted([*official_waiters, *custom_waiters]) @mock_emr def test_get_conn_returns_a_boto3_connection(self): diff --git a/tests/providers/amazon/aws/operators/test_emr_terminate_job_flow.py b/tests/providers/amazon/aws/operators/test_emr_terminate_job_flow.py index 9443402e4e684..509dcfee0ca65 100644 --- a/tests/providers/amazon/aws/operators/test_emr_terminate_job_flow.py +++ b/tests/providers/amazon/aws/operators/test_emr_terminate_job_flow.py @@ -19,8 +19,12 @@ from unittest.mock import MagicMock, patch +import pytest + +from airflow.exceptions import TaskDeferred from airflow.providers.amazon.aws.hooks.s3 import S3Hook from airflow.providers.amazon.aws.operators.emr import EmrTerminateJobFlowOperator +from airflow.providers.amazon.aws.triggers.emr import EmrTerminateJobFlowTrigger TERMINATE_SUCCESS_RETURN = {"ResponseMetadata": {"HTTPStatusCode": 200}} @@ -48,3 +52,22 @@ def test_execute_terminates_the_job_flow_and_does_not_error(self, _): ) operator.execute(MagicMock()) + + @patch.object(S3Hook, "parse_s3_url", return_value="valid_uri") + def test_create_job_flow_deferrable(self, _): + with patch("boto3.session.Session", self.boto3_session_mock), patch( + "airflow.providers.amazon.aws.hooks.base_aws.isinstance" + ) as mock_isinstance: + mock_isinstance.return_value = True + operator = EmrTerminateJobFlowOperator( + task_id="test_task", + job_flow_id="j-8989898989", + aws_conn_id="aws_default", + deferrable=True, + ) + with pytest.raises(TaskDeferred) as exc: + operator.execute(MagicMock()) + + assert isinstance( + exc.value.trigger, EmrTerminateJobFlowTrigger + ), "Trigger is not a EmrTerminateJobFlowTrigger" diff --git a/tests/providers/amazon/aws/triggers/test_emr.py b/tests/providers/amazon/aws/triggers/test_emr.py index c749c4ee9abcf..5d599801a280a 100644 --- a/tests/providers/amazon/aws/triggers/test_emr.py +++ b/tests/providers/amazon/aws/triggers/test_emr.py @@ -24,7 +24,7 @@ from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.emr import EmrHook -from airflow.providers.amazon.aws.triggers.emr import EmrCreateJobFlowTrigger +from airflow.providers.amazon.aws.triggers.emr import EmrCreateJobFlowTrigger, EmrTerminateJobFlowTrigger from airflow.triggers.base import TriggerEvent TEST_JOB_FLOW_ID = "test-job-flow-id" @@ -198,3 +198,155 @@ async def test_emr_create_job_flow_trigger_run_attempts_failed( assert str(exc.value) == f"JobFlow creation failed: {error_failed}" assert mock_get_waiter().wait.call_count == 3 + + +class TestEmrTerminateJobFlowTrigger: + def test_emr_terminate_job_flow_trigger_serialize(self): + emr_terminate_job_flow_trigger = EmrTerminateJobFlowTrigger( + job_flow_id=TEST_JOB_FLOW_ID, + aws_conn_id=TEST_AWS_CONN_ID, + poll_interval=TEST_POLL_INTERVAL, + max_attempts=TEST_MAX_ATTEMPTS, + ) + class_path, args = emr_terminate_job_flow_trigger.serialize() + assert class_path == "airflow.providers.amazon.aws.triggers.emr.EmrTerminateJobFlowTrigger" + assert args["job_flow_id"] == TEST_JOB_FLOW_ID + assert args["aws_conn_id"] == TEST_AWS_CONN_ID + assert args["poll_interval"] == str(TEST_POLL_INTERVAL) + assert args["max_attempts"] == str(TEST_MAX_ATTEMPTS) + + @pytest.mark.asyncio + @mock.patch.object(EmrHook, "get_waiter") + @mock.patch.object(EmrHook, "async_conn") + async def test_emr_terminate_job_flow_trigger_run(self, mock_async_conn, mock_get_waiter): + a_mock = mock.MagicMock() + mock_async_conn.__aenter__.return_value = a_mock + + mock_get_waiter().wait = AsyncMock() + + emr_terminate_job_flow_trigger = EmrTerminateJobFlowTrigger( + job_flow_id=TEST_JOB_FLOW_ID, + aws_conn_id=TEST_AWS_CONN_ID, + poll_interval=TEST_POLL_INTERVAL, + max_attempts=TEST_MAX_ATTEMPTS, + ) + + generator = emr_terminate_job_flow_trigger.run() + response = await generator.asend(None) + + assert response == TriggerEvent( + { + "status": "success", + "message": "JobFlow terminated successfully", + } + ) + + @pytest.mark.asyncio + @mock.patch("asyncio.sleep") + @mock.patch.object(EmrHook, "get_waiter") + @mock.patch.object(EmrHook, "async_conn") + async def test_emr_terminate_job_flow_trigger_run_multiple_attempts( + self, mock_async_conn, mock_get_waiter, mock_sleep + ): + a_mock = mock.MagicMock() + mock_async_conn.__aenter__.return_value = a_mock + error = WaiterError( + name="test_name", + reason="test_reason", + last_response={ + "Cluster": {"Status": {"State": "TERMINATING", "StateChangeReason": "test-reason"}} + }, + ) + mock_get_waiter().wait.side_effect = AsyncMock(side_effect=[error, error, True]) + mock_sleep.return_value = True + + emr_terminate_job_flow_trigger = EmrTerminateJobFlowTrigger( + job_flow_id=TEST_JOB_FLOW_ID, + aws_conn_id=TEST_AWS_CONN_ID, + poll_interval=TEST_POLL_INTERVAL, + max_attempts=TEST_MAX_ATTEMPTS, + ) + + generator = emr_terminate_job_flow_trigger.run() + response = await generator.asend(None) + + assert mock_get_waiter().wait.call_count == 3 + assert response == TriggerEvent( + { + "status": "success", + "message": "JobFlow terminated successfully", + } + ) + + @pytest.mark.asyncio + @mock.patch("asyncio.sleep") + @mock.patch.object(EmrHook, "get_waiter") + @mock.patch.object(EmrHook, "async_conn") + async def test_emr_terminate_job_flow_trigger_run_attempts_exceeded( + self, mock_async_conn, mock_get_waiter, mock_sleep + ): + a_mock = mock.MagicMock() + mock_async_conn.__aenter__.return_value = a_mock + error = WaiterError( + name="test_name", + reason="test_reason", + last_response={ + "Cluster": {"Status": {"State": "TERMINATING", "StateChangeReason": "test-reason"}} + }, + ) + mock_get_waiter().wait.side_effect = AsyncMock(side_effect=[error, error, True]) + mock_sleep.return_value = True + + emr_terminate_job_flow_trigger = EmrTerminateJobFlowTrigger( + job_flow_id=TEST_JOB_FLOW_ID, + aws_conn_id=TEST_AWS_CONN_ID, + poll_interval=TEST_POLL_INTERVAL, + max_attempts=2, + ) + with pytest.raises(AirflowException) as exc: + generator = emr_terminate_job_flow_trigger.run() + await generator.asend(None) + + assert str(exc.value) == "JobFlow termination failed - max attempts reached: 2" + assert mock_get_waiter().wait.call_count == 2 + + @pytest.mark.asyncio + @mock.patch("asyncio.sleep") + @mock.patch.object(EmrHook, "get_waiter") + @mock.patch.object(EmrHook, "async_conn") + async def test_emr_terminate_job_flow_trigger_run_attempts_failed( + self, mock_async_conn, mock_get_waiter, mock_sleep + ): + a_mock = mock.MagicMock() + mock_async_conn.__aenter__.return_value = a_mock + error_starting = WaiterError( + name="test_name", + reason="test_reason", + last_response={ + "Cluster": {"Status": {"State": "TERMINATING", "StateChangeReason": "test-reason"}} + }, + ) + error_failed = WaiterError( + name="test_name", + reason="Waiter encountered a terminal failure state:", + last_response={ + "Cluster": {"Status": {"State": "TERMINATED_WITH_ERRORS", "StateChangeReason": "test-reason"}} + }, + ) + mock_get_waiter().wait.side_effect = AsyncMock( + side_effect=[error_starting, error_starting, error_failed] + ) + mock_sleep.return_value = True + + emr_terminate_job_flow_trigger = EmrTerminateJobFlowTrigger( + job_flow_id=TEST_JOB_FLOW_ID, + aws_conn_id=TEST_AWS_CONN_ID, + poll_interval=TEST_POLL_INTERVAL, + max_attempts=TEST_MAX_ATTEMPTS, + ) + with pytest.raises(AirflowException) as exc: + generator = emr_terminate_job_flow_trigger.run() + await generator.asend(None) + + assert str(exc.value) == f"JobFlow termination failed: {error_failed}" + assert mock_get_waiter().wait.call_count == 3