Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions vllm/attention/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata,
AttentionMetadataBuilder,
AttentionState, AttentionType)
AttentionMetadata, AttentionType)
from vllm.attention.layer import Attention
from vllm.attention.selector import get_attn_backend

Expand All @@ -13,7 +11,5 @@
"AttentionBackend",
"AttentionMetadata",
"AttentionType",
"AttentionMetadataBuilder",
"AttentionState",
"get_attn_backend",
]
145 changes: 3 additions & 142 deletions vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from abc import ABC, abstractmethod
from contextlib import contextmanager
from dataclasses import dataclass, fields
from typing import (Any, Dict, Generic, List, Optional, Protocol, Set, Tuple,
Type, TypeVar)
from typing import Generic, List, Optional, Protocol, Tuple, Type, TypeVar

import torch

Expand Down Expand Up @@ -49,18 +46,13 @@ def get_impl_cls() -> Type["AttentionImpl"]:
def get_metadata_cls() -> Type["AttentionMetadata"]:
raise NotImplementedError

@staticmethod
@abstractmethod
def get_state_cls() -> Type["AttentionState"]:
raise NotImplementedError

@classmethod
def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata":
return cls.get_metadata_cls()(*args, **kwargs)

@staticmethod
@abstractmethod
def get_builder_cls() -> Type["AttentionMetadataBuilder"]:
def get_builder_cls(): # -> Type["AttentionMetadataBuilder"]:
raise NotImplementedError

@staticmethod
Expand All @@ -77,149 +69,18 @@ def get_kv_cache_shape(
def get_kv_cache_stride_order() -> Tuple[int, ...]:
raise NotImplementedError

@staticmethod
@abstractmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: torch.Tensor,
) -> None:
raise NotImplementedError

@staticmethod
@abstractmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
) -> None:
raise NotImplementedError

@classmethod
def full_cls_name(cls) -> tuple[str, str]:
return (cls.__module__, cls.__qualname__)


@dataclass
class AttentionMetadata:
"""Attention metadata for prefill and decode batched together."""
# Total number of prefill requests.
num_prefills: int
# Number of prefill tokens.
num_prefill_tokens: int
# Number of decode tokens. Note that it is equivalent to the number of
# decode requests.
num_decode_tokens: int
# (num_tokens,). The indices of the token slots that input tokens will be
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
# in block 0, and 1st slot in block 1, respectively.
slot_mapping: torch.Tensor

# Enable/disable KV scales calculation. This is so that we can disable the
# calculation until after prefill and cuda graph capture.
enable_kv_scales_calculation: bool

@property
@abstractmethod
def prefill_metadata(self) -> Optional["AttentionMetadata"]:
"""Return the attention metadata that's required to run prefill
attention."""
pass

@property
@abstractmethod
def decode_metadata(self) -> Optional["AttentionMetadata"]:
"""Return the attention metadata that's required to run decode
attention."""
pass

def asdict_zerocopy(self,
skip_fields: Optional[Set[str]] = None
) -> Dict[str, Any]:
"""Similar to dataclasses.asdict, but avoids deepcopying."""
if skip_fields is None:
skip_fields = set()
# Note that if we add dataclasses as fields, they will need
# similar handling.
return {
field.name: getattr(self, field.name)
for field in fields(self) if field.name not in skip_fields
}
pass


T = TypeVar("T", bound=AttentionMetadata)


class AttentionState(ABC, Generic[T]):
"""Holds attention backend-specific objects reused during the
lifetime of the model runner."""

@abstractmethod
def __init__(self, runner: Any):
...

@abstractmethod
@contextmanager
def graph_capture(self, max_batch_size: int):
"""Context manager used when capturing CUDA graphs."""
yield

@abstractmethod
def graph_clone(self, batch_size: int) -> "AttentionState[T]":
"""Clone attention state to save in CUDA graph metadata."""
...

@abstractmethod
def graph_capture_get_metadata_for_batch(
self,
batch_size: int,
is_encoder_decoder_model: bool = False) -> T:
"""Get attention metadata for CUDA graph capture of batch_size."""
...

@abstractmethod
def get_graph_input_buffers(
self,
attn_metadata: T,
is_encoder_decoder_model: bool = False) -> Dict[str, Any]:
"""Get attention-specific input buffers for CUDA graph capture."""
...

@abstractmethod
def prepare_graph_input_buffers(
self,
input_buffers: Dict[str, Any],
attn_metadata: T,
is_encoder_decoder_model: bool = False) -> None:
"""In-place modify input buffers dict for CUDA graph replay."""
...

@abstractmethod
def begin_forward(self, model_input) -> None:
"""Prepare state for forward pass."""
...


class AttentionMetadataBuilder(ABC, Generic[T]):
"""Abstract class for attention metadata builders."""

@abstractmethod
def __init__(self, input_builder) -> None:
"""Create the builder, remember some configuration and parameters."""
raise NotImplementedError

@abstractmethod
def prepare(self) -> None:
"""Prepare for one batch."""
raise NotImplementedError

@abstractmethod
def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int) -> T:
"""Build attention metadata with on-device tensors."""
raise NotImplementedError


class AttentionLayer(Protocol):

_q_scale: torch.Tensor
Expand Down
Loading