-
-
Notifications
You must be signed in to change notification settings - Fork 14.7k
[V1][Hybrid] Enable spec decode and optimize block-aligned split in mamba cache align mode #33024
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
3acfc49
83bb817
5781011
2a90d6b
1920b70
74bc4b2
c7985f7
0463a56
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -30,6 +30,7 @@ | |
| RoutedExpertsReader, | ||
| ) | ||
| from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry | ||
| from vllm.utils.math_utils import round_down | ||
| from vllm.v1.core.encoder_cache_manager import ( | ||
| EncoderCacheManager, | ||
| EncoderDecoderCacheManager, | ||
|
|
@@ -261,6 +262,14 @@ def has_mamba_layers(kv_cache_config: KVCacheConfig) -> bool: | |
| vllm_config=self.vllm_config, | ||
| ) | ||
|
|
||
| def _mamba_compute_cache_pos(self, num_tokens_to_cache: int) -> int: | ||
| block_size = self.cache_config.block_size | ||
| last_cache_position = num_tokens_to_cache - num_tokens_to_cache % block_size | ||
| # eagle prune | ||
| if self.use_eagle: | ||
| last_cache_position = max(last_cache_position - block_size, 0) | ||
| return last_cache_position | ||
|
|
||
| def _mamba_block_aligned_split( | ||
| self, | ||
| request: Request, | ||
|
|
@@ -271,31 +280,44 @@ def _mamba_block_aligned_split( | |
| assert num_external_computed_tokens == 0, ( | ||
| "External KV connector is not verified yet" | ||
| ) | ||
| # TODO: need check for resume requests | ||
| if request.num_output_tokens == 0: # prefill | ||
| num_computed_tokens = ( | ||
| request.num_computed_tokens | ||
| + num_new_local_computed_tokens | ||
| + num_external_computed_tokens | ||
| ) | ||
| # Perform block-aligned splitting at prefill phase, including: | ||
| # * non-resumed requests: num_computed_tokens < num_prompt_tokens + 0 | ||
| # * resumed requests: num_computed_tokens < ( | ||
| # num_prompt_tokens + num_output_tokens | ||
| # ) | ||
| if num_computed_tokens < request.num_tokens: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How will this behave for a normal decode where we have: ?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh right, I totally forgot that the normal decode also follows this logic. While the normal decode doesn't get affected by this logic, it still introduces some redundant computations. We could maybe use |
||
| # To enable block-aligned caching of the Mamba state, `num_new_tokens` | ||
| # must be a multiple of `block_size`. | ||
| # As an exception, if `num_new_tokens` is less than `block_size`, the | ||
| # state is simply not cached, requiring no special handling. | ||
| # Additionally, when Eagle mode is enabled, FullAttn prunes the last | ||
| # matching block. To prevent this from causing a Mamba cache miss, the | ||
| # last chunk must be larger than `block_size`. | ||
| block_size = self.cache_config.block_size | ||
| last_cache_position = ( | ||
| request.num_prompt_tokens - request.num_prompt_tokens % block_size | ||
| ) | ||
| # eagle prune | ||
| if self.use_eagle: | ||
| last_cache_position = max(last_cache_position - block_size, 0) | ||
| num_computed_tokens = ( | ||
| request.num_computed_tokens | ||
| + num_new_local_computed_tokens | ||
| + num_external_computed_tokens | ||
| ) | ||
| # last chunk must be not smaller than `block_size`. | ||
| num_tokens_to_cache = request.num_tokens | ||
| if request.num_output_tokens > 0: # resumed requests | ||
| # Perform separate block-aligned splits for prompt and output tokens | ||
| # in resumed requests to maximize cache hits. | ||
| last_prompt_cache_position = self._mamba_compute_cache_pos( | ||
| request.num_prompt_tokens | ||
| ) | ||
| if num_computed_tokens < last_prompt_cache_position: | ||
| num_new_tokens = min( | ||
| num_new_tokens, request.num_prompt_tokens - num_computed_tokens | ||
| ) | ||
| num_tokens_to_cache = request.num_prompt_tokens | ||
|
|
||
| last_cache_position = self._mamba_compute_cache_pos(num_tokens_to_cache) | ||
| num_computed_tokens_after_sched = num_computed_tokens + num_new_tokens | ||
| if num_computed_tokens_after_sched < last_cache_position: | ||
| # align to block_size | ||
| num_new_tokens = num_new_tokens // block_size * block_size | ||
| num_new_tokens = round_down( | ||
| num_new_tokens, self.cache_config.block_size | ||
| ) | ||
| elif ( | ||
| num_computed_tokens | ||
| < last_cache_position | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.