[Bug] Add dsv4 state_type branch to mooncake disaggregation#24878
[Bug] Add dsv4 state_type branch to mooncake disaggregation#24878ch-wan merged 1 commit intosgl-project:mainfrom
Conversation
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
5ebc8be to
47da450
Compare
|
/tag-and-rerun-ci |
There was a problem hiding this comment.
nit: could elif state_type == "dsv4" branch join here as well?
There was a problem hiding this comment.
Done — folded into the existing branch (force-pushed). The new diff is +7/-1: just adds "dsv4" to the existing state_type in ["swa", "nsa"] list, with a one-line comment about how the compressed-MLA PP/MTP layout is already handled by get_mla_kv_ptrs_with_pp. Thanks!
PR sgl-project#23882 introduced ``state_type="dsv4"`` for the new DeepSeek-V4 flat heterogeneous state pool (SWA + compress + indexer pools) and added a matching branch to ``NixlKVManager.maybe_send_extra``, but the mooncake sibling was never updated. As a result, DSv4 disaggregated runs with the mooncake transfer backend silently fall through ``maybe_send_extra``'s final ``else: return 0`` branch -- the SWA / compress / indexer state is **never** transferred from prefill to decode, and the decode-side state pool keeps whatever it was initialized with, producing wrong outputs whenever the model attends to that state. Repro shape on GB300 disaggregated DSv4-Pro: - Same byte-identical ``prompt_token_ids`` ending on the literal ``<think>`` token (id ``128821``). - Monolithic sglang: correct (gsm8k Janet question -> ``#### 18``). - Disagg + **NIXL**: correct. - Disagg + **mooncake**: wrong (model regurgitates an earlier few-shot answer, e.g. ``Weapon: ... #### 84``). The corruption is most visible when the prompt's last attended position lives in the SWA / indexer state pool, which the missing branch silently leaves untransferred. Cases where attention happens to land entirely in the K/V pool transferred via ``send_kvcache`` don't surface the bug. ## Modifications Add the missing ``dsv4`` branch to ``MooncakeKVManager.maybe_send_extra`` that delegates to the existing ``_send_kvcache_generic`` -- the same helper the ``swa`` / ``nsa`` branches use. This routes DSv4's flat state pool through ``get_mla_kv_ptrs_with_pp``, which already understands the compressed-MLA PP/MTP layout (so the MTP/nextn tail entry on the decode side is naturally accommodated). The diff is purely additive and bails for asymmetric TP between prefill and decode (matching the constraint of the surrounding ``swa`` / ``nsa`` path). Existing ``mamba`` / ``swa`` / ``nsa`` / ``none`` paths are unchanged. ## Validation GB300 1P+1D DSv4-Pro disagg, 8k/1k sa-bench (10 prompts, conc=1) on mooncake with the patched ``conn.py`` hot-mounted into the latest nightly container. Engine bring-up clean, no transfer-worker errors, median TTFT ~950ms / TPOT ~11.5ms / output throughput ~80 tok/s. The pre-patch ``<think>``-ending repro that produced ``Weapon: ... #### 84`` on mooncake now matches the NIXL output. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
47da450 to
fd7e9ca
Compare
|
End-to-end MTP validation on the consolidated commit ( GB300 1P+1D DSv4-Pro disagg, mooncake backend, with EAGLE-3 MTP enabled ( Compared to the non-MTP run on the same setup (80 tok/s output, 11.5 ms TPOT), MTP is delivering its expected ~1.7× speedup, which is direct evidence the indexer/draft state pool is being transferred correctly — the failure mode this PR is fixing would either fire a So:
Both backends (NIXL pre-existing + this mooncake addition) now have working |
Picks up sgl-project/sglang#24878 (merged as c7f674e4), which adds the missing dsv4 state_type branch to MooncakeKVManager.maybe_send_extra. Combined with the prior revert of #1297's nixl switch (commit daa6785), the mooncake backend now correctly transfers DSv4's flat heterogeneous state pool for both non-MTP and MTP runs. Validated on GB300 1P+1D: comp_with_think.json (the prompt ending on the literal `<think>` token that previously surfaced the corruption) now returns the correct gsm8k Janet answer (`#### 18`) on mooncake disagg, matching mono and the NIXL control. MTP sa-bench delivers ~136 tok/s output throughput (~1.7x non-MTP), confirming draft acceptance is working. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Motivation
PR #23882 introduced
state_type="dsv4"for the new DeepSeek-V4 flat heterogeneous state pool (SWA + compress + indexer pools) and added a matching branch toNixlKVManager.maybe_send_extra. The mooncake sibling,MooncakeKVManager.maybe_send_extra, was never updated.DSv4 disaggregated runs over the mooncake transfer backend silently fall through
maybe_send_extra's finalelse: return 0branch -- the SWA / compress / indexer state pool is never transferred from prefill to decode. The decode-side state buffers keep whatever stale data they were initialized with, producing wrong outputs whenever the model attends to that state.Repro on GB300 disaggregated DSv4-Pro,
/v1/completionswith rawprompt_token_idsending on the literal<think>token (id128821):#### 18).Weapon: ... #### 84).The corruption is most visible when the prompt's last attended position lives in the SWA / indexer state pool. Other endings often look correct because the K/V cache itself does transfer via
send_kvcache; the<think>ending is the worst-case noise location for that token's attention pattern.Modifications
Add the missing
dsv4branch toMooncakeKVManager.maybe_send_extrathat delegates to the existing_send_kvcache_generic-- the same helper theswa/nsabranches use. This routes DSv4's flat state pool throughget_mla_kv_ptrs_with_pp, which iterates per-pool entry and uses prefill'sstate_item_lensfor offset arithmetic.Why delegation rather than a fresh per-page flat path:
src_state_item_lens[i] == dst_state_item_lens[i]for every entry. With MTP enabled, decode's indexer-pool entry is 2x prefill's (decode carries the EAGLE-3 draft layer), so the assertion fires on the very first transfer and the engine never makes progress._send_kvcache_genericmatches what the staging DSv4 branch did before this code path was split into a dedicateddsv4state_type. Prefill writes its half-size at the natural offset on decode; the decode-only MTP half is left untouched, which is correct -- decode populates its own draft state.The diff is purely additive and bails for asymmetric TP between prefill and decode (matching the constraint of the surrounding
swa/nsapath). Existingmamba/swa/nsa/nonepaths are unchanged.Validation
GB300 1P+1D DSv4-Pro disagg, mooncake backend, with the patched
conn.pyhot-mounted intolmsysorg/sglang:nightly-dev-cu13-20260509-9ee83034.<think>-ending repro returns wrong gsm8k answer (Weapon: ... #### 84)<think>-ending repro returns correct#### 18. sa-bench (10 prompts, conc=1, ISL=8192/OSL=1024) completed cleanly: TTFT ~950 ms, TPOT ~11.5 ms, output throughput ~80 tok/s.state_item_lensmismatch assertion on the first transfer; mooncake unpatched silently corruptsaccept rate 0.66-0.70and ~190 tok/s gen throughput. Zerostate_item_lensasserts, zero transfer-worker errors.Sample post-patch response for the original
<think>-ending repro on mooncake:Checklist
cc @hnyls2002 (PR #23882 author).