diff --git a/src/vllm_router/service_discovery.py b/src/vllm_router/service_discovery.py index 7d18d16c..83ac835e 100644 --- a/src/vllm_router/service_discovery.py +++ b/src/vllm_router/service_discovery.py @@ -97,6 +97,9 @@ class EndpointInfo: # Endpoint's sleep status sleep: bool + # Model type (e.g., "transcription", "chat", etc.) + model_type: Optional[str] = None + # Pod name pod_name: Optional[str] = None @@ -317,6 +320,7 @@ def get_endpoint_info(self) -> List[EndpointInfo]: sleep=False, added_timestamp=self.added_timestamp, model_label=model_label, + model_type=self.model_types[i] if self.model_types else None, model_info=self._get_model_info(model), ) endpoint_infos.append(endpoint_info) @@ -604,6 +608,20 @@ def _get_model_info(self, pod_ip) -> Dict[str, ModelInfo]: logger.error(f"Failed to get model info from {url}: {e}") return {} + def _get_model_type(self, pod) -> str: + """ + Get the model type from the pod's metadata labels. + + Args: + pod: The Kubernetes pod object + + Returns: + The model type if found, chat otherwise + """ + if isinstance(pod, str) or not pod.metadata.labels: + return "chat" # Default to chat model type + return pod.metadata.labels.get("model-type", "chat") + def _get_model_label(self, pod) -> Optional[str]: """ Get the model label from the pod's metadata labels. @@ -649,9 +667,11 @@ def _watch_engines(self): if is_pod_ready: model_names = self._get_model_names(pod_ip) model_label = self._get_model_label(pod) + model_type = self._get_model_type(pod) else: model_names = [] model_label = None + model_type = None # Record pod status for debugging if is_container_ready and is_pod_terminating: @@ -666,13 +686,19 @@ def _watch_engines(self): is_pod_ready, model_names, model_label, + model_type, ) except Exception as e: logger.error(f"K8s watcher error: {e}") time.sleep(0.5) def _add_engine( - self, engine_name: str, engine_ip: str, model_names: List[str], model_label: str + self, + engine_name: str, + engine_ip: str, + model_names: List[str], + model_label: str, + model_type: Optional[str], ): logger.info( f"Discovered new serving engine {engine_name} at " @@ -689,6 +715,10 @@ def _add_engine( sleep_status = False with self.available_engines_lock: + # Determine model type for each model + model_types = [self._get_model_type(model) for model in model_names] + model_type = model_types[0] if model_types else None + self.available_engines[engine_name] = EndpointInfo( url=f"http://{engine_ip}:{self.port}", model_names=model_names, @@ -699,10 +729,12 @@ def _add_engine( pod_name=engine_name, namespace=self.namespace, model_info=model_info, + model_type=model_type, ) # Store model information in the endpoint info self.available_engines[engine_name].model_info = model_info + self.available_engines[engine_name].model_type = model_type # Track all models we've ever seen with self.known_models_lock: @@ -721,6 +753,7 @@ def _on_engine_update( is_pod_ready: bool, model_names: List[str], model_label: Optional[str], + model_type: Optional[str] = None, ) -> None: if event == "ADDED": if engine_ip is None: @@ -732,7 +765,9 @@ def _on_engine_update( if not model_names: return - self._add_engine(engine_name, engine_ip, model_names, model_label) + self._add_engine( + engine_name, engine_ip, model_names, model_label, model_type + ) elif event == "DELETED": if engine_name not in self.available_engines: @@ -745,7 +780,9 @@ def _on_engine_update( return if is_pod_ready and model_names: - self._add_engine(engine_name, engine_ip, model_names, model_label) + self._add_engine( + engine_name, engine_ip, model_names, model_label, model_type + ) return if ( @@ -1055,6 +1092,20 @@ def _get_model_info(self, service_name) -> Dict[str, ModelInfo]: logger.error(f"Failed to get model info from {url}: {e}") return {} + def _get_model_type(self, service) -> str: + """ + Get the model label from the service's selector. + + Args: + service: The Kubernetes service object + + Returns: + The model selector if found, chat otherwise + """ + if not service.spec.selector: + return "chat" + return service.spec.selector.get("model-type", "chat") + def _get_model_label(self, service) -> Optional[str]: """ Get the model label from the service's selector. @@ -1094,18 +1145,27 @@ def _watch_engines(self): else: model_names = [] model_label = None + model_type = self._get_model_type(service) + self._on_engine_update( service_name, event_type, is_service_ready, model_names, model_label, + model_type, ) except Exception as e: logger.error(f"K8s watcher error: {e}") time.sleep(0.5) - def _add_engine(self, engine_name: str, model_names: List[str], model_label: str): + def _add_engine( + self, + engine_name: str, + model_names: List[str], + model_label: str, + model_type: str, + ): logger.info( f"Discovered new serving engine {engine_name} at " f"running models: {model_names}" @@ -1131,6 +1191,7 @@ def _add_engine(self, engine_name: str, model_names: List[str], model_label: str service_name=engine_name, namespace=self.namespace, model_info=model_info, + model_type=model_type, ) # Store model information in the endpoint info @@ -1148,6 +1209,7 @@ def _on_engine_update( is_service_ready: bool, model_names: List[str], model_label: Optional[str], + model_type: str, ) -> None: if event == "ADDED": if not engine_name: @@ -1159,7 +1221,7 @@ def _on_engine_update( if not model_names: return - self._add_engine(engine_name, model_names, model_label) + self._add_engine(engine_name, model_names, model_label, model_type) elif event == "DELETED": if engine_name not in self.available_engines: @@ -1172,7 +1234,7 @@ def _on_engine_update( return if is_service_ready and model_names: - self._add_engine(engine_name, model_names, model_label) + self._add_engine(engine_name, model_names, model_label, model_type) return if ( diff --git a/src/vllm_router/services/request_service/request.py b/src/vllm_router/services/request_service/request.py index 93e8e376..6734eeef 100644 --- a/src/vllm_router/services/request_service/request.py +++ b/src/vllm_router/services/request_service/request.py @@ -585,13 +585,13 @@ async def route_general_transcriptions( endpoints = service_discovery.get_endpoint_info() - # filter the endpoints url by model name and model label for transcriptions + # filter the endpoints url by model name and model type for transcriptions transcription_endpoints = [] for ep in endpoints: for model_name in ep.model_names: if ( model == model_name - and ep.model_label == "transcription" + and ep.model_type == "transcription" and not ep.sleep ): transcription_endpoints.append(ep)