Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
8e7bcc9
save work
ekagra-ranjan Oct 10, 2025
bee764b
move to dynamic folder and scripts are in working condition
ekagra-ranjan Jan 4, 2026
e61145f
sync main
ekagra-ranjan Jan 4, 2026
06916e5
pre
ekagra-ranjan Jan 4, 2026
5e63d51
start stiching
ekagra-ranjan Jan 4, 2026
b683f26
pipeline works and offline script moved
ekagra-ranjan Jan 5, 2026
618b5fd
load dynamic sd config
ekagra-ranjan Jan 15, 2026
c540a75
remove offline bkp
ekagra-ranjan Jan 15, 2026
dfb2b31
remove
ekagra-ranjan Jan 15, 2026
bb65365
add runtime AL to goodput after warmup
ekagra-ranjan Jan 15, 2026
5fcf59e
Update vllm/v1/spec_decode/dynamic/manager.py
ekagra-ranjan Feb 8, 2026
409eb69
revert offline decoder to save loc diff
ekagra-ranjan Feb 8, 2026
d7a149f
refactor
ekagra-ranjan Feb 8, 2026
f222372
conflict
ekagra-ranjan Feb 8, 2026
44aed5e
conflict
ekagra-ranjan Feb 8, 2026
7ed3353
refactor
ekagra-ranjan Feb 8, 2026
116e76b
add timeout
ekagra-ranjan Feb 8, 2026
d70bd1d
fix
ekagra-ranjan Feb 8, 2026
24304d5
reduce loc in favor of #34105
ekagra-ranjan Feb 9, 2026
8fb86b1
remove test from dynamic manager main()
ekagra-ranjan Feb 9, 2026
11d43a5
remove comment and fix lint
ekagra-ranjan Feb 9, 2026
5df4999
fix mypy
ekagra-ranjan Feb 17, 2026
e76ad8e
fix mypy
ekagra-ranjan Feb 17, 2026
c1e880b
lint
ekagra-ranjan Feb 17, 2026
72d3c6f
lint
ekagra-ranjan Feb 17, 2026
63c7e17
conflict
ekagra-ranjan Mar 13, 2026
534d2f2
add AL computation to generate_config
ekagra-ranjan Mar 13, 2026
3f7196e
fix padding for async sched
ekagra-ranjan Mar 13, 2026
64fab8d
Update vllm/config/speculative.py
ekagra-ranjan Mar 13, 2026
1198aa7
fix docstring
ekagra-ranjan Mar 13, 2026
c07afd1
lint
ekagra-ranjan Mar 13, 2026
594dc0f
make DSD compat with async and padded drafter
ekagra-ranjan Mar 14, 2026
7060a8b
Merge branch 'main' of https://github.com/vllm-project/vllm into er-d…
ekagra-ranjan Mar 14, 2026
cd19750
lint
ekagra-ranjan Mar 14, 2026
4307feb
optimize DSD async scheduling by minimizing delay in propagating opti…
ekagra-ranjan Mar 17, 2026
018b4bd
dsd config path field
ekagra-ranjan Mar 17, 2026
36f5a36
refactor to simplify propose signature and update test
ekagra-ranjan Mar 17, 2026
c54ac4c
conflict
ekagra-ranjan Mar 17, 2026
046c39a
fix comma
ekagra-ranjan Mar 17, 2026
b59662d
Merge branch 'main' of https://github.com/vllm-project/vllm into er-d…
ekagra-ranjan Mar 31, 2026
917e3de
move towards DSD scheduler
ekagra-ranjan Mar 31, 2026
3e94ad0
move towards DSD scheduler
ekagra-ranjan Mar 31, 2026
8aa39fb
fix padded drafter
ekagra-ranjan Apr 1, 2026
2724097
lint
ekagra-ranjan Apr 1, 2026
e57e624
lint
ekagra-ranjan Apr 1, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions tests/v1/spec_decode/test_eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -995,6 +995,7 @@ def create_deterministic_logits(token_ids):
proposer.draft_attn_groups = [mock_attn_group]

