[Bugfix] Fix GDN conv + SSM state corruption with ngram spec decode#40738
[Bugfix] Fix GDN conv + SSM state corruption with ngram spec decode#40738tdoublep wants to merge 1 commit into
Conversation
There was a problem hiding this comment.
Code Review
This pull request updates the GDN and Mamba2 attention backends to handle SSM state indexing after speculative decoding by passing the full block table when no prefills are present. A review comment identifies a potential issue in mixed batches where decode requests might read stale SSM state data because the implementation defaults to the first block when prefills are present.
bfccc78 to
1b2ba55
Compare
When ngram speculative decoding accepts N tokens, the next non-spec decode step reads SSM and conv state from block 0 instead of block N-1, causing output corruption in GDN hybrid models (e.g. Qwen3.5). Changes: - Add spec_decode_src_indices to GDNAttentionMetadata, computed when accepted state lives in a different block than the decode target. - Pre-copy SSM state from the accepted block to block 0 before kernel dispatch in both _forward_core and _forward_core_decode_non_spec. - Pass num_accepted_tokens to causal_conv1d_fn/causal_conv1d_update so the kernel reads/writes conv state at the correct offset. - Add IS_SPEC_DECODING constexpr and conv_state_token_offset logic to _causal_conv1d_fwd_kernel for offset-based conv state correction. - Pad num_accepted_tokens with 1s for prefill sequences in mixed batches so the kernel sees offset=0 for those sequences. - Gate num_accepted_tokens on spec_decode_src_indices to prevent stale values from leaking to non-correction steps. - Pass num_accepted_tokens from gpu_model_runner to GDN/Mamba2 attention builders on non-spec steps. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
1b2ba55 to
6b5632c
Compare
|
@tdoublep — first, big thanks for this PR. It was the single highest-impact fix in our v7.13 investigation — the +30-40% clean-rate jump (from ~20% baseline to 53-70%) we got from backporting Phase 1 + Phase 2 was the data point that confirmed the GDN+ngram interaction layer was where our bug actually lived. Without your work + @bhaktatejas922's original report at #39273, we'd still be searching the wrong layers. Sharing our independent confirmation + one observation below. Independent confirmation on a non-trivial production setupBackported both Phase 1 (Python-only: gdn_attn.py, gdn_linear_attn.py, gpu_model_runner.py) and Phase 2 (Triton kernel + wrapper changes in causal_conv1d.py + gdn_linear_attn call sites) onto vLLM Empirical clean-rate progression (n=30 reproducer per row)
The Phase 1 → Phase 1+Phase 2 delta confirms your "kernel fix is necessary" remark in the PR description — Phase 1 alone gets us partway, Phase 2 closes the rest of the GDN-state-corruption slice. One observation about the existing
|
|
Batch size 1 seems to work well. Anything above batch size 1 seems to be broken and result in weird results |
PR vllm-project#40738 (commit 6b5632c) only fixes the all-non-spec branch in GDNAttentionMetadataBuilder.build() — the case where every request in the batch is doing regular decode this step. The else-branch (mixed batch — some reqs have drafts to verify, others don't) has the same state-corruption pattern unfixed: - non_spec_state_indices_tensor reads block_table[~mask, 0] (column 0, stale if those reqs had spec on the previous step) - num_accepted_tokens is filtered to spec-only, dropping the non-spec partition's counts entirely - spec_decode_src_indices is never built → layer skips the SSM gather- scatter and conv kernel offset for the non-speccing partition This is structurally invisible at BS=1 (no other req → every step is pure-spec or pure-non-spec, the if-branch always wins). At BS>1 with heterogeneous prompts (apply workloads — different code, different ngram histories per request) it triggers most steps where some req's drafter returns empty while others return drafts. Per-fire damage scales with num_speculative_tokens — at k=64 the non-speccing req in a mixed batch reads recurrent state up to ~64 tokens stale. Empirical signature (Qwen3.5 fast-apply, 32-layer hybrid, FP8, k=64): dataset 0926-0930__easy_random_prod (329 ex): BS=1 acc_sorted: 0.9605 pre-fix BS=8: 0.5714 post-fix BS=8: 0.9574 (within 0.003 of BS=1) dataset 0923_0925_random_easy_regression_check (79 ex): BS=1 acc_sorted: 0.9747 pre-fix BS=8: 0.6962 post-fix BS=8: 0.9747 (parity) Fix mirrors the if-branch's logic in the inner-else (the actual mixed- batch case): build spec_decode_src_indices from the non-speccing partition's num_accepted_tokens BEFORE the spec-only filter is applied, and add a non_spec_num_accepted field on GDNAttentionMetadata so the layer's prefill conv1d call uses the right per-row offsets. A repro that triggers this needs heterogeneous prompts/lengths so the batch goes mixed (some reqs with drafts, some without). Identical prompts in parallel keep the batch homogeneous and miss it — that's likely why the original PR's BS>1 tests passed.
|
This pull request has merge conflicts that must be resolved before it can be |
…lm-project#40880) Backport two upstream fixes for degenerate output when combining speculative decoding with TurboQuant KV cache quantization: P65 (vllm-project#40880): Downgrade TurboQuantMetadataBuilder._cudagraph_support from UNIFORM_BATCH to UNIFORM_SINGLE_TOKEN_DECODE. Prevents CUDA graph capture of spec-decode K+1 verify batches which baked stale cu_seqlens_k into the continuation-prefill path, causing the kernel to attend only to query tokens and ignore all prior cached KV. vllm-project#40738: Fix GDN conv + SSM state corruption with spec decode. On non-spec steps after token acceptance, pre-copy SSM state from the accepted block and offset conv state reads by num_accepted_tokens-1. Co-authored-by: Wibey CLI <genai-coding-assistants@walmart.com>
…vllm-project#40880) Backport two upstream fixes for degenerate output when combining speculative decoding with TurboQuant KV cache quantization and Qwen3.5 Gated Delta Net (GDN) layers. Downgrade TurboQuantMetadataBuilder._cudagraph_support from UNIFORM_BATCH to UNIFORM_SINGLE_TOKEN_DECODE. With UNIFORM_BATCH, the captured graph bakes cu_seqlens_k = cu_seqlens_q into the continuation- prefill path, so the kernel attends only to the current K+1 query tokens and ignores all prior cached KV, producing degenerate output. UNIFORM_SINGLE_TOKEN_DECODE forces spec-verify batches to eager/PIECEWISE where TQ's continuation branch correctly decompresses the full cache. On non-spec steps after token acceptance, pre-copy SSM state from the accepted block to block 0 so the next decode reads correct state. Offset conv state reads by num_accepted_tokens-1 via a new parameter threaded through causal_conv1d_fn and the Triton conv kernel.
Summary
Fix output corruption when using ngram speculative decoding with hybrid GDN models (e.g., Qwen3.5) in
mamba_cache_mode="none".After a spec decode step accepts N tokens, the next non-spec decode step must read SSM state from block N-1 and conv state from an offset position. Two bugs prevented this:
num_accepted_tokenswas not passed to SSM metadata builders on non-spec stepscausal_conv1d_fnhad no mechanism to offset conv state reads based on accepted tokensChanges
gdn_attn.py: Computespec_decode_src_indicesfor SSM state correction; padnum_accepted_tokenswith 1s for prefill sequences in mixed batchesgdn_linear_attn.py: Pre-copy SSM state from accepted block to block 0; passnum_accepted_tokensto conv kernels gated on whether correction is neededcausal_conv1d.py: AddIS_SPEC_DECODINGpath to_causal_conv1d_fwd_kernelthat offsets conv state reads/writes bynum_accepted_tokens - 1gpu_model_runner.py: Passnum_accepted_tokensto GDN/Mamba2 builders on non-spec stepsTest plan
Reproducers
Single-prompt:
Mixed-batch (short + long prompt,
max_num_batched_tokens=64): seereproduce_gdn_ngram_mixed.pyin the branch.Fixes #39273
AI-assisted: Yes (Claude). Not duplicating any existing PR.