Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
93 commits
Select commit Hold shift + click to select a range
c7d6b53
Standardize KV-cache layouts per RFC #42082
LucasWilkinson May 16, 2026
0022859
Fix pre-commit
MatthewBonanni May 18, 2026
380d4db
Remove _update_hybrid_attention_layout
MatthewBonanni May 18, 2026
1546c7b
Cache resolve_kv_cache_layout
MatthewBonanni May 18, 2026
10754b8
Fix tokenspeed mla
MatthewBonanni May 18, 2026
d995a9b
Add unit tests
MatthewBonanni May 19, 2026
2cc1951
Use LBHNC and LBNHC and introduce HNC and NHC aliases
MatthewBonanni May 19, 2026
cbb6dc2
Comment
MatthewBonanni May 19, 2026
484208a
Clean up
MatthewBonanni May 19, 2026
c6bb9f4
Merge branch 'main' into lwilkinson/kv-layout/core-standardize-and-re…
MatthewBonanni May 19, 2026
d3234d1
Fix pre-commit
MatthewBonanni May 19, 2026
4fd4e4a
Handle kernel_block_size < block_size
MatthewBonanni May 19, 2026
b6fea2c
Fix typo
MatthewBonanni May 19, 2026
dcf2a87
Fix CPU
MatthewBonanni May 19, 2026
640074e
Update test
MatthewBonanni May 20, 2026
c5d8a76
Cleanup
MatthewBonanni May 20, 2026
f72472b
Eliminate contiguous() in trtllm prefill pathway
MatthewBonanni May 20, 2026
9884008
Merge branch 'main' into lwilkinson/kv-layout/core-standardize-and-re…
MatthewBonanni May 20, 2026
446c7f6
Fix circular import
MatthewBonanni May 21, 2026
aa571aa
Fix turboquant
MatthewBonanni May 21, 2026
e018943
More indexing fixes
MatthewBonanni May 21, 2026
b4d53db
Fix CI failures
MatthewBonanni May 21, 2026
5fddcd4
Fix test
MatthewBonanni May 21, 2026
19f4f6d
Fix contiguous
MatthewBonanni May 21, 2026
43b9995
Merge branch 'main' into lwilkinson/kv-layout/core-standardize-and-re…
MatthewBonanni May 26, 2026
7da171e
Merge remote-tracking branch 'origin/main' into lwilkinson/kv-layout/…
LucasWilkinson May 27, 2026
b373c3d
Merge branch 'main' into lwilkinson/kv-layout/core-standardize-and-re…
MatthewBonanni May 27, 2026
1934fd6
Fix compressor
MatthewBonanni May 27, 2026
8fde376
Eliminate get_kv_cache_shape override
MatthewBonanni May 27, 2026
a6edcdb
Clean up
MatthewBonanni May 27, 2026
8e97274
Docstring
MatthewBonanni May 27, 2026
c91e0ce
Docstring
MatthewBonanni May 27, 2026
d5a7439
Deduplicate SinkFullAttentionSpec.merge
MatthewBonanni May 27, 2026
8851f7b
remove get_kv_cache_layout
MatthewBonanni May 27, 2026
e638a5e
Fix FA
MatthewBonanni May 27, 2026
80f6e82
Fix rocm
MatthewBonanni May 27, 2026
224be3f
Fix tests
MatthewBonanni May 27, 2026
0676cd8
Fix MLA
MatthewBonanni May 27, 2026
eee608a
Fix test
MatthewBonanni May 27, 2026
fd1b617
Fix mamba
MatthewBonanni May 27, 2026
f07cd79
Merge branch 'main' into lwilkinson/kv-layout/core-standardize-and-re…
MatthewBonanni May 27, 2026
66e01a7
Fix merge regressions and remove cosmetic refactors
LucasWilkinson May 28, 2026
13b7f26
Merge remote-tracking branch 'nm/lwilkinson/kv-layout/core-standardiz…
LucasWilkinson May 28, 2026
64c2618
Fix nvfp4 KV cache TMA alignment with 2H head layout
LucasWilkinson May 28, 2026
2dc36bb
Merge branch 'main' into lwilkinson/kv-layout/core-standardize-and-re…
MatthewBonanni May 28, 2026
4df2bcc
Merge branch 'main' into lwilkinson/kv-layout/core-standardize-and-re…
MatthewBonanni May 28, 2026
6b6adb1
Fix test
MatthewBonanni May 28, 2026
958cbf6
Fix tests
MatthewBonanni May 28, 2026
3fb69f9
Fix triton MLA
MatthewBonanni May 28, 2026
2853ee5
Fix NIXL connector
MatthewBonanni May 28, 2026
229fe95
Fix tests
MatthewBonanni May 28, 2026
24fd35c
Add missing block_strides to NixlAgentMetadata in tests
LucasWilkinson May 28, 2026
9d97a8a
Fix MLA tests
MatthewBonanni May 29, 2026
7a69728
Bucket by HNC
MatthewBonanni May 29, 2026
e8c6f98
Merge branch 'main' into lwilkinson/kv-layout/core-standardize-and-re…
MatthewBonanni May 29, 2026
ba123ba
Fix docstring
MatthewBonanni May 29, 2026
cfc0c6b
Always pass as 4D
MatthewBonanni May 29, 2026
488af5b
Fix bucket key
MatthewBonanni May 29, 2026
035ceb2
Merge branch 'main' into lwilkinson/kv-layout/core-standardize-and-re…
MatthewBonanni May 29, 2026
1fc7178
fix
LucasWilkinson May 31, 2026
1ed1b08
cleanup
LucasWilkinson Jun 1, 2026
5d089e0
Merge branch 'main' into lwilkinson/kv-layout/core-standardize-and-re…
MatthewBonanni Jun 1, 2026
2c63df9
Merge branch 'lwilkinson/kv-layout/core-standardize-and-remove-legacy…
LucasWilkinson Jun 1, 2026
1f25208
cleanup
LucasWilkinson Jun 1, 2026
d164bf9
cleanup
LucasWilkinson Jun 1, 2026
0894845
Eliminate cache duplication in CPU attn
MatthewBonanni Jun 1, 2026
a3dbf72
cleanup
LucasWilkinson Jun 1, 2026
ea8113a
Merge branch 'lwilkinson/kv-layout/core-standardize-and-remove-legacy…
LucasWilkinson Jun 1, 2026
cddf0c6
cleanup
LucasWilkinson Jun 1, 2026
25f5fef
cleanup
LucasWilkinson Jun 2, 2026
0d43676
cleanup
LucasWilkinson Jun 2, 2026
e9b3860
cleanup
LucasWilkinson Jun 2, 2026
2aa1698
Use meta tensors for stride computation to avoid OOM
LucasWilkinson Jun 2, 2026
92710f9
cleanup
LucasWilkinson Jun 2, 2026
f9ae140
cleanup
LucasWilkinson Jun 2, 2026
ecbd1a9
Fix default KV cache layout to LBHNC
LucasWilkinson Jun 2, 2026
3aeb6d6
Merge branch 'main' into lwilkinson/kv-layout/core-standardize-and-re…
MatthewBonanni Jun 2, 2026
31aae79
Merge branch 'lwilkinson/kv-layout/core-standardize-and-remove-legacy…
LucasWilkinson Jun 2, 2026
e767c7b
cleanup
LucasWilkinson Jun 2, 2026
2f01f38
cleanup
LucasWilkinson Jun 2, 2026
0393c06
cleanup
LucasWilkinson Jun 2, 2026
be85d37
cleanup
LucasWilkinson Jun 2, 2026
5b40dc4
cleanup
LucasWilkinson Jun 2, 2026
e021cd0
cleanup
LucasWilkinson Jun 2, 2026
3d1363c
cleanup
LucasWilkinson Jun 2, 2026
ed65016
cleanup
LucasWilkinson Jun 2, 2026
9316862
cleanup
LucasWilkinson Jun 2, 2026
8deb263
cleanup
LucasWilkinson Jun 3, 2026
50dfd0f
Fix ROCm attention KV cache layout mismatch
LucasWilkinson Jun 3, 2026
b2318c0
Fix NIXL Mamba cache unpack and chunked_prefill block_size
LucasWilkinson Jun 3, 2026
ba18262
Merge remote-tracking branch 'origin/main' into lwilkinson/kv-layout/…
LucasWilkinson Jun 3, 2026
5845525
cleanup
LucasWilkinson Jun 3, 2026
847b910
cleanup
LucasWilkinson Jun 3, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions .buildkite/test-amd.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -839,7 +839,7 @@ steps:
num_gpus: 2
working_dir: "/vllm-workspace/tests"
source_file_dependencies:
- vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
- vllm/distributed/kv_transfer/kv_connector/v1/nixl/
- vllm/v1/worker/kv_connector_model_runner_mixin.py
- tests/v1/kv_connector/nixl_integration/
- vllm/platforms/rocm.py
Expand All @@ -866,7 +866,7 @@ steps:
num_gpus: 4
working_dir: "/vllm-workspace/tests"
source_file_dependencies:
- vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
- vllm/distributed/kv_transfer/kv_connector/v1/nixl/
- tests/v1/kv_connector/nixl_integration/
- vllm/platforms/rocm.py
commands:
Expand Down Expand Up @@ -2341,7 +2341,7 @@ steps:
num_gpus: 4
working_dir: "/vllm-workspace/tests"
source_file_dependencies:
- vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
- vllm/distributed/kv_transfer/kv_connector/v1/nixl/
- tests/v1/kv_connector/nixl_integration/
- vllm/platforms/rocm.py
commands:
Expand Down Expand Up @@ -2377,7 +2377,7 @@ steps:
optional: true
working_dir: "/vllm-workspace/tests"
source_file_dependencies:
- vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
- vllm/distributed/kv_transfer/kv_connector/v1/nixl/
- tests/v1/kv_connector/nixl_integration/
- vllm/platforms/rocm.py
commands:
Expand All @@ -2391,7 +2391,7 @@ steps:
num_gpus: 4
working_dir: "/vllm-workspace/tests"
source_file_dependencies:
- vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
- vllm/distributed/kv_transfer/kv_connector/v1/nixl/
- tests/v1/kv_connector/nixl_integration/
- vllm/platforms/rocm.py
commands:
Expand All @@ -2405,7 +2405,7 @@ steps:
num_gpus: 4
working_dir: "/vllm-workspace/tests"
source_file_dependencies:
- vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
- vllm/distributed/kv_transfer/kv_connector/v1/nixl/
- tests/v1/kv_connector/nixl_integration/
- vllm/platforms/rocm.py
commands:
Expand Down Expand Up @@ -3353,7 +3353,7 @@ steps:
optional: true
working_dir: "/vllm-workspace/tests"
source_file_dependencies:
- vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
- vllm/distributed/kv_transfer/kv_connector/v1/nixl/
- vllm/v1/worker/kv_connector_model_runner_mixin.py
- tests/v1/kv_connector/nixl_integration/
- vllm/platforms/rocm.py
Expand All @@ -3369,7 +3369,7 @@ steps:
optional: true
working_dir: "/vllm-workspace/tests"
source_file_dependencies:
- vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
- vllm/distributed/kv_transfer/kv_connector/v1/nixl/
- tests/v1/kv_connector/nixl_integration/
- vllm/platforms/rocm.py
commands:
Expand All @@ -3384,7 +3384,7 @@ steps:
optional: true
working_dir: "/vllm-workspace/tests"
source_file_dependencies:
- vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
- vllm/distributed/kv_transfer/kv_connector/v1/nixl/
- tests/v1/kv_connector/nixl_integration/
- vllm/platforms/rocm.py
commands:
Expand Down
70 changes: 18 additions & 52 deletions benchmarks/attention_benchmarks/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import logging
import types
from contextlib import contextmanager
from math import prod

import numpy as np
import torch
Expand All @@ -30,10 +31,13 @@
)
from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata,
get_kv_cache_layout,
set_kv_cache_layout,
resolve_kv_cache_layout,
)
from vllm.v1.kv_cache_interface import (
FullAttentionSpec,
compute_layer_kv_cache_shape_bytes,
reshape_kv_cache,
)
from vllm.v1.kv_cache_interface import FullAttentionSpec

# ============================================================================
# Backend Configuration
Expand Down Expand Up @@ -324,52 +328,23 @@ def _create_input_tensors(
def _create_kv_cache(
config: BenchmarkConfig,
max_num_blocks: int,
backend_class,
device: torch.device,
dtype: torch.dtype,
) -> list:
"""Create KV cache tensors for all layers using the backend's methods.

Uses the backend's get_kv_cache_shape() and get_kv_cache_stride_order()
to create the cache with the correct shape and memory layout.
"""
# Get the logical shape from the backend
cache_shape = backend_class.get_kv_cache_shape(
num_blocks=max_num_blocks,
"""Create KV cache tensors for all layers using the standard allocator."""
spec = FullAttentionSpec(
block_size=config.block_size,
num_kv_heads=config.num_kv_heads,
head_size=config.head_dim,
dtype=dtype,
)

# Get the stride order for custom memory layout
try:
stride_order = backend_class.get_kv_cache_stride_order()
assert len(stride_order) == len(cache_shape)
except (AttributeError, NotImplementedError):
stride_order = tuple(range(len(cache_shape)))

# Permute shape to physical layout order
physical_shape = tuple(cache_shape[i] for i in stride_order)

# Compute inverse permutation to get back to logical view
inv_order = [stride_order.index(i) for i in range(len(stride_order))]

# Use fp8 dtype for cache when requested.
cache_dtype = dtype
if config.kv_cache_dtype == "fp8":
from vllm.platforms import current_platform

cache_dtype = current_platform.fp8_dtype()

cache_list = []
for _ in range(config.num_layers):
# Allocate in physical layout order (contiguous in memory)
cache = torch.zeros(*physical_shape, device=device, dtype=cache_dtype)
# Permute to logical view
cache = cache.permute(*inv_order)
cache_list.append(cache)

return cache_list
layout = resolve_kv_cache_layout()
total_bytes = (
prod(compute_layer_kv_cache_shape_bytes(spec, max_num_blocks))
* config.num_layers
)
buf = torch.zeros(total_bytes, device=device, dtype=torch.int8)
return reshape_kv_cache(buf, spec, max_num_blocks, config.num_layers, layout)


# ============================================================================
Expand Down Expand Up @@ -514,13 +489,6 @@ def run_attention_benchmark(config: BenchmarkConfig) -> BenchmarkResult:
backend_cfg, config, device, dtype
)

# Set KV cache layout if the backend requires a specific one
# (e.g., FlashInfer requires HND on SM100/Blackwell for TRTLLM attention)
required_layout = backend_class.get_required_kv_cache_layout()
if required_layout is not None:
set_kv_cache_layout(required_layout)
get_kv_cache_layout.cache_clear()

common_metadata = _build_common_attn_metadata(
q_lens, kv_lens, config.block_size, device
)
Expand Down Expand Up @@ -549,9 +517,7 @@ def run_attention_benchmark(config: BenchmarkConfig) -> BenchmarkResult:
config, total_q, device, dtype, quantize_query=quantize_query
)

cache_list = _create_kv_cache(
config, max_num_blocks, backend_class, device, dtype
)
cache_list = _create_kv_cache(config, max_num_blocks, device, dtype)

times, mem_stats = _run_single_benchmark(
config,
Expand Down
2 changes: 1 addition & 1 deletion docs/features/nixl_connector_compatibility.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ th:not(:first-child) {

<sup>1</sup> P and D instances must use the same speculation configuration.

<sup>2</sup> Requires `FLASH_ATTN` or `FLASHINFER` backend **and** `HND` KV cache layout. Enable via `--kv-transfer-config '{"kv_connector_extra_config": {"enable_cross_layers_blocks": "True"}}'`.
<sup>2</sup> Cross-layer contiguity is achieved by using a `BLHNC` layout (set via `VLLM_KV_CACHE_LAYOUT=BLHNC` or `--enable-cross-layers`).

<sup>3</sup> Supported only when HMA is **not** required (i.e., non-hybrid models). Block IDs are remapped automatically. Only P block size < D block size is supported.

Expand Down
9 changes: 0 additions & 9 deletions docs/features/nixl_connector_usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -389,15 +389,6 @@ Support use case: Prefill with 'HND' and decode with 'NHD' with experimental con
--kv-transfer-config '{..., "enable_permute_local_kv":"True"}'
```

### Cross layers blocks

By default, this feature is disabled. On attention backends that support this feature, each logical block is contiguous in physical memory. This reduces the number of buffers that need to be transferred.
To enable this feature:

```bash
--kv-transfer-config '{..., "kv_connector_extra_config": {"enable_cross_layers_blocks": "True"}}'
```

## Example Scripts/Code

Refer to these example scripts in the vLLM repository:
Expand Down
37 changes: 21 additions & 16 deletions tests/compile/passes/test_fusion_attn.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
from math import prod

import pytest
import torch._dynamo
Expand Down Expand Up @@ -39,7 +40,12 @@
from vllm.utils.flashinfer import has_flashinfer
from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.kv_cache_interface import AttentionSpec, get_kv_quant_mode
from vllm.v1.attention.backends.utils import resolve_kv_cache_layout
from vllm.v1.kv_cache_interface import (
AttentionSpec,
compute_layer_kv_cache_shape_bytes,
get_kv_quant_mode,
)

DEVICE_TYPE = current_platform.device_type
FP8_DTYPE = current_platform.fp8_dtype()
Expand Down Expand Up @@ -108,29 +114,28 @@ def build_attn_metadata(self, batch_size: int) -> AttentionMetadata:
max_blocks = (max(batch_spec.seq_lens) + self.block_size - 1) // self.block_size
num_blocks = batch_size * max_blocks

# Fetch the attention backend and kv cache shape and stride order
attn_backend = self.attn.attn_backend
kv_cache_shape = attn_backend.get_kv_cache_shape(
num_blocks, self.block_size, self.num_kv_heads, self.head_size
spec = AttentionSpec(
block_size=self.block_size,
num_kv_heads=self.num_kv_heads,
head_size=self.head_size,
dtype=self.attn.kv_cache_torch_dtype,
kv_quant_mode=get_kv_quant_mode(self.attn.kv_cache_dtype),
)
try:
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order()
except (AttributeError, NotImplementedError):
kv_cache_stride_order = tuple(range(len(kv_cache_shape)))

kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order)
layout = resolve_kv_cache_layout()
kv_cache_shape = compute_layer_kv_cache_shape_bytes(spec, num_blocks)
kv_cache_stride_order = layout.layer_stride_order
physical_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order)
inv_order = [
kv_cache_stride_order.index(i) for i in range(len(kv_cache_stride_order))
]

# Create dummy KV cache
raw_tensor = torch.zeros(
2 * num_blocks * self.block_size * self.num_kv_heads * self.head_size,
dtype=self.attn.kv_cache_torch_dtype,
prod(kv_cache_shape),
dtype=torch.int8,
device=self.device,
)
raw_tensor = raw_tensor.view(kv_cache_shape)
kv_cache = raw_tensor.permute(*inv_order)
raw_tensor = raw_tensor.view(physical_shape)
kv_cache = raw_tensor.permute(*inv_order).view(self.attn.kv_cache_torch_dtype)

self.attn.kv_cache = kv_cache

Expand Down
25 changes: 6 additions & 19 deletions tests/compile/passes/test_mla_attn_quant_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,27 +150,14 @@ def build_attn_metadata(self, batch_size: int) -> AttentionMetadata:
max_blocks = (max(batch_spec.seq_lens) + self.block_size - 1) // self.block_size
num_blocks = batch_size * max_blocks

# MLA KV cache is 3D: (num_blocks, block_size, head_size)
attn_backend = self.mla_attn.attn_backend
kv_cache_shape = attn_backend.get_kv_cache_shape(
num_blocks, self.block_size, 1, self.head_size
)
try:
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order()
except (AttributeError, NotImplementedError):
kv_cache_stride_order = tuple(range(len(kv_cache_shape)))

ordered_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order)
inv_order = [
kv_cache_stride_order.index(i) for i in range(len(kv_cache_stride_order))
]

raw_tensor = torch.zeros(
ordered_shape, dtype=self.kv_cache_dtype, device=self.device
# MLA KV cache is 4D: (num_blocks, num_heads=1, block_size, head_size)
kv_cache = torch.zeros(
(num_blocks, 1, self.block_size, self.head_size),
dtype=self.kv_cache_dtype,
device=self.device,
)
kv_cache = raw_tensor.permute(*inv_order)

self.mla_attn.kv_cache = kv_cache
self.mla_attn.bind_kv_cache(kv_cache)

self.attn_metadata = self.builder.build(
common_prefix_len=0, common_attn_metadata=common_attn_metadata
Expand Down
24 changes: 5 additions & 19 deletions tests/compile/passes/test_mla_rope_kvcache_cat_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,29 +165,15 @@ def build_attn_metadata(self, batch_size: int) -> CommonAttentionMetadata:
max_blocks = (max(batch_spec.seq_lens) + self.block_size - 1) // self.block_size
num_blocks = batch_size * max_blocks

# Fetch the attention backend and kv cache shape and stride order
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
num_blocks, self.block_size, self.num_kv_heads, self.head_size
)
try:
kv_cache_stride_order = self.attn_backend.get_kv_cache_stride_order()
except (AttributeError, NotImplementedError):
kv_cache_stride_order = tuple(range(len(kv_cache_shape)))

kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order)
inv_order = [
kv_cache_stride_order.index(i) for i in range(len(kv_cache_stride_order))
]

raw_tensor = torch.zeros(
num_blocks * self.block_size * self.num_kv_heads * self.head_size,
# MLA uses a 4D KV cache: (num_blocks, num_heads=1, block_size, head_size).
kv_cache_shape = (num_blocks, 1, self.block_size, self.head_size)
kv_cache = torch.zeros(
kv_cache_shape,
dtype=self.kv_cache_dtype,
device=self.device,
)
raw_tensor = raw_tensor.view(kv_cache_shape)
kv_cache = raw_tensor.permute(*inv_order)

self.mla_attn.kv_cache = kv_cache
self.mla_attn.bind_kv_cache(kv_cache)

# Build attn metadata
attn_metadata = self.builder.build(
Expand Down
Loading
Loading