22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33"""Attention layer with AiterFlashAttention."""
44from dataclasses import dataclass
5- from typing import TYPE_CHECKING , Any , Optional
5+ from typing import Any , Optional
66
77import torch
88
99from vllm import _custom_ops as ops
1010from vllm .attention .backends .abstract import (AttentionBackend , AttentionImpl ,
1111 AttentionMetadata , AttentionType ,
1212 is_quantized_kv_cache )
13+ from vllm .config import VllmConfig
1314from vllm .logger import init_logger
1415from vllm .platforms import current_platform
1516from vllm .v1 .attention .backends .flash_attn import (
1617 make_local_attention_virtual_batches )
1718from vllm .v1 .attention .backends .utils import CommonAttentionMetadata
1819from vllm .v1 .kv_cache_interface import AttentionSpec
19- from vllm .v1 .worker .block_table import BlockTable
20-
21- if TYPE_CHECKING :
22- from vllm .v1 .core .sched .output import SchedulerOutput
23- from vllm .v1 .worker .gpu_input_batch import InputBatch
24- from vllm .v1 .worker .gpu_model_runner import GPUModelRunner
2520
2621if current_platform .is_rocm ():
2722 import aiter
@@ -172,56 +167,49 @@ def flash_attn_varlen_func_fake(
172167
173168class AiterFlashAttentionMetadataBuilder :
174169
175- def __init__ (self , runner : "GPUModelRunner" , kv_cache_spec : AttentionSpec ,
176- block_table : BlockTable ):
177- model_config = runner .model_config
178-
179- self .runner = runner
180- self .num_heads_q = model_config .get_num_attention_heads (
181- runner .parallel_config )
182- self .num_heads_kv = model_config .get_num_kv_heads (
183- runner .parallel_config )
184- self .headdim = model_config .get_head_size ()
170+ def __init__ (self , kv_cache_spec : AttentionSpec , vllm_config : VllmConfig ,
171+ device : torch .device ):
172+ self .vllm_config = vllm_config
173+ self .model_config = vllm_config .model_config
174+ self .parallel_config = vllm_config .parallel_config
175+ self .cache_config = vllm_config .cache_config
176+ self .device = device
177+
178+ self .num_heads_q = self .model_config .get_num_attention_heads (
179+ self .parallel_config )
180+ self .num_heads_kv = self .model_config .get_num_kv_heads (
181+ self .parallel_config )
182+ self .headdim = self .model_config .get_head_size ()
185183 self .block_size = kv_cache_spec .block_size
186184 self .kv_cache_spec = kv_cache_spec
187- self .block_table = block_table
188185
189186 # Sliding window size to be used with the AOT scheduler will be
190187 # populated on first build() call.
191188 self .aot_sliding_window : Optional [tuple [int , int ]] = None
192189
193- def reorder_batch (self , input_batch : "InputBatch" ,
194- scheduler_output : "SchedulerOutput" ) -> bool :
190+ def reorder_batch (self , input_batch , scheduler_output ) -> bool :
195191 return False
196192
197193 def build (self ,
198194 common_prefix_len : int ,
199195 common_attn_metadata : CommonAttentionMetadata ,
200196 fast_build : bool = False ) -> 'AiterFlashAttentionMetadata' :
201197
202- num_reqs = common_attn_metadata .num_reqs
203198 num_actual_tokens = common_attn_metadata .num_actual_tokens
204199 max_query_len = common_attn_metadata .max_query_len
205200
206- max_seq_len = int (self . runner . seq_lens_np [: num_reqs ]. max ())
207- total_tokens = int (self . runner . seq_lens_np [: num_reqs ]. sum ())
201+ max_seq_len = int (common_attn_metadata . seq_lens_cpu . max (). item ())
202+ total_tokens = int (common_attn_metadata . seq_lens_cpu . sum (). item ())
208203 query_start_loc = common_attn_metadata .query_start_loc
204+ query_start_loc_cpu = common_attn_metadata .query_start_loc_cpu
209205 seq_lens = common_attn_metadata .seq_lens
210- block_table = self .block_table
211- block_table_tensor = block_table .get_device_tensor ()[:num_reqs ]
212-
213- block_table .slot_mapping [:num_actual_tokens ].copy_ (
214- block_table .slot_mapping_cpu [:num_actual_tokens ],
215- non_blocking = True )
216- # Fill unused with -1. Needed for reshape_and_cache in full cuda graph
217- # mode.
218- block_table .slot_mapping [num_actual_tokens :].fill_ (- 1 )
219-
220- slot_mapping = block_table .slot_mapping [:num_actual_tokens ]
206+ seq_lens_cpu = common_attn_metadata .seq_lens_cpu
207+ block_table_tensor = common_attn_metadata .block_table_tensor
208+ slot_mapping = common_attn_metadata .slot_mapping
221209
222210 cu_seq_lens = torch .zeros (seq_lens .shape [0 ] + 1 ,
223211 dtype = torch .int32 ,
224- device = "cuda" )
212+ device = self . device )
225213 torch .cumsum (seq_lens ,
226214 dim = 0 ,
227215 dtype = cu_seq_lens .dtype ,
@@ -233,21 +221,21 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
233221
234222 # for local attention
235223 local_attn_metadata = None
236- if self .runner .attention_chunk_size is not None :
224+ if self .model_config .attention_chunk_size is not None :
237225 seqlens_q_local_np , virt_q_cu_seqlens_np , virt_k_seqlens_np , \
238226 virt_block_table_tensor = make_local_attention_virtual_batches (
239- self .runner .attention_chunk_size ,
240- self . runner . query_start_loc_np [: num_reqs + 1 ] ,
241- self . runner . seq_lens_np [: num_reqs ] ,
227+ self .model_config .attention_chunk_size ,
228+ query_start_loc_cpu . numpy () ,
229+ seq_lens_cpu . numpy () ,
242230 block_table_tensor ,
243231 self .block_size ,
244232 )
245233 local_query_start_loc = torch .from_numpy (virt_q_cu_seqlens_np ).to (
246- self .runner . device , non_blocking = True )
234+ self .device , non_blocking = True )
247235 local_seqused_k = torch .from_numpy (virt_k_seqlens_np ).to (
248- self .runner . device , non_blocking = True )
249- local_max_query_len = int ( seqlens_q_local_np .max ())
250- local_max_seq_len = int ( virt_k_seqlens_np .max ())
236+ self .device , non_blocking = True )
237+ local_max_query_len = seqlens_q_local_np .max (). item ( )
238+ local_max_seq_len = virt_k_seqlens_np .max (). item ( )
251239 local_scheduler_metadata = schedule (
252240 batch_size = local_query_start_loc .shape [0 ] - 1 ,
253241 cu_query_lens = local_query_start_loc ,
@@ -258,12 +246,11 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
258246
259247 local_cu_seq_lens = torch .zeros (virt_k_seqlens_np .shape [0 ] + 1 ,
260248 dtype = torch .int32 ,
261- device = self .runner . device )
249+ device = self .device )
262250 local_cu_seq_lens [1 :] = torch .cumsum (
263- torch .from_numpy (virt_k_seqlens_np ).to (
264- device = self .runner .device ,
265- dtype = torch .int32 ,
266- non_blocking = True ),
251+ torch .from_numpy (virt_k_seqlens_np ).to (device = self .device ,
252+ dtype = torch .int32 ,
253+ non_blocking = True ),
267254 dim = 0 )
268255
269256
0 commit comments