-
-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[Spec Decode] Add hidden states extraction system #33736
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
LucasWilkinson
merged 26 commits into
vllm-project:main
from
fynnsu:extract_hidden_states
Mar 2, 2026
Merged
Changes from all commits
Commits
Show all changes
26 commits
Select commit
Hold shift + click to select a range
27bad96
Add extract_hidden_states speculation method
fynnsu 806715e
Add ExampleHiddenStatesConnector
fynnsu ead1f8b
Cleanup / simplify components
fynnsu ac5d6ae
Add example usage
fynnsu d1ebb82
Fix shape issues
fynnsu e1a01ba
Add safetensors output to example script
fynnsu b9ab5cf
Update example script
fynnsu f332456
Small fixes to ExtractHiddenStatesProposer
fynnsu 1c39645
Simplify custom attention components
fynnsu ab3181c
Update ExtractHiddenStatesModel for simpler CacheOnlyAttentionLayer
fynnsu f04dcf6
Clean up ExampleHiddenStatesConnector
fynnsu ab1ec95
Add basic kv cache insertion function
fynnsu b3e64b9
Clean up ExtractHiddenStatesProposer implementation
fynnsu 2e3b4f4
Fix small issues
fynnsu 294d842
Add support for merging KVConnectorOutput
fynnsu 1ad24df
Improve config handling
fynnsu a624b72
Fix precommit issues
fynnsu 28aa17a
Cleanup todos
fynnsu 45e7e4a
precommit
fynnsu 7c53b80
Fix docs issues
fynnsu ef326ad
Add tests for extract_hidden_states
fynnsu 88d2fb4
Fix ExampleHiddenStatesConnector handling of batched prefill
fynnsu 7638668
Handle review comments
fynnsu 06396f5
lint
fynnsu 4806c95
Add ExtractHiddenStatesModel to test model registry
fynnsu 2268224
Update computation of padded tokens in extract_hidden_states
fynnsu File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,58 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| import tempfile | ||
|
|
||
| from safetensors import safe_open | ||
|
|
||
| from vllm import LLM, SamplingParams | ||
|
|
||
| # Example: Using the custom "extract_hidden_states" speculator method and | ||
| # ExampleHiddenStatesConnector to extract and save hidden states from vllm | ||
|
|
||
| with tempfile.TemporaryDirectory() as tmpdirname: | ||
| llm = LLM( | ||
| model="Qwen/Qwen3-8B", # Your target model | ||
| speculative_config={ | ||
| "method": "extract_hidden_states", | ||
| "num_speculative_tokens": 1, | ||
| "draft_model_config": { | ||
| "hf_config": { | ||
| "eagle_aux_hidden_state_layer_ids": [ # Target model layer indices | ||
| 1, | ||
| 2, | ||
| 3, | ||
| 4, | ||
| ], | ||
| } | ||
| }, | ||
| }, | ||
| kv_transfer_config={ | ||
| "kv_connector": "ExampleHiddenStatesConnector", | ||
| "kv_role": "kv_producer", | ||
| "kv_connector_extra_config": { | ||
| "shared_storage_path": tmpdirname, | ||
| }, | ||
| }, | ||
| ) | ||
|
|
||
| prompts = ["Generate a sentence with hidden states", "Write a python function"] | ||
| sampling_params = SamplingParams(max_tokens=1) | ||
| outputs = llm.generate(prompts, sampling_params) | ||
|
|
||
| for output in outputs: | ||
| print("\nPrompt:", output.prompt) | ||
| print("Prompt token ids:", output.prompt_token_ids) | ||
|
|
||
| hidden_states_path = output.kv_transfer_params.get("hidden_states_path") | ||
| assert hidden_states_path is not None | ||
| print("Prompt hidden states path:", hidden_states_path) | ||
|
|
||
| with safe_open(hidden_states_path, "pt") as f: | ||
| token_ids = f.get_tensor("token_ids") | ||
| hidden_states = f.get_tensor("hidden_states") | ||
|
|
||
| print("Extracted token ids:", token_ids) # Matches prompt token ids | ||
| print( | ||
| "Extracted hidden states shape:", hidden_states.shape | ||
| ) # [num_hidden_layers, prompt len, hidden size] | ||
| print("Extracted hidden states:", hidden_states) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
120 changes: 120 additions & 0 deletions
120
tests/v1/kv_connector/extract_hidden_states_integration/predictable_llama.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,120 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
|
||
| """Predictable dummy model for testing extract_hidden_states. | ||
|
|
||
| Subclasses LlamaForCausalLM but overrides the model to produce deterministic | ||
| hidden states: layer i outputs values equal to (i). | ||
| """ | ||
|
|
||
| from collections.abc import Iterable | ||
|
|
||
| import torch | ||
| import torch.nn as nn | ||
|
|
||
| from vllm.config import VllmConfig | ||
| from vllm.model_executor.models.llama import LlamaForCausalLM | ||
| from vllm.sequence import IntermediateTensors | ||
|
|
||
|
|
||
| class PredictableLlamaModel(nn.Module): | ||
| def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): | ||
| super().__init__() | ||
| self.config = vllm_config.model_config.hf_config | ||
| self.aux_hidden_state_layers = tuple[int, ...]() | ||
|
|
||
| # Create minimal embed_tokens for embedding | ||
| from vllm.model_executor.layers.vocab_parallel_embedding import ( | ||
| VocabParallelEmbedding, | ||
| ) | ||
|
|
||
| self.embed_tokens = VocabParallelEmbedding( | ||
| self.config.vocab_size, | ||
| self.config.hidden_size, | ||
| ) | ||
|
|
||
| # Required for pipeline parallelism | ||
| from vllm.model_executor.models.utils import ( | ||
| make_empty_intermediate_tensors_factory, | ||
| ) | ||
|
|
||
| self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( | ||
| ["hidden_states", "residual"], self.config.hidden_size | ||
| ) | ||
|
|
||
| def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: | ||
| """Embed input IDs.""" | ||
| return self.embed_tokens(input_ids) | ||
|
|
||
| def forward( | ||
| self, | ||
| input_ids: torch.Tensor | None, | ||
| positions: torch.Tensor, | ||
| intermediate_tensors: IntermediateTensors | None, | ||
| inputs_embeds: torch.Tensor | None = None, | ||
| **extra_layer_kwargs, | ||
| ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: | ||
| """Forward pass that produces predictable outputs. | ||
|
|
||
| Returns: | ||
| If aux_hidden_state_layers is set: (hidden_states, aux_hidden_states) | ||
| Otherwise: hidden_states | ||
| """ | ||
| # Determine sequence length | ||
| if inputs_embeds is not None: | ||
| seq_len = inputs_embeds.shape[0] | ||
| device = inputs_embeds.device | ||
| elif input_ids is not None: | ||
| seq_len = input_ids.shape[0] if input_ids.ndim == 1 else input_ids.shape[-1] | ||
| device = input_ids.device | ||
| else: | ||
| raise ValueError("Either input_ids or inputs_embeds must be provided") | ||
|
|
||
| # Final hidden states (last layer value) | ||
| hidden_states = torch.full( | ||
| (seq_len, self.config.hidden_size), | ||
| fill_value=float(self.config.num_hidden_layers), | ||
| device=device, | ||
| dtype=torch.bfloat16, | ||
| ) | ||
|
|
||
| # Check if we need auxiliary hidden states | ||
| if len(self.aux_hidden_state_layers) > 0: | ||
| aux_hidden_states = [] | ||
| for layer_idx in self.aux_hidden_state_layers: | ||
| # Fill with (layer_idx) for predictability | ||
| layer_hidden = torch.full( | ||
| (seq_len, self.config.hidden_size), | ||
| fill_value=float(layer_idx), | ||
| device=device, | ||
| dtype=torch.bfloat16, | ||
| ) | ||
| aux_hidden_states.append(layer_hidden) | ||
|
|
||
| return hidden_states, aux_hidden_states | ||
|
|
||
| return hidden_states | ||
|
|
||
| def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: | ||
| """Skip weight loading.""" | ||
| return set() | ||
|
|
||
|
|
||
| class PredictableLlamaForCausalLM(LlamaForCausalLM): | ||
| """Predictable Llama model for testing. | ||
|
|
||
| Overrides _init_model to use PredictableLlamaModel instead of LlamaModel. | ||
| """ | ||
|
|
||
| def _init_model( | ||
| self, | ||
| vllm_config: VllmConfig, | ||
| prefix: str = "", | ||
| layer_type: type[nn.Module] | None = None, | ||
| ): | ||
| """Initialize with predictable model.""" | ||
| return PredictableLlamaModel(vllm_config=vllm_config, prefix=prefix) | ||
|
|
||
| def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: | ||
| """Skip weight loading for dummy model.""" | ||
| return set() |
155 changes: 155 additions & 0 deletions
155
tests/v1/kv_connector/extract_hidden_states_integration/test_extraction.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,155 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
|
||
| import gc | ||
| import os | ||
|
|
||
| import pytest | ||
| import torch | ||
| from safetensors import safe_open | ||
|
|
||
| from vllm import LLM, ModelRegistry, SamplingParams | ||
|
|
||
|
|
||
| def get_and_check_output(output, expected_shape): | ||
| assert output.kv_transfer_params is not None | ||
| hidden_states_path = output.kv_transfer_params.get("hidden_states_path") | ||
| assert hidden_states_path is not None | ||
| assert os.path.exists(hidden_states_path) | ||
|
|
||
| # Load and verify the saved tensors | ||
| with safe_open(hidden_states_path, "pt") as f: | ||
| # Check that token_ids and hidden_states are present | ||
| tensor_names = f.keys() | ||
| assert "token_ids" in tensor_names | ||
| assert "hidden_states" in tensor_names | ||
|
|
||
| token_ids = f.get_tensor("token_ids") | ||
| hidden_states = f.get_tensor("hidden_states") | ||
|
|
||
| prompt_token_ids = output.prompt_token_ids | ||
| assert torch.equal(token_ids, torch.tensor(prompt_token_ids)) | ||
|
|
||
| assert hidden_states.shape == expected_shape | ||
|
|
||
| # Verify hidden_states are not all zeros (i.e., they were actually computed) | ||
| assert not torch.allclose(hidden_states, torch.zeros_like(hidden_states)) | ||
|
|
||
| return token_ids, hidden_states | ||
|
|
||
|
|
||
| @pytest.fixture(scope="module") | ||
| def predictable_llama_config_path(tmp_path_factory): | ||
| """Create a minimal LlamaConfig for PredictableLlamaForCausalLM.""" | ||
| from transformers import LlamaConfig, LlamaTokenizerFast | ||
|
|
||
| config_dir = tmp_path_factory.mktemp("predictable_llama") | ||
|
|
||
| # Create a minimal Llama config with small dimensions | ||
| config = LlamaConfig( | ||
| vocab_size=1000, | ||
| hidden_size=256, | ||
| intermediate_size=512, | ||
| num_hidden_layers=24, # Enough layers to test various layer_ids | ||
| num_attention_heads=4, | ||
| num_key_value_heads=4, | ||
| max_position_embeddings=128, | ||
| architectures=["PredictableLlamaForCausalLM"], | ||
| ) | ||
|
|
||
| # Save config | ||
| config.save_pretrained(config_dir) | ||
|
|
||
| # Create a simple tokenizer | ||
| tokenizer = LlamaTokenizerFast.from_pretrained( | ||
| "TinyLlama/TinyLlama-1.1B-Chat-v1.0", | ||
| cache_dir=os.path.expanduser("~/.cache/huggingface"), | ||
| ) | ||
| tokenizer.save_pretrained(config_dir) | ||
|
|
||
| return str(config_dir) | ||
|
|
||
|
|
||
| @pytest.fixture(scope="module", autouse=True) | ||
| def register_predictable_model(): | ||
| """Register the PredictableLlamaForCausalLM model.""" | ||
| from .predictable_llama import PredictableLlamaForCausalLM | ||
|
|
||
| if "PredictableLlamaForCausalLM" not in ModelRegistry.get_supported_archs(): | ||
| ModelRegistry.register_model( | ||
| "PredictableLlamaForCausalLM", PredictableLlamaForCausalLM | ||
| ) | ||
| yield | ||
|
|
||
|
|
||
| def test_extract_hidden_states_with_predictable_dummy_model( | ||
| predictable_llama_config_path, tmp_path | ||
| ): | ||
| """Comprehensive test using a predictable dummy model with synthetic weights. | ||
|
|
||
| The PredictableLlamaForCausalLM outputs deterministic hidden states where | ||
| each layer produces values equal to (layer_index). This test verifies: | ||
| 1. Hidden states are correctly extracted from requested layers | ||
| 2. Values match the expected predictable pattern | ||
| 3. Layer ordering is preserved correctly (non-sequential layer IDs) | ||
| 4. Multiple prompts of different lengths produce consistent layer values | ||
| """ | ||
| # Test with non-sequential layer ordering to verify correct association | ||
| layer_ids = [5, 2, 10] | ||
| num_layers = len(layer_ids) | ||
|
|
||
| llm = LLM( | ||
| model=predictable_llama_config_path, | ||
| speculative_config={ | ||
| "method": "extract_hidden_states", | ||
| "num_speculative_tokens": 1, | ||
| "draft_model_config": { | ||
| "hf_config": {"eagle_aux_hidden_state_layer_ids": layer_ids} | ||
| }, | ||
| }, | ||
| kv_transfer_config={ | ||
| "kv_connector": "ExampleHiddenStatesConnector", | ||
| "kv_role": "kv_producer", | ||
| "kv_connector_extra_config": {"shared_storage_path": tmp_path}, | ||
| }, | ||
| max_model_len=128, | ||
| enforce_eager=True, | ||
| trust_remote_code=True, | ||
| load_format="dummy", # Don't try to load real weights | ||
| ) | ||
|
|
||
| # Test with multiple prompts of different lengths | ||
| prompts = [ | ||
| "Short", | ||
| "Medium length", | ||
| "Much longer prompt with many tokens", | ||
| "Much longer prompt with many tokens", # repeated prompt | ||
| ] | ||
| sampling_params = SamplingParams(max_tokens=1, temperature=0.0) | ||
| hidden_size = llm.llm_engine.model_config.get_hidden_size() | ||
| outputs = llm.generate(prompts, sampling_params) | ||
| del llm | ||
| gc.collect() | ||
|
|
||
| assert len(outputs) == len(prompts) | ||
|
|
||
| for output in outputs: | ||
| # hidden_states shape is [prompt_len, num_hidden_layers, hidden_size] | ||
| expected_shape = ( | ||
| len(output.prompt_token_ids), | ||
| num_layers, | ||
| hidden_size, | ||
| ) | ||
| _token_ids, hidden_states = get_and_check_output(output, expected_shape) | ||
|
|
||
| for idx, layer_id in enumerate(layer_ids): | ||
| layer_hidden = hidden_states[:, idx, :] | ||
| assert torch.allclose( | ||
| layer_hidden, | ||
| torch.full_like(layer_hidden, layer_id), | ||
| atol=1e-5, | ||
| ), ( | ||
| f"Layer {layer_id} at position {idx} should output {float(layer_id)}, " | ||
| f"but got mean={layer_hidden.mean():.3f}, " | ||
| f"min={layer_hidden.min():.3f}, max={layer_hidden.max():.3f}" | ||
| ) | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you specifically test that this work if max_num_batched_tokens is smaller than the longest sequence? We should make sure that chunked prefill works with hidden state extraction
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch, we were actually failing in this case (only saving first chunk). I've updated the code to make sure we're handling in-progress requests correctly and keeping track of their block ids. This ended up just being changes to the kv connector, which is now even more inefficient because it overwrites the file on disk as it grows.
That being said, the purpose of this connector is really just to have a baseline working solution for debugging / tests, so I think this is okay for now. We will develop a more performant implementation in a future pr.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with your sentiment, glad it caught a real issue!