Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
66 commits
Select commit Hold shift + click to select a range
8e7bcc9
save work
ekagra-ranjan Oct 10, 2025
bee764b
move to dynamic folder and scripts are in working condition
ekagra-ranjan Jan 4, 2026
e61145f
sync main
ekagra-ranjan Jan 4, 2026
06916e5
pre
ekagra-ranjan Jan 4, 2026
5e63d51
start stiching
ekagra-ranjan Jan 4, 2026
b683f26
pipeline works and offline script moved
ekagra-ranjan Jan 5, 2026
618b5fd
load dynamic sd config
ekagra-ranjan Jan 15, 2026
c540a75
remove offline bkp
ekagra-ranjan Jan 15, 2026
dfb2b31
remove
ekagra-ranjan Jan 15, 2026
bb65365
add runtime AL to goodput after warmup
ekagra-ranjan Jan 15, 2026
5fcf59e
Update vllm/v1/spec_decode/dynamic/manager.py
ekagra-ranjan Feb 8, 2026
409eb69
revert offline decoder to save loc diff
ekagra-ranjan Feb 8, 2026
d7a149f
refactor
ekagra-ranjan Feb 8, 2026
f222372
conflict
ekagra-ranjan Feb 8, 2026
44aed5e
conflict
ekagra-ranjan Feb 8, 2026
7ed3353
refactor
ekagra-ranjan Feb 8, 2026
116e76b
add timeout
ekagra-ranjan Feb 8, 2026
d70bd1d
fix
ekagra-ranjan Feb 8, 2026
24304d5
reduce loc in favor of #34105
ekagra-ranjan Feb 9, 2026
8fb86b1
remove test from dynamic manager main()
ekagra-ranjan Feb 9, 2026
11d43a5
remove comment and fix lint
ekagra-ranjan Feb 9, 2026
5df4999
fix mypy
ekagra-ranjan Feb 17, 2026
e76ad8e
fix mypy
ekagra-ranjan Feb 17, 2026
c1e880b
lint
ekagra-ranjan Feb 17, 2026
72d3c6f
lint
ekagra-ranjan Feb 17, 2026
63c7e17
conflict
ekagra-ranjan Mar 13, 2026
534d2f2
add AL computation to generate_config
ekagra-ranjan Mar 13, 2026
3f7196e
fix padding for async sched
ekagra-ranjan Mar 13, 2026
64fab8d
Update vllm/config/speculative.py
ekagra-ranjan Mar 13, 2026
1198aa7
fix docstring
ekagra-ranjan Mar 13, 2026
c07afd1
lint
ekagra-ranjan Mar 13, 2026
594dc0f
make DSD compat with async and padded drafter
ekagra-ranjan Mar 14, 2026
7060a8b
Merge branch 'main' of https://github.com/vllm-project/vllm into er-d…
ekagra-ranjan Mar 14, 2026
cd19750
lint
ekagra-ranjan Mar 14, 2026
4307feb
optimize DSD async scheduling by minimizing delay in propagating opti…
ekagra-ranjan Mar 17, 2026
018b4bd
dsd config path field
ekagra-ranjan Mar 17, 2026
36f5a36
refactor to simplify propose signature and update test
ekagra-ranjan Mar 17, 2026
c54ac4c
conflict
ekagra-ranjan Mar 17, 2026
046c39a
fix comma
ekagra-ranjan Mar 17, 2026
b59662d
Merge branch 'main' of https://github.com/vllm-project/vllm into er-d…
ekagra-ranjan Mar 31, 2026
917e3de
move towards DSD scheduler
ekagra-ranjan Mar 31, 2026
3e94ad0
move towards DSD scheduler
ekagra-ranjan Mar 31, 2026
8aa39fb
fix padded drafter
ekagra-ranjan Apr 1, 2026
2724097
lint
ekagra-ranjan Apr 1, 2026
e57e624
lint
ekagra-ranjan Apr 1, 2026
d1601b3
conflict
ekagra-ranjan May 9, 2026
abb3a0b
works with offline profiler
ekagra-ranjan May 14, 2026
f7aa05c
remove offline profiler and use K manually
ekagra-ranjan May 14, 2026
51d6eaf
conflict
ekagra-ranjan May 14, 2026
8cdb086
lint
ekagra-ranjan May 14, 2026
8331255
lint and pytest assert
ekagra-ranjan May 19, 2026
a1e2185
remove config abstraction
ekagra-ranjan May 20, 2026
6954f68
revert noop
ekagra-ranjan Jun 1, 2026
3c7aef8
remove DSD manager and update DynamicSDSchedule data formet
ekagra-ranjan Jun 1, 2026
9fa1e9d
Merge branch 'main' of https://github.com/vllm-project/vllm into er-d…
ekagra-ranjan Jun 1, 2026
815fd9c
lint
ekagra-ranjan Jun 1, 2026
503935f
override DSD to piecewise
ekagra-ranjan Jun 2, 2026
b35f7c2
Merge branch 'main' of https://github.com/vllm-project/vllm into er-d…
ekagra-ranjan Jun 11, 2026
5484d2c
disable mrv2 for dsd
ekagra-ranjan Jun 11, 2026
6fa1d52
add doc
ekagra-ranjan Jun 12, 2026
4b9ae13
add doc
ekagra-ranjan Jun 12, 2026
417ddc4
add doc
ekagra-ranjan Jun 12, 2026
065e846
conflict
ekagra-ranjan Jun 12, 2026
5ea4979
@benchislett fix typo in doc
benchislett Jun 12, 2026
f4500b9
Merge branch 'main' into er-dynami-sd
ekagra-ranjan Jun 12, 2026
3ce1639
Merge branch 'main' into er-dynami-sd
ekagra-ranjan Jun 13, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
234 changes: 2 additions & 232 deletions examples/offline_inference/spec_decode.py
Original file line number Diff line number Diff line change
@@ -1,234 +1,4 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from transformers import AutoTokenizer

from vllm import LLM, SamplingParams
from vllm.benchmarks.datasets import add_dataset_parser, get_samples
from vllm.inputs import TokensPrompt
from vllm.v1.metrics.reader import Counter, Vector

try:
from vllm.utils.argparse_utils import FlexibleArgumentParser
except ImportError:
from argparse import ArgumentParser as FlexibleArgumentParser


QUESTION = "What is the content of each image?"
IMAGE_URLS = [
"https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/duck.jpg",
"https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/lion.jpg",
"https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/flycatcher.jpeg",
"https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/somefish.jpg",
"https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/starfish.jpg",
"https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/snail.jpg",
"https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/thistle.jpg",
"https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/husky.jpg",
"https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/orangetabbycat.jpg",
"https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/guineapig.jpg",
"https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/rabbit.jpg",
"https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/horsepony.jpg",
]


def get_custom_mm_prompts(num_prompts):
prompts = []
for url in IMAGE_URLS:
prompts.append(
[
{"type": "image_url", "image_url": {"url": url}},
{"type": "text", "text": QUESTION},
]
)
if num_prompts > len(IMAGE_URLS):
prompts = prompts * (num_prompts // len(IMAGE_URLS) + 1)

return [[{"role": "user", "content": prompt}] for prompt in prompts[:num_prompts]]


def parse_args():
parser = FlexibleArgumentParser()
add_dataset_parser(parser)
parser.add_argument("--test", action="store_true")
parser.add_argument(
"--method",
type=str,
default="eagle",
choices=["ngram", "eagle", "eagle3", "mtp"],
)
parser.add_argument("--num-spec-tokens", type=int, default=2)
parser.add_argument("--prompt-lookup-max", type=int, default=5)
parser.add_argument("--prompt-lookup-min", type=int, default=2)
parser.add_argument("--tp", type=int, default=1)
parser.add_argument("--enforce-eager", action="store_true")
parser.add_argument("--enable-chunked-prefill", action="store_true")
parser.add_argument("--max-model-len", type=int, default=16384)
parser.add_argument("--temp", type=float, default=0)
parser.add_argument("--top-p", type=float, default=1.0)
parser.add_argument("--top-k", type=int, default=-1)
parser.add_argument("--print-output", action="store_true")
parser.add_argument("--output-len", type=int, default=256)
parser.add_argument("--model-dir", type=str, default=None)
parser.add_argument("--eagle-dir", type=str, default=None)
parser.add_argument("--custom-mm-prompts", action="store_true")
return parser.parse_args()


def main(args):
args.endpoint_type = "openai-chat"

model_dir = args.model_dir
if args.model_dir is None:
if args.custom_mm_prompts:
raise ValueError(
"custom_mm_prompts requires mm based models"
"default llama3.1-8b-instruct is not mm based"
"please specify model_dir to give a mm based model"
)
model_dir = "meta-llama/Llama-3.1-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_dir)
args.custom_skip_chat_template = True

if not args.custom_mm_prompts:
prompts = get_samples(args, tokenizer)
# add_special_tokens is False to avoid adding bos twice
# when using chat templates
prompt_ids = [
tokenizer.encode(prompt.prompt, add_special_tokens=False)
for prompt in prompts
]
else:
prompts = get_custom_mm_prompts(args.num_prompts)

if args.method == "eagle" or args.method == "eagle3":
eagle_dir = args.eagle_dir
if args.method == "eagle" and eagle_dir is None:
eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"

elif args.method == "eagle3" and eagle_dir is None:
eagle_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
speculative_config = {
"method": args.method,
"model": eagle_dir,
"num_speculative_tokens": args.num_spec_tokens,
}
elif args.method == "ngram":
speculative_config = {
"method": "ngram",
"num_speculative_tokens": args.num_spec_tokens,
"prompt_lookup_max": args.prompt_lookup_max,
"prompt_lookup_min": args.prompt_lookup_min,
}
elif args.method == "mtp":
speculative_config = {
"method": "mtp",
"num_speculative_tokens": args.num_spec_tokens,
}
else:
raise ValueError(f"unknown method: {args.method}")

llm = LLM(
model=model_dir,
trust_remote_code=True,
tensor_parallel_size=args.tp,
enable_chunked_prefill=args.enable_chunked_prefill,
enforce_eager=args.enforce_eager,
gpu_memory_utilization=0.9,
speculative_config=speculative_config,
disable_log_stats=False,
max_model_len=args.max_model_len,
limit_mm_per_prompt={"image": 5},
disable_chunked_mm_input=True,
)

sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len)
if not args.custom_mm_prompts:
outputs = llm.generate(
[TokensPrompt(prompt_token_ids=x) for x in prompt_ids],
sampling_params=sampling_params,
)
else:
outputs = llm.chat(prompts, sampling_params=sampling_params)

# print the generated text
if args.print_output:
for output in outputs:
print("-" * 50)
print(f"prompt: {output.prompt}")
print(f"generated text: {output.outputs[0].text}")
print("-" * 50)

metrics = llm.get_metrics()

total_num_output_tokens = sum(
len(output.outputs[0].token_ids) for output in outputs
)
num_drafts = 0
num_draft_tokens = 0
num_accepted_tokens = 0
acceptance_counts = [0] * args.num_spec_tokens
for metric in metrics:
if metric.name == "vllm:spec_decode_num_drafts":
assert isinstance(metric, Counter)
num_drafts += metric.value
elif metric.name == "vllm:spec_decode_num_draft_tokens":
assert isinstance(metric, Counter)
num_draft_tokens += metric.value
elif metric.name == "vllm:spec_decode_num_accepted_tokens":
assert isinstance(metric, Counter)
num_accepted_tokens += metric.value
elif metric.name == "vllm:spec_decode_num_accepted_tokens_per_pos":
assert isinstance(metric, Vector)
for pos in range(len(metric.values)):
acceptance_counts[pos] += metric.values[pos]

print("-" * 50)
print(f"total_num_output_tokens: {total_num_output_tokens}")
print(f"num_drafts: {num_drafts}")
print(f"num_draft_tokens: {num_draft_tokens}")
print(f"num_accepted_tokens: {num_accepted_tokens}")
acceptance_length = 1 + (num_accepted_tokens / num_drafts) if num_drafts > 0 else 1
print(f"mean acceptance length: {acceptance_length:.2f}")
print("-" * 50)

# print acceptance at each token position
for i in range(len(acceptance_counts)):
acceptance_rate = acceptance_counts[i] / num_drafts if num_drafts > 0 else 0
print(f"acceptance at token {i}: {acceptance_rate:.2f}")

