22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33"""Attention layer with PagedAttention and Triton prefix prefill."""
44from dataclasses import dataclass
5- from typing import TYPE_CHECKING , Any , ClassVar , Optional
5+ from typing import Any , ClassVar , Optional
66
77import torch
88
1414 chunked_prefill_paged_decode )
1515from vllm .attention .ops .paged_attn import PagedAttention
1616from vllm .attention .ops .triton_unified_attention import unified_attention
17+ from vllm .config import VllmConfig
1718from vllm .logger import init_logger
1819from vllm .platforms import current_platform
1920from vllm .v1 .attention .backends .flash_attn import FlashAttentionMetadata
2021from vllm .v1 .attention .backends .utils import (
2122 AttentionMetadataBuilder , CommonAttentionMetadata ,
2223 make_local_attention_virtual_batches )
2324from vllm .v1 .kv_cache_interface import AttentionSpec
24- from vllm .v1 .worker .block_table import BlockTable
25-
26- if TYPE_CHECKING :
27- from vllm .v1 .worker .gpu_model_runner import GPUModelRunner
2825
2926logger = init_logger (__name__ )
3027
@@ -75,12 +72,21 @@ class TritonAttentionMetadataBuilder(
7572 AttentionMetadataBuilder [TritonAttentionMetadata ]):
7673 full_cudagraph_supported : ClassVar [bool ] = True
7774
78- def __init__ (self , runner : "GPUModelRunner" , kv_cache_spec : AttentionSpec ,
79- block_table : BlockTable ):
80- self .runner = runner
75+ def __init__ (self , kv_cache_spec : AttentionSpec , vllm_config : VllmConfig ,
76+ device : torch . device ):
77+ self .device = device
8178 self .block_size = kv_cache_spec .block_size
8279 self .kv_cache_spec = kv_cache_spec
83- self .block_table = block_table
80+
81+ model_config = vllm_config .model_config
82+ self .num_heads_q = model_config .get_num_attention_heads (
83+ vllm_config .parallel_config )
84+ self .num_heads_kv = model_config .get_num_kv_heads (
85+ vllm_config .parallel_config )
86+ self .headdim = model_config .get_head_size ()
87+
88+ self .attention_chunk_size = getattr (vllm_config .scheduler_config ,
89+ 'attention_chunk_size' , None )
8490
8591 def build_for_cudagraph_capture (
8692 self , common_attn_metadata : CommonAttentionMetadata
@@ -96,42 +102,32 @@ def build(self,
96102 common_prefix_len : int ,
97103 common_attn_metadata : CommonAttentionMetadata ,
98104 fast_build : bool = False ) -> TritonAttentionMetadata :
99- num_reqs = common_attn_metadata .num_reqs
100105 num_actual_tokens = common_attn_metadata .num_actual_tokens
101106 max_query_len = common_attn_metadata .max_query_len
102107
103- max_seq_len = int (self . runner . seq_lens_np [: num_reqs ] .max ())
108+ max_seq_len = int (common_attn_metadata . seq_lens_cpu .max ())
104109 query_start_loc = common_attn_metadata .query_start_loc
105110 seq_lens = common_attn_metadata .seq_lens
106- block_table = self .block_table
107- block_table_tensor = block_table .get_device_tensor ()[:num_reqs ]
108-
109- block_table .slot_mapping [:num_actual_tokens ].copy_ (
110- block_table .slot_mapping_cpu [:num_actual_tokens ],
111- non_blocking = True )
112- # Fill unused with -1. Needed for reshape_and_cache in full cuda graph
113- # mode.
114- block_table .slot_mapping [num_actual_tokens :].fill_ (- 1 )
115-
116- slot_mapping = block_table .slot_mapping [:num_actual_tokens ]
111+ block_table_tensor = common_attn_metadata .block_table_tensor
112+ slot_mapping = common_attn_metadata .slot_mapping
117113
118114 # for local attention
119115 local_attn_metadata = None
120- if self .runner . attention_chunk_size is not None :
116+ if self .attention_chunk_size is not None :
121117 seqlens_q_local_np , virt_q_cu_seqlens_np , virt_k_seqlens_np , \
122118 virt_block_table_tensor = make_local_attention_virtual_batches (
123- self .runner . attention_chunk_size ,
124- self . runner . query_start_loc_np [: num_reqs + 1 ] ,
125- self . runner . seq_lens_np [: num_reqs ] ,
119+ self .attention_chunk_size ,
120+ common_attn_metadata . query_start_loc_cpu . numpy () ,
121+ common_attn_metadata . seq_lens_cpu . numpy () ,
126122 block_table_tensor ,
127123 self .block_size ,
128124 )
129125 local_query_start_loc = torch .from_numpy (virt_q_cu_seqlens_np ).to (
130- self .runner . device , non_blocking = True )
126+ self .device , non_blocking = True )
131127 local_seqused_k = torch .from_numpy (virt_k_seqlens_np ).to (
132- self .runner . device , non_blocking = True )
133- local_max_query_len = seqlens_q_local_np .max ()
134- local_max_seq_len = virt_k_seqlens_np .max ()
128+ self .device , non_blocking = True )
129+ local_max_query_len = seqlens_q_local_np .max (). item ()
130+ local_max_seq_len = virt_k_seqlens_np .max (). item ()
135131
136132 local_attn_metadata = TritonAttentionMetadata \
137133 .LocalAttentionMetadata (
@@ -148,14 +144,13 @@ def build(self,
148144 if use_cascade :
149145 cu_prefix_query_lens = torch .tensor ([0 , num_actual_tokens ],
150146 dtype = torch .int32 ,
151- device = self .runner . device )
147+ device = self .device )
152148 prefix_kv_lens = torch .tensor ([common_prefix_len ],
153149 dtype = torch .int32 ,
154- device = self .runner . device )
155- suffix_kv_lens = (self . runner . seq_lens_np [: num_reqs ] -
150+ device = self .device )
151+ suffix_kv_lens = (common_attn_metadata . seq_lens_cpu -
156152 common_prefix_len )
157- suffix_kv_lens = torch .from_numpy (suffix_kv_lens ).to (
158- self .runner .device )
153+ suffix_kv_lens = suffix_kv_lens .to (self .device )
159154 else :
160155 cu_prefix_query_lens = None
161156 prefix_kv_lens = None
0 commit comments