Skip to content

Commit 33d81dd

Browse files
committed
Speculative Decoding with Draft Model
Signed-off-by: Tomas Ruiz <[email protected]>
1 parent 94866d7 commit 33d81dd

File tree

15 files changed

+331
-87
lines changed

15 files changed

+331
-87
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
# Scripts for development
2+
scripts/
3+
14
# version file generated by setuptools-scm
25
/vllm/_version.py
36

pyproject.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,11 @@ markers = [
154154
"skip_v1: do not run this test with v1",
155155
"optional: optional tests that are automatically skipped, include --optional to run them",
156156
]
157+
# Show print statements and logs during test execution
158+
addopts = "-s --tb=short --log-cli-level=INFO"
159+
log_cli = true
160+
log_cli_format = "%(asctime)s [%(levelname)8s] %(name)s: %(message)s"
161+
log_cli_date_format = "%Y-%m-%d %H:%M:%S"
157162

158163
[tool.ty.src]
159164
root = "./vllm"

tests/v1/e2e/test_spec_decode.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from __future__ import annotations
44

55
import random
6+
from dataclasses import dataclass
67
from typing import Any, Union
78

89
import pytest
@@ -13,7 +14,9 @@
1314
from vllm.assets.base import VLLM_S3_BUCKET_URL
1415
from vllm.assets.image import VLM_IMAGES_DIR
1516
from vllm.distributed import cleanup_dist_env_and_memory
17+
from vllm.outputs import RequestOutput
1618
from vllm.platforms import current_platform
19+
from vllm.v1.spec_decode.metrics import compute_acceptance_rate
1720

1821

1922
def get_test_prompts(mm_enabled: bool):
@@ -69,9 +72,17 @@ def get_test_prompts(mm_enabled: bool):
6972

7073
@pytest.fixture
7174
def sampling_config():
75+
return greedy_sampling()
76+
77+
78+
def greedy_sampling() -> SamplingParams:
7279
return SamplingParams(temperature=0, max_tokens=10, ignore_eos=False)
7380

7481

82+
def stochastic_sampling() -> SamplingParams:
83+
return SamplingParams(temperature=1.0, max_tokens=10, ignore_eos=False)
84+
85+
7586
@pytest.fixture
7687
def model_name():
7788
return "meta-llama/Llama-3.1-8B-Instruct"
@@ -230,3 +241,107 @@ def test_eagle_correctness(
230241
del spec_llm
231242
torch.cuda.empty_cache()
232243
cleanup_dist_env_and_memory()
244+
245+
246+
@dataclass
247+
class ArgsTest:
248+
model: str
249+
draft_model: str
250+
sampling_config: SamplingParams
251+
expected_acceptance_rate: float
252+
expected_same_output_fraction: float
253+
# Defaults
254+
enforce_eager: bool = True
255+
max_model_len: int = 1024
256+
gpu_memory_utilization: float = 0.5
257+
258+
259+
cases = [
260+
ArgsTest(
261+
model="meta-llama/Llama-3.2-1B-Instruct",
262+
draft_model="meta-llama/Llama-3.2-1B-Instruct",
263+
sampling_config=greedy_sampling(),
264+
expected_acceptance_rate=0.85,
265+
expected_same_output_fraction=0.5,
266+
),
267+
ArgsTest(
268+
model="Qwen/Qwen3-1.7B",
269+
draft_model="Qwen/Qwen3-0.6B",
270+
sampling_config=stochastic_sampling(),
271+
expected_acceptance_rate=0.9,
272+
expected_same_output_fraction=0.9,
273+
),
274+
ArgsTest(
275+
model="Qwen/Qwen3-1.7B",
276+
draft_model="Qwen/Qwen3-0.6B",
277+
sampling_config=greedy_sampling(),
278+
expected_acceptance_rate=1.0,
279+
expected_same_output_fraction=1.0,
280+
),
281+
]
282+
283+
284+
@pytest.mark.parametrize("args", cases)
285+
def test_draft_model_correctness(args: ArgsTest,
286+
monkeypatch: pytest.MonkeyPatch):
287+
"""Compare the outputs using and not using speculative decoding.
288+
In the greedy decoding case, the outputs must match EXACTLY."""
289+
monkeypatch.setenv("VLLM_USE_V1", "1")
290+
test_prompts = get_test_prompts(mm_enabled=False)
291+
292+
spec_llm = LLM(
293+
model=args.model,
294+
speculative_config={
295+
"model": args.draft_model,
296+
"method": "draft_model",
297+
"num_speculative_tokens": 3,
298+
"max_model_len": args.max_model_len,
299+
"enforce_eager": args.enforce_eager,
300+
},
301+
max_model_len=args.max_model_len,
302+
gpu_memory_utilization=args.gpu_memory_utilization,
303+
enforce_eager=args.enforce_eager,
304+
disable_log_stats=False # enables get_metrics()
305+
)
306+
spec_outputs = spec_llm.chat(test_prompts, args.sampling_config)
307+
acceptance_rate = compute_acceptance_rate(spec_llm.get_metrics())
308+
del spec_llm # CLEANUP
309+
torch.cuda.empty_cache()
310+
cleanup_dist_env_and_memory()
311+
312+
ref_llm = LLM(
313+
model=args.model,
314+
max_model_len=args.max_model_len,
315+
gpu_memory_utilization=args.gpu_memory_utilization,
316+
enforce_eager=args.enforce_eager,
317+
)
318+
ref_outputs = ref_llm.chat(test_prompts, args.sampling_config)
319+
del ref_llm # CLEANUP
320+
torch.cuda.empty_cache()
321+
cleanup_dist_env_and_memory()
322+
323+
assert len(ref_outputs) > 0
324+
assert len(ref_outputs) == len(spec_outputs)
325+
326+
assert_outputs_match(ref_outputs, spec_outputs,
327+
args.expected_same_output_fraction)
328+
329+
assert acceptance_rate >= args.expected_acceptance_rate
330+
331+
332+
def assert_outputs_match(ref_outputs: list[RequestOutput],
333+
spec_outputs: list[RequestOutput], fraction: float):
334+
"""Assert that at least "fraction" of the prompts match exactly"""
335+
matches = 0
336+
misses = 0
337+
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
338+
if ref_output.outputs[0].text == spec_output.outputs[0].text:
339+
matches += 1
340+
else:
341+
misses += 1
342+
print(f"ref_output: {ref_output.outputs[0].text}")
343+
print(f"spec_output: {spec_output.outputs[0].text}")
344+
345+
# Heuristic: at least a certain fraction of the outputs to match exactly
346+
# Upon failure, inspect the outputs to check for inaccuracy.
347+
assert matches >= int(fraction * len(ref_outputs))

vllm/benchmarks/throughput.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,17 @@
3131
from vllm.outputs import RequestOutput
3232
from vllm.sampling_params import BeamSearchParams
3333
from vllm.utils import merge_async_iterators
34+
from vllm.v1.metrics.reader import Metric
35+
from vllm.v1.spec_decode.metrics import compute_acceptance_rate
3436

3537

3638
def run_vllm(
3739
requests: list[SampleRequest],
3840
n: int,
3941
engine_args: EngineArgs,
42+
do_profile: bool,
4043
disable_detokenize: bool = False,
41-
) -> tuple[float, Optional[list[RequestOutput]]]:
44+
) -> "Results":
4245
from vllm import LLM, SamplingParams
4346
llm = LLM(**dataclasses.asdict(engine_args))
4447
assert all(
@@ -74,12 +77,16 @@ def run_vllm(
7477

7578
outputs = None
7679
if not use_beam_search:
80+
if do_profile:
81+
llm.start_profile()
7782
start = time.perf_counter()
7883
outputs = llm.generate(prompts,
7984
sampling_params,
8085
lora_request=lora_requests,
8186
use_tqdm=True)
8287
end = time.perf_counter()
88+
if do_profile:
89+
llm.stop_profile()
8390
else:
8491
assert lora_requests is None, "BeamSearch API does not support LoRA"
8592
prompts = [request.prompt for request in requests]
@@ -96,7 +103,8 @@ def run_vllm(
96103
ignore_eos=True,
97104
))
98105
end = time.perf_counter()
99-
return end - start, outputs
106+
runtime = end - start
107+
return Results(runtime=runtime, metrics=llm.get_metrics(), outputs=outputs)
100108

101109

102110
def run_vllm_chat(
@@ -138,6 +146,13 @@ def run_vllm_chat(
138146
return end - start, outputs
139147

140148

149+
@dataclasses.dataclass
150+
class Results:
151+
runtime: float
152+
metrics: list[Metric]
153+
outputs: list
154+
155+
141156
async def run_vllm_async(
142157
requests: list[SampleRequest],
143158
n: int,
@@ -496,6 +511,12 @@ def add_cli_args(parser: argparse.ArgumentParser):
496511
type=str,
497512
default=None,
498513
help='Path to save the throughput results in JSON format.')
514+
parser.add_argument(
515+
"--print-acceptance-rate",
516+
action="store_true",
517+
default=False,
518+
help="Print the acceptance rate of the speculative decoding model.",
519+
)
499520
parser.add_argument("--async-engine",
500521
action='store_true',
501522
default=False,
@@ -543,6 +564,10 @@ def add_cli_args(parser: argparse.ArgumentParser):
543564
type=str,
544565
default=None,
545566
help="Split of the HF dataset.")
567+
parser.add_argument("--profile",
568+
action="store_true",
569+
default=False,
570+
help="Profile the model.")
546571

547572
# prefix repetition dataset
548573
prefix_repetition_group = parser.add_argument_group(
@@ -604,9 +629,12 @@ def main(args: argparse.Namespace):
604629
args.disable_detokenize,
605630
))
606631
else:
607-
elapsed_time, request_outputs = run_vllm(
632+
bresults = run_vllm(
608633
requests, args.n, EngineArgs.from_cli_args(args),
609-
args.disable_detokenize)
634+
do_profile=args.profile,
635+
disable_detokenize=args.disable_detokenize)
636+
elapsed_time = bresults.runtime
637+
request_outputs = bresults.outputs
610638
elif args.backend == "hf":
611639
assert args.tensor_parallel_size == 1
612640
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
@@ -651,6 +679,9 @@ def main(args: argparse.Namespace):
651679
f"{total_output_tokens / elapsed_time:.2f} output tokens/s")
652680
print(f"Total num prompt tokens: {total_prompt_tokens}")
653681
print(f"Total num output tokens: {total_output_tokens}")
682+
if args.print_acceptance_rate:
683+
rate = compute_acceptance_rate(bresults.metrics)
684+
print(f"Acceptance rate: {rate:.2f}")
654685

655686
# Output JSON results if specified
656687
if args.output_json:

vllm/config/__init__.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2168,6 +2168,7 @@ def __post_init__(self):
21682168
code_revision=self.code_revision,
21692169
tokenizer_revision=self.target_model_config.
21702170
tokenizer_revision,
2171+
max_model_len=self.max_model_len,
21712172
spec_target_max_model_len=self.target_model_config.
21722173
max_model_len,
21732174
quantization=self.quantization,
@@ -2209,11 +2210,6 @@ def __post_init__(self):
22092210
)
22102211
else:
22112212
self.method = "draft_model"
2212-
raise NotImplementedError(
2213-
"Speculative decoding with draft model is not "
2214-
"supported yet. Please consider using other "
2215-
"speculative decoding methods such as ngram, medusa, "
2216-
"eagle, or deepseek_mtp.")
22172213

