Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 41 additions & 7 deletions airflow/providers/amazon/aws/operators/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from airflow.providers.amazon.aws.links.emr import EmrClusterLink, EmrLogsLink, get_log_uri
from airflow.providers.amazon.aws.triggers.emr import (
EmrAddStepsTrigger,
EmrContainerTrigger,
EmrCreateJobFlowTrigger,
EmrTerminateJobFlowTrigger,
)
Expand Down Expand Up @@ -480,6 +481,7 @@ class EmrContainerOperator(BaseOperator):
Defaults to None, which will poll until the job is *not* in a pending, submitted, or running state.
:param tags: The tags assigned to job runs.
Defaults to None
:param deferrable: Run operator in the deferrable mode.
"""

template_fields: Sequence[str] = (
Expand Down Expand Up @@ -508,6 +510,7 @@ def __init__(
max_tries: int | None = None,
tags: dict | None = None,
max_polling_attempts: int | None = None,
deferrable: bool = False,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
Expand All @@ -524,6 +527,7 @@ def __init__(
self.max_polling_attempts = max_polling_attempts
self.tags = tags
self.job_id: str | None = None
self.deferrable = deferrable

if max_tries:
warnings.warn(
Expand Down Expand Up @@ -556,27 +560,57 @@ def execute(self, context: Context) -> str | None:
self.client_request_token,
self.tags,
)
if self.deferrable:
query_status = self.hook.check_query_status(job_id=self.job_id)
self.check_failure(query_status)
if query_status in EmrContainerHook.SUCCESS_STATES:
return self.job_id
timeout = (
timedelta(seconds=self.max_polling_attempts * self.poll_interval)
if self.max_polling_attempts
else self.execution_timeout
)
self.defer(
timeout=timeout,
trigger=EmrContainerTrigger(
virtual_cluster_id=self.virtual_cluster_id,
job_id=self.job_id,
aws_conn_id=self.aws_conn_id,
poll_interval=self.poll_interval,
),
method_name="execute_complete",
)
if self.wait_for_completion:
query_status = self.hook.poll_query_status(
self.job_id,
max_polling_attempts=self.max_polling_attempts,
poll_interval=self.poll_interval,
)

if query_status in EmrContainerHook.FAILURE_STATES:
error_message = self.hook.get_job_failure_reason(self.job_id)
raise AirflowException(
f"EMR Containers job failed. Final state is {query_status}. "
f"query_execution_id is {self.job_id}. Error: {error_message}"
)
elif not query_status or query_status in EmrContainerHook.INTERMEDIATE_STATES:
self.check_failure(query_status)
if not query_status or query_status in EmrContainerHook.INTERMEDIATE_STATES:
raise AirflowException(
f"Final state of EMR Containers job is {query_status}. "
f"Max tries of poll status exceeded, query_execution_id is {self.job_id}."
)

return self.job_id

def check_failure(self, query_status):
if query_status in EmrContainerHook.FAILURE_STATES:
error_message = self.hook.get_job_failure_reason(self.job_id)
raise AirflowException(
f"EMR Containers job failed. Final state is {query_status}. "
f"query_execution_id is {self.job_id}. Error: {error_message}"
)

def execute_complete(self, context, event=None):
if event["status"] != "success":
raise AirflowException(f"Error while running job: {event}")

self.log.info("%s", event["message"])
return event["job_id"]

def on_kill(self) -> None:
"""Cancel the submitted job run."""
if self.job_id:
Expand Down
68 changes: 64 additions & 4 deletions airflow/providers/amazon/aws/sensors/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,11 @@
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook, EmrServerlessHook
from airflow.providers.amazon.aws.links.emr import EmrClusterLink, EmrLogsLink, get_log_uri
from airflow.providers.amazon.aws.triggers.emr import EmrContainerSensorTrigger
from airflow.providers.amazon.aws.triggers.emr import (
EmrContainerTrigger,
EmrStepSensorTrigger,
EmrTerminateJobFlowTrigger,
)
from airflow.sensors.base import BaseSensorOperator

if TYPE_CHECKING:
Expand Down Expand Up @@ -310,7 +314,7 @@ def execute(self, context: Context):
)
self.defer(
timeout=timeout,
trigger=EmrContainerSensorTrigger(
trigger=EmrContainerTrigger(
virtual_cluster_id=self.virtual_cluster_id,
job_id=self.job_id,
aws_conn_id=self.aws_conn_id,
Expand Down Expand Up @@ -406,9 +410,12 @@ class EmrJobFlowSensor(EmrBaseSensor):

:param job_flow_id: job_flow_id to check the state of
:param target_states: the target states, sensor waits until
job flow reaches any of these states
job flow reaches any of these states. In deferrable mode it would
run until reach the terminal state.
:param failed_states: the failure states, sensor fails when
job flow reaches any of these states
:param max_attempts: Maximum number of tries before failing
:param deferrable: Run sensor in the deferrable mode.
"""

template_fields: Sequence[str] = ("job_flow_id", "target_states", "failed_states")
Expand All @@ -424,12 +431,16 @@ def __init__(
job_flow_id: str,
target_states: Iterable[str] | None = None,
failed_states: Iterable[str] | None = None,
max_attempts: int = 60,
deferrable: bool = False,
**kwargs,
):
super().__init__(**kwargs)
self.job_flow_id = job_flow_id
self.target_states = target_states or ["TERMINATED"]
self.failed_states = failed_states or ["TERMINATED_WITH_ERRORS"]
self.max_attempts = max_attempts
self.deferrable = deferrable

def get_emr_response(self, context: Context) -> dict[str, Any]:
"""
Expand Down Expand Up @@ -488,6 +499,26 @@ def failure_message_from_response(response: dict[str, Any]) -> str | None:
)
return None

def execute(self, context: Context) -> None:
if not self.deferrable:
super().execute(context=context)
elif not self.poke(context):
self.defer(
timeout=timedelta(seconds=self.poke_interval * self.max_attempts),
trigger=EmrTerminateJobFlowTrigger(
job_flow_id=self.job_flow_id,
max_attempts=self.max_attempts,
aws_conn_id=self.aws_conn_id,
poll_interval=int(self.poke_interval),
),
method_name="execute_complete",
)

def execute_complete(self, context, event=None):
if event["status"] != "success":
raise AirflowException(f"Error while running job: {event}")
self.log.info("Job completed.")


class EmrStepSensor(EmrBaseSensor):
"""
Expand All @@ -503,9 +534,12 @@ class EmrStepSensor(EmrBaseSensor):
:param job_flow_id: job_flow_id which contains the step check the state of
:param step_id: step to check the state of
:param target_states: the target states, sensor waits until
step reaches any of these states
step reaches any of these states. In case of deferrable sensor it will
for reach to terminal state
:param failed_states: the failure states, sensor fails when
step reaches any of these states
:param max_attempts: Maximum number of tries before failing
:param deferrable: Run sensor in the deferrable mode.
"""

template_fields: Sequence[str] = ("job_flow_id", "step_id", "target_states", "failed_states")
Expand All @@ -522,13 +556,17 @@ def __init__(
step_id: str,
target_states: Iterable[str] | None = None,
failed_states: Iterable[str] | None = None,
max_attempts: int = 60,
deferrable: bool = False,
**kwargs,
):
super().__init__(**kwargs)
self.job_flow_id = job_flow_id
self.step_id = step_id
self.target_states = target_states or ["COMPLETED"]
self.failed_states = failed_states or ["CANCELLED", "FAILED", "INTERRUPTED"]
self.max_attempts = max_attempts
self.deferrable = deferrable

def get_emr_response(self, context: Context) -> dict[str, Any]:
"""
Expand Down Expand Up @@ -587,3 +625,25 @@ def failure_message_from_response(response: dict[str, Any]) -> str | None:
f"with message {fail_details.get('Message')} and log file {fail_details.get('LogFile')}"
)
return None

def execute(self, context: Context) -> None:
if not self.deferrable:
super().execute(context=context)
elif not self.poke(context):
self.defer(
timeout=timedelta(seconds=self.max_attempts * self.poke_interval),
trigger=EmrStepSensorTrigger(
job_flow_id=self.job_flow_id,
step_id=self.step_id,
aws_conn_id=self.aws_conn_id,
max_attempts=self.max_attempts,
poke_interval=int(self.poke_interval),
),
method_name="execute_complete",
)

def execute_complete(self, context, event=None):
if event["status"] != "success":
raise AirflowException(f"Error while running job: {event}")

self.log.info("Job completed.")
67 changes: 64 additions & 3 deletions airflow/providers/amazon/aws/triggers/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook
from airflow.providers.amazon.aws.utils.waiter_with_logging import async_wait
from airflow.triggers.base import BaseTrigger, TriggerEvent
from airflow.utils.helpers import prune_dict

Expand Down Expand Up @@ -249,7 +250,7 @@ async def run(self):
)


class EmrContainerSensorTrigger(BaseTrigger):
class EmrContainerTrigger(BaseTrigger):
"""
Poll for the status of EMR container until reaches terminal state.

Expand Down Expand Up @@ -278,9 +279,9 @@ def hook(self) -> EmrContainerHook:
return EmrContainerHook(self.aws_conn_id, virtual_cluster_id=self.virtual_cluster_id)

def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serializes EmrContainerSensorTrigger arguments and classpath."""
"""Serializes EmrContainerTrigger arguments and classpath."""
return (
"airflow.providers.amazon.aws.triggers.emr.EmrContainerSensorTrigger",
"airflow.providers.amazon.aws.triggers.emr.EmrContainerTrigger",
{
"virtual_cluster_id": self.virtual_cluster_id,
"job_id": self.job_id,
Expand Down Expand Up @@ -317,3 +318,63 @@ async def run(self) -> AsyncIterator[TriggerEvent]:
await asyncio.sleep(int(self.poll_interval))

yield TriggerEvent({"status": "success", "job_id": self.job_id})


class EmrStepSensorTrigger(BaseTrigger):
"""
Poll for the status of EMR container until reaches terminal state.

:param job_flow_id: job_flow_id which contains the step check the state of
:param step_id: step to check the state of
:param aws_conn_id: Reference to AWS connection id
:param max_attempts: The maximum number of attempts to be made
:param poke_interval: polling period in seconds to check for the status
"""

def __init__(
self,
job_flow_id: str,
step_id: str,
aws_conn_id: str = "aws_default",
max_attempts: int = 60,
poke_interval: int = 30,
**kwargs: Any,
):
self.job_flow_id = job_flow_id
self.step_id = step_id
self.aws_conn_id = aws_conn_id
self.max_attempts = max_attempts
self.poke_interval = poke_interval
super().__init__(**kwargs)

@cached_property
def hook(self) -> EmrHook:
return EmrHook(self.aws_conn_id)

def serialize(self) -> tuple[str, dict[str, Any]]:
return (
"airflow.providers.amazon.aws.triggers.emr.EmrStepSensorTrigger",
{
"job_flow_id": self.job_flow_id,
"step_id": self.step_id,
"aws_conn_id": self.aws_conn_id,
"max_attempts": self.max_attempts,
"poke_interval": self.poke_interval,
},
)

async def run(self) -> AsyncIterator[TriggerEvent]:

async with self.hook.async_conn as client:
waiter = client.get_waiter("step_wait_for_terminal", deferrable=True, client=client)
await async_wait(
waiter=waiter,
waiter_delay=self.poke_interval,
waiter_max_attempts=self.max_attempts,
args={"ClusterId": self.job_flow_id, "StepId": self.step_id},
failure_message=f"Error while waiting for step {self.step_id} to complete",
status_message=f"Step id: {self.step_id}, Step is still in non-terminal state",
status_args=["Step.Status.State"],
)

yield TriggerEvent({"status": "success"})
31 changes: 31 additions & 0 deletions airflow/providers/amazon/aws/waiters/emr.json
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,37 @@
"state": "failure"
}
]
},
"step_wait_for_terminal": {
"operation": "DescribeStep",
"delay": 30,
"maxAttempts": 60,
"acceptors": [
{
"matcher": "path",
"argument": "Step.Status.State",
"expected": "COMPLETED",
"state": "success"
},
{
"matcher": "path",
"argument": "Step.Status.State",
"expected": "CANCELLED",
"state": "failure"
},
{
"matcher": "path",
"argument": "Step.Status.State",
"expected": "FAILED",
"state": "failure"
},
{
"matcher": "path",
"argument": "Step.Status.State",
"expected": "INTERRUPTED",
"state": "failure"
}
]
}
}
}
8 changes: 7 additions & 1 deletion tests/providers/amazon/aws/hooks/test_emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,13 @@ class TestEmrHook:
def test_service_waiters(self):
hook = EmrHook(aws_conn_id=None)
official_waiters = hook.conn.waiter_names
custom_waiters = ["job_flow_waiting", "job_flow_terminated", "notebook_running", "notebook_stopped"]
custom_waiters = [
"job_flow_waiting",
"job_flow_terminated",
"notebook_running",
"notebook_stopped",
"step_wait_for_terminal",
]

assert sorted(hook.list_waiters()) == sorted([*official_waiters, *custom_waiters])

Expand Down
Loading