[PD][Core] Fix Mamba prefix cache with PD#42547
Conversation
|
This pull request has merge conflicts that must be resolved before it can be |
There was a problem hiding this comment.
Code Review
This pull request implements support for KV transfer in Mamba hybrid models, specifically addressing challenges with heterogeneous Tensor Parallelism (TP) and prefix caching. Significant changes include updating BlockPool to register null block hashes, refining the NixlConnectorWorker to use remote physical block ratios for kernel block mapping, and introducing _apply_prefix_caching to manage block ID trimming. The PR also adds validation to disable prefix caching for Mamba hybrid models when physical block counts are heterogeneous. Review feedback highlights a design constraint in the SSM block handling where an assertion assumes a single local block, suggesting this should be better documented or handled with a descriptive error.
| if ( | ||
| _is_ssm_spec(self._group_spec_types[i]) | ||
| and num_local_blocks < num_remote_blocks | ||
| ): | ||
| # NOTE (NickLucche): With prefix caching on SSM, (remote) blocks | ||
| # prior to the last one are placeholders (null blocks). Mind that | ||
| # this doesn't really impact transfer, as we only still care about | ||
| # the last "block", the full in-place state. | ||
| assert num_local_blocks == 1, "SSM can only have one local block" | ||
| remote_block_ids[i] = remote_group[-num_local_blocks:] |
There was a problem hiding this comment.
The assertion assert num_local_blocks == 1 assumes that SSM groups can only have one local block. If this is a design constraint, it should be documented as such in the class or method docstring, or the assertion should be replaced with a more descriptive error message if it's a potential runtime failure point.
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
| # Only count non-null blocks as cached. Null blocks appear here from Mamba | ||
| # align-mode and SWA/chunked-local attention. | ||
| num_cached = sum(1 for b in req_blocks if not b.is_null) | ||
| self.num_cached_block[request_id] = num_cached |
There was a problem hiding this comment.
agree that this line have bug when delay_cache_block is True. but I think it should be set to the input len(new_computed_blocks) before this line new_computed_blocks = new_computed_blocks[num_skipped_blocks:]
There was a problem hiding this comment.
thanks @heheda12345 !
This isn't quite working even after reverting block_pool. Will investigate some more asap
|
This pull request has merge conflicts that must be resolved before it can be |
| from vllm.v1.core.block_pool import BlockPool | ||
| from vllm.v1.core.kv_cache_utils import ( | ||
| BlockHashList, | ||
| BlockHashListWithBlockSize, | ||
| BlockHashWithGroupId, |
There was a problem hiding this comment.
ignore everything which isnt this file
|
@heheda12345 @tdoublep I pulled the changes to be confined to the |
Regarding the cache hit 0% issue discussed here, I've proposed a solution in my PR #42524 that might help address this. Could you please take a look and see if it aligns with what you're trying to achieve for the mamba hybrid models? Would love to get your feedback! |
Fix 0% prefix cache hit rate with Mamba in PD disaggregation (all/align).
Based on #42554, real diff here NickLucche/vllm@mamba-prefix-caching-pd...NickLucche:vllm:pd-fix-apc
Bug
Mamba prefix cache reports 0% hit rate on the Decode side in PD disaggregation.
This is PD-specific. In standalone mode,
allocate_new_computed_blocksisskipped entirely (
num_external_computed_tokens = 0), and null blocks onlyappear later during RUNNING via
remove_skipped_blocks, by which time the realblocks are already hashed.
In PD mode,
allocate_new_computed_blocksruns withnum_external_computed_tokens > 0, which padsreq_blockswith null blocksvia Mamba's
get_num_skipped_tokens(N) = N-1. The old code then set:Fix
Two changes, both in single_type_kv_cache_manager.py:
Capture len(new_computed_blocks) before the skip-slicing that strips
leading blocks. This counts only real prefix-hit blocks, not null padding:
This is a no-op for FullAttention (no skipping) and SWA (the null padding in
new_computed_blocks from find_longest_cache_hit exactly equals
num_skipped_blocks, so the count is unchanged).
With fix 1, cache_blocks() no longer early-returns — it iterates the null
blocks. But BlockPool.cache_full_blocks skips them (blk.is_null → continue),
so their hashes never enter the hash map.
Mamba's find_longest_cache_hit searches right-to-left through block hashes.
If null-block positions aren't in the hash map, the search misses and
hit_length drops to 0, dragging the HMA coordinator's overall hit to 0.
MambaManager.cache_blocks now registers hash → null_block entries for null positions.
Reproducer (PD disaggregation)
Test with
Benchmark
A simple scenario, PD TP1, H100, Nemotron3-Nano, ~8k/1k: