Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/vllm_router/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ def initialize_all(app: FastAPI, args):
prefill_model_labels=args.prefill_model_labels,
decode_model_labels=args.decode_model_labels,
watcher_timeout_seconds=args.k8s_watcher_timeout_seconds,
health_check_timeout_seconds=args.backend_health_check_timeout_seconds,
)

else:
Expand Down
12 changes: 8 additions & 4 deletions src/vllm_router/parsers/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,6 @@ def validate_args(args):
validate_static_model_types(args.static_model_types)
if args.service_discovery == "k8s" and args.k8s_port is None:
raise ValueError("K8s port must be provided when using K8s service discovery.")
if args.k8s_watcher_timeout_seconds <= 0:
raise ValueError("k8s-watcher-timeout-seconds must be greater than 0.")
if args.routing_logic == "session" and args.session_key is None:
raise ValueError(
"Session key must be provided when using session routing logic."
Expand Down Expand Up @@ -193,8 +191,14 @@ def parse_args():
parser.add_argument(
"--k8s-watcher-timeout-seconds",
type=int,
default=30,
help="Timeout in seconds for Kubernetes watcher streams (default: 30).",
default=0,
help="Timeout in seconds for Kubernetes watcher streams (default: 0).",
)
parser.add_argument(
"--backend-health-check-timeout-seconds",
type=int,
default=10,
help="Timeout in seconds for backend health check requests (default: 10).",
)
parser.add_argument(
"--routing-logic",
Expand Down
46 changes: 36 additions & 10 deletions src/vllm_router/service_discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,8 @@ def __init__(
label_selector=None,
prefill_model_labels: List[str] | None = None,
decode_model_labels: List[str] | None = None,
watcher_timeout_seconds: int = 30,
watcher_timeout_seconds: int = 0,
health_check_timeout_seconds: int = 10,
):
"""
Initialize the Kubernetes service discovery module. This module
Expand All @@ -364,7 +365,7 @@ def __init__(
namespace: the namespace of the engine pods
port: the port of the engines
label_selector: the label selector of the engines
watcher_timeout_seconds: timeout in seconds for Kubernetes watcher streams (default: 30)
watcher_timeout_seconds: timeout in seconds for Kubernetes watcher streams (default: 0)
"""
self.app = app
self.namespace = namespace
Expand All @@ -373,6 +374,7 @@ def __init__(
self.available_engines_lock = threading.Lock()
self.label_selector = label_selector
self.watcher_timeout_seconds = watcher_timeout_seconds
self.health_check_timeout_seconds = health_check_timeout_seconds

# Init kubernetes watcher
try:
Expand Down Expand Up @@ -426,7 +428,9 @@ def _get_engine_sleep_status(self, pod_ip) -> Optional[bool]:
if VLLM_API_KEY := os.getenv("VLLM_API_KEY"):
logger.info("Using vllm server authentication")
headers = {"Authorization": f"Bearer {VLLM_API_KEY}"}
response = requests.get(url, headers=headers)
response = requests.get(
url, headers=headers, timeout=self.health_check_timeout_seconds
)
response.raise_for_status()
sleep = response.json()["is_sleeping"]
return sleep
Expand Down Expand Up @@ -508,7 +512,9 @@ def _get_model_names(self, pod_ip) -> List[str]:
if VLLM_API_KEY := os.getenv("VLLM_API_KEY"):
logger.info("Using vllm server authentication")
headers = {"Authorization": f"Bearer {VLLM_API_KEY}"}
response = requests.get(url, headers=headers)
response = requests.get(
url, headers=headers, timeout=self.health_check_timeout_seconds
)
response.raise_for_status()
models = response.json()["data"]

Expand Down Expand Up @@ -540,7 +546,9 @@ def _get_model_info(self, pod_ip) -> Dict[str, ModelInfo]:
if VLLM_API_KEY := os.getenv("VLLM_API_KEY"):
logger.info("Using vllm server authentication")
headers = {"Authorization": f"Bearer {VLLM_API_KEY}"}
response = requests.get(url, headers=headers)
response = requests.get(
url, headers=headers, timeout=self.health_check_timeout_seconds
)
response.raise_for_status()
models = response.json()["data"]
# Create a dictionary of model information
Expand Down Expand Up @@ -582,6 +590,11 @@ def _watch_engines(self):
pod_name = pod.metadata.name
pod_ip = pod.status.pod_ip

if event_type == "DELETED":
if pod_name in self.available_engines:
self._delete_engine(pod_name)
continue

# Check if pod is terminating
is_pod_terminating = self._is_pod_terminating(pod)
is_container_ready = self._check_pod_ready(
Expand Down Expand Up @@ -755,7 +768,8 @@ def __init__(
label_selector=None,
prefill_model_labels: List[str] | None = None,
decode_model_labels: List[str] | None = None,
watcher_timeout_seconds: int = 30,
watcher_timeout_seconds: int = 0,
health_check_timeout_seconds: int = 10,
):
"""
Initialize the Kubernetes service discovery module. This module
Expand Down Expand Up @@ -784,7 +798,8 @@ def __init__(
namespace: the namespace of the engine services
port: the port of the engines
label_selector: the label selector of the engines
watcher_timeout_seconds: timeout in seconds for Kubernetes watcher streams (default: 30)
watcher_timeout_seconds: timeout in seconds for Kubernetes watcher streams (default: 0)
health_check_timeout_seconds: timeout in seconds for health check requests (default: 10)
"""
self.app = app
self.namespace = namespace
Expand All @@ -793,6 +808,7 @@ def __init__(
self.available_engines_lock = threading.Lock()
self.label_selector = label_selector
self.watcher_timeout_seconds = watcher_timeout_seconds
self.health_check_timeout_seconds = health_check_timeout_seconds

# Init kubernetes watcher
try:
Expand Down Expand Up @@ -837,7 +853,9 @@ def _get_engine_sleep_status(self, service_name) -> Optional[bool]:
if VLLM_API_KEY := os.getenv("VLLM_API_KEY"):
logger.info("Using vllm server authentication")
headers = {"Authorization": f"Bearer {VLLM_API_KEY}"}
response = requests.get(url, headers=headers)
response = requests.get(
url, headers=headers, timeout=self.health_check_timeout_seconds
)
response.raise_for_status()
sleep = response.json()["is_sleeping"]
return sleep
Expand Down Expand Up @@ -931,7 +949,9 @@ def _get_model_names(self, service_name) -> List[str]:
if VLLM_API_KEY := os.getenv("VLLM_API_KEY"):
logger.info("Using vllm server authentication")
headers = {"Authorization": f"Bearer {VLLM_API_KEY}"}
response = requests.get(url, headers=headers)
response = requests.get(
url, headers=headers, timeout=self.health_check_timeout_seconds
)
response.raise_for_status()
models = response.json()["data"]

Expand Down Expand Up @@ -963,7 +983,9 @@ def _get_model_info(self, service_name) -> Dict[str, ModelInfo]:
if VLLM_API_KEY := os.getenv("VLLM_API_KEY"):
logger.info("Using vllm server authentication")
headers = {"Authorization": f"Bearer {VLLM_API_KEY}"}
response = requests.get(url, headers=headers)
response = requests.get(
url, headers=headers, timeout=self.health_check_timeout_seconds
)
response.raise_for_status()
models = response.json()["data"]
# Create a dictionary of model information
Expand Down Expand Up @@ -1002,6 +1024,10 @@ def _watch_engines(self):
):
service = event["object"]
event_type = event["type"]
if event_type == "DELETED":
if service.metadata.name in self.available_engines:
self._delete_engine(service.metadata.name)
continue
service_name = service.metadata.name
is_service_ready = self._check_service_ready(
service_name, self.namespace
Expand Down