Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
cdda524
add dflash in speculative config
dangoldbj Feb 6, 2026
ac2005c
treat dflash as graph-affecting
dangoldbj Feb 6, 2026
bb56db2
add DFlash architecture mapping in registry
dangoldbj Feb 6, 2026
560cc7b
add DFlash draft model to architecture test matrix
dangoldbj Feb 6, 2026
abcff56
add DFlash Qwen3 draft model executor and registry wiring
dangoldbj Feb 6, 2026
d952bda
add DFlash Qwen3 draft model executor and registry wiring
dangoldbj Feb 6, 2026
fa49ac4
Add dedicated DFlashProposer and route dflash runtime path separately…
dangoldbj Feb 6, 2026
ca522fe
Add DFlash BS1/backend fail-fast validation and unit tests
dangoldbj Feb 6, 2026
0516f13
add negative/positive coverage for DFlash backend validation
dangoldbj Feb 6, 2026
0d01f71
enable DFlash runtime dispatch (shared_eagle/block_drafting) and alig…
dangoldbj Feb 6, 2026
f7d825f
add DFlash spec decode docs/example path and expand proposer unit cov…
dangoldbj Feb 6, 2026
cd00076
runtime hardening for DFlash block drafting
dangoldbj Feb 6, 2026
4f516b0
harden DFlash block_drafting metadata handling and add config/runtime…
dangoldbj Feb 6, 2026
8086c46
implement BS>1 block_drafting in DFlashProposer without fallback
dangoldbj Feb 6, 2026
c076ad6
harden BS>1 coverage and enforce DFlash vocab-size safety checks
dangoldbj Feb 6, 2026
da4dbe3
add DFlash async e2e coverage, tighten metrics assertions, and refact…
dangoldbj Feb 6, 2026
a505be1
reduced per-step allocations in block drafting
dangoldbj Feb 6, 2026
c2efbfd
improve docs
dangoldbj Feb 6, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions docs/features/spec_decode/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,46 @@ A variety of EAGLE draft models are available on the Hugging Face hub:
| Qwen2-7B-Instruct | yuhuili/EAGLE-Qwen2-7B-Instruct | 0.26B |
| Qwen2-72B-Instruct | yuhuili/EAGLE-Qwen2-72B-Instruct | 1.05B |

## Speculating using DFlash draft models

vLLM also supports speculative decoding with DFlash draft models via
`"method": "dflash"`.

??? code

```python
from vllm import LLM, SamplingParams

prompts = [
"Summarize speculative decoding in one paragraph.",
]
sampling_params = SamplingParams(temperature=0.0, max_tokens=64)

llm = LLM(
model="Qwen/Qwen3-8B",
speculative_config={
"method": "dflash",
"model": "z-lab/Qwen3-8B-DFlash-b16",
"num_speculative_tokens": 16,
},
)

outputs = llm.generate(prompts, sampling_params)
print(outputs[0].outputs[0].text)
```

Important constraints for DFlash:

1. DFlash draft models are currently tested with Qwen3-based checkpoints.
2. `runtime_mode="block_drafting"` supports batched requests in vLLM V1, but
this path is still under active hardening and should be validated for your
workload/backends before production rollout.
3. DFlash requires an attention backend with non-causal drafting support.
If backend validation fails, use FlashAttention or Triton attention backend.

The offline benchmark/example script supports this path:
`examples/offline_inference/spec_decode.py --method dflash --draft-model <dflash-model>`.

## Lossless guarantees of Speculative Decoding

In vLLM, speculative decoding aims to enhance inference efficiency while maintaining accuracy. This section addresses the lossless guarantees of
Expand Down
4 changes: 2 additions & 2 deletions examples/offline_inference/spec_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def parse_args():
"--method",
type=str,
default="eagle",
choices=["ngram", "eagle", "eagle3", "mtp", "draft_model"],
choices=["ngram", "eagle", "eagle3", "mtp", "draft_model", "dflash"],
)
parser.add_argument("--backend", type=str, default="openai")
parser.add_argument("--num-spec-tokens", type=int, default=2)
Expand Down Expand Up @@ -131,7 +131,7 @@ def main(args):
"prompt_lookup_max": args.prompt_lookup_max,
"prompt_lookup_min": args.prompt_lookup_min,
}
elif args.method == "draft_model":
elif args.method in ("draft_model", "dflash"):
assert args.draft_model is not None and args.draft_model != ""
speculative_config = {
"method": args.method,
Expand Down
17 changes: 17 additions & 0 deletions tests/config/draft_model_arch_groundtruth.json
Original file line number Diff line number Diff line change
Expand Up @@ -83,5 +83,22 @@
"is_deepseek_mla": false,
"is_multimodal_model": false,
"dtype": "float16"
},
"z-lab/Qwen3-8B-DFlash-b16": {
"architectures": [
"DFlashDraftModel"
],
"model_type": "qwen3",
"text_model_type": "qwen3",
"hidden_size": 4096,
"total_num_hidden_layers": 5,
"total_num_attention_heads": 32,
"head_size": 128,
"vocab_size": 151936,
"total_num_kv_heads": 8,
"num_experts": 0,
"is_deepseek_mla": false,
"is_multimodal_model": false,
"dtype": "torch.bfloat16"
}
}
1 change: 1 addition & 0 deletions tests/config/test_model_arch_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
("eagle618/deepseek-v3-random", "eagle618/eagle-deepseek-v3-random", True),
("meta-llama/Meta-Llama-3-8B-Instruct", "yuhuili/EAGLE-LLaMA3-Instruct-8B", True),
("meta-llama/Llama-3.1-8B-Instruct", "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", True),
("Qwen/Qwen3-8B", "z-lab/Qwen3-8B-DFlash-b16", False),
]


Expand Down
194 changes: 194 additions & 0 deletions tests/config/test_speculative_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from types import SimpleNamespace
from unittest.mock import patch

import pytest

from vllm.config import ParallelConfig, SpeculativeConfig
from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.models.qwen3_dflash import DFlashQwen3ForCausalLM


class _DummyDraftModelConfig:
def __init__(self, model: str, model_type: str, architectures: list[str]):
self.model = model
self.hf_config = SimpleNamespace(
model_type=model_type,
architectures=architectures,
)
self.hf_text_config = self.hf_config
self.architectures = architectures
self.max_model_len = 4096

def verify_with_parallel_config(self, parallel_config: ParallelConfig) -> None:
del parallel_config


def _make_target_model_config(model_type: str = "qwen"):
return SimpleNamespace(
model="dummy-target",
tokenizer="dummy-target",
tokenizer_mode="auto",
trust_remote_code=False,
allowed_local_media_path=None,
allowed_media_domains=None,
dtype="float16",
seed=0,
revision=None,
tokenizer_revision=None,
max_model_len=4096,
quantization=None,
enforce_eager=False,
max_logprobs=20,
config_format="hf",
hf_text_config=SimpleNamespace(model_type=model_type),
get_vocab_size=lambda: 1000,
)


def _model_config_factory(
model_type: str = "dummy",
architectures: list[str] | None = None,
):
def _factory(*args, **kwargs):
del args
return _DummyDraftModelConfig(
model=kwargs["model"],
model_type=model_type,
architectures=architectures or ["DummyForCausalLM"],
)

return _factory


def test_dflash_method_is_supported_in_speculative_config():
with patch(
"vllm.config.speculative.ModelConfig",
new=_model_config_factory(),
):
config = SpeculativeConfig(
method="dflash",
model="z-lab/Qwen3-8B-DFlash-b16",
num_speculative_tokens=1,
target_model_config=_make_target_model_config(),
target_parallel_config=ParallelConfig(),
)
assert config.method == "dflash"


def test_dflash_method_auto_detects_from_model_name():
with patch(
"vllm.config.speculative.ModelConfig",
new=_model_config_factory(),
):
config = SpeculativeConfig(
model="z-lab/Qwen3-8B-DFlash-b16",
num_speculative_tokens=1,
target_model_config=_make_target_model_config(),
target_parallel_config=ParallelConfig(),
)
assert config.method == "dflash"


