Skip to content
Open
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 122 additions & 26 deletions sky/utils/kubernetes/gpu_labeler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import hashlib
import os
import subprocess
import time
from typing import Dict, Optional, Tuple

import colorama
Expand All @@ -13,6 +14,9 @@
from sky.utils import directory_utils
from sky.utils import rich_utils

# Polling interval in seconds for job completion checks
JOB_COMPLETION_POLL_INTERVAL = 5


def _format_string(str_to_format: str, colorama_format: str) -> str:
return f'{colorama_format}{str_to_format}{colorama.Style.RESET_ALL}'
Expand Down Expand Up @@ -166,6 +170,73 @@ def label(context: Optional[str] = None, wait_for_completion: bool = True):
'`skypilot.co/accelerator: <gpu_name>`. ')


def _poll_jobs_completion(jobs_to_node_names: Dict[str, str],
namespace: str,
context: Optional[str] = None,
timeout: int = 60 * 20) -> bool:
"""Fallback polling method to check job completion status.

This method polls the Kubernetes API to check job status instead of using
the watch API. It's used as a fallback when the watch API fails due to
resource version mismatches.

Args:
jobs_to_node_names: A dictionary mapping job names to node names.
namespace: The namespace the jobs are in.
context: Optional Kubernetes context to use.
timeout: Timeout in seconds (default: 1200 seconds = 20 minutes).

Returns:
True if all jobs completed successfully, False if any failed
or timed out.
"""
batch_v1 = kubernetes.batch_api(context=context)
start_time = time.time()
completed_jobs = []

print(
_format_string('Using polling method to check job completion...',
colorama.Style.DIM))

while time.time() - start_time < timeout:
try:
jobs = batch_v1.list_namespaced_job(namespace=namespace)
for job in jobs.items:
job_name = job.metadata.name
if job_name in jobs_to_node_names:
node_name = jobs_to_node_names[job_name]
if job.status and job.status.completion_time:
if job_name not in completed_jobs:
print(
_format_string(
f'GPU labeler job for node {node_name} '
'completed successfully',
colorama.Style.DIM))
completed_jobs.append(job_name)
elif job.status and job.status.failed:
print(
_format_string(
f'GPU labeler job for node {node_name} failed',
colorama.Style.DIM))
return False

if len(completed_jobs) == len(jobs_to_node_names):
return True

time.sleep(JOB_COMPLETION_POLL_INTERVAL)
except kubernetes.api_exception() as poll_error:
print(
_format_string(f'Polling error: {str(poll_error)}',
colorama.Fore.RED))
time.sleep(JOB_COMPLETION_POLL_INTERVAL)

print(
_format_string(
f'Timed out after waiting {timeout} seconds '
'for job to complete', colorama.Style.DIM))
return False


def wait_for_jobs_completion(jobs_to_node_names: Dict[str, str],
namespace: str,
context: Optional[str] = None,
Expand All @@ -181,38 +252,63 @@ def wait_for_jobs_completion(jobs_to_node_names: Dict[str, str],
True if the Job completed successfully, False if it failed or timed out.
"""
batch_v1 = kubernetes.batch_api(context=context)
w = kubernetes.watch()
completed_jobs = []
for event in w.stream(func=batch_v1.list_namespaced_job,
namespace=namespace,
timeout_seconds=timeout):
job = event['object']
job_name = job.metadata.name
if job_name in jobs_to_node_names:
node_name = jobs_to_node_names[job_name]
if job.status and job.status.completion_time:
print(
_format_string(
f'GPU labeler job for node {node_name} '
'completed successfully', colorama.Style.DIM))
completed_jobs.append(job_name)
num_remaining_jobs = len(jobs_to_node_names) - len(
completed_jobs)
if num_remaining_jobs == 0:

def _watch_jobs():
"""Helper function to watch jobs with error handling."""
w = kubernetes.watch()
for event in w.stream(func=batch_v1.list_namespaced_job,
namespace=namespace,
timeout_seconds=timeout):
job = event['object']
job_name = job.metadata.name
if job_name in jobs_to_node_names:
node_name = jobs_to_node_names[job_name]
if job.status and job.status.completion_time:
print(
_format_string(
f'GPU labeler job for node {node_name} '
'completed successfully', colorama.Style.DIM))
completed_jobs.append(job_name)
num_remaining_jobs = len(jobs_to_node_names) - len(
completed_jobs)
if num_remaining_jobs == 0:
w.stop()
return True
elif job.status and job.status.failed:
print(
_format_string(
f'GPU labeler job for node {node_name} failed',
colorama.Style.DIM))
w.stop()
return True
elif job.status and job.status.failed:
print(
_format_string(
f'GPU labeler job for node {node_name} failed',
colorama.Style.DIM))
w.stop()
return False
return False
return None # Timeout

try:
result = _watch_jobs()
if result is not None:
return result
except kubernetes.api_exception() as e:
if e.status == 504 and 'Too large resource version' in str(e):
print(
_format_string(
'Watch failed due to resource version mismatch. '
'Falling back to polling method...', colorama.Fore.YELLOW))
# Fall back to polling instead of watch API
# The watch API is unreliable when resource versions are changing
# rapidly or when there are multiple API server instances with
# different cache states
return _poll_jobs_completion(jobs_to_node_names, namespace, context,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just retry watch on this error? Given that this only happens when the client retried another server

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried the retry watch in seconds, but it didn't work. I think if we insist on using retry, we might need a longer timeout for the watch to work with Nebius K8s.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then is there any other reason than the issue description? If the reason is multi-replica API server failover as the description analyzed, then retry list and watch should address the resource version issue since we've rebuild the connection to a new server replica.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will look into this again.

Copy link
Collaborator Author

@zpoint zpoint Nov 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you're right, change implementation, please take a look again. Thanks

timeout)
else:
# Re-raise other API exceptions
raise

print(
_format_string(
f'Timed out after waiting {timeout} seconds '
'for job to complete', colorama.Style.DIM))
return False #Timed out
return False # Timed out


def main():
Expand Down
Loading