From d0cceb89dc37e321f1278d5c54627a84f2690970 Mon Sep 17 00:00:00 2001 From: Dong Wang Date: Mon, 2 Feb 2026 00:25:19 -0800 Subject: [PATCH 01/14] WIP: diffuser router --- examples/diffusion_router/README.md | 69 +++++++ examples/diffusion_router/demo.py | 57 ++++++ miles/router/diffusion_router.py | 281 ++++++++++++++++++++++++++++ 3 files changed, 407 insertions(+) create mode 100644 examples/diffusion_router/README.md create mode 100644 examples/diffusion_router/demo.py create mode 100644 miles/router/diffusion_router.py diff --git a/examples/diffusion_router/README.md b/examples/diffusion_router/README.md new file mode 100644 index 000000000..135b24f9a --- /dev/null +++ b/examples/diffusion_router/README.md @@ -0,0 +1,69 @@ +# Miles Diffusion Router + +Load-balances requests across multiple `sglang-diffusion` worker instances using least-request routing with background health checks and worker quarantine. + +## Quick Start + +```bash +# Start the router with two diffusion backends +python examples/diffusion_router/demo.py --port 30080 \ + --worker-urls http://localhost:10090 http://localhost:10091 + +# Or start empty and add workers dynamically +python examples/diffusion_router/demo.py --port 30080 +curl -X POST 'http://localhost:30080/add_worker?url=http://localhost:10090' +``` + +## API Endpoints + +| Method | Path | Description | +|--------|------|-------------| +| `POST` | `/generate` | Image generation (forwards to `/v1/images/generations`) | +| `POST` | `/generate_video` | Video generation (forwards to `/v1/videos/generations`) | +| `GET` | `/health` | Aggregated router health | +| `GET` | `/health_workers` | Per-worker health and load info | +| `POST` | `/add_worker` | Register a diffusion worker (`?url=...` or JSON body) | +| `GET` | `/list_workers` | List registered workers | +| `POST` | `/update_weights_from_disk` | Broadcast weight reload to all workers | +| `*` | `/{path}` | Catch-all proxy to least-loaded worker | + +## Load Balancing + +The router uses a **least-request** strategy: each incoming request is forwarded to the worker with the fewest in-flight requests. This is workload-aware and avoids hot-spotting compared to round-robin. When a request completes, the worker's count is decremented, keeping the load state accurate in real time. + +Workers that fail consecutive health checks (default: 3) are quarantined and excluded from the routing pool. A background loop pings each worker's `GET /health` endpoint at a configurable interval (default: 10s). + +## Notes + +- Health check endpoint follows Miles/SGLang convention: `GET /health`. +- Responses are fully buffered; streaming and large-response handling are not supported yet (planned for a follow-up PR). + +## Example Requests + +```bash +# Check health +curl http://localhost:30080/health + +# Generate an image +curl -X POST http://localhost:30080/generate \ + -H 'Content-Type: application/json' \ + -d '{"model": "stabilityai/stable-diffusion-3", "prompt": "a cat", "n": 1, "size": "1024x1024"}' + +# Reload weights on all workers +curl -X POST http://localhost:30080/update_weights_from_disk \ + -H 'Content-Type: application/json' \ + -d '{"model_path": "/path/to/new/weights"}' +``` + +## CLI Options + +``` +--host Bind address (default: 0.0.0.0) +--port Port (default: 30080) +--worker-urls Initial worker URLs +--max-connections Max concurrent connections (default: 100) +--timeout Request timeout in seconds +--health-check-interval Seconds between health checks (default: 10) +--health-check-failure-threshold Failures before quarantine (default: 3) +--verbose Enable verbose logging +``` diff --git a/examples/diffusion_router/demo.py b/examples/diffusion_router/demo.py new file mode 100644 index 000000000..4bb17d522 --- /dev/null +++ b/examples/diffusion_router/demo.py @@ -0,0 +1,57 @@ +""" +Demo script for the Miles Diffusion Router. + +Starts a diffusion router that load-balances requests across multiple +sglang-diffusion worker instances using least-request routing. + +Usage: + # Start with no workers (add them dynamically via /add_worker): + python examples/diffusion_router/demo.py --port 30080 + + # Start with pre-registered workers: + python examples/diffusion_router/demo.py --port 30080 \ + --worker-urls http://localhost:10090 http://localhost:10091 + + # Then interact: + curl http://localhost:30080/health + curl -X POST 'http://localhost:30080/add_worker?url=http://localhost:10092' + curl http://localhost:30080/list_workers + curl -X POST http://localhost:30080/generate -H 'Content-Type: application/json' \ + -d '{"model": "stabilityai/stable-diffusion-3", "prompt": "a cat", "n": 1, "size": "1024x1024"}' +""" + +import argparse + +import uvicorn + +from miles.router.diffusion_router import DiffusionRouter + + +def main(): + parser = argparse.ArgumentParser(description="Miles Diffusion Router Demo") + parser.add_argument("--host", type=str, default="0.0.0.0", help="Router bind address") + parser.add_argument("--port", type=int, default=30080, help="Router port") + parser.add_argument("--worker-urls", nargs="*", default=[], help="Initial diffusion worker URLs") + parser.add_argument("--max-connections", type=int, default=100, help="Max concurrent connections to workers") + parser.add_argument("--timeout", type=float, default=None, help="Request timeout in seconds") + parser.add_argument("--health-check-interval", type=int, default=10, help="Seconds between health checks") + parser.add_argument("--health-check-failure-threshold", type=int, default=3, help="Failures before quarantine") + parser.add_argument("--verbose", action="store_true", help="Enable verbose logging") + args = parser.parse_args() + + router = DiffusionRouter(args, verbose=args.verbose) + + # Pre-register any workers specified on the command line + for url in args.worker_urls: + router.worker_request_counts[url] = 0 + router.worker_failure_counts[url] = 0 + if args.verbose: + print(f"[demo] Pre-registered worker: {url}") + + print(f"[demo] Starting diffusion router on {args.host}:{args.port}") + print(f"[demo] Workers: {list(router.worker_request_counts.keys()) or '(none — add via POST /add_worker)'}") + uvicorn.run(router.app, host=args.host, port=args.port, log_level="info") + + +if __name__ == "__main__": + main() diff --git a/miles/router/diffusion_router.py b/miles/router/diffusion_router.py new file mode 100644 index 000000000..2b3b306c9 --- /dev/null +++ b/miles/router/diffusion_router.py @@ -0,0 +1,281 @@ +import argparse +import asyncio +import json +import logging + +import httpx +import uvicorn +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse +from starlette.responses import Response + +logger = logging.getLogger(__name__) + + +def run_diffusion_router(args): + """Run the diffusion router with the specified configuration.""" + router = DiffusionRouter(args) + uvicorn.run(router.app, host=args.host, port=args.port, log_level="info") + + +class DiffusionRouter: + def __init__(self, args, verbose=False): + """Initialize the diffusion router for load-balancing across sglang-diffusion workers.""" + self.args = args + self.verbose = verbose + + self.app = FastAPI() + self.app.add_event_handler("startup", self._start_background_health_check) + + # URL -> Active Request Count (load state) + self.worker_request_counts: dict[str, int] = {} + # URL -> Consecutive Failures + self.worker_failure_counts: dict[str, int] = {} + # Quarantined workers excluded from routing pool + self.dead_workers: set[str] = set() + + max_connections = getattr(args, "max_connections", 100) + timeout = getattr(args, "timeout", None) + + self.client = httpx.AsyncClient( + limits=httpx.Limits(max_connections=max_connections), + timeout=httpx.Timeout(timeout), + ) + + self._setup_routes() + + def _setup_routes(self): + """Setup all the HTTP routes.""" + self.app.post("/add_worker")(self.add_worker) + self.app.get("/list_workers")(self.list_workers) + self.app.get("/health")(self.health) + self.app.get("/health_workers")(self.health_workers) + self.app.post("/generate")(self.generate) + self.app.post("/generate_video")(self.generate_video) + self.app.post("/update_weights_from_disk")(self.update_weights_from_disk) + # Catch-all route for proxying — must be registered LAST + self.app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])(self.proxy) + + # ── Health checks ──────────────────────────────────────────────── + + async def _start_background_health_check(self): + asyncio.create_task(self._health_check_loop()) + + async def _check_worker_health(self, url): + try: + response = await self.client.get(f"{url}/health", timeout=5.0) + if response.status_code == 200: + return url, True + logger.debug(f"[diffusion-router] Worker {url} unhealthy (status {response.status_code})") + except Exception as e: + logger.debug(f"[diffusion-router] Worker {url} health check failed: {e}") + return url, False + + async def _health_check_loop(self): + """Background loop to monitor worker health and quarantine failing workers.""" + interval = getattr(self.args, "health_check_interval", 10) + threshold = getattr(self.args, "health_check_failure_threshold", 3) + + while True: + try: + await asyncio.sleep(interval) + + urls = [u for u in self.worker_request_counts if u not in self.dead_workers] + if not urls: + continue + + results = await asyncio.gather(*(self._check_worker_health(url) for url in urls)) + + for url, is_healthy in results: + if not is_healthy: + failures = self.worker_failure_counts.get(url, 0) + 1 + self.worker_failure_counts[url] = failures + if failures >= threshold: + logger.warning( + f"[diffusion-router] Worker {url} failed {threshold} consecutive checks. Marking DEAD." + ) + self.dead_workers.add(url) + else: + self.worker_failure_counts[url] = 0 + + healthy = len(self.worker_request_counts) - len(self.dead_workers) + logger.debug(f"[diffusion-router] Health check complete. {healthy} workers healthy.") + + except asyncio.CancelledError: + raise + except Exception as e: + logger.error(f"[diffusion-router] Unexpected error in health check loop: {e}", exc_info=True) + await asyncio.sleep(5) + + # ── Load balancing ─────────────────────────────────────────────── + + def _use_url(self): + """Select worker URL with minimal active requests.""" + if not self.worker_request_counts: + raise RuntimeError("No workers registered in the pool") + if not self.dead_workers: + url = min(self.worker_request_counts, key=self.worker_request_counts.get) + else: + valid_workers = (w for w in self.worker_request_counts if w not in self.dead_workers) + try: + url = min(valid_workers, key=self.worker_request_counts.get) + except ValueError: + raise RuntimeError("No healthy workers available in the pool") from None + + self.worker_request_counts[url] += 1 + return url + + def _finish_url(self, url): + """Mark the request to the given URL as finished.""" + assert url in self.worker_request_counts, f"URL {url} not recognized" + self.worker_request_counts[url] -= 1 + assert self.worker_request_counts[url] >= 0, f"URL {url} count went negative" + + # ── Proxy helpers ──────────────────────────────────────────────── + + async def _forward_to_worker(self, request: Request, path: str) -> Response: + """Forward a request to the least-loaded worker and return the response.""" + try: + worker_url = self._use_url() + except RuntimeError as exc: + return JSONResponse(status_code=503, content={"error": str(exc)}) + + # TODO: Support streaming responses; current implementation buffers full response. + query = request.url.query + url = f"{worker_url}/{path}" if not query else f"{worker_url}/{path}?{query}" + body = await request.body() + headers = dict(request.headers) + + try: + response = await self.client.request(request.method, url, content=body, headers=headers) + content = await response.aread() + finally: + self._finish_url(worker_url) + + resp_headers = self._sanitize_response_headers(response.headers) + content_type = resp_headers.get("content-type", "") + try: + data = json.loads(content) + return JSONResponse(content=data, status_code=response.status_code, headers=resp_headers) + except Exception: + return Response( + content=content, status_code=response.status_code, headers=resp_headers, media_type=content_type + ) + + async def _broadcast_to_workers(self, path: str, body: bytes, headers: dict) -> list[dict]: + """Send a request to ALL healthy workers and collect results.""" + urls = [u for u in self.worker_request_counts if u not in self.dead_workers] + if not urls: + return [] + + async def _send(worker_url): + try: + response = await self.client.post(f"{worker_url}/{path}", content=body, headers=headers) + content = await response.aread() + return {"worker_url": worker_url, "status_code": response.status_code, "body": json.loads(content)} + except Exception as e: + return {"worker_url": worker_url, "status_code": 502, "body": {"error": str(e)}} + + return await asyncio.gather(*(_send(u) for u in urls)) + + @staticmethod + def _sanitize_response_headers(headers) -> dict: + """Remove hop-by-hop and encoding headers that no longer match buffered content.""" + hop_by_hop = {"connection", "keep-alive", "proxy-authenticate", "proxy-authorization", "te", "trailers", + "transfer-encoding", "upgrade"} + dropped = {"content-length", "content-encoding"} + return {k: v for k, v in headers.items() if k.lower() not in hop_by_hop | dropped} + + # ── Route handlers ─────────────────────────────────────────────── + + async def generate(self, request: Request): + """Route image generation to the least-loaded worker via /v1/images/generations.""" + return await self._forward_to_worker(request, "v1/images/generations") + + async def generate_video(self, request: Request): + """Route video generation to the least-loaded worker via /v1/videos/generations.""" + return await self._forward_to_worker(request, "v1/videos/generations") + + async def health(self, request: Request): + """Aggregated health status: healthy if at least one worker is alive.""" + total = len(self.worker_request_counts) + dead = len(self.dead_workers) + healthy = total - dead + status = "healthy" if healthy > 0 else "unhealthy" + code = 200 if healthy > 0 else 503 + return JSONResponse( + status_code=code, + content={"status": status, "healthy_workers": healthy, "total_workers": total}, + ) + + async def health_workers(self, request: Request): + """Per-worker health and load information.""" + workers = [] + for url, count in self.worker_request_counts.items(): + workers.append({ + "url": url, + "active_requests": count, + "is_dead": url in self.dead_workers, + "consecutive_failures": self.worker_failure_counts.get(url, 0), + }) + return JSONResponse(content={"workers": workers}) + + async def update_weights_from_disk(self, request: Request): + """Broadcast weight reload to all healthy workers.""" + body = await request.body() + headers = dict(request.headers) + results = await self._broadcast_to_workers("update_weights_from_disk", body, headers) + return JSONResponse(content={"results": results}) + + async def add_worker(self, request: Request): + """Register a new diffusion worker.""" + worker_url = request.query_params.get("url") or request.query_params.get("worker_url") + + if not worker_url: + body = await request.body() + try: + payload = json.loads(body) if body else {} + except json.JSONDecodeError: + return JSONResponse(status_code=400, content={"error": "Invalid JSON body"}) + worker_url = payload.get("url") or payload.get("worker_url") + + if not worker_url: + return JSONResponse( + status_code=400, content={"error": "worker_url is required (use query ?url=... or JSON body)"} + ) + + if worker_url not in self.worker_request_counts: + self.worker_request_counts[worker_url] = 0 + self.worker_failure_counts[worker_url] = 0 + if self.verbose: + print(f"[diffusion-router] Added new worker: {worker_url}") + + return {"status": "success", "worker_urls": list(self.worker_request_counts.keys())} + + async def list_workers(self, request: Request): + """List all registered workers.""" + return {"urls": list(self.worker_request_counts.keys())} + + async def proxy(self, request: Request, path: str): + """Catch-all: forward any unmatched request to the least-loaded worker.""" + return await self._forward_to_worker(request, path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Miles Diffusion Router") + parser.add_argument("--host", type=str, default="0.0.0.0") + parser.add_argument("--port", type=int, default=30080) + parser.add_argument("--worker-urls", nargs="*", default=[], help="Initial worker URLs to register") + parser.add_argument("--max-connections", type=int, default=100) + parser.add_argument("--timeout", type=float, default=None) + parser.add_argument("--health-check-interval", type=int, default=10) + parser.add_argument("--health-check-failure-threshold", type=int, default=3) + parser.add_argument("--verbose", action="store_true") + args = parser.parse_args() + + router = DiffusionRouter(args, verbose=args.verbose) + for url in args.worker_urls: + router.worker_request_counts[url] = 0 + router.worker_failure_counts[url] = 0 + + uvicorn.run(router.app, host=args.host, port=args.port, log_level="info") From 46a2e570cc2420877db47efc3146e2ba44610207 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B5=B5=E6=99=A8=E9=98=B3?= Date: Wed, 4 Feb 2026 15:14:49 -0800 Subject: [PATCH 02/14] Apply suggestion from @gemini-code-assist[bot] Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- examples/diffusion_router/demo.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/diffusion_router/demo.py b/examples/diffusion_router/demo.py index 4bb17d522..9c1b4693e 100644 --- a/examples/diffusion_router/demo.py +++ b/examples/diffusion_router/demo.py @@ -43,8 +43,7 @@ def main(): # Pre-register any workers specified on the command line for url in args.worker_urls: - router.worker_request_counts[url] = 0 - router.worker_failure_counts[url] = 0 + router.add_worker_sync(url) if args.verbose: print(f"[demo] Pre-registered worker: {url}") From d39899d62193a02d7f6b67a36a9e4dabbbac8a2a Mon Sep 17 00:00:00 2001 From: Dong Wang Date: Thu, 5 Feb 2026 01:10:20 -0800 Subject: [PATCH 03/14] addressed comments: docs --- examples/diffusion_router/README.md | 4 ++-- miles/router/diffusion_router.py | 2 ++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/diffusion_router/README.md b/examples/diffusion_router/README.md index 135b24f9a..b1fd0cc09 100644 --- a/examples/diffusion_router/README.md +++ b/examples/diffusion_router/README.md @@ -25,7 +25,7 @@ curl -X POST 'http://localhost:30080/add_worker?url=http://localhost:10090' | `POST` | `/add_worker` | Register a diffusion worker (`?url=...` or JSON body) | | `GET` | `/list_workers` | List registered workers | | `POST` | `/update_weights_from_disk` | Broadcast weight reload to all workers | -| `*` | `/{path}` | Catch-all proxy to least-loaded worker | +| `GET, POST, PUT, DELETE` | `/{path}` | Catch-all proxy to least-loaded worker | ## Load Balancing @@ -62,7 +62,7 @@ curl -X POST http://localhost:30080/update_weights_from_disk \ --port Port (default: 30080) --worker-urls Initial worker URLs --max-connections Max concurrent connections (default: 100) ---timeout Request timeout in seconds +--timeout Request timeout in seconds for router-to-worker requests --health-check-interval Seconds between health checks (default: 10) --health-check-failure-threshold Failures before quarantine (default: 3) --verbose Enable verbose logging diff --git a/miles/router/diffusion_router.py b/miles/router/diffusion_router.py index 2b3b306c9..121ba8e26 100644 --- a/miles/router/diffusion_router.py +++ b/miles/router/diffusion_router.py @@ -95,6 +95,8 @@ async def _health_check_loop(self): f"[diffusion-router] Worker {url} failed {threshold} consecutive checks. Marking DEAD." ) self.dead_workers.add(url) + # Dead workers are permanently excluded. Reconnecting them + # would risk serving stale weights after training has moved on. else: self.worker_failure_counts[url] = 0 From a3c34a52012837f203a47b326f09a3b37c5dadf5 Mon Sep 17 00:00:00 2001 From: sniper35 Date: Sat, 7 Feb 2026 08:27:05 +0000 Subject: [PATCH 04/14] fix func call --- examples/diffusion_router/demo.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/diffusion_router/demo.py b/examples/diffusion_router/demo.py index 9c1b4693e..4bb17d522 100644 --- a/examples/diffusion_router/demo.py +++ b/examples/diffusion_router/demo.py @@ -43,7 +43,8 @@ def main(): # Pre-register any workers specified on the command line for url in args.worker_urls: - router.add_worker_sync(url) + router.worker_request_counts[url] = 0 + router.worker_failure_counts[url] = 0 if args.verbose: print(f"[demo] Pre-registered worker: {url}") From 1d93451354a218b391bc6ab3079e56d33681f84a Mon Sep 17 00:00:00 2001 From: sniper35 Date: Sat, 7 Feb 2026 08:35:28 +0000 Subject: [PATCH 05/14] simplify select url logic --- miles/router/diffusion_router.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/miles/router/diffusion_router.py b/miles/router/diffusion_router.py index 121ba8e26..e69497420 100644 --- a/miles/router/diffusion_router.py +++ b/miles/router/diffusion_router.py @@ -115,14 +115,12 @@ def _use_url(self): """Select worker URL with minimal active requests.""" if not self.worker_request_counts: raise RuntimeError("No workers registered in the pool") - if not self.dead_workers: - url = min(self.worker_request_counts, key=self.worker_request_counts.get) - else: - valid_workers = (w for w in self.worker_request_counts if w not in self.dead_workers) - try: - url = min(valid_workers, key=self.worker_request_counts.get) - except ValueError: - raise RuntimeError("No healthy workers available in the pool") from None + + valid_workers = (w for w in self.worker_request_counts if w not in self.dead_workers) + try: + url = min(valid_workers, key=self.worker_request_counts.get) + except ValueError: + raise RuntimeError("No healthy workers available in the pool") from None self.worker_request_counts[url] += 1 return url From fb4570a3aca2f1ce76746b1e741e54229b796c0b Mon Sep 17 00:00:00 2001 From: Dong Wang Date: Sat, 7 Feb 2026 10:52:55 -0800 Subject: [PATCH 06/14] Revert "fix func call" This reverts commit 021cf6b15486676068db53e474c288838efe19ce. --- examples/diffusion_router/demo.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/diffusion_router/demo.py b/examples/diffusion_router/demo.py index 4bb17d522..9c1b4693e 100644 --- a/examples/diffusion_router/demo.py +++ b/examples/diffusion_router/demo.py @@ -43,8 +43,7 @@ def main(): # Pre-register any workers specified on the command line for url in args.worker_urls: - router.worker_request_counts[url] = 0 - router.worker_failure_counts[url] = 0 + router.add_worker_sync(url) if args.verbose: print(f"[demo] Pre-registered worker: {url}") From ad9d879300bae6271108c47894747162094d9a83 Mon Sep 17 00:00:00 2001 From: sniper35 Date: Sun, 8 Feb 2026 07:32:14 +0000 Subject: [PATCH 07/14] added bench scripts and tests --- examples/diffusion_router/bench_router.py | 500 ++++++++++++++++++ .../bench_routing_algorithms.py | 287 ++++++++++ examples/diffusion_router/demo.py | 3 + miles/router/diffusion_router.py | 65 ++- tests/fast/router/test_diffusion_router.py | 167 ++++++ 5 files changed, 1006 insertions(+), 16 deletions(-) create mode 100644 examples/diffusion_router/bench_router.py create mode 100644 examples/diffusion_router/bench_routing_algorithms.py create mode 100644 tests/fast/router/test_diffusion_router.py diff --git a/examples/diffusion_router/bench_router.py b/examples/diffusion_router/bench_router.py new file mode 100644 index 000000000..ed9acce0c --- /dev/null +++ b/examples/diffusion_router/bench_router.py @@ -0,0 +1,500 @@ +#!/usr/bin/env python3 +""" +Launch sglang-diffusion workers, start the Miles DiffusionRouter, then run +the sglang diffusion serving benchmark against the router. + +Example: + python examples/diffusion_router/bench_router.py \ + --model Wan-AI/Wan2.2-T2V-A14B-Diffusers \ + --num-workers 2 \ + --num-prompts 20 \ + --max-concurrency 4 +""" + +from __future__ import annotations + +import argparse +import os +import shlex +import signal +import socket +import subprocess +import sys +import time +from pathlib import Path +from typing import Iterable + +import requests + + +def _require_non_empty_model(model: str) -> str: + normalized = model.strip() + if not normalized: + raise ValueError( + "--model must be a non-empty model ID/path. " + "Detected an empty value, which often means a shell variable such as " + "$MODEL was unset." + ) + return normalized + + +def _default_sglang_root() -> Path: + # Repo layout: miles/examples/diffusion_router/bench_router.py -> miles/ (parents[2]) + return Path(__file__).resolve().parents[2].parent / "sglang" + + +def _resolve_sglang_root(path: str | None) -> Path | None: + if path: + root = Path(path).expanduser().resolve() + sglang_pkg = root / "python" / "sglang" + if not sglang_pkg.exists(): + raise FileNotFoundError(f"sglang source not found at {root}. Expected {sglang_pkg}.") + return root + default = _default_sglang_root() + if (default / "python" / "sglang").exists(): + return default + # No source repo found — fall back to pip-installed sglang + return None + + +def _with_pythonpath(env: dict[str, str], extra_path: Path) -> dict[str, str]: + env = dict(env) + existing = env.get("PYTHONPATH") + extra = str(extra_path) + env["PYTHONPATH"] = f"{extra}{os.pathsep}{existing}" if existing else extra + return env + + +def _build_sglang_cli_cmd() -> list[str]: + """ + Build a command prefix that invokes the `sglang` CLI from the current + Python environment. + """ + sglang_bin = Path(sys.executable).resolve().parent / "sglang" + if sglang_bin.exists(): + return [str(sglang_bin)] + + # Fallback when the console script is missing. + return [sys.executable, "-c", "from sglang.cli.main import main; main()"] + + +def _wait_for_health( + url: str, timeout: int, label: str, proc: subprocess.Popen | None = None, +) -> None: + start = time.time() + last_print = 0.0 + while True: + elapsed = time.time() - start + + # Fail fast if the backing process has already exited + if proc is not None and proc.poll() is not None: + raise RuntimeError( + f"{label} process exited with code {proc.returncode}. " + "Run the worker command directly to see the error." + ) + + try: + resp = requests.get(f"{url}/health", timeout=1) + if resp.status_code == 200: + print(f" [bench] {label} is healthy ({elapsed:.0f}s)", flush=True) + return + except requests.RequestException: + pass + + if elapsed - last_print >= 30: + print(f" [bench] Still waiting for {label}... ({elapsed:.0f}s elapsed)", flush=True) + last_print = elapsed + + if elapsed > timeout: + raise TimeoutError(f"Timed out waiting for {label} at {url}.") + time.sleep(1) + + +def _build_worker_urls(host: str, base_port: int, count: int, stride: int) -> list[str]: + return [f"http://{host}:{base_port + i * stride}" for i in range(count)] + + +def _infer_client_host(host: str) -> str: + if host in ("0.0.0.0", "::"): + return "127.0.0.1" + return host + + +def _is_port_available(host: str, port: int) -> bool: + if host in ("0.0.0.0", "::"): + host = "127.0.0.1" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.settimeout(0.5) + return sock.connect_ex((host, port)) != 0 + + +def _reserve_available_port(host: str, preferred_port: int, used_ports: set[int]) -> int: + if preferred_port < 1 or preferred_port > 65535: + raise ValueError(f"Invalid port: {preferred_port}") + + for port in range(preferred_port, 65536): + if port in used_ports: + continue + if _is_port_available(host, port): + used_ports.add(port) + return port + + for port in range(1024, preferred_port): + if port in used_ports: + continue + if _is_port_available(host, port): + used_ports.add(port) + return port + + raise RuntimeError( + f"Unable to reserve a free port for host {host}. " + f"Preferred start={preferred_port}." + ) + + +def _parse_gpu_id_list(raw: str) -> list[str]: + return [item.strip() for item in raw.split(",") if item.strip()] + + +def _detect_gpu_count() -> int: + try: + import torch + + return int(torch.cuda.device_count()) + except Exception: + return 0 + + +def _resolve_gpu_pool(args: argparse.Namespace, env: dict[str, str]) -> list[str] | None: + if args.worker_gpu_ids: + return [str(x) for x in args.worker_gpu_ids] + + visible = env.get("CUDA_VISIBLE_DEVICES", "") + if visible: + parsed = _parse_gpu_id_list(visible) + if parsed: + return parsed + + gpu_count = _detect_gpu_count() + if gpu_count > 0: + return [str(i) for i in range(gpu_count)] + return None + + +def _terminate_all(processes: Iterable[subprocess.Popen]) -> None: + procs = list(processes) + + def _signal_group(proc: subprocess.Popen, sig: int) -> None: + try: + # Processes are launched with start_new_session=True so each has its own group. + os.killpg(proc.pid, sig) + except ProcessLookupError: + pass + except Exception: + if proc.poll() is None: + try: + os.kill(proc.pid, sig) + except ProcessLookupError: + pass + + for proc in procs: + _signal_group(proc, signal.SIGTERM) + + for proc in procs: + try: + proc.wait(timeout=15) + except subprocess.TimeoutExpired: + _signal_group(proc, signal.SIGKILL) + + for proc in procs: + try: + proc.wait(timeout=5) + except subprocess.TimeoutExpired: + pass + + +def main() -> int: + parser = argparse.ArgumentParser(description="Benchmark Miles DiffusionRouter with sglang bench_serving.") + parser.add_argument("--model", type=str, required=True, help="Diffusion model HF ID or local path.") + parser.add_argument("--sglang-root", type=str, default=None, help="Path to sglang repo (default: ../sglang).") + + parser.add_argument("--router-host", type=str, default="127.0.0.1", help="Router bind host.") + parser.add_argument("--router-port", type=int, default=30080, help="Router port.") + parser.add_argument("--routing-algorithm", type=str, default="least-request", + choices=["least-request", "round-robin", "random"], + help="Load-balancing algorithm for the router.") + parser.add_argument("--router-verbose", action="store_true", help="Enable router verbose logging.") + parser.add_argument("--router-extra-args", type=str, default="", help="Extra args for the router demo script.") + + parser.add_argument("--worker-host", type=str, default="127.0.0.1", help="Worker bind host.") + parser.add_argument("--worker-urls", nargs="*", default=[], help="Existing worker URLs to use.") + parser.add_argument("--num-workers", type=int, default=1, help="Number of workers to launch.") + parser.add_argument("--worker-base-port", type=int, default=10090, help="Base port for launched workers.") + parser.add_argument( + "--worker-port-stride", + type=int, + default=2, + help="Port increment between launched workers. Keep >=2 to avoid sglang internal port collisions.", + ) + parser.add_argument( + "--worker-master-port-base", + type=int, + default=30005, + help="Base torch distributed master port for launched workers.", + ) + parser.add_argument( + "--worker-scheduler-port-base", + type=int, + default=5555, + help="Base scheduler port for launched workers.", + ) + parser.add_argument( + "--worker-internal-port-stride", + type=int, + default=1000, + help=( + "Stride used between workers for master/scheduler base ports. " + "Use >= 101 because sglang randomizes each by +[0,100]." + ), + ) + parser.add_argument("--num-gpus-per-worker", type=int, default=1, help="GPUs per worker.") + parser.add_argument( + "--worker-gpu-ids", + nargs="*", + default=None, + help=( + "Optional GPU IDs/UUIDs for launched workers. They are consumed in order, " + "in groups of --num-gpus-per-worker." + ), + ) + parser.add_argument("--worker-extra-args", type=str, default="", help="Extra args for `sglang serve`.") + parser.add_argument("--skip-workers", action="store_true", help="Do not launch workers.") + + parser.add_argument("--dataset", type=str, default="random", choices=["vbench", "random"]) + parser.add_argument("--dataset-path", type=str, default=None) + parser.add_argument("--num-prompts", type=int, default=20) + parser.add_argument("--max-concurrency", type=int, default=1) + parser.add_argument("--request-rate", type=float, default=float("inf")) + parser.add_argument("--task", type=str, default=None) + parser.add_argument("--width", type=int, default=None) + parser.add_argument("--height", type=int, default=None) + parser.add_argument("--num-frames", type=int, default=None) + parser.add_argument("--fps", type=int, default=None) + parser.add_argument("--output-file", type=str, default=None) + parser.add_argument("--disable-tqdm", action="store_true") + parser.add_argument("--log-level", type=str, default="INFO") + parser.add_argument("--bench-extra-args", type=str, default="", help="Extra args for bench_serving.") + + parser.add_argument("--wait-timeout", type=int, default=1200, help="Seconds to wait for services to be healthy.") + + args = parser.parse_args() + args.model = _require_non_empty_model(args.model) + + sglang_root = _resolve_sglang_root(args.sglang_root) + if sglang_root is not None: + sglang_python = sglang_root / "python" + env = _with_pythonpath(os.environ, sglang_python) + else: + # Verify pip-installed sglang is importable + try: + import sglang # noqa: F401 + except ImportError: + raise RuntimeError( + "sglang is not installed and no source repo found at ../sglang.\n" + "Install with: uv pip install \"sglang[diffusion]\" --prerelease=allow\n" + "Or point to the source repo with: --sglang-root /path/to/sglang" + ) + env = dict(os.environ) + + worker_urls = list(args.worker_urls) + if not worker_urls: + if args.worker_port_stride < 1: + raise ValueError("--worker-port-stride must be >= 1") + if args.worker_internal_port_stride < 101: + raise ValueError("--worker-internal-port-stride must be >= 101") + worker_urls = _build_worker_urls( + args.worker_host, + args.worker_base_port, + args.num_workers, + args.worker_port_stride, + ) + + if args.skip_workers and not worker_urls: + raise ValueError("No workers specified. Provide --worker-urls or disable --skip-workers.") + + if not _is_port_available(args.router_host, args.router_port): + raise RuntimeError( + f"Router port {args.router_port} on {args.router_host} is already in use. " + "Stop the existing router/process or change --router-port." + ) + + processes: list[subprocess.Popen] = [] + try: + if not args.skip_workers: + reserved_ports: set[int] = {args.router_port} + worker_internal_ports: list[tuple[int, int]] = [] + for url in worker_urls: + port = int(url.rsplit(":", 1)[1]) + if not _is_port_available(args.worker_host, port): + raise RuntimeError( + f"Worker port {port} on {args.worker_host} is already in use. " + "Stop existing servers or change --worker-base-port/--worker-port-stride." + ) + reserved_ports.add(port) + + for i, _ in enumerate(worker_urls): + preferred_master = args.worker_master_port_base + i * args.worker_internal_port_stride + preferred_scheduler = ( + args.worker_scheduler_port_base + i * args.worker_internal_port_stride + ) + master_port = _reserve_available_port(args.worker_host, preferred_master, reserved_ports) + scheduler_port = _reserve_available_port(args.worker_host, preferred_scheduler, reserved_ports) + worker_internal_ports.append((master_port, scheduler_port)) + if master_port != preferred_master or scheduler_port != preferred_scheduler: + print( + "[bench] Adjusted internal worker ports due to conflict: " + f"worker={i} master={master_port} scheduler={scheduler_port}", + flush=True, + ) + + sglang_cli_cmd = _build_sglang_cli_cmd() + gpu_pool = _resolve_gpu_pool(args, env) + total_gpus_needed = len(worker_urls) * args.num_gpus_per_worker + if gpu_pool is not None and len(gpu_pool) < total_gpus_needed: + raise ValueError( + f"Need {total_gpus_needed} GPU slots for {len(worker_urls)} worker(s) x " + f"{args.num_gpus_per_worker} GPU(s), but only {len(gpu_pool)} visible. " + "Set --worker-gpu-ids, reduce --num-workers, or reduce --num-gpus-per-worker." + ) + for i, url in enumerate(worker_urls): + port = int(url.rsplit(":", 1)[1]) + master_port, scheduler_port = worker_internal_ports[i] + cmd = [ + *sglang_cli_cmd, + "serve", + "--model-path", + args.model, + "--num-gpus", + str(args.num_gpus_per_worker), + "--host", + args.worker_host, + "--port", + str(port), + "--master-port", + str(master_port), + "--scheduler-port", + str(scheduler_port), + ] + if args.worker_extra_args: + cmd += shlex.split(args.worker_extra_args) + worker_env = env + if gpu_pool is not None: + start = i * args.num_gpus_per_worker + end = start + args.num_gpus_per_worker + assigned = gpu_pool[start:end] + worker_env = dict(env) + worker_env["CUDA_VISIBLE_DEVICES"] = ",".join(assigned) + print( + f"[bench] Worker {i} uses CUDA_VISIBLE_DEVICES={worker_env['CUDA_VISIBLE_DEVICES']}", + flush=True, + ) + print( + f"[bench] Launching worker {i} on port {port} " + f"(master={master_port}, scheduler={scheduler_port})...", + flush=True, + ) + processes.append(subprocess.Popen(cmd, env=worker_env, start_new_session=True)) + + if args.max_concurrency and worker_urls: + per_worker = (args.max_concurrency + len(worker_urls) - 1) // len(worker_urls) + if per_worker > 1: + print( + "[bench] Warning: " + f"max_concurrency={args.max_concurrency} over {len(worker_urls)} workers " + f"can drive up to ~{per_worker} concurrent requests per worker, which may OOM " + "large diffusion models.", + flush=True, + ) + + print(f"[bench] Waiting for {len(worker_urls)} worker(s) to become healthy (this may take several minutes)...", flush=True) + for i, url in enumerate(worker_urls): + _wait_for_health(url, args.wait_timeout, f"worker {url}", proc=processes[i]) + + router_cmd = [ + sys.executable, + "examples/diffusion_router/demo.py", + "--host", + args.router_host, + "--port", + str(args.router_port), + "--worker-urls", + *worker_urls, + ] + if args.routing_algorithm: + router_cmd += ["--routing-algorithm", args.routing_algorithm] + if args.router_verbose: + router_cmd.append("--verbose") + if args.router_extra_args: + router_cmd += shlex.split(args.router_extra_args) + + if not _is_port_available(args.router_host, args.router_port): + raise RuntimeError( + f"Router port {args.router_port} on {args.router_host} is already in use. " + "Stop the existing router/process or change --router-port." + ) + + print(f"[bench] Launching router on port {args.router_port}...", flush=True) + router_proc = subprocess.Popen(router_cmd, start_new_session=True) + processes.append(router_proc) + + router_host = _infer_client_host(args.router_host) + base_url = f"http://{router_host}:{args.router_port}" + _wait_for_health(base_url, args.wait_timeout, "router", proc=router_proc) + + print(f"[bench] Running benchmark: {args.num_prompts} prompts, concurrency={args.max_concurrency}", flush=True) + + bench_cmd = [ + sys.executable, + "-m", + "sglang.multimodal_gen.benchmarks.bench_serving", + "--base-url", + base_url, + "--model", + args.model, + "--dataset", + args.dataset, + "--num-prompts", + str(args.num_prompts), + "--request-rate", + str(args.request_rate), + "--log-level", + args.log_level, + ] + if args.dataset_path: + bench_cmd += ["--dataset-path", args.dataset_path] + if args.max_concurrency: + bench_cmd += ["--max-concurrency", str(args.max_concurrency)] + if args.task: + bench_cmd += ["--task", args.task] + if args.width: + bench_cmd += ["--width", str(args.width)] + if args.height: + bench_cmd += ["--height", str(args.height)] + if args.num_frames: + bench_cmd += ["--num-frames", str(args.num_frames)] + if args.fps: + bench_cmd += ["--fps", str(args.fps)] + if args.output_file: + bench_cmd += ["--output-file", args.output_file] + if args.disable_tqdm: + bench_cmd.append("--disable-tqdm") + if args.bench_extra_args: + bench_cmd += shlex.split(args.bench_extra_args) + + return subprocess.call(bench_cmd, env=env) + finally: + _terminate_all(reversed(processes)) + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/examples/diffusion_router/bench_routing_algorithms.py b/examples/diffusion_router/bench_routing_algorithms.py new file mode 100644 index 000000000..60b9ec760 --- /dev/null +++ b/examples/diffusion_router/bench_routing_algorithms.py @@ -0,0 +1,287 @@ +#!/usr/bin/env python3 +""" +Compare routing algorithms by running bench_router.py for each algorithm +and collecting results into a summary table and CSV. + +Example: + python examples/diffusion_router/bench_routing_algorithms.py \ + --model Wan-AI/Wan2.2-T2V-A14B-Diffusers \ + --num-workers 2 \ + --num-prompts 10 \ + --max-concurrency 2 +""" + +from __future__ import annotations + +import argparse +import csv +import json +import shlex +import subprocess +import sys +from datetime import datetime +from pathlib import Path + +ALL_ALGORITHMS = ["least-request", "round-robin", "random"] + + +def _require_non_empty_model(model: str) -> str: + normalized = model.strip() + if not normalized: + raise ValueError( + "--model must be a non-empty model ID/path. " + "Detected an empty value, which often means a shell variable such as " + "$MODEL was unset." + ) + return normalized + + +def main() -> int: + parser = argparse.ArgumentParser( + description="Compare routing algorithms by running bench_router.py for each." + ) + parser.add_argument("--model", type=str, required=True, help="Diffusion model HF ID or local path.") + parser.add_argument( + "--algorithms", nargs="+", default=ALL_ALGORITHMS, + choices=ALL_ALGORITHMS, help="Algorithms to compare (default: all three).", + ) + parser.add_argument("--output-dir", type=str, default=None, help="Output directory for results.") + + # Pass-through args for bench_router.py + parser.add_argument("--sglang-root", type=str, default=None) + parser.add_argument("--router-host", type=str, default="127.0.0.1") + parser.add_argument("--router-port", type=int, default=30080) + parser.add_argument("--router-verbose", action="store_true") + parser.add_argument("--router-extra-args", type=str, default="") + parser.add_argument("--worker-host", type=str, default="127.0.0.1") + parser.add_argument("--worker-urls", nargs="*", default=[]) + parser.add_argument("--num-workers", type=int, default=1) + parser.add_argument("--worker-base-port", type=int, default=10090) + parser.add_argument("--worker-port-stride", type=int, default=2) + parser.add_argument("--worker-master-port-base", type=int, default=30005) + parser.add_argument("--worker-scheduler-port-base", type=int, default=5555) + parser.add_argument("--worker-internal-port-stride", type=int, default=1000) + parser.add_argument("--num-gpus-per-worker", type=int, default=1) + parser.add_argument("--worker-gpu-ids", nargs="*", default=None) + parser.add_argument("--worker-extra-args", type=str, default="") + parser.add_argument("--skip-workers", action="store_true") + parser.add_argument("--dataset", type=str, default="random", choices=["vbench", "random"]) + parser.add_argument("--dataset-path", type=str, default=None) + parser.add_argument("--num-prompts", type=int, default=20) + parser.add_argument("--max-concurrency", type=int, default=1) + parser.add_argument("--request-rate", type=float, default=float("inf")) + parser.add_argument("--task", type=str, default=None) + parser.add_argument("--width", type=int, default=None) + parser.add_argument("--height", type=int, default=None) + parser.add_argument("--num-frames", type=int, default=None) + parser.add_argument("--fps", type=int, default=None) + parser.add_argument("--disable-tqdm", action="store_true") + parser.add_argument("--log-level", type=str, default="INFO") + parser.add_argument("--bench-extra-args", type=str, default="") + parser.add_argument("--wait-timeout", type=int, default=1200) + + args = parser.parse_args() + args.model = _require_non_empty_model(args.model) + + output_dir = ( + Path(args.output_dir) + if args.output_dir + else Path("outputs") / f"routing_algo_compare_{datetime.now().strftime('%Y%m%d_%H%M%S')}" + ) + output_dir.mkdir(parents=True, exist_ok=True) + + py = sys.executable + results: dict[str, dict] = {} + + for algo in args.algorithms: + print(f"\n{'='*60}", flush=True) + print(f" Running benchmark with routing algorithm: {algo}", flush=True) + print(f"{'='*60}\n", flush=True) + + out_file = output_dir / f"bench_{algo}.json" + + bench_cmd = [ + py, "examples/diffusion_router/bench_router.py", + "--model", args.model, + "--routing-algorithm", algo, + "--num-workers", str(args.num_workers), + "--num-prompts", str(args.num_prompts), + "--max-concurrency", str(args.max_concurrency), + "--num-gpus-per-worker", str(args.num_gpus_per_worker), + "--worker-base-port", str(args.worker_base_port), + "--worker-port-stride", str(args.worker_port_stride), + "--worker-master-port-base", str(args.worker_master_port_base), + "--worker-scheduler-port-base", str(args.worker_scheduler_port_base), + "--worker-internal-port-stride", str(args.worker_internal_port_stride), + "--router-host", args.router_host, + "--router-port", str(args.router_port), + "--dataset", args.dataset, + "--request-rate", str(args.request_rate), + "--wait-timeout", str(args.wait_timeout), + "--log-level", args.log_level, + "--output-file", str(out_file), + ] + if args.sglang_root: + bench_cmd += ["--sglang-root", args.sglang_root] + if args.worker_urls: + bench_cmd += ["--worker-urls", *args.worker_urls] + if args.worker_gpu_ids: + bench_cmd += ["--worker-gpu-ids", *args.worker_gpu_ids] + if args.dataset_path: + bench_cmd += ["--dataset-path", args.dataset_path] + if args.task: + bench_cmd += ["--task", args.task] + if args.width: + bench_cmd += ["--width", str(args.width)] + if args.height: + bench_cmd += ["--height", str(args.height)] + if args.num_frames: + bench_cmd += ["--num-frames", str(args.num_frames)] + if args.fps: + bench_cmd += ["--fps", str(args.fps)] + if args.disable_tqdm: + bench_cmd.append("--disable-tqdm") + if args.router_verbose: + bench_cmd.append("--router-verbose") + if args.skip_workers: + bench_cmd.append("--skip-workers") + if args.worker_extra_args: + bench_cmd += ["--worker-extra-args", args.worker_extra_args] + if args.router_extra_args: + bench_cmd += ["--router-extra-args", args.router_extra_args] + if args.bench_extra_args: + bench_cmd += ["--bench-extra-args", args.bench_extra_args] + if args.worker_host != "127.0.0.1": + bench_cmd += ["--worker-host", args.worker_host] + + print("[run]", " ".join(shlex.quote(x) for x in bench_cmd), flush=True) + rc = subprocess.call(bench_cmd) + + if rc != 0: + print(f"[warn] bench_router.py exited with code {rc} for algorithm '{algo}'", flush=True) + results[algo] = {"error": f"exit_code={rc}"} + continue + + if out_file.exists(): + try: + results[algo] = json.loads(out_file.read_text()) + except json.JSONDecodeError as e: + print(f"[warn] Failed to parse {out_file}: {e}", flush=True) + results[algo] = {"error": f"json_parse_error: {e}"} + else: + print(f"[warn] Output file not found: {out_file}", flush=True) + results[algo] = {"error": "output_file_missing"} + + # ── Collect per-algorithm metrics ──────────────────────────── + BASELINE = "random" + metric_keys = ["throughput_qps", "latency_mean", "latency_median", "latency_p99"] + + csv_rows: list[dict] = [] + parsed: dict[str, dict] = {} + for algo in args.algorithms: + data = results.get(algo, {}) + if "error" in data: + parsed[algo] = None + csv_rows.append({ + "algorithm": algo, "throughput_qps": "", "latency_mean": "", + "latency_median": "", "latency_p99": "", "duration": "", + "completed_requests": "", "failed_requests": "", + "throughput_qps_delta_pct": "", "latency_mean_delta_pct": "", + "latency_median_delta_pct": "", "latency_p99_delta_pct": "", + "error": data["error"], + }) + continue + + row = { + "algorithm": algo, + "throughput_qps": data.get("throughput_qps", ""), + "latency_mean": data.get("latency_mean", ""), + "latency_median": data.get("latency_median", ""), + "latency_p99": data.get("latency_p99", ""), + "duration": data.get("duration", ""), + "completed_requests": data.get("completed_requests", ""), + "failed_requests": data.get("failed_requests", ""), + "error": "", + } + parsed[algo] = row + csv_rows.append(row) + + # ── Compute deltas vs baseline ─────────────────────────────── + baseline_row = parsed.get(BASELINE) + for row in csv_rows: + if row.get("error"): + continue + for key in metric_keys: + val = row.get(key, "") + base = baseline_row.get(key, "") if baseline_row else "" + delta_key = f"{key}_delta_pct" + if isinstance(val, (int, float)) and isinstance(base, (int, float)) and base: + row[delta_key] = ((val - base) / abs(base)) * 100 + else: + row[delta_key] = "" + + # ── Print comparison table ─────────────────────────────────── + print(f"\n{'='*100}", flush=True) + print(f" Routing Algorithm Comparison (baseline: {BASELINE})", flush=True) + print(f"{'='*100}", flush=True) + + header = ( + f"{'Algorithm':<16} {'Throughput':>14} {'Tput Delta':>11}" + f" {'Mean Lat':>12} {'Delta':>8}" + f" {'Median Lat':>12} {'Delta':>8}" + f" {'P99 Lat':>12} {'Delta':>8}" + f" {'Done':>6} {'Fail':>6}" + ) + print(header, flush=True) + print("-" * len(header), flush=True) + + def _fmt_qps(v): + return f"{v:.2f} req/s" if isinstance(v, (int, float)) else str(v) + + def _fmt_lat(v): + return f"{v:.3f} s" if isinstance(v, (int, float)) else str(v) + + def _fmt_delta(v): + if isinstance(v, (int, float)): + sign = "+" if v >= 0 else "" + return f"{sign}{v:.1f}%" + return "" + + def _fmt_int(v): + return str(v) if v != "" else "N/A" + + for row in csv_rows: + if row.get("error"): + print(f"{row['algorithm']:<16} {'ERROR':>14} {row['error']}", flush=True) + continue + print( + f"{row['algorithm']:<16}" + f" {_fmt_qps(row['throughput_qps']):>14} {_fmt_delta(row.get('throughput_qps_delta_pct', '')):>11}" + f" {_fmt_lat(row['latency_mean']):>12} {_fmt_delta(row.get('latency_mean_delta_pct', '')):>8}" + f" {_fmt_lat(row['latency_median']):>12} {_fmt_delta(row.get('latency_median_delta_pct', '')):>8}" + f" {_fmt_lat(row['latency_p99']):>12} {_fmt_delta(row.get('latency_p99_delta_pct', '')):>8}" + f" {_fmt_int(row['completed_requests']):>6} {_fmt_int(row['failed_requests']):>6}", + flush=True, + ) + + # ── Write CSV ──────────────────────────────────────────────── + csv_path = output_dir / "routing_algorithm_comparison.csv" + fieldnames = [ + "algorithm", "throughput_qps", "throughput_qps_delta_pct", + "latency_mean", "latency_mean_delta_pct", + "latency_median", "latency_median_delta_pct", + "latency_p99", "latency_p99_delta_pct", + "duration", "completed_requests", "failed_requests", "error", + ] + with open(csv_path, "w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=fieldnames, extrasaction="ignore") + writer.writeheader() + writer.writerows(csv_rows) + print(f"\n[done] CSV written to {csv_path}", flush=True) + print(f"[done] Per-algorithm JSON results in {output_dir}/", flush=True) + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/examples/diffusion_router/demo.py b/examples/diffusion_router/demo.py index 9c1b4693e..95947fc23 100644 --- a/examples/diffusion_router/demo.py +++ b/examples/diffusion_router/demo.py @@ -36,6 +36,9 @@ def main(): parser.add_argument("--timeout", type=float, default=None, help="Request timeout in seconds") parser.add_argument("--health-check-interval", type=int, default=10, help="Seconds between health checks") parser.add_argument("--health-check-failure-threshold", type=int, default=3, help="Failures before quarantine") + parser.add_argument("--routing-algorithm", type=str, default="least-request", + choices=["least-request", "round-robin", "random"], + help="Load-balancing algorithm") parser.add_argument("--verbose", action="store_true", help="Enable verbose logging") args = parser.parse_args() diff --git a/miles/router/diffusion_router.py b/miles/router/diffusion_router.py index e69497420..d77671083 100644 --- a/miles/router/diffusion_router.py +++ b/miles/router/diffusion_router.py @@ -2,6 +2,7 @@ import asyncio import json import logging +import random import httpx import uvicorn @@ -34,6 +35,9 @@ def __init__(self, args, verbose=False): # Quarantined workers excluded from routing pool self.dead_workers: set[str] = set() + self.routing_algorithm = getattr(args, "routing_algorithm", "least-request") + self._rr_index = 0 + max_connections = getattr(args, "max_connections", 100) timeout = getattr(args, "timeout", None) @@ -112,15 +116,21 @@ async def _health_check_loop(self): # ── Load balancing ─────────────────────────────────────────────── def _use_url(self): - """Select worker URL with minimal active requests.""" + """Select a worker URL based on the configured routing algorithm.""" if not self.worker_request_counts: raise RuntimeError("No workers registered in the pool") - valid_workers = (w for w in self.worker_request_counts if w not in self.dead_workers) - try: + valid_workers = [w for w in self.worker_request_counts if w not in self.dead_workers] + if not valid_workers: + raise RuntimeError("No healthy workers available in the pool") + + if self.routing_algorithm == "round-robin": + url = valid_workers[self._rr_index % len(valid_workers)] + self._rr_index = (self._rr_index + 1) % len(valid_workers) + elif self.routing_algorithm == "random": + url = random.choice(valid_workers) + else: # least-request (default) url = min(valid_workers, key=self.worker_request_counts.get) - except ValueError: - raise RuntimeError("No healthy workers available in the pool") from None self.worker_request_counts[url] += 1 return url @@ -133,8 +143,33 @@ def _finish_url(self, url): # ── Proxy helpers ──────────────────────────────────────────────── + def _build_proxy_response(self, content: bytes, status_code: int, headers: dict) -> Response: + """ + Build an HTTP response from proxied bytes. + + Keep behavior consistent with `MilesRouter._build_proxy_response`: + - If the payload is JSON, return `JSONResponse` + - Otherwise, return raw `Response` + + Diffusion responses (e.g. `b64_json`) can be very large; decoding and re-encoding + the JSON can dominate CPU time. We therefore skip JSON decoding for large bodies. + """ + content_type = headers.get("content-type", "") + + # Size guard: don't pay JSON decode/re-encode costs on large payloads. + # This preserves exact bytes on the wire and avoids CPU/memory pressure. + max_json_reencode_bytes = 256 * 1024 + if len(content) <= max_json_reencode_bytes: + try: + data = json.loads(content) + return JSONResponse(content=data, status_code=status_code, headers=headers) + except Exception: + pass + + return Response(content=content, status_code=status_code, headers=headers, media_type=content_type) + async def _forward_to_worker(self, request: Request, path: str) -> Response: - """Forward a request to the least-loaded worker and return the response.""" + """Forward a request to a selected worker and return the response.""" try: worker_url = self._use_url() except RuntimeError as exc: @@ -145,6 +180,9 @@ async def _forward_to_worker(self, request: Request, path: str) -> Response: url = f"{worker_url}/{path}" if not query else f"{worker_url}/{path}?{query}" body = await request.body() headers = dict(request.headers) + # Let httpx set the correct framing headers for the forwarded body. + if body is not None: + headers = {k: v for k, v in headers.items() if k.lower() not in ("content-length", "transfer-encoding")} try: response = await self.client.request(request.method, url, content=body, headers=headers) @@ -153,14 +191,7 @@ async def _forward_to_worker(self, request: Request, path: str) -> Response: self._finish_url(worker_url) resp_headers = self._sanitize_response_headers(response.headers) - content_type = resp_headers.get("content-type", "") - try: - data = json.loads(content) - return JSONResponse(content=data, status_code=response.status_code, headers=resp_headers) - except Exception: - return Response( - content=content, status_code=response.status_code, headers=resp_headers, media_type=content_type - ) + return self._build_proxy_response(content, response.status_code, resp_headers) async def _broadcast_to_workers(self, path: str, body: bytes, headers: dict) -> list[dict]: """Send a request to ALL healthy workers and collect results.""" @@ -193,8 +224,8 @@ async def generate(self, request: Request): return await self._forward_to_worker(request, "v1/images/generations") async def generate_video(self, request: Request): - """Route video generation to the least-loaded worker via /v1/videos/generations.""" - return await self._forward_to_worker(request, "v1/videos/generations") + """Route video generation to the least-loaded worker via /v1/videos.""" + return await self._forward_to_worker(request, "v1/videos") async def health(self, request: Request): """Aggregated health status: healthy if at least one worker is alive.""" @@ -270,6 +301,8 @@ async def proxy(self, request: Request, path: str): parser.add_argument("--timeout", type=float, default=None) parser.add_argument("--health-check-interval", type=int, default=10) parser.add_argument("--health-check-failure-threshold", type=int, default=3) + parser.add_argument("--routing-algorithm", type=str, default="least-request", + choices=["least-request", "round-robin", "random"]) parser.add_argument("--verbose", action="store_true") args = parser.parse_args() diff --git a/tests/fast/router/test_diffusion_router.py b/tests/fast/router/test_diffusion_router.py new file mode 100644 index 000000000..65a2df394 --- /dev/null +++ b/tests/fast/router/test_diffusion_router.py @@ -0,0 +1,167 @@ +from argparse import Namespace + +import pytest + +from miles.router.diffusion_router import DiffusionRouter + + +def make_router_args(**overrides) -> Namespace: + """Create a Namespace with default DiffusionRouter args, applying overrides.""" + defaults = dict( + host="127.0.0.1", + port=30080, + max_connections=100, + timeout=None, + routing_algorithm="least-request", + ) + defaults.update(overrides) + return Namespace(**defaults) + + +@pytest.fixture +def router_factory(): + """Factory fixture that creates a DiffusionRouter with pre-set worker state.""" + + def _create( + workers: dict[str, int], + dead: set[str] | None = None, + **arg_overrides, + ) -> DiffusionRouter: + router = DiffusionRouter(make_router_args(**arg_overrides)) + router.worker_request_counts = dict(workers) + router.worker_failure_counts = {url: 0 for url in workers} + if dead: + router.dead_workers = set(dead) + return router + + return _create + + +# ── Least-request ──────────────────────────────────────────────── + + +@pytest.mark.unit +class TestLeastRequest: + """Test the least-request (default) load-balancing algorithm.""" + + def test_selects_min_load(self, router_factory): + router = router_factory({"http://w1:8000": 5, "http://w2:8000": 2, "http://w3:8000": 8}) + selected = router._use_url() + assert selected == "http://w2:8000" + assert router.worker_request_counts["http://w2:8000"] == 3 + + +# ── Round-robin ────────────────────────────────────────────────── + + +@pytest.mark.unit +class TestRoundRobin: + """Test the round-robin load-balancing algorithm.""" + + def test_cycles_workers(self, router_factory): + router = router_factory( + {"http://w1:8000": 0, "http://w2:8000": 0, "http://w3:8000": 0}, + routing_algorithm="round-robin", + ) + results = [router._use_url() for _ in range(6)] + workers = list(router.worker_request_counts.keys()) + expected = [workers[i % 3] for i in range(6)] + assert results == expected + for url in workers: + assert router.worker_request_counts[url] == 2 + + def test_excludes_dead_workers(self, router_factory): + router = router_factory( + {"http://w1:8000": 0, "http://w2:8000": 0, "http://w3:8000": 0}, + dead={"http://w2:8000"}, + routing_algorithm="round-robin", + ) + results = [router._use_url() for _ in range(4)] + assert "http://w2:8000" not in results + assert all(url in ("http://w1:8000", "http://w3:8000") for url in results) + + +# ── Random ─────────────────────────────────────────────────────── + + +@pytest.mark.unit +class TestRandom: + """Test the random load-balancing algorithm.""" + + def test_selects_from_valid_workers(self, router_factory): + router = router_factory( + {"http://w1:8000": 0, "http://w2:8000": 0, "http://w3:8000": 0}, + routing_algorithm="random", + ) + seen = set() + for _ in range(30): + # Reset counts so they don't grow unbounded + for url in router.worker_request_counts: + router.worker_request_counts[url] = 0 + seen.add(router._use_url()) + assert seen == {"http://w1:8000", "http://w2:8000", "http://w3:8000"} + + def test_excludes_dead_workers(self, router_factory): + router = router_factory( + {"http://w1:8000": 0, "http://w2:8000": 0, "http://w3:8000": 0}, + dead={"http://w2:8000"}, + routing_algorithm="random", + ) + for _ in range(20): + url = router._use_url() + assert url != "http://w2:8000" + router.worker_request_counts[url] -= 1 # reset increment + + +# ── Error cases ────────────────────────────────────────────────── + + +@pytest.mark.unit +class TestErrorCases: + """Test error handling across all routing algorithms.""" + + @pytest.mark.parametrize("algorithm", ["least-request", "round-robin", "random"]) + def test_raises_when_no_workers(self, router_factory, algorithm): + router = router_factory({}, routing_algorithm=algorithm) + with pytest.raises(RuntimeError, match="No workers registered"): + router._use_url() + + @pytest.mark.parametrize("algorithm", ["least-request", "round-robin", "random"]) + def test_raises_when_all_dead(self, router_factory, algorithm): + router = router_factory( + {"http://w1:8000": 0, "http://w2:8000": 0}, + dead={"http://w1:8000", "http://w2:8000"}, + routing_algorithm=algorithm, + ) + with pytest.raises(RuntimeError, match="No healthy workers"): + router._use_url() + + +# ── Count management ───────────────────────────────────────────── + + +@pytest.mark.unit +class TestCountManagement: + """Test that _use_url / _finish_url correctly track active request counts.""" + + @pytest.mark.parametrize("algorithm", ["least-request", "round-robin", "random"]) + def test_increment_and_finish(self, router_factory, algorithm): + router = router_factory({"http://w1:8000": 0}, routing_algorithm=algorithm) + url = router._use_url() + assert router.worker_request_counts[url] == 1 + router._finish_url(url) + assert router.worker_request_counts[url] == 0 + + +# ── Default algorithm ──────────────────────────────────────────── + + +@pytest.mark.unit +class TestDefaults: + """Test default routing algorithm when the attribute is absent.""" + + def test_default_algorithm_is_least_request(self): + args = Namespace(host="127.0.0.1", port=30080, max_connections=100, timeout=None) + # args has no routing_algorithm attribute + router = DiffusionRouter(args) + assert router.routing_algorithm == "least-request" From 73b065e7206f8c94176d79881908c92e431bf63a Mon Sep 17 00:00:00 2001 From: sniper35 Date: Sun, 8 Feb 2026 08:15:18 +0000 Subject: [PATCH 08/14] fixed demo and updated docs --- examples/diffusion_router/README.md | 7 ++++++- examples/diffusion_router/demo.py | 4 +++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/examples/diffusion_router/README.md b/examples/diffusion_router/README.md index b1fd0cc09..c4af0e7b5 100644 --- a/examples/diffusion_router/README.md +++ b/examples/diffusion_router/README.md @@ -31,7 +31,12 @@ curl -X POST 'http://localhost:30080/add_worker?url=http://localhost:10090' The router uses a **least-request** strategy: each incoming request is forwarded to the worker with the fewest in-flight requests. This is workload-aware and avoids hot-spotting compared to round-robin. When a request completes, the worker's count is decremented, keeping the load state accurate in real time. -Workers that fail consecutive health checks (default: 3) are quarantined and excluded from the routing pool. A background loop pings each worker's `GET /health` endpoint at a configurable interval (default: 10s). +Workers that fail consecutive health checks (default: 3) are quarantined and excluded from the routing pool and will not be added back to avoid stale weights(See [discussion](https://github.com/radixark/miles/pull/260#discussion_r2654274449)). A background loop pings each worker's `GET /health` endpoint at a configurable interval (default: 10s). + +## Benchmark script(under examples/diffusion_router/): +- examples/diffusion_router/bench_routing_algorithms.py — top-level comparison runner +- examples/diffusion_router/bench_router.py — single-algorithm benchmark (spawned per algorithm) +- examples/diffusion_router/demo.py — router process (spawned by bench_router.py) ## Notes diff --git a/examples/diffusion_router/demo.py b/examples/diffusion_router/demo.py index 95947fc23..7db7e5235 100644 --- a/examples/diffusion_router/demo.py +++ b/examples/diffusion_router/demo.py @@ -46,7 +46,9 @@ def main(): # Pre-register any workers specified on the command line for url in args.worker_urls: - router.add_worker_sync(url) + if url not in router.worker_request_counts: + router.worker_request_counts[url] = 0 + router.worker_failure_counts[url] = 0 if args.verbose: print(f"[demo] Pre-registered worker: {url}") From 930fcbfd364010e3122bae5d4ac1a2f89873258a Mon Sep 17 00:00:00 2001 From: sniper35 Date: Sun, 8 Feb 2026 08:47:58 +0000 Subject: [PATCH 09/14] remove local sglang source-repo assumptions --- examples/diffusion_router/bench_router.py | 71 ++++--------------- .../bench_routing_algorithms.py | 3 - 2 files changed, 13 insertions(+), 61 deletions(-) diff --git a/examples/diffusion_router/bench_router.py b/examples/diffusion_router/bench_router.py index ed9acce0c..a20d64ac6 100644 --- a/examples/diffusion_router/bench_router.py +++ b/examples/diffusion_router/bench_router.py @@ -38,46 +38,6 @@ def _require_non_empty_model(model: str) -> str: return normalized -def _default_sglang_root() -> Path: - # Repo layout: miles/examples/diffusion_router/bench_router.py -> miles/ (parents[2]) - return Path(__file__).resolve().parents[2].parent / "sglang" - - -def _resolve_sglang_root(path: str | None) -> Path | None: - if path: - root = Path(path).expanduser().resolve() - sglang_pkg = root / "python" / "sglang" - if not sglang_pkg.exists(): - raise FileNotFoundError(f"sglang source not found at {root}. Expected {sglang_pkg}.") - return root - default = _default_sglang_root() - if (default / "python" / "sglang").exists(): - return default - # No source repo found — fall back to pip-installed sglang - return None - - -def _with_pythonpath(env: dict[str, str], extra_path: Path) -> dict[str, str]: - env = dict(env) - existing = env.get("PYTHONPATH") - extra = str(extra_path) - env["PYTHONPATH"] = f"{extra}{os.pathsep}{existing}" if existing else extra - return env - - -def _build_sglang_cli_cmd() -> list[str]: - """ - Build a command prefix that invokes the `sglang` CLI from the current - Python environment. - """ - sglang_bin = Path(sys.executable).resolve().parent / "sglang" - if sglang_bin.exists(): - return [str(sglang_bin)] - - # Fallback when the console script is missing. - return [sys.executable, "-c", "from sglang.cli.main import main; main()"] - - def _wait_for_health( url: str, timeout: int, label: str, proc: subprocess.Popen | None = None, ) -> None: @@ -216,8 +176,6 @@ def _signal_group(proc: subprocess.Popen, sig: int) -> None: def main() -> int: parser = argparse.ArgumentParser(description="Benchmark Miles DiffusionRouter with sglang bench_serving.") parser.add_argument("--model", type=str, required=True, help="Diffusion model HF ID or local path.") - parser.add_argument("--sglang-root", type=str, default=None, help="Path to sglang repo (default: ../sglang).") - parser.add_argument("--router-host", type=str, default="127.0.0.1", help="Router bind host.") parser.add_argument("--router-port", type=int, default=30080, help="Router port.") parser.add_argument("--routing-algorithm", type=str, default="least-request", @@ -290,21 +248,14 @@ def main() -> int: args = parser.parse_args() args.model = _require_non_empty_model(args.model) - sglang_root = _resolve_sglang_root(args.sglang_root) - if sglang_root is not None: - sglang_python = sglang_root / "python" - env = _with_pythonpath(os.environ, sglang_python) - else: - # Verify pip-installed sglang is importable - try: - import sglang # noqa: F401 - except ImportError: - raise RuntimeError( - "sglang is not installed and no source repo found at ../sglang.\n" - "Install with: uv pip install \"sglang[diffusion]\" --prerelease=allow\n" - "Or point to the source repo with: --sglang-root /path/to/sglang" - ) - env = dict(os.environ) + try: + import sglang # noqa: F401 + except ImportError: + raise RuntimeError( + "sglang is not installed.\n" + "Install with: uv pip install \"sglang[diffusion]\" --prerelease=allow" + ) + env = dict(os.environ) worker_urls = list(args.worker_urls) if not worker_urls: @@ -357,7 +308,11 @@ def main() -> int: flush=True, ) - sglang_cli_cmd = _build_sglang_cli_cmd() + sglang_bin = Path(sys.executable).resolve().parent / "sglang" + if sglang_bin.exists(): + sglang_cli_cmd = [str(sglang_bin)] + else: + sglang_cli_cmd = [sys.executable, "-c", "from sglang.cli.main import main; main()"] gpu_pool = _resolve_gpu_pool(args, env) total_gpus_needed = len(worker_urls) * args.num_gpus_per_worker if gpu_pool is not None and len(gpu_pool) < total_gpus_needed: diff --git a/examples/diffusion_router/bench_routing_algorithms.py b/examples/diffusion_router/bench_routing_algorithms.py index 60b9ec760..64dca91a8 100644 --- a/examples/diffusion_router/bench_routing_algorithms.py +++ b/examples/diffusion_router/bench_routing_algorithms.py @@ -48,7 +48,6 @@ def main() -> int: parser.add_argument("--output-dir", type=str, default=None, help="Output directory for results.") # Pass-through args for bench_router.py - parser.add_argument("--sglang-root", type=str, default=None) parser.add_argument("--router-host", type=str, default="127.0.0.1") parser.add_argument("--router-port", type=int, default=30080) parser.add_argument("--router-verbose", action="store_true") @@ -121,8 +120,6 @@ def main() -> int: "--log-level", args.log_level, "--output-file", str(out_file), ] - if args.sglang_root: - bench_cmd += ["--sglang-root", args.sglang_root] if args.worker_urls: bench_cmd += ["--worker-urls", *args.worker_urls] if args.worker_gpu_ids: From 079e7c219ee45eba53309cbb7ccd58ed7066d0c1 Mon Sep 17 00:00:00 2001 From: sniper35 Date: Sun, 8 Feb 2026 09:01:29 +0000 Subject: [PATCH 10/14] robust & defensive check --- examples/diffusion_router/README.md | 2 +- examples/diffusion_router/bench_router.py | 4 +++- miles/router/diffusion_router.py | 12 +++++++++--- tests/fast/router/test_diffusion_router.py | 9 +++++++++ 4 files changed, 22 insertions(+), 5 deletions(-) diff --git a/examples/diffusion_router/README.md b/examples/diffusion_router/README.md index c4af0e7b5..55c38c4d0 100644 --- a/examples/diffusion_router/README.md +++ b/examples/diffusion_router/README.md @@ -19,7 +19,7 @@ curl -X POST 'http://localhost:30080/add_worker?url=http://localhost:10090' | Method | Path | Description | |--------|------|-------------| | `POST` | `/generate` | Image generation (forwards to `/v1/images/generations`) | -| `POST` | `/generate_video` | Video generation (forwards to `/v1/videos/generations`) | +| `POST` | `/generate_video` | Video generation (forwards to `/v1/videos`) | | `GET` | `/health` | Aggregated router health | | `GET` | `/health_workers` | Per-worker health and load info | | `POST` | `/add_worker` | Register a diffusion worker (`?url=...` or JSON body) | diff --git a/examples/diffusion_router/bench_router.py b/examples/diffusion_router/bench_router.py index a20d64ac6..70f78783c 100644 --- a/examples/diffusion_router/bench_router.py +++ b/examples/diffusion_router/bench_router.py @@ -446,7 +446,9 @@ def main() -> int: if args.bench_extra_args: bench_cmd += shlex.split(args.bench_extra_args) - return subprocess.call(bench_cmd, env=env) + bench_proc = subprocess.Popen(bench_cmd, env=env, start_new_session=True) + processes.append(bench_proc) + return bench_proc.wait() finally: _terminate_all(reversed(processes)) diff --git a/miles/router/diffusion_router.py b/miles/router/diffusion_router.py index d77671083..facc9c51d 100644 --- a/miles/router/diffusion_router.py +++ b/miles/router/diffusion_router.py @@ -137,9 +137,11 @@ def _use_url(self): def _finish_url(self, url): """Mark the request to the given URL as finished.""" - assert url in self.worker_request_counts, f"URL {url} not recognized" + if url not in self.worker_request_counts: + raise ValueError(f"URL {url} not recognized") self.worker_request_counts[url] -= 1 - assert self.worker_request_counts[url] >= 0, f"URL {url} count went negative" + if self.worker_request_counts[url] < 0: + raise RuntimeError(f"URL {url} count went negative") # ── Proxy helpers ──────────────────────────────────────────────── @@ -187,7 +189,11 @@ async def _forward_to_worker(self, request: Request, path: str) -> Response: try: response = await self.client.request(request.method, url, content=body, headers=headers) content = await response.aread() - finally: + except Exception as exc: + self._finish_url(worker_url) + logger.error(f"[diffusion-router] Failed to forward request to {worker_url}: {exc}") + return JSONResponse(status_code=502, content={"error": f"Worker request failed: {exc}"}) + else: self._finish_url(worker_url) resp_headers = self._sanitize_response_headers(response.headers) diff --git a/tests/fast/router/test_diffusion_router.py b/tests/fast/router/test_diffusion_router.py index 65a2df394..0d2d2a134 100644 --- a/tests/fast/router/test_diffusion_router.py +++ b/tests/fast/router/test_diffusion_router.py @@ -50,6 +50,15 @@ def test_selects_min_load(self, router_factory): assert selected == "http://w2:8000" assert router.worker_request_counts["http://w2:8000"] == 3 + def test_excludes_dead_workers(self, router_factory): + router = router_factory( + {"http://w1:8000": 5, "http://w2:8000": 2, "http://w3:8000": 8}, + dead={"http://w2:8000"}, + ) + selected = router._use_url() + assert selected == "http://w1:8000" + assert router.worker_request_counts["http://w1:8000"] == 6 + # ── Round-robin ────────────────────────────────────────────────── From 6111a94710415dc573da6a375a4aca50925a562a Mon Sep 17 00:00:00 2001 From: sniper35 Date: Sun, 8 Feb 2026 09:27:38 +0000 Subject: [PATCH 11/14] remove __main__ block --- examples/diffusion_router/demo.py | 6 +--- miles/router/diffusion_router.py | 47 +++++++------------------------ 2 files changed, 11 insertions(+), 42 deletions(-) diff --git a/examples/diffusion_router/demo.py b/examples/diffusion_router/demo.py index 7db7e5235..26c807aff 100644 --- a/examples/diffusion_router/demo.py +++ b/examples/diffusion_router/demo.py @@ -46,11 +46,7 @@ def main(): # Pre-register any workers specified on the command line for url in args.worker_urls: - if url not in router.worker_request_counts: - router.worker_request_counts[url] = 0 - router.worker_failure_counts[url] = 0 - if args.verbose: - print(f"[demo] Pre-registered worker: {url}") + router.register_worker(url) print(f"[demo] Starting diffusion router on {args.host}:{args.port}") print(f"[demo] Workers: {list(router.worker_request_counts.keys()) or '(none — add via POST /add_worker)'}") diff --git a/miles/router/diffusion_router.py b/miles/router/diffusion_router.py index facc9c51d..a0e700f33 100644 --- a/miles/router/diffusion_router.py +++ b/miles/router/diffusion_router.py @@ -1,11 +1,9 @@ -import argparse import asyncio import json import logging import random import httpx -import uvicorn from fastapi import FastAPI, Request from fastapi.responses import JSONResponse from starlette.responses import Response @@ -13,12 +11,6 @@ logger = logging.getLogger(__name__) -def run_diffusion_router(args): - """Run the diffusion router with the specified configuration.""" - router = DiffusionRouter(args) - uvicorn.run(router.app, host=args.host, port=args.port, log_level="info") - - class DiffusionRouter: def __init__(self, args, verbose=False): """Initialize the diffusion router for load-balancing across sglang-diffusion workers.""" @@ -264,8 +256,16 @@ async def update_weights_from_disk(self, request: Request): results = await self._broadcast_to_workers("update_weights_from_disk", body, headers) return JSONResponse(content={"results": results}) + def register_worker(self, url: str) -> None: + """Register a worker URL if not already known (sync, for startup use).""" + if url not in self.worker_request_counts: + self.worker_request_counts[url] = 0 + self.worker_failure_counts[url] = 0 + if self.verbose: + print(f"[diffusion-router] Added new worker: {url}") + async def add_worker(self, request: Request): - """Register a new diffusion worker.""" + """Register a new diffusion worker (HTTP endpoint).""" worker_url = request.query_params.get("url") or request.query_params.get("worker_url") if not worker_url: @@ -281,12 +281,7 @@ async def add_worker(self, request: Request): status_code=400, content={"error": "worker_url is required (use query ?url=... or JSON body)"} ) - if worker_url not in self.worker_request_counts: - self.worker_request_counts[worker_url] = 0 - self.worker_failure_counts[worker_url] = 0 - if self.verbose: - print(f"[diffusion-router] Added new worker: {worker_url}") - + self.register_worker(worker_url) return {"status": "success", "worker_urls": list(self.worker_request_counts.keys())} async def list_workers(self, request: Request): @@ -296,25 +291,3 @@ async def list_workers(self, request: Request): async def proxy(self, request: Request, path: str): """Catch-all: forward any unmatched request to the least-loaded worker.""" return await self._forward_to_worker(request, path) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Miles Diffusion Router") - parser.add_argument("--host", type=str, default="0.0.0.0") - parser.add_argument("--port", type=int, default=30080) - parser.add_argument("--worker-urls", nargs="*", default=[], help="Initial worker URLs to register") - parser.add_argument("--max-connections", type=int, default=100) - parser.add_argument("--timeout", type=float, default=None) - parser.add_argument("--health-check-interval", type=int, default=10) - parser.add_argument("--health-check-failure-threshold", type=int, default=3) - parser.add_argument("--routing-algorithm", type=str, default="least-request", - choices=["least-request", "round-robin", "random"]) - parser.add_argument("--verbose", action="store_true") - args = parser.parse_args() - - router = DiffusionRouter(args, verbose=args.verbose) - for url in args.worker_urls: - router.worker_request_counts[url] = 0 - router.worker_failure_counts[url] = 0 - - uvicorn.run(router.app, host=args.host, port=args.port, log_level="info") From 749fdd02e5d93c4ffbf1b0ae468d69c0da3a7b34 Mon Sep 17 00:00:00 2001 From: sniper35 Date: Sun, 8 Feb 2026 09:32:46 +0000 Subject: [PATCH 12/14] remove bizzare comment lines --- .../bench_routing_algorithms.py | 4 ---- miles/router/diffusion_router.py | 8 -------- tests/fast/router/test_diffusion_router.py | 18 ------------------ 3 files changed, 30 deletions(-) diff --git a/examples/diffusion_router/bench_routing_algorithms.py b/examples/diffusion_router/bench_routing_algorithms.py index 64dca91a8..549878a47 100644 --- a/examples/diffusion_router/bench_routing_algorithms.py +++ b/examples/diffusion_router/bench_routing_algorithms.py @@ -169,7 +169,6 @@ def main() -> int: print(f"[warn] Output file not found: {out_file}", flush=True) results[algo] = {"error": "output_file_missing"} - # ── Collect per-algorithm metrics ──────────────────────────── BASELINE = "random" metric_keys = ["throughput_qps", "latency_mean", "latency_median", "latency_p99"] @@ -203,7 +202,6 @@ def main() -> int: parsed[algo] = row csv_rows.append(row) - # ── Compute deltas vs baseline ─────────────────────────────── baseline_row = parsed.get(BASELINE) for row in csv_rows: if row.get("error"): @@ -217,7 +215,6 @@ def main() -> int: else: row[delta_key] = "" - # ── Print comparison table ─────────────────────────────────── print(f"\n{'='*100}", flush=True) print(f" Routing Algorithm Comparison (baseline: {BASELINE})", flush=True) print(f"{'='*100}", flush=True) @@ -261,7 +258,6 @@ def _fmt_int(v): flush=True, ) - # ── Write CSV ──────────────────────────────────────────────── csv_path = output_dir / "routing_algorithm_comparison.csv" fieldnames = [ "algorithm", "throughput_qps", "throughput_qps_delta_pct", diff --git a/miles/router/diffusion_router.py b/miles/router/diffusion_router.py index a0e700f33..753e4dfba 100644 --- a/miles/router/diffusion_router.py +++ b/miles/router/diffusion_router.py @@ -52,8 +52,6 @@ def _setup_routes(self): # Catch-all route for proxying — must be registered LAST self.app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])(self.proxy) - # ── Health checks ──────────────────────────────────────────────── - async def _start_background_health_check(self): asyncio.create_task(self._health_check_loop()) @@ -105,8 +103,6 @@ async def _health_check_loop(self): logger.error(f"[diffusion-router] Unexpected error in health check loop: {e}", exc_info=True) await asyncio.sleep(5) - # ── Load balancing ─────────────────────────────────────────────── - def _use_url(self): """Select a worker URL based on the configured routing algorithm.""" if not self.worker_request_counts: @@ -135,8 +131,6 @@ def _finish_url(self, url): if self.worker_request_counts[url] < 0: raise RuntimeError(f"URL {url} count went negative") - # ── Proxy helpers ──────────────────────────────────────────────── - def _build_proxy_response(self, content: bytes, status_code: int, headers: dict) -> Response: """ Build an HTTP response from proxied bytes. @@ -215,8 +209,6 @@ def _sanitize_response_headers(headers) -> dict: dropped = {"content-length", "content-encoding"} return {k: v for k, v in headers.items() if k.lower() not in hop_by_hop | dropped} - # ── Route handlers ─────────────────────────────────────────────── - async def generate(self, request: Request): """Route image generation to the least-loaded worker via /v1/images/generations.""" return await self._forward_to_worker(request, "v1/images/generations") diff --git a/tests/fast/router/test_diffusion_router.py b/tests/fast/router/test_diffusion_router.py index 0d2d2a134..8ab24203f 100644 --- a/tests/fast/router/test_diffusion_router.py +++ b/tests/fast/router/test_diffusion_router.py @@ -37,9 +37,6 @@ def _create( return _create -# ── Least-request ──────────────────────────────────────────────── - - @pytest.mark.unit class TestLeastRequest: """Test the least-request (default) load-balancing algorithm.""" @@ -60,9 +57,6 @@ def test_excludes_dead_workers(self, router_factory): assert router.worker_request_counts["http://w1:8000"] == 6 -# ── Round-robin ────────────────────────────────────────────────── - - @pytest.mark.unit class TestRoundRobin: """Test the round-robin load-balancing algorithm.""" @@ -90,9 +84,6 @@ def test_excludes_dead_workers(self, router_factory): assert all(url in ("http://w1:8000", "http://w3:8000") for url in results) -# ── Random ─────────────────────────────────────────────────────── - - @pytest.mark.unit class TestRandom: """Test the random load-balancing algorithm.""" @@ -122,9 +113,6 @@ def test_excludes_dead_workers(self, router_factory): router.worker_request_counts[url] -= 1 # reset increment -# ── Error cases ────────────────────────────────────────────────── - - @pytest.mark.unit class TestErrorCases: """Test error handling across all routing algorithms.""" @@ -146,9 +134,6 @@ def test_raises_when_all_dead(self, router_factory, algorithm): router._use_url() -# ── Count management ───────────────────────────────────────────── - - @pytest.mark.unit class TestCountManagement: """Test that _use_url / _finish_url correctly track active request counts.""" @@ -162,9 +147,6 @@ def test_increment_and_finish(self, router_factory, algorithm): assert router.worker_request_counts[url] == 0 -# ── Default algorithm ──────────────────────────────────────────── - - @pytest.mark.unit class TestDefaults: """Test default routing algorithm when the attribute is absent.""" From 50745a3f88579d3b8f4356aef57b81dafd952199 Mon Sep 17 00:00:00 2001 From: sniper35 Date: Sun, 8 Feb 2026 09:41:25 +0000 Subject: [PATCH 13/14] more style fix --- examples/diffusion_router/bench_router.py | 38 +++--- .../bench_routing_algorithms.py | 111 ++++++++++++------ examples/diffusion_router/demo.py | 10 +- miles/router/diffusion_router.py | 26 ++-- 4 files changed, 120 insertions(+), 65 deletions(-) diff --git a/examples/diffusion_router/bench_router.py b/examples/diffusion_router/bench_router.py index 70f78783c..cd77c3ae0 100644 --- a/examples/diffusion_router/bench_router.py +++ b/examples/diffusion_router/bench_router.py @@ -21,8 +21,8 @@ import subprocess import sys import time +from collections.abc import Iterable from pathlib import Path -from typing import Iterable import requests @@ -39,7 +39,10 @@ def _require_non_empty_model(model: str) -> str: def _wait_for_health( - url: str, timeout: int, label: str, proc: subprocess.Popen | None = None, + url: str, + timeout: int, + label: str, + proc: subprocess.Popen | None = None, ) -> None: start = time.time() last_print = 0.0 @@ -106,10 +109,7 @@ def _reserve_available_port(host: str, preferred_port: int, used_ports: set[int] used_ports.add(port) return port - raise RuntimeError( - f"Unable to reserve a free port for host {host}. " - f"Preferred start={preferred_port}." - ) + raise RuntimeError(f"Unable to reserve a free port for host {host}. " f"Preferred start={preferred_port}.") def _parse_gpu_id_list(raw: str) -> list[str]: @@ -178,9 +178,13 @@ def main() -> int: parser.add_argument("--model", type=str, required=True, help="Diffusion model HF ID or local path.") parser.add_argument("--router-host", type=str, default="127.0.0.1", help="Router bind host.") parser.add_argument("--router-port", type=int, default=30080, help="Router port.") - parser.add_argument("--routing-algorithm", type=str, default="least-request", - choices=["least-request", "round-robin", "random"], - help="Load-balancing algorithm for the router.") + parser.add_argument( + "--routing-algorithm", + type=str, + default="least-request", + choices=["least-request", "round-robin", "random"], + help="Load-balancing algorithm for the router.", + ) parser.add_argument("--router-verbose", action="store_true", help="Enable router verbose logging.") parser.add_argument("--router-extra-args", type=str, default="", help="Extra args for the router demo script.") @@ -250,11 +254,10 @@ def main() -> int: try: import sglang # noqa: F401 - except ImportError: + except ImportError as exc: raise RuntimeError( - "sglang is not installed.\n" - "Install with: uv pip install \"sglang[diffusion]\" --prerelease=allow" - ) + "sglang is not installed.\n" 'Install with: uv pip install "sglang[diffusion]" --prerelease=allow' + ) from exc env = dict(os.environ) worker_urls = list(args.worker_urls) @@ -295,9 +298,7 @@ def main() -> int: for i, _ in enumerate(worker_urls): preferred_master = args.worker_master_port_base + i * args.worker_internal_port_stride - preferred_scheduler = ( - args.worker_scheduler_port_base + i * args.worker_internal_port_stride - ) + preferred_scheduler = args.worker_scheduler_port_base + i * args.worker_internal_port_stride master_port = _reserve_available_port(args.worker_host, preferred_master, reserved_ports) scheduler_port = _reserve_available_port(args.worker_host, preferred_scheduler, reserved_ports) worker_internal_ports.append((master_port, scheduler_port)) @@ -371,7 +372,10 @@ def main() -> int: flush=True, ) - print(f"[bench] Waiting for {len(worker_urls)} worker(s) to become healthy (this may take several minutes)...", flush=True) + print( + f"[bench] Waiting for {len(worker_urls)} worker(s) to become healthy (this may take several minutes)...", + flush=True, + ) for i, url in enumerate(worker_urls): _wait_for_health(url, args.wait_timeout, f"worker {url}", proc=processes[i]) diff --git a/examples/diffusion_router/bench_routing_algorithms.py b/examples/diffusion_router/bench_routing_algorithms.py index 549878a47..827417e08 100644 --- a/examples/diffusion_router/bench_routing_algorithms.py +++ b/examples/diffusion_router/bench_routing_algorithms.py @@ -37,13 +37,14 @@ def _require_non_empty_model(model: str) -> str: def main() -> int: - parser = argparse.ArgumentParser( - description="Compare routing algorithms by running bench_router.py for each." - ) + parser = argparse.ArgumentParser(description="Compare routing algorithms by running bench_router.py for each.") parser.add_argument("--model", type=str, required=True, help="Diffusion model HF ID or local path.") parser.add_argument( - "--algorithms", nargs="+", default=ALL_ALGORITHMS, - choices=ALL_ALGORITHMS, help="Algorithms to compare (default: all three).", + "--algorithms", + nargs="+", + default=ALL_ALGORITHMS, + choices=ALL_ALGORITHMS, + help="Algorithms to compare (default: all three).", ) parser.add_argument("--output-dir", type=str, default=None, help="Output directory for results.") @@ -100,25 +101,44 @@ def main() -> int: out_file = output_dir / f"bench_{algo}.json" bench_cmd = [ - py, "examples/diffusion_router/bench_router.py", - "--model", args.model, - "--routing-algorithm", algo, - "--num-workers", str(args.num_workers), - "--num-prompts", str(args.num_prompts), - "--max-concurrency", str(args.max_concurrency), - "--num-gpus-per-worker", str(args.num_gpus_per_worker), - "--worker-base-port", str(args.worker_base_port), - "--worker-port-stride", str(args.worker_port_stride), - "--worker-master-port-base", str(args.worker_master_port_base), - "--worker-scheduler-port-base", str(args.worker_scheduler_port_base), - "--worker-internal-port-stride", str(args.worker_internal_port_stride), - "--router-host", args.router_host, - "--router-port", str(args.router_port), - "--dataset", args.dataset, - "--request-rate", str(args.request_rate), - "--wait-timeout", str(args.wait_timeout), - "--log-level", args.log_level, - "--output-file", str(out_file), + py, + "examples/diffusion_router/bench_router.py", + "--model", + args.model, + "--routing-algorithm", + algo, + "--num-workers", + str(args.num_workers), + "--num-prompts", + str(args.num_prompts), + "--max-concurrency", + str(args.max_concurrency), + "--num-gpus-per-worker", + str(args.num_gpus_per_worker), + "--worker-base-port", + str(args.worker_base_port), + "--worker-port-stride", + str(args.worker_port_stride), + "--worker-master-port-base", + str(args.worker_master_port_base), + "--worker-scheduler-port-base", + str(args.worker_scheduler_port_base), + "--worker-internal-port-stride", + str(args.worker_internal_port_stride), + "--router-host", + args.router_host, + "--router-port", + str(args.router_port), + "--dataset", + args.dataset, + "--request-rate", + str(args.request_rate), + "--wait-timeout", + str(args.wait_timeout), + "--log-level", + args.log_level, + "--output-file", + str(out_file), ] if args.worker_urls: bench_cmd += ["--worker-urls", *args.worker_urls] @@ -178,14 +198,23 @@ def main() -> int: data = results.get(algo, {}) if "error" in data: parsed[algo] = None - csv_rows.append({ - "algorithm": algo, "throughput_qps": "", "latency_mean": "", - "latency_median": "", "latency_p99": "", "duration": "", - "completed_requests": "", "failed_requests": "", - "throughput_qps_delta_pct": "", "latency_mean_delta_pct": "", - "latency_median_delta_pct": "", "latency_p99_delta_pct": "", - "error": data["error"], - }) + csv_rows.append( + { + "algorithm": algo, + "throughput_qps": "", + "latency_mean": "", + "latency_median": "", + "latency_p99": "", + "duration": "", + "completed_requests": "", + "failed_requests": "", + "throughput_qps_delta_pct": "", + "latency_mean_delta_pct": "", + "latency_median_delta_pct": "", + "latency_p99_delta_pct": "", + "error": data["error"], + } + ) continue row = { @@ -260,11 +289,19 @@ def _fmt_int(v): csv_path = output_dir / "routing_algorithm_comparison.csv" fieldnames = [ - "algorithm", "throughput_qps", "throughput_qps_delta_pct", - "latency_mean", "latency_mean_delta_pct", - "latency_median", "latency_median_delta_pct", - "latency_p99", "latency_p99_delta_pct", - "duration", "completed_requests", "failed_requests", "error", + "algorithm", + "throughput_qps", + "throughput_qps_delta_pct", + "latency_mean", + "latency_mean_delta_pct", + "latency_median", + "latency_median_delta_pct", + "latency_p99", + "latency_p99_delta_pct", + "duration", + "completed_requests", + "failed_requests", + "error", ] with open(csv_path, "w", newline="") as f: writer = csv.DictWriter(f, fieldnames=fieldnames, extrasaction="ignore") diff --git a/examples/diffusion_router/demo.py b/examples/diffusion_router/demo.py index 26c807aff..eac263798 100644 --- a/examples/diffusion_router/demo.py +++ b/examples/diffusion_router/demo.py @@ -36,9 +36,13 @@ def main(): parser.add_argument("--timeout", type=float, default=None, help="Request timeout in seconds") parser.add_argument("--health-check-interval", type=int, default=10, help="Seconds between health checks") parser.add_argument("--health-check-failure-threshold", type=int, default=3, help="Failures before quarantine") - parser.add_argument("--routing-algorithm", type=str, default="least-request", - choices=["least-request", "round-robin", "random"], - help="Load-balancing algorithm") + parser.add_argument( + "--routing-algorithm", + type=str, + default="least-request", + choices=["least-request", "round-robin", "random"], + help="Load-balancing algorithm", + ) parser.add_argument("--verbose", action="store_true", help="Enable verbose logging") args = parser.parse_args() diff --git a/miles/router/diffusion_router.py b/miles/router/diffusion_router.py index 753e4dfba..a6eca21d8 100644 --- a/miles/router/diffusion_router.py +++ b/miles/router/diffusion_router.py @@ -204,8 +204,16 @@ async def _send(worker_url): @staticmethod def _sanitize_response_headers(headers) -> dict: """Remove hop-by-hop and encoding headers that no longer match buffered content.""" - hop_by_hop = {"connection", "keep-alive", "proxy-authenticate", "proxy-authorization", "te", "trailers", - "transfer-encoding", "upgrade"} + hop_by_hop = { + "connection", + "keep-alive", + "proxy-authenticate", + "proxy-authorization", + "te", + "trailers", + "transfer-encoding", + "upgrade", + } dropped = {"content-length", "content-encoding"} return {k: v for k, v in headers.items() if k.lower() not in hop_by_hop | dropped} @@ -233,12 +241,14 @@ async def health_workers(self, request: Request): """Per-worker health and load information.""" workers = [] for url, count in self.worker_request_counts.items(): - workers.append({ - "url": url, - "active_requests": count, - "is_dead": url in self.dead_workers, - "consecutive_failures": self.worker_failure_counts.get(url, 0), - }) + workers.append( + { + "url": url, + "active_requests": count, + "is_dead": url in self.dead_workers, + "consecutive_failures": self.worker_failure_counts.get(url, 0), + } + ) return JSONResponse(content={"workers": workers}) async def update_weights_from_disk(self, request: Request): From 385805dff0a7472f7bd86771050eee616111a48f Mon Sep 17 00:00:00 2001 From: sniper35 Date: Sun, 8 Feb 2026 09:50:25 +0000 Subject: [PATCH 14/14] add TODO --- miles/router/diffusion_router.py | 1 + 1 file changed, 1 insertion(+) diff --git a/miles/router/diffusion_router.py b/miles/router/diffusion_router.py index a6eca21d8..d947da7bf 100644 --- a/miles/router/diffusion_router.py +++ b/miles/router/diffusion_router.py @@ -251,6 +251,7 @@ async def health_workers(self, request: Request): ) return JSONResponse(content={"workers": workers}) + # TODO: integrate with https://github.com/sgl-project/sglang/pull/18306 when it gets merged. async def update_weights_from_disk(self, request: Request): """Broadcast weight reload to all healthy workers.""" body = await request.body()