Skip to content

OffloadingConnector: Add cpu_bytes_to_use configuration#24498

Merged
NickLucche merged 1 commit intovllm-project:mainfrom
orozery:num-cpu-blocks
Jan 12, 2026
Merged

OffloadingConnector: Add cpu_bytes_to_use configuration#24498
NickLucche merged 1 commit intovllm-project:mainfrom
orozery:num-cpu-blocks

Conversation

@orozery
Copy link
Copy Markdown
Collaborator

@orozery orozery commented Sep 9, 2025

This PR replaces the OffloadingConnector size configuration from num_cpu_blocks to cpu_bytes_to_use.
This allows for a more intuitive space allocation (per vLLM instance, across workers).


Note

Modernizes KV offloading configuration and wiring.

  • Replace num_cpu_blocks/kv_bytes_per_rank with instance-wide cpu_bytes_to_use for OffloadingConnector (docs and configs)
  • CPUOffloadingSpec now derives num_blocks from KVCacheConfig (page size, tensors, world size) and block_size; requires passing kv_cache_config into OffloadingSpecFactory.create_spec(...) and OffloadingConnector
  • Update VllmConfig._post_init_kv_transfer_config to set cpu_bytes_to_use = kv_offloading_size * (1 << 30) (no per-rank split); LMCache path unchanged
  • Tests adjusted for new config keys and spec signatures (MockOffloadingSpec, unit/integration offloading tests)

Written by Cursor Bugbot for commit 8208b41. This will update automatically on new commits. Configure here.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

Gemini encountered an error creating the review. You can try again by commenting /gemini review.

Copy link
Copy Markdown
Collaborator

@ApostaC ApostaC left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise LGTM.

Comment on lines +213 to +214
num_cpu_blocks = (int(vllm_config.cache_config.swap_space_bytes) //
kv_cache_configs[0].kv_bytes_per_block)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is also a knob called offloaded_block_size in #22595. IIUC, it also impacts the calculation of num_cpu_blocks, right? (i.e., if we have larger CPU blocks, we should have less number of CPU blocks)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In v0, the offloading was part of the core.
My suggestion for v1 is to have the offloading as a connector.
I wanted to follow the convention for connectors, where all of their arguments are actually defined in their kv_connector_extra_config.

However, deriving num_cpu_blocks from some kind of a swap_space parameter requires knowledge of kv_bytes_per_block.
So basically, I need my connector (both scheduler-side and worker-side) to be aware of kv_bytes_per_block.
This requires changing things in core, so I tried to make minimal changes and came up with the approach here:

For the scheduler-side connector, report kv_bytes_per_block by setting the existing V0 field num_cpu_blocks.
For the worker-side connector, pass-on kv_cache_configs via the register_kv_caches function (in a follow-up PR).

When the offloading connector gets this num_cpu_blocks (given in GPU block size), it can derive the actual num_cpu_blocks by dividing by block_size_factor.

To sum-up, I'm trying to make minimal changes to the core.
This results in the actual offloading configuration parameters split between vllm_config.cache_config and kv_connector_extra_config.

I'm good with taking a different approach.
Your thoughts?
Perhaps we should ask other relevant folks on their opinion here?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. This is a good point. I think at a high level, there should be two parameters that can be configured by users: (1) total_cpu_buffer_size and (2) cpu_buffer_block_size (how many tokens in each CPU block).

For (1), it's also worth thinking whether it's per rank or per vLLM instance (i.e., summed across all ranks). I feel like if it's per rank, probably it will be better to pass it in the KV connector configs, while it makes more sense to have a "global" cache size when it's configured by global configurations like --swap-space.

For (2), I think it should definitely be put into the KV connector config as it's the current CPU-offloading-connector-specific configuration.

To sum up, I feel like putting all the configs into the KV connector config will probably be better and less confusing. WDYT?

@mergify
Copy link
Copy Markdown

mergify bot commented Sep 16, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @orozery.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Sep 16, 2025
@orozery
Copy link
Copy Markdown
Collaborator Author

orozery commented Sep 18, 2025

I'm planning to change this PR next week to share the raw kv_bytes_per_block to the connectors (both scheduler and worker side), and to omit the use of the legacy swap_block_bytes.

@orozery orozery changed the title v1: Set num_cpu_blocks on VllmConfig OffloadingConnector: Add swap_space_bytes configuration Sep 30, 2025
@mergify mergify bot added documentation Improvements or additions to documentation tpu Related to Google TPUs kv-connector labels Sep 30, 2025
@mergify mergify bot removed the needs-rebase label Sep 30, 2025
@orozery
Copy link
Copy Markdown
Collaborator Author

orozery commented Dec 17, 2025

Thanks for doing this @orozery !

Generally speaking, I am not a huge fan of the pattern we're using to update the config as a way to share global values that are runtime dependent.

Following your comment, I've spent some time thinking on it.
I still think we need the kv_bytes_per_block field added to KVCacheConfig (and @heheda12345 can sure confirm or prove me wrong, as it is due to hybrid models).
On the other hand, I was able to get rid of the introducing of new lazy-initialized fields to the cache_config.
(By noticing that the connector API on both sides has access to KVCacheConfig via its init.

So PR is modified now and is much leaner.
Basically, the only "intrusive" part is in kv_cache_utils.py (and 1 line in kc_cache_interface.py).

@heheda12345
Copy link
Copy Markdown
Collaborator

heheda12345 commented Dec 18, 2025

For passing in KVCacheConfig, we now have an optional kv_cache_config field in KVConnectorBase_V1. And for page size per block, you can get it from kv_cache_config.kv_cache_groups[0].kv_cache_spec.page_size_bytes. Now the page_size_bytes of all kv_cache_groups are the same and you can also add an assert for this assumption.

@orozery
Copy link
Copy Markdown
Collaborator Author

orozery commented Dec 18, 2025

For passing in KVCacheConfig, we now have an optional kv_cache_config field in KVConnectorBase_V1. And for page size per block, you can get it from kv_cache_config.kv_cache_groups[0].kv_cache_spec.page_size_bytes. Now the page_size_bytes of all kv_cache_groups are the same and you can also add an assert for this assumption.

Thanks!

This is what I do now:

        page_sizes = {
            kv_cache_group.kv_cache_spec.page_size_bytes
            for kv_cache_group in kv_cache_config.kv_cache_groups
        }
        assert len(page_sizes) == 1
        page_size_bytes = page_sizes.pop()
        kv_bytes_per_block = (
            page_size_bytes
            * len(kv_cache_config.kv_cache_tensors)
            * vllm_config.parallel_config.world_size
        )

@heheda12345 does this look ok to you?

@NickLucche I removed all intrusive changes, so hopefully we're good to go?

@heheda12345
Copy link
Copy Markdown
Collaborator

For passing in KVCacheConfig, we now have an optional kv_cache_config field in KVConnectorBase_V1. And for page size per block, you can get it from kv_cache_config.kv_cache_groups[0].kv_cache_spec.page_size_bytes. Now the page_size_bytes of all kv_cache_groups are the same and you can also add an assert for this assumption.

Thanks!
This is what I do now:

        page_sizes = {
            kv_cache_group.kv_cache_spec.page_size_bytes
            for kv_cache_group in kv_cache_config.kv_cache_groups
        }
        assert len(page_sizes) == 1
        page_size_bytes = page_sizes.pop()
        kv_bytes_per_block = (
            page_size_bytes
            * len(kv_cache_config.kv_cache_tensors)
            * vllm_config.parallel_config.world_size
        )

@heheda12345 does this look ok to you?
@NickLucche I removed all intrusive changes, so hopefully we're good to go?

Yes, sounds good!

Copy link
Copy Markdown
Collaborator

@ApostaC ApostaC left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update. Removing my previous "request changes"

@ApostaC ApostaC dismissed their stale review January 5, 2026 19:22

Dismiss my previous outdated reviews

@esmeetu
Copy link
Copy Markdown
Member

esmeetu commented Jan 8, 2026

@ApostaC @heheda12345 Is this PR ready to be merged? Hopefully can goes into v0.14.0.

Copy link
Copy Markdown
Collaborator

@NickLucche NickLucche left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leaner, thanks for the work @orozery !


```bash
--kv-transfer-config '{"kv_connector":"OffloadingConnector","kv_role":"kv_both","kv_connector_extra_config":{"block_size": 64, "num_cpu_blocks": 1000}}'
--kv-transfer-config '{"kv_connector":"OffloadingConnector","kv_role":"kv_both","kv_connector_extra_config":{"block_size": 64, "cpu_bytes_to_use": 1000000000}}'
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we should probably follow up with a more human readable way of expressing the value

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm open for suggestions :)

Copy link
Copy Markdown
Collaborator

@NickLucche NickLucche Jan 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thinking about unifying with max-model-len format

@NickLucche NickLucche enabled auto-merge (squash) January 8, 2026 14:45
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Jan 8, 2026
This commit replaces the OffloadingConnector size configuration from num_cpu_blocks to cpu_bytes_to_use.
This allows for a more intuitive space allocation (per vLLM instance, across workers).

Signed-off-by: Or Ozeri <oro@il.ibm.com>
auto-merge was automatically disabled January 10, 2026 17:49

Head branch was pushed to by a user without write access

@orozery
Copy link
Copy Markdown
Collaborator Author

orozery commented Jan 10, 2026

@NickLucche rebased

@NickLucche NickLucche enabled auto-merge (squash) January 12, 2026 13:29
@NickLucche NickLucche merged commit 9cddbdb into vllm-project:main Jan 12, 2026
56 checks passed
TomerBN-Nvidia pushed a commit to TomerBN-Nvidia/vllm that referenced this pull request Jan 13, 2026
…#24498)

Signed-off-by: Or Ozeri <oro@il.ibm.com>
Signed-off-by: Tomer Natan <tbarnatan@computelab-frontend-8.nvidia.com>
wangxiyuan pushed a commit to vllm-project/vllm-ascend that referenced this pull request Jan 15, 2026
### What this PR does / why we need it?
Upgrade vllm commit to 0113 (11b6af5280d6d6dfb8953af16e67b25f819b3be9)

- Modify import paths due to the refactors
vllm-project/vllm#31916
vllm-project/vllm#32054

- Fix `TypeError: NPUOffloadingSpec.__init__() takes 2 positional
arguments but 3 were given` due to
vllm-project/vllm#24498

- Skip the async-scheduling tests in
`tests/e2e/multicard/4-cards/long_sequence/test_mtp.py`, which are never
verified
vllm-project/vllm#31998

- Skip some pooling tests, which are caused by
vllm-project/vllm#32148
where vllm is also failed
https://buildkite.com/vllm/ci/builds/46705/steps/canvas?jid=019bb329-3834-4685-862b-1613b8e0f5d4

We will reopen those tests when main2main reachs
vllm-project/vllm#32243

- Skip some cases in
`tests/e2e/multicard/4-cards/long_sequence/test_mtp.py`, which are
broken by
vllm-project/vllm#32118

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@2f4e654

Signed-off-by: wjunLu <wjunlu217@gmail.com>
Signed-off-by: hfadzxy <starmoon_zhang@163.com>
Co-authored-by: hfadzxy <starmoon_zhang@163.com>
aipaes pushed a commit to aipaes/vllm-ascend that referenced this pull request Jan 15, 2026
### What this PR does / why we need it?
Upgrade vllm commit to 0113 (11b6af5280d6d6dfb8953af16e67b25f819b3be9)

- Modify import paths due to the refactors
vllm-project/vllm#31916
vllm-project/vllm#32054

- Fix `TypeError: NPUOffloadingSpec.__init__() takes 2 positional
arguments but 3 were given` due to
vllm-project/vllm#24498

- Skip the async-scheduling tests in
`tests/e2e/multicard/4-cards/long_sequence/test_mtp.py`, which are never
verified
vllm-project/vllm#31998

- Skip some pooling tests, which are caused by
vllm-project/vllm#32148
where vllm is also failed
https://buildkite.com/vllm/ci/builds/46705/steps/canvas?jid=019bb329-3834-4685-862b-1613b8e0f5d4

We will reopen those tests when main2main reachs
vllm-project/vllm#32243

- Skip some cases in
`tests/e2e/multicard/4-cards/long_sequence/test_mtp.py`, which are
broken by
vllm-project/vllm#32118

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@2f4e654

Signed-off-by: wjunLu <wjunlu217@gmail.com>
Signed-off-by: hfadzxy <starmoon_zhang@163.com>
Co-authored-by: hfadzxy <starmoon_zhang@163.com>
sammysun0711 pushed a commit to sammysun0711/vllm that referenced this pull request Jan 16, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation kv-connector ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants