Skip to content

[HMA] [KVEvent] Enable GPU-side KV events for HMA#37688

Merged
orozery merged 16 commits intovllm-project:mainfrom
hickeyma:add-evicted-grps-blck-rem-prop
Apr 12, 2026
Merged

[HMA] [KVEvent] Enable GPU-side KV events for HMA#37688
orozery merged 16 commits intovllm-project:mainfrom
hickeyma:add-evicted-grps-blck-rem-prop

Conversation

@hickeyma
Copy link
Copy Markdown
Contributor

@hickeyma hickeyma commented Mar 20, 2026

Purpose

Evicted group information is important for routing Hybrid Model Architecture (HMA) aware prefix-cache in distributed serving frameworks as vLLM can evict Sliding Window Attention (SWA) blocks while retaining the full-attention blocks. Without this information it would assume complete eviction and miss valid routing paths.

This PR enables KV events for HMA on GPU-side.

Test Plan

Run bench marking with KV events enabled:

VLLM_LOG_STATS_INTERVAL=0.01 vllm bench throughput --model openai/gpt-oss-20b --num-prompts 1000 --kv-events-config '{"enable_kv_cache_events": "True", "publisher": "zmq", "topic": "kv-events"}' OR

VLLM_LOG_STATS_INTERVAL=0.01 vllm bench throughput --model Qwen/Qwen3-14B --kv-offloading-size 10 --disable-hybrid-kv-cache-manager --num-prompts 1000 --kv-events-config '{"enable_kv_cache_events": "True", "publisher": "zmq", "topic": "kv-events"}'

Use the updated example client to receive the events:
https://github.com/hickeyma/vllm/blob/9ff97922c104117e1577cf9554902a73d2658391/examples/online_serving/kv_events_subscriber.py

Test Result

Receive stored and removed events in the client with the new group fields.


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces an evicted_groups field to the BlockRemoved KV cache event to support Hybrid Model Architecture (HMA) aware prefix-cache routing. The implementation correctly populates this field for GPU-side block evictions. However, for blocks evicted via offloading managers (ARCOffloadingManager and LRUOffloadingManager), the group information is not populated, which significantly limits the feature's effectiveness in scenarios involving KV cache offloading. I've added high-severity comments to highlight this functional gap.

Comment thread vllm/v1/kv_offload/arc_manager.py Outdated
Comment thread vllm/v1/kv_offload/lru_manager.py Outdated
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Mar 24, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @hickeyma.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Mar 24, 2026
@hickeyma hickeyma marked this pull request as draft March 27, 2026 17:08
@hickeyma
Copy link
Copy Markdown
Contributor Author

Moving to draft for now as need to update this PR further.

@hickeyma hickeyma force-pushed the add-evicted-grps-blck-rem-prop branch 2 times, most recently from 9bc6b63 to 63a9bf9 Compare March 30, 2026 09:33
@mergify mergify Bot removed the needs-rebase label Mar 30, 2026
@hickeyma hickeyma force-pushed the add-evicted-grps-blck-rem-prop branch from 0001e88 to 4c231e7 Compare March 30, 2026 14:24
@hickeyma hickeyma changed the title [HMA] [KVEvent] Add evicted groups field to BlockRemoved KV event [HMA] [KVEvent] Enable GPU-side KV events for HMA Mar 30, 2026
@hickeyma hickeyma force-pushed the add-evicted-grps-blck-rem-prop branch from 4c231e7 to cb22782 Compare March 30, 2026 15:07
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Mar 30, 2026

Documentation preview: https://vllm--37688.org.readthedocs.build/en/37688/

@mergify mergify Bot added the documentation Improvements or additions to documentation label Mar 30, 2026
@hickeyma hickeyma marked this pull request as ready for review March 30, 2026 18:58
@orozery orozery added the ready ONLY add when PR is ready to merge/full CI is needed label Apr 3, 2026
hickeyma and others added 11 commits April 3, 2026 12:51
Evicted group information is important for routing HMA aware prefix-cache
in distributed serving frameworks as vLLM can evict SWA blocks while
retaining the the full-attention blocks. Without this information it would
assume complete eviction and miss valid routing.

Signed-off-by: Martin Hickey <martin.hickey@ie.ibm.com>
Signed-off-by: Martin Hickey <martin.hickey@ie.ibm.com>
Signed-off-by: Martin Hickey <martin.hickey@ie.ibm.com>
Signed-off-by: Martin Hickey <martin.hickey@ie.ibm.com>
Review comment:
- vllm-project#37688 (comment)

Signed-off-by: Martin Hickey <martin.hickey@ie.ibm.com>
Recommendation from review as more work needs to be done to enable HMA
for CPU side and will be done in
vllm-project#38453

Review comment:

- vllm-project#37688 (comment)
- vllm-project#37688 (comment)

Signed-off-by: Martin Hickey <martin.hickey@ie.ibm.com>
Review comment:

- vllm-project#37688 (comment)

Signed-off-by: Martin Hickey <martin.hickey@ie.ibm.com>
Signed-off-by: Martin Hickey <martin.hickey@ie.ibm.com>
Using value of: self.group_idx if self.group_idx else None
would return same hash for group_idx == 0 and group_idx == None
because they are both Falsey.

Signed-off-by: Martin Hickey <martin.hickey@ie.ibm.com>
Co-authored-by: Or Ozeri <or@ozery.com>
Signed-off-by: Martin Hickey <martin.hickey@ie.ibm.com>
Co-authored-by: Or Ozeri <or@ozery.com>
Signed-off-by: Martin Hickey <martin.hickey@ie.ibm.com>
@hickeyma hickeyma force-pushed the add-evicted-grps-blck-rem-prop branch from b2dc540 to 8a7c373 Compare April 3, 2026 11:52
Signed-off-by: Martin Hickey <martin.hickey@ie.ibm.com>
@hickeyma hickeyma force-pushed the add-evicted-grps-blck-rem-prop branch from 8a7c373 to 110733b Compare April 3, 2026 11:54
@hickeyma
Copy link
Copy Markdown
Contributor Author

hickeyma commented Apr 3, 2026

This doesn't look like an error caused by the PR in https://buildkite.com/vllm/ci/builds/59637/steps/canvas?sid=019d5333-4316-40e3-85e8-ff134f99245d but more like an issue on the system:

[...]
[2026-04-03T12:13:55Z] [rank1]:   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/core/kv_cache_utils.py", line 644, in _check_enough_kv_cache_memory
[2026-04-03T12:13:55Z] [rank1]:     raise ValueError(
[2026-04-03T12:13:55Z] [rank1]: ValueError: To serve at least one request with the models's max seq len (4096), (0.5 GiB KV cache is needed, which is larger than the available KV cache memory (0.48 GiB). Based on the available memory, the estimated maximum model length is 3888. Try increasing `gpu_memory_utilization` or decreasing `max_model_len` when initializing the engine. See https://docs.vllm.ai/en/latest/configuration/conserving_memory/ for more details.

@orozery orozery merged commit cc07dad into vllm-project:main Apr 12, 2026
67 checks passed
vllm-agent pushed a commit to vllm-agent/vllm that referenced this pull request Apr 13, 2026
@hickeyma hickeyma deleted the add-evicted-grps-blck-rem-prop branch April 13, 2026 08:34
wojciech-wais pushed a commit to wojciech-wais/vllm that referenced this pull request Apr 13, 2026
Signed-off-by: Martin Hickey <martin.hickey@ie.ibm.com>
Co-authored-by: Or Ozeri <or@ozery.com>
whk-lab pushed a commit to whk-lab/vllm that referenced this pull request Apr 23, 2026
Signed-off-by: Martin Hickey <martin.hickey@ie.ibm.com>
Co-authored-by: Or Ozeri <or@ozery.com>
avinashsingh77 pushed a commit to avinashsingh77/vllm that referenced this pull request Apr 27, 2026
Signed-off-by: Martin Hickey <martin.hickey@ie.ibm.com>
Co-authored-by: Or Ozeri <or@ozery.com>
Signed-off-by: Avinash Singh <avinashsingh.rcoem@gmail.com>
mystous pushed a commit to mystous/vllm_hybrid that referenced this pull request May 10, 2026
Signed-off-by: Martin Hickey <martin.hickey@ie.ibm.com>
Co-authored-by: Or Ozeri <or@ozery.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation kv-connector ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants