diff --git a/third_party/airflow/armada/hooks.py b/third_party/airflow/armada/hooks.py index c1ba6349b74..bf52a4cc6eb 100644 --- a/third_party/airflow/armada/hooks.py +++ b/third_party/airflow/armada/hooks.py @@ -92,6 +92,18 @@ def submit_job( return RunningJobContext(queue, job.job_id, job_set_id, DateTime.utcnow()) + @tenacity.retry( + wait=tenacity.wait_random_exponential(max=3), + stop=tenacity.stop_after_attempt(5), + reraise=True, + ) + @log_exceptions + def job_termination_reason(self, job_context: RunningJobContext) -> str: + resp = self.client.get_job_errors([job_context.job_id]) + job_error = resp.job_errors.get(job_context.job_id, "") + + return job_error or "" + @tenacity.retry( wait=tenacity.wait_random_exponential(max=3), stop=tenacity.stop_after_attempt(5), diff --git a/third_party/airflow/armada/operators/armada.py b/third_party/airflow/armada/operators/armada.py index 614de1ab8ff..8d38a65d260 100644 --- a/third_party/airflow/armada/operators/armada.py +++ b/third_party/airflow/armada/operators/armada.py @@ -26,7 +26,6 @@ import jinja2 import tenacity from airflow.configuration import conf -from airflow.exceptions import AirflowException from airflow.models import BaseOperator, BaseOperatorLink, XCom from airflow.models.taskinstancekey import TaskInstanceKey from airflow.serialization.serde import deserialize @@ -40,6 +39,7 @@ from google.protobuf.json_format import MessageToDict, ParseDict from pendulum import DateTime +from .errors import ArmadaOperatorJobFailedError from ..hooks import ArmadaHook from ..model import RunningJobContext from ..triggers import ArmadaPollJobTrigger @@ -349,9 +349,11 @@ def _running_job_terminated(self, context: RunningJobContext): f"job {context.job_id} terminated with state: {context.state.name}" ) if context.state != JobState.SUCCEEDED: - raise AirflowException( - f"job {context.job_id} did not succeed. " - f"Final status was {context.state.name}" + raise ArmadaOperatorJobFailedError( + context.armada_queue, + context.job_id, + context.state, + self.hook.job_termination_reason(context), ) def _not_acknowledged_within_timeout(self) -> bool: @@ -363,6 +365,13 @@ def _not_acknowledged_within_timeout(self) -> bool: return True return False + def _should_have_a_pod_in_k8s(self) -> bool: + return self.job_context.state in { + JobState.RUNNING, + JobState.FAILED, + JobState.SUCCEEDED, + } + @log_exceptions def _check_job_status_and_fetch_logs(self, context) -> None: self.job_context = self.hook.refresh_context( @@ -377,7 +386,7 @@ def _check_job_status_and_fetch_logs(self, context) -> None: self.job_context = self.hook.cancel_job(self.job_context) return - if self.job_context.cluster and self.container_logs: + if self._should_have_a_pod_in_k8s() and self.container_logs: try: last_log_time = self.pod_manager.fetch_container_logs( k8s_context=self.job_context.cluster, diff --git a/third_party/airflow/armada/operators/errors.py b/third_party/airflow/armada/operators/errors.py new file mode 100644 index 00000000000..bd50d8c84a8 --- /dev/null +++ b/third_party/airflow/armada/operators/errors.py @@ -0,0 +1,50 @@ +from airflow.exceptions import AirflowException + +from armada_client.typings import JobState + + +class ArmadaOperatorJobFailedError(AirflowException): + """ + Raised when an ArmadaOperator job has terminated unsuccessfully on Armada. + + :param job_id: The unique identifier of the job. + :type job_id: str + :param queue: The queue the job was submitted to. + :type queue: str + :param state: The termination state of the job. + :type state: TerminationState + :param reason: The termination reason, if provided. + :type reason: str + """ + + def __init__(self, queue: str, job_id: str, state: JobState, reason: str = ""): + self.job_id = job_id + self.queue = queue + self.state = state + self.reason = reason + self.message = self._generate_message() + super().__init__(self.message) + + def _generate_message(self) -> str: + """ + Generate a user-friendly error message. + + :return: Formatted error message with job details. + :rtype: str + """ + message = ( + f"ArmadaOperator job '{self.job_id}' in queue '{self.queue}'" + f" terminated with state '{self.state.name.capitalize()}'." + ) + if self.reason: + message += f" Termination reason: {self.reason}" + return message + + def __str__(self) -> str: + """ + Return the error message when the exception is converted to a string. + + :return: The error message. + :rtype: str + """ + return self.message diff --git a/third_party/airflow/pyproject.toml b/third_party/airflow/pyproject.toml index a4ae0607679..67c7b91f678 100644 --- a/third_party/airflow/pyproject.toml +++ b/third_party/airflow/pyproject.toml @@ -10,7 +10,7 @@ readme='README.md' authors = [{name = "Armada-GROSS", email = "armada@armadaproject.io"}] license = { text = "Apache Software License" } dependencies=[ - 'armada-client>=0.4.6', + 'armada-client>=0.4.7', 'apache-airflow>=2.6.3', 'types-protobuf==4.24.0.1', 'kubernetes>=23.6.0', diff --git a/third_party/airflow/test/unit/operators/test_armada.py b/third_party/airflow/test/unit/operators/test_armada.py index 4e3f8804e36..c8b3272db66 100644 --- a/third_party/airflow/test/unit/operators/test_armada.py +++ b/third_party/airflow/test/unit/operators/test_armada.py @@ -4,9 +4,10 @@ from unittest.mock import MagicMock, patch import pytest -from airflow.exceptions import AirflowException, TaskDeferred +from airflow.exceptions import TaskDeferred from armada.model import GrpcChannelArgs, RunningJobContext from armada.operators.armada import ArmadaOperator +from armada.operators.errors import ArmadaOperatorJobFailedError from armada.triggers import ArmadaPollJobTrigger from armada_client.armada.submit_pb2 import JobSubmitRequestItem from armada_client.typings import JobState @@ -166,12 +167,12 @@ def test_execute_fail(terminal_state, context): for s in [JobState.RUNNING, terminal_state] ] - with pytest.raises(AirflowException) as exec_info: + with pytest.raises(ArmadaOperatorJobFailedError) as exec_info: op.execute(context) # Error message contain terminal state and job id assert DEFAULT_JOB_ID in str(exec_info) - assert terminal_state.name in str(exec_info) + assert terminal_state.name.capitalize() in str(exec_info) op.hook.submit_job.assert_called_once_with( DEFAULT_QUEUE, DEFAULT_JOB_SET, op.job_request @@ -199,12 +200,12 @@ def test_not_acknowledged_within_timeout_terminates_running_job(context): op = operator(JobSubmitRequestItem(), job_acknowledgement_timeout_s=-1) op.hook.refresh_context.return_value = job_context - with pytest.raises(AirflowException) as exec_info: + with pytest.raises(ArmadaOperatorJobFailedError) as exec_info: op.execute(context) # Error message contain terminal state and job id assert DEFAULT_JOB_ID in str(exec_info) - assert JobState.CANCELLED.name in str(exec_info) + assert JobState.CANCELLED.name.capitalize() in str(exec_info) # We also cancel already submitted job op.hook.cancel_job.assert_called_once_with(job_context) diff --git a/third_party/airflow/test/unit/operators/test_errors.py b/third_party/airflow/test/unit/operators/test_errors.py new file mode 100644 index 00000000000..e6b950b5996 --- /dev/null +++ b/third_party/airflow/test/unit/operators/test_errors.py @@ -0,0 +1,39 @@ +import pytest +from armada_client.typings import JobState +from armada.operators.errors import ArmadaOperatorJobFailedError + + +def test_constructor(): + job_id = "test-job" + queue = "default-queue" + state = JobState.FAILED + reason = "Out of memory" + + error = ArmadaOperatorJobFailedError(queue, job_id, state, reason) + + assert error.job_id == job_id + assert error.queue == queue + assert error.state == state + assert error.reason == reason + + +@pytest.mark.parametrize( + "reason,expected_message", + [ + ( + "", + "ArmadaOperator job 'test-job' in queue 'default-queue' terminated " + "with state 'Failed'.", + ), + ( + "Out of memory", + "ArmadaOperator job 'test-job' in queue 'default-queue' terminated " + "with state 'Failed'. Termination reason: Out of memory", + ), + ], +) +def test_message(reason: str, expected_message: str): + error = ArmadaOperatorJobFailedError( + "default-queue", "test-job", JobState.FAILED, reason + ) + assert str(error) == expected_message