Skip to content

Commit

Permalink
Airflow operator - raise structured exceptions and retrieve job termi…
Browse files Browse the repository at this point in the history
…nation reason

Co-authored-by: Martynas Asipauskas <[email protected]>
  • Loading branch information
masipauskas and Martynas Asipauskas authored Nov 25, 2024
1 parent 39ea0d3 commit 7313fb8
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 11 deletions.
12 changes: 12 additions & 0 deletions third_party/airflow/armada/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,18 @@ def submit_job(

return RunningJobContext(queue, job.job_id, job_set_id, DateTime.utcnow())

@tenacity.retry(
wait=tenacity.wait_random_exponential(max=3),
stop=tenacity.stop_after_attempt(5),
reraise=True,
)
@log_exceptions
def job_termination_reason(self, job_context: RunningJobContext) -> str:
resp = self.client.get_job_errors([job_context.job_id])
job_error = resp.job_errors.get(job_context.job_id, "")

return job_error or ""

@tenacity.retry(
wait=tenacity.wait_random_exponential(max=3),
stop=tenacity.stop_after_attempt(5),
Expand Down
19 changes: 14 additions & 5 deletions third_party/airflow/armada/operators/armada.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import jinja2
import tenacity
from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator, BaseOperatorLink, XCom
from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.serialization.serde import deserialize
Expand All @@ -40,6 +39,7 @@
from google.protobuf.json_format import MessageToDict, ParseDict
from pendulum import DateTime

from .errors import ArmadaOperatorJobFailedError
from ..hooks import ArmadaHook
from ..model import RunningJobContext
from ..triggers import ArmadaPollJobTrigger
Expand Down Expand Up @@ -349,9 +349,11 @@ def _running_job_terminated(self, context: RunningJobContext):
f"job {context.job_id} terminated with state: {context.state.name}"
)
if context.state != JobState.SUCCEEDED:
raise AirflowException(
f"job {context.job_id} did not succeed. "
f"Final status was {context.state.name}"
raise ArmadaOperatorJobFailedError(
context.armada_queue,
context.job_id,
context.state,
self.hook.job_termination_reason(context),
)

def _not_acknowledged_within_timeout(self) -> bool:
Expand All @@ -363,6 +365,13 @@ def _not_acknowledged_within_timeout(self) -> bool:
return True
return False

def _should_have_a_pod_in_k8s(self) -> bool:
return self.job_context.state in {
JobState.RUNNING,
JobState.FAILED,
JobState.SUCCEEDED,
}

@log_exceptions
def _check_job_status_and_fetch_logs(self, context) -> None:
self.job_context = self.hook.refresh_context(
Expand All @@ -377,7 +386,7 @@ def _check_job_status_and_fetch_logs(self, context) -> None:
self.job_context = self.hook.cancel_job(self.job_context)
return

if self.job_context.cluster and self.container_logs:
if self._should_have_a_pod_in_k8s() and self.container_logs:
try:
last_log_time = self.pod_manager.fetch_container_logs(
k8s_context=self.job_context.cluster,
Expand Down
50 changes: 50 additions & 0 deletions third_party/airflow/armada/operators/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from airflow.exceptions import AirflowException

from armada_client.typings import JobState


class ArmadaOperatorJobFailedError(AirflowException):
"""
Raised when an ArmadaOperator job has terminated unsuccessfully on Armada.
:param job_id: The unique identifier of the job.
:type job_id: str
:param queue: The queue the job was submitted to.
:type queue: str
:param state: The termination state of the job.
:type state: TerminationState
:param reason: The termination reason, if provided.
:type reason: str
"""

def __init__(self, queue: str, job_id: str, state: JobState, reason: str = ""):
self.job_id = job_id
self.queue = queue
self.state = state
self.reason = reason
self.message = self._generate_message()
super().__init__(self.message)

def _generate_message(self) -> str:
"""
Generate a user-friendly error message.
:return: Formatted error message with job details.
:rtype: str
"""
message = (
f"ArmadaOperator job '{self.job_id}' in queue '{self.queue}'"
f" terminated with state '{self.state.name.capitalize()}'."
)
if self.reason:
message += f" Termination reason: {self.reason}"
return message

def __str__(self) -> str:
"""
Return the error message when the exception is converted to a string.
:return: The error message.
:rtype: str
"""
return self.message
2 changes: 1 addition & 1 deletion third_party/airflow/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ readme='README.md'
authors = [{name = "Armada-GROSS", email = "[email protected]"}]
license = { text = "Apache Software License" }
dependencies=[
'armada-client>=0.4.6',
'armada-client>=0.4.7',
'apache-airflow>=2.6.3',
'types-protobuf==4.24.0.1',
'kubernetes>=23.6.0',
Expand Down
11 changes: 6 additions & 5 deletions third_party/airflow/test/unit/operators/test_armada.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
from unittest.mock import MagicMock, patch

import pytest
from airflow.exceptions import AirflowException, TaskDeferred
from airflow.exceptions import TaskDeferred
from armada.model import GrpcChannelArgs, RunningJobContext
from armada.operators.armada import ArmadaOperator
from armada.operators.errors import ArmadaOperatorJobFailedError
from armada.triggers import ArmadaPollJobTrigger
from armada_client.armada.submit_pb2 import JobSubmitRequestItem
from armada_client.typings import JobState
Expand Down Expand Up @@ -166,12 +167,12 @@ def test_execute_fail(terminal_state, context):
for s in [JobState.RUNNING, terminal_state]
]

with pytest.raises(AirflowException) as exec_info:
with pytest.raises(ArmadaOperatorJobFailedError) as exec_info:
op.execute(context)

# Error message contain terminal state and job id
assert DEFAULT_JOB_ID in str(exec_info)
assert terminal_state.name in str(exec_info)
assert terminal_state.name.capitalize() in str(exec_info)

op.hook.submit_job.assert_called_once_with(
DEFAULT_QUEUE, DEFAULT_JOB_SET, op.job_request
Expand Down Expand Up @@ -199,12 +200,12 @@ def test_not_acknowledged_within_timeout_terminates_running_job(context):
op = operator(JobSubmitRequestItem(), job_acknowledgement_timeout_s=-1)
op.hook.refresh_context.return_value = job_context

with pytest.raises(AirflowException) as exec_info:
with pytest.raises(ArmadaOperatorJobFailedError) as exec_info:
op.execute(context)

# Error message contain terminal state and job id
assert DEFAULT_JOB_ID in str(exec_info)
assert JobState.CANCELLED.name in str(exec_info)
assert JobState.CANCELLED.name.capitalize() in str(exec_info)

# We also cancel already submitted job
op.hook.cancel_job.assert_called_once_with(job_context)
Expand Down
39 changes: 39 additions & 0 deletions third_party/airflow/test/unit/operators/test_errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import pytest
from armada_client.typings import JobState
from armada.operators.errors import ArmadaOperatorJobFailedError


def test_constructor():
job_id = "test-job"
queue = "default-queue"
state = JobState.FAILED
reason = "Out of memory"

error = ArmadaOperatorJobFailedError(queue, job_id, state, reason)

assert error.job_id == job_id
assert error.queue == queue
assert error.state == state
assert error.reason == reason


@pytest.mark.parametrize(
"reason,expected_message",
[
(
"",
"ArmadaOperator job 'test-job' in queue 'default-queue' terminated "
"with state 'Failed'.",
),
(
"Out of memory",
"ArmadaOperator job 'test-job' in queue 'default-queue' terminated "
"with state 'Failed'. Termination reason: Out of memory",
),
],
)
def test_message(reason: str, expected_message: str):
error = ArmadaOperatorJobFailedError(
"default-queue", "test-job", JobState.FAILED, reason
)
assert str(error) == expected_message

0 comments on commit 7313fb8

Please sign in to comment.