Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 68 additions & 6 deletions src/vllm_router/service_discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ class EndpointInfo:
# Endpoint's sleep status
sleep: bool

# Model type (e.g., "transcription", "chat", etc.)
model_type: Optional[str] = None

# Pod name
pod_name: Optional[str] = None

Expand Down Expand Up @@ -317,6 +320,7 @@ def get_endpoint_info(self) -> List[EndpointInfo]:
sleep=False,
added_timestamp=self.added_timestamp,
model_label=model_label,
model_type=self.model_types[i] if self.model_types else None,
model_info=self._get_model_info(model),
)
endpoint_infos.append(endpoint_info)
Expand Down Expand Up @@ -604,6 +608,20 @@ def _get_model_info(self, pod_ip) -> Dict[str, ModelInfo]:
logger.error(f"Failed to get model info from {url}: {e}")
return {}

def _get_model_type(self, pod) -> str:
"""
Get the model type from the pod's metadata labels.

Args:
pod: The Kubernetes pod object

Returns:
The model type if found, chat otherwise
"""
if isinstance(pod, str) or not pod.metadata.labels:
return "chat" # Default to chat model type
return pod.metadata.labels.get("model-type", "chat")

def _get_model_label(self, pod) -> Optional[str]:
"""
Get the model label from the pod's metadata labels.
Expand Down Expand Up @@ -649,9 +667,11 @@ def _watch_engines(self):
if is_pod_ready:
model_names = self._get_model_names(pod_ip)
model_label = self._get_model_label(pod)
model_type = self._get_model_type(pod)
else:
model_names = []
model_label = None
model_type = None

# Record pod status for debugging
if is_container_ready and is_pod_terminating:
Expand All @@ -666,13 +686,19 @@ def _watch_engines(self):
is_pod_ready,
model_names,
model_label,
model_type,
)
except Exception as e:
logger.error(f"K8s watcher error: {e}")
time.sleep(0.5)

def _add_engine(
self, engine_name: str, engine_ip: str, model_names: List[str], model_label: str
self,
engine_name: str,
engine_ip: str,
model_names: List[str],
model_label: str,
model_type: Optional[str],
):
logger.info(
f"Discovered new serving engine {engine_name} at "
Expand All @@ -689,6 +715,10 @@ def _add_engine(
sleep_status = False

with self.available_engines_lock:
# Determine model type for each model
model_types = [self._get_model_type(model) for model in model_names]
model_type = model_types[0] if model_types else None

self.available_engines[engine_name] = EndpointInfo(
url=f"http://{engine_ip}:{self.port}",
model_names=model_names,
Expand All @@ -699,10 +729,12 @@ def _add_engine(
pod_name=engine_name,
namespace=self.namespace,
model_info=model_info,
model_type=model_type,
)

# Store model information in the endpoint info
self.available_engines[engine_name].model_info = model_info
self.available_engines[engine_name].model_type = model_type

# Track all models we've ever seen
with self.known_models_lock:
Expand All @@ -721,6 +753,7 @@ def _on_engine_update(
is_pod_ready: bool,
model_names: List[str],
model_label: Optional[str],
model_type: Optional[str] = None,
) -> None:
if event == "ADDED":
if engine_ip is None:
Expand All @@ -732,7 +765,9 @@ def _on_engine_update(
if not model_names:
return

self._add_engine(engine_name, engine_ip, model_names, model_label)
self._add_engine(
engine_name, engine_ip, model_names, model_label, model_type
)

elif event == "DELETED":
if engine_name not in self.available_engines:
Expand All @@ -745,7 +780,9 @@ def _on_engine_update(
return

if is_pod_ready and model_names:
self._add_engine(engine_name, engine_ip, model_names, model_label)
self._add_engine(
engine_name, engine_ip, model_names, model_label, model_type
)
return

if (
Expand Down Expand Up @@ -1055,6 +1092,20 @@ def _get_model_info(self, service_name) -> Dict[str, ModelInfo]:
logger.error(f"Failed to get model info from {url}: {e}")
return {}

def _get_model_type(self, service) -> str:
"""
Get the model label from the service's selector.

Args:
service: The Kubernetes service object

Returns:
The model selector if found, chat otherwise
"""
if not service.spec.selector:
return "chat"
return service.spec.selector.get("model-type", "chat")

def _get_model_label(self, service) -> Optional[str]:
"""
Get the model label from the service's selector.
Expand Down Expand Up @@ -1094,18 +1145,27 @@ def _watch_engines(self):
else:
model_names = []
model_label = None
model_type = self._get_model_type(service)

self._on_engine_update(
service_name,
event_type,
is_service_ready,
model_names,
model_label,
model_type,
)
except Exception as e:
logger.error(f"K8s watcher error: {e}")
time.sleep(0.5)

def _add_engine(self, engine_name: str, model_names: List[str], model_label: str):
def _add_engine(
self,
engine_name: str,
model_names: List[str],
model_label: str,
model_type: str,
):
logger.info(
f"Discovered new serving engine {engine_name} at "
f"running models: {model_names}"
Expand All @@ -1131,6 +1191,7 @@ def _add_engine(self, engine_name: str, model_names: List[str], model_label: str
service_name=engine_name,
namespace=self.namespace,
model_info=model_info,
model_type=model_type,
)

# Store model information in the endpoint info
Expand All @@ -1148,6 +1209,7 @@ def _on_engine_update(
is_service_ready: bool,
model_names: List[str],
model_label: Optional[str],
model_type: str,
) -> None:
if event == "ADDED":
if not engine_name:
Expand All @@ -1159,7 +1221,7 @@ def _on_engine_update(
if not model_names:
return

self._add_engine(engine_name, model_names, model_label)
self._add_engine(engine_name, model_names, model_label, model_type)

elif event == "DELETED":
if engine_name not in self.available_engines:
Expand All @@ -1172,7 +1234,7 @@ def _on_engine_update(
return

if is_service_ready and model_names:
self._add_engine(engine_name, model_names, model_label)
self._add_engine(engine_name, model_names, model_label, model_type)
return

if (
Expand Down
4 changes: 2 additions & 2 deletions src/vllm_router/services/request_service/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,13 +585,13 @@ async def route_general_transcriptions(

endpoints = service_discovery.get_endpoint_info()

# filter the endpoints url by model name and model label for transcriptions
# filter the endpoints url by model name and model type for transcriptions
transcription_endpoints = []
for ep in endpoints:
for model_name in ep.model_names:
if (
model == model_name
and ep.model_label == "transcription"
and ep.model_type == "transcription"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we actually need model type here? Wouldn't model name itself be enough since we know which model is for transcription when querying.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zerofishnoodles Not sure I can follow here. How do we know which model name is capable of transcription?

E.g. we're using whisper-large-v3-turbo, but there are others.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean when we query we would specify the model name which means we have already known the model is capable of something. You wouldn't query llama for transcription. Unless we want one model to serve specific task, in that case If we really want to add another annotation for that, I would say model_task is more suitable since transcription is not a model type.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ahhh you're right. I made #712 so that we can focus on that. Do we still want the model types for the other Kubernetes based discovery modes?

We're not using it, but I guess it would be useful. Right now the real healthcheck just works for the static service discovery.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I guess we can do that and it will be helpful. Currently we only use /health to inspect the health. If we want to do that, I think we need to add some readme or tutorials so that people know they need to have label on it. Speaking of that, we may also need to modify the helm chart since pod level label is not supported right now.

and not ep.sleep
):
transcription_endpoints.append(ep)
Expand Down