Skip to content

[PD][Core] Fix Mamba prefix cache with PD#42547

Closed
NickLucche wants to merge 6 commits into
vllm-project:mainfrom
NickLucche:pd-fix-apc
Closed

[PD][Core] Fix Mamba prefix cache with PD#42547
NickLucche wants to merge 6 commits into
vllm-project:mainfrom
NickLucche:pd-fix-apc

Conversation

@NickLucche

@NickLucche NickLucche commented May 13, 2026

Copy link
Copy Markdown
Member

Fix 0% prefix cache hit rate with Mamba in PD disaggregation (all/align).
Based on #42554, real diff here NickLucche/vllm@mamba-prefix-caching-pd...NickLucche:vllm:pd-fix-apc

Bug

Mamba prefix cache reports 0% hit rate on the Decode side in PD disaggregation.

This is PD-specific. In standalone mode, allocate_new_computed_blocks is
skipped entirely (num_external_computed_tokens = 0), and null blocks only
appear later during RUNNING via remove_skipped_blocks, by which time the real
blocks are already hashed.

In PD mode, allocate_new_computed_blocks runs with
num_external_computed_tokens > 0, which pads req_blocks with null blocks
via Mamba's get_num_skipped_tokens(N) = N-1. The old code then set:

self.num_cached_block[request_id] = len(req_blocks)  # counts nulls!

When _update_waiting_for_remote_kv later called cache_blocks(), it found
num_cached_block >= num_full_blocks and early-returnednothing was ever
hashed into the block pool, so every subsequent find_longest_cache_hit missed.

allocate_new_computed_blocks (400 tokens, block_size=128):
  get_num_skipped_tokens(400) = 399num_skipped_blocks = 3
  req_blocks = [null, null, null, fresh]
  num_cached_block = 3BUG: counts nulls

cache_blocks(400):
  num_full_blocks = 400 // 128 = 3
  3 >= 3EARLY RETURNnothing hashed0% hit rate

Fix

Two changes, both in single_type_kv_cache_manager.py:

  1. Don't count null blocks in num_cached_block

Capture len(new_computed_blocks) before the skip-slicing that strips
leading blocks. This counts only real prefix-hit blocks, not null padding:

  num_computed_blocks = len(new_computed_blocks)   # before slicing
  # ... slicing, padding, etc ...
  self.num_cached_block[request_id] = num_computed_blocks

This is a no-op for FullAttention (no skipping) and SWA (the null padding in
new_computed_blocks from find_longest_cache_hit exactly equals
num_skipped_blocks, so the count is unchanged).

  1. Register null-block hashes in MambaManager.cache_blocks

With fix 1, cache_blocks() no longer early-returns — it iterates the null
blocks. But BlockPool.cache_full_blocks skips them (blk.is_null → continue),
so their hashes never enter the hash map.

Mamba's find_longest_cache_hit searches right-to-left through block hashes.
If null-block positions aren't in the hash map, the search misses and
hit_length drops to 0, dragging the HMA coordinator's overall hit to 0.

MambaManager.cache_blocks now registers hash → null_block entries for null positions.

Reproducer (PD disaggregation)

# D
 VLLM_NIXL_SIDE_CHANNEL_PORT=$(just port 5558) VLLM_SSM_CONV_STATE_LAYOUT=DS 
vllm serve nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-FP8 --port $(just port 8200) --enforce-eager --tensor-parallel-size 1 --gpu-memory-utilization 0.9 --trust-remote-code --max-model-len 131072 --block-size 128 --enable-prefix-caching --no-disable-hybrid-kv-cache-manager --mamba-cache-mode align --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}'

# P
VLLM_NIXL_SIDE_CHANNEL_PORT=$(just port 5557) vllm serve nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-FP8 --port $(just port 8100)  --gpu-memory-utilization 0.9 --trust-remote-code --enforce-eager --max-model-len 131072 --block-size 128 --enable-prefix-caching --no-disable-hybrid-kv-cache-manager --mamba-cache-mode align --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}'

# proxy
python vllm//tests/v1/kv_connector/nixl_integration/toy_proxy_server.py --port $(just port 8192) --prefiller-port $(just port 8100) --decoder-port $(just port 8200)

# Send same request twice and observe D-side logs:

# D
(APIServer pid=2777847) INFO 05-13 18:03:41 [loggers.py:271] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 0.0 tokens/s, Running: 0 reqs, Waiting: 0 reqs, GPU KV cache usage: 0.0%, **Prefix cache hit rate: 0.0%,** External prefix cache hit rate: 100.0%

Test with

  pytest tests/v1/core/test_single_type_kv_cache_manager.py -k "mamba_align" -v
  pytest tests/v1/kv_connector/unit/test_nixl_connector_hma.py -k "ssm_prefix" -v

Benchmark

A simple scenario, PD TP1, H100, Nemotron3-Nano, ~8k/1k:

vllm bench serve --model nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-FP8 \
  --dataset-name prefix_repetition --num-prompts 1000 \
  --base-url http://localhost:55483 --ignore-eos --max-concurrency 100 \
  --prefix-repetition-prefix-len 6000 --prefix-repetition-suffix-len 2000 \
  --prefix-repetition-num-prefixes 100 --prefix-repetition-output-len 1000

