Skip to content
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
e92d0cc
rust lb init
hnyls2002 May 19, 2025
6b65b4c
add timeout configurable
hnyls2002 May 25, 2025
6a2b515
udpate lb args
hnyls2002 May 25, 2025
64c4aef
optimize error handling
hnyls2002 May 25, 2025
25e0eca
Merge branch 'main' into lsyin-rust-lb
hnyls2002 May 25, 2025
d3b3da4
fake load report server
hnyls2002 May 25, 2025
d76b846
move code
hnyls2002 May 25, 2025
c0ea969
merge code
hnyls2002 May 25, 2025
a81dc9e
add post init check
hnyls2002 May 25, 2025
a7d0703
tmp fix for get_model_info
hnyls2002 May 26, 2025
d16248f
support batch in /generate
hnyls2002 May 26, 2025
be7cb59
fix
hnyls2002 May 26, 2025
417f35b
fix request type check
hnyls2002 May 26, 2025
94fedb3
fix get load
hnyls2002 May 26, 2025
807892d
split func
hnyls2002 May 26, 2025
b4331e7
fix launch_lb type hint
hnyls2002 May 26, 2025
41f4c56
update toml
hnyls2002 May 26, 2025
dd479ea
use anyhow server to handle error
hnyls2002 May 26, 2025
c2f8bb4
merge route and generate func
hnyls2002 May 26, 2025
f736a59
route one for get model info
hnyls2002 May 26, 2025
2b7bdf3
add proxy response struct
hnyls2002 May 26, 2025
95d23e1
fix token_ids is a list of list
hnyls2002 May 27, 2025
2975ce2
dispatch static req struct
hnyls2002 May 28, 2025
c726cdc
optimize error and collect
hnyls2002 May 28, 2025
0161989
remove duplicate clone() on server end
hnyls2002 May 28, 2025
b647b2a
add fixme
hnyls2002 May 28, 2025
74129a6
remove scripts
hnyls2002 May 28, 2025
55d2831
Merge branch 'main' into lsyin-rust-lb
hnyls2002 May 28, 2025
b86c21c
Merge branch 'main' into lsyin-rust-lb
zhyncs May 28, 2025
0c6440f
fix typo
hnyls2002 May 29, 2025
72310fd
add missing service
hnyls2002 May 29, 2025
a4ec1e1
use generics
hnyls2002 May 29, 2025
9c2c1fb
simplify code
hnyls2002 May 29, 2025
09a2ded
rename: ProxyResponseType -> ProxyResponseBody
hnyls2002 May 29, 2025
fba68af
optimize build method
hnyls2002 May 29, 2025
ca49a1a
use reqwest::Method
hnyls2002 May 29, 2025
2ec0af8
nits optimize
hnyls2002 May 29, 2025
6dbbb1a
move lb_state out of server.rs
hnyls2002 May 29, 2025
b501141
Merge branch 'main' into lsyin-rust-lb
hnyls2002 May 29, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 140 additions & 0 deletions python/sglang/srt/disaggregation/launch_lb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import argparse
import dataclasses


@dataclasses.dataclass
class LBArgs:
rust_lb: bool = False
host: str = "0.0.0.0"
port: int = 8000
policy: str = "random"
prefill_infos: list = dataclasses.field(default_factory=list)
decode_infos: list = dataclasses.field(default_factory=list)
log_interval: int = 5
timeout: int = 600

@staticmethod
def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument(
"--rust-lb",
action="store_true",
help="Use Rust load balancer",
)
parser.add_argument(
"--host",
type=str,
default=LBArgs.host,
help=f"Host to bind the server (default: {LBArgs.host})",
)
parser.add_argument(
"--port",
type=int,
default=LBArgs.port,
help=f"Port to bind the server (default: {LBArgs.port})",
)
parser.add_argument(
"--policy",
type=str,
default=LBArgs.policy,
choices=["random", "po2"],
help=f"Policy to use for load balancing (default: {LBArgs.policy})",
)
parser.add_argument(
"--prefill",
type=str,
default=[],
nargs="+",
help="URLs for prefill servers",
)
parser.add_argument(
"--decode",
type=str,
default=[],
nargs="+",
help="URLs for decode servers",
)
parser.add_argument(
"--prefill-bootstrap-ports",
type=int,
nargs="+",
help="Bootstrap ports for prefill servers",
)
parser.add_argument(
"--log-interval",
type=int,
default=LBArgs.log_interval,
help=f"Log interval in seconds (default: {LBArgs.log_interval})",
)
parser.add_argument(
"--timeout",
type=int,
default=LBArgs.timeout,
help=f"Timeout in seconds (default: {LBArgs.timeout})",
)

