Skip to content
41 changes: 26 additions & 15 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class ServerArgs:
chunked_prefill_size: Optional[int] = None
max_prefill_tokens: int = 16384
schedule_policy: str = "fcfs"
schedule_conservativeness: float = 1.0
schedule_conservativeness: Optional[float] = None
cpu_offload_gb: int = 0
page_size: int = 1

Expand Down Expand Up @@ -234,6 +234,11 @@ def __post_init__(self):
self.chunked_prefill_size = 2048
else:
self.chunked_prefill_size = 8192
if self.enable_dp_attention:
self.chunked_prefill_size //= self.dp_size
logger.warning(
f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. "
)

assert self.chunked_prefill_size % self.page_size == 0

Expand Down Expand Up @@ -280,17 +285,15 @@ def __post_init__(self):
f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
)

if self.schedule_conservativeness is None:
self.schedule_conservativeness = 0.3 if self.enable_dp_attention else 1.0

# Data parallelism attention
if self.enable_dp_attention:
self.schedule_conservativeness = self.schedule_conservativeness * 0.3
assert (
self.dp_size > 1
), "Please set a dp-size > 1. You can use 1 < dp-size <= tp-size "
assert self.tp_size % self.dp_size == 0
self.chunked_prefill_size = self.chunked_prefill_size // self.dp_size
logger.warning(
f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. "
)

self.enable_sp_layernorm = False
# DeepEP MoE
Expand Down Expand Up @@ -1208,29 +1211,37 @@ def init_new(server_args, dp_rank: Optional[int] = None) -> "PortArgs":
dist_init_addr = ("127.0.0.1", server_args.port + ZMQ_TCP_PORT_DELTA)
elif server_args.dist_init_addr.startswith("["): # ipv6 address
port_num, host = configure_ipv6(server_args.dist_init_addr)
dist_init_addr = (host, str(port_num))
dist_init_addr = (host, port_num)
else:
dist_init_addr = server_args.dist_init_addr.split(":")
host, port_str = server_args.dist_init_addr.split(":")
dist_init_addr = (host, int(port_str))

assert (
len(dist_init_addr) == 2
), "please provide --dist-init-addr as host:port of head node"

dist_init_host, dist_init_port = dist_init_addr
port_base = int(dist_init_port) + 1
port_base = dist_init_port + 1
port_args = {}
for name in (
"tokenizer_ipc_name",
"detokenizer_ipc_name",
"rpc_ipc_name",
):
port_args[name] = (
f"tcp://{dist_init_host}:{port_base + len(port_args)}"
)
if dp_rank is None:
scheduler_input_port = (
port_base + 3
scheduler_input_port = port_base + len(
port_args
) # TokenizerManager to DataParallelController
else:
scheduler_input_port = port_base + 3 + 1 + dp_rank
scheduler_input_port = port_base + len(port_args) + 1 + dp_rank

return PortArgs(
tokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base}",
**port_args,
scheduler_input_ipc_name=f"tcp://{dist_init_host}:{scheduler_input_port}",
detokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base + 1}",
nccl_port=port,
rpc_ipc_name=f"tcp://{dist_init_host}:{port_base + 2}",
)


Expand Down
Loading