Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/vllm_router/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def initialize_all(app: FastAPI, args):
label_selector=args.k8s_label_selector,
prefill_model_labels=args.prefill_model_labels,
decode_model_labels=args.decode_model_labels,
watcher_timeout_seconds=args.k8s_watcher_timeout_seconds,
)

else:
Expand Down
8 changes: 8 additions & 0 deletions src/vllm_router/parsers/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ def validate_args(args):
validate_static_model_types(args.static_model_types)
if args.service_discovery == "k8s" and args.k8s_port is None:
raise ValueError("K8s port must be provided when using K8s service discovery.")
if args.k8s_watcher_timeout_seconds <= 0:
raise ValueError("k8s-watcher-timeout-seconds must be greater than 0.")
if args.routing_logic == "session" and args.session_key is None:
raise ValueError(
"Session key must be provided when using session routing logic."
Expand Down Expand Up @@ -188,6 +190,12 @@ def parse_args():
default="",
help="The label selector to filter vLLM pods when using K8s service discovery.",
)
parser.add_argument(
"--k8s-watcher-timeout-seconds",
type=int,
default=30,
help="Timeout in seconds for Kubernetes watcher streams (default: 30).",
)
parser.add_argument(
"--routing-logic",
type=str,
Expand Down
14 changes: 8 additions & 6 deletions src/vllm_router/service_discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,7 @@ def __init__(
label_selector=None,
prefill_model_labels: List[str] | None = None,
decode_model_labels: List[str] | None = None,
watcher_timeout_seconds: int = 30,
):
"""
Initialize the Kubernetes service discovery module. This module
Expand All @@ -363,13 +364,15 @@ def __init__(
namespace: the namespace of the engine pods
port: the port of the engines
label_selector: the label selector of the engines
watcher_timeout_seconds: timeout in seconds for Kubernetes watcher streams (default: 30)
"""
self.app = app
self.namespace = namespace
self.port = port
self.available_engines: Dict[str, EndpointInfo] = {}
self.available_engines_lock = threading.Lock()
self.label_selector = label_selector
self.watcher_timeout_seconds = watcher_timeout_seconds

# Init kubernetes watcher
try:
Expand Down Expand Up @@ -566,15 +569,13 @@ def _get_model_label(self, pod) -> Optional[str]:
return pod.metadata.labels.get("model")

def _watch_engines(self):
# TODO (ApostaC): remove the hard-coded timeouts

while self.running:
try:
for event in self.k8s_watcher.stream(
self.k8s_api.list_namespaced_pod,
namespace=self.namespace,
label_selector=self.label_selector,
timeout_seconds=30,
timeout_seconds=self.watcher_timeout_seconds,
):
pod = event["object"]
event_type = event["type"]
Expand Down Expand Up @@ -754,6 +755,7 @@ def __init__(
label_selector=None,
prefill_model_labels: List[str] | None = None,
decode_model_labels: List[str] | None = None,
watcher_timeout_seconds: int = 30,
):
"""
Initialize the Kubernetes service discovery module. This module
Expand Down Expand Up @@ -782,13 +784,15 @@ def __init__(
namespace: the namespace of the engine services
port: the port of the engines
label_selector: the label selector of the engines
watcher_timeout_seconds: timeout in seconds for Kubernetes watcher streams (default: 30)
"""
self.app = app
self.namespace = namespace
self.port = port
self.available_engines: Dict[str, EndpointInfo] = {}
self.available_engines_lock = threading.Lock()
self.label_selector = label_selector
self.watcher_timeout_seconds = watcher_timeout_seconds

# Init kubernetes watcher
try:
Expand Down Expand Up @@ -988,15 +992,13 @@ def _get_model_label(self, service) -> Optional[str]:
return service.spec.selector.get("model")

def _watch_engines(self):
# TODO (ApostaC): remove the hard-coded timeouts

while self.running:
try:
for event in self.k8s_watcher.stream(
self.k8s_api.list_namespaced_service,
namespace=self.namespace,
label_selector=self.label_selector,
timeout_seconds=30,
timeout_seconds=self.watcher_timeout_seconds,
):
service = event["object"]
event_type = event["type"]
Expand Down