Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 45 additions & 3 deletions airflow/providers/amazon/aws/operators/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,11 @@
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook, EmrServerlessHook
from airflow.providers.amazon.aws.links.emr import EmrClusterLink, EmrLogsLink, get_log_uri
from airflow.providers.amazon.aws.triggers.emr import EmrAddStepsTrigger, EmrCreateJobFlowTrigger
from airflow.providers.amazon.aws.triggers.emr import (
EmrAddStepsTrigger,
EmrCreateJobFlowTrigger,
EmrTerminateJobFlowTrigger,
)
from airflow.providers.amazon.aws.utils.waiter import waiter
from airflow.utils.helpers import exactly_one, prune_dict
from airflow.utils.types import NOTSET, ArgNotSet
Expand Down Expand Up @@ -842,6 +846,11 @@ class EmrTerminateJobFlowOperator(BaseOperator):

:param job_flow_id: id of the JobFlow to terminate. (templated)
:param aws_conn_id: aws connection to uses
:param waiter_delay: Time (in seconds) to wait between two consecutive calls to check JobFlow status
:param waiter_max_attempts: The maximum number of times to poll for JobFlow status.
:param deferrable: If True, the operator will wait asynchronously for the crawl to complete.
This implies waiting for completion. This mode requires aiobotocore module to be installed.
(default: False)
"""

template_fields: Sequence[str] = ("job_flow_id",)
Expand All @@ -852,10 +861,22 @@ class EmrTerminateJobFlowOperator(BaseOperator):
EmrLogsLink(),
)

def __init__(self, *, job_flow_id: str, aws_conn_id: str = "aws_default", **kwargs):
def __init__(
self,
*,
job_flow_id: str,
aws_conn_id: str = "aws_default",
waiter_delay: int = 60,
waiter_max_attempts: int = 20,
deferrable: bool = False,
**kwargs,
):
super().__init__(**kwargs)
self.job_flow_id = job_flow_id
self.aws_conn_id = aws_conn_id
self.waiter_delay = waiter_delay
self.waiter_max_attempts = waiter_max_attempts
self.deferrable = deferrable

def execute(self, context: Context) -> None:
emr_hook = EmrHook(aws_conn_id=self.aws_conn_id)
Expand Down Expand Up @@ -883,7 +904,28 @@ def execute(self, context: Context) -> None:
if not response["ResponseMetadata"]["HTTPStatusCode"] == 200:
raise AirflowException(f"JobFlow termination failed: {response}")
else:
self.log.info("JobFlow with id %s terminated", self.job_flow_id)
self.log.info("Terminating JobFlow with id %s", self.job_flow_id)

if self.deferrable:
self.defer(
trigger=EmrTerminateJobFlowTrigger(
job_flow_id=self.job_flow_id,
poll_interval=self.waiter_delay,
max_attempts=self.waiter_max_attempts,
aws_conn_id=self.aws_conn_id,
),
method_name="execute_complete",
# timeout is set to ensure that if a trigger dies, the timeout does not restart
# 60 seconds is added to allow the trigger to exit gracefully (i.e. yield TriggerEvent)
timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay + 60),
)

def execute_complete(self, context, event=None):
if event["status"] != "success":
raise AirflowException(f"Error terminating JobFlow: {event}")
else:
self.log.info("Jobflow terminated successfully.")
return


class EmrServerlessCreateApplicationOperator(BaseOperator):
Expand Down
73 changes: 73 additions & 0 deletions airflow/providers/amazon/aws/triggers/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,3 +173,76 @@ async def run(self):
"job_flow_id": self.job_flow_id,
}
)


class EmrTerminateJobFlowTrigger(BaseTrigger):
"""
Trigger that terminates a running EMR Job Flow.
The trigger will asynchronously poll the boto3 API and wait for the
JobFlow to finish terminating.

