Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
27bad96
Add extract_hidden_states speculation method
fynnsu Jan 3, 2026
806715e
Add ExampleHiddenStatesConnector
fynnsu Jan 3, 2026
ead1f8b
Cleanup / simplify components
fynnsu Jan 3, 2026
ac5d6ae
Add example usage
fynnsu Jan 3, 2026
d1ebb82
Fix shape issues
fynnsu Jan 3, 2026
e1a01ba
Add safetensors output to example script
fynnsu Jan 3, 2026
b9ab5cf
Update example script
fynnsu Feb 3, 2026
f332456
Small fixes to ExtractHiddenStatesProposer
fynnsu Feb 5, 2026
1c39645
Simplify custom attention components
fynnsu Feb 11, 2026
ab3181c
Update ExtractHiddenStatesModel for simpler CacheOnlyAttentionLayer
fynnsu Feb 11, 2026
f04dcf6
Clean up ExampleHiddenStatesConnector
fynnsu Feb 11, 2026
ab1ec95
Add basic kv cache insertion function
fynnsu Feb 11, 2026
b3e64b9
Clean up ExtractHiddenStatesProposer implementation
fynnsu Feb 12, 2026
2e3b4f4
Fix small issues
fynnsu Feb 12, 2026
294d842
Add support for merging KVConnectorOutput
fynnsu Feb 13, 2026
1ad24df
Improve config handling
fynnsu Feb 13, 2026
a624b72
Fix precommit issues
fynnsu Feb 13, 2026
28aa17a
Cleanup todos
fynnsu Feb 13, 2026
45e7e4a
precommit
fynnsu Feb 13, 2026
7c53b80
Fix docs issues
fynnsu Feb 13, 2026
ef326ad
Add tests for extract_hidden_states
fynnsu Feb 16, 2026
88d2fb4
Fix ExampleHiddenStatesConnector handling of batched prefill
fynnsu Feb 23, 2026
7638668
Handle review comments
fynnsu Feb 23, 2026
06396f5
lint
fynnsu Feb 26, 2026
4806c95
Add ExtractHiddenStatesModel to test model registry
fynnsu Feb 27, 2026
2268224
Update computation of padded tokens in extract_hidden_states
fynnsu Mar 2, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions examples/offline_inference/extract_hidden_states.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import tempfile

from safetensors import safe_open

from vllm import LLM, SamplingParams

# Example: Using the custom "extract_hidden_states" speculator method and
# ExampleHiddenStatesConnector to extract and save hidden states from vllm

with tempfile.TemporaryDirectory() as tmpdirname:
llm = LLM(
model="Qwen/Qwen3-8B", # Your target model
speculative_config={
"method": "extract_hidden_states",
"num_speculative_tokens": 1,
"draft_model_config": {
"hf_config": {
"eagle_aux_hidden_state_layer_ids": [ # Target model layer indices
1,
2,
3,
4,
],
}
},
},
kv_transfer_config={
"kv_connector": "ExampleHiddenStatesConnector",
"kv_role": "kv_producer",
"kv_connector_extra_config": {
"shared_storage_path": tmpdirname,
},
},
)

prompts = ["Generate a sentence with hidden states", "Write a python function"]
sampling_params = SamplingParams(max_tokens=1)
outputs = llm.generate(prompts, sampling_params)

for output in outputs:
print("\nPrompt:", output.prompt)
print("Prompt token ids:", output.prompt_token_ids)

hidden_states_path = output.kv_transfer_params.get("hidden_states_path")
assert hidden_states_path is not None
print("Prompt hidden states path:", hidden_states_path)

with safe_open(hidden_states_path, "pt") as f:
token_ids = f.get_tensor("token_ids")
hidden_states = f.get_tensor("hidden_states")

print("Extracted token ids:", token_ids) # Matches prompt token ids
print(
"Extracted hidden states shape:", hidden_states.shape
) # [num_hidden_layers, prompt len, hidden size]
print("Extracted hidden states:", hidden_states)
6 changes: 5 additions & 1 deletion tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ class _HfExamplesInfo:

use_original_num_layers: bool = False
"""
If True, use the original number of layers from the model config
If True, use the original number of layers from the model config
instead of minimal layers for testing.
"""

