Skip to content

[PD][Nixl] Mamba prefix caching mode support #42554

Merged
vllm-bot merged 6 commits into
vllm-project:mainfrom
NickLucche:mamba-prefix-caching-pd
Jun 4, 2026
Merged

[PD][Nixl] Mamba prefix caching mode support #42554
vllm-bot merged 6 commits into
vllm-project:mainfrom
NickLucche:mamba-prefix-caching-pd

Conversation

@NickLucche

@NickLucche NickLucche commented May 13, 2026

Copy link
Copy Markdown
Member

This PR adds support for PD Mamba setups to make use of prefix caching ("all", "align" as well as the upcoming #37898).
It merely adds the logic to handle the result of a prefix cache hit, so it is agnostic to the actual caching implementation.
Running without this PR with prefix caching enabled will run into this assertion

if _is_ssm_spec(self._group_spec_types[i]):
assert num_local_blocks == num_remote_blocks

as prefix caching in mamba will insert placeholder blocks (align-mode) to signal block-aligned checkpoints eg [0, 0, 14].

Unfortunately the underlying prefix caching logic (both all/align) needs a fix to correctly register cache hits with PD #42547.
I am pushing the fix in a separate PR as it is far more general and impactful and would therefore welcome any discussion there (this would be a nice use-case for stacked PRs btw).

Another known limitation is that we're not adding support for heterogeneous block_size and prefix-caching yet, so we leave the following assertion in place

raise RuntimeError(
"Prefix caching with heterogeneous physical_blocks_per_logical "
"is not supported for Mamba hybrid models. "
f"Local: {self._physical_blocks_per_logical_kv_block}, "
f"Remote: {remote_physical_per_logical}. "
"Disable prefix caching with --no-enable-prefix-caching."
)

Will follow up with a separate PR to address that case. I feel agreeing on fix #42547 is actually the most important part here.

Test with

--enable-prefix-caching --mamba-cache-mode align --no-disable-hybrid-kv-cache-manager

# D
 VLLM_NIXL_SIDE_CHANNEL_PORT=$(just port 5558) VLLM_SSM_CONV_STATE_LAYOUT=DS 
vllm serve nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-FP8 --port $(just port 8200) --enforce-eager --tensor-parallel-size 1 --gpu-memory-utilization 0.9 --trust-remote-code --max-model-len 131072 --block-size 128 --enable-prefix-caching --no-disable-hybrid-kv-cache-manager --mamba-cache-mode align --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}'

# P
VLLM_NIXL_SIDE_CHANNEL_PORT=$(just port 5557) vllm serve nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-FP8 --port $(just port 8100)  --gpu-memory-utilization 0.9 --trust-remote-code --enforce-eager --max-model-len 131072 --block-size 128 --enable-prefix-caching --no-disable-hybrid-kv-cache-manager --mamba-cache-mode align --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}'

# proxy
python vllm//tests/v1/kv_connector/nixl_integration/toy_proxy_server.py --port $(just port 8192) --prefiller-port $(just port 8100) --decoder-port $(just port 8200)

# After #42547
# P
(APIServer pid=2669040) INFO 05-13 17:41:28 [loggers.py:271] Engine 000: Avg prompt throughput: 1177.3 tokens/s, Avg generation throughput: 0.2 tokens/s, Running: 0 reqs, Waiting: 0 reqs, GPU KV cache usage: 0.2%, Prefix cache hit rate: 26.4%, External prefix cache hit rate: 0.0%

# D
(APIServer pid=2668541) INFO 05-13 17:41:25 [loggers.py:271] Engine 000: Avg prompt throughput: 0.2 tokens/s, Avg generation throughput: 2.0 tokens/s, Running: 0 reqs, Waiting: 0 reqs, GPU KV cache usage: 0.0%, Prefix cache hit rate: 26.4%, External prefix cache hit rate: 100.0%

or run unit test with

pytest -v -s tests/v1/kv_connector/unit/test_nixl_connector_hma.py::test_apply_prefix_caching_ssm_prefix_cache_hit

@claude claude Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Code Review

This pull request is from a fork — automated review is disabled. A repository maintainer can comment @claude review to run a one-time review.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces prefix caching support for SSM and Full Attention (FA) groups in the NIXL KV connector. Key changes include updating the _apply_prefix_caching logic to handle SSM placeholder trimming and FA partial prefix hits, along with scheduler adjustments to support external KV connectors and skip Mamba block alignment during asynchronous KV loading. Review feedback identified critical issues in the slicing logic and assertions within _apply_prefix_caching that would cause failures during full prefix cache hits (when num_local_blocks is zero).

Comment thread vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py
Comment thread vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py
@markmc

markmc commented May 14, 2026

Copy link
Copy Markdown
Member

xref #42620 which proposes adding KVConnectorBase_V1.supports_mamba_external_kv()

@markmc

markmc commented May 14, 2026

Copy link
Copy Markdown
Member

#42524 is also in this territory

@NickLucche NickLucche added the ready ONLY add when PR is ready to merge/full CI is needed label May 19, 2026
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
@vllm-bot vllm-bot merged commit 68f5e56 into vllm-project:main Jun 4, 2026
68 of 70 checks passed
JisoLya pushed a commit to JisoLya/vllm that referenced this pull request Jun 5, 2026
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: JisoLya <523420504@qq.com>
knight0528 pushed a commit to knight0528/vllm that referenced this pull request Jun 8, 2026
Signed-off-by: NickLucche <nlucches@redhat.com>
waqahmed-amd-fi pushed a commit to waqahmed-amd-fi/vllm that referenced this pull request Jun 10, 2026
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: Waqar Ahmed <waqar.ahmed@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

kv-connector ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants