Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions tests/v1/e2e/test_mamba_prefix_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def fake_propose_draft_token_ids_fn(
aux_hidden_states: list[torch.Tensor] | None,
spec_decode_metadata: SpecDecodeMetadata | None,
common_attn_metadata: CommonAttentionMetadata,
slot_mappings: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None,
) -> list[list[int]]:
num_computed_tokens_cpu_tensor = self.input_batch.num_computed_tokens_cpu_tensor
num_computed_tokens = num_computed_tokens_cpu_tensor[0].item()
Expand Down Expand Up @@ -473,10 +474,6 @@ def apply_patch(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(mamba_utils, "do_mamba_copy_block", fake_copy_fn)


@pytest.mark.skip(
reason="Skipping test_mamba_prefix_cache because it is based on spec "
"decode which is not allowed now."
)
def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch):
run_ref_mamba_state_in_subprocess()
apply_patch(monkeypatch)
Expand Down
4 changes: 0 additions & 4 deletions vllm/model_executor/models/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,10 +354,6 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
assert vllm_config.scheduler_config.enable_chunked_prefill, (
"Chunked prefill is required for mamba cache mode 'align'."
)
assert not vllm_config.speculative_config, (
"Mamba cache mode 'align' is currently not compatible "
"with speculative decoding."
)
logger.info(
"Warning: Prefix caching in Mamba cache '%s' "
"mode is currently enabled. "
Expand Down
54 changes: 38 additions & 16 deletions vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
RoutedExpertsReader,
)
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.utils.math_utils import round_down
from vllm.v1.core.encoder_cache_manager import (
EncoderCacheManager,
EncoderDecoderCacheManager,
Expand Down Expand Up @@ -261,6 +262,14 @@ def has_mamba_layers(kv_cache_config: KVCacheConfig) -> bool:
vllm_config=self.vllm_config,
)

def _mamba_compute_cache_pos(self, num_tokens_to_cache: int) -> int:
block_size = self.cache_config.block_size
last_cache_position = num_tokens_to_cache - num_tokens_to_cache % block_size
# eagle prune
if self.use_eagle:
last_cache_position = max(last_cache_position - block_size, 0)
return last_cache_position

def _mamba_block_aligned_split(
self,
request: Request,
Expand All @@ -271,31 +280,44 @@ def _mamba_block_aligned_split(
assert num_external_computed_tokens == 0, (
"External KV connector is not verified yet"
)
# TODO: need check for resume requests
if request.num_output_tokens == 0: # prefill
num_computed_tokens = (
request.num_computed_tokens
+ num_new_local_computed_tokens
+ num_external_computed_tokens
)
# Perform block-aligned splitting at prefill phase, including:
# * non-resumed requests: num_computed_tokens < num_prompt_tokens + 0
# * resumed requests: num_computed_tokens < (
# num_prompt_tokens + num_output_tokens
# )
if num_computed_tokens < request.num_tokens:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How will this behave for a normal decode where we have:

num_computed_tokens = request.num_tokens - 1

?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh right, I totally forgot that the normal decode also follows this logic. While the normal decode doesn't get affected by this logic, it still introduces some redundant computations. We could maybe use num_computed_tokens < max(request.num_prompt_tokens, request.num_tokens - 1)? It looks a bit complicated. What do you think? The main point here is just to distinguish the normal decode from the rest, since it doesn't require block-aligned processing.

# To enable block-aligned caching of the Mamba state, `num_new_tokens`
# must be a multiple of `block_size`.
# As an exception, if `num_new_tokens` is less than `block_size`, the
# state is simply not cached, requiring no special handling.
# Additionally, when Eagle mode is enabled, FullAttn prunes the last
# matching block. To prevent this from causing a Mamba cache miss, the
# last chunk must be larger than `block_size`.
block_size = self.cache_config.block_size
last_cache_position = (
request.num_prompt_tokens - request.num_prompt_tokens % block_size
)
# eagle prune
if self.use_eagle:
last_cache_position = max(last_cache_position - block_size, 0)
num_computed_tokens = (
request.num_computed_tokens
+ num_new_local_computed_tokens
+ num_external_computed_tokens
)
# last chunk must be not smaller than `block_size`.
num_tokens_to_cache = request.num_tokens
if request.num_output_tokens > 0: # resumed requests
# Perform separate block-aligned splits for prompt and output tokens
# in resumed requests to maximize cache hits.
last_prompt_cache_position = self._mamba_compute_cache_pos(
request.num_prompt_tokens
)
if num_computed_tokens < last_prompt_cache_position:
num_new_tokens = min(
num_new_tokens, request.num_prompt_tokens - num_computed_tokens
)
num_tokens_to_cache = request.num_prompt_tokens

last_cache_position = self._mamba_compute_cache_pos(num_tokens_to_cache)
num_computed_tokens_after_sched = num_computed_tokens + num_new_tokens
if num_computed_tokens_after_sched < last_cache_position:
# align to block_size
num_new_tokens = num_new_tokens // block_size * block_size
num_new_tokens = round_down(
num_new_tokens, self.cache_config.block_size
)
elif (
num_computed_tokens
< last_cache_position
Expand Down