# --no-enable-prefix-caching
============ Serving Benchmark Result ============
Successful requests:                     1000
Failed requests:                         0
Maximum request concurrency:             100
Benchmark duration (s):                  539.67
Total input tokens:                      8000031
Total generated tokens:                  1000000
Request throughput (req/s):              1.85
Output token throughput (tok/s):         1852.97
Peak output token throughput (tok/s):    2100.00
Peak concurrent requests:                106.00
Total token throughput (tok/s):          16676.77
---------------Time to First Token----------------
Mean TTFT (ms):                          1632.73
Median TTFT (ms):                        579.53
P99 TTFT (ms):                           18875.09
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          51.70
Median TPOT (ms):                        52.04
P99 TPOT (ms):                           52.53
---------------Inter-token Latency----------------
Mean ITL (ms):                           51.70
Median ITL (ms):                         51.92
P99 ITL (ms):                            60.07
==================================================

# --enable-prefix-caching --mamba-cache-mode align
============ Serving Benchmark Result ============
Successful requests:                     1000
Failed requests:                         0
Maximum request concurrency:             100
Benchmark duration (s):                  533.51
Total input tokens:                      8000031
Total generated tokens:                  1000000
Request throughput (req/s):              1.87
Output token throughput (tok/s):         1874.39
Peak output token throughput (tok/s):    2100.00
Peak concurrent requests:                112.00
Total token throughput (tok/s):          16869.59
---------------Time to First Token----------------
Mean TTFT (ms):                          1458.54
Median TTFT (ms):                        508.13
P99 TTFT (ms):                           16325.75
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          51.35
Median TPOT (ms):                        51.32
P99 TPOT (ms):                           53.01
---------------Inter-token Latency----------------
Mean ITL (ms):                           51.35
Median ITL (ms):                         51.49
P99 ITL (ms):                            61.80
==================================================

@mergify

mergify Bot commented May 13, 2026

Copy link
Copy Markdown
Contributor

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @NickLucche.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label May 13, 2026

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements support for KV transfer in Mamba hybrid models, specifically addressing challenges with heterogeneous Tensor Parallelism (TP) and prefix caching. Significant changes include updating BlockPool to register null block hashes, refining the NixlConnectorWorker to use remote physical block ratios for kernel block mapping, and introducing _apply_prefix_caching to manage block ID trimming. The PR also adds validation to disable prefix caching for Mamba hybrid models when physical block counts are heterogeneous. Review feedback highlights a design constraint in the SSM block handling where an assertion assumes a single local block, suggesting this should be better documented or handled with a descriptive error.

Comment on lines +2346 to +2355
if (
_is_ssm_spec(self._group_spec_types[i])
and num_local_blocks < num_remote_blocks
):
# NOTE (NickLucche): With prefix caching on SSM, (remote) blocks
# prior to the last one are placeholders (null blocks). Mind that
# this doesn't really impact transfer, as we only still care about
# the last "block", the full in-place state.
assert num_local_blocks == 1, "SSM can only have one local block"
remote_block_ids[i] = remote_group[-num_local_blocks:]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The assertion assert num_local_blocks == 1 assumes that SSM groups can only have one local block. If this is a design constraint, it should be documented as such in the class or method docstring, or the assertion should be replaced with a more descriptive error message if it's a potential runtime failure point.

Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
# Only count non-null blocks as cached. Null blocks appear here from Mamba
# align-mode and SWA/chunked-local attention.
num_cached = sum(1 for b in req_blocks if not b.is_null)
self.num_cached_block[request_id] = num_cached

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree that this line have bug when delay_cache_block is True. but I think it should be set to the input len(new_computed_blocks) before this line new_computed_blocks = new_computed_blocks[num_skipped_blocks:]

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks @heheda12345 !
This isn't quite working even after reverting block_pool. Will investigate some more asap

Comment thread vllm/v1/core/block_pool.py Outdated
@NickLucche NickLucche changed the title [PD] Fix Mamba cache align mode with PD [PD][Core] Fix Mamba prefix cache with PD May 18, 2026
@mergify

mergify Bot commented May 18, 2026

Copy link
Copy Markdown
Contributor

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @NickLucche.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label May 18, 2026
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Comment on lines 9 to 13
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_utils import (
BlockHashList,
BlockHashListWithBlockSize,
BlockHashWithGroupId,

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ignore everything which isnt this file

@mergify mergify Bot removed the needs-rebase label May 20, 2026
@NickLucche

Copy link
Copy Markdown
Member Author

@heheda12345 @tdoublep I pulled the changes to be confined to the MambaManager.cache_blocks, but I am still having to account for those null blocks to get a hit on all groups (we mostly care about FA for xfer side).
Do you see a cleaner way to fix this?

@underfituu

Copy link
Copy Markdown
Contributor

@heheda12345 @tdoublep I pulled the changes to be confined to the MambaManager.cache_blocks, but I am still having to account for those null blocks to get a hit on all groups (we mostly care about FA for xfer side). Do you see a cleaner way to fix this?
Hi @NickLucche @heheda12345 @tdoublep,

Regarding the cache hit 0% issue discussed here, I've proposed a solution in my PR #42524 that might help address this.

Could you please take a look and see if it aligns with what you're trying to achieve for the mamba hybrid models? Would love to get your feedback!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants