Skip to content

[kv_offload+HMA][1/N]: Support multiple KV groups in OffloadingSpec#36610

Merged
orozery merged 1 commit intovllm-project:mainfrom
orozery:kv-offload-spec-multiple-groups
Mar 13, 2026
Merged

[kv_offload+HMA][1/N]: Support multiple KV groups in OffloadingSpec#36610
orozery merged 1 commit intovllm-project:mainfrom
orozery:kv-offload-spec-multiple-groups

Conversation

@orozery
Copy link
Copy Markdown
Collaborator

@orozery orozery commented Mar 10, 2026

This PR extends OffloadingSpec to support multiple KV cache groups,
each with its own possible block size.
We now distinguish between

  1. The block size of each KV cache group (determined by the KVCacheConfig).
  2. The block size used by vLLM for hashing request tokens (determined by cache_config.block_size).

This will allow the offloading connector to correctly map request tokens to:

  1. KVCacheBlocks (using the per-group block sizes from KVCacheConfig)
  2. Request.block_hashes (using the hash block size cache_config.block_size)

For now, we keep the offloading connector using the hash_block_size
as the block size. Later on, we will modify the offloading connector to use
the group-specific block sizes.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request extends OffloadingSpec to support multiple KV cache groups, which is a valuable enhancement for flexible memory management. The changes primarily adapt the offloading mechanism to handle a tuple of GPU block sizes instead of a single one. However, I've identified a critical logical error in vllm/v1/kv_offload/spec.py concerning the calculation of block_size_factor when a custom offloaded block size is specified. This bug could lead to an incorrect offloaded block size being used. My review includes a suggested fix for this issue. I also noted a design limitation that currently prevents using custom offloaded block sizes when KV groups have different GPU block sizes.

Comment thread vllm/v1/kv_offload/spec.py
@orozery orozery force-pushed the kv-offload-spec-multiple-groups branch 2 times, most recently from d2fdae9 to 8aee797 Compare March 10, 2026 09:35
@orozery
Copy link
Copy Markdown
Collaborator Author

orozery commented Mar 10, 2026

/gemini review

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request extends OffloadingSpec to support multiple KV cache groups, which is a good step towards more flexible KV cache management. The changes primarily involve refactoring OffloadingSpec and updating its consumers to handle the new structure. Overall, the changes are well-structured for this refactoring. However, I've identified a critical inconsistency in vllm/v1/kv_offload/cpu.py where the offloaded block size is calculated differently in get_manager and get_handlers. This could lead to serious issues if hash_block_size and gpu_block_size are not identical. Please see the detailed comment.

Comment thread vllm/v1/kv_offload/cpu.py Outdated
@orozery orozery force-pushed the kv-offload-spec-multiple-groups branch from 8aee797 to e21470b Compare March 10, 2026 11:22
@orozery
Copy link
Copy Markdown
Collaborator Author

orozery commented Mar 10, 2026

/gemini review

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request extends OffloadingSpec to support multiple KV cache groups, each with its own block size. This is a preparatory step for more advanced offloading strategies. The changes introduce a distinction between the per-group gpu_block_size and a global hash_block_size. For now, functionality is limited to a single GPU block size across all groups via assertions, which is a reasonable approach for a multi-part feature. My review focuses on ensuring robustness against edge cases introduced by these changes. I've identified a potential issue with error handling when no KV cache groups are present, which could lead to confusing assertion failures.

Comment thread vllm/v1/kv_offload/spec.py
@orozery orozery force-pushed the kv-offload-spec-multiple-groups branch from e21470b to a007a63 Compare March 11, 2026 07:20
Copy link
Copy Markdown
Collaborator

@NickLucche NickLucche left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Comment on lines +56 to +58
"If 'block_size' is specified in kv_connector_extra_config, "
"there must be at least one KV cache group, "
"and all groups must have the same block size."
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
"If 'block_size' is specified in kv_connector_extra_config, "
"there must be at least one KV cache group, "
"and all groups must have the same block size."
"If 'block_size' is specified in kv_connector_extra_config "
"all groups must have the same block size."

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it should never be empty regardless

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you're basically reverting gemini's suggestion :)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But in theory it can be empty if we're running a model without KV cache (encoder model?)

@orozery orozery added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 12, 2026
@orozery orozery enabled auto-merge (squash) March 12, 2026 19:40
This commit extends OffloadingSpec to support multiple KV cache groups,
each with its own possible block size.
We now distinguish between
1. The block size of each KV cache group (determined by the KVCacheConfig).
2. The block size used by vLLM for hashing request tokens (determined by cache_config.block_size).

This will allow the offloading connector to correctly map request tokens to:
1. KVCacheBlocks (using the per-group block sizes from KVCacheConfig)
2. Request.block_hashes (using the hash block size cache_config.block_size)

For now, we keep the offloading connector using the hash_block_size
as the block size. Later on, we will modify the offloading connector to use
the group-specific block sizes.

Signed-off-by: Or Ozeri <oro@il.ibm.com>
@orozery orozery force-pushed the kv-offload-spec-multiple-groups branch from 96c433a to 6af4887 Compare March 13, 2026 05:30
@orozery orozery merged commit cfaf466 into vllm-project:main Mar 13, 2026
47 checks passed
whycoming pushed a commit to whycoming/vllm that referenced this pull request Mar 13, 2026
…llm-project#36610)

Signed-off-by: Or Ozeri <oro@il.ibm.com>
Signed-off-by: whycoming <120623296@qq.com>
Lucaskabela pushed a commit to Lucaskabela/vllm that referenced this pull request Mar 17, 2026
wendyliu235 pushed a commit to wendyliu235/vllm-public that referenced this pull request Mar 18, 2026
wangxiyuan pushed a commit to vllm-project/vllm-ascend that referenced this pull request Mar 18, 2026
### What this PR does / why we need it?

1.fix "TypeError: get_attn_backend() remove variable": [Refactor
`check_and_update_config`](vllm-project/vllm#35122)

2.fix [Rename `compile_ranges_split_points` to
`compile_ranges_endpoints`](vllm-project/vllm#36027)

3.fix "RuntimeError: device_allocator not a DeviceAllocator":[Replace
memory related torch.cuda
APIs"](vllm-project/vllm#37031)

4.fix [Support multiple KV groups in OffloadingSpec
](vllm-project/vllm#36610) removed
self.offloaded_block_size and changed self.gpu_block_size from a scalar
to a tuple of per-group block sizes, adding block_size_factor.

5.fix [Consolidate
SupportsEagle](vllm-project/vllm#36063) renamed
get_eagle3_aux_hidden_state_layers() to
get_eagle3_default_aux_hidden_state_layers() and added a
supports_eagle3() guard before calling it.

### Does this PR introduce _any_ user-facing change?
NA
### How was this patch tested?
E2E


- vLLM version: v0.17.0
- vLLM main:
vllm-project/vllm@8a68046

---------

Signed-off-by: leo-pony <nengjunma@outlook.com>
Co-authored-by: Claude Code <noreply@anthropic.com>
starmountain1997 pushed a commit to starmountain1997/vllm-ascend that referenced this pull request Mar 25, 2026
### What this PR does / why we need it?

1.fix "TypeError: get_attn_backend() remove variable": [Refactor
`check_and_update_config`](vllm-project/vllm#35122)

2.fix [Rename `compile_ranges_split_points` to
`compile_ranges_endpoints`](vllm-project/vllm#36027)

3.fix "RuntimeError: device_allocator not a DeviceAllocator":[Replace
memory related torch.cuda
APIs"](vllm-project/vllm#37031)

4.fix [Support multiple KV groups in OffloadingSpec
](vllm-project/vllm#36610) removed
self.offloaded_block_size and changed self.gpu_block_size from a scalar
to a tuple of per-group block sizes, adding block_size_factor.

5.fix [Consolidate
SupportsEagle](vllm-project/vllm#36063) renamed
get_eagle3_aux_hidden_state_layers() to
get_eagle3_default_aux_hidden_state_layers() and added a
supports_eagle3() guard before calling it.

### Does this PR introduce _any_ user-facing change?
NA
### How was this patch tested?
E2E


- vLLM version: v0.17.0
- vLLM main:
vllm-project/vllm@8a68046

---------

Signed-off-by: leo-pony <nengjunma@outlook.com>
Co-authored-by: Claude Code <noreply@anthropic.com>
khairulkabir1661 pushed a commit to khairulkabir1661/vllm that referenced this pull request Mar 27, 2026
lihaokun-2026 pushed a commit to lihaokun-2026/vllm-ascend that referenced this pull request Mar 29, 2026
### What this PR does / why we need it?

1.fix "TypeError: get_attn_backend() remove variable": [Refactor
`check_and_update_config`](vllm-project/vllm#35122)

2.fix [Rename `compile_ranges_split_points` to
`compile_ranges_endpoints`](vllm-project/vllm#36027)

3.fix "RuntimeError: device_allocator not a DeviceAllocator":[Replace
memory related torch.cuda
APIs"](vllm-project/vllm#37031)

4.fix [Support multiple KV groups in OffloadingSpec
](vllm-project/vllm#36610) removed
self.offloaded_block_size and changed self.gpu_block_size from a scalar
to a tuple of per-group block sizes, adding block_size_factor.

5.fix [Consolidate
SupportsEagle](vllm-project/vllm#36063) renamed
get_eagle3_aux_hidden_state_layers() to
get_eagle3_default_aux_hidden_state_layers() and added a
supports_eagle3() guard before calling it.

### Does this PR introduce _any_ user-facing change?
NA
### How was this patch tested?
E2E


- vLLM version: v0.17.0
- vLLM main:
vllm-project/vllm@8a68046

---------

Signed-off-by: leo-pony <nengjunma@outlook.com>
Co-authored-by: Claude Code <noreply@anthropic.com>
chenchuw886 pushed a commit to chenchuw886/vllm-ascend that referenced this pull request Apr 1, 2026
### What this PR does / why we need it?

1.fix "TypeError: get_attn_backend() remove variable": [Refactor
`check_and_update_config`](vllm-project/vllm#35122)

2.fix [Rename `compile_ranges_split_points` to
`compile_ranges_endpoints`](vllm-project/vllm#36027)

3.fix "RuntimeError: device_allocator not a DeviceAllocator":[Replace
memory related torch.cuda
APIs"](vllm-project/vllm#37031)

4.fix [Support multiple KV groups in OffloadingSpec
](vllm-project/vllm#36610) removed
self.offloaded_block_size and changed self.gpu_block_size from a scalar
to a tuple of per-group block sizes, adding block_size_factor.

5.fix [Consolidate
SupportsEagle](vllm-project/vllm#36063) renamed
get_eagle3_aux_hidden_state_layers() to
get_eagle3_default_aux_hidden_state_layers() and added a
supports_eagle3() guard before calling it.

### Does this PR introduce _any_ user-facing change?
NA
### How was this patch tested?
E2E


- vLLM version: v0.17.0
- vLLM main:
vllm-project/vllm@8a68046

---------

Signed-off-by: leo-pony <nengjunma@outlook.com>
Co-authored-by: Claude Code <noreply@anthropic.com>
mtparet pushed a commit to blackfuel-ai/vllm that referenced this pull request Apr 9, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

kv-connector ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants