Skip to content
25 changes: 14 additions & 11 deletions vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,27 +275,30 @@ def _mamba_block_aligned_split(
assert num_external_computed_tokens == 0, (
"External KV connector is not verified yet"
)
# TODO: need check for resume requests
if request.num_output_tokens == 0: # prefill
num_computed_tokens = (
request.num_computed_tokens
+ num_new_local_computed_tokens
+ num_external_computed_tokens
)
# Perform block-aligned splitting at prefill phase, including:
# * non-resumed requests: num_computed_tokens < num_prompt_tokens + 0
# * resumed requests: num_computed_tokens < (
# num_prompt_tokens + num_output_tokens
# )
# NOTE: Use `request.num_tokens - 1` to bypass normal decoding.
if num_computed_tokens < max(request.num_prompt_tokens, request.num_tokens - 1):
# To enable block-aligned caching of the Mamba state, `num_new_tokens`
# must be a multiple of `block_size`.
# As an exception, if `num_new_tokens` is less than `block_size`, the
# state is simply not cached, requiring no special handling.
# Additionally, when Eagle mode is enabled, FullAttn prunes the last
# matching block. To prevent this from causing a Mamba cache miss, the
# last chunk must be larger than `block_size`.
# last chunk must be not smaller than `block_size`.
block_size = self.cache_config.block_size
last_cache_position = (
request.num_prompt_tokens - request.num_prompt_tokens % block_size
)
last_cache_position = request.num_tokens - request.num_tokens % block_size
# eagle prune
if self.use_eagle:
last_cache_position = max(last_cache_position - block_size, 0)
num_computed_tokens = (
request.num_computed_tokens
+ num_new_local_computed_tokens
+ num_external_computed_tokens
)
num_computed_tokens_after_sched = num_computed_tokens + num_new_tokens
if num_computed_tokens_after_sched < last_cache_position:
# align to block_size
Expand Down