Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
e92d0cc
rust lb init
hnyls2002 May 19, 2025
6b65b4c
add timeout configurable
hnyls2002 May 25, 2025
6a2b515
udpate lb args
hnyls2002 May 25, 2025
64c4aef
optimize error handling
hnyls2002 May 25, 2025
25e0eca
Merge branch 'main' into lsyin-rust-lb
hnyls2002 May 25, 2025
d3b3da4
fake load report server
hnyls2002 May 25, 2025
d76b846
move code
hnyls2002 May 25, 2025
c0ea969
merge code
hnyls2002 May 25, 2025
a81dc9e
add post init check
hnyls2002 May 25, 2025
a7d0703
tmp fix for get_model_info
hnyls2002 May 26, 2025
d16248f
support batch in /generate
hnyls2002 May 26, 2025
be7cb59
fix
hnyls2002 May 26, 2025
417f35b
fix request type check
hnyls2002 May 26, 2025
94fedb3
fix get load
hnyls2002 May 26, 2025
807892d
split func
hnyls2002 May 26, 2025
b4331e7
fix launch_lb type hint
hnyls2002 May 26, 2025
41f4c56
update toml
hnyls2002 May 26, 2025
dd479ea
use anyhow server to handle error
hnyls2002 May 26, 2025
c2f8bb4
merge route and generate func
hnyls2002 May 26, 2025
f736a59
route one for get model info
hnyls2002 May 26, 2025
2b7bdf3
add proxy response struct
hnyls2002 May 26, 2025
95d23e1
fix token_ids is a list of list
hnyls2002 May 27, 2025
2975ce2
dispatch static req struct
hnyls2002 May 28, 2025
c726cdc
optimize error and collect
hnyls2002 May 28, 2025
0161989
remove duplicate clone() on server end
hnyls2002 May 28, 2025
b647b2a
add fixme
hnyls2002 May 28, 2025
74129a6
remove scripts
hnyls2002 May 28, 2025
55d2831
Merge branch 'main' into lsyin-rust-lb
hnyls2002 May 28, 2025
b86c21c
Merge branch 'main' into lsyin-rust-lb
zhyncs May 28, 2025
0c6440f
fix typo
hnyls2002 May 29, 2025
72310fd
add missing service
hnyls2002 May 29, 2025
a4ec1e1
use generics
hnyls2002 May 29, 2025
9c2c1fb
simplify code
hnyls2002 May 29, 2025
09a2ded
rename: ProxyResponseType -> ProxyResponseBody
hnyls2002 May 29, 2025
fba68af
optimize build method
hnyls2002 May 29, 2025
ca49a1a
use reqwest::Method
hnyls2002 May 29, 2025
2ec0af8
nits optimize
hnyls2002 May 29, 2025
6dbbb1a
move lb_state out of server.rs
hnyls2002 May 29, 2025
b501141
Merge branch 'main' into lsyin-rust-lb
hnyls2002 May 29, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions bench.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
MODEL=meta-llama/Llama-3.1-8B-Instruct
BACKEND=${BACKEND:-sglang}
INPUT=${INPUT:-8000}
OUTPUT=${OUTPUT:-500}
PORT=${PORT:-8000}

python3 -m sglang.bench_serving \
--backend $BACKEND \
--dataset-name random \
--num-prompts 500 \
--random-input $INPUT \
--random-output $OUTPUT \
--random-range-ratio 1 \
--port $PORT \
--dataset-name "random" \
--model $MODEL \
--pd-separated \
--request-rate 5
1 change: 1 addition & 0 deletions kill.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ps -ef | grep sgl | grep -v grep | awk '{print $2}' | xargs kill -9
8 changes: 8 additions & 0 deletions profile.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
PORT=${PORT:-30000}

curl -X POST http://localhost:$PORT/start_profile \
-H 'Content-Type: application/json' \
-d '{
"num_steps": 10,
"output_dir": "./"
}'
141 changes: 141 additions & 0 deletions python/sglang/srt/disaggregation/launch_lb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import argparse
import dataclasses
from typing import List


