Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 34 additions & 4 deletions airflow/providers/amazon/aws/sensors/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# under the License.
from __future__ import annotations

from datetime import timedelta
from functools import cached_property
from typing import TYPE_CHECKING, Any, Iterable, Sequence

Expand All @@ -25,6 +26,7 @@
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook, EmrServerlessHook
from airflow.providers.amazon.aws.links.emr import EmrClusterLink, EmrLogsLink, get_log_uri
from airflow.providers.amazon.aws.triggers.emr import EmrContainerSensorTrigger
from airflow.sensors.base import BaseSensorOperator

if TYPE_CHECKING:
Expand Down Expand Up @@ -241,6 +243,7 @@ class EmrContainerSensor(BaseSensorOperator):
:param aws_conn_id: aws connection to use, defaults to 'aws_default'
:param poll_interval: Time in seconds to wait between two consecutive call to
check query status on athena, defaults to 10
:param deferrable: Run sensor in the deferrable mode.
"""

INTERMEDIATE_STATES = (
Expand All @@ -267,6 +270,7 @@ def __init__(
max_retries: int | None = None,
aws_conn_id: str = "aws_default",
poll_interval: int = 10,
deferrable: bool = False,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
Expand All @@ -275,6 +279,11 @@ def __init__(
self.job_id = job_id
self.poll_interval = poll_interval
self.max_retries = max_retries
self.deferrable = deferrable

@cached_property
def hook(self) -> EmrContainerHook:
return EmrContainerHook(self.aws_conn_id, virtual_cluster_id=self.virtual_cluster_id)

def poke(self, context: Context) -> bool:
state = self.hook.poll_query_status(
Expand All @@ -290,10 +299,31 @@ def poke(self, context: Context) -> bool:
return False
return True

@cached_property
def hook(self) -> EmrContainerHook:
"""Create and return an EmrContainerHook."""
return EmrContainerHook(self.aws_conn_id, virtual_cluster_id=self.virtual_cluster_id)
def execute(self, context: Context):
if not self.deferrable:
super().execute(context=context)
else:
timeout = (
timedelta(seconds=self.max_retries * self.poll_interval + 60)
if self.max_retries
else self.execution_timeout
)
self.defer(
timeout=timeout,
trigger=EmrContainerSensorTrigger(
virtual_cluster_id=self.virtual_cluster_id,
job_id=self.job_id,
aws_conn_id=self.aws_conn_id,
poll_interval=self.poll_interval,
),
method_name="execute_complete",
)

def execute_complete(self, context, event=None):
if event["status"] != "success":
raise AirflowException(f"Error while running job: {event}")
else:
self.log.info(event["message"])


class EmrNotebookExecutionSensor(EmrBaseSensor):
Expand Down
75 changes: 73 additions & 2 deletions airflow/providers/amazon/aws/triggers/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@
from __future__ import annotations

import asyncio
from typing import Any
from functools import cached_property
from typing import Any, AsyncIterator

from botocore.exceptions import WaiterError

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.emr import EmrHook
from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook
from airflow.triggers.base import BaseTrigger, TriggerEvent
from airflow.utils.helpers import prune_dict

Expand Down Expand Up @@ -246,3 +247,73 @@ async def run(self):
"message": "JobFlow terminated successfully",
}
)


class EmrContainerSensorTrigger(BaseTrigger):
"""
Poll for the status of EMR container until reaches terminal state.

:param virtual_cluster_id: Reference Emr cluster id
:param job_id: job_id to check the state
:param aws_conn_id: Reference to AWS connection id
:param poll_interval: polling period in seconds to check for the status
"""

def __init__(
self,
virtual_cluster_id: str,
job_id: str,
aws_conn_id: str = "aws_default",
poll_interval: int = 30,
**kwargs: Any,
):
self.virtual_cluster_id = virtual_cluster_id
self.job_id = job_id
self.aws_conn_id = aws_conn_id
self.poll_interval = poll_interval
super().__init__(**kwargs)

@cached_property
def hook(self) -> EmrContainerHook:
return EmrContainerHook(self.aws_conn_id, virtual_cluster_id=self.virtual_cluster_id)

def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serializes EmrContainerSensorTrigger arguments and classpath."""
return (
"airflow.providers.amazon.aws.triggers.emr.EmrContainerSensorTrigger",
{
"virtual_cluster_id": self.virtual_cluster_id,
"job_id": self.job_id,
"aws_conn_id": self.aws_conn_id,
"poll_interval": self.poll_interval,
},
)

