Skip to content
Merged
30 changes: 19 additions & 11 deletions vllm/config/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1292,18 +1292,26 @@ def has_blocked_weights():
if self.scheduler_config.disable_hybrid_kv_cache_manager is None:
# Default to disable HMA, but only if the user didn't express a preference.
if self.kv_transfer_config is not None:
# NOTE(Kuntai): turn HMA off for connector unless specifically enabled.
need_disable_hybrid_kv_cache_manager = True
logger.warning(
"Turning off hybrid kv cache manager because "
"`--kv-transfer-config` is set. This will reduce the "
"performance of vLLM on LLMs with sliding window attention "
"or Mamba attention. If you are a developer of kv connector"
", please consider supporting hybrid kv cache manager for "
"your connector by making sure your connector is a subclass"
" of `SupportsHMA` defined in kv_connector/v1/base.py and"
" use --no-disable-hybrid-kv-cache-manager to start vLLM."
from vllm.distributed.kv_transfer.kv_connector.factory import (
KVConnectorFactory,
)
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
supports_hma,
)

connector_cls = KVConnectorFactory.get_connector_class(
self.kv_transfer_config
)
if not supports_hma(connector_cls):
need_disable_hybrid_kv_cache_manager = True
logger.warning(
"Turning off hybrid kv cache manager because "
"connector %s does not subclass `SupportsHMA`. "
"This will reduce performance on models with "
"sliding window or Mamba attention. See "
"kv_connector/v1/base.py for details.",
connector_cls.__name__,
)
Comment on lines +1371 to +1395

@NickLucche NickLucche May 14, 2026

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we were tracking this change here #41847 and #42024 !

self.scheduler_config.disable_hybrid_kv_cache_manager = (
need_disable_hybrid_kv_cache_manager
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
KVConnectorBase_V1,
KVConnectorMetadata,
KVConnectorRole,
SupportsHMA,
)
from vllm.logger import init_logger
from vllm.v1.attention.backend import AttentionMetadata
Expand All @@ -30,13 +31,9 @@ def extract_from_kv_cache(
slot_mapping: torch.Tensor,
num_tokens: int,
) -> torch.Tensor:
"""Extract data from KV cache
Assume the shape of the kv_cache is (num_pages, page_size, num_heads, head_size)
"""

padded_kv = kv_cache.flatten(0, 1)[slot_mapping]
# shape: [len(slot_mapping), num_heads, head_size]
return padded_kv[:num_tokens] # shape: [num_tokens, num_heads, head_size]
"""Extract data from KV cache."""
block_size = kv_cache.shape[1]
return kv_cache[slot_mapping // block_size, slot_mapping % block_size][:num_tokens]


@dataclass
Expand All @@ -47,8 +44,6 @@ class ReqMeta:
filename: str
# Request tokens
token_ids: torch.Tensor
# Slot mappings, should have the same length as token_ids
slot_mapping: torch.Tensor
# Whether this request is a new request or partially computed already
new_req: bool

Expand All @@ -57,24 +52,12 @@ def make_meta(
req_id: str,
filename: str,
token_ids: list[int],
block_ids: list[int],
block_size: int,
new_req: bool,
) -> "ReqMeta":
token_ids_tensor = torch.tensor(token_ids)
block_ids_tensor = torch.tensor(block_ids)
num_blocks = block_ids_tensor.shape[0]
block_offsets = torch.arange(0, block_size)
slot_mapping = (
block_offsets.reshape((1, block_size))
+ block_ids_tensor.reshape((num_blocks, 1)) * block_size
)
slot_mapping = slot_mapping.flatten()
return ReqMeta(
req_id=req_id,
filename=filename,
token_ids=token_ids_tensor,
slot_mapping=slot_mapping,
token_ids=torch.tensor(token_ids),
new_req=new_req,
)

Expand All @@ -88,18 +71,12 @@ def add_request(
req_id: str,
filename: str,
token_ids: list[int],
block_ids: list[int],
block_size: int,
new_req: bool = True,
) -> None:
self.requests.append(
ReqMeta.make_meta(
req_id, filename, token_ids, block_ids, block_size, new_req
)
)
self.requests.append(ReqMeta.make_meta(req_id, filename, token_ids, new_req))


class ExampleHiddenStatesConnector(KVConnectorBase_V1):
class ExampleHiddenStatesConnector(KVConnectorBase_V1, SupportsHMA):
"""
Simple debug implementation of a HiddenStatesConnector.

Expand Down Expand Up @@ -206,9 +183,16 @@ def save_kv_layer(
assert isinstance(connector_metadata, ExampleHiddenStatesConnectorMetadata)

os.makedirs(self._storage_path, exist_ok=True)

slot_mapping = attn_metadata.slot_mapping
offset = 0
for request in connector_metadata.requests:
num_tokens = request.token_ids.shape[0]
req_slot_mapping = slot_mapping[offset : offset + num_tokens]
offset += num_tokens

hidden_states = extract_from_kv_cache(
kv_layer, request.slot_mapping, request.token_ids.shape[0]
kv_layer, req_slot_mapping, num_tokens
)
tensors = {
"hidden_states": hidden_states.detach().cpu(),
Expand Down Expand Up @@ -269,8 +253,6 @@ def build_connector_meta(
new_req.req_id,
filename=filename,
token_ids=token_ids,
block_ids=new_req.block_ids[0],
block_size=self._block_size,
)
self._request_filenames[new_req.req_id] = filename
self._active_requests[new_req.req_id] = new_req
Expand Down Expand Up @@ -298,8 +280,6 @@ def build_connector_meta(
req_id=req_id,
filename=filename,
token_ids=cached_req.prompt_token_ids or [],
block_ids=req_block_ids,
block_size=self._block_size,
new_req=False,
)

Expand Down Expand Up @@ -331,6 +311,13 @@ def request_finished(

return False, {"hidden_states_path": req_filename}

def request_finished_all_groups(
self,
request: "Request",
block_ids: tuple[list[int], ...],
) -> tuple[bool, dict[str, Any] | None]:
return self.request_finished(request, block_ids[0])

@classmethod
def get_required_kvcache_layout(cls, vllm_config: "VllmConfig") -> str | None:
"""
Expand Down
19 changes: 8 additions & 11 deletions vllm/model_executor/models/extract_hidden_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@
)
from vllm.v1.kv_cache_interface import (
AttentionSpec,
HiddenStateCacheSpec,
KVCacheSpec,
MLAAttentionSpec,
)

########## Custom Ops ########
Expand Down Expand Up @@ -79,13 +79,12 @@ def dummy_attention(layer_name, _placeholder):


def basic_cache(
to_cache: torch.Tensor, # shape: [num_blocks, block_size, num_heads, head_size]
kv_cache: torch.Tensor, # shape: [seq_len, num_heads, head_size]
to_cache: torch.Tensor, # shape: [seq_len, num_heads, head_size]
kv_cache: torch.Tensor, # shape: [num_blocks, block_size, num_heads, head_size]
slot_mapping: torch.Tensor, # shape: [seq_len]
):
num_blocks, block_size, num_heads, head_size = kv_cache.shape
token_kv_cache = kv_cache.view(num_blocks * block_size, num_heads, head_size)
token_kv_cache[slot_mapping] = to_cache
block_size = kv_cache.shape[1]
kv_cache[slot_mapping // block_size, slot_mapping % block_size] = to_cache


######### CacheOnlyAttentionBackend ########
Expand Down Expand Up @@ -322,11 +321,9 @@ def get_attn_backend(self) -> type[AttentionBackend]:
return self.attn_backend

def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
# Note: we use MLAAttentionSpec here to because it will
# produce page sizes of (block_size * num_kv_heads * head_size * dtype_size)
# whereas FullAttentionSpec will add an additional factor of 2
return MLAAttentionSpec(
block_size=self.block_size,
# Re-read block_size: hybrid models may bump it after __init__.
return HiddenStateCacheSpec(
block_size=vllm_config.cache_config.block_size,
num_kv_heads=self.num_heads,
head_size=self.head_size,
dtype=self.kv_cache_torch_dtype,
Expand Down
32 changes: 26 additions & 6 deletions vllm/v1/core/kv_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@
from vllm.utils.hashing import sha256_cbor, xxhash_cbor
from vllm.utils.math_utils import cdiv, round_up
from vllm.utils.mem_utils import format_gib
from vllm.utils.torch_utils import get_dtype_size
from vllm.v1.kv_cache_interface import (
ChunkedLocalAttentionSpec,
FullAttentionSpec,
HiddenStateCacheSpec,
KVCacheConfig,
KVCacheGroupSpec,
KVCacheSpec,
Expand Down Expand Up @@ -1650,15 +1652,33 @@ def get_kv_cache_groups(
_annotate_eagle_groups_deepseek_v4(vllm_config, kv_cache_spec, kv_cache_groups)
return kv_cache_groups

# Pull HiddenStateCacheSpec layers out before the general multi-group
# path so they don't affect page-size unification or grouping.
hidden_specs = {
k: v for k, v in kv_cache_spec.items() if isinstance(v, HiddenStateCacheSpec)
}
filtered_spec = {
k: v
for k, v in kv_cache_spec.items()
if not isinstance(v, HiddenStateCacheSpec)
}

# As KVCacheManager can only allocate memory of one size, we need to unify
# the page size of the layers. For cases cannot be unified, this function
# will raise an error.
kv_cache_spec = unify_kv_cache_spec_page_size(kv_cache_spec)
# Model contains multiple attention types, but KV cache of all layers
# have the same physical memory per block per layer. Split the layers
# into groups with the same number of layers, and thus same total page
# size.
return _get_kv_cache_groups_uniform_page_size(kv_cache_spec)
filtered_spec = unify_kv_cache_spec_page_size(filtered_spec)
groups = _get_kv_cache_groups_uniform_page_size(filtered_spec)

# Add hidden-state layers back with page aligned to the common page.
if hidden_specs:
common_page = get_uniform_page_size([g.kv_cache_spec for g in groups])
for name, spec in hidden_specs.items():
per_token = spec.num_kv_heads * spec.head_size * get_dtype_size(spec.dtype)
new_bs = max(common_page // per_token, 1)
aligned = replace(spec, block_size=new_bs, page_size_padded=common_page)
groups.append(KVCacheGroupSpec([name], aligned))

return groups


def generate_scheduler_kv_cache_config(
Expand Down
2 changes: 2 additions & 0 deletions vllm/v1/core/single_type_kv_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
ChunkedLocalAttentionSpec,
CrossAttentionSpec,
FullAttentionSpec,
HiddenStateCacheSpec,
KVCacheSpec,
MambaSpec,
MLAAttentionSpec,
Expand Down Expand Up @@ -1143,6 +1144,7 @@ def __init__(
FullAttentionSpec: FullAttentionManager,
TQFullAttentionSpec: FullAttentionManager,
MLAAttentionSpec: FullAttentionManager,
HiddenStateCacheSpec: FullAttentionManager,
SlidingWindowSpec: SlidingWindowManager,
SlidingWindowMLASpec: SlidingWindowManager,
ChunkedLocalAttentionSpec: ChunkedLocalAttentionManager,
Expand Down
7 changes: 7 additions & 0 deletions vllm/v1/kv_cache_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,13 @@ def merge(cls, specs: list[Self]) -> Self:
)


@dataclass(frozen=True, kw_only=True)
class HiddenStateCacheSpec(MLAAttentionSpec):
"""Marker for hidden-state cache layers used by extract_hidden_states."""

pass


@dataclass(frozen=True, kw_only=True)
class ChunkedLocalAttentionSpec(AttentionSpec):
attention_chunk_size: int
Expand Down
14 changes: 9 additions & 5 deletions vllm/v1/spec_decode/extract_hidden_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __init__(self, vllm_config: VllmConfig, device):
self.model: nn.Module | None = None
self.attn_layer_names: list[str] = []
self.attn_metadata_builder: AttentionMetadataBuilder | None = None
self.kv_cache_gid: int = -1

# Maximum number of tokens for buffers
max_batch_size = vllm_config.scheduler_config.max_num_seqs
Expand Down Expand Up @@ -374,9 +375,12 @@ def load_model(self, target_model: nn.Module) -> None:
)

def validate_same_kv_cache_group(self, kv_cache_config: KVCacheConfig) -> None:
"""Validate all drafting layers belong to the same KV cache group.

With exactly one attention layer (asserted in load_model), this is
trivially satisfied.
"""
"""Validate all drafting layers belong to the same KV cache group
and record the group index for common_attn_metadata selection."""
assert len(self.attn_layer_names) == 1
layer = self.attn_layer_names[0]
for gid, group in enumerate(kv_cache_config.kv_cache_groups):
if layer in group.layer_names:
self.kv_cache_gid = gid
return
raise ValueError(f"Cache-only layer {layer!r} not in any KV cache group")
8 changes: 7 additions & 1 deletion vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2329,7 +2329,13 @@ def _build_attn_group_metadata(

if self.speculative_config and spec_decode_common_attn_metadata is None:
if isinstance(
self.drafter, (EagleProposer, DFlashProposer, Gemma4Proposer)
self.drafter,
(
EagleProposer,
DFlashProposer,
Gemma4Proposer,
ExtractHiddenStatesProposer,
),
):
if self.drafter.kv_cache_gid == kv_cache_gid:
spec_decode_common_attn_metadata = cm
Expand Down
Loading