[Spec Decoding] Add DFlash model and proposer#1868
[Spec Decoding] Add DFlash model and proposer#1868aaronzhfeng wants to merge 3 commits intovllm-project:mainfrom
Conversation
|
This is a large PR. Can we break it down into several small PRs to make review easier? |
|
Sorry about the large PR. The model, proposer, and attention kernel are tightly coupled (proposer calls model forward, model uses the attention kernel), so splitting them further would leave each PR non-functional on its own. All files here are new additions with no changes to existing code, which should make it easier to review. Broke the original PR down into 3:
PRs 2 and 3 coming shortly. |
kyuyeunk
left a comment
There was a problem hiding this comment.
Hi @aaronzhfeng! thank you for the contribution. couple of questions
- for those who aren't familiar with DFlash (like myself), can you give a brief overview & maybe a link where we can find out about more info?
- is my understanding correct that this feature is not available in vllm's pytorch model implementation? if so, is there a way for a backend that utilizes vllm's model implementation to leverage this spec decoding?
- can you share a sample command for people to try out this feature while going through the review process?
|
Thanks for taking a look! DFlash overview: DFlash is a block-diffusion speculative decoding method that predicts multiple tokens in parallel using discrete diffusion, instead of generating them one at a time autoregressively. Given a context, the draft model takes a block of masked/noise positions and denoises them in a single forward pass to produce K candidate tokens simultaneously. This makes drafting O(1) in block size rather than O(K). Paper: "DFlash: Block Diffusion for Flash Speculative Decoding" (Chen et al., arXiv:2602.06036). The reference GPU implementation is at https://github.com/z-lab/dflash. PyTorch/vLLM availability: Right now there is no DFlash support in vLLM's PyTorch backend. The DFlash authors have confirmed vLLM integration is still in progress on their end (see z-lab/dflash#6). SGLang has DFlash support via sgl-project/sglang#16818, but this PR would be the first DFlash integration in the vLLM ecosystem. It targets the JAX/TPU backend specifically, since the draft model uses non-causal attention which required a different attention path from the standard causal pipeline. A PyTorch port is feasible but not in scope for this PR. Sample command: The unit tests in this PR can be run without a full serving setup: pytest tests/models/jax/test_qwen3_dflash_attention.py
pytest tests/models/jax/test_qwen3_dflash.py
pytest tests/spec_decode/test_dflash.pyEnd-to-end serving requires the pipeline integration in PR #1869 (already open). Once both are merged, with Qwen3-4B on a TPU v5p-8: python -m tpu_inference.entrypoint \
--model Qwen/Qwen3-4B \
--speculative_config '{"model": "z-lab/Qwen3-4B-DFlash-b16", "num_speculative_tokens": 15, "method": "dflash", "draft_tensor_parallel_size": 1}' |
There was a problem hiding this comment.
can you elaborate what kind of feature is missing from existing attention implementation that it requires its own separate code? if it's due to bi-directional attention, we already have an implementation for that.
There was a problem hiding this comment.
Sorry for the late reply. I've been heads-down on other work items, and I appreciate the thorough review.
The missing feature is not bi-directional / non-causal attention itself. In the original DFlash GPU repo (z-lab/dflash), the DFlash path also wraps standard attention backends (flash_attention_2, sdpa) in a custom Qwen3DFlashAttention module that concatenates context and noise K/V before calling the backend; the generation loop then manages the draft cache with DynamicCache.crop(). In tpu_inference, the existing paged-cache attention() path assumes a single (k, v) stream and couples cache update with attention over that stream via block tables, so it does not directly expose the DFlash concat behavior.
We did look into adapting the existing non-causal sharded_flash_attention(..., causal=False) by concatenating K/V upstream and passing a single tensor into the kernel. That may be possible, but for this path it would still require extra reshape/pad bookkeeping for ragged per-request batching and doubled KV length.
We reuse the existing attention() path where we can. In additive_legacy, we collapse the two streams into a single (k, v) and call attention() directly. Even in the concat path, we still use attention() for the KV-cache update on the noise stream; the separate helper is only for the output computation over [k_ctx; k_noise].
Happy to explore a better approach if you have anything in mind.
|
|
||
|
|
||
| @functools.partial(jax.jit, static_argnames=("max_query_len", )) | ||
| def dflash_concat_attention( |
There was a problem hiding this comment.
in general, i think this function is lacking a lot of comments explaning what each line does.
There was a problem hiding this comment.
Good point, thank you. Added inline comments covering GQA head expansion, padding for static XLA slice sizes, per-request boundaries, KV concat, and masking.
kyuyeunk
left a comment
There was a problem hiding this comment.
i think this pr looks fine.
it completely new file that doesn't touch any of our existing logics, so there isn't a huge concern for me that this can potentially break our other logics.
lgtm.
but it would be great if people who are more experienced with spec decoding (like @Lumosis) can take a look.
if there's isn't any further feedback, i'll have it merged by eod.
|
but ci is failing, so please fix it. |
|
|
||
| # Allocate on-device KV caches | ||
| hf_config = self.draft_model_config.hf_config | ||
| from tpu_inference import utils |
There was a problem hiding this comment.
Move these to the top of the file?
There was a problem hiding this comment.
Done. Moved from tpu_inference import utils, ShardingAxisName, and get_mesh_shape_product to the top-level import block.
| cache_shape, | ||
| ) | ||
|
|
||
| def _project_aux_hidden( |
There was a problem hiding this comment.
Should we make this function jitted?
There was a problem hiding this comment.
Great suggestion, thank you. Wrapped with @functools.partial(jax.jit, static_argnums=(0,)) following the same pattern as _sample_block_draft_tokens. Passed state as an explicit traced argument to avoid capturing mutable state through static self.
| new_ctx_np = self._update_context_buffer(projected, seq_len) | ||
|
|
||
| if new_ctx_np is None or len(new_ctx_np) == 0: | ||
| # Full rejection — all padded entries are zeros, and noise | ||
| # writes at cache_len + 0, completely overwriting them. | ||
| actual_new_ctx_count = 0 | ||
| new_ctx_np = np.zeros((16, self.hidden_size), dtype=np.float32) | ||
| else: | ||
| actual_new_ctx_count = len(new_ctx_np) | ||
| new_ctx_np = self._pad_context(new_ctx_np) | ||
|
|
||
| # 5. Upload padded context to device. | ||
| # Padding to power-of-2 sizes (16/32/64/128) means JIT only | ||
| # traces ~4 unique shapes, eliminating per-token retracing. | ||
| new_ctx_jax = device_array( | ||
| self.mesh, | ||
| jnp.array(new_ctx_np, dtype=jnp.bfloat16), | ||
| ) |
There was a problem hiding this comment.
It would be better if we can do this operation on TPU to avoid host <-> TPU transfer overhead
There was a problem hiding this comment.
Thank you for pointing this out. Removed the host-side numpy buffer entirely. Context slicing and power-of-2 padding now happen directly on device, eliminating the np.asarray() download and jnp.array() re-upload. The _update_context_buffer and _pad_context methods were replaced by inline on-device ops in prepare_inputs. Net result: -38 lines, zero host-to-TPU transfers in the context path.
|
Left a few comments. Otherwise, the PR looks good. |
Signed-off-by: aaronzhfeng <fzx333578@gmail.com>
Signed-off-by: aaronzhfeng <fzx333578@gmail.com>
…-device context - Move lazy imports (utils, ShardingAxisName, get_mesh_shape_product) to the top-level import block. - JIT _project_aux_hidden with explicit state arg, following the same pattern as _sample_block_draft_tokens. - Replace host-side numpy context buffer with on-device slicing and padding, eliminating host<->TPU transfer in the context path. Signed-off-by: aaronzhfeng <fzx333578@gmail.com>
|
Thanks for the reviews @kyuyeunk @Lumosis! Sorry for the delay on these. All three inline comments addressed in 9eeae78, DCO sign-off added to all commits, and the branch is rebased on the latest main. A few things we're working on for this PR set:
Will follow up on these as they're ready. Let us know if any of this should be prioritized or handled differently. |
Thanks! Some unit tests failed. PTAL |
Description
Add DFlash draft model and proposer for block-diffusion speculative decoding on JAX/TPU. DFlash predicts multiple tokens in parallel using discrete diffusion, unlike Eagle3's autoregressive drafting. This follows the same proposer pattern as Eagle3.
This is PR 1 of 3 for DFlash support:
New files:
tpu_inference/models/jax/dflash.py-- DFlash draft model (DFlashForCausalLM)tpu_inference/models/jax/qwen3_dflash.py-- Qwen3-specific DFlash variant with attentiontpu_inference/layers/common/dflash_attention_interface.py-- dflash_concat_attention kerneltpu_inference/spec_decode/jax/dflash.py-- DFlashProposer (prepare_inputs, propose, sampling)tests/models/jax/test_qwen3_dflash_attention.py-- DFlash attention unit teststests/models/jax/test_qwen3_dflash.py-- target layer ID selection teststests/spec_decode/test_dflash.py-- proposer sampling testsTests
tests/models/jax/test_qwen3_dflash_attention.pytests/models/jax/test_qwen3_dflash.pytests/spec_decode/test_dflash.pyChecklist