Skip to content

Commit e6750d0

Browse files
authored
[V0 Deprecation] Remove unused classes in attention (#25541)
Signed-off-by: Woosuk Kwon <[email protected]> Signed-off-by: Woosuk Kwon <[email protected]>
1 parent 8c85305 commit e6750d0

File tree

6 files changed

+11
-716
lines changed

6 files changed

+11
-716
lines changed

vllm/attention/__init__.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,7 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
from vllm.attention.backends.abstract import (AttentionBackend,
5-
AttentionMetadata,
6-
AttentionMetadataBuilder,
7-
AttentionState, AttentionType)
5+
AttentionMetadata, AttentionType)
86
from vllm.attention.layer import Attention
97
from vllm.attention.selector import get_attn_backend
108

@@ -13,7 +11,5 @@
1311
"AttentionBackend",
1412
"AttentionMetadata",
1513
"AttentionType",
16-
"AttentionMetadataBuilder",
17-
"AttentionState",
1814
"get_attn_backend",
1915
]

vllm/attention/backends/abstract.py

Lines changed: 3 additions & 142 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,7 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
from abc import ABC, abstractmethod
5-
from contextlib import contextmanager
6-
from dataclasses import dataclass, fields
7-
from typing import (Any, Dict, Generic, List, Optional, Protocol, Set, Tuple,
8-
Type, TypeVar)
5+
from typing import Generic, List, Optional, Protocol, Tuple, Type, TypeVar
96

107
import torch
118

@@ -49,18 +46,13 @@ def get_impl_cls() -> Type["AttentionImpl"]:
4946
def get_metadata_cls() -> Type["AttentionMetadata"]:
5047
raise NotImplementedError
5148

52-
@staticmethod
53-
@abstractmethod
54-
def get_state_cls() -> Type["AttentionState"]:
55-
raise NotImplementedError
56-
5749
@classmethod
5850
def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata":
5951
return cls.get_metadata_cls()(*args, **kwargs)
6052

6153
@staticmethod
6254
@abstractmethod
63-
def get_builder_cls() -> Type["AttentionMetadataBuilder"]:
55+
def get_builder_cls(): # -> Type["AttentionMetadataBuilder"]:
6456
raise NotImplementedError
6557

6658
@staticmethod
@@ -77,149 +69,18 @@ def get_kv_cache_shape(
7769
def get_kv_cache_stride_order() -> Tuple[int, ...]:
7870
raise NotImplementedError
7971

80-
@staticmethod
81-
@abstractmethod
82-
def swap_blocks(
83-
src_kv_cache: torch.Tensor,
84-
dst_kv_cache: torch.Tensor,
85-
src_to_dst: torch.Tensor,
86-
) -> None:
87-
raise NotImplementedError
88-
89-
@staticmethod
90-
@abstractmethod
91-
def copy_blocks(
92-
kv_caches: List[torch.Tensor],
93-
src_to_dists: torch.Tensor,
94-
) -> None:
95-
raise NotImplementedError
96-
9772
@classmethod
9873
def full_cls_name(cls) -> tuple[str, str]:
9974
return (cls.__module__, cls.__qualname__)
10075

10176

102-
@dataclass
10377
class AttentionMetadata:
104-
"""Attention metadata for prefill and decode batched together."""
105-
# Total number of prefill requests.
106-
num_prefills: int
107-
# Number of prefill tokens.
108-
num_prefill_tokens: int
109-
# Number of decode tokens. Note that it is equivalent to the number of
110-
# decode requests.
111-
num_decode_tokens: int
112-
# (num_tokens,). The indices of the token slots that input tokens will be
113-
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
114-
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
115-
# in block 0, and 1st slot in block 1, respectively.
116-
slot_mapping: torch.Tensor
117-
118-
# Enable/disable KV scales calculation. This is so that we can disable the
119-
# calculation until after prefill and cuda graph capture.
120-
enable_kv_scales_calculation: bool
121-
122-
@property
123-
@abstractmethod
124-
def prefill_metadata(self) -> Optional["AttentionMetadata"]:
125-
"""Return the attention metadata that's required to run prefill
126-
attention."""
127-
pass
128-
129-
@property
130-
@abstractmethod
131-
def decode_metadata(self) -> Optional["AttentionMetadata"]:
132-
"""Return the attention metadata that's required to run decode
133-
attention."""
134-
pass
135-
136-
def asdict_zerocopy(self,
137-
skip_fields: Optional[Set[str]] = None
138-
) -> Dict[str, Any]:
139-
"""Similar to dataclasses.asdict, but avoids deepcopying."""
140-
if skip_fields is None:
141-
skip_fields = set()
142-
# Note that if we add dataclasses as fields, they will need
143-
# similar handling.
144-
return {
145-
field.name: getattr(self, field.name)
146-
for field in fields(self) if field.name not in skip_fields
147-
}
78+
pass
14879

14980

15081
T = TypeVar("T", bound=AttentionMetadata)
15182

15283

153-
class AttentionState(ABC, Generic[T]):
154-
"""Holds attention backend-specific objects reused during the
155-
lifetime of the model runner."""
156-
157-
@abstractmethod
158-
def __init__(self, runner: Any):
159-
...
160-
161-
@abstractmethod
162-
@contextmanager
163-
def graph_capture(self, max_batch_size: int):
164-
"""Context manager used when capturing CUDA graphs."""
165-
yield
166-
167-
@abstractmethod
168-
def graph_clone(self, batch_size: int) -> "AttentionState[T]":
169-
"""Clone attention state to save in CUDA graph metadata."""
170-
...
171-
172-
@abstractmethod
173-
def graph_capture_get_metadata_for_batch(
174-
self,
175-
batch_size: int,
176-
is_encoder_decoder_model: bool = False) -> T:
177-
"""Get attention metadata for CUDA graph capture of batch_size."""
178-
...
179-
180-
@abstractmethod
181-
def get_graph_input_buffers(
182-
self,
183-
attn_metadata: T,
184-
is_encoder_decoder_model: bool = False) -> Dict[str, Any]:
185-
"""Get attention-specific input buffers for CUDA graph capture."""
186-
...
187-
188-
@abstractmethod
189-
def prepare_graph_input_buffers(
190-
self,
191-
input_buffers: Dict[str, Any],
192-
attn_metadata: T,
193-
is_encoder_decoder_model: bool = False) -> None:
194-
"""In-place modify input buffers dict for CUDA graph replay."""
195-
...
196-
197-
@abstractmethod
198-
def begin_forward(self, model_input) -> None:
199-
"""Prepare state for forward pass."""
200-
...
201-
202-
203-
class AttentionMetadataBuilder(ABC, Generic[T]):
204-
"""Abstract class for attention metadata builders."""
205-
206-
@abstractmethod
207-
def __init__(self, input_builder) -> None:
208-
"""Create the builder, remember some configuration and parameters."""
209-
raise NotImplementedError
210-
211-
@abstractmethod
212-
def prepare(self) -> None:
213-
"""Prepare for one batch."""
214-
raise NotImplementedError
215-
216-
@abstractmethod
217-
def build(self, seq_lens: List[int], query_lens: List[int],
218-
cuda_graph_pad_size: int, batch_size: int) -> T:
219-
"""Build attention metadata with on-device tensors."""
220-
raise NotImplementedError
221-
222-
22384
class AttentionLayer(Protocol):
22485

22586
_q_scale: torch.Tensor

0 commit comments

Comments
 (0)