diff --git a/tests/v1/e2e/test_mamba_prefix_cache.py b/tests/v1/e2e/test_mamba_prefix_cache.py index 7fe95366b9d5..6796294839bd 100644 --- a/tests/v1/e2e/test_mamba_prefix_cache.py +++ b/tests/v1/e2e/test_mamba_prefix_cache.py @@ -103,6 +103,7 @@ def fake_propose_draft_token_ids_fn( aux_hidden_states: list[torch.Tensor] | None, spec_decode_metadata: SpecDecodeMetadata | None, common_attn_metadata: CommonAttentionMetadata, + slot_mappings: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None, ) -> list[list[int]]: num_computed_tokens_cpu_tensor = self.input_batch.num_computed_tokens_cpu_tensor num_computed_tokens = num_computed_tokens_cpu_tensor[0].item() @@ -473,10 +474,6 @@ def apply_patch(monkeypatch: pytest.MonkeyPatch): monkeypatch.setattr(mamba_utils, "do_mamba_copy_block", fake_copy_fn) -@pytest.mark.skip( - reason="Skipping test_mamba_prefix_cache because it is based on spec " - "decode which is not allowed now." -) def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch): run_ref_mamba_state_in_subprocess() apply_patch(monkeypatch) diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index cd462678b051..efd594c9e3d0 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -354,10 +354,6 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: assert vllm_config.scheduler_config.enable_chunked_prefill, ( "Chunked prefill is required for mamba cache mode 'align'." ) - assert not vllm_config.speculative_config, ( - "Mamba cache mode 'align' is currently not compatible " - "with speculative decoding." - ) logger.info( "Warning: Prefix caching in Mamba cache '%s' " "mode is currently enabled. " diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 20e9fced733c..5f745e39a16a 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -30,6 +30,7 @@ RoutedExpertsReader, ) from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry +from vllm.utils.math_utils import round_down from vllm.v1.core.encoder_cache_manager import ( EncoderCacheManager, EncoderDecoderCacheManager, @@ -261,6 +262,14 @@ def has_mamba_layers(kv_cache_config: KVCacheConfig) -> bool: vllm_config=self.vllm_config, ) + def _mamba_compute_cache_pos(self, num_tokens_to_cache: int) -> int: + block_size = self.cache_config.block_size + last_cache_position = num_tokens_to_cache - num_tokens_to_cache % block_size + # eagle prune + if self.use_eagle: + last_cache_position = max(last_cache_position - block_size, 0) + return last_cache_position + def _mamba_block_aligned_split( self, request: Request, @@ -271,31 +280,44 @@ def _mamba_block_aligned_split( assert num_external_computed_tokens == 0, ( "External KV connector is not verified yet" ) - # TODO: need check for resume requests - if request.num_output_tokens == 0: # prefill + num_computed_tokens = ( + request.num_computed_tokens + + num_new_local_computed_tokens + + num_external_computed_tokens + ) + # Perform block-aligned splitting at prefill phase, including: + # * non-resumed requests: num_computed_tokens < num_prompt_tokens + 0 + # * resumed requests: num_computed_tokens < ( + # num_prompt_tokens + num_output_tokens + # ) + if num_computed_tokens < request.num_tokens: # To enable block-aligned caching of the Mamba state, `num_new_tokens` # must be a multiple of `block_size`. # As an exception, if `num_new_tokens` is less than `block_size`, the # state is simply not cached, requiring no special handling. # Additionally, when Eagle mode is enabled, FullAttn prunes the last # matching block. To prevent this from causing a Mamba cache miss, the - # last chunk must be larger than `block_size`. - block_size = self.cache_config.block_size - last_cache_position = ( - request.num_prompt_tokens - request.num_prompt_tokens % block_size - ) - # eagle prune - if self.use_eagle: - last_cache_position = max(last_cache_position - block_size, 0) - num_computed_tokens = ( - request.num_computed_tokens - + num_new_local_computed_tokens - + num_external_computed_tokens - ) + # last chunk must be not smaller than `block_size`. + num_tokens_to_cache = request.num_tokens + if request.num_output_tokens > 0: # resumed requests + # Perform separate block-aligned splits for prompt and output tokens + # in resumed requests to maximize cache hits. + last_prompt_cache_position = self._mamba_compute_cache_pos( + request.num_prompt_tokens + ) + if num_computed_tokens < last_prompt_cache_position: + num_new_tokens = min( + num_new_tokens, request.num_prompt_tokens - num_computed_tokens + ) + num_tokens_to_cache = request.num_prompt_tokens + + last_cache_position = self._mamba_compute_cache_pos(num_tokens_to_cache) num_computed_tokens_after_sched = num_computed_tokens + num_new_tokens if num_computed_tokens_after_sched < last_cache_position: # align to block_size - num_new_tokens = num_new_tokens // block_size * block_size + num_new_tokens = round_down( + num_new_tokens, self.cache_config.block_size + ) elif ( num_computed_tokens < last_cache_position