22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
44from abc import ABC , abstractmethod
5- from contextlib import contextmanager
6- from dataclasses import dataclass , fields
7- from typing import (Any , Dict , Generic , List , Optional , Protocol , Set , Tuple ,
8- Type , TypeVar )
5+ from typing import Generic , List , Optional , Protocol , Tuple , Type , TypeVar
96
107import torch
118
@@ -49,18 +46,13 @@ def get_impl_cls() -> Type["AttentionImpl"]:
4946 def get_metadata_cls () -> Type ["AttentionMetadata" ]:
5047 raise NotImplementedError
5148
52- @staticmethod
53- @abstractmethod
54- def get_state_cls () -> Type ["AttentionState" ]:
55- raise NotImplementedError
56-
5749 @classmethod
5850 def make_metadata (cls , * args , ** kwargs ) -> "AttentionMetadata" :
5951 return cls .get_metadata_cls ()(* args , ** kwargs )
6052
6153 @staticmethod
6254 @abstractmethod
63- def get_builder_cls () -> Type ["AttentionMetadataBuilder" ]:
55+ def get_builder_cls (): # -> Type["AttentionMetadataBuilder"]:
6456 raise NotImplementedError
6557
6658 @staticmethod
@@ -77,149 +69,18 @@ def get_kv_cache_shape(
7769 def get_kv_cache_stride_order () -> Tuple [int , ...]:
7870 raise NotImplementedError
7971
80- @staticmethod
81- @abstractmethod
82- def swap_blocks (
83- src_kv_cache : torch .Tensor ,
84- dst_kv_cache : torch .Tensor ,
85- src_to_dst : torch .Tensor ,
86- ) -> None :
87- raise NotImplementedError
88-
89- @staticmethod
90- @abstractmethod
91- def copy_blocks (
92- kv_caches : List [torch .Tensor ],
93- src_to_dists : torch .Tensor ,
94- ) -> None :
95- raise NotImplementedError
96-
9772 @classmethod
9873 def full_cls_name (cls ) -> tuple [str , str ]:
9974 return (cls .__module__ , cls .__qualname__ )
10075
10176
102- @dataclass
10377class AttentionMetadata :
104- """Attention metadata for prefill and decode batched together."""
105- # Total number of prefill requests.
106- num_prefills : int
107- # Number of prefill tokens.
108- num_prefill_tokens : int
109- # Number of decode tokens. Note that it is equivalent to the number of
110- # decode requests.
111- num_decode_tokens : int
112- # (num_tokens,). The indices of the token slots that input tokens will be
113- # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
114- # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
115- # in block 0, and 1st slot in block 1, respectively.
116- slot_mapping : torch .Tensor
117-
118- # Enable/disable KV scales calculation. This is so that we can disable the
119- # calculation until after prefill and cuda graph capture.
120- enable_kv_scales_calculation : bool
121-
122- @property
123- @abstractmethod
124- def prefill_metadata (self ) -> Optional ["AttentionMetadata" ]:
125- """Return the attention metadata that's required to run prefill
126- attention."""
127- pass
128-
129- @property
130- @abstractmethod
131- def decode_metadata (self ) -> Optional ["AttentionMetadata" ]:
132- """Return the attention metadata that's required to run decode
133- attention."""
134- pass
135-
136- def asdict_zerocopy (self ,
137- skip_fields : Optional [Set [str ]] = None
138- ) -> Dict [str , Any ]:
139- """Similar to dataclasses.asdict, but avoids deepcopying."""
140- if skip_fields is None :
141- skip_fields = set ()
142- # Note that if we add dataclasses as fields, they will need
143- # similar handling.
144- return {
145- field .name : getattr (self , field .name )
146- for field in fields (self ) if field .name not in skip_fields
147- }
78+ pass
14879
14980
15081T = TypeVar ("T" , bound = AttentionMetadata )
15182
15283
153- class AttentionState (ABC , Generic [T ]):
154- """Holds attention backend-specific objects reused during the
155- lifetime of the model runner."""
156-
157- @abstractmethod
158- def __init__ (self , runner : Any ):
159- ...
160-
161- @abstractmethod
162- @contextmanager
163- def graph_capture (self , max_batch_size : int ):
164- """Context manager used when capturing CUDA graphs."""
165- yield
166-
167- @abstractmethod
168- def graph_clone (self , batch_size : int ) -> "AttentionState[T]" :
169- """Clone attention state to save in CUDA graph metadata."""
170- ...
171-
172- @abstractmethod
173- def graph_capture_get_metadata_for_batch (
174- self ,
175- batch_size : int ,
176- is_encoder_decoder_model : bool = False ) -> T :
177- """Get attention metadata for CUDA graph capture of batch_size."""
178- ...
179-
180- @abstractmethod
181- def get_graph_input_buffers (
182- self ,
183- attn_metadata : T ,
184- is_encoder_decoder_model : bool = False ) -> Dict [str , Any ]:
185- """Get attention-specific input buffers for CUDA graph capture."""
186- ...
187-
188- @abstractmethod
189- def prepare_graph_input_buffers (
190- self ,
191- input_buffers : Dict [str , Any ],
192- attn_metadata : T ,
193- is_encoder_decoder_model : bool = False ) -> None :
194- """In-place modify input buffers dict for CUDA graph replay."""
195- ...
196-
197- @abstractmethod
198- def begin_forward (self , model_input ) -> None :
199- """Prepare state for forward pass."""
200- ...
201-
202-
203- class AttentionMetadataBuilder (ABC , Generic [T ]):
204- """Abstract class for attention metadata builders."""
205-
206- @abstractmethod
207- def __init__ (self , input_builder ) -> None :
208- """Create the builder, remember some configuration and parameters."""
209- raise NotImplementedError
210-
211- @abstractmethod
212- def prepare (self ) -> None :
213- """Prepare for one batch."""
214- raise NotImplementedError
215-
216- @abstractmethod
217- def build (self , seq_lens : List [int ], query_lens : List [int ],
218- cuda_graph_pad_size : int , batch_size : int ) -> T :
219- """Build attention metadata with on-device tensors."""
220- raise NotImplementedError
221-
222-
22384class AttentionLayer (Protocol ):
22485
22586 _q_scale : torch .Tensor
0 commit comments