result = proposer.propose(
num_speculative_tokens=num_speculative_tokens,
target_token_ids=target_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
Expand Down Expand Up @@ -1150,6 +1151,7 @@ def create_deterministic_logits(token_ids, k: int):

# Propose draft tokens.
result = proposer.propose(
num_speculative_tokens=num_speculative_tokens,
target_token_ids=target_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
Expand Down
2 changes: 2 additions & 0 deletions tests/v1/spec_decode/test_extract_hidden_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@ def test_propose():

# Call propose
draft_tokens = proposer.propose(
num_speculative_tokens=1,
sampled_token_ids=sampled_token_ids,
target_hidden_states=target_hidden_states,
common_attn_metadata=common_attn_metadata,
Expand Down Expand Up @@ -324,6 +325,7 @@ def test_propose_different_layer_counts(num_hidden_layers):
).unsqueeze(-1)

draft_tokens = proposer.propose(
num_speculative_tokens=1,
sampled_token_ids=sampled_token_ids,
target_hidden_states=target_hidden_states,
common_attn_metadata=common_attn_metadata,
Expand Down
1 change: 1 addition & 0 deletions tests/v1/spec_decode/test_mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ def create_deterministic_logits(batch_size, vocab_size, token_offset):

# Run propose
result = proposer.propose(
num_speculative_tokens=num_speculative_tokens,
target_token_ids=target_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
Expand Down
10 changes: 10 additions & 0 deletions tests/v1/spec_decode/test_ngram.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def get_ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer:
# No match.
token_ids_cpu = np.array([[1, 2, 3, 4, 5]])
result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose(
num_speculative_tokens=2,
sampled_token_ids=[[0]],
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
token_ids_cpu=token_ids_cpu,
Expand All @@ -90,6 +91,7 @@ def get_ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer:
# No match for 4-gram.
token_ids_cpu = np.array([[1, 2, 3, 4, 1, 2, 3]])
result = get_ngram_proposer(min_n=4, max_n=4, k=2).propose(
num_speculative_tokens=2,
sampled_token_ids=[[0]],
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
token_ids_cpu=token_ids_cpu,
Expand All @@ -99,6 +101,7 @@ def get_ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer:
# No match for 4-gram but match for 3-gram.
token_ids_cpu = np.array([[1, 2, 3, 4, 1, 2, 3]])
result = get_ngram_proposer(min_n=3, max_n=4, k=2).propose(
num_speculative_tokens=2,
sampled_token_ids=[[0]],
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
token_ids_cpu=token_ids_cpu,
Expand All @@ -109,6 +112,7 @@ def get_ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer:
# In this case, the proposer should return the 4-gram match.
token_ids_cpu = np.array([[2, 3, 4, 5, 1, 2, 3, 4, 1, 2, 3, 4]])
result = get_ngram_proposer(min_n=3, max_n=4, k=2).propose(
num_speculative_tokens=2,
sampled_token_ids=[[0]],
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
token_ids_cpu=token_ids_cpu,
Expand All @@ -118,6 +122,7 @@ def get_ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer:
# Match for 2-gram and 3-gram, but not 4-gram.
token_ids_cpu = np.array([[3, 4, 5, 2, 3, 4, 1, 2, 3, 4]])
result = get_ngram_proposer(min_n=2, max_n=4, k=2).propose(
num_speculative_tokens=2,
sampled_token_ids=[[0]],
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
token_ids_cpu=token_ids_cpu,
Expand All @@ -127,6 +132,7 @@ def get_ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer:
# Multiple 3-gram matched, but always pick the first one.
token_ids_cpu = np.array([[1, 2, 3, 100, 1, 2, 3, 200, 1, 2, 3, 300, 1, 2, 3]])
result = get_ngram_proposer(min_n=3, max_n=3, k=2).propose(
num_speculative_tokens=2,
sampled_token_ids=[[0]],
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
token_ids_cpu=token_ids_cpu,
Expand All @@ -136,6 +142,7 @@ def get_ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer:
# check empty input
token_ids_cpu = np.array([[]])
result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose(
num_speculative_tokens=2,
sampled_token_ids=[[0]],
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
token_ids_cpu=token_ids_cpu,
Expand All @@ -147,6 +154,7 @@ def get_ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer:
# second request has 3 tokens and no match. Padded with -1 for max len 5
token_ids_cpu = np.array([[1, 2, 3, 1, 2], [4, 5, 6, -1, -1]])
result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose(
num_speculative_tokens=2,
sampled_token_ids=[[0], [1]],
num_tokens_no_spec=np.array([5, 3]),
token_ids_cpu=token_ids_cpu,
Expand All @@ -166,6 +174,7 @@ def get_ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer:
num_tokens_no_spec = np.array([5, 3, 5], dtype=np.int32)
sampled_token_ids = [[2], [], [8]] # Empty list for request 1 simulates prefill
result = proposer.propose(
num_speculative_tokens=2,
sampled_token_ids=sampled_token_ids,
num_tokens_no_spec=num_tokens_no_spec,
token_ids_cpu=token_ids_cpu,
Expand Down Expand Up @@ -195,6 +204,7 @@ def get_ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer:
input_2[:3] = [4, 5, 6]
token_ids_cpu = np.array([input_1, input_2])
result = ngram_proposer.propose(
num_speculative_tokens=2,
sampled_token_ids=[[0], [1]],
num_tokens_no_spec=np.array([len(input_1), 3]),
token_ids_cpu=token_ids_cpu,
Expand Down
51 changes: 51 additions & 0 deletions vllm/config/speculative.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,41 @@
RejectionSampleMethod = Literal["strict", "probabilistic", "synthetic"]


@config
class DynamicSpeculativeConfig:
"""A mapping from batch size to optimal number of drafts to use for that
batch size. This is used to dynamically adjust the number of drafts used
based on the current batch size."""

is_online: bool = False
"""Whether the statistics are updated online or not during inference."""

batch_stats: dict[int, dict[int, float]] | None = None
"""
Batch statistics for different batch sizes and number of drafts.
The structure is as follows:
{
batch_size: {
num_drafts: itl (i.e., inter token latency in ms)
}
}

e.g.,
{
1: { 0: 6.87, 3: 9.41, 5: 10.8},
4: { 0: 7.3, 3: 9.95, 5: 11.59},
}

where bs 1 at K=3 has itl 9.41ms. K=0 means no speculative decoding.
"""

max_num_speculative_tokens: int | None = None
"""Maximum number of speculative tokens supported in the statistics."""

acceptance_rate_per_pos: list[float] | None = None
"""Acceptance rate per position on an offline dataset."""


@config
class SpeculativeConfig:
"""Configuration for speculative decoding."""
Expand Down Expand Up @@ -150,6 +185,12 @@ class SpeculativeConfig:
target_parallel_config: SkipValidation[ParallelConfig] = None # type: ignore
"""The parallel configuration for the target model."""

# dynamic speculative decoding control
dynamic_config_path: str | None = None
"""Path to config file for dynamic speculative decoding, if provided."""
dynamic_config: SkipValidation[DynamicSpeculativeConfig] | None = None
"""Loaded dynamic speculative config, populated from dynamic_config_path."""

# params generated in the post-init stage
draft_model_config: SkipValidation[ModelConfig] = None # type: ignore
"""The configuration of the draft model initialized internal."""
Expand Down Expand Up @@ -628,6 +669,16 @@ def __post_init__(self):
self.target_parallel_config, self.draft_tensor_parallel_size
)
)

# load DynamicSpeculativeConfig: maybe use get_hf_file_to_dict() later
if self.dynamic_config_path is not None:
import json

with open(self.dynamic_config_path) as f:
data = json.load(f)
Comment thread
ekagra-ranjan marked this conversation as resolved.

self.dynamic_config = DynamicSpeculativeConfig(**data)

return self

def _validate_suffix_decoding(self):
Expand Down
4 changes: 4 additions & 0 deletions vllm/v1/core/sched/async_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ def __init__(self, *args, **kwargs) -> None:
def _update_after_schedule(self, scheduler_output: SchedulerOutput) -> None:
super()._update_after_schedule(scheduler_output)
spec_decode_tokens = scheduler_output.scheduled_spec_decode_tokens
# Use the latest num of scheduled draft tokens in next step as placeholder.
self._spec_token_placeholders = [
-1
] * scheduler_output.num_spec_tokens_to_schedule
for req_id in scheduler_output.num_scheduled_tokens:
request = self.requests[req_id]
if request.is_prefill_chunk:
Expand Down
4 changes: 4 additions & 0 deletions vllm/v1/core/sched/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,10 @@ class SchedulerOutput:
# preventing stale NaN/data from corrupting attention or SSM computation.
new_block_ids_to_zero: list[int] | None = None

# Dynamic speculative decoding: optimal K chosen by scheduler.
# Number of spec tokens to schedule for the next step.
num_spec_tokens_to_schedule: int = 0

@classmethod
def make_empty(cls) -> "SchedulerOutput":
return cls(
Expand Down
23 changes: 21 additions & 2 deletions vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
from vllm.v1.metrics.stats import PrefixCacheStats, SchedulerStats
from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus, StreamingUpdate
from vllm.v1.spec_decode.dynamic.manager import DynamicSpeculativeDecodingManager
from vllm.v1.spec_decode.metrics import SpecDecodingStats
from vllm.v1.structured_output import StructuredOutputManager
from vllm.v1.utils import record_function_or_nullcontext
Expand Down Expand Up @@ -213,8 +214,15 @@ def __init__(
speculative_config = vllm_config.speculative_config
self.use_eagle = False
self.num_spec_tokens = self.num_lookahead_tokens = 0
self.dynamic_sd_manager: DynamicSpeculativeDecodingManager | None = None
if speculative_config:
self.num_spec_tokens = speculative_config.num_speculative_tokens
if speculative_config.dynamic_config:
self.dynamic_sd_manager = DynamicSpeculativeDecodingManager(
speculative_config.dynamic_config,
self.scheduler_config.max_num_seqs,
self.num_spec_tokens,
)
if speculative_config.use_eagle():
self.use_eagle = True
self.num_lookahead_tokens = self.num_spec_tokens
Expand Down Expand Up @@ -904,6 +912,13 @@ def schedule(self) -> SchedulerOutput:
else None
)

# Dynamic speculative decoding: compute optimal K
num_spec_tokens_to_schedule = self.num_spec_tokens
if self.dynamic_sd_manager is not None and len(num_scheduled_tokens) > 0:
num_spec_tokens_to_schedule = self.dynamic_sd_manager.step(
len(num_scheduled_tokens)
)

scheduler_output = SchedulerOutput(
scheduled_new_reqs=new_reqs_data,
scheduled_cached_reqs=cached_reqs_data,
Expand All @@ -920,6 +935,7 @@ def schedule(self) -> SchedulerOutput:
finished_req_ids=self.finished_req_ids,
free_encoder_mm_hashes=self.encoder_cache_manager.get_freed_mm_hashes(),
new_block_ids_to_zero=new_block_ids_to_zero,
num_spec_tokens_to_schedule=num_spec_tokens_to_schedule,
)

# NOTE(Kuntai): this function is designed for multiple purposes:
Expand Down Expand Up @@ -1975,12 +1991,15 @@ def make_spec_decoding_stats(
num_invalid_spec_tokens: dict[str, int] | None,
request_id: str,
) -> SpecDecodingStats | None:
if num_invalid_spec_tokens:
num_draft_tokens -= num_invalid_spec_tokens.get(request_id, 0)
if self.dynamic_sd_manager is not None and num_draft_tokens:
self.dynamic_sd_manager.observe_draft(num_draft_tokens, num_accepted_tokens)

if not self.log_stats or not num_draft_tokens:
return None
if spec_decoding_stats is None:
spec_decoding_stats = SpecDecodingStats.new(self.num_spec_tokens)
if num_invalid_spec_tokens:
num_draft_tokens -= num_invalid_spec_tokens.get(request_id, 0)
spec_decoding_stats.observe_draft(
num_draft_tokens=num_draft_tokens, num_accepted_tokens=num_accepted_tokens
)
Expand Down
2 changes: 2 additions & 0 deletions vllm/v1/spec_decode/dynamic/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
Loading
Loading