[V1] [Hybrid] Lighter Mamba Prefix Caching with standard memory layout#29272
[V1] [Hybrid] Lighter Mamba Prefix Caching with standard memory layout#29272peakcrosser7 wants to merge 67 commits intovllm-project:mainfrom
Conversation
|
This pull request has merge conflicts that must be resolved before it can be |
| @@ -57,9 +58,18 @@ class GDNAttentionMetadata: | |||
| batch_ptr: torch.Tensor | None = None | |||
| token_chunk_offset_ptr: torch.Tensor | None = None | |||
|
|
|||
| def mamba_gather_indices(common_attn_metadata: CommonAttentionMetadata, | |||
There was a problem hiding this comment.
nit: Will it be faster & clearer to write a numba (cpu) / triton (gpu) kernel?
There was a problem hiding this comment.
Yep, that's the plan. This is just a temporary helper function right now. It'll eventually be moved somewhere central so different Mamba variant metadata can all call it to get their state_indices.
| ) | ||
|
|
||
| # Schedule encoder inputs. | ||
| encoder_inputs_to_schedule = None | ||
| external_load_encoder_input: list[int] = [] | ||
| new_encoder_compute_budget = encoder_compute_budget | ||
| if request.has_encoder_inputs: | ||
| ( | ||
| encoder_inputs_to_schedule, | ||
| num_new_tokens, |
There was a problem hiding this comment.
reminder: num_new_tokens is updated here.
There was a problem hiding this comment.
Thanks for the reminder! You're right, I missed the encoder case and will move the block-aligned logic after this section.
By the way, does this block-aligned logic conflict with the encoder input?
vllm/v1/core/sched/scheduler.py
Outdated
| # Additionally, when Eagle mode is enabled, FullAttn prunes the last | ||
| # matching block. To prevent this from causing a Mamba cache miss, the | ||
| # last chunk must be larger than `block_size`. | ||
| block_size = self.block_size |
There was a problem hiding this comment.
I can't understand this part of code. I thought we only need something like:
if request.num_output_tokens == 0: # prefill
last_cache_position = request.num_prompt_tokens - request.num_prompt_tokens % block_size
# eagle prune
if self.use_eagle: last_cache_position = max(last_cache_position - block_size, 0)
num_computed_tokens_after_prefill = request.num_computed_tokens + num_new_tokens
if num_computed_tokens_after_prefill < last_cache_position:
num_new_tokens = num_new_tokens // block_size * block_size # align to block_size
elif request.num_computed_tokens < last_cache_position and last_cache_position < num_computed_tokens_after_prefill:
num_new_tokens = last_cache_position - request.num_computed_tokens # force to cache the last chunk
else:
pass # prefill the last few tokens
There was a problem hiding this comment.
num_new_tokens = num_new_tokens // block_size * block_size may not work if we don't force chunk align in this case
https://github.com/vllm-project/vllm/pull/29272/files#r2555167588
There was a problem hiding this comment.
I can't understand this part of code. I thought we only need something like:
if request.num_output_tokens == 0: # prefill last_cache_position = request.num_prompt_tokens - request.num_prompt_tokens % block_size # eagle prune if self.use_eagle: last_cache_position = max(last_cache_position - block_size, 0) num_computed_tokens_after_prefill = request.num_computed_tokens + num_new_tokens if num_computed_tokens_after_prefill < last_cache_position: num_new_tokens = num_new_tokens // block_size * block_size # align to block_size elif request.num_computed_tokens < last_cache_position and last_cache_position < num_computed_tokens_after_prefill: num_new_tokens = last_cache_position - request.num_computed_tokens # force to cache the last chunk else: pass # prefill the last few tokens
Got it, your implementation is much more concise!
This part of your code should be executed after num_new_tokens = min(num_new_tokens, token_budget).
There was a problem hiding this comment.
num_new_tokens = num_new_tokens // block_size * block_sizemay not work if we don't force chunk align in this case https://github.com/vllm-project/vllm/pull/29272/files#r2555167588
Yes, details in that comment.
vllm/v1/core/sched/scheduler.py
Outdated
| @@ -270,73 +288,58 @@ def schedule(self) -> SchedulerOutput: | |||
| # its max_total_tokens or max_model_len. | |||
| # 2. The encoder budget is exhausted. | |||
| # 3. The encoder cache is exhausted. | |||
| # 4. Insufficient budget for a block-aligned chunk in hybrid | |||
| # models with lighter mamba prefix caching. | |||
There was a problem hiding this comment.
in this case, should we allow the prefill of all scheduled tokens instead of forcing block-aligned chunk?
There was a problem hiding this comment.
We can't do that. For a single prompt, if any intermediate chunk is not block-aligned, we can not bind the computed tokens to a block's hash in next chunks.
And I think trying to re-align by adjusting subsequent chunk sizes would make the logic overly complex.
There was a problem hiding this comment.
The aligned num_new_tokens can be computed with
num_computed_tokens_after_prefill = num_computed_tokens_after_prefill // block_size * block_size
if num_computed_tokens_after_prefill > num_computed_tokens:
num_new_tokens = num_computed_tokens_after_prefill - num_computed_tokens
else:
# don't change
pass
But I think it may also be fine to keep the current implementation
vllm/v1/core/sched/scheduler.py
Outdated
| and num_new_tokens > token_budget | ||
| ): | ||
| self.waiting.pop_request() | ||
| skipped_waiting_requests.prepend_request(request) | ||
| continue | ||
|
|
||
| num_new_tokens = min(num_new_tokens, token_budget) | ||
| if (envs.VLLM_USE_LIGHTER_MAMBA_CACHE |
There was a problem hiding this comment.
make this a util function to avoid code duplication of first prefill / chunked prefill?
There was a problem hiding this comment.
Yep, I will do it
| @@ -647,6 +599,28 @@ def find_longest_cache_hit( | |||
|
|
|||
| return computed_blocks | |||
|
|
|||
| def remove_skipped_blocks(self, request_id: str, | |||
There was a problem hiding this comment.
can you rebase the PR to include the recent changes like #25431?
There was a problem hiding this comment.
ok, I will do it
There was a problem hiding this comment.
I'm finding that the current design still needs remove_skipped_blocks() instead of just get_num_skipped_tokens().
The reason is that in _preprocess_mamba(), we copy the latest immutable block into a new allocated one, and that immutable block can only be freed in the next step.
My plan is to use a dict _req_to_last_computed to track last_computed_tokens for each request. However, get_num_skipped_tokens() doesn't accept the req_id parameter, which prevents this.
Is there a better solution here?
| request_id, num_tokens, new_computed_blocks | ||
| ) | ||
| else: | ||
| num_required_blocks = cdiv(num_tokens, self.block_size) + self.num_speculative_blocks |
There was a problem hiding this comment.
is it ok to always return min(self.num_speculative_blocks + 1, super().get_num_blocks_to_allocate(...)) or
if is_prefill: # I don't have a good idea on how to check is_prefill now
return min(1, super().get_num_blocks_to_allocate(...))
else:
return min(self.num_speculative_blocks + 1, super().get_num_blocks_to_allocate(...))
There was a problem hiding this comment.
Let me think... If we can distinguish between the prefill and decode, we might not need to deal with the complex logic of reusing blocks.
|
|
||
| return num_new_alloc_blocks + num_evictable_computed_blocks | ||
|
|
||
| def save_new_computed_blocks( |
There was a problem hiding this comment.
remove this function?
There was a problem hiding this comment.
mistake, should call super().save_new_computed_blocks()
| req_blocks.extend(new_blocks) | ||
| return new_blocks | ||
|
|
||
| def cache_blocks(self, request: Request, num_tokens: int) -> None: |
There was a problem hiding this comment.
remove this function?
There was a problem hiding this comment.
mistake, same as save_new_computed_blocks()
Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
ed5994b to
fdf8037
Compare
|
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
…rosser7/vllm into ups/mamba_prefix_cache_pro
|
Documentation preview: https://vllm--29272.org.readthedocs.build/en/29272/ |
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Simplify & Bugfix for _preprocess_mamba
| # TODO(hhy): when LPS is enabled, parent_block maybe a null block | ||
| parent_block = blocks[num_cached_blocks - 1] | ||
| assert parent_block.block_hash is not None | ||
| parent_block_hash = maybe_convert_block_hash( |
ok! |
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
| @@ -0,0 +1,56 @@ | |||
| # SPDX-License-Identifier: Apache-2.0 | |||
There was a problem hiding this comment.
todo: remove this file
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
|
This pull request has merge conflicts that must be resolved before it can be |
|
Fantastic work :-) Do we know the timeline here? |
|
@peakcrosser7 That is fantastic :-) I had your last version running but had some issues with guided generation. I will try out the new PR just now. |
|
Closed because of #30877. |
#28176 with with standard memory layout
Purpose
Test Plan
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.