Skip to content

Commit 70d687f

Browse files
committed
revert
Signed-off-by: fsx950223 <[email protected]>
1 parent 60a931e commit 70d687f

File tree

3 files changed

+6
-4
lines changed

3 files changed

+6
-4
lines changed

vllm/attention/backends/rocm_flash_attn.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -913,7 +913,8 @@ def forward(
913913
)
914914
max_logits = torch.empty_like(exp_sums)
915915

916-
torch.ops.aiter.paged_attention_rocm(
916+
query_start_loc = None
917+
ops.paged_attention_rocm(
917918
output[num_prefill_tokens:],
918919
exp_sums,
919920
max_logits,
@@ -929,6 +930,7 @@ def forward(
929930
decode_meta.seq_lens_tensor
930931
if self.attn_type != AttentionType.ENCODER_DECODER else
931932
decode_meta.encoder_seq_lens_tensor,
933+
query_start_loc,
932934
block_size,
933935
max_seq_len,
934936
self.alibi_slopes,

vllm/attention/ops/chunked_prefill_paged_decode.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ def chunked_prefill_paged_decode(
286286
num_queries_per_kv,
287287
max_seq_len, sliding_window,
288288
kv_cache_dtype, alibi_slopes)
289-
if use_custom and head_size <= 128 and num_queries_per_kv <= 16:
289+
if use_custom:
290290
_PARTITION_SIZE_ROCM = 256
291291
max_num_partitions = ((max_seq_len + _PARTITION_SIZE_ROCM - 1) //
292292
_PARTITION_SIZE_ROCM)

vllm/platforms/rocm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,9 +138,9 @@ def use_rocm_custom_paged_attention(
138138
return ((not envs.VLLM_USE_V1 or sliding_window == 0
139139
or sliding_window == (-1, -1))
140140
and (qtype == torch.half or qtype == torch.bfloat16)
141-
and (head_size in [64, 128, 256])
141+
and (head_size == 64 or head_size == 128)
142142
and (block_size == 16 or block_size == 32)
143-
and (gqa_ratio >= 1 and gqa_ratio <= 32)
143+
and (gqa_ratio >= 1 and gqa_ratio <= 16)
144144
and max_seq_len <= 128 * 1024
145145
and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
146146
and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN

0 commit comments

Comments
 (0)