Fix DSv4 attention backend for EAGLE per-step draft#24750
Conversation
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
4518746 to
a7d3504
Compare
The ``DeepseekV4MultiStepDraftBackend`` (added in sgl-project#23882) constructs one ``DeepseekV4AttnBackend`` per spec step and forwards the same ``ForwardBatch`` to each step's ``init_forward_metadata``. For ``forward_mode == decode``, that method passed ``forward_batch.out_cache_loc`` straight through to ``init_forward_metadata_decode``, which then asserted ``out_cache_loc.shape[0] == req_pool_indices.shape[0] == seq_lens.shape[0]``. But in EAGLE draft, ``forward_batch.out_cache_loc`` has shape ``[bs * speculative_num_steps]`` (the full spec-decode cache-loc tensor — every step writes one new slot per request), while ``req_pool_indices`` and ``seq_lens`` are still at the unrepeated ``[bs]`` shape. The assertion fires whenever a draft batch reaches this path with ``bs >= 2``: AssertionError: req_pool_indices.shape=torch.Size([2]) seq_lens.shape=torch.Size([2]) out_cache_loc.shape=torch.Size([6]) In larger configurations (TP=8 / EP=8 / DP=8 multi-node decode) the shapes can coincidentally line up so the assertion does not fire, but the metadata is still mis-aligned: GSM8K accuracy collapses from ~0.93 to ~0.42 and decoded outputs are visibly malformed (stray ``Weapon:`` / ``Weaponry`` / ``Weaponized`` prefixes). Mirror the FA3 backend's draft-decode handling at ``flashattention_backend.py`` (``cache_seqlens_int32 = seqlens_in_batch + (self.speculative_step_id + 1)``): when the backend was constructed as a per-step draft step (``self.speculative_num_steps > 0``), slice ``out_cache_loc`` to this step's portion and advance ``seq_lens`` / ``max_seq_len`` by ``step_id + 1`` before calling ``init_forward_metadata_decode``. Closes sgl-project#24747 Signed-off-by: Cheng Wan <wan4ch@gmail.com>
a7d3504 to
ac6b542
Compare
|
Heads-up: posted an update on #24747 — this PR holds for the monolithic case (verified gsm8k 0.975 / 200 questions on a TP=4 single-node setup with the failing 5-shot returning the correct answer end-to-end), but the disaggregated prefill+decode + EAGLE+DSv4 path is still broken with the same patch applied. Same image, same recipe topology, prompt_tokens identical to the working monolithic path (1128), MTP accept-rate healthy, but the model resumes from a previous few-shot assistant turn instead of generating from |
| if bucket == _GraphBucket.DECODE_OR_IDLE: | ||
| assert out_cache_loc is not None | ||
| assert len(out_cache_loc.shape) == 1, f"{out_cache_loc.shape=}" | ||
| if self.speculative_num_steps > 0: |
There was a problem hiding this comment.
If the bug only happens when disabling cuda graph, these lines are unnecessary?
|
|
||
| if forward_batch.forward_mode.is_decode_or_idle(): | ||
| out_cache_loc = forward_batch.out_cache_loc | ||
| if self.speculative_num_steps > 0: |
There was a problem hiding this comment.
Maybe only do this when cuda graph is disabled
Summary
The
DeepseekV4MultiStepDraftBackend(added in #23882) constructs oneDeepseekV4AttnBackendper spec step and forwards the sameForwardBatchto each step'sinit_forward_metadata. Forforward_mode == decode, that method passedforward_batch.out_cache_locstraight through toinit_forward_metadata_decode, which then assertedout_cache_loc.shape[0] == req_pool_indices.shape[0] == seq_lens.shape[0].But in EAGLE draft,
forward_batch.out_cache_lochas shape[bs * speculative_num_steps](the full spec-decode cache-loc tensor — every step writes one new slot per request), whilereq_pool_indicesandseq_lensare still at the unrepeated[bs]shape. The assertion fires whenever a draft batch reaches this path withbs >= 2:In larger configurations (TP=8 / EP=8 / DP=8 multi-node decode) the shapes can coincidentally line up so the assertion does not fire, but the metadata is still mis-aligned: GSM8K accuracy collapses from ~0.93 to ~0.42 and decoded outputs are visibly malformed (stray
Weapon:/Weaponry/Weaponizedprefixes).Mirror the FA3 backend's draft-decode handling at
flashattention_backend.py(cache_seqlens_int32 = seqlens_in_batch + (self.speculative_step_id + 1)): when the backend was constructed as a per-step draft step (self.speculative_num_steps > 0), sliceout_cache_locto this step's portion and advanceseq_lens/max_seq_lenbystep_id + 1before callinginit_forward_metadata_decode.Closes #24747.
Reproducer
lmsysorg/sglang:nightly-dev-cu13-20260509-9ee83034(built from main9ee83034):Send any 5-shot gsm8k chat-completion request — server crashes with the assertion above. With this patch, server completes normally and returns the correct answer (verified on the same prompt that GB300 disagg-decode CI was getting wrong).
Test plan
Related
Signed-off-by: Cheng Wan wan4ch@gmail.com