Skip to content

Commit 1e004a6

Browse files
author
Aditya K Kamath
committed
Merge branch 'pod_batched_new' of github.com:AKKamath/flashinfer into pod_batched_new
2 parents bc11239 + 5f1e346 commit 1e004a6

File tree

2 files changed

+2
-7
lines changed

2 files changed

+2
-7
lines changed

benchmarks/bench_mixed_attention.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,8 @@ def run_bench(
110110
kv_indices_d = torch.arange(0, d_kv_indptr[-1], device=device, dtype=torch.int32)
111111
kv_indices_p = torch.arange(0, p_kv_indptr[-1], device=device, dtype=torch.int32)
112112

113-
last_page_len_d = (d_seq_lens_blocks - 1) % page_block_size + 1
114-
last_page_len_p = (p_seq_lens_blocks - 1) % page_block_size + 1
113+
last_page_len_d = (torch.tensor(d_kv_lens, device=device) - 1) % page_block_size + 1
114+
last_page_len_p = (torch.tensor(p_kv_lens, device=device) - 1) % page_block_size + 1
115115
wrapper_pod = flashinfer.BatchPODWithPagedKVCacheWrapper(
116116
workspace_buffer,
117117
kv_layout=kv_layout,

flashinfer/pod.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1077,12 +1077,7 @@ def run(
10771077
logits_soft_cap_p > 0, # use_logits_soft_cap
10781078
use_fp16_qk_reduction,
10791079
# Decode params
1080-
# q_d.dtype,
1081-
# self._cached_kv_data_type,
1082-
# self._cached_q_data_type,
10831080
self._indptr_type,
1084-
# head_dim, # head_dim_qk
1085-
# head_dim, # head_dim_vo
10861081
PosEncodingMode[pos_encoding_mode_d].value,
10871082
window_left_d != -1, # use_sliding_window
10881083
logits_soft_cap_d > 0, # use_logits_soft_cap

0 commit comments

Comments
 (0)