Skip to content
Open
258 changes: 188 additions & 70 deletions src/vllm_router/service_discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@

import aiohttp
import requests
import queue
from kubernetes import client, config, watch
from collections import OrderedDict

from vllm_router import utils
from vllm_router.log import init_logger
Expand Down Expand Up @@ -383,6 +385,13 @@ def __init__(
self.k8s_api = client.CoreV1Api()
self.k8s_watcher = watch.Watch()

# Event queue and processor
self.event_queue = queue.Queue()
self._event_queue_dict = OrderedDict()
self._event_queue_dict_lock = threading.Lock()
self.event_processor_task = None
self.resource_version = None

# Start watching engines
self.running = True
self.watcher_thread = threading.Thread(target=self._watch_engines, daemon=True)
Expand All @@ -409,7 +418,7 @@ def _is_pod_terminating(pod):
"""
return pod.metadata.deletion_timestamp is not None

def _get_engine_sleep_status(self, pod_ip) -> Optional[bool]:
async def _get_engine_sleep_status(self, pod_ip) -> Optional[bool]:
"""
Get the engine sleeping status by querying the engine's
'/is_sleeping' endpoint.
Expand All @@ -426,10 +435,14 @@ def _get_engine_sleep_status(self, pod_ip) -> Optional[bool]:
if VLLM_API_KEY := os.getenv("VLLM_API_KEY"):
logger.info("Using vllm server authentication")
headers = {"Authorization": f"Bearer {VLLM_API_KEY}"}
response = requests.get(url, headers=headers)
response.raise_for_status()
sleep = response.json()["is_sleeping"]
return sleep

# Use aiohttp for async HTTP requests
async with aiohttp.ClientSession() as session:
async with session.get(url, headers=headers) as response:
response.raise_for_status()
data = await response.json()
sleep = data["is_sleeping"]
return sleep
except Exception as e:
logger.warning(
f"Failed to get the sleep status for engine at {url} - sleep status is set to `False`: {e}"
Expand Down Expand Up @@ -491,7 +504,7 @@ def remove_sleep_label(self, pod_name):
except client.rest.ApiException as e:
logger.error(f"Error removing sleeping label: {e}")

def _get_model_names(self, pod_ip) -> List[str]:
async def _get_model_names(self, pod_ip) -> List[str]:
"""
Get the model names of the serving engine pod by querying the pod's
'/v1/models' endpoint.
Expand All @@ -508,23 +521,26 @@ def _get_model_names(self, pod_ip) -> List[str]:
if VLLM_API_KEY := os.getenv("VLLM_API_KEY"):
logger.info("Using vllm server authentication")
headers = {"Authorization": f"Bearer {VLLM_API_KEY}"}
response = requests.get(url, headers=headers)
response.raise_for_status()
models = response.json()["data"]

# Collect all model names, including both base models and adapters
model_names = []
for model in models:
model_id = model["id"]
model_names.append(model_id)

logger.info(f"Found models on pod {pod_ip}: {model_names}")
return model_names
# Use aiohttp for async HTTP requests
async with aiohttp.ClientSession() as session:
async with session.get(url, headers=headers) as response:
response.raise_for_status()
data = await response.json()
models = data["data"]

# Collect all model names, including both base models and adapters
model_names = []
for model in models:
model_id = model["id"]
model_names.append(model_id)

logger.info(f"Found models on pod {pod_ip}: {model_names}")
return model_names
except Exception as e:
logger.error(f"Failed to get model names from {url}: {e}")
return []

def _get_model_info(self, pod_ip) -> Dict[str, ModelInfo]:
async def _get_model_info(self, pod_ip) -> Dict[str, ModelInfo]:
"""
Get detailed model information from the serving engine pod.

Expand All @@ -540,16 +556,20 @@ def _get_model_info(self, pod_ip) -> Dict[str, ModelInfo]:
if VLLM_API_KEY := os.getenv("VLLM_API_KEY"):
logger.info("Using vllm server authentication")
headers = {"Authorization": f"Bearer {VLLM_API_KEY}"}
response = requests.get(url, headers=headers)
response.raise_for_status()
models = response.json()["data"]
# Create a dictionary of model information
model_info = {}
for model in models:
model_id = model["id"]
model_info[model_id] = ModelInfo.from_dict(model)

return model_info
# Use aiohttp for async HTTP requests
async with aiohttp.ClientSession() as session:
async with session.get(url, headers=headers) as response:
response.raise_for_status()
data = await response.json()
models = data["data"]
# Create a dictionary of model information
model_info = {}
for model in models:
model_id = model["id"]
model_info[model_id] = ModelInfo.from_dict(model)

return model_info
except Exception as e:
logger.error(f"Failed to get model info from {url}: {e}")
return {}
Expand All @@ -569,54 +589,140 @@ def _get_model_label(self, pod) -> Optional[str]:
return pod.metadata.labels.get("model")

def _watch_engines(self):
"""
Watcher thread that only enqueues events. All processing is done by the async event processor.
"""
while self.running:
try:
logger.debug(f"K8s watcher started{self.get_endpoint_info()}")
# Use resource version for efficient watching
watch_params = {
"namespace": self.namespace,
"label_selector": self.label_selector,
"timeout_seconds": 30,
}
if self.resource_version:
watch_params["resource_version"] = self.resource_version

for event in self.k8s_watcher.stream(
self.k8s_api.list_namespaced_pod,
namespace=self.namespace,
label_selector=self.label_selector,
timeout_seconds=self.watcher_timeout_seconds,
self.k8s_api.list_namespaced_pod, **watch_params
):
pod = event["object"]
event_type = event["type"]
pod_name = pod.metadata.name
pod_ip = pod.status.pod_ip
# Update resource version
self.resource_version = event["object"].metadata.resource_version

# Check if pod is terminating
is_pod_terminating = self._is_pod_terminating(pod)
is_container_ready = self._check_pod_ready(
pod.status.container_statuses
)
# Enqueue event by key (pod_name) to avoid duplicates
self._enqueue_event(event)

# Pod is ready if container is ready and pod is not terminating
is_pod_ready = is_container_ready and not is_pod_terminating
except Exception as e:
logger.error(f"K8s watcher error: {e}")
time.sleep(0.5)

if is_pod_ready:
model_names = self._get_model_names(pod_ip)
model_label = self._get_model_label(pod)
else:
model_names = []
model_label = None
def _enqueue_event(self, event: dict):
"""
Enqueue event by pod name key. If a newer event for the same pod exists,
replace the older event.
"""
pod_name = event["object"].metadata.name

# Record pod status for debugging
if is_container_ready and is_pod_terminating:
logger.info(
f"Pod {pod_name} has ready containers but is terminating - marking as unavailable"
)
# Add/update event in the ordered dict
with self._event_queue_dict_lock:
self._event_queue_dict[pod_name] = event

self._on_engine_update(
pod_name,
pod_ip,
event_type,
is_pod_ready,
model_names,
model_label,
)
# Put the pod name in the queue for processing
try:
self.event_queue.put_nowait(pod_name)
except queue.Full:
logger.warning(f"Event queue is full, dropping event for pod {pod_name}")

async def _start_event_processor(self):
"""
Start the async event processor.
"""
if self.event_processor_task is None:
self.event_processor_task = asyncio.create_task(self._process_events())

async def _process_events(self):
"""
Async event processor that handles events from the queue.
"""
while self.running:
try:
# Get pod name from queue
pod_name = await asyncio.get_event_loop().run_in_executor(
None, self.event_queue.get, True, 1.0
)

# Get the event from our ordered dict
with self._event_queue_dict_lock:
event = self._event_queue_dict.pop(pod_name, None)
if event is None:
continue

# Process the event asynchronously
await self._process_single_event(event)

except queue.Empty:
continue
except Exception as e:
logger.error(f"K8s watcher error: {e}")
time.sleep(0.5)
logger.error(f"Event processor error: {e}")
await asyncio.sleep(0.1)

def _add_engine(
async def _process_single_event(self, event: dict):
"""
Process a single event asynchronously.
"""
pod = event["object"]
event_type = event["type"]
pod_name = pod.metadata.name
pod_ip = pod.status.pod_ip

logger.debug(
f"Processing event: pod_name: {pod_name} pod_ip: {pod_ip} event_type: {event_type}"
)

# Preprocess the event to get all necessary information
preprocessed_data = await self._preprocess_event(pod, pod_ip)

# Call the async engine update handler
await self._on_engine_update(
pod_name,
pod_ip,
event_type,
preprocessed_data["is_pod_ready"],
preprocessed_data["model_names"],
preprocessed_data["model_label"],
)

async def _preprocess_event(self, pod, pod_ip: str) -> dict:
"""
Preprocess event data to extract information about model names/labels/pod readiness
"""
# Check if pod is terminating
is_pod_terminating = self._is_pod_terminating(pod)
is_container_ready = self._check_pod_ready(pod.status.container_statuses)
# Pod is ready if container is ready and pod is not terminating
is_pod_ready = is_container_ready and not is_pod_terminating

if is_pod_ready:
model_names = await self._get_model_names(pod_ip)
model_label = self._get_model_label(pod)
else:
model_names = []
model_label = None

# Record pod status for debugging
if is_container_ready and is_pod_terminating:
logger.info(
f"Pod {pod.metadata.name} has ready containers but is terminating - marking as unavailable"
)

return {
"is_pod_ready": is_pod_ready,
"model_names": model_names,
"model_label": model_label,
}

async def _add_engine(
self, engine_name: str, engine_ip: str, model_names: List[str], model_label: str
):
logger.info(
Expand All @@ -625,11 +731,11 @@ def _add_engine(
)

# Get detailed model information
model_info = self._get_model_info(engine_ip)
model_info = await self._get_model_info(engine_ip)

# Check if engine is enabled with sleep mode and set engine sleep status
if self._check_engine_sleep_mode(engine_name):
sleep_status = self._get_engine_sleep_status(engine_ip)
sleep_status = await self._get_engine_sleep_status(engine_ip)
else:
sleep_status = False

Expand All @@ -654,7 +760,7 @@ def _delete_engine(self, engine_name: str):
with self.available_engines_lock:
del self.available_engines[engine_name]

def _on_engine_update(
async def _on_engine_update(
self,
engine_name: str,
engine_ip: Optional[str],
Expand All @@ -673,7 +779,9 @@ def _on_engine_update(
if not model_names:
return

self._add_engine(engine_name, engine_ip, model_names, model_label)
await self._add_engine(
engine_name, engine_ip, model_names, model_label
)

elif event == "DELETED":
if engine_name not in self.available_engines:
Expand All @@ -686,7 +794,9 @@ def _on_engine_update(
return

if is_pod_ready and model_names:
self._add_engine(engine_name, engine_ip, model_names, model_label)
await self._add_engine(
engine_name, engine_ip, model_names, model_label
)
return

if (
Expand Down Expand Up @@ -721,13 +831,21 @@ def close(self):
"""
self.running = False
self.k8s_watcher.stop()

# Cancel the event processor task
if self.event_processor_task:
self.event_processor_task.cancel()

self.watcher_thread.join()

async def initialize_client_sessions(self) -> None:
"""
Initialize aiohttp ClientSession objects for prefill and decode endpoints.
This must be called from an async context during app startup.
"""
# Start the event processor
await self._start_event_processor()

if (
self.prefill_model_labels is not None
and self.decode_model_labels is not None
Expand Down
Loading