diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 3da3d7e7bef7..293e511aa238 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -18,12 +18,14 @@ from vllm.utils.math_utils import cdiv from vllm.utils.mem_utils import format_gib from vllm.v1.kv_cache_interface import ( + AttentionSpec, ChunkedLocalAttentionSpec, FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec, KVCacheSpec, KVCacheTensor, + MambaSpec, SlidingWindowSpec, UniformTypeKVCacheSpecs, ) @@ -802,6 +804,35 @@ def get_max_concurrency_for_kv_cache_config( """ Get the maximum concurrency for the given KV cache configuration. """ + if _is_hybrid_kv_cache_groups(kv_cache_config.kv_cache_groups): + # For hybrid Mamba+Attention models, concurrency is limited by + # attention layers since Mamba state is O(1) per request while + # attention KV cache is O(n) per token. + attn_groups = [ + g for g in kv_cache_config.kv_cache_groups + if isinstance(g.kv_cache_spec, AttentionSpec) + ] + if not attn_groups: + return float('inf') + # Memory per request for attention layers at max sequence length. + # Each layer in a group gets its own tensor, so total = + # sum(num_layers * max_memory_for_one_layer). + max_mem_per_request = sum( + len(g.layer_names) + * g.kv_cache_spec.max_memory_usage_bytes(vllm_config) + for g in attn_groups + ) + # Memory consumed per block across all attention layers. + # Each layer gets page_size_bytes per block. + mem_per_block = sum( + len(g.layer_names) * g.kv_cache_spec.page_size_bytes + for g in attn_groups + ) + if mem_per_block == 0: + return float('inf') + blocks_per_request = cdiv(max_mem_per_request, mem_per_block) + return kv_cache_config.num_blocks / blocks_per_request + num_layer_per_group = max( len(group.layer_names) for group in kv_cache_config.kv_cache_groups ) @@ -953,6 +984,88 @@ def is_kv_cache_type_attention_free(kv_cache_spec: dict[str, KVCacheSpec]) -> bo return not kv_cache_spec +def _is_hybrid_mamba_attention(kv_cache_spec: dict[str, KVCacheSpec]) -> bool: + """Check if model has both Mamba/GDN and attention layers (hybrid). + + Models like Qwen3.5 use GatedDeltaNet (Mamba) for most layers and full + attention for a small subset. These need special KV cache handling: + Mamba state is O(1) per request while attention KV cache is O(n) per + token, so treating them uniformly wastes ~7x memory. + """ + has_mamba = any( + isinstance(spec, MambaSpec) for spec in kv_cache_spec.values() + ) + has_attention = any( + isinstance(spec, AttentionSpec) for spec in kv_cache_spec.values() + ) + return has_mamba and has_attention + + +def _is_hybrid_kv_cache_groups( + kv_cache_groups: list[KVCacheGroupSpec], +) -> bool: + """Check if KV cache groups contain both Mamba and attention specs.""" + has_mamba = any( + isinstance(g.kv_cache_spec, MambaSpec) for g in kv_cache_groups + ) + has_attention = any( + isinstance(g.kv_cache_spec, AttentionSpec) for g in kv_cache_groups + ) + return has_mamba and has_attention + + +def _get_kv_cache_groups_hybrid_mamba_attention( + kv_cache_spec: dict[str, KVCacheSpec], +) -> list[KVCacheGroupSpec]: + """ + Generate KV cache groups for hybrid Mamba+Attention models, preserving + each group's natural page size instead of padding all groups to a uniform + page size. + + Why this matters: + Models like Qwen3.5-MoE use GatedDeltaNet (Mamba) for most layers + (e.g., 24 of 32) and full attention for a small subset (e.g., 8 of 32). + Mamba state is O(1) per request (~1.1 MiB natural page size) while + attention KV cache is O(n) per token (~3.2 MiB page size). + + The standard ``_get_kv_cache_groups_uniform_page_size`` pads ALL groups to + the largest page size, inflating Mamba's memory footprint by ~3x per layer. + Combined with the layer count ratio, this causes ~7x total overestimation. + + This function instead: + 1. Groups layers by spec type (Mamba vs each distinct attention spec). + 2. Keeps each group at its natural page size. + 3. Lets ``get_kv_cache_config_from_groups`` allocate per-group tensors at + their natural sizes. + + See https://github.com/vllm-project/vllm/issues/37121 + + Args: + kv_cache_spec: The KVCacheSpec of each layer in the model. + + Returns: + KV cache groups with natural (unpadded) page sizes. + """ + # Group layers by their KVCacheSpec (same spec object == same group). + # Mamba layers with identical specs go in one group, attention layers + # with identical specs go in another group, etc. + same_spec_layers: dict[KVCacheSpec, list[str]] = defaultdict(list) + for layer_name, layer_spec in kv_cache_spec.items(): + same_spec_layers[layer_spec].append(layer_name) + + # Create one KVCacheGroupSpec per distinct spec. + # Unlike _get_kv_cache_groups_uniform_page_size, we do NOT split into + # equal-sized sub-groups or pad to match group sizes. Each group keeps + # its natural layer count and page size. + groups: list[KVCacheGroupSpec] = [] + for spec, layer_names in same_spec_layers.items(): + layer_specs = [kv_cache_spec[ln] for ln in layer_names] + merged = layer_specs[0].merge(layer_specs) + groups.append(KVCacheGroupSpec(layer_names, merged)) + + return groups + + def _get_kv_cache_groups_uniform_page_size( kv_cache_spec: dict[str, KVCacheSpec], ) -> list[KVCacheGroupSpec]: @@ -1119,8 +1232,47 @@ def get_kv_cache_config_from_groups( ) for layer_name in kv_cache_groups[0].layer_names ] + elif _is_hybrid_kv_cache_groups(kv_cache_groups): + # Hybrid Mamba+Attention models: allocate per-layer tensors at each + # group's natural page size instead of padding to a uniform size. + # This avoids the ~7x memory overestimation documented in + # https://github.com/vllm-project/vllm/issues/37121 + # + # Unlike the general case where layers at the same position share a + # tensor, hybrid models give each layer its own tensor because: + # 1. Layers in the same group (e.g., all Mamba layers) share a block + # table, so they all use the same block IDs for the same request. + # 2. Each layer needs its own memory for its state (each Mamba layer + # has independent recurrent state; each attention layer has + # independent KV cache). + # 3. Different spec types have different page sizes, so tensors can't + # be shared across groups either. + # + # Memory budget per block: + # sum(num_layers_in_group * group_page_size for each group) + # This correctly accounts for Mamba's small O(1) state and + # attention's larger O(n) KV cache per layer. + total_bytes_per_block = sum( + len(group.layer_names) * group.kv_cache_spec.page_size_bytes + for group in kv_cache_groups + ) + assert total_bytes_per_block > 0 + num_blocks = int(available_memory // total_bytes_per_block) + num_blocks = max(num_blocks, 0) + num_blocks = may_override_num_blocks(vllm_config, num_blocks) + + kv_cache_tensors = [] + for group in kv_cache_groups: + page_size = group.kv_cache_spec.page_size_bytes + for layer_name in group.layer_names: + kv_cache_tensors.append( + KVCacheTensor( + size=page_size * num_blocks, + shared_by=[layer_name], + ) + ) else: - # General case: + # General case (uniform page sizes): # We will have group_size memory pools, each is shared by one layer from # each group. As layers of different groups have different block table, # they will use different parts of the shared Tensor. @@ -1229,6 +1381,18 @@ def get_kv_cache_groups( Returns: The generated KVCacheGroups """ + # Hybrid Mamba+Attention detection MUST run before unify_hybrid_kv_cache_specs + # because unification modifies specs in-place (converting SlidingWindow to + # FullAttention), and would raise ValueError on Mamba+Attention combos. + # Hybrid Mamba+Attention models (e.g., Qwen3.5 with GatedDeltaNet) need + # dedicated grouping that preserves per-group natural page sizes. + if _is_hybrid_mamba_attention(kv_cache_spec): + logger.info( + "Detected hybrid Mamba+Attention model. Using per-group natural " + "page sizes to avoid KV cache memory overestimation." + ) + return _get_kv_cache_groups_hybrid_mamba_attention(kv_cache_spec) + if vllm_config.scheduler_config.disable_hybrid_kv_cache_manager: unify_hybrid_kv_cache_specs(kv_cache_spec) @@ -1291,16 +1455,32 @@ def _report_kv_cache_config( vllm_config: The global VllmConfig kv_cache_config: The resolved KV cache configuration """ - min_block_size = min( - [group.kv_cache_spec.block_size for group in kv_cache_config.kv_cache_groups] - ) + if _is_hybrid_kv_cache_groups(kv_cache_config.kv_cache_groups): + # For hybrid Mamba+Attention models, report tokens based on + # attention groups only, since Mamba state is O(1) per request + # and doesn't scale with sequence length. + attn_groups = [ + g for g in kv_cache_config.kv_cache_groups + if isinstance(g.kv_cache_spec, AttentionSpec) + ] + if attn_groups: + attn_block_size = attn_groups[0].kv_cache_spec.block_size + num_tokens = kv_cache_config.num_blocks * attn_block_size + else: + num_tokens = kv_cache_config.num_blocks + else: + min_block_size = min( + [group.kv_cache_spec.block_size + for group in kv_cache_config.kv_cache_groups] + ) + num_tokens = ( + kv_cache_config.num_blocks + // len(kv_cache_config.kv_cache_groups) + * min_block_size + ) # Log the KV cache size and maximum concurrency. - num_tokens = ( - kv_cache_config.num_blocks - // len(kv_cache_config.kv_cache_groups) - * min_block_size - ) + # (num_tokens is calculated above, handling hybrid models correctly) dcp_size = vllm_config.parallel_config.decode_context_parallel_size pcp_size = vllm_config.parallel_config.prefill_context_parallel_size if pcp_size * dcp_size > 1: @@ -1350,6 +1530,20 @@ def _max_memory_usage_bytes_from_groups( for spec in per_layer_specs.values() ) + if _is_hybrid_kv_cache_groups(kv_cache_groups): + # Hybrid Mamba+Attention: sum per-group costs independently. + # Each layer in a group gets its own tensor, so total memory for a + # group = num_layers * page_size * blocks_for_one_layer. + total = 0 + for group in kv_cache_groups: + spec = group.kv_cache_spec + blocks_per_layer = cdiv( + spec.max_memory_usage_bytes(vllm_config), + spec.page_size_bytes, + ) + total += len(group.layer_names) * spec.page_size_bytes * blocks_per_layer + return total + # General case: group_size pools, each shared by one layer per group # Memory = group_size * page_size * blocks_for_max_len group_size = max(len(group.layer_names) for group in kv_cache_groups)