diff --git a/python/sglang/srt/disaggregation/mini_lb.py b/python/sglang/srt/disaggregation/mini_lb.py index 32cbbfdfbb3..013b7732687 100644 --- a/python/sglang/srt/disaggregation/mini_lb.py +++ b/python/sglang/srt/disaggregation/mini_lb.py @@ -3,10 +3,12 @@ """ import asyncio +import dataclasses +import logging import random import urllib from itertools import chain -from typing import List +from typing import List, Optional import aiohttp import orjson @@ -14,11 +16,32 @@ from fastapi import FastAPI, HTTPException from fastapi.responses import ORJSONResponse, Response, StreamingResponse +from sglang.srt.disaggregation.utils import PDRegistryRequest + +def setup_logger(): + logger = logging.getLogger("pdlb") + logger.setLevel(logging.INFO) + + formatter = logging.Formatter( + "[PDLB (Python)] %(asctime)s - %(levelname)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + + handler = logging.StreamHandler() + handler.setFormatter(formatter) + logger.addHandler(handler) + + return logger + + +logger = setup_logger() + + +@dataclasses.dataclass class PrefillConfig: - def __init__(self, url: str, bootstrap_port: int): - self.url = url - self.bootstrap_port = bootstrap_port + url: str + bootstrap_port: Optional[int] = None class MiniLoadBalancer: @@ -28,6 +51,10 @@ def __init__(self, prefill_configs: List[PrefillConfig], decode_servers: List[st self.decode_servers = decode_servers def select_pair(self): + # TODO: return some message instead of panic + assert len(self.prefill_configs) > 0, "No prefill servers available" + assert len(self.decode_servers) > 0, "No decode servers available" + prefill_config = random.choice(self.prefill_configs) decode_server = random.choice(self.decode_servers) return prefill_config.url, prefill_config.bootstrap_port, decode_server @@ -47,7 +74,7 @@ async def generate( session.post(f"{decode_server}/{endpoint}", json=modified_request), ] # Wait for both responses to complete. Prefill should end first. - prefill_response, decode_response = await asyncio.gather(*tasks) + _, decode_response = await asyncio.gather(*tasks) return ORJSONResponse( content=await decode_response.json(), @@ -268,6 +295,32 @@ async def get_models(): raise HTTPException(status_code=500, detail=str(e)) +@app.post("/register") +async def register(obj: PDRegistryRequest): + if obj.mode == "prefill": + load_balancer.prefill_configs.append( + PrefillConfig(obj.registry_url, obj.bootstrap_port) + ) + logger.info( + f"Registered prefill server: {obj.registry_url} with bootstrap port: {obj.bootstrap_port}" + ) + elif obj.mode == "decode": + load_balancer.decode_servers.append(obj.registry_url) + logger.info(f"Registered decode server: {obj.registry_url}") + else: + raise HTTPException( + status_code=400, + detail="Invalid mode. Must be either PREFILL or DECODE.", + ) + + logger.info( + f"#Prefill servers: {len(load_balancer.prefill_configs)}, " + f"#Decode servers: {len(load_balancer.decode_servers)}" + ) + + return Response(status_code=200) + + def run(prefill_configs, decode_addrs, host, port): global load_balancer load_balancer = MiniLoadBalancer(prefill_configs, decode_addrs) @@ -279,15 +332,16 @@ def run(prefill_configs, decode_addrs, host, port): parser = argparse.ArgumentParser(description="Mini Load Balancer Server") parser.add_argument( - "--prefill", required=True, help="Comma-separated URLs for prefill servers" + "--prefill", type=str, default=[], nargs="+", help="URLs for prefill servers" ) parser.add_argument( - "--prefill-bootstrap-ports", - help="Comma-separated bootstrap ports for prefill servers", - default="8998", + "--decode", type=str, default=[], nargs="+", help="URLs for decode servers" ) parser.add_argument( - "--decode", required=True, help="Comma-separated URLs for decode servers" + "--prefill-bootstrap-ports", + type=int, + nargs="+", + help="Bootstrap ports for prefill servers", ) parser.add_argument( "--host", default="0.0.0.0", help="Host to bind the server (default: 0.0.0.0)" @@ -297,22 +351,19 @@ def run(prefill_configs, decode_addrs, host, port): ) args = parser.parse_args() - prefill_urls = args.prefill.split(",") - bootstrap_ports = [int(p) for p in args.prefill_bootstrap_ports.split(",")] - - if len(bootstrap_ports) == 1: - bootstrap_ports = bootstrap_ports * len(prefill_urls) + bootstrap_ports = args.prefill_bootstrap_ports + if bootstrap_ports is None: + bootstrap_ports = [None] * len(args.prefill) + elif len(bootstrap_ports) == 1: + bootstrap_ports = bootstrap_ports * len(args.prefill) else: - if len(bootstrap_ports) != len(prefill_urls): + if len(bootstrap_ports) != len(args.prefill): raise ValueError( "Number of prefill URLs must match number of bootstrap ports" ) - exit(1) - - prefill_configs = [] - for url, port in zip(prefill_urls, bootstrap_ports): - prefill_configs.append(PrefillConfig(url, port)) - decode_addrs = args.decode.split(",") + prefill_configs = [ + PrefillConfig(url, port) for url, port in zip(args.prefill, bootstrap_ports) + ] - run(prefill_configs, decode_addrs, args.host, args.port) + run(prefill_configs, args.decode, args.host, args.port) diff --git a/python/sglang/srt/disaggregation/utils.py b/python/sglang/srt/disaggregation/utils.py index b50a86c63ff..90fd6034b9e 100644 --- a/python/sglang/srt/disaggregation/utils.py +++ b/python/sglang/srt/disaggregation/utils.py @@ -1,13 +1,18 @@ from __future__ import annotations +import dataclasses +import warnings from collections import deque from enum import Enum -from typing import List +from typing import List, Optional import numpy as np +import requests import torch import torch.distributed as dist +from sglang.srt.utils import get_ip + class DisaggregationMode(Enum): NULL = "null" @@ -119,3 +124,41 @@ def kv_to_page_indices(kv_indices: np.ndarray, page_size: int): def kv_to_page_num(num_kv_indices: int, page_size: int): # ceil(num_kv_indices / page_size) return (num_kv_indices + page_size - 1) // page_size + + +@dataclasses.dataclass +class PDRegistryRequest: + """A request to register a machine itself to the LB.""" + + mode: str + registry_url: str + bootstrap_port: Optional[int] = None + + def __post_init__(self): + if self.mode == "prefill" and self.bootstrap_port is None: + raise ValueError("Bootstrap port must be set in PREFILL mode.") + elif self.mode == "decode" and self.bootstrap_port is not None: + raise ValueError("Bootstrap port must not be set in DECODE mode.") + elif self.mode not in ["prefill", "decode"]: + raise ValueError( + f"Invalid mode: {self.mode}. Must be 'prefill' or 'decode'." + ) + + +def register_disaggregation_server( + mode: str, server_port: int, bootstrap_port: int, pdlb_url: str +): + boostrap_port = bootstrap_port if mode == "prefill" else None + registry_request = PDRegistryRequest( + mode=mode, + registry_url=f"http://{get_ip()}:{server_port}", + bootstrap_port=boostrap_port, + ) + res = requests.post( + f"{pdlb_url}/register", + json=dataclasses.asdict(registry_request), + ) + if res.status_code != 200: + warnings.warn( + f"Failed to register disaggregation server: {res.status_code} {res.text}" + ) diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index 1064d2dbf88..30f4f2305f3 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -42,7 +42,10 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import ORJSONResponse, Response, StreamingResponse -from sglang.srt.disaggregation.utils import FakeBootstrapHost +from sglang.srt.disaggregation.utils import ( + FakeBootstrapHost, + register_disaggregation_server, +) from sglang.srt.entrypoints.engine import _launch_subprocesses from sglang.srt.function_call_parser import FunctionCallParser from sglang.srt.managers.io_struct import ( @@ -871,5 +874,13 @@ def _wait_and_warmup( if server_args.debug_tensor_dump_input_file: kill_process_tree(os.getpid()) + if server_args.pdlb_url is not None: + register_disaggregation_server( + server_args.disaggregation_mode, + server_args.port, + server_args.disaggregation_bootstrap_port, + server_args.pdlb_url, + ) + if launch_callback is not None: launch_callback() diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index fb95660930c..2158ccae6e8 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -925,6 +925,10 @@ def handle_generate_request( ) custom_logit_processor = None + if recv_req.bootstrap_port is None: + # Use default bootstrap port + recv_req.bootstrap_port = self.server_args.disaggregation_bootstrap_port + req = Req( recv_req.rid, recv_req.input_text, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index ec56c6b9493..56c1f916bfd 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -198,6 +198,7 @@ class ServerArgs: disaggregation_bootstrap_port: int = 8998 disaggregation_transfer_backend: str = "mooncake" disaggregation_ib_device: Optional[str] = None + pdlb_url: Optional[str] = None def __post_init__(self): # Expert parallelism @@ -1254,6 +1255,12 @@ def add_cli_args(parser: argparse.ArgumentParser): "or multiple comma-separated devices (e.g., --disaggregation-ib-device mlx5_0,mlx5_1). " "Default is None, which triggers automatic device detection when mooncake backend is enabled.", ) + parser.add_argument( + "--pdlb-url", + type=str, + default=None, + help="The URL of the PD disaggregation load balancer. If set, the prefill/decode server will register with the load balancer.", + ) @classmethod def from_cli_args(cls, args: argparse.Namespace):