Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
fd910d6
feat: simple cpu offloader prototype
ivanium Mar 2, 2026
4714ac3
fix: support dsv3.2 style where kv_raw_tensors have different sizes
ivanium Mar 10, 2026
cc2176c
perf: submit kv block copy kernels after model forward to hide triton…
ivanium Mar 11, 2026
d845120
refactor (worker): merge copy_kernel into submit_copy
ivanium Mar 12, 2026
5a338f4
fix (worker): async load/store
ivanium Mar 13, 2026
bbde447
fix (manager): lazy clean preempted req to load
ivanium Mar 13, 2026
523c52a
fix: clamp negative values in kv cache stats
ivanium Mar 11, 2026
0727d73
fix (scheduler): clamp negative stats values
ivanium Mar 15, 2026
9f22120
refactor: revert register_kv_caches to still take only kv_caches and …
ivanium Mar 12, 2026
273cf21
fix (worker): record_stream to prevent caching allocator frees the te…
ivanium Mar 14, 2026
55404ca
fix (scheduler): handle invalid blocks for multiple kv groups
ivanium Mar 15, 2026
d84570a
chore: clean up
ivanium Mar 16, 2026
362d8ed
chore: rename sub-package
ivanium Mar 16, 2026
86fda6f
perf: cuMemcpyBatchAsync for block copy, reduce TTFT overhead to ~4ms
ivanium Mar 16, 2026
ddd11e5
chore: env var to enable simple native offload
ivanium Mar 20, 2026
51c0efe
fix (worker): need sync transfers upon preemption
ivanium Mar 19, 2026
15b2003
fix (worker): enqueue event after recording
ivanium Mar 20, 2026
a73992f
fix (worker): store blocks should wait for main stream
ivanium Mar 21, 2026
08e4371
fix (manager): keep engine stepping until in-flight transfer finishes
ivanium Mar 20, 2026
f8b6f88
fix (manager): remove `_req_local_computed` tracking
ivanium Mar 21, 2026
ffebd2a
feat: rewrite lazy offload with cursor-based single-pass scanning
ivanium Mar 21, 2026
a821002
test: unit and integration tests for simple_kv_offload scheduler
ivanium Mar 19, 2026
e83b63a
feat (connector): support cross_layer layout
ivanium Mar 21, 2026
1213722
feat: add CopyBackend ABC with kernel and dma implementations
ivanium Mar 21, 2026
c0b5688
fix: delay block store by one step and avoid block_mapping
ivanium Mar 21, 2026
98cd651
chore: disable cross_layer layout
ivanium Mar 22, 2026
dc2f524
fix: correctly handle pending transfers when all requests finish
ivanium Mar 23, 2026
589183c
test: revise integration tests
ivanium Mar 23, 2026
2af9919
chore: nits
ivanium Mar 23, 2026
cd9b584
fix: pre-commit
ivanium Mar 23, 2026
57992d3
refactor: remove copy_ops_kernel as DMA is always faster
ivanium Mar 23, 2026
28a17f2
fix (worker): sync store stream to compute stream upon flush
ivanium Mar 24, 2026
93803a0
fix (worker): store needs to wait for current stream for unknown reason
ivanium Mar 24, 2026
08b93fa
fix (manager): cap blocks to store with num_confirmed_computed_tokens
ivanium Mar 24, 2026
c23d27c
refactor (manager): revise and simplify
ivanium Mar 25, 2026
f27e544
chore (worker): nits
ivanium Mar 25, 2026
5204154
chore (connector): fix mypy error
ivanium Mar 25, 2026
081ef29
fix: various issues
ivanium Mar 26, 2026
fed0291
feat(cuda): pin cpu buffer with cudaHostRegister to bypass torch cach…
ivanium Mar 26, 2026
9a3c14c
fix (test): add preemption_ids
ivanium Mar 26, 2026
f58f259
fix (worker): minor cleanup issue
ivanium Mar 26, 2026
e383c6a
fix (mkdoc): nits
ivanium Mar 26, 2026
efa5234
fix (core/scheduler): revert changes to request.num_cached_tokens
ivanium Mar 26, 2026
2874062
chore (connector): align definition for `cpu_bytes_to_use` with the e…
ivanium Mar 27, 2026
2bb64be
fix (manager): dedup shared blocks during eager offloading and enlarg…
ivanium Mar 27, 2026
13c3e6a
chore (scheduler): revert defensive counter clamp
ivanium Mar 29, 2026
bb8aabf
chore (core): revert drain_pending_transfers and leave that for a fol…
ivanium Mar 29, 2026
29bc685
test: relax test cases to workaround reset_prefix_cache()'s sync issue
ivanium Mar 29, 2026
0f68ab8
chore (scheduler): revert reset num_external_computed_tokens
ivanium Mar 29, 2026
49f0340
chore (manager): revert defensive cleanup for preempted reqs
ivanium Mar 30, 2026
245f137
fix (worker): more robust way to detect kv cache layout
ivanium Mar 30, 2026
2157c7c
feat: use WorkerMetadata to report back store events
ivanium Mar 30, 2026
3e579ee
chore: revert step changes to vllm scheduler and core
ivanium Mar 30, 2026
ccf2bd6
test: make all integration test optional
ivanium Mar 30, 2026
ad3ab34
chore: error out reset_cache() due to not implemented
ivanium Mar 30, 2026
1f15811
doc: add comments on stats changes
ivanium Mar 31, 2026
859974a
Merge branch 'main' into feat/simple-cpu-offload-cleanup
njhill Mar 31, 2026
8bb9d8f
fix (manager): max_hit_len - 1 to enforce recomputing the last token
ivanium Mar 31, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
193 changes: 193 additions & 0 deletions tests/v1/simple_kv_offload/test_integration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Integration tests for SimpleCPUOffloadConnector with real models."""

import time

import pytest

from vllm import LLM, SamplingParams, TokensPrompt
from vllm.config import KVTransferConfig
from vllm.platforms import current_platform

if not current_platform.is_cuda():
pytest.skip("Requires CUDA", allow_module_level=True)

# Small models for default CI / local runs (accuracy only).
SMALL_MODELS = [
"meta-llama/Llama-3.2-1B-Instruct",
"google/gemma-3-1b-it",
]

# Large models for optional perf runs only (slow to load and execute).
PERF_MODELS = [
"meta-llama/Llama-3.1-8B",
"openai/gpt-oss-20b",
]


def _make_llm(model: str, lazy: bool, cpu_bytes_to_use: int) -> LLM:
kv_transfer_config = KVTransferConfig(
kv_connector="SimpleCPUOffloadConnector",
kv_role="kv_both",
kv_connector_extra_config={
"cpu_bytes_to_use": cpu_bytes_to_use,
"lazy_offload": lazy,
},
)
return LLM(
model=model,
kv_cache_memory_bytes=40 << 30, # 40 GiB
disable_hybrid_kv_cache_manager=False,
enable_prefix_caching=True,
kv_transfer_config=kv_transfer_config,
)


def _flush_gpu_cache(llm: LLM, sampling_params: SamplingParams, seed: int = 0):
"""Generate enough filler requests to allocate the entire GPU KV cache.

This pushes all prior blocks through the free queue so that the lazy
cursor offloads them to CPU before they are evicted.
"""
cache_config = llm.llm_engine.vllm_config.cache_config
num_gpu_blocks = cache_config.num_gpu_blocks
block_size = cache_config.block_size
# Use 1.2x GPU capacity to give the lazy cursor enough scheduling steps
# to walk past all target blocks near the tail of the free queue.
total_tokens_needed = int(num_gpu_blocks * block_size * 1.5)

# Use token-id prompts so each filler is unique (no prefix sharing).
# Split into multiple requests to stay under max_model_len.
max_tokens_per_req = 4096
num_fillers = (total_tokens_needed + max_tokens_per_req - 1) // max_tokens_per_req
batch_size = 10
for i in range(0, num_fillers, batch_size):
batch_end = min(i + batch_size, num_fillers)
filler_prompts = []
for j in range(i, batch_end):
ids = [seed * num_fillers + j + 1] * max_tokens_per_req
filler_prompts.append(TokensPrompt(prompt_token_ids=ids))
llm.generate(filler_prompts, sampling_params, use_tqdm=False)


def _accuracy_test(llm: LLM, lazy: bool = False):
"""Verify that CPU-loaded KV produces correct output."""
sampling_params = SamplingParams(max_tokens=1, temperature=0)
prompt = "hi " * 2000 + "Let's count to ten. One, two, three, "

# Cold run — populate GPU cache and trigger CPU offload
cold_output = llm.generate(prompt, sampling_params, use_tqdm=False)[0]

# CPU hit runs
test_count = 10
success_count = 0
expected = cold_output.outputs[0].text
for i in range(test_count):
if lazy:
_flush_gpu_cache(llm, sampling_params, seed=i)
time.sleep(2) # let engine core drain pending transfers

# Reset GPU prefix cache so next run must load from CPU
if not llm.reset_prefix_cache():
print(f"GPU prefix cache reset failed for iteration {i}")

output = llm.generate(prompt, sampling_params, use_tqdm=False)[0]
if output.outputs[0].text == expected:
success_count += 1

assert success_count >= 0.5 * test_count, (
f"Accuracy too low: {success_count}/{test_count} matched '{expected}'"
)


def _latency_test(llm: LLM, lazy: bool = False):
"""Verify CPU cache hit is faster than cold compute."""
sampling_params = SamplingParams(max_tokens=1, seed=42)
prompt_token_ids = [0] * 10001

num_times_cpu_better = 0
num_tests = 10
for i in range(num_tests):
prompt_token_ids[0] = i
prompts = [TokensPrompt(prompt_token_ids=prompt_token_ids)]

# Cold
time.sleep(2) # let engine core drain pending transfers
if not llm.reset_prefix_cache():
print(f"GPU prefix cache reset failed for iteration {i}")
start = time.time()
llm.generate(prompts, sampling_params, use_tqdm=False)
cold_time = time.time() - start

if lazy:
_flush_gpu_cache(llm, sampling_params, seed=i)
else:
# Eager mode: GPU hit ensures store completion is processed.
llm.generate(prompts, sampling_params, use_tqdm=False)

time.sleep(2) # let engine core drain pending transfers
if not llm.reset_prefix_cache():
print(f"GPU prefix cache reset failed for iteration {i}")

# CPU hit
start = time.time()
llm.generate(prompts, sampling_params, use_tqdm=False)
cpu_time = time.time() - start

if cpu_time < cold_time:
num_times_cpu_better += 1

assert num_times_cpu_better >= 0.8 * num_tests, (
f"CPU hit only faster {num_times_cpu_better}/{num_tests} times"
)


@pytest.mark.optional
@pytest.mark.slow_test
@pytest.mark.parametrize("model", SMALL_MODELS)
def test_simple_cpu_offload_accuracy(model: str):
"""Store to CPU, reset GPU, load from CPU; verify output matches baseline."""
llm = _make_llm(model, False, 1 << 30) # 1GB
try:
_accuracy_test(llm, lazy=False)
finally:
del llm


@pytest.mark.optional
@pytest.mark.slow_test
@pytest.mark.parametrize("model", PERF_MODELS)
def test_simple_cpu_offload_perf_latency(model: str):
"""CPU KV hit should beat cold prefill on long context (large models only)."""
llm = _make_llm(model, False, 10 << 30) # 10GB
try:
_latency_test(llm, lazy=False)
finally:
del llm


@pytest.mark.optional
@pytest.mark.slow_test
@pytest.mark.parametrize("model", SMALL_MODELS)
def test_simple_cpu_offload_accuracy_lazy(model: str):
"""Lazy mode: flush GPU cache to trigger CPU offload, then verify hit."""
# CPU must be larger than GPU KV cache to avoid evicting offloaded blocks.
llm = _make_llm(model, True, 80 << 30) # 80GB
try:
_accuracy_test(llm, lazy=True)
finally:
del llm


@pytest.mark.optional
@pytest.mark.slow_test
@pytest.mark.parametrize("model", PERF_MODELS)
def test_simple_cpu_offload_perf_latency_lazy(model: str):
"""Lazy mode: CPU KV hit should beat cold prefill (large models only)."""
# CPU must be larger than GPU KV cache to avoid evicting offloaded blocks.
llm = _make_llm(model, True, 80 << 30) # 80GB
try:
_latency_test(llm, lazy=True)
finally:
del llm
Loading
Loading