diff --git a/benchmarks/routines/attention.py b/benchmarks/routines/attention.py index 96b7a8f58b..320cfbe020 100644 --- a/benchmarks/routines/attention.py +++ b/benchmarks/routines/attention.py @@ -396,7 +396,7 @@ def testBatchDecodeWithPagedKVCacheWrapper(args): # Now initialize the page tables block_tables = torch.tensor( [ - [k + i * num_pages_per_seq for k in range(num_pages_per_seq)] + [k + i * num_pages_per_seq for k in torch.randperm(num_pages_per_seq)] for i in range(batch_size) ], dtype=torch.int, @@ -421,11 +421,7 @@ def testBatchDecodeWithPagedKVCacheWrapper(args): for i in range(len(kv_indptr) - 1): start_idx = kv_indptr[i] end_idx = kv_indptr[i + 1] - kv_indices[start_idx:end_idx] = torch.arange( - i * num_pages_per_seq, - i * num_pages_per_seq + (end_idx - start_idx), - device=device, - ) + kv_indices[start_idx:end_idx] = block_tables[i, : end_idx - start_idx] kv_last_page_len = ( torch.where( @@ -837,7 +833,7 @@ def testBatchPrefillWithPagedKVCacheWrapper(args): # Now initialize the page tables block_tables = torch.tensor( [ - [k + i * num_pages_per_seq for k in range(num_pages_per_seq)] + [k + i * num_pages_per_seq for k in torch.randperm(num_pages_per_seq)] for i in range(batch_size) ], dtype=torch.int, @@ -887,11 +883,7 @@ def testBatchPrefillWithPagedKVCacheWrapper(args): for i in range(len(kv_indptr) - 1): start_idx = kv_indptr[i] end_idx = kv_indptr[i + 1] - kv_indices[start_idx:end_idx] = torch.arange( - i * num_pages_per_seq, - i * num_pages_per_seq + (end_idx - start_idx), - device=device, - ) + kv_indices[start_idx:end_idx] = block_tables[i, : end_idx - start_idx] kv_last_page_len = ( torch.where( actual_seq_lens_kv_device.flatten() % page_size == 0, @@ -1711,7 +1703,7 @@ def testBatchMLAPagedAttentionWrapper(args): # Now initialize the page tables block_tables = torch.tensor( [ - [k + i * num_pages_per_seq for k in range(num_pages_per_seq)] + [k + i * num_pages_per_seq for k in torch.randperm(num_pages_per_seq)] for i in range(batch_size) ], dtype=torch.int, @@ -1758,11 +1750,7 @@ def testBatchMLAPagedAttentionWrapper(args): for i in range(len(kv_indptr) - 1): start_idx = kv_indptr[i] end_idx = kv_indptr[i + 1] - kv_indices[start_idx:end_idx] = torch.arange( - i * num_pages_per_seq, - i * num_pages_per_seq + (end_idx - start_idx), - device=device, - ) + kv_indices[start_idx:end_idx] = block_tables[i, : end_idx - start_idx] sm_scale = 1.0 / ((128 + 64) ** 0.5) # For DeepSeek-R1 workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=device)