diff --git a/tests/v1/core/test_single_type_kv_cache_manager.py b/tests/v1/core/test_single_type_kv_cache_manager.py index 23097bf2a086..b05040ebe2a6 100644 --- a/tests/v1/core/test_single_type_kv_cache_manager.py +++ b/tests/v1/core/test_single_type_kv_cache_manager.py @@ -24,7 +24,7 @@ def get_sliding_window_manager(sliding_window_spec, block_pool, enable_caching=True): return SlidingWindowManager( sliding_window_spec, - block_pool, + block_pool=block_pool, enable_caching=enable_caching, kv_cache_group_id=0, ) @@ -35,7 +35,7 @@ def get_chunked_local_attention_manager( ): return ChunkedLocalAttentionManager( chunked_local_attention_spec, - block_pool, + block_pool=block_pool, enable_caching=enable_caching, kv_cache_group_id=0, ) @@ -342,11 +342,15 @@ def test_get_num_blocks_to_allocate(): ] assert ( - manager.get_num_blocks_to_allocate("1", 20 * block_size, cached_blocks_1, 0) + manager.get_num_blocks_to_allocate( + "1", 20 * block_size, cached_blocks_1, 0, 20 * block_size + ) == 20 ) assert ( - manager.get_num_blocks_to_allocate("2", 20 * block_size, cached_blocks_2, 0) + manager.get_num_blocks_to_allocate( + "2", 20 * block_size, cached_blocks_2, 0, 20 * block_size + ) == 15 ) @@ -375,6 +379,7 @@ def test_evictable_cached_blocks_not_double_allocated(): num_tokens=2 * block_size, new_computed_blocks=[evictable_block], total_computed_tokens=block_size, + num_tokens_main_model=2 * block_size, ) # Free capacity check should count evictable cached blocks, but allocation # should only allocate the truly new block. @@ -386,7 +391,9 @@ def test_evictable_cached_blocks_not_double_allocated(): num_local_computed_tokens=block_size, num_external_computed_tokens=0, ) - new_blocks = manager.allocate_new_blocks(request_id, num_tokens=4) + new_blocks = manager.allocate_new_blocks( + request_id, num_tokens=4, num_tokens_main_model=4 + ) assert len(new_blocks) == 1 assert len(manager.req_to_blocks[request_id]) == 2 @@ -411,10 +418,14 @@ def test_chunked_local_attention_get_num_blocks_to_allocate(): ] assert ( - manager.get_num_blocks_to_allocate("1", 20 * block_size, cached_blocks_1, 0) + manager.get_num_blocks_to_allocate( + "1", 20 * block_size, cached_blocks_1, 0, 20 * block_size + ) == 20 ) assert ( - manager.get_num_blocks_to_allocate("2", 20 * block_size, cached_blocks_2, 0) + manager.get_num_blocks_to_allocate( + "2", 20 * block_size, cached_blocks_2, 0, 20 * block_size + ) == 15 ) diff --git a/tests/v1/e2e/test_mamba_prefix_cache.py b/tests/v1/e2e/test_mamba_prefix_cache.py new file mode 100644 index 000000000000..7fe95366b9d5 --- /dev/null +++ b/tests/v1/e2e/test_mamba_prefix_cache.py @@ -0,0 +1,764 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import multiprocessing as mp +import os +import traceback +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any + +import datasets +import pytest +import torch + +from vllm import LLM, SamplingParams, TokensPrompt +from vllm.config import CacheConfig +from vllm.model_executor.layers.mamba.mamba_utils import MambaStateCopyFunc +from vllm.sequence import IntermediateTensors +from vllm.v1.attention.backends.utils import CommonAttentionMetadata +from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.engine.core_client import InprocClient +from vllm.v1.kv_cache_interface import KVCacheConfig +from vllm.v1.outputs import SamplerOutput +from vllm.v1.request import Request +from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.spec_decode.metadata import SpecDecodeMetadata +from vllm.v1.worker import mamba_utils +from vllm.v1.worker.gpu_input_batch import CachedRequestState +from vllm.v1.worker.gpu_model_runner import GPUModelRunner +from vllm.v1.worker.lora_model_runner_mixin import GPUInputBatch +from vllm.v1.worker.mamba_utils import get_mamba_groups + + +@dataclass +class StepAction: + num_computed_tokens_start: int + num_scheduled_tokens: int + kv_cache_block_ids: list[int] # [] to follow last step + preprocess_copy_idx: tuple[int, int] # -1, -1 for no copy + postprocess_copy_idx: tuple[int, int] # -1, -1 for no copy + + +num_speculative_tokens = 3 + +num_accepted_tokens = 1 +prompt_token_ids: list[int] = [] +MODEL = "Qwen/Qwen3-Next-80B-A3B-Instruct-FP8" +BLOCK_SIZE = 560 +NUM_HIDDEN_LAYERS = 1 +cur_step_action_idx = 0 +cur_step_action: StepAction | None = None +step_actions: list[StepAction] = [] + + +def get_fake_sample_fn() -> SamplerOutput: + def fake_sample_fn( + self: GPUModelRunner, + logits: torch.Tensor | None, + spec_decode_metadata: SpecDecodeMetadata | None, + ) -> SamplerOutput: + assert logits is not None + num_computed_tokens_cpu_tensor = self.input_batch.num_computed_tokens_cpu_tensor + num_computed_tokens = num_computed_tokens_cpu_tensor[0].item() + if num_computed_tokens < self.input_batch.num_prompt_tokens[0].item(): + first_token_id_index = self.input_batch.num_prompt_tokens[0].item() + else: + first_token_id_index = num_computed_tokens + 1 + if spec_decode_metadata is None: + return SamplerOutput( + sampled_token_ids=torch.tensor( + [[prompt_token_ids[first_token_id_index]]], + device="cuda", + dtype=torch.int32, + ), + logprobs_tensors=None, + ) + num_sampled_tokens = spec_decode_metadata.cu_num_sampled_tokens[0].item() + 1 + accpeted_tokens = prompt_token_ids[ + first_token_id_index : first_token_id_index + + min(num_accepted_tokens, logits.shape[0]) + ] + sampled_token_ids = accpeted_tokens + [-1] * ( + num_sampled_tokens - len(accpeted_tokens) + ) + return SamplerOutput( + sampled_token_ids=torch.tensor( + [sampled_token_ids], device="cuda", dtype=torch.int32 + ), + logprobs_tensors=None, + ) + + return fake_sample_fn + + +def get_fake_propose_draft_token_ids_fn(): + def fake_propose_draft_token_ids_fn( + self: GPUModelRunner, + scheduler_output: SchedulerOutput, + sampled_token_ids: torch.Tensor | list[list[int]], + sampling_metadata: SamplingMetadata, + hidden_states: torch.Tensor, + sample_hidden_states: torch.Tensor, + aux_hidden_states: list[torch.Tensor] | None, + spec_decode_metadata: SpecDecodeMetadata | None, + common_attn_metadata: CommonAttentionMetadata, + ) -> list[list[int]]: + num_computed_tokens_cpu_tensor = self.input_batch.num_computed_tokens_cpu_tensor + num_computed_tokens = num_computed_tokens_cpu_tensor[0].item() + if ( + self.input_batch.num_tokens_no_spec[0].item() + <= self.input_batch.num_prompt_tokens[0].item() + ): + first_token_id_index = self.input_batch.num_prompt_tokens[0].item() + else: + first_token_id_index = ( + num_computed_tokens + 1 + ) # bonus token isn't considered as computed + first_token_id_index += self.input_batch.num_accepted_tokens_cpu[0].item() + proposed_draft_token_ids = [ + prompt_token_ids[ + first_token_id_index : first_token_id_index + num_speculative_tokens + ] + ] + return proposed_draft_token_ids + + return fake_propose_draft_token_ids_fn + + +def get_fake_step_action_fn(original_step_action_fn: Callable): + def fake_get_output(self: InprocClient): + global cur_step_action_idx + global cur_step_action + if cur_step_action_idx < len(step_actions): + cur_step_action = step_actions[cur_step_action_idx] + cur_step_action_idx += 1 + else: + cur_step_action = None + print(f"cur_step_action: {cur_step_action_idx=} {cur_step_action=}") + return original_step_action_fn(self) + + return fake_get_output + + +def get_fake_allocate_slots_fn(original_allocate_slots_fn: Callable): + def fake_allocate_slots_fn( + self: KVCacheManager, + request: Request, + num_new_tokens: int, + num_new_computed_tokens: int = 0, + new_computed_blocks: KVCacheBlocks | None = None, + num_lookahead_tokens: int = 0, + num_external_computed_tokens: int = 0, + delay_cache_blocks: bool = False, + num_encoder_tokens: int = 0, + ): + ret = original_allocate_slots_fn( + self, + request, + num_new_tokens, + num_new_computed_tokens, + new_computed_blocks, + num_lookahead_tokens, + num_external_computed_tokens, + delay_cache_blocks, + num_encoder_tokens, + ) + if cur_step_action is not None: + cur_block_ids = self.coordinator.single_type_managers[0].req_to_blocks[ + request.request_id + ] + not_null_block_flags = [not block.is_null for block in cur_block_ids] + block_ids = [1 if block else 0 for block in not_null_block_flags] + assert block_ids == cur_step_action.kv_cache_block_ids + return ret + + return fake_allocate_slots_fn + + +mamba_kv_cache_dict = {} + + +def get_fake_execute_model_fn(original_execute_model_fn: Callable): + last_num_computed_tokens = 0 + + def fake_execute_model_fn( + self: GPUModelRunner, + scheduler_output: SchedulerOutput, + intermediate_tensors: IntermediateTensors | None = None, + ): + if cur_step_action is not None: + num_scheduled_tokens = next( + iter(scheduler_output.num_scheduled_tokens.values()) + ) + assert num_scheduled_tokens == cur_step_action.num_scheduled_tokens + mamba_group_ids, mamba_spec = get_mamba_groups(self.kv_cache_config) + mamba_group_id = mamba_group_ids[0] + mamba_layer_name = self.kv_cache_config.kv_cache_groups[ + mamba_group_id + ].layer_names[0] + nonlocal last_num_computed_tokens + if len(scheduler_output.scheduled_cached_reqs.req_ids) > 0: + num_computed_tokens = ( + scheduler_output.scheduled_cached_reqs.num_computed_tokens[0] + ) + if ( + num_computed_tokens // BLOCK_SIZE + > last_num_computed_tokens // BLOCK_SIZE + ): + # generated a new aligned block in this step + block_idx = num_computed_tokens // mamba_spec.block_size - 1 + block_id = ( + self.input_batch.block_table.block_tables[mamba_group_id] + .block_table.cpu[0, block_idx] + .item() + ) + if block_id != 0: + kv_cache = self.compilation_config.static_forward_context[ + mamba_layer_name + ].kv_cache + mamba_kv_cache_dict[ + num_computed_tokens - num_computed_tokens % BLOCK_SIZE + ] = ( + kv_cache[0][0][block_id].clone(), + kv_cache[0][1][block_id].clone(), + ) + + last_num_computed_tokens = num_computed_tokens + else: + last_num_computed_tokens = 0 + + ret = original_execute_model_fn(self, scheduler_output, intermediate_tensors) + + if cur_step_action is not None: + assert ( + cur_step_action.num_computed_tokens_start + == self.input_batch.num_computed_tokens_cpu[0].item() + ) + + return ret + + return fake_execute_model_fn + + +def get_fake_process_mamba_fn( + original_preprocess_mamba_fn: Callable, + original_post_process_mamba_fn: Callable, + original_copy_fn: Callable, +): + copy_info: tuple[list[int], list[int], list[int]] | None = None + + def check_copy_info( + action: tuple[int, int], + kv_cache_config: KVCacheConfig, + forward_context: dict[str, Any], + input_batch: GPUInputBatch, + ): + assert copy_info is not None + if action == (-1, -1): + assert len(copy_info[0]) == len(copy_info[1]) == len(copy_info[2]) == 0 + else: + assert len(copy_info[0]) == len(copy_info[1]) == len(copy_info[2]) == 2 + mamba_group_ids, mamba_spec = get_mamba_groups(kv_cache_config) + mamba_group_id = mamba_group_ids[0] + mamba_layer_name = kv_cache_config.kv_cache_groups[ + mamba_group_id + ].layer_names[0] + mamba_kv_cache = forward_context[mamba_layer_name].kv_cache[0][-1] + mamba_block_table = input_batch.block_table.block_tables[ + mamba_group_id + ].block_table.cpu[0] + expected_temporal_src = mamba_kv_cache[ + mamba_block_table[action[0]] + ].data_ptr() + expected_temporal_dest = mamba_kv_cache[ + mamba_block_table[action[1]] + ].data_ptr() + # -1 is qwen3-next's temporal. We skip checking conv as it is more complex. + assert copy_info[0][-1] == expected_temporal_src + assert copy_info[1][-1] == expected_temporal_dest + + def fake_preprocess_mamba_fn( + scheduler_output: SchedulerOutput, + kv_cache_config: KVCacheConfig, + cache_config: CacheConfig, + mamba_state_idx: dict[str, int], + input_batch: GPUInputBatch, + requests: dict[str, CachedRequestState], + forward_context: dict[str, Any], + mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...], + ): + nonlocal copy_info + copy_info = None + ret = original_preprocess_mamba_fn( + scheduler_output, + kv_cache_config, + cache_config, + mamba_state_idx, + input_batch, + requests, + forward_context, + mamba_state_copy_funcs, + ) + if cur_step_action is not None: + check_copy_info( + cur_step_action.preprocess_copy_idx, + kv_cache_config, + forward_context, + input_batch, + ) + return ret + + def fake_post_process_mamba_fn( + scheduler_output: SchedulerOutput, + kv_cache_config: KVCacheConfig, + input_batch: GPUInputBatch, + requests: dict[str, CachedRequestState], + mamba_state_idx: dict[str, int], + forward_context: dict[str, Any], + mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...], + ): + nonlocal copy_info + copy_info = None + ret = original_post_process_mamba_fn( + scheduler_output, + kv_cache_config, + input_batch, + requests, + mamba_state_idx, + forward_context, + mamba_state_copy_funcs, + ) + if cur_step_action is not None: + check_copy_info( + cur_step_action.postprocess_copy_idx, + kv_cache_config, + forward_context, + input_batch, + ) + return ret + + def fake_copy_fn( + src_state_list: list[int], + dest_state_list: list[int], + num_elements_list: list[int], + ): + nonlocal copy_info + assert copy_info is None + copy_info = (src_state_list, dest_state_list, num_elements_list) + return original_copy_fn( + src_state_list, + dest_state_list, + num_elements_list, + ) + + return fake_preprocess_mamba_fn, fake_post_process_mamba_fn, fake_copy_fn + + +def run_ref_mamba_state_in_subprocess() -> None: + ctx = mp.get_context("spawn") + proc = ctx.Process(target=_run_ref_mamba_state_worker) + proc.start() + proc.join(timeout=600) + if proc.exitcode != 0: + raise RuntimeError(f"Ref mamba state process exited with code {proc.exitcode}.") + + +def _run_ref_mamba_state_worker(): + try: + os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" + num_generated_tokens = 8000 + num_prompt_tokens = 500 + sampling_params = SamplingParams( + temperature=0.0, max_tokens=num_generated_tokens + ) + prompt_dataset = datasets.load_dataset("heheda/a_long_article") + full_prompt = prompt_dataset["train"][0]["text"] + fake_execute_model_fn = get_fake_execute_model_fn(GPUModelRunner.execute_model) + GPUModelRunner.execute_model = fake_execute_model_fn + fake_sample_fn = get_fake_sample_fn() + GPUModelRunner._sample = fake_sample_fn + engine = LLM( + model=MODEL, + block_size=BLOCK_SIZE, + hf_overrides={"num_hidden_layers": NUM_HIDDEN_LAYERS}, + seed=42, + ) + global prompt_token_ids + prompt_token_ids = engine.get_tokenizer().encode(full_prompt) + print(f"Token IDs length: {len(prompt_token_ids)}") + + _outputs = engine.generate( + [TokensPrompt(prompt_token_ids=prompt_token_ids[:num_prompt_tokens])], + sampling_params, + ) + # ref_mamba_kv_cache_dict = torch.load("mamba_kv_cache_dict.pth") + # check_mamba_state_equal(ref_mamba_kv_cache_dict, mamba_kv_cache_dict) + # torch.save(mamba_kv_cache_dict, "mamba_kv_cache_dict.pth") + cpu_state_ref = { + key: tuple(tensor.detach().cpu() for tensor in tensors) + for key, tensors in mamba_kv_cache_dict.items() + } + torch.save(cpu_state_ref, "mamba_kv_cache_dict_ref.pth") + mamba_kv_cache_dict.clear() + except Exception: + traceback.print_exc() + raise + + +def check_mamba_state_equal( + mamba_state_ref: dict, mamba_state_new: dict, keys_to_check: list[int] +): + atol = 1e-2 + rtol = 1e-2 + for key in keys_to_check: + assert key in mamba_state_new + assert key in mamba_state_ref + # mamba state new is a subset of mamba state ref + for i, (ref, new) in enumerate(zip(mamba_state_ref[key], mamba_state_new[key])): + if ref.device != new.device: + new = new.to(ref.device) + new = new[: ref.shape[0]] + if not torch.allclose(ref, new, atol=atol, rtol=rtol): + diff_mask = ~torch.isclose(ref, new, atol=atol, rtol=rtol) + diff_idx = torch.nonzero(diff_mask) + if diff_idx.shape[0] * 100 < ref.numel(): + print( + f"[WARNING] found {diff_idx.shape[0] * 100 / ref.numel()}% of the elements are different" # noqa: E501 + ) + continue + raise ValueError( + f"Mamba state is not equal for key: {key} at index {i}" + ) + return True + + +@dataclass +class TestConfig: + num_prompt_tokens: int + num_generated_tokens: int + num_accepted_tokens: int + step_actions: list[StepAction] + + +def apply_patch(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") + + fake_sample_fn = get_fake_sample_fn() + monkeypatch.setattr(GPUModelRunner, "_sample", fake_sample_fn) + + fake_propose_draft_token_ids_fn = get_fake_propose_draft_token_ids_fn() + monkeypatch.setattr( + GPUModelRunner, "propose_draft_token_ids", fake_propose_draft_token_ids_fn + ) + + fake_execute_model_fn = get_fake_execute_model_fn(GPUModelRunner.execute_model) + monkeypatch.setattr(GPUModelRunner, "execute_model", fake_execute_model_fn) + + fake_step_action_fn = get_fake_step_action_fn(InprocClient.get_output) + monkeypatch.setattr(InprocClient, "get_output", fake_step_action_fn) + + fake_allocate_slots_fn = get_fake_allocate_slots_fn(KVCacheManager.allocate_slots) + monkeypatch.setattr(KVCacheManager, "allocate_slots", fake_allocate_slots_fn) + + fake_preprocess_mamba_fn, fake_post_process_mamba_fn, fake_copy_fn = ( + get_fake_process_mamba_fn( + mamba_utils.preprocess_mamba, + mamba_utils.postprocess_mamba, + mamba_utils.do_mamba_copy_block, + ) + ) + monkeypatch.setattr(mamba_utils, "preprocess_mamba", fake_preprocess_mamba_fn) + monkeypatch.setattr(mamba_utils, "postprocess_mamba", fake_post_process_mamba_fn) + monkeypatch.setattr(mamba_utils, "do_mamba_copy_block", fake_copy_fn) + + +@pytest.mark.skip( + reason="Skipping test_mamba_prefix_cache because it is based on spec " + "decode which is not allowed now." +) +def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch): + run_ref_mamba_state_in_subprocess() + apply_patch(monkeypatch) + prompt_dataset = datasets.load_dataset("heheda/a_long_article") + full_prompt = prompt_dataset["train"][0]["text"] + tests = { + "accept_1": TestConfig( + num_prompt_tokens=554, + num_generated_tokens=20, + num_accepted_tokens=1, + step_actions=[ + StepAction(0, 554, [1, 1, 1, 1], (-1, -1), (-1, -1)), + StepAction(554, 4, [], (-1, -1), (-1, -1)), + StepAction(555, 4, [], (-1, -1), (-1, -1)), + StepAction(556, 4, [], (-1, -1), (-1, -1)), + StepAction(557, 4, [1, 1, 1, 1, 1], (0, 1), (-1, -1)), + StepAction(558, 4, [], (-1, -1), (-1, -1)), + StepAction(559, 4, [], (-1, -1), (1, 0)), + StepAction(560, 4, [], (-1, -1), (-1, -1)), + StepAction(561, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), + ], + ), + # test case 2.1: no hit, accept 2 tokens + "accept_2_1": TestConfig( + num_prompt_tokens=554, + num_generated_tokens=20, + num_accepted_tokens=2, + step_actions=[ + StepAction(0, 554, [1, 1, 1, 1], (-1, -1), (-1, -1)), + StepAction(554, 4, [], (-1, -1), (-1, -1)), + StepAction(556, 4, [], (-1, -1), (-1, -1)), + StepAction(558, 4, [1, 1, 1, 1, 1], (1, 1), (2, 0)), + StepAction(560, 4, [], (-1, -1), (-1, -1)), + StepAction(562, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), + ], + ), + # test case 2.2: no hit, accept 2 tokens + "accept_2_2": TestConfig( + num_prompt_tokens=555, + num_generated_tokens=20, + num_accepted_tokens=2, + step_actions=[ + StepAction(0, 555, [1, 1, 1, 1], (-1, -1), (-1, -1)), + StepAction(555, 4, [], (-1, -1), (-1, -1)), + StepAction(557, 4, [1, 1, 1, 1, 1], (1, 1), (-1, -1)), + StepAction(559, 4, [], (-1, -1), (1, 0)), + StepAction(561, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), + ], + ), + "accept_3_1": TestConfig( + num_prompt_tokens=553, + num_generated_tokens=20, + num_accepted_tokens=3, + step_actions=[ + StepAction(0, 553, [1, 1, 1, 1], (-1, -1), (-1, -1)), + StepAction(553, 4, [], (-1, -1), (-1, -1)), + StepAction(556, 4, [], (-1, -1), (-1, -1)), + StepAction(559, 4, [1, 1, 1, 1, 1], (2, 1), (1, 0)), + StepAction(562, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), + ], + ), + "accept_3_2": TestConfig( + num_prompt_tokens=554, + num_generated_tokens=20, + num_accepted_tokens=3, + step_actions=[ + StepAction(0, 554, [1, 1, 1, 1], (-1, -1), (-1, -1)), + StepAction(554, 4, [], (-1, -1), (-1, -1)), + StepAction(557, 4, [1, 1, 1, 1, 1], (2, 1), (3, 0)), + StepAction(560, 4, [], (-1, -1), (-1, -1)), + StepAction(563, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), + ], + ), + "accept_3_3": TestConfig( + num_prompt_tokens=555, + num_generated_tokens=20, + num_accepted_tokens=3, + step_actions=[ + StepAction(0, 555, [1, 1, 1, 1], (-1, -1), (-1, -1)), + StepAction(555, 4, [], (-1, -1), (-1, -1)), + StepAction(558, 4, [1, 1, 1, 1, 1], (2, 1), (2, 0)), + StepAction(561, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), + ], + ), + "accept_4_1": TestConfig( + num_prompt_tokens=553, + num_generated_tokens=20, + num_accepted_tokens=4, + step_actions=[ + StepAction(0, 553, [1, 1, 1, 1], (-1, -1), (-1, -1)), + StepAction(553, 4, [], (-1, -1), (-1, -1)), + StepAction(557, 4, [1, 1, 1, 1, 1], (3, 1), (3, 0)), + StepAction(561, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), + StepAction(565, 4, [], (-1, -1), (-1, -1)), + ], + ), + "accept_4_2": TestConfig( + num_prompt_tokens=554, + num_generated_tokens=25, + num_accepted_tokens=4, + step_actions=[ + StepAction(0, 554, [1, 1, 1, 1], (-1, -1), (-1, -1)), + StepAction(554, 4, [], (-1, -1), (-1, -1)), + StepAction(558, 4, [1, 1, 1, 1, 1], (3, 1), (2, 0)), + StepAction(562, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), + StepAction(566, 4, [], (-1, -1), (-1, -1)), + ], + ), + "accept_4_3": TestConfig( + num_prompt_tokens=555, + num_generated_tokens=25, + num_accepted_tokens=4, + step_actions=[ + StepAction(0, 555, [1, 1, 1, 1], (-1, -1), (-1, -1)), + StepAction(555, 4, [], (-1, -1), (-1, -1)), + StepAction(559, 4, [1, 1, 1, 1, 1], (3, 1), (1, 0)), + StepAction(563, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), + ], + ), + "accept_4_4": TestConfig( + num_prompt_tokens=556, + num_generated_tokens=25, + num_accepted_tokens=4, + step_actions=[ + StepAction(0, 556, [1, 1, 1, 1], (-1, -1), (-1, -1)), + StepAction(556, 4, [], (-1, -1), (3, 0)), + StepAction(560, 4, [1, 1, 1, 1, 1], (0, 1), (-1, -1)), + StepAction(564, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), + ], + ), + "prompt_block_size": TestConfig( + num_prompt_tokens=560, + num_generated_tokens=10, + num_accepted_tokens=4, + step_actions=[ + StepAction(0, 560, [1, 1, 1, 1], (-1, -1), (-1, -1)), + StepAction(560, 4, [1, 1, 1, 1, 1], (0, 1), (-1, -1)), + ], + ), + "prompt_2_block_size": TestConfig( + num_prompt_tokens=560 * 2, + num_generated_tokens=10, + num_accepted_tokens=4, + step_actions=[ + StepAction(0, 560, [1, 1, 1, 1], (-1, -1), (-1, -1)), + StepAction(560, 560, [1, 1, 1, 1, 1], (0, 1), (-1, -1)), + StepAction(560 * 2, 4, [0, 1, 1, 1, 1, 1], (1, 2), (-1, -1)), + ], + ), + "prompt_2_block_size_10": TestConfig( + num_prompt_tokens=560 * 2 + 10, + num_generated_tokens=10, + num_accepted_tokens=4, + step_actions=[ + StepAction(0, 560, [1, 1, 1, 1], (-1, -1), (-1, -1)), + StepAction(560, 570, [1, 0, 1, 1, 1, 1], (0, 2), (-1, -1)), + StepAction(560 * 2 + 10, 4, [0, 0, 1, 1, 1, 1], (-1, -1), (-1, -1)), + ], + ), + "prompt_3_block_size": TestConfig( + num_prompt_tokens=560 * 3, + num_generated_tokens=10, + num_accepted_tokens=4, + step_actions=[ + StepAction(0, 560 * 2, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), + StepAction(560 * 2, 560, [0, 1, 1, 1, 1, 1], (1, 2), (-1, -1)), + StepAction(560 * 3, 4, [0, 0, 1, 1, 1, 1, 1], (2, 3), (-1, -1)), + ], + ), + "prompt_3_block_size_10": TestConfig( + num_prompt_tokens=560 * 3 + 10, + num_generated_tokens=10, + num_accepted_tokens=4, + step_actions=[ + StepAction(0, 560 * 2, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), + StepAction(560 * 2, 570, [0, 1, 0, 1, 1, 1, 1], (1, 3), (-1, -1)), + StepAction(560 * 3 + 10, 4, [0, 0, 0, 1, 1, 1, 1], (-1, -1), (-1, -1)), + ], + ), + "prompt_10_block_size": TestConfig( + num_prompt_tokens=560 * 10, + num_generated_tokens=10, + num_accepted_tokens=4, + step_actions=[ + StepAction(0, 560 * 5, [0, 0, 0, 0, 1, 1, 1, 1], (-1, -1), (-1, -1)), + StepAction( + 560 * 5, + 560 * 4, + [0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1], + (4, 8), + (-1, -1), + ), + StepAction( + 560 * 9, + 560, + [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1], + (8, 9), + (-1, -1), + ), + StepAction( + 560 * 10, + 4, + [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1], + (9, 10), + (-1, -1), + ), + ], + ), + "prompt_10_block_size_10": TestConfig( + num_prompt_tokens=560 * 10 + 10, + num_generated_tokens=10, + num_accepted_tokens=4, + step_actions=[ + StepAction(0, 560 * 5, [0, 0, 0, 0, 1, 1, 1, 1], (-1, -1), (-1, -1)), + StepAction( + 560 * 5, + 560 * 4, + [0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1], + (4, 8), + (-1, -1), + ), + StepAction( + 560 * 9, + 560 + 10, + [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1], + (8, 10), + (-1, -1), + ), + ], + ), + } + + engine = LLM( + model=MODEL, + enable_prefix_caching=True, + block_size=BLOCK_SIZE, + mamba_cache_mode="align", + speculative_config={ + "method": "qwen3_next_mtp", + "num_speculative_tokens": num_speculative_tokens, + }, + max_num_batched_tokens=3072, + hf_overrides={"num_hidden_layers": NUM_HIDDEN_LAYERS}, + seed=42, + ) + global prompt_token_ids + prompt_token_ids = engine.get_tokenizer().encode(full_prompt) + print(f"Token IDs length: {len(prompt_token_ids)}") + for test_case_name, test_config in tests.items(): + print(f"Running test case: {test_case_name}") + num_generated_tokens = test_config.num_generated_tokens + num_prompt_tokens = test_config.num_prompt_tokens + global num_accepted_tokens + num_accepted_tokens = test_config.num_accepted_tokens + sampling_params = SamplingParams( + temperature=0.0, max_tokens=num_generated_tokens + ) + global cur_step_action_idx + cur_step_action_idx = 0 + for step_action_prev, step_action_next in zip( + test_config.step_actions[:-1], test_config.step_actions[1:] + ): + if ( + step_action_next.kv_cache_block_ids is not None + and len(step_action_next.kv_cache_block_ids) == 0 + ): + prev_block_ids = step_action_prev.kv_cache_block_ids + if prev_block_ids is not None: + step_action_next.kv_cache_block_ids = prev_block_ids.copy() + global step_actions + step_actions = test_config.step_actions + _ = engine.generate( + [TokensPrompt(prompt_token_ids=prompt_token_ids[:num_prompt_tokens])], + sampling_params, + ) + assert engine.llm_engine.engine_core.engine_core.scheduler.reset_prefix_cache() + print(f"End test case: {test_case_name}") + keys_to_check = [ + (action.postprocess_copy_idx[1] + 1) * BLOCK_SIZE + for action in test_config.step_actions + if action.postprocess_copy_idx and action.postprocess_copy_idx[0] != -1 + ] + mamba_state_ref = torch.load("mamba_kv_cache_dict_ref.pth") + check_mamba_state_equal(mamba_state_ref, mamba_kv_cache_dict, keys_to_check) + mamba_kv_cache_dict.clear() diff --git a/vllm/config/cache.py b/vllm/config/cache.py index 839ea4780c87..abf10e21d408 100644 --- a/vllm/config/cache.py +++ b/vllm/config/cache.py @@ -31,6 +31,7 @@ "fp8_ds_mla", ] MambaDType = Literal["auto", "float32", "float16"] +MambaCacheMode = Literal["all", "align", "none"] PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor", "xxhash", "xxhash_cbor"] KVOffloadingBackend = Literal["native", "lmcache"] @@ -123,6 +124,15 @@ class CacheConfig: """The data type to use for the Mamba cache (ssm state only, conv state will still be controlled by mamba_cache_dtype). If set to 'auto', the data type for the ssm state will be determined by mamba_cache_dtype.""" + mamba_cache_mode: MambaCacheMode = "none" + """The cache strategy for Mamba layers. + - "none": set when prefix caching is disabled. + - "all": cache the mamba state of all tokens at position i * block_size. This is + the default behavior (for models that support it) when prefix caching is + enabled. + - "align": only cache the mamba state of the last token of each scheduler step and + when the token is at position i * block_size. + """ # Will be set after profiling. num_gpu_blocks: int | None = field(default=None, init=False) diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 3d42205ca2fb..dacc30339995 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -996,6 +996,17 @@ def has_blocked_weights(): # Default to enable HMA if not explicitly disabled by user or logic above. self.scheduler_config.disable_hybrid_kv_cache_manager = False + if self.cache_config.mamba_cache_mode == "align": + if self.scheduler_config.long_prefill_token_threshold > 0: + assert ( + self.scheduler_config.long_prefill_token_threshold + >= self.cache_config.block_size + ) + assert not self.scheduler_config.disable_chunked_mm_input, ( + "Chunked MM input is required because we need the flexibility to " + "schedule a multiple of block_size tokens even if they are in the " + "middle of a mm input" + ) if self.compilation_config.debug_dump_path: self.compilation_config.debug_dump_path = ( self.compilation_config.debug_dump_path.absolute().expanduser() diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 0bbb5402962a..16e3dcc6708f 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -60,6 +60,7 @@ BlockSize, CacheDType, KVOffloadingBackend, + MambaCacheMode, MambaDType, PrefixCachingHashAlgo, ) @@ -556,6 +557,7 @@ class EngineArgs: mamba_cache_dtype: MambaDType = CacheConfig.mamba_cache_dtype mamba_ssm_cache_dtype: MambaDType = CacheConfig.mamba_ssm_cache_dtype mamba_block_size: int | None = get_field(CacheConfig, "mamba_block_size") + mamba_cache_mode: MambaCacheMode = CacheConfig.mamba_cache_mode additional_config: dict[str, Any] = get_field(VllmConfig, "additional_config") @@ -939,6 +941,9 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: cache_group.add_argument( "--mamba-block-size", **cache_kwargs["mamba_block_size"] ) + cache_group.add_argument( + "--mamba-cache-mode", **cache_kwargs["mamba_cache_mode"] + ) cache_group.add_argument( "--kv-offloading-size", **cache_kwargs["kv_offloading_size"] ) @@ -1416,6 +1421,7 @@ def create_engine_config( mamba_cache_dtype=self.mamba_cache_dtype, mamba_ssm_cache_dtype=self.mamba_ssm_cache_dtype, mamba_block_size=self.mamba_block_size, + mamba_cache_mode=self.mamba_cache_mode, kv_offloading_size=self.kv_offloading_size, kv_offloading_backend=self.kv_offloading_backend, ) diff --git a/vllm/model_executor/layers/mamba/abstract.py b/vllm/model_executor/layers/mamba/abstract.py index 4f45dd6caf35..f92ecb6b5b4e 100644 --- a/vllm/model_executor/layers/mamba/abstract.py +++ b/vllm/model_executor/layers/mamba/abstract.py @@ -56,6 +56,7 @@ def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None: block_size=mamba_block_size, page_size_padded=page_size_padded, mamba_type=self.mamba_type, + mamba_cache_mode=vllm_config.cache_config.mamba_cache_mode, num_speculative_blocks=( vllm_config.speculative_config.num_speculative_tokens if vllm_config.speculative_config diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index c22a309ce166..134e1dfd6283 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -255,7 +255,7 @@ def forward_cuda(self, hidden_states: torch.Tensor, output: torch.Tensor): assert self.cache_config is not None mamba_block_size = self.cache_config.mamba_block_size - prefix_caching_enabled = self.cache_config.enable_prefix_caching + is_mamba_cache_all = self.cache_config.mamba_cache_mode == "all" if attn_metadata is not None: assert isinstance(attn_metadata, dict) @@ -304,7 +304,7 @@ def forward_cuda(self, hidden_states: torch.Tensor, output: torch.Tensor): state_indices_tensor_p = prefill_decode_split.state_indices_tensor_p state_indices_tensor_d = prefill_decode_split.state_indices_tensor_d - if prefix_caching_enabled: + if is_mamba_cache_all: block_idx_last_computed_token_d, block_idx_last_computed_token_p = ( torch.split( attn_metadata.block_idx_last_computed_token, @@ -380,7 +380,7 @@ def forward_cuda(self, hidden_states: torch.Tensor, output: torch.Tensor): ssm_outputs.append(scan_out_p) if has_decode: - if prefix_caching_enabled: + if is_mamba_cache_all: state_indices_tensor_d_input = state_indices_tensor_d.gather( 1, block_idx_last_computed_token_d.unsqueeze(1) ).squeeze(1) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 74e4a34b4ae0..7af5e02c29d2 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -570,7 +570,7 @@ def conv_ssm_forward( assert self.cache_config is not None mamba_block_size = self.cache_config.mamba_block_size - prefix_caching_enabled = self.cache_config.enable_prefix_caching + is_mamba_cache_all = self.cache_config.mamba_cache_mode == "all" if attn_metadata is not None: assert isinstance(attn_metadata, dict) attn_metadata = attn_metadata[self.prefix] @@ -622,7 +622,7 @@ def conv_ssm_forward( dim=0, ) - if prefix_caching_enabled: + if is_mamba_cache_all: # If prefix caching is enabled, retrieve the relevant variables # for prefill and decode block_idx_last_computed_token_d, block_idx_last_computed_token_p = ( @@ -701,7 +701,7 @@ def conv_ssm_forward( initial_states = None if has_initial_states_p is not None and prep_initial_states: kernel_ssm_indices = state_indices_tensor_p - if prefix_caching_enabled: + if is_mamba_cache_all: kernel_ssm_indices = state_indices_tensor_p.gather( 1, block_idx_last_computed_token_p.unsqueeze(1) ).squeeze(1) @@ -729,14 +729,14 @@ def conv_ssm_forward( cu_chunk_seqlens=cu_chunk_seqlen_p, last_chunk_indices=last_chunk_indices_p, initial_states=initial_states, - return_intermediate_states=prefix_caching_enabled, + return_intermediate_states=is_mamba_cache_all, dt_softplus=True, dt_limit=(0.0, float("inf")), out=preallocated_ssm_out_p.view(num_prefill_tokens, -1, self.head_dim), state_dtype=ssm_state.dtype, ) - if prefix_caching_enabled: + if is_mamba_cache_all: # The chunk_stride is the number of chunks per mamba block # e.g., if mamba_block_size = 512 and chunk_size = 256, # then chunk_stride = 2 @@ -815,7 +815,7 @@ def conv_ssm_forward( # Process decode requests if has_decode: - if prefix_caching_enabled: + if is_mamba_cache_all: state_indices_tensor_d_input = state_indices_tensor_d.gather( 1, block_idx_last_computed_token_d.unsqueeze(1) ).squeeze(1) diff --git a/vllm/model_executor/layers/mamba/mamba_utils.py b/vllm/model_executor/layers/mamba/mamba_utils.py index 831dab2fbb01..816f76bfa069 100644 --- a/vllm/model_executor/layers/mamba/mamba_utils.py +++ b/vllm/model_executor/layers/mamba/mamba_utils.py @@ -1,6 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Callable +from dataclasses import dataclass +from typing import TypeAlias + import torch from vllm.config.cache import MambaDType @@ -223,3 +227,94 @@ def kda_state_shape( conv_state_k_shape, recurrent_state_shape, ) + + +@dataclass +class MambaCopySpec: + """ + Data class specifying the memory-copy parameters for Mamba states used for + prefix caching in align mode. + + Attributes: + start_addr (int): Starting address for the memory copy operation. + num_elements (int): Number of elements to copy from the starting address. + """ + + start_addr: int + num_elements: int + + +MambaStateCopyFunc: TypeAlias = Callable[ + [torch.Tensor, list[int], int, int], MambaCopySpec +] +""" +Type alias for a function that computes a MambaCopySpec for copying state slices. +Parameters: + state: torch.Tensor - the Mamba state tensor (e.g., conv or temporal states). + block_ids: list[int] - the list of block indices for the state to copy. + cur_block_idx: int - current block index within `block_ids` to copy from. + num_accepted_tokens: int - number of accepted tokens used to compute the copy offset. + Range: 1 .. 1 + num_speculative_tokens (inclusive). +""" + + +def get_conv_copy_spec( + state: torch.Tensor, + block_ids: list[int], + cur_block_idx: int, + num_accepted_tokens: int, +) -> MambaCopySpec: + """Return a MambaCopySpec for copying a convolutional state slice.""" + src_block_id = block_ids[cur_block_idx] + src_state = state[src_block_id, num_accepted_tokens - 1 :] + return MambaCopySpec( + start_addr=src_state.data_ptr(), num_elements=src_state.numel() + ) + + +def get_temporal_copy_spec( + state: torch.Tensor, + block_ids: list[int], + cur_block_idx: int, + num_accepted_tokens: int, +) -> MambaCopySpec: + """Return a MambaCopySpec for copying a temporal state slice.""" + src_block_id = block_ids[cur_block_idx + num_accepted_tokens - 1] + src_state = state[src_block_id] + return MambaCopySpec( + start_addr=src_state.data_ptr(), num_elements=src_state.numel() + ) + + +get_full_copy_spec = get_temporal_copy_spec + + +class MambaStateCopyFuncCalculator: + @classmethod + def linear_attention_state_copy_func(cls): + return (get_temporal_copy_spec,) + + @classmethod + def mamba1_state_copy_func(cls): + return (get_conv_copy_spec, get_temporal_copy_spec) + + @classmethod + def mamba2_state_copy_func(cls): + return get_conv_copy_spec, get_temporal_copy_spec + + @classmethod + def short_conv_state_copy_func(cls): + return (get_conv_copy_spec,) + + @classmethod + def gated_delta_net_state_copy_func(cls): + return (get_conv_copy_spec, get_temporal_copy_spec) + + @classmethod + def kda_state_copy_func(cls): + return ( + get_conv_copy_spec, + get_conv_copy_spec, + get_conv_copy_spec, + get_temporal_copy_spec, + ) diff --git a/vllm/model_executor/models/bamba.py b/vllm/model_executor/models/bamba.py index 22631bbc5489..a7de8e7cf349 100644 --- a/vllm/model_executor/models/bamba.py +++ b/vllm/model_executor/models/bamba.py @@ -24,6 +24,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateCopyFunc, + MambaStateCopyFuncCalculator, MambaStateDtypeCalculator, MambaStateShapeCalculator, ) @@ -455,6 +457,10 @@ def get_mamba_state_shape_from_config( conv_kernel=hf_config.mamba_d_conv, ) + @classmethod + def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]: + return MambaStateCopyFuncCalculator.mamba2_state_copy_func() + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config self.vllm_config = vllm_config diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index e51a110ce0b3..cd462678b051 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -330,26 +330,54 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: cache_config = vllm_config.cache_config if cache_config.enable_prefix_caching: - if model_config.supports_mamba_prefix_caching: - logger.info( - "Warning: Prefix caching is currently enabled. " - "Its support for Mamba layers is experimental. " - "Please report any issues you may observe." + if cache_config.mamba_cache_mode == "none": + cache_config.mamba_cache_mode = ( + "all" if model_config.supports_mamba_prefix_caching else "align" ) - # By default, mamba block size will be set to max_model_len (see - # below). When enabling prefix caching, we align mamba block size - # to the block size as the basic granularity for prefix caching. - if cache_config.mamba_block_size is None: - cache_config.mamba_block_size = cache_config.block_size - else: - logger.info( - "Hybrid or mamba-based model detected without " - "support for prefix caching: disabling." + logger.warning( + "Mamba cache mode is set to '%s' for %s by default " + "when prefix caching is enabled", + cache_config.mamba_cache_mode, + model_config.architecture, ) - cache_config.enable_prefix_caching = False - - if cache_config.mamba_block_size is None: - cache_config.mamba_block_size = model_config.max_model_len + if ( + cache_config.mamba_cache_mode == "all" + and not model_config.supports_mamba_prefix_caching + ): + cache_config.mamba_cache_mode = "align" + logger.warning( + "Hybrid or mamba-based model detected without support " + "for prefix caching with Mamba cache 'all' mode: " + "falling back to 'align' mode." + ) + if cache_config.mamba_cache_mode == "align": + assert vllm_config.scheduler_config.enable_chunked_prefill, ( + "Chunked prefill is required for mamba cache mode 'align'." + ) + assert not vllm_config.speculative_config, ( + "Mamba cache mode 'align' is currently not compatible " + "with speculative decoding." + ) + logger.info( + "Warning: Prefix caching in Mamba cache '%s' " + "mode is currently enabled. " + "Its support for Mamba layers is experimental. " + "Please report any issues you may observe.", + cache_config.mamba_cache_mode, + ) + # By default, mamba block size will be set to max_model_len (see + # below). When enabling prefix caching, we align mamba block size + # to the block size as the basic granularity for prefix caching. + if cache_config.mamba_block_size is None: + cache_config.mamba_block_size = cache_config.block_size + else: + if cache_config.mamba_cache_mode != "none": + cache_config.mamba_cache_mode = "none" + logger.warning( + "Mamba cache mode is set to 'none' when prefix caching is disabled" + ) + if cache_config.mamba_block_size is None: + cache_config.mamba_block_size = model_config.max_model_len class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig): @@ -426,7 +454,7 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: mamba_page_size = MambaSpec( shapes=model_cls.get_mamba_state_shape_from_config(vllm_config), dtypes=model_cls.get_mamba_state_dtype_from_config(vllm_config), - block_size=model_config.max_model_len, + block_size=-1, # block_size doesn't matter for mamba page size ).page_size_bytes # Model may be marked as is_hybrid @@ -435,7 +463,7 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: if mamba_page_size == 0: return - if cache_config.enable_prefix_caching: + if cache_config.mamba_cache_mode == "all": # With prefix caching, select attention block size to # optimize for mamba kernel performance @@ -479,6 +507,13 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: attn_block_size, ) + # By default, mamba block size will be set to max_model_len. + # When enabling prefix caching and using align mamba cache + # mode, we align mamba block size to the block size as the + # basic granularity for prefix caching. + if cache_config.mamba_cache_mode == "align": + cache_config.mamba_block_size = cache_config.block_size + # compute new attention page size attn_page_size = cache_config.block_size * attn_page_size_1_token diff --git a/vllm/model_executor/models/falcon_h1.py b/vllm/model_executor/models/falcon_h1.py index bfb6b1a1f160..49722b6d721f 100644 --- a/vllm/model_executor/models/falcon_h1.py +++ b/vllm/model_executor/models/falcon_h1.py @@ -24,6 +24,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateCopyFunc, + MambaStateCopyFuncCalculator, MambaStateDtypeCalculator, MambaStateShapeCalculator, ) @@ -551,6 +553,10 @@ def get_mamba_state_shape_from_config( conv_kernel=hf_config.mamba_d_conv, ) + @classmethod + def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]: + return MambaStateCopyFuncCalculator.mamba2_state_copy_func() + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config self.vllm_config = vllm_config diff --git a/vllm/model_executor/models/granitemoehybrid.py b/vllm/model_executor/models/granitemoehybrid.py index 3434716b8378..0b601b4b8941 100644 --- a/vllm/model_executor/models/granitemoehybrid.py +++ b/vllm/model_executor/models/granitemoehybrid.py @@ -19,6 +19,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateCopyFunc, + MambaStateCopyFuncCalculator, MambaStateDtypeCalculator, MambaStateShapeCalculator, ) @@ -641,6 +643,10 @@ def get_mamba_state_shape_from_config( conv_kernel=hf_config.mamba_d_conv, ) + @classmethod + def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]: + return MambaStateCopyFuncCalculator.mamba2_state_copy_func() + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 809395cf3a24..4e533d3f0fc0 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -24,6 +24,7 @@ from vllm.inputs import TokensPrompt from vllm.inputs.data import PromptType from vllm.logger import init_logger +from vllm.model_executor.layers.mamba.mamba_utils import MambaStateCopyFunc from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.utils.func_utils import supports_kw @@ -735,6 +736,19 @@ def get_mamba_state_shape_from_config( """ ... + @classmethod + def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, ...]: + """Calculate copy-function callables for each Mamba state. + + Returns: + A tuple of MambaStateCopyFunc callables that correspond, in order, + to the Mamba states produced by the model. Each callable accepts + (state, block_ids, cur_block_idx, num_accepted_tokens) and returns + a MambaCopySpec describing the memory-copy parameters for prefix + caching in align mode. + """ + ... + @overload def is_hybrid(model: object) -> TypeIs[IsHybrid]: ... diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 91b58a83e09a..eeca3cf78bab 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -24,6 +24,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateCopyFunc, + MambaStateCopyFuncCalculator, MambaStateDtypeCalculator, MambaStateShapeCalculator, ) @@ -558,6 +560,10 @@ def get_mamba_state_shape_from_config( conv_kernel=hf_config.mamba_d_conv, ) + @classmethod + def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]: + return MambaStateCopyFuncCalculator.mamba1_state_copy_func() + def compute_logits( self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/kimi_linear.py b/vllm/model_executor/models/kimi_linear.py index d149c3642406..f3ec5b759047 100644 --- a/vllm/model_executor/models/kimi_linear.py +++ b/vllm/model_executor/models/kimi_linear.py @@ -26,6 +26,8 @@ ) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateCopyFunc, + MambaStateCopyFuncCalculator, MambaStateDtypeCalculator, MambaStateShapeCalculator, ) @@ -544,6 +546,14 @@ def get_mamba_state_shape_from_config( num_spec=num_spec, ) + @classmethod + def get_mamba_state_copy_func( + cls, + ) -> tuple[ + MambaStateCopyFunc, MambaStateCopyFunc, MambaStateCopyFunc, MambaStateCopyFunc + ]: + return MambaStateCopyFuncCalculator.kda_state_copy_func() + def compute_logits( self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/lfm2.py b/vllm/model_executor/models/lfm2.py index a4b9c747f872..629a72f39980 100644 --- a/vllm/model_executor/models/lfm2.py +++ b/vllm/model_executor/models/lfm2.py @@ -20,6 +20,8 @@ ) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateCopyFunc, + MambaStateCopyFuncCalculator, MambaStateDtypeCalculator, MambaStateShapeCalculator, ) @@ -459,14 +461,19 @@ def get_mamba_state_shape_from_config( conv_kernel=hf_config.conv_L_cache, ) + @classmethod + def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc]: + return MambaStateCopyFuncCalculator.short_conv_state_copy_func() + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config cache_config = vllm_config.cache_config - - assert not cache_config.enable_prefix_caching, ( - "Lfm2 currently does not support prefix caching" - ) + if cache_config.mamba_cache_mode == "all": + raise NotImplementedError( + "Lfm2 currently does not support 'all' prefix caching, " + "please use '--mamba-cache-mode=align' instead" + ) super().__init__() self.config = config diff --git a/vllm/model_executor/models/lfm2_moe.py b/vllm/model_executor/models/lfm2_moe.py index 6b97e171c727..4704967b56f6 100644 --- a/vllm/model_executor/models/lfm2_moe.py +++ b/vllm/model_executor/models/lfm2_moe.py @@ -25,6 +25,8 @@ ) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateCopyFunc, + MambaStateCopyFuncCalculator, MambaStateDtypeCalculator, MambaStateShapeCalculator, ) @@ -640,6 +642,10 @@ def get_mamba_state_shape_from_config( conv_kernel=hf_config.conv_L_cache, ) + @classmethod + def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc]: + return MambaStateCopyFuncCalculator.short_conv_state_copy_func() + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index aa16640a9427..85212feca529 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -16,6 +16,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateCopyFunc, + MambaStateCopyFuncCalculator, MambaStateDtypeCalculator, MambaStateShapeCalculator, ) @@ -261,6 +263,10 @@ def get_mamba_state_shape_from_config( conv_kernel=hf_config.conv_kernel, ) + @classmethod + def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]: + return MambaStateCopyFuncCalculator.mamba1_state_copy_func() + def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): return self.mamba_cache.copy_inputs_before_cuda_graphs(input_buffers, **kwargs) diff --git a/vllm/model_executor/models/mamba2.py b/vllm/model_executor/models/mamba2.py index 5fcfa9431230..ed363df21230 100644 --- a/vllm/model_executor/models/mamba2.py +++ b/vllm/model_executor/models/mamba2.py @@ -15,6 +15,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateCopyFunc, + MambaStateCopyFuncCalculator, MambaStateDtypeCalculator, MambaStateShapeCalculator, ) @@ -228,6 +230,10 @@ def get_mamba_state_shape_from_config( conv_kernel=hf_config.conv_kernel, ) + @classmethod + def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]: + return MambaStateCopyFuncCalculator.mamba2_state_copy_func() + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py index 955a73ff19ed..44417c98b1b2 100644 --- a/vllm/model_executor/models/minimax_text_01.py +++ b/vllm/model_executor/models/minimax_text_01.py @@ -35,6 +35,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.linear_attn import MiniMaxText01LinearAttention from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateCopyFunc, + MambaStateCopyFuncCalculator, MambaStateDtypeCalculator, MambaStateShapeCalculator, ) @@ -1006,3 +1008,7 @@ def get_mamba_state_shape_from_config( tp_size=parallel_config.tensor_parallel_size, head_dim=hf_config.head_dim, ) + + @classmethod + def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc]: + return MambaStateCopyFuncCalculator.linear_attention_state_copy_func() diff --git a/vllm/model_executor/models/nano_nemotron_vl.py b/vllm/model_executor/models/nano_nemotron_vl.py index 82f3f1362118..2f9497bfba19 100644 --- a/vllm/model_executor/models/nano_nemotron_vl.py +++ b/vllm/model_executor/models/nano_nemotron_vl.py @@ -2128,3 +2128,7 @@ def get_mamba_state_dtype_from_config(cls, vllm_config: "VllmConfig"): temp_vllm_config = copy.deepcopy(vllm_config) temp_vllm_config.model_config.hf_config = text_config return NemotronHForCausalLM.get_mamba_state_dtype_from_config(temp_vllm_config) + + @classmethod + def get_mamba_state_copy_func(cls): + return NemotronHForCausalLM.get_mamba_state_copy_func() diff --git a/vllm/model_executor/models/nemotron_h.py b/vllm/model_executor/models/nemotron_h.py index 8655cf66d209..3984b524feae 100644 --- a/vllm/model_executor/models/nemotron_h.py +++ b/vllm/model_executor/models/nemotron_h.py @@ -45,6 +45,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateCopyFunc, + MambaStateCopyFuncCalculator, MambaStateDtypeCalculator, MambaStateShapeCalculator, ) @@ -809,6 +811,10 @@ def get_mamba_state_shape_from_config( conv_kernel=hf_config.conv_kernel, ) + @classmethod + def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]: + return MambaStateCopyFuncCalculator.mamba2_state_copy_func() + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config self.vllm_config = vllm_config diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index 45512d23d269..24df17963e09 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -27,6 +27,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateCopyFunc, + MambaStateCopyFuncCalculator, MambaStateDtypeCalculator, MambaStateShapeCalculator, ) @@ -899,6 +901,10 @@ def get_mamba_state_shape_from_config( conv_kernel=hf_config.mamba_d_conv, ) + @classmethod + def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]: + return MambaStateCopyFuncCalculator.mamba2_state_copy_func() + def compute_logits( self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index e244e64740dc..bc70d2b54b12 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -48,6 +48,8 @@ from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.mamba.mamba_mixer2 import mamba_v2_sharded_weight_loader from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateCopyFunc, + MambaStateCopyFuncCalculator, MambaStateDtypeCalculator, MambaStateShapeCalculator, ) @@ -1205,9 +1207,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): cache_config = vllm_config.cache_config scheduler_config = vllm_config.scheduler_config - assert not cache_config.enable_prefix_caching, ( - "Qwen3Next currently does not support prefix caching" - ) + if cache_config.mamba_cache_mode == "all": + raise NotImplementedError( + "Qwen3Next currently does not support 'all' prefix caching, " + "please use '--mamba-cache-mode=align' instead" + ) self.quant_config = vllm_config.quant_config super().__init__() @@ -1278,6 +1282,10 @@ def get_mamba_state_shape_from_config( num_spec, ) + @classmethod + def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]: + return MambaStateCopyFuncCalculator.gated_delta_net_state_copy_func() + def compute_logits( self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/qwen3_next_mtp.py b/vllm/model_executor/models/qwen3_next_mtp.py index 565fd7d8f9b8..854d7f9a722a 100644 --- a/vllm/model_executor/models/qwen3_next_mtp.py +++ b/vllm/model_executor/models/qwen3_next_mtp.py @@ -234,9 +234,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config self.vllm_config = vllm_config cache_config = vllm_config.cache_config - assert not cache_config.enable_prefix_caching, ( - "Qwen3NextMTP currently does not support prefix caching" - ) + if cache_config.mamba_cache_mode == "all": + raise NotImplementedError( + "Qwen3NextMTP currently does not support 'all' prefix caching, " + "please use '--mamba-cache-mode=align' instead" + ) self.quant_config = vllm_config.quant_config diff --git a/vllm/model_executor/models/zamba2.py b/vllm/model_executor/models/zamba2.py index b5132cd86024..59a8520f7ff2 100644 --- a/vllm/model_executor/models/zamba2.py +++ b/vllm/model_executor/models/zamba2.py @@ -32,6 +32,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateCopyFunc, + MambaStateCopyFuncCalculator, MambaStateDtypeCalculator, MambaStateShapeCalculator, ) @@ -891,6 +893,10 @@ def get_mamba_state_shape_from_config( conv_kernel=hf_config.mamba_d_conv, ) + @classmethod + def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]: + return MambaStateCopyFuncCalculator.mamba2_state_copy_func() + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: """Initialize the Zamba2 model for causal language modeling. diff --git a/vllm/v1/attention/backends/gdn_attn.py b/vllm/v1/attention/backends/gdn_attn.py index 426c17689ee0..280105c2915c 100644 --- a/vllm/v1/attention/backends/gdn_attn.py +++ b/vllm/v1/attention/backends/gdn_attn.py @@ -16,6 +16,7 @@ from vllm.v1.attention.backends.utils import ( PAD_SLOT_ID, compute_causal_conv1d_metadata, + mamba_get_block_table_tensor, split_decodes_and_prefills, ) from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec @@ -153,6 +154,12 @@ def build( # type: ignore[override] query_start_loc = m.query_start_loc context_lens_tensor = m.compute_num_computed_tokens() nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None + block_table_tensor = mamba_get_block_table_tensor( + m.block_table_tensor, + m.seq_lens, + self.kv_cache_spec, + self.vllm_config.cache_config.mamba_cache_mode, + ) if ( not self.use_spec_decode @@ -182,7 +189,7 @@ def build( # type: ignore[override] spec_token_indx = None non_spec_token_indx = None spec_state_indices_tensor = None - non_spec_state_indices_tensor = m.block_table_tensor[:, 0] + non_spec_state_indices_tensor = block_table_tensor[:, 0] spec_query_start_loc = None non_spec_query_start_loc = query_start_loc num_accepted_tokens = None @@ -211,7 +218,7 @@ def build( # type: ignore[override] non_spec_token_indx = torch.empty( 0, dtype=torch.int32, device=query_start_loc.device ) - spec_state_indices_tensor = m.block_table_tensor[:, : self.num_spec + 1] + spec_state_indices_tensor = block_table_tensor[:, : self.num_spec + 1] non_spec_state_indices_tensor = None spec_query_start_loc = query_start_loc non_spec_query_start_loc = None @@ -224,10 +231,10 @@ def build( # type: ignore[override] non_spec_token_indx = index[:num_non_spec_tokens] spec_token_indx = index[num_non_spec_tokens:] - spec_state_indices_tensor = m.block_table_tensor[ + spec_state_indices_tensor = block_table_tensor[ spec_sequence_masks, : self.num_spec + 1 ] - non_spec_state_indices_tensor = m.block_table_tensor[ + non_spec_state_indices_tensor = block_table_tensor[ ~spec_sequence_masks, 0 ] diff --git a/vllm/v1/attention/backends/linear_attn.py b/vllm/v1/attention/backends/linear_attn.py index 4ef5656916dc..02551e704766 100644 --- a/vllm/v1/attention/backends/linear_attn.py +++ b/vllm/v1/attention/backends/linear_attn.py @@ -11,7 +11,10 @@ AttentionMetadataBuilder, CommonAttentionMetadata, ) -from vllm.v1.attention.backends.utils import split_decodes_and_prefills +from vllm.v1.attention.backends.utils import ( + mamba_get_block_table_tensor, + split_decodes_and_prefills, +) from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec @@ -57,7 +60,12 @@ def build( query_start_loc = common_attn_metadata.query_start_loc seq_lens = common_attn_metadata.seq_lens - state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] + state_indices_tensor = mamba_get_block_table_tensor( + common_attn_metadata.block_table_tensor, + common_attn_metadata.seq_lens, + self.kv_cache_spec, + self.vllm_config.cache_config.mamba_cache_mode, + )[:, 0] num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( split_decodes_and_prefills( diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index 0c55877a5675..e76981faa8ec 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -18,6 +18,7 @@ from vllm.v1.attention.backends.utils import ( PAD_SLOT_ID, compute_causal_conv1d_metadata, + mamba_get_block_table_tensor, split_decodes_and_prefills, ) from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec @@ -41,11 +42,15 @@ class BaseMambaAttentionMetadata: state_indices_tensor: torch.Tensor - # The following tensors are only used for prefix caching and are None if disabled + # The following tensors are only used for prefix caching in all mode and + # are None if disabled block_idx_last_scheduled_token: torch.Tensor | None block_idx_first_scheduled_token_p: torch.Tensor | None block_idx_last_computed_token: torch.Tensor | None + # The following tensor is only used for prefix caching in align mode + seq_lens: torch.Tensor + # The following attributes are for triton implementation of causal_conv1d nums_dict: dict | None = None batch_ptr: torch.Tensor | None = None @@ -78,7 +83,7 @@ def __init__( self.compilation_config.max_cudagraph_capture_size, ) - if self.vllm_config.cache_config.enable_prefix_caching: + if self.vllm_config.cache_config.mamba_cache_mode == "all": self.state_indices_tensor = torch.empty( ( self.decode_cudagraph_max_bs, @@ -198,7 +203,7 @@ def _compute_common_metadata( # for causal_conv1d nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None - if self.vllm_config.cache_config.enable_prefix_caching: + if self.vllm_config.cache_config.mamba_cache_mode == "all": num_computed_tokens = common_attn_metadata.compute_num_computed_tokens() # Return a tensor of shape (#requests, #max blocks) @@ -214,7 +219,12 @@ def _compute_common_metadata( ) else: # Always return just a single block per each request: - state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] + state_indices_tensor = mamba_get_block_table_tensor( + common_attn_metadata.block_table_tensor, + common_attn_metadata.seq_lens, + self.kv_cache_spec, + self.vllm_config.cache_config.mamba_cache_mode, + )[:, 0] if num_prefills > 0: if num_computed_tokens is None: @@ -236,7 +246,7 @@ def _compute_common_metadata( compute_causal_conv1d_metadata(query_start_loc_p) ) - if self.vllm_config.cache_config.enable_prefix_caching: + if self.vllm_config.cache_config.mamba_cache_mode == "all": assert num_computed_tokens is not None num_computed_tokens_p = num_computed_tokens[ num_reqs - num_prefills : num_reqs @@ -255,7 +265,7 @@ def _compute_common_metadata( state_indices_tensor = self.state_indices_tensor[:num_decode_tokens] state_indices_tensor[num_decodes:] = PAD_SLOT_ID - if self.vllm_config.cache_config.enable_prefix_caching: + if self.vllm_config.cache_config.mamba_cache_mode == "all": self.block_idx_last_scheduled_token[:num_decodes].copy_( block_idx_last_scheduled_token, non_blocking=True ) @@ -283,6 +293,7 @@ def _compute_common_metadata( block_idx_last_computed_token=block_idx_last_computed_token, num_computed_tokens_p=num_computed_tokens_p, num_reqs=num_reqs, + seq_lens=common_attn_metadata.seq_lens, nums_dict=nums_dict, batch_ptr=batch_ptr, token_chunk_offset_ptr=token_chunk_offset_ptr, @@ -295,8 +306,16 @@ def update_block_table( slot_mapping: torch.Tensor, ) -> M: new_metadata = copy.copy(metadata) - prefix_caching = self.vllm_config.cache_config.enable_prefix_caching - state_indices_t = blk_table if prefix_caching else blk_table[:, 0] + state_indices_t = mamba_get_block_table_tensor( + blk_table, + metadata.seq_lens, + self.kv_cache_spec, + self.vllm_config.cache_config.mamba_cache_mode, + ) + if self.vllm_config.cache_config.mamba_cache_mode in ("none", "align"): + # Only needs the block that saves the running state + state_indices_t = state_indices_t[:, 0] + num_reqs = blk_table.shape[0] # For CUDA graphs, copy to persistent buffer diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 82321c0008b4..f22d54fefd98 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -17,6 +17,7 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.utils.math_utils import cdiv +from vllm.v1.kv_cache_interface import KVCacheSpec, MambaSpec if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput @@ -850,3 +851,40 @@ def extend_all_queries_by_1( slot_mapping=new_slot_mapping, ) return new_cad + + +def mamba_get_block_table_tensor( + block_table: torch.Tensor, + seq_lens: torch.Tensor, + kv_cache_spec: KVCacheSpec, + mamba_cache_mode: str, +) -> torch.Tensor: + """ + Get the block table tensor for mamba kernels from the input + common_attn_metadata.block_table_tensor given different mamba cache modes. + + - "all": input (#requests, cdiv(max_model_len, block_size)); + output (#requests, cdiv(max_model_len, block_size)). + + - "none": input (#requests, 1 + num_speculative_blocks); + output (#requests, 1 + num_speculative_blocks). + + - "align": input (#requests, cdiv(max_model_len, block_size)); + output (#requests, 1 + num_speculative_blocks), which are the last + 1 + num_speculative_blocks of each request. + """ + if mamba_cache_mode in ("all", "none"): + return block_table + else: + assert isinstance(kv_cache_spec, MambaSpec) + # NOTE: For 0-length requests in CUDA graph, use a start_index of 0 + # to handle the invalid block table. + start_indices = torch.clamp( + (seq_lens - 1) // kv_cache_spec.block_size, + min=0, + ) + offsets = torch.arange( + 1 + kv_cache_spec.num_speculative_blocks, device=block_table.device + ) + indices_to_gather = start_indices.unsqueeze(1) + offsets + return torch.gather(block_table, 1, indices_to_gather) diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index cf93218a1873..ce7e396d8a9a 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -255,7 +255,8 @@ def cache_full_blocks( ) for i, blk in enumerate(new_full_blocks): # Some blocks may be null blocks when enabling sparse attention like - # sliding window attention. We skip null blocks here. + # sliding window attention, or Mamba models with prefix-caching in + # align mode. We skip null blocks here. if blk.is_null: continue assert blk.block_hash is None diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index 4550e2b79562..c72fbb7be193 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -75,6 +75,7 @@ def get_num_blocks_to_allocate( new_computed_blocks: tuple[Sequence[KVCacheBlock], ...], num_encoder_tokens: int, total_computed_tokens: int, + num_tokens_main_model: int, ) -> int: """ Get the number of blocks needed to be allocated for the request. @@ -88,6 +89,9 @@ def get_num_blocks_to_allocate( num_encoder_tokens: The number of encoder tokens for allocating blocks for cross-attention. total_computed_tokens: Include both local and external tokens. + num_tokens_main_model: The number of tokens for the main model (aka target + model in spec decode). w/o spec decode, it is num_tokens; + with spec decode, it is num_tokens - num_lookahead_tokens. Returns: The number of blocks to allocate. @@ -98,7 +102,7 @@ def get_num_blocks_to_allocate( # For cross-attention, we issue a single static allocation # of blocks based on the number of encoder input tokens. num_blocks_to_allocate += manager.get_num_blocks_to_allocate( - request_id, num_encoder_tokens, [], 0 + request_id, num_encoder_tokens, [], 0, num_encoder_tokens ) else: num_blocks_to_allocate += manager.get_num_blocks_to_allocate( @@ -106,6 +110,7 @@ def get_num_blocks_to_allocate( num_tokens, new_computed_blocks[i], total_computed_tokens, + num_tokens_main_model, ) return num_blocks_to_allocate @@ -139,6 +144,7 @@ def allocate_new_blocks( self, request_id: str, num_tokens: int, + num_tokens_main_model: int, num_encoder_tokens: int = 0, ) -> tuple[list[KVCacheBlock], ...]: """ @@ -149,6 +155,9 @@ def allocate_new_blocks( request_id: The request ID. num_tokens: The total number of tokens that need a slot (including tokens that are already allocated). + num_tokens_main_model: The number of tokens for the main model (aka target + model in spec decode). w/o spec decode, it is num_tokens; + with spec decode, it is num_tokens - num_lookahead_tokens. num_encoder_tokens: The number of encoder tokens for allocating blocks for cross-attention. @@ -161,6 +170,7 @@ def allocate_new_blocks( num_encoder_tokens if isinstance(manager, CrossAttentionManager) else num_tokens, + num_tokens_main_model, ) for manager in self.single_type_managers ) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 2197107c1fc6..2caed0493752 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -307,8 +307,9 @@ def allocate_slots( num_local_computed_tokens + num_external_computed_tokens, self.max_model_len, ) + num_tokens_main_model = total_computed_tokens + num_new_tokens num_tokens_need_slot = min( - total_computed_tokens + num_new_tokens + num_lookahead_tokens, + num_tokens_main_model + num_lookahead_tokens, self.max_model_len, ) @@ -329,6 +330,7 @@ def allocate_slots( num_encoder_tokens=num_encoder_tokens, total_computed_tokens=num_local_computed_tokens + num_external_computed_tokens, + num_tokens_main_model=num_tokens_main_model, ) if num_blocks_to_allocate > self.block_pool.get_num_free_blocks(): @@ -349,7 +351,10 @@ def allocate_slots( ) new_blocks = self.coordinator.allocate_new_blocks( - request.request_id, num_tokens_need_slot, num_encoder_tokens + request.request_id, + num_tokens_need_slot, + num_tokens_main_model, + num_encoder_tokens, ) # P/D: delay caching blocks if we have to recv from diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 0cb65bd0f779..ec9a5d799021 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -47,7 +47,7 @@ from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_queue from vllm.v1.core.sched.utils import check_stop, remove_all from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs -from vllm.v1.kv_cache_interface import KVCacheConfig +from vllm.v1.kv_cache_interface import KVCacheConfig, MambaSpec from vllm.v1.metrics.perf import ModelMetrics, PerfStats from vllm.v1.metrics.stats import ( PrefixCacheStats, @@ -226,6 +226,17 @@ def __init__( ) self.use_pp = self.parallel_config.pipeline_parallel_size > 1 self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER + + def has_mamba_layers(kv_cache_config: KVCacheConfig) -> bool: + return any( + isinstance(group_spec.kv_cache_spec, MambaSpec) + for group_spec in kv_cache_config.kv_cache_groups + ) + + self.has_mamba_layers = has_mamba_layers(kv_cache_config) + self.need_mamba_block_aligned_split = ( + self.has_mamba_layers and self.cache_config.mamba_cache_mode == "align" + ) self.perf_metrics: ModelMetrics | None = None if self.log_stats and vllm_config.observability_config.enable_mfu_metrics: self.perf_metrics = ModelMetrics(vllm_config) @@ -250,6 +261,53 @@ def __init__( vllm_config=self.vllm_config, ) + def _mamba_block_aligned_split( + self, + request: Request, + num_new_tokens: int, + num_new_local_computed_tokens: int = 0, + num_external_computed_tokens: int = 0, + ) -> int: + assert num_external_computed_tokens == 0, ( + "External KV connector is not verified yet" + ) + # TODO: need check for resume requests + if request.num_output_tokens == 0: # prefill + # To enable block-aligned caching of the Mamba state, `num_new_tokens` + # must be a multiple of `block_size`. + # As an exception, if `num_new_tokens` is less than `block_size`, the + # state is simply not cached, requiring no special handling. + # Additionally, when Eagle mode is enabled, FullAttn prunes the last + # matching block. To prevent this from causing a Mamba cache miss, the + # last chunk must be larger than `block_size`. + block_size = self.cache_config.block_size + last_cache_position = ( + request.num_prompt_tokens - request.num_prompt_tokens % block_size + ) + # eagle prune + if self.use_eagle: + last_cache_position = max(last_cache_position - block_size, 0) + num_computed_tokens = ( + request.num_computed_tokens + + num_new_local_computed_tokens + + num_external_computed_tokens + ) + num_computed_tokens_after_sched = num_computed_tokens + num_new_tokens + if num_computed_tokens_after_sched < last_cache_position: + # align to block_size + num_new_tokens = num_new_tokens // block_size * block_size + elif ( + num_computed_tokens + < last_cache_position + < num_computed_tokens_after_sched + ): + # force to cache the last chunk + num_new_tokens = last_cache_position - num_computed_tokens + else: + # prefill the last few tokens + pass + return num_new_tokens + def schedule(self) -> SchedulerOutput: # NOTE(woosuk) on the scheduling algorithm: # There's no "decoding phase" nor "prefill phase" in the scheduler. @@ -340,6 +398,11 @@ def schedule(self) -> SchedulerOutput: shift_computed_tokens=1 if self.use_eagle else 0, ) + if self.need_mamba_block_aligned_split: + num_new_tokens = self._mamba_block_aligned_split( + request, num_new_tokens + ) + if num_new_tokens == 0: # The request cannot be scheduled because one of the following # reasons: @@ -350,6 +413,8 @@ def schedule(self) -> SchedulerOutput: # its max_total_tokens or max_model_len. # 2. The encoder budget is exhausted. # 3. The encoder cache is exhausted. + # 4. Insufficient budget for a block-aligned chunk in hybrid + # models with mamba cache mode \"align\". # NOTE(woosuk): Here, by doing `continue` instead of `break`, # we do not strictly follow the FCFS scheduling policy and # allow the lower-priority requests to be scheduled. @@ -608,6 +673,16 @@ def schedule(self) -> SchedulerOutput: # The request cannot be scheduled. break + if self.need_mamba_block_aligned_split: + num_new_tokens = self._mamba_block_aligned_split( + request, + num_new_tokens, + num_new_local_computed_tokens, + num_external_computed_tokens, + ) + if num_new_tokens == 0: + break + # Handles an edge case when P/D Disaggregation # is used with Spec Decoding where an # extra block gets allocated which diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index aed5c0580b28..9918d6ffd2d9 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -66,12 +66,17 @@ def __init__( self.kv_cache_group_id = kv_cache_group_id self._null_block = block_pool.null_block + @classmethod + def _get_num_evictable_blocks(cls, blocks: Sequence[KVCacheBlock]): + return sum(blk.ref_cnt == 0 and not blk.is_null for blk in blocks) + def get_num_blocks_to_allocate( self, request_id: str, num_tokens: int, new_computed_blocks: Sequence[KVCacheBlock], total_computed_tokens: int, + num_tokens_main_model: int, ) -> int: """ Get the number of blocks needed to be allocated for the request. @@ -84,6 +89,9 @@ def get_num_blocks_to_allocate( prefix caching. total_computed_tokens: Include both local and external computed tokens. + num_tokens_main_model: The number of tokens for the main model (aka target + model in spec decode). w/o spec decode, it is num_tokens; + with spec decode, it is num_tokens - num_lookahead_tokens. Returns: The number of blocks to allocate. @@ -121,9 +129,8 @@ def get_num_blocks_to_allocate( # If a computed block is an eviction candidate (in the free queue and # ref_cnt == 0), it will be removed from the free queue when touched by # the allocated request, so we must count it in the free-capacity check. - num_evictable_blocks = sum( - blk.ref_cnt == 0 and not blk.is_null - for blk in new_computed_blocks[num_skipped_new_computed_blocks:] + num_evictable_blocks = self._get_num_evictable_blocks( + new_computed_blocks[num_skipped_new_computed_blocks:] ) return num_new_blocks + num_evictable_blocks @@ -199,7 +206,7 @@ def allocate_new_computed_blocks( req_blocks.extend(allocated_blocks) def allocate_new_blocks( - self, request_id: str, num_tokens: int + self, request_id: str, num_tokens: int, num_tokens_main_model: int ) -> list[KVCacheBlock]: """ Allocate new blocks for the request to give it at least `num_tokens` @@ -209,7 +216,9 @@ def allocate_new_blocks( request_id: The request ID. num_tokens: The total number of tokens that need a slot (including tokens that are already allocated). - + num_tokens_main_model: The number of tokens for the main model (aka target + model in spec decode). w/o spec decode, it is num_tokens; + with spec decode, it is num_tokens - num_lookahead_tokens. Returns: The new allocated blocks. """ @@ -450,12 +459,9 @@ def get_num_common_prefix_blocks(self, running_request_id: str) -> int: class SlidingWindowManager(SingleTypeKVCacheManager): - def __init__( - self, kv_cache_spec: SlidingWindowSpec, block_pool: BlockPool, **kwargs - ) -> None: - super().__init__(kv_cache_spec, block_pool, **kwargs) + def __init__(self, kv_cache_spec: SlidingWindowSpec, **kwargs) -> None: + super().__init__(kv_cache_spec, **kwargs) self.sliding_window = kv_cache_spec.sliding_window - self._null_block = block_pool.null_block @classmethod def find_longest_cache_hit( @@ -586,12 +592,9 @@ def get_num_common_prefix_blocks(self, running_request_id: str) -> int: class ChunkedLocalAttentionManager(SingleTypeKVCacheManager): - def __init__( - self, kv_cache_spec: ChunkedLocalAttentionSpec, block_pool: BlockPool, **kwargs - ) -> None: - super().__init__(kv_cache_spec, block_pool, **kwargs) + def __init__(self, kv_cache_spec: ChunkedLocalAttentionSpec, **kwargs) -> None: + super().__init__(kv_cache_spec, **kwargs) self.attention_chunk_size = kv_cache_spec.attention_chunk_size - self._null_block = block_pool.null_block @classmethod def find_longest_cache_hit( @@ -739,6 +742,17 @@ def get_num_common_prefix_blocks(self, running_request_id: str) -> int: class MambaManager(SingleTypeKVCacheManager): + def __init__(self, kv_cache_spec: MambaSpec, **kwargs) -> None: + super().__init__(kv_cache_spec, **kwargs) + self.mamba_cache_mode = kv_cache_spec.mamba_cache_mode + self.num_speculative_blocks: int = kv_cache_spec.num_speculative_blocks + if self.mamba_cache_mode == "align": + # Mapping from request ID to the index of the block + # allocated in the previous step + self.last_state_block_idx: dict[str, int] = {} + # The set of the requests that have been allocated blocks + self._allocated_block_reqs: set[str] = set() + @classmethod def find_longest_cache_hit( cls, @@ -787,6 +801,28 @@ def find_longest_cache_hit( return computed_blocks + def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None: + assert isinstance(self.kv_cache_spec, MambaSpec) + super().remove_skipped_blocks(request_id, num_computed_tokens) + if self.mamba_cache_mode == "align": + # `last_state_block_idx` refers to the block index allocated two steps ago. + # The block allocated in the previous step is used to copy Mamba states + # into the block allocated in the current step; the earlier block is + # no longer needed and should be freed here. + last_state_block_idx = self.last_state_block_idx.get(request_id) + # Blocks allocated during prefill may be non-contiguous. Use + # `last_state_block_idx` to free the appropriate block and replace it + # with a null block. + if ( + last_state_block_idx is not None + and last_state_block_idx + < cdiv(num_computed_tokens, self.block_size) - 1 + ): + blocks = self.req_to_blocks[request_id] + if blocks[last_state_block_idx] != self._null_block: + self.block_pool.free_blocks([blocks[last_state_block_idx]]) + blocks[last_state_block_idx] = self._null_block + def get_num_common_prefix_blocks(self, running_request_id: str) -> int: """ cascade attention is not supported by mamba @@ -799,31 +835,134 @@ def get_num_blocks_to_allocate( num_tokens: int, new_computed_blocks: Sequence[KVCacheBlock], total_computed_tokens: int, + num_tokens_main_model: int, ) -> int: - # Allocate extra `num_speculative_blocks` blocks for - # speculative decoding (MTP/EAGLE) with linear attention. assert isinstance(self.kv_cache_spec, MambaSpec) - if self.kv_cache_spec.num_speculative_blocks > 0: - num_tokens += ( - self.kv_cache_spec.block_size - * self.kv_cache_spec.num_speculative_blocks + if self.mamba_cache_mode != "align": + # Allocate extra `num_speculative_blocks` blocks for + # speculative decoding (MTP/EAGLE) with linear attention. + if self.num_speculative_blocks > 0: + num_tokens += ( + self.kv_cache_spec.block_size * self.num_speculative_blocks + ) + return super().get_num_blocks_to_allocate( + request_id, + num_tokens, + new_computed_blocks, + total_computed_tokens, + num_tokens_main_model, ) - return super().get_num_blocks_to_allocate( - request_id, num_tokens, new_computed_blocks, total_computed_tokens - ) + else: + # We don't allocate blocks for lookahead tokens in align mode, because if + # x * block_size tokens are scheduled, num_tokens is + # x * block_size + num_lookahead_tokens and breaks the alignment. + # We can ignore lookahead tokens because current draft models don't have + # mamba layers. + num_tokens = num_tokens_main_model + num_required_blocks = ( + cdiv(num_tokens, self.block_size) + self.num_speculative_blocks + ) + num_new_blocks = ( + num_required_blocks + - len(new_computed_blocks) + - len(self.req_to_blocks[request_id]) + ) + if num_new_blocks > 0: + if request_id in self._allocated_block_reqs: + # Old request. Needs at most 1 more blocks as we can reuse the + # speculative blocks in previous step. + num_new_blocks = 1 + else: + # First prefill. Allocate 1 block for running state and the + # speculative blocks. + num_new_blocks = 1 + self.num_speculative_blocks + + num_evictable_computed_blocks = self._get_num_evictable_blocks( + new_computed_blocks + ) + return num_new_blocks + num_evictable_computed_blocks def allocate_new_blocks( - self, request_id: str, num_tokens: int + self, request_id: str, num_tokens: int, num_tokens_main_model: int ) -> list[KVCacheBlock]: - # Allocate extra `num_speculative_blocks` blocks for - # speculative decoding (MTP/EAGLE) with linear attention. assert isinstance(self.kv_cache_spec, MambaSpec) - if self.kv_cache_spec.num_speculative_blocks > 0: - num_tokens += ( - self.kv_cache_spec.block_size - * self.kv_cache_spec.num_speculative_blocks + if self.mamba_cache_mode != "align": + # Allocate extra `num_speculative_blocks` blocks for + # speculative decoding (MTP/EAGLE) with linear attention. + if self.num_speculative_blocks > 0: + num_tokens += self.block_size * self.num_speculative_blocks + return super().allocate_new_blocks( + request_id, num_tokens, num_tokens_main_model ) - return super().allocate_new_blocks(request_id, num_tokens) + else: + # We don't allocate blocks for lookahead tokens in align mode, because if + # x * block_size tokens are scheduled, num_tokens is + # x * block_size + num_lookahead_tokens and breaks the alignment. + # We can ignore lookahead tokens because current draft models don't have + # mamba layers. + num_tokens = num_tokens_main_model + req_blocks: list[KVCacheBlock] = self.req_to_blocks[request_id] + num_required_blocks = ( + cdiv(num_tokens, self.block_size) + self.num_speculative_blocks + ) + if num_required_blocks == len(req_blocks): + return [] + else: + assert num_required_blocks > len(req_blocks), ( + "num_required_blocks " + f"{num_required_blocks} < len(req_blocks) {len(req_blocks)}" + ) + prev_block_len = len(req_blocks) + blocks_allocated = request_id in self._allocated_block_reqs + # Record the last state block + if blocks_allocated: + # We always save the running state at the last + # (1 + num_speculative_blocks) block + self.last_state_block_idx[request_id] = ( + prev_block_len - 1 - self.num_speculative_blocks + ) + elif prev_block_len > 0: + # When a new request hits the prefix cache, the last block + # saves the hit state. + self.last_state_block_idx[request_id] = prev_block_len - 1 + + num_skipped_blocks = ( + num_required_blocks - self.num_speculative_blocks - 1 + ) + # null blocks + if prev_block_len < num_skipped_blocks: + req_blocks.extend( + [ + self._null_block + for _ in range(prev_block_len, num_skipped_blocks) + ] + ) + + if blocks_allocated: + # reuse previous speculative blocks in this step + for block_idx in range( + prev_block_len - self.num_speculative_blocks, prev_block_len + ): + if block_idx < num_skipped_blocks: + req_blocks.append(req_blocks[block_idx]) + req_blocks[block_idx] = self._null_block + else: + break + num_new_blocks = num_required_blocks - len(req_blocks) + if blocks_allocated: + assert num_new_blocks <= 1 + else: + assert num_new_blocks <= self.num_speculative_blocks + 1 + new_blocks = self.block_pool.get_new_blocks(num_new_blocks) + req_blocks.extend(new_blocks) + self._allocated_block_reqs.add(request_id) + return req_blocks[prev_block_len:] + + def free(self, request_id: str) -> None: + if self.mamba_cache_mode == "align": + self._allocated_block_reqs.discard(request_id) + self.last_state_block_idx.pop(request_id, None) + super().free(request_id) def get_num_skipped_tokens(self, num_computed_tokens: int) -> int: """ diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 5c9913bb095b..27c6f7da25f7 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -276,6 +276,7 @@ class MambaSpec(KVCacheSpec): dtypes: tuple[torch.dtype] page_size_padded: int | None = None mamba_type: str = "mamba2" + mamba_cache_mode: str = "none" num_speculative_blocks: int = 0 @property @@ -290,8 +291,13 @@ def page_size_bytes(self) -> int: return page_size def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: - max_model_len = vllm_config.model_config.max_model_len - return cdiv(max_model_len, self.block_size) * self.page_size_bytes + if vllm_config.cache_config.mamba_cache_mode == "all": + max_model_len = vllm_config.model_config.max_model_len + return cdiv(max_model_len, self.block_size) * self.page_size_bytes + elif vllm_config.cache_config.mamba_cache_mode == "align": + return self.page_size_bytes * (2 + self.num_speculative_blocks) + else: + return self.page_size_bytes * (1 + self.num_speculative_blocks) @dataclass(frozen=True) diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index 37ec0fb97e06..591f49761a0e 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -8,6 +8,7 @@ from vllm.logger import init_logger from vllm.utils.math_utils import cdiv from vllm.v1.utils import CpuGpuBuffer +from vllm.v1.worker.cp_utils import get_total_cp_world_size logger = init_logger(__name__) @@ -261,47 +262,45 @@ def __init__( device: torch.device, block_sizes: list[int], kernel_block_sizes: list[int], - num_speculative_tokens: int = 0, + max_num_blocks: list[int] | None = None, cp_kv_cache_interleave_size: int = 1, ) -> None: - # Note(hc): each dcp rank only store - # (max_model_len//dcp_world_size) tokens in kvcache, - # so the block_size which used for calc max_num_blocks_per_req - # must be multiplied by dcp_world_size. - try: - pcp_world_size = get_pcp_group().world_size - except AssertionError: - # PCP might not be initialized in testing - pcp_world_size = 1 - try: - dcp_world_size = get_dcp_group().world_size - except AssertionError: - # DCP might not be initialized in testing - dcp_world_size = 1 - if len(kernel_block_sizes) != len(block_sizes): raise ValueError( f"kernel_block_sizes length ({len(kernel_block_sizes)}) " f"must match block_sizes length ({len(block_sizes)})" ) - - total_cp_world_size = dcp_world_size * pcp_world_size + if max_num_blocks is None: + # Note(hc): each dcp rank only store + # (max_model_len//dcp_world_size) tokens in kvcache, + # so the block_size which used for calc max_num_blocks_per_req + # must be multiplied by dcp_world_size. + total_cp_world_size = get_total_cp_world_size() + max_num_blocks = [ + cdiv(max_model_len, block_size * total_cp_world_size) + for block_size in block_sizes + ] + + if len(max_num_blocks) != len(block_sizes): + raise ValueError( + f"max_num_blocks length ({len(max_num_blocks)}) " + f"must match block_sizes length ({len(block_sizes)})" + ) self.block_tables = [ BlockTable( block_size, max_num_reqs, - max( - cdiv(max_model_len, block_size * total_cp_world_size), - 1 + num_speculative_tokens, - ), + max_num_blocks_per_req, max_num_batched_tokens, pin_memory, device, kernel_block_size, cp_kv_cache_interleave_size, ) - for block_size, kernel_block_size in zip(block_sizes, kernel_block_sizes) + for block_size, kernel_block_size, max_num_blocks_per_req in zip( + block_sizes, kernel_block_sizes, max_num_blocks + ) ] def append_row(self, block_ids: tuple[list[int], ...], row_idx: int) -> None: diff --git a/vllm/v1/worker/cp_utils.py b/vllm/v1/worker/cp_utils.py index f666c739b0be..2c2e0b5cdbe2 100644 --- a/vllm/v1/worker/cp_utils.py +++ b/vllm/v1/worker/cp_utils.py @@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Any, cast from vllm.config import VllmConfig, get_layers_from_vllm_config +from vllm.distributed import get_dcp_group, get_pcp_group if TYPE_CHECKING: from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase @@ -40,3 +41,17 @@ def check_attention_cp_compatibility(vllm_config: VllmConfig) -> None: f"but the impl {layer_impl.__class__.__name__} " "does not support PCP." ) + + +def get_total_cp_world_size(): + try: + pcp_world_size = get_pcp_group().world_size + except AssertionError: + # PCP might not be initialized in testing + pcp_world_size = 1 + try: + dcp_world_size = get_dcp_group().world_size + except AssertionError: + # DCP might not be initialized in testing + dcp_world_size = 1 + return dcp_world_size * pcp_world_size diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 662badeb5f1a..c70970fdc06e 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -89,11 +89,11 @@ def __init__( vocab_size: int, block_sizes: list[int], # The block_size of each kv cache group kernel_block_sizes: list[int], + max_num_blocks_per_req: list[int] | None = None, logitsprocs: LogitsProcessors | None = None, logitsprocs_need_output_token_ids: bool = False, is_spec_decode: bool = False, is_pooling_model: bool = False, - num_speculative_tokens: int = 0, cp_kv_cache_interleave_size: int = 1, ): self.is_pooling_model = is_pooling_model @@ -146,7 +146,7 @@ def __init__( device=device, block_sizes=block_sizes, kernel_block_sizes=kernel_block_sizes, - num_speculative_tokens=num_speculative_tokens, + max_num_blocks=max_num_blocks_per_req, cp_kv_cache_interleave_size=cp_kv_cache_interleave_size, ) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 982ae44c2def..29b4178630e5 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -152,7 +152,11 @@ from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer from vllm.v1.structured_output.utils import apply_grammar_bitmask from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext -from vllm.v1.worker.cp_utils import check_attention_cp_compatibility +from vllm.v1.worker import mamba_utils +from vllm.v1.worker.cp_utils import ( + check_attention_cp_compatibility, + get_total_cp_world_size, +) from vllm.v1.worker.dp_utils import coordinate_batch_across_dp from vllm.v1.worker.ec_connector_model_runner_mixin import ECConnectorModelRunnerMixin from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch @@ -688,6 +692,7 @@ def __init__( # Ephemeral state transferred between execute_model() and sample_tokens(). self.execute_model_state: ExecuteModelState | None = None self.kv_connector_output: KVConnectorOutput | None = None + self.mamba_state_idx: dict[str, int] = {} self.layerwise_nvtx_hooks_registered = False def update_max_model_len(self, max_model_len: int) -> None: @@ -1075,7 +1080,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: self.input_batch.refresh_metadata() def _update_states_after_model_execute( - self, output_token_ids: torch.Tensor + self, output_token_ids: torch.Tensor, scheduler_output: "SchedulerOutput" ) -> None: """Update the cached states after model execution. @@ -1111,6 +1116,16 @@ def _update_states_after_model_execute( ) for i, num_tokens in enumerate(num_accepted_tokens): self.input_batch.num_accepted_tokens_cpu[i] = num_tokens + if self.cache_config.mamba_cache_mode == "align": + mamba_utils.postprocess_mamba( + scheduler_output, + self.kv_cache_config, + self.input_batch, + self.requests, + self.mamba_state_idx, + self.compilation_config.static_forward_context, + self.model.get_mamba_state_copy_func(), + ) def _init_mrope_positions(self, req_state: CachedRequestState): model = self.get_model() @@ -2751,7 +2766,6 @@ def _sample( logits, sampling_metadata, ) - self._update_states_after_model_execute(sampler_output.sampled_token_ids) return sampler_output def _bookkeeping_sync( @@ -3237,6 +3251,18 @@ def execute_model( pad_attn = cudagraph_mode == CUDAGraphMode.FULL + if self.cache_config.mamba_cache_mode == "align": + mamba_utils.preprocess_mamba( + scheduler_output, + self.kv_cache_config, + self.cache_config, + self.mamba_state_idx, + self.input_batch, + self.requests, + self.compilation_config.static_forward_context, + self.model.get_mamba_state_copy_func(), + ) + use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 ubatch_slices_attn = ubatch_slices_padded if pad_attn else ubatch_slices @@ -3423,6 +3449,10 @@ def sample_tokens( with record_function_or_nullcontext("gpu_model_runner: sample"): sampler_output = self._sample(logits, spec_decode_metadata) + self._update_states_after_model_execute( + sampler_output.sampled_token_ids, scheduler_output + ) + self._draft_token_ids = None self._draft_token_req_ids = None self.input_batch.prev_sampled_token_ids = None @@ -5322,6 +5352,24 @@ def may_reinitialize_input_batch( for kv_cache_group in kv_cache_config.kv_cache_groups if not isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec) ] + max_num_blocks = [] + max_model_len = max(self.max_model_len, self.max_encoder_len) + for i, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups): + if isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec): + continue + max_num_blocks_per_req = cdiv( + max_model_len, block_sizes[i] * get_total_cp_world_size() + ) + if isinstance(kv_cache_group.kv_cache_spec, MambaSpec): + mamba_blocks_per_req = ( + max_num_blocks_per_req + if self.cache_config.enable_prefix_caching + else 1 + ) + kv_cache_group.kv_cache_spec.num_speculative_blocks + max_num_blocks_per_req = max( + max_num_blocks_per_req, mamba_blocks_per_req + ) + max_num_blocks.append(max_num_blocks_per_req) if block_sizes != [self.cache_config.block_size] or kernel_block_sizes != [ self.cache_config.block_size @@ -5333,18 +5381,18 @@ def may_reinitialize_input_batch( ) self.input_batch = InputBatch( max_num_reqs=self.max_num_reqs, - max_model_len=max(self.max_model_len, self.max_encoder_len), + max_model_len=max_model_len, max_num_batched_tokens=self.max_num_tokens, device=self.device, pin_memory=self.pin_memory, vocab_size=self.model_config.get_vocab_size(), block_sizes=block_sizes, kernel_block_sizes=kernel_block_sizes, + max_num_blocks_per_req=max_num_blocks, is_spec_decode=bool(self.vllm_config.speculative_config), logitsprocs=self.input_batch.logitsprocs, logitsprocs_need_output_token_ids=self.input_batch.logitsprocs_need_output_token_ids, is_pooling_model=self.is_pooling_model, - num_speculative_tokens=self.num_spec_tokens, ) def _allocate_kv_cache_tensors( diff --git a/vllm/v1/worker/mamba_utils.py b/vllm/v1/worker/mamba_utils.py new file mode 100644 index 000000000000..a0a1ae224f2a --- /dev/null +++ b/vllm/v1/worker/mamba_utils.py @@ -0,0 +1,232 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import itertools +from typing import Any + +import torch +import triton +import triton.language as tl + +from vllm.config import CacheConfig +from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateCopyFunc, +) +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.kv_cache_interface import KVCacheConfig, MambaSpec +from vllm.v1.worker.gpu_input_batch import CachedRequestState +from vllm.v1.worker.lora_model_runner_mixin import GPUInputBatch + + +@triton.jit +def batch_memcpy_kernel(src_ptrs, dst_ptrs, sizes, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + + src_ptr = tl.load(src_ptrs + pid) + dst_ptr = tl.load(dst_ptrs + pid) + size = tl.load(sizes + pid) + + offsets = tl.arange(0, BLOCK_SIZE) + for i in range(0, size, BLOCK_SIZE): + mask = (i + offsets) < size + + curr_src_ptr = (src_ptr + i + offsets).to(tl.pointer_type(tl.uint8)) + curr_dst_ptr = (dst_ptr + i + offsets).to(tl.pointer_type(tl.uint8)) + + data = tl.load(curr_src_ptr, mask=mask) + tl.store(curr_dst_ptr, data, mask=mask) + + +def batch_memcpy(src_ptrs, dst_ptrs, sizes): + batch = src_ptrs.shape[0] + assert dst_ptrs.shape[0] == batch + assert sizes.shape[0] == batch + + grid = (batch,) + BLOCK_SIZE = 1024 + batch_memcpy_kernel[grid](src_ptrs, dst_ptrs, sizes, BLOCK_SIZE=BLOCK_SIZE) + + +def get_mamba_groups(kv_cache_config: KVCacheConfig) -> tuple[list[int], MambaSpec]: + mamba_group_ids: list[int] = [] + mamba_specs: list[MambaSpec] = [] + for i in range(len(kv_cache_config.kv_cache_groups)): + kv_cache_spec = kv_cache_config.kv_cache_groups[i].kv_cache_spec + if isinstance(kv_cache_spec, MambaSpec): + mamba_group_ids.append(i) + mamba_specs.append(kv_cache_spec) + assert len(mamba_group_ids) > 0, "no mamba layers in the model" + assert all(mamba_specs[0] == spec for spec in mamba_specs) + return mamba_group_ids, mamba_specs[0] + + +def collect_mamba_copy_meta( + src_state_list: list[int], + dest_state_list: list[int], + num_elements_list: list[int], + kv_cache_config: KVCacheConfig, + mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...], + mamba_group_ids: list[int], + src_block_idx: int, + dest_block_idx: int, + accept_token_bias: int, + req_state: CachedRequestState, + forward_context: dict[str, Any], +): + if src_block_idx == dest_block_idx and accept_token_bias == 0: + return + + for mamba_group_id in mamba_group_ids: + block_ids = req_state.block_ids[mamba_group_id] + dest_block_id = block_ids[dest_block_idx] + layer_names = kv_cache_config.kv_cache_groups[mamba_group_id].layer_names + for layer_name in layer_names: + attention = forward_context[layer_name] + kv_caches: list[torch.Tensor] = attention.kv_cache[0] + for state, state_copy_func in zip(kv_caches, mamba_state_copy_funcs): + copy_spec = state_copy_func( + state, block_ids, src_block_idx, accept_token_bias + 1 + ) + + src_state_list.append(copy_spec.start_addr) + dest_state_list.append(state[dest_block_id].data_ptr()) + num_elements_list.append(copy_spec.num_elements * state.element_size()) + + +def do_mamba_copy_block( + src_state_list: list[int], + dest_state_list: list[int], + num_elements_list: list[int], +): + if len(src_state_list) == 0: + return + assert len(src_state_list) == len(dest_state_list) + assert len(src_state_list) == len(num_elements_list) + src_state_ptrs = torch.tensor(src_state_list, device="cuda", dtype=torch.int64) + dst_state_ptrs = torch.tensor(dest_state_list, device="cuda", dtype=torch.int64) + num_elements = torch.tensor(num_elements_list, device="cuda", dtype=torch.int32) + + batch_memcpy(src_state_ptrs, dst_state_ptrs, num_elements) + + +def preprocess_mamba( + scheduler_output: SchedulerOutput, + kv_cache_config: KVCacheConfig, + cache_config: CacheConfig, + mamba_state_idx: dict[str, int], + input_batch: GPUInputBatch, + requests: dict[str, CachedRequestState], + forward_context: dict[str, Any], + mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...], +): + """ + Copy the mamba state of previous step to the last + (1 + num_speculative_blocks) block. + """ + mamba_group_ids, mamba_spec = get_mamba_groups(kv_cache_config) + num_speculative_blocks = mamba_spec.num_speculative_blocks + # TODO(Chen): we need to optimize this function a lot + assert cache_config.enable_prefix_caching + block_size = mamba_spec.block_size + finished_req_ids = scheduler_output.finished_req_ids + preempted_req_ids = scheduler_output.preempted_req_ids or set() + for req_id in itertools.chain(finished_req_ids, preempted_req_ids): + mamba_state_idx.pop(req_id, None) + + src_state_list: list[int] = [] + dest_state_list: list[int] = [] + num_elements_list: list[int] = [] + for i, req_id in enumerate(input_batch.req_ids): + req_state = requests[req_id] + prev_state_idx = mamba_state_idx.get(req_id) + if prev_state_idx is None: + # new / resumed request, no previous state + # if num_computed_tokens is 0, prev_state_idx will be -1 + prev_state_idx = (req_state.num_computed_tokens - 1) // block_size + + num_blocks = len(req_state.block_ids[mamba_group_ids[0]]) + + # We always save the current running state at the last + # (1 + num_speculative_blocks) block. + # A corner case worth mention here: assume we have block_size = 4 and + # num_speculative_tokens = 2. The request is [A, B, C] and contains 2 draft + # tokens [draft 1, draft 2]. Then we will have: + # Block 0: [A, B, C, draft 1] + # Block 1: [draft 2, TOFILL, TOFILL, TOFILL] + # Block 2: speculative block + # Block 3: speculative block + # And use block 1 to save the running state. + curr_state_idx = num_blocks - 1 - num_speculative_blocks + mamba_state_idx[req_id] = curr_state_idx + if prev_state_idx != -1 and prev_state_idx != curr_state_idx: + collect_mamba_copy_meta( + src_state_list, + dest_state_list, + num_elements_list, + kv_cache_config, + mamba_state_copy_funcs, + mamba_group_ids, + prev_state_idx, + curr_state_idx, + input_batch.num_accepted_tokens_cpu[i] - 1, + req_state, + forward_context, + ) + input_batch.num_accepted_tokens_cpu[i] = 1 + do_mamba_copy_block(src_state_list, dest_state_list, num_elements_list) + + +def postprocess_mamba( + scheduler_output: SchedulerOutput, + kv_cache_config: KVCacheConfig, + input_batch: GPUInputBatch, + requests: dict[str, CachedRequestState], + mamba_state_idx: dict[str, int], + forward_context: dict[str, Any], + mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...], +): + """ + If a blocks is converted from partial block to full block in this step, copy the + state from the block for running state to the new full block. + """ + num_scheduled_tokens_dict = scheduler_output.num_scheduled_tokens + scheduled_spec_decode_tokens_dict = scheduler_output.scheduled_spec_decode_tokens + num_accepted_tokens_cpu = input_batch.num_accepted_tokens_cpu + # NOTE: can be optimized as this function always returns the same result + mamba_group_ids, mamba_spec = get_mamba_groups(kv_cache_config) + src_state_list: list[int] = [] + dest_state_list: list[int] = [] + num_elements_list: list[int] = [] + for i, req_id in enumerate(input_batch.req_ids): + req_state = requests[req_id] + num_computed_tokens = req_state.num_computed_tokens + num_draft_tokens = len(scheduled_spec_decode_tokens_dict.get(req_id, [])) + num_scheduled_tokens = num_scheduled_tokens_dict[req_id] + num_accepted_tokens = num_accepted_tokens_cpu[i] + num_tokens_running_state = ( + num_computed_tokens + num_scheduled_tokens - num_draft_tokens + ) + new_num_computed_tokens = num_tokens_running_state + num_accepted_tokens - 1 + aligned_new_computed_tokens = ( + new_num_computed_tokens // mamba_spec.block_size * mamba_spec.block_size + ) + # TODO: how to ensure all blocks that cache_blocks called are cached here? + if aligned_new_computed_tokens >= num_tokens_running_state: + accept_token_bias = aligned_new_computed_tokens - num_tokens_running_state + src_block_idx = mamba_state_idx[req_id] + dest_block_idx = aligned_new_computed_tokens // mamba_spec.block_size - 1 + collect_mamba_copy_meta( + src_state_list, + dest_state_list, + num_elements_list, + kv_cache_config, + mamba_state_copy_funcs, + mamba_group_ids, + src_block_idx, + dest_block_idx, + accept_token_bias, + req_state, + forward_context, + ) + if src_block_idx == dest_block_idx: + num_accepted_tokens_cpu[i] = 1 + do_mamba_copy_block(src_state_list, dest_state_list, num_elements_list)