Skip to content

CPU KV Offloading: Use more CUDA streams#29013

Merged
njhill merged 2 commits intovllm-project:mainfrom
orozery:cpu-kv-offloading-streams
Dec 14, 2025
Merged

CPU KV Offloading: Use more CUDA streams#29013
njhill merged 2 commits intovllm-project:mainfrom
orozery:cpu-kv-offloading-streams

Conversation

@orozery
Copy link
Copy Markdown
Collaborator

@orozery orozery commented Nov 19, 2025

Prior to this PR, the CPU offloading connector used 2 global CUDA streams, one for loads and one for stores.
When there are a lot of concurrent loads, the stream queue (which has a capacity of about 1000 on A100) can get full.
When this happens, submitting more operations to the stream becomes blocking, which is harmful to model execution.
To overcome this, this PR changes the CPU KV offloading to use a unique CUDA stream per transfer.

I have witnessed this is a true issue by timing the time it takes to call cudaMemcpyAsync, which increases from 2us to 180us when the stream queue gets full.

@mergify
Copy link
Copy Markdown

mergify bot commented Nov 19, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @orozery.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the CPU KV offloading to use a unique CUDA stream per transfer, which is a good approach to prevent stream queue saturation and potential blocking during concurrent loads. The implementation correctly uses pooling for streams and events to improve efficiency. However, I've identified a critical resource leak issue where streams and events are not returned to the pool if an error occurs during the transfer. I've provided a code suggestion to fix this. Additionally, it appears the test file tests/v1/kv_offload/test_cpu_gpu.py has not been updated to reflect the changes in this PR (e.g., renaming transfer_events to transfers), which will cause tests to fail. Please ensure the tests are updated accordingly.

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

@orozery orozery force-pushed the cpu-kv-offloading-streams branch 2 times, most recently from 82346cf to ec780a6 Compare November 19, 2025 13:31
@mergify mergify bot removed the needs-rebase label Nov 19, 2025
@orozery
Copy link
Copy Markdown
Collaborator Author

orozery commented Nov 19, 2025

@ApostaC

@orozery orozery force-pushed the cpu-kv-offloading-streams branch 3 times, most recently from a135689 to 25d1378 Compare November 27, 2025 06:22
Copy link
Copy Markdown
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @orozery, looks great just a couple of minor comments.

Comment on lines +133 to +137
stream = (
self._stream_pool.pop()
if self._stream_pool
else torch.cuda.Stream(priority=self.priority)
)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we keep a bound on the max number of streams (start to share after this)? Or is the number of concurrent transfers already bounded somehow?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think there's a reason to limit.
On the GPU side, it seems there's no limit on the number of streams.
Also, note that the streams do not run in parallel, as each transfer stream waits for the previous transfer to finish.
The only purpose for creating multiple streams is to overcome the 1024-ops-per-stream limit.
I tested this PR with 100000 concurrent small requests loading from CPU and it handles it well.
It does not really create 100000 as the GPU KV cache gets full so only around 1500 requests get allocated in any point in time, so basically my test used 1500 streams with no issues.

I am planning on limiting the number of requests being async loaded, to count towards max_num_seqs (which equals 1024 for H100).
See this draft PR #29877.
So once we get that in, the number of streams will be bounded by, roughly, max_num_seqs.

@njhill
Copy link
Copy Markdown
Member

njhill commented Dec 13, 2025

Also pinging @ApostaC again in case he wants to check

Prior to this commit, the CPU offloading connector used 2 global CUDA streams, one for loads and one for stores.
When there are a lot of concurrent loads, the stream queue (which has a capacity of about 1000 on A100) can get full.
When this happens, submitting more operations to the stream becomes blocking, which is harmful to model execution.
To overcome this, this commit changes the CPU KV offloading to use a unique CUDA stream per transfer.
Additionally, we add code that detects the kernel block size and fails if it is different than the user set block size.
Support for general kernel block size will be added in a follow-up PR.

Signed-off-by: Or Ozeri <oro@il.ibm.com>
@orozery orozery force-pushed the cpu-kv-offloading-streams branch from 25d1378 to a2048af Compare December 13, 2025 18:27
@njhill njhill added the ready ONLY add when PR is ready to merge/full CI is needed label Dec 14, 2025
@github-project-automation github-project-automation bot moved this to Ready in NVIDIA Dec 14, 2025
@njhill njhill enabled auto-merge (squash) December 14, 2025 02:04
@njhill njhill merged commit 174e39e into vllm-project:main Dec 14, 2025
46 of 47 checks passed
@github-project-automation github-project-automation bot moved this from Ready to Done in NVIDIA Dec 14, 2025
Lucaskabela pushed a commit to Lucaskabela/vllm that referenced this pull request Dec 15, 2025
joa-stdn pushed a commit to joa-stdn/vllm that referenced this pull request Dec 15, 2025
Signed-off-by: Or Ozeri <oro@il.ibm.com>
Signed-off-by: Joachim Studnia <joachim@mistral.ai>
teddygood pushed a commit to teddygood/vllm that referenced this pull request Dec 16, 2025
if self._stream_pool
else torch.cuda.Stream(priority=self.priority)
)
event = self._event_pool.pop() if self._event_pool else torch.Event()
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, there seems to be a problem.

In the swapping-out situation, this workflow does not guarantee that the swap operation occurs after the forward computation, potentially resulting in the incorrect KV cache being swapped out.

So, maybe we can insert such lines of code?

        current_forward_stream = torch.cuda.current_stream()
        if (src_spec.medium() == "gpu") and (dst_spec.medium() == "cpu"):
            stream.wait_stream(current_forward_stream)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even before this PR, the offloading connector used a separate cuda stream for executing the transfers.
This stream was not synced with the model execution stream.
This is also the case for other connectors.
Looking at the connector API itself, specifically save_kv_layer and wait_for_save.
My assumption was that when these functions are called we can be sure the KV cache is safe to read (without syncing on the default stream).
@NickLucche @njhill can you confirm?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Digging more into the code, looks like you're right @lixiaobai09.
This is indeed a bug, though not introduced by this PR, but existed from the beginning of the OffloadingConnector.
I just opened #31341 to fix it.
@lixiaobai09 thank you so much for point this out!

Majid-Taheri pushed a commit to Majid-Taheri/vllm that referenced this pull request Dec 23, 2025
Signed-off-by: Or Ozeri <oro@il.ibm.com>
Signed-off-by: Ubuntu <mjtaheri68@gmail.com>
dsuhinin pushed a commit to dsuhinin/vllm that referenced this pull request Jan 21, 2026
Signed-off-by: Or Ozeri <oro@il.ibm.com>
Signed-off-by: dsuhinin <suhinin.dmitriy@gmail.com>
ItzDEXX pushed a commit to ItzDEXX/vllm that referenced this pull request Feb 19, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

nvidia ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

3 participants