diff --git a/.buildkite/test-amd.yaml b/.buildkite/test-amd.yaml index a7e26280c90d..05b32d889d06 100644 --- a/.buildkite/test-amd.yaml +++ b/.buildkite/test-amd.yaml @@ -839,7 +839,7 @@ steps: num_gpus: 2 working_dir: "/vllm-workspace/tests" source_file_dependencies: - - vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py + - vllm/distributed/kv_transfer/kv_connector/v1/nixl/ - vllm/v1/worker/kv_connector_model_runner_mixin.py - tests/v1/kv_connector/nixl_integration/ - vllm/platforms/rocm.py @@ -866,7 +866,7 @@ steps: num_gpus: 4 working_dir: "/vllm-workspace/tests" source_file_dependencies: - - vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py + - vllm/distributed/kv_transfer/kv_connector/v1/nixl/ - tests/v1/kv_connector/nixl_integration/ - vllm/platforms/rocm.py commands: @@ -2341,7 +2341,7 @@ steps: num_gpus: 4 working_dir: "/vllm-workspace/tests" source_file_dependencies: - - vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py + - vllm/distributed/kv_transfer/kv_connector/v1/nixl/ - tests/v1/kv_connector/nixl_integration/ - vllm/platforms/rocm.py commands: @@ -2377,7 +2377,7 @@ steps: optional: true working_dir: "/vllm-workspace/tests" source_file_dependencies: - - vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py + - vllm/distributed/kv_transfer/kv_connector/v1/nixl/ - tests/v1/kv_connector/nixl_integration/ - vllm/platforms/rocm.py commands: @@ -2391,7 +2391,7 @@ steps: num_gpus: 4 working_dir: "/vllm-workspace/tests" source_file_dependencies: - - vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py + - vllm/distributed/kv_transfer/kv_connector/v1/nixl/ - tests/v1/kv_connector/nixl_integration/ - vllm/platforms/rocm.py commands: @@ -2405,7 +2405,7 @@ steps: num_gpus: 4 working_dir: "/vllm-workspace/tests" source_file_dependencies: - - vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py + - vllm/distributed/kv_transfer/kv_connector/v1/nixl/ - tests/v1/kv_connector/nixl_integration/ - vllm/platforms/rocm.py commands: @@ -3353,7 +3353,7 @@ steps: optional: true working_dir: "/vllm-workspace/tests" source_file_dependencies: - - vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py + - vllm/distributed/kv_transfer/kv_connector/v1/nixl/ - vllm/v1/worker/kv_connector_model_runner_mixin.py - tests/v1/kv_connector/nixl_integration/ - vllm/platforms/rocm.py @@ -3369,7 +3369,7 @@ steps: optional: true working_dir: "/vllm-workspace/tests" source_file_dependencies: - - vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py + - vllm/distributed/kv_transfer/kv_connector/v1/nixl/ - tests/v1/kv_connector/nixl_integration/ - vllm/platforms/rocm.py commands: @@ -3384,7 +3384,7 @@ steps: optional: true working_dir: "/vllm-workspace/tests" source_file_dependencies: - - vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py + - vllm/distributed/kv_transfer/kv_connector/v1/nixl/ - tests/v1/kv_connector/nixl_integration/ - vllm/platforms/rocm.py commands: diff --git a/benchmarks/attention_benchmarks/runner.py b/benchmarks/attention_benchmarks/runner.py index aa636cd9cb53..cecba287bb4f 100644 --- a/benchmarks/attention_benchmarks/runner.py +++ b/benchmarks/attention_benchmarks/runner.py @@ -11,6 +11,7 @@ import logging import types from contextlib import contextmanager +from math import prod import numpy as np import torch @@ -30,10 +31,13 @@ ) from vllm.v1.attention.backends.utils import ( CommonAttentionMetadata, - get_kv_cache_layout, - set_kv_cache_layout, + resolve_kv_cache_layout, +) +from vllm.v1.kv_cache_interface import ( + FullAttentionSpec, + compute_layer_kv_cache_shape_bytes, + reshape_kv_cache, ) -from vllm.v1.kv_cache_interface import FullAttentionSpec # ============================================================================ # Backend Configuration @@ -324,52 +328,23 @@ def _create_input_tensors( def _create_kv_cache( config: BenchmarkConfig, max_num_blocks: int, - backend_class, device: torch.device, dtype: torch.dtype, ) -> list: - """Create KV cache tensors for all layers using the backend's methods. - - Uses the backend's get_kv_cache_shape() and get_kv_cache_stride_order() - to create the cache with the correct shape and memory layout. - """ - # Get the logical shape from the backend - cache_shape = backend_class.get_kv_cache_shape( - num_blocks=max_num_blocks, + """Create KV cache tensors for all layers using the standard allocator.""" + spec = FullAttentionSpec( block_size=config.block_size, num_kv_heads=config.num_kv_heads, head_size=config.head_dim, + dtype=dtype, ) - - # Get the stride order for custom memory layout - try: - stride_order = backend_class.get_kv_cache_stride_order() - assert len(stride_order) == len(cache_shape) - except (AttributeError, NotImplementedError): - stride_order = tuple(range(len(cache_shape))) - - # Permute shape to physical layout order - physical_shape = tuple(cache_shape[i] for i in stride_order) - - # Compute inverse permutation to get back to logical view - inv_order = [stride_order.index(i) for i in range(len(stride_order))] - - # Use fp8 dtype for cache when requested. - cache_dtype = dtype - if config.kv_cache_dtype == "fp8": - from vllm.platforms import current_platform - - cache_dtype = current_platform.fp8_dtype() - - cache_list = [] - for _ in range(config.num_layers): - # Allocate in physical layout order (contiguous in memory) - cache = torch.zeros(*physical_shape, device=device, dtype=cache_dtype) - # Permute to logical view - cache = cache.permute(*inv_order) - cache_list.append(cache) - - return cache_list + layout = resolve_kv_cache_layout() + total_bytes = ( + prod(compute_layer_kv_cache_shape_bytes(spec, max_num_blocks)) + * config.num_layers + ) + buf = torch.zeros(total_bytes, device=device, dtype=torch.int8) + return reshape_kv_cache(buf, spec, max_num_blocks, config.num_layers, layout) # ============================================================================ @@ -514,13 +489,6 @@ def run_attention_benchmark(config: BenchmarkConfig) -> BenchmarkResult: backend_cfg, config, device, dtype ) - # Set KV cache layout if the backend requires a specific one - # (e.g., FlashInfer requires HND on SM100/Blackwell for TRTLLM attention) - required_layout = backend_class.get_required_kv_cache_layout() - if required_layout is not None: - set_kv_cache_layout(required_layout) - get_kv_cache_layout.cache_clear() - common_metadata = _build_common_attn_metadata( q_lens, kv_lens, config.block_size, device ) @@ -549,9 +517,7 @@ def run_attention_benchmark(config: BenchmarkConfig) -> BenchmarkResult: config, total_q, device, dtype, quantize_query=quantize_query ) - cache_list = _create_kv_cache( - config, max_num_blocks, backend_class, device, dtype - ) + cache_list = _create_kv_cache(config, max_num_blocks, device, dtype) times, mem_stats = _run_single_benchmark( config, diff --git a/docs/features/nixl_connector_compatibility.md b/docs/features/nixl_connector_compatibility.md index 5541cd99bd80..84f564796155 100644 --- a/docs/features/nixl_connector_compatibility.md +++ b/docs/features/nixl_connector_compatibility.md @@ -59,7 +59,7 @@ th:not(:first-child) { 1 P and D instances must use the same speculation configuration. -2 Requires `FLASH_ATTN` or `FLASHINFER` backend **and** `HND` KV cache layout. Enable via `--kv-transfer-config '{"kv_connector_extra_config": {"enable_cross_layers_blocks": "True"}}'`. +2 Cross-layer contiguity is achieved by using a `BLHNC` layout (set via `VLLM_KV_CACHE_LAYOUT=BLHNC` or `--enable-cross-layers`). 3 Supported only when HMA is **not** required (i.e., non-hybrid models). Block IDs are remapped automatically. Only P block size < D block size is supported. diff --git a/docs/features/nixl_connector_usage.md b/docs/features/nixl_connector_usage.md index cb5a3dca035a..7717a57b2f62 100644 --- a/docs/features/nixl_connector_usage.md +++ b/docs/features/nixl_connector_usage.md @@ -389,15 +389,6 @@ Support use case: Prefill with 'HND' and decode with 'NHD' with experimental con --kv-transfer-config '{..., "enable_permute_local_kv":"True"}' ``` -### Cross layers blocks - -By default, this feature is disabled. On attention backends that support this feature, each logical block is contiguous in physical memory. This reduces the number of buffers that need to be transferred. -To enable this feature: - -```bash ---kv-transfer-config '{..., "kv_connector_extra_config": {"enable_cross_layers_blocks": "True"}}' -``` - ## Example Scripts/Code Refer to these example scripts in the vLLM repository: diff --git a/tests/compile/passes/test_fusion_attn.py b/tests/compile/passes/test_fusion_attn.py index b776f6af98a1..1c60dad80e4d 100644 --- a/tests/compile/passes/test_fusion_attn.py +++ b/tests/compile/passes/test_fusion_attn.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import copy +from math import prod import pytest import torch._dynamo @@ -39,7 +40,12 @@ from vllm.utils.flashinfer import has_flashinfer from vllm.v1.attention.backend import AttentionMetadata from vllm.v1.attention.backends.registry import AttentionBackendEnum -from vllm.v1.kv_cache_interface import AttentionSpec, get_kv_quant_mode +from vllm.v1.attention.backends.utils import resolve_kv_cache_layout +from vllm.v1.kv_cache_interface import ( + AttentionSpec, + compute_layer_kv_cache_shape_bytes, + get_kv_quant_mode, +) DEVICE_TYPE = current_platform.device_type FP8_DTYPE = current_platform.fp8_dtype() @@ -108,29 +114,28 @@ def build_attn_metadata(self, batch_size: int) -> AttentionMetadata: max_blocks = (max(batch_spec.seq_lens) + self.block_size - 1) // self.block_size num_blocks = batch_size * max_blocks - # Fetch the attention backend and kv cache shape and stride order - attn_backend = self.attn.attn_backend - kv_cache_shape = attn_backend.get_kv_cache_shape( - num_blocks, self.block_size, self.num_kv_heads, self.head_size + spec = AttentionSpec( + block_size=self.block_size, + num_kv_heads=self.num_kv_heads, + head_size=self.head_size, + dtype=self.attn.kv_cache_torch_dtype, + kv_quant_mode=get_kv_quant_mode(self.attn.kv_cache_dtype), ) - try: - kv_cache_stride_order = attn_backend.get_kv_cache_stride_order() - except (AttributeError, NotImplementedError): - kv_cache_stride_order = tuple(range(len(kv_cache_shape))) - - kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order) + layout = resolve_kv_cache_layout() + kv_cache_shape = compute_layer_kv_cache_shape_bytes(spec, num_blocks) + kv_cache_stride_order = layout.layer_stride_order + physical_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order) inv_order = [ kv_cache_stride_order.index(i) for i in range(len(kv_cache_stride_order)) ] - # Create dummy KV cache raw_tensor = torch.zeros( - 2 * num_blocks * self.block_size * self.num_kv_heads * self.head_size, - dtype=self.attn.kv_cache_torch_dtype, + prod(kv_cache_shape), + dtype=torch.int8, device=self.device, ) - raw_tensor = raw_tensor.view(kv_cache_shape) - kv_cache = raw_tensor.permute(*inv_order) + raw_tensor = raw_tensor.view(physical_shape) + kv_cache = raw_tensor.permute(*inv_order).view(self.attn.kv_cache_torch_dtype) self.attn.kv_cache = kv_cache diff --git a/tests/compile/passes/test_mla_attn_quant_fusion.py b/tests/compile/passes/test_mla_attn_quant_fusion.py index 0a38ffca483a..c1a5e714499e 100644 --- a/tests/compile/passes/test_mla_attn_quant_fusion.py +++ b/tests/compile/passes/test_mla_attn_quant_fusion.py @@ -150,27 +150,14 @@ def build_attn_metadata(self, batch_size: int) -> AttentionMetadata: max_blocks = (max(batch_spec.seq_lens) + self.block_size - 1) // self.block_size num_blocks = batch_size * max_blocks - # MLA KV cache is 3D: (num_blocks, block_size, head_size) - attn_backend = self.mla_attn.attn_backend - kv_cache_shape = attn_backend.get_kv_cache_shape( - num_blocks, self.block_size, 1, self.head_size - ) - try: - kv_cache_stride_order = attn_backend.get_kv_cache_stride_order() - except (AttributeError, NotImplementedError): - kv_cache_stride_order = tuple(range(len(kv_cache_shape))) - - ordered_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order) - inv_order = [ - kv_cache_stride_order.index(i) for i in range(len(kv_cache_stride_order)) - ] - - raw_tensor = torch.zeros( - ordered_shape, dtype=self.kv_cache_dtype, device=self.device + # MLA KV cache is 4D: (num_blocks, num_heads=1, block_size, head_size) + kv_cache = torch.zeros( + (num_blocks, 1, self.block_size, self.head_size), + dtype=self.kv_cache_dtype, + device=self.device, ) - kv_cache = raw_tensor.permute(*inv_order) - self.mla_attn.kv_cache = kv_cache + self.mla_attn.bind_kv_cache(kv_cache) self.attn_metadata = self.builder.build( common_prefix_len=0, common_attn_metadata=common_attn_metadata diff --git a/tests/compile/passes/test_mla_rope_kvcache_cat_fusion.py b/tests/compile/passes/test_mla_rope_kvcache_cat_fusion.py index cc3bfb7693ba..b1da94cb20b3 100644 --- a/tests/compile/passes/test_mla_rope_kvcache_cat_fusion.py +++ b/tests/compile/passes/test_mla_rope_kvcache_cat_fusion.py @@ -165,29 +165,15 @@ def build_attn_metadata(self, batch_size: int) -> CommonAttentionMetadata: max_blocks = (max(batch_spec.seq_lens) + self.block_size - 1) // self.block_size num_blocks = batch_size * max_blocks - # Fetch the attention backend and kv cache shape and stride order - kv_cache_shape = self.attn_backend.get_kv_cache_shape( - num_blocks, self.block_size, self.num_kv_heads, self.head_size - ) - try: - kv_cache_stride_order = self.attn_backend.get_kv_cache_stride_order() - except (AttributeError, NotImplementedError): - kv_cache_stride_order = tuple(range(len(kv_cache_shape))) - - kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order) - inv_order = [ - kv_cache_stride_order.index(i) for i in range(len(kv_cache_stride_order)) - ] - - raw_tensor = torch.zeros( - num_blocks * self.block_size * self.num_kv_heads * self.head_size, + # MLA uses a 4D KV cache: (num_blocks, num_heads=1, block_size, head_size). + kv_cache_shape = (num_blocks, 1, self.block_size, self.head_size) + kv_cache = torch.zeros( + kv_cache_shape, dtype=self.kv_cache_dtype, device=self.device, ) - raw_tensor = raw_tensor.view(kv_cache_shape) - kv_cache = raw_tensor.permute(*inv_order) - self.mla_attn.kv_cache = kv_cache + self.mla_attn.bind_kv_cache(kv_cache) # Build attn metadata attn_metadata = self.builder.build( diff --git a/tests/compile/passes/test_rope_kvcache_fusion.py b/tests/compile/passes/test_rope_kvcache_fusion.py index b27adfc46f51..4f97f3181c51 100644 --- a/tests/compile/passes/test_rope_kvcache_fusion.py +++ b/tests/compile/passes/test_rope_kvcache_fusion.py @@ -34,6 +34,10 @@ CommonAttentionMetadata, ) from vllm.v1.attention.backends.registry import AttentionBackendEnum +from vllm.v1.kv_cache_interface import ( + FullAttentionSpec, + compute_layer_kv_cache_shape_bytes, +) INDEX_SELECT_OP = torch.ops.aten.index.Tensor VLLM_UNIFIED_KV_CACHE_UPDATE_OP = torch.ops.vllm.unified_kv_cache_update @@ -119,28 +123,21 @@ def build_attn_metadata(self, batch_size: int) -> CommonAttentionMetadata: max_blocks = (max(batch_spec.seq_lens) + self.block_size - 1) // self.block_size num_blocks = batch_size * max_blocks - # Fetch the attention backend and kv cache shape and stride order - kv_cache_shape = self.attn_backend.get_kv_cache_shape( - num_blocks, self.block_size, self.num_kv_heads, self.head_size + kv_cache_shape = compute_layer_kv_cache_shape_bytes( + FullAttentionSpec( + block_size=self.block_size, + num_kv_heads=self.num_kv_heads, + head_size=self.head_size, + dtype=self.kv_cache_dtype, + ), + num_blocks, ) - try: - kv_cache_stride_order = self.attn_backend.get_kv_cache_stride_order() - except (AttributeError, NotImplementedError): - kv_cache_stride_order = tuple(range(len(kv_cache_shape))) - - kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order) - inv_order = [ - kv_cache_stride_order.index(i) for i in range(len(kv_cache_stride_order)) - ] - # Create dummy KV cache - raw_tensor = torch.zeros( - 2 * num_blocks * self.block_size * self.num_kv_heads * self.head_size, - dtype=self.kv_cache_dtype, + kv_cache = torch.zeros( + kv_cache_shape, + dtype=torch.int8, device=self.device, - ) - raw_tensor = raw_tensor.view(kv_cache_shape) - kv_cache = raw_tensor.permute(*inv_order) + ).view(self.kv_cache_dtype) self.attn.kv_cache = kv_cache diff --git a/tests/distributed/test_kvlayout.py b/tests/distributed/test_kvlayout.py index c8177f1c7c2f..11f4da241fa7 100644 --- a/tests/distributed/test_kvlayout.py +++ b/tests/distributed/test_kvlayout.py @@ -21,7 +21,7 @@ def test_get_kv_connector_cache_layout_without_kv_connector(): with set_current_vllm_config(vllm_config): # Test with default settings layout = get_kv_connector_cache_layout() - assert layout == "NHD" + assert layout == "LBNHC" def test_get_kv_connector_cache_layout_with_lmcache_connector(): @@ -35,7 +35,7 @@ def test_get_kv_connector_cache_layout_with_lmcache_connector(): with set_current_vllm_config(vllm_config): # Test with default settings layout = get_kv_connector_cache_layout() - assert layout == "NHD" + assert layout == "LBNHC" def test_get_kv_connector_cache_layout_with_nixl_connector(): @@ -52,7 +52,7 @@ def test_get_kv_connector_cache_layout_with_nixl_connector(): with set_current_vllm_config(vllm_config): # Test with default settings layout = get_kv_connector_cache_layout() - assert layout == "HND" + assert layout == "LBHNC" def test_get_kv_connector_cache_layout_with_multi_connector(): @@ -75,4 +75,4 @@ def test_get_kv_connector_cache_layout_with_multi_connector(): with set_current_vllm_config(vllm_config): # Test with default settings layout = get_kv_connector_cache_layout() - assert layout == "HND" + assert layout == "LBHNC" diff --git a/tests/kernels/attention/test_cache.py b/tests/kernels/attention/test_cache.py index 9b022a042c81..d9dc80792f94 100644 --- a/tests/kernels/attention/test_cache.py +++ b/tests/kernels/attention/test_cache.py @@ -10,7 +10,7 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.utils.quant_utils import scaled_dequantize from vllm.platforms import current_platform -from vllm.utils.torch_utils import nvfp4_kv_cache_split_views, set_random_seed +from vllm.utils.torch_utils import nvfp4_split_data_scale, set_random_seed COPYING_DIRECTION = [("cuda", "cpu"), ("cuda", "cuda"), ("cpu", "cuda")] DTYPES = [torch.bfloat16, torch.float] @@ -19,7 +19,7 @@ NUM_HEADS = [8] # Arbitrary values for testing HEAD_SIZES = [64, 80, 256] BLOCK_SIZES = [8, 16, 32] -CACHE_LAYOUTS = ["NHD", "HND"] +CACHE_LAYOUTS = ["LBNHC", "LBHNC"] KV_SCALE_TYPES = ["tensor", "attn_head"] # Parameters for MLA tests. @@ -196,8 +196,8 @@ def test_reshape_and_cache_flash( torch.set_default_device(device) torch.accelerator.set_device_index(device) assert implementation in ["cuda", "triton"] - if implementation == "triton" and kv_cache_layout == "HND": - pytest.skip("Triton implementation only supports NHD layout.") + if implementation == "triton" and kv_cache_layout == "LBHNC": + pytest.skip("Triton implementation only supports LBNHC layout.") if kv_scale_type == "attn_head" and implementation != "cuda": pytest.skip("Only CUDA implementation supports attn_head scaling.") @@ -255,10 +255,8 @@ def test_reshape_and_cache_flash( nvfp4_key_data = None nvfp4_value_data = None if kv_cache_dtype == "nvfp4": - (nvfp4_key_data,), (key_scale_cache,) = nvfp4_kv_cache_split_views(key_cache) - (nvfp4_value_data,), (value_scale_cache,) = nvfp4_kv_cache_split_views( - value_cache - ) + nvfp4_key_data, key_scale_cache = nvfp4_split_data_scale(key_cache) + nvfp4_value_data, value_scale_cache = nvfp4_split_data_scale(value_cache) if kv_cache_dtype == "nvfp4": # Global scale = amax / 448 (per-tensor) @@ -272,7 +270,7 @@ def test_reshape_and_cache_flash( v_scale = (value.amax(dim=(0, 2)) / 64.0).to(torch.float32) def permute_and_compact(x): - y = x if kv_cache_layout == "NHD" else x.permute(0, 2, 1, 3) + y = x if kv_cache_layout == "LBNHC" else x.permute(0, 2, 1, 3) return y.contiguous() if kv_cache_dtype != "nvfp4": @@ -286,8 +284,8 @@ def convert_fp8_local(output, input, scale, kv_dtype): fp8_input.flatten(0, 2), scale, group_shape=None, out_dtype=output.dtype ).reshape(*input.shape) else: # per-head: broadcast scale along the head dimension - # Original code uses dim 2 for NHD, dim 1 for HND - if kv_cache_layout == "NHD": + # Original code uses dim 2 for LBNHC, dim 1 for LBHNC + if kv_cache_layout == "LBNHC": result = fp8_input.to(output.dtype) * scale.view(1, 1, -1, 1) else: result = fp8_input.to(output.dtype) * scale.view(1, -1, 1, 1) @@ -356,28 +354,29 @@ def convert_fp8_local(output, input, scale, kv_dtype): dequant_nvfp4_kv_cache, ) - def dequant_nvfp4_cache_nhd(data_cache, scale_cache, global_scale): - # data_cache: [N, T, H, data_dim] NHD (contiguous inner dims) - # scale_cache: [N, T, H, scale_dim] NHD (contiguous inner dims) - # Permute to HND layout for the dequant utility. - data_hnd = data_cache.permute(0, 2, 1, 3) - scale_hnd = scale_cache.permute(0, 2, 1, 3) - result_hnd = dequant_nvfp4_kv_cache( - data_hnd, scale_hnd, global_scale, head_size, block_size + def dequant_nvfp4_cache_hnc(data_cache, scale_cache, global_scale): + # data_cache: [H, N, T, data_dim] HNC layout + # scale_cache: [H, N, T, scale_dim] HNC layout + return dequant_nvfp4_kv_cache( + data_cache, scale_cache, global_scale, head_size, block_size ) - return result_hnd.permute(0, 2, 1, 3) # back to [N, T, H, D] - result_key_cache = dequant_nvfp4_cache_nhd( + result_key_cache = dequant_nvfp4_cache_hnc( nvfp4_key_data, key_scale_cache, k_scale.item() ) - result_value_cache = dequant_nvfp4_cache_nhd( + result_value_cache = dequant_nvfp4_cache_hnc( nvfp4_value_data, value_scale_cache, v_scale.item() ) - # Flatten [num_blocks, block_size] → [num_slots] and index by slot_mapping. + # Result is HNC: (num_blocks, num_heads, block_size, head_size). + # Flatten to (num_slots, num_heads, head_size) for comparison. num_slots = num_blocks * block_size - result_key_flat = result_key_cache.reshape(num_slots, num_heads, head_size) - result_value_flat = result_value_cache.reshape(num_slots, num_heads, head_size) + result_key_flat = result_key_cache.permute(0, 2, 1, 3).reshape( + num_slots, num_heads, head_size + ) + result_value_flat = result_value_cache.permute(0, 2, 1, 3).reshape( + num_slots, num_heads, head_size + ) torch.testing.assert_close( result_key_flat[slot_mapping], key.float(), atol=1.5, rtol=0.5 @@ -409,7 +408,7 @@ def dequant_nvfp4_cache_nhd(data_cache, scale_cache, global_scale): for i in range(num_tokens): block_idx = block_indices_lst[i] block_offset = block_offsets_lst[i] - if kv_cache_layout == "NHD": + if kv_cache_layout == "LBNHC": cloned_key_cache[block_idx, block_offset, :, :] = key[i] cloned_value_cache[block_idx, block_offset, :, :] = value[i] else: diff --git a/tests/kernels/attention/test_flashinfer_trtllm_attention.py b/tests/kernels/attention/test_flashinfer_trtllm_attention.py index 87a12c2ff395..41b59e036446 100644 --- a/tests/kernels/attention/test_flashinfer_trtllm_attention.py +++ b/tests/kernels/attention/test_flashinfer_trtllm_attention.py @@ -13,7 +13,7 @@ from vllm.utils.math_utils import round_up from vllm.utils.torch_utils import ( nvfp4_kv_cache_full_dim, - nvfp4_kv_cache_split_views, + nvfp4_split_data_scale, set_random_seed, ) @@ -74,17 +74,18 @@ def make_nvfp4_kv_cache( kv_scale_val, dtype=torch.float32, device=kv_bf16_hnd.device ) - # Allocate in HND physical order, permute to NHD logical order. - # hnd_order swaps dims 2↔3; it is its own inverse. + # Production layout: 4D (B, 2*H, N, full_dim) where K heads occupy + # the first H heads and V heads occupy the second H heads. full_dim = nvfp4_kv_cache_full_dim(head_size) - hnd_order = (0, 1, 3, 2, 4) - kv_cache = torch.zeros( - (num_blocks, 2, num_kv_heads, block_size, full_dim), + kv_cache_hnd = torch.zeros( + (num_blocks, 2 * num_kv_heads, block_size, full_dim), dtype=torch.uint8, device=kv_bf16_hnd.device, - ).permute(*hnd_order) + ) + kv_cache_nhd = kv_cache_hnd.permute(0, 2, 1, 3) + k_view_nhd, v_view_nhd = kv_cache_nhd.split(num_kv_heads, dim=-2) - # Flatten NHD [N, T, H, D] → token tensors [N*T, H, D] for the kernel. + # Flatten input KV → token tensors [B*N, H, head_size] for the kernel. num_tokens = num_blocks * block_size k_tokens = ( kv_bf16_hnd[:, 0] @@ -98,22 +99,21 @@ def make_nvfp4_kv_cache( ) slot_mapping = torch.arange(num_tokens, dtype=torch.long, device=kv_bf16_hnd.device) - # reshape_and_cache_flash: kernel receives kv_cache[:, 0] and [:, 1] - # (full K/V buffers containing both data and scale). torch.ops._C_cache_ops.reshape_and_cache_flash( k_tokens, v_tokens, - kv_cache[:, 0], - kv_cache[:, 1], + k_view_nhd, + v_view_nhd, slot_mapping, "nvfp4", kv_scale_tensor, kv_scale_tensor, ) - # Split in HND order for trtllm kernel (expects HND numTokensPerPage). - kv_cache_hnd = kv_cache.permute(*hnd_order) - (k_data, v_data), (k_scales, v_scales) = nvfp4_kv_cache_split_views(kv_cache_hnd) + # Split into data/scale views in HNC order for trtllm kernel. + k_cache_hnc, v_cache_hnc = kv_cache_hnd.split(num_kv_heads, dim=1) + k_data, k_scales = nvfp4_split_data_scale(k_cache_hnc) + v_data, v_scales = nvfp4_split_data_scale(v_cache_hnc) # Dequantize for the FA2 reference baseline. ref_k = dequant_nvfp4_kv_cache( diff --git a/tests/kernels/attention/test_trtllm_kvfp8_dequant.py b/tests/kernels/attention/test_trtllm_kvfp8_dequant.py index c49ceb03f5b1..b5193e0a8814 100644 --- a/tests/kernels/attention/test_trtllm_kvfp8_dequant.py +++ b/tests/kernels/attention/test_trtllm_kvfp8_dequant.py @@ -3,14 +3,14 @@ """ Standalone unit tests for trtllm_prefill_attn_kvfp8_dequant. -Tests both contiguous and non-contiguous (cross-layer unified) KV cache -layouts against a pure-PyTorch reference implementation. +Tests KV cache layouts against a pure-PyTorch reference implementation. """ import pytest import torch from vllm.platforms import current_platform +from vllm.v1.kv_cache_interface import KVCacheLayout if current_platform.is_rocm(): pytest.skip( @@ -34,51 +34,31 @@ def to_float8(x, dtype=None): return x_scl_sat.to(dtype), scale.float().reciprocal() -def make_contiguous_kv_cache(num_blocks, num_kv_heads, block_size, head_size): - """Create a standard contiguous fp8 KV cache (HND layout).""" - raw = torch.randn( - num_blocks, - 2, - num_kv_heads, - block_size, - head_size, - dtype=torch.bfloat16, - device="cuda", - ) - kv_cache, scale = to_float8(raw) - return kv_cache, scale - - -def make_cross_layer_kv_cache( - num_blocks, - num_kv_heads, - block_size, - head_size, - num_layers=4, +def make_random_kv_cache( + num_blocks, num_kv_heads, block_size, head_size, layout=KVCacheLayout.LBHNC ): - """ - Create a non-contiguous per-layer view mimicking cross-layer allocation. + """Create a random fp8 KV cache in 5D ``(B, 2, H, N, hs)`` format. - Physical layout: (num_blocks, 2, num_kv_heads, num_layers, block_size, head_size) - Returned view: (num_blocks, 2, num_kv_heads, block_size, head_size) - with non-contiguous strides on dims 0, 1, 2 (they skip over num_layers). + The 4D cache ``(B, H, N, 2*hs)`` is allocated with the physical stride + order from *layout*, then reshaped to 5D. Non-identity layouts produce + non-contiguous strides on inner dims, matching the actual forward path. """ - raw = torch.randn( + logical_4d = (num_blocks, num_kv_heads, block_size, 2 * head_size) + stride_order = layout.layer_stride_order + physical_4d = tuple(logical_4d[i] for i in stride_order) + inv_order = [stride_order.index(i) for i in range(4)] + + raw_phys = torch.randn(*physical_4d, dtype=torch.bfloat16, device="cuda") + fp8_phys, scale = to_float8(raw_phys) + fp8_4d = fp8_phys.permute(*inv_order) + kv_5d = fp8_4d.view( num_blocks, - 2, num_kv_heads, - num_layers, block_size, + 2, head_size, - dtype=torch.bfloat16, - device="cuda", - ) - fp8_full, scale = to_float8(raw) - layer_view = fp8_full[:, :, :, 0, :, :] - assert not layer_view.is_contiguous(), ( - f"Expected non-contiguous view, got strides {layer_view.stride()}" - ) - return layer_view, scale + ).permute(0, 3, 1, 2, 4) + return kv_5d, scale def ref_dequant(kv_cache, block_tables, k_scale, v_scale, dequant_dtype): @@ -114,7 +94,7 @@ def ref_dequant(kv_cache, block_tables, k_scale, v_scale, dequant_dtype): @pytest.mark.parametrize("block_size", [16, 32]) @pytest.mark.parametrize("batch_size", [1, 4]) @pytest.mark.parametrize("num_pages_per_seq", [3, 8]) -@pytest.mark.parametrize("contiguous", [True, False]) +@pytest.mark.parametrize("layout", [KVCacheLayout.LBHNC, KVCacheLayout.LBNHC]) @torch.inference_mode() def test_trtllm_kvfp8_dequant( num_kv_heads: int, @@ -122,7 +102,7 @@ def test_trtllm_kvfp8_dequant( block_size: int, batch_size: int, num_pages_per_seq: int, - contiguous: bool, + layout: KVCacheLayout, ): from vllm.v1.attention.backends.flashinfer import ( trtllm_prefill_attn_kvfp8_dequant, @@ -130,20 +110,13 @@ def test_trtllm_kvfp8_dequant( torch.set_default_device("cuda") - if contiguous: - kv_cache, scale = make_contiguous_kv_cache( - NUM_BLOCKS, - num_kv_heads, - block_size, - head_size, - ) - else: - kv_cache, scale = make_cross_layer_kv_cache( - NUM_BLOCKS, - num_kv_heads, - block_size, - head_size, - ) + kv_cache, scale = make_random_kv_cache( + NUM_BLOCKS, + num_kv_heads, + block_size, + head_size, + layout=layout, + ) k_scale = scale.clone() v_scale = scale.clone() @@ -187,7 +160,7 @@ def test_block_tables_with_zero_pages(): torch.set_default_device("cuda") num_kv_heads, block_size, head_size = 8, 16, 64 - kv_cache, scale = make_contiguous_kv_cache( + kv_cache, scale = make_random_kv_cache( NUM_BLOCKS, num_kv_heads, block_size, @@ -234,7 +207,7 @@ def test_all_zero_block_tables(): torch.set_default_device("cuda") num_kv_heads, block_size, head_size = 4, 16, 64 - kv_cache, scale = make_contiguous_kv_cache( + kv_cache, scale = make_random_kv_cache( NUM_BLOCKS, num_kv_heads, block_size, @@ -266,7 +239,7 @@ def test_different_k_v_scales(): torch.set_default_device("cuda") num_kv_heads, block_size, head_size = 8, 16, 64 - kv_cache, _ = make_contiguous_kv_cache( + kv_cache, _ = make_random_kv_cache( NUM_BLOCKS, num_kv_heads, block_size, @@ -299,7 +272,7 @@ def test_single_page_per_seq(): torch.set_default_device("cuda") num_kv_heads, block_size, head_size = 8, 16, 128 - kv_cache, scale = make_contiguous_kv_cache( + kv_cache, scale = make_random_kv_cache( NUM_BLOCKS, num_kv_heads, block_size, @@ -332,7 +305,7 @@ def test_large_page_indices(): num_kv_heads, block_size, head_size = 8, 16, 128 large_num_blocks = 32768 - kv_cache, scale = make_contiguous_kv_cache( + kv_cache, scale = make_random_kv_cache( large_num_blocks, num_kv_heads, block_size, @@ -369,7 +342,7 @@ def test_large_block_size(): torch.set_default_device("cuda") num_kv_heads, block_size, head_size = 4, 64, 128 - kv_cache, scale = make_contiguous_kv_cache( + kv_cache, scale = make_random_kv_cache( NUM_BLOCKS, num_kv_heads, block_size, @@ -395,46 +368,3 @@ def test_large_block_size(): ref = ref_dequant(kv_cache, block_tables, k_scale, v_scale, torch.bfloat16) torch.testing.assert_close(mock_kv_cache[1:], ref[1:], atol=1e-3, rtol=1e-3) - - -@torch.inference_mode() -def test_cross_layer_many_layers(): - """ - Non-contiguous with 36 layers -- matches real gpt-oss-120b. - Strides are far from contiguous (factor of 36 in the gaps). - """ - from vllm.v1.attention.backends.flashinfer import ( - trtllm_prefill_attn_kvfp8_dequant, - ) - - torch.set_default_device("cuda") - num_kv_heads, block_size, head_size = 8, 16, 64 - num_layers = 36 - - kv_cache, scale = make_cross_layer_kv_cache( - NUM_BLOCKS, - num_kv_heads, - block_size, - head_size, - num_layers=num_layers, - ) - k_scale = v_scale = scale.clone() - - block_tables = torch.randint( - 1, - NUM_BLOCKS, - (4, 6), - dtype=torch.int32, - device="cuda", - ) - - mock_kv_cache, _ = trtllm_prefill_attn_kvfp8_dequant( - kv_cache, - block_tables, - k_scale, - v_scale, - torch.bfloat16, - ) - ref = ref_dequant(kv_cache, block_tables, k_scale, v_scale, torch.bfloat16) - - torch.testing.assert_close(mock_kv_cache[1:], ref[1:], atol=1e-3, rtol=1e-3) diff --git a/tests/test_attention_backend_registry.py b/tests/test_attention_backend_registry.py index 034749874d7b..6c5e0d7fa54f 100644 --- a/tests/test_attention_backend_registry.py +++ b/tests/test_attention_backend_registry.py @@ -38,11 +38,6 @@ def get_builder_cls(): """Mock builder class.""" return None - @staticmethod - def get_required_kv_cache_layout(): - """Mock KV cache layout.""" - return None - class CustomMambaAttentionImpl(AttentionImpl): """Mock custom mamba attention implementation for testing.""" @@ -71,11 +66,6 @@ def get_builder_cls(): """Mock builder class.""" return None - @staticmethod - def get_required_kv_cache_layout(): - """Mock KV cache layout.""" - return None - def test_custom_is_not_alias_of_any_backend(): # Get all members of AttentionBackendEnum diff --git a/tests/v1/attention/test_attention_backends.py b/tests/v1/attention/test_attention_backends.py index 62643032edb4..5b924c0ca11a 100644 --- a/tests/v1/attention/test_attention_backends.py +++ b/tests/v1/attention/test_attention_backends.py @@ -27,9 +27,10 @@ from vllm.v1.attention.backend import AttentionType, CommonAttentionMetadata from vllm.v1.attention.backends.registry import AttentionBackendEnum from vllm.v1.attention.backends.utils import ( + resolve_kv_cache_layout, set_kv_cache_layout, ) -from vllm.v1.kv_cache_interface import FullAttentionSpec +from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheLayout BACKENDS_TO_TEST = [ AttentionBackendEnum.FLASH_ATTN, @@ -105,26 +106,18 @@ def create_and_prepopulate_kv_cache( device: torch.device, num_blocks: int, common_attn_metadata: CommonAttentionMetadata, + layout: KVCacheLayout, randomize_blocks: bool = True, ) -> torch.Tensor: """Create and prepopulate a KV cache with context data. - Args: - k_contexts: List of key context tensors for each sequence - v_contexts: List of value context tensors for each sequence - seq_lens: List of sequence lengths - block_size: Size of each block - num_kv_heads: Number of KV heads - head_size: Size of each head - dtype: Data type for the cache - device: Device to create the cache on - num_blocks: Total number of blocks in the cache - block_table: Block table tensor to populate - randomize_blocks: Whether to randomly permute blocks - or use sequential order + Mirrors production's ``reshape_kv_cache``: allocates a flat buffer in + the physical order dictated by *layout*, then permutes to the logical + ``[B, H, N, C]`` shape that every backend expects. Returns: - Tuple of (kv_cache, updated_block_table) + A 4D tensor in logical ``(num_blocks, num_kv_heads, block_size, + 2 * head_size)`` order with strides determined by *layout*. """ batch_size = len(k_contexts) seq_lens = common_attn_metadata.seq_lens.cpu() @@ -136,43 +129,44 @@ def create_and_prepopulate_kv_cache( block_table = common_attn_metadata.block_table_tensor slot_mapping = common_attn_metadata.slot_mapping - # Create KV cache and populate in (2, num_blocks, ...) layout for easy - # flat indexing, then transpose to (num_blocks, 2, ...) layout. - kv_cache = torch.zeros( - 2, num_blocks, block_size, num_kv_heads, head_size, dtype=dtype, device=device - ) - kv_cache_flat = kv_cache.view(2, -1, num_kv_heads, head_size) - - # Populate the cache with the context tokens - # Start from block_id=1 since block_id=0 is considered the null block - start_block_idx = 1 + # --- allocate --------------------------------------------------------- + # Logical 4D shape is always [B, H, N, C]. + logical_4d = (num_blocks, num_kv_heads, block_size, 2 * head_size) + stride_order = layout.layer_stride_order + physical_4d = tuple(logical_4d[i] for i in stride_order) + inv_order = [stride_order.index(i) for i in range(4)] + + kv_cache_physical = torch.zeros(physical_4d, dtype=dtype, device=device) + # Permute to logical [B, H, N, C] — mirrors reshape_kv_cache. + kv_cache = kv_cache_physical.permute(*inv_order) + + # --- populate --------------------------------------------------------- + # Write context tokens into the cache via the logical view: + # kv_cache[block, :, token_in_block, :] routes correctly regardless + # of physical layout. + start_block_idx = 1 # block 0 is the null block for i in range(batch_size): k_context, v_context = k_contexts[i], v_contexts[i] - start = start_block_idx * block_size - end = start + k_context.shape[0] - kv_cache_flat[0, start:end, ...] = k_context - kv_cache_flat[1, start:end, ...] = v_context - - # Stay block aligned and allocate enough blocks for the new tokens + for t in range(k_context.shape[0]): + blk = start_block_idx + t // block_size + off = t % block_size + kv_cache[blk, :, off, :head_size] = k_context[t] + kv_cache[blk, :, off, head_size:] = v_context[t] start_block_idx += cdiv(int(seq_lens[i]), block_size) - # Transpose to (num_blocks, 2, ...) layout - kv_cache = kv_cache.transpose(0, 1).contiguous() - blocks_end = start_block_idx # Permute the context blocks (excluding block 0 which is null) if randomize_blocks: - # Random permutation starting from block 1 perm = torch.randperm(blocks_end - 1) + 1 else: - # Sequential order starting from block 1 perm = torch.arange(1, blocks_end) inv_perm = torch.zeros(blocks_end, dtype=torch.long, device=device) - # Add 1 to account for starting from block 1 inv_perm[1:] = torch.argsort(perm) + 1 - kv_cache[1:blocks_end, ...] = kv_cache[perm, ...] + # Shuffle on the physical tensor (dim-0 is always the block dim for + # layer-compact layouts). + kv_cache_physical[1:blocks_end, ...] = kv_cache_physical[perm, ...] # Construct the right block table # Start from block_id=1 since block_id=0 is considered the null block @@ -460,6 +454,7 @@ def _test_backend_correctness( common_attn_metadata.causal = causal # 3. Simulate Paged KV Cache and a realistic slot_mapping + layout = resolve_kv_cache_layout() kv_cache = create_and_prepopulate_kv_cache( k_contexts=k_contexts, v_contexts=v_contexts, @@ -470,6 +465,7 @@ def _test_backend_correctness( device=device, num_blocks=vllm_config.cache_config.num_gpu_blocks or 1000, common_attn_metadata=common_attn_metadata, + layout=layout, randomize_blocks=True, ) @@ -477,54 +473,26 @@ def _test_backend_correctness( # Note: flex_attention has known Triton kernel compatibility issues # with test infrastructures for backend_name in backend_to_test: - reset_kv_cache_layout = False - - # Resolve backend class for both enum and string names. - actual_backend = backend_name - if backend_name == "FLEX_ATTENTION_SLOW": - actual_backend = AttentionBackendEnum.FLEX_ATTENTION - if hasattr(actual_backend, "get_class"): - backend_cls = actual_backend.get_class() - else: - backend_cls = None - - if backend_name == AttentionBackendEnum.FLASHINFER: - set_kv_cache_layout("HND") - reset_kv_cache_layout = True - - # Apply stride order like runtime does in - # _reshape_kv_cache (attn_utils.py:182-210): permute to physical - # layout, make contiguous, then permute to logical layout. kv_cache_for_backend = kv_cache - if backend_cls is not None: - try: - stride_order = backend_cls.get_kv_cache_stride_order() - except (AttributeError, NotImplementedError): - stride_order = tuple(range(kv_cache.ndim)) - if stride_order != tuple(range(kv_cache.ndim)): - inv_order = [stride_order.index(i) for i in range(len(stride_order))] - kv_cache_for_backend = ( - kv_cache.permute(*stride_order).contiguous().permute(*inv_order) - ) - try: - backend_output = run_attention_backend( - backend_name, - kv_cache_spec, - ["placeholder"], - vllm_config, - device, - common_attn_metadata, - query_vllm, - key_vllm, - value_vllm, - kv_cache_for_backend, - sliding_window=sliding_window, - attn_type=attn_type, - ) - finally: - if reset_kv_cache_layout: - set_kv_cache_layout(None) + # FlashInfer reads the layout at plan time; override to match + # the physical order of the test cache. + set_kv_cache_layout(layout.name) + + backend_output = run_attention_backend( + backend_name, + kv_cache_spec, + ["placeholder"], + vllm_config, + device, + common_attn_metadata, + query_vllm, + key_vllm, + value_vllm, + kv_cache_for_backend, + sliding_window=sliding_window, + attn_type=attn_type, + ) # Check shape and dtype consistency assert backend_output.shape == sdpa_output.shape, ( @@ -571,7 +539,10 @@ def error_msg(msg: str, backend_name: str): @pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"]) @pytest.mark.parametrize("tensor_parallel_size", [1, 2, 4]) def test_causal_backend_correctness( - default_vllm_config, batch_spec_name: str, model: str, tensor_parallel_size: int + default_vllm_config, + batch_spec_name: str, + model: str, + tensor_parallel_size: int, ): """Test backend's correctness with causal attention.""" @@ -656,7 +627,9 @@ def causal_mask_mod( @pytest.mark.parametrize("model", ["microsoft/Phi-tiny-MoE-instruct"]) @pytest.mark.parametrize("tensor_parallel_size", [1, 2, 4]) def test_sliding_window_backend_correctness( - batch_spec_name: str, model: str, tensor_parallel_size: int + batch_spec_name: str, + model: str, + tensor_parallel_size: int, ): """Test backend's correctness with sliding window attention.""" @@ -718,7 +691,9 @@ def sliding_window_mask_mod( @pytest.mark.parametrize("model", ["google/embeddinggemma-300m"]) @pytest.mark.parametrize("tensor_parallel_size", [1, 2]) def test_sliding_window_encoder_backend_correctness( - batch_spec_name: str, model: str, tensor_parallel_size: int + batch_spec_name: str, + model: str, + tensor_parallel_size: int, ): """Test backend's correctness with sliding window attention.""" @@ -775,7 +750,9 @@ def bidi_sliding_window_mask_mod( ) @pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"]) def test_non_causal_backend_correctness( - default_vllm_config, batch_spec_name: str, model: str + default_vllm_config, + batch_spec_name: str, + model: str, ): """Test backend's correctness with non-causal (bidirectional) decoder attention, as used by DFlash speculative decoding.""" diff --git a/tests/v1/attention/test_indexer_deepseek_v4_slot_mapping.py b/tests/v1/attention/test_indexer_deepseek_v4_slot_mapping.py index 159bb8af3fb9..a7514381f0a2 100644 --- a/tests/v1/attention/test_indexer_deepseek_v4_slot_mapping.py +++ b/tests/v1/attention/test_indexer_deepseek_v4_slot_mapping.py @@ -17,13 +17,13 @@ def test_indexer_builder_deepseek_v4_compressed_slot_mapping_uses_storage_block_ """ device = torch.device("cuda") - # storage_block_size = block_size // compress_ratio = 256 // 4 = 64 + # storage_block_size = block_size // tokens_per_state = 256 // 4 = 64 kv_cache_spec = MLAAttentionSpec( block_size=256, num_kv_heads=1, head_size=128, dtype=torch.bfloat16, - compress_ratio=4, + tokens_per_state=4, ) vllm_config = create_vllm_config(max_model_len=1024) builder = DeepseekV32IndexerMetadataBuilder( diff --git a/tests/v1/attention/test_mla_backends.py b/tests/v1/attention/test_mla_backends.py index 109e56cb3838..c77f117e6f1f 100644 --- a/tests/v1/attention/test_mla_backends.py +++ b/tests/v1/attention/test_mla_backends.py @@ -205,8 +205,9 @@ def create_and_prepopulate_kv_cache( else: kv_entry_size = head_size + # Create MLA KV cache: (num_blocks, num_heads=1, block_size, kv_entry_size) kv_cache = torch.zeros( - num_blocks, block_size, kv_entry_size, dtype=torch.uint8, device=device + num_blocks, 1, block_size, kv_entry_size, dtype=torch.uint8, device=device ) scale_tensor = ( scale @@ -215,9 +216,9 @@ def create_and_prepopulate_kv_cache( ) scale_tensor = scale_tensor.to(device=device, dtype=torch.float32) else: - # Create MLA KV cache: (num_blocks, block_size, head_size) + # Create MLA KV cache: (num_blocks, num_heads=1, block_size, head_size) kv_cache = torch.zeros( - num_blocks, block_size, head_size, dtype=dtype, device=device + num_blocks, 1, block_size, head_size, dtype=dtype, device=device ) kv_cache_flat = kv_cache.view(-1, head_size) @@ -238,7 +239,7 @@ def create_and_prepopulate_kv_cache( ops.concat_and_cache_mla( kv_c_context, k_pe_context.squeeze(1), - kv_cache, + kv_cache.squeeze(1), slots, kv_cache_dtype=kv_cache_dtype, scale=scale_tensor, @@ -361,7 +362,7 @@ def forward_impl( ops.concat_and_cache_mla( kv_c, k_pe.squeeze(1), - kv_cache, + kv_cache.squeeze(1), attn_metadata.slot_mapping.flatten(), kv_cache_dtype=kv_cache_dtype, scale=self._k_scale, @@ -497,7 +498,7 @@ def forward_impl( ops.concat_and_cache_mla( kv_c, k_pe.squeeze(1), - kv_cache, + kv_cache.squeeze(1), attn_metadata.slot_mapping.flatten(), kv_cache_dtype=kv_cache_dtype, scale=self._k_scale, diff --git a/tests/v1/attention/test_trtllm_attention_integration.py b/tests/v1/attention/test_trtllm_attention_integration.py index 06c5844508f4..d7b65ebb2d1a 100644 --- a/tests/v1/attention/test_trtllm_attention_integration.py +++ b/tests/v1/attention/test_trtllm_attention_integration.py @@ -9,6 +9,7 @@ import torch from torch.nn.attention.flex_attention import create_block_mask, flex_attention +from tests.v1.attention.test_attention_backends import create_and_prepopulate_kv_cache from tests.v1.attention.utils import ( BatchSpec, create_common_attn_metadata, @@ -16,14 +17,13 @@ ) from vllm.config import set_current_vllm_config from vllm.platforms import current_platform -from vllm.utils.math_utils import cdiv from vllm.utils.torch_utils import nvfp4_kv_cache_full_dim, set_random_seed from vllm.v1.attention.backends.utils import ( PerLayerParameters, - get_kv_cache_layout, + resolve_kv_cache_layout, set_kv_cache_layout, ) -from vllm.v1.kv_cache_interface import FullAttentionSpec, KVQuantMode +from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheLayout, KVQuantMode if not current_platform.is_device_capability_family(100): pytest.skip( @@ -86,90 +86,6 @@ def _mock_get_per_layer_parameters(vllm_config, layer_names, impl_cls): } -def _create_hnd_kv_cache( - k_contexts, - v_contexts, - block_size, - num_kv_heads, - head_size, - dtype, - device, - num_blocks, - common_attn_metadata, -): - """Create and populate a KV cache with HND-compatible strides. - - The returned tensor has logical shape - (num_blocks, 2, block_size, num_kv_heads, head_size) but is physically - laid out as (num_blocks, 2, num_kv_heads, block_size, head_size) so that - ``kv_cache.permute(0, 1, 3, 2, 4)`` yields a contiguous HND view. - """ - seq_lens = common_attn_metadata.seq_lens.cpu() - query_lens = ( - common_attn_metadata.query_start_loc_cpu[1:] - - common_attn_metadata.query_start_loc_cpu[:-1] - ) - block_table = common_attn_metadata.block_table_tensor - slot_mapping = common_attn_metadata.slot_mapping - batch_size = len(k_contexts) - - # Build cache in (2, num_blocks, block_size, num_kv_heads, head_size) - # then convert to HND format (same approach as test_attention_backends.py). - kv_cache_raw = torch.zeros( - 2, - num_blocks, - block_size, - num_kv_heads, - head_size, - dtype=dtype, - device=device, - ) - kv_cache_flat = kv_cache_raw.view(2, -1, num_kv_heads, head_size) - - start_block_idx = 1 - for i in range(batch_size): - k_ctx, v_ctx = k_contexts[i], v_contexts[i] - start = start_block_idx * block_size - end = start + k_ctx.shape[0] - kv_cache_flat[0, start:end] = k_ctx - kv_cache_flat[1, start:end] = v_ctx - start_block_idx += cdiv(int(seq_lens[i]), block_size) - - blocks_end = start_block_idx - - # Randomly permute blocks (starting from block 1; block 0 is null). - perm = torch.randperm(blocks_end - 1) + 1 - inv_perm = torch.zeros(blocks_end, dtype=torch.long, device=device) - inv_perm[1:] = torch.argsort(perm) + 1 - kv_cache_raw[:, 1:blocks_end] = kv_cache_raw[:, perm] - - # Build block table. - start_block_idx = 1 - for i in range(batch_size): - n_blocks = cdiv(int(seq_lens[i]), block_size) - block_table[i, :n_blocks] = inv_perm[ - start_block_idx : start_block_idx + n_blocks - ] - start_block_idx += n_blocks - - # Build slot mapping that is consistent with the block table. - for i in range(batch_size): - ctx_len = int(seq_lens[i]) - int(query_lens[i]) - token_offsets = torch.arange(int(query_lens[i])) + ctx_len - block_indices = token_offsets // block_size - intra_block_offsets = token_offsets % block_size - start = common_attn_metadata.query_start_loc_cpu[i] - end = common_attn_metadata.query_start_loc_cpu[i + 1] - slot_mapping[start:end] = block_table[ - i, block_indices - ] * block_size + intra_block_offsets.to(device) - - # Transpose to FlashInfer logical shape then make HND-strided. - kv_cache = kv_cache_raw.transpose(0, 1) - kv_cache = kv_cache.transpose(2, 3).contiguous().transpose(2, 3) - return kv_cache - - def _create_nvfp4_hnd_kv_cache( k_contexts, v_contexts, @@ -182,43 +98,15 @@ def _create_nvfp4_hnd_kv_cache( common_attn_metadata, kv_scale_val, ): - """Create an nvfp4 KV cache by quantizing bf16 context via - reshape_and_cache_flash, using the same block-table layout as - _create_hnd_kv_cache. - - The returned tensor is dtype ``uint8`` with shape - ``(num_blocks, 2, block_size, num_kv_heads, full_dim)`` in logical - (NHD) order, but physically permuted to HND layout via stride order - ``(0, 1, 3, 2, 4)`` (i.e. ``num_kv_heads`` before ``block_size``). - - The last dimension ``full_dim = head_size // 2 + head_size // 16`` - packs two regions contiguously: - - **FP4 data** (``head_size // 2`` bytes): pairs of E2M1 values, - two per byte. - - **FP8 block scales** (``head_size // 16`` bytes): one E4M3 - scale per 16-element block. - - Dimension 1 indexes K (``[:, 0]``) and V (``[:, 1]``). - - Args: - k_contexts: List of key context tensors, one per sequence. - v_contexts: List of value context tensors, one per sequence. - block_size: Number of tokens per cache block. - num_kv_heads: Number of key/value heads. - head_size: Head dimension (must be divisible by 16). - dtype: Source data type for the bf16 intermediate cache. - device: Target device. - num_blocks: Total number of blocks to allocate. - common_attn_metadata: Metadata containing block tables and - sequence lengths. - kv_scale_val: Scalar float used as both k_scale and v_scale - during quantization. - - Returns: - ``torch.Tensor``: The nvfp4 kv_cache tensor (uint8, HND-strided). + """Create an nvfp4 KV cache with 2H head layout. + + The returned tensor is dtype ``uint8`` with logical shape + ``(num_blocks, 2 * num_kv_heads, block_size, full_dim)`` where K occupies + the first H heads and V occupies the next H heads, and + ``full_dim = head_size // 2 + head_size // 16`` packs FP4 data and + FP8 block scales per head. """ - # First create a bf16 HND cache so block tables are populated. - bf16_cache = _create_hnd_kv_cache( + bf16_cache = create_and_prepopulate_kv_cache( k_contexts, v_contexts, block_size, @@ -228,20 +116,18 @@ def _create_nvfp4_hnd_kv_cache( device, num_blocks, common_attn_metadata, + layout=KVCacheLayout.LBHNC, ) - # Allocate nvfp4 cache: same shape but with full_dim (data + scale). full_dim = nvfp4_kv_cache_full_dim(head_size) - hnd_order = (0, 1, 3, 2, 4) nvfp4_cache = torch.zeros( - (num_blocks, 2, num_kv_heads, block_size, full_dim), + (num_blocks, 2 * num_kv_heads, block_size, full_dim), dtype=torch.uint8, device=device, - ).permute(*hnd_order) + ) + k_cache = nvfp4_cache[:, :num_kv_heads] + v_cache = nvfp4_cache[:, num_kv_heads:] - # Flatten bf16 context into tokens and quantize via reshape_and_cache_flash. - # bf16_cache is (num_blocks, 2, block_size, num_kv_heads, head_size) logical - # with HND physical strides. block_table = common_attn_metadata.block_table_tensor seq_lens = common_attn_metadata.seq_lens.cpu() query_lens = ( @@ -254,22 +140,29 @@ def _create_nvfp4_hnd_kv_cache( ctx_len = int(seq_lens[i]) - int(query_lens[i]) if ctx_len == 0: continue - # Gather context tokens from the bf16 cache using block table. n_ctx_blocks = (ctx_len + block_size - 1) // block_size blocks = block_table[i, :n_ctx_blocks] - # bf16_cache[:, kv_idx] is (num_blocks, block_size, num_kv_heads, head_size) - k_ctx = bf16_cache[blocks, 0].reshape(-1, num_kv_heads, head_size)[:ctx_len] - v_ctx = bf16_cache[blocks, 1].reshape(-1, num_kv_heads, head_size)[:ctx_len] - # Build slot mapping for these context tokens. + # bf16_cache is (B, H, N, 2*head_size); extract K and V from last dim. + k_ctx = ( + bf16_cache[blocks, :, :, :head_size] + .transpose(1, 2) + .reshape(-1, num_kv_heads, head_size)[:ctx_len] + ) + v_ctx = ( + bf16_cache[blocks, :, :, head_size:] + .transpose(1, 2) + .reshape(-1, num_kv_heads, head_size)[:ctx_len] + ) token_offsets = torch.arange(ctx_len, device=device) block_indices = token_offsets // block_size intra_offsets = token_offsets % block_size slots = block_table[i, block_indices] * block_size + intra_offsets + # reshape_and_cache_flash expects (B, N, H, D) cache views torch.ops._C_cache_ops.reshape_and_cache_flash( k_ctx, v_ctx, - nvfp4_cache[:, 0], - nvfp4_cache[:, 1], + k_cache.transpose(1, 2), + v_cache.transpose(1, 2), slots, "nvfp4", kv_scale_t, @@ -358,7 +251,7 @@ def causal_mask_mod(b, h, q_idx, kv_idx, *, context_len): common_attn_metadata = create_common_attn_metadata(batch_spec, BLOCK_SIZE, device) - # 2. Create HND KV cache + # 2. Create HNC KV cache is_nvfp4 = kv_cache_dtype == "nvfp4" if is_nvfp4: # Compute a global scale from the context data. @@ -378,7 +271,7 @@ def causal_mask_mod(b, h, q_idx, kv_idx, *, context_len): ) else: kv_scale_val = 1.0 - kv_cache = _create_hnd_kv_cache( + kv_cache = create_and_prepopulate_kv_cache( k_contexts, v_contexts, BLOCK_SIZE, @@ -388,11 +281,12 @@ def causal_mask_mod(b, h, q_idx, kv_idx, *, context_len): device, NUM_GPU_BLOCKS, common_attn_metadata, + layout=KVCacheLayout.LBHNC, ) # 3. Run through FlashInfer with TRTLLM enabled - set_kv_cache_layout("HND") - get_kv_cache_layout.cache_clear() + set_kv_cache_layout("LBHNC") + resolve_kv_cache_layout.cache_clear() try: is_nvfp4 = kv_cache_dtype == "nvfp4" @@ -506,7 +400,7 @@ def causal_mask_mod(b, h, q_idx, kv_idx, *, context_len): finally: set_kv_cache_layout(None) - get_kv_cache_layout.cache_clear() + resolve_kv_cache_layout.cache_clear() @pytest.mark.parametrize( diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index 68ad7bc42ef0..6174fb78e19a 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -788,36 +788,19 @@ def test_get_kv_cache_configs_multiple_workers(): ref_kv_cache_spec.page_size_bytes * 2 * 10, ], ) - assert kv_cache_configs == [ - KVCacheConfig( - num_blocks=10, - kv_cache_tensors=[ - KVCacheTensor( - size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer1"] - ), - KVCacheTensor( - size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer2"] - ), - ], - kv_cache_groups=[ - KVCacheGroupSpec(["layer1", "layer2"], ref_kv_cache_spec), - ], - ), - KVCacheConfig( - num_blocks=10, - kv_cache_tensors=[ - KVCacheTensor( - size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer1"] - ), - KVCacheTensor( - size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer2"] - ), - ], - kv_cache_groups=[ - KVCacheGroupSpec(["layer1", "layer2"], ref_kv_cache_spec), - ], - ), - ] + expected = KVCacheConfig( + num_blocks=10, + kv_cache_tensors=[ + KVCacheTensor( + size=ref_kv_cache_spec.page_size_bytes * 10 * 2, + shared_by=[["layer1"], ["layer2"]], + ), + ], + kv_cache_groups=[ + KVCacheGroupSpec(["layer1", "layer2"], ref_kv_cache_spec), + ], + ) + assert kv_cache_configs == [expected, expected] # Different available memory. This is the case for TP. # Use the smallest memory available. @@ -829,36 +812,7 @@ def test_get_kv_cache_configs_multiple_workers(): ref_kv_cache_spec.page_size_bytes * 2 * 20, ], ) - assert kv_cache_configs == [ - KVCacheConfig( - num_blocks=10, - kv_cache_tensors=[ - KVCacheTensor( - size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer1"] - ), - KVCacheTensor( - size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer2"] - ), - ], - kv_cache_groups=[ - KVCacheGroupSpec(["layer1", "layer2"], ref_kv_cache_spec), - ], - ), - KVCacheConfig( - num_blocks=10, - kv_cache_tensors=[ - KVCacheTensor( - size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer1"] - ), - KVCacheTensor( - size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer2"] - ), - ], - kv_cache_groups=[ - KVCacheGroupSpec(["layer1", "layer2"], ref_kv_cache_spec), - ], - ), - ] + assert kv_cache_configs == [expected, expected] # Different KV cache specs. This is the case for PP. different_layer_specs = [ @@ -885,7 +839,8 @@ def test_get_kv_cache_configs_multiple_workers(): num_blocks=10, kv_cache_tensors=[ KVCacheTensor( - size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer1"] + size=ref_kv_cache_spec.page_size_bytes * 10, + shared_by=[["layer1"]], ), ], kv_cache_groups=[ @@ -896,10 +851,8 @@ def test_get_kv_cache_configs_multiple_workers(): num_blocks=10, kv_cache_tensors=[ KVCacheTensor( - size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer2"] - ), - KVCacheTensor( - size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer3"] + size=ref_kv_cache_spec.page_size_bytes * 10 * 2, + shared_by=[["layer2"], ["layer3"]], ), ], kv_cache_groups=[ @@ -929,64 +882,37 @@ def test_get_kv_cache_configs_multiple_workers(): kv_cache_configs = get_kv_cache_configs( vllm_config, tp_pp_kv_cache_specs, - [ - ref_kv_cache_spec.page_size_bytes * 2 * 10, - ref_kv_cache_spec.page_size_bytes * 2 * 10, - ref_kv_cache_spec.page_size_bytes * 2 * 10, - ref_kv_cache_spec.page_size_bytes * 2 * 10, + [ref_kv_cache_spec.page_size_bytes * 2 * 10] * 4, + ) + expected_12 = KVCacheConfig( + num_blocks=10, + kv_cache_tensors=[ + KVCacheTensor( + size=ref_kv_cache_spec.page_size_bytes * 10 * 2, + shared_by=[["layer1"], ["layer2"]], + ), + ], + kv_cache_groups=[ + KVCacheGroupSpec(["layer1", "layer2"], ref_kv_cache_spec), + ], + ) + expected_3 = KVCacheConfig( + num_blocks=10, + kv_cache_tensors=[ + KVCacheTensor( + size=ref_kv_cache_spec.page_size_bytes * 10, + shared_by=[["layer3"]], + ), + ], + kv_cache_groups=[ + KVCacheGroupSpec(["layer3"], ref_kv_cache_spec), ], ) assert kv_cache_configs == [ - KVCacheConfig( - num_blocks=10, - kv_cache_tensors=[ - KVCacheTensor( - size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer1"] - ), - KVCacheTensor( - size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer2"] - ), - ], - kv_cache_groups=[ - KVCacheGroupSpec(["layer1", "layer2"], ref_kv_cache_spec), - ], - ), - KVCacheConfig( - num_blocks=10, - kv_cache_tensors=[ - KVCacheTensor( - size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer1"] - ), - KVCacheTensor( - size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer2"] - ), - ], - kv_cache_groups=[ - KVCacheGroupSpec(["layer1", "layer2"], ref_kv_cache_spec), - ], - ), - KVCacheConfig( - num_blocks=10, - kv_cache_tensors=[ - KVCacheTensor( - size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer3"] - ), - ], - kv_cache_groups=[ - KVCacheGroupSpec(["layer3"], ref_kv_cache_spec), - ], - ), - KVCacheConfig( - num_blocks=10, - kv_cache_tensors=[ - KVCacheTensor( - size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer3"] - ), - ], - kv_cache_groups=[ - KVCacheGroupSpec(["layer3"], ref_kv_cache_spec), - ], - ), + expected_12, + expected_12, + expected_3, + expected_3, ] # Different workers have different types of layers. This is the case for @@ -1014,10 +940,8 @@ def test_get_kv_cache_configs_multiple_workers(): num_blocks=10, kv_cache_tensors=[ KVCacheTensor( - size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer1"] - ), - KVCacheTensor( - size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer2"] + size=ref_kv_cache_spec.page_size_bytes * 10 * 2, + shared_by=[["layer1"], ["layer2"]], ), ], kv_cache_groups=[ @@ -1029,10 +953,8 @@ def test_get_kv_cache_configs_multiple_workers(): num_blocks=10, kv_cache_tensors=[ KVCacheTensor( - size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer3"] - ), - KVCacheTensor( - size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer4"] + size=ref_kv_cache_spec.page_size_bytes * 10 * 2, + shared_by=[["layer3"], ["layer4"]], ), ], kv_cache_groups=[ @@ -1059,10 +981,7 @@ def test_get_kv_cache_configs_multiple_workers(): kv_cache_configs = get_kv_cache_configs( vllm_config, different_type_layer_specs, - [ - ref_kv_cache_spec.page_size_bytes * 10, - ref_kv_cache_spec.page_size_bytes * 10, - ], + [ref_kv_cache_spec.page_size_bytes * 10] * 2, ) assert kv_cache_configs == [ KVCacheConfig( @@ -1070,7 +989,7 @@ def test_get_kv_cache_configs_multiple_workers(): kv_cache_tensors=[ KVCacheTensor( size=ref_kv_cache_spec.page_size_bytes * 10, - shared_by=["layer1", "layer2", "layer3"], + shared_by=[["layer1", "layer2", "layer3"]], ), ], kv_cache_groups=[ @@ -1084,7 +1003,7 @@ def test_get_kv_cache_configs_multiple_workers(): kv_cache_tensors=[ KVCacheTensor( size=ref_kv_cache_spec.page_size_bytes * 10, - shared_by=["layer4", "layer5", "layer6"], + shared_by=[["layer4", "layer5", "layer6"]], ), ], kv_cache_groups=[ @@ -1151,7 +1070,7 @@ def test_get_kv_cache_configs_pp_sharding(asymmetric_memory): kv_cache_tensors=[ KVCacheTensor( size=ref_kv_cache_spec.page_size_bytes * expected_num_blocks, - shared_by=["layer1"], + shared_by=[["layer1"]], ), ], kv_cache_groups=[KVCacheGroupSpec(["layer1"], ref_kv_cache_spec)], @@ -1161,7 +1080,7 @@ def test_get_kv_cache_configs_pp_sharding(asymmetric_memory): kv_cache_tensors=[ KVCacheTensor( size=ref_kv_cache_spec.page_size_bytes * expected_num_blocks, - shared_by=["layer2"], + shared_by=[["layer2"]], ), ], kv_cache_groups=[KVCacheGroupSpec(["layer2"], ref_kv_cache_spec)], @@ -1430,7 +1349,7 @@ def test_allocate_with_lookahead(): config = KVCacheConfig( num_blocks=10, kv_cache_tensors=[ - KVCacheTensor(size=100, shared_by=["layer1"]), + KVCacheTensor(size=100, shared_by=[["layer1"]]), ], kv_cache_groups=[ KVCacheGroupSpec(["layer1"], new_kv_cache_spec(block_size=block_size)), @@ -1505,11 +1424,14 @@ def test_get_kv_cache_config_one_worker(): vllm_config, [kv_cache_specs_full], [mem_per_block_per_layer * 2 * 32] )[0] print(kv_cache_config_full) + assert kv_cache_config_full == KVCacheConfig( num_blocks=32, kv_cache_tensors=[ - KVCacheTensor(size=mem_per_block_per_layer * 32, shared_by=["layer_1"]), - KVCacheTensor(size=mem_per_block_per_layer * 32, shared_by=["layer_2"]), + KVCacheTensor( + size=mem_per_block_per_layer * 32 * 2, + shared_by=[["layer_1"], ["layer_2"]], + ), ], kv_cache_groups=[KVCacheGroupSpec(["layer_1", "layer_2"], new_kv_cache_spec())], ) @@ -1525,8 +1447,10 @@ def test_get_kv_cache_config_one_worker(): assert kv_cache_config_sliding == KVCacheConfig( num_blocks=32, kv_cache_tensors=[ - KVCacheTensor(size=mem_per_block_per_layer * 32, shared_by=["layer_1"]), - KVCacheTensor(size=mem_per_block_per_layer * 32, shared_by=["layer_2"]), + KVCacheTensor( + size=mem_per_block_per_layer * 32 * 2, + shared_by=[["layer_1"], ["layer_2"]], + ), ], kv_cache_groups=[ KVCacheGroupSpec(["layer_1", "layer_2"], new_sliding_window_spec()) @@ -1545,8 +1469,10 @@ def test_get_kv_cache_config_one_worker(): assert kv_cache_config_hybrid == KVCacheConfig( num_blocks=32, kv_cache_tensors=[ - KVCacheTensor(size=mem_per_block_per_layer * 32, shared_by=["layer_1"]), - KVCacheTensor(size=mem_per_block_per_layer * 32, shared_by=["layer_2"]), + KVCacheTensor( + size=mem_per_block_per_layer * 32 * 2, + shared_by=[["layer_1"], ["layer_2"]], + ), ], kv_cache_groups=[ KVCacheGroupSpec( @@ -1568,7 +1494,8 @@ def test_get_kv_cache_config_one_worker(): num_blocks=64, kv_cache_tensors=[ KVCacheTensor( - size=mem_per_block_per_layer * 64, shared_by=["layer_1", "layer_2"] + size=mem_per_block_per_layer * 64, + shared_by=[["layer_1", "layer_2"]], ), ], kv_cache_groups=[ @@ -1593,12 +1520,11 @@ def test_get_kv_cache_config_one_worker(): num_blocks=32, kv_cache_tensors=[ KVCacheTensor( - size=mem_per_block_per_layer * 32, - shared_by=["layer_1", "layer_3", "layer_4"], - ), - KVCacheTensor( - size=mem_per_block_per_layer * 32, - shared_by=["layer_2", "layer_5", "layer_6"], + size=mem_per_block_per_layer * 32 * 2, + shared_by=[ + ["layer_1", "layer_3", "layer_4"], + ["layer_2", "layer_5", "layer_6"], + ], ), ], kv_cache_groups=[ @@ -1628,15 +1554,12 @@ def test_get_kv_cache_config_one_worker(): num_blocks=32, kv_cache_tensors=[ KVCacheTensor( - size=mem_per_block_per_layer * 32, - shared_by=["layer_1", "layer_4", "layer_5", "layer_6"], - ), - KVCacheTensor( - size=mem_per_block_per_layer * 32, - shared_by=["layer_2", "layer_7", "layer_8", "layer_9"], - ), - KVCacheTensor( - size=mem_per_block_per_layer * 32, shared_by=["layer_3", "layer_10"] + size=mem_per_block_per_layer * 32 * 3, + shared_by=[ + ["layer_1", "layer_4", "layer_5", "layer_6"], + ["layer_2", "layer_7", "layer_8", "layer_9"], + ["layer_3", "layer_10"], + ], ), ], kv_cache_groups=[ @@ -1649,8 +1572,7 @@ def test_get_kv_cache_config_one_worker(): ], ) - # 6 full + 5 sliding, pad to 6 full + 6 sliding. This is a typical case for gpt-oss - # eagle where there is only one more full attention layer than sliding window layers + # 6 full + 5 sliding kv_cache_specs_hybrid = { "layer_1": new_kv_cache_spec(), "layer_2": new_kv_cache_spec(), @@ -1668,33 +1590,19 @@ def test_get_kv_cache_config_one_worker(): kv_cache_config_hybrid = get_kv_cache_configs( vllm_config, [kv_cache_specs_hybrid], [mem_per_block_per_layer * 6 * 32] )[0] - print(kv_cache_config_hybrid) assert kv_cache_config_hybrid == KVCacheConfig( num_blocks=32, kv_cache_tensors=[ KVCacheTensor( - size=mem_per_block_per_layer * 32, - shared_by=["layer_1", "layer_7"], - ), - KVCacheTensor( - size=mem_per_block_per_layer * 32, - shared_by=["layer_2", "layer_8"], - ), - KVCacheTensor( - size=mem_per_block_per_layer * 32, - shared_by=["layer_3", "layer_9"], - ), - KVCacheTensor( - size=mem_per_block_per_layer * 32, - shared_by=["layer_4", "layer_10"], - ), - KVCacheTensor( - size=mem_per_block_per_layer * 32, - shared_by=["layer_5", "layer_11"], - ), - KVCacheTensor( - size=mem_per_block_per_layer * 32, - shared_by=["layer_6"], + size=mem_per_block_per_layer * 32 * 6, + shared_by=[ + ["layer_1", "layer_7"], + ["layer_2", "layer_8"], + ["layer_3", "layer_9"], + ["layer_4", "layer_10"], + ["layer_5", "layer_11"], + ["layer_6"], + ], ), ], kv_cache_groups=[ @@ -1720,8 +1628,14 @@ def test_get_kv_cache_config_one_worker(): assert kv_cache_config_hybrid == KVCacheConfig( num_blocks=32, kv_cache_tensors=[ - KVCacheTensor(size=mem_per_block_per_layer * 32 * 2, shared_by=["layer_1"]), - KVCacheTensor(size=mem_per_block_per_layer * 32, shared_by=["layer_2"]), + KVCacheTensor( + size=mem_per_block_per_layer * 32 * 2, + shared_by=[["layer_1"]], + ), + KVCacheTensor( + size=mem_per_block_per_layer * 32, + shared_by=[["layer_2"]], + ), ], kv_cache_groups=[ KVCacheGroupSpec( @@ -1745,7 +1659,8 @@ def test_get_kv_cache_config_one_worker(): num_blocks=32, kv_cache_tensors=[ KVCacheTensor( - size=mem_per_block_per_layer * 32, shared_by=["layer_1", "layer_2"] + size=mem_per_block_per_layer * 32, + shared_by=[["layer_1", "layer_2"]], ), ], kv_cache_groups=[ @@ -1775,8 +1690,10 @@ def test_get_kv_cache_config_one_worker(): assert kv_cache_config_override_blocks == KVCacheConfig( num_blocks=16, kv_cache_tensors=[ - KVCacheTensor(size=mem_per_block_per_layer * 16, shared_by=["layer_1"]), - KVCacheTensor(size=mem_per_block_per_layer * 16, shared_by=["layer_2"]), + KVCacheTensor( + size=mem_per_block_per_layer * 16 * 2, + shared_by=[["layer_1"], ["layer_2"]], + ), ], kv_cache_groups=[KVCacheGroupSpec(["layer_1", "layer_2"], new_kv_cache_spec())], ) diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index 91c5f37b4179..c6ed619b4927 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -140,7 +140,7 @@ def make_kv_cache_config_hybrid_model( elif second_spec_type == "mamba": second_spec = MambaSpec( block_size=block_size, - shapes=(1, 1), + shapes=((1, 1),), dtypes=(torch.float32,), ) @@ -175,7 +175,7 @@ def make_kv_cache_config_three_types( if third_spec_type == "mamba": third_spec = MambaSpec( block_size=block_size, - shapes=(1, 1), + shapes=((1, 1),), dtypes=(torch.float32,), ) elif third_spec_type == "sliding_window": @@ -755,12 +755,12 @@ def _make_hybrid_kv_cache_config( ), "mamba": lambda: MambaSpec( block_size=block_size, - shapes=(1, 1), + shapes=((1, 1),), dtypes=(torch.float32,), ), "mamba_align": lambda: MambaSpec( block_size=block_size, - shapes=(1, 1), + shapes=((1, 1),), dtypes=(torch.float32,), mamba_cache_mode="align", ), diff --git a/tests/v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh b/tests/v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh index 040632249d34..76aaa5f0e5b4 100755 --- a/tests/v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh +++ b/tests/v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh @@ -94,7 +94,6 @@ else echo "running with default attention backend" fi -# Check if cross-layers is enabled (non-empty) if [[ -n "${CROSS_LAYERS_BLOCKS:-}" ]]; then echo "CROSS_LAYERS_BLOCKS is set, running with --enable-cross-layers" label+=" - CROSS_LAYERS_BLOCKS enabled" diff --git a/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh b/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh index fc446a0e7658..1f26df012bbb 100755 --- a/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh +++ b/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh @@ -4,7 +4,6 @@ set -xe # Parse command line arguments KV_BUFFER_DEVICE="cuda" # Default to cuda ATTENTION_BACKEND="" # Default to empty (use vllm default) -CROSS_LAYERS_BLOCKS="False" ENABLE_HMA_VAR="" # Default to empty (HMA disabled by default for kv connector) # Check for ENABLE_HMA_FLAG environment variable if [[ -n "${ENABLE_HMA_FLAG:-}" ]]; then @@ -22,12 +21,12 @@ while [[ $# -gt 0 ]]; do shift 2 ;; --enable-cross-layers) - CROSS_LAYERS_BLOCKS="True" + export VLLM_KV_CACHE_LAYOUT="BLHNC" shift 1 ;; *) echo "Unknown option $1" - echo "Usage: $0 [--kv_buffer_device ] [--attention-backend ]" + echo "Usage: $0 [--kv_buffer_device ] [--attention-backend ] [--enable-cross-layers]" exit 1 ;; esac @@ -44,24 +43,19 @@ if [[ -n "$VLLM_SERVE_EXTRA_ARGS" ]]; then echo "vLLM serve extra args: $VLLM_SERVE_EXTRA_ARGS" fi -DECODER_KV_LAYOUT=${DECODER_KV_LAYOUT:-"HND"} # Default to HND, optional NHD -if [[ "$DECODER_KV_LAYOUT" == "NHD" ]]; then +PREFILLER_KV_LAYOUT=${VLLM_KV_CACHE_LAYOUT:-"LBHNC"} +DECODER_KV_LAYOUT=${DECODER_KV_LAYOUT:-"$PREFILLER_KV_LAYOUT"} +if [[ "$DECODER_KV_LAYOUT" == "LBNHC" ]]; then KV_CONFIG_HETERO_LAYOUT=',"enable_permute_local_kv":"True"' else KV_CONFIG_HETERO_LAYOUT='' fi -if [[ "$CROSS_LAYERS_BLOCKS" == "True" ]]; then - KV_EXTRA_CONFIG=',"kv_connector_extra_config":{"enable_cross_layers_blocks": "True"}' -else - KV_EXTRA_CONFIG='' -fi - # Build the kv-transfer-config once if [[ "$KV_BUFFER_DEVICE" == "cuda" ]]; then - KV_CONFIG='{"kv_connector":"NixlConnector","kv_role":"kv_both"'${KV_CONFIG_HETERO_LAYOUT}${KV_EXTRA_CONFIG}'}' + KV_CONFIG='{"kv_connector":"NixlConnector","kv_role":"kv_both"'${KV_CONFIG_HETERO_LAYOUT}'}' else - KV_CONFIG="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"$KV_BUFFER_DEVICE\""${KV_CONFIG_HETERO_LAYOUT}${KV_EXTRA_CONFIG}"}" + KV_CONFIG="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"$KV_BUFFER_DEVICE\""${KV_CONFIG_HETERO_LAYOUT}"}" fi # Models to run @@ -158,7 +152,7 @@ run_tests_for_model() { # Build the command with or without model-specific args BASE_CMD="CUDA_VISIBLE_DEVICES=$GPU_ID \ - VLLM_KV_CACHE_LAYOUT='HND' \ + VLLM_KV_CACHE_LAYOUT='$PREFILLER_KV_LAYOUT' \ UCX_NET_DEVICES=all \ VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT \ vllm serve $model_name \ diff --git a/tests/v1/kv_connector/nixl_integration/run_multi_connector_accuracy_test.sh b/tests/v1/kv_connector/nixl_integration/run_multi_connector_accuracy_test.sh index 2e71858983e9..86e9b34e51fb 100755 --- a/tests/v1/kv_connector/nixl_integration/run_multi_connector_accuracy_test.sh +++ b/tests/v1/kv_connector/nixl_integration/run_multi_connector_accuracy_test.sh @@ -8,9 +8,9 @@ # wrapping NixlConnector and OffloadingConnector, then runs gsm8k accuracy via # test_accuracy.py. # -# By default runs two configurations: -# 1. Normal KV layout (NixlConnector without cross-layer blocks) -# 2. Cross-layer KV layout (NixlConnector with enable_cross_layers_blocks) +# Runs two configurations: +# 1. Standard KV layout (LBHNC) +# 2. Cross-layer KV layout (BLHNC) via VLLM_KV_CACHE_LAYOUT # # Usage: # bash tests/v1/kv_connector/nixl_integration/run_multi_connector_accuracy_test.sh @@ -41,8 +41,6 @@ SMI_BIN=$(which nvidia-smi || which rocm-smi || echo "") # ── KV transfer configs ───────────────────────────────────────────────── -# Normal layout: OffloadingConnector prefers cross-layer but NixlConnector -# does not, so MultiConnector.prefer_cross_layer_blocks = False. KV_CONFIG_NORMAL='{ "kv_connector":"MultiConnector", "kv_role":"kv_both", @@ -57,21 +55,6 @@ KV_CONFIG_NORMAL='{ # Remove whitespace for CLI safety KV_CONFIG_NORMAL=$(echo "$KV_CONFIG_NORMAL" | tr -d '[:space:]') -# Cross-layer layout: both connectors prefer cross-layer blocks. -KV_CONFIG_CROSS_LAYERS='{ - "kv_connector":"MultiConnector", - "kv_role":"kv_both", - "kv_connector_extra_config":{ - "connectors":[ - {"kv_connector":"NixlConnector","kv_role":"kv_both", - "kv_connector_extra_config":{"enable_cross_layers_blocks":"True"}}, - {"kv_connector":"OffloadingConnector","kv_role":"kv_both", - "kv_connector_extra_config":{"cpu_bytes_to_use":1000000000}} - ] - } -}' -KV_CONFIG_CROSS_LAYERS=$(echo "$KV_CONFIG_CROSS_LAYERS" | tr -d '[:space:]') - # ── Helpers ────────────────────────────────────────────────────────────── trap 'kill $(jobs -pr) 2>/dev/null' SIGINT SIGTERM EXIT @@ -122,7 +105,7 @@ run_tests_for_model() { # ── Start prefill instance ── echo "Starting prefill instance on GPU $PREFILL_GPU, port $PREFILL_PORT" BASE_CMD="CUDA_VISIBLE_DEVICES=$PREFILL_GPU \ - VLLM_KV_CACHE_LAYOUT='HND' \ + VLLM_KV_CACHE_LAYOUT='${VLLM_KV_CACHE_LAYOUT:-LBHNC}' \ UCX_NET_DEVICES=all \ VLLM_NIXL_SIDE_CHANNEL_PORT=$PREFILL_SIDE_CHANNEL_PORT \ vllm serve $model_name \ @@ -144,7 +127,7 @@ run_tests_for_model() { # ── Start decode instance ── echo "Starting decode instance on GPU $DECODE_GPU, port $DECODE_PORT" BASE_CMD="CUDA_VISIBLE_DEVICES=$DECODE_GPU \ - VLLM_KV_CACHE_LAYOUT='HND' \ + VLLM_KV_CACHE_LAYOUT='${VLLM_KV_CACHE_LAYOUT:-LBHNC}' \ UCX_NET_DEVICES=all \ VLLM_NIXL_SIDE_CHANNEL_PORT=$DECODE_SIDE_CHANNEL_PORT \ vllm serve $model_name \ @@ -198,7 +181,8 @@ for model in "${MODELS[@]}"; do fi if [[ -z "${SKIP_CROSS_LAYERS:-}" ]]; then - run_tests_for_model "$model" "$KV_CONFIG_CROSS_LAYERS" "MultiConnector cross-layer layout" + VLLM_KV_CACHE_LAYOUT=BLHNC \ + run_tests_for_model "$model" "$KV_CONFIG_NORMAL" "MultiConnector cross-layer layout" fi done diff --git a/tests/v1/kv_connector/nixl_integration/run_multi_connector_edge_case_test.sh b/tests/v1/kv_connector/nixl_integration/run_multi_connector_edge_case_test.sh index a80950b34136..8eb86f51b3ff 100755 --- a/tests/v1/kv_connector/nixl_integration/run_multi_connector_edge_case_test.sh +++ b/tests/v1/kv_connector/nixl_integration/run_multi_connector_edge_case_test.sh @@ -92,7 +92,7 @@ run_tests_for_model() { # ── Start prefill instance ── echo "Starting prefill instance on GPU $PREFILL_GPU, port $PREFILL_PORT" BASE_CMD="CUDA_VISIBLE_DEVICES=$PREFILL_GPU \ - VLLM_KV_CACHE_LAYOUT='HND' \ + VLLM_KV_CACHE_LAYOUT='LBHNC' \ UCX_NET_DEVICES=all \ VLLM_NIXL_SIDE_CHANNEL_PORT=$PREFILL_SIDE_CHANNEL_PORT \ vllm serve \"$model_name\" \ @@ -115,7 +115,7 @@ run_tests_for_model() { # ── Start decode instance ── echo "Starting decode instance on GPU $DECODE_GPU, port $DECODE_PORT" BASE_CMD="CUDA_VISIBLE_DEVICES=$DECODE_GPU \ - VLLM_KV_CACHE_LAYOUT='HND' \ + VLLM_KV_CACHE_LAYOUT='LBHNC' \ UCX_NET_DEVICES=all \ VLLM_NIXL_SIDE_CHANNEL_PORT=$DECODE_SIDE_CHANNEL_PORT \ vllm serve \"$model_name\" \ diff --git a/tests/v1/kv_connector/nixl_integration/spec_decode_acceptance_test.sh b/tests/v1/kv_connector/nixl_integration/spec_decode_acceptance_test.sh index 2c5622a2f0e1..ef624ea9d8a6 100755 --- a/tests/v1/kv_connector/nixl_integration/spec_decode_acceptance_test.sh +++ b/tests/v1/kv_connector/nixl_integration/spec_decode_acceptance_test.sh @@ -243,7 +243,7 @@ run_test_for_device() { echo "Starting prefill instance $i on GPU $GPU_ID, port $PORT" env \ ${GPU_DEVICE_VAR}=$GPU_ID \ - VLLM_KV_CACHE_LAYOUT='HND' \ + VLLM_KV_CACHE_LAYOUT='LBHNC' \ UCX_NET_DEVICES=all \ ${VLLM_SSM_CONV_STATE_LAYOUT:+VLLM_SSM_CONV_STATE_LAYOUT=$VLLM_SSM_CONV_STATE_LAYOUT} \ VLLM_NIXL_SIDE_CHANNEL_HOST=$NIXL_SIDE_CHANNEL_HOST \ @@ -283,7 +283,7 @@ run_test_for_device() { echo "Starting decode instance $i on GPU $GPU_ID, port $PORT" env \ ${GPU_DEVICE_VAR}=$GPU_ID \ - VLLM_KV_CACHE_LAYOUT='HND' \ + VLLM_KV_CACHE_LAYOUT='LBHNC' \ UCX_NET_DEVICES=all \ ${VLLM_SSM_CONV_STATE_LAYOUT:+VLLM_SSM_CONV_STATE_LAYOUT=$VLLM_SSM_CONV_STATE_LAYOUT} \ VLLM_NIXL_SIDE_CHANNEL_HOST=$NIXL_SIDE_CHANNEL_HOST \ diff --git a/tests/v1/kv_connector/unit/offloading_connector/test_worker.py b/tests/v1/kv_connector/unit/offloading_connector/test_worker.py index 36294632bb9f..6cb8373d9c11 100644 --- a/tests/v1/kv_connector/unit/offloading_connector/test_worker.py +++ b/tests/v1/kv_connector/unit/offloading_connector/test_worker.py @@ -9,7 +9,6 @@ from vllm.platforms import current_platform from vllm.utils.torch_utils import get_dtype_size from vllm.v1.attention.backends.registry import AttentionBackendEnum -from vllm.v1.attention.backends.utils import set_kv_cache_layout from vllm.v1.kv_cache_interface import ( FullAttentionSpec, KVCacheConfig, @@ -57,31 +56,25 @@ def _allocate_and_reshape_kv_caches( Use the real GPUModelRunner allocation and reshape methods to produce kv_caches, just like the model runner does during initialization. """ + from vllm.v1.kv_cache_interface import KVCacheLayout from vllm.v1.worker.gpu_model_runner import GPUModelRunner - # Some backends (e.g. FlashAttention) query the KV cache layout during - # reshape, which ultimately calls get_current_vllm_config(). Setting - # the layout override avoids needing a full VllmConfig context. - set_kv_cache_layout("NHD") - try: - runner = object.__new__(GPUModelRunner) - runner.device = device - runner.runner_only_attn_layers = set() - runner.attn_groups = attn_groups - runner.kv_cache_config = kv_cache_config - runner.cache_config = MagicMock(cache_dtype="auto") - runner.shared_kv_cache_layers = {} - runner.model_config = MagicMock() - runner.model_config.hf_config.model_type = "" - runner.compilation_config = MagicMock( - static_forward_context=defaultdict(MagicMock) - ) - runner.kv_caches = [] - - kernel_block_sizes = [BLOCK_SIZE] * len(kv_cache_config.kv_cache_groups) - return runner.initialize_kv_cache_tensors(kv_cache_config, kernel_block_sizes) - finally: - set_kv_cache_layout(None) + runner = object.__new__(GPUModelRunner) + runner.device = device + runner.runner_only_attn_layers = set() + runner.attn_groups = attn_groups + runner.kv_cache_config = kv_cache_config + runner.cache_config = MagicMock(cache_dtype="auto") + runner.shared_kv_cache_layers = {} + runner.model_config = MagicMock() + runner.model_config.hf_config.model_type = "" + runner.compilation_config = MagicMock(static_forward_context=defaultdict(MagicMock)) + runner.kv_caches = [] + + kernel_block_sizes = [BLOCK_SIZE] * len(kv_cache_config.kv_cache_groups) + return runner._allocate_and_reshape_kv_cache( + kv_cache_config, kernel_block_sizes, layout=KVCacheLayout.LBNHC + ) def _make_worker(kv_cache_config: KVCacheConfig): @@ -208,18 +201,19 @@ def test_register_kv_caches(backend): aligned_mamba_layer_names, ] - kv_cache_tensors: list[KVCacheTensor] = [] + shared_by: list[list[str]] = [] for i in range(GROUP_SIZE): - shared_by: list[str] = [] + slot_layers: list[str] = [] for group_layer_names in layer_groups: if len(group_layer_names) > i: - shared_by.append(group_layer_names[i]) - kv_cache_tensors.append( - KVCacheTensor( - size=PAGE_SIZE_BYTES * NUM_BLOCKS, - shared_by=shared_by, - ) + slot_layers.append(group_layer_names[i]) + shared_by.append(slot_layers) + kv_cache_tensors: list[KVCacheTensor] = [ + KVCacheTensor( + size=PAGE_SIZE_BYTES * NUM_BLOCKS * GROUP_SIZE, + shared_by=shared_by, ) + ] kv_cache_groups = [ KVCacheGroupSpec(layer_names=attn_layer_names, kv_cache_spec=attn_spec), @@ -378,11 +372,11 @@ def test_register_kv_caches_uniform_type(backend): kv_cache_tensors=[ KVCacheTensor( size=spec_a.page_size_bytes * NUM_BLOCKS, - shared_by=[layer_a], + shared_by=[[layer_a]], ), KVCacheTensor( size=spec_b.page_size_bytes * NUM_BLOCKS, - shared_by=[layer_b], + shared_by=[[layer_b]], ), ], kv_cache_groups=[ diff --git a/tests/v1/kv_connector/unit/offloading_connector/utils.py b/tests/v1/kv_connector/unit/offloading_connector/utils.py index 22d00b0c834b..90ba96db2723 100644 --- a/tests/v1/kv_connector/unit/offloading_connector/utils.py +++ b/tests/v1/kv_connector/unit/offloading_connector/utils.py @@ -250,7 +250,7 @@ def __init__( ) # register worker kv_caches to enable OffloadingWorker creations - # set_current_vllm_config is needed for get_kv_cache_layout() to work + # set_current_vllm_config is needed for resolve_kv_cache_layout() to work kv_caches: dict[str, torch.Tensor] = {} for group in kv_cache_groups: spec = group.kv_cache_spec diff --git a/tests/v1/kv_connector/unit/test_bidirectional_kv_transfer.py b/tests/v1/kv_connector/unit/test_bidirectional_kv_transfer.py index dc76d61178d8..1277141347a0 100644 --- a/tests/v1/kv_connector/unit/test_bidirectional_kv_transfer.py +++ b/tests/v1/kv_connector/unit/test_bidirectional_kv_transfer.py @@ -99,7 +99,7 @@ def _make_connector_with_fake_worker( worker = connector.connector_worker assert isinstance(worker.nixl_wrapper, FakeNixlWrapper) worker.nixl_wrapper.set_cycles_before_xfer_done(cycles_before_done) - worker.kv_cache_layout = "HND" + worker.kv_cache_layout = "LBHNC" if do_handshake: remote_agents = worker._nixl_handshake( host="localhost", diff --git a/tests/v1/kv_connector/unit/test_kv_cache_layout.py b/tests/v1/kv_connector/unit/test_kv_cache_layout.py index 7f8028991703..bd36f6c6c848 100644 --- a/tests/v1/kv_connector/unit/test_kv_cache_layout.py +++ b/tests/v1/kv_connector/unit/test_kv_cache_layout.py @@ -1,36 +1,59 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Unit tests for reshape_kv_cache.""" +import pytest +import torch -def test_mla_backend_rejects_cross_layer_kv_cache(): - """MLA backends return identity permutation (layers dim first) - to signal cross-layer KV cache is unsupported.""" - from vllm.model_executor.layers.attention.mla_attention import ( - MLACommonBackend, - ) +from vllm.v1.kv_cache_interface import ( + FullAttentionSpec, + KVCacheLayout, + compute_layer_kv_cache_shape_bytes, + reshape_kv_cache, +) + +NUM_BLOCKS = 4 +BLOCK_SIZE = 4 +NUM_KV_HEADS = 2 +HEAD_SIZE = 8 +DTYPE = torch.bfloat16 - stride_order = MLACommonBackend.get_kv_cache_stride_order( - include_num_layers_dimension=True - ) - assert stride_order == (0, 1, 2, 3) - assert stride_order[0] == 0 # layers dim first => no cross-layer - assert MLACommonBackend.get_kv_cache_stride_order( - include_num_layers_dimension=False - ) == (0, 1, 2) - - -def test_deepseek_v32_indexer_rejects_cross_layer_kv_cache(): - """DeepseekV32Indexer returns identity permutation (layers dim first) - to signal cross-layer KV cache is unsupported.""" - from vllm.v1.attention.backends.mla.indexer import ( - DeepseekV32IndexerBackend, - ) - stride_order = DeepseekV32IndexerBackend.get_kv_cache_stride_order( - include_num_layers_dimension=True +@pytest.mark.parametrize( + "layout", [layer for layer in KVCacheLayout if layer.is_layer_compact] +) +def test_reshape_kv_cache(layout): + spec = FullAttentionSpec( + block_size=BLOCK_SIZE, + num_kv_heads=NUM_KV_HEADS, + head_size=HEAD_SIZE, + dtype=DTYPE, ) - assert stride_order == (0, 1, 2, 3) - assert stride_order[0] == 0 # layers dim first => no cross-layer - assert DeepseekV32IndexerBackend.get_kv_cache_stride_order( - include_num_layers_dimension=False - ) == (0, 1, 2) + num_slots = 2 + total_bytes = spec.page_size_bytes * NUM_BLOCKS * num_slots + raw = torch.zeros(total_bytes, dtype=torch.int8, device="cuda") + views = reshape_kv_cache(raw, spec, NUM_BLOCKS, num_slots, layout) + + byte_4d = compute_layer_kv_cache_shape_bytes(spec, NUM_BLOCKS) + dtype_size = torch.tensor([], dtype=spec.dtype).element_size() + expected_shape = (*byte_4d[:3], byte_4d[3] // dtype_size) + + assert len(views) == num_slots + for v in views: + assert v.shape == expected_shape + assert v.dtype == spec.dtype + + # The physical layout's innermost 3 dims (H, N, C minus the L dim) + # should match the layout's stride_order permutation: dims later in + # the physical order have smaller strides. + stride_order = layout.layer_stride_order + strides = views[0].stride() + for i in range(3): + for j in range(i + 1, 4): + if stride_order[i] < stride_order[j]: + assert strides[i] >= strides[j], ( + f"layout {layout.name}: dim {i} (physical pos " + f"{stride_order[i]}) should have >= stride than " + f"dim {j} (physical pos {stride_order[j]}), got " + f"strides={strides}" + ) diff --git a/tests/v1/kv_connector/unit/test_mooncake_connector.py b/tests/v1/kv_connector/unit/test_mooncake_connector.py index a10ae1f456ed..b5147124aebf 100644 --- a/tests/v1/kv_connector/unit/test_mooncake_connector.py +++ b/tests/v1/kv_connector/unit/test_mooncake_connector.py @@ -25,8 +25,11 @@ MooncakeBootstrapServer, ) from vllm.utils.network_utils import get_open_port -from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend -from vllm.v1.kv_cache_interface import KVCacheConfig +from vllm.v1.kv_cache_interface import ( + FullAttentionSpec, + KVCacheConfig, + compute_layer_kv_cache_shape_bytes, +) from vllm.v1.request import RequestStatus from .utils import create_request, create_scheduler, create_vllm_config @@ -369,26 +372,21 @@ async def test_kv_producer(monkeypatch): with patch.object( prefill_worker, "_send_blocks", return_value=0 ) as mock_send_blocks: - # With blocks-first layout, each block is virtually split - # into K and V halves, producing non-coalesced transfers. - kv_half = block_len // 2 - - def expected_split_transfers(src_base, dst_base, src_blocks, dst_blocks): - """Build expected (src_ptrs, dst_ptrs, lengths) for - virtual-split K/V transfers.""" - src_ptrs, dst_ptrs, lengths = [], [], [] - for kv_offset in (0, kv_half): - for sb, db in zip(src_blocks, dst_blocks): - src_ptrs.append(src_base + sb * block_len + kv_offset) - dst_ptrs.append(dst_base + db * block_len + kv_offset) - lengths.append(kv_half) - return src_ptrs, dst_ptrs, lengths + # Under the standardized blocks-first layout K and V are packed + # into a single contiguous region per block. Adjacent blocks are + # coalesced into a single larger transfer; all cases below pass + # a single contiguous run. + def expected_transfers(src_base, dst_base, src_blocks, dst_blocks): + n = len(src_blocks) + return ( + [src_base + src_blocks[0] * block_len], + [dst_base + dst_blocks[0] * block_len], + [n * block_len], + ) # Normal case: 2 blocks to 2 blocks await prefill_worker.send_kv_to_decode(identity, mock_socket, xfer_meta) - src, dst, lens = expected_split_transfers( - 0x1000, 0x2000, [10, 11], [20, 21] - ) + src, dst, lens = expected_transfers(0x1000, 0x2000, [10, 11], [20, 21]) mock_send_blocks.assert_called_once_with( "consumer-host:54321", src, @@ -420,7 +418,7 @@ def expected_split_transfers(src_base, dst_base, src_blocks, dst_blocks): # Worker processes the consumer's request await prefill_worker.send_kv_to_decode(identity, mock_socket, xfer_meta) # Verify transfer parameters are correct: 11 to 20 - src, dst, lens = expected_split_transfers(0x1000, 0x2000, [11], [20]) + src, dst, lens = expected_transfers(0x1000, 0x2000, [11], [20]) mock_send_blocks.assert_called_once_with( "consumer-host:54321", src, @@ -621,11 +619,14 @@ def test_register_kv_caches(): worker = connector.connector_worker mock_thread.return_value.is_alive.return_value = False - kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape( - num_blocks=2, block_size=16, num_kv_heads=4, head_size=64 + shape = compute_layer_kv_cache_shape_bytes( + FullAttentionSpec( + block_size=16, num_kv_heads=4, head_size=64, dtype=torch.float16 + ), + 2, ) - tensor1 = torch.zeros(*kv_cache_shape, dtype=torch.float16) - tensor2 = torch.zeros(*kv_cache_shape, dtype=torch.float16) + tensor1 = torch.zeros(*shape, dtype=torch.int8).view(torch.float16) + tensor2 = torch.zeros(*shape, dtype=torch.int8).view(torch.float16) kv_caches = {"layer0": tensor1, "layer1": tensor2} with patch.object( @@ -642,7 +643,7 @@ def test_register_kv_caches(): # Verify block_len_per_layer is set correctly. assert len(worker.block_len_per_layer) == len(registered_ptrs) for bl in worker.block_len_per_layer: - assert bl == tensor1.nbytes // tensor1.shape[0] + assert bl == tensor1.stride(0) * tensor1.element_size() def test_register_kv_caches_supports_mixed_mla_and_eagle_shapes(): @@ -806,47 +807,34 @@ async def test_kv_producer_heterogeneous_tp(monkeypatch, d_tp_size): flat_remote = [b for g in remote_block_ids for b in g] num_blocks = len(flat_local) - # With blocks-first layout, virtual split halves block - # lengths and doubles transfer regions (K + V). - local_kv_block_len = local_block_len // 2 - remote_kv_block_len = remote_block_len // 2 + # Under the standardized blocks-first layout K and V are + # already packed into a single contiguous region per block, + # so _expand_transfer_regions emits one region per layer. + assert len(src_ptrs) == num_blocks + assert len(dst_ptrs) == num_blocks + assert len(lengths) == num_blocks - assert len(src_ptrs) == 2 * num_blocks - assert len(dst_ptrs) == 2 * num_blocks - assert len(lengths) == 2 * num_blocks - - # Compute expected offsets using kv_block_len if d_tp_size <= P_TP_SIZE: tp_ratio = P_TP_SIZE // d_tp_size expected_src_off = 0 - expected_dst_off = (P_TP_RANK % tp_ratio) * local_kv_block_len - expected_xfer_len = local_kv_block_len + expected_dst_off = (P_TP_RANK % tp_ratio) * local_block_len + expected_xfer_len = local_block_len else: ratio_abs = d_tp_size // P_TP_SIZE - expected_src_off = (d_rank % ratio_abs) * remote_kv_block_len + expected_src_off = (d_rank % ratio_abs) * remote_block_len expected_dst_off = 0 - expected_xfer_len = remote_kv_block_len - - # First num_blocks entries are K region, - # next num_blocks are V region. - for region_idx in range(2): - local_region_base = 0x1000 + region_idx * local_kv_block_len - remote_region_base = 0x2000 + region_idx * remote_kv_block_len - for blk_idx, (lblk, rblk) in enumerate( - zip(flat_local, flat_remote) - ): - idx = region_idx * num_blocks + blk_idx - assert src_ptrs[idx] == ( - local_region_base - + lblk * local_block_len - + expected_src_off - ) - assert dst_ptrs[idx] == ( - remote_region_base - + rblk * remote_block_len - + expected_dst_off - ) - assert lengths[idx] == expected_xfer_len + expected_xfer_len = remote_block_len + + local_region_base = 0x1000 + remote_region_base = 0x2000 + for blk_idx, (lblk, rblk) in enumerate(zip(flat_local, flat_remote)): + assert src_ptrs[blk_idx] == ( + local_region_base + lblk * local_block_len + expected_src_off + ) + assert dst_ptrs[blk_idx] == ( + remote_region_base + rblk * remote_block_len + expected_dst_off + ) + assert lengths[blk_idx] == expected_xfer_len # Verify successful response sent back to consumer mock_socket.send_multipart.assert_called_once() diff --git a/tests/v1/kv_connector/unit/test_moriio_connector.py b/tests/v1/kv_connector/unit/test_moriio_connector.py index 2a5c96a46e5a..b9097865f12f 100644 --- a/tests/v1/kv_connector/unit/test_moriio_connector.py +++ b/tests/v1/kv_connector/unit/test_moriio_connector.py @@ -35,7 +35,11 @@ get_ip, make_zmq_path, ) -from vllm.v1.kv_cache_interface import KVCacheConfig +from vllm.v1.kv_cache_interface import ( + FullAttentionSpec, + KVCacheConfig, + compute_layer_kv_cache_shape_bytes, +) from .utils import create_request, create_scheduler @@ -174,7 +178,7 @@ class FakeMoRIIOConnectorWorker(MoRIIOConnectorWorker): REMOTE_ENGINE_ID = "remote_engine" def __init__( - self, *args, hand_shake_latency: float = 1.8, kv_cache_layout="HND", **kwargs + self, *args, hand_shake_latency: float = 1.8, kv_cache_layout="LBHNC", **kwargs ): super().__init__(*args, **kwargs) @@ -415,16 +419,15 @@ def test_register_kv_caches(mock_parallel_groups): DEFAULT_PORT = 6301 TP_RANK = 0 DP_RANK = 0 - from vllm.v1.attention.backends.rocm_aiter_fa import AiterFlashAttentionBackend - - backend_cls = AiterFlashAttentionBackend - - # Create test kv cache tensors using proper backend shape - kv_cache_shape = backend_cls.get_kv_cache_shape( - num_blocks=2, block_size=16, num_kv_heads=4, head_size=64 + # Create test kv cache tensors using KVCacheSpec layout + shape = compute_layer_kv_cache_shape_bytes( + FullAttentionSpec( + block_size=16, num_kv_heads=4, head_size=64, dtype=torch.float16 + ), + 2, ) - shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16) - unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16) + shared_tensor = torch.zeros(*shape, dtype=torch.int8).view(torch.float16) + unique_tensor = torch.zeros(*shape, dtype=torch.int8).view(torch.float16) kv_caches = { "layer0": shared_tensor, "layer1": unique_tensor, @@ -511,16 +514,15 @@ def test_moriio_handshake_returns_metadata(mock_parallel_groups): ROLE = "kv_consumer" vllm_config = create_vllm_config(role=ROLE) - from vllm.v1.attention.backends.rocm_aiter_fa import AiterFlashAttentionBackend - - backend_cls = AiterFlashAttentionBackend - - # Create test kv cache tensors using proper backend shape - kv_cache_shape = backend_cls.get_kv_cache_shape( - num_blocks=2, block_size=16, num_kv_heads=4, head_size=64 + # Create test kv cache tensors using KVCacheSpec layout + shape = compute_layer_kv_cache_shape_bytes( + FullAttentionSpec( + block_size=16, num_kv_heads=4, head_size=64, dtype=torch.float16 + ), + 2, ) - shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16) - unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16) + shared_tensor = torch.zeros(*shape, dtype=torch.int8).view(torch.float16) + unique_tensor = torch.zeros(*shape, dtype=torch.int8).view(torch.float16) kv_caches = { "layer0": shared_tensor, "layer1": unique_tensor, diff --git a/tests/v1/kv_connector/unit/test_multi_connector.py b/tests/v1/kv_connector/unit/test_multi_connector.py index d1f3a81ca969..15089599d294 100644 --- a/tests/v1/kv_connector/unit/test_multi_connector.py +++ b/tests/v1/kv_connector/unit/test_multi_connector.py @@ -844,16 +844,6 @@ def test_multi_connector_overrides_all_base_methods(): """) -def test_multi_connector_prefer_cross_layer_blocks(mc): - mc._connectors[0].prefer_cross_layer_blocks = False - mc._connectors[1].prefer_cross_layer_blocks = True - assert mc.prefer_cross_layer_blocks is False - - mc._connectors[0].prefer_cross_layer_blocks = True - mc._connectors[1].prefer_cross_layer_blocks = True - assert mc.prefer_cross_layer_blocks is True - - def test_multi_connector_worker_metadata(mc): class MockConnectorWorkerMetadata(KVConnectorWorkerMetadata): def __init__(self, data: set[str]): diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index f07a8352e735..130a529d7e51 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -51,7 +51,6 @@ from vllm.platforms import current_platform from vllm.platforms.interface import Platform from vllm.sampling_params import SamplingParams -from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend from vllm.v1.attention.backends.utils import set_kv_cache_layout from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.output_processor import OutputProcessor @@ -60,12 +59,10 @@ FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec, - KVCacheTensor, + compute_layer_kv_cache_shape_bytes, ) from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput from vllm.v1.request import RequestStatus -from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin -from vllm.v1.worker.utils import AttentionGroup from .utils import ( create_request, @@ -355,7 +352,6 @@ def test_abort_immediately_remote_prefill_enqueues_empty_recv(): ) def test_kv_transfer_handshake(dist_init): """Unit test for basic NixlConnector interface functionality.""" - from vllm.config import set_current_vllm_config # Test setup, we creates a scheduler that contains a NixlConnector # of role SCHEDULER, and expect it to be serving NixlAgentMetadata from @@ -391,14 +387,11 @@ def test_kv_transfer_handshake(dist_init): kv_cache_spec = cast( AttentionSpec, kv_cache_config.kv_cache_groups[0].kv_cache_spec ) - kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape( - num_blocks=kv_cache_config.num_blocks, - block_size=kv_cache_spec.block_size, - num_kv_heads=kv_cache_spec.num_kv_heads, - head_size=kv_cache_spec.head_size, + shape = compute_layer_kv_cache_shape_bytes( + kv_cache_spec, kv_cache_config.num_blocks ) - shared_tensor = torch.zeros(*kv_cache_shape, dtype=kv_cache_spec.dtype) - unique_tensor = torch.zeros(*kv_cache_shape, dtype=kv_cache_spec.dtype) + shared_tensor = torch.zeros(*shape, dtype=torch.int8).view(kv_cache_spec.dtype) + unique_tensor = torch.zeros(*shape, dtype=torch.int8).view(kv_cache_spec.dtype) kv_caches = { "layer0": shared_tensor, "layer1": unique_tensor, @@ -477,7 +470,7 @@ def __init__( self, *args, hand_shake_latency: float = 1.8, - kv_cache_layout="HND", + kv_cache_layout="LBHNC", kv_cache_config=None, **kwargs, ): @@ -488,9 +481,8 @@ def __init__( self.kv_cache_layout = kv_cache_layout # Mock register_kv_caches attribute needed for tests that do not call it. self.src_xfer_handles_by_block_size = {self.block_size: 1} - test_shape = self.attn_backends[0].get_kv_cache_shape( - num_blocks=1, block_size=16, num_kv_heads=1, head_size=1 - ) + rep_spec = self.kv_cache_config.kv_cache_groups[0].kv_cache_spec + test_shape = compute_layer_kv_cache_shape_bytes(rep_spec, 1) self.transfer_topo = TransferTopology( tp_rank=self.tp_rank, tp_size=self.world_size, @@ -504,7 +496,7 @@ def __init__( ) self.compat_hash = compute_nixl_compatibility_hash( - self.vllm_config, self.backend_name, self.transfer_topo.cross_layers_blocks + self.vllm_config, self.backend_name ) def _nixl_handshake( @@ -549,9 +541,8 @@ def _nixl_handshake( device_id=remote_tp_rank, num_blocks=1, block_lens=remote_block_lens, - # `self.kv_cache_layout` is only forced to HND when vllm engine - # is started. We mock HND here. - kv_cache_layout="HND", + block_strides=remote_block_lens, + kv_cache_layout="LBHNC", block_size=self.block_size, ssm_sizes=(0, 0), attn_backend_name=self.backend_name, @@ -603,7 +594,7 @@ def test_multi_xfer_one_engine( worker.dst_xfer_side_handles = { FakeNixlConnectorWorker.REMOTE_ENGINE_ID: {0: 1} } - worker.kv_cache_layout = "HND" + worker.kv_cache_layout = "LBHNC" num_xfers = 4 while True: # For the same request_id, initiate multiple xfers across different @@ -996,7 +987,9 @@ def test_handshake_fails_on_kv_cache_layout_mismatch( worker.dst_num_blocks[worker.engine_id] = worker.num_blocks # Metadata with different kv_cache_layout than local worker - mismatched_layout = "HND" if worker.kv_cache_layout != "HND" else "NHD" + mismatched_layout = ( + "LBHNC" if worker.kv_cache_layout != "LBHNC" else "LBNHC" + ) meta = NixlAgentMetadata( engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID, agent_metadata=FakeNixlWrapper.AGENT_METADATA, @@ -1004,6 +997,7 @@ def test_handshake_fails_on_kv_cache_layout_mismatch( device_id=0, num_blocks=1, block_lens=worker.block_len_per_layer, + block_strides=worker.block_len_per_layer, kv_cache_layout=mismatched_layout, block_size=worker.block_size, ssm_sizes=(0, 0), @@ -1043,7 +1037,7 @@ def test_handshake_succeed_on_kv_cache_layout_mismatch_with_experimental( vllm_config, connector.engine_id, hand_shake_latency=0, - kv_cache_layout="NHD", + kv_cache_layout="LBNHC", ) worker = connector.connector_worker @@ -1054,15 +1048,16 @@ def test_handshake_succeed_on_kv_cache_layout_mismatch_with_experimental( worker.dst_num_blocks[worker.engine_id] = worker.num_blocks # Metadata with different kv_cache_layout than local worker + remote_block_lens = [i * 2 for i in worker.block_len_per_layer] meta = NixlAgentMetadata( engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID, agent_metadata=FakeNixlWrapper.AGENT_METADATA, kv_caches_base_addr=[0], device_id=0, num_blocks=1, - # prefill TP=1, decode TP=2, remote block_lens is double to local - block_lens=[i * 2 for i in worker.block_len_per_layer], - kv_cache_layout="HND", + block_lens=remote_block_lens, + block_strides=remote_block_lens, + kv_cache_layout="LBHNC", block_size=worker.block_size, ssm_sizes=(0, 0), attn_backend_name=worker.backend_name, @@ -1490,7 +1485,6 @@ def req_id(outputs: list[RequestOutput]) -> str: llm.llm_engine.engine_core.shutdown() -@pytest.mark.parametrize("enable_cross_layers", ["False", "True"]) @pytest.mark.parametrize( "attn_backend", [ @@ -1511,9 +1505,7 @@ def req_id(outputs: list[RequestOutput]) -> str: "TRITON_ATTN", ], ) -def test_register_kv_caches( - default_vllm_config, dist_init, attn_backend, enable_cross_layers -): +def test_register_kv_caches(default_vllm_config, dist_init, attn_backend): """ Test that register_kv_caches() properly calls nixl_wrapper methods with correct data. @@ -1526,12 +1518,7 @@ def test_register_kv_caches( """ vllm_config = create_vllm_config(attention_backend=attn_backend) - - # Enable cross layers blocks - vllm_config.kv_transfer_config.kv_connector_extra_config[ - "enable_cross_layers_blocks" - ] = enable_cross_layers - set_kv_cache_layout("HND") + set_kv_cache_layout("LBHNC") # Import the appropriate backend based on the parameter if attn_backend == "FLASH_ATTN": @@ -1548,59 +1535,31 @@ def test_register_kv_caches( backend_cls = TritonAttentionBackend nixl_worker = "vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker" - nixl_connector = "vllm.distributed.kv_transfer.kv_connector.v1.nixl.connector" with ( patch(f"{nixl_worker}.NixlWrapper") as mock_nixl_wrapper, patch(f"{nixl_worker}.threading.Event"), patch(f"{nixl_worker}.threading.Thread") as mock_thread, - patch(f"{nixl_connector}.get_current_attn_backend") as mock_get_attn_backend, patch(f"{nixl_worker}.get_current_attn_backends") as mock_get_attn_backends, ): - # Ensure get_attn_backend returns the correct value due to - # _cached_get_attn_backend returning the backend from previous - # test run if not mocking. - mock_get_attn_backend.return_value = backend_cls mock_get_attn_backends.return_value = [backend_cls] - num_layers = 32 - block_size = 16 num_blocks = 8 + block_size = 16 num_heads = 4 head_size = 16 - # TODO (NickLucche) the fact that connector depends on kv_cache_config for init - # but cross-layer preference cant be inferred prior to creating kv_cache_config - # is a bit awkward. - dummy_connector = NixlConnector( - vllm_config, - KVConnectorRole.WORKER, - make_kv_cache_config(block_size=block_size), - ) kv_cache_spec = FullAttentionSpec( block_size=block_size, num_kv_heads=num_heads, head_size=head_size, dtype=torch.float16, ) - if dummy_connector.prefer_cross_layer_blocks: - kv_cache_config = KVCacheConfig( - num_blocks=num_blocks, - kv_cache_tensors=[ - KVCacheTensor( - size=kv_cache_spec.page_size_bytes * num_blocks, - shared_by=["all-layers"], - ) - for _ in range(num_layers) - ], - kv_cache_groups=[KVCacheGroupSpec(["all-layers"], kv_cache_spec)], - ) - else: - kv_cache_config = KVCacheConfig( - num_blocks=num_blocks, - kv_cache_tensors=[], - kv_cache_groups=[ - KVCacheGroupSpec(["layer0", "layer1", "layer2"], kv_cache_spec) - ], - ) + kv_cache_config = KVCacheConfig( + num_blocks=num_blocks, + kv_cache_tensors=[], + kv_cache_groups=[ + KVCacheGroupSpec(["layer0", "layer1", "layer2"], kv_cache_spec) + ], + ) # Create connector connector = NixlConnector(vllm_config, KVConnectorRole.WORKER, kv_cache_config) connector.connector_worker = FakeNixlConnectorWorker( @@ -1620,93 +1579,28 @@ def test_register_kv_caches( # Reassure the shutdown() check that the thread is terminated mock_thread.return_value.is_alive.return_value = False - expected_tensor_size: int - expected_base_addrs: list[int] - expected_num_entries: int - kv_caches: dict[str, torch.Tensor] - if str(enable_cross_layers).lower() == "true": - assert connector.prefer_cross_layer_blocks == ( - attn_backend in ("FLASH_ATTN", "FLASHINFER", "TRITON_ATTN") - ) - else: - assert not connector.prefer_cross_layer_blocks - - test_shape = backend_cls.get_kv_cache_shape( - num_blocks=1, block_size=16, num_kv_heads=1, head_size=1 + # Create test kv cache tensors using proper backend shape + shape = compute_layer_kv_cache_shape_bytes( + kv_cache_spec, kv_cache_config.num_blocks ) - is_blocks_first = len(test_shape) == 5 and test_shape[0] == 1 - virtually_split = is_blocks_first and not connector.prefer_cross_layer_blocks - - if connector.prefer_cross_layer_blocks: - with set_current_vllm_config(vllm_config): - _, cross_layers_kv_cache, _ = ( - KVConnectorModelRunnerMixin.allocate_uniform_kv_caches( - kv_cache_config=kv_cache_config, - attn_groups=[ - [ - AttentionGroup( - backend=backend_cls, - layer_names=[], - kv_cache_spec=kv_cache_spec, - kv_cache_group_id=0, - ) - ] - ], - cache_dtype="bfloat16", - device=torch.accelerator.current_device_index(), - kernel_block_sizes=[block_size], - ) - ) - # Store tensor info for validation - expected_tensor_size = ( - cross_layers_kv_cache.element_size() * cross_layers_kv_cache.numel() - ) - expected_base_addrs = [ - cross_layers_kv_cache.data_ptr(), - ] - expected_num_entries = 1 - - expected_blocks_count = num_blocks * (2 if virtually_split else 1) - - kv_caches = {"all-layers": cross_layers_kv_cache} - else: - # Create test kv cache tensors using proper backend shape - kv_cache_shape = backend_cls.get_kv_cache_shape( - num_blocks=kv_cache_config.num_blocks, - block_size=kv_cache_spec.block_size, - num_kv_heads=kv_cache_spec.num_kv_heads, - head_size=kv_cache_spec.head_size, - ) - shared_tensor = torch.zeros(*kv_cache_shape, dtype=kv_cache_spec.dtype) - unique_tensor = torch.zeros(*kv_cache_shape, dtype=kv_cache_spec.dtype) - kv_caches = { - "layer0": shared_tensor, - "layer1": unique_tensor, - "layer2": shared_tensor, - } + shared_tensor = torch.zeros(*shape, dtype=torch.int8).view(kv_cache_spec.dtype) + unique_tensor = torch.zeros(*shape, dtype=torch.int8).view(kv_cache_spec.dtype) + kv_caches = { + "layer0": shared_tensor, + "layer1": unique_tensor, + "layer2": shared_tensor, + } - # Store tensor info for validation - if is_blocks_first: - expected_tensor_size = ( - shared_tensor.element_size() * shared_tensor.numel() - ) - expected_base_addrs = [ - shared_tensor.data_ptr(), - unique_tensor.data_ptr(), - ] - expected_num_entries = 2 - else: - expected_tensor_size = ( - shared_tensor[0].element_size() * shared_tensor[0].numel() - ) - expected_base_addrs = [ - shared_tensor[0].data_ptr(), - shared_tensor[1].data_ptr(), - unique_tensor[0].data_ptr(), - unique_tensor[1].data_ptr(), - ] - expected_num_entries = 4 - expected_blocks_count = kv_cache_config.num_blocks * 4 + # Per-layer shape is always 4D (B, N, H, C) — no separate K/V dim. + expected_tensor_size = shared_tensor.element_size() * shared_tensor.numel() + expected_base_addrs = [ + shared_tensor.data_ptr(), + unique_tensor.data_ptr(), + ] + expected_num_entries = 2 + # Packed KV layout: full-page descriptors (no K/V split). + # 2 unique regions × num_blocks. + expected_blocks_count = kv_cache_config.num_blocks * 2 # Execute register_kv_caches connector.register_kv_caches(kv_caches) @@ -1735,15 +1629,7 @@ def test_register_kv_caches( f"Expected {expected_blocks_count} blocks, got {len(blocks_data)}" ) - if connector.prefer_cross_layer_blocks: - num_blocks = 8 - else: - num_blocks = kv_cache_config.num_blocks - - if virtually_split: - expected_block_len = expected_tensor_size // num_blocks // 2 - else: - expected_block_len = expected_tensor_size // num_blocks + expected_block_len = expected_tensor_size // num_blocks for i, block_entry in enumerate(blocks_data): block_start_addr, block_len, tp_rank = block_entry @@ -2449,14 +2335,11 @@ def test_compatibility_hash_validation( kv_cache_spec = cast( AttentionSpec, kv_cache_config.kv_cache_groups[0].kv_cache_spec ) - kv_cache_shape = decode_worker.attn_backends[0].get_kv_cache_shape( - num_blocks=kv_cache_config.num_blocks, - block_size=kv_cache_spec.block_size, - num_kv_heads=kv_cache_spec.num_kv_heads, - head_size=kv_cache_spec.head_size, + shape = compute_layer_kv_cache_shape_bytes( + kv_cache_spec, kv_cache_config.num_blocks ) - shared_tensor = torch.zeros(*kv_cache_shape, dtype=kv_cache_spec.dtype) - unique_tensor = torch.zeros(*kv_cache_shape, dtype=kv_cache_spec.dtype) + shared_tensor = torch.zeros(*shape, dtype=torch.int8).view(kv_cache_spec.dtype) + unique_tensor = torch.zeros(*shape, dtype=torch.int8).view(kv_cache_spec.dtype) # Build kv_caches from the actual layer names in kv_cache_config so that # _layer_specs lookups in register_kv_caches always find a matching key. layer_names = [ @@ -2491,18 +2374,19 @@ def test_compatibility_hash_validation( remote_hash = compute_nixl_compatibility_hash( remote_vllm_config, decode_worker.backend_name, - decode_worker.transfer_topo.cross_layers_blocks, ) prefill_block_size = config_overrides.get("block_size", 16) + prefill_block_lens = [4096 * prefill_block_size] prefill_metadata = NixlAgentMetadata( engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID, agent_metadata=FakeNixlWrapper.AGENT_METADATA, kv_caches_base_addr=[0], device_id=0, num_blocks=1, - block_lens=[4096 * prefill_block_size], # slot_size * block_size - kv_cache_layout="HND", + block_lens=prefill_block_lens, + block_strides=prefill_block_lens, + kv_cache_layout="LBHNC", block_size=prefill_block_size, ssm_sizes=(0, 0), attn_backend_name=decode_worker.backend_name, @@ -2576,9 +2460,8 @@ def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario) decode_worker = decode_connector.connector_worker backend = get_current_attn_backend(local_vllm_config) - test_shape = backend.get_kv_cache_shape( - num_blocks=1, block_size=16, num_kv_heads=1, head_size=1 - ) + probe_spec = decode_worker.kv_cache_config.kv_cache_groups[0].kv_cache_spec + test_shape = compute_layer_kv_cache_shape_bytes(probe_spec, 1) decode_worker.transfer_topo = TransferTopology( tp_rank=decode_worker.tp_rank, tp_size=decode_worker.world_size, @@ -2594,7 +2477,6 @@ def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario) decode_worker.compat_hash = compute_nixl_compatibility_hash( decode_worker.vllm_config, decode_worker.backend_name, - decode_worker.transfer_topo.cross_layers_blocks, ) if error_scenario == "handshake_decode_error": diff --git a/tests/v1/simple_kv_offload/test_scheduler.py b/tests/v1/simple_kv_offload/test_scheduler.py index 970e16e52798..f65d66326492 100644 --- a/tests/v1/simple_kv_offload/test_scheduler.py +++ b/tests/v1/simple_kv_offload/test_scheduler.py @@ -84,7 +84,7 @@ def _make_kv_cache_config( tensors.append( KVCacheTensor( size=_BYTES_PER_BLOCK * num_blocks, - shared_by=layer_names, + shared_by=[layer_names], ) ) return KVCacheConfig( diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index 1a1352249c33..df61f8a36d6f 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -67,7 +67,7 @@ def initialize_kv_cache(runner: GPUModelRunner): kv_cache_config = KVCacheConfig( num_blocks=NUM_BLOCKS, kv_cache_tensors=[ - KVCacheTensor(size=tensor_size, shared_by=["layer.0"]), + KVCacheTensor(size=tensor_size, shared_by=[["layer.0"]]), ], kv_cache_groups=[ KVCacheGroupSpec(layer_names=["layer.0"], kv_cache_spec=attn_spec) @@ -673,55 +673,6 @@ def test_update_states_pp_async_multi_request_keeps_rank_state_consistent( ) -def test_kv_cache_stride_order(monkeypatch, model_runner): - # This test checks if GPUModelRunner initializes correctly when an attention - # backend enforces a non-default KV cache stride order. - n_heads = model_runner.model_config.get_num_kv_heads(model_runner.parallel_config) - head_size = model_runner.model_config.get_head_size() - - # Get the expected shape from the backend's get_kv_cache_shape method - # to ensure compatibility with different backends (triton vs flexattention) - attn_backend = None - for attn_group in model_runner._attn_group_iterator(): - attn_backend = attn_group.backend - break - - assert attn_backend is not None, "No attention backend found" - expected_kv_cache_shape = list( - attn_backend.get_kv_cache_shape(NUM_BLOCKS, BLOCK_SIZE, n_heads, head_size) - ) - - # TODO mla test - default_stride = tuple(range(5)) - # Permutation that gets you back to expected kv shape - for test_stride in ((1, 4, 0, 2, 3), (0, 1, 2, 3, 4)): - - def rnd_stride_order( - include_num_layers_dimension: bool = False, test_stride=test_stride - ): - assert not include_num_layers_dimension - return test_stride - - # Patch the attention backend class and re-trigger the KV cache creation - for attn_group in model_runner._attn_group_iterator(): - attn_backend = attn_group.backend - monkeypatch.setattr( - attn_backend, "get_kv_cache_stride_order", rnd_stride_order - ) - - model_runner.attn_groups = [] - model_runner.kv_caches = [] - model_runner.initialize_kv_cache(model_runner.kv_cache_config) - - # Shape is unchanged, but layout may differ - kv_cache_shape = model_runner.kv_caches[0].shape - assert list(kv_cache_shape) == expected_kv_cache_shape - if default_stride == test_stride: - assert all(kv.is_contiguous() for kv in model_runner.kv_caches) - else: - assert all(not kv.is_contiguous() for kv in model_runner.kv_caches) - - def test_update_config(model_runner): # Simple update model_runner.update_config({"load_config": {"load_format": "dummy"}}) @@ -971,21 +922,22 @@ def test_init_kv_cache_without_kv_sharing(default_vllm_config): vllm_config, [kv_cache_spec], [available_memory] )[0] assert kv_cache_config.num_blocks == num_expected_blocks - assert len(kv_cache_config.kv_cache_tensors) == 2 - assert kv_cache_config.kv_cache_tensors[0].size == available_memory // 2 - assert kv_cache_config.kv_cache_tensors[1].size == available_memory // 2 + assert len(kv_cache_config.kv_cache_tensors) == 1 + assert kv_cache_config.kv_cache_tensors[0].size == available_memory max_context_len = estimate_max_model_len(vllm_config, kv_cache_spec, 5 * GiB_bytes) # max context len with KV sharing should be 2x as large as without assert max_context_len == 1310720 # important: override tensor size to prevent large mem alloc during test - # this will only allocate 2 block worth of memory (2 * 32kb) + # this will only allocate 1 block worth of memory per slot (2 slots * 32kb) kv_cache_config.num_blocks = 1 for kv_cache_tensor in kv_cache_config.kv_cache_tensors: - kv_cache_tensor.size = kv_cache_spec[ - kv_cache_tensor.shared_by[0] - ].page_size_bytes + num_layer_slots = len(kv_cache_tensor.shared_by) + kv_cache_tensor.size = ( + kv_cache_spec[kv_cache_tensor.shared_by[0][0]].page_size_bytes + * num_layer_slots + ) runner.initialize_kv_cache(kv_cache_config) @@ -1079,7 +1031,7 @@ def test_hybrid_attention_mamba_tensor_shapes(): """ The GPU model runner creates different views into the KVCacheTensors for the attention and mamba layers - (via _reshape_kv_cache_tensors function). This test verifies + (via _allocate_kv_caches). This test verifies that the views are compatible: writing a mamba block will not corrupt an attention block and vice versa """ @@ -1385,7 +1337,7 @@ def test_hybrid_cache_integration(default_vllm_config, dist_init): kv_cache_config = KVCacheConfig( num_blocks=NUM_BLOCKS, kv_cache_tensors=[ - KVCacheTensor(size=tensor_size, shared_by=["layer.0"]), + KVCacheTensor(size=tensor_size, shared_by=[["layer.0"]]), ], kv_cache_groups=[ KVCacheGroupSpec(layer_names=["layer.0"], kv_cache_spec=attn_spec) diff --git a/vllm/compilation/passes/fusion/mla_rope_kvcache_cat_fusion.py b/vllm/compilation/passes/fusion/mla_rope_kvcache_cat_fusion.py index 5a493149a9d5..5998a0313840 100644 --- a/vllm/compilation/passes/fusion/mla_rope_kvcache_cat_fusion.py +++ b/vllm/compilation/passes/fusion/mla_rope_kvcache_cat_fusion.py @@ -45,7 +45,7 @@ def fused_rope_unified_mla_kv_cache_update_impl( cos_sin_cache, is_neox, layer_slot_mapping, - kv_cache, + kv_cache.squeeze(1), kv_cache_dtype, kv_cache_scale, ) diff --git a/vllm/config/kv_transfer.py b/vllm/config/kv_transfer.py index b22af99f703f..38d626670436 100644 --- a/vllm/config/kv_transfer.py +++ b/vllm/config/kv_transfer.py @@ -65,7 +65,7 @@ class KVTransferConfig: Only supported in V1.""" enable_permute_local_kv: bool = False - """Experiment feature flag to enable HND to NHD KV Transfer""" + """Experiment feature flag to enable HNC to NHC KV Transfer""" kv_load_failure_policy: Literal["recompute", "fail"] = "fail" """Policy for handling KV cache load failures. diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index d7a595716f08..593869b7886f 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -10,18 +10,20 @@ import torch -from vllm.config import VllmConfig, get_current_vllm_config, get_layers_from_vllm_config +from vllm.config import ( + VllmConfig, + get_current_vllm_config_or_none, + get_layers_from_vllm_config, +) from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory from vllm.logger import init_logger from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.platforms import current_platform from vllm.v1.attention.backend import AttentionBackend -from vllm.v1.kv_cache_interface import MambaSpec from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput if TYPE_CHECKING: from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase - from vllm.v1.kv_cache_interface import KVCacheSpec logger = init_logger(__name__) @@ -32,19 +34,18 @@ def get_kv_connector_cache_layout(): - # NOTE (NickLucche) When running disaggregated PD with NIXL, HND layout is - # used for faster transfer. - vllm_config = get_current_vllm_config() + # NOTE (NickLucche) When running disaggregated PD with NIXL, LBHNC layout + # is used for faster transfer. + vllm_config = get_current_vllm_config_or_none() + if vllm_config is None: + return None kv_config = vllm_config.kv_transfer_config if kv_config is not None: connector_cls = KVConnectorFactory.get_connector_class(kv_config) required_kvcache_layout = connector_cls.get_required_kvcache_layout(vllm_config) if required_kvcache_layout is not None: return required_kvcache_layout - logger.info_once( - "Connectors do not specify a kv cache layout, defaulting to NHD." - ) - return "NHD" + return None class KVOutputAggregator: @@ -279,11 +280,11 @@ def kv_postprocess_layout_on_receive(cache, indices): def kv_postprocess_blksize_and_layout_on_receive(cache, indices, block_size_ratio): """ - Transforms the layout of received KV cache to the local block_size and HND. - (Only works for local blocksize > remote blocksize) + Transforms the layout of received KV cache to the local block_size + and LBHNC. (Only works for local blocksize > remote blocksize) - prefill is HND, smaller block_size - decode(local) is NHD, larger block_size + prefill is LBHNC, smaller block_size + decode(local) is LBNHC, larger block_size """ blocks_to_update = cache.index_select(0, indices) @@ -408,43 +409,12 @@ def __post_init__(self): self._engines: dict[EngineId, EngineTransferInfo] = {} - # Figure out whether the first dimension of the cache is K/V - # or num_blocks. - attn_backend = self.attn_backends[0] - if not self.is_mamba: - _MOCK_BLOCK_SIZE = 16 - kv_cache_shape: tuple[int, ...] = attn_backend.get_kv_cache_shape( - num_blocks=1, - block_size=_MOCK_BLOCK_SIZE, - num_kv_heads=1, - head_size=1, - ) - logger.debug("Test kv_cache_shape: %s", kv_cache_shape) - # Non-MLA backends caches have 5 dims [num_blocks, 2, H,N,D], - # we just mock num_blocks to 1 for the dimension check below. - # Hybrid SSM models assume a single blocks_first layout - self._is_kv_layout_blocks_first = self.is_mamba or ( - len(kv_cache_shape) == 5 and kv_cache_shape[0] == 1 - ) - - self._cross_layers_blocks = False - if self.tensor_shape is not None: - self._cross_layers_blocks = ( - len(self.tensor_shape) == len(kv_cache_shape) + 1 - ) + # Cross-layer layouts (BLHNC) have B outermost, so all layers + # for a block are contiguous — transfers can coalesce multiple + # layers into one operation. + from vllm.v1.attention.backends.utils import resolve_kv_cache_layout - if self._cross_layers_blocks: - logger.debug("Using cross-layer KV cache") - _MOCK_NUM_LAYERS = 80 - kv_cache_shape = (_MOCK_NUM_LAYERS,) + kv_cache_shape - try: - kv_cache_stride_order = attn_backend.get_kv_cache_stride_order( - include_num_layers_dimension=self._cross_layers_blocks - ) - except (AttributeError, NotImplementedError): - assert self.tensor_shape is not None - kv_cache_stride_order = tuple(range(len(self.tensor_shape))) - kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order) + self._is_kv_layout_blocks_first = not resolve_kv_cache_layout().is_layer_compact # ============================================================ # Engine registration @@ -480,26 +450,6 @@ def get_engine_info(self, remote_engine_id: EngineId) -> EngineTransferInfo: def is_kv_layout_blocks_first(self) -> bool: return self._is_kv_layout_blocks_first - @property - def cross_layers_blocks(self) -> bool: - return self._cross_layers_blocks - - @property - def virtually_split_kv_in_blocks(self) -> bool: - # Whether to logically split each block into K and V halves. - # Applies when K/V are interleaved within each block (blocks-first), - # but NOT when cross-layer blocks are used — cross-layer blocks have - # per-layer K/V interleaving (L0_K, L0_V, L1_K, L1_V, ...) so a - # simple half-split does not separate K from V. - return self._is_kv_layout_blocks_first and not self._cross_layers_blocks - - @property - def split_k_and_v(self) -> bool: - # Whether to register regions for K and V separately (when present). - return not ( - self._cross_layers_blocks or self.is_mla or self.is_kv_layout_blocks_first - ) - # ============================================================ # Common methods # ============================================================ @@ -571,31 +521,6 @@ def target_remote_ranks(self, remote_engine_id: EngineId) -> list[int]: abs_ratio = -tp_ratio return [self.tp_rank * abs_ratio + i for i in range(abs_ratio)] - def get_transfer_cache_regions( - self, cache: torch.Tensor, layer_spec: "KVCacheSpec" - ) -> list[torch.Tensor] | torch.Tensor: - """Return the cache tensor(s) to register as NIXL memory regions, - also accounting for hybrid SSM models specificities. - """ - if isinstance(layer_spec, MambaSpec): - # Register the whole kv cache shared tensor, including - # SSM/Conv. - conv, ssm = cache - return [conv] - - # Check may be hacky but it's matching - # `_update_hybrid_attention_mamba_layout`. - if self.is_mamba and cache.shape[0] == 2: - # When MAMBA is present, all backends are blocks first, so - # that blocks can be shared between attention layers and mamba - # layers. Runner already adjusted strides for FlashAttn-like - # backends so its num_blocks first. - # Swap [2<>num_blocks] dims for hybrid SSM layout. - cache = cache.transpose(0, 1) - - # Regular case: backends like FA register K/V in separate regions - return cache if self.split_k_and_v else [cache] - def describe(self, remote_engine_id: EngineId) -> str: """One-line summary of transfer config for logging.""" info = self._engines[remote_engine_id] diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index fb5658da887a..146c44730d31 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -48,7 +48,7 @@ import torch from vllm.logger import init_logger -from vllm.v1.attention.backend import AttentionBackend, AttentionMetadata +from vllm.v1.attention.backend import AttentionMetadata from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.outputs import KVConnectorOutput @@ -173,14 +173,6 @@ class KVConnectorBase_V1(ABC): Base class for KV connectors. """ - @property - def prefer_cross_layer_blocks(self) -> bool: - """ - Indicates whether this connector prefers KV blocks that hold KV data for all - layers, which can speed up KV data transfers. Defaults to False. - """ - return False - def __init__( self, vllm_config: "VllmConfig", @@ -258,23 +250,6 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): """ return - def register_cross_layers_kv_cache( - self, kv_cache: torch.Tensor, attn_backend: type["AttentionBackend"] - ): - """ - Initialize with a single KV cache tensor used by all layers. - The first dimension should be num_layers. - This function will only be called for models with uniform layers, - and only if the prefers_cross_layer_blocks is set to True. - Only one of the functions - {register_kv_caches, register_cross_layers_kv_cache} will be called. - - Args: - kv_cache: a cross-layers kv cache tensor - attn_backend: The attention backend that corresponds to all layers - """ - return - def set_host_xfer_buffer_ops(self, copy_operation: CopyBlocksOp): """ Set the xPU-specific ops for copying KV between host and device. @@ -577,7 +552,7 @@ def get_required_kvcache_layout(cls, vllm_config: "VllmConfig") -> str | None: vllm_config (VllmConfig): the vllm config. Returns: - str: the required KV cache layout. e.g. HND, or NHD. + str: the required KV cache layout. e.g. HNC, or NHC. None if the connector does not require a specific layout. """ diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/example_hidden_states_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/example_hidden_states_connector.py index 3e4e6750858a..5a500dbfa5e9 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/example_hidden_states_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/example_hidden_states_connector.py @@ -38,8 +38,10 @@ def extract_from_kv_cache( num_tokens: int, ) -> torch.Tensor: """Extract data from KV cache.""" - block_size = kv_cache.shape[1] - return kv_cache[slot_mapping // block_size, slot_mapping % block_size][:num_tokens] + block_size = kv_cache.shape[2] + return kv_cache[slot_mapping // block_size, :, slot_mapping % block_size][ + :num_tokens + ] def load_hidden_states(path: str) -> dict[str, torch.Tensor]: @@ -123,15 +125,6 @@ class ExampleHiddenStatesConnector(KVConnectorBase_V1, SupportsHMA): Must be used in conjunction with the `extract_hidden_states` spec decoding method. """ - @property - def prefer_cross_layer_blocks(self) -> bool: - """ - Indicates whether this connector prefers KV blocks that hold KV data for all - layers, which can speed up KV data transfers. Defaults to False. - """ - # Must be False so that drafter kv cache isn't merged with verifier's - return False - def __init__( self, vllm_config: "VllmConfig", @@ -500,7 +493,7 @@ def get_required_kvcache_layout(cls, vllm_config: "VllmConfig") -> str | None: vllm_config (VllmConfig): the vllm config. Returns: - str: the required KV cache layout. e.g. HND, or NHD. + str: the required KV cache layout. e.g. HNC, or NHC. None if the connector does not require a specific layout. """ @@ -509,9 +502,9 @@ def get_required_kvcache_layout(cls, vllm_config: "VllmConfig") -> str | None: "get_required_kvcache_layout should not be called " "on the abstract base class" ) - # NHD means we have (num_tokens, num_heads) - # HND means we have (num_heads, num_tokens) - # For now, we only support NHD layout since this keeps the + # LBNHC means we have (num_tokens, num_heads) + # LBHNC means we have (num_heads, num_tokens) + # For now, we only support LBNHC layout since this keeps the # hidden states for each token together in memory. - # HND is primarily used when sharding heads across devices. - return "NHD" + # LBHNC is primarily used when sharding heads across devices. + return "LBNHC" diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_mp_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_mp_connector.py index 8786e91a5a14..8ecadaf2b952 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_mp_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_mp_connector.py @@ -974,7 +974,7 @@ def get_required_kvcache_layout(cls, vllm_config: "VllmConfig") -> str | None: vllm_config (VllmConfig): the vllm config. Returns: - str: the required KV cache layout. e.g. HND, or NHD. + str: the required KV cache layout. e.g. HNC, or NHC. None if the connector does not require a specific layout. """ diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py index ccb7257c88c4..9e342a1a3d6f 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py @@ -51,7 +51,7 @@ from vllm.utils.math_utils import cdiv from vllm.utils.network_utils import get_ip, make_zmq_path, make_zmq_socket from vllm.v1.attention.backend import AttentionMetadata -from vllm.v1.attention.backends.utils import get_kv_cache_layout +from vllm.v1.attention.backends.utils import resolve_kv_cache_layout from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import FullAttentionSpec, SlidingWindowSpec from vllm.v1.request import RequestStatus @@ -109,7 +109,7 @@ def _get_tp_ratio(local_tp_size: int, remote_tp_size: int) -> int: def _expand_transfer_regions( base_addrs: list[int], block_lens: list[int], - is_kv_layout_blocks_first: bool, + is_kv_layout_blocks_first: bool, # kept for API compat, unused ) -> list[TransferRegion]: """Expand registered KV tensors into the regions transferred by Mooncake.""" assert len(base_addrs) == len(block_lens), ( @@ -118,22 +118,13 @@ def _expand_transfer_regions( ) regions: list[TransferRegion] = [] for base_addr, block_len in zip(base_addrs, block_lens): - kv_block_len = block_len // 2 if is_kv_layout_blocks_first else block_len regions.append( TransferRegion( base_addr=base_addr, block_len=block_len, - kv_block_len=kv_block_len, + kv_block_len=block_len, ) ) - if is_kv_layout_blocks_first: - regions.append( - TransferRegion( - base_addr=base_addr + kv_block_len, - block_len=block_len, - kv_block_len=kv_block_len, - ) - ) return regions @@ -377,10 +368,10 @@ def get_required_kvcache_layout(cls, vllm_config: VllmConfig): if vllm_config.model_config.use_mla: return None logger.info_once( - "MooncakeConnector setting KV cache layout to HND for " + "MooncakeConnector setting KV cache layout to LBHNC for " "heterogeneous TP-safe KV transfer." ) - return "HND" + return "LBHNC" ############################################################ # Scheduler Side Methods @@ -852,7 +843,7 @@ def __init__( # NOTE (NickLucche) models with multiple backends are not supported yet backend = get_current_attn_backend(vllm_config) self.backend_name = backend.get_name() - self.kv_cache_layout = get_kv_cache_layout() + self.kv_cache_layout = resolve_kv_cache_layout().name logger.debug("Detected attention backend %s", self.backend_name) logger.debug("Detected kv cache layout %s", self.kv_cache_layout) @@ -1395,10 +1386,9 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): seen_base_addresses = [] self.block_len_per_layer = [] - split_k_and_v = self.transfer_topo.split_k_and_v tensor_size_bytes = None for layer_name, cache_or_caches in kv_caches.items(): - cache_list = cache_or_caches if split_k_and_v else [cache_or_caches] + cache_list = [cache_or_caches] logger.debug( "registering layer %s with %d cache tensor(s)", layer_name, @@ -1745,7 +1735,7 @@ def _get_transfer_regions( return _expand_transfer_regions( base_addrs=base_addrs, block_lens=block_lens, - is_kv_layout_blocks_first=self.transfer_topo.virtually_split_kv_in_blocks, + is_kv_layout_blocks_first=self.transfer_topo.is_kv_layout_blocks_first, ) def _get_sender_transfer_plan( diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py index 73418104bea6..372ce6f1b929 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py @@ -27,7 +27,7 @@ PromMetricT, ) from vllm.logger import init_logger -from vllm.v1.attention.backend import AttentionBackend, AttentionMetadata +from vllm.v1.attention.backend import AttentionMetadata from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.outputs import KVConnectorOutput @@ -203,12 +203,6 @@ def __init__( # Propagated from scheduler to worker side via the connector metadata. self._extra_async_saves: dict[str, int] = {} - @property - def prefer_cross_layer_blocks(self) -> bool: - if not self._connectors: - return False - return all(c.prefer_cross_layer_blocks for c in self._connectors) - @classmethod def _get_connector_classes_and_configs( cls, vllm_config: "VllmConfig" @@ -235,13 +229,6 @@ def _get_connector_classes_and_configs( ) return ret - def register_cross_layers_kv_cache( - self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend] - ): - # Register on all connectors - for c in self._connectors: - c.register_cross_layers_kv_cache(kv_cache, attn_backend) - def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): for c in self._connectors: c.register_kv_caches(kv_caches) @@ -540,7 +527,7 @@ def get_required_kvcache_layout(cls, vllm_config: "VllmConfig") -> str | None: vllm_config (VllmConfig): the vllm config. Returns: - str: the required KV cache layout. e.g. HND, or NHD. + str: the required KV cache layout. e.g. HNC, or NHC. None if the connector does not require a specific layout. """ assert vllm_config.kv_transfer_config is not None diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/connector.py index 187322b4ae4e..37357adff352 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/connector.py @@ -9,7 +9,6 @@ from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.utils import ( EngineId, - get_current_attn_backend, ) from vllm.distributed.kv_transfer.kv_connector.v1.base import ( CopyBlocksOp, @@ -40,10 +39,8 @@ ) from vllm.forward_context import ForwardContext from vllm.logger import init_logger -from vllm.v1.attention.backend import AttentionBackend, AttentionMetadata -from vllm.v1.attention.backends.utils import get_kv_cache_layout +from vllm.v1.attention.backend import AttentionMetadata from vllm.v1.core.sched.output import SchedulerOutput -from vllm.v1.kv_cache_interface import MambaSpec from vllm.v1.outputs import KVConnectorOutput if TYPE_CHECKING: @@ -55,36 +52,6 @@ class NixlConnector(KVConnectorBase_V1, SupportsHMA): - @property - def prefer_cross_layer_blocks(self) -> bool: - if any( - [ - isinstance(group.kv_cache_spec, MambaSpec) - for group in self.kv_cache_config.kv_cache_groups - ] - ): - # Hybrid SSM models do not yet support cross-layer layout - return False - - backend = get_current_attn_backend(self._vllm_config) - if backend.get_name() not in ( - "FLASH_ATTN", - "FLASHINFER", - "TRITON_ATTN", - ): - return False - - # For now there is no benefit to run cross layers when backend - # does not support on HND - if get_kv_cache_layout() != "HND": - return False - - extra_config = self.kv_transfer_config.kv_connector_extra_config - return ( - str(extra_config.get("enable_cross_layers_blocks", "False")).lower() - == "true" - ) - def __init__( self, vllm_config: VllmConfig, @@ -126,9 +93,10 @@ def get_required_kvcache_layout(cls, vllm_config: VllmConfig): # which fallback to the default behavior. return None logger.info_once( - "NixlConnector setting KV cache layout to HND for better xfer performance." + "NixlConnector setting KV cache layout to LBHNC for " + "better xfer performance." ) - return "HND" + return "LBHNC" ############################################################ # Scheduler Side Methods @@ -200,12 +168,6 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): assert self.connector_worker is not None self.connector_worker.register_kv_caches(kv_caches) - def register_cross_layers_kv_cache( - self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend] - ): - assert self.connector_worker is not None - self.connector_worker.register_cross_layers_kv_caches(kv_cache) - def set_host_xfer_buffer_ops(self, copy_operation: CopyBlocksOp): assert self.connector_worker is not None self.connector_worker.set_host_xfer_buffer_ops(copy_operation) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/metadata.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/metadata.py index b9e3436f5019..848a0451ac90 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/metadata.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/metadata.py @@ -34,8 +34,9 @@ # 2: Add remote_request_id to kv_transfer_params # 3: Add physical_blocks_per_logical_kv_block to NixlAgentMetadata # 4: Add KV block lease renewal through heartbeats +# 5: Add block_strides # -NIXL_CONNECTOR_VERSION: int = 4 +NIXL_CONNECTOR_VERSION: int = 5 @dataclass @@ -46,6 +47,7 @@ class NixlAgentMetadata: device_id: int num_blocks: int block_lens: list[int] + block_strides: list[int] kv_cache_layout: str block_size: int ssm_sizes: tuple[int, int] @@ -72,7 +74,7 @@ class NixlHandshakePayload(KVConnectorHandshakeMetadata): def compute_nixl_compatibility_hash( - vllm_config: VllmConfig, attn_backend_name: str, cross_layers_blocks: bool + vllm_config: VllmConfig, attn_backend_name: str ) -> str: """ Compute compatibility hash for NIXL KV transfer. @@ -116,7 +118,6 @@ def compute_nixl_compatibility_hash( # Attention backend and KV cache dtype affect memory layout "attn_backend_name": attn_backend_name, "cache_dtype": str(cache_config.cache_dtype), - "cross_layers_blocks": cross_layers_blocks, "is_hma_enabled": is_hma_enabled, } diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index 0d30d4a692ad..2d4907c0b071 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -67,9 +67,10 @@ from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils.network_utils import make_zmq_path -from vllm.v1.attention.backends.utils import get_kv_cache_layout +from vllm.v1.attention.backends.utils import resolve_kv_cache_layout from vllm.v1.kv_cache_interface import ( FullAttentionSpec, + KVCacheLayout, MambaSpec, UniformTypeKVCacheSpecs, ) @@ -235,7 +236,9 @@ def __init__( for group in kv_cache_config.kv_cache_groups for layer in group.layer_names } - self.hma_group_size = len(kv_cache_config.kv_cache_tensors) + self.hma_group_size = sum( + len(t.shared_by) for t in kv_cache_config.kv_cache_tensors + ) # ---- Model state (derived from model config) ---- mamba_ssm_size = (0, 0) @@ -418,7 +421,7 @@ def __init__( self.attn_backends = get_current_attn_backends(vllm_config) self.backend_name = self.attn_backends[0].get_name() - self.kv_cache_layout = get_kv_cache_layout() + self.kv_cache_layout = resolve_kv_cache_layout().name self.host_buffer_kv_cache_layout = self.kv_cache_layout logger.info( "Detected attention backend(s) %s", @@ -602,18 +605,18 @@ def initialize_host_xfer_buffer(self, kv_caches: dict[str, torch.Tensor]) -> Non kv_dtype = kv_cache.dtype permute_shape = False if ( - self.kv_cache_layout == "NHD" + self.kv_cache_layout == "LBNHC" and self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.enable_permute_local_kv ): logger.info_once( "'enable_permute_local_kv' flag is enabled while " - "device KV Layout is NHD. Init host buffer with" - " HND to better support Decode/Prefill TP_ratio > 1." + "device KV Layout is LBNHC. Init host buffer with" + " LBHNC to better support Decode/Prefill TP_ratio > 1." ) - # Since NHD will not support Decode/Prefill TP_ratio > 1, + # Since LBNHC will not support Decode/Prefill TP_ratio > 1, # we can leverage host_buffer for permute - self.host_buffer_kv_cache_layout = "HND" + self.host_buffer_kv_cache_layout = "LBHNC" kv_shape = ( tuple(kv_shape[i] for i in inv_order) if not self.use_mla @@ -774,17 +777,6 @@ def request_ready(f: Future[Any], entry=(req_id, meta)): fut.add_done_callback(request_ready) - def register_cross_layers_kv_caches(self, kv_cache: torch.Tensor) -> None: - """Register a cross-layers KV cache tensor with NIXL. - - `use_uniform_kv_cache()` guarantees a single KV cache group whose - layers all share the same `AttentionSpec`, so any layer name from - `_layer_specs` yields the correct per-layer spec for `page_size_bytes`. - """ - first_layer = next(iter(self._layer_specs)) - # Forwarding a real layer name rather than a synthetic key - self.register_kv_caches({first_layer: kv_cache}) - def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): """Register the KV Cache data in nixl.""" self.transfer_topo = TransferTopology( @@ -802,7 +794,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): is_mamba=self._has_mamba, ) self.compat_hash = compute_nixl_compatibility_hash( - self.vllm_config, self.backend_name, self.transfer_topo.cross_layers_blocks + self.vllm_config, self.backend_name ) if self.use_host_buffer: @@ -844,6 +836,11 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): # Enable different block lengths for different layers *only* when MLA is used. # This is not used for SSM layers, which use the counterpart `mamba_ssm_size`. self.block_len_per_layer = list[int]() + + # Per-block physical stride per layer (bytes). Read from each + # registered tensor's stride(0) so it stays correct under layouts + # that interleave layers within a block (BLHNC/BHLNC). + self.block_stride_per_layer = list[int]() for layer_name, cache_or_caches in xfer_buffers.items(): # NOTE (NickLucche) Hybrid SSM models assume a layout that is similar to # that of FI, with block laid out as in `get_backend_aware_kv_block_len`. @@ -854,24 +851,12 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): if isinstance(layer_spec, UniformTypeKVCacheSpecs): # MLA DSv32 Indexer case: UniformTypeKVCacheSpecs merges kv_cache_specs layer_spec = layer_spec.kv_cache_specs[layer_name] - cache_list = self.transfer_topo.get_transfer_cache_regions( - cache_or_caches, layer_spec - ) - # `layer_spec.page_size_bytes` only accounts for logical page_size, that is - # the page_size assuming constant `self._logical_num_blocks`. physical_page_size = ( layer_spec.page_size_bytes if isinstance(layer_spec, MambaSpec) else layer_spec.page_size_bytes // self._physical_blocks_per_logical_kv_block ) - # For when registering multiple tensors eg K/V in separate regions. - physical_page_size = physical_page_size // len(cache_list) - if self.transfer_topo._cross_layers_blocks: - # When cross-layers blocks are used, multiply by number of layers - physical_page_size = physical_page_size * len( - self.kv_cache_config.kv_cache_tensors - ) num_blocks = ( self._logical_num_blocks if isinstance(layer_spec, MambaSpec) @@ -883,75 +868,63 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): if tensor_size_bytes is None: tensor_size_bytes = curr_tensor_size_bytes - # TODO (NickLucche) we could eventually unify how we handle FA/FI regions, - # registering a single tensor for both K/V and splitting logically like FI. - for cache in cache_list: - base_addr = cache.data_ptr() - if base_addr in seen_base_addresses: - # NOTE (NickLucche) HMA employs memory pooling to share tensors - # across groups. This results in skipping all tensors but the ones - # pointed to by group0. Also, generally we will have more blocks - # per tensor but fewer regions. - logger.debug("Skipping %s because it's already seen", layer_name) - continue - logger.debug( - "Registering layer %s with cache shape: %s", layer_name, cache.shape + cache = cache_or_caches + base_addr = cache.data_ptr() + if base_addr in seen_base_addresses: + logger.debug("Skipping %s because it's already seen", layer_name) + continue + logger.debug( + "Registering layer %s with cache shape: %s", + layer_name, + cache.shape, + ) + seen_base_addresses.append(base_addr) + if isinstance(layer_spec, MambaSpec): + self.block_len_per_layer.append( + physical_page_size // self._physical_blocks_per_logical_kv_block + ) + self.block_stride_per_layer.append(self.block_len_per_layer[-1]) + else: + self.block_len_per_layer.append(physical_page_size) + self.block_stride_per_layer.append( + cache.stride(0) * cache.element_size() ) - seen_base_addresses.append(base_addr) - # Only record non-Mamba page sizes. - if isinstance(layer_spec, MambaSpec): - self.block_len_per_layer.append( - physical_page_size // self._physical_blocks_per_logical_kv_block - ) - else: - self.block_len_per_layer.append(physical_page_size) - - if cache.shape[0] != num_blocks: - raise AssertionError( - "All kv cache tensors must have the same number of " - f"blocks; layer={layer_name}, " - f"expected_num_blocks={num_blocks}, " - f"cache_shape={tuple(cache.shape)}, " - f"cache_stride={tuple(cache.stride())}, " - f"layer_spec={type(layer_spec).__name__}, " - f"backend={self.backend_name}, " - "all_backends=" - f"{[backend.get_name() for backend in self.attn_backends]}, " - f"kv_cache_layout={self.kv_cache_layout}, " - "blocks_first=" - f"{self.transfer_topo.is_kv_layout_blocks_first}" - ) - if not self.use_mla: - # Different kv cache shape is not supported by HeteroTP. - # This must also hold true for Mamba-like models. - assert tensor_size_bytes == curr_tensor_size_bytes, ( - "All kv cache tensors must have the same size" - ) - # Need to make sure the device ID is non-negative for NIXL, - # Torch uses -1 to indicate CPU tensors. - self.device_id = max(cache.get_device(), 0) - caches_data.append( - (base_addr, curr_tensor_size_bytes, self.device_id, "") + if cache.shape[0] != num_blocks: + raise AssertionError( + "All kv cache tensors must have the same number of " + f"blocks; layer={layer_name}, " + f"expected_num_blocks={num_blocks}, " + f"cache_shape={tuple(cache.shape)}, " + f"cache_stride={tuple(cache.stride())}, " + f"layer_spec={type(layer_spec).__name__}, " + f"backend={self.backend_name}, " + "all_backends=" + f"{[backend.get_name() for backend in self.attn_backends]}, " + f"kv_cache_layout={self.kv_cache_layout}, " + "blocks_first=" + f"{self.transfer_topo.is_kv_layout_blocks_first}" ) + if not self.use_mla: + assert tensor_size_bytes == curr_tensor_size_bytes, ( + "All kv cache tensors must have the same size" + ) + self.device_id = max(cache.get_device(), 0) + caches_data.append((base_addr, curr_tensor_size_bytes, self.device_id, "")) + logger.debug( "Different block lengths collected: %s", set(self.block_len_per_layer) ) assert len(self.block_len_per_layer) == len(seen_base_addresses) + assert len(self.block_stride_per_layer) == len(seen_base_addresses) self.kv_caches_base_addr[self.engine_id][self.tp_rank] = seen_base_addresses self.num_regions = len(caches_data) - if self.transfer_topo.virtually_split_kv_in_blocks: - # NOTE (NickLucche) When FlashInfer is used, memory is registered - # with joint KV for each block. This minimizes the overhead in - # registerMem allowing faster descs queries. In order to be able to - # split on kv_heads dim as required by heterogeneous TP, one must - # be able to index K/V separately. Hence we double the number - # of 'virtual' regions here and halve `block_len` below. - # Similarly for Mamba layers, we register SSM+Conv as a single region and - # then duplicate it logically to be able to index SSM/Conv separately. + if self.transfer_topo.is_kv_layout_blocks_first and self._has_mamba: + # For Mamba layers, SSM+Conv are registered as a single region + # and duplicated logically to index SSM/Conv separately. self.num_regions *= 2 # Total local FA descriptors (boundary between FA and mamba descs). @@ -993,6 +966,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id][self.tp_rank], num_blocks=self.num_blocks, block_lens=self.block_len_per_layer, + block_strides=self.block_stride_per_layer, kv_cache_layout=self.kv_cache_layout if not self.use_host_buffer else self.host_buffer_kv_cache_layout, @@ -1112,24 +1086,12 @@ def _build_fa_local( ) // block_size_ratio ) - page_stride = self.block_len_per_layer[i] // block_size_ratio + page_stride = self.block_stride_per_layer[i] // block_size_ratio for block_id in range(num_blocks): block_offset = block_id * page_stride addr = base_addr + block_offset result.append((addr, kv_block_len, self.device_id)) - if self.transfer_topo.virtually_split_kv_in_blocks: - # Separate and interleave K/V regions to maintain the same - # descs ordering. This is needed for selecting contiguous heads - # when split across TP ranks. - second_split = self.get_backend_aware_kv_block_len( - layer_idx=i, first_split=False, mamba_view=False - ) - for block_id in range(num_blocks): - block_offset = block_id * page_stride - addr = base_addr + block_offset - v_addr = addr + kv_block_len - result.append((v_addr, second_split, self.device_id)) return result def _build_fa_remote( @@ -1159,7 +1121,7 @@ def _build_fa_remote( local_block_len = local_block_len // num_attn_reads rank_offset = plan.rank_offset_factor * remote_kv_block_len - page_size = nixl_agent_meta.block_lens[i] + page_size = nixl_agent_meta.block_strides[i] for block_id in range(num_blocks): block_offset = block_id * page_size # For each block, grab the kv heads chunk belonging to current local @@ -1167,18 +1129,6 @@ def _build_fa_remote( addr = base_addr + block_offset + rank_offset result.append((addr, local_block_len, nixl_agent_meta.device_id)) - if self.transfer_topo.virtually_split_kv_in_blocks: - # With FlashInfer index V separately to allow head splitting. - second_split = self.get_backend_aware_kv_block_len( - layer_idx=i, first_split=False, mamba_view=False - ) - second_split = second_split // num_attn_reads - for block_id in range(num_blocks): - block_offset = block_id * page_size - addr = base_addr + block_offset + rank_offset - # Hop over the first split of remote page, K, to read V. - v_addr = addr + nixl_agent_meta.block_lens[i] // 2 - result.append((v_addr, second_split, nixl_agent_meta.device_id)) return result def register_local_xfer_handler( @@ -1261,7 +1211,7 @@ def add_remote_agent( tp_ratio = 4 // 2 = 2 Considering the KV Caches, if P-Worker_i has cache size [2, num_blocksP, kv_heads, block_size, head_dim] - then D-Worker_j has [2, num_blocksD, kv_heads//tp_ratio, block_size, head_dim]. Mind the "HND" layout format. + then D-Worker_j has [2, num_blocksD, kv_heads//tp_ratio, block_size, head_dim]. Mind the "LBHNC" layout format. Assuming num_blocksD >= num_blocksP, D-Worker0 reads from P-Worker0 by preparing the kv_heads//tp_ratio first heads from all the slots of all the blocks. D-Worker1 will do the same, but reading the second split along the kv_heads dimension, and so forth until "tp_ratio" D TP workers have pulled from P-Worker0. @@ -1465,10 +1415,10 @@ def _validate_remote_agent_handshake( if not self.use_mla and nixl_agent_meta.kv_cache_layout != kv_cache_layout: if ( self.kv_transfer_config.enable_permute_local_kv - and nixl_agent_meta.kv_cache_layout == "HND" + and nixl_agent_meta.kv_cache_layout == "LBHNC" ): logger.info( - "Remote is HND and local is NHD, enabled additional permute " + "Remote is LBHNC and local is LBNHC, enabled additional permute " "on local device KV." ) assert not self._is_hma_required, ( @@ -1478,7 +1428,7 @@ def _validate_remote_agent_handshake( else: raise RuntimeError( "Heterogeneous TP expects same kv_cache_layout. " - "Or enable experimental feature to use HND to NHD support by " + "Or enable experimental feature to use HNC to NHC support by " "setting 'enable_permute_local_kv'=True in --kv-transfer-config." ) # if remote_agent used attn is not same as local, @@ -1498,18 +1448,18 @@ def _validate_remote_agent_handshake( self.enable_heterogeneous_attn_post_process = True # Heterogeneous TP requires head-splitting, which only works with - # HND layout. MLA and replicated-KV cases don't split on heads. - # Mamba doesn't support heterogeneous TP. + # block-contiguous layouts (H before N, e.g. HNC / BHLNC). + # MLA and replicated-KV cases don't split on heads. if ( abs(tp_ratio) != 1 and not self.use_mla and not self.transfer_topo.is_kv_replicated(remote_engine_id) - and kv_cache_layout != "HND" + and not KVCacheLayout[kv_cache_layout].is_block_contiguous and not self.enable_permute_local_kv ): raise RuntimeError( "Heterogeneous TP head-dimension splitting requires contiguous heads. " - "Use HND layout on the prefill side." + "Use HNC layout on the prefill side." ) # Block len can only vary across layers when using MLA. @@ -1620,11 +1570,11 @@ def post_process_device_kv_on_receive( Post process device kv cache after receiving from remote. 3 types of post processing supported: - * kv_cache_postprocess_layout => convert from HND to NHD + * kv_cache_postprocess_layout => convert from HNC to NHC * kv_cache_postprocess_blksize => convert from small block size to large block size * kv_cache_postprocess_blksize_and_layout => convert from small - block size to large block size and convert from HND to NHD + block size to large block size and convert from HNC to NHC """ if len(self.device_kv_caches) == 0: @@ -1634,14 +1584,14 @@ def post_process_device_kv_on_receive( if self.enable_permute_local_kv and block_size_ratio > 1: logger.debug( "Post-processing device kv cache on receive by converting " - "block_size with %sx bigger and permuting layout from HND" - " to NHD.", + "block_size with %sx bigger and permuting layout from HNC" + " to NHC.", block_size_ratio, ) elif self.enable_permute_local_kv: logger.debug( "Post-processing device kv cache on receive by permuting layout" - "from HND to NHD." + "from HNC to NHC." ) else: logger.debug( @@ -1650,24 +1600,18 @@ def post_process_device_kv_on_receive( block_size_ratio, ) - split_k_and_v = self.transfer_topo.split_k_and_v - for block_ids in block_ids_list: indices = torch.tensor(block_ids, device=self.device_type, dtype=torch.long) - for _, cache_or_caches in self.device_kv_caches.items(): - cache_list = cache_or_caches if split_k_and_v else [cache_or_caches] - for cache in cache_list: - if self.enable_permute_local_kv and block_size_ratio > 1: - kv_postprocess_blksize_and_layout_on_receive( - cache, indices, block_size_ratio - ) - elif self.enable_permute_local_kv: - kv_postprocess_layout_on_receive(cache, indices) - else: - kv_postprocess_blksize_on_receive( - cache, indices, block_size_ratio - ) + for _, cache in self.device_kv_caches.items(): + if self.enable_permute_local_kv and block_size_ratio > 1: + kv_postprocess_blksize_and_layout_on_receive( + cache, indices, block_size_ratio + ) + elif self.enable_permute_local_kv: + kv_postprocess_layout_on_receive(cache, indices) + else: + kv_postprocess_blksize_on_receive(cache, indices, block_size_ratio) def post_process_device_kv_on_receive_heterogeneous_attn( self, block_ids: list[int] @@ -2416,11 +2360,8 @@ def get_backend_aware_kv_block_len( |1st_split-2nd_split| |1st_split-2nd_split | """ assert self.transfer_topo is not None - if self.transfer_topo.virtually_split_kv_in_blocks: - if mamba_view: - block_len = self._mamba_ssm_size[not first_split] - else: - block_len = self.block_len_per_layer[layer_idx] // 2 + if self.transfer_topo.is_kv_layout_blocks_first and mamba_view: + block_len = self._mamba_ssm_size[not first_split] else: block_len = self.block_len_per_layer[layer_idx] return block_len diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/offloading/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/offloading/worker.py index 8957ce3445ae..a271aa425624 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/offloading/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/offloading/worker.py @@ -17,7 +17,6 @@ OffloadingConnectorStats, ) from vllm.logger import init_logger -from vllm.v1.attention.backend import AttentionBackend from vllm.v1.kv_cache_interface import ( AttentionSpec, MambaSpec, @@ -59,7 +58,8 @@ def register_kv_caches( ): num_blocks = self.spec.kv_cache_config.num_blocks - # layer_name -> (num_blocks, page_size_bytes) tensor + # layer_name -> (num_blocks, page_size_bytes) int8 view. + # Standardized layouts always have num_blocks as the leading dim. tensors_per_block: dict[str, tuple[torch.Tensor, ...]] = {} # layer_name -> size of (un-padded) page in bytes unpadded_page_size_bytes: dict[str, int] = {} @@ -76,98 +76,79 @@ def register_kv_caches( layer_kv_cache_spec = per_layer_specs.get( layer_name, group_kv_cache_spec ) - if isinstance(layer_kv_cache_spec, AttentionSpec): - layer_kv_cache = kv_caches[layer_name] - assert isinstance(layer_kv_cache, torch.Tensor) - assert layer_kv_cache.storage_offset() == 0 + layer_kv_cache = kv_caches[layer_name] + # AttentionSpec yields a single tensor; MambaSpec yields a + # list of typed state tensors that share one underlying + # buffer. Either way, the first tensor's storage_offset + # marks the start of this layer's region. + ref = ( + layer_kv_cache[0] + if isinstance(layer_kv_cache, list) + else layer_kv_cache + ) + page = layer_kv_cache_spec.page_size_bytes + offset = ref.storage_offset() * ref.element_size() + tensors_per_block[layer_name] = ( + torch.tensor([], dtype=torch.int8, device=ref.device) + .set_(ref.untyped_storage()) + .view(-1)[offset : offset + num_blocks * page] + .view(num_blocks, page), + ) + page_size_bytes[layer_name] = page - storage = layer_kv_cache.untyped_storage() - page = layer_kv_cache_spec.page_size_bytes - tensors_per_block[layer_name] = ( - torch.tensor( - [], - dtype=torch.int8, - device=layer_kv_cache.device, - ) - .set_(storage) - .view(num_blocks, page), - ) - page_size_bytes[layer_name] = layer_kv_cache_spec.page_size_bytes + if isinstance(layer_kv_cache_spec, AttentionSpec): unpadded_page_size_bytes[layer_name] = ( layer_kv_cache_spec.real_page_size_bytes ) - elif isinstance(layer_kv_cache_spec, MambaSpec): - state_tensors = kv_caches[layer_name] - assert isinstance(state_tensors, list) - - # re-construct the raw (num_blocks, page_size) tensor - # from the first state tensor - assert len(state_tensors) > 0 - first_state_tensor = state_tensors[0] - assert first_state_tensor.storage_offset() == 0 - tensor = ( - torch.tensor( - [], - dtype=torch.int8, - device=first_state_tensor.device, - ) - .set_(first_state_tensor.untyped_storage()) - .view((num_blocks, layer_kv_cache_spec.page_size_bytes)) - ) - tensors_per_block[layer_name] = (tensor,) - - page_size_bytes[layer_name] = layer_kv_cache_spec.page_size_bytes unpadded_page_size_bytes[layer_name] = replace( layer_kv_cache_spec, page_size_padded=None ).page_size_bytes - else: raise NotImplementedError block_tensors: list[CanonicalKVCacheTensor] = [] block_data_refs: dict[str, list[CanonicalKVCacheRef]] = defaultdict(list) for kv_cache_tensor in self.spec.kv_cache_config.kv_cache_tensors: - # Filter to layers that were actually processed above. - # _get_kv_cache_config_deepseek_v4 emits KVCacheTensor entries for - # every (tuple_idx, page_size) slot; slots where no group has a - # layer at that index produce an empty shared_by (reserved memory - # with no corresponding model layer). - tensor_layer_names = [ - n for n in kv_cache_tensor.shared_by if n in tensors_per_block - ] - if not tensor_layer_names: - continue - - # verify all layers in the group reference the exact same tensors - assert len({len(tensors_per_block[n]) for n in tensor_layer_names}) == 1 - assert ( - len({tensors_per_block[n][0].data_ptr() for n in tensor_layer_names}) - == 1 - ) - assert ( - len({tensors_per_block[n][0].stride() for n in tensor_layer_names}) == 1 - ) - - # pick the first layer to represent the group - first_layer_name = tensor_layer_names[0] - for tensor in tensors_per_block[first_layer_name]: - block_tensors.append( - CanonicalKVCacheTensor( - tensor=tensor, - page_size_bytes=page_size_bytes[first_layer_name], + for slot_layers in kv_cache_tensor.shared_by: + # Filter to layers that were actually processed above. + # Some slots may have no corresponding model layer (reserved + # memory with no group layer at that index). + tensor_layer_names = [n for n in slot_layers if n in tensors_per_block] + if not tensor_layer_names: + continue + + # Verify all layers in the slot reference the same tensors. + assert len({len(tensors_per_block[n]) for n in tensor_layer_names}) == 1 + assert ( + len( + {tensors_per_block[n][0].data_ptr() for n in tensor_layer_names} ) + == 1 + ) + assert ( + len({tensors_per_block[n][0].stride() for n in tensor_layer_names}) + == 1 ) - curr_tensor_idx = len(block_tensors) - 1 - for layer_name in tensor_layer_names: - block_data_refs[layer_name].append( - CanonicalKVCacheRef( - tensor_idx=curr_tensor_idx, - page_size_bytes=(unpadded_page_size_bytes[layer_name]), + first_layer_name = tensor_layer_names[0] + for tensor in tensors_per_block[first_layer_name]: + block_tensors.append( + CanonicalKVCacheTensor( + tensor=tensor, + page_size_bytes=page_size_bytes[first_layer_name], ) ) + curr_tensor_idx = len(block_tensors) - 1 + for layer_name in tensor_layer_names: + block_data_refs[layer_name].append( + CanonicalKVCacheRef( + tensor_idx=curr_tensor_idx, + page_size_bytes=(unpadded_page_size_bytes[layer_name]), + ) + ) + group_data_refs: list[list[CanonicalKVCacheRef]] = [] for kv_cache_group in self.spec.kv_cache_config.kv_cache_groups: group_refs: list[CanonicalKVCacheRef] = [] @@ -182,53 +163,6 @@ def register_kv_caches( self._register_handlers(canonical_kv_caches) - def register_cross_layers_kv_cache( - self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend] - ): - # verify that num_blocks is at physical position 0 in the cross-layers - # tensor layout. - test_shape = attn_backend.get_kv_cache_shape( - num_blocks=1234, block_size=16, num_kv_heads=1, head_size=256 - ) - num_blocks_logical_dim = test_shape.index(1234) + 1 - physical_to_logical = attn_backend.get_kv_cache_stride_order( - include_num_layers_dimension=True - ) - num_blocks_physical_dim = physical_to_logical.index(num_blocks_logical_dim) - assert num_blocks_physical_dim == 0 - - kv_cache_groups = self.spec.kv_cache_config.kv_cache_groups - assert len(kv_cache_groups) == 1 - kv_cache_spec = kv_cache_groups[0].kv_cache_spec - num_layers = len(kv_cache_groups[0].layer_names) - page_size_bytes = kv_cache_spec.page_size_bytes * num_layers - - assert kv_cache.storage_offset() == 0 - storage = kv_cache.untyped_storage() - assert len(storage) % page_size_bytes == 0 - num_blocks = len(storage) // page_size_bytes - tensor = ( - torch.tensor( - [], - dtype=torch.int8, - device=kv_cache.device, - ) - .set_(storage) - .view(num_blocks, page_size_bytes) - ) - kv_cache_tensor = CanonicalKVCacheTensor( - tensor=tensor, page_size_bytes=page_size_bytes - ) - # in cross layers layout, there's currently only a single group - kv_cache_data_ref = CanonicalKVCacheRef( - tensor_idx=0, page_size_bytes=page_size_bytes - ) - canonical_kv_caches = CanonicalKVCaches( - tensors=[kv_cache_tensor], group_data_refs=[[kv_cache_data_ref]] - ) - - self._register_handlers(canonical_kv_caches) - def handle_preemptions(self, kv_connector_metadata: OffloadingConnectorMetadata): for job_id, transfer_spec in self._unsubmitted_store_jobs: success = self.worker.transfer_async(job_id, transfer_spec) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py index 20888c71f84c..48a774dea9ea 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py @@ -34,7 +34,7 @@ OffloadingConnectorWorker, ) from vllm.forward_context import ForwardContext -from vllm.v1.attention.backend import AttentionBackend, AttentionMetadata +from vllm.v1.attention.backend import AttentionMetadata from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import KVCacheConfig @@ -44,10 +44,6 @@ class OffloadingConnector(KVConnectorBase_V1, SupportsHMA): - @property - def prefer_cross_layer_blocks(self) -> bool: - return True - def __init__( self, vllm_config: VllmConfig, @@ -75,12 +71,6 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): assert self.connector_worker is not None self.connector_worker.register_kv_caches(kv_caches) - def register_cross_layers_kv_cache( - self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend] - ): - assert self.connector_worker is not None - self.connector_worker.register_cross_layers_kv_cache(kv_cache, attn_backend) - def handle_preemptions(self, kv_connector_metadata: KVConnectorMetadata): assert self.connector_worker is not None assert isinstance(kv_connector_metadata, OffloadingConnectorMetadata) @@ -176,7 +166,7 @@ def take_events(self) -> Iterable[KVCacheEvent]: @classmethod def get_required_kvcache_layout(cls, vllm_config: VllmConfig) -> str | None: - return "HND" + return "LBHNC" def reset_cache(self) -> bool | None: assert self.connector_scheduler is not None diff --git a/vllm/envs.py b/vllm/envs.py index dc11fbd224d9..fe31bf49375a 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -205,7 +205,9 @@ VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS: int = 1 VLLM_MQ_MAX_CHUNK_BYTES_MB: int = 16 VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS: int = 300 - VLLM_KV_CACHE_LAYOUT: Literal["NHD", "HND"] | None = None + VLLM_KV_CACHE_LAYOUT: ( + Literal["LBNHC", "LBHNC", "NHC", "HNC", "NHD", "HND", "BLHNC", "BHLNC"] | None + ) = None VLLM_SSM_CONV_STATE_LAYOUT: Literal["SD", "DS"] | None = None VLLM_COMPUTE_NANS_IN_LOGITS: bool = False VLLM_USE_NVFP4_CT_EMULATIONS: bool = False @@ -1617,18 +1619,20 @@ def _resolve_rust_frontend_path() -> str | None: ), # KV Cache layout used throughout vllm. # Some common values are: - # - NHD - # - HND - # Where N=num_blocks, H=num_heads and D=head_size. The default value will - # leave the layout choice to the backend. Mind that backends may only + # - LBNHC + # - LBHNC + # Where N=num_states, H=num_heads and C=state_content. The default value + # will leave the layout choice to the backend. Mind that backends may only # implement and support a subset of all possible layouts. "VLLM_KV_CACHE_LAYOUT": env_with_choices( - "VLLM_KV_CACHE_LAYOUT", None, ["NHD", "HND"] + "VLLM_KV_CACHE_LAYOUT", + None, + ["LBNHC", "LBHNC", "NHC", "HNC", "NHD", "HND", "BLHNC", "BHLNC"], ), # SSM conv state layout used for Mamba models. # - SD: (state_len, dim) — dim contiguous (default) # - DS: (dim, state_len) — TP-sharded dim on dim1, - # consistent with SSM temporal state and HND KV cache layout. + # consistent with SSM temporal state and LBHNC KV cache layout. "VLLM_SSM_CONV_STATE_LAYOUT": env_with_choices( "VLLM_SSM_CONV_STATE_LAYOUT", None, ["SD", "DS"] ), diff --git a/vllm/model_executor/layers/attention/mla_attention.py b/vllm/model_executor/layers/attention/mla_attention.py index 140e071c7465..bd8e64a6bcdd 100644 --- a/vllm/model_executor/layers/attention/mla_attention.py +++ b/vllm/model_executor/layers/attention/mla_attention.py @@ -519,6 +519,10 @@ def __init__( compile_native=True, ) + def bind_kv_cache(self, kv_cache: torch.Tensor) -> None: + # [B, H=1, N, C] -> [B, N, C] + self.kv_cache = kv_cache.squeeze(1) + @property def chunked_prefill_workspace_size(self) -> int: if self._chunked_prefill_workspace_size is None: @@ -1172,27 +1176,6 @@ def get_name() -> str: def get_builder_cls() -> type["MLACommonMetadataBuilder"]: return MLACommonMetadataBuilder - @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, # assumed to be 1 for MLA - head_size: int, - cache_dtype_str: str = "auto", - ) -> tuple[int, ...]: - return (num_blocks, block_size, head_size) - - @staticmethod - def get_kv_cache_stride_order( - include_num_layers_dimension: bool = False, - ) -> tuple[int, ...]: - if include_num_layers_dimension: - # MLA kernels require contiguous per-layer KV cache views. - # Identity permutation keeps num_layers first in physical - # layout, signaling cross-layer allocation is unsupported. - return (0, 1, 2, 3) - return (0, 1, 2) - @classmethod def get_supported_head_sizes(cls) -> list[int]: return [320, 576] @@ -2054,7 +2037,7 @@ def _compute_prefill_context( toks = prefill_metadata.chunked_context.seq_tot[i] if not use_fp8_prefill: ops.gather_and_maybe_dequant_cache( - src_cache=kv_c_and_k_pe_cache, + src_cache=kv_c_and_k_pe_cache.squeeze(1), dst=workspace, block_table=prefill_metadata.block_table, cu_seq_lens=prefill_metadata.chunked_context.cu_seq_lens[i], @@ -2067,7 +2050,7 @@ def _compute_prefill_context( else: # FP8 path: gather cache without dequantization ops.cp_gather_cache( - src_cache=kv_c_and_k_pe_cache, + src_cache=kv_c_and_k_pe_cache.squeeze(1), dst=workspace, block_table=prefill_metadata.block_table, cu_seq_lens=prefill_metadata.chunked_context.cu_seq_lens[i], @@ -2162,7 +2145,7 @@ def _context_parallel_compute_prefill_context( for i in range(iters): toks = prefill_metadata.chunked_context.seq_tot[i] ops.cp_gather_cache( - src_cache=kv_c_and_k_pe_cache, + src_cache=kv_c_and_k_pe_cache.squeeze(1), dst=workspace, block_table=prefill_metadata.block_table, cu_seq_lens=prefill_metadata.chunked_context.padded_local_cu_seq_lens[ diff --git a/vllm/model_executor/layers/attention_layer_base.py b/vllm/model_executor/layers/attention_layer_base.py index 97395b641497..598870c92bcb 100644 --- a/vllm/model_executor/layers/attention_layer_base.py +++ b/vllm/model_executor/layers/attention_layer_base.py @@ -4,6 +4,8 @@ from abc import ABC, abstractmethod +import torch + from vllm.config import VllmConfig from vllm.v1.attention.backend import AttentionBackend, AttentionImpl from vllm.v1.kv_cache_interface import KVCacheSpec @@ -20,6 +22,10 @@ class AttentionLayerBase(ABC): impl: "AttentionImpl" + def bind_kv_cache(self, kv_cache: torch.Tensor) -> None: + """Bind a ``[B, H, N, C]`` cache view; override to reshape.""" + self.kv_cache = kv_cache + @abstractmethod def get_attn_backend(self) -> type[AttentionBackend]: """Get the attention backend class for this layer.""" diff --git a/vllm/model_executor/layers/mamba/abstract.py b/vllm/model_executor/layers/mamba/abstract.py index 8bbb21d7bc90..05ceb6ac48bc 100644 --- a/vllm/model_executor/layers/mamba/abstract.py +++ b/vllm/model_executor/layers/mamba/abstract.py @@ -2,11 +2,13 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import abstractmethod from collections.abc import Iterable +from math import prod import torch from vllm.config import VllmConfig from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase +from vllm.utils.torch_utils import get_dtype_size from vllm.v1.attention.backend import AttentionBackend from vllm.v1.attention.backends.registry import MambaAttentionBackendEnum from vllm.v1.attention.selector import get_mamba_attn_backend @@ -23,6 +25,18 @@ class MambaBase(AttentionLayerBase): # in the shape specified by `self.get_state_shape`. kv_cache: tuple[torch.Tensor, ...] + def bind_kv_cache(self, kv_cache: torch.Tensor) -> None: + """Unpack a raw 4D ``[B, 1, 1, C]`` int8 view into per-state views.""" + pages = kv_cache.squeeze(dim=(1, 2)) + states: list[torch.Tensor] = [] + offset = 0 + for shape, dtype in zip(self.get_state_shape(), self.get_state_dtype()): + nbytes = prod(shape) * get_dtype_size(dtype) + state = pages[:, offset : offset + nbytes].view(dtype) + states.append(state.view(-1, *shape)) + offset += nbytes + self.kv_cache = tuple(states) + @abstractmethod def get_state_shape(self) -> Iterable[tuple[int, ...]]: """ diff --git a/vllm/model_executor/layers/mamba/mamba_utils.py b/vllm/model_executor/layers/mamba/mamba_utils.py index c1fd81e40e34..aaf579bce3cc 100644 --- a/vllm/model_executor/layers/mamba/mamba_utils.py +++ b/vllm/model_executor/layers/mamba/mamba_utils.py @@ -28,7 +28,7 @@ def get_conv_state_layout() -> ConvStateLayoutType: """Return the SSM conv state layout. SD = (state_len, dim) — dim is the innermost contiguous dimension. - DS = (dim, state_len) — TP-sharded dim is on dim-1 (like HND for KV + DS = (dim, state_len) — TP-sharded dim is on dim-1 (like HNC for KV cache), consistent with SSM temporal state layout. """ layout: ConvStateLayoutType | None = envs.VLLM_SSM_CONV_STATE_LAYOUT diff --git a/vllm/model_executor/models/extract_hidden_states.py b/vllm/model_executor/models/extract_hidden_states.py index 8df4823b6973..8f08a055b07a 100644 --- a/vllm/model_executor/models/extract_hidden_states.py +++ b/vllm/model_executor/models/extract_hidden_states.py @@ -80,11 +80,11 @@ def dummy_attention(layer_name, _placeholder): def basic_cache( to_cache: torch.Tensor, # shape: [seq_len, num_heads, head_size] - kv_cache: torch.Tensor, # shape: [num_blocks, block_size, num_heads, head_size] + kv_cache: torch.Tensor, # shape: [num_blocks, num_heads, block_size, head_size] slot_mapping: torch.Tensor, # shape: [seq_len] ): - block_size = kv_cache.shape[1] - kv_cache[slot_mapping // block_size, slot_mapping % block_size] = to_cache + block_size = kv_cache.shape[2] + kv_cache[slot_mapping // block_size, :, slot_mapping % block_size] = to_cache ######### CacheOnlyAttentionBackend ######## @@ -120,18 +120,6 @@ def supports_mm_prefix(cls) -> bool: def get_impl_cls() -> type["CacheOnlyAttentionImpl"]: return CacheOnlyAttentionImpl - @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, - head_size: int, - cache_dtype_str: str = "auto", - ) -> tuple[int, ...]: - # We set `num_kv_heads = num_hidden_layers` and `head_size = hidden_size` - # We also don't use a k/v (2) dim - return (num_blocks, block_size, num_kv_heads, head_size) - @staticmethod def get_builder_cls() -> type["CacheOnlyAttentionMetadataBuilder"]: return CacheOnlyAttentionMetadataBuilder diff --git a/vllm/model_executor/models/whisper_causal.py b/vllm/model_executor/models/whisper_causal.py index dfbf69418a6c..07b496c2605d 100644 --- a/vllm/model_executor/models/whisper_causal.py +++ b/vllm/model_executor/models/whisper_causal.py @@ -125,12 +125,6 @@ def __init__( vllm_config: VllmConfig, device: torch.device, ): - assert kv_cache_spec.num_kv_heads % block_pool_size == 0 - kv_cache_spec = replace( - kv_cache_spec, - block_size=kv_cache_spec.block_size * block_pool_size, - num_kv_heads=kv_cache_spec.num_kv_heads // block_pool_size, - ) super().__init__(kv_cache_spec, layer_names, vllm_config, device) # Override model_config-derived values with the actual # encoder values from kv_cache_spec @@ -247,18 +241,6 @@ def forward( overrides={ "get_builder_cls": lambda: WhisperCausalAttentionWithBlockPoolingBuilder, "get_impl_cls": lambda: WhisperCausalAttentionWithBlockPoolingImpl, - "get_kv_cache_shape": lambda num_blocks, - block_size, - num_kv_heads, - head_size, - cache_dtype_str: underlying_attn_backend.get_kv_cache_shape( - num_blocks, - # we stretch each block by `block_pool_size` - block_size * block_pool_size, - num_kv_heads // block_pool_size, - head_size, - cache_dtype_str, - ), "forward_includes_kv_cache_update": True, }, ) @@ -325,9 +307,11 @@ def __init__( def get_kv_cache_spec(self, vllm_config: VllmConfig): kv_cache_spec = super().get_kv_cache_spec(vllm_config) assert isinstance(kv_cache_spec, AttentionSpec) + assert kv_cache_spec.num_kv_heads % self.block_pool_size == 0 kv_cache_spec = replace( kv_cache_spec, - num_kv_heads=self.block_pool_size * kv_cache_spec.num_kv_heads, + block_size=kv_cache_spec.block_size * self.block_pool_size, + num_kv_heads=kv_cache_spec.num_kv_heads // self.block_pool_size, ) return kv_cache_spec diff --git a/vllm/models/deepseek_v4/attention.py b/vllm/models/deepseek_v4/attention.py index 55cb3d94ba67..0d51eca95fdc 100644 --- a/vllm/models/deepseek_v4/attention.py +++ b/vllm/models/deepseek_v4/attention.py @@ -594,6 +594,10 @@ def __init__( self.kv_cache = torch.tensor([]) + def bind_kv_cache(self, kv_cache: torch.Tensor) -> None: + # [B, H=1, N, C] -> [B, N, C] + self.kv_cache = kv_cache.squeeze(1) + def get_attn_backend(self) -> type[AttentionBackend]: return self.backend_cls @@ -607,7 +611,7 @@ def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None: num_kv_heads=1, head_size=self.head_dim, dtype=torch.uint8, - compress_ratio=self.compress_ratio, + tokens_per_state=self.compress_ratio, cache_dtype_str=self.kv_cache_dtype, alignment=576, # NOTE: FlashMLA requires 576B alignment model_version="deepseek_v4", @@ -644,15 +648,19 @@ def __init__( raise ValueError(f"Duplicate layer name: {prefix}") compilation_config.static_forward_context[prefix] = self + def bind_kv_cache(self, kv_cache: torch.Tensor) -> None: + # [B, H=1, N, C] -> [B, N, C] + self.kv_cache = kv_cache.squeeze(1) + def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec: # head_dim already carries the fp8 scale padding - # compress_ratio=1 for V3.2, >1 for DeepseekV4; both use the same cache layout. + # tokens_per_state=1 for V3.2, >1 for DSV4; same cache layout. return MLAAttentionSpec( block_size=self.cache_config.block_size, num_kv_heads=1, head_size=self.head_dim, dtype=self.dtype, - compress_ratio=self.compress_ratio, + tokens_per_state=self.compress_ratio, # DeepseekV4 aligns indexer pages to FlashMLA's 576B so they can pack with # the indexer's compressor state cache. V3.2 keeps the legacy layout. alignment=576, diff --git a/vllm/models/deepseek_v4/compressor.py b/vllm/models/deepseek_v4/compressor.py index f36dc8f17629..6b4b7f79eab4 100644 --- a/vllm/models/deepseek_v4/compressor.py +++ b/vllm/models/deepseek_v4/compressor.py @@ -54,25 +54,6 @@ def get_supported_head_sizes(cls) -> list[int]: def get_builder_cls() -> type["CompressorMetadataBuilder"]: return CompressorMetadataBuilder - @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, - head_size: int, - cache_dtype_str: str = "auto", - ) -> tuple[int, ...]: - assert num_kv_heads == 1 - return (num_blocks, block_size, head_size) - - @staticmethod - def get_kv_cache_stride_order( - include_num_layers_dimension: bool = False, - ) -> tuple[int, ...]: - if include_num_layers_dimension: - return (0, 1, 2, 3) - return (0, 1, 2) - @dataclass class CompressorMetadata: @@ -154,6 +135,10 @@ def __init__( else: raise ValueError(f"Invalid compress ratio: {compress_ratio}") + def bind_kv_cache(self, kv_cache: torch.Tensor) -> None: + # [B, H=1, N, C] -> [B, N, C] + self.kv_cache = kv_cache.squeeze(1) + def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec: return SlidingWindowMLASpec( # only has one vector instead of K + V block_size=self.block_size, diff --git a/vllm/models/deepseek_v4/nvidia/flashmla.py b/vllm/models/deepseek_v4/nvidia/flashmla.py index 5c8b08d4c120..0c5a1605722a 100644 --- a/vllm/models/deepseek_v4/nvidia/flashmla.py +++ b/vllm/models/deepseek_v4/nvidia/flashmla.py @@ -95,21 +95,6 @@ def get_supported_head_sizes(cls) -> list[int]: # V3.2 default of 576 from FlashMLASparseBackend). return [512] - @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, - head_size: int, - cache_dtype_str: str = "auto", - ) -> tuple[int, ...]: - if cache_dtype_str == "fp8_ds_mla": - # DeepseekV4 main MLA: 584B per token (448 NoPE + 128 RoPE + 8 fp8 scale). - # head_size passed in is the semantic head_dim (512). - return (num_blocks, block_size, 584) - else: - return (num_blocks, block_size, head_size) - class DeepseekV4FlashMLASparseImpl(DeepseekV4SparseMLAAttentionImpl): """FlashMLA sparse MLA implementation for DeepSeek V4's custom MLA layer.""" diff --git a/vllm/models/deepseek_v4/nvidia/ops/sparse_attn_compress_cutedsl.py b/vllm/models/deepseek_v4/nvidia/ops/sparse_attn_compress_cutedsl.py index ed16ca6d3b54..8ab5dbb5116b 100644 --- a/vllm/models/deepseek_v4/nvidia/ops/sparse_attn_compress_cutedsl.py +++ b/vllm/models/deepseek_v4/nvidia/ops/sparse_attn_compress_cutedsl.py @@ -1313,6 +1313,8 @@ def compress_norm_rope_store_cutedsl( token_stride: int, scale_dim: int, ) -> None: + # (B, H=1, N, C) -> (B, N, C) + kv_cache = kv_cache.squeeze(1) if compress_ratio == 4: # For C4A, the single fused kernel is faster than the two-kernel version. fused_kv_compress_norm_rope_insert_sparse_attn_cutedsl( diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index 5947bff9b080..681674d3501c 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -55,10 +55,10 @@ def get_attn_backend_cls( ) -> str: from vllm.v1.attention.backends.utils import set_kv_cache_layout - set_kv_cache_layout("NHD") + set_kv_cache_layout("LBNHC") logger.info( - "Setting VLLM_KV_CACHE_LAYOUT to 'NHD' for XPU; " - "only NHD layout is supported by XPU attention kernels." + "Setting VLLM_KV_CACHE_LAYOUT to 'LBNHC' for XPU; " + "only LBNHC layout is supported by XPU attention kernels." ) # TurboQuant KV cache: route directly to TQ backend diff --git a/vllm/utils/torch_utils.py b/vllm/utils/torch_utils.py index 12ec5b0fcc66..3ba33ca634ab 100644 --- a/vllm/utils/torch_utils.py +++ b/vllm/utils/torch_utils.py @@ -417,26 +417,24 @@ def nvfp4_kv_cache_full_dim(head_size: int) -> int: return head_size // 2 + head_size // 16 -def _nvfp4_split_data_scale( +def nvfp4_split_data_scale( kv_side: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: - """Split a single NVFP4 KV-side buffer into data and scale views. + """Split one side (K or V) of an NVFP4 KV cache into data and scale. - The input is a 4D tensor for one KV side (K or V) whose last - dimension is ``full_dim = data_dim + scale_dim``. The physical - layout within each side is [data | scale], both packed contiguously. + The input is a 4D uint8 tensor whose last dimension is + ``full_dim = data_dim + scale_dim``. The physical layout within + each side is ``[data | scale]``, both packed contiguously. + + The caller is responsible for slicing K and V from the combined + cache first (e.g. ``kv_cache.split(num_kv_heads, dim=1)``). Args: - kv_side: 4D uint8 tensor with shape - ``(num_pages, dim_1, dim_2, full_dim)``. - May be in any permutation order (NHD or HND). + kv_side: 4D uint8 tensor ``(B, H, N, full_dim)``. Returns: - ``(data, scale)`` where - ``data`` is a uint8 view with shape - ``(num_pages, dim_1, dim_2, data_dim)``. - ``scale`` is a float8_e4m3fn view with shape - ``(num_pages, dim_1, dim_2, scale_dim)``. + ``(data, scale)`` where *data* is uint8 and *scale* is + float8_e4m3fn, both views of the same storage. """ num_pages = kv_side.shape[0] dim_1, dim_2 = kv_side.shape[1], kv_side.shape[2] @@ -447,9 +445,6 @@ def _nvfp4_split_data_scale( data_per_kv = dim_1 * dim_2 * data_dim page_bytes = kv_side.stride(0) - # Derive inner strides from the kv_side strides, scaling by the - # ratio of the target dim to full_dim. This preserves the physical - # layout (NHD vs HND) encoded in the input tensor's strides. s1 = kv_side.stride(1) * data_dim // full_dim s2 = kv_side.stride(2) * data_dim // full_dim data_shape = (num_pages, dim_1, dim_2, data_dim) @@ -463,44 +458,15 @@ def _nvfp4_split_data_scale( base = kv_side.storage_offset() data = torch.as_strided(kv_side, data_shape, data_strides, storage_offset=base) scale = torch.as_strided( - kv_side, scale_shape, scale_strides, storage_offset=base + data_per_kv + kv_side, + scale_shape, + scale_strides, + storage_offset=base + data_per_kv, ).view(torch.float8_e4m3fn) return data, scale -def nvfp4_kv_cache_split_views(kv_cache: torch.Tensor) -> tuple[tuple, tuple]: - """Split an NVFP4 KV cache tensor into data and scale views. - - Accepts either a 5D tensor ``(num_pages, 2, dim_2, dim_3, full_dim)`` - or a 4D single-side tensor ``(num_pages, dim_2, dim_3, full_dim)``. - - Per-page layout: [K_data | K_scale | V_data | V_scale]. - Each KV side is self-contained (data followed by its scale), so the - 5D case simply splits each side independently. - - The returned views are in the same dim order as the input (NHD or - HND), so callers get views matching whichever order they passed in. - - Args: - kv_cache: 5D or 4D uint8 tensor where the last dimension is - ``full_dim = data_dim + scale_dim = 9 * head_size / 16``. - - Returns: - For 5D input: - ``(k_data, v_data), (k_scale, v_scale)`` - For 4D input (single KV side): - ``(data,), (scale,)`` - """ - if kv_cache.dim() == 4: - data, scale = _nvfp4_split_data_scale(kv_cache) - return (data,), (scale,) - - k_data, k_scale = _nvfp4_split_data_scale(kv_cache[:, 0]) - v_data, v_scale = _nvfp4_split_data_scale(kv_cache[:, 1]) - return (k_data, v_data), (k_scale, v_scale) - - def create_kv_caches_with_random_flash( num_blocks: int, block_size: int, @@ -511,14 +477,14 @@ def create_kv_caches_with_random_flash( model_dtype: str | torch.dtype | None = None, seed: int | None = None, device: str | None = "cuda", - cache_layout: str | None = "NHD", + cache_layout: str | None = "LBNHC", ) -> tuple[list[torch.Tensor], list[torch.Tensor]]: set_random_seed(seed) dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) generic_kv_cache_shape = (num_blocks, 2, block_size, num_heads, head_size) - assert cache_layout in ("NHD", "HND") - stride_order = (0, 1, 2, 3, 4) if cache_layout == "NHD" else (0, 1, 3, 2, 4) + assert cache_layout in ("LBNHC", "LBHNC") + stride_order = (0, 1, 2, 3, 4) if cache_layout == "LBNHC" else (0, 1, 3, 2, 4) kv_cache_allocation_shape = tuple(generic_kv_cache_shape[i] for i in stride_order) scale = head_size**-0.5 diff --git a/vllm/v1/attention/backend.py b/vllm/v1/attention/backend.py index af58bfd31a57..6641f5dd6b1e 100644 --- a/vllm/v1/attention/backend.py +++ b/vllm/v1/attention/backend.py @@ -23,7 +23,6 @@ from vllm.model_executor.layers.linear import ColumnParallelLinear from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey from vllm.platforms.interface import DeviceCapability - from vllm.v1.attention.backends.utils import KVCacheLayoutType from vllm.v1.kv_cache_interface import AttentionSpec, KVQuantMode from vllm.v1.kv_cache_interface import get_kv_quant_mode @@ -84,72 +83,14 @@ def get_impl_cls() -> type["AttentionImplBase"]: def get_builder_cls(): # -> Type["AttentionMetadataBuilder"]: raise NotImplementedError - @staticmethod - @abstractmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, - head_size: int, - cache_dtype_str: str = "auto", - ) -> tuple[int, ...]: - raise NotImplementedError - - @classmethod - def get_kv_cache_block_dim( - cls, - block_size: int, - num_kv_heads: int, - head_size: int, - cache_dtype_str: str = "auto", - ) -> int: - """Discover which tensor dim is the block index, since different - backends lay out dims differently.""" - _S = 1234567 - shape = cls.get_kv_cache_shape( - _S, - block_size, - num_kv_heads, - head_size, - cache_dtype_str=cache_dtype_str, - ) - return shape.index(_S) - - @staticmethod - def get_kv_cache_stride_order( - include_num_layers_dimension: bool = False, - ) -> tuple[int, ...]: - """ - Get the physical (memory layout) ordering of the kv cache dimensions. - e.g. if the KV cache shape is - [2, num_blocks, block_size, num_heads, head_size], - and get_kv_cache_stride_order returns (1, 3, 0, 2, 4) then the physical - ordering of dimensions is - [num_blocks, num_heads, 2, block_size, head_size]. - - If this function is unimplemented / raises NotImplementedError, - the physical layout of the KV cache will match the logical shape. - - Args: - include_num_layers_dimension: if True, includes an additional - num_layers dimension, which is assumed to be prepended - to the logical KV cache shape. - With the above example, a return value (2, 4, 0, 1, 3, 5) - corresponds to - [num_blocks, num_heads, num_layers, 2, block_size, head_size]. - - If an additional dimension is NOT included in the returned - tuple, the physical layout will not include a layers dimension. - - Returns: - A tuple of ints which is a permutation of range(len(shape)). - """ - raise NotImplementedError - @classmethod def full_cls_name(cls) -> tuple[str, str]: return (cls.__module__, cls.__qualname__) + @classmethod + def get_required_kv_cache_layout(cls) -> str | None: + return None + @classmethod def get_supported_head_sizes(cls) -> list[int]: return [] @@ -340,10 +281,6 @@ def validate_configuration( invalid_reasons.append(combination_reason) return invalid_reasons - @classmethod - def get_required_kv_cache_layout(cls) -> "KVCacheLayoutType | None": - return None - @classmethod def is_ssm(cls) -> bool: return False diff --git a/vllm/v1/attention/backends/cpu_attn.py b/vllm/v1/attention/backends/cpu_attn.py index 3519691a3c58..347c2ab761f8 100644 --- a/vllm/v1/attention/backends/cpu_attn.py +++ b/vllm/v1/attention/backends/cpu_attn.py @@ -25,7 +25,6 @@ MultipleOf, ) from vllm.v1.attention.backends.utils import ( - KVCacheLayoutType, split_decodes_and_prefills, ) from vllm.v1.kv_cache_interface import AttentionSpec, CrossAttentionSpec @@ -85,20 +84,6 @@ def get_impl_cls() -> type["CPUAttentionBackendImpl"]: def get_builder_cls() -> type["CPUAttentionMetadataBuilder"]: return CPUAttentionMetadataBuilder - @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, - head_size: int, - cache_dtype_str: str = "auto", - ) -> tuple[int, ...]: - return num_blocks, num_kv_heads, block_size, 2 * head_size - - @classmethod - def get_required_kv_cache_layout(cls) -> "KVCacheLayoutType | None": - return "HND" - @staticmethod def use_cascade_attention(*args, **kwargs) -> bool: return False @@ -308,7 +293,7 @@ def forward( key: shape = [num_tokens, num_kv_heads, head_size] value: shape = [num_tokens, num_kv_heads, head_size] kv_cache: shape = - [num_blocks, num_kv_heads, block_size, 2 * head_size] + [num_blocks, 2*num_kv_heads, block_size, head_size] attn_metadata: Metadata for attention. Returns: shape = [num_tokens, num_heads * head_size] @@ -337,13 +322,10 @@ def forward( self.attn_type, ) - # For decoder and cross-attention, use KV cache, size are - # [num_blocks, num_kv_heads, block_size, 2 * head_size] - # Make a view [num_blocks, num_kv_heads, block_size * 2, head_size] - # Then slice KV at dim 2 - num_blocks, num_kv_heads, block_size, _ = kv_cache.size() - kv_cache = kv_cache.view((num_blocks, num_kv_heads, block_size * 2, -1)) - key_cache, value_cache = kv_cache.chunk(2, dim=2) + # K and V are stored as separate head groups; slice them out as + # contiguous per-head tensors (see CPUModelRunner._allocate_kv_caches). + key_cache = kv_cache[:, : self.num_kv_heads] + value_cache = kv_cache[:, self.num_kv_heads :] # key and value may be None in the case of cross attention. They are # calculated once based on the output from the encoder and then cached diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index c56c4ee6e1ff..31e7c0b4413a 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -57,9 +57,6 @@ AttentionMetadataBuilder, CommonAttentionMetadata, ) -from vllm.v1.attention.backends.utils import ( - get_kv_cache_layout, -) from vllm.v1.kv_cache_interface import AttentionSpec logger = init_logger(__name__) @@ -136,39 +133,6 @@ def get_impl_cls() -> type["FlashAttentionImpl"]: def get_builder_cls() -> type["FlashAttentionMetadataBuilder"]: return FlashAttentionMetadataBuilder - @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, - head_size: int, - cache_dtype_str: str = "auto", - ) -> tuple[int, ...]: - if block_size % 16 != 0: - raise ValueError("Block size must be a multiple of 16.") - return (num_blocks, 2, block_size, num_kv_heads, head_size) - - @staticmethod - def get_kv_cache_stride_order( - include_num_layers_dimension: bool = False, - ) -> tuple[int, ...]: - # `stride_order` indicates the permutation that gets - # us from `get_kv_cache_shape` to the actual memory layout we want. - cache_layout = get_kv_cache_layout() - if cache_layout == "NHD" and include_num_layers_dimension: - # (num_blocks, num_layers, 2, block_size, num_kv_heads, head_size) - return (1, 0, 2, 3, 4, 5) - elif cache_layout == "NHD": - stride_order = (0, 1, 2, 3, 4) - elif cache_layout == "HND" and include_num_layers_dimension: - # (num_blocks, num_kv_heads, num_layers, 2, block_size, head_size) - return (1, 4, 0, 2, 3, 5) - elif cache_layout == "HND": - stride_order = (0, 1, 3, 2, 4) - else: - raise ValueError(f"Unknown cache layout format {cache_layout}.") - return stride_order - @classmethod def supports_head_size(cls, head_size: int) -> bool: if head_size % 8 != 0: @@ -683,7 +647,7 @@ def forward( key: shape = [num_tokens, num_kv_heads, head_size] value: shape = [num_tokens, num_kv_heads, head_size] kv_cache: shape = - [num_blocks, 2, block_size, num_kv_heads, head_size] + [num_blocks, num_kv_heads, block_size, 2*head_size] attn_metadata: Metadata for attention. Returns: shape = [num_tokens, num_heads * head_size] @@ -730,8 +694,8 @@ def forward( layer, ) - # For decoder and cross-attention, use KV cache as before - key_cache, value_cache = kv_cache.unbind(1) + # (B, H, N, 2*head_size) -> ((B, N, H, head_size), (B, N, H, head_size)) + key_cache, value_cache = kv_cache.transpose(1, 2).split(self.head_size, dim=-1) # Fix degenerate strides on size-1 dims (e.g. num_kv_heads=1 with TP). # FA3/4 on H100+ uses TMA, which requires ≥16-byte stride alignment. # See vllm.utils.torch_utils.canonicalize_singleton_dim_strides. @@ -861,8 +825,7 @@ def do_kv_cache_update( return # Scatter write into the KV cache using slot_mapping indices. - # No TMA kernel is invoked here, so stride canonicalization is not needed. - key_cache, value_cache = kv_cache.unbind(1) + k_cache, v_cache = kv_cache.transpose(1, 2).split(self.head_size, dim=-1) # Reshape the input keys and values and store them in the cache. # Skip this if sharing KV cache with an earlier attention layer. @@ -874,8 +837,8 @@ def do_kv_cache_update( reshape_and_cache_flash( key, value, - key_cache, - value_cache, + k_cache, + v_cache, slot_mapping, self.kv_cache_dtype, layer._k_scale, @@ -1163,10 +1126,11 @@ def cascade_attention( num_tokens = query.shape[0] block_size = key_cache.shape[-3] + num_kv_heads = key_cache.shape[-2] assert common_prefix_len % block_size == 0 num_common_kv_blocks = common_prefix_len // block_size assert num_common_kv_blocks > 0 - descale_shape = (cu_prefix_query_lens.shape[0] - 1, key_cache.shape[-2]) + descale_shape = (cu_prefix_query_lens.shape[0] - 1, num_kv_heads) # Process shared prefix. prefix_output, prefix_lse = flash_attn_varlen_func( @@ -1194,7 +1158,7 @@ def cascade_attention( num_splits=1 if envs.VLLM_BATCH_INVARIANT else max_num_splits, ) - descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2]) + descale_shape = (cu_query_lens.shape[0] - 1, num_kv_heads) # Process suffix per query. suffix_output, suffix_lse = flash_attn_varlen_func( diff --git a/vllm/v1/attention/backends/flash_attn_diffkv.py b/vllm/v1/attention/backends/flash_attn_diffkv.py index e788b0e3496f..d5a1da20c4d6 100644 --- a/vllm/v1/attention/backends/flash_attn_diffkv.py +++ b/vllm/v1/attention/backends/flash_attn_diffkv.py @@ -21,8 +21,6 @@ if is_flash_attn_varlen_func_available(): from vllm.v1.attention.backends.fa_utils import flash_attn_varlen_func -from vllm.v1.attention.backends.utils import get_kv_cache_layout - from .flash_attn import ( FlashAttentionBackend, FlashAttentionImpl, @@ -49,48 +47,6 @@ def get_name() -> str: def get_impl_cls() -> type["FlashAttentionImpl"]: return FlashAttentionDiffKVImpl - # Do not modify the interface of get_kv_cache_shape, - # but consider head_size_v when returning result. - @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, - head_size: int, - cache_dtype_str: str = "auto", - ) -> tuple[int, ...]: - if block_size % 16 != 0: - raise ValueError("Block size must be a multiple of 16.") - return ( - num_blocks, - block_size, - num_kv_heads, - head_size + FlashAttentionDiffKVBackend.head_size_v, - ) - - @staticmethod - def get_kv_cache_stride_order( - include_num_layers_dimension: bool = False, - ) -> tuple[int, ...]: - # `stride_order` indicates the permutation that gets - # us from `get_kv_cache_shape` to the actual memory layout we want. - cache_layout = get_kv_cache_layout() - if cache_layout == "NHD" and include_num_layers_dimension: - # (num_blocks, num_layers, block_size, - # num_kv_heads, head_size + head_size_v) - return (1, 0, 2, 3, 4) - elif cache_layout == "NHD": - stride_order = (0, 1, 2, 3) - elif cache_layout == "HND" and include_num_layers_dimension: - # (num_blocks, num_kv_heads, num_layers, - # block_size, head_size + head_size_v) - return (1, 3, 0, 2, 4) - elif cache_layout == "HND": - stride_order = (0, 2, 1, 3) - else: - raise ValueError(f"Unknown cache layout format {cache_layout}.") - return stride_order - class FlashAttentionDiffKVImpl(FlashAttentionImpl): vllm_flash_attn_version: int | None @@ -120,21 +76,14 @@ def do_kv_cache_update( # we use direct Q, K, V tensors without caching return - # Unlike standard FlashAttn which splits kv_cache via unbind(0), # DiffKV packs K and V into a single tensor along the last dim: # kv_cache shape: [num_blocks, block_size, num_kv_heads, # head_size_k + head_size_v] - # The triton kernel handles this combined layout directly. - # - # NOTE(woosuk): key and value are padded while slot_mapping is - # not padded. However, we don't need to do key[:num_actual_tokens] - # and value[:num_actual_tokens] because the reshape_and_cache_flash - # op uses the slot_mapping's shape to determine the number of - # actual tokens. + # (B, H, N, C) -> (B, N, H, C) for kernel compatibility. triton_reshape_and_cache_flash_diffkv( key, value, - kv_cache, + kv_cache.transpose(1, 2), slot_mapping, self.kv_cache_dtype, layer._k_scale, @@ -207,8 +156,8 @@ def forward( layer, ) - # For decoder and cross-attention, use KV cache as before - # Different head_size for K and V + # (B, H, N, C) -> (B, N, H, C) for kernel compatibility. + kv_cache = kv_cache.transpose(1, 2) key_cache = kv_cache[..., : self.head_size] value_cache = kv_cache[..., self.head_size :] # Fix degenerate strides on size-1 dims (e.g. num_kv_heads=1 with TP). diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 83e3072546f1..8d63cb87238f 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -46,8 +46,7 @@ canonicalize_singleton_dim_strides, is_quantized_kv_cache, is_strictly_contiguous, - nvfp4_kv_cache_full_dim, - nvfp4_kv_cache_split_views, + nvfp4_split_data_scale, ) from vllm.v1.attention.backend import ( AttentionBackend, @@ -59,12 +58,12 @@ MultipleOf, ) from vllm.v1.attention.backends.utils import ( - KVCacheLayoutType, get_dcp_local_seq_lens, - get_kv_cache_layout, + get_flashinfer_layout_string, get_num_attention_heads_from_layers, get_per_layer_parameters, infer_global_hyperparameters, + resolve_kv_cache_layout, split_decodes_and_prefills, ) from vllm.v1.attention.ops.common import cp_lse_ag_out_rs @@ -107,9 +106,11 @@ def _trtllm_prefill_attn_kvfp8_dequant( src_stride_page, src_stride_kv, src_stride_head, + src_stride_n, DST_K_CACHE_STRIDE: tl.constexpr, DST_KV_CACHE_STRIDE: tl.constexpr, - HEAD_STRIDE: tl.constexpr, + PAGE_SIZE: tl.constexpr, + HEAD_SIZE: tl.constexpr, NUM_KV_HEADS: tl.constexpr, ): batch_idx = tl.program_id(0).to(tl.int64) @@ -125,36 +126,42 @@ def _trtllm_prefill_attn_kvfp8_dequant( v_scale_val = tl.load(v_scale_ptr) mock_page_idx = batch_idx * block_table_stride + mock_block_table_idx + 1 - head_offsets = tl.arange(0, HEAD_STRIDE) + HEAD_STRIDE: tl.constexpr = PAGE_SIZE * HEAD_SIZE + # 2D indexing: source may have non-contiguous block_size stride. + n_idx = tl.arange(0, PAGE_SIZE)[:, None] + d_idx = tl.arange(0, HEAD_SIZE)[None, :] + dst_nd = n_idx * HEAD_SIZE + d_idx for h in range(NUM_KV_HEADS): h_off = tl.cast(h, tl.int64) - # Read K from source (supports non-contiguous page/kv/head strides) - src_k = orig_page_num * src_stride_page + h_off * src_stride_head + head_offsets + src_k = ( + orig_page_num * src_stride_page + + h_off * src_stride_head + + n_idx * src_stride_n + + d_idx + ) fp8_k = tl.load(kv_cache_ptr + src_k) dequant_k = (fp8_k.to(tl.float32) * k_scale_val).to(dequant_dtype) - # Write K to contiguous mock cache - dst_k = mock_page_idx * DST_KV_CACHE_STRIDE + h * HEAD_STRIDE + head_offsets + dst_k = mock_page_idx * DST_KV_CACHE_STRIDE + h * HEAD_STRIDE + dst_nd tl.store(mock_kv_cache_ptr + dst_k, dequant_k) - # Read V from source (offset by src_stride_kv for the V half) src_v = ( orig_page_num * src_stride_page + src_stride_kv + h_off * src_stride_head - + head_offsets + + n_idx * src_stride_n + + d_idx ) fp8_v = tl.load(kv_cache_ptr + src_v) dequant_v = (fp8_v.to(tl.float32) * v_scale_val).to(dequant_dtype) - # Write V to contiguous mock cache dst_v = ( mock_page_idx * DST_KV_CACHE_STRIDE + DST_K_CACHE_STRIDE + h * HEAD_STRIDE - + head_offsets + + dst_nd ) tl.store(mock_kv_cache_ptr + dst_v, dequant_v) @@ -177,9 +184,8 @@ def trtllm_prefill_attn_kvfp8_dequant( kv_cache_stride = k_cache_stride * s[1] strides = kv_cache.stride() - assert strides[3] == head_size and strides[4] == 1, ( - "For kv cache layouts, (block_size, head_size) " - f"dimensions must be contiguous, got strides {strides}" + assert strides[4] == 1, ( + f"The head_size dimension must be contiguous, got strides {strides}" ) new_s = (batch_size * num_of_page_per_token + 1, s[1], s[2], s[3], s[4]) @@ -203,9 +209,11 @@ def trtllm_prefill_attn_kvfp8_dequant( strides[0], strides[1], strides[2], + strides[3], k_cache_stride, kv_cache_stride, - head_stride, + block_size, + head_size, num_kv_heads, ) return mock_kv_cache, mock_block_table @@ -222,7 +230,7 @@ def __init__( else: self._dcp_combine = partial(cp_lse_ag_out_rs, is_lse_base_on_e=False) self._context = BatchPrefillWithPagedKVCacheWrapper( - workspace_buffer, get_kv_cache_layout() + workspace_buffer, get_flashinfer_layout_string() ) self._new_tokens = BatchPrefillWithRaggedKVCacheWrapper(workspace_buffer) @@ -282,7 +290,7 @@ def run( self, layer: torch.nn.Module, prefill_query: torch.Tensor, - kv_cache_permute: torch.Tensor, + kv_cache_tuple: tuple[torch.Tensor, torch.Tensor], key: torch.Tensor, value: torch.Tensor, out: torch.Tensor, @@ -292,7 +300,7 @@ def run( ) output_context_tmp, lse_context_tmp = self._context.run( prefill_query_across_dcp, - kv_cache_permute, + kv_cache_tuple, k_scale=layer._k_scale_float, v_scale=layer._v_scale_float, return_lse=True, @@ -353,41 +361,6 @@ def get_impl_cls() -> type["FlashInferImpl"]: def get_builder_cls() -> type["FlashInferMetadataBuilder"]: return FlashInferMetadataBuilder - @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, - head_size: int, - cache_dtype_str: str = "auto", - ) -> tuple[int, ...]: - if cache_dtype_str == "nvfp4": - # Packed layout: fp4 data + fp8 block scales in last dim - last_dim = nvfp4_kv_cache_full_dim(head_size) - return (num_blocks, 2, block_size, num_kv_heads, last_dim) - return (num_blocks, 2, block_size, num_kv_heads, head_size) - - @staticmethod - def get_kv_cache_stride_order( - include_num_layers_dimension: bool = False, - ) -> tuple[int, ...]: - # `stride_order` indicates the permutation that gets us from - # `get_kv_cache_shape` to the actual memory layout we want. - cache_layout = get_kv_cache_layout() - if cache_layout == "NHD" and include_num_layers_dimension: - # (num_blocks, num_layers, 2, block_size, num_kv_heads, head_size) - return (1, 0, 2, 3, 4, 5) - elif cache_layout == "NHD": - stride_order = (0, 1, 2, 3, 4) - elif cache_layout == "HND" and include_num_layers_dimension: - # (num_blocks, 2, num_kv_heads, num_layers, block_size, head_size) - return (1, 2, 4, 0, 3, 5) - elif cache_layout == "HND": - stride_order = (0, 1, 3, 2, 4) - else: - raise ValueError(f"Unknown cache layout format {cache_layout}.") - return stride_order - @staticmethod def get_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype: if kv_cache_dtype in ("fp8", "fp8_e4m3"): @@ -426,13 +399,6 @@ def supports_sink(cls) -> bool: # Check if TRTLLM is supported on this platform return supports_trtllm_attention() - @classmethod - def get_required_kv_cache_layout(cls) -> KVCacheLayoutType | None: - capability = current_platform.get_device_capability() - if capability is not None and capability.major == 10: - return "HND" - return None - forward_includes_kv_cache_update: bool = False @@ -780,7 +746,7 @@ def _get_prefill_wrapper( backend = "trtllm-gen" if self.is_kvcache_nvfp4 else "auto" self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper( self._get_workspace_buffer(), - get_kv_cache_layout(), + get_flashinfer_layout_string(), backend=backend, ) assert self._prefill_wrapper is not None @@ -806,7 +772,7 @@ def _get_decode_wrapper(self, batch_size: int, use_cudagraph: bool = False): backend = "trtllm-gen" if self.is_kvcache_nvfp4 else "auto" decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( self._get_workspace_buffer(), - get_kv_cache_layout(), + get_flashinfer_layout_string(), use_cuda_graph=use_cudagraph, paged_kv_indptr_buffer=paged_kv_indptr, paged_kv_indices_buffer=paged_kv_indices, @@ -829,7 +795,7 @@ def _get_decode_wrapper(self, batch_size: int, use_cudagraph: bool = False): def _get_cascade_wrapper(self): if self._cascade_wrapper is None: self._cascade_wrapper = MultiLevelCascadeAttentionWrapper( - 2, self._get_workspace_buffer(), get_kv_cache_layout() + 2, self._get_workspace_buffer(), get_flashinfer_layout_string() ) return self._cascade_wrapper @@ -1383,9 +1349,7 @@ def forward( query: shape = [num_tokens, num_heads, head_size] key: shape = [num_tokens, num_kv_heads, head_size] value: shape = [num_tokens, num_kv_heads, head_size] - kv_cache: KV cache tensor with different possible shapes: - - NHD: [num_blocks, 2, block_size, num_kv_heads, head_size] - - HND: [num_blocks, 2, num_kv_heads, block_size, head_size] + kv_cache: [num_blocks, num_kv_heads, block_size, 2*head_size] attn_metadata: Metadata for attention. Returns: shape = [num_tokens, num_heads * head_size] @@ -1476,19 +1440,9 @@ def forward( output_padded = output output = output[:num_actual_tokens] - if attn_metadata.use_cascade: - # Cascade attention (rare case). - assert attn_metadata.cascade_wrapper is not None - output.copy_(attn_metadata.cascade_wrapper.run(query, kv_cache)) - return output - - # When using spec decoding, num_decodes can be < num_decode_tokens - # because some decode requests may have more than one query token. - num_decode_tokens = attn_metadata.num_decode_tokens - num_prefill_tokens = attn_metadata.num_prefill_tokens - - stride_order = FlashInferBackend.get_kv_cache_stride_order() - kv_cache_permute = kv_cache.permute(*stride_order) # HND and contiguous + # Permute to FlashInfer's expected layout (metadata-only). + stride_order = resolve_kv_cache_layout().layer_stride_order + kv_cache_permute = kv_cache.permute(*stride_order) # Fix degenerate strides on any size-1 dimension (e.g. num_kv_heads=1 # with TP=8). PyTorch permits non-canonical strides on size-1 dims; # CUDA TMA requires ≥16-byte alignment on all non-outermost strides. @@ -1505,14 +1459,35 @@ def forward( ) kv_cache_permute = fixed - # For NVFP4, the kv_cache last dim is full_dim (data + scale packed). - # Split into correctly-strided data and scale views. + # Split K/V — zero-copy views. + # For nvfp4, K and V are stored as separate head groups (2*H heads + # in dim 1); for other dtypes, K and V are packed in the content dim. nvfp4_kv_data = None nvfp4_kv_block_scales = None if self.is_kvcache_nvfp4: - nvfp4_kv_data, nvfp4_kv_block_scales = nvfp4_kv_cache_split_views( - kv_cache_permute - ) + kv_cache_tuple = kv_cache_permute.split(self.num_kv_heads, dim=1) + k_data, k_sf = nvfp4_split_data_scale(kv_cache_tuple[0]) + v_data, v_sf = nvfp4_split_data_scale(kv_cache_tuple[1]) + nvfp4_kv_data = (k_data, v_data) + nvfp4_kv_block_scales = (k_sf, v_sf) + else: + kv_cache_tuple = kv_cache_permute.split(self.head_size, dim=-1) + + flashinfer_layout = get_flashinfer_layout_string() + if flashinfer_layout == "HND": + trtllm_kv_cache = tuple(t.transpose(1, 2) for t in kv_cache_tuple) + trtllm_kv_layout = "NHD" + else: + trtllm_kv_cache = kv_cache_tuple + trtllm_kv_layout = flashinfer_layout + + if attn_metadata.use_cascade: + assert attn_metadata.cascade_wrapper is not None + output.copy_(attn_metadata.cascade_wrapper.run(query, kv_cache_tuple)) + return output + + num_decode_tokens = attn_metadata.num_decode_tokens + num_prefill_tokens = attn_metadata.num_prefill_tokens use_dcp = self.dcp_world_size > 1 @@ -1544,7 +1519,7 @@ def forward( prefill_wrapper.run( layer, prefill_query, - kv_cache_permute, + kv_cache_tuple, key[num_decode_tokens:], value[num_decode_tokens:], out=output[num_decode_tokens:], @@ -1561,7 +1536,9 @@ def forward( assert prefill_wrapper._causal if self.is_kvcache_nvfp4: - kv_cache_permute = nvfp4_kv_data + kv_cache_for_fi = nvfp4_kv_data + else: + kv_cache_for_fi = kv_cache_tuple kv_cache_sf = ( nvfp4_kv_block_scales if self.is_kvcache_nvfp4 else None ) @@ -1579,7 +1556,7 @@ def forward( prefill_wrapper.run( prefill_query, - kv_cache_permute, + kv_cache_for_fi, k_scale=layer._k_scale_float, v_scale=layer._v_scale_float, out=out_prefill, @@ -1602,8 +1579,6 @@ def forward( block_tables_prefill = attn_metadata.prefill.block_tables seq_lens_prefill = attn_metadata.prefill.seq_lens - # This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND - assert get_kv_cache_layout() == "HND" assert is_strictly_contiguous(prefill_query) assert is_strictly_contiguous(workspace_buffer) assert is_strictly_contiguous(block_tables_prefill) @@ -1628,6 +1603,7 @@ def forward( out = self._nvfp4_fp8_out[:num_prefill_tokens] prefill_kv_block_scales = None + prefill_kv_layout = trtllm_kv_layout if self.is_kvcache_nvfp4: # NVFP4 trtllm-gen kernel requires FP8 query. assert attn_metadata.q_data_type == FP8_DTYPE, ( @@ -1638,6 +1614,7 @@ def forward( mock_kv_cache = nvfp4_kv_data mock_block_table = block_tables_prefill prefill_kv_block_scales = nvfp4_kv_block_scales + prefill_kv_layout = flashinfer_layout elif ( attn_metadata.q_data_type != FP8_DTYPE and self.kv_cache_dtype.startswith("fp8") @@ -1645,28 +1622,22 @@ def forward( # TRTLLM prefill attention does not support BF16 Q # and fp8 kv cache. So to enable prefill attention # with fp8 kv cache, we can construct a mock block - # and mock kv cache with BF16 KV involved in the prefill - # - kv_cache_permute = canonicalize_singleton_dim_strides( - kv_cache_permute - ) - kv_strides = kv_cache_permute.stride() - assert ( - kv_strides[-1] == 1 - and kv_strides[-2] == kv_cache_permute.shape[-1] - ), ( - "KV cache inner dims (block_size, head_size) must be " - f"contiguous, got strides {kv_strides}" - ) + # and mock kv cache with BF16 KV involved in the prefill. + B_kv, H_kv, N_kv = kv_cache_permute.shape[:3] + kv_cache_5d = kv_cache_permute.view( + B_kv, H_kv, N_kv, 2, self.head_size + ).permute(0, 3, 1, 2, 4) mock_kv_cache, mock_block_table = trtllm_prefill_attn_kvfp8_dequant( - kv_cache_permute, + kv_cache_5d, block_tables_prefill, layer._k_scale, layer._v_scale, attn_metadata.q_data_type, ) + if trtllm_kv_layout == "NHD": + mock_kv_cache = mock_kv_cache.transpose(2, 3) else: - mock_kv_cache = kv_cache_permute + mock_kv_cache = trtllm_kv_cache mock_block_table = block_tables_prefill trtllm_batch_context_with_kv_cache( @@ -1687,6 +1658,7 @@ def forward( o_sf_scale=self.o_sf_scale, out=out, kv_cache_sf=prefill_kv_block_scales, + kv_layout=prefill_kv_layout, ) if needs_fp8_out: @@ -1707,7 +1679,9 @@ def forward( assert decode_wrapper._sm_scale == self.scale if self.is_kvcache_nvfp4: - kv_cache_permute = nvfp4_kv_data + kv_cache_for_fi = nvfp4_kv_data + else: + kv_cache_for_fi = kv_cache_tuple kv_cache_sf = nvfp4_kv_block_scales if self.is_kvcache_nvfp4 else None # NVFP4 kernel only supports FP8 output. @@ -1730,7 +1704,7 @@ def forward( ) decode_wrapper.run( decode_query, - kv_cache_permute, + kv_cache_for_fi, k_scale=layer._k_scale_float, v_scale=layer._v_scale_float, out=output_tmp, @@ -1746,7 +1720,7 @@ def forward( else: decode_wrapper.run( decode_query, - kv_cache_permute, + kv_cache_for_fi, k_scale=layer._k_scale_float, v_scale=layer._v_scale_float, out=out_decode, @@ -1767,21 +1741,10 @@ def forward( block_tables_decode = attn_metadata.decode.block_tables seq_lens_decode = attn_metadata.decode.seq_lens - # This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND - assert get_kv_cache_layout() == "HND" assert is_strictly_contiguous(decode_query) assert is_strictly_contiguous(workspace_buffer) assert is_strictly_contiguous(block_tables_decode) assert is_strictly_contiguous(seq_lens_decode) - kv_cache_permute = canonicalize_singleton_dim_strides(kv_cache_permute) - kv_strides = kv_cache_permute.stride() - assert ( - kv_strides[-1] == 1 and kv_strides[-2] == kv_cache_permute.shape[-1] - ), ( - "KV cache inner dims (block_size, head_size) must be " - f"contiguous, got strides {kv_strides}" - ) - if output.dtype == FP4_DTYPE: assert self.o_sf_scale is not None out = FP4Tensor( @@ -1807,11 +1770,16 @@ def forward( else: q_len_per_req = num_decode_tokens // attn_metadata.num_decodes + if self.is_kvcache_nvfp4: + decode_kv_data = nvfp4_kv_data + decode_kv_layout = flashinfer_layout + else: + decode_kv_data = trtllm_kv_cache + decode_kv_layout = trtllm_kv_layout + trtllm_batch_decode_with_kv_cache( query=decode_query, - kv_cache=( - nvfp4_kv_data if self.is_kvcache_nvfp4 else kv_cache_permute - ), + kv_cache=decode_kv_data, workspace_buffer=workspace_buffer, block_tables=block_tables_decode, seq_lens=seq_lens_decode, @@ -1826,6 +1794,7 @@ def forward( kv_cache_sf=( nvfp4_kv_block_scales if self.is_kvcache_nvfp4 else None ), + kv_layout=decode_kv_layout, ) if needs_fp8_out: @@ -1848,8 +1817,11 @@ def do_kv_cache_update( # and value[:num_actual_tokens] because the reshape_and_cache_flash # op uses the slot_mapping's shape to determine the number of # actual tokens. - k_cache = kv_cache[:, 0] - v_cache = kv_cache[:, 1] + kv_cache = kv_cache.transpose(1, 2) + if self.is_kvcache_nvfp4: + k_cache, v_cache = kv_cache.split(self.num_kv_heads, dim=-2) + else: + k_cache, v_cache = kv_cache.split(self.head_size, dim=-1) torch.ops._C_cache_ops.reshape_and_cache_flash( key, value, diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index b87014252018..f74891f11cac 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -117,24 +117,6 @@ def supports_mm_prefix(cls) -> bool: def get_impl_cls() -> type["FlexAttentionImpl"]: return FlexAttentionImpl - @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, - head_size: int, - cache_dtype_str: str = "auto", - ) -> tuple[int, ...]: - return (num_blocks, 2, block_size, num_kv_heads, head_size) - - @staticmethod - def get_kv_cache_stride_order( - include_num_layers_dimension: bool = False, - ) -> tuple[int, ...]: - if include_num_layers_dimension: - return (1, 0, 3, 2, 4, 5) - return (0, 2, 1, 3, 4) - @staticmethod def get_builder_cls() -> type["FlexAttentionMetadataBuilder"]: return FlexAttentionMetadataBuilder @@ -1063,7 +1045,8 @@ def do_kv_cache_update( if self.attn_type == AttentionType.ENCODER_ONLY: return - key_cache, value_cache = kv_cache.unbind(1) + kv_cache = kv_cache.transpose(1, 2) + key_cache, value_cache = kv_cache.split(self.head_size, dim=-1) torch.ops._C_cache_ops.reshape_and_cache_flash( key, value, @@ -1168,11 +1151,12 @@ def forward( else: assert self.attn_type == AttentionType.DECODER - key_cache, value_cache = kv_cache.unbind(1) + kv_cache = kv_cache.transpose(1, 2) + key_cache, value_cache = kv_cache.split(self.head_size, dim=-1) # Flatten (num_blocks, block_size) into a single token dim - key_cache = key_cache.view(-1, self.num_kv_heads, self.head_size) - value_cache = value_cache.view(-1, self.num_kv_heads, self.head_size) + key_cache = key_cache.reshape(-1, self.num_kv_heads, self.head_size) + value_cache = value_cache.reshape(-1, self.num_kv_heads, self.head_size) query, key_tensor, value_tensor = map( lambda x: self.view_as_4d(x).permute(0, 2, 1, 3), (query, key_cache, value_cache), diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index bd947296e8bc..e767952ebe19 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -326,6 +326,7 @@ def forward_mqa( if is_quantized_kv_cache(self.kv_cache_dtype): raise NotImplementedError("FP8 FlashAttention MLA not yet supported") + kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.unsqueeze(2) kv_c_cache = kv_c_and_k_pe_cache[..., : self.kv_lora_rank] k_pe_cache = kv_c_and_k_pe_cache[..., self.kv_lora_rank :] @@ -336,8 +337,8 @@ def forward_mqa( attn_out = flash_attn_varlen_func( q=q_pe, - k=k_pe_cache.unsqueeze(-2), # Add head dim of 1 - v=kv_c_cache.unsqueeze(-2), # Add head dim of 1 + k=k_pe_cache, + v=kv_c_cache, q_v=q_nope, max_seqlen_q=max_seqlen_q, cu_seqlens_q=attn_metadata.decode.query_start_loc, diff --git a/vllm/v1/attention/backends/mla/flashinfer_mla.py b/vllm/v1/attention/backends/mla/flashinfer_mla.py index e98bee9d79b5..4750f0bc18c0 100644 --- a/vllm/v1/attention/backends/mla/flashinfer_mla.py +++ b/vllm/v1/attention/backends/mla/flashinfer_mla.py @@ -23,7 +23,6 @@ AttentionType, MultipleOf, ) -from vllm.v1.attention.backends.utils import KVCacheLayoutType logger = init_logger(__name__) @@ -91,10 +90,6 @@ def supports_combination( ) return None - @classmethod - def get_required_kv_cache_layout(cls) -> "KVCacheLayoutType | None": - return "HND" - g_fi_workspace = torch.zeros( FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE, @@ -189,7 +184,7 @@ def forward_mqa( o = trtllm_batch_decode_with_kv_cache_mla( query=q, - kv_cache=kv_c_and_k_pe_cache.unsqueeze(1), + kv_cache=kv_c_and_k_pe_cache, workspace_buffer=self._workspace_buffer, qk_nope_head_dim=self.qk_nope_head_dim, kv_lora_rank=self.kv_lora_rank, diff --git a/vllm/v1/attention/backends/mla/flashinfer_mla_sparse.py b/vllm/v1/attention/backends/mla/flashinfer_mla_sparse.py index 842153f40396..0a77f53d7e8f 100644 --- a/vllm/v1/attention/backends/mla/flashinfer_mla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashinfer_mla_sparse.py @@ -41,7 +41,6 @@ from vllm.v1.attention.backends.mla.sparse_utils import ( triton_convert_req_index_to_global_index, ) -from vllm.v1.attention.backends.utils import KVCacheLayoutType from vllm.v1.kv_cache_interface import AttentionSpec if TYPE_CHECKING: @@ -130,20 +129,6 @@ def supports_combination( return "FlashInfer MLA Sparse requires model with index_topk config" return None - @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, # assumed to be 1 for MLA - head_size: int, - cache_dtype_str: str = "auto", - ) -> tuple[int, ...]: - return (num_blocks, block_size, head_size) - - @classmethod - def get_required_kv_cache_layout(cls) -> "KVCacheLayoutType | None": - return "HND" - @dataclass class FlashInferMLASparseMetadata(AttentionMetadata): @@ -350,7 +335,7 @@ def forward_mqa( o = trtllm_batch_decode_with_kv_cache_mla( query=q.unsqueeze(1), - kv_cache=kv_c_and_k_pe_cache.unsqueeze(1), + kv_cache=kv_c_and_k_pe_cache, workspace_buffer=self._workspace_buffer, qk_nope_head_dim=self.qk_nope_head_dim, kv_lora_rank=self.kv_lora_rank, diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 2f6058d69aeb..760624dfbecb 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -305,7 +305,7 @@ def forward_mqa( if is_quantized_kv_cache(self.kv_cache_dtype): o, lse = flash_mla_with_kvcache_fp8( q=q, - k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1 + k_cache=kv_c_and_k_pe_cache.unsqueeze(2), block_table=attn_metadata.decode.block_table, cache_seqlens=attn_metadata.decode.seq_lens, head_dim_v=self.kv_lora_rank, @@ -319,7 +319,7 @@ def forward_mqa( else: o, lse = flash_mla_with_kvcache( q=q, - k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1 + k_cache=kv_c_and_k_pe_cache.unsqueeze(2), block_table=attn_metadata.decode.block_table, cache_seqlens=attn_metadata.decode.seq_lens, head_dim_v=self.kv_lora_rank, diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index 9140a6fccd55..8ec868be2b34 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -135,19 +135,15 @@ def is_sparse(cls) -> bool: def supports_compute_capability(cls, capability: DeviceCapability) -> bool: return capability.major in [9, 10] + +class DeepseekV4FlashMLASparseBackend(FlashMLASparseBackend): @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, # assumed to be 1 for MLA - head_size: int, - cache_dtype_str: str = "auto", - ) -> tuple[int, ...]: - if cache_dtype_str == "fp8_ds_mla": - # V3.2 main MLA: 656-byte custom storage format. See module docstring. - return (num_blocks, block_size, 656) - else: - return (num_blocks, block_size, head_size) + def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: + return [256] + + @staticmethod + def get_name() -> str: + return "V4_FLASHMLA_SPARSE" @dataclass @@ -332,8 +328,7 @@ def __init__( ) self.compress_ratio = 1 if self.is_deepseek_v4: - assert hasattr(self.kv_cache_spec, "compress_ratio") - self.compress_ratio = self.kv_cache_spec.compress_ratio + self.compress_ratio = self.kv_cache_spec.tokens_per_state # Pre-allocate compressed slot mapping buffer for CUDA graph # address stability when compress_ratio > 1. if self.compress_ratio > 1: @@ -942,7 +937,7 @@ def _fp8_flash_mla_kernel( out, lse = flash_mla_with_kvcache( q=q, - k_cache=kv_c_and_k_pe_cache.view(torch.uint8).unsqueeze(-2), + k_cache=kv_c_and_k_pe_cache.view(torch.uint8).unsqueeze(2), block_table=kernel_metadata.dummy_block_table, head_dim_v=512, cache_seqlens=kernel_metadata.cache_lens, @@ -966,7 +961,7 @@ def _bf16_flash_mla_kernel( ) -> torch.Tensor: num_tokens = q.shape[0] kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view( - -1, 1, kv_c_and_k_pe_cache.shape[-1] + -1, 1, 1, kv_c_and_k_pe_cache.shape[-1] ) # NOTE(Chen): kernel requires num_local_head to be a multiple of diff --git a/vllm/v1/attention/backends/mla/indexer.py b/vllm/v1/attention/backends/mla/indexer.py index 2870ec9a15c0..a410b1dac441 100644 --- a/vllm/v1/attention/backends/mla/indexer.py +++ b/vllm/v1/attention/backends/mla/indexer.py @@ -132,28 +132,6 @@ def get_supported_head_sizes(cls) -> list[int]: def get_builder_cls() -> type["DeepseekV32IndexerMetadataBuilder"]: return DeepseekV32IndexerMetadataBuilder - @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, - head_size: int, - cache_dtype_str: str = "auto", - ) -> tuple[int, ...]: - assert num_kv_heads == 1 - return (num_blocks, block_size, head_size) - - @staticmethod - def get_kv_cache_stride_order( - include_num_layers_dimension: bool = False, - ) -> tuple[int, ...]: - if include_num_layers_dimension: - # DeepseekV32Indexer kernels do not support cross-layer - # KV cache layout. Identity permutation keeps num_layers - # first, signaling incompatibility. - return (0, 1, 2, 3) - return (0, 1, 2) - class DeepseekV4IndexerBackend(DeepseekV32IndexerBackend): @staticmethod @@ -326,7 +304,7 @@ def __init__(self, *args, **kwargs): self.compress_ratio = 1 # Get compress_ratio for DeepseekV4 support if isinstance(self.kv_cache_spec, MLAAttentionSpec): - self.compress_ratio = self.kv_cache_spec.compress_ratio + self.compress_ratio = self.kv_cache_spec.tokens_per_state # Pre-allocate buffers for CUDA graph compatibility when if self.compress_ratio > 1: diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py index a58ecf2c651f..8dbe4b468f88 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py @@ -291,16 +291,6 @@ def get_builder_cls() -> type["ROCMAiterMLASparseMetadataBuilder"]: def get_impl_cls() -> type["ROCMAiterMLASparseImpl"]: return ROCMAiterMLASparseImpl - @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, # assumed to be 1 for MLA - head_size: int, - cache_dtype_str: str = "auto", - ) -> tuple[int, ...]: - return (num_blocks, block_size, head_size) - @classmethod def is_mla(cls) -> bool: return True diff --git a/vllm/v1/attention/backends/mla/sparse_swa.py b/vllm/v1/attention/backends/mla/sparse_swa.py index f0e444e493c4..05828c8e6d89 100644 --- a/vllm/v1/attention/backends/mla/sparse_swa.py +++ b/vllm/v1/attention/backends/mla/sparse_swa.py @@ -75,6 +75,10 @@ def __init__( self.block_size = 64 assert self.dtype == torch.uint8 + def bind_kv_cache(self, kv_cache: torch.Tensor) -> None: + # [B, H=1, N, C] -> [B, N, C] + self.kv_cache = kv_cache.squeeze(1) + def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec: return SlidingWindowMLASpec( block_size=self.block_size, @@ -120,30 +124,6 @@ def get_builder_cls() -> type["DeepseekSparseSWAMetadataBuilder"]: return DeepseekV4ROCMAiterSparseSWAMetadataBuilder return DeepseekSparseSWAMetadataBuilder - @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, - head_size: int, - cache_dtype_str: str = "auto", - ) -> tuple[int, ...]: - assert num_kv_heads == 1 - if cache_dtype_str == "fp8_ds_mla": - # DeepseekV4 SWA: 584B per token (448 NoPE + 128 RoPE + 8 fp8 scale). - # head_size passed in is the semantic head_dim (512). - return (num_blocks, block_size, 584) - else: - return (num_blocks, block_size, head_size) - - @staticmethod - def get_kv_cache_stride_order( - include_num_layers_dimension: bool = False, - ) -> tuple[int, ...]: - if include_num_layers_dimension: - return (0, 1, 2, 3) - return (0, 1, 2) - @dataclass class DeepseekSparseSWAMetadata: @@ -206,7 +186,7 @@ def __init__(self, *args, **kwargs): assert isinstance(self.kv_cache_spec, SlidingWindowMLASpec | MLAAttentionSpec) mla_spec = cast(SlidingWindowMLASpec | MLAAttentionSpec, self.kv_cache_spec) self.head_size = mla_spec.head_size # Already considered quantization. - self.compress_ratio = mla_spec.compress_ratio + self.compress_ratio = mla_spec.tokens_per_state self.block_size = mla_spec.block_size # Handle MTP: adjust decode_threshold like the indexer does diff --git a/vllm/v1/attention/backends/mla/tokenspeed_mla.py b/vllm/v1/attention/backends/mla/tokenspeed_mla.py index 6c8dedd77f27..d70a05d7a1e7 100644 --- a/vllm/v1/attention/backends/mla/tokenspeed_mla.py +++ b/vllm/v1/attention/backends/mla/tokenspeed_mla.py @@ -23,7 +23,6 @@ AttentionType, MultipleOf, ) -from vllm.v1.attention.backends.utils import KVCacheLayoutType logger = init_logger(__name__) @@ -123,10 +122,6 @@ def supports_combination( ) return None - @classmethod - def get_required_kv_cache_layout(cls) -> "KVCacheLayoutType | None": - return "HND" - class TokenspeedMLAImpl(MLACommonImpl[MLACommonMetadata]): def __init__( @@ -254,8 +249,7 @@ def forward_mqa( q.device, self.num_heads, self.kv_lora_rank ) - # vLLM kv_c_and_k_pe_cache is already (num_blocks, block_size, head_size). - # tokenspeed_mla_decode wants 3D — pass as-is (no unsqueeze, unlike trtllm). + # tokenspeed_mla_decode expects 3D (num_blocks, block_size, head_size). o = tokenspeed_mla_decode( query=q, kv_cache=kv_c_and_k_pe_cache, diff --git a/vllm/v1/attention/backends/mla/triton_mla.py b/vllm/v1/attention/backends/mla/triton_mla.py index c2aa5edccb66..77b88c44b5d0 100644 --- a/vllm/v1/attention/backends/mla/triton_mla.py +++ b/vllm/v1/attention/backends/mla/triton_mla.py @@ -188,7 +188,6 @@ def forward_mqa( device=q.device, ) - # Add a head dim of 1 kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.unsqueeze(2) kv_c_cache = kv_c_and_k_pe_cache[..., : self.kv_lora_rank] PAGE_SIZE = kv_c_and_k_pe_cache.size(1) diff --git a/vllm/v1/attention/backends/mla/xpu_mla_sparse.py b/vllm/v1/attention/backends/mla/xpu_mla_sparse.py index 2fa91d018388..64cc9901b9ea 100644 --- a/vllm/v1/attention/backends/mla/xpu_mla_sparse.py +++ b/vllm/v1/attention/backends/mla/xpu_mla_sparse.py @@ -66,16 +66,6 @@ def is_mla(cls) -> bool: def is_sparse(cls) -> bool: return True - @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, # assumed to be 1 for MLA - head_size: int, - cache_dtype_str: str = "auto", - ) -> tuple[int, ...]: - return (num_blocks, block_size, head_size) - @classmethod def get_supported_head_sizes(cls) -> list[int]: return [576] @@ -207,7 +197,7 @@ def _forward_bf16_kv( ) -> torch.Tensor: num_tokens = q.shape[0] kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view( - -1, 1, kv_c_and_k_pe_cache.shape[-1] + -1, 1, 1, kv_c_and_k_pe_cache.shape[-1] ) topk_indices = topk_indices.view(num_tokens, 1, -1) diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index a9fa45debcfc..b810c1c7a268 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -754,18 +754,6 @@ def get_impl_cls() -> type["AiterFlashAttentionImpl"]: def get_builder_cls() -> type["AiterFlashAttentionMetadataBuilder"]: return AiterFlashAttentionMetadataBuilder - @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, - head_size: int, - cache_dtype_str: str = "auto", - ) -> tuple[int, ...]: - if block_size % 16 != 0: - raise ValueError("Block size must be a multiple of 16.") - return (num_blocks, 2, block_size, num_kv_heads, head_size) - @classmethod def supports_compute_capability(cls, capability: DeviceCapability) -> bool: from vllm.platforms.rocm import on_mi3xx @@ -1057,7 +1045,8 @@ def forward( # Whenever making a change in this method, please benchmark the # performance to make sure it does not introduce any overhead. num_actual_tokens = attn_metadata.num_actual_tokens - key_cache, value_cache = kv_cache.unbind(1) + kv_cache = kv_cache.transpose(1, 2) + key_cache, value_cache = kv_cache.split(self.head_size, dim=-1) if is_quantized_kv_cache(self.kv_cache_dtype): key_cache = key_cache.view(current_platform.fp8_dtype()) @@ -1384,7 +1373,8 @@ def do_kv_cache_update( kv_cache: torch.Tensor, slot_mapping: torch.Tensor, ): - key_cache, value_cache = kv_cache.unbind(1) + kv_cache = kv_cache.transpose(1, 2) + key_cache, value_cache = kv_cache.split(self.head_size, dim=-1) # key and value may be None in the case of cross attention. They are # calculated once based on the output from the encoder and then cached @@ -1452,7 +1442,8 @@ def do_rope_and_kv_cache_update( kv_cache: torch.Tensor, layer_slot_mapping: torch.Tensor, ): - key_cache, value_cache = kv_cache.unbind(1) + kv_cache = kv_cache.transpose(1, 2) + key_cache, value_cache = kv_cache.split(self.head_size, dim=-1) flash_layout = True is_fp8_kv_cache = is_quantized_kv_cache(self.kv_cache_dtype) diff --git a/vllm/v1/attention/backends/rocm_aiter_unified_attn.py b/vllm/v1/attention/backends/rocm_aiter_unified_attn.py index 984fc20ecaff..bf2b934f6bd0 100644 --- a/vllm/v1/attention/backends/rocm_aiter_unified_attn.py +++ b/vllm/v1/attention/backends/rocm_aiter_unified_attn.py @@ -64,18 +64,6 @@ def get_name() -> str: def get_impl_cls() -> type["RocmAiterUnifiedAttentionImpl"]: return RocmAiterUnifiedAttentionImpl - @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, - head_size: int, - cache_dtype_str: str = "auto", - ) -> tuple[int, ...]: - if block_size % 16 != 0: - raise ValueError("Block size must be a multiple of 16.") - return (num_blocks, 2, block_size, num_kv_heads, head_size) - @staticmethod def use_cascade_attention(*args, **kwargs) -> bool: return False @@ -134,30 +122,6 @@ def __init__( self.unified_attention = unified_attention self.supports_quant_query_input = True - def _split_kv_cache( - self, kv_cache: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor]: - if self.attn_type != AttentionType.ENCODER_DECODER: - return kv_cache.unbind(1) - - # NOTE: Encoder-decoder layers can share the same raw KV allocation with - # ROCM_ATTN decoder layers, whose physical layout is K/V first. Keep - # this cross-attention path on that physical layout so block IDs do not - # alias different bytes across the shared allocation. - num_blocks, _, block_size, num_kv_heads, head_size = kv_cache.shape - block_stride = block_size * num_kv_heads * head_size - kv_cache = kv_cache.as_strided( - (2, num_blocks, block_size, num_kv_heads, head_size), - ( - num_blocks * block_stride, - block_stride, - num_kv_heads * head_size, - head_size, - 1, - ), - ) - return kv_cache.unbind(0) - def forward( self, layer: torch.nn.Module, @@ -177,7 +141,7 @@ def forward( key: shape = [num_tokens, num_kv_heads, head_size] value: shape = [num_tokens, num_kv_heads, head_size] kv_cache: shape = - [num_blocks, 2, block_size, num_kv_heads, head_size] + [num_blocks, num_kv_heads, block_size, 2 * head_size] attn_metadata: Metadata for attention. Returns: shape = [num_tokens, num_heads * head_size] @@ -218,7 +182,8 @@ def forward( layer, ) - key_cache, value_cache = self._split_kv_cache(kv_cache) + kv_cache = kv_cache.transpose(1, 2) + key_cache, value_cache = kv_cache.split(self.head_size, dim=-1) softmax_scale = self.scale if is_quantized_kv_cache(self.kv_cache_dtype): @@ -267,7 +232,8 @@ def do_kv_cache_update( # For encoder attention, # we use direct Q, K, V tensors without caching return - key_cache, value_cache = self._split_kv_cache(kv_cache) + kv_cache = kv_cache.transpose(1, 2) + key_cache, value_cache = kv_cache.split(self.head_size, dim=-1) # Reshape the input keys and values and store them in the cache. ops.reshape_and_cache_flash( @@ -300,7 +266,8 @@ def do_rope_and_kv_cache_update( # For encoder attention, # we use direct Q, K, V tensors without caching return - key_cache, value_cache = self._split_kv_cache(kv_cache) + kv_cache = kv_cache.transpose(1, 2) + key_cache, value_cache = kv_cache.split(self.head_size, dim=-1) flash_layout = True is_fp8_kv_cache = is_quantized_kv_cache(self.kv_cache_dtype) diff --git a/vllm/v1/attention/backends/rocm_attn.py b/vllm/v1/attention/backends/rocm_attn.py index 2f6c48e3df34..74007e73ee03 100644 --- a/vllm/v1/attention/backends/rocm_attn.py +++ b/vllm/v1/attention/backends/rocm_attn.py @@ -29,7 +29,6 @@ ) from vllm.v1.attention.ops.chunked_prefill_paged_decode import ( chunked_prefill_paged_decode, - has_native_kv_cache_layout, ) from vllm.v1.attention.ops.paged_attn import PagedAttention from vllm.v1.attention.ops.triton_reshape_and_cache_flash import ( @@ -128,6 +127,7 @@ def build( use_cascade = common_prefix_len > 0 + prefix_scheduler_metadata = None if use_cascade: cu_prefix_query_lens = torch.tensor( [0, num_actual_tokens], dtype=torch.int32, device=self.device @@ -239,18 +239,6 @@ def supports_attn_type(cls, attn_type: str) -> bool: AttentionType.ENCODER_ONLY, ) - @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, - head_size: int, - cache_dtype_str: str = "auto", - ) -> tuple[int, ...]: - if block_size % 16 != 0: - raise ValueError("Block size must be a multiple of 16.") - return (2, num_blocks, block_size, num_kv_heads, head_size) - @staticmethod def use_cascade_attention(*args, **kwargs) -> bool: return False @@ -376,7 +364,7 @@ def forward( key: shape = [num_tokens, num_kv_heads, head_size] value: shape = [num_tokens, num_kv_heads, head_size] kv_cache: shape = - [2, num_blocks, block_size, num_kv_heads, head_size] + [num_blocks, num_kv_heads, block_size, 2 * head_size] attn_metadata: Metadata for attention. Returns: shape = [num_tokens, num_heads * head_size] @@ -467,44 +455,19 @@ def do_kv_cache_update( ): if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER): return - key_cache, value_cache = PagedAttention.split_kv_cache( - kv_cache, self.num_kv_heads, self.head_size + kv_cache_transposed = kv_cache.transpose(1, 2) + key_cache, value_cache = kv_cache_transposed.split(self.head_size, dim=-1) + triton_reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, ) - # Reshape the input keys and values and store them in the cache. - # Get the actual block_size from value_cache - # value_cache shape: [num_blocks, num_heads, head_size, block_size] - block_size = value_cache.shape[3] - has_native_layout = has_native_kv_cache_layout(key_cache, value_cache) - - if block_size in (16, 32) and has_native_layout: - # Normal 16, 32 with contiguous blocks: use vLLM native HIP C++ logic. - PagedAttention.write_to_paged_cache( - key, - value, - key_cache, - value_cache, - slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) - else: - # Non-standard blocks and hybrid attention/Mamba layouts need the - # stride-aware Triton writer. The native reshape_and_cache kernel - # assumes contiguous block storage and writes to the wrong hybrid - # cache blocks. - triton_reshape_and_cache_flash( - key, - value, - key_cache, - value_cache, - slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) - def fused_rope_kvcache_supported(self): return rocm_aiter_ops.is_enabled() @@ -522,12 +485,12 @@ def do_rope_and_kv_cache_update( ): if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER): return - key_cache, value_cache = PagedAttention.split_kv_cache( - kv_cache, - layer.num_kv_heads, # type: ignore[attr-defined] - layer.head_size, # type: ignore[attr-defined] + kv_cache_transposed = kv_cache.transpose(1, 2) + key_cache, value_cache = kv_cache_transposed.split( + self.head_size, + dim=-1, ) - flash_layout = False + flash_layout = True is_fp8_kv_cache = is_quantized_kv_cache(self.kv_cache_dtype) if is_fp8_kv_cache: diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 008b74c9ff70..832ce0b5e4ea 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -31,7 +31,6 @@ MultipleOf, ) from vllm.v1.attention.backends.utils import ( - get_kv_cache_layout, get_num_attention_heads_from_layers, ) from vllm.v1.attention.ops.triton_prefill_attention import context_attention_fwd @@ -44,7 +43,6 @@ AttentionSpec, KVQuantMode, get_kv_quant_mode, - kv_cache_uses_per_token_head_scales, ) logger = init_logger(__name__) @@ -309,51 +307,6 @@ def supports_batch_invariance(cls) -> bool: def get_impl_cls() -> type["TritonAttentionImpl"]: return TritonAttentionImpl - @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, - head_size: int, - cache_dtype_str: str = "auto", - ) -> tuple[int, ...]: - if block_size % 16 != 0: - raise ValueError("Block size must be a multiple of 16.") - if kv_cache_uses_per_token_head_scales(cache_dtype_str): - # Pad head_size by sizeof(float32)/sizeof(cache_dtype) so - # the per-head scale fits inline. The backend extracts - # data[:head_size] and scale[head_size:] via typed views. - from vllm.utils.torch_utils import ( - STR_DTYPE_TO_TORCH_DTYPE, - get_dtype_size, - ) - - cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype_str] - scale_pad = get_dtype_size(torch.float32) // get_dtype_size(cache_dtype) - return (num_blocks, 2, block_size, num_kv_heads, head_size + scale_pad) - return (num_blocks, 2, block_size, num_kv_heads, head_size) - - @staticmethod - def get_kv_cache_stride_order( - include_num_layers_dimension: bool = False, - ) -> tuple[int, ...]: - # `stride_order` indicates the permutation that gets - # us from `get_kv_cache_shape` to the actual memory layout we want. - cache_layout = get_kv_cache_layout() - if cache_layout == "NHD" and include_num_layers_dimension: - # (num_blocks, num_layers, 2, block_size, num_kv_heads, head_size) - return (1, 0, 2, 3, 4, 5) - elif cache_layout == "NHD": - stride_order = (0, 1, 2, 3, 4) - elif cache_layout == "HND" and include_num_layers_dimension: - # (num_blocks, num_kv_heads, num_layers, 2, block_size, head_size) - return (1, 4, 0, 2, 3, 5) - elif cache_layout == "HND": - stride_order = (0, 1, 3, 2, 4) - else: - raise ValueError(f"Unknown cache layout: {cache_layout}") - return stride_order - @staticmethod def use_cascade_attention(*args, **kwargs) -> bool: return False @@ -583,10 +536,11 @@ def forward( layer, ) - # Per-token-head quantized KV cache: use separate scale caches. + # (B, H, N, C) -> (B, N, H, C) for kernel compatibility. + kv_cache = kv_cache.transpose(1, 2) if self._is_per_token_head_quant: self._ensure_scale_caches(kv_cache) - key_cache, value_cache = kv_cache.unbind(1) + key_cache, value_cache = kv_cache.split(self.head_size, dim=-1) if key_cache.dtype == torch.uint8: key_cache = key_cache.view(self.fp8_dtype) value_cache = value_cache.view(self.fp8_dtype) @@ -597,7 +551,7 @@ def forward( v_scale_cache = self._v_scale_cache # FP8 per-tensor / auto path (original flow). else: - key_cache, value_cache = kv_cache.unbind(1) + key_cache, value_cache = kv_cache.split(self.head_size, dim=-1) if ( is_quantized_kv_cache(self.kv_cache_dtype) and key_cache.dtype != self.fp8_dtype @@ -730,9 +684,10 @@ def do_kv_cache_update( # we use direct Q, K, V tensors without caching return # Reshape the input keys and values and store them in the cache. + kv_cache = kv_cache.transpose(1, 2) if self._is_per_token_head_quant: self._ensure_scale_caches(kv_cache) - key_cache, value_cache = kv_cache.unbind(1) + key_cache, value_cache = kv_cache.split(self.head_size, dim=-1) if key_cache.dtype == torch.uint8: key_cache = key_cache.view(self.fp8_dtype) value_cache = value_cache.view(self.fp8_dtype) @@ -747,7 +702,7 @@ def do_kv_cache_update( ) return # For decoder and cross-attention, use KV cache as before. - key_cache, value_cache = kv_cache.unbind(1) + key_cache, value_cache = kv_cache.split(self.head_size, dim=-1) if is_quantized_kv_cache(self.kv_cache_dtype): key_cache = key_cache.view(self.fp8_dtype) value_cache = value_cache.view(self.fp8_dtype) @@ -779,7 +734,8 @@ def do_rope_and_kv_cache_update( kv_cache: torch.Tensor, layer_slot_mapping: torch.Tensor, ): - key_cache, value_cache = kv_cache.unbind(1) + kv_cache = kv_cache.transpose(1, 2) + key_cache, value_cache = kv_cache.split(self.head_size, dim=-1) flash_layout = True is_fp8_kv_cache = is_quantized_kv_cache(self.kv_cache_dtype) diff --git a/vllm/v1/attention/backends/turboquant_attn.py b/vllm/v1/attention/backends/turboquant_attn.py index 3bf3b6b82482..18bd29723520 100644 --- a/vllm/v1/attention/backends/turboquant_attn.py +++ b/vllm/v1/attention/backends/turboquant_attn.py @@ -108,6 +108,10 @@ class TurboQuantAttentionBackend(AttentionBackend): def get_name() -> str: return "TURBOQUANT" + @classmethod + def get_required_kv_cache_layout(cls) -> str | None: + return "LBNHC" + @staticmethod def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: return [16, 32, 64, 128] @@ -128,38 +132,6 @@ def get_impl_cls() -> type["TurboQuantAttentionImpl"]: def get_builder_cls() -> type["TurboQuantMetadataBuilder"]: return TurboQuantMetadataBuilder - @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, - head_size: int, - cache_dtype_str: str = "turboquant_4bit_nc", - ) -> tuple[int, ...]: - """Combined K+V cache shape — no leading 2 dimension. - - Standard attention backends use (2, num_blocks, block_size, num_kv_heads, - head_dim) with a leading 2 to separate K and V. TurboQuant packs K+V - into a single interleaved slot per head per position, so the cache is: - - (num_blocks, block_size, num_kv_heads, slot_size_aligned) - - Each slot = [key_packed | value_packed | padding]. - This is safe because TQ has its own get_kv_cache_shape override and - never shares cache tensors with other backends. Layers that fall back - to native dtype via kv_cache_dtype_skip_layers get their own - standard-shaped cache allocation. - - head_size is the model's real head_dim. slot_size_aligned is computed - from the TQ config to ensure correct cache allocation for all head dims. - """ - from vllm.model_executor.layers.quantization.turboquant.config import ( - TurboQuantConfig, - ) - - tq_config = TurboQuantConfig.from_cache_dtype(cache_dtype_str, head_size) - return (num_blocks, block_size, num_kv_heads, tq_config.slot_size_aligned) - @classmethod def supports_kv_cache_dtype(cls, kv_cache_dtype: CacheDType | None) -> bool: if kv_cache_dtype is None: @@ -383,6 +355,8 @@ def do_kv_cache_update( k = key[:N].view(N, self.num_kv_heads, self.head_size) v = value[:N].view(N, self.num_kv_heads, self.head_size) + # (B, H, N, C) -> (B, N, H, C) for TQ kernels + kv_cache = kv_cache.transpose(1, 2) self._store_kv(k, v, kv_cache, slot_mapping, layer) def forward( @@ -410,13 +384,15 @@ def forward( if attn_metadata is None: return output.fill_(0) + # (B, H, N, C) -> (B, N, H, C) for TQ kernels + kv_cache = kv_cache.transpose(1, 2) + # Slice to actual tokens N = attn_metadata.num_actual_tokens if N <= 0: return output.fill_(0) q = query[:N].view(N, self.num_heads, self.head_size) - # Get TQ buffers, ensure on device (one-time migration). # Use Any-typed alias for dynamic _tq_* attrs set by _ensure_on_device. tq_layer: Any = layer @@ -538,7 +514,7 @@ def _store_kv( self, key: torch.Tensor, # (N, Hk, D) value: torch.Tensor, # (N, Hk, D) - kv_cache: torch.Tensor, # (num_blocks, block_size, Hk, slot_size) + kv_cache: torch.Tensor, # (num_blocks, Hk, block_size, slot_size) slot_mapping: torch.Tensor, layer: Any, ): @@ -564,7 +540,7 @@ def _prefill_attention( query: torch.Tensor, # (N, Hq, D) key: torch.Tensor, # (N, Hk, D) value: torch.Tensor, # (N, Hk, D) - kv_cache: torch.Tensor, # (num_blocks, block_size, Hk, slot_size) + kv_cache: torch.Tensor, # (num_blocks, Hk, block_size, slot_size) attn_metadata: TurboQuantMetadata, Pi: torch.Tensor, centroids: torch.Tensor, @@ -715,7 +691,7 @@ def _continuation_prefill( query: torch.Tensor, # (q_len, Hq, D) key_chunk: torch.Tensor, # (q_len, Hk, D) val_chunk: torch.Tensor, # (q_len, Hk, D) - kv_cache: torch.Tensor, # (num_blocks, block_size, Hk, slot_size) + kv_cache: torch.Tensor, # (num_blocks, Hk, block_size, slot_size) block_table: torch.Tensor, # (1, max_num_blocks) cached_len: int, seq_len: int, @@ -730,7 +706,7 @@ def _continuation_prefill( q_len, Hq, D = query.shape Hk = key_chunk.shape[1] device = query.device - block_size = kv_cache.shape[1] + block_size = kv_cache.shape[2] BLOCK_D = triton.next_power_of_2(D) mse_bytes = self._mse_bytes @@ -767,8 +743,8 @@ def _continuation_prefill( v_cached.stride(1), v_cached.stride(2), kv_cache.stride(0), - kv_cache.stride(1), kv_cache.stride(2), + kv_cache.stride(1), block_table.stride(0), HEAD_DIM=D, BLOCK_SIZE=block_size, @@ -858,7 +834,7 @@ def _continuation_prefill( def _decode_attention( self, query: torch.Tensor, # (B, Hq, D) - kv_cache: torch.Tensor, # (num_blocks, block_size, Hk, slot_size) + kv_cache: torch.Tensor, # (num_blocks, Hk, block_size, slot_size) attn_metadata: TurboQuantMetadata, Pi: torch.Tensor, centroids: torch.Tensor, diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index b73d17e8e5cc..54a8596a0ebc 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -8,7 +8,6 @@ Any, Literal, Protocol, - get_args, ) import numpy as np @@ -17,7 +16,11 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.utils.math_utils import cdiv -from vllm.v1.kv_cache_interface import KVCacheSpec, MambaSpec +from vllm.v1.kv_cache_interface import ( + KVCacheLayout, + KVCacheSpec, + MambaSpec, +) if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput @@ -38,51 +41,87 @@ ) logger = init_logger(__name__) -KVCacheLayoutType = Literal["NHD", "HND"] + +# Deprecated: use resolve_kv_cache_layout() instead (RFC #42082). +KVCacheLayoutType = Literal["LBNHC", "LBHNC", "BLHNC", "BHLNC"] _KV_CACHE_LAYOUT_OVERRIDE: KVCacheLayoutType | None = None PAD_SLOT_ID = -1 NULL_BLOCK_ID = 0 -def is_valid_kv_cache_layout(value: str) -> bool: - return value in get_args(KVCacheLayoutType) +_LAYOUT_COMPAT_ALIASES = { + "NHD": "LBNHC", + "HND": "LBHNC", + "NHC": "LBNHC", + "HNC": "LBHNC", +} +_FLASHINFER_LAYOUT_NAMES = {"LBNHC": "NHD", "LBHNC": "HND"} -@functools.lru_cache -def get_kv_cache_layout(): - # Format specified by the code. - global _KV_CACHE_LAYOUT_OVERRIDE +def is_valid_kv_cache_layout(value: str) -> bool: + return value in KVCacheLayout.__members__ or value in _LAYOUT_COMPAT_ALIASES - cache_layout: Literal["NHD", "HND"] | None = None - if _KV_CACHE_LAYOUT_OVERRIDE is not None: - cache_layout = _KV_CACHE_LAYOUT_OVERRIDE - logger.info_once( - "`_KV_CACHE_LAYOUT_OVERRIDE` variable detected. " - "Setting KV cache layout to %s.", - cache_layout, - ) - return cache_layout - # Format specified by the user. - cache_layout = envs.VLLM_KV_CACHE_LAYOUT - # When neither the user nor the override specified a layout, get default - if cache_layout is None: - cache_layout = get_kv_connector_cache_layout() - else: - assert is_valid_kv_cache_layout(cache_layout) - logger.info_once( - "`VLLM_KV_CACHE_LAYOUT` environment variable " - "detected. Setting KV cache layout to %s.", - cache_layout, - ) - return cache_layout +def get_flashinfer_layout_string() -> str: + """Return the layout name in FlashInfer's convention (NHD/HND).""" + name = resolve_kv_cache_layout().name + return _FLASHINFER_LAYOUT_NAMES.get(name, name) -def set_kv_cache_layout(cache_layout: KVCacheLayoutType | None): +def set_kv_cache_layout(cache_layout: "KVCacheLayoutType | None"): + """Override the KV cache layout (for tests and platform constraints).""" global _KV_CACHE_LAYOUT_OVERRIDE _KV_CACHE_LAYOUT_OVERRIDE = cache_layout - get_kv_cache_layout.cache_clear() + resolve_kv_cache_layout.cache_clear() + + +@functools.lru_cache +def resolve_kv_cache_layout( + attn_backends: tuple[type[AttentionBackend], ...] | None = None, +) -> KVCacheLayout: + """Resolve the physical KV cache layout from the config priority chain. + + Priority: + 1. Runtime override (set_kv_cache_layout, used by tests) + 2. VLLM_KV_CACHE_LAYOUT env var (user override) + 3. Connector's get_required_kvcache_layout() preference + 4. "LBHNC" fallback + """ + global _KV_CACHE_LAYOUT_OVERRIDE + layout_name: str | None + if _KV_CACHE_LAYOUT_OVERRIDE is not None: + layout_name = _KV_CACHE_LAYOUT_OVERRIDE + else: + layout_name = envs.VLLM_KV_CACHE_LAYOUT + + if layout_name is None and attn_backends is not None: + required_layouts = set( + backend.get_required_kv_cache_layout() for backend in attn_backends + ) + required_layouts.discard(None) + if len(required_layouts) > 1: + raise ValueError( + f"Multiple required KV cache layouts: {required_layouts}. " + f"All backends must use the same layout." + ) + if len(required_layouts) == 1: + layout_name = required_layouts.pop() + + if layout_name is None: + layout_name = get_kv_connector_cache_layout() + + layout_name = layout_name or "LBHNC" + layout_name = _LAYOUT_COMPAT_ALIASES.get(layout_name, layout_name) + try: + layout = KVCacheLayout[layout_name] + except KeyError: + raise ValueError( + f"Unknown KV cache layout {layout_name!r}. " + f"Valid layouts: {[m.name for m in KVCacheLayout]}" + ) from None + logger.info_once("Resolved KV cache layout: %s", layout) + return layout @dataclass diff --git a/vllm/v1/attention/ops/chunked_prefill_paged_decode.py b/vllm/v1/attention/ops/chunked_prefill_paged_decode.py index 77eb3ac60b1f..68224b63a150 100644 --- a/vllm/v1/attention/ops/chunked_prefill_paged_decode.py +++ b/vllm/v1/attention/ops/chunked_prefill_paged_decode.py @@ -27,9 +27,10 @@ def has_native_kv_cache_layout( ) -> bool: """Return whether KV cache blocks can use the native ROCm pairing. - The native reshape_and_cache writer assumes packed blocks. If cache update - needs reshape_and_cache_flash for a stride-padded hybrid layout, decode - should use the matching Triton path too. + The C++ ``ops.paged_attention_rocm`` custom kernel requires each block + to be contiguous in memory. Returns False for stride-padded hybrid + layouts and for the unified KV cache (RFC #42082, see + :meth:`PagedAttention.split_kv_cache`), routing them to Triton. """ return ( key_cache.stride(0) == key_cache.shape[1:].numel() @@ -415,7 +416,7 @@ def chunked_prefill_paged_decode( "Cannot use ROCm custom paged attention kernel," " falling back to Triton implementation." ) - real_block_size = value_cache.shape[3] + real_block_size = value_cache.shape[2] # The standard model directly uses the original block_size. # Non-standard 544 uses 32 to accommodate integer division logic. # Cap at 128 to avoid exceeding GPU shared memory limits @@ -479,8 +480,8 @@ def chunked_prefill_paged_decode( stride_k_cache_4=key_cache.stride(4), stride_v_cache_0=value_cache.stride(0), stride_v_cache_1=value_cache.stride(1), - stride_v_cache_2=value_cache.stride(2), - stride_v_cache_3=value_cache.stride(3), + stride_v_cache_2=value_cache.stride(3), + stride_v_cache_3=value_cache.stride(2), filter_by_query_len=True, query_start_len_ptr=query_start_loc, USE_SINKS=sinks is not None, diff --git a/vllm/v1/attention/ops/paged_attn.py b/vllm/v1/attention/ops/paged_attn.py index 896e929b5433..6e54294902f9 100644 --- a/vllm/v1/attention/ops/paged_attn.py +++ b/vllm/v1/attention/ops/paged_attn.py @@ -19,13 +19,16 @@ def split_kv_cache( num_kv_heads: int, head_size: int, ) -> tuple[torch.Tensor, torch.Tensor]: + # [B, H, N, 2*C] -> key [B, H, C//x, N, x], value [B, H, C, N] x = 16 // kv_cache.element_size() - num_blocks = kv_cache.shape[1] - - key_cache = kv_cache[0] - key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x, -1, x) - value_cache = kv_cache[1] - value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1) + key_slice = kv_cache[..., :head_size] + value_slice = kv_cache[..., head_size:] + key_cache = ( + key_slice.permute(0, 1, 3, 2) + .unflatten(2, (head_size // x, x)) + .transpose(3, 4) + ) + value_cache = value_slice.permute(0, 1, 3, 2) return key_cache, value_cache @staticmethod diff --git a/vllm/v1/attention/selector.py b/vllm/v1/attention/selector.py index 5d63553a7b60..171002eb8cc6 100644 --- a/vllm/v1/attention/selector.py +++ b/vllm/v1/attention/selector.py @@ -8,15 +8,12 @@ import vllm.envs as envs from vllm.config.cache import CacheDType -from vllm.logger import init_logger from vllm.utils.import_utils import resolve_obj_by_qualname from vllm.v1.attention.backend import AttentionBackend, AttentionType from vllm.v1.attention.backends.registry import ( MambaAttentionBackendEnum, ) -logger = init_logger(__name__) - class AttentionSelectorConfig(NamedTuple): head_size: int @@ -128,19 +125,6 @@ def _cached_get_attn_backend( f"Invalid attention backend for {current_platform.device_name}" ) backend = resolve_obj_by_qualname(attention_cls) - - # Adjust kv cache layout if the selected backend requires a specific one - required_layout = backend.get_required_kv_cache_layout() - if required_layout is not None: - from vllm.v1.attention.backends.utils import set_kv_cache_layout - - set_kv_cache_layout(required_layout) - logger.info( - "Using %s KV cache layout for %s backend.", - required_layout, - backend.get_name(), - ) - return backend diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 7f3a5e4fdf3f..44386046077d 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -10,7 +10,7 @@ from collections.abc import Callable, Iterable, Iterator, Sequence from dataclasses import dataclass, replace from functools import partial -from typing import Any, NewType, TypeAlias, cast, overload +from typing import Any, NewType, TypeAlias, overload from vllm import envs from vllm.config import VllmConfig @@ -19,12 +19,15 @@ from vllm.utils.math_utils import cdiv, round_up from vllm.utils.mem_utils import format_gib from vllm.utils.torch_utils import get_dtype_size +from vllm.v1.attention.backends.utils import resolve_kv_cache_layout from vllm.v1.kv_cache_interface import ( + AttentionSpec, ChunkedLocalAttentionSpec, FullAttentionSpec, HiddenStateCacheSpec, KVCacheConfig, KVCacheGroupSpec, + KVCacheLayout, KVCacheSpec, KVCacheTensor, MambaSpec, @@ -912,24 +915,8 @@ def _pool_bytes_per_block(kv_cache_groups: list[KVCacheGroupSpec]) -> int: `available_memory` into `num_blocks`. Used to compute the effective KV cache capacity once `num_gpu_blocks_override` is applied. """ - if len(kv_cache_groups) == 1 and isinstance( - kv_cache_groups[0].kv_cache_spec, UniformTypeKVCacheSpecs - ): - return kv_cache_groups[0].kv_cache_spec.page_size_bytes - if all( - isinstance(g.kv_cache_spec, UniformTypeKVCacheSpecs) for g in kv_cache_groups - ): - # DeepseekV4: shared layout sized by the largest per-page-size bucket. - full_mla_spec = cast(UniformTypeKVCacheSpecs, kv_cache_groups[0].kv_cache_spec) - layer_tuple_page_bytes = sum(full_mla_spec.get_page_sizes()) - num_layer_tuples = max( - cast(UniformTypeKVCacheSpecs, g.kv_cache_spec).get_num_layer_tuples() - for g in kv_cache_groups - ) - return layer_tuple_page_bytes * num_layer_tuples - group_size = max(len(g.layer_names) for g in kv_cache_groups) - page_size = get_uniform_page_size([g.kv_cache_spec for g in kv_cache_groups]) - return page_size * group_size + buckets = _bucket_layers_by_page_size(kv_cache_groups) + return sum(ps * len(slots) for ps, slots in buckets.items()) def get_num_blocks( @@ -1176,61 +1163,79 @@ def _get_kv_cache_groups_uniform_page_size( return create_kv_cache_group_specs(kv_cache_spec, grouped_layers) -def _get_kv_cache_config_deepseek_v4( - vllm_config: VllmConfig, +def _bucket_layers_by_page_size( kv_cache_groups: list[KVCacheGroupSpec], - available_memory: int, -) -> tuple[int, list[KVCacheTensor]]: - """DeepseekV4 KV cache tensor layout planning. - - Precondition: kv_cache_groups[0] is the full-MLA group; its page sizes - define the canonical bucket set. Non-full-MLA groups must have been - page_size-padded upstream (see _get_kv_cache_groups_uniform_groups) so - every layer's page_size matches one of the full-MLA bucket sizes. - - For each group, bucket its layers by page_size_bytes and place each - layer at tuple_idx = position-within-bucket. Emit one KVCacheTensor - per (tuple_idx, bucket) whose shared_by is the union of per-group - layers at that slot. - """ - full_mla_spec = kv_cache_groups[0].kv_cache_spec - assert isinstance(full_mla_spec, UniformTypeKVCacheSpecs) - page_sizes = sorted(full_mla_spec.get_page_sizes()) - layer_tuple_page_bytes = sum(page_sizes) - - # Pre-bucket each group's layers by page_size (registration order within - # bucket). bucketed[g_idx][page_size] = [layer_name, ...]. - bucketed: list[dict[int, list[str]]] = [] +) -> dict[int, list[list[str]]]: + """Bucket layers by page size: ``result[ps][slot_idx] = [layer_names]``. + + Layers from different groups at the same ``slot_idx`` share a block + (they have independent block tables so block-id namespaces never collide). + """ + buckets: dict[int, list[list[str]]] = defaultdict(list) for group in kv_cache_groups: - assert isinstance(group.kv_cache_spec, UniformTypeKVCacheSpecs) - specs = group.kv_cache_spec.kv_cache_specs - b: dict[int, list[str]] = defaultdict(list) - for name in group.layer_names: - b[specs[name].page_size_bytes].append(name) - bucketed.append(b) - - # num_layer_tuples = longest bucket list across all groups. For the - # full-MLA group this equals the count of layers in the largest - # per-page-size bucket (= get_num_layer_tuples()); for SWA sub-groups - # this equals the sub-group size (each has a single page_size). - num_layer_tuples = max(len(layers) for b in bucketed for layers in b.values()) - - num_blocks = available_memory // (layer_tuple_page_bytes * num_layer_tuples) - num_blocks = may_override_num_blocks(vllm_config, num_blocks) + spec = group.kv_cache_spec + slot_count: dict[int, int] = defaultdict(int) + for layer_name in group.layer_names: + if isinstance(spec, UniformTypeKVCacheSpecs): + ps = spec.kv_cache_specs[layer_name].page_size_bytes + else: + ps = spec.page_size_bytes + slot_idx = slot_count[ps] + slot_count[ps] += 1 + if slot_idx == len(buckets[ps]): + buckets[ps].append([]) + buckets[ps][slot_idx].append(layer_name) + return buckets + + +def _get_per_layer_spec( + group: KVCacheGroupSpec, + layer_name: str, +) -> KVCacheSpec: + """Return the concrete KVCacheSpec for a specific layer.""" + spec = group.kv_cache_spec + if isinstance(spec, UniformTypeKVCacheSpecs): + return spec.kv_cache_specs[layer_name] + return spec + + +def _validate_layout_compatibility( + kv_cache_tensors: list[KVCacheTensor], + kv_cache_groups: list[KVCacheGroupSpec], +) -> None: + """Ensure [H, N, C] is contiguous when layers sharing a tensor differ.""" + layout = resolve_kv_cache_layout() - kv_cache_tensors: list[KVCacheTensor] = [] - for tuple_idx in range(num_layer_tuples): - for ps in page_sizes: - shared_by: list[str] = [] - for b in bucketed: - bucket = b.get(ps) - if bucket is not None and tuple_idx < len(bucket): - shared_by.append(bucket[tuple_idx]) - kv_cache_tensors.append( - KVCacheTensor(size=ps * num_blocks, shared_by=shared_by) - ) + # The per-layer-per-block content [H, N, C] is contiguous therefore + # these layouts are always valid. + if layout in (KVCacheLayout.LBHNC, KVCacheLayout.BLHNC): + return - return num_blocks, kv_cache_tensors + layer_spec_map: dict[str, KVCacheSpec] = {} + for group in kv_cache_groups: + for layer_name in group.layer_names: + layer_spec_map[layer_name] = _get_per_layer_spec(group, layer_name) + + for tensor in kv_cache_tensors: + all_layer_names = [name for slot in tensor.shared_by for name in slot] + if len(all_layer_names) <= 1: + continue + attn_specs = [ + layer_spec_map[n] + for n in all_layer_names + if isinstance(layer_spec_map[n], AttentionSpec) + ] + if len(attn_specs) <= 1: + continue + shapes = {(s.num_heads, s.state_content_size_bytes) for s in attn_specs} + if len(shapes) > 1: + raise ValueError( + f"Groups {kv_cache_groups} share a KVCacheTensor but" + f" have different (num_heads, state_content_size_bytes)" + f" for layers {all_layer_names}. Use a layout where" + f" [H, N, C] is contiguous (e.g." + f" VLLM_KV_CACHE_LAYOUT=LBHNC or BLHNC)." + ) def get_kv_cache_config_from_groups( @@ -1238,8 +1243,7 @@ def get_kv_cache_config_from_groups( kv_cache_groups: list[KVCacheGroupSpec], available_memory: int, ) -> KVCacheConfig: - """ - Generate the KV cache configuration from the KV cache groups and spec + """Generate the KV cache configuration from the KV cache groups and spec of each layer. Args: @@ -1258,61 +1262,24 @@ def get_kv_cache_config_from_groups( kv_cache_groups=kv_cache_groups, ) - # Determine how model runners should initialize the KV cache tensors. - if len(kv_cache_groups) == 1 and isinstance( - kv_cache_groups[0].kv_cache_spec, UniformTypeKVCacheSpecs - ): - # Special case: all layers have the same type of KV cache but with - # different hidden sizes. Allocate different amount of memory for each - # layer based on its hidden size. - num_blocks = ( - available_memory // kv_cache_groups[0].kv_cache_spec.page_size_bytes - ) - num_blocks = may_override_num_blocks(vllm_config, num_blocks) - per_layer_specs = kv_cache_groups[0].kv_cache_spec.kv_cache_specs - kv_cache_tensors = [ + buckets = _bucket_layers_by_page_size(kv_cache_groups) + + page_sizes = list(buckets.keys()) + bytes_per_block = sum(ps * len(buckets[ps]) for ps in page_sizes) + num_blocks = available_memory // bytes_per_block + num_blocks = may_override_num_blocks(vllm_config, num_blocks) + + kv_cache_tensors: list[KVCacheTensor] = [] + for ps in page_sizes: + shared_by = buckets[ps] + kv_cache_tensors.append( KVCacheTensor( - size=per_layer_specs[layer_name].page_size_bytes * num_blocks, - shared_by=[layer_name], + size=ps * num_blocks * len(shared_by), + shared_by=shared_by, ) - for layer_name in kv_cache_groups[0].layer_names - ] - elif all( - isinstance(group.kv_cache_spec, UniformTypeKVCacheSpecs) - for group in kv_cache_groups - ): - # DeepseekV4: UniformTypeKVCacheSpecs but multiple groups. - # Delegate to the DeepseekV4-specific allocator. - num_blocks, kv_cache_tensors = _get_kv_cache_config_deepseek_v4( - vllm_config, kv_cache_groups, available_memory - ) - else: - # General case: - # We will have group_size memory pools, each is shared by one layer from - # each group. As layers of different groups have different block table, - # they will use different parts of the shared Tensor. - # The memory layout for 3 groups (full.0, full.1), (sw.0, sw.2), - # (sw.1, padding) will be: (group_size = 2) - # full.0, sw.0, sw.1: share a Tensor with size=available_memory//2 - # full.1, sw.2: share another Tensor with size=available_memory//2 - group_size = max(len(group.layer_names) for group in kv_cache_groups) - - page_size = get_uniform_page_size( - [group.kv_cache_spec for group in kv_cache_groups] ) - assert group_size > 0, "group_size must be greater than 0" - num_blocks = get_num_blocks( - vllm_config, group_size, available_memory, page_size - ) - kv_cache_tensors = [] - for i in range(group_size): - shared_by = [] - for j in range(len(kv_cache_groups)): - if i < len(kv_cache_groups[j].layer_names): - shared_by.append(kv_cache_groups[j].layer_names[i]) - kv_cache_tensors.append( - KVCacheTensor(size=page_size * num_blocks, shared_by=shared_by) - ) + + _validate_layout_compatibility(kv_cache_tensors, kv_cache_groups) return KVCacheConfig( num_blocks=num_blocks, @@ -1383,7 +1350,7 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]): page_size_padded=spec.page_size_padded, cache_dtype_str=spec.cache_dtype_str, alignment=spec.alignment, - compress_ratio=spec.compress_ratio, + tokens_per_state=spec.tokens_per_state, model_version=spec.model_version, ) elif isinstance(spec, SlidingWindowSpec): @@ -1745,60 +1712,30 @@ def _max_memory_usage_bytes_from_groups( """ Calculate maximum memory usage in bytes from KV cache groups. - This correctly accounts for padding in hybrid models. For example, if a - model has 8 full attention layers and 9 sliding window layers, they will - be padded to 9 full + 9 sliding window for uniform group sizes. + A request needs blocks from *every* group simultaneously. The total + blocks consumed per request is the **sum** across groups (each group + independently claims blocks from the shared pool). Total memory is + ``bytes_per_block * sum_of_blocks``. """ if not kv_cache_groups: return 0 - if len(kv_cache_groups) == 1 and isinstance( - kv_cache_groups[0].kv_cache_spec, UniformTypeKVCacheSpecs - ): - # UniformTypeKVCacheSpecs special case (single group, per-layer specs) - per_layer_specs = kv_cache_groups[0].kv_cache_spec.kv_cache_specs - return sum( - spec.max_memory_usage_bytes(vllm_config) - for spec in per_layer_specs.values() - ) - elif all( - isinstance(group.kv_cache_spec, UniformTypeKVCacheSpecs) - for group in kv_cache_groups - ): - # Special case (only DeepseekV4 for now): all groups are - # UniformTypeKVCacheSpecs. - # They must already be page_size aligned and share a common padded - # layer-tuple layout. Even groups with fewer actual tuples still reserve - # the global number of tuple slots in the shared tensor layout. - full_mla_spec = cast(UniformTypeKVCacheSpecs, kv_cache_groups[0].kv_cache_spec) - layer_tuple_bytes = sum(full_mla_spec.get_page_sizes()) - num_layer_tuples = max( - cast(UniformTypeKVCacheSpecs, group.kv_cache_spec).get_num_layer_tuples() - for group in kv_cache_groups - ) + bytes_per_block = _pool_bytes_per_block(kv_cache_groups) + if bytes_per_block == 0: + return 0 - total_max_mem_usage_bytes = 0 - for group in kv_cache_groups: - group_spec = cast(UniformTypeKVCacheSpecs, group.kv_cache_spec) - g_max_mem_usage_pages = group_spec.max_memory_usage_pages(vllm_config) - g_max_mem_usage_page_bytes = ( - num_layer_tuples * g_max_mem_usage_pages * layer_tuple_bytes + total_blocks = 0 + for group in kv_cache_groups: + spec = group.kv_cache_spec + if isinstance(spec, UniformTypeKVCacheSpecs): + total_blocks += spec.max_memory_usage_pages(vllm_config) + else: + total_blocks += cdiv( + spec.max_memory_usage_bytes(vllm_config), + spec.page_size_bytes, ) - total_max_mem_usage_bytes += g_max_mem_usage_page_bytes - return total_max_mem_usage_bytes - - # General case: group_size pools, each shared by one layer per group - # Memory = group_size * page_size * blocks_for_max_len - group_size = max(len(group.layer_names) for group in kv_cache_groups) - page_size = get_uniform_page_size( - [group.kv_cache_spec for group in kv_cache_groups] - ) - blocks_needed = sum( - cdiv(group.kv_cache_spec.max_memory_usage_bytes(vllm_config), page_size) - for group in kv_cache_groups - ) - return group_size * page_size * blocks_needed + return bytes_per_block * total_blocks def _estimate_max_model_len_from_groups( diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 31ee89bc72aa..727c7f58e9eb 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -95,11 +95,27 @@ class KVCacheSpecKind(str, Enum): class KVCacheSpec: """ A base class for specifying the KV cache format of one layer. + + RFC #42082 standard vocabulary (properties, overridden by subclasses): + num_heads: int — heads (1 if headless, e.g. MLA) + tokens_per_state: int — -1 infinite (recurrent), 1 standard, N compressed + state_content_size_bytes: int — bytes per state per head """ - # number of tokens in a block block_size: int + @property + def num_heads(self) -> int: + raise NotImplementedError + + @property + def tokens_per_state(self) -> int: + raise NotImplementedError + + @property + def state_content_size_bytes(self) -> int: + raise NotImplementedError + @property def page_size_bytes(self) -> int: """ @@ -140,6 +156,104 @@ def merge(cls, specs: list[Self]) -> Self: return copy.deepcopy(specs[0]) +# Logical dim indices in the 5D stride permutation [L, B, H, N, C] (see: RFC #42082). +_DIM_L, _DIM_B, _DIM_H, _DIM_N, _DIM_C = 0, 1, 2, 3, 4 + + +class KVCacheLayout(Enum): + """Physical layout descriptor for a KV cache group. + + The logical shape is always [L, B, H, N, ] (RFC #42082). + Each member's value is a stride permutation that maps logical axes + to physical (memory) order. + """ + + LBHNC = (0, 1, 2, 3, 4) # [L, B, H, N, C] (identity) + LBNHC = (0, 1, 3, 2, 4) # [L, B, N, H, C] + BLHNC = (1, 0, 2, 3, 4) # [B, L, H, N, C] + BHLNC = (1, 2, 0, 3, 4) # [B, H, L, N, C] + + @property + def stride_order(self) -> tuple[int, ...]: + return self.value + + @property + def layer_stride_order(self) -> tuple[int, ...]: + """4D permutation [B, H, N, C] for per-layer tensors.""" + if not self.is_layer_compact: + compact = [m.name for m in KVCacheLayout if m.is_layer_compact] + raise ValueError( + f"KVCacheLayout.{self.name} cannot produce a 4D" + f" layer_stride_order because L is not outermost." + f" Use a layer-compact layout: {compact}" + ) + return tuple(i - 1 for i in self.value if i != _DIM_L) + + @property + def is_layer_compact(self) -> bool: + """True when the layer is compact; i.e. the L dimension is outermost.""" + return self.value[_DIM_L] == 0 + + @property + def is_block_contiguous(self) -> bool: + """True when [H, N, C] is contiguous within a block.""" + return self.value[-3:] == (_DIM_H, _DIM_N, _DIM_C) + + +def num_states_for(block_size: int, tokens_per_state: int) -> int: + """Derive num_states at allocation time (not part of the spec).""" + if tokens_per_state == -1: + return 1 # recurrent: single state per block + return block_size // tokens_per_state + + +def compute_layer_kv_cache_shape_bytes( + spec: KVCacheSpec, + num_blocks: int, + block_size: int | None = None, +) -> tuple[int, ...]: + """Return the 4D logical shape ``(B, H, N, C)`` where C is in bytes.""" + bs = block_size if block_size is not None else spec.storage_block_size + ns = num_states_for(bs, spec.tokens_per_state) + return (num_blocks, spec.num_heads, ns, spec.state_content_size_bytes) + + +def reshape_kv_cache( + raw: torch.Tensor, + spec: KVCacheSpec, + num_blocks: int, + num_layer_slots: int, + layout: KVCacheLayout, + block_size: int | None = None, +) -> list[torch.Tensor]: + """View a flat int8 buffer as 4D ``[B, H, N, C]`` per-slot views. + + Works for all KVCacheSpec subclasses. Shapes as int8 via + compute_layer_kv_cache_shape_bytes, then reinterprets as spec.dtype. + """ + dtype = getattr(spec, "dtype", None) + logical_shape_bytes = ( + num_layer_slots, + *compute_layer_kv_cache_shape_bytes(spec, num_blocks, block_size), + ) + stride_order = layout.stride_order + physical_shape_bytes = tuple(logical_shape_bytes[i] for i in stride_order) + inv_order = [stride_order.index(i) for i in range(5)] + + if page_size_padded := getattr(spec, "page_size_padded", None): + strides = list(torch.empty(physical_shape_bytes, device="meta").stride()) + strides[inv_order[_DIM_B]] = page_size_padded + cache = torch.as_strided(raw, size=physical_shape_bytes, stride=tuple(strides)) + else: + cache = raw.view(physical_shape_bytes) + cache_logical = cache.permute(*inv_order) + + if dtype is not None: + cache_logical = cache_logical.view(dtype) + + return [cache_logical[i] for i in range(num_layer_slots)] + + @dataclass(frozen=True, kw_only=True) class AttentionSpec(KVCacheSpec): num_kv_heads: int @@ -147,6 +261,24 @@ class AttentionSpec(KVCacheSpec): dtype: torch.dtype kv_quant_mode: KVQuantMode = KVQuantMode.NONE page_size_padded: int | None = None + tokens_per_state: int = 1 + + # NVFP4: K and V are stored as separate head groups [L, B, 2*H, N, dim] + # rather than interleaved in content [L, B, H, N, 2*dim]. + + @property + def num_heads(self) -> int: + if self.kv_quant_mode.is_nvfp4: + return 2 * self.num_kv_heads + return self.num_kv_heads + + @property + def state_content_size_bytes(self) -> int: + hs = self.head_size + if self.kv_quant_mode.is_nvfp4: + hs = nvfp4_kv_cache_full_dim(hs) + return hs * get_dtype_size(self.dtype) + return 2 * hs * get_dtype_size(self.dtype) @property def page_size_bytes(self) -> int: @@ -207,6 +339,19 @@ def __post_init__(self): if self.head_size_v is None: object.__setattr__(self, "head_size_v", self.head_size) + @property + def state_content_size_bytes(self) -> int: + hs_k = self.head_size + hs_v = self.head_size_v + if self.kv_quant_mode.is_nvfp4: + hs_k = nvfp4_kv_cache_full_dim(hs_k) + hs_v = nvfp4_kv_cache_full_dim(hs_v) + assert hs_k == hs_v, ( + "nvfp4 with asymmetric K/V head sizes not yet supported" + ) + return hs_k * get_dtype_size(self.dtype) + return (hs_k + hs_v) * get_dtype_size(self.dtype) + def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: max_model_len = vllm_config.model_config.max_model_len dcp_world_size = vllm_config.parallel_config.decode_context_parallel_size @@ -318,6 +463,12 @@ class TQFullAttentionSpec(FullAttentionSpec): tq_slot_size: int = 0 + @property + def state_content_size_bytes(self) -> int: + if self.tq_slot_size > 0: + return self.tq_slot_size + return super().state_content_size_bytes + @property def real_page_size_bytes(self) -> int: if self.tq_slot_size > 0: @@ -339,16 +490,22 @@ class MLAAttentionSpec(FullAttentionSpec): cache_dtype_str: str | None = None # DeepseekV4 only fields. Non-DeepseekV4 MLA models leave these at defaults. alignment: int | None = None # Default to None for no padding. - compress_ratio: int = 1 # Default to 1 for no compression. + tokens_per_state: int = 1 model_version: str | None = None def __post_init__(self): super().__post_init__() _apply_alignment_padding(self) + @property + def state_content_size_bytes(self) -> int: + if self.cache_dtype_str == "fp8_ds_mla": + return 584 if self.model_version == "deepseek_v4" else 656 + return self.head_size * get_dtype_size(self.dtype) + @property def storage_block_size(self) -> int: - return self.block_size // self.compress_ratio + return self.block_size // self.tokens_per_state @property def real_page_size_bytes(self) -> int: @@ -373,11 +530,11 @@ def merge(cls, specs: list[Self]) -> Self: "All attention layers in the same KV cache group must be MLAAttentionSpec." ) cache_dtype_str_set = set(spec.cache_dtype_str for spec in specs) - compress_ratio_set = set(spec.compress_ratio for spec in specs) + tokens_per_state_set = set(spec.tokens_per_state for spec in specs) model_version_set = set(spec.model_version for spec in specs) assert ( len(cache_dtype_str_set) == 1 - and len(compress_ratio_set) == 1 + and len(tokens_per_state_set) == 1 and len(model_version_set) == 1 ), ( "All attention layers in the same KV cache group must use the same " @@ -391,7 +548,7 @@ def merge(cls, specs: list[Self]) -> Self: kv_quant_mode=specs[0].kv_quant_mode, page_size_padded=specs[0].page_size_padded, cache_dtype_str=cache_dtype_str_set.pop(), - compress_ratio=compress_ratio_set.pop(), + tokens_per_state=tokens_per_state_set.pop(), model_version=model_version_set.pop(), ) @@ -501,15 +658,22 @@ class SlidingWindowMLASpec(SlidingWindowSpec): cache_dtype_str: str | None = None # DeepseekV4-only: see MLAAttentionSpec.model_version. alignment: int | None = None # Default to None for no padding. - compress_ratio: int = 1 + tokens_per_state: int = 1 model_version: str | None = None def __post_init__(self): + super().__post_init__() _apply_alignment_padding(self) + @property + def state_content_size_bytes(self) -> int: + if self.model_version == "deepseek_v4": + return 584 + return self.head_size * get_dtype_size(self.dtype) + @property def storage_block_size(self) -> int: - return self.block_size // self.compress_ratio + return self.block_size // self.tokens_per_state @property def real_page_size_bytes(self) -> int: @@ -533,12 +697,12 @@ def merge(cls, specs: list[Self]) -> Self: "SlidingWindowMLASpec." ) cache_dtype_str_set = set(spec.cache_dtype_str for spec in specs) - compress_ratio_set = set(spec.compress_ratio for spec in specs) + tokens_per_state_set = set(spec.tokens_per_state for spec in specs) model_version_set = set(spec.model_version for spec in specs) sliding_window_set = set(spec.sliding_window for spec in specs) assert ( len(cache_dtype_str_set) == 1 - and len(compress_ratio_set) == 1 + and len(tokens_per_state_set) == 1 and len(model_version_set) == 1 and len(sliding_window_set) == 1 ), ( @@ -554,7 +718,7 @@ def merge(cls, specs: list[Self]) -> Self: page_size_padded=specs[0].page_size_padded, sliding_window=sliding_window_set.pop(), cache_dtype_str=cache_dtype_str_set.pop(), - compress_ratio=compress_ratio_set.pop(), + tokens_per_state=tokens_per_state_set.pop(), model_version=model_version_set.pop(), ) @@ -567,6 +731,15 @@ class MambaSpec(KVCacheSpec): mamba_type: MambaAttentionBackendEnum = MambaAttentionBackendEnum.MAMBA2 mamba_cache_mode: str = "none" num_speculative_blocks: int = 0 + num_heads: int = 1 + tokens_per_state: int = -1 + + @property + def state_content_size_bytes(self) -> int: + return sum( + prod(shape) * get_dtype_size(dtype) + for (shape, dtype) in zip(self.shapes, self.dtypes) + ) @property def page_size_bytes(self) -> int: @@ -811,12 +984,15 @@ def get_kv_cache_spec_sliding_window(kv_cache_spec: KVCacheSpec) -> int | None: @dataclass class KVCacheTensor: - """ - A class for specifying how the workers should initialize the KV cache. + """One contiguous GPU allocation backing one or more layer slots. + + ``shared_by[slot_idx]`` lists the layer names aliasing slot ``slot_idx``. + Layers in the same inner list belong to different groups (independent + block tables) so their block-id namespaces never collide. """ - size: int # size of the KV cache tensor in bytes - shared_by: list[str] # layer names that share the same KV cache tensor + size: int # total size in bytes + shared_by: list[list[str]] # shared_by[slot_idx] = [layer_names] @dataclass diff --git a/vllm/v1/kv_offload/base.py b/vllm/v1/kv_offload/base.py index 5f798f41eac8..00ec1fc55757 100644 --- a/vllm/v1/kv_offload/base.py +++ b/vllm/v1/kv_offload/base.py @@ -330,10 +330,8 @@ class CanonicalKVCacheTensor: """ A canonicalized KV cache tensor whose first dimension is num_blocks. - For attention backends where the raw tensor has num_blocks at a - non-leading physical dimension (e.g. FlashAttention's - (2, num_blocks, ...) layout), the tensor is split so that each - resulting CanonicalKVCacheTensor starts with (num_blocks, ...). + With standardized layouts (RFC #42082) num_blocks is always the + leading logical dimension. """ # The KV cache tensor with shape (num_blocks, ...) diff --git a/vllm/v1/simple_kv_offload/manager.py b/vllm/v1/simple_kv_offload/manager.py index f61c4320dffd..bf0dffa7346b 100644 --- a/vllm/v1/simple_kv_offload/manager.py +++ b/vllm/v1/simple_kv_offload/manager.py @@ -192,7 +192,7 @@ def _derive_cpu_config( cpu_tensors = [ KVCacheTensor( size=t.size // num_gpu_blocks * num_cpu_blocks, - shared_by=list(t.shared_by), + shared_by=[list(slot) for slot in t.shared_by], ) for t in gpu_config.kv_cache_tensors ] diff --git a/vllm/v1/simple_kv_offload/worker.py b/vllm/v1/simple_kv_offload/worker.py index c23b44f29173..09559be6cb9e 100644 --- a/vllm/v1/simple_kv_offload/worker.py +++ b/vllm/v1/simple_kv_offload/worker.py @@ -80,15 +80,8 @@ def register_kv_caches( logger.warning("No KV caches to offload.") return - # Resolve each entry to a representative tensor for storage - # deduplication. For attention layers the value is already a tensor; - # for Mamba layers it is a list of tensors that all share the same - # underlying raw storage, so we take the first one. - def _repr_tensor(v: torch.Tensor | list[torch.Tensor]) -> torch.Tensor: - assert isinstance(v, torch.Tensor | list) - return v if isinstance(v, torch.Tensor) else v[0] - - any_tensor = _repr_tensor(next(iter(kv_caches.values()))) + any_tensor = next(iter(kv_caches.values())) + assert isinstance(any_tensor, torch.Tensor) self.device = any_tensor.device assert self.kv_cache_config is not None @@ -97,7 +90,8 @@ def _repr_tensor(v: torch.Tensor | list[torch.Tensor]) -> torch.Tensor: # Deduplicate: multiple layers may share the same backing storage. seen_ptrs: dict[int, tuple[str, torch.Tensor]] = {} for name, value in kv_caches.items(): - tensor = _repr_tensor(value) + assert isinstance(value, torch.Tensor) + tensor = value ptr = tensor.untyped_storage().data_ptr() if ptr not in seen_ptrs: seen_ptrs[ptr] = (name, tensor) @@ -105,13 +99,11 @@ def _repr_tensor(v: torch.Tensor | list[torch.Tensor]) -> torch.Tensor: # Build [num_blocks, block_bytes] int8 views from each unique # storage so that stride(0) gives block_bytes for the copy op. # - # The physical layout varies across attention backends: - # FlashAttn/ROCm: (2, num_blocks, ...) -> K/V outermost, 2 segments - # FlashInfer/MLA: (num_blocks, ...) -> blocks outermost, 1 segment - # We derive page_size_bytes = storage.nbytes() // num_blocks, then - # classify dims: any dim whose byte-stride exceeds page_size_bytes - # must be an outer segment dim (e.g. the K/V dim of size 2). A less - # hacky way is to update the interface with the layout. + # With standardized layouts (RFC #42082) num_blocks is always the + # leading logical dim, but cross-layer physical layouts + # (e.g. BLHNC) interleave layers between blocks, inflating the + # block stride. Detect any such outer segment dims by comparing + # byte-strides against page_size_bytes. unique_gpu_caches: dict[str, torch.Tensor] = {} for name, tensor in seen_ptrs.values(): storage = tensor.untyped_storage() diff --git a/vllm/v1/worker/cpu_model_runner.py b/vllm/v1/worker/cpu_model_runner.py index 6afffa424d42..f0b6529d3f84 100644 --- a/vllm/v1/worker/cpu_model_runner.py +++ b/vllm/v1/worker/cpu_model_runner.py @@ -12,7 +12,7 @@ from vllm.model_executor.model_loader import get_model from vllm.tracing import instrument from vllm.v1.core.sched.output import SchedulerOutput -from vllm.v1.kv_cache_interface import KVCacheConfig +from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig, KVCacheLayout from vllm.v1.utils import CpuGpuBuffer from vllm.v1.worker.gpu_model_runner import GPUModelRunner @@ -124,6 +124,40 @@ def warming_up_model(self) -> None: self.profile_run() logger.info("Warming up done.") + def _allocate_and_reshape_kv_cache( + self, + kv_cache_config: KVCacheConfig, + kernel_block_sizes: list[int], + layout: KVCacheLayout = KVCacheLayout.LBHNC, + ) -> dict[str, torch.Tensor]: + # Re-view the interleaved K|V content dim as separate head groups so + # the backend can slice contiguous per-head K/V: + # [num_blocks, num_kv_heads, block_size, 2*head_size] -> + # [num_blocks, 2*num_kv_heads, block_size, head_size]. Assumes + # content == 2*head_size (true for all CPU attention specs; MLA, which + # differs, isn't supported). + kv_caches = super()._allocate_and_reshape_kv_cache( + kv_cache_config, + kernel_block_sizes, + layout=KVCacheLayout.LBHNC, + ) + attn_layers = { + name + for group in self._kv_cache_spec_attn_group_iterator() + if isinstance(group.kv_cache_spec, AttentionSpec) + for name in group.layer_names + } + for name in attn_layers: + cache = kv_caches.get(name) + if cache is None: + continue + num_blocks, num_kv_heads, block_size, content = cache.shape + assert cache.is_contiguous() and content % 2 == 0 + kv_caches[name] = cache.view( + num_blocks, 2 * num_kv_heads, block_size, content // 2 + ) + return kv_caches + def initialize_kv_cache( self, kv_cache_config: KVCacheConfig, diff --git a/vllm/v1/worker/gpu/attn_utils.py b/vllm/v1/worker/gpu/attn_utils.py index 6fc55ee32030..9ccad49397ae 100644 --- a/vllm/v1/worker/gpu/attn_utils.py +++ b/vllm/v1/worker/gpu/attn_utils.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Iterable, Sequence +from collections.abc import Sequence from dataclasses import dataclass from typing import Any, cast @@ -9,17 +9,19 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase -from vllm.utils.torch_utils import get_dtype_size from vllm.v1.attention.backend import ( + AttentionBackend, AttentionCGSupport, CommonAttentionMetadata, ) +from vllm.v1.attention.backends.utils import resolve_kv_cache_layout from vllm.v1.kv_cache_interface import ( AttentionSpec, KVCacheConfig, + KVCacheLayout, KVCacheSpec, - MambaSpec, UniformTypeKVCacheSpecs, + reshape_kv_cache, ) from vllm.v1.worker.gpu.model_states.interface import ModelSpecificAttnMetadata from vllm.v1.worker.utils import ( @@ -147,199 +149,70 @@ def init_attn_backend( return attn_groups, attn_cg_support_info, kernel_block_sizes -def _allocate_kv_cache( - kv_cache_config: KVCacheConfig, shared_layers: dict[str, str], device: torch.device -): - kv_cache_raw_tensors: dict[str, torch.Tensor] = {} - for kv_cache_tensor in kv_cache_config.kv_cache_tensors: - tensor = torch.zeros(kv_cache_tensor.size, dtype=torch.int8, device=device) - for layer_name in kv_cache_tensor.shared_by: - kv_cache_raw_tensors[layer_name] = tensor +def _allocate_and_reshape_kv_cache( + kv_cache_config: KVCacheConfig, + device: torch.device, + layout: KVCacheLayout | None = None, + kernel_block_sizes: list[int] | None = None, + attn_backends: tuple[type[AttentionBackend], ...] | None = None, +) -> dict[str, Any]: + if layout is None: + layout = resolve_kv_cache_layout( + tuple(attn_backends) if attn_backends else None + ) - layer_names = set() - for group in kv_cache_config.kv_cache_groups: + layer_to_group: dict[str, tuple[KVCacheSpec, int]] = {} + for group_id, group in enumerate(kv_cache_config.kv_cache_groups): + spec = group.kv_cache_spec for layer_name in group.layer_names: - layer_names.add(layer_name) - assert layer_names == (kv_cache_raw_tensors.keys() | shared_layers.keys()), ( - "Some layers are not correctly initialized" - ) - return kv_cache_raw_tensors - + if isinstance(spec, UniformTypeKVCacheSpecs): + layer_to_group[layer_name] = (spec.kv_cache_specs[layer_name], group_id) + else: + layer_to_group[layer_name] = (spec, group_id) -def _reshape_kv_cache( - attn_groups: Sequence[AttentionGroup], - kv_cache_raw_tensors: dict[str, torch.Tensor], - cache_dtype: str, - kernel_block_sizes: list[int], - shared_kv_cache_layers: dict[str, str], -) -> dict[str, Any]: kv_caches: dict[str, Any] = {} - has_attn, has_mamba = False, False - - for group in attn_groups: - if group.kv_cache_group_id >= len(kernel_block_sizes): - continue - - kv_cache_spec = group.kv_cache_spec - if kv_cache_spec.storage_block_size != kv_cache_spec.block_size: - # use storage_block_size as the kernel block size for groups - # that apply a compression on block size (eg. DeepSeek V4). - kernel_block_size = kv_cache_spec.storage_block_size - else: - kernel_block_size = kernel_block_sizes[group.kv_cache_group_id] - - for layer_name in group.layer_names: - if layer_name in shared_kv_cache_layers: - # Shared layer — tensor will be aliased to its target later. + for kv_cache_tensor in kv_cache_config.kv_cache_tensors: + num_layer_slots = len(kv_cache_tensor.shared_by) + buf = torch.zeros(kv_cache_tensor.size, dtype=torch.int8, device=device) + + layer_to_slot: dict[str, int] = {} + for slot_idx, slot_layers in enumerate(kv_cache_tensor.shared_by): + for layer_name in slot_layers: + layer_to_slot[layer_name] = slot_idx + + tensor_layers = set(layer_to_slot) + slot_bytes = kv_cache_tensor.size // num_layer_slots + for group_id, group in enumerate(kv_cache_config.kv_cache_groups): + layer_names = [n for n in group.layer_names if n in tensor_layers] + if not layer_names: continue - - kv_raw_tensor = kv_cache_raw_tensors[layer_name] - assert kv_raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0 - num_blocks = kv_raw_tensor.numel() // kv_cache_spec.page_size_bytes - - if isinstance(kv_cache_spec, AttentionSpec): - has_attn = True - # Use storage_block_size: it equals block_size for uncompressed - # specs but is smaller for compressed ones (DeepSeek V4), which - # store block_size tokens in block_size // compress_ratio slots. - num_blocks_per_kv_block = ( - kv_cache_spec.storage_block_size // kernel_block_size - ) - kernel_num_blocks = num_blocks * num_blocks_per_kv_block - kv_cache_shape = group.backend.get_kv_cache_shape( - kernel_num_blocks, - kernel_block_size, - kv_cache_spec.num_kv_heads, - kv_cache_spec.head_size, - cache_dtype_str=cache_dtype, - ) - - # FIXME(woosuk): Add kv_cache_stride_order to all attention backends. - try: - kv_cache_stride_order = group.backend.get_kv_cache_stride_order() - assert len(kv_cache_stride_order) == len(kv_cache_shape) - except (AttributeError, NotImplementedError): - kv_cache_stride_order = tuple(range(len(kv_cache_shape))) - - kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order) - inv_order = [ - kv_cache_stride_order.index(i) - for i in range(len(kv_cache_stride_order)) - ] - - dtype = kv_cache_spec.dtype - kv_tensor = kv_raw_tensor.view(dtype) - if kv_cache_spec.page_size_padded is not None: - # Use strided view to handle page_size_bytes that - # include padding. This follows the same pattern as - # MambaSpec handling in gpu_model_runner.py. - # NOTE: This assumes kv_cache_shape[0] == num_blocks - # (i.e. the first physical dimension is the block - # index), which holds for all current backends - # (MLA, FlashAttention, TritonAttention, etc.). - dtype_size = get_dtype_size(dtype) - page_stride = kv_cache_spec.page_size_bytes // dtype_size - strides = list(torch.empty(kv_cache_shape).stride()) - strides[inv_order[0]] = page_stride - kv_cache = torch.as_strided( - kv_tensor, - size=kv_cache_shape, - stride=tuple(strides), - ) - else: - # No padding — safe to use a contiguous view. - kv_cache = kv_tensor.view(kv_cache_shape) - kv_caches[layer_name] = kv_cache.permute(*inv_order) - - elif isinstance(kv_cache_spec, MambaSpec): - has_mamba = True - state_tensors = [] - storage_offset_bytes = 0 - for shape, dtype in zip(kv_cache_spec.shapes, kv_cache_spec.dtypes): - dtype_size = get_dtype_size(dtype) - num_element_per_page = kv_cache_spec.page_size_bytes // dtype_size - target_shape = (num_blocks, *shape) - stride = torch.empty(target_shape).stride() - target_stride = (num_element_per_page, *stride[1:]) - assert storage_offset_bytes % dtype_size == 0 - tensor = torch.as_strided( - kv_raw_tensor.view(dtype), - size=target_shape, - stride=target_stride, - storage_offset=storage_offset_bytes // dtype_size, - ) - state_tensors.append(tensor) - storage_offset_bytes += stride[0] * dtype_size - kv_caches[layer_name] = state_tensors + spec, _ = layer_to_group[layer_names[0]] + num_blocks = slot_bytes // spec.page_size_bytes + + if ( + kernel_block_sizes is not None + and isinstance(spec, AttentionSpec) + and group_id < len(kernel_block_sizes) + ): + kernel_block_size = kernel_block_sizes[group_id] + num_blocks = num_blocks * spec.storage_block_size // kernel_block_size else: - raise NotImplementedError( - f"Unsupported KV cache spec type: {type(kv_cache_spec)}" - ) - - if has_attn and has_mamba: - _update_hybrid_attention_layout( - attn_groups=attn_groups, - kv_caches=kv_caches, - kernel_block_sizes=kernel_block_sizes, - cache_dtype=cache_dtype, - ) - - # Map any sharing layers to their target layer's KV cache. - for layer_name, target_layer_name in shared_kv_cache_layers.items(): - kv_caches[layer_name] = kv_caches[target_layer_name] + kernel_block_size = spec.storage_block_size + + views = reshape_kv_cache( + buf, + spec, + num_blocks, + num_layer_slots=num_layer_slots, + layout=layout, + block_size=kernel_block_size, + ) + for layer_name in layer_names: + kv_caches[layer_name] = views[layer_to_slot[layer_name]] return kv_caches -def _update_hybrid_attention_layout( - attn_groups: Iterable[AttentionGroup], - kv_caches: dict[str, Any], - kernel_block_sizes: list[int], - cache_dtype: str, -) -> None: - for group in attn_groups: - if group.kv_cache_group_id >= len(kernel_block_sizes): - continue - - kv_cache_spec = group.kv_cache_spec - if not isinstance(kv_cache_spec, AttentionSpec): - continue - block_dim = group.backend.get_kv_cache_block_dim( - kernel_block_sizes[group.kv_cache_group_id], - kv_cache_spec.num_kv_heads, - kv_cache_spec.head_size, - cache_dtype_str=cache_dtype, - ) - # if the first dim of the kvcache's layout is already num_blocks, continue - if block_dim == 0: - continue - - assert block_dim == 1, ( - "Expected the dim `num_blocks` at the second dim when updating" - " the kvcache's layout of full attention layer" - ) - - for layer_name in group.layer_names: - if layer_name not in kv_caches: - # Shared layer — will be aliased to its target after this pass. - continue - - kv_cache = kv_caches[layer_name] - if kv_cache.shape[0] == 2: - assert kv_cache.shape[1] != 2, ( - f"Cannot determine layout for tensor of shape {kv_cache.shape}" - ) - hidden_size = kv_cache.shape[2:].numel() - kv_cache.as_strided_( - size=kv_cache.shape, - stride=( - hidden_size, - 2 * hidden_size, - *kv_cache.stride()[2:], - ), - ) - - def init_kv_cache( runner_kv_caches: list[torch.Tensor | list[torch.Tensor]], forward_context: dict[str, Any], @@ -347,20 +220,17 @@ def init_kv_cache( attn_groups: list[list[AttentionGroup]], device: torch.device, cache_dtype: str, - kernel_block_sizes: list[int], - vllm_config: VllmConfig, + kernel_block_sizes: list[int] | None = None, + layout: KVCacheLayout | None = None, ) -> dict[str, Any]: - shared_kv_cache_layers = get_shared_kv_cache_layers(vllm_config) - kv_cache_raw_tensors = _allocate_kv_cache( - kv_cache_config, shared_kv_cache_layers, device - ) - flattened_attn_groups = list(group for groups in attn_groups for group in groups) - kv_caches = _reshape_kv_cache( - attn_groups=flattened_attn_groups, - kv_cache_raw_tensors=kv_cache_raw_tensors, + kv_caches = _allocate_and_reshape_kv_cache( + kv_cache_config, + device, + layout=layout, kernel_block_sizes=kernel_block_sizes, - cache_dtype=cache_dtype, - shared_kv_cache_layers=shared_kv_cache_layers, + attn_backends=tuple( + group.backend for groups in attn_groups for group in groups + ), ) bind_kv_cache(kv_caches, forward_context, runner_kv_caches) return kv_caches diff --git a/vllm/v1/worker/gpu/kv_connector.py b/vllm/v1/worker/gpu/kv_connector.py index cdacb36e5833..dcff3c3ddf2d 100644 --- a/vllm/v1/worker/gpu/kv_connector.py +++ b/vllm/v1/worker/gpu/kv_connector.py @@ -50,9 +50,7 @@ def __init__( ): self.vllm_config = vllm_config self.kv_connector = get_kv_transfer_group() - # Register kv caches with KV Connector if applicable. - # TODO: support cross_layers_kv_cache - # (see https://github.com/vllm-project/vllm/pull/27743) + # Register kv caches with KV Connector. self.kv_connector.register_kv_caches(kv_caches_dict) self.kv_connector.set_host_xfer_buffer_ops(copy_kv_blocks) diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index 2e3133822fd4..cd15c7ed5cbe 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -472,7 +472,6 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: self.device, self.cache_config.cache_dtype, self.kernel_block_sizes, - self.vllm_config, ) self.kv_connector = get_kv_connector(self.vllm_config, kv_caches_dict) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index dbbeafcbd40b..9b19827e1173 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -117,7 +117,6 @@ from vllm.utils.nvtx_pytorch_hooks import PytHooks from vllm.utils.platform_utils import is_pin_memory_available, num_compute_units from vllm.utils.torch_utils import ( - get_dtype_size, is_quantized_kv_cache, kv_cache_dtype_str_to_dtype, ) @@ -136,6 +135,7 @@ create_fast_prefill_custom_backend, get_dcp_local_seq_lens, reorder_batch_to_split_decodes_and_prefills, + resolve_kv_cache_layout, ) from vllm.v1.core.sched.output import NewRequestData from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher @@ -147,10 +147,12 @@ FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec, + KVCacheLayout, KVCacheSpec, MambaSpec, SlidingWindowSpec, UniformTypeKVCacheSpecs, + reshape_kv_cache, ) from vllm.v1.outputs import ( EMPTY_MODEL_RUNNER_OUTPUT, @@ -519,9 +521,6 @@ def __init__( # self.model: nn.Module # Set after load_model # Initialize in initialize_kv_cache self.kv_caches: list[torch.Tensor] = [] - # Initialize in initialize_kv_cache_tensors - self.cross_layers_kv_cache: torch.Tensor | None = None - self.cross_layers_attn_backend: type[AttentionBackend] | None = None # indexes: [kv_cache_group_id][attn_group] self.attn_groups: list[list[AttentionGroup]] = [] # self.kv_cache_config: KVCacheConfig @@ -4590,17 +4589,16 @@ def propose_draft_token_ids(sampled_token_ids): with record_function_or_nullcontext( "gpu_model_runner: AsyncGPUModelRunnerOutput" ): - # Async path: produce a device-side snapshot that the async - # copy stream can D2H later. Both tensors must be private - # clones because: + # Async path: produce a device-side snapshot that the async copy + # stream can D2H later. Both tensors must be private clones + # because: # - ``routing_data`` source is the shared capturer buffer, - # which is ``clear_buffer()``-ed at the start of the - # next step on the default stream. + # which is ``clear_buffer()``-ed at the start of the next + # step on the default stream. # - ``slot_mapping`` source is our own - # ``routed_experts_slot_mapping_device``, which the - # next ``_prepare_inputs`` overwrites on the default - # stream while the D2H is still pending on the copy - # stream. + # ``routed_experts_slot_mapping_device``, which the next + # ``_prepare_inputs`` overwrites on the default stream + # while the D2H is still pending on the copy stream. # Without clones, the copy stream would read torn data. routed_experts_snapshot = None if self.routed_experts_initialized: @@ -6295,9 +6293,6 @@ def _cleanup_profiling_kv_cache(self) -> None: for i in range(len(self.kv_caches)): self.kv_caches[i] = None # type: ignore self.kv_caches.clear() - if hasattr(self, "cross_layers_kv_cache"): - self.cross_layers_kv_cache = None - self.cross_layers_attn_backend = None if hasattr(self, "attn_groups"): self.attn_groups.clear() if hasattr(self, "kv_cache_config"): @@ -6717,10 +6712,9 @@ def get_attn_backends_for_group( layer_kv_cache_spec = layer_kv_cache_spec.kv_cache_specs[layer_name] # Non-Attention layer types (e.g. Mamba1, ShortConv) do not # expose ``num_heads``; fall back to 0 so they cluster as - # before. Such layers never coexist with Attention in a - # single KV cache group (different KVCacheSpec), so the - # fallback can never spuriously merge them with attention - # layers. + # before. Such layers never coexist with Attention in a single + # KV cache group (different KVCacheSpec), so the fallback can + # never spuriously merge them with attention layers. num_heads_q = getattr(layers[layer_name], "num_heads", 0) key = (full_cls_name, layer_kv_cache_spec, num_heads_q) attn_backends[key] = AttentionGroupKey( @@ -6983,38 +6977,6 @@ def may_reinitialize_input_batch( f"!= kv_cache kernel_block_sizes {kernel_block_sizes}" ) - def _allocate_kv_cache_tensors( - self, kv_cache_config: KVCacheConfig - ) -> dict[str, torch.Tensor]: - """ - Initializes the KV cache buffer with the correct size. The buffer needs - to be reshaped to the desired shape before being used by the models. - - Args: - kv_cache_config: The KV cache config - Returns: - dict[str, torch.Tensor]: A map between layer names to their - corresponding memory buffer for KV cache. - """ - kv_cache_raw_tensors: dict[str, torch.Tensor] = {} - for kv_cache_tensor in kv_cache_config.kv_cache_tensors: - tensor = torch.zeros( - kv_cache_tensor.size, dtype=torch.int8, device=self.device - ) - for layer_name in kv_cache_tensor.shared_by: - kv_cache_raw_tensors[layer_name] = tensor - - layer_names = set() - for group in kv_cache_config.kv_cache_groups: - for layer_name in group.layer_names: - if layer_name in self.runner_only_attn_layers: - continue - layer_names.add(layer_name) - assert layer_names == set(kv_cache_raw_tensors.keys()), ( - "Some layers are not correctly initialized" - ) - return kv_cache_raw_tensors - def _attn_group_iterator(self) -> Iterator[AttentionGroup]: return itertools.chain.from_iterable(self.attn_groups) @@ -7024,166 +6986,59 @@ def _kv_cache_spec_attn_group_iterator(self) -> Iterator[AttentionGroup]: for attn_groups in self.attn_groups: yield from attn_groups - def _reshape_kv_cache_tensors( + def _allocate_and_reshape_kv_cache( self, - kv_cache_raw_tensors: dict[str, torch.Tensor], + kv_cache_config: KVCacheConfig, kernel_block_sizes: list[int], + layout: KVCacheLayout, ) -> dict[str, torch.Tensor]: - """ - Reshape the KV cache tensors to the desired shape and dtype. + """Allocate backing tensors and reshape into per-layer [B,H,N,C] views. - Args: - kv_cache_raw_tensors: The KV cache buffer of each layer, with - correct size but uninitialized shape. - kernel_block_sizes: The kernel block sizes for each KV cache group. - Returns: - Dict[str, torch.Tensor]: A map between layer names to their - corresponding memory buffer for KV cache. + Returns: a dict mapping layer-name -> view of the backing tensor. """ kv_caches: dict[str, torch.Tensor] = {} - has_attn, has_mamba = False, False - for group in self._kv_cache_spec_attn_group_iterator(): - kv_cache_spec = group.kv_cache_spec - attn_backend = group.backend - if group.kv_cache_group_id == len(kernel_block_sizes): - # There may be a last group for layers without kv cache. - continue - kernel_block_size = kernel_block_sizes[group.kv_cache_group_id] - for layer_name in group.layer_names: - if layer_name in self.runner_only_attn_layers: - continue - raw_tensor = kv_cache_raw_tensors[layer_name] - assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0 - num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes - if isinstance(kv_cache_spec, AttentionSpec): - has_attn = True - num_blocks_per_kv_block = ( - kv_cache_spec.block_size // kernel_block_size - ) - kernel_num_blocks = num_blocks * num_blocks_per_kv_block - - # For MLA with compression, storage_block_size != block_size - if kv_cache_spec.storage_block_size != kv_cache_spec.block_size: - shape_block_size = kv_cache_spec.storage_block_size - else: - shape_block_size = kernel_block_size - - kv_cache_shape = attn_backend.get_kv_cache_shape( - kernel_num_blocks, - shape_block_size, - kv_cache_spec.num_kv_heads, - kv_cache_spec.head_size, - cache_dtype_str=self.cache_config.cache_dtype, - ) - dtype = kv_cache_spec.dtype - try: - kv_cache_stride_order = attn_backend.get_kv_cache_stride_order() - assert len(kv_cache_stride_order) == len(kv_cache_shape) - except (AttributeError, NotImplementedError): - kv_cache_stride_order = tuple(range(len(kv_cache_shape))) - # The allocation respects the backend-defined stride order - # to ensure the semantic remains consistent for each - # backend. We first obtain the generic kv cache shape and - # then permute it according to the stride order which could - # result in a non-contiguous tensor. - kv_cache_shape = tuple( - kv_cache_shape[i] for i in kv_cache_stride_order - ) - # Maintain original KV shape view. - inv_order = [ - kv_cache_stride_order.index(i) - for i in range(len(kv_cache_stride_order)) - ] - raw_tensor = kv_cache_raw_tensors[layer_name].view(dtype) - if kv_cache_spec.page_size_padded is not None: - # Use strided view to handle page_size_bytes that - # include padding. This follows - # the same pattern as MambaSpec handling below. - # NOTE: This assumes kv_cache_shape[0] == num_blocks - # (i.e. the first physical dimension is the block - # index), which holds for MLA backends but NOT for - # standard attention backends whose shape starts with - # a K/V dimension of size 2. - dtype_size = get_dtype_size(dtype) - page_stride = kv_cache_spec.page_size_bytes // dtype_size - strides = list(torch.empty(kv_cache_shape).stride()) - strides[inv_order[0]] = page_stride - kv_cache = torch.as_strided( - raw_tensor, - size=kv_cache_shape, - stride=tuple(strides), - ) - else: - # No padding — safe to use a contiguous view. - kv_cache = raw_tensor.view(kv_cache_shape) - kv_caches[layer_name] = kv_cache.permute(*inv_order) - - elif isinstance(kv_cache_spec, MambaSpec): - has_mamba = True - raw_tensor = kv_cache_raw_tensors[layer_name] - state_tensors = [] - storage_offset_bytes = 0 - for shape, dtype in zip(kv_cache_spec.shapes, kv_cache_spec.dtypes): - dtype_size = get_dtype_size(dtype) - num_element_per_page = ( - kv_cache_spec.page_size_bytes // dtype_size - ) - target_shape = (num_blocks, *shape) - stride = torch.empty(target_shape).stride() - target_stride = (num_element_per_page, *stride[1:]) - assert storage_offset_bytes % dtype_size == 0 - tensor = torch.as_strided( - raw_tensor.view(dtype), - size=target_shape, - stride=target_stride, - storage_offset=storage_offset_bytes // dtype_size, - ) - state_tensors.append(tensor) - storage_offset_bytes += stride[0] * dtype_size - - kv_caches[layer_name] = state_tensors - else: - raise NotImplementedError - - if has_attn and has_mamba: - self._update_hybrid_attention_mamba_layout(kv_caches, kernel_block_sizes) + for kv_cache_tensor in kv_cache_config.kv_cache_tensors: + buf = torch.zeros( + kv_cache_tensor.size, dtype=torch.int8, device=self.device + ) - return kv_caches + layer_to_slot_idx: dict[str, int] = {} - def _update_hybrid_attention_mamba_layout( - self, kv_caches: dict[str, torch.Tensor], kernel_block_sizes: list[int] - ) -> None: - """ - Update the layout of attention layers from (2, num_blocks, ...) to - (num_blocks, 2, ...). + for slot_idx, slot_layers in enumerate(kv_cache_tensor.shared_by): + for layer_name in slot_layers: + layer_to_slot_idx[layer_name] = slot_idx - Args: - kv_caches: The KV cache buffer of each layer. - kernel_block_sizes: The kernel block sizes for each KV cache group. - """ + num_slots = len(kv_cache_tensor.shared_by) + bytes_per_slot = kv_cache_tensor.size // num_slots - for group in self._kv_cache_spec_attn_group_iterator(): - kv_cache_spec = group.kv_cache_spec - if not isinstance(kv_cache_spec, AttentionSpec): - continue - block_dim = group.backend.get_kv_cache_block_dim( - kernel_block_sizes[group.kv_cache_group_id], - kv_cache_spec.num_kv_heads, - kv_cache_spec.head_size, - cache_dtype_str=self.cache_config.cache_dtype, - ) - # block_dim: 0 means (num_blocks, 2, ...); 1 means (2, num_blocks, ...). - if block_dim == 0: - continue - assert block_dim == 1 - for layer_name in group.layer_names: - kv_cache = kv_caches[layer_name] - hidden_size = kv_cache.shape[2:].numel() - kv_cache.as_strided_( - size=kv_cache.shape, - stride=(hidden_size, 2 * hidden_size, *kv_cache.stride()[2:]), + layers_shared_by_tensor = set(layer_to_slot_idx.keys()) + for g in self._kv_cache_spec_attn_group_iterator(): + if g.kv_cache_group_id >= len(kernel_block_sizes): + continue + layer_names = [n for n in g.layer_names if n in layers_shared_by_tensor] + if not layer_names: + continue + spec = g.kv_cache_spec + kernel_block_size = kernel_block_sizes[g.kv_cache_group_id] + + num_blocks = bytes_per_slot // spec.page_size_bytes + num_kernel_blocks_per_block = spec.block_size // kernel_block_size + kernel_num_blocks = num_blocks * num_kernel_blocks_per_block + + multi_slot_kv_cache = reshape_kv_cache( + buf, + spec, + kernel_num_blocks, + num_layer_slots=num_slots, + layout=layout, + block_size=kernel_block_size, ) + for layer_name in layer_names: + layer_view = multi_slot_kv_cache[layer_to_slot_idx[layer_name]] + kv_caches[layer_name] = layer_view + + return kv_caches def initialize_kv_cache_tensors( self, kv_cache_config: KVCacheConfig, kernel_block_sizes: list[int] @@ -7200,29 +7055,12 @@ def initialize_kv_cache_tensors( corresponding memory buffer for KV cache. """ - # Try creating KV caches optimized for kv-connector transfers - cache_dtype = self.cache_config.cache_dtype - if self.use_uniform_kv_cache(self.attn_groups, cache_dtype): - kv_caches, cross_layers_kv_cache, attn_backend = ( - self.allocate_uniform_kv_caches( - kv_cache_config, - self.attn_groups, - cache_dtype, - self.device, - kernel_block_sizes, - ) - ) - self.cross_layers_kv_cache = cross_layers_kv_cache - self.cross_layers_attn_backend = attn_backend - else: - # Fallback to the general case - # Initialize the memory buffer for KV cache - kv_cache_raw_tensors = self._allocate_kv_cache_tensors(kv_cache_config) - - # Change the memory buffer to the desired shape - kv_caches = self._reshape_kv_cache_tensors( - kv_cache_raw_tensors, kernel_block_sizes - ) + layout = resolve_kv_cache_layout( + tuple(g.backend for g in self._attn_group_iterator()) + ) + kv_caches = self._allocate_and_reshape_kv_cache( + kv_cache_config, kernel_block_sizes, layout=layout + ) # Set up cross-layer KV cache sharing for layer_name, target_layer_name in self.shared_kv_cache_layers.items(): @@ -7318,13 +7156,7 @@ def initialize_kv_cache( if has_kv_transfer_group() and not is_profiling: kv_transfer_group = get_kv_transfer_group() - if self.cross_layers_kv_cache is not None: - assert self.cross_layers_attn_backend is not None - kv_transfer_group.register_cross_layers_kv_cache( - self.cross_layers_kv_cache, self.cross_layers_attn_backend - ) - else: - kv_transfer_group.register_kv_caches(kv_caches) + kv_transfer_group.register_kv_caches(kv_caches) kv_transfer_group.set_host_xfer_buffer_ops(copy_kv_blocks) def _get_attention_kv_cache_gid(self) -> int: diff --git a/vllm/v1/worker/kv_connector_model_runner_mixin.py b/vllm/v1/worker/kv_connector_model_runner_mixin.py index 797e59c02909..085d16771d46 100644 --- a/vllm/v1/worker/kv_connector_model_runner_mixin.py +++ b/vllm/v1/worker/kv_connector_model_runner_mixin.py @@ -8,21 +8,15 @@ from contextlib import AbstractContextManager, contextmanager, nullcontext from typing import TYPE_CHECKING -import torch - from vllm.config import VllmConfig -from vllm.config.cache import CacheDType from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase from vllm.forward_context import get_forward_context, set_forward_context from vllm.logger import init_logger -from vllm.v1.attention.backend import AttentionBackend -from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig from vllm.v1.outputs import ( KVConnectorOutput, ModelRunnerOutput, ) -from vllm.v1.worker.utils import AttentionGroup if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput @@ -110,167 +104,3 @@ def _get_kv_connector_output( if not defer_finalize: kv_connector.clear_connector_metadata() - - @staticmethod - def use_uniform_kv_cache( - attn_groups: list[list[AttentionGroup]], - cache_dtype: CacheDType, - ) -> bool: - """ - Determines whether a uniform KV layout should be used. - A uniform layout means all layers KV caches will share the same - underlying tensor, where for a given block number, the respective - KV data for all layers will be contiguous. - This will allow efficient KV transfer of per-block KV data for all - layers at once. - Note this layout will only be applied given 3 conditions: - 1. The KV Cache config contains just a single group where all layers - have the same page size. - 2. A KV connector is configured, and the KV connector instance prefers - to use this layout (prefer_cross_layer_blocks() returns True) - 2. The flash attention backend supports this layout - (get_kv_cache_stride_order(True) includes a placement for a - num_layers dimension) - - Note that the actual placement of the num_layers dimensions - in the unified layers tensors will be determined by the attention - backend. - Thus, the layers KV data may still not be contiguous per block - if the attention backend does not support it. - - Args: - attn_groups: The list of attention groups for this model - cache_dtype: The KV cache dtype - Returns: - True if we should use a uniform KV cache layout. - """ - - if not has_kv_transfer_group(): - return False - if not get_kv_transfer_group().prefer_cross_layer_blocks: - return False - - if len(attn_groups) != 1 or len(attn_groups[0]) != 1: - return False - - attn_group = attn_groups[0][0] - kv_cache_spec = attn_group.kv_cache_spec - if not isinstance(kv_cache_spec, AttentionSpec): - return False - - attn_backend = attn_group.backend - kv_cache_shape = attn_backend.get_kv_cache_shape( - 1234, - kv_cache_spec.block_size, - kv_cache_spec.num_kv_heads, - kv_cache_spec.head_size, - cache_dtype_str=cache_dtype, - ) - - try: - kv_cache_stride_order = attn_backend.get_kv_cache_stride_order( - include_num_layers_dimension=True - ) - except (AttributeError, NotImplementedError): - return False - - # check that attention backend includes a layers dimension - if len(kv_cache_stride_order) != len(kv_cache_shape) + 1: - return False - - # stride_order[0] == 0 means num_layers stays first in physical - # layout (identity permutation), so cross-layer is unsupported. - return kv_cache_stride_order[0] != 0 - - @staticmethod - def allocate_uniform_kv_caches( - kv_cache_config: KVCacheConfig, - attn_groups: list[list[AttentionGroup]], - cache_dtype: CacheDType, - device: torch.device, - kernel_block_sizes: list[int], - ) -> tuple[dict[str, torch.Tensor], torch.Tensor, type[AttentionBackend]]: - """ - Initializes and reshapes KV caches for the simple case where all - layers have the same layout. - - This function assumes use_uniform_kv_cache() returned True. - - Args: - kv_cache_config: The KV cache config - attn_groups: The list of attention groups for this model - cache_dtype: The KV cache dtype - device: The torch device to allocate on. - kernel_block_sizes: The kernel block sizes for each KV cache group. - Returns: - A tuple (kv_caches, cross_layers_kv_cache, attn_backend) where: - kv_caches is a dict mapping between layer names to their - corresponding memory buffer for KV cache. - cross_layers_kv_cache is the cross layers kv cache tensor - attn_backend is the attention backend matching this tensor - """ - attn_group = attn_groups[0][0] - kv_cache_spec = attn_group.kv_cache_spec - assert isinstance(kv_cache_spec, AttentionSpec) - - tensor_sizes = set( - kv_cache_tensor.size for kv_cache_tensor in kv_cache_config.kv_cache_tensors - ) - assert len(tensor_sizes) == 1 - tensor_size = tensor_sizes.pop() - - page_size = kv_cache_spec.page_size_bytes - assert tensor_size % page_size == 0 - num_blocks = tensor_size // page_size - num_layers = len(kv_cache_config.kv_cache_tensors) - total_size = tensor_size * num_layers - - assert len(kernel_block_sizes) == 1 - kernel_block_size = kernel_block_sizes[0] - num_blocks_per_kv_block = kv_cache_spec.block_size // kernel_block_size - kernel_num_blocks = num_blocks * num_blocks_per_kv_block - - attn_backend = attn_group.backend - kv_cache_shape = attn_backend.get_kv_cache_shape( - kernel_num_blocks, - kernel_block_size, - kv_cache_spec.num_kv_heads, - kv_cache_spec.head_size, - cache_dtype_str=cache_dtype, - ) - - # prepend a num_layers dimension into the shape - kv_cache_shape = (num_layers,) + kv_cache_shape - - try: - kv_cache_stride_order = attn_backend.get_kv_cache_stride_order( - include_num_layers_dimension=True - ) - assert len(kv_cache_stride_order) == len(kv_cache_shape) - except (AttributeError, NotImplementedError): - kv_cache_stride_order = tuple(range(len(kv_cache_shape))) - - kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order) - - logger.info("Allocating a cross layer KV cache of shape %s", kv_cache_shape) - - # allocate one contiguous buffer for all layers - cross_layers_kv_cache = ( - torch.zeros(total_size, dtype=torch.int8, device=device) - .view(kv_cache_spec.dtype) - .view(kv_cache_shape) - ) - - # Maintain original KV shape view. - inv_order = [ - kv_cache_stride_order.index(i) for i in range(len(kv_cache_stride_order)) - ] - permuted_kv_cache = cross_layers_kv_cache.permute(*inv_order) - - kv_caches = {} - for i, kv_cache_tensor in enumerate(kv_cache_config.kv_cache_tensors): - tensor = permuted_kv_cache[i] - for layer_name in kv_cache_tensor.shared_by: - kv_caches[layer_name] = tensor - - return kv_caches, cross_layers_kv_cache, attn_backend diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index c0f44b6db0c3..abd4a7fcea8a 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -4,7 +4,6 @@ from collections import defaultdict from collections.abc import Iterable from dataclasses import dataclass, field -from itertools import product as iprod from typing import Any import torch @@ -128,13 +127,6 @@ def __init__( continue kernel_bs = kernel_block_sizes[group.kv_cache_group_id] ratio = spec.block_size // kernel_bs - block_dim = group.backend.get_kv_cache_block_dim( - kernel_bs, - spec.num_kv_heads, - spec.head_size, - cache_dtype_str=cache_dtype, - ) - for layer_name in group.layer_names: if layer_name in runner_only_attn_layers: continue @@ -147,7 +139,7 @@ def __init__( seen_ptrs.add(dp) el = kv.element_size() - cur_bytes = kv.stride(block_dim) * el + cur_bytes = kv.stride(0) * el assert cur_bytes % 4 == 0 kernel_block_el = cur_bytes // 4 cur_page_el = kernel_block_el * ratio @@ -158,16 +150,7 @@ def __init__( f"Non-uniform page sizes: {page_size_el} vs {cur_page_el}" ) - block_stride_bytes = cur_bytes - outer_dims = [ - d - for d in range(block_dim) - if kv.stride(d) * el > block_stride_bytes - ] - outer_strides = [kv.stride(d) * el for d in outer_dims] - for outer in iprod(*(range(kv.shape[d]) for d in outer_dims)): - off_bytes = sum(i * s for i, s in zip(outer, outer_strides)) - seg_addrs.append(dp + off_bytes) + seg_addrs.append(dp) if not seg_addrs or page_size_el is None: self._meta = None @@ -515,7 +498,7 @@ def bind_kv_cache( # Bind kv_caches to forward context for layer_name, kv_cache in kv_caches.items(): - forward_context[layer_name].kv_cache = kv_cache + forward_context[layer_name].bind_kv_cache(kv_cache) def is_residual_scattered_for_sp(