def test_dflash_method_auto_detects_from_draft_architecture():
with patch(
"vllm.config.speculative.ModelConfig",
new=_model_config_factory(architectures=["DFlashDraftModel"]),
):
config = SpeculativeConfig(
model="org/draft-model-without-name-hint",
num_speculative_tokens=1,
target_model_config=_make_target_model_config(),
target_parallel_config=ParallelConfig(),
)
assert config.method == "dflash"


def test_dflash_target_model_validation():
with (
patch(
"vllm.config.speculative.ModelConfig",
new=_model_config_factory(),
),
pytest.raises(ValueError, match="dflash is only supported"),
):
SpeculativeConfig(
method="dflash",
model="z-lab/Qwen3-8B-DFlash-b16",
num_speculative_tokens=1,
target_model_config=_make_target_model_config(model_type="opt"),
target_parallel_config=ParallelConfig(),
)


def test_dflash_hash_differs_from_non_eagle3_method():
ngram_config = SpeculativeConfig(
method="ngram",
num_speculative_tokens=1,
)
with patch(
"vllm.config.speculative.ModelConfig",
new=_model_config_factory(),
):
dflash_config = SpeculativeConfig(
method="dflash",
model="org/dflash-draft",
num_speculative_tokens=1,
target_model_config=_make_target_model_config(),
target_parallel_config=ParallelConfig(),
)
assert dflash_config.compute_hash() != ngram_config.compute_hash()


def test_dflash_architecture_is_registered():
assert "DFlashDraftModel" in ModelRegistry.get_supported_archs()
model_cls = ModelRegistry._try_load_model_cls("DFlashDraftModel")
assert model_cls is DFlashQwen3ForCausalLM


def test_use_eagle_returns_true_for_dflash():
with patch(
"vllm.config.speculative.ModelConfig",
new=_model_config_factory(),
):
config = SpeculativeConfig(
method="dflash",
model="org/dflash-draft",
num_speculative_tokens=1,
target_model_config=_make_target_model_config(),
target_parallel_config=ParallelConfig(),
)
assert config.use_eagle()


def test_dflash_aux_layers_from_dflash_config():
model = DFlashQwen3ForCausalLM.__new__(DFlashQwen3ForCausalLM)
model.config = SimpleNamespace(
dflash_config={"layer_ids": [31, 33, 35]},
target_layer_count=36,
)
assert model.get_eagle3_aux_hidden_state_layers() == (31, 33, 35)


def test_dflash_aux_layers_fallback_to_eagle_aux_layer_ids():
model = DFlashQwen3ForCausalLM.__new__(DFlashQwen3ForCausalLM)
model.config = SimpleNamespace(
eagle_aux_hidden_state_layer_ids=[7, 11, 15],
target_layer_count=36,
)
assert model.get_eagle3_aux_hidden_state_layers() == (7, 11, 15)


def test_dflash_aux_layers_fallback_to_target_layer_count_default():
model = DFlashQwen3ForCausalLM.__new__(DFlashQwen3ForCausalLM)
model.config = SimpleNamespace(target_layer_count=8)
assert model.get_eagle3_aux_hidden_state_layers() == (2, 4, 5)


def test_set_dflash_aux_layers_updates_dflash_config():
model = DFlashQwen3ForCausalLM.__new__(DFlashQwen3ForCausalLM)
model.config = SimpleNamespace(dflash_config={})
model.set_aux_hidden_state_layers((1, 3, 5))
assert model.config.dflash_config["layer_ids"] == [1, 3, 5]
5 changes: 5 additions & 0 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -1022,6 +1022,11 @@ def check_available_online(


_SPECULATIVE_DECODING_EXAMPLE_MODELS = {
"DFlashDraftModel": _HfExamplesInfo(
"Qwen/Qwen3-8B",
speculative_model="z-lab/Qwen3-8B-DFlash-b16",
speculative_method="dflash",
),
"MedusaModel": _HfExamplesInfo(
"JackFram/llama-68m", speculative_model="abhigoyal/vllm-medusa-llama-68m-random"
),
Expand Down
Loading
Loading