Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 6 additions & 18 deletions benchmarks/routines/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ def testBatchDecodeWithPagedKVCacheWrapper(args):
# Now initialize the page tables
block_tables = torch.tensor(
[
[k + i * num_pages_per_seq for k in range(num_pages_per_seq)]
[k + i * num_pages_per_seq for k in torch.randperm(num_pages_per_seq)]
for i in range(batch_size)
],
dtype=torch.int,
Expand All @@ -421,11 +421,7 @@ def testBatchDecodeWithPagedKVCacheWrapper(args):
for i in range(len(kv_indptr) - 1):
start_idx = kv_indptr[i]
end_idx = kv_indptr[i + 1]
kv_indices[start_idx:end_idx] = torch.arange(
i * num_pages_per_seq,
i * num_pages_per_seq + (end_idx - start_idx),
device=device,
)
kv_indices[start_idx:end_idx] = block_tables[i, : end_idx - start_idx]

kv_last_page_len = (
torch.where(
Expand Down Expand Up @@ -837,7 +833,7 @@ def testBatchPrefillWithPagedKVCacheWrapper(args):
# Now initialize the page tables
block_tables = torch.tensor(
[
[k + i * num_pages_per_seq for k in range(num_pages_per_seq)]
[k + i * num_pages_per_seq for k in torch.randperm(num_pages_per_seq)]
for i in range(batch_size)
],
dtype=torch.int,
Expand Down Expand Up @@ -887,11 +883,7 @@ def testBatchPrefillWithPagedKVCacheWrapper(args):
for i in range(len(kv_indptr) - 1):
start_idx = kv_indptr[i]
end_idx = kv_indptr[i + 1]
kv_indices[start_idx:end_idx] = torch.arange(
i * num_pages_per_seq,
i * num_pages_per_seq + (end_idx - start_idx),
device=device,
)
kv_indices[start_idx:end_idx] = block_tables[i, : end_idx - start_idx]
kv_last_page_len = (
torch.where(
actual_seq_lens_kv_device.flatten() % page_size == 0,
Expand Down Expand Up @@ -1711,7 +1703,7 @@ def testBatchMLAPagedAttentionWrapper(args):
# Now initialize the page tables
block_tables = torch.tensor(
[
[k + i * num_pages_per_seq for k in range(num_pages_per_seq)]
[k + i * num_pages_per_seq for k in torch.randperm(num_pages_per_seq)]
for i in range(batch_size)
],
dtype=torch.int,
Expand Down Expand Up @@ -1758,11 +1750,7 @@ def testBatchMLAPagedAttentionWrapper(args):
for i in range(len(kv_indptr) - 1):
start_idx = kv_indptr[i]
end_idx = kv_indptr[i + 1]
kv_indices[start_idx:end_idx] = torch.arange(
i * num_pages_per_seq,
i * num_pages_per_seq + (end_idx - start_idx),
device=device,
)
kv_indices[start_idx:end_idx] = block_tables[i, : end_idx - start_idx]

sm_scale = 1.0 / ((128 + 64) ** 0.5) # For DeepSeek-R1
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=device)
Expand Down