@dataclasses.dataclass
class LBArgs:
rust_lb: bool = False
host: str = "0.0.0.0"
port: int = 8000
policy: str = "random"
prefill_infos: List[str] = dataclasses.field(default_factory=list)
decode_infos: List[str] = dataclasses.field(default_factory=list)
log_interval: int = 5
timeout: int = 600

@staticmethod
def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument(
"--rust-lb",
action="store_true",
help="Use Rust load balancer",
)
parser.add_argument(
"--host",
type=str,
default=LBArgs.host,
help=f"Host to bind the server (default: {LBArgs.host})",
)
parser.add_argument(
"--port",
type=int,
default=LBArgs.port,
help=f"Port to bind the server (default: {LBArgs.port})",
)
parser.add_argument(
"--policy",
type=str,
default=LBArgs.policy,
choices=["random", "po2"],
help=f"Policy to use for load balancing (default: {LBArgs.policy})",
)
parser.add_argument(
"--prefill",
type=str,
default=[],
nargs="+",
help="URLs for prefill servers",
)
parser.add_argument(
"--decode",
type=str,
default=[],
nargs="+",
help="URLs for decode servers",
)
parser.add_argument(
"--prefill-bootstrap-ports",
type=int,
nargs="+",
help="Bootstrap ports for prefill servers",
)
parser.add_argument(
"--log-interval",
type=int,
default=LBArgs.log_interval,
help=f"Log interval in seconds (default: {LBArgs.log_interval})",
)
parser.add_argument(
"--timeout",
type=int,
default=LBArgs.timeout,
help=f"Timeout in seconds (default: {LBArgs.timeout})",
)

@classmethod
def from_cli_args(cls, args: argparse.Namespace) -> "LBArgs":
bootstrap_ports = args.prefill_bootstrap_ports
if bootstrap_ports is None:
bootstrap_ports = [None] * len(args.prefill)
elif len(bootstrap_ports) == 1:
bootstrap_ports = bootstrap_ports * len(args.prefill)
else:
if len(bootstrap_ports) != len(args.prefill):
raise ValueError(
"Number of prefill URLs must match number of bootstrap ports"
)

prefill_infos = [
(url, port) for url, port in zip(args.prefill, bootstrap_ports)
]

return cls(
rust_lb=args.rust_lb,
host=args.host,
port=args.port,
policy=args.policy,
prefill_infos=prefill_infos,
decode_infos=args.decode,
log_interval=args.log_interval,
timeout=args.timeout,
)

def __post_init__(self):
if not self.rust_lb:
assert (
self.policy == "random"
), "Only random policy is supported for Python load balancer"


def main():
parser = argparse.ArgumentParser(
description="PD Disaggregation Load Balancer Server"
)
LBArgs.add_cli_args(parser)
args = parser.parse_args()
lb_args = LBArgs.from_cli_args(args)

if lb_args.rust_lb:
from sgl_pdlb._rust import LoadBalancer as RustLB

RustLB(
host=lb_args.host,
port=lb_args.port,
policy=lb_args.policy,
prefill_infos=lb_args.prefill_infos,
decode_infos=lb_args.decode_infos,
log_interval=lb_args.log_interval,
timeout=lb_args.timeout,
).start()
else:
from sglang.srt.disaggregation.mini_lb import PrefillConfig, run

prefill_configs = [
PrefillConfig(url, port) for url, port in lb_args.prefill_infos
]
run(prefill_configs, lb_args.decode_infos, lb_args.host, lb_args.port)


if __name__ == "__main__":
main()
41 changes: 3 additions & 38 deletions python/sglang/srt/disaggregation/mini_lb.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,42 +361,7 @@ def run(prefill_configs, decode_addrs, host, port):


if __name__ == "__main__":
import argparse
# FIXME: remove this, use the unified entry point: sglang.srt.disaggregation.launch_lb
from sglang.srt.disaggregation.launch_lb import main

parser = argparse.ArgumentParser(description="Mini Load Balancer Server")
parser.add_argument(
"--prefill", type=str, default=[], nargs="+", help="URLs for prefill servers"
)
parser.add_argument(
"--decode", type=str, default=[], nargs="+", help="URLs for decode servers"
)
parser.add_argument(
"--prefill-bootstrap-ports",
type=int,
nargs="+",
help="Bootstrap ports for prefill servers",
)
parser.add_argument(
"--host", default="0.0.0.0", help="Host to bind the server (default: 0.0.0.0)"
)
parser.add_argument(
"--port", type=int, default=8000, help="Port to bind the server (default: 8000)"
)
args = parser.parse_args()

bootstrap_ports = args.prefill_bootstrap_ports
if bootstrap_ports is None:
bootstrap_ports = [None] * len(args.prefill)
elif len(bootstrap_ports) == 1:
bootstrap_ports = bootstrap_ports * len(args.prefill)
else:
if len(bootstrap_ports) != len(args.prefill):
raise ValueError(
"Number of prefill URLs must match number of bootstrap ports"
)

prefill_configs = [
PrefillConfig(url, port) for url, port in zip(args.prefill, bootstrap_ports)
]

run(prefill_configs, args.decode, args.host, args.port)
main()
5 changes: 5 additions & 0 deletions python/sglang/srt/entrypoints/http_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,11 @@ async def get_server_info():
}


@app.get("/get_load")
async def get_load():
return {"load": await _global_state.tokenizer_manager.get_load()}


@app.api_route("/set_internal_state", methods=["POST", "PUT"])
async def set_internal_state(obj: SetInternalStateReq, request: Request):
res = await _global_state.tokenizer_manager.set_internal_state(obj)
Expand Down
22 changes: 20 additions & 2 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1920,9 +1920,27 @@ def get_internal_state(self, recv_req: GetInternalStateReq):
)
if RECORD_STEP_TIME:
ret["step_time_dict"] = self.step_time_dict
return GetInternalStateReqOutput(
internal_state=ret,

num_used_tokens = (
self.max_total_num_tokens
- self.token_to_kv_pool_allocator.available_size()
- self.tree_cache.evictable_size()
)
ret["load"] = num_used_tokens
# TODO(lsyin): use dynamically maintained num_waiting_tokens
ret["load"] += sum(len(req.origin_input_ids) for req in self.waiting_queue)
if self.disaggregation_mode == DisaggregationMode.PREFILL:
ret["load"] += sum(
len(req.origin_input_ids)
for req in self.disagg_prefill_bootstrap_queue.queue
)
elif self.disaggregation_mode == DisaggregationMode.DECODE:
ret["load"] += sum(
len(req.req.origin_input_ids)
for req in self.disagg_decode_prealloc_queue.queue
)

return GetInternalStateReqOutput(internal_state=ret)

def set_internal_state(self, recv_req: SetInternalStateReq):
server_args_dict = recv_req.server_args
Expand Down
11 changes: 11 additions & 0 deletions python/sglang/srt/managers/tokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,9 @@ def __init__(
self.server_args.disaggregation_bootstrap_port
)

self.current_load = 0
self.current_load_lock = asyncio.Lock()

async def generate_request(
self,
obj: Union[GenerateReqInput, EmbeddingReqInput],
Expand Down Expand Up @@ -983,6 +986,14 @@ async def get_internal_state(self) -> List[Dict[Any, Any]]:
# Many DP ranks
return [res.internal_state for res in responses]

async def get_load(self) -> int:
# TODO(lsyin): fake load report server
if not self.current_load_lock.locked():
async with self.current_load_lock:
internal_state = await self.get_internal_state()
self.current_load = internal_state[0]["load"]
return self.current_load

async def set_internal_state(
self, obj: SetInternalStateReq
) -> SetInternalStateReqOutput:
Expand Down
5 changes: 5 additions & 0 deletions register.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
curl -X POST http://localhost:8000/register -H "Content-Type: application/json" -d '{
"mode": "prefill",
"url": "http://localhost:30000",
"bootstrap_port": 8998
}'
34 changes: 34 additions & 0 deletions run_engines.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
MODEL=meta-llama/Llama-3.1-8B-Instruct
DEVICES="mlx5_1,mlx5_2,mlx5_3,mlx5_4"
PAGE_SIZE=32
ATTENTION=fa3
TRANSFER=mooncake
HOST=0.0.0.0
PDLB=http://localhost:8000

python -m sglang.launch_server \
--model-path $MODEL \
--disaggregation-mode prefill \
--host $HOST \
--port 30000 \
--page-size $PAGE_SIZE \
--disaggregation-ib-device $DEVICES \
--disable-radix-cache \
--disaggregation-transfer-backend $TRANSFER \
--attention-backend $ATTENTION \
--pdlb-url $PDLB \
--tp-size 2 &

python -m sglang.launch_server \
--model-path $MODEL \
--disaggregation-mode decode \
--host $HOST \
--port 40000 \
--page-size $PAGE_SIZE \
--base-gpu-id 2 \
--disaggregation-ib-device $DEVICES \
--disable-radix-cache \
--disaggregation-transfer-backend $TRANSFER \
--attention-backend $ATTENTION \
--pdlb-url $PDLB \
--tp-size 2
3 changes: 3 additions & 0 deletions run_lb.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# python3 -m sglang.srt.disaggregation.mini_lb --prefill http://localhost:30000 --decode http://localhost:40000 --host 0.0.0.0 --port 8000

python3 -m sgl_pdlb.launch_lb --prefill http://localhost:30000 --decode http://localhost:40000 --host 0.0.0.0 --port 8000
14 changes: 14 additions & 0 deletions run_normal.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
MODEL=/models/hub/Meta-Llama-3.1-8B-Instruct/
PAGE_SIZE=1
ATTENTION=fa3

python -m sglang.launch_server \
--model-path $MODEL \
--port 33000 \
--page-size $PAGE_SIZE \
--host localhost \
--attention-backend $ATTENTION \
--disable-radix-cache \
--max-total-tokens 80000 \
--base-gpu-id 4 \
--tp-size 2
39 changes: 39 additions & 0 deletions run_pd.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
MODEL=meta-llama/Llama-3.1-8B-Instruct
DEVICES="mlx5_101,mlx5_102,mlx5_103"
PAGE_SIZE=64
ATTENTION=fa3
TRANSFER=mooncake

# python3 -m sglang.srt.disaggregation.mini_lb --prefill http://localhost:30000 --decode http://localhost:40000 --host 0.0.0.0 --port 8000 &

python3 -m sgl_pdlb.launch_lb --prefill http://localhost:30000 --decode http://localhost:40000 --host 0.0.0.0 --port 8000 &

sleep 3

python -m sglang.launch_server \
--model-path $MODEL \
--disaggregation-mode prefill \
--port 30000 \
--page-size $PAGE_SIZE \
--host localhost \
--disaggregation-ib-device $DEVICES \
--disable-radix-cache \
--disaggregation-transfer-backend $TRANSFER \
--attention-backend $ATTENTION \
--tp-size 2 &

sleep 10

python -m sglang.launch_server \
--model-path $MODEL \
--disaggregation-mode decode \
--port 40000 \
--page-size $PAGE_SIZE \
--base-gpu-id 2 \
--disaggregation-ib-device $DEVICES \
--disable-radix-cache \
--disaggregation-transfer-backend $TRANSFER \
--attention-backend $ATTENTION \
--tp-size 2

sleep 10
9 changes: 9 additions & 0 deletions send_msg.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
curl -X POST http://127.0.0.1:8000/generate -H "Content-Type: application/json" -d '{
"text": "Where are you from?",
"sampling_params": {
"temperature": 0
},
"stream": false
}'

# curl http://127.0.0.1:8000/get_loads
2 changes: 2 additions & 0 deletions sgl-pdlb/.rustfmt.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
reorder_imports = true
reorder_modules = true
Loading
Loading