Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions tests/v1/worker/test_gpu_model_runner_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import torch

from vllm.platforms import current_platform

DEVICE_TYPE = current_platform.device_type


def test_v2_block_tables_kernel_block_expansion():
from vllm.v1.worker.gpu.block_table import BlockTables

block_tables = BlockTables(
block_sizes=[128],
kernel_block_sizes=[64],
max_num_reqs=4,
max_num_batched_tokens=256,
max_num_blocks_per_group=[10],
device=torch.device(DEVICE_TYPE),
)

block_tables.append_block_ids(0, ([0, 1, 2],), overwrite=True)
block_tables.apply_staged_writes()

assert block_tables.blocks_per_kv_block == [2]
assert block_tables.block_tables[0].gpu[0, :6].cpu().tolist() == [
0,
1,
2,
3,
4,
5,
]
49 changes: 40 additions & 9 deletions vllm/v1/worker/gpu/attn_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@
UniformTypeKVCacheSpecs,
)
from vllm.v1.worker.gpu.model_states.interface import ModelSpecificAttnMetadata
from vllm.v1.worker.utils import AttentionGroup, bind_kv_cache
from vllm.v1.worker.utils import (
AttentionGroup,
bind_kv_cache,
prepare_kernel_block_sizes,
)


@dataclass(frozen=True)
Expand Down Expand Up @@ -51,6 +55,7 @@ def init_attn_backend(
dict[str, type[AttentionBackend]],
list[list[AttentionGroup]],
AttentionCGSupportInfo,
list[int],
]:
attn_backends: dict[str, type[AttentionBackend]] = {}
attn_groups: list[list[AttentionGroup]] = []
Expand Down Expand Up @@ -91,12 +96,21 @@ def init_attn_backend(
else:
group_map[key].layer_names.append(layer_name)

groups = [group_map[key] for key in group_order]
attn_groups.append([group_map[key] for key in group_order])

kernel_block_sizes = prepare_kernel_block_sizes(kv_cache_config, attn_groups)
for kv_cache_group_id, groups in enumerate(attn_groups):
kernel_block_size = (
kernel_block_sizes[kv_cache_group_id]
if kv_cache_group_id < len(kernel_block_sizes)
else None
)
Comment on lines +101 to +107
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The indexing of kernel_block_sizes by kv_cache_group_id is incorrect because prepare_kernel_block_sizes (in vllm/v1/worker/utils.py) skips EncoderOnlyAttentionSpec groups. This results in a length mismatch and misaligned mapping between groups and their kernel block sizes. If an encoder-only group exists, subsequent groups will receive the wrong block size or None, causing them to fallback to logical block sizes and defeating the purpose of this fix. prepare_kernel_block_sizes should be updated to return a list of the same length as kv_cache_groups (e.g., by using None for skipped groups).

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prepare_kernel_block_sizes() currently returns a compact list because EncoderOnlyAttentionSpec does not allocate KV cache. In the current KV cache config construction, encoder-only
groups are appended after regular KV cache groups, so the compact list remains aligned for all non-encoder-only groups, and the trailing encoder-only group is skipped by the existing guard.

This PR keeps that existing behavior and focuses on the Qwen3 + FlashInfer + NIXL MRV2 regression. If MRV2 later allows encoder-only groups before regular attention groups, we should revisit this
indexing contract separately.

kv_cache_group_spec = kv_cache_config.kv_cache_groups[kv_cache_group_id]
for group in groups:
group.create_metadata_builders(
vllm_config=vllm_config,
device=device,
kernel_block_size=None,
kernel_block_size=kernel_block_size,
num_metadata_builders=1,
)
builder = group.get_metadata_builder(0)
Expand All @@ -113,8 +127,7 @@ def init_attn_backend(
)
if cg_support.value < min_cg_support.value:
min_cg_support = cg_support
min_cg_attn_backend = attn_backend.__name__
attn_groups.append(groups)
min_cg_attn_backend = group.backend.__name__

return (
attn_backends,
Expand All @@ -123,6 +136,7 @@ def init_attn_backend(
min_cg_support=min_cg_support,
min_cg_attn_backend=min_cg_attn_backend,
),
kernel_block_sizes,
)


Expand All @@ -147,11 +161,16 @@ def _reshape_kv_cache(
kv_cache_config: KVCacheConfig,
kv_cache_raw_tensors: dict[str, torch.Tensor],
attn_backends: dict[str, type[AttentionBackend]],
kernel_block_sizes: list[int],
cache_dtype: str,
) -> dict[str, Any]:
kv_caches: dict[str, Any] = {}
has_attn, has_mamba = False, False
for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
for kv_cache_group_id, kv_cache_group_spec in enumerate(
kv_cache_config.kv_cache_groups
):
if kv_cache_group_id >= len(kernel_block_sizes):
continue
Comment on lines +172 to +173
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

In _reshape_kv_cache, the loop incorrectly assumes that kernel_block_sizes is 1-to-1 with kv_cache_config.kv_cache_groups. Because prepare_kernel_block_sizes skips certain specs, len(kernel_block_sizes) may be less than the number of groups. The continue on line 173 will cause the last groups in the configuration to be skipped entirely, leading to missing KV caches for those layers.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same underlying concern as above. Under the current KV cache config construction, the groups skipped by this guard are trailing encoder-only groups, which do not allocate KV cache.

for layer_name in kv_cache_group_spec.layer_names:
kv_cache_spec = kv_cache_group_spec.kv_cache_spec
if isinstance(kv_cache_spec, UniformTypeKVCacheSpecs):
Expand All @@ -164,9 +183,16 @@ def _reshape_kv_cache(
if isinstance(kv_cache_spec, AttentionSpec):
has_attn = True
attn_backend = attn_backends[layer_name]
kernel_block_size = kernel_block_sizes[kv_cache_group_id]
num_blocks_per_kv_block = kv_cache_spec.block_size // kernel_block_size
kernel_num_blocks = num_blocks * num_blocks_per_kv_block
if kv_cache_spec.storage_block_size != kv_cache_spec.block_size:
shape_block_size = kv_cache_spec.storage_block_size
else:
shape_block_size = kernel_block_size
kv_cache_shape = attn_backend.get_kv_cache_shape(
num_blocks,
kv_cache_spec.storage_block_size,
kernel_num_blocks,
shape_block_size,
kv_cache_spec.num_kv_heads,
kv_cache_spec.head_size,
cache_dtype_str=cache_dtype,
Expand Down Expand Up @@ -273,12 +299,17 @@ def init_kv_cache(
forward_context: dict[str, Any],
kv_cache_config: KVCacheConfig,
attn_backends: dict[str, type[AttentionBackend]],
kernel_block_sizes: list[int],
device: torch.device,
cache_dtype: str,
) -> dict[str, Any]:
kv_cache_raw_tensors = _allocate_kv_cache(kv_cache_config, device)
kv_caches = _reshape_kv_cache(
kv_cache_config, kv_cache_raw_tensors, attn_backends, cache_dtype
kv_cache_config,
kv_cache_raw_tensors,
attn_backends,
kernel_block_sizes,
cache_dtype,
)
bind_kv_cache(kv_caches, forward_context, runner_kv_caches)
return kv_caches
Expand Down
44 changes: 41 additions & 3 deletions vllm/v1/worker/gpu/block_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable

import numpy as np
import torch

from vllm.triton_utils import tl, triton
Expand All @@ -13,6 +14,7 @@ class BlockTables:
def __init__(
self,
block_sizes: list[int],
kernel_block_sizes: list[int],
max_num_reqs: int,
max_num_batched_tokens: int,
max_num_blocks_per_group: list[int],
Expand All @@ -21,7 +23,24 @@ def __init__(
cp_rank: int = 0,
cp_interleave: int = 1,
):
self.block_sizes = block_sizes
if len(kernel_block_sizes) != len(block_sizes):
raise ValueError(
f"kernel_block_sizes length ({len(kernel_block_sizes)}) "
f"must match block_sizes length ({len(block_sizes)})"
)
self.block_sizes = kernel_block_sizes
self.blocks_per_kv_block: list[int] = []
for block_size, kernel_block_size in zip(block_sizes, kernel_block_sizes):
if block_size % kernel_block_size != 0:
raise ValueError(
f"kernel_block_size {kernel_block_size} must divide "
f"kv_manager_block_size {block_size} evenly"
)
self.blocks_per_kv_block.append(block_size // kernel_block_size)
self._kernel_block_offsets: list[np.ndarray] = [
np.arange(blocks_per_kv_block, dtype=np.int32).reshape(1, -1)
for blocks_per_kv_block in self.blocks_per_kv_block
]
self.max_num_reqs = max_num_reqs
self.max_num_batched_tokens = max_num_batched_tokens
self.device = device
Expand All @@ -35,7 +54,7 @@ def __init__(
# num_kv_cache_groups x [max_num_reqs, max_num_blocks]
self.block_tables: list[StagedWriteTensor] = []
for i in range(self.num_kv_cache_groups):
max_num_blocks = max_num_blocks_per_group[i]
max_num_blocks = max_num_blocks_per_group[i] * self.blocks_per_kv_block[i]
block_table = StagedWriteTensor(
(self.max_num_reqs, max_num_blocks),
dtype=torch.int32,
Expand Down Expand Up @@ -88,6 +107,21 @@ def init_block_table_layout_tensors(self) -> None:
)
self.input_block_table_ptrs = self._make_ptr_tensor(self.input_block_tables)

@staticmethod
def map_to_kernel_blocks(
kv_manager_block_ids: np.ndarray,
blocks_per_kv_block: int,
kernel_block_offsets: np.ndarray,
) -> np.ndarray:
"""Convert KV cache manager block IDs to kernel block IDs."""
if blocks_per_kv_block == 1:
return kv_manager_block_ids
kernel_block_ids = (
kv_manager_block_ids.reshape(-1, 1) * blocks_per_kv_block
+ kernel_block_offsets
)
return kernel_block_ids.reshape(-1)

def append_block_ids(
self,
req_index: int,
Expand All @@ -96,7 +130,11 @@ def append_block_ids(
) -> None:
for i in range(self.num_kv_cache_groups):
start = self.num_blocks.np[i, req_index] if not overwrite else 0
block_ids = new_block_ids[i]
block_ids = self.map_to_kernel_blocks(
np.array(new_block_ids[i]),
self.blocks_per_kv_block[i],
self._kernel_block_offsets[i],
).tolist()
self.block_tables[i].stage_write(req_index, start, block_ids)
self.num_blocks.np[i, req_index] = start + len(block_ids)

Expand Down
12 changes: 8 additions & 4 deletions vllm/v1/worker/gpu/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,8 +381,15 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
) + spec.num_speculative_blocks
max_num_blocks_per_group.append(max_num_blocks)

(
self.attn_backends,
self.attn_groups,
attn_cg_support,
kernel_block_sizes,
) = init_attn_backend(self.kv_cache_config, self.vllm_config, self.device)
self.block_tables = BlockTables(
block_sizes=block_sizes,
kernel_block_sizes=kernel_block_sizes,
max_num_reqs=self.max_num_reqs,
max_num_batched_tokens=self.max_num_tokens,
max_num_blocks_per_group=max_num_blocks_per_group,
Expand All @@ -391,10 +398,6 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
cp_rank=self.dcp_rank,
cp_interleave=self.cp_interleave,
)

self.attn_backends, self.attn_groups, attn_cg_support = init_attn_backend(
self.kv_cache_config, self.vllm_config, self.device
)
initialize_mamba_ssu_backend(
self.vllm_config.mamba_config, self.kv_cache_config
)
Expand Down Expand Up @@ -430,6 +433,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
self.compilation_config.static_forward_context,
self.kv_cache_config,
self.attn_backends,
kernel_block_sizes,
self.device,
self.cache_config.cache_dtype,
)
Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/worker/gpu/spec_decode/eagle/speculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def set_attn(
) -> None:
self.model_state = model_state
self.kv_cache_config = kv_cache_config
_, self.attn_groups, _ = init_attn_backend(
_, self.attn_groups, _, _ = init_attn_backend(
kv_cache_config,
self.vllm_config,
self.device,
Expand Down
Loading