Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions tests/v1/kv_connector/unit/test_offloading_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
)
from vllm.forward_context import ForwardContext
from vllm.utils.hashing import sha256
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
from vllm.v1.core.kv_cache_utils import (
BlockHash,
get_request_block_hasher,
Expand Down Expand Up @@ -92,7 +93,7 @@ def get_manager(self) -> OffloadingManager:
return self.manager

def get_handlers(
self, _
self, _, __
) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec], OffloadingHandler]]:
yield GPULoadStoreSpec, MockLoadStoreSpec, self.handler
yield MockLoadStoreSpec, GPULoadStoreSpec, self.handler
Expand Down Expand Up @@ -138,7 +139,10 @@ def __init__(
self.worker_connector = OffloadingConnector(vllm_config, KVConnectorRole.WORKER)

# register worker kv_caches to enable OffloadingWorker creations
self.worker_connector.register_kv_caches(kv_caches={"a": torch.empty(0)})
self.worker_connector.register_cross_layers_kv_cache(
kv_cache=torch.empty(0),
attn_backend=FlashAttentionBackend,
)

# extract connector of scheduler
scheduler_connector = self.scheduler.connector
Expand Down
145 changes: 93 additions & 52 deletions tests/v1/kv_offload/test_cpu_offloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
from vllm import LLM, SamplingParams, TokensPrompt
from vllm.config import KVEventsConfig, KVTransferConfig
from vllm.distributed.kv_events import BlockStored, KVEventBatch
from vllm.utils.system_utils import set_env_var

CPU_BLOCK_SIZES = [16, 48]
CPU_BLOCK_SIZES = [48]
ATTN_BACKENDS = ["FLASH_ATTN", "FLASHINFER"]


class MockSubscriber:
Expand Down Expand Up @@ -63,8 +65,88 @@ def close(self):
self.sub.close()


def _latency_test(llm: LLM, subscriber: MockSubscriber):
sampling_params = SamplingParams(max_tokens=1)

num_times_cpu_better_than_cold = 0
num_tests = 10
total_cold_time = 0.0
total_gpu_hit_time = 0.0
total_cpu_hit_time = 0.0
prompt_token_ids = [0] * 10001
for i in tqdm(range(num_tests), desc="Running tests"):
prompt_token_ids[0] = i
prompts = [TokensPrompt(prompt_token_ids=prompt_token_ids)]

# run generation - this should trigger saving KV cache
start_time = time.time()
llm.generate(prompts, sampling_params, use_tqdm=False)
cold_time = time.time() - start_time
total_cold_time += cold_time

# run generation again - should hit the GPU prefix cache
start_time = time.time()
llm.generate(prompts, sampling_params, use_tqdm=False)
gpu_hit_time = time.time() - start_time
total_gpu_hit_time += gpu_hit_time

# reset prefix cache to avoid GPU hit.
llm.reset_prefix_cache()

assert subscriber.get_new_cpu_stored_events()

# run generation again - this should trigger loading from CPU
start_time = time.time()
llm.generate(prompts, sampling_params, use_tqdm=False)
cpu_hit_time = time.time() - start_time
total_cpu_hit_time += cpu_hit_time

if cpu_hit_time < cold_time:
num_times_cpu_better_than_cold += 1

print("Average times:")
print(f" Cold: {total_cold_time * 1000 / num_tests:.2f}ms")
print(f" GPU hit: {total_gpu_hit_time * 1000 / num_tests:.2f}ms")
print(f" CPU hit: {total_cpu_hit_time * 1000 / num_tests:.2f}ms")

assert num_times_cpu_better_than_cold >= 0.8 * num_tests


def _accuracy_test(llm: LLM, subscriber: MockSubscriber):
sampling_params = SamplingParams(max_tokens=1)
cpu_block_size = (
llm.llm_engine.vllm_config.kv_transfer_config.kv_connector_extra_config[
"block_size"
]
)

subscriber.get_new_cpu_stored_events()

# prepend prompt to be cpu block aligned
prompt = "Let's count to 10. One, two, three, four,"
while (
len(llm.generate(prompt, use_tqdm=False)[0].prompt_token_ids) % cpu_block_size
!= 0
):
prompt = ". " + prompt

assert subscriber.get_new_cpu_stored_events()

test_count = 100
success_count = 0
for i in range(test_count):
if (
llm.generate(prompt, sampling_params, use_tqdm=False)[0].outputs[0].text
== " five"
):
success_count += 1

assert success_count >= 0.5 * test_count


@pytest.mark.parametrize("cpu_block_size", CPU_BLOCK_SIZES)
def test_cpu_offloading(cpu_block_size: int) -> None:
@pytest.mark.parametrize("attn_backend", ATTN_BACKENDS)
def test_cpu_offloading(cpu_block_size: int, attn_backend: str) -> None:
"""
Tests OffloadingConnector with CPUOffloadingSpec.
"""
Expand Down Expand Up @@ -92,61 +174,20 @@ def test_cpu_offloading(cpu_block_size: int) -> None:
topic="test",
)

llm = LLM(
model="meta-llama/Llama-3.2-1B-Instruct",
gpu_memory_utilization=0.5,
kv_events_config=kv_events_config,
kv_transfer_config=kv_transfer_config,
)

sampling_params = SamplingParams(temperature=0, max_tokens=1)
with set_env_var("VLLM_ATTENTION_BACKEND", attn_backend):
llm = LLM(
model="meta-llama/Llama-3.2-1B-Instruct",
gpu_memory_utilization=0.5,
kv_events_config=kv_events_config,
kv_transfer_config=kv_transfer_config,
)

events_endpoint = events_endpoint.replace("*", "127.0.0.1")
subscriber = MockSubscriber(events_endpoint, topic=kv_events_config.topic)

try:
num_times_cpu_better_than_cold = 0
num_tests = 10
total_cold_time = 0.0
total_gpu_hit_time = 0.0
total_cpu_hit_time = 0.0
prompt_token_ids = [0] * 10001
for i in tqdm(range(num_tests), desc="Running tests"):
prompt_token_ids[0] = i
prompts = [TokensPrompt(prompt_token_ids=prompt_token_ids)]

# run generation - this should trigger saving KV cache
start_time = time.time()
llm.generate(prompts, sampling_params, use_tqdm=False)
cold_time = time.time() - start_time
total_cold_time += cold_time

# run generation again - should hit the GPU prefix cache
start_time = time.time()
llm.generate(prompts, sampling_params, use_tqdm=False)
gpu_hit_time = time.time() - start_time
total_gpu_hit_time += gpu_hit_time

# reset prefix cache to avoid GPU hit.
llm.reset_prefix_cache()

assert subscriber.get_new_cpu_stored_events()

# run generation again - this should trigger loading from CPU
start_time = time.time()
llm.generate(prompts, sampling_params, use_tqdm=False)
cpu_hit_time = time.time() - start_time
total_cpu_hit_time += cpu_hit_time

if cpu_hit_time < cold_time:
num_times_cpu_better_than_cold += 1

print("Average times:")
print(f" Cold: {total_cold_time * 1000 / num_tests:.2f}ms")
print(f" GPU hit: {total_gpu_hit_time * 1000 / num_tests:.2f}ms")
print(f" CPU hit: {total_cpu_hit_time * 1000 / num_tests:.2f}ms")

assert num_times_cpu_better_than_cold >= 0.8 * num_tests
_latency_test(llm, subscriber)
_accuracy_test(llm, subscriber)
finally:
subscriber.close()
del llm
5 changes: 4 additions & 1 deletion tests/v1/worker/test_gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,10 @@ def test_kv_cache_stride_order(monkeypatch, model_runner):
# Permutation that gets you back to expected kv shape
for test_stride in ((1, 4, 0, 2, 3), (0, 1, 2, 3, 4)):

def rnd_stride_order(test_stride=test_stride):
def rnd_stride_order(
include_num_layers_dimension: bool = False, test_stride=test_stride
):
assert not include_num_layers_dimension
return test_stride

# Patch the attention backend class and re-trigger the KV cache creation
Expand Down
29 changes: 28 additions & 1 deletion vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,34 @@ def get_kv_cache_shape(
raise NotImplementedError

@staticmethod
def get_kv_cache_stride_order() -> tuple[int, ...]:
def get_kv_cache_stride_order(
include_num_layers_dimension: bool = False,
) -> tuple[int, ...]:
"""
Get the physical (memory layout) ordering of the kv cache dimensions.
e.g. if the KV cache shape is
[2, num_blocks, block_size, num_heads, head_size],
and get_kv_cache_stride_order returns (1, 3, 0, 2, 4) then the physical
ordering of dimensions is
[num_blocks, num_heads, 2, block_size, head_size].

If this function is unimplemented / raises NotImplementedError,
the physical layout of the KV cache will match the logical shape.

Args:
include_num_layers_dimension: if True, includes an additional
num_layers dimension, which is assumed to be prepended
to the logical KV cache shape.
With the above example, a return value (2, 4, 0, 1, 3, 5)
corresponds to
[num_blocks, num_heads, num_layers, 2, block_size, head_size].

If an additional dimension is NOT included in the returned
tuple, the physical layout will not include a layers dimension.

Returns:
A tuple of ints which is a permutation of range(len(shape)).
"""
raise NotImplementedError

@classmethod
Expand Down
33 changes: 31 additions & 2 deletions vllm/distributed/kv_transfer/kv_connector/v1/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
import enum
from abc import ABC, abstractmethod
from collections.abc import Callable, Iterable
from typing import TYPE_CHECKING, Any, Literal, Optional
from typing import TYPE_CHECKING, Any, ClassVar, Literal, Optional

import torch

Expand All @@ -47,7 +47,7 @@
from vllm.v1.outputs import KVConnectorOutput

if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
from vllm.config import VllmConfig
from vllm.distributed.kv_events import KVCacheEvent
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
Expand Down Expand Up @@ -142,6 +142,18 @@ class KVConnectorMetadata(ABC): # noqa: B024


class KVConnectorBase_V1(ABC):
"""
Base class for KV connectors.

Attributes:
prefer_cross_layer_blocks (bool): Indicates whether this connector
prefers KV blocks that hold KV data for all layers (for speeding
up KV data transfers).
Defaults to False.
"""

prefer_cross_layer_blocks: ClassVar[bool] = False
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quick question: what will happen if the connector implementation sets it to true? Will that impact vLLM's initialization process?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not necessarily.

The conditions that effect whether this cross-layers layout will be used is documented in the use_uniform_kv_cache function:

        Note this layout will only be applied given 3 conditions:
        1. The KV Cache config contains just a single group where all layers
            have the same page size.
        2. A KV connector is configured, and the KV connector instance prefers
            to use this layout (prefer_cross_layer_blocks() returns True)
        2. The flash attention backend supports this layout
            (get_kv_cache_stride_order(True) includes a placement for a
            num_layers dimension)


def __init__(
self,
vllm_config: "VllmConfig",
Expand Down Expand Up @@ -226,6 +238,23 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
"""
return

def register_cross_layers_kv_cache(
self, kv_cache: torch.Tensor, attn_backend: type["AttentionBackend"]
):
Comment on lines +241 to +243
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to pass attn_backend here? If we have a pre-defined shape of the underlying kv cache tensor, it should be the same across all the attn backends, right?

If so, can we also merge this function with the existing register_kvcaches function? Otherwise, it is not clear that when those two functions will be called.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The single cross-layers kv-cache tensor always matches a single attention backend.
Of course you connector could re-figure this out but this makes it more easy and straight-forward for the connector.
We could avoid this function and re-use register_kvcaches. This was actually my first thought.
However, this requires defining a "dummy" layer name for the cross-layer tensor.
I find this is hacky, and prefer we define an explicit function like I defined here.

This new function may only be called if a connector set prefer_cross_layer_blocks to True, and only one of two register functions will be called. This is documented in the definition in the base class:

        This function will only be called for models with uniform layers,
        and only if the KV connector returns True on
        prefers_cross_layer_blocks().
        Only one of the functions
        {register_kv_caches, register_cross_layers_kv_cache} will be called.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The single cross-layers kv-cache tensor always matches a single attention backend.

Can't agree. This layout can be extended to full + swa / full + mamba cases in the future. And don't want to change connector API when we start to work on that.

Copy link
Copy Markdown
Collaborator

@heheda12345 heheda12345 Nov 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After a second thought, I feel more comfortable with the current API when supporting multiple attention backends. To use cross-layer KV blocks with multiple attention backends, we need to let all backends to make agreement on the underlying stride and it's OK to pass one attention backend that can represent all backends. And the logic for this agreement is needed for all connectors so I prefer to do it in model runner instead of doing it inside each connector.

Copy link
Copy Markdown
Collaborator

@ApostaC ApostaC Nov 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

However, this requires defining a "dummy" layer name for the cross-layer tensor.

Dumb question: If the goal is to have the layer name, why don't we just pass the layer name to the connector? Or is it impossible?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

However, this requires defining a "dummy" layer name for the cross-layer tensor.

Dumb question: If the goal is to have the layer name, why don't we just pass the layer name to the connector? Or is it impossible?

Not sure I understand your question.
The current API (register_kv_caches) passes a dict mapping layer names to kv cache tensors. Those layer names are real, and the connector can use those layer names to reference other kv cache related structs (like attention backends).
Had I tried to use this API to pass in a cross-layer kv cache tensor, the connector would have to implicitly detect as not to wrongfully correspond this tensor with that of a "true" layer.

"""
Initialize with a single KV cache tensor used by all layers.
The first dimension should be num_layers.
This function will only be called for models with uniform layers,
and only if the prefers_cross_layer_blocks is set to True.
Only one of the functions
{register_kv_caches, register_cross_layers_kv_cache} will be called.

Args:
kv_cache: a cross-layers kv cache tensor
attn_backend: The attention backend that corresponds to all layers
"""
return

def set_host_xfer_buffer_ops(self, copy_operation: CopyBlocksOp):
"""
Set the xPU-specific ops for copying KV between host and device.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
from collections.abc import Iterable, Iterator
from dataclasses import dataclass
from itertools import islice
from typing import Any
from typing import Any, ClassVar

import torch

from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.attention import Attention, AttentionBackend, AttentionMetadata
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.distributed.kv_events import BlockRemoved, BlockStored, KVCacheEvent
from vllm.distributed.kv_transfer.kv_connector.v1 import (
KVConnectorBase_V1,
Expand Down Expand Up @@ -42,6 +42,8 @@ class OffloadingConnectorMetadata(KVConnectorMetadata):


class OffloadingConnector(KVConnectorBase_V1):
prefer_cross_layer_blocks: ClassVar[bool] = True

def __init__(
self,
vllm_config: VllmConfig,
Expand All @@ -63,6 +65,12 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
assert self.connector_worker is not None
self.connector_worker.register_kv_caches(kv_caches)

def register_cross_layers_kv_cache(
self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend]
):
assert self.connector_worker is not None
self.connector_worker.register_cross_layers_kv_cache(kv_cache, attn_backend)

def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
assert self.connector_worker is not None
assert isinstance(self._connector_metadata, OffloadingConnectorMetadata)
Expand Down Expand Up @@ -422,10 +430,35 @@ def _generate_job_id(self) -> int:
self._job_counter = job_id + 1
return job_id

def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
for src_cls, dst_cls, handler in self.spec.get_handlers(kv_caches):
def _register_handlers(
self,
kv_caches: dict[str, torch.Tensor],
attn_backends: dict[str, type[AttentionBackend]],
):
for src_cls, dst_cls, handler in self.spec.get_handlers(
kv_caches, attn_backends
):
self.worker.register_handler(src_cls, dst_cls, handler)

def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
layer_names = list(kv_caches.keys())
layers = get_layers_from_vllm_config(
self.spec.vllm_config, Attention, layer_names
)
attn_backends = {
layer_name: layers[layer_name].get_attn_backend()
for layer_name in layer_names
}
self._register_handlers(kv_caches, attn_backends)

def register_cross_layers_kv_cache(
self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend]
):
cross_layer_name = "ALL_LAYERS"
kv_caches = {cross_layer_name: kv_cache}
attn_backends = {cross_layer_name: attn_backend}
self._register_handlers(kv_caches, attn_backends)

def start_load_kv(self, metadata: OffloadingConnectorMetadata):
for req_id, transfer_spec in metadata.reqs_to_load.items():
job_id = self._generate_job_id()
Expand Down
Loading