3030if current_platform .is_cuda ():
3131 from vllm .vllm_flash_attn import (flash_attn_varlen_func ,
3232 get_scheduler_metadata )
33- if current_platform .is_rocm ():
34- import aiter
35-
36- from vllm .attention .ops .triton_unified_attention import unified_attention
37- from vllm .triton_utils import tl , triton
38- from vllm .utils import direct_register_custom_op
39-
40- @triton .jit
41- def _vllm_layout_trans_kernel (
42- k_buffer_ptr ,
43- v_buffer_ptr ,
44- k_values_ptr ,
45- v_values_ptr ,
46- b_seq_lens_loc ,
47- block_table ,
48- block_table_stride_0 ,
49- E_DIM : tl .constexpr ,
50- BLOCK_SIZE : tl .constexpr ,
51- ):
52- batch_idx = tl .program_id (0 )
53- block_idx = tl .program_id (1 )
54- batch_token_indexes = tl .load (b_seq_lens_loc + batch_idx +
55- tl .arange (0 , 2 ))
56- batch_token_start , batch_token_end = tl .split (batch_token_indexes )
57- seq_len = batch_token_end - batch_token_start
58- if block_idx * BLOCK_SIZE < seq_len :
59- block_mask = (block_idx * BLOCK_SIZE +
60- tl .arange (0 , BLOCK_SIZE )[:, None ]) < seq_len
61-
62- kv_idx = tl .load (block_table + batch_idx * block_table_stride_0 +
63- block_idx )
64-
65- kv_buffer_off = kv_idx * BLOCK_SIZE * E_DIM + tl .arange (
66- 0 , BLOCK_SIZE )[:, None ] * E_DIM + tl .arange (0 , E_DIM )[None , :]
67- k_vals = tl .load (k_buffer_ptr + kv_buffer_off ,
68- mask = block_mask ,
69- other = 0.0 )
70- v_vals = tl .load (v_buffer_ptr + kv_buffer_off ,
71- mask = block_mask ,
72- other = 0.0 )
73-
74- kv_values_off = batch_token_start * E_DIM + \
75- block_idx * BLOCK_SIZE * E_DIM + \
76- tl .arange (0 , BLOCK_SIZE )[:, None ] * E_DIM + \
77- tl .arange (0 , E_DIM )[None , :]
78- tl .store (k_values_ptr + kv_values_off , k_vals , mask = block_mask )
79- tl .store (v_values_ptr + kv_values_off , v_vals , mask = block_mask )
80-
81- def vllm_layout_trans (b_seq_lens_loc , block_table , k_buffer , v_buffer ,
82- max_seq_len , total_tokens ):
83- H_KV = v_buffer .shape [2 ]
84- D = v_buffer .shape [3 ]
85- BLOCK_SIZE = v_buffer .shape [1 ]
86- dtype = k_buffer .dtype
87- k_values = torch .empty ((total_tokens , H_KV , D ),
88- dtype = dtype ,
89- device = "cuda" )
90- v_values = torch .empty ((total_tokens , H_KV , D ),
91- dtype = dtype ,
92- device = "cuda" )
93-
94- grid = (block_table .shape [0 ],
95- (max_seq_len + BLOCK_SIZE - 1 ) // BLOCK_SIZE )
96-
97- _vllm_layout_trans_kernel [grid ](k_buffer ,
98- v_buffer ,
99- k_values ,
100- v_values ,
101- b_seq_lens_loc ,
102- block_table ,
103- block_table .stride (0 ),
104- E_DIM = H_KV * D ,
105- BLOCK_SIZE = BLOCK_SIZE )
106-
107- return k_values , v_values
108-
109- def flash_attn_varlen_func_impl (
110- q : torch .Tensor ,
111- k_cache : torch .Tensor ,
112- v_cache : torch .Tensor ,
113- out : torch .Tensor ,
114- cu_seqlens_q : torch .Tensor ,
115- cu_seqlens_k : torch .Tensor ,
116- total_tokens : int ,
117- max_seqlen_q : int ,
118- max_seqlen_k : int ,
119- softmax_scale : float ,
120- window_size : Optional [list [int ]], # -1 means infinite context window
121- alibi_slopes : Optional [list [float ]],
122- block_table : torch .Tensor ,
123- ) -> torch .Tensor :
124- k , v = vllm_layout_trans (cu_seqlens_k , block_table , k_cache , v_cache ,
125- max_seqlen_k , total_tokens )
126- output = aiter .flash_attn_varlen_func (
127- q = q ,
128- k = k ,
129- v = v ,
130- cu_seqlens_q = cu_seqlens_q ,
131- max_seqlen_q = max_seqlen_q ,
132- cu_seqlens_k = cu_seqlens_k ,
133- max_seqlen_k = max_seqlen_k ,
134- softmax_scale = softmax_scale ,
135- causal = True ,
136- alibi_slopes = alibi_slopes ,
137- window_size = window_size ,
138- out = out ,
139- )
140- return output
141-
142- def flash_attn_varlen_func_fake (
143- q : torch .Tensor ,
144- k_cache : torch .Tensor ,
145- v_cache : torch .Tensor ,
146- out : torch .Tensor ,
147- cu_seqlens_q : torch .Tensor ,
148- cu_seqlens_k : torch .Tensor ,
149- total_tokens : int ,
150- max_seqlen_q : int ,
151- max_seqlen_k : int ,
152- softmax_scale : float ,
153- window_size : Optional [list [int ]], # -1 means infinite context window
154- alibi_slopes : Optional [list [float ]],
155- block_table : torch .Tensor ,
156- ) -> torch .Tensor :
157- return torch .empty (q .shape [0 ],
158- q .shape [1 ],
159- v_cache .shape [- 2 ],
160- dtype = torch .float8_e4m3fnuz ,
161- device = "cuda" )
162-
163- try :
164- direct_register_custom_op ("flash_attn_varlen_func" ,
165- flash_attn_varlen_func_impl , ["out" ],
166- flash_attn_varlen_func_fake )
167- flash_attn_varlen_func = torch .ops .vllm .flash_attn_varlen_func
168-
169- except AttributeError :
170- flash_attn_varlen_func = flash_attn_varlen_func_impl
17133
17234logger = init_logger (__name__ )
17335
@@ -223,8 +85,6 @@ class FlashAttentionMetadata:
22385 query_start_loc : torch .Tensor
22486 max_seq_len : int
22587 seq_lens : torch .Tensor
226- cu_seq_lens : torch .Tensor
227- total_tokens : int
22888 block_table : torch .Tensor
22989 slot_mapping : torch .Tensor
23090
@@ -466,7 +326,6 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
466326 common_prefix_len : int ,
467327 common_attn_metadata : CommonAttentionMetadata ):
468328 max_seq_len = self .runner .seq_lens_np [:num_reqs ].max ()
469- total_tokens = self .runner .seq_lens_np [:num_reqs ].sum ()
470329 query_start_loc = common_attn_metadata .query_start_loc
471330 seq_lens = common_attn_metadata .seq_lens
472331 block_table = self .block_table
@@ -481,13 +340,6 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
481340
482341 slot_mapping = block_table .slot_mapping [:num_actual_tokens ]
483342
484- cu_seq_lens = torch .zeros (seq_lens .shape [0 ] + 1 ,
485- dtype = torch .int32 ,
486- device = "cuda" )
487- torch .cumsum (seq_lens ,
488- dim = 0 ,
489- dtype = cu_seq_lens .dtype ,
490- out = cu_seq_lens [1 :])
491343 if self .aot_sliding_window is None :
492344 self .aot_sliding_window = (- 1 , - 1 )
493345 # For the AOT scheduler we need the sliding window value to be
@@ -601,8 +453,6 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
601453 query_start_loc = query_start_loc ,
602454 max_seq_len = max_seq_len ,
603455 seq_lens = seq_lens ,
604- cu_seq_lens = cu_seq_lens ,
605- total_tokens = total_tokens ,
606456 block_table = block_table_tensor ,
607457 slot_mapping = slot_mapping ,
608458 use_cascade = use_cascade ,
@@ -768,67 +618,28 @@ def forward(
768618 scheduler_metadata = attn_metadata .scheduler_metadata
769619
770620 descale_shape = (cu_seqlens_q .shape [0 ] - 1 , key .shape [1 ])
771- if current_platform .is_rocm ():
772- cu_seq_lens = attn_metadata .cu_seq_lens
773- total_tokens = attn_metadata .total_tokens
774- if max_seqlen_q <= 1 :
775- unified_attention (
776- q = query [:num_actual_tokens ],
777- k = key_cache ,
778- v = value_cache ,
779- out = output [:num_actual_tokens ],
780- cu_seqlens_q = cu_seqlens_q ,
781- max_seqlen_q = max_seqlen_q ,
782- seqused_k = seqused_k ,
783- max_seqlen_k = max_seqlen_k ,
784- softmax_scale = self .scale ,
785- causal = True ,
786- alibi_slopes = self .alibi_slopes ,
787- window_size = self .sliding_window ,
788- block_table = block_table ,
789- softcap = self .logits_soft_cap ,
790- q_descale = None , # Not supported
791- k_descale = layer ._k_scale .expand (descale_shape ),
792- v_descale = layer ._v_scale .expand (descale_shape ),
793- )
794- else :
795- flash_attn_varlen_func (
796- query [:num_actual_tokens ],
797- key_cache ,
798- value_cache ,
799- out = output [:num_actual_tokens ],
800- cu_seqlens_q = cu_seqlens_q ,
801- max_seqlen_q = max_seqlen_q ,
802- max_seqlen_k = max_seqlen_k ,
803- total_tokens = total_tokens ,
804- softmax_scale = self .scale ,
805- alibi_slopes = self .alibi_slopes ,
806- window_size = list (self .sliding_window ),
807- block_table = block_table ,
808- cu_seqlens_k = cu_seq_lens ,
809- )
810- else :
811- flash_attn_varlen_func (
812- q = query [:num_actual_tokens ],
813- k = key_cache ,
814- v = value_cache ,
815- out = output [:num_actual_tokens ],
816- cu_seqlens_q = cu_seqlens_q ,
817- max_seqlen_q = max_seqlen_q ,
818- seqused_k = seqused_k ,
819- max_seqlen_k = max_seqlen_k ,
820- softmax_scale = self .scale ,
821- causal = True ,
822- alibi_slopes = self .alibi_slopes ,
823- window_size = self .sliding_window ,
824- block_table = block_table ,
825- softcap = self .logits_soft_cap ,
826- scheduler_metadata = scheduler_metadata ,
827- fa_version = self .vllm_flash_attn_version ,
828- q_descale = layer ._q_scale .expand (descale_shape ),
829- k_descale = layer ._k_scale .expand (descale_shape ),
830- v_descale = layer ._v_scale .expand (descale_shape ),
831- )
621+
622+ flash_attn_varlen_func (
623+ q = query [:num_actual_tokens ],
624+ k = key_cache ,
625+ v = value_cache ,
626+ out = output [:num_actual_tokens ],
627+ cu_seqlens_q = cu_seqlens_q ,
628+ max_seqlen_q = max_seqlen_q ,
629+ seqused_k = seqused_k ,
630+ max_seqlen_k = max_seqlen_k ,
631+ softmax_scale = self .scale ,
632+ causal = True ,
633+ alibi_slopes = self .alibi_slopes ,
634+ window_size = self .sliding_window ,
635+ block_table = block_table ,
636+ softcap = self .logits_soft_cap ,
637+ scheduler_metadata = scheduler_metadata ,
638+ fa_version = self .vllm_flash_attn_version ,
639+ q_descale = layer ._q_scale .expand (descale_shape ),
640+ k_descale = layer ._k_scale .expand (descale_shape ),
641+ v_descale = layer ._v_scale .expand (descale_shape ),
642+ )
832643 return output
833644
834645 assert not use_local_attn , (
0 commit comments