Skip to content

[Spec Decoding] Add DFlash model and proposer#1868

Open
aaronzhfeng wants to merge 3 commits intovllm-project:mainfrom
aaronzhfeng:pr_dflash_1
Open

[Spec Decoding] Add DFlash model and proposer#1868
aaronzhfeng wants to merge 3 commits intovllm-project:mainfrom
aaronzhfeng:pr_dflash_1

Conversation

@aaronzhfeng
Copy link
Copy Markdown

@aaronzhfeng aaronzhfeng commented Mar 5, 2026

Description

Add DFlash draft model and proposer for block-diffusion speculative decoding on JAX/TPU. DFlash predicts multiple tokens in parallel using discrete diffusion, unlike Eagle3's autoregressive drafting. This follows the same proposer pattern as Eagle3.

This is PR 1 of 3 for DFlash support:

  1. This PR: Model, proposer, and unit tests (all new files)
  2. Pipeline integration (modifications to existing files)
  3. E2E tests and Buildkite CI

New files:

  • tpu_inference/models/jax/dflash.py -- DFlash draft model (DFlashForCausalLM)
  • tpu_inference/models/jax/qwen3_dflash.py -- Qwen3-specific DFlash variant with attention
  • tpu_inference/layers/common/dflash_attention_interface.py -- dflash_concat_attention kernel
  • tpu_inference/spec_decode/jax/dflash.py -- DFlashProposer (prepare_inputs, propose, sampling)
  • tests/models/jax/test_qwen3_dflash_attention.py -- DFlash attention unit tests
  • tests/models/jax/test_qwen3_dflash.py -- target layer ID selection tests
  • tests/spec_decode/test_dflash.py -- proposer sampling tests

Tests

  • Unit tests for DFlash attention (concat, additive bias, GQA): tests/models/jax/test_qwen3_dflash_attention.py
  • Unit tests for target layer ID selection: tests/models/jax/test_qwen3_dflash.py
  • Unit tests for proposer sampling: tests/spec_decode/test_dflash.py

Checklist

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have made or will make corresponding changes to any relevant documentation.

@Lumosis
Copy link
Copy Markdown
Collaborator

Lumosis commented Mar 5, 2026

This is a large PR. Can we break it down into several small PRs to make review easier?

@aaronzhfeng aaronzhfeng changed the title [Spec Decoding] Add DFlash block-diffusion speculative decoding [Spec Decoding] Add DFlash model and proposer Mar 5, 2026
@aaronzhfeng
Copy link
Copy Markdown
Author

Sorry about the large PR. The model, proposer, and attention kernel are tightly coupled (proposer calls model forward, model uses the attention kernel), so splitting them further would leave each PR non-functional on its own. All files here are new additions with no changes to existing code, which should make it easier to review.

Broke the original PR down into 3:

  1. This PR (updated): DFlash model, proposer, and unit tests -- all new files, no existing files modified
  2. Pipeline integration -- modifications to existing files (speculative_decoding_manager, kv_cache_manager, qwen3, tpu_runner, model_loader)
  3. E2E tests + Buildkite CI

PRs 2 and 3 coming shortly.

@Lumosis Lumosis added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 5, 2026
Copy link
Copy Markdown
Collaborator

@kyuyeunk kyuyeunk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @aaronzhfeng! thank you for the contribution. couple of questions

  • for those who aren't familiar with DFlash (like myself), can you give a brief overview & maybe a link where we can find out about more info?
  • is my understanding correct that this feature is not available in vllm's pytorch model implementation? if so, is there a way for a backend that utilizes vllm's model implementation to leverage this spec decoding?
  • can you share a sample command for people to try out this feature while going through the review process?

@aaronzhfeng
Copy link
Copy Markdown
Author

Thanks for taking a look!

DFlash overview: DFlash is a block-diffusion speculative decoding method that predicts multiple tokens in parallel using discrete diffusion, instead of generating them one at a time autoregressively. Given a context, the draft model takes a block of masked/noise positions and denoises them in a single forward pass to produce K candidate tokens simultaneously. This makes drafting O(1) in block size rather than O(K).

