[Mamba1] - Kernel Level Chunk Alignment for Prefix Caching#34798
[Mamba1] - Kernel Level Chunk Alignment for Prefix Caching#34798DarkLight1337 merged 16 commits intovllm-project:mainfrom
Conversation
Signed-off-by: Josephasafg <ajgard7@gmail.com>
There was a problem hiding this comment.
Code Review
This pull request introduces kernel-level chunk alignment for prefix caching in Mamba1 to address a bug where state was not written at block boundaries. The changes are primarily within the selective_scan_fwd.cu kernel, with supporting modifications to plumb the new chunk_start_offsets parameter through the C++ and Python layers. The new logic dynamically calculates chunk sizes to align with block boundaries, ensuring correct state writing for prefix caching. The implementation appears robust and correctly reflects the logic outlined in the pull request description. I have reviewed the new chunking calculations, pointer arithmetic, and state-writing logic and found no issues.
divakar-amd
left a comment
There was a problem hiding this comment.
Thanks for this PR, was working on a similar fix for the kernel. Requesting few changes to make it compatible for ROCm
…a1-chunk-alignment-upstream
Signed-off-by: Josephasafg <ajgard7@gmail.com>
Signed-off-by: Josephasafg <ajgard7@gmail.com>
|
Created #34977 to support the fix for |
|
@tdoublep Can you please take a look? Thanks! |
Signed-off-by: Josephasafg <ajgard7@gmail.com>
Signed-off-by: Josephasafg <ajgard7@gmail.com>
|
Hi @Josephasafg, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
|
@tdoublep The chunk metadata computation ( |
|
This pull request has merge conflicts that must be resolved before it can be |
…a1-chunk-alignment-upstream
44419ac to
3939010
Compare
|
Hi @Josephasafg, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
Signed-off-by: Josephasafg <ajgard7@gmail.com>
Signed-off-by: Josephasafg <ajgard7@gmail.com>
Signed-off-by: Josephasafg <ajgard7@gmail.com>
tdoublep
left a comment
There was a problem hiding this comment.
Thanks a lot for addressing the review comments. It looks good - just had one question for my understanding.
| ): | ||
| cu_chunk_seqlen_p, _, last_chunk_indices_p = ( | ||
| self._build_chunk_metadata_tensors( | ||
| self.kv_cache_spec.block_size, |
There was a problem hiding this comment.
It seems like for mamba1 we are forcing the kernel-level chunk size to be equal to block size, whereas for mamba2 the kernel-level chunk size is something kind of fixed by the model (it is typically equal to 256). Is there a reason we couldn't keep the kernel-level chunk size to whatever it was before prefix caching was introduced - is it just that 2048 is too big? Just want to make sure I understand this difference between mamba1 vs. mamba2.
There was a problem hiding this comment.
@tdoublep Good question.
The old mamba1 "chunk size" (kChunkSize = kNThreads * kNItems) was purely a hardware detail - how many tokens fit in one thread block's registers per loop iteration. mamba1 does not have a block size as a model parameter.
For APC we need state snapshots at block boundaries, so we use block_size as the chunk size for _compute_chunk_metadata - each iteration produces one state write to one cache block. The default mamba_block_size for mamba1 is 2048 (to comply with non-APC but it can also be decreased), which equals kNThreads * kNItems for the largest kernel config, so for full-block chunks the effective iteration size is the same as before.
With cu_chunk_seqlen, the kernel reads whatever chunk sizes the metadata builder provides. This is what enables correct handling of partial first chunks during chunked prefill - something the old fixed kChunkSize didn't handle. In the non-APC fallback, the kernel chunks by 2048 same as before.
Does that answer your question?
Signed-off-by: Josephasafg <ajgard7@gmail.com>
b3b3abd to
f718db0
Compare
…ect#34798) Signed-off-by: Josephasafg <ajgard7@gmail.com> Signed-off-by: Sergey Zinchenko <sergey.zinchenko.rnd@gmail.com>
…ect#34798) Signed-off-by: Josephasafg <ajgard7@gmail.com> Signed-off-by: EanWang211123 <wangyiheng@sangfor.com.cn>
…ect#34798) Signed-off-by: Josephasafg <ajgard7@gmail.com>
…ect#34798) Signed-off-by: Josephasafg <ajgard7@gmail.com>
…ect#34798) Signed-off-by: Josephasafg <ajgard7@gmail.com>
Purpose
The
selective_scan_fnkernel processed tokens in fixed-size chunkskChunkSize = kNThreads * kNItemstypically 2048, regardless of where the sequence started within a block. When chunked prefill split a request across scheduler iterations, the kernel would write state at positions that didn't align with block boundaries.Example of the bug:
Block size: 2048, seqlen: 3966
Iteration 1: Process 1866 tokens → state written at position 1866 (partial block 0)
Iteration 2: Process 1100 tokens
Solution
Implement kernel-level chunk alignment (similar to PR #24683 for Mamba2) that:
Dry Run:
For
num_computed=1866,seqlen=1100,block_size=2048:chunk_start_offset = 1866 % 2048 = 1866 first_chunk_size = min(1100, 2048 - 1866) = 182 // Tokens to complete block 0 remaining = 1100 - 182 = 918 n_chunks = 2 Loop iteration 1 (chunk 0): Process 182 tokens position: 1866 → 2048 block_idx_completed = (2048 - 1) / 2048 = 0 - Write state to block 0 (now complete) Loop iteration 2 (chunk 1): Process 918 tokens position: 2048 → 2966 - Write state to block 1 (last_scheduled)Test Plan
All prefix caching unittests for Mamba1/Jamba models pass
No degradation in performance and quality has improved
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.