22182214
# Replace hf_config for EAGLE draft_model
22192215
if self.method in ("eagle", "eagle3"):
@@ -2424,6 +2420,9 @@ def num_lookahead_slots(self) -> int:
24242420
def use_eagle(self) -> bool:
24252421
return self.method in ("eagle", "eagle3", "deepseek_mtp", "ernie_mtp")
24262422

2423+
def uses_draft_model(self) -> bool:
2424+
return self.method == "draft_model"
2425+
24272426
def __repr__(self) -> str:
24282427
method = self.method
24292428
model = None if method == "ngram" else self.draft_model_config.model

vllm/engine/arg_utils.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1474,10 +1474,7 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
14741474
# V1 supports N-gram, Medusa, and Eagle speculative decoding.
14751475
if (self.speculative_config is not None
14761476
and self.speculative_config.get("method") == "draft_model"):
1477-
raise NotImplementedError(
1478-
"Speculative decoding with draft model is not supported yet. "
1479-
"Please consider using other speculative decoding methods "
1480-
"such as ngram, medusa, eagle, or deepseek_mtp.")
1477+
return True
14811478

14821479
V1_BACKENDS = [
14831480
"FLASH_ATTN_VLLM_V1",

vllm/model_executor/model_loader/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,12 +111,14 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
111111

112112
def get_model(*,
113113
vllm_config: VllmConfig,
114-
model_config: Optional[ModelConfig] = None) -> nn.Module:
114+
model_config: Optional[ModelConfig] = None,
115+
prefix: str = "") -> nn.Module:
115116
loader = get_model_loader(vllm_config.load_config)
116117
if model_config is None:
117118
model_config = vllm_config.model_config
118119
return loader.load_model(vllm_config=vllm_config,
119-
model_config=model_config)
120+
model_config=model_config,
121+
prefix=prefix)
120122

