Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 8 additions & 22 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -892,18 +892,12 @@ def init_disaggregation(self):
self.req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
buffer_size
)
# Full-size buffer on both sides so the wire layout aligns
# under asymmetric P/D where one side may not run spec.
self.disagg_metadata_buffers = MetadataBuffers(
buffer_size,
hidden_size=(
model_config.spec_hidden_size
if self.spec_algorithm.is_eagle()
else 16 # minimal padding size for RDMA
),
hidden_states_dtype=(
model_config.dtype
if self.spec_algorithm.is_eagle()
else torch.float32
),
hidden_size=model_config.spec_hidden_size,
hidden_states_dtype=model_config.dtype,
custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
)

Expand Down Expand Up @@ -946,20 +940,12 @@ def init_disaggregation(self):
self.req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
buffer_size
)
# See decode branch above. Asymmetric P/D: prefill without a
# spec module ships zeros, decode mocks first-step conditioning.
self.disagg_metadata_buffers = MetadataBuffers(
buffer_size,
hidden_size=(
model_config.spec_hidden_size
if self.spec_algorithm.is_eagle()
or self.spec_algorithm.is_standalone()
else 16 # minimal padding size for RDMA
),
hidden_states_dtype=(
model_config.dtype
if self.spec_algorithm.is_eagle()
or self.spec_algorithm.is_standalone()
else torch.float32
),
hidden_size=model_config.spec_hidden_size,
hidden_states_dtype=model_config.dtype,
custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
)

Expand Down
39 changes: 5 additions & 34 deletions test/registered/8-gpu-models/test_dsv4_pd_disagg_nixl.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,7 @@
"""DSv4 Flash PD-disaggregation test with NIXL transfer backend.

Topology (1 H200 node, 8 GPUs total):
- Prefill: GPU 0-3, tp=4 — pure TP, **no EP** (no deepep), no DP
attention. Optimized for throughput on long prompts; each rank
holds the full MoE weights, no all-to-all dispatch traffic.
Spec config matches decode (PD ferry currently assumes symmetric
spec on both sides) so the prefill -> decode metadata buffer
is sized correctly for the spec module's hidden shape.
- Decode: GPU 4-7, tp=4 dp=4 enable-dp-attention + deepep + EAGLE
MTP — optimized for low-latency decode with spec decoding and
expert parallelism.
- Mini load balancer fronting both.

Both sides use DSv4 Flash FP8 weights. Transfer backend is NIXL
(the focus of recent nixl/conn.py forward-delta work; this test is
the e2e check that the generic `send_state` / shared buffer-pool
changes do not break PD).
"""
"""DSv4 Flash PD-disagg with NIXL backend, asymmetric: prefill is pure
TP with no spec module, decode runs EAGLE MTP. Prefill ships a zero-
init hidden buffer; decode mocks first-step conditioning, verify keeps
the output correct."""

import unittest
from types import SimpleNamespace
Expand All @@ -38,9 +23,7 @@

DSV4_FLASH_ENV = {
"SGLANG_DSV4_FP4_EXPERTS": "0",
# Decode side runs MTP with num_draft_tokens=4 → dispatch input scales
# by ~4x, so default 256 overflows once cuda-graph-max-bs * draft > 256.
# 1024 covers bs=128 * 4 with headroom (no-op on prefill which has no EP).
# MTP num_draft_tokens=4 scales dispatch by ~4x; 256 overflows at bs=128.
"SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK": "1024",
"SGLANG_JIT_DEEPGEMM_PRECOMPILE": "0",
}
Expand All @@ -65,9 +48,6 @@ def setUpClass(cls):

@classmethod
def start_prefill(cls):
# Prefill: TP=4 (no EP, no DP attention). EAGLE config mirrors decode
# so the metadata buffer is sized for the spec module's hidden shape;
# PD ferry currently assumes both sides agree on the spec algorithm.
prefill_args = [
"--trust-remote-code",
"--disaggregation-mode",
Expand All @@ -86,14 +66,6 @@ def start_prefill(cls):
"4",
"--disaggregation-decode-dp",
"4",
"--speculative-algorithm",
"EAGLE",
"--speculative-num-steps",
"3",
"--speculative-eagle-topk",
"1",
"--speculative-num-draft-tokens",
"4",
*cls.transfer_backend,
*cls.rdma_devices,
]
Expand All @@ -107,7 +79,6 @@ def start_prefill(cls):

@classmethod
def start_decode(cls):
# Decode: TP=4 + DP=4 attention + deepep EP + EAGLE MTP.
decode_args = [
"--trust-remote-code",
"--disaggregation-mode",
Expand Down
Loading