async def run(self) -> AsyncIterator[TriggerEvent]:
async with self.hook.async_conn as client:
waiter = self.hook.get_waiter("container_job_complete", deferrable=True, client=client)
attempt = 0
while True:
attempt = attempt + 1
try:
await waiter.wait(
id=self.job_id,
virtualClusterId=self.virtual_cluster_id,
WaiterConfig={
"Delay": self.poll_interval,
"MaxAttempts": 1,
},
)
break
except WaiterError as error:
if "terminal failure" in str(error):
yield TriggerEvent({"status": "failure", "message": f"Job Failed: {error}"})
break
self.log.info(
"Job status is %s. Retrying attempt %s",
error.last_response["jobRun"]["state"],
attempt,
)
await asyncio.sleep(int(self.poll_interval))

yield TriggerEvent({"status": "success", "job_id": self.job_id})
30 changes: 30 additions & 0 deletions airflow/providers/amazon/aws/waiters/emr-containers.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
{
"version": 2,
"waiters": {
"container_job_complete": {
"operation": "DescribeJobRun",
"delay": 30,
"maxAttempts": 60,
"acceptors": [
{
"matcher": "path",
"argument": "jobRun.state",
"expected": "COMPLETED",
"state": "success"
},
{
"matcher": "path",
"argument": "jobRun.state",
"expected": "FAILED",
"state": "failure"
},
{
"matcher": "path",
"argument": "jobRun.state",
"expected": "CANCELLED",
"state": "failure"
}
]
}
}
}
13 changes: 12 additions & 1 deletion tests/providers/amazon/aws/sensors/test_emr_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@

import pytest

from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowException, TaskDeferred
from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook
from airflow.providers.amazon.aws.sensors.emr import EmrContainerSensor
from airflow.providers.amazon.aws.triggers.emr import EmrContainerSensorTrigger


class TestEmrContainerSensor:
Expand Down Expand Up @@ -73,3 +74,13 @@ def test_poke_cancel_pending(self, mock_check_query_status):
with pytest.raises(AirflowException) as ctx:
self.sensor.poke(None)
assert "EMR Containers sensor failed" in str(ctx.value)

@mock.patch("airflow.providers.amazon.aws.sensors.emr.EmrContainerSensor.poke")
def test_sensor_defer(self, mock_poke):
self.sensor.deferrable = True
mock_poke.return_value = False
with pytest.raises(TaskDeferred) as exc:
self.sensor.execute(context=None)
assert isinstance(
exc.value.trigger, EmrContainerSensorTrigger
), "Trigger is not a EmrContainerSensorTrigger"
112 changes: 111 additions & 1 deletion tests/providers/amazon/aws/triggers/test_emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,21 @@

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.emr import EmrHook
from airflow.providers.amazon.aws.triggers.emr import EmrCreateJobFlowTrigger, EmrTerminateJobFlowTrigger
from airflow.providers.amazon.aws.triggers.emr import (
EmrContainerSensorTrigger,
EmrCreateJobFlowTrigger,
EmrTerminateJobFlowTrigger,
)
from airflow.triggers.base import TriggerEvent

TEST_JOB_FLOW_ID = "test-job-flow-id"
TEST_POLL_INTERVAL = 10
TEST_MAX_ATTEMPTS = 10
TEST_AWS_CONN_ID = "test-aws-id"
VIRTUAL_CLUSTER_ID = "vzwemreks"
JOB_ID = "job-1234"
AWS_CONN_ID = "aws_emr_conn"
POLL_INTERVAL = 60


