File tree Expand file tree Collapse file tree 3 files changed +6
-4
lines changed Expand file tree Collapse file tree 3 files changed +6
-4
lines changed Original file line number Diff line number Diff line change @@ -913,7 +913,8 @@ def forward(
913913 )
914914 max_logits = torch .empty_like (exp_sums )
915915
916- torch .ops .aiter .paged_attention_rocm (
916+ query_start_loc = None
917+ ops .paged_attention_rocm (
917918 output [num_prefill_tokens :],
918919 exp_sums ,
919920 max_logits ,
@@ -929,6 +930,7 @@ def forward(
929930 decode_meta .seq_lens_tensor
930931 if self .attn_type != AttentionType .ENCODER_DECODER else
931932 decode_meta .encoder_seq_lens_tensor ,
933+ query_start_loc ,
932934 block_size ,
933935 max_seq_len ,
934936 self .alibi_slopes ,
Original file line number Diff line number Diff line change @@ -286,7 +286,7 @@ def chunked_prefill_paged_decode(
286286 num_queries_per_kv ,
287287 max_seq_len , sliding_window ,
288288 kv_cache_dtype , alibi_slopes )
289- if use_custom and head_size <= 128 and num_queries_per_kv <= 16 :
289+ if use_custom :
290290 _PARTITION_SIZE_ROCM = 256
291291 max_num_partitions = ((max_seq_len + _PARTITION_SIZE_ROCM - 1 ) //
292292 _PARTITION_SIZE_ROCM )
Original file line number Diff line number Diff line change @@ -138,9 +138,9 @@ def use_rocm_custom_paged_attention(
138138 return ((not envs .VLLM_USE_V1 or sliding_window == 0
139139 or sliding_window == (- 1 , - 1 ))
140140 and (qtype == torch .half or qtype == torch .bfloat16 )
141- and (head_size in [ 64 , 128 , 256 ] )
141+ and (head_size == 64 or head_size == 128 )
142142 and (block_size == 16 or block_size == 32 )
143- and (gqa_ratio >= 1 and gqa_ratio <= 32 )
143+ and (gqa_ratio >= 1 and gqa_ratio <= 16 )
144144 and max_seq_len <= 128 * 1024
145145 and (envs .VLLM_ROCM_CUSTOM_PAGED_ATTN )
146146 and not (envs .VLLM_ROCM_USE_AITER_PAGED_ATTN
You can’t perform that action at this time.
0 commit comments