Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
48ecbc7
CKPT: 86% acceptance rate
tomasruizt Jan 28, 2026
a23f4f0
Put block_tables_by_gid in CommonAttentionMetadata
tomasruizt Jan 29, 2026
1434bfe
Simplify
tomasruizt Jan 29, 2026
c3fbddf
CKPT: 100% acceptance rate
tomasruizt Jan 29, 2026
6246a8c
Pass block size in CommonAttentionMetadata
tomasruizt Jan 29, 2026
874f733
Reduce if-else statements
tomasruizt Jan 29, 2026
6ad3311
Introduce slot_mapping_buffer_by_gid
tomasruizt Jan 29, 2026
0ced07a
Invalidate slot_mapping and block_table_tensor while they are not com…
tomasruizt Jan 29, 2026
3a52f2e
Use defaultdict for buffers
tomasruizt Jan 29, 2026
01e591c
Include MetadataBuilder in CommonAttentionMetadata
tomasruizt Jan 29, 2026
575b0a6
Make code concise
tomasruizt Jan 29, 2026
1e6551f
Minimize changes
tomasruizt Jan 29, 2026
d1c7932
Minor refactor
tomasruizt Jan 30, 2026
b1bd1b5
remove partially wrong stringdocs
tomasruizt Jan 30, 2026
6308f76
Separate GID for draft and target
tomasruizt Jan 30, 2026
e3f72b1
100% acceptance rate with VLLM_BATCH_INVARIANT=1
tomasruizt Jan 30, 2026
bce1bc1
Monkeypatch VLLM_BATCH_INVARIANT=1 when needed
tomasruizt Jan 30, 2026
8770b28
Only use FLASH_ATTN backend for batch invariance
tomasruizt Jan 30, 2026
5a01a30
Introduce constant DRAFT_MODEL_PREFIX
tomasruizt Jan 30, 2026
56227f3
missing use of constant
tomasruizt Jan 30, 2026
213d67e
Simplify code
tomasruizt Jan 30, 2026
5d7d3a4
Docs change
tomasruizt Jan 30, 2026
ca41d45
Appease linter
tomasruizt Jan 30, 2026
3cc5741
Merge branch 'main' into feature/spec-decode-gemma3-2
tomasruizt Jan 30, 2026
4875d96
Appease mypy
tomasruizt Jan 30, 2026
d93347f
Make tests pass (tests/v1/spec_decode/test_eagle.py)
tomasruizt Jan 30, 2026
4fa8664
Remove block size from CommonAttentionMetadata
tomasruizt Jan 30, 2026
405cedc
Use a single CommonAttentionMetadata per KV-cache group
tomasruizt Feb 2, 2026
e99e17a
Fold usage of single CommonAttentionMetadata into group
tomasruizt Feb 3, 2026
8378a05
Introduce custom type
tomasruizt Feb 3, 2026
88497b2
Simplify set_inputs_first_pass() signature
tomasruizt Feb 3, 2026
a7720bf
Fix gpt-oss problem for enforce_eager=True
tomasruizt Feb 3, 2026
1ee0b1d
Fix for enforce_eager=False
tomasruizt Feb 3, 2026
9841080
Merge branch 'main' into feature/spec-decode-gemma3-2
tomasruizt Feb 4, 2026
3d713c7
Revert forward_context changes
tomasruizt Feb 4, 2026
966e46b
Reduce changes, simplify code
tomasruizt Feb 4, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 39 additions & 10 deletions tests/v1/e2e/test_spec_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,15 @@ def test_mtp_correctness(
cleanup_dist_env_and_memory()


def some_high_acceptance_metrics() -> dict:
return {
"sampling_config": greedy_sampling(),
"num_speculative_tokens": 3,
"expected_acceptance_len": 2.8 + 1,
"expected_acceptance_rate": 0.90,
}


@dataclass
class ArgsTest:
target_model: str
Expand All @@ -631,6 +640,8 @@ class ArgsTest:
gpu_memory_utilization: float = 0.5
dataset: str = "test_prompts"
num_prompts: int = 100
# Some settings only get 100% acceptance_rate with VLLM_BATCH_INVARIANT=1
use_batch_invariance: bool = False


cases = [
Expand All @@ -652,12 +663,36 @@ class ArgsTest:
expected_acceptance_len=2.8 + 1,
expected_acceptance_rate=0.9,
),
# Multiple KV Cache groups
ArgsTest(
target_model="google/gemma-3-270m-it",
draft_model="google/gemma-3-270m-it",
sampling_config=greedy_sampling(),
num_speculative_tokens=3,
expected_acceptance_len=4,
expected_acceptance_rate=1,
# Without batch invariance, acceptance rate is ~86%
use_batch_invariance=True,
),
# GPT-OSS MoE models with different target/draft sizes
# Tests MoE layer resolution in speculative decoding.
ArgsTest(
target_model="openai/gpt-oss-120b",
draft_model="openai/gpt-oss-20b",
# Leave some headroom for CUDA graph capture.
gpu_memory_utilization=0.85,
**some_high_acceptance_metrics(),
),
]