class TestEmrCreateJobFlowTrigger:
Expand Down Expand Up @@ -350,3 +358,105 @@ async def test_emr_terminate_job_flow_trigger_run_attempts_failed(

assert str(exc.value) == f"JobFlow termination failed: {error_failed}"
assert mock_get_waiter().wait.call_count == 3


class TestEmrContainerSensorTrigger:
def test_emr_container_sensor_trigger_serialize(self):
emr_trigger = EmrContainerSensorTrigger(
virtual_cluster_id=VIRTUAL_CLUSTER_ID,
job_id=JOB_ID,
aws_conn_id=AWS_CONN_ID,
poll_interval=POLL_INTERVAL,
)
class_path, args = emr_trigger.serialize()
assert class_path == "airflow.providers.amazon.aws.triggers.emr.EmrContainerSensorTrigger"
assert args["virtual_cluster_id"] == VIRTUAL_CLUSTER_ID
assert args["job_id"] == JOB_ID
assert args["aws_conn_id"] == AWS_CONN_ID
assert args["poll_interval"] == POLL_INTERVAL

@pytest.mark.asyncio
@mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrContainerHook.get_waiter")
@mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrContainerHook.async_conn")
async def test_emr_container_trigger_run(self, mock_async_conn, mock_get_waiter):
a_mock = mock.MagicMock()
mock_async_conn.__aenter__.return_value = a_mock

mock_get_waiter().wait = AsyncMock()

emr_trigger = EmrContainerSensorTrigger(
virtual_cluster_id=VIRTUAL_CLUSTER_ID,
job_id=JOB_ID,
aws_conn_id=AWS_CONN_ID,
poll_interval=POLL_INTERVAL,
)

generator = emr_trigger.run()
response = await generator.asend(None)

assert response == TriggerEvent({"status": "success", "job_id": JOB_ID})

@pytest.mark.asyncio
@mock.patch("asyncio.sleep")
@mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrContainerHook.get_waiter")
@mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrContainerHook.async_conn")
async def test_emr_trigger_run_multiple_attempts(self, mock_async_conn, mock_get_waiter, mock_sleep):
a_mock = mock.MagicMock()
mock_async_conn.__aenter__.return_value = a_mock

error = WaiterError(
name="test_name",
reason="test_reason",
last_response={"jobRun": {"state": "RUNNING"}},
)
mock_get_waiter().wait.side_effect = AsyncMock(side_effect=[error, error, True])
mock_sleep.return_value = True

emr_trigger = EmrContainerSensorTrigger(
virtual_cluster_id=VIRTUAL_CLUSTER_ID,
job_id=JOB_ID,
aws_conn_id=AWS_CONN_ID,
poll_interval=POLL_INTERVAL,
)

generator = emr_trigger.run()
response = await generator.asend(None)

assert mock_get_waiter().wait.call_count == 3
assert response == TriggerEvent({"status": "success", "job_id": JOB_ID})

@pytest.mark.asyncio
@mock.patch("asyncio.sleep")
@mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrContainerHook.get_waiter")
@mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrContainerHook.async_conn")
async def test_emr_trigger_run_attempts_failed(self, mock_async_conn, mock_get_waiter, mock_sleep):
a_mock = mock.MagicMock()
mock_async_conn.__aenter__.return_value = a_mock

error_available = WaiterError(
name="test_name",
reason="Max attempts exceeded",
last_response={"jobRun": {"state": "FAILED"}},
)
error_failed = WaiterError(
name="test_name",
reason="Waiter encountered a terminal failure state",
last_response={"jobRun": {"state": "FAILED"}},
)
mock_get_waiter().wait.side_effect = AsyncMock(
side_effect=[error_available, error_available, error_failed]
)
mock_sleep.return_value = True

emr_trigger = EmrContainerSensorTrigger(
virtual_cluster_id=VIRTUAL_CLUSTER_ID,
job_id=JOB_ID,
aws_conn_id=AWS_CONN_ID,
poll_interval=POLL_INTERVAL,
)

generator = emr_trigger.run()
response = await generator.asend(None)

assert mock_get_waiter().wait.call_count == 3
assert response == TriggerEvent({"status": "failure", "message": f"Job Failed: {error_failed}"})