Skip to content
Merged
3 changes: 2 additions & 1 deletion benchmarks/kernels/benchmark_paged_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def main(
if version == "v2":
if current_platform.is_rocm():
global PARTITION_SIZE
if not args.custom_paged_attn:
if not args.custom_paged_attn and not current_platform.is_navi():
PARTITION_SIZE = 1024
else:
PARTITION_SIZE = PARTITION_SIZE_ROCM
Expand Down Expand Up @@ -166,6 +166,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
scale,
block_tables,
seq_lens,
None,
block_size,
max_seq_len,
alibi_slopes,
Expand Down
Loading