[Bugfix] - Fix Mamba prefix caching corruption with chunked prefill#34587
[Bugfix] - Fix Mamba prefix caching corruption with chunked prefill#34587Josephasafg wants to merge 2 commits intovllm-project:mainfrom
Conversation
Signed-off-by: Josephasafg <ajgard7@gmail.com>
Signed-off-by: Josephasafg <ajgard7@gmail.com>
There was a problem hiding this comment.
Code Review
This pull request addresses a critical bug causing prefix caching corruption in Mamba models when chunked prefill is enabled. The root cause is that state blocks were being cached before they were completely filled, leading to incorrect state being loaded from the cache. The fix introduces a mechanism to track the write position for each block within a request. This is achieved by adding block_write_positions to MambaManager and two new helper methods, _update_block_write_positions and _get_num_cacheable_blocks. The cache_blocks method is updated to use this tracking information to only cache blocks that are verified to be complete. The logic appears sound and correctly resolves the described caching issue by preventing incomplete blocks from being added to the prefix cache.
|
@heheda12345 @tdoublep I think I managed to implement a similar chunk alignment solution for mamba1, which handles chunk alignment from the kernel perspective, which will make this PR obsolete. I'll update here in a bit |
|
im closing this PR in favor of this #34798 |
Purpose
When using Mamba models with
mamba_cache_mode="all"and chunked prefill enabled, mamba state blocks can be cached before they contain complete state (e.g. ifmamba_block_size=2048and chunked prefill chunks the sequence to 1800 and 1100 tokens), leading to incorrect output when subsequent requests hit the prefix cache.Root Cause
The SSM kernel (
selective_scan_fwd) writes state at chunk boundaries that don't necessarily align with block boundaries. When the scheduler splits a long prefill into multiple chunks:So if sequence is 3026
3026 // block_sizefull blocks -> caches block 0The kernel writes state based on:
block_idx_first_scheduled + chunk_idxblock_idx_last_scheduledThis means earlier blocks may never be "completed" if subsequent scheduler calls skip over them.
Solution:
Track exactly when each block was written by the kernel via a
block_write_positionsdict. Only cache blocks wherewrite_position == (block_idx + 1) * block_sizeexactly.Test Plan
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.