Paper: "DFlash: Block Diffusion for Flash Speculative Decoding" (Chen et al., arXiv:2602.06036). The reference GPU implementation is at https://github.com/z-lab/dflash.

PyTorch/vLLM availability: Right now there is no DFlash support in vLLM's PyTorch backend. The DFlash authors have confirmed vLLM integration is still in progress on their end (see z-lab/dflash#6). SGLang has DFlash support via sgl-project/sglang#16818, but this PR would be the first DFlash integration in the vLLM ecosystem. It targets the JAX/TPU backend specifically, since the draft model uses non-causal attention which required a different attention path from the standard causal pipeline. A PyTorch port is feasible but not in scope for this PR.

Sample command: The unit tests in this PR can be run without a full serving setup:

pytest tests/models/jax/test_qwen3_dflash_attention.py
pytest tests/models/jax/test_qwen3_dflash.py
pytest tests/spec_decode/test_dflash.py

End-to-end serving requires the pipeline integration in PR #1869 (already open). Once both are merged, with Qwen3-4B on a TPU v5p-8:

python -m tpu_inference.entrypoint \
  --model Qwen/Qwen3-4B \
  --speculative_config '{"model": "z-lab/Qwen3-4B-DFlash-b16", "num_speculative_tokens": 15, "method": "dflash", "draft_tensor_parallel_size": 1}'

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you elaborate what kind of feature is missing from existing attention implementation that it requires its own separate code? if it's due to bi-directional attention, we already have an implementation for that.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the late reply. I've been heads-down on other work items, and I appreciate the thorough review.

The missing feature is not bi-directional / non-causal attention itself. In the original DFlash GPU repo (z-lab/dflash), the DFlash path also wraps standard attention backends (flash_attention_2, sdpa) in a custom Qwen3DFlashAttention module that concatenates context and noise K/V before calling the backend; the generation loop then manages the draft cache with DynamicCache.crop(). In tpu_inference, the existing paged-cache attention() path assumes a single (k, v) stream and couples cache update with attention over that stream via block tables, so it does not directly expose the DFlash concat behavior.

We did look into adapting the existing non-causal sharded_flash_attention(..., causal=False) by concatenating K/V upstream and passing a single tensor into the kernel. That may be possible, but for this path it would still require extra reshape/pad bookkeeping for ragged per-request batching and doubled KV length.

We reuse the existing attention() path where we can. In additive_legacy, we collapse the two streams into a single (k, v) and call attention() directly. Even in the concat path, we still use attention() for the KV-cache update on the noise stream; the separate helper is only for the output computation over [k_ctx; k_noise].

Happy to explore a better approach if you have anything in mind.



@functools.partial(jax.jit, static_argnames=("max_query_len", ))
def dflash_concat_attention(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in general, i think this function is lacking a lot of comments explaning what each line does.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, thank you. Added inline comments covering GQA head expansion, padding for static XLA slice sizes, per-request boundaries, KV concat, and masking.

Copy link
Copy Markdown
Collaborator

@kyuyeunk kyuyeunk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think this pr looks fine.

it completely new file that doesn't touch any of our existing logics, so there isn't a huge concern for me that this can potentially break our other logics.

lgtm.

but it would be great if people who are more experienced with spec decoding (like @Lumosis) can take a look.

if there's isn't any further feedback, i'll have it merged by eod.

@kyuyeunk
Copy link
Copy Markdown
Collaborator

but ci is failing, so please fix it.


# Allocate on-device KV caches
hf_config = self.draft_model_config.hf_config
from tpu_inference import utils
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move these to the top of the file?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Moved from tpu_inference import utils, ShardingAxisName, and get_mesh_shape_product to the top-level import block.

cache_shape,
)

