From 02118100a35e837a5f66c9b1f2edc5d45a4e797a Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 23 Sep 2025 21:03:09 -0700 Subject: [PATCH 1/4] [V0 Deprecation] Remove unused classes in attention Signed-off-by: Woosuk Kwon --- vllm/attention/__init__.py | 6 +- vllm/attention/backends/abstract.py | 145 +------ vllm/attention/backends/utils.py | 547 +------------------------ vllm/v1/attention/backends/cpu_attn.py | 15 - vllm/v1/attention/backends/pallas.py | 5 - 5 files changed, 5 insertions(+), 713 deletions(-) diff --git a/vllm/attention/__init__.py b/vllm/attention/__init__.py index dcb2aa68fbee..1b37bd1f6100 100644 --- a/vllm/attention/__init__.py +++ b/vllm/attention/__init__.py @@ -2,9 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from vllm.attention.backends.abstract import (AttentionBackend, - AttentionMetadata, - AttentionMetadataBuilder, - AttentionState, AttentionType) + AttentionMetadata, AttentionType) from vllm.attention.layer import Attention from vllm.attention.selector import get_attn_backend @@ -13,7 +11,5 @@ "AttentionBackend", "AttentionMetadata", "AttentionType", - "AttentionMetadataBuilder", - "AttentionState", "get_attn_backend", ] diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 1b392cd7c88d..0f51ef4b2e51 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -2,10 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod -from contextlib import contextmanager -from dataclasses import dataclass, fields -from typing import (Any, Dict, Generic, List, Optional, Protocol, Set, Tuple, - Type, TypeVar) +from typing import Generic, List, Optional, Protocol, Tuple, Type, TypeVar import torch @@ -49,18 +46,13 @@ def get_impl_cls() -> Type["AttentionImpl"]: def get_metadata_cls() -> Type["AttentionMetadata"]: raise NotImplementedError - @staticmethod - @abstractmethod - def get_state_cls() -> Type["AttentionState"]: - raise NotImplementedError - @classmethod def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata": return cls.get_metadata_cls()(*args, **kwargs) @staticmethod @abstractmethod - def get_builder_cls() -> Type["AttentionMetadataBuilder"]: + def get_builder_cls(): # -> Type["AttentionMetadataBuilder"]: raise NotImplementedError @staticmethod @@ -77,149 +69,18 @@ def get_kv_cache_shape( def get_kv_cache_stride_order() -> Tuple[int, ...]: raise NotImplementedError - @staticmethod - @abstractmethod - def swap_blocks( - src_kv_cache: torch.Tensor, - dst_kv_cache: torch.Tensor, - src_to_dst: torch.Tensor, - ) -> None: - raise NotImplementedError - - @staticmethod - @abstractmethod - def copy_blocks( - kv_caches: List[torch.Tensor], - src_to_dists: torch.Tensor, - ) -> None: - raise NotImplementedError - @classmethod def full_cls_name(cls) -> tuple[str, str]: return (cls.__module__, cls.__qualname__) -@dataclass class AttentionMetadata: - """Attention metadata for prefill and decode batched together.""" - # Total number of prefill requests. - num_prefills: int - # Number of prefill tokens. - num_prefill_tokens: int - # Number of decode tokens. Note that it is equivalent to the number of - # decode requests. - num_decode_tokens: int - # (num_tokens,). The indices of the token slots that input tokens will be - # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size - # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot - # in block 0, and 1st slot in block 1, respectively. - slot_mapping: torch.Tensor - - # Enable/disable KV scales calculation. This is so that we can disable the - # calculation until after prefill and cuda graph capture. - enable_kv_scales_calculation: bool - - @property - @abstractmethod - def prefill_metadata(self) -> Optional["AttentionMetadata"]: - """Return the attention metadata that's required to run prefill - attention.""" - pass - - @property - @abstractmethod - def decode_metadata(self) -> Optional["AttentionMetadata"]: - """Return the attention metadata that's required to run decode - attention.""" - pass - - def asdict_zerocopy(self, - skip_fields: Optional[Set[str]] = None - ) -> Dict[str, Any]: - """Similar to dataclasses.asdict, but avoids deepcopying.""" - if skip_fields is None: - skip_fields = set() - # Note that if we add dataclasses as fields, they will need - # similar handling. - return { - field.name: getattr(self, field.name) - for field in fields(self) if field.name not in skip_fields - } + pass T = TypeVar("T", bound=AttentionMetadata) -class AttentionState(ABC, Generic[T]): - """Holds attention backend-specific objects reused during the - lifetime of the model runner.""" - - @abstractmethod - def __init__(self, runner: Any): - ... - - @abstractmethod - @contextmanager - def graph_capture(self, max_batch_size: int): - """Context manager used when capturing CUDA graphs.""" - yield - - @abstractmethod - def graph_clone(self, batch_size: int) -> "AttentionState[T]": - """Clone attention state to save in CUDA graph metadata.""" - ... - - @abstractmethod - def graph_capture_get_metadata_for_batch( - self, - batch_size: int, - is_encoder_decoder_model: bool = False) -> T: - """Get attention metadata for CUDA graph capture of batch_size.""" - ... - - @abstractmethod - def get_graph_input_buffers( - self, - attn_metadata: T, - is_encoder_decoder_model: bool = False) -> Dict[str, Any]: - """Get attention-specific input buffers for CUDA graph capture.""" - ... - - @abstractmethod - def prepare_graph_input_buffers( - self, - input_buffers: Dict[str, Any], - attn_metadata: T, - is_encoder_decoder_model: bool = False) -> None: - """In-place modify input buffers dict for CUDA graph replay.""" - ... - - @abstractmethod - def begin_forward(self, model_input) -> None: - """Prepare state for forward pass.""" - ... - - -class AttentionMetadataBuilder(ABC, Generic[T]): - """Abstract class for attention metadata builders.""" - - @abstractmethod - def __init__(self, input_builder) -> None: - """Create the builder, remember some configuration and parameters.""" - raise NotImplementedError - - @abstractmethod - def prepare(self) -> None: - """Prepare for one batch.""" - raise NotImplementedError - - @abstractmethod - def build(self, seq_lens: List[int], query_lens: List[int], - cuda_graph_pad_size: int, batch_size: int) -> T: - """Build attention metadata with on-device tensors.""" - raise NotImplementedError - - class AttentionLayer(Protocol): _q_scale: torch.Tensor diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 63ee8f50825c..86998ddbea0e 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -1,559 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention backend utils""" -from contextlib import contextmanager from dataclasses import dataclass -from itertools import accumulate -from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union +from typing import Optional -import numpy as np -import torch - -from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder, - AttentionState) -from vllm.attention.backends.abstract import AttentionType from vllm.config import ModelConfig from vllm.logger import init_logger -from vllm.utils import async_tensor_h2d, make_tensor_with_pad logger = init_logger(__name__) -PAD_SLOT_ID = -1 - -# Switch to numpy implementation of compute_slot_mapping -# if we have at least this many elements. Could be tuned further. -_COMPUTE_SLOT_MAPPING_NUMPY_NUMEL = 256 - - -def is_block_tables_empty(block_tables: Union[None, Dict]): - """ - Check if block_tables is None or a dictionary with all None values. - """ - if block_tables is None: - return True - return (isinstance(block_tables, dict) - and all(value is None for value in block_tables.values())) - - -def compute_slot_mapping_start_idx(is_prompt: bool, query_len: int, - context_len: int, sliding_window: int): - """ - Compute the start index of slot mapping. - """ - start_idx = 0 - if is_prompt and sliding_window is not None: - start_idx = max(0, query_len - sliding_window) - return start_idx - - -def _compute_slot_mapping_python(slot_mapping: List[int], - block_table: List[int], range_start: int, - range_end: int, block_size: int): - for i in range(range_start, range_end): - block_number = block_table[i // block_size] - block_offset = i % block_size - slot = block_number * block_size + block_offset - slot_mapping.append(slot) - - -def _compute_slot_mapping_numpy(slot_mapping: List[int], - block_table: List[int], range_start: int, - range_end: int, block_size: int): - block_table_array = np.array(block_table) - idx = np.arange(range_start, range_end) - block_offset = idx % block_size - idx //= block_size - seq_slot_mapping_array = block_table_array[idx] - seq_slot_mapping_array *= block_size - seq_slot_mapping_array += block_offset - slot_mapping.extend(seq_slot_mapping_array) - - -def compute_slot_mapping(is_profile_run: bool, slot_mapping: List[int], - seq_id: int, seq_len: int, context_len: int, - start_idx: int, block_size: int, - block_tables: Dict[int, List[int]]): - """ - Compute slot mapping. - """ - if is_profile_run: - # During memory profiling, the block tables are not - # initialized yet. In this case, we just use a dummy - # slot mapping. - # In embeddings, the block tables are {seq_id: None}. - slot_mapping.extend([PAD_SLOT_ID] * seq_len) - return - - # Mask the [0, start_idx) tokens of the prompt with - # PAD_SLOT_ID, where start_idx is max(0, seq_len - - # sliding_window). For example, if the prompt len is 10, - # sliding window is 8, and block size is 4, the first two - # tokens are masked and the slot mapping will be - # [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. - padding_mask_len = max(0, start_idx - context_len) - slot_mapping.extend([PAD_SLOT_ID] * padding_mask_len) - - range_start = max(start_idx, context_len) - range_end = seq_len - numel = range_end - range_start - block_table = block_tables[seq_id] - - # numpy implementation will be faster than python if we have - # many elements, otherwise it will be slower. - if numel < _COMPUTE_SLOT_MAPPING_NUMPY_NUMEL: - _compute_slot_mapping_python(slot_mapping, block_table, range_start, - range_end, block_size) - else: - _compute_slot_mapping_numpy(slot_mapping, block_table, range_start, - range_end, block_size) - - -TAttentionMetadata = TypeVar("TAttentionMetadata", bound='AttentionMetadata') - - -class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]): - - _metadata_cls: Type[TAttentionMetadata] - - def __init__(self, input_builder): - self.input_builder = input_builder - self.runner = input_builder.runner - - self.sliding_window = input_builder.sliding_window - self.block_size = input_builder.block_size - - def prepare(self): - self.slot_mapping: List[int] = [] - self.prefill_seq_lens: List[int] = [] - self.context_lens: List[int] = [] - self.block_tables: List[List[int]] = [] - self.curr_seq_lens: List[int] = [] - self.num_prefills = 0 - self.num_prefill_tokens = 0 - self.num_decode_tokens = 0 - - def _add_seq_group(self, inter_data, chunked_prefill_enabled: bool): - is_prompt = inter_data.is_prompt - block_tables = inter_data.block_tables - - for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, - curr_sliding_window_block) in zip( - inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], - inter_data.orig_seq_lens, inter_data.seq_lens, - inter_data.query_lens, inter_data.context_lens, - inter_data.curr_sliding_window_blocks): - self.context_lens.append(context_len) - if is_prompt: - self.num_prefills += 1 - self.num_prefill_tokens += token_len - self.prefill_seq_lens.append(seq_len) - else: - assert query_len == 1, ( - "seq_len: {}, context_len: {}, query_len: {}".format( - seq_len, context_len, query_len)) - self.num_decode_tokens += query_len - self.curr_seq_lens.append(curr_seq_len) - - # Compute block table. - # TODO(sang): Combine chunked prefill and prefix caching by - # only allowing multiple of block_size chunk size. - # NOTE: This only works for oooooooxxx style attention. - block_table = [] - if inter_data.prefix_cache_hit: - block_table = block_tables[seq_id] - elif ((chunked_prefill_enabled or not is_prompt) - and block_tables is not None): - if curr_sliding_window_block == 0: - block_table = block_tables[seq_id] - else: - block_table = block_tables[seq_id][ - -curr_sliding_window_block:] - self.block_tables.append(block_table) - - # Compute slot mapping. - is_profile_run = is_block_tables_empty(block_tables) - start_idx = compute_slot_mapping_start_idx(is_prompt, query_len, - context_len, - self.sliding_window) - compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, - seq_len, context_len, start_idx, - self.block_size, inter_data.block_tables) - - def build(self, seq_lens: List[int], query_lens: List[int], - cuda_graph_pad_size: int, batch_size: int): - """Build attention metadata with on-device tensors. - - Args: - seq_lens: The maybe padded sequence lengths of the input sequences. - query_lens: The query lengths of the input sequences. - cuda_graph_pad_size: The padding size for cuda graph. - -1 if cuda graph is not used. - batch_size: The maybe padded batch size. - """ - for inter_data in self.input_builder.inter_data_list: - self._add_seq_group(inter_data, - self.input_builder.chunked_prefill_enabled) - - device = self.runner.device - use_captured_graph = cuda_graph_pad_size != -1 - - max_query_len = max(query_lens) - max_prefill_seq_len = max(self.prefill_seq_lens, default=0) - max_decode_seq_len = max(self.curr_seq_lens, default=0) - num_decode_tokens = self.num_decode_tokens - query_start_loc = list(accumulate(query_lens, initial=0)) - seq_start_loc = list(accumulate(seq_lens, initial=0)) - - if use_captured_graph: - self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) - self.block_tables.extend([] * cuda_graph_pad_size) - num_decode_tokens = batch_size - - # The shape of graph_block_tables is - # [max batch size, max context len // block size]. - input_block_tables = self.runner.graph_block_tables[:batch_size] - for i, block_table in enumerate(self.block_tables): - if block_table: - input_block_tables[i, :len(block_table)] = block_table - block_tables = torch.from_numpy(input_block_tables).to( - device, non_blocking=True) - else: - block_tables = make_tensor_with_pad( - self.block_tables, - pad=0, - dtype=torch.int, - device=device, - ) - assert max_query_len > 0, "query_lens: {}".format(query_lens) - - assert device is not None - context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int, - device, self.runner.pin_memory) - seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, - self.runner.pin_memory) - slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long, - device, self.runner.pin_memory) - query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32, - device, - self.runner.pin_memory) - seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32, - device, self.runner.pin_memory) - - return self._metadata_cls( # type: ignore - num_prefills=self.num_prefills, - slot_mapping=slot_mapping_tensor, - enable_kv_scales_calculation=True, - num_prefill_tokens=self.num_prefill_tokens, - num_decode_tokens=num_decode_tokens, - seq_lens=seq_lens, - seq_lens_tensor=seq_lens_tensor, - max_query_len=max_query_len, - max_prefill_seq_len=max_prefill_seq_len, - max_decode_seq_len=max_decode_seq_len, - query_start_loc=query_start_loc_tensor, - seq_start_loc=seq_start_loc_tensor, - context_lens_tensor=context_lens_tensor, - block_tables=block_tables, - use_cuda_graph=use_captured_graph, - ) - - -class CommonAttentionState(AttentionState): - - def __init__(self, runner): - self.runner = runner - self._is_graph_capturing = False - - @contextmanager - def graph_capture(self, max_batch_size: int): - - self._is_graph_capturing = True - - self._graph_slot_mapping = torch.full((max_batch_size, ), - PAD_SLOT_ID, - dtype=torch.long, - device=self.runner.device) - self._graph_seq_lens = torch.ones(max_batch_size, - dtype=torch.int32, - device=self.runner.device) - self._graph_block_tables = torch.from_numpy( - self.runner.graph_block_tables).to(device=self.runner.device) - - yield - - self._is_graph_capturing = False - del self._graph_slot_mapping - del self._graph_seq_lens - del self._graph_block_tables - - def graph_clone(self, batch_size: int) -> "CommonAttentionState": - assert self._is_graph_capturing - return self.__class__(self.runner) - - def graph_capture_get_metadata_for_batch( - self, batch_size: int, is_encoder_decoder_model: bool = False): - assert self._is_graph_capturing - attn_metadata = self.runner.attn_backend.make_metadata( - num_prefills=0, - num_prefill_tokens=0, - num_decode_tokens=batch_size, - slot_mapping=self._graph_slot_mapping[:batch_size], - enable_kv_scales_calculation=True, - seq_lens=None, - seq_lens_tensor=self._graph_seq_lens[:batch_size], - max_query_len=1, - max_decode_query_len=1, - max_prefill_seq_len=0, - max_decode_seq_len=self.runner.max_seq_len_to_capture, - query_start_loc=None, - seq_start_loc=None, - context_lens_tensor=None, - block_tables=self._graph_block_tables[:batch_size], - use_cuda_graph=True, - ) - if is_encoder_decoder_model: - # The encoder decoder model works only with XFormers and - # Flash Attention backend. Assert the same. - assert self.runner.attn_backend.get_name() in \ - ["XFORMERS", "FLASH_ATTN"], \ - f"Expected attn_backend name to be either 'XFORMERS' or " \ - f"'FLASH_ATTN', but got '{self.runner.attn_backend.get_name()}'" - self._update_captured_metadata_for_enc_dec_model( - batch_size=batch_size, attn_metadata=attn_metadata) - - return attn_metadata - - def get_graph_input_buffers( - self, - attn_metadata, - is_encoder_decoder_model: bool = False) -> Dict[str, Any]: - input_buffers = { - "slot_mapping": attn_metadata.slot_mapping, - "seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor, - "block_tables": attn_metadata.decode_metadata.block_tables, - } - if is_encoder_decoder_model: - # The encoder decoder model works only with XFormers and - # Flash Attention backend. Assert the same. - assert self.runner.attn_backend.get_name() in \ - ["XFORMERS", "FLASH_ATTN"], \ - f"Expected attn_backend name to be either 'XFORMERS' or " \ - f"'FLASH_ATTN', but got '{self.runner.attn_backend.get_name()}'" - self._add_additional_input_buffers_for_enc_dec_model( - attn_metadata=attn_metadata, input_buffers=input_buffers) - return input_buffers - - def prepare_graph_input_buffers( - self, - input_buffers, - attn_metadata, - is_encoder_decoder_model: bool = False) -> None: - input_buffers["seq_lens_tensor"].copy_( - attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True) - input_buffers["block_tables"].copy_( - attn_metadata.decode_metadata.block_tables, non_blocking=True) - if is_encoder_decoder_model: - # The encoder decoder model works only with XFormers and - # Flash Attention backend. Assert the same. - assert self.runner.attn_backend.get_name() in\ - ["XFORMERS", "FLASH_ATTN"], \ - f"Expected attn_backend name to be either 'XFORMERS' or "\ - f"'FLASH_ATTN', but "\ - f"got '{self.runner.attn_backend.get_name()}'" - self._prepare_input_buffers_for_enc_dec_model( - attn_metadata, input_buffers) - - def begin_forward(self, model_input) -> None: - return - - def _update_captured_metadata_for_enc_dec_model(self, batch_size: int, - attn_metadata): - """ - Updates the attention metadata parameters for CUDA graph capture in an - encoder-decoder model. - - This method modifies attention-related tensors and metadata required - for CUDA graph capture in encoder-decoder models. Specifically, it - updates the cross-attention and encoder sequence tensors in the - AttentionMetadata object. - """ - # During decode phase the cross_slot_mapping will be empty. Hence set - # an empty tensor for CUDA Graph capture. - attn_metadata.cross_slot_mapping = torch.tensor( - [], dtype=torch.int).cuda() - attn_metadata.cross_block_tables = torch.full( - (batch_size, self.runner.get_max_block_per_batch()), - 1, - dtype=torch.int).cuda() - attn_metadata.encoder_seq_lens = torch.full((batch_size, ), - 1, - dtype=torch.int).cuda() - attn_metadata.encoder_seq_lens_tensor = torch.full( - (batch_size, ), 1, dtype=torch.int).cuda() - attn_metadata.max_encoder_seq_len = self.runner.max_seq_len_to_capture - attn_metadata.num_encoder_tokens = 0 - - def _add_additional_input_buffers_for_enc_dec_model( - self, attn_metadata, input_buffers: Dict[str, Any]): - """ - Saves additional input buffers specific to the encoder-decoder model - from the attention metadata. - - This method extracts and stores encoder-decoder related input buffers - from the `attn_metadata` into the `input_buffers` dictionary. The - buffers include encoder sequence lengths, cross-slot mappings, and - cross-block tables, which are essential for the encoder-decoder model - during CUDA graph replay. - """ - input_buffers["encoder_seq_lens_tensor"] = ( - attn_metadata.decode_metadata.encoder_seq_lens_tensor) - input_buffers["cross_slot_mapping"] = ( - attn_metadata.decode_metadata.cross_slot_mapping) - input_buffers["cross_block_tables"] = ( - attn_metadata.decode_metadata.cross_block_tables) - - def _prepare_input_buffers_for_enc_dec_model(self, attn_metadata, - input_buffers: Dict[str, - Any]): - """ - Populates input buffers with data from the encoder-decoder model's - attention metadata. - - This method fills the input buffers with encoder-decoder specific - tensors. It copies data from the `attn_metadata` and keyword arguments - (`kwargs`) into corresponding buffers in the `input_buffers` dictionary. - The copied data includes attention-related metadata as well as input - IDs and positional information for the encoder. - """ - input_buffers["encoder_seq_lens_tensor"].copy_( - attn_metadata.decode_metadata.encoder_seq_lens_tensor, - non_blocking=True) - input_buffers["cross_slot_mapping"].copy_( - attn_metadata.decode_metadata.cross_slot_mapping, - non_blocking=True) - input_buffers["cross_block_tables"].copy_( - attn_metadata.decode_metadata.cross_block_tables, - non_blocking=True) - - -def is_all_encoder_attn_metadata_set(attn_metadata): - ''' - All attention metadata required for encoder attention is set. - ''' - return ((attn_metadata.encoder_seq_lens is not None) - and (attn_metadata.encoder_seq_lens_tensor is not None) - and (attn_metadata.max_encoder_seq_len is not None)) - - -def is_all_cross_attn_metadata_set(attn_metadata): - ''' - All attention metadata required for enc/dec cross-attention is set. - - Superset of encoder attention required metadata. - ''' - return (attn_metadata.is_all_encoder_attn_metadata_set - and (attn_metadata.cross_slot_mapping is not None) - and (attn_metadata.cross_block_tables is not None)) - - -def get_seq_len_block_table_args( - attn_metadata, - is_prompt: bool, - attn_type: str, -) -> tuple: - ''' - The particular choice of sequence-length- and block-table-related - attributes which should be extracted from attn_metadata is dependent - on the type of attention operation. - - Decoder attn -> select entirely decoder self-attention-related fields - Encoder/decoder cross-attn -> select encoder sequence lengths & - cross-attn block-tables fields - Encoder attn -> select encoder sequence lengths fields & no block tables - - Arguments: - - * attn_metadata: Attention metadata structure associated with attention op - * is_prompt: True if prefill, False otherwise - * attn_type: encoder attention, decoder self-attention, - encoder/decoder cross-attention - - Returns: - - * Appropriate sequence-lengths tensor - * Appropriate max sequence-length scalar - * Appropriate block tables (or None) - ''' - - if attn_type == AttentionType.DECODER: - # Decoder self-attention - # Choose max_seq_len based on whether we are in prompt_run - if is_prompt: - max_seq_len = attn_metadata.max_prefill_seq_len - else: - max_seq_len = attn_metadata.max_decode_seq_len - return (attn_metadata.seq_lens_tensor, max_seq_len, - attn_metadata.block_tables) - elif attn_type == AttentionType.ENCODER_DECODER: - # Enc/dec cross-attention KVs match encoder sequence length; - # cross-attention utilizes special "cross" block tables - return (attn_metadata.encoder_seq_lens_tensor, - attn_metadata.max_encoder_seq_len, - attn_metadata.cross_block_tables) - elif attn_type == AttentionType.ENCODER: - # No block tables associated with encoder attention - return (attn_metadata.encoder_seq_lens_tensor, - attn_metadata.max_encoder_seq_len, None) - else: - raise AttributeError(f"Invalid attention type {str(attn_type)}") - - -def get_num_prefill_decode_query_kv_tokens( - attn_metadata, - attn_type: str, -) -> Tuple[int, int, int]: - """ - Calculate the number of prefill and decode tokens for query, key/value - based on the attention metadata and the specified attention type. - - Args: - attn_metadata (AttentionMetadata): Attention Metadata object. - attn_type (AttentionType): The type of attention being used. - Returns: - Tuple[int, int, int]: A tuple containing three integers: - - The number of prefill query tokens. - - The number of prefill key/value tokens. - - The number of decode query tokens. - - Raises: - AssertionError: If the number of encoder tokens in `attn_metadata` - is `None` when required for the calculations. - """ - num_prefill_query_tokens = 0 - num_decode_query_tokens = 0 - num_prefill_kv_tokens = 0 - if attn_type == AttentionType.ENCODER: - # Encoder attention is only invoked during prefill phase. - # The same input servers a both query and key. - assert attn_metadata.num_encoder_tokens is not None - num_prefill_query_tokens = attn_metadata.num_encoder_tokens - num_prefill_kv_tokens = attn_metadata.num_encoder_tokens - num_decode_query_tokens = 0 - elif attn_type == AttentionType.ENCODER_DECODER: - assert attn_metadata.num_encoder_tokens is not None - num_prefill_query_tokens = attn_metadata.num_prefill_tokens - # The key is the encoder/cross-attention. - num_prefill_kv_tokens = attn_metadata.num_encoder_tokens - num_decode_query_tokens = attn_metadata.num_decode_tokens - else: # attn_type == AttentionType.DECODER or - # attn_type == AttentionType.ENCODER_ONLY - num_prefill_query_tokens = attn_metadata.num_prefill_tokens - num_prefill_kv_tokens = attn_metadata.num_prefill_tokens - num_decode_query_tokens = attn_metadata.num_decode_tokens - - return (num_prefill_query_tokens, num_prefill_kv_tokens, - num_decode_query_tokens) - @dataclass class MLADims: diff --git a/vllm/v1/attention/backends/cpu_attn.py b/vllm/v1/attention/backends/cpu_attn.py index 7e485fea2689..72f26c23b60b 100644 --- a/vllm/v1/attention/backends/cpu_attn.py +++ b/vllm/v1/attention/backends/cpu_attn.py @@ -11,7 +11,6 @@ AttentionLayer, AttentionMetadata, AttentionType, is_quantized_kv_cache) -from vllm.attention.backends.utils import CommonAttentionState from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, @@ -65,10 +64,6 @@ def get_impl_cls() -> type["TorchSDPABackendImpl"]: def get_metadata_cls() -> type["AttentionMetadata"]: return TorchSDPAMetadata - @staticmethod - def get_state_cls() -> type["CommonAttentionState"]: - return CommonAttentionState - @staticmethod def get_builder_cls() -> type["TorchSDPAMetadataBuilderV1"]: return TorchSDPAMetadataBuilderV1 @@ -835,16 +830,6 @@ def forward_decode( blocksparse_head_sliding_step, ) - @staticmethod - def copy_blocks( - kv_caches: list[torch.Tensor], - src_to_dists: torch.Tensor, - *args, - ) -> None: - key_caches = [kv_cache[0] for kv_cache in kv_caches] - value_caches = [kv_cache[1] for kv_cache in kv_caches] - ops.copy_blocks(key_caches, value_caches, src_to_dists) - class _IPEXPagedAttention(_PagedAttention): diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index 26f9abf13d0e..4ae0634e082a 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -8,7 +8,6 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionLayer, AttentionType) -from vllm.attention.backends.utils import CommonAttentionState from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.utils import cdiv, next_power_of_2 @@ -97,10 +96,6 @@ def get_impl_cls() -> type["PallasAttentionBackendImpl"]: def get_metadata_cls() -> type["PallasMetadata"]: return PallasMetadata - @staticmethod - def get_state_cls() -> type["CommonAttentionState"]: - return CommonAttentionState - @staticmethod def get_kv_cache_shape( num_blocks: int, From a1512f9bd3585706a69fc75a03046dee1f163f45 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 24 Sep 2025 01:54:29 -0700 Subject: [PATCH 2/4] pad_slot_id Signed-off-by: Woosuk Kwon --- vllm/attention/backends/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 86998ddbea0e..6b8d97be7050 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -9,6 +9,8 @@ logger = init_logger(__name__) +PAD_SLOT_ID = -1 + @dataclass class MLADims: From b43e7719e8118a59dc08c9d54151c72e4c42d2d9 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 24 Sep 2025 02:44:19 -0700 Subject: [PATCH 3/4] fix eagle Signed-off-by: Woosuk Kwon --- vllm/v1/spec_decode/eagle.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 5cae7df70470..8a0489fe20d8 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -9,7 +9,6 @@ import torch import torch.nn as nn -from vllm.attention.backends.abstract import AttentionMetadataBuilder from vllm.attention.layer import Attention from vllm.config import (CompilationLevel, VllmConfig, get_layers_from_vllm_config) @@ -25,7 +24,8 @@ from vllm.v1.attention.backends.tree_attn import (TreeAttentionMetadata, TreeAttentionMetadataBuilder) from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata -from vllm.v1.attention.backends.utils import CommonAttentionMetadata +from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, + CommonAttentionMetadata) from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.spec_decode.metadata import SpecDecodeMetadata From 46829c8b45ba6fff65d2777d2f7ff6cf1d1675e4 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 24 Sep 2025 20:19:34 +0000 Subject: [PATCH 4/4] precommit Signed-off-by: Woosuk Kwon --- vllm/v1/spec_decode/eagle.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 8a0489fe20d8..b30e4dab956a 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -184,8 +184,9 @@ def propose( builder = (self._get_attention_metadata_builder() if self.attn_metadata_builder is None else self.attn_metadata_builder) - attn_metadata = builder.build_for_drafting( - common_attn_metadata=common_attn_metadata, draft_index=0) + attn_metadata = builder.build_for_drafting( # type: ignore + common_attn_metadata=common_attn_metadata, + draft_index=0) # At this moment, we assume all eagle layers belong to the same KV # cache group, thus using the same attention metadata. @@ -319,7 +320,7 @@ def propose( exceeds_max_model_len, PADDING_SLOT_ID) # Rebuild attention metadata - attn_metadata = builder.build_for_drafting( + attn_metadata = builder.build_for_drafting( # type: ignore common_attn_metadata=common_attn_metadata, draft_index=token_index + 1) for layer_name in self.attn_layer_names: