diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index d8ecf28cbed1..1b180e28628d 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -34,6 +34,8 @@ init_none_hash, is_kv_cache_spec_uniform, make_block_hash_with_group_id, + merge_attn_layers_into_pack, + split_attn_layers_from_pack, tensor_data, ) from vllm.v1.kv_cache_interface import ( @@ -110,6 +112,7 @@ def new_kv_cache_spec( page_size_padded=None, sliding_window=None, attention_chunk_size=None, + pack_size=1, ): return FullAttentionSpec( block_size=block_size, @@ -119,6 +122,7 @@ def new_kv_cache_spec( page_size_padded=page_size_padded, sliding_window=sliding_window, attention_chunk_size=attention_chunk_size, + pack_size=pack_size, ) @@ -2137,3 +2141,194 @@ def test_unify_hybrid_kv_cache_specs(): with pytest.raises(ValueError): kv_cache_utils.unify_hybrid_kv_cache_specs(kv_cache_spec) + + +def test_merge_attn_layers_into_pack(): + attn_pack_size = 2 + + hybrid_kv_cache_specs = { + "layer_1": new_mamba_spec(), + "layer_2": new_kv_cache_spec(head_size=32), + "layer_3": new_kv_cache_spec(head_size=32), + } + merged_kv_cache_specs = merge_attn_layers_into_pack( + attn_pack_size, hybrid_kv_cache_specs + ) + assert merged_kv_cache_specs == { + "layer_1": new_mamba_spec(), + "layer_2+layer_3": new_kv_cache_spec(head_size=32, pack_size=2), + } + + hybrid_kv_cache_specs = { + "layer_1": new_mamba_spec(), + "layer_2": new_kv_cache_spec(head_size=32), + "layer_3": new_kv_cache_spec(head_size=32), + "layer_4": new_kv_cache_spec(head_size=32), + } + merged_kv_cache_specs = merge_attn_layers_into_pack( + attn_pack_size, hybrid_kv_cache_specs + ) + assert merged_kv_cache_specs == { + "layer_1": new_mamba_spec(), + "layer_2+layer_3": new_kv_cache_spec(head_size=32, pack_size=attn_pack_size), + "layer_4": new_kv_cache_spec(head_size=32, pack_size=attn_pack_size), + } + + +def test_split_attn_layers_from_pack(): + attn_pack_size = 2 + expected_page_size = new_mamba_spec().page_size_bytes + + kv_cache_config = KVCacheConfig( + num_blocks=20, + kv_cache_tensors=[ + KVCacheTensor( + size=expected_page_size * 20, + shared_by=["layer_1", "layer_2+layer_3"], + ), + ], + kv_cache_groups=[ + KVCacheGroupSpec(["layer_1"], new_mamba_spec()), + KVCacheGroupSpec( + ["layer_2+layer_3"], + new_kv_cache_spec(head_size=32, pack_size=2), + ), + ], + ) + split_kv_cache_config = split_attn_layers_from_pack(attn_pack_size, kv_cache_config) + + assert split_kv_cache_config == KVCacheConfig( + num_blocks=20, + kv_cache_tensors=[ + KVCacheTensor( + size=expected_page_size * 20, + shared_by=["layer_1", "layer_2", "layer_3"], + ), + ], + kv_cache_groups=[ + KVCacheGroupSpec(["layer_1"], new_mamba_spec()), + KVCacheGroupSpec( + ["layer_2", "layer_3"], + new_kv_cache_spec(head_size=32, pack_size=attn_pack_size), + ), + ], + ) + + +def test_get_kv_cache_configs_with_mamba(): + model_config = ModelConfig(max_model_len=16) + vllm_config = VllmConfig(model_config=model_config) + + expected_page_size = new_mamba_spec().page_size_bytes + + # Test 1: Pure mamba model (2 layers with same spec) + kv_cache_specs = { + "layer_1": new_mamba_spec(), + "layer_2": new_mamba_spec(), + } + available_memory = expected_page_size * 2 * 10 + kv_cache_config = get_kv_cache_configs( + vllm_config, [kv_cache_specs], [available_memory] + )[0] + + assert kv_cache_config == KVCacheConfig( + num_blocks=10, + kv_cache_tensors=[ + KVCacheTensor(size=expected_page_size * 10, shared_by=["layer_1"]), + KVCacheTensor(size=expected_page_size * 10, shared_by=["layer_2"]), + ], + kv_cache_groups=[KVCacheGroupSpec(["layer_1", "layer_2"], new_mamba_spec())], + ) + + # Test 2: 1 mamba + 1 full + hybrid_kv_cache_specs = { + "layer_1": new_mamba_spec(), + "layer_2": new_kv_cache_spec(), + } + available_memory_hybrid = expected_page_size * 2 * 10 + kv_cache_config_hybrid = get_kv_cache_configs( + vllm_config, [hybrid_kv_cache_specs], [available_memory_hybrid] + )[0] + + assert kv_cache_config_hybrid == KVCacheConfig( + num_blocks=20, + kv_cache_tensors=[ + KVCacheTensor( + size=expected_page_size * 20, shared_by=["layer_1", "layer_2"] + ), + ], + kv_cache_groups=[ + KVCacheGroupSpec(["layer_1"], new_mamba_spec()), + KVCacheGroupSpec(["layer_2"], new_kv_cache_spec()), + ], + ) + + # Test 3: 1 mamba + 2 full attention with group size 2 + vllm_config.cache_config.attn_pack_size = 2 + hybrid_kv_cache_specs = { + "layer_1": new_mamba_spec(), + "layer_2": new_kv_cache_spec(head_size=32), + "layer_3": new_kv_cache_spec(head_size=32), + } + available_memory_hybrid = expected_page_size * 2 * 10 + kv_cache_config_hybrid = get_kv_cache_configs( + vllm_config, [hybrid_kv_cache_specs], [available_memory_hybrid] + )[0] + + assert kv_cache_config_hybrid == KVCacheConfig( + num_blocks=20, + kv_cache_tensors=[ + KVCacheTensor( + size=expected_page_size * 20, + shared_by=["layer_1", "layer_2", "layer_3"], + ), + ], + kv_cache_groups=[ + KVCacheGroupSpec(["layer_1"], new_mamba_spec()), + KVCacheGroupSpec( + ["layer_2", "layer_3"], + new_kv_cache_spec(head_size=32, pack_size=2), + ), + ], + ) + + # Test 4: 2 mamba + 5 full (with 3 padding full) + vllm_config.cache_config.attn_pack_size = 2 + hybrid_kv_cache_specs = { + "layer_1": new_mamba_spec(), + "layer_2": new_mamba_spec(), + "layer_3": new_kv_cache_spec(head_size=32), + "layer_4": new_kv_cache_spec(head_size=32), + "layer_5": new_kv_cache_spec(head_size=32), + "layer_6": new_kv_cache_spec(head_size=32), + "layer_7": new_kv_cache_spec(head_size=32), + } + available_memory_hybrid = expected_page_size * 2 * 10 + kv_cache_config_hybrid = get_kv_cache_configs( + vllm_config, [hybrid_kv_cache_specs], [available_memory_hybrid] + )[0] + + assert kv_cache_config_hybrid == KVCacheConfig( + num_blocks=10, + kv_cache_tensors=[ + KVCacheTensor( + size=expected_page_size * 10, + shared_by=["layer_1", "layer_3", "layer_4", "layer_5", "layer_6"], + ), + KVCacheTensor( + size=expected_page_size * 10, + shared_by=["layer_2", "layer_7"], + ), + ], + kv_cache_groups=[ + KVCacheGroupSpec(["layer_1", "layer_2"], new_mamba_spec()), + KVCacheGroupSpec( + ["layer_3", "layer_4", "layer_7"], + new_kv_cache_spec(head_size=32, pack_size=2), + ), + KVCacheGroupSpec( + ["layer_5", "layer_6"], + new_kv_cache_spec(head_size=32, pack_size=2), + ), + ], + ) diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index 0de443858c98..ecd5489815bf 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from math import prod from types import SimpleNamespace import numpy as np @@ -10,6 +11,7 @@ from vllm.config import ( AttentionConfig, CacheConfig, + LoadConfig, ModelConfig, ParallelConfig, SchedulerConfig, @@ -26,7 +28,7 @@ from vllm.sampling_params import SamplingParams from vllm.utils.mem_constants import GiB_bytes from vllm.utils.system_utils import update_environment_variables -from vllm.utils.torch_utils import set_random_seed +from vllm.utils.torch_utils import get_dtype_size, set_random_seed from vllm.v1.attention.backend import MultipleOf from vllm.v1.attention.backends.registry import AttentionBackendEnum from vllm.v1.core.kv_cache_utils import estimate_max_model_len, get_kv_cache_configs @@ -1321,3 +1323,218 @@ def test_mamba_cache_raises_when_max_num_seqs_exceeds_blocks(): with pytest.raises(ValueError, match="max_num_seqs"): runner.initialize_kv_cache(kv_cache_config) + + +@pytest.mark.skipif( + current_platform.is_rocm(), + reason="Attention backend FLASHINFER is not supported on ROCm.", +) +@pytest.mark.parametrize("pack_size", [1, 2, 4]) +def test_hybrid_attention_mamba_kv_cache_pack_size(pack_size: int): + """ + Test that KV cache layout for Attention layers in a hybrid (Attention+Mamba) + model correctly reflects pack_size settings (1, 2, 4) after calling + get_kv_cache_configs() and initialize_kv_cache(). + + Verifications: + 1. Attention layer kv_cache logical shape equals + (num_blocks, 2, block_size, num_kv_heads, head_size). + 2. Attention layer kv_cache stride at dim-0 (block dim) equals + base_stride * pack_size, reflecting the packed layout. + 3. When pack_size > 1, all Attention layers in a pack share the same + underlying storage; storage_offset of layer i equals + i * num_element_per_attn_pack. + 4. Writes to one packed Attention layer's region do NOT corrupt sibling + layers (write-isolation check). + 5. Mamba layer kv_cache block count is independent of pack_size. + """ + set_random_seed(42) + + update_environment_variables( + { + "RANK": "0", + "LOCAL_RANK": "0", + "WORLD_SIZE": "1", + "MASTER_ADDR": "localhost", + "MASTER_PORT": "12345", + } + ) + from tests.utils import ensure_current_vllm_config + + with ensure_current_vllm_config(): + init_distributed_environment() + initialize_model_parallel(tensor_model_parallel_size=1) + torch.set_default_dtype(torch.float16) + + model_config = ModelConfig( + model="ibm-granite/granite-4.0-tiny-preview", + dtype="float16", + ) + scheduler_config = SchedulerConfig( + max_num_seqs=10, + max_num_batched_tokens=512, + max_model_len=512, + is_encoder_decoder=model_config.is_encoder_decoder, + ) + cache_config = CacheConfig( + block_size=BLOCK_SIZE, + gpu_memory_utilization=0.9, + cache_dtype="auto", + attn_pack_size=pack_size, + ) + parallel_config = ParallelConfig() + vllm_config = VllmConfig( + model_config=model_config, + cache_config=cache_config, + scheduler_config=scheduler_config, + parallel_config=parallel_config, + attention_config=AttentionConfig(backend=AttentionBackendEnum.FLASHINFER), + load_config=LoadConfig(load_format="dummy"), + ) + + attn_layer_names = [ + "model.layers.0.self_attn.attn", + "model.layers.1.self_attn.attn", + "model.layers.2.self_attn.attn", + "model.layers.3.self_attn.attn", + ] + mamba_layer_names = [ + "model.layers.4.mixer", + "model.layers.5.mixer", + "model.layers.6.mixer", + "model.layers.7.mixer", + ] + + with set_current_vllm_config(vllm_config): + hf_config = vllm_config.model_config.hf_config + for key in attn_layer_names: + Attention( + num_heads=model_config.get_num_attention_heads(parallel_config), + num_kv_heads=model_config.get_num_kv_heads(parallel_config), + head_size=model_config.get_head_size(), + scale=1.0, + prefix=key, + ) + for key in mamba_layer_names: + MambaMixer2( + hidden_size=hf_config.hidden_size, + ssm_state_size=hf_config.mamba_d_state, + conv_kernel_size=hf_config.mamba_d_conv, + intermediate_size=hf_config.mamba_expand * hf_config.hidden_size, + use_conv_bias=hf_config.mamba_conv_bias, + use_bias=hf_config.mamba_proj_bias, + n_groups=hf_config.mamba_n_groups, + num_heads=hf_config.mamba_n_heads, + head_dim=hf_config.mamba_d_head, + rms_norm_eps=hf_config.rms_norm_eps, + activation=hf_config.hidden_act, + cache_config=cache_config, + model_config=model_config, + prefix=key, + ) + vllm_ctx = vllm_config.compilation_config.static_forward_context + + runner = GPUModelRunner(vllm_config, DEVICE_TYPE) + current_platform.update_block_size_for_backend(vllm_config) + kv_cache_spec = runner.get_kv_cache_spec() + + available_memory = 5 * GiB_bytes + kv_cache_config = get_kv_cache_configs( + vllm_config, [kv_cache_spec], [available_memory] + )[0] + runner.initialize_kv_cache(kv_cache_config) + + num_blocks = kv_cache_config.num_blocks + + # After `get_kv_cache_configs()`, the merged AttentionSpec + # (with `pack_size` set to `attn_pack_size`) is stored in + # `kv_cache_config.kv_cache_groups[].kv_cache_spec`. + # Use it as the authoritative spec for layout calculations; + # `runner.get_kv_cache_spec()` always returns per-layer specs + # with pack_size=1 and is unaffected by the merge. + attn_group_spec = next( + g.kv_cache_spec + for g in kv_cache_config.kv_cache_groups + if isinstance(g.kv_cache_spec, AttentionSpec) + ) + assert attn_group_spec.pack_size == pack_size, ( + f"attn_group_spec.pack_size expected {pack_size}, " + f"got {attn_group_spec.pack_size}" + ) + + kv0 = vllm_ctx[attn_layer_names[0]].kv_cache + # FlashInfer logical shape: + # (kernel_num_blocks, 2, kernel_block_size, num_kv_heads, head_size) + expected_attn_shape = tuple(kv0.shape) + base_block_stride = prod(expected_attn_shape[1:]) + expected_block_stride = base_block_stride * pack_size + + dtype_size = get_dtype_size(attn_group_spec.dtype) + kernel_block_size = expected_attn_shape[2] # dim-2 of FlashInfer shape + num_blocks_per_kv_block = attn_group_spec.block_size // kernel_block_size + num_element_per_attn_pack = ( + attn_group_spec.page_size_bytes + // dtype_size + // num_blocks_per_kv_block + // attn_group_spec.pack_size + ) + + # --- 1 & 2. Verify shape and dim-0 stride for all Attention layers --- + for layer_name in attn_layer_names: + kv = vllm_ctx[layer_name].kv_cache + assert tuple(kv.shape) == expected_attn_shape, ( + f"pack_size={pack_size}, {layer_name}: " + f"expected shape {expected_attn_shape}, got {tuple(kv.shape)}" + ) + assert kv.stride(0) == expected_block_stride, ( + f"pack_size={pack_size}, {layer_name}: " + f"dim-0 stride expected {expected_block_stride}, got {kv.stride(0)}" + ) + + # --- 3 & 4. Verify storage sharing, offsets, and write isolation --- + for pack_start in range(0, len(attn_layer_names), pack_size): + pack_layers = attn_layer_names[pack_start : pack_start + pack_size] + kv_tensors = [vllm_ctx[ln].kv_cache for ln in pack_layers] + + if pack_size > 1: + # All layers in a pack share the same underlying storage. + base_ptr = kv_tensors[0].storage().data_ptr() + for i, (ln, kv) in enumerate(zip(pack_layers, kv_tensors)): + assert kv.storage().data_ptr() == base_ptr, ( + f"pack_size={pack_size}, {ln}: storage not shared with pack leader" + ) + # Layer at pack position i has offset i * num_element_per_attn_pack. + expected_offset = i * num_element_per_attn_pack + assert kv.storage_offset() == expected_offset, ( + f"pack_size={pack_size}, {ln} (pack_idx={i}): " + f"storage_offset expected {expected_offset}, " + f"got {kv.storage_offset()}" + ) + # Write-isolation: filling block-0 of layer i must not corrupt layer j. + for i, kv in enumerate(kv_tensors): + kv[0].fill_(float(i + 1)) + for i, (ln, kv) in enumerate(zip(pack_layers, kv_tensors)): + fill_val = float(i + 1) + assert ( + kv[0].min().item() == fill_val and kv[0].max().item() == fill_val + ), f"pack_size={pack_size}, {ln}: block 0 corrupted by sibling writes" + else: + # pack_size == 1: independent storage, offset must be 0. + for ln, kv in zip(pack_layers, kv_tensors): + assert kv.storage_offset() == 0, ( + f"pack_size=1, {ln}: storage_offset expected 0, " + f"got {kv.storage_offset()}" + ) + + # --- 5. Verify Mamba layer block count is independent of pack_size --- + for layer_name in mamba_layer_names: + conv_state = vllm_ctx[layer_name].kv_cache[0] + ssm_state = vllm_ctx[layer_name].kv_cache[1] + assert conv_state.shape[0] == num_blocks, ( + f"pack_size={pack_size}, {layer_name}: " + f"conv_state.shape[0] expected {num_blocks}, got {conv_state.shape[0]}" + ) + assert ssm_state.shape[0] == num_blocks, ( + f"pack_size={pack_size}, {layer_name}: " + f"ssm_state.shape[0] expected {num_blocks}, got {ssm_state.shape[0]}" + ) diff --git a/tests/v1/worker/test_hybrid_kv_cache_layout.py b/tests/v1/worker/test_hybrid_kv_cache_layout.py new file mode 100644 index 000000000000..fa5cb560e9ae --- /dev/null +++ b/tests/v1/worker/test_hybrid_kv_cache_layout.py @@ -0,0 +1,286 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from math import prod + +import pytest +import torch + +from vllm.utils.torch_utils import get_dtype_size +from vllm.v1.attention.backend import AttentionBackend +from vllm.v1.attention.backends.cpu_attn import CPUAttentionBackend +from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend +from vllm.v1.attention.backends.flash_attn_diffkv import FlashAttentionDiffKVBackend +from vllm.v1.attention.backends.flashinfer import FlashInferBackend +from vllm.v1.attention.backends.flex_attention import FlexAttentionBackend +from vllm.v1.attention.backends.tree_attn import TreeAttentionBackend +from vllm.v1.attention.backends.triton_attn import TritonAttentionBackend +from vllm.v1.attention.backends.utils import ( + get_kv_cache_layout, + set_kv_cache_layout, +) +from vllm.v1.kv_cache_interface import AttentionSpec, FullAttentionSpec +from vllm.v1.worker import mamba_utils + + +def _build_full_attn_spec( + block_size: int, + num_kv_heads: int, + head_size: int, + dtype: torch.dtype, + pack_size: int, +) -> FullAttentionSpec: + # Minimal valid FullAttentionSpec for layout tests. + return FullAttentionSpec( + block_size=block_size, + num_kv_heads=num_kv_heads, + head_size=head_size, + head_size_v=head_size, + dtype=dtype, + pack_size=pack_size, + ) + + +def _compute_layout_ref( + backend: AttentionBackend, + kv_cache_spec: AttentionSpec, + layer_idx: int, + kernel_block_size: int, + num_blocks: int, + enable_hybrid_attn_mamba_layout: bool, +): + """ + Reference implementation that mirrors the pre-refactor logic: + - grouping layout in `_reshape_kv_cache_tensors` + - followed by `_update_hybrid_attention_mamba_layout` when enabled. + """ + dtype = kv_cache_spec.dtype + block_size = kv_cache_spec.block_size + attn_pack_size = kv_cache_spec.pack_size + + num_blocks_per_kv_block = block_size // kernel_block_size + kernel_num_blocks = num_blocks * num_blocks_per_kv_block + + # Match `_reshape_kv_cache_tensors`: use `kernel_num_blocks` and + # `kernel_block_size` when querying backend shape. + kv_cache_shape_logical = backend.get_kv_cache_shape( + kernel_num_blocks, + kernel_block_size, + kv_cache_spec.num_kv_heads, + kv_cache_spec.head_size, + cache_dtype_str="auto", + ) + + try: + kv_cache_stride_order = backend.get_kv_cache_stride_order() + assert len(kv_cache_stride_order) == len(kv_cache_shape_logical) + except (AttributeError, NotImplementedError): + kv_cache_stride_order = tuple(range(len(kv_cache_shape_logical))) + + # Physical shape (the one used for allocation / as_strided). + kv_cache_shape = tuple(kv_cache_shape_logical[i] for i in kv_cache_stride_order) + + inv_order = [ + kv_cache_stride_order.index(i) for i in range(len(kv_cache_stride_order)) + ] + + # Base contiguous strides in physical layout. + base_stride = list(torch.empty(kv_cache_shape).stride()) + + storage_offset = 0 + block_dim = backend.get_kv_cache_block_dim( + kernel_block_size, + kv_cache_spec.num_kv_heads, + kv_cache_spec.head_size, + cache_dtype_str="auto", + ) + if attn_pack_size > 1: + # Match the original `_reshape_kv_cache_tensors` logic. + base_stride[block_dim] *= attn_pack_size + dtype_size = get_dtype_size(dtype) + num_element_per_page = kv_cache_spec.page_size_bytes // dtype_size + num_element_per_attn_pack = ( + num_element_per_page // num_blocks_per_kv_block // attn_pack_size + ) + attn_pack_idx = layer_idx % attn_pack_size + storage_offset = attn_pack_idx * num_element_per_attn_pack + + # Logical KV tensor after the initial reshape. + kv = torch.empty_strided( + size=kv_cache_shape, + stride=tuple(base_stride), + dtype=dtype, + ).permute(*inv_order) + + # Optional hybrid attention+mamba layout update. + # We analytically update the stride to match `_update_hybrid_attention_mamba_layout` + # without actually changing the underlying storage. + if ( + enable_hybrid_attn_mamba_layout + and isinstance(kv_cache_spec, AttentionSpec) + and block_dim == 1 + ): + hidden_size = prod(kv.shape[2:]) + attn_pack_size_for_layout = kv_cache_spec.pack_size + kv_stride = kv.stride() + kv_stride = ( + hidden_size, + 2 * hidden_size * attn_pack_size_for_layout, + *kv_stride[2:], + ) + return kv.shape, kv_stride, storage_offset + + return kv.shape, kv.stride(), storage_offset + + +def _compute_layout_new( + backend: AttentionBackend, + kv_cache_spec: AttentionSpec, + layer_idx: int, + kernel_block_size: int, + num_blocks: int, + enable_hybrid_attn_mamba_layout: bool, +): + """ + Layout computed via the new helper `_get_hybrid_attention_mamba_layout`. + """ + dtype = kv_cache_spec.dtype + block_size = kv_cache_spec.block_size + + num_blocks_per_kv_block = block_size // kernel_block_size + kernel_num_blocks = num_blocks * num_blocks_per_kv_block + + kv_cache_shape_logical = backend.get_kv_cache_shape( + kernel_num_blocks, + kernel_block_size, + kv_cache_spec.num_kv_heads, + kv_cache_spec.head_size, + cache_dtype_str="auto", + ) + + try: + kv_cache_stride_order = backend.get_kv_cache_stride_order() + assert len(kv_cache_stride_order) == len(kv_cache_shape_logical) + except (AttributeError, NotImplementedError): + kv_cache_stride_order = tuple(range(len(kv_cache_shape_logical))) + + kv_cache_shape = tuple(kv_cache_shape_logical[i] for i in kv_cache_stride_order) + inv_order = [ + kv_cache_stride_order.index(i) for i in range(len(kv_cache_stride_order)) + ] + + kv_cache_stride = tuple(torch.empty(kv_cache_shape).stride()) + storage_offset = 0 + + if enable_hybrid_attn_mamba_layout: + block_dim = backend.get_kv_cache_block_dim( + kernel_block_size, + kv_cache_spec.num_kv_heads, + kv_cache_spec.head_size, + cache_dtype_str="auto", + ) + kv_cache_stride, storage_offset = mamba_utils.get_hybrid_attention_mamba_layout( + kv_cache_shape=kv_cache_shape, + kv_cache_stride=kv_cache_stride, + kv_cache_spec=kv_cache_spec, + block_dim=block_dim, + layer_idx=layer_idx, + kernel_block_size=kernel_block_size, + ) + + kv = torch.empty_strided( + size=kv_cache_shape, + stride=kv_cache_stride, + dtype=dtype, + ).permute(*inv_order) + + # Sanity: group_size should not affect logical shape. + assert kv.shape == tuple( + backend.get_kv_cache_shape( + num_blocks, + block_size, + kv_cache_spec.num_kv_heads, + kv_cache_spec.head_size, + cache_dtype_str="auto", + ) + ) + + return kv.shape, kv.stride(), storage_offset + + +@pytest.mark.parametrize( + "backend_cls", + [ + CPUAttentionBackend, + FlashAttentionBackend, + FlashInferBackend, + TritonAttentionBackend, + FlashAttentionDiffKVBackend, + FlexAttentionBackend, + TreeAttentionBackend, + ], +) +@pytest.mark.parametrize("pack_size", [1, 2, 4]) +@pytest.mark.parametrize("enable_hybrid_attn_mamba_layout", [False, True]) +@pytest.mark.parametrize("cache_layout", ["NHD", "HND"]) +def test_hybrid_attention_mamba_layout_matches_reference( + backend_cls: type[AttentionBackend], + cache_layout: str, + pack_size: int, + enable_hybrid_attn_mamba_layout: bool, +): + if (not enable_hybrid_attn_mamba_layout) and pack_size > 1: + pytest.skip("pack_size > 1 only occurs when hybrid attention+mamba is enabled") + # Explicitly test both cache layouts for backends that depend on it + # (FlashAttentionBackend / FlashInferBackend). Other backends ignore + # this setting, but it is harmless to apply globally. + set_kv_cache_layout(cache_layout) # sets the override + # Invalidate the cached value in get_kv_cache_layout so the override + # takes effect for this test case. + get_kv_cache_layout.cache_clear() # type: ignore[attr-defined] + + block_size = 16 + num_kv_heads = 2 + head_size = 32 + dtype = torch.float16 + + num_blocks = 100 + kernel_block_size = 16 + layer_idx = 1 + + kv_cache_spec = _build_full_attn_spec( + block_size=block_size, + num_kv_heads=num_kv_heads, + head_size=head_size, + dtype=dtype, + pack_size=pack_size, + ) + + backend = backend_cls + + ref_shape, ref_stride, ref_offset = _compute_layout_ref( + backend=backend, + kv_cache_spec=kv_cache_spec, + layer_idx=layer_idx, + kernel_block_size=kernel_block_size, + num_blocks=num_blocks, + enable_hybrid_attn_mamba_layout=enable_hybrid_attn_mamba_layout, + ) + + new_shape, new_stride, new_offset = _compute_layout_new( + backend=backend, + kv_cache_spec=kv_cache_spec, + layer_idx=layer_idx, + kernel_block_size=kernel_block_size, + num_blocks=num_blocks, + enable_hybrid_attn_mamba_layout=enable_hybrid_attn_mamba_layout, + ) + + assert ref_shape == new_shape + assert ref_stride == new_stride + # storage_offset only differs from zero when group_size > 1. + if pack_size > 1: + assert new_offset == ref_offset + else: + assert ref_offset == 0 + assert new_offset == 0 diff --git a/vllm/config/cache.py b/vllm/config/cache.py index cd1554590ea3..2d4e879e47a5 100644 --- a/vllm/config/cache.py +++ b/vllm/config/cache.py @@ -29,6 +29,7 @@ ] MambaDType = Literal["auto", "float32", "float16"] MambaCacheMode = Literal["all", "align", "none"] +AttnPackSize = Literal[1, 2, 4, 8] PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor", "xxhash", "xxhash_cbor"] KVOffloadingBackend = Literal["native", "lmcache"] @@ -131,6 +132,9 @@ class CacheConfig: """Number of Philox PRNG rounds for stochastic rounding random number generation. 0 uses the Triton default. Higher values improve randomness quality at the cost of compute.""" + attn_pack_size: AttnPackSize = 1 + """The number of attention pages to allocate for Mamba layers. + This is only relevant for models that includes Mamba layers.""" # Will be set after profiling. num_gpu_blocks: int | None = field(default=None, init=False) diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 348ea40884a1..af5816962742 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -1287,6 +1287,12 @@ def has_blocked_weights(): # Default to enable HMA if not explicitly disabled by user or logic above. self.scheduler_config.disable_hybrid_kv_cache_manager = False + if self.cache_config.attn_pack_size > 1: + assert self.model_config.is_hybrid, ( + "Packing multiple FullAttention layers to a single page is only " + "supported for hybrid models" + ) + if self.compilation_config.debug_dump_path: self.compilation_config.debug_dump_path = ( self.compilation_config.debug_dump_path.absolute().expanduser() diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index c9b90848ff04..0e49a16b871c 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -63,6 +63,7 @@ get_attr_docs, ) from vllm.config.cache import ( + AttnPackSize, CacheDType, KVOffloadingBackend, MambaCacheMode, @@ -614,6 +615,7 @@ class EngineArgs: CacheConfig.enable_mamba_cache_stochastic_rounding ) mamba_cache_philox_rounds: int = CacheConfig.mamba_cache_philox_rounds + attn_pack_size: AttnPackSize = CacheConfig.attn_pack_size additional_config: dict[str, Any] = get_field(VllmConfig, "additional_config") @@ -1057,6 +1059,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: cache_group.add_argument( "--mamba-cache-philox-rounds", **cache_kwargs["mamba_cache_philox_rounds"] ) + cache_group.add_argument("--attn-pack-size", **cache_kwargs["attn_pack_size"]) cache_group.add_argument( "--kv-offloading-size", **cache_kwargs["kv_offloading_size"] ) @@ -1627,6 +1630,7 @@ def create_engine_config( mamba_cache_mode=self.mamba_cache_mode, enable_mamba_cache_stochastic_rounding=self.enable_mamba_cache_stochastic_rounding, mamba_cache_philox_rounds=self.mamba_cache_philox_rounds, + attn_pack_size=self.attn_pack_size, kv_offloading_size=self.kv_offloading_size, kv_offloading_backend=self.kv_offloading_backend, ) diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index c97c9118ac6e..97be9339a487 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -533,6 +533,14 @@ def _align_hybrid_block_size( kv_quant_mode = get_kv_quant_mode(cache_config.cache_dtype) + if cache_config.attn_pack_size > 1: + logger.info( + "Packing %d attention layers into one page, which will " + "reduce the block size to roughly 1/%d of the original.", + cache_config.attn_pack_size, + cache_config.attn_pack_size, + ) + # Compute attention page size for 1 token if model_config.use_mla: attn_page_size_1_token = MLAAttentionSpec( @@ -541,6 +549,7 @@ def _align_hybrid_block_size( head_size=model_config.get_head_size(), dtype=kv_cache_dtype, kv_quant_mode=kv_quant_mode, + pack_size=cache_config.attn_pack_size, ).page_size_bytes else: attn_page_size_1_token = FullAttentionSpec( @@ -549,6 +558,7 @@ def _align_hybrid_block_size( head_size=model_config.get_head_size(), dtype=kv_cache_dtype, kv_quant_mode=kv_quant_mode, + pack_size=cache_config.attn_pack_size, ).page_size_bytes # Compute mamba page size diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 9ab5af0f6fb0..7528abab87ed 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -18,6 +18,7 @@ from vllm.utils.math_utils import cdiv from vllm.utils.mem_utils import format_gib from vllm.v1.kv_cache_interface import ( + AttentionSpec, ChunkedLocalAttentionSpec, FullAttentionSpec, KVCacheConfig, @@ -1505,6 +1506,104 @@ def _project_kv_cache_groups_to_worker( return projected_groups +def merge_attn_layers_into_pack( + attn_pack_size: int, + kv_cache_specs: dict[str, KVCacheSpec], +) -> dict[str, KVCacheSpec]: + """ + Merge attention layers into packs based on attn_pack_size. + + When attn_pack_size > 1, consecutive attention layers are packed + together to share a KV-block, with the block partitioned across layers. + This function packs every attn_pack_size consecutive attention layers + into a group, using "+" as a delimiter to join their layer names into a + new key. + + Args: + attn_pack_size: The number of layers in each attention pack. + kv_cache_specs: A dictionary mapping layer names to their KV cache specs. + + Returns: + A merged KV cache spec dictionary where consecutive attention layers + are packed together. + """ + merged_kv_cache_specs: dict[str, KVCacheSpec] = {} + att_groups: defaultdict[AttentionSpec, tuple[list[str], list[AttentionSpec]]] = ( + defaultdict(lambda: ([], [])) + ) + for layer_name, kv_spec in kv_cache_specs.items(): + if isinstance(kv_spec, AttentionSpec): + att_groups[kv_spec][0].append(layer_name) + att_groups[kv_spec][1].append(kv_spec) + else: + merged_kv_cache_specs[layer_name] = kv_spec + for kv_spec, layers in att_groups.items(): + layer_names, layer_specs = layers + num_layers = len(layer_names) + assert num_layers == len(layer_specs) + for start in range(0, num_layers, attn_pack_size): + end = min(num_layers, start + attn_pack_size) + grouped_layer_names = "+".join(layer_names[start:end]) + try: + merged_kv_spec = layer_specs[start].merge_from_grouping( + layer_specs[start:end], attn_pack_size + ) + except AssertionError as e: + raise ValueError( + "Failed to merge KV cache specs for layers " + f"{layer_names[start:end]}" + ) from e + merged_kv_cache_specs[grouped_layer_names] = merged_kv_spec + return merged_kv_cache_specs + + +def split_attn_layers_from_pack( + attn_pack_size: int, + kv_cache_config: KVCacheConfig, +) -> KVCacheConfig: + """ + Split attention layer packs back to individual layers. + + This is the reverse operation of _merge_attn_layers_into_pack. Once + the KV cache configuration is generated with packed layers, this + function splits them back to individual layer names so that each + physical layer can be properly initialized. + + Args: + attn_pack_size: The number of layers in each attention pack (same as + used in _merge_attn_layers_into_pack). + kv_cache_config: The KV cache configuration with packed layers. + + Returns: + The KV cache configuration with layer packs split back to individual + layers. + """ + grouped_layers: dict[str, list[str]] = {} + for group in kv_cache_config.kv_cache_groups: + if not isinstance(group.kv_cache_spec, AttentionSpec): + continue + split_layer_names = [] + for i, layer_name in enumerate(group.layer_names): + layer_names = layer_name.split("+") + assert len(layer_names) == attn_pack_size or ( + 0 < len(layer_names) < attn_pack_size + and i == len(group.layer_names) - 1 + ), f"Invalid layer name {layer_name} for attention grouping split" + grouped_layers[layer_name] = layer_names + split_layer_names.extend(layer_names) + group.layer_names = split_layer_names + for kv_cache_tensor in kv_cache_config.kv_cache_tensors: + split_shared_by = [] + for layer_name in kv_cache_tensor.shared_by: + if layer_name in grouped_layers: + split_shared_by.extend(grouped_layers.pop(layer_name)) + else: + split_shared_by.append(layer_name) + kv_cache_tensor.shared_by = split_shared_by + assert not grouped_layers + return kv_cache_config + + def get_kv_cache_configs( vllm_config: VllmConfig, kv_cache_specs: list[dict[str, KVCacheSpec]], @@ -1518,17 +1617,21 @@ def get_kv_cache_configs( workers may have different memory available, and different type of layers (when pipeline parallel is enabled). To handle the difference between workers, the current implementation is: - 1. Merge the KV cache specs of all workers to get the KVCacheSpecs for + 1. If attn_pack_size > 1 (for Mamba models), pack attention layers into + groups to share a KV-block before processing. + 2. Merge the KV cache specs of all workers to get the KVCacheSpecs for the whole model. - 2. Generate the KV cache groups based on the layer ratio of the whole model. + 3. Generate the KV cache groups based on the layer ratio of the whole model. This also handles spec unification for hybrid models. - 3. Handle auto-fit max_model_len and memory checks using per-worker + 4. Handle auto-fit max_model_len and memory checks using per-worker projected groups to account for PP sharding. - 4. Generate the KV cache configs for each worker based on the KV cache + 5. Generate the KV cache configs for each worker based on the KV cache grouping strategy. (This is reasonable because the layer ratio of different PP stages are similar.) - 5. Change the num_blocks of each worker to the smallest among all workers + 6. Change the num_blocks of each worker to the smallest among all workers and shrink tensor sizes proportionally to avoid allocating unused memory. + 7. If attn_pack_size > 1 (for Mamba models), split packed layers back + to individual layers after generating configs. Args: vllm_config: The global VllmConfig @@ -1540,6 +1643,16 @@ def get_kv_cache_configs( The generated KVCacheConfigs for each worker. """ + attn_pack_size = vllm_config.cache_config.attn_pack_size + # When attn_pack_size > 1 (for Mamba models), pack attention layers together + # to share a KV-block. + if attn_pack_size > 1: + for i in range(len(kv_cache_specs)): + kv_cache_specs[i] = merge_attn_layers_into_pack( + attn_pack_size, + kv_cache_specs[i], + ) + # Merge the KV cache specs of all workers. Different PP stages may have # different layer names, and different TP ranks of the same PP stage should # have the same KV cache spec. @@ -1614,6 +1727,15 @@ def get_kv_cache_configs( if len(kv_cache_config.kv_cache_groups) > 0: _report_kv_cache_config(vllm_config, kv_cache_config) + # When attn_pack_size > 1 (for Mamba models), split packed layers back + # to individual layers after generating configs. + if attn_pack_size > 1: + for i in range(len(kv_cache_configs)): + kv_cache_configs[i] = split_attn_layers_from_pack( + attn_pack_size, + kv_cache_configs[i], + ) + return kv_cache_configs diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 6f8ad8e7d8ef..71210f5a6617 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -117,6 +117,11 @@ class AttentionSpec(KVCacheSpec): dtype: torch.dtype kv_quant_mode: KVQuantMode = KVQuantMode.NONE page_size_padded: int | None = None + pack_size: int = 1 + """ + The size of a group of attention layers. + It's used for Mamba models to group more than one FullAttn layers to one page. + """ @property def page_size_bytes(self) -> int: @@ -130,8 +135,8 @@ def page_size_bytes(self) -> int: ) if self.page_size_padded is not None: assert self.page_size_padded >= real_page_size - return self.page_size_padded - return real_page_size + return self.page_size_padded * self.pack_size + return real_page_size * self.pack_size @property def real_page_size_bytes(self) -> int: @@ -143,6 +148,10 @@ def real_page_size_bytes(self) -> int: * get_dtype_size(self.dtype) ) + @classmethod + def merge_from_grouping(cls, specs: list[Self], pack_size: int) -> Self: + return replace(cls.merge(specs), pack_size=pack_size) + @dataclass(frozen=True, kw_only=True) class FullAttentionSpec(AttentionSpec): @@ -218,6 +227,7 @@ def merge(cls, specs: list[Self]) -> Self: dtype=specs[0].dtype, kv_quant_mode=specs[0].kv_quant_mode, page_size_padded=specs[0].page_size_padded, + pack_size=specs[0].pack_size, sliding_window=cls.merge_window_sizes(sliding_window), attention_chunk_size=cls.merge_window_sizes(attention_chunk_size), ) @@ -280,6 +290,7 @@ def merge(cls, specs: list[Self]) -> Self: dtype=specs[0].dtype, kv_quant_mode=specs[0].kv_quant_mode, page_size_padded=specs[0].page_size_padded, + pack_size=specs[0].pack_size, cache_dtype_str=cache_dtype_str_set.pop(), ) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 4f3e192772ff..6e2f18e0fe9a 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -128,6 +128,10 @@ get_dcp_local_seq_lens, reorder_batch_to_split_decodes_and_prefills, ) +from vllm.v1.core.kv_cache_utils import ( + merge_attn_layers_into_pack, + split_attn_layers_from_pack, +) from vllm.v1.core.sched.output import NewRequestData from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher from vllm.v1.kv_cache_interface import ( @@ -5832,6 +5836,14 @@ def _init_minimal_kv_cache_for_profiling(self) -> None: ) kv_cache_spec = self.get_kv_cache_spec() + attn_pack_size = self.vllm_config.cache_config.attn_pack_size + # When attn_pack_size > 1 (for Mamba models), pack attention layers together + # to share a KV-block. + if attn_pack_size > 1: + kv_cache_spec = merge_attn_layers_into_pack( + attn_pack_size, + kv_cache_spec, + ) kv_cache_groups = get_kv_cache_groups(self.vllm_config, kv_cache_spec) min_blocks = self.compilation_config.max_cudagraph_capture_size or 1 @@ -5841,6 +5853,13 @@ def _init_minimal_kv_cache_for_profiling(self) -> None: minimal_config = get_kv_cache_config_from_groups( self.vllm_config, kv_cache_groups, available_memory=0 ) + # When attn_pack_size > 1 (for Mamba models), split packed layers back + # to individual layers after generating configs. + if attn_pack_size > 1: + minimal_config = split_attn_layers_from_pack( + attn_pack_size, + minimal_config, + ) self.cache_config.num_gpu_blocks_override = saved_override self.initialize_kv_cache(minimal_config, is_profiling=True) @@ -6531,7 +6550,14 @@ def _reshape_kv_cache_tensors( corresponding memory buffer for KV cache. """ kv_caches: dict[str, torch.Tensor] = {} - has_attn, has_mamba = False, False + + # Pre-scan to determine whether there are mamba layers. + has_mamba = False + for group in self._kv_cache_spec_attn_group_iterator(): + kv_cache_spec = group.kv_cache_spec + if isinstance(kv_cache_spec, MambaSpec): + has_mamba = True + for group in self._kv_cache_spec_attn_group_iterator(): kv_cache_spec = group.kv_cache_spec attn_backend = group.backend @@ -6539,14 +6565,13 @@ def _reshape_kv_cache_tensors( # There may be a last group for layers without kv cache. continue kernel_block_size = kernel_block_sizes[group.kv_cache_group_id] - for layer_name in group.layer_names: + for layer_idx, layer_name in enumerate(group.layer_names): if layer_name in self.runner_only_attn_layers: continue raw_tensor = kv_cache_raw_tensors[layer_name] assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0 num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes if isinstance(kv_cache_spec, AttentionSpec): - has_attn = True num_blocks_per_kv_block = ( kv_cache_spec.block_size // kernel_block_size ) @@ -6573,19 +6598,37 @@ def _reshape_kv_cache_tensors( kv_cache_shape = tuple( kv_cache_shape[i] for i in kv_cache_stride_order ) + kv_cache_stride = tuple(torch.empty(kv_cache_shape).stride()) + storage_offset = 0 # Maintain original KV shape view. inv_order = [ kv_cache_stride_order.index(i) for i in range(len(kv_cache_stride_order)) ] - kv_caches[layer_name] = ( - kv_cache_raw_tensors[layer_name] - .view(dtype) - .view(kv_cache_shape) - .permute(*inv_order) - ) + if has_mamba: + block_dim = group.backend.get_kv_cache_block_dim( + kernel_block_sizes[group.kv_cache_group_id], + kv_cache_spec.num_kv_heads, + kv_cache_spec.head_size, + cache_dtype_str=self.cache_config.cache_dtype, + ) + kv_cache_stride, storage_offset = ( + mamba_utils.get_hybrid_attention_mamba_layout( + kv_cache_shape=kv_cache_shape, + kv_cache_stride=kv_cache_stride, + kv_cache_spec=kv_cache_spec, + block_dim=block_dim, + layer_idx=layer_idx, + kernel_block_size=kernel_block_size, + ) + ) + kv_caches[layer_name] = torch.as_strided( + raw_tensor.view(dtype), + size=kv_cache_shape, + stride=kv_cache_stride, + storage_offset=storage_offset, + ).permute(*inv_order) elif isinstance(kv_cache_spec, MambaSpec): - has_mamba = True raw_tensor = kv_cache_raw_tensors[layer_name] state_tensors = [] storage_offset_bytes = 0 @@ -6611,45 +6654,8 @@ def _reshape_kv_cache_tensors( else: raise NotImplementedError - if has_attn and has_mamba: - self._update_hybrid_attention_mamba_layout(kv_caches, kernel_block_sizes) - return kv_caches - def _update_hybrid_attention_mamba_layout( - self, kv_caches: dict[str, torch.Tensor], kernel_block_sizes: list[int] - ) -> None: - """ - Update the layout of attention layers from (2, num_blocks, ...) to - (num_blocks, 2, ...). - - Args: - kv_caches: The KV cache buffer of each layer. - kernel_block_sizes: The kernel block sizes for each KV cache group. - """ - - for group in self._kv_cache_spec_attn_group_iterator(): - kv_cache_spec = group.kv_cache_spec - if not isinstance(kv_cache_spec, AttentionSpec): - continue - block_dim = group.backend.get_kv_cache_block_dim( - kernel_block_sizes[group.kv_cache_group_id], - kv_cache_spec.num_kv_heads, - kv_cache_spec.head_size, - cache_dtype_str=self.cache_config.cache_dtype, - ) - # block_dim: 0 means (num_blocks, 2, ...); 1 means (2, num_blocks, ...). - if block_dim == 0: - continue - assert block_dim == 1 - for layer_name in group.layer_names: - kv_cache = kv_caches[layer_name] - hidden_size = kv_cache.shape[2:].numel() - kv_cache.as_strided_( - size=kv_cache.shape, - stride=(hidden_size, 2 * hidden_size, *kv_cache.stride()[2:]), - ) - def initialize_kv_cache_tensors( self, kv_cache_config: KVCacheConfig, kernel_block_sizes: list[int] ) -> dict[str, torch.Tensor]: diff --git a/vllm/v1/worker/mamba_utils.py b/vllm/v1/worker/mamba_utils.py index c832389b1b0a..16fec327d7b2 100644 --- a/vllm/v1/worker/mamba_utils.py +++ b/vllm/v1/worker/mamba_utils.py @@ -3,6 +3,7 @@ import dataclasses import itertools from collections.abc import Callable +from math import prod from typing import Any import torch @@ -13,8 +14,9 @@ ) from vllm.triton_utils import tl, triton from vllm.utils.math_utils import cdiv +from vllm.utils.torch_utils import get_dtype_size from vllm.v1.core.sched.output import SchedulerOutput -from vllm.v1.kv_cache_interface import KVCacheConfig, MambaSpec +from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig, MambaSpec from vllm.v1.utils import CpuGpuBuffer from vllm.v1.worker.gpu_input_batch import CachedRequestState from vllm.v1.worker.lora_model_runner_mixin import GPUInputBatch @@ -271,3 +273,57 @@ def postprocess_mamba( if src_block_idx == dest_block_idx: num_accepted_tokens_cpu[i] = 1 do_mamba_copy_block(copy_bufs) + + +def get_hybrid_attention_mamba_layout( + kv_cache_shape: tuple[int, ...], + kv_cache_stride: tuple[int, ...], + kv_cache_spec: AttentionSpec, + block_dim: int, + layer_idx: int, + kernel_block_size: int, +) -> tuple[tuple[int, ...], int]: + """ + Compute the stride and storage offset for the hybrid attention+mamba layout. + + Args: + kv_cache_shape: The shape of the KV cache tensor. + kv_cache_stride: The stride of the KV cache tensor. + kv_cache_spec: The specification of the KV cache. + layer_idx: The index of the layer. + kernel_num_blocks: The number of kernel blocks. + kernel_block_size: The size of the kernel block. + Returns: + A tuple containing the target stride and storage offset. + """ + target_stride_list = list(kv_cache_stride) + storage_offset = 0 + + attn_pack_size = kv_cache_spec.pack_size + # block_dim: 0 means (num_blocks, 2, ...); 1 means (2, num_blocks, ...). + if block_dim != 0: + # Hybrid attention+mamba uses (2, num_blocks, ...) logical shape but + # (num_blocks, 2, ...) physical layout. + assert kv_cache_shape[0] == 2, ( + "Fail to determine whether the layout is " + "(2, num_blocks, ...) or (num_blocks, 2, ...) for " + f"a tensor of shape {kv_cache_shape}" + ) + assert block_dim == 1 + hidden_size = prod(kv_cache_shape[2:]) + target_stride_list[0] = hidden_size + target_stride_list[1] = 2 * hidden_size + # When multiple attention layers share one physical KV cache block + # (attn_pack_size > 1), scale the block-dim stride by attn_pack_size + # and compute this layer's element offset within the shared block. + if attn_pack_size > 1: + target_stride_list[block_dim] *= attn_pack_size + dtype_size = get_dtype_size(kv_cache_spec.dtype) + num_element_per_page = kv_cache_spec.page_size_bytes // dtype_size + num_blocks_per_kv_block = kv_cache_spec.block_size // kernel_block_size + num_element_per_attn_pack = ( + num_element_per_page // num_blocks_per_kv_block // attn_pack_size + ) + attn_pack_idx = layer_idx % attn_pack_size + storage_offset = attn_pack_idx * num_element_per_attn_pack + return tuple(target_stride_list), storage_offset