-
-
Notifications
You must be signed in to change notification settings - Fork 17.8k
[Bugfix][Model Runner v2] Fix MRV2 KV cache kernel block sizing. #42872
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,34 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
|
||
| import torch | ||
|
|
||
| from vllm.platforms import current_platform | ||
|
|
||
| DEVICE_TYPE = current_platform.device_type | ||
|
|
||
|
|
||
| def test_v2_block_tables_kernel_block_expansion(): | ||
| from vllm.v1.worker.gpu.block_table import BlockTables | ||
|
|
||
| block_tables = BlockTables( | ||
| block_sizes=[128], | ||
| kernel_block_sizes=[64], | ||
| max_num_reqs=4, | ||
| max_num_batched_tokens=256, | ||
| max_num_blocks_per_group=[10], | ||
| device=torch.device(DEVICE_TYPE), | ||
| ) | ||
|
|
||
| block_tables.append_block_ids(0, ([0, 1, 2],), overwrite=True) | ||
| block_tables.apply_staged_writes() | ||
|
|
||
| assert block_tables.blocks_per_kv_block == [2] | ||
| assert block_tables.block_tables[0].gpu[0, :6].cpu().tolist() == [ | ||
| 0, | ||
| 1, | ||
| 2, | ||
| 3, | ||
| 4, | ||
| 5, | ||
| ] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,7 +22,11 @@ | |
| UniformTypeKVCacheSpecs, | ||
| ) | ||
| from vllm.v1.worker.gpu.model_states.interface import ModelSpecificAttnMetadata | ||
| from vllm.v1.worker.utils import AttentionGroup, bind_kv_cache | ||
| from vllm.v1.worker.utils import ( | ||
| AttentionGroup, | ||
| bind_kv_cache, | ||
| prepare_kernel_block_sizes, | ||
| ) | ||
|
|
||
|
|
||
| @dataclass(frozen=True) | ||
|
|
@@ -51,6 +55,7 @@ def init_attn_backend( | |
| dict[str, type[AttentionBackend]], | ||
| list[list[AttentionGroup]], | ||
| AttentionCGSupportInfo, | ||
| list[int], | ||
| ]: | ||
| attn_backends: dict[str, type[AttentionBackend]] = {} | ||
| attn_groups: list[list[AttentionGroup]] = [] | ||
|
|
@@ -91,12 +96,21 @@ def init_attn_backend( | |
| else: | ||
| group_map[key].layer_names.append(layer_name) | ||
|
|
||
| groups = [group_map[key] for key in group_order] | ||
| attn_groups.append([group_map[key] for key in group_order]) | ||
|
|
||
| kernel_block_sizes = prepare_kernel_block_sizes(kv_cache_config, attn_groups) | ||
| for kv_cache_group_id, groups in enumerate(attn_groups): | ||
| kernel_block_size = ( | ||
| kernel_block_sizes[kv_cache_group_id] | ||
| if kv_cache_group_id < len(kernel_block_sizes) | ||
| else None | ||
| ) | ||
| kv_cache_group_spec = kv_cache_config.kv_cache_groups[kv_cache_group_id] | ||
| for group in groups: | ||
| group.create_metadata_builders( | ||
| vllm_config=vllm_config, | ||
| device=device, | ||
| kernel_block_size=None, | ||
| kernel_block_size=kernel_block_size, | ||
| num_metadata_builders=1, | ||
| ) | ||
| builder = group.get_metadata_builder(0) | ||
|
|
@@ -113,8 +127,7 @@ def init_attn_backend( | |
| ) | ||
| if cg_support.value < min_cg_support.value: | ||
| min_cg_support = cg_support | ||
| min_cg_attn_backend = attn_backend.__name__ | ||
| attn_groups.append(groups) | ||
| min_cg_attn_backend = group.backend.__name__ | ||
|
|
||
| return ( | ||
| attn_backends, | ||
|
|
@@ -123,6 +136,7 @@ def init_attn_backend( | |
| min_cg_support=min_cg_support, | ||
| min_cg_attn_backend=min_cg_attn_backend, | ||
| ), | ||
| kernel_block_sizes, | ||
| ) | ||
|
|
||
|
|
||
|
|
@@ -147,11 +161,16 @@ def _reshape_kv_cache( | |
| kv_cache_config: KVCacheConfig, | ||
| kv_cache_raw_tensors: dict[str, torch.Tensor], | ||
| attn_backends: dict[str, type[AttentionBackend]], | ||
| kernel_block_sizes: list[int], | ||
| cache_dtype: str, | ||
| ) -> dict[str, Any]: | ||
| kv_caches: dict[str, Any] = {} | ||
| has_attn, has_mamba = False, False | ||
| for kv_cache_group_spec in kv_cache_config.kv_cache_groups: | ||
| for kv_cache_group_id, kv_cache_group_spec in enumerate( | ||
| kv_cache_config.kv_cache_groups | ||
| ): | ||
| if kv_cache_group_id >= len(kernel_block_sizes): | ||
| continue | ||
|
Comment on lines
+172
to
+173
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same underlying concern as above. Under the current KV cache config construction, the groups skipped by this guard are trailing encoder-only groups, which do not allocate KV cache. |
||
| for layer_name in kv_cache_group_spec.layer_names: | ||
| kv_cache_spec = kv_cache_group_spec.kv_cache_spec | ||
| if isinstance(kv_cache_spec, UniformTypeKVCacheSpecs): | ||
|
|
@@ -164,9 +183,16 @@ def _reshape_kv_cache( | |
| if isinstance(kv_cache_spec, AttentionSpec): | ||
| has_attn = True | ||
| attn_backend = attn_backends[layer_name] | ||
| kernel_block_size = kernel_block_sizes[kv_cache_group_id] | ||
| num_blocks_per_kv_block = kv_cache_spec.block_size // kernel_block_size | ||
| kernel_num_blocks = num_blocks * num_blocks_per_kv_block | ||
| if kv_cache_spec.storage_block_size != kv_cache_spec.block_size: | ||
| shape_block_size = kv_cache_spec.storage_block_size | ||
| else: | ||
| shape_block_size = kernel_block_size | ||
| kv_cache_shape = attn_backend.get_kv_cache_shape( | ||
| num_blocks, | ||
| kv_cache_spec.storage_block_size, | ||
| kernel_num_blocks, | ||
| shape_block_size, | ||
| kv_cache_spec.num_kv_heads, | ||
| kv_cache_spec.head_size, | ||
| cache_dtype_str=cache_dtype, | ||
|
|
@@ -273,12 +299,17 @@ def init_kv_cache( | |
| forward_context: dict[str, Any], | ||
| kv_cache_config: KVCacheConfig, | ||
| attn_backends: dict[str, type[AttentionBackend]], | ||
| kernel_block_sizes: list[int], | ||
| device: torch.device, | ||
| cache_dtype: str, | ||
| ) -> dict[str, Any]: | ||
| kv_cache_raw_tensors = _allocate_kv_cache(kv_cache_config, device) | ||
| kv_caches = _reshape_kv_cache( | ||
| kv_cache_config, kv_cache_raw_tensors, attn_backends, cache_dtype | ||
| kv_cache_config, | ||
| kv_cache_raw_tensors, | ||
| attn_backends, | ||
| kernel_block_sizes, | ||
| cache_dtype, | ||
| ) | ||
| bind_kv_cache(kv_caches, forward_context, runner_kv_caches) | ||
| return kv_caches | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The indexing of
kernel_block_sizesbykv_cache_group_idis incorrect becauseprepare_kernel_block_sizes(invllm/v1/worker/utils.py) skipsEncoderOnlyAttentionSpecgroups. This results in a length mismatch and misaligned mapping between groups and their kernel block sizes. If an encoder-only group exists, subsequent groups will receive the wrong block size orNone, causing them to fallback to logical block sizes and defeating the purpose of this fix.prepare_kernel_block_sizesshould be updated to return a list of the same length askv_cache_groups(e.g., by usingNonefor skipped groups).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
prepare_kernel_block_sizes()currently returns a compact list because EncoderOnlyAttentionSpec does not allocate KV cache. In the current KV cache config construction, encoder-onlygroups are appended after regular KV cache groups, so the compact list remains aligned for all non-encoder-only groups, and the trailing encoder-only group is skipped by the existing guard.
This PR keeps that existing behavior and focuses on the Qwen3 + FlashInfer + NIXL MRV2 regression. If MRV2 later allows encoder-only groups before regular attention groups, we should revisit this
indexing contract separately.