Skip to content

Drafter Supports Multiple KVCache Groups#33318

Open
tomasruizt wants to merge 9 commits intovllm-project:mainfrom
tomasruizt:feature/spec-decode-gemma3-2
Open

Drafter Supports Multiple KVCache Groups#33318
tomasruizt wants to merge 9 commits intovllm-project:mainfrom
tomasruizt:feature/spec-decode-gemma3-2

Conversation

@tomasruizt
Copy link
Copy Markdown
Contributor

@tomasruizt tomasruizt commented Jan 29, 2026

Summary

This PR enables models with multiple KV-cache groups to be used as drafters in speculative decoding. Previously, the speculative decoding infrastructure assumed a single KV-cache group, which prevented the use of architectures like Gemma3 and GPT-OSS MoE models as drafters.

Key changes:

  • Refactored CommonAttentionMetadata handling to support a dictionary of metadata per KV-cache group ID (CommonAttnMetadataByGid)
  • Added per-group slot-mapping buffers for draft model inference
  • Introduced layer_names_to_kv_cache_gid mapping to correctly route attention layers to their corresponding KV-cache groups

Fixes #33133

Try it: GPT-OSS 120b + 20b
vllm serve openai/gpt-oss-120b \
  --no-enable-prefix-caching \
  --speculative-config '{"method": "draft_model", "model": "openai/gpt-oss-20b", "num_speculative_tokens": 3}'
Try it: Gemma-3 1b + 270m
vllm serve google/gemma-3-1b-it \
  --no-enable-prefix-caching \
  --speculative-config '{"method": "draft_model", "model": "google/gemma-3-270m-it", "num_speculative_tokens": 3}'

New Test Cases

Two new end-to-end test cases validate the feature:

  1. Gemma3 (270m): Tests a model architecture with multiple KV-cache groups (different head configurations across layers). Achieves 100% acceptance rate with VLLM_BATCH_INVARIANT=1.

  2. GPT-OSS MoE (120b/20b): Tests MoE layer resolution in speculative decoding with different target/draft model sizes. Initially, this combination exhibited low acceptance rates due to the cold-start MoE optimization interfering with speculative decoding. This was resolved in [torch.compile] Don't do the fast moe cold start optimization if there is speculative decoding #33624, which disables that optimization when speculative decoding is active. This test case ensures MoE models continue to work correctly with speculative decoding.

Test Plan

  • Existing unit tests pass (tests/v1/spec_decode/test_eagle.py)
  • New e2e tests for Gemma3 and GPT-OSS pass (tests/v1/e2e/test_spec_decode.py)

Benchmark Summary

1. Correctness: Acceptance Rates

Multi-KV cache drafters produce correct outputs, confirmed by healthy acceptance rates (GPU=H200, dataset=mt-bench, 50 prompts, K=3, temp=0.0):

Configuration Acc. Length AR @ Pos 0 AR @ Pos 1 AR @ Pos 2 Overall AR Log
gpt-oss-120b + gpt-oss-20b 2.88 77.5% 61.9% 49.0% 62.78% log
gemma-3-1b + gemma-3-270m 2.52 67.7% 47.8% 36.1% 50.54% log

Note: On the main branch, these configurations fail with "All drafting layers should belong to the same kv cache group".

2. Regression Tests: No Degradation

Method Configuration Feature Branch Main Branch Diff Logs
EAGLE3 Llama-3.1-8B + EAGLE3 201.71 tok/s 201.71 tok/s 0% feature, main
draft_model Qwen3-32B + Qwen3-1.7B 75.23 tok/s (avg) 75.13 tok/s (avg) 0% feature-1, feature-2, main-1, main-2

3. Performance: Piecewise CUDA Graph Limitation

New multi-KV cache workloads are slower than expected due to piecewise CUDA graphs (#33341):

Configuration Standalone SD Slowdown Logs
gpt-oss-120b + gpt-oss-20b 200.69 tok/s 23.07 tok/s 8.7x baseline, SD
gemma-3-1b + gemma-3-270m 204.27 tok/s 92.19 tok/s 2.2x baseline, SD

Evidence: Standalone models dispatch to cudaGraphLaunch (full CUDA graphs), but during speculative decoding the draft model doesn't use cudaGraphLaunch, instead dispatching many small ops. This results in a massive slowdown (~3ms standalone vs ~95ms in SD, with PyTorch profiler on). This slowdown is NOT caused by changes of this PR. The PyTorch profiles are in this link.

Example Profiling Command
# Start server with profiling enabled
vllm serve openai/gpt-oss-120b \
  --no-enable-prefix-caching \
  --speculative-config '{"method": "draft_model", "model": "openai/gpt-oss-20b", "num_speculative_tokens": 3}'

# Run benchmark with profiling
vllm bench serve \
  --model openai/gpt-oss-120b \
  --dataset-name hf \
  --dataset-path philschmid/mt-bench \
  --num-prompts 1 \
  --output-len 10 \
  --profile \
  --endpoint /v1/completions

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for speculative decoding on models with multiple KV cache groups. The changes are extensive, primarily affecting vllm/v1/spec_decode/eagle.py, vllm/v1/spec_decode/draft_model.py, and vllm/v1/worker/gpu_model_runner.py. The implementation correctly computes and utilizes per-group attention metadata and slot mappings. A new test case for a multi-KV-group model (gemma-3-270m-it) has also been added, which is a valuable addition. I have identified one potential issue in vllm/v1/spec_decode/eagle.py concerning a fallback mechanism for slot mappings that could conceal bugs and result in incorrect behavior. My feedback includes a suggestion for a more robust implementation using assertions to ensure correctness.

@tomasruizt tomasruizt changed the title CKPT: 86% acceptance rate Drafter Supports Multiple KVCache Groups Jan 29, 2026
@tomasruizt
Copy link
Copy Markdown
Contributor Author

/gemini review

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for multiple KV cache groups in the drafter for speculative decoding. This is a significant and well-executed refactoring that enables speculative decoding for models with heterogeneous layer configurations, such as different attention mechanisms.

The changes primarily involve:

  • Extending CommonAttentionMetadata to carry per-group information like block tables, slot mappings, and block sizes.
  • Modifying the speculative decoding proposers (DraftModelProposer and EagleProposer) to iterate over KV cache groups and build attention metadata for each group individually.
  • Updating gpu_model_runner to populate this new per-group information.
  • Removing the previous assumption that all draft model layers belong to the same KV cache group.

The implementation appears robust and consistent across the modified files. A new end-to-end test case for a model with multiple KV cache groups (gemma-3-270m-it) has been added, which is great for ensuring correctness.

I have reviewed the changes and found no critical or high-severity issues. The code quality is good, and the changes are logical and well-contained.

@tomasruizt
Copy link
Copy Markdown
Contributor Author

/gemini review

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for multiple KV-cache groups in the drafter for speculative decoding, a significant enhancement that enables the use of models with mixed attention types as drafters. The changes are well-structured and consistently applied across core attention logic, speculative decoding proposers, and tests. A key change is the introduction of per-group KV cache information management, which is crucial for correctness. The addition of a new test case using google/gemma-3-270m-it is a great way to validate this new capability. Overall, the implementation appears solid. I have one high-severity comment regarding a potential inconsistency that could lead to issues with CUDA graph execution.

@mergify
Copy link
Copy Markdown

mergify bot commented Feb 4, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @tomasruizt.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Feb 4, 2026
@tomasruizt tomasruizt force-pushed the feature/spec-decode-gemma3-2 branch from 4d48715 to 966e46b Compare February 4, 2026 12:51
@mergify mergify bot removed the needs-rebase label Feb 4, 2026
@tomasruizt
Copy link
Copy Markdown
Contributor Author

tomasruizt commented Feb 4, 2026

Trying to resolve the DCO issue of CI, I seem to have changed authorship of the commits of the main merge, leading to this PR getting a bunch of unnecessary tags, and requesting many reviews unnecessarily 🤦 This is now reverted.

This enables models with multiple KV-cache groups (e.g., Gemma3, GPT-OSS MoE)
to be used as drafters in speculative decoding.

Key changes:
- Refactored CommonAttentionMetadata handling to support a dictionary of
  metadata per KV-cache group ID (CommonAttnMetadataByGid)
- Added per-group slot-mapping buffers for draft model inference
- Introduced layer_names_to_kv_cache_gid mapping to correctly route
  attention layers to their corresponding KV-cache groups

New test cases:
- Gemma3 (270m): multiple KV-cache groups with mixed attention
- GPT-OSS MoE (120b/20b): validates MoE layer resolution in spec decoding

Fixes vllm-project#33133

Signed-off-by: Tomas Ruiz <tomas.ruiz.te@gmail.com>
@tomasruizt tomasruizt force-pushed the feature/spec-decode-gemma3-2 branch from 966e46b to 0dac88e Compare February 4, 2026 12:55
@tomasruizt
Copy link
Copy Markdown
Contributor Author

/gemini review

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant and well-executed enhancement to the speculative decoding framework by enabling models with multiple KV-cache groups to function as drafters. The core of the change involves refactoring the handling of CommonAttentionMetadata to support a dictionary of metadata per KV-cache group ID (CommonAttnMetadataByGid), which is a clean and effective solution. The changes are consistently applied throughout the codebase, from the GPUModelRunner down to the SpecDecodeBaseProposer and its implementations. The removal of the validate_same_kv_cache_group limitation is a key outcome of this work. The addition of new end-to-end tests for Gemma3 and GPT-OSS MoE models provides strong validation for this new capability. Overall, the code is of high quality, and the changes are well-reasoned and correctly implemented.

@benchislett
Copy link
Copy Markdown
Collaborator

I still feel like this is not the right way to build this. It feels like we're building up new utilities instead of leveraging the existing model runner code. My ideal implementation here would factor out some of the attn builder logic from gpu model runner and then reuse it in the EAGLE code. If the model runner can already handle hybrid model inference, we should not need to implement any new tooling or abstractions to get the same support for drafting. We may need new abstractions to unify and enable code reuse, but that is not what I am seeing in this revision.

…cache_group_id()

Signed-off-by: Tomas Ruiz <tomas.ruiz.te@gmail.com>
Signed-off-by: Tomas Ruiz <tomas.ruiz.te@gmail.com>
@mergify
Copy link
Copy Markdown

mergify bot commented Feb 5, 2026

Hi @tomasruizt, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

Comment on lines +1659 to +1702
def get_slot_mappings_by_layer(
kv_cache_config: KVCacheConfig,
slot_mappings_by_gid: dict[int, torch.Tensor],
ubatch_slices: "UBatchSlices | None" = None,
) -> dict[str, torch.Tensor] | list[dict[str, torch.Tensor]]:
"""
Convert slot mappings from group ID indexing to layer name indexing.

Args:
kv_cache_config: KV cache configuration containing group-to-layer mappings.
slot_mappings_by_gid: Slot mappings keyed by KV cache group ID.
ubatch_slices: Optional ubatch slicing info for DBO (Disaggregated
Block Orchestrator). When provided, returns a list of sliced
mappings per ubatch.

Returns:
dict[str, torch.Tensor]: Slot mappings keyed by layer name for
ForwardContext, or list[dict[str, torch.Tensor]] when ubatch_slices
is provided.
"""
slot_mappings_by_layer: dict[str, torch.Tensor] = {
layer_name: slot_mappings_by_gid[gid]
for layer_name, gid in kv_cache_group_id_by_layer(kv_cache_config).items()
}

if ubatch_slices is not None:
result: list[dict[str, torch.Tensor]] = []
for ubatch in ubatch_slices:
sliced_mappings: dict[str, torch.Tensor] = {}
for layer_name, slot_mapping in slot_mappings_by_layer.items():
sliced_mappings[layer_name] = slot_mapping[ubatch.token_slice]
result.append(sliced_mappings)
return result

return slot_mappings_by_layer


def kv_cache_group_id_by_layer(kv_cache_config: KVCacheConfig) -> dict[str, int]:
"""Return a mapping from layer_name -> KV cache group ID."""
gid_by_layer: dict[str, int] = {}
for gid, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups):
for layer_name in kv_cache_group.layer_names:
gid_by_layer[layer_name] = gid
return gid_by_layer
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These functions were factored out of GpuModelRunner

Signed-off-by: Tomas Ruiz <tomas.ruiz.te@gmail.com>
Signed-off-by: Tomas Ruiz <tomas.ruiz.te@gmail.com>
Signed-off-by: Tomas Ruiz <tomas.ruiz.te@gmail.com>
@tomasruizt
Copy link
Copy Markdown
Contributor Author

tomasruizt commented Feb 5, 2026

@benchislett I understand where you are coming from!

I did factor out some code from GPUModelRunner to reuse in eagle.py (specifically get_slot_mappings_by_layer() and kv_cache_group_id_by_layer()). However, making the full attention metadata building logic reusable was difficult: _build_attn_metadata() has tight coupling with runner state (kv_cache_config, attn_groups, cuda graph buffers, closures with side-effects, etc.).

I'm happy to do that deeper refactor before or after this PR. I'd prefer merging this to avoid merge conflicts, but if it's worth blocking for, let's track the consolidation in a dedicated issue. What do you think?

The main aim in this PR is to enable more models for SD, which it does successfully.

Signed-off-by: Tomas Ruiz <tomas.ruiz.te@gmail.com>
@mergify
Copy link
Copy Markdown

mergify bot commented Feb 6, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @tomasruizt.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Feb 6, 2026
Resolve merge conflicts to bring the multi-KV-cache-group speculative
decoding feature up to date with main. Key resolutions:

- eagle.py: Adapt set_inputs_first_pass to use cm_by_gid (dict of
  CommonAttentionMetadata per group) on top of main's refactored
  needs_extra_input_slots branching and parallel drafting support.
  Per-group slot mapping computation in the draft model path.
- draft_model.py: Take main's simplified _get_model() pattern; the
  PR's set_inputs_first_pass logic now lives in the base class.
- kv_cache_utils.py: Merge both sets of imports.
- attention/backends/utils.py: Drop extend_all_queries_by_1 (superseded
  by extend_all_queries_by_N in spec_decode/utils.py).
- Tests: Use main's token_indices_to_sample rename, keep PR's cm_by_gid.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

Signed-off-by: Tomas Ruiz <tomas.ruiz.te@gmail.com>
Adapt the three test_set_inputs_first_pass_* tests (from main) to use
the multi-gid API: wrap CommonAttentionMetadata in cm_by_gid dict,
set up layer_names_to_kv_cache_gid, and mock _get_metadata_builder
for per-group slot mapping in the draft model/parallel drafting paths.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

Signed-off-by: Tomas Ruiz <tomas.ruiz.te@gmail.com>
@tomasruizt
Copy link
Copy Markdown
Contributor Author

tomasruizt commented Feb 25, 2026

@benchislett You're right about code reuse. As you've mentioned before, ModelRunnerV2 (MRV2) probably resolves these issues: The EagleSpeculator receives an InputBatch directly, eliminating the duplicated abstractions.

Since MRV2 is targeting default-on in Q1/Q2, investing into MRV1 would be likely throwaway work. The impactful path is porting method=draft_model and other SD methods to MRV2, where multi-KV-cache-group support may come out-of-the-box from the runner's architecture.

This PR enables multi KV-group drafters (Gemma3, GPT-OSS MoE) on V1, which is the only runner supporting method="draft_model" today. We probably should flesh out the SD transition to MRV2 in a dedicated issue.
What do you think?

@mergify
Copy link
Copy Markdown

mergify bot commented Feb 25, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @tomasruizt.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build cpu Related to CPU backends deepseek Related to DeepSeek models documentation Improvements or additions to documentation frontend gpt-oss Related to GPT-OSS models kv-connector llama Related to Llama models multi-modality Related to multi-modality (#4194) needs-rebase new-model Requests to new models nvidia performance Performance-related issues qwen Related to Qwen models rocm Related to AMD ROCm speculative-decoding structured-output tool-calling v1

Projects

Status: Todo
Status: No status
Status: No status
Status: No status
Status: To Triage

Development

Successfully merging this pull request may close these issues.

[Bug]: Using GPT OSS 20B as Drafter Throws Error

3 participants