Skip to content

[Bugfix] Fix GDN conv + SSM state corruption with ngram spec decode#40738

Open
tdoublep wants to merge 1 commit into
vllm-project:mainfrom
tdoublep:fix/gdn-ngram-spec-decode-state
Open

[Bugfix] Fix GDN conv + SSM state corruption with ngram spec decode#40738
tdoublep wants to merge 1 commit into
vllm-project:mainfrom
tdoublep:fix/gdn-ngram-spec-decode-state

Conversation

@tdoublep
Copy link
Copy Markdown
Member

@tdoublep tdoublep commented Apr 23, 2026

Summary

Fix output corruption when using ngram speculative decoding with hybrid GDN models (e.g., Qwen3.5) in mamba_cache_mode="none".

After a spec decode step accepts N tokens, the next non-spec decode step must read SSM state from block N-1 and conv state from an offset position. Two bugs prevented this:

  1. num_accepted_tokens was not passed to SSM metadata builders on non-spec steps
  2. causal_conv1d_fn had no mechanism to offset conv state reads based on accepted tokens

Changes

  • gdn_attn.py: Compute spec_decode_src_indices for SSM state correction; pad num_accepted_tokens with 1s for prefill sequences in mixed batches
  • gdn_linear_attn.py: Pre-copy SSM state from accepted block to block 0; pass num_accepted_tokens to conv kernels gated on whether correction is needed
  • causal_conv1d.py: Add IS_SPEC_DECODING path to _causal_conv1d_fwd_kernel that offsets conv state reads/writes by num_accepted_tokens - 1
  • gpu_model_runner.py: Pass num_accepted_tokens to GDN/Mamba2 builders on non-spec steps

Test plan

  • Single-prompt: baseline vs ngram match token-for-token (Qwen3.5-0.8B)
  • Mixed-batch: short prompt matches baseline; long prompt generates coherent output
  • Kernel fix verified necessary: disabling conv offset causes regression
  • Existing GDN + spec decode CI tests
Reproducers

Single-prompt:

from vllm import LLM, SamplingParams
MODEL, PROMPT = "Qwen/Qwen3.5-0.8B", "<code>\nclass Calculator:\n    def add(self, a, b):\n        return a + b\n</code>\n<update>\nAdd subtract and multiply methods\n</update>"
ARGS = dict(model=MODEL, trust_remote_code=True, enforce_eager=True, enable_chunked_prefill=True, max_model_len=4096)
SPEC = {"method": "ngram", "num_speculative_tokens": 5, "prompt_lookup_max": 10, "prompt_lookup_min": 2}
S = SamplingParams(max_tokens=200, temperature=0)
b = list(LLM(**ARGS).generate([PROMPT], S)[0].outputs[0].token_ids)
n = list(LLM(**ARGS, speculative_config=SPEC).generate([PROMPT], S)[0].outputs[0].token_ids)
print("PASS" if b == n else "FAIL")

Mixed-batch (short + long prompt, max_num_batched_tokens=64): see reproduce_gdn_ngram_mixed.py in the branch.

Fixes #39273

AI-assisted: Yes (Claude). Not duplicating any existing PR.

@mergify mergify Bot added v1 bug Something isn't working labels Apr 23, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the GDN and Mamba2 attention backends to handle SSM state indexing after speculative decoding by passing the full block table when no prefills are present. A review comment identifies a potential issue in mixed batches where decode requests might read stale SSM state data because the implementation defaults to the first block when prefills are present.

Comment thread vllm/v1/attention/backends/gdn_attn.py
Comment thread vllm/v1/attention/backends/gdn_attn.py Outdated
@tdoublep tdoublep force-pushed the fix/gdn-ngram-spec-decode-state branch 4 times, most recently from bfccc78 to 1b2ba55 Compare April 24, 2026 22:39
@tdoublep tdoublep changed the title [Bugfix] Fix GDN SSM state corruption with ngram speculative decoding [Bugfix] Fix GDN conv + SSM state corruption with ngram spec decode Apr 24, 2026
When ngram speculative decoding accepts N tokens, the next non-spec
decode step reads SSM and conv state from block 0 instead of block
N-1, causing output corruption in GDN hybrid models (e.g. Qwen3.5).

Changes:
- Add spec_decode_src_indices to GDNAttentionMetadata, computed when
  accepted state lives in a different block than the decode target.
