1010from vllm .attention .backends .abstract import (AttentionBackend , AttentionImpl ,
1111 AttentionMetadata , AttentionType ,
1212 is_quantized_kv_cache )
13+ from vllm .attention .layer import Attention
1314from vllm .attention .ops .merge_attn_states import merge_attn_states
1415from vllm .attention .utils .fa_utils import (flash_attn_supports_fp8 ,
1516 get_flash_attn_version )
17+ from vllm .config import VllmConfig , get_layers_from_vllm_config
1618from vllm .logger import init_logger
1719from vllm .platforms import current_platform
1820from vllm .utils import cdiv
@@ -273,20 +275,35 @@ def make_local_attention_virtual_batches(
273275 block_table_local
274276
275277
278+ def _get_sliding_window_configs (
279+ vllm_config : VllmConfig ) -> set [Optional [tuple [int , int ]]]:
280+ """Get the set of all sliding window configs used in the model."""
281+ sliding_window_configs : set [Optional [tuple [int , int ]]] = set ()
282+ layers = get_layers_from_vllm_config (vllm_config , Attention )
283+ for layer in layers .values ():
284+ assert isinstance (layer .impl , FlashAttentionImpl )
285+ sliding_window_configs .add (layer .impl .sliding_window )
286+ return sliding_window_configs
287+
288+
276289class FlashAttentionMetadataBuilder :
277290
278291 def __init__ (self , runner : "GPUModelRunner" ):
279292 model_config = runner .model_config
280293
281294 self .runner = runner
282- self .aot_schedule = (get_flash_attn_version () == 3 )
283295 self .num_heads_q = model_config .get_num_attention_heads (
284296 runner .parallel_config )
285297 self .num_heads_kv = model_config .get_num_kv_heads (
286298 runner .parallel_config )
287299 self .headdim = model_config .get_head_size ()
288300 self .page_size = self .runner .block_size
289301
302+ self .aot_schedule = (get_flash_attn_version () == 3 )
303+ # Sliding window size to be used with the AOT scheduler will be
304+ # populated on first build() call.
305+ self .aot_sliding_window : Optional [tuple [int , int ]] = None
306+
290307 def reorder_batch (self , input_batch : "InputBatch" ,
291308 scheduler_output : "SchedulerOutput" ) -> bool :
292309 return False
@@ -304,6 +321,22 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
304321 slot_mapping = self .runner .slot_mapping_cpu [:num_actual_tokens ].to (
305322 self .runner .device , non_blocking = True ).long ()
306323
324+ if self .aot_sliding_window is None :
325+ self .aot_sliding_window = (- 1 , - 1 )
326+ # For the AOT scheduler we need the sliding window value to be
327+ # constant for all layers to. We have to populate this on the first
328+ # build() call so the layers are constructed (cannot populate)
329+ # in __init__.
330+ if self .aot_schedule :
331+ sliding_window_configs = _get_sliding_window_configs (
332+ self .runner .vllm_config )
333+ if len (sliding_window_configs ) == 1 :
334+ sliding_window_config = sliding_window_configs .pop ()
335+ if sliding_window_config is not None :
336+ self .aot_sliding_window = sliding_window_config
337+ elif len (sliding_window_configs ) > 1 :
338+ self .aot_schedule = False
339+
307340 def schedule (batch_size , cu_query_lens , max_query_len , seqlens ,
308341 max_seq_len , causal ):
309342 if self .aot_schedule :
@@ -318,6 +351,7 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
318351 page_size = self .page_size ,
319352 cu_seqlens_q = cu_query_lens ,
320353 causal = causal ,
354+ window_size = self .aot_sliding_window ,
321355 )
322356 return None
323357
0 commit comments