[V1][Hybrid] Enable spec decode and optimize block-aligned split in mamba cache align mode#33024
[V1][Hybrid] Enable spec decode and optimize block-aligned split in mamba cache align mode#33024peakcrosser7 wants to merge 8 commits intovllm-project:mainfrom
Conversation
Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
There was a problem hiding this comment.
Code Review
This pull request re-enables speculative decoding for Mamba with align cache mode and refactors the block-aligned splitting logic to better support resumed requests. The changes are logical and well-structured, particularly the introduction of the _mamba_compute_cache_pos helper function. However, I've identified a potential issue in the calculation of the last cacheable position for Eagle mode, which could lead to performance degradation due to cache misses under specific conditions.
Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
heheda12345
left a comment
There was a problem hiding this comment.
Is this for correctness or for high cache hit rate?
For cache hit rate of resumed request, I think we only need to ensure the last prefill chunk state was computed before so we have cached it if an external storage exists and don't need to force recompute this position. Try to make the code simple :)
|
add ready for the large unit test |
The changes are for both correctness and cache hit rate. |
Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
| # * resumed requests: num_computed_tokens < ( | ||
| # num_prompt_tokens + num_output_tokens | ||
| # ) | ||
| if num_computed_tokens < request.num_tokens: |
There was a problem hiding this comment.
How will this behave for a normal decode where we have:
num_computed_tokens = request.num_tokens - 1
?
There was a problem hiding this comment.
Oh right, I totally forgot that the normal decode also follows this logic. While the normal decode doesn't get affected by this logic, it still introduces some redundant computations. We could maybe use num_computed_tokens < max(request.num_prompt_tokens, request.num_tokens - 1)? It looks a bit complicated. What do you think? The main point here is just to distinguish the normal decode from the rest, since it doesn't require block-aligned processing.
|
This pull request has merge conflicts that must be resolved before it can be |
Purpose
Test Plan
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.