Skip to content

[Feat][v1] Simple yet General CPU KV Cache Offloading#37160

Merged
njhill merged 58 commits intovllm-project:mainfrom
ivanium:feat/simple-cpu-offload-cleanup
Apr 1, 2026
Merged

[Feat][v1] Simple yet General CPU KV Cache Offloading#37160
njhill merged 58 commits intovllm-project:mainfrom
ivanium:feat/simple-cpu-offload-cleanup

Conversation

@ivanium
Copy link
Copy Markdown
Contributor

@ivanium ivanium commented Mar 16, 2026

Purpose

SimpleCPUOffloadConnector is another design of vLLM's CPU KV cache offloading path. Instead of maintaining a parallel block management stack, it reuses vLLM's existing BlockPool and KVCacheCoordinator infrastructure directly. This gives us HMA support, prefix caching, and LRU eviction for free.

The new design is simpler with ~1,400 lines of code, more general with support for hybrid models, lazy offloading, and lower per-step overhead.

Full design doc: https://docs.google.com/document/d/1TDY3eSjv7gsTXAcUjKEu15QTKSZpUpZqmnaKafywpgw/edit?usp=sharing

Note

This PR supports regular models and hybrid models with SWA, but not yet hybrid models with Mamba.
Supporting hybrid models with Mamba needs some scheduler-side fixes, and we will address this in a follow-up PR.

Test Plan

# NOTE: Need to export this env var to enable it
# NOTE: kv-offloading-size is CPU memory in GB in total (GB_per_rank * world_size), to align with existing offloading connector.
# NOTE: Enable lazy_offload by setting it to true in extra_config
export VLLM_USE_SIMPLE_KV_OFFLOAD=1
MODEL=openai/gpt-oss-20b # meta-llama/Llama-3.1-8B, Qwen3.5-4B
vllm serve $MODEL \
  --no-disable-hybrid-kv-cache-manager \
  --enable-prefix-caching \
  --no-enable-log-requests \
  --kv-offloading-size 80 \
  --kv-offloading-backend native \
  --kv-transfer-config '{"kv_connector_extra_config": {"lazy_offload": false}}'

Test Result

Overhead

Workload: random, Input:Output=8k:1k, 2 req/s

  • Llama-3.1-8B - GB200
Metric No Offload Simple Offload Delta
Output throughput (tok/s) 2009.76 2009.80 +0.00%
Request throughput (req/s) 1.96 1.96 +0.00%
Mean TTFT (ms) 174.73 176.63 +1.09%
Mean TPOT (ms) 8.48 8.48 +0.00%
P99 TPOT (ms) 11.44 11.42 -0.17%
Mean ITL (ms) 8.48 8.48 +0.00%
P99 ITL (ms) 82.62 76.25 -7.71%
  • GPT-oss-20b - GB200
Metric No Offload Simple Offload Delta
Output throughput (tok/s) 2020.81 2020.72 -0.00%
Request throughput (req/s) 1.97 1.97 +0.00%
Mean TTFT (ms) 541.71 541.42 -0.05%
Mean TPOT (ms) 5.15 5.18 +0.58%
P99 TPOT (ms) 9.52 9.51 -0.10%
Mean ITL (ms) 5.17 5.19 +0.39%
P99 ITL (ms) 58.04 58.70 +1.14%

Multi-turn with CPU KV cache hits (Llama-3.1-8B):

  • Llama-3.1-8B - GB200 - 400GB host memory space
Metric No Offload Native Simple Eager Simple Lazy
Total throughput (tok/s) 35881.52 41403.76 44908.33 45151.15
Output throughput (tok/s) 1504.77 1414.28 1451.68 1442.57
Request tput  (req/s) 2.94 2.77 2.84 2.82
Mean TTFT (ms) 6621.90 6509.47 6214.20 6462.97
Mean TPOT (ms) 32.13 24.01 21.48 21.93
P99 TPOT (ms) 59.22 62.49 49.01 50.49
Mean ITL (ms) 32.13 24.02 21.51 21.96
P99 ITL (ms) 166.22 179.63 170.18 176.08

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant and well-designed refactoring of the CPU KV cache offloading mechanism. The new SimpleCPUOffloadConnector simplifies the architecture by reusing existing components like BlockPool and KVCacheCoordinator, and it introduces efficient Triton-based copy operations. The code is generally well-structured and clear. My review identified two potential high-severity issues in the new scheduler manager related to state management during request preemption and CPU cache eviction, which could lead to a memory leak and incorrect behavior respectively. These issues are noted with FIXMEs in the code, but I've provided detailed comments on their potential impact and suggestions for resolution.

Comment thread vllm/v1/simple_kv_offload/manager.py Outdated
Comment thread vllm/v1/simple_kv_offload/manager.py
@ivanium ivanium force-pushed the feat/simple-cpu-offload-cleanup branch 3 times, most recently from 8a59f1f to b2dbc9b Compare March 21, 2026 08:31
@ivanium ivanium marked this pull request as ready for review March 21, 2026 08:31
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Mar 21, 2026

Hi @ivanium, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

1 similar comment
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Mar 23, 2026

Hi @ivanium, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

Comment on lines +238 to +242
# kv_cache_manager is constructed so block_pool is available.
if self.connector is not None and hasattr(
self.connector, "bind_gpu_block_pool"
):
self.connector.bind_gpu_block_pool(self.kv_cache_manager.block_pool)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not really happy with this sort of API changes given the maturity we're trying to reach with the Connector interface contract.
Not to mention this won't work with MultiConnector.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to clarify, editing the Connector interface is ok but this is a hack

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this is intentional to avoid changing the Connector interface for now and keep this simple CPU offload backend as an experimental feature without confusing the other connector backends. I know we have some ongoing plans for Connector API v2, and I think we can discuss/finalize API changes then.

Comment thread vllm/v1/core/sched/scheduler.py Outdated
if self.connector is not None and hasattr(
self.connector, "has_pending_transfers"
):
return self.connector.has_pending_transfers()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

@mergify mergify Bot added frontend performance Performance-related issues labels Mar 23, 2026
)

if hit_length > 0:
return hit_length, True
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we have n tokens, we can hit at most n-1 tokens. need to recompute the last one to get the first logprob

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's fine, the scheduler will make the reduction when the load is done.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe I miss the code. Can you give me the code pointer?

And why can scheduler make this reduction? for swa with window size 100, a cache hit of 1000 tokens means tokens [900, 1000] are cached, it doesn't indicate cache hit of 999 tokens, which needs kv of token [899, 999], because we don't know whether token 899 is cached.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

request.num_computed_tokens = request.num_tokens - 1

for swa with window size 100, a cache hit of 1000 tokens means tokens [900, 1000] are cached, it doesn't indicate cache hit of 999 tokens, which needs kv of token [899, 999], because we don't know whether token 899 is cached.

Good point!
I think this code in the scheduler existed before we supported sliding windows.

Anyhow, I think the correct fix is possibly reducing max_hit_len by 1 BEFORE calling self.cpu_coordinator.find_longest_cache_hit. WDYT?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch. I think I can change it to max_hit_len = request.num_tokens - 1 - num_computed_tokens?

return 0, False

max_hit_len = len(remaining_hashes) * self.block_size
_, hit_length = self.cpu_coordinator.find_longest_cache_hit(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are you implementing lazy offloading? if yes, the common prefix of all prompts will never be offloaded and you can get 0 cache hit in cpu.

Copy link
Copy Markdown
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @ivanium for the great work and thanks @heheda12345 and @orozery for the really good reviews.

Comment thread vllm/v1/metrics/stats.py
Comment on lines +284 to +293
# FIXME(yifan): local_cache_hit can go negative after preemption.
# num_cached_tokens is a one-time snapshot from first scheduling and
# is never reset on preemption, while num_external_computed_tokens is
# overwritten on re-scheduling. If CPU offload finds more tokens on
# the second pass than the original total, the subtraction underflows.
# A fundamental fix is to track the first-time num_external_computed_tokens
# as a separate metric rather than reusing num_external_computed_tokens
# for metric directly.
self.local_cache_hit += max(
0, (num_cached_tokens + recomputed - num_external_computed_tokens)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this temporary hack fix is ok, it will hopefully be addressed properly soon via #37460.

@github-project-automation github-project-automation Bot moved this to Ready in NVIDIA Mar 31, 2026
@njhill njhill merged commit 91e4521 into vllm-project:main Apr 1, 2026
67 checks passed
@github-project-automation github-project-automation Bot moved this from Ready to Done in NVIDIA Apr 1, 2026
@khluu khluu added this to the v0.19.0 cherry picks milestone Apr 1, 2026
khluu pushed a commit that referenced this pull request Apr 1, 2026
Signed-off-by: Yifan Qiao <yifanqiao@berkeley.edu>
Signed-off-by: Yifan Qiao <yifanqiao@inferact.ai>
(cherry picked from commit 91e4521)
puririshi98 pushed a commit to puririshi98/vllm that referenced this pull request Apr 7, 2026
…7160)

Signed-off-by: Yifan Qiao <yifanqiao@berkeley.edu>
Signed-off-by: Yifan Qiao <yifanqiao@inferact.ai>
Signed-off-by: Rishi Puri <riship@nvidia.com>
mtparet pushed a commit to blackfuel-ai/vllm that referenced this pull request Apr 9, 2026
…7160)

Signed-off-by: Yifan Qiao <yifanqiao@berkeley.edu>
Signed-off-by: Yifan Qiao <yifanqiao@inferact.ai>
cquil11 added a commit to SemiAnalysisAI/InferenceX that referenced this pull request Apr 14, 2026
- Replace native offloading with SimpleCPUOffloadConnector
  (VLLM_USE_SIMPLE_KV_OFFLOAD=1 + --no-disable-hybrid-kv-cache-manager)
  for ~10% better throughput and TPOT per vllm-project/vllm#37160
- Remove local_cache_hit and scheduler.py monkey-patches (fixed in
  vLLM 0.19.0+), replace with version check warning
- Add AIPERF_SERVICE_PROFILE_CONFIGURE_TIMEOUT=1800 to H200 and B200
  (H100 already had it)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
cquil11 added a commit to SemiAnalysisAI/InferenceX that referenced this pull request Apr 15, 2026
VLLM_USE_SIMPLE_KV_OFFLOAD=1 routes to SimpleCPUOffloadConnector which
imports cuda.bindings (NVIDIA-only, PR vllm-project/vllm#37160). Remove
it from MI355X scripts so native offloading uses the ROCm-safe
OffloadingConnector. Also update H200 trace dir to use traces_neon with
env-var override to match the other trace replay scripts.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

frontend kv-connector nvidia performance Performance-related issues ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

7 participants