diff --git a/examples/diffusion_router/README.md b/examples/diffusion_router/README.md new file mode 100644 index 000000000..55c38c4d0 --- /dev/null +++ b/examples/diffusion_router/README.md @@ -0,0 +1,74 @@ +# Miles Diffusion Router + +Load-balances requests across multiple `sglang-diffusion` worker instances using least-request routing with background health checks and worker quarantine. + +## Quick Start + +```bash +# Start the router with two diffusion backends +python examples/diffusion_router/demo.py --port 30080 \ + --worker-urls http://localhost:10090 http://localhost:10091 + +# Or start empty and add workers dynamically +python examples/diffusion_router/demo.py --port 30080 +curl -X POST 'http://localhost:30080/add_worker?url=http://localhost:10090' +``` + +## API Endpoints + +| Method | Path | Description | +|--------|------|-------------| +| `POST` | `/generate` | Image generation (forwards to `/v1/images/generations`) | +| `POST` | `/generate_video` | Video generation (forwards to `/v1/videos`) | +| `GET` | `/health` | Aggregated router health | +| `GET` | `/health_workers` | Per-worker health and load info | +| `POST` | `/add_worker` | Register a diffusion worker (`?url=...` or JSON body) | +| `GET` | `/list_workers` | List registered workers | +| `POST` | `/update_weights_from_disk` | Broadcast weight reload to all workers | +| `GET, POST, PUT, DELETE` | `/{path}` | Catch-all proxy to least-loaded worker | + +## Load Balancing + +The router uses a **least-request** strategy: each incoming request is forwarded to the worker with the fewest in-flight requests. This is workload-aware and avoids hot-spotting compared to round-robin. When a request completes, the worker's count is decremented, keeping the load state accurate in real time. + +Workers that fail consecutive health checks (default: 3) are quarantined and excluded from the routing pool and will not be added back to avoid stale weights(See [discussion](https://github.com/radixark/miles/pull/260#discussion_r2654274449)). A background loop pings each worker's `GET /health` endpoint at a configurable interval (default: 10s). + +## Benchmark script(under examples/diffusion_router/): +- examples/diffusion_router/bench_routing_algorithms.py — top-level comparison runner +- examples/diffusion_router/bench_router.py — single-algorithm benchmark (spawned per algorithm) +- examples/diffusion_router/demo.py — router process (spawned by bench_router.py) + +## Notes + +- Health check endpoint follows Miles/SGLang convention: `GET /health`. +- Responses are fully buffered; streaming and large-response handling are not supported yet (planned for a follow-up PR). + +## Example Requests + +```bash +# Check health +curl http://localhost:30080/health + +# Generate an image +curl -X POST http://localhost:30080/generate \ + -H 'Content-Type: application/json' \ + -d '{"model": "stabilityai/stable-diffusion-3", "prompt": "a cat", "n": 1, "size": "1024x1024"}' + +# Reload weights on all workers +curl -X POST http://localhost:30080/update_weights_from_disk \ + -H 'Content-Type: application/json' \ + -d '{"model_path": "/path/to/new/weights"}' +``` + +## CLI Options + +``` +--host Bind address (default: 0.0.0.0) +--port Port (default: 30080) +--worker-urls Initial worker URLs +--max-connections Max concurrent connections (default: 100) +--timeout Request timeout in seconds for router-to-worker requests +--health-check-interval Seconds between health checks (default: 10) +--health-check-failure-threshold Failures before quarantine (default: 3) +--verbose Enable verbose logging +``` diff --git a/examples/diffusion_router/bench_router.py b/examples/diffusion_router/bench_router.py new file mode 100644 index 000000000..cd77c3ae0 --- /dev/null +++ b/examples/diffusion_router/bench_router.py @@ -0,0 +1,461 @@ +#!/usr/bin/env python3 +""" +Launch sglang-diffusion workers, start the Miles DiffusionRouter, then run +the sglang diffusion serving benchmark against the router. + +Example: + python examples/diffusion_router/bench_router.py \ + --model Wan-AI/Wan2.2-T2V-A14B-Diffusers \ + --num-workers 2 \ + --num-prompts 20 \ + --max-concurrency 4 +""" + +from __future__ import annotations + +import argparse +import os +import shlex +import signal +import socket +import subprocess +import sys +import time +from collections.abc import Iterable +from pathlib import Path + +import requests + + +def _require_non_empty_model(model: str) -> str: + normalized = model.strip() + if not normalized: + raise ValueError( + "--model must be a non-empty model ID/path. " + "Detected an empty value, which often means a shell variable such as " + "$MODEL was unset." + ) + return normalized + + +def _wait_for_health( + url: str, + timeout: int, + label: str, + proc: subprocess.Popen | None = None, +) -> None: + start = time.time() + last_print = 0.0 + while True: + elapsed = time.time() - start + + # Fail fast if the backing process has already exited + if proc is not None and proc.poll() is not None: + raise RuntimeError( + f"{label} process exited with code {proc.returncode}. " + "Run the worker command directly to see the error." + ) + + try: + resp = requests.get(f"{url}/health", timeout=1) + if resp.status_code == 200: + print(f" [bench] {label} is healthy ({elapsed:.0f}s)", flush=True) + return + except requests.RequestException: + pass + + if elapsed - last_print >= 30: + print(f" [bench] Still waiting for {label}... ({elapsed:.0f}s elapsed)", flush=True) + last_print = elapsed + + if elapsed > timeout: + raise TimeoutError(f"Timed out waiting for {label} at {url}.") + time.sleep(1) + + +def _build_worker_urls(host: str, base_port: int, count: int, stride: int) -> list[str]: + return [f"http://{host}:{base_port + i * stride}" for i in range(count)] + + +def _infer_client_host(host: str) -> str: + if host in ("0.0.0.0", "::"): + return "127.0.0.1" + return host + + +def _is_port_available(host: str, port: int) -> bool: + if host in ("0.0.0.0", "::"): + host = "127.0.0.1" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.settimeout(0.5) + return sock.connect_ex((host, port)) != 0 + + +def _reserve_available_port(host: str, preferred_port: int, used_ports: set[int]) -> int: + if preferred_port < 1 or preferred_port > 65535: + raise ValueError(f"Invalid port: {preferred_port}") + + for port in range(preferred_port, 65536): + if port in used_ports: + continue + if _is_port_available(host, port): + used_ports.add(port) + return port + + for port in range(1024, preferred_port): + if port in used_ports: + continue + if _is_port_available(host, port): + used_ports.add(port) + return port + + raise RuntimeError(f"Unable to reserve a free port for host {host}. " f"Preferred start={preferred_port}.") + + +def _parse_gpu_id_list(raw: str) -> list[str]: + return [item.strip() for item in raw.split(",") if item.strip()] + + +def _detect_gpu_count() -> int: + try: + import torch + + return int(torch.cuda.device_count()) + except Exception: + return 0 + + +def _resolve_gpu_pool(args: argparse.Namespace, env: dict[str, str]) -> list[str] | None: + if args.worker_gpu_ids: + return [str(x) for x in args.worker_gpu_ids] + + visible = env.get("CUDA_VISIBLE_DEVICES", "") + if visible: + parsed = _parse_gpu_id_list(visible) + if parsed: + return parsed + + gpu_count = _detect_gpu_count() + if gpu_count > 0: + return [str(i) for i in range(gpu_count)] + return None + + +def _terminate_all(processes: Iterable[subprocess.Popen]) -> None: + procs = list(processes) + + def _signal_group(proc: subprocess.Popen, sig: int) -> None: + try: + # Processes are launched with start_new_session=True so each has its own group. + os.killpg(proc.pid, sig) + except ProcessLookupError: + pass + except Exception: + if proc.poll() is None: + try: + os.kill(proc.pid, sig) + except ProcessLookupError: + pass + + for proc in procs: + _signal_group(proc, signal.SIGTERM) + + for proc in procs: + try: + proc.wait(timeout=15) + except subprocess.TimeoutExpired: + _signal_group(proc, signal.SIGKILL) + + for proc in procs: + try: + proc.wait(timeout=5) + except subprocess.TimeoutExpired: + pass + + +def main() -> int: + parser = argparse.ArgumentParser(description="Benchmark Miles DiffusionRouter with sglang bench_serving.") + parser.add_argument("--model", type=str, required=True, help="Diffusion model HF ID or local path.") + parser.add_argument("--router-host", type=str, default="127.0.0.1", help="Router bind host.") + parser.add_argument("--router-port", type=int, default=30080, help="Router port.") + parser.add_argument( + "--routing-algorithm", + type=str, + default="least-request", + choices=["least-request", "round-robin", "random"], + help="Load-balancing algorithm for the router.", + ) + parser.add_argument("--router-verbose", action="store_true", help="Enable router verbose logging.") + parser.add_argument("--router-extra-args", type=str, default="", help="Extra args for the router demo script.") + + parser.add_argument("--worker-host", type=str, default="127.0.0.1", help="Worker bind host.") + parser.add_argument("--worker-urls", nargs="*", default=[], help="Existing worker URLs to use.") + parser.add_argument("--num-workers", type=int, default=1, help="Number of workers to launch.") + parser.add_argument("--worker-base-port", type=int, default=10090, help="Base port for launched workers.") + parser.add_argument( + "--worker-port-stride", + type=int, + default=2, + help="Port increment between launched workers. Keep >=2 to avoid sglang internal port collisions.", + ) + parser.add_argument( + "--worker-master-port-base", + type=int, + default=30005, + help="Base torch distributed master port for launched workers.", + ) + parser.add_argument( + "--worker-scheduler-port-base", + type=int, + default=5555, + help="Base scheduler port for launched workers.", + ) + parser.add_argument( + "--worker-internal-port-stride", + type=int, + default=1000, + help=( + "Stride used between workers for master/scheduler base ports. " + "Use >= 101 because sglang randomizes each by +[0,100]." + ), + ) + parser.add_argument("--num-gpus-per-worker", type=int, default=1, help="GPUs per worker.") + parser.add_argument( + "--worker-gpu-ids", + nargs="*", + default=None, + help=( + "Optional GPU IDs/UUIDs for launched workers. They are consumed in order, " + "in groups of --num-gpus-per-worker." + ), + ) + parser.add_argument("--worker-extra-args", type=str, default="", help="Extra args for `sglang serve`.") + parser.add_argument("--skip-workers", action="store_true", help="Do not launch workers.") + + parser.add_argument("--dataset", type=str, default="random", choices=["vbench", "random"]) + parser.add_argument("--dataset-path", type=str, default=None) + parser.add_argument("--num-prompts", type=int, default=20) + parser.add_argument("--max-concurrency", type=int, default=1) + parser.add_argument("--request-rate", type=float, default=float("inf")) + parser.add_argument("--task", type=str, default=None) + parser.add_argument("--width", type=int, default=None) + parser.add_argument("--height", type=int, default=None) + parser.add_argument("--num-frames", type=int, default=None) + parser.add_argument("--fps", type=int, default=None) + parser.add_argument("--output-file", type=str, default=None) + parser.add_argument("--disable-tqdm", action="store_true") + parser.add_argument("--log-level", type=str, default="INFO") + parser.add_argument("--bench-extra-args", type=str, default="", help="Extra args for bench_serving.") + + parser.add_argument("--wait-timeout", type=int, default=1200, help="Seconds to wait for services to be healthy.") + + args = parser.parse_args() + args.model = _require_non_empty_model(args.model) + + try: + import sglang # noqa: F401 + except ImportError as exc: + raise RuntimeError( + "sglang is not installed.\n" 'Install with: uv pip install "sglang[diffusion]" --prerelease=allow' + ) from exc + env = dict(os.environ) + + worker_urls = list(args.worker_urls) + if not worker_urls: + if args.worker_port_stride < 1: + raise ValueError("--worker-port-stride must be >= 1") + if args.worker_internal_port_stride < 101: + raise ValueError("--worker-internal-port-stride must be >= 101") + worker_urls = _build_worker_urls( + args.worker_host, + args.worker_base_port, + args.num_workers, + args.worker_port_stride, + ) + + if args.skip_workers and not worker_urls: + raise ValueError("No workers specified. Provide --worker-urls or disable --skip-workers.") + + if not _is_port_available(args.router_host, args.router_port): + raise RuntimeError( + f"Router port {args.router_port} on {args.router_host} is already in use. " + "Stop the existing router/process or change --router-port." + ) + + processes: list[subprocess.Popen] = [] + try: + if not args.skip_workers: + reserved_ports: set[int] = {args.router_port} + worker_internal_ports: list[tuple[int, int]] = [] + for url in worker_urls: + port = int(url.rsplit(":", 1)[1]) + if not _is_port_available(args.worker_host, port): + raise RuntimeError( + f"Worker port {port} on {args.worker_host} is already in use. " + "Stop existing servers or change --worker-base-port/--worker-port-stride." + ) + reserved_ports.add(port) + + for i, _ in enumerate(worker_urls): + preferred_master = args.worker_master_port_base + i * args.worker_internal_port_stride + preferred_scheduler = args.worker_scheduler_port_base + i * args.worker_internal_port_stride + master_port = _reserve_available_port(args.worker_host, preferred_master, reserved_ports) + scheduler_port = _reserve_available_port(args.worker_host, preferred_scheduler, reserved_ports) + worker_internal_ports.append((master_port, scheduler_port)) + if master_port != preferred_master or scheduler_port != preferred_scheduler: + print( + "[bench] Adjusted internal worker ports due to conflict: " + f"worker={i} master={master_port} scheduler={scheduler_port}", + flush=True, + ) + + sglang_bin = Path(sys.executable).resolve().parent / "sglang" + if sglang_bin.exists(): + sglang_cli_cmd = [str(sglang_bin)] + else: + sglang_cli_cmd = [sys.executable, "-c", "from sglang.cli.main import main; main()"] + gpu_pool = _resolve_gpu_pool(args, env) + total_gpus_needed = len(worker_urls) * args.num_gpus_per_worker + if gpu_pool is not None and len(gpu_pool) < total_gpus_needed: + raise ValueError( + f"Need {total_gpus_needed} GPU slots for {len(worker_urls)} worker(s) x " + f"{args.num_gpus_per_worker} GPU(s), but only {len(gpu_pool)} visible. " + "Set --worker-gpu-ids, reduce --num-workers, or reduce --num-gpus-per-worker." + ) + for i, url in enumerate(worker_urls): + port = int(url.rsplit(":", 1)[1]) + master_port, scheduler_port = worker_internal_ports[i] + cmd = [ + *sglang_cli_cmd, + "serve", + "--model-path", + args.model, + "--num-gpus", + str(args.num_gpus_per_worker), + "--host", + args.worker_host, + "--port", + str(port), + "--master-port", + str(master_port), + "--scheduler-port", + str(scheduler_port), + ] + if args.worker_extra_args: + cmd += shlex.split(args.worker_extra_args) + worker_env = env + if gpu_pool is not None: + start = i * args.num_gpus_per_worker + end = start + args.num_gpus_per_worker + assigned = gpu_pool[start:end] + worker_env = dict(env) + worker_env["CUDA_VISIBLE_DEVICES"] = ",".join(assigned) + print( + f"[bench] Worker {i} uses CUDA_VISIBLE_DEVICES={worker_env['CUDA_VISIBLE_DEVICES']}", + flush=True, + ) + print( + f"[bench] Launching worker {i} on port {port} " + f"(master={master_port}, scheduler={scheduler_port})...", + flush=True, + ) + processes.append(subprocess.Popen(cmd, env=worker_env, start_new_session=True)) + + if args.max_concurrency and worker_urls: + per_worker = (args.max_concurrency + len(worker_urls) - 1) // len(worker_urls) + if per_worker > 1: + print( + "[bench] Warning: " + f"max_concurrency={args.max_concurrency} over {len(worker_urls)} workers " + f"can drive up to ~{per_worker} concurrent requests per worker, which may OOM " + "large diffusion models.", + flush=True, + ) + + print( + f"[bench] Waiting for {len(worker_urls)} worker(s) to become healthy (this may take several minutes)...", + flush=True, + ) + for i, url in enumerate(worker_urls): + _wait_for_health(url, args.wait_timeout, f"worker {url}", proc=processes[i]) + + router_cmd = [ + sys.executable, + "examples/diffusion_router/demo.py", + "--host", + args.router_host, + "--port", + str(args.router_port), + "--worker-urls", + *worker_urls, + ] + if args.routing_algorithm: + router_cmd += ["--routing-algorithm", args.routing_algorithm] + if args.router_verbose: + router_cmd.append("--verbose") + if args.router_extra_args: + router_cmd += shlex.split(args.router_extra_args) + + if not _is_port_available(args.router_host, args.router_port): + raise RuntimeError( + f"Router port {args.router_port} on {args.router_host} is already in use. " + "Stop the existing router/process or change --router-port." + ) + + print(f"[bench] Launching router on port {args.router_port}...", flush=True) + router_proc = subprocess.Popen(router_cmd, start_new_session=True) + processes.append(router_proc) + + router_host = _infer_client_host(args.router_host) + base_url = f"http://{router_host}:{args.router_port}" + _wait_for_health(base_url, args.wait_timeout, "router", proc=router_proc) + + print(f"[bench] Running benchmark: {args.num_prompts} prompts, concurrency={args.max_concurrency}", flush=True) + + bench_cmd = [ + sys.executable, + "-m", + "sglang.multimodal_gen.benchmarks.bench_serving", + "--base-url", + base_url, + "--model", + args.model, + "--dataset", + args.dataset, + "--num-prompts", + str(args.num_prompts), + "--request-rate", + str(args.request_rate), + "--log-level", + args.log_level, + ] + if args.dataset_path: + bench_cmd += ["--dataset-path", args.dataset_path] + if args.max_concurrency: + bench_cmd += ["--max-concurrency", str(args.max_concurrency)] + if args.task: + bench_cmd += ["--task", args.task] + if args.width: + bench_cmd += ["--width", str(args.width)] + if args.height: + bench_cmd += ["--height", str(args.height)] + if args.num_frames: + bench_cmd += ["--num-frames", str(args.num_frames)] + if args.fps: + bench_cmd += ["--fps", str(args.fps)] + if args.output_file: + bench_cmd += ["--output-file", args.output_file] + if args.disable_tqdm: + bench_cmd.append("--disable-tqdm") + if args.bench_extra_args: + bench_cmd += shlex.split(args.bench_extra_args) + + bench_proc = subprocess.Popen(bench_cmd, env=env, start_new_session=True) + processes.append(bench_proc) + return bench_proc.wait() + finally: + _terminate_all(reversed(processes)) + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/examples/diffusion_router/bench_routing_algorithms.py b/examples/diffusion_router/bench_routing_algorithms.py new file mode 100644 index 000000000..827417e08 --- /dev/null +++ b/examples/diffusion_router/bench_routing_algorithms.py @@ -0,0 +1,317 @@ +#!/usr/bin/env python3 +""" +Compare routing algorithms by running bench_router.py for each algorithm +and collecting results into a summary table and CSV. + +Example: + python examples/diffusion_router/bench_routing_algorithms.py \ + --model Wan-AI/Wan2.2-T2V-A14B-Diffusers \ + --num-workers 2 \ + --num-prompts 10 \ + --max-concurrency 2 +""" + +from __future__ import annotations + +import argparse +import csv +import json +import shlex +import subprocess +import sys +from datetime import datetime +from pathlib import Path + +ALL_ALGORITHMS = ["least-request", "round-robin", "random"] + + +def _require_non_empty_model(model: str) -> str: + normalized = model.strip() + if not normalized: + raise ValueError( + "--model must be a non-empty model ID/path. " + "Detected an empty value, which often means a shell variable such as " + "$MODEL was unset." + ) + return normalized + + +def main() -> int: + parser = argparse.ArgumentParser(description="Compare routing algorithms by running bench_router.py for each.") + parser.add_argument("--model", type=str, required=True, help="Diffusion model HF ID or local path.") + parser.add_argument( + "--algorithms", + nargs="+", + default=ALL_ALGORITHMS, + choices=ALL_ALGORITHMS, + help="Algorithms to compare (default: all three).", + ) + parser.add_argument("--output-dir", type=str, default=None, help="Output directory for results.") + + # Pass-through args for bench_router.py + parser.add_argument("--router-host", type=str, default="127.0.0.1") + parser.add_argument("--router-port", type=int, default=30080) + parser.add_argument("--router-verbose", action="store_true") + parser.add_argument("--router-extra-args", type=str, default="") + parser.add_argument("--worker-host", type=str, default="127.0.0.1") + parser.add_argument("--worker-urls", nargs="*", default=[]) + parser.add_argument("--num-workers", type=int, default=1) + parser.add_argument("--worker-base-port", type=int, default=10090) + parser.add_argument("--worker-port-stride", type=int, default=2) + parser.add_argument("--worker-master-port-base", type=int, default=30005) + parser.add_argument("--worker-scheduler-port-base", type=int, default=5555) + parser.add_argument("--worker-internal-port-stride", type=int, default=1000) + parser.add_argument("--num-gpus-per-worker", type=int, default=1) + parser.add_argument("--worker-gpu-ids", nargs="*", default=None) + parser.add_argument("--worker-extra-args", type=str, default="") + parser.add_argument("--skip-workers", action="store_true") + parser.add_argument("--dataset", type=str, default="random", choices=["vbench", "random"]) + parser.add_argument("--dataset-path", type=str, default=None) + parser.add_argument("--num-prompts", type=int, default=20) + parser.add_argument("--max-concurrency", type=int, default=1) + parser.add_argument("--request-rate", type=float, default=float("inf")) + parser.add_argument("--task", type=str, default=None) + parser.add_argument("--width", type=int, default=None) + parser.add_argument("--height", type=int, default=None) + parser.add_argument("--num-frames", type=int, default=None) + parser.add_argument("--fps", type=int, default=None) + parser.add_argument("--disable-tqdm", action="store_true") + parser.add_argument("--log-level", type=str, default="INFO") + parser.add_argument("--bench-extra-args", type=str, default="") + parser.add_argument("--wait-timeout", type=int, default=1200) + + args = parser.parse_args() + args.model = _require_non_empty_model(args.model) + + output_dir = ( + Path(args.output_dir) + if args.output_dir + else Path("outputs") / f"routing_algo_compare_{datetime.now().strftime('%Y%m%d_%H%M%S')}" + ) + output_dir.mkdir(parents=True, exist_ok=True) + + py = sys.executable + results: dict[str, dict] = {} + + for algo in args.algorithms: + print(f"\n{'='*60}", flush=True) + print(f" Running benchmark with routing algorithm: {algo}", flush=True) + print(f"{'='*60}\n", flush=True) + + out_file = output_dir / f"bench_{algo}.json" + + bench_cmd = [ + py, + "examples/diffusion_router/bench_router.py", + "--model", + args.model, + "--routing-algorithm", + algo, + "--num-workers", + str(args.num_workers), + "--num-prompts", + str(args.num_prompts), + "--max-concurrency", + str(args.max_concurrency), + "--num-gpus-per-worker", + str(args.num_gpus_per_worker), + "--worker-base-port", + str(args.worker_base_port), + "--worker-port-stride", + str(args.worker_port_stride), + "--worker-master-port-base", + str(args.worker_master_port_base), + "--worker-scheduler-port-base", + str(args.worker_scheduler_port_base), + "--worker-internal-port-stride", + str(args.worker_internal_port_stride), + "--router-host", + args.router_host, + "--router-port", + str(args.router_port), + "--dataset", + args.dataset, + "--request-rate", + str(args.request_rate), + "--wait-timeout", + str(args.wait_timeout), + "--log-level", + args.log_level, + "--output-file", + str(out_file), + ] + if args.worker_urls: + bench_cmd += ["--worker-urls", *args.worker_urls] + if args.worker_gpu_ids: + bench_cmd += ["--worker-gpu-ids", *args.worker_gpu_ids] + if args.dataset_path: + bench_cmd += ["--dataset-path", args.dataset_path] + if args.task: + bench_cmd += ["--task", args.task] + if args.width: + bench_cmd += ["--width", str(args.width)] + if args.height: + bench_cmd += ["--height", str(args.height)] + if args.num_frames: + bench_cmd += ["--num-frames", str(args.num_frames)] + if args.fps: + bench_cmd += ["--fps", str(args.fps)] + if args.disable_tqdm: + bench_cmd.append("--disable-tqdm") + if args.router_verbose: + bench_cmd.append("--router-verbose") + if args.skip_workers: + bench_cmd.append("--skip-workers") + if args.worker_extra_args: + bench_cmd += ["--worker-extra-args", args.worker_extra_args] + if args.router_extra_args: + bench_cmd += ["--router-extra-args", args.router_extra_args] + if args.bench_extra_args: + bench_cmd += ["--bench-extra-args", args.bench_extra_args] + if args.worker_host != "127.0.0.1": + bench_cmd += ["--worker-host", args.worker_host] + + print("[run]", " ".join(shlex.quote(x) for x in bench_cmd), flush=True) + rc = subprocess.call(bench_cmd) + + if rc != 0: + print(f"[warn] bench_router.py exited with code {rc} for algorithm '{algo}'", flush=True) + results[algo] = {"error": f"exit_code={rc}"} + continue + + if out_file.exists(): + try: + results[algo] = json.loads(out_file.read_text()) + except json.JSONDecodeError as e: + print(f"[warn] Failed to parse {out_file}: {e}", flush=True) + results[algo] = {"error": f"json_parse_error: {e}"} + else: + print(f"[warn] Output file not found: {out_file}", flush=True) + results[algo] = {"error": "output_file_missing"} + + BASELINE = "random" + metric_keys = ["throughput_qps", "latency_mean", "latency_median", "latency_p99"] + + csv_rows: list[dict] = [] + parsed: dict[str, dict] = {} + for algo in args.algorithms: + data = results.get(algo, {}) + if "error" in data: + parsed[algo] = None + csv_rows.append( + { + "algorithm": algo, + "throughput_qps": "", + "latency_mean": "", + "latency_median": "", + "latency_p99": "", + "duration": "", + "completed_requests": "", + "failed_requests": "", + "throughput_qps_delta_pct": "", + "latency_mean_delta_pct": "", + "latency_median_delta_pct": "", + "latency_p99_delta_pct": "", + "error": data["error"], + } + ) + continue + + row = { + "algorithm": algo, + "throughput_qps": data.get("throughput_qps", ""), + "latency_mean": data.get("latency_mean", ""), + "latency_median": data.get("latency_median", ""), + "latency_p99": data.get("latency_p99", ""), + "duration": data.get("duration", ""), + "completed_requests": data.get("completed_requests", ""), + "failed_requests": data.get("failed_requests", ""), + "error": "", + } + parsed[algo] = row + csv_rows.append(row) + + baseline_row = parsed.get(BASELINE) + for row in csv_rows: + if row.get("error"): + continue + for key in metric_keys: + val = row.get(key, "") + base = baseline_row.get(key, "") if baseline_row else "" + delta_key = f"{key}_delta_pct" + if isinstance(val, (int, float)) and isinstance(base, (int, float)) and base: + row[delta_key] = ((val - base) / abs(base)) * 100 + else: + row[delta_key] = "" + + print(f"\n{'='*100}", flush=True) + print(f" Routing Algorithm Comparison (baseline: {BASELINE})", flush=True) + print(f"{'='*100}", flush=True) + + header = ( + f"{'Algorithm':<16} {'Throughput':>14} {'Tput Delta':>11}" + f" {'Mean Lat':>12} {'Delta':>8}" + f" {'Median Lat':>12} {'Delta':>8}" + f" {'P99 Lat':>12} {'Delta':>8}" + f" {'Done':>6} {'Fail':>6}" + ) + print(header, flush=True) + print("-" * len(header), flush=True) + + def _fmt_qps(v): + return f"{v:.2f} req/s" if isinstance(v, (int, float)) else str(v) + + def _fmt_lat(v): + return f"{v:.3f} s" if isinstance(v, (int, float)) else str(v) + + def _fmt_delta(v): + if isinstance(v, (int, float)): + sign = "+" if v >= 0 else "" + return f"{sign}{v:.1f}%" + return "" + + def _fmt_int(v): + return str(v) if v != "" else "N/A" + + for row in csv_rows: + if row.get("error"): + print(f"{row['algorithm']:<16} {'ERROR':>14} {row['error']}", flush=True) + continue + print( + f"{row['algorithm']:<16}" + f" {_fmt_qps(row['throughput_qps']):>14} {_fmt_delta(row.get('throughput_qps_delta_pct', '')):>11}" + f" {_fmt_lat(row['latency_mean']):>12} {_fmt_delta(row.get('latency_mean_delta_pct', '')):>8}" + f" {_fmt_lat(row['latency_median']):>12} {_fmt_delta(row.get('latency_median_delta_pct', '')):>8}" + f" {_fmt_lat(row['latency_p99']):>12} {_fmt_delta(row.get('latency_p99_delta_pct', '')):>8}" + f" {_fmt_int(row['completed_requests']):>6} {_fmt_int(row['failed_requests']):>6}", + flush=True, + ) + + csv_path = output_dir / "routing_algorithm_comparison.csv" + fieldnames = [ + "algorithm", + "throughput_qps", + "throughput_qps_delta_pct", + "latency_mean", + "latency_mean_delta_pct", + "latency_median", + "latency_median_delta_pct", + "latency_p99", + "latency_p99_delta_pct", + "duration", + "completed_requests", + "failed_requests", + "error", + ] + with open(csv_path, "w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=fieldnames, extrasaction="ignore") + writer.writeheader() + writer.writerows(csv_rows) + print(f"\n[done] CSV written to {csv_path}", flush=True) + print(f"[done] Per-algorithm JSON results in {output_dir}/", flush=True) + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/examples/diffusion_router/demo.py b/examples/diffusion_router/demo.py new file mode 100644 index 000000000..eac263798 --- /dev/null +++ b/examples/diffusion_router/demo.py @@ -0,0 +1,61 @@ +""" +Demo script for the Miles Diffusion Router. + +Starts a diffusion router that load-balances requests across multiple +sglang-diffusion worker instances using least-request routing. + +Usage: + # Start with no workers (add them dynamically via /add_worker): + python examples/diffusion_router/demo.py --port 30080 + + # Start with pre-registered workers: + python examples/diffusion_router/demo.py --port 30080 \ + --worker-urls http://localhost:10090 http://localhost:10091 + + # Then interact: + curl http://localhost:30080/health + curl -X POST 'http://localhost:30080/add_worker?url=http://localhost:10092' + curl http://localhost:30080/list_workers + curl -X POST http://localhost:30080/generate -H 'Content-Type: application/json' \ + -d '{"model": "stabilityai/stable-diffusion-3", "prompt": "a cat", "n": 1, "size": "1024x1024"}' +""" + +import argparse + +import uvicorn + +from miles.router.diffusion_router import DiffusionRouter + + +def main(): + parser = argparse.ArgumentParser(description="Miles Diffusion Router Demo") + parser.add_argument("--host", type=str, default="0.0.0.0", help="Router bind address") + parser.add_argument("--port", type=int, default=30080, help="Router port") + parser.add_argument("--worker-urls", nargs="*", default=[], help="Initial diffusion worker URLs") + parser.add_argument("--max-connections", type=int, default=100, help="Max concurrent connections to workers") + parser.add_argument("--timeout", type=float, default=None, help="Request timeout in seconds") + parser.add_argument("--health-check-interval", type=int, default=10, help="Seconds between health checks") + parser.add_argument("--health-check-failure-threshold", type=int, default=3, help="Failures before quarantine") + parser.add_argument( + "--routing-algorithm", + type=str, + default="least-request", + choices=["least-request", "round-robin", "random"], + help="Load-balancing algorithm", + ) + parser.add_argument("--verbose", action="store_true", help="Enable verbose logging") + args = parser.parse_args() + + router = DiffusionRouter(args, verbose=args.verbose) + + # Pre-register any workers specified on the command line + for url in args.worker_urls: + router.register_worker(url) + + print(f"[demo] Starting diffusion router on {args.host}:{args.port}") + print(f"[demo] Workers: {list(router.worker_request_counts.keys()) or '(none — add via POST /add_worker)'}") + uvicorn.run(router.app, host=args.host, port=args.port, log_level="info") + + +if __name__ == "__main__": + main() diff --git a/miles/router/diffusion_router.py b/miles/router/diffusion_router.py new file mode 100644 index 000000000..d947da7bf --- /dev/null +++ b/miles/router/diffusion_router.py @@ -0,0 +1,296 @@ +import asyncio +import json +import logging +import random + +import httpx +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse +from starlette.responses import Response + +logger = logging.getLogger(__name__) + + +class DiffusionRouter: + def __init__(self, args, verbose=False): + """Initialize the diffusion router for load-balancing across sglang-diffusion workers.""" + self.args = args + self.verbose = verbose + + self.app = FastAPI() + self.app.add_event_handler("startup", self._start_background_health_check) + + # URL -> Active Request Count (load state) + self.worker_request_counts: dict[str, int] = {} + # URL -> Consecutive Failures + self.worker_failure_counts: dict[str, int] = {} + # Quarantined workers excluded from routing pool + self.dead_workers: set[str] = set() + + self.routing_algorithm = getattr(args, "routing_algorithm", "least-request") + self._rr_index = 0 + + max_connections = getattr(args, "max_connections", 100) + timeout = getattr(args, "timeout", None) + + self.client = httpx.AsyncClient( + limits=httpx.Limits(max_connections=max_connections), + timeout=httpx.Timeout(timeout), + ) + + self._setup_routes() + + def _setup_routes(self): + """Setup all the HTTP routes.""" + self.app.post("/add_worker")(self.add_worker) + self.app.get("/list_workers")(self.list_workers) + self.app.get("/health")(self.health) + self.app.get("/health_workers")(self.health_workers) + self.app.post("/generate")(self.generate) + self.app.post("/generate_video")(self.generate_video) + self.app.post("/update_weights_from_disk")(self.update_weights_from_disk) + # Catch-all route for proxying — must be registered LAST + self.app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])(self.proxy) + + async def _start_background_health_check(self): + asyncio.create_task(self._health_check_loop()) + + async def _check_worker_health(self, url): + try: + response = await self.client.get(f"{url}/health", timeout=5.0) + if response.status_code == 200: + return url, True + logger.debug(f"[diffusion-router] Worker {url} unhealthy (status {response.status_code})") + except Exception as e: + logger.debug(f"[diffusion-router] Worker {url} health check failed: {e}") + return url, False + + async def _health_check_loop(self): + """Background loop to monitor worker health and quarantine failing workers.""" + interval = getattr(self.args, "health_check_interval", 10) + threshold = getattr(self.args, "health_check_failure_threshold", 3) + + while True: + try: + await asyncio.sleep(interval) + + urls = [u for u in self.worker_request_counts if u not in self.dead_workers] + if not urls: + continue + + results = await asyncio.gather(*(self._check_worker_health(url) for url in urls)) + + for url, is_healthy in results: + if not is_healthy: + failures = self.worker_failure_counts.get(url, 0) + 1 + self.worker_failure_counts[url] = failures + if failures >= threshold: + logger.warning( + f"[diffusion-router] Worker {url} failed {threshold} consecutive checks. Marking DEAD." + ) + self.dead_workers.add(url) + # Dead workers are permanently excluded. Reconnecting them + # would risk serving stale weights after training has moved on. + else: + self.worker_failure_counts[url] = 0 + + healthy = len(self.worker_request_counts) - len(self.dead_workers) + logger.debug(f"[diffusion-router] Health check complete. {healthy} workers healthy.") + + except asyncio.CancelledError: + raise + except Exception as e: + logger.error(f"[diffusion-router] Unexpected error in health check loop: {e}", exc_info=True) + await asyncio.sleep(5) + + def _use_url(self): + """Select a worker URL based on the configured routing algorithm.""" + if not self.worker_request_counts: + raise RuntimeError("No workers registered in the pool") + + valid_workers = [w for w in self.worker_request_counts if w not in self.dead_workers] + if not valid_workers: + raise RuntimeError("No healthy workers available in the pool") + + if self.routing_algorithm == "round-robin": + url = valid_workers[self._rr_index % len(valid_workers)] + self._rr_index = (self._rr_index + 1) % len(valid_workers) + elif self.routing_algorithm == "random": + url = random.choice(valid_workers) + else: # least-request (default) + url = min(valid_workers, key=self.worker_request_counts.get) + + self.worker_request_counts[url] += 1 + return url + + def _finish_url(self, url): + """Mark the request to the given URL as finished.""" + if url not in self.worker_request_counts: + raise ValueError(f"URL {url} not recognized") + self.worker_request_counts[url] -= 1 + if self.worker_request_counts[url] < 0: + raise RuntimeError(f"URL {url} count went negative") + + def _build_proxy_response(self, content: bytes, status_code: int, headers: dict) -> Response: + """ + Build an HTTP response from proxied bytes. + + Keep behavior consistent with `MilesRouter._build_proxy_response`: + - If the payload is JSON, return `JSONResponse` + - Otherwise, return raw `Response` + + Diffusion responses (e.g. `b64_json`) can be very large; decoding and re-encoding + the JSON can dominate CPU time. We therefore skip JSON decoding for large bodies. + """ + content_type = headers.get("content-type", "") + + # Size guard: don't pay JSON decode/re-encode costs on large payloads. + # This preserves exact bytes on the wire and avoids CPU/memory pressure. + max_json_reencode_bytes = 256 * 1024 + if len(content) <= max_json_reencode_bytes: + try: + data = json.loads(content) + return JSONResponse(content=data, status_code=status_code, headers=headers) + except Exception: + pass + + return Response(content=content, status_code=status_code, headers=headers, media_type=content_type) + + async def _forward_to_worker(self, request: Request, path: str) -> Response: + """Forward a request to a selected worker and return the response.""" + try: + worker_url = self._use_url() + except RuntimeError as exc: + return JSONResponse(status_code=503, content={"error": str(exc)}) + + # TODO: Support streaming responses; current implementation buffers full response. + query = request.url.query + url = f"{worker_url}/{path}" if not query else f"{worker_url}/{path}?{query}" + body = await request.body() + headers = dict(request.headers) + # Let httpx set the correct framing headers for the forwarded body. + if body is not None: + headers = {k: v for k, v in headers.items() if k.lower() not in ("content-length", "transfer-encoding")} + + try: + response = await self.client.request(request.method, url, content=body, headers=headers) + content = await response.aread() + except Exception as exc: + self._finish_url(worker_url) + logger.error(f"[diffusion-router] Failed to forward request to {worker_url}: {exc}") + return JSONResponse(status_code=502, content={"error": f"Worker request failed: {exc}"}) + else: + self._finish_url(worker_url) + + resp_headers = self._sanitize_response_headers(response.headers) + return self._build_proxy_response(content, response.status_code, resp_headers) + + async def _broadcast_to_workers(self, path: str, body: bytes, headers: dict) -> list[dict]: + """Send a request to ALL healthy workers and collect results.""" + urls = [u for u in self.worker_request_counts if u not in self.dead_workers] + if not urls: + return [] + + async def _send(worker_url): + try: + response = await self.client.post(f"{worker_url}/{path}", content=body, headers=headers) + content = await response.aread() + return {"worker_url": worker_url, "status_code": response.status_code, "body": json.loads(content)} + except Exception as e: + return {"worker_url": worker_url, "status_code": 502, "body": {"error": str(e)}} + + return await asyncio.gather(*(_send(u) for u in urls)) + + @staticmethod + def _sanitize_response_headers(headers) -> dict: + """Remove hop-by-hop and encoding headers that no longer match buffered content.""" + hop_by_hop = { + "connection", + "keep-alive", + "proxy-authenticate", + "proxy-authorization", + "te", + "trailers", + "transfer-encoding", + "upgrade", + } + dropped = {"content-length", "content-encoding"} + return {k: v for k, v in headers.items() if k.lower() not in hop_by_hop | dropped} + + async def generate(self, request: Request): + """Route image generation to the least-loaded worker via /v1/images/generations.""" + return await self._forward_to_worker(request, "v1/images/generations") + + async def generate_video(self, request: Request): + """Route video generation to the least-loaded worker via /v1/videos.""" + return await self._forward_to_worker(request, "v1/videos") + + async def health(self, request: Request): + """Aggregated health status: healthy if at least one worker is alive.""" + total = len(self.worker_request_counts) + dead = len(self.dead_workers) + healthy = total - dead + status = "healthy" if healthy > 0 else "unhealthy" + code = 200 if healthy > 0 else 503 + return JSONResponse( + status_code=code, + content={"status": status, "healthy_workers": healthy, "total_workers": total}, + ) + + async def health_workers(self, request: Request): + """Per-worker health and load information.""" + workers = [] + for url, count in self.worker_request_counts.items(): + workers.append( + { + "url": url, + "active_requests": count, + "is_dead": url in self.dead_workers, + "consecutive_failures": self.worker_failure_counts.get(url, 0), + } + ) + return JSONResponse(content={"workers": workers}) + + # TODO: integrate with https://github.com/sgl-project/sglang/pull/18306 when it gets merged. + async def update_weights_from_disk(self, request: Request): + """Broadcast weight reload to all healthy workers.""" + body = await request.body() + headers = dict(request.headers) + results = await self._broadcast_to_workers("update_weights_from_disk", body, headers) + return JSONResponse(content={"results": results}) + + def register_worker(self, url: str) -> None: + """Register a worker URL if not already known (sync, for startup use).""" + if url not in self.worker_request_counts: + self.worker_request_counts[url] = 0 + self.worker_failure_counts[url] = 0 + if self.verbose: + print(f"[diffusion-router] Added new worker: {url}") + + async def add_worker(self, request: Request): + """Register a new diffusion worker (HTTP endpoint).""" + worker_url = request.query_params.get("url") or request.query_params.get("worker_url") + + if not worker_url: + body = await request.body() + try: + payload = json.loads(body) if body else {} + except json.JSONDecodeError: + return JSONResponse(status_code=400, content={"error": "Invalid JSON body"}) + worker_url = payload.get("url") or payload.get("worker_url") + + if not worker_url: + return JSONResponse( + status_code=400, content={"error": "worker_url is required (use query ?url=... or JSON body)"} + ) + + self.register_worker(worker_url) + return {"status": "success", "worker_urls": list(self.worker_request_counts.keys())} + + async def list_workers(self, request: Request): + """List all registered workers.""" + return {"urls": list(self.worker_request_counts.keys())} + + async def proxy(self, request: Request, path: str): + """Catch-all: forward any unmatched request to the least-loaded worker.""" + return await self._forward_to_worker(request, path) diff --git a/tests/fast/router/test_diffusion_router.py b/tests/fast/router/test_diffusion_router.py new file mode 100644 index 000000000..8ab24203f --- /dev/null +++ b/tests/fast/router/test_diffusion_router.py @@ -0,0 +1,158 @@ +from argparse import Namespace + +import pytest + +from miles.router.diffusion_router import DiffusionRouter + + +def make_router_args(**overrides) -> Namespace: + """Create a Namespace with default DiffusionRouter args, applying overrides.""" + defaults = dict( + host="127.0.0.1", + port=30080, + max_connections=100, + timeout=None, + routing_algorithm="least-request", + ) + defaults.update(overrides) + return Namespace(**defaults) + + +@pytest.fixture +def router_factory(): + """Factory fixture that creates a DiffusionRouter with pre-set worker state.""" + + def _create( + workers: dict[str, int], + dead: set[str] | None = None, + **arg_overrides, + ) -> DiffusionRouter: + router = DiffusionRouter(make_router_args(**arg_overrides)) + router.worker_request_counts = dict(workers) + router.worker_failure_counts = {url: 0 for url in workers} + if dead: + router.dead_workers = set(dead) + return router + + return _create + + +@pytest.mark.unit +class TestLeastRequest: + """Test the least-request (default) load-balancing algorithm.""" + + def test_selects_min_load(self, router_factory): + router = router_factory({"http://w1:8000": 5, "http://w2:8000": 2, "http://w3:8000": 8}) + selected = router._use_url() + assert selected == "http://w2:8000" + assert router.worker_request_counts["http://w2:8000"] == 3 + + def test_excludes_dead_workers(self, router_factory): + router = router_factory( + {"http://w1:8000": 5, "http://w2:8000": 2, "http://w3:8000": 8}, + dead={"http://w2:8000"}, + ) + selected = router._use_url() + assert selected == "http://w1:8000" + assert router.worker_request_counts["http://w1:8000"] == 6 + + +@pytest.mark.unit +class TestRoundRobin: + """Test the round-robin load-balancing algorithm.""" + + def test_cycles_workers(self, router_factory): + router = router_factory( + {"http://w1:8000": 0, "http://w2:8000": 0, "http://w3:8000": 0}, + routing_algorithm="round-robin", + ) + results = [router._use_url() for _ in range(6)] + workers = list(router.worker_request_counts.keys()) + expected = [workers[i % 3] for i in range(6)] + assert results == expected + for url in workers: + assert router.worker_request_counts[url] == 2 + + def test_excludes_dead_workers(self, router_factory): + router = router_factory( + {"http://w1:8000": 0, "http://w2:8000": 0, "http://w3:8000": 0}, + dead={"http://w2:8000"}, + routing_algorithm="round-robin", + ) + results = [router._use_url() for _ in range(4)] + assert "http://w2:8000" not in results + assert all(url in ("http://w1:8000", "http://w3:8000") for url in results) + + +@pytest.mark.unit +class TestRandom: + """Test the random load-balancing algorithm.""" + + def test_selects_from_valid_workers(self, router_factory): + router = router_factory( + {"http://w1:8000": 0, "http://w2:8000": 0, "http://w3:8000": 0}, + routing_algorithm="random", + ) + seen = set() + for _ in range(30): + # Reset counts so they don't grow unbounded + for url in router.worker_request_counts: + router.worker_request_counts[url] = 0 + seen.add(router._use_url()) + assert seen == {"http://w1:8000", "http://w2:8000", "http://w3:8000"} + + def test_excludes_dead_workers(self, router_factory): + router = router_factory( + {"http://w1:8000": 0, "http://w2:8000": 0, "http://w3:8000": 0}, + dead={"http://w2:8000"}, + routing_algorithm="random", + ) + for _ in range(20): + url = router._use_url() + assert url != "http://w2:8000" + router.worker_request_counts[url] -= 1 # reset increment + + +@pytest.mark.unit +class TestErrorCases: + """Test error handling across all routing algorithms.""" + + @pytest.mark.parametrize("algorithm", ["least-request", "round-robin", "random"]) + def test_raises_when_no_workers(self, router_factory, algorithm): + router = router_factory({}, routing_algorithm=algorithm) + with pytest.raises(RuntimeError, match="No workers registered"): + router._use_url() + + @pytest.mark.parametrize("algorithm", ["least-request", "round-robin", "random"]) + def test_raises_when_all_dead(self, router_factory, algorithm): + router = router_factory( + {"http://w1:8000": 0, "http://w2:8000": 0}, + dead={"http://w1:8000", "http://w2:8000"}, + routing_algorithm=algorithm, + ) + with pytest.raises(RuntimeError, match="No healthy workers"): + router._use_url() + + +@pytest.mark.unit +class TestCountManagement: + """Test that _use_url / _finish_url correctly track active request counts.""" + + @pytest.mark.parametrize("algorithm", ["least-request", "round-robin", "random"]) + def test_increment_and_finish(self, router_factory, algorithm): + router = router_factory({"http://w1:8000": 0}, routing_algorithm=algorithm) + url = router._use_url() + assert router.worker_request_counts[url] == 1 + router._finish_url(url) + assert router.worker_request_counts[url] == 0 + + +@pytest.mark.unit +class TestDefaults: + """Test default routing algorithm when the attribute is absent.""" + + def test_default_algorithm_is_least_request(self): + args = Namespace(host="127.0.0.1", port=30080, max_connections=100, timeout=None) + # args has no routing_algorithm attribute + router = DiffusionRouter(args) + assert router.routing_algorithm == "least-request"