Skip to content

[BugFix] Wait for compute before offloading KV to CPU#31341

Merged
njhill merged 2 commits intovllm-project:mainfrom
orozery:offloading-wait-stream
Jan 10, 2026
Merged

[BugFix] Wait for compute before offloading KV to CPU#31341
njhill merged 2 commits intovllm-project:mainfrom
orozery:offloading-wait-stream

Conversation

@orozery
Copy link
Copy Markdown
Collaborator

@orozery orozery commented Dec 25, 2025

This PR fixes the OffloadingConnector to wait on the default (compute) stream, before offloading KV data from the GPU to the CPU.
Additionally, we move the offloading to start at the beginning of the next engine step. This is to avoid contention with sample_tokens which also involves gpu->cpu copies.
Lastly, we remove the use of stream priorities as they don't effect DMA-based copies.


Note

Implements safer, staged KV offloading and cleans up stream handling.

  • Worker: add start_kv_transfers and prepare_store_kv; queue store jobs in _unsubmitted_store_jobs and submit them at the next step before starting loads.
  • Connector: rename calls from start_load_kvstart_kv_transfers and start_store_kvprepare_store_kv.
  • GPU/CPU handler: remove stream priority parameter; offload streams now wait_stream on the current (compute) stream for GPU→CPU copies; use default torch.cuda.Stream().
  • Tests: update expectations/bookkeeping to account for deferred store submission (unsubmitted_stores_count) and adjust assertions accordingly.

Written by Cursor Bugbot for commit 252e7205ce32fb041eda7fe61d2e49e73b237cc9. This will update automatically on new commits. Configure here.


Note

Implements safer, staged KV offloading and stream handling.

  • Worker: add start_kv_transfers and prepare_store_kv; queue store jobs in _unsubmitted_store_jobs and submit them at the next step before starting loads
  • Connector: rename start_load_kvstart_kv_transfers and start_store_kvprepare_store_kv
  • GPU/CPU handler: remove stream-priority parameter; offload streams now wait_stream on the current (compute) stream for GPU→CPU; use default torch.cuda.Stream()
  • Tests: adjust bookkeeping to track deferred stores (unsubmitted_stores_count), update expected stored indices and assertions

Written by Cursor Bugbot for commit 369cb4b006569afd1c586dc944299f25018abb66. This will update automatically on new commits. Configure here.


Note

Implements safer, staged KV offloading and removes ineffective stream priorities.

  • Worker: add start_kv_transfers and prepare_store_kv; queue store jobs in _unsubmitted_store_jobs and submit them at the start of the next step before starting loads
  • Connector: rename calls start_load_kvstart_kv_transfers and start_store_kvprepare_store_kv
  • CPU/GPU handler: remove stream-priority argument; use default torch.cuda.Stream() and, for GPU->CPU, wait_stream on the current compute stream to avoid contention
  • Tests: update bookkeeping to track deferred stores (unsubmitted_stores_count) and adjust expected stored/loaded block assertions accordingly

Written by Cursor Bugbot for commit f6640eb. This will update automatically on new commits. Configure here.


Note

Implements safer, staged KV offloading and stream handling.

  • Worker: add start_kv_transfers and prepare_store_kv; queue store jobs in _unsubmitted_store_jobs and submit them at the start of the next step before starting loads
  • Connector: rename calls start_load_kvstart_kv_transfers and start_store_kvprepare_store_kv
  • CPU/GPU handler: remove stream-priority parameter; use default torch.cuda.Stream() and, for GPU→CPU, wait_stream on the current compute stream to avoid contention
  • Tests: track deferred stores (unsubmitted_stores_count) and adjust expected stored/loaded block assertions accordingly

Written by Cursor Bugbot for commit 4f5923a. This will update automatically on new commits. Configure here.

@chatgpt-codex-connector
Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly addresses a race condition between model computation and KV cache offloading by introducing a stream synchronization for GPU-to-CPU transfers. It also refactors the offloading logic to delay store operations to the next engine step, which should help reduce contention on the DMA engine. The removal of unused stream priorities is a good cleanup. The changes are well-implemented and the accompanying test modifications accurately reflect the new asynchronous behavior.

@LucasWilkinson
Copy link
Copy Markdown
Collaborator

cc @NickLucche

Copy link
Copy Markdown
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @orozery, mostly looks good to me.

Additionally, we move the offloading to start at the beginning of the next engine step. This is to avoid contention with sample_tokens which also involves gpu->cpu copies.

Actually with async scheduling (which is now default), the sampled token gpu->cpu copy would still be happening at this point.

