Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/sglang/srt/disaggregation/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def add(self, req: Req) -> None:
kv_receiver_class = get_kv_class(self.transfer_backend, KVClassType.RECEIVER)
kv_receiver = kv_receiver_class(
mgr=self.kv_manager,
bootstrap_addr=f"{req.bootstrap_host}:{self.bootstrap_port}",
bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}",
bootstrap_room=req.bootstrap_room,
)
self.queue.append(DecodeRequest(req=req, kv_receiver=kv_receiver))
Expand Down
53 changes: 45 additions & 8 deletions python/sglang/srt/disaggregation/mini_lb.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import random
import urllib
from itertools import chain
from typing import List

import aiohttp
import orjson
Expand All @@ -14,13 +15,22 @@
from fastapi.responses import ORJSONResponse, Response, StreamingResponse


class PrefillConfig:
def __init__(self, url: str, bootstrap_port: int):
self.url = url
self.bootstrap_port = bootstrap_port


class MiniLoadBalancer:
def __init__(self, prefill_servers, decode_servers):
self.prefill_servers = prefill_servers
def __init__(self, prefill_configs: List[PrefillConfig], decode_servers: List[str]):
self.prefill_configs = prefill_configs
self.prefill_servers = [p.url for p in prefill_configs]
self.decode_servers = decode_servers

def select_pair(self):
return random.choice(self.prefill_servers), random.choice(self.decode_servers)
prefill_config = random.choice(self.prefill_configs)
decode_server = random.choice(self.decode_servers)
return prefill_config.url, prefill_config.bootstrap_port, decode_server

async def generate(
self, modified_request, prefill_server, decode_server, endpoint
Expand Down Expand Up @@ -160,7 +170,7 @@ async def get_model_info():

@app.post("/generate")
async def handle_generate_request(request_data: dict):
prefill_server, decode_server = load_balancer.select_pair()
prefill_server, bootstrap_port, decode_server = load_balancer.select_pair()

# Parse and transform prefill_server for bootstrap data
parsed_url = urllib.parse.urlparse(prefill_server)
Expand All @@ -172,6 +182,7 @@ async def handle_generate_request(request_data: dict):
modified_request.update(
{
"bootstrap_host": [hostname] * batch_size,
"bootstrap_port": [bootstrap_port] * batch_size,
"bootstrap_room": [
_generate_bootstrap_room() for _ in range(batch_size)
],
Expand All @@ -181,6 +192,7 @@ async def handle_generate_request(request_data: dict):
modified_request.update(
{
"bootstrap_host": hostname,
"bootstrap_port": bootstrap_port,
"bootstrap_room": _generate_bootstrap_room(),
}
)
Expand All @@ -197,7 +209,7 @@ async def handle_generate_request(request_data: dict):

@app.post("/v1/chat/completions")
async def handle_completion_request(request_data: dict):
prefill_server, decode_server = load_balancer.select_pair()
prefill_server, bootstrap_port, decode_server = load_balancer.select_pair()

# Parse and transform prefill_server for bootstrap data
parsed_url = urllib.parse.urlparse(prefill_server)
Expand All @@ -206,6 +218,7 @@ async def handle_completion_request(request_data: dict):
modified_request.update(
{
"bootstrap_host": hostname,
"bootstrap_port": bootstrap_port,
"bootstrap_room": random.randint(0, 2**63 - 1),
}
)
Expand Down Expand Up @@ -255,9 +268,9 @@ async def get_models():
raise HTTPException(status_code=500, detail=str(e))


def run(prefill_addrs, decode_addrs, host, port):
def run(prefill_configs, decode_addrs, host, port):
global load_balancer
load_balancer = MiniLoadBalancer(prefill_addrs, decode_addrs)
load_balancer = MiniLoadBalancer(prefill_configs, decode_addrs)
uvicorn.run(app, host=host, port=port)


Expand All @@ -268,6 +281,11 @@ def run(prefill_addrs, decode_addrs, host, port):
parser.add_argument(
"--prefill", required=True, help="Comma-separated URLs for prefill servers"
)
parser.add_argument(
"--prefill-bootstrap-ports",
help="Comma-separated bootstrap ports for prefill servers",
default="8998",
)
parser.add_argument(
"--decode", required=True, help="Comma-separated URLs for decode servers"
)
Expand All @@ -278,4 +296,23 @@ def run(prefill_addrs, decode_addrs, host, port):
"--port", type=int, default=8000, help="Port to bind the server (default: 8000)"
)
args = parser.parse_args()
run(args.prefill.split(","), args.decode.split(","), args.host, args.port)

prefill_urls = args.prefill.split(",")
bootstrap_ports = [int(p) for p in args.prefill_bootstrap_ports.split(",")]

if len(bootstrap_ports) == 1:
bootstrap_ports = bootstrap_ports * len(prefill_urls)
else:
if len(bootstrap_ports) != len(prefill_urls):
raise ValueError(
"Number of prefill URLs must match number of bootstrap ports"
)
exit(1)

prefill_configs = []
for url, port in zip(prefill_urls, bootstrap_ports):
prefill_configs.append(PrefillConfig(url, port))

decode_addrs = args.decode.split(",")

run(prefill_configs, decode_addrs, args.host, args.port)
5 changes: 5 additions & 0 deletions python/sglang/srt/managers/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ class GenerateReqInput:

# For disaggregated inference
bootstrap_host: Optional[Union[List[str], str]] = None
bootstrap_port: Optional[Union[List[int], int]] = None
bootstrap_room: Optional[Union[List[int], int]] = None

def normalize_batch_and_arguments(self):
Expand Down Expand Up @@ -400,6 +401,9 @@ def __getitem__(self, i):
bootstrap_host=(
self.bootstrap_host[i] if self.bootstrap_host is not None else None
),
bootstrap_port=(
self.bootstrap_port[i] if self.bootstrap_port is not None else None
),
bootstrap_room=(
self.bootstrap_room[i] if self.bootstrap_room is not None else None
),
Expand Down Expand Up @@ -447,6 +451,7 @@ class TokenizedGenerateReqInput:

# For disaggregated inference
bootstrap_host: Optional[str] = None
bootstrap_port: Optional[int] = None
bootstrap_room: Optional[int] = None


Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,7 @@ def __init__(
return_hidden_states: bool = False,
eos_token_ids: Optional[Set[int]] = None,
bootstrap_host: Optional[str] = None,
bootstrap_port: Optional[int] = None,
bootstrap_room: Optional[int] = None,
):
# Input and output info
Expand Down Expand Up @@ -526,6 +527,7 @@ def __init__(

# For disaggregation
self.bootstrap_host: str = bootstrap_host
self.bootstrap_port: Optional[int] = bootstrap_port
self.bootstrap_room: Optional[int] = bootstrap_room
self.disagg_kv_sender: Optional[BaseKVSender] = None

Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,6 +787,7 @@ def handle_generate_request(
return_hidden_states=recv_req.return_hidden_states,
eos_token_ids=self.model_config.hf_eos_token_id,
bootstrap_host=recv_req.bootstrap_host,
bootstrap_port=recv_req.bootstrap_port,
bootstrap_room=recv_req.bootstrap_room,
)
req.tokenizer = self.tokenizer
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/managers/tokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,7 @@ def _create_tokenized_object(
token_ids_logprob,
obj.stream,
bootstrap_host=obj.bootstrap_host,
bootstrap_port=obj.bootstrap_port,
bootstrap_room=obj.bootstrap_room,
lora_path=obj.lora_path,
input_embeds=input_embeds,
Expand Down
Loading