From e3005fe4f7ddfb21cf1d491eb56c3bd01ac84980 Mon Sep 17 00:00:00 2001 From: IAN <50618241+hcyz33@users.noreply.github.com> Date: Tue, 22 Apr 2025 16:54:43 +0800 Subject: [PATCH] Support Muti Prefill in one node Co-authored-by: shuaills --- python/sglang/srt/disaggregation/decode.py | 2 +- python/sglang/srt/disaggregation/mini_lb.py | 53 ++++++++++++++++--- python/sglang/srt/managers/io_struct.py | 5 ++ python/sglang/srt/managers/schedule_batch.py | 2 + python/sglang/srt/managers/scheduler.py | 1 + .../sglang/srt/managers/tokenizer_manager.py | 1 + 6 files changed, 55 insertions(+), 9 deletions(-) diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py index a45af7e3719..a2f34a98ac8 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -137,7 +137,7 @@ def add(self, req: Req) -> None: kv_receiver_class = get_kv_class(self.transfer_backend, KVClassType.RECEIVER) kv_receiver = kv_receiver_class( mgr=self.kv_manager, - bootstrap_addr=f"{req.bootstrap_host}:{self.bootstrap_port}", + bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}", bootstrap_room=req.bootstrap_room, ) self.queue.append(DecodeRequest(req=req, kv_receiver=kv_receiver)) diff --git a/python/sglang/srt/disaggregation/mini_lb.py b/python/sglang/srt/disaggregation/mini_lb.py index d3194e6742d..32cbbfdfbb3 100644 --- a/python/sglang/srt/disaggregation/mini_lb.py +++ b/python/sglang/srt/disaggregation/mini_lb.py @@ -6,6 +6,7 @@ import random import urllib from itertools import chain +from typing import List import aiohttp import orjson @@ -14,13 +15,22 @@ from fastapi.responses import ORJSONResponse, Response, StreamingResponse +class PrefillConfig: + def __init__(self, url: str, bootstrap_port: int): + self.url = url + self.bootstrap_port = bootstrap_port + + class MiniLoadBalancer: - def __init__(self, prefill_servers, decode_servers): - self.prefill_servers = prefill_servers + def __init__(self, prefill_configs: List[PrefillConfig], decode_servers: List[str]): + self.prefill_configs = prefill_configs + self.prefill_servers = [p.url for p in prefill_configs] self.decode_servers = decode_servers def select_pair(self): - return random.choice(self.prefill_servers), random.choice(self.decode_servers) + prefill_config = random.choice(self.prefill_configs) + decode_server = random.choice(self.decode_servers) + return prefill_config.url, prefill_config.bootstrap_port, decode_server async def generate( self, modified_request, prefill_server, decode_server, endpoint @@ -160,7 +170,7 @@ async def get_model_info(): @app.post("/generate") async def handle_generate_request(request_data: dict): - prefill_server, decode_server = load_balancer.select_pair() + prefill_server, bootstrap_port, decode_server = load_balancer.select_pair() # Parse and transform prefill_server for bootstrap data parsed_url = urllib.parse.urlparse(prefill_server) @@ -172,6 +182,7 @@ async def handle_generate_request(request_data: dict): modified_request.update( { "bootstrap_host": [hostname] * batch_size, + "bootstrap_port": [bootstrap_port] * batch_size, "bootstrap_room": [ _generate_bootstrap_room() for _ in range(batch_size) ], @@ -181,6 +192,7 @@ async def handle_generate_request(request_data: dict): modified_request.update( { "bootstrap_host": hostname, + "bootstrap_port": bootstrap_port, "bootstrap_room": _generate_bootstrap_room(), } ) @@ -197,7 +209,7 @@ async def handle_generate_request(request_data: dict): @app.post("/v1/chat/completions") async def handle_completion_request(request_data: dict): - prefill_server, decode_server = load_balancer.select_pair() + prefill_server, bootstrap_port, decode_server = load_balancer.select_pair() # Parse and transform prefill_server for bootstrap data parsed_url = urllib.parse.urlparse(prefill_server) @@ -206,6 +218,7 @@ async def handle_completion_request(request_data: dict): modified_request.update( { "bootstrap_host": hostname, + "bootstrap_port": bootstrap_port, "bootstrap_room": random.randint(0, 2**63 - 1), } ) @@ -255,9 +268,9 @@ async def get_models(): raise HTTPException(status_code=500, detail=str(e)) -def run(prefill_addrs, decode_addrs, host, port): +def run(prefill_configs, decode_addrs, host, port): global load_balancer - load_balancer = MiniLoadBalancer(prefill_addrs, decode_addrs) + load_balancer = MiniLoadBalancer(prefill_configs, decode_addrs) uvicorn.run(app, host=host, port=port) @@ -268,6 +281,11 @@ def run(prefill_addrs, decode_addrs, host, port): parser.add_argument( "--prefill", required=True, help="Comma-separated URLs for prefill servers" ) + parser.add_argument( + "--prefill-bootstrap-ports", + help="Comma-separated bootstrap ports for prefill servers", + default="8998", + ) parser.add_argument( "--decode", required=True, help="Comma-separated URLs for decode servers" ) @@ -278,4 +296,23 @@ def run(prefill_addrs, decode_addrs, host, port): "--port", type=int, default=8000, help="Port to bind the server (default: 8000)" ) args = parser.parse_args() - run(args.prefill.split(","), args.decode.split(","), args.host, args.port) + + prefill_urls = args.prefill.split(",") + bootstrap_ports = [int(p) for p in args.prefill_bootstrap_ports.split(",")] + + if len(bootstrap_ports) == 1: + bootstrap_ports = bootstrap_ports * len(prefill_urls) + else: + if len(bootstrap_ports) != len(prefill_urls): + raise ValueError( + "Number of prefill URLs must match number of bootstrap ports" + ) + exit(1) + + prefill_configs = [] + for url, port in zip(prefill_urls, bootstrap_ports): + prefill_configs.append(PrefillConfig(url, port)) + + decode_addrs = args.decode.split(",") + + run(prefill_configs, decode_addrs, args.host, args.port) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 9d9b576c866..174656b2dcb 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -97,6 +97,7 @@ class GenerateReqInput: # For disaggregated inference bootstrap_host: Optional[Union[List[str], str]] = None + bootstrap_port: Optional[Union[List[int], int]] = None bootstrap_room: Optional[Union[List[int], int]] = None def normalize_batch_and_arguments(self): @@ -400,6 +401,9 @@ def __getitem__(self, i): bootstrap_host=( self.bootstrap_host[i] if self.bootstrap_host is not None else None ), + bootstrap_port=( + self.bootstrap_port[i] if self.bootstrap_port is not None else None + ), bootstrap_room=( self.bootstrap_room[i] if self.bootstrap_room is not None else None ), @@ -447,6 +451,7 @@ class TokenizedGenerateReqInput: # For disaggregated inference bootstrap_host: Optional[str] = None + bootstrap_port: Optional[int] = None bootstrap_room: Optional[int] = None diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index f7ea34d5063..6dea9321c48 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -391,6 +391,7 @@ def __init__( return_hidden_states: bool = False, eos_token_ids: Optional[Set[int]] = None, bootstrap_host: Optional[str] = None, + bootstrap_port: Optional[int] = None, bootstrap_room: Optional[int] = None, ): # Input and output info @@ -526,6 +527,7 @@ def __init__( # For disaggregation self.bootstrap_host: str = bootstrap_host + self.bootstrap_port: Optional[int] = bootstrap_port self.bootstrap_room: Optional[int] = bootstrap_room self.disagg_kv_sender: Optional[BaseKVSender] = None diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 303c2205903..f0158b345ec 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -787,6 +787,7 @@ def handle_generate_request( return_hidden_states=recv_req.return_hidden_states, eos_token_ids=self.model_config.hf_eos_token_id, bootstrap_host=recv_req.bootstrap_host, + bootstrap_port=recv_req.bootstrap_port, bootstrap_room=recv_req.bootstrap_room, ) req.tokenizer = self.tokenizer diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 65ee4d6d306..82709b09592 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -498,6 +498,7 @@ def _create_tokenized_object( token_ids_logprob, obj.stream, bootstrap_host=obj.bootstrap_host, + bootstrap_port=obj.bootstrap_port, bootstrap_room=obj.bootstrap_room, lora_path=obj.lora_path, input_embeds=input_embeds,