Skip to content

Commit c5ef2ca

Browse files
committed
add
1 parent b9287c9 commit c5ef2ca

File tree

1 file changed

+69
-62
lines changed

1 file changed

+69
-62
lines changed

benchmarks/bench_mixed_attention.py

Lines changed: 69 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,24 @@ def run_bench(
7272
measurements = bench_gpu_time(lambda: wrapper_old.run(q, kv_data))
7373
ms_old = np.median(measurements)
7474

75+
wrapper_persistent = flashinfer.BatchAttention(kv_layout="NHD")
76+
wrapper_persistent.plan(
77+
q_indptr.to(device),
78+
kv_indptr.to(device),
79+
torch.arange(num_blocks, dtype=torch.int32, device=device),
80+
seq_lens.to(device),
81+
num_qo_heads,
82+
num_kv_heads,
83+
head_dim,
84+
head_dim,
85+
page_block_size,
86+
causal=causal,
87+
q_data_type=torch.bfloat16,
88+
kv_data_type=torch.bfloat16,
89+
)
90+
o_persistent, _ = wrapper_persistent.run(q, kv_data)
91+
measurements_persistent = bench_gpu_time(lambda: wrapper_persistent.run(q, kv_data))
92+
ms_persistent = np.mean(measurements_persistent)
7593
if len(p_kv_lens) == 1:
7694
q_d = q[: d_q_indptr[-1]]
7795
kv_d = kv_data[: d_kv_indptr[-1]].unbind(1)
@@ -123,9 +141,46 @@ def run_bench(
123141
)
124142
)
125143
ms_pod = np.median(measurements)
144+
145+
# Sequential two kernels: single prefill + batch decode (tensor cores)
146+
# Prefill using single_prefill_with_kv_cache
147+
def _run_single_prefill():
148+
return flashinfer.prefill.single_prefill_with_kv_cache(
149+
q_p,
150+
k_p,
151+
v_p,
152+
causal=causal,
153+
pos_encoding_mode="NONE",
154+
backend="fa2",
155+
)
156+
157+
measurements_prefill = bench_gpu_time(lambda: _run_single_prefill())
158+
ms_prefill = np.median(measurements_prefill)
159+
160+
# Batch decode using tensor cores
161+
wrapper_decode = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
162+
workspace_buffer, kv_layout=kv_layout, use_tensor_cores=True
163+
)
164+
wrapper_decode.plan(
165+
d_kv_indptr.to(device),
166+
kv_indices_d.to(device),
167+
last_page_len_d,
168+
num_qo_heads,
169+
num_kv_heads,
170+
head_dim,
171+
page_block_size,
172+
data_type=torch.bfloat16,
173+
q_data_type=torch.bfloat16,
174+
)
175+
measurements_decode = bench_gpu_time(lambda: wrapper_decode.run(q_d, kv_d))
176+
ms_decode = np.median(measurements_decode)
177+
ms_seq_two_kernels = ms_prefill + ms_decode
178+
126179
print(f"Elapsed time (Batched Prefill): {ms_old:.2f} ms")
127180
if len(p_kv_lens) == 1:
128181
print(f"Elapsed time (POD Attention): {ms_pod:.2f} ms")
182+
print(f"Elapsed time (Sequential two kernels): {ms_seq_two_kernels:.2f} ms")
183+
print(f"Elapsed time (Persistent BatchAttention): {ms_persistent:.2f} ms")
129184
total_bytes = (
130185
q.numel() * q.element_size() + kv_data.numel() * kv_data.element_size()
131186
)
@@ -137,77 +192,29 @@ def run_bench(
137192
if len(p_kv_lens) == 1:
138193
bandwidth_pod_gb_s = total_bytes / (ms_pod * 1e-3) / (1024**3)
139194
print(f"Memory bandwidth (POD Attention): {bandwidth_pod_gb_s:.2f} GB/s")
195+
bandwidth_seq_gb_s = total_bytes / (ms_seq_two_kernels * 1e-3) / (1024**3)
196+
print(
197+
f"Memory bandwidth (Sequential two kernels): {bandwidth_seq_gb_s:.2f} GB/s"
198+
)
199+
bandwidth_persistent_gb_s = total_bytes / (ms_persistent * 1e-3) / (1024**3)
200+
print(
201+
f"Memory bandwidth (Persistent BatchAttention): {bandwidth_persistent_gb_s:.2f} GB/s"
202+
)
140203

141204

142205
if __name__ == "__main__":
143206
np.random.seed(42)
144207
torch.random.manual_seed(42)
145208

146209
# Irregular sequence lengths for prefill and decode
147-
d_q_len_configs = [[1] * 122, [1] * 128, [1] * 242, [1] * 256]
148-
d_kv_len_configs = [[600] * 122, [10000] * 128, [400] * 242, [8192] * 256]
149-
p_q_configs = [[17] * 1, [10000], [17] * 1, []]
150-
p_kv_configs = [[10000] * 1, [10000], [8192] * 1, []]
151-
152-
# construct random length testcases
153-
for _ in range(1):
154-
bsz = 256
155-
stride = 16
156-
sparsity = 0.05
157-
158-
full_kv_len = np.random.randint(1000, 8192, size=bsz)
159-
p_q_lens = []
160-
p_kv_lens = []
161-
d_q_lens = []
162-
d_kv_lens = []
163-
for i in range(bsz):
164-
if i % stride == 0:
165-
kv_len = full_kv_len[i]
166-
qo_len = stride + 1
167-
p_q_lens.append(qo_len)
168-
p_kv_lens.append(kv_len)
169-
else:
170-
kv_len = int(full_kv_len[i] * sparsity)
171-
qo_len = 1
172-
d_q_lens.append(qo_len)
173-
d_kv_lens.append(kv_len)
174-
175-
p_q_configs.append(p_q_lens)
176-
p_kv_configs.append(p_kv_lens)
177-
d_q_len_configs.append(d_q_lens)
178-
d_kv_len_configs.append(d_kv_lens)
179-
180-
for _ in range(1):
181-
bsz = 128
182-
stride = 16
183-
sparsity = 0.05
184-
185-
full_kv_len = np.random.randint(2000, 16000, size=bsz)
186-
p_q_lens = []
187-
p_kv_lens = []
188-
d_q_lens = []
189-
d_kv_lens = []
190-
191-
for i in range(bsz):
192-
if i % stride == 0:
193-
kv_len = full_kv_len[i]
194-
qo_len = stride + 1
195-
p_q_lens.append(qo_len)
196-
p_kv_lens.append(kv_len)
197-
else:
198-
kv_len = int(full_kv_len[i] * sparsity)
199-
qo_len = 1
200-
d_q_lens.append(qo_len)
201-
d_kv_lens.append(kv_len)
202-
203-
p_q_configs.append(p_q_lens)
204-
p_kv_configs.append(p_kv_lens)
205-
d_q_len_configs.append(d_q_lens)
206-
d_kv_len_configs.append(d_kv_lens)
210+
d_q_len_configs = [[1] * 128, [1] * 128, [1] * 128, [1] * 128]
211+
d_kv_len_configs = [[2048] * 128, [4096] * 128, [8192] * 128, [8192] * 128]
212+
p_q_configs = [[2048], [4096], [4096], [6000]]
213+
p_kv_configs = [[2048], [4096], [4096], [7000]]
207214

208215
page_block_size = 1
209-
num_kv_heads = 4
210-
num_qo_heads = 28
216+
num_kv_heads = 8
217+
num_qo_heads = 32
211218
head_dim = 128
212219

213220
for idx, (p_q_lens, p_kv_lens, d_q_len, d_kv_len) in enumerate(

0 commit comments

Comments
 (0)