diff --git a/tests/v1/worker/test_gpu_model_runner_v2.py b/tests/v1/worker/test_gpu_model_runner_v2.py new file mode 100644 index 000000000000..290c737e9b18 --- /dev/null +++ b/tests/v1/worker/test_gpu_model_runner_v2.py @@ -0,0 +1,34 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch + +from vllm.platforms import current_platform + +DEVICE_TYPE = current_platform.device_type + + +def test_v2_block_tables_kernel_block_expansion(): + from vllm.v1.worker.gpu.block_table import BlockTables + + block_tables = BlockTables( + block_sizes=[128], + kernel_block_sizes=[64], + max_num_reqs=4, + max_num_batched_tokens=256, + max_num_blocks_per_group=[10], + device=torch.device(DEVICE_TYPE), + ) + + block_tables.append_block_ids(0, ([0, 1, 2],), overwrite=True) + block_tables.apply_staged_writes() + + assert block_tables.blocks_per_kv_block == [2] + assert block_tables.block_tables[0].gpu[0, :6].cpu().tolist() == [ + 0, + 1, + 2, + 3, + 4, + 5, + ] diff --git a/vllm/v1/worker/gpu/attn_utils.py b/vllm/v1/worker/gpu/attn_utils.py index 5930777de6fa..3ed6a74af232 100644 --- a/vllm/v1/worker/gpu/attn_utils.py +++ b/vllm/v1/worker/gpu/attn_utils.py @@ -22,7 +22,11 @@ UniformTypeKVCacheSpecs, ) from vllm.v1.worker.gpu.model_states.interface import ModelSpecificAttnMetadata -from vllm.v1.worker.utils import AttentionGroup, bind_kv_cache +from vllm.v1.worker.utils import ( + AttentionGroup, + bind_kv_cache, + prepare_kernel_block_sizes, +) @dataclass(frozen=True) @@ -51,6 +55,7 @@ def init_attn_backend( dict[str, type[AttentionBackend]], list[list[AttentionGroup]], AttentionCGSupportInfo, + list[int], ]: attn_backends: dict[str, type[AttentionBackend]] = {} attn_groups: list[list[AttentionGroup]] = [] @@ -91,12 +96,21 @@ def init_attn_backend( else: group_map[key].layer_names.append(layer_name) - groups = [group_map[key] for key in group_order] + attn_groups.append([group_map[key] for key in group_order]) + + kernel_block_sizes = prepare_kernel_block_sizes(kv_cache_config, attn_groups) + for kv_cache_group_id, groups in enumerate(attn_groups): + kernel_block_size = ( + kernel_block_sizes[kv_cache_group_id] + if kv_cache_group_id < len(kernel_block_sizes) + else None + ) + kv_cache_group_spec = kv_cache_config.kv_cache_groups[kv_cache_group_id] for group in groups: group.create_metadata_builders( vllm_config=vllm_config, device=device, - kernel_block_size=None, + kernel_block_size=kernel_block_size, num_metadata_builders=1, ) builder = group.get_metadata_builder(0) @@ -113,8 +127,7 @@ def init_attn_backend( ) if cg_support.value < min_cg_support.value: min_cg_support = cg_support - min_cg_attn_backend = attn_backend.__name__ - attn_groups.append(groups) + min_cg_attn_backend = group.backend.__name__ return ( attn_backends, @@ -123,6 +136,7 @@ def init_attn_backend( min_cg_support=min_cg_support, min_cg_attn_backend=min_cg_attn_backend, ), + kernel_block_sizes, ) @@ -147,11 +161,16 @@ def _reshape_kv_cache( kv_cache_config: KVCacheConfig, kv_cache_raw_tensors: dict[str, torch.Tensor], attn_backends: dict[str, type[AttentionBackend]], + kernel_block_sizes: list[int], cache_dtype: str, ) -> dict[str, Any]: kv_caches: dict[str, Any] = {} has_attn, has_mamba = False, False - for kv_cache_group_spec in kv_cache_config.kv_cache_groups: + for kv_cache_group_id, kv_cache_group_spec in enumerate( + kv_cache_config.kv_cache_groups + ): + if kv_cache_group_id >= len(kernel_block_sizes): + continue for layer_name in kv_cache_group_spec.layer_names: kv_cache_spec = kv_cache_group_spec.kv_cache_spec if isinstance(kv_cache_spec, UniformTypeKVCacheSpecs): @@ -164,9 +183,16 @@ def _reshape_kv_cache( if isinstance(kv_cache_spec, AttentionSpec): has_attn = True attn_backend = attn_backends[layer_name] + kernel_block_size = kernel_block_sizes[kv_cache_group_id] + num_blocks_per_kv_block = kv_cache_spec.block_size // kernel_block_size + kernel_num_blocks = num_blocks * num_blocks_per_kv_block + if kv_cache_spec.storage_block_size != kv_cache_spec.block_size: + shape_block_size = kv_cache_spec.storage_block_size + else: + shape_block_size = kernel_block_size kv_cache_shape = attn_backend.get_kv_cache_shape( - num_blocks, - kv_cache_spec.storage_block_size, + kernel_num_blocks, + shape_block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size, cache_dtype_str=cache_dtype, @@ -273,12 +299,17 @@ def init_kv_cache( forward_context: dict[str, Any], kv_cache_config: KVCacheConfig, attn_backends: dict[str, type[AttentionBackend]], + kernel_block_sizes: list[int], device: torch.device, cache_dtype: str, ) -> dict[str, Any]: kv_cache_raw_tensors = _allocate_kv_cache(kv_cache_config, device) kv_caches = _reshape_kv_cache( - kv_cache_config, kv_cache_raw_tensors, attn_backends, cache_dtype + kv_cache_config, + kv_cache_raw_tensors, + attn_backends, + kernel_block_sizes, + cache_dtype, ) bind_kv_cache(kv_caches, forward_context, runner_kv_caches) return kv_caches diff --git a/vllm/v1/worker/gpu/block_table.py b/vllm/v1/worker/gpu/block_table.py index 62454f70304d..8872a57422d5 100644 --- a/vllm/v1/worker/gpu/block_table.py +++ b/vllm/v1/worker/gpu/block_table.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable +import numpy as np import torch from vllm.triton_utils import tl, triton @@ -13,6 +14,7 @@ class BlockTables: def __init__( self, block_sizes: list[int], + kernel_block_sizes: list[int], max_num_reqs: int, max_num_batched_tokens: int, max_num_blocks_per_group: list[int], @@ -21,7 +23,24 @@ def __init__( cp_rank: int = 0, cp_interleave: int = 1, ): - self.block_sizes = block_sizes + if len(kernel_block_sizes) != len(block_sizes): + raise ValueError( + f"kernel_block_sizes length ({len(kernel_block_sizes)}) " + f"must match block_sizes length ({len(block_sizes)})" + ) + self.block_sizes = kernel_block_sizes + self.blocks_per_kv_block: list[int] = [] + for block_size, kernel_block_size in zip(block_sizes, kernel_block_sizes): + if block_size % kernel_block_size != 0: + raise ValueError( + f"kernel_block_size {kernel_block_size} must divide " + f"kv_manager_block_size {block_size} evenly" + ) + self.blocks_per_kv_block.append(block_size // kernel_block_size) + self._kernel_block_offsets: list[np.ndarray] = [ + np.arange(blocks_per_kv_block, dtype=np.int32).reshape(1, -1) + for blocks_per_kv_block in self.blocks_per_kv_block + ] self.max_num_reqs = max_num_reqs self.max_num_batched_tokens = max_num_batched_tokens self.device = device @@ -35,7 +54,7 @@ def __init__( # num_kv_cache_groups x [max_num_reqs, max_num_blocks] self.block_tables: list[StagedWriteTensor] = [] for i in range(self.num_kv_cache_groups): - max_num_blocks = max_num_blocks_per_group[i] + max_num_blocks = max_num_blocks_per_group[i] * self.blocks_per_kv_block[i] block_table = StagedWriteTensor( (self.max_num_reqs, max_num_blocks), dtype=torch.int32, @@ -88,6 +107,21 @@ def init_block_table_layout_tensors(self) -> None: ) self.input_block_table_ptrs = self._make_ptr_tensor(self.input_block_tables) + @staticmethod + def map_to_kernel_blocks( + kv_manager_block_ids: np.ndarray, + blocks_per_kv_block: int, + kernel_block_offsets: np.ndarray, + ) -> np.ndarray: + """Convert KV cache manager block IDs to kernel block IDs.""" + if blocks_per_kv_block == 1: + return kv_manager_block_ids + kernel_block_ids = ( + kv_manager_block_ids.reshape(-1, 1) * blocks_per_kv_block + + kernel_block_offsets + ) + return kernel_block_ids.reshape(-1) + def append_block_ids( self, req_index: int, @@ -96,7 +130,11 @@ def append_block_ids( ) -> None: for i in range(self.num_kv_cache_groups): start = self.num_blocks.np[i, req_index] if not overwrite else 0 - block_ids = new_block_ids[i] + block_ids = self.map_to_kernel_blocks( + np.array(new_block_ids[i]), + self.blocks_per_kv_block[i], + self._kernel_block_offsets[i], + ).tolist() self.block_tables[i].stage_write(req_index, start, block_ids) self.num_blocks.np[i, req_index] = start + len(block_ids) diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index 091138d9ab5f..2e19b507386e 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -381,8 +381,15 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: ) + spec.num_speculative_blocks max_num_blocks_per_group.append(max_num_blocks) + ( + self.attn_backends, + self.attn_groups, + attn_cg_support, + kernel_block_sizes, + ) = init_attn_backend(self.kv_cache_config, self.vllm_config, self.device) self.block_tables = BlockTables( block_sizes=block_sizes, + kernel_block_sizes=kernel_block_sizes, max_num_reqs=self.max_num_reqs, max_num_batched_tokens=self.max_num_tokens, max_num_blocks_per_group=max_num_blocks_per_group, @@ -391,10 +398,6 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: cp_rank=self.dcp_rank, cp_interleave=self.cp_interleave, ) - - self.attn_backends, self.attn_groups, attn_cg_support = init_attn_backend( - self.kv_cache_config, self.vllm_config, self.device - ) initialize_mamba_ssu_backend( self.vllm_config.mamba_config, self.kv_cache_config ) @@ -430,6 +433,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: self.compilation_config.static_forward_context, self.kv_cache_config, self.attn_backends, + kernel_block_sizes, self.device, self.cache_config.cache_dtype, ) diff --git a/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py b/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py index 00777cbd81d1..af1c3608da88 100644 --- a/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py +++ b/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py @@ -173,7 +173,7 @@ def set_attn( ) -> None: self.model_state = model_state self.kv_cache_config = kv_cache_config - _, self.attn_groups, _ = init_attn_backend( + _, self.attn_groups, _, _ = init_attn_backend( kv_cache_config, self.vllm_config, self.device,