diff --git a/.buildkite/test-amd.yaml b/.buildkite/test-amd.yaml
index a7e26280c90d..05b32d889d06 100644
--- a/.buildkite/test-amd.yaml
+++ b/.buildkite/test-amd.yaml
@@ -839,7 +839,7 @@ steps:
num_gpus: 2
working_dir: "/vllm-workspace/tests"
source_file_dependencies:
- - vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
+ - vllm/distributed/kv_transfer/kv_connector/v1/nixl/
- vllm/v1/worker/kv_connector_model_runner_mixin.py
- tests/v1/kv_connector/nixl_integration/
- vllm/platforms/rocm.py
@@ -866,7 +866,7 @@ steps:
num_gpus: 4
working_dir: "/vllm-workspace/tests"
source_file_dependencies:
- - vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
+ - vllm/distributed/kv_transfer/kv_connector/v1/nixl/
- tests/v1/kv_connector/nixl_integration/
- vllm/platforms/rocm.py
commands:
@@ -2341,7 +2341,7 @@ steps:
num_gpus: 4
working_dir: "/vllm-workspace/tests"
source_file_dependencies:
- - vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
+ - vllm/distributed/kv_transfer/kv_connector/v1/nixl/
- tests/v1/kv_connector/nixl_integration/
- vllm/platforms/rocm.py
commands:
@@ -2377,7 +2377,7 @@ steps:
optional: true
working_dir: "/vllm-workspace/tests"
source_file_dependencies:
- - vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
+ - vllm/distributed/kv_transfer/kv_connector/v1/nixl/
- tests/v1/kv_connector/nixl_integration/
- vllm/platforms/rocm.py
commands:
@@ -2391,7 +2391,7 @@ steps:
num_gpus: 4
working_dir: "/vllm-workspace/tests"
source_file_dependencies:
- - vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
+ - vllm/distributed/kv_transfer/kv_connector/v1/nixl/
- tests/v1/kv_connector/nixl_integration/
- vllm/platforms/rocm.py
commands:
@@ -2405,7 +2405,7 @@ steps:
num_gpus: 4
working_dir: "/vllm-workspace/tests"
source_file_dependencies:
- - vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
+ - vllm/distributed/kv_transfer/kv_connector/v1/nixl/
- tests/v1/kv_connector/nixl_integration/
- vllm/platforms/rocm.py
commands:
@@ -3353,7 +3353,7 @@ steps:
optional: true
working_dir: "/vllm-workspace/tests"
source_file_dependencies:
- - vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
+ - vllm/distributed/kv_transfer/kv_connector/v1/nixl/
- vllm/v1/worker/kv_connector_model_runner_mixin.py
- tests/v1/kv_connector/nixl_integration/
- vllm/platforms/rocm.py
@@ -3369,7 +3369,7 @@ steps:
optional: true
working_dir: "/vllm-workspace/tests"
source_file_dependencies:
- - vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
+ - vllm/distributed/kv_transfer/kv_connector/v1/nixl/
- tests/v1/kv_connector/nixl_integration/
- vllm/platforms/rocm.py
commands:
@@ -3384,7 +3384,7 @@ steps:
optional: true
working_dir: "/vllm-workspace/tests"
source_file_dependencies:
- - vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
+ - vllm/distributed/kv_transfer/kv_connector/v1/nixl/
- tests/v1/kv_connector/nixl_integration/
- vllm/platforms/rocm.py
commands:
diff --git a/benchmarks/attention_benchmarks/runner.py b/benchmarks/attention_benchmarks/runner.py
index aa636cd9cb53..cecba287bb4f 100644
--- a/benchmarks/attention_benchmarks/runner.py
+++ b/benchmarks/attention_benchmarks/runner.py
@@ -11,6 +11,7 @@
import logging
import types
from contextlib import contextmanager
+from math import prod
import numpy as np
import torch
@@ -30,10 +31,13 @@
)
from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata,
- get_kv_cache_layout,
- set_kv_cache_layout,
+ resolve_kv_cache_layout,
+)
+from vllm.v1.kv_cache_interface import (
+ FullAttentionSpec,
+ compute_layer_kv_cache_shape_bytes,
+ reshape_kv_cache,
)
-from vllm.v1.kv_cache_interface import FullAttentionSpec
# ============================================================================
# Backend Configuration
@@ -324,52 +328,23 @@ def _create_input_tensors(
def _create_kv_cache(
config: BenchmarkConfig,
max_num_blocks: int,
- backend_class,
device: torch.device,
dtype: torch.dtype,
) -> list:
- """Create KV cache tensors for all layers using the backend's methods.
-
- Uses the backend's get_kv_cache_shape() and get_kv_cache_stride_order()
- to create the cache with the correct shape and memory layout.
- """
- # Get the logical shape from the backend
- cache_shape = backend_class.get_kv_cache_shape(
- num_blocks=max_num_blocks,
+ """Create KV cache tensors for all layers using the standard allocator."""
+ spec = FullAttentionSpec(
block_size=config.block_size,
num_kv_heads=config.num_kv_heads,
head_size=config.head_dim,
+ dtype=dtype,
)
-
- # Get the stride order for custom memory layout
- try:
- stride_order = backend_class.get_kv_cache_stride_order()
- assert len(stride_order) == len(cache_shape)
- except (AttributeError, NotImplementedError):
- stride_order = tuple(range(len(cache_shape)))
-
- # Permute shape to physical layout order
- physical_shape = tuple(cache_shape[i] for i in stride_order)
-
- # Compute inverse permutation to get back to logical view
- inv_order = [stride_order.index(i) for i in range(len(stride_order))]
-
- # Use fp8 dtype for cache when requested.
- cache_dtype = dtype
- if config.kv_cache_dtype == "fp8":
- from vllm.platforms import current_platform
-
- cache_dtype = current_platform.fp8_dtype()
-
- cache_list = []
- for _ in range(config.num_layers):
- # Allocate in physical layout order (contiguous in memory)
- cache = torch.zeros(*physical_shape, device=device, dtype=cache_dtype)
- # Permute to logical view
- cache = cache.permute(*inv_order)
- cache_list.append(cache)
-
- return cache_list
+ layout = resolve_kv_cache_layout()
+ total_bytes = (
+ prod(compute_layer_kv_cache_shape_bytes(spec, max_num_blocks))
+ * config.num_layers
+ )
+ buf = torch.zeros(total_bytes, device=device, dtype=torch.int8)
+ return reshape_kv_cache(buf, spec, max_num_blocks, config.num_layers, layout)
# ============================================================================
@@ -514,13 +489,6 @@ def run_attention_benchmark(config: BenchmarkConfig) -> BenchmarkResult:
backend_cfg, config, device, dtype
)
- # Set KV cache layout if the backend requires a specific one
- # (e.g., FlashInfer requires HND on SM100/Blackwell for TRTLLM attention)
- required_layout = backend_class.get_required_kv_cache_layout()
- if required_layout is not None:
- set_kv_cache_layout(required_layout)
- get_kv_cache_layout.cache_clear()
-
common_metadata = _build_common_attn_metadata(
q_lens, kv_lens, config.block_size, device
)
@@ -549,9 +517,7 @@ def run_attention_benchmark(config: BenchmarkConfig) -> BenchmarkResult:
config, total_q, device, dtype, quantize_query=quantize_query
)
- cache_list = _create_kv_cache(
- config, max_num_blocks, backend_class, device, dtype
- )
+ cache_list = _create_kv_cache(config, max_num_blocks, device, dtype)
times, mem_stats = _run_single_benchmark(
config,
diff --git a/docs/features/nixl_connector_compatibility.md b/docs/features/nixl_connector_compatibility.md
index 5541cd99bd80..84f564796155 100644
--- a/docs/features/nixl_connector_compatibility.md
+++ b/docs/features/nixl_connector_compatibility.md
@@ -59,7 +59,7 @@ th:not(:first-child) {
1 P and D instances must use the same speculation configuration.
-2 Requires `FLASH_ATTN` or `FLASHINFER` backend **and** `HND` KV cache layout. Enable via `--kv-transfer-config '{"kv_connector_extra_config": {"enable_cross_layers_blocks": "True"}}'`.
+2 Cross-layer contiguity is achieved by using a `BLHNC` layout (set via `VLLM_KV_CACHE_LAYOUT=BLHNC` or `--enable-cross-layers`).
3 Supported only when HMA is **not** required (i.e., non-hybrid models). Block IDs are remapped automatically. Only P block size < D block size is supported.
diff --git a/docs/features/nixl_connector_usage.md b/docs/features/nixl_connector_usage.md
index cb5a3dca035a..7717a57b2f62 100644
--- a/docs/features/nixl_connector_usage.md
+++ b/docs/features/nixl_connector_usage.md
@@ -389,15 +389,6 @@ Support use case: Prefill with 'HND' and decode with 'NHD' with experimental con
--kv-transfer-config '{..., "enable_permute_local_kv":"True"}'
```
-### Cross layers blocks
-
-By default, this feature is disabled. On attention backends that support this feature, each logical block is contiguous in physical memory. This reduces the number of buffers that need to be transferred.
-To enable this feature:
-
-```bash
---kv-transfer-config '{..., "kv_connector_extra_config": {"enable_cross_layers_blocks": "True"}}'
-```
-
## Example Scripts/Code
Refer to these example scripts in the vLLM repository:
diff --git a/tests/compile/passes/test_fusion_attn.py b/tests/compile/passes/test_fusion_attn.py
index b776f6af98a1..1c60dad80e4d 100644
--- a/tests/compile/passes/test_fusion_attn.py
+++ b/tests/compile/passes/test_fusion_attn.py
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
+from math import prod
import pytest
import torch._dynamo
@@ -39,7 +40,12 @@
from vllm.utils.flashinfer import has_flashinfer
from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.attention.backends.registry import AttentionBackendEnum
-from vllm.v1.kv_cache_interface import AttentionSpec, get_kv_quant_mode
+from vllm.v1.attention.backends.utils import resolve_kv_cache_layout
+from vllm.v1.kv_cache_interface import (
+ AttentionSpec,
+ compute_layer_kv_cache_shape_bytes,
+ get_kv_quant_mode,
+)
DEVICE_TYPE = current_platform.device_type
FP8_DTYPE = current_platform.fp8_dtype()
@@ -108,29 +114,28 @@ def build_attn_metadata(self, batch_size: int) -> AttentionMetadata:
max_blocks = (max(batch_spec.seq_lens) + self.block_size - 1) // self.block_size
num_blocks = batch_size * max_blocks
- # Fetch the attention backend and kv cache shape and stride order
- attn_backend = self.attn.attn_backend
- kv_cache_shape = attn_backend.get_kv_cache_shape(
- num_blocks, self.block_size, self.num_kv_heads, self.head_size
+ spec = AttentionSpec(
+ block_size=self.block_size,
+ num_kv_heads=self.num_kv_heads,
+ head_size=self.head_size,
+ dtype=self.attn.kv_cache_torch_dtype,
+ kv_quant_mode=get_kv_quant_mode(self.attn.kv_cache_dtype),
)
- try:
- kv_cache_stride_order = attn_backend.get_kv_cache_stride_order()
- except (AttributeError, NotImplementedError):
- kv_cache_stride_order = tuple(range(len(kv_cache_shape)))
-
- kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order)
+ layout = resolve_kv_cache_layout()
+ kv_cache_shape = compute_layer_kv_cache_shape_bytes(spec, num_blocks)
+ kv_cache_stride_order = layout.layer_stride_order
+ physical_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order)
inv_order = [
kv_cache_stride_order.index(i) for i in range(len(kv_cache_stride_order))
]
- # Create dummy KV cache
raw_tensor = torch.zeros(
- 2 * num_blocks * self.block_size * self.num_kv_heads * self.head_size,
- dtype=self.attn.kv_cache_torch_dtype,
+ prod(kv_cache_shape),
+ dtype=torch.int8,
device=self.device,
)
- raw_tensor = raw_tensor.view(kv_cache_shape)
- kv_cache = raw_tensor.permute(*inv_order)
+ raw_tensor = raw_tensor.view(physical_shape)
+ kv_cache = raw_tensor.permute(*inv_order).view(self.attn.kv_cache_torch_dtype)
self.attn.kv_cache = kv_cache
diff --git a/tests/compile/passes/test_mla_attn_quant_fusion.py b/tests/compile/passes/test_mla_attn_quant_fusion.py
index 0a38ffca483a..c1a5e714499e 100644
--- a/tests/compile/passes/test_mla_attn_quant_fusion.py
+++ b/tests/compile/passes/test_mla_attn_quant_fusion.py
@@ -150,27 +150,14 @@ def build_attn_metadata(self, batch_size: int) -> AttentionMetadata:
max_blocks = (max(batch_spec.seq_lens) + self.block_size - 1) // self.block_size
num_blocks = batch_size * max_blocks
- # MLA KV cache is 3D: (num_blocks, block_size, head_size)
- attn_backend = self.mla_attn.attn_backend
- kv_cache_shape = attn_backend.get_kv_cache_shape(
- num_blocks, self.block_size, 1, self.head_size
- )
- try:
- kv_cache_stride_order = attn_backend.get_kv_cache_stride_order()
- except (AttributeError, NotImplementedError):
- kv_cache_stride_order = tuple(range(len(kv_cache_shape)))
-
- ordered_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order)
- inv_order = [
- kv_cache_stride_order.index(i) for i in range(len(kv_cache_stride_order))
- ]
-
- raw_tensor = torch.zeros(
- ordered_shape, dtype=self.kv_cache_dtype, device=self.device
+ # MLA KV cache is 4D: (num_blocks, num_heads=1, block_size, head_size)
+ kv_cache = torch.zeros(
+ (num_blocks, 1, self.block_size, self.head_size),
+ dtype=self.kv_cache_dtype,
+ device=self.device,
)
- kv_cache = raw_tensor.permute(*inv_order)
- self.mla_attn.kv_cache = kv_cache
+ self.mla_attn.bind_kv_cache(kv_cache)
self.attn_metadata = self.builder.build(
common_prefix_len=0, common_attn_metadata=common_attn_metadata
diff --git a/tests/compile/passes/test_mla_rope_kvcache_cat_fusion.py b/tests/compile/passes/test_mla_rope_kvcache_cat_fusion.py
index cc3bfb7693ba..b1da94cb20b3 100644
--- a/tests/compile/passes/test_mla_rope_kvcache_cat_fusion.py
+++ b/tests/compile/passes/test_mla_rope_kvcache_cat_fusion.py
@@ -165,29 +165,15 @@ def build_attn_metadata(self, batch_size: int) -> CommonAttentionMetadata:
max_blocks = (max(batch_spec.seq_lens) + self.block_size - 1) // self.block_size
num_blocks = batch_size * max_blocks
- # Fetch the attention backend and kv cache shape and stride order
- kv_cache_shape = self.attn_backend.get_kv_cache_shape(
- num_blocks, self.block_size, self.num_kv_heads, self.head_size
- )
- try:
- kv_cache_stride_order = self.attn_backend.get_kv_cache_stride_order()
- except (AttributeError, NotImplementedError):
- kv_cache_stride_order = tuple(range(len(kv_cache_shape)))
-
- kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order)
- inv_order = [
- kv_cache_stride_order.index(i) for i in range(len(kv_cache_stride_order))
- ]
-
- raw_tensor = torch.zeros(
- num_blocks * self.block_size * self.num_kv_heads * self.head_size,
+ # MLA uses a 4D KV cache: (num_blocks, num_heads=1, block_size, head_size).
+ kv_cache_shape = (num_blocks, 1, self.block_size, self.head_size)
+ kv_cache = torch.zeros(
+ kv_cache_shape,
dtype=self.kv_cache_dtype,
device=self.device,
)
- raw_tensor = raw_tensor.view(kv_cache_shape)
- kv_cache = raw_tensor.permute(*inv_order)
- self.mla_attn.kv_cache = kv_cache
+ self.mla_attn.bind_kv_cache(kv_cache)
# Build attn metadata
attn_metadata = self.builder.build(
diff --git a/tests/compile/passes/test_rope_kvcache_fusion.py b/tests/compile/passes/test_rope_kvcache_fusion.py
index b27adfc46f51..4f97f3181c51 100644
--- a/tests/compile/passes/test_rope_kvcache_fusion.py
+++ b/tests/compile/passes/test_rope_kvcache_fusion.py
@@ -34,6 +34,10 @@
CommonAttentionMetadata,
)
from vllm.v1.attention.backends.registry import AttentionBackendEnum
+from vllm.v1.kv_cache_interface import (
+ FullAttentionSpec,
+ compute_layer_kv_cache_shape_bytes,
+)
INDEX_SELECT_OP = torch.ops.aten.index.Tensor
VLLM_UNIFIED_KV_CACHE_UPDATE_OP = torch.ops.vllm.unified_kv_cache_update
@@ -119,28 +123,21 @@ def build_attn_metadata(self, batch_size: int) -> CommonAttentionMetadata:
max_blocks = (max(batch_spec.seq_lens) + self.block_size - 1) // self.block_size
num_blocks = batch_size * max_blocks
- # Fetch the attention backend and kv cache shape and stride order
- kv_cache_shape = self.attn_backend.get_kv_cache_shape(
- num_blocks, self.block_size, self.num_kv_heads, self.head_size
+ kv_cache_shape = compute_layer_kv_cache_shape_bytes(
+ FullAttentionSpec(
+ block_size=self.block_size,
+ num_kv_heads=self.num_kv_heads,
+ head_size=self.head_size,
+ dtype=self.kv_cache_dtype,
+ ),
+ num_blocks,
)
- try:
- kv_cache_stride_order = self.attn_backend.get_kv_cache_stride_order()
- except (AttributeError, NotImplementedError):
- kv_cache_stride_order = tuple(range(len(kv_cache_shape)))
-
- kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order)
- inv_order = [
- kv_cache_stride_order.index(i) for i in range(len(kv_cache_stride_order))
- ]
- # Create dummy KV cache
- raw_tensor = torch.zeros(
- 2 * num_blocks * self.block_size * self.num_kv_heads * self.head_size,
- dtype=self.kv_cache_dtype,
+ kv_cache = torch.zeros(
+ kv_cache_shape,
+ dtype=torch.int8,
device=self.device,
- )
- raw_tensor = raw_tensor.view(kv_cache_shape)
- kv_cache = raw_tensor.permute(*inv_order)
+ ).view(self.kv_cache_dtype)
self.attn.kv_cache = kv_cache
diff --git a/tests/distributed/test_kvlayout.py b/tests/distributed/test_kvlayout.py
index c8177f1c7c2f..11f4da241fa7 100644
--- a/tests/distributed/test_kvlayout.py
+++ b/tests/distributed/test_kvlayout.py
@@ -21,7 +21,7 @@ def test_get_kv_connector_cache_layout_without_kv_connector():
with set_current_vllm_config(vllm_config):
# Test with default settings
layout = get_kv_connector_cache_layout()
- assert layout == "NHD"
+ assert layout == "LBNHC"
def test_get_kv_connector_cache_layout_with_lmcache_connector():
@@ -35,7 +35,7 @@ def test_get_kv_connector_cache_layout_with_lmcache_connector():
with set_current_vllm_config(vllm_config):
# Test with default settings
layout = get_kv_connector_cache_layout()
- assert layout == "NHD"
+ assert layout == "LBNHC"
def test_get_kv_connector_cache_layout_with_nixl_connector():
@@ -52,7 +52,7 @@ def test_get_kv_connector_cache_layout_with_nixl_connector():
with set_current_vllm_config(vllm_config):
# Test with default settings
layout = get_kv_connector_cache_layout()
- assert layout == "HND"
+ assert layout == "LBHNC"
def test_get_kv_connector_cache_layout_with_multi_connector():
@@ -75,4 +75,4 @@ def test_get_kv_connector_cache_layout_with_multi_connector():
with set_current_vllm_config(vllm_config):
# Test with default settings
layout = get_kv_connector_cache_layout()
- assert layout == "HND"
+ assert layout == "LBHNC"
diff --git a/tests/kernels/attention/test_cache.py b/tests/kernels/attention/test_cache.py
index 9b022a042c81..d9dc80792f94 100644
--- a/tests/kernels/attention/test_cache.py
+++ b/tests/kernels/attention/test_cache.py
@@ -10,7 +10,7 @@
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.quant_utils import scaled_dequantize
from vllm.platforms import current_platform
-from vllm.utils.torch_utils import nvfp4_kv_cache_split_views, set_random_seed
+from vllm.utils.torch_utils import nvfp4_split_data_scale, set_random_seed
COPYING_DIRECTION = [("cuda", "cpu"), ("cuda", "cuda"), ("cpu", "cuda")]
DTYPES = [torch.bfloat16, torch.float]
@@ -19,7 +19,7 @@
NUM_HEADS = [8] # Arbitrary values for testing
HEAD_SIZES = [64, 80, 256]
BLOCK_SIZES = [8, 16, 32]
-CACHE_LAYOUTS = ["NHD", "HND"]
+CACHE_LAYOUTS = ["LBNHC", "LBHNC"]
KV_SCALE_TYPES = ["tensor", "attn_head"]
# Parameters for MLA tests.
@@ -196,8 +196,8 @@ def test_reshape_and_cache_flash(
torch.set_default_device(device)
torch.accelerator.set_device_index(device)
assert implementation in ["cuda", "triton"]
- if implementation == "triton" and kv_cache_layout == "HND":
- pytest.skip("Triton implementation only supports NHD layout.")
+ if implementation == "triton" and kv_cache_layout == "LBHNC":
+ pytest.skip("Triton implementation only supports LBNHC layout.")
if kv_scale_type == "attn_head" and implementation != "cuda":
pytest.skip("Only CUDA implementation supports attn_head scaling.")
@@ -255,10 +255,8 @@ def test_reshape_and_cache_flash(
nvfp4_key_data = None
nvfp4_value_data = None
if kv_cache_dtype == "nvfp4":
- (nvfp4_key_data,), (key_scale_cache,) = nvfp4_kv_cache_split_views(key_cache)
- (nvfp4_value_data,), (value_scale_cache,) = nvfp4_kv_cache_split_views(
- value_cache
- )
+ nvfp4_key_data, key_scale_cache = nvfp4_split_data_scale(key_cache)
+ nvfp4_value_data, value_scale_cache = nvfp4_split_data_scale(value_cache)
if kv_cache_dtype == "nvfp4":
# Global scale = amax / 448 (per-tensor)
@@ -272,7 +270,7 @@ def test_reshape_and_cache_flash(
v_scale = (value.amax(dim=(0, 2)) / 64.0).to(torch.float32)
def permute_and_compact(x):
- y = x if kv_cache_layout == "NHD" else x.permute(0, 2, 1, 3)
+ y = x if kv_cache_layout == "LBNHC" else x.permute(0, 2, 1, 3)
return y.contiguous()
if kv_cache_dtype != "nvfp4":
@@ -286,8 +284,8 @@ def convert_fp8_local(output, input, scale, kv_dtype):
fp8_input.flatten(0, 2), scale, group_shape=None, out_dtype=output.dtype
).reshape(*input.shape)
else: # per-head: broadcast scale along the head dimension
- # Original code uses dim 2 for NHD, dim 1 for HND
- if kv_cache_layout == "NHD":
+ # Original code uses dim 2 for LBNHC, dim 1 for LBHNC
+ if kv_cache_layout == "LBNHC":
result = fp8_input.to(output.dtype) * scale.view(1, 1, -1, 1)
else:
result = fp8_input.to(output.dtype) * scale.view(1, -1, 1, 1)
@@ -356,28 +354,29 @@ def convert_fp8_local(output, input, scale, kv_dtype):
dequant_nvfp4_kv_cache,
)
- def dequant_nvfp4_cache_nhd(data_cache, scale_cache, global_scale):
- # data_cache: [N, T, H, data_dim] NHD (contiguous inner dims)
- # scale_cache: [N, T, H, scale_dim] NHD (contiguous inner dims)
- # Permute to HND layout for the dequant utility.
- data_hnd = data_cache.permute(0, 2, 1, 3)
- scale_hnd = scale_cache.permute(0, 2, 1, 3)
- result_hnd = dequant_nvfp4_kv_cache(
- data_hnd, scale_hnd, global_scale, head_size, block_size
+ def dequant_nvfp4_cache_hnc(data_cache, scale_cache, global_scale):
+ # data_cache: [H, N, T, data_dim] HNC layout
+ # scale_cache: [H, N, T, scale_dim] HNC layout
+ return dequant_nvfp4_kv_cache(
+ data_cache, scale_cache, global_scale, head_size, block_size
)
- return result_hnd.permute(0, 2, 1, 3) # back to [N, T, H, D]
- result_key_cache = dequant_nvfp4_cache_nhd(
+ result_key_cache = dequant_nvfp4_cache_hnc(
nvfp4_key_data, key_scale_cache, k_scale.item()
)
- result_value_cache = dequant_nvfp4_cache_nhd(
+ result_value_cache = dequant_nvfp4_cache_hnc(
nvfp4_value_data, value_scale_cache, v_scale.item()
)
- # Flatten [num_blocks, block_size] → [num_slots] and index by slot_mapping.
+ # Result is HNC: (num_blocks, num_heads, block_size, head_size).
+ # Flatten to (num_slots, num_heads, head_size) for comparison.
num_slots = num_blocks * block_size
- result_key_flat = result_key_cache.reshape(num_slots, num_heads, head_size)
- result_value_flat = result_value_cache.reshape(num_slots, num_heads, head_size)
+ result_key_flat = result_key_cache.permute(0, 2, 1, 3).reshape(
+ num_slots, num_heads, head_size
+ )
+ result_value_flat = result_value_cache.permute(0, 2, 1, 3).reshape(
+ num_slots, num_heads, head_size
+ )
torch.testing.assert_close(
result_key_flat[slot_mapping], key.float(), atol=1.5, rtol=0.5
@@ -409,7 +408,7 @@ def dequant_nvfp4_cache_nhd(data_cache, scale_cache, global_scale):
for i in range(num_tokens):
block_idx = block_indices_lst[i]
block_offset = block_offsets_lst[i]
- if kv_cache_layout == "NHD":
+ if kv_cache_layout == "LBNHC":
cloned_key_cache[block_idx, block_offset, :, :] = key[i]
cloned_value_cache[block_idx, block_offset, :, :] = value[i]
else:
diff --git a/tests/kernels/attention/test_flashinfer_trtllm_attention.py b/tests/kernels/attention/test_flashinfer_trtllm_attention.py
index 87a12c2ff395..41b59e036446 100644
--- a/tests/kernels/attention/test_flashinfer_trtllm_attention.py
+++ b/tests/kernels/attention/test_flashinfer_trtllm_attention.py
@@ -13,7 +13,7 @@
from vllm.utils.math_utils import round_up
from vllm.utils.torch_utils import (
nvfp4_kv_cache_full_dim,
- nvfp4_kv_cache_split_views,
+ nvfp4_split_data_scale,
set_random_seed,
)
@@ -74,17 +74,18 @@ def make_nvfp4_kv_cache(
kv_scale_val, dtype=torch.float32, device=kv_bf16_hnd.device
)
- # Allocate in HND physical order, permute to NHD logical order.
- # hnd_order swaps dims 2↔3; it is its own inverse.
+ # Production layout: 4D (B, 2*H, N, full_dim) where K heads occupy
+ # the first H heads and V heads occupy the second H heads.
full_dim = nvfp4_kv_cache_full_dim(head_size)
- hnd_order = (0, 1, 3, 2, 4)
- kv_cache = torch.zeros(
- (num_blocks, 2, num_kv_heads, block_size, full_dim),
+ kv_cache_hnd = torch.zeros(
+ (num_blocks, 2 * num_kv_heads, block_size, full_dim),
dtype=torch.uint8,
device=kv_bf16_hnd.device,
- ).permute(*hnd_order)
+ )
+ kv_cache_nhd = kv_cache_hnd.permute(0, 2, 1, 3)
+ k_view_nhd, v_view_nhd = kv_cache_nhd.split(num_kv_heads, dim=-2)
- # Flatten NHD [N, T, H, D] → token tensors [N*T, H, D] for the kernel.
+ # Flatten input KV → token tensors [B*N, H, head_size] for the kernel.
num_tokens = num_blocks * block_size
k_tokens = (
kv_bf16_hnd[:, 0]
@@ -98,22 +99,21 @@ def make_nvfp4_kv_cache(
)
slot_mapping = torch.arange(num_tokens, dtype=torch.long, device=kv_bf16_hnd.device)
- # reshape_and_cache_flash: kernel receives kv_cache[:, 0] and [:, 1]
- # (full K/V buffers containing both data and scale).
torch.ops._C_cache_ops.reshape_and_cache_flash(
k_tokens,
v_tokens,
- kv_cache[:, 0],
- kv_cache[:, 1],
+ k_view_nhd,
+ v_view_nhd,
slot_mapping,
"nvfp4",
kv_scale_tensor,
kv_scale_tensor,
)
- # Split in HND order for trtllm kernel (expects HND numTokensPerPage).
- kv_cache_hnd = kv_cache.permute(*hnd_order)
- (k_data, v_data), (k_scales, v_scales) = nvfp4_kv_cache_split_views(kv_cache_hnd)
+ # Split into data/scale views in HNC order for trtllm kernel.
+ k_cache_hnc, v_cache_hnc = kv_cache_hnd.split(num_kv_heads, dim=1)
+ k_data, k_scales = nvfp4_split_data_scale(k_cache_hnc)
+ v_data, v_scales = nvfp4_split_data_scale(v_cache_hnc)
# Dequantize for the FA2 reference baseline.
ref_k = dequant_nvfp4_kv_cache(
diff --git a/tests/kernels/attention/test_trtllm_kvfp8_dequant.py b/tests/kernels/attention/test_trtllm_kvfp8_dequant.py
index c49ceb03f5b1..b5193e0a8814 100644
--- a/tests/kernels/attention/test_trtllm_kvfp8_dequant.py
+++ b/tests/kernels/attention/test_trtllm_kvfp8_dequant.py
@@ -3,14 +3,14 @@
"""
Standalone unit tests for trtllm_prefill_attn_kvfp8_dequant.
-Tests both contiguous and non-contiguous (cross-layer unified) KV cache
-layouts against a pure-PyTorch reference implementation.
+Tests KV cache layouts against a pure-PyTorch reference implementation.
"""
import pytest
import torch
from vllm.platforms import current_platform
+from vllm.v1.kv_cache_interface import KVCacheLayout
if current_platform.is_rocm():
pytest.skip(
@@ -34,51 +34,31 @@ def to_float8(x, dtype=None):
return x_scl_sat.to(dtype), scale.float().reciprocal()
-def make_contiguous_kv_cache(num_blocks, num_kv_heads, block_size, head_size):
- """Create a standard contiguous fp8 KV cache (HND layout)."""
- raw = torch.randn(
- num_blocks,
- 2,
- num_kv_heads,
- block_size,
- head_size,
- dtype=torch.bfloat16,
- device="cuda",
- )
- kv_cache, scale = to_float8(raw)
- return kv_cache, scale
-
-
-def make_cross_layer_kv_cache(
- num_blocks,
- num_kv_heads,
- block_size,
- head_size,
- num_layers=4,
+def make_random_kv_cache(
+ num_blocks, num_kv_heads, block_size, head_size, layout=KVCacheLayout.LBHNC
):
- """
- Create a non-contiguous per-layer view mimicking cross-layer allocation.
+ """Create a random fp8 KV cache in 5D ``(B, 2, H, N, hs)`` format.
- Physical layout: (num_blocks, 2, num_kv_heads, num_layers, block_size, head_size)
- Returned view: (num_blocks, 2, num_kv_heads, block_size, head_size)
- with non-contiguous strides on dims 0, 1, 2 (they skip over num_layers).
+ The 4D cache ``(B, H, N, 2*hs)`` is allocated with the physical stride
+ order from *layout*, then reshaped to 5D. Non-identity layouts produce
+ non-contiguous strides on inner dims, matching the actual forward path.
"""
- raw = torch.randn(
+ logical_4d = (num_blocks, num_kv_heads, block_size, 2 * head_size)
+ stride_order = layout.layer_stride_order
+ physical_4d = tuple(logical_4d[i] for i in stride_order)
+ inv_order = [stride_order.index(i) for i in range(4)]
+
+ raw_phys = torch.randn(*physical_4d, dtype=torch.bfloat16, device="cuda")
+ fp8_phys, scale = to_float8(raw_phys)
+ fp8_4d = fp8_phys.permute(*inv_order)
+ kv_5d = fp8_4d.view(
num_blocks,
- 2,
num_kv_heads,
- num_layers,
block_size,
+ 2,
head_size,
- dtype=torch.bfloat16,
- device="cuda",
- )
- fp8_full, scale = to_float8(raw)
- layer_view = fp8_full[:, :, :, 0, :, :]
- assert not layer_view.is_contiguous(), (
- f"Expected non-contiguous view, got strides {layer_view.stride()}"
- )
- return layer_view, scale
+ ).permute(0, 3, 1, 2, 4)
+ return kv_5d, scale
def ref_dequant(kv_cache, block_tables, k_scale, v_scale, dequant_dtype):
@@ -114,7 +94,7 @@ def ref_dequant(kv_cache, block_tables, k_scale, v_scale, dequant_dtype):
@pytest.mark.parametrize("block_size", [16, 32])
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("num_pages_per_seq", [3, 8])
-@pytest.mark.parametrize("contiguous", [True, False])
+@pytest.mark.parametrize("layout", [KVCacheLayout.LBHNC, KVCacheLayout.LBNHC])
@torch.inference_mode()
def test_trtllm_kvfp8_dequant(
num_kv_heads: int,
@@ -122,7 +102,7 @@ def test_trtllm_kvfp8_dequant(
block_size: int,
batch_size: int,
num_pages_per_seq: int,
- contiguous: bool,
+ layout: KVCacheLayout,
):
from vllm.v1.attention.backends.flashinfer import (
trtllm_prefill_attn_kvfp8_dequant,
@@ -130,20 +110,13 @@ def test_trtllm_kvfp8_dequant(
torch.set_default_device("cuda")
- if contiguous:
- kv_cache, scale = make_contiguous_kv_cache(
- NUM_BLOCKS,
- num_kv_heads,
- block_size,
- head_size,
- )
- else:
- kv_cache, scale = make_cross_layer_kv_cache(
- NUM_BLOCKS,
- num_kv_heads,
- block_size,
- head_size,
- )
+ kv_cache, scale = make_random_kv_cache(
+ NUM_BLOCKS,
+ num_kv_heads,
+ block_size,
+ head_size,
+ layout=layout,
+ )
k_scale = scale.clone()
v_scale = scale.clone()
@@ -187,7 +160,7 @@ def test_block_tables_with_zero_pages():
torch.set_default_device("cuda")
num_kv_heads, block_size, head_size = 8, 16, 64
- kv_cache, scale = make_contiguous_kv_cache(
+ kv_cache, scale = make_random_kv_cache(
NUM_BLOCKS,
num_kv_heads,
block_size,
@@ -234,7 +207,7 @@ def test_all_zero_block_tables():
torch.set_default_device("cuda")
num_kv_heads, block_size, head_size = 4, 16, 64
- kv_cache, scale = make_contiguous_kv_cache(
+ kv_cache, scale = make_random_kv_cache(
NUM_BLOCKS,
num_kv_heads,
block_size,
@@ -266,7 +239,7 @@ def test_different_k_v_scales():
torch.set_default_device("cuda")
num_kv_heads, block_size, head_size = 8, 16, 64
- kv_cache, _ = make_contiguous_kv_cache(
+ kv_cache, _ = make_random_kv_cache(
NUM_BLOCKS,
num_kv_heads,
block_size,
@@ -299,7 +272,7 @@ def test_single_page_per_seq():
torch.set_default_device("cuda")
num_kv_heads, block_size, head_size = 8, 16, 128
- kv_cache, scale = make_contiguous_kv_cache(
+ kv_cache, scale = make_random_kv_cache(
NUM_BLOCKS,
num_kv_heads,
block_size,
@@ -332,7 +305,7 @@ def test_large_page_indices():
num_kv_heads, block_size, head_size = 8, 16, 128
large_num_blocks = 32768
- kv_cache, scale = make_contiguous_kv_cache(
+ kv_cache, scale = make_random_kv_cache(
large_num_blocks,
num_kv_heads,
block_size,
@@ -369,7 +342,7 @@ def test_large_block_size():
torch.set_default_device("cuda")
num_kv_heads, block_size, head_size = 4, 64, 128
- kv_cache, scale = make_contiguous_kv_cache(
+ kv_cache, scale = make_random_kv_cache(
NUM_BLOCKS,
num_kv_heads,
block_size,
@@ -395,46 +368,3 @@ def test_large_block_size():
ref = ref_dequant(kv_cache, block_tables, k_scale, v_scale, torch.bfloat16)
torch.testing.assert_close(mock_kv_cache[1:], ref[1:], atol=1e-3, rtol=1e-3)
-
-
-@torch.inference_mode()
-def test_cross_layer_many_layers():
- """
- Non-contiguous with 36 layers -- matches real gpt-oss-120b.
- Strides are far from contiguous (factor of 36 in the gaps).
- """
- from vllm.v1.attention.backends.flashinfer import (
- trtllm_prefill_attn_kvfp8_dequant,
- )
-
- torch.set_default_device("cuda")
- num_kv_heads, block_size, head_size = 8, 16, 64
- num_layers = 36
-
- kv_cache, scale = make_cross_layer_kv_cache(
- NUM_BLOCKS,
- num_kv_heads,
- block_size,
- head_size,
- num_layers=num_layers,
- )
- k_scale = v_scale = scale.clone()
-
- block_tables = torch.randint(
- 1,
- NUM_BLOCKS,
- (4, 6),
- dtype=torch.int32,
- device="cuda",
- )
-
- mock_kv_cache, _ = trtllm_prefill_attn_kvfp8_dequant(
- kv_cache,
- block_tables,
- k_scale,
- v_scale,
- torch.bfloat16,
- )
- ref = ref_dequant(kv_cache, block_tables, k_scale, v_scale, torch.bfloat16)
-
- torch.testing.assert_close(mock_kv_cache[1:], ref[1:], atol=1e-3, rtol=1e-3)
diff --git a/tests/test_attention_backend_registry.py b/tests/test_attention_backend_registry.py
index 034749874d7b..6c5e0d7fa54f 100644
--- a/tests/test_attention_backend_registry.py
+++ b/tests/test_attention_backend_registry.py
@@ -38,11 +38,6 @@ def get_builder_cls():
"""Mock builder class."""
return None
- @staticmethod
- def get_required_kv_cache_layout():
- """Mock KV cache layout."""
- return None
-
class CustomMambaAttentionImpl(AttentionImpl):
"""Mock custom mamba attention implementation for testing."""
@@ -71,11 +66,6 @@ def get_builder_cls():
"""Mock builder class."""
return None
- @staticmethod
- def get_required_kv_cache_layout():
- """Mock KV cache layout."""
- return None
-
def test_custom_is_not_alias_of_any_backend():
# Get all members of AttentionBackendEnum
diff --git a/tests/v1/attention/test_attention_backends.py b/tests/v1/attention/test_attention_backends.py
index 62643032edb4..5b924c0ca11a 100644
--- a/tests/v1/attention/test_attention_backends.py
+++ b/tests/v1/attention/test_attention_backends.py
@@ -27,9 +27,10 @@
from vllm.v1.attention.backend import AttentionType, CommonAttentionMetadata
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.attention.backends.utils import (
+ resolve_kv_cache_layout,
set_kv_cache_layout,
)
-from vllm.v1.kv_cache_interface import FullAttentionSpec
+from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheLayout
BACKENDS_TO_TEST = [
AttentionBackendEnum.FLASH_ATTN,
@@ -105,26 +106,18 @@ def create_and_prepopulate_kv_cache(
device: torch.device,
num_blocks: int,
common_attn_metadata: CommonAttentionMetadata,
+ layout: KVCacheLayout,
randomize_blocks: bool = True,
) -> torch.Tensor:
"""Create and prepopulate a KV cache with context data.
- Args:
- k_contexts: List of key context tensors for each sequence
- v_contexts: List of value context tensors for each sequence
- seq_lens: List of sequence lengths
- block_size: Size of each block
- num_kv_heads: Number of KV heads
- head_size: Size of each head
- dtype: Data type for the cache
- device: Device to create the cache on
- num_blocks: Total number of blocks in the cache
- block_table: Block table tensor to populate
- randomize_blocks: Whether to randomly permute blocks
- or use sequential order
+ Mirrors production's ``reshape_kv_cache``: allocates a flat buffer in
+ the physical order dictated by *layout*, then permutes to the logical
+ ``[B, H, N, C]`` shape that every backend expects.
Returns:
- Tuple of (kv_cache, updated_block_table)
+ A 4D tensor in logical ``(num_blocks, num_kv_heads, block_size,
+ 2 * head_size)`` order with strides determined by *layout*.
"""
batch_size = len(k_contexts)
seq_lens = common_attn_metadata.seq_lens.cpu()
@@ -136,43 +129,44 @@ def create_and_prepopulate_kv_cache(
block_table = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping
- # Create KV cache and populate in (2, num_blocks, ...) layout for easy
- # flat indexing, then transpose to (num_blocks, 2, ...) layout.
- kv_cache = torch.zeros(
- 2, num_blocks, block_size, num_kv_heads, head_size, dtype=dtype, device=device
- )
- kv_cache_flat = kv_cache.view(2, -1, num_kv_heads, head_size)
-
- # Populate the cache with the context tokens
- # Start from block_id=1 since block_id=0 is considered the null block
- start_block_idx = 1
+ # --- allocate ---------------------------------------------------------
+ # Logical 4D shape is always [B, H, N, C].
+ logical_4d = (num_blocks, num_kv_heads, block_size, 2 * head_size)
+ stride_order = layout.layer_stride_order
+ physical_4d = tuple(logical_4d[i] for i in stride_order)
+ inv_order = [stride_order.index(i) for i in range(4)]
+
+ kv_cache_physical = torch.zeros(physical_4d, dtype=dtype, device=device)
+ # Permute to logical [B, H, N, C] — mirrors reshape_kv_cache.
+ kv_cache = kv_cache_physical.permute(*inv_order)
+
+ # --- populate ---------------------------------------------------------
+ # Write context tokens into the cache via the logical view:
+ # kv_cache[block, :, token_in_block, :] routes correctly regardless
+ # of physical layout.
+ start_block_idx = 1 # block 0 is the null block
for i in range(batch_size):
k_context, v_context = k_contexts[i], v_contexts[i]
- start = start_block_idx * block_size
- end = start + k_context.shape[0]
- kv_cache_flat[0, start:end, ...] = k_context
- kv_cache_flat[1, start:end, ...] = v_context
-
- # Stay block aligned and allocate enough blocks for the new tokens
+ for t in range(k_context.shape[0]):
+ blk = start_block_idx + t // block_size
+ off = t % block_size
+ kv_cache[blk, :, off, :head_size] = k_context[t]
+ kv_cache[blk, :, off, head_size:] = v_context[t]
start_block_idx += cdiv(int(seq_lens[i]), block_size)
- # Transpose to (num_blocks, 2, ...) layout
- kv_cache = kv_cache.transpose(0, 1).contiguous()
-
blocks_end = start_block_idx
# Permute the context blocks (excluding block 0 which is null)
if randomize_blocks:
- # Random permutation starting from block 1
perm = torch.randperm(blocks_end - 1) + 1
else:
- # Sequential order starting from block 1
perm = torch.arange(1, blocks_end)
inv_perm = torch.zeros(blocks_end, dtype=torch.long, device=device)
- # Add 1 to account for starting from block 1
inv_perm[1:] = torch.argsort(perm) + 1
- kv_cache[1:blocks_end, ...] = kv_cache[perm, ...]
+ # Shuffle on the physical tensor (dim-0 is always the block dim for
+ # layer-compact layouts).
+ kv_cache_physical[1:blocks_end, ...] = kv_cache_physical[perm, ...]
# Construct the right block table
# Start from block_id=1 since block_id=0 is considered the null block
@@ -460,6 +454,7 @@ def _test_backend_correctness(
common_attn_metadata.causal = causal
# 3. Simulate Paged KV Cache and a realistic slot_mapping
+ layout = resolve_kv_cache_layout()
kv_cache = create_and_prepopulate_kv_cache(
k_contexts=k_contexts,
v_contexts=v_contexts,
@@ -470,6 +465,7 @@ def _test_backend_correctness(
device=device,
num_blocks=vllm_config.cache_config.num_gpu_blocks or 1000,
common_attn_metadata=common_attn_metadata,
+ layout=layout,
randomize_blocks=True,
)
@@ -477,54 +473,26 @@ def _test_backend_correctness(
# Note: flex_attention has known Triton kernel compatibility issues
# with test infrastructures
for backend_name in backend_to_test:
- reset_kv_cache_layout = False
-
- # Resolve backend class for both enum and string names.
- actual_backend = backend_name
- if backend_name == "FLEX_ATTENTION_SLOW":
- actual_backend = AttentionBackendEnum.FLEX_ATTENTION
- if hasattr(actual_backend, "get_class"):
- backend_cls = actual_backend.get_class()
- else:
- backend_cls = None
-
- if backend_name == AttentionBackendEnum.FLASHINFER:
- set_kv_cache_layout("HND")
- reset_kv_cache_layout = True
-
- # Apply stride order like runtime does in
- # _reshape_kv_cache (attn_utils.py:182-210): permute to physical
- # layout, make contiguous, then permute to logical layout.
kv_cache_for_backend = kv_cache
- if backend_cls is not None:
- try:
- stride_order = backend_cls.get_kv_cache_stride_order()
- except (AttributeError, NotImplementedError):
- stride_order = tuple(range(kv_cache.ndim))
- if stride_order != tuple(range(kv_cache.ndim)):
- inv_order = [stride_order.index(i) for i in range(len(stride_order))]
- kv_cache_for_backend = (
- kv_cache.permute(*stride_order).contiguous().permute(*inv_order)
- )
- try:
- backend_output = run_attention_backend(
- backend_name,
- kv_cache_spec,
- ["placeholder"],
- vllm_config,
- device,
- common_attn_metadata,
- query_vllm,
- key_vllm,
- value_vllm,
- kv_cache_for_backend,
- sliding_window=sliding_window,
- attn_type=attn_type,
- )
- finally:
- if reset_kv_cache_layout:
- set_kv_cache_layout(None)
+ # FlashInfer reads the layout at plan time; override to match
+ # the physical order of the test cache.
+ set_kv_cache_layout(layout.name)
+
+ backend_output = run_attention_backend(
+ backend_name,
+ kv_cache_spec,
+ ["placeholder"],
+ vllm_config,
+ device,
+ common_attn_metadata,
+ query_vllm,
+ key_vllm,
+ value_vllm,
+ kv_cache_for_backend,
+ sliding_window=sliding_window,
+ attn_type=attn_type,
+ )
# Check shape and dtype consistency
assert backend_output.shape == sdpa_output.shape, (
@@ -571,7 +539,10 @@ def error_msg(msg: str, backend_name: str):
@pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"])
@pytest.mark.parametrize("tensor_parallel_size", [1, 2, 4])
def test_causal_backend_correctness(
- default_vllm_config, batch_spec_name: str, model: str, tensor_parallel_size: int
+ default_vllm_config,
+ batch_spec_name: str,
+ model: str,
+ tensor_parallel_size: int,
):
"""Test backend's correctness with causal attention."""
@@ -656,7 +627,9 @@ def causal_mask_mod(
@pytest.mark.parametrize("model", ["microsoft/Phi-tiny-MoE-instruct"])
@pytest.mark.parametrize("tensor_parallel_size", [1, 2, 4])
def test_sliding_window_backend_correctness(
- batch_spec_name: str, model: str, tensor_parallel_size: int
+ batch_spec_name: str,
+ model: str,
+ tensor_parallel_size: int,
):
"""Test backend's correctness with sliding window attention."""
@@ -718,7 +691,9 @@ def sliding_window_mask_mod(
@pytest.mark.parametrize("model", ["google/embeddinggemma-300m"])
@pytest.mark.parametrize("tensor_parallel_size", [1, 2])
def test_sliding_window_encoder_backend_correctness(
- batch_spec_name: str, model: str, tensor_parallel_size: int
+ batch_spec_name: str,
+ model: str,
+ tensor_parallel_size: int,
):
"""Test backend's correctness with sliding window attention."""
@@ -775,7 +750,9 @@ def bidi_sliding_window_mask_mod(
)
@pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"])
def test_non_causal_backend_correctness(
- default_vllm_config, batch_spec_name: str, model: str
+ default_vllm_config,
+ batch_spec_name: str,
+ model: str,
):
"""Test backend's correctness with non-causal (bidirectional) decoder
attention, as used by DFlash speculative decoding."""
diff --git a/tests/v1/attention/test_indexer_deepseek_v4_slot_mapping.py b/tests/v1/attention/test_indexer_deepseek_v4_slot_mapping.py
index 159bb8af3fb9..a7514381f0a2 100644
--- a/tests/v1/attention/test_indexer_deepseek_v4_slot_mapping.py
+++ b/tests/v1/attention/test_indexer_deepseek_v4_slot_mapping.py
@@ -17,13 +17,13 @@ def test_indexer_builder_deepseek_v4_compressed_slot_mapping_uses_storage_block_
"""
device = torch.device("cuda")
- # storage_block_size = block_size // compress_ratio = 256 // 4 = 64
+ # storage_block_size = block_size // tokens_per_state = 256 // 4 = 64
kv_cache_spec = MLAAttentionSpec(
block_size=256,
num_kv_heads=1,
head_size=128,
dtype=torch.bfloat16,
- compress_ratio=4,
+ tokens_per_state=4,
)
vllm_config = create_vllm_config(max_model_len=1024)
builder = DeepseekV32IndexerMetadataBuilder(
diff --git a/tests/v1/attention/test_mla_backends.py b/tests/v1/attention/test_mla_backends.py
index 109e56cb3838..c77f117e6f1f 100644
--- a/tests/v1/attention/test_mla_backends.py
+++ b/tests/v1/attention/test_mla_backends.py
@@ -205,8 +205,9 @@ def create_and_prepopulate_kv_cache(
else:
kv_entry_size = head_size
+ # Create MLA KV cache: (num_blocks, num_heads=1, block_size, kv_entry_size)
kv_cache = torch.zeros(
- num_blocks, block_size, kv_entry_size, dtype=torch.uint8, device=device
+ num_blocks, 1, block_size, kv_entry_size, dtype=torch.uint8, device=device
)
scale_tensor = (
scale
@@ -215,9 +216,9 @@ def create_and_prepopulate_kv_cache(
)
scale_tensor = scale_tensor.to(device=device, dtype=torch.float32)
else:
- # Create MLA KV cache: (num_blocks, block_size, head_size)
+ # Create MLA KV cache: (num_blocks, num_heads=1, block_size, head_size)
kv_cache = torch.zeros(
- num_blocks, block_size, head_size, dtype=dtype, device=device
+ num_blocks, 1, block_size, head_size, dtype=dtype, device=device
)
kv_cache_flat = kv_cache.view(-1, head_size)
@@ -238,7 +239,7 @@ def create_and_prepopulate_kv_cache(
ops.concat_and_cache_mla(
kv_c_context,
k_pe_context.squeeze(1),
- kv_cache,
+ kv_cache.squeeze(1),
slots,
kv_cache_dtype=kv_cache_dtype,
scale=scale_tensor,
@@ -361,7 +362,7 @@ def forward_impl(
ops.concat_and_cache_mla(
kv_c,
k_pe.squeeze(1),
- kv_cache,
+ kv_cache.squeeze(1),
attn_metadata.slot_mapping.flatten(),
kv_cache_dtype=kv_cache_dtype,
scale=self._k_scale,
@@ -497,7 +498,7 @@ def forward_impl(
ops.concat_and_cache_mla(
kv_c,
k_pe.squeeze(1),
- kv_cache,
+ kv_cache.squeeze(1),
attn_metadata.slot_mapping.flatten(),
kv_cache_dtype=kv_cache_dtype,
scale=self._k_scale,
diff --git a/tests/v1/attention/test_trtllm_attention_integration.py b/tests/v1/attention/test_trtllm_attention_integration.py
index 06c5844508f4..d7b65ebb2d1a 100644
--- a/tests/v1/attention/test_trtllm_attention_integration.py
+++ b/tests/v1/attention/test_trtllm_attention_integration.py
@@ -9,6 +9,7 @@
import torch
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
+from tests.v1.attention.test_attention_backends import create_and_prepopulate_kv_cache
from tests.v1.attention.utils import (
BatchSpec,
create_common_attn_metadata,
@@ -16,14 +17,13 @@
)
from vllm.config import set_current_vllm_config
from vllm.platforms import current_platform
-from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import nvfp4_kv_cache_full_dim, set_random_seed
from vllm.v1.attention.backends.utils import (
PerLayerParameters,
- get_kv_cache_layout,
+ resolve_kv_cache_layout,
set_kv_cache_layout,
)
-from vllm.v1.kv_cache_interface import FullAttentionSpec, KVQuantMode
+from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheLayout, KVQuantMode
if not current_platform.is_device_capability_family(100):
pytest.skip(
@@ -86,90 +86,6 @@ def _mock_get_per_layer_parameters(vllm_config, layer_names, impl_cls):
}
-def _create_hnd_kv_cache(
- k_contexts,
- v_contexts,
- block_size,
- num_kv_heads,
- head_size,
- dtype,
- device,
- num_blocks,
- common_attn_metadata,
-):
- """Create and populate a KV cache with HND-compatible strides.
-
- The returned tensor has logical shape
- (num_blocks, 2, block_size, num_kv_heads, head_size) but is physically
- laid out as (num_blocks, 2, num_kv_heads, block_size, head_size) so that
- ``kv_cache.permute(0, 1, 3, 2, 4)`` yields a contiguous HND view.
- """
- seq_lens = common_attn_metadata.seq_lens.cpu()
- query_lens = (
- common_attn_metadata.query_start_loc_cpu[1:]
- - common_attn_metadata.query_start_loc_cpu[:-1]
- )
- block_table = common_attn_metadata.block_table_tensor
- slot_mapping = common_attn_metadata.slot_mapping
- batch_size = len(k_contexts)
-
- # Build cache in (2, num_blocks, block_size, num_kv_heads, head_size)
- # then convert to HND format (same approach as test_attention_backends.py).
- kv_cache_raw = torch.zeros(
- 2,
- num_blocks,
- block_size,
- num_kv_heads,
- head_size,
- dtype=dtype,
- device=device,
- )
- kv_cache_flat = kv_cache_raw.view(2, -1, num_kv_heads, head_size)
-
- start_block_idx = 1
- for i in range(batch_size):
- k_ctx, v_ctx = k_contexts[i], v_contexts[i]
- start = start_block_idx * block_size
- end = start + k_ctx.shape[0]
- kv_cache_flat[0, start:end] = k_ctx
- kv_cache_flat[1, start:end] = v_ctx
- start_block_idx += cdiv(int(seq_lens[i]), block_size)
-
- blocks_end = start_block_idx
-
- # Randomly permute blocks (starting from block 1; block 0 is null).
- perm = torch.randperm(blocks_end - 1) + 1
- inv_perm = torch.zeros(blocks_end, dtype=torch.long, device=device)
- inv_perm[1:] = torch.argsort(perm) + 1
- kv_cache_raw[:, 1:blocks_end] = kv_cache_raw[:, perm]
-
- # Build block table.
- start_block_idx = 1
- for i in range(batch_size):
- n_blocks = cdiv(int(seq_lens[i]), block_size)
- block_table[i, :n_blocks] = inv_perm[
- start_block_idx : start_block_idx + n_blocks
- ]
- start_block_idx += n_blocks
-
- # Build slot mapping that is consistent with the block table.
- for i in range(batch_size):
- ctx_len = int(seq_lens[i]) - int(query_lens[i])
- token_offsets = torch.arange(int(query_lens[i])) + ctx_len
- block_indices = token_offsets // block_size
- intra_block_offsets = token_offsets % block_size
- start = common_attn_metadata.query_start_loc_cpu[i]
- end = common_attn_metadata.query_start_loc_cpu[i + 1]
- slot_mapping[start:end] = block_table[
- i, block_indices
- ] * block_size + intra_block_offsets.to(device)
-
- # Transpose to FlashInfer logical shape then make HND-strided.
- kv_cache = kv_cache_raw.transpose(0, 1)
- kv_cache = kv_cache.transpose(2, 3).contiguous().transpose(2, 3)
- return kv_cache
-
-
def _create_nvfp4_hnd_kv_cache(
k_contexts,
v_contexts,
@@ -182,43 +98,15 @@ def _create_nvfp4_hnd_kv_cache(
common_attn_metadata,
kv_scale_val,
):
- """Create an nvfp4 KV cache by quantizing bf16 context via
- reshape_and_cache_flash, using the same block-table layout as
- _create_hnd_kv_cache.
-
- The returned tensor is dtype ``uint8`` with shape
- ``(num_blocks, 2, block_size, num_kv_heads, full_dim)`` in logical
- (NHD) order, but physically permuted to HND layout via stride order
- ``(0, 1, 3, 2, 4)`` (i.e. ``num_kv_heads`` before ``block_size``).
-
- The last dimension ``full_dim = head_size // 2 + head_size // 16``
- packs two regions contiguously:
- - **FP4 data** (``head_size // 2`` bytes): pairs of E2M1 values,
- two per byte.
- - **FP8 block scales** (``head_size // 16`` bytes): one E4M3
- scale per 16-element block.
-
- Dimension 1 indexes K (``[:, 0]``) and V (``[:, 1]``).
-
- Args:
- k_contexts: List of key context tensors, one per sequence.
- v_contexts: List of value context tensors, one per sequence.
- block_size: Number of tokens per cache block.
- num_kv_heads: Number of key/value heads.
- head_size: Head dimension (must be divisible by 16).
- dtype: Source data type for the bf16 intermediate cache.
- device: Target device.
- num_blocks: Total number of blocks to allocate.
- common_attn_metadata: Metadata containing block tables and
- sequence lengths.
- kv_scale_val: Scalar float used as both k_scale and v_scale
- during quantization.
-
- Returns:
- ``torch.Tensor``: The nvfp4 kv_cache tensor (uint8, HND-strided).
+ """Create an nvfp4 KV cache with 2H head layout.
+
+ The returned tensor is dtype ``uint8`` with logical shape
+ ``(num_blocks, 2 * num_kv_heads, block_size, full_dim)`` where K occupies
+ the first H heads and V occupies the next H heads, and
+ ``full_dim = head_size // 2 + head_size // 16`` packs FP4 data and
+ FP8 block scales per head.
"""
- # First create a bf16 HND cache so block tables are populated.
- bf16_cache = _create_hnd_kv_cache(
+ bf16_cache = create_and_prepopulate_kv_cache(
k_contexts,
v_contexts,
block_size,
@@ -228,20 +116,18 @@ def _create_nvfp4_hnd_kv_cache(
device,
num_blocks,
common_attn_metadata,
+ layout=KVCacheLayout.LBHNC,
)
- # Allocate nvfp4 cache: same shape but with full_dim (data + scale).
full_dim = nvfp4_kv_cache_full_dim(head_size)
- hnd_order = (0, 1, 3, 2, 4)
nvfp4_cache = torch.zeros(
- (num_blocks, 2, num_kv_heads, block_size, full_dim),
+ (num_blocks, 2 * num_kv_heads, block_size, full_dim),
dtype=torch.uint8,
device=device,
- ).permute(*hnd_order)
+ )
+ k_cache = nvfp4_cache[:, :num_kv_heads]
+ v_cache = nvfp4_cache[:, num_kv_heads:]
- # Flatten bf16 context into tokens and quantize via reshape_and_cache_flash.
- # bf16_cache is (num_blocks, 2, block_size, num_kv_heads, head_size) logical
- # with HND physical strides.
block_table = common_attn_metadata.block_table_tensor
seq_lens = common_attn_metadata.seq_lens.cpu()
query_lens = (
@@ -254,22 +140,29 @@ def _create_nvfp4_hnd_kv_cache(
ctx_len = int(seq_lens[i]) - int(query_lens[i])
if ctx_len == 0:
continue
- # Gather context tokens from the bf16 cache using block table.
n_ctx_blocks = (ctx_len + block_size - 1) // block_size
blocks = block_table[i, :n_ctx_blocks]
- # bf16_cache[:, kv_idx] is (num_blocks, block_size, num_kv_heads, head_size)
- k_ctx = bf16_cache[blocks, 0].reshape(-1, num_kv_heads, head_size)[:ctx_len]
- v_ctx = bf16_cache[blocks, 1].reshape(-1, num_kv_heads, head_size)[:ctx_len]
- # Build slot mapping for these context tokens.
+ # bf16_cache is (B, H, N, 2*head_size); extract K and V from last dim.
+ k_ctx = (
+ bf16_cache[blocks, :, :, :head_size]
+ .transpose(1, 2)
+ .reshape(-1, num_kv_heads, head_size)[:ctx_len]
+ )
+ v_ctx = (
+ bf16_cache[blocks, :, :, head_size:]
+ .transpose(1, 2)
+ .reshape(-1, num_kv_heads, head_size)[:ctx_len]
+ )
token_offsets = torch.arange(ctx_len, device=device)
block_indices = token_offsets // block_size
intra_offsets = token_offsets % block_size
slots = block_table[i, block_indices] * block_size + intra_offsets
+ # reshape_and_cache_flash expects (B, N, H, D) cache views
torch.ops._C_cache_ops.reshape_and_cache_flash(
k_ctx,
v_ctx,
- nvfp4_cache[:, 0],
- nvfp4_cache[:, 1],
+ k_cache.transpose(1, 2),
+ v_cache.transpose(1, 2),
slots,
"nvfp4",
kv_scale_t,
@@ -358,7 +251,7 @@ def causal_mask_mod(b, h, q_idx, kv_idx, *, context_len):
common_attn_metadata = create_common_attn_metadata(batch_spec, BLOCK_SIZE, device)
- # 2. Create HND KV cache
+ # 2. Create HNC KV cache
is_nvfp4 = kv_cache_dtype == "nvfp4"
if is_nvfp4:
# Compute a global scale from the context data.
@@ -378,7 +271,7 @@ def causal_mask_mod(b, h, q_idx, kv_idx, *, context_len):
)
else:
kv_scale_val = 1.0
- kv_cache = _create_hnd_kv_cache(
+ kv_cache = create_and_prepopulate_kv_cache(
k_contexts,
v_contexts,
BLOCK_SIZE,
@@ -388,11 +281,12 @@ def causal_mask_mod(b, h, q_idx, kv_idx, *, context_len):
device,
NUM_GPU_BLOCKS,
common_attn_metadata,
+ layout=KVCacheLayout.LBHNC,
)
# 3. Run through FlashInfer with TRTLLM enabled
- set_kv_cache_layout("HND")
- get_kv_cache_layout.cache_clear()
+ set_kv_cache_layout("LBHNC")
+ resolve_kv_cache_layout.cache_clear()
try:
is_nvfp4 = kv_cache_dtype == "nvfp4"
@@ -506,7 +400,7 @@ def causal_mask_mod(b, h, q_idx, kv_idx, *, context_len):
finally:
set_kv_cache_layout(None)
- get_kv_cache_layout.cache_clear()
+ resolve_kv_cache_layout.cache_clear()
@pytest.mark.parametrize(
diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py
index 68ad7bc42ef0..6174fb78e19a 100644
--- a/tests/v1/core/test_kv_cache_utils.py
+++ b/tests/v1/core/test_kv_cache_utils.py
@@ -788,36 +788,19 @@ def test_get_kv_cache_configs_multiple_workers():
ref_kv_cache_spec.page_size_bytes * 2 * 10,
],
)
- assert kv_cache_configs == [
- KVCacheConfig(
- num_blocks=10,
- kv_cache_tensors=[
- KVCacheTensor(
- size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer1"]
- ),
- KVCacheTensor(
- size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer2"]
- ),
- ],
- kv_cache_groups=[
- KVCacheGroupSpec(["layer1", "layer2"], ref_kv_cache_spec),
- ],
- ),
- KVCacheConfig(
- num_blocks=10,
- kv_cache_tensors=[
- KVCacheTensor(
- size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer1"]
- ),
- KVCacheTensor(
- size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer2"]
- ),
- ],
- kv_cache_groups=[
- KVCacheGroupSpec(["layer1", "layer2"], ref_kv_cache_spec),
- ],
- ),
- ]
+ expected = KVCacheConfig(
+ num_blocks=10,
+ kv_cache_tensors=[
+ KVCacheTensor(
+ size=ref_kv_cache_spec.page_size_bytes * 10 * 2,
+ shared_by=[["layer1"], ["layer2"]],
+ ),
+ ],
+ kv_cache_groups=[
+ KVCacheGroupSpec(["layer1", "layer2"], ref_kv_cache_spec),
+ ],
+ )
+ assert kv_cache_configs == [expected, expected]
# Different available memory. This is the case for TP.
# Use the smallest memory available.
@@ -829,36 +812,7 @@ def test_get_kv_cache_configs_multiple_workers():
ref_kv_cache_spec.page_size_bytes * 2 * 20,
],
)
- assert kv_cache_configs == [
- KVCacheConfig(
- num_blocks=10,
- kv_cache_tensors=[
- KVCacheTensor(
- size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer1"]
- ),
- KVCacheTensor(
- size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer2"]
- ),
- ],
- kv_cache_groups=[
- KVCacheGroupSpec(["layer1", "layer2"], ref_kv_cache_spec),
- ],
- ),
- KVCacheConfig(
- num_blocks=10,
- kv_cache_tensors=[
- KVCacheTensor(
- size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer1"]
- ),
- KVCacheTensor(
- size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer2"]
- ),
- ],
- kv_cache_groups=[
- KVCacheGroupSpec(["layer1", "layer2"], ref_kv_cache_spec),
- ],
- ),
- ]
+ assert kv_cache_configs == [expected, expected]
# Different KV cache specs. This is the case for PP.
different_layer_specs = [
@@ -885,7 +839,8 @@ def test_get_kv_cache_configs_multiple_workers():
num_blocks=10,
kv_cache_tensors=[
KVCacheTensor(
- size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer1"]
+ size=ref_kv_cache_spec.page_size_bytes * 10,
+ shared_by=[["layer1"]],
),
],
kv_cache_groups=[
@@ -896,10 +851,8 @@ def test_get_kv_cache_configs_multiple_workers():
num_blocks=10,
kv_cache_tensors=[
KVCacheTensor(
- size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer2"]
- ),
- KVCacheTensor(
- size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer3"]
+ size=ref_kv_cache_spec.page_size_bytes * 10 * 2,
+ shared_by=[["layer2"], ["layer3"]],
),
],
kv_cache_groups=[
@@ -929,64 +882,37 @@ def test_get_kv_cache_configs_multiple_workers():
kv_cache_configs = get_kv_cache_configs(
vllm_config,
tp_pp_kv_cache_specs,
- [
- ref_kv_cache_spec.page_size_bytes * 2 * 10,
- ref_kv_cache_spec.page_size_bytes * 2 * 10,
- ref_kv_cache_spec.page_size_bytes * 2 * 10,
- ref_kv_cache_spec.page_size_bytes * 2 * 10,
+ [ref_kv_cache_spec.page_size_bytes * 2 * 10] * 4,
+ )
+ expected_12 = KVCacheConfig(
+ num_blocks=10,
+ kv_cache_tensors=[
+ KVCacheTensor(
+ size=ref_kv_cache_spec.page_size_bytes * 10 * 2,
+ shared_by=[["layer1"], ["layer2"]],
+ ),
+ ],
+ kv_cache_groups=[
+ KVCacheGroupSpec(["layer1", "layer2"], ref_kv_cache_spec),
+ ],
+ )
+ expected_3 = KVCacheConfig(
+ num_blocks=10,
+ kv_cache_tensors=[
+ KVCacheTensor(
+ size=ref_kv_cache_spec.page_size_bytes * 10,
+ shared_by=[["layer3"]],
+ ),
+ ],
+ kv_cache_groups=[
+ KVCacheGroupSpec(["layer3"], ref_kv_cache_spec),
],
)
assert kv_cache_configs == [
- KVCacheConfig(
- num_blocks=10,
- kv_cache_tensors=[
- KVCacheTensor(
- size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer1"]
- ),
- KVCacheTensor(
- size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer2"]
- ),
- ],
- kv_cache_groups=[
- KVCacheGroupSpec(["layer1", "layer2"], ref_kv_cache_spec),
- ],
- ),
- KVCacheConfig(
- num_blocks=10,
- kv_cache_tensors=[
- KVCacheTensor(
- size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer1"]
- ),
- KVCacheTensor(
- size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer2"]
- ),
- ],
- kv_cache_groups=[
- KVCacheGroupSpec(["layer1", "layer2"], ref_kv_cache_spec),
- ],
- ),
- KVCacheConfig(
- num_blocks=10,
- kv_cache_tensors=[
- KVCacheTensor(
- size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer3"]
- ),
- ],
- kv_cache_groups=[
- KVCacheGroupSpec(["layer3"], ref_kv_cache_spec),
- ],
- ),
- KVCacheConfig(
- num_blocks=10,
- kv_cache_tensors=[
- KVCacheTensor(
- size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer3"]
- ),
- ],
- kv_cache_groups=[
- KVCacheGroupSpec(["layer3"], ref_kv_cache_spec),
- ],
- ),
+ expected_12,
+ expected_12,
+ expected_3,
+ expected_3,
]
# Different workers have different types of layers. This is the case for
@@ -1014,10 +940,8 @@ def test_get_kv_cache_configs_multiple_workers():
num_blocks=10,
kv_cache_tensors=[
KVCacheTensor(
- size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer1"]
- ),
- KVCacheTensor(
- size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer2"]
+ size=ref_kv_cache_spec.page_size_bytes * 10 * 2,
+ shared_by=[["layer1"], ["layer2"]],
),
],
kv_cache_groups=[
@@ -1029,10 +953,8 @@ def test_get_kv_cache_configs_multiple_workers():
num_blocks=10,
kv_cache_tensors=[
KVCacheTensor(
- size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer3"]
- ),
- KVCacheTensor(
- size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer4"]
+ size=ref_kv_cache_spec.page_size_bytes * 10 * 2,
+ shared_by=[["layer3"], ["layer4"]],
),
],
kv_cache_groups=[
@@ -1059,10 +981,7 @@ def test_get_kv_cache_configs_multiple_workers():
kv_cache_configs = get_kv_cache_configs(
vllm_config,
different_type_layer_specs,
- [
- ref_kv_cache_spec.page_size_bytes * 10,
- ref_kv_cache_spec.page_size_bytes * 10,
- ],
+ [ref_kv_cache_spec.page_size_bytes * 10] * 2,
)
assert kv_cache_configs == [
KVCacheConfig(
@@ -1070,7 +989,7 @@ def test_get_kv_cache_configs_multiple_workers():
kv_cache_tensors=[
KVCacheTensor(
size=ref_kv_cache_spec.page_size_bytes * 10,
- shared_by=["layer1", "layer2", "layer3"],
+ shared_by=[["layer1", "layer2", "layer3"]],
),
],
kv_cache_groups=[
@@ -1084,7 +1003,7 @@ def test_get_kv_cache_configs_multiple_workers():
kv_cache_tensors=[
KVCacheTensor(
size=ref_kv_cache_spec.page_size_bytes * 10,
- shared_by=["layer4", "layer5", "layer6"],
+ shared_by=[["layer4", "layer5", "layer6"]],
),
],
kv_cache_groups=[
@@ -1151,7 +1070,7 @@ def test_get_kv_cache_configs_pp_sharding(asymmetric_memory):
kv_cache_tensors=[
KVCacheTensor(
size=ref_kv_cache_spec.page_size_bytes * expected_num_blocks,
- shared_by=["layer1"],
+ shared_by=[["layer1"]],
),
],
kv_cache_groups=[KVCacheGroupSpec(["layer1"], ref_kv_cache_spec)],
@@ -1161,7 +1080,7 @@ def test_get_kv_cache_configs_pp_sharding(asymmetric_memory):
kv_cache_tensors=[
KVCacheTensor(
size=ref_kv_cache_spec.page_size_bytes * expected_num_blocks,
- shared_by=["layer2"],
+ shared_by=[["layer2"]],
),
],
kv_cache_groups=[KVCacheGroupSpec(["layer2"], ref_kv_cache_spec)],
@@ -1430,7 +1349,7 @@ def test_allocate_with_lookahead():
config = KVCacheConfig(
num_blocks=10,
kv_cache_tensors=[
- KVCacheTensor(size=100, shared_by=["layer1"]),
+ KVCacheTensor(size=100, shared_by=[["layer1"]]),
],
kv_cache_groups=[
KVCacheGroupSpec(["layer1"], new_kv_cache_spec(block_size=block_size)),
@@ -1505,11 +1424,14 @@ def test_get_kv_cache_config_one_worker():
vllm_config, [kv_cache_specs_full], [mem_per_block_per_layer * 2 * 32]
)[0]
print(kv_cache_config_full)
+
assert kv_cache_config_full == KVCacheConfig(
num_blocks=32,
kv_cache_tensors=[
- KVCacheTensor(size=mem_per_block_per_layer * 32, shared_by=["layer_1"]),
- KVCacheTensor(size=mem_per_block_per_layer * 32, shared_by=["layer_2"]),
+ KVCacheTensor(
+ size=mem_per_block_per_layer * 32 * 2,
+ shared_by=[["layer_1"], ["layer_2"]],
+ ),
],
kv_cache_groups=[KVCacheGroupSpec(["layer_1", "layer_2"], new_kv_cache_spec())],
)
@@ -1525,8 +1447,10 @@ def test_get_kv_cache_config_one_worker():
assert kv_cache_config_sliding == KVCacheConfig(
num_blocks=32,
kv_cache_tensors=[
- KVCacheTensor(size=mem_per_block_per_layer * 32, shared_by=["layer_1"]),
- KVCacheTensor(size=mem_per_block_per_layer * 32, shared_by=["layer_2"]),
+ KVCacheTensor(
+ size=mem_per_block_per_layer * 32 * 2,
+ shared_by=[["layer_1"], ["layer_2"]],
+ ),
],
kv_cache_groups=[
KVCacheGroupSpec(["layer_1", "layer_2"], new_sliding_window_spec())
@@ -1545,8 +1469,10 @@ def test_get_kv_cache_config_one_worker():
assert kv_cache_config_hybrid == KVCacheConfig(
num_blocks=32,
kv_cache_tensors=[
- KVCacheTensor(size=mem_per_block_per_layer * 32, shared_by=["layer_1"]),
- KVCacheTensor(size=mem_per_block_per_layer * 32, shared_by=["layer_2"]),
+ KVCacheTensor(
+ size=mem_per_block_per_layer * 32 * 2,
+ shared_by=[["layer_1"], ["layer_2"]],
+ ),
],
kv_cache_groups=[
KVCacheGroupSpec(
@@ -1568,7 +1494,8 @@ def test_get_kv_cache_config_one_worker():
num_blocks=64,
kv_cache_tensors=[
KVCacheTensor(
- size=mem_per_block_per_layer * 64, shared_by=["layer_1", "layer_2"]
+ size=mem_per_block_per_layer * 64,
+ shared_by=[["layer_1", "layer_2"]],
),
],
kv_cache_groups=[
@@ -1593,12 +1520,11 @@ def test_get_kv_cache_config_one_worker():
num_blocks=32,
kv_cache_tensors=[
KVCacheTensor(
- size=mem_per_block_per_layer * 32,
- shared_by=["layer_1", "layer_3", "layer_4"],
- ),
- KVCacheTensor(
- size=mem_per_block_per_layer * 32,
- shared_by=["layer_2", "layer_5", "layer_6"],
+ size=mem_per_block_per_layer * 32 * 2,
+ shared_by=[
+ ["layer_1", "layer_3", "layer_4"],
+ ["layer_2", "layer_5", "layer_6"],
+ ],
),
],
kv_cache_groups=[
@@ -1628,15 +1554,12 @@ def test_get_kv_cache_config_one_worker():
num_blocks=32,
kv_cache_tensors=[
KVCacheTensor(
- size=mem_per_block_per_layer * 32,
- shared_by=["layer_1", "layer_4", "layer_5", "layer_6"],
- ),
- KVCacheTensor(
- size=mem_per_block_per_layer * 32,
- shared_by=["layer_2", "layer_7", "layer_8", "layer_9"],
- ),
- KVCacheTensor(
- size=mem_per_block_per_layer * 32, shared_by=["layer_3", "layer_10"]
+ size=mem_per_block_per_layer * 32 * 3,
+ shared_by=[
+ ["layer_1", "layer_4", "layer_5", "layer_6"],
+ ["layer_2", "layer_7", "layer_8", "layer_9"],
+ ["layer_3", "layer_10"],
+ ],
),
],
kv_cache_groups=[
@@ -1649,8 +1572,7 @@ def test_get_kv_cache_config_one_worker():
],
)
- # 6 full + 5 sliding, pad to 6 full + 6 sliding. This is a typical case for gpt-oss
- # eagle where there is only one more full attention layer than sliding window layers
+ # 6 full + 5 sliding
kv_cache_specs_hybrid = {
"layer_1": new_kv_cache_spec(),
"layer_2": new_kv_cache_spec(),
@@ -1668,33 +1590,19 @@ def test_get_kv_cache_config_one_worker():
kv_cache_config_hybrid = get_kv_cache_configs(
vllm_config, [kv_cache_specs_hybrid], [mem_per_block_per_layer * 6 * 32]
)[0]
- print(kv_cache_config_hybrid)
assert kv_cache_config_hybrid == KVCacheConfig(
num_blocks=32,
kv_cache_tensors=[
KVCacheTensor(
- size=mem_per_block_per_layer * 32,
- shared_by=["layer_1", "layer_7"],
- ),
- KVCacheTensor(
- size=mem_per_block_per_layer * 32,
- shared_by=["layer_2", "layer_8"],
- ),
- KVCacheTensor(
- size=mem_per_block_per_layer * 32,
- shared_by=["layer_3", "layer_9"],
- ),
- KVCacheTensor(
- size=mem_per_block_per_layer * 32,
- shared_by=["layer_4", "layer_10"],
- ),
- KVCacheTensor(
- size=mem_per_block_per_layer * 32,
- shared_by=["layer_5", "layer_11"],
- ),
- KVCacheTensor(
- size=mem_per_block_per_layer * 32,
- shared_by=["layer_6"],
+ size=mem_per_block_per_layer * 32 * 6,
+ shared_by=[
+ ["layer_1", "layer_7"],
+ ["layer_2", "layer_8"],
+ ["layer_3", "layer_9"],
+ ["layer_4", "layer_10"],
+ ["layer_5", "layer_11"],
+ ["layer_6"],
+ ],
),
],
kv_cache_groups=[
@@ -1720,8 +1628,14 @@ def test_get_kv_cache_config_one_worker():
assert kv_cache_config_hybrid == KVCacheConfig(
num_blocks=32,
kv_cache_tensors=[
- KVCacheTensor(size=mem_per_block_per_layer * 32 * 2, shared_by=["layer_1"]),
- KVCacheTensor(size=mem_per_block_per_layer * 32, shared_by=["layer_2"]),
+ KVCacheTensor(
+ size=mem_per_block_per_layer * 32 * 2,
+ shared_by=[["layer_1"]],
+ ),
+ KVCacheTensor(
+ size=mem_per_block_per_layer * 32,
+ shared_by=[["layer_2"]],
+ ),
],
kv_cache_groups=[
KVCacheGroupSpec(
@@ -1745,7 +1659,8 @@ def test_get_kv_cache_config_one_worker():
num_blocks=32,
kv_cache_tensors=[
KVCacheTensor(
- size=mem_per_block_per_layer * 32, shared_by=["layer_1", "layer_2"]
+ size=mem_per_block_per_layer * 32,
+ shared_by=[["layer_1", "layer_2"]],
),
],
kv_cache_groups=[
@@ -1775,8 +1690,10 @@ def test_get_kv_cache_config_one_worker():
assert kv_cache_config_override_blocks == KVCacheConfig(
num_blocks=16,
kv_cache_tensors=[
- KVCacheTensor(size=mem_per_block_per_layer * 16, shared_by=["layer_1"]),
- KVCacheTensor(size=mem_per_block_per_layer * 16, shared_by=["layer_2"]),
+ KVCacheTensor(
+ size=mem_per_block_per_layer * 16 * 2,
+ shared_by=[["layer_1"], ["layer_2"]],
+ ),
],
kv_cache_groups=[KVCacheGroupSpec(["layer_1", "layer_2"], new_kv_cache_spec())],
)
diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py
index 91c5f37b4179..c6ed619b4927 100644
--- a/tests/v1/core/test_prefix_caching.py
+++ b/tests/v1/core/test_prefix_caching.py
@@ -140,7 +140,7 @@ def make_kv_cache_config_hybrid_model(
elif second_spec_type == "mamba":
second_spec = MambaSpec(
block_size=block_size,
- shapes=(1, 1),
+ shapes=((1, 1),),
dtypes=(torch.float32,),
)
@@ -175,7 +175,7 @@ def make_kv_cache_config_three_types(
if third_spec_type == "mamba":
third_spec = MambaSpec(
block_size=block_size,
- shapes=(1, 1),
+ shapes=((1, 1),),
dtypes=(torch.float32,),
)
elif third_spec_type == "sliding_window":
@@ -755,12 +755,12 @@ def _make_hybrid_kv_cache_config(
),
"mamba": lambda: MambaSpec(
block_size=block_size,
- shapes=(1, 1),
+ shapes=((1, 1),),
dtypes=(torch.float32,),
),
"mamba_align": lambda: MambaSpec(
block_size=block_size,
- shapes=(1, 1),
+ shapes=((1, 1),),
dtypes=(torch.float32,),
mamba_cache_mode="align",
),
diff --git a/tests/v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh b/tests/v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh
index 040632249d34..76aaa5f0e5b4 100755
--- a/tests/v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh
+++ b/tests/v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh
@@ -94,7 +94,6 @@ else
echo "running with default attention backend"
fi
-# Check if cross-layers is enabled (non-empty)
if [[ -n "${CROSS_LAYERS_BLOCKS:-}" ]]; then
echo "CROSS_LAYERS_BLOCKS is set, running with --enable-cross-layers"
label+=" - CROSS_LAYERS_BLOCKS enabled"
diff --git a/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh b/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh
index fc446a0e7658..1f26df012bbb 100755
--- a/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh
+++ b/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh
@@ -4,7 +4,6 @@ set -xe
# Parse command line arguments
KV_BUFFER_DEVICE="cuda" # Default to cuda
ATTENTION_BACKEND="" # Default to empty (use vllm default)
-CROSS_LAYERS_BLOCKS="False"
ENABLE_HMA_VAR="" # Default to empty (HMA disabled by default for kv connector)
# Check for ENABLE_HMA_FLAG environment variable
if [[ -n "${ENABLE_HMA_FLAG:-}" ]]; then
@@ -22,12 +21,12 @@ while [[ $# -gt 0 ]]; do
shift 2
;;
--enable-cross-layers)
- CROSS_LAYERS_BLOCKS="True"
+ export VLLM_KV_CACHE_LAYOUT="BLHNC"
shift 1
;;
*)
echo "Unknown option $1"
- echo "Usage: $0 [--kv_buffer_device ] [--attention-backend ]"
+ echo "Usage: $0 [--kv_buffer_device ] [--attention-backend ] [--enable-cross-layers]"
exit 1
;;
esac
@@ -44,24 +43,19 @@ if [[ -n "$VLLM_SERVE_EXTRA_ARGS" ]]; then
echo "vLLM serve extra args: $VLLM_SERVE_EXTRA_ARGS"
fi
-DECODER_KV_LAYOUT=${DECODER_KV_LAYOUT:-"HND"} # Default to HND, optional NHD
-if [[ "$DECODER_KV_LAYOUT" == "NHD" ]]; then
+PREFILLER_KV_LAYOUT=${VLLM_KV_CACHE_LAYOUT:-"LBHNC"}
+DECODER_KV_LAYOUT=${DECODER_KV_LAYOUT:-"$PREFILLER_KV_LAYOUT"}
+if [[ "$DECODER_KV_LAYOUT" == "LBNHC" ]]; then
KV_CONFIG_HETERO_LAYOUT=',"enable_permute_local_kv":"True"'
else
KV_CONFIG_HETERO_LAYOUT=''
fi
-if [[ "$CROSS_LAYERS_BLOCKS" == "True" ]]; then
- KV_EXTRA_CONFIG=',"kv_connector_extra_config":{"enable_cross_layers_blocks": "True"}'
-else
- KV_EXTRA_CONFIG=''
-fi
-
# Build the kv-transfer-config once
if [[ "$KV_BUFFER_DEVICE" == "cuda" ]]; then
- KV_CONFIG='{"kv_connector":"NixlConnector","kv_role":"kv_both"'${KV_CONFIG_HETERO_LAYOUT}${KV_EXTRA_CONFIG}'}'
+ KV_CONFIG='{"kv_connector":"NixlConnector","kv_role":"kv_both"'${KV_CONFIG_HETERO_LAYOUT}'}'
else
- KV_CONFIG="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"$KV_BUFFER_DEVICE\""${KV_CONFIG_HETERO_LAYOUT}${KV_EXTRA_CONFIG}"}"
+ KV_CONFIG="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"$KV_BUFFER_DEVICE\""${KV_CONFIG_HETERO_LAYOUT}"}"
fi
# Models to run
@@ -158,7 +152,7 @@ run_tests_for_model() {
# Build the command with or without model-specific args
BASE_CMD="CUDA_VISIBLE_DEVICES=$GPU_ID \
- VLLM_KV_CACHE_LAYOUT='HND' \
+ VLLM_KV_CACHE_LAYOUT='$PREFILLER_KV_LAYOUT' \
UCX_NET_DEVICES=all \
VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT \
vllm serve $model_name \
diff --git a/tests/v1/kv_connector/nixl_integration/run_multi_connector_accuracy_test.sh b/tests/v1/kv_connector/nixl_integration/run_multi_connector_accuracy_test.sh
index 2e71858983e9..86e9b34e51fb 100755
--- a/tests/v1/kv_connector/nixl_integration/run_multi_connector_accuracy_test.sh
+++ b/tests/v1/kv_connector/nixl_integration/run_multi_connector_accuracy_test.sh
@@ -8,9 +8,9 @@
# wrapping NixlConnector and OffloadingConnector, then runs gsm8k accuracy via
# test_accuracy.py.
#
-# By default runs two configurations:
-# 1. Normal KV layout (NixlConnector without cross-layer blocks)
-# 2. Cross-layer KV layout (NixlConnector with enable_cross_layers_blocks)
+# Runs two configurations:
+# 1. Standard KV layout (LBHNC)
+# 2. Cross-layer KV layout (BLHNC) via VLLM_KV_CACHE_LAYOUT
#
# Usage:
# bash tests/v1/kv_connector/nixl_integration/run_multi_connector_accuracy_test.sh
@@ -41,8 +41,6 @@ SMI_BIN=$(which nvidia-smi || which rocm-smi || echo "")
# ── KV transfer configs ─────────────────────────────────────────────────
-# Normal layout: OffloadingConnector prefers cross-layer but NixlConnector
-# does not, so MultiConnector.prefer_cross_layer_blocks = False.
KV_CONFIG_NORMAL='{
"kv_connector":"MultiConnector",
"kv_role":"kv_both",
@@ -57,21 +55,6 @@ KV_CONFIG_NORMAL='{
# Remove whitespace for CLI safety
KV_CONFIG_NORMAL=$(echo "$KV_CONFIG_NORMAL" | tr -d '[:space:]')
-# Cross-layer layout: both connectors prefer cross-layer blocks.
-KV_CONFIG_CROSS_LAYERS='{
- "kv_connector":"MultiConnector",
- "kv_role":"kv_both",
- "kv_connector_extra_config":{
- "connectors":[
- {"kv_connector":"NixlConnector","kv_role":"kv_both",
- "kv_connector_extra_config":{"enable_cross_layers_blocks":"True"}},
- {"kv_connector":"OffloadingConnector","kv_role":"kv_both",
- "kv_connector_extra_config":{"cpu_bytes_to_use":1000000000}}
- ]
- }
-}'
-KV_CONFIG_CROSS_LAYERS=$(echo "$KV_CONFIG_CROSS_LAYERS" | tr -d '[:space:]')
-
# ── Helpers ──────────────────────────────────────────────────────────────
trap 'kill $(jobs -pr) 2>/dev/null' SIGINT SIGTERM EXIT
@@ -122,7 +105,7 @@ run_tests_for_model() {
# ── Start prefill instance ──
echo "Starting prefill instance on GPU $PREFILL_GPU, port $PREFILL_PORT"
BASE_CMD="CUDA_VISIBLE_DEVICES=$PREFILL_GPU \
- VLLM_KV_CACHE_LAYOUT='HND' \
+ VLLM_KV_CACHE_LAYOUT='${VLLM_KV_CACHE_LAYOUT:-LBHNC}' \
UCX_NET_DEVICES=all \
VLLM_NIXL_SIDE_CHANNEL_PORT=$PREFILL_SIDE_CHANNEL_PORT \
vllm serve $model_name \
@@ -144,7 +127,7 @@ run_tests_for_model() {
# ── Start decode instance ──
echo "Starting decode instance on GPU $DECODE_GPU, port $DECODE_PORT"
BASE_CMD="CUDA_VISIBLE_DEVICES=$DECODE_GPU \
- VLLM_KV_CACHE_LAYOUT='HND' \
+ VLLM_KV_CACHE_LAYOUT='${VLLM_KV_CACHE_LAYOUT:-LBHNC}' \
UCX_NET_DEVICES=all \
VLLM_NIXL_SIDE_CHANNEL_PORT=$DECODE_SIDE_CHANNEL_PORT \
vllm serve $model_name \
@@ -198,7 +181,8 @@ for model in "${MODELS[@]}"; do
fi
if [[ -z "${SKIP_CROSS_LAYERS:-}" ]]; then
- run_tests_for_model "$model" "$KV_CONFIG_CROSS_LAYERS" "MultiConnector cross-layer layout"
+ VLLM_KV_CACHE_LAYOUT=BLHNC \
+ run_tests_for_model "$model" "$KV_CONFIG_NORMAL" "MultiConnector cross-layer layout"
fi
done
diff --git a/tests/v1/kv_connector/nixl_integration/run_multi_connector_edge_case_test.sh b/tests/v1/kv_connector/nixl_integration/run_multi_connector_edge_case_test.sh
index a80950b34136..8eb86f51b3ff 100755
--- a/tests/v1/kv_connector/nixl_integration/run_multi_connector_edge_case_test.sh
+++ b/tests/v1/kv_connector/nixl_integration/run_multi_connector_edge_case_test.sh
@@ -92,7 +92,7 @@ run_tests_for_model() {
# ── Start prefill instance ──
echo "Starting prefill instance on GPU $PREFILL_GPU, port $PREFILL_PORT"
BASE_CMD="CUDA_VISIBLE_DEVICES=$PREFILL_GPU \
- VLLM_KV_CACHE_LAYOUT='HND' \
+ VLLM_KV_CACHE_LAYOUT='LBHNC' \
UCX_NET_DEVICES=all \
VLLM_NIXL_SIDE_CHANNEL_PORT=$PREFILL_SIDE_CHANNEL_PORT \
vllm serve \"$model_name\" \
@@ -115,7 +115,7 @@ run_tests_for_model() {
# ── Start decode instance ──
echo "Starting decode instance on GPU $DECODE_GPU, port $DECODE_PORT"
BASE_CMD="CUDA_VISIBLE_DEVICES=$DECODE_GPU \
- VLLM_KV_CACHE_LAYOUT='HND' \
+ VLLM_KV_CACHE_LAYOUT='LBHNC' \
UCX_NET_DEVICES=all \
VLLM_NIXL_SIDE_CHANNEL_PORT=$DECODE_SIDE_CHANNEL_PORT \
vllm serve \"$model_name\" \
diff --git a/tests/v1/kv_connector/nixl_integration/spec_decode_acceptance_test.sh b/tests/v1/kv_connector/nixl_integration/spec_decode_acceptance_test.sh
index 2c5622a2f0e1..ef624ea9d8a6 100755
--- a/tests/v1/kv_connector/nixl_integration/spec_decode_acceptance_test.sh
+++ b/tests/v1/kv_connector/nixl_integration/spec_decode_acceptance_test.sh
@@ -243,7 +243,7 @@ run_test_for_device() {
echo "Starting prefill instance $i on GPU $GPU_ID, port $PORT"
env \
${GPU_DEVICE_VAR}=$GPU_ID \
- VLLM_KV_CACHE_LAYOUT='HND' \
+ VLLM_KV_CACHE_LAYOUT='LBHNC' \
UCX_NET_DEVICES=all \
${VLLM_SSM_CONV_STATE_LAYOUT:+VLLM_SSM_CONV_STATE_LAYOUT=$VLLM_SSM_CONV_STATE_LAYOUT} \
VLLM_NIXL_SIDE_CHANNEL_HOST=$NIXL_SIDE_CHANNEL_HOST \
@@ -283,7 +283,7 @@ run_test_for_device() {
echo "Starting decode instance $i on GPU $GPU_ID, port $PORT"
env \
${GPU_DEVICE_VAR}=$GPU_ID \
- VLLM_KV_CACHE_LAYOUT='HND' \
+ VLLM_KV_CACHE_LAYOUT='LBHNC' \
UCX_NET_DEVICES=all \
${VLLM_SSM_CONV_STATE_LAYOUT:+VLLM_SSM_CONV_STATE_LAYOUT=$VLLM_SSM_CONV_STATE_LAYOUT} \
VLLM_NIXL_SIDE_CHANNEL_HOST=$NIXL_SIDE_CHANNEL_HOST \
diff --git a/tests/v1/kv_connector/unit/offloading_connector/test_worker.py b/tests/v1/kv_connector/unit/offloading_connector/test_worker.py
index 36294632bb9f..6cb8373d9c11 100644
--- a/tests/v1/kv_connector/unit/offloading_connector/test_worker.py
+++ b/tests/v1/kv_connector/unit/offloading_connector/test_worker.py
@@ -9,7 +9,6 @@
from vllm.platforms import current_platform
from vllm.utils.torch_utils import get_dtype_size
from vllm.v1.attention.backends.registry import AttentionBackendEnum
-from vllm.v1.attention.backends.utils import set_kv_cache_layout
from vllm.v1.kv_cache_interface import (
FullAttentionSpec,
KVCacheConfig,
@@ -57,31 +56,25 @@ def _allocate_and_reshape_kv_caches(
Use the real GPUModelRunner allocation and reshape methods to produce
kv_caches, just like the model runner does during initialization.
"""
+ from vllm.v1.kv_cache_interface import KVCacheLayout
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
- # Some backends (e.g. FlashAttention) query the KV cache layout during
- # reshape, which ultimately calls get_current_vllm_config(). Setting
- # the layout override avoids needing a full VllmConfig context.
- set_kv_cache_layout("NHD")
- try:
- runner = object.__new__(GPUModelRunner)
- runner.device = device
- runner.runner_only_attn_layers = set()
- runner.attn_groups = attn_groups
- runner.kv_cache_config = kv_cache_config
- runner.cache_config = MagicMock(cache_dtype="auto")
- runner.shared_kv_cache_layers = {}
- runner.model_config = MagicMock()
- runner.model_config.hf_config.model_type = ""
- runner.compilation_config = MagicMock(
- static_forward_context=defaultdict(MagicMock)
- )
- runner.kv_caches = []
-
- kernel_block_sizes = [BLOCK_SIZE] * len(kv_cache_config.kv_cache_groups)
- return runner.initialize_kv_cache_tensors(kv_cache_config, kernel_block_sizes)
- finally:
- set_kv_cache_layout(None)
+ runner = object.__new__(GPUModelRunner)
+ runner.device = device
+ runner.runner_only_attn_layers = set()
+ runner.attn_groups = attn_groups
+ runner.kv_cache_config = kv_cache_config
+ runner.cache_config = MagicMock(cache_dtype="auto")
+ runner.shared_kv_cache_layers = {}
+ runner.model_config = MagicMock()
+ runner.model_config.hf_config.model_type = ""
+ runner.compilation_config = MagicMock(static_forward_context=defaultdict(MagicMock))
+ runner.kv_caches = []
+
+ kernel_block_sizes = [BLOCK_SIZE] * len(kv_cache_config.kv_cache_groups)
+ return runner._allocate_and_reshape_kv_cache(
+ kv_cache_config, kernel_block_sizes, layout=KVCacheLayout.LBNHC
+ )
def _make_worker(kv_cache_config: KVCacheConfig):
@@ -208,18 +201,19 @@ def test_register_kv_caches(backend):
aligned_mamba_layer_names,
]
- kv_cache_tensors: list[KVCacheTensor] = []
+ shared_by: list[list[str]] = []
for i in range(GROUP_SIZE):
- shared_by: list[str] = []
+ slot_layers: list[str] = []
for group_layer_names in layer_groups:
if len(group_layer_names) > i:
- shared_by.append(group_layer_names[i])
- kv_cache_tensors.append(
- KVCacheTensor(
- size=PAGE_SIZE_BYTES * NUM_BLOCKS,
- shared_by=shared_by,
- )
+ slot_layers.append(group_layer_names[i])
+ shared_by.append(slot_layers)
+ kv_cache_tensors: list[KVCacheTensor] = [
+ KVCacheTensor(
+ size=PAGE_SIZE_BYTES * NUM_BLOCKS * GROUP_SIZE,
+ shared_by=shared_by,
)
+ ]
kv_cache_groups = [
KVCacheGroupSpec(layer_names=attn_layer_names, kv_cache_spec=attn_spec),
@@ -378,11 +372,11 @@ def test_register_kv_caches_uniform_type(backend):
kv_cache_tensors=[
KVCacheTensor(
size=spec_a.page_size_bytes * NUM_BLOCKS,
- shared_by=[layer_a],
+ shared_by=[[layer_a]],
),
KVCacheTensor(
size=spec_b.page_size_bytes * NUM_BLOCKS,
- shared_by=[layer_b],
+ shared_by=[[layer_b]],
),
],
kv_cache_groups=[
diff --git a/tests/v1/kv_connector/unit/offloading_connector/utils.py b/tests/v1/kv_connector/unit/offloading_connector/utils.py
index 22d00b0c834b..90ba96db2723 100644
--- a/tests/v1/kv_connector/unit/offloading_connector/utils.py
+++ b/tests/v1/kv_connector/unit/offloading_connector/utils.py
@@ -250,7 +250,7 @@ def __init__(
)
# register worker kv_caches to enable OffloadingWorker creations
- # set_current_vllm_config is needed for get_kv_cache_layout() to work
+ # set_current_vllm_config is needed for resolve_kv_cache_layout() to work
kv_caches: dict[str, torch.Tensor] = {}
for group in kv_cache_groups:
spec = group.kv_cache_spec
diff --git a/tests/v1/kv_connector/unit/test_bidirectional_kv_transfer.py b/tests/v1/kv_connector/unit/test_bidirectional_kv_transfer.py
index dc76d61178d8..1277141347a0 100644
--- a/tests/v1/kv_connector/unit/test_bidirectional_kv_transfer.py
+++ b/tests/v1/kv_connector/unit/test_bidirectional_kv_transfer.py
@@ -99,7 +99,7 @@ def _make_connector_with_fake_worker(
worker = connector.connector_worker
assert isinstance(worker.nixl_wrapper, FakeNixlWrapper)
worker.nixl_wrapper.set_cycles_before_xfer_done(cycles_before_done)
- worker.kv_cache_layout = "HND"
+ worker.kv_cache_layout = "LBHNC"
if do_handshake:
remote_agents = worker._nixl_handshake(
host="localhost",
diff --git a/tests/v1/kv_connector/unit/test_kv_cache_layout.py b/tests/v1/kv_connector/unit/test_kv_cache_layout.py
index 7f8028991703..bd36f6c6c848 100644
--- a/tests/v1/kv_connector/unit/test_kv_cache_layout.py
+++ b/tests/v1/kv_connector/unit/test_kv_cache_layout.py
@@ -1,36 +1,59 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Unit tests for reshape_kv_cache."""
+import pytest
+import torch
-def test_mla_backend_rejects_cross_layer_kv_cache():
- """MLA backends return identity permutation (layers dim first)
- to signal cross-layer KV cache is unsupported."""
- from vllm.model_executor.layers.attention.mla_attention import (
- MLACommonBackend,
- )
+from vllm.v1.kv_cache_interface import (
+ FullAttentionSpec,
+ KVCacheLayout,
+ compute_layer_kv_cache_shape_bytes,
+ reshape_kv_cache,
+)
+
+NUM_BLOCKS = 4
+BLOCK_SIZE = 4
+NUM_KV_HEADS = 2
+HEAD_SIZE = 8
+DTYPE = torch.bfloat16
- stride_order = MLACommonBackend.get_kv_cache_stride_order(
- include_num_layers_dimension=True
- )
- assert stride_order == (0, 1, 2, 3)
- assert stride_order[0] == 0 # layers dim first => no cross-layer
- assert MLACommonBackend.get_kv_cache_stride_order(
- include_num_layers_dimension=False
- ) == (0, 1, 2)
-
-
-def test_deepseek_v32_indexer_rejects_cross_layer_kv_cache():
- """DeepseekV32Indexer returns identity permutation (layers dim first)
- to signal cross-layer KV cache is unsupported."""
- from vllm.v1.attention.backends.mla.indexer import (
- DeepseekV32IndexerBackend,
- )
- stride_order = DeepseekV32IndexerBackend.get_kv_cache_stride_order(
- include_num_layers_dimension=True
+@pytest.mark.parametrize(
+ "layout", [layer for layer in KVCacheLayout if layer.is_layer_compact]
+)
+def test_reshape_kv_cache(layout):
+ spec = FullAttentionSpec(
+ block_size=BLOCK_SIZE,
+ num_kv_heads=NUM_KV_HEADS,
+ head_size=HEAD_SIZE,
+ dtype=DTYPE,
)
- assert stride_order == (0, 1, 2, 3)
- assert stride_order[0] == 0 # layers dim first => no cross-layer
- assert DeepseekV32IndexerBackend.get_kv_cache_stride_order(
- include_num_layers_dimension=False
- ) == (0, 1, 2)
+ num_slots = 2
+ total_bytes = spec.page_size_bytes * NUM_BLOCKS * num_slots
+ raw = torch.zeros(total_bytes, dtype=torch.int8, device="cuda")
+ views = reshape_kv_cache(raw, spec, NUM_BLOCKS, num_slots, layout)
+
+ byte_4d = compute_layer_kv_cache_shape_bytes(spec, NUM_BLOCKS)
+ dtype_size = torch.tensor([], dtype=spec.dtype).element_size()
+ expected_shape = (*byte_4d[:3], byte_4d[3] // dtype_size)
+
+ assert len(views) == num_slots
+ for v in views:
+ assert v.shape == expected_shape
+ assert v.dtype == spec.dtype
+
+ # The physical layout's innermost 3 dims (H, N, C minus the L dim)
+ # should match the layout's stride_order permutation: dims later in
+ # the physical order have smaller strides.
+ stride_order = layout.layer_stride_order
+ strides = views[0].stride()
+ for i in range(3):
+ for j in range(i + 1, 4):
+ if stride_order[i] < stride_order[j]:
+ assert strides[i] >= strides[j], (
+ f"layout {layout.name}: dim {i} (physical pos "
+ f"{stride_order[i]}) should have >= stride than "
+ f"dim {j} (physical pos {stride_order[j]}), got "
+ f"strides={strides}"
+ )
diff --git a/tests/v1/kv_connector/unit/test_mooncake_connector.py b/tests/v1/kv_connector/unit/test_mooncake_connector.py
index a10ae1f456ed..b5147124aebf 100644
--- a/tests/v1/kv_connector/unit/test_mooncake_connector.py
+++ b/tests/v1/kv_connector/unit/test_mooncake_connector.py
@@ -25,8 +25,11 @@
MooncakeBootstrapServer,
)
from vllm.utils.network_utils import get_open_port
-from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
-from vllm.v1.kv_cache_interface import KVCacheConfig
+from vllm.v1.kv_cache_interface import (
+ FullAttentionSpec,
+ KVCacheConfig,
+ compute_layer_kv_cache_shape_bytes,
+)
from vllm.v1.request import RequestStatus
from .utils import create_request, create_scheduler, create_vllm_config
@@ -369,26 +372,21 @@ async def test_kv_producer(monkeypatch):
with patch.object(
prefill_worker, "_send_blocks", return_value=0
) as mock_send_blocks:
- # With blocks-first layout, each block is virtually split
- # into K and V halves, producing non-coalesced transfers.
- kv_half = block_len // 2
-
- def expected_split_transfers(src_base, dst_base, src_blocks, dst_blocks):
- """Build expected (src_ptrs, dst_ptrs, lengths) for
- virtual-split K/V transfers."""
- src_ptrs, dst_ptrs, lengths = [], [], []
- for kv_offset in (0, kv_half):
- for sb, db in zip(src_blocks, dst_blocks):
- src_ptrs.append(src_base + sb * block_len + kv_offset)
- dst_ptrs.append(dst_base + db * block_len + kv_offset)
- lengths.append(kv_half)
- return src_ptrs, dst_ptrs, lengths
+ # Under the standardized blocks-first layout K and V are packed
+ # into a single contiguous region per block. Adjacent blocks are
+ # coalesced into a single larger transfer; all cases below pass
+ # a single contiguous run.
+ def expected_transfers(src_base, dst_base, src_blocks, dst_blocks):
+ n = len(src_blocks)
+ return (
+ [src_base + src_blocks[0] * block_len],
+ [dst_base + dst_blocks[0] * block_len],
+ [n * block_len],
+ )
# Normal case: 2 blocks to 2 blocks
await prefill_worker.send_kv_to_decode(identity, mock_socket, xfer_meta)
- src, dst, lens = expected_split_transfers(
- 0x1000, 0x2000, [10, 11], [20, 21]
- )
+ src, dst, lens = expected_transfers(0x1000, 0x2000, [10, 11], [20, 21])
mock_send_blocks.assert_called_once_with(
"consumer-host:54321",
src,
@@ -420,7 +418,7 @@ def expected_split_transfers(src_base, dst_base, src_blocks, dst_blocks):
# Worker processes the consumer's request
await prefill_worker.send_kv_to_decode(identity, mock_socket, xfer_meta)
# Verify transfer parameters are correct: 11 to 20
- src, dst, lens = expected_split_transfers(0x1000, 0x2000, [11], [20])
+ src, dst, lens = expected_transfers(0x1000, 0x2000, [11], [20])
mock_send_blocks.assert_called_once_with(
"consumer-host:54321",
src,
@@ -621,11 +619,14 @@ def test_register_kv_caches():
worker = connector.connector_worker
mock_thread.return_value.is_alive.return_value = False
- kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape(
- num_blocks=2, block_size=16, num_kv_heads=4, head_size=64
+ shape = compute_layer_kv_cache_shape_bytes(
+ FullAttentionSpec(
+ block_size=16, num_kv_heads=4, head_size=64, dtype=torch.float16
+ ),
+ 2,
)
- tensor1 = torch.zeros(*kv_cache_shape, dtype=torch.float16)
- tensor2 = torch.zeros(*kv_cache_shape, dtype=torch.float16)
+ tensor1 = torch.zeros(*shape, dtype=torch.int8).view(torch.float16)
+ tensor2 = torch.zeros(*shape, dtype=torch.int8).view(torch.float16)
kv_caches = {"layer0": tensor1, "layer1": tensor2}
with patch.object(
@@ -642,7 +643,7 @@ def test_register_kv_caches():
# Verify block_len_per_layer is set correctly.
assert len(worker.block_len_per_layer) == len(registered_ptrs)
for bl in worker.block_len_per_layer:
- assert bl == tensor1.nbytes // tensor1.shape[0]
+ assert bl == tensor1.stride(0) * tensor1.element_size()
def test_register_kv_caches_supports_mixed_mla_and_eagle_shapes():
@@ -806,47 +807,34 @@ async def test_kv_producer_heterogeneous_tp(monkeypatch, d_tp_size):
flat_remote = [b for g in remote_block_ids for b in g]
num_blocks = len(flat_local)
- # With blocks-first layout, virtual split halves block
- # lengths and doubles transfer regions (K + V).
- local_kv_block_len = local_block_len // 2
- remote_kv_block_len = remote_block_len // 2
+ # Under the standardized blocks-first layout K and V are
+ # already packed into a single contiguous region per block,
+ # so _expand_transfer_regions emits one region per layer.
+ assert len(src_ptrs) == num_blocks
+ assert len(dst_ptrs) == num_blocks
+ assert len(lengths) == num_blocks
- assert len(src_ptrs) == 2 * num_blocks
- assert len(dst_ptrs) == 2 * num_blocks
- assert len(lengths) == 2 * num_blocks
-
- # Compute expected offsets using kv_block_len
if d_tp_size <= P_TP_SIZE:
tp_ratio = P_TP_SIZE // d_tp_size
expected_src_off = 0
- expected_dst_off = (P_TP_RANK % tp_ratio) * local_kv_block_len
- expected_xfer_len = local_kv_block_len
+ expected_dst_off = (P_TP_RANK % tp_ratio) * local_block_len
+ expected_xfer_len = local_block_len
else:
ratio_abs = d_tp_size // P_TP_SIZE
- expected_src_off = (d_rank % ratio_abs) * remote_kv_block_len
+ expected_src_off = (d_rank % ratio_abs) * remote_block_len
expected_dst_off = 0
- expected_xfer_len = remote_kv_block_len
-
- # First num_blocks entries are K region,
- # next num_blocks are V region.
- for region_idx in range(2):
- local_region_base = 0x1000 + region_idx * local_kv_block_len
- remote_region_base = 0x2000 + region_idx * remote_kv_block_len
- for blk_idx, (lblk, rblk) in enumerate(
- zip(flat_local, flat_remote)
- ):
- idx = region_idx * num_blocks + blk_idx
- assert src_ptrs[idx] == (
- local_region_base
- + lblk * local_block_len
- + expected_src_off
- )
- assert dst_ptrs[idx] == (
- remote_region_base
- + rblk * remote_block_len
- + expected_dst_off
- )
- assert lengths[idx] == expected_xfer_len
+ expected_xfer_len = remote_block_len
+
+ local_region_base = 0x1000
+ remote_region_base = 0x2000
+ for blk_idx, (lblk, rblk) in enumerate(zip(flat_local, flat_remote)):
+ assert src_ptrs[blk_idx] == (
+ local_region_base + lblk * local_block_len + expected_src_off
+ )
+ assert dst_ptrs[blk_idx] == (
+ remote_region_base + rblk * remote_block_len + expected_dst_off
+ )
+ assert lengths[blk_idx] == expected_xfer_len
# Verify successful response sent back to consumer
mock_socket.send_multipart.assert_called_once()
diff --git a/tests/v1/kv_connector/unit/test_moriio_connector.py b/tests/v1/kv_connector/unit/test_moriio_connector.py
index 2a5c96a46e5a..b9097865f12f 100644
--- a/tests/v1/kv_connector/unit/test_moriio_connector.py
+++ b/tests/v1/kv_connector/unit/test_moriio_connector.py
@@ -35,7 +35,11 @@
get_ip,
make_zmq_path,
)
-from vllm.v1.kv_cache_interface import KVCacheConfig
+from vllm.v1.kv_cache_interface import (
+ FullAttentionSpec,
+ KVCacheConfig,
+ compute_layer_kv_cache_shape_bytes,
+)
from .utils import create_request, create_scheduler
@@ -174,7 +178,7 @@ class FakeMoRIIOConnectorWorker(MoRIIOConnectorWorker):
REMOTE_ENGINE_ID = "remote_engine"
def __init__(
- self, *args, hand_shake_latency: float = 1.8, kv_cache_layout="HND", **kwargs
+ self, *args, hand_shake_latency: float = 1.8, kv_cache_layout="LBHNC", **kwargs
):
super().__init__(*args, **kwargs)
@@ -415,16 +419,15 @@ def test_register_kv_caches(mock_parallel_groups):
DEFAULT_PORT = 6301
TP_RANK = 0
DP_RANK = 0
- from vllm.v1.attention.backends.rocm_aiter_fa import AiterFlashAttentionBackend
-
- backend_cls = AiterFlashAttentionBackend
-
- # Create test kv cache tensors using proper backend shape
- kv_cache_shape = backend_cls.get_kv_cache_shape(
- num_blocks=2, block_size=16, num_kv_heads=4, head_size=64
+ # Create test kv cache tensors using KVCacheSpec layout
+ shape = compute_layer_kv_cache_shape_bytes(
+ FullAttentionSpec(
+ block_size=16, num_kv_heads=4, head_size=64, dtype=torch.float16
+ ),
+ 2,
)
- shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
- unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
+ shared_tensor = torch.zeros(*shape, dtype=torch.int8).view(torch.float16)
+ unique_tensor = torch.zeros(*shape, dtype=torch.int8).view(torch.float16)
kv_caches = {
"layer0": shared_tensor,
"layer1": unique_tensor,
@@ -511,16 +514,15 @@ def test_moriio_handshake_returns_metadata(mock_parallel_groups):
ROLE = "kv_consumer"
vllm_config = create_vllm_config(role=ROLE)
- from vllm.v1.attention.backends.rocm_aiter_fa import AiterFlashAttentionBackend
-
- backend_cls = AiterFlashAttentionBackend
-
- # Create test kv cache tensors using proper backend shape
- kv_cache_shape = backend_cls.get_kv_cache_shape(
- num_blocks=2, block_size=16, num_kv_heads=4, head_size=64
+ # Create test kv cache tensors using KVCacheSpec layout
+ shape = compute_layer_kv_cache_shape_bytes(
+ FullAttentionSpec(
+ block_size=16, num_kv_heads=4, head_size=64, dtype=torch.float16
+ ),
+ 2,
)
- shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
- unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
+ shared_tensor = torch.zeros(*shape, dtype=torch.int8).view(torch.float16)
+ unique_tensor = torch.zeros(*shape, dtype=torch.int8).view(torch.float16)
kv_caches = {
"layer0": shared_tensor,
"layer1": unique_tensor,
diff --git a/tests/v1/kv_connector/unit/test_multi_connector.py b/tests/v1/kv_connector/unit/test_multi_connector.py
index d1f3a81ca969..15089599d294 100644
--- a/tests/v1/kv_connector/unit/test_multi_connector.py
+++ b/tests/v1/kv_connector/unit/test_multi_connector.py
@@ -844,16 +844,6 @@ def test_multi_connector_overrides_all_base_methods():
""")
-def test_multi_connector_prefer_cross_layer_blocks(mc):
- mc._connectors[0].prefer_cross_layer_blocks = False
- mc._connectors[1].prefer_cross_layer_blocks = True
- assert mc.prefer_cross_layer_blocks is False
-
- mc._connectors[0].prefer_cross_layer_blocks = True
- mc._connectors[1].prefer_cross_layer_blocks = True
- assert mc.prefer_cross_layer_blocks is True
-
-
def test_multi_connector_worker_metadata(mc):
class MockConnectorWorkerMetadata(KVConnectorWorkerMetadata):
def __init__(self, data: set[str]):
diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py
index f07a8352e735..130a529d7e51 100644
--- a/tests/v1/kv_connector/unit/test_nixl_connector.py
+++ b/tests/v1/kv_connector/unit/test_nixl_connector.py
@@ -51,7 +51,6 @@
from vllm.platforms import current_platform
from vllm.platforms.interface import Platform
from vllm.sampling_params import SamplingParams
-from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
from vllm.v1.attention.backends.utils import set_kv_cache_layout
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.output_processor import OutputProcessor
@@ -60,12 +59,10 @@
FullAttentionSpec,
KVCacheConfig,
KVCacheGroupSpec,
- KVCacheTensor,
+ compute_layer_kv_cache_shape_bytes,
)
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
from vllm.v1.request import RequestStatus
-from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
-from vllm.v1.worker.utils import AttentionGroup
from .utils import (
create_request,
@@ -355,7 +352,6 @@ def test_abort_immediately_remote_prefill_enqueues_empty_recv():
)
def test_kv_transfer_handshake(dist_init):
"""Unit test for basic NixlConnector interface functionality."""
- from vllm.config import set_current_vllm_config
# Test setup, we creates a scheduler that contains a NixlConnector
# of role SCHEDULER, and expect it to be serving NixlAgentMetadata from
@@ -391,14 +387,11 @@ def test_kv_transfer_handshake(dist_init):
kv_cache_spec = cast(
AttentionSpec, kv_cache_config.kv_cache_groups[0].kv_cache_spec
)
- kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape(
- num_blocks=kv_cache_config.num_blocks,
- block_size=kv_cache_spec.block_size,
- num_kv_heads=kv_cache_spec.num_kv_heads,
- head_size=kv_cache_spec.head_size,
+ shape = compute_layer_kv_cache_shape_bytes(
+ kv_cache_spec, kv_cache_config.num_blocks
)
- shared_tensor = torch.zeros(*kv_cache_shape, dtype=kv_cache_spec.dtype)
- unique_tensor = torch.zeros(*kv_cache_shape, dtype=kv_cache_spec.dtype)
+ shared_tensor = torch.zeros(*shape, dtype=torch.int8).view(kv_cache_spec.dtype)
+ unique_tensor = torch.zeros(*shape, dtype=torch.int8).view(kv_cache_spec.dtype)
kv_caches = {
"layer0": shared_tensor,
"layer1": unique_tensor,
@@ -477,7 +470,7 @@ def __init__(
self,
*args,
hand_shake_latency: float = 1.8,
- kv_cache_layout="HND",
+ kv_cache_layout="LBHNC",
kv_cache_config=None,
**kwargs,
):
@@ -488,9 +481,8 @@ def __init__(
self.kv_cache_layout = kv_cache_layout
# Mock register_kv_caches attribute needed for tests that do not call it.
self.src_xfer_handles_by_block_size = {self.block_size: 1}
- test_shape = self.attn_backends[0].get_kv_cache_shape(
- num_blocks=1, block_size=16, num_kv_heads=1, head_size=1
- )
+ rep_spec = self.kv_cache_config.kv_cache_groups[0].kv_cache_spec
+ test_shape = compute_layer_kv_cache_shape_bytes(rep_spec, 1)
self.transfer_topo = TransferTopology(
tp_rank=self.tp_rank,
tp_size=self.world_size,
@@ -504,7 +496,7 @@ def __init__(
)
self.compat_hash = compute_nixl_compatibility_hash(
- self.vllm_config, self.backend_name, self.transfer_topo.cross_layers_blocks
+ self.vllm_config, self.backend_name
)
def _nixl_handshake(
@@ -549,9 +541,8 @@ def _nixl_handshake(
device_id=remote_tp_rank,
num_blocks=1,
block_lens=remote_block_lens,
- # `self.kv_cache_layout` is only forced to HND when vllm engine
- # is started. We mock HND here.
- kv_cache_layout="HND",
+ block_strides=remote_block_lens,
+ kv_cache_layout="LBHNC",
block_size=self.block_size,
ssm_sizes=(0, 0),
attn_backend_name=self.backend_name,
@@ -603,7 +594,7 @@ def test_multi_xfer_one_engine(
worker.dst_xfer_side_handles = {
FakeNixlConnectorWorker.REMOTE_ENGINE_ID: {0: 1}
}
- worker.kv_cache_layout = "HND"
+ worker.kv_cache_layout = "LBHNC"
num_xfers = 4
while True:
# For the same request_id, initiate multiple xfers across different
@@ -996,7 +987,9 @@ def test_handshake_fails_on_kv_cache_layout_mismatch(
worker.dst_num_blocks[worker.engine_id] = worker.num_blocks
# Metadata with different kv_cache_layout than local worker
- mismatched_layout = "HND" if worker.kv_cache_layout != "HND" else "NHD"
+ mismatched_layout = (
+ "LBHNC" if worker.kv_cache_layout != "LBHNC" else "LBNHC"
+ )
meta = NixlAgentMetadata(
engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
@@ -1004,6 +997,7 @@ def test_handshake_fails_on_kv_cache_layout_mismatch(
device_id=0,
num_blocks=1,
block_lens=worker.block_len_per_layer,
+ block_strides=worker.block_len_per_layer,
kv_cache_layout=mismatched_layout,
block_size=worker.block_size,
ssm_sizes=(0, 0),
@@ -1043,7 +1037,7 @@ def test_handshake_succeed_on_kv_cache_layout_mismatch_with_experimental(
vllm_config,
connector.engine_id,
hand_shake_latency=0,
- kv_cache_layout="NHD",
+ kv_cache_layout="LBNHC",
)
worker = connector.connector_worker
@@ -1054,15 +1048,16 @@ def test_handshake_succeed_on_kv_cache_layout_mismatch_with_experimental(
worker.dst_num_blocks[worker.engine_id] = worker.num_blocks
# Metadata with different kv_cache_layout than local worker
+ remote_block_lens = [i * 2 for i in worker.block_len_per_layer]
meta = NixlAgentMetadata(
engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
kv_caches_base_addr=[0],
device_id=0,
num_blocks=1,
- # prefill TP=1, decode TP=2, remote block_lens is double to local
- block_lens=[i * 2 for i in worker.block_len_per_layer],
- kv_cache_layout="HND",
+ block_lens=remote_block_lens,
+ block_strides=remote_block_lens,
+ kv_cache_layout="LBHNC",
block_size=worker.block_size,
ssm_sizes=(0, 0),
attn_backend_name=worker.backend_name,
@@ -1490,7 +1485,6 @@ def req_id(outputs: list[RequestOutput]) -> str:
llm.llm_engine.engine_core.shutdown()
-@pytest.mark.parametrize("enable_cross_layers", ["False", "True"])
@pytest.mark.parametrize(
"attn_backend",
[
@@ -1511,9 +1505,7 @@ def req_id(outputs: list[RequestOutput]) -> str:
"TRITON_ATTN",
],
)
-def test_register_kv_caches(
- default_vllm_config, dist_init, attn_backend, enable_cross_layers
-):
+def test_register_kv_caches(default_vllm_config, dist_init, attn_backend):
"""
Test that register_kv_caches() properly calls nixl_wrapper methods with
correct data.
@@ -1526,12 +1518,7 @@ def test_register_kv_caches(
"""
vllm_config = create_vllm_config(attention_backend=attn_backend)
-
- # Enable cross layers blocks
- vllm_config.kv_transfer_config.kv_connector_extra_config[
- "enable_cross_layers_blocks"
- ] = enable_cross_layers
- set_kv_cache_layout("HND")
+ set_kv_cache_layout("LBHNC")
# Import the appropriate backend based on the parameter
if attn_backend == "FLASH_ATTN":
@@ -1548,59 +1535,31 @@ def test_register_kv_caches(
backend_cls = TritonAttentionBackend
nixl_worker = "vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker"
- nixl_connector = "vllm.distributed.kv_transfer.kv_connector.v1.nixl.connector"
with (
patch(f"{nixl_worker}.NixlWrapper") as mock_nixl_wrapper,
patch(f"{nixl_worker}.threading.Event"),
patch(f"{nixl_worker}.threading.Thread") as mock_thread,
- patch(f"{nixl_connector}.get_current_attn_backend") as mock_get_attn_backend,
patch(f"{nixl_worker}.get_current_attn_backends") as mock_get_attn_backends,
):
- # Ensure get_attn_backend returns the correct value due to
- # _cached_get_attn_backend returning the backend from previous
- # test run if not mocking.
- mock_get_attn_backend.return_value = backend_cls
mock_get_attn_backends.return_value = [backend_cls]
- num_layers = 32
- block_size = 16
num_blocks = 8
+ block_size = 16
num_heads = 4
head_size = 16
- # TODO (NickLucche) the fact that connector depends on kv_cache_config for init
- # but cross-layer preference cant be inferred prior to creating kv_cache_config
- # is a bit awkward.
- dummy_connector = NixlConnector(
- vllm_config,
- KVConnectorRole.WORKER,
- make_kv_cache_config(block_size=block_size),
- )
kv_cache_spec = FullAttentionSpec(
block_size=block_size,
num_kv_heads=num_heads,
head_size=head_size,
dtype=torch.float16,
)
- if dummy_connector.prefer_cross_layer_blocks:
- kv_cache_config = KVCacheConfig(
- num_blocks=num_blocks,
- kv_cache_tensors=[
- KVCacheTensor(
- size=kv_cache_spec.page_size_bytes * num_blocks,
- shared_by=["all-layers"],
- )
- for _ in range(num_layers)
- ],
- kv_cache_groups=[KVCacheGroupSpec(["all-layers"], kv_cache_spec)],
- )
- else:
- kv_cache_config = KVCacheConfig(
- num_blocks=num_blocks,
- kv_cache_tensors=[],
- kv_cache_groups=[
- KVCacheGroupSpec(["layer0", "layer1", "layer2"], kv_cache_spec)
- ],
- )
+ kv_cache_config = KVCacheConfig(
+ num_blocks=num_blocks,
+ kv_cache_tensors=[],
+ kv_cache_groups=[
+ KVCacheGroupSpec(["layer0", "layer1", "layer2"], kv_cache_spec)
+ ],
+ )
# Create connector
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER, kv_cache_config)
connector.connector_worker = FakeNixlConnectorWorker(
@@ -1620,93 +1579,28 @@ def test_register_kv_caches(
# Reassure the shutdown() check that the thread is terminated
mock_thread.return_value.is_alive.return_value = False
- expected_tensor_size: int
- expected_base_addrs: list[int]
- expected_num_entries: int
- kv_caches: dict[str, torch.Tensor]
- if str(enable_cross_layers).lower() == "true":
- assert connector.prefer_cross_layer_blocks == (
- attn_backend in ("FLASH_ATTN", "FLASHINFER", "TRITON_ATTN")
- )
- else:
- assert not connector.prefer_cross_layer_blocks
-
- test_shape = backend_cls.get_kv_cache_shape(
- num_blocks=1, block_size=16, num_kv_heads=1, head_size=1
+ # Create test kv cache tensors using proper backend shape
+ shape = compute_layer_kv_cache_shape_bytes(
+ kv_cache_spec, kv_cache_config.num_blocks
)
- is_blocks_first = len(test_shape) == 5 and test_shape[0] == 1
- virtually_split = is_blocks_first and not connector.prefer_cross_layer_blocks
-
- if connector.prefer_cross_layer_blocks:
- with set_current_vllm_config(vllm_config):
- _, cross_layers_kv_cache, _ = (
- KVConnectorModelRunnerMixin.allocate_uniform_kv_caches(
- kv_cache_config=kv_cache_config,
- attn_groups=[
- [
- AttentionGroup(
- backend=backend_cls,
- layer_names=[],
- kv_cache_spec=kv_cache_spec,
- kv_cache_group_id=0,
- )
- ]
- ],
- cache_dtype="bfloat16",
- device=torch.accelerator.current_device_index(),
- kernel_block_sizes=[block_size],
- )
- )
- # Store tensor info for validation
- expected_tensor_size = (
- cross_layers_kv_cache.element_size() * cross_layers_kv_cache.numel()
- )
- expected_base_addrs = [
- cross_layers_kv_cache.data_ptr(),
- ]
- expected_num_entries = 1
-
- expected_blocks_count = num_blocks * (2 if virtually_split else 1)
-
- kv_caches = {"all-layers": cross_layers_kv_cache}
- else:
- # Create test kv cache tensors using proper backend shape
- kv_cache_shape = backend_cls.get_kv_cache_shape(
- num_blocks=kv_cache_config.num_blocks,
- block_size=kv_cache_spec.block_size,
- num_kv_heads=kv_cache_spec.num_kv_heads,
- head_size=kv_cache_spec.head_size,
- )
- shared_tensor = torch.zeros(*kv_cache_shape, dtype=kv_cache_spec.dtype)
- unique_tensor = torch.zeros(*kv_cache_shape, dtype=kv_cache_spec.dtype)
- kv_caches = {
- "layer0": shared_tensor,
- "layer1": unique_tensor,
- "layer2": shared_tensor,
- }
+ shared_tensor = torch.zeros(*shape, dtype=torch.int8).view(kv_cache_spec.dtype)
+ unique_tensor = torch.zeros(*shape, dtype=torch.int8).view(kv_cache_spec.dtype)
+ kv_caches = {
+ "layer0": shared_tensor,
+ "layer1": unique_tensor,
+ "layer2": shared_tensor,
+ }
- # Store tensor info for validation
- if is_blocks_first:
- expected_tensor_size = (
- shared_tensor.element_size() * shared_tensor.numel()
- )
- expected_base_addrs = [
- shared_tensor.data_ptr(),
- unique_tensor.data_ptr(),
- ]
- expected_num_entries = 2
- else:
- expected_tensor_size = (
- shared_tensor[0].element_size() * shared_tensor[0].numel()
- )
- expected_base_addrs = [
- shared_tensor[0].data_ptr(),
- shared_tensor[1].data_ptr(),
- unique_tensor[0].data_ptr(),
- unique_tensor[1].data_ptr(),
- ]
- expected_num_entries = 4
- expected_blocks_count = kv_cache_config.num_blocks * 4
+ # Per-layer shape is always 4D (B, N, H, C) — no separate K/V dim.
+ expected_tensor_size = shared_tensor.element_size() * shared_tensor.numel()
+ expected_base_addrs = [
+ shared_tensor.data_ptr(),
+ unique_tensor.data_ptr(),
+ ]
+ expected_num_entries = 2
+ # Packed KV layout: full-page descriptors (no K/V split).
+ # 2 unique regions × num_blocks.
+ expected_blocks_count = kv_cache_config.num_blocks * 2
# Execute register_kv_caches
connector.register_kv_caches(kv_caches)
@@ -1735,15 +1629,7 @@ def test_register_kv_caches(
f"Expected {expected_blocks_count} blocks, got {len(blocks_data)}"
)
- if connector.prefer_cross_layer_blocks:
- num_blocks = 8
- else:
- num_blocks = kv_cache_config.num_blocks
-
- if virtually_split:
- expected_block_len = expected_tensor_size // num_blocks // 2
- else:
- expected_block_len = expected_tensor_size // num_blocks
+ expected_block_len = expected_tensor_size // num_blocks
for i, block_entry in enumerate(blocks_data):
block_start_addr, block_len, tp_rank = block_entry
@@ -2449,14 +2335,11 @@ def test_compatibility_hash_validation(
kv_cache_spec = cast(
AttentionSpec, kv_cache_config.kv_cache_groups[0].kv_cache_spec
)
- kv_cache_shape = decode_worker.attn_backends[0].get_kv_cache_shape(
- num_blocks=kv_cache_config.num_blocks,
- block_size=kv_cache_spec.block_size,
- num_kv_heads=kv_cache_spec.num_kv_heads,
- head_size=kv_cache_spec.head_size,
+ shape = compute_layer_kv_cache_shape_bytes(
+ kv_cache_spec, kv_cache_config.num_blocks
)
- shared_tensor = torch.zeros(*kv_cache_shape, dtype=kv_cache_spec.dtype)
- unique_tensor = torch.zeros(*kv_cache_shape, dtype=kv_cache_spec.dtype)
+ shared_tensor = torch.zeros(*shape, dtype=torch.int8).view(kv_cache_spec.dtype)
+ unique_tensor = torch.zeros(*shape, dtype=torch.int8).view(kv_cache_spec.dtype)
# Build kv_caches from the actual layer names in kv_cache_config so that
# _layer_specs lookups in register_kv_caches always find a matching key.
layer_names = [
@@ -2491,18 +2374,19 @@ def test_compatibility_hash_validation(
remote_hash = compute_nixl_compatibility_hash(
remote_vllm_config,
decode_worker.backend_name,
- decode_worker.transfer_topo.cross_layers_blocks,
)
prefill_block_size = config_overrides.get("block_size", 16)
+ prefill_block_lens = [4096 * prefill_block_size]
prefill_metadata = NixlAgentMetadata(
engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
kv_caches_base_addr=[0],
device_id=0,
num_blocks=1,
- block_lens=[4096 * prefill_block_size], # slot_size * block_size
- kv_cache_layout="HND",
+ block_lens=prefill_block_lens,
+ block_strides=prefill_block_lens,
+ kv_cache_layout="LBHNC",
block_size=prefill_block_size,
ssm_sizes=(0, 0),
attn_backend_name=decode_worker.backend_name,
@@ -2576,9 +2460,8 @@ def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario)
decode_worker = decode_connector.connector_worker
backend = get_current_attn_backend(local_vllm_config)
- test_shape = backend.get_kv_cache_shape(
- num_blocks=1, block_size=16, num_kv_heads=1, head_size=1
- )
+ probe_spec = decode_worker.kv_cache_config.kv_cache_groups[0].kv_cache_spec
+ test_shape = compute_layer_kv_cache_shape_bytes(probe_spec, 1)
decode_worker.transfer_topo = TransferTopology(
tp_rank=decode_worker.tp_rank,
tp_size=decode_worker.world_size,
@@ -2594,7 +2477,6 @@ def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario)
decode_worker.compat_hash = compute_nixl_compatibility_hash(
decode_worker.vllm_config,
decode_worker.backend_name,
- decode_worker.transfer_topo.cross_layers_blocks,
)
if error_scenario == "handshake_decode_error":
diff --git a/tests/v1/simple_kv_offload/test_scheduler.py b/tests/v1/simple_kv_offload/test_scheduler.py
index 970e16e52798..f65d66326492 100644
--- a/tests/v1/simple_kv_offload/test_scheduler.py
+++ b/tests/v1/simple_kv_offload/test_scheduler.py
@@ -84,7 +84,7 @@ def _make_kv_cache_config(
tensors.append(
KVCacheTensor(
size=_BYTES_PER_BLOCK * num_blocks,
- shared_by=layer_names,
+ shared_by=[layer_names],
)
)
return KVCacheConfig(
diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py
index 1a1352249c33..df61f8a36d6f 100644
--- a/tests/v1/worker/test_gpu_model_runner.py
+++ b/tests/v1/worker/test_gpu_model_runner.py
@@ -67,7 +67,7 @@ def initialize_kv_cache(runner: GPUModelRunner):
kv_cache_config = KVCacheConfig(
num_blocks=NUM_BLOCKS,
kv_cache_tensors=[
- KVCacheTensor(size=tensor_size, shared_by=["layer.0"]),
+ KVCacheTensor(size=tensor_size, shared_by=[["layer.0"]]),
],
kv_cache_groups=[
KVCacheGroupSpec(layer_names=["layer.0"], kv_cache_spec=attn_spec)
@@ -673,55 +673,6 @@ def test_update_states_pp_async_multi_request_keeps_rank_state_consistent(
)
-def test_kv_cache_stride_order(monkeypatch, model_runner):
- # This test checks if GPUModelRunner initializes correctly when an attention
- # backend enforces a non-default KV cache stride order.
- n_heads = model_runner.model_config.get_num_kv_heads(model_runner.parallel_config)
- head_size = model_runner.model_config.get_head_size()
-
- # Get the expected shape from the backend's get_kv_cache_shape method
- # to ensure compatibility with different backends (triton vs flexattention)
- attn_backend = None
- for attn_group in model_runner._attn_group_iterator():
- attn_backend = attn_group.backend
- break
-
- assert attn_backend is not None, "No attention backend found"
- expected_kv_cache_shape = list(
- attn_backend.get_kv_cache_shape(NUM_BLOCKS, BLOCK_SIZE, n_heads, head_size)
- )
-
- # TODO mla test
- default_stride = tuple(range(5))
- # Permutation that gets you back to expected kv shape
- for test_stride in ((1, 4, 0, 2, 3), (0, 1, 2, 3, 4)):
-
- def rnd_stride_order(
- include_num_layers_dimension: bool = False, test_stride=test_stride
- ):
- assert not include_num_layers_dimension
- return test_stride
-
- # Patch the attention backend class and re-trigger the KV cache creation
- for attn_group in model_runner._attn_group_iterator():
- attn_backend = attn_group.backend
- monkeypatch.setattr(
- attn_backend, "get_kv_cache_stride_order", rnd_stride_order
- )
-
- model_runner.attn_groups = []
- model_runner.kv_caches = []
- model_runner.initialize_kv_cache(model_runner.kv_cache_config)
-
- # Shape is unchanged, but layout may differ
- kv_cache_shape = model_runner.kv_caches[0].shape
- assert list(kv_cache_shape) == expected_kv_cache_shape
- if default_stride == test_stride:
- assert all(kv.is_contiguous() for kv in model_runner.kv_caches)
- else:
- assert all(not kv.is_contiguous() for kv in model_runner.kv_caches)
-
-
def test_update_config(model_runner):
# Simple update
model_runner.update_config({"load_config": {"load_format": "dummy"}})
@@ -971,21 +922,22 @@ def test_init_kv_cache_without_kv_sharing(default_vllm_config):
vllm_config, [kv_cache_spec], [available_memory]
)[0]
assert kv_cache_config.num_blocks == num_expected_blocks
- assert len(kv_cache_config.kv_cache_tensors) == 2
- assert kv_cache_config.kv_cache_tensors[0].size == available_memory // 2
- assert kv_cache_config.kv_cache_tensors[1].size == available_memory // 2
+ assert len(kv_cache_config.kv_cache_tensors) == 1
+ assert kv_cache_config.kv_cache_tensors[0].size == available_memory
max_context_len = estimate_max_model_len(vllm_config, kv_cache_spec, 5 * GiB_bytes)
# max context len with KV sharing should be 2x as large as without
assert max_context_len == 1310720
# important: override tensor size to prevent large mem alloc during test
- # this will only allocate 2 block worth of memory (2 * 32kb)
+ # this will only allocate 1 block worth of memory per slot (2 slots * 32kb)
kv_cache_config.num_blocks = 1
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
- kv_cache_tensor.size = kv_cache_spec[
- kv_cache_tensor.shared_by[0]
- ].page_size_bytes
+ num_layer_slots = len(kv_cache_tensor.shared_by)
+ kv_cache_tensor.size = (
+ kv_cache_spec[kv_cache_tensor.shared_by[0][0]].page_size_bytes
+ * num_layer_slots
+ )
runner.initialize_kv_cache(kv_cache_config)
@@ -1079,7 +1031,7 @@ def test_hybrid_attention_mamba_tensor_shapes():
"""
The GPU model runner creates different views into the
KVCacheTensors for the attention and mamba layers
- (via _reshape_kv_cache_tensors function). This test verifies
+ (via _allocate_kv_caches). This test verifies
that the views are compatible: writing a mamba block
will not corrupt an attention block and vice versa
"""
@@ -1385,7 +1337,7 @@ def test_hybrid_cache_integration(default_vllm_config, dist_init):
kv_cache_config = KVCacheConfig(
num_blocks=NUM_BLOCKS,
kv_cache_tensors=[
- KVCacheTensor(size=tensor_size, shared_by=["layer.0"]),
+ KVCacheTensor(size=tensor_size, shared_by=[["layer.0"]]),
],
kv_cache_groups=[
KVCacheGroupSpec(layer_names=["layer.0"], kv_cache_spec=attn_spec)
diff --git a/vllm/compilation/passes/fusion/mla_rope_kvcache_cat_fusion.py b/vllm/compilation/passes/fusion/mla_rope_kvcache_cat_fusion.py
index 5a493149a9d5..5998a0313840 100644
--- a/vllm/compilation/passes/fusion/mla_rope_kvcache_cat_fusion.py
+++ b/vllm/compilation/passes/fusion/mla_rope_kvcache_cat_fusion.py
@@ -45,7 +45,7 @@ def fused_rope_unified_mla_kv_cache_update_impl(
cos_sin_cache,
is_neox,
layer_slot_mapping,
- kv_cache,
+ kv_cache.squeeze(1),
kv_cache_dtype,
kv_cache_scale,
)
diff --git a/vllm/config/kv_transfer.py b/vllm/config/kv_transfer.py
index b22af99f703f..38d626670436 100644
--- a/vllm/config/kv_transfer.py
+++ b/vllm/config/kv_transfer.py
@@ -65,7 +65,7 @@ class KVTransferConfig:
Only supported in V1."""
enable_permute_local_kv: bool = False
- """Experiment feature flag to enable HND to NHD KV Transfer"""
+ """Experiment feature flag to enable HNC to NHC KV Transfer"""
kv_load_failure_policy: Literal["recompute", "fail"] = "fail"
"""Policy for handling KV cache load failures.
diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py
index d7a595716f08..593869b7886f 100644
--- a/vllm/distributed/kv_transfer/kv_connector/utils.py
+++ b/vllm/distributed/kv_transfer/kv_connector/utils.py
@@ -10,18 +10,20 @@
import torch
-from vllm.config import VllmConfig, get_current_vllm_config, get_layers_from_vllm_config
+from vllm.config import (
+ VllmConfig,
+ get_current_vllm_config_or_none,
+ get_layers_from_vllm_config,
+)
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
from vllm.logger import init_logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.platforms import current_platform
from vllm.v1.attention.backend import AttentionBackend
-from vllm.v1.kv_cache_interface import MambaSpec
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
if TYPE_CHECKING:
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
- from vllm.v1.kv_cache_interface import KVCacheSpec
logger = init_logger(__name__)
@@ -32,19 +34,18 @@
def get_kv_connector_cache_layout():
- # NOTE (NickLucche) When running disaggregated PD with NIXL, HND layout is
- # used for faster transfer.
- vllm_config = get_current_vllm_config()
+ # NOTE (NickLucche) When running disaggregated PD with NIXL, LBHNC layout
+ # is used for faster transfer.
+ vllm_config = get_current_vllm_config_or_none()
+ if vllm_config is None:
+ return None
kv_config = vllm_config.kv_transfer_config
if kv_config is not None:
connector_cls = KVConnectorFactory.get_connector_class(kv_config)
required_kvcache_layout = connector_cls.get_required_kvcache_layout(vllm_config)
if required_kvcache_layout is not None:
return required_kvcache_layout
- logger.info_once(
- "Connectors do not specify a kv cache layout, defaulting to NHD."
- )
- return "NHD"
+ return None
class KVOutputAggregator:
@@ -279,11 +280,11 @@ def kv_postprocess_layout_on_receive(cache, indices):
def kv_postprocess_blksize_and_layout_on_receive(cache, indices, block_size_ratio):
"""
- Transforms the layout of received KV cache to the local block_size and HND.
- (Only works for local blocksize > remote blocksize)
+ Transforms the layout of received KV cache to the local block_size
+ and LBHNC. (Only works for local blocksize > remote blocksize)
- prefill is HND, smaller block_size
- decode(local) is NHD, larger block_size
+ prefill is LBHNC, smaller block_size
+ decode(local) is LBNHC, larger block_size
"""
blocks_to_update = cache.index_select(0, indices)
@@ -408,43 +409,12 @@ def __post_init__(self):
self._engines: dict[EngineId, EngineTransferInfo] = {}
- # Figure out whether the first dimension of the cache is K/V
- # or num_blocks.
- attn_backend = self.attn_backends[0]
- if not self.is_mamba:
- _MOCK_BLOCK_SIZE = 16
- kv_cache_shape: tuple[int, ...] = attn_backend.get_kv_cache_shape(
- num_blocks=1,
- block_size=_MOCK_BLOCK_SIZE,
- num_kv_heads=1,
- head_size=1,
- )
- logger.debug("Test kv_cache_shape: %s", kv_cache_shape)
- # Non-MLA backends caches have 5 dims [num_blocks, 2, H,N,D],
- # we just mock num_blocks to 1 for the dimension check below.
- # Hybrid SSM models assume a single blocks_first layout
- self._is_kv_layout_blocks_first = self.is_mamba or (
- len(kv_cache_shape) == 5 and kv_cache_shape[0] == 1
- )
-
- self._cross_layers_blocks = False
- if self.tensor_shape is not None:
- self._cross_layers_blocks = (
- len(self.tensor_shape) == len(kv_cache_shape) + 1
- )
+ # Cross-layer layouts (BLHNC) have B outermost, so all layers
+ # for a block are contiguous — transfers can coalesce multiple
+ # layers into one operation.
+ from vllm.v1.attention.backends.utils import resolve_kv_cache_layout
- if self._cross_layers_blocks:
- logger.debug("Using cross-layer KV cache")
- _MOCK_NUM_LAYERS = 80
- kv_cache_shape = (_MOCK_NUM_LAYERS,) + kv_cache_shape
- try:
- kv_cache_stride_order = attn_backend.get_kv_cache_stride_order(
- include_num_layers_dimension=self._cross_layers_blocks
- )
- except (AttributeError, NotImplementedError):
- assert self.tensor_shape is not None
- kv_cache_stride_order = tuple(range(len(self.tensor_shape)))
- kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order)
+ self._is_kv_layout_blocks_first = not resolve_kv_cache_layout().is_layer_compact
# ============================================================
# Engine registration
@@ -480,26 +450,6 @@ def get_engine_info(self, remote_engine_id: EngineId) -> EngineTransferInfo:
def is_kv_layout_blocks_first(self) -> bool:
return self._is_kv_layout_blocks_first
- @property
- def cross_layers_blocks(self) -> bool:
- return self._cross_layers_blocks
-
- @property
- def virtually_split_kv_in_blocks(self) -> bool:
- # Whether to logically split each block into K and V halves.
- # Applies when K/V are interleaved within each block (blocks-first),
- # but NOT when cross-layer blocks are used — cross-layer blocks have
- # per-layer K/V interleaving (L0_K, L0_V, L1_K, L1_V, ...) so a
- # simple half-split does not separate K from V.
- return self._is_kv_layout_blocks_first and not self._cross_layers_blocks
-
- @property
- def split_k_and_v(self) -> bool:
- # Whether to register regions for K and V separately (when present).
- return not (
- self._cross_layers_blocks or self.is_mla or self.is_kv_layout_blocks_first
- )
-
# ============================================================
# Common methods
# ============================================================
@@ -571,31 +521,6 @@ def target_remote_ranks(self, remote_engine_id: EngineId) -> list[int]:
abs_ratio = -tp_ratio
return [self.tp_rank * abs_ratio + i for i in range(abs_ratio)]
- def get_transfer_cache_regions(
- self, cache: torch.Tensor, layer_spec: "KVCacheSpec"
- ) -> list[torch.Tensor] | torch.Tensor:
- """Return the cache tensor(s) to register as NIXL memory regions,
- also accounting for hybrid SSM models specificities.
- """
- if isinstance(layer_spec, MambaSpec):
- # Register the whole kv cache shared tensor, including
- # SSM/Conv.
- conv, ssm = cache
- return [conv]
-
- # Check may be hacky but it's matching
- # `_update_hybrid_attention_mamba_layout`.
- if self.is_mamba and cache.shape[0] == 2:
- # When MAMBA is present, all backends are blocks first, so
- # that blocks can be shared between attention layers and mamba
- # layers. Runner already adjusted strides for FlashAttn-like
- # backends so its num_blocks first.
- # Swap [2<>num_blocks] dims for hybrid SSM layout.
- cache = cache.transpose(0, 1)
-
- # Regular case: backends like FA register K/V in separate regions
- return cache if self.split_k_and_v else [cache]
-
def describe(self, remote_engine_id: EngineId) -> str:
"""One-line summary of transfer config for logging."""
info = self._engines[remote_engine_id]
diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py
index fb5658da887a..146c44730d31 100644
--- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py
+++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py
@@ -48,7 +48,7 @@
import torch
from vllm.logger import init_logger
-from vllm.v1.attention.backend import AttentionBackend, AttentionMetadata
+from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import KVConnectorOutput
@@ -173,14 +173,6 @@ class KVConnectorBase_V1(ABC):
Base class for KV connectors.
"""
- @property
- def prefer_cross_layer_blocks(self) -> bool:
- """
- Indicates whether this connector prefers KV blocks that hold KV data for all
- layers, which can speed up KV data transfers. Defaults to False.
- """
- return False
-
def __init__(
self,
vllm_config: "VllmConfig",
@@ -258,23 +250,6 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
"""
return
- def register_cross_layers_kv_cache(
- self, kv_cache: torch.Tensor, attn_backend: type["AttentionBackend"]
- ):
- """
- Initialize with a single KV cache tensor used by all layers.
- The first dimension should be num_layers.
- This function will only be called for models with uniform layers,
- and only if the prefers_cross_layer_blocks is set to True.
- Only one of the functions
- {register_kv_caches, register_cross_layers_kv_cache} will be called.
-
- Args:
- kv_cache: a cross-layers kv cache tensor
- attn_backend: The attention backend that corresponds to all layers
- """
- return
-
def set_host_xfer_buffer_ops(self, copy_operation: CopyBlocksOp):
"""
Set the xPU-specific ops for copying KV between host and device.
@@ -577,7 +552,7 @@ def get_required_kvcache_layout(cls, vllm_config: "VllmConfig") -> str | None:
vllm_config (VllmConfig): the vllm config.
Returns:
- str: the required KV cache layout. e.g. HND, or NHD.
+ str: the required KV cache layout. e.g. HNC, or NHC.
None if the connector does not require a specific layout.
"""
diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/example_hidden_states_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/example_hidden_states_connector.py
index 3e4e6750858a..5a500dbfa5e9 100644
--- a/vllm/distributed/kv_transfer/kv_connector/v1/example_hidden_states_connector.py
+++ b/vllm/distributed/kv_transfer/kv_connector/v1/example_hidden_states_connector.py
@@ -38,8 +38,10 @@ def extract_from_kv_cache(
num_tokens: int,
) -> torch.Tensor:
"""Extract data from KV cache."""
- block_size = kv_cache.shape[1]
- return kv_cache[slot_mapping // block_size, slot_mapping % block_size][:num_tokens]
+ block_size = kv_cache.shape[2]
+ return kv_cache[slot_mapping // block_size, :, slot_mapping % block_size][
+ :num_tokens
+ ]
def load_hidden_states(path: str) -> dict[str, torch.Tensor]:
@@ -123,15 +125,6 @@ class ExampleHiddenStatesConnector(KVConnectorBase_V1, SupportsHMA):
Must be used in conjunction with the `extract_hidden_states` spec decoding method.
"""
- @property
- def prefer_cross_layer_blocks(self) -> bool:
- """
- Indicates whether this connector prefers KV blocks that hold KV data for all
- layers, which can speed up KV data transfers. Defaults to False.
- """
- # Must be False so that drafter kv cache isn't merged with verifier's
- return False
-
def __init__(
self,
vllm_config: "VllmConfig",
@@ -500,7 +493,7 @@ def get_required_kvcache_layout(cls, vllm_config: "VllmConfig") -> str | None:
vllm_config (VllmConfig): the vllm config.
Returns:
- str: the required KV cache layout. e.g. HND, or NHD.
+ str: the required KV cache layout. e.g. HNC, or NHC.
None if the connector does not require a specific layout.
"""
@@ -509,9 +502,9 @@ def get_required_kvcache_layout(cls, vllm_config: "VllmConfig") -> str | None:
"get_required_kvcache_layout should not be called "
"on the abstract base class"
)
- # NHD means we have (num_tokens, num_heads)
- # HND means we have (num_heads, num_tokens)
- # For now, we only support NHD layout since this keeps the
+ # LBNHC means we have (num_tokens, num_heads)
+ # LBHNC means we have (num_heads, num_tokens)
+ # For now, we only support LBNHC layout since this keeps the
# hidden states for each token together in memory.
- # HND is primarily used when sharding heads across devices.
- return "NHD"
+ # LBHNC is primarily used when sharding heads across devices.
+ return "LBNHC"
diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_mp_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_mp_connector.py
index 8786e91a5a14..8ecadaf2b952 100644
--- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_mp_connector.py
+++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_mp_connector.py
@@ -974,7 +974,7 @@ def get_required_kvcache_layout(cls, vllm_config: "VllmConfig") -> str | None:
vllm_config (VllmConfig): the vllm config.
Returns:
- str: the required KV cache layout. e.g. HND, or NHD.
+ str: the required KV cache layout. e.g. HNC, or NHC.
None if the connector does not require a specific layout.
"""
diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py
index ccb7257c88c4..9e342a1a3d6f 100644
--- a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py
+++ b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py
@@ -51,7 +51,7 @@
from vllm.utils.math_utils import cdiv
from vllm.utils.network_utils import get_ip, make_zmq_path, make_zmq_socket
from vllm.v1.attention.backend import AttentionMetadata
-from vllm.v1.attention.backends.utils import get_kv_cache_layout
+from vllm.v1.attention.backends.utils import resolve_kv_cache_layout
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import FullAttentionSpec, SlidingWindowSpec
from vllm.v1.request import RequestStatus
@@ -109,7 +109,7 @@ def _get_tp_ratio(local_tp_size: int, remote_tp_size: int) -> int:
def _expand_transfer_regions(
base_addrs: list[int],
block_lens: list[int],
- is_kv_layout_blocks_first: bool,
+ is_kv_layout_blocks_first: bool, # kept for API compat, unused
) -> list[TransferRegion]:
"""Expand registered KV tensors into the regions transferred by Mooncake."""
assert len(base_addrs) == len(block_lens), (
@@ -118,22 +118,13 @@ def _expand_transfer_regions(
)
regions: list[TransferRegion] = []
for base_addr, block_len in zip(base_addrs, block_lens):
- kv_block_len = block_len // 2 if is_kv_layout_blocks_first else block_len
regions.append(
TransferRegion(
base_addr=base_addr,
block_len=block_len,
- kv_block_len=kv_block_len,
+ kv_block_len=block_len,
)
)
- if is_kv_layout_blocks_first:
- regions.append(
- TransferRegion(
- base_addr=base_addr + kv_block_len,
- block_len=block_len,
- kv_block_len=kv_block_len,
- )
- )
return regions
@@ -377,10 +368,10 @@ def get_required_kvcache_layout(cls, vllm_config: VllmConfig):
if vllm_config.model_config.use_mla:
return None
logger.info_once(
- "MooncakeConnector setting KV cache layout to HND for "
+ "MooncakeConnector setting KV cache layout to LBHNC for "
"heterogeneous TP-safe KV transfer."
)
- return "HND"
+ return "LBHNC"
############################################################
# Scheduler Side Methods
@@ -852,7 +843,7 @@ def __init__(
# NOTE (NickLucche) models with multiple backends are not supported yet
backend = get_current_attn_backend(vllm_config)
self.backend_name = backend.get_name()
- self.kv_cache_layout = get_kv_cache_layout()
+ self.kv_cache_layout = resolve_kv_cache_layout().name
logger.debug("Detected attention backend %s", self.backend_name)
logger.debug("Detected kv cache layout %s", self.kv_cache_layout)
@@ -1395,10 +1386,9 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
seen_base_addresses = []
self.block_len_per_layer = []
- split_k_and_v = self.transfer_topo.split_k_and_v
tensor_size_bytes = None
for layer_name, cache_or_caches in kv_caches.items():
- cache_list = cache_or_caches if split_k_and_v else [cache_or_caches]
+ cache_list = [cache_or_caches]
logger.debug(
"registering layer %s with %d cache tensor(s)",
layer_name,
@@ -1745,7 +1735,7 @@ def _get_transfer_regions(
return _expand_transfer_regions(
base_addrs=base_addrs,
block_lens=block_lens,
- is_kv_layout_blocks_first=self.transfer_topo.virtually_split_kv_in_blocks,
+ is_kv_layout_blocks_first=self.transfer_topo.is_kv_layout_blocks_first,
)
def _get_sender_transfer_plan(
diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
index 73418104bea6..372ce6f1b929 100644
--- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
+++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
@@ -27,7 +27,7 @@
PromMetricT,
)
from vllm.logger import init_logger
-from vllm.v1.attention.backend import AttentionBackend, AttentionMetadata
+from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import KVConnectorOutput
@@ -203,12 +203,6 @@ def __init__(
# Propagated from scheduler to worker side via the connector metadata.
self._extra_async_saves: dict[str, int] = {}
- @property
- def prefer_cross_layer_blocks(self) -> bool:
- if not self._connectors:
- return False
- return all(c.prefer_cross_layer_blocks for c in self._connectors)
-
@classmethod
def _get_connector_classes_and_configs(
cls, vllm_config: "VllmConfig"
@@ -235,13 +229,6 @@ def _get_connector_classes_and_configs(
)
return ret
- def register_cross_layers_kv_cache(
- self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend]
- ):
- # Register on all connectors
- for c in self._connectors:
- c.register_cross_layers_kv_cache(kv_cache, attn_backend)
-
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
for c in self._connectors:
c.register_kv_caches(kv_caches)
@@ -540,7 +527,7 @@ def get_required_kvcache_layout(cls, vllm_config: "VllmConfig") -> str | None:
vllm_config (VllmConfig): the vllm config.
Returns:
- str: the required KV cache layout. e.g. HND, or NHD.
+ str: the required KV cache layout. e.g. HNC, or NHC.
None if the connector does not require a specific layout.
"""
assert vllm_config.kv_transfer_config is not None
diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/connector.py
index 187322b4ae4e..37357adff352 100644
--- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/connector.py
+++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/connector.py
@@ -9,7 +9,6 @@
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.utils import (
EngineId,
- get_current_attn_backend,
)
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
CopyBlocksOp,
@@ -40,10 +39,8 @@
)
from vllm.forward_context import ForwardContext
from vllm.logger import init_logger
-from vllm.v1.attention.backend import AttentionBackend, AttentionMetadata
-from vllm.v1.attention.backends.utils import get_kv_cache_layout
+from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.core.sched.output import SchedulerOutput
-from vllm.v1.kv_cache_interface import MambaSpec
from vllm.v1.outputs import KVConnectorOutput
if TYPE_CHECKING:
@@ -55,36 +52,6 @@
class NixlConnector(KVConnectorBase_V1, SupportsHMA):
- @property
- def prefer_cross_layer_blocks(self) -> bool:
- if any(
- [
- isinstance(group.kv_cache_spec, MambaSpec)
- for group in self.kv_cache_config.kv_cache_groups
- ]
- ):
- # Hybrid SSM models do not yet support cross-layer layout
- return False
-
- backend = get_current_attn_backend(self._vllm_config)
- if backend.get_name() not in (
- "FLASH_ATTN",
- "FLASHINFER",
- "TRITON_ATTN",
- ):
- return False
-
- # For now there is no benefit to run cross layers when backend
- # does not support on HND
- if get_kv_cache_layout() != "HND":
- return False
-
- extra_config = self.kv_transfer_config.kv_connector_extra_config
- return (
- str(extra_config.get("enable_cross_layers_blocks", "False")).lower()
- == "true"
- )
-
def __init__(
self,
vllm_config: VllmConfig,
@@ -126,9 +93,10 @@ def get_required_kvcache_layout(cls, vllm_config: VllmConfig):
# which fallback to the default behavior.
return None
logger.info_once(
- "NixlConnector setting KV cache layout to HND for better xfer performance."
+ "NixlConnector setting KV cache layout to LBHNC for "
+ "better xfer performance."
)
- return "HND"
+ return "LBHNC"
############################################################
# Scheduler Side Methods
@@ -200,12 +168,6 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
assert self.connector_worker is not None
self.connector_worker.register_kv_caches(kv_caches)
- def register_cross_layers_kv_cache(
- self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend]
- ):
- assert self.connector_worker is not None
- self.connector_worker.register_cross_layers_kv_caches(kv_cache)
-
def set_host_xfer_buffer_ops(self, copy_operation: CopyBlocksOp):
assert self.connector_worker is not None
self.connector_worker.set_host_xfer_buffer_ops(copy_operation)
diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/metadata.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/metadata.py
index b9e3436f5019..848a0451ac90 100644
--- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/metadata.py
+++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/metadata.py
@@ -34,8 +34,9 @@
# 2: Add remote_request_id to kv_transfer_params
# 3: Add physical_blocks_per_logical_kv_block to NixlAgentMetadata
# 4: Add KV block lease renewal through heartbeats
+# 5: Add block_strides
#
-NIXL_CONNECTOR_VERSION: int = 4
+NIXL_CONNECTOR_VERSION: int = 5
@dataclass
@@ -46,6 +47,7 @@ class NixlAgentMetadata:
device_id: int
num_blocks: int
block_lens: list[int]
+ block_strides: list[int]
kv_cache_layout: str
block_size: int
ssm_sizes: tuple[int, int]
@@ -72,7 +74,7 @@ class NixlHandshakePayload(KVConnectorHandshakeMetadata):
def compute_nixl_compatibility_hash(
- vllm_config: VllmConfig, attn_backend_name: str, cross_layers_blocks: bool
+ vllm_config: VllmConfig, attn_backend_name: str
) -> str:
"""
Compute compatibility hash for NIXL KV transfer.
@@ -116,7 +118,6 @@ def compute_nixl_compatibility_hash(
# Attention backend and KV cache dtype affect memory layout
"attn_backend_name": attn_backend_name,
"cache_dtype": str(cache_config.cache_dtype),
- "cross_layers_blocks": cross_layers_blocks,
"is_hma_enabled": is_hma_enabled,
}
diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py
index 0d30d4a692ad..2d4907c0b071 100644
--- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py
+++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py
@@ -67,9 +67,10 @@
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils.network_utils import make_zmq_path
-from vllm.v1.attention.backends.utils import get_kv_cache_layout
+from vllm.v1.attention.backends.utils import resolve_kv_cache_layout
from vllm.v1.kv_cache_interface import (
FullAttentionSpec,
+ KVCacheLayout,
MambaSpec,
UniformTypeKVCacheSpecs,
)
@@ -235,7 +236,9 @@ def __init__(
for group in kv_cache_config.kv_cache_groups
for layer in group.layer_names
}
- self.hma_group_size = len(kv_cache_config.kv_cache_tensors)
+ self.hma_group_size = sum(
+ len(t.shared_by) for t in kv_cache_config.kv_cache_tensors
+ )
# ---- Model state (derived from model config) ----
mamba_ssm_size = (0, 0)
@@ -418,7 +421,7 @@ def __init__(
self.attn_backends = get_current_attn_backends(vllm_config)
self.backend_name = self.attn_backends[0].get_name()
- self.kv_cache_layout = get_kv_cache_layout()
+ self.kv_cache_layout = resolve_kv_cache_layout().name
self.host_buffer_kv_cache_layout = self.kv_cache_layout
logger.info(
"Detected attention backend(s) %s",
@@ -602,18 +605,18 @@ def initialize_host_xfer_buffer(self, kv_caches: dict[str, torch.Tensor]) -> Non
kv_dtype = kv_cache.dtype
permute_shape = False
if (
- self.kv_cache_layout == "NHD"
+ self.kv_cache_layout == "LBNHC"
and self.vllm_config.kv_transfer_config is not None
and self.vllm_config.kv_transfer_config.enable_permute_local_kv
):
logger.info_once(
"'enable_permute_local_kv' flag is enabled while "
- "device KV Layout is NHD. Init host buffer with"
- " HND to better support Decode/Prefill TP_ratio > 1."
+ "device KV Layout is LBNHC. Init host buffer with"
+ " LBHNC to better support Decode/Prefill TP_ratio > 1."
)
- # Since NHD will not support Decode/Prefill TP_ratio > 1,
+ # Since LBNHC will not support Decode/Prefill TP_ratio > 1,
# we can leverage host_buffer for permute
- self.host_buffer_kv_cache_layout = "HND"
+ self.host_buffer_kv_cache_layout = "LBHNC"
kv_shape = (
tuple(kv_shape[i] for i in inv_order)
if not self.use_mla
@@ -774,17 +777,6 @@ def request_ready(f: Future[Any], entry=(req_id, meta)):
fut.add_done_callback(request_ready)
- def register_cross_layers_kv_caches(self, kv_cache: torch.Tensor) -> None:
- """Register a cross-layers KV cache tensor with NIXL.
-
- `use_uniform_kv_cache()` guarantees a single KV cache group whose
- layers all share the same `AttentionSpec`, so any layer name from
- `_layer_specs` yields the correct per-layer spec for `page_size_bytes`.
- """
- first_layer = next(iter(self._layer_specs))
- # Forwarding a real layer name rather than a synthetic key
- self.register_kv_caches({first_layer: kv_cache})
-
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
"""Register the KV Cache data in nixl."""
self.transfer_topo = TransferTopology(
@@ -802,7 +794,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
is_mamba=self._has_mamba,
)
self.compat_hash = compute_nixl_compatibility_hash(
- self.vllm_config, self.backend_name, self.transfer_topo.cross_layers_blocks
+ self.vllm_config, self.backend_name
)
if self.use_host_buffer:
@@ -844,6 +836,11 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
# Enable different block lengths for different layers *only* when MLA is used.
# This is not used for SSM layers, which use the counterpart `mamba_ssm_size`.
self.block_len_per_layer = list[int]()
+
+ # Per-block physical stride per layer (bytes). Read from each
+ # registered tensor's stride(0) so it stays correct under layouts
+ # that interleave layers within a block (BLHNC/BHLNC).
+ self.block_stride_per_layer = list[int]()
for layer_name, cache_or_caches in xfer_buffers.items():
# NOTE (NickLucche) Hybrid SSM models assume a layout that is similar to
# that of FI, with block laid out as in `get_backend_aware_kv_block_len`.
@@ -854,24 +851,12 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
if isinstance(layer_spec, UniformTypeKVCacheSpecs):
# MLA DSv32 Indexer case: UniformTypeKVCacheSpecs merges kv_cache_specs
layer_spec = layer_spec.kv_cache_specs[layer_name]
- cache_list = self.transfer_topo.get_transfer_cache_regions(
- cache_or_caches, layer_spec
- )
- # `layer_spec.page_size_bytes` only accounts for logical page_size, that is
- # the page_size assuming constant `self._logical_num_blocks`.
physical_page_size = (
layer_spec.page_size_bytes
if isinstance(layer_spec, MambaSpec)
else layer_spec.page_size_bytes
// self._physical_blocks_per_logical_kv_block
)
- # For when registering multiple tensors eg K/V in separate regions.
- physical_page_size = physical_page_size // len(cache_list)
- if self.transfer_topo._cross_layers_blocks:
- # When cross-layers blocks are used, multiply by number of layers
- physical_page_size = physical_page_size * len(
- self.kv_cache_config.kv_cache_tensors
- )
num_blocks = (
self._logical_num_blocks
if isinstance(layer_spec, MambaSpec)
@@ -883,75 +868,63 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
if tensor_size_bytes is None:
tensor_size_bytes = curr_tensor_size_bytes
- # TODO (NickLucche) we could eventually unify how we handle FA/FI regions,
- # registering a single tensor for both K/V and splitting logically like FI.
- for cache in cache_list:
- base_addr = cache.data_ptr()
- if base_addr in seen_base_addresses:
- # NOTE (NickLucche) HMA employs memory pooling to share tensors
- # across groups. This results in skipping all tensors but the ones
- # pointed to by group0. Also, generally we will have more blocks
- # per tensor but fewer regions.
- logger.debug("Skipping %s because it's already seen", layer_name)
- continue
- logger.debug(
- "Registering layer %s with cache shape: %s", layer_name, cache.shape
+ cache = cache_or_caches
+ base_addr = cache.data_ptr()
+ if base_addr in seen_base_addresses:
+ logger.debug("Skipping %s because it's already seen", layer_name)
+ continue
+ logger.debug(
+ "Registering layer %s with cache shape: %s",
+ layer_name,
+ cache.shape,
+ )
+ seen_base_addresses.append(base_addr)
+ if isinstance(layer_spec, MambaSpec):
+ self.block_len_per_layer.append(
+ physical_page_size // self._physical_blocks_per_logical_kv_block
+ )
+ self.block_stride_per_layer.append(self.block_len_per_layer[-1])
+ else:
+ self.block_len_per_layer.append(physical_page_size)
+ self.block_stride_per_layer.append(
+ cache.stride(0) * cache.element_size()
)
- seen_base_addresses.append(base_addr)
- # Only record non-Mamba page sizes.
- if isinstance(layer_spec, MambaSpec):
- self.block_len_per_layer.append(
- physical_page_size // self._physical_blocks_per_logical_kv_block
- )
- else:
- self.block_len_per_layer.append(physical_page_size)
-
- if cache.shape[0] != num_blocks:
- raise AssertionError(
- "All kv cache tensors must have the same number of "
- f"blocks; layer={layer_name}, "
- f"expected_num_blocks={num_blocks}, "
- f"cache_shape={tuple(cache.shape)}, "
- f"cache_stride={tuple(cache.stride())}, "
- f"layer_spec={type(layer_spec).__name__}, "
- f"backend={self.backend_name}, "
- "all_backends="
- f"{[backend.get_name() for backend in self.attn_backends]}, "
- f"kv_cache_layout={self.kv_cache_layout}, "
- "blocks_first="
- f"{self.transfer_topo.is_kv_layout_blocks_first}"
- )
- if not self.use_mla:
- # Different kv cache shape is not supported by HeteroTP.
- # This must also hold true for Mamba-like models.
- assert tensor_size_bytes == curr_tensor_size_bytes, (
- "All kv cache tensors must have the same size"
- )
- # Need to make sure the device ID is non-negative for NIXL,
- # Torch uses -1 to indicate CPU tensors.
- self.device_id = max(cache.get_device(), 0)
- caches_data.append(
- (base_addr, curr_tensor_size_bytes, self.device_id, "")
+ if cache.shape[0] != num_blocks:
+ raise AssertionError(
+ "All kv cache tensors must have the same number of "
+ f"blocks; layer={layer_name}, "
+ f"expected_num_blocks={num_blocks}, "
+ f"cache_shape={tuple(cache.shape)}, "
+ f"cache_stride={tuple(cache.stride())}, "
+ f"layer_spec={type(layer_spec).__name__}, "
+ f"backend={self.backend_name}, "
+ "all_backends="
+ f"{[backend.get_name() for backend in self.attn_backends]}, "
+ f"kv_cache_layout={self.kv_cache_layout}, "
+ "blocks_first="
+ f"{self.transfer_topo.is_kv_layout_blocks_first}"
)
+ if not self.use_mla:
+ assert tensor_size_bytes == curr_tensor_size_bytes, (
+ "All kv cache tensors must have the same size"
+ )
+ self.device_id = max(cache.get_device(), 0)
+ caches_data.append((base_addr, curr_tensor_size_bytes, self.device_id, ""))
+
logger.debug(
"Different block lengths collected: %s", set(self.block_len_per_layer)
)
assert len(self.block_len_per_layer) == len(seen_base_addresses)
+ assert len(self.block_stride_per_layer) == len(seen_base_addresses)
self.kv_caches_base_addr[self.engine_id][self.tp_rank] = seen_base_addresses
self.num_regions = len(caches_data)
- if self.transfer_topo.virtually_split_kv_in_blocks:
- # NOTE (NickLucche) When FlashInfer is used, memory is registered
- # with joint KV for each block. This minimizes the overhead in
- # registerMem allowing faster descs queries. In order to be able to
- # split on kv_heads dim as required by heterogeneous TP, one must
- # be able to index K/V separately. Hence we double the number
- # of 'virtual' regions here and halve `block_len` below.
- # Similarly for Mamba layers, we register SSM+Conv as a single region and
- # then duplicate it logically to be able to index SSM/Conv separately.
+ if self.transfer_topo.is_kv_layout_blocks_first and self._has_mamba:
+ # For Mamba layers, SSM+Conv are registered as a single region
+ # and duplicated logically to index SSM/Conv separately.
self.num_regions *= 2
# Total local FA descriptors (boundary between FA and mamba descs).
@@ -993,6 +966,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id][self.tp_rank],
num_blocks=self.num_blocks,
block_lens=self.block_len_per_layer,
+ block_strides=self.block_stride_per_layer,
kv_cache_layout=self.kv_cache_layout
if not self.use_host_buffer
else self.host_buffer_kv_cache_layout,
@@ -1112,24 +1086,12 @@ def _build_fa_local(
)
// block_size_ratio
)
- page_stride = self.block_len_per_layer[i] // block_size_ratio
+ page_stride = self.block_stride_per_layer[i] // block_size_ratio
for block_id in range(num_blocks):
block_offset = block_id * page_stride
addr = base_addr + block_offset
result.append((addr, kv_block_len, self.device_id))
- if self.transfer_topo.virtually_split_kv_in_blocks:
- # Separate and interleave K/V regions to maintain the same
- # descs ordering. This is needed for selecting contiguous heads
- # when split across TP ranks.
- second_split = self.get_backend_aware_kv_block_len(
- layer_idx=i, first_split=False, mamba_view=False
- )
- for block_id in range(num_blocks):
- block_offset = block_id * page_stride
- addr = base_addr + block_offset
- v_addr = addr + kv_block_len
- result.append((v_addr, second_split, self.device_id))
return result
def _build_fa_remote(
@@ -1159,7 +1121,7 @@ def _build_fa_remote(
local_block_len = local_block_len // num_attn_reads
rank_offset = plan.rank_offset_factor * remote_kv_block_len
- page_size = nixl_agent_meta.block_lens[i]
+ page_size = nixl_agent_meta.block_strides[i]
for block_id in range(num_blocks):
block_offset = block_id * page_size
# For each block, grab the kv heads chunk belonging to current local
@@ -1167,18 +1129,6 @@ def _build_fa_remote(
addr = base_addr + block_offset + rank_offset
result.append((addr, local_block_len, nixl_agent_meta.device_id))
- if self.transfer_topo.virtually_split_kv_in_blocks:
- # With FlashInfer index V separately to allow head splitting.
- second_split = self.get_backend_aware_kv_block_len(
- layer_idx=i, first_split=False, mamba_view=False
- )
- second_split = second_split // num_attn_reads
- for block_id in range(num_blocks):
- block_offset = block_id * page_size
- addr = base_addr + block_offset + rank_offset
- # Hop over the first split of remote page, K, to read V.
- v_addr = addr + nixl_agent_meta.block_lens[i] // 2
- result.append((v_addr, second_split, nixl_agent_meta.device_id))
return result
def register_local_xfer_handler(
@@ -1261,7 +1211,7 @@ def add_remote_agent(
tp_ratio = 4 // 2 = 2
Considering the KV Caches, if P-Worker_i has cache size [2, num_blocksP, kv_heads, block_size, head_dim]
- then D-Worker_j has [2, num_blocksD, kv_heads//tp_ratio, block_size, head_dim]. Mind the "HND" layout format.
+ then D-Worker_j has [2, num_blocksD, kv_heads//tp_ratio, block_size, head_dim]. Mind the "LBHNC" layout format.
Assuming num_blocksD >= num_blocksP, D-Worker0 reads from P-Worker0 by preparing the kv_heads//tp_ratio
first heads from all the slots of all the blocks. D-Worker1 will do the same, but reading the second split
along the kv_heads dimension, and so forth until "tp_ratio" D TP workers have pulled from P-Worker0.
@@ -1465,10 +1415,10 @@ def _validate_remote_agent_handshake(
if not self.use_mla and nixl_agent_meta.kv_cache_layout != kv_cache_layout:
if (
self.kv_transfer_config.enable_permute_local_kv
- and nixl_agent_meta.kv_cache_layout == "HND"
+ and nixl_agent_meta.kv_cache_layout == "LBHNC"
):
logger.info(
- "Remote is HND and local is NHD, enabled additional permute "
+ "Remote is LBHNC and local is LBNHC, enabled additional permute "
"on local device KV."
)
assert not self._is_hma_required, (
@@ -1478,7 +1428,7 @@ def _validate_remote_agent_handshake(
else:
raise RuntimeError(
"Heterogeneous TP expects same kv_cache_layout. "
- "Or enable experimental feature to use HND to NHD support by "
+ "Or enable experimental feature to use HNC to NHC support by "
"setting 'enable_permute_local_kv'=True in --kv-transfer-config."
)
# if remote_agent used attn is not same as local,
@@ -1498,18 +1448,18 @@ def _validate_remote_agent_handshake(
self.enable_heterogeneous_attn_post_process = True
# Heterogeneous TP requires head-splitting, which only works with
- # HND layout. MLA and replicated-KV cases don't split on heads.
- # Mamba doesn't support heterogeneous TP.
+ # block-contiguous layouts (H before N, e.g. HNC / BHLNC).
+ # MLA and replicated-KV cases don't split on heads.
if (
abs(tp_ratio) != 1
and not self.use_mla
and not self.transfer_topo.is_kv_replicated(remote_engine_id)
- and kv_cache_layout != "HND"
+ and not KVCacheLayout[kv_cache_layout].is_block_contiguous
and not self.enable_permute_local_kv
):
raise RuntimeError(
"Heterogeneous TP head-dimension splitting requires contiguous heads. "
- "Use HND layout on the prefill side."
+ "Use HNC layout on the prefill side."
)
# Block len can only vary across layers when using MLA.
@@ -1620,11 +1570,11 @@ def post_process_device_kv_on_receive(
Post process device kv cache after receiving from remote.
3 types of post processing supported:
- * kv_cache_postprocess_layout => convert from HND to NHD
+ * kv_cache_postprocess_layout => convert from HNC to NHC
* kv_cache_postprocess_blksize => convert from small block size
to large block size
* kv_cache_postprocess_blksize_and_layout => convert from small
- block size to large block size and convert from HND to NHD
+ block size to large block size and convert from HNC to NHC
"""
if len(self.device_kv_caches) == 0:
@@ -1634,14 +1584,14 @@ def post_process_device_kv_on_receive(
if self.enable_permute_local_kv and block_size_ratio > 1:
logger.debug(
"Post-processing device kv cache on receive by converting "
- "block_size with %sx bigger and permuting layout from HND"
- " to NHD.",
+ "block_size with %sx bigger and permuting layout from HNC"
+ " to NHC.",
block_size_ratio,
)
elif self.enable_permute_local_kv:
logger.debug(
"Post-processing device kv cache on receive by permuting layout"
- "from HND to NHD."
+ "from HNC to NHC."
)
else:
logger.debug(
@@ -1650,24 +1600,18 @@ def post_process_device_kv_on_receive(
block_size_ratio,
)
- split_k_and_v = self.transfer_topo.split_k_and_v
-
for block_ids in block_ids_list:
indices = torch.tensor(block_ids, device=self.device_type, dtype=torch.long)
- for _, cache_or_caches in self.device_kv_caches.items():
- cache_list = cache_or_caches if split_k_and_v else [cache_or_caches]
- for cache in cache_list:
- if self.enable_permute_local_kv and block_size_ratio > 1:
- kv_postprocess_blksize_and_layout_on_receive(
- cache, indices, block_size_ratio
- )
- elif self.enable_permute_local_kv:
- kv_postprocess_layout_on_receive(cache, indices)
- else:
- kv_postprocess_blksize_on_receive(
- cache, indices, block_size_ratio
- )
+ for _, cache in self.device_kv_caches.items():
+ if self.enable_permute_local_kv and block_size_ratio > 1:
+ kv_postprocess_blksize_and_layout_on_receive(
+ cache, indices, block_size_ratio
+ )
+ elif self.enable_permute_local_kv:
+ kv_postprocess_layout_on_receive(cache, indices)
+ else:
+ kv_postprocess_blksize_on_receive(cache, indices, block_size_ratio)
def post_process_device_kv_on_receive_heterogeneous_attn(
self, block_ids: list[int]
@@ -2416,11 +2360,8 @@ def get_backend_aware_kv_block_len(
|1st_split-2nd_split| |1st_split-2nd_split |
"""
assert self.transfer_topo is not None
- if self.transfer_topo.virtually_split_kv_in_blocks:
- if mamba_view:
- block_len = self._mamba_ssm_size[not first_split]
- else:
- block_len = self.block_len_per_layer[layer_idx] // 2
+ if self.transfer_topo.is_kv_layout_blocks_first and mamba_view:
+ block_len = self._mamba_ssm_size[not first_split]
else:
block_len = self.block_len_per_layer[layer_idx]
return block_len
diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/offloading/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/offloading/worker.py
index 8957ce3445ae..a271aa425624 100644
--- a/vllm/distributed/kv_transfer/kv_connector/v1/offloading/worker.py
+++ b/vllm/distributed/kv_transfer/kv_connector/v1/offloading/worker.py
@@ -17,7 +17,6 @@
OffloadingConnectorStats,
)
from vllm.logger import init_logger
-from vllm.v1.attention.backend import AttentionBackend
from vllm.v1.kv_cache_interface import (
AttentionSpec,
MambaSpec,
@@ -59,7 +58,8 @@ def register_kv_caches(
):
num_blocks = self.spec.kv_cache_config.num_blocks
- # layer_name -> (num_blocks, page_size_bytes) tensor
+ # layer_name -> (num_blocks, page_size_bytes) int8 view.
+ # Standardized layouts always have num_blocks as the leading dim.
tensors_per_block: dict[str, tuple[torch.Tensor, ...]] = {}
# layer_name -> size of (un-padded) page in bytes
unpadded_page_size_bytes: dict[str, int] = {}
@@ -76,98 +76,79 @@ def register_kv_caches(
layer_kv_cache_spec = per_layer_specs.get(
layer_name, group_kv_cache_spec
)
- if isinstance(layer_kv_cache_spec, AttentionSpec):
- layer_kv_cache = kv_caches[layer_name]
- assert isinstance(layer_kv_cache, torch.Tensor)
- assert layer_kv_cache.storage_offset() == 0
+ layer_kv_cache = kv_caches[layer_name]
+ # AttentionSpec yields a single tensor; MambaSpec yields a
+ # list of typed state tensors that share one underlying
+ # buffer. Either way, the first tensor's storage_offset
+ # marks the start of this layer's region.
+ ref = (
+ layer_kv_cache[0]
+ if isinstance(layer_kv_cache, list)
+ else layer_kv_cache
+ )
+ page = layer_kv_cache_spec.page_size_bytes
+ offset = ref.storage_offset() * ref.element_size()
+ tensors_per_block[layer_name] = (
+ torch.tensor([], dtype=torch.int8, device=ref.device)
+ .set_(ref.untyped_storage())
+ .view(-1)[offset : offset + num_blocks * page]
+ .view(num_blocks, page),
+ )
+ page_size_bytes[layer_name] = page
- storage = layer_kv_cache.untyped_storage()
- page = layer_kv_cache_spec.page_size_bytes
- tensors_per_block[layer_name] = (
- torch.tensor(
- [],
- dtype=torch.int8,
- device=layer_kv_cache.device,
- )
- .set_(storage)
- .view(num_blocks, page),
- )
- page_size_bytes[layer_name] = layer_kv_cache_spec.page_size_bytes
+ if isinstance(layer_kv_cache_spec, AttentionSpec):
unpadded_page_size_bytes[layer_name] = (
layer_kv_cache_spec.real_page_size_bytes
)
-
elif isinstance(layer_kv_cache_spec, MambaSpec):
- state_tensors = kv_caches[layer_name]
- assert isinstance(state_tensors, list)
-
- # re-construct the raw (num_blocks, page_size) tensor
- # from the first state tensor
- assert len(state_tensors) > 0
- first_state_tensor = state_tensors[0]
- assert first_state_tensor.storage_offset() == 0
- tensor = (
- torch.tensor(
- [],
- dtype=torch.int8,
- device=first_state_tensor.device,
- )
- .set_(first_state_tensor.untyped_storage())
- .view((num_blocks, layer_kv_cache_spec.page_size_bytes))
- )
- tensors_per_block[layer_name] = (tensor,)
-
- page_size_bytes[layer_name] = layer_kv_cache_spec.page_size_bytes
unpadded_page_size_bytes[layer_name] = replace(
layer_kv_cache_spec, page_size_padded=None
).page_size_bytes
-
else:
raise NotImplementedError
block_tensors: list[CanonicalKVCacheTensor] = []
block_data_refs: dict[str, list[CanonicalKVCacheRef]] = defaultdict(list)
for kv_cache_tensor in self.spec.kv_cache_config.kv_cache_tensors:
- # Filter to layers that were actually processed above.
- # _get_kv_cache_config_deepseek_v4 emits KVCacheTensor entries for
- # every (tuple_idx, page_size) slot; slots where no group has a
- # layer at that index produce an empty shared_by (reserved memory
- # with no corresponding model layer).
- tensor_layer_names = [
- n for n in kv_cache_tensor.shared_by if n in tensors_per_block
- ]
- if not tensor_layer_names:
- continue
-
- # verify all layers in the group reference the exact same tensors
- assert len({len(tensors_per_block[n]) for n in tensor_layer_names}) == 1
- assert (
- len({tensors_per_block[n][0].data_ptr() for n in tensor_layer_names})
- == 1
- )
- assert (
- len({tensors_per_block[n][0].stride() for n in tensor_layer_names}) == 1
- )
-
- # pick the first layer to represent the group
- first_layer_name = tensor_layer_names[0]
- for tensor in tensors_per_block[first_layer_name]:
- block_tensors.append(
- CanonicalKVCacheTensor(
- tensor=tensor,
- page_size_bytes=page_size_bytes[first_layer_name],
+ for slot_layers in kv_cache_tensor.shared_by:
+ # Filter to layers that were actually processed above.
+ # Some slots may have no corresponding model layer (reserved
+ # memory with no group layer at that index).
+ tensor_layer_names = [n for n in slot_layers if n in tensors_per_block]
+ if not tensor_layer_names:
+ continue
+
+ # Verify all layers in the slot reference the same tensors.
+ assert len({len(tensors_per_block[n]) for n in tensor_layer_names}) == 1
+ assert (
+ len(
+ {tensors_per_block[n][0].data_ptr() for n in tensor_layer_names}
)
+ == 1
+ )
+ assert (
+ len({tensors_per_block[n][0].stride() for n in tensor_layer_names})
+ == 1
)
- curr_tensor_idx = len(block_tensors) - 1
- for layer_name in tensor_layer_names:
- block_data_refs[layer_name].append(
- CanonicalKVCacheRef(
- tensor_idx=curr_tensor_idx,
- page_size_bytes=(unpadded_page_size_bytes[layer_name]),
+ first_layer_name = tensor_layer_names[0]
+ for tensor in tensors_per_block[first_layer_name]:
+ block_tensors.append(
+ CanonicalKVCacheTensor(
+ tensor=tensor,
+ page_size_bytes=page_size_bytes[first_layer_name],
)
)
+ curr_tensor_idx = len(block_tensors) - 1
+ for layer_name in tensor_layer_names:
+ block_data_refs[layer_name].append(
+ CanonicalKVCacheRef(
+ tensor_idx=curr_tensor_idx,
+ page_size_bytes=(unpadded_page_size_bytes[layer_name]),
+ )
+ )
+
group_data_refs: list[list[CanonicalKVCacheRef]] = []
for kv_cache_group in self.spec.kv_cache_config.kv_cache_groups:
group_refs: list[CanonicalKVCacheRef] = []
@@ -182,53 +163,6 @@ def register_kv_caches(
self._register_handlers(canonical_kv_caches)
- def register_cross_layers_kv_cache(
- self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend]
- ):
- # verify that num_blocks is at physical position 0 in the cross-layers
- # tensor layout.
- test_shape = attn_backend.get_kv_cache_shape(
- num_blocks=1234, block_size=16, num_kv_heads=1, head_size=256
- )
- num_blocks_logical_dim = test_shape.index(1234) + 1
- physical_to_logical = attn_backend.get_kv_cache_stride_order(
- include_num_layers_dimension=True
- )
- num_blocks_physical_dim = physical_to_logical.index(num_blocks_logical_dim)
- assert num_blocks_physical_dim == 0
-
- kv_cache_groups = self.spec.kv_cache_config.kv_cache_groups
- assert len(kv_cache_groups) == 1
- kv_cache_spec = kv_cache_groups[0].kv_cache_spec
- num_layers = len(kv_cache_groups[0].layer_names)
- page_size_bytes = kv_cache_spec.page_size_bytes * num_layers
-
- assert kv_cache.storage_offset() == 0
- storage = kv_cache.untyped_storage()
- assert len(storage) % page_size_bytes == 0
- num_blocks = len(storage) // page_size_bytes
- tensor = (
- torch.tensor(
- [],
- dtype=torch.int8,
- device=kv_cache.device,
- )
- .set_(storage)
- .view(num_blocks, page_size_bytes)
- )
- kv_cache_tensor = CanonicalKVCacheTensor(
- tensor=tensor, page_size_bytes=page_size_bytes
- )
- # in cross layers layout, there's currently only a single group
- kv_cache_data_ref = CanonicalKVCacheRef(
- tensor_idx=0, page_size_bytes=page_size_bytes
- )
- canonical_kv_caches = CanonicalKVCaches(
- tensors=[kv_cache_tensor], group_data_refs=[[kv_cache_data_ref]]
- )
-
- self._register_handlers(canonical_kv_caches)
-
def handle_preemptions(self, kv_connector_metadata: OffloadingConnectorMetadata):
for job_id, transfer_spec in self._unsubmitted_store_jobs:
success = self.worker.transfer_async(job_id, transfer_spec)
diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py
index 20888c71f84c..48a774dea9ea 100644
--- a/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py
+++ b/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py
@@ -34,7 +34,7 @@
OffloadingConnectorWorker,
)
from vllm.forward_context import ForwardContext
-from vllm.v1.attention.backend import AttentionBackend, AttentionMetadata
+from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig
@@ -44,10 +44,6 @@
class OffloadingConnector(KVConnectorBase_V1, SupportsHMA):
- @property
- def prefer_cross_layer_blocks(self) -> bool:
- return True
-
def __init__(
self,
vllm_config: VllmConfig,
@@ -75,12 +71,6 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
assert self.connector_worker is not None
self.connector_worker.register_kv_caches(kv_caches)
- def register_cross_layers_kv_cache(
- self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend]
- ):
- assert self.connector_worker is not None
- self.connector_worker.register_cross_layers_kv_cache(kv_cache, attn_backend)
-
def handle_preemptions(self, kv_connector_metadata: KVConnectorMetadata):
assert self.connector_worker is not None
assert isinstance(kv_connector_metadata, OffloadingConnectorMetadata)
@@ -176,7 +166,7 @@ def take_events(self) -> Iterable[KVCacheEvent]:
@classmethod
def get_required_kvcache_layout(cls, vllm_config: VllmConfig) -> str | None:
- return "HND"
+ return "LBHNC"
def reset_cache(self) -> bool | None:
assert self.connector_scheduler is not None
diff --git a/vllm/envs.py b/vllm/envs.py
index dc11fbd224d9..fe31bf49375a 100755
--- a/vllm/envs.py
+++ b/vllm/envs.py
@@ -205,7 +205,9 @@
VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS: int = 1
VLLM_MQ_MAX_CHUNK_BYTES_MB: int = 16
VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS: int = 300
- VLLM_KV_CACHE_LAYOUT: Literal["NHD", "HND"] | None = None
+ VLLM_KV_CACHE_LAYOUT: (
+ Literal["LBNHC", "LBHNC", "NHC", "HNC", "NHD", "HND", "BLHNC", "BHLNC"] | None
+ ) = None
VLLM_SSM_CONV_STATE_LAYOUT: Literal["SD", "DS"] | None = None
VLLM_COMPUTE_NANS_IN_LOGITS: bool = False
VLLM_USE_NVFP4_CT_EMULATIONS: bool = False
@@ -1617,18 +1619,20 @@ def _resolve_rust_frontend_path() -> str | None:
),
# KV Cache layout used throughout vllm.
# Some common values are:
- # - NHD
- # - HND
- # Where N=num_blocks, H=num_heads and D=head_size. The default value will
- # leave the layout choice to the backend. Mind that backends may only
+ # - LBNHC
+ # - LBHNC
+ # Where N=num_states, H=num_heads and C=state_content. The default value
+ # will leave the layout choice to the backend. Mind that backends may only
# implement and support a subset of all possible layouts.
"VLLM_KV_CACHE_LAYOUT": env_with_choices(
- "VLLM_KV_CACHE_LAYOUT", None, ["NHD", "HND"]
+ "VLLM_KV_CACHE_LAYOUT",
+ None,
+ ["LBNHC", "LBHNC", "NHC", "HNC", "NHD", "HND", "BLHNC", "BHLNC"],
),
# SSM conv state layout used for Mamba models.
# - SD: (state_len, dim) — dim contiguous (default)
# - DS: (dim, state_len) — TP-sharded dim on dim1,
- # consistent with SSM temporal state and HND KV cache layout.
+ # consistent with SSM temporal state and LBHNC KV cache layout.
"VLLM_SSM_CONV_STATE_LAYOUT": env_with_choices(
"VLLM_SSM_CONV_STATE_LAYOUT", None, ["SD", "DS"]
),
diff --git a/vllm/model_executor/layers/attention/mla_attention.py b/vllm/model_executor/layers/attention/mla_attention.py
index 140e071c7465..bd8e64a6bcdd 100644
--- a/vllm/model_executor/layers/attention/mla_attention.py
+++ b/vllm/model_executor/layers/attention/mla_attention.py
@@ -519,6 +519,10 @@ def __init__(
compile_native=True,
)
+ def bind_kv_cache(self, kv_cache: torch.Tensor) -> None:
+ # [B, H=1, N, C] -> [B, N, C]
+ self.kv_cache = kv_cache.squeeze(1)
+
@property
def chunked_prefill_workspace_size(self) -> int:
if self._chunked_prefill_workspace_size is None:
@@ -1172,27 +1176,6 @@ def get_name() -> str:
def get_builder_cls() -> type["MLACommonMetadataBuilder"]:
return MLACommonMetadataBuilder
- @staticmethod
- def get_kv_cache_shape(
- num_blocks: int,
- block_size: int,
- num_kv_heads: int, # assumed to be 1 for MLA
- head_size: int,
- cache_dtype_str: str = "auto",
- ) -> tuple[int, ...]:
- return (num_blocks, block_size, head_size)
-
- @staticmethod
- def get_kv_cache_stride_order(
- include_num_layers_dimension: bool = False,
- ) -> tuple[int, ...]:
- if include_num_layers_dimension:
- # MLA kernels require contiguous per-layer KV cache views.
- # Identity permutation keeps num_layers first in physical
- # layout, signaling cross-layer allocation is unsupported.
- return (0, 1, 2, 3)
- return (0, 1, 2)
-
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [320, 576]
@@ -2054,7 +2037,7 @@ def _compute_prefill_context(
toks = prefill_metadata.chunked_context.seq_tot[i]
if not use_fp8_prefill:
ops.gather_and_maybe_dequant_cache(
- src_cache=kv_c_and_k_pe_cache,
+ src_cache=kv_c_and_k_pe_cache.squeeze(1),
dst=workspace,
block_table=prefill_metadata.block_table,
cu_seq_lens=prefill_metadata.chunked_context.cu_seq_lens[i],
@@ -2067,7 +2050,7 @@ def _compute_prefill_context(
else:
# FP8 path: gather cache without dequantization
ops.cp_gather_cache(
- src_cache=kv_c_and_k_pe_cache,
+ src_cache=kv_c_and_k_pe_cache.squeeze(1),
dst=workspace,
block_table=prefill_metadata.block_table,
cu_seq_lens=prefill_metadata.chunked_context.cu_seq_lens[i],
@@ -2162,7 +2145,7 @@ def _context_parallel_compute_prefill_context(
for i in range(iters):
toks = prefill_metadata.chunked_context.seq_tot[i]
ops.cp_gather_cache(
- src_cache=kv_c_and_k_pe_cache,
+ src_cache=kv_c_and_k_pe_cache.squeeze(1),
dst=workspace,
block_table=prefill_metadata.block_table,
cu_seq_lens=prefill_metadata.chunked_context.padded_local_cu_seq_lens[
diff --git a/vllm/model_executor/layers/attention_layer_base.py b/vllm/model_executor/layers/attention_layer_base.py
index 97395b641497..598870c92bcb 100644
--- a/vllm/model_executor/layers/attention_layer_base.py
+++ b/vllm/model_executor/layers/attention_layer_base.py
@@ -4,6 +4,8 @@
from abc import ABC, abstractmethod
+import torch
+
from vllm.config import VllmConfig
from vllm.v1.attention.backend import AttentionBackend, AttentionImpl
from vllm.v1.kv_cache_interface import KVCacheSpec
@@ -20,6 +22,10 @@ class AttentionLayerBase(ABC):
impl: "AttentionImpl"
+ def bind_kv_cache(self, kv_cache: torch.Tensor) -> None:
+ """Bind a ``[B, H, N, C]`` cache view; override to reshape."""
+ self.kv_cache = kv_cache
+
@abstractmethod
def get_attn_backend(self) -> type[AttentionBackend]:
"""Get the attention backend class for this layer."""
diff --git a/vllm/model_executor/layers/mamba/abstract.py b/vllm/model_executor/layers/mamba/abstract.py
index 8bbb21d7bc90..05ceb6ac48bc 100644
--- a/vllm/model_executor/layers/mamba/abstract.py
+++ b/vllm/model_executor/layers/mamba/abstract.py
@@ -2,11 +2,13 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import abstractmethod
from collections.abc import Iterable
+from math import prod
import torch
from vllm.config import VllmConfig
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
+from vllm.utils.torch_utils import get_dtype_size
from vllm.v1.attention.backend import AttentionBackend
from vllm.v1.attention.backends.registry import MambaAttentionBackendEnum
from vllm.v1.attention.selector import get_mamba_attn_backend
@@ -23,6 +25,18 @@ class MambaBase(AttentionLayerBase):
# in the shape specified by `self.get_state_shape`.
kv_cache: tuple[torch.Tensor, ...]
+ def bind_kv_cache(self, kv_cache: torch.Tensor) -> None:
+ """Unpack a raw 4D ``[B, 1, 1, C]`` int8 view into per-state views."""
+ pages = kv_cache.squeeze(dim=(1, 2))
+ states: list[torch.Tensor] = []
+ offset = 0
+ for shape, dtype in zip(self.get_state_shape(), self.get_state_dtype()):
+ nbytes = prod(shape) * get_dtype_size(dtype)
+ state = pages[:, offset : offset + nbytes].view(dtype)
+ states.append(state.view(-1, *shape))
+ offset += nbytes
+ self.kv_cache = tuple(states)
+
@abstractmethod
def get_state_shape(self) -> Iterable[tuple[int, ...]]:
"""
diff --git a/vllm/model_executor/layers/mamba/mamba_utils.py b/vllm/model_executor/layers/mamba/mamba_utils.py
index c1fd81e40e34..aaf579bce3cc 100644
--- a/vllm/model_executor/layers/mamba/mamba_utils.py
+++ b/vllm/model_executor/layers/mamba/mamba_utils.py
@@ -28,7 +28,7 @@ def get_conv_state_layout() -> ConvStateLayoutType:
"""Return the SSM conv state layout.
SD = (state_len, dim) — dim is the innermost contiguous dimension.
- DS = (dim, state_len) — TP-sharded dim is on dim-1 (like HND for KV
+ DS = (dim, state_len) — TP-sharded dim is on dim-1 (like HNC for KV
cache), consistent with SSM temporal state layout.
"""
layout: ConvStateLayoutType | None = envs.VLLM_SSM_CONV_STATE_LAYOUT
diff --git a/vllm/model_executor/models/extract_hidden_states.py b/vllm/model_executor/models/extract_hidden_states.py
index 8df4823b6973..8f08a055b07a 100644
--- a/vllm/model_executor/models/extract_hidden_states.py
+++ b/vllm/model_executor/models/extract_hidden_states.py
@@ -80,11 +80,11 @@ def dummy_attention(layer_name, _placeholder):
def basic_cache(
to_cache: torch.Tensor, # shape: [seq_len, num_heads, head_size]
- kv_cache: torch.Tensor, # shape: [num_blocks, block_size, num_heads, head_size]
+ kv_cache: torch.Tensor, # shape: [num_blocks, num_heads, block_size, head_size]
slot_mapping: torch.Tensor, # shape: [seq_len]
):
- block_size = kv_cache.shape[1]
- kv_cache[slot_mapping // block_size, slot_mapping % block_size] = to_cache
+ block_size = kv_cache.shape[2]
+ kv_cache[slot_mapping // block_size, :, slot_mapping % block_size] = to_cache
######### CacheOnlyAttentionBackend ########
@@ -120,18 +120,6 @@ def supports_mm_prefix(cls) -> bool:
def get_impl_cls() -> type["CacheOnlyAttentionImpl"]:
return CacheOnlyAttentionImpl
- @staticmethod
- def get_kv_cache_shape(
- num_blocks: int,
- block_size: int,
- num_kv_heads: int,
- head_size: int,
- cache_dtype_str: str = "auto",
- ) -> tuple[int, ...]:
- # We set `num_kv_heads = num_hidden_layers` and `head_size = hidden_size`
- # We also don't use a k/v (2) dim
- return (num_blocks, block_size, num_kv_heads, head_size)
-
@staticmethod
def get_builder_cls() -> type["CacheOnlyAttentionMetadataBuilder"]:
return CacheOnlyAttentionMetadataBuilder
diff --git a/vllm/model_executor/models/whisper_causal.py b/vllm/model_executor/models/whisper_causal.py
index dfbf69418a6c..07b496c2605d 100644
--- a/vllm/model_executor/models/whisper_causal.py
+++ b/vllm/model_executor/models/whisper_causal.py
@@ -125,12 +125,6 @@ def __init__(
vllm_config: VllmConfig,
device: torch.device,
):
- assert kv_cache_spec.num_kv_heads % block_pool_size == 0
- kv_cache_spec = replace(
- kv_cache_spec,
- block_size=kv_cache_spec.block_size * block_pool_size,
- num_kv_heads=kv_cache_spec.num_kv_heads // block_pool_size,
- )
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
# Override model_config-derived values with the actual
# encoder values from kv_cache_spec
@@ -247,18 +241,6 @@ def forward(
overrides={
"get_builder_cls": lambda: WhisperCausalAttentionWithBlockPoolingBuilder,
"get_impl_cls": lambda: WhisperCausalAttentionWithBlockPoolingImpl,
- "get_kv_cache_shape": lambda num_blocks,
- block_size,
- num_kv_heads,
- head_size,
- cache_dtype_str: underlying_attn_backend.get_kv_cache_shape(
- num_blocks,
- # we stretch each block by `block_pool_size`
- block_size * block_pool_size,
- num_kv_heads // block_pool_size,
- head_size,
- cache_dtype_str,
- ),
"forward_includes_kv_cache_update": True,
},
)
@@ -325,9 +307,11 @@ def __init__(
def get_kv_cache_spec(self, vllm_config: VllmConfig):
kv_cache_spec = super().get_kv_cache_spec(vllm_config)
assert isinstance(kv_cache_spec, AttentionSpec)
+ assert kv_cache_spec.num_kv_heads % self.block_pool_size == 0
kv_cache_spec = replace(
kv_cache_spec,
- num_kv_heads=self.block_pool_size * kv_cache_spec.num_kv_heads,
+ block_size=kv_cache_spec.block_size * self.block_pool_size,
+ num_kv_heads=kv_cache_spec.num_kv_heads // self.block_pool_size,
)
return kv_cache_spec
diff --git a/vllm/models/deepseek_v4/attention.py b/vllm/models/deepseek_v4/attention.py
index 55cb3d94ba67..0d51eca95fdc 100644
--- a/vllm/models/deepseek_v4/attention.py
+++ b/vllm/models/deepseek_v4/attention.py
@@ -594,6 +594,10 @@ def __init__(
self.kv_cache = torch.tensor([])
+ def bind_kv_cache(self, kv_cache: torch.Tensor) -> None:
+ # [B, H=1, N, C] -> [B, N, C]
+ self.kv_cache = kv_cache.squeeze(1)
+
def get_attn_backend(self) -> type[AttentionBackend]:
return self.backend_cls
@@ -607,7 +611,7 @@ def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None:
num_kv_heads=1,
head_size=self.head_dim,
dtype=torch.uint8,
- compress_ratio=self.compress_ratio,
+ tokens_per_state=self.compress_ratio,
cache_dtype_str=self.kv_cache_dtype,
alignment=576, # NOTE: FlashMLA requires 576B alignment
model_version="deepseek_v4",
@@ -644,15 +648,19 @@ def __init__(
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
+ def bind_kv_cache(self, kv_cache: torch.Tensor) -> None:
+ # [B, H=1, N, C] -> [B, N, C]
+ self.kv_cache = kv_cache.squeeze(1)
+
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
# head_dim already carries the fp8 scale padding
- # compress_ratio=1 for V3.2, >1 for DeepseekV4; both use the same cache layout.
+ # tokens_per_state=1 for V3.2, >1 for DSV4; same cache layout.
return MLAAttentionSpec(
block_size=self.cache_config.block_size,
num_kv_heads=1,
head_size=self.head_dim,
dtype=self.dtype,
- compress_ratio=self.compress_ratio,
+ tokens_per_state=self.compress_ratio,
# DeepseekV4 aligns indexer pages to FlashMLA's 576B so they can pack with
# the indexer's compressor state cache. V3.2 keeps the legacy layout.
alignment=576,
diff --git a/vllm/models/deepseek_v4/compressor.py b/vllm/models/deepseek_v4/compressor.py
index f36dc8f17629..6b4b7f79eab4 100644
--- a/vllm/models/deepseek_v4/compressor.py
+++ b/vllm/models/deepseek_v4/compressor.py
@@ -54,25 +54,6 @@ def get_supported_head_sizes(cls) -> list[int]:
def get_builder_cls() -> type["CompressorMetadataBuilder"]:
return CompressorMetadataBuilder
- @staticmethod
- def get_kv_cache_shape(
- num_blocks: int,
- block_size: int,
- num_kv_heads: int,
- head_size: int,
- cache_dtype_str: str = "auto",
- ) -> tuple[int, ...]:
- assert num_kv_heads == 1
- return (num_blocks, block_size, head_size)
-
- @staticmethod
- def get_kv_cache_stride_order(
- include_num_layers_dimension: bool = False,
- ) -> tuple[int, ...]:
- if include_num_layers_dimension:
- return (0, 1, 2, 3)
- return (0, 1, 2)
-
@dataclass
class CompressorMetadata:
@@ -154,6 +135,10 @@ def __init__(
else:
raise ValueError(f"Invalid compress ratio: {compress_ratio}")
+ def bind_kv_cache(self, kv_cache: torch.Tensor) -> None:
+ # [B, H=1, N, C] -> [B, N, C]
+ self.kv_cache = kv_cache.squeeze(1)
+
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
return SlidingWindowMLASpec( # only has one vector instead of K + V
block_size=self.block_size,
diff --git a/vllm/models/deepseek_v4/nvidia/flashmla.py b/vllm/models/deepseek_v4/nvidia/flashmla.py
index 5c8b08d4c120..0c5a1605722a 100644
--- a/vllm/models/deepseek_v4/nvidia/flashmla.py
+++ b/vllm/models/deepseek_v4/nvidia/flashmla.py
@@ -95,21 +95,6 @@ def get_supported_head_sizes(cls) -> list[int]:
# V3.2 default of 576 from FlashMLASparseBackend).
return [512]
- @staticmethod
- def get_kv_cache_shape(
- num_blocks: int,
- block_size: int,
- num_kv_heads: int,
- head_size: int,
- cache_dtype_str: str = "auto",
- ) -> tuple[int, ...]:
- if cache_dtype_str == "fp8_ds_mla":
- # DeepseekV4 main MLA: 584B per token (448 NoPE + 128 RoPE + 8 fp8 scale).
- # head_size passed in is the semantic head_dim (512).
- return (num_blocks, block_size, 584)
- else:
- return (num_blocks, block_size, head_size)
-
class DeepseekV4FlashMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
"""FlashMLA sparse MLA implementation for DeepSeek V4's custom MLA layer."""
diff --git a/vllm/models/deepseek_v4/nvidia/ops/sparse_attn_compress_cutedsl.py b/vllm/models/deepseek_v4/nvidia/ops/sparse_attn_compress_cutedsl.py
index ed16ca6d3b54..8ab5dbb5116b 100644
--- a/vllm/models/deepseek_v4/nvidia/ops/sparse_attn_compress_cutedsl.py
+++ b/vllm/models/deepseek_v4/nvidia/ops/sparse_attn_compress_cutedsl.py
@@ -1313,6 +1313,8 @@ def compress_norm_rope_store_cutedsl(
token_stride: int,
scale_dim: int,
) -> None:
+ # (B, H=1, N, C) -> (B, N, C)
+ kv_cache = kv_cache.squeeze(1)
if compress_ratio == 4:
# For C4A, the single fused kernel is faster than the two-kernel version.
fused_kv_compress_norm_rope_insert_sparse_attn_cutedsl(
diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py
index 5947bff9b080..681674d3501c 100644
--- a/vllm/platforms/xpu.py
+++ b/vllm/platforms/xpu.py
@@ -55,10 +55,10 @@ def get_attn_backend_cls(
) -> str:
from vllm.v1.attention.backends.utils import set_kv_cache_layout
- set_kv_cache_layout("NHD")
+ set_kv_cache_layout("LBNHC")
logger.info(
- "Setting VLLM_KV_CACHE_LAYOUT to 'NHD' for XPU; "
- "only NHD layout is supported by XPU attention kernels."
+ "Setting VLLM_KV_CACHE_LAYOUT to 'LBNHC' for XPU; "
+ "only LBNHC layout is supported by XPU attention kernels."
)
# TurboQuant KV cache: route directly to TQ backend
diff --git a/vllm/utils/torch_utils.py b/vllm/utils/torch_utils.py
index 12ec5b0fcc66..3ba33ca634ab 100644
--- a/vllm/utils/torch_utils.py
+++ b/vllm/utils/torch_utils.py
@@ -417,26 +417,24 @@ def nvfp4_kv_cache_full_dim(head_size: int) -> int:
return head_size // 2 + head_size // 16
-def _nvfp4_split_data_scale(
+def nvfp4_split_data_scale(
kv_side: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
- """Split a single NVFP4 KV-side buffer into data and scale views.
+ """Split one side (K or V) of an NVFP4 KV cache into data and scale.
- The input is a 4D tensor for one KV side (K or V) whose last
- dimension is ``full_dim = data_dim + scale_dim``. The physical
- layout within each side is [data | scale], both packed contiguously.
+ The input is a 4D uint8 tensor whose last dimension is
+ ``full_dim = data_dim + scale_dim``. The physical layout within
+ each side is ``[data | scale]``, both packed contiguously.
+
+ The caller is responsible for slicing K and V from the combined
+ cache first (e.g. ``kv_cache.split(num_kv_heads, dim=1)``).
Args:
- kv_side: 4D uint8 tensor with shape
- ``(num_pages, dim_1, dim_2, full_dim)``.
- May be in any permutation order (NHD or HND).
+ kv_side: 4D uint8 tensor ``(B, H, N, full_dim)``.
Returns:
- ``(data, scale)`` where
- ``data`` is a uint8 view with shape
- ``(num_pages, dim_1, dim_2, data_dim)``.
- ``scale`` is a float8_e4m3fn view with shape
- ``(num_pages, dim_1, dim_2, scale_dim)``.
+ ``(data, scale)`` where *data* is uint8 and *scale* is
+ float8_e4m3fn, both views of the same storage.
"""
num_pages = kv_side.shape[0]
dim_1, dim_2 = kv_side.shape[1], kv_side.shape[2]
@@ -447,9 +445,6 @@ def _nvfp4_split_data_scale(
data_per_kv = dim_1 * dim_2 * data_dim
page_bytes = kv_side.stride(0)
- # Derive inner strides from the kv_side strides, scaling by the
- # ratio of the target dim to full_dim. This preserves the physical
- # layout (NHD vs HND) encoded in the input tensor's strides.
s1 = kv_side.stride(1) * data_dim // full_dim
s2 = kv_side.stride(2) * data_dim // full_dim
data_shape = (num_pages, dim_1, dim_2, data_dim)
@@ -463,44 +458,15 @@ def _nvfp4_split_data_scale(
base = kv_side.storage_offset()
data = torch.as_strided(kv_side, data_shape, data_strides, storage_offset=base)
scale = torch.as_strided(
- kv_side, scale_shape, scale_strides, storage_offset=base + data_per_kv
+ kv_side,
+ scale_shape,
+ scale_strides,
+ storage_offset=base + data_per_kv,
).view(torch.float8_e4m3fn)
return data, scale
-def nvfp4_kv_cache_split_views(kv_cache: torch.Tensor) -> tuple[tuple, tuple]:
- """Split an NVFP4 KV cache tensor into data and scale views.
-
- Accepts either a 5D tensor ``(num_pages, 2, dim_2, dim_3, full_dim)``
- or a 4D single-side tensor ``(num_pages, dim_2, dim_3, full_dim)``.
-
- Per-page layout: [K_data | K_scale | V_data | V_scale].
- Each KV side is self-contained (data followed by its scale), so the
- 5D case simply splits each side independently.
-
- The returned views are in the same dim order as the input (NHD or
- HND), so callers get views matching whichever order they passed in.
-
- Args:
- kv_cache: 5D or 4D uint8 tensor where the last dimension is
- ``full_dim = data_dim + scale_dim = 9 * head_size / 16``.
-
- Returns:
- For 5D input:
- ``(k_data, v_data), (k_scale, v_scale)``
- For 4D input (single KV side):
- ``(data,), (scale,)``
- """
- if kv_cache.dim() == 4:
- data, scale = _nvfp4_split_data_scale(kv_cache)
- return (data,), (scale,)
-
- k_data, k_scale = _nvfp4_split_data_scale(kv_cache[:, 0])
- v_data, v_scale = _nvfp4_split_data_scale(kv_cache[:, 1])
- return (k_data, v_data), (k_scale, v_scale)
-
-
def create_kv_caches_with_random_flash(
num_blocks: int,
block_size: int,
@@ -511,14 +477,14 @@ def create_kv_caches_with_random_flash(
model_dtype: str | torch.dtype | None = None,
seed: int | None = None,
device: str | None = "cuda",
- cache_layout: str | None = "NHD",
+ cache_layout: str | None = "LBNHC",
) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
set_random_seed(seed)
dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
generic_kv_cache_shape = (num_blocks, 2, block_size, num_heads, head_size)
- assert cache_layout in ("NHD", "HND")
- stride_order = (0, 1, 2, 3, 4) if cache_layout == "NHD" else (0, 1, 3, 2, 4)
+ assert cache_layout in ("LBNHC", "LBHNC")
+ stride_order = (0, 1, 2, 3, 4) if cache_layout == "LBNHC" else (0, 1, 3, 2, 4)
kv_cache_allocation_shape = tuple(generic_kv_cache_shape[i] for i in stride_order)
scale = head_size**-0.5
diff --git a/vllm/v1/attention/backend.py b/vllm/v1/attention/backend.py
index af58bfd31a57..6641f5dd6b1e 100644
--- a/vllm/v1/attention/backend.py
+++ b/vllm/v1/attention/backend.py
@@ -23,7 +23,6 @@
from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
from vllm.platforms.interface import DeviceCapability
- from vllm.v1.attention.backends.utils import KVCacheLayoutType
from vllm.v1.kv_cache_interface import AttentionSpec, KVQuantMode
from vllm.v1.kv_cache_interface import get_kv_quant_mode
@@ -84,72 +83,14 @@ def get_impl_cls() -> type["AttentionImplBase"]:
def get_builder_cls(): # -> Type["AttentionMetadataBuilder"]:
raise NotImplementedError
- @staticmethod
- @abstractmethod
- def get_kv_cache_shape(
- num_blocks: int,
- block_size: int,
- num_kv_heads: int,
- head_size: int,
- cache_dtype_str: str = "auto",
- ) -> tuple[int, ...]:
- raise NotImplementedError
-
- @classmethod
- def get_kv_cache_block_dim(
- cls,
- block_size: int,
- num_kv_heads: int,
- head_size: int,
- cache_dtype_str: str = "auto",
- ) -> int:
- """Discover which tensor dim is the block index, since different
- backends lay out dims differently."""
- _S = 1234567
- shape = cls.get_kv_cache_shape(
- _S,
- block_size,
- num_kv_heads,
- head_size,
- cache_dtype_str=cache_dtype_str,
- )
- return shape.index(_S)
-
- @staticmethod
- def get_kv_cache_stride_order(
- include_num_layers_dimension: bool = False,
- ) -> tuple[int, ...]:
- """
- Get the physical (memory layout) ordering of the kv cache dimensions.
- e.g. if the KV cache shape is
- [2, num_blocks, block_size, num_heads, head_size],
- and get_kv_cache_stride_order returns (1, 3, 0, 2, 4) then the physical
- ordering of dimensions is
- [num_blocks, num_heads, 2, block_size, head_size].
-
- If this function is unimplemented / raises NotImplementedError,
- the physical layout of the KV cache will match the logical shape.
-
- Args:
- include_num_layers_dimension: if True, includes an additional
- num_layers dimension, which is assumed to be prepended
- to the logical KV cache shape.
- With the above example, a return value (2, 4, 0, 1, 3, 5)
- corresponds to
- [num_blocks, num_heads, num_layers, 2, block_size, head_size].
-
- If an additional dimension is NOT included in the returned
- tuple, the physical layout will not include a layers dimension.
-
- Returns:
- A tuple of ints which is a permutation of range(len(shape)).
- """
- raise NotImplementedError
-
@classmethod
def full_cls_name(cls) -> tuple[str, str]:
return (cls.__module__, cls.__qualname__)
+ @classmethod
+ def get_required_kv_cache_layout(cls) -> str | None:
+ return None
+
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return []
@@ -340,10 +281,6 @@ def validate_configuration(
invalid_reasons.append(combination_reason)
return invalid_reasons
- @classmethod
- def get_required_kv_cache_layout(cls) -> "KVCacheLayoutType | None":
- return None
-
@classmethod
def is_ssm(cls) -> bool:
return False
diff --git a/vllm/v1/attention/backends/cpu_attn.py b/vllm/v1/attention/backends/cpu_attn.py
index 3519691a3c58..347c2ab761f8 100644
--- a/vllm/v1/attention/backends/cpu_attn.py
+++ b/vllm/v1/attention/backends/cpu_attn.py
@@ -25,7 +25,6 @@
MultipleOf,
)
from vllm.v1.attention.backends.utils import (
- KVCacheLayoutType,
split_decodes_and_prefills,
)
from vllm.v1.kv_cache_interface import AttentionSpec, CrossAttentionSpec
@@ -85,20 +84,6 @@ def get_impl_cls() -> type["CPUAttentionBackendImpl"]:
def get_builder_cls() -> type["CPUAttentionMetadataBuilder"]:
return CPUAttentionMetadataBuilder
- @staticmethod
- def get_kv_cache_shape(
- num_blocks: int,
- block_size: int,
- num_kv_heads: int,
- head_size: int,
- cache_dtype_str: str = "auto",
- ) -> tuple[int, ...]:
- return num_blocks, num_kv_heads, block_size, 2 * head_size
-
- @classmethod
- def get_required_kv_cache_layout(cls) -> "KVCacheLayoutType | None":
- return "HND"
-
@staticmethod
def use_cascade_attention(*args, **kwargs) -> bool:
return False
@@ -308,7 +293,7 @@ def forward(
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache: shape =
- [num_blocks, num_kv_heads, block_size, 2 * head_size]
+ [num_blocks, 2*num_kv_heads, block_size, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
@@ -337,13 +322,10 @@ def forward(
self.attn_type,
)
- # For decoder and cross-attention, use KV cache, size are
- # [num_blocks, num_kv_heads, block_size, 2 * head_size]
- # Make a view [num_blocks, num_kv_heads, block_size * 2, head_size]
- # Then slice KV at dim 2
- num_blocks, num_kv_heads, block_size, _ = kv_cache.size()
- kv_cache = kv_cache.view((num_blocks, num_kv_heads, block_size * 2, -1))
- key_cache, value_cache = kv_cache.chunk(2, dim=2)
+ # K and V are stored as separate head groups; slice them out as
+ # contiguous per-head tensors (see CPUModelRunner._allocate_kv_caches).
+ key_cache = kv_cache[:, : self.num_kv_heads]
+ value_cache = kv_cache[:, self.num_kv_heads :]
# key and value may be None in the case of cross attention. They are
# calculated once based on the output from the encoder and then cached
diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py
index c56c4ee6e1ff..31e7c0b4413a 100755
--- a/vllm/v1/attention/backends/flash_attn.py
+++ b/vllm/v1/attention/backends/flash_attn.py
@@ -57,9 +57,6 @@
AttentionMetadataBuilder,
CommonAttentionMetadata,
)
-from vllm.v1.attention.backends.utils import (
- get_kv_cache_layout,
-)
from vllm.v1.kv_cache_interface import AttentionSpec
logger = init_logger(__name__)
@@ -136,39 +133,6 @@ def get_impl_cls() -> type["FlashAttentionImpl"]:
def get_builder_cls() -> type["FlashAttentionMetadataBuilder"]:
return FlashAttentionMetadataBuilder
- @staticmethod
- def get_kv_cache_shape(
- num_blocks: int,
- block_size: int,
- num_kv_heads: int,
- head_size: int,
- cache_dtype_str: str = "auto",
- ) -> tuple[int, ...]:
- if block_size % 16 != 0:
- raise ValueError("Block size must be a multiple of 16.")
- return (num_blocks, 2, block_size, num_kv_heads, head_size)
-
- @staticmethod
- def get_kv_cache_stride_order(
- include_num_layers_dimension: bool = False,
- ) -> tuple[int, ...]:
- # `stride_order` indicates the permutation that gets
- # us from `get_kv_cache_shape` to the actual memory layout we want.
- cache_layout = get_kv_cache_layout()
- if cache_layout == "NHD" and include_num_layers_dimension:
- # (num_blocks, num_layers, 2, block_size, num_kv_heads, head_size)
- return (1, 0, 2, 3, 4, 5)
- elif cache_layout == "NHD":
- stride_order = (0, 1, 2, 3, 4)
- elif cache_layout == "HND" and include_num_layers_dimension:
- # (num_blocks, num_kv_heads, num_layers, 2, block_size, head_size)
- return (1, 4, 0, 2, 3, 5)
- elif cache_layout == "HND":
- stride_order = (0, 1, 3, 2, 4)
- else:
- raise ValueError(f"Unknown cache layout format {cache_layout}.")
- return stride_order
-
@classmethod
def supports_head_size(cls, head_size: int) -> bool:
if head_size % 8 != 0:
@@ -683,7 +647,7 @@ def forward(
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache: shape =
- [num_blocks, 2, block_size, num_kv_heads, head_size]
+ [num_blocks, num_kv_heads, block_size, 2*head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
@@ -730,8 +694,8 @@ def forward(
layer,
)
- # For decoder and cross-attention, use KV cache as before
- key_cache, value_cache = kv_cache.unbind(1)
+ # (B, H, N, 2*head_size) -> ((B, N, H, head_size), (B, N, H, head_size))
+ key_cache, value_cache = kv_cache.transpose(1, 2).split(self.head_size, dim=-1)
# Fix degenerate strides on size-1 dims (e.g. num_kv_heads=1 with TP).
# FA3/4 on H100+ uses TMA, which requires ≥16-byte stride alignment.
# See vllm.utils.torch_utils.canonicalize_singleton_dim_strides.
@@ -861,8 +825,7 @@ def do_kv_cache_update(
return
# Scatter write into the KV cache using slot_mapping indices.
- # No TMA kernel is invoked here, so stride canonicalization is not needed.
- key_cache, value_cache = kv_cache.unbind(1)
+ k_cache, v_cache = kv_cache.transpose(1, 2).split(self.head_size, dim=-1)
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
@@ -874,8 +837,8 @@ def do_kv_cache_update(
reshape_and_cache_flash(
key,
value,
- key_cache,
- value_cache,
+ k_cache,
+ v_cache,
slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
@@ -1163,10 +1126,11 @@ def cascade_attention(
num_tokens = query.shape[0]
block_size = key_cache.shape[-3]
+ num_kv_heads = key_cache.shape[-2]
assert common_prefix_len % block_size == 0
num_common_kv_blocks = common_prefix_len // block_size
assert num_common_kv_blocks > 0
- descale_shape = (cu_prefix_query_lens.shape[0] - 1, key_cache.shape[-2])
+ descale_shape = (cu_prefix_query_lens.shape[0] - 1, num_kv_heads)
# Process shared prefix.
prefix_output, prefix_lse = flash_attn_varlen_func(
@@ -1194,7 +1158,7 @@ def cascade_attention(
num_splits=1 if envs.VLLM_BATCH_INVARIANT else max_num_splits,
)
- descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2])
+ descale_shape = (cu_query_lens.shape[0] - 1, num_kv_heads)
# Process suffix per query.
suffix_output, suffix_lse = flash_attn_varlen_func(
diff --git a/vllm/v1/attention/backends/flash_attn_diffkv.py b/vllm/v1/attention/backends/flash_attn_diffkv.py
index e788b0e3496f..d5a1da20c4d6 100644
--- a/vllm/v1/attention/backends/flash_attn_diffkv.py
+++ b/vllm/v1/attention/backends/flash_attn_diffkv.py
@@ -21,8 +21,6 @@
if is_flash_attn_varlen_func_available():
from vllm.v1.attention.backends.fa_utils import flash_attn_varlen_func
-from vllm.v1.attention.backends.utils import get_kv_cache_layout
-
from .flash_attn import (
FlashAttentionBackend,
FlashAttentionImpl,
@@ -49,48 +47,6 @@ def get_name() -> str:
def get_impl_cls() -> type["FlashAttentionImpl"]:
return FlashAttentionDiffKVImpl
- # Do not modify the interface of get_kv_cache_shape,
- # but consider head_size_v when returning result.
- @staticmethod
- def get_kv_cache_shape(
- num_blocks: int,
- block_size: int,
- num_kv_heads: int,
- head_size: int,
- cache_dtype_str: str = "auto",
- ) -> tuple[int, ...]:
- if block_size % 16 != 0:
- raise ValueError("Block size must be a multiple of 16.")
- return (
- num_blocks,
- block_size,
- num_kv_heads,
- head_size + FlashAttentionDiffKVBackend.head_size_v,
- )
-
- @staticmethod
- def get_kv_cache_stride_order(
- include_num_layers_dimension: bool = False,
- ) -> tuple[int, ...]:
- # `stride_order` indicates the permutation that gets
- # us from `get_kv_cache_shape` to the actual memory layout we want.
- cache_layout = get_kv_cache_layout()
- if cache_layout == "NHD" and include_num_layers_dimension:
- # (num_blocks, num_layers, block_size,
- # num_kv_heads, head_size + head_size_v)
- return (1, 0, 2, 3, 4)
- elif cache_layout == "NHD":
- stride_order = (0, 1, 2, 3)
- elif cache_layout == "HND" and include_num_layers_dimension:
- # (num_blocks, num_kv_heads, num_layers,
- # block_size, head_size + head_size_v)
- return (1, 3, 0, 2, 4)
- elif cache_layout == "HND":
- stride_order = (0, 2, 1, 3)
- else:
- raise ValueError(f"Unknown cache layout format {cache_layout}.")
- return stride_order
-
class FlashAttentionDiffKVImpl(FlashAttentionImpl):
vllm_flash_attn_version: int | None
@@ -120,21 +76,14 @@ def do_kv_cache_update(
# we use direct Q, K, V tensors without caching
return
- # Unlike standard FlashAttn which splits kv_cache via unbind(0),
# DiffKV packs K and V into a single tensor along the last dim:
# kv_cache shape: [num_blocks, block_size, num_kv_heads,
# head_size_k + head_size_v]
- # The triton kernel handles this combined layout directly.
- #
- # NOTE(woosuk): key and value are padded while slot_mapping is
- # not padded. However, we don't need to do key[:num_actual_tokens]
- # and value[:num_actual_tokens] because the reshape_and_cache_flash
- # op uses the slot_mapping's shape to determine the number of
- # actual tokens.
+ # (B, H, N, C) -> (B, N, H, C) for kernel compatibility.
triton_reshape_and_cache_flash_diffkv(
key,
value,
- kv_cache,
+ kv_cache.transpose(1, 2),
slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
@@ -207,8 +156,8 @@ def forward(
layer,
)
- # For decoder and cross-attention, use KV cache as before
- # Different head_size for K and V
+ # (B, H, N, C) -> (B, N, H, C) for kernel compatibility.
+ kv_cache = kv_cache.transpose(1, 2)
key_cache = kv_cache[..., : self.head_size]
value_cache = kv_cache[..., self.head_size :]
# Fix degenerate strides on size-1 dims (e.g. num_kv_heads=1 with TP).
diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py
index 83e3072546f1..8d63cb87238f 100755
--- a/vllm/v1/attention/backends/flashinfer.py
+++ b/vllm/v1/attention/backends/flashinfer.py
@@ -46,8 +46,7 @@
canonicalize_singleton_dim_strides,
is_quantized_kv_cache,
is_strictly_contiguous,
- nvfp4_kv_cache_full_dim,
- nvfp4_kv_cache_split_views,
+ nvfp4_split_data_scale,
)
from vllm.v1.attention.backend import (
AttentionBackend,
@@ -59,12 +58,12 @@
MultipleOf,
)
from vllm.v1.attention.backends.utils import (
- KVCacheLayoutType,
get_dcp_local_seq_lens,
- get_kv_cache_layout,
+ get_flashinfer_layout_string,
get_num_attention_heads_from_layers,
get_per_layer_parameters,
infer_global_hyperparameters,
+ resolve_kv_cache_layout,
split_decodes_and_prefills,
)
from vllm.v1.attention.ops.common import cp_lse_ag_out_rs
@@ -107,9 +106,11 @@ def _trtllm_prefill_attn_kvfp8_dequant(
src_stride_page,
src_stride_kv,
src_stride_head,
+ src_stride_n,
DST_K_CACHE_STRIDE: tl.constexpr,
DST_KV_CACHE_STRIDE: tl.constexpr,
- HEAD_STRIDE: tl.constexpr,
+ PAGE_SIZE: tl.constexpr,
+ HEAD_SIZE: tl.constexpr,
NUM_KV_HEADS: tl.constexpr,
):
batch_idx = tl.program_id(0).to(tl.int64)
@@ -125,36 +126,42 @@ def _trtllm_prefill_attn_kvfp8_dequant(
v_scale_val = tl.load(v_scale_ptr)
mock_page_idx = batch_idx * block_table_stride + mock_block_table_idx + 1
- head_offsets = tl.arange(0, HEAD_STRIDE)
+ HEAD_STRIDE: tl.constexpr = PAGE_SIZE * HEAD_SIZE
+ # 2D indexing: source may have non-contiguous block_size stride.
+ n_idx = tl.arange(0, PAGE_SIZE)[:, None]
+ d_idx = tl.arange(0, HEAD_SIZE)[None, :]
+ dst_nd = n_idx * HEAD_SIZE + d_idx
for h in range(NUM_KV_HEADS):
h_off = tl.cast(h, tl.int64)
- # Read K from source (supports non-contiguous page/kv/head strides)
- src_k = orig_page_num * src_stride_page + h_off * src_stride_head + head_offsets
+ src_k = (
+ orig_page_num * src_stride_page
+ + h_off * src_stride_head
+ + n_idx * src_stride_n
+ + d_idx
+ )
fp8_k = tl.load(kv_cache_ptr + src_k)
dequant_k = (fp8_k.to(tl.float32) * k_scale_val).to(dequant_dtype)
- # Write K to contiguous mock cache
- dst_k = mock_page_idx * DST_KV_CACHE_STRIDE + h * HEAD_STRIDE + head_offsets
+ dst_k = mock_page_idx * DST_KV_CACHE_STRIDE + h * HEAD_STRIDE + dst_nd
tl.store(mock_kv_cache_ptr + dst_k, dequant_k)
- # Read V from source (offset by src_stride_kv for the V half)
src_v = (
orig_page_num * src_stride_page
+ src_stride_kv
+ h_off * src_stride_head
- + head_offsets
+ + n_idx * src_stride_n
+ + d_idx
)
fp8_v = tl.load(kv_cache_ptr + src_v)
dequant_v = (fp8_v.to(tl.float32) * v_scale_val).to(dequant_dtype)
- # Write V to contiguous mock cache
dst_v = (
mock_page_idx * DST_KV_CACHE_STRIDE
+ DST_K_CACHE_STRIDE
+ h * HEAD_STRIDE
- + head_offsets
+ + dst_nd
)
tl.store(mock_kv_cache_ptr + dst_v, dequant_v)
@@ -177,9 +184,8 @@ def trtllm_prefill_attn_kvfp8_dequant(
kv_cache_stride = k_cache_stride * s[1]
strides = kv_cache.stride()
- assert strides[3] == head_size and strides[4] == 1, (
- "For kv cache layouts, (block_size, head_size) "
- f"dimensions must be contiguous, got strides {strides}"
+ assert strides[4] == 1, (
+ f"The head_size dimension must be contiguous, got strides {strides}"
)
new_s = (batch_size * num_of_page_per_token + 1, s[1], s[2], s[3], s[4])
@@ -203,9 +209,11 @@ def trtllm_prefill_attn_kvfp8_dequant(
strides[0],
strides[1],
strides[2],
+ strides[3],
k_cache_stride,
kv_cache_stride,
- head_stride,
+ block_size,
+ head_size,
num_kv_heads,
)
return mock_kv_cache, mock_block_table
@@ -222,7 +230,7 @@ def __init__(
else:
self._dcp_combine = partial(cp_lse_ag_out_rs, is_lse_base_on_e=False)
self._context = BatchPrefillWithPagedKVCacheWrapper(
- workspace_buffer, get_kv_cache_layout()
+ workspace_buffer, get_flashinfer_layout_string()
)
self._new_tokens = BatchPrefillWithRaggedKVCacheWrapper(workspace_buffer)
@@ -282,7 +290,7 @@ def run(
self,
layer: torch.nn.Module,
prefill_query: torch.Tensor,
- kv_cache_permute: torch.Tensor,
+ kv_cache_tuple: tuple[torch.Tensor, torch.Tensor],
key: torch.Tensor,
value: torch.Tensor,
out: torch.Tensor,
@@ -292,7 +300,7 @@ def run(
)
output_context_tmp, lse_context_tmp = self._context.run(
prefill_query_across_dcp,
- kv_cache_permute,
+ kv_cache_tuple,
k_scale=layer._k_scale_float,
v_scale=layer._v_scale_float,
return_lse=True,
@@ -353,41 +361,6 @@ def get_impl_cls() -> type["FlashInferImpl"]:
def get_builder_cls() -> type["FlashInferMetadataBuilder"]:
return FlashInferMetadataBuilder
- @staticmethod
- def get_kv_cache_shape(
- num_blocks: int,
- block_size: int,
- num_kv_heads: int,
- head_size: int,
- cache_dtype_str: str = "auto",
- ) -> tuple[int, ...]:
- if cache_dtype_str == "nvfp4":
- # Packed layout: fp4 data + fp8 block scales in last dim
- last_dim = nvfp4_kv_cache_full_dim(head_size)
- return (num_blocks, 2, block_size, num_kv_heads, last_dim)
- return (num_blocks, 2, block_size, num_kv_heads, head_size)
-
- @staticmethod
- def get_kv_cache_stride_order(
- include_num_layers_dimension: bool = False,
- ) -> tuple[int, ...]:
- # `stride_order` indicates the permutation that gets us from
- # `get_kv_cache_shape` to the actual memory layout we want.
- cache_layout = get_kv_cache_layout()
- if cache_layout == "NHD" and include_num_layers_dimension:
- # (num_blocks, num_layers, 2, block_size, num_kv_heads, head_size)
- return (1, 0, 2, 3, 4, 5)
- elif cache_layout == "NHD":
- stride_order = (0, 1, 2, 3, 4)
- elif cache_layout == "HND" and include_num_layers_dimension:
- # (num_blocks, 2, num_kv_heads, num_layers, block_size, head_size)
- return (1, 2, 4, 0, 3, 5)
- elif cache_layout == "HND":
- stride_order = (0, 1, 3, 2, 4)
- else:
- raise ValueError(f"Unknown cache layout format {cache_layout}.")
- return stride_order
-
@staticmethod
def get_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype:
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
@@ -426,13 +399,6 @@ def supports_sink(cls) -> bool:
# Check if TRTLLM is supported on this platform
return supports_trtllm_attention()
- @classmethod
- def get_required_kv_cache_layout(cls) -> KVCacheLayoutType | None:
- capability = current_platform.get_device_capability()
- if capability is not None and capability.major == 10:
- return "HND"
- return None
-
forward_includes_kv_cache_update: bool = False
@@ -780,7 +746,7 @@ def _get_prefill_wrapper(
backend = "trtllm-gen" if self.is_kvcache_nvfp4 else "auto"
self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
self._get_workspace_buffer(),
- get_kv_cache_layout(),
+ get_flashinfer_layout_string(),
backend=backend,
)
assert self._prefill_wrapper is not None
@@ -806,7 +772,7 @@ def _get_decode_wrapper(self, batch_size: int, use_cudagraph: bool = False):
backend = "trtllm-gen" if self.is_kvcache_nvfp4 else "auto"
decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
self._get_workspace_buffer(),
- get_kv_cache_layout(),
+ get_flashinfer_layout_string(),
use_cuda_graph=use_cudagraph,
paged_kv_indptr_buffer=paged_kv_indptr,
paged_kv_indices_buffer=paged_kv_indices,
@@ -829,7 +795,7 @@ def _get_decode_wrapper(self, batch_size: int, use_cudagraph: bool = False):
def _get_cascade_wrapper(self):
if self._cascade_wrapper is None:
self._cascade_wrapper = MultiLevelCascadeAttentionWrapper(
- 2, self._get_workspace_buffer(), get_kv_cache_layout()
+ 2, self._get_workspace_buffer(), get_flashinfer_layout_string()
)
return self._cascade_wrapper
@@ -1383,9 +1349,7 @@ def forward(
query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
- kv_cache: KV cache tensor with different possible shapes:
- - NHD: [num_blocks, 2, block_size, num_kv_heads, head_size]
- - HND: [num_blocks, 2, num_kv_heads, block_size, head_size]
+ kv_cache: [num_blocks, num_kv_heads, block_size, 2*head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
@@ -1476,19 +1440,9 @@ def forward(
output_padded = output
output = output[:num_actual_tokens]
- if attn_metadata.use_cascade:
- # Cascade attention (rare case).
- assert attn_metadata.cascade_wrapper is not None
- output.copy_(attn_metadata.cascade_wrapper.run(query, kv_cache))
- return output
-
- # When using spec decoding, num_decodes can be < num_decode_tokens
- # because some decode requests may have more than one query token.
- num_decode_tokens = attn_metadata.num_decode_tokens
- num_prefill_tokens = attn_metadata.num_prefill_tokens
-
- stride_order = FlashInferBackend.get_kv_cache_stride_order()
- kv_cache_permute = kv_cache.permute(*stride_order) # HND and contiguous
+ # Permute to FlashInfer's expected layout (metadata-only).
+ stride_order = resolve_kv_cache_layout().layer_stride_order
+ kv_cache_permute = kv_cache.permute(*stride_order)
# Fix degenerate strides on any size-1 dimension (e.g. num_kv_heads=1
# with TP=8). PyTorch permits non-canonical strides on size-1 dims;
# CUDA TMA requires ≥16-byte alignment on all non-outermost strides.
@@ -1505,14 +1459,35 @@ def forward(
)
kv_cache_permute = fixed
- # For NVFP4, the kv_cache last dim is full_dim (data + scale packed).
- # Split into correctly-strided data and scale views.
+ # Split K/V — zero-copy views.
+ # For nvfp4, K and V are stored as separate head groups (2*H heads
+ # in dim 1); for other dtypes, K and V are packed in the content dim.
nvfp4_kv_data = None
nvfp4_kv_block_scales = None
if self.is_kvcache_nvfp4:
- nvfp4_kv_data, nvfp4_kv_block_scales = nvfp4_kv_cache_split_views(
- kv_cache_permute
- )
+ kv_cache_tuple = kv_cache_permute.split(self.num_kv_heads, dim=1)
+ k_data, k_sf = nvfp4_split_data_scale(kv_cache_tuple[0])
+ v_data, v_sf = nvfp4_split_data_scale(kv_cache_tuple[1])
+ nvfp4_kv_data = (k_data, v_data)
+ nvfp4_kv_block_scales = (k_sf, v_sf)
+ else:
+ kv_cache_tuple = kv_cache_permute.split(self.head_size, dim=-1)
+
+ flashinfer_layout = get_flashinfer_layout_string()
+ if flashinfer_layout == "HND":
+ trtllm_kv_cache = tuple(t.transpose(1, 2) for t in kv_cache_tuple)
+ trtllm_kv_layout = "NHD"
+ else:
+ trtllm_kv_cache = kv_cache_tuple
+ trtllm_kv_layout = flashinfer_layout
+
+ if attn_metadata.use_cascade:
+ assert attn_metadata.cascade_wrapper is not None
+ output.copy_(attn_metadata.cascade_wrapper.run(query, kv_cache_tuple))
+ return output
+
+ num_decode_tokens = attn_metadata.num_decode_tokens
+ num_prefill_tokens = attn_metadata.num_prefill_tokens
use_dcp = self.dcp_world_size > 1
@@ -1544,7 +1519,7 @@ def forward(
prefill_wrapper.run(
layer,
prefill_query,
- kv_cache_permute,
+ kv_cache_tuple,
key[num_decode_tokens:],
value[num_decode_tokens:],
out=output[num_decode_tokens:],
@@ -1561,7 +1536,9 @@ def forward(
assert prefill_wrapper._causal
if self.is_kvcache_nvfp4:
- kv_cache_permute = nvfp4_kv_data
+ kv_cache_for_fi = nvfp4_kv_data
+ else:
+ kv_cache_for_fi = kv_cache_tuple
kv_cache_sf = (
nvfp4_kv_block_scales if self.is_kvcache_nvfp4 else None
)
@@ -1579,7 +1556,7 @@ def forward(
prefill_wrapper.run(
prefill_query,
- kv_cache_permute,
+ kv_cache_for_fi,
k_scale=layer._k_scale_float,
v_scale=layer._v_scale_float,
out=out_prefill,
@@ -1602,8 +1579,6 @@ def forward(
block_tables_prefill = attn_metadata.prefill.block_tables
seq_lens_prefill = attn_metadata.prefill.seq_lens
- # This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
- assert get_kv_cache_layout() == "HND"
assert is_strictly_contiguous(prefill_query)
assert is_strictly_contiguous(workspace_buffer)
assert is_strictly_contiguous(block_tables_prefill)
@@ -1628,6 +1603,7 @@ def forward(
out = self._nvfp4_fp8_out[:num_prefill_tokens]
prefill_kv_block_scales = None
+ prefill_kv_layout = trtllm_kv_layout
if self.is_kvcache_nvfp4:
# NVFP4 trtllm-gen kernel requires FP8 query.
assert attn_metadata.q_data_type == FP8_DTYPE, (
@@ -1638,6 +1614,7 @@ def forward(
mock_kv_cache = nvfp4_kv_data
mock_block_table = block_tables_prefill
prefill_kv_block_scales = nvfp4_kv_block_scales
+ prefill_kv_layout = flashinfer_layout
elif (
attn_metadata.q_data_type != FP8_DTYPE
and self.kv_cache_dtype.startswith("fp8")
@@ -1645,28 +1622,22 @@ def forward(
# TRTLLM prefill attention does not support BF16 Q
# and fp8 kv cache. So to enable prefill attention
# with fp8 kv cache, we can construct a mock block
- # and mock kv cache with BF16 KV involved in the prefill
- #
- kv_cache_permute = canonicalize_singleton_dim_strides(
- kv_cache_permute
- )
- kv_strides = kv_cache_permute.stride()
- assert (
- kv_strides[-1] == 1
- and kv_strides[-2] == kv_cache_permute.shape[-1]
- ), (
- "KV cache inner dims (block_size, head_size) must be "
- f"contiguous, got strides {kv_strides}"
- )
+ # and mock kv cache with BF16 KV involved in the prefill.
+ B_kv, H_kv, N_kv = kv_cache_permute.shape[:3]
+ kv_cache_5d = kv_cache_permute.view(
+ B_kv, H_kv, N_kv, 2, self.head_size
+ ).permute(0, 3, 1, 2, 4)
mock_kv_cache, mock_block_table = trtllm_prefill_attn_kvfp8_dequant(
- kv_cache_permute,
+ kv_cache_5d,
block_tables_prefill,
layer._k_scale,
layer._v_scale,
attn_metadata.q_data_type,
)
+ if trtllm_kv_layout == "NHD":
+ mock_kv_cache = mock_kv_cache.transpose(2, 3)
else:
- mock_kv_cache = kv_cache_permute
+ mock_kv_cache = trtllm_kv_cache
mock_block_table = block_tables_prefill
trtllm_batch_context_with_kv_cache(
@@ -1687,6 +1658,7 @@ def forward(
o_sf_scale=self.o_sf_scale,
out=out,
kv_cache_sf=prefill_kv_block_scales,
+ kv_layout=prefill_kv_layout,
)
if needs_fp8_out:
@@ -1707,7 +1679,9 @@ def forward(
assert decode_wrapper._sm_scale == self.scale
if self.is_kvcache_nvfp4:
- kv_cache_permute = nvfp4_kv_data
+ kv_cache_for_fi = nvfp4_kv_data
+ else:
+ kv_cache_for_fi = kv_cache_tuple
kv_cache_sf = nvfp4_kv_block_scales if self.is_kvcache_nvfp4 else None
# NVFP4 kernel only supports FP8 output.
@@ -1730,7 +1704,7 @@ def forward(
)
decode_wrapper.run(
decode_query,
- kv_cache_permute,
+ kv_cache_for_fi,
k_scale=layer._k_scale_float,
v_scale=layer._v_scale_float,
out=output_tmp,
@@ -1746,7 +1720,7 @@ def forward(
else:
decode_wrapper.run(
decode_query,
- kv_cache_permute,
+ kv_cache_for_fi,
k_scale=layer._k_scale_float,
v_scale=layer._v_scale_float,
out=out_decode,
@@ -1767,21 +1741,10 @@ def forward(
block_tables_decode = attn_metadata.decode.block_tables
seq_lens_decode = attn_metadata.decode.seq_lens
- # This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
- assert get_kv_cache_layout() == "HND"
assert is_strictly_contiguous(decode_query)
assert is_strictly_contiguous(workspace_buffer)
assert is_strictly_contiguous(block_tables_decode)
assert is_strictly_contiguous(seq_lens_decode)
- kv_cache_permute = canonicalize_singleton_dim_strides(kv_cache_permute)
- kv_strides = kv_cache_permute.stride()
- assert (
- kv_strides[-1] == 1 and kv_strides[-2] == kv_cache_permute.shape[-1]
- ), (
- "KV cache inner dims (block_size, head_size) must be "
- f"contiguous, got strides {kv_strides}"
- )
-
if output.dtype == FP4_DTYPE:
assert self.o_sf_scale is not None
out = FP4Tensor(
@@ -1807,11 +1770,16 @@ def forward(
else:
q_len_per_req = num_decode_tokens // attn_metadata.num_decodes
+ if self.is_kvcache_nvfp4:
+ decode_kv_data = nvfp4_kv_data
+ decode_kv_layout = flashinfer_layout
+ else:
+ decode_kv_data = trtllm_kv_cache
+ decode_kv_layout = trtllm_kv_layout
+
trtllm_batch_decode_with_kv_cache(
query=decode_query,
- kv_cache=(
- nvfp4_kv_data if self.is_kvcache_nvfp4 else kv_cache_permute
- ),
+ kv_cache=decode_kv_data,
workspace_buffer=workspace_buffer,
block_tables=block_tables_decode,
seq_lens=seq_lens_decode,
@@ -1826,6 +1794,7 @@ def forward(
kv_cache_sf=(
nvfp4_kv_block_scales if self.is_kvcache_nvfp4 else None
),
+ kv_layout=decode_kv_layout,
)
if needs_fp8_out:
@@ -1848,8 +1817,11 @@ def do_kv_cache_update(
# and value[:num_actual_tokens] because the reshape_and_cache_flash
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
- k_cache = kv_cache[:, 0]
- v_cache = kv_cache[:, 1]
+ kv_cache = kv_cache.transpose(1, 2)
+ if self.is_kvcache_nvfp4:
+ k_cache, v_cache = kv_cache.split(self.num_kv_heads, dim=-2)
+ else:
+ k_cache, v_cache = kv_cache.split(self.head_size, dim=-1)
torch.ops._C_cache_ops.reshape_and_cache_flash(
key,
value,
diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py
index b87014252018..f74891f11cac 100644
--- a/vllm/v1/attention/backends/flex_attention.py
+++ b/vllm/v1/attention/backends/flex_attention.py
@@ -117,24 +117,6 @@ def supports_mm_prefix(cls) -> bool:
def get_impl_cls() -> type["FlexAttentionImpl"]:
return FlexAttentionImpl
- @staticmethod
- def get_kv_cache_shape(
- num_blocks: int,
- block_size: int,
- num_kv_heads: int,
- head_size: int,
- cache_dtype_str: str = "auto",
- ) -> tuple[int, ...]:
- return (num_blocks, 2, block_size, num_kv_heads, head_size)
-
- @staticmethod
- def get_kv_cache_stride_order(
- include_num_layers_dimension: bool = False,
- ) -> tuple[int, ...]:
- if include_num_layers_dimension:
- return (1, 0, 3, 2, 4, 5)
- return (0, 2, 1, 3, 4)
-
@staticmethod
def get_builder_cls() -> type["FlexAttentionMetadataBuilder"]:
return FlexAttentionMetadataBuilder
@@ -1063,7 +1045,8 @@ def do_kv_cache_update(
if self.attn_type == AttentionType.ENCODER_ONLY:
return
- key_cache, value_cache = kv_cache.unbind(1)
+ kv_cache = kv_cache.transpose(1, 2)
+ key_cache, value_cache = kv_cache.split(self.head_size, dim=-1)
torch.ops._C_cache_ops.reshape_and_cache_flash(
key,
value,
@@ -1168,11 +1151,12 @@ def forward(
else:
assert self.attn_type == AttentionType.DECODER
- key_cache, value_cache = kv_cache.unbind(1)
+ kv_cache = kv_cache.transpose(1, 2)
+ key_cache, value_cache = kv_cache.split(self.head_size, dim=-1)
# Flatten (num_blocks, block_size) into a single token dim
- key_cache = key_cache.view(-1, self.num_kv_heads, self.head_size)
- value_cache = value_cache.view(-1, self.num_kv_heads, self.head_size)
+ key_cache = key_cache.reshape(-1, self.num_kv_heads, self.head_size)
+ value_cache = value_cache.reshape(-1, self.num_kv_heads, self.head_size)
query, key_tensor, value_tensor = map(
lambda x: self.view_as_4d(x).permute(0, 2, 1, 3),
(query, key_cache, value_cache),
diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py
index bd947296e8bc..e767952ebe19 100644
--- a/vllm/v1/attention/backends/mla/flashattn_mla.py
+++ b/vllm/v1/attention/backends/mla/flashattn_mla.py
@@ -326,6 +326,7 @@ def forward_mqa(
if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError("FP8 FlashAttention MLA not yet supported")
+ kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.unsqueeze(2)
kv_c_cache = kv_c_and_k_pe_cache[..., : self.kv_lora_rank]
k_pe_cache = kv_c_and_k_pe_cache[..., self.kv_lora_rank :]
@@ -336,8 +337,8 @@ def forward_mqa(
attn_out = flash_attn_varlen_func(
q=q_pe,
- k=k_pe_cache.unsqueeze(-2), # Add head dim of 1
- v=kv_c_cache.unsqueeze(-2), # Add head dim of 1
+ k=k_pe_cache,
+ v=kv_c_cache,
q_v=q_nope,
max_seqlen_q=max_seqlen_q,
cu_seqlens_q=attn_metadata.decode.query_start_loc,
diff --git a/vllm/v1/attention/backends/mla/flashinfer_mla.py b/vllm/v1/attention/backends/mla/flashinfer_mla.py
index e98bee9d79b5..4750f0bc18c0 100644
--- a/vllm/v1/attention/backends/mla/flashinfer_mla.py
+++ b/vllm/v1/attention/backends/mla/flashinfer_mla.py
@@ -23,7 +23,6 @@
AttentionType,
MultipleOf,
)
-from vllm.v1.attention.backends.utils import KVCacheLayoutType
logger = init_logger(__name__)
@@ -91,10 +90,6 @@ def supports_combination(
)
return None
- @classmethod
- def get_required_kv_cache_layout(cls) -> "KVCacheLayoutType | None":
- return "HND"
-
g_fi_workspace = torch.zeros(
FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE,
@@ -189,7 +184,7 @@ def forward_mqa(
o = trtllm_batch_decode_with_kv_cache_mla(
query=q,
- kv_cache=kv_c_and_k_pe_cache.unsqueeze(1),
+ kv_cache=kv_c_and_k_pe_cache,
workspace_buffer=self._workspace_buffer,
qk_nope_head_dim=self.qk_nope_head_dim,
kv_lora_rank=self.kv_lora_rank,
diff --git a/vllm/v1/attention/backends/mla/flashinfer_mla_sparse.py b/vllm/v1/attention/backends/mla/flashinfer_mla_sparse.py
index 842153f40396..0a77f53d7e8f 100644
--- a/vllm/v1/attention/backends/mla/flashinfer_mla_sparse.py
+++ b/vllm/v1/attention/backends/mla/flashinfer_mla_sparse.py
@@ -41,7 +41,6 @@
from vllm.v1.attention.backends.mla.sparse_utils import (
triton_convert_req_index_to_global_index,
)
-from vllm.v1.attention.backends.utils import KVCacheLayoutType
from vllm.v1.kv_cache_interface import AttentionSpec
if TYPE_CHECKING:
@@ -130,20 +129,6 @@ def supports_combination(
return "FlashInfer MLA Sparse requires model with index_topk config"
return None
- @staticmethod
- def get_kv_cache_shape(
- num_blocks: int,
- block_size: int,
- num_kv_heads: int, # assumed to be 1 for MLA
- head_size: int,
- cache_dtype_str: str = "auto",
- ) -> tuple[int, ...]:
- return (num_blocks, block_size, head_size)
-
- @classmethod
- def get_required_kv_cache_layout(cls) -> "KVCacheLayoutType | None":
- return "HND"
-
@dataclass
class FlashInferMLASparseMetadata(AttentionMetadata):
@@ -350,7 +335,7 @@ def forward_mqa(
o = trtllm_batch_decode_with_kv_cache_mla(
query=q.unsqueeze(1),
- kv_cache=kv_c_and_k_pe_cache.unsqueeze(1),
+ kv_cache=kv_c_and_k_pe_cache,
workspace_buffer=self._workspace_buffer,
qk_nope_head_dim=self.qk_nope_head_dim,
kv_lora_rank=self.kv_lora_rank,
diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py
index 2f6058d69aeb..760624dfbecb 100644
--- a/vllm/v1/attention/backends/mla/flashmla.py
+++ b/vllm/v1/attention/backends/mla/flashmla.py
@@ -305,7 +305,7 @@ def forward_mqa(
if is_quantized_kv_cache(self.kv_cache_dtype):
o, lse = flash_mla_with_kvcache_fp8(
q=q,
- k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
+ k_cache=kv_c_and_k_pe_cache.unsqueeze(2),
block_table=attn_metadata.decode.block_table,
cache_seqlens=attn_metadata.decode.seq_lens,
head_dim_v=self.kv_lora_rank,
@@ -319,7 +319,7 @@ def forward_mqa(
else:
o, lse = flash_mla_with_kvcache(
q=q,
- k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
+ k_cache=kv_c_and_k_pe_cache.unsqueeze(2),
block_table=attn_metadata.decode.block_table,
cache_seqlens=attn_metadata.decode.seq_lens,
head_dim_v=self.kv_lora_rank,
diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py
index 9140a6fccd55..8ec868be2b34 100644
--- a/vllm/v1/attention/backends/mla/flashmla_sparse.py
+++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py
@@ -135,19 +135,15 @@ def is_sparse(cls) -> bool:
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
return capability.major in [9, 10]
+
+class DeepseekV4FlashMLASparseBackend(FlashMLASparseBackend):
@staticmethod
- def get_kv_cache_shape(
- num_blocks: int,
- block_size: int,
- num_kv_heads: int, # assumed to be 1 for MLA
- head_size: int,
- cache_dtype_str: str = "auto",
- ) -> tuple[int, ...]:
- if cache_dtype_str == "fp8_ds_mla":
- # V3.2 main MLA: 656-byte custom storage format. See module docstring.
- return (num_blocks, block_size, 656)
- else:
- return (num_blocks, block_size, head_size)
+ def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
+ return [256]
+
+ @staticmethod
+ def get_name() -> str:
+ return "V4_FLASHMLA_SPARSE"
@dataclass
@@ -332,8 +328,7 @@ def __init__(
)
self.compress_ratio = 1
if self.is_deepseek_v4:
- assert hasattr(self.kv_cache_spec, "compress_ratio")
- self.compress_ratio = self.kv_cache_spec.compress_ratio
+ self.compress_ratio = self.kv_cache_spec.tokens_per_state
# Pre-allocate compressed slot mapping buffer for CUDA graph
# address stability when compress_ratio > 1.
if self.compress_ratio > 1:
@@ -942,7 +937,7 @@ def _fp8_flash_mla_kernel(
out, lse = flash_mla_with_kvcache(
q=q,
- k_cache=kv_c_and_k_pe_cache.view(torch.uint8).unsqueeze(-2),
+ k_cache=kv_c_and_k_pe_cache.view(torch.uint8).unsqueeze(2),
block_table=kernel_metadata.dummy_block_table,
head_dim_v=512,
cache_seqlens=kernel_metadata.cache_lens,
@@ -966,7 +961,7 @@ def _bf16_flash_mla_kernel(
) -> torch.Tensor:
num_tokens = q.shape[0]
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view(
- -1, 1, kv_c_and_k_pe_cache.shape[-1]
+ -1, 1, 1, kv_c_and_k_pe_cache.shape[-1]
)
# NOTE(Chen): kernel requires num_local_head to be a multiple of
diff --git a/vllm/v1/attention/backends/mla/indexer.py b/vllm/v1/attention/backends/mla/indexer.py
index 2870ec9a15c0..a410b1dac441 100644
--- a/vllm/v1/attention/backends/mla/indexer.py
+++ b/vllm/v1/attention/backends/mla/indexer.py
@@ -132,28 +132,6 @@ def get_supported_head_sizes(cls) -> list[int]:
def get_builder_cls() -> type["DeepseekV32IndexerMetadataBuilder"]:
return DeepseekV32IndexerMetadataBuilder
- @staticmethod
- def get_kv_cache_shape(
- num_blocks: int,
- block_size: int,
- num_kv_heads: int,
- head_size: int,
- cache_dtype_str: str = "auto",
- ) -> tuple[int, ...]:
- assert num_kv_heads == 1
- return (num_blocks, block_size, head_size)
-
- @staticmethod
- def get_kv_cache_stride_order(
- include_num_layers_dimension: bool = False,
- ) -> tuple[int, ...]:
- if include_num_layers_dimension:
- # DeepseekV32Indexer kernels do not support cross-layer
- # KV cache layout. Identity permutation keeps num_layers
- # first, signaling incompatibility.
- return (0, 1, 2, 3)
- return (0, 1, 2)
-
class DeepseekV4IndexerBackend(DeepseekV32IndexerBackend):
@staticmethod
@@ -326,7 +304,7 @@ def __init__(self, *args, **kwargs):
self.compress_ratio = 1
# Get compress_ratio for DeepseekV4 support
if isinstance(self.kv_cache_spec, MLAAttentionSpec):
- self.compress_ratio = self.kv_cache_spec.compress_ratio
+ self.compress_ratio = self.kv_cache_spec.tokens_per_state
# Pre-allocate buffers for CUDA graph compatibility when
if self.compress_ratio > 1:
diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py
index a58ecf2c651f..8dbe4b468f88 100644
--- a/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py
+++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py
@@ -291,16 +291,6 @@ def get_builder_cls() -> type["ROCMAiterMLASparseMetadataBuilder"]:
def get_impl_cls() -> type["ROCMAiterMLASparseImpl"]:
return ROCMAiterMLASparseImpl
- @staticmethod
- def get_kv_cache_shape(
- num_blocks: int,
- block_size: int,
- num_kv_heads: int, # assumed to be 1 for MLA
- head_size: int,
- cache_dtype_str: str = "auto",
- ) -> tuple[int, ...]:
- return (num_blocks, block_size, head_size)
-
@classmethod
def is_mla(cls) -> bool:
return True
diff --git a/vllm/v1/attention/backends/mla/sparse_swa.py b/vllm/v1/attention/backends/mla/sparse_swa.py
index f0e444e493c4..05828c8e6d89 100644
--- a/vllm/v1/attention/backends/mla/sparse_swa.py
+++ b/vllm/v1/attention/backends/mla/sparse_swa.py
@@ -75,6 +75,10 @@ def __init__(
self.block_size = 64
assert self.dtype == torch.uint8
+ def bind_kv_cache(self, kv_cache: torch.Tensor) -> None:
+ # [B, H=1, N, C] -> [B, N, C]
+ self.kv_cache = kv_cache.squeeze(1)
+
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
return SlidingWindowMLASpec(
block_size=self.block_size,
@@ -120,30 +124,6 @@ def get_builder_cls() -> type["DeepseekSparseSWAMetadataBuilder"]:
return DeepseekV4ROCMAiterSparseSWAMetadataBuilder
return DeepseekSparseSWAMetadataBuilder
- @staticmethod
- def get_kv_cache_shape(
- num_blocks: int,
- block_size: int,
- num_kv_heads: int,
- head_size: int,
- cache_dtype_str: str = "auto",
- ) -> tuple[int, ...]:
- assert num_kv_heads == 1
- if cache_dtype_str == "fp8_ds_mla":
- # DeepseekV4 SWA: 584B per token (448 NoPE + 128 RoPE + 8 fp8 scale).
- # head_size passed in is the semantic head_dim (512).
- return (num_blocks, block_size, 584)
- else:
- return (num_blocks, block_size, head_size)
-
- @staticmethod
- def get_kv_cache_stride_order(
- include_num_layers_dimension: bool = False,
- ) -> tuple[int, ...]:
- if include_num_layers_dimension:
- return (0, 1, 2, 3)
- return (0, 1, 2)
-
@dataclass
class DeepseekSparseSWAMetadata:
@@ -206,7 +186,7 @@ def __init__(self, *args, **kwargs):
assert isinstance(self.kv_cache_spec, SlidingWindowMLASpec | MLAAttentionSpec)
mla_spec = cast(SlidingWindowMLASpec | MLAAttentionSpec, self.kv_cache_spec)
self.head_size = mla_spec.head_size # Already considered quantization.
- self.compress_ratio = mla_spec.compress_ratio
+ self.compress_ratio = mla_spec.tokens_per_state
self.block_size = mla_spec.block_size
# Handle MTP: adjust decode_threshold like the indexer does
diff --git a/vllm/v1/attention/backends/mla/tokenspeed_mla.py b/vllm/v1/attention/backends/mla/tokenspeed_mla.py
index 6c8dedd77f27..d70a05d7a1e7 100644
--- a/vllm/v1/attention/backends/mla/tokenspeed_mla.py
+++ b/vllm/v1/attention/backends/mla/tokenspeed_mla.py
@@ -23,7 +23,6 @@
AttentionType,
MultipleOf,
)
-from vllm.v1.attention.backends.utils import KVCacheLayoutType
logger = init_logger(__name__)
@@ -123,10 +122,6 @@ def supports_combination(
)
return None
- @classmethod
- def get_required_kv_cache_layout(cls) -> "KVCacheLayoutType | None":
- return "HND"
-
class TokenspeedMLAImpl(MLACommonImpl[MLACommonMetadata]):
def __init__(
@@ -254,8 +249,7 @@ def forward_mqa(
q.device, self.num_heads, self.kv_lora_rank
)
- # vLLM kv_c_and_k_pe_cache is already (num_blocks, block_size, head_size).
- # tokenspeed_mla_decode wants 3D — pass as-is (no unsqueeze, unlike trtllm).
+ # tokenspeed_mla_decode expects 3D (num_blocks, block_size, head_size).
o = tokenspeed_mla_decode(
query=q,
kv_cache=kv_c_and_k_pe_cache,
diff --git a/vllm/v1/attention/backends/mla/triton_mla.py b/vllm/v1/attention/backends/mla/triton_mla.py
index c2aa5edccb66..77b88c44b5d0 100644
--- a/vllm/v1/attention/backends/mla/triton_mla.py
+++ b/vllm/v1/attention/backends/mla/triton_mla.py
@@ -188,7 +188,6 @@ def forward_mqa(
device=q.device,
)
- # Add a head dim of 1
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.unsqueeze(2)
kv_c_cache = kv_c_and_k_pe_cache[..., : self.kv_lora_rank]
PAGE_SIZE = kv_c_and_k_pe_cache.size(1)
diff --git a/vllm/v1/attention/backends/mla/xpu_mla_sparse.py b/vllm/v1/attention/backends/mla/xpu_mla_sparse.py
index 2fa91d018388..64cc9901b9ea 100644
--- a/vllm/v1/attention/backends/mla/xpu_mla_sparse.py
+++ b/vllm/v1/attention/backends/mla/xpu_mla_sparse.py
@@ -66,16 +66,6 @@ def is_mla(cls) -> bool:
def is_sparse(cls) -> bool:
return True
- @staticmethod
- def get_kv_cache_shape(
- num_blocks: int,
- block_size: int,
- num_kv_heads: int, # assumed to be 1 for MLA
- head_size: int,
- cache_dtype_str: str = "auto",
- ) -> tuple[int, ...]:
- return (num_blocks, block_size, head_size)
-
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [576]
@@ -207,7 +197,7 @@ def _forward_bf16_kv(
) -> torch.Tensor:
num_tokens = q.shape[0]
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view(
- -1, 1, kv_c_and_k_pe_cache.shape[-1]
+ -1, 1, 1, kv_c_and_k_pe_cache.shape[-1]
)
topk_indices = topk_indices.view(num_tokens, 1, -1)
diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py
index a9fa45debcfc..b810c1c7a268 100644
--- a/vllm/v1/attention/backends/rocm_aiter_fa.py
+++ b/vllm/v1/attention/backends/rocm_aiter_fa.py
@@ -754,18 +754,6 @@ def get_impl_cls() -> type["AiterFlashAttentionImpl"]:
def get_builder_cls() -> type["AiterFlashAttentionMetadataBuilder"]:
return AiterFlashAttentionMetadataBuilder
- @staticmethod
- def get_kv_cache_shape(
- num_blocks: int,
- block_size: int,
- num_kv_heads: int,
- head_size: int,
- cache_dtype_str: str = "auto",
- ) -> tuple[int, ...]:
- if block_size % 16 != 0:
- raise ValueError("Block size must be a multiple of 16.")
- return (num_blocks, 2, block_size, num_kv_heads, head_size)
-
@classmethod
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
from vllm.platforms.rocm import on_mi3xx
@@ -1057,7 +1045,8 @@ def forward(
# Whenever making a change in this method, please benchmark the
# performance to make sure it does not introduce any overhead.
num_actual_tokens = attn_metadata.num_actual_tokens
- key_cache, value_cache = kv_cache.unbind(1)
+ kv_cache = kv_cache.transpose(1, 2)
+ key_cache, value_cache = kv_cache.split(self.head_size, dim=-1)
if is_quantized_kv_cache(self.kv_cache_dtype):
key_cache = key_cache.view(current_platform.fp8_dtype())
@@ -1384,7 +1373,8 @@ def do_kv_cache_update(
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
):
- key_cache, value_cache = kv_cache.unbind(1)
+ kv_cache = kv_cache.transpose(1, 2)
+ key_cache, value_cache = kv_cache.split(self.head_size, dim=-1)
# key and value may be None in the case of cross attention. They are
# calculated once based on the output from the encoder and then cached
@@ -1452,7 +1442,8 @@ def do_rope_and_kv_cache_update(
kv_cache: torch.Tensor,
layer_slot_mapping: torch.Tensor,
):
- key_cache, value_cache = kv_cache.unbind(1)
+ kv_cache = kv_cache.transpose(1, 2)
+ key_cache, value_cache = kv_cache.split(self.head_size, dim=-1)
flash_layout = True
is_fp8_kv_cache = is_quantized_kv_cache(self.kv_cache_dtype)
diff --git a/vllm/v1/attention/backends/rocm_aiter_unified_attn.py b/vllm/v1/attention/backends/rocm_aiter_unified_attn.py
index 984fc20ecaff..bf2b934f6bd0 100644
--- a/vllm/v1/attention/backends/rocm_aiter_unified_attn.py
+++ b/vllm/v1/attention/backends/rocm_aiter_unified_attn.py
@@ -64,18 +64,6 @@ def get_name() -> str:
def get_impl_cls() -> type["RocmAiterUnifiedAttentionImpl"]:
return RocmAiterUnifiedAttentionImpl
- @staticmethod
- def get_kv_cache_shape(
- num_blocks: int,
- block_size: int,
- num_kv_heads: int,
- head_size: int,
- cache_dtype_str: str = "auto",
- ) -> tuple[int, ...]:
- if block_size % 16 != 0:
- raise ValueError("Block size must be a multiple of 16.")
- return (num_blocks, 2, block_size, num_kv_heads, head_size)
-
@staticmethod
def use_cascade_attention(*args, **kwargs) -> bool:
return False
@@ -134,30 +122,6 @@ def __init__(
self.unified_attention = unified_attention
self.supports_quant_query_input = True
- def _split_kv_cache(
- self, kv_cache: torch.Tensor
- ) -> tuple[torch.Tensor, torch.Tensor]:
- if self.attn_type != AttentionType.ENCODER_DECODER:
- return kv_cache.unbind(1)
-
- # NOTE: Encoder-decoder layers can share the same raw KV allocation with
- # ROCM_ATTN decoder layers, whose physical layout is K/V first. Keep
- # this cross-attention path on that physical layout so block IDs do not
- # alias different bytes across the shared allocation.
- num_blocks, _, block_size, num_kv_heads, head_size = kv_cache.shape
- block_stride = block_size * num_kv_heads * head_size
- kv_cache = kv_cache.as_strided(
- (2, num_blocks, block_size, num_kv_heads, head_size),
- (
- num_blocks * block_stride,
- block_stride,
- num_kv_heads * head_size,
- head_size,
- 1,
- ),
- )
- return kv_cache.unbind(0)
-
def forward(
self,
layer: torch.nn.Module,
@@ -177,7 +141,7 @@ def forward(
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache: shape =
- [num_blocks, 2, block_size, num_kv_heads, head_size]
+ [num_blocks, num_kv_heads, block_size, 2 * head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
@@ -218,7 +182,8 @@ def forward(
layer,
)
- key_cache, value_cache = self._split_kv_cache(kv_cache)
+ kv_cache = kv_cache.transpose(1, 2)
+ key_cache, value_cache = kv_cache.split(self.head_size, dim=-1)
softmax_scale = self.scale
if is_quantized_kv_cache(self.kv_cache_dtype):
@@ -267,7 +232,8 @@ def do_kv_cache_update(
# For encoder attention,
# we use direct Q, K, V tensors without caching
return
- key_cache, value_cache = self._split_kv_cache(kv_cache)
+ kv_cache = kv_cache.transpose(1, 2)
+ key_cache, value_cache = kv_cache.split(self.head_size, dim=-1)
# Reshape the input keys and values and store them in the cache.
ops.reshape_and_cache_flash(
@@ -300,7 +266,8 @@ def do_rope_and_kv_cache_update(
# For encoder attention,
# we use direct Q, K, V tensors without caching
return
- key_cache, value_cache = self._split_kv_cache(kv_cache)
+ kv_cache = kv_cache.transpose(1, 2)
+ key_cache, value_cache = kv_cache.split(self.head_size, dim=-1)
flash_layout = True
is_fp8_kv_cache = is_quantized_kv_cache(self.kv_cache_dtype)
diff --git a/vllm/v1/attention/backends/rocm_attn.py b/vllm/v1/attention/backends/rocm_attn.py
index 2f6c48e3df34..74007e73ee03 100644
--- a/vllm/v1/attention/backends/rocm_attn.py
+++ b/vllm/v1/attention/backends/rocm_attn.py
@@ -29,7 +29,6 @@
)
from vllm.v1.attention.ops.chunked_prefill_paged_decode import (
chunked_prefill_paged_decode,
- has_native_kv_cache_layout,
)
from vllm.v1.attention.ops.paged_attn import PagedAttention
from vllm.v1.attention.ops.triton_reshape_and_cache_flash import (
@@ -128,6 +127,7 @@ def build(
use_cascade = common_prefix_len > 0
+ prefix_scheduler_metadata = None
if use_cascade:
cu_prefix_query_lens = torch.tensor(
[0, num_actual_tokens], dtype=torch.int32, device=self.device
@@ -239,18 +239,6 @@ def supports_attn_type(cls, attn_type: str) -> bool:
AttentionType.ENCODER_ONLY,
)
- @staticmethod
- def get_kv_cache_shape(
- num_blocks: int,
- block_size: int,
- num_kv_heads: int,
- head_size: int,
- cache_dtype_str: str = "auto",
- ) -> tuple[int, ...]:
- if block_size % 16 != 0:
- raise ValueError("Block size must be a multiple of 16.")
- return (2, num_blocks, block_size, num_kv_heads, head_size)
-
@staticmethod
def use_cascade_attention(*args, **kwargs) -> bool:
return False
@@ -376,7 +364,7 @@ def forward(
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache: shape =
- [2, num_blocks, block_size, num_kv_heads, head_size]
+ [num_blocks, num_kv_heads, block_size, 2 * head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
@@ -467,44 +455,19 @@ def do_kv_cache_update(
):
if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
return
- key_cache, value_cache = PagedAttention.split_kv_cache(
- kv_cache, self.num_kv_heads, self.head_size
+ kv_cache_transposed = kv_cache.transpose(1, 2)
+ key_cache, value_cache = kv_cache_transposed.split(self.head_size, dim=-1)
+ triton_reshape_and_cache_flash(
+ key,
+ value,
+ key_cache,
+ value_cache,
+ slot_mapping,
+ self.kv_cache_dtype,
+ layer._k_scale,
+ layer._v_scale,
)
- # Reshape the input keys and values and store them in the cache.
- # Get the actual block_size from value_cache
- # value_cache shape: [num_blocks, num_heads, head_size, block_size]
- block_size = value_cache.shape[3]
- has_native_layout = has_native_kv_cache_layout(key_cache, value_cache)
-
- if block_size in (16, 32) and has_native_layout:
- # Normal 16, 32 with contiguous blocks: use vLLM native HIP C++ logic.
- PagedAttention.write_to_paged_cache(
- key,
- value,
- key_cache,
- value_cache,
- slot_mapping,
- self.kv_cache_dtype,
- layer._k_scale,
- layer._v_scale,
- )
- else:
- # Non-standard blocks and hybrid attention/Mamba layouts need the
- # stride-aware Triton writer. The native reshape_and_cache kernel
- # assumes contiguous block storage and writes to the wrong hybrid
- # cache blocks.
- triton_reshape_and_cache_flash(
- key,
- value,
- key_cache,
- value_cache,
- slot_mapping,
- self.kv_cache_dtype,
- layer._k_scale,
- layer._v_scale,
- )
-
def fused_rope_kvcache_supported(self):
return rocm_aiter_ops.is_enabled()
@@ -522,12 +485,12 @@ def do_rope_and_kv_cache_update(
):
if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
return
- key_cache, value_cache = PagedAttention.split_kv_cache(
- kv_cache,
- layer.num_kv_heads, # type: ignore[attr-defined]
- layer.head_size, # type: ignore[attr-defined]
+ kv_cache_transposed = kv_cache.transpose(1, 2)
+ key_cache, value_cache = kv_cache_transposed.split(
+ self.head_size,
+ dim=-1,
)
- flash_layout = False
+ flash_layout = True
is_fp8_kv_cache = is_quantized_kv_cache(self.kv_cache_dtype)
if is_fp8_kv_cache:
diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py
index 008b74c9ff70..832ce0b5e4ea 100644
--- a/vllm/v1/attention/backends/triton_attn.py
+++ b/vllm/v1/attention/backends/triton_attn.py
@@ -31,7 +31,6 @@
MultipleOf,
)
from vllm.v1.attention.backends.utils import (
- get_kv_cache_layout,
get_num_attention_heads_from_layers,
)
from vllm.v1.attention.ops.triton_prefill_attention import context_attention_fwd
@@ -44,7 +43,6 @@
AttentionSpec,
KVQuantMode,
get_kv_quant_mode,
- kv_cache_uses_per_token_head_scales,
)
logger = init_logger(__name__)
@@ -309,51 +307,6 @@ def supports_batch_invariance(cls) -> bool:
def get_impl_cls() -> type["TritonAttentionImpl"]:
return TritonAttentionImpl
- @staticmethod
- def get_kv_cache_shape(
- num_blocks: int,
- block_size: int,
- num_kv_heads: int,
- head_size: int,
- cache_dtype_str: str = "auto",
- ) -> tuple[int, ...]:
- if block_size % 16 != 0:
- raise ValueError("Block size must be a multiple of 16.")
- if kv_cache_uses_per_token_head_scales(cache_dtype_str):
- # Pad head_size by sizeof(float32)/sizeof(cache_dtype) so
- # the per-head scale fits inline. The backend extracts
- # data[:head_size] and scale[head_size:] via typed views.
- from vllm.utils.torch_utils import (
- STR_DTYPE_TO_TORCH_DTYPE,
- get_dtype_size,
- )
-
- cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype_str]
- scale_pad = get_dtype_size(torch.float32) // get_dtype_size(cache_dtype)
- return (num_blocks, 2, block_size, num_kv_heads, head_size + scale_pad)
- return (num_blocks, 2, block_size, num_kv_heads, head_size)
-
- @staticmethod
- def get_kv_cache_stride_order(
- include_num_layers_dimension: bool = False,
- ) -> tuple[int, ...]:
- # `stride_order` indicates the permutation that gets
- # us from `get_kv_cache_shape` to the actual memory layout we want.
- cache_layout = get_kv_cache_layout()
- if cache_layout == "NHD" and include_num_layers_dimension:
- # (num_blocks, num_layers, 2, block_size, num_kv_heads, head_size)
- return (1, 0, 2, 3, 4, 5)
- elif cache_layout == "NHD":
- stride_order = (0, 1, 2, 3, 4)
- elif cache_layout == "HND" and include_num_layers_dimension:
- # (num_blocks, num_kv_heads, num_layers, 2, block_size, head_size)
- return (1, 4, 0, 2, 3, 5)
- elif cache_layout == "HND":
- stride_order = (0, 1, 3, 2, 4)
- else:
- raise ValueError(f"Unknown cache layout: {cache_layout}")
- return stride_order
-
@staticmethod
def use_cascade_attention(*args, **kwargs) -> bool:
return False
@@ -583,10 +536,11 @@ def forward(
layer,
)
- # Per-token-head quantized KV cache: use separate scale caches.
+ # (B, H, N, C) -> (B, N, H, C) for kernel compatibility.
+ kv_cache = kv_cache.transpose(1, 2)
if self._is_per_token_head_quant:
self._ensure_scale_caches(kv_cache)
- key_cache, value_cache = kv_cache.unbind(1)
+ key_cache, value_cache = kv_cache.split(self.head_size, dim=-1)
if key_cache.dtype == torch.uint8:
key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype)
@@ -597,7 +551,7 @@ def forward(
v_scale_cache = self._v_scale_cache
# FP8 per-tensor / auto path (original flow).
else:
- key_cache, value_cache = kv_cache.unbind(1)
+ key_cache, value_cache = kv_cache.split(self.head_size, dim=-1)
if (
is_quantized_kv_cache(self.kv_cache_dtype)
and key_cache.dtype != self.fp8_dtype
@@ -730,9 +684,10 @@ def do_kv_cache_update(
# we use direct Q, K, V tensors without caching
return
# Reshape the input keys and values and store them in the cache.
+ kv_cache = kv_cache.transpose(1, 2)
if self._is_per_token_head_quant:
self._ensure_scale_caches(kv_cache)
- key_cache, value_cache = kv_cache.unbind(1)
+ key_cache, value_cache = kv_cache.split(self.head_size, dim=-1)
if key_cache.dtype == torch.uint8:
key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype)
@@ -747,7 +702,7 @@ def do_kv_cache_update(
)
return
# For decoder and cross-attention, use KV cache as before.
- key_cache, value_cache = kv_cache.unbind(1)
+ key_cache, value_cache = kv_cache.split(self.head_size, dim=-1)
if is_quantized_kv_cache(self.kv_cache_dtype):
key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype)
@@ -779,7 +734,8 @@ def do_rope_and_kv_cache_update(
kv_cache: torch.Tensor,
layer_slot_mapping: torch.Tensor,
):
- key_cache, value_cache = kv_cache.unbind(1)
+ kv_cache = kv_cache.transpose(1, 2)
+ key_cache, value_cache = kv_cache.split(self.head_size, dim=-1)
flash_layout = True
is_fp8_kv_cache = is_quantized_kv_cache(self.kv_cache_dtype)
diff --git a/vllm/v1/attention/backends/turboquant_attn.py b/vllm/v1/attention/backends/turboquant_attn.py
index 3bf3b6b82482..18bd29723520 100644
--- a/vllm/v1/attention/backends/turboquant_attn.py
+++ b/vllm/v1/attention/backends/turboquant_attn.py
@@ -108,6 +108,10 @@ class TurboQuantAttentionBackend(AttentionBackend):
def get_name() -> str:
return "TURBOQUANT"
+ @classmethod
+ def get_required_kv_cache_layout(cls) -> str | None:
+ return "LBNHC"
+
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [16, 32, 64, 128]
@@ -128,38 +132,6 @@ def get_impl_cls() -> type["TurboQuantAttentionImpl"]:
def get_builder_cls() -> type["TurboQuantMetadataBuilder"]:
return TurboQuantMetadataBuilder
- @staticmethod
- def get_kv_cache_shape(
- num_blocks: int,
- block_size: int,
- num_kv_heads: int,
- head_size: int,
- cache_dtype_str: str = "turboquant_4bit_nc",
- ) -> tuple[int, ...]:
- """Combined K+V cache shape — no leading 2 dimension.
-
- Standard attention backends use (2, num_blocks, block_size, num_kv_heads,
- head_dim) with a leading 2 to separate K and V. TurboQuant packs K+V
- into a single interleaved slot per head per position, so the cache is:
-
- (num_blocks, block_size, num_kv_heads, slot_size_aligned)
-
- Each slot = [key_packed | value_packed | padding].
- This is safe because TQ has its own get_kv_cache_shape override and
- never shares cache tensors with other backends. Layers that fall back
- to native dtype via kv_cache_dtype_skip_layers get their own
- standard-shaped cache allocation.
-
- head_size is the model's real head_dim. slot_size_aligned is computed
- from the TQ config to ensure correct cache allocation for all head dims.
- """
- from vllm.model_executor.layers.quantization.turboquant.config import (
- TurboQuantConfig,
- )
-
- tq_config = TurboQuantConfig.from_cache_dtype(cache_dtype_str, head_size)
- return (num_blocks, block_size, num_kv_heads, tq_config.slot_size_aligned)
-
@classmethod
def supports_kv_cache_dtype(cls, kv_cache_dtype: CacheDType | None) -> bool:
if kv_cache_dtype is None:
@@ -383,6 +355,8 @@ def do_kv_cache_update(
k = key[:N].view(N, self.num_kv_heads, self.head_size)
v = value[:N].view(N, self.num_kv_heads, self.head_size)
+ # (B, H, N, C) -> (B, N, H, C) for TQ kernels
+ kv_cache = kv_cache.transpose(1, 2)
self._store_kv(k, v, kv_cache, slot_mapping, layer)
def forward(
@@ -410,13 +384,15 @@ def forward(
if attn_metadata is None:
return output.fill_(0)
+ # (B, H, N, C) -> (B, N, H, C) for TQ kernels
+ kv_cache = kv_cache.transpose(1, 2)
+
# Slice to actual tokens
N = attn_metadata.num_actual_tokens
if N <= 0:
return output.fill_(0)
q = query[:N].view(N, self.num_heads, self.head_size)
-
# Get TQ buffers, ensure on device (one-time migration).
# Use Any-typed alias for dynamic _tq_* attrs set by _ensure_on_device.
tq_layer: Any = layer
@@ -538,7 +514,7 @@ def _store_kv(
self,
key: torch.Tensor, # (N, Hk, D)
value: torch.Tensor, # (N, Hk, D)
- kv_cache: torch.Tensor, # (num_blocks, block_size, Hk, slot_size)
+ kv_cache: torch.Tensor, # (num_blocks, Hk, block_size, slot_size)
slot_mapping: torch.Tensor,
layer: Any,
):
@@ -564,7 +540,7 @@ def _prefill_attention(
query: torch.Tensor, # (N, Hq, D)
key: torch.Tensor, # (N, Hk, D)
value: torch.Tensor, # (N, Hk, D)
- kv_cache: torch.Tensor, # (num_blocks, block_size, Hk, slot_size)
+ kv_cache: torch.Tensor, # (num_blocks, Hk, block_size, slot_size)
attn_metadata: TurboQuantMetadata,
Pi: torch.Tensor,
centroids: torch.Tensor,
@@ -715,7 +691,7 @@ def _continuation_prefill(
query: torch.Tensor, # (q_len, Hq, D)
key_chunk: torch.Tensor, # (q_len, Hk, D)
val_chunk: torch.Tensor, # (q_len, Hk, D)
- kv_cache: torch.Tensor, # (num_blocks, block_size, Hk, slot_size)
+ kv_cache: torch.Tensor, # (num_blocks, Hk, block_size, slot_size)
block_table: torch.Tensor, # (1, max_num_blocks)
cached_len: int,
seq_len: int,
@@ -730,7 +706,7 @@ def _continuation_prefill(
q_len, Hq, D = query.shape
Hk = key_chunk.shape[1]
device = query.device
- block_size = kv_cache.shape[1]
+ block_size = kv_cache.shape[2]
BLOCK_D = triton.next_power_of_2(D)
mse_bytes = self._mse_bytes
@@ -767,8 +743,8 @@ def _continuation_prefill(
v_cached.stride(1),
v_cached.stride(2),
kv_cache.stride(0),
- kv_cache.stride(1),
kv_cache.stride(2),
+ kv_cache.stride(1),
block_table.stride(0),
HEAD_DIM=D,
BLOCK_SIZE=block_size,
@@ -858,7 +834,7 @@ def _continuation_prefill(
def _decode_attention(
self,
query: torch.Tensor, # (B, Hq, D)
- kv_cache: torch.Tensor, # (num_blocks, block_size, Hk, slot_size)
+ kv_cache: torch.Tensor, # (num_blocks, Hk, block_size, slot_size)
attn_metadata: TurboQuantMetadata,
Pi: torch.Tensor,
centroids: torch.Tensor,
diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py
index b73d17e8e5cc..54a8596a0ebc 100644
--- a/vllm/v1/attention/backends/utils.py
+++ b/vllm/v1/attention/backends/utils.py
@@ -8,7 +8,6 @@
Any,
Literal,
Protocol,
- get_args,
)
import numpy as np
@@ -17,7 +16,11 @@
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.utils.math_utils import cdiv
-from vllm.v1.kv_cache_interface import KVCacheSpec, MambaSpec
+from vllm.v1.kv_cache_interface import (
+ KVCacheLayout,
+ KVCacheSpec,
+ MambaSpec,
+)
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
@@ -38,51 +41,87 @@
)
logger = init_logger(__name__)
-KVCacheLayoutType = Literal["NHD", "HND"]
+
+# Deprecated: use resolve_kv_cache_layout() instead (RFC #42082).
+KVCacheLayoutType = Literal["LBNHC", "LBHNC", "BLHNC", "BHLNC"]
_KV_CACHE_LAYOUT_OVERRIDE: KVCacheLayoutType | None = None
PAD_SLOT_ID = -1
NULL_BLOCK_ID = 0
-def is_valid_kv_cache_layout(value: str) -> bool:
- return value in get_args(KVCacheLayoutType)
+_LAYOUT_COMPAT_ALIASES = {
+ "NHD": "LBNHC",
+ "HND": "LBHNC",
+ "NHC": "LBNHC",
+ "HNC": "LBHNC",
+}
+_FLASHINFER_LAYOUT_NAMES = {"LBNHC": "NHD", "LBHNC": "HND"}
-@functools.lru_cache
-def get_kv_cache_layout():
- # Format specified by the code.
- global _KV_CACHE_LAYOUT_OVERRIDE
+def is_valid_kv_cache_layout(value: str) -> bool:
+ return value in KVCacheLayout.__members__ or value in _LAYOUT_COMPAT_ALIASES
- cache_layout: Literal["NHD", "HND"] | None = None
- if _KV_CACHE_LAYOUT_OVERRIDE is not None:
- cache_layout = _KV_CACHE_LAYOUT_OVERRIDE
- logger.info_once(
- "`_KV_CACHE_LAYOUT_OVERRIDE` variable detected. "
- "Setting KV cache layout to %s.",
- cache_layout,
- )
- return cache_layout
- # Format specified by the user.
- cache_layout = envs.VLLM_KV_CACHE_LAYOUT
- # When neither the user nor the override specified a layout, get default
- if cache_layout is None:
- cache_layout = get_kv_connector_cache_layout()
- else:
- assert is_valid_kv_cache_layout(cache_layout)
- logger.info_once(
- "`VLLM_KV_CACHE_LAYOUT` environment variable "
- "detected. Setting KV cache layout to %s.",
- cache_layout,
- )
- return cache_layout
+def get_flashinfer_layout_string() -> str:
+ """Return the layout name in FlashInfer's convention (NHD/HND)."""
+ name = resolve_kv_cache_layout().name
+ return _FLASHINFER_LAYOUT_NAMES.get(name, name)
-def set_kv_cache_layout(cache_layout: KVCacheLayoutType | None):
+def set_kv_cache_layout(cache_layout: "KVCacheLayoutType | None"):
+ """Override the KV cache layout (for tests and platform constraints)."""
global _KV_CACHE_LAYOUT_OVERRIDE
_KV_CACHE_LAYOUT_OVERRIDE = cache_layout
- get_kv_cache_layout.cache_clear()
+ resolve_kv_cache_layout.cache_clear()
+
+
+@functools.lru_cache
+def resolve_kv_cache_layout(
+ attn_backends: tuple[type[AttentionBackend], ...] | None = None,
+) -> KVCacheLayout:
+ """Resolve the physical KV cache layout from the config priority chain.
+
+ Priority:
+ 1. Runtime override (set_kv_cache_layout, used by tests)
+ 2. VLLM_KV_CACHE_LAYOUT env var (user override)
+ 3. Connector's get_required_kvcache_layout() preference
+ 4. "LBHNC" fallback
+ """
+ global _KV_CACHE_LAYOUT_OVERRIDE
+ layout_name: str | None
+ if _KV_CACHE_LAYOUT_OVERRIDE is not None:
+ layout_name = _KV_CACHE_LAYOUT_OVERRIDE
+ else:
+ layout_name = envs.VLLM_KV_CACHE_LAYOUT
+
+ if layout_name is None and attn_backends is not None:
+ required_layouts = set(
+ backend.get_required_kv_cache_layout() for backend in attn_backends
+ )
+ required_layouts.discard(None)
+ if len(required_layouts) > 1:
+ raise ValueError(
+ f"Multiple required KV cache layouts: {required_layouts}. "
+ f"All backends must use the same layout."
+ )
+ if len(required_layouts) == 1:
+ layout_name = required_layouts.pop()
+
+ if layout_name is None:
+ layout_name = get_kv_connector_cache_layout()
+
+ layout_name = layout_name or "LBHNC"
+ layout_name = _LAYOUT_COMPAT_ALIASES.get(layout_name, layout_name)
+ try:
+ layout = KVCacheLayout[layout_name]
+ except KeyError:
+ raise ValueError(
+ f"Unknown KV cache layout {layout_name!r}. "
+ f"Valid layouts: {[m.name for m in KVCacheLayout]}"
+ ) from None
+ logger.info_once("Resolved KV cache layout: %s", layout)
+ return layout
@dataclass
diff --git a/vllm/v1/attention/ops/chunked_prefill_paged_decode.py b/vllm/v1/attention/ops/chunked_prefill_paged_decode.py
index 77eb3ac60b1f..68224b63a150 100644
--- a/vllm/v1/attention/ops/chunked_prefill_paged_decode.py
+++ b/vllm/v1/attention/ops/chunked_prefill_paged_decode.py
@@ -27,9 +27,10 @@ def has_native_kv_cache_layout(
) -> bool:
"""Return whether KV cache blocks can use the native ROCm pairing.
- The native reshape_and_cache writer assumes packed blocks. If cache update
- needs reshape_and_cache_flash for a stride-padded hybrid layout, decode
- should use the matching Triton path too.
+ The C++ ``ops.paged_attention_rocm`` custom kernel requires each block
+ to be contiguous in memory. Returns False for stride-padded hybrid
+ layouts and for the unified KV cache (RFC #42082, see
+ :meth:`PagedAttention.split_kv_cache`), routing them to Triton.
"""
return (
key_cache.stride(0) == key_cache.shape[1:].numel()
@@ -415,7 +416,7 @@ def chunked_prefill_paged_decode(
"Cannot use ROCm custom paged attention kernel,"
" falling back to Triton implementation."
)
- real_block_size = value_cache.shape[3]
+ real_block_size = value_cache.shape[2]
# The standard model directly uses the original block_size.
# Non-standard 544 uses 32 to accommodate integer division logic.
# Cap at 128 to avoid exceeding GPU shared memory limits
@@ -479,8 +480,8 @@ def chunked_prefill_paged_decode(
stride_k_cache_4=key_cache.stride(4),
stride_v_cache_0=value_cache.stride(0),
stride_v_cache_1=value_cache.stride(1),
- stride_v_cache_2=value_cache.stride(2),
- stride_v_cache_3=value_cache.stride(3),
+ stride_v_cache_2=value_cache.stride(3),
+ stride_v_cache_3=value_cache.stride(2),
filter_by_query_len=True,
query_start_len_ptr=query_start_loc,
USE_SINKS=sinks is not None,
diff --git a/vllm/v1/attention/ops/paged_attn.py b/vllm/v1/attention/ops/paged_attn.py
index 896e929b5433..6e54294902f9 100644
--- a/vllm/v1/attention/ops/paged_attn.py
+++ b/vllm/v1/attention/ops/paged_attn.py
@@ -19,13 +19,16 @@ def split_kv_cache(
num_kv_heads: int,
head_size: int,
) -> tuple[torch.Tensor, torch.Tensor]:
+ # [B, H, N, 2*C] -> key [B, H, C//x, N, x], value [B, H, C, N]
x = 16 // kv_cache.element_size()
- num_blocks = kv_cache.shape[1]
-
- key_cache = kv_cache[0]
- key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x, -1, x)
- value_cache = kv_cache[1]
- value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1)
+ key_slice = kv_cache[..., :head_size]
+ value_slice = kv_cache[..., head_size:]
+ key_cache = (
+ key_slice.permute(0, 1, 3, 2)
+ .unflatten(2, (head_size // x, x))
+ .transpose(3, 4)
+ )
+ value_cache = value_slice.permute(0, 1, 3, 2)
return key_cache, value_cache
@staticmethod
diff --git a/vllm/v1/attention/selector.py b/vllm/v1/attention/selector.py
index 5d63553a7b60..171002eb8cc6 100644
--- a/vllm/v1/attention/selector.py
+++ b/vllm/v1/attention/selector.py
@@ -8,15 +8,12 @@
import vllm.envs as envs
from vllm.config.cache import CacheDType
-from vllm.logger import init_logger
from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.v1.attention.backend import AttentionBackend, AttentionType
from vllm.v1.attention.backends.registry import (
MambaAttentionBackendEnum,
)
-logger = init_logger(__name__)
-
class AttentionSelectorConfig(NamedTuple):
head_size: int
@@ -128,19 +125,6 @@ def _cached_get_attn_backend(
f"Invalid attention backend for {current_platform.device_name}"
)
backend = resolve_obj_by_qualname(attention_cls)
-
- # Adjust kv cache layout if the selected backend requires a specific one
- required_layout = backend.get_required_kv_cache_layout()
- if required_layout is not None:
- from vllm.v1.attention.backends.utils import set_kv_cache_layout
-
- set_kv_cache_layout(required_layout)
- logger.info(
- "Using %s KV cache layout for %s backend.",
- required_layout,
- backend.get_name(),
- )
-
return backend
diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py
index 7f3a5e4fdf3f..44386046077d 100644
--- a/vllm/v1/core/kv_cache_utils.py
+++ b/vllm/v1/core/kv_cache_utils.py
@@ -10,7 +10,7 @@
from collections.abc import Callable, Iterable, Iterator, Sequence
from dataclasses import dataclass, replace
from functools import partial
-from typing import Any, NewType, TypeAlias, cast, overload
+from typing import Any, NewType, TypeAlias, overload
from vllm import envs
from vllm.config import VllmConfig
@@ -19,12 +19,15 @@
from vllm.utils.math_utils import cdiv, round_up
from vllm.utils.mem_utils import format_gib
from vllm.utils.torch_utils import get_dtype_size
+from vllm.v1.attention.backends.utils import resolve_kv_cache_layout
from vllm.v1.kv_cache_interface import (
+ AttentionSpec,
ChunkedLocalAttentionSpec,
FullAttentionSpec,
HiddenStateCacheSpec,
KVCacheConfig,
KVCacheGroupSpec,
+ KVCacheLayout,
KVCacheSpec,
KVCacheTensor,
MambaSpec,
@@ -912,24 +915,8 @@ def _pool_bytes_per_block(kv_cache_groups: list[KVCacheGroupSpec]) -> int:
`available_memory` into `num_blocks`. Used to compute the effective KV cache
capacity once `num_gpu_blocks_override` is applied.
"""
- if len(kv_cache_groups) == 1 and isinstance(
- kv_cache_groups[0].kv_cache_spec, UniformTypeKVCacheSpecs
- ):
- return kv_cache_groups[0].kv_cache_spec.page_size_bytes
- if all(
- isinstance(g.kv_cache_spec, UniformTypeKVCacheSpecs) for g in kv_cache_groups
- ):
- # DeepseekV4: shared layout sized by the largest per-page-size bucket.
- full_mla_spec = cast(UniformTypeKVCacheSpecs, kv_cache_groups[0].kv_cache_spec)
- layer_tuple_page_bytes = sum(full_mla_spec.get_page_sizes())
- num_layer_tuples = max(
- cast(UniformTypeKVCacheSpecs, g.kv_cache_spec).get_num_layer_tuples()
- for g in kv_cache_groups
- )
- return layer_tuple_page_bytes * num_layer_tuples
- group_size = max(len(g.layer_names) for g in kv_cache_groups)
- page_size = get_uniform_page_size([g.kv_cache_spec for g in kv_cache_groups])
- return page_size * group_size
+ buckets = _bucket_layers_by_page_size(kv_cache_groups)
+ return sum(ps * len(slots) for ps, slots in buckets.items())
def get_num_blocks(
@@ -1176,61 +1163,79 @@ def _get_kv_cache_groups_uniform_page_size(
return create_kv_cache_group_specs(kv_cache_spec, grouped_layers)
-def _get_kv_cache_config_deepseek_v4(
- vllm_config: VllmConfig,
+def _bucket_layers_by_page_size(
kv_cache_groups: list[KVCacheGroupSpec],
- available_memory: int,
-) -> tuple[int, list[KVCacheTensor]]:
- """DeepseekV4 KV cache tensor layout planning.
-
- Precondition: kv_cache_groups[0] is the full-MLA group; its page sizes
- define the canonical bucket set. Non-full-MLA groups must have been
- page_size-padded upstream (see _get_kv_cache_groups_uniform_groups) so
- every layer's page_size matches one of the full-MLA bucket sizes.
-
- For each group, bucket its layers by page_size_bytes and place each
- layer at tuple_idx = position-within-bucket. Emit one KVCacheTensor
- per (tuple_idx, bucket) whose shared_by is the union of per-group
- layers at that slot.
- """
- full_mla_spec = kv_cache_groups[0].kv_cache_spec
- assert isinstance(full_mla_spec, UniformTypeKVCacheSpecs)
- page_sizes = sorted(full_mla_spec.get_page_sizes())
- layer_tuple_page_bytes = sum(page_sizes)
-
- # Pre-bucket each group's layers by page_size (registration order within
- # bucket). bucketed[g_idx][page_size] = [layer_name, ...].
- bucketed: list[dict[int, list[str]]] = []
+) -> dict[int, list[list[str]]]:
+ """Bucket layers by page size: ``result[ps][slot_idx] = [layer_names]``.
+
+ Layers from different groups at the same ``slot_idx`` share a block
+ (they have independent block tables so block-id namespaces never collide).
+ """
+ buckets: dict[int, list[list[str]]] = defaultdict(list)
for group in kv_cache_groups:
- assert isinstance(group.kv_cache_spec, UniformTypeKVCacheSpecs)
- specs = group.kv_cache_spec.kv_cache_specs
- b: dict[int, list[str]] = defaultdict(list)
- for name in group.layer_names:
- b[specs[name].page_size_bytes].append(name)
- bucketed.append(b)
-
- # num_layer_tuples = longest bucket list across all groups. For the
- # full-MLA group this equals the count of layers in the largest
- # per-page-size bucket (= get_num_layer_tuples()); for SWA sub-groups
- # this equals the sub-group size (each has a single page_size).
- num_layer_tuples = max(len(layers) for b in bucketed for layers in b.values())
-
- num_blocks = available_memory // (layer_tuple_page_bytes * num_layer_tuples)
- num_blocks = may_override_num_blocks(vllm_config, num_blocks)
+ spec = group.kv_cache_spec
+ slot_count: dict[int, int] = defaultdict(int)
+ for layer_name in group.layer_names:
+ if isinstance(spec, UniformTypeKVCacheSpecs):
+ ps = spec.kv_cache_specs[layer_name].page_size_bytes
+ else:
+ ps = spec.page_size_bytes
+ slot_idx = slot_count[ps]
+ slot_count[ps] += 1
+ if slot_idx == len(buckets[ps]):
+ buckets[ps].append([])
+ buckets[ps][slot_idx].append(layer_name)
+ return buckets
+
+
+def _get_per_layer_spec(
+ group: KVCacheGroupSpec,
+ layer_name: str,
+) -> KVCacheSpec:
+ """Return the concrete KVCacheSpec for a specific layer."""
+ spec = group.kv_cache_spec
+ if isinstance(spec, UniformTypeKVCacheSpecs):
+ return spec.kv_cache_specs[layer_name]
+ return spec
+
+
+def _validate_layout_compatibility(
+ kv_cache_tensors: list[KVCacheTensor],
+ kv_cache_groups: list[KVCacheGroupSpec],
+) -> None:
+ """Ensure [H, N, C] is contiguous when layers sharing a tensor differ."""
+ layout = resolve_kv_cache_layout()
- kv_cache_tensors: list[KVCacheTensor] = []
- for tuple_idx in range(num_layer_tuples):
- for ps in page_sizes:
- shared_by: list[str] = []
- for b in bucketed:
- bucket = b.get(ps)
- if bucket is not None and tuple_idx < len(bucket):
- shared_by.append(bucket[tuple_idx])
- kv_cache_tensors.append(
- KVCacheTensor(size=ps * num_blocks, shared_by=shared_by)
- )
+ # The per-layer-per-block content [H, N, C] is contiguous therefore
+ # these layouts are always valid.
+ if layout in (KVCacheLayout.LBHNC, KVCacheLayout.BLHNC):
+ return
- return num_blocks, kv_cache_tensors
+ layer_spec_map: dict[str, KVCacheSpec] = {}
+ for group in kv_cache_groups:
+ for layer_name in group.layer_names:
+ layer_spec_map[layer_name] = _get_per_layer_spec(group, layer_name)
+
+ for tensor in kv_cache_tensors:
+ all_layer_names = [name for slot in tensor.shared_by for name in slot]
+ if len(all_layer_names) <= 1:
+ continue
+ attn_specs = [
+ layer_spec_map[n]
+ for n in all_layer_names
+ if isinstance(layer_spec_map[n], AttentionSpec)
+ ]
+ if len(attn_specs) <= 1:
+ continue
+ shapes = {(s.num_heads, s.state_content_size_bytes) for s in attn_specs}
+ if len(shapes) > 1:
+ raise ValueError(
+ f"Groups {kv_cache_groups} share a KVCacheTensor but"
+ f" have different (num_heads, state_content_size_bytes)"
+ f" for layers {all_layer_names}. Use a layout where"
+ f" [H, N, C] is contiguous (e.g."
+ f" VLLM_KV_CACHE_LAYOUT=LBHNC or BLHNC)."
+ )
def get_kv_cache_config_from_groups(
@@ -1238,8 +1243,7 @@ def get_kv_cache_config_from_groups(
kv_cache_groups: list[KVCacheGroupSpec],
available_memory: int,
) -> KVCacheConfig:
- """
- Generate the KV cache configuration from the KV cache groups and spec
+ """Generate the KV cache configuration from the KV cache groups and spec
of each layer.
Args:
@@ -1258,61 +1262,24 @@ def get_kv_cache_config_from_groups(
kv_cache_groups=kv_cache_groups,
)
- # Determine how model runners should initialize the KV cache tensors.
- if len(kv_cache_groups) == 1 and isinstance(
- kv_cache_groups[0].kv_cache_spec, UniformTypeKVCacheSpecs
- ):
- # Special case: all layers have the same type of KV cache but with
- # different hidden sizes. Allocate different amount of memory for each
- # layer based on its hidden size.
- num_blocks = (
- available_memory // kv_cache_groups[0].kv_cache_spec.page_size_bytes
- )
- num_blocks = may_override_num_blocks(vllm_config, num_blocks)
- per_layer_specs = kv_cache_groups[0].kv_cache_spec.kv_cache_specs
- kv_cache_tensors = [
+ buckets = _bucket_layers_by_page_size(kv_cache_groups)
+
+ page_sizes = list(buckets.keys())
+ bytes_per_block = sum(ps * len(buckets[ps]) for ps in page_sizes)
+ num_blocks = available_memory // bytes_per_block
+ num_blocks = may_override_num_blocks(vllm_config, num_blocks)
+
+ kv_cache_tensors: list[KVCacheTensor] = []
+ for ps in page_sizes:
+ shared_by = buckets[ps]
+ kv_cache_tensors.append(
KVCacheTensor(
- size=per_layer_specs[layer_name].page_size_bytes * num_blocks,
- shared_by=[layer_name],
+ size=ps * num_blocks * len(shared_by),
+ shared_by=shared_by,
)
- for layer_name in kv_cache_groups[0].layer_names
- ]
- elif all(
- isinstance(group.kv_cache_spec, UniformTypeKVCacheSpecs)
- for group in kv_cache_groups
- ):
- # DeepseekV4: UniformTypeKVCacheSpecs but multiple groups.
- # Delegate to the DeepseekV4-specific allocator.
- num_blocks, kv_cache_tensors = _get_kv_cache_config_deepseek_v4(
- vllm_config, kv_cache_groups, available_memory
- )
- else:
- # General case:
- # We will have group_size memory pools, each is shared by one layer from
- # each group. As layers of different groups have different block table,
- # they will use different parts of the shared Tensor.
- # The memory layout for 3 groups (full.0, full.1), (sw.0, sw.2),
- # (sw.1, padding) will be: (group_size = 2)
- # full.0, sw.0, sw.1: share a Tensor with size=available_memory//2
- # full.1, sw.2: share another Tensor with size=available_memory//2
- group_size = max(len(group.layer_names) for group in kv_cache_groups)
-
- page_size = get_uniform_page_size(
- [group.kv_cache_spec for group in kv_cache_groups]
)
- assert group_size > 0, "group_size must be greater than 0"
- num_blocks = get_num_blocks(
- vllm_config, group_size, available_memory, page_size
- )
- kv_cache_tensors = []
- for i in range(group_size):
- shared_by = []
- for j in range(len(kv_cache_groups)):
- if i < len(kv_cache_groups[j].layer_names):
- shared_by.append(kv_cache_groups[j].layer_names[i])
- kv_cache_tensors.append(
- KVCacheTensor(size=page_size * num_blocks, shared_by=shared_by)
- )
+
+ _validate_layout_compatibility(kv_cache_tensors, kv_cache_groups)
return KVCacheConfig(
num_blocks=num_blocks,
@@ -1383,7 +1350,7 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
page_size_padded=spec.page_size_padded,
cache_dtype_str=spec.cache_dtype_str,
alignment=spec.alignment,
- compress_ratio=spec.compress_ratio,
+ tokens_per_state=spec.tokens_per_state,
model_version=spec.model_version,
)
elif isinstance(spec, SlidingWindowSpec):
@@ -1745,60 +1712,30 @@ def _max_memory_usage_bytes_from_groups(
"""
Calculate maximum memory usage in bytes from KV cache groups.
- This correctly accounts for padding in hybrid models. For example, if a
- model has 8 full attention layers and 9 sliding window layers, they will
- be padded to 9 full + 9 sliding window for uniform group sizes.
+ A request needs blocks from *every* group simultaneously. The total
+ blocks consumed per request is the **sum** across groups (each group
+ independently claims blocks from the shared pool). Total memory is
+ ``bytes_per_block * sum_of_blocks``.
"""
if not kv_cache_groups:
return 0
- if len(kv_cache_groups) == 1 and isinstance(
- kv_cache_groups[0].kv_cache_spec, UniformTypeKVCacheSpecs
- ):
- # UniformTypeKVCacheSpecs special case (single group, per-layer specs)
- per_layer_specs = kv_cache_groups[0].kv_cache_spec.kv_cache_specs
- return sum(
- spec.max_memory_usage_bytes(vllm_config)
- for spec in per_layer_specs.values()
- )
- elif all(
- isinstance(group.kv_cache_spec, UniformTypeKVCacheSpecs)
- for group in kv_cache_groups
- ):
- # Special case (only DeepseekV4 for now): all groups are
- # UniformTypeKVCacheSpecs.
- # They must already be page_size aligned and share a common padded
- # layer-tuple layout. Even groups with fewer actual tuples still reserve
- # the global number of tuple slots in the shared tensor layout.
- full_mla_spec = cast(UniformTypeKVCacheSpecs, kv_cache_groups[0].kv_cache_spec)
- layer_tuple_bytes = sum(full_mla_spec.get_page_sizes())
- num_layer_tuples = max(
- cast(UniformTypeKVCacheSpecs, group.kv_cache_spec).get_num_layer_tuples()
- for group in kv_cache_groups
- )
+ bytes_per_block = _pool_bytes_per_block(kv_cache_groups)
+ if bytes_per_block == 0:
+ return 0
- total_max_mem_usage_bytes = 0
- for group in kv_cache_groups:
- group_spec = cast(UniformTypeKVCacheSpecs, group.kv_cache_spec)
- g_max_mem_usage_pages = group_spec.max_memory_usage_pages(vllm_config)
- g_max_mem_usage_page_bytes = (
- num_layer_tuples * g_max_mem_usage_pages * layer_tuple_bytes
+ total_blocks = 0
+ for group in kv_cache_groups:
+ spec = group.kv_cache_spec
+ if isinstance(spec, UniformTypeKVCacheSpecs):
+ total_blocks += spec.max_memory_usage_pages(vllm_config)
+ else:
+ total_blocks += cdiv(
+ spec.max_memory_usage_bytes(vllm_config),
+ spec.page_size_bytes,
)
- total_max_mem_usage_bytes += g_max_mem_usage_page_bytes
- return total_max_mem_usage_bytes
-
- # General case: group_size pools, each shared by one layer per group
- # Memory = group_size * page_size * blocks_for_max_len
- group_size = max(len(group.layer_names) for group in kv_cache_groups)
- page_size = get_uniform_page_size(
- [group.kv_cache_spec for group in kv_cache_groups]
- )
- blocks_needed = sum(
- cdiv(group.kv_cache_spec.max_memory_usage_bytes(vllm_config), page_size)
- for group in kv_cache_groups
- )
- return group_size * page_size * blocks_needed
+ return bytes_per_block * total_blocks
def _estimate_max_model_len_from_groups(
diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py
index 31ee89bc72aa..727c7f58e9eb 100644
--- a/vllm/v1/kv_cache_interface.py
+++ b/vllm/v1/kv_cache_interface.py
@@ -95,11 +95,27 @@ class KVCacheSpecKind(str, Enum):
class KVCacheSpec:
"""
A base class for specifying the KV cache format of one layer.
+
+ RFC #42082 standard vocabulary (properties, overridden by subclasses):
+ num_heads: int — heads (1 if headless, e.g. MLA)
+ tokens_per_state: int — -1 infinite (recurrent), 1 standard, N compressed
+ state_content_size_bytes: int — bytes per state per head
"""
- # number of tokens in a block
block_size: int
+ @property
+ def num_heads(self) -> int:
+ raise NotImplementedError
+
+ @property
+ def tokens_per_state(self) -> int:
+ raise NotImplementedError
+
+ @property
+ def state_content_size_bytes(self) -> int:
+ raise NotImplementedError
+
@property
def page_size_bytes(self) -> int:
"""
@@ -140,6 +156,104 @@ def merge(cls, specs: list[Self]) -> Self:
return copy.deepcopy(specs[0])
+# Logical dim indices in the 5D stride permutation [L, B, H, N, C] (see: RFC #42082).
+_DIM_L, _DIM_B, _DIM_H, _DIM_N, _DIM_C = 0, 1, 2, 3, 4
+
+
+class KVCacheLayout(Enum):
+ """Physical layout descriptor for a KV cache group.
+
+ The logical shape is always [L, B, H, N, ] (RFC #42082).
+ Each member's value is a stride permutation that maps logical axes
+ to physical (memory) order.
+ """
+
+ LBHNC = (0, 1, 2, 3, 4) # [L, B, H, N, C] (identity)
+ LBNHC = (0, 1, 3, 2, 4) # [L, B, N, H, C]
+ BLHNC = (1, 0, 2, 3, 4) # [B, L, H, N, C]
+ BHLNC = (1, 2, 0, 3, 4) # [B, H, L, N, C]
+
+ @property
+ def stride_order(self) -> tuple[int, ...]:
+ return self.value
+
+ @property
+ def layer_stride_order(self) -> tuple[int, ...]:
+ """4D permutation [B, H, N, C] for per-layer tensors."""
+ if not self.is_layer_compact:
+ compact = [m.name for m in KVCacheLayout if m.is_layer_compact]
+ raise ValueError(
+ f"KVCacheLayout.{self.name} cannot produce a 4D"
+ f" layer_stride_order because L is not outermost."
+ f" Use a layer-compact layout: {compact}"
+ )
+ return tuple(i - 1 for i in self.value if i != _DIM_L)
+
+ @property
+ def is_layer_compact(self) -> bool:
+ """True when the layer is compact; i.e. the L dimension is outermost."""
+ return self.value[_DIM_L] == 0
+
+ @property
+ def is_block_contiguous(self) -> bool:
+ """True when [H, N, C] is contiguous within a block."""
+ return self.value[-3:] == (_DIM_H, _DIM_N, _DIM_C)
+
+
+def num_states_for(block_size: int, tokens_per_state: int) -> int:
+ """Derive num_states at allocation time (not part of the spec)."""
+ if tokens_per_state == -1:
+ return 1 # recurrent: single state per block
+ return block_size // tokens_per_state
+
+
+def compute_layer_kv_cache_shape_bytes(
+ spec: KVCacheSpec,
+ num_blocks: int,
+ block_size: int | None = None,
+) -> tuple[int, ...]:
+ """Return the 4D logical shape ``(B, H, N, C)`` where C is in bytes."""
+ bs = block_size if block_size is not None else spec.storage_block_size
+ ns = num_states_for(bs, spec.tokens_per_state)
+ return (num_blocks, spec.num_heads, ns, spec.state_content_size_bytes)
+
+
+def reshape_kv_cache(
+ raw: torch.Tensor,
+ spec: KVCacheSpec,
+ num_blocks: int,
+ num_layer_slots: int,
+ layout: KVCacheLayout,
+ block_size: int | None = None,
+) -> list[torch.Tensor]:
+ """View a flat int8 buffer as 4D ``[B, H, N, C]`` per-slot views.
+
+ Works for all KVCacheSpec subclasses. Shapes as int8 via
+ compute_layer_kv_cache_shape_bytes, then reinterprets as spec.dtype.
+ """
+ dtype = getattr(spec, "dtype", None)
+ logical_shape_bytes = (
+ num_layer_slots,
+ *compute_layer_kv_cache_shape_bytes(spec, num_blocks, block_size),
+ )
+ stride_order = layout.stride_order
+ physical_shape_bytes = tuple(logical_shape_bytes[i] for i in stride_order)
+ inv_order = [stride_order.index(i) for i in range(5)]
+
+ if page_size_padded := getattr(spec, "page_size_padded", None):
+ strides = list(torch.empty(physical_shape_bytes, device="meta").stride())
+ strides[inv_order[_DIM_B]] = page_size_padded
+ cache = torch.as_strided(raw, size=physical_shape_bytes, stride=tuple(strides))
+ else:
+ cache = raw.view(physical_shape_bytes)
+ cache_logical = cache.permute(*inv_order)
+
+ if dtype is not None:
+ cache_logical = cache_logical.view(dtype)
+
+ return [cache_logical[i] for i in range(num_layer_slots)]
+
+
@dataclass(frozen=True, kw_only=True)
class AttentionSpec(KVCacheSpec):
num_kv_heads: int
@@ -147,6 +261,24 @@ class AttentionSpec(KVCacheSpec):
dtype: torch.dtype
kv_quant_mode: KVQuantMode = KVQuantMode.NONE
page_size_padded: int | None = None
+ tokens_per_state: int = 1
+
+ # NVFP4: K and V are stored as separate head groups [L, B, 2*H, N, dim]
+ # rather than interleaved in content [L, B, H, N, 2*dim].
+
+ @property
+ def num_heads(self) -> int:
+ if self.kv_quant_mode.is_nvfp4:
+ return 2 * self.num_kv_heads
+ return self.num_kv_heads
+
+ @property
+ def state_content_size_bytes(self) -> int:
+ hs = self.head_size
+ if self.kv_quant_mode.is_nvfp4:
+ hs = nvfp4_kv_cache_full_dim(hs)
+ return hs * get_dtype_size(self.dtype)
+ return 2 * hs * get_dtype_size(self.dtype)
@property
def page_size_bytes(self) -> int:
@@ -207,6 +339,19 @@ def __post_init__(self):
if self.head_size_v is None:
object.__setattr__(self, "head_size_v", self.head_size)
+ @property
+ def state_content_size_bytes(self) -> int:
+ hs_k = self.head_size
+ hs_v = self.head_size_v
+ if self.kv_quant_mode.is_nvfp4:
+ hs_k = nvfp4_kv_cache_full_dim(hs_k)
+ hs_v = nvfp4_kv_cache_full_dim(hs_v)
+ assert hs_k == hs_v, (
+ "nvfp4 with asymmetric K/V head sizes not yet supported"
+ )
+ return hs_k * get_dtype_size(self.dtype)
+ return (hs_k + hs_v) * get_dtype_size(self.dtype)
+
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
max_model_len = vllm_config.model_config.max_model_len
dcp_world_size = vllm_config.parallel_config.decode_context_parallel_size
@@ -318,6 +463,12 @@ class TQFullAttentionSpec(FullAttentionSpec):
tq_slot_size: int = 0
+ @property
+ def state_content_size_bytes(self) -> int:
+ if self.tq_slot_size > 0:
+ return self.tq_slot_size
+ return super().state_content_size_bytes
+
@property
def real_page_size_bytes(self) -> int:
if self.tq_slot_size > 0:
@@ -339,16 +490,22 @@ class MLAAttentionSpec(FullAttentionSpec):
cache_dtype_str: str | None = None
# DeepseekV4 only fields. Non-DeepseekV4 MLA models leave these at defaults.
alignment: int | None = None # Default to None for no padding.
- compress_ratio: int = 1 # Default to 1 for no compression.
+ tokens_per_state: int = 1
model_version: str | None = None
def __post_init__(self):
super().__post_init__()
_apply_alignment_padding(self)
+ @property
+ def state_content_size_bytes(self) -> int:
+ if self.cache_dtype_str == "fp8_ds_mla":
+ return 584 if self.model_version == "deepseek_v4" else 656
+ return self.head_size * get_dtype_size(self.dtype)
+
@property
def storage_block_size(self) -> int:
- return self.block_size // self.compress_ratio
+ return self.block_size // self.tokens_per_state
@property
def real_page_size_bytes(self) -> int:
@@ -373,11 +530,11 @@ def merge(cls, specs: list[Self]) -> Self:
"All attention layers in the same KV cache group must be MLAAttentionSpec."
)
cache_dtype_str_set = set(spec.cache_dtype_str for spec in specs)
- compress_ratio_set = set(spec.compress_ratio for spec in specs)
+ tokens_per_state_set = set(spec.tokens_per_state for spec in specs)
model_version_set = set(spec.model_version for spec in specs)
assert (
len(cache_dtype_str_set) == 1
- and len(compress_ratio_set) == 1
+ and len(tokens_per_state_set) == 1
and len(model_version_set) == 1
), (
"All attention layers in the same KV cache group must use the same "
@@ -391,7 +548,7 @@ def merge(cls, specs: list[Self]) -> Self:
kv_quant_mode=specs[0].kv_quant_mode,
page_size_padded=specs[0].page_size_padded,
cache_dtype_str=cache_dtype_str_set.pop(),
- compress_ratio=compress_ratio_set.pop(),
+ tokens_per_state=tokens_per_state_set.pop(),
model_version=model_version_set.pop(),
)
@@ -501,15 +658,22 @@ class SlidingWindowMLASpec(SlidingWindowSpec):
cache_dtype_str: str | None = None
# DeepseekV4-only: see MLAAttentionSpec.model_version.
alignment: int | None = None # Default to None for no padding.
- compress_ratio: int = 1
+ tokens_per_state: int = 1
model_version: str | None = None
def __post_init__(self):
+ super().__post_init__()
_apply_alignment_padding(self)
+ @property
+ def state_content_size_bytes(self) -> int:
+ if self.model_version == "deepseek_v4":
+ return 584
+ return self.head_size * get_dtype_size(self.dtype)
+
@property
def storage_block_size(self) -> int:
- return self.block_size // self.compress_ratio
+ return self.block_size // self.tokens_per_state
@property
def real_page_size_bytes(self) -> int:
@@ -533,12 +697,12 @@ def merge(cls, specs: list[Self]) -> Self:
"SlidingWindowMLASpec."
)
cache_dtype_str_set = set(spec.cache_dtype_str for spec in specs)
- compress_ratio_set = set(spec.compress_ratio for spec in specs)
+ tokens_per_state_set = set(spec.tokens_per_state for spec in specs)
model_version_set = set(spec.model_version for spec in specs)
sliding_window_set = set(spec.sliding_window for spec in specs)
assert (
len(cache_dtype_str_set) == 1
- and len(compress_ratio_set) == 1
+ and len(tokens_per_state_set) == 1
and len(model_version_set) == 1
and len(sliding_window_set) == 1
), (
@@ -554,7 +718,7 @@ def merge(cls, specs: list[Self]) -> Self:
page_size_padded=specs[0].page_size_padded,
sliding_window=sliding_window_set.pop(),
cache_dtype_str=cache_dtype_str_set.pop(),
- compress_ratio=compress_ratio_set.pop(),
+ tokens_per_state=tokens_per_state_set.pop(),
model_version=model_version_set.pop(),
)
@@ -567,6 +731,15 @@ class MambaSpec(KVCacheSpec):
mamba_type: MambaAttentionBackendEnum = MambaAttentionBackendEnum.MAMBA2
mamba_cache_mode: str = "none"
num_speculative_blocks: int = 0
+ num_heads: int = 1
+ tokens_per_state: int = -1
+
+ @property
+ def state_content_size_bytes(self) -> int:
+ return sum(
+ prod(shape) * get_dtype_size(dtype)
+ for (shape, dtype) in zip(self.shapes, self.dtypes)
+ )
@property
def page_size_bytes(self) -> int:
@@ -811,12 +984,15 @@ def get_kv_cache_spec_sliding_window(kv_cache_spec: KVCacheSpec) -> int | None:
@dataclass
class KVCacheTensor:
- """
- A class for specifying how the workers should initialize the KV cache.
+ """One contiguous GPU allocation backing one or more layer slots.
+
+ ``shared_by[slot_idx]`` lists the layer names aliasing slot ``slot_idx``.
+ Layers in the same inner list belong to different groups (independent
+ block tables) so their block-id namespaces never collide.
"""
- size: int # size of the KV cache tensor in bytes
- shared_by: list[str] # layer names that share the same KV cache tensor
+ size: int # total size in bytes
+ shared_by: list[list[str]] # shared_by[slot_idx] = [layer_names]
@dataclass
diff --git a/vllm/v1/kv_offload/base.py b/vllm/v1/kv_offload/base.py
index 5f798f41eac8..00ec1fc55757 100644
--- a/vllm/v1/kv_offload/base.py
+++ b/vllm/v1/kv_offload/base.py
@@ -330,10 +330,8 @@ class CanonicalKVCacheTensor:
"""
A canonicalized KV cache tensor whose first dimension is num_blocks.
- For attention backends where the raw tensor has num_blocks at a
- non-leading physical dimension (e.g. FlashAttention's
- (2, num_blocks, ...) layout), the tensor is split so that each
- resulting CanonicalKVCacheTensor starts with (num_blocks, ...).
+ With standardized layouts (RFC #42082) num_blocks is always the
+ leading logical dimension.
"""
# The KV cache tensor with shape (num_blocks, ...)
diff --git a/vllm/v1/simple_kv_offload/manager.py b/vllm/v1/simple_kv_offload/manager.py
index f61c4320dffd..bf0dffa7346b 100644
--- a/vllm/v1/simple_kv_offload/manager.py
+++ b/vllm/v1/simple_kv_offload/manager.py
@@ -192,7 +192,7 @@ def _derive_cpu_config(
cpu_tensors = [
KVCacheTensor(
size=t.size // num_gpu_blocks * num_cpu_blocks,
- shared_by=list(t.shared_by),
+ shared_by=[list(slot) for slot in t.shared_by],
)
for t in gpu_config.kv_cache_tensors
]
diff --git a/vllm/v1/simple_kv_offload/worker.py b/vllm/v1/simple_kv_offload/worker.py
index c23b44f29173..09559be6cb9e 100644
--- a/vllm/v1/simple_kv_offload/worker.py
+++ b/vllm/v1/simple_kv_offload/worker.py
@@ -80,15 +80,8 @@ def register_kv_caches(
logger.warning("No KV caches to offload.")
return
- # Resolve each entry to a representative tensor for storage
- # deduplication. For attention layers the value is already a tensor;
- # for Mamba layers it is a list of tensors that all share the same
- # underlying raw storage, so we take the first one.
- def _repr_tensor(v: torch.Tensor | list[torch.Tensor]) -> torch.Tensor:
- assert isinstance(v, torch.Tensor | list)
- return v if isinstance(v, torch.Tensor) else v[0]
-
- any_tensor = _repr_tensor(next(iter(kv_caches.values())))
+ any_tensor = next(iter(kv_caches.values()))
+ assert isinstance(any_tensor, torch.Tensor)
self.device = any_tensor.device
assert self.kv_cache_config is not None
@@ -97,7 +90,8 @@ def _repr_tensor(v: torch.Tensor | list[torch.Tensor]) -> torch.Tensor:
# Deduplicate: multiple layers may share the same backing storage.
seen_ptrs: dict[int, tuple[str, torch.Tensor]] = {}
for name, value in kv_caches.items():
- tensor = _repr_tensor(value)
+ assert isinstance(value, torch.Tensor)
+ tensor = value
ptr = tensor.untyped_storage().data_ptr()
if ptr not in seen_ptrs:
seen_ptrs[ptr] = (name, tensor)
@@ -105,13 +99,11 @@ def _repr_tensor(v: torch.Tensor | list[torch.Tensor]) -> torch.Tensor:
# Build [num_blocks, block_bytes] int8 views from each unique
# storage so that stride(0) gives block_bytes for the copy op.
#
- # The physical layout varies across attention backends:
- # FlashAttn/ROCm: (2, num_blocks, ...) -> K/V outermost, 2 segments
- # FlashInfer/MLA: (num_blocks, ...) -> blocks outermost, 1 segment
- # We derive page_size_bytes = storage.nbytes() // num_blocks, then
- # classify dims: any dim whose byte-stride exceeds page_size_bytes
- # must be an outer segment dim (e.g. the K/V dim of size 2). A less
- # hacky way is to update the interface with the layout.
+ # With standardized layouts (RFC #42082) num_blocks is always the
+ # leading logical dim, but cross-layer physical layouts
+ # (e.g. BLHNC) interleave layers between blocks, inflating the
+ # block stride. Detect any such outer segment dims by comparing
+ # byte-strides against page_size_bytes.
unique_gpu_caches: dict[str, torch.Tensor] = {}
for name, tensor in seen_ptrs.values():
storage = tensor.untyped_storage()
diff --git a/vllm/v1/worker/cpu_model_runner.py b/vllm/v1/worker/cpu_model_runner.py
index 6afffa424d42..f0b6529d3f84 100644
--- a/vllm/v1/worker/cpu_model_runner.py
+++ b/vllm/v1/worker/cpu_model_runner.py
@@ -12,7 +12,7 @@
from vllm.model_executor.model_loader import get_model
from vllm.tracing import instrument
from vllm.v1.core.sched.output import SchedulerOutput
-from vllm.v1.kv_cache_interface import KVCacheConfig
+from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig, KVCacheLayout
from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
@@ -124,6 +124,40 @@ def warming_up_model(self) -> None:
self.profile_run()
logger.info("Warming up done.")
+ def _allocate_and_reshape_kv_cache(
+ self,
+ kv_cache_config: KVCacheConfig,
+ kernel_block_sizes: list[int],
+ layout: KVCacheLayout = KVCacheLayout.LBHNC,
+ ) -> dict[str, torch.Tensor]:
+ # Re-view the interleaved K|V content dim as separate head groups so
+ # the backend can slice contiguous per-head K/V:
+ # [num_blocks, num_kv_heads, block_size, 2*head_size] ->
+ # [num_blocks, 2*num_kv_heads, block_size, head_size]. Assumes
+ # content == 2*head_size (true for all CPU attention specs; MLA, which
+ # differs, isn't supported).
+ kv_caches = super()._allocate_and_reshape_kv_cache(
+ kv_cache_config,
+ kernel_block_sizes,
+ layout=KVCacheLayout.LBHNC,
+ )
+ attn_layers = {
+ name
+ for group in self._kv_cache_spec_attn_group_iterator()
+ if isinstance(group.kv_cache_spec, AttentionSpec)
+ for name in group.layer_names
+ }
+ for name in attn_layers:
+ cache = kv_caches.get(name)
+ if cache is None:
+ continue
+ num_blocks, num_kv_heads, block_size, content = cache.shape
+ assert cache.is_contiguous() and content % 2 == 0
+ kv_caches[name] = cache.view(
+ num_blocks, 2 * num_kv_heads, block_size, content // 2
+ )
+ return kv_caches
+
def initialize_kv_cache(
self,
kv_cache_config: KVCacheConfig,
diff --git a/vllm/v1/worker/gpu/attn_utils.py b/vllm/v1/worker/gpu/attn_utils.py
index 6fc55ee32030..9ccad49397ae 100644
--- a/vllm/v1/worker/gpu/attn_utils.py
+++ b/vllm/v1/worker/gpu/attn_utils.py
@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-from collections.abc import Iterable, Sequence
+from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any, cast
@@ -9,17 +9,19 @@
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
-from vllm.utils.torch_utils import get_dtype_size
from vllm.v1.attention.backend import (
+ AttentionBackend,
AttentionCGSupport,
CommonAttentionMetadata,
)
+from vllm.v1.attention.backends.utils import resolve_kv_cache_layout
from vllm.v1.kv_cache_interface import (
AttentionSpec,
KVCacheConfig,
+ KVCacheLayout,
KVCacheSpec,
- MambaSpec,
UniformTypeKVCacheSpecs,
+ reshape_kv_cache,
)
from vllm.v1.worker.gpu.model_states.interface import ModelSpecificAttnMetadata
from vllm.v1.worker.utils import (
@@ -147,199 +149,70 @@ def init_attn_backend(
return attn_groups, attn_cg_support_info, kernel_block_sizes
-def _allocate_kv_cache(
- kv_cache_config: KVCacheConfig, shared_layers: dict[str, str], device: torch.device
-):
- kv_cache_raw_tensors: dict[str, torch.Tensor] = {}
- for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
- tensor = torch.zeros(kv_cache_tensor.size, dtype=torch.int8, device=device)
- for layer_name in kv_cache_tensor.shared_by:
- kv_cache_raw_tensors[layer_name] = tensor
+def _allocate_and_reshape_kv_cache(
+ kv_cache_config: KVCacheConfig,
+ device: torch.device,
+ layout: KVCacheLayout | None = None,
+ kernel_block_sizes: list[int] | None = None,
+ attn_backends: tuple[type[AttentionBackend], ...] | None = None,
+) -> dict[str, Any]:
+ if layout is None:
+ layout = resolve_kv_cache_layout(
+ tuple(attn_backends) if attn_backends else None
+ )
- layer_names = set()
- for group in kv_cache_config.kv_cache_groups:
+ layer_to_group: dict[str, tuple[KVCacheSpec, int]] = {}
+ for group_id, group in enumerate(kv_cache_config.kv_cache_groups):
+ spec = group.kv_cache_spec
for layer_name in group.layer_names:
- layer_names.add(layer_name)
- assert layer_names == (kv_cache_raw_tensors.keys() | shared_layers.keys()), (
- "Some layers are not correctly initialized"
- )
- return kv_cache_raw_tensors
-
+ if isinstance(spec, UniformTypeKVCacheSpecs):
+ layer_to_group[layer_name] = (spec.kv_cache_specs[layer_name], group_id)
+ else:
+ layer_to_group[layer_name] = (spec, group_id)
-def _reshape_kv_cache(
- attn_groups: Sequence[AttentionGroup],
- kv_cache_raw_tensors: dict[str, torch.Tensor],
- cache_dtype: str,
- kernel_block_sizes: list[int],
- shared_kv_cache_layers: dict[str, str],
-) -> dict[str, Any]:
kv_caches: dict[str, Any] = {}
- has_attn, has_mamba = False, False
-
- for group in attn_groups:
- if group.kv_cache_group_id >= len(kernel_block_sizes):
- continue
-
- kv_cache_spec = group.kv_cache_spec
- if kv_cache_spec.storage_block_size != kv_cache_spec.block_size:
- # use storage_block_size as the kernel block size for groups
- # that apply a compression on block size (eg. DeepSeek V4).
- kernel_block_size = kv_cache_spec.storage_block_size
- else:
- kernel_block_size = kernel_block_sizes[group.kv_cache_group_id]
-
- for layer_name in group.layer_names:
- if layer_name in shared_kv_cache_layers:
- # Shared layer — tensor will be aliased to its target later.
+ for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
+ num_layer_slots = len(kv_cache_tensor.shared_by)
+ buf = torch.zeros(kv_cache_tensor.size, dtype=torch.int8, device=device)
+
+ layer_to_slot: dict[str, int] = {}
+ for slot_idx, slot_layers in enumerate(kv_cache_tensor.shared_by):
+ for layer_name in slot_layers:
+ layer_to_slot[layer_name] = slot_idx
+
+ tensor_layers = set(layer_to_slot)
+ slot_bytes = kv_cache_tensor.size // num_layer_slots
+ for group_id, group in enumerate(kv_cache_config.kv_cache_groups):
+ layer_names = [n for n in group.layer_names if n in tensor_layers]
+ if not layer_names:
continue
-
- kv_raw_tensor = kv_cache_raw_tensors[layer_name]
- assert kv_raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0
- num_blocks = kv_raw_tensor.numel() // kv_cache_spec.page_size_bytes
-
- if isinstance(kv_cache_spec, AttentionSpec):
- has_attn = True
- # Use storage_block_size: it equals block_size for uncompressed
- # specs but is smaller for compressed ones (DeepSeek V4), which
- # store block_size tokens in block_size // compress_ratio slots.
- num_blocks_per_kv_block = (
- kv_cache_spec.storage_block_size // kernel_block_size
- )
- kernel_num_blocks = num_blocks * num_blocks_per_kv_block
- kv_cache_shape = group.backend.get_kv_cache_shape(
- kernel_num_blocks,
- kernel_block_size,
- kv_cache_spec.num_kv_heads,
- kv_cache_spec.head_size,
- cache_dtype_str=cache_dtype,
- )
-
- # FIXME(woosuk): Add kv_cache_stride_order to all attention backends.
- try:
- kv_cache_stride_order = group.backend.get_kv_cache_stride_order()
- assert len(kv_cache_stride_order) == len(kv_cache_shape)
- except (AttributeError, NotImplementedError):
- kv_cache_stride_order = tuple(range(len(kv_cache_shape)))
-
- kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order)
- inv_order = [
- kv_cache_stride_order.index(i)
- for i in range(len(kv_cache_stride_order))
- ]
-
- dtype = kv_cache_spec.dtype
- kv_tensor = kv_raw_tensor.view(dtype)
- if kv_cache_spec.page_size_padded is not None:
- # Use strided view to handle page_size_bytes that
- # include padding. This follows the same pattern as
- # MambaSpec handling in gpu_model_runner.py.
- # NOTE: This assumes kv_cache_shape[0] == num_blocks
- # (i.e. the first physical dimension is the block
- # index), which holds for all current backends
- # (MLA, FlashAttention, TritonAttention, etc.).
- dtype_size = get_dtype_size(dtype)
- page_stride = kv_cache_spec.page_size_bytes // dtype_size
- strides = list(torch.empty(kv_cache_shape).stride())
- strides[inv_order[0]] = page_stride
- kv_cache = torch.as_strided(
- kv_tensor,
- size=kv_cache_shape,
- stride=tuple(strides),
- )
- else:
- # No padding — safe to use a contiguous view.
- kv_cache = kv_tensor.view(kv_cache_shape)
- kv_caches[layer_name] = kv_cache.permute(*inv_order)
-
- elif isinstance(kv_cache_spec, MambaSpec):
- has_mamba = True
- state_tensors = []
- storage_offset_bytes = 0
- for shape, dtype in zip(kv_cache_spec.shapes, kv_cache_spec.dtypes):
- dtype_size = get_dtype_size(dtype)
- num_element_per_page = kv_cache_spec.page_size_bytes // dtype_size
- target_shape = (num_blocks, *shape)
- stride = torch.empty(target_shape).stride()
- target_stride = (num_element_per_page, *stride[1:])
- assert storage_offset_bytes % dtype_size == 0
- tensor = torch.as_strided(
- kv_raw_tensor.view(dtype),
- size=target_shape,
- stride=target_stride,
- storage_offset=storage_offset_bytes // dtype_size,
- )
- state_tensors.append(tensor)
- storage_offset_bytes += stride[0] * dtype_size
- kv_caches[layer_name] = state_tensors
+ spec, _ = layer_to_group[layer_names[0]]
+ num_blocks = slot_bytes // spec.page_size_bytes
+
+ if (
+ kernel_block_sizes is not None
+ and isinstance(spec, AttentionSpec)
+ and group_id < len(kernel_block_sizes)
+ ):
+ kernel_block_size = kernel_block_sizes[group_id]
+ num_blocks = num_blocks * spec.storage_block_size // kernel_block_size
else:
- raise NotImplementedError(
- f"Unsupported KV cache spec type: {type(kv_cache_spec)}"
- )
-
- if has_attn and has_mamba:
- _update_hybrid_attention_layout(
- attn_groups=attn_groups,
- kv_caches=kv_caches,
- kernel_block_sizes=kernel_block_sizes,
- cache_dtype=cache_dtype,
- )
-
- # Map any sharing layers to their target layer's KV cache.
- for layer_name, target_layer_name in shared_kv_cache_layers.items():
- kv_caches[layer_name] = kv_caches[target_layer_name]
+ kernel_block_size = spec.storage_block_size
+
+ views = reshape_kv_cache(
+ buf,
+ spec,
+ num_blocks,
+ num_layer_slots=num_layer_slots,
+ layout=layout,
+ block_size=kernel_block_size,
+ )
+ for layer_name in layer_names:
+ kv_caches[layer_name] = views[layer_to_slot[layer_name]]
return kv_caches
-def _update_hybrid_attention_layout(
- attn_groups: Iterable[AttentionGroup],
- kv_caches: dict[str, Any],
- kernel_block_sizes: list[int],
- cache_dtype: str,
-) -> None:
- for group in attn_groups:
- if group.kv_cache_group_id >= len(kernel_block_sizes):
- continue
-
- kv_cache_spec = group.kv_cache_spec
- if not isinstance(kv_cache_spec, AttentionSpec):
- continue
- block_dim = group.backend.get_kv_cache_block_dim(
- kernel_block_sizes[group.kv_cache_group_id],
- kv_cache_spec.num_kv_heads,
- kv_cache_spec.head_size,
- cache_dtype_str=cache_dtype,
- )
- # if the first dim of the kvcache's layout is already num_blocks, continue
- if block_dim == 0:
- continue
-
- assert block_dim == 1, (
- "Expected the dim `num_blocks` at the second dim when updating"
- " the kvcache's layout of full attention layer"
- )
-
- for layer_name in group.layer_names:
- if layer_name not in kv_caches:
- # Shared layer — will be aliased to its target after this pass.
- continue
-
- kv_cache = kv_caches[layer_name]
- if kv_cache.shape[0] == 2:
- assert kv_cache.shape[1] != 2, (
- f"Cannot determine layout for tensor of shape {kv_cache.shape}"
- )
- hidden_size = kv_cache.shape[2:].numel()
- kv_cache.as_strided_(
- size=kv_cache.shape,
- stride=(
- hidden_size,
- 2 * hidden_size,
- *kv_cache.stride()[2:],
- ),
- )
-
-
def init_kv_cache(
runner_kv_caches: list[torch.Tensor | list[torch.Tensor]],
forward_context: dict[str, Any],
@@ -347,20 +220,17 @@ def init_kv_cache(
attn_groups: list[list[AttentionGroup]],
device: torch.device,
cache_dtype: str,
- kernel_block_sizes: list[int],
- vllm_config: VllmConfig,
+ kernel_block_sizes: list[int] | None = None,
+ layout: KVCacheLayout | None = None,
) -> dict[str, Any]:
- shared_kv_cache_layers = get_shared_kv_cache_layers(vllm_config)
- kv_cache_raw_tensors = _allocate_kv_cache(
- kv_cache_config, shared_kv_cache_layers, device
- )
- flattened_attn_groups = list(group for groups in attn_groups for group in groups)
- kv_caches = _reshape_kv_cache(
- attn_groups=flattened_attn_groups,
- kv_cache_raw_tensors=kv_cache_raw_tensors,
+ kv_caches = _allocate_and_reshape_kv_cache(
+ kv_cache_config,
+ device,
+ layout=layout,
kernel_block_sizes=kernel_block_sizes,
- cache_dtype=cache_dtype,
- shared_kv_cache_layers=shared_kv_cache_layers,
+ attn_backends=tuple(
+ group.backend for groups in attn_groups for group in groups
+ ),
)
bind_kv_cache(kv_caches, forward_context, runner_kv_caches)
return kv_caches
diff --git a/vllm/v1/worker/gpu/kv_connector.py b/vllm/v1/worker/gpu/kv_connector.py
index cdacb36e5833..dcff3c3ddf2d 100644
--- a/vllm/v1/worker/gpu/kv_connector.py
+++ b/vllm/v1/worker/gpu/kv_connector.py
@@ -50,9 +50,7 @@ def __init__(
):
self.vllm_config = vllm_config
self.kv_connector = get_kv_transfer_group()
- # Register kv caches with KV Connector if applicable.
- # TODO: support cross_layers_kv_cache
- # (see https://github.com/vllm-project/vllm/pull/27743)
+ # Register kv caches with KV Connector.
self.kv_connector.register_kv_caches(kv_caches_dict)
self.kv_connector.set_host_xfer_buffer_ops(copy_kv_blocks)
diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py
index 2e3133822fd4..cd15c7ed5cbe 100644
--- a/vllm/v1/worker/gpu/model_runner.py
+++ b/vllm/v1/worker/gpu/model_runner.py
@@ -472,7 +472,6 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
self.device,
self.cache_config.cache_dtype,
self.kernel_block_sizes,
- self.vllm_config,
)
self.kv_connector = get_kv_connector(self.vllm_config, kv_caches_dict)
diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py
index dbbeafcbd40b..9b19827e1173 100644
--- a/vllm/v1/worker/gpu_model_runner.py
+++ b/vllm/v1/worker/gpu_model_runner.py
@@ -117,7 +117,6 @@
from vllm.utils.nvtx_pytorch_hooks import PytHooks
from vllm.utils.platform_utils import is_pin_memory_available, num_compute_units
from vllm.utils.torch_utils import (
- get_dtype_size,
is_quantized_kv_cache,
kv_cache_dtype_str_to_dtype,
)
@@ -136,6 +135,7 @@
create_fast_prefill_custom_backend,
get_dcp_local_seq_lens,
reorder_batch_to_split_decodes_and_prefills,
+ resolve_kv_cache_layout,
)
from vllm.v1.core.sched.output import NewRequestData
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
@@ -147,10 +147,12 @@
FullAttentionSpec,
KVCacheConfig,
KVCacheGroupSpec,
+ KVCacheLayout,
KVCacheSpec,
MambaSpec,
SlidingWindowSpec,
UniformTypeKVCacheSpecs,
+ reshape_kv_cache,
)
from vllm.v1.outputs import (
EMPTY_MODEL_RUNNER_OUTPUT,
@@ -519,9 +521,6 @@ def __init__(
# self.model: nn.Module # Set after load_model
# Initialize in initialize_kv_cache
self.kv_caches: list[torch.Tensor] = []
- # Initialize in initialize_kv_cache_tensors
- self.cross_layers_kv_cache: torch.Tensor | None = None
- self.cross_layers_attn_backend: type[AttentionBackend] | None = None
# indexes: [kv_cache_group_id][attn_group]
self.attn_groups: list[list[AttentionGroup]] = []
# self.kv_cache_config: KVCacheConfig
@@ -4590,17 +4589,16 @@ def propose_draft_token_ids(sampled_token_ids):
with record_function_or_nullcontext(
"gpu_model_runner: AsyncGPUModelRunnerOutput"
):
- # Async path: produce a device-side snapshot that the async
- # copy stream can D2H later. Both tensors must be private
- # clones because:
+ # Async path: produce a device-side snapshot that the async copy
+ # stream can D2H later. Both tensors must be private clones
+ # because:
# - ``routing_data`` source is the shared capturer buffer,
- # which is ``clear_buffer()``-ed at the start of the
- # next step on the default stream.
+ # which is ``clear_buffer()``-ed at the start of the next
+ # step on the default stream.
# - ``slot_mapping`` source is our own
- # ``routed_experts_slot_mapping_device``, which the
- # next ``_prepare_inputs`` overwrites on the default
- # stream while the D2H is still pending on the copy
- # stream.
+ # ``routed_experts_slot_mapping_device``, which the next
+ # ``_prepare_inputs`` overwrites on the default stream
+ # while the D2H is still pending on the copy stream.
# Without clones, the copy stream would read torn data.
routed_experts_snapshot = None
if self.routed_experts_initialized:
@@ -6295,9 +6293,6 @@ def _cleanup_profiling_kv_cache(self) -> None:
for i in range(len(self.kv_caches)):
self.kv_caches[i] = None # type: ignore
self.kv_caches.clear()
- if hasattr(self, "cross_layers_kv_cache"):
- self.cross_layers_kv_cache = None
- self.cross_layers_attn_backend = None
if hasattr(self, "attn_groups"):
self.attn_groups.clear()
if hasattr(self, "kv_cache_config"):
@@ -6717,10 +6712,9 @@ def get_attn_backends_for_group(
layer_kv_cache_spec = layer_kv_cache_spec.kv_cache_specs[layer_name]
# Non-Attention layer types (e.g. Mamba1, ShortConv) do not
# expose ``num_heads``; fall back to 0 so they cluster as
- # before. Such layers never coexist with Attention in a
- # single KV cache group (different KVCacheSpec), so the
- # fallback can never spuriously merge them with attention
- # layers.
+ # before. Such layers never coexist with Attention in a single
+ # KV cache group (different KVCacheSpec), so the fallback can
+ # never spuriously merge them with attention layers.
num_heads_q = getattr(layers[layer_name], "num_heads", 0)
key = (full_cls_name, layer_kv_cache_spec, num_heads_q)
attn_backends[key] = AttentionGroupKey(
@@ -6983,38 +6977,6 @@ def may_reinitialize_input_batch(
f"!= kv_cache kernel_block_sizes {kernel_block_sizes}"
)
- def _allocate_kv_cache_tensors(
- self, kv_cache_config: KVCacheConfig
- ) -> dict[str, torch.Tensor]:
- """
- Initializes the KV cache buffer with the correct size. The buffer needs
- to be reshaped to the desired shape before being used by the models.
-
- Args:
- kv_cache_config: The KV cache config
- Returns:
- dict[str, torch.Tensor]: A map between layer names to their
- corresponding memory buffer for KV cache.
- """
- kv_cache_raw_tensors: dict[str, torch.Tensor] = {}
- for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
- tensor = torch.zeros(
- kv_cache_tensor.size, dtype=torch.int8, device=self.device
- )
- for layer_name in kv_cache_tensor.shared_by:
- kv_cache_raw_tensors[layer_name] = tensor
-
- layer_names = set()
- for group in kv_cache_config.kv_cache_groups:
- for layer_name in group.layer_names:
- if layer_name in self.runner_only_attn_layers:
- continue
- layer_names.add(layer_name)
- assert layer_names == set(kv_cache_raw_tensors.keys()), (
- "Some layers are not correctly initialized"
- )
- return kv_cache_raw_tensors
-
def _attn_group_iterator(self) -> Iterator[AttentionGroup]:
return itertools.chain.from_iterable(self.attn_groups)
@@ -7024,166 +6986,59 @@ def _kv_cache_spec_attn_group_iterator(self) -> Iterator[AttentionGroup]:
for attn_groups in self.attn_groups:
yield from attn_groups
- def _reshape_kv_cache_tensors(
+ def _allocate_and_reshape_kv_cache(
self,
- kv_cache_raw_tensors: dict[str, torch.Tensor],
+ kv_cache_config: KVCacheConfig,
kernel_block_sizes: list[int],
+ layout: KVCacheLayout,
) -> dict[str, torch.Tensor]:
- """
- Reshape the KV cache tensors to the desired shape and dtype.
+ """Allocate backing tensors and reshape into per-layer [B,H,N,C] views.
- Args:
- kv_cache_raw_tensors: The KV cache buffer of each layer, with
- correct size but uninitialized shape.
- kernel_block_sizes: The kernel block sizes for each KV cache group.
- Returns:
- Dict[str, torch.Tensor]: A map between layer names to their
- corresponding memory buffer for KV cache.
+ Returns: a dict mapping layer-name -> view of the backing tensor.
"""
kv_caches: dict[str, torch.Tensor] = {}
- has_attn, has_mamba = False, False
- for group in self._kv_cache_spec_attn_group_iterator():
- kv_cache_spec = group.kv_cache_spec
- attn_backend = group.backend
- if group.kv_cache_group_id == len(kernel_block_sizes):
- # There may be a last group for layers without kv cache.
- continue
- kernel_block_size = kernel_block_sizes[group.kv_cache_group_id]
- for layer_name in group.layer_names:
- if layer_name in self.runner_only_attn_layers:
- continue
- raw_tensor = kv_cache_raw_tensors[layer_name]
- assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0
- num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes
- if isinstance(kv_cache_spec, AttentionSpec):
- has_attn = True
- num_blocks_per_kv_block = (
- kv_cache_spec.block_size // kernel_block_size
- )
- kernel_num_blocks = num_blocks * num_blocks_per_kv_block
-
- # For MLA with compression, storage_block_size != block_size
- if kv_cache_spec.storage_block_size != kv_cache_spec.block_size:
- shape_block_size = kv_cache_spec.storage_block_size
- else:
- shape_block_size = kernel_block_size
-
- kv_cache_shape = attn_backend.get_kv_cache_shape(
- kernel_num_blocks,
- shape_block_size,
- kv_cache_spec.num_kv_heads,
- kv_cache_spec.head_size,
- cache_dtype_str=self.cache_config.cache_dtype,
- )
- dtype = kv_cache_spec.dtype
- try:
- kv_cache_stride_order = attn_backend.get_kv_cache_stride_order()
- assert len(kv_cache_stride_order) == len(kv_cache_shape)
- except (AttributeError, NotImplementedError):
- kv_cache_stride_order = tuple(range(len(kv_cache_shape)))
- # The allocation respects the backend-defined stride order
- # to ensure the semantic remains consistent for each
- # backend. We first obtain the generic kv cache shape and
- # then permute it according to the stride order which could
- # result in a non-contiguous tensor.
- kv_cache_shape = tuple(
- kv_cache_shape[i] for i in kv_cache_stride_order
- )
- # Maintain original KV shape view.
- inv_order = [
- kv_cache_stride_order.index(i)
- for i in range(len(kv_cache_stride_order))
- ]
- raw_tensor = kv_cache_raw_tensors[layer_name].view(dtype)
- if kv_cache_spec.page_size_padded is not None:
- # Use strided view to handle page_size_bytes that
- # include padding. This follows
- # the same pattern as MambaSpec handling below.
- # NOTE: This assumes kv_cache_shape[0] == num_blocks
- # (i.e. the first physical dimension is the block
- # index), which holds for MLA backends but NOT for
- # standard attention backends whose shape starts with
- # a K/V dimension of size 2.
- dtype_size = get_dtype_size(dtype)
- page_stride = kv_cache_spec.page_size_bytes // dtype_size
- strides = list(torch.empty(kv_cache_shape).stride())
- strides[inv_order[0]] = page_stride
- kv_cache = torch.as_strided(
- raw_tensor,
- size=kv_cache_shape,
- stride=tuple(strides),
- )
- else:
- # No padding — safe to use a contiguous view.
- kv_cache = raw_tensor.view(kv_cache_shape)
- kv_caches[layer_name] = kv_cache.permute(*inv_order)
-
- elif isinstance(kv_cache_spec, MambaSpec):
- has_mamba = True
- raw_tensor = kv_cache_raw_tensors[layer_name]
- state_tensors = []
- storage_offset_bytes = 0
- for shape, dtype in zip(kv_cache_spec.shapes, kv_cache_spec.dtypes):
- dtype_size = get_dtype_size(dtype)
- num_element_per_page = (
- kv_cache_spec.page_size_bytes // dtype_size
- )
- target_shape = (num_blocks, *shape)
- stride = torch.empty(target_shape).stride()
- target_stride = (num_element_per_page, *stride[1:])
- assert storage_offset_bytes % dtype_size == 0
- tensor = torch.as_strided(
- raw_tensor.view(dtype),
- size=target_shape,
- stride=target_stride,
- storage_offset=storage_offset_bytes // dtype_size,
- )
- state_tensors.append(tensor)
- storage_offset_bytes += stride[0] * dtype_size
-
- kv_caches[layer_name] = state_tensors
- else:
- raise NotImplementedError
-
- if has_attn and has_mamba:
- self._update_hybrid_attention_mamba_layout(kv_caches, kernel_block_sizes)
+ for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
+ buf = torch.zeros(
+ kv_cache_tensor.size, dtype=torch.int8, device=self.device
+ )
- return kv_caches
+ layer_to_slot_idx: dict[str, int] = {}
- def _update_hybrid_attention_mamba_layout(
- self, kv_caches: dict[str, torch.Tensor], kernel_block_sizes: list[int]
- ) -> None:
- """
- Update the layout of attention layers from (2, num_blocks, ...) to
- (num_blocks, 2, ...).
+ for slot_idx, slot_layers in enumerate(kv_cache_tensor.shared_by):
+ for layer_name in slot_layers:
+ layer_to_slot_idx[layer_name] = slot_idx
- Args:
- kv_caches: The KV cache buffer of each layer.
- kernel_block_sizes: The kernel block sizes for each KV cache group.
- """
+ num_slots = len(kv_cache_tensor.shared_by)
+ bytes_per_slot = kv_cache_tensor.size // num_slots
- for group in self._kv_cache_spec_attn_group_iterator():
- kv_cache_spec = group.kv_cache_spec
- if not isinstance(kv_cache_spec, AttentionSpec):
- continue
- block_dim = group.backend.get_kv_cache_block_dim(
- kernel_block_sizes[group.kv_cache_group_id],
- kv_cache_spec.num_kv_heads,
- kv_cache_spec.head_size,
- cache_dtype_str=self.cache_config.cache_dtype,
- )
- # block_dim: 0 means (num_blocks, 2, ...); 1 means (2, num_blocks, ...).
- if block_dim == 0:
- continue
- assert block_dim == 1
- for layer_name in group.layer_names:
- kv_cache = kv_caches[layer_name]
- hidden_size = kv_cache.shape[2:].numel()
- kv_cache.as_strided_(
- size=kv_cache.shape,
- stride=(hidden_size, 2 * hidden_size, *kv_cache.stride()[2:]),
+ layers_shared_by_tensor = set(layer_to_slot_idx.keys())
+ for g in self._kv_cache_spec_attn_group_iterator():
+ if g.kv_cache_group_id >= len(kernel_block_sizes):
+ continue
+ layer_names = [n for n in g.layer_names if n in layers_shared_by_tensor]
+ if not layer_names:
+ continue
+ spec = g.kv_cache_spec
+ kernel_block_size = kernel_block_sizes[g.kv_cache_group_id]
+
+ num_blocks = bytes_per_slot // spec.page_size_bytes
+ num_kernel_blocks_per_block = spec.block_size // kernel_block_size
+ kernel_num_blocks = num_blocks * num_kernel_blocks_per_block
+
+ multi_slot_kv_cache = reshape_kv_cache(
+ buf,
+ spec,
+ kernel_num_blocks,
+ num_layer_slots=num_slots,
+ layout=layout,
+ block_size=kernel_block_size,
)
+ for layer_name in layer_names:
+ layer_view = multi_slot_kv_cache[layer_to_slot_idx[layer_name]]
+ kv_caches[layer_name] = layer_view
+
+ return kv_caches
def initialize_kv_cache_tensors(
self, kv_cache_config: KVCacheConfig, kernel_block_sizes: list[int]
@@ -7200,29 +7055,12 @@ def initialize_kv_cache_tensors(
corresponding memory buffer for KV cache.
"""
- # Try creating KV caches optimized for kv-connector transfers
- cache_dtype = self.cache_config.cache_dtype
- if self.use_uniform_kv_cache(self.attn_groups, cache_dtype):
- kv_caches, cross_layers_kv_cache, attn_backend = (
- self.allocate_uniform_kv_caches(
- kv_cache_config,
- self.attn_groups,
- cache_dtype,
- self.device,
- kernel_block_sizes,
- )
- )
- self.cross_layers_kv_cache = cross_layers_kv_cache
- self.cross_layers_attn_backend = attn_backend
- else:
- # Fallback to the general case
- # Initialize the memory buffer for KV cache
- kv_cache_raw_tensors = self._allocate_kv_cache_tensors(kv_cache_config)
-
- # Change the memory buffer to the desired shape
- kv_caches = self._reshape_kv_cache_tensors(
- kv_cache_raw_tensors, kernel_block_sizes
- )
+ layout = resolve_kv_cache_layout(
+ tuple(g.backend for g in self._attn_group_iterator())
+ )
+ kv_caches = self._allocate_and_reshape_kv_cache(
+ kv_cache_config, kernel_block_sizes, layout=layout
+ )
# Set up cross-layer KV cache sharing
for layer_name, target_layer_name in self.shared_kv_cache_layers.items():
@@ -7318,13 +7156,7 @@ def initialize_kv_cache(
if has_kv_transfer_group() and not is_profiling:
kv_transfer_group = get_kv_transfer_group()
- if self.cross_layers_kv_cache is not None:
- assert self.cross_layers_attn_backend is not None
- kv_transfer_group.register_cross_layers_kv_cache(
- self.cross_layers_kv_cache, self.cross_layers_attn_backend
- )
- else:
- kv_transfer_group.register_kv_caches(kv_caches)
+ kv_transfer_group.register_kv_caches(kv_caches)
kv_transfer_group.set_host_xfer_buffer_ops(copy_kv_blocks)
def _get_attention_kv_cache_gid(self) -> int:
diff --git a/vllm/v1/worker/kv_connector_model_runner_mixin.py b/vllm/v1/worker/kv_connector_model_runner_mixin.py
index 797e59c02909..085d16771d46 100644
--- a/vllm/v1/worker/kv_connector_model_runner_mixin.py
+++ b/vllm/v1/worker/kv_connector_model_runner_mixin.py
@@ -8,21 +8,15 @@
from contextlib import AbstractContextManager, contextmanager, nullcontext
from typing import TYPE_CHECKING
-import torch
-
from vllm.config import VllmConfig
-from vllm.config.cache import CacheDType
from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
from vllm.forward_context import get_forward_context, set_forward_context
from vllm.logger import init_logger
-from vllm.v1.attention.backend import AttentionBackend
-from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig
from vllm.v1.outputs import (
KVConnectorOutput,
ModelRunnerOutput,
)
-from vllm.v1.worker.utils import AttentionGroup
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
@@ -110,167 +104,3 @@ def _get_kv_connector_output(
if not defer_finalize:
kv_connector.clear_connector_metadata()
-
- @staticmethod
- def use_uniform_kv_cache(
- attn_groups: list[list[AttentionGroup]],
- cache_dtype: CacheDType,
- ) -> bool:
- """
- Determines whether a uniform KV layout should be used.
- A uniform layout means all layers KV caches will share the same
- underlying tensor, where for a given block number, the respective
- KV data for all layers will be contiguous.
- This will allow efficient KV transfer of per-block KV data for all
- layers at once.
- Note this layout will only be applied given 3 conditions:
- 1. The KV Cache config contains just a single group where all layers
- have the same page size.
- 2. A KV connector is configured, and the KV connector instance prefers
- to use this layout (prefer_cross_layer_blocks() returns True)
- 2. The flash attention backend supports this layout
- (get_kv_cache_stride_order(True) includes a placement for a
- num_layers dimension)
-
- Note that the actual placement of the num_layers dimensions
- in the unified layers tensors will be determined by the attention
- backend.
- Thus, the layers KV data may still not be contiguous per block
- if the attention backend does not support it.
-
- Args:
- attn_groups: The list of attention groups for this model
- cache_dtype: The KV cache dtype
- Returns:
- True if we should use a uniform KV cache layout.
- """
-
- if not has_kv_transfer_group():
- return False
- if not get_kv_transfer_group().prefer_cross_layer_blocks:
- return False
-
- if len(attn_groups) != 1 or len(attn_groups[0]) != 1:
- return False
-
- attn_group = attn_groups[0][0]
- kv_cache_spec = attn_group.kv_cache_spec
- if not isinstance(kv_cache_spec, AttentionSpec):
- return False
-
- attn_backend = attn_group.backend
- kv_cache_shape = attn_backend.get_kv_cache_shape(
- 1234,
- kv_cache_spec.block_size,
- kv_cache_spec.num_kv_heads,
- kv_cache_spec.head_size,
- cache_dtype_str=cache_dtype,
- )
-
- try:
- kv_cache_stride_order = attn_backend.get_kv_cache_stride_order(
- include_num_layers_dimension=True
- )
- except (AttributeError, NotImplementedError):
- return False
-
- # check that attention backend includes a layers dimension
- if len(kv_cache_stride_order) != len(kv_cache_shape) + 1:
- return False
-
- # stride_order[0] == 0 means num_layers stays first in physical
- # layout (identity permutation), so cross-layer is unsupported.
- return kv_cache_stride_order[0] != 0
-
- @staticmethod
- def allocate_uniform_kv_caches(
- kv_cache_config: KVCacheConfig,
- attn_groups: list[list[AttentionGroup]],
- cache_dtype: CacheDType,
- device: torch.device,
- kernel_block_sizes: list[int],
- ) -> tuple[dict[str, torch.Tensor], torch.Tensor, type[AttentionBackend]]:
- """
- Initializes and reshapes KV caches for the simple case where all
- layers have the same layout.
-
- This function assumes use_uniform_kv_cache() returned True.
-
- Args:
- kv_cache_config: The KV cache config
- attn_groups: The list of attention groups for this model
- cache_dtype: The KV cache dtype
- device: The torch device to allocate on.
- kernel_block_sizes: The kernel block sizes for each KV cache group.
- Returns:
- A tuple (kv_caches, cross_layers_kv_cache, attn_backend) where:
- kv_caches is a dict mapping between layer names to their
- corresponding memory buffer for KV cache.
- cross_layers_kv_cache is the cross layers kv cache tensor
- attn_backend is the attention backend matching this tensor
- """
- attn_group = attn_groups[0][0]
- kv_cache_spec = attn_group.kv_cache_spec
- assert isinstance(kv_cache_spec, AttentionSpec)
-
- tensor_sizes = set(
- kv_cache_tensor.size for kv_cache_tensor in kv_cache_config.kv_cache_tensors
- )
- assert len(tensor_sizes) == 1
- tensor_size = tensor_sizes.pop()
-
- page_size = kv_cache_spec.page_size_bytes
- assert tensor_size % page_size == 0
- num_blocks = tensor_size // page_size
- num_layers = len(kv_cache_config.kv_cache_tensors)
- total_size = tensor_size * num_layers
-
- assert len(kernel_block_sizes) == 1
- kernel_block_size = kernel_block_sizes[0]
- num_blocks_per_kv_block = kv_cache_spec.block_size // kernel_block_size
- kernel_num_blocks = num_blocks * num_blocks_per_kv_block
-
- attn_backend = attn_group.backend
- kv_cache_shape = attn_backend.get_kv_cache_shape(
- kernel_num_blocks,
- kernel_block_size,
- kv_cache_spec.num_kv_heads,
- kv_cache_spec.head_size,
- cache_dtype_str=cache_dtype,
- )
-
- # prepend a num_layers dimension into the shape
- kv_cache_shape = (num_layers,) + kv_cache_shape
-
- try:
- kv_cache_stride_order = attn_backend.get_kv_cache_stride_order(
- include_num_layers_dimension=True
- )
- assert len(kv_cache_stride_order) == len(kv_cache_shape)
- except (AttributeError, NotImplementedError):
- kv_cache_stride_order = tuple(range(len(kv_cache_shape)))
-
- kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order)
-
- logger.info("Allocating a cross layer KV cache of shape %s", kv_cache_shape)
-
- # allocate one contiguous buffer for all layers
- cross_layers_kv_cache = (
- torch.zeros(total_size, dtype=torch.int8, device=device)
- .view(kv_cache_spec.dtype)
- .view(kv_cache_shape)
- )
-
- # Maintain original KV shape view.
- inv_order = [
- kv_cache_stride_order.index(i) for i in range(len(kv_cache_stride_order))
- ]
- permuted_kv_cache = cross_layers_kv_cache.permute(*inv_order)
-
- kv_caches = {}
- for i, kv_cache_tensor in enumerate(kv_cache_config.kv_cache_tensors):
- tensor = permuted_kv_cache[i]
- for layer_name in kv_cache_tensor.shared_by:
- kv_caches[layer_name] = tensor
-
- return kv_caches, cross_layers_kv_cache, attn_backend
diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py
index c0f44b6db0c3..abd4a7fcea8a 100644
--- a/vllm/v1/worker/utils.py
+++ b/vllm/v1/worker/utils.py
@@ -4,7 +4,6 @@
from collections import defaultdict
from collections.abc import Iterable
from dataclasses import dataclass, field
-from itertools import product as iprod
from typing import Any
import torch
@@ -128,13 +127,6 @@ def __init__(
continue
kernel_bs = kernel_block_sizes[group.kv_cache_group_id]
ratio = spec.block_size // kernel_bs
- block_dim = group.backend.get_kv_cache_block_dim(
- kernel_bs,
- spec.num_kv_heads,
- spec.head_size,
- cache_dtype_str=cache_dtype,
- )
-
for layer_name in group.layer_names:
if layer_name in runner_only_attn_layers:
continue
@@ -147,7 +139,7 @@ def __init__(
seen_ptrs.add(dp)
el = kv.element_size()
- cur_bytes = kv.stride(block_dim) * el
+ cur_bytes = kv.stride(0) * el
assert cur_bytes % 4 == 0
kernel_block_el = cur_bytes // 4
cur_page_el = kernel_block_el * ratio
@@ -158,16 +150,7 @@ def __init__(
f"Non-uniform page sizes: {page_size_el} vs {cur_page_el}"
)
- block_stride_bytes = cur_bytes
- outer_dims = [
- d
- for d in range(block_dim)
- if kv.stride(d) * el > block_stride_bytes
- ]
- outer_strides = [kv.stride(d) * el for d in outer_dims]
- for outer in iprod(*(range(kv.shape[d]) for d in outer_dims)):
- off_bytes = sum(i * s for i, s in zip(outer, outer_strides))
- seg_addrs.append(dp + off_bytes)
+ seg_addrs.append(dp)
if not seg_addrs or page_size_el is None:
self._meta = None
@@ -515,7 +498,7 @@ def bind_kv_cache(
# Bind kv_caches to forward context
for layer_name, kv_cache in kv_caches.items():
- forward_context[layer_name].kv_cache = kv_cache
+ forward_context[layer_name].bind_kv_cache(kv_cache)
def is_residual_scattered_for_sp(