diff --git a/src/vllm_router/app.py b/src/vllm_router/app.py index 98ca740fc..adb63d61b 100644 --- a/src/vllm_router/app.py +++ b/src/vllm_router/app.py @@ -164,6 +164,7 @@ def initialize_all(app: FastAPI, args): label_selector=args.k8s_label_selector, prefill_model_labels=args.prefill_model_labels, decode_model_labels=args.decode_model_labels, + watcher_timeout_seconds=args.k8s_watcher_timeout_seconds, ) else: diff --git a/src/vllm_router/parsers/parser.py b/src/vllm_router/parsers/parser.py index 2786705fc..5bace01fb 100644 --- a/src/vllm_router/parsers/parser.py +++ b/src/vllm_router/parsers/parser.py @@ -97,6 +97,8 @@ def validate_args(args): validate_static_model_types(args.static_model_types) if args.service_discovery == "k8s" and args.k8s_port is None: raise ValueError("K8s port must be provided when using K8s service discovery.") + if args.k8s_watcher_timeout_seconds <= 0: + raise ValueError("k8s-watcher-timeout-seconds must be greater than 0.") if args.routing_logic == "session" and args.session_key is None: raise ValueError( "Session key must be provided when using session routing logic." @@ -188,6 +190,12 @@ def parse_args(): default="", help="The label selector to filter vLLM pods when using K8s service discovery.", ) + parser.add_argument( + "--k8s-watcher-timeout-seconds", + type=int, + default=30, + help="Timeout in seconds for Kubernetes watcher streams (default: 30).", + ) parser.add_argument( "--routing-logic", type=str, diff --git a/src/vllm_router/service_discovery.py b/src/vllm_router/service_discovery.py index de40a3125..7dc0fd651 100644 --- a/src/vllm_router/service_discovery.py +++ b/src/vllm_router/service_discovery.py @@ -350,6 +350,7 @@ def __init__( label_selector=None, prefill_model_labels: List[str] | None = None, decode_model_labels: List[str] | None = None, + watcher_timeout_seconds: int = 30, ): """ Initialize the Kubernetes service discovery module. This module @@ -363,6 +364,7 @@ def __init__( namespace: the namespace of the engine pods port: the port of the engines label_selector: the label selector of the engines + watcher_timeout_seconds: timeout in seconds for Kubernetes watcher streams (default: 30) """ self.app = app self.namespace = namespace @@ -370,6 +372,7 @@ def __init__( self.available_engines: Dict[str, EndpointInfo] = {} self.available_engines_lock = threading.Lock() self.label_selector = label_selector + self.watcher_timeout_seconds = watcher_timeout_seconds # Init kubernetes watcher try: @@ -566,15 +569,13 @@ def _get_model_label(self, pod) -> Optional[str]: return pod.metadata.labels.get("model") def _watch_engines(self): - # TODO (ApostaC): remove the hard-coded timeouts - while self.running: try: for event in self.k8s_watcher.stream( self.k8s_api.list_namespaced_pod, namespace=self.namespace, label_selector=self.label_selector, - timeout_seconds=30, + timeout_seconds=self.watcher_timeout_seconds, ): pod = event["object"] event_type = event["type"] @@ -754,6 +755,7 @@ def __init__( label_selector=None, prefill_model_labels: List[str] | None = None, decode_model_labels: List[str] | None = None, + watcher_timeout_seconds: int = 30, ): """ Initialize the Kubernetes service discovery module. This module @@ -782,6 +784,7 @@ def __init__( namespace: the namespace of the engine services port: the port of the engines label_selector: the label selector of the engines + watcher_timeout_seconds: timeout in seconds for Kubernetes watcher streams (default: 30) """ self.app = app self.namespace = namespace @@ -789,6 +792,7 @@ def __init__( self.available_engines: Dict[str, EndpointInfo] = {} self.available_engines_lock = threading.Lock() self.label_selector = label_selector + self.watcher_timeout_seconds = watcher_timeout_seconds # Init kubernetes watcher try: @@ -988,15 +992,13 @@ def _get_model_label(self, service) -> Optional[str]: return service.spec.selector.get("model") def _watch_engines(self): - # TODO (ApostaC): remove the hard-coded timeouts - while self.running: try: for event in self.k8s_watcher.stream( self.k8s_api.list_namespaced_service, namespace=self.namespace, label_selector=self.label_selector, - timeout_seconds=30, + timeout_seconds=self.watcher_timeout_seconds, ): service = event["object"] event_type = event["type"]