Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions docs/advanced_features/server_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -492,9 +492,6 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| `--disaggregation-mode` | Only used for PD disaggregation. "prefill" for prefill-only server, and "decode" for decode-only server. If not specified, it is not PD disaggregated | `null` | `null`, `prefill`, `decode` |
| `--disaggregation-transfer-backend` | The backend for disaggregation transfer. Default is mooncake. | `mooncake` | `mooncake`, `nixl`, `ascend`, `fake` |
| `--disaggregation-bootstrap-port` | Bootstrap server port on the prefill server. Default is 8998. | `8998` | Type: int |
| `--disaggregation-decode-tp` | Decode tp size. If not set, it matches the tp size of the current engine. This is only set on the prefill server. | `None` | Type: int |
| `--disaggregation-decode-dp` | Decode dp size. If not set, it matches the dp size of the current engine. This is only set on the prefill server. | `None` | Type: int |
| `--disaggregation-prefill-pp` | Prefill pp size. If not set, it is default to 1. This is only set on the decode server. | `1` | Type: int |
| `--disaggregation-ib-device` | The InfiniBand devices for disaggregation transfer, accepts single device (e.g., --disaggregation-ib-device mlx5_0) or multiple comma-separated devices (e.g., --disaggregation-ib-device mlx5_0,mlx5_1). Default is None, which triggers automatic device detection when mooncake backend is enabled. | `None` | Type: str |
| `--disaggregation-decode-enable-offload-kvcache` | Enable async KV cache offloading on decode server (PD mode). | `False` | bool flag (set to enable) |
| `--num-reserved-decode-tokens` | Number of decode tokens that will have memory reserved when adding new request to the running batch. | `512` | Type: int |
Expand Down
3 changes: 0 additions & 3 deletions python/sglang/srt/disaggregation/base/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,10 @@ class KVArgs:
ib_device: str
ib_traffic_class: str
gpu_id: int
# for different tp
decode_tp_size: int
kv_head_num: int
total_kv_head_num: int
page_size: int
# for pp prefill
prefill_pp_size: int
pp_rank: int
prefill_start_layer: int
# for system dp
Expand Down
5 changes: 0 additions & 5 deletions python/sglang/srt/disaggregation/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,6 @@ def __init__(
gpu_id: int,
bootstrap_port: int,
max_total_num_tokens: int,
prefill_pp_size: int,
pp_rank: int,
num_reserved_decode_tokens: int,
transfer_backend: TransferBackend,
Expand All @@ -265,15 +264,13 @@ def __init__(
self.gpu_id = gpu_id
self.bootstrap_port = bootstrap_port
self.max_total_num_tokens = max_total_num_tokens
self.prefill_pp_size = prefill_pp_size
self.pp_rank = pp_rank
self.num_reserved_decode_tokens = num_reserved_decode_tokens
self.transfer_backend = transfer_backend
# Queue for requests pending pre-allocation
self.queue: List[DecodeRequest] = []
self.retracted_queue: List[Req] = []
self.pending_reqs: List[Req] = []
self.prefill_pp_size = prefill_pp_size
self.kv_manager = self._init_kv_manager()

if self.scheduler.tp_worker.is_hybrid_swa:
Expand All @@ -290,10 +287,8 @@ def _init_kv_manager(self) -> CommonKVManager:
attn_tp_size = get_attention_tp_size()
kv_args.engine_rank = self.tp_rank % (attn_tp_size)

kv_args.decode_tp_size = attn_tp_size
kv_args.pp_rank = self.pp_rank
kv_args.system_dp_rank = self.scheduler.dp_rank
kv_args.prefill_pp_size = self.prefill_pp_size
kv_data_ptrs, kv_data_lens, kv_item_lens = (
self.token_to_kv_pool.get_contiguous_buf_infos()
)
Expand Down
6 changes: 0 additions & 6 deletions python/sglang/srt/disaggregation/prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,6 @@ def __init__(
bootstrap_port: int,
gloo_group: ProcessGroup,
max_total_num_tokens: int,
decode_tp_size: int,
decode_dp_size: int,
scheduler: Scheduler,
pp_rank: int,
pp_size: int,
Expand All @@ -109,8 +107,6 @@ def __init__(
self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
self.tp_rank = tp_rank
self.tp_size = tp_size
self.decode_tp_size = decode_tp_size
self.decode_dp_size = decode_dp_size
self.pp_rank = pp_rank
self.pp_size = pp_size
self.gpu_id = gpu_id
Expand All @@ -135,8 +131,6 @@ def _init_kv_manager(self) -> CommonKVManager:
kv_args.engine_rank = self.tp_rank
kv_args.pp_rank = self.pp_rank
kv_args.system_dp_rank = self.scheduler.dp_rank
kv_args.decode_tp_size = self.decode_tp_size // self.decode_dp_size
kv_args.prefill_pp_size = self.pp_size
kv_args.prefill_start_layer = self.token_to_kv_pool.start_layer
kv_data_ptrs, kv_data_lens, kv_item_lens = (
self.token_to_kv_pool.get_contiguous_buf_infos()
Expand Down
3 changes: 0 additions & 3 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -928,7 +928,6 @@ def init_disaggregation(self):
gpu_id=self.gpu_id,
bootstrap_port=self.server_args.disaggregation_bootstrap_port,
max_total_num_tokens=self.max_total_num_tokens,
prefill_pp_size=self.server_args.disaggregation_prefill_pp,
pp_rank=self.pp_rank,
num_reserved_decode_tokens=self.server_args.num_reserved_decode_tokens,
transfer_backend=self.transfer_backend,
Expand Down Expand Up @@ -968,8 +967,6 @@ def init_disaggregation(self):
bootstrap_port=self.server_args.disaggregation_bootstrap_port,
gloo_group=self.attn_tp_cpu_group,
max_total_num_tokens=self.max_total_num_tokens,
decode_tp_size=self.server_args.disaggregation_decode_tp,
decode_dp_size=self.server_args.disaggregation_decode_dp,
scheduler=self,
pp_rank=self.pp_rank,
pp_size=self.pp_size,
Expand Down
43 changes: 0 additions & 43 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,9 +662,6 @@ class ServerArgs:
disaggregation_mode: Literal["null", "prefill", "decode"] = "null"
disaggregation_transfer_backend: str = "mooncake"
disaggregation_bootstrap_port: int = 8998
disaggregation_decode_tp: Optional[int] = None
disaggregation_decode_dp: Optional[int] = None
disaggregation_prefill_pp: Optional[int] = 1
disaggregation_ib_device: Optional[str] = None
disaggregation_decode_enable_offload_kvcache: bool = False
num_reserved_decode_tokens: int = 512 # used for decode kv cache offload in PD
Expand Down Expand Up @@ -2600,27 +2597,13 @@ def _handle_load_format(self):

def _handle_pd_disaggregation(self):
if self.disaggregation_mode == "decode":
assert (
self.disaggregation_decode_tp is None
), "Cannot set --disaggregation-decode-tp for the decode engine."
assert (
self.disaggregation_decode_dp is None
), "Cannot set --disaggregation-decode-dp for the decode engine."

self.disable_radix_cache = True
logger.warning("KV cache is forced as chunk cache for decode server")

elif self.disaggregation_mode == "prefill":
assert (
self.disaggregation_transfer_backend != "fake"
), "Prefill server does not support 'fake' as the transfer backend"
if self.disaggregation_decode_tp is None:
self.disaggregation_decode_tp = self.tp_size
if self.disaggregation_decode_dp is None:
self.disaggregation_decode_dp = self.dp_size

self.disaggregation_prefill_pp = self.pp_size
self.validate_disagg_tp_size(self.tp_size, self.disaggregation_decode_tp)

if not self.enable_piecewise_cuda_graph:
self.disable_cuda_graph = True
Expand Down Expand Up @@ -4972,24 +4955,6 @@ def add_cli_args(parser: argparse.ArgumentParser):
default=ServerArgs.disaggregation_bootstrap_port,
help="Bootstrap server port on the prefill server. Default is 8998.",
)
parser.add_argument(
"--disaggregation-decode-tp",
type=int,
default=ServerArgs.disaggregation_decode_tp,
help="Decode tp size. If not set, it matches the tp size of the current engine. This is only set on the prefill server.",
)
parser.add_argument(
"--disaggregation-decode-dp",
type=int,
default=ServerArgs.disaggregation_decode_dp,
help="Decode dp size. If not set, it matches the dp size of the current engine. This is only set on the prefill server.",
)
parser.add_argument(
"--disaggregation-prefill-pp",
type=int,
default=ServerArgs.disaggregation_prefill_pp,
help="Prefill pp size. If not set, it is default to 1. This is only set on the decode server.",
)
parser.add_argument(
"--disaggregation-ib-device",
type=str,
Expand Down Expand Up @@ -5545,14 +5510,6 @@ def check_lora_server_args(self):
and (self.max_lora_chunk_size & (self.max_lora_chunk_size - 1)) == 0
), "--max-lora-chunk-size must be a power of 2 between 16 and 128."

def validate_disagg_tp_size(self, prefill_tp: int, decode_tp: int):
larger_tp = max(decode_tp, prefill_tp)
smaller_tp = min(decode_tp, prefill_tp)
assert larger_tp % smaller_tp == 0, (
"Different tp size is supported only when one tp is multiple of the other. "
f"decode_tp={decode_tp}, prefill_tp={prefill_tp}"
)

def validate_buckets_rule(self, arg_name: str, buckets_rule: List[str]):
if not buckets_rule:
return
Expand Down
Loading