- Pre-copy SSM state from the accepted block to block 0 before kernel
  dispatch in both _forward_core and _forward_core_decode_non_spec.
- Pass num_accepted_tokens to causal_conv1d_fn/causal_conv1d_update so
  the kernel reads/writes conv state at the correct offset.
- Add IS_SPEC_DECODING constexpr and conv_state_token_offset logic to
  _causal_conv1d_fwd_kernel for offset-based conv state correction.
- Pad num_accepted_tokens with 1s for prefill sequences in mixed
  batches so the kernel sees offset=0 for those sequences.
- Gate num_accepted_tokens on spec_decode_src_indices to prevent stale
  values from leaking to non-correction steps.
- Pass num_accepted_tokens from gpu_model_runner to GDN/Mamba2
  attention builders on non-spec steps.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
@tdoublep tdoublep force-pushed the fix/gdn-ngram-spec-decode-state branch from 1b2ba55 to 6b5632c Compare April 24, 2026 22:42
@tdoublep tdoublep marked this pull request as ready for review April 24, 2026 22:44
Copy link
Copy Markdown

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Code Review

This pull request is from a fork — automated review is disabled. A repository maintainer can comment @claude review to run a one-time review.

@Sandermage
Copy link
Copy Markdown
Contributor

@tdoublep — first, big thanks for this PR. It was the single highest-impact fix in our v7.13 investigation — the +30-40% clean-rate jump (from ~20% baseline to 53-70%) we got from backporting Phase 1 + Phase 2 was the data point that confirmed the GDN+ngram interaction layer was where our bug actually lived. Without your work + @bhaktatejas922's original report at #39273, we'd still be searching the wrong layers. Sharing our independent confirmation + one observation below.

Independent confirmation on a non-trivial production setup

Backported both Phase 1 (Python-only: gdn_attn.py, gdn_linear_attn.py, gpu_model_runner.py) and Phase 2 (Triton kernel + wrapper changes in causal_conv1d.py + gdn_linear_attn call sites) onto vLLM 0.19.2rc1.dev205+g07351e088 running on 2× RTX A5000 (Ampere SM 8.6) with Qwen3.6-35B-A3B-FP8 (the hybrid GDN model class @bhaktatejas922 originally reported in #39273).

Empirical clean-rate progression (n=30 reproducer per row)

Config Tool-call clean rate
baseline (no fix) ~20%
+ Phase 1 (SSM state pre-copy only) 43-60%
+ Phase 2 (Triton conv-state offset) 53-70%

The Phase 1 → Phase 1+Phase 2 delta confirms your "kernel fix is necessary" remark in the PR description — Phase 1 alone gets us partway, Phase 2 closes the rest of the GDN-state-corruption slice.

One observation about the existing _causal_conv1d_update_kernel

Found that _causal_conv1d_update_kernel (decode-only) in dev205 already has IS_SPEC_DECODING constexpr + num_accepted_tokens_ptr from an earlier commit. Your PR adds the equivalent to _causal_conv1d_fwd_kernel (prefill/mixed) which was missing. Worth noting in the PR body if reviewers ask "isn't this already there" — the asymmetry between the two kernels was the actual gap.

Other things our setup needed in addition to your PR

Just for completeness — Phase 1+2 alone got us to ~50-70%. To reach 100% on tool-call workloads we additionally needed prompt_lookup_min=8 in the ngram speculative config — the residual was an orthogonal ngram acceptance bias toward XML-repeat patterns in system-prompt tool definitions, separate from the GDN state corruption your PR fixes. Mentioning so reviewers don't assume PR #40738 = silver bullet for ALL ngram-related corruption (it's correctly scoped to the GDN state corruption).

Backport reference

Phase 1: patch_60_gdn_ngram_state_recovery.py
Phase 2: patch_60b_gdn_ngram_triton_kernel.py

Both implemented as opt-in text patches with anchor validation + drift markers (auto-no-op if your PR lands upstream). Credit to you and @bhaktatejas922 in the docstrings + CREDITS.md.

Hope the data point helps the review. Happy to test additional changes if you want a second rig to verify against.

@bhaktatejas922
Copy link
Copy Markdown

Batch size 1 seems to work well. Anything above batch size 1 seems to be broken and result in weird results

bhaktatejas922 added a commit to bhaktatejas922/vllm that referenced this pull request May 6, 2026
PR vllm-project#40738 (commit 6b5632c) only fixes the all-non-spec branch in
GDNAttentionMetadataBuilder.build() — the case where every request
in the batch is doing regular decode this step.