@pytest.mark.parametrize("args", cases)
@pytest.mark.parametrize("enforce_eager", [True, False])
def test_draft_model_correctness(args: ArgsTest, enforce_eager: bool):
def test_draft_model_correctness(
args: ArgsTest, enforce_eager: bool, monkeypatch: pytest.MonkeyPatch
):
if args.use_batch_invariance:
monkeypatch.setenv("VLLM_BATCH_INVARIANT", "1")
assert_draft_model_correctness(args, enforce_eager)


Expand Down Expand Up @@ -753,6 +788,8 @@ def test_draft_model_engine_args_rejects_invalid_tp_argname():
def assert_draft_model_correctness(args: ArgsTest, enforce_eager: bool):
"""Compare the outputs using and not using speculative decoding.
In the greedy decoding case, the outputs must match EXACTLY."""
attention_config = {"backend": "FLASH_ATTN"} if args.use_batch_invariance else None

test_prompts: list[Messages] = get_messages(
dataset=args.dataset, n=args.num_prompts
)
Expand All @@ -773,6 +810,7 @@ def assert_draft_model_correctness(args: ArgsTest, enforce_eager: bool):
tensor_parallel_size=args.target_tensor_parallel_size,
enforce_eager=enforce_eager,
disable_log_stats=False, # enables get_metrics()
attention_config=attention_config,
)
# we don't check the outputs, only check the metrics
spec_llm.chat(test_prompts, args.sampling_config)
Expand Down Expand Up @@ -804,15 +842,6 @@ def get_messages(dataset: str, n: int) -> list[Messages]:
raise NotImplementedError(f"Dataset '{dataset}' not implemented")


def some_high_acceptance_metrics() -> dict:
return {
"sampling_config": greedy_sampling(),
"num_speculative_tokens": 3,
"expected_acceptance_len": 2.8 + 1,
"expected_acceptance_rate": 0.90,
}


def test_merge_toks_kernel():
device = "cuda"
merged_len = 5 + 2 # len(target_toks) = 5, batch_size = 2
Expand Down
57 changes: 41 additions & 16 deletions tests/v1/spec_decode/test_eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,14 +148,19 @@ def test_prepare_next_token_ids():
block_size=16,
device=device,
)
kv_cache_group_id = 0
common_attn_metadata_by_gid = {kv_cache_group_id: common_attn_metadata}
# Mock connection of layer names to kv_cache_group_ids
proposer.attn_layer_names = ["layer0"]
proposer.layer_names_to_kv_cache_gid = {"layer0": kv_cache_group_id}

expected_valid_sampled_tokens_count = torch.tensor(
[2, 5, 0, 0], dtype=torch.int32, device=device
)

next_token_ids_from_padded, valid_sampled_tokens_count = (
proposer.prepare_next_token_ids_padded(
common_attn_metadata,
common_attn_metadata_by_gid,
sampled_token_ids_tensor,
mock_requests,
mock_input_batch,
Expand Down Expand Up @@ -248,10 +253,17 @@ def test_prepare_inputs():
)
proposer = _create_proposer("eagle", 1)

updated_metadata, token_indices = proposer.prepare_inputs(
common_attn_metadata, sampled_token_ids, num_draft_tokens
# Mock connection of layer names to kv_cache_group_ids
kv_cache_group_id = 0
common_attn_metadata_by_gid = {kv_cache_group_id: common_attn_metadata}
proposer.attn_layer_names = ["layer0"]
proposer.layer_names_to_kv_cache_gid = {"layer0": kv_cache_group_id}

updated_cm_by_gid, token_indices = proposer.prepare_inputs(
common_attn_metadata_by_gid, sampled_token_ids, num_draft_tokens
)

updated_metadata = updated_cm_by_gid[kv_cache_group_id]
assert torch.equal(updated_metadata.query_start_loc, expected_cu_num_tokens)
assert token_indices.shape[0] == expected_cu_num_tokens[-1].item()
assert torch.equal(token_indices, expected_token_indices)
Expand Down Expand Up @@ -306,16 +318,25 @@ def test_prepare_inputs_padded():

proposer = _create_proposer("eagle", num_speculative_tokens)

output_metadata, token_indices_to_sample, num_rejected_tokens_gpu = (
# Mock connection of layer names to kv_cache_group_ids
kv_cache_group_id = 0
common_attn_metadata_by_gid = {kv_cache_group_id: common_attn_metadata}
proposer.attn_layer_names = ["layer0"]
proposer.layer_names_to_kv_cache_gid = {"layer0": kv_cache_group_id}

output_cm_by_gid, token_indices_to_sample, num_rejected_tokens_gpu = (
proposer.prepare_inputs_padded(
common_attn_metadata, spec_decode_metadata, valid_sampled_tokens_count
common_attn_metadata_by_gid,
spec_decode_metadata,
valid_sampled_tokens_count,
)
)

# Verify num_rejected_tokens_gpu is calculated correctly
expected_num_rejected = torch.tensor([1, 0, 2], dtype=torch.int32, device=device)
assert torch.equal(num_rejected_tokens_gpu, expected_num_rejected)

output_metadata = output_cm_by_gid[kv_cache_group_id]
assert output_metadata.max_query_len == 3
assert torch.equal(output_metadata.query_start_loc, expected_query_start_loc)
assert torch.equal(token_indices_to_sample, expected_token_indices_to_sample)
Expand Down Expand Up @@ -566,22 +587,24 @@ def create_deterministic_logits(token_ids):

# Mock runner for attention metadata building
proposer.runner = mock.MagicMock()
proposer.runner.attn_groups.append([mock.MagicMock()])
proposer.runner.attn_groups[0][
0
].get_metadata_builder.return_value = attn_metadata_builder
attn_group = mock.MagicMock()
attn_group.get_metadata_builder.return_value = attn_metadata_builder
attn_group.layer_names = proposer.attn_layer_names
proposer.runner.attn_groups = [[attn_group]]
proposer._get_attention_metadata_builder = mock.MagicMock(
return_value=attn_metadata_builder
)
kv_cache_group_id = 0
common_attn_metadata_by_gid = {kv_cache_group_id: common_attn_metadata}

result = proposer.propose(
target_token_ids=target_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
next_token_ids=next_token_ids,
last_token_indices=None,
common_attn_metadata=common_attn_metadata,
sampling_metadata=sampling_metadata,
cm_by_gid=common_attn_metadata_by_gid,
)

assert result.shape == (batch_size, num_speculative_tokens)
Expand Down Expand Up @@ -702,11 +725,11 @@ def create_deterministic_logits(token_ids, k: int):

# Mock runner for attention metadata building.
proposer.runner = mock.MagicMock()
proposer.runner.attn_groups.append([mock.MagicMock()])
proposer.runner.attn_groups[0][0].metadata_builders = [attn_metadata_builder]
proposer.runner.attn_groups[0][
0
].get_metadata_builder.return_value = attn_metadata_builder
attn_group = mock.MagicMock()
attn_group.layer_names = proposer.attn_layer_names
attn_group.metadata_builders = [attn_metadata_builder]
attn_group.get_metadata_builder.return_value = attn_metadata_builder
proposer.runner.attn_groups = [[attn_group]]
proposer._get_attention_metadata_builder = mock.MagicMock(
return_value=attn_metadata_builder
)
Expand All @@ -729,6 +752,8 @@ def create_deterministic_logits(token_ids, k: int):
block_size=16,
device=device,
)
kv_cache_group_id = 0
common_attn_metadata_by_gid = {kv_cache_group_id: common_attn_metadata}
sampling_metadata = mock.MagicMock()

# Propose draft tokens.
Expand All @@ -738,8 +763,8 @@ def create_deterministic_logits(token_ids, k: int):
target_hidden_states=target_hidden_states,
next_token_ids=next_token_ids,
last_token_indices=None,
common_attn_metadata=common_attn_metadata,
sampling_metadata=sampling_metadata,
cm_by_gid=common_attn_metadata_by_gid,
)
assert result.shape == (batch_size, num_speculative_tokens)

Expand Down
21 changes: 9 additions & 12 deletions tests/v1/worker/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import torch

from vllm.v1.core.kv_cache_utils import DRAFT_MODEL_PREFIX
from vllm.v1.worker.utils import bind_kv_cache


Expand Down Expand Up @@ -63,8 +64,8 @@ def test_bind_kv_cache_draft_model(default_vllm_config):
layer_names = [
"model.layers.0.attn",
"model.layers.1.attn",
"draft_model.layers.0.attn",
"draft_model.layers.1.attn",
f"{DRAFT_MODEL_PREFIX}.layers.0.attn",
f"{DRAFT_MODEL_PREFIX}.layers.1.attn",
]
ctx = {
layer_name: Attention(32, 128, 0.1, prefix=layer_name)
Expand All @@ -76,17 +77,13 @@ def test_bind_kv_cache_draft_model(default_vllm_config):

assert ctx["model.layers.0.attn"].kv_cache[0] is kv_cache["model.layers.0.attn"]
assert ctx["model.layers.1.attn"].kv_cache[0] is kv_cache["model.layers.1.attn"]
assert (
ctx["draft_model.layers.0.attn"].kv_cache[0]
is kv_cache["draft_model.layers.0.attn"]
)
assert (
ctx["draft_model.layers.1.attn"].kv_cache[0]
is kv_cache["draft_model.layers.1.attn"]
)
draft_layer_0 = f"{DRAFT_MODEL_PREFIX}.layers.0.attn"
draft_layer_1 = f"{DRAFT_MODEL_PREFIX}.layers.1.attn"
assert ctx[draft_layer_0].kv_cache[0] is kv_cache[draft_layer_0]
assert ctx[draft_layer_1].kv_cache[0] is kv_cache[draft_layer_1]

# caches are ordered by layer_index, interleaving target and draft model
assert runner_kv_caches[0] is kv_cache["model.layers.0.attn"]
assert runner_kv_caches[1] is kv_cache["draft_model.layers.0.attn"]
assert runner_kv_caches[1] is kv_cache[draft_layer_0]
assert runner_kv_caches[2] is kv_cache["model.layers.1.attn"]
assert runner_kv_caches[3] is kv_cache["draft_model.layers.1.attn"]
assert runner_kv_caches[3] is kv_cache[draft_layer_1]
6 changes: 6 additions & 0 deletions vllm/v1/attention/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,12 @@ def unpadded(
)


# Mapping from KV cache group ID to its CommonAttentionMetadata.
# Each KV cache group may have different block_table_tensor and slot_mapping,
# while sharing other fields like query_start_loc and seq_lens.
CommonAttnMetadataByGid = dict[int, CommonAttentionMetadata]


M = TypeVar("M")


Expand Down
2 changes: 0 additions & 2 deletions vllm/v1/attention/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,7 +828,6 @@ def get_dcp_local_seq_lens(
def extend_all_queries_by_1(
common_attn_metadata: CommonAttentionMetadata,
arange: torch.Tensor,
new_slot_mapping: torch.Tensor,
) -> CommonAttentionMetadata:
"""
Creates a new CommonAttentionMetadata with all query lengths increased by 1.
Expand All @@ -852,7 +851,6 @@ def extend_all_queries_by_1(
# All query lens increase by 1, so max query len increases by 1
max_query_len=cad.max_query_len + 1,
max_seq_len=cad.max_seq_len + 1,
slot_mapping=new_slot_mapping,
)
return new_cad

Expand Down
13 changes: 10 additions & 3 deletions vllm/v1/core/kv_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ def maybe_convert_block_hash(hash_bytes: BlockHash) -> ExternalBlockHash:

logger = init_logger(__name__)

# Prefix for draft model layer names in the KV cache (e.g. "draft_model.layers.0.attn").
# Must match the prefix passed to get_model() when loading the draft model.
DRAFT_MODEL_PREFIX = "draft_model"

# The hash seed for the first block of any prefix block sequence.
#
# We use a random value to avoid hash collisions or PYTHONHASHSEED environment
Expand Down Expand Up @@ -1011,12 +1015,15 @@ def _get_kv_cache_groups_uniform_page_size(
Returns:
The generated KVCacheGroupSpecs
"""
# Group all layers by kv_cache_spec.
# Group all layers by kv_cache_spec AND whether it's a draft or target layer.
# E.g., 2 full attention layers and 3 sliding window attention layers,
# -> (full.0, full.1), (sw.0, sw.1, sw.2).
same_type_layers: dict[KVCacheSpec, list[str]] = defaultdict(list)
# Draft and target layers are kept in separate groups to prevent KV cache
# cross-contamination during speculative decoding.
same_type_layers: dict[tuple[KVCacheSpec, bool], list[str]] = defaultdict(list)
for layer_name, layer_spec in kv_cache_spec.items():
same_type_layers[layer_spec].append(layer_name)
is_draft_layer = layer_name.startswith(DRAFT_MODEL_PREFIX)
same_type_layers[(layer_spec, is_draft_layer)].append(layer_name)

# Split each group into smaller groups, to make the number of layers in each
# group identical. Add padding to the last group of each type if necessary.
Expand Down
Loading