:param job_flow_id: ID of the EMR Job Flow to terminate
:param poll_interval: The amount of time in seconds to wait between attempts.
:param max_attempts: The maximum number of attempts to be made.
:param aws_conn_id: The Airflow connection used for AWS credentials.
"""

def __init__(
self,
job_flow_id: str,
poll_interval: int,
max_attempts: int,
aws_conn_id: str,
):
self.job_flow_id = job_flow_id
self.poll_interval = poll_interval
self.max_attempts = max_attempts
self.aws_conn_id = aws_conn_id

def serialize(self) -> tuple[str, dict[str, Any]]:
return (
self.__class__.__module__ + "." + self.__class__.__qualname__,
{
"job_flow_id": self.job_flow_id,
"poll_interval": str(self.poll_interval),
"max_attempts": str(self.max_attempts),
"aws_conn_id": self.aws_conn_id,
},
)

async def run(self):
self.hook = EmrHook(aws_conn_id=self.aws_conn_id)
async with self.hook.async_conn as client:
attempt = 0
waiter = self.hook.get_waiter("job_flow_terminated", deferrable=True, client=client)
while attempt < int(self.max_attempts):
attempt = attempt + 1
try:
await waiter.wait(
ClusterId=self.job_flow_id,
WaiterConfig=prune_dict(
{
"Delay": self.poll_interval,
"MaxAttempts": 1,
}
),
)
break
except WaiterError as error:
if "terminal failure" in str(error):
raise AirflowException(f"JobFlow termination failed: {error}")
self.log.info(
"Status of jobflow is %s - %s",
error.last_response["Cluster"]["Status"]["State"],
error.last_response["Cluster"]["Status"]["StateChangeReason"],
)
await asyncio.sleep(int(self.poll_interval))
if attempt >= int(self.max_attempts):
raise AirflowException(f"JobFlow termination failed - max attempts reached: {self.max_attempts}")
else:
yield TriggerEvent(
{
"status": "success",
"message": "JobFlow terminated successfully",
}
)
19 changes: 19 additions & 0 deletions airflow/providers/amazon/aws/waiters/emr.json
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,25 @@
"state": "failure"
}
]
},
"job_flow_terminated": {
"operation": "DescribeCluster",
"delay": 30,
"maxAttempts": 60,
"acceptors": [
{
"matcher": "path",
"argument": "Cluster.Status.State",
"expected": "TERMINATED",
"state": "success"
},
{
"matcher": "path",
"argument": "Cluster.Status.State",
"expected": "TERMINATED_WITH_ERRORS",
"state": "failure"
}
]
}
}
}
4 changes: 4 additions & 0 deletions docs/apache-airflow-providers-amazon/operators/emr/emr.rst
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,10 @@ Terminate an EMR job flow

To terminate an EMR Job Flow you can use
:class:`~airflow.providers.amazon.aws.operators.emr.EmrTerminateJobFlowOperator`.
This operator can be run in deferrable mode by passing ``deferrable=True`` as a parameter.
Using ``deferrable`` mode will release worker slots and leads to efficient utilization of
resources within Airflow cluster.However this mode will need the Airflow triggerer to be
available in your deployment.

.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_emr.py
:language: python
Expand Down
4 changes: 2 additions & 2 deletions tests/providers/amazon/aws/hooks/test_emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ class TestEmrHook:
def test_service_waiters(self):
hook = EmrHook(aws_conn_id=None)
official_waiters = hook.conn.waiter_names
custom_waiters = ["job_flow_waiting", "notebook_running", "notebook_stopped"]
custom_waiters = ["job_flow_waiting", "job_flow_terminated", "notebook_running", "notebook_stopped"]

assert hook.list_waiters() == [*official_waiters, *custom_waiters]
assert sorted(hook.list_waiters()) == sorted([*official_waiters, *custom_waiters])

@mock_emr
def test_get_conn_returns_a_boto3_connection(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,12 @@

from unittest.mock import MagicMock, patch

import pytest

from airflow.exceptions import TaskDeferred
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.providers.amazon.aws.operators.emr import EmrTerminateJobFlowOperator
from airflow.providers.amazon.aws.triggers.emr import EmrTerminateJobFlowTrigger

TERMINATE_SUCCESS_RETURN = {"ResponseMetadata": {"HTTPStatusCode": 200}}

Expand Down Expand Up @@ -48,3 +52,22 @@ def test_execute_terminates_the_job_flow_and_does_not_error(self, _):
)

operator.execute(MagicMock())

@patch.object(S3Hook, "parse_s3_url", return_value="valid_uri")
def test_create_job_flow_deferrable(self, _):
with patch("boto3.session.Session", self.boto3_session_mock), patch(
"airflow.providers.amazon.aws.hooks.base_aws.isinstance"
) as mock_isinstance:
mock_isinstance.return_value = True
operator = EmrTerminateJobFlowOperator(
task_id="test_task",
job_flow_id="j-8989898989",
aws_conn_id="aws_default",
deferrable=True,
)
with pytest.raises(TaskDeferred) as exc:
operator.execute(MagicMock())

assert isinstance(
exc.value.trigger, EmrTerminateJobFlowTrigger
), "Trigger is not a EmrTerminateJobFlowTrigger"
Loading