Skip to content

Commit 42424b8

Browse files
author
Aditya K Kamath
committed
Fix formatting required by pre-commit
1 parent 37ae780 commit 42424b8

File tree

7 files changed

+86
-79
lines changed

7 files changed

+86
-79
lines changed

benchmarks/bench_mixed_attention.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -107,12 +107,8 @@ def run_bench(
107107
kv_d = kv_data[: d_kv_indptr[-1]].unbind(1)
108108
q_p = q[d_q_indptr[-1] :]
109109
kv_p = kv_data[d_kv_indptr[-1] :].unbind(1)
110-
kv_indices_d = torch.arange(
111-
0, d_kv_indptr[-1], device=device, dtype=torch.int32
112-
)
113-
kv_indices_p = torch.arange(
114-
0, p_kv_indptr[-1], device=device, dtype=torch.int32
115-
)
110+
kv_indices_d = torch.arange(0, d_kv_indptr[-1], device=device, dtype=torch.int32)
111+
kv_indices_p = torch.arange(0, p_kv_indptr[-1], device=device, dtype=torch.int32)
116112

117113
last_page_len_d = (d_seq_lens_blocks - 1) % page_block_size + 1
118114
last_page_len_p = (p_seq_lens_blocks - 1) % page_block_size + 1
@@ -266,7 +262,9 @@ def _run_single_prefill():
266262

267263
print(f"Memory bandwidth (Batched Prefill): {bandwidth_old_gb_s:.2f} GB/s")
268264
bandwidth_batch_pod_gb_s = total_bytes / (ms_batch_pod * 1e-3) / (1024**3)
269-
print(f"Memory bandwidth (Batched POD Attention): {bandwidth_batch_pod_gb_s:.2f} GB/s")
265+
print(
266+
f"Memory bandwidth (Batched POD Attention): {bandwidth_batch_pod_gb_s:.2f} GB/s"
267+
)
270268
if len(p_kv_lens) == 1:
271269
bandwidth_pod_gb_s = total_bytes / (ms_pod * 1e-3) / (1024**3)
272270
print(f"Memory bandwidth (POD Attention): {bandwidth_pod_gb_s:.2f} GB/s")
@@ -286,7 +284,13 @@ def _run_single_prefill():
286284

287285
# Irregular sequence lengths for prefill and decode
288286
d_q_len_configs = [[1] * 128, [1] * 128, [1] * 128, [1] * 128, [1] * 128]
289-
d_kv_len_configs = [[2048] * 128, [2048] * 128, [4096] * 128, [8192] * 128, [8192] * 128]
287+
d_kv_len_configs = [
288+
[2048] * 128,
289+
[2048] * 128,
290+
[4096] * 128,
291+
[8192] * 128,
292+
[8192] * 128,
293+
]
290294
p_q_configs = [[2048] * 2, [2048], [4096], [4096], [6000]]
291295
p_kv_configs = [[2048] * 2, [2048], [4096], [4096], [7000]]
292296

csrc/batch_pod.cu

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,15 @@
2121

2222
namespace flashinfer {
2323
template <uint32_t HEAD_DIM_QK, uint32_t HEAD_DIM_VO, PosEncodingMode POS_ENCODING_MODE,
24-
bool USE_FP16_QK_REDUCTION, uint32_t CTA_TILE_Q_P, MaskMode MASK_MODE_P, uint32_t CTA_TILE_Q_D,
25-
MaskMode MASK_MODE_D, typename PrefillAttentionVariant, typename DecodeAttentionVariant,
26-
typename PrefillParams, typename DecodeParams>
24+
bool USE_FP16_QK_REDUCTION, uint32_t CTA_TILE_Q_P, MaskMode MASK_MODE_P,
25+
uint32_t CTA_TILE_Q_D, MaskMode MASK_MODE_D, typename PrefillAttentionVariant,
26+
typename DecodeAttentionVariant, typename PrefillParams, typename DecodeParams>
2727
cudaError_t BatchPODWithKVCacheTensorDispatched(PrefillParams prefill_params,
28-
typename PrefillParams::DTypeO* tmp_v_p, float* tmp_s_p,
29-
DecodeParams decode_params,
30-
typename DecodeParams::DTypeO* tmp_v_d, float* tmp_s_d,
31-
bool enable_pdl, cudaStream_t stream);
28+
typename PrefillParams::DTypeO* tmp_v_p,
29+
float* tmp_s_p, DecodeParams decode_params,
30+
typename DecodeParams::DTypeO* tmp_v_d,
31+
float* tmp_s_d, bool enable_pdl,
32+
cudaStream_t stream);
3233

3334
} // namespace flashinfer
3435

