diff --git a/miles/router/router.py b/miles/router/router.py index 88179a293..2e8ecfc41 100644 --- a/miles/router/router.py +++ b/miles/router/router.py @@ -1,5 +1,7 @@ import argparse +import asyncio import json +import logging import httpx import uvicorn @@ -9,6 +11,8 @@ from miles.utils.misc import load_function +logger = logging.getLogger(__name__) + def run_router(args): """ @@ -28,9 +32,14 @@ def __init__(self, args, verbose=False): self.verbose = verbose self.app = FastAPI() - - # Worker information - self.worker_urls: dict[str, int] = {} + self.app.add_event_handler("startup", self._start_background_health_check) + + # URL -> Active Request Count (load state) + self.worker_request_counts: dict[str, int] = {} + # URL -> Consecutive Failures + self.worker_failure_counts: dict[str, int] = {} + # Quarantined workers excluded from routing pool + self.dead_workers: set[str] = set() self.max_weight_version = None max_connections = getattr(args, "miles_router_max_connections", None) @@ -63,9 +72,61 @@ def _setup_routes(self): # Catch-all route for proxying to SGLang - must be registered LAST self.app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])(self.proxy) - async def health_check(self, request: Request): - # TODO: do health check in background - pass + async def _start_background_health_check(self): + asyncio.create_task(self._health_check_loop()) + + async def _check_worker_health(self, url): + """Encapsulated health check logic for better maintainability.""" + try: + response = await self.client.get(f"{url}/health", timeout=5.0) + if response.status_code == 200: + return url, True + logger.debug(f"[miles-router] Worker {url} is unhealthy (Status: {response.status_code})") + except Exception as e: + logger.debug(f"[miles-router] Worker {url} health check failed: {e}") + return url, False + + async def _health_check_loop(self): + """Background loop to monitor worker health and adjust routing pool.""" + interval = self.args.rollout_health_check_interval + threshold = self.args.miles_router_health_check_failure_threshold + + while True: + try: + await asyncio.sleep(interval) + + urls = [u for u in self.worker_request_counts if u not in self.dead_workers] + if not urls: + continue + + results = await asyncio.gather(*(self._check_worker_health(url) for url in urls)) + + for url, is_healthy in results: + if not is_healthy: + failures = self.worker_failure_counts.get(url, 0) + 1 + self.worker_failure_counts[url] = failures + + if failures >= threshold: + logger.warning( + f"[miles-router] Worker {url} failed {threshold} consecutive health checks. Marking as DEAD." + ) + self.dead_workers.add(url) + # TODO (chenyang): Connect back 'dead' workers requires a mechanism to sync + # model versions to avoid off-policy issues from stale weights, since these + # dead workers' parameters may not be refitted. + else: + self.worker_failure_counts[url] = 0 + + logger.debug( + f"[miles-router] Health check complete. {len(self.worker_request_counts) - len(self.dead_workers)} workers healthy." + ) + + except asyncio.CancelledError: + logger.warning("[miles-router] Background health check loop is being cancelled.") + raise + except Exception as e: + logger.error(f"[miles-router] Unexpected error in health check loop: {e}", exc_info=True) + await asyncio.sleep(5) async def proxy(self, request: Request, path: str): """Proxy all other requests to the SGLang router""" @@ -124,16 +185,17 @@ async def add_worker(self, request: Request): ) # Add if new, keep a simple request count per worker - if worker_url not in self.worker_urls: - self.worker_urls[worker_url] = 0 + if worker_url not in self.worker_request_counts: + self.worker_request_counts[worker_url] = 0 + self.worker_failure_counts[worker_url] = 0 if self.verbose: print(f"[miles-router] Added new worker: {worker_url}") - return {"status": "success", "worker_urls": self.worker_urls} + return {"status": "success", "worker_urls": self.worker_request_counts} async def list_workers(self, request: Request): """List all registered workers""" - return {"urls": list(self.worker_urls.keys())} + return {"urls": list(self.worker_request_counts.keys())} async def retrieve_from_text(self, request: Request): """Get token information from text input""" @@ -158,19 +220,27 @@ async def retrieve_from_text(self, request: Request): return result def _use_url(self): - """Select a worker URL using round-robin strategy""" - assert len(self.worker_urls) > 0, "No workers available" + """Select worker URL with minimal active requests.""" + + if not self.dead_workers: + # Healthy path: select from all workers + url = min(self.worker_request_counts, key=self.worker_request_counts.get) + else: + # Degraded path: select from workers not in dead_workers + valid_workers = (w for w in self.worker_request_counts if w not in self.dead_workers) + try: + url = min(valid_workers, key=self.worker_request_counts.get) + except ValueError: + raise RuntimeError("No healthy workers available in the pool") from None - # get the url with mininal count - url = min(self.worker_urls, key=self.worker_urls.get) - self.worker_urls[url] += 1 + self.worker_request_counts[url] += 1 return url def _finish_url(self, url): """Mark the request to the given URL as finished""" - assert url in self.worker_urls, f"URL {url} not recognized" - self.worker_urls[url] -= 1 - assert self.worker_urls[url] >= 0, f"URL {url} count went negative" + assert url in self.worker_request_counts, f"URL {url} not recognized" + self.worker_request_counts[url] -= 1 + assert self.worker_request_counts[url] >= 0, f"URL {url} count went negative" if __name__ == "__main__": diff --git a/miles/utils/arguments.py b/miles/utils/arguments.py index ce6e47161..cc324e46c 100644 --- a/miles/utils/arguments.py +++ b/miles/utils/arguments.py @@ -868,6 +868,12 @@ def add_router_arguments(parser): default=None, help="Max connections for MilesRouter HTTP client.", ) + parser.add_argument( + "--miles-router-health-check-failure-threshold", + type=int, + default=3, + help="Number of consecutive failures before marking a worker as unhealthy.", + ) RouterArgs.add_cli_args(parser, use_router_prefix=True, exclude_host_port=True) return parser