diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index d7695027a284..26ff3a71b4e0 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -1321,3 +1321,189 @@ def test_mamba_cache_raises_when_max_num_seqs_exceeds_blocks(): with pytest.raises(ValueError, match="max_num_seqs"): runner.initialize_kv_cache(kv_cache_config) + + +def test_v2_reshape_kv_cache_hybrid_attention_mamba(): + """Test V2 model runner's _reshape_kv_cache with mixed AttentionSpec + and MambaSpec, including virtual block splitting and hybrid layout + adjustment. + + This is a regression test for the V2 model runner's handling of hybrid + attention models (e.g. Qwen3.5) where attention layers produce + AttentionSpec and linear attention (Gated DeltaNet) layers produce + MambaSpec. + """ + from unittest.mock import MagicMock + + from vllm.v1.kv_cache_interface import ( + KVCacheTensor, + MambaSpec, + ) + from vllm.v1.worker.gpu.attn_utils import ( + _reshape_kv_cache, + _update_hybrid_attention_mamba_layout, + ) + + # Configuration + num_kv_heads = 4 + head_size = 8 + dtype = torch.float16 + # dtype_size = 2 (float16), not used directly but documents layout + + # KV manager block_size = 32 (large, not power of 2 for kernel) + # Kernel block_size = 16 (FlashInfer-compatible) + kv_manager_block_size = 32 + kernel_block_size = 16 + num_blocks = 4 + + # --- Build AttentionSpec --- + attn_spec = FullAttentionSpec( + block_size=kv_manager_block_size, + num_kv_heads=num_kv_heads, + head_size=head_size, + dtype=dtype, + ) + attn_page_size = attn_spec.page_size_bytes + # 2 * 32 * 4 * 8 * 2 = 4096 bytes + + # --- Build MambaSpec --- + # Two state tensors: conv (d_conv, d_inner) and ssm (d_state, d_inner) + conv_shape = (4, 16) # (d_conv, d_inner) + ssm_shape = (8, 16) # (d_state, d_inner) + mamba_spec = MambaSpec( + block_size=kv_manager_block_size, + shapes=(conv_shape, ssm_shape), + dtypes=(dtype, dtype), + # Pad to match attention page size for hybrid models + page_size_padded=attn_page_size, + ) + + # --- Build KVCacheConfig --- + attn_layer = "model.layers.0.self_attn.attn" + mamba_layer = "model.layers.1.mixer" + + attn_tensor_size = attn_page_size * num_blocks + mamba_tensor_size = mamba_spec.page_size_bytes * num_blocks + + kv_cache_config = KVCacheConfig( + num_blocks=num_blocks, + kv_cache_tensors=[ + KVCacheTensor(size=attn_tensor_size, shared_by=[attn_layer]), + KVCacheTensor(size=mamba_tensor_size, shared_by=[mamba_layer]), + ], + kv_cache_groups=[ + KVCacheGroupSpec(layer_names=[attn_layer], kv_cache_spec=attn_spec), + KVCacheGroupSpec(layer_names=[mamba_layer], kv_cache_spec=mamba_spec), + ], + ) + + # --- Allocate raw tensors --- + device = torch.device("cpu") + kv_cache_raw_tensors = { + attn_layer: torch.zeros(attn_tensor_size, dtype=torch.int8, device=device), + mamba_layer: torch.zeros(mamba_tensor_size, dtype=torch.int8, device=device), + } + + # --- Mock attention backend --- + # FlashInfer-like: shape = (num_blocks, 2, block_size, num_kv_heads, + # head_size) + mock_backend = MagicMock() + + def mock_get_kv_cache_shape(n_blocks, block_sz, n_kv_heads, h_size, cache_dtype): + return (n_blocks, 2, block_sz, n_kv_heads, h_size) + + mock_backend.get_kv_cache_shape = mock_get_kv_cache_shape + mock_backend.get_kv_cache_stride_order.side_effect = AttributeError + + attn_backends = {attn_layer: mock_backend} + kernel_block_sizes = [kernel_block_size, kv_manager_block_size] + + # --- Call _reshape_kv_cache --- + kv_caches = _reshape_kv_cache( + kv_cache_config, + kv_cache_raw_tensors, + attn_backends, + "auto", + kernel_block_sizes, + ) + + # --- Verify attention layer --- + attn_cache = kv_caches[attn_layer] + assert isinstance(attn_cache, torch.Tensor) + + # Virtual block splitting: 4 KV blocks * (32/16) = 8 kernel blocks + expected_kernel_num_blocks = num_blocks * ( + kv_manager_block_size // kernel_block_size + ) + assert attn_cache.shape[0] == expected_kernel_num_blocks # 8 + assert attn_cache.shape[1] == 2 # K and V + assert attn_cache.shape[2] == kernel_block_size # 16 + assert attn_cache.shape[3] == num_kv_heads # 4 + assert attn_cache.shape[4] == head_size # 8 + + # --- Verify hybrid layout adjustment --- + # In hybrid mode, _update_hybrid_attention_mamba_layout should have been + # called. For FlashInfer-like backends where shape[0] == 2 (kv_dim), + # the stride adjustment converts (2, num_blocks, ...) to + # (num_blocks, 2, ...). + # Since our mock returns shape (8, 2, 16, 4, 8) where shape[0]=8 != 2, + # the hybrid layout adjustment should NOT fire (it only fires when + # shape[0] == 2). + + # --- Verify mamba layer --- + mamba_cache = kv_caches[mamba_layer] + assert isinstance(mamba_cache, list) + assert len(mamba_cache) == 2 # conv and ssm state tensors + + conv_tensor = mamba_cache[0] + ssm_tensor = mamba_cache[1] + + assert conv_tensor.shape == (num_blocks, *conv_shape) + assert ssm_tensor.shape == (num_blocks, *ssm_shape) + assert conv_tensor.dtype == dtype + assert ssm_tensor.dtype == dtype + + # --- Verify data isolation --- + # Writing to mamba blocks should not corrupt attention blocks + conv_tensor.fill_(1.0) + ssm_tensor.fill_(2.0) + + # Attention cache should still be zeros + assert torch.all(attn_cache == 0) + + # Writing to attention cache should not corrupt mamba + attn_cache.fill_(3.0) + assert torch.all(conv_tensor == 1.0) + assert torch.all(ssm_tensor == 2.0) + + # --- Directly test _update_hybrid_attention_mamba_layout --- + # Create a scenario where shape[0] == 2 (like FA backend with + # (2, num_blocks, block_size, num_kv_heads, head_size)) + small_num_blocks = 4 + raw = torch.arange( + 2 * small_num_blocks * kernel_block_size * num_kv_heads * head_size, + dtype=dtype, + ).reshape(2, small_num_blocks, kernel_block_size, num_kv_heads, head_size) + test_kv_caches: dict[str, torch.Tensor] = { + attn_layer: raw, + mamba_layer: [conv_tensor, ssm_tensor], # type: ignore[assignment] + } + _update_hybrid_attention_mamba_layout(kv_cache_config, test_kv_caches) + + updated = test_kv_caches[attn_layer] + assert isinstance(updated, torch.Tensor) + # Shape should be unchanged + assert updated.shape == ( + 2, + small_num_blocks, + kernel_block_size, + num_kv_heads, + head_size, + ) + # But strides should be adjusted: dim0 stride should be + # hidden_size (= block_size * num_kv_heads * head_size) + hidden_size = kernel_block_size * num_kv_heads * head_size + assert updated.stride()[0] == hidden_size + assert updated.stride()[1] == 2 * hidden_size + # Data is the same (just reinterpreted via strides) + assert updated.data_ptr() == raw.data_ptr() diff --git a/vllm/v1/worker/gpu/attn_utils.py b/vllm/v1/worker/gpu/attn_utils.py index 34089a67b3be..95ce490e8766 100644 --- a/vllm/v1/worker/gpu/attn_utils.py +++ b/vllm/v1/worker/gpu/attn_utils.py @@ -8,14 +8,20 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase +from vllm.utils.torch_utils import get_dtype_size from vllm.v1.attention.backend import AttentionBackend, CommonAttentionMetadata from vllm.v1.kv_cache_interface import ( AttentionSpec, KVCacheConfig, KVCacheSpec, + MambaSpec, UniformTypeKVCacheSpecs, ) -from vllm.v1.worker.utils import AttentionGroup, bind_kv_cache +from vllm.v1.worker.utils import ( + AttentionGroup, + bind_kv_cache, + prepare_kernel_block_sizes, +) def get_kv_cache_spec(vllm_config: VllmConfig) -> dict[str, KVCacheSpec]: @@ -34,10 +40,11 @@ def init_attn_backend( vllm_config: VllmConfig, device: torch.device, active_layer_names: set[str] | None = None, -): +) -> tuple[dict[str, type[AttentionBackend]], list[list[AttentionGroup]], list[int]]: attn_backends: dict[str, type[AttentionBackend]] = {} attn_groups: list[list[AttentionGroup]] = [] - attn_backend_workspace: torch.Tensor | None = None + + # Phase 1: Build attention groups for all KV cache groups. for kv_cache_group_id, kv_cache_group_spec in enumerate( kv_cache_config.kv_cache_groups ): @@ -71,12 +78,24 @@ def init_attn_backend( else: group_map[key].layer_names.append(layer_name) - groups = [group_map[key] for key in group_order] + attn_groups.append([group_map[key] for key in group_order]) + + # Phase 2: Compute kernel block sizes for virtual block splitting. + kernel_block_sizes = prepare_kernel_block_sizes(kv_cache_config, attn_groups) + + # Phase 3: Create metadata builders with correct kernel block sizes. + attn_backend_workspace: torch.Tensor | None = None + for kv_cache_group_id, groups in enumerate(attn_groups): + kernel_block_size = ( + kernel_block_sizes[kv_cache_group_id] + if kv_cache_group_id < len(kernel_block_sizes) + else None + ) for group in groups: group.create_metadata_builders( vllm_config=vllm_config, device=device, - kernel_block_size=None, + kernel_block_size=kernel_block_size, num_metadata_builders=1, ) builder = group.get_metadata_builder(0) @@ -86,8 +105,8 @@ def init_attn_backend( else: if hasattr(builder, "set_workspace_buffer"): builder.set_workspace_buffer(attn_backend_workspace) - attn_groups.append(groups) - return attn_backends, attn_groups + + return attn_backends, attn_groups, kernel_block_sizes def _allocate_kv_cache(kv_cache_config: KVCacheConfig, device: torch.device): @@ -107,50 +126,128 @@ def _allocate_kv_cache(kv_cache_config: KVCacheConfig, device: torch.device): return kv_cache_raw_tensors +def _update_hybrid_attention_mamba_layout( + kv_cache_config: KVCacheConfig, + kv_caches: dict[str, torch.Tensor], +) -> None: + """Update the layout of attention layers from (2, num_blocks, ...) to + (num_blocks, 2, ...) when attention and Mamba layers coexist. + + This is necessary so that blocks can be shared between attention layers + and Mamba layers. The KV-transfer code assumes this adjustment has been + applied (see kv_connector/utils.py). + """ + for kv_cache_group_spec in kv_cache_config.kv_cache_groups: + kv_cache_spec = kv_cache_group_spec.kv_cache_spec + if isinstance(kv_cache_spec, UniformTypeKVCacheSpecs): + kv_cache_spec = next(iter(kv_cache_spec.kv_cache_specs.values())) + if not isinstance(kv_cache_spec, AttentionSpec): + continue + for layer_name in kv_cache_group_spec.layer_names: + kv_cache = kv_caches[layer_name] + if not isinstance(kv_cache, torch.Tensor): + continue + if kv_cache.shape[0] == 2: + assert kv_cache.shape[1] != 2, ( + "Fail to determine whether the layout is " + "(2, num_blocks, ...) or (num_blocks, 2, ...) for " + f"a tensor of shape {kv_cache.shape}" + ) + hidden_size = kv_cache.shape[2:].numel() + kv_cache.as_strided_( + size=kv_cache.shape, + stride=( + hidden_size, + 2 * hidden_size, + *kv_cache.stride()[2:], + ), + ) + + def _reshape_kv_cache( kv_cache_config: KVCacheConfig, kv_cache_raw_tensors: dict[str, torch.Tensor], - attn_backends: dict[str, AttentionBackend], + attn_backends: dict[str, type[AttentionBackend]], cache_dtype: str, + kernel_block_sizes: list[int], ) -> dict[str, torch.Tensor]: kv_caches: dict[str, torch.Tensor] = {} - for kv_cache_group_spec in kv_cache_config.kv_cache_groups: + has_attn, has_mamba = False, False + for group_idx, kv_cache_group_spec in enumerate(kv_cache_config.kv_cache_groups): + kernel_block_size = kernel_block_sizes[group_idx] for layer_name in kv_cache_group_spec.layer_names: kv_cache_spec = kv_cache_group_spec.kv_cache_spec if isinstance(kv_cache_spec, UniformTypeKVCacheSpecs): kv_cache_spec = kv_cache_spec.kv_cache_specs[layer_name] - assert isinstance(kv_cache_spec, AttentionSpec) raw_tensor = kv_cache_raw_tensors[layer_name] assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0 num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes - attn_backend = attn_backends[layer_name] - kv_cache_shape = attn_backend.get_kv_cache_shape( - num_blocks, - kv_cache_spec.block_size, - kv_cache_spec.num_kv_heads, - kv_cache_spec.head_size, - cache_dtype, - ) + if isinstance(kv_cache_spec, AttentionSpec): + has_attn = True + # Virtual block splitting: split each KV manager block + # into smaller kernel blocks for backend compatibility. + num_blocks_per_kv_block = kv_cache_spec.block_size // kernel_block_size + kernel_num_blocks = num_blocks * num_blocks_per_kv_block + + attn_backend = attn_backends[layer_name] + kv_cache_shape = attn_backend.get_kv_cache_shape( + kernel_num_blocks, + kernel_block_size, + kv_cache_spec.num_kv_heads, + kv_cache_spec.head_size, + cache_dtype, + ) + + # FIXME(woosuk): Add kv_cache_stride_order to all attention + # backends. + try: + kv_cache_stride_order = attn_backend.get_kv_cache_stride_order() + assert len(kv_cache_stride_order) == len(kv_cache_shape) + except (AttributeError, NotImplementedError): + kv_cache_stride_order = tuple(range(len(kv_cache_shape))) + + kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order) + inv_order = [ + kv_cache_stride_order.index(i) + for i in range(len(kv_cache_stride_order)) + ] + + dtype = kv_cache_spec.dtype + kv_caches[layer_name] = ( + raw_tensor.view(dtype).view(kv_cache_shape).permute(*inv_order) + ) + + elif isinstance(kv_cache_spec, MambaSpec): + has_mamba = True + state_tensors = [] + storage_offset_bytes = 0 + for shape, dtype in zip(kv_cache_spec.shapes, kv_cache_spec.dtypes): + dtype_size = get_dtype_size(dtype) + num_element_per_page = kv_cache_spec.page_size_bytes // dtype_size + target_shape = (num_blocks, *shape) + stride = torch.empty(target_shape).stride() + target_stride = (num_element_per_page, *stride[1:]) + assert storage_offset_bytes % dtype_size == 0 + tensor = torch.as_strided( + raw_tensor.view(dtype), + size=target_shape, + stride=target_stride, + storage_offset=storage_offset_bytes // dtype_size, + ) + state_tensors.append(tensor) + storage_offset_bytes += stride[0] * dtype_size + kv_caches[layer_name] = state_tensors # type: ignore[assignment] + + else: + raise NotImplementedError( + f"Unsupported KV cache spec type: {type(kv_cache_spec)}" + ) + + if has_attn and has_mamba: + _update_hybrid_attention_mamba_layout(kv_cache_config, kv_caches) - # FIXME(woosuk): Add kv_cache_stride_order to all attention backends. - try: - kv_cache_stride_order = attn_backend.get_kv_cache_stride_order() - assert len(kv_cache_stride_order) == len(kv_cache_shape) - except (AttributeError, NotImplementedError): - kv_cache_stride_order = tuple(range(len(kv_cache_shape))) - - kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order) - inv_order = [ - kv_cache_stride_order.index(i) - for i in range(len(kv_cache_stride_order)) - ] - - dtype = kv_cache_spec.dtype - raw_tensor = raw_tensor.view(dtype) - raw_tensor = raw_tensor.view(kv_cache_shape) - kv_caches[layer_name] = raw_tensor.permute(*inv_order) return kv_caches @@ -158,13 +255,18 @@ def init_kv_cache( runner_kv_caches: list[torch.Tensor], forward_context: dict[str, Any], kv_cache_config: KVCacheConfig, - attn_backends: dict[str, AttentionBackend], + attn_backends: dict[str, type[AttentionBackend]], device: torch.device, cache_dtype: str, + kernel_block_sizes: list[int], ) -> dict[str, torch.Tensor]: kv_cache_raw_tensors = _allocate_kv_cache(kv_cache_config, device) kv_caches = _reshape_kv_cache( - kv_cache_config, kv_cache_raw_tensors, attn_backends, cache_dtype + kv_cache_config, + kv_cache_raw_tensors, + attn_backends, + cache_dtype, + kernel_block_sizes, ) bind_kv_cache(kv_caches, forward_context, runner_kv_caches) return kv_caches diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index a2f83c52e951..0d43443d842d 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -362,8 +362,8 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: cp_interleave=self.cp_interleave, ) - self.attn_backends, self.attn_groups = init_attn_backend( - self.kv_cache_config, self.vllm_config, self.device + self.attn_backends, self.attn_groups, self.kernel_block_sizes = ( + init_attn_backend(self.kv_cache_config, self.vllm_config, self.device) ) check_attention_cp_compatibility(self.vllm_config) if self.speculator is not None: @@ -382,6 +382,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: self.attn_backends, self.device, self.cache_config.cache_dtype, + self.kernel_block_sizes, ) self.kv_connector = get_kv_connector(self.vllm_config, kv_caches_dict) diff --git a/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py b/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py index 887fd52794cb..5f3aeb23efe2 100644 --- a/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py +++ b/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py @@ -133,7 +133,7 @@ def set_attn( ) -> None: self.model_state = model_state self.kv_cache_config = kv_cache_config - _, self.attn_groups = init_attn_backend( + _, self.attn_groups, _ = init_attn_backend( kv_cache_config, self.vllm_config, self.device,