121123

122124
__all__ = [

vllm/model_executor/model_loader/base_loader.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,10 @@ def load_weights(self, model: nn.Module,
3131
inplace weights loading for an already-initialized model"""
3232
raise NotImplementedError
3333

34-
def load_model(self, vllm_config: VllmConfig,
35-
model_config: ModelConfig) -> nn.Module:
34+
def load_model(self,
35+
vllm_config: VllmConfig,
36+
model_config: ModelConfig,
37+
prefix: str = "") -> nn.Module:
3638
"""Load a model with the given configurations."""
3739
device_config = vllm_config.device_config
3840
load_config = vllm_config.load_config
@@ -42,7 +44,8 @@ def load_model(self, vllm_config: VllmConfig,
4244
with set_default_torch_dtype(model_config.dtype):
4345
with target_device:
4446
model = initialize_model(vllm_config=vllm_config,
45-
model_config=model_config)
47+
model_config=model_config,
48+
prefix=prefix)
4649

4750
logger.debug("Loading weights on %s ...", load_device)
4851
# Quantization does not happen in `load_weights` but after it

vllm/model_executor/model_loader/gguf_loader.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,10 @@ def load_weights(self, model: nn.Module,
123123
model.load_weights(
124124
self._get_weights_iterator(local_model_path, gguf_weights_map))
125125

126-
def load_model(self, vllm_config: VllmConfig,
127-
model_config: ModelConfig) -> nn.Module:
126+
def load_model(self,
127+
vllm_config: VllmConfig,
128+
model_config: ModelConfig,
129+
prefix: str = "") -> nn.Module:
128130
device_config = vllm_config.device_config
129131
local_model_path = self._prepare_weights(model_config.model)
130132
gguf_weights_map = self._get_gguf_weights_map(model_config)
@@ -147,7 +149,8 @@ def load_model(self, vllm_config: VllmConfig,
147149
target_device = torch.device(device_config.device)
148150
with set_default_torch_dtype(model_config.dtype):
149151
with target_device:
150-
model = initialize_model(vllm_config=vllm_config)
152+
model = initialize_model(vllm_config=vllm_config,
153+
prefix=prefix)
151154
self.load_weights(model, model_config)
152155

153156
process_weights_after_loading(model, model_config, target_device)

0 commit comments

Comments
 (0)