From aa27d8f07a51977e8f1af4eae9490b7dc65e876c Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Tue, 28 Oct 2025 03:34:23 +0000 Subject: [PATCH 001/130] Implementation of lighter mamba prefix caching on V1 Signed-off-by: huanghaoyan.hhy --- vllm/engine/arg_utils.py | 9 ++ vllm/envs.py | 6 +- vllm/model_executor/models/config.py | 3 +- vllm/model_executor/models/qwen3_next.py | 5 + vllm/model_executor/models/qwen3_next_mtp.py | 8 +- vllm/v1/attention/backends/gdn_attn.py | 11 ++ vllm/v1/attention/backends/linear_attn.py | 3 + vllm/v1/attention/backends/mamba1_attn.py | 15 +- vllm/v1/attention/backends/mamba2_attn.py | 18 ++- vllm/v1/core/block_pool.py | 45 +++++- vllm/v1/core/sched/scheduler.py | 79 ++++++++-- vllm/v1/core/single_type_kv_cache_manager.py | 158 +++++++++++++++++-- vllm/v1/kv_cache_interface.py | 17 +- vllm/v1/worker/gpu_model_runner.py | 93 ++++++++++- 14 files changed, 432 insertions(+), 38 deletions(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index b7c8f56e18c5..7a658ff95cd2 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1430,6 +1430,15 @@ def create_engine_config( f"dcp_size={self.decode_context_parallel_size}." ) + if (envs.VLLM_USE_LIGHTER_MAMBA_CACHE + and self.enable_prefix_caching + and model_config.is_hybrid + ): + assert envs.VLLM_USE_V1, ( + "Prefix caching for hybrid models requires the V1 engine.") + assert self.enable_chunked_prefill, ( + "Prefix caching for hybrid models requires chunked prefill.") + cache_config = CacheConfig( block_size=self.block_size, gpu_memory_utilization=self.gpu_memory_utilization, diff --git a/vllm/envs.py b/vllm/envs.py index 56558548d398..fff3c15edaab 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -232,7 +232,7 @@ VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD: int = 256 VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary" VLLM_USE_V2_MODEL_RUNNER: bool = False - + VLLM_USE_LIGHTER_MAMBA_CACHE: bool = False def get_default_cache_root(): return os.getenv( @@ -1526,6 +1526,9 @@ def get_vllm_port() -> int | None: "VLLM_USE_V2_MODEL_RUNNER": lambda: bool( int(os.getenv("VLLM_USE_V2_MODEL_RUNNER", "0")) ), + "VLLM_USE_LIGHTER_MAMBA_CACHE": lambda: bool( + int(os.getenv("VLLM_USE_LIGHTER_MAMBA_CACHE", "0")) + ), } # --8<-- [end:env-vars-definition] @@ -1637,6 +1640,7 @@ def compile_factors() -> dict[str, object]: "LOCAL_RANK", "CUDA_VISIBLE_DEVICES", "NO_COLOR", + "VLLM_USE_LIGHTER_MAMBA_CACHE", } from vllm.config.utils import normalize_value diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 3cf4bf991e66..0df8ef4c4561 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -384,7 +384,8 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: mamba_page_size = MambaSpec( shapes=model_cls.get_mamba_state_shape_from_config(vllm_config), dtypes=model_cls.get_mamba_state_dtype_from_config(vllm_config), - block_size=model_config.max_model_len, + block_size=model_config.max_model_len if not envs.VLLM_USE_LIGHTER_MAMBA_CACHE else cache_config.block_size, + enable_caching=cache_config.enable_prefix_caching, ).page_size_bytes # Model may be marked as is_hybrid diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index bfed64728305..fecfbf762e39 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -10,6 +10,7 @@ from torch import nn from transformers.activations import ACT2FN +from vllm import envs from vllm.attention import Attention, AttentionMetadata from vllm.compilation.decorators import support_torch_compile from vllm.config import ( @@ -1186,6 +1187,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): cache_config = vllm_config.cache_config scheduler_config = vllm_config.scheduler_config + if not envs.VLLM_USE_LIGHTER_MAMBA_CACHE: + assert not cache_config.enable_prefix_caching, ( + "Qwen3NextMTP currently does not support prefix caching" + ) assert not cache_config.enable_prefix_caching, ( "Qwen3Next currently does not support prefix caching" ) diff --git a/vllm/model_executor/models/qwen3_next_mtp.py b/vllm/model_executor/models/qwen3_next_mtp.py index 83694caa5248..17ee3a9792d8 100644 --- a/vllm/model_executor/models/qwen3_next_mtp.py +++ b/vllm/model_executor/models/qwen3_next_mtp.py @@ -7,6 +7,7 @@ import torch from torch import nn +from vllm import envs from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig from vllm.distributed.parallel_state import get_pp_group @@ -234,9 +235,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config self.vllm_config = vllm_config cache_config = vllm_config.cache_config - assert not cache_config.enable_prefix_caching, ( - "Qwen3NextMTP currently does not support prefix caching" - ) + if not envs.VLLM_USE_LIGHTER_MAMBA_CACHE: + assert not cache_config.enable_prefix_caching, ( + "Qwen3NextMTP currently does not support prefix caching" + ) self.quant_config = vllm_config.quant_config diff --git a/vllm/v1/attention/backends/gdn_attn.py b/vllm/v1/attention/backends/gdn_attn.py index 69b5a6fb4856..a0aa1ee344c7 100644 --- a/vllm/v1/attention/backends/gdn_attn.py +++ b/vllm/v1/attention/backends/gdn_attn.py @@ -6,6 +6,7 @@ import torch +from vllm import envs from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.config import VllmConfig @@ -337,6 +338,16 @@ def build( # type: ignore[override] non_spec_query_start_loc = self.non_spec_query_start_loc[: batch_size + 1] non_spec_query_start_loc[num_decodes + 1 :].fill_(non_spec_num_query_tokens) + if envs.VLLM_USE_LIGHTER_MAMBA_CACHE: + # NOTE: With Mamba prefix-caching support, a request can consist of + # multiple blocks. This makes the state_indices non-contiguous, so + # we must explicitly make them contiguous here. + if spec_state_indices_tensor is not None: + spec_state_indices_tensor = spec_state_indices_tensor.contiguous() + if non_spec_state_indices_tensor is not None: + non_spec_state_indices_tensor = \ + non_spec_state_indices_tensor.contiguous() + attn_metadata = GDNAttentionMetadata( num_prefills=num_prefills, num_prefill_tokens=num_prefill_tokens, diff --git a/vllm/v1/attention/backends/linear_attn.py b/vllm/v1/attention/backends/linear_attn.py index 1900c50849ec..2fd36233006c 100644 --- a/vllm/v1/attention/backends/linear_attn.py +++ b/vllm/v1/attention/backends/linear_attn.py @@ -4,6 +4,7 @@ import torch +from vllm import envs from vllm.attention.backends.abstract import AttentionBackend from vllm.config import VllmConfig from vllm.v1.attention.backends.utils import ( @@ -55,6 +56,8 @@ def build( seq_lens = common_attn_metadata.seq_lens state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] + if envs.VLLM_USE_LIGHTER_MAMBA_CACHE: + state_indices_tensor = state_indices_tensor.contiguous() num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( split_decodes_and_prefills( diff --git a/vllm/v1/attention/backends/mamba1_attn.py b/vllm/v1/attention/backends/mamba1_attn.py index 8e949e53330c..5ebd61eeca55 100644 --- a/vllm/v1/attention/backends/mamba1_attn.py +++ b/vllm/v1/attention/backends/mamba1_attn.py @@ -5,6 +5,7 @@ import torch +from vllm import envs from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.config import VllmConfig @@ -75,7 +76,9 @@ def build( # TODO(@Josephasafg) Mamba1 and Mamba2 have a lot of code in common here. # We should consolidate this code - if self.vllm_config.cache_config.enable_prefix_caching: + if (not envs.VLLM_USE_LIGHTER_MAMBA_CACHE + and self.vllm_config.cache_config.enable_prefix_caching + ): # Return a tensor of shape (#requests, #max blocks) state_indices_tensor = common_attn_metadata.block_table_tensor mamba_block_size = self.kv_cache_spec.block_size @@ -92,6 +95,8 @@ def build( else: # Always return just a single block per each request: state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] + if envs.VLLM_USE_LIGHTER_MAMBA_CACHE: + state_indices_tensor = state_indices_tensor.contiguous() block_idx_last_scheduled_token = None block_idx_last_computed_token = None @@ -110,7 +115,9 @@ def build( common_attn_metadata.query_start_loc.device ) - if self.vllm_config.cache_config.enable_prefix_caching: + if (not envs.VLLM_USE_LIGHTER_MAMBA_CACHE + and self.vllm_config.cache_config.enable_prefix_caching + ): assert num_computed_tokens is not None num_computed_tokens_p = num_computed_tokens[ num_reqs - num_prefills : num_reqs @@ -132,7 +139,9 @@ def build( state_indices_tensor = self.state_indices_tensor[:padded_decodes] state_indices_tensor[num_decodes:] = PAD_SLOT_ID - if self.vllm_config.cache_config.enable_prefix_caching: + if (not envs.VLLM_USE_LIGHTER_MAMBA_CACHE + and self.vllm_config.cache_config.enable_prefix_caching + ): self.block_idx_last_scheduled_token[:num_decodes].copy_( block_idx_last_scheduled_token, non_blocking=True ) diff --git a/vllm/v1/attention/backends/mamba2_attn.py b/vllm/v1/attention/backends/mamba2_attn.py index 888734e5d2b6..9d12bd1d4b06 100644 --- a/vllm/v1/attention/backends/mamba2_attn.py +++ b/vllm/v1/attention/backends/mamba2_attn.py @@ -5,6 +5,7 @@ import torch +from vllm import envs from vllm.attention.backends.abstract import AttentionBackend from vllm.config import VllmConfig from vllm.utils.math_utils import cdiv @@ -173,7 +174,9 @@ def build( block_idx_first_scheduled_token = None block_idx_first_scheduled_token_p = None - if self.vllm_config.cache_config.enable_prefix_caching: + if (not envs.VLLM_USE_LIGHTER_MAMBA_CACHE + and self.vllm_config.cache_config.enable_prefix_caching + ): # Return a tensor of shape (#requests, #max blocks) state_indices_tensor = common_attn_metadata.block_table_tensor # Additional cache-related varaiables: @@ -191,6 +194,11 @@ def build( else: # Always return just a single block per each request: state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] + if envs.VLLM_USE_LIGHTER_MAMBA_CACHE: + # NOTE: With Mamba prefix-caching support, a request can consist of + # multiple blocks. This makes the state_indices non-contiguous, so + # we must explicitly make them contiguous here. + state_indices_tensor = state_indices_tensor.contiguous() # Additional cache-related varaiables: block_idx_last_scheduled_token = None block_idx_last_computed_token = None @@ -220,7 +228,9 @@ def build( - num_decode_tokens ) - if self.vllm_config.cache_config.enable_prefix_caching: + if (not envs.VLLM_USE_LIGHTER_MAMBA_CACHE + and self.vllm_config.cache_config.enable_prefix_caching + ): assert num_computed_tokens is not None num_computed_tokens_p = num_computed_tokens[ num_reqs - num_prefills : num_reqs @@ -312,7 +322,9 @@ def build( state_indices_tensor = self.state_indices_tensor[:num_input_tokens] state_indices_tensor[num_decodes:] = PAD_SLOT_ID - if self.vllm_config.cache_config.enable_prefix_caching: + if (not envs.VLLM_USE_LIGHTER_MAMBA_CACHE + and self.vllm_config.cache_config.enable_prefix_caching + ): self.block_idx_last_scheduled_token[:num_decodes].copy_( block_idx_last_scheduled_token, non_blocking=True ) diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index 55710ad5cc69..cef6d7975e7d 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable, Sequence -from typing import Any +from typing import Any, Optional from vllm.distributed.kv_events import ( MEDIUM_GPU, @@ -193,6 +193,49 @@ def get_cached_block( cached_blocks.append(block) return cached_blocks + def cache_full_block( + self, + request: Request, + block: KVCacheBlock, + cached_block_index: int, + block_size: int, + kv_cache_group_id: int, + ) -> None: + """Cache a full block for prefix caching. + """ + + assert cached_block_index >= 0 + assert len(request.block_hashes) > cached_block_index + new_block_hash: BlockHash = request.block_hashes[cached_block_index] + new_hashes: Optional[list[ExternalBlockHash]] = ( + [] if self.enable_kv_cache_events else None) + assert block.block_hash is None + + # Update and added the full block to the cache. + block_hash_with_group_id: BlockHashWithGroupId = make_block_hash_with_group_id( + new_block_hash, kv_cache_group_id) + block.block_hash = block_hash_with_group_id + self.cached_block_hash_to_block[block_hash_with_group_id][ + block.block_id] = block + if new_hashes is not None: + new_hashes.append(maybe_convert_block_hash(new_block_hash)) + + if self.enable_kv_cache_events: + parent_block_hash: Optional[ExternalBlockHash] = None + + self.kv_event_queue.append( + BlockStored( + block_hashes=new_hashes, + parent_block_hash=parent_block_hash, + token_ids=request. + all_token_ids[cached_block_index * block_size: + (cached_block_index+1) * block_size], + block_size=block_size, + lora_id=request.lora_request.id + if request.lora_request else None, + medium=MEDIUM_GPU, + )) + def cache_full_blocks( self, request: Request, diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index a7ec0de37263..b12faa0b8956 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -39,7 +39,7 @@ from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_queue from vllm.v1.core.sched.utils import check_stop, remove_all from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs -from vllm.v1.kv_cache_interface import KVCacheConfig +from vllm.v1.kv_cache_interface import KVCacheConfig, MambaSpec from vllm.v1.metrics.stats import PrefixCacheStats, SchedulerStats from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput from vllm.v1.request import Request, RequestStatus @@ -190,6 +190,12 @@ def __init__( self.use_pp = self.parallel_config.pipeline_parallel_size > 1 self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER + def _has_mamba_spec(self) -> bool: + has_mamba: bool = any(isinstance(spec.kv_cache_spec, MambaSpec) + for spec in self.kv_cache_config.kv_cache_groups) + assert not has_mamba or self.vllm_config.model_config.is_hybrid + return has_mamba + def schedule(self) -> SchedulerOutput: # NOTE(woosuk) on the scheduling algorithm: # There's no "decoding phase" nor "prefill phase" in the scheduler. @@ -224,14 +230,47 @@ def schedule(self) -> SchedulerOutput: while req_index < len(self.running) and token_budget > 0: request = self.running[req_index] - num_new_tokens = ( - request.num_tokens_with_spec - + request.num_output_placeholders - - request.num_computed_tokens - ) - if 0 < self.scheduler_config.long_prefill_token_threshold < num_new_tokens: - num_new_tokens = self.scheduler_config.long_prefill_token_threshold - num_new_tokens = min(num_new_tokens, token_budget) + # Ensure new tokens for a request in the prefill phase do not contain + # sps tokens, especially in the last prefill chunk. For a hybrid-model, + # extra sps tokens would corrupt the generated Mamba state. + # TODO: This logic does not yet handle resumed requests. + if request.num_computed_tokens < request.num_prompt_tokens: + num_new_tokens = min(request.num_tokens_with_spec + + request.num_output_placeholders, + request.num_prompt_tokens) - request.num_computed_tokens + else: + num_new_tokens = (request.num_tokens_with_spec + + request.num_output_placeholders - + request.num_computed_tokens) + + if (0 < self.scheduler_config.long_prefill_token_threshold < + num_new_tokens): + num_new_tokens = ( + self.scheduler_config.long_prefill_token_threshold) + + if (envs.VLLM_USE_LIGHTER_MAMBA_CACHE + and self.cache_config.enable_prefix_caching + and self._has_mamba_spec()): + # To enable block-aligned caching of the Mamba state, `num_new_tokens` + # must be a multiple of `block_size`. + # As an exception, if `num_new_tokens` is less than `block_size`, the + # state is simply not cached, requiring no special handling. + # Additionally, when Eagle mode is enabled, FullAttn prunes the last + # matching block. To prevent this from causing a Mamba cache miss, the + # last chunk must be larger than `block_size`. + block_size = self.block_size + max_last_chunk = block_size * (2 if self.use_eagle else 1) + if num_new_tokens < max_last_chunk: + num_new_tokens = min(num_new_tokens, token_budget) + else: + ori_num_new_tokens = num_new_tokens + num_new_tokens = min(num_new_tokens, token_budget) + num_new_tokens = num_new_tokens // block_size * block_size + if self.use_eagle and ori_num_new_tokens - num_new_tokens < block_size: + assert num_new_tokens >= block_size + num_new_tokens -= block_size + else: + num_new_tokens = min(num_new_tokens, token_budget) # Make sure the input position does not exceed the max model len or # request's max_tokens. @@ -270,6 +309,8 @@ def schedule(self) -> SchedulerOutput: # its max_total_tokens or max_model_len. # 2. The encoder budget is exhausted. # 3. The encoder cache is exhausted. + # 4. Insufficient budget for a block-aligned chunk in hybrid + # models with lighter mamba prefix caching. # NOTE(woosuk): Here, by doing `continue` instead of `break`, # we do not strictly follow the FCFS scheduling policy and # allow the lower-priority requests to be scheduled. @@ -512,7 +553,25 @@ def schedule(self) -> SchedulerOutput: skipped_waiting_requests.prepend_request(request) continue - num_new_tokens = min(num_new_tokens, token_budget) + if (envs.VLLM_USE_LIGHTER_MAMBA_CACHE + and self.cache_config.enable_prefix_caching + and self._has_mamba_spec()): + block_size = self.block_size + max_last_chunk = block_size * (2 if self.use_eagle else 1) + if num_new_tokens < max_last_chunk: + num_new_tokens = min(num_new_tokens, token_budget) + else: + ori_num_new_tokens = num_new_tokens + num_new_tokens = min(num_new_tokens, token_budget) + num_new_tokens = num_new_tokens // block_size * block_size + if self.use_eagle and ori_num_new_tokens - num_new_tokens < block_size: + assert num_new_tokens >= block_size + num_new_tokens -= block_size + if num_new_tokens == 0: + token_budget = 0 + break + else: + num_new_tokens = min(num_new_tokens, token_budget) assert num_new_tokens > 0 # Schedule encoder inputs. diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index d90ec550f766..a43caa79392c 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -5,6 +5,7 @@ from collections import defaultdict from collections.abc import Sequence +from vllm import envs from vllm.utils.math_utils import cdiv from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock @@ -609,6 +610,13 @@ def get_num_common_prefix_blocks(self, running_request_id: str) -> int: class MambaManager(SingleTypeKVCacheManager): + + def __init__(self, kv_cache_spec: MambaSpec, **kwargs) -> None: + super().__init__(kv_cache_spec, **kwargs) + if envs.VLLM_USE_LIGHTER_MAMBA_CACHE: + self._req_to_computed_tokens: dict[str, int] = {} + self._req_to_new_tokens: dict[str, int] = {} + @classmethod def find_longest_cache_hit( cls, @@ -647,6 +655,39 @@ def find_longest_cache_hit( return computed_blocks + def remove_skipped_blocks(self, request_id: str, + num_computed_tokens: int) -> None: + if not envs.VLLM_USE_LIGHTER_MAMBA_CACHE: + # Here unused blocks may be freed up for running requests. + # TODO(@s3woz) Free up all blocks that aren't needed by Mamba2 + # (for which find_longest_cache_hit returns block_pool.null_block) + pass + else: + assert isinstance(self.kv_cache_spec, MambaSpec) + self._req_to_computed_tokens[request_id] = num_computed_tokens + # Each request will always have 1 block at this moment, so no need to + # remove blocks. + if not self.kv_cache_spec.enable_caching: + return + blocks: list[KVCacheBlock] = self.req_to_blocks[request_id] + num_blocks = len(blocks) + if num_blocks > 0: + prefix_block: KVCacheBlock = blocks[0] + assert self._req_to_computed_tokens[request_id] != 0 + if not prefix_block.is_null: + self.block_pool.free_blocks([prefix_block]) + blocks[0] = self.block_pool.null_block + + if num_blocks > 2 + self.kv_cache_spec.num_speculative_blocks: + assert num_blocks == 3 + self.kv_cache_spec.num_speculative_blocks + last_new_tokens = self._req_to_new_tokens[request_id] + assert last_new_tokens % self.block_size == 0 + assert last_new_tokens >= self.block_size + self.block_pool.free_blocks(blocks[-1:]) + blocks.pop() + else: + assert num_blocks == 0 or num_blocks == 2 + self.kv_cache_spec.num_speculative_blocks + def get_num_common_prefix_blocks(self, running_request_id: str) -> int: """ cascade attention is not supported by mamba @@ -662,14 +703,55 @@ def get_num_blocks_to_allocate( # Allocate extra `num_speculative_blocks` blocks for # speculative decoding (MTP/EAGLE) with linear attention. assert isinstance(self.kv_cache_spec, MambaSpec) - if self.kv_cache_spec.num_speculative_blocks > 0: - num_tokens += ( - self.kv_cache_spec.block_size - * self.kv_cache_spec.num_speculative_blocks + if not envs.VLLM_USE_LIGHTER_MAMBA_CACHE: + if self.kv_cache_spec.num_speculative_blocks > 0: + num_tokens += ( + self.kv_cache_spec.block_size + * self.kv_cache_spec.num_speculative_blocks + ) + return super().get_num_blocks_to_allocate( + request_id, num_tokens, new_computed_blocks ) - return super().get_num_blocks_to_allocate( - request_id, num_tokens, new_computed_blocks - ) + else: + num_computed_tokens = self._req_to_computed_tokens.get(request_id, 0) + if num_computed_tokens == 0: + num_new_blocks = 1 + (self.kv_cache_spec.num_speculative_blocks + if self.kv_cache_spec.num_speculative_blocks > 0 else 0) + else: + num_new_blocks = 0 + + if self.kv_cache_spec.enable_caching: + num_new_tokens = num_tokens - num_computed_tokens - len(new_computed_blocks) * self.block_size + if num_computed_tokens != 0: + num_new_tokens -= self.kv_cache_spec.num_speculative_blocks + self._req_to_new_tokens[request_id] = num_new_tokens + # NOTE: last chunk may larger than block_size when using eagle. + if num_new_tokens >= self.block_size and num_new_tokens % self.block_size == 0: + num_new_blocks += 1 + + # If a computed block of a request is an eviction candidate (in the + # free queue and ref_cnt == 0), it will be changed from a free block + # to a computed block when the request is allocated, so we also count + # it as needed to be allocated. + num_evictable_computed_blocks = sum( + blk.ref_cnt == 0 and not blk.is_null + for blk in new_computed_blocks) + return num_new_blocks + num_evictable_computed_blocks + + def save_new_computed_blocks( + self, request_id: str, + new_computed_blocks: list[KVCacheBlock]) -> None: + assert isinstance(self.kv_cache_spec, MambaSpec) + if envs.VLLM_USE_LIGHTER_MAMBA_CACHE: + if not self.kv_cache_spec.enable_caching: + return + if request_id not in self.num_cached_block: + if new_computed_blocks: + assert len(new_computed_blocks) == 1 or new_computed_blocks[-2].is_null + new_computed_blocks = new_computed_blocks[-1:] + else: + new_computed_blocks = [self.block_pool.null_block] + super().save_new_computed_blocks(request_id, new_computed_blocks) def allocate_new_blocks( self, request_id: str, num_tokens: int @@ -677,13 +759,63 @@ def allocate_new_blocks( # Allocate extra `num_speculative_blocks` blocks for # speculative decoding (MTP/EAGLE) with linear attention. assert isinstance(self.kv_cache_spec, MambaSpec) - if self.kv_cache_spec.num_speculative_blocks > 0: - num_tokens += ( - self.kv_cache_spec.block_size - * self.kv_cache_spec.num_speculative_blocks - ) - return super().allocate_new_blocks(request_id, num_tokens) + if not envs.VLLM_USE_LIGHTER_MAMBA_CACHE: + if self.kv_cache_spec.num_speculative_blocks > 0: + num_tokens += ( + self.kv_cache_spec.block_size + * self.kv_cache_spec.num_speculative_blocks + ) + return super().allocate_new_blocks(request_id, num_tokens) + else: + num_computed_tokens = self._req_to_computed_tokens.get(request_id, 0) + if num_computed_tokens == 0: + num_new_blocks = 1 + (self.kv_cache_spec.num_speculative_blocks + if self.kv_cache_spec.num_speculative_blocks > 0 else 0) + else: + assert num_tokens >= num_computed_tokens + num_new_blocks = 0 + + if self.kv_cache_spec.enable_caching: + num_new_tokens = self._req_to_new_tokens[request_id] + if num_new_tokens >= self.block_size and num_new_tokens % self.block_size == 0: + num_new_blocks += 1 + + req_blocks: list[KVCacheBlock] = self.req_to_blocks[request_id] + if num_new_blocks <= 0: + return [] + else: + new_blocks: list[KVCacheBlock] = self.block_pool.get_new_blocks(num_new_blocks) + req_blocks.extend(new_blocks) + return new_blocks + + def cache_blocks(self, request: Request, num_tokens: int) -> None: + assert isinstance(self.kv_cache_spec, MambaSpec) + if envs.VLLM_USE_LIGHTER_MAMBA_CACHE: + num_computed_tokens = request.num_computed_tokens + num_new_tokens = self._req_to_new_tokens[request.request_id] + # NOTE:For sps, an extra block may be allocated but not cached + if (num_new_tokens >= self.block_size + and num_new_tokens % self.block_size == 0 + and num_tokens % self.block_size == 0): + assert num_new_tokens % self.block_size == 0 + assert num_computed_tokens % self.block_size == 0 + assert len(self.req_to_blocks[request.request_id]) == 3 + self.kv_cache_spec.num_speculative_blocks + self.block_pool.cache_full_block( + request=request, + block=self.req_to_blocks[request.request_id][-1], + cached_block_index=(num_tokens // self.block_size - 1), + block_size=self.block_size, + kv_cache_group_id=self.kv_cache_group_id + ) + self.num_cached_block[request.request_id] += 1 + else: + super().cache_blocks(request, num_tokens) + def free(self, request_id: str) -> None: + if envs.VLLM_USE_LIGHTER_MAMBA_CACHE: + self._req_to_computed_tokens.pop(request_id, None) + self._req_to_new_tokens.pop(request_id, None) + super().free(request_id) class CrossAttentionManager(SingleTypeKVCacheManager): """Manager for cross-attention KV cache in encoder-decoder models.""" diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 751862aa9c76..cdacd427291a 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -8,6 +8,7 @@ import torch from typing_extensions import Self +from vllm import envs from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.utils.math_utils import cdiv @@ -247,6 +248,7 @@ class MambaSpec(KVCacheSpec): page_size_padded: int | None = None mamba_type: str = "mamba2" num_speculative_blocks: int = 0 + enable_caching: bool = False @property def page_size_bytes(self) -> int: @@ -260,8 +262,19 @@ def page_size_bytes(self) -> int: return page_size def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: - max_model_len = vllm_config.model_config.max_model_len - return cdiv(max_model_len, self.block_size) * self.page_size_bytes + # We allocate 1 block for each request now, so max_memory_usage_bytes is + # the same as page_size_bytes. + # Need to update this when supporting prefix caching. + if not envs.VLLM_USE_LIGHTER_MAMBA_CACHE: + max_model_len = vllm_config.model_config.max_model_len + return cdiv(max_model_len, self.block_size) * self.page_size_bytes + else: + # NOTE: We allocate 1 block per request by default. With prefix + # caching enabled, up to 2 additional blocks are required: one + # for reading the matched prefix and one for caching the current + # state. + return self.page_size_bytes * (3 if self.enable_caching else 1) + @dataclass(frozen=True) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index cbafc9c993cc..5ef995e949a1 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -10,7 +10,7 @@ from copy import copy, deepcopy from functools import reduce from itertools import product -from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias, cast +from typing import TYPE_CHECKING, Any, NamedTuple, Optional, TypeAlias, cast import numpy as np import torch @@ -98,6 +98,7 @@ reorder_batch_to_split_decodes_and_prefills, split_attn_metadata, ) +from vllm.v1.core.sched.output import CachedRequestData, NewRequestData from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher from vllm.v1.kv_cache_interface import ( AttentionSpec, @@ -1527,6 +1528,13 @@ def _build_attention_metadata( # Fill unused with -1. Needed for reshape_and_cache in full cuda # graph mode. blk_table.slot_mapping.gpu[total_num_scheduled_tokens:].fill_(-1) + + if (envs.VLLM_USE_LIGHTER_MAMBA_CACHE + and self.cache_config.enable_prefix_caching + and isinstance(kv_cache_group_spec.kv_cache_spec, MambaSpec) + ): + blk_table_tensor = blk_table_tensor[:, 1:] + self._preprocess_mamba_prefix(scheduler_output, kv_cache_group_id, kv_cache_group_spec) common_attn_metadata = CommonAttentionMetadata( query_start_loc=query_start_loc, @@ -2622,6 +2630,84 @@ def _model_forward( **model_kwargs, ) + def _mamba_copy_block(self, kv_cache_group_spec: KVCacheGroupSpec, + src_block_id: int, dest_block_id: int): + forward_context = self.compilation_config.static_forward_context + for layer_name in kv_cache_group_spec.layer_names: + kv_caches: list[list[torch.Tensor]] = forward_context[layer_name].kv_cache + for kv_cache in kv_caches: + if isinstance(kv_cache, torch.Tensor): + kv_cache[dest_block_id].copy_(kv_cache[src_block_id]) + elif isinstance(kv_cache, list): + for kv_cache_part in kv_cache: + kv_cache_part[dest_block_id].copy_(kv_cache_part[src_block_id]) + + def _preprocess_mamba_prefix(self, scheduler_output: "SchedulerOutput", + kv_cache_group_id: int, + kv_cache_group_spec: KVCacheGroupSpec, + ): + assert isinstance(kv_cache_group_spec.kv_cache_spec, MambaSpec) + assert self.cache_config.enable_prefix_caching + new_reqs: list[NewRequestData] = scheduler_output.scheduled_new_reqs + for new_req in new_reqs: + if new_req.num_computed_tokens == 0: + continue + block_ids: list[int] = new_req.block_ids[kv_cache_group_id] + assert block_ids[0] != 0, f'{block_ids=}' + prefix_block_id, dest_block_id = block_ids[0], block_ids[1] + self._mamba_copy_block(kv_cache_group_spec, prefix_block_id, dest_block_id) + cached_reqs: CachedRequestData = scheduler_output.scheduled_cached_reqs + for i, resumed in enumerate(cached_reqs.resumed_from_preemption): + if not resumed: + continue + group_block_ids: Optional[tuple[list[int], ...]] = cached_reqs.new_block_ids[i] + assert group_block_ids is not None + new_block_ids: list[int] = group_block_ids[kv_cache_group_id] + assert len(new_block_ids) >= 2 + kv_cache_group_spec.kv_cache_spec.num_speculative_blocks + if cached_reqs.num_computed_tokens[i] == 0: + continue + assert new_block_ids[0] != 0, f'{new_block_ids=}' + prefix_block_id, dest_block_id = new_block_ids[0], new_block_ids[1] + self._mamba_copy_block(kv_cache_group_spec, prefix_block_id, dest_block_id) + + + def _postprocess_mamba_cache(self, scheduler_output: "SchedulerOutput"): + assert self.cache_config.enable_prefix_caching + for kv_cache_group_id, kv_cache_group_spec in enumerate( + self.kv_cache_config.kv_cache_groups): + if not isinstance(kv_cache_group_spec.kv_cache_spec, MambaSpec): + continue + new_reqs: list[NewRequestData] = scheduler_output.scheduled_new_reqs + num_speculative_blocks = kv_cache_group_spec.kv_cache_spec.num_speculative_blocks + for new_req in new_reqs: + block_ids: list[int] = new_req.block_ids[kv_cache_group_id] + if len(block_ids) <= 2 + num_speculative_blocks: + continue + assert len(block_ids) == 3 + num_speculative_blocks + src_block_id, dest_block_id = block_ids[1], block_ids[-1] + self._mamba_copy_block(kv_cache_group_spec, src_block_id, dest_block_id) + cached_reqs: CachedRequestData = scheduler_output.scheduled_cached_reqs + for i, req_id in enumerate(cached_reqs.req_ids): + group_block_ids: Optional[tuple[list[int], ...]] = cached_reqs.new_block_ids[i] + if group_block_ids is None: + assert not cached_reqs.resumed_from_preemption[i] + continue + new_block_ids: list[int] = group_block_ids[kv_cache_group_id] + if not new_block_ids: + assert not cached_reqs.resumed_from_preemption[i] + continue + if not cached_reqs.resumed_from_preemption[i]: + assert len(new_block_ids) == 1 + block_ids: list[int] = self.requests[req_id].block_ids[kv_cache_group_id] + src_block_id, dest_block_id = block_ids[1], new_block_ids[0] + else: + if len(new_block_ids) == 2 + num_speculative_blocks: + continue + assert len(new_block_ids) == 3 + num_speculative_blocks + src_block_id, dest_block_id = new_block_ids[1], new_block_ids[-1] + self._mamba_copy_block(kv_cache_group_spec, src_block_id, dest_block_id) + + @torch.inference_mode() def execute_model( self, @@ -2912,6 +2998,11 @@ def sample_tokens( scheduler_output, grammar_output, self.input_batch, logits ) + if (envs.VLLM_USE_LIGHTER_MAMBA_CACHE + and self.cache_config.enable_prefix_caching + ): + self._postprocess_mamba_cache(scheduler_output) + with record_function_or_nullcontext("gpu_model_runner: sample"): sampler_output = self._sample(logits, spec_decode_metadata) From 81d356190ab60937fdfa936a8a5cac3cdf4239ea Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Wed, 5 Nov 2025 04:12:50 +0000 Subject: [PATCH 002/130] [BugFix] Resolve compatibility issues in lighter mamba prefix cache Signed-off-by: huanghaoyan.hhy --- vllm/model_executor/layers/mamba/abstract.py | 6 +++- vllm/model_executor/models/config.py | 30 +++++++++++++------- vllm/model_executor/models/qwen3_next.py | 3 -- vllm/v1/core/block_pool.py | 3 +- vllm/v1/core/kv_cache_manager.py | 5 +++- 5 files changed, 29 insertions(+), 18 deletions(-) diff --git a/vllm/model_executor/layers/mamba/abstract.py b/vllm/model_executor/layers/mamba/abstract.py index aa919d6fdc35..41626eed5c52 100644 --- a/vllm/model_executor/layers/mamba/abstract.py +++ b/vllm/model_executor/layers/mamba/abstract.py @@ -6,6 +6,7 @@ import torch +from vllm import envs from vllm.attention.selector import get_mamba_attn_backend from vllm.config import VllmConfig from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase @@ -51,7 +52,9 @@ def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None: raise NotImplementedError( "Mamba with speculative decoding is not supported yet." ) - mamba_block_size = vllm_config.cache_config.mamba_block_size + mamba_block_size = (vllm_config.cache_config.mamba_block_size + if not envs.VLLM_USE_LIGHTER_MAMBA_CACHE + else vllm_config.cache_config.block_size) page_size_padded = vllm_config.cache_config.mamba_page_size_padded return MambaSpec( shapes=self.get_state_shape(), @@ -64,6 +67,7 @@ def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None: if vllm_config.speculative_config else 0 ), + enable_caching=vllm_config.cache_config.enable_prefix_caching, ) def get_attn_backend(self) -> type["AttentionBackend"]: diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 0df8ef4c4561..bc2f434e3caa 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -293,18 +293,25 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: cache_config.mamba_block_size = model_config.max_model_len if cache_config.enable_prefix_caching: - if model_config.supports_mamba_prefix_caching: - logger.info( - "Warning: Prefix caching is currently enabled. " - "Its support for Mamba layers is experimental. " - "Please report any issues you may observe." - ) + if not envs.VLLM_USE_LIGHTER_MAMBA_CACHE: + if model_config.supports_mamba_prefix_caching: + logger.info( + "Warning: Prefix caching is currently enabled. " + "Its support for Mamba layers is experimental. " + "Please report any issues you may observe." + ) + else: + logger.info( + "Hybrid or mamba-based model detected without " + "support for prefix caching: disabling." + ) + cache_config.enable_prefix_caching = False else: logger.info( - "Hybrid or mamba-based model detected without " - "support for prefix caching: disabling." - ) - cache_config.enable_prefix_caching = False + "Warning: Lighter Mamba Prefix caching is currently" + " enabled. Its support is experimental. " + "Please report any issues you may observe." + ) # TODO(tdoublep): remove once cascade attention is supported logger.info( @@ -394,7 +401,8 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: if mamba_page_size == 0: return - if cache_config.enable_prefix_caching: + if (cache_config.enable_prefix_caching + and not envs.VLLM_USE_LIGHTER_MAMBA_CACHE): # With prefix caching, select attention block size to # optimize for mamba kernel performance diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index fecfbf762e39..29e81f36e347 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -1191,9 +1191,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): assert not cache_config.enable_prefix_caching, ( "Qwen3NextMTP currently does not support prefix caching" ) - assert not cache_config.enable_prefix_caching, ( - "Qwen3Next currently does not support prefix caching" - ) self.quant_config = vllm_config.quant_config super().__init__() diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index cef6d7975e7d..eecf676917d1 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -215,8 +215,7 @@ def cache_full_block( block_hash_with_group_id: BlockHashWithGroupId = make_block_hash_with_group_id( new_block_hash, kv_cache_group_id) block.block_hash = block_hash_with_group_id - self.cached_block_hash_to_block[block_hash_with_group_id][ - block.block_id] = block + self.cached_block_hash_to_block.insert(block_hash_with_group_id, block) if new_hashes is not None: new_hashes.append(maybe_convert_block_hash(new_block_hash)) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 2012c3fef88b..b4cd44ca0aea 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -6,6 +6,7 @@ from dataclasses import dataclass from typing import Literal, overload +from vllm import envs from vllm.distributed.kv_events import KVCacheEvent from vllm.logger import init_logger from vllm.v1.core.kv_cache_coordinator import get_kv_cache_coordinator @@ -304,7 +305,9 @@ def allocate_slots( "Computed blocks should be empty when prefix caching is disabled" ) - if new_computed_block_list is not self.empty_kv_cache_blocks.blocks: + if (envs.VLLM_USE_LIGHTER_MAMBA_CACHE + or new_computed_block_list is not self.empty_kv_cache_blocks.blocks + ): # Append the new computed blocks to the request blocks until now to # avoid the case where the new blocks cannot be allocated. self.coordinator.save_new_computed_blocks( From 8a652af0f7ccf539070e22bff6fa146f81960d30 Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Wed, 5 Nov 2025 16:04:25 +0000 Subject: [PATCH 003/130] [BugFix] Resolve compatibility issues for mamba Signed-off-by: huanghaoyan.hhy --- vllm/model_executor/layers/mamba/mamba_mixer2.py | 10 ++++++++-- vllm/v1/attention/backends/mamba_attn.py | 5 ++++- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 0ea5805305ed..714f12d97811 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -1,6 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import TYPE_CHECKING + +from vllm import envs + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend import torch from torch import nn @@ -616,7 +622,7 @@ def conv_ssm_forward( dim=0, ) - if prefix_caching_enabled: + if not envs.VLLM_USE_LIGHTER_MAMBA_CACHE and prefix_caching_enabled: # If prefix caching is enabled, retrieve the relevant variables # for prefill and decode block_idx_last_computed_token_d, block_idx_last_computed_token_p = ( @@ -809,7 +815,7 @@ def conv_ssm_forward( # Process decode requests if has_decode: - if prefix_caching_enabled: + if not envs.VLLM_USE_LIGHTER_MAMBA_CACHE and prefix_caching_enabled: state_indices_tensor_d_input = state_indices_tensor_d.gather( 1, block_idx_last_computed_token_d.unsqueeze(1) ).squeeze(1) diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index 0d875565fc99..a19ad79574ab 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -6,6 +6,7 @@ import torch +from vllm import envs from vllm.config import VllmConfig from vllm.utils.math_utils import cdiv from vllm.v1.attention.backends.utils import ( @@ -40,7 +41,9 @@ def __init__( self.compilation_config.max_cudagraph_capture_size, ) - if self.vllm_config.cache_config.enable_prefix_caching: + if (not envs.VLLM_USE_LIGHTER_MAMBA_CACHE + and self.vllm_config.cache_config.enable_prefix_caching + ): self.state_indices_tensor = torch.empty( ( self.decode_cudagraph_max_bs, From 8a39893917395c6d5a55b8534b9d69b883871d48 Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Sun, 23 Nov 2025 16:31:02 +0000 Subject: [PATCH 004/130] Add base implementation for lighter mamba cache with standard layout Signed-off-by: huanghaoyan.hhy --- vllm/v1/attention/backends/gdn_attn.py | 41 +++-- vllm/v1/core/block_pool.py | 86 ++++----- vllm/v1/core/single_type_kv_cache_manager.py | 177 +++++++++---------- vllm/v1/worker/gpu_model_runner.py | 152 ++++++++++------ 4 files changed, 260 insertions(+), 196 deletions(-) diff --git a/vllm/v1/attention/backends/gdn_attn.py b/vllm/v1/attention/backends/gdn_attn.py index a0aa1ee344c7..eaa506c91770 100644 --- a/vllm/v1/attention/backends/gdn_attn.py +++ b/vllm/v1/attention/backends/gdn_attn.py @@ -58,6 +58,15 @@ class GDNAttentionMetadata: batch_ptr: torch.Tensor | None = None token_chunk_offset_ptr: torch.Tensor | None = None +def mamba_gather_indices(common_attn_metadata: CommonAttentionMetadata, + block_size: int, + num_blocks: int): + block_table_tensor = common_attn_metadata.block_table_tensor + start_indices = common_attn_metadata.seq_lens // block_size + offsets = torch.arange(num_blocks, device=block_table_tensor.device) + indices_to_gather = start_indices.unsqueeze(1) + offsets + return torch.gather(block_table_tensor, 1, indices_to_gather) + class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]): _cudagraph_support = AttentionCGSupport.UNIFORM_BATCH @@ -146,6 +155,12 @@ def build( # type: ignore[override] context_lens = m.num_computed_tokens_cpu context_lens_tensor = context_lens.to(query_start_loc.device) nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None + if envs.VLLM_USE_LIGHTER_MAMBA_CACHE: + block_table_tensor = mamba_gather_indices(common_attn_metadata, + self.kv_cache_spec.block_size, + 1 + self.num_spec) + else: + block_table_tensor = m.block_table_tensor if ( not self.use_spec_decode @@ -175,7 +190,7 @@ def build( # type: ignore[override] spec_token_indx = None non_spec_token_indx = None spec_state_indices_tensor = None - non_spec_state_indices_tensor = m.block_table_tensor[:, 0] + non_spec_state_indices_tensor = block_table_tensor[:, 0] spec_query_start_loc = None non_spec_query_start_loc = query_start_loc num_accepted_tokens = None @@ -204,7 +219,7 @@ def build( # type: ignore[override] non_spec_token_indx = torch.empty( 0, dtype=torch.int32, device=query_start_loc.device ) - spec_state_indices_tensor = m.block_table_tensor[:, : self.num_spec + 1] + spec_state_indices_tensor = block_table_tensor[:, : self.num_spec + 1] non_spec_state_indices_tensor = None spec_query_start_loc = query_start_loc non_spec_query_start_loc = None @@ -217,10 +232,10 @@ def build( # type: ignore[override] non_spec_token_indx = index[:num_non_spec_tokens] spec_token_indx = index[num_non_spec_tokens:] - spec_state_indices_tensor = m.block_table_tensor[ + spec_state_indices_tensor = block_table_tensor[ spec_sequence_masks, : self.num_spec + 1 ] - non_spec_state_indices_tensor = m.block_table_tensor[ + non_spec_state_indices_tensor = block_table_tensor[ ~spec_sequence_masks, 0 ] @@ -338,15 +353,15 @@ def build( # type: ignore[override] non_spec_query_start_loc = self.non_spec_query_start_loc[: batch_size + 1] non_spec_query_start_loc[num_decodes + 1 :].fill_(non_spec_num_query_tokens) - if envs.VLLM_USE_LIGHTER_MAMBA_CACHE: - # NOTE: With Mamba prefix-caching support, a request can consist of - # multiple blocks. This makes the state_indices non-contiguous, so - # we must explicitly make them contiguous here. - if spec_state_indices_tensor is not None: - spec_state_indices_tensor = spec_state_indices_tensor.contiguous() - if non_spec_state_indices_tensor is not None: - non_spec_state_indices_tensor = \ - non_spec_state_indices_tensor.contiguous() + # if envs.VLLM_USE_LIGHTER_MAMBA_CACHE: + # # NOTE: With Mamba prefix-caching support, a request can consist of + # # multiple blocks. This makes the state_indices non-contiguous, so + # # we must explicitly make them contiguous here. + # if spec_state_indices_tensor is not None: + # spec_state_indices_tensor = spec_state_indices_tensor.contiguous() + # if non_spec_state_indices_tensor is not None: + # non_spec_state_indices_tensor = \ + # non_spec_state_indices_tensor.contiguous() attn_metadata = GDNAttentionMetadata( num_prefills=num_prefills, diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index eecf676917d1..9c336160e1c8 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -193,47 +193,47 @@ def get_cached_block( cached_blocks.append(block) return cached_blocks - def cache_full_block( - self, - request: Request, - block: KVCacheBlock, - cached_block_index: int, - block_size: int, - kv_cache_group_id: int, - ) -> None: - """Cache a full block for prefix caching. - """ - - assert cached_block_index >= 0 - assert len(request.block_hashes) > cached_block_index - new_block_hash: BlockHash = request.block_hashes[cached_block_index] - new_hashes: Optional[list[ExternalBlockHash]] = ( - [] if self.enable_kv_cache_events else None) - assert block.block_hash is None - - # Update and added the full block to the cache. - block_hash_with_group_id: BlockHashWithGroupId = make_block_hash_with_group_id( - new_block_hash, kv_cache_group_id) - block.block_hash = block_hash_with_group_id - self.cached_block_hash_to_block.insert(block_hash_with_group_id, block) - if new_hashes is not None: - new_hashes.append(maybe_convert_block_hash(new_block_hash)) - - if self.enable_kv_cache_events: - parent_block_hash: Optional[ExternalBlockHash] = None - - self.kv_event_queue.append( - BlockStored( - block_hashes=new_hashes, - parent_block_hash=parent_block_hash, - token_ids=request. - all_token_ids[cached_block_index * block_size: - (cached_block_index+1) * block_size], - block_size=block_size, - lora_id=request.lora_request.id - if request.lora_request else None, - medium=MEDIUM_GPU, - )) + # def cache_full_block( + # self, + # request: Request, + # block: KVCacheBlock, + # cached_block_index: int, + # block_size: int, + # kv_cache_group_id: int, + # ) -> None: + # """Cache a full block for prefix caching. + # """ + + # assert cached_block_index >= 0 + # assert len(request.block_hashes) > cached_block_index + # new_block_hash: BlockHash = request.block_hashes[cached_block_index] + # new_hashes: Optional[list[ExternalBlockHash]] = ( + # [] if self.enable_kv_cache_events else None) + # assert block.block_hash is None + + # # Update and added the full block to the cache. + # block_hash_with_group_id: BlockHashWithGroupId = make_block_hash_with_group_id( + # new_block_hash, kv_cache_group_id) + # block.block_hash = block_hash_with_group_id + # self.cached_block_hash_to_block.insert(block_hash_with_group_id, block) + # if new_hashes is not None: + # new_hashes.append(maybe_convert_block_hash(new_block_hash)) + + # if self.enable_kv_cache_events: + # parent_block_hash: Optional[ExternalBlockHash] = None + + # self.kv_event_queue.append( + # BlockStored( + # block_hashes=new_hashes, + # parent_block_hash=parent_block_hash, + # token_ids=request. + # all_token_ids[cached_block_index * block_size: + # (cached_block_index+1) * block_size], + # block_size=block_size, + # lora_id=request.lora_request.id + # if request.lora_request else None, + # medium=MEDIUM_GPU, + # )) def cache_full_blocks( self, @@ -271,6 +271,10 @@ def cache_full_blocks( [] if self.enable_kv_cache_events else None ) for i, blk in enumerate(new_full_blocks): + # NOTE: for mamba, full blocks includes the null block + # and the null block should be skipped + if blk is self.null_block: + continue assert blk.block_hash is None block_hash = new_block_hashes[i] diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index a43caa79392c..93decb5155ca 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -614,8 +614,9 @@ class MambaManager(SingleTypeKVCacheManager): def __init__(self, kv_cache_spec: MambaSpec, **kwargs) -> None: super().__init__(kv_cache_spec, **kwargs) if envs.VLLM_USE_LIGHTER_MAMBA_CACHE: - self._req_to_computed_tokens: dict[str, int] = {} - self._req_to_new_tokens: dict[str, int] = {} + self.num_speculative_blocks: int = kv_cache_spec.num_speculative_blocks + self._allocated_reqs: set[str] = set() + self._req_to_last_computed: dict[str, int] = {} @classmethod def find_longest_cache_hit( @@ -657,36 +658,25 @@ def find_longest_cache_hit( def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None: + assert isinstance(self.kv_cache_spec, MambaSpec) if not envs.VLLM_USE_LIGHTER_MAMBA_CACHE: # Here unused blocks may be freed up for running requests. # TODO(@s3woz) Free up all blocks that aren't needed by Mamba2 # (for which find_longest_cache_hit returns block_pool.null_block) pass else: - assert isinstance(self.kv_cache_spec, MambaSpec) - self._req_to_computed_tokens[request_id] = num_computed_tokens - # Each request will always have 1 block at this moment, so no need to - # remove blocks. - if not self.kv_cache_spec.enable_caching: - return blocks: list[KVCacheBlock] = self.req_to_blocks[request_id] - num_blocks = len(blocks) - if num_blocks > 0: - prefix_block: KVCacheBlock = blocks[0] - assert self._req_to_computed_tokens[request_id] != 0 - if not prefix_block.is_null: - self.block_pool.free_blocks([prefix_block]) - blocks[0] = self.block_pool.null_block - - if num_blocks > 2 + self.kv_cache_spec.num_speculative_blocks: - assert num_blocks == 3 + self.kv_cache_spec.num_speculative_blocks - last_new_tokens = self._req_to_new_tokens[request_id] - assert last_new_tokens % self.block_size == 0 - assert last_new_tokens >= self.block_size - self.block_pool.free_blocks(blocks[-1:]) - blocks.pop() - else: - assert num_blocks == 0 or num_blocks == 2 + self.kv_cache_spec.num_speculative_blocks + if request_id in self._req_to_last_computed: + # TODO what if in decoding phase and enabled sps when accepted + # token is not a multiple of block size + # NOTE: pre block should not be freed becasue it may be used to copy + last_computed_tokens = self._req_to_last_computed[request_id] + target_idx = last_computed_tokens // self.block_size - 1 + if target_idx >= 0 and blocks[target_idx] != self._null_block: + self.block_pool.free_blocks([blocks[target_idx]]) + blocks[target_idx] = self._null_block + + self._req_to_last_computed[request_id] = num_computed_tokens def get_num_common_prefix_blocks(self, running_request_id: str) -> int: """ @@ -713,22 +703,18 @@ def get_num_blocks_to_allocate( request_id, num_tokens, new_computed_blocks ) else: - num_computed_tokens = self._req_to_computed_tokens.get(request_id, 0) - if num_computed_tokens == 0: - num_new_blocks = 1 + (self.kv_cache_spec.num_speculative_blocks - if self.kv_cache_spec.num_speculative_blocks > 0 else 0) - else: - num_new_blocks = 0 - - if self.kv_cache_spec.enable_caching: - num_new_tokens = num_tokens - num_computed_tokens - len(new_computed_blocks) * self.block_size - if num_computed_tokens != 0: - num_new_tokens -= self.kv_cache_spec.num_speculative_blocks - self._req_to_new_tokens[request_id] = num_new_tokens - # NOTE: last chunk may larger than block_size when using eagle. - if num_new_tokens >= self.block_size and num_new_tokens % self.block_size == 0: - num_new_blocks += 1 - + num_required_blocks = cdiv(num_tokens, self.block_size) + self.num_speculative_blocks + num_new_blocks = (num_required_blocks - len(new_computed_blocks) - + len(self.req_to_blocks[request_id])) + num_new_alloc_blocks = 0 + if num_new_blocks > 0: + # first prefill + if request_id not in self._allocated_reqs: + # if len(self.req_to_blocks[request_id]) == 0: + num_new_alloc_blocks = 1 + self.num_speculative_blocks + else: + num_new_alloc_blocks = 1 + # If a computed block of a request is an eviction candidate (in the # free queue and ref_cnt == 0), it will be changed from a free block # to a computed block when the request is allocated, so we also count @@ -736,22 +722,23 @@ def get_num_blocks_to_allocate( num_evictable_computed_blocks = sum( blk.ref_cnt == 0 and not blk.is_null for blk in new_computed_blocks) - return num_new_blocks + num_evictable_computed_blocks + + return num_new_alloc_blocks + num_evictable_computed_blocks def save_new_computed_blocks( self, request_id: str, new_computed_blocks: list[KVCacheBlock]) -> None: assert isinstance(self.kv_cache_spec, MambaSpec) - if envs.VLLM_USE_LIGHTER_MAMBA_CACHE: - if not self.kv_cache_spec.enable_caching: - return - if request_id not in self.num_cached_block: - if new_computed_blocks: - assert len(new_computed_blocks) == 1 or new_computed_blocks[-2].is_null - new_computed_blocks = new_computed_blocks[-1:] - else: - new_computed_blocks = [self.block_pool.null_block] - super().save_new_computed_blocks(request_id, new_computed_blocks) + # if envs.VLLM_USE_LIGHTER_MAMBA_CACHE: + # if not self.kv_cache_spec.enable_caching: + # return + # if request_id not in self.num_cached_block: + # if new_computed_blocks: + # assert len(new_computed_blocks) == 1 or new_computed_blocks[-2].is_null + # new_computed_blocks = new_computed_blocks[-1:] + # else: + # new_computed_blocks = [self.block_pool.null_block] + # super().save_new_computed_blocks(request_id, new_computed_blocks) def allocate_new_blocks( self, request_id: str, num_tokens: int @@ -767,54 +754,66 @@ def allocate_new_blocks( ) return super().allocate_new_blocks(request_id, num_tokens) else: - num_computed_tokens = self._req_to_computed_tokens.get(request_id, 0) - if num_computed_tokens == 0: - num_new_blocks = 1 + (self.kv_cache_spec.num_speculative_blocks - if self.kv_cache_spec.num_speculative_blocks > 0 else 0) - else: - assert num_tokens >= num_computed_tokens - num_new_blocks = 0 - - if self.kv_cache_spec.enable_caching: - num_new_tokens = self._req_to_new_tokens[request_id] - if num_new_tokens >= self.block_size and num_new_tokens % self.block_size == 0: - num_new_blocks += 1 - req_blocks: list[KVCacheBlock] = self.req_to_blocks[request_id] + num_required_blocks = cdiv(num_tokens, self.block_size) + self.num_speculative_blocks + num_new_blocks = (num_required_blocks - len(self.req_to_blocks[request_id])) if num_new_blocks <= 0: return [] - else: - new_blocks: list[KVCacheBlock] = self.block_pool.get_new_blocks(num_new_blocks) + else: + # first prefill chunk + # TODO: for mamba num_cached_block including null-blocks + new_blocks = [] + if request_id not in self._allocated_reqs: + self._allocated_reqs.add(request_id) + num_new_alloc_blocks = 1 + self.num_speculative_blocks + new_blocks.extend([self._null_block + for _ in range(num_new_blocks - num_new_alloc_blocks)]) + # new_alloc_blocks = self.block_pool.get_new_blocks(num_new_alloc_blocks) + else: + num_new_alloc_blocks = 1 + new_blocks.extend([self._null_block for _ in range(num_new_blocks - num_new_alloc_blocks)]) + if self.num_speculative_blocks > 0: + # step i: [0, 0, 0, a, sps_i_0, sps_i_1] + # step i+1(i.e. j): [0, 0, 0, a, 0, 0, 0, b, sps_j_0, sps_j_1] + # reuse blocks sps_i_0 and sps_i_1 as b and sps_j_0 + # new_alloc_blocks: [sps_j_1] + # new_blocks: [0, 0, 0, b, sps_j_0, sps_j_1] + req_blocks = req_blocks[:-self.num_speculative_blocks] + # TODO: reuse blocks 是否需要清除内存?尤其是decode + reuse_blocks = req_blocks[-self.num_speculative_blocks:] + new_blocks.extend(reuse_blocks) + new_alloc_blocks = self.block_pool.get_new_blocks(num_new_alloc_blocks) + new_blocks.extend(new_alloc_blocks) req_blocks.extend(new_blocks) return new_blocks def cache_blocks(self, request: Request, num_tokens: int) -> None: assert isinstance(self.kv_cache_spec, MambaSpec) - if envs.VLLM_USE_LIGHTER_MAMBA_CACHE: - num_computed_tokens = request.num_computed_tokens - num_new_tokens = self._req_to_new_tokens[request.request_id] - # NOTE:For sps, an extra block may be allocated but not cached - if (num_new_tokens >= self.block_size - and num_new_tokens % self.block_size == 0 - and num_tokens % self.block_size == 0): - assert num_new_tokens % self.block_size == 0 - assert num_computed_tokens % self.block_size == 0 - assert len(self.req_to_blocks[request.request_id]) == 3 + self.kv_cache_spec.num_speculative_blocks - self.block_pool.cache_full_block( - request=request, - block=self.req_to_blocks[request.request_id][-1], - cached_block_index=(num_tokens // self.block_size - 1), - block_size=self.block_size, - kv_cache_group_id=self.kv_cache_group_id - ) - self.num_cached_block[request.request_id] += 1 - else: - super().cache_blocks(request, num_tokens) + # if envs.VLLM_USE_LIGHTER_MAMBA_CACHE: + # num_computed_tokens = request.num_computed_tokens + # num_new_tokens = self._req_to_new_tokens[request.request_id] + # # NOTE:For sps, an extra block may be allocated but not cached + # if (num_new_tokens >= self.block_size + # and num_new_tokens % self.block_size == 0 + # and num_tokens % self.block_size == 0): + # assert num_new_tokens % self.block_size == 0 + # assert num_computed_tokens % self.block_size == 0 + # assert len(self.req_to_blocks[request.request_id]) == 3 + self.kv_cache_spec.num_speculative_blocks + # self.block_pool.cache_full_block( + # request=request, + # block=self.req_to_blocks[request.request_id][-1], + # cached_block_index=(num_tokens // self.block_size - 1), + # block_size=self.block_size, + # kv_cache_group_id=self.kv_cache_group_id + # ) + # self.num_cached_block[request.request_id] += 1 + # else: + # super().cache_blocks(request, num_tokens) def free(self, request_id: str) -> None: if envs.VLLM_USE_LIGHTER_MAMBA_CACHE: - self._req_to_computed_tokens.pop(request_id, None) - self._req_to_new_tokens.pop(request_id, None) + self._allocated_reqs.discard(request_id) + self._req_to_last_computed.pop(request_id, None) super().free(request_id) class CrossAttentionManager(SingleTypeKVCacheManager): diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 5ef995e949a1..c113fc9f2c4e 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1533,8 +1533,7 @@ def _build_attention_metadata( and self.cache_config.enable_prefix_caching and isinstance(kv_cache_group_spec.kv_cache_spec, MambaSpec) ): - blk_table_tensor = blk_table_tensor[:, 1:] - self._preprocess_mamba_prefix(scheduler_output, kv_cache_group_id, kv_cache_group_spec) + self._preprocess_mamba(scheduler_output, kv_cache_group_id, kv_cache_group_spec) common_attn_metadata = CommonAttentionMetadata( query_start_loc=query_start_loc, @@ -2632,6 +2631,8 @@ def _model_forward( def _mamba_copy_block(self, kv_cache_group_spec: KVCacheGroupSpec, src_block_id: int, dest_block_id: int): + if src_block_id == dest_block_id: + return forward_context = self.compilation_config.static_forward_context for layer_name in kv_cache_group_spec.layer_names: kv_caches: list[list[torch.Tensor]] = forward_context[layer_name].kv_cache @@ -2642,70 +2643,115 @@ def _mamba_copy_block(self, kv_cache_group_spec: KVCacheGroupSpec, for kv_cache_part in kv_cache: kv_cache_part[dest_block_id].copy_(kv_cache_part[src_block_id]) - def _preprocess_mamba_prefix(self, scheduler_output: "SchedulerOutput", - kv_cache_group_id: int, - kv_cache_group_spec: KVCacheGroupSpec, - ): + def _preprocess_mamba(self, scheduler_output: "SchedulerOutput", + kv_cache_group_id: int, + kv_cache_group_spec: KVCacheGroupSpec, + ): assert isinstance(kv_cache_group_spec.kv_cache_spec, MambaSpec) assert self.cache_config.enable_prefix_caching + block_size = kv_cache_group_spec.kv_cache_spec.block_size + num_speculative_blocks = kv_cache_group_spec.kv_cache_spec.num_speculative_blocks new_reqs: list[NewRequestData] = scheduler_output.scheduled_new_reqs for new_req in new_reqs: if new_req.num_computed_tokens == 0: continue block_ids: list[int] = new_req.block_ids[kv_cache_group_id] - assert block_ids[0] != 0, f'{block_ids=}' - prefix_block_id, dest_block_id = block_ids[0], block_ids[1] + prefix_block_idx = cdiv(new_req.num_computed_tokens, block_size) - 1 + dest_block_idx = len(block_ids) - 1 - num_speculative_blocks + prefix_block_id, dest_block_id = block_ids[prefix_block_idx], block_ids[dest_block_idx] self._mamba_copy_block(kv_cache_group_spec, prefix_block_id, dest_block_id) + cached_reqs: CachedRequestData = scheduler_output.scheduled_cached_reqs for i, resumed in enumerate(cached_reqs.resumed_from_preemption): - if not resumed: - continue group_block_ids: Optional[tuple[list[int], ...]] = cached_reqs.new_block_ids[i] - assert group_block_ids is not None - new_block_ids: list[int] = group_block_ids[kv_cache_group_id] - assert len(new_block_ids) >= 2 + kv_cache_group_spec.kv_cache_spec.num_speculative_blocks - if cached_reqs.num_computed_tokens[i] == 0: - continue - assert new_block_ids[0] != 0, f'{new_block_ids=}' - prefix_block_id, dest_block_id = new_block_ids[0], new_block_ids[1] - self._mamba_copy_block(kv_cache_group_spec, prefix_block_id, dest_block_id) - - - def _postprocess_mamba_cache(self, scheduler_output: "SchedulerOutput"): - assert self.cache_config.enable_prefix_caching - for kv_cache_group_id, kv_cache_group_spec in enumerate( - self.kv_cache_config.kv_cache_groups): - if not isinstance(kv_cache_group_spec.kv_cache_spec, MambaSpec): - continue - new_reqs: list[NewRequestData] = scheduler_output.scheduled_new_reqs - num_speculative_blocks = kv_cache_group_spec.kv_cache_spec.num_speculative_blocks - for new_req in new_reqs: - block_ids: list[int] = new_req.block_ids[kv_cache_group_id] - if len(block_ids) <= 2 + num_speculative_blocks: - continue - assert len(block_ids) == 3 + num_speculative_blocks - src_block_id, dest_block_id = block_ids[1], block_ids[-1] - self._mamba_copy_block(kv_cache_group_spec, src_block_id, dest_block_id) - cached_reqs: CachedRequestData = scheduler_output.scheduled_cached_reqs - for i, req_id in enumerate(cached_reqs.req_ids): - group_block_ids: Optional[tuple[list[int], ...]] = cached_reqs.new_block_ids[i] + num_compute_tokens = cached_reqs.num_computed_tokens[i] + if not resumed: + assert num_compute_tokens > 0 if group_block_ids is None: - assert not cached_reqs.resumed_from_preemption[i] continue new_block_ids: list[int] = group_block_ids[kv_cache_group_id] if not new_block_ids: - assert not cached_reqs.resumed_from_preemption[i] continue - if not cached_reqs.resumed_from_preemption[i]: - assert len(new_block_ids) == 1 - block_ids: list[int] = self.requests[req_id].block_ids[kv_cache_group_id] - src_block_id, dest_block_id = block_ids[1], new_block_ids[0] - else: - if len(new_block_ids) == 2 + num_speculative_blocks: - continue - assert len(new_block_ids) == 3 + num_speculative_blocks - src_block_id, dest_block_id = new_block_ids[1], new_block_ids[-1] + assert len(new_block_ids) >= 1 + num_speculative_blocks + block_ids: list[int] = self.requests[cached_reqs.req_ids[i]].block_ids[kv_cache_group_id] + src_block_idx = cdiv(num_compute_tokens, block_size) - 1 + dest_block_idx = len(new_block_ids) - 1 - num_speculative_blocks + src_block_id, dest_block_id = block_ids[src_block_idx], new_block_ids[dest_block_idx] self._mamba_copy_block(kv_cache_group_spec, src_block_id, dest_block_id) + else: + assert group_block_ids is not None + new_block_ids: list[int] = group_block_ids[kv_cache_group_id] + if num_compute_tokens == 0: + continue + prefix_block_idx = cdiv(num_compute_tokens, block_size) - 1 + dest_block_idx = len(new_block_ids) - 1 - num_speculative_blocks + prefix_block_id, dest_block_id = new_block_ids[prefix_block_idx], new_block_ids[dest_block_idx] + self._mamba_copy_block(kv_cache_group_spec, prefix_block_id, dest_block_id) + + # def _preprocess_mamba_prefix(self, scheduler_output: "SchedulerOutput", + # kv_cache_group_id: int, + # kv_cache_group_spec: KVCacheGroupSpec, + # ): + # assert isinstance(kv_cache_group_spec.kv_cache_spec, MambaSpec) + # assert self.cache_config.enable_prefix_caching + # new_reqs: list[NewRequestData] = scheduler_output.scheduled_new_reqs + # for new_req in new_reqs: + # if new_req.num_computed_tokens == 0: + # continue + # block_ids: list[int] = new_req.block_ids[kv_cache_group_id] + # assert block_ids[0] != 0, f'{block_ids=}' + # prefix_block_id, dest_block_id = block_ids[0], block_ids[1] + # self._mamba_copy_block(kv_cache_group_spec, prefix_block_id, dest_block_id) + # cached_reqs: CachedRequestData = scheduler_output.scheduled_cached_reqs + # for i, resumed in enumerate(cached_reqs.resumed_from_preemption): + # if not resumed: + # continue + # group_block_ids: Optional[tuple[list[int], ...]] = cached_reqs.new_block_ids[i] + # assert group_block_ids is not None + # new_block_ids: list[int] = group_block_ids[kv_cache_group_id] + # assert len(new_block_ids) >= 2 + kv_cache_group_spec.kv_cache_spec.num_speculative_blocks + # if cached_reqs.num_computed_tokens[i] == 0: + # continue + # assert new_block_ids[0] != 0, f'{new_block_ids=}' + # prefix_block_id, dest_block_id = new_block_ids[0], new_block_ids[1] + # self._mamba_copy_block(kv_cache_group_spec, prefix_block_id, dest_block_id) + + + # def _postprocess_mamba_cache(self, scheduler_output: "SchedulerOutput"): + # assert self.cache_config.enable_prefix_caching + # for kv_cache_group_id, kv_cache_group_spec in enumerate( + # self.kv_cache_config.kv_cache_groups): + # if not isinstance(kv_cache_group_spec.kv_cache_spec, MambaSpec): + # continue + # new_reqs: list[NewRequestData] = scheduler_output.scheduled_new_reqs + # num_speculative_blocks = kv_cache_group_spec.kv_cache_spec.num_speculative_blocks + # for new_req in new_reqs: + # block_ids: list[int] = new_req.block_ids[kv_cache_group_id] + # if len(block_ids) <= 2 + num_speculative_blocks: + # continue + # assert len(block_ids) == 3 + num_speculative_blocks + # src_block_id, dest_block_id = block_ids[1], block_ids[-1] + # self._mamba_copy_block(kv_cache_group_spec, src_block_id, dest_block_id) + # cached_reqs: CachedRequestData = scheduler_output.scheduled_cached_reqs + # for i, req_id in enumerate(cached_reqs.req_ids): + # group_block_ids: Optional[tuple[list[int], ...]] = cached_reqs.new_block_ids[i] + # if group_block_ids is None: + # assert not cached_reqs.resumed_from_preemption[i] + # continue + # new_block_ids: list[int] = group_block_ids[kv_cache_group_id] + # if not new_block_ids: + # assert not cached_reqs.resumed_from_preemption[i] + # continue + # if not cached_reqs.resumed_from_preemption[i]: + # assert len(new_block_ids) == 1 + # block_ids: list[int] = self.requests[req_id].block_ids[kv_cache_group_id] + # src_block_id, dest_block_id = block_ids[1], new_block_ids[0] + # else: + # if len(new_block_ids) == 2 + num_speculative_blocks: + # continue + # assert len(new_block_ids) == 3 + num_speculative_blocks + # src_block_id, dest_block_id = new_block_ids[1], new_block_ids[-1] + # self._mamba_copy_block(kv_cache_group_spec, src_block_id, dest_block_id) @torch.inference_mode() @@ -2998,10 +3044,10 @@ def sample_tokens( scheduler_output, grammar_output, self.input_batch, logits ) - if (envs.VLLM_USE_LIGHTER_MAMBA_CACHE - and self.cache_config.enable_prefix_caching - ): - self._postprocess_mamba_cache(scheduler_output) + # if (envs.VLLM_USE_LIGHTER_MAMBA_CACHE + # and self.cache_config.enable_prefix_caching + # ): + # self._postprocess_mamba_cache(scheduler_output) with record_function_or_nullcontext("gpu_model_runner: sample"): sampler_output = self._sample(logits, spec_decode_metadata) From ee08f54995b00425e2cfb4c29a6d1dcb6099091a Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Mon, 24 Nov 2025 17:20:16 +0000 Subject: [PATCH 005/130] update and fix bugs Signed-off-by: huanghaoyan.hhy --- vllm/v1/attention/backends/gdn_attn.py | 1 + vllm/v1/core/kv_cache_manager.py | 4 +- vllm/v1/core/sched/scheduler.py | 84 ++++++++++---------- vllm/v1/core/single_type_kv_cache_manager.py | 38 --------- vllm/v1/worker/gpu_model_runner.py | 1 + 5 files changed, 44 insertions(+), 84 deletions(-) diff --git a/vllm/v1/attention/backends/gdn_attn.py b/vllm/v1/attention/backends/gdn_attn.py index eaa506c91770..9ba3f4d74b73 100644 --- a/vllm/v1/attention/backends/gdn_attn.py +++ b/vllm/v1/attention/backends/gdn_attn.py @@ -58,6 +58,7 @@ class GDNAttentionMetadata: batch_ptr: torch.Tensor | None = None token_chunk_offset_ptr: torch.Tensor | None = None +# TODO: need to move, and called by all mamba builders def mamba_gather_indices(common_attn_metadata: CommonAttentionMetadata, block_size: int, num_blocks: int): diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index b4cd44ca0aea..b3d6bf5bc6a5 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -305,9 +305,7 @@ def allocate_slots( "Computed blocks should be empty when prefix caching is disabled" ) - if (envs.VLLM_USE_LIGHTER_MAMBA_CACHE - or new_computed_block_list is not self.empty_kv_cache_blocks.blocks - ): + if new_computed_block_list is not self.empty_kv_cache_blocks.blocks: # Append the new computed blocks to the request blocks until now to # avoid the case where the new blocks cannot be allocated. self.coordinator.save_new_computed_blocks( diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index b12faa0b8956..daa9fe622ad3 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -192,9 +192,37 @@ def __init__( def _has_mamba_spec(self) -> bool: has_mamba: bool = any(isinstance(spec.kv_cache_spec, MambaSpec) - for spec in self.kv_cache_config.kv_cache_groups) + for spec in self.kv_cache_config.kv_cache_groups) assert not has_mamba or self.vllm_config.model_config.is_hybrid return has_mamba + + def _mamba_block_aligned_split(self, request: Request, num_new_tokens: int) -> int: + if (self.cache_config.enable_prefix_caching + and self._has_mamba_spec()): + # To enable block-aligned caching of the Mamba state, `num_new_tokens` + # must be a multiple of `block_size`. + # As an exception, if `num_new_tokens` is less than `block_size`, the + # state is simply not cached, requiring no special handling. + # Additionally, when Eagle mode is enabled, FullAttn prunes the last + # matching block. To prevent this from causing a Mamba cache miss, the + # last chunk must be larger than `block_size`. + block_size = self.cache_config.block_size + if request.num_output_tokens == 0: # prefill + last_cache_position = request.num_prompt_tokens - request.num_prompt_tokens % block_size + # eagle prune + if self.use_eagle: + last_cache_position = max(last_cache_position - block_size, 0) + num_computed_tokens_after_prefill = request.num_computed_tokens + num_new_tokens + if num_computed_tokens_after_prefill < last_cache_position: + # align to block_size + num_new_tokens = num_new_tokens // block_size * block_size + elif request.num_computed_tokens < last_cache_position < num_computed_tokens_after_prefill: + # force to cache the last chunk + num_new_tokens = last_cache_position - request.num_computed_tokens + else: + # prefill the last few tokens + pass + return num_new_tokens def schedule(self) -> SchedulerOutput: # NOTE(woosuk) on the scheduling algorithm: @@ -248,29 +276,7 @@ def schedule(self) -> SchedulerOutput: num_new_tokens = ( self.scheduler_config.long_prefill_token_threshold) - if (envs.VLLM_USE_LIGHTER_MAMBA_CACHE - and self.cache_config.enable_prefix_caching - and self._has_mamba_spec()): - # To enable block-aligned caching of the Mamba state, `num_new_tokens` - # must be a multiple of `block_size`. - # As an exception, if `num_new_tokens` is less than `block_size`, the - # state is simply not cached, requiring no special handling. - # Additionally, when Eagle mode is enabled, FullAttn prunes the last - # matching block. To prevent this from causing a Mamba cache miss, the - # last chunk must be larger than `block_size`. - block_size = self.block_size - max_last_chunk = block_size * (2 if self.use_eagle else 1) - if num_new_tokens < max_last_chunk: - num_new_tokens = min(num_new_tokens, token_budget) - else: - ori_num_new_tokens = num_new_tokens - num_new_tokens = min(num_new_tokens, token_budget) - num_new_tokens = num_new_tokens // block_size * block_size - if self.use_eagle and ori_num_new_tokens - num_new_tokens < block_size: - assert num_new_tokens >= block_size - num_new_tokens -= block_size - else: - num_new_tokens = min(num_new_tokens, token_budget) + num_new_tokens = min(num_new_tokens, token_budget) # Make sure the input position does not exceed the max model len or # request's max_tokens. @@ -299,6 +305,10 @@ def schedule(self) -> SchedulerOutput: encoder_compute_budget, ) + if envs.VLLM_USE_LIGHTER_MAMBA_CACHE: + num_new_tokens = self._mamba_block_aligned_split( + request, num_new_tokens) + if num_new_tokens == 0: # The request cannot be scheduled because one of the following # reasons: @@ -553,25 +563,7 @@ def schedule(self) -> SchedulerOutput: skipped_waiting_requests.prepend_request(request) continue - if (envs.VLLM_USE_LIGHTER_MAMBA_CACHE - and self.cache_config.enable_prefix_caching - and self._has_mamba_spec()): - block_size = self.block_size - max_last_chunk = block_size * (2 if self.use_eagle else 1) - if num_new_tokens < max_last_chunk: - num_new_tokens = min(num_new_tokens, token_budget) - else: - ori_num_new_tokens = num_new_tokens - num_new_tokens = min(num_new_tokens, token_budget) - num_new_tokens = num_new_tokens // block_size * block_size - if self.use_eagle and ori_num_new_tokens - num_new_tokens < block_size: - assert num_new_tokens >= block_size - num_new_tokens -= block_size - if num_new_tokens == 0: - token_budget = 0 - break - else: - num_new_tokens = min(num_new_tokens, token_budget) + num_new_tokens = min(num_new_tokens, token_budget) assert num_new_tokens > 0 # Schedule encoder inputs. @@ -591,6 +583,12 @@ def schedule(self) -> SchedulerOutput: # The request cannot be scheduled. break + if envs.VLLM_USE_LIGHTER_MAMBA_CACHE: + num_new_tokens = self._mamba_block_aligned_split( + request, num_new_tokens) + if num_new_tokens == 0: + break + # Handles an edge case when P/D Disaggregation # is used with Spec Decoding where an # extra block gets allocated which diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 93decb5155ca..3da23a6e9bd1 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -725,21 +725,6 @@ def get_num_blocks_to_allocate( return num_new_alloc_blocks + num_evictable_computed_blocks - def save_new_computed_blocks( - self, request_id: str, - new_computed_blocks: list[KVCacheBlock]) -> None: - assert isinstance(self.kv_cache_spec, MambaSpec) - # if envs.VLLM_USE_LIGHTER_MAMBA_CACHE: - # if not self.kv_cache_spec.enable_caching: - # return - # if request_id not in self.num_cached_block: - # if new_computed_blocks: - # assert len(new_computed_blocks) == 1 or new_computed_blocks[-2].is_null - # new_computed_blocks = new_computed_blocks[-1:] - # else: - # new_computed_blocks = [self.block_pool.null_block] - # super().save_new_computed_blocks(request_id, new_computed_blocks) - def allocate_new_blocks( self, request_id: str, num_tokens: int ) -> list[KVCacheBlock]: @@ -786,29 +771,6 @@ def allocate_new_blocks( new_blocks.extend(new_alloc_blocks) req_blocks.extend(new_blocks) return new_blocks - - def cache_blocks(self, request: Request, num_tokens: int) -> None: - assert isinstance(self.kv_cache_spec, MambaSpec) - # if envs.VLLM_USE_LIGHTER_MAMBA_CACHE: - # num_computed_tokens = request.num_computed_tokens - # num_new_tokens = self._req_to_new_tokens[request.request_id] - # # NOTE:For sps, an extra block may be allocated but not cached - # if (num_new_tokens >= self.block_size - # and num_new_tokens % self.block_size == 0 - # and num_tokens % self.block_size == 0): - # assert num_new_tokens % self.block_size == 0 - # assert num_computed_tokens % self.block_size == 0 - # assert len(self.req_to_blocks[request.request_id]) == 3 + self.kv_cache_spec.num_speculative_blocks - # self.block_pool.cache_full_block( - # request=request, - # block=self.req_to_blocks[request.request_id][-1], - # cached_block_index=(num_tokens // self.block_size - 1), - # block_size=self.block_size, - # kv_cache_group_id=self.kv_cache_group_id - # ) - # self.num_cached_block[request.request_id] += 1 - # else: - # super().cache_blocks(request, num_tokens) def free(self, request_id: str) -> None: if envs.VLLM_USE_LIGHTER_MAMBA_CACHE: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index c113fc9f2c4e..0aef8407beda 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2674,6 +2674,7 @@ def _preprocess_mamba(self, scheduler_output: "SchedulerOutput", continue assert len(new_block_ids) >= 1 + num_speculative_blocks block_ids: list[int] = self.requests[cached_reqs.req_ids[i]].block_ids[kv_cache_group_id] + # TODO: for sps, need to handle sps blocks src_block_idx = cdiv(num_compute_tokens, block_size) - 1 dest_block_idx = len(new_block_ids) - 1 - num_speculative_blocks src_block_id, dest_block_id = block_ids[src_block_idx], new_block_ids[dest_block_idx] From ce840b303e1b19abcfad52cc4a94961d221f2738 Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Mon, 24 Nov 2025 17:55:36 +0000 Subject: [PATCH 006/130] fix bugs after rebasing Signed-off-by: huanghaoyan.hhy --- vllm/engine/arg_utils.py | 2 -- vllm/v1/core/single_type_kv_cache_manager.py | 2 +- vllm/v1/worker/gpu_model_runner.py | 7 +++++-- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 7a658ff95cd2..c3a9c5856eb2 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1434,8 +1434,6 @@ def create_engine_config( and self.enable_prefix_caching and model_config.is_hybrid ): - assert envs.VLLM_USE_V1, ( - "Prefix caching for hybrid models requires the V1 engine.") assert self.enable_chunked_prefill, ( "Prefix caching for hybrid models requires chunked prefill.") diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 3da23a6e9bd1..a25fbc6498b0 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -764,7 +764,7 @@ def allocate_new_blocks( # new_alloc_blocks: [sps_j_1] # new_blocks: [0, 0, 0, b, sps_j_0, sps_j_1] req_blocks = req_blocks[:-self.num_speculative_blocks] - # TODO: reuse blocks 是否需要清除内存?尤其是decode + # TODO: reuse blocks. if we need clean? especially in decode reuse_blocks = req_blocks[-self.num_speculative_blocks:] new_blocks.extend(reuse_blocks) new_alloc_blocks = self.block_pool.get_new_blocks(num_new_alloc_blocks) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 0aef8407beda..2101082207ce 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1429,6 +1429,7 @@ def _prepare_inputs( def _build_attention_metadata( self, + scheduler_output: "SchedulerOutput", total_num_scheduled_tokens: int, max_num_scheduled_tokens: int, num_reqs: int, @@ -1531,9 +1532,11 @@ def _build_attention_metadata( if (envs.VLLM_USE_LIGHTER_MAMBA_CACHE and self.cache_config.enable_prefix_caching - and isinstance(kv_cache_group_spec.kv_cache_spec, MambaSpec) + and isinstance(kv_cache_group.kv_cache_spec, MambaSpec) ): - self._preprocess_mamba(scheduler_output, kv_cache_group_id, kv_cache_group_spec) + # TODO: temporarily add `scheduler_output` as it's missing in the new API. + # maybe move the _preprocess_mamba function outside. + self._preprocess_mamba(scheduler_output, kv_cache_gid, kv_cache_group) common_attn_metadata = CommonAttentionMetadata( query_start_loc=query_start_loc, From 865ea28825591e3cd102052cc74c8c5b535e3c0e Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Wed, 26 Nov 2025 05:01:32 +0000 Subject: [PATCH 007/130] add running script (just for testing) Signed-off-by: huanghaoyan.hhy --- my_tests/run_op_prefix_cache.sh | 72 +++++++++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) create mode 100755 my_tests/run_op_prefix_cache.sh diff --git a/my_tests/run_op_prefix_cache.sh b/my_tests/run_op_prefix_cache.sh new file mode 100755 index 000000000000..f7552a4f1972 --- /dev/null +++ b/my_tests/run_op_prefix_cache.sh @@ -0,0 +1,72 @@ +#!/bin/bash + +PORT=8235 +TP=2 +MAX_MODEL_LEN=262144 +# MAX_MODEL_LEN=131072 + +DO_NSYS=0 + + +MODEL_DIR=/mnt/disk0/huanghaoyan.hhy/Qwen3-Next-80B-A3B-Instruct/ +echo "MODEL_DIR: $MODEL_DIR" + +NSYS_OUTPUT="qwen_next_h20_tp1_nreq1_fp8_mtp1_prefixcache_2" +NSYS="" +if (( DO_NSYS == 1 )); then + NSYS="nsys profile -c cudaProfilerApi --cuda-graph-trace node -o $NSYS_OUTPUT" +fi + +env_vars=( + # "CUDA_LAUNCH_BLOCKING=0" + "CUDA_VISIBLE_DEVICES=0,1,2,3" + "VLLM_USE_LIGHTER_MAMBA_CACHE=1" + # "CUDA_VISIBLE_DEVICES=6,7" + # "VLLM_ATTENTION_BACKEND=FLASH_ATTN" + # "VLLM_FLASH_ATTN_VERSION=3" + # "VLLM_ALLOW_LONG_MAX_MODEL_LEN=1" + # "OMP_NUM_THREADS=1" + # "VLLM_USE_V1=1" + # "VLLM_LOG_REQ_KV_LENS=1" + # "VLLM_USE_FLASHINFER_SAMPLER=0" +) + +for var in "${env_vars[@]}"; do + var_name="${var%%=*}" + var_value="${var#*=}" + echo -e "\t$var_name=$var_value" +done + +CMD=( env ) +for var in "${env_vars[@]}"; do + CMD+=( "$var" ) +done +CMD+=( + $NSYS vllm serve + $MODEL_DIR + # --trust-remote-code + --port "$PORT" + --gpu-memory-utilization 0.9 + -tp $TP + --enforce-eager + # --no-enable-prefix-caching + --enable-prefix-caching + # --no-enable-chunked-prefill + --enable-chunked-prefill + --max-num-batched-tokens 8192 + --distributed-executor-backend mp + --block-size 64 + --max-num-seqs 128 + # --max-num-seqs 16 + # --max-model-len $MAX_MODEL_LEN + # --max-seq-len-to-capture $MAX_MODEL_LEN + # --compilation-config "{\"use_inductor\": false, \"cudagraph_mode\": \"FULL_DECODE_ONLY\", \"custom_ops\": [\"all\"]}" + # --speculative-config "{\"method\": \"qwen3_next_mtp\", \"num_speculative_tokens\": 3}" + # --hf_overrides "{\"max_position_embeddings\": $MAX_MODEL_LEN}" +) + +echo -e "\nExecuting command:" +printf " %s" "${CMD[@]}" +echo -e "\n" + +"${CMD[@]}" From 7316d035b3a4b4e9eb288b5faf1d46568e47332a Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Wed, 26 Nov 2025 05:03:31 +0000 Subject: [PATCH 008/130] add debug logs Signed-off-by: huanghaoyan.hhy --- vllm/entrypoints/openai/serving_chat.py | 1 + vllm/v1/core/sched/scheduler.py | 10 ++++++++++ vllm/v1/core/single_type_kv_cache_manager.py | 19 ++++++++++++++++++- vllm/v1/worker/gpu_model_runner.py | 8 ++++++++ 4 files changed, 37 insertions(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 6cc685acd672..d58c6e699928 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -179,6 +179,7 @@ async def create_chat_completion( if self.engine_client.errored: raise self.engine_client.dead_error + logger.info(f'>>> [DEBUG] create_chat: req_id={request.request_id} msg={request.messages}') try: lora_request = self._maybe_get_adapters( request, supports_default_mm_loras=True diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index daa9fe622ad3..7c9f2f54d7d6 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -258,6 +258,8 @@ def schedule(self) -> SchedulerOutput: while req_index < len(self.running) and token_budget > 0: request = self.running[req_index] + logger.info(f'>>> [DEBUG] Scheduler: schedule RUNING: req_id={request.request_id}, ' + f'num_prompt_tokens={request.num_prompt_tokens=}') # Ensure new tokens for a request in the prefill phase do not contain # sps tokens, especially in the last prefill chunk. For a hybrid-model, # extra sps tokens would corrupt the generated Mamba state. @@ -454,6 +456,8 @@ def schedule(self) -> SchedulerOutput: break request = self.waiting.peek_request() + logger.info(f'>>> [DEBUG] Scheduler: schedule WAITING: req_id={request.request_id}, ' + f'num_prompt_tokens={request.num_prompt_tokens=}') # KVTransfer: skip request if still waiting for remote kvs. if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: @@ -504,6 +508,8 @@ def schedule(self) -> SchedulerOutput: new_computed_blocks, num_new_local_computed_tokens = ( self.kv_cache_manager.get_computed_blocks(request) ) + logger.info(f'>>> [DEBUG] Scheduler: get_computed_blk: req_id={request.request_id},' + f'{num_new_local_computed_tokens=}, {new_computed_blocks.blocks=}') # Get externally-cached tokens if using a KVConnector. if self.connector is not None: @@ -749,6 +755,10 @@ def schedule(self) -> SchedulerOutput: self.prev_step_scheduled_req_ids.clear() self.prev_step_scheduled_req_ids.update(num_scheduled_tokens.keys()) + logger.info('>>> [DEBUG] Scheduler: new_reqs:' + f'{[(reqdata.req_id, reqdata.block_ids) for reqdata in new_reqs_data]}') + logger.info('>>> [DEBUG] Scheduler: cached_reqs:' + f'{[(req_id, cached_reqs_data.new_block_ids[i]) for i, req_id in enumerate(cached_reqs_data.req_ids)]}') scheduler_output = SchedulerOutput( scheduled_new_reqs=new_reqs_data, scheduled_cached_reqs=cached_reqs_data, diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index a25fbc6498b0..cee0952df1d3 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -64,6 +64,10 @@ def __init__( self.kv_cache_group_id = kv_cache_group_id self._null_block = block_pool.null_block + def print(self, *args, **kwargs): + new_args = (f">>> [KvGrp {self.kv_cache_group_id}] ", ) + args + print(*new_args, **kwargs) + def get_num_blocks_to_allocate( self, request_id: str, @@ -673,6 +677,7 @@ def remove_skipped_blocks(self, request_id: str, last_computed_tokens = self._req_to_last_computed[request_id] target_idx = last_computed_tokens // self.block_size - 1 if target_idx >= 0 and blocks[target_idx] != self._null_block: + self.print(f'Mamba.remove_skipped: Freeing block {target_idx=}, {blocks[target_idx]=}') self.block_pool.free_blocks([blocks[target_idx]]) blocks[target_idx] = self._null_block @@ -722,9 +727,19 @@ def get_num_blocks_to_allocate( num_evictable_computed_blocks = sum( blk.ref_cnt == 0 and not blk.is_null for blk in new_computed_blocks) - + + self.print(f'Mamba.get_nblks: {request_id=}, {num_tokens=}, {num_new_blocks=}, ' + f'{num_new_alloc_blocks=}, {num_evictable_computed_blocks=}') + return num_new_alloc_blocks + num_evictable_computed_blocks + def save_new_computed_blocks( + self, request_id: str, + new_computed_blocks: list[KVCacheBlock]) -> None: + assert isinstance(self.kv_cache_spec, MambaSpec) + self.print(f'Mamba.save_computed: {request_id=}, {new_computed_blocks=}') + super().save_new_computed_blocks(request_id, new_computed_blocks) + def allocate_new_blocks( self, request_id: str, num_tokens: int ) -> list[KVCacheBlock]: @@ -743,6 +758,7 @@ def allocate_new_blocks( num_required_blocks = cdiv(num_tokens, self.block_size) + self.num_speculative_blocks num_new_blocks = (num_required_blocks - len(self.req_to_blocks[request_id])) if num_new_blocks <= 0: + self.print(f'Mamba.alloc_blks: {request_id=}, {num_tokens=}, new_blocks=[], {req_blocks=}') return [] else: # first prefill chunk @@ -770,6 +786,7 @@ def allocate_new_blocks( new_alloc_blocks = self.block_pool.get_new_blocks(num_new_alloc_blocks) new_blocks.extend(new_alloc_blocks) req_blocks.extend(new_blocks) + self.print(f'Mamba.alloc_blks: {request_id=}, {num_tokens=}, {new_blocks=}, {req_blocks=}') return new_blocks def free(self, request_id: str) -> None: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 2101082207ce..44197402c817 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2661,7 +2661,9 @@ def _preprocess_mamba(self, scheduler_output: "SchedulerOutput", block_ids: list[int] = new_req.block_ids[kv_cache_group_id] prefix_block_idx = cdiv(new_req.num_computed_tokens, block_size) - 1 dest_block_idx = len(block_ids) - 1 - num_speculative_blocks + logger.info(f'>>> [DEBUG] Worker: preprocess mamba for NEW: {new_req.req_id=}, {kv_cache_group_id=}, {prefix_block_idx=}, {dest_block_idx=}') prefix_block_id, dest_block_id = block_ids[prefix_block_idx], block_ids[dest_block_idx] + logger.info(f'>>> [DEBUG] Worker: preprocess mamba for NEW: {new_req.req_id=}, {kv_cache_group_id=}, copy {prefix_block_id=} -> {dest_block_id=}') self._mamba_copy_block(kv_cache_group_spec, prefix_block_id, dest_block_id) cached_reqs: CachedRequestData = scheduler_output.scheduled_cached_reqs @@ -2680,7 +2682,10 @@ def _preprocess_mamba(self, scheduler_output: "SchedulerOutput", # TODO: for sps, need to handle sps blocks src_block_idx = cdiv(num_compute_tokens, block_size) - 1 dest_block_idx = len(new_block_ids) - 1 - num_speculative_blocks + logger.info(f'>>> [DEBUG] Worker: preprocess mamba for RUN: {cached_reqs.req_ids[i]=}, {kv_cache_group_id=}, {src_block_idx=}, {dest_block_idx=}') src_block_id, dest_block_id = block_ids[src_block_idx], new_block_ids[dest_block_idx] + logger.info(f'>>> [DEBUG] Worker: req_id={cached_reqs.req_ids[i]}, {block_ids=}') + logger.info(f'>>> [DEBUG] Worker: preprocess mamba for RUN: {cached_reqs.req_ids[i]=}, {kv_cache_group_id=}, copy {src_block_id=} -> {dest_block_id=}') self._mamba_copy_block(kv_cache_group_spec, src_block_id, dest_block_id) else: assert group_block_ids is not None @@ -2689,7 +2694,9 @@ def _preprocess_mamba(self, scheduler_output: "SchedulerOutput", continue prefix_block_idx = cdiv(num_compute_tokens, block_size) - 1 dest_block_idx = len(new_block_ids) - 1 - num_speculative_blocks + logger.info(f'>>> [DEBUG] Worker: preprocess mamba for RUN: {cached_reqs.req_ids[i]=}, {kv_cache_group_id=}, {prefix_block_id=}, {dest_block_idx=}') prefix_block_id, dest_block_id = new_block_ids[prefix_block_idx], new_block_ids[dest_block_idx] + logger.info(f'>>> [DEBUG] Worker: preprocess mamba for RUN: {cached_reqs.req_ids[i]=}, {kv_cache_group_id=}, copy {prefix_block_id=} -> {dest_block_id=}') self._mamba_copy_block(kv_cache_group_spec, prefix_block_id, dest_block_id) # def _preprocess_mamba_prefix(self, scheduler_output: "SchedulerOutput", @@ -2856,6 +2863,7 @@ def execute_model( use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 attn_metadata, spec_decode_common_attn_metadata = ( self._build_attention_metadata( + scheduler_output=scheduler_output, total_num_scheduled_tokens=total_num_scheduled_tokens, max_num_scheduled_tokens=max_num_scheduled_tokens, num_reqs=num_reqs, From ad37d099faa64ca33e51a29a4bc9c2b7dfc10ab9 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 25 Nov 2025 22:51:36 -0800 Subject: [PATCH 009/130] fix schedule Signed-off-by: Chen Zhang --- examples/offline_inference/run.py | 42 +++++++++++++++++++++++++++++++ vllm/v1/core/sched/scheduler.py | 15 +++++++---- 2 files changed, 52 insertions(+), 5 deletions(-) create mode 100644 examples/offline_inference/run.py diff --git a/examples/offline_inference/run.py b/examples/offline_inference/run.py new file mode 100644 index 000000000000..5fdc255cbb28 --- /dev/null +++ b/examples/offline_inference/run.py @@ -0,0 +1,42 @@ +from vllm import LLM, SamplingParams +import time + +def main(): + MODEL = "Qwen/Qwen3-Next-80B-A3B-Instruct" + PROMPT_MULTIPLE = 3 + sampling_params = SamplingParams(temperature=0.0, max_tokens=20) + prefix = ( # examples/offline_inference/prefix_caching.py + "Your name is QQQQ " + "You are an expert school principal, skilled in effectively managing " + "faculty and staff. Draft 10-15 questions for a potential first grade " + "Head Teacher for my K-12, all-girls', independent school that emphasizes " + "community, joyful discovery, and life-long learning. The candidate is " + "coming in for a first-round panel interview for a 8th grade Math " + "teaching role. They have 5 years of previous teaching experience " + "as an assistant teacher at a co-ed, public school with experience " + "in middle school math teaching. ") + prefix2 = ("Based on these information, fulfill " + "the following paragraph: ") + prompt = PROMPT_MULTIPLE * prefix + prefix2 + "Hello, my name is" + # print('Prompt length:', ) + # for APC in [False, True]: + for APC in [True]: + engine = LLM(model=MODEL, enable_prefix_caching=APC, enforce_eager=True, tensor_parallel_size=4, + # load_format="dummy" + ) + for i in range(3): + if i == 0: + print('Warm-up') + if i == 1: + print('Measuring') + start_time = time.time() + outputs = engine.generate(prompt, sampling_params) + print('APC:', APC, i, f"Generated text: {outputs[0].outputs[0].text!r}") + # for m in engine.llm_engine.get_metrics(): + # if 'vllm:prefix_cache_hits' in m.name: + # print(m.name, m.value) + print('APC:', APC, "loop took --- %s seconds ---" % (time.time() - start_time)) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 7c9f2f54d7d6..3eecb3f46edb 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -189,6 +189,7 @@ def __init__( ) self.use_pp = self.parallel_config.pipeline_parallel_size > 1 self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER + print(f">>> [DEBUG] Scheduler: init enable_prefix_caching={self.cache_config.enable_prefix_caching} block_size={self.block_size} kv_cache_config={self.kv_cache_config}") def _has_mamba_spec(self) -> bool: has_mamba: bool = any(isinstance(spec.kv_cache_spec, MambaSpec) @@ -196,7 +197,8 @@ def _has_mamba_spec(self) -> bool: assert not has_mamba or self.vllm_config.model_config.is_hybrid return has_mamba - def _mamba_block_aligned_split(self, request: Request, num_new_tokens: int) -> int: + def _mamba_block_aligned_split(self, request: Request, num_new_tokens: int, num_new_local_computed_tokens: int=0, num_external_computed_tokens: int=0) -> int: + assert num_external_computed_tokens == 0, "External KV connector is not verified yet" if (self.cache_config.enable_prefix_caching and self._has_mamba_spec()): # To enable block-aligned caching of the Mamba state, `num_new_tokens` @@ -212,19 +214,21 @@ def _mamba_block_aligned_split(self, request: Request, num_new_tokens: int) -> i # eagle prune if self.use_eagle: last_cache_position = max(last_cache_position - block_size, 0) - num_computed_tokens_after_prefill = request.num_computed_tokens + num_new_tokens + num_computed_tokens = request.num_computed_tokens + num_new_local_computed_tokens + num_external_computed_tokens + num_computed_tokens_after_prefill = num_computed_tokens + num_new_tokens if num_computed_tokens_after_prefill < last_cache_position: # align to block_size num_new_tokens = num_new_tokens // block_size * block_size - elif request.num_computed_tokens < last_cache_position < num_computed_tokens_after_prefill: + elif num_computed_tokens < last_cache_position < num_computed_tokens_after_prefill: # force to cache the last chunk - num_new_tokens = last_cache_position - request.num_computed_tokens + num_new_tokens = last_cache_position - num_computed_tokens else: # prefill the last few tokens pass return num_new_tokens def schedule(self) -> SchedulerOutput: + print(f">>> [DEBUG] Scheduler: schedule new step") # NOTE(woosuk) on the scheduling algorithm: # There's no "decoding phase" nor "prefill phase" in the scheduler. # Each request just has the num_computed_tokens and @@ -374,6 +378,7 @@ def schedule(self) -> SchedulerOutput: req_index -= 1 else: preempted_req = self.running.pop() + print(f">>> [DEBUG] Scheduler: preempted request {preempted_req.request_id}") self.kv_cache_manager.free(preempted_req) self.encoder_cache_manager.free(preempted_req) @@ -591,7 +596,7 @@ def schedule(self) -> SchedulerOutput: if envs.VLLM_USE_LIGHTER_MAMBA_CACHE: num_new_tokens = self._mamba_block_aligned_split( - request, num_new_tokens) + request, num_new_tokens, num_new_local_computed_tokens, num_external_computed_tokens) if num_new_tokens == 0: break From f1295e586ac0f2bee3ff6dd3efb8adcb11fbedb8 Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Wed, 26 Nov 2025 06:01:47 +0000 Subject: [PATCH 010/130] add test script (just for testing) Signed-off-by: huanghaoyan.hhy --- my_tests/test_mamba_cache.py | 262 +++++++++++++++++++++++++++++++++++ 1 file changed, 262 insertions(+) create mode 100644 my_tests/test_mamba_cache.py diff --git a/my_tests/test_mamba_cache.py b/my_tests/test_mamba_cache.py new file mode 100644 index 000000000000..53bbd305135c --- /dev/null +++ b/my_tests/test_mamba_cache.py @@ -0,0 +1,262 @@ +#!/usr/bin/env python + +from enum import Enum, auto +import time +from typing import List +import requests +import multiprocessing as mp +import sys + +class TestType(Enum): + Dummy = auto() + Real = auto() + +SEED = 1234 +SEED = None +PORT = 8235 +NUM_REQUESTS = 1 +MAX_NEW_TOKENS = 1024 +IGNORE_EOS = False +TEST_TYPE = TestType.Real +IS_WARMUP = False +ONE_PROMPT = [] +# KEY = time.time() +# ONE_PROMPT = [f"There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there.\n\nThe pass key is {KEY}. Remember it. {KEY} is the pass key.\n " + \ +# "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. " * 200 + \ +# "The block is red. The sky is yello. The sun is orange. Here we go. There and back again. " * 200 + \ +# "\nWhat is the pass key?"] +# Tom Eric Bob Amy Mom Dad Lisa Susan Linda Alex Leo + +KEY = 'Lisa' +LEN = 32 * 1024 +assert LEN >= 560 +ONE_PROMPT = [] +# ONE_PROMPT = [f'Hello {KEY} ' * ((LEN-560)//2) + 'Hello ' * (560-9)] +# ONE_PROMPT = ['Hello ' * (560 * 6)] +# ONE_PROMPT = ['请详细介绍一下北京这座城市, 不少于10000字'] +ONE_PROMPT = ["adfllekkThere is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there.\n\nThe pass key is 2222. Remember it. 2222 is the pass key.\n " + \ + "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. " * 545 + \ + "\nWhat is the pass key?"] +# ONE_PROMPT = ["adfllekkThere is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there.\n\nThe pass key is 333333. Remember it. 333333 is the pass key.\n " + \ +# "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. " * 190 + \ +# "The block is red. The sky is yello. The sun is ddddd. Here we go. There and back try a. " * 185 + \ +# "\nWhat is the pass key?"] + +# ONE_PROMPT = ['Hello ' * (4096 - 30 + 11)] +# ONE_PROMPT = ['Hello ' * (4096 - 30 + 1)] # <<< last not matching +# ONE_PROMPT = ['Hello ' * (4096 - 30 + 1)] +# ONE_PROMPT = ['Hello ' * (4096 - 30 + 13)] + +if IS_WARMUP: + ONE_PROMPT = ['Wo'] # 9 tokens + NUM_REQUESTS = 1 + MAX_NEW_TOKENS = 1 + +MESSAGE = [] +# MESSAGE = MESSAGE3 + +# NOTE: block-size should be 256 +def hello(pid: int, prompt_id: int, max_new_tokens: int, ignore_eos: bool): + headers = { + "Content-Type": "application/json", + } + url = f"http://localhost:{PORT}/v1/chat/completions" + if pid == 0: + if TEST_TYPE == TestType.Dummy: + if prompt_id == 0: + # tokens: 3808*2+2720+472=10808 + # hit-rate: 0/10808=0 + # mamba state: [3808, 7616, 10336] + prompts = ["Repeat V 10 times" * 1800] + elif prompt_id == 1: + # tokens: 3808*3+544+40=12008 + # hit-rate: 3808/(10808+12008)=16.7% + # mamba state: [3808RD, 7616, 11424, 11968] + prompts = ["Repeat V 10 times" * 1000 + "Repeat V 11 times" * 1000] + elif prompt_id == 2: + # tokens: 3808+1088+512=5408 + # hit-rate: (3808+3808)/(22816+5408)=27.0% + # mamba state: [3808RD, 4896] + prompts = ["Repeat V 10 times" * 900] + elif prompt_id == 3: + # tokens: 208 + # hit-rate: (7616+0)/(28224+208)=26.8% + # mamba state: [] + prompts = ["hi " * 199] + elif prompt_id == 4: + # tokens: 3808*2+544+523=8683 + # hit-rate: (7616+0)/(28432+8683)=20.5% + # mamba state: [3808, 7616, 8160] + prompts = ["Hello " * (4096 * 2 - 30 + 256 * 2)] + elif prompt_id == 5: + # tokens: 3808+242=4050 + # hit-rate: (7616+3808)/(37115+4050)=27.8% + # mamba state: [3080RD] + prompts = ['Hello ' * (3808 + 233)] + elif prompt_id == 6: + # tokens: 544+523=1067 + # hit-rate: (11424+0)/(41165+1067)=27.1% + # mamba state: [544] + prompts = ['ha ' * (544 * 2 - 30)] + else: + prompts = ['Hi'] + elif TEST_TYPE == TestType.Real: + if prompt_id == 0: + # tokens: 3808*2+1632+381=9629 v1 + # hit-rate: 0/9629=0 + # mamba state: [3808, 7616, 9248] + # ----- + # tokens: 3920*2+1680+112 + prompts = ["adfllekkThere is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there.\n\nThe pass key is 28884. Remember it. 28884 is the pass key.\n " + \ + "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. " * 190 + \ + "The block is red. The sky is yello. The sun is ddddd. Here we go. There and back try a. " * 185 + \ + "\nWhat is the pass key?"] + elif prompt_id == 1: + # tokens: 3808*2+1632+98=13154 v1 + # hit-rate: 3808/(9629+13154)=16.7% + # mamba state: [3808RD, 7616, 13056] + prompts = ["adfllekkThere is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there.\n\nThe pass key is 28884. Remember it. 28884 is the pass key.\n " + \ + "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. " * 545 + \ + "\nWhat is the pass key?"] + elif prompt_id == 2: + # tokens: 544+126=670 v1 + # hit-rate: (3808+0)/(22783+670)=16.2% + # mamba state: [544] + prompts = ["There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there.\n\nThe pass key is 28884. Remember it. 28886 is the pass key.\n " + \ + "The grass is yellow. The sky is blue. The sun is red. Here we go. There and back again. " * 25 + \ + "\nWhat is the pass key?"] + elif prompt_id == 3: + # tokens: 544+475=1019 + # hit-rate: (3808+0)/(23453+1019)=15.6% + # mamba state: [544] + prompts = ["There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there.\n\nThe pass key is 28886. Remember it. 28886 is the pass key.\n " + \ + "The grass is yellow. The sky is blue. The sun is red. Here we go. There and back again. " * 13 + \ + "ljlkjslkfei lkjlkj elkjfslk woiejoifjwokjjlweuriljlskjf lwkjelkjlkj. lskj lkj lkjslkfj l" * 13 + \ + "\nWhat is the pass key?"] # 600 tokens hit 300 + elif prompt_id == 4: + # tokens: 544+494=1038 + # hit-rate: (3808+544)/(24472+1038)=17.1% + # mamba state: [544RD] + prompts = ["There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there.\n\nThe pass key is 28886. Remember it. 28886 is the pass key.\n " + \ + "The grass is yellow. The sky is blue. The sun is red. Here we go. There and back again. " * 13 + \ + "ljlkjslkfei lkjlkj elkjfslk woiejoifjwokjjlweuriljlskjf lwkjelkjlkj. lskj lkj lkjslkfj l" * 13 + \ + "\nWhat is the pass key? And, what is the result of reversing the pass key and adding 1234?"] + elif prompt_id == 5: + # tokens: 13056+1088+330=14474 + # hit-rate: (4352+13056)/(25510+14474)=43.5% + # mamba state: [13056RD, 13056] + prompts = ["adfllekkThere is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there.\n\nThe pass key is 28884. Remember it. 28884 is the pass key.\n " + \ + "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. " * 600 + \ + "\nWhat is the pass key?"] + elif prompt_id == 6: + prompts = ["adfllekkThere is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there.\n\nThe pass key is 28884. Remember it. 28884 is the pass key.\n " + \ + "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. " * 510 + \ + "\nWhat is the pass key?"] + else: + prompts = ['Helloha!'] + elif pid == 1: + # v1 670 tokens + prompts = ["There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there.\n\nThe pass key is 11111. Remember it. 11111 is the pass key.\n " + \ + "The grass is yellow. The sky is blue. The sun is red. Here we go. There and back again. " * 25 + \ + "\nWhat is the pass key?"] + elif pid == 2: + # v1 13152 tokens + prompts = ["adfllekkThere is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there.\n\nThe pass key is 2222. Remember it. 2222 is the pass key.\n " + \ + "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. " * 545 + \ + "\nWhat is the pass key?"] + elif pid == 3: + prompts = ["adfllekkThere is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there.\n\nThe pass key is 333333. Remember it. 333333 is the pass key.\n " + \ + "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. " * 190 + \ + "The block is red. The sky is yello. The sun is ddddd. Here we go. There and back try a. " * 185 + \ + "\nWhat is the pass key?"] # 9k tokens + elif pid == 4: + prompts = ["There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there.\n\nThe pass key is 444. Remember it. 444 is the pass key.\n " + \ + "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. " * 190 + \ + "The block is red. The sky is yello. The sun is ddddd. Here we go. There and back try a. " * 185 + \ + "\nWhat is the pass key?"] + elif pid == 5: + prompts = ["There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there.\n\n" + \ + "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. " * 190 + \ + "The pass key is 55555. Remember it. 55555 is the pass key.\n " \ + "The block is red. The sky is yello. The sun is ddddd. Here we go. There and back try a. " * 185 + \ + "\nWhat is the pass key?"] + elif pid == 6: + prompts = ["There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there.\n\n" + \ + "The grass is yellow. The sky is blue. The sun is yellow. Here we go. There and back again. " * 190 + \ + "The pass key is 66. Remember it. 66 is the pass key.\n " \ + "The block is red. The sky is yello. The sun is ddddd. Here we go. There and back try a. " * 185 + \ + "\nWhat is the pass key?"] + else: + prompts = ['Hello!'] + + if ONE_PROMPT: + # print(ONE_PROMPT) + prompts = ONE_PROMPT + + for p in prompts: + data = { + "messages": MESSAGE if not IS_WARMUP and MESSAGE else [{"role": "user", "content": p}], + "max_tokens": max_new_tokens, + "ignore_eos": ignore_eos, + "temperature": 0.7, + "top_p": 0.8, + "top_k": 20, + "repetition_penalty": 1, + "presence_penalty": 1.5, + **({'seed': SEED} if SEED is not None else {}), + "chat_template_kwargs": {"enable_thinking": False} + } + response = requests.post(url, headers=headers, json=data) + if response.status_code == 200: + # print(response.content) + result = response.json() + # print(f"[PID {pid}] Prompt:\n {prompts[0]}") + print(f"[PID {pid}] Response:\n {result['choices'][0]['message']['content']}\n {'-' * 40}\n", end='') + # print(result) + # loss = json.loads(result['choices'][0]['message']['content'])['loss'] + # risk_level_logits = torch.tensor(json.loads(result['choices'][0]['message']['content'])['risk_level_logits']).view(-1, 2) + # category_logits = torch.tensor(json.loads(result['choices'][0]['message']['content'])['category_logits']).view(-1, 26) + # query_risk_level_logits = torch.tensor(json.loads(result['choices'][0]['message']['content'])['query_risk_level_logits']).view(-1, 3) + # query_category_logits = torch.tensor(json.loads(result['choices'][0]['message']['content'])['query_category_logits']).view(-1, 33) + + # torch.set_printoptions(precision=3, sci_mode=False) + # print(f"{loss=},{risk_level_logits.shape=},{risk_level_logits=},{category_logits.shape=},{category_logits=}") + + # query_risk_level_prob = F.softmax(query_risk_level_logits, dim=1) + # risk_level_prob = F.softmax(risk_level_logits, dim=1) + # print(f"{query_risk_level_prob.shape=},{query_risk_level_prob=}") + # print(f"{risk_level_prob.shape=},{risk_level_prob=}") + + else: + print(f"Request failed with status code {response.status_code}") + print("Response content:") + print(response.content) + +def main(prompt_id: int): + procs: List[mp.Process] = [] + + start = time.time() + for pid in range(NUM_REQUESTS): + proc = mp.Process( + target=hello, args=(pid, prompt_id, MAX_NEW_TOKENS, IGNORE_EOS), daemon=True + ) + proc.start() + procs.append(proc) + + for _proc in procs: + _proc.join() + if _proc.exitcode != 0: + sys.exit(_proc.exitcode) + + elapsed = time.time() - start + output_tps = MAX_NEW_TOKENS * NUM_REQUESTS / elapsed + print("\n") + print(f"Generate {output_tps} tokens/s, elapsed: {elapsed} s, TPS {output_tps / NUM_REQUESTS}, TPOT {1000 / (output_tps / NUM_REQUESTS)}ms") + + +if __name__ == "__main__": + prompt_id = 0 + if len(sys.argv) > 1: + prompt_id = int(sys.argv[1]) + assert prompt_id >= 0 + main(prompt_id) \ No newline at end of file From bf445fc8870eaa99b7ff3e950a12946de578f3a1 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Wed, 26 Nov 2025 00:21:15 -0800 Subject: [PATCH 011/130] fix runner Signed-off-by: Chen Zhang --- examples/offline_inference/run.py | 2 +- vllm/v1/attention/backends/gdn_attn.py | 8 +- vllm/v1/core/sched/scheduler.py | 1 + vllm/v1/worker/gpu_model_runner.py | 142 +++++-------------------- 4 files changed, 36 insertions(+), 117 deletions(-) diff --git a/examples/offline_inference/run.py b/examples/offline_inference/run.py index 5fdc255cbb28..d8e72e927b5b 100644 --- a/examples/offline_inference/run.py +++ b/examples/offline_inference/run.py @@ -4,7 +4,7 @@ def main(): MODEL = "Qwen/Qwen3-Next-80B-A3B-Instruct" PROMPT_MULTIPLE = 3 - sampling_params = SamplingParams(temperature=0.0, max_tokens=20) + sampling_params = SamplingParams(temperature=0.0, max_tokens=5) prefix = ( # examples/offline_inference/prefix_caching.py "Your name is QQQQ " "You are an expert school principal, skilled in effectively managing " diff --git a/vllm/v1/attention/backends/gdn_attn.py b/vllm/v1/attention/backends/gdn_attn.py index 9ba3f4d74b73..dc0aac69979c 100644 --- a/vllm/v1/attention/backends/gdn_attn.py +++ b/vllm/v1/attention/backends/gdn_attn.py @@ -10,6 +10,7 @@ from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.config import VllmConfig +from vllm.distributed.parallel_state import is_global_first_rank from vllm.v1.attention.backends.utils import ( AttentionCGSupport, AttentionMetadataBuilder, @@ -18,7 +19,8 @@ split_decodes_and_prefills, ) from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec - +from vllm.logger import init_logger +logger = init_logger(__name__) class GDNAttentionBackend(AttentionBackend): @staticmethod @@ -63,7 +65,7 @@ def mamba_gather_indices(common_attn_metadata: CommonAttentionMetadata, block_size: int, num_blocks: int): block_table_tensor = common_attn_metadata.block_table_tensor - start_indices = common_attn_metadata.seq_lens // block_size + start_indices = (common_attn_metadata.seq_lens - 1) // block_size offsets = torch.arange(num_blocks, device=block_table_tensor.device) indices_to_gather = start_indices.unsqueeze(1) + offsets return torch.gather(block_table_tensor, 1, indices_to_gather) @@ -160,6 +162,8 @@ def build( # type: ignore[override] block_table_tensor = mamba_gather_indices(common_attn_metadata, self.kv_cache_spec.block_size, 1 + self.num_spec) + if is_global_first_rank(): + logger.info(f"block_table_tensor: {block_table_tensor=}") else: block_table_tensor = m.block_table_tensor diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 3eecb3f46edb..743203fbcd83 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -780,6 +780,7 @@ def schedule(self) -> SchedulerOutput: finished_req_ids=self.finished_req_ids, free_encoder_mm_hashes=self.encoder_cache_manager.get_freed_mm_hashes(), ) + logger.info(f">>> [DEBUG] Scheduler: scheduler output: {scheduler_output}") # NOTE(Kuntai): this function is designed for multiple purposes: # 1. Plan the KV cache store diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 44197402c817..8f50423ca3d4 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1534,9 +1534,8 @@ def _build_attention_metadata( and self.cache_config.enable_prefix_caching and isinstance(kv_cache_group.kv_cache_spec, MambaSpec) ): - # TODO: temporarily add `scheduler_output` as it's missing in the new API. - # maybe move the _preprocess_mamba function outside. - self._preprocess_mamba(scheduler_output, kv_cache_gid, kv_cache_group) + # NOTE(Chen): where should we put this? + self._preprocess_mamba(kv_cache_gid, kv_cache_group) common_attn_metadata = CommonAttentionMetadata( query_start_loc=query_start_loc, @@ -2646,124 +2645,39 @@ def _mamba_copy_block(self, kv_cache_group_spec: KVCacheGroupSpec, for kv_cache_part in kv_cache: kv_cache_part[dest_block_id].copy_(kv_cache_part[src_block_id]) - def _preprocess_mamba(self, scheduler_output: "SchedulerOutput", + def _preprocess_mamba(self, kv_cache_group_id: int, kv_cache_group_spec: KVCacheGroupSpec, ): + # TODO(Chen): we need to optimize this function a lot assert isinstance(kv_cache_group_spec.kv_cache_spec, MambaSpec) assert self.cache_config.enable_prefix_caching block_size = kv_cache_group_spec.kv_cache_spec.block_size - num_speculative_blocks = kv_cache_group_spec.kv_cache_spec.num_speculative_blocks - new_reqs: list[NewRequestData] = scheduler_output.scheduled_new_reqs - for new_req in new_reqs: - if new_req.num_computed_tokens == 0: + block_copy_requests = [] + for i, req_id in enumerate(self.input_batch.req_ids): + if is_global_first_rank(): + logger.info(f'>>> [DEBUG] Worker: preprocess mamba for RUN: {i=} {req_id=}') + req_state = self.requests[req_id] + if req_state.num_computed_tokens == 0: + # new request, no previous state continue - block_ids: list[int] = new_req.block_ids[kv_cache_group_id] - prefix_block_idx = cdiv(new_req.num_computed_tokens, block_size) - 1 - dest_block_idx = len(block_ids) - 1 - num_speculative_blocks - logger.info(f'>>> [DEBUG] Worker: preprocess mamba for NEW: {new_req.req_id=}, {kv_cache_group_id=}, {prefix_block_idx=}, {dest_block_idx=}') - prefix_block_id, dest_block_id = block_ids[prefix_block_idx], block_ids[dest_block_idx] - logger.info(f'>>> [DEBUG] Worker: preprocess mamba for NEW: {new_req.req_id=}, {kv_cache_group_id=}, copy {prefix_block_id=} -> {dest_block_id=}') - self._mamba_copy_block(kv_cache_group_spec, prefix_block_id, dest_block_id) - - cached_reqs: CachedRequestData = scheduler_output.scheduled_cached_reqs - for i, resumed in enumerate(cached_reqs.resumed_from_preemption): - group_block_ids: Optional[tuple[list[int], ...]] = cached_reqs.new_block_ids[i] - num_compute_tokens = cached_reqs.num_computed_tokens[i] - if not resumed: - assert num_compute_tokens > 0 - if group_block_ids is None: - continue - new_block_ids: list[int] = group_block_ids[kv_cache_group_id] - if not new_block_ids: - continue - assert len(new_block_ids) >= 1 + num_speculative_blocks - block_ids: list[int] = self.requests[cached_reqs.req_ids[i]].block_ids[kv_cache_group_id] - # TODO: for sps, need to handle sps blocks - src_block_idx = cdiv(num_compute_tokens, block_size) - 1 - dest_block_idx = len(new_block_ids) - 1 - num_speculative_blocks - logger.info(f'>>> [DEBUG] Worker: preprocess mamba for RUN: {cached_reqs.req_ids[i]=}, {kv_cache_group_id=}, {src_block_idx=}, {dest_block_idx=}') - src_block_id, dest_block_id = block_ids[src_block_idx], new_block_ids[dest_block_idx] - logger.info(f'>>> [DEBUG] Worker: req_id={cached_reqs.req_ids[i]}, {block_ids=}') - logger.info(f'>>> [DEBUG] Worker: preprocess mamba for RUN: {cached_reqs.req_ids[i]=}, {kv_cache_group_id=}, copy {src_block_id=} -> {dest_block_id=}') - self._mamba_copy_block(kv_cache_group_spec, src_block_id, dest_block_id) - else: - assert group_block_ids is not None - new_block_ids: list[int] = group_block_ids[kv_cache_group_id] - if num_compute_tokens == 0: - continue - prefix_block_idx = cdiv(num_compute_tokens, block_size) - 1 - dest_block_idx = len(new_block_ids) - 1 - num_speculative_blocks - logger.info(f'>>> [DEBUG] Worker: preprocess mamba for RUN: {cached_reqs.req_ids[i]=}, {kv_cache_group_id=}, {prefix_block_id=}, {dest_block_idx=}') - prefix_block_id, dest_block_id = new_block_ids[prefix_block_idx], new_block_ids[dest_block_idx] - logger.info(f'>>> [DEBUG] Worker: preprocess mamba for RUN: {cached_reqs.req_ids[i]=}, {kv_cache_group_id=}, copy {prefix_block_id=} -> {dest_block_id=}') - self._mamba_copy_block(kv_cache_group_spec, prefix_block_id, dest_block_id) - - # def _preprocess_mamba_prefix(self, scheduler_output: "SchedulerOutput", - # kv_cache_group_id: int, - # kv_cache_group_spec: KVCacheGroupSpec, - # ): - # assert isinstance(kv_cache_group_spec.kv_cache_spec, MambaSpec) - # assert self.cache_config.enable_prefix_caching - # new_reqs: list[NewRequestData] = scheduler_output.scheduled_new_reqs - # for new_req in new_reqs: - # if new_req.num_computed_tokens == 0: - # continue - # block_ids: list[int] = new_req.block_ids[kv_cache_group_id] - # assert block_ids[0] != 0, f'{block_ids=}' - # prefix_block_id, dest_block_id = block_ids[0], block_ids[1] - # self._mamba_copy_block(kv_cache_group_spec, prefix_block_id, dest_block_id) - # cached_reqs: CachedRequestData = scheduler_output.scheduled_cached_reqs - # for i, resumed in enumerate(cached_reqs.resumed_from_preemption): - # if not resumed: - # continue - # group_block_ids: Optional[tuple[list[int], ...]] = cached_reqs.new_block_ids[i] - # assert group_block_ids is not None - # new_block_ids: list[int] = group_block_ids[kv_cache_group_id] - # assert len(new_block_ids) >= 2 + kv_cache_group_spec.kv_cache_spec.num_speculative_blocks - # if cached_reqs.num_computed_tokens[i] == 0: - # continue - # assert new_block_ids[0] != 0, f'{new_block_ids=}' - # prefix_block_id, dest_block_id = new_block_ids[0], new_block_ids[1] - # self._mamba_copy_block(kv_cache_group_spec, prefix_block_id, dest_block_id) - - - # def _postprocess_mamba_cache(self, scheduler_output: "SchedulerOutput"): - # assert self.cache_config.enable_prefix_caching - # for kv_cache_group_id, kv_cache_group_spec in enumerate( - # self.kv_cache_config.kv_cache_groups): - # if not isinstance(kv_cache_group_spec.kv_cache_spec, MambaSpec): - # continue - # new_reqs: list[NewRequestData] = scheduler_output.scheduled_new_reqs - # num_speculative_blocks = kv_cache_group_spec.kv_cache_spec.num_speculative_blocks - # for new_req in new_reqs: - # block_ids: list[int] = new_req.block_ids[kv_cache_group_id] - # if len(block_ids) <= 2 + num_speculative_blocks: - # continue - # assert len(block_ids) == 3 + num_speculative_blocks - # src_block_id, dest_block_id = block_ids[1], block_ids[-1] - # self._mamba_copy_block(kv_cache_group_spec, src_block_id, dest_block_id) - # cached_reqs: CachedRequestData = scheduler_output.scheduled_cached_reqs - # for i, req_id in enumerate(cached_reqs.req_ids): - # group_block_ids: Optional[tuple[list[int], ...]] = cached_reqs.new_block_ids[i] - # if group_block_ids is None: - # assert not cached_reqs.resumed_from_preemption[i] - # continue - # new_block_ids: list[int] = group_block_ids[kv_cache_group_id] - # if not new_block_ids: - # assert not cached_reqs.resumed_from_preemption[i] - # continue - # if not cached_reqs.resumed_from_preemption[i]: - # assert len(new_block_ids) == 1 - # block_ids: list[int] = self.requests[req_id].block_ids[kv_cache_group_id] - # src_block_id, dest_block_id = block_ids[1], new_block_ids[0] - # else: - # if len(new_block_ids) == 2 + num_speculative_blocks: - # continue - # assert len(new_block_ids) == 3 + num_speculative_blocks - # src_block_id, dest_block_id = new_block_ids[1], new_block_ids[-1] - # self._mamba_copy_block(kv_cache_group_spec, src_block_id, dest_block_id) - + prev_block_idx = (req_state.num_computed_tokens - 1)// block_size + # NOTE(Chen): if we have 7 tokens and 3 unverified speculative decoding tokens, seq_lens=10 here + # TODO in this PR(Chen): verify this for spec decode and adjust the comment + curr_block_idx = (self.seq_lens.cpu[i] - 1) // block_size + if is_global_first_rank(): + logger.info(f'>>> [DEBUG] Worker: preprocess mamba for RUN: {req_id=}, prev_len {req_state.num_computed_tokens} prev_block_idx {prev_block_idx} curr_len {self.seq_lens.cpu[i]} curr_block_idx {curr_block_idx}') + if prev_block_idx == curr_block_idx: + # same block, no need to copy + continue + prev_block_id = req_state.block_ids[kv_cache_group_id][prev_block_idx] + curr_block_id = req_state.block_ids[kv_cache_group_id][curr_block_idx] + block_copy_requests.append((prev_block_id, curr_block_id)) + if is_global_first_rank(): + logger.info(f'>>> [DEBUG] Worker: preprocess mamba for RUN: {req_id=}, prev_block_id {prev_block_id} curr_block_id {curr_block_id}') + # TODO(Chen): parallelize this loop + for prev_block_id, curr_block_id in block_copy_requests: + self._mamba_copy_block(kv_cache_group_spec, prev_block_id, curr_block_id) @torch.inference_mode() def execute_model( From d8acf4bf81e05320cbc92b2fe3c545b60788447c Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Wed, 26 Nov 2025 17:30:33 +0000 Subject: [PATCH 012/130] support sps (still issues) Signed-off-by: huanghaoyan.hhy --- vllm/v1/core/single_type_kv_cache_manager.py | 9 ++++++++- vllm/v1/worker/gpu_model_runner.py | 15 +++++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index cee0952df1d3..8bfd20e2403e 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -708,6 +708,11 @@ def get_num_blocks_to_allocate( request_id, num_tokens, new_computed_blocks ) else: + # TODO(hhy): when sps is enabled, num_tokens incudes SPS lookahead tokens when + # num_computed_tokens > 0, we should minus SPS tokens in prefill phase, + # how about in decode? + num_tokens -= (self.num_speculative_blocks + if request_id in self._allocated_reqs else 0) num_required_blocks = cdiv(num_tokens, self.block_size) + self.num_speculative_blocks num_new_blocks = (num_required_blocks - len(new_computed_blocks) - len(self.req_to_blocks[request_id])) @@ -755,6 +760,8 @@ def allocate_new_blocks( return super().allocate_new_blocks(request_id, num_tokens) else: req_blocks: list[KVCacheBlock] = self.req_to_blocks[request_id] + num_tokens -= (self.num_speculative_blocks + if request_id in self._allocated_reqs else 0) num_required_blocks = cdiv(num_tokens, self.block_size) + self.num_speculative_blocks num_new_blocks = (num_required_blocks - len(self.req_to_blocks[request_id])) if num_new_blocks <= 0: @@ -779,9 +786,9 @@ def allocate_new_blocks( # reuse blocks sps_i_0 and sps_i_1 as b and sps_j_0 # new_alloc_blocks: [sps_j_1] # new_blocks: [0, 0, 0, b, sps_j_0, sps_j_1] - req_blocks = req_blocks[:-self.num_speculative_blocks] # TODO: reuse blocks. if we need clean? especially in decode reuse_blocks = req_blocks[-self.num_speculative_blocks:] + del req_blocks[-self.num_speculative_blocks:] new_blocks.extend(reuse_blocks) new_alloc_blocks = self.block_pool.get_new_blocks(num_new_alloc_blocks) new_blocks.extend(new_alloc_blocks) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 8f50423ca3d4..38dc863bbd1f 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -846,6 +846,19 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Update the block IDs. if not resumed_from_preemption: if new_block_ids is not None: + if envs.VLLM_USE_LIGHTER_MAMBA_CACHE: + for kv_cache_gid, kv_cache_group in enumerate( + self.kv_cache_config.kv_cache_groups + ): + kv_cache_spec = kv_cache_group.kv_cache_spec + # NOTE(hhy): pop the last num_speculative_blocks (reuse blocks). + # maybe have better way to help copy blocks? + if isinstance(kv_cache_spec, MambaSpec): + block_ids = req_state.block_ids[kv_cache_gid] + del block_ids[-kv_cache_spec.num_speculative_blocks:] + if is_global_first_rank(): + logger.info(f'>>> [DEBUG] Worker: poping reuse blocks: {kv_cache_gid=}, {req_state.block_ids[kv_cache_gid]=}') + # Append the new blocks to the existing block IDs. for block_ids, new_ids in zip(req_state.block_ids, new_block_ids): block_ids.extend(new_ids) @@ -2672,6 +2685,8 @@ def _preprocess_mamba(self, continue prev_block_id = req_state.block_ids[kv_cache_group_id][prev_block_idx] curr_block_id = req_state.block_ids[kv_cache_group_id][curr_block_idx] + assert prev_block_id != 0 + assert curr_block_id != 0 block_copy_requests.append((prev_block_id, curr_block_id)) if is_global_first_rank(): logger.info(f'>>> [DEBUG] Worker: preprocess mamba for RUN: {req_id=}, prev_block_id {prev_block_id} curr_block_id {curr_block_id}') From b48f042e88ed4e90f394909415387419323c8c34 Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Wed, 26 Nov 2025 17:31:22 +0000 Subject: [PATCH 013/130] add some logs for debug Signed-off-by: huanghaoyan.hhy --- vllm/v1/core/sched/scheduler.py | 4 +-- vllm/v1/core/single_type_kv_cache_manager.py | 33 +++++++++++++++++--- vllm/v1/worker/gpu_model_runner.py | 1 + 3 files changed, 32 insertions(+), 6 deletions(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 743203fbcd83..d88672b59d33 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -514,7 +514,7 @@ def schedule(self) -> SchedulerOutput: self.kv_cache_manager.get_computed_blocks(request) ) logger.info(f'>>> [DEBUG] Scheduler: get_computed_blk: req_id={request.request_id},' - f'{num_new_local_computed_tokens=}, {new_computed_blocks.blocks=}') + f'{num_new_local_computed_tokens=}') # Get externally-cached tokens if using a KVConnector. if self.connector is not None: @@ -780,7 +780,7 @@ def schedule(self) -> SchedulerOutput: finished_req_ids=self.finished_req_ids, free_encoder_mm_hashes=self.encoder_cache_manager.get_freed_mm_hashes(), ) - logger.info(f">>> [DEBUG] Scheduler: scheduler output: {scheduler_output}") + # logger.info(f">>> [DEBUG] Scheduler: scheduler output: {scheduler_output}") # NOTE(Kuntai): this function is designed for multiple purposes: # 1. Plan the KV cache store diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 8bfd20e2403e..a00f57bb404b 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -21,6 +21,27 @@ from vllm.v1.request import Request +def format_blocks(blocks: list[KVCacheBlock]): + if not blocks: + return "[]" + + result = [] + i = 0 + + while i < len(blocks): + if blocks[i].block_id == 0: + count = 0 + start = i + while i < len(blocks) and blocks[i].block_id == 0: + count += 1 + i += 1 + result.append(f"Null-block*{count}") + else: + result.append(f'KVBlock(block_id={blocks[i].block_id})') + i += 1 + + return f"[{', '.join(result)}]" + class SingleTypeKVCacheManager(ABC): """ An abstract base class for a manager that handle the kv cache management @@ -658,6 +679,7 @@ def find_longest_cache_hit( computed.append(cached) break # we just need the last match - early stopping + print(f'Mamba.FindLongest: computed_blocks={[format_blocks(computed_block) for computed_block in computed_blocks]}', flush=True) return computed_blocks def remove_skipped_blocks(self, request_id: str, @@ -676,8 +698,9 @@ def remove_skipped_blocks(self, request_id: str, # NOTE: pre block should not be freed becasue it may be used to copy last_computed_tokens = self._req_to_last_computed[request_id] target_idx = last_computed_tokens // self.block_size - 1 + self.print(f'Mamba.remove_skipped: {last_computed_tokens=}, {target_idx=}, {len(blocks)=}') if target_idx >= 0 and blocks[target_idx] != self._null_block: - self.print(f'Mamba.remove_skipped: Freeing block {target_idx=}, {blocks[target_idx]=}') + self.print(f'Mamba.remove_skipped: Freeing block {last_computed_tokens=}, {target_idx=}, {blocks[target_idx]=}') self.block_pool.free_blocks([blocks[target_idx]]) blocks[target_idx] = self._null_block @@ -742,7 +765,7 @@ def save_new_computed_blocks( self, request_id: str, new_computed_blocks: list[KVCacheBlock]) -> None: assert isinstance(self.kv_cache_spec, MambaSpec) - self.print(f'Mamba.save_computed: {request_id=}, {new_computed_blocks=}') + self.print(f'Mamba.save_computed: {request_id=}, new_computed_blocks={format_blocks(new_computed_blocks)}') super().save_new_computed_blocks(request_id, new_computed_blocks) def allocate_new_blocks( @@ -764,8 +787,9 @@ def allocate_new_blocks( if request_id in self._allocated_reqs else 0) num_required_blocks = cdiv(num_tokens, self.block_size) + self.num_speculative_blocks num_new_blocks = (num_required_blocks - len(self.req_to_blocks[request_id])) + self.print(f'Mamba.alloc_blks: {request_id=}, {num_tokens=}, {num_required_blocks=}, {num_new_blocks=}') if num_new_blocks <= 0: - self.print(f'Mamba.alloc_blks: {request_id=}, {num_tokens=}, new_blocks=[], {req_blocks=}') + self.print(f'Mamba.alloc_blks: {request_id=}, {num_tokens=}, new_blocks=[], req_blocks={format_blocks(req_blocks)}') return [] else: # first prefill chunk @@ -793,7 +817,8 @@ def allocate_new_blocks( new_alloc_blocks = self.block_pool.get_new_blocks(num_new_alloc_blocks) new_blocks.extend(new_alloc_blocks) req_blocks.extend(new_blocks) - self.print(f'Mamba.alloc_blks: {request_id=}, {num_tokens=}, {new_blocks=}, {req_blocks=}') + self.print(f'Mamba.alloc_blks: {request_id=}, {num_tokens=}, new_blocks={format_blocks(new_blocks)}') + # self.print(f'Mamba.alloc_blks: {request_id=}, {len(req_blocks)=}, {len(self.req_to_blocks[request_id])=}, req_blocks={format_blocks(req_blocks)}') return new_blocks def free(self, request_id: str) -> None: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 38dc863bbd1f..0ff1cf946a09 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2679,6 +2679,7 @@ def _preprocess_mamba(self, # TODO in this PR(Chen): verify this for spec decode and adjust the comment curr_block_idx = (self.seq_lens.cpu[i] - 1) // block_size if is_global_first_rank(): + logger.info(f'>>> [DEBUG] Worker: preprocess mamba for RUN: {req_state.block_ids[kv_cache_group_id]=}') logger.info(f'>>> [DEBUG] Worker: preprocess mamba for RUN: {req_id=}, prev_len {req_state.num_computed_tokens} prev_block_idx {prev_block_idx} curr_len {self.seq_lens.cpu[i]} curr_block_idx {curr_block_idx}') if prev_block_idx == curr_block_idx: # same block, no need to copy From 336a140a5944cac64cf6e729f87b57dbffb7b7d5 Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Sat, 29 Nov 2025 04:52:40 +0000 Subject: [PATCH 014/130] fix block reuse bug when SPS is enabled Signed-off-by: huanghaoyan.hhy --- vllm/v1/worker/gpu_model_runner.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 0ff1cf946a09..1df68e658471 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -846,7 +846,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Update the block IDs. if not resumed_from_preemption: if new_block_ids is not None: - if envs.VLLM_USE_LIGHTER_MAMBA_CACHE: + if (envs.VLLM_USE_LIGHTER_MAMBA_CACHE + and self.speculative_config): for kv_cache_gid, kv_cache_group in enumerate( self.kv_cache_config.kv_cache_groups ): From 08feffb6298e1a3d6980e6147349d62ca6b8dcd3 Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Sun, 30 Nov 2025 06:10:48 +0000 Subject: [PATCH 015/130] fix block_table bug when sps is enabled Signed-off-by: huanghaoyan.hhy --- vllm/v1/core/single_type_kv_cache_manager.py | 14 ++++++++------ vllm/v1/worker/block_table.py | 17 +++++++++++++++++ vllm/v1/worker/gpu_model_runner.py | 14 ++++++++++++++ 3 files changed, 39 insertions(+), 6 deletions(-) diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index a00f57bb404b..607e972814e3 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -764,6 +764,7 @@ def get_num_blocks_to_allocate( def save_new_computed_blocks( self, request_id: str, new_computed_blocks: list[KVCacheBlock]) -> None: + # TODO(hhy): remove when prefix-caching is ready assert isinstance(self.kv_cache_spec, MambaSpec) self.print(f'Mamba.save_computed: {request_id=}, new_computed_blocks={format_blocks(new_computed_blocks)}') super().save_new_computed_blocks(request_id, new_computed_blocks) @@ -795,6 +796,7 @@ def allocate_new_blocks( # first prefill chunk # TODO: for mamba num_cached_block including null-blocks new_blocks = [] + reuse_blocks = [] if request_id not in self._allocated_reqs: self._allocated_reqs.add(request_id) num_new_alloc_blocks = 1 + self.num_speculative_blocks @@ -805,17 +807,17 @@ def allocate_new_blocks( num_new_alloc_blocks = 1 new_blocks.extend([self._null_block for _ in range(num_new_blocks - num_new_alloc_blocks)]) if self.num_speculative_blocks > 0: - # step i: [0, 0, 0, a, sps_i_0, sps_i_1] - # step i+1(i.e. j): [0, 0, 0, a, 0, 0, 0, b, sps_j_0, sps_j_1] - # reuse blocks sps_i_0 and sps_i_1 as b and sps_j_0 - # new_alloc_blocks: [sps_j_1] - # new_blocks: [0, 0, 0, b, sps_j_0, sps_j_1] + # step i: [0, 0, 0, a, sps_0, sps_1] + # step i+1(i.e. j): [0, 0, 0, a, 0, 0, 0, b, sps_0, sps_1] + # reuse blocks sps_0 and sps_1 + # new_alloc_blocks: b + # new_blocks: [0, 0, 0, b, sps_0, sps_1] # TODO: reuse blocks. if we need clean? especially in decode reuse_blocks = req_blocks[-self.num_speculative_blocks:] del req_blocks[-self.num_speculative_blocks:] - new_blocks.extend(reuse_blocks) new_alloc_blocks = self.block_pool.get_new_blocks(num_new_alloc_blocks) new_blocks.extend(new_alloc_blocks) + new_blocks.extend(reuse_blocks) req_blocks.extend(new_blocks) self.print(f'Mamba.alloc_blks: {request_id=}, {num_tokens=}, new_blocks={format_blocks(new_blocks)}') # self.print(f'Mamba.alloc_blks: {request_id=}, {len(req_blocks)=}, {len(self.req_to_blocks[request_id])=}, req_blocks={format_blocks(req_blocks)}') diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index 37ec0fb97e06..a81962e7311e 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -113,6 +113,19 @@ def append_row( start = self.num_blocks_per_row[row_idx] self.num_blocks_per_row[row_idx] += num_blocks self.block_table.np[row_idx, start : start + num_blocks] = block_ids + + def pop_row(self, num_blocks: int, row_idx: int): + if num_blocks <= 0: + return + + if self.use_hybrid_blocks: + num_blocks = num_blocks * self.blocks_per_kv_block + + end = self.num_blocks_per_row[row_idx] + start = end - num_blocks + assert start >= 0 + self.num_blocks_per_row[row_idx] -= num_blocks + self.block_table.np[row_idx, start : end] = 0 def add_row(self, block_ids: list[int], row_idx: int) -> None: self.num_blocks_per_row[row_idx] = 0 @@ -308,6 +321,10 @@ def append_row(self, block_ids: tuple[list[int], ...], row_idx: int) -> None: for i, block_table in enumerate(self.block_tables): block_table.append_row(block_ids[i], row_idx) + def pop_row(self, num_blocks: tuple[int, ...], row_idx: int) -> None: + for i, block_table in enumerate(self.block_tables): + block_table.pop_row(num_blocks[i], row_idx) + def add_row(self, block_ids: tuple[list[int], ...], row_idx: int) -> None: for i, block_table in enumerate(self.block_tables): block_table.add_row(block_ids[i], row_idx) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 1df68e658471..9ebe37ce4566 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -887,6 +887,20 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Update the persistent batch. self.input_batch.num_computed_tokens_cpu[req_index] = num_computed_tokens if new_block_ids is not None: + # NOTE(hhy): same as block_ids, pop the last num_speculative_blocks + if (envs.VLLM_USE_LIGHTER_MAMBA_CACHE + and self.speculative_config): + num_poped_blocks = [] + for kv_cache_group in self.kv_cache_config.kv_cache_groups: + if isinstance(kv_cache_group.kv_cache_spec, MambaSpec): + num_poped_blocks.append( + kv_cache_group.kv_cache_spec.num_speculative_blocks + ) + else: + num_poped_blocks.append(0) + self.input_batch.block_table.pop_row( + tuple(num_poped_blocks), req_index, + ) self.input_batch.block_table.append_row(new_block_ids, req_index) # For the last rank, we don't need to update the token_ids_cpu From 12306ed7c9d9ef8268b8fbc80e80af7c0c2d2aea Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Sun, 30 Nov 2025 07:45:49 +0000 Subject: [PATCH 016/130] fix the bug only mamba new blocks are empty Signed-off-by: huanghaoyan.hhy --- vllm/v1/worker/gpu_model_runner.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 9ebe37ce4566..86050b3b6ac2 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -854,7 +854,9 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: kv_cache_spec = kv_cache_group.kv_cache_spec # NOTE(hhy): pop the last num_speculative_blocks (reuse blocks). # maybe have better way to help copy blocks? - if isinstance(kv_cache_spec, MambaSpec): + if (isinstance(kv_cache_spec, MambaSpec) + and new_block_ids[kv_cache_gid] + ): block_ids = req_state.block_ids[kv_cache_gid] del block_ids[-kv_cache_spec.num_speculative_blocks:] if is_global_first_rank(): @@ -892,7 +894,9 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: and self.speculative_config): num_poped_blocks = [] for kv_cache_group in self.kv_cache_config.kv_cache_groups: - if isinstance(kv_cache_group.kv_cache_spec, MambaSpec): + if (isinstance(kv_cache_group.kv_cache_spec, MambaSpec) + and new_block_ids[kv_cache_gid] + ): num_poped_blocks.append( kv_cache_group.kv_cache_spec.num_speculative_blocks ) From 60142b6d3a2d6aac3b74504c71a852d1e33219ce Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Sun, 30 Nov 2025 16:06:35 +0000 Subject: [PATCH 017/130] add postprocess for prefix caching Signed-off-by: huanghaoyan.hhy --- vllm/v1/core/block_pool.py | 1 + vllm/v1/worker/block_table.py | 4 +++ vllm/v1/worker/gpu_model_runner.py | 43 +++++++++++++++++++++++------- 3 files changed, 38 insertions(+), 10 deletions(-) diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index 9c336160e1c8..1cf947bf1f35 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -291,6 +291,7 @@ def cache_full_blocks( if num_cached_blocks == 0: parent_block_hash: ExternalBlockHash | None = None else: + # TODO(hhy): when LPS is enabled, parent_block maybe a null block parent_block = blocks[num_cached_blocks - 1] assert parent_block.block_hash is not None parent_block_hash = maybe_convert_block_hash( diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index a81962e7311e..a97ac1cb1a7b 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -304,6 +304,10 @@ def __init__( BlockTable( block_size, max_num_reqs, + # TODO: when prefix-caching and sps are both enable for + # mamba hybrid model, it will need + # `cdiv(max_model_len, block_size * total_cp_world_size) + num_speculative_tokens` + # blocks for mamba groups max( cdiv(max_model_len, block_size * total_cp_world_size), 1 + num_speculative_tokens, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 86050b3b6ac2..5b349b689f60 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1002,6 +1002,10 @@ def _update_states_after_model_execute( ) for i, num_tokens in enumerate(num_accepted_tokens): self.input_batch.num_accepted_tokens_cpu[i] = num_tokens + if (envs.VLLM_USE_LIGHTER_MAMBA_CACHE + and self.cache_config.enable_prefix_caching + ): + self._postprocess_mamba() def _init_mrope_positions(self, req_state: CachedRequestState): model = self.get_model() @@ -2703,16 +2707,35 @@ def _preprocess_mamba(self, if prev_block_idx == curr_block_idx: # same block, no need to copy continue - prev_block_id = req_state.block_ids[kv_cache_group_id][prev_block_idx] - curr_block_id = req_state.block_ids[kv_cache_group_id][curr_block_idx] - assert prev_block_id != 0 - assert curr_block_id != 0 - block_copy_requests.append((prev_block_id, curr_block_id)) - if is_global_first_rank(): - logger.info(f'>>> [DEBUG] Worker: preprocess mamba for RUN: {req_id=}, prev_block_id {prev_block_id} curr_block_id {curr_block_id}') - # TODO(Chen): parallelize this loop - for prev_block_id, curr_block_id in block_copy_requests: - self._mamba_copy_block(kv_cache_group_spec, prev_block_id, curr_block_id) + def _postprocess_mamba(self): + assert self.cache_config.enable_prefix_caching + assert self.speculative_config + block_size = self.cache_config.block_size + num_reqs = self.input_batch.num_reqs + for i in range(num_reqs): + num_computed_tokens = self.input_batch.num_computed_tokens_cpu[i] + num_accepted_tokens = self.input_batch.num_accepted_tokens_cpu[i] + # block aligned, no need to copy mamba blocks + if num_computed_tokens % block_size == 0: + continue + computed_block_idx = num_computed_tokens // block_size + num_new_computed_tokens = num_computed_tokens + num_accepted_tokens + new_computed_block_idx = num_new_computed_tokens // block_size + if new_computed_block_idx == computed_block_idx: + continue + assert computed_block_idx + 1 == new_computed_block_idx + req_id = self.input_batch.req_ids[i] + req_state = self.requests[req_id] + prev_block_idx = (num_computed_tokens - 1)// block_size + curr_block_idx = (self.seq_lens.cpu[i] - 1) // block_size + num_accepted_tokens - 1 + for kv_cache_gid, kv_cache_group in enumerate( + self.kv_cache_config.kv_cache_groups + ): + if not isinstance(kv_cache_group.kv_cache_spec, MambaSpec): + continue + prev_block_id = req_state.block_ids[kv_cache_gid][prev_block_idx] + curr_block_id = req_state.block_ids[kv_cache_gid][curr_block_idx] + self._mamba_copy_block(kv_cache_group, curr_block_id, prev_block_id) @torch.inference_mode() def execute_model( From fa3d81005512ac65b1c7b909b39e2ef1f68f5f9b Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Sun, 30 Nov 2025 16:25:54 +0000 Subject: [PATCH 018/130] refactor preprocess_mamba Signed-off-by: huanghaoyan.hhy --- vllm/v1/worker/gpu_model_runner.py | 48 ++++++++++++++++++++---------- 1 file changed, 33 insertions(+), 15 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 5b349b689f60..cefb35d3672c 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1566,12 +1566,12 @@ def _build_attention_metadata( # graph mode. blk_table.slot_mapping.gpu[total_num_scheduled_tokens:].fill_(-1) - if (envs.VLLM_USE_LIGHTER_MAMBA_CACHE - and self.cache_config.enable_prefix_caching - and isinstance(kv_cache_group.kv_cache_spec, MambaSpec) - ): - # NOTE(Chen): where should we put this? - self._preprocess_mamba(kv_cache_gid, kv_cache_group) + # if (envs.VLLM_USE_LIGHTER_MAMBA_CACHE + # and self.cache_config.enable_prefix_caching + # and isinstance(kv_cache_group.kv_cache_spec, MambaSpec) + # ): + # # NOTE(Chen): where should we put this? + # self._preprocess_mamba(kv_cache_gid, kv_cache_group) common_attn_metadata = CommonAttentionMetadata( query_start_loc=query_start_loc, @@ -2681,15 +2681,10 @@ def _mamba_copy_block(self, kv_cache_group_spec: KVCacheGroupSpec, for kv_cache_part in kv_cache: kv_cache_part[dest_block_id].copy_(kv_cache_part[src_block_id]) - def _preprocess_mamba(self, - kv_cache_group_id: int, - kv_cache_group_spec: KVCacheGroupSpec, - ): + def _preprocess_mamba(self): # TODO(Chen): we need to optimize this function a lot - assert isinstance(kv_cache_group_spec.kv_cache_spec, MambaSpec) assert self.cache_config.enable_prefix_caching - block_size = kv_cache_group_spec.kv_cache_spec.block_size - block_copy_requests = [] + block_size = self.cache_config.block_size for i, req_id in enumerate(self.input_batch.req_ids): if is_global_first_rank(): logger.info(f'>>> [DEBUG] Worker: preprocess mamba for RUN: {i=} {req_id=}') @@ -2697,16 +2692,33 @@ def _preprocess_mamba(self, if req_state.num_computed_tokens == 0: # new request, no previous state continue + prev_block_idx = (req_state.num_computed_tokens - 1)// block_size # NOTE(Chen): if we have 7 tokens and 3 unverified speculative decoding tokens, seq_lens=10 here # TODO in this PR(Chen): verify this for spec decode and adjust the comment + # TODO: copy must have new blocks curr_block_idx = (self.seq_lens.cpu[i] - 1) // block_size if is_global_first_rank(): - logger.info(f'>>> [DEBUG] Worker: preprocess mamba for RUN: {req_state.block_ids[kv_cache_group_id]=}') - logger.info(f'>>> [DEBUG] Worker: preprocess mamba for RUN: {req_id=}, prev_len {req_state.num_computed_tokens} prev_block_idx {prev_block_idx} curr_len {self.seq_lens.cpu[i]} curr_block_idx {curr_block_idx}') + logger.info(f'>>> [DEBUG] Worker: preprocess mamba: {req_id=}, prev_len={req_state.num_computed_tokens} ' + f'prev_block_idx={prev_block_idx}, curr_len={self.seq_lens.cpu[i]} curr_block_idx={curr_block_idx}') if prev_block_idx == curr_block_idx: # same block, no need to copy continue + + for kv_cache_gid, kv_cache_group in enumerate( + self.kv_cache_config.kv_cache_groups + ): + if not isinstance(kv_cache_group.kv_cache_spec, MambaSpec): + continue + prev_block_id = req_state.block_ids[kv_cache_gid][prev_block_idx] + curr_block_id = req_state.block_ids[kv_cache_gid][curr_block_idx] + assert prev_block_id != 0 + assert curr_block_id != 0 + if is_global_first_rank(): + logger.info(f'>>> [DEBUG] Worker: preprocess mamba: {req_id=}, COPY block {prev_block_id=} -> {curr_block_id=}') + # TODO(Chen): parallelize this loop + self._mamba_copy_block(kv_cache_group, prev_block_id, curr_block_id) + def _postprocess_mamba(self): assert self.cache_config.enable_prefix_caching assert self.speculative_config @@ -2831,6 +2843,12 @@ def execute_model( # TODO(lucas): move cudagraph dispatching here: # https://github.com/vllm-project/vllm/issues/23789 + if (envs.VLLM_USE_LIGHTER_MAMBA_CACHE + and self.cache_config.enable_prefix_caching + ): + # TODO: add limition: preprocess only have new blocks + self._preprocess_mamba() + total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 attn_metadata, spec_decode_common_attn_metadata = ( From 9082d1624f55206a24b976ae3a8c2dcc4c235025 Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Sun, 30 Nov 2025 16:33:06 +0000 Subject: [PATCH 019/130] add some logs for debug Signed-off-by: huanghaoyan.hhy --- vllm/v1/attention/backends/gdn_attn.py | 2 +- vllm/v1/core/single_type_kv_cache_manager.py | 9 +++++++-- vllm/v1/worker/gpu_model_runner.py | 10 ++++++++++ 3 files changed, 18 insertions(+), 3 deletions(-) diff --git a/vllm/v1/attention/backends/gdn_attn.py b/vllm/v1/attention/backends/gdn_attn.py index dc0aac69979c..0ebbc5eefd9d 100644 --- a/vllm/v1/attention/backends/gdn_attn.py +++ b/vllm/v1/attention/backends/gdn_attn.py @@ -163,7 +163,7 @@ def build( # type: ignore[override] self.kv_cache_spec.block_size, 1 + self.num_spec) if is_global_first_rank(): - logger.info(f"block_table_tensor: {block_table_tensor=}") + logger.info(f"{block_table_tensor=}") else: block_table_tensor = m.block_table_tensor diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 607e972814e3..82653936d6ea 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -37,7 +37,7 @@ def format_blocks(blocks: list[KVCacheBlock]): i += 1 result.append(f"Null-block*{count}") else: - result.append(f'KVBlock(block_id={blocks[i].block_id})') + result.append(f'KVBlock(block_id={blocks[i].block_id}, ref_cnt={blocks[i].ref_cnt})') i += 1 return f"[{', '.join(result)}]" @@ -184,6 +184,11 @@ def cache_blocks(self, request: Request, num_tokens: int) -> None: if num_cached_blocks >= num_full_blocks: return + + if isinstance(self, MambaManager) and num_cached_blocks < num_full_blocks: + self.print(f'Mamba.cache_blocks: req_id={request.request_id}, {num_tokens=}, ' + f'{num_cached_blocks=}, {num_full_blocks=}, ' + f'new_full_blocks={format_blocks(self.req_to_blocks[request.request_id][num_cached_blocks:num_full_blocks])}') self.block_pool.cache_full_blocks( request=request, @@ -820,7 +825,7 @@ def allocate_new_blocks( new_blocks.extend(reuse_blocks) req_blocks.extend(new_blocks) self.print(f'Mamba.alloc_blks: {request_id=}, {num_tokens=}, new_blocks={format_blocks(new_blocks)}') - # self.print(f'Mamba.alloc_blks: {request_id=}, {len(req_blocks)=}, {len(self.req_to_blocks[request_id])=}, req_blocks={format_blocks(req_blocks)}') + self.print(f'Mamba.alloc_blks: {request_id=}, {len(req_blocks)=}, {len(self.req_to_blocks[request_id])=}, req_blocks={format_blocks(req_blocks)}') return new_blocks def free(self, request_id: str) -> None: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index cefb35d3672c..bb48e5a9053d 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1002,6 +1002,11 @@ def _update_states_after_model_execute( ) for i, num_tokens in enumerate(num_accepted_tokens): self.input_batch.num_accepted_tokens_cpu[i] = num_tokens + if is_global_first_rank(): + logger.info(f'>>> [DEBUG] Worker: _update_states: ' + f'{self.input_batch.num_accepted_tokens_cpu[:len(num_accepted_tokens)]=}') + logger.info(f'>>> [DEBUG] Worker: _update_states: ' + f'{self.input_batch.num_computed_tokens_cpu[:len(num_accepted_tokens)]=}') if (envs.VLLM_USE_LIGHTER_MAMBA_CACHE and self.cache_config.enable_prefix_caching ): @@ -2740,6 +2745,9 @@ def _postprocess_mamba(self): req_state = self.requests[req_id] prev_block_idx = (num_computed_tokens - 1)// block_size curr_block_idx = (self.seq_lens.cpu[i] - 1) // block_size + num_accepted_tokens - 1 + if is_global_first_rank(): + logger.info(f'>>> [DEBUG] Worker: postprocess mamba: {req_id=}, ' + f'{num_computed_tokens=}, {num_accepted_tokens=}, {prev_block_idx=}, {curr_block_idx=}') for kv_cache_gid, kv_cache_group in enumerate( self.kv_cache_config.kv_cache_groups ): @@ -2747,6 +2755,8 @@ def _postprocess_mamba(self): continue prev_block_id = req_state.block_ids[kv_cache_gid][prev_block_idx] curr_block_id = req_state.block_ids[kv_cache_gid][curr_block_idx] + if is_global_first_rank(): + logger.info(f'>>> [DEBUG] Worker: postprocess mamba: {kv_cache_gid=}, COPY block {curr_block_id=} -> {prev_block_id=}') self._mamba_copy_block(kv_cache_group, curr_block_id, prev_block_id) @torch.inference_mode() From 4b5e07478148ba80934c8410af4ed74182beb36a Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Mon, 1 Dec 2025 17:13:01 +0000 Subject: [PATCH 020/130] fix the wrong kv_cache_gid bug Signed-off-by: huanghaoyan.hhy --- vllm/v1/worker/gpu_model_runner.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index bb48e5a9053d..0722698cf719 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -893,7 +893,9 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: if (envs.VLLM_USE_LIGHTER_MAMBA_CACHE and self.speculative_config): num_poped_blocks = [] - for kv_cache_group in self.kv_cache_config.kv_cache_groups: + for kv_cache_gid, kv_cache_group in enumerate( + self.kv_cache_config.kv_cache_groups + ): if (isinstance(kv_cache_group.kv_cache_spec, MambaSpec) and new_block_ids[kv_cache_gid] ): From ea92ef48a6297855d3520a86df6c75b4fbc0acb6 Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Mon, 1 Dec 2025 17:52:31 +0000 Subject: [PATCH 021/130] adjust preprocess_mamba only copying state when new blocks exist Signed-off-by: huanghaoyan.hhy --- vllm/v1/worker/gpu_model_runner.py | 36 +++++++++++++++++++++++++++--- 1 file changed, 33 insertions(+), 3 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 0722698cf719..6875da656cbd 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2688,11 +2688,41 @@ def _mamba_copy_block(self, kv_cache_group_spec: KVCacheGroupSpec, for kv_cache_part in kv_cache: kv_cache_part[dest_block_id].copy_(kv_cache_part[src_block_id]) - def _preprocess_mamba(self): + def _preprocess_mamba(self, scheduler_output: "SchedulerOutput"): # TODO(Chen): we need to optimize this function a lot assert self.cache_config.enable_prefix_caching block_size = self.cache_config.block_size - for i, req_id in enumerate(self.input_batch.req_ids): + preprocess_req_index: list[int] = [] + for new_req_data in scheduler_output.scheduled_new_reqs: + if new_req_data.num_computed_tokens > 0: + # NOTE(hhy): prefix should be block aligned + assert new_req_data.num_computed_tokens % block_size == 0 + preprocess_req_index.append( + self.input_batch.req_id_to_index[new_req_data.req_id] + ) + cached_reqs = scheduler_output.scheduled_cached_reqs + for i, req_id in enumerate(cached_reqs.req_ids): + new_block_ids = cached_reqs.new_block_ids[i] + if not new_block_ids: + continue + for kv_cache_gid, kv_cache_group in enumerate( + self.kv_cache_config.kv_cache_groups + ): + if not isinstance( + kv_cache_group.kv_cache_spec, MambaSpec + ): + continue + # NOTE(hhy): assume all mamba groups are the same + if new_block_ids[kv_cache_gid]: + preprocess_req_index.append( + self.input_batch.req_id_to_index[req_id] + ) + break + + if is_global_first_rank(): + logger.info(f'>>> [DEBUG] Worker: preprocess mamba: {preprocess_req_index=}') + for i in preprocess_req_index: + req_id = self.input_batch.req_ids[i] if is_global_first_rank(): logger.info(f'>>> [DEBUG] Worker: preprocess mamba for RUN: {i=} {req_id=}') req_state = self.requests[req_id] @@ -2859,7 +2889,7 @@ def execute_model( and self.cache_config.enable_prefix_caching ): # TODO: add limition: preprocess only have new blocks - self._preprocess_mamba() + self._preprocess_mamba(scheduler_output) total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 From 0b3a6b54d0c735d870c0402010f06e43dbcdd9ec Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Tue, 2 Dec 2025 17:39:00 +0000 Subject: [PATCH 022/130] update mamba_gather_indices and apply to mamba models Signed-off-by: huanghaoyan.hhy --- vllm/v1/attention/backends/gdn_attn.py | 19 ++++++------------- vllm/v1/attention/backends/linear_attn.py | 8 +++++++- vllm/v1/attention/backends/mamba1_attn.py | 6 +++++- vllm/v1/attention/backends/mamba2_attn.py | 12 ++++++------ vllm/v1/attention/backends/short_conv_attn.py | 10 +++++++++- vllm/v1/attention/backends/utils.py | 18 +++++++++++++++++- vllm/v1/worker/gpu_model_runner.py | 2 -- 7 files changed, 50 insertions(+), 25 deletions(-) diff --git a/vllm/v1/attention/backends/gdn_attn.py b/vllm/v1/attention/backends/gdn_attn.py index 0ebbc5eefd9d..f41182d59225 100644 --- a/vllm/v1/attention/backends/gdn_attn.py +++ b/vllm/v1/attention/backends/gdn_attn.py @@ -17,6 +17,7 @@ CommonAttentionMetadata, compute_causal_conv1d_metadata, split_decodes_and_prefills, + mamba_gather_indices, ) from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec from vllm.logger import init_logger @@ -60,16 +61,6 @@ class GDNAttentionMetadata: batch_ptr: torch.Tensor | None = None token_chunk_offset_ptr: torch.Tensor | None = None -# TODO: need to move, and called by all mamba builders -def mamba_gather_indices(common_attn_metadata: CommonAttentionMetadata, - block_size: int, - num_blocks: int): - block_table_tensor = common_attn_metadata.block_table_tensor - start_indices = (common_attn_metadata.seq_lens - 1) // block_size - offsets = torch.arange(num_blocks, device=block_table_tensor.device) - indices_to_gather = start_indices.unsqueeze(1) + offsets - return torch.gather(block_table_tensor, 1, indices_to_gather) - class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]): _cudagraph_support = AttentionCGSupport.UNIFORM_BATCH @@ -159,9 +150,11 @@ def build( # type: ignore[override] context_lens_tensor = context_lens.to(query_start_loc.device) nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None if envs.VLLM_USE_LIGHTER_MAMBA_CACHE: - block_table_tensor = mamba_gather_indices(common_attn_metadata, - self.kv_cache_spec.block_size, - 1 + self.num_spec) + block_table_tensor = mamba_gather_indices( + common_attn_metadata, + self.kv_cache_spec, + 1 + self.num_spec, + ) if is_global_first_rank(): logger.info(f"{block_table_tensor=}") else: diff --git a/vllm/v1/attention/backends/linear_attn.py b/vllm/v1/attention/backends/linear_attn.py index 2fd36233006c..0755e5a58177 100644 --- a/vllm/v1/attention/backends/linear_attn.py +++ b/vllm/v1/attention/backends/linear_attn.py @@ -10,6 +10,7 @@ from vllm.v1.attention.backends.utils import ( AttentionMetadataBuilder, CommonAttentionMetadata, + mamba_gather_indices, split_decodes_and_prefills, ) from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec @@ -57,7 +58,12 @@ def build( state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] if envs.VLLM_USE_LIGHTER_MAMBA_CACHE: - state_indices_tensor = state_indices_tensor.contiguous() + state_indices_tensor = mamba_gather_indices( + common_attn_metadata, + self.kv_cache_spec, + )[:, 0] + else: + state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( split_decodes_and_prefills( diff --git a/vllm/v1/attention/backends/mamba1_attn.py b/vllm/v1/attention/backends/mamba1_attn.py index 5ebd61eeca55..6e544a44e8cb 100644 --- a/vllm/v1/attention/backends/mamba1_attn.py +++ b/vllm/v1/attention/backends/mamba1_attn.py @@ -12,6 +12,7 @@ from vllm.v1.attention.backends.mamba_attn import BaseMambaAttentionMetadataBuilder from vllm.v1.attention.backends.utils import ( CommonAttentionMetadata, + mamba_gather_indices, split_decodes_and_prefills, ) from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec @@ -94,7 +95,10 @@ def build( ) else: # Always return just a single block per each request: - state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] + state_indices_tensor = mamba_gather_indices( + common_attn_metadata, + self.kv_cache_spec, + )[:, 0] if envs.VLLM_USE_LIGHTER_MAMBA_CACHE: state_indices_tensor = state_indices_tensor.contiguous() block_idx_last_scheduled_token = None diff --git a/vllm/v1/attention/backends/mamba2_attn.py b/vllm/v1/attention/backends/mamba2_attn.py index 9d12bd1d4b06..f314fb5d792c 100644 --- a/vllm/v1/attention/backends/mamba2_attn.py +++ b/vllm/v1/attention/backends/mamba2_attn.py @@ -14,6 +14,7 @@ PAD_SLOT_ID, CommonAttentionMetadata, compute_causal_conv1d_metadata, + mamba_gather_indices, split_decodes_and_prefills, ) from vllm.v1.kv_cache_interface import AttentionSpec @@ -193,12 +194,11 @@ def build( ) else: # Always return just a single block per each request: - state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] - if envs.VLLM_USE_LIGHTER_MAMBA_CACHE: - # NOTE: With Mamba prefix-caching support, a request can consist of - # multiple blocks. This makes the state_indices non-contiguous, so - # we must explicitly make them contiguous here. - state_indices_tensor = state_indices_tensor.contiguous() + state_indices_tensor = mamba_gather_indices( + common_attn_metadata, + self.kv_cache_spec, + )[:, 0] + # Additional cache-related varaiables: block_idx_last_scheduled_token = None block_idx_last_computed_token = None diff --git a/vllm/v1/attention/backends/short_conv_attn.py b/vllm/v1/attention/backends/short_conv_attn.py index de0cb73db091..38edd12cf9e3 100644 --- a/vllm/v1/attention/backends/short_conv_attn.py +++ b/vllm/v1/attention/backends/short_conv_attn.py @@ -4,12 +4,14 @@ import torch +from vllm import envs from vllm.attention.backends.abstract import AttentionBackend from vllm.v1.attention.backends.mamba_attn import BaseMambaAttentionMetadataBuilder from vllm.v1.attention.backends.utils import ( PAD_SLOT_ID, CommonAttentionMetadata, compute_causal_conv1d_metadata, + mamba_gather_indices, split_decodes_and_prefills, ) @@ -48,7 +50,13 @@ def build( ) -> ShortConvAttentionMetadata: num_reqs = common_attn_metadata.num_reqs query_start_loc = common_attn_metadata.query_start_loc - state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] + if envs.VLLM_USE_LIGHTER_MAMBA_CACHE: + state_indices_tensor = mamba_gather_indices( + common_attn_metadata, + self.kv_cache_spec, + )[:, 0] + else: + state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] # for causal_conv1d nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 540a8e2b1d01..9fe03a51979c 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -35,7 +35,7 @@ ) from vllm.logger import init_logger from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase -from vllm.v1.kv_cache_interface import AttentionSpec +from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec from vllm.v1.worker.ubatch_utils import UBatchSlice logger = init_logger(__name__) @@ -1114,3 +1114,19 @@ def get_dcp_local_seq_lens( ) dcp_local_seq_lens = base + remainder return dcp_local_seq_lens.squeeze(1) + +# For Lighter Mamba Prefix-Caching +@torch.compile +def mamba_gather_indices( + common_attn_metadata: CommonAttentionMetadata, + kv_cache_spec: MambaSpec, + num_blocks: int = 1, +) -> torch.Tensor: + assert isinstance(kv_cache_spec, MambaSpec) + block_table_tensor = common_attn_metadata.block_table_tensor + if not kv_cache_spec.enable_caching: + return block_table_tensor + start_indices = (common_attn_metadata.seq_lens - 1) // kv_cache_spec.block_size + offsets = torch.arange(num_blocks, device=block_table_tensor.device) + indices_to_gather = start_indices.unsqueeze(1) + offsets + return torch.gather(block_table_tensor, 1, indices_to_gather) \ No newline at end of file diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 6875da656cbd..ad265feeb960 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1472,7 +1472,6 @@ def _prepare_inputs( def _build_attention_metadata( self, - scheduler_output: "SchedulerOutput", total_num_scheduled_tokens: int, max_num_scheduled_tokens: int, num_reqs: int, @@ -2895,7 +2894,6 @@ def execute_model( use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 attn_metadata, spec_decode_common_attn_metadata = ( self._build_attention_metadata( - scheduler_output=scheduler_output, total_num_scheduled_tokens=total_num_scheduled_tokens, max_num_scheduled_tokens=max_num_scheduled_tokens, num_reqs=num_reqs, From 381ab343286acd2140afb836a3629dec9eee95e2 Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Tue, 2 Dec 2025 18:06:57 +0000 Subject: [PATCH 023/130] fix bugs with mamba and lfm2 Signed-off-by: huanghaoyan.hhy --- vllm/model_executor/layers/mamba/mamba_mixer2.py | 8 +++++--- vllm/model_executor/models/lfm2.py | 8 +++++--- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 714f12d97811..2db5e84fe63f 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -570,7 +570,9 @@ def conv_ssm_forward( assert self.cache_config is not None mamba_block_size = self.cache_config.mamba_block_size - prefix_caching_enabled = self.cache_config.enable_prefix_caching + prefix_caching_enabled = (not envs.VLLM_USE_LIGHTER_MAMBA_CACHE + and self.cache_config.enable_prefix_caching + ) if attn_metadata is not None: assert isinstance(attn_metadata, dict) attn_metadata = attn_metadata[self.prefix] @@ -622,7 +624,7 @@ def conv_ssm_forward( dim=0, ) - if not envs.VLLM_USE_LIGHTER_MAMBA_CACHE and prefix_caching_enabled: + if prefix_caching_enabled: # If prefix caching is enabled, retrieve the relevant variables # for prefill and decode block_idx_last_computed_token_d, block_idx_last_computed_token_p = ( @@ -815,7 +817,7 @@ def conv_ssm_forward( # Process decode requests if has_decode: - if not envs.VLLM_USE_LIGHTER_MAMBA_CACHE and prefix_caching_enabled: + if prefix_caching_enabled: state_indices_tensor_d_input = state_indices_tensor_d.gather( 1, block_idx_last_computed_token_d.unsqueeze(1) ).squeeze(1) diff --git a/vllm/model_executor/models/lfm2.py b/vllm/model_executor/models/lfm2.py index 74bdde27ece5..f1e64d7518f7 100644 --- a/vllm/model_executor/models/lfm2.py +++ b/vllm/model_executor/models/lfm2.py @@ -7,6 +7,7 @@ import torch.nn as nn from transformers import Lfm2Config +from vllm import envs from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ModelConfig, VllmConfig @@ -462,9 +463,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: quant_config = vllm_config.quant_config cache_config = vllm_config.cache_config - assert not cache_config.enable_prefix_caching, ( - "Lfm2 currently does not support prefix caching" - ) + if not envs.VLLM_USE_LIGHTER_MAMBA_CACHE: + assert not cache_config.enable_prefix_caching, ( + "Lfm2 currently does not support prefix caching" + ) super().__init__() self.config = config From 1da80322d73f9f6e3beb71f48999309419354846 Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Wed, 3 Dec 2025 17:38:19 +0000 Subject: [PATCH 024/130] fix the bug when prefix-caching is disable Signed-off-by: huanghaoyan.hhy --- vllm/v1/core/single_type_kv_cache_manager.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 82653936d6ea..a6d2218e0356 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -690,7 +690,8 @@ def find_longest_cache_hit( def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None: assert isinstance(self.kv_cache_spec, MambaSpec) - if not envs.VLLM_USE_LIGHTER_MAMBA_CACHE: + if not (envs.VLLM_USE_LIGHTER_MAMBA_CACHE + and self.kv_cache_spec.enable_caching): # Here unused blocks may be freed up for running requests. # TODO(@s3woz) Free up all blocks that aren't needed by Mamba2 # (for which find_longest_cache_hit returns block_pool.null_block) @@ -741,7 +742,10 @@ def get_num_blocks_to_allocate( # how about in decode? num_tokens -= (self.num_speculative_blocks if request_id in self._allocated_reqs else 0) - num_required_blocks = cdiv(num_tokens, self.block_size) + self.num_speculative_blocks + if self.kv_cache_spec.enable_caching: + num_required_blocks = cdiv(num_tokens, self.block_size) + self.num_speculative_blocks + else: + num_required_blocks = 1 + self.num_speculative_blocks num_new_blocks = (num_required_blocks - len(new_computed_blocks) - len(self.req_to_blocks[request_id])) num_new_alloc_blocks = 0 @@ -791,7 +795,10 @@ def allocate_new_blocks( req_blocks: list[KVCacheBlock] = self.req_to_blocks[request_id] num_tokens -= (self.num_speculative_blocks if request_id in self._allocated_reqs else 0) - num_required_blocks = cdiv(num_tokens, self.block_size) + self.num_speculative_blocks + if self.kv_cache_spec.enable_caching: + num_required_blocks = cdiv(num_tokens, self.block_size) + self.num_speculative_blocks + else: + num_required_blocks = 1 + self.num_speculative_blocks num_new_blocks = (num_required_blocks - len(self.req_to_blocks[request_id])) self.print(f'Mamba.alloc_blks: {request_id=}, {num_tokens=}, {num_required_blocks=}, {num_new_blocks=}') if num_new_blocks <= 0: From 1667210b7bafd1b299b4d3f12f490ecdc894e7f1 Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Thu, 4 Dec 2025 17:52:15 +0000 Subject: [PATCH 025/130] fix the bug between LPC and mamba mixer Signed-off-by: huanghaoyan.hhy --- vllm/model_executor/layers/mamba/mamba_mixer.py | 6 +++++- vllm/model_executor/layers/mamba/mamba_mixer2.py | 3 ++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index 90e520e24441..69c8f8c30f49 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -7,6 +7,7 @@ from torch import nn from torch.nn.parameter import Parameter +from vllm import envs from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config from vllm.distributed.parallel_state import ( get_tensor_model_parallel_rank, @@ -240,7 +241,10 @@ def forward_cuda(self, hidden_states: torch.Tensor, output: torch.Tensor): assert self.cache_config is not None mamba_block_size = self.cache_config.mamba_block_size - prefix_caching_enabled = self.cache_config.enable_prefix_caching + prefix_caching_enabled = ( + not envs.VLLM_USE_LIGHTER_MAMBA_CACHE + and self.cache_config.enable_prefix_caching + ) if attn_metadata is not None: assert isinstance(attn_metadata, dict) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 2db5e84fe63f..a318a82ba800 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -570,7 +570,8 @@ def conv_ssm_forward( assert self.cache_config is not None mamba_block_size = self.cache_config.mamba_block_size - prefix_caching_enabled = (not envs.VLLM_USE_LIGHTER_MAMBA_CACHE + prefix_caching_enabled = ( + not envs.VLLM_USE_LIGHTER_MAMBA_CACHE and self.cache_config.enable_prefix_caching ) if attn_metadata is not None: From cb455056f6d2e2319844c261da2b4aab9f561648 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Mon, 1 Dec 2025 00:13:54 -0800 Subject: [PATCH 026/130] tmp from commit "b15e6fd" Signed-off-by: huanghaoyan.hhy --- examples/offline_inference/run.py | 3 +- vllm/v1/core/kv_cache_coordinator.py | 12 +- vllm/v1/core/kv_cache_manager.py | 6 +- vllm/v1/core/sched/scheduler.py | 5 +- vllm/v1/core/single_type_kv_cache_manager.py | 190 ++++++++----------- 5 files changed, 102 insertions(+), 114 deletions(-) diff --git a/examples/offline_inference/run.py b/examples/offline_inference/run.py index d8e72e927b5b..0883486987d4 100644 --- a/examples/offline_inference/run.py +++ b/examples/offline_inference/run.py @@ -3,7 +3,7 @@ def main(): MODEL = "Qwen/Qwen3-Next-80B-A3B-Instruct" - PROMPT_MULTIPLE = 3 + PROMPT_MULTIPLE = 6 sampling_params = SamplingParams(temperature=0.0, max_tokens=5) prefix = ( # examples/offline_inference/prefix_caching.py "Your name is QQQQ " @@ -23,6 +23,7 @@ def main(): for APC in [True]: engine = LLM(model=MODEL, enable_prefix_caching=APC, enforce_eager=True, tensor_parallel_size=4, # load_format="dummy" + speculative_config={"method": "qwen3_next_mtp", "num_speculative_tokens": 2} ) for i in range(3): if i == 0: diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index 1531b61f88fe..30e1442c12c5 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -56,6 +56,7 @@ def get_num_blocks_to_allocate( num_tokens: int, new_computed_blocks: tuple[Sequence[KVCacheBlock], ...], num_encoder_tokens: int, + num_tokens_target_model: int, ) -> int: """ Get the number of blocks needed to be allocated for the request. @@ -68,7 +69,8 @@ def get_num_blocks_to_allocate( prefix caching. num_encoder_tokens: The number of encoder tokens for allocating blocks for cross-attention. - + num_tokens_target_model: w/o spec decode, this should be the same as + num_tokens, with spec decode, TODO more comments here. Returns: The number of blocks. """ @@ -82,7 +84,7 @@ def get_num_blocks_to_allocate( ) else: num_blocks_to_allocate += manager.get_num_blocks_to_allocate( - request_id, num_tokens, new_computed_blocks[i] + request_id, num_tokens, new_computed_blocks[i], num_tokens_target_model ) return num_blocks_to_allocate @@ -101,7 +103,7 @@ def save_new_computed_blocks( manager.save_new_computed_blocks(request_id, new_computed_blocks[i]) def allocate_new_blocks( - self, request_id: str, num_tokens: int, num_encoder_tokens: int = 0 + self, request_id: str, num_tokens: int, num_tokens_target_model: int, num_encoder_tokens: int = 0 ) -> tuple[list[KVCacheBlock], ...]: """ Allocate new blocks for the request to give it at least `num_tokens` @@ -111,9 +113,10 @@ def allocate_new_blocks( request_id: The request ID. num_tokens: The total number of tokens that need a slot (including tokens that are already allocated). + num_tokens_target_model: w/o spec decode, this should be the same as + num_tokens, with spec decode, TODO more comments here. num_encoder_tokens: The number of encoder tokens for allocating blocks for cross-attention. - Returns: The new allocated blocks. """ @@ -123,6 +126,7 @@ def allocate_new_blocks( num_encoder_tokens if isinstance(manager, CrossAttentionManager) else num_tokens, + num_tokens_target_model, ) for manager in self.single_type_managers ) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index b3d6bf5bc6a5..bfaa166f8f36 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -281,8 +281,9 @@ def allocate_slots( # The number of computed tokens is the number of computed tokens plus # the new prefix caching hits num_computed_tokens = request.num_computed_tokens + num_new_computed_tokens + num_tokens_target_model = num_computed_tokens + num_new_tokens num_tokens_need_slot = min( - num_computed_tokens + num_new_tokens + num_lookahead_tokens, + num_tokens_target_model + num_lookahead_tokens, self.max_model_len, ) @@ -291,6 +292,7 @@ def allocate_slots( num_tokens=num_tokens_need_slot, new_computed_blocks=new_computed_block_list, num_encoder_tokens=num_encoder_tokens, + num_tokens_target_model=num_tokens_target_model, ) if num_blocks_to_allocate > self.block_pool.get_num_free_blocks(): @@ -313,7 +315,7 @@ def allocate_slots( ) new_blocks = self.coordinator.allocate_new_blocks( - request.request_id, num_tokens_need_slot, num_encoder_tokens + request.request_id, num_tokens_need_slot, num_tokens_target_model, num_encoder_tokens ) # P/D: delay caching blocks if we have to recv from diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index d88672b59d33..0f2aeda6bcbd 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -228,7 +228,9 @@ def _mamba_block_aligned_split(self, request: Request, num_new_tokens: int, num_ return num_new_tokens def schedule(self) -> SchedulerOutput: - print(f">>> [DEBUG] Scheduler: schedule new step") + print(f">>> [DEBUG] Scheduler: schidule new step") + for req in self.requests.values(): + print(f">>> [DEBUG] Scheduler: request {req.request_id} num_computed_tokens={req.num_computed_tokens} num_tokens={req.num_tokens} num_tokens_with_spec={req.num_tokens_with_spec}") # NOTE(woosuk) on the scheduling algorithm: # There's no "decoding phase" nor "prefill phase" in the scheduler. # Each request just has the num_computed_tokens and @@ -801,6 +803,7 @@ def schedule(self) -> SchedulerOutput: with record_function_or_nullcontext("schedule: update_after_schedule"): self._update_after_schedule(scheduler_output) + logger.info('>>> [DEBUG] Scheduler: scheduler_output: {scheduler_output}') return scheduler_output def _update_after_schedule( diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index a6d2218e0356..9579a8c1c1c6 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -94,6 +94,7 @@ def get_num_blocks_to_allocate( request_id: str, num_tokens: int, new_computed_blocks: Sequence[KVCacheBlock], + num_tokens_target_model: int, ) -> int: """ Get the number of blocks needed to be allocated for the request. @@ -104,6 +105,8 @@ def get_num_blocks_to_allocate( tokens that are already allocated). new_computed_blocks: The new computed blocks just hitting the prefix caching. + num_tokens_target_model: w/o spec decode, this should be the same as + num_tokens, with spec decode, TODO more comments here. Returns: The number of blocks. @@ -146,7 +149,7 @@ def save_new_computed_blocks( assert len(new_computed_blocks) == 0 def allocate_new_blocks( - self, request_id: str, num_tokens: int + self, request_id: str, num_tokens: int, num_tokens_target_model: int ) -> list[KVCacheBlock]: """ Allocate new blocks for the request to give it at least `num_tokens` @@ -156,7 +159,8 @@ def allocate_new_blocks( request_id: The request ID. num_tokens: The total number of tokens that need a slot (including tokens that are already allocated). - + num_tokens_target_model: w/o spec decode, this should be the same as + num_tokens, with spec decode, TODO more comments here. Returns: The new allocated blocks. """ @@ -310,6 +314,7 @@ def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> No break removed_blocks.append(blocks[i]) blocks[i] = self._null_block + self.print(f'Mamba.remove_skipped_blocks: {request_id=}, {num_computed_tokens=}, {num_skipped_tokens=}, {num_skipped_blocks=}, removed_blocks={format_blocks(removed_blocks)}') self.block_pool.free_blocks(removed_blocks) def get_num_skipped_tokens(self, num_computed_tokens: int) -> int: @@ -645,8 +650,9 @@ def __init__(self, kv_cache_spec: MambaSpec, **kwargs) -> None: super().__init__(kv_cache_spec, **kwargs) if envs.VLLM_USE_LIGHTER_MAMBA_CACHE: self.num_speculative_blocks: int = kv_cache_spec.num_speculative_blocks - self._allocated_reqs: set[str] = set() - self._req_to_last_computed: dict[str, int] = {} + # self._req_info : dict[str, MambaManager.AllocationInfo] = {} + self.last_state_block_idx: dict[str, int] = {} + self._allocated_spec_block_reqs: set[str] = set() @classmethod def find_longest_cache_hit( @@ -687,30 +693,21 @@ def find_longest_cache_hit( print(f'Mamba.FindLongest: computed_blocks={[format_blocks(computed_block) for computed_block in computed_blocks]}', flush=True) return computed_blocks + def get_num_skipped_tokens(self, num_computed_tokens: int) -> int: + # TODO: merge https://github.com/vllm-project/vllm/pull/28047 first + return num_computed_tokens - 1 + def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None: assert isinstance(self.kv_cache_spec, MambaSpec) - if not (envs.VLLM_USE_LIGHTER_MAMBA_CACHE - and self.kv_cache_spec.enable_caching): - # Here unused blocks may be freed up for running requests. - # TODO(@s3woz) Free up all blocks that aren't needed by Mamba2 - # (for which find_longest_cache_hit returns block_pool.null_block) - pass - else: - blocks: list[KVCacheBlock] = self.req_to_blocks[request_id] - if request_id in self._req_to_last_computed: - # TODO what if in decoding phase and enabled sps when accepted - # token is not a multiple of block size - # NOTE: pre block should not be freed becasue it may be used to copy - last_computed_tokens = self._req_to_last_computed[request_id] - target_idx = last_computed_tokens // self.block_size - 1 - self.print(f'Mamba.remove_skipped: {last_computed_tokens=}, {target_idx=}, {len(blocks)=}') - if target_idx >= 0 and blocks[target_idx] != self._null_block: - self.print(f'Mamba.remove_skipped: Freeing block {last_computed_tokens=}, {target_idx=}, {blocks[target_idx]=}') - self.block_pool.free_blocks([blocks[target_idx]]) - blocks[target_idx] = self._null_block - - self._req_to_last_computed[request_id] = num_computed_tokens + super().remove_skipped_blocks(request_id, num_computed_tokens) + if envs.VLLM_USE_LIGHTER_MAMBA_CACHE: + if ( + (last_state_block_idx := self.last_state_block_idx.get(request_id)) and last_state_block_idx < cdiv(num_computed_tokens, self.block_size) - 1): + blocks = self.req_to_blocks[request_id] + if blocks[last_state_block_idx] != self._null_block: + self.block_pool.free_blocks([blocks[last_state_block_idx]]) + blocks[last_state_block_idx] = self._null_block def get_num_common_prefix_blocks(self, running_request_id: str) -> int: """ @@ -723,52 +720,35 @@ def get_num_blocks_to_allocate( request_id: str, num_tokens: int, new_computed_blocks: Sequence[KVCacheBlock], + num_tokens_target_model: int, ) -> int: + # mamba layers only exist in target model. + num_tokens = num_tokens_target_model # Allocate extra `num_speculative_blocks` blocks for # speculative decoding (MTP/EAGLE) with linear attention. assert isinstance(self.kv_cache_spec, MambaSpec) - if not envs.VLLM_USE_LIGHTER_MAMBA_CACHE: - if self.kv_cache_spec.num_speculative_blocks > 0: - num_tokens += ( - self.kv_cache_spec.block_size - * self.kv_cache_spec.num_speculative_blocks - ) - return super().get_num_blocks_to_allocate( - request_id, num_tokens, new_computed_blocks + if self.kv_cache_spec.num_speculative_blocks > 0: + num_tokens += ( + self.kv_cache_spec.block_size + * self.kv_cache_spec.num_speculative_blocks ) - else: - # TODO(hhy): when sps is enabled, num_tokens incudes SPS lookahead tokens when - # num_computed_tokens > 0, we should minus SPS tokens in prefill phase, - # how about in decode? - num_tokens -= (self.num_speculative_blocks - if request_id in self._allocated_reqs else 0) - if self.kv_cache_spec.enable_caching: - num_required_blocks = cdiv(num_tokens, self.block_size) + self.num_speculative_blocks + num_blocks_to_allocate = super().get_num_blocks_to_allocate( + request_id, num_tokens, new_computed_blocks + ) + if envs.VLLM_USE_LIGHTER_MAMBA_CACHE: + # (Chen): This may be possible. (block_size 4, 2 sps). + # [A, stoken1, stoken2] SBLOCK1 SBLOCK2 -> + # [A, ?, ?, ?] NULL NULL [?, ?, ?, B] [stoken 1, stoken 2] SBLOCK1 SBLOCK2 -> need two blocks + # but we do it as following: + # [A, ?, ?, ?] NULL NULL NULL [stoken 1, stoken 2] SBLOCK1 SBLOCK2 -> need 1 block + if request_id in self._req_info: + # previously allocated blocks + num_blocks_to_allocate = min(num_blocks_to_allocate, 1) else: - num_required_blocks = 1 + self.num_speculative_blocks - num_new_blocks = (num_required_blocks - len(new_computed_blocks) - - len(self.req_to_blocks[request_id])) - num_new_alloc_blocks = 0 - if num_new_blocks > 0: - # first prefill - if request_id not in self._allocated_reqs: - # if len(self.req_to_blocks[request_id]) == 0: - num_new_alloc_blocks = 1 + self.num_speculative_blocks - else: - num_new_alloc_blocks = 1 - - # If a computed block of a request is an eviction candidate (in the - # free queue and ref_cnt == 0), it will be changed from a free block - # to a computed block when the request is allocated, so we also count - # it as needed to be allocated. - num_evictable_computed_blocks = sum( - blk.ref_cnt == 0 and not blk.is_null - for blk in new_computed_blocks) - - self.print(f'Mamba.get_nblks: {request_id=}, {num_tokens=}, {num_new_blocks=}, ' - f'{num_new_alloc_blocks=}, {num_evictable_computed_blocks=}') - - return num_new_alloc_blocks + num_evictable_computed_blocks + num_blocks_to_allocate = min(num_blocks_to_allocate, 1 + self.kv_cache_spec.num_speculative_blocks) + self.print(f'Mamba.get_num_blocks_to_allocate: {request_id=}, {num_tokens=}, {num_blocks_to_allocate=}') + return num_blocks_to_allocate + def save_new_computed_blocks( self, request_id: str, @@ -779,7 +759,7 @@ def save_new_computed_blocks( super().save_new_computed_blocks(request_id, new_computed_blocks) def allocate_new_blocks( - self, request_id: str, num_tokens: int + self, request_id: str, num_tokens: int, num_tokens_target_model: int ) -> list[KVCacheBlock]: # Allocate extra `num_speculative_blocks` blocks for # speculative decoding (MTP/EAGLE) with linear attention. @@ -790,55 +770,53 @@ def allocate_new_blocks( self.kv_cache_spec.block_size * self.kv_cache_spec.num_speculative_blocks ) - return super().allocate_new_blocks(request_id, num_tokens) + return super().allocate_new_blocks(request_id, num_tokens, num_tokens_target_model) else: req_blocks: list[KVCacheBlock] = self.req_to_blocks[request_id] - num_tokens -= (self.num_speculative_blocks - if request_id in self._allocated_reqs else 0) - if self.kv_cache_spec.enable_caching: - num_required_blocks = cdiv(num_tokens, self.block_size) + self.num_speculative_blocks - else: - num_required_blocks = 1 + self.num_speculative_blocks - num_new_blocks = (num_required_blocks - len(self.req_to_blocks[request_id])) - self.print(f'Mamba.alloc_blks: {request_id=}, {num_tokens=}, {num_required_blocks=}, {num_new_blocks=}') - if num_new_blocks <= 0: - self.print(f'Mamba.alloc_blks: {request_id=}, {num_tokens=}, new_blocks=[], req_blocks={format_blocks(req_blocks)}') + num_tokens = num_tokens_target_model + num_required_blocks = cdiv(num_tokens, self.block_size) + self.kv_cache_spec.num_speculative_blocks + if num_required_blocks == len(req_blocks): return [] - else: - # first prefill chunk - # TODO: for mamba num_cached_block including null-blocks - new_blocks = [] - reuse_blocks = [] - if request_id not in self._allocated_reqs: - self._allocated_reqs.add(request_id) - num_new_alloc_blocks = 1 + self.num_speculative_blocks - new_blocks.extend([self._null_block - for _ in range(num_new_blocks - num_new_alloc_blocks)]) - # new_alloc_blocks = self.block_pool.get_new_blocks(num_new_alloc_blocks) + else: + assert num_required_blocks < len(req_blocks), f'num_required_blocks {num_required_blocks} < len(req_blocks) {len(req_blocks)}' + prev_block_len = len(req_blocks) + dbg_is_allocated = request_id in self._allocated_spec_block_reqs + # We always save the current running state at this position. + if request_id in self._allocated_spec_block_reqs: + self.last_state_block_idx[request_id] = prev_block_len - 1 - self.kv_cache_spec.num_speculative_blocks + else: + if prev_block_len > 0: + self.last_state_block_idx[request_id] = prev_block_len - 1 + + required_block_start_idx = num_required_blocks - self.kv_cache_spec.num_speculative_blocks - 1 + # null blocks + if len(req_blocks) < required_block_start_idx: + req_blocks.extend([self._null_block for _ in range(required_block_start_idx - len(req_blocks))]) + + # reuse previous speculative blocks in this step + for block_idx in range(prev_block_len - self.kv_cache_spec.num_speculative_blocks, prev_block_len): + if block_idx < required_block_start_idx: + req_blocks.append(req_blocks[block_idx]) + req_blocks[block_idx] = self._null_block + else: + break + num_new_blocks = num_required_blocks - len(req_blocks) + self.print(f'Mamba.alloc_blks: {request_id=}, num_new_blocks={num_new_blocks}') + if dbg_is_allocated: + assert num_new_blocks <= 1 else: - num_new_alloc_blocks = 1 - new_blocks.extend([self._null_block for _ in range(num_new_blocks - num_new_alloc_blocks)]) - if self.num_speculative_blocks > 0: - # step i: [0, 0, 0, a, sps_0, sps_1] - # step i+1(i.e. j): [0, 0, 0, a, 0, 0, 0, b, sps_0, sps_1] - # reuse blocks sps_0 and sps_1 - # new_alloc_blocks: b - # new_blocks: [0, 0, 0, b, sps_0, sps_1] - # TODO: reuse blocks. if we need clean? especially in decode - reuse_blocks = req_blocks[-self.num_speculative_blocks:] - del req_blocks[-self.num_speculative_blocks:] - new_alloc_blocks = self.block_pool.get_new_blocks(num_new_alloc_blocks) - new_blocks.extend(new_alloc_blocks) - new_blocks.extend(reuse_blocks) + assert num_new_blocks <= self.kv_cache_spec.num_speculative_blocks + 1 + new_blocks = self.block_pool.get_new_blocks(num_new_blocks) req_blocks.extend(new_blocks) - self.print(f'Mamba.alloc_blks: {request_id=}, {num_tokens=}, new_blocks={format_blocks(new_blocks)}') + self._allocated_spec_block_reqs.add(request_id) self.print(f'Mamba.alloc_blks: {request_id=}, {len(req_blocks)=}, {len(self.req_to_blocks[request_id])=}, req_blocks={format_blocks(req_blocks)}') - return new_blocks + return req_blocks[prev_block_len:] + def free(self, request_id: str) -> None: if envs.VLLM_USE_LIGHTER_MAMBA_CACHE: - self._allocated_reqs.discard(request_id) - self._req_to_last_computed.pop(request_id, None) + self._allocated_spec_block_reqs.discard(request_id) + self.last_state_block_idx.pop(request_id, None) super().free(request_id) class CrossAttentionManager(SingleTypeKVCacheManager): From f2ac29b241c548a73ea6c464322117b488c887b3 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Wed, 3 Dec 2025 00:12:54 -0800 Subject: [PATCH 027/130] add e2e impl (still has bug) Signed-off-by: Chen Zhang --- examples/offline_inference/run.py | 5 +- vllm/v1/core/sched/scheduler.py | 2 +- vllm/v1/core/single_type_kv_cache_manager.py | 53 +++-- vllm/v1/worker/gpu_model_runner.py | 225 ++++++++----------- vllm/v1/worker/utils.py | 15 +- 5 files changed, 142 insertions(+), 158 deletions(-) diff --git a/examples/offline_inference/run.py b/examples/offline_inference/run.py index 0883486987d4..40dc4b31dad2 100644 --- a/examples/offline_inference/run.py +++ b/examples/offline_inference/run.py @@ -4,7 +4,7 @@ def main(): MODEL = "Qwen/Qwen3-Next-80B-A3B-Instruct" PROMPT_MULTIPLE = 6 - sampling_params = SamplingParams(temperature=0.0, max_tokens=5) + sampling_params = SamplingParams(temperature=0.0, max_tokens=300) prefix = ( # examples/offline_inference/prefix_caching.py "Your name is QQQQ " "You are an expert school principal, skilled in effectively managing " @@ -22,7 +22,8 @@ def main(): # for APC in [False, True]: for APC in [True]: engine = LLM(model=MODEL, enable_prefix_caching=APC, enforce_eager=True, tensor_parallel_size=4, - # load_format="dummy" + block_size=288, + # load_format="dummy", speculative_config={"method": "qwen3_next_mtp", "num_speculative_tokens": 2} ) for i in range(3): diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 0f2aeda6bcbd..00d878e52ab7 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -803,7 +803,7 @@ def schedule(self) -> SchedulerOutput: with record_function_or_nullcontext("schedule: update_after_schedule"): self._update_after_schedule(scheduler_output) - logger.info('>>> [DEBUG] Scheduler: scheduler_output: {scheduler_output}') + logger.info(f'>>> [DEBUG] Scheduler: scheduler_output: {scheduler_output}') return scheduler_output def _update_after_schedule( diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 9579a8c1c1c6..7cf06db854f9 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -648,8 +648,8 @@ class MambaManager(SingleTypeKVCacheManager): def __init__(self, kv_cache_spec: MambaSpec, **kwargs) -> None: super().__init__(kv_cache_spec, **kwargs) + self.num_speculative_blocks: int = kv_cache_spec.num_speculative_blocks if envs.VLLM_USE_LIGHTER_MAMBA_CACHE: - self.num_speculative_blocks: int = kv_cache_spec.num_speculative_blocks # self._req_info : dict[str, MambaManager.AllocationInfo] = {} self.last_state_block_idx: dict[str, int] = {} self._allocated_spec_block_reqs: set[str] = set() @@ -733,7 +733,8 @@ def get_num_blocks_to_allocate( * self.kv_cache_spec.num_speculative_blocks ) num_blocks_to_allocate = super().get_num_blocks_to_allocate( - request_id, num_tokens, new_computed_blocks + request_id, num_tokens, new_computed_blocks, + num_tokens_target_model, ) if envs.VLLM_USE_LIGHTER_MAMBA_CACHE: # (Chen): This may be possible. (block_size 4, 2 sps). @@ -741,7 +742,7 @@ def get_num_blocks_to_allocate( # [A, ?, ?, ?] NULL NULL [?, ?, ?, B] [stoken 1, stoken 2] SBLOCK1 SBLOCK2 -> need two blocks # but we do it as following: # [A, ?, ?, ?] NULL NULL NULL [stoken 1, stoken 2] SBLOCK1 SBLOCK2 -> need 1 block - if request_id in self._req_info: + if request_id in self._allocated_spec_block_reqs: # previously allocated blocks num_blocks_to_allocate = min(num_blocks_to_allocate, 1) else: @@ -765,47 +766,51 @@ def allocate_new_blocks( # speculative decoding (MTP/EAGLE) with linear attention. assert isinstance(self.kv_cache_spec, MambaSpec) if not envs.VLLM_USE_LIGHTER_MAMBA_CACHE: - if self.kv_cache_spec.num_speculative_blocks > 0: + if self.num_speculative_blocks > 0: num_tokens += ( - self.kv_cache_spec.block_size - * self.kv_cache_spec.num_speculative_blocks + self.block_size + * self.num_speculative_blocks ) return super().allocate_new_blocks(request_id, num_tokens, num_tokens_target_model) else: req_blocks: list[KVCacheBlock] = self.req_to_blocks[request_id] num_tokens = num_tokens_target_model - num_required_blocks = cdiv(num_tokens, self.block_size) + self.kv_cache_spec.num_speculative_blocks + num_required_blocks = cdiv(num_tokens, self.block_size) + self.num_speculative_blocks if num_required_blocks == len(req_blocks): return [] else: - assert num_required_blocks < len(req_blocks), f'num_required_blocks {num_required_blocks} < len(req_blocks) {len(req_blocks)}' + assert num_required_blocks > len(req_blocks), f'num_required_blocks {num_required_blocks} < len(req_blocks) {len(req_blocks)}' prev_block_len = len(req_blocks) - dbg_is_allocated = request_id in self._allocated_spec_block_reqs - # We always save the current running state at this position. + spec_blocks_allocated = request_id in self._allocated_spec_block_reqs + # We always save the current running state at the last (1 + num_speculative_blocks) block if request_id in self._allocated_spec_block_reqs: - self.last_state_block_idx[request_id] = prev_block_len - 1 - self.kv_cache_spec.num_speculative_blocks + self.last_state_block_idx[request_id] = prev_block_len - 1 - self.num_speculative_blocks else: if prev_block_len > 0: self.last_state_block_idx[request_id] = prev_block_len - 1 + else: + assert request_id not in self._allocated_spec_block_reqs - required_block_start_idx = num_required_blocks - self.kv_cache_spec.num_speculative_blocks - 1 + num_skipped_blocks = num_required_blocks - self.num_speculative_blocks - 1 # null blocks - if len(req_blocks) < required_block_start_idx: - req_blocks.extend([self._null_block for _ in range(required_block_start_idx - len(req_blocks))]) - - # reuse previous speculative blocks in this step - for block_idx in range(prev_block_len - self.kv_cache_spec.num_speculative_blocks, prev_block_len): - if block_idx < required_block_start_idx: - req_blocks.append(req_blocks[block_idx]) - req_blocks[block_idx] = self._null_block - else: - break + if len(req_blocks) < num_skipped_blocks: + req_blocks.extend([self._null_block for _ in range(num_skipped_blocks - len(req_blocks))]) + + if spec_blocks_allocated: + # reuse previous speculative blocks in this step + for block_idx in range(prev_block_len - self.num_speculative_blocks, prev_block_len): + if block_idx < num_skipped_blocks: + req_blocks.append(req_blocks[block_idx]) + req_blocks[block_idx] = self._null_block + self.print(f"Mamba.alloc_blks: {request_id=}, moving block {block_idx} to the end now, req_blocks={format_blocks(req_blocks)}") + else: + break num_new_blocks = num_required_blocks - len(req_blocks) self.print(f'Mamba.alloc_blks: {request_id=}, num_new_blocks={num_new_blocks}') - if dbg_is_allocated: + if spec_blocks_allocated: assert num_new_blocks <= 1 else: - assert num_new_blocks <= self.kv_cache_spec.num_speculative_blocks + 1 + assert num_new_blocks <= self.num_speculative_blocks + 1 new_blocks = self.block_pool.get_new_blocks(num_new_blocks) req_blocks.extend(new_blocks) self._allocated_spec_block_reqs.add(request_id) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index ad265feeb960..9dbbb65cb400 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -8,7 +8,7 @@ from collections.abc import Iterator, Sequence from contextlib import contextmanager from copy import copy, deepcopy -from functools import reduce +from functools import lru_cache, reduce from itertools import product from typing import TYPE_CHECKING, Any, NamedTuple, Optional, TypeAlias, cast @@ -158,6 +158,7 @@ add_kv_sharing_layers_to_kv_cache_groups, bind_kv_cache, gather_mm_placeholders, + get_mamba_groups, sanity_check_mm_encoder_outputs, scatter_mm_placeholders, ) @@ -583,6 +584,7 @@ def __init__( # Ephemeral state transferred between execute_model() and sample_tokens(). self.execute_model_state: ExecuteModelState | None = None self.kv_connector_output: KVConnectorOutput | None = None + self.mamba_state_idx: dict[str, list[int]] = {} def reset_mm_cache(self) -> None: if self.mm_budget: @@ -846,22 +848,6 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Update the block IDs. if not resumed_from_preemption: if new_block_ids is not None: - if (envs.VLLM_USE_LIGHTER_MAMBA_CACHE - and self.speculative_config): - for kv_cache_gid, kv_cache_group in enumerate( - self.kv_cache_config.kv_cache_groups - ): - kv_cache_spec = kv_cache_group.kv_cache_spec - # NOTE(hhy): pop the last num_speculative_blocks (reuse blocks). - # maybe have better way to help copy blocks? - if (isinstance(kv_cache_spec, MambaSpec) - and new_block_ids[kv_cache_gid] - ): - block_ids = req_state.block_ids[kv_cache_gid] - del block_ids[-kv_cache_spec.num_speculative_blocks:] - if is_global_first_rank(): - logger.info(f'>>> [DEBUG] Worker: poping reuse blocks: {kv_cache_gid=}, {req_state.block_ids[kv_cache_gid]=}') - # Append the new blocks to the existing block IDs. for block_ids, new_ids in zip(req_state.block_ids, new_block_ids): block_ids.extend(new_ids) @@ -889,24 +875,6 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Update the persistent batch. self.input_batch.num_computed_tokens_cpu[req_index] = num_computed_tokens if new_block_ids is not None: - # NOTE(hhy): same as block_ids, pop the last num_speculative_blocks - if (envs.VLLM_USE_LIGHTER_MAMBA_CACHE - and self.speculative_config): - num_poped_blocks = [] - for kv_cache_gid, kv_cache_group in enumerate( - self.kv_cache_config.kv_cache_groups - ): - if (isinstance(kv_cache_group.kv_cache_spec, MambaSpec) - and new_block_ids[kv_cache_gid] - ): - num_poped_blocks.append( - kv_cache_group.kv_cache_spec.num_speculative_blocks - ) - else: - num_poped_blocks.append(0) - self.input_batch.block_table.pop_row( - tuple(num_poped_blocks), req_index, - ) self.input_batch.block_table.append_row(new_block_ids, req_index) # For the last rank, we don't need to update the token_ids_cpu @@ -968,7 +936,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: self.input_batch.refresh_metadata() def _update_states_after_model_execute( - self, output_token_ids: torch.Tensor + self, output_token_ids: torch.Tensor, scheduler_output: "SchedulerOutput" ) -> None: """Update the cached states after model execution. @@ -1006,13 +974,13 @@ def _update_states_after_model_execute( self.input_batch.num_accepted_tokens_cpu[i] = num_tokens if is_global_first_rank(): logger.info(f'>>> [DEBUG] Worker: _update_states: ' - f'{self.input_batch.num_accepted_tokens_cpu[:len(num_accepted_tokens)]=}') + f'{output_token_ids=}') logger.info(f'>>> [DEBUG] Worker: _update_states: ' - f'{self.input_batch.num_computed_tokens_cpu[:len(num_accepted_tokens)]=}') + f'{self.input_batch.num_accepted_tokens_cpu[:len(num_accepted_tokens)]=}') if (envs.VLLM_USE_LIGHTER_MAMBA_CACHE and self.cache_config.enable_prefix_caching ): - self._postprocess_mamba() + self._postprocess_mamba(scheduler_output) def _init_mrope_positions(self, req_state: CachedRequestState): model = self.get_model() @@ -2489,7 +2457,8 @@ def _sample( logits, sampling_metadata, ) - self._update_states_after_model_execute(sampler_output.sampled_token_ids) + if is_global_first_rank(): + logger.info(f'>>> [DEBUG] Worker: sampler_output: {sampler_output.sampled_token_ids.shape} {sampler_output.sampled_token_ids}') return sampler_output def _bookkeeping_sync( @@ -2688,108 +2657,103 @@ def _mamba_copy_block(self, kv_cache_group_spec: KVCacheGroupSpec, kv_cache_part[dest_block_id].copy_(kv_cache_part[src_block_id]) def _preprocess_mamba(self, scheduler_output: "SchedulerOutput"): + """ + Copies the mamba state of previous step to the last (1 + num_speculative_blocks) block + """ + mamba_group_ids, mamba_spec = get_mamba_groups(self.kv_cache_config) + num_speculative_blocks = mamba_spec.num_speculative_blocks # TODO(Chen): we need to optimize this function a lot assert self.cache_config.enable_prefix_caching - block_size = self.cache_config.block_size - preprocess_req_index: list[int] = [] - for new_req_data in scheduler_output.scheduled_new_reqs: - if new_req_data.num_computed_tokens > 0: - # NOTE(hhy): prefix should be block aligned - assert new_req_data.num_computed_tokens % block_size == 0 - preprocess_req_index.append( - self.input_batch.req_id_to_index[new_req_data.req_id] - ) - cached_reqs = scheduler_output.scheduled_cached_reqs - for i, req_id in enumerate(cached_reqs.req_ids): - new_block_ids = cached_reqs.new_block_ids[i] - if not new_block_ids: - continue - for kv_cache_gid, kv_cache_group in enumerate( - self.kv_cache_config.kv_cache_groups - ): - if not isinstance( - kv_cache_group.kv_cache_spec, MambaSpec - ): - continue - # NOTE(hhy): assume all mamba groups are the same - if new_block_ids[kv_cache_gid]: - preprocess_req_index.append( - self.input_batch.req_id_to_index[req_id] - ) - break - - if is_global_first_rank(): - logger.info(f'>>> [DEBUG] Worker: preprocess mamba: {preprocess_req_index=}') - for i in preprocess_req_index: - req_id = self.input_batch.req_ids[i] + block_size = mamba_spec.block_size + for req_id in itertools.chain(scheduler_output.finished_req_ids, scheduler_output.preempted_req_ids): + self.mamba_state_idx.pop(req_id, None) + for req_id in self.input_batch.req_ids: if is_global_first_rank(): - logger.info(f'>>> [DEBUG] Worker: preprocess mamba for RUN: {i=} {req_id=}') + logger.info(f'>>> [DEBUG] Worker: preprocess mamba for RUN: {req_id=}') req_state = self.requests[req_id] - if req_state.num_computed_tokens == 0: - # new request, no previous state - continue - - prev_block_idx = (req_state.num_computed_tokens - 1)// block_size - # NOTE(Chen): if we have 7 tokens and 3 unverified speculative decoding tokens, seq_lens=10 here - # TODO in this PR(Chen): verify this for spec decode and adjust the comment - # TODO: copy must have new blocks - curr_block_idx = (self.seq_lens.cpu[i] - 1) // block_size + prev_state_idx = self.mamba_state_idx.get(req_id) + if prev_state_idx is None: + # new / resumed request, no previous state + # if num_computed_tokens is 0, prev_state_idx will be -1 + prev_state_idx = (req_state.num_computed_tokens - 1) // block_size + + num_blocks = len(req_state.block_ids[mamba_group_ids[0]]) + # We always save the current running state at the last (1 + num_speculative_blocks) block + curr_state_idx = num_blocks - 1 - num_speculative_blocks if is_global_first_rank(): - logger.info(f'>>> [DEBUG] Worker: preprocess mamba: {req_id=}, prev_len={req_state.num_computed_tokens} ' - f'prev_block_idx={prev_block_idx}, curr_len={self.seq_lens.cpu[i]} curr_block_idx={curr_block_idx}') - if prev_block_idx == curr_block_idx: - # same block, no need to copy + logger.info(f'>>> [DEBUG] Worker: preprocess mamba: {req_id=}, idx {prev_state_idx=} -> {curr_state_idx=}') + self.mamba_state_idx[req_id] = curr_state_idx + if prev_state_idx == -1 or prev_state_idx == curr_state_idx: + # no need to copy continue - - for kv_cache_gid, kv_cache_group in enumerate( - self.kv_cache_config.kv_cache_groups - ): - if not isinstance(kv_cache_group.kv_cache_spec, MambaSpec): - continue - prev_block_id = req_state.block_ids[kv_cache_gid][prev_block_idx] - curr_block_id = req_state.block_ids[kv_cache_gid][curr_block_idx] + for mamba_group_id in mamba_group_ids: + prev_block_id = req_state.block_ids[mamba_group_id][prev_state_idx] + curr_block_id = req_state.block_ids[mamba_group_id][curr_state_idx] assert prev_block_id != 0 assert curr_block_id != 0 - if is_global_first_rank(): + if is_global_first_rank(): logger.info(f'>>> [DEBUG] Worker: preprocess mamba: {req_id=}, COPY block {prev_block_id=} -> {curr_block_id=}') - # TODO(Chen): parallelize this loop - self._mamba_copy_block(kv_cache_group, prev_block_id, curr_block_id) - - def _postprocess_mamba(self): - assert self.cache_config.enable_prefix_caching - assert self.speculative_config - block_size = self.cache_config.block_size - num_reqs = self.input_batch.num_reqs - for i in range(num_reqs): - num_computed_tokens = self.input_batch.num_computed_tokens_cpu[i] - num_accepted_tokens = self.input_batch.num_accepted_tokens_cpu[i] - # block aligned, no need to copy mamba blocks - if num_computed_tokens % block_size == 0: - continue - computed_block_idx = num_computed_tokens // block_size - num_new_computed_tokens = num_computed_tokens + num_accepted_tokens - new_computed_block_idx = num_new_computed_tokens // block_size - if new_computed_block_idx == computed_block_idx: - continue - assert computed_block_idx + 1 == new_computed_block_idx - req_id = self.input_batch.req_ids[i] + self._mamba_copy_block(self.kv_cache_config.kv_cache_groups[mamba_group_id], prev_block_id, curr_block_id) + + def _mamba_copy_block_for_qwen_next(self, kv_cache_group_spec: KVCacheGroupSpec, src_block_idx: int, dest_block_idx: int, accept_token_bias: int, block_ids: list[int]): + # TODO: general impl for all models + if src_block_idx == dest_block_idx and accept_token_bias == 0: + return + forward_context = self.compilation_config.static_forward_context + dest_block_id = block_ids[dest_block_idx] + for layer_name in kv_cache_group_spec.layer_names: + kv_caches: list[list[torch.Tensor]] = forward_context[layer_name].kv_cache[0] + conv_state, gdn_state = kv_caches + # conv state + conv_state_block_id = block_ids[src_block_idx] + src_conv_state = conv_state[conv_state_block_id][accept_token_bias:] + dest_conv_state = conv_state[dest_block_id] + dest_conv_state[:len(src_conv_state)].copy_(src_conv_state.clone()) + # gdn state + gdn_state_block_id = block_ids[src_block_idx + accept_token_bias] + src_gdn_state = gdn_state[gdn_state_block_id] + dest_gdn_state = gdn_state[dest_block_id] + dest_gdn_state.copy_(src_gdn_state) + if is_global_first_rank() and layer_name == kv_cache_group_spec.layer_names[0]: + logger.info(f'>>> [DEBUG] Worker: mamba_copy_block_for_qwen_next: {layer_name=}, conv {conv_state_block_id=} -> {dest_block_id=} with bias {accept_token_bias}, {gdn_state_block_id=} -> {dest_block_id=}') + + def _postprocess_mamba(self, scheduler_output: "SchedulerOutput"): + """ + 1. If a blocks is converted from partial block to full block in this step, copy + 2. Unify the state after token acceptance + the state from mamba_state_idx to that block + """ + num_scheduled_tokens_dict = scheduler_output.num_scheduled_tokens + scheduled_spec_decode_tokens_dict = scheduler_output.scheduled_spec_decode_tokens + num_accepted_tokens_cpu = self.input_batch.num_accepted_tokens_cpu + if is_global_first_rank(): + logger.info(f'>>> [DEBUG] Worker: postprocess mamba {num_scheduled_tokens_dict=} {scheduled_spec_decode_tokens_dict=} {num_accepted_tokens_cpu=}') + # NOTE: can be optimized as this function always returns the same result + mamba_group_ids, mamba_spec = get_mamba_groups(self.kv_cache_config) + # TODO: vectorize this loop + for i, req_id in enumerate(self.input_batch.req_ids): req_state = self.requests[req_id] - prev_block_idx = (num_computed_tokens - 1)// block_size - curr_block_idx = (self.seq_lens.cpu[i] - 1) // block_size + num_accepted_tokens - 1 + num_computed_tokens = req_state.num_computed_tokens + num_draft_tokens = len(scheduled_spec_decode_tokens_dict.get(req_id, [])) + num_scheduled_tokens = num_scheduled_tokens_dict[req_id] + num_accepted_tokens = num_accepted_tokens_cpu[i] + num_tokens_running_state = num_computed_tokens + num_scheduled_tokens - num_draft_tokens + new_num_computed_tokens = num_tokens_running_state + num_accepted_tokens - 1 + aligned_new_computed_tokens = new_num_computed_tokens // mamba_spec.block_size * mamba_spec.block_size if is_global_first_rank(): - logger.info(f'>>> [DEBUG] Worker: postprocess mamba: {req_id=}, ' - f'{num_computed_tokens=}, {num_accepted_tokens=}, {prev_block_idx=}, {curr_block_idx=}') - for kv_cache_gid, kv_cache_group in enumerate( - self.kv_cache_config.kv_cache_groups - ): - if not isinstance(kv_cache_group.kv_cache_spec, MambaSpec): - continue - prev_block_id = req_state.block_ids[kv_cache_gid][prev_block_idx] - curr_block_id = req_state.block_ids[kv_cache_gid][curr_block_idx] + logger.info(f'>>> [DEBUG] Worker: postprocess mamba: {req_id=}, {num_computed_tokens=}, {num_scheduled_tokens=} {num_draft_tokens=} {num_accepted_tokens=} {num_tokens_running_state=} {new_num_computed_tokens=} {aligned_new_computed_tokens=}') + # TODO: how to ensure all blocks that cache_blocks called are cached here? + if aligned_new_computed_tokens >= num_tokens_running_state: + accept_token_bias = aligned_new_computed_tokens - num_tokens_running_state + src_block_idx = self.mamba_state_idx[req_id] + dest_block_idx = aligned_new_computed_tokens // mamba_spec.block_size - 1 if is_global_first_rank(): - logger.info(f'>>> [DEBUG] Worker: postprocess mamba: {kv_cache_gid=}, COPY block {curr_block_id=} -> {prev_block_id=}') - self._mamba_copy_block(kv_cache_group, curr_block_id, prev_block_id) - + logger.info(f'>>> [DEBUG] Worker: postprocess mamba: {req_id=}, {src_block_idx=} -> {dest_block_idx=} with bias {accept_token_bias}') + for mamba_group_id in mamba_group_ids: + self._mamba_copy_block_for_qwen_next(self.kv_cache_config.kv_cache_groups[mamba_group_id], src_block_idx, dest_block_idx, accept_token_bias, req_state.block_ids[mamba_group_id]) + if src_block_idx == dest_block_idx: + num_accepted_tokens_cpu[i] = 0 + @torch.inference_mode() def execute_model( self, @@ -3094,6 +3058,7 @@ def sample_tokens( with record_function_or_nullcontext("gpu_model_runner: sample"): sampler_output = self._sample(logits, spec_decode_metadata) + self._update_states_after_model_execute(sampler_output.sampled_token_ids, scheduler_output) self.input_batch.prev_sampled_token_ids = None def propose_draft_token_ids(sampled_token_ids): diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index 92e4ce3abdba..064f9aa8568b 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections import defaultdict from dataclasses import dataclass, field +from functools import lru_cache from typing import TYPE_CHECKING import torch @@ -15,7 +16,7 @@ from vllm.platforms import current_platform from vllm.v1.attention.backends.utils import AttentionMetadataBuilder from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget -from vllm.v1.kv_cache_interface import KVCacheGroupSpec, KVCacheSpec +from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheGroupSpec, KVCacheSpec, MambaSpec if TYPE_CHECKING: from vllm.attention.layer import Attention @@ -366,3 +367,15 @@ def is_residual_scattered_for_sp( if compile_sizes is None: return False return num_input_tokens in compile_sizes + + +def get_mamba_groups(kv_cache_config: KVCacheConfig) -> tuple[list[int], MambaSpec]: + mamba_group_ids: list[int] = [] + mamba_specs: list[MambaSpec] = [] + for i in range(len(kv_cache_config.kv_cache_groups)): + if isinstance(kv_cache_config.kv_cache_groups[i].kv_cache_spec, MambaSpec): + mamba_group_ids.append(i) + mamba_specs.append(kv_cache_config.kv_cache_groups[i].kv_cache_spec) + assert len(mamba_group_ids) > 0, "no mamba layers in the model" + assert all(mamba_specs[0] == spec for spec in mamba_specs) + return mamba_group_ids, mamba_specs[0] From 2f103fc73ee4a9f4575d2cd28cd69e318460642a Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Sat, 6 Dec 2025 23:33:17 -0800 Subject: [PATCH 028/130] update runner Signed-off-by: Chen Zhang --- vllm/v1/worker/gpu_model_runner.py | 130 +++++++++++++++++++++-------- 1 file changed, 93 insertions(+), 37 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 9dbbb65cb400..eb571144a12c 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -8,9 +8,9 @@ from collections.abc import Iterator, Sequence from contextlib import contextmanager from copy import copy, deepcopy -from functools import lru_cache, reduce +from functools import reduce from itertools import product -from typing import TYPE_CHECKING, Any, NamedTuple, Optional, TypeAlias, cast +from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias, cast import numpy as np import torch @@ -98,7 +98,6 @@ reorder_batch_to_split_decodes_and_prefills, split_attn_metadata, ) -from vllm.v1.core.sched.output import CachedRequestData, NewRequestData from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher from vllm.v1.kv_cache_interface import ( AttentionSpec, @@ -973,11 +972,13 @@ def _update_states_after_model_execute( for i, num_tokens in enumerate(num_accepted_tokens): self.input_batch.num_accepted_tokens_cpu[i] = num_tokens if is_global_first_rank(): - logger.info(f'>>> [DEBUG] Worker: _update_states: ' - f'{output_token_ids=}') - logger.info(f'>>> [DEBUG] Worker: _update_states: ' - f'{self.input_batch.num_accepted_tokens_cpu[:len(num_accepted_tokens)]=}') - if (envs.VLLM_USE_LIGHTER_MAMBA_CACHE + logger.info(f">>> [DEBUG] Worker: _update_states: {output_token_ids=}") + logger.info( + f">>> [DEBUG] Worker: _update_states: " + f"{self.input_batch.num_accepted_tokens_cpu[:len(num_accepted_tokens)]=}" + ) + if ( + envs.VLLM_USE_LIGHTER_MAMBA_CACHE and self.cache_config.enable_prefix_caching ): self._postprocess_mamba(scheduler_output) @@ -1539,7 +1540,7 @@ def _build_attention_metadata( # Fill unused with -1. Needed for reshape_and_cache in full cuda # graph mode. blk_table.slot_mapping.gpu[total_num_scheduled_tokens:].fill_(-1) - + # if (envs.VLLM_USE_LIGHTER_MAMBA_CACHE # and self.cache_config.enable_prefix_caching # and isinstance(kv_cache_group.kv_cache_spec, MambaSpec) @@ -2458,7 +2459,9 @@ def _sample( sampling_metadata, ) if is_global_first_rank(): - logger.info(f'>>> [DEBUG] Worker: sampler_output: {sampler_output.sampled_token_ids.shape} {sampler_output.sampled_token_ids}') + logger.info( + f">>> [DEBUG] Worker: sampler_output: {sampler_output.sampled_token_ids.shape} {sampler_output.sampled_token_ids}" + ) return sampler_output def _bookkeeping_sync( @@ -2642,8 +2645,12 @@ def _model_forward( **model_kwargs, ) - def _mamba_copy_block(self, kv_cache_group_spec: KVCacheGroupSpec, - src_block_id: int, dest_block_id: int): + def _mamba_copy_block( + self, + kv_cache_group_spec: KVCacheGroupSpec, + src_block_id: int, + dest_block_id: int, + ): if src_block_id == dest_block_id: return forward_context = self.compilation_config.static_forward_context @@ -2665,11 +2672,13 @@ def _preprocess_mamba(self, scheduler_output: "SchedulerOutput"): # TODO(Chen): we need to optimize this function a lot assert self.cache_config.enable_prefix_caching block_size = mamba_spec.block_size - for req_id in itertools.chain(scheduler_output.finished_req_ids, scheduler_output.preempted_req_ids): + for req_id in itertools.chain( + scheduler_output.finished_req_ids, scheduler_output.preempted_req_ids + ): self.mamba_state_idx.pop(req_id, None) for req_id in self.input_batch.req_ids: if is_global_first_rank(): - logger.info(f'>>> [DEBUG] Worker: preprocess mamba for RUN: {req_id=}') + logger.info(f">>> [DEBUG] Worker: preprocess mamba for RUN: {req_id=}") req_state = self.requests[req_id] prev_state_idx = self.mamba_state_idx.get(req_id) if prev_state_idx is None: @@ -2681,7 +2690,9 @@ def _preprocess_mamba(self, scheduler_output: "SchedulerOutput"): # We always save the current running state at the last (1 + num_speculative_blocks) block curr_state_idx = num_blocks - 1 - num_speculative_blocks if is_global_first_rank(): - logger.info(f'>>> [DEBUG] Worker: preprocess mamba: {req_id=}, idx {prev_state_idx=} -> {curr_state_idx=}') + logger.info( + f">>> [DEBUG] Worker: preprocess mamba: {req_id=}, idx {prev_state_idx=} -> {curr_state_idx=}" + ) self.mamba_state_idx[req_id] = curr_state_idx if prev_state_idx == -1 or prev_state_idx == curr_state_idx: # no need to copy @@ -2691,31 +2702,51 @@ def _preprocess_mamba(self, scheduler_output: "SchedulerOutput"): curr_block_id = req_state.block_ids[mamba_group_id][curr_state_idx] assert prev_block_id != 0 assert curr_block_id != 0 - if is_global_first_rank(): - logger.info(f'>>> [DEBUG] Worker: preprocess mamba: {req_id=}, COPY block {prev_block_id=} -> {curr_block_id=}') - self._mamba_copy_block(self.kv_cache_config.kv_cache_groups[mamba_group_id], prev_block_id, curr_block_id) - - def _mamba_copy_block_for_qwen_next(self, kv_cache_group_spec: KVCacheGroupSpec, src_block_idx: int, dest_block_idx: int, accept_token_bias: int, block_ids: list[int]): + if is_global_first_rank(): + logger.info( + f">>> [DEBUG] Worker: preprocess mamba: {req_id=}, COPY block {prev_block_id=} -> {curr_block_id=}" + ) + self._mamba_copy_block( + self.kv_cache_config.kv_cache_groups[mamba_group_id], + prev_block_id, + curr_block_id, + ) + + def _mamba_copy_block_for_qwen_next( + self, + kv_cache_group_spec: KVCacheGroupSpec, + src_block_idx: int, + dest_block_idx: int, + accept_token_bias: int, + block_ids: list[int], + ): # TODO: general impl for all models if src_block_idx == dest_block_idx and accept_token_bias == 0: return forward_context = self.compilation_config.static_forward_context dest_block_id = block_ids[dest_block_idx] for layer_name in kv_cache_group_spec.layer_names: - kv_caches: list[list[torch.Tensor]] = forward_context[layer_name].kv_cache[0] + kv_caches: list[list[torch.Tensor]] = forward_context[layer_name].kv_cache[ + 0 + ] conv_state, gdn_state = kv_caches # conv state conv_state_block_id = block_ids[src_block_idx] src_conv_state = conv_state[conv_state_block_id][accept_token_bias:] dest_conv_state = conv_state[dest_block_id] - dest_conv_state[:len(src_conv_state)].copy_(src_conv_state.clone()) + dest_conv_state[: len(src_conv_state)].copy_(src_conv_state.clone()) # gdn state gdn_state_block_id = block_ids[src_block_idx + accept_token_bias] src_gdn_state = gdn_state[gdn_state_block_id] dest_gdn_state = gdn_state[dest_block_id] dest_gdn_state.copy_(src_gdn_state) - if is_global_first_rank() and layer_name == kv_cache_group_spec.layer_names[0]: - logger.info(f'>>> [DEBUG] Worker: mamba_copy_block_for_qwen_next: {layer_name=}, conv {conv_state_block_id=} -> {dest_block_id=} with bias {accept_token_bias}, {gdn_state_block_id=} -> {dest_block_id=}') + if ( + is_global_first_rank() + and layer_name == kv_cache_group_spec.layer_names[0] + ): + logger.info( + f">>> [DEBUG] Worker: mamba_copy_block_for_qwen_next: {layer_name=}, conv {conv_state_block_id=} -> {dest_block_id=} with bias {accept_token_bias}, {gdn_state_block_id=} -> {dest_block_id=}" + ) def _postprocess_mamba(self, scheduler_output: "SchedulerOutput"): """ @@ -2724,10 +2755,14 @@ def _postprocess_mamba(self, scheduler_output: "SchedulerOutput"): the state from mamba_state_idx to that block """ num_scheduled_tokens_dict = scheduler_output.num_scheduled_tokens - scheduled_spec_decode_tokens_dict = scheduler_output.scheduled_spec_decode_tokens + scheduled_spec_decode_tokens_dict = ( + scheduler_output.scheduled_spec_decode_tokens + ) num_accepted_tokens_cpu = self.input_batch.num_accepted_tokens_cpu if is_global_first_rank(): - logger.info(f'>>> [DEBUG] Worker: postprocess mamba {num_scheduled_tokens_dict=} {scheduled_spec_decode_tokens_dict=} {num_accepted_tokens_cpu=}') + logger.info( + f">>> [DEBUG] Worker: postprocess mamba {num_scheduled_tokens_dict=} {scheduled_spec_decode_tokens_dict=} {num_accepted_tokens_cpu=}" + ) # NOTE: can be optimized as this function always returns the same result mamba_group_ids, mamba_spec = get_mamba_groups(self.kv_cache_config) # TODO: vectorize this loop @@ -2737,23 +2772,41 @@ def _postprocess_mamba(self, scheduler_output: "SchedulerOutput"): num_draft_tokens = len(scheduled_spec_decode_tokens_dict.get(req_id, [])) num_scheduled_tokens = num_scheduled_tokens_dict[req_id] num_accepted_tokens = num_accepted_tokens_cpu[i] - num_tokens_running_state = num_computed_tokens + num_scheduled_tokens - num_draft_tokens + num_tokens_running_state = ( + num_computed_tokens + num_scheduled_tokens - num_draft_tokens + ) new_num_computed_tokens = num_tokens_running_state + num_accepted_tokens - 1 - aligned_new_computed_tokens = new_num_computed_tokens // mamba_spec.block_size * mamba_spec.block_size + aligned_new_computed_tokens = ( + new_num_computed_tokens // mamba_spec.block_size * mamba_spec.block_size + ) if is_global_first_rank(): - logger.info(f'>>> [DEBUG] Worker: postprocess mamba: {req_id=}, {num_computed_tokens=}, {num_scheduled_tokens=} {num_draft_tokens=} {num_accepted_tokens=} {num_tokens_running_state=} {new_num_computed_tokens=} {aligned_new_computed_tokens=}') + logger.info( + f">>> [DEBUG] Worker: postprocess mamba: {req_id=}, {num_computed_tokens=}, {num_scheduled_tokens=} {num_draft_tokens=} {num_accepted_tokens=} {num_tokens_running_state=} {new_num_computed_tokens=} {aligned_new_computed_tokens=}" + ) # TODO: how to ensure all blocks that cache_blocks called are cached here? if aligned_new_computed_tokens >= num_tokens_running_state: - accept_token_bias = aligned_new_computed_tokens - num_tokens_running_state + accept_token_bias = ( + aligned_new_computed_tokens - num_tokens_running_state + ) src_block_idx = self.mamba_state_idx[req_id] - dest_block_idx = aligned_new_computed_tokens // mamba_spec.block_size - 1 + dest_block_idx = ( + aligned_new_computed_tokens // mamba_spec.block_size - 1 + ) if is_global_first_rank(): - logger.info(f'>>> [DEBUG] Worker: postprocess mamba: {req_id=}, {src_block_idx=} -> {dest_block_idx=} with bias {accept_token_bias}') + logger.info( + f">>> [DEBUG] Worker: postprocess mamba: {req_id=}, {src_block_idx=} -> {dest_block_idx=} with bias {accept_token_bias}" + ) for mamba_group_id in mamba_group_ids: - self._mamba_copy_block_for_qwen_next(self.kv_cache_config.kv_cache_groups[mamba_group_id], src_block_idx, dest_block_idx, accept_token_bias, req_state.block_ids[mamba_group_id]) + self._mamba_copy_block_for_qwen_next( + self.kv_cache_config.kv_cache_groups[mamba_group_id], + src_block_idx, + dest_block_idx, + accept_token_bias, + req_state.block_ids[mamba_group_id], + ) if src_block_idx == dest_block_idx: - num_accepted_tokens_cpu[i] = 0 - + num_accepted_tokens_cpu[i] = 1 + @torch.inference_mode() def execute_model( self, @@ -2848,7 +2901,8 @@ def execute_model( # TODO(lucas): move cudagraph dispatching here: # https://github.com/vllm-project/vllm/issues/23789 - if (envs.VLLM_USE_LIGHTER_MAMBA_CACHE + if ( + envs.VLLM_USE_LIGHTER_MAMBA_CACHE and self.cache_config.enable_prefix_caching ): # TODO: add limition: preprocess only have new blocks @@ -3058,7 +3112,9 @@ def sample_tokens( with record_function_or_nullcontext("gpu_model_runner: sample"): sampler_output = self._sample(logits, spec_decode_metadata) - self._update_states_after_model_execute(sampler_output.sampled_token_ids, scheduler_output) + self._update_states_after_model_execute( + sampler_output.sampled_token_ids, scheduler_output + ) self.input_batch.prev_sampled_token_ids = None def propose_draft_token_ids(sampled_token_ids): From 08aa7102ccd09ada5c6a230f27c5cc2af71c89c5 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 9 Dec 2025 00:30:55 -0800 Subject: [PATCH 029/130] [WIP] write unit test Signed-off-by: Chen Zhang --- tests/v1/e2e/test_mamba_prefix_cache.py | 55 +++++++++++++++++++++++++ vllm/v1/engine/core_client.py | 1 + 2 files changed, 56 insertions(+) create mode 100644 tests/v1/e2e/test_mamba_prefix_cache.py diff --git a/tests/v1/e2e/test_mamba_prefix_cache.py b/tests/v1/e2e/test_mamba_prefix_cache.py new file mode 100644 index 000000000000..ab30893e9220 --- /dev/null +++ b/tests/v1/e2e/test_mamba_prefix_cache.py @@ -0,0 +1,55 @@ +import pytest +from vllm import LLM, SamplingParams +import time +from vllm.v1.worker.gpu_model_runner import GPUModelRunner + + +# def _fake_sample(self, logits: torch.Tensor | None, spec_decode_metadata: SpecDecodeMetadata | None) -> SamplerOutput: + + +def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch): + MODEL = "Qwen/Qwen3-Next-80B-A3B-Instruct" + PROMPT_MULTIPLE = 6 + monkeypatch.setenv("VLLM_USE_LIGHTER_MAMBA_CACHE", "1") + monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") + sampling_params = SamplingParams(temperature=0.0, max_tokens=30) + prefix = ( # examples/offline_inference/prefix_caching.py + "Your name is QQQQ " + "You are an expert school principal, skilled in effectively managing " + "faculty and staff. Draft 10-15 questions for a potential first grade " + "Head Teacher for my K-12, all-girls', independent school that emphasizes " + "community, joyful discovery, and life-long learning. The candidate is " + "coming in for a first-round panel interview for a 8th grade Math " + "teaching role. They have 5 years of previous teaching experience " + "as an assistant teacher at a co-ed, public school with experience " + "in middle school math teaching. " + ) + prefix2 = "Based on these information, fulfill the following paragraph: " + prompt = PROMPT_MULTIPLE * prefix + prefix2 + "Hello, my name is" + # print('Prompt length:', ) + # for APC in [False, True]: + for APC in [True]: + engine = LLM( + model=MODEL, + enable_prefix_caching=APC, + enforce_eager=True, + block_size=288, + speculative_config={ + "method": "qwen3_next_mtp", + "num_speculative_tokens": 2, + }, + hf_overrides={"num_hidden_layers": 8}, + seed=42, + ) + for i in range(3): + if i == 0: + print("Warm-up") + if i == 1: + print("Measuring") + start_time = time.time() + outputs = engine.generate(prompt, sampling_params) + print("APC:", APC, i, f"Generated text: {outputs[0].outputs[0].text!r}") + # for m in engine.llm_engine.get_metrics(): + # if 'vllm:prefix_cache_hits' in m.name: + # print(m.name, m.value) + print("APC:", APC, "loop took --- %s seconds ---" % (time.time() - start_time)) diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 9b440505bd9d..913c6a03f45e 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -265,6 +265,7 @@ def __init__(self, *args, **kwargs): def get_output(self) -> EngineCoreOutputs: outputs, _ = self.engine_core.step_fn() + self.engine_core.post_step(model_executed=True) return outputs and outputs.get(0) or EngineCoreOutputs() def get_supported_tasks(self) -> tuple[SupportedTask, ...]: From 5a39c44a238243dc9d53a48c5ac633df7004c1f5 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Wed, 10 Dec 2025 00:41:34 -0800 Subject: [PATCH 030/130] extract mamba state Signed-off-by: Chen Zhang --- tests/v1/e2e/test_mamba_prefix_cache.py | 261 +++++++++++++++++++---- vllm/model_executor/models/qwen3_next.py | 4 + 2 files changed, 221 insertions(+), 44 deletions(-) diff --git a/tests/v1/e2e/test_mamba_prefix_cache.py b/tests/v1/e2e/test_mamba_prefix_cache.py index ab30893e9220..32dd21ce8210 100644 --- a/tests/v1/e2e/test_mamba_prefix_cache.py +++ b/tests/v1/e2e/test_mamba_prefix_cache.py @@ -1,55 +1,228 @@ -import pytest -from vllm import LLM, SamplingParams +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import time +from typing import Callable + +import pytest +import os + +import torch + +from vllm import LLM, SamplingParams, TokensPrompt +from vllm.sequence import IntermediateTensors +from vllm.v1.attention.backends.utils import CommonAttentionMetadata +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.engine.core import EngineCore +from vllm.v1.engine.core_client import InprocClient +from vllm.v1.outputs import SamplerOutput +from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.worker.gpu_model_runner import GPUModelRunner +from vllm.v1.worker.utils import get_mamba_groups + +num_speculative_tokens = 2 + +num_accepted_tokens = 1 +prompt_token_ids = [] +MODEL = "Qwen/Qwen3-Next-80B-A3B-Instruct" +BLOCK_SIZE = 560 +NUM_HIDDEN_LAYERS = 8 + + +def get_fake_sample_fn() -> SamplerOutput: + def fake_sample_fn( + self: GPUModelRunner, + logits: torch.Tensor | None, + spec_decode_metadata: SpecDecodeMetadata | None, + ) -> SamplerOutput: + print( + f"[UNIT TEST] fake_sample_fn: {logits.shape=} {spec_decode_metadata=} {self.input_ids.cpu=}" + ) + num_computed_tokens_cpu_tensor = self.input_batch.num_computed_tokens_cpu_tensor + num_computed_tokens = num_computed_tokens_cpu_tensor[0].item() + if num_computed_tokens < self.input_batch.num_prompt_tokens[0].item(): + first_token_id_index = self.input_batch.num_prompt_tokens[0].item() + else: + first_token_id_index = num_computed_tokens + 1 + if spec_decode_metadata is None: + print( + f"[UNIT TEST] fake_sample_fn: {first_token_id_index=} {prompt_token_ids[first_token_id_index]=}" + ) + return SamplerOutput( + sampled_token_ids=torch.tensor( + [[prompt_token_ids[first_token_id_index]]], + device="cuda", + dtype=torch.int32, + ), + logprobs_tensors=None, + ) + num_sampled_tokens = spec_decode_metadata.cu_num_sampled_tokens[0].item() + 1 + accpeted_tokens = prompt_token_ids[ + first_token_id_index : first_token_id_index + + min(num_accepted_tokens, logits.shape[0]) + ] + sampled_token_ids = accpeted_tokens + [-1] * ( + num_sampled_tokens - len(accpeted_tokens) + ) + print( + f"[UNIT TEST] fake_sample_fn: {first_token_id_index=} {accpeted_tokens=} {sampled_token_ids=}" + ) + return SamplerOutput( + sampled_token_ids=torch.tensor( + [sampled_token_ids], device="cuda", dtype=torch.int32 + ), + logprobs_tensors=None, + ) + + return fake_sample_fn + + +def get_fake_propose_draft_token_ids_fn(): + def fake_propose_draft_token_ids_fn( + self, + scheduler_output: SchedulerOutput, + sampled_token_ids: torch.Tensor | list[list[int]], + sampling_metadata: SamplingMetadata, + hidden_states: torch.Tensor, + sample_hidden_states: torch.Tensor, + aux_hidden_states: list[torch.Tensor] | None, + spec_decode_metadata: SpecDecodeMetadata | None, + common_attn_metadata: CommonAttentionMetadata, + ) -> list[list[int]]: + num_computed_tokens_cpu_tensor = self.input_batch.num_computed_tokens_cpu_tensor + num_computed_tokens = num_computed_tokens_cpu_tensor[0].item() + if num_computed_tokens < self.input_batch.num_prompt_tokens[0].item(): + first_token_id_index = self.input_batch.num_prompt_tokens[0].item() + 1 + else: + first_token_id_index = num_computed_tokens + 2 + return [ + prompt_token_ids[ + first_token_id_index : first_token_id_index + num_speculative_tokens + ] + ] + + return fake_propose_draft_token_ids_fn + +mamba_kv_cache_dict = {} -# def _fake_sample(self, logits: torch.Tensor | None, spec_decode_metadata: SpecDecodeMetadata | None) -> SamplerOutput: + +def get_fake_execute_model_fn(original_execute_model_fn: Callable): + last_num_computed_tokens = 0 + + def fake_execute_model_fn( + self: GPUModelRunner, + scheduler_output: SchedulerOutput, + intermediate_tensors: IntermediateTensors | None = None, + ): + mamba_group_ids, mamba_spec = get_mamba_groups(self.kv_cache_config) + mamba_group_id = mamba_group_ids[0] + mamba_layer_name = self.kv_cache_config.kv_cache_groups[ + mamba_group_id + ].layer_names[0] + print(f"fake_execute_model_fn: {mamba_spec=}") + nonlocal last_num_computed_tokens + if len(scheduler_output.scheduled_cached_reqs.req_ids) > 0: + num_computed_tokens = ( + scheduler_output.scheduled_cached_reqs.num_computed_tokens[0] + ) + print( + f"fake_execute_model_fn: {num_computed_tokens=} {last_num_computed_tokens=} {num_computed_tokens // BLOCK_SIZE > last_num_computed_tokens // BLOCK_SIZE=}" + ) + if ( + num_computed_tokens // BLOCK_SIZE + > last_num_computed_tokens // BLOCK_SIZE + ): + # generated a new aligned block in this step + block_idx = num_computed_tokens // mamba_spec.block_size + block_id = ( + self.input_batch.block_table.block_tables[mamba_group_id] + .block_table.cpu[0, block_idx] + .item() + ) + kv_cache = self.compilation_config.static_forward_context[ + mamba_layer_name + ].kv_cache + print("KV CACHE", kv_cache) + print(kv_cache[0]) + print(kv_cache[0][0].shape, kv_cache[0][1].shape) + mamba_kv_cache_dict[ + num_computed_tokens - num_computed_tokens % BLOCK_SIZE + ] = (kv_cache[0][0][block_id].clone(), kv_cache[0][1][block_id].clone()) + + last_num_computed_tokens = num_computed_tokens + + ret = original_execute_model_fn(self, scheduler_output, intermediate_tensors) + + return ret + + return fake_execute_model_fn + + +def test_run_ref_mamba_state(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") + num_generated_tokens = 4000 + num_prompt_tokens = 500 + sampling_params = SamplingParams(temperature=0.0, max_tokens=num_generated_tokens) + full_prompt = open(f"{os.path.dirname(__file__)}/input.txt", "r").read() + fake_execute_model_fn = get_fake_execute_model_fn(GPUModelRunner.execute_model) + monkeypatch.setattr(GPUModelRunner, "execute_model", fake_execute_model_fn) + fake_sample_fn = get_fake_sample_fn() + monkeypatch.setattr(GPUModelRunner, "_sample", fake_sample_fn) + engine = LLM( + model=MODEL, + enforce_eager=True, + block_size=BLOCK_SIZE, + hf_overrides={"num_hidden_layers": NUM_HIDDEN_LAYERS}, + seed=42, + ) + global prompt_token_ids + prompt_token_ids = engine.get_tokenizer().encode(full_prompt) + print(f"Token IDs length: {len(prompt_token_ids)}") + + outputs = engine.generate( + [TokensPrompt(prompt_token_ids=prompt_token_ids[:num_prompt_tokens])], + sampling_params, + ) + print(f"Generated text: {outputs[0].outputs[0].token_ids}") + print( + f"expect token ids: {prompt_token_ids[num_prompt_tokens : num_prompt_tokens + num_generated_tokens]}" + ) + print(f"mamba_kv_cache_dict: {mamba_kv_cache_dict}") + torch.save(mamba_kv_cache_dict, "mamba_kv_cache_dict.pth") def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch): - MODEL = "Qwen/Qwen3-Next-80B-A3B-Instruct" - PROMPT_MULTIPLE = 6 monkeypatch.setenv("VLLM_USE_LIGHTER_MAMBA_CACHE", "1") monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") sampling_params = SamplingParams(temperature=0.0, max_tokens=30) - prefix = ( # examples/offline_inference/prefix_caching.py - "Your name is QQQQ " - "You are an expert school principal, skilled in effectively managing " - "faculty and staff. Draft 10-15 questions for a potential first grade " - "Head Teacher for my K-12, all-girls', independent school that emphasizes " - "community, joyful discovery, and life-long learning. The candidate is " - "coming in for a first-round panel interview for a 8th grade Math " - "teaching role. They have 5 years of previous teaching experience " - "as an assistant teacher at a co-ed, public school with experience " - "in middle school math teaching. " + + full_prompt = open(f"{os.path.dirname(__file__)}/input.txt", "r").read() + fake_sample_fn = get_fake_sample_fn() + monkeypatch.setattr(GPUModelRunner, "_sample", fake_sample_fn) + fake_propose_draft_token_ids_fn = get_fake_propose_draft_token_ids_fn() + monkeypatch.setattr( + GPUModelRunner, "propose_draft_token_ids", fake_propose_draft_token_ids_fn ) - prefix2 = "Based on these information, fulfill the following paragraph: " - prompt = PROMPT_MULTIPLE * prefix + prefix2 + "Hello, my name is" - # print('Prompt length:', ) - # for APC in [False, True]: - for APC in [True]: - engine = LLM( - model=MODEL, - enable_prefix_caching=APC, - enforce_eager=True, - block_size=288, - speculative_config={ - "method": "qwen3_next_mtp", - "num_speculative_tokens": 2, - }, - hf_overrides={"num_hidden_layers": 8}, - seed=42, - ) - for i in range(3): - if i == 0: - print("Warm-up") - if i == 1: - print("Measuring") - start_time = time.time() - outputs = engine.generate(prompt, sampling_params) - print("APC:", APC, i, f"Generated text: {outputs[0].outputs[0].text!r}") - # for m in engine.llm_engine.get_metrics(): - # if 'vllm:prefix_cache_hits' in m.name: - # print(m.name, m.value) - print("APC:", APC, "loop took --- %s seconds ---" % (time.time() - start_time)) + engine = LLM( + model=MODEL, + enable_prefix_caching=True, + enforce_eager=True, + block_size=BLOCK_SIZE, + speculative_config={ + "method": "qwen3_next_mtp", + "num_speculative_tokens": num_speculative_tokens, + }, + hf_overrides={"num_hidden_layers": NUM_HIDDEN_LAYERS}, + seed=42, + ) + global prompt_token_ids + prompt_token_ids = engine.get_tokenizer().encode(full_prompt) + # print(f"Token IDs: {token_ids}") + print(f"Token IDs length: {len(prompt_token_ids)}") + + outputs = engine.generate( + [TokensPrompt(prompt_token_ids=prompt_token_ids[:2000])], sampling_params + ) + print(f"Generated text: {outputs[0].outputs[0].token_ids}") + print(f"expect token ids: {prompt_token_ids[2000 : 2000 + 30]}") diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index 29e81f36e347..4523f3d25c14 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -1094,6 +1094,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: name.endswith(".bias") or name.endswith("_bias") ) and name not in params_dict: continue + if name not in params_dict: + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader( @@ -1110,6 +1112,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: continue if is_pp_missing_parameter(name, self): continue + if name not in params_dict: + continue param = params_dict[name] weight_loader = getattr( param, "weight_loader", default_weight_loader From fd7e3404adaf261fe4542313d25208d2729887c8 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Wed, 10 Dec 2025 00:43:36 -0800 Subject: [PATCH 031/130] update Signed-off-by: Chen Zhang --- tests/v1/e2e/test_mamba_prefix_cache.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/tests/v1/e2e/test_mamba_prefix_cache.py b/tests/v1/e2e/test_mamba_prefix_cache.py index 32dd21ce8210..be8d0d608687 100644 --- a/tests/v1/e2e/test_mamba_prefix_cache.py +++ b/tests/v1/e2e/test_mamba_prefix_cache.py @@ -1,19 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import time -from typing import Callable - -import pytest import os +from collections.abc import Callable +import pytest import torch from vllm import LLM, SamplingParams, TokensPrompt from vllm.sequence import IntermediateTensors from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.core.sched.output import SchedulerOutput -from vllm.v1.engine.core import EngineCore -from vllm.v1.engine.core_client import InprocClient from vllm.v1.outputs import SamplerOutput from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.spec_decode.metadata import SpecDecodeMetadata @@ -164,7 +160,7 @@ def test_run_ref_mamba_state(monkeypatch: pytest.MonkeyPatch): num_generated_tokens = 4000 num_prompt_tokens = 500 sampling_params = SamplingParams(temperature=0.0, max_tokens=num_generated_tokens) - full_prompt = open(f"{os.path.dirname(__file__)}/input.txt", "r").read() + full_prompt = open(f"{os.path.dirname(__file__)}/input.txt").read() fake_execute_model_fn = get_fake_execute_model_fn(GPUModelRunner.execute_model) monkeypatch.setattr(GPUModelRunner, "execute_model", fake_execute_model_fn) fake_sample_fn = get_fake_sample_fn() @@ -197,7 +193,7 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch): monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") sampling_params = SamplingParams(temperature=0.0, max_tokens=30) - full_prompt = open(f"{os.path.dirname(__file__)}/input.txt", "r").read() + full_prompt = open(f"{os.path.dirname(__file__)}/input.txt").read() fake_sample_fn = get_fake_sample_fn() monkeypatch.setattr(GPUModelRunner, "_sample", fake_sample_fn) fake_propose_draft_token_ids_fn = get_fake_propose_draft_token_ids_fn() From 30c6a5a28badd5eea28b3d9ff5b476eaf3cd09d3 Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Wed, 10 Dec 2025 17:20:23 +0000 Subject: [PATCH 032/130] set block size to max_model_len when prefix caching is disable Signed-off-by: huanghaoyan.hhy --- vllm/model_executor/layers/mamba/abstract.py | 9 ++++++--- vllm/model_executor/models/config.py | 7 ++++++- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/mamba/abstract.py b/vllm/model_executor/layers/mamba/abstract.py index 41626eed5c52..fc111530e5f3 100644 --- a/vllm/model_executor/layers/mamba/abstract.py +++ b/vllm/model_executor/layers/mamba/abstract.py @@ -52,9 +52,12 @@ def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None: raise NotImplementedError( "Mamba with speculative decoding is not supported yet." ) - mamba_block_size = (vllm_config.cache_config.mamba_block_size - if not envs.VLLM_USE_LIGHTER_MAMBA_CACHE - else vllm_config.cache_config.block_size) + if not envs.VLLM_USE_LIGHTER_MAMBA_CACHE: + mamba_block_size = vllm_config.cache_config.mamba_block_size + elif vllm_config.cache_config.enable_prefix_caching: + mamba_block_size = vllm_config.cache_config.block_size + else: + mamba_block_size = vllm_config.model_config.max_model_len page_size_padded = vllm_config.cache_config.mamba_page_size_padded return MambaSpec( shapes=self.get_state_shape(), diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index bc2f434e3caa..06b50ee07d17 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -387,11 +387,16 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: model_config=model_config, ) + if (envs.VLLM_USE_LIGHTER_MAMBA_CACHE + and cache_config.enable_prefix_caching): + block_size = cache_config.block_size + else: + block_size = model_config.max_model_len # get mamba page size mamba_page_size = MambaSpec( shapes=model_cls.get_mamba_state_shape_from_config(vllm_config), dtypes=model_cls.get_mamba_state_dtype_from_config(vllm_config), - block_size=model_config.max_model_len if not envs.VLLM_USE_LIGHTER_MAMBA_CACHE else cache_config.block_size, + block_size=block_size, enable_caching=cache_config.enable_prefix_caching, ).page_size_bytes From 9e6724e3cd25cf93abfb525f3ac0ca5bc29ed46b Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Wed, 10 Dec 2025 17:54:53 +0000 Subject: [PATCH 033/130] update mamba manager Signed-off-by: huanghaoyan.hhy --- vllm/v1/core/single_type_kv_cache_manager.py | 108 +++++++++++-------- 1 file changed, 63 insertions(+), 45 deletions(-) diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 7cf06db854f9..06088ee1deb7 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -650,9 +650,9 @@ def __init__(self, kv_cache_spec: MambaSpec, **kwargs) -> None: super().__init__(kv_cache_spec, **kwargs) self.num_speculative_blocks: int = kv_cache_spec.num_speculative_blocks if envs.VLLM_USE_LIGHTER_MAMBA_CACHE: - # self._req_info : dict[str, MambaManager.AllocationInfo] = {} self.last_state_block_idx: dict[str, int] = {} - self._allocated_spec_block_reqs: set[str] = set() + # the set of the requests that have been allocated blocks + self._allocated_block_reqs: set[str] = set() @classmethod def find_longest_cache_hit( @@ -702,8 +702,11 @@ def remove_skipped_blocks(self, request_id: str, assert isinstance(self.kv_cache_spec, MambaSpec) super().remove_skipped_blocks(request_id, num_computed_tokens) if envs.VLLM_USE_LIGHTER_MAMBA_CACHE: + last_state_block_idx = self.last_state_block_idx.get(request_id) if ( - (last_state_block_idx := self.last_state_block_idx.get(request_id)) and last_state_block_idx < cdiv(num_computed_tokens, self.block_size) - 1): + last_state_block_idx is not None + and last_state_block_idx < cdiv(num_computed_tokens, self.block_size) - 1 + ): blocks = self.req_to_blocks[request_id] if blocks[last_state_block_idx] != self._null_block: self.block_pool.free_blocks([blocks[last_state_block_idx]]) @@ -722,34 +725,50 @@ def get_num_blocks_to_allocate( new_computed_blocks: Sequence[KVCacheBlock], num_tokens_target_model: int, ) -> int: + assert isinstance(self.kv_cache_spec, MambaSpec) # mamba layers only exist in target model. num_tokens = num_tokens_target_model - # Allocate extra `num_speculative_blocks` blocks for - # speculative decoding (MTP/EAGLE) with linear attention. - assert isinstance(self.kv_cache_spec, MambaSpec) - if self.kv_cache_spec.num_speculative_blocks > 0: - num_tokens += ( - self.kv_cache_spec.block_size - * self.kv_cache_spec.num_speculative_blocks + if not envs.VLLM_USE_LIGHTER_MAMBA_CACHE: + # Allocate extra `num_speculative_blocks` blocks for + # speculative decoding (MTP/EAGLE) with linear attention. + if self.kv_cache_spec.num_speculative_blocks > 0: + num_tokens += ( + self.kv_cache_spec.block_size + * self.kv_cache_spec.num_speculative_blocks + ) + return super().get_num_blocks_to_allocate( + request_id, num_tokens, new_computed_blocks, + num_tokens_target_model, ) - num_blocks_to_allocate = super().get_num_blocks_to_allocate( - request_id, num_tokens, new_computed_blocks, - num_tokens_target_model, - ) - if envs.VLLM_USE_LIGHTER_MAMBA_CACHE: - # (Chen): This may be possible. (block_size 4, 2 sps). - # [A, stoken1, stoken2] SBLOCK1 SBLOCK2 -> - # [A, ?, ?, ?] NULL NULL [?, ?, ?, B] [stoken 1, stoken 2] SBLOCK1 SBLOCK2 -> need two blocks - # but we do it as following: - # [A, ?, ?, ?] NULL NULL NULL [stoken 1, stoken 2] SBLOCK1 SBLOCK2 -> need 1 block - if request_id in self._allocated_spec_block_reqs: - # previously allocated blocks - num_blocks_to_allocate = min(num_blocks_to_allocate, 1) - else: - num_blocks_to_allocate = min(num_blocks_to_allocate, 1 + self.kv_cache_spec.num_speculative_blocks) - self.print(f'Mamba.get_num_blocks_to_allocate: {request_id=}, {num_tokens=}, {num_blocks_to_allocate=}') - return num_blocks_to_allocate - + else: + num_required_blocks = cdiv(num_tokens, self.block_size) + self.num_speculative_blocks + num_new_blocks = ( + num_required_blocks + - len(new_computed_blocks) + - len(self.req_to_blocks[request_id]) + ) + if num_new_blocks > 0: + # (Chen): This may be possible. (block_size 4, 2 sps). + # [A, stoken1, stoken2] SBLOCK1 SBLOCK2 -> + # [A, ?, ?, ?] NULL NULL [?, ?, ?, B] [stoken 1, stoken 2] SBLOCK1 SBLOCK2 -> need two blocks + # but we do it as following: + # [A, ?, ?, ?] NULL NULL NULL [stoken 1, stoken 2] SBLOCK1 SBLOCK2 -> need 1 block + if request_id in self._allocated_block_reqs: + # previously allocated blocks + num_new_blocks = 1 + else: + # first prefill + num_new_blocks = 1 + self.kv_cache_spec.num_speculative_blocks + + # If a computed block of a request is an eviction candidate (in the + # free queue and ref_cnt == 0), it will be changed from a free block + # to a computed block when the request is allocated, so we also count + # it as needed to be allocated. + num_evictable_computed_blocks = sum( + blk.ref_cnt == 0 and not blk.is_null for blk in new_computed_blocks + ) + self.print(f'Mamba.get_num_blocks_to_allocate: {request_id=}, {num_tokens=}, {num_new_blocks=}') + return num_new_blocks + num_evictable_computed_blocks def save_new_computed_blocks( self, request_id: str, @@ -762,10 +781,11 @@ def save_new_computed_blocks( def allocate_new_blocks( self, request_id: str, num_tokens: int, num_tokens_target_model: int ) -> list[KVCacheBlock]: - # Allocate extra `num_speculative_blocks` blocks for - # speculative decoding (MTP/EAGLE) with linear attention. assert isinstance(self.kv_cache_spec, MambaSpec) + num_tokens = num_tokens_target_model if not envs.VLLM_USE_LIGHTER_MAMBA_CACHE: + # Allocate extra `num_speculative_blocks` blocks for + # speculative decoding (MTP/EAGLE) with linear attention. if self.num_speculative_blocks > 0: num_tokens += ( self.block_size @@ -774,29 +794,27 @@ def allocate_new_blocks( return super().allocate_new_blocks(request_id, num_tokens, num_tokens_target_model) else: req_blocks: list[KVCacheBlock] = self.req_to_blocks[request_id] - num_tokens = num_tokens_target_model num_required_blocks = cdiv(num_tokens, self.block_size) + self.num_speculative_blocks if num_required_blocks == len(req_blocks): return [] else: assert num_required_blocks > len(req_blocks), f'num_required_blocks {num_required_blocks} < len(req_blocks) {len(req_blocks)}' prev_block_len = len(req_blocks) - spec_blocks_allocated = request_id in self._allocated_spec_block_reqs - # We always save the current running state at the last (1 + num_speculative_blocks) block - if request_id in self._allocated_spec_block_reqs: + blocks_allocated = request_id in self._allocated_block_reqs + # Record the last state block + if blocks_allocated: + # We always save the current running state at the last (1 + num_speculative_blocks) block self.last_state_block_idx[request_id] = prev_block_len - 1 - self.num_speculative_blocks - else: - if prev_block_len > 0: - self.last_state_block_idx[request_id] = prev_block_len - 1 - else: - assert request_id not in self._allocated_spec_block_reqs + elif prev_block_len > 0: + # When a new request hits the prefix cache, the last block saves the hit state. + self.last_state_block_idx[request_id] = prev_block_len - 1 num_skipped_blocks = num_required_blocks - self.num_speculative_blocks - 1 # null blocks - if len(req_blocks) < num_skipped_blocks: - req_blocks.extend([self._null_block for _ in range(num_skipped_blocks - len(req_blocks))]) + if prev_block_len < num_skipped_blocks: + req_blocks.extend([self._null_block for _ in range(prev_block_len, num_skipped_blocks)]) - if spec_blocks_allocated: + if blocks_allocated: # reuse previous speculative blocks in this step for block_idx in range(prev_block_len - self.num_speculative_blocks, prev_block_len): if block_idx < num_skipped_blocks: @@ -807,20 +825,20 @@ def allocate_new_blocks( break num_new_blocks = num_required_blocks - len(req_blocks) self.print(f'Mamba.alloc_blks: {request_id=}, num_new_blocks={num_new_blocks}') - if spec_blocks_allocated: + if blocks_allocated: assert num_new_blocks <= 1 else: assert num_new_blocks <= self.num_speculative_blocks + 1 new_blocks = self.block_pool.get_new_blocks(num_new_blocks) req_blocks.extend(new_blocks) - self._allocated_spec_block_reqs.add(request_id) + self._allocated_block_reqs.add(request_id) self.print(f'Mamba.alloc_blks: {request_id=}, {len(req_blocks)=}, {len(self.req_to_blocks[request_id])=}, req_blocks={format_blocks(req_blocks)}') return req_blocks[prev_block_len:] def free(self, request_id: str) -> None: if envs.VLLM_USE_LIGHTER_MAMBA_CACHE: - self._allocated_spec_block_reqs.discard(request_id) + self._allocated_block_reqs.discard(request_id) self.last_state_block_idx.pop(request_id, None) super().free(request_id) From ac577b7d90920a16333b4b0100ec4ee84a036004 Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Wed, 10 Dec 2025 18:06:06 +0000 Subject: [PATCH 034/130] update max mem usage bytes for mamba spec Signed-off-by: huanghaoyan.hhy --- vllm/v1/kv_cache_interface.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index cdacd427291a..d30eacdef89b 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -269,11 +269,11 @@ def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: max_model_len = vllm_config.model_config.max_model_len return cdiv(max_model_len, self.block_size) * self.page_size_bytes else: - # NOTE: We allocate 1 block per request by default. With prefix - # caching enabled, up to 2 additional blocks are required: one - # for reading the matched prefix and one for caching the current - # state. - return self.page_size_bytes * (3 if self.enable_caching else 1) + # NOTE: We allocate 1+sps block per request by default. With prefix + # caching enabled, one additional blocks are required which is saved + # last state for copying. + return self.page_size_bytes * (1 + self.num_speculative_blocks + + self.enable_caching) From 6f036d708069902fd53473cada8d55253e6a01ef Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Wed, 10 Dec 2025 15:44:10 -0800 Subject: [PATCH 035/130] init result checker Signed-off-by: Chen Zhang --- tests/v1/e2e/test_mamba_prefix_cache.py | 49 ++++++++++++++++++++++--- 1 file changed, 43 insertions(+), 6 deletions(-) diff --git a/tests/v1/e2e/test_mamba_prefix_cache.py b/tests/v1/e2e/test_mamba_prefix_cache.py index be8d0d608687..771beed89fa7 100644 --- a/tests/v1/e2e/test_mamba_prefix_cache.py +++ b/tests/v1/e2e/test_mamba_prefix_cache.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import os from collections.abc import Callable +from dataclasses import dataclass import pytest import torch @@ -16,7 +17,7 @@ from vllm.v1.worker.gpu_model_runner import GPUModelRunner from vllm.v1.worker.utils import get_mamba_groups -num_speculative_tokens = 2 +num_speculative_tokens = 3 num_accepted_tokens = 1 prompt_token_ids = [] @@ -184,15 +185,36 @@ def test_run_ref_mamba_state(monkeypatch: pytest.MonkeyPatch): print( f"expect token ids: {prompt_token_ids[num_prompt_tokens : num_prompt_tokens + num_generated_tokens]}" ) - print(f"mamba_kv_cache_dict: {mamba_kv_cache_dict}") + print(f"mamba_kv_cache_dict: {mamba_kv_cache_dict.keys()}") torch.save(mamba_kv_cache_dict, "mamba_kv_cache_dict.pth") +def check_mamba_state_equal(mamba_state_ref: dict, mamba_state_new: dict): + for key in mamba_state_new: + # mamba state new is a subset of mamba state ref + for i, (ref, new) in enumerate(zip(mamba_state_ref[key], mamba_state_new[key])): + if not torch.allclose(ref, new[: ref.shape[0]]): + raise ValueError( + f"Mamba state is not equal for key: {key} at index {i}" + ) + return True + + +@dataclass +class TestConfig: + num_prompt_tokens: int + num_generated_tokens: int + num_accepted_tokens: int + expect_schedule_tokens: list[int] | None + expect_block_table: list[int] | None + + def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch): monkeypatch.setenv("VLLM_USE_LIGHTER_MAMBA_CACHE", "1") monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") - sampling_params = SamplingParams(temperature=0.0, max_tokens=30) - + num_generated_tokens = 1000 + num_prompt_tokens = 500 + sampling_params = SamplingParams(temperature=0.0, max_tokens=num_generated_tokens) full_prompt = open(f"{os.path.dirname(__file__)}/input.txt").read() fake_sample_fn = get_fake_sample_fn() monkeypatch.setattr(GPUModelRunner, "_sample", fake_sample_fn) @@ -200,6 +222,8 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch): monkeypatch.setattr( GPUModelRunner, "propose_draft_token_ids", fake_propose_draft_token_ids_fn ) + fake_execute_model_fn = get_fake_execute_model_fn(GPUModelRunner.execute_model) + monkeypatch.setattr(GPUModelRunner, "execute_model", fake_execute_model_fn) engine = LLM( model=MODEL, enable_prefix_caching=True, @@ -218,7 +242,20 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch): print(f"Token IDs length: {len(prompt_token_ids)}") outputs = engine.generate( - [TokensPrompt(prompt_token_ids=prompt_token_ids[:2000])], sampling_params + [TokensPrompt(prompt_token_ids=prompt_token_ids[:num_prompt_tokens])], + sampling_params, ) print(f"Generated text: {outputs[0].outputs[0].token_ids}") - print(f"expect token ids: {prompt_token_ids[2000 : 2000 + 30]}") + print( + f"expect token ids: {prompt_token_ids[num_prompt_tokens : num_prompt_tokens + num_generated_tokens]}" + ) + + torch.save(mamba_kv_cache_dict, "mamba_kv_cache_dict_new.pth") + mamba_state_ref = torch.load("mamba_kv_cache_dict.pth") + check_mamba_state_equal(mamba_state_ref, mamba_kv_cache_dict) + + +def test_check_mamba_state_equal(): + mamba_state_ref = torch.load("mamba_kv_cache_dict.pth") + mamba_state_new = torch.load("mamba_kv_cache_dict_new.pth") + check_mamba_state_equal(mamba_state_ref, mamba_state_new) From 9162095c74ac093d4343f0a21cc259f1d641649f Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Wed, 10 Dec 2025 15:45:33 -0800 Subject: [PATCH 036/130] format code Signed-off-by: Chen Zhang --- vllm/v1/core/kv_cache_coordinator.py | 11 +- vllm/v1/core/kv_cache_manager.py | 5 +- vllm/v1/core/sched/scheduler.py | 126 +++++++++++++------ vllm/v1/core/single_type_kv_cache_manager.py | 113 +++++++++++------ 4 files changed, 173 insertions(+), 82 deletions(-) diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index 30e1442c12c5..9238247aa349 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -84,7 +84,10 @@ def get_num_blocks_to_allocate( ) else: num_blocks_to_allocate += manager.get_num_blocks_to_allocate( - request_id, num_tokens, new_computed_blocks[i], num_tokens_target_model + request_id, + num_tokens, + new_computed_blocks[i], + num_tokens_target_model, ) return num_blocks_to_allocate @@ -103,7 +106,11 @@ def save_new_computed_blocks( manager.save_new_computed_blocks(request_id, new_computed_blocks[i]) def allocate_new_blocks( - self, request_id: str, num_tokens: int, num_tokens_target_model: int, num_encoder_tokens: int = 0 + self, + request_id: str, + num_tokens: int, + num_tokens_target_model: int, + num_encoder_tokens: int = 0, ) -> tuple[list[KVCacheBlock], ...]: """ Allocate new blocks for the request to give it at least `num_tokens` diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index bfaa166f8f36..b4171c2d24ec 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -315,7 +315,10 @@ def allocate_slots( ) new_blocks = self.coordinator.allocate_new_blocks( - request.request_id, num_tokens_need_slot, num_tokens_target_model, num_encoder_tokens + request.request_id, + num_tokens_need_slot, + num_tokens_target_model, + num_encoder_tokens, ) # P/D: delay caching blocks if we have to recv from diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 00d878e52ab7..e6df707e20cb 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -189,18 +189,29 @@ def __init__( ) self.use_pp = self.parallel_config.pipeline_parallel_size > 1 self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER - print(f">>> [DEBUG] Scheduler: init enable_prefix_caching={self.cache_config.enable_prefix_caching} block_size={self.block_size} kv_cache_config={self.kv_cache_config}") + print( + f">>> [DEBUG] Scheduler: init enable_prefix_caching={self.cache_config.enable_prefix_caching} block_size={self.block_size} kv_cache_config={self.kv_cache_config}" + ) def _has_mamba_spec(self) -> bool: - has_mamba: bool = any(isinstance(spec.kv_cache_spec, MambaSpec) - for spec in self.kv_cache_config.kv_cache_groups) + has_mamba: bool = any( + isinstance(spec.kv_cache_spec, MambaSpec) + for spec in self.kv_cache_config.kv_cache_groups + ) assert not has_mamba or self.vllm_config.model_config.is_hybrid return has_mamba - - def _mamba_block_aligned_split(self, request: Request, num_new_tokens: int, num_new_local_computed_tokens: int=0, num_external_computed_tokens: int=0) -> int: - assert num_external_computed_tokens == 0, "External KV connector is not verified yet" - if (self.cache_config.enable_prefix_caching - and self._has_mamba_spec()): + + def _mamba_block_aligned_split( + self, + request: Request, + num_new_tokens: int, + num_new_local_computed_tokens: int = 0, + num_external_computed_tokens: int = 0, + ) -> int: + assert num_external_computed_tokens == 0, ( + "External KV connector is not verified yet" + ) + if self.cache_config.enable_prefix_caching and self._has_mamba_spec(): # To enable block-aligned caching of the Mamba state, `num_new_tokens` # must be a multiple of `block_size`. # As an exception, if `num_new_tokens` is less than `block_size`, the @@ -209,17 +220,27 @@ def _mamba_block_aligned_split(self, request: Request, num_new_tokens: int, num_ # matching block. To prevent this from causing a Mamba cache miss, the # last chunk must be larger than `block_size`. block_size = self.cache_config.block_size - if request.num_output_tokens == 0: # prefill - last_cache_position = request.num_prompt_tokens - request.num_prompt_tokens % block_size + if request.num_output_tokens == 0: # prefill + last_cache_position = ( + request.num_prompt_tokens - request.num_prompt_tokens % block_size + ) # eagle prune if self.use_eagle: last_cache_position = max(last_cache_position - block_size, 0) - num_computed_tokens = request.num_computed_tokens + num_new_local_computed_tokens + num_external_computed_tokens + num_computed_tokens = ( + request.num_computed_tokens + + num_new_local_computed_tokens + + num_external_computed_tokens + ) num_computed_tokens_after_prefill = num_computed_tokens + num_new_tokens if num_computed_tokens_after_prefill < last_cache_position: # align to block_size num_new_tokens = num_new_tokens // block_size * block_size - elif num_computed_tokens < last_cache_position < num_computed_tokens_after_prefill: + elif ( + num_computed_tokens + < last_cache_position + < num_computed_tokens_after_prefill + ): # force to cache the last chunk num_new_tokens = last_cache_position - num_computed_tokens else: @@ -230,7 +251,9 @@ def _mamba_block_aligned_split(self, request: Request, num_new_tokens: int, num_ def schedule(self) -> SchedulerOutput: print(f">>> [DEBUG] Scheduler: schidule new step") for req in self.requests.values(): - print(f">>> [DEBUG] Scheduler: request {req.request_id} num_computed_tokens={req.num_computed_tokens} num_tokens={req.num_tokens} num_tokens_with_spec={req.num_tokens_with_spec}") + print( + f">>> [DEBUG] Scheduler: request {req.request_id} num_computed_tokens={req.num_computed_tokens} num_tokens={req.num_tokens} num_tokens_with_spec={req.num_tokens_with_spec}" + ) # NOTE(woosuk) on the scheduling algorithm: # There's no "decoding phase" nor "prefill phase" in the scheduler. # Each request just has the num_computed_tokens and @@ -264,27 +287,33 @@ def schedule(self) -> SchedulerOutput: while req_index < len(self.running) and token_budget > 0: request = self.running[req_index] - logger.info(f'>>> [DEBUG] Scheduler: schedule RUNING: req_id={request.request_id}, ' - f'num_prompt_tokens={request.num_prompt_tokens=}') - # Ensure new tokens for a request in the prefill phase do not contain - # sps tokens, especially in the last prefill chunk. For a hybrid-model, + logger.info( + f">>> [DEBUG] Scheduler: schedule RUNING: req_id={request.request_id}, " + f"num_prompt_tokens={request.num_prompt_tokens=}" + ) + # Ensure new tokens for a request in the prefill phase do not contain + # sps tokens, especially in the last prefill chunk. For a hybrid-model, # extra sps tokens would corrupt the generated Mamba state. # TODO: This logic does not yet handle resumed requests. if request.num_computed_tokens < request.num_prompt_tokens: - num_new_tokens = min(request.num_tokens_with_spec - + request.num_output_placeholders, - request.num_prompt_tokens) - request.num_computed_tokens + num_new_tokens = ( + min( + request.num_tokens_with_spec + request.num_output_placeholders, + request.num_prompt_tokens, + ) + - request.num_computed_tokens + ) else: - num_new_tokens = (request.num_tokens_with_spec + - request.num_output_placeholders - - request.num_computed_tokens) - - if (0 < self.scheduler_config.long_prefill_token_threshold < - num_new_tokens): num_new_tokens = ( - self.scheduler_config.long_prefill_token_threshold) + request.num_tokens_with_spec + + request.num_output_placeholders + - request.num_computed_tokens + ) + + if 0 < self.scheduler_config.long_prefill_token_threshold < num_new_tokens: + num_new_tokens = self.scheduler_config.long_prefill_token_threshold - num_new_tokens = min(num_new_tokens, token_budget) + num_new_tokens = min(num_new_tokens, token_budget) # Make sure the input position does not exceed the max model len or # request's max_tokens. @@ -315,7 +344,8 @@ def schedule(self) -> SchedulerOutput: if envs.VLLM_USE_LIGHTER_MAMBA_CACHE: num_new_tokens = self._mamba_block_aligned_split( - request, num_new_tokens) + request, num_new_tokens + ) if num_new_tokens == 0: # The request cannot be scheduled because one of the following @@ -327,7 +357,7 @@ def schedule(self) -> SchedulerOutput: # its max_total_tokens or max_model_len. # 2. The encoder budget is exhausted. # 3. The encoder cache is exhausted. - # 4. Insufficient budget for a block-aligned chunk in hybrid + # 4. Insufficient budget for a block-aligned chunk in hybrid # models with lighter mamba prefix caching. # NOTE(woosuk): Here, by doing `continue` instead of `break`, # we do not strictly follow the FCFS scheduling policy and @@ -380,7 +410,9 @@ def schedule(self) -> SchedulerOutput: req_index -= 1 else: preempted_req = self.running.pop() - print(f">>> [DEBUG] Scheduler: preempted request {preempted_req.request_id}") + print( + f">>> [DEBUG] Scheduler: preempted request {preempted_req.request_id}" + ) self.kv_cache_manager.free(preempted_req) self.encoder_cache_manager.free(preempted_req) @@ -463,8 +495,10 @@ def schedule(self) -> SchedulerOutput: break request = self.waiting.peek_request() - logger.info(f'>>> [DEBUG] Scheduler: schedule WAITING: req_id={request.request_id}, ' - f'num_prompt_tokens={request.num_prompt_tokens=}') + logger.info( + f">>> [DEBUG] Scheduler: schedule WAITING: req_id={request.request_id}, " + f"num_prompt_tokens={request.num_prompt_tokens=}" + ) # KVTransfer: skip request if still waiting for remote kvs. if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: @@ -515,8 +549,10 @@ def schedule(self) -> SchedulerOutput: new_computed_blocks, num_new_local_computed_tokens = ( self.kv_cache_manager.get_computed_blocks(request) ) - logger.info(f'>>> [DEBUG] Scheduler: get_computed_blk: req_id={request.request_id},' - f'{num_new_local_computed_tokens=}') + logger.info( + f">>> [DEBUG] Scheduler: get_computed_blk: req_id={request.request_id}," + f"{num_new_local_computed_tokens=}" + ) # Get externally-cached tokens if using a KVConnector. if self.connector is not None: @@ -598,7 +634,11 @@ def schedule(self) -> SchedulerOutput: if envs.VLLM_USE_LIGHTER_MAMBA_CACHE: num_new_tokens = self._mamba_block_aligned_split( - request, num_new_tokens, num_new_local_computed_tokens, num_external_computed_tokens) + request, + num_new_tokens, + num_new_local_computed_tokens, + num_external_computed_tokens, + ) if num_new_tokens == 0: break @@ -762,10 +802,14 @@ def schedule(self) -> SchedulerOutput: self.prev_step_scheduled_req_ids.clear() self.prev_step_scheduled_req_ids.update(num_scheduled_tokens.keys()) - logger.info('>>> [DEBUG] Scheduler: new_reqs:' - f'{[(reqdata.req_id, reqdata.block_ids) for reqdata in new_reqs_data]}') - logger.info('>>> [DEBUG] Scheduler: cached_reqs:' - f'{[(req_id, cached_reqs_data.new_block_ids[i]) for i, req_id in enumerate(cached_reqs_data.req_ids)]}') + logger.info( + ">>> [DEBUG] Scheduler: new_reqs:" + f"{[(reqdata.req_id, reqdata.block_ids) for reqdata in new_reqs_data]}" + ) + logger.info( + ">>> [DEBUG] Scheduler: cached_reqs:" + f"{[(req_id, cached_reqs_data.new_block_ids[i]) for i, req_id in enumerate(cached_reqs_data.req_ids)]}" + ) scheduler_output = SchedulerOutput( scheduled_new_reqs=new_reqs_data, scheduled_cached_reqs=cached_reqs_data, @@ -803,7 +847,7 @@ def schedule(self) -> SchedulerOutput: with record_function_or_nullcontext("schedule: update_after_schedule"): self._update_after_schedule(scheduler_output) - logger.info(f'>>> [DEBUG] Scheduler: scheduler_output: {scheduler_output}') + logger.info(f">>> [DEBUG] Scheduler: scheduler_output: {scheduler_output}") return scheduler_output def _update_after_schedule( diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 06088ee1deb7..36d4f4efaa63 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -24,10 +24,10 @@ def format_blocks(blocks: list[KVCacheBlock]): if not blocks: return "[]" - + result = [] i = 0 - + while i < len(blocks): if blocks[i].block_id == 0: count = 0 @@ -37,11 +37,14 @@ def format_blocks(blocks: list[KVCacheBlock]): i += 1 result.append(f"Null-block*{count}") else: - result.append(f'KVBlock(block_id={blocks[i].block_id}, ref_cnt={blocks[i].ref_cnt})') + result.append( + f"KVBlock(block_id={blocks[i].block_id}, ref_cnt={blocks[i].ref_cnt})" + ) i += 1 - + return f"[{', '.join(result)}]" + class SingleTypeKVCacheManager(ABC): """ An abstract base class for a manager that handle the kv cache management @@ -86,7 +89,7 @@ def __init__( self._null_block = block_pool.null_block def print(self, *args, **kwargs): - new_args = (f">>> [KvGrp {self.kv_cache_group_id}] ", ) + args + new_args = (f">>> [KvGrp {self.kv_cache_group_id}] ",) + args print(*new_args, **kwargs) def get_num_blocks_to_allocate( @@ -188,11 +191,13 @@ def cache_blocks(self, request: Request, num_tokens: int) -> None: if num_cached_blocks >= num_full_blocks: return - + if isinstance(self, MambaManager) and num_cached_blocks < num_full_blocks: - self.print(f'Mamba.cache_blocks: req_id={request.request_id}, {num_tokens=}, ' - f'{num_cached_blocks=}, {num_full_blocks=}, ' - f'new_full_blocks={format_blocks(self.req_to_blocks[request.request_id][num_cached_blocks:num_full_blocks])}') + self.print( + f"Mamba.cache_blocks: req_id={request.request_id}, {num_tokens=}, " + f"{num_cached_blocks=}, {num_full_blocks=}, " + f"new_full_blocks={format_blocks(self.req_to_blocks[request.request_id][num_cached_blocks:num_full_blocks])}" + ) self.block_pool.cache_full_blocks( request=request, @@ -314,7 +319,9 @@ def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> No break removed_blocks.append(blocks[i]) blocks[i] = self._null_block - self.print(f'Mamba.remove_skipped_blocks: {request_id=}, {num_computed_tokens=}, {num_skipped_tokens=}, {num_skipped_blocks=}, removed_blocks={format_blocks(removed_blocks)}') + self.print( + f"Mamba.remove_skipped_blocks: {request_id=}, {num_computed_tokens=}, {num_skipped_tokens=}, {num_skipped_blocks=}, removed_blocks={format_blocks(removed_blocks)}" + ) self.block_pool.free_blocks(removed_blocks) def get_num_skipped_tokens(self, num_computed_tokens: int) -> int: @@ -645,7 +652,6 @@ def get_num_common_prefix_blocks(self, running_request_id: str) -> int: class MambaManager(SingleTypeKVCacheManager): - def __init__(self, kv_cache_spec: MambaSpec, **kwargs) -> None: super().__init__(kv_cache_spec, **kwargs) self.num_speculative_blocks: int = kv_cache_spec.num_speculative_blocks @@ -690,22 +696,25 @@ def find_longest_cache_hit( computed.append(cached) break # we just need the last match - early stopping - print(f'Mamba.FindLongest: computed_blocks={[format_blocks(computed_block) for computed_block in computed_blocks]}', flush=True) + print( + f"Mamba.FindLongest: computed_blocks={[format_blocks(computed_block) for computed_block in computed_blocks]}", + flush=True, + ) return computed_blocks def get_num_skipped_tokens(self, num_computed_tokens: int) -> int: # TODO: merge https://github.com/vllm-project/vllm/pull/28047 first return num_computed_tokens - 1 - def remove_skipped_blocks(self, request_id: str, - num_computed_tokens: int) -> None: + def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None: assert isinstance(self.kv_cache_spec, MambaSpec) super().remove_skipped_blocks(request_id, num_computed_tokens) if envs.VLLM_USE_LIGHTER_MAMBA_CACHE: last_state_block_idx = self.last_state_block_idx.get(request_id) if ( - last_state_block_idx is not None - and last_state_block_idx < cdiv(num_computed_tokens, self.block_size) - 1 + last_state_block_idx is not None + and last_state_block_idx + < cdiv(num_computed_tokens, self.block_size) - 1 ): blocks = self.req_to_blocks[request_id] if blocks[last_state_block_idx] != self._null_block: @@ -737,11 +746,15 @@ def get_num_blocks_to_allocate( * self.kv_cache_spec.num_speculative_blocks ) return super().get_num_blocks_to_allocate( - request_id, num_tokens, new_computed_blocks, + request_id, + num_tokens, + new_computed_blocks, num_tokens_target_model, ) else: - num_required_blocks = cdiv(num_tokens, self.block_size) + self.num_speculative_blocks + num_required_blocks = ( + cdiv(num_tokens, self.block_size) + self.num_speculative_blocks + ) num_new_blocks = ( num_required_blocks - len(new_computed_blocks) @@ -759,7 +772,7 @@ def get_num_blocks_to_allocate( else: # first prefill num_new_blocks = 1 + self.kv_cache_spec.num_speculative_blocks - + # If a computed block of a request is an eviction candidate (in the # free queue and ref_cnt == 0), it will be changed from a free block # to a computed block when the request is allocated, so we also count @@ -767,15 +780,19 @@ def get_num_blocks_to_allocate( num_evictable_computed_blocks = sum( blk.ref_cnt == 0 and not blk.is_null for blk in new_computed_blocks ) - self.print(f'Mamba.get_num_blocks_to_allocate: {request_id=}, {num_tokens=}, {num_new_blocks=}') + self.print( + f"Mamba.get_num_blocks_to_allocate: {request_id=}, {num_tokens=}, {num_new_blocks=}" + ) return num_new_blocks + num_evictable_computed_blocks def save_new_computed_blocks( - self, request_id: str, - new_computed_blocks: list[KVCacheBlock]) -> None: + self, request_id: str, new_computed_blocks: list[KVCacheBlock] + ) -> None: # TODO(hhy): remove when prefix-caching is ready assert isinstance(self.kv_cache_spec, MambaSpec) - self.print(f'Mamba.save_computed: {request_id=}, new_computed_blocks={format_blocks(new_computed_blocks)}') + self.print( + f"Mamba.save_computed: {request_id=}, new_computed_blocks={format_blocks(new_computed_blocks)}" + ) super().save_new_computed_blocks(request_id, new_computed_blocks) def allocate_new_blocks( @@ -787,44 +804,62 @@ def allocate_new_blocks( # Allocate extra `num_speculative_blocks` blocks for # speculative decoding (MTP/EAGLE) with linear attention. if self.num_speculative_blocks > 0: - num_tokens += ( - self.block_size - * self.num_speculative_blocks - ) - return super().allocate_new_blocks(request_id, num_tokens, num_tokens_target_model) + num_tokens += self.block_size * self.num_speculative_blocks + return super().allocate_new_blocks( + request_id, num_tokens, num_tokens_target_model + ) else: req_blocks: list[KVCacheBlock] = self.req_to_blocks[request_id] - num_required_blocks = cdiv(num_tokens, self.block_size) + self.num_speculative_blocks + num_required_blocks = ( + cdiv(num_tokens, self.block_size) + self.num_speculative_blocks + ) if num_required_blocks == len(req_blocks): return [] else: - assert num_required_blocks > len(req_blocks), f'num_required_blocks {num_required_blocks} < len(req_blocks) {len(req_blocks)}' + assert num_required_blocks > len(req_blocks), ( + f"num_required_blocks {num_required_blocks} < len(req_blocks) {len(req_blocks)}" + ) prev_block_len = len(req_blocks) blocks_allocated = request_id in self._allocated_block_reqs # Record the last state block if blocks_allocated: # We always save the current running state at the last (1 + num_speculative_blocks) block - self.last_state_block_idx[request_id] = prev_block_len - 1 - self.num_speculative_blocks + self.last_state_block_idx[request_id] = ( + prev_block_len - 1 - self.num_speculative_blocks + ) elif prev_block_len > 0: # When a new request hits the prefix cache, the last block saves the hit state. self.last_state_block_idx[request_id] = prev_block_len - 1 - num_skipped_blocks = num_required_blocks - self.num_speculative_blocks - 1 + num_skipped_blocks = ( + num_required_blocks - self.num_speculative_blocks - 1 + ) # null blocks if prev_block_len < num_skipped_blocks: - req_blocks.extend([self._null_block for _ in range(prev_block_len, num_skipped_blocks)]) + req_blocks.extend( + [ + self._null_block + for _ in range(prev_block_len, num_skipped_blocks) + ] + ) if blocks_allocated: # reuse previous speculative blocks in this step - for block_idx in range(prev_block_len - self.num_speculative_blocks, prev_block_len): + for block_idx in range( + prev_block_len - self.num_speculative_blocks, prev_block_len + ): if block_idx < num_skipped_blocks: req_blocks.append(req_blocks[block_idx]) req_blocks[block_idx] = self._null_block - self.print(f"Mamba.alloc_blks: {request_id=}, moving block {block_idx} to the end now, req_blocks={format_blocks(req_blocks)}") + self.print( + f"Mamba.alloc_blks: {request_id=}, moving block {block_idx} to the end now, req_blocks={format_blocks(req_blocks)}" + ) else: break num_new_blocks = num_required_blocks - len(req_blocks) - self.print(f'Mamba.alloc_blks: {request_id=}, num_new_blocks={num_new_blocks}') + self.print( + f"Mamba.alloc_blks: {request_id=}, num_new_blocks={num_new_blocks}" + ) if blocks_allocated: assert num_new_blocks <= 1 else: @@ -832,9 +867,10 @@ def allocate_new_blocks( new_blocks = self.block_pool.get_new_blocks(num_new_blocks) req_blocks.extend(new_blocks) self._allocated_block_reqs.add(request_id) - self.print(f'Mamba.alloc_blks: {request_id=}, {len(req_blocks)=}, {len(self.req_to_blocks[request_id])=}, req_blocks={format_blocks(req_blocks)}') + self.print( + f"Mamba.alloc_blks: {request_id=}, {len(req_blocks)=}, {len(self.req_to_blocks[request_id])=}, req_blocks={format_blocks(req_blocks)}" + ) return req_blocks[prev_block_len:] - def free(self, request_id: str) -> None: if envs.VLLM_USE_LIGHTER_MAMBA_CACHE: @@ -842,6 +878,7 @@ def free(self, request_id: str) -> None: self.last_state_block_idx.pop(request_id, None) super().free(request_id) + class CrossAttentionManager(SingleTypeKVCacheManager): """Manager for cross-attention KV cache in encoder-decoder models.""" From 84e7a3f588beab977a8fd1f1649550dfe2f6b1db Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Wed, 10 Dec 2025 20:16:13 -0800 Subject: [PATCH 037/130] preprocess copy need to consider accept token Signed-off-by: Chen Zhang --- tests/v1/e2e/test_mamba_prefix_cache.py | 85 ++++++++++++++++++++----- vllm/v1/worker/gpu_model_runner.py | 32 ++++------ 2 files changed, 82 insertions(+), 35 deletions(-) diff --git a/tests/v1/e2e/test_mamba_prefix_cache.py b/tests/v1/e2e/test_mamba_prefix_cache.py index 771beed89fa7..9ce688371a72 100644 --- a/tests/v1/e2e/test_mamba_prefix_cache.py +++ b/tests/v1/e2e/test_mamba_prefix_cache.py @@ -19,7 +19,7 @@ num_speculative_tokens = 3 -num_accepted_tokens = 1 +num_accepted_tokens = 4 prompt_token_ids = [] MODEL = "Qwen/Qwen3-Next-80B-A3B-Instruct" BLOCK_SIZE = 560 @@ -64,6 +64,14 @@ def fake_sample_fn( print( f"[UNIT TEST] fake_sample_fn: {first_token_id_index=} {accpeted_tokens=} {sampled_token_ids=}" ) + # if ( + # self.input_batch.num_computed_tokens_cpu_tensor[0].item() + # >= self.input_batch.num_prompt_tokens[0].item() + # ): + # for i, x in enumerate(sampled_token_ids): + # if x == -1: + # continue + # assert x == self.input_ids.cpu[i + 1] return SamplerOutput( sampled_token_ids=torch.tensor( [sampled_token_ids], device="cuda", dtype=torch.int32 @@ -76,7 +84,7 @@ def fake_sample_fn( def get_fake_propose_draft_token_ids_fn(): def fake_propose_draft_token_ids_fn( - self, + self: GPUModelRunner, scheduler_output: SchedulerOutput, sampled_token_ids: torch.Tensor | list[list[int]], sampling_metadata: SamplingMetadata, @@ -88,15 +96,28 @@ def fake_propose_draft_token_ids_fn( ) -> list[list[int]]: num_computed_tokens_cpu_tensor = self.input_batch.num_computed_tokens_cpu_tensor num_computed_tokens = num_computed_tokens_cpu_tensor[0].item() - if num_computed_tokens < self.input_batch.num_prompt_tokens[0].item(): - first_token_id_index = self.input_batch.num_prompt_tokens[0].item() + 1 + if ( + self.input_batch.num_tokens_no_spec[0].item() + <= self.input_batch.num_prompt_tokens[0].item() + ): + first_token_id_index = self.input_batch.num_prompt_tokens[0].item() else: - first_token_id_index = num_computed_tokens + 2 - return [ + first_token_id_index = ( + num_computed_tokens + 1 + ) # bonus token isn't considered as computed + print( + f"fake_propose_draft_token_ids_fn: {self.input_batch.num_accepted_tokens_cpu=}" + ) + first_token_id_index += self.input_batch.num_accepted_tokens_cpu[0].item() + proposed_draft_token_ids = [ prompt_token_ids[ first_token_id_index : first_token_id_index + num_speculative_tokens ] ] + print( + f"[UNIT TEST] fake_propose_draft_token_ids_fn: {num_computed_tokens=} num_accepted_tokens={self.input_batch.num_accepted_tokens_cpu[0].item()} num_prompt_tokens={self.input_batch.num_prompt_tokens[0].item()} num_tokens_no_spec={self.input_batch.num_tokens_no_spec[0].item()} {first_token_id_index=} {proposed_draft_token_ids=}" + ) + return proposed_draft_token_ids return fake_propose_draft_token_ids_fn @@ -131,7 +152,10 @@ def fake_execute_model_fn( > last_num_computed_tokens // BLOCK_SIZE ): # generated a new aligned block in this step - block_idx = num_computed_tokens // mamba_spec.block_size + block_idx = num_computed_tokens // mamba_spec.block_size - 1 + print( + f"[UNIT TEST] fake_execute_model_fn: block_idx= {block_idx} for num_computed_tokens={num_computed_tokens - num_computed_tokens % BLOCK_SIZE}" + ) block_id = ( self.input_batch.block_table.block_tables[mamba_group_id] .block_table.cpu[0, block_idx] @@ -140,9 +164,6 @@ def fake_execute_model_fn( kv_cache = self.compilation_config.static_forward_context[ mamba_layer_name ].kv_cache - print("KV CACHE", kv_cache) - print(kv_cache[0]) - print(kv_cache[0][0].shape, kv_cache[0][1].shape) mamba_kv_cache_dict[ num_computed_tokens - num_computed_tokens % BLOCK_SIZE ] = (kv_cache[0][0][block_id].clone(), kv_cache[0][1][block_id].clone()) @@ -158,8 +179,8 @@ def fake_execute_model_fn( def test_run_ref_mamba_state(monkeypatch: pytest.MonkeyPatch): monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") - num_generated_tokens = 4000 - num_prompt_tokens = 500 + num_generated_tokens = 20 + num_prompt_tokens = 551 sampling_params = SamplingParams(temperature=0.0, max_tokens=num_generated_tokens) full_prompt = open(f"{os.path.dirname(__file__)}/input.txt").read() fake_execute_model_fn = get_fake_execute_model_fn(GPUModelRunner.execute_model) @@ -186,20 +207,49 @@ def test_run_ref_mamba_state(monkeypatch: pytest.MonkeyPatch): f"expect token ids: {prompt_token_ids[num_prompt_tokens : num_prompt_tokens + num_generated_tokens]}" ) print(f"mamba_kv_cache_dict: {mamba_kv_cache_dict.keys()}") - torch.save(mamba_kv_cache_dict, "mamba_kv_cache_dict.pth") + ref_mamba_kv_cache_dict = torch.load("mamba_kv_cache_dict.pth") + check_mamba_state_equal(ref_mamba_kv_cache_dict, mamba_kv_cache_dict) + # torch.save(mamba_kv_cache_dict, "mamba_kv_cache_dict.pth") def check_mamba_state_equal(mamba_state_ref: dict, mamba_state_new: dict): + atol = 1e-2 + rtol = 1e-2 for key in mamba_state_new: # mamba state new is a subset of mamba state ref for i, (ref, new) in enumerate(zip(mamba_state_ref[key], mamba_state_new[key])): - if not torch.allclose(ref, new[: ref.shape[0]]): + print("check_mamba_state_equal: ", ref.shape, new.shape) + new = new[: ref.shape[0]] + print("check_mamba_state_equal after convert: ", ref.shape, new.shape) + if not torch.allclose(ref, new, atol=atol, rtol=rtol): + diff_mask = ~torch.isclose(ref, new, atol=atol, rtol=rtol) + diff_idx = torch.nonzero(diff_mask) + if diff_idx.shape[0] * 100 < ref.numel(): + print( + f"[WARNING] found {diff_idx.shape[0] * 100 / ref.numel()}% of the elements are different" + ) + continue + print( + "diff: ", + diff_idx.shape, + diff_idx, + ref[diff_mask], + new[diff_mask], + torch.max(torch.abs(ref - new)), + ) raise ValueError( f"Mamba state is not equal for key: {key} at index {i}" ) return True +@dataclass +class StepActions: + scheduled_tokens: int + preprocess_copy_idx: int + postprocess_copy_idx: int + + @dataclass class TestConfig: num_prompt_tokens: int @@ -212,8 +262,8 @@ class TestConfig: def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch): monkeypatch.setenv("VLLM_USE_LIGHTER_MAMBA_CACHE", "1") monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") - num_generated_tokens = 1000 - num_prompt_tokens = 500 + num_generated_tokens = 50 + num_prompt_tokens = 551 sampling_params = SamplingParams(temperature=0.0, max_tokens=num_generated_tokens) full_prompt = open(f"{os.path.dirname(__file__)}/input.txt").read() fake_sample_fn = get_fake_sample_fn() @@ -240,6 +290,9 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch): prompt_token_ids = engine.get_tokenizer().encode(full_prompt) # print(f"Token IDs: {token_ids}") print(f"Token IDs length: {len(prompt_token_ids)}") + print( + f"expect token ids: {prompt_token_ids[num_prompt_tokens : num_prompt_tokens + num_generated_tokens]}" + ) outputs = engine.generate( [TokensPrompt(prompt_token_ids=prompt_token_ids[:num_prompt_tokens])], diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index eb571144a12c..88dbeac38e15 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2676,7 +2676,7 @@ def _preprocess_mamba(self, scheduler_output: "SchedulerOutput"): scheduler_output.finished_req_ids, scheduler_output.preempted_req_ids ): self.mamba_state_idx.pop(req_id, None) - for req_id in self.input_batch.req_ids: + for i, req_id in enumerate(self.input_batch.req_ids): if is_global_first_rank(): logger.info(f">>> [DEBUG] Worker: preprocess mamba for RUN: {req_id=}") req_state = self.requests[req_id] @@ -2694,23 +2694,17 @@ def _preprocess_mamba(self, scheduler_output: "SchedulerOutput"): f">>> [DEBUG] Worker: preprocess mamba: {req_id=}, idx {prev_state_idx=} -> {curr_state_idx=}" ) self.mamba_state_idx[req_id] = curr_state_idx - if prev_state_idx == -1 or prev_state_idx == curr_state_idx: - # no need to copy - continue - for mamba_group_id in mamba_group_ids: - prev_block_id = req_state.block_ids[mamba_group_id][prev_state_idx] - curr_block_id = req_state.block_ids[mamba_group_id][curr_state_idx] - assert prev_block_id != 0 - assert curr_block_id != 0 - if is_global_first_rank(): - logger.info( - f">>> [DEBUG] Worker: preprocess mamba: {req_id=}, COPY block {prev_block_id=} -> {curr_block_id=}" + if prev_state_idx != -1 and prev_state_idx != curr_state_idx: + # TODO: merge all these lines to copy_block + for mamba_group_id in mamba_group_ids: + self._mamba_copy_block_for_qwen_next( + self.kv_cache_config.kv_cache_groups[mamba_group_id], + prev_state_idx, + curr_state_idx, + self.input_batch.num_accepted_tokens_cpu[i] - 1, + req_state.block_ids[mamba_group_id], ) - self._mamba_copy_block( - self.kv_cache_config.kv_cache_groups[mamba_group_id], - prev_block_id, - curr_block_id, - ) + self.input_batch.num_accepted_tokens_cpu[i] = 1 def _mamba_copy_block_for_qwen_next( self, @@ -2745,7 +2739,7 @@ def _mamba_copy_block_for_qwen_next( and layer_name == kv_cache_group_spec.layer_names[0] ): logger.info( - f">>> [DEBUG] Worker: mamba_copy_block_for_qwen_next: {layer_name=}, conv {conv_state_block_id=} -> {dest_block_id=} with bias {accept_token_bias}, {gdn_state_block_id=} -> {dest_block_id=}" + f">>> [DEBUG] Worker: mamba_copy_block_for_qwen_next: {layer_name=}, idx {src_block_idx=} -> {dest_block_idx=} conv {conv_state_block_id=} -> {dest_block_id=} with bias {accept_token_bias}, {gdn_state_block_id=} -> {dest_block_id=}" ) def _postprocess_mamba(self, scheduler_output: "SchedulerOutput"): @@ -2794,7 +2788,7 @@ def _postprocess_mamba(self, scheduler_output: "SchedulerOutput"): ) if is_global_first_rank(): logger.info( - f">>> [DEBUG] Worker: postprocess mamba: {req_id=}, {src_block_idx=} -> {dest_block_idx=} with bias {accept_token_bias}" + f">>> [DEBUG] Worker: postprocess mamba copy: {req_id=}, {src_block_idx=} -> {dest_block_idx=} with bias {accept_token_bias}" ) for mamba_group_id in mamba_group_ids: self._mamba_copy_block_for_qwen_next( From 6d810c269c79dfa13436f5f58b5d7fc68fc91a46 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Wed, 10 Dec 2025 23:30:33 -0800 Subject: [PATCH 038/130] test decode Signed-off-by: Chen Zhang --- tests/v1/e2e/test_mamba_prefix_cache.py | 373 ++++++++++++++++++++++-- vllm/v1/worker/gpu_model_runner.py | 82 +++--- 2 files changed, 384 insertions(+), 71 deletions(-) diff --git a/tests/v1/e2e/test_mamba_prefix_cache.py b/tests/v1/e2e/test_mamba_prefix_cache.py index 9ce688371a72..ea55360856fa 100644 --- a/tests/v1/e2e/test_mamba_prefix_cache.py +++ b/tests/v1/e2e/test_mamba_prefix_cache.py @@ -10,20 +10,37 @@ from vllm import LLM, SamplingParams, TokensPrompt from vllm.sequence import IntermediateTensors from vllm.v1.attention.backends.utils import CommonAttentionMetadata +from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.engine.core_client import InprocClient from vllm.v1.outputs import SamplerOutput +from vllm.v1.request import Request from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.spec_decode.metadata import SpecDecodeMetadata +from vllm.v1.worker.gpu_input_batch import CachedRequestState from vllm.v1.worker.gpu_model_runner import GPUModelRunner from vllm.v1.worker.utils import get_mamba_groups + +@dataclass +class StepAction: + num_computed_tokens_start: int + num_scheduled_tokens: int + kv_cache_block_ids: list[int] | None # [] to follow last step + preprocess_copy_idx: tuple[int, int] | None # -1, -1 for no copy + postprocess_copy_idx: tuple[int, int] | None # -1, -1 for no copy + + num_speculative_tokens = 3 -num_accepted_tokens = 4 +num_accepted_tokens = 1 prompt_token_ids = [] MODEL = "Qwen/Qwen3-Next-80B-A3B-Instruct" BLOCK_SIZE = 560 NUM_HIDDEN_LAYERS = 8 +cur_step_action_idx = 0 +cur_step_action: StepAction | None = None +step_actions: list[StepAction] = [] def get_fake_sample_fn() -> SamplerOutput: @@ -122,6 +139,55 @@ def fake_propose_draft_token_ids_fn( return fake_propose_draft_token_ids_fn +def get_fake_step_action_fn(original_step_action_fn: Callable): + def fake_get_output(self: InprocClient): + global cur_step_action_idx + global cur_step_action + if cur_step_action_idx < len(step_actions): + cur_step_action = step_actions[cur_step_action_idx] + cur_step_action_idx += 1 + else: + cur_step_action = None + print(f"fake_get_output: {cur_step_action_idx=} {cur_step_action=}") + return original_step_action_fn(self) + + return fake_get_output + + +def get_fake_allocate_slots_fn(original_allocate_slots_fn: Callable): + def fake_allocate_slots_fn( + self: KVCacheManager, + request: Request, + num_new_tokens: int, + num_new_computed_tokens: int = 0, + new_computed_blocks: KVCacheBlocks | None = None, + num_lookahead_tokens: int = 0, + delay_cache_blocks: bool = False, + num_encoder_tokens: int = 0, + ): + ret = original_allocate_slots_fn( + self, + request, + num_new_tokens, + num_new_computed_tokens, + new_computed_blocks, + num_lookahead_tokens, + delay_cache_blocks, + num_encoder_tokens, + ) + if cur_step_action is not None: + print("[UNIT TEST STEP] verifying kv_cache_block_ids") + cur_block_ids = self.coordinator.single_type_managers[0].req_to_blocks[ + request.request_id + ] + not_null_block = [not block.is_null for block in cur_block_ids] + not_null_block = [1 if block else 0 for block in not_null_block] + assert not_null_block == cur_step_action.kv_cache_block_ids + return ret + + return fake_allocate_slots_fn + + mamba_kv_cache_dict = {} @@ -133,6 +199,12 @@ def fake_execute_model_fn( scheduler_output: SchedulerOutput, intermediate_tensors: IntermediateTensors | None = None, ): + if cur_step_action is not None: + num_scheduled_tokens = next( + iter(scheduler_output.num_scheduled_tokens.values()) + ) + assert num_scheduled_tokens == cur_step_action.num_scheduled_tokens + print("[UNIT TEST STEP] verified num_scheduled_tokens") mamba_group_ids, mamba_spec = get_mamba_groups(self.kv_cache_config) mamba_group_id = mamba_group_ids[0] mamba_layer_name = self.kv_cache_config.kv_cache_groups[ @@ -172,11 +244,70 @@ def fake_execute_model_fn( ret = original_execute_model_fn(self, scheduler_output, intermediate_tensors) + if cur_step_action is not None: + assert ( + cur_step_action.num_computed_tokens_start + == self.input_batch.num_computed_tokens_cpu[0].item() + ) + print("[UNIT TEST STEP] verified num_computed_tokens_start") + return ret return fake_execute_model_fn +def get_fake_process_mamba_fn( + original_preprocess_mamba_fn: Callable, + original_post_process_mamba_fn: Callable, + original_copy_fn: Callable, +): + copy_info = (-1, -1) + + def fake_preprocess_mamba_fn( + self: GPUModelRunner, scheduler_output: SchedulerOutput + ): + nonlocal copy_info + copy_info = (-1, -1) + ret = original_preprocess_mamba_fn(self, scheduler_output) + if cur_step_action is not None: + print("[UNIT TEST STEP] verifying preprocess_copy_idx") + assert copy_info == cur_step_action.preprocess_copy_idx + return ret + + def fake_post_process_mamba_fn( + self: GPUModelRunner, scheduler_output: SchedulerOutput + ): + nonlocal copy_info + copy_info = (-1, -1) + ret = original_post_process_mamba_fn(self, scheduler_output) + if cur_step_action is not None: + print("[UNIT TEST STEP] verifying postprocess_copy_idx") + assert copy_info == cur_step_action.postprocess_copy_idx + return ret + + def fake_copy_fn( + self: GPUModelRunner, + kv_cache_group_ids: list[int], + src_block_idx: int, + dest_block_idx: int, + accept_token_bias: int, + req_state: CachedRequestState, + ): + nonlocal copy_info + assert copy_info == (-1, -1) + copy_info = (src_block_idx, dest_block_idx) + return original_copy_fn( + self, + kv_cache_group_ids, + src_block_idx, + dest_block_idx, + accept_token_bias, + req_state, + ) + + return fake_preprocess_mamba_fn, fake_post_process_mamba_fn, fake_copy_fn + + def test_run_ref_mamba_state(monkeypatch: pytest.MonkeyPatch): monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") num_generated_tokens = 20 @@ -243,37 +374,198 @@ def check_mamba_state_equal(mamba_state_ref: dict, mamba_state_new: dict): return True -@dataclass -class StepActions: - scheduled_tokens: int - preprocess_copy_idx: int - postprocess_copy_idx: int - - @dataclass class TestConfig: num_prompt_tokens: int num_generated_tokens: int num_accepted_tokens: int - expect_schedule_tokens: list[int] | None - expect_block_table: list[int] | None + step_actions: list[StepAction] -def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch): +def apply_patch(monkeypatch: pytest.MonkeyPatch): monkeypatch.setenv("VLLM_USE_LIGHTER_MAMBA_CACHE", "1") monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") - num_generated_tokens = 50 - num_prompt_tokens = 551 - sampling_params = SamplingParams(temperature=0.0, max_tokens=num_generated_tokens) - full_prompt = open(f"{os.path.dirname(__file__)}/input.txt").read() + fake_sample_fn = get_fake_sample_fn() monkeypatch.setattr(GPUModelRunner, "_sample", fake_sample_fn) + fake_propose_draft_token_ids_fn = get_fake_propose_draft_token_ids_fn() monkeypatch.setattr( GPUModelRunner, "propose_draft_token_ids", fake_propose_draft_token_ids_fn ) + fake_execute_model_fn = get_fake_execute_model_fn(GPUModelRunner.execute_model) monkeypatch.setattr(GPUModelRunner, "execute_model", fake_execute_model_fn) + + fake_step_action_fn = get_fake_step_action_fn(InprocClient.get_output) + monkeypatch.setattr(InprocClient, "get_output", fake_step_action_fn) + + fake_allocate_slots_fn = get_fake_allocate_slots_fn(KVCacheManager.allocate_slots) + monkeypatch.setattr(KVCacheManager, "allocate_slots", fake_allocate_slots_fn) + + fake_preprocess_mamba_fn, fake_post_process_mamba_fn, fake_copy_fn = ( + get_fake_process_mamba_fn( + GPUModelRunner._preprocess_mamba, + GPUModelRunner._postprocess_mamba, + GPUModelRunner._mamba_copy_block_for_qwen_next, + ) + ) + monkeypatch.setattr(GPUModelRunner, "_preprocess_mamba", fake_preprocess_mamba_fn) + monkeypatch.setattr( + GPUModelRunner, "_postprocess_mamba", fake_post_process_mamba_fn + ) + monkeypatch.setattr(GPUModelRunner, "_mamba_copy_block_for_qwen_next", fake_copy_fn) + + +def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch): + apply_patch(monkeypatch) + full_prompt = open(f"{os.path.dirname(__file__)}/input.txt").read() + tests = { + # test case 1: no hit, accept 1 token + "accept_1": TestConfig( + num_prompt_tokens=554, + num_generated_tokens=20, + num_accepted_tokens=1, + step_actions=[ + StepAction(0, 554, [1, 1, 1, 1], (-1, -1), (-1, -1)), + StepAction(554, 4, [], (-1, -1), (-1, -1)), + StepAction(555, 4, [], (-1, -1), (-1, -1)), + StepAction(556, 4, [], (-1, -1), (-1, -1)), + StepAction(557, 4, [1, 1, 1, 1, 1], (0, 1), (-1, -1)), + StepAction(558, 4, [], (-1, -1), (-1, -1)), + StepAction(559, 4, [], (-1, -1), (1, 0)), + StepAction(560, 4, [], (-1, -1), (-1, -1)), + StepAction(561, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), + ], + ), + # test case 2.1: no hit, accept 2 tokens + "accept_2_1": TestConfig( + num_prompt_tokens=554, + num_generated_tokens=20, + num_accepted_tokens=2, + step_actions=[ + StepAction(0, 554, [1, 1, 1, 1], (-1, -1), (-1, -1)), + StepAction(554, 4, [], (-1, -1), (-1, -1)), + StepAction(556, 4, [], (-1, -1), (-1, -1)), + StepAction(558, 4, [1, 1, 1, 1, 1], (0, 1), (1, 0)), + StepAction(560, 4, [], (-1, -1), (-1, -1)), + StepAction(562, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), + ], + ), + # test case 2.2: no hit, accept 2 tokens + "accept_2_2": TestConfig( + num_prompt_tokens=555, + num_generated_tokens=20, + num_accepted_tokens=2, + step_actions=[ + StepAction(0, 555, [1, 1, 1, 1], (-1, -1), (-1, -1)), + StepAction(555, 4, [], (-1, -1), (-1, -1)), + StepAction(557, 4, [1, 1, 1, 1, 1], (0, 1), (-1, -1)), + StepAction(559, 4, [], (-1, -1), (1, 0)), + StepAction(561, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), + ], + ), + "accept_3_1": TestConfig( + num_prompt_tokens=553, + num_generated_tokens=20, + num_accepted_tokens=3, + step_actions=[ + StepAction(0, 553, [1, 1, 1, 1], (-1, -1), (-1, -1)), + StepAction(553, 4, [], (-1, -1), (-1, -1)), + StepAction(556, 4, [], (-1, -1), (-1, -1)), + StepAction(559, 4, [1, 1, 1, 1, 1], (0, 1), (1, 0)), + StepAction(562, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), + ], + ), + "accept_3_2": TestConfig( + num_prompt_tokens=554, + num_generated_tokens=20, + num_accepted_tokens=3, + step_actions=[ + StepAction(0, 554, [1, 1, 1, 1], (-1, -1), (-1, -1)), + StepAction(554, 4, [], (-1, -1), (-1, -1)), + StepAction(557, 4, [1, 1, 1, 1, 1], (0, 1), (1, 0)), + StepAction(560, 4, [], (-1, -1), (-1, -1)), + StepAction(563, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), + ], + ), + "accept_3_3": TestConfig( + num_prompt_tokens=555, + num_generated_tokens=20, + num_accepted_tokens=3, + step_actions=[ + StepAction(0, 555, [1, 1, 1, 1], (-1, -1), (-1, -1)), + StepAction(555, 4, [], (-1, -1), (-1, -1)), + StepAction(558, 4, [1, 1, 1, 1, 1], (0, 1), (1, 0)), + StepAction(561, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), + ], + ), + "accept_4_1": TestConfig( + num_prompt_tokens=553, + num_generated_tokens=20, + num_accepted_tokens=4, + step_actions=[ + StepAction(0, 553, [1, 1, 1, 1], (-1, -1), (-1, -1)), + StepAction(553, 4, [], (-1, -1), (-1, -1)), + StepAction(557, 4, [1, 1, 1, 1, 1], (0, 1), (1, 0)), + StepAction(561, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), + StepAction(565, 4, [], (-1, -1), (-1, -1)), + ], + ), + "accept_4_2": TestConfig( + num_prompt_tokens=554, + num_generated_tokens=25, + num_accepted_tokens=4, + step_actions=[ + StepAction(0, 554, [1, 1, 1, 1], (-1, -1), (-1, -1)), + StepAction(554, 4, [], (-1, -1), (-1, -1)), + StepAction(558, 4, [1, 1, 1, 1, 1], (0, 1), (1, 0)), + StepAction(562, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), + StepAction(566, 4, [], (-1, -1), (-1, -1)), + ], + ), + "accept_4_3": TestConfig( + num_prompt_tokens=555, + num_generated_tokens=25, + num_accepted_tokens=4, + step_actions=[ + StepAction(0, 555, [1, 1, 1, 1], (-1, -1), (-1, -1)), + StepAction(555, 4, [], (-1, -1), (-1, -1)), + StepAction(559, 4, [1, 1, 1, 1, 1], (0, 1), (1, 0)), + StepAction(563, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), + ], + ), + "accept_4_4": TestConfig( + num_prompt_tokens=556, + num_generated_tokens=25, + num_accepted_tokens=4, + step_actions=[ + StepAction(0, 556, [1, 1, 1, 1], (-1, -1), (-1, -1)), + StepAction(556, 4, [], (-1, -1), (0, 0)), + StepAction(560, 4, [1, 1, 1, 1, 1], (0, 1), (-1, -1)), + StepAction(564, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), + ], + ), + # "prompt_block_size": TestConfig( + # num_prompt_tokens=560, + # num_generated_tokens=10, + # num_accepted_tokens=4, + # step_actions=[ + # StepAction(0, 560, [1, 1, 1, 1], (-1, -1), (0, 0)), + # StepAction(560, 4, [], (-1, -1), (-1, -1)), + # ], + # ), + # "prompt_2_block_size": TestConfig( + # num_prompt_tokens=560 * 2, + # num_generated_tokens=10, + # num_accepted_tokens=4, + # step_actions=[ + # StepAction(0, 560, [0, 1, 1, 1, 1], (-1, -1), (1, 1)), + # StepAction(560, 4, [], (-1, -1), (-1, -1)), + # ], + # ), + } + engine = LLM( model=MODEL, enable_prefix_caching=True, @@ -290,22 +582,43 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch): prompt_token_ids = engine.get_tokenizer().encode(full_prompt) # print(f"Token IDs: {token_ids}") print(f"Token IDs length: {len(prompt_token_ids)}") - print( - f"expect token ids: {prompt_token_ids[num_prompt_tokens : num_prompt_tokens + num_generated_tokens]}" - ) - - outputs = engine.generate( - [TokensPrompt(prompt_token_ids=prompt_token_ids[:num_prompt_tokens])], - sampling_params, - ) - print(f"Generated text: {outputs[0].outputs[0].token_ids}") - print( - f"expect token ids: {prompt_token_ids[num_prompt_tokens : num_prompt_tokens + num_generated_tokens]}" - ) - - torch.save(mamba_kv_cache_dict, "mamba_kv_cache_dict_new.pth") mamba_state_ref = torch.load("mamba_kv_cache_dict.pth") - check_mamba_state_equal(mamba_state_ref, mamba_kv_cache_dict) + for test_case_name, test_config in tests.items(): + print(f"Running test case: {test_case_name}") + num_generated_tokens = test_config.num_generated_tokens + num_prompt_tokens = test_config.num_prompt_tokens + global num_accepted_tokens + num_accepted_tokens = test_config.num_accepted_tokens + sampling_params = SamplingParams( + temperature=0.0, max_tokens=num_generated_tokens + ) + global cur_step_action_idx + cur_step_action_idx = 0 + for step_action_prev, step_action_next in zip( + test_config.step_actions[:-1], test_config.step_actions[1:] + ): + if ( + step_action_next.kv_cache_block_ids is not None + and len(step_action_next.kv_cache_block_ids) == 0 + ): + step_action_next.kv_cache_block_ids = ( + step_action_prev.kv_cache_block_ids.copy() + ) + global step_actions + step_actions = test_config.step_actions + print("step actions: ", step_actions) + print( + f"expect token ids: {prompt_token_ids[num_prompt_tokens : num_prompt_tokens + num_generated_tokens]}" + ) + + outputs = engine.generate( + [TokensPrompt(prompt_token_ids=prompt_token_ids[:num_prompt_tokens])], + sampling_params, + ) + assert engine.llm_engine.engine_core.engine_core.scheduler.reset_prefix_cache() + print(f"End test case: {test_case_name}") + check_mamba_state_equal(mamba_state_ref, mamba_kv_cache_dict) + mamba_kv_cache_dict.clear() def test_check_mamba_state_equal(): diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 88dbeac38e15..49999baeaf9c 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2696,51 +2696,52 @@ def _preprocess_mamba(self, scheduler_output: "SchedulerOutput"): self.mamba_state_idx[req_id] = curr_state_idx if prev_state_idx != -1 and prev_state_idx != curr_state_idx: # TODO: merge all these lines to copy_block - for mamba_group_id in mamba_group_ids: - self._mamba_copy_block_for_qwen_next( - self.kv_cache_config.kv_cache_groups[mamba_group_id], - prev_state_idx, - curr_state_idx, - self.input_batch.num_accepted_tokens_cpu[i] - 1, - req_state.block_ids[mamba_group_id], - ) + self._mamba_copy_block_for_qwen_next( + mamba_group_ids, + prev_state_idx, + curr_state_idx, + self.input_batch.num_accepted_tokens_cpu[i] - 1, + req_state, + ) self.input_batch.num_accepted_tokens_cpu[i] = 1 def _mamba_copy_block_for_qwen_next( self, - kv_cache_group_spec: KVCacheGroupSpec, + kv_cache_group_ids: list[int], src_block_idx: int, dest_block_idx: int, accept_token_bias: int, - block_ids: list[int], + req_state: CachedRequestState, ): # TODO: general impl for all models if src_block_idx == dest_block_idx and accept_token_bias == 0: return forward_context = self.compilation_config.static_forward_context - dest_block_id = block_ids[dest_block_idx] - for layer_name in kv_cache_group_spec.layer_names: - kv_caches: list[list[torch.Tensor]] = forward_context[layer_name].kv_cache[ - 0 - ] - conv_state, gdn_state = kv_caches - # conv state - conv_state_block_id = block_ids[src_block_idx] - src_conv_state = conv_state[conv_state_block_id][accept_token_bias:] - dest_conv_state = conv_state[dest_block_id] - dest_conv_state[: len(src_conv_state)].copy_(src_conv_state.clone()) - # gdn state - gdn_state_block_id = block_ids[src_block_idx + accept_token_bias] - src_gdn_state = gdn_state[gdn_state_block_id] - dest_gdn_state = gdn_state[dest_block_id] - dest_gdn_state.copy_(src_gdn_state) - if ( - is_global_first_rank() - and layer_name == kv_cache_group_spec.layer_names[0] - ): - logger.info( - f">>> [DEBUG] Worker: mamba_copy_block_for_qwen_next: {layer_name=}, idx {src_block_idx=} -> {dest_block_idx=} conv {conv_state_block_id=} -> {dest_block_id=} with bias {accept_token_bias}, {gdn_state_block_id=} -> {dest_block_id=}" - ) + for kv_cache_group_id in kv_cache_group_ids: + block_ids = req_state.block_ids[kv_cache_group_id] + dest_block_id = block_ids[dest_block_idx] + layer_names = self.kv_cache_config.kv_cache_groups[ + kv_cache_group_id + ].layer_names + for layer_name in layer_names: + kv_caches: list[list[torch.Tensor]] = forward_context[ + layer_name + ].kv_cache[0] + conv_state, gdn_state = kv_caches + # conv state + conv_state_block_id = block_ids[src_block_idx] + src_conv_state = conv_state[conv_state_block_id][accept_token_bias:] + dest_conv_state = conv_state[dest_block_id] + dest_conv_state[: len(src_conv_state)].copy_(src_conv_state.clone()) + # gdn state + gdn_state_block_id = block_ids[src_block_idx + accept_token_bias] + src_gdn_state = gdn_state[gdn_state_block_id] + dest_gdn_state = gdn_state[dest_block_id] + dest_gdn_state.copy_(src_gdn_state) + if is_global_first_rank() and layer_name == layer_names[0]: + logger.info( + f">>> [DEBUG] Worker: mamba_copy_block_for_qwen_next: {layer_name=}, idx {src_block_idx=} -> {dest_block_idx=} conv {conv_state_block_id=} -> {dest_block_id=} with bias {accept_token_bias}, {gdn_state_block_id=} -> {dest_block_id=}" + ) def _postprocess_mamba(self, scheduler_output: "SchedulerOutput"): """ @@ -2790,14 +2791,13 @@ def _postprocess_mamba(self, scheduler_output: "SchedulerOutput"): logger.info( f">>> [DEBUG] Worker: postprocess mamba copy: {req_id=}, {src_block_idx=} -> {dest_block_idx=} with bias {accept_token_bias}" ) - for mamba_group_id in mamba_group_ids: - self._mamba_copy_block_for_qwen_next( - self.kv_cache_config.kv_cache_groups[mamba_group_id], - src_block_idx, - dest_block_idx, - accept_token_bias, - req_state.block_ids[mamba_group_id], - ) + self._mamba_copy_block_for_qwen_next( + mamba_group_ids, + src_block_idx, + dest_block_idx, + accept_token_bias, + req_state, + ) if src_block_idx == dest_block_idx: num_accepted_tokens_cpu[i] = 1 From 6260f9ecb6cbf4c7a1616ef5012a55eba4584a63 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Thu, 11 Dec 2025 00:17:41 -0800 Subject: [PATCH 039/130] unit test Signed-off-by: Chen Zhang --- tests/v1/e2e/input.txt | 199 ++++++++++++++++++++++++ tests/v1/e2e/test_mamba_prefix_cache.py | 85 +++++++--- 2 files changed, 259 insertions(+), 25 deletions(-) create mode 100644 tests/v1/e2e/input.txt diff --git a/tests/v1/e2e/input.txt b/tests/v1/e2e/input.txt new file mode 100644 index 000000000000..b10b1bbcba2a --- /dev/null +++ b/tests/v1/e2e/input.txt @@ -0,0 +1,199 @@ +# The Architecture of Intelligence: A Deep Dive into Large Language Models (LLMs) + +## Introduction: The New Cognitive Revolution + +In the annals of computing history, few technologies have burst onto the global stage with the same immediate and transformative impact as Large Language Models (LLMs). Emerging from the confluence of decades of theoretical research and the exponential growth of computational power and data, LLMs like GPT, Gemini, and Claude have transitioned Artificial Intelligence (AI) from a niche academic pursuit to the central utility of the digital age. + +An LLM is not merely a sophisticated piece of software; it is a complex, deep neural network designed to understand, process, and generate human language with startling fluency, coherence, and context. These models serve as the probabilistic engines of a new cognitive revolution, capable of tasks that range from synthesizing vast datasets and translating languages to creating novel code and engaging in philosophical debate. + +This comprehensive article explores the complete landscape of Large Language Models. We will trace their historical lineage, demystify the revolutionary architecture upon which they are built, detail the arduous training process, analyze the emergent capabilities and inherent flaws, survey their massive commercial and social applications, and, finally, grapple with the profound ethical and strategic challenges they pose for the future of humanity. + +## Part I: The Historical Foundations of Language Modeling + +The concept of a machine generating human language has a history far longer than the digital computer. Its modern journey, however, can be segmented into distinct eras, each overcoming the limitations of the last. + +### 1. Statistical Language Models (1980s – 2000s) +The earliest forms of language modeling were rooted in statistics and probability theory. These were dominated by **n-gram models**, inspired by the mathematical work of Andrey Markov. An n-gram model predicts the probability of the next word ($w_i$) based solely on the previous $n-1$ words ($w_{i-(n-1)}, \dots, w_{i-1}$). + +$$P(w_i | w_{1}^{i-1}) \approx P(w_i | w_{i-(n-1)}^{i-1})$$ + +These models were simple, explainable, and formed the backbone of early machine translation and speech recognition systems, notably pioneering corpus-based language modeling at IBM. However, they suffered from **the curse of dimensionality** and **data sparsity**. As $n$ increased (to capture more context), the number of possible word sequences grew exponentially, making it impossible to accurately estimate probabilities for sequences not seen in the training data. + +### 2. Neural Language Models and Deep Learning (2000s – 2017) +The transition from statistical methods to neural networks addressed the data sparsity problem. The breakthrough came with the introduction of **word embeddings** (pioneered by Bengio in 2003, and popularized by Word2Vec in 2013). + +Instead of treating words as discrete, independent symbols, word embeddings represent each word as a dense, real-valued vector in a multi-dimensional space. Words with similar meanings (e.g., "King," "Queen," "Man," "Woman") are mapped closer together in this geometric space. This allowed the models to generalize, moving beyond simple word co-occurrence to semantic relationships. + +The workhorse of this era was the **Recurrent Neural Network (RNN)**, particularly the **Long Short-Term Memory (LSTM)** network. RNNs process sequences word-by-word, maintaining a "hidden state" or "memory cell" that accumulates information from the previous steps. This allowed them to handle longer-term dependencies than n-gram models. However, the sequential nature of RNNs created two major issues: +1. **Slow Training:** Processing must be strictly sequential, preventing the use of modern parallel computing hardware like GPUs. +2. **Vanishing/Exploding Gradients:** For very long sequences, the error signals used during training (gradients) either vanished (making the model forget the beginning of the text) or exploded (making training unstable). + +### 3. The Attention Mechanism (2014) +The first true step toward the LLM revolution was the introduction of the **Attention Mechanism** in 2014. Used initially within RNN-based encoder-decoder architectures (the basis of Google Translate at the time), attention allowed the model to dynamically weigh the importance of different parts of the input sequence when generating a specific part of the output. This was crucial for tasks like translation, where the most relevant input word might not be the adjacent one. + +## Part II: The Transformer Architecture (2017 - Present) + +The year 2017 marks the true beginning of the LLM era with the publication of "Attention Is All You Need" by researchers at Google. This paper proposed the **Transformer** architecture, which jettisoned recurrence entirely and relied *only* on the attention mechanism. + +### The Encoder-Decoder Foundation +The original Transformer model consists of two main stacks: an **Encoder** and a **Decoder**. +* **Encoder:** Processes the input sequence (e.g., an English sentence), creating a robust, context-aware numerical representation of it. +* **Decoder:** Takes the Encoder's output and iteratively generates the output sequence (e.g., the French translation). + +### The Self-Attention Breakthrough +The core innovation is **Self-Attention**. It allows the model to calculate how much every word in the input sequence relates to every other word *within that same sequence*. This is done through a mathematical process involving three vector representations for each input token: + +1. **Query ($Q$):** Represents the token being processed—the question being asked. +2. **Key ($K$):** Represents all other tokens—the information that can be searched. +3. **Value ($V$):** Represents the actual information content of all other tokens. + +The model computes the dot product of the $Q$ vector with all $K$ vectors to get **attention scores**. These scores, after normalization (using a Softmax function), determine how much of the $V$ vectors should be aggregated to create the new, context-rich representation of the original token. + +$$\text{Attention}(Q, K, V) = \text{Softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$ + +This allows the model to achieve **parallel processing**. Unlike sequential RNNs, every word's vector representation can be calculated simultaneously, leveraging the massive parallel capabilities of GPUs and leading to unprecedented scalability. + +### Positional Encoding +Since the Transformer has no inherent recurrence (no left-to-right reading), the model needs a way to know the order of the words. This is solved by **Positional Encoding**—adding a vector to the input embeddings that contains information about the word’s absolute or relative position in the sequence. Without this, the phrase "Dog bites man" would be processed identically to "Man bites dog." + +### Model Variants: BERT vs. GPT +The Transformer architecture gave rise to three major model families: + +1. **Encoder-Only (e.g., BERT, RoBERTa):** Used primarily for *understanding* tasks (classification, named entity recognition, sentiment analysis). They are excellent at bidirectional context (looking both backward and forward in a sentence). +2. **Decoder-Only (e.g., GPT, Llama):** Used primarily for *generation* tasks. The decoder is constrained by a **causal mask** that prevents it from looking at future tokens, forcing it to generate text sequentially, word-by-word. These models have become the dominant architecture for conversational AI. +3. **Encoder-Decoder (e.g., T5, BART):** Used for sequence-to-sequence tasks like translation and summarization. + +## Part III: The Training Lifecycle of an LLM + +The development of an LLM is a complex, multi-stage process involving massive computational resources, vast data curation efforts, and sophisticated human intervention. + +### 1. Data Curation and Tokenization +The first step is gathering and cleaning the training corpus. Modern LLMs are trained on hundreds of terabytes or even petabytes of text, often sourced from: +* **CommonCrawl:** A massive, open-source scrape of the public internet. +* **Filtered Web Text:** Highly curated, higher-quality web pages. +* **Books and Literature:** Digitized libraries. +* **Code Repositories:** Such as GitHub, to instill programming knowledge. +* **Wikipedia:** Structured knowledge bases. + +Data is meticulously filtered to remove low-quality content, boilerplate text, and offensive material. The text is then broken down into **tokens** using a process like **Byte-Pair Encoding (BPE)**. Tokens are the minimal units of meaning the model processes, bridging the gap between human language and numerical vectors. + +### 2. Pre-Training: Self-Supervised Learning +The core of LLM training is the **Pre-Training** phase. The model's hundreds of billions of parameters are initialized, and it is fed the massive, unlabeled dataset. The primary objective is **Next-Token Prediction** (or autoregressive modeling): predicting the next most probable token in a sequence, given all previous tokens. + +* **Objective Function:** The model minimizes the **Loss Function** (often **Cross-Entropy Loss**), which measures the difference between the model's predicted probability distribution over the vocabulary and the actual next token. +* **Optimization:** The model iteratively adjusts its weights using **Backpropagation** and an **Optimizer** (e.g., Adam or its variants) to reduce this loss. + +This phase, costing millions of dollars in GPU time, imbues the model with its fundamental knowledge base, grammar, syntax, and a basic, structural understanding of the world. It is through this pure statistical exercise that "reasoning" begins to emerge. + +### 3. Fine-Tuning and Alignment +A raw pre-trained model is highly knowledgeable but often unhelpful and potentially toxic. It will simply continue the statistical pattern of the input, regardless of intent. Alignment is the process of making the model follow instructions and adhere to ethical guidelines. + +#### A. Supervised Fine-Tuning (SFT) +The model is trained on a smaller, high-quality, human-curated dataset of prompts and desired, high-quality responses. This teaches the model a conversational style—how to act as an assistant, answer questions, and follow complex directions. + +#### B. Reinforcement Learning from Human Feedback (RLHF) +RLHF is the key component that created the conversational brilliance of models like ChatGPT. +1. **Response Generation:** For a given prompt, the LLM generates several possible answers. +2. **Human Ranking:** Human labelers rank these responses from best to worst based on helpfulness, accuracy, and safety. +3. **Reward Model Training:** A separate, smaller model called the **Reward Model (RM)** is trained to predict the human preference score for any response. The RM effectively learns "what a good answer looks like." +4. **Policy Optimization:** The main LLM is then fine-tuned using a Reinforcement Learning algorithm (like **Proximal Policy Optimization, PPO**) to maximize the score given by the Reward Model. + +This process explicitly aligns the model's objective function with human values, a crucial step in preparing the model for public deployment. + +## Part IV: Emergent Capabilities and Inherent Limitations + +The path from a neural network to a cognitive tool is marked by phenomena that both inspire awe and caution. + +### The Phenomenon of Emergence +As LLMs crossed certain thresholds—specifically in parameter count (size) and training data volume—researchers observed **Emergent Capabilities**. These are skills that the model was never explicitly trained for, yet they appear spontaneously. + +* **In-Context Learning (ICL):** The ability to learn a new task from a few examples provided directly in the prompt, without needing formal fine-tuning (Few-Shot Learning). +* **Chain-of-Thought (CoT) Reasoning:** The ability to decompose complex, multi-step problems into sequential reasoning steps, often unlocked by simply telling the model to "think step-by-step." This dramatically improves performance on arithmetic, common sense, and symbolic logic tasks. +* **Multilingual and Code Proficiency:** Models trained primarily on English and code surprisingly develop high-level proficiency in dozens of other languages and complex programming languages. + +These emergent properties suggest that the simple task of next-token prediction, when scaled sufficiently, leads to a kind of generalized, implicit world model—a probabilistic simulation of human knowledge and reasoning. + +### The Challenge of Hallucination +The most significant and stubborn limitation of LLMs is **Hallucination**—the generation of factually incorrect, nonsensical, or unfaithful content that is nevertheless syntactically plausible. + +The root cause lies in the model's core function: it is a **prediction engine, not a retrieval engine**. It does not access an external database of facts; it samples the most statistically likely sequence of tokens based on its internal, compressed world model. If the highest-probability sequence *looks* like a scientific citation but is entirely fabricated, the model will generate it. + +Mitigation strategies, such as **Retrieval-Augmented Generation (RAG)**, which links the LLM to a real-time, verifiable external knowledge source (like a search index or a company database), are essential for using LLMs in high-stakes, fact-based applications. + +## Part V: The Expanding Ecosystem and Applications + +The LLM ecosystem is diversifying rapidly, moving beyond the simple "chatbot" into powerful, specialized tools. + +### 1. Model Scaling and Efficiency +The pursuit of ever-larger models has reached its limits due to cost and data scarcity. The frontier has shifted to efficiency and specialization. +* **Mixture-of-Experts (MoE):** Models like Mixtral use a routing mechanism to activate only a subset of specialized "expert" neural networks for any given query. This allows the model to have a massive total parameter count (high knowledge capacity) while only using a fraction of the computational power (high efficiency). +* **Quantization and Pruning:** Techniques used to reduce the size and computational demands of models, making them executable on smaller devices (e.g., a mobile phone or a personal laptop). + +### 2. Multimodality +The most significant recent breakthrough is the transition from LLMs (Large Language Models) to **LMMs (Large Multimodal Models)**. These models are trained not just on text, but also on images, audio, and video data, allowing them to: +* **Visual Reasoning:** Analyze a complex graph, a photograph, or a technical diagram and answer questions about its content. +* **Audio Processing:** Transcribe, summarize, and understand the context of spoken language directly. +* **Seamless Integration:** Accept a prompt containing text and an image simultaneously (e.g., "Describe this image and write a poem about it"). + +### 3. Industry Applications +LLMs are no longer experimental; they are becoming foundational infrastructure across nearly every industry: +* **Software Engineering:** Automated code generation (e.g., GitHub Copilot), debugging, code translation between languages, and writing documentation. +* **Knowledge Work & Productivity:** Summarizing long documents, drafting complex reports, synthesizing research, and managing data from unstructured sources. +* **Customer Service & Sales:** Highly personalized and efficient conversational AI bots that can handle complex queries beyond simple FAQs. +* **Medicine and Law:** Assisting in drafting legal briefs, summarizing medical records, and cross-referencing diagnostic information (always requiring human oversight). +* **Creative Arts:** Generating marketing copy, scriptwriting, music composition (in conjunction with other AI models), and video production assets. + +## Part VI: The Ethical and Societal Labyrinth + +The power of LLMs brings with it a commensurately large set of ethical, social, and economic risks that demand global governance and responsible development. + +### 1. Bias, Fairness, and Amplification +LLMs are fundamentally statistical mirrors of their training data. If the internet contains biases related to gender, race, or geography, the model will ingest, amplify, and operationalize those biases. +* **Stereotype Reinforcement:** A model might associate certain professions (e.g., "engineer") predominantly with one gender, leading to biased outputs in hiring tools. +* **Harmful Generalizations:** Biases can lead to unfair or discriminatory decision-making when the models are deployed in high-stakes areas like loan applications or judicial risk assessment. +Mitigating bias requires meticulous data curation, adversarial testing, and post-processing "guardrails," but complete elimination remains technically elusive. + +### 2. Misinformation and Disinformation +The ability of LLMs to generate highly convincing, fluent text at scale is a threat to information integrity. Malicious actors can use these tools to: +* **Automate Phishing and Scams:** Generate personalized, sophisticated deceptive content. +* **Create Deepfake Text:** Impersonate real individuals or organizations with convincing prose. +* **Fabricate "Fake News" and Propaganda:** Generate massive volumes of highly plausible, factually false content, overwhelming traditional fact-checking mechanisms and accelerating the breakdown of public trust. + +### 3. Data Privacy and Security +LLMs pose risks related to data ingestion and leakage: +* **Training Data Memorization:** Models can, in rare cases, memorize and regurgitate personally identifiable information (PII) or copyrighted material from their vast training corpus. +* **Inference Attack (Data Leakage):** If a user provides proprietary or sensitive information as a prompt, that data may be inadvertently used to train future iterations of the model or leak through side channels, raising major security concerns for enterprise adoption. + +### 4. Environmental Impact +The scale of LLMs has a significant environmental footprint. Training a single frontier model requires months of continuous operation on thousands of GPUs, consuming energy equivalent to hundreds of homes for a year. The high computational cost raises questions about the long-term sustainability and equitable access to the technology. + +### 5. Economic Disruption and Labor +LLMs are directly impacting knowledge-based professions, particularly those involving content creation, data synthesis, and routine communication. While optimists argue the technology will mostly automate mundane tasks, freeing humans for higher-level work, policymakers and economists are grappling with the reality of rapid job displacement, income inequality, and the need for massive reskilling initiatives. + +## Part VII: The Frontier—The Path to Agentic AI and AGI + +The current state of the art is fleeting. The research community is pushing toward systems that are more autonomous, capable, and integrated. + +### 1. Agentic AI +The shift from a "Chatbot" to an "Agent" is the immediate future. Current LLMs are **reactive** (Question $\rightarrow$ Answer). An Agentic LLM is **proactive and goal-oriented**. +* **Goal:** The user provides a high-level goal (e.g., "Find the cheapest flight to Tokyo next month and book a hotel near the Shinjuku station."). +* **Planning:** The LLM breaks the goal into sub-tasks (Search flights, Compare prices, Search hotels, Check availability, Execute booking actions). +* **Tool Use:** The LLM integrates external tools (search engines, flight APIs, email/calendar APIs) to complete the tasks autonomously, engaging in a trial-and-error loop until the goal is achieved. This transforms the LLM from a generator of text into an executor of complex, multi-step actions. + +### 2. The Multi-Agent Ecosystem +The next stage involves creating swarms of specialized LLM Agents that communicate and collaborate to solve enormous, non-trivial problems. One agent might be a "researcher," another a "coder," and a third an "editor," all collaborating on a project, mimicking a human team. + +### 3. The Pursuit of Artificial General Intelligence (AGI) +The ultimate horizon is Artificial General Intelligence—a machine with the capacity to understand, learn, and apply its intelligence to solve virtually any problem that a human can. + +The debate remains: Is the current path of massive scaling and improved architecture (the **scaling hypothesis**) sufficient to reach AGI, or is some fundamental, non-Transformer-based innovation required? The appearance of emergent properties strongly suggests that the scaling path has not yet exhausted its potential, keeping the AGI goal within the sights of major research labs. + +## Conclusion: The Mirror of Human Intelligence + +Large Language Models are perhaps the most profound technological platform shift since the invention of the Internet. They represent the culmination of 75 years of AI research, transitioning from rule-based systems and statistical models to the deep, parallel processing power of the Transformer architecture. + +LLMs are the definitive statistical compressors of human knowledge, capable of synthesizing our collective digital output with stunning fidelity. They have unlocked a new era of computational creativity and efficiency, driving unprecedented change across every sector. + +Yet, this power is a double-edged sword. LLMs are not inherently wise; they are merely proficient at pattern matching. They reflect and amplify human biases, they can deceive with convincing misinformation, and they introduce profound questions about accountability, labor, and the nature of creative work. + +The future of LLMs is not just about making them *smarter*, but making them *safer*, *more efficient*, and more *aligned* with human values. The challenge for the coming decade is not technical—the algorithms and compute will continue to improve—but **governance and ethical**. Humanity must learn to responsibly wield this powerful mirror of its own intelligence, ensuring that the cognitive revolution we have started leads to a future of prosperity and equitable access, rather than fragmentation and control. The architecture of intelligence is now in our hands; the path forward depends on the wisdom of its design and deployment. \ No newline at end of file diff --git a/tests/v1/e2e/test_mamba_prefix_cache.py b/tests/v1/e2e/test_mamba_prefix_cache.py index ea55360856fa..9cf62bd8b587 100644 --- a/tests/v1/e2e/test_mamba_prefix_cache.py +++ b/tests/v1/e2e/test_mamba_prefix_cache.py @@ -233,12 +233,16 @@ def fake_execute_model_fn( .block_table.cpu[0, block_idx] .item() ) - kv_cache = self.compilation_config.static_forward_context[ - mamba_layer_name - ].kv_cache - mamba_kv_cache_dict[ - num_computed_tokens - num_computed_tokens % BLOCK_SIZE - ] = (kv_cache[0][0][block_id].clone(), kv_cache[0][1][block_id].clone()) + if block_id != 0: + kv_cache = self.compilation_config.static_forward_context[ + mamba_layer_name + ].kv_cache + mamba_kv_cache_dict[ + num_computed_tokens - num_computed_tokens % BLOCK_SIZE + ] = ( + kv_cache[0][0][block_id].clone(), + kv_cache[0][1][block_id].clone(), + ) last_num_computed_tokens = num_computed_tokens @@ -421,7 +425,6 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch): apply_patch(monkeypatch) full_prompt = open(f"{os.path.dirname(__file__)}/input.txt").read() tests = { - # test case 1: no hit, accept 1 token "accept_1": TestConfig( num_prompt_tokens=554, num_generated_tokens=20, @@ -546,24 +549,55 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch): StepAction(564, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), ], ), - # "prompt_block_size": TestConfig( - # num_prompt_tokens=560, - # num_generated_tokens=10, - # num_accepted_tokens=4, - # step_actions=[ - # StepAction(0, 560, [1, 1, 1, 1], (-1, -1), (0, 0)), - # StepAction(560, 4, [], (-1, -1), (-1, -1)), - # ], - # ), - # "prompt_2_block_size": TestConfig( - # num_prompt_tokens=560 * 2, - # num_generated_tokens=10, - # num_accepted_tokens=4, - # step_actions=[ - # StepAction(0, 560, [0, 1, 1, 1, 1], (-1, -1), (1, 1)), - # StepAction(560, 4, [], (-1, -1), (-1, -1)), - # ], - # ), + "prompt_block_size": TestConfig( + num_prompt_tokens=560, + num_generated_tokens=10, + num_accepted_tokens=4, + step_actions=[ + StepAction(0, 560, [1, 1, 1, 1], (-1, -1), (0, 0)), + StepAction(560, 4, [1, 1, 1, 1, 1], (0, 1), (-1, -1)), + ], + ), + "prompt_2_block_size": TestConfig( + num_prompt_tokens=560 * 2, + num_generated_tokens=10, + num_accepted_tokens=4, + step_actions=[ + StepAction(0, 560, [1, 1, 1, 1], (-1, -1), (0, 0)), + StepAction(560, 560, [1, 1, 1, 1, 1], (0, 1), (1, 1)), + StepAction(560 * 2, 4, [0, 1, 1, 1, 1, 1], (1, 2), (-1, -1)), + ], + ), + "prompt_2_block_size_10": TestConfig( + num_prompt_tokens=560 * 2 + 10, + num_generated_tokens=10, + num_accepted_tokens=4, + step_actions=[ + StepAction(0, 560, [1, 1, 1, 1], (-1, -1), (0, 0)), + StepAction(560, 570, [1, 0, 1, 1, 1, 1], (0, 2), (-1, -1)), + StepAction(560 * 2 + 10, 4, [0, 0, 1, 1, 1, 1], (-1, -1), (-1, -1)), + ], + ), + "prompt_3_block_size": TestConfig( + num_prompt_tokens=560 * 3, + num_generated_tokens=10, + num_accepted_tokens=4, + step_actions=[ + StepAction(0, 560 * 2, [0, 1, 1, 1, 1], (-1, -1), (1, 1)), + StepAction(560 * 2, 560, [0, 1, 1, 1, 1, 1], (1, 2), (2, 2)), + StepAction(560 * 3, 4, [0, 0, 1, 1, 1, 1], (-1, -1), (-1, -1)), + ], + ), + "prompt_3_block_size_10": TestConfig( + num_prompt_tokens=560 * 3 + 10, + num_generated_tokens=10, + num_accepted_tokens=4, + step_actions=[ + StepAction(0, 560 * 2, [0, 1, 1, 1, 1], (-1, -1), (1, 1)), + StepAction(560 * 2, 570, [0, 1, 1, 1, 1, 1], (1, 2), (2, 2)), + StepAction(560 * 3 + 10, 4, [0, 0, 1, 1, 1, 1], (-1, -1), (-1, -1)), + ], + ), } engine = LLM( @@ -575,6 +609,7 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch): "method": "qwen3_next_mtp", "num_speculative_tokens": num_speculative_tokens, }, + max_num_batched_tokens=3072, hf_overrides={"num_hidden_layers": NUM_HIDDEN_LAYERS}, seed=42, ) From ebc614999b534fed399a7acd842ca357bdfdcedc Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Thu, 11 Dec 2025 00:21:52 -0800 Subject: [PATCH 040/130] revert Signed-off-by: Chen Zhang --- tests/v1/e2e/test_mamba_prefix_cache.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/v1/e2e/test_mamba_prefix_cache.py b/tests/v1/e2e/test_mamba_prefix_cache.py index 9cf62bd8b587..6c91e44c6bbf 100644 --- a/tests/v1/e2e/test_mamba_prefix_cache.py +++ b/tests/v1/e2e/test_mamba_prefix_cache.py @@ -314,8 +314,8 @@ def fake_copy_fn( def test_run_ref_mamba_state(monkeypatch: pytest.MonkeyPatch): monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") - num_generated_tokens = 20 - num_prompt_tokens = 551 + num_generated_tokens = 4000 + num_prompt_tokens = 500 sampling_params = SamplingParams(temperature=0.0, max_tokens=num_generated_tokens) full_prompt = open(f"{os.path.dirname(__file__)}/input.txt").read() fake_execute_model_fn = get_fake_execute_model_fn(GPUModelRunner.execute_model) @@ -342,9 +342,9 @@ def test_run_ref_mamba_state(monkeypatch: pytest.MonkeyPatch): f"expect token ids: {prompt_token_ids[num_prompt_tokens : num_prompt_tokens + num_generated_tokens]}" ) print(f"mamba_kv_cache_dict: {mamba_kv_cache_dict.keys()}") - ref_mamba_kv_cache_dict = torch.load("mamba_kv_cache_dict.pth") - check_mamba_state_equal(ref_mamba_kv_cache_dict, mamba_kv_cache_dict) - # torch.save(mamba_kv_cache_dict, "mamba_kv_cache_dict.pth") + # ref_mamba_kv_cache_dict = torch.load("mamba_kv_cache_dict.pth") + # check_mamba_state_equal(ref_mamba_kv_cache_dict, mamba_kv_cache_dict) + torch.save(mamba_kv_cache_dict, "mamba_kv_cache_dict.pth") def check_mamba_state_equal(mamba_state_ref: dict, mamba_state_new: dict): From da789b29c4d62d8edb1669aa91329077ef0fb120 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Thu, 11 Dec 2025 15:44:30 -0800 Subject: [PATCH 041/130] fix Signed-off-by: Chen Zhang --- tests/v1/e2e/test_mamba_prefix_cache.py | 263 ++++++++++++------------ 1 file changed, 136 insertions(+), 127 deletions(-) diff --git a/tests/v1/e2e/test_mamba_prefix_cache.py b/tests/v1/e2e/test_mamba_prefix_cache.py index 6c91e44c6bbf..02166cdfbaeb 100644 --- a/tests/v1/e2e/test_mamba_prefix_cache.py +++ b/tests/v1/e2e/test_mamba_prefix_cache.py @@ -347,10 +347,14 @@ def test_run_ref_mamba_state(monkeypatch: pytest.MonkeyPatch): torch.save(mamba_kv_cache_dict, "mamba_kv_cache_dict.pth") -def check_mamba_state_equal(mamba_state_ref: dict, mamba_state_new: dict): +def check_mamba_state_equal( + mamba_state_ref: dict, mamba_state_new: dict, keys_to_check: list[int] +): atol = 1e-2 rtol = 1e-2 - for key in mamba_state_new: + for key in keys_to_check: + assert key in mamba_state_new + assert key in mamba_state_ref # mamba state new is a subset of mamba state ref for i, (ref, new) in enumerate(zip(mamba_state_ref[key], mamba_state_new[key])): print("check_mamba_state_equal: ", ref.shape, new.shape) @@ -425,130 +429,130 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch): apply_patch(monkeypatch) full_prompt = open(f"{os.path.dirname(__file__)}/input.txt").read() tests = { - "accept_1": TestConfig( - num_prompt_tokens=554, - num_generated_tokens=20, - num_accepted_tokens=1, - step_actions=[ - StepAction(0, 554, [1, 1, 1, 1], (-1, -1), (-1, -1)), - StepAction(554, 4, [], (-1, -1), (-1, -1)), - StepAction(555, 4, [], (-1, -1), (-1, -1)), - StepAction(556, 4, [], (-1, -1), (-1, -1)), - StepAction(557, 4, [1, 1, 1, 1, 1], (0, 1), (-1, -1)), - StepAction(558, 4, [], (-1, -1), (-1, -1)), - StepAction(559, 4, [], (-1, -1), (1, 0)), - StepAction(560, 4, [], (-1, -1), (-1, -1)), - StepAction(561, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), - ], - ), - # test case 2.1: no hit, accept 2 tokens - "accept_2_1": TestConfig( - num_prompt_tokens=554, - num_generated_tokens=20, - num_accepted_tokens=2, - step_actions=[ - StepAction(0, 554, [1, 1, 1, 1], (-1, -1), (-1, -1)), - StepAction(554, 4, [], (-1, -1), (-1, -1)), - StepAction(556, 4, [], (-1, -1), (-1, -1)), - StepAction(558, 4, [1, 1, 1, 1, 1], (0, 1), (1, 0)), - StepAction(560, 4, [], (-1, -1), (-1, -1)), - StepAction(562, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), - ], - ), - # test case 2.2: no hit, accept 2 tokens - "accept_2_2": TestConfig( - num_prompt_tokens=555, - num_generated_tokens=20, - num_accepted_tokens=2, - step_actions=[ - StepAction(0, 555, [1, 1, 1, 1], (-1, -1), (-1, -1)), - StepAction(555, 4, [], (-1, -1), (-1, -1)), - StepAction(557, 4, [1, 1, 1, 1, 1], (0, 1), (-1, -1)), - StepAction(559, 4, [], (-1, -1), (1, 0)), - StepAction(561, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), - ], - ), - "accept_3_1": TestConfig( - num_prompt_tokens=553, - num_generated_tokens=20, - num_accepted_tokens=3, - step_actions=[ - StepAction(0, 553, [1, 1, 1, 1], (-1, -1), (-1, -1)), - StepAction(553, 4, [], (-1, -1), (-1, -1)), - StepAction(556, 4, [], (-1, -1), (-1, -1)), - StepAction(559, 4, [1, 1, 1, 1, 1], (0, 1), (1, 0)), - StepAction(562, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), - ], - ), - "accept_3_2": TestConfig( - num_prompt_tokens=554, - num_generated_tokens=20, - num_accepted_tokens=3, - step_actions=[ - StepAction(0, 554, [1, 1, 1, 1], (-1, -1), (-1, -1)), - StepAction(554, 4, [], (-1, -1), (-1, -1)), - StepAction(557, 4, [1, 1, 1, 1, 1], (0, 1), (1, 0)), - StepAction(560, 4, [], (-1, -1), (-1, -1)), - StepAction(563, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), - ], - ), - "accept_3_3": TestConfig( - num_prompt_tokens=555, - num_generated_tokens=20, - num_accepted_tokens=3, - step_actions=[ - StepAction(0, 555, [1, 1, 1, 1], (-1, -1), (-1, -1)), - StepAction(555, 4, [], (-1, -1), (-1, -1)), - StepAction(558, 4, [1, 1, 1, 1, 1], (0, 1), (1, 0)), - StepAction(561, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), - ], - ), - "accept_4_1": TestConfig( - num_prompt_tokens=553, - num_generated_tokens=20, - num_accepted_tokens=4, - step_actions=[ - StepAction(0, 553, [1, 1, 1, 1], (-1, -1), (-1, -1)), - StepAction(553, 4, [], (-1, -1), (-1, -1)), - StepAction(557, 4, [1, 1, 1, 1, 1], (0, 1), (1, 0)), - StepAction(561, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), - StepAction(565, 4, [], (-1, -1), (-1, -1)), - ], - ), - "accept_4_2": TestConfig( - num_prompt_tokens=554, - num_generated_tokens=25, - num_accepted_tokens=4, - step_actions=[ - StepAction(0, 554, [1, 1, 1, 1], (-1, -1), (-1, -1)), - StepAction(554, 4, [], (-1, -1), (-1, -1)), - StepAction(558, 4, [1, 1, 1, 1, 1], (0, 1), (1, 0)), - StepAction(562, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), - StepAction(566, 4, [], (-1, -1), (-1, -1)), - ], - ), - "accept_4_3": TestConfig( - num_prompt_tokens=555, - num_generated_tokens=25, - num_accepted_tokens=4, - step_actions=[ - StepAction(0, 555, [1, 1, 1, 1], (-1, -1), (-1, -1)), - StepAction(555, 4, [], (-1, -1), (-1, -1)), - StepAction(559, 4, [1, 1, 1, 1, 1], (0, 1), (1, 0)), - StepAction(563, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), - ], - ), - "accept_4_4": TestConfig( - num_prompt_tokens=556, - num_generated_tokens=25, - num_accepted_tokens=4, - step_actions=[ - StepAction(0, 556, [1, 1, 1, 1], (-1, -1), (-1, -1)), - StepAction(556, 4, [], (-1, -1), (0, 0)), - StepAction(560, 4, [1, 1, 1, 1, 1], (0, 1), (-1, -1)), - StepAction(564, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), - ], - ), + # "accept_1": TestConfig( + # num_prompt_tokens=554, + # num_generated_tokens=20, + # num_accepted_tokens=1, + # step_actions=[ + # StepAction(0, 554, [1, 1, 1, 1], (-1, -1), (-1, -1)), + # StepAction(554, 4, [], (-1, -1), (-1, -1)), + # StepAction(555, 4, [], (-1, -1), (-1, -1)), + # StepAction(556, 4, [], (-1, -1), (-1, -1)), + # StepAction(557, 4, [1, 1, 1, 1, 1], (0, 1), (-1, -1)), + # StepAction(558, 4, [], (-1, -1), (-1, -1)), + # StepAction(559, 4, [], (-1, -1), (1, 0)), + # StepAction(560, 4, [], (-1, -1), (-1, -1)), + # StepAction(561, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), + # ], + # ), + # # test case 2.1: no hit, accept 2 tokens + # "accept_2_1": TestConfig( + # num_prompt_tokens=554, + # num_generated_tokens=20, + # num_accepted_tokens=2, + # step_actions=[ + # StepAction(0, 554, [1, 1, 1, 1], (-1, -1), (-1, -1)), + # StepAction(554, 4, [], (-1, -1), (-1, -1)), + # StepAction(556, 4, [], (-1, -1), (-1, -1)), + # StepAction(558, 4, [1, 1, 1, 1, 1], (0, 1), (1, 0)), + # StepAction(560, 4, [], (-1, -1), (-1, -1)), + # StepAction(562, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), + # ], + # ), + # # test case 2.2: no hit, accept 2 tokens + # "accept_2_2": TestConfig( + # num_prompt_tokens=555, + # num_generated_tokens=20, + # num_accepted_tokens=2, + # step_actions=[ + # StepAction(0, 555, [1, 1, 1, 1], (-1, -1), (-1, -1)), + # StepAction(555, 4, [], (-1, -1), (-1, -1)), + # StepAction(557, 4, [1, 1, 1, 1, 1], (0, 1), (-1, -1)), + # StepAction(559, 4, [], (-1, -1), (1, 0)), + # StepAction(561, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), + # ], + # ), + # "accept_3_1": TestConfig( + # num_prompt_tokens=553, + # num_generated_tokens=20, + # num_accepted_tokens=3, + # step_actions=[ + # StepAction(0, 553, [1, 1, 1, 1], (-1, -1), (-1, -1)), + # StepAction(553, 4, [], (-1, -1), (-1, -1)), + # StepAction(556, 4, [], (-1, -1), (-1, -1)), + # StepAction(559, 4, [1, 1, 1, 1, 1], (0, 1), (1, 0)), + # StepAction(562, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), + # ], + # ), + # "accept_3_2": TestConfig( + # num_prompt_tokens=554, + # num_generated_tokens=20, + # num_accepted_tokens=3, + # step_actions=[ + # StepAction(0, 554, [1, 1, 1, 1], (-1, -1), (-1, -1)), + # StepAction(554, 4, [], (-1, -1), (-1, -1)), + # StepAction(557, 4, [1, 1, 1, 1, 1], (0, 1), (1, 0)), + # StepAction(560, 4, [], (-1, -1), (-1, -1)), + # StepAction(563, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), + # ], + # ), + # "accept_3_3": TestConfig( + # num_prompt_tokens=555, + # num_generated_tokens=20, + # num_accepted_tokens=3, + # step_actions=[ + # StepAction(0, 555, [1, 1, 1, 1], (-1, -1), (-1, -1)), + # StepAction(555, 4, [], (-1, -1), (-1, -1)), + # StepAction(558, 4, [1, 1, 1, 1, 1], (0, 1), (1, 0)), + # StepAction(561, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), + # ], + # ), + # "accept_4_1": TestConfig( + # num_prompt_tokens=553, + # num_generated_tokens=20, + # num_accepted_tokens=4, + # step_actions=[ + # StepAction(0, 553, [1, 1, 1, 1], (-1, -1), (-1, -1)), + # StepAction(553, 4, [], (-1, -1), (-1, -1)), + # StepAction(557, 4, [1, 1, 1, 1, 1], (0, 1), (1, 0)), + # StepAction(561, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), + # StepAction(565, 4, [], (-1, -1), (-1, -1)), + # ], + # ), + # "accept_4_2": TestConfig( + # num_prompt_tokens=554, + # num_generated_tokens=25, + # num_accepted_tokens=4, + # step_actions=[ + # StepAction(0, 554, [1, 1, 1, 1], (-1, -1), (-1, -1)), + # StepAction(554, 4, [], (-1, -1), (-1, -1)), + # StepAction(558, 4, [1, 1, 1, 1, 1], (0, 1), (1, 0)), + # StepAction(562, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), + # StepAction(566, 4, [], (-1, -1), (-1, -1)), + # ], + # ), + # "accept_4_3": TestConfig( + # num_prompt_tokens=555, + # num_generated_tokens=25, + # num_accepted_tokens=4, + # step_actions=[ + # StepAction(0, 555, [1, 1, 1, 1], (-1, -1), (-1, -1)), + # StepAction(555, 4, [], (-1, -1), (-1, -1)), + # StepAction(559, 4, [1, 1, 1, 1, 1], (0, 1), (1, 0)), + # StepAction(563, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), + # ], + # ), + # "accept_4_4": TestConfig( + # num_prompt_tokens=556, + # num_generated_tokens=25, + # num_accepted_tokens=4, + # step_actions=[ + # StepAction(0, 556, [1, 1, 1, 1], (-1, -1), (-1, -1)), + # StepAction(556, 4, [], (-1, -1), (0, 0)), + # StepAction(560, 4, [1, 1, 1, 1, 1], (0, 1), (-1, -1)), + # StepAction(564, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), + # ], + # ), "prompt_block_size": TestConfig( num_prompt_tokens=560, num_generated_tokens=10, @@ -652,7 +656,12 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch): ) assert engine.llm_engine.engine_core.engine_core.scheduler.reset_prefix_cache() print(f"End test case: {test_case_name}") - check_mamba_state_equal(mamba_state_ref, mamba_kv_cache_dict) + keys_to_check = [ + (action.postprocess_copy_idx[1] + 1) * BLOCK_SIZE + for action in test_config.step_actions + if action.postprocess_copy_idx[0] != -1 + ] + check_mamba_state_equal(mamba_state_ref, mamba_kv_cache_dict, keys_to_check) mamba_kv_cache_dict.clear() From 77480e432ef76fb92da6be915f26c99a9abb8986 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Thu, 11 Dec 2025 16:57:17 -0800 Subject: [PATCH 042/130] add more tests Signed-off-by: Chen Zhang --- tests/v1/e2e/test_mamba_prefix_cache.py | 62 +++++++++++++++++++++++-- 1 file changed, 58 insertions(+), 4 deletions(-) diff --git a/tests/v1/e2e/test_mamba_prefix_cache.py b/tests/v1/e2e/test_mamba_prefix_cache.py index 02166cdfbaeb..6177154b9496 100644 --- a/tests/v1/e2e/test_mamba_prefix_cache.py +++ b/tests/v1/e2e/test_mamba_prefix_cache.py @@ -245,6 +245,9 @@ def fake_execute_model_fn( ) last_num_computed_tokens = num_computed_tokens + else: + last_num_computed_tokens = 0 + print(f"[UNIT TEST] fake_execute_model_fn: clear last_num_computed_tokens") ret = original_execute_model_fn(self, scheduler_output, intermediate_tensors) @@ -314,7 +317,7 @@ def fake_copy_fn( def test_run_ref_mamba_state(monkeypatch: pytest.MonkeyPatch): monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") - num_generated_tokens = 4000 + num_generated_tokens = 8000 num_prompt_tokens = 500 sampling_params = SamplingParams(temperature=0.0, max_tokens=num_generated_tokens) full_prompt = open(f"{os.path.dirname(__file__)}/input.txt").read() @@ -589,7 +592,7 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch): step_actions=[ StepAction(0, 560 * 2, [0, 1, 1, 1, 1], (-1, -1), (1, 1)), StepAction(560 * 2, 560, [0, 1, 1, 1, 1, 1], (1, 2), (2, 2)), - StepAction(560 * 3, 4, [0, 0, 1, 1, 1, 1], (-1, -1), (-1, -1)), + StepAction(560 * 3, 4, [0, 0, 1, 1, 1, 1, 1], (2, 3), (-1, -1)), ], ), "prompt_3_block_size_10": TestConfig( @@ -598,8 +601,59 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch): num_accepted_tokens=4, step_actions=[ StepAction(0, 560 * 2, [0, 1, 1, 1, 1], (-1, -1), (1, 1)), - StepAction(560 * 2, 570, [0, 1, 1, 1, 1, 1], (1, 2), (2, 2)), - StepAction(560 * 3 + 10, 4, [0, 0, 1, 1, 1, 1], (-1, -1), (-1, -1)), + StepAction(560 * 2, 570, [0, 1, 0, 1, 1, 1, 1], (1, 3), (-1, -1)), + StepAction(560 * 3 + 10, 4, [0, 0, 0, 1, 1, 1, 1], (-1, -1), (-1, -1)), + ], + ), + "prompt_10_block_size": TestConfig( + num_prompt_tokens=560 * 10, + num_generated_tokens=10, + num_accepted_tokens=4, + step_actions=[ + StepAction(0, 560 * 5, [0, 0, 0, 0, 1, 1, 1, 1], (-1, -1), (4, 4)), + StepAction( + 560 * 5, + 560 * 4, + [0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1], + (4, 8), + (8, 8), + ), + StepAction( + 560 * 9, + 560, + [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1], + (8, 9), + (9, 9), + ), + StepAction( + 560 * 10, + 4, + [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1], + (9, 10), + (-1, -1), + ), + ], + ), + "prompt_10_block_size_10": TestConfig( + num_prompt_tokens=560 * 10 + 10, + num_generated_tokens=10, + num_accepted_tokens=4, + step_actions=[ + StepAction(0, 560 * 5, [0, 0, 0, 0, 1, 1, 1, 1], (-1, -1), (4, 4)), + StepAction( + 560 * 5, + 560 * 4, + [0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1], + (4, 8), + (8, 8), + ), + StepAction( + 560 * 9, + 560 + 10, + [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1], + (8, 10), + (-1, -1), + ), ], ), } From 247986b563f549873da6ae221ffcb4a5ce63d8d9 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Thu, 11 Dec 2025 16:57:38 -0800 Subject: [PATCH 043/130] add more tests Signed-off-by: Chen Zhang --- tests/v1/e2e/test_mamba_prefix_cache.py | 248 ++++++++++++------------ 1 file changed, 124 insertions(+), 124 deletions(-) diff --git a/tests/v1/e2e/test_mamba_prefix_cache.py b/tests/v1/e2e/test_mamba_prefix_cache.py index 6177154b9496..dd171ed524ca 100644 --- a/tests/v1/e2e/test_mamba_prefix_cache.py +++ b/tests/v1/e2e/test_mamba_prefix_cache.py @@ -432,130 +432,130 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch): apply_patch(monkeypatch) full_prompt = open(f"{os.path.dirname(__file__)}/input.txt").read() tests = { - # "accept_1": TestConfig( - # num_prompt_tokens=554, - # num_generated_tokens=20, - # num_accepted_tokens=1, - # step_actions=[ - # StepAction(0, 554, [1, 1, 1, 1], (-1, -1), (-1, -1)), - # StepAction(554, 4, [], (-1, -1), (-1, -1)), - # StepAction(555, 4, [], (-1, -1), (-1, -1)), - # StepAction(556, 4, [], (-1, -1), (-1, -1)), - # StepAction(557, 4, [1, 1, 1, 1, 1], (0, 1), (-1, -1)), - # StepAction(558, 4, [], (-1, -1), (-1, -1)), - # StepAction(559, 4, [], (-1, -1), (1, 0)), - # StepAction(560, 4, [], (-1, -1), (-1, -1)), - # StepAction(561, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), - # ], - # ), - # # test case 2.1: no hit, accept 2 tokens - # "accept_2_1": TestConfig( - # num_prompt_tokens=554, - # num_generated_tokens=20, - # num_accepted_tokens=2, - # step_actions=[ - # StepAction(0, 554, [1, 1, 1, 1], (-1, -1), (-1, -1)), - # StepAction(554, 4, [], (-1, -1), (-1, -1)), - # StepAction(556, 4, [], (-1, -1), (-1, -1)), - # StepAction(558, 4, [1, 1, 1, 1, 1], (0, 1), (1, 0)), - # StepAction(560, 4, [], (-1, -1), (-1, -1)), - # StepAction(562, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), - # ], - # ), - # # test case 2.2: no hit, accept 2 tokens - # "accept_2_2": TestConfig( - # num_prompt_tokens=555, - # num_generated_tokens=20, - # num_accepted_tokens=2, - # step_actions=[ - # StepAction(0, 555, [1, 1, 1, 1], (-1, -1), (-1, -1)), - # StepAction(555, 4, [], (-1, -1), (-1, -1)), - # StepAction(557, 4, [1, 1, 1, 1, 1], (0, 1), (-1, -1)), - # StepAction(559, 4, [], (-1, -1), (1, 0)), - # StepAction(561, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), - # ], - # ), - # "accept_3_1": TestConfig( - # num_prompt_tokens=553, - # num_generated_tokens=20, - # num_accepted_tokens=3, - # step_actions=[ - # StepAction(0, 553, [1, 1, 1, 1], (-1, -1), (-1, -1)), - # StepAction(553, 4, [], (-1, -1), (-1, -1)), - # StepAction(556, 4, [], (-1, -1), (-1, -1)), - # StepAction(559, 4, [1, 1, 1, 1, 1], (0, 1), (1, 0)), - # StepAction(562, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), - # ], - # ), - # "accept_3_2": TestConfig( - # num_prompt_tokens=554, - # num_generated_tokens=20, - # num_accepted_tokens=3, - # step_actions=[ - # StepAction(0, 554, [1, 1, 1, 1], (-1, -1), (-1, -1)), - # StepAction(554, 4, [], (-1, -1), (-1, -1)), - # StepAction(557, 4, [1, 1, 1, 1, 1], (0, 1), (1, 0)), - # StepAction(560, 4, [], (-1, -1), (-1, -1)), - # StepAction(563, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), - # ], - # ), - # "accept_3_3": TestConfig( - # num_prompt_tokens=555, - # num_generated_tokens=20, - # num_accepted_tokens=3, - # step_actions=[ - # StepAction(0, 555, [1, 1, 1, 1], (-1, -1), (-1, -1)), - # StepAction(555, 4, [], (-1, -1), (-1, -1)), - # StepAction(558, 4, [1, 1, 1, 1, 1], (0, 1), (1, 0)), - # StepAction(561, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), - # ], - # ), - # "accept_4_1": TestConfig( - # num_prompt_tokens=553, - # num_generated_tokens=20, - # num_accepted_tokens=4, - # step_actions=[ - # StepAction(0, 553, [1, 1, 1, 1], (-1, -1), (-1, -1)), - # StepAction(553, 4, [], (-1, -1), (-1, -1)), - # StepAction(557, 4, [1, 1, 1, 1, 1], (0, 1), (1, 0)), - # StepAction(561, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), - # StepAction(565, 4, [], (-1, -1), (-1, -1)), - # ], - # ), - # "accept_4_2": TestConfig( - # num_prompt_tokens=554, - # num_generated_tokens=25, - # num_accepted_tokens=4, - # step_actions=[ - # StepAction(0, 554, [1, 1, 1, 1], (-1, -1), (-1, -1)), - # StepAction(554, 4, [], (-1, -1), (-1, -1)), - # StepAction(558, 4, [1, 1, 1, 1, 1], (0, 1), (1, 0)), - # StepAction(562, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), - # StepAction(566, 4, [], (-1, -1), (-1, -1)), - # ], - # ), - # "accept_4_3": TestConfig( - # num_prompt_tokens=555, - # num_generated_tokens=25, - # num_accepted_tokens=4, - # step_actions=[ - # StepAction(0, 555, [1, 1, 1, 1], (-1, -1), (-1, -1)), - # StepAction(555, 4, [], (-1, -1), (-1, -1)), - # StepAction(559, 4, [1, 1, 1, 1, 1], (0, 1), (1, 0)), - # StepAction(563, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), - # ], - # ), - # "accept_4_4": TestConfig( - # num_prompt_tokens=556, - # num_generated_tokens=25, - # num_accepted_tokens=4, - # step_actions=[ - # StepAction(0, 556, [1, 1, 1, 1], (-1, -1), (-1, -1)), - # StepAction(556, 4, [], (-1, -1), (0, 0)), - # StepAction(560, 4, [1, 1, 1, 1, 1], (0, 1), (-1, -1)), - # StepAction(564, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), - # ], - # ), + "accept_1": TestConfig( + num_prompt_tokens=554, + num_generated_tokens=20, + num_accepted_tokens=1, + step_actions=[ + StepAction(0, 554, [1, 1, 1, 1], (-1, -1), (-1, -1)), + StepAction(554, 4, [], (-1, -1), (-1, -1)), + StepAction(555, 4, [], (-1, -1), (-1, -1)), + StepAction(556, 4, [], (-1, -1), (-1, -1)), + StepAction(557, 4, [1, 1, 1, 1, 1], (0, 1), (-1, -1)), + StepAction(558, 4, [], (-1, -1), (-1, -1)), + StepAction(559, 4, [], (-1, -1), (1, 0)), + StepAction(560, 4, [], (-1, -1), (-1, -1)), + StepAction(561, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), + ], + ), + # test case 2.1: no hit, accept 2 tokens + "accept_2_1": TestConfig( + num_prompt_tokens=554, + num_generated_tokens=20, + num_accepted_tokens=2, + step_actions=[ + StepAction(0, 554, [1, 1, 1, 1], (-1, -1), (-1, -1)), + StepAction(554, 4, [], (-1, -1), (-1, -1)), + StepAction(556, 4, [], (-1, -1), (-1, -1)), + StepAction(558, 4, [1, 1, 1, 1, 1], (0, 1), (1, 0)), + StepAction(560, 4, [], (-1, -1), (-1, -1)), + StepAction(562, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), + ], + ), + # test case 2.2: no hit, accept 2 tokens + "accept_2_2": TestConfig( + num_prompt_tokens=555, + num_generated_tokens=20, + num_accepted_tokens=2, + step_actions=[ + StepAction(0, 555, [1, 1, 1, 1], (-1, -1), (-1, -1)), + StepAction(555, 4, [], (-1, -1), (-1, -1)), + StepAction(557, 4, [1, 1, 1, 1, 1], (0, 1), (-1, -1)), + StepAction(559, 4, [], (-1, -1), (1, 0)), + StepAction(561, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), + ], + ), + "accept_3_1": TestConfig( + num_prompt_tokens=553, + num_generated_tokens=20, + num_accepted_tokens=3, + step_actions=[ + StepAction(0, 553, [1, 1, 1, 1], (-1, -1), (-1, -1)), + StepAction(553, 4, [], (-1, -1), (-1, -1)), + StepAction(556, 4, [], (-1, -1), (-1, -1)), + StepAction(559, 4, [1, 1, 1, 1, 1], (0, 1), (1, 0)), + StepAction(562, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), + ], + ), + "accept_3_2": TestConfig( + num_prompt_tokens=554, + num_generated_tokens=20, + num_accepted_tokens=3, + step_actions=[ + StepAction(0, 554, [1, 1, 1, 1], (-1, -1), (-1, -1)), + StepAction(554, 4, [], (-1, -1), (-1, -1)), + StepAction(557, 4, [1, 1, 1, 1, 1], (0, 1), (1, 0)), + StepAction(560, 4, [], (-1, -1), (-1, -1)), + StepAction(563, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), + ], + ), + "accept_3_3": TestConfig( + num_prompt_tokens=555, + num_generated_tokens=20, + num_accepted_tokens=3, + step_actions=[ + StepAction(0, 555, [1, 1, 1, 1], (-1, -1), (-1, -1)), + StepAction(555, 4, [], (-1, -1), (-1, -1)), + StepAction(558, 4, [1, 1, 1, 1, 1], (0, 1), (1, 0)), + StepAction(561, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), + ], + ), + "accept_4_1": TestConfig( + num_prompt_tokens=553, + num_generated_tokens=20, + num_accepted_tokens=4, + step_actions=[ + StepAction(0, 553, [1, 1, 1, 1], (-1, -1), (-1, -1)), + StepAction(553, 4, [], (-1, -1), (-1, -1)), + StepAction(557, 4, [1, 1, 1, 1, 1], (0, 1), (1, 0)), + StepAction(561, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), + StepAction(565, 4, [], (-1, -1), (-1, -1)), + ], + ), + "accept_4_2": TestConfig( + num_prompt_tokens=554, + num_generated_tokens=25, + num_accepted_tokens=4, + step_actions=[ + StepAction(0, 554, [1, 1, 1, 1], (-1, -1), (-1, -1)), + StepAction(554, 4, [], (-1, -1), (-1, -1)), + StepAction(558, 4, [1, 1, 1, 1, 1], (0, 1), (1, 0)), + StepAction(562, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), + StepAction(566, 4, [], (-1, -1), (-1, -1)), + ], + ), + "accept_4_3": TestConfig( + num_prompt_tokens=555, + num_generated_tokens=25, + num_accepted_tokens=4, + step_actions=[ + StepAction(0, 555, [1, 1, 1, 1], (-1, -1), (-1, -1)), + StepAction(555, 4, [], (-1, -1), (-1, -1)), + StepAction(559, 4, [1, 1, 1, 1, 1], (0, 1), (1, 0)), + StepAction(563, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), + ], + ), + "accept_4_4": TestConfig( + num_prompt_tokens=556, + num_generated_tokens=25, + num_accepted_tokens=4, + step_actions=[ + StepAction(0, 556, [1, 1, 1, 1], (-1, -1), (-1, -1)), + StepAction(556, 4, [], (-1, -1), (0, 0)), + StepAction(560, 4, [1, 1, 1, 1, 1], (0, 1), (-1, -1)), + StepAction(564, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), + ], + ), "prompt_block_size": TestConfig( num_prompt_tokens=560, num_generated_tokens=10, From f438ee01e52cf44415c80065ca24a3c2d3067e1d Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Thu, 11 Dec 2025 23:31:39 -0800 Subject: [PATCH 044/130] change env to config Signed-off-by: Chen Zhang --- examples/offline_inference/run.py | 35 +-- my_tests/run_op_prefix_cache.sh | 2 +- tests/v1/e2e/input.txt | 200 ++++++++++++++++++ tests/v1/e2e/test_mamba_prefix_cache.py | 2 +- vllm/config/cache.py | 18 ++ vllm/engine/arg_utils.py | 11 +- vllm/model_executor/layers/mamba/abstract.py | 6 +- .../layers/mamba/mamba_mixer.py | 3 +- .../layers/mamba/mamba_mixer2.py | 4 +- vllm/model_executor/models/config.py | 20 +- vllm/model_executor/models/lfm2.py | 9 +- vllm/model_executor/models/qwen3_next.py | 8 +- vllm/model_executor/models/qwen3_next_mtp.py | 8 +- vllm/v1/attention/backends/gdn_attn.py | 16 +- vllm/v1/attention/backends/linear_attn.py | 5 +- vllm/v1/attention/backends/mamba1_attn.py | 18 +- vllm/v1/attention/backends/mamba2_attn.py | 14 +- vllm/v1/attention/backends/mamba_attn.py | 6 +- vllm/v1/attention/backends/short_conv_attn.py | 6 +- vllm/v1/core/kv_cache_coordinator.py | 15 ++ vllm/v1/core/kv_cache_manager.py | 4 +- vllm/v1/core/sched/scheduler.py | 92 ++++---- vllm/v1/core/single_type_kv_cache_manager.py | 21 +- vllm/v1/kv_cache_interface.py | 13 +- vllm/v1/worker/gpu_model_runner.py | 22 +- vllm/v1/worker/utils.py | 7 +- 26 files changed, 384 insertions(+), 181 deletions(-) diff --git a/examples/offline_inference/run.py b/examples/offline_inference/run.py index 40dc4b31dad2..353fd72c0e32 100644 --- a/examples/offline_inference/run.py +++ b/examples/offline_inference/run.py @@ -1,11 +1,12 @@ from vllm import LLM, SamplingParams import time + def main(): MODEL = "Qwen/Qwen3-Next-80B-A3B-Instruct" PROMPT_MULTIPLE = 6 sampling_params = SamplingParams(temperature=0.0, max_tokens=300) - prefix = ( # examples/offline_inference/prefix_caching.py + prefix = ( # examples/offline_inference/prefix_caching.py "Your name is QQQQ " "You are an expert school principal, skilled in effectively managing " "faculty and staff. Draft 10-15 questions for a potential first grade " @@ -14,31 +15,39 @@ def main(): "coming in for a first-round panel interview for a 8th grade Math " "teaching role. They have 5 years of previous teaching experience " "as an assistant teacher at a co-ed, public school with experience " - "in middle school math teaching. ") - prefix2 = ("Based on these information, fulfill " - "the following paragraph: ") + "in middle school math teaching. " + ) + prefix2 = "Based on these information, fulfill the following paragraph: " prompt = PROMPT_MULTIPLE * prefix + prefix2 + "Hello, my name is" # print('Prompt length:', ) # for APC in [False, True]: for APC in [True]: - engine = LLM(model=MODEL, enable_prefix_caching=APC, enforce_eager=True, tensor_parallel_size=4, - block_size=288, - # load_format="dummy", - speculative_config={"method": "qwen3_next_mtp", "num_speculative_tokens": 2} + engine = LLM( + model=MODEL, + enable_prefix_caching=APC, + enforce_eager=True, + tensor_parallel_size=4, + block_size=288, + mamba_cache_mode="align", + # load_format="dummy", + speculative_config={ + "method": "qwen3_next_mtp", + "num_speculative_tokens": 2, + }, ) for i in range(3): if i == 0: - print('Warm-up') + print("Warm-up") if i == 1: - print('Measuring') + print("Measuring") start_time = time.time() outputs = engine.generate(prompt, sampling_params) - print('APC:', APC, i, f"Generated text: {outputs[0].outputs[0].text!r}") + print("APC:", APC, i, f"Generated text: {outputs[0].outputs[0].text!r}") # for m in engine.llm_engine.get_metrics(): # if 'vllm:prefix_cache_hits' in m.name: # print(m.name, m.value) - print('APC:', APC, "loop took --- %s seconds ---" % (time.time() - start_time)) + print("APC:", APC, "loop took --- %s seconds ---" % (time.time() - start_time)) if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/my_tests/run_op_prefix_cache.sh b/my_tests/run_op_prefix_cache.sh index f7552a4f1972..a8cdc6baed50 100755 --- a/my_tests/run_op_prefix_cache.sh +++ b/my_tests/run_op_prefix_cache.sh @@ -20,7 +20,6 @@ fi env_vars=( # "CUDA_LAUNCH_BLOCKING=0" "CUDA_VISIBLE_DEVICES=0,1,2,3" - "VLLM_USE_LIGHTER_MAMBA_CACHE=1" # "CUDA_VISIBLE_DEVICES=6,7" # "VLLM_ATTENTION_BACKEND=FLASH_ATTN" # "VLLM_FLASH_ATTN_VERSION=3" @@ -53,6 +52,7 @@ CMD+=( --enable-prefix-caching # --no-enable-chunked-prefill --enable-chunked-prefill + --mamba-cache-mode align --max-num-batched-tokens 8192 --distributed-executor-backend mp --block-size 64 diff --git a/tests/v1/e2e/input.txt b/tests/v1/e2e/input.txt index b10b1bbcba2a..f5cee144f9e8 100644 --- a/tests/v1/e2e/input.txt +++ b/tests/v1/e2e/input.txt @@ -196,4 +196,204 @@ LLMs are the definitive statistical compressors of human knowledge, capable of s Yet, this power is a double-edged sword. LLMs are not inherently wise; they are merely proficient at pattern matching. They reflect and amplify human biases, they can deceive with convincing misinformation, and they introduce profound questions about accountability, labor, and the nature of creative work. +The future of LLMs is not just about making them *smarter*, but making them *safer*, *more efficient*, and more *aligned* with human values. The challenge for the coming decade is not technical—the algorithms and compute will continue to improve—but **governance and ethical**. Humanity must learn to responsibly wield this powerful mirror of its own intelligence, ensuring that the cognitive revolution we have started leads to a future of prosperity and equitable access, rather than fragmentation and control. The architecture of intelligence is now in our hands; the path forward depends on the wisdom of its design and deployment. + +# The Architecture of Intelligence: A Deep Dive into Large Language Models (LLMs) + +## Introduction: The New Cognitive Revolution + +In the annals of computing history, few technologies have burst onto the global stage with the same immediate and transformative impact as Large Language Models (LLMs). Emerging from the confluence of decades of theoretical research and the exponential growth of computational power and data, LLMs like GPT, Gemini, and Claude have transitioned Artificial Intelligence (AI) from a niche academic pursuit to the central utility of the digital age. + +An LLM is not merely a sophisticated piece of software; it is a complex, deep neural network designed to understand, process, and generate human language with startling fluency, coherence, and context. These models serve as the probabilistic engines of a new cognitive revolution, capable of tasks that range from synthesizing vast datasets and translating languages to creating novel code and engaging in philosophical debate. + +This comprehensive article explores the complete landscape of Large Language Models. We will trace their historical lineage, demystify the revolutionary architecture upon which they are built, detail the arduous training process, analyze the emergent capabilities and inherent flaws, survey their massive commercial and social applications, and, finally, grapple with the profound ethical and strategic challenges they pose for the future of humanity. + +## Part I: The Historical Foundations of Language Modeling + +The concept of a machine generating human language has a history far longer than the digital computer. Its modern journey, however, can be segmented into distinct eras, each overcoming the limitations of the last. + +### 1. Statistical Language Models (1980s – 2000s) +The earliest forms of language modeling were rooted in statistics and probability theory. These were dominated by **n-gram models**, inspired by the mathematical work of Andrey Markov. An n-gram model predicts the probability of the next word ($w_i$) based solely on the previous $n-1$ words ($w_{i-(n-1)}, \dots, w_{i-1}$). + +$$P(w_i | w_{1}^{i-1}) \approx P(w_i | w_{i-(n-1)}^{i-1})$$ + +These models were simple, explainable, and formed the backbone of early machine translation and speech recognition systems, notably pioneering corpus-based language modeling at IBM. However, they suffered from **the curse of dimensionality** and **data sparsity**. As $n$ increased (to capture more context), the number of possible word sequences grew exponentially, making it impossible to accurately estimate probabilities for sequences not seen in the training data. + +### 2. Neural Language Models and Deep Learning (2000s – 2017) +The transition from statistical methods to neural networks addressed the data sparsity problem. The breakthrough came with the introduction of **word embeddings** (pioneered by Bengio in 2003, and popularized by Word2Vec in 2013). + +Instead of treating words as discrete, independent symbols, word embeddings represent each word as a dense, real-valued vector in a multi-dimensional space. Words with similar meanings (e.g., "King," "Queen," "Man," "Woman") are mapped closer together in this geometric space. This allowed the models to generalize, moving beyond simple word co-occurrence to semantic relationships. + +The workhorse of this era was the **Recurrent Neural Network (RNN)**, particularly the **Long Short-Term Memory (LSTM)** network. RNNs process sequences word-by-word, maintaining a "hidden state" or "memory cell" that accumulates information from the previous steps. This allowed them to handle longer-term dependencies than n-gram models. However, the sequential nature of RNNs created two major issues: +1. **Slow Training:** Processing must be strictly sequential, preventing the use of modern parallel computing hardware like GPUs. +2. **Vanishing/Exploding Gradients:** For very long sequences, the error signals used during training (gradients) either vanished (making the model forget the beginning of the text) or exploded (making training unstable). + +### 3. The Attention Mechanism (2014) +The first true step toward the LLM revolution was the introduction of the **Attention Mechanism** in 2014. Used initially within RNN-based encoder-decoder architectures (the basis of Google Translate at the time), attention allowed the model to dynamically weigh the importance of different parts of the input sequence when generating a specific part of the output. This was crucial for tasks like translation, where the most relevant input word might not be the adjacent one. + +## Part II: The Transformer Architecture (2017 - Present) + +The year 2017 marks the true beginning of the LLM era with the publication of "Attention Is All You Need" by researchers at Google. This paper proposed the **Transformer** architecture, which jettisoned recurrence entirely and relied *only* on the attention mechanism. + +### The Encoder-Decoder Foundation +The original Transformer model consists of two main stacks: an **Encoder** and a **Decoder**. +* **Encoder:** Processes the input sequence (e.g., an English sentence), creating a robust, context-aware numerical representation of it. +* **Decoder:** Takes the Encoder's output and iteratively generates the output sequence (e.g., the French translation). + +### The Self-Attention Breakthrough +The core innovation is **Self-Attention**. It allows the model to calculate how much every word in the input sequence relates to every other word *within that same sequence*. This is done through a mathematical process involving three vector representations for each input token: + +1. **Query ($Q$):** Represents the token being processed—the question being asked. +2. **Key ($K$):** Represents all other tokens—the information that can be searched. +3. **Value ($V$):** Represents the actual information content of all other tokens. + +The model computes the dot product of the $Q$ vector with all $K$ vectors to get **attention scores**. These scores, after normalization (using a Softmax function), determine how much of the $V$ vectors should be aggregated to create the new, context-rich representation of the original token. + +$$\text{Attention}(Q, K, V) = \text{Softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$ + +This allows the model to achieve **parallel processing**. Unlike sequential RNNs, every word's vector representation can be calculated simultaneously, leveraging the massive parallel capabilities of GPUs and leading to unprecedented scalability. + +### Positional Encoding +Since the Transformer has no inherent recurrence (no left-to-right reading), the model needs a way to know the order of the words. This is solved by **Positional Encoding**—adding a vector to the input embeddings that contains information about the word’s absolute or relative position in the sequence. Without this, the phrase "Dog bites man" would be processed identically to "Man bites dog." + +### Model Variants: BERT vs. GPT +The Transformer architecture gave rise to three major model families: + +1. **Encoder-Only (e.g., BERT, RoBERTa):** Used primarily for *understanding* tasks (classification, named entity recognition, sentiment analysis). They are excellent at bidirectional context (looking both backward and forward in a sentence). +2. **Decoder-Only (e.g., GPT, Llama):** Used primarily for *generation* tasks. The decoder is constrained by a **causal mask** that prevents it from looking at future tokens, forcing it to generate text sequentially, word-by-word. These models have become the dominant architecture for conversational AI. +3. **Encoder-Decoder (e.g., T5, BART):** Used for sequence-to-sequence tasks like translation and summarization. + +## Part III: The Training Lifecycle of an LLM + +The development of an LLM is a complex, multi-stage process involving massive computational resources, vast data curation efforts, and sophisticated human intervention. + +### 1. Data Curation and Tokenization +The first step is gathering and cleaning the training corpus. Modern LLMs are trained on hundreds of terabytes or even petabytes of text, often sourced from: +* **CommonCrawl:** A massive, open-source scrape of the public internet. +* **Filtered Web Text:** Highly curated, higher-quality web pages. +* **Books and Literature:** Digitized libraries. +* **Code Repositories:** Such as GitHub, to instill programming knowledge. +* **Wikipedia:** Structured knowledge bases. + +Data is meticulously filtered to remove low-quality content, boilerplate text, and offensive material. The text is then broken down into **tokens** using a process like **Byte-Pair Encoding (BPE)**. Tokens are the minimal units of meaning the model processes, bridging the gap between human language and numerical vectors. + +### 2. Pre-Training: Self-Supervised Learning +The core of LLM training is the **Pre-Training** phase. The model's hundreds of billions of parameters are initialized, and it is fed the massive, unlabeled dataset. The primary objective is **Next-Token Prediction** (or autoregressive modeling): predicting the next most probable token in a sequence, given all previous tokens. + +* **Objective Function:** The model minimizes the **Loss Function** (often **Cross-Entropy Loss**), which measures the difference between the model's predicted probability distribution over the vocabulary and the actual next token. +* **Optimization:** The model iteratively adjusts its weights using **Backpropagation** and an **Optimizer** (e.g., Adam or its variants) to reduce this loss. + +This phase, costing millions of dollars in GPU time, imbues the model with its fundamental knowledge base, grammar, syntax, and a basic, structural understanding of the world. It is through this pure statistical exercise that "reasoning" begins to emerge. + +### 3. Fine-Tuning and Alignment +A raw pre-trained model is highly knowledgeable but often unhelpful and potentially toxic. It will simply continue the statistical pattern of the input, regardless of intent. Alignment is the process of making the model follow instructions and adhere to ethical guidelines. + +#### A. Supervised Fine-Tuning (SFT) +The model is trained on a smaller, high-quality, human-curated dataset of prompts and desired, high-quality responses. This teaches the model a conversational style—how to act as an assistant, answer questions, and follow complex directions. + +#### B. Reinforcement Learning from Human Feedback (RLHF) +RLHF is the key component that created the conversational brilliance of models like ChatGPT. +1. **Response Generation:** For a given prompt, the LLM generates several possible answers. +2. **Human Ranking:** Human labelers rank these responses from best to worst based on helpfulness, accuracy, and safety. +3. **Reward Model Training:** A separate, smaller model called the **Reward Model (RM)** is trained to predict the human preference score for any response. The RM effectively learns "what a good answer looks like." +4. **Policy Optimization:** The main LLM is then fine-tuned using a Reinforcement Learning algorithm (like **Proximal Policy Optimization, PPO**) to maximize the score given by the Reward Model. + +This process explicitly aligns the model's objective function with human values, a crucial step in preparing the model for public deployment. + +## Part IV: Emergent Capabilities and Inherent Limitations + +The path from a neural network to a cognitive tool is marked by phenomena that both inspire awe and caution. + +### The Phenomenon of Emergence +As LLMs crossed certain thresholds—specifically in parameter count (size) and training data volume—researchers observed **Emergent Capabilities**. These are skills that the model was never explicitly trained for, yet they appear spontaneously. + +* **In-Context Learning (ICL):** The ability to learn a new task from a few examples provided directly in the prompt, without needing formal fine-tuning (Few-Shot Learning). +* **Chain-of-Thought (CoT) Reasoning:** The ability to decompose complex, multi-step problems into sequential reasoning steps, often unlocked by simply telling the model to "think step-by-step." This dramatically improves performance on arithmetic, common sense, and symbolic logic tasks. +* **Multilingual and Code Proficiency:** Models trained primarily on English and code surprisingly develop high-level proficiency in dozens of other languages and complex programming languages. + +These emergent properties suggest that the simple task of next-token prediction, when scaled sufficiently, leads to a kind of generalized, implicit world model—a probabilistic simulation of human knowledge and reasoning. + +### The Challenge of Hallucination +The most significant and stubborn limitation of LLMs is **Hallucination**—the generation of factually incorrect, nonsensical, or unfaithful content that is nevertheless syntactically plausible. + +The root cause lies in the model's core function: it is a **prediction engine, not a retrieval engine**. It does not access an external database of facts; it samples the most statistically likely sequence of tokens based on its internal, compressed world model. If the highest-probability sequence *looks* like a scientific citation but is entirely fabricated, the model will generate it. + +Mitigation strategies, such as **Retrieval-Augmented Generation (RAG)**, which links the LLM to a real-time, verifiable external knowledge source (like a search index or a company database), are essential for using LLMs in high-stakes, fact-based applications. + +## Part V: The Expanding Ecosystem and Applications + +The LLM ecosystem is diversifying rapidly, moving beyond the simple "chatbot" into powerful, specialized tools. + +### 1. Model Scaling and Efficiency +The pursuit of ever-larger models has reached its limits due to cost and data scarcity. The frontier has shifted to efficiency and specialization. +* **Mixture-of-Experts (MoE):** Models like Mixtral use a routing mechanism to activate only a subset of specialized "expert" neural networks for any given query. This allows the model to have a massive total parameter count (high knowledge capacity) while only using a fraction of the computational power (high efficiency). +* **Quantization and Pruning:** Techniques used to reduce the size and computational demands of models, making them executable on smaller devices (e.g., a mobile phone or a personal laptop). + +### 2. Multimodality +The most significant recent breakthrough is the transition from LLMs (Large Language Models) to **LMMs (Large Multimodal Models)**. These models are trained not just on text, but also on images, audio, and video data, allowing them to: +* **Visual Reasoning:** Analyze a complex graph, a photograph, or a technical diagram and answer questions about its content. +* **Audio Processing:** Transcribe, summarize, and understand the context of spoken language directly. +* **Seamless Integration:** Accept a prompt containing text and an image simultaneously (e.g., "Describe this image and write a poem about it"). + +### 3. Industry Applications +LLMs are no longer experimental; they are becoming foundational infrastructure across nearly every industry: +* **Software Engineering:** Automated code generation (e.g., GitHub Copilot), debugging, code translation between languages, and writing documentation. +* **Knowledge Work & Productivity:** Summarizing long documents, drafting complex reports, synthesizing research, and managing data from unstructured sources. +* **Customer Service & Sales:** Highly personalized and efficient conversational AI bots that can handle complex queries beyond simple FAQs. +* **Medicine and Law:** Assisting in drafting legal briefs, summarizing medical records, and cross-referencing diagnostic information (always requiring human oversight). +* **Creative Arts:** Generating marketing copy, scriptwriting, music composition (in conjunction with other AI models), and video production assets. + +## Part VI: The Ethical and Societal Labyrinth + +The power of LLMs brings with it a commensurately large set of ethical, social, and economic risks that demand global governance and responsible development. + +### 1. Bias, Fairness, and Amplification +LLMs are fundamentally statistical mirrors of their training data. If the internet contains biases related to gender, race, or geography, the model will ingest, amplify, and operationalize those biases. +* **Stereotype Reinforcement:** A model might associate certain professions (e.g., "engineer") predominantly with one gender, leading to biased outputs in hiring tools. +* **Harmful Generalizations:** Biases can lead to unfair or discriminatory decision-making when the models are deployed in high-stakes areas like loan applications or judicial risk assessment. +Mitigating bias requires meticulous data curation, adversarial testing, and post-processing "guardrails," but complete elimination remains technically elusive. + +### 2. Misinformation and Disinformation +The ability of LLMs to generate highly convincing, fluent text at scale is a threat to information integrity. Malicious actors can use these tools to: +* **Automate Phishing and Scams:** Generate personalized, sophisticated deceptive content. +* **Create Deepfake Text:** Impersonate real individuals or organizations with convincing prose. +* **Fabricate "Fake News" and Propaganda:** Generate massive volumes of highly plausible, factually false content, overwhelming traditional fact-checking mechanisms and accelerating the breakdown of public trust. + +### 3. Data Privacy and Security +LLMs pose risks related to data ingestion and leakage: +* **Training Data Memorization:** Models can, in rare cases, memorize and regurgitate personally identifiable information (PII) or copyrighted material from their vast training corpus. +* **Inference Attack (Data Leakage):** If a user provides proprietary or sensitive information as a prompt, that data may be inadvertently used to train future iterations of the model or leak through side channels, raising major security concerns for enterprise adoption. + +### 4. Environmental Impact +The scale of LLMs has a significant environmental footprint. Training a single frontier model requires months of continuous operation on thousands of GPUs, consuming energy equivalent to hundreds of homes for a year. The high computational cost raises questions about the long-term sustainability and equitable access to the technology. + +### 5. Economic Disruption and Labor +LLMs are directly impacting knowledge-based professions, particularly those involving content creation, data synthesis, and routine communication. While optimists argue the technology will mostly automate mundane tasks, freeing humans for higher-level work, policymakers and economists are grappling with the reality of rapid job displacement, income inequality, and the need for massive reskilling initiatives. + +## Part VII: The Frontier—The Path to Agentic AI and AGI + +The current state of the art is fleeting. The research community is pushing toward systems that are more autonomous, capable, and integrated. + +### 1. Agentic AI +The shift from a "Chatbot" to an "Agent" is the immediate future. Current LLMs are **reactive** (Question $\rightarrow$ Answer). An Agentic LLM is **proactive and goal-oriented**. +* **Goal:** The user provides a high-level goal (e.g., "Find the cheapest flight to Tokyo next month and book a hotel near the Shinjuku station."). +* **Planning:** The LLM breaks the goal into sub-tasks (Search flights, Compare prices, Search hotels, Check availability, Execute booking actions). +* **Tool Use:** The LLM integrates external tools (search engines, flight APIs, email/calendar APIs) to complete the tasks autonomously, engaging in a trial-and-error loop until the goal is achieved. This transforms the LLM from a generator of text into an executor of complex, multi-step actions. + +### 2. The Multi-Agent Ecosystem +The next stage involves creating swarms of specialized LLM Agents that communicate and collaborate to solve enormous, non-trivial problems. One agent might be a "researcher," another a "coder," and a third an "editor," all collaborating on a project, mimicking a human team. + +### 3. The Pursuit of Artificial General Intelligence (AGI) +The ultimate horizon is Artificial General Intelligence—a machine with the capacity to understand, learn, and apply its intelligence to solve virtually any problem that a human can. + +The debate remains: Is the current path of massive scaling and improved architecture (the **scaling hypothesis**) sufficient to reach AGI, or is some fundamental, non-Transformer-based innovation required? The appearance of emergent properties strongly suggests that the scaling path has not yet exhausted its potential, keeping the AGI goal within the sights of major research labs. + +## Conclusion: The Mirror of Human Intelligence + +Large Language Models are perhaps the most profound technological platform shift since the invention of the Internet. They represent the culmination of 75 years of AI research, transitioning from rule-based systems and statistical models to the deep, parallel processing power of the Transformer architecture. + +LLMs are the definitive statistical compressors of human knowledge, capable of synthesizing our collective digital output with stunning fidelity. They have unlocked a new era of computational creativity and efficiency, driving unprecedented change across every sector. + +Yet, this power is a double-edged sword. LLMs are not inherently wise; they are merely proficient at pattern matching. They reflect and amplify human biases, they can deceive with convincing misinformation, and they introduce profound questions about accountability, labor, and the nature of creative work. + The future of LLMs is not just about making them *smarter*, but making them *safer*, *more efficient*, and more *aligned* with human values. The challenge for the coming decade is not technical—the algorithms and compute will continue to improve—but **governance and ethical**. Humanity must learn to responsibly wield this powerful mirror of its own intelligence, ensuring that the cognitive revolution we have started leads to a future of prosperity and equitable access, rather than fragmentation and control. The architecture of intelligence is now in our hands; the path forward depends on the wisdom of its design and deployment. \ No newline at end of file diff --git a/tests/v1/e2e/test_mamba_prefix_cache.py b/tests/v1/e2e/test_mamba_prefix_cache.py index dd171ed524ca..1b6a58b6276f 100644 --- a/tests/v1/e2e/test_mamba_prefix_cache.py +++ b/tests/v1/e2e/test_mamba_prefix_cache.py @@ -394,7 +394,6 @@ class TestConfig: def apply_patch(monkeypatch: pytest.MonkeyPatch): - monkeypatch.setenv("VLLM_USE_LIGHTER_MAMBA_CACHE", "1") monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") fake_sample_fn = get_fake_sample_fn() @@ -663,6 +662,7 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch): enable_prefix_caching=True, enforce_eager=True, block_size=BLOCK_SIZE, + mamba_cache_mode="align", speculative_config={ "method": "qwen3_next_mtp", "num_speculative_tokens": num_speculative_tokens, diff --git a/vllm/config/cache.py b/vllm/config/cache.py index 067799a44db3..71318269ea00 100644 --- a/vllm/config/cache.py +++ b/vllm/config/cache.py @@ -30,6 +30,7 @@ "fp8_ds_mla", ] MambaDType = Literal["auto", "float32", "float16"] +MambaCacheMode = Literal["all", "align", "none"] PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor", "xxhash", "xxhash_cbor"] KVOffloadingBackend = Literal["native", "lmcache"] @@ -122,6 +123,14 @@ class CacheConfig: """The data type to use for the Mamba cache (ssm state only, conv state will still be controlled by mamba_cache_dtype). If set to 'auto', the data type for the ssm state will be determined by mamba_cache_dtype.""" + mamba_cache_mode: MambaCacheMode = "none" + """The cache strategy for Mamba layers. + - "none": set when prefix caching is disabled. + - "all": cache the mamba state of all tokens at position i * block_size. This is + the default behavior when prefix caching is enabled. + - "align": only cache the mamba state of the last token of each scheduler step and + when the token is at position i * block_size. + """ # Will be set after profiling. num_gpu_blocks: int | None = field(default=None, init=False) @@ -230,3 +239,12 @@ def verify_with_parallel_config( raise ValueError("Too large swap space. " + msg) elif cpu_memory_usage > 0.4 * total_cpu_memory: logger.warning("Possibly too large swap space. %s", msg) + + def __post_init__(self) -> None: + if self.enable_prefix_caching: + if self.mamba_cache_mode == "none": + self.mamba_cache_mode = "last" + else: + assert self.mamba_cache_mode == "none", ( + "mamba_cache_mode must be 'none' when prefix caching is disabled" + ) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index f9ddf7ed2576..ce036a4efc50 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -61,6 +61,7 @@ BlockSize, CacheDType, KVOffloadingBackend, + MambaCacheMode, MambaDType, PrefixCachingHashAlgo, ) @@ -549,6 +550,7 @@ class EngineArgs: mamba_cache_dtype: MambaDType = CacheConfig.mamba_cache_dtype mamba_ssm_cache_dtype: MambaDType = CacheConfig.mamba_ssm_cache_dtype mamba_block_size: int | None = get_field(CacheConfig, "mamba_block_size") + mamba_cache_mode: MambaCacheMode = CacheConfig.mamba_cache_mode additional_config: dict[str, Any] = get_field(VllmConfig, "additional_config") @@ -920,6 +922,9 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: cache_group.add_argument( "--mamba-block-size", **cache_kwargs["mamba_block_size"] ) + cache_group.add_argument( + "--mamba-cache-mode", **cache_kwargs["mamba_cache_mode"] + ) cache_group.add_argument( "--kv-offloading-size", **cache_kwargs["kv_offloading_size"] ) @@ -1356,8 +1361,9 @@ def create_engine_config( f"dcp_size={self.decode_context_parallel_size}." ) - if (envs.VLLM_USE_LIGHTER_MAMBA_CACHE - and self.enable_prefix_caching + if ( + self.mamba_cache_mode == "align" + and self.enable_prefix_caching and model_config.is_hybrid ): assert self.enable_chunked_prefill, ( @@ -1380,6 +1386,7 @@ def create_engine_config( mamba_cache_dtype=self.mamba_cache_dtype, mamba_ssm_cache_dtype=self.mamba_ssm_cache_dtype, mamba_block_size=self.mamba_block_size, + mamba_cache_mode=self.mamba_cache_mode, kv_offloading_size=self.kv_offloading_size, kv_offloading_backend=self.kv_offloading_backend, ) diff --git a/vllm/model_executor/layers/mamba/abstract.py b/vllm/model_executor/layers/mamba/abstract.py index 37466e90d99a..c9dd0232bc2b 100644 --- a/vllm/model_executor/layers/mamba/abstract.py +++ b/vllm/model_executor/layers/mamba/abstract.py @@ -4,8 +4,6 @@ from collections.abc import Iterable import torch - -from vllm import envs from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.selector import get_mamba_attn_backend from vllm.config import VllmConfig @@ -49,9 +47,7 @@ def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None: raise NotImplementedError( "Mamba with speculative decoding is not supported yet." ) - if not envs.VLLM_USE_LIGHTER_MAMBA_CACHE: - mamba_block_size = vllm_config.cache_config.mamba_block_size - elif vllm_config.cache_config.enable_prefix_caching: + if vllm_config.cache_config.enable_prefix_caching: mamba_block_size = vllm_config.cache_config.block_size else: mamba_block_size = vllm_config.model_config.max_model_len diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index 11632a91477c..789776e923e5 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -7,7 +7,6 @@ from torch import nn from torch.nn.parameter import Parameter -from vllm import envs from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config from vllm.distributed.parallel_state import ( get_tensor_model_parallel_rank, @@ -242,7 +241,7 @@ def forward_cuda(self, hidden_states: torch.Tensor, output: torch.Tensor): assert self.cache_config is not None mamba_block_size = self.cache_config.mamba_block_size prefix_caching_enabled = ( - not envs.VLLM_USE_LIGHTER_MAMBA_CACHE + self.cache_config.mamba_cache_mode != "align" and self.cache_config.enable_prefix_caching ) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index a318a82ba800..61e741b75946 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -3,8 +3,6 @@ from typing import TYPE_CHECKING -from vllm import envs - if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionBackend @@ -571,7 +569,7 @@ def conv_ssm_forward( assert self.cache_config is not None mamba_block_size = self.cache_config.mamba_block_size prefix_caching_enabled = ( - not envs.VLLM_USE_LIGHTER_MAMBA_CACHE + self.cache_config.mamba_cache_mode != "align" and self.cache_config.enable_prefix_caching ) if attn_metadata is not None: diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 08dc8c3caf5a..1b0c524fbec3 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -12,8 +12,6 @@ from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec, MLAAttentionSpec -from vllm import envs - if TYPE_CHECKING: from vllm.config import VllmConfig @@ -289,7 +287,7 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: cache_config = vllm_config.cache_config if cache_config.enable_prefix_caching: - if not envs.VLLM_USE_LIGHTER_MAMBA_CACHE: + if cache_config.mamba_cache_mode == "all": if model_config.supports_mamba_prefix_caching: logger.info( "Warning: Prefix caching is currently enabled. " @@ -302,12 +300,16 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: "support for prefix caching: disabling." ) cache_config.enable_prefix_caching = False - else: + elif cache_config.mamba_cache_mode == "align": logger.info( - "Warning: Lighter Mamba Prefix caching is currently" - " enabled. Its support is experimental. " + "Warning: Mamba cache mode 'align' with prefix caching is" + " currently enabled. Its support is experimental. " "Please report any issues you may observe." - ) + ) + else: + raise ValueError( + "unknown mamba cache mode: %s", cache_config.mamba_cache_mode + ) # By default, mamba block size will be set to max_model_len (see # below). When enabling prefix caching, we align mamba block size # to the block size as the basic granularity for prefix caching. @@ -393,7 +395,7 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: model_config=model_config, ) - if envs.VLLM_USE_LIGHTER_MAMBA_CACHE and cache_config.enable_prefix_caching: + if cache_config.enable_prefix_caching: block_size = cache_config.block_size else: block_size = model_config.max_model_len @@ -411,7 +413,7 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: if mamba_page_size == 0: return - if cache_config.enable_prefix_caching and not envs.VLLM_USE_LIGHTER_MAMBA_CACHE: + if cache_config.mamba_cache_mode == "all": # With prefix caching, select attention block size to # optimize for mamba kernel performance diff --git a/vllm/model_executor/models/lfm2.py b/vllm/model_executor/models/lfm2.py index 0c6e11fa90c2..24e4b5df71e4 100644 --- a/vllm/model_executor/models/lfm2.py +++ b/vllm/model_executor/models/lfm2.py @@ -7,7 +7,6 @@ import torch.nn as nn from transformers import Lfm2Config -from vllm import envs from vllm.attention.layer import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ModelConfig, VllmConfig @@ -460,10 +459,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config cache_config = vllm_config.cache_config - - if not envs.VLLM_USE_LIGHTER_MAMBA_CACHE: - assert not cache_config.enable_prefix_caching, ( - "Lfm2 currently does not support prefix caching" + if cache_config.mamba_cache_mode == "all": + raise NotImplementedError( + "Lfm2 currently does not support 'all' prefix caching, " + "please use '--mamba-cache-mode=align' instead" ) super().__init__() diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index 7a8183c1c123..daf917a583e0 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -10,7 +10,6 @@ from torch import nn from transformers.activations import ACT2FN -from vllm import envs from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.layer import Attention from vllm.compilation.decorators import support_torch_compile @@ -1190,9 +1189,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): cache_config = vllm_config.cache_config scheduler_config = vllm_config.scheduler_config - if not envs.VLLM_USE_LIGHTER_MAMBA_CACHE: - assert not cache_config.enable_prefix_caching, ( - "Qwen3NextMTP currently does not support prefix caching" + if cache_config.mamba_cache_mode == "all": + raise NotImplementedError( + "Qwen3Next currently does not support 'all' prefix caching, " + "please use '--mamba-cache-mode=align' instead" ) self.quant_config = vllm_config.quant_config diff --git a/vllm/model_executor/models/qwen3_next_mtp.py b/vllm/model_executor/models/qwen3_next_mtp.py index 17ee3a9792d8..c91775f409c3 100644 --- a/vllm/model_executor/models/qwen3_next_mtp.py +++ b/vllm/model_executor/models/qwen3_next_mtp.py @@ -7,7 +7,6 @@ import torch from torch import nn -from vllm import envs from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig from vllm.distributed.parallel_state import get_pp_group @@ -235,9 +234,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config self.vllm_config = vllm_config cache_config = vllm_config.cache_config - if not envs.VLLM_USE_LIGHTER_MAMBA_CACHE: - assert not cache_config.enable_prefix_caching, ( - "Qwen3NextMTP currently does not support prefix caching" + if cache_config.mamba_cache_mode == "all": + raise NotImplementedError( + "Qwen3NextMTP currently does not support 'all' prefix caching, " + "please use '--mamba-cache-mode=align' instead" ) self.quant_config = vllm_config.quant_config diff --git a/vllm/v1/attention/backends/gdn_attn.py b/vllm/v1/attention/backends/gdn_attn.py index b2afb1ff6e16..6c3908bb608f 100644 --- a/vllm/v1/attention/backends/gdn_attn.py +++ b/vllm/v1/attention/backends/gdn_attn.py @@ -5,8 +5,6 @@ from dataclasses import dataclass import torch - -from vllm import envs from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.config import VllmConfig @@ -21,8 +19,10 @@ ) from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec from vllm.logger import init_logger + logger = init_logger(__name__) + class GDNAttentionBackend(AttentionBackend): @staticmethod def get_builder_cls() -> type["GDNAttentionMetadataBuilder"]: @@ -149,7 +149,7 @@ def build( # type: ignore[override] context_lens = m.num_computed_tokens_cpu context_lens_tensor = context_lens.to(query_start_loc.device) nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None - if envs.VLLM_USE_LIGHTER_MAMBA_CACHE: + if self.vllm_config.cache_config.mamba_cache_mode == "align": block_table_tensor = mamba_gather_indices( common_attn_metadata, self.kv_cache_spec, @@ -339,16 +339,6 @@ def build( # type: ignore[override] non_spec_query_start_loc = self.non_spec_query_start_loc[: batch_size + 1] non_spec_query_start_loc[num_decodes + 1 :].fill_(non_spec_num_query_tokens) - # if envs.VLLM_USE_LIGHTER_MAMBA_CACHE: - # # NOTE: With Mamba prefix-caching support, a request can consist of - # # multiple blocks. This makes the state_indices non-contiguous, so - # # we must explicitly make them contiguous here. - # if spec_state_indices_tensor is not None: - # spec_state_indices_tensor = spec_state_indices_tensor.contiguous() - # if non_spec_state_indices_tensor is not None: - # non_spec_state_indices_tensor = \ - # non_spec_state_indices_tensor.contiguous() - attn_metadata = GDNAttentionMetadata( num_prefills=num_prefills, num_prefill_tokens=num_prefill_tokens, diff --git a/vllm/v1/attention/backends/linear_attn.py b/vllm/v1/attention/backends/linear_attn.py index ef3123e891a2..b1848a5e1644 100644 --- a/vllm/v1/attention/backends/linear_attn.py +++ b/vllm/v1/attention/backends/linear_attn.py @@ -3,8 +3,6 @@ from dataclasses import dataclass import torch - -from vllm import envs from vllm.attention.backends.abstract import AttentionBackend from vllm.config import VllmConfig from vllm.v1.attention.backends.utils import ( @@ -59,8 +57,7 @@ def build( query_start_loc = common_attn_metadata.query_start_loc seq_lens = common_attn_metadata.seq_lens - state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] - if envs.VLLM_USE_LIGHTER_MAMBA_CACHE: + if self.vllm_config.cache_config.mamba_cache_mode == "align": state_indices_tensor = mamba_gather_indices( common_attn_metadata, self.kv_cache_spec, diff --git a/vllm/v1/attention/backends/mamba1_attn.py b/vllm/v1/attention/backends/mamba1_attn.py index d976ec6e9442..cdd4d3d280ab 100644 --- a/vllm/v1/attention/backends/mamba1_attn.py +++ b/vllm/v1/attention/backends/mamba1_attn.py @@ -4,8 +4,6 @@ from dataclasses import dataclass import torch - -from vllm import envs from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.config import VllmConfig @@ -75,9 +73,7 @@ def build( # TODO(@Josephasafg) Mamba1 and Mamba2 have a lot of code in common here. # We should consolidate this code - if (not envs.VLLM_USE_LIGHTER_MAMBA_CACHE - and self.vllm_config.cache_config.enable_prefix_caching - ): + if self.vllm_config.cache_config.mamba_cache_mode == "all": # Return a tensor of shape (#requests, #max blocks) state_indices_tensor = common_attn_metadata.block_table_tensor mamba_block_size = self.kv_cache_spec.block_size @@ -96,9 +92,7 @@ def build( state_indices_tensor = mamba_gather_indices( common_attn_metadata, self.kv_cache_spec, - )[:, 0] - if envs.VLLM_USE_LIGHTER_MAMBA_CACHE: - state_indices_tensor = state_indices_tensor.contiguous() + )[:, 0].contiguous() block_idx_last_scheduled_token = None block_idx_last_computed_token = None @@ -117,9 +111,7 @@ def build( common_attn_metadata.query_start_loc.device ) - if (not envs.VLLM_USE_LIGHTER_MAMBA_CACHE - and self.vllm_config.cache_config.enable_prefix_caching - ): + if self.vllm_config.cache_config.mamba_cache_mode == "all": assert num_computed_tokens is not None num_computed_tokens_p = num_computed_tokens[ num_reqs - num_prefills : num_reqs @@ -140,9 +132,7 @@ def build( state_indices_tensor = self.state_indices_tensor[:num_decode_tokens] state_indices_tensor[num_decodes:] = PAD_SLOT_ID - if (not envs.VLLM_USE_LIGHTER_MAMBA_CACHE - and self.vllm_config.cache_config.enable_prefix_caching - ): + if self.vllm_config.cache_config.mamba_cache_mode == "all": self.block_idx_last_scheduled_token[:num_decodes].copy_( block_idx_last_scheduled_token, non_blocking=True ) diff --git a/vllm/v1/attention/backends/mamba2_attn.py b/vllm/v1/attention/backends/mamba2_attn.py index a7e42594fbc6..e417217277b8 100644 --- a/vllm/v1/attention/backends/mamba2_attn.py +++ b/vllm/v1/attention/backends/mamba2_attn.py @@ -4,8 +4,6 @@ from dataclasses import dataclass import torch - -from vllm import envs from vllm.attention.backends.abstract import AttentionBackend from vllm.config import VllmConfig from vllm.utils.math_utils import cdiv @@ -174,9 +172,7 @@ def build( block_idx_first_scheduled_token = None block_idx_first_scheduled_token_p = None - if (not envs.VLLM_USE_LIGHTER_MAMBA_CACHE - and self.vllm_config.cache_config.enable_prefix_caching - ): + if self.vllm_config.cache_config.mamba_cache_mode == "all": # Return a tensor of shape (#requests, #max blocks) state_indices_tensor = common_attn_metadata.block_table_tensor # Additional cache-related varaiables: @@ -227,9 +223,7 @@ def build( - num_decode_tokens ) - if (not envs.VLLM_USE_LIGHTER_MAMBA_CACHE - and self.vllm_config.cache_config.enable_prefix_caching - ): + if self.vllm_config.cache_config.mamba_cache_mode == "all": assert num_computed_tokens is not None num_computed_tokens_p = num_computed_tokens[ num_reqs - num_prefills : num_reqs @@ -318,9 +312,7 @@ def build( ) state_indices_tensor = self.state_indices_tensor[:num_decode_tokens] - if (not envs.VLLM_USE_LIGHTER_MAMBA_CACHE - and self.vllm_config.cache_config.enable_prefix_caching - ): + if self.vllm_config.cache_config.mamba_cache_mode == "all": self.block_idx_last_scheduled_token[:num_decodes].copy_( block_idx_last_scheduled_token, non_blocking=True ) diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index e363b556be42..1eacd2ffdedf 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -5,8 +5,6 @@ from typing import ClassVar, TypeVar import torch - -from vllm import envs from vllm.config import VllmConfig from vllm.utils.math_utils import cdiv from vllm.v1.attention.backends.utils import ( @@ -41,9 +39,7 @@ def __init__( self.compilation_config.max_cudagraph_capture_size, ) - if (not envs.VLLM_USE_LIGHTER_MAMBA_CACHE - and self.vllm_config.cache_config.enable_prefix_caching - ): + if self.vllm_config.cache_config.mamba_cache_mode == "all": self.state_indices_tensor = torch.empty( ( self.decode_cudagraph_max_bs, diff --git a/vllm/v1/attention/backends/short_conv_attn.py b/vllm/v1/attention/backends/short_conv_attn.py index 7dd3a57ee922..24b4551293bd 100644 --- a/vllm/v1/attention/backends/short_conv_attn.py +++ b/vllm/v1/attention/backends/short_conv_attn.py @@ -3,8 +3,6 @@ from dataclasses import dataclass import torch - -from vllm import envs from vllm.attention.backends.abstract import AttentionBackend from vllm.v1.attention.backends.mamba_attn import BaseMambaAttentionMetadataBuilder from vllm.v1.attention.backends.utils import ( @@ -50,12 +48,12 @@ def build( ) -> ShortConvAttentionMetadata: num_reqs = common_attn_metadata.num_reqs query_start_loc = common_attn_metadata.query_start_loc - if envs.VLLM_USE_LIGHTER_MAMBA_CACHE: + if self.vllm_config.cache_config.mamba_cache_mode == "align": state_indices_tensor = mamba_gather_indices( common_attn_metadata, self.kv_cache_spec, )[:, 0] - else: + else: state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] # for causal_conv1d diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index 1e8181b73828..ebfa992e305b 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -4,6 +4,8 @@ from collections.abc import Sequence from math import lcm +from vllm.config.cache import CacheConfig +from vllm.config.vllm import VllmConfig from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_metrics import KVCacheMetricsCollector from vllm.v1.core.kv_cache_utils import ( @@ -32,6 +34,7 @@ class KVCacheCoordinator(ABC): def __init__( self, + cache_config: CacheConfig, kv_cache_config: KVCacheConfig, max_model_len: int, use_eagle: bool, @@ -42,6 +45,7 @@ def __init__( hash_block_size: int, metrics_collector: KVCacheMetricsCollector | None = None, ): + self.cache_config = cache_config self.kv_cache_config = kv_cache_config self.max_model_len = max_model_len self.enable_caching = enable_caching @@ -58,6 +62,7 @@ def __init__( self.use_eagle = use_eagle self.single_type_managers = tuple( get_manager_for_kv_cache_spec( + cache_config=self.cache_config, kv_cache_spec=kv_cache_group.kv_cache_spec, block_pool=self.block_pool, kv_cache_group_id=i, @@ -235,6 +240,7 @@ class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator): def __init__( self, + cache_config: CacheConfig, kv_cache_config: KVCacheConfig, max_model_len: int, use_eagle: bool, @@ -245,6 +251,7 @@ def __init__( metrics_collector: KVCacheMetricsCollector | None = None, ): super().__init__( + cache_config, kv_cache_config, max_model_len, use_eagle, @@ -280,6 +287,7 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator): def __init__( self, + cache_config: CacheConfig, kv_cache_config: KVCacheConfig, max_model_len: int, use_eagle: bool, @@ -291,6 +299,7 @@ def __init__( metrics_collector: KVCacheMetricsCollector | None = None, ): super().__init__( + cache_config, kv_cache_config, max_model_len, use_eagle, @@ -348,6 +357,7 @@ class HybridKVCacheCoordinator(KVCacheCoordinator): def __init__( self, + cache_config: CacheConfig, kv_cache_config: KVCacheConfig, max_model_len: int, use_eagle: bool, @@ -359,6 +369,7 @@ def __init__( metrics_collector: KVCacheMetricsCollector | None = None, ): super().__init__( + cache_config, kv_cache_config, max_model_len, use_eagle, @@ -535,6 +546,7 @@ def find_longest_cache_hit( def get_kv_cache_coordinator( + cache_config: CacheConfig, kv_cache_config: KVCacheConfig, max_model_len: int, use_eagle: bool, @@ -547,6 +559,7 @@ def get_kv_cache_coordinator( ) -> KVCacheCoordinator: if not enable_caching: return KVCacheCoordinatorNoPrefixCache( + cache_config, kv_cache_config, max_model_len, use_eagle, @@ -558,6 +571,7 @@ def get_kv_cache_coordinator( ) if len(kv_cache_config.kv_cache_groups) == 1: return UnitaryKVCacheCoordinator( + cache_config, kv_cache_config, max_model_len, use_eagle, @@ -569,6 +583,7 @@ def get_kv_cache_coordinator( metrics_collector=metrics_collector, ) return HybridKVCacheCoordinator( + cache_config, kv_cache_config, max_model_len, use_eagle, diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 28609f44964a..5aebab9b3ec9 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -6,7 +6,7 @@ from dataclasses import dataclass from typing import Literal, overload -from vllm import envs +from vllm.config.cache import CacheConfig from vllm.distributed.kv_events import KVCacheEvent from vllm.logger import init_logger from vllm.v1.core.kv_cache_coordinator import get_kv_cache_coordinator @@ -97,6 +97,7 @@ def __init__( self, kv_cache_config: KVCacheConfig, max_model_len: int, + cache_config: CacheConfig, hash_block_size: int, enable_caching: bool = True, use_eagle: bool = False, @@ -118,6 +119,7 @@ def __init__( self.prefix_cache_stats = PrefixCacheStats() if log_stats else None self.coordinator = get_kv_cache_coordinator( + cache_config=cache_config, kv_cache_config=kv_cache_config, max_model_len=self.max_model_len, use_eagle=self.use_eagle, diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index f8f53fa92247..f1db90990915 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -201,6 +201,7 @@ def __init__( self.kv_cache_manager = KVCacheManager( kv_cache_config=kv_cache_config, max_model_len=self.max_model_len, + cache_config=self.cache_config, enable_caching=self.cache_config.enable_prefix_caching, use_eagle=self.use_eagle, log_stats=self.log_stats, @@ -216,13 +217,19 @@ def __init__( f">>> [DEBUG] Scheduler: init enable_prefix_caching={self.cache_config.enable_prefix_caching} block_size={self.block_size} kv_cache_config={self.kv_cache_config}" ) - def _has_mamba_spec(self) -> bool: - has_mamba: bool = any( - isinstance(spec.kv_cache_spec, MambaSpec) - for spec in self.kv_cache_config.kv_cache_groups + def has_mamba_layers(kv_cache_config: KVCacheConfig) -> bool: + has_mamba: bool = any( + isinstance(group_spec.kv_cache_spec, MambaSpec) + for group_spec in kv_cache_config.kv_cache_groups + ) + if vllm_config.model_config.is_hybrid: + assert has_mamba, "Hybrid models must have mamba layers" + return has_mamba + + self.need_mamba_block_aligned_split = ( + has_mamba_layers(self.kv_cache_config) + and self.cache_config.mamba_cache_mode == "align" ) - assert not has_mamba or self.vllm_config.model_config.is_hybrid - return has_mamba def _mamba_block_aligned_split( self, @@ -234,41 +241,40 @@ def _mamba_block_aligned_split( assert num_external_computed_tokens == 0, ( "External KV connector is not verified yet" ) - if self.cache_config.enable_prefix_caching and self._has_mamba_spec(): - # To enable block-aligned caching of the Mamba state, `num_new_tokens` - # must be a multiple of `block_size`. - # As an exception, if `num_new_tokens` is less than `block_size`, the - # state is simply not cached, requiring no special handling. - # Additionally, when Eagle mode is enabled, FullAttn prunes the last - # matching block. To prevent this from causing a Mamba cache miss, the - # last chunk must be larger than `block_size`. + # To enable block-aligned caching of the Mamba state, `num_new_tokens` + # must be a multiple of `block_size`. + # As an exception, if `num_new_tokens` is less than `block_size`, the + # state is simply not cached, requiring no special handling. + # Additionally, when Eagle mode is enabled, FullAttn prunes the last + # matching block. To prevent this from causing a Mamba cache miss, the + # last chunk must be larger than `block_size`. + if request.num_output_tokens == 0: # prefill block_size = self.cache_config.block_size - if request.num_output_tokens == 0: # prefill - last_cache_position = ( - request.num_prompt_tokens - request.num_prompt_tokens % block_size - ) - # eagle prune - if self.use_eagle: - last_cache_position = max(last_cache_position - block_size, 0) - num_computed_tokens = ( - request.num_computed_tokens - + num_new_local_computed_tokens - + num_external_computed_tokens - ) - num_computed_tokens_after_prefill = num_computed_tokens + num_new_tokens - if num_computed_tokens_after_prefill < last_cache_position: - # align to block_size - num_new_tokens = num_new_tokens // block_size * block_size - elif ( - num_computed_tokens - < last_cache_position - < num_computed_tokens_after_prefill - ): - # force to cache the last chunk - num_new_tokens = last_cache_position - num_computed_tokens - else: - # prefill the last few tokens - pass + last_cache_position = ( + request.num_prompt_tokens - request.num_prompt_tokens % block_size + ) + # eagle prune + if self.use_eagle: + last_cache_position = max(last_cache_position - block_size, 0) + num_computed_tokens = ( + request.num_computed_tokens + + num_new_local_computed_tokens + + num_external_computed_tokens + ) + num_computed_tokens_after_prefill = num_computed_tokens + num_new_tokens + if num_computed_tokens_after_prefill < last_cache_position: + # align to block_size + num_new_tokens = num_new_tokens // block_size * block_size + elif ( + num_computed_tokens + < last_cache_position + < num_computed_tokens_after_prefill + ): + # force to cache the last chunk + num_new_tokens = last_cache_position - num_computed_tokens + else: + # prefill the last few tokens + pass return num_new_tokens def schedule(self) -> SchedulerOutput: @@ -378,7 +384,7 @@ def schedule(self) -> SchedulerOutput: shift_computed_tokens=1 if self.use_eagle else 0, ) - if envs.VLLM_USE_LIGHTER_MAMBA_CACHE: + if self.need_mamba_block_aligned_split: num_new_tokens = self._mamba_block_aligned_split( request, num_new_tokens ) @@ -394,7 +400,7 @@ def schedule(self) -> SchedulerOutput: # 2. The encoder budget is exhausted. # 3. The encoder cache is exhausted. # 4. Insufficient budget for a block-aligned chunk in hybrid - # models with lighter mamba prefix caching. + # models with mamba cache mode \"align\". # NOTE(woosuk): Here, by doing `continue` instead of `break`, # we do not strictly follow the FCFS scheduling policy and # allow the lower-priority requests to be scheduled. @@ -659,7 +665,7 @@ def schedule(self) -> SchedulerOutput: # The request cannot be scheduled. break - if envs.VLLM_USE_LIGHTER_MAMBA_CACHE: + if self.need_mamba_block_aligned_split: num_new_tokens = self._mamba_block_aligned_split( request, num_new_tokens, diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 58669ef2e0dd..37b471715e08 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -5,7 +5,7 @@ from collections import defaultdict from collections.abc import Sequence -from vllm import envs +from vllm.config.cache import CacheConfig from vllm.utils.math_utils import cdiv from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_utils import BlockHashList, KVCacheBlock @@ -54,6 +54,7 @@ class SingleTypeKVCacheManager(ABC): def __init__( self, kv_cache_spec: KVCacheSpec, + cache_config: CacheConfig, block_pool: BlockPool, kv_cache_group_id: int, dcp_world_size: int = 1, @@ -66,6 +67,7 @@ def __init__( block_pool: The block pool. kv_cache_group_id: The id of the kv cache group of this manager. """ + self.cache_config = cache_config self.block_size = kv_cache_spec.block_size self.dcp_world_size = dcp_world_size self.pcp_world_size = pcp_world_size @@ -695,10 +697,13 @@ def get_num_common_prefix_blocks(self, running_request_id: str) -> int: class MambaManager(SingleTypeKVCacheManager): - def __init__(self, kv_cache_spec: MambaSpec, **kwargs) -> None: - super().__init__(kv_cache_spec, **kwargs) + def __init__( + self, kv_cache_spec: MambaSpec, cache_config: CacheConfig, **kwargs + ) -> None: + super().__init__(kv_cache_spec, cache_config, **kwargs) + self.mamba_cache_mode = cache_config.mamba_cache_mode self.num_speculative_blocks: int = kv_cache_spec.num_speculative_blocks - if envs.VLLM_USE_LIGHTER_MAMBA_CACHE: + if self.mamba_cache_mode == "align": self.last_state_block_idx: dict[str, int] = {} # the set of the requests that have been allocated blocks self._allocated_block_reqs: set[str] = set() @@ -762,7 +767,7 @@ def get_num_skipped_tokens(self, num_computed_tokens: int) -> int: def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None: assert isinstance(self.kv_cache_spec, MambaSpec) super().remove_skipped_blocks(request_id, num_computed_tokens) - if envs.VLLM_USE_LIGHTER_MAMBA_CACHE: + if self.mamba_cache_mode == "align": last_state_block_idx = self.last_state_block_idx.get(request_id) if ( last_state_block_idx is not None @@ -790,7 +795,7 @@ def get_num_blocks_to_allocate( assert isinstance(self.kv_cache_spec, MambaSpec) # mamba layers only exist in target model. num_tokens = num_tokens_target_model - if not envs.VLLM_USE_LIGHTER_MAMBA_CACHE: + if self.mamba_cache_mode != "align": # Allocate extra `num_speculative_blocks` blocks for # speculative decoding (MTP/EAGLE) with linear attention. if self.kv_cache_spec.num_speculative_blocks > 0: @@ -853,7 +858,7 @@ def allocate_new_blocks( ) -> list[KVCacheBlock]: assert isinstance(self.kv_cache_spec, MambaSpec) num_tokens = num_tokens_target_model - if not envs.VLLM_USE_LIGHTER_MAMBA_CACHE: + if self.mamba_cache_mode != "align": # Allocate extra `num_speculative_blocks` blocks for # speculative decoding (MTP/EAGLE) with linear attention. if self.num_speculative_blocks > 0: @@ -926,7 +931,7 @@ def allocate_new_blocks( return req_blocks[prev_block_len:] def free(self, request_id: str) -> None: - if envs.VLLM_USE_LIGHTER_MAMBA_CACHE: + if self.mamba_cache_mode == "align": self._allocated_block_reqs.discard(request_id) self.last_state_block_idx.pop(request_id, None) super().free(request_id) diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index d30eacdef89b..daf8487365b1 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -8,8 +8,8 @@ import torch from typing_extensions import Self -from vllm import envs from vllm.config import VllmConfig +from vllm.config.cache import MambaCacheMode from vllm.logger import init_logger from vllm.utils.math_utils import cdiv from vllm.utils.torch_utils import get_dtype_size @@ -265,16 +265,13 @@ def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: # We allocate 1 block for each request now, so max_memory_usage_bytes is # the same as page_size_bytes. # Need to update this when supporting prefix caching. - if not envs.VLLM_USE_LIGHTER_MAMBA_CACHE: + if vllm_config.cache_config.mamba_cache_mode == "all": max_model_len = vllm_config.model_config.max_model_len return cdiv(max_model_len, self.block_size) * self.page_size_bytes + elif vllm_config.cache_config.mamba_cache_mode == "align": + return self.page_size_bytes * (2 + self.num_speculative_blocks) else: - # NOTE: We allocate 1+sps block per request by default. With prefix - # caching enabled, one additional blocks are required which is saved - # last state for copying. - return self.page_size_bytes * (1 + self.num_speculative_blocks - + self.enable_caching) - + return self.page_size_bytes * (1 + self.num_speculative_blocks) @dataclass(frozen=True) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 5f9305ce7d0d..bdbef5fc8607 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1061,10 +1061,7 @@ def _update_states_after_model_execute( f">>> [DEBUG] Worker: _update_states: " f"{self.input_batch.num_accepted_tokens_cpu[:len(num_accepted_tokens)]=}" ) - if ( - envs.VLLM_USE_LIGHTER_MAMBA_CACHE - and self.cache_config.enable_prefix_caching - ): + if self.cache_config.mamba_cache_mode == "align": self._postprocess_mamba(scheduler_output) def _init_mrope_positions(self, req_state: CachedRequestState): @@ -1600,13 +1597,6 @@ def _get_block_table_and_slot_mapping(kv_cache_gid: int): slot_mapping[num_tokens:num_tokens_padded].fill_(-1) blk_table_tensor[num_reqs:num_reqs_padded].fill_(-1) - # if (envs.VLLM_USE_LIGHTER_MAMBA_CACHE - # and self.cache_config.enable_prefix_caching - # and isinstance(kv_cache_group.kv_cache_spec, MambaSpec) - # ): - # # NOTE(Chen): where should we put this? - # self._preprocess_mamba(kv_cache_gid, kv_cache_group) - return blk_table_tensor, slot_mapping block_table_gid_0, slot_mapping_gid_0 = _get_block_table_and_slot_mapping(0) @@ -3209,10 +3199,7 @@ def execute_model( pad_attn = cudagraph_mode == CUDAGraphMode.FULL - if ( - envs.VLLM_USE_LIGHTER_MAMBA_CACHE - and self.cache_config.enable_prefix_caching - ): + if self.cache_config.mamba_cache_mode == "align": # TODO: add limition: preprocess only have new blocks self._preprocess_mamba(scheduler_output) @@ -3390,11 +3377,6 @@ def sample_tokens( scheduler_output, grammar_output, self.input_batch, logits ) - # if (envs.VLLM_USE_LIGHTER_MAMBA_CACHE - # and self.cache_config.enable_prefix_caching - # ): - # self._postprocess_mamba_cache(scheduler_output) - with record_function_or_nullcontext("gpu_model_runner: sample"): sampler_output = self._sample(logits, spec_decode_metadata) diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index a80603d32eff..1e9e505fd0ef 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -16,7 +16,12 @@ from vllm.platforms import current_platform from vllm.v1.attention.backends.utils import AttentionMetadataBuilder from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget -from vllm.v1.kv_cache_interface import KVCacheGroupSpec, KVCacheSpec +from vllm.v1.kv_cache_interface import ( + KVCacheConfig, + KVCacheGroupSpec, + KVCacheSpec, + MambaSpec, +) class MultiModalBudget: From 1052cb673e3adeb63437b7f7aaeeb5c4a39be274 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Thu, 11 Dec 2025 23:42:40 -0800 Subject: [PATCH 045/130] nit update Signed-off-by: Chen Zhang --- vllm/engine/arg_utils.py | 8 -------- vllm/entrypoints/openai/serving_chat.py | 1 - vllm/model_executor/models/config.py | 3 +++ 3 files changed, 3 insertions(+), 9 deletions(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index ce036a4efc50..619d43ed571f 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1361,14 +1361,6 @@ def create_engine_config( f"dcp_size={self.decode_context_parallel_size}." ) - if ( - self.mamba_cache_mode == "align" - and self.enable_prefix_caching - and model_config.is_hybrid - ): - assert self.enable_chunked_prefill, ( - "Prefix caching for hybrid models requires chunked prefill.") - cache_config = CacheConfig( block_size=self.block_size, gpu_memory_utilization=self.gpu_memory_utilization, diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index e42fad2b77d0..d94fa7dd9193 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -185,7 +185,6 @@ async def create_chat_completion( if self.engine_client.errored: raise self.engine_client.dead_error - logger.info(f'>>> [DEBUG] create_chat: req_id={request.request_id} msg={request.messages}') try: lora_request = self._maybe_get_adapters( request, supports_default_mm_loras=True diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 1b0c524fbec3..e3c396857019 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -306,6 +306,9 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: " currently enabled. Its support is experimental. " "Please report any issues you may observe." ) + assert vllm_config.scheduler_config.enable_chunked_prefill, ( + "Chunked prefill is required for mamba cache mode 'align'." + ) else: raise ValueError( "unknown mamba cache mode: %s", cache_config.mamba_cache_mode From dc37673e5be1cdaa6c6280544118cf8b96052ff9 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 12 Dec 2025 00:12:52 -0800 Subject: [PATCH 046/130] remove enable_caching from mamba_spec Signed-off-by: Chen Zhang --- vllm/model_executor/layers/mamba/abstract.py | 2 +- vllm/model_executor/models/config.py | 1 - vllm/v1/attention/backends/gdn_attn.py | 22 ++++++++---------- vllm/v1/attention/backends/linear_attn.py | 15 ++++++------ vllm/v1/attention/backends/mamba1_attn.py | 6 +++-- vllm/v1/attention/backends/mamba2_attn.py | 6 +++-- vllm/v1/attention/backends/short_conv_attn.py | 15 ++++++------ vllm/v1/attention/backends/utils.py | 23 ++++++++++--------- vllm/v1/kv_cache_interface.py | 2 -- 9 files changed, 44 insertions(+), 48 deletions(-) diff --git a/vllm/model_executor/layers/mamba/abstract.py b/vllm/model_executor/layers/mamba/abstract.py index c9dd0232bc2b..fe9e55e7b533 100644 --- a/vllm/model_executor/layers/mamba/abstract.py +++ b/vllm/model_executor/layers/mamba/abstract.py @@ -4,6 +4,7 @@ from collections.abc import Iterable import torch + from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.selector import get_mamba_attn_backend from vllm.config import VllmConfig @@ -63,7 +64,6 @@ def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None: if vllm_config.speculative_config else 0 ), - enable_caching=vllm_config.cache_config.enable_prefix_caching, ) def get_attn_backend(self) -> type[AttentionBackend]: diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index e3c396857019..8aa4d4adea13 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -407,7 +407,6 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: shapes=model_cls.get_mamba_state_shape_from_config(vllm_config), dtypes=model_cls.get_mamba_state_dtype_from_config(vllm_config), block_size=block_size, - enable_caching=cache_config.enable_prefix_caching, ).page_size_bytes # Model may be marked as is_hybrid diff --git a/vllm/v1/attention/backends/gdn_attn.py b/vllm/v1/attention/backends/gdn_attn.py index 6c3908bb608f..f9ba358e023f 100644 --- a/vllm/v1/attention/backends/gdn_attn.py +++ b/vllm/v1/attention/backends/gdn_attn.py @@ -5,20 +5,20 @@ from dataclasses import dataclass import torch + from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.config import VllmConfig -from vllm.distributed.parallel_state import is_global_first_rank +from vllm.logger import init_logger from vllm.v1.attention.backends.utils import ( AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, compute_causal_conv1d_metadata, + mamba_get_block_table_tensor, split_decodes_and_prefills, - mamba_gather_indices, ) from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec -from vllm.logger import init_logger logger = init_logger(__name__) @@ -149,16 +149,12 @@ def build( # type: ignore[override] context_lens = m.num_computed_tokens_cpu context_lens_tensor = context_lens.to(query_start_loc.device) nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None - if self.vllm_config.cache_config.mamba_cache_mode == "align": - block_table_tensor = mamba_gather_indices( - common_attn_metadata, - self.kv_cache_spec, - 1 + self.num_spec, - ) - if is_global_first_rank(): - logger.info(f"{block_table_tensor=}") - else: - block_table_tensor = m.block_table_tensor + block_table_tensor = mamba_get_block_table_tensor( + common_attn_metadata, + self.kv_cache_spec, + self.vllm_config.cache_config.mamba_cache_mode, + 1 + self.num_spec, + ) if ( not self.use_spec_decode diff --git a/vllm/v1/attention/backends/linear_attn.py b/vllm/v1/attention/backends/linear_attn.py index b1848a5e1644..d25ec564f7f0 100644 --- a/vllm/v1/attention/backends/linear_attn.py +++ b/vllm/v1/attention/backends/linear_attn.py @@ -3,13 +3,14 @@ from dataclasses import dataclass import torch + from vllm.attention.backends.abstract import AttentionBackend from vllm.config import VllmConfig from vllm.v1.attention.backends.utils import ( AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, - mamba_gather_indices, + mamba_get_block_table_tensor, split_decodes_and_prefills, ) from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec @@ -57,13 +58,11 @@ def build( query_start_loc = common_attn_metadata.query_start_loc seq_lens = common_attn_metadata.seq_lens - if self.vllm_config.cache_config.mamba_cache_mode == "align": - state_indices_tensor = mamba_gather_indices( - common_attn_metadata, - self.kv_cache_spec, - )[:, 0] - else: - state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] + state_indices_tensor = mamba_get_block_table_tensor( + common_attn_metadata, + self.kv_cache_spec, + self.vllm_config.cache_config.mamba_cache_mode, + )[:, 0] num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( split_decodes_and_prefills( diff --git a/vllm/v1/attention/backends/mamba1_attn.py b/vllm/v1/attention/backends/mamba1_attn.py index cdd4d3d280ab..0ec8436814f9 100644 --- a/vllm/v1/attention/backends/mamba1_attn.py +++ b/vllm/v1/attention/backends/mamba1_attn.py @@ -4,13 +4,14 @@ from dataclasses import dataclass import torch + from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.config import VllmConfig from vllm.v1.attention.backends.mamba_attn import BaseMambaAttentionMetadataBuilder from vllm.v1.attention.backends.utils import ( CommonAttentionMetadata, - mamba_gather_indices, + mamba_get_block_table_tensor, split_decodes_and_prefills, ) from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec @@ -89,9 +90,10 @@ def build( ) else: # Always return just a single block per each request: - state_indices_tensor = mamba_gather_indices( + state_indices_tensor = mamba_get_block_table_tensor( common_attn_metadata, self.kv_cache_spec, + self.vllm_config.cache_config.mamba_cache_mode, )[:, 0].contiguous() block_idx_last_scheduled_token = None block_idx_last_computed_token = None diff --git a/vllm/v1/attention/backends/mamba2_attn.py b/vllm/v1/attention/backends/mamba2_attn.py index e417217277b8..5d49c09d6fc3 100644 --- a/vllm/v1/attention/backends/mamba2_attn.py +++ b/vllm/v1/attention/backends/mamba2_attn.py @@ -4,6 +4,7 @@ from dataclasses import dataclass import torch + from vllm.attention.backends.abstract import AttentionBackend from vllm.config import VllmConfig from vllm.utils.math_utils import cdiv @@ -11,7 +12,7 @@ from vllm.v1.attention.backends.utils import ( CommonAttentionMetadata, compute_causal_conv1d_metadata, - mamba_gather_indices, + mamba_get_block_table_tensor, split_decodes_and_prefills, ) from vllm.v1.kv_cache_interface import AttentionSpec @@ -189,9 +190,10 @@ def build( ) else: # Always return just a single block per each request: - state_indices_tensor = mamba_gather_indices( + state_indices_tensor = mamba_get_block_table_tensor( common_attn_metadata, self.kv_cache_spec, + self.vllm_config.cache_config.mamba_cache_mode, )[:, 0] # Additional cache-related varaiables: diff --git a/vllm/v1/attention/backends/short_conv_attn.py b/vllm/v1/attention/backends/short_conv_attn.py index 24b4551293bd..f25681dbc49f 100644 --- a/vllm/v1/attention/backends/short_conv_attn.py +++ b/vllm/v1/attention/backends/short_conv_attn.py @@ -3,13 +3,14 @@ from dataclasses import dataclass import torch + from vllm.attention.backends.abstract import AttentionBackend from vllm.v1.attention.backends.mamba_attn import BaseMambaAttentionMetadataBuilder from vllm.v1.attention.backends.utils import ( PAD_SLOT_ID, CommonAttentionMetadata, compute_causal_conv1d_metadata, - mamba_gather_indices, + mamba_get_block_table_tensor, split_decodes_and_prefills, ) @@ -48,13 +49,11 @@ def build( ) -> ShortConvAttentionMetadata: num_reqs = common_attn_metadata.num_reqs query_start_loc = common_attn_metadata.query_start_loc - if self.vllm_config.cache_config.mamba_cache_mode == "align": - state_indices_tensor = mamba_gather_indices( - common_attn_metadata, - self.kv_cache_spec, - )[:, 0] - else: - state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] + state_indices_tensor = mamba_get_block_table_tensor( + common_attn_metadata, + self.kv_cache_spec, + self.vllm_config.cache_config.mamba_cache_mode, + )[:, 0] # for causal_conv1d nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 5c9e19d2d20b..4806d3927216 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -1184,18 +1184,19 @@ def get_dcp_local_seq_lens( dcp_local_seq_lens = base + remainder return dcp_local_seq_lens.squeeze(1) -# For Lighter Mamba Prefix-Caching -@torch.compile -def mamba_gather_indices( + +def mamba_get_block_table_tensor( common_attn_metadata: CommonAttentionMetadata, kv_cache_spec: MambaSpec, + mamba_cache_mode: str, num_blocks: int = 1, ) -> torch.Tensor: - assert isinstance(kv_cache_spec, MambaSpec) - block_table_tensor = common_attn_metadata.block_table_tensor - if not kv_cache_spec.enable_caching: - return block_table_tensor - start_indices = (common_attn_metadata.seq_lens - 1) // kv_cache_spec.block_size - offsets = torch.arange(num_blocks, device=block_table_tensor.device) - indices_to_gather = start_indices.unsqueeze(1) + offsets - return torch.gather(block_table_tensor, 1, indices_to_gather) \ No newline at end of file + if mamba_cache_mode in ("all", "none"): + return common_attn_metadata.block_table_tensor + else: + assert isinstance(kv_cache_spec, MambaSpec) + block_table_tensor = common_attn_metadata.block_table_tensor + start_indices = (common_attn_metadata.seq_lens - 1) // kv_cache_spec.block_size + offsets = torch.arange(num_blocks, device=block_table_tensor.device) + indices_to_gather = start_indices.unsqueeze(1) + offsets + return torch.gather(block_table_tensor, 1, indices_to_gather) diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index daf8487365b1..867354a81ba3 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -9,7 +9,6 @@ from typing_extensions import Self from vllm.config import VllmConfig -from vllm.config.cache import MambaCacheMode from vllm.logger import init_logger from vllm.utils.math_utils import cdiv from vllm.utils.torch_utils import get_dtype_size @@ -248,7 +247,6 @@ class MambaSpec(KVCacheSpec): page_size_padded: int | None = None mamba_type: str = "mamba2" num_speculative_blocks: int = 0 - enable_caching: bool = False @property def page_size_bytes(self) -> int: From 2390ec3f5630fac988c60a2714f87f695de9609a Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 12 Dec 2025 00:53:40 -0800 Subject: [PATCH 047/130] revert Signed-off-by: Chen Zhang --- vllm/v1/core/block_pool.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index c1e9c60a705a..ad0c7e8e32c5 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -316,7 +316,6 @@ def cache_full_blocks( if num_cached_blocks == 0: parent_block_hash: ExternalBlockHash | None = None else: - # TODO(hhy): when LPS is enabled, parent_block maybe a null block parent_block = blocks[num_cached_blocks - 1] assert parent_block.block_hash is not None parent_block_hash = maybe_convert_block_hash( From fdba531b4352aa07469dd3246626021bc2d36c98 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 12 Dec 2025 10:27:51 -0800 Subject: [PATCH 048/130] fix pre-commit Signed-off-by: Chen Zhang --- examples/offline_inference/run.py | 5 +- my_tests/test_mamba_cache.py | 262 ------------------ tests/v1/e2e/test_mamba_prefix_cache.py | 53 ++-- vllm/envs.py | 1 + .../layers/mamba/mamba_mixer2.py | 4 - vllm/v1/attention/backends/mamba_attn.py | 1 + vllm/v1/core/block_pool.py | 44 +-- vllm/v1/core/kv_cache_coordinator.py | 3 +- vllm/v1/core/sched/scheduler.py | 41 +-- vllm/v1/core/single_type_kv_cache_manager.py | 49 +--- vllm/v1/worker/block_table.py | 14 +- vllm/v1/worker/gpu_model_runner.py | 109 ++++++-- vllm/v1/worker/utils.py | 1 - 13 files changed, 145 insertions(+), 442 deletions(-) delete mode 100644 my_tests/test_mamba_cache.py diff --git a/examples/offline_inference/run.py b/examples/offline_inference/run.py index 353fd72c0e32..5f3933a6aea4 100644 --- a/examples/offline_inference/run.py +++ b/examples/offline_inference/run.py @@ -1,6 +1,9 @@ -from vllm import LLM, SamplingParams +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import time +from vllm import LLM, SamplingParams + def main(): MODEL = "Qwen/Qwen3-Next-80B-A3B-Instruct" diff --git a/my_tests/test_mamba_cache.py b/my_tests/test_mamba_cache.py deleted file mode 100644 index 53bbd305135c..000000000000 --- a/my_tests/test_mamba_cache.py +++ /dev/null @@ -1,262 +0,0 @@ -#!/usr/bin/env python - -from enum import Enum, auto -import time -from typing import List -import requests -import multiprocessing as mp -import sys - -class TestType(Enum): - Dummy = auto() - Real = auto() - -SEED = 1234 -SEED = None -PORT = 8235 -NUM_REQUESTS = 1 -MAX_NEW_TOKENS = 1024 -IGNORE_EOS = False -TEST_TYPE = TestType.Real -IS_WARMUP = False -ONE_PROMPT = [] -# KEY = time.time() -# ONE_PROMPT = [f"There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there.\n\nThe pass key is {KEY}. Remember it. {KEY} is the pass key.\n " + \ -# "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. " * 200 + \ -# "The block is red. The sky is yello. The sun is orange. Here we go. There and back again. " * 200 + \ -# "\nWhat is the pass key?"] -# Tom Eric Bob Amy Mom Dad Lisa Susan Linda Alex Leo - -KEY = 'Lisa' -LEN = 32 * 1024 -assert LEN >= 560 -ONE_PROMPT = [] -# ONE_PROMPT = [f'Hello {KEY} ' * ((LEN-560)//2) + 'Hello ' * (560-9)] -# ONE_PROMPT = ['Hello ' * (560 * 6)] -# ONE_PROMPT = ['请详细介绍一下北京这座城市, 不少于10000字'] -ONE_PROMPT = ["adfllekkThere is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there.\n\nThe pass key is 2222. Remember it. 2222 is the pass key.\n " + \ - "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. " * 545 + \ - "\nWhat is the pass key?"] -# ONE_PROMPT = ["adfllekkThere is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there.\n\nThe pass key is 333333. Remember it. 333333 is the pass key.\n " + \ -# "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. " * 190 + \ -# "The block is red. The sky is yello. The sun is ddddd. Here we go. There and back try a. " * 185 + \ -# "\nWhat is the pass key?"] - -# ONE_PROMPT = ['Hello ' * (4096 - 30 + 11)] -# ONE_PROMPT = ['Hello ' * (4096 - 30 + 1)] # <<< last not matching -# ONE_PROMPT = ['Hello ' * (4096 - 30 + 1)] -# ONE_PROMPT = ['Hello ' * (4096 - 30 + 13)] - -if IS_WARMUP: - ONE_PROMPT = ['Wo'] # 9 tokens - NUM_REQUESTS = 1 - MAX_NEW_TOKENS = 1 - -MESSAGE = [] -# MESSAGE = MESSAGE3 - -# NOTE: block-size should be 256 -def hello(pid: int, prompt_id: int, max_new_tokens: int, ignore_eos: bool): - headers = { - "Content-Type": "application/json", - } - url = f"http://localhost:{PORT}/v1/chat/completions" - if pid == 0: - if TEST_TYPE == TestType.Dummy: - if prompt_id == 0: - # tokens: 3808*2+2720+472=10808 - # hit-rate: 0/10808=0 - # mamba state: [3808, 7616, 10336] - prompts = ["Repeat V 10 times" * 1800] - elif prompt_id == 1: - # tokens: 3808*3+544+40=12008 - # hit-rate: 3808/(10808+12008)=16.7% - # mamba state: [3808RD, 7616, 11424, 11968] - prompts = ["Repeat V 10 times" * 1000 + "Repeat V 11 times" * 1000] - elif prompt_id == 2: - # tokens: 3808+1088+512=5408 - # hit-rate: (3808+3808)/(22816+5408)=27.0% - # mamba state: [3808RD, 4896] - prompts = ["Repeat V 10 times" * 900] - elif prompt_id == 3: - # tokens: 208 - # hit-rate: (7616+0)/(28224+208)=26.8% - # mamba state: [] - prompts = ["hi " * 199] - elif prompt_id == 4: - # tokens: 3808*2+544+523=8683 - # hit-rate: (7616+0)/(28432+8683)=20.5% - # mamba state: [3808, 7616, 8160] - prompts = ["Hello " * (4096 * 2 - 30 + 256 * 2)] - elif prompt_id == 5: - # tokens: 3808+242=4050 - # hit-rate: (7616+3808)/(37115+4050)=27.8% - # mamba state: [3080RD] - prompts = ['Hello ' * (3808 + 233)] - elif prompt_id == 6: - # tokens: 544+523=1067 - # hit-rate: (11424+0)/(41165+1067)=27.1% - # mamba state: [544] - prompts = ['ha ' * (544 * 2 - 30)] - else: - prompts = ['Hi'] - elif TEST_TYPE == TestType.Real: - if prompt_id == 0: - # tokens: 3808*2+1632+381=9629 v1 - # hit-rate: 0/9629=0 - # mamba state: [3808, 7616, 9248] - # ----- - # tokens: 3920*2+1680+112 - prompts = ["adfllekkThere is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there.\n\nThe pass key is 28884. Remember it. 28884 is the pass key.\n " + \ - "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. " * 190 + \ - "The block is red. The sky is yello. The sun is ddddd. Here we go. There and back try a. " * 185 + \ - "\nWhat is the pass key?"] - elif prompt_id == 1: - # tokens: 3808*2+1632+98=13154 v1 - # hit-rate: 3808/(9629+13154)=16.7% - # mamba state: [3808RD, 7616, 13056] - prompts = ["adfllekkThere is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there.\n\nThe pass key is 28884. Remember it. 28884 is the pass key.\n " + \ - "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. " * 545 + \ - "\nWhat is the pass key?"] - elif prompt_id == 2: - # tokens: 544+126=670 v1 - # hit-rate: (3808+0)/(22783+670)=16.2% - # mamba state: [544] - prompts = ["There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there.\n\nThe pass key is 28884. Remember it. 28886 is the pass key.\n " + \ - "The grass is yellow. The sky is blue. The sun is red. Here we go. There and back again. " * 25 + \ - "\nWhat is the pass key?"] - elif prompt_id == 3: - # tokens: 544+475=1019 - # hit-rate: (3808+0)/(23453+1019)=15.6% - # mamba state: [544] - prompts = ["There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there.\n\nThe pass key is 28886. Remember it. 28886 is the pass key.\n " + \ - "The grass is yellow. The sky is blue. The sun is red. Here we go. There and back again. " * 13 + \ - "ljlkjslkfei lkjlkj elkjfslk woiejoifjwokjjlweuriljlskjf lwkjelkjlkj. lskj lkj lkjslkfj l" * 13 + \ - "\nWhat is the pass key?"] # 600 tokens hit 300 - elif prompt_id == 4: - # tokens: 544+494=1038 - # hit-rate: (3808+544)/(24472+1038)=17.1% - # mamba state: [544RD] - prompts = ["There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there.\n\nThe pass key is 28886. Remember it. 28886 is the pass key.\n " + \ - "The grass is yellow. The sky is blue. The sun is red. Here we go. There and back again. " * 13 + \ - "ljlkjslkfei lkjlkj elkjfslk woiejoifjwokjjlweuriljlskjf lwkjelkjlkj. lskj lkj lkjslkfj l" * 13 + \ - "\nWhat is the pass key? And, what is the result of reversing the pass key and adding 1234?"] - elif prompt_id == 5: - # tokens: 13056+1088+330=14474 - # hit-rate: (4352+13056)/(25510+14474)=43.5% - # mamba state: [13056RD, 13056] - prompts = ["adfllekkThere is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there.\n\nThe pass key is 28884. Remember it. 28884 is the pass key.\n " + \ - "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. " * 600 + \ - "\nWhat is the pass key?"] - elif prompt_id == 6: - prompts = ["adfllekkThere is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there.\n\nThe pass key is 28884. Remember it. 28884 is the pass key.\n " + \ - "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. " * 510 + \ - "\nWhat is the pass key?"] - else: - prompts = ['Helloha!'] - elif pid == 1: - # v1 670 tokens - prompts = ["There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there.\n\nThe pass key is 11111. Remember it. 11111 is the pass key.\n " + \ - "The grass is yellow. The sky is blue. The sun is red. Here we go. There and back again. " * 25 + \ - "\nWhat is the pass key?"] - elif pid == 2: - # v1 13152 tokens - prompts = ["adfllekkThere is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there.\n\nThe pass key is 2222. Remember it. 2222 is the pass key.\n " + \ - "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. " * 545 + \ - "\nWhat is the pass key?"] - elif pid == 3: - prompts = ["adfllekkThere is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there.\n\nThe pass key is 333333. Remember it. 333333 is the pass key.\n " + \ - "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. " * 190 + \ - "The block is red. The sky is yello. The sun is ddddd. Here we go. There and back try a. " * 185 + \ - "\nWhat is the pass key?"] # 9k tokens - elif pid == 4: - prompts = ["There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there.\n\nThe pass key is 444. Remember it. 444 is the pass key.\n " + \ - "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. " * 190 + \ - "The block is red. The sky is yello. The sun is ddddd. Here we go. There and back try a. " * 185 + \ - "\nWhat is the pass key?"] - elif pid == 5: - prompts = ["There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there.\n\n" + \ - "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. " * 190 + \ - "The pass key is 55555. Remember it. 55555 is the pass key.\n " \ - "The block is red. The sky is yello. The sun is ddddd. Here we go. There and back try a. " * 185 + \ - "\nWhat is the pass key?"] - elif pid == 6: - prompts = ["There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there.\n\n" + \ - "The grass is yellow. The sky is blue. The sun is yellow. Here we go. There and back again. " * 190 + \ - "The pass key is 66. Remember it. 66 is the pass key.\n " \ - "The block is red. The sky is yello. The sun is ddddd. Here we go. There and back try a. " * 185 + \ - "\nWhat is the pass key?"] - else: - prompts = ['Hello!'] - - if ONE_PROMPT: - # print(ONE_PROMPT) - prompts = ONE_PROMPT - - for p in prompts: - data = { - "messages": MESSAGE if not IS_WARMUP and MESSAGE else [{"role": "user", "content": p}], - "max_tokens": max_new_tokens, - "ignore_eos": ignore_eos, - "temperature": 0.7, - "top_p": 0.8, - "top_k": 20, - "repetition_penalty": 1, - "presence_penalty": 1.5, - **({'seed': SEED} if SEED is not None else {}), - "chat_template_kwargs": {"enable_thinking": False} - } - response = requests.post(url, headers=headers, json=data) - if response.status_code == 200: - # print(response.content) - result = response.json() - # print(f"[PID {pid}] Prompt:\n {prompts[0]}") - print(f"[PID {pid}] Response:\n {result['choices'][0]['message']['content']}\n {'-' * 40}\n", end='') - # print(result) - # loss = json.loads(result['choices'][0]['message']['content'])['loss'] - # risk_level_logits = torch.tensor(json.loads(result['choices'][0]['message']['content'])['risk_level_logits']).view(-1, 2) - # category_logits = torch.tensor(json.loads(result['choices'][0]['message']['content'])['category_logits']).view(-1, 26) - # query_risk_level_logits = torch.tensor(json.loads(result['choices'][0]['message']['content'])['query_risk_level_logits']).view(-1, 3) - # query_category_logits = torch.tensor(json.loads(result['choices'][0]['message']['content'])['query_category_logits']).view(-1, 33) - - # torch.set_printoptions(precision=3, sci_mode=False) - # print(f"{loss=},{risk_level_logits.shape=},{risk_level_logits=},{category_logits.shape=},{category_logits=}") - - # query_risk_level_prob = F.softmax(query_risk_level_logits, dim=1) - # risk_level_prob = F.softmax(risk_level_logits, dim=1) - # print(f"{query_risk_level_prob.shape=},{query_risk_level_prob=}") - # print(f"{risk_level_prob.shape=},{risk_level_prob=}") - - else: - print(f"Request failed with status code {response.status_code}") - print("Response content:") - print(response.content) - -def main(prompt_id: int): - procs: List[mp.Process] = [] - - start = time.time() - for pid in range(NUM_REQUESTS): - proc = mp.Process( - target=hello, args=(pid, prompt_id, MAX_NEW_TOKENS, IGNORE_EOS), daemon=True - ) - proc.start() - procs.append(proc) - - for _proc in procs: - _proc.join() - if _proc.exitcode != 0: - sys.exit(_proc.exitcode) - - elapsed = time.time() - start - output_tps = MAX_NEW_TOKENS * NUM_REQUESTS / elapsed - print("\n") - print(f"Generate {output_tps} tokens/s, elapsed: {elapsed} s, TPS {output_tps / NUM_REQUESTS}, TPOT {1000 / (output_tps / NUM_REQUESTS)}ms") - - -if __name__ == "__main__": - prompt_id = 0 - if len(sys.argv) > 1: - prompt_id = int(sys.argv[1]) - assert prompt_id >= 0 - main(prompt_id) \ No newline at end of file diff --git a/tests/v1/e2e/test_mamba_prefix_cache.py b/tests/v1/e2e/test_mamba_prefix_cache.py index 1b6a58b6276f..36df54c6e57f 100644 --- a/tests/v1/e2e/test_mamba_prefix_cache.py +++ b/tests/v1/e2e/test_mamba_prefix_cache.py @@ -34,7 +34,7 @@ class StepAction: num_speculative_tokens = 3 num_accepted_tokens = 1 -prompt_token_ids = [] +prompt_token_ids: list[int] = [] MODEL = "Qwen/Qwen3-Next-80B-A3B-Instruct" BLOCK_SIZE = 560 NUM_HIDDEN_LAYERS = 8 @@ -49,8 +49,9 @@ def fake_sample_fn( logits: torch.Tensor | None, spec_decode_metadata: SpecDecodeMetadata | None, ) -> SamplerOutput: + assert logits is not None print( - f"[UNIT TEST] fake_sample_fn: {logits.shape=} {spec_decode_metadata=} {self.input_ids.cpu=}" + f"[UNIT TEST] fake_sample_fn: {logits.shape=} {spec_decode_metadata=} {self.input_ids.cpu=}" # noqa: E501 ) num_computed_tokens_cpu_tensor = self.input_batch.num_computed_tokens_cpu_tensor num_computed_tokens = num_computed_tokens_cpu_tensor[0].item() @@ -60,7 +61,7 @@ def fake_sample_fn( first_token_id_index = num_computed_tokens + 1 if spec_decode_metadata is None: print( - f"[UNIT TEST] fake_sample_fn: {first_token_id_index=} {prompt_token_ids[first_token_id_index]=}" + f"[UNIT TEST] fake_sample_fn: {first_token_id_index=} {prompt_token_ids[first_token_id_index]=}" # noqa: E501 ) return SamplerOutput( sampled_token_ids=torch.tensor( @@ -79,7 +80,7 @@ def fake_sample_fn( num_sampled_tokens - len(accpeted_tokens) ) print( - f"[UNIT TEST] fake_sample_fn: {first_token_id_index=} {accpeted_tokens=} {sampled_token_ids=}" + f"[UNIT TEST] fake_sample_fn: {first_token_id_index=} {accpeted_tokens=} {sampled_token_ids=}" # noqa: E501 ) # if ( # self.input_batch.num_computed_tokens_cpu_tensor[0].item() @@ -123,7 +124,7 @@ def fake_propose_draft_token_ids_fn( num_computed_tokens + 1 ) # bonus token isn't considered as computed print( - f"fake_propose_draft_token_ids_fn: {self.input_batch.num_accepted_tokens_cpu=}" + f"fake_propose_draft_token_ids_fn: {self.input_batch.num_accepted_tokens_cpu=}" # noqa: E501 ) first_token_id_index += self.input_batch.num_accepted_tokens_cpu[0].item() proposed_draft_token_ids = [ @@ -132,7 +133,7 @@ def fake_propose_draft_token_ids_fn( ] ] print( - f"[UNIT TEST] fake_propose_draft_token_ids_fn: {num_computed_tokens=} num_accepted_tokens={self.input_batch.num_accepted_tokens_cpu[0].item()} num_prompt_tokens={self.input_batch.num_prompt_tokens[0].item()} num_tokens_no_spec={self.input_batch.num_tokens_no_spec[0].item()} {first_token_id_index=} {proposed_draft_token_ids=}" + f"[UNIT TEST] fake_propose_draft_token_ids_fn: {num_computed_tokens=} num_accepted_tokens={self.input_batch.num_accepted_tokens_cpu[0].item()} num_prompt_tokens={self.input_batch.num_prompt_tokens[0].item()} num_tokens_no_spec={self.input_batch.num_tokens_no_spec[0].item()} {first_token_id_index=} {proposed_draft_token_ids=}" # noqa: E501 ) return proposed_draft_token_ids @@ -180,9 +181,9 @@ def fake_allocate_slots_fn( cur_block_ids = self.coordinator.single_type_managers[0].req_to_blocks[ request.request_id ] - not_null_block = [not block.is_null for block in cur_block_ids] - not_null_block = [1 if block else 0 for block in not_null_block] - assert not_null_block == cur_step_action.kv_cache_block_ids + not_null_block_flags = [not block.is_null for block in cur_block_ids] + block_ids = [1 if block else 0 for block in not_null_block_flags] + assert block_ids == cur_step_action.kv_cache_block_ids return ret return fake_allocate_slots_fn @@ -217,7 +218,7 @@ def fake_execute_model_fn( scheduler_output.scheduled_cached_reqs.num_computed_tokens[0] ) print( - f"fake_execute_model_fn: {num_computed_tokens=} {last_num_computed_tokens=} {num_computed_tokens // BLOCK_SIZE > last_num_computed_tokens // BLOCK_SIZE=}" + f"fake_execute_model_fn: {num_computed_tokens=} {last_num_computed_tokens=} {num_computed_tokens // BLOCK_SIZE > last_num_computed_tokens // BLOCK_SIZE=}" # noqa: E501 ) if ( num_computed_tokens // BLOCK_SIZE @@ -226,7 +227,7 @@ def fake_execute_model_fn( # generated a new aligned block in this step block_idx = num_computed_tokens // mamba_spec.block_size - 1 print( - f"[UNIT TEST] fake_execute_model_fn: block_idx= {block_idx} for num_computed_tokens={num_computed_tokens - num_computed_tokens % BLOCK_SIZE}" + f"[UNIT TEST] fake_execute_model_fn: block_idx= {block_idx} for num_computed_tokens={num_computed_tokens - num_computed_tokens % BLOCK_SIZE}" # noqa: E501 ) block_id = ( self.input_batch.block_table.block_tables[mamba_group_id] @@ -247,7 +248,7 @@ def fake_execute_model_fn( last_num_computed_tokens = num_computed_tokens else: last_num_computed_tokens = 0 - print(f"[UNIT TEST] fake_execute_model_fn: clear last_num_computed_tokens") + print("[UNIT TEST] fake_execute_model_fn: clear last_num_computed_tokens") ret = original_execute_model_fn(self, scheduler_output, intermediate_tensors) @@ -320,7 +321,8 @@ def test_run_ref_mamba_state(monkeypatch: pytest.MonkeyPatch): num_generated_tokens = 8000 num_prompt_tokens = 500 sampling_params = SamplingParams(temperature=0.0, max_tokens=num_generated_tokens) - full_prompt = open(f"{os.path.dirname(__file__)}/input.txt").read() + with open(f"{os.path.dirname(__file__)}/input.txt") as file: + full_prompt = file.read() fake_execute_model_fn = get_fake_execute_model_fn(GPUModelRunner.execute_model) monkeypatch.setattr(GPUModelRunner, "execute_model", fake_execute_model_fn) fake_sample_fn = get_fake_sample_fn() @@ -342,7 +344,7 @@ def test_run_ref_mamba_state(monkeypatch: pytest.MonkeyPatch): ) print(f"Generated text: {outputs[0].outputs[0].token_ids}") print( - f"expect token ids: {prompt_token_ids[num_prompt_tokens : num_prompt_tokens + num_generated_tokens]}" + f"expect token ids: {prompt_token_ids[num_prompt_tokens : num_prompt_tokens + num_generated_tokens]}" # noqa: E501 ) print(f"mamba_kv_cache_dict: {mamba_kv_cache_dict.keys()}") # ref_mamba_kv_cache_dict = torch.load("mamba_kv_cache_dict.pth") @@ -368,7 +370,7 @@ def check_mamba_state_equal( diff_idx = torch.nonzero(diff_mask) if diff_idx.shape[0] * 100 < ref.numel(): print( - f"[WARNING] found {diff_idx.shape[0] * 100 / ref.numel()}% of the elements are different" + f"[WARNING] found {diff_idx.shape[0] * 100 / ref.numel()}% of the elements are different" # noqa: E501 ) continue print( @@ -429,7 +431,8 @@ def apply_patch(monkeypatch: pytest.MonkeyPatch): def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch): apply_patch(monkeypatch) - full_prompt = open(f"{os.path.dirname(__file__)}/input.txt").read() + with open(f"{os.path.dirname(__file__)}/input.txt") as file: + full_prompt = file.read() tests = { "accept_1": TestConfig( num_prompt_tokens=554, @@ -694,17 +697,17 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch): step_action_next.kv_cache_block_ids is not None and len(step_action_next.kv_cache_block_ids) == 0 ): - step_action_next.kv_cache_block_ids = ( - step_action_prev.kv_cache_block_ids.copy() - ) + prev_block_ids = step_action_prev.kv_cache_block_ids + if prev_block_ids is not None: + step_action_next.kv_cache_block_ids = prev_block_ids.copy() global step_actions step_actions = test_config.step_actions print("step actions: ", step_actions) print( - f"expect token ids: {prompt_token_ids[num_prompt_tokens : num_prompt_tokens + num_generated_tokens]}" + f"expect token ids: {prompt_token_ids[num_prompt_tokens : num_prompt_tokens + num_generated_tokens]}" # noqa: E501 ) - outputs = engine.generate( + _ = engine.generate( [TokensPrompt(prompt_token_ids=prompt_token_ids[:num_prompt_tokens])], sampling_params, ) @@ -713,7 +716,7 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch): keys_to_check = [ (action.postprocess_copy_idx[1] + 1) * BLOCK_SIZE for action in test_config.step_actions - if action.postprocess_copy_idx[0] != -1 + if action.postprocess_copy_idx and action.postprocess_copy_idx[0] != -1 ] check_mamba_state_equal(mamba_state_ref, mamba_kv_cache_dict, keys_to_check) mamba_kv_cache_dict.clear() @@ -722,4 +725,8 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch): def test_check_mamba_state_equal(): mamba_state_ref = torch.load("mamba_kv_cache_dict.pth") mamba_state_new = torch.load("mamba_kv_cache_dict_new.pth") - check_mamba_state_equal(mamba_state_ref, mamba_state_new) + check_mamba_state_equal( + mamba_state_ref, + mamba_state_new, + keys_to_check=list(mamba_state_ref.keys()), + ) diff --git a/vllm/envs.py b/vllm/envs.py index a7417badfd9f..954e1ced06df 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -245,6 +245,7 @@ VLLM_USE_V2_MODEL_RUNNER: bool = False VLLM_USE_LIGHTER_MAMBA_CACHE: bool = False + def get_default_cache_root(): return os.getenv( "XDG_CACHE_HOME", diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 61e741b75946..6b42567f1b5c 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -1,10 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from vllm.attention.backends.abstract import AttentionBackend import torch from torch import nn diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index 1eacd2ffdedf..c2ec8e6aa7c3 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -5,6 +5,7 @@ from typing import ClassVar, TypeVar import torch + from vllm.config import VllmConfig from vllm.utils.math_utils import cdiv from vllm.v1.attention.backends.utils import ( diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index ad0c7e8e32c5..2e8889ae74e9 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable, Sequence -from typing import Any, Optional +from typing import Any from vllm.distributed.kv_events import ( MEDIUM_GPU, @@ -206,48 +206,6 @@ def get_cached_block( cached_blocks.append(block) return cached_blocks - # def cache_full_block( - # self, - # request: Request, - # block: KVCacheBlock, - # cached_block_index: int, - # block_size: int, - # kv_cache_group_id: int, - # ) -> None: - # """Cache a full block for prefix caching. - # """ - - # assert cached_block_index >= 0 - # assert len(request.block_hashes) > cached_block_index - # new_block_hash: BlockHash = request.block_hashes[cached_block_index] - # new_hashes: Optional[list[ExternalBlockHash]] = ( - # [] if self.enable_kv_cache_events else None) - # assert block.block_hash is None - - # # Update and added the full block to the cache. - # block_hash_with_group_id: BlockHashWithGroupId = make_block_hash_with_group_id( - # new_block_hash, kv_cache_group_id) - # block.block_hash = block_hash_with_group_id - # self.cached_block_hash_to_block.insert(block_hash_with_group_id, block) - # if new_hashes is not None: - # new_hashes.append(maybe_convert_block_hash(new_block_hash)) - - # if self.enable_kv_cache_events: - # parent_block_hash: Optional[ExternalBlockHash] = None - - # self.kv_event_queue.append( - # BlockStored( - # block_hashes=new_hashes, - # parent_block_hash=parent_block_hash, - # token_ids=request. - # all_token_ids[cached_block_index * block_size: - # (cached_block_index+1) * block_size], - # block_size=block_size, - # lora_id=request.lora_request.id - # if request.lora_request else None, - # medium=MEDIUM_GPU, - # )) - def cache_full_blocks( self, request: Request, diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index ebfa992e305b..378b6e515e0b 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -5,7 +5,6 @@ from math import lcm from vllm.config.cache import CacheConfig -from vllm.config.vllm import VllmConfig from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_metrics import KVCacheMetricsCollector from vllm.v1.core.kv_cache_utils import ( @@ -102,7 +101,7 @@ def get_num_blocks_to_allocate( # For cross-attention, we issue a single static allocation # of blocks based on the number of encoder input tokens. num_blocks_to_allocate += manager.get_num_blocks_to_allocate( - request_id, num_encoder_tokens, [] + request_id, num_encoder_tokens, [], num_encoder_tokens ) else: num_blocks_to_allocate += manager.get_num_blocks_to_allocate( diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index f1db90990915..29c823c15ef6 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -213,9 +213,6 @@ def __init__( ) self.use_pp = self.parallel_config.pipeline_parallel_size > 1 self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER - print( - f">>> [DEBUG] Scheduler: init enable_prefix_caching={self.cache_config.enable_prefix_caching} block_size={self.block_size} kv_cache_config={self.kv_cache_config}" - ) def has_mamba_layers(kv_cache_config: KVCacheConfig) -> bool: has_mamba: bool = any( @@ -278,11 +275,6 @@ def _mamba_block_aligned_split( return num_new_tokens def schedule(self) -> SchedulerOutput: - print(f">>> [DEBUG] Scheduler: schidule new step") - for req in self.requests.values(): - print( - f">>> [DEBUG] Scheduler: request {req.request_id} num_computed_tokens={req.num_computed_tokens} num_tokens={req.num_tokens} num_tokens_with_spec={req.num_tokens_with_spec}" - ) # NOTE(woosuk) on the scheduling algorithm: # There's no "decoding phase" nor "prefill phase" in the scheduler. # Each request just has the num_computed_tokens and @@ -332,10 +324,6 @@ def schedule(self) -> SchedulerOutput: req_index += 1 continue - logger.info( - f">>> [DEBUG] Scheduler: schedule RUNING: req_id={request.request_id}, " - f"num_prompt_tokens={request.num_prompt_tokens=}" - ) # Ensure new tokens for a request in the prefill phase do not contain # sps tokens, especially in the last prefill chunk. For a hybrid-model, # extra sps tokens would corrupt the generated Mamba state. @@ -452,9 +440,6 @@ def schedule(self) -> SchedulerOutput: req_index -= 1 else: preempted_req = self.running.pop() - print( - f">>> [DEBUG] Scheduler: preempted request {preempted_req.request_id}" - ) self._preempt_request(preempted_req, scheduled_timestamp) preempted_reqs.append(preempted_req) @@ -527,10 +512,6 @@ def schedule(self) -> SchedulerOutput: break request = self.waiting.peek_request() - logger.info( - f">>> [DEBUG] Scheduler: schedule WAITING: req_id={request.request_id}, " - f"num_prompt_tokens={request.num_prompt_tokens=}" - ) # KVTransfer: skip request if still waiting for remote kvs. if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: @@ -538,10 +519,6 @@ def schedule(self) -> SchedulerOutput: if is_ready: request.status = RequestStatus.WAITING else: - logger.debug( - "%s is still in WAITING_FOR_REMOTE_KVS state.", - request.request_id, - ) self.waiting.pop_request() skipped_waiting_requests.prepend_request(request) continue @@ -581,10 +558,6 @@ def schedule(self) -> SchedulerOutput: new_computed_blocks, num_new_local_computed_tokens = ( self.kv_cache_manager.get_computed_blocks(request) ) - logger.info( - f">>> [DEBUG] Scheduler: get_computed_blk: req_id={request.request_id}," - f"{num_new_local_computed_tokens=}" - ) # Get externally-cached tokens if using a KVConnector. if self.connector is not None: @@ -834,14 +807,6 @@ def schedule(self) -> SchedulerOutput: self.prev_step_scheduled_req_ids.clear() self.prev_step_scheduled_req_ids.update(num_scheduled_tokens.keys()) - logger.info( - ">>> [DEBUG] Scheduler: new_reqs:" - f"{[(reqdata.req_id, reqdata.block_ids) for reqdata in new_reqs_data]}" - ) - logger.info( - ">>> [DEBUG] Scheduler: cached_reqs:" - f"{[(req_id, cached_reqs_data.new_block_ids[i]) for i, req_id in enumerate(cached_reqs_data.req_ids)]}" - ) scheduler_output = SchedulerOutput( scheduled_new_reqs=new_reqs_data, scheduled_cached_reqs=cached_reqs_data, @@ -858,7 +823,6 @@ def schedule(self) -> SchedulerOutput: finished_req_ids=self.finished_req_ids, free_encoder_mm_hashes=self.encoder_cache_manager.get_freed_mm_hashes(), ) - # logger.info(f">>> [DEBUG] Scheduler: scheduler output: {scheduler_output}") # NOTE(Kuntai): this function is designed for multiple purposes: # 1. Plan the KV cache store @@ -879,7 +843,6 @@ def schedule(self) -> SchedulerOutput: with record_function_or_nullcontext("schedule: update_after_schedule"): self._update_after_schedule(scheduler_output) - logger.info(f">>> [DEBUG] Scheduler: scheduler_output: {scheduler_output}") return scheduler_output def _preempt_request( @@ -1444,7 +1407,7 @@ def update_draft_token_ids( # Add newly generated spec token ids to the request. if self.structured_output_manager.should_advance(request): metadata = request.structured_output_request - request.spec_token_ids = metadata.grammar.validate_tokens( # type: ignore[union-attr] + request.spec_token_ids = metadata.grammar.validate_tokens( spec_token_ids ) else: @@ -1916,7 +1879,7 @@ def _handle_invalid_blocks(self, invalid_block_ids: set[int]) -> set[str]: all_failed_req_ids = async_failed_req_ids | sync_failed_req_ids logger.error( "Failing %d request(s) due to KV load failure " - "(failure_policy=fail, %d tokens affected). Request IDs: %s", + "(failure_policy=fail, %d tokens affected). Request IDs: 328", total_failed_requests, total_failed_tokens, all_failed_req_ids, diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 37b471715e08..709f6b37433b 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -31,7 +31,6 @@ def format_blocks(blocks: list[KVCacheBlock]): while i < len(blocks): if blocks[i].block_id == 0: count = 0 - start = i while i < len(blocks) and blocks[i].block_id == 0: count += 1 i += 1 @@ -194,13 +193,6 @@ def cache_blocks(self, request: Request, num_tokens: int) -> None: if num_cached_blocks >= num_full_blocks: return - if isinstance(self, MambaManager) and num_cached_blocks < num_full_blocks: - self.print( - f"Mamba.cache_blocks: req_id={request.request_id}, {num_tokens=}, " - f"{num_cached_blocks=}, {num_full_blocks=}, " - f"new_full_blocks={format_blocks(self.req_to_blocks[request.request_id][num_cached_blocks:num_full_blocks])}" - ) - self.block_pool.cache_full_blocks( request=request, blocks=self.req_to_blocks[request.request_id], @@ -327,9 +319,7 @@ def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> No break removed_blocks.append(blocks[i]) blocks[i] = self._null_block - self.print( - f"Mamba.remove_skipped_blocks: {request_id=}, {num_computed_tokens=}, {num_skipped_tokens=}, {num_skipped_blocks=}, removed_blocks={format_blocks(removed_blocks)}" - ) + self.block_pool.free_blocks(removed_blocks) def get_num_skipped_tokens(self, num_computed_tokens: int) -> int: @@ -754,8 +744,11 @@ def find_longest_cache_hit( computed.append(cached) break # we just need the last match - early stopping + computed_blocks_fmt = [ + format_blocks(computed_block) for computed_block in computed_blocks + ] print( - f"Mamba.FindLongest: computed_blocks={[format_blocks(computed_block) for computed_block in computed_blocks]}", + f"Mamba.FindLongest: computed_blocks={computed_blocks_fmt}", flush=True, ) return computed_blocks @@ -821,9 +814,9 @@ def get_num_blocks_to_allocate( if num_new_blocks > 0: # (Chen): This may be possible. (block_size 4, 2 sps). # [A, stoken1, stoken2] SBLOCK1 SBLOCK2 -> - # [A, ?, ?, ?] NULL NULL [?, ?, ?, B] [stoken 1, stoken 2] SBLOCK1 SBLOCK2 -> need two blocks + # [A, ?, ?, ?] NULL NULL [?, ?, ?, B] [stoken 1, stoken 2] SBLOCK1 SBLOCK2 -> need two blocks # noqa: E501 # but we do it as following: - # [A, ?, ?, ?] NULL NULL NULL [stoken 1, stoken 2] SBLOCK1 SBLOCK2 -> need 1 block + # [A, ?, ?, ?] NULL NULL NULL [stoken 1, stoken 2] SBLOCK1 SBLOCK2 -> need 1 block # noqa: E501 if request_id in self._allocated_block_reqs: # previously allocated blocks num_new_blocks = 1 @@ -838,20 +831,14 @@ def get_num_blocks_to_allocate( num_evictable_computed_blocks = sum( blk.ref_cnt == 0 and not blk.is_null for blk in new_computed_blocks ) - self.print( - f"Mamba.get_num_blocks_to_allocate: {request_id=}, {num_tokens=}, {num_new_blocks=}" - ) return num_new_blocks + num_evictable_computed_blocks def save_new_computed_blocks( - self, request_id: str, new_computed_blocks: list[KVCacheBlock] + self, request_id: str, new_computed_blocks: Sequence[KVCacheBlock] ) -> None: # TODO(hhy): remove when prefix-caching is ready assert isinstance(self.kv_cache_spec, MambaSpec) - self.print( - f"Mamba.save_computed: {request_id=}, new_computed_blocks={format_blocks(new_computed_blocks)}" - ) - super().save_new_computed_blocks(request_id, new_computed_blocks) + super().save_new_computed_blocks(request_id, list(new_computed_blocks)) def allocate_new_blocks( self, request_id: str, num_tokens: int, num_tokens_target_model: int @@ -875,18 +862,21 @@ def allocate_new_blocks( return [] else: assert num_required_blocks > len(req_blocks), ( - f"num_required_blocks {num_required_blocks} < len(req_blocks) {len(req_blocks)}" + "num_required_blocks " + f"{num_required_blocks} < len(req_blocks) {len(req_blocks)}" ) prev_block_len = len(req_blocks) blocks_allocated = request_id in self._allocated_block_reqs # Record the last state block if blocks_allocated: - # We always save the current running state at the last (1 + num_speculative_blocks) block + # We always save the running state at the last + # (1 + num_speculative_blocks) block self.last_state_block_idx[request_id] = ( prev_block_len - 1 - self.num_speculative_blocks ) elif prev_block_len > 0: - # When a new request hits the prefix cache, the last block saves the hit state. + # When a new request hits the prefix cache, the last block + # saves the hit state. self.last_state_block_idx[request_id] = prev_block_len - 1 num_skipped_blocks = ( @@ -909,15 +899,9 @@ def allocate_new_blocks( if block_idx < num_skipped_blocks: req_blocks.append(req_blocks[block_idx]) req_blocks[block_idx] = self._null_block - self.print( - f"Mamba.alloc_blks: {request_id=}, moving block {block_idx} to the end now, req_blocks={format_blocks(req_blocks)}" - ) else: break num_new_blocks = num_required_blocks - len(req_blocks) - self.print( - f"Mamba.alloc_blks: {request_id=}, num_new_blocks={num_new_blocks}" - ) if blocks_allocated: assert num_new_blocks <= 1 else: @@ -925,9 +909,6 @@ def allocate_new_blocks( new_blocks = self.block_pool.get_new_blocks(num_new_blocks) req_blocks.extend(new_blocks) self._allocated_block_reqs.add(request_id) - self.print( - f"Mamba.alloc_blks: {request_id=}, {len(req_blocks)=}, {len(self.req_to_blocks[request_id])=}, req_blocks={format_blocks(req_blocks)}" - ) return req_blocks[prev_block_len:] def free(self, request_id: str) -> None: diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index a97ac1cb1a7b..0b6319eec121 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -113,19 +113,19 @@ def append_row( start = self.num_blocks_per_row[row_idx] self.num_blocks_per_row[row_idx] += num_blocks self.block_table.np[row_idx, start : start + num_blocks] = block_ids - + def pop_row(self, num_blocks: int, row_idx: int): if num_blocks <= 0: return - + if self.use_hybrid_blocks: num_blocks = num_blocks * self.blocks_per_kv_block - + end = self.num_blocks_per_row[row_idx] start = end - num_blocks assert start >= 0 self.num_blocks_per_row[row_idx] -= num_blocks - self.block_table.np[row_idx, start : end] = 0 + self.block_table.np[row_idx, start:end] = 0 def add_row(self, block_ids: list[int], row_idx: int) -> None: self.num_blocks_per_row[row_idx] = 0 @@ -304,9 +304,9 @@ def __init__( BlockTable( block_size, max_num_reqs, - # TODO: when prefix-caching and sps are both enable for - # mamba hybrid model, it will need - # `cdiv(max_model_len, block_size * total_cp_world_size) + num_speculative_tokens` + # TODO: when prefix-caching and sps are both enable for + # mamba hybrid model, it will need + # `cdiv(max_model_len, block_size * total_cp_world_size) + num_speculative_tokens` # noqa: E501 # blocks for mamba groups max( cdiv(max_model_len, block_size * total_cp_world_size), diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index bdbef5fc8607..b33f7e0f32c8 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -453,7 +453,9 @@ def __init__( # uses output token ids so we set this conservatively. logitsprocs_need_output_token_ids=bool(custom_logitsprocs), is_pooling_model=self.is_pooling_model, - cp_kv_cache_interleave_size=self.parallel_config.cp_kv_cache_interleave_size, + cp_kv_cache_interleave_size=( + self.parallel_config.cp_kv_cache_interleave_size + ), ) self.use_async_scheduling = self.scheduler_config.async_scheduling @@ -517,7 +519,8 @@ def __init__( # NOTE: `mrope_positions` is implemented with one additional dummy # position on purpose to make it non-contiguous so that it can work # with torch compile. - # See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923 + # See detailed explanation in + # https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923 # NOTE: When M-RoPE is enabled, position ids are 3D regardless of # the modality of inputs. For text-only inputs, each dimension has @@ -607,7 +610,7 @@ def __init__( # Ephemeral state transferred between execute_model() and sample_tokens(). self.execute_model_state: ExecuteModelState | None = None self.kv_connector_output: KVConnectorOutput | None = None - self.mamba_state_idx: dict[str, list[int]] = {} + self.mamba_state_idx: dict[str, int] = {} self.layerwise_nvtx_hooks_registered = False def reset_mm_cache(self) -> None: @@ -1056,10 +1059,13 @@ def _update_states_after_model_execute( for i, num_tokens in enumerate(num_accepted_tokens): self.input_batch.num_accepted_tokens_cpu[i] = num_tokens if is_global_first_rank(): - logger.info(f">>> [DEBUG] Worker: _update_states: {output_token_ids=}") logger.info( - f">>> [DEBUG] Worker: _update_states: " - f"{self.input_batch.num_accepted_tokens_cpu[:len(num_accepted_tokens)]=}" + ">>> [DEBUG] Worker: _update_states: output_token_ids=%s", + output_token_ids, + ) + logger.info( + ">>> [DEBUG] Worker: _update_states: num_accepted_tokens_cpu=%s", + self.input_batch.num_accepted_tokens_cpu[: len(num_accepted_tokens)], ) if self.cache_config.mamba_cache_mode == "align": self._postprocess_mamba(scheduler_output) @@ -1731,10 +1737,12 @@ def _build_attn_group_metadata( if isinstance(attn_metadata, list): for ub_metadata in attn_metadata: for _metadata in ub_metadata.values(): - _metadata.mm_prefix_range = req_doc_ranges # type: ignore[attr-defined] + # type: ignore[attr-defined] + _metadata.mm_prefix_range = req_doc_ranges else: for _metadata in attn_metadata.values(): - _metadata.mm_prefix_range = req_doc_ranges # type: ignore[attr-defined] + # type: ignore[attr-defined] + _metadata.mm_prefix_range = req_doc_ranges if spec_decode_common_attn_metadata is not None and ( num_reqs != num_reqs_padded or num_tokens != num_tokens_padded @@ -2185,7 +2193,7 @@ def _execute_mm_encoder( # 2. A list or tuple (length: num_items) of tensors, # each of shape (feature_size, hidden_size) in case the feature # size is dynamic depending on the input multimodal items. - curr_group_outputs = model.embed_multimodal(**mm_kwargs_group) # type: ignore[assignment] + curr_group_outputs = model.embed_multimodal(**mm_kwargs_group) sanity_check_mm_encoder_outputs( curr_group_outputs, @@ -2584,7 +2592,9 @@ def _sample( ) if is_global_first_rank(): logger.info( - f">>> [DEBUG] Worker: sampler_output: {sampler_output.sampled_token_ids.shape} {sampler_output.sampled_token_ids}" + ">>> [DEBUG] Worker: sampler_output: shape=%s tokens=%s", + sampler_output.sampled_token_ids.shape, + sampler_output.sampled_token_ids, ) return sampler_output @@ -2935,20 +2945,23 @@ def _mamba_copy_block( def _preprocess_mamba(self, scheduler_output: "SchedulerOutput"): """ - Copies the mamba state of previous step to the last (1 + num_speculative_blocks) block + Copy the mamba state of previous step to the last + (1 + num_speculative_blocks) block. """ mamba_group_ids, mamba_spec = get_mamba_groups(self.kv_cache_config) num_speculative_blocks = mamba_spec.num_speculative_blocks # TODO(Chen): we need to optimize this function a lot assert self.cache_config.enable_prefix_caching block_size = mamba_spec.block_size - for req_id in itertools.chain( - scheduler_output.finished_req_ids, scheduler_output.preempted_req_ids - ): + finished_req_ids = scheduler_output.finished_req_ids + preempted_req_ids = scheduler_output.preempted_req_ids or [] + for req_id in itertools.chain(finished_req_ids, preempted_req_ids): self.mamba_state_idx.pop(req_id, None) for i, req_id in enumerate(self.input_batch.req_ids): if is_global_first_rank(): - logger.info(f">>> [DEBUG] Worker: preprocess mamba for RUN: {req_id=}") + logger.info( + ">>> [DEBUG] Worker: preprocess mamba for RUN: req_id=%s", req_id + ) req_state = self.requests[req_id] prev_state_idx = self.mamba_state_idx.get(req_id) if prev_state_idx is None: @@ -2957,11 +2970,15 @@ def _preprocess_mamba(self, scheduler_output: "SchedulerOutput"): prev_state_idx = (req_state.num_computed_tokens - 1) // block_size num_blocks = len(req_state.block_ids[mamba_group_ids[0]]) - # We always save the current running state at the last (1 + num_speculative_blocks) block + # We always save the current running state at the last + # (1 + num_speculative_blocks) block curr_state_idx = num_blocks - 1 - num_speculative_blocks if is_global_first_rank(): logger.info( - f">>> [DEBUG] Worker: preprocess mamba: {req_id=}, idx {prev_state_idx=} -> {curr_state_idx=}" + ">>> [DEBUG] Worker: preprocess mamba: req_id=%s, idx %s -> %s", + req_id, + prev_state_idx, + curr_state_idx, ) self.mamba_state_idx[req_id] = curr_state_idx if prev_state_idx != -1 and prev_state_idx != curr_state_idx: @@ -3010,7 +3027,17 @@ def _mamba_copy_block_for_qwen_next( dest_gdn_state.copy_(src_gdn_state) if is_global_first_rank() and layer_name == layer_names[0]: logger.info( - f">>> [DEBUG] Worker: mamba_copy_block_for_qwen_next: {layer_name=}, idx {src_block_idx=} -> {dest_block_idx=} conv {conv_state_block_id=} -> {dest_block_id=} with bias {accept_token_bias}, {gdn_state_block_id=} -> {dest_block_id=}" + ">>> [DEBUG] Worker: mamba_copy_block_for_qwen_next: " + "layer_name=%s, idx %s -> %s conv %s -> %s with bias %s, " + "gdn %s -> %s", + layer_name, + src_block_idx, + dest_block_idx, + conv_state_block_id, + dest_block_id, + accept_token_bias, + gdn_state_block_id, + dest_block_id, ) def _postprocess_mamba(self, scheduler_output: "SchedulerOutput"): @@ -3026,7 +3053,11 @@ def _postprocess_mamba(self, scheduler_output: "SchedulerOutput"): num_accepted_tokens_cpu = self.input_batch.num_accepted_tokens_cpu if is_global_first_rank(): logger.info( - f">>> [DEBUG] Worker: postprocess mamba {num_scheduled_tokens_dict=} {scheduled_spec_decode_tokens_dict=} {num_accepted_tokens_cpu=}" + ">>> [DEBUG] Worker: postprocess mamba num_scheduled_tokens=%s " + "scheduled_spec_decode_tokens=%s num_accepted_tokens_cpu=%s", + num_scheduled_tokens_dict, + scheduled_spec_decode_tokens_dict, + num_accepted_tokens_cpu, ) # NOTE: can be optimized as this function always returns the same result mamba_group_ids, mamba_spec = get_mamba_groups(self.kv_cache_config) @@ -3046,7 +3077,16 @@ def _postprocess_mamba(self, scheduler_output: "SchedulerOutput"): ) if is_global_first_rank(): logger.info( - f">>> [DEBUG] Worker: postprocess mamba: {req_id=}, {num_computed_tokens=}, {num_scheduled_tokens=} {num_draft_tokens=} {num_accepted_tokens=} {num_tokens_running_state=} {new_num_computed_tokens=} {aligned_new_computed_tokens=}" + ">>> [DEBUG] Worker: postprocess mamba: req=%s comp=%s " + "sched=%s draft=%s accepted=%s run_state=%s new=%s aligned=%s", + req_id, + num_computed_tokens, + num_scheduled_tokens, + num_draft_tokens, + num_accepted_tokens, + num_tokens_running_state, + new_num_computed_tokens, + aligned_new_computed_tokens, ) # TODO: how to ensure all blocks that cache_blocks called are cached here? if aligned_new_computed_tokens >= num_tokens_running_state: @@ -3059,7 +3099,12 @@ def _postprocess_mamba(self, scheduler_output: "SchedulerOutput"): ) if is_global_first_rank(): logger.info( - f">>> [DEBUG] Worker: postprocess mamba copy: {req_id=}, {src_block_idx=} -> {dest_block_idx=} with bias {accept_token_bias}" + ">>> [DEBUG] Worker: postprocess mamba copy: req_id=%s, " + "%s -> %s with bias %s", + req_id, + src_block_idx, + dest_block_idx, + accept_token_bias, ) self._mamba_copy_block_for_qwen_next( mamba_group_ids, @@ -5088,8 +5133,10 @@ def _check_and_update_cudagraph_mode( # if we have dedicated decode cudagraphs, and spec-decode is enabled, # we need to adjust the cudagraph sizes to be a multiple of the uniform - # decode query length to avoid: https://github.com/vllm-project/vllm/issues/28207 - # temp-fix: https://github.com/vllm-project/vllm/issues/28207#issuecomment-3504004536 + # decode query length to avoid: + # https://github.com/vllm-project/vllm/issues/28207 + # temp-fix: + # https://github.com/vllm-project/vllm/issues/28207#issuecomment-3504004536 # Will be removed in the near future when we have separate cudagraph capture # sizes for decode and mixed prefill-decode. if ( @@ -5119,7 +5166,13 @@ def calculate_reorder_batch_threshold(self) -> None: just may have a performance penalty due to that backend treating decodes as prefills. """ - min_none_high = lambda a, b: a if b is None else b if a is None else min(a, b) + + def min_none_high(a: int | None, b: int | None) -> int | None: + if b is None: + return a + if a is None: + return b + return min(a, b) reorder_batch_thresholds: list[int | None] = [ group.get_metadata_builder().reorder_batch_threshold @@ -5130,7 +5183,9 @@ def calculate_reorder_batch_threshold(self) -> None: if len(reorder_batch_thresholds) == 0: self.reorder_batch_threshold = None return - self.reorder_batch_threshold = reduce(min_none_high, reorder_batch_thresholds) # type: ignore[assignment] + self.reorder_batch_threshold = reduce( # type: ignore[assignment] + min_none_high, reorder_batch_thresholds + ) @staticmethod def select_common_block_size( @@ -5241,7 +5296,9 @@ def may_reinitialize_input_batch( kernel_block_sizes=kernel_block_sizes, is_spec_decode=bool(self.vllm_config.speculative_config), logitsprocs=self.input_batch.logitsprocs, - logitsprocs_need_output_token_ids=self.input_batch.logitsprocs_need_output_token_ids, + logitsprocs_need_output_token_ids=( + self.input_batch.logitsprocs_need_output_token_ids + ), is_pooling_model=self.is_pooling_model, num_speculative_tokens=self.num_spec_tokens, ) diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index 1e9e505fd0ef..0947adab3ca2 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections import defaultdict from dataclasses import dataclass, field -from functools import lru_cache import torch From 5da50637084da04ff1ee4454a8c0aca562ccd316 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 12 Dec 2025 10:34:49 -0800 Subject: [PATCH 049/130] fix pre-commit Signed-off-by: Chen Zhang --- vllm/v1/core/single_type_kv_cache_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 709f6b37433b..7f5168536ef7 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -53,8 +53,8 @@ class SingleTypeKVCacheManager(ABC): def __init__( self, kv_cache_spec: KVCacheSpec, - cache_config: CacheConfig, block_pool: BlockPool, + cache_config: CacheConfig, kv_cache_group_id: int, dcp_world_size: int = 1, pcp_world_size: int = 1, From 0114da1ac7cb9af565e185e47467027c2c79a99f Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 12 Dec 2025 10:38:28 -0800 Subject: [PATCH 050/130] fix pre-commit Signed-off-by: Chen Zhang --- vllm/v1/core/sched/scheduler.py | 2 +- vllm/v1/worker/gpu_model_runner.py | 11 ++++------- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 29c823c15ef6..c68c5c6cceb8 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -1407,7 +1407,7 @@ def update_draft_token_ids( # Add newly generated spec token ids to the request. if self.structured_output_manager.should_advance(request): metadata = request.structured_output_request - request.spec_token_ids = metadata.grammar.validate_tokens( + request.spec_token_ids = metadata.grammar.validate_tokens( # type: ignore[union-attr] spec_token_ids ) else: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index b33f7e0f32c8..49d0648b0d24 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -519,8 +519,7 @@ def __init__( # NOTE: `mrope_positions` is implemented with one additional dummy # position on purpose to make it non-contiguous so that it can work # with torch compile. - # See detailed explanation in - # https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923 + # See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923 # NOTE: When M-RoPE is enabled, position ids are 3D regardless of # the modality of inputs. For text-only inputs, each dimension has @@ -1737,12 +1736,10 @@ def _build_attn_group_metadata( if isinstance(attn_metadata, list): for ub_metadata in attn_metadata: for _metadata in ub_metadata.values(): - # type: ignore[attr-defined] - _metadata.mm_prefix_range = req_doc_ranges + _metadata.mm_prefix_range = req_doc_ranges # type: ignore[attr-defined] else: for _metadata in attn_metadata.values(): - # type: ignore[attr-defined] - _metadata.mm_prefix_range = req_doc_ranges + _metadata.mm_prefix_range = req_doc_ranges # type: ignore[attr-defined] if spec_decode_common_attn_metadata is not None and ( num_reqs != num_reqs_padded or num_tokens != num_tokens_padded @@ -2193,7 +2190,7 @@ def _execute_mm_encoder( # 2. A list or tuple (length: num_items) of tensors, # each of shape (feature_size, hidden_size) in case the feature # size is dynamic depending on the input multimodal items. - curr_group_outputs = model.embed_multimodal(**mm_kwargs_group) + curr_group_outputs = model.embed_multimodal(**mm_kwargs_group) # type: ignore[assignment] sanity_check_mm_encoder_outputs( curr_group_outputs, From e0f607c1bd83343422c7d8c9f37bd0ae314315c6 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 12 Dec 2025 10:39:15 -0800 Subject: [PATCH 051/130] fix pre-commit Signed-off-by: Chen Zhang --- vllm/v1/worker/gpu_model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 49d0648b0d24..aa90d41a8da5 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2951,7 +2951,7 @@ def _preprocess_mamba(self, scheduler_output: "SchedulerOutput"): assert self.cache_config.enable_prefix_caching block_size = mamba_spec.block_size finished_req_ids = scheduler_output.finished_req_ids - preempted_req_ids = scheduler_output.preempted_req_ids or [] + preempted_req_ids = scheduler_output.preempted_req_ids or set() for req_id in itertools.chain(finished_req_ids, preempted_req_ids): self.mamba_state_idx.pop(req_id, None) for i, req_id in enumerate(self.input_batch.req_ids): From 0ec5e29462824c4c770b4bdf3e43cd71593ab81e Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 12 Dec 2025 11:39:00 -0800 Subject: [PATCH 052/130] fix pre-commit Signed-off-by: Chen Zhang --- vllm/v1/core/kv_cache_coordinator.py | 2 +- vllm/v1/core/single_type_kv_cache_manager.py | 18 ++++++------------ 2 files changed, 7 insertions(+), 13 deletions(-) diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index 378b6e515e0b..ccc41c7531f5 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -61,8 +61,8 @@ def __init__( self.use_eagle = use_eagle self.single_type_managers = tuple( get_manager_for_kv_cache_spec( - cache_config=self.cache_config, kv_cache_spec=kv_cache_group.kv_cache_spec, + cache_config=self.cache_config, block_pool=self.block_pool, kv_cache_group_id=i, dcp_world_size=dcp_world_size, diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 7f5168536ef7..8e6da05c8396 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -53,8 +53,8 @@ class SingleTypeKVCacheManager(ABC): def __init__( self, kv_cache_spec: KVCacheSpec, - block_pool: BlockPool, cache_config: CacheConfig, + block_pool: BlockPool, kv_cache_group_id: int, dcp_world_size: int = 1, pcp_world_size: int = 1, @@ -398,12 +398,9 @@ def get_num_common_prefix_blocks(self, running_request_id: str) -> int: class SlidingWindowManager(SingleTypeKVCacheManager): - def __init__( - self, kv_cache_spec: SlidingWindowSpec, block_pool: BlockPool, **kwargs - ) -> None: - super().__init__(kv_cache_spec, block_pool, **kwargs) + def __init__(self, kv_cache_spec: SlidingWindowSpec, **kwargs) -> None: + super().__init__(kv_cache_spec, **kwargs) self.sliding_window = kv_cache_spec.sliding_window - self._null_block = block_pool.null_block @classmethod def find_longest_cache_hit( @@ -534,12 +531,9 @@ def get_num_common_prefix_blocks(self, running_request_id: str) -> int: class ChunkedLocalAttentionManager(SingleTypeKVCacheManager): - def __init__( - self, kv_cache_spec: ChunkedLocalAttentionSpec, block_pool: BlockPool, **kwargs - ) -> None: - super().__init__(kv_cache_spec, block_pool, **kwargs) + def __init__(self, kv_cache_spec: ChunkedLocalAttentionSpec, **kwargs) -> None: + super().__init__(kv_cache_spec, **kwargs) self.attention_chunk_size = kv_cache_spec.attention_chunk_size - self._null_block = block_pool.null_block @classmethod def find_longest_cache_hit( @@ -690,7 +684,7 @@ class MambaManager(SingleTypeKVCacheManager): def __init__( self, kv_cache_spec: MambaSpec, cache_config: CacheConfig, **kwargs ) -> None: - super().__init__(kv_cache_spec, cache_config, **kwargs) + super().__init__(kv_cache_spec, cache_config=cache_config, **kwargs) self.mamba_cache_mode = cache_config.mamba_cache_mode self.num_speculative_blocks: int = kv_cache_spec.num_speculative_blocks if self.mamba_cache_mode == "align": From 1be5e6d207bd252aafc00cf32f79cb6a844f1a1f Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 12 Dec 2025 11:39:41 -0800 Subject: [PATCH 053/130] fix pre-commit Signed-off-by: Chen Zhang --- vllm/v1/worker/utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index 0947adab3ca2..6010b41bca97 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -378,9 +378,10 @@ def get_mamba_groups(kv_cache_config: KVCacheConfig) -> tuple[list[int], MambaSp mamba_group_ids: list[int] = [] mamba_specs: list[MambaSpec] = [] for i in range(len(kv_cache_config.kv_cache_groups)): - if isinstance(kv_cache_config.kv_cache_groups[i].kv_cache_spec, MambaSpec): + kv_cache_spec = kv_cache_config.kv_cache_groups[i].kv_cache_spec + if isinstance(kv_cache_spec, MambaSpec): mamba_group_ids.append(i) - mamba_specs.append(kv_cache_config.kv_cache_groups[i].kv_cache_spec) + mamba_specs.append(kv_cache_spec) assert len(mamba_group_ids) > 0, "no mamba layers in the model" assert all(mamba_specs[0] == spec for spec in mamba_specs) return mamba_group_ids, mamba_specs[0] From ef2e9f29a781676e1d7a641f116164cff2326d48 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 12 Dec 2025 14:25:38 -0800 Subject: [PATCH 054/130] clean up Signed-off-by: Chen Zhang --- vllm/model_executor/layers/mamba/abstract.py | 5 +---- vllm/model_executor/layers/mamba/mamba_mixer.py | 9 +++------ vllm/model_executor/layers/mamba/mamba_mixer2.py | 15 ++++++--------- 3 files changed, 10 insertions(+), 19 deletions(-) diff --git a/vllm/model_executor/layers/mamba/abstract.py b/vllm/model_executor/layers/mamba/abstract.py index fe9e55e7b533..74f4383e9c23 100644 --- a/vllm/model_executor/layers/mamba/abstract.py +++ b/vllm/model_executor/layers/mamba/abstract.py @@ -48,10 +48,7 @@ def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None: raise NotImplementedError( "Mamba with speculative decoding is not supported yet." ) - if vllm_config.cache_config.enable_prefix_caching: - mamba_block_size = vllm_config.cache_config.block_size - else: - mamba_block_size = vllm_config.model_config.max_model_len + mamba_block_size = vllm_config.cache_config.mamba_block_size page_size_padded = vllm_config.cache_config.mamba_page_size_padded return MambaSpec( shapes=self.get_state_shape(), diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index 789776e923e5..9d509b82d9bb 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -240,10 +240,7 @@ def forward_cuda(self, hidden_states: torch.Tensor, output: torch.Tensor): assert self.cache_config is not None mamba_block_size = self.cache_config.mamba_block_size - prefix_caching_enabled = ( - self.cache_config.mamba_cache_mode != "align" - and self.cache_config.enable_prefix_caching - ) + return_intermediate_states = self.cache_config.mamba_cache_mode == "all" if attn_metadata is not None: assert isinstance(attn_metadata, dict) @@ -292,7 +289,7 @@ def forward_cuda(self, hidden_states: torch.Tensor, output: torch.Tensor): state_indices_tensor_p = prefill_decode_split.state_indices_tensor_p state_indices_tensor_d = prefill_decode_split.state_indices_tensor_d - if prefix_caching_enabled: + if return_intermediate_states: block_idx_last_computed_token_d, block_idx_last_computed_token_p = ( torch.split( attn_metadata.block_idx_last_computed_token, @@ -368,7 +365,7 @@ def forward_cuda(self, hidden_states: torch.Tensor, output: torch.Tensor): ssm_outputs.append(scan_out_p) if has_decode: - if prefix_caching_enabled: + if return_intermediate_states: state_indices_tensor_d_input = state_indices_tensor_d.gather( 1, block_idx_last_computed_token_d.unsqueeze(1) ).squeeze(1) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 6b42567f1b5c..ef3936b839b2 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -564,10 +564,7 @@ def conv_ssm_forward( assert self.cache_config is not None mamba_block_size = self.cache_config.mamba_block_size - prefix_caching_enabled = ( - self.cache_config.mamba_cache_mode != "align" - and self.cache_config.enable_prefix_caching - ) + return_intermediate_states = self.cache_config.mamba_cache_mode == "all" if attn_metadata is not None: assert isinstance(attn_metadata, dict) attn_metadata = attn_metadata[self.prefix] @@ -619,7 +616,7 @@ def conv_ssm_forward( dim=0, ) - if prefix_caching_enabled: + if return_intermediate_states: # If prefix caching is enabled, retrieve the relevant variables # for prefill and decode block_idx_last_computed_token_d, block_idx_last_computed_token_p = ( @@ -698,7 +695,7 @@ def conv_ssm_forward( initial_states = None if has_initial_states_p is not None and prep_initial_states: kernel_ssm_indices = state_indices_tensor_p - if prefix_caching_enabled: + if return_intermediate_states: kernel_ssm_indices = state_indices_tensor_p.gather( 1, block_idx_last_computed_token_p.unsqueeze(1) ).squeeze(1) @@ -726,14 +723,14 @@ def conv_ssm_forward( cu_chunk_seqlens=cu_chunk_seqlen_p, last_chunk_indices=last_chunk_indices_p, initial_states=initial_states, - return_intermediate_states=prefix_caching_enabled, + return_intermediate_states=return_intermediate_states, dt_softplus=True, dt_limit=(0.0, float("inf")), out=preallocated_ssm_out_p.view(num_prefill_tokens, -1, self.head_dim), state_dtype=ssm_state.dtype, ) - if prefix_caching_enabled: + if return_intermediate_states: # The chunk_stride is the number of chunks per mamba block # e.g., if mamba_block_size = 512 and chunk_size = 256, # then chunk_stride = 2 @@ -812,7 +809,7 @@ def conv_ssm_forward( # Process decode requests if has_decode: - if prefix_caching_enabled: + if return_intermediate_states: state_indices_tensor_d_input = state_indices_tensor_d.gather( 1, block_idx_last_computed_token_d.unsqueeze(1) ).squeeze(1) From 7aab2122c5c9984217bb0fae09e75567d1999a55 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 12 Dec 2025 14:38:01 -0800 Subject: [PATCH 055/130] code cleanup Signed-off-by: Chen Zhang --- vllm/v1/attention/backends/gdn_attn.py | 1 - vllm/v1/attention/backends/utils.py | 19 +++++++++++++++++-- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/vllm/v1/attention/backends/gdn_attn.py b/vllm/v1/attention/backends/gdn_attn.py index f9ba358e023f..6340005bdc62 100644 --- a/vllm/v1/attention/backends/gdn_attn.py +++ b/vllm/v1/attention/backends/gdn_attn.py @@ -153,7 +153,6 @@ def build( # type: ignore[override] common_attn_metadata, self.kv_cache_spec, self.vllm_config.cache_config.mamba_cache_mode, - 1 + self.num_spec, ) if ( diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 4806d3927216..d8c4acb1aeb4 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -1189,14 +1189,29 @@ def mamba_get_block_table_tensor( common_attn_metadata: CommonAttentionMetadata, kv_cache_spec: MambaSpec, mamba_cache_mode: str, - num_blocks: int = 1, ) -> torch.Tensor: + """ + Get the block table tensor for mamba kernels from the input + common_attn_metadata.block_table_tensor given different mamba cache modes. + + - "all": input (#requests, cdiv(max_model_len, block_size)); + output (#requests, cdiv(max_model_len, block_size)). + + - "none": input (#requests, 1 + num_speculative_blocks); + output (#requests, 1 + num_speculative_blocks). + + - "align": input (#requests, cdiv(max_model_len, block_size)); + output (#requests, 1 + num_speculative_blocks), which are the last + 1 + num_speculative_blocks of each request. + """ if mamba_cache_mode in ("all", "none"): return common_attn_metadata.block_table_tensor else: assert isinstance(kv_cache_spec, MambaSpec) block_table_tensor = common_attn_metadata.block_table_tensor start_indices = (common_attn_metadata.seq_lens - 1) // kv_cache_spec.block_size - offsets = torch.arange(num_blocks, device=block_table_tensor.device) + offsets = torch.arange( + 1 + kv_cache_spec.num_speculative_blocks, device=block_table_tensor.device + ) indices_to_gather = start_indices.unsqueeze(1) + offsets return torch.gather(block_table_tensor, 1, indices_to_gather) From 5aedba93be7b768fb933646c4ab29b690ff6dbd7 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 12 Dec 2025 14:59:59 -0800 Subject: [PATCH 056/130] code cleanup Signed-off-by: Chen Zhang --- vllm/config/vllm.py | 11 +++++++ vllm/model_executor/models/config.py | 6 +--- vllm/v1/core/sched/scheduler.py | 47 +++++++++++++--------------- 3 files changed, 34 insertions(+), 30 deletions(-) diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index b5f8f916de43..a2a2a2f69d53 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -933,6 +933,17 @@ def has_blocked_weights(): # local attention. self.scheduler_config.disable_hybrid_kv_cache_manager = True + if self.cache_config.mamba_cache_mode == "align": + if self.scheduler_config.long_prefill_token_threshold > 0: + assert ( + self.scheduler_config.long_prefill_token_threshold + >= self.cache_config.block_size + ) + assert not self.scheduler_config.disable_chunked_mm_input, ( + "Chunked MM input is required because we need the flexibility to " + "schedule a multiple of block_size tokens even if they are in the " + "middle of a mm input" + ) if self.compilation_config.debug_dump_path: self.compilation_config.debug_dump_path = ( self.compilation_config.debug_dump_path.absolute().expanduser() diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 8aa4d4adea13..3a74114c69ae 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -398,15 +398,11 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: model_config=model_config, ) - if cache_config.enable_prefix_caching: - block_size = cache_config.block_size - else: - block_size = model_config.max_model_len # get mamba page size mamba_page_size = MambaSpec( shapes=model_cls.get_mamba_state_shape_from_config(vllm_config), dtypes=model_cls.get_mamba_state_dtype_from_config(vllm_config), - block_size=block_size, + block_size=-1, # block_size doesn't matter for mamba page size ).page_size_bytes # Model may be marked as is_hybrid diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index c68c5c6cceb8..18cd8d485ea7 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -238,14 +238,24 @@ def _mamba_block_aligned_split( assert num_external_computed_tokens == 0, ( "External KV connector is not verified yet" ) - # To enable block-aligned caching of the Mamba state, `num_new_tokens` - # must be a multiple of `block_size`. - # As an exception, if `num_new_tokens` is less than `block_size`, the - # state is simply not cached, requiring no special handling. - # Additionally, when Eagle mode is enabled, FullAttn prunes the last - # matching block. To prevent this from causing a Mamba cache miss, the - # last chunk must be larger than `block_size`. if request.num_output_tokens == 0: # prefill + # Ensure new tokens for a request in the prefill phase do not contain + # sps tokens, especially in the last prefill chunk. For a hybrid-model, + # extra sps tokens would corrupt the generated Mamba state. + # TODO: This logic does not yet handle resumed requests. + if request.num_computed_tokens < request.num_prompt_tokens: + num_new_tokens = min( + request.num_prompt_tokens - request.num_computed_tokens, + num_new_tokens, + ) + + # To enable block-aligned caching of the Mamba state, `num_new_tokens` + # must be a multiple of `block_size`. + # As an exception, if `num_new_tokens` is less than `block_size`, the + # state is simply not cached, requiring no special handling. + # Additionally, when Eagle mode is enabled, FullAttn prunes the last + # matching block. To prevent this from causing a Mamba cache miss, the + # last chunk must be larger than `block_size`. block_size = self.cache_config.block_size last_cache_position = ( request.num_prompt_tokens - request.num_prompt_tokens % block_size @@ -324,24 +334,11 @@ def schedule(self) -> SchedulerOutput: req_index += 1 continue - # Ensure new tokens for a request in the prefill phase do not contain - # sps tokens, especially in the last prefill chunk. For a hybrid-model, - # extra sps tokens would corrupt the generated Mamba state. - # TODO: This logic does not yet handle resumed requests. - if request.num_computed_tokens < request.num_prompt_tokens: - num_new_tokens = ( - min( - request.num_tokens_with_spec + request.num_output_placeholders, - request.num_prompt_tokens, - ) - - request.num_computed_tokens - ) - else: - num_new_tokens = ( - request.num_tokens_with_spec - + request.num_output_placeholders - - request.num_computed_tokens - ) + num_new_tokens = ( + request.num_tokens_with_spec + + request.num_output_placeholders + - request.num_computed_tokens + ) if 0 < self.scheduler_config.long_prefill_token_threshold < num_new_tokens: num_new_tokens = self.scheduler_config.long_prefill_token_threshold From 1e15448715056ec67fffe10b03a903daa27a2b97 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 12 Dec 2025 15:03:09 -0800 Subject: [PATCH 057/130] code cleanup Signed-off-by: Chen Zhang --- vllm/v1/core/sched/scheduler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 18cd8d485ea7..83e6a6b41435 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -240,8 +240,8 @@ def _mamba_block_aligned_split( ) if request.num_output_tokens == 0: # prefill # Ensure new tokens for a request in the prefill phase do not contain - # sps tokens, especially in the last prefill chunk. For a hybrid-model, - # extra sps tokens would corrupt the generated Mamba state. + # draft tokens, especially in the last prefill chunk. For a hybrid-model, + # extra draft tokens would corrupt the generated Mamba state. # TODO: This logic does not yet handle resumed requests. if request.num_computed_tokens < request.num_prompt_tokens: num_new_tokens = min( From 0dd525ef4a12d117e607acae090804b1e556a357 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 12 Dec 2025 15:07:25 -0800 Subject: [PATCH 058/130] revert Signed-off-by: Chen Zhang --- vllm/v1/core/sched/scheduler.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 83e6a6b41435..dbe17315d515 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -339,10 +339,8 @@ def schedule(self) -> SchedulerOutput: + request.num_output_placeholders - request.num_computed_tokens ) - if 0 < self.scheduler_config.long_prefill_token_threshold < num_new_tokens: num_new_tokens = self.scheduler_config.long_prefill_token_threshold - num_new_tokens = min(num_new_tokens, token_budget) # Make sure the input position does not exceed the max model len. @@ -516,6 +514,10 @@ def schedule(self) -> SchedulerOutput: if is_ready: request.status = RequestStatus.WAITING else: + logger.debug( + "%s is still in WAITING_FOR_REMOTE_KVS state.", + request.request_id, + ) self.waiting.pop_request() skipped_waiting_requests.prepend_request(request) continue @@ -1876,7 +1878,7 @@ def _handle_invalid_blocks(self, invalid_block_ids: set[int]) -> set[str]: all_failed_req_ids = async_failed_req_ids | sync_failed_req_ids logger.error( "Failing %d request(s) due to KV load failure " - "(failure_policy=fail, %d tokens affected). Request IDs: 328", + "(failure_policy=fail, %d tokens affected). Request IDs: %s", total_failed_requests, total_failed_tokens, all_failed_req_ids, From 239b7d59e6d96db395e42296ab08fbacce66103c Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 12 Dec 2025 15:14:50 -0800 Subject: [PATCH 059/130] update Signed-off-by: Chen Zhang --- vllm/v1/core/kv_cache_coordinator.py | 16 ++++++------ vllm/v1/core/kv_cache_manager.py | 8 +++--- vllm/v1/core/single_type_kv_cache_manager.py | 26 +++++++++++--------- 3 files changed, 28 insertions(+), 22 deletions(-) diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index ccc41c7531f5..1efb675bd377 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -77,7 +77,7 @@ def get_num_blocks_to_allocate( num_tokens: int, new_computed_blocks: tuple[Sequence[KVCacheBlock], ...], num_encoder_tokens: int, - num_tokens_target_model: int, + num_tokens_main_model: int, ) -> int: """ Get the number of blocks needed to be allocated for the request. @@ -90,8 +90,10 @@ def get_num_blocks_to_allocate( prefix caching. num_encoder_tokens: The number of encoder tokens for allocating blocks for cross-attention. - num_tokens_target_model: w/o spec decode, this should be the same as - num_tokens, with spec decode, TODO more comments here. + num_tokens_main_model: The number of tokens for the main model (aka target + model in spec decode). w/o spec decode, it is num_tokens; + with spec decode, it is num_tokens - num_lookahead_tokens. + Returns: The number of blocks. """ @@ -108,7 +110,7 @@ def get_num_blocks_to_allocate( request_id, num_tokens, new_computed_blocks[i], - num_tokens_target_model, + num_tokens_main_model, ) return num_blocks_to_allocate @@ -130,7 +132,7 @@ def allocate_new_blocks( self, request_id: str, num_tokens: int, - num_tokens_target_model: int, + num_tokens_main_model: int, num_encoder_tokens: int = 0, ) -> tuple[list[KVCacheBlock], ...]: """ @@ -141,7 +143,7 @@ def allocate_new_blocks( request_id: The request ID. num_tokens: The total number of tokens that need a slot (including tokens that are already allocated). - num_tokens_target_model: w/o spec decode, this should be the same as + num_tokens_main_model: w/o spec decode, this should be the same as num_tokens, with spec decode, TODO more comments here. num_encoder_tokens: The number of encoder tokens for allocating blocks for cross-attention. @@ -154,7 +156,7 @@ def allocate_new_blocks( num_encoder_tokens if isinstance(manager, CrossAttentionManager) else num_tokens, - num_tokens_target_model, + num_tokens_main_model, ) for manager in self.single_type_managers ) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 5aebab9b3ec9..5a760b6a41b6 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -275,9 +275,9 @@ def allocate_slots( # The number of computed tokens is the number of computed tokens plus # the new prefix caching hits num_computed_tokens = request.num_computed_tokens + num_new_computed_tokens - num_tokens_target_model = num_computed_tokens + num_new_tokens + num_tokens_main_model = num_computed_tokens + num_new_tokens num_tokens_need_slot = min( - num_tokens_target_model + num_lookahead_tokens, + num_tokens_main_model + num_lookahead_tokens, self.max_model_len, ) @@ -286,7 +286,7 @@ def allocate_slots( num_tokens=num_tokens_need_slot, new_computed_blocks=new_computed_block_list, num_encoder_tokens=num_encoder_tokens, - num_tokens_target_model=num_tokens_target_model, + num_tokens_main_model=num_tokens_main_model, ) if num_blocks_to_allocate > self.block_pool.get_num_free_blocks(): @@ -311,7 +311,7 @@ def allocate_slots( new_blocks = self.coordinator.allocate_new_blocks( request.request_id, num_tokens_need_slot, - num_tokens_target_model, + num_tokens_main_model, num_encoder_tokens, ) diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 8e6da05c8396..6be98d256b56 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -98,7 +98,7 @@ def get_num_blocks_to_allocate( request_id: str, num_tokens: int, new_computed_blocks: Sequence[KVCacheBlock], - num_tokens_target_model: int, + num_tokens_main_model: int, ) -> int: """ Get the number of blocks needed to be allocated for the request. @@ -109,8 +109,9 @@ def get_num_blocks_to_allocate( tokens that are already allocated). new_computed_blocks: The new computed blocks just hitting the prefix caching. - num_tokens_target_model: w/o spec decode, this should be the same as - num_tokens, with spec decode, TODO more comments here. + num_tokens_main_model: The number of tokens for the main model (aka target + model in spec decode). w/o spec decode, it is num_tokens; + with spec decode, it is num_tokens - num_lookahead_tokens. Returns: The number of blocks. @@ -153,7 +154,7 @@ def save_new_computed_blocks( assert len(new_computed_blocks) == 0 def allocate_new_blocks( - self, request_id: str, num_tokens: int, num_tokens_target_model: int + self, request_id: str, num_tokens: int, num_tokens_main_model: int ) -> list[KVCacheBlock]: """ Allocate new blocks for the request to give it at least `num_tokens` @@ -163,7 +164,7 @@ def allocate_new_blocks( request_id: The request ID. num_tokens: The total number of tokens that need a slot (including tokens that are already allocated). - num_tokens_target_model: w/o spec decode, this should be the same as + num_tokens_main_model: w/o spec decode, this should be the same as num_tokens, with spec decode, TODO more comments here. Returns: The new allocated blocks. @@ -777,11 +778,11 @@ def get_num_blocks_to_allocate( request_id: str, num_tokens: int, new_computed_blocks: Sequence[KVCacheBlock], - num_tokens_target_model: int, + num_tokens_main_model: int, ) -> int: assert isinstance(self.kv_cache_spec, MambaSpec) # mamba layers only exist in target model. - num_tokens = num_tokens_target_model + num_tokens = num_tokens_main_model if self.mamba_cache_mode != "align": # Allocate extra `num_speculative_blocks` blocks for # speculative decoding (MTP/EAGLE) with linear attention. @@ -794,7 +795,7 @@ def get_num_blocks_to_allocate( request_id, num_tokens, new_computed_blocks, - num_tokens_target_model, + num_tokens_main_model, ) else: num_required_blocks = ( @@ -835,19 +836,22 @@ def save_new_computed_blocks( super().save_new_computed_blocks(request_id, list(new_computed_blocks)) def allocate_new_blocks( - self, request_id: str, num_tokens: int, num_tokens_target_model: int + self, request_id: str, num_tokens: int, num_tokens_main_model: int ) -> list[KVCacheBlock]: assert isinstance(self.kv_cache_spec, MambaSpec) - num_tokens = num_tokens_target_model if self.mamba_cache_mode != "align": # Allocate extra `num_speculative_blocks` blocks for # speculative decoding (MTP/EAGLE) with linear attention. if self.num_speculative_blocks > 0: num_tokens += self.block_size * self.num_speculative_blocks return super().allocate_new_blocks( - request_id, num_tokens, num_tokens_target_model + request_id, num_tokens, num_tokens_main_model ) else: + # We don't allocate blocks for lookahead tokens in align mode, because if + # x * block_size tokens are scheduled, original num_tokens is + # x * block_size + num_lookahead_tokens and breaks the alignment. + num_tokens = num_tokens_main_model req_blocks: list[KVCacheBlock] = self.req_to_blocks[request_id] num_required_blocks = ( cdiv(num_tokens, self.block_size) + self.num_speculative_blocks From b09d8cefbbd68093887133aa0dea0176819623d0 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 12 Dec 2025 15:17:19 -0800 Subject: [PATCH 060/130] update Signed-off-by: Chen Zhang --- vllm/v1/core/kv_cache_coordinator.py | 6 ++++-- vllm/v1/core/single_type_kv_cache_manager.py | 5 +++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index 1efb675bd377..c8b9ec96feb7 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -143,10 +143,12 @@ def allocate_new_blocks( request_id: The request ID. num_tokens: The total number of tokens that need a slot (including tokens that are already allocated). - num_tokens_main_model: w/o spec decode, this should be the same as - num_tokens, with spec decode, TODO more comments here. + num_tokens_main_model: The number of tokens for the main model (aka target + model in spec decode). w/o spec decode, it is num_tokens; + with spec decode, it is num_tokens - num_lookahead_tokens. num_encoder_tokens: The number of encoder tokens for allocating blocks for cross-attention. + Returns: The new allocated blocks. """ diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 6be98d256b56..032adada1bf0 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -164,8 +164,9 @@ def allocate_new_blocks( request_id: The request ID. num_tokens: The total number of tokens that need a slot (including tokens that are already allocated). - num_tokens_main_model: w/o spec decode, this should be the same as - num_tokens, with spec decode, TODO more comments here. + num_tokens_main_model: The number of tokens for the main model (aka target + model in spec decode). w/o spec decode, it is num_tokens; + with spec decode, it is num_tokens - num_lookahead_tokens. Returns: The new allocated blocks. """ From 167f35f25bf64db4be3d5a9f0a27834d12dc40e5 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 12 Dec 2025 15:59:11 -0800 Subject: [PATCH 061/130] update Signed-off-by: Chen Zhang --- vllm/v1/core/single_type_kv_cache_manager.py | 30 ++++------ vllm/v1/worker/block_table.py | 17 ------ vllm/v1/worker/gpu_model_runner.py | 63 +++++--------------- 3 files changed, 29 insertions(+), 81 deletions(-) diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 032adada1bf0..beb925fbbc1b 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -740,13 +740,6 @@ def find_longest_cache_hit( computed.append(cached) break # we just need the last match - early stopping - computed_blocks_fmt = [ - format_blocks(computed_block) for computed_block in computed_blocks - ] - print( - f"Mamba.FindLongest: computed_blocks={computed_blocks_fmt}", - flush=True, - ) return computed_blocks def get_num_skipped_tokens(self, num_computed_tokens: int) -> int: @@ -782,8 +775,6 @@ def get_num_blocks_to_allocate( num_tokens_main_model: int, ) -> int: assert isinstance(self.kv_cache_spec, MambaSpec) - # mamba layers only exist in target model. - num_tokens = num_tokens_main_model if self.mamba_cache_mode != "align": # Allocate extra `num_speculative_blocks` blocks for # speculative decoding (MTP/EAGLE) with linear attention. @@ -799,6 +790,12 @@ def get_num_blocks_to_allocate( num_tokens_main_model, ) else: + # We don't allocate blocks for lookahead tokens in align mode, because if + # x * block_size tokens are scheduled, num_tokens is + # x * block_size + num_lookahead_tokens and breaks the alignment. + # We can ignore lookahead tokens because current draft models don't have + # mamba layers. + num_tokens = num_tokens_main_model num_required_blocks = ( cdiv(num_tokens, self.block_size) + self.num_speculative_blocks ) @@ -808,16 +805,13 @@ def get_num_blocks_to_allocate( - len(self.req_to_blocks[request_id]) ) if num_new_blocks > 0: - # (Chen): This may be possible. (block_size 4, 2 sps). - # [A, stoken1, stoken2] SBLOCK1 SBLOCK2 -> - # [A, ?, ?, ?] NULL NULL [?, ?, ?, B] [stoken 1, stoken 2] SBLOCK1 SBLOCK2 -> need two blocks # noqa: E501 - # but we do it as following: - # [A, ?, ?, ?] NULL NULL NULL [stoken 1, stoken 2] SBLOCK1 SBLOCK2 -> need 1 block # noqa: E501 if request_id in self._allocated_block_reqs: - # previously allocated blocks + # Old request. Needs at most 1 more blocks as we can reuse the + # speculative blocks in previous step. num_new_blocks = 1 else: - # first prefill + # First prefill. Allocate 1 block for running state and the + # speculative blocks. num_new_blocks = 1 + self.kv_cache_spec.num_speculative_blocks # If a computed block of a request is an eviction candidate (in the @@ -850,8 +844,10 @@ def allocate_new_blocks( ) else: # We don't allocate blocks for lookahead tokens in align mode, because if - # x * block_size tokens are scheduled, original num_tokens is + # x * block_size tokens are scheduled, num_tokens is # x * block_size + num_lookahead_tokens and breaks the alignment. + # We can ignore lookahead tokens because current draft models don't have + # mamba layers. num_tokens = num_tokens_main_model req_blocks: list[KVCacheBlock] = self.req_to_blocks[request_id] num_required_blocks = ( diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index 0b6319eec121..efb01bb8675f 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -114,19 +114,6 @@ def append_row( self.num_blocks_per_row[row_idx] += num_blocks self.block_table.np[row_idx, start : start + num_blocks] = block_ids - def pop_row(self, num_blocks: int, row_idx: int): - if num_blocks <= 0: - return - - if self.use_hybrid_blocks: - num_blocks = num_blocks * self.blocks_per_kv_block - - end = self.num_blocks_per_row[row_idx] - start = end - num_blocks - assert start >= 0 - self.num_blocks_per_row[row_idx] -= num_blocks - self.block_table.np[row_idx, start:end] = 0 - def add_row(self, block_ids: list[int], row_idx: int) -> None: self.num_blocks_per_row[row_idx] = 0 self.append_row(block_ids, row_idx) @@ -325,10 +312,6 @@ def append_row(self, block_ids: tuple[list[int], ...], row_idx: int) -> None: for i, block_table in enumerate(self.block_tables): block_table.append_row(block_ids[i], row_idx) - def pop_row(self, num_blocks: tuple[int, ...], row_idx: int) -> None: - for i, block_table in enumerate(self.block_tables): - block_table.pop_row(num_blocks[i], row_idx) - def add_row(self, block_ids: tuple[list[int], ...], row_idx: int) -> None: for i, block_table in enumerate(self.block_tables): block_table.add_row(block_ids[i], row_idx) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index aa90d41a8da5..2b56442f7d62 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -453,9 +453,7 @@ def __init__( # uses output token ids so we set this conservatively. logitsprocs_need_output_token_ids=bool(custom_logitsprocs), is_pooling_model=self.is_pooling_model, - cp_kv_cache_interleave_size=( - self.parallel_config.cp_kv_cache_interleave_size - ), + cp_kv_cache_interleave_size=self.parallel_config.cp_kv_cache_interleave_size, ) self.use_async_scheduling = self.scheduler_config.async_scheduling @@ -1057,15 +1055,6 @@ def _update_states_after_model_execute( ) for i, num_tokens in enumerate(num_accepted_tokens): self.input_batch.num_accepted_tokens_cpu[i] = num_tokens - if is_global_first_rank(): - logger.info( - ">>> [DEBUG] Worker: _update_states: output_token_ids=%s", - output_token_ids, - ) - logger.info( - ">>> [DEBUG] Worker: _update_states: num_accepted_tokens_cpu=%s", - self.input_batch.num_accepted_tokens_cpu[: len(num_accepted_tokens)], - ) if self.cache_config.mamba_cache_mode == "align": self._postprocess_mamba(scheduler_output) @@ -2587,12 +2576,6 @@ def _sample( logits, sampling_metadata, ) - if is_global_first_rank(): - logger.info( - ">>> [DEBUG] Worker: sampler_output: shape=%s tokens=%s", - sampler_output.sampled_token_ids.shape, - sampler_output.sampled_token_ids, - ) return sampler_output def _bookkeeping_sync( @@ -2955,10 +2938,6 @@ def _preprocess_mamba(self, scheduler_output: "SchedulerOutput"): for req_id in itertools.chain(finished_req_ids, preempted_req_ids): self.mamba_state_idx.pop(req_id, None) for i, req_id in enumerate(self.input_batch.req_ids): - if is_global_first_rank(): - logger.info( - ">>> [DEBUG] Worker: preprocess mamba for RUN: req_id=%s", req_id - ) req_state = self.requests[req_id] prev_state_idx = self.mamba_state_idx.get(req_id) if prev_state_idx is None: @@ -2967,19 +2946,20 @@ def _preprocess_mamba(self, scheduler_output: "SchedulerOutput"): prev_state_idx = (req_state.num_computed_tokens - 1) // block_size num_blocks = len(req_state.block_ids[mamba_group_ids[0]]) + # We always save the current running state at the last - # (1 + num_speculative_blocks) block + # (1 + num_speculative_blocks) block. + # A corner case worth mention here: assume we have block_size = 4 and + # num_speculative_tokens = 2. The request is [A, B, C] and contains 2 draft + # tokens [draft 1, draft 2]. Then we will have: + # Block 0: [A, B, C, draft 1] + # Block 1: [draft 2, TOFILL, TOFILL, TOFILL] + # Block 2: speculative block + # Block 3: speculative block + # And use block 1 to save the running state. curr_state_idx = num_blocks - 1 - num_speculative_blocks - if is_global_first_rank(): - logger.info( - ">>> [DEBUG] Worker: preprocess mamba: req_id=%s, idx %s -> %s", - req_id, - prev_state_idx, - curr_state_idx, - ) self.mamba_state_idx[req_id] = curr_state_idx if prev_state_idx != -1 and prev_state_idx != curr_state_idx: - # TODO: merge all these lines to copy_block self._mamba_copy_block_for_qwen_next( mamba_group_ids, prev_state_idx, @@ -5130,10 +5110,8 @@ def _check_and_update_cudagraph_mode( # if we have dedicated decode cudagraphs, and spec-decode is enabled, # we need to adjust the cudagraph sizes to be a multiple of the uniform - # decode query length to avoid: - # https://github.com/vllm-project/vllm/issues/28207 - # temp-fix: - # https://github.com/vllm-project/vllm/issues/28207#issuecomment-3504004536 + # decode query length to avoid: https://github.com/vllm-project/vllm/issues/28207 + # temp-fix: https://github.com/vllm-project/vllm/issues/28207#issuecomment-3504004536 # Will be removed in the near future when we have separate cudagraph capture # sizes for decode and mixed prefill-decode. if ( @@ -5164,12 +5142,7 @@ def calculate_reorder_batch_threshold(self) -> None: as prefills. """ - def min_none_high(a: int | None, b: int | None) -> int | None: - if b is None: - return a - if a is None: - return b - return min(a, b) + min_none_high = lambda a, b: a if b is None else b if a is None else min(a, b) reorder_batch_thresholds: list[int | None] = [ group.get_metadata_builder().reorder_batch_threshold @@ -5180,9 +5153,7 @@ def min_none_high(a: int | None, b: int | None) -> int | None: if len(reorder_batch_thresholds) == 0: self.reorder_batch_threshold = None return - self.reorder_batch_threshold = reduce( # type: ignore[assignment] - min_none_high, reorder_batch_thresholds - ) + self.reorder_batch_threshold = reduce(min_none_high, reorder_batch_thresholds) # type: ignore[assignment] @staticmethod def select_common_block_size( @@ -5293,9 +5264,7 @@ def may_reinitialize_input_batch( kernel_block_sizes=kernel_block_sizes, is_spec_decode=bool(self.vllm_config.speculative_config), logitsprocs=self.input_batch.logitsprocs, - logitsprocs_need_output_token_ids=( - self.input_batch.logitsprocs_need_output_token_ids - ), + logitsprocs_need_output_token_ids=self.input_batch.logitsprocs_need_output_token_ids, is_pooling_model=self.is_pooling_model, num_speculative_tokens=self.num_spec_tokens, ) From 5e915986f2a09670a02ace13378f9baffebce0fc Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 12 Dec 2025 17:06:42 -0800 Subject: [PATCH 062/130] introduce mamba_utils Signed-off-by: Chen Zhang --- tests/v1/e2e/test_mamba_prefix_cache.py | 65 +++++--- vllm/v1/worker/gpu_model_runner.py | 192 +++--------------------- vllm/v1/worker/mamba_utils.py | 163 ++++++++++++++++++++ vllm/v1/worker/utils.py | 15 -- 4 files changed, 229 insertions(+), 206 deletions(-) create mode 100644 vllm/v1/worker/mamba_utils.py diff --git a/tests/v1/e2e/test_mamba_prefix_cache.py b/tests/v1/e2e/test_mamba_prefix_cache.py index 36df54c6e57f..6138aa17896d 100644 --- a/tests/v1/e2e/test_mamba_prefix_cache.py +++ b/tests/v1/e2e/test_mamba_prefix_cache.py @@ -3,23 +3,28 @@ import os from collections.abc import Callable from dataclasses import dataclass +from typing import Any import pytest import torch from vllm import LLM, SamplingParams, TokensPrompt +from vllm.config import CacheConfig from vllm.sequence import IntermediateTensors from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.engine.core_client import InprocClient +from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.outputs import SamplerOutput from vllm.v1.request import Request from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.spec_decode.metadata import SpecDecodeMetadata +from vllm.v1.worker import mamba_utils from vllm.v1.worker.gpu_input_batch import CachedRequestState from vllm.v1.worker.gpu_model_runner import GPUModelRunner -from vllm.v1.worker.utils import get_mamba_groups +from vllm.v1.worker.lora_model_runner_mixin import GPUInputBatch +from vllm.v1.worker.mamba_utils import get_mamba_groups @dataclass @@ -272,45 +277,73 @@ def get_fake_process_mamba_fn( copy_info = (-1, -1) def fake_preprocess_mamba_fn( - self: GPUModelRunner, scheduler_output: SchedulerOutput + scheduler_output: SchedulerOutput, + kv_cache_config: KVCacheConfig, + cache_config: CacheConfig, + mamba_state_idx: dict[str, int], + input_batch: GPUInputBatch, + requests: dict[str, CachedRequestState], + forward_context: dict[str, Any], ): nonlocal copy_info copy_info = (-1, -1) - ret = original_preprocess_mamba_fn(self, scheduler_output) + ret = original_preprocess_mamba_fn( + scheduler_output, + kv_cache_config, + cache_config, + mamba_state_idx, + input_batch, + requests, + forward_context, + ) if cur_step_action is not None: print("[UNIT TEST STEP] verifying preprocess_copy_idx") assert copy_info == cur_step_action.preprocess_copy_idx return ret def fake_post_process_mamba_fn( - self: GPUModelRunner, scheduler_output: SchedulerOutput + scheduler_output: SchedulerOutput, + kv_cache_config: KVCacheConfig, + input_batch: GPUInputBatch, + requests: dict[str, CachedRequestState], + mamba_state_idx: dict[str, int], + forward_context: dict[str, Any], ): nonlocal copy_info copy_info = (-1, -1) - ret = original_post_process_mamba_fn(self, scheduler_output) + ret = original_post_process_mamba_fn( + scheduler_output, + kv_cache_config, + input_batch, + requests, + mamba_state_idx, + forward_context, + ) if cur_step_action is not None: print("[UNIT TEST STEP] verifying postprocess_copy_idx") assert copy_info == cur_step_action.postprocess_copy_idx return ret def fake_copy_fn( - self: GPUModelRunner, - kv_cache_group_ids: list[int], + kv_cache_config: KVCacheConfig, + mamba_group_ids: list[int], src_block_idx: int, dest_block_idx: int, accept_token_bias: int, req_state: CachedRequestState, + forward_context: dict[str, Any], ): nonlocal copy_info assert copy_info == (-1, -1) copy_info = (src_block_idx, dest_block_idx) return original_copy_fn( - self, - kv_cache_group_ids, + kv_cache_config, + mamba_group_ids, src_block_idx, dest_block_idx, accept_token_bias, req_state, + forward_context, ) return fake_preprocess_mamba_fn, fake_post_process_mamba_fn, fake_copy_fn @@ -417,16 +450,14 @@ def apply_patch(monkeypatch: pytest.MonkeyPatch): fake_preprocess_mamba_fn, fake_post_process_mamba_fn, fake_copy_fn = ( get_fake_process_mamba_fn( - GPUModelRunner._preprocess_mamba, - GPUModelRunner._postprocess_mamba, - GPUModelRunner._mamba_copy_block_for_qwen_next, + mamba_utils.preprocess_mamba, + mamba_utils.postprocess_mamba, + mamba_utils.mamba_copy_block_for_qwen_next, ) ) - monkeypatch.setattr(GPUModelRunner, "_preprocess_mamba", fake_preprocess_mamba_fn) - monkeypatch.setattr( - GPUModelRunner, "_postprocess_mamba", fake_post_process_mamba_fn - ) - monkeypatch.setattr(GPUModelRunner, "_mamba_copy_block_for_qwen_next", fake_copy_fn) + monkeypatch.setattr(mamba_utils, "preprocess_mamba", fake_preprocess_mamba_fn) + monkeypatch.setattr(mamba_utils, "postprocess_mamba", fake_post_process_mamba_fn) + monkeypatch.setattr(mamba_utils, "mamba_copy_block_for_qwen_next", fake_copy_fn) def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch): diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 2b56442f7d62..ba259d6694a9 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -149,6 +149,7 @@ from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer from vllm.v1.structured_output.utils import apply_grammar_bitmask from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext +from vllm.v1.worker import mamba_utils from vllm.v1.worker.cp_utils import check_attention_cp_compatibility from vllm.v1.worker.dp_utils import coordinate_batch_across_dp from vllm.v1.worker.ec_connector_model_runner_mixin import ECConnectorModelRunnerMixin @@ -169,7 +170,6 @@ add_kv_sharing_layers_to_kv_cache_groups, bind_kv_cache, gather_mm_placeholders, - get_mamba_groups, sanity_check_mm_encoder_outputs, scatter_mm_placeholders, ) @@ -1056,7 +1056,14 @@ def _update_states_after_model_execute( for i, num_tokens in enumerate(num_accepted_tokens): self.input_batch.num_accepted_tokens_cpu[i] = num_tokens if self.cache_config.mamba_cache_mode == "align": - self._postprocess_mamba(scheduler_output) + mamba_utils.postprocess_mamba( + scheduler_output, + self.kv_cache_config, + self.input_batch, + self.requests, + self.mamba_state_idx, + self.compilation_config.static_forward_context, + ) def _init_mrope_positions(self, req_state: CachedRequestState): model = self.get_model() @@ -2923,176 +2930,6 @@ def _mamba_copy_block( for kv_cache_part in kv_cache: kv_cache_part[dest_block_id].copy_(kv_cache_part[src_block_id]) - def _preprocess_mamba(self, scheduler_output: "SchedulerOutput"): - """ - Copy the mamba state of previous step to the last - (1 + num_speculative_blocks) block. - """ - mamba_group_ids, mamba_spec = get_mamba_groups(self.kv_cache_config) - num_speculative_blocks = mamba_spec.num_speculative_blocks - # TODO(Chen): we need to optimize this function a lot - assert self.cache_config.enable_prefix_caching - block_size = mamba_spec.block_size - finished_req_ids = scheduler_output.finished_req_ids - preempted_req_ids = scheduler_output.preempted_req_ids or set() - for req_id in itertools.chain(finished_req_ids, preempted_req_ids): - self.mamba_state_idx.pop(req_id, None) - for i, req_id in enumerate(self.input_batch.req_ids): - req_state = self.requests[req_id] - prev_state_idx = self.mamba_state_idx.get(req_id) - if prev_state_idx is None: - # new / resumed request, no previous state - # if num_computed_tokens is 0, prev_state_idx will be -1 - prev_state_idx = (req_state.num_computed_tokens - 1) // block_size - - num_blocks = len(req_state.block_ids[mamba_group_ids[0]]) - - # We always save the current running state at the last - # (1 + num_speculative_blocks) block. - # A corner case worth mention here: assume we have block_size = 4 and - # num_speculative_tokens = 2. The request is [A, B, C] and contains 2 draft - # tokens [draft 1, draft 2]. Then we will have: - # Block 0: [A, B, C, draft 1] - # Block 1: [draft 2, TOFILL, TOFILL, TOFILL] - # Block 2: speculative block - # Block 3: speculative block - # And use block 1 to save the running state. - curr_state_idx = num_blocks - 1 - num_speculative_blocks - self.mamba_state_idx[req_id] = curr_state_idx - if prev_state_idx != -1 and prev_state_idx != curr_state_idx: - self._mamba_copy_block_for_qwen_next( - mamba_group_ids, - prev_state_idx, - curr_state_idx, - self.input_batch.num_accepted_tokens_cpu[i] - 1, - req_state, - ) - self.input_batch.num_accepted_tokens_cpu[i] = 1 - - def _mamba_copy_block_for_qwen_next( - self, - kv_cache_group_ids: list[int], - src_block_idx: int, - dest_block_idx: int, - accept_token_bias: int, - req_state: CachedRequestState, - ): - # TODO: general impl for all models - if src_block_idx == dest_block_idx and accept_token_bias == 0: - return - forward_context = self.compilation_config.static_forward_context - for kv_cache_group_id in kv_cache_group_ids: - block_ids = req_state.block_ids[kv_cache_group_id] - dest_block_id = block_ids[dest_block_idx] - layer_names = self.kv_cache_config.kv_cache_groups[ - kv_cache_group_id - ].layer_names - for layer_name in layer_names: - kv_caches: list[list[torch.Tensor]] = forward_context[ - layer_name - ].kv_cache[0] - conv_state, gdn_state = kv_caches - # conv state - conv_state_block_id = block_ids[src_block_idx] - src_conv_state = conv_state[conv_state_block_id][accept_token_bias:] - dest_conv_state = conv_state[dest_block_id] - dest_conv_state[: len(src_conv_state)].copy_(src_conv_state.clone()) - # gdn state - gdn_state_block_id = block_ids[src_block_idx + accept_token_bias] - src_gdn_state = gdn_state[gdn_state_block_id] - dest_gdn_state = gdn_state[dest_block_id] - dest_gdn_state.copy_(src_gdn_state) - if is_global_first_rank() and layer_name == layer_names[0]: - logger.info( - ">>> [DEBUG] Worker: mamba_copy_block_for_qwen_next: " - "layer_name=%s, idx %s -> %s conv %s -> %s with bias %s, " - "gdn %s -> %s", - layer_name, - src_block_idx, - dest_block_idx, - conv_state_block_id, - dest_block_id, - accept_token_bias, - gdn_state_block_id, - dest_block_id, - ) - - def _postprocess_mamba(self, scheduler_output: "SchedulerOutput"): - """ - 1. If a blocks is converted from partial block to full block in this step, copy - 2. Unify the state after token acceptance - the state from mamba_state_idx to that block - """ - num_scheduled_tokens_dict = scheduler_output.num_scheduled_tokens - scheduled_spec_decode_tokens_dict = ( - scheduler_output.scheduled_spec_decode_tokens - ) - num_accepted_tokens_cpu = self.input_batch.num_accepted_tokens_cpu - if is_global_first_rank(): - logger.info( - ">>> [DEBUG] Worker: postprocess mamba num_scheduled_tokens=%s " - "scheduled_spec_decode_tokens=%s num_accepted_tokens_cpu=%s", - num_scheduled_tokens_dict, - scheduled_spec_decode_tokens_dict, - num_accepted_tokens_cpu, - ) - # NOTE: can be optimized as this function always returns the same result - mamba_group_ids, mamba_spec = get_mamba_groups(self.kv_cache_config) - # TODO: vectorize this loop - for i, req_id in enumerate(self.input_batch.req_ids): - req_state = self.requests[req_id] - num_computed_tokens = req_state.num_computed_tokens - num_draft_tokens = len(scheduled_spec_decode_tokens_dict.get(req_id, [])) - num_scheduled_tokens = num_scheduled_tokens_dict[req_id] - num_accepted_tokens = num_accepted_tokens_cpu[i] - num_tokens_running_state = ( - num_computed_tokens + num_scheduled_tokens - num_draft_tokens - ) - new_num_computed_tokens = num_tokens_running_state + num_accepted_tokens - 1 - aligned_new_computed_tokens = ( - new_num_computed_tokens // mamba_spec.block_size * mamba_spec.block_size - ) - if is_global_first_rank(): - logger.info( - ">>> [DEBUG] Worker: postprocess mamba: req=%s comp=%s " - "sched=%s draft=%s accepted=%s run_state=%s new=%s aligned=%s", - req_id, - num_computed_tokens, - num_scheduled_tokens, - num_draft_tokens, - num_accepted_tokens, - num_tokens_running_state, - new_num_computed_tokens, - aligned_new_computed_tokens, - ) - # TODO: how to ensure all blocks that cache_blocks called are cached here? - if aligned_new_computed_tokens >= num_tokens_running_state: - accept_token_bias = ( - aligned_new_computed_tokens - num_tokens_running_state - ) - src_block_idx = self.mamba_state_idx[req_id] - dest_block_idx = ( - aligned_new_computed_tokens // mamba_spec.block_size - 1 - ) - if is_global_first_rank(): - logger.info( - ">>> [DEBUG] Worker: postprocess mamba copy: req_id=%s, " - "%s -> %s with bias %s", - req_id, - src_block_idx, - dest_block_idx, - accept_token_bias, - ) - self._mamba_copy_block_for_qwen_next( - mamba_group_ids, - src_block_idx, - dest_block_idx, - accept_token_bias, - req_state, - ) - if src_block_idx == dest_block_idx: - num_accepted_tokens_cpu[i] = 1 - @torch.inference_mode() def execute_model( self, @@ -3222,8 +3059,15 @@ def execute_model( pad_attn = cudagraph_mode == CUDAGraphMode.FULL if self.cache_config.mamba_cache_mode == "align": - # TODO: add limition: preprocess only have new blocks - self._preprocess_mamba(scheduler_output) + mamba_utils.preprocess_mamba( + scheduler_output, + self.kv_cache_config, + self.cache_config, + self.mamba_state_idx, + self.input_batch, + self.requests, + self.compilation_config.static_forward_context, + ) use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 ubatch_slices_attn = ubatch_slices_padded if pad_attn else ubatch_slices diff --git a/vllm/v1/worker/mamba_utils.py b/vllm/v1/worker/mamba_utils.py new file mode 100644 index 000000000000..ed03352e4acf --- /dev/null +++ b/vllm/v1/worker/mamba_utils.py @@ -0,0 +1,163 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import itertools +from typing import Any + +import torch + +from vllm.config import CacheConfig +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.kv_cache_interface import KVCacheConfig, MambaSpec +from vllm.v1.worker.gpu_input_batch import CachedRequestState +from vllm.v1.worker.lora_model_runner_mixin import GPUInputBatch + + +def get_mamba_groups(kv_cache_config: KVCacheConfig) -> tuple[list[int], MambaSpec]: + mamba_group_ids: list[int] = [] + mamba_specs: list[MambaSpec] = [] + for i in range(len(kv_cache_config.kv_cache_groups)): + kv_cache_spec = kv_cache_config.kv_cache_groups[i].kv_cache_spec + if isinstance(kv_cache_spec, MambaSpec): + mamba_group_ids.append(i) + mamba_specs.append(kv_cache_spec) + assert len(mamba_group_ids) > 0, "no mamba layers in the model" + assert all(mamba_specs[0] == spec for spec in mamba_specs) + return mamba_group_ids, mamba_specs[0] + + +def mamba_copy_block_for_qwen_next( + kv_cache_config: KVCacheConfig, + mamba_group_ids: list[int], + src_block_idx: int, + dest_block_idx: int, + accept_token_bias: int, + req_state: CachedRequestState, + forward_context: dict[str, Any], +): + # TODO: general impl for all models + if src_block_idx == dest_block_idx and accept_token_bias == 0: + return + for mamba_group_id in mamba_group_ids: + block_ids = req_state.block_ids[mamba_group_id] + dest_block_id = block_ids[dest_block_idx] + layer_names = kv_cache_config.kv_cache_groups[mamba_group_id].layer_names + for layer_name in layer_names: + attention = forward_context[layer_name] + kv_caches: list[list[torch.Tensor]] = attention.kv_cache[0] + conv_state, gdn_state = kv_caches + # conv state + conv_state_block_id = block_ids[src_block_idx] + src_conv_state = conv_state[conv_state_block_id][accept_token_bias:] + dest_conv_state = conv_state[dest_block_id] + dest_conv_state[: len(src_conv_state)].copy_(src_conv_state.clone()) + # gdn state + gdn_state_block_id = block_ids[src_block_idx + accept_token_bias] + src_gdn_state = gdn_state[gdn_state_block_id] + dest_gdn_state = gdn_state[dest_block_id] + dest_gdn_state.copy_(src_gdn_state) + + +def preprocess_mamba( + scheduler_output: SchedulerOutput, + kv_cache_config: KVCacheConfig, + cache_config: CacheConfig, + mamba_state_idx: dict[str, int], + input_batch: GPUInputBatch, + requests: dict[str, CachedRequestState], + forward_context: dict[str, Any], +): + """ + Copy the mamba state of previous step to the last + (1 + num_speculative_blocks) block. + """ + mamba_group_ids, mamba_spec = get_mamba_groups(kv_cache_config) + num_speculative_blocks = mamba_spec.num_speculative_blocks + # TODO(Chen): we need to optimize this function a lot + assert cache_config.enable_prefix_caching + block_size = mamba_spec.block_size + finished_req_ids = scheduler_output.finished_req_ids + preempted_req_ids = scheduler_output.preempted_req_ids or set() + for req_id in itertools.chain(finished_req_ids, preempted_req_ids): + mamba_state_idx.pop(req_id, None) + for i, req_id in enumerate(input_batch.req_ids): + req_state = requests[req_id] + prev_state_idx = mamba_state_idx.get(req_id) + if prev_state_idx is None: + # new / resumed request, no previous state + # if num_computed_tokens is 0, prev_state_idx will be -1 + prev_state_idx = (req_state.num_computed_tokens - 1) // block_size + + num_blocks = len(req_state.block_ids[mamba_group_ids[0]]) + + # We always save the current running state at the last + # (1 + num_speculative_blocks) block. + # A corner case worth mention here: assume we have block_size = 4 and + # num_speculative_tokens = 2. The request is [A, B, C] and contains 2 draft + # tokens [draft 1, draft 2]. Then we will have: + # Block 0: [A, B, C, draft 1] + # Block 1: [draft 2, TOFILL, TOFILL, TOFILL] + # Block 2: speculative block + # Block 3: speculative block + # And use block 1 to save the running state. + curr_state_idx = num_blocks - 1 - num_speculative_blocks + mamba_state_idx[req_id] = curr_state_idx + if prev_state_idx != -1 and prev_state_idx != curr_state_idx: + mamba_copy_block_for_qwen_next( + kv_cache_config, + mamba_group_ids, + prev_state_idx, + curr_state_idx, + input_batch.num_accepted_tokens_cpu[i] - 1, + req_state, + forward_context, + ) + input_batch.num_accepted_tokens_cpu[i] = 1 + + +def postprocess_mamba( + scheduler_output: SchedulerOutput, + kv_cache_config: KVCacheConfig, + input_batch: GPUInputBatch, + requests: dict[str, CachedRequestState], + mamba_state_idx: dict[str, int], + forward_context: dict[str, Any], +): + """ + If a blocks is converted from partial block to full block in this step, copy the + state from the block for running state to the new full block. + """ + num_scheduled_tokens_dict = scheduler_output.num_scheduled_tokens + scheduled_spec_decode_tokens_dict = scheduler_output.scheduled_spec_decode_tokens + num_accepted_tokens_cpu = input_batch.num_accepted_tokens_cpu + # NOTE: can be optimized as this function always returns the same result + mamba_group_ids, mamba_spec = get_mamba_groups(kv_cache_config) + # TODO: vectorize this loop + for i, req_id in enumerate(input_batch.req_ids): + req_state = requests[req_id] + num_computed_tokens = req_state.num_computed_tokens + num_draft_tokens = len(scheduled_spec_decode_tokens_dict.get(req_id, [])) + num_scheduled_tokens = num_scheduled_tokens_dict[req_id] + num_accepted_tokens = num_accepted_tokens_cpu[i] + num_tokens_running_state = ( + num_computed_tokens + num_scheduled_tokens - num_draft_tokens + ) + new_num_computed_tokens = num_tokens_running_state + num_accepted_tokens - 1 + aligned_new_computed_tokens = ( + new_num_computed_tokens // mamba_spec.block_size * mamba_spec.block_size + ) + # TODO: how to ensure all blocks that cache_blocks called are cached here? + if aligned_new_computed_tokens >= num_tokens_running_state: + accept_token_bias = aligned_new_computed_tokens - num_tokens_running_state + src_block_idx = mamba_state_idx[req_id] + dest_block_idx = aligned_new_computed_tokens // mamba_spec.block_size - 1 + mamba_copy_block_for_qwen_next( + kv_cache_config, + mamba_group_ids, + src_block_idx, + dest_block_idx, + accept_token_bias, + req_state, + forward_context, + ) + if src_block_idx == dest_block_idx: + num_accepted_tokens_cpu[i] = 1 diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index 6010b41bca97..4b7be233412f 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -16,10 +16,8 @@ from vllm.v1.attention.backends.utils import AttentionMetadataBuilder from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget from vllm.v1.kv_cache_interface import ( - KVCacheConfig, KVCacheGroupSpec, KVCacheSpec, - MambaSpec, ) @@ -372,16 +370,3 @@ def is_residual_scattered_for_sp( if compile_sizes is None: return False return num_input_tokens in compile_sizes - - -def get_mamba_groups(kv_cache_config: KVCacheConfig) -> tuple[list[int], MambaSpec]: - mamba_group_ids: list[int] = [] - mamba_specs: list[MambaSpec] = [] - for i in range(len(kv_cache_config.kv_cache_groups)): - kv_cache_spec = kv_cache_config.kv_cache_groups[i].kv_cache_spec - if isinstance(kv_cache_spec, MambaSpec): - mamba_group_ids.append(i) - mamba_specs.append(kv_cache_spec) - assert len(mamba_group_ids) > 0, "no mamba layers in the model" - assert all(mamba_specs[0] == spec for spec in mamba_specs) - return mamba_group_ids, mamba_specs[0] From 9fcc6eed233a84aa5a3a5856e28bc23ea6f382b0 Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Sat, 13 Dec 2025 11:47:50 +0000 Subject: [PATCH 063/130] fix mamba block size usage before update Signed-off-by: huanghaoyan.hhy --- vllm/model_executor/models/config.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 3a74114c69ae..fc0e03629138 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -313,11 +313,6 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: raise ValueError( "unknown mamba cache mode: %s", cache_config.mamba_cache_mode ) - # By default, mamba block size will be set to max_model_len (see - # below). When enabling prefix caching, we align mamba block size - # to the block size as the basic granularity for prefix caching. - if cache_config.mamba_block_size is None: - cache_config.mamba_block_size = cache_config.block_size elif cache_config.mamba_block_size is None: cache_config.mamba_block_size = model_config.max_model_len @@ -455,6 +450,13 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: attn_block_size, ) + # By default, mamba block size will be set to max_model_len. + # When enabling prefix caching and using align mamba cache + # mode, we align mamba block size to the block size as the + # basic granularity for prefix caching. + if cache_config.mamba_cache_mode == "align": + cache_config.mamba_block_size = cache_config.block_size + # compute new attention page size attn_page_size = cache_config.block_size * attn_page_size_1_token From ab3957891201117aa6d1f3d35918a57b9679794c Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Sat, 13 Dec 2025 16:09:07 +0000 Subject: [PATCH 064/130] fix prefill chunk incorrectly including draft tokens Signed-off-by: huanghaoyan.hhy --- vllm/v1/core/sched/scheduler.py | 33 ++++++++++++++++++--------------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index dbe17315d515..76e7af72560f 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -239,16 +239,6 @@ def _mamba_block_aligned_split( "External KV connector is not verified yet" ) if request.num_output_tokens == 0: # prefill - # Ensure new tokens for a request in the prefill phase do not contain - # draft tokens, especially in the last prefill chunk. For a hybrid-model, - # extra draft tokens would corrupt the generated Mamba state. - # TODO: This logic does not yet handle resumed requests. - if request.num_computed_tokens < request.num_prompt_tokens: - num_new_tokens = min( - request.num_prompt_tokens - request.num_computed_tokens, - num_new_tokens, - ) - # To enable block-aligned caching of the Mamba state, `num_new_tokens` # must be a multiple of `block_size`. # As an exception, if `num_new_tokens` is less than `block_size`, the @@ -334,11 +324,24 @@ def schedule(self) -> SchedulerOutput: req_index += 1 continue - num_new_tokens = ( - request.num_tokens_with_spec - + request.num_output_placeholders - - request.num_computed_tokens - ) + # Ensure new tokens for a request in the prefill phase do not contain + # draft tokens, especially in the last prefill chunk. For a hybrid-model, + # extra draft tokens would corrupt the generated Mamba state. + # TODO: This logic does not yet handle resumed requests. + if request.num_computed_tokens < request.num_prompt_tokens: + num_new_tokens = ( + min( + request.num_tokens_with_spec + request.num_output_placeholders, + request.num_prompt_tokens, + ) + - request.num_computed_tokens + ) + else: + num_new_tokens = ( + request.num_tokens_with_spec + + request.num_output_placeholders + - request.num_computed_tokens + ) if 0 < self.scheduler_config.long_prefill_token_threshold < num_new_tokens: num_new_tokens = self.scheduler_config.long_prefill_token_threshold num_new_tokens = min(num_new_tokens, token_budget) From e41e9737f81c9a75262993411593249d7a0d1d99 Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Wed, 17 Dec 2025 17:58:49 +0000 Subject: [PATCH 065/130] fix the bug in mamba_get_block_table_tensor when cuda graph is enabled Signed-off-by: huanghaoyan.hhy --- vllm/v1/attention/backends/utils.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index d8c4acb1aeb4..74a17aa20bcf 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -1209,7 +1209,12 @@ def mamba_get_block_table_tensor( else: assert isinstance(kv_cache_spec, MambaSpec) block_table_tensor = common_attn_metadata.block_table_tensor - start_indices = (common_attn_metadata.seq_lens - 1) // kv_cache_spec.block_size + # NOTE: For 0-length requests in CUDA graph, use a start_index of 0 + # to handle the invalid block table. + start_indices = torch.clamp( + (common_attn_metadata.seq_lens - 1) // kv_cache_spec.block_size, + min=0, + ) offsets = torch.arange( 1 + kv_cache_spec.num_speculative_blocks, device=block_table_tensor.device ) From 7b9f90c357165a39179cea094e0281a4fad99902 Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Thu, 18 Dec 2025 16:30:10 +0000 Subject: [PATCH 066/130] rm test script Signed-off-by: huanghaoyan.hhy --- my_tests/run_op_prefix_cache.sh | 72 --------------------------------- 1 file changed, 72 deletions(-) delete mode 100755 my_tests/run_op_prefix_cache.sh diff --git a/my_tests/run_op_prefix_cache.sh b/my_tests/run_op_prefix_cache.sh deleted file mode 100755 index a8cdc6baed50..000000000000 --- a/my_tests/run_op_prefix_cache.sh +++ /dev/null @@ -1,72 +0,0 @@ -#!/bin/bash - -PORT=8235 -TP=2 -MAX_MODEL_LEN=262144 -# MAX_MODEL_LEN=131072 - -DO_NSYS=0 - - -MODEL_DIR=/mnt/disk0/huanghaoyan.hhy/Qwen3-Next-80B-A3B-Instruct/ -echo "MODEL_DIR: $MODEL_DIR" - -NSYS_OUTPUT="qwen_next_h20_tp1_nreq1_fp8_mtp1_prefixcache_2" -NSYS="" -if (( DO_NSYS == 1 )); then - NSYS="nsys profile -c cudaProfilerApi --cuda-graph-trace node -o $NSYS_OUTPUT" -fi - -env_vars=( - # "CUDA_LAUNCH_BLOCKING=0" - "CUDA_VISIBLE_DEVICES=0,1,2,3" - # "CUDA_VISIBLE_DEVICES=6,7" - # "VLLM_ATTENTION_BACKEND=FLASH_ATTN" - # "VLLM_FLASH_ATTN_VERSION=3" - # "VLLM_ALLOW_LONG_MAX_MODEL_LEN=1" - # "OMP_NUM_THREADS=1" - # "VLLM_USE_V1=1" - # "VLLM_LOG_REQ_KV_LENS=1" - # "VLLM_USE_FLASHINFER_SAMPLER=0" -) - -for var in "${env_vars[@]}"; do - var_name="${var%%=*}" - var_value="${var#*=}" - echo -e "\t$var_name=$var_value" -done - -CMD=( env ) -for var in "${env_vars[@]}"; do - CMD+=( "$var" ) -done -CMD+=( - $NSYS vllm serve - $MODEL_DIR - # --trust-remote-code - --port "$PORT" - --gpu-memory-utilization 0.9 - -tp $TP - --enforce-eager - # --no-enable-prefix-caching - --enable-prefix-caching - # --no-enable-chunked-prefill - --enable-chunked-prefill - --mamba-cache-mode align - --max-num-batched-tokens 8192 - --distributed-executor-backend mp - --block-size 64 - --max-num-seqs 128 - # --max-num-seqs 16 - # --max-model-len $MAX_MODEL_LEN - # --max-seq-len-to-capture $MAX_MODEL_LEN - # --compilation-config "{\"use_inductor\": false, \"cudagraph_mode\": \"FULL_DECODE_ONLY\", \"custom_ops\": [\"all\"]}" - # --speculative-config "{\"method\": \"qwen3_next_mtp\", \"num_speculative_tokens\": 3}" - # --hf_overrides "{\"max_position_embeddings\": $MAX_MODEL_LEN}" -) - -echo -e "\nExecuting command:" -printf " %s" "${CMD[@]}" -echo -e "\n" - -"${CMD[@]}" From 0afca42a2a4fdcf22fd8b637f9b635055a166328 Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Mon, 15 Dec 2025 18:11:54 +0000 Subject: [PATCH 067/130] general copy spec demo Signed-off-by: huanghaoyan.hhy --- vllm/v1/worker/mamba_utils.py | 79 ++++++++++++++++++++++++++++++++++- 1 file changed, 78 insertions(+), 1 deletion(-) diff --git a/vllm/v1/worker/mamba_utils.py b/vllm/v1/worker/mamba_utils.py index ed03352e4acf..3c6a20e52fb1 100644 --- a/vllm/v1/worker/mamba_utils.py +++ b/vllm/v1/worker/mamba_utils.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import itertools -from typing import Any +from typing import Any, Callable import torch @@ -56,6 +56,83 @@ def mamba_copy_block_for_qwen_next( dest_gdn_state = gdn_state[dest_block_id] dest_gdn_state.copy_(src_gdn_state) +from dataclasses import dataclass + +@dataclass +class CopySpec: + block_idx_offset_func: Callable[[int], int] + data_offset_func: Callable[[torch.Tensor, int], int] + num_elements_func: Callable[[torch.Tensor, int], int] + + +conv_copy = CopySpec( + block_idx_offset_func=lambda bias: 0, + data_offset_func=lambda state, bias: bias * state.stride(0), + num_elements_func=lambda state, bias: state.numel() - bias * state.stride(0), +) + +full_copy = CopySpec( + block_idx_offset_func=lambda bias: bias, + data_offset_func=lambda state, bias: 0, + num_elements_func=lambda state, bias: state.numel(), +) + +def mamba_copy_block_for_qwen_next_v1( + kv_cache_config: KVCacheConfig, + mamba_group_ids: list[int], + src_block_idx: int, + dest_block_idx: int, + accept_token_bias: int, + req_state: CachedRequestState, + forward_context: dict[str, Any], +): + # TODO: general impl for all models + if src_block_idx == dest_block_idx and accept_token_bias == 0: + return + for mamba_group_id in mamba_group_ids: + block_ids = req_state.block_ids[mamba_group_id] + dest_block_id = block_ids[dest_block_idx] + layer_names = kv_cache_config.kv_cache_groups[mamba_group_id].layer_names + for layer_name in layer_names: + attention = forward_context[layer_name] + kv_caches: list[list[torch.Tensor]] = attention.kv_cache[0] + conv_state, gdn_state = kv_caches + + # conv state + conv_state_block_id_ref = block_ids[src_block_idx] + conv_state_block_id = block_ids[src_block_idx + conv_copy.block_idx_offset_func(accept_token_bias)] + assert conv_state_block_id_ref == conv_state_block_id, f"{conv_state_block_id_ref} != {conv_state_block_id}" + + conv_state_block = conv_state[conv_state_block_id] + data_offset = conv_copy.data_offset_func(conv_state_block, accept_token_bias) + num_elements = conv_copy.num_elements_func(conv_state_block, accept_token_bias) + src_conv_state_blk = conv_state_block.flatten()[data_offset:data_offset + num_elements] + dest_conv_state_blk = conv_state[dest_block_id].flatten()[:num_elements] + dest_conv_state_blk.copy_(src_conv_state_blk) + + src_conv_state = conv_state[conv_state_block_id][accept_token_bias:] + dest_conv_state = conv_state[dest_block_id] + # dest_conv_state[: len(src_conv_state)].copy_(src_conv_state.clone()) + dest_conv_state_ref = dest_conv_state[: len(src_conv_state)] + src_conv_state_ref = conv_state[conv_state_block_id][accept_token_bias:] + assert dest_conv_state_ref == src_conv_state_ref + + # gdn state + gdn_state_block_id_ref = block_ids[src_block_idx + accept_token_bias] + gdn_state_block_id = block_ids[src_block_idx + full_copy.block_idx_offset_func(accept_token_bias)] + assert gdn_state_block_id_ref == gdn_state_block_id, f"{gdn_state_block_id_ref} != {gdn_state_block_id}" + + gdn_state_block = gdn_state[gdn_state_block_id] + data_offset = full_copy.data_offset_func(gdn_state_block, accept_token_bias) + num_elements = full_copy.num_elements_func(gdn_state_block, accept_token_bias) + src_gdn_state_blk = gdn_state_block.flatten()[data_offset:data_offset + num_elements] + dest_gdn_state_blk = gdn_state[dest_block_id].flatten()[:num_elements] + dest_gdn_state_blk.copy_(src_gdn_state_blk) + + src_gdn_state_ref = gdn_state[gdn_state_block_id] + dest_gdn_state_ref = gdn_state[dest_block_id] + # dest_gdn_state.copy_(src_gdn_state) + assert dest_gdn_state_ref == src_gdn_state_ref def preprocess_mamba( scheduler_output: SchedulerOutput, From 51f88d0969fa572227c81f2c0090e6c19828382c Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Tue, 16 Dec 2025 17:23:29 +0000 Subject: [PATCH 068/130] copy spec v2 Signed-off-by: huanghaoyan.hhy --- vllm/v1/worker/mamba_utils.py | 50 +++++++++++++++++++++++++++++------ 1 file changed, 42 insertions(+), 8 deletions(-) diff --git a/vllm/v1/worker/mamba_utils.py b/vllm/v1/worker/mamba_utils.py index 3c6a20e52fb1..68699b78650f 100644 --- a/vllm/v1/worker/mamba_utils.py +++ b/vllm/v1/worker/mamba_utils.py @@ -106,8 +106,8 @@ def mamba_copy_block_for_qwen_next_v1( conv_state_block = conv_state[conv_state_block_id] data_offset = conv_copy.data_offset_func(conv_state_block, accept_token_bias) num_elements = conv_copy.num_elements_func(conv_state_block, accept_token_bias) - src_conv_state_blk = conv_state_block.flatten()[data_offset:data_offset + num_elements] - dest_conv_state_blk = conv_state[dest_block_id].flatten()[:num_elements] + src_conv_state_blk = conv_state_block.view(-1)[data_offset:data_offset + num_elements] + dest_conv_state_blk = conv_state[dest_block_id].view(-1)[:num_elements] dest_conv_state_blk.copy_(src_conv_state_blk) src_conv_state = conv_state[conv_state_block_id][accept_token_bias:] @@ -115,7 +115,7 @@ def mamba_copy_block_for_qwen_next_v1( # dest_conv_state[: len(src_conv_state)].copy_(src_conv_state.clone()) dest_conv_state_ref = dest_conv_state[: len(src_conv_state)] src_conv_state_ref = conv_state[conv_state_block_id][accept_token_bias:] - assert dest_conv_state_ref == src_conv_state_ref + assert torch.equal(dest_conv_state_ref, src_conv_state_ref) # gdn state gdn_state_block_id_ref = block_ids[src_block_idx + accept_token_bias] @@ -125,14 +125,48 @@ def mamba_copy_block_for_qwen_next_v1( gdn_state_block = gdn_state[gdn_state_block_id] data_offset = full_copy.data_offset_func(gdn_state_block, accept_token_bias) num_elements = full_copy.num_elements_func(gdn_state_block, accept_token_bias) - src_gdn_state_blk = gdn_state_block.flatten()[data_offset:data_offset + num_elements] - dest_gdn_state_blk = gdn_state[dest_block_id].flatten()[:num_elements] + src_gdn_state_blk = gdn_state_block.view(-1)[data_offset:data_offset + num_elements] + dest_gdn_state_blk = gdn_state[dest_block_id].view(-1)[:num_elements] dest_gdn_state_blk.copy_(src_gdn_state_blk) src_gdn_state_ref = gdn_state[gdn_state_block_id] dest_gdn_state_ref = gdn_state[dest_block_id] # dest_gdn_state.copy_(src_gdn_state) - assert dest_gdn_state_ref == src_gdn_state_ref + assert torch.equal(dest_gdn_state_ref, src_gdn_state_ref) + + +def mamba_copy_block_for_qwen_next_v2( + kv_cache_config: KVCacheConfig, + mamba_group_ids: list[int], + src_block_idx: int, + dest_block_idx: int, + accept_token_bias: int, + req_state: CachedRequestState, + forward_context: dict[str, Any], +): + if src_block_idx == dest_block_idx and accept_token_bias == 0: + return + print('>>> [BEBUG] mamba_copy_block_for_qwen_next_v2', flush=True) + copy_meta = [] + for mamba_group_id in mamba_group_ids: + block_ids = req_state.block_ids[mamba_group_id] + dest_block_id = block_ids[dest_block_idx] + layer_names = kv_cache_config.kv_cache_groups[mamba_group_id].layer_names + for layer_name in layer_names: + attention = forward_context[layer_name] + kv_caches: list[list[torch.Tensor]] = attention.kv_cache[0] + copy_specs = [conv_copy, full_copy] + for state, copy_spec in zip(kv_caches, copy_specs): + src_block_id = block_ids[src_block_idx + copy_spec.block_idx_offset_func(accept_token_bias)] + data_offset = copy_spec.data_offset_func(state[0], accept_token_bias) + num_elements = copy_spec.num_elements_func(state[0], accept_token_bias) + copy_meta.append((state[src_block_id], state[dest_block_id], data_offset, num_elements)) + + for src_state, dest_state, data_offset, num_elements in copy_meta: + src_state_data = src_state.view(-1)[data_offset:data_offset + num_elements] + dest_state_data = dest_state.view(-1)[:num_elements] + dest_state_data.copy_(src_state_data) + def preprocess_mamba( scheduler_output: SchedulerOutput, @@ -179,7 +213,7 @@ def preprocess_mamba( curr_state_idx = num_blocks - 1 - num_speculative_blocks mamba_state_idx[req_id] = curr_state_idx if prev_state_idx != -1 and prev_state_idx != curr_state_idx: - mamba_copy_block_for_qwen_next( + mamba_copy_block_for_qwen_next_v2( kv_cache_config, mamba_group_ids, prev_state_idx, @@ -227,7 +261,7 @@ def postprocess_mamba( accept_token_bias = aligned_new_computed_tokens - num_tokens_running_state src_block_idx = mamba_state_idx[req_id] dest_block_idx = aligned_new_computed_tokens // mamba_spec.block_size - 1 - mamba_copy_block_for_qwen_next( + mamba_copy_block_for_qwen_next_v2( kv_cache_config, mamba_group_ids, src_block_idx, From 7448c6ade1c60182a10520a45344860c1d0b56cd Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Tue, 16 Dec 2025 18:08:24 +0000 Subject: [PATCH 069/130] mamba copy v3 Signed-off-by: huanghaoyan.hhy --- vllm/v1/worker/mamba_utils.py | 87 +++++++++++++++++++++++++++++++++-- 1 file changed, 83 insertions(+), 4 deletions(-) diff --git a/vllm/v1/worker/mamba_utils.py b/vllm/v1/worker/mamba_utils.py index 68699b78650f..0e648cd769ff 100644 --- a/vllm/v1/worker/mamba_utils.py +++ b/vllm/v1/worker/mamba_utils.py @@ -1,9 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import dataclass import itertools from typing import Any, Callable import torch +import triton +import triton.language as tl from vllm.config import CacheConfig from vllm.v1.core.sched.output import SchedulerOutput @@ -12,6 +15,43 @@ from vllm.v1.worker.lora_model_runner_mixin import GPUInputBatch +@triton.jit +def batch_memcpy_kernel( + src_ptrs, + dst_ptrs, + sizes, + BLOCK_SIZE: tl.constexpr +): + pid = tl.program_id(0) + + src_ptr = tl.load(src_ptrs + pid) + dst_ptr = tl.load(dst_ptrs + pid) + size = tl.load(sizes + pid) + + offsets = tl.arange(0, BLOCK_SIZE) + for i in range(0, size, BLOCK_SIZE): + mask = (i + offsets) < size + + curr_src_ptr = (src_ptr + i + offsets).to(tl.pointer_type(tl.uint8)) + curr_dst_ptr = (dst_ptr + i + offsets).to(tl.pointer_type(tl.uint8)) + + data = tl.load(curr_src_ptr, mask=mask) + tl.store(curr_dst_ptr, data, mask=mask) + +def batch_memcpy(src_ptrs, dst_ptrs, sizes): + batch = src_ptrs.shape[0] + assert dst_ptrs.shape[0] == batch + assert sizes.shape[0] == batch + + grid = (batch,) + BLOCK_SIZE = 1024 + batch_memcpy_kernel[grid]( + src_ptrs, + dst_ptrs, + sizes, + BLOCK_SIZE=BLOCK_SIZE + ) + def get_mamba_groups(kv_cache_config: KVCacheConfig) -> tuple[list[int], MambaSpec]: mamba_group_ids: list[int] = [] mamba_specs: list[MambaSpec] = [] @@ -56,7 +96,6 @@ def mamba_copy_block_for_qwen_next( dest_gdn_state = gdn_state[dest_block_id] dest_gdn_state.copy_(src_gdn_state) -from dataclasses import dataclass @dataclass class CopySpec: @@ -134,7 +173,6 @@ def mamba_copy_block_for_qwen_next_v1( # dest_gdn_state.copy_(src_gdn_state) assert torch.equal(dest_gdn_state_ref, src_gdn_state_ref) - def mamba_copy_block_for_qwen_next_v2( kv_cache_config: KVCacheConfig, mamba_group_ids: list[int], @@ -167,6 +205,47 @@ def mamba_copy_block_for_qwen_next_v2( dest_state_data = dest_state.view(-1)[:num_elements] dest_state_data.copy_(src_state_data) +def mamba_copy_block_for_qwen_next_v3( + kv_cache_config: KVCacheConfig, + mamba_group_ids: list[int], + src_block_idx: int, + dest_block_idx: int, + accept_token_bias: int, + req_state: CachedRequestState, + forward_context: dict[str, Any], +): + if src_block_idx == dest_block_idx and accept_token_bias == 0: + return + # print('>>> [BEBUG] mamba_copy_block_for_qwen_next_v3', flush=True) + + src_state_list = [] + dest_state_list = [] + # data_offset_list = [] + num_elements_list = [] + for mamba_group_id in mamba_group_ids: + block_ids = req_state.block_ids[mamba_group_id] + dest_block_id = block_ids[dest_block_idx] + layer_names = kv_cache_config.kv_cache_groups[mamba_group_id].layer_names + for layer_name in layer_names: + attention = forward_context[layer_name] + kv_caches: list[list[torch.Tensor]] = attention.kv_cache[0] + copy_specs = [conv_copy, full_copy] + for state, copy_spec in zip(kv_caches, copy_specs): + src_block_id = block_ids[src_block_idx + copy_spec.block_idx_offset_func(accept_token_bias)] + data_offset = copy_spec.data_offset_func(state[0], accept_token_bias) + num_elements = copy_spec.num_elements_func(state[0], accept_token_bias) + src_state_list.append(state[src_block_id].data_ptr() + data_offset * state.element_size()) + dest_state_list.append(state[dest_block_id].data_ptr()) + # data_offset_list.append(data_offset) + num_elements_list.append(num_elements * state.element_size()) + + src_state_ptrs = torch.tensor(src_state_list, device='cuda', dtype=torch.int64) + dst_state_ptrs = torch.tensor(dest_state_list, device='cuda', dtype=torch.int64) + # data_offsets = torch.tensor(data_offset_list, device='cuda', dtype=torch.int32) + num_elements = torch.tensor(num_elements_list, device='cuda', dtype=torch.int32) + + batch_memcpy(src_state_ptrs, dst_state_ptrs, num_elements) + def preprocess_mamba( scheduler_output: SchedulerOutput, @@ -213,7 +292,7 @@ def preprocess_mamba( curr_state_idx = num_blocks - 1 - num_speculative_blocks mamba_state_idx[req_id] = curr_state_idx if prev_state_idx != -1 and prev_state_idx != curr_state_idx: - mamba_copy_block_for_qwen_next_v2( + mamba_copy_block_for_qwen_next_v3( kv_cache_config, mamba_group_ids, prev_state_idx, @@ -261,7 +340,7 @@ def postprocess_mamba( accept_token_bias = aligned_new_computed_tokens - num_tokens_running_state src_block_idx = mamba_state_idx[req_id] dest_block_idx = aligned_new_computed_tokens // mamba_spec.block_size - 1 - mamba_copy_block_for_qwen_next_v2( + mamba_copy_block_for_qwen_next_v3( kv_cache_config, mamba_group_ids, src_block_idx, From 6ec0e4811f0ea6ac7fb0491ec41f426d2b1d7898 Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Wed, 17 Dec 2025 15:10:37 +0000 Subject: [PATCH 070/130] clean up code Signed-off-by: huanghaoyan.hhy --- vllm/v1/worker/mamba_utils.py | 131 +--------------------------------- 1 file changed, 3 insertions(+), 128 deletions(-) diff --git a/vllm/v1/worker/mamba_utils.py b/vllm/v1/worker/mamba_utils.py index 0e648cd769ff..c5bbf633745c 100644 --- a/vllm/v1/worker/mamba_utils.py +++ b/vllm/v1/worker/mamba_utils.py @@ -65,38 +65,6 @@ def get_mamba_groups(kv_cache_config: KVCacheConfig) -> tuple[list[int], MambaSp return mamba_group_ids, mamba_specs[0] -def mamba_copy_block_for_qwen_next( - kv_cache_config: KVCacheConfig, - mamba_group_ids: list[int], - src_block_idx: int, - dest_block_idx: int, - accept_token_bias: int, - req_state: CachedRequestState, - forward_context: dict[str, Any], -): - # TODO: general impl for all models - if src_block_idx == dest_block_idx and accept_token_bias == 0: - return - for mamba_group_id in mamba_group_ids: - block_ids = req_state.block_ids[mamba_group_id] - dest_block_id = block_ids[dest_block_idx] - layer_names = kv_cache_config.kv_cache_groups[mamba_group_id].layer_names - for layer_name in layer_names: - attention = forward_context[layer_name] - kv_caches: list[list[torch.Tensor]] = attention.kv_cache[0] - conv_state, gdn_state = kv_caches - # conv state - conv_state_block_id = block_ids[src_block_idx] - src_conv_state = conv_state[conv_state_block_id][accept_token_bias:] - dest_conv_state = conv_state[dest_block_id] - dest_conv_state[: len(src_conv_state)].copy_(src_conv_state.clone()) - # gdn state - gdn_state_block_id = block_ids[src_block_idx + accept_token_bias] - src_gdn_state = gdn_state[gdn_state_block_id] - dest_gdn_state = gdn_state[dest_block_id] - dest_gdn_state.copy_(src_gdn_state) - - @dataclass class CopySpec: block_idx_offset_func: Callable[[int], int] @@ -116,96 +84,7 @@ class CopySpec: num_elements_func=lambda state, bias: state.numel(), ) -def mamba_copy_block_for_qwen_next_v1( - kv_cache_config: KVCacheConfig, - mamba_group_ids: list[int], - src_block_idx: int, - dest_block_idx: int, - accept_token_bias: int, - req_state: CachedRequestState, - forward_context: dict[str, Any], -): - # TODO: general impl for all models - if src_block_idx == dest_block_idx and accept_token_bias == 0: - return - for mamba_group_id in mamba_group_ids: - block_ids = req_state.block_ids[mamba_group_id] - dest_block_id = block_ids[dest_block_idx] - layer_names = kv_cache_config.kv_cache_groups[mamba_group_id].layer_names - for layer_name in layer_names: - attention = forward_context[layer_name] - kv_caches: list[list[torch.Tensor]] = attention.kv_cache[0] - conv_state, gdn_state = kv_caches - - # conv state - conv_state_block_id_ref = block_ids[src_block_idx] - conv_state_block_id = block_ids[src_block_idx + conv_copy.block_idx_offset_func(accept_token_bias)] - assert conv_state_block_id_ref == conv_state_block_id, f"{conv_state_block_id_ref} != {conv_state_block_id}" - - conv_state_block = conv_state[conv_state_block_id] - data_offset = conv_copy.data_offset_func(conv_state_block, accept_token_bias) - num_elements = conv_copy.num_elements_func(conv_state_block, accept_token_bias) - src_conv_state_blk = conv_state_block.view(-1)[data_offset:data_offset + num_elements] - dest_conv_state_blk = conv_state[dest_block_id].view(-1)[:num_elements] - dest_conv_state_blk.copy_(src_conv_state_blk) - - src_conv_state = conv_state[conv_state_block_id][accept_token_bias:] - dest_conv_state = conv_state[dest_block_id] - # dest_conv_state[: len(src_conv_state)].copy_(src_conv_state.clone()) - dest_conv_state_ref = dest_conv_state[: len(src_conv_state)] - src_conv_state_ref = conv_state[conv_state_block_id][accept_token_bias:] - assert torch.equal(dest_conv_state_ref, src_conv_state_ref) - - # gdn state - gdn_state_block_id_ref = block_ids[src_block_idx + accept_token_bias] - gdn_state_block_id = block_ids[src_block_idx + full_copy.block_idx_offset_func(accept_token_bias)] - assert gdn_state_block_id_ref == gdn_state_block_id, f"{gdn_state_block_id_ref} != {gdn_state_block_id}" - - gdn_state_block = gdn_state[gdn_state_block_id] - data_offset = full_copy.data_offset_func(gdn_state_block, accept_token_bias) - num_elements = full_copy.num_elements_func(gdn_state_block, accept_token_bias) - src_gdn_state_blk = gdn_state_block.view(-1)[data_offset:data_offset + num_elements] - dest_gdn_state_blk = gdn_state[dest_block_id].view(-1)[:num_elements] - dest_gdn_state_blk.copy_(src_gdn_state_blk) - - src_gdn_state_ref = gdn_state[gdn_state_block_id] - dest_gdn_state_ref = gdn_state[dest_block_id] - # dest_gdn_state.copy_(src_gdn_state) - assert torch.equal(dest_gdn_state_ref, src_gdn_state_ref) - -def mamba_copy_block_for_qwen_next_v2( - kv_cache_config: KVCacheConfig, - mamba_group_ids: list[int], - src_block_idx: int, - dest_block_idx: int, - accept_token_bias: int, - req_state: CachedRequestState, - forward_context: dict[str, Any], -): - if src_block_idx == dest_block_idx and accept_token_bias == 0: - return - print('>>> [BEBUG] mamba_copy_block_for_qwen_next_v2', flush=True) - copy_meta = [] - for mamba_group_id in mamba_group_ids: - block_ids = req_state.block_ids[mamba_group_id] - dest_block_id = block_ids[dest_block_idx] - layer_names = kv_cache_config.kv_cache_groups[mamba_group_id].layer_names - for layer_name in layer_names: - attention = forward_context[layer_name] - kv_caches: list[list[torch.Tensor]] = attention.kv_cache[0] - copy_specs = [conv_copy, full_copy] - for state, copy_spec in zip(kv_caches, copy_specs): - src_block_id = block_ids[src_block_idx + copy_spec.block_idx_offset_func(accept_token_bias)] - data_offset = copy_spec.data_offset_func(state[0], accept_token_bias) - num_elements = copy_spec.num_elements_func(state[0], accept_token_bias) - copy_meta.append((state[src_block_id], state[dest_block_id], data_offset, num_elements)) - - for src_state, dest_state, data_offset, num_elements in copy_meta: - src_state_data = src_state.view(-1)[data_offset:data_offset + num_elements] - dest_state_data = dest_state.view(-1)[:num_elements] - dest_state_data.copy_(src_state_data) - -def mamba_copy_block_for_qwen_next_v3( +def mamba_copy_block_for_qwen_next( kv_cache_config: KVCacheConfig, mamba_group_ids: list[int], src_block_idx: int, @@ -216,11 +95,9 @@ def mamba_copy_block_for_qwen_next_v3( ): if src_block_idx == dest_block_idx and accept_token_bias == 0: return - # print('>>> [BEBUG] mamba_copy_block_for_qwen_next_v3', flush=True) src_state_list = [] dest_state_list = [] - # data_offset_list = [] num_elements_list = [] for mamba_group_id in mamba_group_ids: block_ids = req_state.block_ids[mamba_group_id] @@ -236,12 +113,10 @@ def mamba_copy_block_for_qwen_next_v3( num_elements = copy_spec.num_elements_func(state[0], accept_token_bias) src_state_list.append(state[src_block_id].data_ptr() + data_offset * state.element_size()) dest_state_list.append(state[dest_block_id].data_ptr()) - # data_offset_list.append(data_offset) num_elements_list.append(num_elements * state.element_size()) src_state_ptrs = torch.tensor(src_state_list, device='cuda', dtype=torch.int64) dst_state_ptrs = torch.tensor(dest_state_list, device='cuda', dtype=torch.int64) - # data_offsets = torch.tensor(data_offset_list, device='cuda', dtype=torch.int32) num_elements = torch.tensor(num_elements_list, device='cuda', dtype=torch.int32) batch_memcpy(src_state_ptrs, dst_state_ptrs, num_elements) @@ -292,7 +167,7 @@ def preprocess_mamba( curr_state_idx = num_blocks - 1 - num_speculative_blocks mamba_state_idx[req_id] = curr_state_idx if prev_state_idx != -1 and prev_state_idx != curr_state_idx: - mamba_copy_block_for_qwen_next_v3( + mamba_copy_block_for_qwen_next( kv_cache_config, mamba_group_ids, prev_state_idx, @@ -340,7 +215,7 @@ def postprocess_mamba( accept_token_bias = aligned_new_computed_tokens - num_tokens_running_state src_block_idx = mamba_state_idx[req_id] dest_block_idx = aligned_new_computed_tokens // mamba_spec.block_size - 1 - mamba_copy_block_for_qwen_next_v3( + mamba_copy_block_for_qwen_next( kv_cache_config, mamba_group_ids, src_block_idx, From 73cca3c6442019607ae07716e61010865b01df07 Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Thu, 18 Dec 2025 17:29:56 +0000 Subject: [PATCH 071/130] mamba copy v4 Signed-off-by: huanghaoyan.hhy --- vllm/v1/worker/mamba_utils.py | 68 +++++++++++++++++++++++++++-------- 1 file changed, 53 insertions(+), 15 deletions(-) diff --git a/vllm/v1/worker/mamba_utils.py b/vllm/v1/worker/mamba_utils.py index c5bbf633745c..a776c15731ef 100644 --- a/vllm/v1/worker/mamba_utils.py +++ b/vllm/v1/worker/mamba_utils.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from abc import ABC, abstractmethod from dataclasses import dataclass import itertools from typing import Any, Callable @@ -66,23 +67,60 @@ def get_mamba_groups(kv_cache_config: KVCacheConfig) -> tuple[list[int], MambaSp @dataclass -class CopySpec: - block_idx_offset_func: Callable[[int], int] - data_offset_func: Callable[[torch.Tensor, int], int] - num_elements_func: Callable[[torch.Tensor, int], int] +class MambaCopySpec(ABC): + @staticmethod + @abstractmethod + def block_idx_offset_func(accept_token_bias: int) -> int: + """ + Return the offset of the source block idx which needs to be copied. + """ + pass -conv_copy = CopySpec( - block_idx_offset_func=lambda bias: 0, - data_offset_func=lambda state, bias: bias * state.stride(0), - num_elements_func=lambda state, bias: state.numel() - bias * state.stride(0), -) + @staticmethod + @abstractmethod + def data_offset_func(state: torch.Tensor, accept_token_bias: int) -> int: + """ + Return the offset of the data in the source block which needs to be copied. + """ + pass + + @staticmethod + @abstractmethod + def num_elements_func(state: torch.Tensor, accept_token_bias: int) -> int: + """ + Return the number of elements to be copied. + """ + pass + +class MambaFullCopySpec(MambaCopySpec): + + @staticmethod + def block_idx_offset_func(accept_token_bias: int) -> int: + return accept_token_bias + + @staticmethod + def data_offset_func(state: torch.Tensor, accept_token_bias: int) -> int: + return 0 + + @staticmethod + def num_elements_func(state: torch.Tensor, accept_token_bias: int) -> int: + return state.numel() + +class MambaConvCopySpec(MambaCopySpec): + + @staticmethod + def block_idx_offset_func(accept_token_bias: int) -> int: + return accept_token_bias + + @staticmethod + def data_offset_func(state: torch.Tensor, accept_token_bias: int) -> int: + return accept_token_bias * state.stride(0) + + @staticmethod + def num_elements_func(state: torch.Tensor, accept_token_bias: int) -> int: + return state.numel() - accept_token_bias * state.stride(0) -full_copy = CopySpec( - block_idx_offset_func=lambda bias: bias, - data_offset_func=lambda state, bias: 0, - num_elements_func=lambda state, bias: state.numel(), -) def mamba_copy_block_for_qwen_next( kv_cache_config: KVCacheConfig, @@ -106,7 +144,7 @@ def mamba_copy_block_for_qwen_next( for layer_name in layer_names: attention = forward_context[layer_name] kv_caches: list[list[torch.Tensor]] = attention.kv_cache[0] - copy_specs = [conv_copy, full_copy] + copy_specs: list[type[MambaCopySpec]] = [MambaConvCopySpec, MambaFullCopySpec] for state, copy_spec in zip(kv_caches, copy_specs): src_block_id = block_ids[src_block_idx + copy_spec.block_idx_offset_func(accept_token_bias)] data_offset = copy_spec.data_offset_func(state[0], accept_token_bias) From cfb85d16bb6fda0c89c873d94c121545370f813b Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Thu, 18 Dec 2025 18:15:32 +0000 Subject: [PATCH 072/130] support general mamba copy Signed-off-by: huanghaoyan.hhy --- vllm/model_executor/layers/mamba/abstract.py | 6 ++ .../layers/mamba/mamba_utils.py | 67 +++++++++++++++++ vllm/model_executor/models/qwen3_next.py | 5 ++ vllm/v1/kv_cache_interface.py | 2 + vllm/v1/worker/mamba_utils.py | 74 +++---------------- 5 files changed, 91 insertions(+), 63 deletions(-) diff --git a/vllm/model_executor/layers/mamba/abstract.py b/vllm/model_executor/layers/mamba/abstract.py index 74f4383e9c23..dfebcd80476f 100644 --- a/vllm/model_executor/layers/mamba/abstract.py +++ b/vllm/model_executor/layers/mamba/abstract.py @@ -9,7 +9,9 @@ from vllm.attention.selector import get_mamba_attn_backend from vllm.config import VllmConfig from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase +from vllm.model_executor.layers.mamba.mamba_utils import MambaFullCopySpec from vllm.v1.kv_cache_interface import KVCacheSpec, MambaSpec +from vllm.v1.worker.mamba_utils import MambaCopySpec class MambaBase(AttentionLayerBase): @@ -40,6 +42,9 @@ def mamba_type(self) -> str: def get_state_dtype(self) -> tuple[torch.dtype, ...]: pass + def get_copy_spec(self) -> tuple[type[MambaCopySpec], ...]: + return (MambaFullCopySpec, ) * len(self.get_state_dtype()) + def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None: if ( vllm_config.speculative_config is not None @@ -61,6 +66,7 @@ def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None: if vllm_config.speculative_config else 0 ), + copy_specs=self.get_copy_spec(), ) def get_attn_backend(self) -> type[AttentionBackend]: diff --git a/vllm/model_executor/layers/mamba/mamba_utils.py b/vllm/model_executor/layers/mamba/mamba_utils.py index 831dab2fbb01..8ab615da9c6d 100644 --- a/vllm/model_executor/layers/mamba/mamba_utils.py +++ b/vllm/model_executor/layers/mamba/mamba_utils.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from abc import ABC, abstractmethod import torch from vllm.config.cache import MambaDType @@ -223,3 +224,69 @@ def kda_state_shape( conv_state_k_shape, recurrent_state_shape, ) + + +class MambaCopySpec(ABC): + @staticmethod + @abstractmethod + def block_idx_offset_func(accept_token_bias: int) -> int: + """ + Return the offset of the source block idx which needs to be copied. + """ + pass + + @staticmethod + @abstractmethod + def data_offset_func(state: torch.Tensor, accept_token_bias: int) -> int: + """ + Return the offset of the data in the source block which needs to be copied. + """ + pass + + @staticmethod + @abstractmethod + def num_elements_func(state: torch.Tensor, accept_token_bias: int) -> int: + """ + Return the number of elements to be copied. + """ + pass + +class MambaFullCopySpec(MambaCopySpec): + @staticmethod + def block_idx_offset_func(accept_token_bias: int) -> int: + return accept_token_bias + + @staticmethod + def data_offset_func(state: torch.Tensor, accept_token_bias: int) -> int: + return 0 + + @staticmethod + def num_elements_func(state: torch.Tensor, accept_token_bias: int) -> int: + return state.numel() + +class MambaConvCopySpec(MambaCopySpec): + @staticmethod + def block_idx_offset_func(accept_token_bias: int) -> int: + return accept_token_bias + + @staticmethod + def data_offset_func(state: torch.Tensor, accept_token_bias: int) -> int: + return accept_token_bias * state.stride(0) + + @staticmethod + def num_elements_func(state: torch.Tensor, accept_token_bias: int) -> int: + return state.numel() - accept_token_bias * state.stride(0) + + +class MambaCopySpecCalculator: + @classmethod + def linear_attention_copy_spec(cls): + return (MambaFullCopySpec,) + + @classmethod + def mamba1_state_copy_spec(cls): + return MambaConvCopySpec, MambaFullCopySpec + + @classmethod + def gated_delta_net_copy_spec(cls): + return MambaConvCopySpec, MambaFullCopySpec \ No newline at end of file diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index daf917a583e0..cdd4866a9aa5 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -50,6 +50,8 @@ from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.mamba.mamba_mixer2 import mamba_v2_sharded_weight_loader from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaCopySpec, + MambaCopySpecCalculator, MambaStateDtypeCalculator, MambaStateShapeCalculator, ) @@ -234,6 +236,9 @@ def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: self.conv_kernel_size, self.num_spec, ) + + def get_copy_spec(self) -> tuple[type[MambaCopySpec], ...]: + return MambaCopySpecCalculator.gated_delta_net_copy_spec() def __init__( self, diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 867354a81ba3..986ef13550a4 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -10,6 +10,7 @@ from vllm.config import VllmConfig from vllm.logger import init_logger +from vllm.model_executor.layers.mamba.mamba_utils import MambaCopySpec from vllm.utils.math_utils import cdiv from vllm.utils.torch_utils import get_dtype_size @@ -247,6 +248,7 @@ class MambaSpec(KVCacheSpec): page_size_padded: int | None = None mamba_type: str = "mamba2" num_speculative_blocks: int = 0 + copy_specs: tuple[type[MambaCopySpec], ...] | None = None @property def page_size_bytes(self) -> int: diff --git a/vllm/v1/worker/mamba_utils.py b/vllm/v1/worker/mamba_utils.py index a776c15731ef..8c005050362f 100644 --- a/vllm/v1/worker/mamba_utils.py +++ b/vllm/v1/worker/mamba_utils.py @@ -1,15 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from abc import ABC, abstractmethod -from dataclasses import dataclass import itertools -from typing import Any, Callable +from typing import Any import torch import triton import triton.language as tl from vllm.config import CacheConfig +from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaCopySpec, +) from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import KVCacheConfig, MambaSpec from vllm.v1.worker.gpu_input_batch import CachedRequestState @@ -66,64 +67,9 @@ def get_mamba_groups(kv_cache_config: KVCacheConfig) -> tuple[list[int], MambaSp return mamba_group_ids, mamba_specs[0] -@dataclass -class MambaCopySpec(ABC): - - @staticmethod - @abstractmethod - def block_idx_offset_func(accept_token_bias: int) -> int: - """ - Return the offset of the source block idx which needs to be copied. - """ - pass - - @staticmethod - @abstractmethod - def data_offset_func(state: torch.Tensor, accept_token_bias: int) -> int: - """ - Return the offset of the data in the source block which needs to be copied. - """ - pass - - @staticmethod - @abstractmethod - def num_elements_func(state: torch.Tensor, accept_token_bias: int) -> int: - """ - Return the number of elements to be copied. - """ - pass - -class MambaFullCopySpec(MambaCopySpec): - - @staticmethod - def block_idx_offset_func(accept_token_bias: int) -> int: - return accept_token_bias - - @staticmethod - def data_offset_func(state: torch.Tensor, accept_token_bias: int) -> int: - return 0 - - @staticmethod - def num_elements_func(state: torch.Tensor, accept_token_bias: int) -> int: - return state.numel() - -class MambaConvCopySpec(MambaCopySpec): - - @staticmethod - def block_idx_offset_func(accept_token_bias: int) -> int: - return accept_token_bias - - @staticmethod - def data_offset_func(state: torch.Tensor, accept_token_bias: int) -> int: - return accept_token_bias * state.stride(0) - - @staticmethod - def num_elements_func(state: torch.Tensor, accept_token_bias: int) -> int: - return state.numel() - accept_token_bias * state.stride(0) - - -def mamba_copy_block_for_qwen_next( +def mamba_copy_block( kv_cache_config: KVCacheConfig, + mamba_spec: MambaSpec, mamba_group_ids: list[int], src_block_idx: int, dest_block_idx: int, @@ -137,6 +83,7 @@ def mamba_copy_block_for_qwen_next( src_state_list = [] dest_state_list = [] num_elements_list = [] + copy_specs: tuple[type[MambaCopySpec], ...] = mamba_spec.copy_specs for mamba_group_id in mamba_group_ids: block_ids = req_state.block_ids[mamba_group_id] dest_block_id = block_ids[dest_block_idx] @@ -144,7 +91,6 @@ def mamba_copy_block_for_qwen_next( for layer_name in layer_names: attention = forward_context[layer_name] kv_caches: list[list[torch.Tensor]] = attention.kv_cache[0] - copy_specs: list[type[MambaCopySpec]] = [MambaConvCopySpec, MambaFullCopySpec] for state, copy_spec in zip(kv_caches, copy_specs): src_block_id = block_ids[src_block_idx + copy_spec.block_idx_offset_func(accept_token_bias)] data_offset = copy_spec.data_offset_func(state[0], accept_token_bias) @@ -205,8 +151,9 @@ def preprocess_mamba( curr_state_idx = num_blocks - 1 - num_speculative_blocks mamba_state_idx[req_id] = curr_state_idx if prev_state_idx != -1 and prev_state_idx != curr_state_idx: - mamba_copy_block_for_qwen_next( + mamba_copy_block( kv_cache_config, + mamba_spec, mamba_group_ids, prev_state_idx, curr_state_idx, @@ -253,8 +200,9 @@ def postprocess_mamba( accept_token_bias = aligned_new_computed_tokens - num_tokens_running_state src_block_idx = mamba_state_idx[req_id] dest_block_idx = aligned_new_computed_tokens // mamba_spec.block_size - 1 - mamba_copy_block_for_qwen_next( + mamba_copy_block( kv_cache_config, + mamba_spec, mamba_group_ids, src_block_idx, dest_block_idx, From c52024f6c770568b3c6022b007ea096496faff2e Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Fri, 19 Dec 2025 19:44:26 +0000 Subject: [PATCH 073/130] fix a bug of conv copy spec Signed-off-by: huanghaoyan.hhy --- vllm/model_executor/layers/mamba/mamba_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/mamba/mamba_utils.py b/vllm/model_executor/layers/mamba/mamba_utils.py index 8ab615da9c6d..dddc99a2c5cb 100644 --- a/vllm/model_executor/layers/mamba/mamba_utils.py +++ b/vllm/model_executor/layers/mamba/mamba_utils.py @@ -267,7 +267,7 @@ def num_elements_func(state: torch.Tensor, accept_token_bias: int) -> int: class MambaConvCopySpec(MambaCopySpec): @staticmethod def block_idx_offset_func(accept_token_bias: int) -> int: - return accept_token_bias + return 0 @staticmethod def data_offset_func(state: torch.Tensor, accept_token_bias: int) -> int: From 4f401efe6d17a81e675b53ecd363bdcbd05816d6 Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Fri, 19 Dec 2025 19:52:58 +0000 Subject: [PATCH 074/130] support other mamba models Signed-off-by: huanghaoyan.hhy --- vllm/model_executor/layers/kda.py | 9 +++++++- .../layers/mamba/linear_attn.py | 4 ++++ .../layers/mamba/mamba_mixer.py | 4 ++++ .../layers/mamba/mamba_mixer2.py | 4 ++++ .../layers/mamba/mamba_utils.py | 22 +++++++++++++++---- .../model_executor/layers/mamba/short_conv.py | 4 ++++ vllm/model_executor/models/plamo2.py | 4 ++++ vllm/model_executor/models/qwen3_next.py | 3 +-- 8 files changed, 47 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/layers/kda.py b/vllm/model_executor/layers/kda.py index 27cc3884517f..53f017fbe38d 100644 --- a/vllm/model_executor/layers/kda.py +++ b/vllm/model_executor/layers/kda.py @@ -31,7 +31,11 @@ RowParallelLinear, ) from .mamba.abstract import MambaBase -from .mamba.mamba_utils import MambaStateDtypeCalculator, MambaStateShapeCalculator +from .mamba.mamba_utils import ( + MambaCopySpecCalculator, + MambaStateDtypeCalculator, + MambaStateShapeCalculator, +) from .mamba.ops.causal_conv1d import causal_conv1d_fn, causal_conv1d_update from .quantization.base_config import QuantizationConfig @@ -100,6 +104,9 @@ def get_state_shape( self.tp_size, self.num_heads, self.head_dim, conv_kernel_size=self.conv_size ) + def get_copy_spec(self): + return MambaCopySpecCalculator.kda_state_copy_spec() + def __init__( self, layer_idx: int, diff --git a/vllm/model_executor/layers/mamba/linear_attn.py b/vllm/model_executor/layers/mamba/linear_attn.py index 278713408c28..9aafd50f4090 100644 --- a/vllm/model_executor/layers/mamba/linear_attn.py +++ b/vllm/model_executor/layers/mamba/linear_attn.py @@ -24,6 +24,7 @@ from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaCopySpecCalculator, MambaStateDtypeCalculator, MambaStateShapeCalculator, ) @@ -126,6 +127,9 @@ def get_state_shape(self) -> tuple[tuple[int, int, int], ...]: return MambaStateShapeCalculator.linear_attention_state_shape( num_heads=self.num_heads, tp_size=self.tp_size, head_dim=self.head_dim ) + + def get_copy_spec(self): + return MambaCopySpecCalculator.linear_attention_copy_spec() def __init__( self, diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index 9d509b82d9bb..18c044515295 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -22,6 +22,7 @@ ) from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaCopySpecCalculator, MambaStateDtypeCalculator, MambaStateShapeCalculator, ) @@ -444,6 +445,9 @@ def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: conv_kernel=self.conv_kernel_size, ) + def get_copy_spec(self): + return MambaCopySpecCalculator.mamba1_state_copy_spec() + @property def mamba_type(self) -> str: return "mamba1" diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index ef3936b839b2..c33bbba09add 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -23,6 +23,7 @@ ) from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaCopySpecCalculator, MambaStateDtypeCalculator, MambaStateShapeCalculator, ) @@ -898,6 +899,9 @@ def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: state_size=self.ssm_state_size, conv_kernel=self.conv_kernel_size, ) + + def get_copy_spec(self): + return MambaCopySpecCalculator.mamba2_state_copy_spec() @property def mamba_type(self) -> str: diff --git a/vllm/model_executor/layers/mamba/mamba_utils.py b/vllm/model_executor/layers/mamba/mamba_utils.py index dddc99a2c5cb..4419a75be4e8 100644 --- a/vllm/model_executor/layers/mamba/mamba_utils.py +++ b/vllm/model_executor/layers/mamba/mamba_utils.py @@ -251,7 +251,7 @@ def num_elements_func(state: torch.Tensor, accept_token_bias: int) -> int: """ pass -class MambaFullCopySpec(MambaCopySpec): +class MambaTemporalCopySpec(MambaCopySpec): @staticmethod def block_idx_offset_func(accept_token_bias: int) -> int: return accept_token_bias @@ -271,22 +271,36 @@ def block_idx_offset_func(accept_token_bias: int) -> int: @staticmethod def data_offset_func(state: torch.Tensor, accept_token_bias: int) -> int: + # TODO: check contiguous! return accept_token_bias * state.stride(0) @staticmethod def num_elements_func(state: torch.Tensor, accept_token_bias: int) -> int: return state.numel() - accept_token_bias * state.stride(0) +MambaFullCopySpec = MambaTemporalCopySpec class MambaCopySpecCalculator: @classmethod def linear_attention_copy_spec(cls): - return (MambaFullCopySpec,) + return (MambaTemporalCopySpec,) @classmethod def mamba1_state_copy_spec(cls): - return MambaConvCopySpec, MambaFullCopySpec + return MambaConvCopySpec, MambaTemporalCopySpec + + @classmethod + def mamba2_state_copy_spec(cls): + return MambaConvCopySpec, MambaTemporalCopySpec + + @classmethod + def short_conv_state_copy_spec(cls): + return (MambaConvCopySpec,) @classmethod def gated_delta_net_copy_spec(cls): - return MambaConvCopySpec, MambaFullCopySpec \ No newline at end of file + return MambaConvCopySpec, MambaTemporalCopySpec + + @classmethod + def kda_state_copy_spec(cls): + return MambaConvCopySpec, MambaConvCopySpec, MambaConvCopySpec, MambaTemporalCopySpec \ No newline at end of file diff --git a/vllm/model_executor/layers/mamba/short_conv.py b/vllm/model_executor/layers/mamba/short_conv.py index 0bbad17d7ebc..9358362356d6 100644 --- a/vllm/model_executor/layers/mamba/short_conv.py +++ b/vllm/model_executor/layers/mamba/short_conv.py @@ -16,6 +16,7 @@ ) from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaCopySpecCalculator, MambaStateDtypeCalculator, MambaStateShapeCalculator, ) @@ -224,6 +225,9 @@ def get_state_shape(self) -> tuple[tuple[int, ...]]: conv_kernel=self.L_cache, ) + def get_copy_spec(self): + return MambaCopySpecCalculator.short_conv_state_copy_spec() + @property def mamba_type(self) -> str: return "short_conv" diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index 6765ee0c5779..88d0b6acd102 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -28,6 +28,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaCopySpecCalculator, MambaStateDtypeCalculator, MambaStateShapeCalculator, ) @@ -457,6 +458,9 @@ def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: state_size=self.ssm_state_size, conv_kernel=self.conv_kernel_size, ) + + def get_copy_spec(self): + return MambaCopySpecCalculator.mamba2_state_copy_spec() @property def mamba_type(self) -> str: diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index cdd4866a9aa5..eec5c6daa700 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -50,7 +50,6 @@ from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.mamba.mamba_mixer2 import mamba_v2_sharded_weight_loader from vllm.model_executor.layers.mamba.mamba_utils import ( - MambaCopySpec, MambaCopySpecCalculator, MambaStateDtypeCalculator, MambaStateShapeCalculator, @@ -237,7 +236,7 @@ def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: self.num_spec, ) - def get_copy_spec(self) -> tuple[type[MambaCopySpec], ...]: + def get_copy_spec(self): return MambaCopySpecCalculator.gated_delta_net_copy_spec() def __init__( From 55f98e14c71808c6a07cc7de6ac67d76e4b734ab Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Fri, 19 Dec 2025 19:54:36 +0000 Subject: [PATCH 075/130] update mamba_cache_mode config Signed-off-by: huanghaoyan.hhy --- vllm/config/cache.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/vllm/config/cache.py b/vllm/config/cache.py index 71318269ea00..4b91c595925c 100644 --- a/vllm/config/cache.py +++ b/vllm/config/cache.py @@ -243,8 +243,13 @@ def verify_with_parallel_config( def __post_init__(self) -> None: if self.enable_prefix_caching: if self.mamba_cache_mode == "none": - self.mamba_cache_mode = "last" + self.mamba_cache_mode = "align" + logger.warning( + "mamba_cache_mode set to 'align' defaultly when prefix caching is enabled" + ) else: - assert self.mamba_cache_mode == "none", ( - "mamba_cache_mode must be 'none' when prefix caching is disabled" - ) + if self.mamba_cache_mode != "none": + self.mamba_cache_mode = "none" + logger.warning( + "mamba_cache_mode set to 'none' when prefix caching is disabled" + ) From 07b6f0c19c2570bbb13080129594f132f007f39a Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Sat, 20 Dec 2025 17:48:08 +0000 Subject: [PATCH 076/130] format code Signed-off-by: huanghaoyan.hhy --- vllm/model_executor/layers/kda.py | 4 +-- vllm/model_executor/layers/mamba/abstract.py | 2 +- .../layers/mamba/linear_attn.py | 2 +- .../layers/mamba/mamba_mixer2.py | 2 +- .../layers/mamba/mamba_utils.py | 26 ++++++++++----- vllm/model_executor/models/config.py | 6 ++-- vllm/model_executor/models/plamo2.py | 2 +- vllm/model_executor/models/qwen3_next.py | 2 +- vllm/v1/attention/backends/utils.py | 4 +-- vllm/v1/worker/mamba_utils.py | 32 ++++++++----------- 10 files changed, 44 insertions(+), 38 deletions(-) diff --git a/vllm/model_executor/layers/kda.py b/vllm/model_executor/layers/kda.py index 53f017fbe38d..6e2c16e0e223 100644 --- a/vllm/model_executor/layers/kda.py +++ b/vllm/model_executor/layers/kda.py @@ -32,8 +32,8 @@ ) from .mamba.abstract import MambaBase from .mamba.mamba_utils import ( - MambaCopySpecCalculator, - MambaStateDtypeCalculator, + MambaCopySpecCalculator, + MambaStateDtypeCalculator, MambaStateShapeCalculator, ) from .mamba.ops.causal_conv1d import causal_conv1d_fn, causal_conv1d_update diff --git a/vllm/model_executor/layers/mamba/abstract.py b/vllm/model_executor/layers/mamba/abstract.py index dfebcd80476f..6d22f768b8a9 100644 --- a/vllm/model_executor/layers/mamba/abstract.py +++ b/vllm/model_executor/layers/mamba/abstract.py @@ -43,7 +43,7 @@ def get_state_dtype(self) -> tuple[torch.dtype, ...]: pass def get_copy_spec(self) -> tuple[type[MambaCopySpec], ...]: - return (MambaFullCopySpec, ) * len(self.get_state_dtype()) + return (MambaFullCopySpec,) * len(self.get_state_dtype()) def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None: if ( diff --git a/vllm/model_executor/layers/mamba/linear_attn.py b/vllm/model_executor/layers/mamba/linear_attn.py index 9aafd50f4090..9a90e3195733 100644 --- a/vllm/model_executor/layers/mamba/linear_attn.py +++ b/vllm/model_executor/layers/mamba/linear_attn.py @@ -127,7 +127,7 @@ def get_state_shape(self) -> tuple[tuple[int, int, int], ...]: return MambaStateShapeCalculator.linear_attention_state_shape( num_heads=self.num_heads, tp_size=self.tp_size, head_dim=self.head_dim ) - + def get_copy_spec(self): return MambaCopySpecCalculator.linear_attention_copy_spec() diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index c33bbba09add..c668477bcca8 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -899,7 +899,7 @@ def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: state_size=self.ssm_state_size, conv_kernel=self.conv_kernel_size, ) - + def get_copy_spec(self): return MambaCopySpecCalculator.mamba2_state_copy_spec() diff --git a/vllm/model_executor/layers/mamba/mamba_utils.py b/vllm/model_executor/layers/mamba/mamba_utils.py index 4419a75be4e8..e14c6d7c1f4e 100644 --- a/vllm/model_executor/layers/mamba/mamba_utils.py +++ b/vllm/model_executor/layers/mamba/mamba_utils.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod + import torch from vllm.config.cache import MambaDType @@ -236,7 +237,7 @@ def block_idx_offset_func(accept_token_bias: int) -> int: pass @staticmethod - @abstractmethod + @abstractmethod def data_offset_func(state: torch.Tensor, accept_token_bias: int) -> int: """ Return the offset of the data in the source block which needs to be copied. @@ -244,13 +245,14 @@ def data_offset_func(state: torch.Tensor, accept_token_bias: int) -> int: pass @staticmethod - @abstractmethod + @abstractmethod def num_elements_func(state: torch.Tensor, accept_token_bias: int) -> int: """ Return the number of elements to be copied. """ pass + class MambaTemporalCopySpec(MambaCopySpec): @staticmethod def block_idx_offset_func(accept_token_bias: int) -> int: @@ -259,32 +261,35 @@ def block_idx_offset_func(accept_token_bias: int) -> int: @staticmethod def data_offset_func(state: torch.Tensor, accept_token_bias: int) -> int: return 0 - + @staticmethod def num_elements_func(state: torch.Tensor, accept_token_bias: int) -> int: return state.numel() + class MambaConvCopySpec(MambaCopySpec): @staticmethod def block_idx_offset_func(accept_token_bias: int) -> int: return 0 - + @staticmethod def data_offset_func(state: torch.Tensor, accept_token_bias: int) -> int: # TODO: check contiguous! return accept_token_bias * state.stride(0) - + @staticmethod def num_elements_func(state: torch.Tensor, accept_token_bias: int) -> int: return state.numel() - accept_token_bias * state.stride(0) + MambaFullCopySpec = MambaTemporalCopySpec + class MambaCopySpecCalculator: @classmethod def linear_attention_copy_spec(cls): return (MambaTemporalCopySpec,) - + @classmethod def mamba1_state_copy_spec(cls): return MambaConvCopySpec, MambaTemporalCopySpec @@ -300,7 +305,12 @@ def short_conv_state_copy_spec(cls): @classmethod def gated_delta_net_copy_spec(cls): return MambaConvCopySpec, MambaTemporalCopySpec - + @classmethod def kda_state_copy_spec(cls): - return MambaConvCopySpec, MambaConvCopySpec, MambaConvCopySpec, MambaTemporalCopySpec \ No newline at end of file + return ( + MambaConvCopySpec, + MambaConvCopySpec, + MambaConvCopySpec, + MambaTemporalCopySpec, + ) \ No newline at end of file diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index fc0e03629138..7e59f3399520 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -450,9 +450,9 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: attn_block_size, ) - # By default, mamba block size will be set to max_model_len. - # When enabling prefix caching and using align mamba cache - # mode, we align mamba block size to the block size as the + # By default, mamba block size will be set to max_model_len. + # When enabling prefix caching and using align mamba cache + # mode, we align mamba block size to the block size as the # basic granularity for prefix caching. if cache_config.mamba_cache_mode == "align": cache_config.mamba_block_size = cache_config.block_size diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index 88d0b6acd102..07b45f29605b 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -458,7 +458,7 @@ def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: state_size=self.ssm_state_size, conv_kernel=self.conv_kernel_size, ) - + def get_copy_spec(self): return MambaCopySpecCalculator.mamba2_state_copy_spec() diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index eec5c6daa700..ae3eaf612f5d 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -235,7 +235,7 @@ def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: self.conv_kernel_size, self.num_spec, ) - + def get_copy_spec(self): return MambaCopySpecCalculator.gated_delta_net_copy_spec() diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 74a17aa20bcf..66a265159429 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -1209,10 +1209,10 @@ def mamba_get_block_table_tensor( else: assert isinstance(kv_cache_spec, MambaSpec) block_table_tensor = common_attn_metadata.block_table_tensor - # NOTE: For 0-length requests in CUDA graph, use a start_index of 0 + # NOTE: For 0-length requests in CUDA graph, use a start_index of 0 # to handle the invalid block table. start_indices = torch.clamp( - (common_attn_metadata.seq_lens - 1) // kv_cache_spec.block_size, + (common_attn_metadata.seq_lens - 1) // kv_cache_spec.block_size, min=0, ) offsets = torch.arange( diff --git a/vllm/v1/worker/mamba_utils.py b/vllm/v1/worker/mamba_utils.py index 8c005050362f..91a7a7d9154e 100644 --- a/vllm/v1/worker/mamba_utils.py +++ b/vllm/v1/worker/mamba_utils.py @@ -18,12 +18,7 @@ @triton.jit -def batch_memcpy_kernel( - src_ptrs, - dst_ptrs, - sizes, - BLOCK_SIZE: tl.constexpr -): +def batch_memcpy_kernel(src_ptrs, dst_ptrs, sizes, BLOCK_SIZE: tl.constexpr): pid = tl.program_id(0) src_ptr = tl.load(src_ptrs + pid) @@ -40,6 +35,7 @@ def batch_memcpy_kernel( data = tl.load(curr_src_ptr, mask=mask) tl.store(curr_dst_ptr, data, mask=mask) + def batch_memcpy(src_ptrs, dst_ptrs, sizes): batch = src_ptrs.shape[0] assert dst_ptrs.shape[0] == batch @@ -47,12 +43,8 @@ def batch_memcpy(src_ptrs, dst_ptrs, sizes): grid = (batch,) BLOCK_SIZE = 1024 - batch_memcpy_kernel[grid]( - src_ptrs, - dst_ptrs, - sizes, - BLOCK_SIZE=BLOCK_SIZE - ) + batch_memcpy_kernel[grid](src_ptrs, dst_ptrs, sizes, BLOCK_SIZE=BLOCK_SIZE) + def get_mamba_groups(kv_cache_config: KVCacheConfig) -> tuple[list[int], MambaSpec]: mamba_group_ids: list[int] = [] @@ -79,7 +71,7 @@ def mamba_copy_block( ): if src_block_idx == dest_block_idx and accept_token_bias == 0: return - + src_state_list = [] dest_state_list = [] num_elements_list = [] @@ -92,16 +84,20 @@ def mamba_copy_block( attention = forward_context[layer_name] kv_caches: list[list[torch.Tensor]] = attention.kv_cache[0] for state, copy_spec in zip(kv_caches, copy_specs): - src_block_id = block_ids[src_block_idx + copy_spec.block_idx_offset_func(accept_token_bias)] + src_block_id = block_ids[ + src_block_idx + copy_spec.block_idx_offset_func(accept_token_bias) + ] data_offset = copy_spec.data_offset_func(state[0], accept_token_bias) num_elements = copy_spec.num_elements_func(state[0], accept_token_bias) - src_state_list.append(state[src_block_id].data_ptr() + data_offset * state.element_size()) + src_state_list.append( + state[src_block_id].data_ptr() + data_offset * state.element_size() + ) dest_state_list.append(state[dest_block_id].data_ptr()) num_elements_list.append(num_elements * state.element_size()) - src_state_ptrs = torch.tensor(src_state_list, device='cuda', dtype=torch.int64) - dst_state_ptrs = torch.tensor(dest_state_list, device='cuda', dtype=torch.int64) - num_elements = torch.tensor(num_elements_list, device='cuda', dtype=torch.int32) + src_state_ptrs = torch.tensor(src_state_list, device="cuda", dtype=torch.int64) + dst_state_ptrs = torch.tensor(dest_state_list, device="cuda", dtype=torch.int64) + num_elements = torch.tensor(num_elements_list, device="cuda", dtype=torch.int32) batch_memcpy(src_state_ptrs, dst_state_ptrs, num_elements) From fec5b525c69a58d55c16aee49b07095717f51cbd Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Sat, 20 Dec 2025 18:19:50 +0000 Subject: [PATCH 077/130] format code Signed-off-by: huanghaoyan.hhy --- vllm/config/cache.py | 3 ++- vllm/model_executor/layers/mamba/mamba_utils.py | 3 +-- vllm/v1/kv_cache_interface.py | 2 +- vllm/v1/worker/mamba_utils.py | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/config/cache.py b/vllm/config/cache.py index 4b91c595925c..bb357ec614ac 100644 --- a/vllm/config/cache.py +++ b/vllm/config/cache.py @@ -245,7 +245,8 @@ def __post_init__(self) -> None: if self.mamba_cache_mode == "none": self.mamba_cache_mode = "align" logger.warning( - "mamba_cache_mode set to 'align' defaultly when prefix caching is enabled" + "mamba_cache_mode set to 'align' defaultly when prefix " + "caching is enabled" ) else: if self.mamba_cache_mode != "none": diff --git a/vllm/model_executor/layers/mamba/mamba_utils.py b/vllm/model_executor/layers/mamba/mamba_utils.py index e14c6d7c1f4e..4da0e7d84171 100644 --- a/vllm/model_executor/layers/mamba/mamba_utils.py +++ b/vllm/model_executor/layers/mamba/mamba_utils.py @@ -274,7 +274,6 @@ def block_idx_offset_func(accept_token_bias: int) -> int: @staticmethod def data_offset_func(state: torch.Tensor, accept_token_bias: int) -> int: - # TODO: check contiguous! return accept_token_bias * state.stride(0) @staticmethod @@ -313,4 +312,4 @@ def kda_state_copy_spec(cls): MambaConvCopySpec, MambaConvCopySpec, MambaTemporalCopySpec, - ) \ No newline at end of file + ) diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 986ef13550a4..62f26b8bb9d0 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -248,7 +248,7 @@ class MambaSpec(KVCacheSpec): page_size_padded: int | None = None mamba_type: str = "mamba2" num_speculative_blocks: int = 0 - copy_specs: tuple[type[MambaCopySpec], ...] | None = None + copy_specs: tuple[type[MambaCopySpec], ...] = () @property def page_size_bytes(self) -> int: diff --git a/vllm/v1/worker/mamba_utils.py b/vllm/v1/worker/mamba_utils.py index 91a7a7d9154e..21d39bde9674 100644 --- a/vllm/v1/worker/mamba_utils.py +++ b/vllm/v1/worker/mamba_utils.py @@ -82,7 +82,7 @@ def mamba_copy_block( layer_names = kv_cache_config.kv_cache_groups[mamba_group_id].layer_names for layer_name in layer_names: attention = forward_context[layer_name] - kv_caches: list[list[torch.Tensor]] = attention.kv_cache[0] + kv_caches: list[torch.Tensor] = attention.kv_cache[0] for state, copy_spec in zip(kv_caches, copy_specs): src_block_id = block_ids[ src_block_idx + copy_spec.block_idx_offset_func(accept_token_bias) From 9a3a5567aa969cb0068fa2f37708aebbee33ec59 Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Sat, 20 Dec 2025 18:30:40 +0000 Subject: [PATCH 078/130] update mamba copy func in the test Signed-off-by: huanghaoyan.hhy --- tests/v1/e2e/test_mamba_prefix_cache.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/v1/e2e/test_mamba_prefix_cache.py b/tests/v1/e2e/test_mamba_prefix_cache.py index 6138aa17896d..00c48f7160f2 100644 --- a/tests/v1/e2e/test_mamba_prefix_cache.py +++ b/tests/v1/e2e/test_mamba_prefix_cache.py @@ -452,12 +452,12 @@ def apply_patch(monkeypatch: pytest.MonkeyPatch): get_fake_process_mamba_fn( mamba_utils.preprocess_mamba, mamba_utils.postprocess_mamba, - mamba_utils.mamba_copy_block_for_qwen_next, + mamba_utils.mamba_copy_block, ) ) monkeypatch.setattr(mamba_utils, "preprocess_mamba", fake_preprocess_mamba_fn) monkeypatch.setattr(mamba_utils, "postprocess_mamba", fake_post_process_mamba_fn) - monkeypatch.setattr(mamba_utils, "mamba_copy_block_for_qwen_next", fake_copy_fn) + monkeypatch.setattr(mamba_utils, "mamba_copy_block", fake_copy_fn) def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch): From 25ca2a914de970f7482d844abf6bab6186827d6e Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Sat, 20 Dec 2025 18:33:21 +0000 Subject: [PATCH 079/130] remove lpc env var Signed-off-by: huanghaoyan.hhy --- vllm/envs.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index 954e1ced06df..cb75ba1a62de 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -243,7 +243,6 @@ VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD: int = 256 VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary" VLLM_USE_V2_MODEL_RUNNER: bool = False - VLLM_USE_LIGHTER_MAMBA_CACHE: bool = False def get_default_cache_root(): @@ -1562,9 +1561,6 @@ def get_vllm_port() -> int | None: "VLLM_USE_V2_MODEL_RUNNER": lambda: bool( int(os.getenv("VLLM_USE_V2_MODEL_RUNNER", "0")) ), - "VLLM_USE_LIGHTER_MAMBA_CACHE": lambda: bool( - int(os.getenv("VLLM_USE_LIGHTER_MAMBA_CACHE", "0")) - ), } # --8<-- [end:env-vars-definition] @@ -1695,7 +1691,6 @@ def compile_factors() -> dict[str, object]: "LOCAL_RANK", "CUDA_VISIBLE_DEVICES", "NO_COLOR", - "VLLM_USE_LIGHTER_MAMBA_CACHE", } from vllm.config.utils import normalize_value From da3be73a344e0f414b9612495c590f103ccd5af2 Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Sat, 20 Dec 2025 18:35:48 +0000 Subject: [PATCH 080/130] format code Signed-off-by: huanghaoyan.hhy --- vllm/model_executor/models/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 7e59f3399520..fb7e43141a49 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -311,7 +311,7 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: ) else: raise ValueError( - "unknown mamba cache mode: %s", cache_config.mamba_cache_mode + f"unknown mamba cache mode: {cache_config.mamba_cache_mode}" ) elif cache_config.mamba_block_size is None: cache_config.mamba_block_size = model_config.max_model_len From 6ef2bc81e742d3158b40784a61f71b662a401ca3 Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Sun, 21 Dec 2025 17:00:09 +0000 Subject: [PATCH 081/130] remove unused _mamba_copy_block Signed-off-by: huanghaoyan.hhy --- vllm/v1/worker/gpu_model_runner.py | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index ba259d6694a9..627c00406a91 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2912,24 +2912,6 @@ def _register_layerwise_nvtx_hooks(self) -> None: pyt_hooks.register_hooks(self.model, self.model.__class__.__name__) self.layerwise_nvtx_hooks_registered = True - def _mamba_copy_block( - self, - kv_cache_group_spec: KVCacheGroupSpec, - src_block_id: int, - dest_block_id: int, - ): - if src_block_id == dest_block_id: - return - forward_context = self.compilation_config.static_forward_context - for layer_name in kv_cache_group_spec.layer_names: - kv_caches: list[list[torch.Tensor]] = forward_context[layer_name].kv_cache - for kv_cache in kv_caches: - if isinstance(kv_cache, torch.Tensor): - kv_cache[dest_block_id].copy_(kv_cache[src_block_id]) - elif isinstance(kv_cache, list): - for kv_cache_part in kv_cache: - kv_cache_part[dest_block_id].copy_(kv_cache_part[src_block_id]) - @torch.inference_mode() def execute_model( self, From c2839b0ae8ad65f49f90e038d0add9c77dba288b Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Sun, 21 Dec 2025 17:29:03 +0000 Subject: [PATCH 082/130] batch mamba copy block Signed-off-by: huanghaoyan.hhy --- vllm/v1/worker/mamba_utils.py | 37 +++++++++++++++++++++++++++++------ 1 file changed, 31 insertions(+), 6 deletions(-) diff --git a/vllm/v1/worker/mamba_utils.py b/vllm/v1/worker/mamba_utils.py index 21d39bde9674..f490984f96d0 100644 --- a/vllm/v1/worker/mamba_utils.py +++ b/vllm/v1/worker/mamba_utils.py @@ -59,7 +59,10 @@ def get_mamba_groups(kv_cache_config: KVCacheConfig) -> tuple[list[int], MambaSp return mamba_group_ids, mamba_specs[0] -def mamba_copy_block( +def collect_mamba_copy_meta( + src_state_list: list[int], + dest_state_list: list[int], + num_elements_list: list[int], kv_cache_config: KVCacheConfig, mamba_spec: MambaSpec, mamba_group_ids: list[int], @@ -72,9 +75,6 @@ def mamba_copy_block( if src_block_idx == dest_block_idx and accept_token_bias == 0: return - src_state_list = [] - dest_state_list = [] - num_elements_list = [] copy_specs: tuple[type[MambaCopySpec], ...] = mamba_spec.copy_specs for mamba_group_id in mamba_group_ids: block_ids = req_state.block_ids[mamba_group_id] @@ -95,6 +95,16 @@ def mamba_copy_block( dest_state_list.append(state[dest_block_id].data_ptr()) num_elements_list.append(num_elements * state.element_size()) + +def do_mamba_copy_block( + src_state_list: list[int], + dest_state_list: list[int], + num_elements_list: list[int], +): + if len(src_state_list) == 0: + return + assert len(src_state_list) == len(dest_state_list) + assert len(src_state_list) == len(num_elements_list) src_state_ptrs = torch.tensor(src_state_list, device="cuda", dtype=torch.int64) dst_state_ptrs = torch.tensor(dest_state_list, device="cuda", dtype=torch.int64) num_elements = torch.tensor(num_elements_list, device="cuda", dtype=torch.int32) @@ -124,6 +134,10 @@ def preprocess_mamba( preempted_req_ids = scheduler_output.preempted_req_ids or set() for req_id in itertools.chain(finished_req_ids, preempted_req_ids): mamba_state_idx.pop(req_id, None) + + src_state_list: list[int] = [] + dest_state_list: list[int] = [] + num_elements_list: list[int] = [] for i, req_id in enumerate(input_batch.req_ids): req_state = requests[req_id] prev_state_idx = mamba_state_idx.get(req_id) @@ -147,7 +161,10 @@ def preprocess_mamba( curr_state_idx = num_blocks - 1 - num_speculative_blocks mamba_state_idx[req_id] = curr_state_idx if prev_state_idx != -1 and prev_state_idx != curr_state_idx: - mamba_copy_block( + collect_mamba_copy_meta( + src_state_list, + dest_state_list, + num_elements_list, kv_cache_config, mamba_spec, mamba_group_ids, @@ -158,6 +175,7 @@ def preprocess_mamba( forward_context, ) input_batch.num_accepted_tokens_cpu[i] = 1 + do_mamba_copy_block(src_state_list, dest_state_list, num_elements_list) def postprocess_mamba( @@ -178,6 +196,9 @@ def postprocess_mamba( # NOTE: can be optimized as this function always returns the same result mamba_group_ids, mamba_spec = get_mamba_groups(kv_cache_config) # TODO: vectorize this loop + src_state_list: list[int] = [] + dest_state_list: list[int] = [] + num_elements_list: list[int] = [] for i, req_id in enumerate(input_batch.req_ids): req_state = requests[req_id] num_computed_tokens = req_state.num_computed_tokens @@ -196,7 +217,10 @@ def postprocess_mamba( accept_token_bias = aligned_new_computed_tokens - num_tokens_running_state src_block_idx = mamba_state_idx[req_id] dest_block_idx = aligned_new_computed_tokens // mamba_spec.block_size - 1 - mamba_copy_block( + collect_mamba_copy_meta( + src_state_list, + dest_state_list, + num_elements_list, kv_cache_config, mamba_spec, mamba_group_ids, @@ -208,3 +232,4 @@ def postprocess_mamba( ) if src_block_idx == dest_block_idx: num_accepted_tokens_cpu[i] = 1 + do_mamba_copy_block(src_state_list, dest_state_list, num_elements_list) From 331957918b1e09f0f3b7bf39377932fd065dae13 Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Mon, 22 Dec 2025 16:41:10 +0000 Subject: [PATCH 083/130] update interface for mamba copy Signed-off-by: huanghaoyan.hhy --- vllm/model_executor/layers/kda.py | 4 +- vllm/model_executor/layers/mamba/abstract.py | 12 +- .../layers/mamba/linear_attn.py | 4 +- .../layers/mamba/mamba_mixer.py | 4 +- .../layers/mamba/mamba_mixer2.py | 4 +- .../layers/mamba/mamba_utils.py | 111 ++++++++---------- .../model_executor/layers/mamba/short_conv.py | 4 +- vllm/model_executor/models/plamo2.py | 4 +- vllm/model_executor/models/qwen3_next.py | 4 +- vllm/v1/kv_cache_interface.py | 4 +- vllm/v1/worker/mamba_utils.py | 20 ++-- 11 files changed, 80 insertions(+), 95 deletions(-) diff --git a/vllm/model_executor/layers/kda.py b/vllm/model_executor/layers/kda.py index 6e2c16e0e223..8e50af091331 100644 --- a/vllm/model_executor/layers/kda.py +++ b/vllm/model_executor/layers/kda.py @@ -104,8 +104,8 @@ def get_state_shape( self.tp_size, self.num_heads, self.head_dim, conv_kernel_size=self.conv_size ) - def get_copy_spec(self): - return MambaCopySpecCalculator.kda_state_copy_spec() + def get_copy_spec_func(self): + return MambaCopySpecCalculator.kda_state_copy_spec_func() def __init__( self, diff --git a/vllm/model_executor/layers/mamba/abstract.py b/vllm/model_executor/layers/mamba/abstract.py index 6d22f768b8a9..54d229878b82 100644 --- a/vllm/model_executor/layers/mamba/abstract.py +++ b/vllm/model_executor/layers/mamba/abstract.py @@ -9,9 +9,11 @@ from vllm.attention.selector import get_mamba_attn_backend from vllm.config import VllmConfig from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase -from vllm.model_executor.layers.mamba.mamba_utils import MambaFullCopySpec +from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaCopySpecFunc, + get_full_copy_spec, +) from vllm.v1.kv_cache_interface import KVCacheSpec, MambaSpec -from vllm.v1.worker.mamba_utils import MambaCopySpec class MambaBase(AttentionLayerBase): @@ -42,8 +44,8 @@ def mamba_type(self) -> str: def get_state_dtype(self) -> tuple[torch.dtype, ...]: pass - def get_copy_spec(self) -> tuple[type[MambaCopySpec], ...]: - return (MambaFullCopySpec,) * len(self.get_state_dtype()) + def get_copy_spec_func(self) -> tuple[MambaCopySpecFunc, ...]: + return (get_full_copy_spec,) * len(self.get_state_dtype()) def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None: if ( @@ -66,7 +68,7 @@ def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None: if vllm_config.speculative_config else 0 ), - copy_specs=self.get_copy_spec(), + copy_spec_funcs=self.get_copy_spec_func(), ) def get_attn_backend(self) -> type[AttentionBackend]: diff --git a/vllm/model_executor/layers/mamba/linear_attn.py b/vllm/model_executor/layers/mamba/linear_attn.py index 9a90e3195733..704969f370d0 100644 --- a/vllm/model_executor/layers/mamba/linear_attn.py +++ b/vllm/model_executor/layers/mamba/linear_attn.py @@ -128,8 +128,8 @@ def get_state_shape(self) -> tuple[tuple[int, int, int], ...]: num_heads=self.num_heads, tp_size=self.tp_size, head_dim=self.head_dim ) - def get_copy_spec(self): - return MambaCopySpecCalculator.linear_attention_copy_spec() + def get_copy_spec_func(self): + return MambaCopySpecCalculator.linear_attention_copy_spec_func() def __init__( self, diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index 18c044515295..67d1e6bfdb16 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -445,8 +445,8 @@ def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: conv_kernel=self.conv_kernel_size, ) - def get_copy_spec(self): - return MambaCopySpecCalculator.mamba1_state_copy_spec() + def get_copy_spec_func(self): + return MambaCopySpecCalculator.mamba1_state_copy_spec_func() @property def mamba_type(self) -> str: diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index c668477bcca8..c9df53acbb86 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -900,8 +900,8 @@ def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: conv_kernel=self.conv_kernel_size, ) - def get_copy_spec(self): - return MambaCopySpecCalculator.mamba2_state_copy_spec() + def get_copy_spec_func(self): + return MambaCopySpecCalculator.mamba2_state_copy_spec_func() @property def mamba_type(self) -> str: diff --git a/vllm/model_executor/layers/mamba/mamba_utils.py b/vllm/model_executor/layers/mamba/mamba_utils.py index 4da0e7d84171..d551f63cb074 100644 --- a/vllm/model_executor/layers/mamba/mamba_utils.py +++ b/vllm/model_executor/layers/mamba/mamba_utils.py @@ -1,7 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from abc import ABC, abstractmethod +from collections.abc import Callable +from dataclasses import dataclass +from typing import TypeAlias import torch @@ -227,89 +229,72 @@ def kda_state_shape( ) -class MambaCopySpec(ABC): - @staticmethod - @abstractmethod - def block_idx_offset_func(accept_token_bias: int) -> int: - """ - Return the offset of the source block idx which needs to be copied. - """ - pass +@dataclass +class MambaCopySpec: + start_addr: int + num_elements: int - @staticmethod - @abstractmethod - def data_offset_func(state: torch.Tensor, accept_token_bias: int) -> int: - """ - Return the offset of the data in the source block which needs to be copied. - """ - pass - @staticmethod - @abstractmethod - def num_elements_func(state: torch.Tensor, accept_token_bias: int) -> int: - """ - Return the number of elements to be copied. - """ - pass +MambaCopySpecFunc: TypeAlias = Callable[ + [torch.Tensor, list[int], int, int], MambaCopySpec +] -class MambaTemporalCopySpec(MambaCopySpec): - @staticmethod - def block_idx_offset_func(accept_token_bias: int) -> int: - return accept_token_bias +def get_conv_copy_spec( + state: torch.Tensor, + block_ids: list[int], + cur_block_idx: int, + num_accepted_tokens: int, +) -> MambaCopySpec: + src_block_id = block_ids[cur_block_idx] + src_state = state[src_block_id, num_accepted_tokens - 1 :] + return MambaCopySpec( + start_addr=src_state.data_ptr(), num_elements=src_state.numel() + ) - @staticmethod - def data_offset_func(state: torch.Tensor, accept_token_bias: int) -> int: - return 0 - @staticmethod - def num_elements_func(state: torch.Tensor, accept_token_bias: int) -> int: - return state.numel() +def get_temporal_copy_spec( + state: torch.Tensor, + block_ids: list[int], + cur_block_idx: int, + num_accepted_tokens: int, +) -> MambaCopySpec: + src_block_id = block_ids[cur_block_idx + num_accepted_tokens - 1] + src_state = state[src_block_id] + return MambaCopySpec( + start_addr=src_state.data_ptr(), num_elements=src_state.numel() + ) -class MambaConvCopySpec(MambaCopySpec): - @staticmethod - def block_idx_offset_func(accept_token_bias: int) -> int: - return 0 - - @staticmethod - def data_offset_func(state: torch.Tensor, accept_token_bias: int) -> int: - return accept_token_bias * state.stride(0) - - @staticmethod - def num_elements_func(state: torch.Tensor, accept_token_bias: int) -> int: - return state.numel() - accept_token_bias * state.stride(0) - - -MambaFullCopySpec = MambaTemporalCopySpec +get_full_copy_spec = get_temporal_copy_spec class MambaCopySpecCalculator: @classmethod - def linear_attention_copy_spec(cls): - return (MambaTemporalCopySpec,) + def linear_attention_copy_spec_func(cls): + return (get_temporal_copy_spec,) @classmethod - def mamba1_state_copy_spec(cls): - return MambaConvCopySpec, MambaTemporalCopySpec + def mamba1_state_copy_spec_func(cls): + return get_conv_copy_spec, get_temporal_copy_spec @classmethod - def mamba2_state_copy_spec(cls): - return MambaConvCopySpec, MambaTemporalCopySpec + def mamba2_state_copy_spec_func(cls): + return get_conv_copy_spec, get_temporal_copy_spec @classmethod - def short_conv_state_copy_spec(cls): - return (MambaConvCopySpec,) + def short_conv_state_copy_spec_func(cls): + return (get_conv_copy_spec,) @classmethod - def gated_delta_net_copy_spec(cls): - return MambaConvCopySpec, MambaTemporalCopySpec + def gated_delta_net_copy_spec_func(cls): + return get_conv_copy_spec, get_temporal_copy_spec @classmethod - def kda_state_copy_spec(cls): + def kda_state_copy_spec_func(cls): return ( - MambaConvCopySpec, - MambaConvCopySpec, - MambaConvCopySpec, - MambaTemporalCopySpec, + get_conv_copy_spec, + get_conv_copy_spec, + get_conv_copy_spec, + get_temporal_copy_spec, ) diff --git a/vllm/model_executor/layers/mamba/short_conv.py b/vllm/model_executor/layers/mamba/short_conv.py index 9358362356d6..b7af00e91478 100644 --- a/vllm/model_executor/layers/mamba/short_conv.py +++ b/vllm/model_executor/layers/mamba/short_conv.py @@ -225,8 +225,8 @@ def get_state_shape(self) -> tuple[tuple[int, ...]]: conv_kernel=self.L_cache, ) - def get_copy_spec(self): - return MambaCopySpecCalculator.short_conv_state_copy_spec() + def get_copy_spec_func(self): + return MambaCopySpecCalculator.short_conv_state_copy_spec_func() @property def mamba_type(self) -> str: diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index 07b45f29605b..7f7bdff156c8 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -459,8 +459,8 @@ def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: conv_kernel=self.conv_kernel_size, ) - def get_copy_spec(self): - return MambaCopySpecCalculator.mamba2_state_copy_spec() + def get_copy_spec_func(self): + return MambaCopySpecCalculator.mamba2_state_copy_spec_func() @property def mamba_type(self) -> str: diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index 78dabb9a1ab1..b9c5da7516d4 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -236,8 +236,8 @@ def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: self.num_spec, ) - def get_copy_spec(self): - return MambaCopySpecCalculator.gated_delta_net_copy_spec() + def get_copy_spec_func(self): + return MambaCopySpecCalculator.gated_delta_net_copy_spec_func() def __init__( self, diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 62f26b8bb9d0..f6b3d553c609 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -10,7 +10,7 @@ from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.model_executor.layers.mamba.mamba_utils import MambaCopySpec +from vllm.model_executor.layers.mamba.mamba_utils import MambaCopySpecFunc from vllm.utils.math_utils import cdiv from vllm.utils.torch_utils import get_dtype_size @@ -248,7 +248,7 @@ class MambaSpec(KVCacheSpec): page_size_padded: int | None = None mamba_type: str = "mamba2" num_speculative_blocks: int = 0 - copy_specs: tuple[type[MambaCopySpec], ...] = () + copy_spec_funcs: tuple[MambaCopySpecFunc, ...] | None = None @property def page_size_bytes(self) -> int: diff --git a/vllm/v1/worker/mamba_utils.py b/vllm/v1/worker/mamba_utils.py index f490984f96d0..9e99ecf51837 100644 --- a/vllm/v1/worker/mamba_utils.py +++ b/vllm/v1/worker/mamba_utils.py @@ -9,7 +9,7 @@ from vllm.config import CacheConfig from vllm.model_executor.layers.mamba.mamba_utils import ( - MambaCopySpec, + MambaCopySpecFunc, ) from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import KVCacheConfig, MambaSpec @@ -75,7 +75,8 @@ def collect_mamba_copy_meta( if src_block_idx == dest_block_idx and accept_token_bias == 0: return - copy_specs: tuple[type[MambaCopySpec], ...] = mamba_spec.copy_specs + copy_spec_funcs: tuple[MambaCopySpecFunc, ...] = mamba_spec.copy_spec_funcs + assert copy_spec_funcs is not None for mamba_group_id in mamba_group_ids: block_ids = req_state.block_ids[mamba_group_id] dest_block_id = block_ids[dest_block_idx] @@ -83,17 +84,14 @@ def collect_mamba_copy_meta( for layer_name in layer_names: attention = forward_context[layer_name] kv_caches: list[torch.Tensor] = attention.kv_cache[0] - for state, copy_spec in zip(kv_caches, copy_specs): - src_block_id = block_ids[ - src_block_idx + copy_spec.block_idx_offset_func(accept_token_bias) - ] - data_offset = copy_spec.data_offset_func(state[0], accept_token_bias) - num_elements = copy_spec.num_elements_func(state[0], accept_token_bias) - src_state_list.append( - state[src_block_id].data_ptr() + data_offset * state.element_size() + for state, copy_spec_func in zip(kv_caches, copy_spec_funcs): + copy_spec = copy_spec_func( + state, block_ids, src_block_idx, accept_token_bias + 1 ) + + src_state_list.append(copy_spec.start_addr) dest_state_list.append(state[dest_block_id].data_ptr()) - num_elements_list.append(num_elements * state.element_size()) + num_elements_list.append(copy_spec.num_elements * state.element_size()) def do_mamba_copy_block( From 4518c4a80f2c761567bb91c313c602207dec6d3b Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Wed, 24 Dec 2025 17:34:17 +0000 Subject: [PATCH 084/130] move get_mamba_copy_func to the model Signed-off-by: huanghaoyan.hhy --- vllm/model_executor/layers/kda.py | 4 ---- vllm/model_executor/layers/mamba/abstract.py | 8 -------- vllm/model_executor/layers/mamba/linear_attn.py | 4 ---- vllm/model_executor/layers/mamba/mamba_mixer.py | 4 ---- vllm/model_executor/layers/mamba/mamba_mixer2.py | 4 ---- vllm/model_executor/layers/mamba/mamba_utils.py | 16 ++++++++-------- vllm/model_executor/layers/mamba/short_conv.py | 4 ---- vllm/model_executor/models/bamba.py | 6 ++++++ vllm/model_executor/models/falcon_h1.py | 6 ++++++ vllm/model_executor/models/granitemoehybrid.py | 5 +++++ vllm/model_executor/models/interfaces.py | 6 ++++++ vllm/model_executor/models/jamba.py | 6 ++++++ vllm/model_executor/models/kimi_linear.py | 10 ++++++++++ vllm/model_executor/models/lfm2.py | 6 ++++++ vllm/model_executor/models/lfm2_moe.py | 6 ++++++ vllm/model_executor/models/mamba.py | 6 ++++++ vllm/model_executor/models/mamba2.py | 6 ++++++ vllm/model_executor/models/minimax_text_01.py | 6 ++++++ vllm/model_executor/models/nano_nemotron_vl.py | 4 ++++ vllm/model_executor/models/nemotron_h.py | 6 ++++++ vllm/model_executor/models/plamo2.py | 10 ++++++---- vllm/model_executor/models/qwen3_next.py | 10 ++++++---- vllm/model_executor/models/zamba2.py | 6 ++++++ vllm/v1/kv_cache_interface.py | 2 -- vllm/v1/worker/gpu_model_runner.py | 2 ++ vllm/v1/worker/mamba_utils.py | 16 ++++++++-------- 26 files changed, 115 insertions(+), 54 deletions(-) diff --git a/vllm/model_executor/layers/kda.py b/vllm/model_executor/layers/kda.py index 8e50af091331..80a1b32df928 100644 --- a/vllm/model_executor/layers/kda.py +++ b/vllm/model_executor/layers/kda.py @@ -32,7 +32,6 @@ ) from .mamba.abstract import MambaBase from .mamba.mamba_utils import ( - MambaCopySpecCalculator, MambaStateDtypeCalculator, MambaStateShapeCalculator, ) @@ -104,9 +103,6 @@ def get_state_shape( self.tp_size, self.num_heads, self.head_dim, conv_kernel_size=self.conv_size ) - def get_copy_spec_func(self): - return MambaCopySpecCalculator.kda_state_copy_spec_func() - def __init__( self, layer_idx: int, diff --git a/vllm/model_executor/layers/mamba/abstract.py b/vllm/model_executor/layers/mamba/abstract.py index 54d229878b82..74f4383e9c23 100644 --- a/vllm/model_executor/layers/mamba/abstract.py +++ b/vllm/model_executor/layers/mamba/abstract.py @@ -9,10 +9,6 @@ from vllm.attention.selector import get_mamba_attn_backend from vllm.config import VllmConfig from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase -from vllm.model_executor.layers.mamba.mamba_utils import ( - MambaCopySpecFunc, - get_full_copy_spec, -) from vllm.v1.kv_cache_interface import KVCacheSpec, MambaSpec @@ -44,9 +40,6 @@ def mamba_type(self) -> str: def get_state_dtype(self) -> tuple[torch.dtype, ...]: pass - def get_copy_spec_func(self) -> tuple[MambaCopySpecFunc, ...]: - return (get_full_copy_spec,) * len(self.get_state_dtype()) - def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None: if ( vllm_config.speculative_config is not None @@ -68,7 +61,6 @@ def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None: if vllm_config.speculative_config else 0 ), - copy_spec_funcs=self.get_copy_spec_func(), ) def get_attn_backend(self) -> type[AttentionBackend]: diff --git a/vllm/model_executor/layers/mamba/linear_attn.py b/vllm/model_executor/layers/mamba/linear_attn.py index 704969f370d0..278713408c28 100644 --- a/vllm/model_executor/layers/mamba/linear_attn.py +++ b/vllm/model_executor/layers/mamba/linear_attn.py @@ -24,7 +24,6 @@ from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.mamba.mamba_utils import ( - MambaCopySpecCalculator, MambaStateDtypeCalculator, MambaStateShapeCalculator, ) @@ -128,9 +127,6 @@ def get_state_shape(self) -> tuple[tuple[int, int, int], ...]: num_heads=self.num_heads, tp_size=self.tp_size, head_dim=self.head_dim ) - def get_copy_spec_func(self): - return MambaCopySpecCalculator.linear_attention_copy_spec_func() - def __init__( self, hidden_size: int, diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index 67d1e6bfdb16..9d509b82d9bb 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -22,7 +22,6 @@ ) from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.mamba.mamba_utils import ( - MambaCopySpecCalculator, MambaStateDtypeCalculator, MambaStateShapeCalculator, ) @@ -445,9 +444,6 @@ def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: conv_kernel=self.conv_kernel_size, ) - def get_copy_spec_func(self): - return MambaCopySpecCalculator.mamba1_state_copy_spec_func() - @property def mamba_type(self) -> str: return "mamba1" diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index c9df53acbb86..ef3936b839b2 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -23,7 +23,6 @@ ) from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.mamba.mamba_utils import ( - MambaCopySpecCalculator, MambaStateDtypeCalculator, MambaStateShapeCalculator, ) @@ -900,9 +899,6 @@ def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: conv_kernel=self.conv_kernel_size, ) - def get_copy_spec_func(self): - return MambaCopySpecCalculator.mamba2_state_copy_spec_func() - @property def mamba_type(self) -> str: return "mamba2" diff --git a/vllm/model_executor/layers/mamba/mamba_utils.py b/vllm/model_executor/layers/mamba/mamba_utils.py index d551f63cb074..26ddded7914a 100644 --- a/vllm/model_executor/layers/mamba/mamba_utils.py +++ b/vllm/model_executor/layers/mamba/mamba_utils.py @@ -235,7 +235,7 @@ class MambaCopySpec: num_elements: int -MambaCopySpecFunc: TypeAlias = Callable[ +MambaStateCopyFunc: TypeAlias = Callable[ [torch.Tensor, list[int], int, int], MambaCopySpec ] @@ -269,29 +269,29 @@ def get_temporal_copy_spec( get_full_copy_spec = get_temporal_copy_spec -class MambaCopySpecCalculator: +class MambaStateCopyFuncCalculator: @classmethod - def linear_attention_copy_spec_func(cls): + def linear_attention_state_copy_func(cls): return (get_temporal_copy_spec,) @classmethod - def mamba1_state_copy_spec_func(cls): + def mamba1_state_copy_func(cls): return get_conv_copy_spec, get_temporal_copy_spec @classmethod - def mamba2_state_copy_spec_func(cls): + def mamba2_state_copy_func(cls): return get_conv_copy_spec, get_temporal_copy_spec @classmethod - def short_conv_state_copy_spec_func(cls): + def short_conv_state_copy_func(cls): return (get_conv_copy_spec,) @classmethod - def gated_delta_net_copy_spec_func(cls): + def gated_delta_net_state_copy_func(cls): return get_conv_copy_spec, get_temporal_copy_spec @classmethod - def kda_state_copy_spec_func(cls): + def kda_state_copy_func(cls): return ( get_conv_copy_spec, get_conv_copy_spec, diff --git a/vllm/model_executor/layers/mamba/short_conv.py b/vllm/model_executor/layers/mamba/short_conv.py index b7af00e91478..0bbad17d7ebc 100644 --- a/vllm/model_executor/layers/mamba/short_conv.py +++ b/vllm/model_executor/layers/mamba/short_conv.py @@ -16,7 +16,6 @@ ) from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.mamba.mamba_utils import ( - MambaCopySpecCalculator, MambaStateDtypeCalculator, MambaStateShapeCalculator, ) @@ -225,9 +224,6 @@ def get_state_shape(self) -> tuple[tuple[int, ...]]: conv_kernel=self.L_cache, ) - def get_copy_spec_func(self): - return MambaCopySpecCalculator.short_conv_state_copy_spec_func() - @property def mamba_type(self) -> str: return "short_conv" diff --git a/vllm/model_executor/models/bamba.py b/vllm/model_executor/models/bamba.py index 22631bbc5489..a7de8e7cf349 100644 --- a/vllm/model_executor/models/bamba.py +++ b/vllm/model_executor/models/bamba.py @@ -24,6 +24,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateCopyFunc, + MambaStateCopyFuncCalculator, MambaStateDtypeCalculator, MambaStateShapeCalculator, ) @@ -455,6 +457,10 @@ def get_mamba_state_shape_from_config( conv_kernel=hf_config.mamba_d_conv, ) + @classmethod + def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]: + return MambaStateCopyFuncCalculator.mamba2_state_copy_func() + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config self.vllm_config = vllm_config diff --git a/vllm/model_executor/models/falcon_h1.py b/vllm/model_executor/models/falcon_h1.py index bfb6b1a1f160..49722b6d721f 100644 --- a/vllm/model_executor/models/falcon_h1.py +++ b/vllm/model_executor/models/falcon_h1.py @@ -24,6 +24,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateCopyFunc, + MambaStateCopyFuncCalculator, MambaStateDtypeCalculator, MambaStateShapeCalculator, ) @@ -551,6 +553,10 @@ def get_mamba_state_shape_from_config( conv_kernel=hf_config.mamba_d_conv, ) + @classmethod + def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]: + return MambaStateCopyFuncCalculator.mamba2_state_copy_func() + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config self.vllm_config = vllm_config diff --git a/vllm/model_executor/models/granitemoehybrid.py b/vllm/model_executor/models/granitemoehybrid.py index 3434716b8378..6f58df835b05 100644 --- a/vllm/model_executor/models/granitemoehybrid.py +++ b/vllm/model_executor/models/granitemoehybrid.py @@ -19,6 +19,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateCopyFunc, + MambaStateCopyFuncCalculator, MambaStateDtypeCalculator, MambaStateShapeCalculator, ) @@ -641,6 +643,9 @@ def get_mamba_state_shape_from_config( conv_kernel=hf_config.mamba_d_conv, ) + def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]: + return MambaStateCopyFuncCalculator.mamba2_state_copy_func() + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 67c65a44dcf7..ea296dd562d6 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -23,6 +23,7 @@ from vllm.inputs import TokensPrompt from vllm.inputs.data import PromptType from vllm.logger import init_logger +from vllm.model_executor.layers.mamba.mamba_utils import MambaStateCopyFunc from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.utils.func_utils import supports_kw @@ -647,6 +648,11 @@ def get_mamba_state_shape_from_config( """ ... + @classmethod + def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, ...]: + # TODO: add notes + ... + @overload def is_hybrid(model: object) -> TypeIs[IsHybrid]: ... diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index b2ad12be1e35..2ab71c573cc4 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -24,6 +24,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateCopyFunc, + MambaStateCopyFuncCalculator, MambaStateDtypeCalculator, MambaStateShapeCalculator, ) @@ -556,6 +558,10 @@ def get_mamba_state_shape_from_config( conv_kernel=hf_config.mamba_d_conv, ) + @classmethod + def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]: + return MambaStateCopyFuncCalculator.mamba1_state_copy_func() + def compute_logits( self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/kimi_linear.py b/vllm/model_executor/models/kimi_linear.py index 4562b2202c5e..5cb162103d7f 100644 --- a/vllm/model_executor/models/kimi_linear.py +++ b/vllm/model_executor/models/kimi_linear.py @@ -26,6 +26,8 @@ ) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateCopyFunc, + MambaStateCopyFuncCalculator, MambaStateDtypeCalculator, MambaStateShapeCalculator, ) @@ -544,6 +546,14 @@ def get_mamba_state_shape_from_config( num_spec=num_spec, ) + @classmethod + def get_mamba_state_copy_func( + cls, + ) -> tuple[ + MambaStateCopyFunc, MambaStateCopyFunc, MambaStateCopyFunc, MambaStateCopyFunc + ]: + return MambaStateCopyFuncCalculator.kda_state_copy_func() + def compute_logits( self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/lfm2.py b/vllm/model_executor/models/lfm2.py index 24e4b5df71e4..b4d5ae415ea4 100644 --- a/vllm/model_executor/models/lfm2.py +++ b/vllm/model_executor/models/lfm2.py @@ -20,6 +20,8 @@ ) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateCopyFunc, + MambaStateCopyFuncCalculator, MambaStateDtypeCalculator, MambaStateShapeCalculator, ) @@ -455,6 +457,10 @@ def get_mamba_state_shape_from_config( conv_kernel=hf_config.conv_L_cache, ) + @classmethod + def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc]: + return MambaStateCopyFuncCalculator.short_conv_state_copy_func() + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config diff --git a/vllm/model_executor/models/lfm2_moe.py b/vllm/model_executor/models/lfm2_moe.py index 70804e0a843e..253c5a77aad7 100644 --- a/vllm/model_executor/models/lfm2_moe.py +++ b/vllm/model_executor/models/lfm2_moe.py @@ -25,6 +25,8 @@ ) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateCopyFunc, + MambaStateCopyFuncCalculator, MambaStateDtypeCalculator, MambaStateShapeCalculator, ) @@ -635,6 +637,10 @@ def get_mamba_state_shape_from_config( conv_kernel=hf_config.conv_L_cache, ) + @classmethod + def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc]: + return MambaStateCopyFuncCalculator.short_conv_state_copy_func() + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index aa16640a9427..85212feca529 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -16,6 +16,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateCopyFunc, + MambaStateCopyFuncCalculator, MambaStateDtypeCalculator, MambaStateShapeCalculator, ) @@ -261,6 +263,10 @@ def get_mamba_state_shape_from_config( conv_kernel=hf_config.conv_kernel, ) + @classmethod + def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]: + return MambaStateCopyFuncCalculator.mamba1_state_copy_func() + def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): return self.mamba_cache.copy_inputs_before_cuda_graphs(input_buffers, **kwargs) diff --git a/vllm/model_executor/models/mamba2.py b/vllm/model_executor/models/mamba2.py index 5fcfa9431230..ed363df21230 100644 --- a/vllm/model_executor/models/mamba2.py +++ b/vllm/model_executor/models/mamba2.py @@ -15,6 +15,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateCopyFunc, + MambaStateCopyFuncCalculator, MambaStateDtypeCalculator, MambaStateShapeCalculator, ) @@ -228,6 +230,10 @@ def get_mamba_state_shape_from_config( conv_kernel=hf_config.conv_kernel, ) + @classmethod + def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]: + return MambaStateCopyFuncCalculator.mamba2_state_copy_func() + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py index 4bfe3c391c26..aa367ec6fcf3 100644 --- a/vllm/model_executor/models/minimax_text_01.py +++ b/vllm/model_executor/models/minimax_text_01.py @@ -36,6 +36,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.linear_attn import MiniMaxText01LinearAttention from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateCopyFunc, + MambaStateCopyFuncCalculator, MambaStateDtypeCalculator, MambaStateShapeCalculator, ) @@ -1005,3 +1007,7 @@ def get_mamba_state_shape_from_config( tp_size=parallel_config.tensor_parallel_size, head_dim=hf_config.head_dim, ) + + @classmethod + def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc]: + return MambaStateCopyFuncCalculator.linear_attention_state_copy_func() diff --git a/vllm/model_executor/models/nano_nemotron_vl.py b/vllm/model_executor/models/nano_nemotron_vl.py index 6dfab595e5b9..c469df4a45a1 100644 --- a/vllm/model_executor/models/nano_nemotron_vl.py +++ b/vllm/model_executor/models/nano_nemotron_vl.py @@ -1728,3 +1728,7 @@ def get_mamba_state_dtype_from_config(cls, vllm_config: "VllmConfig"): temp_vllm_config = copy.deepcopy(vllm_config) temp_vllm_config.model_config.hf_config = text_config return NemotronHForCausalLM.get_mamba_state_dtype_from_config(temp_vllm_config) + + @classmethod + def get_mamba_state_copy_func(cls): + return NemotronHForCausalLM.get_mamba_state_copy_func() diff --git a/vllm/model_executor/models/nemotron_h.py b/vllm/model_executor/models/nemotron_h.py index 2d9dfbd3e768..df73e24d9723 100644 --- a/vllm/model_executor/models/nemotron_h.py +++ b/vllm/model_executor/models/nemotron_h.py @@ -45,6 +45,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateCopyFunc, + MambaStateCopyFuncCalculator, MambaStateDtypeCalculator, MambaStateShapeCalculator, ) @@ -804,6 +806,10 @@ def get_mamba_state_shape_from_config( conv_kernel=hf_config.conv_kernel, ) + @classmethod + def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]: + return MambaStateCopyFuncCalculator.mamba2_state_copy_func() + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config self.vllm_config = vllm_config diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index 7f7bdff156c8..0df82b358289 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -28,7 +28,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.mamba.mamba_utils import ( - MambaCopySpecCalculator, + MambaStateCopyFunc, + MambaStateCopyFuncCalculator, MambaStateDtypeCalculator, MambaStateShapeCalculator, ) @@ -459,9 +460,6 @@ def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: conv_kernel=self.conv_kernel_size, ) - def get_copy_spec_func(self): - return MambaCopySpecCalculator.mamba2_state_copy_spec_func() - @property def mamba_type(self) -> str: return "mamba2" @@ -881,6 +879,10 @@ def get_mamba_state_shape_from_config( conv_kernel=hf_config.mamba_d_conv, ) + @classmethod + def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]: + return MambaStateCopyFuncCalculator.mamba2_state_copy_func() + def compute_logits( self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index b9c5da7516d4..f2ac57b8af7b 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -50,7 +50,8 @@ from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.mamba.mamba_mixer2 import mamba_v2_sharded_weight_loader from vllm.model_executor.layers.mamba.mamba_utils import ( - MambaCopySpecCalculator, + MambaStateCopyFunc, + MambaStateCopyFuncCalculator, MambaStateDtypeCalculator, MambaStateShapeCalculator, ) @@ -236,9 +237,6 @@ def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: self.num_spec, ) - def get_copy_spec_func(self): - return MambaCopySpecCalculator.gated_delta_net_copy_spec_func() - def __init__( self, config: Qwen3NextConfig, @@ -1271,6 +1269,10 @@ def get_mamba_state_shape_from_config( num_spec, ) + @classmethod + def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]: + return MambaStateCopyFuncCalculator.gated_delta_net_state_copy_func() + def compute_logits( self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/zamba2.py b/vllm/model_executor/models/zamba2.py index fe157887eea9..31252f0bc89a 100644 --- a/vllm/model_executor/models/zamba2.py +++ b/vllm/model_executor/models/zamba2.py @@ -32,6 +32,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateCopyFunc, + MambaStateCopyFuncCalculator, MambaStateDtypeCalculator, MambaStateShapeCalculator, ) @@ -884,6 +886,10 @@ def get_mamba_state_shape_from_config( conv_kernel=hf_config.mamba_d_conv, ) + @classmethod + def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]: + return MambaStateCopyFuncCalculator.mamba2_state_copy_func() + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: """Initialize the Zamba2 model for causal language modeling. diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index f6b3d553c609..867354a81ba3 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -10,7 +10,6 @@ from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.model_executor.layers.mamba.mamba_utils import MambaCopySpecFunc from vllm.utils.math_utils import cdiv from vllm.utils.torch_utils import get_dtype_size @@ -248,7 +247,6 @@ class MambaSpec(KVCacheSpec): page_size_padded: int | None = None mamba_type: str = "mamba2" num_speculative_blocks: int = 0 - copy_spec_funcs: tuple[MambaCopySpecFunc, ...] | None = None @property def page_size_bytes(self) -> int: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 55204e0c3c7e..d5372839908f 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1060,6 +1060,7 @@ def _update_states_after_model_execute( self.requests, self.mamba_state_idx, self.compilation_config.static_forward_context, + self.model.get_mamba_state_copy_func(), ) def _init_mrope_positions(self, req_state: CachedRequestState): @@ -3108,6 +3109,7 @@ def execute_model( self.input_batch, self.requests, self.compilation_config.static_forward_context, + self.model.get_mamba_state_copy_func(), ) use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 diff --git a/vllm/v1/worker/mamba_utils.py b/vllm/v1/worker/mamba_utils.py index 9e99ecf51837..20bf442ecae5 100644 --- a/vllm/v1/worker/mamba_utils.py +++ b/vllm/v1/worker/mamba_utils.py @@ -9,7 +9,7 @@ from vllm.config import CacheConfig from vllm.model_executor.layers.mamba.mamba_utils import ( - MambaCopySpecFunc, + MambaStateCopyFunc, ) from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import KVCacheConfig, MambaSpec @@ -64,7 +64,7 @@ def collect_mamba_copy_meta( dest_state_list: list[int], num_elements_list: list[int], kv_cache_config: KVCacheConfig, - mamba_spec: MambaSpec, + mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...], mamba_group_ids: list[int], src_block_idx: int, dest_block_idx: int, @@ -75,8 +75,6 @@ def collect_mamba_copy_meta( if src_block_idx == dest_block_idx and accept_token_bias == 0: return - copy_spec_funcs: tuple[MambaCopySpecFunc, ...] = mamba_spec.copy_spec_funcs - assert copy_spec_funcs is not None for mamba_group_id in mamba_group_ids: block_ids = req_state.block_ids[mamba_group_id] dest_block_id = block_ids[dest_block_idx] @@ -84,8 +82,8 @@ def collect_mamba_copy_meta( for layer_name in layer_names: attention = forward_context[layer_name] kv_caches: list[torch.Tensor] = attention.kv_cache[0] - for state, copy_spec_func in zip(kv_caches, copy_spec_funcs): - copy_spec = copy_spec_func( + for state, state_copy_func in zip(kv_caches, mamba_state_copy_funcs): + copy_spec = state_copy_func( state, block_ids, src_block_idx, accept_token_bias + 1 ) @@ -118,6 +116,7 @@ def preprocess_mamba( input_batch: GPUInputBatch, requests: dict[str, CachedRequestState], forward_context: dict[str, Any], + mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...], ): """ Copy the mamba state of previous step to the last @@ -164,7 +163,7 @@ def preprocess_mamba( dest_state_list, num_elements_list, kv_cache_config, - mamba_spec, + mamba_state_copy_funcs, mamba_group_ids, prev_state_idx, curr_state_idx, @@ -183,6 +182,7 @@ def postprocess_mamba( requests: dict[str, CachedRequestState], mamba_state_idx: dict[str, int], forward_context: dict[str, Any], + mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...], ): """ If a blocks is converted from partial block to full block in this step, copy the @@ -220,7 +220,7 @@ def postprocess_mamba( dest_state_list, num_elements_list, kv_cache_config, - mamba_spec, + mamba_state_copy_funcs, mamba_group_ids, src_block_idx, dest_block_idx, From bf57e49d1ddd07a53bba1d0c7c56ec62adac4870 Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Thu, 25 Dec 2025 15:51:11 +0000 Subject: [PATCH 085/130] update mamba cache mode Signed-off-by: huanghaoyan.hhy --- vllm/config/cache.py | 15 --------------- vllm/model_executor/models/config.py | 16 ++++++++++++++-- 2 files changed, 14 insertions(+), 17 deletions(-) diff --git a/vllm/config/cache.py b/vllm/config/cache.py index bb357ec614ac..85195976e022 100644 --- a/vllm/config/cache.py +++ b/vllm/config/cache.py @@ -239,18 +239,3 @@ def verify_with_parallel_config( raise ValueError("Too large swap space. " + msg) elif cpu_memory_usage > 0.4 * total_cpu_memory: logger.warning("Possibly too large swap space. %s", msg) - - def __post_init__(self) -> None: - if self.enable_prefix_caching: - if self.mamba_cache_mode == "none": - self.mamba_cache_mode = "align" - logger.warning( - "mamba_cache_mode set to 'align' defaultly when prefix " - "caching is enabled" - ) - else: - if self.mamba_cache_mode != "none": - self.mamba_cache_mode = "none" - logger.warning( - "mamba_cache_mode set to 'none' when prefix caching is disabled" - ) diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 2951be193f17..2c6b58ec1245 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -287,6 +287,12 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: cache_config = vllm_config.cache_config if cache_config.enable_prefix_caching: + if cache_config.mamba_cache_mode == "none": + cache_config.mamba_cache_mode = "align" + logger.warning( + "Mamba cache mode is set to 'align' defaultly when prefix " + "caching is enabled" + ) if cache_config.mamba_cache_mode == "all": if model_config.supports_mamba_prefix_caching: logger.info( @@ -313,8 +319,14 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: raise ValueError( f"unknown mamba cache mode: {cache_config.mamba_cache_mode}" ) - elif cache_config.mamba_block_size is None: - cache_config.mamba_block_size = model_config.max_model_len + else: + if cache_config.mamba_cache_mode != "none": + cache_config.mamba_cache_mode = "none" + logger.warning( + "Mamba cache mode is set to 'none' when prefix caching is disabled" + ) + if cache_config.mamba_block_size is None: + cache_config.mamba_block_size = model_config.max_model_len class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig): From 7e845f9f4bb38ed5252afa0340320ecc66a28f32 Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Thu, 25 Dec 2025 15:58:10 +0000 Subject: [PATCH 086/130] cleanup mamba manager Signed-off-by: huanghaoyan.hhy --- vllm/v1/core/single_type_kv_cache_manager.py | 34 -------------------- 1 file changed, 34 deletions(-) diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index beb925fbbc1b..d082640f9f5c 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -21,29 +21,6 @@ from vllm.v1.request import Request -def format_blocks(blocks: list[KVCacheBlock]): - if not blocks: - return "[]" - - result = [] - i = 0 - - while i < len(blocks): - if blocks[i].block_id == 0: - count = 0 - while i < len(blocks) and blocks[i].block_id == 0: - count += 1 - i += 1 - result.append(f"Null-block*{count}") - else: - result.append( - f"KVBlock(block_id={blocks[i].block_id}, ref_cnt={blocks[i].ref_cnt})" - ) - i += 1 - - return f"[{', '.join(result)}]" - - class SingleTypeKVCacheManager(ABC): """ An abstract base class for a manager that handle the kv cache management @@ -89,10 +66,6 @@ def __init__( self.kv_cache_group_id = kv_cache_group_id self._null_block = block_pool.null_block - def print(self, *args, **kwargs): - new_args = (f">>> [KvGrp {self.kv_cache_group_id}] ",) + args - print(*new_args, **kwargs) - def get_num_blocks_to_allocate( self, request_id: str, @@ -823,13 +796,6 @@ def get_num_blocks_to_allocate( ) return num_new_blocks + num_evictable_computed_blocks - def save_new_computed_blocks( - self, request_id: str, new_computed_blocks: Sequence[KVCacheBlock] - ) -> None: - # TODO(hhy): remove when prefix-caching is ready - assert isinstance(self.kv_cache_spec, MambaSpec) - super().save_new_computed_blocks(request_id, list(new_computed_blocks)) - def allocate_new_blocks( self, request_id: str, num_tokens: int, num_tokens_main_model: int ) -> list[KVCacheBlock]: From a6b7d08c3e99426b92dc2e8b12d0eaed6d9cc6f4 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Thu, 25 Dec 2025 23:46:38 -0800 Subject: [PATCH 087/130] fix test Signed-off-by: Chen Zhang --- tests/v1/e2e/test_mamba_prefix_cache.py | 139 +++++++++++++++--------- 1 file changed, 86 insertions(+), 53 deletions(-) diff --git a/tests/v1/e2e/test_mamba_prefix_cache.py b/tests/v1/e2e/test_mamba_prefix_cache.py index 00c48f7160f2..5c52be740978 100644 --- a/tests/v1/e2e/test_mamba_prefix_cache.py +++ b/tests/v1/e2e/test_mamba_prefix_cache.py @@ -10,6 +10,7 @@ from vllm import LLM, SamplingParams, TokensPrompt from vllm.config import CacheConfig +from vllm.model_executor.layers.mamba.mamba_utils import MambaStateCopyFunc from vllm.sequence import IntermediateTensors from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager @@ -31,18 +32,18 @@ class StepAction: num_computed_tokens_start: int num_scheduled_tokens: int - kv_cache_block_ids: list[int] | None # [] to follow last step - preprocess_copy_idx: tuple[int, int] | None # -1, -1 for no copy - postprocess_copy_idx: tuple[int, int] | None # -1, -1 for no copy + kv_cache_block_ids: list[int] # [] to follow last step + preprocess_copy_idx: tuple[int, int] # -1, -1 for no copy + postprocess_copy_idx: tuple[int, int] # -1, -1 for no copy num_speculative_tokens = 3 num_accepted_tokens = 1 prompt_token_ids: list[int] = [] -MODEL = "Qwen/Qwen3-Next-80B-A3B-Instruct" +MODEL = "Qwen/Qwen3-Next-80B-A3B-Instruct-FP8" BLOCK_SIZE = 560 -NUM_HIDDEN_LAYERS = 8 +NUM_HIDDEN_LAYERS = 1 cur_step_action_idx = 0 cur_step_action: StepAction | None = None step_actions: list[StepAction] = [] @@ -274,7 +275,37 @@ def get_fake_process_mamba_fn( original_post_process_mamba_fn: Callable, original_copy_fn: Callable, ): - copy_info = (-1, -1) + copy_info: tuple[list[int], list[int], list[int]] | None = None + + def check_copy_info( + action: tuple[int, int], + kv_cache_config: KVCacheConfig, + forward_context: dict[str, Any], + input_batch: GPUInputBatch, + ): + assert copy_info is not None + if action == (-1, -1): + assert len(copy_info[0]) == len(copy_info[1]) == len(copy_info[2]) == 0 + else: + assert len(copy_info[0]) == len(copy_info[1]) == len(copy_info[2]) == 2 + mamba_group_ids, mamba_spec = get_mamba_groups(kv_cache_config) + mamba_group_id = mamba_group_ids[0] + mamba_layer_name = kv_cache_config.kv_cache_groups[ + mamba_group_id + ].layer_names[0] + mamba_kv_cache = forward_context[mamba_layer_name].kv_cache[0][-1] + mamba_block_table = input_batch.block_table.block_tables[ + mamba_group_id + ].block_table.cpu[0] + expected_temporal_src = mamba_kv_cache[ + mamba_block_table[action[0]] + ].data_ptr() + expected_temporal_dest = mamba_kv_cache[ + mamba_block_table[action[1]] + ].data_ptr() + # -1 is qwen3-next's temporal. We skip checking conv as it is more complex. + assert copy_info[0][-1] == expected_temporal_src + assert copy_info[1][-1] == expected_temporal_dest def fake_preprocess_mamba_fn( scheduler_output: SchedulerOutput, @@ -284,9 +315,10 @@ def fake_preprocess_mamba_fn( input_batch: GPUInputBatch, requests: dict[str, CachedRequestState], forward_context: dict[str, Any], + mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...], ): nonlocal copy_info - copy_info = (-1, -1) + copy_info = None ret = original_preprocess_mamba_fn( scheduler_output, kv_cache_config, @@ -295,10 +327,15 @@ def fake_preprocess_mamba_fn( input_batch, requests, forward_context, + mamba_state_copy_funcs, ) if cur_step_action is not None: - print("[UNIT TEST STEP] verifying preprocess_copy_idx") - assert copy_info == cur_step_action.preprocess_copy_idx + check_copy_info( + cur_step_action.preprocess_copy_idx, + kv_cache_config, + forward_context, + input_batch, + ) return ret def fake_post_process_mamba_fn( @@ -308,9 +345,10 @@ def fake_post_process_mamba_fn( requests: dict[str, CachedRequestState], mamba_state_idx: dict[str, int], forward_context: dict[str, Any], + mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...], ): nonlocal copy_info - copy_info = (-1, -1) + copy_info = None ret = original_post_process_mamba_fn( scheduler_output, kv_cache_config, @@ -318,32 +356,29 @@ def fake_post_process_mamba_fn( requests, mamba_state_idx, forward_context, + mamba_state_copy_funcs, ) if cur_step_action is not None: - print("[UNIT TEST STEP] verifying postprocess_copy_idx") - assert copy_info == cur_step_action.postprocess_copy_idx + check_copy_info( + cur_step_action.postprocess_copy_idx, + kv_cache_config, + forward_context, + input_batch, + ) return ret def fake_copy_fn( - kv_cache_config: KVCacheConfig, - mamba_group_ids: list[int], - src_block_idx: int, - dest_block_idx: int, - accept_token_bias: int, - req_state: CachedRequestState, - forward_context: dict[str, Any], + src_state_list: list[int], + dest_state_list: list[int], + num_elements_list: list[int], ): nonlocal copy_info - assert copy_info == (-1, -1) - copy_info = (src_block_idx, dest_block_idx) + assert copy_info is None + copy_info = (src_state_list, dest_state_list, num_elements_list) return original_copy_fn( - kv_cache_config, - mamba_group_ids, - src_block_idx, - dest_block_idx, - accept_token_bias, - req_state, - forward_context, + src_state_list, + dest_state_list, + num_elements_list, ) return fake_preprocess_mamba_fn, fake_post_process_mamba_fn, fake_copy_fn @@ -362,7 +397,6 @@ def test_run_ref_mamba_state(monkeypatch: pytest.MonkeyPatch): monkeypatch.setattr(GPUModelRunner, "_sample", fake_sample_fn) engine = LLM( model=MODEL, - enforce_eager=True, block_size=BLOCK_SIZE, hf_overrides={"num_hidden_layers": NUM_HIDDEN_LAYERS}, seed=42, @@ -452,12 +486,12 @@ def apply_patch(monkeypatch: pytest.MonkeyPatch): get_fake_process_mamba_fn( mamba_utils.preprocess_mamba, mamba_utils.postprocess_mamba, - mamba_utils.mamba_copy_block, + mamba_utils.do_mamba_copy_block, ) ) monkeypatch.setattr(mamba_utils, "preprocess_mamba", fake_preprocess_mamba_fn) monkeypatch.setattr(mamba_utils, "postprocess_mamba", fake_post_process_mamba_fn) - monkeypatch.setattr(mamba_utils, "mamba_copy_block", fake_copy_fn) + monkeypatch.setattr(mamba_utils, "do_mamba_copy_block", fake_copy_fn) def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch): @@ -490,7 +524,7 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch): StepAction(0, 554, [1, 1, 1, 1], (-1, -1), (-1, -1)), StepAction(554, 4, [], (-1, -1), (-1, -1)), StepAction(556, 4, [], (-1, -1), (-1, -1)), - StepAction(558, 4, [1, 1, 1, 1, 1], (0, 1), (1, 0)), + StepAction(558, 4, [1, 1, 1, 1, 1], (1, 1), (2, 0)), StepAction(560, 4, [], (-1, -1), (-1, -1)), StepAction(562, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), ], @@ -503,7 +537,7 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch): step_actions=[ StepAction(0, 555, [1, 1, 1, 1], (-1, -1), (-1, -1)), StepAction(555, 4, [], (-1, -1), (-1, -1)), - StepAction(557, 4, [1, 1, 1, 1, 1], (0, 1), (-1, -1)), + StepAction(557, 4, [1, 1, 1, 1, 1], (1, 1), (-1, -1)), StepAction(559, 4, [], (-1, -1), (1, 0)), StepAction(561, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), ], @@ -516,7 +550,7 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch): StepAction(0, 553, [1, 1, 1, 1], (-1, -1), (-1, -1)), StepAction(553, 4, [], (-1, -1), (-1, -1)), StepAction(556, 4, [], (-1, -1), (-1, -1)), - StepAction(559, 4, [1, 1, 1, 1, 1], (0, 1), (1, 0)), + StepAction(559, 4, [1, 1, 1, 1, 1], (2, 1), (1, 0)), StepAction(562, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), ], ), @@ -527,7 +561,7 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch): step_actions=[ StepAction(0, 554, [1, 1, 1, 1], (-1, -1), (-1, -1)), StepAction(554, 4, [], (-1, -1), (-1, -1)), - StepAction(557, 4, [1, 1, 1, 1, 1], (0, 1), (1, 0)), + StepAction(557, 4, [1, 1, 1, 1, 1], (2, 1), (3, 0)), StepAction(560, 4, [], (-1, -1), (-1, -1)), StepAction(563, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), ], @@ -539,7 +573,7 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch): step_actions=[ StepAction(0, 555, [1, 1, 1, 1], (-1, -1), (-1, -1)), StepAction(555, 4, [], (-1, -1), (-1, -1)), - StepAction(558, 4, [1, 1, 1, 1, 1], (0, 1), (1, 0)), + StepAction(558, 4, [1, 1, 1, 1, 1], (2, 1), (2, 0)), StepAction(561, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), ], ), @@ -550,7 +584,7 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch): step_actions=[ StepAction(0, 553, [1, 1, 1, 1], (-1, -1), (-1, -1)), StepAction(553, 4, [], (-1, -1), (-1, -1)), - StepAction(557, 4, [1, 1, 1, 1, 1], (0, 1), (1, 0)), + StepAction(557, 4, [1, 1, 1, 1, 1], (3, 1), (3, 0)), StepAction(561, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), StepAction(565, 4, [], (-1, -1), (-1, -1)), ], @@ -562,7 +596,7 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch): step_actions=[ StepAction(0, 554, [1, 1, 1, 1], (-1, -1), (-1, -1)), StepAction(554, 4, [], (-1, -1), (-1, -1)), - StepAction(558, 4, [1, 1, 1, 1, 1], (0, 1), (1, 0)), + StepAction(558, 4, [1, 1, 1, 1, 1], (3, 1), (2, 0)), StepAction(562, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), StepAction(566, 4, [], (-1, -1), (-1, -1)), ], @@ -574,7 +608,7 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch): step_actions=[ StepAction(0, 555, [1, 1, 1, 1], (-1, -1), (-1, -1)), StepAction(555, 4, [], (-1, -1), (-1, -1)), - StepAction(559, 4, [1, 1, 1, 1, 1], (0, 1), (1, 0)), + StepAction(559, 4, [1, 1, 1, 1, 1], (3, 1), (1, 0)), StepAction(563, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), ], ), @@ -584,7 +618,7 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch): num_accepted_tokens=4, step_actions=[ StepAction(0, 556, [1, 1, 1, 1], (-1, -1), (-1, -1)), - StepAction(556, 4, [], (-1, -1), (0, 0)), + StepAction(556, 4, [], (-1, -1), (3, 0)), StepAction(560, 4, [1, 1, 1, 1, 1], (0, 1), (-1, -1)), StepAction(564, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), ], @@ -594,7 +628,7 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch): num_generated_tokens=10, num_accepted_tokens=4, step_actions=[ - StepAction(0, 560, [1, 1, 1, 1], (-1, -1), (0, 0)), + StepAction(0, 560, [1, 1, 1, 1], (-1, -1), (-1, -1)), StepAction(560, 4, [1, 1, 1, 1, 1], (0, 1), (-1, -1)), ], ), @@ -603,8 +637,8 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch): num_generated_tokens=10, num_accepted_tokens=4, step_actions=[ - StepAction(0, 560, [1, 1, 1, 1], (-1, -1), (0, 0)), - StepAction(560, 560, [1, 1, 1, 1, 1], (0, 1), (1, 1)), + StepAction(0, 560, [1, 1, 1, 1], (-1, -1), (-1, -1)), + StepAction(560, 560, [1, 1, 1, 1, 1], (0, 1), (-1, -1)), StepAction(560 * 2, 4, [0, 1, 1, 1, 1, 1], (1, 2), (-1, -1)), ], ), @@ -613,7 +647,7 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch): num_generated_tokens=10, num_accepted_tokens=4, step_actions=[ - StepAction(0, 560, [1, 1, 1, 1], (-1, -1), (0, 0)), + StepAction(0, 560, [1, 1, 1, 1], (-1, -1), (-1, -1)), StepAction(560, 570, [1, 0, 1, 1, 1, 1], (0, 2), (-1, -1)), StepAction(560 * 2 + 10, 4, [0, 0, 1, 1, 1, 1], (-1, -1), (-1, -1)), ], @@ -623,8 +657,8 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch): num_generated_tokens=10, num_accepted_tokens=4, step_actions=[ - StepAction(0, 560 * 2, [0, 1, 1, 1, 1], (-1, -1), (1, 1)), - StepAction(560 * 2, 560, [0, 1, 1, 1, 1, 1], (1, 2), (2, 2)), + StepAction(0, 560 * 2, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), + StepAction(560 * 2, 560, [0, 1, 1, 1, 1, 1], (1, 2), (-1, -1)), StepAction(560 * 3, 4, [0, 0, 1, 1, 1, 1, 1], (2, 3), (-1, -1)), ], ), @@ -633,7 +667,7 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch): num_generated_tokens=10, num_accepted_tokens=4, step_actions=[ - StepAction(0, 560 * 2, [0, 1, 1, 1, 1], (-1, -1), (1, 1)), + StepAction(0, 560 * 2, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)), StepAction(560 * 2, 570, [0, 1, 0, 1, 1, 1, 1], (1, 3), (-1, -1)), StepAction(560 * 3 + 10, 4, [0, 0, 0, 1, 1, 1, 1], (-1, -1), (-1, -1)), ], @@ -643,20 +677,20 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch): num_generated_tokens=10, num_accepted_tokens=4, step_actions=[ - StepAction(0, 560 * 5, [0, 0, 0, 0, 1, 1, 1, 1], (-1, -1), (4, 4)), + StepAction(0, 560 * 5, [0, 0, 0, 0, 1, 1, 1, 1], (-1, -1), (-1, -1)), StepAction( 560 * 5, 560 * 4, [0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1], (4, 8), - (8, 8), + (-1, -1), ), StepAction( 560 * 9, 560, [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1], (8, 9), - (9, 9), + (-1, -1), ), StepAction( 560 * 10, @@ -672,13 +706,13 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch): num_generated_tokens=10, num_accepted_tokens=4, step_actions=[ - StepAction(0, 560 * 5, [0, 0, 0, 0, 1, 1, 1, 1], (-1, -1), (4, 4)), + StepAction(0, 560 * 5, [0, 0, 0, 0, 1, 1, 1, 1], (-1, -1), (-1, -1)), StepAction( 560 * 5, 560 * 4, [0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1], (4, 8), - (8, 8), + (-1, -1), ), StepAction( 560 * 9, @@ -694,7 +728,6 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch): engine = LLM( model=MODEL, enable_prefix_caching=True, - enforce_eager=True, block_size=BLOCK_SIZE, mamba_cache_mode="align", speculative_config={ From ad71cce0ec7fc66e75ef792e71bebb303bd01f3e Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Sat, 3 Jan 2026 08:54:44 +0000 Subject: [PATCH 088/130] fix sinkfullattn Signed-off-by: huanghaoyan.hhy --- vllm/v1/core/single_type_kv_cache_manager.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index acd86a3ad146..51d1564e4954 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -1029,6 +1029,7 @@ class SinkFullAttentionManager(FullAttentionManager): def __init__( self, kv_cache_spec: SinkFullAttentionSpec, + cache_config: CacheConfig, block_pool: BlockPool, enable_caching: bool, kv_cache_group_id: int, @@ -1037,6 +1038,7 @@ def __init__( ): super().__init__( kv_cache_spec, + cache_config, block_pool, enable_caching, kv_cache_group_id, From 961845f1aa14aae5c279bb848bf3b5db2123fc2c Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Sat, 3 Jan 2026 08:56:20 +0000 Subject: [PATCH 089/130] update mamba_attn for align mode Signed-off-by: huanghaoyan.hhy --- vllm/v1/attention/backends/mamba_attn.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index 343f9e2bfed4..512f0387c1da 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -16,6 +16,7 @@ AttentionMetadataBuilder, CommonAttentionMetadata, compute_causal_conv1d_metadata, + mamba_get_block_table_tensor, split_decodes_and_prefills, ) from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec @@ -192,7 +193,7 @@ def _compute_common_metadata( # for causal_conv1d nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None - if self.vllm_config.cache_config.enable_prefix_caching: + if self.vllm_config.cache_config.mamba_cache_mode == "all": # Return a tensor of shape (#requests, #max blocks) state_indices_tensor = common_attn_metadata.block_table_tensor # Additional cache-related varaiables: @@ -209,7 +210,11 @@ def _compute_common_metadata( ) else: # Always return just a single block per each request: - state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] + state_indices_tensor = mamba_get_block_table_tensor( + common_attn_metadata, + self.kv_cache_spec, + self.vllm_config.cache_config.mamba_cache_mode, + )[:, 0] if num_prefills > 0: query_start_loc_p = ( @@ -230,7 +235,7 @@ def _compute_common_metadata( compute_causal_conv1d_metadata(query_start_loc_p) ) - if self.vllm_config.cache_config.enable_prefix_caching: + if self.vllm_config.cache_config.mamba_cache_mode == "all": assert num_computed_tokens is not None num_computed_tokens_p = num_computed_tokens[ num_reqs - num_prefills : num_reqs @@ -249,7 +254,7 @@ def _compute_common_metadata( state_indices_tensor = self.state_indices_tensor[:num_decode_tokens] state_indices_tensor[num_decodes:] = PAD_SLOT_ID - if self.vllm_config.cache_config.enable_prefix_caching: + if self.vllm_config.cache_config.mamba_cache_mode == "all": self.block_idx_last_scheduled_token[:num_decodes].copy_( block_idx_last_scheduled_token, non_blocking=True ) From 04f4c45a0bd74a638172f0f680db2a7d494dedcc Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Sat, 3 Jan 2026 09:35:18 +0000 Subject: [PATCH 090/130] remove duplicate get_num_skipped_tokens Signed-off-by: huanghaoyan.hhy --- vllm/v1/core/single_type_kv_cache_manager.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 51d1564e4954..c6ba755e6765 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -802,10 +802,6 @@ def find_longest_cache_hit( return computed_blocks - def get_num_skipped_tokens(self, num_computed_tokens: int) -> int: - # TODO: merge https://github.com/vllm-project/vllm/pull/28047 first - return num_computed_tokens - 1 - def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None: assert isinstance(self.kv_cache_spec, MambaSpec) super().remove_skipped_blocks(request_id, num_computed_tokens) From fc51f32266a59faca8b0e68c19b31e7c626d4487 Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Sat, 3 Jan 2026 09:36:41 +0000 Subject: [PATCH 091/130] update builder.update_block_table Signed-off-by: huanghaoyan.hhy --- vllm/attention/layers/chunked_local_attention.py | 10 ++++++++-- vllm/v1/attention/backends/flash_attn.py | 1 + vllm/v1/attention/backends/mamba_attn.py | 15 +++++++++++++-- vllm/v1/attention/backends/utils.py | 1 + vllm/v1/worker/gpu_model_runner.py | 1 + 5 files changed, 24 insertions(+), 4 deletions(-) diff --git a/vllm/attention/layers/chunked_local_attention.py b/vllm/attention/layers/chunked_local_attention.py index 7e3794d40833..026a56b978fc 100644 --- a/vllm/attention/layers/chunked_local_attention.py +++ b/vllm/attention/layers/chunked_local_attention.py @@ -60,10 +60,16 @@ def build( return metadata def update_block_table( - self, metadata, blk_table: torch.Tensor, slot_mapping: torch.Tensor + self, + common_metadata, + metadata, + blk_table: torch.Tensor, + slot_mapping: torch.Tensor, ): blk_table = metadata.make_virtual_batches_block_table(blk_table) - return super().update_block_table(metadata, blk_table, slot_mapping) + return super().update_block_table( + metadata, common_metadata, blk_table, slot_mapping + ) attn_backend = subclass_attention_backend( name_prefix=prefix, diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 3445e998d637..c6a43996998c 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -497,6 +497,7 @@ def schedule( def update_block_table( self, + common_metadata: CommonAttentionMetadata, metadata: FlashAttentionMetadata, blk_table: torch.Tensor, slot_mapping: torch.Tensor, diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index 512f0387c1da..31e9085bfc9c 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -289,13 +289,24 @@ def _compute_common_metadata( def update_block_table( self, + common_metadata: CommonAttentionMetadata, metadata: M, blk_table: torch.Tensor, slot_mapping: torch.Tensor, ) -> M: new_metadata = copy.copy(metadata) - prefix_caching = self.vllm_config.cache_config.enable_prefix_caching - state_indices_t = blk_table if prefix_caching else blk_table[:, 0] + prefix_caching_all_mode = ( + self.vllm_config.cache_config.mamba_cache_mode == "all" + ) + state_indices_t = ( + blk_table + if prefix_caching_all_mode + else mamba_get_block_table_tensor( + common_metadata, + self.kv_cache_spec, + self.vllm_config.cache_config.mamba_cache_mode, + )[:, 0] + ) num_reqs = blk_table.shape[0] # For CUDA graphs, copy to persistent buffer diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 3e52a9cafe4a..bb407a873726 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -393,6 +393,7 @@ def build( def update_block_table( self, + common_metadata: CommonAttentionMetadata, metadata: M, blk_table: torch.Tensor, slot_mapping: torch.Tensor, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 920001e25ce9..20f0a82dc2a1 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1728,6 +1728,7 @@ def _build_attn_group_metadata( and builder.supports_update_block_table ): attn_metadata_i = builder.update_block_table( + common_attn_metadata, cached_attn_metadata[cache_key], common_attn_metadata.block_table_tensor, common_attn_metadata.slot_mapping, From 61af735cffa3ad0da649e30c3caa82f1b712074b Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Sat, 3 Jan 2026 09:57:03 +0000 Subject: [PATCH 092/130] update builder.update_block_table: rm blk_table and slot_mapping Signed-off-by: huanghaoyan.hhy --- vllm/attention/layers/chunked_local_attention.py | 9 ++++----- vllm/v1/attention/backends/flash_attn.py | 6 ++---- vllm/v1/attention/backends/mamba_attn.py | 3 +-- vllm/v1/attention/backends/utils.py | 2 -- vllm/v1/worker/gpu_model_runner.py | 2 -- 5 files changed, 7 insertions(+), 15 deletions(-) diff --git a/vllm/attention/layers/chunked_local_attention.py b/vllm/attention/layers/chunked_local_attention.py index 026a56b978fc..f95a114552de 100644 --- a/vllm/attention/layers/chunked_local_attention.py +++ b/vllm/attention/layers/chunked_local_attention.py @@ -63,13 +63,12 @@ def update_block_table( self, common_metadata, metadata, - blk_table: torch.Tensor, - slot_mapping: torch.Tensor, ): - blk_table = metadata.make_virtual_batches_block_table(blk_table) - return super().update_block_table( - metadata, common_metadata, blk_table, slot_mapping + new_metadata = super().update_block_table(metadata, common_metadata) + new_metadata.block_table = metadata.make_virtual_batches_block_table( + new_metadata.block_table ) + return new_metadata attn_backend = subclass_attention_backend( name_prefix=prefix, diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index c6a43996998c..61a038e3f5c4 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -499,12 +499,10 @@ def update_block_table( self, common_metadata: CommonAttentionMetadata, metadata: FlashAttentionMetadata, - blk_table: torch.Tensor, - slot_mapping: torch.Tensor, ) -> FlashAttentionMetadata: new_metadata = copy.copy(metadata) - new_metadata.block_table = blk_table - new_metadata.slot_mapping = slot_mapping + new_metadata.block_table = common_metadata.block_table_tensor + new_metadata.slot_mapping = common_metadata.slot_mapping return new_metadata def use_cascade_attention(self, *args, **kwargs) -> bool: diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index 31e9085bfc9c..6da69ecc9fff 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -291,10 +291,9 @@ def update_block_table( self, common_metadata: CommonAttentionMetadata, metadata: M, - blk_table: torch.Tensor, - slot_mapping: torch.Tensor, ) -> M: new_metadata = copy.copy(metadata) + blk_table = common_metadata.block_table_tensor prefix_caching_all_mode = ( self.vllm_config.cache_config.mamba_cache_mode == "all" ) diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index bb407a873726..1d094d7de372 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -395,8 +395,6 @@ def update_block_table( self, common_metadata: CommonAttentionMetadata, metadata: M, - blk_table: torch.Tensor, - slot_mapping: torch.Tensor, ) -> M: """ Update the block table for the attention metadata. diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 20f0a82dc2a1..ab9488889f60 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1730,8 +1730,6 @@ def _build_attn_group_metadata( attn_metadata_i = builder.update_block_table( common_attn_metadata, cached_attn_metadata[cache_key], - common_attn_metadata.block_table_tensor, - common_attn_metadata.slot_mapping, ) else: attn_metadata_i = builder.build( From 189b9562b7833641517cbac14073d94b66ecfd8e Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Sat, 3 Jan 2026 11:05:59 +0000 Subject: [PATCH 093/130] update note in cache_full_blocks Signed-off-by: huanghaoyan.hhy --- vllm/v1/core/block_pool.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index cf93218a1873..ce7e396d8a9a 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -255,7 +255,8 @@ def cache_full_blocks( ) for i, blk in enumerate(new_full_blocks): # Some blocks may be null blocks when enabling sparse attention like - # sliding window attention. We skip null blocks here. + # sliding window attention, or Mamba models with prefix-caching in + # align mode. We skip null blocks here. if blk.is_null: continue assert blk.block_hash is None From ba297c7a0557d9c4e52dac329360233a6b673cf7 Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Sat, 3 Jan 2026 12:15:12 +0000 Subject: [PATCH 094/130] fix the max_num_blocks_per_req bug for mamba models with sps Signed-off-by: huanghaoyan.hhy --- vllm/v1/worker/block_table.py | 39 ++++++++---------------------- vllm/v1/worker/cp_utils.py | 15 ++++++++++++ vllm/v1/worker/gpu_input_batch.py | 5 ++-- vllm/v1/worker/gpu_model_runner.py | 34 ++++++++++++++++++++++++-- vllm/v1/worker/tpu_input_batch.py | 3 ++- vllm/v1/worker/tpu_model_runner.py | 11 +++++++++ 6 files changed, 72 insertions(+), 35 deletions(-) diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index efb01bb8675f..24c17f8a3734 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -6,7 +6,6 @@ from vllm.distributed import get_dcp_group, get_pcp_group from vllm.logger import init_logger -from vllm.utils.math_utils import cdiv from vllm.v1.utils import CpuGpuBuffer logger = init_logger(__name__) @@ -255,57 +254,39 @@ class MultiGroupBlockTable: def __init__( self, max_num_reqs: int, - max_model_len: int, max_num_batched_tokens: int, pin_memory: bool, device: torch.device, block_sizes: list[int], kernel_block_sizes: list[int], - num_speculative_tokens: int = 0, + max_num_blocks: list[int], cp_kv_cache_interleave_size: int = 1, ) -> None: - # Note(hc): each dcp rank only store - # (max_model_len//dcp_world_size) tokens in kvcache, - # so the block_size which used for calc max_num_blocks_per_req - # must be multiplied by dcp_world_size. - try: - pcp_world_size = get_pcp_group().world_size - except AssertionError: - # PCP might not be initialized in testing - pcp_world_size = 1 - try: - dcp_world_size = get_dcp_group().world_size - except AssertionError: - # DCP might not be initialized in testing - dcp_world_size = 1 - if len(kernel_block_sizes) != len(block_sizes): raise ValueError( f"kernel_block_sizes length ({len(kernel_block_sizes)}) " f"must match block_sizes length ({len(block_sizes)})" ) - - total_cp_world_size = dcp_world_size * pcp_world_size + if len(max_num_blocks) != len(block_sizes): + raise ValueError( + f"max_num_blocks length ({len(max_num_blocks)}) " + f"must match block_sizes length ({len(block_sizes)})" + ) self.block_tables = [ BlockTable( block_size, max_num_reqs, - # TODO: when prefix-caching and sps are both enable for - # mamba hybrid model, it will need - # `cdiv(max_model_len, block_size * total_cp_world_size) + num_speculative_tokens` # noqa: E501 - # blocks for mamba groups - max( - cdiv(max_model_len, block_size * total_cp_world_size), - 1 + num_speculative_tokens, - ), + max_num_blocks_per_req, max_num_batched_tokens, pin_memory, device, kernel_block_size, cp_kv_cache_interleave_size, ) - for block_size, kernel_block_size in zip(block_sizes, kernel_block_sizes) + for block_size, kernel_block_size, max_num_blocks_per_req in zip( + block_sizes, kernel_block_sizes, max_num_blocks + ) ] def append_row(self, block_ids: tuple[list[int], ...], row_idx: int) -> None: diff --git a/vllm/v1/worker/cp_utils.py b/vllm/v1/worker/cp_utils.py index f666c739b0be..2c2e0b5cdbe2 100644 --- a/vllm/v1/worker/cp_utils.py +++ b/vllm/v1/worker/cp_utils.py @@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Any, cast from vllm.config import VllmConfig, get_layers_from_vllm_config +from vllm.distributed import get_dcp_group, get_pcp_group if TYPE_CHECKING: from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase @@ -40,3 +41,17 @@ def check_attention_cp_compatibility(vllm_config: VllmConfig) -> None: f"but the impl {layer_impl.__class__.__name__} " "does not support PCP." ) + + +def get_total_cp_world_size(): + try: + pcp_world_size = get_pcp_group().world_size + except AssertionError: + # PCP might not be initialized in testing + pcp_world_size = 1 + try: + dcp_world_size = get_dcp_group().world_size + except AssertionError: + # DCP might not be initialized in testing + dcp_world_size = 1 + return dcp_world_size * pcp_world_size diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 14bbd6578e92..e4823cd57d5c 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -90,11 +90,11 @@ def __init__( vocab_size: int, block_sizes: list[int], # The block_size of each kv cache group kernel_block_sizes: list[int], + max_num_blocks_per_req: list[int], logitsprocs: LogitsProcessors | None = None, logitsprocs_need_output_token_ids: bool = False, is_spec_decode: bool = False, is_pooling_model: bool = False, - num_speculative_tokens: int = 0, cp_kv_cache_interleave_size: int = 1, ): self.is_pooling_model = is_pooling_model @@ -141,13 +141,12 @@ def __init__( # Block table. self.block_table = MultiGroupBlockTable( max_num_reqs=max_num_reqs, - max_model_len=max_model_len, max_num_batched_tokens=max_num_batched_tokens, pin_memory=pin_memory, device=device, block_sizes=block_sizes, kernel_block_sizes=kernel_block_sizes, - num_speculative_tokens=num_speculative_tokens, + max_num_blocks=max_num_blocks_per_req, cp_kv_cache_interleave_size=cp_kv_cache_interleave_size, ) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index ab9488889f60..797a1791f174 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -153,7 +153,10 @@ from vllm.v1.structured_output.utils import apply_grammar_bitmask from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext from vllm.v1.worker import mamba_utils -from vllm.v1.worker.cp_utils import check_attention_cp_compatibility +from vllm.v1.worker.cp_utils import ( + check_attention_cp_compatibility, + get_total_cp_world_size, +) from vllm.v1.worker.dp_utils import coordinate_batch_across_dp from vllm.v1.worker.ec_connector_model_runner_mixin import ECConnectorModelRunnerMixin from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch @@ -488,6 +491,16 @@ def __init__( vocab_size=self.model_config.get_vocab_size(), block_sizes=[self.cache_config.block_size], kernel_block_sizes=[self.cache_config.block_size], + max_num_blocks_per_req=[ + # Note(hc): each dcp rank only store + # (max_model_len//dcp_world_size) tokens in kvcache, + # so the block_size which used for calc max_num_blocks_per_req + # must be multiplied by dcp_world_size. + cdiv( + self.max_model_len, + self.cache_config.block_size * get_total_cp_world_size(), + ) + ], is_spec_decode=bool(self.vllm_config.speculative_config), logitsprocs=build_logitsprocs( self.vllm_config, @@ -5277,6 +5290,23 @@ def may_reinitialize_input_batch( for kv_cache_group in kv_cache_config.kv_cache_groups if not isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec) ] + max_num_blocks = [] + for i, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups): + if isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec): + continue + max_num_blocks_per_req = cdiv( + self.max_model_len, block_sizes[i] * get_total_cp_world_size() + ) + if isinstance(kv_cache_group.kv_cache_spec, MambaSpec): + mamba_blocks_per_req = ( + max_num_blocks_per_req + if self.cache_config.enable_prefix_caching + else 1 + ) + kv_cache_group.kv_cache_spec.num_speculative_blocks + max_num_blocks_per_req = max( + max_num_blocks_per_req, mamba_blocks_per_req + ) + max_num_blocks.append(max_num_blocks_per_req) if block_sizes != [self.cache_config.block_size] or kernel_block_sizes != [ self.cache_config.block_size @@ -5295,11 +5325,11 @@ def may_reinitialize_input_batch( vocab_size=self.model_config.get_vocab_size(), block_sizes=block_sizes, kernel_block_sizes=kernel_block_sizes, + max_num_blocks_per_req=max_num_blocks, is_spec_decode=bool(self.vllm_config.speculative_config), logitsprocs=self.input_batch.logitsprocs, logitsprocs_need_output_token_ids=self.input_batch.logitsprocs_need_output_token_ids, is_pooling_model=self.is_pooling_model, - num_speculative_tokens=self.num_spec_tokens, ) def _allocate_kv_cache_tensors( diff --git a/vllm/v1/worker/tpu_input_batch.py b/vllm/v1/worker/tpu_input_batch.py index 3758a73ee496..1396c8ad9d5a 100644 --- a/vllm/v1/worker/tpu_input_batch.py +++ b/vllm/v1/worker/tpu_input_batch.py @@ -29,6 +29,7 @@ def __init__( vocab_size: int, block_sizes: list[int], # The block_size of each kv cache group kernel_block_sizes: list[int], + max_num_blocks_per_req: list[int], ): self.max_num_reqs = max_num_reqs self.max_model_len = max_model_len @@ -64,12 +65,12 @@ def __init__( # Block table. self.block_table = MultiGroupBlockTable( max_num_reqs=max_num_reqs, - max_model_len=max_model_len, max_num_batched_tokens=max_num_batched_tokens, pin_memory=pin_memory, device=device, block_sizes=block_sizes, kernel_block_sizes=kernel_block_sizes, + max_num_blocks=max_num_blocks_per_req, ) # Sampling-related. diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index ba5e6200e426..db0d2ae17622 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -77,6 +77,7 @@ ) from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler +from vllm.v1.worker.cp_utils import get_total_cp_world_size from vllm.v1.worker.kv_connector_model_runner_mixin import ( KVConnectorModelRunnerMixin, KVConnectorOutput, @@ -261,6 +262,9 @@ def __init__( vocab_size=self.model_config.get_vocab_size(), block_sizes=[self.block_size], kernel_block_sizes=[self.cache_config.block_size], + max_num_blocks_per_req=[ + cdiv(self.max_model_len, self.block_size * get_total_cp_world_size()) + ], ) # Cached torch/numpy tensor @@ -1846,6 +1850,13 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: kernel_block_sizes=[ kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size ], + max_num_blocks_per_req=[ + cdiv( + self.max_model_len, + kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size + * get_total_cp_world_size(), + ) + ], ) # Verify dtype compatibility between block_table_cpu and input_batch assert ( From 81f6b8056102292afbd6b5f5b2c0b2012dbde8c5 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Sat, 3 Jan 2026 22:37:58 -0800 Subject: [PATCH 095/130] fix test Signed-off-by: Chen Zhang --- tests/v1/e2e/test_mamba_prefix_cache.py | 105 +++++++++++++----------- 1 file changed, 58 insertions(+), 47 deletions(-) diff --git a/tests/v1/e2e/test_mamba_prefix_cache.py b/tests/v1/e2e/test_mamba_prefix_cache.py index 5c52be740978..4544bc18462b 100644 --- a/tests/v1/e2e/test_mamba_prefix_cache.py +++ b/tests/v1/e2e/test_mamba_prefix_cache.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import multiprocessing as mp import os +import traceback from collections.abc import Callable from dataclasses import dataclass from typing import Any @@ -169,6 +171,7 @@ def fake_allocate_slots_fn( num_new_computed_tokens: int = 0, new_computed_blocks: KVCacheBlocks | None = None, num_lookahead_tokens: int = 0, + num_external_computed_tokens: int = 0, delay_cache_blocks: bool = False, num_encoder_tokens: int = 0, ): @@ -179,6 +182,7 @@ def fake_allocate_slots_fn( num_new_computed_tokens, new_computed_blocks, num_lookahead_tokens, + num_external_computed_tokens, delay_cache_blocks, num_encoder_tokens, ) @@ -384,39 +388,56 @@ def fake_copy_fn( return fake_preprocess_mamba_fn, fake_post_process_mamba_fn, fake_copy_fn -def test_run_ref_mamba_state(monkeypatch: pytest.MonkeyPatch): - monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") - num_generated_tokens = 8000 - num_prompt_tokens = 500 - sampling_params = SamplingParams(temperature=0.0, max_tokens=num_generated_tokens) - with open(f"{os.path.dirname(__file__)}/input.txt") as file: - full_prompt = file.read() - fake_execute_model_fn = get_fake_execute_model_fn(GPUModelRunner.execute_model) - monkeypatch.setattr(GPUModelRunner, "execute_model", fake_execute_model_fn) - fake_sample_fn = get_fake_sample_fn() - monkeypatch.setattr(GPUModelRunner, "_sample", fake_sample_fn) - engine = LLM( - model=MODEL, - block_size=BLOCK_SIZE, - hf_overrides={"num_hidden_layers": NUM_HIDDEN_LAYERS}, - seed=42, - ) - global prompt_token_ids - prompt_token_ids = engine.get_tokenizer().encode(full_prompt) - print(f"Token IDs length: {len(prompt_token_ids)}") +def run_ref_mamba_state_in_subprocess() -> None: + ctx = mp.get_context("spawn") + proc = ctx.Process(target=_run_ref_mamba_state_worker) + proc.start() + proc.join(timeout=600) + if proc.exitcode != 0: + raise RuntimeError(f"Ref mamba state process exited with code {proc.exitcode}.") - outputs = engine.generate( - [TokensPrompt(prompt_token_ids=prompt_token_ids[:num_prompt_tokens])], - sampling_params, - ) - print(f"Generated text: {outputs[0].outputs[0].token_ids}") - print( - f"expect token ids: {prompt_token_ids[num_prompt_tokens : num_prompt_tokens + num_generated_tokens]}" # noqa: E501 - ) - print(f"mamba_kv_cache_dict: {mamba_kv_cache_dict.keys()}") - # ref_mamba_kv_cache_dict = torch.load("mamba_kv_cache_dict.pth") - # check_mamba_state_equal(ref_mamba_kv_cache_dict, mamba_kv_cache_dict) - torch.save(mamba_kv_cache_dict, "mamba_kv_cache_dict.pth") + +def _run_ref_mamba_state_worker(): + try: + os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" + num_generated_tokens = 8000 + num_prompt_tokens = 500 + sampling_params = SamplingParams( + temperature=0.0, max_tokens=num_generated_tokens + ) + with open(f"{os.path.dirname(__file__)}/input.txt") as file: + full_prompt = file.read() + fake_execute_model_fn = get_fake_execute_model_fn(GPUModelRunner.execute_model) + GPUModelRunner.execute_model = fake_execute_model_fn + fake_sample_fn = get_fake_sample_fn() + GPUModelRunner._sample = fake_sample_fn + engine = LLM( + model=MODEL, + block_size=BLOCK_SIZE, + hf_overrides={"num_hidden_layers": NUM_HIDDEN_LAYERS}, + seed=42, + ) + global prompt_token_ids + prompt_token_ids = engine.get_tokenizer().encode(full_prompt) + print(f"Token IDs length: {len(prompt_token_ids)}") + + _outputs = engine.generate( + [TokensPrompt(prompt_token_ids=prompt_token_ids[:num_prompt_tokens])], + sampling_params, + ) + print(f"mamba_kv_cache_dict: {mamba_kv_cache_dict.keys()}") + # ref_mamba_kv_cache_dict = torch.load("mamba_kv_cache_dict.pth") + # check_mamba_state_equal(ref_mamba_kv_cache_dict, mamba_kv_cache_dict) + # torch.save(mamba_kv_cache_dict, "mamba_kv_cache_dict.pth") + cpu_state_ref = { + key: tuple(tensor.detach().cpu() for tensor in tensors) + for key, tensors in mamba_kv_cache_dict.items() + } + torch.save(cpu_state_ref, "mamba_kv_cache_dict_ref.pth") + mamba_kv_cache_dict.clear() + except Exception: + traceback.print_exc() + raise def check_mamba_state_equal( @@ -430,6 +451,8 @@ def check_mamba_state_equal( # mamba state new is a subset of mamba state ref for i, (ref, new) in enumerate(zip(mamba_state_ref[key], mamba_state_new[key])): print("check_mamba_state_equal: ", ref.shape, new.shape) + if ref.device != new.device: + new = new.to(ref.device) new = new[: ref.shape[0]] print("check_mamba_state_equal after convert: ", ref.shape, new.shape) if not torch.allclose(ref, new, atol=atol, rtol=rtol): @@ -495,6 +518,7 @@ def apply_patch(monkeypatch: pytest.MonkeyPatch): def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch): + run_ref_mamba_state_in_subprocess() apply_patch(monkeypatch) with open(f"{os.path.dirname(__file__)}/input.txt") as file: full_prompt = file.read() @@ -742,7 +766,7 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch): prompt_token_ids = engine.get_tokenizer().encode(full_prompt) # print(f"Token IDs: {token_ids}") print(f"Token IDs length: {len(prompt_token_ids)}") - mamba_state_ref = torch.load("mamba_kv_cache_dict.pth") + # mamba_state_ref = torch.load("mamba_kv_cache_dict.pth") for test_case_name, test_config in tests.items(): print(f"Running test case: {test_case_name}") num_generated_tokens = test_config.num_generated_tokens @@ -767,10 +791,6 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch): global step_actions step_actions = test_config.step_actions print("step actions: ", step_actions) - print( - f"expect token ids: {prompt_token_ids[num_prompt_tokens : num_prompt_tokens + num_generated_tokens]}" # noqa: E501 - ) - _ = engine.generate( [TokensPrompt(prompt_token_ids=prompt_token_ids[:num_prompt_tokens])], sampling_params, @@ -782,15 +802,6 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch): for action in test_config.step_actions if action.postprocess_copy_idx and action.postprocess_copy_idx[0] != -1 ] + mamba_state_ref = torch.load("mamba_kv_cache_dict_ref.pth") check_mamba_state_equal(mamba_state_ref, mamba_kv_cache_dict, keys_to_check) mamba_kv_cache_dict.clear() - - -def test_check_mamba_state_equal(): - mamba_state_ref = torch.load("mamba_kv_cache_dict.pth") - mamba_state_new = torch.load("mamba_kv_cache_dict_new.pth") - check_mamba_state_equal( - mamba_state_ref, - mamba_state_new, - keys_to_check=list(mamba_state_ref.keys()), - ) From 0cb4c0444db01d25c997108b0bc246c6a2111f64 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Sat, 3 Jan 2026 22:47:16 -0800 Subject: [PATCH 096/130] cleanup unit test Signed-off-by: Chen Zhang --- tests/v1/e2e/test_mamba_prefix_cache.py | 59 +++---------------------- 1 file changed, 6 insertions(+), 53 deletions(-) diff --git a/tests/v1/e2e/test_mamba_prefix_cache.py b/tests/v1/e2e/test_mamba_prefix_cache.py index 4544bc18462b..554d94f9a803 100644 --- a/tests/v1/e2e/test_mamba_prefix_cache.py +++ b/tests/v1/e2e/test_mamba_prefix_cache.py @@ -7,6 +7,7 @@ from dataclasses import dataclass from typing import Any +import datasets import pytest import torch @@ -58,9 +59,6 @@ def fake_sample_fn( spec_decode_metadata: SpecDecodeMetadata | None, ) -> SamplerOutput: assert logits is not None - print( - f"[UNIT TEST] fake_sample_fn: {logits.shape=} {spec_decode_metadata=} {self.input_ids.cpu=}" # noqa: E501 - ) num_computed_tokens_cpu_tensor = self.input_batch.num_computed_tokens_cpu_tensor num_computed_tokens = num_computed_tokens_cpu_tensor[0].item() if num_computed_tokens < self.input_batch.num_prompt_tokens[0].item(): @@ -68,9 +66,6 @@ def fake_sample_fn( else: first_token_id_index = num_computed_tokens + 1 if spec_decode_metadata is None: - print( - f"[UNIT TEST] fake_sample_fn: {first_token_id_index=} {prompt_token_ids[first_token_id_index]=}" # noqa: E501 - ) return SamplerOutput( sampled_token_ids=torch.tensor( [[prompt_token_ids[first_token_id_index]]], @@ -87,17 +82,6 @@ def fake_sample_fn( sampled_token_ids = accpeted_tokens + [-1] * ( num_sampled_tokens - len(accpeted_tokens) ) - print( - f"[UNIT TEST] fake_sample_fn: {first_token_id_index=} {accpeted_tokens=} {sampled_token_ids=}" # noqa: E501 - ) - # if ( - # self.input_batch.num_computed_tokens_cpu_tensor[0].item() - # >= self.input_batch.num_prompt_tokens[0].item() - # ): - # for i, x in enumerate(sampled_token_ids): - # if x == -1: - # continue - # assert x == self.input_ids.cpu[i + 1] return SamplerOutput( sampled_token_ids=torch.tensor( [sampled_token_ids], device="cuda", dtype=torch.int32 @@ -131,18 +115,12 @@ def fake_propose_draft_token_ids_fn( first_token_id_index = ( num_computed_tokens + 1 ) # bonus token isn't considered as computed - print( - f"fake_propose_draft_token_ids_fn: {self.input_batch.num_accepted_tokens_cpu=}" # noqa: E501 - ) first_token_id_index += self.input_batch.num_accepted_tokens_cpu[0].item() proposed_draft_token_ids = [ prompt_token_ids[ first_token_id_index : first_token_id_index + num_speculative_tokens ] ] - print( - f"[UNIT TEST] fake_propose_draft_token_ids_fn: {num_computed_tokens=} num_accepted_tokens={self.input_batch.num_accepted_tokens_cpu[0].item()} num_prompt_tokens={self.input_batch.num_prompt_tokens[0].item()} num_tokens_no_spec={self.input_batch.num_tokens_no_spec[0].item()} {first_token_id_index=} {proposed_draft_token_ids=}" # noqa: E501 - ) return proposed_draft_token_ids return fake_propose_draft_token_ids_fn @@ -157,7 +135,7 @@ def fake_get_output(self: InprocClient): cur_step_action_idx += 1 else: cur_step_action = None - print(f"fake_get_output: {cur_step_action_idx=} {cur_step_action=}") + print(f"cur_step_action: {cur_step_action_idx=} {cur_step_action=}") return original_step_action_fn(self) return fake_get_output @@ -187,7 +165,6 @@ def fake_allocate_slots_fn( num_encoder_tokens, ) if cur_step_action is not None: - print("[UNIT TEST STEP] verifying kv_cache_block_ids") cur_block_ids = self.coordinator.single_type_managers[0].req_to_blocks[ request.request_id ] @@ -215,30 +192,22 @@ def fake_execute_model_fn( iter(scheduler_output.num_scheduled_tokens.values()) ) assert num_scheduled_tokens == cur_step_action.num_scheduled_tokens - print("[UNIT TEST STEP] verified num_scheduled_tokens") mamba_group_ids, mamba_spec = get_mamba_groups(self.kv_cache_config) mamba_group_id = mamba_group_ids[0] mamba_layer_name = self.kv_cache_config.kv_cache_groups[ mamba_group_id ].layer_names[0] - print(f"fake_execute_model_fn: {mamba_spec=}") nonlocal last_num_computed_tokens if len(scheduler_output.scheduled_cached_reqs.req_ids) > 0: num_computed_tokens = ( scheduler_output.scheduled_cached_reqs.num_computed_tokens[0] ) - print( - f"fake_execute_model_fn: {num_computed_tokens=} {last_num_computed_tokens=} {num_computed_tokens // BLOCK_SIZE > last_num_computed_tokens // BLOCK_SIZE=}" # noqa: E501 - ) if ( num_computed_tokens // BLOCK_SIZE > last_num_computed_tokens // BLOCK_SIZE ): # generated a new aligned block in this step block_idx = num_computed_tokens // mamba_spec.block_size - 1 - print( - f"[UNIT TEST] fake_execute_model_fn: block_idx= {block_idx} for num_computed_tokens={num_computed_tokens - num_computed_tokens % BLOCK_SIZE}" # noqa: E501 - ) block_id = ( self.input_batch.block_table.block_tables[mamba_group_id] .block_table.cpu[0, block_idx] @@ -258,7 +227,6 @@ def fake_execute_model_fn( last_num_computed_tokens = num_computed_tokens else: last_num_computed_tokens = 0 - print("[UNIT TEST] fake_execute_model_fn: clear last_num_computed_tokens") ret = original_execute_model_fn(self, scheduler_output, intermediate_tensors) @@ -267,7 +235,6 @@ def fake_execute_model_fn( cur_step_action.num_computed_tokens_start == self.input_batch.num_computed_tokens_cpu[0].item() ) - print("[UNIT TEST STEP] verified num_computed_tokens_start") return ret @@ -405,8 +372,8 @@ def _run_ref_mamba_state_worker(): sampling_params = SamplingParams( temperature=0.0, max_tokens=num_generated_tokens ) - with open(f"{os.path.dirname(__file__)}/input.txt") as file: - full_prompt = file.read() + prompt_dataset = datasets.load_dataset("heheda/a_long_article") + full_prompt = prompt_dataset["train"][0]["text"] fake_execute_model_fn = get_fake_execute_model_fn(GPUModelRunner.execute_model) GPUModelRunner.execute_model = fake_execute_model_fn fake_sample_fn = get_fake_sample_fn() @@ -425,7 +392,6 @@ def _run_ref_mamba_state_worker(): [TokensPrompt(prompt_token_ids=prompt_token_ids[:num_prompt_tokens])], sampling_params, ) - print(f"mamba_kv_cache_dict: {mamba_kv_cache_dict.keys()}") # ref_mamba_kv_cache_dict = torch.load("mamba_kv_cache_dict.pth") # check_mamba_state_equal(ref_mamba_kv_cache_dict, mamba_kv_cache_dict) # torch.save(mamba_kv_cache_dict, "mamba_kv_cache_dict.pth") @@ -450,11 +416,9 @@ def check_mamba_state_equal( assert key in mamba_state_ref # mamba state new is a subset of mamba state ref for i, (ref, new) in enumerate(zip(mamba_state_ref[key], mamba_state_new[key])): - print("check_mamba_state_equal: ", ref.shape, new.shape) if ref.device != new.device: new = new.to(ref.device) new = new[: ref.shape[0]] - print("check_mamba_state_equal after convert: ", ref.shape, new.shape) if not torch.allclose(ref, new, atol=atol, rtol=rtol): diff_mask = ~torch.isclose(ref, new, atol=atol, rtol=rtol) diff_idx = torch.nonzero(diff_mask) @@ -463,14 +427,6 @@ def check_mamba_state_equal( f"[WARNING] found {diff_idx.shape[0] * 100 / ref.numel()}% of the elements are different" # noqa: E501 ) continue - print( - "diff: ", - diff_idx.shape, - diff_idx, - ref[diff_mask], - new[diff_mask], - torch.max(torch.abs(ref - new)), - ) raise ValueError( f"Mamba state is not equal for key: {key} at index {i}" ) @@ -520,8 +476,8 @@ def apply_patch(monkeypatch: pytest.MonkeyPatch): def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch): run_ref_mamba_state_in_subprocess() apply_patch(monkeypatch) - with open(f"{os.path.dirname(__file__)}/input.txt") as file: - full_prompt = file.read() + prompt_dataset = datasets.load_dataset("heheda/a_long_article") + full_prompt = prompt_dataset["train"][0]["text"] tests = { "accept_1": TestConfig( num_prompt_tokens=554, @@ -764,9 +720,7 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch): ) global prompt_token_ids prompt_token_ids = engine.get_tokenizer().encode(full_prompt) - # print(f"Token IDs: {token_ids}") print(f"Token IDs length: {len(prompt_token_ids)}") - # mamba_state_ref = torch.load("mamba_kv_cache_dict.pth") for test_case_name, test_config in tests.items(): print(f"Running test case: {test_case_name}") num_generated_tokens = test_config.num_generated_tokens @@ -790,7 +744,6 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch): step_action_next.kv_cache_block_ids = prev_block_ids.copy() global step_actions step_actions = test_config.step_actions - print("step actions: ", step_actions) _ = engine.generate( [TokensPrompt(prompt_token_ids=prompt_token_ids[:num_prompt_tokens])], sampling_params, From 1882efd0f97e55a228e9f4024d20cce6d638ecec Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Sat, 3 Jan 2026 23:27:19 -0800 Subject: [PATCH 097/130] remove debug scripts Signed-off-by: Chen Zhang --- examples/offline_inference/run.py | 56 ----- tests/v1/e2e/input.txt | 399 ------------------------------ 2 files changed, 455 deletions(-) delete mode 100644 examples/offline_inference/run.py delete mode 100644 tests/v1/e2e/input.txt diff --git a/examples/offline_inference/run.py b/examples/offline_inference/run.py deleted file mode 100644 index 5f3933a6aea4..000000000000 --- a/examples/offline_inference/run.py +++ /dev/null @@ -1,56 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import time - -from vllm import LLM, SamplingParams - - -def main(): - MODEL = "Qwen/Qwen3-Next-80B-A3B-Instruct" - PROMPT_MULTIPLE = 6 - sampling_params = SamplingParams(temperature=0.0, max_tokens=300) - prefix = ( # examples/offline_inference/prefix_caching.py - "Your name is QQQQ " - "You are an expert school principal, skilled in effectively managing " - "faculty and staff. Draft 10-15 questions for a potential first grade " - "Head Teacher for my K-12, all-girls', independent school that emphasizes " - "community, joyful discovery, and life-long learning. The candidate is " - "coming in for a first-round panel interview for a 8th grade Math " - "teaching role. They have 5 years of previous teaching experience " - "as an assistant teacher at a co-ed, public school with experience " - "in middle school math teaching. " - ) - prefix2 = "Based on these information, fulfill the following paragraph: " - prompt = PROMPT_MULTIPLE * prefix + prefix2 + "Hello, my name is" - # print('Prompt length:', ) - # for APC in [False, True]: - for APC in [True]: - engine = LLM( - model=MODEL, - enable_prefix_caching=APC, - enforce_eager=True, - tensor_parallel_size=4, - block_size=288, - mamba_cache_mode="align", - # load_format="dummy", - speculative_config={ - "method": "qwen3_next_mtp", - "num_speculative_tokens": 2, - }, - ) - for i in range(3): - if i == 0: - print("Warm-up") - if i == 1: - print("Measuring") - start_time = time.time() - outputs = engine.generate(prompt, sampling_params) - print("APC:", APC, i, f"Generated text: {outputs[0].outputs[0].text!r}") - # for m in engine.llm_engine.get_metrics(): - # if 'vllm:prefix_cache_hits' in m.name: - # print(m.name, m.value) - print("APC:", APC, "loop took --- %s seconds ---" % (time.time() - start_time)) - - -if __name__ == "__main__": - main() diff --git a/tests/v1/e2e/input.txt b/tests/v1/e2e/input.txt deleted file mode 100644 index f5cee144f9e8..000000000000 --- a/tests/v1/e2e/input.txt +++ /dev/null @@ -1,399 +0,0 @@ -# The Architecture of Intelligence: A Deep Dive into Large Language Models (LLMs) - -## Introduction: The New Cognitive Revolution - -In the annals of computing history, few technologies have burst onto the global stage with the same immediate and transformative impact as Large Language Models (LLMs). Emerging from the confluence of decades of theoretical research and the exponential growth of computational power and data, LLMs like GPT, Gemini, and Claude have transitioned Artificial Intelligence (AI) from a niche academic pursuit to the central utility of the digital age. - -An LLM is not merely a sophisticated piece of software; it is a complex, deep neural network designed to understand, process, and generate human language with startling fluency, coherence, and context. These models serve as the probabilistic engines of a new cognitive revolution, capable of tasks that range from synthesizing vast datasets and translating languages to creating novel code and engaging in philosophical debate. - -This comprehensive article explores the complete landscape of Large Language Models. We will trace their historical lineage, demystify the revolutionary architecture upon which they are built, detail the arduous training process, analyze the emergent capabilities and inherent flaws, survey their massive commercial and social applications, and, finally, grapple with the profound ethical and strategic challenges they pose for the future of humanity. - -## Part I: The Historical Foundations of Language Modeling - -The concept of a machine generating human language has a history far longer than the digital computer. Its modern journey, however, can be segmented into distinct eras, each overcoming the limitations of the last. - -### 1. Statistical Language Models (1980s – 2000s) -The earliest forms of language modeling were rooted in statistics and probability theory. These were dominated by **n-gram models**, inspired by the mathematical work of Andrey Markov. An n-gram model predicts the probability of the next word ($w_i$) based solely on the previous $n-1$ words ($w_{i-(n-1)}, \dots, w_{i-1}$). - -$$P(w_i | w_{1}^{i-1}) \approx P(w_i | w_{i-(n-1)}^{i-1})$$ - -These models were simple, explainable, and formed the backbone of early machine translation and speech recognition systems, notably pioneering corpus-based language modeling at IBM. However, they suffered from **the curse of dimensionality** and **data sparsity**. As $n$ increased (to capture more context), the number of possible word sequences grew exponentially, making it impossible to accurately estimate probabilities for sequences not seen in the training data. - -### 2. Neural Language Models and Deep Learning (2000s – 2017) -The transition from statistical methods to neural networks addressed the data sparsity problem. The breakthrough came with the introduction of **word embeddings** (pioneered by Bengio in 2003, and popularized by Word2Vec in 2013). - -Instead of treating words as discrete, independent symbols, word embeddings represent each word as a dense, real-valued vector in a multi-dimensional space. Words with similar meanings (e.g., "King," "Queen," "Man," "Woman") are mapped closer together in this geometric space. This allowed the models to generalize, moving beyond simple word co-occurrence to semantic relationships. - -The workhorse of this era was the **Recurrent Neural Network (RNN)**, particularly the **Long Short-Term Memory (LSTM)** network. RNNs process sequences word-by-word, maintaining a "hidden state" or "memory cell" that accumulates information from the previous steps. This allowed them to handle longer-term dependencies than n-gram models. However, the sequential nature of RNNs created two major issues: -1. **Slow Training:** Processing must be strictly sequential, preventing the use of modern parallel computing hardware like GPUs. -2. **Vanishing/Exploding Gradients:** For very long sequences, the error signals used during training (gradients) either vanished (making the model forget the beginning of the text) or exploded (making training unstable). - -### 3. The Attention Mechanism (2014) -The first true step toward the LLM revolution was the introduction of the **Attention Mechanism** in 2014. Used initially within RNN-based encoder-decoder architectures (the basis of Google Translate at the time), attention allowed the model to dynamically weigh the importance of different parts of the input sequence when generating a specific part of the output. This was crucial for tasks like translation, where the most relevant input word might not be the adjacent one. - -## Part II: The Transformer Architecture (2017 - Present) - -The year 2017 marks the true beginning of the LLM era with the publication of "Attention Is All You Need" by researchers at Google. This paper proposed the **Transformer** architecture, which jettisoned recurrence entirely and relied *only* on the attention mechanism. - -### The Encoder-Decoder Foundation -The original Transformer model consists of two main stacks: an **Encoder** and a **Decoder**. -* **Encoder:** Processes the input sequence (e.g., an English sentence), creating a robust, context-aware numerical representation of it. -* **Decoder:** Takes the Encoder's output and iteratively generates the output sequence (e.g., the French translation). - -### The Self-Attention Breakthrough -The core innovation is **Self-Attention**. It allows the model to calculate how much every word in the input sequence relates to every other word *within that same sequence*. This is done through a mathematical process involving three vector representations for each input token: - -1. **Query ($Q$):** Represents the token being processed—the question being asked. -2. **Key ($K$):** Represents all other tokens—the information that can be searched. -3. **Value ($V$):** Represents the actual information content of all other tokens. - -The model computes the dot product of the $Q$ vector with all $K$ vectors to get **attention scores**. These scores, after normalization (using a Softmax function), determine how much of the $V$ vectors should be aggregated to create the new, context-rich representation of the original token. - -$$\text{Attention}(Q, K, V) = \text{Softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$ - -This allows the model to achieve **parallel processing**. Unlike sequential RNNs, every word's vector representation can be calculated simultaneously, leveraging the massive parallel capabilities of GPUs and leading to unprecedented scalability. - -### Positional Encoding -Since the Transformer has no inherent recurrence (no left-to-right reading), the model needs a way to know the order of the words. This is solved by **Positional Encoding**—adding a vector to the input embeddings that contains information about the word’s absolute or relative position in the sequence. Without this, the phrase "Dog bites man" would be processed identically to "Man bites dog." - -### Model Variants: BERT vs. GPT -The Transformer architecture gave rise to three major model families: - -1. **Encoder-Only (e.g., BERT, RoBERTa):** Used primarily for *understanding* tasks (classification, named entity recognition, sentiment analysis). They are excellent at bidirectional context (looking both backward and forward in a sentence). -2. **Decoder-Only (e.g., GPT, Llama):** Used primarily for *generation* tasks. The decoder is constrained by a **causal mask** that prevents it from looking at future tokens, forcing it to generate text sequentially, word-by-word. These models have become the dominant architecture for conversational AI. -3. **Encoder-Decoder (e.g., T5, BART):** Used for sequence-to-sequence tasks like translation and summarization. - -## Part III: The Training Lifecycle of an LLM - -The development of an LLM is a complex, multi-stage process involving massive computational resources, vast data curation efforts, and sophisticated human intervention. - -### 1. Data Curation and Tokenization -The first step is gathering and cleaning the training corpus. Modern LLMs are trained on hundreds of terabytes or even petabytes of text, often sourced from: -* **CommonCrawl:** A massive, open-source scrape of the public internet. -* **Filtered Web Text:** Highly curated, higher-quality web pages. -* **Books and Literature:** Digitized libraries. -* **Code Repositories:** Such as GitHub, to instill programming knowledge. -* **Wikipedia:** Structured knowledge bases. - -Data is meticulously filtered to remove low-quality content, boilerplate text, and offensive material. The text is then broken down into **tokens** using a process like **Byte-Pair Encoding (BPE)**. Tokens are the minimal units of meaning the model processes, bridging the gap between human language and numerical vectors. - -### 2. Pre-Training: Self-Supervised Learning -The core of LLM training is the **Pre-Training** phase. The model's hundreds of billions of parameters are initialized, and it is fed the massive, unlabeled dataset. The primary objective is **Next-Token Prediction** (or autoregressive modeling): predicting the next most probable token in a sequence, given all previous tokens. - -* **Objective Function:** The model minimizes the **Loss Function** (often **Cross-Entropy Loss**), which measures the difference between the model's predicted probability distribution over the vocabulary and the actual next token. -* **Optimization:** The model iteratively adjusts its weights using **Backpropagation** and an **Optimizer** (e.g., Adam or its variants) to reduce this loss. - -This phase, costing millions of dollars in GPU time, imbues the model with its fundamental knowledge base, grammar, syntax, and a basic, structural understanding of the world. It is through this pure statistical exercise that "reasoning" begins to emerge. - -### 3. Fine-Tuning and Alignment -A raw pre-trained model is highly knowledgeable but often unhelpful and potentially toxic. It will simply continue the statistical pattern of the input, regardless of intent. Alignment is the process of making the model follow instructions and adhere to ethical guidelines. - -#### A. Supervised Fine-Tuning (SFT) -The model is trained on a smaller, high-quality, human-curated dataset of prompts and desired, high-quality responses. This teaches the model a conversational style—how to act as an assistant, answer questions, and follow complex directions. - -#### B. Reinforcement Learning from Human Feedback (RLHF) -RLHF is the key component that created the conversational brilliance of models like ChatGPT. -1. **Response Generation:** For a given prompt, the LLM generates several possible answers. -2. **Human Ranking:** Human labelers rank these responses from best to worst based on helpfulness, accuracy, and safety. -3. **Reward Model Training:** A separate, smaller model called the **Reward Model (RM)** is trained to predict the human preference score for any response. The RM effectively learns "what a good answer looks like." -4. **Policy Optimization:** The main LLM is then fine-tuned using a Reinforcement Learning algorithm (like **Proximal Policy Optimization, PPO**) to maximize the score given by the Reward Model. - -This process explicitly aligns the model's objective function with human values, a crucial step in preparing the model for public deployment. - -## Part IV: Emergent Capabilities and Inherent Limitations - -The path from a neural network to a cognitive tool is marked by phenomena that both inspire awe and caution. - -### The Phenomenon of Emergence -As LLMs crossed certain thresholds—specifically in parameter count (size) and training data volume—researchers observed **Emergent Capabilities**. These are skills that the model was never explicitly trained for, yet they appear spontaneously. - -* **In-Context Learning (ICL):** The ability to learn a new task from a few examples provided directly in the prompt, without needing formal fine-tuning (Few-Shot Learning). -* **Chain-of-Thought (CoT) Reasoning:** The ability to decompose complex, multi-step problems into sequential reasoning steps, often unlocked by simply telling the model to "think step-by-step." This dramatically improves performance on arithmetic, common sense, and symbolic logic tasks. -* **Multilingual and Code Proficiency:** Models trained primarily on English and code surprisingly develop high-level proficiency in dozens of other languages and complex programming languages. - -These emergent properties suggest that the simple task of next-token prediction, when scaled sufficiently, leads to a kind of generalized, implicit world model—a probabilistic simulation of human knowledge and reasoning. - -### The Challenge of Hallucination -The most significant and stubborn limitation of LLMs is **Hallucination**—the generation of factually incorrect, nonsensical, or unfaithful content that is nevertheless syntactically plausible. - -The root cause lies in the model's core function: it is a **prediction engine, not a retrieval engine**. It does not access an external database of facts; it samples the most statistically likely sequence of tokens based on its internal, compressed world model. If the highest-probability sequence *looks* like a scientific citation but is entirely fabricated, the model will generate it. - -Mitigation strategies, such as **Retrieval-Augmented Generation (RAG)**, which links the LLM to a real-time, verifiable external knowledge source (like a search index or a company database), are essential for using LLMs in high-stakes, fact-based applications. - -## Part V: The Expanding Ecosystem and Applications - -The LLM ecosystem is diversifying rapidly, moving beyond the simple "chatbot" into powerful, specialized tools. - -### 1. Model Scaling and Efficiency -The pursuit of ever-larger models has reached its limits due to cost and data scarcity. The frontier has shifted to efficiency and specialization. -* **Mixture-of-Experts (MoE):** Models like Mixtral use a routing mechanism to activate only a subset of specialized "expert" neural networks for any given query. This allows the model to have a massive total parameter count (high knowledge capacity) while only using a fraction of the computational power (high efficiency). -* **Quantization and Pruning:** Techniques used to reduce the size and computational demands of models, making them executable on smaller devices (e.g., a mobile phone or a personal laptop). - -### 2. Multimodality -The most significant recent breakthrough is the transition from LLMs (Large Language Models) to **LMMs (Large Multimodal Models)**. These models are trained not just on text, but also on images, audio, and video data, allowing them to: -* **Visual Reasoning:** Analyze a complex graph, a photograph, or a technical diagram and answer questions about its content. -* **Audio Processing:** Transcribe, summarize, and understand the context of spoken language directly. -* **Seamless Integration:** Accept a prompt containing text and an image simultaneously (e.g., "Describe this image and write a poem about it"). - -### 3. Industry Applications -LLMs are no longer experimental; they are becoming foundational infrastructure across nearly every industry: -* **Software Engineering:** Automated code generation (e.g., GitHub Copilot), debugging, code translation between languages, and writing documentation. -* **Knowledge Work & Productivity:** Summarizing long documents, drafting complex reports, synthesizing research, and managing data from unstructured sources. -* **Customer Service & Sales:** Highly personalized and efficient conversational AI bots that can handle complex queries beyond simple FAQs. -* **Medicine and Law:** Assisting in drafting legal briefs, summarizing medical records, and cross-referencing diagnostic information (always requiring human oversight). -* **Creative Arts:** Generating marketing copy, scriptwriting, music composition (in conjunction with other AI models), and video production assets. - -## Part VI: The Ethical and Societal Labyrinth - -The power of LLMs brings with it a commensurately large set of ethical, social, and economic risks that demand global governance and responsible development. - -### 1. Bias, Fairness, and Amplification -LLMs are fundamentally statistical mirrors of their training data. If the internet contains biases related to gender, race, or geography, the model will ingest, amplify, and operationalize those biases. -* **Stereotype Reinforcement:** A model might associate certain professions (e.g., "engineer") predominantly with one gender, leading to biased outputs in hiring tools. -* **Harmful Generalizations:** Biases can lead to unfair or discriminatory decision-making when the models are deployed in high-stakes areas like loan applications or judicial risk assessment. -Mitigating bias requires meticulous data curation, adversarial testing, and post-processing "guardrails," but complete elimination remains technically elusive. - -### 2. Misinformation and Disinformation -The ability of LLMs to generate highly convincing, fluent text at scale is a threat to information integrity. Malicious actors can use these tools to: -* **Automate Phishing and Scams:** Generate personalized, sophisticated deceptive content. -* **Create Deepfake Text:** Impersonate real individuals or organizations with convincing prose. -* **Fabricate "Fake News" and Propaganda:** Generate massive volumes of highly plausible, factually false content, overwhelming traditional fact-checking mechanisms and accelerating the breakdown of public trust. - -### 3. Data Privacy and Security -LLMs pose risks related to data ingestion and leakage: -* **Training Data Memorization:** Models can, in rare cases, memorize and regurgitate personally identifiable information (PII) or copyrighted material from their vast training corpus. -* **Inference Attack (Data Leakage):** If a user provides proprietary or sensitive information as a prompt, that data may be inadvertently used to train future iterations of the model or leak through side channels, raising major security concerns for enterprise adoption. - -### 4. Environmental Impact -The scale of LLMs has a significant environmental footprint. Training a single frontier model requires months of continuous operation on thousands of GPUs, consuming energy equivalent to hundreds of homes for a year. The high computational cost raises questions about the long-term sustainability and equitable access to the technology. - -### 5. Economic Disruption and Labor -LLMs are directly impacting knowledge-based professions, particularly those involving content creation, data synthesis, and routine communication. While optimists argue the technology will mostly automate mundane tasks, freeing humans for higher-level work, policymakers and economists are grappling with the reality of rapid job displacement, income inequality, and the need for massive reskilling initiatives. - -## Part VII: The Frontier—The Path to Agentic AI and AGI - -The current state of the art is fleeting. The research community is pushing toward systems that are more autonomous, capable, and integrated. - -### 1. Agentic AI -The shift from a "Chatbot" to an "Agent" is the immediate future. Current LLMs are **reactive** (Question $\rightarrow$ Answer). An Agentic LLM is **proactive and goal-oriented**. -* **Goal:** The user provides a high-level goal (e.g., "Find the cheapest flight to Tokyo next month and book a hotel near the Shinjuku station."). -* **Planning:** The LLM breaks the goal into sub-tasks (Search flights, Compare prices, Search hotels, Check availability, Execute booking actions). -* **Tool Use:** The LLM integrates external tools (search engines, flight APIs, email/calendar APIs) to complete the tasks autonomously, engaging in a trial-and-error loop until the goal is achieved. This transforms the LLM from a generator of text into an executor of complex, multi-step actions. - -### 2. The Multi-Agent Ecosystem -The next stage involves creating swarms of specialized LLM Agents that communicate and collaborate to solve enormous, non-trivial problems. One agent might be a "researcher," another a "coder," and a third an "editor," all collaborating on a project, mimicking a human team. - -### 3. The Pursuit of Artificial General Intelligence (AGI) -The ultimate horizon is Artificial General Intelligence—a machine with the capacity to understand, learn, and apply its intelligence to solve virtually any problem that a human can. - -The debate remains: Is the current path of massive scaling and improved architecture (the **scaling hypothesis**) sufficient to reach AGI, or is some fundamental, non-Transformer-based innovation required? The appearance of emergent properties strongly suggests that the scaling path has not yet exhausted its potential, keeping the AGI goal within the sights of major research labs. - -## Conclusion: The Mirror of Human Intelligence - -Large Language Models are perhaps the most profound technological platform shift since the invention of the Internet. They represent the culmination of 75 years of AI research, transitioning from rule-based systems and statistical models to the deep, parallel processing power of the Transformer architecture. - -LLMs are the definitive statistical compressors of human knowledge, capable of synthesizing our collective digital output with stunning fidelity. They have unlocked a new era of computational creativity and efficiency, driving unprecedented change across every sector. - -Yet, this power is a double-edged sword. LLMs are not inherently wise; they are merely proficient at pattern matching. They reflect and amplify human biases, they can deceive with convincing misinformation, and they introduce profound questions about accountability, labor, and the nature of creative work. - -The future of LLMs is not just about making them *smarter*, but making them *safer*, *more efficient*, and more *aligned* with human values. The challenge for the coming decade is not technical—the algorithms and compute will continue to improve—but **governance and ethical**. Humanity must learn to responsibly wield this powerful mirror of its own intelligence, ensuring that the cognitive revolution we have started leads to a future of prosperity and equitable access, rather than fragmentation and control. The architecture of intelligence is now in our hands; the path forward depends on the wisdom of its design and deployment. - -# The Architecture of Intelligence: A Deep Dive into Large Language Models (LLMs) - -## Introduction: The New Cognitive Revolution - -In the annals of computing history, few technologies have burst onto the global stage with the same immediate and transformative impact as Large Language Models (LLMs). Emerging from the confluence of decades of theoretical research and the exponential growth of computational power and data, LLMs like GPT, Gemini, and Claude have transitioned Artificial Intelligence (AI) from a niche academic pursuit to the central utility of the digital age. - -An LLM is not merely a sophisticated piece of software; it is a complex, deep neural network designed to understand, process, and generate human language with startling fluency, coherence, and context. These models serve as the probabilistic engines of a new cognitive revolution, capable of tasks that range from synthesizing vast datasets and translating languages to creating novel code and engaging in philosophical debate. - -This comprehensive article explores the complete landscape of Large Language Models. We will trace their historical lineage, demystify the revolutionary architecture upon which they are built, detail the arduous training process, analyze the emergent capabilities and inherent flaws, survey their massive commercial and social applications, and, finally, grapple with the profound ethical and strategic challenges they pose for the future of humanity. - -## Part I: The Historical Foundations of Language Modeling - -The concept of a machine generating human language has a history far longer than the digital computer. Its modern journey, however, can be segmented into distinct eras, each overcoming the limitations of the last. - -### 1. Statistical Language Models (1980s – 2000s) -The earliest forms of language modeling were rooted in statistics and probability theory. These were dominated by **n-gram models**, inspired by the mathematical work of Andrey Markov. An n-gram model predicts the probability of the next word ($w_i$) based solely on the previous $n-1$ words ($w_{i-(n-1)}, \dots, w_{i-1}$). - -$$P(w_i | w_{1}^{i-1}) \approx P(w_i | w_{i-(n-1)}^{i-1})$$ - -These models were simple, explainable, and formed the backbone of early machine translation and speech recognition systems, notably pioneering corpus-based language modeling at IBM. However, they suffered from **the curse of dimensionality** and **data sparsity**. As $n$ increased (to capture more context), the number of possible word sequences grew exponentially, making it impossible to accurately estimate probabilities for sequences not seen in the training data. - -### 2. Neural Language Models and Deep Learning (2000s – 2017) -The transition from statistical methods to neural networks addressed the data sparsity problem. The breakthrough came with the introduction of **word embeddings** (pioneered by Bengio in 2003, and popularized by Word2Vec in 2013). - -Instead of treating words as discrete, independent symbols, word embeddings represent each word as a dense, real-valued vector in a multi-dimensional space. Words with similar meanings (e.g., "King," "Queen," "Man," "Woman") are mapped closer together in this geometric space. This allowed the models to generalize, moving beyond simple word co-occurrence to semantic relationships. - -The workhorse of this era was the **Recurrent Neural Network (RNN)**, particularly the **Long Short-Term Memory (LSTM)** network. RNNs process sequences word-by-word, maintaining a "hidden state" or "memory cell" that accumulates information from the previous steps. This allowed them to handle longer-term dependencies than n-gram models. However, the sequential nature of RNNs created two major issues: -1. **Slow Training:** Processing must be strictly sequential, preventing the use of modern parallel computing hardware like GPUs. -2. **Vanishing/Exploding Gradients:** For very long sequences, the error signals used during training (gradients) either vanished (making the model forget the beginning of the text) or exploded (making training unstable). - -### 3. The Attention Mechanism (2014) -The first true step toward the LLM revolution was the introduction of the **Attention Mechanism** in 2014. Used initially within RNN-based encoder-decoder architectures (the basis of Google Translate at the time), attention allowed the model to dynamically weigh the importance of different parts of the input sequence when generating a specific part of the output. This was crucial for tasks like translation, where the most relevant input word might not be the adjacent one. - -## Part II: The Transformer Architecture (2017 - Present) - -The year 2017 marks the true beginning of the LLM era with the publication of "Attention Is All You Need" by researchers at Google. This paper proposed the **Transformer** architecture, which jettisoned recurrence entirely and relied *only* on the attention mechanism. - -### The Encoder-Decoder Foundation -The original Transformer model consists of two main stacks: an **Encoder** and a **Decoder**. -* **Encoder:** Processes the input sequence (e.g., an English sentence), creating a robust, context-aware numerical representation of it. -* **Decoder:** Takes the Encoder's output and iteratively generates the output sequence (e.g., the French translation). - -### The Self-Attention Breakthrough -The core innovation is **Self-Attention**. It allows the model to calculate how much every word in the input sequence relates to every other word *within that same sequence*. This is done through a mathematical process involving three vector representations for each input token: - -1. **Query ($Q$):** Represents the token being processed—the question being asked. -2. **Key ($K$):** Represents all other tokens—the information that can be searched. -3. **Value ($V$):** Represents the actual information content of all other tokens. - -The model computes the dot product of the $Q$ vector with all $K$ vectors to get **attention scores**. These scores, after normalization (using a Softmax function), determine how much of the $V$ vectors should be aggregated to create the new, context-rich representation of the original token. - -$$\text{Attention}(Q, K, V) = \text{Softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$ - -This allows the model to achieve **parallel processing**. Unlike sequential RNNs, every word's vector representation can be calculated simultaneously, leveraging the massive parallel capabilities of GPUs and leading to unprecedented scalability. - -### Positional Encoding -Since the Transformer has no inherent recurrence (no left-to-right reading), the model needs a way to know the order of the words. This is solved by **Positional Encoding**—adding a vector to the input embeddings that contains information about the word’s absolute or relative position in the sequence. Without this, the phrase "Dog bites man" would be processed identically to "Man bites dog." - -### Model Variants: BERT vs. GPT -The Transformer architecture gave rise to three major model families: - -1. **Encoder-Only (e.g., BERT, RoBERTa):** Used primarily for *understanding* tasks (classification, named entity recognition, sentiment analysis). They are excellent at bidirectional context (looking both backward and forward in a sentence). -2. **Decoder-Only (e.g., GPT, Llama):** Used primarily for *generation* tasks. The decoder is constrained by a **causal mask** that prevents it from looking at future tokens, forcing it to generate text sequentially, word-by-word. These models have become the dominant architecture for conversational AI. -3. **Encoder-Decoder (e.g., T5, BART):** Used for sequence-to-sequence tasks like translation and summarization. - -## Part III: The Training Lifecycle of an LLM - -The development of an LLM is a complex, multi-stage process involving massive computational resources, vast data curation efforts, and sophisticated human intervention. - -### 1. Data Curation and Tokenization -The first step is gathering and cleaning the training corpus. Modern LLMs are trained on hundreds of terabytes or even petabytes of text, often sourced from: -* **CommonCrawl:** A massive, open-source scrape of the public internet. -* **Filtered Web Text:** Highly curated, higher-quality web pages. -* **Books and Literature:** Digitized libraries. -* **Code Repositories:** Such as GitHub, to instill programming knowledge. -* **Wikipedia:** Structured knowledge bases. - -Data is meticulously filtered to remove low-quality content, boilerplate text, and offensive material. The text is then broken down into **tokens** using a process like **Byte-Pair Encoding (BPE)**. Tokens are the minimal units of meaning the model processes, bridging the gap between human language and numerical vectors. - -### 2. Pre-Training: Self-Supervised Learning -The core of LLM training is the **Pre-Training** phase. The model's hundreds of billions of parameters are initialized, and it is fed the massive, unlabeled dataset. The primary objective is **Next-Token Prediction** (or autoregressive modeling): predicting the next most probable token in a sequence, given all previous tokens. - -* **Objective Function:** The model minimizes the **Loss Function** (often **Cross-Entropy Loss**), which measures the difference between the model's predicted probability distribution over the vocabulary and the actual next token. -* **Optimization:** The model iteratively adjusts its weights using **Backpropagation** and an **Optimizer** (e.g., Adam or its variants) to reduce this loss. - -This phase, costing millions of dollars in GPU time, imbues the model with its fundamental knowledge base, grammar, syntax, and a basic, structural understanding of the world. It is through this pure statistical exercise that "reasoning" begins to emerge. - -### 3. Fine-Tuning and Alignment -A raw pre-trained model is highly knowledgeable but often unhelpful and potentially toxic. It will simply continue the statistical pattern of the input, regardless of intent. Alignment is the process of making the model follow instructions and adhere to ethical guidelines. - -#### A. Supervised Fine-Tuning (SFT) -The model is trained on a smaller, high-quality, human-curated dataset of prompts and desired, high-quality responses. This teaches the model a conversational style—how to act as an assistant, answer questions, and follow complex directions. - -#### B. Reinforcement Learning from Human Feedback (RLHF) -RLHF is the key component that created the conversational brilliance of models like ChatGPT. -1. **Response Generation:** For a given prompt, the LLM generates several possible answers. -2. **Human Ranking:** Human labelers rank these responses from best to worst based on helpfulness, accuracy, and safety. -3. **Reward Model Training:** A separate, smaller model called the **Reward Model (RM)** is trained to predict the human preference score for any response. The RM effectively learns "what a good answer looks like." -4. **Policy Optimization:** The main LLM is then fine-tuned using a Reinforcement Learning algorithm (like **Proximal Policy Optimization, PPO**) to maximize the score given by the Reward Model. - -This process explicitly aligns the model's objective function with human values, a crucial step in preparing the model for public deployment. - -## Part IV: Emergent Capabilities and Inherent Limitations - -The path from a neural network to a cognitive tool is marked by phenomena that both inspire awe and caution. - -### The Phenomenon of Emergence -As LLMs crossed certain thresholds—specifically in parameter count (size) and training data volume—researchers observed **Emergent Capabilities**. These are skills that the model was never explicitly trained for, yet they appear spontaneously. - -* **In-Context Learning (ICL):** The ability to learn a new task from a few examples provided directly in the prompt, without needing formal fine-tuning (Few-Shot Learning). -* **Chain-of-Thought (CoT) Reasoning:** The ability to decompose complex, multi-step problems into sequential reasoning steps, often unlocked by simply telling the model to "think step-by-step." This dramatically improves performance on arithmetic, common sense, and symbolic logic tasks. -* **Multilingual and Code Proficiency:** Models trained primarily on English and code surprisingly develop high-level proficiency in dozens of other languages and complex programming languages. - -These emergent properties suggest that the simple task of next-token prediction, when scaled sufficiently, leads to a kind of generalized, implicit world model—a probabilistic simulation of human knowledge and reasoning. - -### The Challenge of Hallucination -The most significant and stubborn limitation of LLMs is **Hallucination**—the generation of factually incorrect, nonsensical, or unfaithful content that is nevertheless syntactically plausible. - -The root cause lies in the model's core function: it is a **prediction engine, not a retrieval engine**. It does not access an external database of facts; it samples the most statistically likely sequence of tokens based on its internal, compressed world model. If the highest-probability sequence *looks* like a scientific citation but is entirely fabricated, the model will generate it. - -Mitigation strategies, such as **Retrieval-Augmented Generation (RAG)**, which links the LLM to a real-time, verifiable external knowledge source (like a search index or a company database), are essential for using LLMs in high-stakes, fact-based applications. - -## Part V: The Expanding Ecosystem and Applications - -The LLM ecosystem is diversifying rapidly, moving beyond the simple "chatbot" into powerful, specialized tools. - -### 1. Model Scaling and Efficiency -The pursuit of ever-larger models has reached its limits due to cost and data scarcity. The frontier has shifted to efficiency and specialization. -* **Mixture-of-Experts (MoE):** Models like Mixtral use a routing mechanism to activate only a subset of specialized "expert" neural networks for any given query. This allows the model to have a massive total parameter count (high knowledge capacity) while only using a fraction of the computational power (high efficiency). -* **Quantization and Pruning:** Techniques used to reduce the size and computational demands of models, making them executable on smaller devices (e.g., a mobile phone or a personal laptop). - -### 2. Multimodality -The most significant recent breakthrough is the transition from LLMs (Large Language Models) to **LMMs (Large Multimodal Models)**. These models are trained not just on text, but also on images, audio, and video data, allowing them to: -* **Visual Reasoning:** Analyze a complex graph, a photograph, or a technical diagram and answer questions about its content. -* **Audio Processing:** Transcribe, summarize, and understand the context of spoken language directly. -* **Seamless Integration:** Accept a prompt containing text and an image simultaneously (e.g., "Describe this image and write a poem about it"). - -### 3. Industry Applications -LLMs are no longer experimental; they are becoming foundational infrastructure across nearly every industry: -* **Software Engineering:** Automated code generation (e.g., GitHub Copilot), debugging, code translation between languages, and writing documentation. -* **Knowledge Work & Productivity:** Summarizing long documents, drafting complex reports, synthesizing research, and managing data from unstructured sources. -* **Customer Service & Sales:** Highly personalized and efficient conversational AI bots that can handle complex queries beyond simple FAQs. -* **Medicine and Law:** Assisting in drafting legal briefs, summarizing medical records, and cross-referencing diagnostic information (always requiring human oversight). -* **Creative Arts:** Generating marketing copy, scriptwriting, music composition (in conjunction with other AI models), and video production assets. - -## Part VI: The Ethical and Societal Labyrinth - -The power of LLMs brings with it a commensurately large set of ethical, social, and economic risks that demand global governance and responsible development. - -### 1. Bias, Fairness, and Amplification -LLMs are fundamentally statistical mirrors of their training data. If the internet contains biases related to gender, race, or geography, the model will ingest, amplify, and operationalize those biases. -* **Stereotype Reinforcement:** A model might associate certain professions (e.g., "engineer") predominantly with one gender, leading to biased outputs in hiring tools. -* **Harmful Generalizations:** Biases can lead to unfair or discriminatory decision-making when the models are deployed in high-stakes areas like loan applications or judicial risk assessment. -Mitigating bias requires meticulous data curation, adversarial testing, and post-processing "guardrails," but complete elimination remains technically elusive. - -### 2. Misinformation and Disinformation -The ability of LLMs to generate highly convincing, fluent text at scale is a threat to information integrity. Malicious actors can use these tools to: -* **Automate Phishing and Scams:** Generate personalized, sophisticated deceptive content. -* **Create Deepfake Text:** Impersonate real individuals or organizations with convincing prose. -* **Fabricate "Fake News" and Propaganda:** Generate massive volumes of highly plausible, factually false content, overwhelming traditional fact-checking mechanisms and accelerating the breakdown of public trust. - -### 3. Data Privacy and Security -LLMs pose risks related to data ingestion and leakage: -* **Training Data Memorization:** Models can, in rare cases, memorize and regurgitate personally identifiable information (PII) or copyrighted material from their vast training corpus. -* **Inference Attack (Data Leakage):** If a user provides proprietary or sensitive information as a prompt, that data may be inadvertently used to train future iterations of the model or leak through side channels, raising major security concerns for enterprise adoption. - -### 4. Environmental Impact -The scale of LLMs has a significant environmental footprint. Training a single frontier model requires months of continuous operation on thousands of GPUs, consuming energy equivalent to hundreds of homes for a year. The high computational cost raises questions about the long-term sustainability and equitable access to the technology. - -### 5. Economic Disruption and Labor -LLMs are directly impacting knowledge-based professions, particularly those involving content creation, data synthesis, and routine communication. While optimists argue the technology will mostly automate mundane tasks, freeing humans for higher-level work, policymakers and economists are grappling with the reality of rapid job displacement, income inequality, and the need for massive reskilling initiatives. - -## Part VII: The Frontier—The Path to Agentic AI and AGI - -The current state of the art is fleeting. The research community is pushing toward systems that are more autonomous, capable, and integrated. - -### 1. Agentic AI -The shift from a "Chatbot" to an "Agent" is the immediate future. Current LLMs are **reactive** (Question $\rightarrow$ Answer). An Agentic LLM is **proactive and goal-oriented**. -* **Goal:** The user provides a high-level goal (e.g., "Find the cheapest flight to Tokyo next month and book a hotel near the Shinjuku station."). -* **Planning:** The LLM breaks the goal into sub-tasks (Search flights, Compare prices, Search hotels, Check availability, Execute booking actions). -* **Tool Use:** The LLM integrates external tools (search engines, flight APIs, email/calendar APIs) to complete the tasks autonomously, engaging in a trial-and-error loop until the goal is achieved. This transforms the LLM from a generator of text into an executor of complex, multi-step actions. - -### 2. The Multi-Agent Ecosystem -The next stage involves creating swarms of specialized LLM Agents that communicate and collaborate to solve enormous, non-trivial problems. One agent might be a "researcher," another a "coder," and a third an "editor," all collaborating on a project, mimicking a human team. - -### 3. The Pursuit of Artificial General Intelligence (AGI) -The ultimate horizon is Artificial General Intelligence—a machine with the capacity to understand, learn, and apply its intelligence to solve virtually any problem that a human can. - -The debate remains: Is the current path of massive scaling and improved architecture (the **scaling hypothesis**) sufficient to reach AGI, or is some fundamental, non-Transformer-based innovation required? The appearance of emergent properties strongly suggests that the scaling path has not yet exhausted its potential, keeping the AGI goal within the sights of major research labs. - -## Conclusion: The Mirror of Human Intelligence - -Large Language Models are perhaps the most profound technological platform shift since the invention of the Internet. They represent the culmination of 75 years of AI research, transitioning from rule-based systems and statistical models to the deep, parallel processing power of the Transformer architecture. - -LLMs are the definitive statistical compressors of human knowledge, capable of synthesizing our collective digital output with stunning fidelity. They have unlocked a new era of computational creativity and efficiency, driving unprecedented change across every sector. - -Yet, this power is a double-edged sword. LLMs are not inherently wise; they are merely proficient at pattern matching. They reflect and amplify human biases, they can deceive with convincing misinformation, and they introduce profound questions about accountability, labor, and the nature of creative work. - -The future of LLMs is not just about making them *smarter*, but making them *safer*, *more efficient*, and more *aligned* with human values. The challenge for the coming decade is not technical—the algorithms and compute will continue to improve—but **governance and ethical**. Humanity must learn to responsibly wield this powerful mirror of its own intelligence, ensuring that the cognitive revolution we have started leads to a future of prosperity and equitable access, rather than fragmentation and control. The architecture of intelligence is now in our hands; the path forward depends on the wisdom of its design and deployment. \ No newline at end of file From b15f6caff7655d9602fd4be09afc254179969db5 Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Sun, 4 Jan 2026 15:11:35 +0000 Subject: [PATCH 098/130] Revert "update builder.update_block_table: rm blk_table and slot_mapping" This reverts commit 61af735cffa3ad0da649e30c3caa82f1b712074b. Signed-off-by: huanghaoyan.hhy --- vllm/attention/layers/chunked_local_attention.py | 9 +++++---- vllm/v1/attention/backends/flash_attn.py | 6 ++++-- vllm/v1/attention/backends/mamba_attn.py | 3 ++- vllm/v1/attention/backends/utils.py | 2 ++ vllm/v1/worker/gpu_model_runner.py | 2 ++ 5 files changed, 15 insertions(+), 7 deletions(-) diff --git a/vllm/attention/layers/chunked_local_attention.py b/vllm/attention/layers/chunked_local_attention.py index f95a114552de..026a56b978fc 100644 --- a/vllm/attention/layers/chunked_local_attention.py +++ b/vllm/attention/layers/chunked_local_attention.py @@ -63,12 +63,13 @@ def update_block_table( self, common_metadata, metadata, + blk_table: torch.Tensor, + slot_mapping: torch.Tensor, ): - new_metadata = super().update_block_table(metadata, common_metadata) - new_metadata.block_table = metadata.make_virtual_batches_block_table( - new_metadata.block_table + blk_table = metadata.make_virtual_batches_block_table(blk_table) + return super().update_block_table( + metadata, common_metadata, blk_table, slot_mapping ) - return new_metadata attn_backend = subclass_attention_backend( name_prefix=prefix, diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 61a038e3f5c4..c6a43996998c 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -499,10 +499,12 @@ def update_block_table( self, common_metadata: CommonAttentionMetadata, metadata: FlashAttentionMetadata, + blk_table: torch.Tensor, + slot_mapping: torch.Tensor, ) -> FlashAttentionMetadata: new_metadata = copy.copy(metadata) - new_metadata.block_table = common_metadata.block_table_tensor - new_metadata.slot_mapping = common_metadata.slot_mapping + new_metadata.block_table = blk_table + new_metadata.slot_mapping = slot_mapping return new_metadata def use_cascade_attention(self, *args, **kwargs) -> bool: diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index 6da69ecc9fff..31e9085bfc9c 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -291,9 +291,10 @@ def update_block_table( self, common_metadata: CommonAttentionMetadata, metadata: M, + blk_table: torch.Tensor, + slot_mapping: torch.Tensor, ) -> M: new_metadata = copy.copy(metadata) - blk_table = common_metadata.block_table_tensor prefix_caching_all_mode = ( self.vllm_config.cache_config.mamba_cache_mode == "all" ) diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 1d094d7de372..bb407a873726 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -395,6 +395,8 @@ def update_block_table( self, common_metadata: CommonAttentionMetadata, metadata: M, + blk_table: torch.Tensor, + slot_mapping: torch.Tensor, ) -> M: """ Update the block table for the attention metadata. diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 797a1791f174..afaf778440f9 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1743,6 +1743,8 @@ def _build_attn_group_metadata( attn_metadata_i = builder.update_block_table( common_attn_metadata, cached_attn_metadata[cache_key], + common_attn_metadata.block_table_tensor, + common_attn_metadata.slot_mapping, ) else: attn_metadata_i = builder.build( From 8a92d977a40a67db20e4ceeae704e61389500c34 Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Sun, 4 Jan 2026 15:58:33 +0000 Subject: [PATCH 099/130] update mamba_get_block_table_tensor api Signed-off-by: huanghaoyan.hhy --- vllm/v1/attention/backends/gdn_attn.py | 3 ++- vllm/v1/attention/backends/linear_attn.py | 3 ++- vllm/v1/attention/backends/mamba_attn.py | 6 ++++-- vllm/v1/attention/backends/utils.py | 12 ++++++------ 4 files changed, 14 insertions(+), 10 deletions(-) diff --git a/vllm/v1/attention/backends/gdn_attn.py b/vllm/v1/attention/backends/gdn_attn.py index 04452f1e6cfb..740f282afb52 100644 --- a/vllm/v1/attention/backends/gdn_attn.py +++ b/vllm/v1/attention/backends/gdn_attn.py @@ -150,7 +150,8 @@ def build( # type: ignore[override] context_lens_tensor = context_lens.to(query_start_loc.device, non_blocking=True) nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None block_table_tensor = mamba_get_block_table_tensor( - common_attn_metadata, + common_attn_metadata.block_table_tensor, + common_attn_metadata.seq_lens, self.kv_cache_spec, self.vllm_config.cache_config.mamba_cache_mode, ) diff --git a/vllm/v1/attention/backends/linear_attn.py b/vllm/v1/attention/backends/linear_attn.py index d25ec564f7f0..482d2bb084ba 100644 --- a/vllm/v1/attention/backends/linear_attn.py +++ b/vllm/v1/attention/backends/linear_attn.py @@ -59,7 +59,8 @@ def build( seq_lens = common_attn_metadata.seq_lens state_indices_tensor = mamba_get_block_table_tensor( - common_attn_metadata, + common_attn_metadata.block_table_tensor, + common_attn_metadata.seq_lens, self.kv_cache_spec, self.vllm_config.cache_config.mamba_cache_mode, )[:, 0] diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index 31e9085bfc9c..8763409c7c20 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -211,7 +211,8 @@ def _compute_common_metadata( else: # Always return just a single block per each request: state_indices_tensor = mamba_get_block_table_tensor( - common_attn_metadata, + common_attn_metadata.block_table_tensor, + common_attn_metadata.seq_lens, self.kv_cache_spec, self.vllm_config.cache_config.mamba_cache_mode, )[:, 0] @@ -302,7 +303,8 @@ def update_block_table( blk_table if prefix_caching_all_mode else mamba_get_block_table_tensor( - common_metadata, + blk_table, + common_metadata.seq_lens, self.kv_cache_spec, self.vllm_config.cache_config.mamba_cache_mode, )[:, 0] diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index bb407a873726..47fc017fb1f4 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -1249,7 +1249,8 @@ def get_dcp_local_seq_lens( def mamba_get_block_table_tensor( - common_attn_metadata: CommonAttentionMetadata, + block_table: torch.Tensor, + seq_lens: torch.Tensor, kv_cache_spec: MambaSpec, mamba_cache_mode: str, ) -> torch.Tensor: @@ -1268,18 +1269,17 @@ def mamba_get_block_table_tensor( 1 + num_speculative_blocks of each request. """ if mamba_cache_mode in ("all", "none"): - return common_attn_metadata.block_table_tensor + return block_table else: assert isinstance(kv_cache_spec, MambaSpec) - block_table_tensor = common_attn_metadata.block_table_tensor # NOTE: For 0-length requests in CUDA graph, use a start_index of 0 # to handle the invalid block table. start_indices = torch.clamp( - (common_attn_metadata.seq_lens - 1) // kv_cache_spec.block_size, + (seq_lens - 1) // kv_cache_spec.block_size, min=0, ) offsets = torch.arange( - 1 + kv_cache_spec.num_speculative_blocks, device=block_table_tensor.device + 1 + kv_cache_spec.num_speculative_blocks, device=block_table.device ) indices_to_gather = start_indices.unsqueeze(1) + offsets - return torch.gather(block_table_tensor, 1, indices_to_gather) + return torch.gather(block_table, 1, indices_to_gather) From 477747ed6ee0414dc7af72e86fb85d16505aa37a Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Sun, 4 Jan 2026 16:14:17 +0000 Subject: [PATCH 100/130] fix update_block_table in chunk_local_attn Signed-off-by: huanghaoyan.hhy --- vllm/attention/layers/chunked_local_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/attention/layers/chunked_local_attention.py b/vllm/attention/layers/chunked_local_attention.py index 026a56b978fc..641c74be1cf4 100644 --- a/vllm/attention/layers/chunked_local_attention.py +++ b/vllm/attention/layers/chunked_local_attention.py @@ -68,7 +68,7 @@ def update_block_table( ): blk_table = metadata.make_virtual_batches_block_table(blk_table) return super().update_block_table( - metadata, common_metadata, blk_table, slot_mapping + common_metadata, metadata, blk_table, slot_mapping ) attn_backend = subclass_attention_backend( From 49b533abadc95a5387949cf75b43e1388616c11f Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Sun, 4 Jan 2026 16:15:38 +0000 Subject: [PATCH 101/130] update builder.update_block_table: add seq_lens arg Signed-off-by: huanghaoyan.hhy --- collect_env.py | 857 ++++++++++++++++++ .../layers/chunked_local_attention.py | 4 +- vllm/v1/attention/backends/flash_attn.py | 2 +- vllm/v1/attention/backends/mamba_attn.py | 4 +- vllm/v1/attention/backends/utils.py | 2 +- vllm/v1/worker/gpu_model_runner.py | 2 +- 6 files changed, 864 insertions(+), 7 deletions(-) create mode 100644 collect_env.py diff --git a/collect_env.py b/collect_env.py new file mode 100644 index 000000000000..4ca0852e3998 --- /dev/null +++ b/collect_env.py @@ -0,0 +1,857 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# ruff: noqa +# code borrowed from https://github.com/pytorch/pytorch/blob/main/torch/utils/collect_env.py + +import datetime +import locale +import os +import subprocess +import sys + +# Unlike the rest of the PyTorch this file must be python2 compliant. +# This script outputs relevant system environment info +# Run it with `python collect_env.py` or `python -m torch.utils.collect_env` +from collections import namedtuple + +import regex as re + +from vllm.envs import environment_variables + +try: + import torch + + TORCH_AVAILABLE = True +except (ImportError, NameError, AttributeError, OSError): + TORCH_AVAILABLE = False + +# System Environment Information +SystemEnv = namedtuple( + "SystemEnv", + [ + "torch_version", + "is_debug_build", + "cuda_compiled_version", + "gcc_version", + "clang_version", + "cmake_version", + "os", + "libc_version", + "python_version", + "python_platform", + "is_cuda_available", + "cuda_runtime_version", + "cuda_module_loading", + "nvidia_driver_version", + "nvidia_gpu_models", + "cudnn_version", + "pip_version", # 'pip' or 'pip3' + "pip_packages", + "conda_packages", + "hip_compiled_version", + "hip_runtime_version", + "miopen_runtime_version", + "caching_allocator_config", + "is_xnnpack_available", + "cpu_info", + "rocm_version", # vllm specific field + "vllm_version", # vllm specific field + "vllm_build_flags", # vllm specific field + "gpu_topo", # vllm specific field + "env_vars", + ], +) + +DEFAULT_CONDA_PATTERNS = { + "torch", + "numpy", + "cudatoolkit", + "soumith", + "mkl", + "magma", + "triton", + "optree", + "nccl", + "transformers", + "zmq", + "nvidia", + "pynvml", + "flashinfer-python", +} + +DEFAULT_PIP_PATTERNS = { + "torch", + "numpy", + "mypy", + "flake8", + "triton", + "optree", + "onnx", + "nccl", + "transformers", + "zmq", + "nvidia", + "pynvml", + "flashinfer-python", +} + + +def run(command): + """Return (return-code, stdout, stderr).""" + shell = True if type(command) is str else False + try: + p = subprocess.Popen( + command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=shell + ) + raw_output, raw_err = p.communicate() + rc = p.returncode + if get_platform() == "win32": + enc = "oem" + else: + enc = locale.getpreferredencoding() + output = raw_output.decode(enc) + if command == "nvidia-smi topo -m": + # don't remove the leading whitespace of `nvidia-smi topo -m` + # because they are meaningful + output = output.rstrip() + else: + output = output.strip() + err = raw_err.decode(enc) + return rc, output, err.strip() + + except FileNotFoundError: + cmd_str = command if isinstance(command, str) else command[0] + return 127, "", f"Command not found: {cmd_str}" + + +def run_and_read_all(run_lambda, command): + """Run command using run_lambda; reads and returns entire output if rc is 0.""" + rc, out, _ = run_lambda(command) + if rc != 0: + return None + return out + + +def run_and_parse_first_match(run_lambda, command, regex): + """Run command using run_lambda, returns the first regex match if it exists.""" + rc, out, _ = run_lambda(command) + if rc != 0: + return None + match = re.search(regex, out) + if match is None: + return None + return match.group(1) + + +def run_and_return_first_line(run_lambda, command): + """Run command using run_lambda and returns first line if output is not empty.""" + rc, out, _ = run_lambda(command) + if rc != 0: + return None + return out.split("\n")[0] + + +def get_conda_packages(run_lambda, patterns=None): + if patterns is None: + patterns = DEFAULT_CONDA_PATTERNS + conda = os.environ.get("CONDA_EXE", "conda") + out = run_and_read_all(run_lambda, [conda, "list"]) + if out is None: + return out + + return "\n".join( + line + for line in out.splitlines() + if not line.startswith("#") and any(name in line for name in patterns) + ) + + +def get_gcc_version(run_lambda): + return run_and_parse_first_match(run_lambda, "gcc --version", r"gcc (.*)") + + +def get_clang_version(run_lambda): + return run_and_parse_first_match( + run_lambda, "clang --version", r"clang version (.*)" + ) + + +def get_cmake_version(run_lambda): + return run_and_parse_first_match(run_lambda, "cmake --version", r"cmake (.*)") + + +def get_nvidia_driver_version(run_lambda): + if get_platform() == "darwin": + cmd = "kextstat | grep -i cuda" + return run_and_parse_first_match( + run_lambda, cmd, r"com[.]nvidia[.]CUDA [(](.*?)[)]" + ) + smi = get_nvidia_smi() + return run_and_parse_first_match(run_lambda, smi, r"Driver Version: (.*?) ") + + +def get_gpu_info(run_lambda): + if get_platform() == "darwin" or ( + TORCH_AVAILABLE + and hasattr(torch.version, "hip") + and torch.version.hip is not None + ): + if TORCH_AVAILABLE and torch.cuda.is_available(): + if torch.version.hip is not None: + prop = torch.cuda.get_device_properties(0) + if hasattr(prop, "gcnArchName"): + gcnArch = " ({})".format(prop.gcnArchName) + else: + gcnArch = "NoGCNArchNameOnOldPyTorch" + else: + gcnArch = "" + return torch.cuda.get_device_name(None) + gcnArch + return None + smi = get_nvidia_smi() + uuid_regex = re.compile(r" \(UUID: .+?\)") + rc, out, _ = run_lambda(smi + " -L") + if rc != 0: + return None + # Anonymize GPUs by removing their UUID + return re.sub(uuid_regex, "", out) + + +def get_running_cuda_version(run_lambda): + return run_and_parse_first_match(run_lambda, "nvcc --version", r"release .+ V(.*)") + + +def get_cudnn_version(run_lambda): + """Return a list of libcudnn.so; it's hard to tell which one is being used.""" + if get_platform() == "win32": + system_root = os.environ.get("SYSTEMROOT", "C:\\Windows") + cuda_path = os.environ.get("CUDA_PATH", "%CUDA_PATH%") + where_cmd = os.path.join(system_root, "System32", "where") + cudnn_cmd = '{} /R "{}\\bin" cudnn*.dll'.format(where_cmd, cuda_path) + elif get_platform() == "darwin": + # CUDA libraries and drivers can be found in /usr/local/cuda/. See + # https://docs.nvidia.com/cuda/cuda-installation-guide-mac-os-x/index.html#install + # https://docs.nvidia.com/deeplearning/sdk/cudnn-install/index.html#installmac + # Use CUDNN_LIBRARY when cudnn library is installed elsewhere. + cudnn_cmd = "ls /usr/local/cuda/lib/libcudnn*" + else: + cudnn_cmd = 'ldconfig -p | grep libcudnn | rev | cut -d" " -f1 | rev' + rc, out, _ = run_lambda(cudnn_cmd) + # find will return 1 if there are permission errors or if not found + if len(out) == 0 or (rc != 1 and rc != 0): + l = os.environ.get("CUDNN_LIBRARY") + if l is not None and os.path.isfile(l): + return os.path.realpath(l) + return None + files_set = set() + for fn in out.split("\n"): + fn = os.path.realpath(fn) # eliminate symbolic links + if os.path.isfile(fn): + files_set.add(fn) + if not files_set: + return None + # Alphabetize the result because the order is non-deterministic otherwise + files = sorted(files_set) + if len(files) == 1: + return files[0] + result = "\n".join(files) + return "Probably one of the following:\n{}".format(result) + + +def get_nvidia_smi(): + # Note: nvidia-smi is currently available only on Windows and Linux + smi = "nvidia-smi" + if get_platform() == "win32": + system_root = os.environ.get("SYSTEMROOT", "C:\\Windows") + program_files_root = os.environ.get("PROGRAMFILES", "C:\\Program Files") + legacy_path = os.path.join( + program_files_root, "NVIDIA Corporation", "NVSMI", smi + ) + new_path = os.path.join(system_root, "System32", smi) + smis = [new_path, legacy_path] + for candidate_smi in smis: + if os.path.exists(candidate_smi): + smi = '"{}"'.format(candidate_smi) + break + return smi + + +def get_rocm_version(run_lambda): + """Returns the ROCm version if available, otherwise 'N/A'.""" + return run_and_parse_first_match( + run_lambda, "hipcc --version", r"HIP version: (\S+)" + ) + + +def get_vllm_version(): + from vllm import __version__, __version_tuple__ + + if __version__ == "dev": + return "N/A (dev)" + version_str = __version_tuple__[-1] + if isinstance(version_str, str) and version_str.startswith("g"): + # it's a dev build + if "." in version_str: + # it's a dev build containing local changes + git_sha = version_str.split(".")[0][1:] + date = version_str.split(".")[-1][1:] + return f"{__version__} (git sha: {git_sha}, date: {date})" + else: + # it's a dev build without local changes + git_sha = version_str[1:] # type: ignore + return f"{__version__} (git sha: {git_sha})" + return __version__ + + +def summarize_vllm_build_flags(): + # This could be a static method if the flags are constant, or dynamic if you need to check environment variables, etc. + return "CUDA Archs: {}; ROCm: {}".format( + os.environ.get("TORCH_CUDA_ARCH_LIST", "Not Set"), + "Enabled" if os.environ.get("ROCM_HOME") else "Disabled", + ) + + +def get_gpu_topo(run_lambda): + output = None + + if get_platform() == "linux": + output = run_and_read_all(run_lambda, "nvidia-smi topo -m") + if output is None: + output = run_and_read_all(run_lambda, "rocm-smi --showtopo") + + return output + + +# example outputs of CPU infos +# * linux +# Architecture: x86_64 +# CPU op-mode(s): 32-bit, 64-bit +# Address sizes: 46 bits physical, 48 bits virtual +# Byte Order: Little Endian +# CPU(s): 128 +# On-line CPU(s) list: 0-127 +# Vendor ID: GenuineIntel +# Model name: Intel(R) Xeon(R) Platinum 8375C CPU @ 2.90GHz +# CPU family: 6 +# Model: 106 +# Thread(s) per core: 2 +# Core(s) per socket: 32 +# Socket(s): 2 +# Stepping: 6 +# BogoMIPS: 5799.78 +# Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr +# sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl +# xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq monitor ssse3 fma cx16 +# pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand +# hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced +# fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap +# avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 +# xsaves wbnoinvd ida arat avx512vbmi pku ospke avx512_vbmi2 gfni vaes vpclmulqdq +# avx512_vnni avx512_bitalg tme avx512_vpopcntdq rdpid md_clear flush_l1d arch_capabilities +# Virtualization features: +# Hypervisor vendor: KVM +# Virtualization type: full +# Caches (sum of all): +# L1d: 3 MiB (64 instances) +# L1i: 2 MiB (64 instances) +# L2: 80 MiB (64 instances) +# L3: 108 MiB (2 instances) +# NUMA: +# NUMA node(s): 2 +# NUMA node0 CPU(s): 0-31,64-95 +# NUMA node1 CPU(s): 32-63,96-127 +# Vulnerabilities: +# Itlb multihit: Not affected +# L1tf: Not affected +# Mds: Not affected +# Meltdown: Not affected +# Mmio stale data: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown +# Retbleed: Not affected +# Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp +# Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization +# Spectre v2: Mitigation; Enhanced IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence +# Srbds: Not affected +# Tsx async abort: Not affected +# * win32 +# Architecture=9 +# CurrentClockSpeed=2900 +# DeviceID=CPU0 +# Family=179 +# L2CacheSize=40960 +# L2CacheSpeed= +# Manufacturer=GenuineIntel +# MaxClockSpeed=2900 +# Name=Intel(R) Xeon(R) Platinum 8375C CPU @ 2.90GHz +# ProcessorType=3 +# Revision=27142 +# +# Architecture=9 +# CurrentClockSpeed=2900 +# DeviceID=CPU1 +# Family=179 +# L2CacheSize=40960 +# L2CacheSpeed= +# Manufacturer=GenuineIntel +# MaxClockSpeed=2900 +# Name=Intel(R) Xeon(R) Platinum 8375C CPU @ 2.90GHz +# ProcessorType=3 +# Revision=27142 + + +def get_cpu_info(run_lambda): + rc, out, err = 0, "", "" + if get_platform() == "linux": + rc, out, err = run_lambda("lscpu") + elif get_platform() == "win32": + rc, out, err = run_lambda( + "wmic cpu get Name,Manufacturer,Family,Architecture,ProcessorType,DeviceID, \ + CurrentClockSpeed,MaxClockSpeed,L2CacheSize,L2CacheSpeed,Revision /VALUE" + ) + elif get_platform() == "darwin": + rc, out, err = run_lambda("sysctl -n machdep.cpu.brand_string") + cpu_info = "None" + if rc == 0: + cpu_info = out + else: + cpu_info = err + return cpu_info + + +def get_platform(): + if sys.platform.startswith("linux"): + return "linux" + elif sys.platform.startswith("win32"): + return "win32" + elif sys.platform.startswith("cygwin"): + return "cygwin" + elif sys.platform.startswith("darwin"): + return "darwin" + else: + return sys.platform + + +def get_mac_version(run_lambda): + return run_and_parse_first_match(run_lambda, "sw_vers -productVersion", r"(.*)") + + +def get_windows_version(run_lambda): + system_root = os.environ.get("SYSTEMROOT", "C:\\Windows") + wmic_cmd = os.path.join(system_root, "System32", "Wbem", "wmic") + findstr_cmd = os.path.join(system_root, "System32", "findstr") + return run_and_read_all( + run_lambda, "{} os get Caption | {} /v Caption".format(wmic_cmd, findstr_cmd) + ) + + +def get_lsb_version(run_lambda): + return run_and_parse_first_match( + run_lambda, "lsb_release -a", r"Description:\t(.*)" + ) + + +def check_release_file(run_lambda): + return run_and_parse_first_match( + run_lambda, "cat /etc/*-release", r'PRETTY_NAME="(.*)"' + ) + + +def get_os(run_lambda): + from platform import machine + + platform = get_platform() + + if platform == "win32" or platform == "cygwin": + return get_windows_version(run_lambda) + + if platform == "darwin": + version = get_mac_version(run_lambda) + if version is None: + return None + return "macOS {} ({})".format(version, machine()) + + if platform == "linux": + # Ubuntu/Debian based + desc = get_lsb_version(run_lambda) + if desc is not None: + return "{} ({})".format(desc, machine()) + + # Try reading /etc/*-release + desc = check_release_file(run_lambda) + if desc is not None: + return "{} ({})".format(desc, machine()) + + return "{} ({})".format(platform, machine()) + + # Unknown platform + return platform + + +def get_python_platform(): + import platform + + return platform.platform() + + +def get_libc_version(): + import platform + + if get_platform() != "linux": + return "N/A" + return "-".join(platform.libc_ver()) + + +def is_uv_venv(): + if os.environ.get("UV"): + return True + pyvenv_cfg_path = os.path.join(sys.prefix, "pyvenv.cfg") + if os.path.exists(pyvenv_cfg_path): + with open(pyvenv_cfg_path, "r") as f: + return any(line.startswith("uv = ") for line in f) + return False + + +def get_pip_packages(run_lambda, patterns=None): + """Return `pip list` output. Note: will also find conda-installed pytorch and numpy packages.""" + if patterns is None: + patterns = DEFAULT_PIP_PATTERNS + + def run_with_pip(): + try: + import importlib.util + + pip_spec = importlib.util.find_spec("pip") + pip_available = pip_spec is not None + except ImportError: + pip_available = False + + if pip_available: + cmd = [sys.executable, "-mpip", "list", "--format=freeze"] + elif is_uv_venv(): + print("uv is set") + cmd = ["uv", "pip", "list", "--format=freeze"] + else: + raise RuntimeError( + "Could not collect pip list output (pip or uv module not available)" + ) + + out = run_and_read_all(run_lambda, cmd) + return "\n".join( + line for line in out.splitlines() if any(name in line for name in patterns) + ) + + pip_version = "pip3" if sys.version[0] == "3" else "pip" + out = run_with_pip() + return pip_version, out + + +def get_cachingallocator_config(): + ca_config = os.environ.get("PYTORCH_CUDA_ALLOC_CONF", "") + return ca_config + + +def get_cuda_module_loading_config(): + if TORCH_AVAILABLE and torch.cuda.is_available(): + torch.cuda.init() + config = os.environ.get("CUDA_MODULE_LOADING", "") + return config + else: + return "N/A" + + +def is_xnnpack_available(): + if TORCH_AVAILABLE: + import torch.backends.xnnpack + + return str(torch.backends.xnnpack.enabled) # type: ignore[attr-defined] + else: + return "N/A" + + +def get_env_vars(): + env_vars = "" + secret_terms = ("secret", "token", "api", "access", "password") + report_prefix = ( + "TORCH", + "NCCL", + "PYTORCH", + "CUDA", + "CUBLAS", + "CUDNN", + "OMP_", + "MKL_", + "NVIDIA", + ) + for k, v in os.environ.items(): + if any(term in k.lower() for term in secret_terms): + continue + if k in environment_variables: + env_vars = env_vars + "{}={}".format(k, v) + "\n" + if k.startswith(report_prefix): + env_vars = env_vars + "{}={}".format(k, v) + "\n" + + return env_vars + + +def get_env_info(): + run_lambda = run + pip_version, pip_list_output = get_pip_packages(run_lambda) + + if TORCH_AVAILABLE: + version_str = torch.__version__ + debug_mode_str = str(torch.version.debug) + cuda_available_str = str(torch.cuda.is_available()) + cuda_version_str = torch.version.cuda + if ( + not hasattr(torch.version, "hip") or torch.version.hip is None + ): # cuda version + hip_compiled_version = hip_runtime_version = miopen_runtime_version = "N/A" + else: # HIP version + + def get_version_or_na(cfg, prefix): + _lst = [s.rsplit(None, 1)[-1] for s in cfg if prefix in s] + return _lst[0] if _lst else "N/A" + + cfg = torch._C._show_config().split("\n") + hip_runtime_version = get_version_or_na(cfg, "HIP Runtime") + miopen_runtime_version = get_version_or_na(cfg, "MIOpen") + cuda_version_str = "N/A" + hip_compiled_version = torch.version.hip + else: + version_str = debug_mode_str = cuda_available_str = cuda_version_str = "N/A" + hip_compiled_version = hip_runtime_version = miopen_runtime_version = "N/A" + + sys_version = sys.version.replace("\n", " ") + + conda_packages = get_conda_packages(run_lambda) + + rocm_version = get_rocm_version(run_lambda) + vllm_version = get_vllm_version() + vllm_build_flags = summarize_vllm_build_flags() + gpu_topo = get_gpu_topo(run_lambda) + + return SystemEnv( + torch_version=version_str, + is_debug_build=debug_mode_str, + python_version="{} ({}-bit runtime)".format( + sys_version, sys.maxsize.bit_length() + 1 + ), + python_platform=get_python_platform(), + is_cuda_available=cuda_available_str, + cuda_compiled_version=cuda_version_str, + cuda_runtime_version=get_running_cuda_version(run_lambda), + cuda_module_loading=get_cuda_module_loading_config(), + nvidia_gpu_models=get_gpu_info(run_lambda), + nvidia_driver_version=get_nvidia_driver_version(run_lambda), + cudnn_version=get_cudnn_version(run_lambda), + hip_compiled_version=hip_compiled_version, + hip_runtime_version=hip_runtime_version, + miopen_runtime_version=miopen_runtime_version, + pip_version=pip_version, + pip_packages=pip_list_output, + conda_packages=conda_packages, + os=get_os(run_lambda), + libc_version=get_libc_version(), + gcc_version=get_gcc_version(run_lambda), + clang_version=get_clang_version(run_lambda), + cmake_version=get_cmake_version(run_lambda), + caching_allocator_config=get_cachingallocator_config(), + is_xnnpack_available=is_xnnpack_available(), + cpu_info=get_cpu_info(run_lambda), + rocm_version=rocm_version, + vllm_version=vllm_version, + vllm_build_flags=vllm_build_flags, + gpu_topo=gpu_topo, + env_vars=get_env_vars(), + ) + + +env_info_fmt = """ +============================== + System Info +============================== +OS : {os} +GCC version : {gcc_version} +Clang version : {clang_version} +CMake version : {cmake_version} +Libc version : {libc_version} + +============================== + PyTorch Info +============================== +PyTorch version : {torch_version} +Is debug build : {is_debug_build} +CUDA used to build PyTorch : {cuda_compiled_version} +ROCM used to build PyTorch : {hip_compiled_version} + +============================== + Python Environment +============================== +Python version : {python_version} +Python platform : {python_platform} + +============================== + CUDA / GPU Info +============================== +Is CUDA available : {is_cuda_available} +CUDA runtime version : {cuda_runtime_version} +CUDA_MODULE_LOADING set to : {cuda_module_loading} +GPU models and configuration : {nvidia_gpu_models} +Nvidia driver version : {nvidia_driver_version} +cuDNN version : {cudnn_version} +HIP runtime version : {hip_runtime_version} +MIOpen runtime version : {miopen_runtime_version} +Is XNNPACK available : {is_xnnpack_available} + +============================== + CPU Info +============================== +{cpu_info} + +============================== +Versions of relevant libraries +============================== +{pip_packages} +{conda_packages} +""".strip() + +# both the above code and the following code use `strip()` to +# remove leading/trailing whitespaces, so we need to add a newline +# in between to separate the two sections +env_info_fmt += "\n\n" + +env_info_fmt += """ +============================== + vLLM Info +============================== +ROCM Version : {rocm_version} +vLLM Version : {vllm_version} +vLLM Build Flags: + {vllm_build_flags} +GPU Topology: + {gpu_topo} + +============================== + Environment Variables +============================== +{env_vars} +""".strip() + + +def pretty_str(envinfo): + def replace_nones(dct, replacement="Could not collect"): + for key in dct.keys(): + if dct[key] is not None: + continue + dct[key] = replacement + return dct + + def replace_bools(dct, true="Yes", false="No"): + for key in dct.keys(): + if dct[key] is True: + dct[key] = true + elif dct[key] is False: + dct[key] = false + return dct + + def prepend(text, tag="[prepend]"): + lines = text.split("\n") + updated_lines = [tag + line for line in lines] + return "\n".join(updated_lines) + + def replace_if_empty(text, replacement="No relevant packages"): + if text is not None and len(text) == 0: + return replacement + return text + + def maybe_start_on_next_line(string): + # If `string` is multiline, prepend a \n to it. + if string is not None and len(string.split("\n")) > 1: + return "\n{}\n".format(string) + return string + + mutable_dict = envinfo._asdict() + + # If nvidia_gpu_models is multiline, start on the next line + mutable_dict["nvidia_gpu_models"] = maybe_start_on_next_line( + envinfo.nvidia_gpu_models + ) + + # If the machine doesn't have CUDA, report some fields as 'No CUDA' + dynamic_cuda_fields = [ + "cuda_runtime_version", + "nvidia_gpu_models", + "nvidia_driver_version", + ] + all_cuda_fields = dynamic_cuda_fields + ["cudnn_version"] + all_dynamic_cuda_fields_missing = all( + mutable_dict[field] is None for field in dynamic_cuda_fields + ) + if ( + TORCH_AVAILABLE + and not torch.cuda.is_available() + and all_dynamic_cuda_fields_missing + ): + for field in all_cuda_fields: + mutable_dict[field] = "No CUDA" + if envinfo.cuda_compiled_version is None: + mutable_dict["cuda_compiled_version"] = "None" + + # Replace True with Yes, False with No + mutable_dict = replace_bools(mutable_dict) + + # Replace all None objects with 'Could not collect' + mutable_dict = replace_nones(mutable_dict) + + # If either of these are '', replace with 'No relevant packages' + mutable_dict["pip_packages"] = replace_if_empty(mutable_dict["pip_packages"]) + mutable_dict["conda_packages"] = replace_if_empty(mutable_dict["conda_packages"]) + + # Tag conda and pip packages with a prefix + # If they were previously None, they'll show up as ie '[conda] Could not collect' + if mutable_dict["pip_packages"]: + mutable_dict["pip_packages"] = prepend( + mutable_dict["pip_packages"], "[{}] ".format(envinfo.pip_version) + ) + if mutable_dict["conda_packages"]: + mutable_dict["conda_packages"] = prepend( + mutable_dict["conda_packages"], "[conda] " + ) + mutable_dict["cpu_info"] = envinfo.cpu_info + return env_info_fmt.format(**mutable_dict) + + +def get_pretty_env_info(): + return pretty_str(get_env_info()) + + +def main(): + print("Collecting environment information...") + output = get_pretty_env_info() + print(output) + + if ( + TORCH_AVAILABLE + and hasattr(torch, "utils") + and hasattr(torch.utils, "_crash_handler") + ): + minidump_dir = torch.utils._crash_handler.DEFAULT_MINIDUMP_DIR + if sys.platform == "linux" and os.path.exists(minidump_dir): + dumps = [ + os.path.join(minidump_dir, dump) for dump in os.listdir(minidump_dir) + ] + latest = max(dumps, key=os.path.getctime) + ctime = os.path.getctime(latest) + creation_time = datetime.datetime.fromtimestamp(ctime).strftime( + "%Y-%m-%d %H:%M:%S" + ) + msg = ( + "\n*** Detected a minidump at {} created on {}, ".format( + latest, creation_time + ) + + "if this is related to your bug please include it when you file a report ***" + ) + print(msg, file=sys.stderr) + + +if __name__ == "__main__": + main() diff --git a/vllm/attention/layers/chunked_local_attention.py b/vllm/attention/layers/chunked_local_attention.py index 641c74be1cf4..569a264d7f3c 100644 --- a/vllm/attention/layers/chunked_local_attention.py +++ b/vllm/attention/layers/chunked_local_attention.py @@ -61,14 +61,14 @@ def build( def update_block_table( self, - common_metadata, metadata, + seq_lens: torch.Tensor, blk_table: torch.Tensor, slot_mapping: torch.Tensor, ): blk_table = metadata.make_virtual_batches_block_table(blk_table) return super().update_block_table( - common_metadata, metadata, blk_table, slot_mapping + metadata, seq_lens, blk_table, slot_mapping ) attn_backend = subclass_attention_backend( diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index c6a43996998c..8066ddcfaeda 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -497,8 +497,8 @@ def schedule( def update_block_table( self, - common_metadata: CommonAttentionMetadata, metadata: FlashAttentionMetadata, + seq_lens: torch.Tensor, blk_table: torch.Tensor, slot_mapping: torch.Tensor, ) -> FlashAttentionMetadata: diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index 8763409c7c20..935d6a7279d8 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -290,8 +290,8 @@ def _compute_common_metadata( def update_block_table( self, - common_metadata: CommonAttentionMetadata, metadata: M, + seq_lens: torch.Tensor, blk_table: torch.Tensor, slot_mapping: torch.Tensor, ) -> M: @@ -304,7 +304,7 @@ def update_block_table( if prefix_caching_all_mode else mamba_get_block_table_tensor( blk_table, - common_metadata.seq_lens, + seq_lens, self.kv_cache_spec, self.vllm_config.cache_config.mamba_cache_mode, )[:, 0] diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 47fc017fb1f4..3c5e206d251e 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -393,8 +393,8 @@ def build( def update_block_table( self, - common_metadata: CommonAttentionMetadata, metadata: M, + seq_lens: torch.Tensor, blk_table: torch.Tensor, slot_mapping: torch.Tensor, ) -> M: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index afaf778440f9..e9a5fdeba11a 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1741,8 +1741,8 @@ def _build_attn_group_metadata( and builder.supports_update_block_table ): attn_metadata_i = builder.update_block_table( - common_attn_metadata, cached_attn_metadata[cache_key], + common_attn_metadata.seq_lens, common_attn_metadata.block_table_tensor, common_attn_metadata.slot_mapping, ) From 1c667d31a1b357d97cb484304d85b9f106daedf7 Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Sun, 4 Jan 2026 17:47:36 +0000 Subject: [PATCH 102/130] update builder.update_block_table: rm seq_lens arg Signed-off-by: huanghaoyan.hhy --- vllm/attention/layers/chunked_local_attention.py | 5 +---- vllm/v1/attention/backends/flash_attn.py | 1 - vllm/v1/attention/backends/mamba_attn.py | 10 +++++++--- vllm/v1/attention/backends/utils.py | 1 - vllm/v1/worker/gpu_model_runner.py | 1 - 5 files changed, 8 insertions(+), 10 deletions(-) diff --git a/vllm/attention/layers/chunked_local_attention.py b/vllm/attention/layers/chunked_local_attention.py index 569a264d7f3c..43b98b2582c1 100644 --- a/vllm/attention/layers/chunked_local_attention.py +++ b/vllm/attention/layers/chunked_local_attention.py @@ -62,14 +62,11 @@ def build( def update_block_table( self, metadata, - seq_lens: torch.Tensor, blk_table: torch.Tensor, slot_mapping: torch.Tensor, ): blk_table = metadata.make_virtual_batches_block_table(blk_table) - return super().update_block_table( - metadata, seq_lens, blk_table, slot_mapping - ) + return super().update_block_table(metadata, blk_table, slot_mapping) attn_backend = subclass_attention_backend( name_prefix=prefix, diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 8066ddcfaeda..3445e998d637 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -498,7 +498,6 @@ def schedule( def update_block_table( self, metadata: FlashAttentionMetadata, - seq_lens: torch.Tensor, blk_table: torch.Tensor, slot_mapping: torch.Tensor, ) -> FlashAttentionMetadata: diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index 935d6a7279d8..3cdfc6aa1b8e 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -40,11 +40,15 @@ class BaseMambaAttentionMetadata: state_indices_tensor: torch.Tensor - # The following tensors are only used for prefix caching and are None if disabled + # The following tensors are only used for prefix caching with all mode and + # are None if disabled block_idx_last_scheduled_token: torch.Tensor | None block_idx_first_scheduled_token_p: torch.Tensor | None block_idx_last_computed_token: torch.Tensor | None + # The following tensor is only used for prefix caching with align mode + seq_lens: torch.Tensor + # The following attributes are for triton implementation of causal_conv1d nums_dict: dict | None = None batch_ptr: torch.Tensor | None = None @@ -283,6 +287,7 @@ def _compute_common_metadata( block_idx_last_computed_token=block_idx_last_computed_token, num_computed_tokens_p=num_computed_tokens_p, num_reqs=num_reqs, + seq_lens=common_attn_metadata.seq_lens, nums_dict=nums_dict, batch_ptr=batch_ptr, token_chunk_offset_ptr=token_chunk_offset_ptr, @@ -291,7 +296,6 @@ def _compute_common_metadata( def update_block_table( self, metadata: M, - seq_lens: torch.Tensor, blk_table: torch.Tensor, slot_mapping: torch.Tensor, ) -> M: @@ -304,7 +308,7 @@ def update_block_table( if prefix_caching_all_mode else mamba_get_block_table_tensor( blk_table, - seq_lens, + metadata.seq_lens, self.kv_cache_spec, self.vllm_config.cache_config.mamba_cache_mode, )[:, 0] diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 3c5e206d251e..aab615f32c23 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -394,7 +394,6 @@ def build( def update_block_table( self, metadata: M, - seq_lens: torch.Tensor, blk_table: torch.Tensor, slot_mapping: torch.Tensor, ) -> M: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index e9a5fdeba11a..f4aa96eafc64 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1742,7 +1742,6 @@ def _build_attn_group_metadata( ): attn_metadata_i = builder.update_block_table( cached_attn_metadata[cache_key], - common_attn_metadata.seq_lens, common_attn_metadata.block_table_tensor, common_attn_metadata.slot_mapping, ) From 9cdfaa7e6154ff15f617abfc9907c155dafde90a Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Sun, 4 Jan 2026 12:43:10 -0800 Subject: [PATCH 103/130] revert change Signed-off-by: Chen Zhang --- vllm/attention/layers/chunked_local_attention.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/vllm/attention/layers/chunked_local_attention.py b/vllm/attention/layers/chunked_local_attention.py index 43b98b2582c1..7e3794d40833 100644 --- a/vllm/attention/layers/chunked_local_attention.py +++ b/vllm/attention/layers/chunked_local_attention.py @@ -60,10 +60,7 @@ def build( return metadata def update_block_table( - self, - metadata, - blk_table: torch.Tensor, - slot_mapping: torch.Tensor, + self, metadata, blk_table: torch.Tensor, slot_mapping: torch.Tensor ): blk_table = metadata.make_virtual_batches_block_table(blk_table) return super().update_block_table(metadata, blk_table, slot_mapping) From 6e51e1cd2785ed26818f495264df394d516cb35b Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Sun, 4 Jan 2026 12:52:17 -0800 Subject: [PATCH 104/130] use mamba_get_block_table_tensor in basic mamba Signed-off-by: Chen Zhang --- vllm/v1/attention/backends/mamba_attn.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index 3cdfc6aa1b8e..c9f5339fdb35 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -300,19 +300,16 @@ def update_block_table( slot_mapping: torch.Tensor, ) -> M: new_metadata = copy.copy(metadata) - prefix_caching_all_mode = ( - self.vllm_config.cache_config.mamba_cache_mode == "all" - ) - state_indices_t = ( - blk_table - if prefix_caching_all_mode - else mamba_get_block_table_tensor( - blk_table, - metadata.seq_lens, - self.kv_cache_spec, - self.vllm_config.cache_config.mamba_cache_mode, - )[:, 0] + state_indices_t = mamba_get_block_table_tensor( + blk_table, + metadata.seq_lens, + self.kv_cache_spec, + self.vllm_config.cache_config.mamba_cache_mode, ) + if self.vllm_config.cache_config.mamba_cache_mode in ("all", "none"): + # Only needs the block that saves the running state + state_indices_t = state_indices_t[:, 0] + num_reqs = blk_table.shape[0] # For CUDA graphs, copy to persistent buffer From 85fc6d68360529860971bd43ae8ff16b18ed9b4b Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Sun, 4 Jan 2026 12:55:34 -0800 Subject: [PATCH 105/130] remove unrelated changes Signed-off-by: Chen Zhang --- collect_env.py | 857 ------------------------------------------------- 1 file changed, 857 deletions(-) delete mode 100644 collect_env.py diff --git a/collect_env.py b/collect_env.py deleted file mode 100644 index 4ca0852e3998..000000000000 --- a/collect_env.py +++ /dev/null @@ -1,857 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -# ruff: noqa -# code borrowed from https://github.com/pytorch/pytorch/blob/main/torch/utils/collect_env.py - -import datetime -import locale -import os -import subprocess -import sys - -# Unlike the rest of the PyTorch this file must be python2 compliant. -# This script outputs relevant system environment info -# Run it with `python collect_env.py` or `python -m torch.utils.collect_env` -from collections import namedtuple - -import regex as re - -from vllm.envs import environment_variables - -try: - import torch - - TORCH_AVAILABLE = True -except (ImportError, NameError, AttributeError, OSError): - TORCH_AVAILABLE = False - -# System Environment Information -SystemEnv = namedtuple( - "SystemEnv", - [ - "torch_version", - "is_debug_build", - "cuda_compiled_version", - "gcc_version", - "clang_version", - "cmake_version", - "os", - "libc_version", - "python_version", - "python_platform", - "is_cuda_available", - "cuda_runtime_version", - "cuda_module_loading", - "nvidia_driver_version", - "nvidia_gpu_models", - "cudnn_version", - "pip_version", # 'pip' or 'pip3' - "pip_packages", - "conda_packages", - "hip_compiled_version", - "hip_runtime_version", - "miopen_runtime_version", - "caching_allocator_config", - "is_xnnpack_available", - "cpu_info", - "rocm_version", # vllm specific field - "vllm_version", # vllm specific field - "vllm_build_flags", # vllm specific field - "gpu_topo", # vllm specific field - "env_vars", - ], -) - -DEFAULT_CONDA_PATTERNS = { - "torch", - "numpy", - "cudatoolkit", - "soumith", - "mkl", - "magma", - "triton", - "optree", - "nccl", - "transformers", - "zmq", - "nvidia", - "pynvml", - "flashinfer-python", -} - -DEFAULT_PIP_PATTERNS = { - "torch", - "numpy", - "mypy", - "flake8", - "triton", - "optree", - "onnx", - "nccl", - "transformers", - "zmq", - "nvidia", - "pynvml", - "flashinfer-python", -} - - -def run(command): - """Return (return-code, stdout, stderr).""" - shell = True if type(command) is str else False - try: - p = subprocess.Popen( - command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=shell - ) - raw_output, raw_err = p.communicate() - rc = p.returncode - if get_platform() == "win32": - enc = "oem" - else: - enc = locale.getpreferredencoding() - output = raw_output.decode(enc) - if command == "nvidia-smi topo -m": - # don't remove the leading whitespace of `nvidia-smi topo -m` - # because they are meaningful - output = output.rstrip() - else: - output = output.strip() - err = raw_err.decode(enc) - return rc, output, err.strip() - - except FileNotFoundError: - cmd_str = command if isinstance(command, str) else command[0] - return 127, "", f"Command not found: {cmd_str}" - - -def run_and_read_all(run_lambda, command): - """Run command using run_lambda; reads and returns entire output if rc is 0.""" - rc, out, _ = run_lambda(command) - if rc != 0: - return None - return out - - -def run_and_parse_first_match(run_lambda, command, regex): - """Run command using run_lambda, returns the first regex match if it exists.""" - rc, out, _ = run_lambda(command) - if rc != 0: - return None - match = re.search(regex, out) - if match is None: - return None - return match.group(1) - - -def run_and_return_first_line(run_lambda, command): - """Run command using run_lambda and returns first line if output is not empty.""" - rc, out, _ = run_lambda(command) - if rc != 0: - return None - return out.split("\n")[0] - - -def get_conda_packages(run_lambda, patterns=None): - if patterns is None: - patterns = DEFAULT_CONDA_PATTERNS - conda = os.environ.get("CONDA_EXE", "conda") - out = run_and_read_all(run_lambda, [conda, "list"]) - if out is None: - return out - - return "\n".join( - line - for line in out.splitlines() - if not line.startswith("#") and any(name in line for name in patterns) - ) - - -def get_gcc_version(run_lambda): - return run_and_parse_first_match(run_lambda, "gcc --version", r"gcc (.*)") - - -def get_clang_version(run_lambda): - return run_and_parse_first_match( - run_lambda, "clang --version", r"clang version (.*)" - ) - - -def get_cmake_version(run_lambda): - return run_and_parse_first_match(run_lambda, "cmake --version", r"cmake (.*)") - - -def get_nvidia_driver_version(run_lambda): - if get_platform() == "darwin": - cmd = "kextstat | grep -i cuda" - return run_and_parse_first_match( - run_lambda, cmd, r"com[.]nvidia[.]CUDA [(](.*?)[)]" - ) - smi = get_nvidia_smi() - return run_and_parse_first_match(run_lambda, smi, r"Driver Version: (.*?) ") - - -def get_gpu_info(run_lambda): - if get_platform() == "darwin" or ( - TORCH_AVAILABLE - and hasattr(torch.version, "hip") - and torch.version.hip is not None - ): - if TORCH_AVAILABLE and torch.cuda.is_available(): - if torch.version.hip is not None: - prop = torch.cuda.get_device_properties(0) - if hasattr(prop, "gcnArchName"): - gcnArch = " ({})".format(prop.gcnArchName) - else: - gcnArch = "NoGCNArchNameOnOldPyTorch" - else: - gcnArch = "" - return torch.cuda.get_device_name(None) + gcnArch - return None - smi = get_nvidia_smi() - uuid_regex = re.compile(r" \(UUID: .+?\)") - rc, out, _ = run_lambda(smi + " -L") - if rc != 0: - return None - # Anonymize GPUs by removing their UUID - return re.sub(uuid_regex, "", out) - - -def get_running_cuda_version(run_lambda): - return run_and_parse_first_match(run_lambda, "nvcc --version", r"release .+ V(.*)") - - -def get_cudnn_version(run_lambda): - """Return a list of libcudnn.so; it's hard to tell which one is being used.""" - if get_platform() == "win32": - system_root = os.environ.get("SYSTEMROOT", "C:\\Windows") - cuda_path = os.environ.get("CUDA_PATH", "%CUDA_PATH%") - where_cmd = os.path.join(system_root, "System32", "where") - cudnn_cmd = '{} /R "{}\\bin" cudnn*.dll'.format(where_cmd, cuda_path) - elif get_platform() == "darwin": - # CUDA libraries and drivers can be found in /usr/local/cuda/. See - # https://docs.nvidia.com/cuda/cuda-installation-guide-mac-os-x/index.html#install - # https://docs.nvidia.com/deeplearning/sdk/cudnn-install/index.html#installmac - # Use CUDNN_LIBRARY when cudnn library is installed elsewhere. - cudnn_cmd = "ls /usr/local/cuda/lib/libcudnn*" - else: - cudnn_cmd = 'ldconfig -p | grep libcudnn | rev | cut -d" " -f1 | rev' - rc, out, _ = run_lambda(cudnn_cmd) - # find will return 1 if there are permission errors or if not found - if len(out) == 0 or (rc != 1 and rc != 0): - l = os.environ.get("CUDNN_LIBRARY") - if l is not None and os.path.isfile(l): - return os.path.realpath(l) - return None - files_set = set() - for fn in out.split("\n"): - fn = os.path.realpath(fn) # eliminate symbolic links - if os.path.isfile(fn): - files_set.add(fn) - if not files_set: - return None - # Alphabetize the result because the order is non-deterministic otherwise - files = sorted(files_set) - if len(files) == 1: - return files[0] - result = "\n".join(files) - return "Probably one of the following:\n{}".format(result) - - -def get_nvidia_smi(): - # Note: nvidia-smi is currently available only on Windows and Linux - smi = "nvidia-smi" - if get_platform() == "win32": - system_root = os.environ.get("SYSTEMROOT", "C:\\Windows") - program_files_root = os.environ.get("PROGRAMFILES", "C:\\Program Files") - legacy_path = os.path.join( - program_files_root, "NVIDIA Corporation", "NVSMI", smi - ) - new_path = os.path.join(system_root, "System32", smi) - smis = [new_path, legacy_path] - for candidate_smi in smis: - if os.path.exists(candidate_smi): - smi = '"{}"'.format(candidate_smi) - break - return smi - - -def get_rocm_version(run_lambda): - """Returns the ROCm version if available, otherwise 'N/A'.""" - return run_and_parse_first_match( - run_lambda, "hipcc --version", r"HIP version: (\S+)" - ) - - -def get_vllm_version(): - from vllm import __version__, __version_tuple__ - - if __version__ == "dev": - return "N/A (dev)" - version_str = __version_tuple__[-1] - if isinstance(version_str, str) and version_str.startswith("g"): - # it's a dev build - if "." in version_str: - # it's a dev build containing local changes - git_sha = version_str.split(".")[0][1:] - date = version_str.split(".")[-1][1:] - return f"{__version__} (git sha: {git_sha}, date: {date})" - else: - # it's a dev build without local changes - git_sha = version_str[1:] # type: ignore - return f"{__version__} (git sha: {git_sha})" - return __version__ - - -def summarize_vllm_build_flags(): - # This could be a static method if the flags are constant, or dynamic if you need to check environment variables, etc. - return "CUDA Archs: {}; ROCm: {}".format( - os.environ.get("TORCH_CUDA_ARCH_LIST", "Not Set"), - "Enabled" if os.environ.get("ROCM_HOME") else "Disabled", - ) - - -def get_gpu_topo(run_lambda): - output = None - - if get_platform() == "linux": - output = run_and_read_all(run_lambda, "nvidia-smi topo -m") - if output is None: - output = run_and_read_all(run_lambda, "rocm-smi --showtopo") - - return output - - -# example outputs of CPU infos -# * linux -# Architecture: x86_64 -# CPU op-mode(s): 32-bit, 64-bit -# Address sizes: 46 bits physical, 48 bits virtual -# Byte Order: Little Endian -# CPU(s): 128 -# On-line CPU(s) list: 0-127 -# Vendor ID: GenuineIntel -# Model name: Intel(R) Xeon(R) Platinum 8375C CPU @ 2.90GHz -# CPU family: 6 -# Model: 106 -# Thread(s) per core: 2 -# Core(s) per socket: 32 -# Socket(s): 2 -# Stepping: 6 -# BogoMIPS: 5799.78 -# Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr -# sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl -# xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq monitor ssse3 fma cx16 -# pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand -# hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced -# fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap -# avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 -# xsaves wbnoinvd ida arat avx512vbmi pku ospke avx512_vbmi2 gfni vaes vpclmulqdq -# avx512_vnni avx512_bitalg tme avx512_vpopcntdq rdpid md_clear flush_l1d arch_capabilities -# Virtualization features: -# Hypervisor vendor: KVM -# Virtualization type: full -# Caches (sum of all): -# L1d: 3 MiB (64 instances) -# L1i: 2 MiB (64 instances) -# L2: 80 MiB (64 instances) -# L3: 108 MiB (2 instances) -# NUMA: -# NUMA node(s): 2 -# NUMA node0 CPU(s): 0-31,64-95 -# NUMA node1 CPU(s): 32-63,96-127 -# Vulnerabilities: -# Itlb multihit: Not affected -# L1tf: Not affected -# Mds: Not affected -# Meltdown: Not affected -# Mmio stale data: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown -# Retbleed: Not affected -# Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp -# Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization -# Spectre v2: Mitigation; Enhanced IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence -# Srbds: Not affected -# Tsx async abort: Not affected -# * win32 -# Architecture=9 -# CurrentClockSpeed=2900 -# DeviceID=CPU0 -# Family=179 -# L2CacheSize=40960 -# L2CacheSpeed= -# Manufacturer=GenuineIntel -# MaxClockSpeed=2900 -# Name=Intel(R) Xeon(R) Platinum 8375C CPU @ 2.90GHz -# ProcessorType=3 -# Revision=27142 -# -# Architecture=9 -# CurrentClockSpeed=2900 -# DeviceID=CPU1 -# Family=179 -# L2CacheSize=40960 -# L2CacheSpeed= -# Manufacturer=GenuineIntel -# MaxClockSpeed=2900 -# Name=Intel(R) Xeon(R) Platinum 8375C CPU @ 2.90GHz -# ProcessorType=3 -# Revision=27142 - - -def get_cpu_info(run_lambda): - rc, out, err = 0, "", "" - if get_platform() == "linux": - rc, out, err = run_lambda("lscpu") - elif get_platform() == "win32": - rc, out, err = run_lambda( - "wmic cpu get Name,Manufacturer,Family,Architecture,ProcessorType,DeviceID, \ - CurrentClockSpeed,MaxClockSpeed,L2CacheSize,L2CacheSpeed,Revision /VALUE" - ) - elif get_platform() == "darwin": - rc, out, err = run_lambda("sysctl -n machdep.cpu.brand_string") - cpu_info = "None" - if rc == 0: - cpu_info = out - else: - cpu_info = err - return cpu_info - - -def get_platform(): - if sys.platform.startswith("linux"): - return "linux" - elif sys.platform.startswith("win32"): - return "win32" - elif sys.platform.startswith("cygwin"): - return "cygwin" - elif sys.platform.startswith("darwin"): - return "darwin" - else: - return sys.platform - - -def get_mac_version(run_lambda): - return run_and_parse_first_match(run_lambda, "sw_vers -productVersion", r"(.*)") - - -def get_windows_version(run_lambda): - system_root = os.environ.get("SYSTEMROOT", "C:\\Windows") - wmic_cmd = os.path.join(system_root, "System32", "Wbem", "wmic") - findstr_cmd = os.path.join(system_root, "System32", "findstr") - return run_and_read_all( - run_lambda, "{} os get Caption | {} /v Caption".format(wmic_cmd, findstr_cmd) - ) - - -def get_lsb_version(run_lambda): - return run_and_parse_first_match( - run_lambda, "lsb_release -a", r"Description:\t(.*)" - ) - - -def check_release_file(run_lambda): - return run_and_parse_first_match( - run_lambda, "cat /etc/*-release", r'PRETTY_NAME="(.*)"' - ) - - -def get_os(run_lambda): - from platform import machine - - platform = get_platform() - - if platform == "win32" or platform == "cygwin": - return get_windows_version(run_lambda) - - if platform == "darwin": - version = get_mac_version(run_lambda) - if version is None: - return None - return "macOS {} ({})".format(version, machine()) - - if platform == "linux": - # Ubuntu/Debian based - desc = get_lsb_version(run_lambda) - if desc is not None: - return "{} ({})".format(desc, machine()) - - # Try reading /etc/*-release - desc = check_release_file(run_lambda) - if desc is not None: - return "{} ({})".format(desc, machine()) - - return "{} ({})".format(platform, machine()) - - # Unknown platform - return platform - - -def get_python_platform(): - import platform - - return platform.platform() - - -def get_libc_version(): - import platform - - if get_platform() != "linux": - return "N/A" - return "-".join(platform.libc_ver()) - - -def is_uv_venv(): - if os.environ.get("UV"): - return True - pyvenv_cfg_path = os.path.join(sys.prefix, "pyvenv.cfg") - if os.path.exists(pyvenv_cfg_path): - with open(pyvenv_cfg_path, "r") as f: - return any(line.startswith("uv = ") for line in f) - return False - - -def get_pip_packages(run_lambda, patterns=None): - """Return `pip list` output. Note: will also find conda-installed pytorch and numpy packages.""" - if patterns is None: - patterns = DEFAULT_PIP_PATTERNS - - def run_with_pip(): - try: - import importlib.util - - pip_spec = importlib.util.find_spec("pip") - pip_available = pip_spec is not None - except ImportError: - pip_available = False - - if pip_available: - cmd = [sys.executable, "-mpip", "list", "--format=freeze"] - elif is_uv_venv(): - print("uv is set") - cmd = ["uv", "pip", "list", "--format=freeze"] - else: - raise RuntimeError( - "Could not collect pip list output (pip or uv module not available)" - ) - - out = run_and_read_all(run_lambda, cmd) - return "\n".join( - line for line in out.splitlines() if any(name in line for name in patterns) - ) - - pip_version = "pip3" if sys.version[0] == "3" else "pip" - out = run_with_pip() - return pip_version, out - - -def get_cachingallocator_config(): - ca_config = os.environ.get("PYTORCH_CUDA_ALLOC_CONF", "") - return ca_config - - -def get_cuda_module_loading_config(): - if TORCH_AVAILABLE and torch.cuda.is_available(): - torch.cuda.init() - config = os.environ.get("CUDA_MODULE_LOADING", "") - return config - else: - return "N/A" - - -def is_xnnpack_available(): - if TORCH_AVAILABLE: - import torch.backends.xnnpack - - return str(torch.backends.xnnpack.enabled) # type: ignore[attr-defined] - else: - return "N/A" - - -def get_env_vars(): - env_vars = "" - secret_terms = ("secret", "token", "api", "access", "password") - report_prefix = ( - "TORCH", - "NCCL", - "PYTORCH", - "CUDA", - "CUBLAS", - "CUDNN", - "OMP_", - "MKL_", - "NVIDIA", - ) - for k, v in os.environ.items(): - if any(term in k.lower() for term in secret_terms): - continue - if k in environment_variables: - env_vars = env_vars + "{}={}".format(k, v) + "\n" - if k.startswith(report_prefix): - env_vars = env_vars + "{}={}".format(k, v) + "\n" - - return env_vars - - -def get_env_info(): - run_lambda = run - pip_version, pip_list_output = get_pip_packages(run_lambda) - - if TORCH_AVAILABLE: - version_str = torch.__version__ - debug_mode_str = str(torch.version.debug) - cuda_available_str = str(torch.cuda.is_available()) - cuda_version_str = torch.version.cuda - if ( - not hasattr(torch.version, "hip") or torch.version.hip is None - ): # cuda version - hip_compiled_version = hip_runtime_version = miopen_runtime_version = "N/A" - else: # HIP version - - def get_version_or_na(cfg, prefix): - _lst = [s.rsplit(None, 1)[-1] for s in cfg if prefix in s] - return _lst[0] if _lst else "N/A" - - cfg = torch._C._show_config().split("\n") - hip_runtime_version = get_version_or_na(cfg, "HIP Runtime") - miopen_runtime_version = get_version_or_na(cfg, "MIOpen") - cuda_version_str = "N/A" - hip_compiled_version = torch.version.hip - else: - version_str = debug_mode_str = cuda_available_str = cuda_version_str = "N/A" - hip_compiled_version = hip_runtime_version = miopen_runtime_version = "N/A" - - sys_version = sys.version.replace("\n", " ") - - conda_packages = get_conda_packages(run_lambda) - - rocm_version = get_rocm_version(run_lambda) - vllm_version = get_vllm_version() - vllm_build_flags = summarize_vllm_build_flags() - gpu_topo = get_gpu_topo(run_lambda) - - return SystemEnv( - torch_version=version_str, - is_debug_build=debug_mode_str, - python_version="{} ({}-bit runtime)".format( - sys_version, sys.maxsize.bit_length() + 1 - ), - python_platform=get_python_platform(), - is_cuda_available=cuda_available_str, - cuda_compiled_version=cuda_version_str, - cuda_runtime_version=get_running_cuda_version(run_lambda), - cuda_module_loading=get_cuda_module_loading_config(), - nvidia_gpu_models=get_gpu_info(run_lambda), - nvidia_driver_version=get_nvidia_driver_version(run_lambda), - cudnn_version=get_cudnn_version(run_lambda), - hip_compiled_version=hip_compiled_version, - hip_runtime_version=hip_runtime_version, - miopen_runtime_version=miopen_runtime_version, - pip_version=pip_version, - pip_packages=pip_list_output, - conda_packages=conda_packages, - os=get_os(run_lambda), - libc_version=get_libc_version(), - gcc_version=get_gcc_version(run_lambda), - clang_version=get_clang_version(run_lambda), - cmake_version=get_cmake_version(run_lambda), - caching_allocator_config=get_cachingallocator_config(), - is_xnnpack_available=is_xnnpack_available(), - cpu_info=get_cpu_info(run_lambda), - rocm_version=rocm_version, - vllm_version=vllm_version, - vllm_build_flags=vllm_build_flags, - gpu_topo=gpu_topo, - env_vars=get_env_vars(), - ) - - -env_info_fmt = """ -============================== - System Info -============================== -OS : {os} -GCC version : {gcc_version} -Clang version : {clang_version} -CMake version : {cmake_version} -Libc version : {libc_version} - -============================== - PyTorch Info -============================== -PyTorch version : {torch_version} -Is debug build : {is_debug_build} -CUDA used to build PyTorch : {cuda_compiled_version} -ROCM used to build PyTorch : {hip_compiled_version} - -============================== - Python Environment -============================== -Python version : {python_version} -Python platform : {python_platform} - -============================== - CUDA / GPU Info -============================== -Is CUDA available : {is_cuda_available} -CUDA runtime version : {cuda_runtime_version} -CUDA_MODULE_LOADING set to : {cuda_module_loading} -GPU models and configuration : {nvidia_gpu_models} -Nvidia driver version : {nvidia_driver_version} -cuDNN version : {cudnn_version} -HIP runtime version : {hip_runtime_version} -MIOpen runtime version : {miopen_runtime_version} -Is XNNPACK available : {is_xnnpack_available} - -============================== - CPU Info -============================== -{cpu_info} - -============================== -Versions of relevant libraries -============================== -{pip_packages} -{conda_packages} -""".strip() - -# both the above code and the following code use `strip()` to -# remove leading/trailing whitespaces, so we need to add a newline -# in between to separate the two sections -env_info_fmt += "\n\n" - -env_info_fmt += """ -============================== - vLLM Info -============================== -ROCM Version : {rocm_version} -vLLM Version : {vllm_version} -vLLM Build Flags: - {vllm_build_flags} -GPU Topology: - {gpu_topo} - -============================== - Environment Variables -============================== -{env_vars} -""".strip() - - -def pretty_str(envinfo): - def replace_nones(dct, replacement="Could not collect"): - for key in dct.keys(): - if dct[key] is not None: - continue - dct[key] = replacement - return dct - - def replace_bools(dct, true="Yes", false="No"): - for key in dct.keys(): - if dct[key] is True: - dct[key] = true - elif dct[key] is False: - dct[key] = false - return dct - - def prepend(text, tag="[prepend]"): - lines = text.split("\n") - updated_lines = [tag + line for line in lines] - return "\n".join(updated_lines) - - def replace_if_empty(text, replacement="No relevant packages"): - if text is not None and len(text) == 0: - return replacement - return text - - def maybe_start_on_next_line(string): - # If `string` is multiline, prepend a \n to it. - if string is not None and len(string.split("\n")) > 1: - return "\n{}\n".format(string) - return string - - mutable_dict = envinfo._asdict() - - # If nvidia_gpu_models is multiline, start on the next line - mutable_dict["nvidia_gpu_models"] = maybe_start_on_next_line( - envinfo.nvidia_gpu_models - ) - - # If the machine doesn't have CUDA, report some fields as 'No CUDA' - dynamic_cuda_fields = [ - "cuda_runtime_version", - "nvidia_gpu_models", - "nvidia_driver_version", - ] - all_cuda_fields = dynamic_cuda_fields + ["cudnn_version"] - all_dynamic_cuda_fields_missing = all( - mutable_dict[field] is None for field in dynamic_cuda_fields - ) - if ( - TORCH_AVAILABLE - and not torch.cuda.is_available() - and all_dynamic_cuda_fields_missing - ): - for field in all_cuda_fields: - mutable_dict[field] = "No CUDA" - if envinfo.cuda_compiled_version is None: - mutable_dict["cuda_compiled_version"] = "None" - - # Replace True with Yes, False with No - mutable_dict = replace_bools(mutable_dict) - - # Replace all None objects with 'Could not collect' - mutable_dict = replace_nones(mutable_dict) - - # If either of these are '', replace with 'No relevant packages' - mutable_dict["pip_packages"] = replace_if_empty(mutable_dict["pip_packages"]) - mutable_dict["conda_packages"] = replace_if_empty(mutable_dict["conda_packages"]) - - # Tag conda and pip packages with a prefix - # If they were previously None, they'll show up as ie '[conda] Could not collect' - if mutable_dict["pip_packages"]: - mutable_dict["pip_packages"] = prepend( - mutable_dict["pip_packages"], "[{}] ".format(envinfo.pip_version) - ) - if mutable_dict["conda_packages"]: - mutable_dict["conda_packages"] = prepend( - mutable_dict["conda_packages"], "[conda] " - ) - mutable_dict["cpu_info"] = envinfo.cpu_info - return env_info_fmt.format(**mutable_dict) - - -def get_pretty_env_info(): - return pretty_str(get_env_info()) - - -def main(): - print("Collecting environment information...") - output = get_pretty_env_info() - print(output) - - if ( - TORCH_AVAILABLE - and hasattr(torch, "utils") - and hasattr(torch.utils, "_crash_handler") - ): - minidump_dir = torch.utils._crash_handler.DEFAULT_MINIDUMP_DIR - if sys.platform == "linux" and os.path.exists(minidump_dir): - dumps = [ - os.path.join(minidump_dir, dump) for dump in os.listdir(minidump_dir) - ] - latest = max(dumps, key=os.path.getctime) - ctime = os.path.getctime(latest) - creation_time = datetime.datetime.fromtimestamp(ctime).strftime( - "%Y-%m-%d %H:%M:%S" - ) - msg = ( - "\n*** Detected a minidump at {} created on {}, ".format( - latest, creation_time - ) - + "if this is related to your bug please include it when you file a report ***" - ) - print(msg, file=sys.stderr) - - -if __name__ == "__main__": - main() From 0eafca27fe8b285da5ecd21424902170d932f5d6 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Mon, 5 Jan 2026 00:25:19 -0800 Subject: [PATCH 106/130] fix Signed-off-by: Chen Zhang --- vllm/v1/attention/backends/mamba_attn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index c9f5339fdb35..6ac9369b72dc 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -306,7 +306,7 @@ def update_block_table( self.kv_cache_spec, self.vllm_config.cache_config.mamba_cache_mode, ) - if self.vllm_config.cache_config.mamba_cache_mode in ("all", "none"): + if self.vllm_config.cache_config.mamba_cache_mode in ("none", "align"): # Only needs the block that saves the running state state_indices_t = state_indices_t[:, 0] From 7b2044efff7ae4f620297b277d5d5581b0cdf8d4 Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Mon, 5 Jan 2026 16:01:39 +0000 Subject: [PATCH 107/130] add mamba cache mode to mamba_spec and rm cache_config in kv_manager Signed-off-by: huanghaoyan.hhy --- vllm/model_executor/layers/mamba/abstract.py | 1 + vllm/v1/core/kv_cache_coordinator.py | 14 -------------- vllm/v1/core/kv_cache_manager.py | 3 --- vllm/v1/core/sched/scheduler.py | 1 - vllm/v1/core/single_type_kv_cache_manager.py | 13 +++---------- vllm/v1/kv_cache_interface.py | 1 + 6 files changed, 5 insertions(+), 28 deletions(-) diff --git a/vllm/model_executor/layers/mamba/abstract.py b/vllm/model_executor/layers/mamba/abstract.py index 74f4383e9c23..e51480086845 100644 --- a/vllm/model_executor/layers/mamba/abstract.py +++ b/vllm/model_executor/layers/mamba/abstract.py @@ -56,6 +56,7 @@ def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None: block_size=mamba_block_size, page_size_padded=page_size_padded, mamba_type=self.mamba_type, + mamba_cache_mode=vllm_config.cache_config.mamba_cache_mode, num_speculative_blocks=( vllm_config.speculative_config.num_speculative_tokens if vllm_config.speculative_config diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index 12a1ada309d8..5a5ed4676990 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -4,7 +4,6 @@ from collections.abc import Sequence from math import lcm -from vllm.config.cache import CacheConfig from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_metrics import KVCacheMetricsCollector from vllm.v1.core.kv_cache_utils import ( @@ -33,7 +32,6 @@ class KVCacheCoordinator(ABC): def __init__( self, - cache_config: CacheConfig, kv_cache_config: KVCacheConfig, max_model_len: int, use_eagle: bool, @@ -44,7 +42,6 @@ def __init__( hash_block_size: int, metrics_collector: KVCacheMetricsCollector | None = None, ): - self.cache_config = cache_config self.kv_cache_config = kv_cache_config self.max_model_len = max_model_len self.enable_caching = enable_caching @@ -62,7 +59,6 @@ def __init__( self.single_type_managers = tuple( get_manager_for_kv_cache_spec( kv_cache_spec=kv_cache_group.kv_cache_spec, - cache_config=self.cache_config, block_pool=self.block_pool, enable_caching=enable_caching, kv_cache_group_id=i, @@ -262,7 +258,6 @@ class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator): def __init__( self, - cache_config: CacheConfig, kv_cache_config: KVCacheConfig, max_model_len: int, use_eagle: bool, @@ -273,7 +268,6 @@ def __init__( metrics_collector: KVCacheMetricsCollector | None = None, ): super().__init__( - cache_config, kv_cache_config, max_model_len, use_eagle, @@ -309,7 +303,6 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator): def __init__( self, - cache_config: CacheConfig, kv_cache_config: KVCacheConfig, max_model_len: int, use_eagle: bool, @@ -321,7 +314,6 @@ def __init__( metrics_collector: KVCacheMetricsCollector | None = None, ): super().__init__( - cache_config, kv_cache_config, max_model_len, use_eagle, @@ -379,7 +371,6 @@ class HybridKVCacheCoordinator(KVCacheCoordinator): def __init__( self, - cache_config: CacheConfig, kv_cache_config: KVCacheConfig, max_model_len: int, use_eagle: bool, @@ -391,7 +382,6 @@ def __init__( metrics_collector: KVCacheMetricsCollector | None = None, ): super().__init__( - cache_config, kv_cache_config, max_model_len, use_eagle, @@ -568,7 +558,6 @@ def find_longest_cache_hit( def get_kv_cache_coordinator( - cache_config: CacheConfig, kv_cache_config: KVCacheConfig, max_model_len: int, use_eagle: bool, @@ -581,7 +570,6 @@ def get_kv_cache_coordinator( ) -> KVCacheCoordinator: if not enable_caching: return KVCacheCoordinatorNoPrefixCache( - cache_config, kv_cache_config, max_model_len, use_eagle, @@ -593,7 +581,6 @@ def get_kv_cache_coordinator( ) if len(kv_cache_config.kv_cache_groups) == 1: return UnitaryKVCacheCoordinator( - cache_config, kv_cache_config, max_model_len, use_eagle, @@ -605,7 +592,6 @@ def get_kv_cache_coordinator( metrics_collector=metrics_collector, ) return HybridKVCacheCoordinator( - cache_config, kv_cache_config, max_model_len, use_eagle, diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 95891dc3b86c..2caed0493752 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -6,7 +6,6 @@ from dataclasses import dataclass from typing import Literal, overload -from vllm.config.cache import CacheConfig from vllm.distributed.kv_events import KVCacheEvent from vllm.logger import init_logger from vllm.v1.core.kv_cache_coordinator import get_kv_cache_coordinator @@ -97,7 +96,6 @@ def __init__( self, kv_cache_config: KVCacheConfig, max_model_len: int, - cache_config: CacheConfig, hash_block_size: int, enable_caching: bool = True, use_eagle: bool = False, @@ -119,7 +117,6 @@ def __init__( self.prefix_cache_stats = PrefixCacheStats() if log_stats else None self.coordinator = get_kv_cache_coordinator( - cache_config=cache_config, kv_cache_config=kv_cache_config, max_model_len=self.max_model_len, use_eagle=self.use_eagle, diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 8893bd4456b8..45c6aeaa9f6b 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -208,7 +208,6 @@ def __init__( self.kv_cache_manager = KVCacheManager( kv_cache_config=kv_cache_config, max_model_len=self.max_model_len, - cache_config=self.cache_config, enable_caching=self.cache_config.enable_prefix_caching, use_eagle=self.use_eagle, log_stats=self.log_stats, diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index c6ba755e6765..2729eb06d0b1 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -5,7 +5,6 @@ from collections import defaultdict from collections.abc import Sequence -from vllm.config.cache import CacheConfig from vllm.utils.math_utils import cdiv from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_utils import BlockHashList, KVCacheBlock @@ -31,7 +30,6 @@ class SingleTypeKVCacheManager(ABC): def __init__( self, kv_cache_spec: KVCacheSpec, - cache_config: CacheConfig, block_pool: BlockPool, enable_caching: bool, kv_cache_group_id: int, @@ -45,7 +43,6 @@ def __init__( block_pool: The block pool. kv_cache_group_id: The id of the kv cache group of this manager. """ - self.cache_config = cache_config self.block_size = kv_cache_spec.block_size self.dcp_world_size = dcp_world_size self.pcp_world_size = pcp_world_size @@ -743,11 +740,9 @@ def get_num_common_prefix_blocks(self, running_request_id: str) -> int: class MambaManager(SingleTypeKVCacheManager): - def __init__( - self, kv_cache_spec: MambaSpec, cache_config: CacheConfig, **kwargs - ) -> None: - super().__init__(kv_cache_spec, cache_config=cache_config, **kwargs) - self.mamba_cache_mode = cache_config.mamba_cache_mode + def __init__(self, kv_cache_spec: MambaSpec, **kwargs) -> None: + super().__init__(kv_cache_spec, **kwargs) + self.mamba_cache_mode = kv_cache_spec.mamba_cache_mode self.num_speculative_blocks: int = kv_cache_spec.num_speculative_blocks if self.mamba_cache_mode == "align": self.last_state_block_idx: dict[str, int] = {} @@ -1025,7 +1020,6 @@ class SinkFullAttentionManager(FullAttentionManager): def __init__( self, kv_cache_spec: SinkFullAttentionSpec, - cache_config: CacheConfig, block_pool: BlockPool, enable_caching: bool, kv_cache_group_id: int, @@ -1034,7 +1028,6 @@ def __init__( ): super().__init__( kv_cache_spec, - cache_config, block_pool, enable_caching, kv_cache_group_id, diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index beb06f6ef5cb..c9eb03a83e41 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -265,6 +265,7 @@ class MambaSpec(KVCacheSpec): dtypes: tuple[torch.dtype] page_size_padded: int | None = None mamba_type: str = "mamba2" + mamba_cache_mode: str = "none" num_speculative_blocks: int = 0 @property From 42de03c2b2c50d0b3e8e38374f953a33c4e3c9a6 Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Mon, 5 Jan 2026 16:12:07 +0000 Subject: [PATCH 108/130] prefill exclude sps tokens Signed-off-by: huanghaoyan.hhy --- vllm/v1/core/sched/scheduler.py | 42 +++++++++++++++------------------ 1 file changed, 19 insertions(+), 23 deletions(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 45c6aeaa9f6b..238fe085bc4f 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -220,24 +220,21 @@ def __init__( self.use_pp = self.parallel_config.pipeline_parallel_size > 1 self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER - self.perf_metrics: ModelMetrics | None = None - if self.log_stats and vllm_config.observability_config.enable_mfu_metrics: - self.perf_metrics = ModelMetrics(vllm_config) - def has_mamba_layers(kv_cache_config: KVCacheConfig) -> bool: - has_mamba: bool = any( + return any( isinstance(group_spec.kv_cache_spec, MambaSpec) for group_spec in kv_cache_config.kv_cache_groups ) - if vllm_config.model_config.is_hybrid: - assert has_mamba, "Hybrid models must have mamba layers" - return has_mamba + self.has_mamba_layers = has_mamba_layers(kv_cache_config) self.need_mamba_block_aligned_split = ( - has_mamba_layers(self.kv_cache_config) - and self.cache_config.mamba_cache_mode == "align" + self.has_mamba_layers and self.cache_config.mamba_cache_mode == "align" ) + self.perf_metrics: ModelMetrics | None = None + if self.log_stats and vllm_config.observability_config.enable_mfu_metrics: + self.perf_metrics = ModelMetrics(vllm_config) + def _mamba_block_aligned_split( self, request: Request, @@ -334,24 +331,23 @@ def schedule(self) -> SchedulerOutput: req_index += 1 continue + # TODO: merge PR#30618 + num_tokens_to_compute = ( + request.num_tokens_with_spec + request.num_output_placeholders + ) # Ensure new tokens for a request in the prefill phase do not contain # draft tokens, especially in the last prefill chunk. For a hybrid-model, # extra draft tokens would corrupt the generated Mamba state. # TODO: This logic does not yet handle resumed requests. - if request.num_computed_tokens < request.num_prompt_tokens: - num_new_tokens = ( - min( - request.num_tokens_with_spec + request.num_output_placeholders, - request.num_prompt_tokens, - ) - - request.num_computed_tokens - ) - else: - num_new_tokens = ( - request.num_tokens_with_spec - + request.num_output_placeholders - - request.num_computed_tokens + if ( + self.has_mamba_layers + and request.num_computed_tokens < request.num_prompt_tokens + ): + num_tokens_to_compute = min( + num_tokens_to_compute, request.num_prompt_tokens ) + num_new_tokens = num_tokens_to_compute - request.num_computed_tokens + if 0 < self.scheduler_config.long_prefill_token_threshold < num_new_tokens: num_new_tokens = self.scheduler_config.long_prefill_token_threshold num_new_tokens = min(num_new_tokens, token_budget) From b39990db1fdc7225a4212f73fa9c6fbab2ea40a1 Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Mon, 5 Jan 2026 17:07:44 +0000 Subject: [PATCH 109/130] fix tests Signed-off-by: huanghaoyan.hhy --- tests/v1/core/test_single_type_kv_cache_manager.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/v1/core/test_single_type_kv_cache_manager.py b/tests/v1/core/test_single_type_kv_cache_manager.py index 23097bf2a086..3af4a3dce5e0 100644 --- a/tests/v1/core/test_single_type_kv_cache_manager.py +++ b/tests/v1/core/test_single_type_kv_cache_manager.py @@ -24,7 +24,7 @@ def get_sliding_window_manager(sliding_window_spec, block_pool, enable_caching=True): return SlidingWindowManager( sliding_window_spec, - block_pool, + block_pool=block_pool, enable_caching=enable_caching, kv_cache_group_id=0, ) @@ -35,7 +35,7 @@ def get_chunked_local_attention_manager( ): return ChunkedLocalAttentionManager( chunked_local_attention_spec, - block_pool, + block_pool=block_pool, enable_caching=enable_caching, kv_cache_group_id=0, ) From f480c083585f66a5481cff8b30fc81e72f446867 Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Mon, 5 Jan 2026 17:40:56 +0000 Subject: [PATCH 110/130] revert InputBatch api Signed-off-by: huanghaoyan.hhy --- vllm/v1/worker/block_table.py | 16 +++++++++++++++- vllm/v1/worker/gpu_input_batch.py | 3 ++- vllm/v1/worker/gpu_model_runner.py | 15 +++------------ vllm/v1/worker/tpu_input_batch.py | 3 +-- vllm/v1/worker/tpu_model_runner.py | 11 ----------- 5 files changed, 21 insertions(+), 27 deletions(-) diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index 24c17f8a3734..591f49761a0e 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -6,7 +6,9 @@ from vllm.distributed import get_dcp_group, get_pcp_group from vllm.logger import init_logger +from vllm.utils.math_utils import cdiv from vllm.v1.utils import CpuGpuBuffer +from vllm.v1.worker.cp_utils import get_total_cp_world_size logger = init_logger(__name__) @@ -254,12 +256,13 @@ class MultiGroupBlockTable: def __init__( self, max_num_reqs: int, + max_model_len: int, max_num_batched_tokens: int, pin_memory: bool, device: torch.device, block_sizes: list[int], kernel_block_sizes: list[int], - max_num_blocks: list[int], + max_num_blocks: list[int] | None = None, cp_kv_cache_interleave_size: int = 1, ) -> None: if len(kernel_block_sizes) != len(block_sizes): @@ -267,6 +270,17 @@ def __init__( f"kernel_block_sizes length ({len(kernel_block_sizes)}) " f"must match block_sizes length ({len(block_sizes)})" ) + if max_num_blocks is None: + # Note(hc): each dcp rank only store + # (max_model_len//dcp_world_size) tokens in kvcache, + # so the block_size which used for calc max_num_blocks_per_req + # must be multiplied by dcp_world_size. + total_cp_world_size = get_total_cp_world_size() + max_num_blocks = [ + cdiv(max_model_len, block_size * total_cp_world_size) + for block_size in block_sizes + ] + if len(max_num_blocks) != len(block_sizes): raise ValueError( f"max_num_blocks length ({len(max_num_blocks)}) " diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index e4823cd57d5c..98b2457aa1c6 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -90,7 +90,7 @@ def __init__( vocab_size: int, block_sizes: list[int], # The block_size of each kv cache group kernel_block_sizes: list[int], - max_num_blocks_per_req: list[int], + max_num_blocks_per_req: list[int] | None = None, logitsprocs: LogitsProcessors | None = None, logitsprocs_need_output_token_ids: bool = False, is_spec_decode: bool = False, @@ -141,6 +141,7 @@ def __init__( # Block table. self.block_table = MultiGroupBlockTable( max_num_reqs=max_num_reqs, + max_model_len=max_model_len, max_num_batched_tokens=max_num_batched_tokens, pin_memory=pin_memory, device=device, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index f4aa96eafc64..3ed6f0d0a7c4 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -491,16 +491,6 @@ def __init__( vocab_size=self.model_config.get_vocab_size(), block_sizes=[self.cache_config.block_size], kernel_block_sizes=[self.cache_config.block_size], - max_num_blocks_per_req=[ - # Note(hc): each dcp rank only store - # (max_model_len//dcp_world_size) tokens in kvcache, - # so the block_size which used for calc max_num_blocks_per_req - # must be multiplied by dcp_world_size. - cdiv( - self.max_model_len, - self.cache_config.block_size * get_total_cp_world_size(), - ) - ], is_spec_decode=bool(self.vllm_config.speculative_config), logitsprocs=build_logitsprocs( self.vllm_config, @@ -5292,11 +5282,12 @@ def may_reinitialize_input_batch( if not isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec) ] max_num_blocks = [] + max_model_len = max(self.max_model_len, self.max_encoder_len) for i, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups): if isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec): continue max_num_blocks_per_req = cdiv( - self.max_model_len, block_sizes[i] * get_total_cp_world_size() + max_model_len, block_sizes[i] * get_total_cp_world_size() ) if isinstance(kv_cache_group.kv_cache_spec, MambaSpec): mamba_blocks_per_req = ( @@ -5319,7 +5310,7 @@ def may_reinitialize_input_batch( ) self.input_batch = InputBatch( max_num_reqs=self.max_num_reqs, - max_model_len=max(self.max_model_len, self.max_encoder_len), + max_model_len=max_model_len, max_num_batched_tokens=self.max_num_tokens, device=self.device, pin_memory=self.pin_memory, diff --git a/vllm/v1/worker/tpu_input_batch.py b/vllm/v1/worker/tpu_input_batch.py index 1396c8ad9d5a..3758a73ee496 100644 --- a/vllm/v1/worker/tpu_input_batch.py +++ b/vllm/v1/worker/tpu_input_batch.py @@ -29,7 +29,6 @@ def __init__( vocab_size: int, block_sizes: list[int], # The block_size of each kv cache group kernel_block_sizes: list[int], - max_num_blocks_per_req: list[int], ): self.max_num_reqs = max_num_reqs self.max_model_len = max_model_len @@ -65,12 +64,12 @@ def __init__( # Block table. self.block_table = MultiGroupBlockTable( max_num_reqs=max_num_reqs, + max_model_len=max_model_len, max_num_batched_tokens=max_num_batched_tokens, pin_memory=pin_memory, device=device, block_sizes=block_sizes, kernel_block_sizes=kernel_block_sizes, - max_num_blocks=max_num_blocks_per_req, ) # Sampling-related. diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index db0d2ae17622..ba5e6200e426 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -77,7 +77,6 @@ ) from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler -from vllm.v1.worker.cp_utils import get_total_cp_world_size from vllm.v1.worker.kv_connector_model_runner_mixin import ( KVConnectorModelRunnerMixin, KVConnectorOutput, @@ -262,9 +261,6 @@ def __init__( vocab_size=self.model_config.get_vocab_size(), block_sizes=[self.block_size], kernel_block_sizes=[self.cache_config.block_size], - max_num_blocks_per_req=[ - cdiv(self.max_model_len, self.block_size * get_total_cp_world_size()) - ], ) # Cached torch/numpy tensor @@ -1850,13 +1846,6 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: kernel_block_sizes=[ kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size ], - max_num_blocks_per_req=[ - cdiv( - self.max_model_len, - kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size - * get_total_cp_world_size(), - ) - ], ) # Verify dtype compatibility between block_table_cpu and input_batch assert ( From 67d4e035a84612d3b7e26c0a8da9e30af9c117e1 Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Tue, 6 Jan 2026 17:17:35 +0000 Subject: [PATCH 111/130] revert code Signed-off-by: huanghaoyan.hhy --- vllm/model_executor/layers/kda.py | 5 +---- vllm/v1/attention/backends/gdn_attn.py | 7 ++----- vllm/v1/core/single_type_kv_cache_manager.py | 1 - vllm/v1/worker/gpu_model_runner.py | 1 - vllm/v1/worker/utils.py | 5 +---- 5 files changed, 4 insertions(+), 15 deletions(-) diff --git a/vllm/model_executor/layers/kda.py b/vllm/model_executor/layers/kda.py index 80a1b32df928..27cc3884517f 100644 --- a/vllm/model_executor/layers/kda.py +++ b/vllm/model_executor/layers/kda.py @@ -31,10 +31,7 @@ RowParallelLinear, ) from .mamba.abstract import MambaBase -from .mamba.mamba_utils import ( - MambaStateDtypeCalculator, - MambaStateShapeCalculator, -) +from .mamba.mamba_utils import MambaStateDtypeCalculator, MambaStateShapeCalculator from .mamba.ops.causal_conv1d import causal_conv1d_fn, causal_conv1d_update from .quantization.base_config import QuantizationConfig diff --git a/vllm/v1/attention/backends/gdn_attn.py b/vllm/v1/attention/backends/gdn_attn.py index 740f282afb52..a69aadf34635 100644 --- a/vllm/v1/attention/backends/gdn_attn.py +++ b/vllm/v1/attention/backends/gdn_attn.py @@ -9,7 +9,6 @@ from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.config import VllmConfig -from vllm.logger import init_logger from vllm.v1.attention.backends.utils import ( AttentionCGSupport, AttentionMetadataBuilder, @@ -20,8 +19,6 @@ ) from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec -logger = init_logger(__name__) - class GDNAttentionBackend(AttentionBackend): @staticmethod @@ -150,8 +147,8 @@ def build( # type: ignore[override] context_lens_tensor = context_lens.to(query_start_loc.device, non_blocking=True) nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None block_table_tensor = mamba_get_block_table_tensor( - common_attn_metadata.block_table_tensor, - common_attn_metadata.seq_lens, + m.block_table_tensor, + m.seq_lens, self.kv_cache_spec, self.vllm_config.cache_config.mamba_cache_mode, ) diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 2729eb06d0b1..98c38d4d4803 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -378,7 +378,6 @@ def remove_skipped_blocks( break removed_blocks.append(blocks[i]) blocks[i] = self._null_block - self.block_pool.free_blocks(removed_blocks) def get_num_skipped_tokens(self, num_computed_tokens: int) -> int: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 3ed6f0d0a7c4..459fd04deb54 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -5178,7 +5178,6 @@ def calculate_reorder_batch_threshold(self) -> None: just may have a performance penalty due to that backend treating decodes as prefills. """ - min_none_high = lambda a, b: a if b is None else b if a is None else min(a, b) reorder_batch_thresholds: list[int | None] = [ diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index 0eedbd68a6f0..31ccf7f15746 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -19,10 +19,7 @@ from vllm.utils.mem_utils import MemorySnapshot from vllm.v1.attention.backends.utils import AttentionMetadataBuilder from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget -from vllm.v1.kv_cache_interface import ( - KVCacheGroupSpec, - KVCacheSpec, -) +from vllm.v1.kv_cache_interface import KVCacheGroupSpec, KVCacheSpec logger = init_logger(__name__) From 8dcf54df4b3eebe4e4cace89d5a33843db5701be Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Tue, 6 Jan 2026 17:18:35 +0000 Subject: [PATCH 112/130] update according to suggestions Signed-off-by: huanghaoyan.hhy --- vllm/model_executor/layers/mamba/mamba_utils.py | 4 ++-- vllm/model_executor/models/config.py | 2 +- vllm/v1/core/single_type_kv_cache_manager.py | 7 +++---- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_utils.py b/vllm/model_executor/layers/mamba/mamba_utils.py index 26ddded7914a..49af117c841d 100644 --- a/vllm/model_executor/layers/mamba/mamba_utils.py +++ b/vllm/model_executor/layers/mamba/mamba_utils.py @@ -276,7 +276,7 @@ def linear_attention_state_copy_func(cls): @classmethod def mamba1_state_copy_func(cls): - return get_conv_copy_spec, get_temporal_copy_spec + return (get_conv_copy_spec, get_temporal_copy_spec) @classmethod def mamba2_state_copy_func(cls): @@ -288,7 +288,7 @@ def short_conv_state_copy_func(cls): @classmethod def gated_delta_net_state_copy_func(cls): - return get_conv_copy_spec, get_temporal_copy_spec + return (get_conv_copy_spec, get_temporal_copy_spec) @classmethod def kda_state_copy_func(cls): diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 8530c1eec007..36c415de8002 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -313,7 +313,7 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: if cache_config.mamba_cache_mode == "none": cache_config.mamba_cache_mode = "align" logger.warning( - "Mamba cache mode is set to 'align' defaultly when prefix " + "Mamba cache mode is set to 'align' by default when prefix " "caching is enabled" ) if cache_config.mamba_cache_mode == "all": diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 98c38d4d4803..358edd9ae26a 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -829,10 +829,9 @@ def get_num_blocks_to_allocate( if self.mamba_cache_mode != "align": # Allocate extra `num_speculative_blocks` blocks for # speculative decoding (MTP/EAGLE) with linear attention. - if self.kv_cache_spec.num_speculative_blocks > 0: + if self.num_speculative_blocks > 0: num_tokens += ( - self.kv_cache_spec.block_size - * self.kv_cache_spec.num_speculative_blocks + self.kv_cache_spec.block_size * self.num_speculative_blocks ) return super().get_num_blocks_to_allocate( request_id, @@ -864,7 +863,7 @@ def get_num_blocks_to_allocate( else: # First prefill. Allocate 1 block for running state and the # speculative blocks. - num_new_blocks = 1 + self.kv_cache_spec.num_speculative_blocks + num_new_blocks = 1 + self.num_speculative_blocks # If a computed block of a request is an eviction candidate (in the # free queue and ref_cnt == 0), it will be changed from a free block From 0ad13e206012655d23175eb9d2238ad6e7d196d1 Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Wed, 7 Jan 2026 16:58:45 +0000 Subject: [PATCH 113/130] update mamba cache mode config Signed-off-by: huanghaoyan.hhy --- vllm/model_executor/models/config.py | 49 ++++++++++++++-------------- 1 file changed, 24 insertions(+), 25 deletions(-) diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 36c415de8002..e507e314c2f5 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -311,37 +311,36 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: if cache_config.enable_prefix_caching: if cache_config.mamba_cache_mode == "none": - cache_config.mamba_cache_mode = "align" + cache_config.mamba_cache_mode = ( + "all" if model_config.supports_mamba_prefix_caching else "align" + ) logger.warning( - "Mamba cache mode is set to 'align' by default when prefix " - "caching is enabled" + "Mamba cache mode is set to '%s' for %s by default " + "when prefix caching is enabled", + cache_config.mamba_cache_mode, + model_config.architecture, ) - if cache_config.mamba_cache_mode == "all": - if model_config.supports_mamba_prefix_caching: - logger.info( - "Warning: Prefix caching is currently enabled. " - "Its support for Mamba layers is experimental. " - "Please report any issues you may observe." - ) - else: - logger.info( - "Hybrid or mamba-based model detected without " - "support for prefix caching: disabling." - ) - cache_config.enable_prefix_caching = False - elif cache_config.mamba_cache_mode == "align": - logger.info( - "Warning: Mamba cache mode 'align' with prefix caching is" - " currently enabled. Its support is experimental. " - "Please report any issues you may observe." + if ( + cache_config.mamba_cache_mode == "all" + and not model_config.supports_mamba_prefix_caching + ): + cache_config.mamba_cache_mode = "align" + logger.warning( + "Hybrid or mamba-based model detected without support " + "for prefix caching with Mamba cache 'all' mode: " + "falling back to 'align' mode." ) + if cache_config.mamba_cache_mode == "align": assert vllm_config.scheduler_config.enable_chunked_prefill, ( "Chunked prefill is required for mamba cache mode 'align'." ) - else: - raise ValueError( - f"unknown mamba cache mode: {cache_config.mamba_cache_mode}" - ) + logger.info( + "Warning: Prefix caching with Mamba cache '%s' " + "mode is currently enabled. " + "Its support for Mamba layers is experimental. " + "Please report any issues you may observe.", + cache_config.mamba_cache_mode, + ) else: if cache_config.mamba_cache_mode != "none": cache_config.mamba_cache_mode = "none" From f0d49ef761a8b9d1a23f27911aea87e0c9beac56 Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Wed, 7 Jan 2026 17:36:11 +0000 Subject: [PATCH 114/130] fix get_num_blocks_to_allocate tests Signed-off-by: huanghaoyan.hhy --- .../core/test_single_type_kv_cache_manager.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/tests/v1/core/test_single_type_kv_cache_manager.py b/tests/v1/core/test_single_type_kv_cache_manager.py index 3af4a3dce5e0..93606ffcb70a 100644 --- a/tests/v1/core/test_single_type_kv_cache_manager.py +++ b/tests/v1/core/test_single_type_kv_cache_manager.py @@ -342,11 +342,15 @@ def test_get_num_blocks_to_allocate(): ] assert ( - manager.get_num_blocks_to_allocate("1", 20 * block_size, cached_blocks_1, 0) + manager.get_num_blocks_to_allocate( + "1", 20 * block_size, cached_blocks_1, 0, 20 * block_size + ) == 20 ) assert ( - manager.get_num_blocks_to_allocate("2", 20 * block_size, cached_blocks_2, 0) + manager.get_num_blocks_to_allocate( + "2", 20 * block_size, cached_blocks_2, 0, 20 * block_size + ) == 15 ) @@ -375,6 +379,7 @@ def test_evictable_cached_blocks_not_double_allocated(): num_tokens=2 * block_size, new_computed_blocks=[evictable_block], total_computed_tokens=block_size, + num_tokens_main_model=2 * block_size, ) # Free capacity check should count evictable cached blocks, but allocation # should only allocate the truly new block. @@ -411,10 +416,14 @@ def test_chunked_local_attention_get_num_blocks_to_allocate(): ] assert ( - manager.get_num_blocks_to_allocate("1", 20 * block_size, cached_blocks_1, 0) + manager.get_num_blocks_to_allocate( + "1", 20 * block_size, cached_blocks_1, 0, 20 * block_size + ) == 20 ) assert ( - manager.get_num_blocks_to_allocate("2", 20 * block_size, cached_blocks_2, 0) + manager.get_num_blocks_to_allocate( + "2", 20 * block_size, cached_blocks_2, 0, 20 * block_size + ) == 15 ) From f02f2659597b5d989c66c0f4024f93946c88d435 Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Wed, 7 Jan 2026 18:24:06 +0000 Subject: [PATCH 115/130] add _get_num_evictable_blocks Signed-off-by: huanghaoyan.hhy --- vllm/v1/core/single_type_kv_cache_manager.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 358edd9ae26a..f96a9807fbf2 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -66,6 +66,10 @@ def __init__( self.kv_cache_group_id = kv_cache_group_id self._null_block = block_pool.null_block + @classmethod + def _get_num_evictable_blocks(cls, blocks: Sequence[KVCacheBlock]): + return sum(blk.ref_cnt == 0 and not blk.is_null for blk in blocks) + def get_num_blocks_to_allocate( self, request_id: str, @@ -125,9 +129,8 @@ def get_num_blocks_to_allocate( # If a computed block is an eviction candidate (in the free queue and # ref_cnt == 0), it will be removed from the free queue when touched by # the allocated request, so we must count it in the free-capacity check. - num_evictable_blocks = sum( - blk.ref_cnt == 0 and not blk.is_null - for blk in new_computed_blocks[num_skipped_new_computed_blocks:] + num_evictable_blocks = self._get_num_evictable_blocks( + new_computed_blocks[num_skipped_new_computed_blocks:] ) return num_new_blocks + num_evictable_blocks @@ -865,12 +868,8 @@ def get_num_blocks_to_allocate( # speculative blocks. num_new_blocks = 1 + self.num_speculative_blocks - # If a computed block of a request is an eviction candidate (in the - # free queue and ref_cnt == 0), it will be changed from a free block - # to a computed block when the request is allocated, so we also count - # it as needed to be allocated. - num_evictable_computed_blocks = sum( - blk.ref_cnt == 0 and not blk.is_null for blk in new_computed_blocks + num_evictable_computed_blocks = self._get_num_evictable_blocks( + new_computed_blocks ) return num_new_blocks + num_evictable_computed_blocks From 7c2a6f2b7429bdaf3f9d0f96539d2537c74ce8b9 Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Wed, 7 Jan 2026 18:24:23 +0000 Subject: [PATCH 116/130] add todo Signed-off-by: huanghaoyan.hhy --- vllm/v1/core/sched/scheduler.py | 1 + vllm/v1/core/single_type_kv_cache_manager.py | 1 + 2 files changed, 2 insertions(+) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 238fe085bc4f..65f83a0c8eb2 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -245,6 +245,7 @@ def _mamba_block_aligned_split( assert num_external_computed_tokens == 0, ( "External KV connector is not verified yet" ) + # TODO: need check for resume requests if request.num_output_tokens == 0: # prefill # To enable block-aligned caching of the Mamba state, `num_new_tokens` # must be a multiple of `block_size`. diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index f96a9807fbf2..63f3633dc56c 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -803,6 +803,7 @@ def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> No assert isinstance(self.kv_cache_spec, MambaSpec) super().remove_skipped_blocks(request_id, num_computed_tokens) if self.mamba_cache_mode == "align": + # TODO: need comments last_state_block_idx = self.last_state_block_idx.get(request_id) if ( last_state_block_idx is not None From 21e8cee973a10b958dcd63e4e228ee274a62cedd Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Thu, 8 Jan 2026 16:47:43 +0000 Subject: [PATCH 117/130] fix allocate_new_blocks Signed-off-by: huanghaoyan.hhy --- tests/v1/core/test_single_type_kv_cache_manager.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/v1/core/test_single_type_kv_cache_manager.py b/tests/v1/core/test_single_type_kv_cache_manager.py index 93606ffcb70a..b05040ebe2a6 100644 --- a/tests/v1/core/test_single_type_kv_cache_manager.py +++ b/tests/v1/core/test_single_type_kv_cache_manager.py @@ -391,7 +391,9 @@ def test_evictable_cached_blocks_not_double_allocated(): num_local_computed_tokens=block_size, num_external_computed_tokens=0, ) - new_blocks = manager.allocate_new_blocks(request_id, num_tokens=4) + new_blocks = manager.allocate_new_blocks( + request_id, num_tokens=4, num_tokens_main_model=4 + ) assert len(new_blocks) == 1 assert len(manager.req_to_blocks[request_id]) == 2 From aa13dc5a9b127a2dea802ee0ee585a0a7f9b8540 Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Wed, 14 Jan 2026 10:28:33 +0000 Subject: [PATCH 118/130] revert mamba_block_size code Signed-off-by: huanghaoyan.hhy --- vllm/model_executor/models/config.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index e507e314c2f5..cd3e0fa12eae 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -341,6 +341,11 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: "Please report any issues you may observe.", cache_config.mamba_cache_mode, ) + # By default, mamba block size will be set to max_model_len (see + # below). When enabling prefix caching, we align mamba block size + # to the block size as the basic granularity for prefix caching. + if cache_config.mamba_block_size is None: + cache_config.mamba_block_size = cache_config.block_size else: if cache_config.mamba_cache_mode != "none": cache_config.mamba_cache_mode = "none" From 3481f160482da37cb783d6534dadd39740dbbe60 Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Wed, 14 Jan 2026 12:24:25 +0000 Subject: [PATCH 119/130] remove unused comments Signed-off-by: huanghaoyan.hhy --- vllm/v1/kv_cache_interface.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index c9eb03a83e41..3c7157487654 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -280,9 +280,6 @@ def page_size_bytes(self) -> int: return page_size def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: - # We allocate 1 block for each request now, so max_memory_usage_bytes is - # the same as page_size_bytes. - # Need to update this when supporting prefix caching. if vllm_config.cache_config.mamba_cache_mode == "all": max_model_len = vllm_config.model_config.max_model_len return cdiv(max_model_len, self.block_size) * self.page_size_bytes From 7502358395dfbad3943cccfcf04fb22983ae59ad Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Wed, 14 Jan 2026 15:39:00 +0000 Subject: [PATCH 120/130] add classmethod Signed-off-by: huanghaoyan.hhy --- vllm/model_executor/models/granitemoehybrid.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/model_executor/models/granitemoehybrid.py b/vllm/model_executor/models/granitemoehybrid.py index 6f58df835b05..0b601b4b8941 100644 --- a/vllm/model_executor/models/granitemoehybrid.py +++ b/vllm/model_executor/models/granitemoehybrid.py @@ -643,6 +643,7 @@ def get_mamba_state_shape_from_config( conv_kernel=hf_config.mamba_d_conv, ) + @classmethod def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]: return MambaStateCopyFuncCalculator.mamba2_state_copy_func() From c808dd0f325f405269a8745207f198fd2e78ecad Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Wed, 14 Jan 2026 18:00:41 +0000 Subject: [PATCH 121/130] block_aligned_split support resume Signed-off-by: huanghaoyan.hhy --- vllm/v1/core/sched/scheduler.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 65f83a0c8eb2..25cbc3457e74 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -245,19 +245,19 @@ def _mamba_block_aligned_split( assert num_external_computed_tokens == 0, ( "External KV connector is not verified yet" ) + block_size = self.cache_config.block_size # TODO: need check for resume requests - if request.num_output_tokens == 0: # prefill - # To enable block-aligned caching of the Mamba state, `num_new_tokens` - # must be a multiple of `block_size`. - # As an exception, if `num_new_tokens` is less than `block_size`, the - # state is simply not cached, requiring no special handling. - # Additionally, when Eagle mode is enabled, FullAttn prunes the last - # matching block. To prevent this from causing a Mamba cache miss, the - # last chunk must be larger than `block_size`. - block_size = self.cache_config.block_size - last_cache_position = ( - request.num_prompt_tokens - request.num_prompt_tokens % block_size - ) + # To enable block-aligned caching of the Mamba state, `num_new_tokens` + # must be a multiple of `block_size`. + # As an exception, if `num_new_tokens` is less than `block_size`, the + # state is simply not cached, requiring no special handling. + # Additionally, when Eagle mode is enabled, FullAttn prunes the last + # matching block. To prevent this from causing a Mamba cache miss, the + # last chunk must be larger than `block_size`. + if num_new_tokens >= block_size: + # Use `num_tokens` instead of `num_prompt_tokens` to handle + # resumed requests. + last_cache_position = request.num_tokens - request.num_tokens % block_size # eagle prune if self.use_eagle: last_cache_position = max(last_cache_position - block_size, 0) @@ -266,14 +266,14 @@ def _mamba_block_aligned_split( + num_new_local_computed_tokens + num_external_computed_tokens ) - num_computed_tokens_after_prefill = num_computed_tokens + num_new_tokens - if num_computed_tokens_after_prefill < last_cache_position: + num_computed_tokens_after_sched = num_computed_tokens + num_new_tokens + if num_computed_tokens_after_sched < last_cache_position: # align to block_size num_new_tokens = num_new_tokens // block_size * block_size elif ( num_computed_tokens < last_cache_position - < num_computed_tokens_after_prefill + < num_computed_tokens_after_sched ): # force to cache the last chunk num_new_tokens = last_cache_position - num_computed_tokens From 976081f0e049389aa0ba045be63663d8c69197db Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Wed, 14 Jan 2026 18:45:03 +0000 Subject: [PATCH 122/130] block_aligned_split support resume v2 Signed-off-by: huanghaoyan.hhy --- vllm/v1/core/sched/scheduler.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 25cbc3457e74..a170127aa4a3 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -246,6 +246,11 @@ def _mamba_block_aligned_split( "External KV connector is not verified yet" ) block_size = self.cache_config.block_size + num_computed_tokens = ( + request.num_computed_tokens + + num_new_local_computed_tokens + + num_external_computed_tokens + ) # TODO: need check for resume requests # To enable block-aligned caching of the Mamba state, `num_new_tokens` # must be a multiple of `block_size`. @@ -254,18 +259,15 @@ def _mamba_block_aligned_split( # Additionally, when Eagle mode is enabled, FullAttn prunes the last # matching block. To prevent this from causing a Mamba cache miss, the # last chunk must be larger than `block_size`. - if num_new_tokens >= block_size: + if num_computed_tokens < request.num_prompt_tokens: # Use `num_tokens` instead of `num_prompt_tokens` to handle # resumed requests. - last_cache_position = request.num_tokens - request.num_tokens % block_size + last_cache_position = ( + request.num_prompt_tokens - request.num_prompt_tokens % block_size + ) # eagle prune if self.use_eagle: last_cache_position = max(last_cache_position - block_size, 0) - num_computed_tokens = ( - request.num_computed_tokens - + num_new_local_computed_tokens - + num_external_computed_tokens - ) num_computed_tokens_after_sched = num_computed_tokens + num_new_tokens if num_computed_tokens_after_sched < last_cache_position: # align to block_size From e381dd9ea0bb603b12c2c730490cbe81bb2b03f5 Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Wed, 14 Jan 2026 19:01:57 +0000 Subject: [PATCH 123/130] revert block_align_split Signed-off-by: huanghaoyan.hhy --- vllm/v1/core/sched/scheduler.py | 30 ++++++++++++++---------------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index a170127aa4a3..1466a0116498 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -245,29 +245,27 @@ def _mamba_block_aligned_split( assert num_external_computed_tokens == 0, ( "External KV connector is not verified yet" ) - block_size = self.cache_config.block_size - num_computed_tokens = ( - request.num_computed_tokens - + num_new_local_computed_tokens - + num_external_computed_tokens - ) # TODO: need check for resume requests - # To enable block-aligned caching of the Mamba state, `num_new_tokens` - # must be a multiple of `block_size`. - # As an exception, if `num_new_tokens` is less than `block_size`, the - # state is simply not cached, requiring no special handling. - # Additionally, when Eagle mode is enabled, FullAttn prunes the last - # matching block. To prevent this from causing a Mamba cache miss, the - # last chunk must be larger than `block_size`. - if num_computed_tokens < request.num_prompt_tokens: - # Use `num_tokens` instead of `num_prompt_tokens` to handle - # resumed requests. + if request.num_output_tokens == 0: # prefill + # To enable block-aligned caching of the Mamba state, `num_new_tokens` + # must be a multiple of `block_size`. + # As an exception, if `num_new_tokens` is less than `block_size`, the + # state is simply not cached, requiring no special handling. + # Additionally, when Eagle mode is enabled, FullAttn prunes the last + # matching block. To prevent this from causing a Mamba cache miss, the + # last chunk must be larger than `block_size`. + block_size = self.cache_config.block_size last_cache_position = ( request.num_prompt_tokens - request.num_prompt_tokens % block_size ) # eagle prune if self.use_eagle: last_cache_position = max(last_cache_position - block_size, 0) + num_computed_tokens = ( + request.num_computed_tokens + + num_new_local_computed_tokens + + num_external_computed_tokens + ) num_computed_tokens_after_sched = num_computed_tokens + num_new_tokens if num_computed_tokens_after_sched < last_cache_position: # align to block_size From f5580ede51f77f77705ab9112b64b53637f3734e Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Thu, 15 Jan 2026 18:33:33 +0000 Subject: [PATCH 124/130] add comments Signed-off-by: huanghaoyan.hhy --- .../layers/mamba/mamba_utils.py | 20 +++++++++++++++++++ vllm/model_executor/models/interfaces.py | 10 +++++++++- vllm/v1/attention/backends/mamba_attn.py | 4 ++-- vllm/v1/core/single_type_kv_cache_manager.py | 12 +++++++++-- vllm/v1/worker/mamba_utils.py | 1 - 5 files changed, 41 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_utils.py b/vllm/model_executor/layers/mamba/mamba_utils.py index 49af117c841d..816f76bfa069 100644 --- a/vllm/model_executor/layers/mamba/mamba_utils.py +++ b/vllm/model_executor/layers/mamba/mamba_utils.py @@ -231,6 +231,15 @@ def kda_state_shape( @dataclass class MambaCopySpec: + """ + Data class specifying the memory-copy parameters for Mamba states used for + prefix caching in align mode. + + Attributes: + start_addr (int): Starting address for the memory copy operation. + num_elements (int): Number of elements to copy from the starting address. + """ + start_addr: int num_elements: int @@ -238,6 +247,15 @@ class MambaCopySpec: MambaStateCopyFunc: TypeAlias = Callable[ [torch.Tensor, list[int], int, int], MambaCopySpec ] +""" +Type alias for a function that computes a MambaCopySpec for copying state slices. +Parameters: + state: torch.Tensor - the Mamba state tensor (e.g., conv or temporal states). + block_ids: list[int] - the list of block indices for the state to copy. + cur_block_idx: int - current block index within `block_ids` to copy from. + num_accepted_tokens: int - number of accepted tokens used to compute the copy offset. + Range: 1 .. 1 + num_speculative_tokens (inclusive). +""" def get_conv_copy_spec( @@ -246,6 +264,7 @@ def get_conv_copy_spec( cur_block_idx: int, num_accepted_tokens: int, ) -> MambaCopySpec: + """Return a MambaCopySpec for copying a convolutional state slice.""" src_block_id = block_ids[cur_block_idx] src_state = state[src_block_id, num_accepted_tokens - 1 :] return MambaCopySpec( @@ -259,6 +278,7 @@ def get_temporal_copy_spec( cur_block_idx: int, num_accepted_tokens: int, ) -> MambaCopySpec: + """Return a MambaCopySpec for copying a temporal state slice.""" src_block_id = block_ids[cur_block_idx + num_accepted_tokens - 1] src_state = state[src_block_id] return MambaCopySpec( diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 04c54b16f259..ff0b45b2b01a 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -639,7 +639,15 @@ def get_mamba_state_shape_from_config( @classmethod def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, ...]: - # TODO: add notes + """Calculate copy-function callables for each Mamba state. + + Returns: + A tuple of MambaStateCopyFunc callables that correspond, in order, + to the Mamba states produced by the model. Each callable accepts + (state, block_ids, cur_block_idx, num_accepted_tokens) and returns + a MambaCopySpec describing the memory-copy parameters for prefix + caching in align mode. + """ ... diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index 6ac9369b72dc..960af6eb492a 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -40,13 +40,13 @@ class BaseMambaAttentionMetadata: state_indices_tensor: torch.Tensor - # The following tensors are only used for prefix caching with all mode and + # The following tensors are only used for prefix caching in all mode and # are None if disabled block_idx_last_scheduled_token: torch.Tensor | None block_idx_first_scheduled_token_p: torch.Tensor | None block_idx_last_computed_token: torch.Tensor | None - # The following tensor is only used for prefix caching with align mode + # The following tensor is only used for prefix caching in align mode seq_lens: torch.Tensor # The following attributes are for triton implementation of causal_conv1d diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 63f3633dc56c..9918d6ffd2d9 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -747,8 +747,10 @@ def __init__(self, kv_cache_spec: MambaSpec, **kwargs) -> None: self.mamba_cache_mode = kv_cache_spec.mamba_cache_mode self.num_speculative_blocks: int = kv_cache_spec.num_speculative_blocks if self.mamba_cache_mode == "align": + # Mapping from request ID to the index of the block + # allocated in the previous step self.last_state_block_idx: dict[str, int] = {} - # the set of the requests that have been allocated blocks + # The set of the requests that have been allocated blocks self._allocated_block_reqs: set[str] = set() @classmethod @@ -803,8 +805,14 @@ def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> No assert isinstance(self.kv_cache_spec, MambaSpec) super().remove_skipped_blocks(request_id, num_computed_tokens) if self.mamba_cache_mode == "align": - # TODO: need comments + # `last_state_block_idx` refers to the block index allocated two steps ago. + # The block allocated in the previous step is used to copy Mamba states + # into the block allocated in the current step; the earlier block is + # no longer needed and should be freed here. last_state_block_idx = self.last_state_block_idx.get(request_id) + # Blocks allocated during prefill may be non-contiguous. Use + # `last_state_block_idx` to free the appropriate block and replace it + # with a null block. if ( last_state_block_idx is not None and last_state_block_idx diff --git a/vllm/v1/worker/mamba_utils.py b/vllm/v1/worker/mamba_utils.py index 20bf442ecae5..a0a1ae224f2a 100644 --- a/vllm/v1/worker/mamba_utils.py +++ b/vllm/v1/worker/mamba_utils.py @@ -193,7 +193,6 @@ def postprocess_mamba( num_accepted_tokens_cpu = input_batch.num_accepted_tokens_cpu # NOTE: can be optimized as this function always returns the same result mamba_group_ids, mamba_spec = get_mamba_groups(kv_cache_config) - # TODO: vectorize this loop src_state_list: list[int] = [] dest_state_list: list[int] = [] num_elements_list: list[int] = [] From 8f347a6b1021f2f7c7fca249fe806ba1c4263092 Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Sun, 18 Jan 2026 05:25:06 +0000 Subject: [PATCH 125/130] revert the prefill logic back to include draft tokens Signed-off-by: huanghaoyan.hhy --- vllm/v1/core/sched/scheduler.py | 20 ++++---------------- 1 file changed, 4 insertions(+), 16 deletions(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 1466a0116498..5bc25956d675 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -332,23 +332,11 @@ def schedule(self) -> SchedulerOutput: req_index += 1 continue - # TODO: merge PR#30618 - num_tokens_to_compute = ( - request.num_tokens_with_spec + request.num_output_placeholders + num_new_tokens = ( + request.num_tokens_with_spec + + request.num_output_placeholders + - request.num_computed_tokens ) - # Ensure new tokens for a request in the prefill phase do not contain - # draft tokens, especially in the last prefill chunk. For a hybrid-model, - # extra draft tokens would corrupt the generated Mamba state. - # TODO: This logic does not yet handle resumed requests. - if ( - self.has_mamba_layers - and request.num_computed_tokens < request.num_prompt_tokens - ): - num_tokens_to_compute = min( - num_tokens_to_compute, request.num_prompt_tokens - ) - num_new_tokens = num_tokens_to_compute - request.num_computed_tokens - if 0 < self.scheduler_config.long_prefill_token_threshold < num_new_tokens: num_new_tokens = self.scheduler_config.long_prefill_token_threshold num_new_tokens = min(num_new_tokens, token_budget) From b29b026312c46f33b877b7035a12171e4dbac829 Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Sun, 18 Jan 2026 05:27:00 +0000 Subject: [PATCH 126/130] temporarily disable spec decoding Signed-off-by: huanghaoyan.hhy --- vllm/model_executor/models/config.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index cd3e0fa12eae..1f75e873af2c 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -334,8 +334,12 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: assert vllm_config.scheduler_config.enable_chunked_prefill, ( "Chunked prefill is required for mamba cache mode 'align'." ) + assert not vllm_config.speculative_config, ( + "Mamba cache mode 'align' is currently not compatible " + "with speculative decoding." + ) logger.info( - "Warning: Prefix caching with Mamba cache '%s' " + "Warning: Prefix caching in Mamba cache '%s' " "mode is currently enabled. " "Its support for Mamba layers is experimental. " "Please report any issues you may observe.", From fd6e24fe8e5aac0be66d808c07897cb81cc02bc2 Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Sun, 18 Jan 2026 06:28:29 +0000 Subject: [PATCH 127/130] fix pre-commit Signed-off-by: huanghaoyan.hhy --- vllm/v1/attention/backends/linear_attn.py | 4 +++- vllm/v1/attention/backends/utils.py | 4 ++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/vllm/v1/attention/backends/linear_attn.py b/vllm/v1/attention/backends/linear_attn.py index e24bf934a757..02551e704766 100644 --- a/vllm/v1/attention/backends/linear_attn.py +++ b/vllm/v1/attention/backends/linear_attn.py @@ -10,9 +10,11 @@ AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, +) +from vllm.v1.attention.backends.utils import ( mamba_get_block_table_tensor, + split_decodes_and_prefills, ) -from vllm.v1.attention.backends.utils import split_decodes_and_prefills from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 27704e52184c..7b8757adc522 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -17,7 +17,7 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.utils.math_utils import cdiv -from vllm.v1.kv_cache_interface import MambaSpec +from vllm.v1.kv_cache_interface import KVCacheSpec, MambaSpec if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput @@ -824,7 +824,7 @@ def get_dcp_local_seq_lens( def mamba_get_block_table_tensor( block_table: torch.Tensor, seq_lens: torch.Tensor, - kv_cache_spec: MambaSpec, + kv_cache_spec: KVCacheSpec, mamba_cache_mode: str, ) -> torch.Tensor: """ From c200506002246d3beab1d8f4b2e8b81782d4fdec Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Mon, 19 Jan 2026 00:23:20 -0800 Subject: [PATCH 128/130] skip test Signed-off-by: Chen Zhang --- tests/v1/e2e/test_mamba_prefix_cache.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/v1/e2e/test_mamba_prefix_cache.py b/tests/v1/e2e/test_mamba_prefix_cache.py index 554d94f9a803..83a233ce5e83 100644 --- a/tests/v1/e2e/test_mamba_prefix_cache.py +++ b/tests/v1/e2e/test_mamba_prefix_cache.py @@ -473,6 +473,10 @@ def apply_patch(monkeypatch: pytest.MonkeyPatch): monkeypatch.setattr(mamba_utils, "do_mamba_copy_block", fake_copy_fn) +@pytest.skip( + reason="Skipping test_mamba_prefix_cache because it is based on spec " + "decode which is not allowed now." +) def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch): run_ref_mamba_state_in_subprocess() apply_patch(monkeypatch) From 738e7f450da48b9e613ee3de51c87215664df0b6 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Mon, 19 Jan 2026 01:01:31 -0800 Subject: [PATCH 129/130] skip test Signed-off-by: Chen Zhang --- tests/v1/e2e/test_mamba_prefix_cache.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/v1/e2e/test_mamba_prefix_cache.py b/tests/v1/e2e/test_mamba_prefix_cache.py index 83a233ce5e83..7fe95366b9d5 100644 --- a/tests/v1/e2e/test_mamba_prefix_cache.py +++ b/tests/v1/e2e/test_mamba_prefix_cache.py @@ -473,7 +473,7 @@ def apply_patch(monkeypatch: pytest.MonkeyPatch): monkeypatch.setattr(mamba_utils, "do_mamba_copy_block", fake_copy_fn) -@pytest.skip( +@pytest.mark.skip( reason="Skipping test_mamba_prefix_cache because it is based on spec " "decode which is not allowed now." ) From 74c60f5b484e08dd5fd527818897e614da0b2067 Mon Sep 17 00:00:00 2001 From: "huanghaoyan.hhy" Date: Mon, 19 Jan 2026 15:23:54 +0000 Subject: [PATCH 130/130] update commemts Signed-off-by: huanghaoyan.hhy --- vllm/config/cache.py | 3 ++- vllm/model_executor/layers/mamba/mamba_mixer.py | 6 +++--- vllm/model_executor/layers/mamba/mamba_mixer2.py | 12 ++++++------ 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/vllm/config/cache.py b/vllm/config/cache.py index 50853d405583..abf10e21d408 100644 --- a/vllm/config/cache.py +++ b/vllm/config/cache.py @@ -128,7 +128,8 @@ class CacheConfig: """The cache strategy for Mamba layers. - "none": set when prefix caching is disabled. - "all": cache the mamba state of all tokens at position i * block_size. This is - the default behavior when prefix caching is enabled. + the default behavior (for models that support it) when prefix caching is + enabled. - "align": only cache the mamba state of the last token of each scheduler step and when the token is at position i * block_size. """ diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index d9f7b1c1ee57..134e1dfd6283 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -255,7 +255,7 @@ def forward_cuda(self, hidden_states: torch.Tensor, output: torch.Tensor): assert self.cache_config is not None mamba_block_size = self.cache_config.mamba_block_size - return_intermediate_states = self.cache_config.mamba_cache_mode == "all" + is_mamba_cache_all = self.cache_config.mamba_cache_mode == "all" if attn_metadata is not None: assert isinstance(attn_metadata, dict) @@ -304,7 +304,7 @@ def forward_cuda(self, hidden_states: torch.Tensor, output: torch.Tensor): state_indices_tensor_p = prefill_decode_split.state_indices_tensor_p state_indices_tensor_d = prefill_decode_split.state_indices_tensor_d - if return_intermediate_states: + if is_mamba_cache_all: block_idx_last_computed_token_d, block_idx_last_computed_token_p = ( torch.split( attn_metadata.block_idx_last_computed_token, @@ -380,7 +380,7 @@ def forward_cuda(self, hidden_states: torch.Tensor, output: torch.Tensor): ssm_outputs.append(scan_out_p) if has_decode: - if return_intermediate_states: + if is_mamba_cache_all: state_indices_tensor_d_input = state_indices_tensor_d.gather( 1, block_idx_last_computed_token_d.unsqueeze(1) ).squeeze(1) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 2270a966f42d..7af5e02c29d2 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -570,7 +570,7 @@ def conv_ssm_forward( assert self.cache_config is not None mamba_block_size = self.cache_config.mamba_block_size - return_intermediate_states = self.cache_config.mamba_cache_mode == "all" + is_mamba_cache_all = self.cache_config.mamba_cache_mode == "all" if attn_metadata is not None: assert isinstance(attn_metadata, dict) attn_metadata = attn_metadata[self.prefix] @@ -622,7 +622,7 @@ def conv_ssm_forward( dim=0, ) - if return_intermediate_states: + if is_mamba_cache_all: # If prefix caching is enabled, retrieve the relevant variables # for prefill and decode block_idx_last_computed_token_d, block_idx_last_computed_token_p = ( @@ -701,7 +701,7 @@ def conv_ssm_forward( initial_states = None if has_initial_states_p is not None and prep_initial_states: kernel_ssm_indices = state_indices_tensor_p - if return_intermediate_states: + if is_mamba_cache_all: kernel_ssm_indices = state_indices_tensor_p.gather( 1, block_idx_last_computed_token_p.unsqueeze(1) ).squeeze(1) @@ -729,14 +729,14 @@ def conv_ssm_forward( cu_chunk_seqlens=cu_chunk_seqlen_p, last_chunk_indices=last_chunk_indices_p, initial_states=initial_states, - return_intermediate_states=return_intermediate_states, + return_intermediate_states=is_mamba_cache_all, dt_softplus=True, dt_limit=(0.0, float("inf")), out=preallocated_ssm_out_p.view(num_prefill_tokens, -1, self.head_dim), state_dtype=ssm_state.dtype, ) - if return_intermediate_states: + if is_mamba_cache_all: # The chunk_stride is the number of chunks per mamba block # e.g., if mamba_block_size = 512 and chunk_size = 256, # then chunk_stride = 2 @@ -815,7 +815,7 @@ def conv_ssm_forward( # Process decode requests if has_decode: - if return_intermediate_states: + if is_mamba_cache_all: state_indices_tensor_d_input = state_indices_tensor_d.gather( 1, block_idx_last_computed_token_d.unsqueeze(1) ).squeeze(1)