Skip to content

v1/engine: emit prefix-cache KV-events at hash_block_size granularity for hybrid Mamba+Attention models#43258

Open
vanshilshah97 wants to merge 2 commits into
vllm-project:mainfrom
vanshilshah97:vanshils/kv-events-hybrid-hash-block-size
Open

v1/engine: emit prefix-cache KV-events at hash_block_size granularity for hybrid Mamba+Attention models#43258
vanshilshah97 wants to merge 2 commits into
vllm-project:mainfrom
vanshilshah97:vanshils/kv-events-hybrid-hash-block-size

Conversation

@vanshilshah97
Copy link
Copy Markdown
Contributor

@vanshilshah97 vanshilshah97 commented May 20, 2026

Summary

Two changes that make vLLM emit prefix-cache KV-events at the user-specified
hash_block_size granularity for hybrid Mamba+Attention models, where
the physical attention block size gets inflated to satisfy the mamba page
size constraint.

Motivation

For hybrid Mamba+Attention models, vLLM inflates cache_config.block_size
to attn_block_size so the attention page size is ≥ the mamba page size.
On current main, that inflation also overrides the finer hash_block_size
the user requested via --block-size, so:

  • resolve_kv_cache_block_sizes() returns hash_block_size == block_size,
  • SingleTypeKVCacheManager.cache_blocks() only fires BlockStored events
    on full inflated physical blocks,
  • requests with a prompt shorter than the inflated block produce zero
    kv-events,
  • downstream kv-event consumers lose the prefix signal they would
    otherwise use for routing or observability.

What this PR changes

  1. Preserve hash_block_size before inflation. In both
    Platform.check_and_update_config (vllm/platforms/interface.py) and
    the EngineCore resolution path (vllm/v1/engine/core.py), capture the
    user-supplied block_size as hash_block_size before inflating
    block_size to attn_block_size.

  2. Emit sub-block events at hash_block_size granularity.
    SingleTypeKVCacheManager._maybe_emit_sub_block_events advances a
    per-request cursor over hash-block boundaries and appends BlockStored
    events to block_pool.kv_event_queue whenever
    block_size > hash_block_size. Mamba groups override this with a noop,
    so only attention groups emit (matches the existing convention that
    only attention groups participate in prefix-cache hashing).

  3. Surface hash_block_size in metadata.
    EngineCoreProc.get_kv_cache_group_metadata now reports
    hash_block_size (not spec.block_size) when sub-block emission is
    active, so downstream kv-event consumers use the right hashing
    granularity. Falls back to spec.block_size when hash_block_size is
    unset or equals the physical block size.

This is purely additive: the legacy event path (full physical blocks)
fires exactly as before. Behaviour is unchanged for any deployment that
does not configure --kv-events-config and does not pass --block-size
smaller than the model's effective attn_block_size.

Minimal reproducer

A small Python ZMQ subscriber is enough to see the new events. Start a
hybrid Mamba+Attention model with prefix caching and a kv-events publisher:

python -m vllm.entrypoints.openai.api_server \
  --model <hybrid-mamba-attention-model-id> \
  --tensor-parallel-size 8 \
  --block-size 64 \
  --enable-prefix-caching \
  --trust-remote-code \
  --kv-events-config '{"publisher":"zmq","topic":"kv-events",
                       "endpoint":"tcp://*:20080",
                       "enable_kv_cache_events":true}' \
  --port 8000

Run a subscriber alongside it:

# kv_subscriber.py
import json, sys, time, zmq
ctx = zmq.Context.instance()
sock = ctx.socket(zmq.SUB)
sock.connect("tcp://127.0.0.1:20080")
sock.setsockopt(zmq.SUBSCRIBE, b"kv-events")
out = open(sys.argv[1], "w", buffering=1)
while True:
    frames = sock.recv_multipart()
    out.write(json.dumps({"t": time.time(),
                          "frames": [f.decode("latin-1") for f in frames]}) + "\n")

Send a short request whose prompt is smaller than the inflated
attention block (e.g. ~1000 tokens when attn_block_size inflates to
2176):

curl -s http://127.0.0.1:8000/v1/completions \
  -H "Content-Type: application/json" \
  -d '{"model":"<hybrid-mamba-attention-model-id>",
       "prompt":"...","max_tokens":16}'

On main today: the worker logs
Setting attention block size to <N> tokens to ensure that attention page size is >= mamba page size.
and the subscriber receives no BlockStored events for that
short-prompt request — the prefix signal is silently dropped.

With this PR: the same worker logs
Setting attention block size to <N> tokens to ensure that attention page size is >= mamba page size (hash granularity preserved at <H>).
and the subscriber receives one or more BlockStored events whose
block_size field equals the user-supplied hash_block_size rather than
the inflated attn_block_size. Each event carries one block_hash per
hash_block_size-token chunk of the prompt, with parent_block_hash
chaining the chunks together exactly the way the regular full-block path
already chains its events.

DCO

Both commits are signed-off-by Vanshil Shah <vanshils@nvidia.com>.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces sub-block BlockStored event emission for hybrid models, ensuring that KV-cache events are fired at the hash_block_size granularity even when physical block sizes are inflated. It adds a tracking cursor for emitted blocks and updates engine metadata to reflect the effective block size. Review feedback identifies a bug where incorrect attributes are accessed on LoRARequest objects and points out a potential issue with duplicate event emissions for hybrid models.

Comment on lines +329 to +333
lora_id=(request.lora_request.adapter_id
if request.lora_request else None),
medium=MEDIUM_GPU,
lora_name=(request.lora_request.name
if request.lora_request else None),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The fields adapter_id and name do not exist on the LoRARequest class used in vLLM. Accessing them will raise an AttributeError when processing requests with LoRA enabled. Based on the LoRARequest definition in vllm/lora/request.py and its usage in other parts of the codebase (e.g., vllm/v1/core/kv_cache_utils.py), these should be lora_int_id and lora_name respectively.

Suggested change
lora_id=(request.lora_request.adapter_id
if request.lora_request else None),
medium=MEDIUM_GPU,
lora_name=(request.lora_request.name
if request.lora_request else None),
lora_id=(request.lora_request.lora_int_id
if request.lora_request else None),
medium=MEDIUM_GPU,
lora_name=(request.lora_request.lora_name
if request.lora_request else None),

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review. adapter_id and name are @property aliases defined on LoRARequest at vllm/lora/request.py:41-46:

@property
def adapter_id(self):
    return self.lora_int_id

@property
def name(self):
    return self.lora_name

The existing emit path at vllm/v1/core/block_pool.py:310,314 uses these same property accessors, so this change follows the convention already in the file. Happy to switch to lora_int_id/lora_name if maintainers prefer the underlying field names everywhere — let me know.

self.num_cached_block[request.request_id] = num_full_blocks
# Also emit sub-block events so prompts smaller than the inflated
# physical block still produce a KV-event signal.
self._maybe_emit_sub_block_events(request, num_tokens)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This implementation introduces duplicate BlockStored events for hybrid models. When block_size > hash_block_size, the self.block_pool.cache_full_blocks call (lines 352-359) already emits BlockStored events for the newly cached physical blocks at hash_block_size granularity. By calling _maybe_emit_sub_block_events immediately after, the same hash blocks are emitted again (once as part of the physical block event and once as individual sub-block events).

This redundancy wastes bandwidth and can cause issues for downstream consumers that expect unique events per hash block. To fix this, you should ensure that _maybe_emit_sub_block_events only emits for tokens that are not yet covered by a full physical block, or coordinate with BlockPool to suppress its internal emission when sub-block tracking is active for a group.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for flagging this. The two emission paths actually fire at different granularities for the hybrid case where block_size > hash_block_size, so they end up complementary rather than duplicating each other:

  • block_pool.cache_full_blocks (vllm/v1/core/block_pool.py:241-256) rebuilds block_hashes with BlockHashListWithBlockSize(request.block_hashes, self.hash_block_size, block_size) when block_size != hash_block_size, and emits BlockStored with block_size=self.block_size (the inflated value, e.g. 2176). One coarse event per physical block, hashed over the full physical-block window.
  • _maybe_emit_sub_block_events (new) walks request.block_hashes directly and emits BlockStored with block_size=hash_block_size (e.g. 64) — many fine events per physical block, each hashed over a 64-token window.

Because the hash inputs are different (a hash over 2176 tokens vs. a hash over 64 tokens), the resulting block_hashes values are distinct between the two streams; downstream consumers see them as different cached entries at different granularities, not redundant events for the same hash.

Verified empirically against vllm/vllm-openai:nightly-bf610c2f5 with a hybrid Mamba+Attention model, --block-size 64, --enable-prefix-caching, and the ZMQ subscriber from the PR description: without this change the subscriber receives only block_size=2176 events; with this change those exact same coarse events still arrive and a new stream of block_size=64 events appears alongside them (with the right parent_block_hash chain).

Consumers that only need the fine granularity can filter on block_size == hash_block_size. If maintainers think the coarse stream should be suppressed for hybrid groups when sub-block emission is active, I can gate it behind a flag — happy to discuss the right behaviour here.

@mergify mergify Bot added the v1 label May 20, 2026
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented May 20, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @vanshilshah97.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label May 20, 2026
@vanshilshah97 vanshilshah97 marked this pull request as ready for review May 21, 2026 06:36
@vanshilshah97 vanshilshah97 force-pushed the vanshils/kv-events-hybrid-hash-block-size branch from ae90ff8 to 74f7233 Compare May 21, 2026 06:51
@mergify mergify Bot removed the needs-rebase label May 21, 2026
…r hybrid models

For hybrid Mamba+Attention models vLLM inflates cache_config.block_size to
attn_block_size so the attention page size is >= the mamba page size.
Without this patch the inflation also wipes out the finer hash_block_size
the user supplied via --block-size, so:

  * resolve_kv_cache_block_sizes() returns hash_block_size == block_size,
  * SingleTypeKVCacheManager.cache_blocks() only fires BlockStored events
    on full inflated physical blocks,
  * prompts shorter than the inflated block produce zero kv-events.

This patch:

  * Preserves the user --block-size as hash_block_size on both the
    Platform.check_and_update_config and EngineCore resolution paths
    before the inflation runs.
  * Emits synthetic BlockStored events at hash-block granularity from
    SingleTypeKVCacheManager.cache_blocks() whenever block_size >
    hash_block_size (only attention groups; mamba groups are explicitly
    suppressed by overriding the emit hook to a noop).

Signed-off-by: Vanshil Shah <vanshils@nvidia.com>
Report hash_block_size (not spec.block_size) when sub-block emission is
active so downstream KV-event consumers use the right hashing granularity.
Falls back to spec.block_size when hash_block_size is unset or equals the
physical block size.

Pairs with the sub-block emit patch in this branch: the emit patch sends
events at hash_block_size granularity, this patch makes sure consumers
know that granularity.

Signed-off-by: Vanshil Shah <vanshils@nvidia.com>
@vanshilshah97 vanshilshah97 force-pushed the vanshils/kv-events-hybrid-hash-block-size branch from 74f7233 to e9c0793 Compare May 21, 2026 07:59
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Jun 3, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @vanshilshah97.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Jun 3, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant