Skip to content

[BugFix] kv_offloading: Fix bug in loading of partial cpu blocks#28951

Merged
DarkLight1337 merged 2 commits intovllm-project:mainfrom
orozery:cpu-offloading-partial-block-load-bugfix
Nov 20, 2025
Merged

[BugFix] kv_offloading: Fix bug in loading of partial cpu blocks#28951
DarkLight1337 merged 2 commits intovllm-project:mainfrom
orozery:cpu-offloading-partial-block-load-bugfix

Conversation

@orozery
Copy link
Copy Markdown
Collaborator

@orozery orozery commented Nov 18, 2025

Fixes #28950

This PR fixes a bug when trying to load from the middle of a CPU block. This can happen if cpu_block_size > gpu_block_size, and there's both a cpu and gpu (prefix cache) hit, where the gpu hit ends in the middle of a cpu block. Before this commit, the code tried to wrongfully address the other direction, storing to the middle of a cpu block. But this is impossible since the offloading connector always stores full CPU blocks.

@mergify mergify bot added the v1 label Nov 18, 2025
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a bug in loading partial CPU blocks for KV cache offloading. The core logic is adjusted to correctly skip sub-blocks from the source (CPU) rather than the destination, and the test suite is updated to validate this scenario. However, this change introduces a critical issue where the block mapping array src_to_dst is allocated with an incorrect size. This can lead to reading uninitialized memory and subsequent incorrect data transfers. I have provided a specific code suggestion to rectify this allocation bug.

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

@orozery orozery force-pushed the cpu-offloading-partial-block-load-bugfix branch from b7ad215 to ee67691 Compare November 18, 2025 17:24
This commit fixes a bug when trying to load from the middle of a CPU block.
This can happen if cpu_block_size > gpu_block_size,
and there's both a cpu and gpu (prefix cache) hit,
where the gpu hit ends in the middle of a cpu block.
Before this commit, the code tried to wrongfully address the other direction,
storing to the middle of a cpu block. But this is impossible since the offloading connector
always stores full CPU blocks.

Signed-off-by: Or Ozeri <oro@il.ibm.com>
@orozery orozery force-pushed the cpu-offloading-partial-block-load-bugfix branch from ee67691 to b5a0482 Compare November 18, 2025 17:45
@orozery
Copy link
Copy Markdown
Collaborator Author

orozery commented Nov 18, 2025

@ApostaC can you please have a look?
Thanks!

Copy link
Copy Markdown
Collaborator

@ApostaC ApostaC left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

One quick question: is there any chance that the CPU block size is smaller than the GPU block size? In this case, we also need to skip dst (GPU) sub blocks when doing CPU -> GPU transfer.

@orozery
Copy link
Copy Markdown
Collaborator Author

orozery commented Nov 19, 2025

One quick question: is there any chance that the CPU block size is smaller than the GPU block size? In this case, we also need to skip dst (GPU) sub blocks when doing CPU -> GPU transfer.

It's impossible.
The offloading connector is based on the assumption that the offloaded block size is a multiply of the gpu block size.

See here:

assert self.offloaded_block_size % self.gpu_block_size == 0

@njhill njhill changed the title kv_offloading: Fix bug in loading of partial cpu blocks [BugFix] kv_offloading: Fix bug in loading of partial cpu blocks Nov 19, 2025
@njhill njhill added ready ONLY add when PR is ready to merge/full CI is needed bug Something isn't working labels Nov 19, 2025
@DarkLight1337 DarkLight1337 merged commit c0c2dd1 into vllm-project:main Nov 20, 2025
42 checks passed
lpapavassiliou pushed a commit to lpapavassiliou/vllm that referenced this pull request Nov 24, 2025
…m-project#28951)

Signed-off-by: Or Ozeri <oro@il.ibm.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
RunkaiTao pushed a commit to RunkaiTao/vllm that referenced this pull request Nov 24, 2025
…m-project#28951)

Signed-off-by: Or Ozeri <oro@il.ibm.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
Signed-off-by: Runkai Tao <rt572@physics.rutgers.edu>
devpatelio pushed a commit to SumanthRH/vllm that referenced this pull request Nov 29, 2025
…m-project#28951)

Signed-off-by: Or Ozeri <oro@il.ibm.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
kitaekatt pushed a commit to kitaekatt/vllm that referenced this pull request Dec 1, 2025
…m-project#28951)

Signed-off-by: Or Ozeri <oro@il.ibm.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug]: assert when offloading to cpu

4 participants