Expand Down Expand Up @@ -1160,6 +1160,10 @@ def check_available_online(
speculative_model="LGAI-EXAONE/K-EXAONE-236B-A23B",
min_transformers_version="5.1.0",
),
"ExtractHiddenStatesModel": _HfExamplesInfo(
"Qwen/Qwen3-8B",
speculative_method="extract_hidden_states",
),
"Glm4MoeMTPModel": _HfExamplesInfo(
"zai-org/GLM-4.5",
speculative_model="zai-org/GLM-4.5",
Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

"""Predictable dummy model for testing extract_hidden_states.

Subclasses LlamaForCausalLM but overrides the model to produce deterministic
hidden states: layer i outputs values equal to (i).
"""

from collections.abc import Iterable

import torch
import torch.nn as nn

from vllm.config import VllmConfig
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.sequence import IntermediateTensors


class PredictableLlamaModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
self.config = vllm_config.model_config.hf_config
self.aux_hidden_state_layers = tuple[int, ...]()

# Create minimal embed_tokens for embedding
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding,
)

self.embed_tokens = VocabParallelEmbedding(
self.config.vocab_size,
self.config.hidden_size,
)

# Required for pipeline parallelism
from vllm.model_executor.models.utils import (
make_empty_intermediate_tensors_factory,
)

self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], self.config.hidden_size
)

def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
"""Embed input IDs."""
return self.embed_tokens(input_ids)

def forward(
self,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None,
**extra_layer_kwargs,
) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
"""Forward pass that produces predictable outputs.

Returns:
If aux_hidden_state_layers is set: (hidden_states, aux_hidden_states)
Otherwise: hidden_states
"""
# Determine sequence length
if inputs_embeds is not None:
seq_len = inputs_embeds.shape[0]
device = inputs_embeds.device
elif input_ids is not None:
seq_len = input_ids.shape[0] if input_ids.ndim == 1 else input_ids.shape[-1]
device = input_ids.device
else:
raise ValueError("Either input_ids or inputs_embeds must be provided")

# Final hidden states (last layer value)
hidden_states = torch.full(
(seq_len, self.config.hidden_size),
fill_value=float(self.config.num_hidden_layers),
device=device,
dtype=torch.bfloat16,
)

# Check if we need auxiliary hidden states
if len(self.aux_hidden_state_layers) > 0:
aux_hidden_states = []
for layer_idx in self.aux_hidden_state_layers:
# Fill with (layer_idx) for predictability
layer_hidden = torch.full(
(seq_len, self.config.hidden_size),
fill_value=float(layer_idx),
device=device,
dtype=torch.bfloat16,
)
aux_hidden_states.append(layer_hidden)

return hidden_states, aux_hidden_states

return hidden_states

def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
"""Skip weight loading."""
return set()


class PredictableLlamaForCausalLM(LlamaForCausalLM):
"""Predictable Llama model for testing.

Overrides _init_model to use PredictableLlamaModel instead of LlamaModel.
"""

def _init_model(
self,
vllm_config: VllmConfig,
prefix: str = "",
layer_type: type[nn.Module] | None = None,
):
"""Initialize with predictable model."""
return PredictableLlamaModel(vllm_config=vllm_config, prefix=prefix)

def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
"""Skip weight loading for dummy model."""
return set()
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import gc
import os

import pytest
import torch
from safetensors import safe_open

from vllm import LLM, ModelRegistry, SamplingParams


def get_and_check_output(output, expected_shape):
assert output.kv_transfer_params is not None
hidden_states_path = output.kv_transfer_params.get("hidden_states_path")
assert hidden_states_path is not None
assert os.path.exists(hidden_states_path)

# Load and verify the saved tensors
with safe_open(hidden_states_path, "pt") as f:
# Check that token_ids and hidden_states are present
tensor_names = f.keys()
assert "token_ids" in tensor_names
assert "hidden_states" in tensor_names

token_ids = f.get_tensor("token_ids")
hidden_states = f.get_tensor("hidden_states")

prompt_token_ids = output.prompt_token_ids
assert torch.equal(token_ids, torch.tensor(prompt_token_ids))

assert hidden_states.shape == expected_shape

# Verify hidden_states are not all zeros (i.e., they were actually computed)
assert not torch.allclose(hidden_states, torch.zeros_like(hidden_states))

return token_ids, hidden_states


@pytest.fixture(scope="module")
def predictable_llama_config_path(tmp_path_factory):
"""Create a minimal LlamaConfig for PredictableLlamaForCausalLM."""
from transformers import LlamaConfig, LlamaTokenizerFast

config_dir = tmp_path_factory.mktemp("predictable_llama")

# Create a minimal Llama config with small dimensions
config = LlamaConfig(
vocab_size=1000,
hidden_size=256,
intermediate_size=512,
num_hidden_layers=24, # Enough layers to test various layer_ids
num_attention_heads=4,
num_key_value_heads=4,
max_position_embeddings=128,
architectures=["PredictableLlamaForCausalLM"],
)

# Save config
config.save_pretrained(config_dir)

# Create a simple tokenizer
tokenizer = LlamaTokenizerFast.from_pretrained(
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
cache_dir=os.path.expanduser("~/.cache/huggingface"),
)
tokenizer.save_pretrained(config_dir)

return str(config_dir)


@pytest.fixture(scope="module", autouse=True)
def register_predictable_model():
"""Register the PredictableLlamaForCausalLM model."""
from .predictable_llama import PredictableLlamaForCausalLM

if "PredictableLlamaForCausalLM" not in ModelRegistry.get_supported_archs():
ModelRegistry.register_model(
"PredictableLlamaForCausalLM", PredictableLlamaForCausalLM
)
yield


def test_extract_hidden_states_with_predictable_dummy_model(
predictable_llama_config_path, tmp_path
):
"""Comprehensive test using a predictable dummy model with synthetic weights.

The PredictableLlamaForCausalLM outputs deterministic hidden states where
each layer produces values equal to (layer_index). This test verifies:
1. Hidden states are correctly extracted from requested layers
2. Values match the expected predictable pattern
3. Layer ordering is preserved correctly (non-sequential layer IDs)
4. Multiple prompts of different lengths produce consistent layer values
"""
# Test with non-sequential layer ordering to verify correct association
layer_ids = [5, 2, 10]
num_layers = len(layer_ids)

llm = LLM(
model=predictable_llama_config_path,
speculative_config={
"method": "extract_hidden_states",
"num_speculative_tokens": 1,
"draft_model_config": {
"hf_config": {"eagle_aux_hidden_state_layer_ids": layer_ids}
},
},
kv_transfer_config={
"kv_connector": "ExampleHiddenStatesConnector",
"kv_role": "kv_producer",
"kv_connector_extra_config": {"shared_storage_path": tmp_path},
},
max_model_len=128,
enforce_eager=True,
trust_remote_code=True,
load_format="dummy", # Don't try to load real weights
)
Comment on lines +101 to +119
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you specifically test that this work if max_num_batched_tokens is smaller than the longest sequence? We should make sure that chunked prefill works with hidden state extraction

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, we were actually failing in this case (only saving first chunk). I've updated the code to make sure we're handling in-progress requests correctly and keeping track of their block ids. This ended up just being changes to the kv connector, which is now even more inefficient because it overwrites the file on disk as it grows.

That being said, the purpose of this connector is really just to have a baseline working solution for debugging / tests, so I think this is okay for now. We will develop a more performant implementation in a future pr.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with your sentiment, glad it caught a real issue!


# Test with multiple prompts of different lengths
prompts = [
"Short",
"Medium length",
"Much longer prompt with many tokens",
"Much longer prompt with many tokens", # repeated prompt
]
sampling_params = SamplingParams(max_tokens=1, temperature=0.0)
hidden_size = llm.llm_engine.model_config.get_hidden_size()
outputs = llm.generate(prompts, sampling_params)
del llm
gc.collect()

assert len(outputs) == len(prompts)

for output in outputs:
# hidden_states shape is [prompt_len, num_hidden_layers, hidden_size]
expected_shape = (
len(output.prompt_token_ids),
num_layers,
hidden_size,
)
_token_ids, hidden_states = get_and_check_output(output, expected_shape)

for idx, layer_id in enumerate(layer_ids):
layer_hidden = hidden_states[:, idx, :]
assert torch.allclose(
layer_hidden,
torch.full_like(layer_hidden, layer_id),
atol=1e-5,
), (
f"Layer {layer_id} at position {idx} should output {float(layer_id)}, "
f"but got mean={layer_hidden.mean():.3f}, "
f"min={layer_hidden.min():.3f}, max={layer_hidden.max():.3f}"
)
Loading