def _project_aux_hidden(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we make this function jitted?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great suggestion, thank you. Wrapped with @functools.partial(jax.jit, static_argnums=(0,)) following the same pattern as _sample_block_draft_tokens. Passed state as an explicit traced argument to avoid capturing mutable state through static self.

Comment on lines +266 to +283
new_ctx_np = self._update_context_buffer(projected, seq_len)

if new_ctx_np is None or len(new_ctx_np) == 0:
# Full rejection — all padded entries are zeros, and noise
# writes at cache_len + 0, completely overwriting them.
actual_new_ctx_count = 0
new_ctx_np = np.zeros((16, self.hidden_size), dtype=np.float32)
else:
actual_new_ctx_count = len(new_ctx_np)
new_ctx_np = self._pad_context(new_ctx_np)

# 5. Upload padded context to device.
# Padding to power-of-2 sizes (16/32/64/128) means JIT only
# traces ~4 unique shapes, eliminating per-token retracing.
new_ctx_jax = device_array(
self.mesh,
jnp.array(new_ctx_np, dtype=jnp.bfloat16),
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be better if we can do this operation on TPU to avoid host <-> TPU transfer overhead

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for pointing this out. Removed the host-side numpy buffer entirely. Context slicing and power-of-2 padding now happen directly on device, eliminating the np.asarray() download and jnp.array() re-upload. The _update_context_buffer and _pad_context methods were replaced by inline on-device ops in prepare_inputs. Net result: -38 lines, zero host-to-TPU transfers in the context path.

@Lumosis
Copy link
Copy Markdown
Collaborator

Lumosis commented Apr 1, 2026

Left a few comments. Otherwise, the PR looks good.

Signed-off-by: aaronzhfeng <fzx333578@gmail.com>
Signed-off-by: aaronzhfeng <fzx333578@gmail.com>
…-device context

- Move lazy imports (utils, ShardingAxisName, get_mesh_shape_product)
  to the top-level import block.
- JIT _project_aux_hidden with explicit state arg, following the same
  pattern as _sample_block_draft_tokens.
- Replace host-side numpy context buffer with on-device slicing and
  padding, eliminating host<->TPU transfer in the context path.

Signed-off-by: aaronzhfeng <fzx333578@gmail.com>
@aaronzhfeng
Copy link
Copy Markdown
Author

Thanks for the reviews @kyuyeunk @Lumosis! Sorry for the delay on these. All three inline comments addressed in 9eeae78, DCO sign-off added to all commits, and the branch is rebased on the latest main.

A few things we're working on for this PR set:

  1. Torchax coverage: adding a torchax proposer so DFlash works on the PyTorch serving path as well. Following the existing vllm_model_wrapper.py interop pattern.
  2. Environment guard: DFlash currently only activates when method="dflash" is set in speculative_config, so non-DFlash users are unaffected. Happy to add a more explicit flag if preferred.
  3. Framework extensibility: DFlash follows the same proposer interface as Eagle3 (load_model, prepare_inputs, propose). Open to suggestions on whether a formal base class or registry would be useful.

Will follow up on these as they're ready. Let us know if any of this should be prioritized or handled differently.

@Lumosis
Copy link
Copy Markdown
Collaborator

Lumosis commented Apr 7, 2026

Thanks for the reviews @kyuyeunk @Lumosis! Sorry for the delay on these. All three inline comments addressed in 9eeae78, DCO sign-off added to all commits, and the branch is rebased on the latest main.

A few things we're working on for this PR set:

  1. Torchax coverage: adding a torchax proposer so DFlash works on the PyTorch serving path as well. Following the existing vllm_model_wrapper.py interop pattern.
  2. Environment guard: DFlash currently only activates when method="dflash" is set in speculative_config, so non-DFlash users are unaffected. Happy to add a more explicit flag if preferred.
  3. Framework extensibility: DFlash follows the same proposer interface as Eagle3 (load_model, prepare_inputs, propose). Open to suggestions on whether a formal base class or registry would be useful.

Will follow up on these as they're ready. Let us know if any of this should be prioritized or handled differently.

Thanks! Some unit tests failed. PTAL

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants