From 1787b87d8eb78aa9ffe7605f27bb08477493a49d Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Sun, 1 Mar 2026 17:05:45 +0000 Subject: [PATCH 01/34] map more fullattn layers to one block Signed-off-by: huanghaoyan.hhy --- vllm/config/cache.py | 4 ++ vllm/engine/arg_utils.py | 6 ++ .../layers/attention/attention.py | 3 + vllm/model_executor/models/config.py | 1 + vllm/v1/core/kv_cache_utils.py | 72 +++++++++++++++---- vllm/v1/kv_cache_interface.py | 21 ++++++ vllm/v1/worker/gpu_model_runner.py | 50 ++++++++++--- 7 files changed, 136 insertions(+), 21 deletions(-) diff --git a/vllm/config/cache.py b/vllm/config/cache.py index daceaa6c2bb4..c9b57b283dc6 100644 --- a/vllm/config/cache.py +++ b/vllm/config/cache.py @@ -31,6 +31,7 @@ ] MambaDType = Literal["auto", "float32", "float16"] MambaCacheMode = Literal["all", "align", "none"] +MambaNumAttnPages = Literal[1, 2, 4, 8] PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor", "xxhash", "xxhash_cbor"] KVOffloadingBackend = Literal["native", "lmcache"] @@ -142,6 +143,9 @@ class CacheConfig: - "align": only cache the mamba state of the last token of each scheduler step and when the token is at position i * block_size. """ + mamba_num_attn_pages: MambaNumAttnPages = 1 + """The number of attention pages to allocate for Mamba layers. + This is only relevant for models that includes Mamba layers.""" # Will be set after profiling. num_gpu_blocks: int | None = field(default=None, init=False) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 8ea96de4913e..4e343bd28455 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -64,6 +64,7 @@ KVOffloadingBackend, MambaCacheMode, MambaDType, + MambaNumAttnPages, PrefixCachingHashAlgo, ) from vllm.config.device import Device @@ -571,6 +572,7 @@ class EngineArgs: mamba_ssm_cache_dtype: MambaDType = CacheConfig.mamba_ssm_cache_dtype mamba_block_size: int | None = get_field(CacheConfig, "mamba_block_size") mamba_cache_mode: MambaCacheMode = CacheConfig.mamba_cache_mode + mamba_num_attn_pages: MambaNumAttnPages = CacheConfig.mamba_num_attn_pages additional_config: dict[str, Any] = get_field(VllmConfig, "additional_config") @@ -970,6 +972,9 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: cache_group.add_argument( "--mamba-cache-mode", **cache_kwargs["mamba_cache_mode"] ) + cache_group.add_argument( + "--mamba-num-attn-pages", **cache_kwargs["mamba_num_attn_pages"] + ) cache_group.add_argument( "--kv-offloading-size", **cache_kwargs["kv_offloading_size"] ) @@ -1474,6 +1479,7 @@ def create_engine_config( mamba_ssm_cache_dtype=self.mamba_ssm_cache_dtype, mamba_block_size=self.mamba_block_size, mamba_cache_mode=self.mamba_cache_mode, + mamba_num_attn_pages=self.mamba_num_attn_pages, kv_offloading_size=self.kv_offloading_size, kv_offloading_backend=self.kv_offloading_backend, ) diff --git a/vllm/model_executor/layers/attention/attention.py b/vllm/model_executor/layers/attention/attention.py index 8c3ff3cc4df7..706891bac06e 100644 --- a/vllm/model_executor/layers/attention/attention.py +++ b/vllm/model_executor/layers/attention/attention.py @@ -223,10 +223,12 @@ def __init__( kv_cache_dtype = cache_config.cache_dtype block_size = cache_config.block_size calculate_kv_scales = cache_config.calculate_kv_scales + self.num_attn_pages = cache_config.mamba_num_attn_pages else: kv_cache_dtype = "auto" block_size = 16 calculate_kv_scales = False + self.num_attn_pages = 1 # llm-compressor mdls need to set cache_dtype to "fp8" manually. if getattr(quant_config, "kv_cache_scheme", None) is not None: @@ -531,6 +533,7 @@ def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec: head_size=self.head_size, head_size_v=self.head_size_v, dtype=self.kv_cache_torch_dtype, + group_size=self.num_attn_pages, ) diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 27cf3a7929fc..9c8a767ca3d5 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -439,6 +439,7 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: num_kv_heads=model_config.get_num_kv_heads(parallel_config), head_size=model_config.get_head_size(), dtype=kv_cache_dtype, + group_size=cache_config.mamba_num_attn_pages, ).page_size_bytes model_cls, _ = ModelRegistry.resolve_model_cls( diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index cfaa37074d9e..a31ed493e329 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -765,8 +765,11 @@ def create_kv_cache_group_specs( kv_cache_spec[layer_name] for layer_name in layer_names_one_group ] merged_layer_spec = layer_specs[0].merge(layer_specs) + group_size = len(layer_names_one_group) + if isinstance(merged_layer_spec, FullAttentionSpec): + group_size = cdiv(group_size, merged_layer_spec.group_size) kv_cache_groups.append( - KVCacheGroupSpec(layer_names_one_group, merged_layer_spec) + KVCacheGroupSpec(layer_names_one_group, merged_layer_spec, group_size) ) return kv_cache_groups @@ -803,7 +806,7 @@ def get_max_concurrency_for_kv_cache_config( Get the maximum concurrency for the given KV cache configuration. """ num_layer_per_group = max( - len(group.layer_names) for group in kv_cache_config.kv_cache_groups + group.group_size for group in kv_cache_config.kv_cache_groups ) max_memory_usage_per_request = num_layer_per_group * max_memory_usage_bytes( vllm_config, (group.kv_cache_spec for group in kv_cache_config.kv_cache_groups) @@ -1025,6 +1028,22 @@ def _get_kv_cache_groups_uniform_page_size( for layer_name, layer_spec in kv_cache_spec.items(): same_type_layers[layer_spec].append(layer_name) + attn_group_size = next( + ( + kv_spec.group_size + for kv_spec in same_type_layers + if isinstance(kv_spec, FullAttentionSpec) + ), + 1, + ) + same_type_num_layers: dict[KVCacheSpec, int] = {} + for kv_spec, layers in same_type_layers.items(): + num_layers = len(layers) + # For FullAttn, group `attn_group_size` layers to 1 page + if isinstance(kv_spec, FullAttentionSpec): + num_layers = cdiv(num_layers, attn_group_size) + same_type_num_layers[kv_spec] = num_layers + # Split each group into smaller groups, to make the number of layers in each # group identical. Add padding to the last group of each type if necessary. # E.g., (full.0, full.1), (sw.0, sw.1, sw.2) @@ -1037,9 +1056,9 @@ def _get_kv_cache_groups_uniform_page_size( # is the minimum number of layers among all attention types. Need a better # strategy if we want to support more complex patterns (e.g., 20 full + 30 # sw, where the group size should be 10). - min_num_layers = min([len(layers) for layers in same_type_layers.values()]) + min_num_layers = min([num_layers for num_layers in same_type_num_layers.values()]) group_size = min_num_layers - max_num_layers = max([len(layers) for layers in same_type_layers.values()]) + max_num_layers = max([num_layers for num_layers in same_type_num_layers.values()]) if max_num_layers < min_num_layers * 1.25: # If the number of layers is not much larger than the minimum number of layers, # use the maximum number of layers as the group size to avoid too many padding @@ -1048,7 +1067,7 @@ def _get_kv_cache_groups_uniform_page_size( # magic number to avoid too many padding layers. group_size = max_num_layers grouped_layers = [] - for layers in same_type_layers.values(): + for kv_spec, layers in same_type_layers.items(): num_padding_layers = group_size - len(layers) % group_size if num_padding_layers != group_size: logger.warning( @@ -1056,7 +1075,7 @@ def _get_kv_cache_groups_uniform_page_size( num_padding_layers, num_padding_layers / len(layers) * 100, ) - num_groups = cdiv(len(layers), group_size) + num_groups = cdiv(same_type_num_layers[kv_spec], group_size) # In PP case, say if we have # - stage 0: full.0, sw.0, sw.1 # - stage 1: full.1, sw.2, sw.3 @@ -1068,8 +1087,22 @@ def _get_kv_cache_groups_uniform_page_size( # the same and will cause memory waste. # To avoid this, we assign layers[i::num_groups] to the i-th group # instead of layers[i * group_size: (i + 1) * group_size] - for i in range(num_groups): - grouped_layers.append(layers[i::num_groups]) + if isinstance(kv_spec, FullAttentionSpec) and attn_group_size > 1: + stride = num_groups * attn_group_size + for i in range(num_groups): + attn_group_layers = [] + for start in range(0, len(layers), stride): + attn_group_layers.extend( + layers[ + i * attn_group_size + start : (i + 1) * attn_group_size + + start + ] + ) + grouped_layers.append(attn_group_layers) + else: + for i in range(num_groups): + grouped_layers.append(layers[i::num_groups]) + return create_kv_cache_group_specs(kv_cache_spec, grouped_layers) @@ -1126,7 +1159,7 @@ def get_kv_cache_config_from_groups( # (sw.1, padding) will be: (group_size = 2) # full.0, sw.0, sw.1: share a Tensor with size=available_memory//2 # full.1, sw.2: share another Tensor with size=available_memory//2 - group_size = max(len(group.layer_names) for group in kv_cache_groups) + group_size = max(group.group_size for group in kv_cache_groups) page_size = get_uniform_page_size( [group.kv_cache_spec for group in kv_cache_groups] @@ -1139,8 +1172,19 @@ def get_kv_cache_config_from_groups( for i in range(group_size): shared_by = [] for j in range(len(kv_cache_groups)): - if i < len(kv_cache_groups[j].layer_names): - shared_by.append(kv_cache_groups[j].layer_names[i]) + num_layers = len(kv_cache_groups[j].layer_names) + kv_cache_spec = kv_cache_groups[j].kv_cache_spec + if ( + isinstance(kv_cache_spec, FullAttentionSpec) + and (attn_group_size := kv_cache_spec.group_size) > 1 + ): + for k in range(attn_group_size): + idx = i * attn_group_size + k + if idx < num_layers: + shared_by.append(kv_cache_groups[j].layer_names[idx]) + else: + if i < num_layers: + shared_by.append(kv_cache_groups[j].layer_names[i]) kv_cache_tensors.append( KVCacheTensor(size=page_size * num_blocks, shared_by=shared_by) ) @@ -1350,7 +1394,7 @@ def _max_memory_usage_bytes_from_groups( # General case: group_size pools, each shared by one layer per group # Memory = group_size * page_size * blocks_for_max_len - group_size = max(len(group.layer_names) for group in kv_cache_groups) + group_size = max(group.group_size for group in kv_cache_groups) page_size = get_uniform_page_size( [group.kv_cache_spec for group in kv_cache_groups] ) @@ -1494,7 +1538,9 @@ def _project_kv_cache_groups_to_worker( for layer_name in worker_layer_names }, ) - projected_groups.append(KVCacheGroupSpec(worker_layer_names, group_spec)) + projected_groups.append( + KVCacheGroupSpec(worker_layer_names, group_spec, group.group_size) + ) return projected_groups diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 4a1b16fc580c..4556492b92c6 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -106,10 +106,20 @@ class FullAttentionSpec(AttentionSpec): """ attention_chunk_size: int | None = None + group_size: int = 1 + """ + The size of a group of attention layers. + It's used for Mamba models to group more than one FullAttn layers to one page. + """ + def __post_init__(self): if self.head_size_v is None: object.__setattr__(self, "head_size_v", self.head_size) + @property + def page_size_bytes(self) -> int: + return super().page_size_bytes * self.group_size + def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: max_model_len = vllm_config.model_config.max_model_len dcp_world_size = vllm_config.parallel_config.decode_context_parallel_size @@ -162,6 +172,7 @@ def merge(cls, specs: list[Self]) -> Self: page_size_padded=specs[0].page_size_padded, sliding_window=cls.merge_window_sizes(sliding_window), attention_chunk_size=cls.merge_window_sizes(attention_chunk_size), + group_size=specs[0].group_size, ) for spec in specs: for f in fields(AttentionSpec): @@ -469,6 +480,16 @@ class KVCacheGroupSpec: layer_names: list[str] # The KV cache spec of this manager layer kv_cache_spec: KVCacheSpec + # The size of this group. + # Normally, it is the length of the layer names. + # For Mamba models, some FullAttn layers are grouped together and map + # to the same KV cache block, so the group size will be smaller than + # the number of layers. + group_size: int = -1 + + def __post_init__(self): + if self.group_size <= 0: + self.group_size = len(self.layer_names) @dataclass diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 3a354b81864f..b5fd2d2960ea 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -5872,7 +5872,7 @@ def _reshape_kv_cache_tensors( # There may be a last group for layers without kv cache. continue kernel_block_size = kernel_block_sizes[group.kv_cache_group_id] - for layer_name in group.layer_names: + for layer_idx, layer_name in enumerate(group.layer_names): if layer_name in self.runner_only_attn_layers: continue raw_tensor = kv_cache_raw_tensors[layer_name] @@ -5911,12 +5911,37 @@ def _reshape_kv_cache_tensors( kv_cache_stride_order.index(i) for i in range(len(kv_cache_stride_order)) ] - kv_caches[layer_name] = ( - kv_cache_raw_tensors[layer_name] - .view(dtype) - .view(kv_cache_shape) - .permute(*inv_order) - ) + if ( + isinstance(kv_cache_spec, FullAttentionSpec) + and (attn_group_size := kv_cache_spec.group_size) > 1 + ): + attn_group_idx = layer_idx % attn_group_size + ori_stride = list(torch.empty(kv_cache_shape).stride()) + kernel_blocks_idx = kv_cache_shape.index(kernel_num_blocks) + ori_stride[kernel_blocks_idx] *= attn_group_size + target_stride = tuple(ori_stride) + dtype_size = get_dtype_size(dtype) + num_element_per_page = ( + kv_cache_spec.page_size_bytes // dtype_size + ) + num_element_per_attn_group = ( + num_element_per_page + // num_blocks_per_kv_block + // attn_group_size + ) + kv_caches[layer_name] = torch.as_strided( + kv_cache_raw_tensors[layer_name].view(dtype), + size=kv_cache_shape, + stride=target_stride, + storage_offset=attn_group_idx * num_element_per_attn_group, + ).permute(*inv_order) + else: + kv_caches[layer_name] = ( + kv_cache_raw_tensors[layer_name] + .view(dtype) + .view(kv_cache_shape) + .permute(*inv_order) + ) elif isinstance(kv_cache_spec, MambaSpec): has_mamba = True raw_tensor = kv_cache_raw_tensors[layer_name] @@ -5971,9 +5996,18 @@ def _update_hybrid_attention_mamba_layout( f"a tensor of shape {kv_cache.shape}" ) hidden_size = kv_cache.shape[2:].numel() + attn_group_size = ( + kv_cache_spec.group_size + if isinstance(kv_cache_spec, FullAttentionSpec) + else 1 + ) kv_cache.as_strided_( size=kv_cache.shape, - stride=(hidden_size, 2 * hidden_size, *kv_cache.stride()[2:]), + stride=( + hidden_size, + 2 * hidden_size * attn_group_size, + *kv_cache.stride()[2:], + ), ) def initialize_kv_cache_tensors( From 954db4ad24e8efdf00807037beea7156c899597b Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Mon, 2 Mar 2026 17:35:33 +0000 Subject: [PATCH 02/34] add test_get_kv_cache_configs_with_mamba Signed-off-by: huanghaoyan.hhy --- tests/v1/core/test_kv_cache_utils.py | 98 ++++++++++++++++++++++++++++ 1 file changed, 98 insertions(+) diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index c609bc1b85ec..5894aebb3058 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -43,6 +43,7 @@ KVCacheGroupSpec, KVCacheSpec, KVCacheTensor, + MambaSpec, MLAAttentionSpec, SlidingWindowSpec, UniformTypeKVCacheSpecs, @@ -109,6 +110,7 @@ def new_kv_cache_spec( page_size_padded=None, sliding_window=None, attention_chunk_size=None, + group_size=1, ): return FullAttentionSpec( block_size=block_size, @@ -118,6 +120,7 @@ def new_kv_cache_spec( page_size_padded=page_size_padded, sliding_window=sliding_window, attention_chunk_size=attention_chunk_size, + group_size=group_size, ) @@ -157,6 +160,22 @@ def new_chunked_local_attention_spec( ) +def new_mamba_spec( + block_size=16, + shapes=((2, 512), (3, 32, 32)), + dtypes=(torch.float32, torch.float32), + num_speculative_blocks=2, + page_size_padded=None, +): + return MambaSpec( + block_size=block_size, + shapes=shapes, + dtypes=dtypes, + page_size_padded=page_size_padded, + num_speculative_blocks=num_speculative_blocks, + ) + + @pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor]) def test_none_hash(monkeypatch, hash_fn): import vllm.v1.core.kv_cache_utils @@ -2082,3 +2101,82 @@ def test_unify_hybrid_kv_cache_specs(): with pytest.raises(ValueError): kv_cache_utils.unify_hybrid_kv_cache_specs(kv_cache_spec) + + +def test_get_kv_cache_configs_with_mamba(): + model_config = ModelConfig(max_model_len=16) + vllm_config = VllmConfig(model_config=model_config) + + expected_page_size = new_mamba_spec().page_size_bytes + + # Test 1: Pure mamba model (2 layers with same spec) + kv_cache_specs = { + "layer_1": new_mamba_spec(), + "layer_2": new_mamba_spec(), + } + available_memory = expected_page_size * 2 * 10 + kv_cache_config = get_kv_cache_configs( + vllm_config, [kv_cache_specs], [available_memory] + )[0] + + assert kv_cache_config == KVCacheConfig( + num_blocks=10, + kv_cache_tensors=[ + KVCacheTensor(size=expected_page_size * 10, shared_by=["layer_1"]), + KVCacheTensor(size=expected_page_size * 10, shared_by=["layer_2"]), + ], + kv_cache_groups=[KVCacheGroupSpec(["layer_1", "layer_2"], new_mamba_spec())], + ) + + # Test 2: 1 mamba + 1 full + hybrid_kv_cache_specs = { + "layer_1": new_mamba_spec(), + "layer_2": new_kv_cache_spec(), + } + available_memory_hybrid = expected_page_size * 2 * 10 + kv_cache_config_hybrid = get_kv_cache_configs( + vllm_config, [hybrid_kv_cache_specs], [available_memory_hybrid] + )[0] + + assert kv_cache_config_hybrid == KVCacheConfig( + num_blocks=20, + kv_cache_tensors=[ + KVCacheTensor( + size=expected_page_size * 20, shared_by=["layer_1", "layer_2"] + ), + ], + kv_cache_groups=[ + KVCacheGroupSpec(["layer_1"], new_mamba_spec()), + KVCacheGroupSpec(["layer_2"], new_kv_cache_spec()), + ], + ) + + # Test 3: 1 mamba + 2 full attention with group size 2 + hybrid_kv_cache_specs = { + "layer_1": new_mamba_spec(), + "layer_2": new_kv_cache_spec(head_size=32, group_size=2), + "layer_3": new_kv_cache_spec(head_size=32, group_size=2), + } + available_memory_hybrid = expected_page_size * 2 * 10 + kv_cache_config_hybrid = get_kv_cache_configs( + vllm_config, [hybrid_kv_cache_specs], [available_memory_hybrid] + )[0] + + print(kv_cache_config_hybrid) + assert kv_cache_config_hybrid == KVCacheConfig( + num_blocks=20, + kv_cache_tensors=[ + KVCacheTensor( + size=expected_page_size * 20, + shared_by=["layer_1", "layer_2", "layer_3"], + ), + ], + kv_cache_groups=[ + KVCacheGroupSpec(["layer_1"], new_mamba_spec()), + KVCacheGroupSpec( + ["layer_2", "layer_3"], + new_kv_cache_spec(head_size=32, group_size=2), + group_size=1, + ), + ], + ) From 75fe78eca46476bf5d30ddc2446f8ec970c49be9 Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Mon, 2 Mar 2026 17:54:09 +0000 Subject: [PATCH 03/34] fix padding layers Signed-off-by: huanghaoyan.hhy --- vllm/v1/core/kv_cache_utils.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index a31ed493e329..6a8a03524329 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -1068,7 +1068,12 @@ def _get_kv_cache_groups_uniform_page_size( group_size = max_num_layers grouped_layers = [] for kv_spec, layers in same_type_layers.items(): - num_padding_layers = group_size - len(layers) % group_size + if isinstance(kv_spec, FullAttentionSpec) and attn_group_size > 1: + num_padding_layers = group_size * attn_group_size - len(layers) % ( + group_size * attn_group_size + ) + else: + num_padding_layers = group_size - len(layers) % group_size if num_padding_layers != group_size: logger.warning( "Add %d padding layers, may waste at most %.2f%% KV cache memory", # noqa From 5817d4d9f6149ae43402db03261f847cfaf85d7b Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Mon, 2 Mar 2026 18:07:00 +0000 Subject: [PATCH 04/34] fix padding layers Signed-off-by: huanghaoyan.hhy --- vllm/v1/core/kv_cache_utils.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 6a8a03524329..8fda83e78600 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -1068,13 +1068,11 @@ def _get_kv_cache_groups_uniform_page_size( group_size = max_num_layers grouped_layers = [] for kv_spec, layers in same_type_layers.items(): - if isinstance(kv_spec, FullAttentionSpec) and attn_group_size > 1: - num_padding_layers = group_size * attn_group_size - len(layers) % ( - group_size * attn_group_size - ) - else: - num_padding_layers = group_size - len(layers) % group_size - if num_padding_layers != group_size: + padding_size = group_size + if isinstance(kv_spec, FullAttentionSpec): + padding_size *= attn_group_size + num_padding_layers = padding_size - len(layers) % padding_size + if num_padding_layers != padding_size: logger.warning( "Add %d padding layers, may waste at most %.2f%% KV cache memory", # noqa num_padding_layers, From 2d6eefc02821a3bbbcf136a488ad6d6461b592ba Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Mon, 2 Mar 2026 18:25:32 +0000 Subject: [PATCH 05/34] add test Signed-off-by: huanghaoyan.hhy --- tests/v1/core/test_kv_cache_utils.py | 43 +++++++++++++++++++++++++++- 1 file changed, 42 insertions(+), 1 deletion(-) diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index 5894aebb3058..f6f8b78eb57a 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -2162,7 +2162,6 @@ def test_get_kv_cache_configs_with_mamba(): vllm_config, [hybrid_kv_cache_specs], [available_memory_hybrid] )[0] - print(kv_cache_config_hybrid) assert kv_cache_config_hybrid == KVCacheConfig( num_blocks=20, kv_cache_tensors=[ @@ -2180,3 +2179,45 @@ def test_get_kv_cache_configs_with_mamba(): ), ], ) + + # Test 4: 2 mamba + 5 full (with 3 padding full) + hybrid_kv_cache_specs = { + "layer_1": new_mamba_spec(), + "layer_2": new_mamba_spec(), + "layer_3": new_kv_cache_spec(head_size=32, group_size=2), + "layer_4": new_kv_cache_spec(head_size=32, group_size=2), + "layer_5": new_kv_cache_spec(head_size=32, group_size=2), + "layer_6": new_kv_cache_spec(head_size=32, group_size=2), + "layer_7": new_kv_cache_spec(head_size=32, group_size=2), + } + available_memory_hybrid = expected_page_size * 2 * 10 + kv_cache_config_hybrid = get_kv_cache_configs( + vllm_config, [hybrid_kv_cache_specs], [available_memory_hybrid] + )[0] + + assert kv_cache_config_hybrid == KVCacheConfig( + num_blocks=10, + kv_cache_tensors=[ + KVCacheTensor( + size=expected_page_size * 10, + shared_by=["layer_1", "layer_3", "layer_4", "layer_5", "layer_6"], + ), + KVCacheTensor( + size=expected_page_size * 10, + shared_by=["layer_2", "layer_7"], + ), + ], + kv_cache_groups=[ + KVCacheGroupSpec(["layer_1", "layer_2"], new_mamba_spec()), + KVCacheGroupSpec( + ["layer_3", "layer_4", "layer_7"], + new_kv_cache_spec(head_size=32, group_size=2), + group_size=2, + ), + KVCacheGroupSpec( + ["layer_5", "layer_6"], + new_kv_cache_spec(head_size=32, group_size=2), + group_size=1, + ), + ], + ) From a365a26f9fa324e6375cdf3f5a8b5cd373ab532a Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Tue, 3 Mar 2026 02:31:52 +0000 Subject: [PATCH 06/34] change group_size to property Signed-off-by: huanghaoyan.hhy --- tests/v1/core/test_kv_cache_utils.py | 2 -- vllm/v1/core/kv_cache_utils.py | 6 ++---- vllm/v1/kv_cache_interface.py | 22 +++++++++++++--------- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index f6f8b78eb57a..d6b3552d9873 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -2212,12 +2212,10 @@ def test_get_kv_cache_configs_with_mamba(): KVCacheGroupSpec( ["layer_3", "layer_4", "layer_7"], new_kv_cache_spec(head_size=32, group_size=2), - group_size=2, ), KVCacheGroupSpec( ["layer_5", "layer_6"], new_kv_cache_spec(head_size=32, group_size=2), - group_size=1, ), ], ) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 8fda83e78600..cf016373a604 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -769,7 +769,7 @@ def create_kv_cache_group_specs( if isinstance(merged_layer_spec, FullAttentionSpec): group_size = cdiv(group_size, merged_layer_spec.group_size) kv_cache_groups.append( - KVCacheGroupSpec(layer_names_one_group, merged_layer_spec, group_size) + KVCacheGroupSpec(layer_names_one_group, merged_layer_spec) ) return kv_cache_groups @@ -1541,9 +1541,7 @@ def _project_kv_cache_groups_to_worker( for layer_name in worker_layer_names }, ) - projected_groups.append( - KVCacheGroupSpec(worker_layer_names, group_spec, group.group_size) - ) + projected_groups.append(KVCacheGroupSpec(worker_layer_names, group_spec)) return projected_groups diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 4556492b92c6..86016f64b579 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -480,16 +480,20 @@ class KVCacheGroupSpec: layer_names: list[str] # The KV cache spec of this manager layer kv_cache_spec: KVCacheSpec - # The size of this group. - # Normally, it is the length of the layer names. - # For Mamba models, some FullAttn layers are grouped together and map - # to the same KV cache block, so the group size will be smaller than - # the number of layers. - group_size: int = -1 - def __post_init__(self): - if self.group_size <= 0: - self.group_size = len(self.layer_names) + @property + def group_size(self) -> int: + """ + The size of this group. + Normally, it is the length of the layer names. + For Mamba models, some FullAttn layers are grouped together and map + to the same KV cache block, so the group size will be smaller than + the number of layers. + """ + if isinstance(self.kv_cache_spec, FullAttentionSpec): + return cdiv(len(self.layer_names), self.kv_cache_spec.group_size) + else: + return len(self.layer_names) @dataclass From c39fd7ff1fb66da29d77e4e981927265e76c1255 Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Tue, 3 Mar 2026 07:26:43 +0000 Subject: [PATCH 07/34] fix Signed-off-by: huanghaoyan.hhy --- tests/v1/core/test_kv_cache_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index d6b3552d9873..96973ab26144 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -2175,7 +2175,6 @@ def test_get_kv_cache_configs_with_mamba(): KVCacheGroupSpec( ["layer_2", "layer_3"], new_kv_cache_spec(head_size=32, group_size=2), - group_size=1, ), ], ) From 5e54f79b7ed3a4939ef165b5e92ce0066eaeb2c8 Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Tue, 3 Mar 2026 17:06:06 +0000 Subject: [PATCH 08/34] restrict mamba_num_attn_pages to hybrid model only Signed-off-by: huanghaoyan.hhy --- vllm/config/vllm.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 5db217b2203e..298b48392722 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -1130,6 +1130,12 @@ def has_blocked_weights(): "schedule a multiple of block_size tokens even if they are in the " "middle of a mm input" ) + if self.cache_config.mamba_num_attn_pages > 1: + assert self.model_config.is_hybrid, ( + "Mapping multiple FullAttention layers to a single page is only " + "supported for hybrid models" + ) + if self.compilation_config.debug_dump_path: self.compilation_config.debug_dump_path = ( self.compilation_config.debug_dump_path.absolute().expanduser() From 4eae237b43147ce05824cc7cf90352c64d8fb714 Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Wed, 4 Mar 2026 16:38:41 +0000 Subject: [PATCH 09/34] remove num_attn_pages in Attention Signed-off-by: huanghaoyan.hhy --- vllm/model_executor/layers/attention/attention.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/attention/attention.py b/vllm/model_executor/layers/attention/attention.py index 706891bac06e..9699c9c678c3 100644 --- a/vllm/model_executor/layers/attention/attention.py +++ b/vllm/model_executor/layers/attention/attention.py @@ -223,12 +223,10 @@ def __init__( kv_cache_dtype = cache_config.cache_dtype block_size = cache_config.block_size calculate_kv_scales = cache_config.calculate_kv_scales - self.num_attn_pages = cache_config.mamba_num_attn_pages else: kv_cache_dtype = "auto" block_size = 16 calculate_kv_scales = False - self.num_attn_pages = 1 # llm-compressor mdls need to set cache_dtype to "fp8" manually. if getattr(quant_config, "kv_cache_scheme", None) is not None: @@ -513,6 +511,7 @@ def get_attn_backend(self) -> type[AttentionBackend]: def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec: # Block size may get updated after model loading, refresh it block_size = vllm_config.cache_config.block_size + attn_group_size = vllm_config.cache_config.mamba_num_attn_pages # Should not be called for enc-dec or encoder-only attention. assert self.attn_type == AttentionType.DECODER if self.sliding_window is not None: @@ -533,7 +532,7 @@ def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec: head_size=self.head_size, head_size_v=self.head_size_v, dtype=self.kv_cache_torch_dtype, - group_size=self.num_attn_pages, + group_size=attn_group_size, ) From a96611eb7a979bdb3ea8d56f68c4ed17b17d5835 Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Wed, 4 Mar 2026 17:08:41 +0000 Subject: [PATCH 10/34] move group_size to KVCacheSpec Signed-off-by: huanghaoyan.hhy --- vllm/v1/kv_cache_interface.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 86016f64b579..7264e4a68b67 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -16,7 +16,7 @@ logger = init_logger(__name__) -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class KVCacheSpec: """ A base class for specifying the KV cache format of one layer. @@ -24,6 +24,12 @@ class KVCacheSpec: # number of tokens in a block block_size: int + group_size: int = 1 + """ + The size of a group of attention layers. + Currently, it's used for Mamba models to group more than one + FullAttn layers to one page. + """ @property def page_size_bytes(self) -> int: @@ -106,12 +112,6 @@ class FullAttentionSpec(AttentionSpec): """ attention_chunk_size: int | None = None - group_size: int = 1 - """ - The size of a group of attention layers. - It's used for Mamba models to group more than one FullAttn layers to one page. - """ - def __post_init__(self): if self.head_size_v is None: object.__setattr__(self, "head_size_v", self.head_size) @@ -165,6 +165,7 @@ def merge(cls, specs: list[Self]) -> Self: ) merged_spec = cls( block_size=specs[0].block_size, + group_size=specs[0].group_size, num_kv_heads=specs[0].num_kv_heads, head_size=specs[0].head_size, head_size_v=specs[0].head_size_v, @@ -172,7 +173,6 @@ def merge(cls, specs: list[Self]) -> Self: page_size_padded=specs[0].page_size_padded, sliding_window=cls.merge_window_sizes(sliding_window), attention_chunk_size=cls.merge_window_sizes(attention_chunk_size), - group_size=specs[0].group_size, ) for spec in specs: for f in fields(AttentionSpec): @@ -228,6 +228,7 @@ def merge(cls, specs: list[Self]) -> Self: ) return cls( block_size=specs[0].block_size, + group_size=specs[0].group_size, num_kv_heads=specs[0].num_kv_heads, head_size=specs[0].head_size, dtype=specs[0].dtype, @@ -281,7 +282,7 @@ def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: return (cdiv(num_tokens, self.block_size) + 1) * self.page_size_bytes -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class MambaSpec(KVCacheSpec): shapes: tuple[tuple[int, ...], ...] dtypes: tuple[torch.dtype] @@ -358,6 +359,7 @@ def merge(cls, specs: list[Self]) -> Self: ) merged_spec = cls( block_size=specs[0].block_size, + group_size=specs[0].group_size, num_kv_heads=specs[0].num_kv_heads, head_size=specs[0].head_size, head_size_v=specs[0].head_size_v, From f553c16a44a92a2097846767e2423ddaec91555d Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Sat, 7 Mar 2026 11:03:19 +0000 Subject: [PATCH 11/34] Revert "move group_size to KVCacheSpec" This reverts commit a96611eb7a979bdb3ea8d56f68c4ed17b17d5835. Signed-off-by: huanghaoyan.hhy --- vllm/v1/kv_cache_interface.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 7264e4a68b67..86016f64b579 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -16,7 +16,7 @@ logger = init_logger(__name__) -@dataclass(frozen=True, kw_only=True) +@dataclass(frozen=True) class KVCacheSpec: """ A base class for specifying the KV cache format of one layer. @@ -24,12 +24,6 @@ class KVCacheSpec: # number of tokens in a block block_size: int - group_size: int = 1 - """ - The size of a group of attention layers. - Currently, it's used for Mamba models to group more than one - FullAttn layers to one page. - """ @property def page_size_bytes(self) -> int: @@ -112,6 +106,12 @@ class FullAttentionSpec(AttentionSpec): """ attention_chunk_size: int | None = None + group_size: int = 1 + """ + The size of a group of attention layers. + It's used for Mamba models to group more than one FullAttn layers to one page. + """ + def __post_init__(self): if self.head_size_v is None: object.__setattr__(self, "head_size_v", self.head_size) @@ -165,7 +165,6 @@ def merge(cls, specs: list[Self]) -> Self: ) merged_spec = cls( block_size=specs[0].block_size, - group_size=specs[0].group_size, num_kv_heads=specs[0].num_kv_heads, head_size=specs[0].head_size, head_size_v=specs[0].head_size_v, @@ -173,6 +172,7 @@ def merge(cls, specs: list[Self]) -> Self: page_size_padded=specs[0].page_size_padded, sliding_window=cls.merge_window_sizes(sliding_window), attention_chunk_size=cls.merge_window_sizes(attention_chunk_size), + group_size=specs[0].group_size, ) for spec in specs: for f in fields(AttentionSpec): @@ -228,7 +228,6 @@ def merge(cls, specs: list[Self]) -> Self: ) return cls( block_size=specs[0].block_size, - group_size=specs[0].group_size, num_kv_heads=specs[0].num_kv_heads, head_size=specs[0].head_size, dtype=specs[0].dtype, @@ -282,7 +281,7 @@ def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: return (cdiv(num_tokens, self.block_size) + 1) * self.page_size_bytes -@dataclass(frozen=True, kw_only=True) +@dataclass(frozen=True) class MambaSpec(KVCacheSpec): shapes: tuple[tuple[int, ...], ...] dtypes: tuple[torch.dtype] @@ -359,7 +358,6 @@ def merge(cls, specs: list[Self]) -> Self: ) merged_spec = cls( block_size=specs[0].block_size, - group_size=specs[0].group_size, num_kv_heads=specs[0].num_kv_heads, head_size=specs[0].head_size, head_size_v=specs[0].head_size_v, From 09c429e397e30c37b101578048b16721a7e82780 Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Sat, 7 Mar 2026 14:48:31 +0000 Subject: [PATCH 12/34] move group_size to AttentionSpec Signed-off-by: huanghaoyan.hhy --- vllm/v1/kv_cache_interface.py | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 86016f64b579..fbe2012a81be 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -67,14 +67,19 @@ class AttentionSpec(KVCacheSpec): head_size: int dtype: torch.dtype page_size_padded: int | None = None + group_size: int = 1 + """ + The size of a group of attention layers. + It's used for Mamba models to group more than one FullAttn layers to one page. + """ @property def page_size_bytes(self) -> int: real_page_size = self.real_page_size_bytes if self.page_size_padded is not None: assert self.page_size_padded >= real_page_size - return self.page_size_padded - return real_page_size + return self.page_size_padded * self.group_size + return real_page_size * self.group_size @property def real_page_size_bytes(self) -> int: @@ -106,20 +111,10 @@ class FullAttentionSpec(AttentionSpec): """ attention_chunk_size: int | None = None - group_size: int = 1 - """ - The size of a group of attention layers. - It's used for Mamba models to group more than one FullAttn layers to one page. - """ - def __post_init__(self): if self.head_size_v is None: object.__setattr__(self, "head_size_v", self.head_size) - @property - def page_size_bytes(self) -> int: - return super().page_size_bytes * self.group_size - def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: max_model_len = vllm_config.model_config.max_model_len dcp_world_size = vllm_config.parallel_config.decode_context_parallel_size @@ -170,9 +165,9 @@ def merge(cls, specs: list[Self]) -> Self: head_size_v=specs[0].head_size_v, dtype=specs[0].dtype, page_size_padded=specs[0].page_size_padded, + group_size=specs[0].group_size, sliding_window=cls.merge_window_sizes(sliding_window), attention_chunk_size=cls.merge_window_sizes(attention_chunk_size), - group_size=specs[0].group_size, ) for spec in specs: for f in fields(AttentionSpec): @@ -232,6 +227,7 @@ def merge(cls, specs: list[Self]) -> Self: head_size=specs[0].head_size, dtype=specs[0].dtype, page_size_padded=specs[0].page_size_padded, + group_size=specs[0].group_size, cache_dtype_str=cache_dtype_str_set.pop(), ) From e1dc529a709455d199b7b1af455d6e62ba415883 Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Sun, 8 Mar 2026 17:16:08 +0000 Subject: [PATCH 13/34] refactor with merge and split layers Signed-off-by: huanghaoyan.hhy --- vllm/v1/core/kv_cache_utils.py | 153 ++++++++++++++++++++------------- vllm/v1/kv_cache_interface.py | 18 +--- 2 files changed, 98 insertions(+), 73 deletions(-) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index cf016373a604..cebb78e18e87 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -18,6 +18,7 @@ from vllm.utils.math_utils import cdiv from vllm.utils.mem_utils import format_gib from vllm.v1.kv_cache_interface import ( + AttentionSpec, ChunkedLocalAttentionSpec, FullAttentionSpec, KVCacheConfig, @@ -765,9 +766,6 @@ def create_kv_cache_group_specs( kv_cache_spec[layer_name] for layer_name in layer_names_one_group ] merged_layer_spec = layer_specs[0].merge(layer_specs) - group_size = len(layer_names_one_group) - if isinstance(merged_layer_spec, FullAttentionSpec): - group_size = cdiv(group_size, merged_layer_spec.group_size) kv_cache_groups.append( KVCacheGroupSpec(layer_names_one_group, merged_layer_spec) ) @@ -806,7 +804,7 @@ def get_max_concurrency_for_kv_cache_config( Get the maximum concurrency for the given KV cache configuration. """ num_layer_per_group = max( - group.group_size for group in kv_cache_config.kv_cache_groups + len(group.layer_names) for group in kv_cache_config.kv_cache_groups ) max_memory_usage_per_request = num_layer_per_group * max_memory_usage_bytes( vllm_config, (group.kv_cache_spec for group in kv_cache_config.kv_cache_groups) @@ -1028,22 +1026,6 @@ def _get_kv_cache_groups_uniform_page_size( for layer_name, layer_spec in kv_cache_spec.items(): same_type_layers[layer_spec].append(layer_name) - attn_group_size = next( - ( - kv_spec.group_size - for kv_spec in same_type_layers - if isinstance(kv_spec, FullAttentionSpec) - ), - 1, - ) - same_type_num_layers: dict[KVCacheSpec, int] = {} - for kv_spec, layers in same_type_layers.items(): - num_layers = len(layers) - # For FullAttn, group `attn_group_size` layers to 1 page - if isinstance(kv_spec, FullAttentionSpec): - num_layers = cdiv(num_layers, attn_group_size) - same_type_num_layers[kv_spec] = num_layers - # Split each group into smaller groups, to make the number of layers in each # group identical. Add padding to the last group of each type if necessary. # E.g., (full.0, full.1), (sw.0, sw.1, sw.2) @@ -1056,9 +1038,9 @@ def _get_kv_cache_groups_uniform_page_size( # is the minimum number of layers among all attention types. Need a better # strategy if we want to support more complex patterns (e.g., 20 full + 30 # sw, where the group size should be 10). - min_num_layers = min([num_layers for num_layers in same_type_num_layers.values()]) + min_num_layers = min([len(layers) for layers in same_type_layers.values()]) group_size = min_num_layers - max_num_layers = max([num_layers for num_layers in same_type_num_layers.values()]) + max_num_layers = max([len(layers) for layers in same_type_layers.values()]) if max_num_layers < min_num_layers * 1.25: # If the number of layers is not much larger than the minimum number of layers, # use the maximum number of layers as the group size to avoid too many padding @@ -1067,18 +1049,15 @@ def _get_kv_cache_groups_uniform_page_size( # magic number to avoid too many padding layers. group_size = max_num_layers grouped_layers = [] - for kv_spec, layers in same_type_layers.items(): - padding_size = group_size - if isinstance(kv_spec, FullAttentionSpec): - padding_size *= attn_group_size - num_padding_layers = padding_size - len(layers) % padding_size - if num_padding_layers != padding_size: + for layers in same_type_layers.values(): + num_padding_layers = group_size - len(layers) % group_size + if num_padding_layers != group_size: logger.warning( "Add %d padding layers, may waste at most %.2f%% KV cache memory", # noqa num_padding_layers, num_padding_layers / len(layers) * 100, ) - num_groups = cdiv(same_type_num_layers[kv_spec], group_size) + num_groups = cdiv(len(layers), group_size) # In PP case, say if we have # - stage 0: full.0, sw.0, sw.1 # - stage 1: full.1, sw.2, sw.3 @@ -1090,21 +1069,8 @@ def _get_kv_cache_groups_uniform_page_size( # the same and will cause memory waste. # To avoid this, we assign layers[i::num_groups] to the i-th group # instead of layers[i * group_size: (i + 1) * group_size] - if isinstance(kv_spec, FullAttentionSpec) and attn_group_size > 1: - stride = num_groups * attn_group_size - for i in range(num_groups): - attn_group_layers = [] - for start in range(0, len(layers), stride): - attn_group_layers.extend( - layers[ - i * attn_group_size + start : (i + 1) * attn_group_size - + start - ] - ) - grouped_layers.append(attn_group_layers) - else: - for i in range(num_groups): - grouped_layers.append(layers[i::num_groups]) + for i in range(num_groups): + grouped_layers.append(layers[i::num_groups]) return create_kv_cache_group_specs(kv_cache_spec, grouped_layers) @@ -1162,7 +1128,7 @@ def get_kv_cache_config_from_groups( # (sw.1, padding) will be: (group_size = 2) # full.0, sw.0, sw.1: share a Tensor with size=available_memory//2 # full.1, sw.2: share another Tensor with size=available_memory//2 - group_size = max(group.group_size for group in kv_cache_groups) + group_size = max(len(group.layer_names) for group in kv_cache_groups) page_size = get_uniform_page_size( [group.kv_cache_spec for group in kv_cache_groups] @@ -1175,19 +1141,8 @@ def get_kv_cache_config_from_groups( for i in range(group_size): shared_by = [] for j in range(len(kv_cache_groups)): - num_layers = len(kv_cache_groups[j].layer_names) - kv_cache_spec = kv_cache_groups[j].kv_cache_spec - if ( - isinstance(kv_cache_spec, FullAttentionSpec) - and (attn_group_size := kv_cache_spec.group_size) > 1 - ): - for k in range(attn_group_size): - idx = i * attn_group_size + k - if idx < num_layers: - shared_by.append(kv_cache_groups[j].layer_names[idx]) - else: - if i < num_layers: - shared_by.append(kv_cache_groups[j].layer_names[i]) + if i < len(kv_cache_groups[j].layer_names): + shared_by.append(kv_cache_groups[j].layer_names[i]) kv_cache_tensors.append( KVCacheTensor(size=page_size * num_blocks, shared_by=shared_by) ) @@ -1397,7 +1352,7 @@ def _max_memory_usage_bytes_from_groups( # General case: group_size pools, each shared by one layer per group # Memory = group_size * page_size * blocks_for_max_len - group_size = max(group.group_size for group in kv_cache_groups) + group_size = max(len(group.layer_names) for group in kv_cache_groups) page_size = get_uniform_page_size( [group.kv_cache_spec for group in kv_cache_groups] ) @@ -1545,6 +1500,70 @@ def _project_kv_cache_groups_to_worker( return projected_groups +def _merge_layers_from_attn_grouping( + attn_group_size: int, + kv_cache_specs: dict[str, KVCacheSpec], +) -> dict[str, KVCacheSpec]: + merged_kv_cache_specs: dict[str, KVCacheSpec] = {} + att_groups: defaultdict[AttentionSpec, tuple[list[str], list[AttentionSpec]]] = ( + defaultdict(lambda: ([], [])) + ) + for layer_name, kv_spec in kv_cache_specs.items(): + if isinstance(kv_spec, AttentionSpec): + att_groups[kv_spec][0].append(layer_name) + att_groups[kv_spec][1].append(kv_spec) + else: + merged_kv_cache_specs[layer_name] = kv_spec + for kv_spec, layers in att_groups.items(): + layer_names, layer_specs = layers + num_layers = len(layer_names) + assert num_layers == len(layer_specs) + for start in range(0, num_layers, attn_group_size): + end = min(num_layers, start + attn_group_size) + grouped_layer_names = "+".join(layer_names[start:end]) + try: + merged_kv_spec = layer_specs[start].merge_from_grouping( + layer_specs[start:end], attn_group_size + ) + except AssertionError as e: + raise ValueError( + "Failed to merge KV cache specs for layers " + f"{layer_names[start:end]}" + ) from e + merged_kv_cache_specs[grouped_layer_names] = merged_kv_spec + return merged_kv_cache_specs + + +def _split_layers_from_attn_grouping( + attn_group_size: int, + kv_cache_config: KVCacheConfig, +) -> KVCacheConfig: + grouped_layers: dict[str, list[str]] = {} + for group in kv_cache_config.kv_cache_groups: + if not isinstance(group.kv_cache_spec, AttentionSpec): + continue + split_layer_names = [] + for i, layer_name in enumerate(group.layer_names): + layer_names = layer_name.split("+") + assert len(layer_names) == attn_group_size or ( + 0 < len(layer_names) < attn_group_size + and i == len(group.layer_names) - 1 + ), f"Invalid layer name {layer_name} for attention grouping split" + grouped_layers[layer_name] = layer_names + split_layer_names.extend(layer_names) + group.layer_names = split_layer_names + for kv_cache_tensor in kv_cache_config.kv_cache_tensors: + split_shared_by = [] + for layer_name in kv_cache_tensor.shared_by: + if layer_name in grouped_layers: + split_shared_by.extend(grouped_layers.pop(layer_name)) + else: + split_shared_by.append(layer_name) + kv_cache_tensor.shared_by = split_shared_by + assert not grouped_layers + return kv_cache_config + + def get_kv_cache_configs( vllm_config: VllmConfig, kv_cache_specs: list[dict[str, KVCacheSpec]], @@ -1580,6 +1599,15 @@ def get_kv_cache_configs( The generated KVCacheConfigs for each worker. """ + attn_group_size = vllm_config.cache_config.mamba_num_attn_pages + # TODO: add comments + if attn_group_size > 1: + for i in range(len(kv_cache_specs)): + kv_cache_specs[i] = _merge_layers_from_attn_grouping( + attn_group_size, + kv_cache_specs[i], + ) + # Merge the KV cache specs of all workers. Different PP stages may have # different layer names, and different TP ranks of the same PP stage should # have the same KV cache spec. @@ -1654,6 +1682,13 @@ def get_kv_cache_configs( if len(kv_cache_config.kv_cache_groups) > 0: _report_kv_cache_config(vllm_config, kv_cache_config) + if attn_group_size > 1: + for i in range(len(kv_cache_configs)): + kv_cache_configs[i] = _split_layers_from_attn_grouping( + attn_group_size, + kv_cache_configs[i], + ) + return kv_cache_configs diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index fbe2012a81be..1e8686e5d2b5 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -91,6 +91,10 @@ def real_page_size_bytes(self) -> int: * get_dtype_size(self.dtype) ) + @classmethod + def merge_from_grouping(cls, specs: list[Self], group_size: int) -> Self: + return replace(cls.merge(specs), group_size=group_size) + @dataclass(frozen=True, kw_only=True) class FullAttentionSpec(AttentionSpec): @@ -477,20 +481,6 @@ class KVCacheGroupSpec: # The KV cache spec of this manager layer kv_cache_spec: KVCacheSpec - @property - def group_size(self) -> int: - """ - The size of this group. - Normally, it is the length of the layer names. - For Mamba models, some FullAttn layers are grouped together and map - to the same KV cache block, so the group size will be smaller than - the number of layers. - """ - if isinstance(self.kv_cache_spec, FullAttentionSpec): - return cdiv(len(self.layer_names), self.kv_cache_spec.group_size) - else: - return len(self.layer_names) - @dataclass class KVCacheConfig: From 49c58f75e7c0e388e99554903fa480af9f963d0f Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Mon, 9 Mar 2026 15:32:14 +0000 Subject: [PATCH 14/34] remove group_size in attention Signed-off-by: huanghaoyan.hhy --- vllm/model_executor/layers/attention/attention.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/model_executor/layers/attention/attention.py b/vllm/model_executor/layers/attention/attention.py index 514db37cc45d..38f10998ec9e 100644 --- a/vllm/model_executor/layers/attention/attention.py +++ b/vllm/model_executor/layers/attention/attention.py @@ -519,7 +519,6 @@ def get_attn_backend(self) -> type[AttentionBackend]: def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec: # Block size may get updated after model loading, refresh it block_size = vllm_config.cache_config.block_size - attn_group_size = vllm_config.cache_config.mamba_num_attn_pages # Should not be called for enc-dec or encoder-only attention. assert self.attn_type == AttentionType.DECODER if self.sliding_window is not None: @@ -540,7 +539,6 @@ def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec: head_size=self.head_size, head_size_v=self.head_size_v, dtype=self.kv_cache_torch_dtype, - group_size=attn_group_size, ) From 632c5e934894bb6872d5d7dd351b4d1bd86813d9 Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Mon, 9 Mar 2026 15:43:51 +0000 Subject: [PATCH 15/34] update test Signed-off-by: huanghaoyan.hhy --- tests/v1/core/test_kv_cache_utils.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index 96973ab26144..d6c2dd51f166 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -2152,10 +2152,11 @@ def test_get_kv_cache_configs_with_mamba(): ) # Test 3: 1 mamba + 2 full attention with group size 2 + vllm_config.cache_config.mamba_num_attn_pages = 2 hybrid_kv_cache_specs = { "layer_1": new_mamba_spec(), - "layer_2": new_kv_cache_spec(head_size=32, group_size=2), - "layer_3": new_kv_cache_spec(head_size=32, group_size=2), + "layer_2": new_kv_cache_spec(head_size=32), + "layer_3": new_kv_cache_spec(head_size=32), } available_memory_hybrid = expected_page_size * 2 * 10 kv_cache_config_hybrid = get_kv_cache_configs( @@ -2180,14 +2181,15 @@ def test_get_kv_cache_configs_with_mamba(): ) # Test 4: 2 mamba + 5 full (with 3 padding full) + vllm_config.cache_config.mamba_num_attn_pages = 2 hybrid_kv_cache_specs = { "layer_1": new_mamba_spec(), "layer_2": new_mamba_spec(), - "layer_3": new_kv_cache_spec(head_size=32, group_size=2), - "layer_4": new_kv_cache_spec(head_size=32, group_size=2), - "layer_5": new_kv_cache_spec(head_size=32, group_size=2), - "layer_6": new_kv_cache_spec(head_size=32, group_size=2), - "layer_7": new_kv_cache_spec(head_size=32, group_size=2), + "layer_3": new_kv_cache_spec(head_size=32), + "layer_4": new_kv_cache_spec(head_size=32), + "layer_5": new_kv_cache_spec(head_size=32), + "layer_6": new_kv_cache_spec(head_size=32), + "layer_7": new_kv_cache_spec(head_size=32), } available_memory_hybrid = expected_page_size * 2 * 10 kv_cache_config_hybrid = get_kv_cache_configs( From fa3dcc498ad295793e8fa003abf1db707902c76f Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Mon, 9 Mar 2026 15:57:43 +0000 Subject: [PATCH 16/34] add tests Signed-off-by: huanghaoyan.hhy --- tests/v1/core/test_kv_cache_utils.py | 76 ++++++++++++++++++++++++++++ 1 file changed, 76 insertions(+) diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index d6c2dd51f166..1ec783c55dbc 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -24,6 +24,8 @@ BlockHash, FreeKVCacheBlockQueue, KVCacheBlock, + _merge_layers_from_attn_grouping, + _split_layers_from_attn_grouping, estimate_max_model_len, generate_block_hash_extra_keys, generate_scheduler_kv_cache_config, @@ -2103,6 +2105,80 @@ def test_unify_hybrid_kv_cache_specs(): kv_cache_utils.unify_hybrid_kv_cache_specs(kv_cache_spec) +def test_merge_layers_from_attn_grouping(): + attn_group_size = 2 + + hybrid_kv_cache_specs = { + "layer_1": new_mamba_spec(), + "layer_2": new_kv_cache_spec(head_size=32), + "layer_3": new_kv_cache_spec(head_size=32), + } + merged_kv_cache_specs = _merge_layers_from_attn_grouping( + attn_group_size, hybrid_kv_cache_specs + ) + assert merged_kv_cache_specs == { + "layer_1": new_mamba_spec(), + "layer_2+layer_3": new_kv_cache_spec(head_size=32, group_size=2), + } + + hybrid_kv_cache_specs = { + "layer_1": new_mamba_spec(), + "layer_2": new_kv_cache_spec(head_size=32), + "layer_3": new_kv_cache_spec(head_size=32), + "layer_4": new_kv_cache_spec(head_size=32), + } + merged_kv_cache_specs = _merge_layers_from_attn_grouping( + attn_group_size, hybrid_kv_cache_specs + ) + assert merged_kv_cache_specs == { + "layer_1": new_mamba_spec(), + "layer_2+layer_3": new_kv_cache_spec(head_size=32, group_size=2), + "layer_4": new_kv_cache_spec(head_size=32, group_size=2), + } + + +def test_split_layers_from_attn_grouping(): + attn_group_size = 2 + expected_page_size = new_mamba_spec().page_size_bytes + + kv_cache_config = KVCacheConfig( + num_blocks=20, + kv_cache_tensors=[ + KVCacheTensor( + size=expected_page_size * 20, + shared_by=["layer_1", "layer_2+layer_3"], + ), + ], + kv_cache_groups=[ + KVCacheGroupSpec(["layer_1"], new_mamba_spec()), + KVCacheGroupSpec( + ["layer_2+layer_3"], + new_kv_cache_spec(head_size=32, group_size=2), + ), + ], + ) + split_kv_cache_config = _split_layers_from_attn_grouping( + attn_group_size, kv_cache_config + ) + + assert split_kv_cache_config == KVCacheConfig( + num_blocks=20, + kv_cache_tensors=[ + KVCacheTensor( + size=expected_page_size * 20, + shared_by=["layer_1", "layer_2", "layer_3"], + ), + ], + kv_cache_groups=[ + KVCacheGroupSpec(["layer_1"], new_mamba_spec()), + KVCacheGroupSpec( + ["layer_2", "layer_3"], + new_kv_cache_spec(head_size=32, group_size=2), + ), + ], + ) + + def test_get_kv_cache_configs_with_mamba(): model_config = ModelConfig(max_model_len=16) vllm_config = VllmConfig(model_config=model_config) From 5bf8e6751c75fda1670707ab4a7bccae5adbead4 Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Tue, 10 Mar 2026 17:34:21 +0000 Subject: [PATCH 17/34] refactor attn layout reshape Signed-off-by: huanghaoyan.hhy --- .../v1/worker/test_hybrid_kv_cache_layout.py | 277 ++++++++++++++++++ vllm/v1/worker/gpu_model_runner.py | 111 +++++-- 2 files changed, 356 insertions(+), 32 deletions(-) create mode 100644 tests/v1/worker/test_hybrid_kv_cache_layout.py diff --git a/tests/v1/worker/test_hybrid_kv_cache_layout.py b/tests/v1/worker/test_hybrid_kv_cache_layout.py new file mode 100644 index 000000000000..21ebc026cf12 --- /dev/null +++ b/tests/v1/worker/test_hybrid_kv_cache_layout.py @@ -0,0 +1,277 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from math import prod + +import pytest +import torch + +from vllm.utils.torch_utils import get_dtype_size +from vllm.v1.attention.backend import AttentionBackend +from vllm.v1.attention.backends.cpu_attn import CPUAttentionBackend +from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend +from vllm.v1.attention.backends.flash_attn_diffkv import FlashAttentionDiffKVBackend +from vllm.v1.attention.backends.flashinfer import FlashInferBackend +from vllm.v1.attention.backends.flex_attention import FlexAttentionBackend +from vllm.v1.attention.backends.tree_attn import TreeAttentionBackend +from vllm.v1.attention.backends.triton_attn import TritonAttentionBackend +from vllm.v1.attention.backends.utils import ( + get_kv_cache_layout, + set_kv_cache_layout, +) +from vllm.v1.kv_cache_interface import FullAttentionSpec +from vllm.v1.worker.gpu_model_runner import GPUModelRunner + + +def _build_full_attn_spec( + block_size: int, + num_kv_heads: int, + head_size: int, + dtype: torch.dtype, + group_size: int, +) -> FullAttentionSpec: + # Minimal valid FullAttentionSpec for layout tests. + return FullAttentionSpec( + block_size=block_size, + num_kv_heads=num_kv_heads, + head_size=head_size, + head_size_v=head_size, + dtype=dtype, + group_size=group_size, + ) + + +def _compute_layout_ref( + backend: AttentionBackend, + kv_cache_spec: FullAttentionSpec, + layer_idx: int, + kernel_block_size: int, + num_blocks: int, + enable_hybrid_attn_mamba_layout: bool, +): + """ + Reference implementation that mirrors the pre-refactor logic: + - grouping layout in `_reshape_kv_cache_tensors` + - followed by `_update_hybrid_attention_mamba_layout` when enabled. + """ + dtype = kv_cache_spec.dtype + block_size = kv_cache_spec.block_size + attn_group_size = kv_cache_spec.group_size + + num_blocks_per_kv_block = block_size // kernel_block_size + kernel_num_blocks = num_blocks * num_blocks_per_kv_block + + # Match `_reshape_kv_cache_tensors`: use `kernel_num_blocks` and + # `kernel_block_size` when querying backend shape. + kv_cache_shape_logical = backend.get_kv_cache_shape( + kernel_num_blocks, + kernel_block_size, + kv_cache_spec.num_kv_heads, + kv_cache_spec.head_size, + cache_dtype_str="auto", + ) + + try: + kv_cache_stride_order = backend.get_kv_cache_stride_order() + assert len(kv_cache_stride_order) == len(kv_cache_shape_logical) + except (AttributeError, NotImplementedError): + kv_cache_stride_order = tuple(range(len(kv_cache_shape_logical))) + + # Physical shape (the one used for allocation / as_strided). + kv_cache_shape = tuple(kv_cache_shape_logical[i] for i in kv_cache_stride_order) + + inv_order = [ + kv_cache_stride_order.index(i) for i in range(len(kv_cache_stride_order)) + ] + + # Base contiguous strides in physical layout. + base_stride = list(torch.empty(kv_cache_shape).stride()) + + storage_offset = 0 + if attn_group_size > 1: + # Match the original `_reshape_kv_cache_tensors` logic. + kernel_blocks_idx = kv_cache_shape.index(kernel_num_blocks) + base_stride[kernel_blocks_idx] *= attn_group_size + dtype_size = get_dtype_size(dtype) + num_element_per_page = kv_cache_spec.page_size_bytes // dtype_size + num_element_per_attn_group = ( + num_element_per_page // num_blocks_per_kv_block // attn_group_size + ) + attn_group_idx = layer_idx % attn_group_size + storage_offset = attn_group_idx * num_element_per_attn_group + + # Logical KV tensor after the initial reshape. + kv = torch.empty_strided( + size=kv_cache_shape, + stride=tuple(base_stride), + dtype=dtype, + ).permute(*inv_order) + + # Optional hybrid attention+mamba layout update. + # We analytically update the stride to match `_update_hybrid_attention_mamba_layout` + # without actually changing the underlying storage. + if ( + enable_hybrid_attn_mamba_layout + and isinstance(kv_cache_spec, FullAttentionSpec) + and kv.shape[0] == 2 + ): + hidden_size = prod(kv.shape[2:]) + attn_group_size_for_layout = kv_cache_spec.group_size + kv_stride = kv.stride() + kv_stride = ( + hidden_size, + 2 * hidden_size * attn_group_size_for_layout, + *kv_stride[2:], + ) + return kv.shape, kv_stride, storage_offset + + return kv.shape, kv.stride(), storage_offset + + +def _compute_layout_new( + backend: AttentionBackend, + kv_cache_spec: FullAttentionSpec, + layer_idx: int, + kernel_block_size: int, + num_blocks: int, + enable_hybrid_attn_mamba_layout: bool, +): + """ + Layout computed via the new helper `_get_hybrid_attention_mamba_layout`. + """ + dtype = kv_cache_spec.dtype + block_size = kv_cache_spec.block_size + + num_blocks_per_kv_block = block_size // kernel_block_size + kernel_num_blocks = num_blocks * num_blocks_per_kv_block + + kv_cache_shape_logical = backend.get_kv_cache_shape( + kernel_num_blocks, + kernel_block_size, + kv_cache_spec.num_kv_heads, + kv_cache_spec.head_size, + cache_dtype_str="auto", + ) + + try: + kv_cache_stride_order = backend.get_kv_cache_stride_order() + assert len(kv_cache_stride_order) == len(kv_cache_shape_logical) + except (AttributeError, NotImplementedError): + kv_cache_stride_order = tuple(range(len(kv_cache_shape_logical))) + + kv_cache_shape = tuple(kv_cache_shape_logical[i] for i in kv_cache_stride_order) + inv_order = [ + kv_cache_stride_order.index(i) for i in range(len(kv_cache_stride_order)) + ] + + kv_cache_stride = tuple(torch.empty(kv_cache_shape).stride()) + storage_offset = 0 + + # We only need the method body; GPUModelRunner state is unused. + runner = object.__new__(GPUModelRunner) + if enable_hybrid_attn_mamba_layout: + kv_cache_stride, storage_offset = runner._get_hybrid_attention_mamba_layout( + kv_cache_shape=kv_cache_shape, + kv_cache_stride=kv_cache_stride, + kv_cache_spec=kv_cache_spec, + layer_idx=layer_idx, + kernel_num_blocks=kernel_num_blocks, + kernel_block_size=kernel_block_size, + ) + + kv = torch.empty_strided( + size=kv_cache_shape, + stride=kv_cache_stride, + dtype=dtype, + ).permute(*inv_order) + + # Sanity: group_size should not affect logical shape. + assert kv.shape == tuple( + backend.get_kv_cache_shape( + num_blocks, + block_size, + kv_cache_spec.num_kv_heads, + kv_cache_spec.head_size, + cache_dtype_str="auto", + ) + ) + + return kv.shape, kv.stride(), storage_offset + + +@pytest.mark.parametrize( + "backend_cls", + [ + CPUAttentionBackend, + FlashAttentionBackend, + FlashInferBackend, + TritonAttentionBackend, + FlashAttentionDiffKVBackend, + FlexAttentionBackend, + TreeAttentionBackend, + ], +) +@pytest.mark.parametrize("group_size", [1, 2, 4]) +@pytest.mark.parametrize("enable_hybrid_attn_mamba_layout", [False, True]) +@pytest.mark.parametrize("cache_layout", ["NHD", "HND"]) +def test_hybrid_attention_mamba_layout_matches_reference( + backend_cls: type[AttentionBackend], + cache_layout: str, + group_size: int, + enable_hybrid_attn_mamba_layout: bool, +): + if (not enable_hybrid_attn_mamba_layout) and group_size > 1: + pytest.skip("group_size > 1 only occurs when hybrid attention+mamba is enabled") + # Explicitly test both cache layouts for backends that depend on it + # (FlashAttentionBackend / FlashInferBackend). Other backends ignore + # this setting, but it is harmless to apply globally. + set_kv_cache_layout(cache_layout) # sets the override + # Invalidate the cached value in get_kv_cache_layout so the override + # takes effect for this test case. + get_kv_cache_layout.cache_clear() # type: ignore[attr-defined] + + block_size = 16 + num_kv_heads = 2 + head_size = 32 + dtype = torch.float16 + + num_blocks = 100 + kernel_block_size = 16 + layer_idx = 1 + + kv_cache_spec = _build_full_attn_spec( + block_size=block_size, + num_kv_heads=num_kv_heads, + head_size=head_size, + dtype=dtype, + group_size=group_size, + ) + + backend = backend_cls + + ref_shape, ref_stride, ref_offset = _compute_layout_ref( + backend=backend, + kv_cache_spec=kv_cache_spec, + layer_idx=layer_idx, + kernel_block_size=kernel_block_size, + num_blocks=num_blocks, + enable_hybrid_attn_mamba_layout=enable_hybrid_attn_mamba_layout, + ) + + new_shape, new_stride, new_offset = _compute_layout_new( + backend=backend, + kv_cache_spec=kv_cache_spec, + layer_idx=layer_idx, + kernel_block_size=kernel_block_size, + num_blocks=num_blocks, + enable_hybrid_attn_mamba_layout=enable_hybrid_attn_mamba_layout, + ) + + assert ref_shape == new_shape + assert ref_stride == new_stride + # storage_offset only differs from zero when group_size > 1. + if group_size > 1: + assert new_offset == ref_offset + else: + assert ref_offset == 0 + assert new_offset == 0 diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index f74f99cf7503..7559c85667d8 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -12,6 +12,7 @@ from copy import copy, deepcopy from dataclasses import dataclass from functools import reduce +from math import prod from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias, cast import numpy as np @@ -5866,7 +5867,16 @@ def _reshape_kv_cache_tensors( corresponding memory buffer for KV cache. """ kv_caches: dict[str, torch.Tensor] = {} + + # Pre-scan to determine whether there are attn / mamba layers. has_attn, has_mamba = False, False + for group in self._kv_cache_spec_attn_group_iterator(): + kv_cache_spec = group.kv_cache_spec + if isinstance(kv_cache_spec, AttentionSpec): + has_attn = True + elif isinstance(kv_cache_spec, MambaSpec): + has_mamba = True + for group in self._kv_cache_spec_attn_group_iterator(): kv_cache_spec = group.kv_cache_spec attn_backend = group.backend @@ -5881,7 +5891,6 @@ def _reshape_kv_cache_tensors( assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0 num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes if isinstance(kv_cache_spec, AttentionSpec): - has_attn = True num_blocks_per_kv_block = ( kv_cache_spec.block_size // kernel_block_size ) @@ -5908,44 +5917,31 @@ def _reshape_kv_cache_tensors( kv_cache_shape = tuple( kv_cache_shape[i] for i in kv_cache_stride_order ) + kv_cache_stride = tuple(torch.empty(kv_cache_shape).stride()) + storage_offset = 0 # Maintain original KV shape view. inv_order = [ kv_cache_stride_order.index(i) for i in range(len(kv_cache_stride_order)) ] - if ( - isinstance(kv_cache_spec, FullAttentionSpec) - and (attn_group_size := kv_cache_spec.group_size) > 1 - ): - attn_group_idx = layer_idx % attn_group_size - ori_stride = list(torch.empty(kv_cache_shape).stride()) - kernel_blocks_idx = kv_cache_shape.index(kernel_num_blocks) - ori_stride[kernel_blocks_idx] *= attn_group_size - target_stride = tuple(ori_stride) - dtype_size = get_dtype_size(dtype) - num_element_per_page = ( - kv_cache_spec.page_size_bytes // dtype_size - ) - num_element_per_attn_group = ( - num_element_per_page - // num_blocks_per_kv_block - // attn_group_size - ) - kv_caches[layer_name] = torch.as_strided( - kv_cache_raw_tensors[layer_name].view(dtype), - size=kv_cache_shape, - stride=target_stride, - storage_offset=attn_group_idx * num_element_per_attn_group, - ).permute(*inv_order) - else: - kv_caches[layer_name] = ( - kv_cache_raw_tensors[layer_name] - .view(dtype) - .view(kv_cache_shape) - .permute(*inv_order) + if has_mamba: + kv_cache_stride, storage_offset = ( + self._get_hybrid_attention_mamba_layout( + kv_cache_shape=kv_cache_shape, + kv_cache_stride=kv_cache_stride, + kv_cache_spec=kv_cache_spec, + layer_idx=layer_idx, + kernel_num_blocks=kernel_num_blocks, + kernel_block_size=kernel_block_size, + ) ) + kv_caches[layer_name] = torch.as_strided( + raw_tensor.view(dtype), + size=kv_cache_shape, + stride=kv_cache_stride, + storage_offset=storage_offset, + ).permute(*inv_order) elif isinstance(kv_cache_spec, MambaSpec): - has_mamba = True raw_tensor = kv_cache_raw_tensors[layer_name] state_tensors = [] storage_offset_bytes = 0 @@ -5976,6 +5972,57 @@ def _reshape_kv_cache_tensors( return kv_caches + def _get_hybrid_attention_mamba_layout( + self, + kv_cache_shape: tuple[int, ...], + kv_cache_stride: tuple[int, ...], + kv_cache_spec: AttentionSpec, + layer_idx: int, + kernel_num_blocks: int, + kernel_block_size: int, + ) -> tuple[tuple[int, ...], int]: + """ + Compute the stride and storage offset for the hybrid attention+mamba layout. + + Args: + kv_cache_shape: The shape of the KV cache tensor. + kv_cache_stride: The stride of the KV cache tensor. + kv_cache_spec: The specification of the KV cache. + layer_idx: The index of the layer. + kernel_num_blocks: The number of kernel blocks. + kernel_block_size: The size of the kernel block. + Returns: + A tuple containing the target stride and storage offset. + """ + target_stride_list = list(kv_cache_stride) + storage_offset = 0 + + attn_group_size = kv_cache_spec.group_size + kernel_blocks_idx = kv_cache_shape.index(kernel_num_blocks) + if kv_cache_shape[0] == 2: + # Hybrid attention+mamba uses (2, num_blocks, ...) logical shape but + # (num_blocks, 2, ...) physical layout. + assert kv_cache_shape[1] != 2, ( + "Fail to determine whether the layout is " + "(2, num_blocks, ...) or (num_blocks, 2, ...) for " + f"a tensor of shape {kv_cache_shape}" + ) + assert kernel_blocks_idx == 1 + hidden_size = prod(kv_cache_shape[2:]) + target_stride_list[0] = hidden_size + target_stride_list[1] = 2 * hidden_size + if attn_group_size > 1: + target_stride_list[kernel_blocks_idx] *= attn_group_size + dtype_size = get_dtype_size(kv_cache_spec.dtype) + num_element_per_page = kv_cache_spec.page_size_bytes // dtype_size + num_blocks_per_kv_block = kv_cache_spec.block_size // kernel_block_size + num_element_per_attn_group = ( + num_element_per_page // num_blocks_per_kv_block // attn_group_size + ) + attn_group_idx = layer_idx % attn_group_size + storage_offset = attn_group_idx * num_element_per_attn_group + return tuple(target_stride_list), storage_offset + def _update_hybrid_attention_mamba_layout( self, kv_caches: dict[str, torch.Tensor] ) -> None: From 372fc18f02a278c56aa7743bae388f9ae977b917 Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Sun, 15 Mar 2026 08:28:08 +0000 Subject: [PATCH 18/34] update after merged Signed-off-by: huanghaoyan.hhy --- vllm/model_executor/models/config.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index b76168281380..28561fb224ce 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -142,6 +142,7 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: kernel_block_alignment_size = 128 if use_cutlass_mla else 64 attn_page_size_1_token = MLAAttentionSpec( block_size=1, + group_size=cache_config.mamba_num_attn_pages, num_kv_heads=model_config.get_num_kv_heads(parallel_config), head_size=model_config.get_head_size(), dtype=kv_cache_dtype, @@ -161,6 +162,7 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: kernel_block_alignment_size = 32 attn_page_size_1_token = FullAttentionSpec( block_size=1, + group_size=cache_config.mamba_num_attn_pages, num_kv_heads=model_config.get_num_kv_heads(parallel_config), head_size=model_config.get_head_size(), dtype=kv_cache_dtype, From 4ebdc5a7a01341e6b9c3c057bdf55b93ccfe3cab Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Sat, 21 Mar 2026 09:23:08 +0000 Subject: [PATCH 19/34] remove _update_hybrid_attention_mamba_layout Signed-off-by: huanghaoyan.hhy --- vllm/v1/worker/gpu_model_runner.py | 45 ++---------------------------- 1 file changed, 2 insertions(+), 43 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 9b3a2f146f8e..9e3b10feca4f 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -6285,12 +6285,10 @@ def _reshape_kv_cache_tensors( kv_caches: dict[str, torch.Tensor] = {} # Pre-scan to determine whether there are attn / mamba layers. - has_attn, has_mamba = False, False + has_mamba = False for group in self._kv_cache_spec_attn_group_iterator(): kv_cache_spec = group.kv_cache_spec - if isinstance(kv_cache_spec, AttentionSpec): - has_attn = True - elif isinstance(kv_cache_spec, MambaSpec): + if isinstance(kv_cache_spec, MambaSpec): has_mamba = True for group in self._kv_cache_spec_attn_group_iterator(): @@ -6383,9 +6381,6 @@ def _reshape_kv_cache_tensors( else: raise NotImplementedError - if has_attn and has_mamba: - self._update_hybrid_attention_mamba_layout(kv_caches) - return kv_caches def _get_hybrid_attention_mamba_layout( @@ -6439,42 +6434,6 @@ def _get_hybrid_attention_mamba_layout( storage_offset = attn_group_idx * num_element_per_attn_group return tuple(target_stride_list), storage_offset - def _update_hybrid_attention_mamba_layout( - self, kv_caches: dict[str, torch.Tensor] - ) -> None: - """ - Update the layout of attention layers from (2, num_blocks, ...) to - (num_blocks, 2, ...). - - Args: - kv_caches: The KV cache buffer of each layer. - """ - - for group in self._kv_cache_spec_attn_group_iterator(): - kv_cache_spec = group.kv_cache_spec - for layer_name in group.layer_names: - kv_cache = kv_caches[layer_name] - if isinstance(kv_cache_spec, AttentionSpec) and kv_cache.shape[0] == 2: - assert kv_cache.shape[1] != 2, ( - "Fail to determine whether the layout is " - "(2, num_blocks, ...) or (num_blocks, 2, ...) for " - f"a tensor of shape {kv_cache.shape}" - ) - hidden_size = kv_cache.shape[2:].numel() - attn_group_size = ( - kv_cache_spec.group_size - if isinstance(kv_cache_spec, FullAttentionSpec) - else 1 - ) - kv_cache.as_strided_( - size=kv_cache.shape, - stride=( - hidden_size, - 2 * hidden_size * attn_group_size, - *kv_cache.stride()[2:], - ), - ) - def initialize_kv_cache_tensors( self, kv_cache_config: KVCacheConfig, kernel_block_sizes: list[int] ) -> dict[str, torch.Tensor]: From f307aa4b54f7d871e305ece87421ccc6f6f13b08 Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Sat, 21 Mar 2026 09:34:17 +0000 Subject: [PATCH 20/34] move get_hybrid_attention_mamba_layout to mamba_utils Signed-off-by: huanghaoyan.hhy --- vllm/v1/worker/gpu_model_runner.py | 54 +----------------------------- vllm/v1/worker/mamba_utils.py | 54 ++++++++++++++++++++++++++++++ 2 files changed, 55 insertions(+), 53 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 39b2bb4fd18e..b7df09330ee8 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -12,7 +12,6 @@ from copy import copy, deepcopy from dataclasses import dataclass, replace from functools import reduce -from math import prod from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias, cast import numpy as np @@ -6345,7 +6344,7 @@ def _reshape_kv_cache_tensors( ] if has_mamba: kv_cache_stride, storage_offset = ( - self._get_hybrid_attention_mamba_layout( + mamba_utils.get_hybrid_attention_mamba_layout( kv_cache_shape=kv_cache_shape, kv_cache_stride=kv_cache_stride, kv_cache_spec=kv_cache_spec, @@ -6388,57 +6387,6 @@ def _reshape_kv_cache_tensors( return kv_caches - def _get_hybrid_attention_mamba_layout( - self, - kv_cache_shape: tuple[int, ...], - kv_cache_stride: tuple[int, ...], - kv_cache_spec: AttentionSpec, - layer_idx: int, - kernel_num_blocks: int, - kernel_block_size: int, - ) -> tuple[tuple[int, ...], int]: - """ - Compute the stride and storage offset for the hybrid attention+mamba layout. - - Args: - kv_cache_shape: The shape of the KV cache tensor. - kv_cache_stride: The stride of the KV cache tensor. - kv_cache_spec: The specification of the KV cache. - layer_idx: The index of the layer. - kernel_num_blocks: The number of kernel blocks. - kernel_block_size: The size of the kernel block. - Returns: - A tuple containing the target stride and storage offset. - """ - target_stride_list = list(kv_cache_stride) - storage_offset = 0 - - attn_group_size = kv_cache_spec.group_size - kernel_blocks_idx = kv_cache_shape.index(kernel_num_blocks) - if kv_cache_shape[0] == 2: - # Hybrid attention+mamba uses (2, num_blocks, ...) logical shape but - # (num_blocks, 2, ...) physical layout. - assert kv_cache_shape[1] != 2, ( - "Fail to determine whether the layout is " - "(2, num_blocks, ...) or (num_blocks, 2, ...) for " - f"a tensor of shape {kv_cache_shape}" - ) - assert kernel_blocks_idx == 1 - hidden_size = prod(kv_cache_shape[2:]) - target_stride_list[0] = hidden_size - target_stride_list[1] = 2 * hidden_size - if attn_group_size > 1: - target_stride_list[kernel_blocks_idx] *= attn_group_size - dtype_size = get_dtype_size(kv_cache_spec.dtype) - num_element_per_page = kv_cache_spec.page_size_bytes // dtype_size - num_blocks_per_kv_block = kv_cache_spec.block_size // kernel_block_size - num_element_per_attn_group = ( - num_element_per_page // num_blocks_per_kv_block // attn_group_size - ) - attn_group_idx = layer_idx % attn_group_size - storage_offset = attn_group_idx * num_element_per_attn_group - return tuple(target_stride_list), storage_offset - def initialize_kv_cache_tensors( self, kv_cache_config: KVCacheConfig, kernel_block_sizes: list[int] ) -> dict[str, torch.Tensor]: diff --git a/vllm/v1/worker/mamba_utils.py b/vllm/v1/worker/mamba_utils.py index 2bd5d2b3fea8..d3be4cdbda62 100644 --- a/vllm/v1/worker/mamba_utils.py +++ b/vllm/v1/worker/mamba_utils.py @@ -3,6 +3,7 @@ import dataclasses import itertools from collections.abc import Callable +from math import prod from typing import Any import torch @@ -13,6 +14,8 @@ ) from vllm.triton_utils import tl, triton from vllm.utils.math_utils import cdiv +from vllm.utils.torch_utils import get_dtype_size +from vllm.v1.attention.backend import AttentionSpec from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import KVCacheConfig, MambaSpec from vllm.v1.utils import CpuGpuBuffer @@ -266,3 +269,54 @@ def postprocess_mamba( if src_block_idx == dest_block_idx: num_accepted_tokens_cpu[i] = 1 do_mamba_copy_block(copy_bufs) + + +def get_hybrid_attention_mamba_layout( + kv_cache_shape: tuple[int, ...], + kv_cache_stride: tuple[int, ...], + kv_cache_spec: AttentionSpec, + layer_idx: int, + kernel_num_blocks: int, + kernel_block_size: int, +) -> tuple[tuple[int, ...], int]: + """ + Compute the stride and storage offset for the hybrid attention+mamba layout. + + Args: + kv_cache_shape: The shape of the KV cache tensor. + kv_cache_stride: The stride of the KV cache tensor. + kv_cache_spec: The specification of the KV cache. + layer_idx: The index of the layer. + kernel_num_blocks: The number of kernel blocks. + kernel_block_size: The size of the kernel block. + Returns: + A tuple containing the target stride and storage offset. + """ + target_stride_list = list(kv_cache_stride) + storage_offset = 0 + + attn_group_size = kv_cache_spec.group_size + kernel_blocks_idx = kv_cache_shape.index(kernel_num_blocks) + if kv_cache_shape[0] == 2: + # Hybrid attention+mamba uses (2, num_blocks, ...) logical shape but + # (num_blocks, 2, ...) physical layout. + assert kv_cache_shape[1] != 2, ( + "Fail to determine whether the layout is " + "(2, num_blocks, ...) or (num_blocks, 2, ...) for " + f"a tensor of shape {kv_cache_shape}" + ) + assert kernel_blocks_idx == 1 + hidden_size = prod(kv_cache_shape[2:]) + target_stride_list[0] = hidden_size + target_stride_list[1] = 2 * hidden_size + if attn_group_size > 1: + target_stride_list[kernel_blocks_idx] *= attn_group_size + dtype_size = get_dtype_size(kv_cache_spec.dtype) + num_element_per_page = kv_cache_spec.page_size_bytes // dtype_size + num_blocks_per_kv_block = kv_cache_spec.block_size // kernel_block_size + num_element_per_attn_group = ( + num_element_per_page // num_blocks_per_kv_block // attn_group_size + ) + attn_group_idx = layer_idx % attn_group_size + storage_offset = attn_group_idx * num_element_per_attn_group + return tuple(target_stride_list), storage_offset From 6114cd8c3f7a9e8685f2422e4a9920f02f5dcd04 Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Sat, 21 Mar 2026 09:36:59 +0000 Subject: [PATCH 21/34] revert blank line Signed-off-by: huanghaoyan.hhy --- vllm/v1/core/kv_cache_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index e9213ca3cd38..586a95614d26 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -1076,7 +1076,6 @@ def _get_kv_cache_groups_uniform_page_size( # instead of layers[i * group_size: (i + 1) * group_size] for i in range(num_groups): grouped_layers.append(layers[i::num_groups]) - return create_kv_cache_group_specs(kv_cache_spec, grouped_layers) From 938c092bec18c7e94ad4f64d3c4ec1e38df620e0 Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Sat, 21 Mar 2026 10:24:19 +0000 Subject: [PATCH 22/34] rename group_size to pack_size Signed-off-by: huanghaoyan.hhy --- tests/v1/core/test_kv_cache_utils.py | 30 ++++++++-------- .../v1/worker/test_hybrid_kv_cache_layout.py | 34 +++++++++---------- vllm/model_executor/models/config.py | 4 +-- vllm/v1/core/kv_cache_utils.py | 24 ++++++------- vllm/v1/kv_cache_interface.py | 14 ++++---- vllm/v1/worker/mamba_utils.py | 14 ++++---- 6 files changed, 60 insertions(+), 60 deletions(-) diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index b9a9f4bf2553..97738229af12 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -112,7 +112,7 @@ def new_kv_cache_spec( page_size_padded=None, sliding_window=None, attention_chunk_size=None, - group_size=1, + pack_size=1, ): return FullAttentionSpec( block_size=block_size, @@ -122,7 +122,7 @@ def new_kv_cache_spec( page_size_padded=page_size_padded, sliding_window=sliding_window, attention_chunk_size=attention_chunk_size, - group_size=group_size, + pack_size=pack_size, ) @@ -2144,7 +2144,7 @@ def test_unify_hybrid_kv_cache_specs(): def test_merge_layers_from_attn_grouping(): - attn_group_size = 2 + attn_pack_size = 2 hybrid_kv_cache_specs = { "layer_1": new_mamba_spec(), @@ -2152,11 +2152,11 @@ def test_merge_layers_from_attn_grouping(): "layer_3": new_kv_cache_spec(head_size=32), } merged_kv_cache_specs = _merge_layers_from_attn_grouping( - attn_group_size, hybrid_kv_cache_specs + attn_pack_size, hybrid_kv_cache_specs ) assert merged_kv_cache_specs == { "layer_1": new_mamba_spec(), - "layer_2+layer_3": new_kv_cache_spec(head_size=32, group_size=2), + "layer_2+layer_3": new_kv_cache_spec(head_size=32, pack_size=2), } hybrid_kv_cache_specs = { @@ -2166,17 +2166,17 @@ def test_merge_layers_from_attn_grouping(): "layer_4": new_kv_cache_spec(head_size=32), } merged_kv_cache_specs = _merge_layers_from_attn_grouping( - attn_group_size, hybrid_kv_cache_specs + attn_pack_size, hybrid_kv_cache_specs ) assert merged_kv_cache_specs == { "layer_1": new_mamba_spec(), - "layer_2+layer_3": new_kv_cache_spec(head_size=32, group_size=2), - "layer_4": new_kv_cache_spec(head_size=32, group_size=2), + "layer_2+layer_3": new_kv_cache_spec(head_size=32, pack_size=attn_pack_size), + "layer_4": new_kv_cache_spec(head_size=32, pack_size=attn_pack_size), } def test_split_layers_from_attn_grouping(): - attn_group_size = 2 + attn_pack_size = 2 expected_page_size = new_mamba_spec().page_size_bytes kv_cache_config = KVCacheConfig( @@ -2191,12 +2191,12 @@ def test_split_layers_from_attn_grouping(): KVCacheGroupSpec(["layer_1"], new_mamba_spec()), KVCacheGroupSpec( ["layer_2+layer_3"], - new_kv_cache_spec(head_size=32, group_size=2), + new_kv_cache_spec(head_size=32, pack_size=2), ), ], ) split_kv_cache_config = _split_layers_from_attn_grouping( - attn_group_size, kv_cache_config + attn_pack_size, kv_cache_config ) assert split_kv_cache_config == KVCacheConfig( @@ -2211,7 +2211,7 @@ def test_split_layers_from_attn_grouping(): KVCacheGroupSpec(["layer_1"], new_mamba_spec()), KVCacheGroupSpec( ["layer_2", "layer_3"], - new_kv_cache_spec(head_size=32, group_size=2), + new_kv_cache_spec(head_size=32, pack_size=attn_pack_size), ), ], ) @@ -2289,7 +2289,7 @@ def test_get_kv_cache_configs_with_mamba(): KVCacheGroupSpec(["layer_1"], new_mamba_spec()), KVCacheGroupSpec( ["layer_2", "layer_3"], - new_kv_cache_spec(head_size=32, group_size=2), + new_kv_cache_spec(head_size=32, pack_size=2), ), ], ) @@ -2326,11 +2326,11 @@ def test_get_kv_cache_configs_with_mamba(): KVCacheGroupSpec(["layer_1", "layer_2"], new_mamba_spec()), KVCacheGroupSpec( ["layer_3", "layer_4", "layer_7"], - new_kv_cache_spec(head_size=32, group_size=2), + new_kv_cache_spec(head_size=32, pack_size=2), ), KVCacheGroupSpec( ["layer_5", "layer_6"], - new_kv_cache_spec(head_size=32, group_size=2), + new_kv_cache_spec(head_size=32, pack_size=2), ), ], ) diff --git a/tests/v1/worker/test_hybrid_kv_cache_layout.py b/tests/v1/worker/test_hybrid_kv_cache_layout.py index 21ebc026cf12..436d59862c58 100644 --- a/tests/v1/worker/test_hybrid_kv_cache_layout.py +++ b/tests/v1/worker/test_hybrid_kv_cache_layout.py @@ -28,7 +28,7 @@ def _build_full_attn_spec( num_kv_heads: int, head_size: int, dtype: torch.dtype, - group_size: int, + pack_size: int, ) -> FullAttentionSpec: # Minimal valid FullAttentionSpec for layout tests. return FullAttentionSpec( @@ -37,7 +37,7 @@ def _build_full_attn_spec( head_size=head_size, head_size_v=head_size, dtype=dtype, - group_size=group_size, + pack_size=pack_size, ) @@ -56,7 +56,7 @@ def _compute_layout_ref( """ dtype = kv_cache_spec.dtype block_size = kv_cache_spec.block_size - attn_group_size = kv_cache_spec.group_size + attn_pack_size = kv_cache_spec.pack_size num_blocks_per_kv_block = block_size // kernel_block_size kernel_num_blocks = num_blocks * num_blocks_per_kv_block @@ -88,17 +88,17 @@ def _compute_layout_ref( base_stride = list(torch.empty(kv_cache_shape).stride()) storage_offset = 0 - if attn_group_size > 1: + if attn_pack_size > 1: # Match the original `_reshape_kv_cache_tensors` logic. kernel_blocks_idx = kv_cache_shape.index(kernel_num_blocks) - base_stride[kernel_blocks_idx] *= attn_group_size + base_stride[kernel_blocks_idx] *= attn_pack_size dtype_size = get_dtype_size(dtype) num_element_per_page = kv_cache_spec.page_size_bytes // dtype_size - num_element_per_attn_group = ( - num_element_per_page // num_blocks_per_kv_block // attn_group_size + num_element_per_attn_pack = ( + num_element_per_page // num_blocks_per_kv_block // attn_pack_size ) - attn_group_idx = layer_idx % attn_group_size - storage_offset = attn_group_idx * num_element_per_attn_group + attn_pack_idx = layer_idx % attn_pack_size + storage_offset = attn_pack_idx * num_element_per_attn_pack # Logical KV tensor after the initial reshape. kv = torch.empty_strided( @@ -116,11 +116,11 @@ def _compute_layout_ref( and kv.shape[0] == 2 ): hidden_size = prod(kv.shape[2:]) - attn_group_size_for_layout = kv_cache_spec.group_size + attn_pack_size_for_layout = kv_cache_spec.pack_size kv_stride = kv.stride() kv_stride = ( hidden_size, - 2 * hidden_size * attn_group_size_for_layout, + 2 * hidden_size * attn_pack_size_for_layout, *kv_stride[2:], ) return kv.shape, kv_stride, storage_offset @@ -211,17 +211,17 @@ def _compute_layout_new( TreeAttentionBackend, ], ) -@pytest.mark.parametrize("group_size", [1, 2, 4]) +@pytest.mark.parametrize("pack_size", [1, 2, 4]) @pytest.mark.parametrize("enable_hybrid_attn_mamba_layout", [False, True]) @pytest.mark.parametrize("cache_layout", ["NHD", "HND"]) def test_hybrid_attention_mamba_layout_matches_reference( backend_cls: type[AttentionBackend], cache_layout: str, - group_size: int, + pack_size: int, enable_hybrid_attn_mamba_layout: bool, ): - if (not enable_hybrid_attn_mamba_layout) and group_size > 1: - pytest.skip("group_size > 1 only occurs when hybrid attention+mamba is enabled") + if (not enable_hybrid_attn_mamba_layout) and pack_size > 1: + pytest.skip("pack_size > 1 only occurs when hybrid attention+mamba is enabled") # Explicitly test both cache layouts for backends that depend on it # (FlashAttentionBackend / FlashInferBackend). Other backends ignore # this setting, but it is harmless to apply globally. @@ -244,7 +244,7 @@ def test_hybrid_attention_mamba_layout_matches_reference( num_kv_heads=num_kv_heads, head_size=head_size, dtype=dtype, - group_size=group_size, + pack_size=pack_size, ) backend = backend_cls @@ -270,7 +270,7 @@ def test_hybrid_attention_mamba_layout_matches_reference( assert ref_shape == new_shape assert ref_stride == new_stride # storage_offset only differs from zero when group_size > 1. - if group_size > 1: + if pack_size > 1: assert new_offset == ref_offset else: assert ref_offset == 0 diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 930d29ee1bbb..653256390c26 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -157,7 +157,7 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: kernel_block_alignment_size = 128 if use_cutlass_mla else 64 attn_page_size_1_token = MLAAttentionSpec( block_size=1, - group_size=cache_config.mamba_num_attn_pages, + pack_size=cache_config.mamba_num_attn_pages, num_kv_heads=model_config.get_num_kv_heads(parallel_config), head_size=model_config.get_head_size(), dtype=kv_cache_dtype, @@ -166,7 +166,7 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: kernel_block_alignment_size = 16 attn_page_size_1_token = FullAttentionSpec( block_size=1, - group_size=cache_config.mamba_num_attn_pages, + pack_size=cache_config.mamba_num_attn_pages, num_kv_heads=model_config.get_num_kv_heads(parallel_config), head_size=model_config.get_head_size(), dtype=kv_cache_dtype, diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 586a95614d26..ac3a477b16fa 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -1507,7 +1507,7 @@ def _project_kv_cache_groups_to_worker( def _merge_layers_from_attn_grouping( - attn_group_size: int, + attn_pack_size: int, kv_cache_specs: dict[str, KVCacheSpec], ) -> dict[str, KVCacheSpec]: merged_kv_cache_specs: dict[str, KVCacheSpec] = {} @@ -1524,12 +1524,12 @@ def _merge_layers_from_attn_grouping( layer_names, layer_specs = layers num_layers = len(layer_names) assert num_layers == len(layer_specs) - for start in range(0, num_layers, attn_group_size): - end = min(num_layers, start + attn_group_size) + for start in range(0, num_layers, attn_pack_size): + end = min(num_layers, start + attn_pack_size) grouped_layer_names = "+".join(layer_names[start:end]) try: merged_kv_spec = layer_specs[start].merge_from_grouping( - layer_specs[start:end], attn_group_size + layer_specs[start:end], attn_pack_size ) except AssertionError as e: raise ValueError( @@ -1541,7 +1541,7 @@ def _merge_layers_from_attn_grouping( def _split_layers_from_attn_grouping( - attn_group_size: int, + attn_pack_size: int, kv_cache_config: KVCacheConfig, ) -> KVCacheConfig: grouped_layers: dict[str, list[str]] = {} @@ -1551,8 +1551,8 @@ def _split_layers_from_attn_grouping( split_layer_names = [] for i, layer_name in enumerate(group.layer_names): layer_names = layer_name.split("+") - assert len(layer_names) == attn_group_size or ( - 0 < len(layer_names) < attn_group_size + assert len(layer_names) == attn_pack_size or ( + 0 < len(layer_names) < attn_pack_size and i == len(group.layer_names) - 1 ), f"Invalid layer name {layer_name} for attention grouping split" grouped_layers[layer_name] = layer_names @@ -1605,12 +1605,12 @@ def get_kv_cache_configs( The generated KVCacheConfigs for each worker. """ - attn_group_size = vllm_config.cache_config.mamba_num_attn_pages + attn_pack_size = vllm_config.cache_config.mamba_num_attn_pages # TODO: add comments - if attn_group_size > 1: + if attn_pack_size > 1: for i in range(len(kv_cache_specs)): kv_cache_specs[i] = _merge_layers_from_attn_grouping( - attn_group_size, + attn_pack_size, kv_cache_specs[i], ) @@ -1688,10 +1688,10 @@ def get_kv_cache_configs( if len(kv_cache_config.kv_cache_groups) > 0: _report_kv_cache_config(vllm_config, kv_cache_config) - if attn_group_size > 1: + if attn_pack_size > 1: for i in range(len(kv_cache_configs)): kv_cache_configs[i] = _split_layers_from_attn_grouping( - attn_group_size, + attn_pack_size, kv_cache_configs[i], ) diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index eb3a76e87634..c2c2ca2f29a9 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -67,7 +67,7 @@ class AttentionSpec(KVCacheSpec): head_size: int dtype: torch.dtype page_size_padded: int | None = None - group_size: int = 1 + pack_size: int = 1 """ The size of a group of attention layers. It's used for Mamba models to group more than one FullAttn layers to one page. @@ -78,8 +78,8 @@ def page_size_bytes(self) -> int: real_page_size = self.real_page_size_bytes if self.page_size_padded is not None: assert self.page_size_padded >= real_page_size - return self.page_size_padded * self.group_size - return real_page_size * self.group_size + return self.page_size_padded * self.pack_size + return real_page_size * self.pack_size @property def real_page_size_bytes(self) -> int: @@ -92,8 +92,8 @@ def real_page_size_bytes(self) -> int: ) @classmethod - def merge_from_grouping(cls, specs: list[Self], group_size: int) -> Self: - return replace(cls.merge(specs), group_size=group_size) + def merge_from_grouping(cls, specs: list[Self], pack_size: int) -> Self: + return replace(cls.merge(specs), pack_size=pack_size) @dataclass(frozen=True, kw_only=True) @@ -169,7 +169,7 @@ def merge(cls, specs: list[Self]) -> Self: head_size_v=specs[0].head_size_v, dtype=specs[0].dtype, page_size_padded=specs[0].page_size_padded, - group_size=specs[0].group_size, + pack_size=specs[0].pack_size, sliding_window=cls.merge_window_sizes(sliding_window), attention_chunk_size=cls.merge_window_sizes(attention_chunk_size), ) @@ -231,7 +231,7 @@ def merge(cls, specs: list[Self]) -> Self: head_size=specs[0].head_size, dtype=specs[0].dtype, page_size_padded=specs[0].page_size_padded, - group_size=specs[0].group_size, + pack_size=specs[0].pack_size, cache_dtype_str=cache_dtype_str_set.pop(), ) diff --git a/vllm/v1/worker/mamba_utils.py b/vllm/v1/worker/mamba_utils.py index d3be4cdbda62..4670a8e08d79 100644 --- a/vllm/v1/worker/mamba_utils.py +++ b/vllm/v1/worker/mamba_utils.py @@ -295,7 +295,7 @@ def get_hybrid_attention_mamba_layout( target_stride_list = list(kv_cache_stride) storage_offset = 0 - attn_group_size = kv_cache_spec.group_size + attn_pack_size = kv_cache_spec.pack_size kernel_blocks_idx = kv_cache_shape.index(kernel_num_blocks) if kv_cache_shape[0] == 2: # Hybrid attention+mamba uses (2, num_blocks, ...) logical shape but @@ -309,14 +309,14 @@ def get_hybrid_attention_mamba_layout( hidden_size = prod(kv_cache_shape[2:]) target_stride_list[0] = hidden_size target_stride_list[1] = 2 * hidden_size - if attn_group_size > 1: - target_stride_list[kernel_blocks_idx] *= attn_group_size + if attn_pack_size > 1: + target_stride_list[kernel_blocks_idx] *= attn_pack_size dtype_size = get_dtype_size(kv_cache_spec.dtype) num_element_per_page = kv_cache_spec.page_size_bytes // dtype_size num_blocks_per_kv_block = kv_cache_spec.block_size // kernel_block_size - num_element_per_attn_group = ( - num_element_per_page // num_blocks_per_kv_block // attn_group_size + num_element_per_attn_pack = ( + num_element_per_page // num_blocks_per_kv_block // attn_pack_size ) - attn_group_idx = layer_idx % attn_group_size - storage_offset = attn_group_idx * num_element_per_attn_group + attn_pack_idx = layer_idx % attn_pack_size + storage_offset = attn_pack_idx * num_element_per_attn_pack return tuple(target_stride_list), storage_offset From 263363a5e31ee55b36f44848cf148f0a80867150 Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Sat, 21 Mar 2026 10:57:00 +0000 Subject: [PATCH 23/34] fix the test Signed-off-by: huanghaoyan.hhy --- tests/v1/worker/test_hybrid_kv_cache_layout.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/tests/v1/worker/test_hybrid_kv_cache_layout.py b/tests/v1/worker/test_hybrid_kv_cache_layout.py index 436d59862c58..5ad7150f2921 100644 --- a/tests/v1/worker/test_hybrid_kv_cache_layout.py +++ b/tests/v1/worker/test_hybrid_kv_cache_layout.py @@ -19,8 +19,8 @@ get_kv_cache_layout, set_kv_cache_layout, ) -from vllm.v1.kv_cache_interface import FullAttentionSpec -from vllm.v1.worker.gpu_model_runner import GPUModelRunner +from vllm.v1.kv_cache_interface import AttentionSpec, FullAttentionSpec +from vllm.v1.worker import mamba_utils def _build_full_attn_spec( @@ -43,7 +43,7 @@ def _build_full_attn_spec( def _compute_layout_ref( backend: AttentionBackend, - kv_cache_spec: FullAttentionSpec, + kv_cache_spec: AttentionSpec, layer_idx: int, kernel_block_size: int, num_blocks: int, @@ -112,7 +112,7 @@ def _compute_layout_ref( # without actually changing the underlying storage. if ( enable_hybrid_attn_mamba_layout - and isinstance(kv_cache_spec, FullAttentionSpec) + and isinstance(kv_cache_spec, AttentionSpec) and kv.shape[0] == 2 ): hidden_size = prod(kv.shape[2:]) @@ -130,7 +130,7 @@ def _compute_layout_ref( def _compute_layout_new( backend: AttentionBackend, - kv_cache_spec: FullAttentionSpec, + kv_cache_spec: AttentionSpec, layer_idx: int, kernel_block_size: int, num_blocks: int, @@ -167,10 +167,8 @@ def _compute_layout_new( kv_cache_stride = tuple(torch.empty(kv_cache_shape).stride()) storage_offset = 0 - # We only need the method body; GPUModelRunner state is unused. - runner = object.__new__(GPUModelRunner) if enable_hybrid_attn_mamba_layout: - kv_cache_stride, storage_offset = runner._get_hybrid_attention_mamba_layout( + kv_cache_stride, storage_offset = mamba_utils.get_hybrid_attention_mamba_layout( kv_cache_shape=kv_cache_shape, kv_cache_stride=kv_cache_stride, kv_cache_spec=kv_cache_spec, From fea9f3ee332fbf4da54080ad0066bb265f117085 Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Sat, 21 Mar 2026 15:50:15 +0000 Subject: [PATCH 24/34] add comments Signed-off-by: huanghaoyan.hhy --- vllm/v1/core/kv_cache_utils.py | 61 ++++++++++++++++++++++++++++------ 1 file changed, 51 insertions(+), 10 deletions(-) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index ac3a477b16fa..83c8700debd2 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -1506,10 +1506,27 @@ def _project_kv_cache_groups_to_worker( return projected_groups -def _merge_layers_from_attn_grouping( +def _merge_attn_layers_into_pack( attn_pack_size: int, kv_cache_specs: dict[str, KVCacheSpec], ) -> dict[str, KVCacheSpec]: + """ + Merge attention layers into packs based on attn_pack_size. + + When mamba_num_attn_pages > 1, consecutive attention layers are packed + together to share a KV-block, with the block partitioned across layers. + This function packs every attn_pack_size consecutive attention layers + into a group, using "+" as a delimiter to join their layer names into a + new key. + + Args: + attn_pack_size: The number of layers in each attention pack. + kv_cache_specs: A dictionary mapping layer names to their KV cache specs. + + Returns: + A merged KV cache spec dictionary where consecutive attention layers + are packed together. + """ merged_kv_cache_specs: dict[str, KVCacheSpec] = {} att_groups: defaultdict[AttentionSpec, tuple[list[str], list[AttentionSpec]]] = ( defaultdict(lambda: ([], [])) @@ -1540,10 +1557,27 @@ def _merge_layers_from_attn_grouping( return merged_kv_cache_specs -def _split_layers_from_attn_grouping( +def _split_attn_layers_from_pack( attn_pack_size: int, kv_cache_config: KVCacheConfig, ) -> KVCacheConfig: + """ + Split attention layer packs back to individual layers. + + This is the reverse operation of _merge_attn_layers_into_pack. Once + the KV cache configuration is generated with packed layers, this + function splits them back to individual layer names so that each + physical layer can be properly initialized. + + Args: + attn_pack_size: The number of layers in each attention pack (same as + used in _merge_attn_layers_into_pack). + kv_cache_config: The KV cache configuration with packed layers. + + Returns: + The KV cache configuration with layer packs split back to individual + layers. + """ grouped_layers: dict[str, list[str]] = {} for group in kv_cache_config.kv_cache_groups: if not isinstance(group.kv_cache_spec, AttentionSpec): @@ -1583,17 +1617,21 @@ def get_kv_cache_configs( workers may have different memory available, and different type of layers (when pipeline parallel is enabled). To handle the difference between workers, the current implementation is: - 1. Merge the KV cache specs of all workers to get the KVCacheSpecs for + 1. If attn_pack_size > 1 (for Mamba models), pack attention layers into + groups to share a KV-block before processing. + 2. Merge the KV cache specs of all workers to get the KVCacheSpecs for the whole model. - 2. Generate the KV cache groups based on the layer ratio of the whole model. + 3. Generate the KV cache groups based on the layer ratio of the whole model. This also handles spec unification for hybrid models. - 3. Handle auto-fit max_model_len and memory checks using per-worker + 4. Handle auto-fit max_model_len and memory checks using per-worker projected groups to account for PP sharding. - 4. Generate the KV cache configs for each worker based on the KV cache + 5. Generate the KV cache configs for each worker based on the KV cache grouping strategy. (This is reasonable because the layer ratio of different PP stages are similar.) - 5. Change the num_blocks of each worker to the smallest among all workers + 6. Change the num_blocks of each worker to the smallest among all workers and shrink tensor sizes proportionally to avoid allocating unused memory. + 7. If attn_pack_size > 1 (for Mamba models), split packed layers back + to individual layers after generating configs. Args: vllm_config: The global VllmConfig @@ -1606,10 +1644,11 @@ def get_kv_cache_configs( """ attn_pack_size = vllm_config.cache_config.mamba_num_attn_pages - # TODO: add comments + # When attn_pack_size > 1 (for Mamba models), pack attention layers together + # to share a KV-block. if attn_pack_size > 1: for i in range(len(kv_cache_specs)): - kv_cache_specs[i] = _merge_layers_from_attn_grouping( + kv_cache_specs[i] = _merge_attn_layers_into_pack( attn_pack_size, kv_cache_specs[i], ) @@ -1688,9 +1727,11 @@ def get_kv_cache_configs( if len(kv_cache_config.kv_cache_groups) > 0: _report_kv_cache_config(vllm_config, kv_cache_config) + # When attn_pack_size > 1 (for Mamba models), split packed layers back + # to individual layers after generating configs. if attn_pack_size > 1: for i in range(len(kv_cache_configs)): - kv_cache_configs[i] = _split_layers_from_attn_grouping( + kv_cache_configs[i] = _split_attn_layers_from_pack( attn_pack_size, kv_cache_configs[i], ) From 920ecc18201772615182af010e079b1281ae166c Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Sat, 21 Mar 2026 15:51:06 +0000 Subject: [PATCH 25/34] fix Signed-off-by: huanghaoyan.hhy --- vllm/v1/worker/mamba_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/v1/worker/mamba_utils.py b/vllm/v1/worker/mamba_utils.py index 4670a8e08d79..82ae1b1b778a 100644 --- a/vllm/v1/worker/mamba_utils.py +++ b/vllm/v1/worker/mamba_utils.py @@ -15,9 +15,8 @@ from vllm.triton_utils import tl, triton from vllm.utils.math_utils import cdiv from vllm.utils.torch_utils import get_dtype_size -from vllm.v1.attention.backend import AttentionSpec from vllm.v1.core.sched.output import SchedulerOutput -from vllm.v1.kv_cache_interface import KVCacheConfig, MambaSpec +from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig, MambaSpec from vllm.v1.utils import CpuGpuBuffer from vllm.v1.worker.gpu_input_batch import CachedRequestState from vllm.v1.worker.lora_model_runner_mixin import GPUInputBatch From 54eccaeacd303606e68afd18e2adfe72093fa392 Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Sat, 21 Mar 2026 15:59:58 +0000 Subject: [PATCH 26/34] fix the name of tests Signed-off-by: huanghaoyan.hhy --- tests/v1/core/test_kv_cache_utils.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index 97738229af12..6a70eb5c0335 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -24,8 +24,8 @@ BlockHash, FreeKVCacheBlockQueue, KVCacheBlock, - _merge_layers_from_attn_grouping, - _split_layers_from_attn_grouping, + _merge_attn_layers_into_pack, + _split_attn_layers_from_pack, estimate_max_model_len, generate_block_hash_extra_keys, generate_scheduler_kv_cache_config, @@ -2143,7 +2143,7 @@ def test_unify_hybrid_kv_cache_specs(): kv_cache_utils.unify_hybrid_kv_cache_specs(kv_cache_spec) -def test_merge_layers_from_attn_grouping(): +def test_merge_attn_layers_into_pack(): attn_pack_size = 2 hybrid_kv_cache_specs = { @@ -2151,7 +2151,7 @@ def test_merge_layers_from_attn_grouping(): "layer_2": new_kv_cache_spec(head_size=32), "layer_3": new_kv_cache_spec(head_size=32), } - merged_kv_cache_specs = _merge_layers_from_attn_grouping( + merged_kv_cache_specs = _merge_attn_layers_into_pack( attn_pack_size, hybrid_kv_cache_specs ) assert merged_kv_cache_specs == { @@ -2165,7 +2165,7 @@ def test_merge_layers_from_attn_grouping(): "layer_3": new_kv_cache_spec(head_size=32), "layer_4": new_kv_cache_spec(head_size=32), } - merged_kv_cache_specs = _merge_layers_from_attn_grouping( + merged_kv_cache_specs = _merge_attn_layers_into_pack( attn_pack_size, hybrid_kv_cache_specs ) assert merged_kv_cache_specs == { @@ -2175,7 +2175,7 @@ def test_merge_layers_from_attn_grouping(): } -def test_split_layers_from_attn_grouping(): +def test_split_attn_layers_from_pack(): attn_pack_size = 2 expected_page_size = new_mamba_spec().page_size_bytes @@ -2195,7 +2195,7 @@ def test_split_layers_from_attn_grouping(): ), ], ) - split_kv_cache_config = _split_layers_from_attn_grouping( + split_kv_cache_config = _split_attn_layers_from_pack( attn_pack_size, kv_cache_config ) From cda00e285ede166d88624710c926c01a6d90cc89 Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Sat, 21 Mar 2026 16:02:19 +0000 Subject: [PATCH 27/34] udpate comments Signed-off-by: huanghaoyan.hhy --- vllm/v1/worker/gpu_model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index b7df09330ee8..7e1fff4789a1 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -6288,7 +6288,7 @@ def _reshape_kv_cache_tensors( """ kv_caches: dict[str, torch.Tensor] = {} - # Pre-scan to determine whether there are attn / mamba layers. + # Pre-scan to determine whether there are mamba layers. has_mamba = False for group in self._kv_cache_spec_attn_group_iterator(): kv_cache_spec = group.kv_cache_spec From dc1b08b49b1b56885eb2a3918065310da2b64dc6 Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Tue, 24 Mar 2026 17:49:06 +0000 Subject: [PATCH 28/34] add test_hybrid_attention_mamba_kv_cache_pack_size Signed-off-by: huanghaoyan.hhy --- tests/v1/worker/test_gpu_model_runner.py | 220 ++++++++++++++++++++++- 1 file changed, 219 insertions(+), 1 deletion(-) diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index dd23d9dfaf64..4a73cf95359b 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from math import prod + import numpy as np import pytest import torch @@ -8,6 +10,7 @@ from vllm.config import ( AttentionConfig, CacheConfig, + LoadConfig, ModelConfig, ParallelConfig, SchedulerConfig, @@ -24,12 +27,13 @@ from vllm.sampling_params import SamplingParams from vllm.utils.mem_constants import GiB_bytes from vllm.utils.system_utils import update_environment_variables -from vllm.utils.torch_utils import set_random_seed +from vllm.utils.torch_utils import get_dtype_size, set_random_seed from vllm.v1.attention.backend import MultipleOf from vllm.v1.attention.backends.registry import AttentionBackendEnum from vllm.v1.core.kv_cache_utils import estimate_max_model_len, get_kv_cache_configs from vllm.v1.core.sched.output import CachedRequestData, NewRequestData, SchedulerOutput from vllm.v1.kv_cache_interface import ( + AttentionSpec, FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec, @@ -1304,3 +1308,217 @@ def test_cudagraph_sizes_capped_for_mamba_cache(): compilation_config.cudagraph_capture_sizes[-1] == compilation_config.max_cudagraph_capture_size ) + + +@pytest.mark.skipif( + current_platform.is_rocm(), + reason="Attention backend FLASHINFER is not supported on ROCm.", +) +@pytest.mark.parametrize("pack_size", [1, 2, 4]) +def test_hybrid_attention_mamba_kv_cache_pack_size(pack_size: int): + """ + Test that KV cache layout for Attention layers in a hybrid (Attention+Mamba) + model correctly reflects pack_size settings (1, 2, 4) after calling + get_kv_cache_configs() and initialize_kv_cache(). + + Verifications: + 1. Attention layer kv_cache logical shape equals + (num_blocks, 2, block_size, num_kv_heads, head_size). + 2. Attention layer kv_cache stride at dim-0 (block dim) equals + base_stride * pack_size, reflecting the packed layout. + 3. When pack_size > 1, all Attention layers in a pack share the same + underlying storage; storage_offset of layer i equals + i * num_element_per_attn_pack. + 4. Writes to one packed Attention layer's region do NOT corrupt sibling + layers (write-isolation check). + 5. Mamba layer kv_cache block count is independent of pack_size. + """ + set_random_seed(42) + + update_environment_variables( + { + "RANK": "0", + "LOCAL_RANK": "0", + "WORLD_SIZE": "1", + "MASTER_ADDR": "localhost", + "MASTER_PORT": "12345", + } + ) + from tests.utils import ensure_current_vllm_config + + with ensure_current_vllm_config(): + init_distributed_environment() + initialize_model_parallel(tensor_model_parallel_size=1) + torch.set_default_dtype(torch.float16) + + model_config = ModelConfig( + model="ibm-granite/granite-4.0-tiny-preview", + dtype="float16", + ) + scheduler_config = SchedulerConfig( + max_num_seqs=10, + max_num_batched_tokens=512, + max_model_len=512, + is_encoder_decoder=model_config.is_encoder_decoder, + ) + cache_config = CacheConfig( + block_size=BLOCK_SIZE, + gpu_memory_utilization=0.9, + cache_dtype="auto", + mamba_num_attn_pages=pack_size, + ) + parallel_config = ParallelConfig() + vllm_config = VllmConfig( + model_config=model_config, + cache_config=cache_config, + scheduler_config=scheduler_config, + parallel_config=parallel_config, + attention_config=AttentionConfig(backend=AttentionBackendEnum.FLASHINFER), + load_config=LoadConfig(load_format="dummy"), + ) + + attn_layer_names = [ + "model.layers.0.self_attn.attn", + "model.layers.1.self_attn.attn", + "model.layers.2.self_attn.attn", + "model.layers.3.self_attn.attn", + ] + mamba_layer_names = [ + "model.layers.4.mixer", + "model.layers.5.mixer", + "model.layers.6.mixer", + "model.layers.7.mixer", + ] + + with set_current_vllm_config(vllm_config): + hf_config = vllm_config.model_config.hf_config + for key in attn_layer_names: + Attention( + num_heads=model_config.get_num_attention_heads(parallel_config), + num_kv_heads=model_config.get_num_kv_heads(parallel_config), + head_size=model_config.get_head_size(), + scale=1.0, + prefix=key, + ) + for key in mamba_layer_names: + MambaMixer2( + hidden_size=hf_config.hidden_size, + ssm_state_size=hf_config.mamba_d_state, + conv_kernel_size=hf_config.mamba_d_conv, + intermediate_size=hf_config.mamba_expand * hf_config.hidden_size, + use_conv_bias=hf_config.mamba_conv_bias, + use_bias=hf_config.mamba_proj_bias, + n_groups=hf_config.mamba_n_groups, + num_heads=hf_config.mamba_n_heads, + head_dim=hf_config.mamba_d_head, + rms_norm_eps=hf_config.rms_norm_eps, + activation=hf_config.hidden_act, + cache_config=cache_config, + model_config=model_config, + prefix=key, + ) + vllm_ctx = vllm_config.compilation_config.static_forward_context + + runner = GPUModelRunner(vllm_config, DEVICE) + kv_cache_spec = runner.get_kv_cache_spec() + + available_memory = 5 * GiB_bytes + kv_cache_config = get_kv_cache_configs( + vllm_config, [kv_cache_spec], [available_memory] + )[0] + runner.initialize_kv_cache(kv_cache_config) + + num_blocks = kv_cache_config.num_blocks + + # After `get_kv_cache_configs()`, the merged AttentionSpec + # (with `pack_size` set to `attn_pack_size`) is stored in + # `kv_cache_config.kv_cache_groups[].kv_cache_spec`. + # Use it as the authoritative spec for layout calculations; + # `runner.get_kv_cache_spec()` always returns per-layer specs + # with pack_size=1 and is unaffected by the merge. + attn_group_spec = next( + g.kv_cache_spec + for g in kv_cache_config.kv_cache_groups + if isinstance(g.kv_cache_spec, AttentionSpec) + ) + assert attn_group_spec.pack_size == pack_size, ( + f"attn_group_spec.pack_size expected {pack_size}, " + f"got {attn_group_spec.pack_size}" + ) + + kv0 = vllm_ctx[attn_layer_names[0]].kv_cache[0] + # FlashInfer logical shape: + # (kernel_num_blocks, 2, kernel_block_size, num_kv_heads, head_size) + expected_attn_shape = tuple(kv0.shape) + base_block_stride = prod(expected_attn_shape[1:]) + expected_block_stride = base_block_stride * pack_size + + dtype_size = get_dtype_size(attn_group_spec.dtype) + kernel_block_size = expected_attn_shape[2] # dim-2 of FlashInfer shape + num_blocks_per_kv_block = attn_group_spec.block_size // kernel_block_size + num_element_per_attn_pack = ( + attn_group_spec.page_size_bytes + // dtype_size + // num_blocks_per_kv_block + // attn_group_spec.pack_size + ) + + # --- 1 & 2. Verify shape and dim-0 stride for all Attention layers --- + for layer_name in attn_layer_names: + kv = vllm_ctx[layer_name].kv_cache[0] + assert tuple(kv.shape) == expected_attn_shape, ( + f"pack_size={pack_size}, {layer_name}: " + f"expected shape {expected_attn_shape}, got {tuple(kv.shape)}" + ) + assert kv.stride(0) == expected_block_stride, ( + f"pack_size={pack_size}, {layer_name}: " + f"dim-0 stride expected {expected_block_stride}, got {kv.stride(0)}" + ) + + # --- 3 & 4. Verify storage sharing, offsets, and write isolation --- + for pack_start in range(0, len(attn_layer_names), pack_size): + pack_layers = attn_layer_names[pack_start : pack_start + pack_size] + kv_tensors = [vllm_ctx[ln].kv_cache[0] for ln in pack_layers] + + if pack_size > 1: + # All layers in a pack share the same underlying storage. + base_ptr = kv_tensors[0].storage().data_ptr() + for i, (ln, kv) in enumerate(zip(pack_layers, kv_tensors)): + assert kv.storage().data_ptr() == base_ptr, ( + f"pack_size={pack_size}, {ln}: storage not shared with pack leader" + ) + # Layer at pack position i has offset i * num_element_per_attn_pack. + expected_offset = i * num_element_per_attn_pack + assert kv.storage_offset() == expected_offset, ( + f"pack_size={pack_size}, {ln} (pack_idx={i}): " + f"storage_offset expected {expected_offset}, " + f"got {kv.storage_offset()}" + ) + # Write-isolation: filling block-0 of layer i must not corrupt layer j. + for i, kv in enumerate(kv_tensors): + kv[0].fill_(float(i + 1)) + for i, (ln, kv) in enumerate(zip(pack_layers, kv_tensors)): + fill_val = float(i + 1) + assert ( + kv[0].min().item() == fill_val and kv[0].max().item() == fill_val + ), f"pack_size={pack_size}, {ln}: block 0 corrupted by sibling writes" + else: + # pack_size == 1: independent storage, offset must be 0. + for ln, kv in zip(pack_layers, kv_tensors): + assert kv.storage_offset() == 0, ( + f"pack_size=1, {ln}: storage_offset expected 0, " + f"got {kv.storage_offset()}" + ) + + # --- 5. Verify Mamba layer block count is independent of pack_size --- + for layer_name in mamba_layer_names: + conv_state = vllm_ctx[layer_name].kv_cache[0][0] + ssm_state = vllm_ctx[layer_name].kv_cache[0][1] + assert conv_state.shape[0] == num_blocks, ( + f"pack_size={pack_size}, {layer_name}: " + f"conv_state.shape[0] expected {num_blocks}, got {conv_state.shape[0]}" + ) + assert ssm_state.shape[0] == num_blocks, ( + f"pack_size={pack_size}, {layer_name}: " + f"ssm_state.shape[0] expected {num_blocks}, got {ssm_state.shape[0]}" + ) From 6b1104d89dc67029beb20dc2e40be117c6cd1a65 Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Sun, 29 Mar 2026 15:37:29 +0000 Subject: [PATCH 29/34] fix block_size without packed in profiling Signed-off-by: huanghaoyan.hhy --- tests/v1/core/test_kv_cache_utils.py | 12 +++++------- vllm/v1/core/kv_cache_utils.py | 8 ++++---- vllm/v1/worker/gpu_model_runner.py | 19 +++++++++++++++++++ 3 files changed, 28 insertions(+), 11 deletions(-) diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index 6a70eb5c0335..c83920a0547c 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -24,8 +24,6 @@ BlockHash, FreeKVCacheBlockQueue, KVCacheBlock, - _merge_attn_layers_into_pack, - _split_attn_layers_from_pack, estimate_max_model_len, generate_block_hash_extra_keys, generate_scheduler_kv_cache_config, @@ -36,6 +34,8 @@ init_none_hash, is_kv_cache_spec_uniform, make_block_hash_with_group_id, + merge_attn_layers_into_pack, + split_attn_layers_from_pack, tensor_data, ) from vllm.v1.kv_cache_interface import ( @@ -2151,7 +2151,7 @@ def test_merge_attn_layers_into_pack(): "layer_2": new_kv_cache_spec(head_size=32), "layer_3": new_kv_cache_spec(head_size=32), } - merged_kv_cache_specs = _merge_attn_layers_into_pack( + merged_kv_cache_specs = merge_attn_layers_into_pack( attn_pack_size, hybrid_kv_cache_specs ) assert merged_kv_cache_specs == { @@ -2165,7 +2165,7 @@ def test_merge_attn_layers_into_pack(): "layer_3": new_kv_cache_spec(head_size=32), "layer_4": new_kv_cache_spec(head_size=32), } - merged_kv_cache_specs = _merge_attn_layers_into_pack( + merged_kv_cache_specs = merge_attn_layers_into_pack( attn_pack_size, hybrid_kv_cache_specs ) assert merged_kv_cache_specs == { @@ -2195,9 +2195,7 @@ def test_split_attn_layers_from_pack(): ), ], ) - split_kv_cache_config = _split_attn_layers_from_pack( - attn_pack_size, kv_cache_config - ) + split_kv_cache_config = split_attn_layers_from_pack(attn_pack_size, kv_cache_config) assert split_kv_cache_config == KVCacheConfig( num_blocks=20, diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 83c8700debd2..ace8d51d6617 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -1506,7 +1506,7 @@ def _project_kv_cache_groups_to_worker( return projected_groups -def _merge_attn_layers_into_pack( +def merge_attn_layers_into_pack( attn_pack_size: int, kv_cache_specs: dict[str, KVCacheSpec], ) -> dict[str, KVCacheSpec]: @@ -1557,7 +1557,7 @@ def _merge_attn_layers_into_pack( return merged_kv_cache_specs -def _split_attn_layers_from_pack( +def split_attn_layers_from_pack( attn_pack_size: int, kv_cache_config: KVCacheConfig, ) -> KVCacheConfig: @@ -1648,7 +1648,7 @@ def get_kv_cache_configs( # to share a KV-block. if attn_pack_size > 1: for i in range(len(kv_cache_specs)): - kv_cache_specs[i] = _merge_attn_layers_into_pack( + kv_cache_specs[i] = merge_attn_layers_into_pack( attn_pack_size, kv_cache_specs[i], ) @@ -1731,7 +1731,7 @@ def get_kv_cache_configs( # to individual layers after generating configs. if attn_pack_size > 1: for i in range(len(kv_cache_configs)): - kv_cache_configs[i] = _split_attn_layers_from_pack( + kv_cache_configs[i] = split_attn_layers_from_pack( attn_pack_size, kv_cache_configs[i], ) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 60ef7795e0ac..5b51fc9a6c66 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -126,6 +126,10 @@ get_dcp_local_seq_lens, reorder_batch_to_split_decodes_and_prefills, ) +from vllm.v1.core.kv_cache_utils import ( + merge_attn_layers_into_pack, + split_attn_layers_from_pack, +) from vllm.v1.core.sched.output import NewRequestData from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher from vllm.v1.kv_cache_interface import ( @@ -5792,6 +5796,14 @@ def _init_minimal_kv_cache_for_profiling(self) -> None: ) kv_cache_spec = self.get_kv_cache_spec() + attn_pack_size = self.vllm_config.cache_config.mamba_num_attn_pages + # When attn_pack_size > 1 (for Mamba models), pack attention layers together + # to share a KV-block. + if attn_pack_size > 1: + kv_cache_spec = merge_attn_layers_into_pack( + attn_pack_size, + kv_cache_spec, + ) kv_cache_groups = get_kv_cache_groups(self.vllm_config, kv_cache_spec) min_blocks = self.compilation_config.max_cudagraph_capture_size or 1 @@ -5801,6 +5813,13 @@ def _init_minimal_kv_cache_for_profiling(self) -> None: minimal_config = get_kv_cache_config_from_groups( self.vllm_config, kv_cache_groups, available_memory=0 ) + # When attn_pack_size > 1 (for Mamba models), split packed layers back + # to individual layers after generating configs. + if attn_pack_size > 1: + minimal_config = split_attn_layers_from_pack( + attn_pack_size, + minimal_config, + ) self.cache_config.num_gpu_blocks_override = saved_override self.initialize_kv_cache(minimal_config) From e8db555d827597aa8cae83e4a44bdb80ab711e47 Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Sat, 4 Apr 2026 12:18:33 +0000 Subject: [PATCH 30/34] rename mamba_num_attn_pages to attn_pack_size Signed-off-by: huanghaoyan.hhy --- tests/v1/core/test_kv_cache_utils.py | 4 ++-- tests/v1/worker/test_gpu_model_runner.py | 2 +- vllm/config/cache.py | 4 ++-- vllm/config/vllm.py | 4 ++-- vllm/engine/arg_utils.py | 10 ++++------ vllm/model_executor/models/config.py | 4 ++-- vllm/v1/core/kv_cache_utils.py | 4 ++-- vllm/v1/worker/gpu_model_runner.py | 2 +- 8 files changed, 16 insertions(+), 18 deletions(-) diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index c83920a0547c..1b180e28628d 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -2264,7 +2264,7 @@ def test_get_kv_cache_configs_with_mamba(): ) # Test 3: 1 mamba + 2 full attention with group size 2 - vllm_config.cache_config.mamba_num_attn_pages = 2 + vllm_config.cache_config.attn_pack_size = 2 hybrid_kv_cache_specs = { "layer_1": new_mamba_spec(), "layer_2": new_kv_cache_spec(head_size=32), @@ -2293,7 +2293,7 @@ def test_get_kv_cache_configs_with_mamba(): ) # Test 4: 2 mamba + 5 full (with 3 padding full) - vllm_config.cache_config.mamba_num_attn_pages = 2 + vllm_config.cache_config.attn_pack_size = 2 hybrid_kv_cache_specs = { "layer_1": new_mamba_spec(), "layer_2": new_mamba_spec(), diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index 58d306d21daf..87e1f0adb833 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -1365,7 +1365,7 @@ def test_hybrid_attention_mamba_kv_cache_pack_size(pack_size: int): block_size=BLOCK_SIZE, gpu_memory_utilization=0.9, cache_dtype="auto", - mamba_num_attn_pages=pack_size, + attn_pack_size=pack_size, ) parallel_config = ParallelConfig() vllm_config = VllmConfig( diff --git a/vllm/config/cache.py b/vllm/config/cache.py index ac0fdac5a1aa..f3475037427b 100644 --- a/vllm/config/cache.py +++ b/vllm/config/cache.py @@ -23,7 +23,7 @@ ] MambaDType = Literal["auto", "float32", "float16"] MambaCacheMode = Literal["all", "align", "none"] -MambaNumAttnPages = Literal[1, 2, 4, 8] +AttnPackSize = Literal[1, 2, 4, 8] PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor", "xxhash", "xxhash_cbor"] KVOffloadingBackend = Literal["native", "lmcache"] @@ -114,7 +114,7 @@ class CacheConfig: - "align": only cache the mamba state of the last token of each scheduler step and when the token is at position i * block_size. """ - mamba_num_attn_pages: MambaNumAttnPages = 1 + attn_pack_size: AttnPackSize = 1 """The number of attention pages to allocate for Mamba layers. This is only relevant for models that includes Mamba layers.""" diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index c34cccc0eb3d..3bd0245ebf36 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -1233,9 +1233,9 @@ def has_blocked_weights(): # Default to enable HMA if not explicitly disabled by user or logic above. self.scheduler_config.disable_hybrid_kv_cache_manager = False - if self.cache_config.mamba_num_attn_pages > 1: + if self.cache_config.attn_pack_size > 1: assert self.model_config.is_hybrid, ( - "Mapping multiple FullAttention layers to a single page is only " + "Packing multiple FullAttention layers to a single page is only " "supported for hybrid models" ) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 02d1582c61e7..8262a03918e5 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -63,11 +63,11 @@ get_attr_docs, ) from vllm.config.cache import ( + AttnPackSize, CacheDType, KVOffloadingBackend, MambaCacheMode, MambaDType, - MambaNumAttnPages, PrefixCachingHashAlgo, ) from vllm.config.device import Device @@ -601,7 +601,7 @@ class EngineArgs: mamba_ssm_cache_dtype: MambaDType = CacheConfig.mamba_ssm_cache_dtype mamba_block_size: int | None = get_field(CacheConfig, "mamba_block_size") mamba_cache_mode: MambaCacheMode = CacheConfig.mamba_cache_mode - mamba_num_attn_pages: MambaNumAttnPages = CacheConfig.mamba_num_attn_pages + attn_pack_size: AttnPackSize = CacheConfig.attn_pack_size additional_config: dict[str, Any] = get_field(VllmConfig, "additional_config") @@ -1015,9 +1015,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: cache_group.add_argument( "--mamba-cache-mode", **cache_kwargs["mamba_cache_mode"] ) - cache_group.add_argument( - "--mamba-num-attn-pages", **cache_kwargs["mamba_num_attn_pages"] - ) + cache_group.add_argument("--attn-pack-size", **cache_kwargs["attn_pack_size"]) cache_group.add_argument( "--kv-offloading-size", **cache_kwargs["kv_offloading_size"] ) @@ -1582,7 +1580,7 @@ def create_engine_config( mamba_ssm_cache_dtype=self.mamba_ssm_cache_dtype, mamba_block_size=self.mamba_block_size, mamba_cache_mode=self.mamba_cache_mode, - mamba_num_attn_pages=self.mamba_num_attn_pages, + attn_pack_size=self.attn_pack_size, kv_offloading_size=self.kv_offloading_size, kv_offloading_backend=self.kv_offloading_backend, ) diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 653256390c26..04e97a84dd4d 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -157,7 +157,7 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: kernel_block_alignment_size = 128 if use_cutlass_mla else 64 attn_page_size_1_token = MLAAttentionSpec( block_size=1, - pack_size=cache_config.mamba_num_attn_pages, + pack_size=cache_config.attn_pack_size, num_kv_heads=model_config.get_num_kv_heads(parallel_config), head_size=model_config.get_head_size(), dtype=kv_cache_dtype, @@ -166,7 +166,7 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: kernel_block_alignment_size = 16 attn_page_size_1_token = FullAttentionSpec( block_size=1, - pack_size=cache_config.mamba_num_attn_pages, + pack_size=cache_config.attn_pack_size, num_kv_heads=model_config.get_num_kv_heads(parallel_config), head_size=model_config.get_head_size(), dtype=kv_cache_dtype, diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index ace8d51d6617..7528abab87ed 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -1513,7 +1513,7 @@ def merge_attn_layers_into_pack( """ Merge attention layers into packs based on attn_pack_size. - When mamba_num_attn_pages > 1, consecutive attention layers are packed + When attn_pack_size > 1, consecutive attention layers are packed together to share a KV-block, with the block partitioned across layers. This function packs every attn_pack_size consecutive attention layers into a group, using "+" as a delimiter to join their layer names into a @@ -1643,7 +1643,7 @@ def get_kv_cache_configs( The generated KVCacheConfigs for each worker. """ - attn_pack_size = vllm_config.cache_config.mamba_num_attn_pages + attn_pack_size = vllm_config.cache_config.attn_pack_size # When attn_pack_size > 1 (for Mamba models), pack attention layers together # to share a KV-block. if attn_pack_size > 1: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 5b51fc9a6c66..01232f5725c4 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -5796,7 +5796,7 @@ def _init_minimal_kv_cache_for_profiling(self) -> None: ) kv_cache_spec = self.get_kv_cache_spec() - attn_pack_size = self.vllm_config.cache_config.mamba_num_attn_pages + attn_pack_size = self.vllm_config.cache_config.attn_pack_size # When attn_pack_size > 1 (for Mamba models), pack attention layers together # to share a KV-block. if attn_pack_size > 1: From 7a0e1135a181b0a212aa70cbf0dff2a180e8b8ce Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Tue, 7 Apr 2026 18:09:01 +0000 Subject: [PATCH 31/34] update after merge Signed-off-by: huanghaoyan.hhy --- tests/v1/worker/test_gpu_model_runner.py | 16 ++++++++-------- .../v1/worker/test_hybrid_kv_cache_layout.py | 19 +++++++++++++++---- vllm/platforms/interface.py | 2 ++ vllm/v1/worker/gpu_model_runner.py | 8 +++++++- vllm/v1/worker/mamba_utils.py | 12 ++++++------ 5 files changed, 38 insertions(+), 19 deletions(-) diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index 8e6ea23b90bf..ecd5489815bf 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -1,9 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from types import SimpleNamespace - from math import prod +from types import SimpleNamespace import numpy as np import pytest @@ -1435,7 +1434,8 @@ def test_hybrid_attention_mamba_kv_cache_pack_size(pack_size: int): ) vllm_ctx = vllm_config.compilation_config.static_forward_context - runner = GPUModelRunner(vllm_config, DEVICE) + runner = GPUModelRunner(vllm_config, DEVICE_TYPE) + current_platform.update_block_size_for_backend(vllm_config) kv_cache_spec = runner.get_kv_cache_spec() available_memory = 5 * GiB_bytes @@ -1462,7 +1462,7 @@ def test_hybrid_attention_mamba_kv_cache_pack_size(pack_size: int): f"got {attn_group_spec.pack_size}" ) - kv0 = vllm_ctx[attn_layer_names[0]].kv_cache[0] + kv0 = vllm_ctx[attn_layer_names[0]].kv_cache # FlashInfer logical shape: # (kernel_num_blocks, 2, kernel_block_size, num_kv_heads, head_size) expected_attn_shape = tuple(kv0.shape) @@ -1481,7 +1481,7 @@ def test_hybrid_attention_mamba_kv_cache_pack_size(pack_size: int): # --- 1 & 2. Verify shape and dim-0 stride for all Attention layers --- for layer_name in attn_layer_names: - kv = vllm_ctx[layer_name].kv_cache[0] + kv = vllm_ctx[layer_name].kv_cache assert tuple(kv.shape) == expected_attn_shape, ( f"pack_size={pack_size}, {layer_name}: " f"expected shape {expected_attn_shape}, got {tuple(kv.shape)}" @@ -1494,7 +1494,7 @@ def test_hybrid_attention_mamba_kv_cache_pack_size(pack_size: int): # --- 3 & 4. Verify storage sharing, offsets, and write isolation --- for pack_start in range(0, len(attn_layer_names), pack_size): pack_layers = attn_layer_names[pack_start : pack_start + pack_size] - kv_tensors = [vllm_ctx[ln].kv_cache[0] for ln in pack_layers] + kv_tensors = [vllm_ctx[ln].kv_cache for ln in pack_layers] if pack_size > 1: # All layers in a pack share the same underlying storage. @@ -1528,8 +1528,8 @@ def test_hybrid_attention_mamba_kv_cache_pack_size(pack_size: int): # --- 5. Verify Mamba layer block count is independent of pack_size --- for layer_name in mamba_layer_names: - conv_state = vllm_ctx[layer_name].kv_cache[0][0] - ssm_state = vllm_ctx[layer_name].kv_cache[0][1] + conv_state = vllm_ctx[layer_name].kv_cache[0] + ssm_state = vllm_ctx[layer_name].kv_cache[1] assert conv_state.shape[0] == num_blocks, ( f"pack_size={pack_size}, {layer_name}: " f"conv_state.shape[0] expected {num_blocks}, got {conv_state.shape[0]}" diff --git a/tests/v1/worker/test_hybrid_kv_cache_layout.py b/tests/v1/worker/test_hybrid_kv_cache_layout.py index 5ad7150f2921..fa5cb560e9ae 100644 --- a/tests/v1/worker/test_hybrid_kv_cache_layout.py +++ b/tests/v1/worker/test_hybrid_kv_cache_layout.py @@ -88,10 +88,15 @@ def _compute_layout_ref( base_stride = list(torch.empty(kv_cache_shape).stride()) storage_offset = 0 + block_dim = backend.get_kv_cache_block_dim( + kernel_block_size, + kv_cache_spec.num_kv_heads, + kv_cache_spec.head_size, + cache_dtype_str="auto", + ) if attn_pack_size > 1: # Match the original `_reshape_kv_cache_tensors` logic. - kernel_blocks_idx = kv_cache_shape.index(kernel_num_blocks) - base_stride[kernel_blocks_idx] *= attn_pack_size + base_stride[block_dim] *= attn_pack_size dtype_size = get_dtype_size(dtype) num_element_per_page = kv_cache_spec.page_size_bytes // dtype_size num_element_per_attn_pack = ( @@ -113,7 +118,7 @@ def _compute_layout_ref( if ( enable_hybrid_attn_mamba_layout and isinstance(kv_cache_spec, AttentionSpec) - and kv.shape[0] == 2 + and block_dim == 1 ): hidden_size = prod(kv.shape[2:]) attn_pack_size_for_layout = kv_cache_spec.pack_size @@ -168,12 +173,18 @@ def _compute_layout_new( storage_offset = 0 if enable_hybrid_attn_mamba_layout: + block_dim = backend.get_kv_cache_block_dim( + kernel_block_size, + kv_cache_spec.num_kv_heads, + kv_cache_spec.head_size, + cache_dtype_str="auto", + ) kv_cache_stride, storage_offset = mamba_utils.get_hybrid_attention_mamba_layout( kv_cache_shape=kv_cache_shape, kv_cache_stride=kv_cache_stride, kv_cache_spec=kv_cache_spec, + block_dim=block_dim, layer_idx=layer_idx, - kernel_num_blocks=kernel_num_blocks, kernel_block_size=kernel_block_size, ) diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 72b714df1ba3..5f5a709239a0 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -536,6 +536,7 @@ def _align_hybrid_block_size( head_size=model_config.get_head_size(), dtype=kv_cache_dtype, kv_quant_mode=kv_quant_mode, + pack_size=cache_config.attn_pack_size, ).page_size_bytes else: attn_page_size_1_token = FullAttentionSpec( @@ -544,6 +545,7 @@ def _align_hybrid_block_size( head_size=model_config.get_head_size(), dtype=kv_cache_dtype, kv_quant_mode=kv_quant_mode, + pack_size=cache_config.attn_pack_size, ).page_size_bytes # Compute mamba page size diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 6fb0120bf088..9e70655b86ef 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -6674,13 +6674,19 @@ def _reshape_kv_cache_tensors( for i in range(len(kv_cache_stride_order)) ] if has_mamba: + block_dim = group.backend.get_kv_cache_block_dim( + kernel_block_sizes[group.kv_cache_group_id], + kv_cache_spec.num_kv_heads, + kv_cache_spec.head_size, + cache_dtype_str=self.cache_config.cache_dtype, + ) kv_cache_stride, storage_offset = ( mamba_utils.get_hybrid_attention_mamba_layout( kv_cache_shape=kv_cache_shape, kv_cache_stride=kv_cache_stride, kv_cache_spec=kv_cache_spec, + block_dim=block_dim, layer_idx=layer_idx, - kernel_num_blocks=kernel_num_blocks, kernel_block_size=kernel_block_size, ) ) diff --git a/vllm/v1/worker/mamba_utils.py b/vllm/v1/worker/mamba_utils.py index d97298fe7b43..8242b4aa823c 100644 --- a/vllm/v1/worker/mamba_utils.py +++ b/vllm/v1/worker/mamba_utils.py @@ -279,8 +279,8 @@ def get_hybrid_attention_mamba_layout( kv_cache_shape: tuple[int, ...], kv_cache_stride: tuple[int, ...], kv_cache_spec: AttentionSpec, + block_dim: int, layer_idx: int, - kernel_num_blocks: int, kernel_block_size: int, ) -> tuple[tuple[int, ...], int]: """ @@ -300,21 +300,21 @@ def get_hybrid_attention_mamba_layout( storage_offset = 0 attn_pack_size = kv_cache_spec.pack_size - kernel_blocks_idx = kv_cache_shape.index(kernel_num_blocks) - if kv_cache_shape[0] == 2: + # block_dim: 0 means (num_blocks, 2, ...); 1 means (2, num_blocks, ...). + if block_dim != 0: # Hybrid attention+mamba uses (2, num_blocks, ...) logical shape but # (num_blocks, 2, ...) physical layout. - assert kv_cache_shape[1] != 2, ( + assert kv_cache_shape[0] == 2, ( "Fail to determine whether the layout is " "(2, num_blocks, ...) or (num_blocks, 2, ...) for " f"a tensor of shape {kv_cache_shape}" ) - assert kernel_blocks_idx == 1 + assert block_dim == 1 hidden_size = prod(kv_cache_shape[2:]) target_stride_list[0] = hidden_size target_stride_list[1] = 2 * hidden_size if attn_pack_size > 1: - target_stride_list[kernel_blocks_idx] *= attn_pack_size + target_stride_list[block_dim] *= attn_pack_size dtype_size = get_dtype_size(kv_cache_spec.dtype) num_element_per_page = kv_cache_spec.page_size_bytes // dtype_size num_blocks_per_kv_block = kv_cache_spec.block_size // kernel_block_size From 5303953a5476b3ac8de0f1c6c76f4ad44753462b Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Thu, 9 Apr 2026 16:05:09 +0000 Subject: [PATCH 32/34] rm _update_hybrid_attention_mamba_layout Signed-off-by: huanghaoyan.hhy --- vllm/v1/worker/gpu_model_runner.py | 34 ------------------------------ 1 file changed, 34 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 9e70655b86ef..13e79e7eefcc 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -6724,40 +6724,6 @@ def _reshape_kv_cache_tensors( return kv_caches - def _update_hybrid_attention_mamba_layout( - self, kv_caches: dict[str, torch.Tensor], kernel_block_sizes: list[int] - ) -> None: - """ - Update the layout of attention layers from (2, num_blocks, ...) to - (num_blocks, 2, ...). - - Args: - kv_caches: The KV cache buffer of each layer. - kernel_block_sizes: The kernel block sizes for each KV cache group. - """ - - for group in self._kv_cache_spec_attn_group_iterator(): - kv_cache_spec = group.kv_cache_spec - if not isinstance(kv_cache_spec, AttentionSpec): - continue - block_dim = group.backend.get_kv_cache_block_dim( - kernel_block_sizes[group.kv_cache_group_id], - kv_cache_spec.num_kv_heads, - kv_cache_spec.head_size, - cache_dtype_str=self.cache_config.cache_dtype, - ) - # block_dim: 0 means (num_blocks, 2, ...); 1 means (2, num_blocks, ...). - if block_dim == 0: - continue - assert block_dim == 1 - for layer_name in group.layer_names: - kv_cache = kv_caches[layer_name] - hidden_size = kv_cache.shape[2:].numel() - kv_cache.as_strided_( - size=kv_cache.shape, - stride=(hidden_size, 2 * hidden_size, *kv_cache.stride()[2:]), - ) - def initialize_kv_cache_tensors( self, kv_cache_config: KVCacheConfig, kernel_block_sizes: list[int] ) -> dict[str, torch.Tensor]: From 144b3a30fb8c0b8e8aeb7343f9c5e25379eaa71c Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Thu, 9 Apr 2026 16:29:34 +0000 Subject: [PATCH 33/34] add comment Signed-off-by: huanghaoyan.hhy --- vllm/v1/worker/mamba_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/v1/worker/mamba_utils.py b/vllm/v1/worker/mamba_utils.py index 8242b4aa823c..16fec327d7b2 100644 --- a/vllm/v1/worker/mamba_utils.py +++ b/vllm/v1/worker/mamba_utils.py @@ -313,6 +313,9 @@ def get_hybrid_attention_mamba_layout( hidden_size = prod(kv_cache_shape[2:]) target_stride_list[0] = hidden_size target_stride_list[1] = 2 * hidden_size + # When multiple attention layers share one physical KV cache block + # (attn_pack_size > 1), scale the block-dim stride by attn_pack_size + # and compute this layer's element offset within the shared block. if attn_pack_size > 1: target_stride_list[block_dim] *= attn_pack_size dtype_size = get_dtype_size(kv_cache_spec.dtype) From 024c60024542848c3235b5d1c86878ddb3aa26b0 Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Thu, 9 Apr 2026 18:28:10 +0000 Subject: [PATCH 34/34] add log Signed-off-by: huanghaoyan.hhy --- vllm/platforms/interface.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 5f5a709239a0..85e96e513ae5 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -528,6 +528,14 @@ def _align_hybrid_block_size( kv_quant_mode = get_kv_quant_mode(cache_config.cache_dtype) + if cache_config.attn_pack_size > 1: + logger.info( + "Packing %d attention layers into one page, which will " + "reduce the block size to roughly 1/%d of the original.", + cache_config.attn_pack_size, + cache_config.attn_pack_size, + ) + # Compute attention page size for 1 token if model_config.use_mla: attn_page_size_1_token = MLAAttentionSpec(