@@ -122,8 +123,8 @@ void batch_pod_with_kv_cache_tensor(
122123
num_kv_heads_d = paged_k_cache_d.size(2);
123124
}
124125
TVM_FFI_ICHECK_EQ(num_kv_heads_p, num_kv_heads_d)
125-
<< "POD currently requires same # KV heads for prefill and decode; Prefill: " << num_kv_heads_p
126-
<< ", Decode: " << num_kv_heads_d;
126+
<< "POD currently requires same # KV heads for prefill and decode; Prefill: "
127+
<< num_kv_heads_p << ", Decode: " << num_kv_heads_d;
127128

128129
if (maybe_lse_d.has_value()) {
129130
const auto& lse = maybe_lse_d.value();
@@ -151,8 +152,8 @@ void batch_pod_with_kv_cache_tensor(
151152
kv_cache_strides_d = k_strides_d.data();
152153

153154
// Already handled by prefill
154-
//cudaSetDevice(float_workspace_buffer_d.device().device_id);
155-
//const cudaStream_t stream = get_stream(float_workspace_buffer_d.device());
155+
// cudaSetDevice(float_workspace_buffer_d.device().device_id);
156+
// const cudaStream_t stream = get_stream(float_workspace_buffer_d.device());
156157

157158
DISPATCH_context(
158159
MASK_MODE_P, MASK_MODE_D, DTypeQ, DTypeKV, HEAD_DIM_QK, USE_SLIDING_WINDOW_P,
@@ -213,7 +214,8 @@ void batch_pod_with_kv_cache_tensor(
213214
GetPtrFromBaseOffset<IdType>(int_buffer_ptr_p, plan_info_p.qo_tile_indices_offset);
214215
params.kv_tile_indices =
215216
GetPtrFromBaseOffset<IdType>(int_buffer_ptr_p, plan_info_p.kv_tile_indices_offset);
216-
params.o_indptr = GetPtrFromBaseOffset<IdType>(int_buffer_ptr_p, plan_info_p.o_indptr_offset);
217+
params.o_indptr =
218+
GetPtrFromBaseOffset<IdType>(int_buffer_ptr_p, plan_info_p.o_indptr_offset);
217219
params.kv_chunk_size_ptr =
218220
GetPtrFromBaseOffset<IdType>(int_buffer_ptr_p, plan_info_p.kv_chunk_size_ptr_offset);
219221
if (plan_info_p.split_kv) {
@@ -290,7 +292,8 @@ void batch_pod_with_kv_cache_tensor(
290292
GetPtrFromBaseOffset<IdType>(int_buffer_ptr_d, plan_info_d.qo_tile_indices_offset);
291293
params.kv_tile_indices =
292294
GetPtrFromBaseOffset<IdType>(int_buffer_ptr_d, plan_info_d.kv_tile_indices_offset);
293-
params.o_indptr = GetPtrFromBaseOffset<IdType>(int_buffer_ptr_d, plan_info_d.o_indptr_offset);
295+
params.o_indptr =
296+
GetPtrFromBaseOffset<IdType>(int_buffer_ptr_d, plan_info_d.o_indptr_offset);
294297
params.kv_chunk_size_ptr =
295298
GetPtrFromBaseOffset<IdType>(int_buffer_ptr_d, plan_info_d.kv_chunk_size_ptr_offset);
296299
if (plan_info_d.split_kv) {
@@ -322,10 +325,10 @@ void batch_pod_with_kv_cache_tensor(
322325
DISPATCH_CTA_TILE_Q(plan_info_p.cta_tile_q, CTA_TILE_Q_P, {
323326
constexpr size_t CTA_TILE_Q_D = 16;
324327
cudaError_t status = flashinfer::BatchPODWithKVCacheTensorDispatched<
325-
HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE, USE_FP16_QK_REDUCTION, CTA_TILE_Q_P, MASK_MODE_P,
326-
CTA_TILE_Q_D, MASK_MODE_D, PrefillAttentionVariant, DecodeAttentionVariant>(
327-
prefill_params, tmp_v_p, tmp_s_p, decode_params, tmp_v_d, tmp_s_d,
328-
enable_pdl, stream);
328+
HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE, USE_FP16_QK_REDUCTION, CTA_TILE_Q_P,
329+
MASK_MODE_P, CTA_TILE_Q_D, MASK_MODE_D, PrefillAttentionVariant,
330+
DecodeAttentionVariant>(prefill_params, tmp_v_p, tmp_s_p, decode_params, tmp_v_d,
331+
tmp_s_d, enable_pdl, stream);
329332
TVM_FFI_ICHECK(status == cudaSuccess)
330333
<< "BatchPODWithKVCache kernel launch failed, error: " << cudaGetErrorString(status);
331334
});

csrc/batch_pod_kernel_inst.jinja

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,11 @@ constexpr auto POS_ENCODING_MODE = PosEncodingMode::kNone;
2121
{% for cta_tile_q in [16, 64, 128] %}
2222
template cudaError_t BatchPODWithKVCacheTensorDispatched<
2323
{{ head_dim_qk }}, {{ head_dim_vo }}, POS_ENCODING_MODE,
24-
{{ use_fp16_qk_reduction }}, /*CTA_TILE_Q_P=*/{{cta_tile_q}}, {{ mask_mode_p }},
24+
{{ use_fp16_qk_reduction }}, /*CTA_TILE_Q_P=*/{{cta_tile_q}}, {{ mask_mode_p }},
2525
/*CTA_TILE_Q_D=*/16, {{ mask_mode_d }}, {{ variant_name_p }},
2626
{{ variant_name_d }}, PrefillParams, DecodeParams>(
2727
PrefillParams prefill_params, {{ dtype_o }}* tmp_v_p, float *tmp_s_p,
28-
DecodeParams decode_params, {{ dtype_o }}* tmp_v_d, float *tmp_s_d,
28+
DecodeParams decode_params, {{ dtype_o }}* tmp_v_d, float *tmp_s_d,
2929
bool enable_pdl, cudaStream_t stream);
3030
{% endfor %}
3131
};

csrc/pod_customize_config.jinja

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,4 +40,3 @@ using DecodeParams = BatchPrefillPagedParams<DTypeQ, DTypeKV, DTypeO, IdType>;
4040
__VA_ARGS__(); \
4141
}); \
4242
});
43-

flashinfer/jit/attention/modules.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -629,6 +629,7 @@ def gen_pod_module(
629629
use_fp16_qk_reduction=use_fp16_qk_reduction,
630630
)
631631

632+
632633
def gen_batch_pod_module(
633634
dtype_q: torch.dtype,
634635
dtype_kv: torch.dtype,
@@ -643,7 +644,7 @@ def gen_batch_pod_module(
643644
use_sliding_window_d: bool,
644645
use_logits_soft_cap_d: bool,
645646
) -> JitSpec:
646-
uri = 'batch_' + get_pod_uri(
647+
uri = "batch_" + get_pod_uri(
647648
dtype_q,
648649
dtype_kv,
649650
dtype_o,
@@ -693,6 +694,7 @@ def gen_batch_pod_module(
693694
use_fp16_qk_reduction=use_fp16_qk_reduction,
694695
)
695696

697+
696698
def gen_customize_pod_module(
697699
uri: str,
698700
dtype_q: torch.dtype,
@@ -792,6 +794,7 @@ def gen_customize_pod_module(
792794

793795
return gen_jit_spec(uri, source_paths)
794796

797+
795798
def gen_customize_batch_pod_module(
796799
uri: str,
797800
dtype_q: torch.dtype,
@@ -891,6 +894,7 @@ def gen_customize_batch_pod_module(
891894

892895
return gen_jit_spec(uri, source_paths)
893896

897+
894898
def gen_batch_decode_module(
895899
dtype_q: torch.dtype,
896900
dtype_kv: torch.dtype,

flashinfer/pod.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,13 @@ def get_pod_module(*args):
4646
module = gen_pod_module(*args).build_and_load()
4747
return SimpleNamespace(run_tensor=module.pod_with_kv_cache_tensor)
4848

49+
4950
@functools.cache
5051
def get_batch_pod_module(*args):
5152
module = gen_batch_pod_module(*args).build_and_load()
5253
return SimpleNamespace(run_tensor=module.batch_pod_with_kv_cache_tensor)
5354

55+
5456
class PODWithPagedKVCacheWrapper:
5557
r"""Wrapper class for POD-Attention with paged kv-cache (first proposed in
5658
`<https://arxiv.org/abs/2410.18038>`_) for batch of requests.
@@ -615,6 +617,7 @@ def end_forward(self) -> None:
615617
r"""Warning: this function is deprecated and has no effect."""
616618
pass
617619

620+
618621
class BatchPODWithPagedKVCacheWrapper:
619622
r"""Wrapper class for POD-Attention with paged kv-cache (first proposed in
620623
`<https://arxiv.org/abs/2410.18038>`_) for batch of requests.
@@ -837,12 +840,8 @@ def plan(
837840
batch_size_p = len(last_page_len_p)
838841
qo_indptr_host_p = qo_indptr_p.to("cpu")
839842
total_num_rows_p = int(qo_indptr_host_p[-1])
840-
self._kv_indptr_buf_p = kv_indptr_p.to(
841-
self.device, non_blocking=non_blocking
842-
)
843-
self._kv_indices_buf_p = kv_indices_p.to(
844-
self.device, non_blocking=non_blocking
845-
)
843+
self._kv_indptr_buf_p = kv_indptr_p.to(self.device, non_blocking=non_blocking)
844+
self._kv_indices_buf_p = kv_indices_p.to(self.device, non_blocking=non_blocking)
846845
self._kv_last_page_len_buf_p = last_page_len_p.to(
847846
self.device, non_blocking=non_blocking
848847
)
@@ -851,7 +850,9 @@ def plan(
851850
)
852851
kv_indptr_host_p = kv_indptr_p.to("cpu")
853852
last_page_len_host_p = last_page_len_p.to("cpu")
854-
kv_lens_arr_host_p = get_seq_lens(kv_indptr_host_p, last_page_len_host_p, page_size)
853+
kv_lens_arr_host_p = get_seq_lens(
854+
kv_indptr_host_p, last_page_len_host_p, page_size
855+
)
855856

856857
if data_type is not None:
857858
if q_data_type is None:
@@ -908,12 +909,8 @@ def plan(
908909
batch_size_d = len(last_page_len_d)
909910
qo_indptr_host_d = qo_indptr_d.to("cpu")
910911
total_num_rows_d = int(qo_indptr_host_d[-1])
911-
self._kv_indptr_buf_d = kv_indptr_d.to(
912-
self.device, non_blocking=non_blocking
913-
)
914-
self._kv_indices_buf_d = kv_indices_d.to(
915-
self.device, non_blocking=non_blocking
916-
)
912+
self._kv_indptr_buf_d = kv_indptr_d.to(self.device, non_blocking=non_blocking)
913+
self._kv_indices_buf_d = kv_indices_d.to(self.device, non_blocking=non_blocking)
917914
self._kv_last_page_len_buf_d = last_page_len_d.to(
918915
self.device, non_blocking=non_blocking
919916
)
@@ -922,7 +919,9 @@ def plan(
922919
)
923920
kv_indptr_host_d = kv_indptr_d.to("cpu")
924921
last_page_len_host_d = last_page_len_d.to("cpu")
925-
kv_lens_arr_host_d = get_seq_lens(kv_indptr_host_d, last_page_len_host_d, page_size)
922+
kv_lens_arr_host_d = get_seq_lens(
923+
kv_indptr_host_d, last_page_len_host_d, page_size
924+
)
926925

927926
self._plan_info_d = self._cached_module.plan(
928927
self._float_workspace_buffer_d,

0 commit comments

Comments
 (0)