@classmethod
def from_cli_args(cls, args: argparse.Namespace) -> "LBArgs":
bootstrap_ports = args.prefill_bootstrap_ports
if bootstrap_ports is None:
bootstrap_ports = [None] * len(args.prefill)
elif len(bootstrap_ports) == 1:
bootstrap_ports = bootstrap_ports * len(args.prefill)
else:
if len(bootstrap_ports) != len(args.prefill):
raise ValueError(
"Number of prefill URLs must match number of bootstrap ports"
)

prefill_infos = [
(url, port) for url, port in zip(args.prefill, bootstrap_ports)
]

return cls(
rust_lb=args.rust_lb,
host=args.host,
port=args.port,
policy=args.policy,
prefill_infos=prefill_infos,
decode_infos=args.decode,
log_interval=args.log_interval,
timeout=args.timeout,
)

def __post_init__(self):
if not self.rust_lb:
assert (
self.policy == "random"
), "Only random policy is supported for Python load balancer"


def main():
parser = argparse.ArgumentParser(
description="PD Disaggregation Load Balancer Server"
)
LBArgs.add_cli_args(parser)
args = parser.parse_args()
lb_args = LBArgs.from_cli_args(args)

if lb_args.rust_lb:
from sgl_pdlb._rust import LoadBalancer as RustLB

RustLB(
host=lb_args.host,
port=lb_args.port,
policy=lb_args.policy,
prefill_infos=lb_args.prefill_infos,
decode_infos=lb_args.decode_infos,
log_interval=lb_args.log_interval,
timeout=lb_args.timeout,
).start()
else:
from sglang.srt.disaggregation.mini_lb import PrefillConfig, run

prefill_configs = [
PrefillConfig(url, port) for url, port in lb_args.prefill_infos
]
run(prefill_configs, lb_args.decode_infos, lb_args.host, lb_args.port)


if __name__ == "__main__":
main()
41 changes: 3 additions & 38 deletions python/sglang/srt/disaggregation/mini_lb.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,42 +368,7 @@ def run(prefill_configs, decode_addrs, host, port):


if __name__ == "__main__":
import argparse
# FIXME: remove this, use the unified entry point: sglang.srt.disaggregation.launch_lb
from sglang.srt.disaggregation.launch_lb import main

parser = argparse.ArgumentParser(description="Mini Load Balancer Server")
parser.add_argument(
"--prefill", type=str, default=[], nargs="+", help="URLs for prefill servers"
)
parser.add_argument(
"--decode", type=str, default=[], nargs="+", help="URLs for decode servers"
)
parser.add_argument(
"--prefill-bootstrap-ports",
type=int,
nargs="+",
help="Bootstrap ports for prefill servers",
)
parser.add_argument(
"--host", default="0.0.0.0", help="Host to bind the server (default: 0.0.0.0)"
)
parser.add_argument(
"--port", type=int, default=8000, help="Port to bind the server (default: 8000)"
)
args = parser.parse_args()

bootstrap_ports = args.prefill_bootstrap_ports
if bootstrap_ports is None:
bootstrap_ports = [None] * len(args.prefill)
elif len(bootstrap_ports) == 1:
bootstrap_ports = bootstrap_ports * len(args.prefill)
else:
if len(bootstrap_ports) != len(args.prefill):
raise ValueError(
"Number of prefill URLs must match number of bootstrap ports"
)

prefill_configs = [
PrefillConfig(url, port) for url, port in zip(args.prefill, bootstrap_ports)
]

run(prefill_configs, args.decode, args.host, args.port)
main()
5 changes: 5 additions & 0 deletions python/sglang/srt/entrypoints/http_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,11 @@ async def get_server_info():
}


@app.get("/get_load")
async def get_load():
return await _global_state.tokenizer_manager.get_load()


@app.api_route("/set_internal_state", methods=["POST", "PUT"])
async def set_internal_state(obj: SetInternalStateReq, request: Request):
res = await _global_state.tokenizer_manager.set_internal_state(obj)
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/managers/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ class GenerateReqInput:

# For disaggregated inference
bootstrap_host: Optional[Union[List[str], str]] = None
bootstrap_port: Optional[Union[List[int], int]] = None
bootstrap_port: Optional[Union[List[Optional[int]], int]] = None
bootstrap_room: Optional[Union[List[int], int]] = None

def contains_mm_input(self) -> bool:
Expand Down
28 changes: 25 additions & 3 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1911,6 +1911,27 @@ def flush_cache(self):
if_success = False
return if_success

def get_load(self):
# TODO(lsyin): use dynamically maintained num_waiting_tokens
load = (
self.max_total_num_tokens
- self.token_to_kv_pool_allocator.available_size()
- self.tree_cache.evictable_size()
)
load += sum(len(req.origin_input_ids) for req in self.waiting_queue)
if self.disaggregation_mode == DisaggregationMode.PREFILL:
load += sum(
len(req.origin_input_ids)
for req in self.disagg_prefill_bootstrap_queue.queue
)
elif self.disaggregation_mode == DisaggregationMode.DECODE:
load += sum(
len(req.req.origin_input_ids)
for req in self.disagg_decode_prealloc_queue.queue
)

return load

def get_internal_state(self, recv_req: GetInternalStateReq):
ret = dict(global_server_args_dict)
ret["last_gen_throughput"] = self.last_gen_throughput
Expand All @@ -1920,9 +1941,10 @@ def get_internal_state(self, recv_req: GetInternalStateReq):
)
if RECORD_STEP_TIME:
ret["step_time_dict"] = self.step_time_dict
return GetInternalStateReqOutput(
internal_state=ret,
)

ret["load"] = self.get_load()

return GetInternalStateReqOutput(internal_state=ret)

def set_internal_state(self, recv_req: SetInternalStateReq):
server_args_dict = recv_req.server_args
Expand Down
11 changes: 11 additions & 0 deletions python/sglang/srt/managers/tokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,9 @@ def __init__(
self.server_args.disaggregation_bootstrap_port
)

self.current_load = 0
self.current_load_lock = asyncio.Lock()

async def generate_request(
self,
obj: Union[GenerateReqInput, EmbeddingReqInput],
Expand Down Expand Up @@ -983,6 +986,14 @@ async def get_internal_state(self) -> List[Dict[Any, Any]]:
# Many DP ranks
return [res.internal_state for res in responses]

async def get_load(self) -> dict:
# TODO(lsyin): fake load report server
if not self.current_load_lock.locked():
async with self.current_load_lock:
internal_state = await self.get_internal_state()
self.current_load = internal_state[0]["load"]
return {"load": self.current_load}

async def set_internal_state(
self, obj: SetInternalStateReq
) -> SetInternalStateReqOutput:
Expand Down
2 changes: 2 additions & 0 deletions sgl-pdlb/.rustfmt.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
reorder_imports = true
reorder_modules = true
28 changes: 28 additions & 0 deletions sgl-pdlb/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
[package]
edition = "2024"
name = "sgl-pdlb"
version = "0.1.0"

[lib]
crate-type = ["cdylib", "rlib"]
name = "sgl_pdlb_rs"

[dependencies]
actix-web = "4.11"
bytes = "1.8.0"
chrono = "0.4.38"
clap = { version = "4.4", features = ["derive"] }
dashmap = "6.1.0"
env_logger = "0.11.5"
futures = "0.3"
futures-util = "0.3"
http = "1.3.1"
log = "0.4.22"
pyo3 = { version = "0.25.0", features = ["extension-module"] }
rand = "0.9.0"
reqwest = { version = "0.12.8", features = ["stream", "blocking", "json"] }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
tokio = { version = "1.34", features = ["full"] }
anyhow = "1.0.98"
typetag = "0.2.20"
12 changes: 12 additions & 0 deletions sgl-pdlb/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
### Install dependencies

```bash
pip install "maturin[patchelf]"
```

### Build and install

```bash
maturin develop
pip install -e .
```
1 change: 1 addition & 0 deletions sgl-pdlb/py_src/sgl_pdlb/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
__version__ = "0.0.1"
14 changes: 14 additions & 0 deletions sgl-pdlb/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
[build-system]
requires = ["maturin>=1.8.0"]
build-backend = "maturin"

[project]
name = "sgl_pdlb"
version = "0.0.1"

[tool.maturin]
python-source = "py_src"
module-name = "sgl_pdlb._rust"

[tool.maturin.build-backend]
features = ["pyo3/extension-module"]
Loading
Loading