The else-branch (mixed batch — some reqs have drafts to verify, others
don't) has the same state-corruption pattern unfixed:
- non_spec_state_indices_tensor reads block_table[~mask, 0] (column 0,
  stale if those reqs had spec on the previous step)
- num_accepted_tokens is filtered to spec-only, dropping the non-spec
  partition's counts entirely
- spec_decode_src_indices is never built → layer skips the SSM gather-
  scatter and conv kernel offset for the non-speccing partition

This is structurally invisible at BS=1 (no other req → every step is
pure-spec or pure-non-spec, the if-branch always wins). At BS>1 with
heterogeneous prompts (apply workloads — different code, different
ngram histories per request) it triggers most steps where some req's
drafter returns empty while others return drafts. Per-fire damage scales
with num_speculative_tokens — at k=64 the non-speccing req in a mixed
batch reads recurrent state up to ~64 tokens stale.

Empirical signature (Qwen3.5 fast-apply, 32-layer hybrid, FP8, k=64):
  dataset 0926-0930__easy_random_prod (329 ex):
    BS=1 acc_sorted: 0.9605
    pre-fix BS=8:    0.5714
    post-fix BS=8:   0.9574  (within 0.003 of BS=1)
  dataset 0923_0925_random_easy_regression_check (79 ex):
    BS=1 acc_sorted: 0.9747
    pre-fix BS=8:    0.6962
    post-fix BS=8:   0.9747  (parity)

Fix mirrors the if-branch's logic in the inner-else (the actual mixed-
batch case): build spec_decode_src_indices from the non-speccing
partition's num_accepted_tokens BEFORE the spec-only filter is applied,
and add a non_spec_num_accepted field on GDNAttentionMetadata so the
layer's prefill conv1d call uses the right per-row offsets.

A repro that triggers this needs heterogeneous prompts/lengths so the
batch goes mixed (some reqs with drafts, some without). Identical
prompts in parallel keep the batch homogeneous and miss it — that's
likely why the original PR's BS>1 tests passed.
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented May 23, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @tdoublep.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label May 23, 2026
JEF1056 pushed a commit to JEF1056/vllm-turboquant that referenced this pull request Jun 3, 2026
…lm-project#40880)

Backport two upstream fixes for degenerate output when combining
speculative decoding with TurboQuant KV cache quantization:

P65 (vllm-project#40880): Downgrade TurboQuantMetadataBuilder._cudagraph_support
from UNIFORM_BATCH to UNIFORM_SINGLE_TOKEN_DECODE. Prevents CUDA graph
capture of spec-decode K+1 verify batches which baked stale
cu_seqlens_k into the continuation-prefill path, causing the kernel
to attend only to query tokens and ignore all prior cached KV.

vllm-project#40738: Fix GDN conv + SSM state corruption with spec decode. On
non-spec steps after token acceptance, pre-copy SSM state from the
accepted block and offset conv state reads by num_accepted_tokens-1.

Co-authored-by: Wibey CLI <genai-coding-assistants@walmart.com>
JEF1056 pushed a commit to JEF1056/vllm-turboquant that referenced this pull request Jun 3, 2026
…vllm-project#40880)

Backport two upstream fixes for degenerate output when combining
speculative decoding with TurboQuant KV cache quantization and Qwen3.5
Gated Delta Net (GDN) layers.

Downgrade TurboQuantMetadataBuilder._cudagraph_support from
UNIFORM_BATCH to UNIFORM_SINGLE_TOKEN_DECODE. With UNIFORM_BATCH, the
captured graph bakes cu_seqlens_k = cu_seqlens_q into the continuation-
prefill path, so the kernel attends only to the current K+1 query tokens
and ignores all prior cached KV, producing degenerate output.
UNIFORM_SINGLE_TOKEN_DECODE forces spec-verify batches to eager/PIECEWISE
where TQ's continuation branch correctly decompresses the full cache.

On non-spec steps after token acceptance, pre-copy SSM state from the
accepted block to block 0 so the next decode reads correct state. Offset
conv state reads by num_accepted_tokens-1 via a new parameter threaded
through causal_conv1d_fn and the Triton conv kernel.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working needs-rebase v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug]: Ngram speculative decoding produces corrupted output on hybrid GDN (Qwen3.5) models

3 participants