return acceptance_length

from vllm.v1.spec_decode.offline import entrypoint as spec_decode_main

if __name__ == "__main__":
args = parse_args()
acceptance_length = main(args)

if args.test:
# takes ~30s to run on 1xH100
assert args.method in ["eagle", "eagle3"]
assert args.tp == 1
assert args.num_spec_tokens == 3
assert args.dataset_name == "hf"
assert args.dataset_path == "philschmid/mt-bench"
assert args.num_prompts == 80
assert args.temp == 0
assert args.top_p == 1.0
assert args.top_k == -1
assert args.enable_chunked_prefill

# check acceptance length is within 2% of expected value
rtol = 0.02
expected_acceptance_length = 2.296 if args.method == "eagle" else 2.811

assert (
acceptance_length <= (1 + rtol) * expected_acceptance_length
and acceptance_length >= (1 - rtol) * expected_acceptance_length
), (
f"acceptance_length {acceptance_length} is not "
f"within {rtol * 100}% of {expected_acceptance_length}"
)

print(
f"Test passed! Expected AL: "
f"{expected_acceptance_length}, got {acceptance_length}"
)
spec_decode_main()
54 changes: 53 additions & 1 deletion vllm/config/speculative.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import ast
from typing import TYPE_CHECKING, Any, Literal, get_args

from pydantic import Field, SkipValidation, model_validator
from pydantic import Field, SkipValidation, model_validator, BaseModel
from pydantic.dataclasses import dataclass
from typing_extensions import Self

Expand Down Expand Up @@ -49,6 +49,43 @@
]


# @dataclass
class DynamicSpeculativeConfig(BaseModel):
Comment thread
ekagra-ranjan marked this conversation as resolved.
Outdated
# """A mapping from batch size to optimal number of drafts to use for that
# batch size. This is used to dynamically adjust the number of drafts used
# based on the current batch size."""
# optimal_num_speculative_tokens: dict[int, int] = None

"""Whether the statistics are updated online or not during inference."""

is_online: bool = False

"""
Batch statistics for different batch sizes and number of drafts.
The structure is as follows:
{
batch_size: {
num_drafts: itl (i.e., inter token latency in ms)
}
}

e.g.,
{
1: { 0: 6.87, 3: 9.41, 5: 10.8},
4: { 0: 7.3, 3: 9.95, 5: 11.59},
}

where bs 1 at K=3 has itl 9.41ms. K=0 means no speculative decoding.
"""
batch_stats: dict[int, dict[int, float]] = None

"""Maximum number of speculative tokens supported in the statistics."""
max_num_speculative_tokens: int = None

"""Acceptance rate per position on an offline dataset."""
acceptance_rate_per_pos: list[float] = None
Comment thread
ekagra-ranjan marked this conversation as resolved.
Outdated


@config
@dataclass
class SpeculativeConfig:
Expand Down Expand Up @@ -119,6 +156,10 @@ class SpeculativeConfig:
target_parallel_config: SkipValidation[ParallelConfig] = None # type: ignore
"""The parallel configuration for the target model."""

# dynamic speculative decoding control
"""Path to config file for dynamic speculative decoding, if provided."""
dynamic_config_path: str | None = None

# params generated in the post-init stage
draft_model_config: SkipValidation[ModelConfig] = None # type: ignore
"""The configuration of the draft model initialized internal."""
Expand Down Expand Up @@ -462,6 +503,17 @@ def __post_init__(self):
self.target_parallel_config, self.draft_tensor_parallel_size
)
)

Comment thread
ekagra-ranjan marked this conversation as resolved.
Outdated
# load DynamicSpeculativeConfig: maybe use get_hf_file_to_dict() later
if self.dynamic_config_path is not None:
import json
with open(self.dynamic_config_path) as f:
data = json.load(f)
Comment thread
ekagra-ranjan marked this conversation as resolved.
Outdated

self.dynamic_config = DynamicSpeculativeConfig.model_validate(data)
else:
self.dynamic_config = None

return self

def _validate_suffix_decoding(self):
Expand Down
Loading