diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 91ad8c04d4cb..4f510b61cd7b 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -892,18 +892,12 @@ def init_disaggregation(self): self.req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator( buffer_size ) + # Full-size buffer on both sides so the wire layout aligns + # under asymmetric P/D where one side may not run spec. self.disagg_metadata_buffers = MetadataBuffers( buffer_size, - hidden_size=( - model_config.spec_hidden_size - if self.spec_algorithm.is_eagle() - else 16 # minimal padding size for RDMA - ), - hidden_states_dtype=( - model_config.dtype - if self.spec_algorithm.is_eagle() - else torch.float32 - ), + hidden_size=model_config.spec_hidden_size, + hidden_states_dtype=model_config.dtype, custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(), ) @@ -946,20 +940,12 @@ def init_disaggregation(self): self.req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator( buffer_size ) + # See decode branch above. Asymmetric P/D: prefill without a + # spec module ships zeros, decode mocks first-step conditioning. self.disagg_metadata_buffers = MetadataBuffers( buffer_size, - hidden_size=( - model_config.spec_hidden_size - if self.spec_algorithm.is_eagle() - or self.spec_algorithm.is_standalone() - else 16 # minimal padding size for RDMA - ), - hidden_states_dtype=( - model_config.dtype - if self.spec_algorithm.is_eagle() - or self.spec_algorithm.is_standalone() - else torch.float32 - ), + hidden_size=model_config.spec_hidden_size, + hidden_states_dtype=model_config.dtype, custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(), ) diff --git a/test/registered/8-gpu-models/test_dsv4_pd_disagg_nixl.py b/test/registered/8-gpu-models/test_dsv4_pd_disagg_nixl.py index 15a4c4f4bd21..970868d6e35c 100644 --- a/test/registered/8-gpu-models/test_dsv4_pd_disagg_nixl.py +++ b/test/registered/8-gpu-models/test_dsv4_pd_disagg_nixl.py @@ -1,22 +1,7 @@ -"""DSv4 Flash PD-disaggregation test with NIXL transfer backend. - -Topology (1 H200 node, 8 GPUs total): - - Prefill: GPU 0-3, tp=4 — pure TP, **no EP** (no deepep), no DP - attention. Optimized for throughput on long prompts; each rank - holds the full MoE weights, no all-to-all dispatch traffic. - Spec config matches decode (PD ferry currently assumes symmetric - spec on both sides) so the prefill -> decode metadata buffer - is sized correctly for the spec module's hidden shape. - - Decode: GPU 4-7, tp=4 dp=4 enable-dp-attention + deepep + EAGLE - MTP — optimized for low-latency decode with spec decoding and - expert parallelism. - - Mini load balancer fronting both. - -Both sides use DSv4 Flash FP8 weights. Transfer backend is NIXL -(the focus of recent nixl/conn.py forward-delta work; this test is -the e2e check that the generic `send_state` / shared buffer-pool -changes do not break PD). -""" +"""DSv4 Flash PD-disagg with NIXL backend, asymmetric: prefill is pure +TP with no spec module, decode runs EAGLE MTP. Prefill ships a zero- +init hidden buffer; decode mocks first-step conditioning, verify keeps +the output correct.""" import unittest from types import SimpleNamespace @@ -38,9 +23,7 @@ DSV4_FLASH_ENV = { "SGLANG_DSV4_FP4_EXPERTS": "0", - # Decode side runs MTP with num_draft_tokens=4 → dispatch input scales - # by ~4x, so default 256 overflows once cuda-graph-max-bs * draft > 256. - # 1024 covers bs=128 * 4 with headroom (no-op on prefill which has no EP). + # MTP num_draft_tokens=4 scales dispatch by ~4x; 256 overflows at bs=128. "SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK": "1024", "SGLANG_JIT_DEEPGEMM_PRECOMPILE": "0", } @@ -65,9 +48,6 @@ def setUpClass(cls): @classmethod def start_prefill(cls): - # Prefill: TP=4 (no EP, no DP attention). EAGLE config mirrors decode - # so the metadata buffer is sized for the spec module's hidden shape; - # PD ferry currently assumes both sides agree on the spec algorithm. prefill_args = [ "--trust-remote-code", "--disaggregation-mode", @@ -86,14 +66,6 @@ def start_prefill(cls): "4", "--disaggregation-decode-dp", "4", - "--speculative-algorithm", - "EAGLE", - "--speculative-num-steps", - "3", - "--speculative-eagle-topk", - "1", - "--speculative-num-draft-tokens", - "4", *cls.transfer_backend, *cls.rdma_devices, ] @@ -107,7 +79,6 @@ def start_prefill(cls): @classmethod def start_decode(cls): - # Decode: TP=4 + DP=4 attention + deepep EP + EAGLE MTP. decode_args = [ "--trust-remote-code", "--disaggregation-mode",