It may be better to leave the transfer where it is, or possibly better have the offload stream also wait on the sampled_tokens stream (GPUModelRunner.async_output_copy_stream).

@orozery orozery force-pushed the offloading-wait-stream branch from 4572662 to 252e720 Compare January 10, 2026 16:37
@orozery
Copy link
Copy Markdown
Collaborator Author

orozery commented Jan 10, 2026

Actually with async scheduling (which is now default), the sampled token gpu->cpu copy would still be happening at this point.

It may be better to leave the transfer where it is, or possibly better have the offload stream also wait on the sampled_tokens stream (GPUModelRunner.async_output_copy_stream).

I've been actually testing with both async_scheduling on and off.
Let me elaborate on what I've learned (from experimentation on H100 + online searching):

cudaMemcpyAsync calls, unlike CUDA kernels, are handled FIFO (in their order of submission).
There's a single copy engine per direction (e.g. CPU->GPU) which handles them.

My initial try was to start offloading right at the end of execute_model (adding a wait_stream on the compute stream).
This caused the transfer of sampled tokens to be blocked until all offloading finished.
If async scheduling is off, this was pretty bad: sample_tokens itself was blocking, and the core engine was blocked.
This harmed the latencies seen by the user (TTFT, ITL).
With async scheduling enabled, sample_tokens was no longer blocking, so results were much better.
Still, I prefer the offloading to start AFTER sample_tokens, as the sampled tokens are of higher priority from the user point of view, than the offloading.
Ideally, we would have a dedicated hook which allows us to submit the offloading cudaMemcpyAsync right after the async sample_token copies are submitted.
This is a bit tricky since only the output_rank worker actually calls sample_tokens, and I did not want to make too much invasive changes for now.

I think the best solution for now is what's currently implemented in this PR, i.e. moving the offloading before the start of the next execute_model.

@njhill
Copy link
Copy Markdown
Member

njhill commented Jan 10, 2026

cudaMemcpyAsync calls, unlike CUDA kernels, are handled FIFO (in their order of submission).
There's a single copy engine per direction (e.g. CPU->GPU) which handles them.

Do you mean regardless of cuda streams? In which case the situation we will have here is that both the offloading stream and sampled token stream will be waiting for the execution stream to reach the same point. At which point they will commence their async copy operations. It's not clear to me though which one will "win" and start first - not obvious that it would be the one which started waiting first...

This is a bit tricky since only the output_rank worker actually calls sample_tokens, and I did not want to make too much invasive changes for now.

This is only applicable to PP which is not compatible with async scheduling yet anyhow.

@njhill njhill self-assigned this Jan 10, 2026
@orozery
Copy link
Copy Markdown
Collaborator Author

orozery commented Jan 10, 2026

Do you mean regardless of cuda streams? In which case the situation we will have here is that both the offloading stream and sampled token stream will be waiting for the execution stream to reach the same point. At which point they will commence their async copy operations. It's not clear to me though which one will "win" and start first - not obvious that it would be the one which started waiting first...

Exactly. Both wait on the same stream, yet offloading wins as it was queued first.
I was trying to set stream priorities, giving the sample_tokens stream a higher priority, this made no difference.
And then search claimed that priorities only effect kernels scheduling, excluding copies using the copy engines.

@njhill
Copy link
Copy Markdown
Member

njhill commented Jan 10, 2026

Do you mean regardless of cuda streams? In which case the situation we will have here is that both the offloading stream and sampled token stream will be waiting for the execution stream to reach the same point. At which point they will commence their async copy operations. It's not clear to me though which one will "win" and start first - not obvious that it would be the one which started waiting first...

Exactly. Both wait on the same stream, yet offloading wins as it was queued first. I was trying to set stream priorities, giving the sample_tokens stream a higher priority, this made no difference. And then search claimed that priorities only effect kernels scheduling, excluding copies using the copy engines.

But it's not obvious that it's a queue, both of the streams are "woken up" at the same time, at which point they both proceed to call cudaMemcpyAsync. IIUC this is the point at which they are effectively queued, so it's still not clear to me that we'll be sure that the sampled token stream will always win.

I'm now not suggesting moving the hook, just that we should ideally change the stream.wait_stream(torch.cuda.current_stream()) in this PR to stream.wait_stream(async_output_copy_stream). But I know that the complication is that the connector doesn't have access to this rn.

@orozery
Copy link
Copy Markdown
Collaborator Author

orozery commented Jan 10, 2026

I'm now not suggesting moving the hook, just that we should ideally change the stream.wait_stream(torch.cuda.current_stream()) in this PR to stream.wait_stream(async_output_copy_stream). But I know that the complication is that the connector doesn't have access to this rn.

Nice idea (waiting on the sample tokens stream).
I was actually of thinking of adding a new connector API hook after sample tokens.
But your idea is neater :)

@njhill
Copy link
Copy Markdown
Member

njhill commented Jan 10, 2026

Having hook after sample tokens would not be any different to doing it at the start of the next step as you have here, right?

Either way you would still need to wait on the output copy stream.

@orozery
Copy link
Copy Markdown
Collaborator Author

orozery commented Jan 10, 2026

Having hook after sample tokens would not be any different to doing it at the start of the next step as you have here.

Either way you would still need to wait on the output copy stream.

I take that back.
We still need to call swap_blocks only after sample_tokens is called.
Waiting on the sample_tokens stream at execute_model will do nothing since the sample_tokens stream only gets populated at sample_tokens.
So we either wait to the next step, or add a new hook after sample tokens.
My feel (based on experimenting) is that this is enough. Even if sample tokens lose the race this is not a correctness issue.
But if the connector has access to the sample tokens stream I agree it's better to wait on it.

@njhill
Copy link
Copy Markdown
Member

njhill commented Jan 10, 2026

But if the connector has access to the sample tokens stream I agree it's better to wait on it.

We could add the copy stream as a field in ForwardContext, wdyt? This could be set to either the async output stream or current stream if that's None (async sched disabled case)

@orozery
Copy link
Copy Markdown
Collaborator Author

orozery commented Jan 10, 2026

We could add the copy stream as a field in ForwardContext, wdyt? This could be set to either the async output stream or current stream if that's None (async sched disabled case)

I'm not so familiar with ForwardContext.
I think it was the mechanism for passing in the kv tensors to v0 connectors?
I also recall it being acessed via a global variable.
Frankly I did not like this pattern as it was hard to figure out for someone who does not know it.
I prefer more explicit APIs.

Also, the stream is constant throughout the process lifetime, so it feels more right to pass it once at connector bootstrap, like in register_kv_caches.

@njhill
Copy link
Copy Markdown
Member

njhill commented Jan 10, 2026

Also, the stream is constant throughout the process lifetime, so it feels more right to pass it once at connector bootstrap, like in register_kv_caches.

Good point, maybe we can extend that method. I guess we could consider that as a follow-on.

@@ -92,7 +92,7 @@ def save_kv_layer(
def wait_for_save(self):
assert self.connector_worker is not None
assert isinstance(self._connector_metadata, OffloadingConnectorMetadata)
self.connector_worker.start_store_kv(self._connector_metadata)
self.connector_worker.prepare_store_kv(self._connector_metadata)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a short comment here to explain why we are deferring?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This class is just a redirection class.
I've added a comment instead in OffloadingConnectorWorker.prepare_store_kv where I think it makes more sense.
Let me know what you think.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess I was thinking it would make more sense here, since the name of the method being called now is prepare_store_kv rather than "transfer" or "start", and the comment is about the hook itself .. so it's clearer that this is in the load rather than save hook.

But it's not that big deal :)

@orozery orozery force-pushed the offloading-wait-stream branch from 252e720 to 369cb4b Compare January 10, 2026 20:10
@mergify
Copy link
Copy Markdown

mergify bot commented Jan 10, 2026

Hi @orozery, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

This commit fixes the OffloadingConnector to wait on the default (compute) stream,
before offloading KV data from the GPU to the CPU.
Additionally, we move the offloading to start at the beginning of the next engine step.
This is to avoid contention with sample_tokens which also involves gpu->cpu copies.
Lastly, we remove the use of stream priorities as they don't effect DMA-based copies.

Signed-off-by: Or Ozeri <oro@il.ibm.com>
@orozery orozery force-pushed the offloading-wait-stream branch from 369cb4b to f6640eb Compare January 10, 2026 20:18
@njhill njhill added the ready ONLY add when PR is ready to merge/full CI is needed label Jan 10, 2026
@njhill njhill enabled auto-merge (squash) January 10, 2026 20:34
@njhill njhill added this to the v0.14.0 milestone Jan 10, 2026
@njhill njhill merged commit 2a4dbe2 into vllm-project:main Jan 10, 2026
53 of 54 checks passed
akh64bit pushed a commit to akh64bit/vllm that referenced this pull request Jan 16, 2026
dsuhinin pushed a commit to dsuhinin/vllm that referenced this pull request Jan 21, 2026
…1341)

Signed-off-by: Or Ozeri <oro@il.ibm.com>
Signed-off-by: dsuhinin <suhinin.dmitriy@gmail.com>
ItzDEXX pushed a commit to ItzDEXX/vllm that referenced this pull request Feb 19, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

kv-connector ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants