Skip to content

[Spec Decoding] Integrate DFlash into speculative decoding pipeline#1869

Open
aaronzhfeng wants to merge 1 commit intovllm-project:mainfrom
aaronzhfeng:pr_dflash_1b
Open

[Spec Decoding] Integrate DFlash into speculative decoding pipeline#1869
aaronzhfeng wants to merge 1 commit intovllm-project:mainfrom
aaronzhfeng:pr_dflash_1b

Conversation

@aaronzhfeng
Copy link
Copy Markdown

Description

Wire DFlash block-diffusion speculative decoding into the existing TPU inference pipeline. The DFlash model and proposer were added in #1868; this PR connects them to the runner, KV cache manager, and speculative decoding manager so DFlash can be used end-to-end.

No changes to existing Eagle3 or ngram code paths: DFlash gets its own propose_dflash_draft_token_ids method and a separate elif "dflash" dispatch branch.

Modified files:

  • tpu_inference/models/common/model_loader.py -- register DFlashDraftModel in model registry
  • tpu_inference/models/jax/qwen3.py -- collect aux_hidden_states from target layers during forward pass (needed by DFlash proposer to inject target context)
  • tpu_inference/runner/tpu_runner.py -- add DFlashProposer initialization for method="dflash"
  • tpu_inference/runner/speculative_decoding_manager.py -- add dflash method dispatch and propose_dflash_draft_token_ids (uses accepted_attn_metadata with correct seq_lens for drafter)
  • tpu_inference/runner/kv_cache_manager.py -- extend draft KV cache allocation to cover dflash, read num_hidden_layers from config instead of hardcoding 1

Usage (after both #1868 and this PR):

args['speculative_config'] = {
    'model': 'z-lab/Qwen3-4B-DFlash-b16',
    'num_speculative_tokens': 5,
    'method': 'dflash',
    'draft_tensor_parallel_size': 1,
}

Tests

E2e tests are in a follow-up PR.

Checklist

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have made or will make corresponding changes to any relevant documentation.

Signed-off-by: aaronzhfeng <fzx333578@gmail.com>
Copy link
Copy Markdown
Collaborator

@kyuyeunk kyuyeunk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think the pr looks okay, but please add a unit test.

it doesn't have to be a big one - like integrating this into a ci: #1870

but just a simple thing - like making sure that functions added in tpu_inference/runner/speculative_decoding_manager.py (like propose_dflash_draft_token_ids) is working correctly, etc, would give me better confidence that this won't break any thing.

@Lumosis
Copy link
Copy Markdown
Collaborator

Lumosis commented Apr 1, 2026

We should precompile the jitted functions for dflash in compilation_manager.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants