Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/vllm_router/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def initialize_all(app: FastAPI, args):
label_selector=args.k8s_label_selector,
prefill_model_labels=args.prefill_model_labels,
decode_model_labels=args.decode_model_labels,
timeout_seconds=args.k8s_timeout_seconds,
)

else:
Expand Down
8 changes: 8 additions & 0 deletions src/vllm_router/parsers/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ def validate_args(args):
validate_static_model_types(args.static_model_types)
if args.service_discovery == "k8s" and args.k8s_port is None:
raise ValueError("K8s port must be provided when using K8s service discovery.")
if args.k8s_timeout_seconds <= 0:
raise ValueError("k8s-timeout-seconds must be greater than 0.")
if args.routing_logic == "session" and args.session_key is None:
raise ValueError(
"Session key must be provided when using session routing logic."
Expand Down Expand Up @@ -188,6 +190,12 @@ def parse_args():
default="",
help="The label selector to filter vLLM pods when using K8s service discovery.",
)
parser.add_argument(
"--k8s-timeout-seconds",
type=int,
default=30,
help="Timeout in seconds for Kubernetes watcher streams (default: 30).",
)
parser.add_argument(
"--routing-logic",
type=str,
Expand Down
14 changes: 8 additions & 6 deletions src/vllm_router/service_discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,7 @@ def __init__(
label_selector=None,
prefill_model_labels: List[str] | None = None,
decode_model_labels: List[str] | None = None,
timeout_seconds: int = 30,
):
"""
Initialize the Kubernetes service discovery module. This module
Expand All @@ -363,13 +364,15 @@ def __init__(
namespace: the namespace of the engine pods
port: the port of the engines
label_selector: the label selector of the engines
timeout_seconds: timeout in seconds for Kubernetes watcher streams (default: 30)
"""
self.app = app
self.namespace = namespace
self.port = port
self.available_engines: Dict[str, EndpointInfo] = {}
self.available_engines_lock = threading.Lock()
self.label_selector = label_selector
self.timeout_seconds = timeout_seconds

# Init kubernetes watcher
try:
Expand Down Expand Up @@ -566,15 +569,13 @@ def _get_model_label(self, pod) -> Optional[str]:
return pod.metadata.labels.get("model")

def _watch_engines(self):
# TODO (ApostaC): remove the hard-coded timeouts

while self.running:
try:
for event in self.k8s_watcher.stream(
self.k8s_api.list_namespaced_pod,
namespace=self.namespace,
label_selector=self.label_selector,
timeout_seconds=30,
timeout_seconds=self.timeout_seconds,
):
pod = event["object"]
event_type = event["type"]
Expand Down Expand Up @@ -754,6 +755,7 @@ def __init__(
label_selector=None,
prefill_model_labels: List[str] | None = None,
decode_model_labels: List[str] | None = None,
timeout_seconds: int = 30,
):
"""
Initialize the Kubernetes service discovery module. This module
Expand Down Expand Up @@ -782,13 +784,15 @@ def __init__(
namespace: the namespace of the engine services
port: the port of the engines
label_selector: the label selector of the engines
timeout_seconds: timeout in seconds for Kubernetes watcher streams (default: 30)
"""
self.app = app
self.namespace = namespace
self.port = port
self.available_engines: Dict[str, EndpointInfo] = {}
self.available_engines_lock = threading.Lock()
self.label_selector = label_selector
self.timeout_seconds = timeout_seconds

# Init kubernetes watcher
try:
Expand Down Expand Up @@ -988,15 +992,13 @@ def _get_model_label(self, service) -> Optional[str]:
return service.spec.selector.get("model")

def _watch_engines(self):
# TODO (ApostaC): remove the hard-coded timeouts

while self.running:
try:
for event in self.k8s_watcher.stream(
self.k8s_api.list_namespaced_service,
namespace=self.namespace,
label_selector=self.label_selector,
timeout_seconds=30,
timeout_seconds=self.timeout_seconds,
):
service = event["object"]
event_type = event["type"]
Expand Down