-
Notifications
You must be signed in to change notification settings - Fork 137
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Deferred and non-deferred operator cancels armada jobs when task is m…
…arked as failed. * Deferred and non-deferred operator cancels armada jobs when task is marked as failed. * Minor test / quality of life fixes.
- Loading branch information
1 parent
6923240
commit 9b6f770
Showing
23 changed files
with
835 additions
and
676 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from airflow.serialization.serde import _extra_allowed | ||
|
||
_extra_allowed.add("armada.model.RunningJobContext") | ||
_extra_allowed.add("armada.model.GrpcChannelArgs") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,8 @@ | ||
from typing import Any, Dict, Protocol, Tuple | ||
from typing import Protocol | ||
|
||
""" We use this interface for objects fetching Kubernetes auth tokens. Since | ||
it's used within the Trigger, it must be serialisable.""" | ||
|
||
|
||
class TokenRetriever(Protocol): | ||
def get_token(self) -> str: ... | ||
|
||
def serialize(self) -> Tuple[str, Dict[str, Any]]: ... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
import dataclasses | ||
import json | ||
import threading | ||
from functools import cached_property | ||
from typing import Dict, Optional | ||
|
||
import grpc | ||
from airflow.exceptions import AirflowException | ||
from airflow.serialization.serde import serialize | ||
from airflow.utils.log.logging_mixin import LoggingMixin | ||
from armada.model import GrpcChannelArgs | ||
from armada_client.armada.job_pb2 import JobRunDetails | ||
from armada_client.armada.submit_pb2 import JobSubmitRequestItem | ||
from armada_client.client import ArmadaClient | ||
from armada_client.typings import JobState | ||
from pendulum import DateTime | ||
|
||
from .model import RunningJobContext | ||
|
||
|
||
class ArmadaClientFactory: | ||
CLIENTS_LOCK = threading.Lock() | ||
CLIENTS: Dict[str, ArmadaClient] = {} | ||
|
||
@staticmethod | ||
def client_for(args: GrpcChannelArgs) -> ArmadaClient: | ||
""" | ||
Armada clients, maintain GRPC connection to Armada API. | ||
We cache them per channel args config in class level cache. | ||
Access to this method can be from multiple-threads. | ||
""" | ||
channel_args_key = json.dumps(serialize(args)) | ||
with ArmadaClientFactory.CLIENTS_LOCK: | ||
if channel_args_key not in ArmadaClientFactory.CLIENTS: | ||
ArmadaClientFactory.CLIENTS[channel_args_key] = ArmadaClient( | ||
channel=ArmadaClientFactory._create_channel(args) | ||
) | ||
return ArmadaClientFactory.CLIENTS[channel_args_key] | ||
|
||
@staticmethod | ||
def _create_channel(args: GrpcChannelArgs) -> grpc.Channel: | ||
if args.auth is None: | ||
return grpc.insecure_channel( | ||
target=args.target, options=args.options, compression=args.compression | ||
) | ||
|
||
return grpc.secure_channel( | ||
target=args.target, | ||
options=args.options, | ||
compression=args.compression, | ||
credentials=grpc.composite_channel_credentials( | ||
grpc.ssl_channel_credentials(), | ||
grpc.metadata_call_credentials(args.auth), | ||
), | ||
) | ||
|
||
|
||
class ArmadaHook(LoggingMixin): | ||
def __init__(self, args: GrpcChannelArgs): | ||
self.args = args | ||
|
||
@cached_property | ||
def client(self): | ||
return ArmadaClientFactory.client_for(self.args) | ||
|
||
def cancel_job(self, job_context: RunningJobContext) -> RunningJobContext: | ||
try: | ||
result = self.client.cancel_jobs( | ||
queue=job_context.armada_queue, | ||
job_set_id=job_context.job_set_id, | ||
job_id=job_context.job_id, | ||
) | ||
if len(list(result.cancelled_ids)) > 0: | ||
self.log.info(f"Cancelled job with id {result.cancelled_ids}") | ||
else: | ||
self.log.warning(f"Failed to cancel job with id {job_context.job_id}") | ||
except Exception as e: | ||
self.log.warning(f"Failed to cancel job with id {job_context.job_id}: {e}") | ||
finally: | ||
return dataclasses.replace(job_context, job_state=JobState.CANCELLED.name) | ||
|
||
def submit_job( | ||
self, queue: str, job_set_id: str, job_request: JobSubmitRequestItem | ||
) -> RunningJobContext: | ||
resp = self.client.submit_jobs(queue, job_set_id, [job_request]) | ||
num_responses = len(resp.job_response_items) | ||
|
||
# We submitted exactly one job to armada, so we expect a single response | ||
if num_responses != 1: | ||
raise AirflowException( | ||
f"No valid received from Armada (expected 1 job to be created " | ||
f"but got {num_responses})" | ||
) | ||
job = resp.job_response_items[0] | ||
|
||
# Throw if armada told us we had submitted something bad | ||
if job.error: | ||
raise AirflowException(f"Error submitting job to Armada: {job.error}") | ||
|
||
return RunningJobContext(queue, job.job_id, job_set_id, DateTime.utcnow()) | ||
|
||
def refresh_context( | ||
self, job_context: RunningJobContext, tracking_url: str | ||
) -> RunningJobContext: | ||
response = self.client.get_job_status([job_context.job_id]) | ||
state = JobState(response.job_states[job_context.job_id]) | ||
if state != job_context.state: | ||
self.log.info( | ||
f"job {job_context.job_id} is in state: {state.name}. " | ||
f"{tracking_url}" | ||
) | ||
|
||
cluster = job_context.cluster | ||
if not cluster: | ||
# Job is running / or completed already | ||
if state == JobState.RUNNING or state.is_terminal(): | ||
run_details = self._get_latest_job_run_details(job_context.job_id) | ||
if run_details: | ||
cluster = run_details.cluster | ||
return dataclasses.replace(job_context, job_state=state.name, cluster=cluster) | ||
|
||
def _get_latest_job_run_details(self, job_id) -> Optional[JobRunDetails]: | ||
job_details = self.client.get_job_details([job_id]).job_details[job_id] | ||
if job_details and job_details.latest_run_id: | ||
for run in job_details.job_runs: | ||
if run.run_id == job_details.latest_run_id: | ||
return run | ||
return None |
Oops, something went wrong.