diff --git a/3rdparty/googletest b/3rdparty/googletest new file mode 160000 index 0000000000..5a37b517ad --- /dev/null +++ b/3rdparty/googletest @@ -0,0 +1 @@ +Subproject commit 5a37b517ad4ab6738556f0284c256cae1466c5b4 diff --git a/3rdparty/nvbench b/3rdparty/nvbench new file mode 160000 index 0000000000..555d628e9b --- /dev/null +++ b/3rdparty/nvbench @@ -0,0 +1 @@ +Subproject commit 555d628e9b250868c9da003e4407087ff1982e8e diff --git a/aot_build_utils/generate.py b/aot_build_utils/generate.py new file mode 100644 index 0000000000..5daa47c08b --- /dev/null +++ b/aot_build_utils/generate.py @@ -0,0 +1,388 @@ +""" +Copyright (c) 2024 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import argparse +from itertools import product +from pathlib import Path +from typing import List + +from . import ( + generate_aot_default_additional_params_header, + generate_batch_paged_decode_inst, + generate_batch_paged_prefill_inst, + generate_batch_ragged_prefill_inst, + generate_pod_inst, + generate_single_decode_inst, + generate_single_prefill_inst, +) + + +def get_instantiation_cu(args: argparse.Namespace) -> List[str]: + def write_if_different(path: Path, content: str) -> None: + if path.exists() and path.read_text() == content: + return + path.write_text(content) + + path: Path = args.path + head_dims: List[int] = args.head_dims + pos_encoding_modes: List[int] = args.pos_encoding_modes + use_fp16_qk_reductions: List[int] = args.use_fp16_qk_reductions + mask_modes: List[int] = args.mask_modes + enable_f16: bool = args.enable_f16 + enable_bf16: bool = args.enable_bf16 + enable_fp8_e4m3: bool = args.enable_fp8_e4m3 + enable_fp8_e5m2: bool = args.enable_fp8_e5m2 + + path.mkdir(parents=True, exist_ok=True) + + write_if_different( + path / "aot_default_additional_params.h", + generate_aot_default_additional_params_header.get_aot_default_additional_params_header_str(), + ) + + idtypes = ["i32"] + prefill_dtypes = [] + decode_dtypes = [] + fp16_dtypes = [] + fp8_dtypes = [] + if enable_f16: + prefill_dtypes.append("f16") + decode_dtypes.append("f16") + fp16_dtypes.append("f16") + if enable_bf16: + prefill_dtypes.append("bf16") + decode_dtypes.append("bf16") + fp16_dtypes.append("bf16") + if enable_fp8_e4m3: + fp8_dtypes.extend(["e4m3"]) + decode_dtypes.extend(["e4m3"]) + if enable_fp8_e5m2: + fp8_dtypes.extend(["e5m2"]) + decode_dtypes.extend(["e5m2"]) + + single_decode_uris = [] + # single decode files + for head_dim, pos_encoding_mode in product(head_dims, pos_encoding_modes): + for dtype_q, dtype_kv in list(zip(decode_dtypes, decode_dtypes)) + list( + product(fp16_dtypes, fp8_dtypes) + ): + dtype_out = dtype_q + fname = f"single_decode_head_qk_{head_dim}_head_vo_{head_dim}_posenc_{pos_encoding_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_out}.cu" + content = generate_single_decode_inst.get_cu_file_str( + head_dim, # head_dim_qk + head_dim, # head_dim_vo + pos_encoding_mode, + dtype_q, + dtype_kv, + dtype_out, + ) + for use_sliding_window in [True, False]: + for use_logits_soft_cap in [True, False]: + single_decode_uris.append( + f"single_decode_with_kv_cache_dtype_q_{dtype_q}_" + f"dtype_kv_{dtype_kv}_" + f"dtype_o_{dtype_out}_" + f"head_dim_qk_{head_dim}_" + f"head_dim_vo_{head_dim}_" + f"posenc_{pos_encoding_mode}_" + f"use_swa_{use_sliding_window}_" + f"use_logits_cap_{use_logits_soft_cap}" + ) + write_if_different(path / fname, content) + + # batch decode files + batch_decode_uris = [] + for ( + head_dim, + pos_encoding_mode, + ) in product( + head_dims, + pos_encoding_modes, + ): + for idtype in idtypes: + for dtype_q, dtype_kv in list(zip(decode_dtypes, decode_dtypes)) + list( + product(fp16_dtypes, fp8_dtypes) + ): + dtype_out = dtype_q + fname = f"batch_paged_decode_head_qk_{head_dim}_head_vo_{head_dim}_posenc_{pos_encoding_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_out}_idtype_{idtype}.cu" + content = generate_batch_paged_decode_inst.get_cu_file_str( + head_dim, # head_dim_qk + head_dim, # head_dim_vo + pos_encoding_mode, + dtype_q, + dtype_kv, + dtype_out, + idtype, + ) + for use_sliding_window in [True, False]: + for use_logits_soft_cap in [True, False]: + batch_decode_uris.append( + f"batch_decode_with_kv_cache_dtype_q_{dtype_q}_" + f"dtype_kv_{dtype_kv}_" + f"dtype_o_{dtype_out}_" + f"dtype_idx_{idtype}_" + f"head_dim_qk_{head_dim}_" + f"head_dim_vo_{head_dim}_" + f"posenc_{pos_encoding_mode}_" + f"use_swa_{use_sliding_window}_" + f"use_logits_cap_{use_logits_soft_cap}" + ) + write_if_different(path / fname, content) + + # single prefill files + single_prefill_uris = [] + for ( + head_dim, + pos_encoding_mode, + use_fp16_qk_reduction, + mask_mode, + ) in product( + head_dims, + pos_encoding_modes, + use_fp16_qk_reductions, + mask_modes, + ): + for dtype_q, dtype_kv in list(zip(prefill_dtypes, prefill_dtypes)) + list( + product(prefill_dtypes, fp8_dtypes) + ): + fname = f"single_prefill_head_qk_{head_dim}_head_vo_{head_dim}_posenc_{pos_encoding_mode}_fp16qkred_{use_fp16_qk_reduction}_mask_{mask_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}.cu" + content = generate_single_prefill_inst.get_cu_file_str( + head_dim, # head_dim_qk + head_dim, # head_dim_vo + pos_encoding_mode, + use_fp16_qk_reduction, + mask_mode, + dtype_q, # dtype_q + dtype_kv, # dtype_kv + dtype_q, # dtype_out + ) + for use_sliding_window in [True, False]: + for use_logits_soft_cap in [True, False]: + if ( + mask_mode == 0 + ): # NOTE(Zihao): uri do not contain mask, avoid duplicate uris + single_prefill_uris.append( + f"single_prefill_with_kv_cache_dtype_q_{dtype_q}_" + f"dtype_kv_{dtype_kv}_" + f"dtype_o_{dtype_q}_" + f"head_dim_qk_{head_dim}_" + f"head_dim_vo_{head_dim}_" + f"posenc_{pos_encoding_mode}_" + f"use_swa_{use_sliding_window}_" + f"use_logits_cap_{use_logits_soft_cap}_" + f"f16qk_{bool(use_fp16_qk_reduction)}" + ) + write_if_different(path / fname, content) + + # batch prefill files + batch_prefill_uris = [] + for ( + head_dim, + pos_encoding_mode, + use_fp16_qk_reduction, + mask_mode, + idtype, + ) in product( + head_dims, + pos_encoding_modes, + use_fp16_qk_reductions, + mask_modes, + idtypes, + ): + for dtype_q, dtype_kv in list(zip(prefill_dtypes, prefill_dtypes)) + list( + product(prefill_dtypes, fp8_dtypes) + ): + fname = f"batch_paged_prefill_head_qk_{head_dim}_head_vo_{head_dim}_posenc_{pos_encoding_mode}_fp16qkred_{use_fp16_qk_reduction}_mask_{mask_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}_idtype_{idtype}.cu" + content = generate_batch_paged_prefill_inst.get_cu_file_str( + head_dim, # head_dim_qk + head_dim, # head_dim_vo + pos_encoding_mode, + use_fp16_qk_reduction, + mask_mode, + dtype_q, # dtype_q + dtype_kv, # dtype_kv + dtype_q, # dtype_out + idtype, + ) + write_if_different(path / fname, content) + + fname = f"batch_ragged_prefill_head_qk_{head_dim}_head_vo_{head_dim}_posenc_{pos_encoding_mode}_fp16qkred_{use_fp16_qk_reduction}_mask_{mask_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}_idtype_{idtype}.cu" + content = generate_batch_ragged_prefill_inst.get_cu_file_str( + head_dim, # head_dim_qk + head_dim, # head_dim_vo + pos_encoding_mode, + use_fp16_qk_reduction, + mask_mode, + dtype_q, # dtype_q + dtype_kv, # dtype_kv + dtype_q, # dtype_out + idtype, + ) + write_if_different(path / fname, content) + + for sliding_window in [True, False]: + for logits_soft_cap in [True, False]: + if ( + mask_mode == 0 + ): # NOTE(Zihao): uri do not contain mask, avoid duplicate uris + batch_prefill_uris.append( + f"batch_prefill_with_kv_cache_dtype_q_{dtype_q}_" + f"dtype_kv_{dtype_kv}_" + f"dtype_o_{dtype_q}_" + f"dtype_idx_{idtype}_" + f"head_dim_qk_{head_dim}_" + f"head_dim_vo_{head_dim}_" + f"posenc_{pos_encoding_mode}_" + f"use_swa_{sliding_window}_" + f"use_logits_cap_{logits_soft_cap}_" + f"f16qk_{bool(use_fp16_qk_reduction)}" + ) + + # POD files + pod_uris = [] + for ( + head_dim, + pos_encoding_mode, + use_fp16_qk_reduction, + mask_mode_p, + mask_mode_d, + idtype, + ) in product( + head_dims, + pos_encoding_modes, + use_fp16_qk_reductions, + mask_modes, + mask_modes, # mask_mode_d can be different from mask_mode_p + idtypes, + ): + for dtype_q, dtype_kv in list(zip(prefill_dtypes, prefill_dtypes)) + list( + product(prefill_dtypes, fp8_dtypes) + ): + fname = f"pod_head_qk_{head_dim}_head_vo_{head_dim}_posenc_{pos_encoding_mode}_fp16qkred_{use_fp16_qk_reduction}_maskp_{mask_mode_p}_maskd_{mask_mode_d}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}_idtype_{idtype}.cu" + content = generate_pod_inst.get_cu_file_str( + head_dim, # head_dim_qk + head_dim, # head_dim_vo + pos_encoding_mode, + use_fp16_qk_reduction, + mask_mode_p, + mask_mode_d, + dtype_q, # dtype_q + dtype_kv, # dtype_kv + dtype_q, # dtype_out + idtype, + ) + write_if_different(path / fname, content) + + for sliding_window_p in [True, False]: + for sliding_window_d in [True, False]: + for logits_soft_cap_p in [True, False]: + for logits_soft_cap_d in [True, False]: + if ( + mask_mode_p == 0 and mask_mode_d == 0 + ): # NOTE(Zihao): uri do not contain mask, avoid duplicate uris + pod_uris.append( + f"pod_with_kv_cache_dtype_q_{dtype_q}_" + f"dtype_kv_{dtype_kv}_" + f"dtype_o_{dtype_q}_" + f"dtype_idx_{idtype}_" + f"head_dim_qk_{head_dim}_" + f"head_dim_vo_{head_dim}_" + f"posenc_{pos_encoding_mode}_" + f"use_swa_p_{sliding_window_p}_" + f"use_swa_d_{sliding_window_d}_" + f"use_logits_cap_p_{logits_soft_cap_p}_" + f"use_logits_cap_d_{logits_soft_cap_d}_" + f"f16qk_{bool(use_fp16_qk_reduction)}" + ) + + return ( + single_decode_uris + + batch_decode_uris + + single_prefill_uris + + batch_prefill_uris + + pod_uris + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("Generate cuda files") + parser.add_argument( + "--path", type=Path, required=True, help="Path to the dispatch inc file" + ) + parser.add_argument( + "--head_dims", type=int, required=True, nargs="+", help="Head dimensions" + ) + parser.add_argument( + "--pos_encoding_modes", + type=int, + required=True, + nargs="+", + help="Position encoding modes", + ) + parser.add_argument( + "--use_fp16_qk_reductions", + type=lambda x: ( + x if isinstance(x, int) else int(x.lower() == "true" or x.lower() == "on") + ), + required=True, + nargs="+", + help="Allow fp16 qk reductions", + ) + parser.add_argument( + "--mask_modes", + type=int, + required=True, + nargs="+", + help="Mask modes", + ) + parser.add_argument( + "--enable_f16", + type=lambda x: ( + x if isinstance(x, int) else (x.lower() == "true" or x.lower() == "on") + ), + required=True, + nargs="?", + help="Enable fp16", + ) + parser.add_argument( + "--enable_bf16", + type=lambda x: ( + x if isinstance(x, int) else (x.lower() == "true" or x.lower() == "on") + ), + required=True, + nargs="?", + help="Enable bf16", + ) + parser.add_argument( + "--enable_fp8_e4m3", + type=lambda x: ( + x if isinstance(x, int) else (x.lower() == "true" or x.lower() == "on") + ), + default=True, + nargs="?", + help="Enable fp8_e4m3", + ) + parser.add_argument( + "--enable_fp8_e5m2", + type=lambda x: ( + x if isinstance(x, int) else (x.lower() == "true" or x.lower() == "on") + ), + default=True, + nargs="?", + help="Enable fp8_e5m2", + ) + args = parser.parse_args() + get_instantiation_cu(args) diff --git a/aot_build_utils/generate_pod_inst.py b/aot_build_utils/generate_pod_inst.py new file mode 100644 index 0000000000..53e061c09c --- /dev/null +++ b/aot_build_utils/generate_pod_inst.py @@ -0,0 +1,121 @@ +""" +Copyright (c) 2024 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import re +import sys +from pathlib import Path + +from .literal_map import ( + dtype_literal, + idtype_literal, + mask_mode_literal, + pos_encoding_mode_literal, +) + + +def get_cu_file_str( + head_dim_qk, + head_dim_vo, + pos_encoding_mode, + use_fp16_qk_reduction, + mask_mode_p, + mask_mode_d, + dtype_q, + dtype_kv, + dtype_out, + idtype, +): + cta_tile_q_choice = [128, 64, 16] + cta_tile_q_d = 16 + + def get_insts(attention_variant_p, attention_variant_d, dtype_out): + return "\n".join( + [ + """template cudaError_t PODWithKVCacheTensorDispatched<{head_dim_qk}, {head_dim_vo}, {pos_encoding_mode}, {use_fp16_qk_reduction}, {mask_mode_p}, {cta_tile_q_p}, {cta_tile_q_d}, {mask_mode_d}, {attention_variant_p}, {attention_variant_d}, PrefillParams, DecodeParams>( + PrefillParams prefill_params, DecodeParams decode_params, + {dtype_out}* tmp_v, float* tmp_s, bool enable_pdl, cudaStream_t stream); + """.format( + head_dim_qk=head_dim_qk, + head_dim_vo=head_dim_vo, + pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)], + use_fp16_qk_reduction=use_fp16_qk_reduction, + mask_mode_p=mask_mode_literal[int(mask_mode_p)], + cta_tile_q_p=cta_tile_q_p, + cta_tile_q_d=cta_tile_q_d, + mask_mode_d=mask_mode_literal[int(mask_mode_d)], + attention_variant_p=attention_variant_p, + attention_variant_d=attention_variant_d, + dtype_out=dtype_out, + ) + for cta_tile_q_p in cta_tile_q_choice + ] + ) + + use_custom_mask_p = "true" if int(mask_mode_p) == 2 else "false" + use_custom_mask_d = "true" if int(mask_mode_d) == 2 else "false" + dtype_q = dtype_literal[dtype_q] + dtype_kv = dtype_literal[dtype_kv] + dtype_out = dtype_literal[dtype_out] + idtype = idtype_literal[idtype] + + content = f"""#include +#include "pod_config.inc" + +namespace flashinfer {{ + +constexpr auto use_custom_mask_p = MaskMode::kNone == MaskMode::kCustom; +constexpr auto use_custom_mask_d = MaskMode::kNone == MaskMode::kCustom; + +using PrefillParams = BatchPrefillPagedParams<{dtype_q}, {dtype_kv}, {dtype_out}>; +using DecodeParams = BatchPrefillPagedParams<{dtype_q}, {dtype_kv}, {dtype_out}, {idtype}>; + +using AttentionVariant1_P = DefaultAttention; +using AttentionVariant1_D = DefaultAttention; + +{get_insts("AttentionVariant1_P", "AttentionVariant1_D", dtype_out)} + +using AttentionVariant2_P = DefaultAttention; +using AttentionVariant2_D = DefaultAttention; + +{get_insts("AttentionVariant2_P", "AttentionVariant2_D", dtype_out)} + +using AttentionVariant3_P = DefaultAttention; +using AttentionVariant3_D = DefaultAttention; + +{get_insts("AttentionVariant3_P", "AttentionVariant3_D", dtype_out)} + +using AttentionVariant4_P = DefaultAttention; +using AttentionVariant4_D = DefaultAttention; + +{get_insts("AttentionVariant4_P", "AttentionVariant4_D", dtype_out)} + +}}""" + return content + + +if __name__ == "__main__": + pattern = ( + r"pod_head_qk_([0-9]+)_head_vo_([0-9]+)_posenc_([0-9]+)_" + r"fp16qkred_([a-z]+)_maskp_([0-9]+)_maskd_([0-9]+)_" + r"dtypeq_([a-z0-9]+)_dtypekv_([a-z0-9]+)_dtypeout_([a-z0-9]+)_idtype_([a-z0-9]+)\.cu" + ) + compiled_pattern = re.compile(pattern) + path = Path(sys.argv[1]) + fname = path.name + match = compiled_pattern.match(fname) + + with open(path, "w") as f: + f.write(get_cu_file_str(*match.groups())) diff --git a/benchmarks/bench_mixed_attention.py b/benchmarks/bench_mixed_attention.py index 85753a71f9..9773e8f37d 100644 --- a/benchmarks/bench_mixed_attention.py +++ b/benchmarks/bench_mixed_attention.py @@ -72,6 +72,24 @@ def run_bench( measurements = bench_gpu_time(lambda: wrapper_old.run(q, kv_data)) ms_old = np.median(measurements) + wrapper_persistent = flashinfer.BatchAttention(kv_layout="NHD") + wrapper_persistent.plan( + q_indptr.to(device), + kv_indptr.to(device), + torch.arange(num_blocks, dtype=torch.int32, device=device), + seq_lens.to(device), + num_qo_heads, + num_kv_heads, + head_dim, + head_dim, + page_block_size, + causal=causal, + q_data_type=torch.bfloat16, + kv_data_type=torch.bfloat16, + ) + o_persistent, _ = wrapper_persistent.run(q, kv_data) + measurements_persistent = bench_gpu_time(lambda: wrapper_persistent.run(q, kv_data)) + ms_persistent = np.mean(measurements_persistent) if len(p_kv_lens) == 1: q_d = q[: d_q_indptr[-1]] kv_d = kv_data[: d_kv_indptr[-1]].unbind(1) @@ -123,9 +141,46 @@ def run_bench( ) ) ms_pod = np.median(measurements) + + # Sequential two kernels: single prefill + batch decode (tensor cores) + # Prefill using single_prefill_with_kv_cache + def _run_single_prefill(): + return flashinfer.prefill.single_prefill_with_kv_cache( + q_p, + k_p, + v_p, + causal=causal, + pos_encoding_mode="NONE", + backend="fa2", + ) + + measurements_prefill = bench_gpu_time(lambda: _run_single_prefill()) + ms_prefill = np.median(measurements_prefill) + + # Batch decode using tensor cores + wrapper_decode = flashinfer.BatchDecodeWithPagedKVCacheWrapper( + workspace_buffer, kv_layout=kv_layout, use_tensor_cores=True + ) + wrapper_decode.plan( + d_kv_indptr.to(device), + kv_indices_d.to(device), + last_page_len_d, + num_qo_heads, + num_kv_heads, + head_dim, + page_block_size, + data_type=torch.bfloat16, + q_data_type=torch.bfloat16, + ) + measurements_decode = bench_gpu_time(lambda: wrapper_decode.run(q_d, kv_d)) + ms_decode = np.median(measurements_decode) + ms_seq_two_kernels = ms_prefill + ms_decode + print(f"Elapsed time (Batched Prefill): {ms_old:.2f} ms") if len(p_kv_lens) == 1: print(f"Elapsed time (POD Attention): {ms_pod:.2f} ms") + print(f"Elapsed time (Sequential two kernels): {ms_seq_two_kernels:.2f} ms") + print(f"Elapsed time (Persistent BatchAttention): {ms_persistent:.2f} ms") total_bytes = ( q.numel() * q.element_size() + kv_data.numel() * kv_data.element_size() ) @@ -137,6 +192,14 @@ def run_bench( if len(p_kv_lens) == 1: bandwidth_pod_gb_s = total_bytes / (ms_pod * 1e-3) / (1024**3) print(f"Memory bandwidth (POD Attention): {bandwidth_pod_gb_s:.2f} GB/s") + bandwidth_seq_gb_s = total_bytes / (ms_seq_two_kernels * 1e-3) / (1024**3) + print( + f"Memory bandwidth (Sequential two kernels): {bandwidth_seq_gb_s:.2f} GB/s" + ) + bandwidth_persistent_gb_s = total_bytes / (ms_persistent * 1e-3) / (1024**3) + print( + f"Memory bandwidth (Persistent BatchAttention): {bandwidth_persistent_gb_s:.2f} GB/s" + ) if __name__ == "__main__": @@ -144,70 +207,14 @@ def run_bench( torch.random.manual_seed(42) # Irregular sequence lengths for prefill and decode - d_q_len_configs = [[1] * 122, [1] * 128, [1] * 242, [1] * 256] - d_kv_len_configs = [[600] * 122, [10000] * 128, [400] * 242, [8192] * 256] - p_q_configs = [[17] * 1, [10000], [17] * 1, []] - p_kv_configs = [[10000] * 1, [10000], [8192] * 1, []] - - # construct random length testcases - for _ in range(1): - bsz = 256 - stride = 16 - sparsity = 0.05 - - full_kv_len = np.random.randint(1000, 8192, size=bsz) - p_q_lens = [] - p_kv_lens = [] - d_q_lens = [] - d_kv_lens = [] - for i in range(bsz): - if i % stride == 0: - kv_len = full_kv_len[i] - qo_len = stride + 1 - p_q_lens.append(qo_len) - p_kv_lens.append(kv_len) - else: - kv_len = int(full_kv_len[i] * sparsity) - qo_len = 1 - d_q_lens.append(qo_len) - d_kv_lens.append(kv_len) - - p_q_configs.append(p_q_lens) - p_kv_configs.append(p_kv_lens) - d_q_len_configs.append(d_q_lens) - d_kv_len_configs.append(d_kv_lens) - - for _ in range(1): - bsz = 128 - stride = 16 - sparsity = 0.05 - - full_kv_len = np.random.randint(2000, 16000, size=bsz) - p_q_lens = [] - p_kv_lens = [] - d_q_lens = [] - d_kv_lens = [] - - for i in range(bsz): - if i % stride == 0: - kv_len = full_kv_len[i] - qo_len = stride + 1 - p_q_lens.append(qo_len) - p_kv_lens.append(kv_len) - else: - kv_len = int(full_kv_len[i] * sparsity) - qo_len = 1 - d_q_lens.append(qo_len) - d_kv_lens.append(kv_len) - - p_q_configs.append(p_q_lens) - p_kv_configs.append(p_kv_lens) - d_q_len_configs.append(d_q_lens) - d_kv_len_configs.append(d_kv_lens) + d_q_len_configs = [[1] * 128, [1] * 128, [1] * 128, [1] * 128] + d_kv_len_configs = [[2048] * 128, [4096] * 128, [8192] * 128, [8192] * 128] + p_q_configs = [[2048], [4096], [4096], [6000]] + p_kv_configs = [[2048], [4096], [4096], [7000]] page_block_size = 1 - num_kv_heads = 4 - num_qo_heads = 28 + num_kv_heads = 8 + num_qo_heads = 32 head_dim = 128 for idx, (p_q_lens, p_kv_lens, d_q_len, d_kv_len) in enumerate( diff --git a/csrc/batch_attention.cu b/csrc/batch_attention.cu index a3d36b7981..a233df29d2 100644 --- a/csrc/batch_attention.cu +++ b/csrc/batch_attention.cu @@ -152,7 +152,7 @@ void BatchPagedAttentionRun(TensorView float_workspace_buffer, TensorView int_wo GetPtrFromBaseOffset(int_buffer_ptr, plan_info.merge_indptr_offset); params[i].merge_o_indices = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.merge_o_indices_offset); - params[i].num_packed_qo_len = + params[i].num_to_merge_qo_len = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.num_qo_len_offset); params[i].num_kv_heads = num_kv_heads; diff --git a/csrc/batch_attention_customize_config.jinja b/csrc/batch_attention_customize_config.jinja index 3cf9312748..c7de93ff87 100644 --- a/csrc/batch_attention_customize_config.jinja +++ b/csrc/batch_attention_customize_config.jinja @@ -93,7 +93,7 @@ struct PersistentParams { // for state reduction IdType* merge_indptr; IdType* merge_o_indices; - IdType* num_packed_qo_len; + IdType* num_to_merge_qo_len; uint32_t num_kv_heads; uint_fastdiv gqa_group_size; diff --git a/csrc/batch_decode.cu b/csrc/batch_decode.cu index c3ce1e2ecf..244ce8bede 100644 --- a/csrc/batch_decode.cu +++ b/csrc/batch_decode.cu @@ -143,7 +143,6 @@ void BatchDecodeWithPagedKVCacheRun(TensorView float_workspace_buffer, static_cast(paged_kv_indices.data_ptr()), static_cast(paged_kv_indptr.data_ptr()), static_cast(paged_kv_last_page_len.data_ptr())); - Params params; params.q = static_cast(q.data_ptr()); params.paged_kv = paged_kv; diff --git a/csrc/flashinfer_ops.cu b/csrc/flashinfer_ops.cu new file mode 100644 index 0000000000..90d63c94aa --- /dev/null +++ b/csrc/flashinfer_ops.cu @@ -0,0 +1,338 @@ +/* + * Copyright (c) 2023 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "aot_default_additional_params.h" +#include "pytorch_extension_utils.h" + +//========== activation ========== + +void silu_and_mul(at::Tensor& out, at::Tensor& input, bool enable_pdl); +void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input, bool enable_pdl); +void gelu_and_mul(at::Tensor& out, at::Tensor& input, bool enable_pdl); + +//========== cascade ========== + +void merge_state(at::Tensor v_a, at::Tensor s_a, at::Tensor v_b, at::Tensor s_b, + at::Tensor v_merged, at::Tensor s_merged); + +void merge_state_in_place(at::Tensor v, at::Tensor s, at::Tensor v_other, at::Tensor s_other, + std::optional mask); + +void merge_states(at::Tensor v, at::Tensor s, at::Tensor v_merged, at::Tensor s_merged); + +//========== decode ========== + +void single_decode_with_kv_cache(at::Tensor q, at::Tensor k, at::Tensor v, at::Tensor tmp, + at::Tensor o, std::optional maybe_lse, int64_t layout, + int64_t window_left SINGLE_DECODE_ADDITIONAL_FUNC_PARAMS); + +at::Tensor BatchDecodeWithPagedKVCachePlan( + at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, + at::Tensor page_locked_int_workspace_buffer, at::Tensor indptr, int64_t batch_size, + int64_t num_qo_heads, int64_t num_kv_heads, int64_t page_size, bool enable_cuda_graph, + int64_t window_left, double logits_soft_cap, int64_t head_dim_qk, int64_t head_dim_vo, + at::Tensor empty_q_data, at::Tensor empty_kv_data); + +void BatchDecodeWithPagedKVCacheRun(at::Tensor float_workspace_buffer, + at::Tensor int_workspace_buffer, at::Tensor plan_info_vec, + at::Tensor q, at::Tensor paged_k_cache, + at::Tensor paged_v_cache, at::Tensor paged_kv_indptr, + at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len, + at::Tensor o, std::optional maybe_lse, + int64_t kv_layout_code, int64_t window_left, + bool enable_pdl BATCH_DECODE_ADDITIONAL_FUNC_PARAMS); + +//========== gemm ========== + +void bmm_fp8(at::Tensor A, at::Tensor B, at::Tensor D, at::Tensor A_scale, at::Tensor B_scale, + at::Tensor workspace_buffer, int64_t cublas_handle); + +void CutlassSegmentGEMM(at::Tensor workspace_buffer, at::Tensor all_problems, at::Tensor x_ptr, + at::Tensor w_ptr, at::Tensor y_ptr, at::Tensor x_ld, at::Tensor w_ld, + at::Tensor y_ld, at::Tensor empty_x_data, bool weight_column_major); + +//========== norm ========== + +void rmsnorm(at::Tensor& out, at::Tensor& input, at::Tensor& weight, double eps, bool enable_pdl); + +void fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps, + bool enable_pdl); + +void gemma_rmsnorm(at::Tensor& out, at::Tensor& input, at::Tensor& weight, double eps, + bool enable_pdl); + +void gemma_fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, + double eps, bool enable_pdl); + +//========== page ========== + +void append_paged_kv_cache(at::Tensor append_key, at::Tensor append_value, at::Tensor batch_indices, + at::Tensor positions, at::Tensor paged_k_cache, at::Tensor paged_v_cache, + at::Tensor kv_indices, at::Tensor kv_indptr, at::Tensor kv_last_page_len, + int64_t layout); + +void append_paged_mla_kv_cache(at::Tensor append_ckv, at::Tensor append_kpe, + at::Tensor batch_indices, at::Tensor positions, at::Tensor ckv_cache, + at::Tensor kpe_cache, at::Tensor kv_indices, at::Tensor kv_indptr, + at::Tensor kv_last_page_len); + +void block_sparse_indices_to_vector_sparse_offsets( + at::Tensor block_sparse_indices, at::Tensor block_sparse_indptr, + at::Tensor vector_sparse_offsets, at::Tensor vector_sparse_indptr, at::Tensor kv_len_arr, + int64_t stride_block, int64_t stride_n, int64_t batch_size, int64_t block_size); + +//========== prefill ========== + +void single_prefill_with_kv_cache(at::Tensor q, at::Tensor k, at::Tensor v, at::Tensor tmp, + at::Tensor o, std::optional maybe_lse, + int64_t mask_mode_code, int64_t layout, + int64_t window_left SINGLE_PREFILL_ADDITIONAL_FUNC_PARAMS); + +at::Tensor BatchPrefillWithKVCachePlan( + at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, + at::Tensor page_locked_int_workspace_buffer, at::Tensor qo_indptr, at::Tensor kv_indptr, + at::Tensor kv_len_arr, int64_t total_num_rows, int64_t batch_size, int64_t num_qo_heads, + int64_t num_kv_heads, int64_t page_size, bool enable_cuda_graph, int64_t head_dim_qk, + int64_t head_dim_vo, bool causal); + +void BatchPrefillWithRaggedKVCacheRun(at::Tensor float_workspace_buffer, + at::Tensor int_workspace_buffer, at::Tensor plan_info_vec, + at::Tensor q, at::Tensor k, at::Tensor v, + at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor o, + std::optional maybe_lse, int64_t mask_mode_code, + int64_t layout, + int64_t window_left BATCH_PREFILL_ADDITIONAL_FUNC_PARAMS); + +void BatchPrefillWithPagedKVCacheRun( + at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, at::Tensor plan_info_vec, + at::Tensor q, at::Tensor paged_k_cache, at::Tensor paged_v_cache, at::Tensor qo_indptr, + at::Tensor paged_kv_indptr, at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len, + at::Tensor o, std::optional maybe_lse, int64_t mask_mode_code, int64_t layout, + int64_t window_left BATCH_PREFILL_ADDITIONAL_FUNC_PARAMS); + +//========== pod-attention ========= +void PODWithKVCacheTensorRun( + // Shared params + at::Tensor float_workspace_buffer_d, at::Tensor int_workspace_buffer_d, + at::Tensor plan_info_vec, at::Tensor qo_indptr, at::Tensor paged_kv_indptr, + at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len, at::Tensor o, + // Prefill params + at::Tensor q_p, at::Tensor paged_k_p, at::Tensor paged_v_p, + std::optional maybe_lse_p, int64_t mask_mode_code_p, int64_t layout_p, + int64_t window_left_p, std::optional maybe_custom_mask_p, + std::optional maybe_alibi_slopes_p, double logits_soft_cap_p, double sm_scale_p, + double rope_rcp_scale_p, double rope_rcp_theta_p, + // Decode params + at::Tensor q_d, at::Tensor paged_k_cache_d, at::Tensor paged_v_cache_d, at::Tensor qo_indptr_d, + at::Tensor paged_kv_indptr_d, at::Tensor paged_kv_indices_d, + at::Tensor paged_kv_last_page_len_d, std::optional maybe_lse_d, + int64_t mask_mode_code_d, int64_t layout_d, int64_t window_left_d, + std::optional maybe_custom_mask_d, std::optional maybe_mask_indptr_d, + std::optional maybe_alibi_slopes_d, double logits_soft_cap_d, double sm_scale_d, + double rope_rcp_scale_d, double rope_rcp_theta_d, bool enable_pdl); + +at::Tensor PODWithKVCachePlan(at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, + at::Tensor page_locked_int_workspace_buffer, at::Tensor qo_indptr_p, + at::Tensor kv_indptr_p, int64_t total_num_rows_p, + int64_t batch_size_p, at::Tensor qo_indptr_d, at::Tensor kv_indptr_d, + int64_t total_num_rows_d, int64_t batch_size_d, int64_t num_qo_heads, + int64_t num_kv_heads, int64_t head_dim_qk, int64_t head_dim_vo, + int64_t page_size, bool enable_cuda_graph); +//========== quantization ========== + +void packbits(at::Tensor x, const std::string& bitorder, at::Tensor y); + +void segment_packbits(at::Tensor x, at::Tensor input_indptr, at::Tensor output_indptr, + const std::string& bitorder, at::Tensor y); + +//========== rope ========== + +void apply_rope(at::Tensor q, at::Tensor k, at::Tensor q_rope, at::Tensor k_rope, at::Tensor indptr, + at::Tensor offsets, int64_t rotary_dim, bool interleave, double rope_scale, + double rope_theta); + +void apply_llama31_rope(at::Tensor q, at::Tensor k, at::Tensor q_rope, at::Tensor k_rope, + at::Tensor indptr, at::Tensor offsets, int64_t rotary_dim, bool interleave, + double rope_scale, double rope_theta, double low_freq_factor, + double high_freq_factor, double old_context_length); + +void apply_rope_pos_ids(at::Tensor q, at::Tensor k, at::Tensor q_rope, at::Tensor k_rope, + at::Tensor pos_ids, int64_t rotary_dim, bool interleave, double rope_scale, + double rope_theta); + +void apply_llama31_rope_pos_ids(at::Tensor q, at::Tensor k, at::Tensor q_rope, at::Tensor k_rope, + at::Tensor pos_ids, int64_t rotary_dim, bool interleave, + double rope_scale, double rope_theta, double low_freq_factor, + double high_freq_factor, double old_context_length); + +void apply_rope_pos_ids_cos_sin_cache(at::Tensor q, at::Tensor k, at::Tensor q_rope, + at::Tensor k_rope, at::Tensor cos_sin_cache, + at::Tensor pos_ids, bool interleave); + +//========== sampling ========== + +void softmax(at::Tensor workspace_buffer, at::Tensor logits, at::Tensor output, + std::optional maybe_temperature_arr, double temperature_val, + bool enable_pdl); + +void sampling_from_probs(at::Tensor probs, at::Tensor output, + std::optional maybe_indices, bool deterministic, + std::optional gen); + +void sampling_from_logits(at::Tensor logits, at::Tensor output, + std::optional maybe_indices, bool deterministic, + std::optional gen); + +void top_p_sampling_from_probs(at::Tensor probs, at::Tensor output, + std::optional maybe_indices, + std::optional maybe_top_p_arr, double top_p_val, + bool deterministic, std::optional gen); + +void top_k_sampling_from_probs(at::Tensor probs, at::Tensor output, + std::optional maybe_indices, + std::optional maybe_top_k_arr, int64_t top_k_val, + bool deterministic, std::optional gen); + +void min_p_sampling_from_probs(at::Tensor probs, at::Tensor output, + std::optional maybe_indices, + std::optional maybe_min_p_arr, double min_p_val, + bool deterministic, std::optional gen); + +void top_k_top_p_sampling_from_probs(at::Tensor probs, at::Tensor output, + std::optional maybe_indices, + std::optional maybe_top_k_arr, double top_k_val, + std::optional maybe_top_p_arr, double top_p_val, + bool deterministic, std::optional gen); + +void top_p_renorm_probs(at::Tensor probs, at::Tensor renorm_probs, + std::optional maybe_top_p_arr, double top_p_val); + +void top_k_renorm_probs(at::Tensor probs, at::Tensor renorm_probs, + std::optional maybe_top_k_arr, int64_t top_k_val); + +void top_k_mask_logits(at::Tensor logits, at::Tensor mask_logits, + std::optional maybe_top_k_arr, int64_t top_k_val); + +void chain_speculative_sampling(at::Tensor draft_probs, at::Tensor draft_token_ids, + at::Tensor target_probs, at::Tensor output_token_ids, + at::Tensor output_accepted_token_num, + at::Tensor output_emitted_draft_token_num, bool deterministic, + std::optional gen); + +//========== Torch Library ========== + +TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) { + // activation + // Fused SiLU and Mul + m.def("silu_and_mul", silu_and_mul); + // Fused GeLU Tanh and Mul + m.def("gelu_tanh_and_mul", gelu_tanh_and_mul); + // Fused GeLU and Mul + m.def("gelu_and_mul", gelu_and_mul); + + // cascade + // Merge two self-attention states + m.def("merge_state", merge_state); + // Merge another self-attention state in-place. + m.def("merge_state_in_place", merge_state_in_place); + // "Merge multiple self-attention states" + m.def("merge_states", merge_states); + + // decode + // "Single-request decode with KV-Cache operator" + m.def("single_decode_with_kv_cache", single_decode_with_kv_cache); + m.def("batch_decode_with_paged_kv_cache_plan", BatchDecodeWithPagedKVCachePlan); + m.def("batch_decode_with_paged_kv_cache_run", BatchDecodeWithPagedKVCacheRun); + + // gemm + // BMM FP8 + m.def("bmm_fp8", bmm_fp8); + // Cutlass Segment GEMM operator + m.def("cutlass_segment_gemm", CutlassSegmentGEMM); + + // norm + // Root mean square normalization + m.def("rmsnorm", rmsnorm); + // Fused add root mean square normalization + m.def("fused_add_rmsnorm", fused_add_rmsnorm); + // Gemma Root mean square normalization + m.def("gemma_rmsnorm", gemma_rmsnorm); + // Gemma Fused add root mean square normalization + m.def("gemma_fused_add_rmsnorm", gemma_fused_add_rmsnorm); + + // page + // Append paged KV-Cache operator + m.def("append_paged_kv_cache", append_paged_kv_cache); + // Append paged MLA KV-Cache operator + m.def("append_paged_mla_kv_cache", append_paged_mla_kv_cache); + // Precompute block sparse offsets + m.def("block_sparse_indices_to_vector_sparse_offsets", + block_sparse_indices_to_vector_sparse_offsets); + + // prefill + // Single-request prefill attention with KV-Cache operator + m.def("single_prefill_with_kv_cache", single_prefill_with_kv_cache); + m.def("batch_prefill_with_kv_cache_plan", BatchPrefillWithKVCachePlan); + m.def("batch_prefill_with_ragged_kv_cache_run", BatchPrefillWithRaggedKVCacheRun); + m.def("batch_prefill_with_paged_kv_cache_run", BatchPrefillWithPagedKVCacheRun); + + // pod-attention + // Temporarily disabled because we don't generate the implementation yet. + // m.def("PODWithKVCacheTensor", PODWithKVCacheTensorRun); + m.def("pod_with_kv_cache_plan", PODWithKVCachePlan); + m.def("pod_with_kv_cache_tensor_run", PODWithKVCacheTensorRun); + + // quantization + // GPU packbits operator + m.def("packbits", packbits); + // GPU segment packbits operator + m.def("segment_packbits", segment_packbits); + + // rope + // "Apply RoPE" + m.def("apply_rope", apply_rope); + // "Apply Llama 3.1 style RoPE" + m.def("apply_llama31_rope", apply_llama31_rope); + // "Apply RoPE with positional ids" + m.def("apply_rope_pos_ids", apply_rope_pos_ids); + // "Apply Llama 3.1 style RoPE with positional ids" + m.def("apply_llama31_rope_pos_ids", apply_llama31_rope_pos_ids); + // "Apply RoPE with positional ids and cosine/sine cache" + m.def("apply_rope_pos_ids_cos_sin_cache", apply_rope_pos_ids_cos_sin_cache); + + // sampling + // Softmax + m.def("softmax", softmax); + // Sample from probabilities + m.def("sampling_from_probs", sampling_from_probs); + // Sample from logits + m.def("sampling_from_logits", sampling_from_logits); + // Top-k sampling from probabilities + m.def("top_k_sampling_from_probs", top_k_sampling_from_probs); + // Min-p sampling from probabilities + m.def("min_p_sampling_from_probs", min_p_sampling_from_probs); + // Top-p sampling from probabilities + m.def("top_p_sampling_from_probs", top_p_sampling_from_probs); + // Top-k and top-p sampling from probabilities + m.def("top_k_top_p_sampling_from_probs", top_k_top_p_sampling_from_probs); + // Renormalize probabilities by top-k mask + m.def("top_k_renorm_probs", top_k_renorm_probs); + // Renormalize probabilities by top-p mask + m.def("top_p_renorm_probs", top_p_renorm_probs); + // Mask logits by top-k mask + m.def("top_k_mask_logits", top_k_mask_logits); + // Speculative sampling from sequence of probabilities + m.def("chain_speculative_sampling", chain_speculative_sampling); +} diff --git a/csrc/pod.cu b/csrc/pod.cu index e38796e0e3..1a5c2e217b 100644 --- a/csrc/pod.cu +++ b/csrc/pod.cu @@ -21,149 +21,180 @@ namespace flashinfer { template -cudaError_t PODWithKVCacheTensorDispatched(PrefillParams prefill_params, - typename PrefillParams::DTypeO* tmp, - DecodeParams decode_params, + bool USE_FP16_QK_REDUCTION, MaskMode MASK_MODE_P, uint32_t CTA_TILE_Q_P, + uint32_t CTA_TILE_Q_D, MaskMode MASK_MODE_D, typename PrefillAttentionVariant, + typename DecodeAttentionVariant, typename PrefillParams, typename DecodeParams> +cudaError_t PODWithKVCacheTensorDispatched(PrefillParams prefill_params, DecodeParams decode_params, typename DecodeParams::DTypeO* tmp_v, float* tmp_s, bool enable_pdl, cudaStream_t stream); } // namespace flashinfer using namespace flashinfer; - using tvm::ffi::Array; using tvm::ffi::Optional; -void pod_with_kv_cache_tensor( +Array PODWithKVCachePlan( + TensorView float_workspace_buffer, TensorView int_workspace_buffer, + TensorView page_locked_int_workspace_buffer, TensorView qo_indptr_p, TensorView kv_indptr_p, + int64_t total_num_rows_p, int64_t batch_size_p, TensorView qo_indptr_d, TensorView kv_indptr_d, + int64_t total_num_rows_d, int64_t batch_size_d, int64_t num_qo_heads, int64_t num_kv_heads, + int64_t head_dim_qk, int64_t head_dim_vo, int64_t page_size, bool enable_cuda_graph) { + size_t float_workspace_size_in_bytes = + float_workspace_buffer.size(0) * get_element_size(float_workspace_buffer); + size_t int_workspace_size_in_bytes = + int_workspace_buffer.size(0) * get_element_size(int_workspace_buffer); + + PODPlanInfo plan_info; + + cudaSetDevice(float_workspace_buffer.device().device_id); + const cudaStream_t stream = get_stream(float_workspace_buffer.device()); + cudaError_t status = PODPlan( + float_workspace_buffer.data_ptr(), float_workspace_size_in_bytes, + int_workspace_buffer.data_ptr(), page_locked_int_workspace_buffer.data_ptr(), + int_workspace_size_in_bytes, plan_info, static_cast(qo_indptr_p.data_ptr()), + static_cast(kv_indptr_p.data_ptr()), total_num_rows_p, batch_size_p, + static_cast(qo_indptr_d.data_ptr()), static_cast(kv_indptr_d.data_ptr()), + total_num_rows_d, batch_size_d, num_qo_heads, num_kv_heads, head_dim_qk, head_dim_vo, + page_size, enable_cuda_graph, /*sizeof_dtype_o=*/2, stream); + + TVM_FFI_ICHECK(status == cudaSuccess) + << "Failed to plan prefill with error: " << cudaGetErrorString(status); + + return Array(plan_info.ToVector()); +} + +void PODWithKVCacheTensorRun( + // Shared params + TensorView float_workspace_buffer_d, TensorView int_workspace_buffer_d, + Array plan_info_vec, TensorView paged_k_cache, TensorView paged_v_cache, + TensorView qo_indptr, TensorView paged_kv_indptr, TensorView paged_kv_indices, + TensorView paged_kv_last_page_len, TensorView o, Optional maybe_lse, int64_t layout, // Prefill params - TensorView q_p, TensorView k_p, TensorView v_p, TensorView tmp_p, TensorView o_p, - Optional maybe_lse_p, int64_t mask_mode_code_p, int64_t layout_p, - int64_t window_left_p, Optional maybe_custom_mask_p, - Optional maybe_alibi_slopes_p, double logits_soft_cap_p, double sm_scale_p, - double rope_rcp_scale_p, double rope_rcp_theta_p, + TensorView q_p, int64_t mask_mode_code_p, int64_t window_left_p, + Optional maybe_custom_mask_p, Optional maybe_alibi_slopes_p, + double logits_soft_cap_p, double sm_scale_p, double rope_rcp_scale_p, double rope_rcp_theta_p, // Decode params - TensorView float_workspace_buffer_d, TensorView int_workspace_buffer_d, - Array plan_info_vec, TensorView q_d, TensorView paged_k_cache_d, - TensorView paged_v_cache_d, TensorView qo_indptr_d, TensorView paged_kv_indptr_d, - TensorView paged_kv_indices_d, TensorView paged_kv_last_page_len_d, TensorView o_d, - Optional maybe_lse_d, int64_t mask_mode_code_d, int64_t layout_d, - int64_t window_left_d, Optional maybe_custom_mask_d, - Optional maybe_mask_indptr_d, Optional maybe_alibi_slopes_d, - double logits_soft_cap_d, double sm_scale_d, double rope_rcp_scale_d, double rope_rcp_theta_d, - bool enable_pdl) { + TensorView q_d, int64_t mask_mode_code_d, int64_t window_left_d, + Optional maybe_custom_mask_d, Optional maybe_mask_indptr_d, + Optional maybe_alibi_slopes_d, double logits_soft_cap_d, double sm_scale_d, + double rope_rcp_scale_d, double rope_rcp_theta_d, bool enable_pdl) { + PODPlanInfo plan_info; + plan_info.FromVector(std::vector(plan_info_vec.begin(), plan_info_vec.end())); + uint32_t batch_size = paged_kv_indptr.size(0) - 1; + void* float_buffer_ptr = static_cast(float_workspace_buffer_d.data_ptr()); + void* int_buffer_ptr = static_cast(int_workspace_buffer_d.data_ptr()); + // Prefill setup - unsigned int head_dim_qk = q_p.size(2); - unsigned int kv_len_p, qo_len_p, num_kv_heads, num_qo_heads; - QKVLayout kv_layout_p = static_cast(layout_p); - qo_len_p = q_p.size(0); - num_qo_heads = q_p.size(1); - uint32_t q_stride_n_p = q_p.stride(0), q_stride_h_p = q_p.stride(1), k_stride_n_p, k_stride_h_p, - v_stride_n_p, v_stride_h_p; - if (kv_layout_p == QKVLayout::kNHD) { - kv_len_p = k_p.size(0); - num_kv_heads = k_p.size(1); - k_stride_n_p = k_p.stride(0); - k_stride_h_p = k_p.stride(1); - v_stride_n_p = v_p.stride(0); - v_stride_h_p = v_p.stride(1); - } else { - kv_len_p = k_p.size(1); - num_kv_heads = k_p.size(0); - k_stride_h_p = k_p.stride(0); - k_stride_n_p = k_p.stride(1); - v_stride_h_p = v_p.stride(0); - v_stride_n_p = v_p.stride(1); - } - if (maybe_lse_p.has_value()) { - const auto& lse = maybe_lse_p.value(); - TVM_FFI_ICHECK_EQ(lse.size(0), qo_len_p); - TVM_FFI_ICHECK_EQ(lse.size(1), num_qo_heads); + uint32_t head_dim_qk = q_p.size(2); + uint32_t qo_len, num_qo_heads_p; + QKVLayout kv_layout = static_cast(layout); + qo_len = q_p.size(0) + q_d.size(0); + num_qo_heads_p = q_p.size(1); + uint32_t q_stride_n_p = q_p.stride(0), q_stride_h_p = q_p.stride(1); + if (maybe_lse.has_value()) { + const auto& lse = maybe_lse.value(); + TVM_FFI_ICHECK_EQ(lse.size(0), qo_len); + TVM_FFI_ICHECK_EQ(lse.size(1), num_qo_heads_p); } const MaskMode mask_mode_p = static_cast(mask_mode_code_p); - // Decode setup (TensorView decode = batched prefill) - PrefillPlanInfo plan_info; - plan_info.FromVector(std::vector(plan_info_vec.begin(), plan_info_vec.end())); - QKVLayout kv_layout_d = static_cast(layout_d); - int64_t batch_size = paged_kv_indptr_d.size(0) - 1; - int64_t num_qo_heads_d = q_d.size(1); - - TVM_FFI_ICHECK_EQ(num_qo_heads, num_qo_heads_d) + // Decode setup (Tensor decode = batched prefill) + uint32_t num_qo_heads = q_d.size(1); + TVM_FFI_ICHECK_EQ(num_qo_heads_p, num_qo_heads) << "POD currently requires same # Query heads for prefill and decode"; - int64_t num_kv_heads_d, page_size_d; - uint32_t head_dim_qk_d = q_d.size(2); - if (kv_layout_d == QKVLayout::kHND) { - num_kv_heads_d = paged_k_cache_d.size(1); - page_size_d = paged_k_cache_d.size(2); + uint32_t num_kv_heads_d, num_kv_heads, page_size; + if (kv_layout == QKVLayout::kHND) { + num_kv_heads = paged_k_cache.size(1); + num_kv_heads_d = paged_k_cache.size(1); + page_size = paged_k_cache.size(2); } else { - page_size_d = paged_k_cache_d.size(1); - num_kv_heads_d = paged_k_cache_d.size(2); + num_kv_heads = paged_k_cache.size(2); + num_kv_heads_d = paged_k_cache.size(2); + page_size = paged_k_cache.size(1); } TVM_FFI_ICHECK_EQ(num_kv_heads, num_kv_heads_d) << "POD currently requires same # KV heads for prefill and decode; Prefill: " << num_kv_heads << ", Decode: " << num_kv_heads_d; - if (maybe_lse_d.has_value()) { - const auto& lse = maybe_lse_d.value(); - TVM_FFI_ICHECK_EQ(lse.size(0), q_d.size(0)); - TVM_FFI_ICHECK_EQ(lse.size(1), q_d.size(1)); - } - - void* float_buffer_ptr = static_cast(float_workspace_buffer_d.data_ptr()); - void* int_buffer_ptr = static_cast(int_workspace_buffer_d.data_ptr()); - const MaskMode mask_mode_d = static_cast(mask_mode_code_d); // get q_stride_n and q_stride_h const auto q_stride_n_d = q_d.stride(0); const auto q_stride_h_d = q_d.stride(1); - // get kv_cache_strides - const int64_t* kv_cache_strides_d = nullptr; - auto k_strides_d = paged_k_cache_d.strides(); - auto v_strides_d = paged_v_cache_d.strides(); - TVM_FFI_ICHECK_EQ(k_strides_d.size(), v_strides_d.size()); - for (int i = 0; i < k_strides_d.size(); ++i) { - TVM_FFI_ICHECK_EQ(k_strides_d[i], v_strides_d[i]); - } - kv_cache_strides_d = k_strides_d.data(); - cudaSetDevice(float_workspace_buffer_d.device().device_id); const cudaStream_t stream = get_stream(float_workspace_buffer_d.device()); DISPATCH_context( MASK_MODE_P, MASK_MODE_D, DTypeQ, DTypeKV, HEAD_DIM_QK, USE_SLIDING_WINDOW_P, USE_SLIDING_WINDOW_D, USE_LOGITS_SOFT_CAP, [&] { + // Compute kv_cache_strides from tensor strides + // paged_kv_t expects [stride_page, stride_n, stride_h] where: + // - stride_page is stride(0) + // - stride_n and stride_h depend on layout + int64_t kv_strides[3]; + kv_strides[0] = paged_k_cache.stride(0); // stride_page + if (kv_layout == QKVLayout::kHND) { + kv_strides[1] = paged_k_cache.stride(1); // stride_h + kv_strides[2] = paged_k_cache.stride(2); // stride_n + } else { + kv_strides[1] = paged_k_cache.stride(1); // stride_n + kv_strides[2] = paged_k_cache.stride(2); // stride_h + } + TVM_FFI_ICHECK(paged_k_cache.stride(0) == paged_v_cache.stride(0) && + paged_k_cache.stride(1) == paged_v_cache.stride(1) && + paged_k_cache.stride(2) == paged_v_cache.stride(2)) + << "k/v strides must be identical"; + + paged_kv_t paged_kv( + num_kv_heads, page_size, HEAD_DIM_VO, batch_size, kv_layout, + static_cast(paged_k_cache.data_ptr()), + static_cast(paged_v_cache.data_ptr()), kv_strides, + static_cast(paged_kv_indices.data_ptr()), + static_cast(paged_kv_indptr.data_ptr()), + static_cast(paged_kv_last_page_len.data_ptr())); + IdType* q_indptr = static_cast(qo_indptr.data_ptr()); + + // debug indices PrefillParams prefill_params; { // Make params a reference to prefill_params to set values PrefillParams& params = prefill_params; params.q = static_cast(q_p.data_ptr()); - params.k = static_cast(k_p.data_ptr()); - params.v = static_cast(v_p.data_ptr()); - params.o = static_cast(o_p.data_ptr()); - params.lse = maybe_lse_p.has_value() ? static_cast(maybe_lse_p.value().data_ptr()) - : nullptr; - params.num_qo_heads = num_qo_heads; - params.num_kv_heads = num_kv_heads; - params.group_size = uint_fastdiv(num_qo_heads / num_kv_heads); - params.qo_len = qo_len_p; - params.kv_len = kv_len_p; + params.paged_kv = paged_kv; + params.q_indptr = static_cast(qo_indptr.data_ptr()); + + params.o = static_cast(o.data_ptr()); + params.lse = + maybe_lse.has_value() ? static_cast(maybe_lse.value().data_ptr()) : nullptr; + params.group_size = uint_fastdiv(num_qo_heads / paged_kv.num_heads); params.q_stride_n = q_stride_n_p; params.q_stride_h = q_stride_h_p; - params.k_stride_n = k_stride_n_p; - params.k_stride_h = k_stride_h_p; - params.v_stride_n = v_stride_n_p; - params.v_stride_h = v_stride_h_p; - params.window_left = window_left_p; - params.partition_kv = false; + params.paged_kv.num_heads = num_kv_heads; + params.num_qo_heads = num_qo_heads; + params.request_indices = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.request_indices_offset); + params.qo_tile_indices = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.qo_tile_indices_offset); + params.kv_tile_indices = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.kv_tile_indices_offset); + params.o_indptr = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.o_indptr_offset); + if (plan_info.split_kv) { + params.merge_indptr = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.merge_indptr_offset); + if (plan_info.enable_cuda_graph) { + params.block_valid_mask = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.block_valid_mask_offset); + } + } + params.kv_chunk_size_ptr = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.kv_chunk_size_ptr_offset_p); + params.padded_batch_size = plan_info.padded_batch_size_p; params.maybe_custom_mask = maybe_custom_mask_p.has_value() ? static_cast(maybe_custom_mask_p.value().data_ptr()) @@ -176,6 +207,18 @@ void pod_with_kv_cache_tensor( params.sm_scale = sm_scale_p; params.rope_rcp_scale = rope_rcp_scale_p; params.rope_rcp_theta = rope_rcp_theta_p; + params.max_total_num_rows = plan_info.total_num_rows; + if (plan_info.enable_cuda_graph) { + params.total_num_rows = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.total_num_rows_offset); + } + params.partition_kv = plan_info.split_kv; + if (plan_info.split_kv) { + if (plan_info.enable_cuda_graph) { + params.block_valid_mask = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.block_valid_mask_offset); + } + } } DecodeParams decode_params; @@ -184,36 +227,36 @@ void pod_with_kv_cache_tensor( { DecodeParams& params = decode_params; params.q = static_cast(q_d.data_ptr()); - paged_kv_t paged_kv( - num_kv_heads, page_size_d, HEAD_DIM_VO, batch_size, kv_layout_d, - static_cast(paged_k_cache_d.data_ptr()), - static_cast(paged_v_cache_d.data_ptr()), kv_cache_strides_d, - static_cast(paged_kv_indices_d.data_ptr()), - static_cast(paged_kv_indptr_d.data_ptr()), - static_cast(paged_kv_last_page_len_d.data_ptr())); params.paged_kv = paged_kv; - params.q_indptr = static_cast(qo_indptr_d.data_ptr()); - params.o = static_cast(o_d.data_ptr()); - - params.lse = maybe_lse_d.has_value() ? static_cast(maybe_lse_d.value().data_ptr()) - : nullptr; - params.num_qo_heads = num_qo_heads; + params.q_indptr = static_cast(qo_indptr.data_ptr()); + params.o = static_cast(o.data_ptr()); + params.lse = + maybe_lse.has_value() ? static_cast(maybe_lse.value().data_ptr()) : nullptr; params.group_size = uint_fastdiv(num_qo_heads / paged_kv.num_heads); params.q_stride_n = q_stride_n_d; params.q_stride_h = q_stride_h_d; params.window_left = window_left_d; + params.paged_kv.num_heads = num_kv_heads; + params.num_qo_heads = num_qo_heads; + + params.request_indices = prefill_params.request_indices; + params.qo_tile_indices = prefill_params.qo_tile_indices; + params.kv_tile_indices = prefill_params.kv_tile_indices; + params.o_indptr = prefill_params.o_indptr; + params.kv_chunk_size_ptr = prefill_params.kv_chunk_size_ptr; - params.request_indices = nullptr; - params.qo_tile_indices = nullptr; - params.kv_tile_indices = nullptr; - params.merge_indptr = nullptr; - params.o_indptr = nullptr; - params.kv_chunk_size_ptr = nullptr; - params.block_valid_mask = nullptr; - params.total_num_rows = nullptr; - params.max_total_num_rows = 0; - params.padded_batch_size = 0; - params.partition_kv = false; + params.partition_kv = plan_info.split_kv; + if (plan_info.split_kv) { + params.merge_indptr = prefill_params.merge_indptr; + // These should be assigned from plan info, not from prefill_params + tmp_v = GetPtrFromBaseOffset(float_buffer_ptr, plan_info.v_offset); + tmp_s = GetPtrFromBaseOffset(float_buffer_ptr, plan_info.s_offset); + if (plan_info.enable_cuda_graph) { + params.block_valid_mask = prefill_params.block_valid_mask; + } + } + params.padded_batch_size = plan_info.padded_batch_size_d; + params.max_total_num_rows = plan_info.total_num_rows; params.maybe_mask_indptr = maybe_mask_indptr_d.has_value() @@ -228,30 +271,8 @@ void pod_with_kv_cache_tensor( params.rope_rcp_scale = rope_rcp_scale_d; params.rope_rcp_theta = rope_rcp_theta_d; - params.request_indices = - GetPtrFromBaseOffset(int_buffer_ptr, plan_info.request_indices_offset); - params.qo_tile_indices = - GetPtrFromBaseOffset(int_buffer_ptr, plan_info.qo_tile_indices_offset); - params.kv_tile_indices = - GetPtrFromBaseOffset(int_buffer_ptr, plan_info.kv_tile_indices_offset); - params.o_indptr = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.o_indptr_offset); - params.kv_chunk_size_ptr = - GetPtrFromBaseOffset(int_buffer_ptr, plan_info.kv_chunk_size_ptr_offset); - if (plan_info.split_kv) { - params.merge_indptr = - GetPtrFromBaseOffset(int_buffer_ptr, plan_info.merge_indptr_offset); - tmp_v = GetPtrFromBaseOffset(float_buffer_ptr, plan_info.v_offset); - tmp_s = GetPtrFromBaseOffset(float_buffer_ptr, plan_info.s_offset); - if (plan_info.enable_cuda_graph) { - params.block_valid_mask = - GetPtrFromBaseOffset(int_buffer_ptr, plan_info.block_valid_mask_offset); - } - } - params.padded_batch_size = plan_info.padded_batch_size; - params.max_total_num_rows = plan_info.total_num_rows; if (plan_info.enable_cuda_graph) { - params.total_num_rows = - GetPtrFromBaseOffset(int_buffer_ptr, plan_info.total_num_rows_offset); + params.total_num_rows = prefill_params.total_num_rows; } } @@ -263,15 +284,17 @@ void pod_with_kv_cache_tensor( using DecodeAttentionVariant = DefaultAttention; - // DISPATCH_CTA_TILE_Q(plan_info.cta_tile_q, CTA_TILE_Q, { - constexpr size_t CTA_TILE_Q = 16; - cudaError_t status = flashinfer::PODWithKVCacheTensorDispatched< - HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE, USE_FP16_QK_REDUCTION, MASK_MODE_P, - CTA_TILE_Q, MASK_MODE_D, PrefillAttentionVariant, DecodeAttentionVariant>( - prefill_params, static_cast(tmp_p.data_ptr()), decode_params, tmp_v, tmp_s, - enable_pdl, stream); - TVM_FFI_ICHECK(status == cudaSuccess) - << "PODWithKVCache kernel launch failed, error: " << cudaGetErrorString(status); - //}); + DISPATCH_CTA_TILE_Q(plan_info.cta_tile_q_p, CTA_TILE_Q_P, { + TVM_FFI_ICHECK(plan_info.cta_tile_q_d == 16) + << "Decode tile size should be 16 for POD. Check planner."; + constexpr size_t CTA_TILE_Q_D = 16; + cudaError_t status = flashinfer::PODWithKVCacheTensorDispatched< + HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE, USE_FP16_QK_REDUCTION, MASK_MODE_P, + CTA_TILE_Q_P, CTA_TILE_Q_D, MASK_MODE_D, PrefillAttentionVariant, + DecodeAttentionVariant>(prefill_params, decode_params, tmp_v, tmp_s, enable_pdl, + stream); + TVM_FFI_ICHECK(status == cudaSuccess) + << "PODWithKVCache kernel launch failed, error: " << cudaGetErrorString(status); + }); }); } diff --git a/csrc/pod_config.inc b/csrc/pod_config.inc new file mode 100644 index 0000000000..3c95f4286b --- /dev/null +++ b/csrc/pod_config.inc @@ -0,0 +1,45 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "aot_default_additional_params.h" +#include "aot_extension_utils.h" + +using namespace flashinfer; + +#define DISPATCH_context(MASK_MODE_P, MASK_MODE_D, DTypeQ, DTypeKV, HEAD_DIM_QK, \ + USE_SLIDING_WINDOW_P, USE_SLIDING_WINDOW_D, USE_LOGITS_SOFT_CAP, ...) \ +{ \ + DISPATCH_mask_mode(mask_mode_p, MASK_MODE_P, [&] { \ + return DISPATCH_mask_mode(mask_mode_d, MASK_MODE_D, [&] { \ + return DISPATCH_PYTORCH_QKV_DTYPE_TO_CTYPE( \ + q_scalar_type, kv_scalar_type, DTypeQ, DTypeKV, [&] { \ + using DTypeO = DTypeQ; \ + constexpr auto POS_ENCODING_MODE = PosEncodingMode::kNone; \ + constexpr bool USE_FP16_QK_REDUCTION = false; \ + return DISPATCH_head_dim(head_dim_qk, HEAD_DIM_QK, [&] { \ + [[maybe_unused]] constexpr int HEAD_DIM_VO = HEAD_DIM_QK; \ + return DISPATCH_BOOL(window_left_p > -1, USE_SLIDING_WINDOW_P, [&] { \ + return DISPATCH_BOOL(window_left_d > -1, USE_SLIDING_WINDOW_D, [&] { \ + return DISPATCH_BOOL(false, USE_LOGITS_SOFT_CAP, [&] { \ + using IdType = int32_t; \ + using PrefillParams = BatchPrefillPagedParams;\ + using DecodeParams = BatchPrefillPagedParams; \ + __VA_ARGS__(); \ + return true; \ + }); \ + }); \ + }); \ + }); \ + }); \ + }); \ + }); \ +} diff --git a/csrc/pod_customize_config.jinja b/csrc/pod_customize_config.jinja index b4c56a0e82..bde2ed3af4 100644 --- a/csrc/pod_customize_config.jinja +++ b/csrc/pod_customize_config.jinja @@ -30,7 +30,7 @@ constexpr auto USE_SLIDING_WINDOW_D = {{ use_sliding_window_d }}; constexpr auto POS_ENCODING_MODE = PosEncodingMode::kNone; constexpr bool USE_LOGITS_SOFT_CAP = false; -using PrefillParams = SinglePrefillParams; +using PrefillParams = BatchPrefillPagedParams; using DecodeParams = BatchPrefillPagedParams; #define DISPATCH_context(MASK_MODE_P, MASK_MODE_D, DTypeQ, DTypeKV, HEAD_DIM_QK, \ diff --git a/csrc/pod_jit_binding.cu b/csrc/pod_jit_binding.cu index 915e4bcdbf..a0ebd2cbfc 100644 --- a/csrc/pod_jit_binding.cu +++ b/csrc/pod_jit_binding.cu @@ -19,23 +19,29 @@ using tvm::ffi::Array; using tvm::ffi::Optional; -void pod_with_kv_cache_tensor( +Array PODWithKVCachePlan( + TensorView float_workspace_buffer, TensorView int_workspace_buffer, + TensorView page_locked_int_workspace_buffer, TensorView qo_indptr_p, TensorView kv_indptr_p, + int64_t total_num_rows_p, int64_t batch_size_p, TensorView qo_indptr_d, TensorView kv_indptr_d, + int64_t total_num_rows_d, int64_t batch_size_d, int64_t num_qo_heads, int64_t num_kv_heads, + int64_t head_dim_qk, int64_t head_dim_vo, int64_t page_size, bool enable_cuda_graph); + +void PODWithKVCacheTensorRun( + // Shared params (match implementation in pod.cu) + TensorView float_workspace_buffer_d, TensorView int_workspace_buffer_d, + Array plan_info_vec, TensorView paged_k_cache, TensorView paged_v_cache, + TensorView qo_indptr, TensorView paged_kv_indptr, TensorView paged_kv_indices, + TensorView paged_kv_last_page_len, TensorView o, Optional maybe_lse, int64_t layout, // Prefill params - TensorView q_p, TensorView k_p, TensorView v_p, TensorView tmp_p, TensorView o_p, - Optional maybe_lse_p, int64_t mask_mode_code_p, int64_t layout_p, - int64_t window_left_p, Optional maybe_custom_mask_p, - Optional maybe_alibi_slopes_p, double logits_soft_cap_p, double sm_scale_p, - double rope_rcp_scale_p, double rope_rcp_theta_p, + TensorView q_p, int64_t mask_mode_code_p, int64_t window_left_p, + Optional maybe_custom_mask_p, Optional maybe_alibi_slopes_p, + double logits_soft_cap_p, double sm_scale_p, double rope_rcp_scale_p, double rope_rcp_theta_p, // Decode params - TensorView float_workspace_buffer_d, TensorView int_workspace_buffer_d, - Array plan_info_vec, TensorView q_d, TensorView paged_k_cache_d, - TensorView paged_v_cache_d, TensorView qo_indptr_d, TensorView paged_kv_indptr_d, - TensorView paged_kv_indices_d, TensorView paged_kv_last_page_len_d, TensorView o_d, - Optional maybe_lse_d, int64_t mask_mode_code_d, int64_t layout_d, - int64_t window_left_d, Optional maybe_custom_mask_d, - Optional maybe_mask_indptr_d, Optional maybe_alibi_slopes_d, - double logits_soft_cap_d, double sm_scale_d, double rope_rcp_scale_d, double rope_rcp_theta_d, - bool enable_pdl); + TensorView q_d, int64_t mask_mode_code_d, int64_t window_left_d, + Optional maybe_custom_mask_d, Optional maybe_mask_indptr_d, + Optional maybe_alibi_slopes_d, double logits_soft_cap_d, double sm_scale_d, + double rope_rcp_scale_d, double rope_rcp_theta_d, bool enable_pdl); // Batch-request prefill attention with KV-Cache operator -TVM_FFI_DLL_EXPORT_TYPED_FUNC(pod_with_kv_cache_tensor, pod_with_kv_cache_tensor); +TVM_FFI_DLL_EXPORT_TYPED_FUNC(PODWithKVCachePlan, PODWithKVCachePlan); +TVM_FFI_DLL_EXPORT_TYPED_FUNC(PODWithKVCacheTensorRun, PODWithKVCacheTensorRun); diff --git a/csrc/pod_jit_pybind.cu b/csrc/pod_jit_pybind.cu new file mode 100644 index 0000000000..6f156c4037 --- /dev/null +++ b/csrc/pod_jit_pybind.cu @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2023-2025 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "pod_config.inc" +#include "pytorch_extension_utils.h" + +void PODWithKVCacheTensorRun( + // Shared params + at::Tensor float_workspace_buffer_d, at::Tensor int_workspace_buffer_d, + at::Tensor plan_info_vec, at::Tensor paged_k_cache, at::Tensor paged_v_cache, + at::Tensor qo_indptr, at::Tensor paged_kv_indptr, at::Tensor paged_kv_indices, + at::Tensor paged_kv_last_page_len, at::Tensor o, std::optional maybe_lse, + int64_t layout, + // Prefill params + at::Tensor q_p, int64_t mask_mode_code_p, int64_t window_left_p, + std::optional maybe_custom_mask_p, std::optional maybe_alibi_slopes_p, + double logits_soft_cap_p, double sm_scale_p, double rope_rcp_scale_p, double rope_rcp_theta_p, + // Decode params + at::Tensor q_d, int64_t mask_mode_code_d, int64_t window_left_d, + std::optional maybe_custom_mask_d, std::optional maybe_mask_indptr_d, + std::optional maybe_alibi_slopes_d, double logits_soft_cap_d, double sm_scale_d, + double rope_rcp_scale_d, double rope_rcp_theta_d, bool enable_pdl); + +at::Tensor PODWithKVCachePlan(at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, + at::Tensor page_locked_int_workspace_buffer, at::Tensor qo_indptr_p, + at::Tensor kv_indptr_p, int64_t total_num_rows_p, + int64_t batch_size_p, at::Tensor qo_indptr_d, at::Tensor kv_indptr_d, + int64_t total_num_rows_d, int64_t batch_size_d, int64_t num_qo_heads, + int64_t num_kv_heads, int64_t head_dim_qk, int64_t head_dim_vo, + int64_t page_size, bool enable_cuda_graph); + +TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) { + // Batch-request prefill attention with KV-Cache operator + m.def("PODWithKVCacheTensor", PODWithKVCacheTensorRun); + m.def("PODWithKVCachePlan", PODWithKVCachePlan); +} diff --git a/csrc/pod_kernel_inst.jinja b/csrc/pod_kernel_inst.jinja index 926e9bce8d..b0e6d9513a 100644 --- a/csrc/pod_kernel_inst.jinja +++ b/csrc/pod_kernel_inst.jinja @@ -11,20 +11,25 @@ #include "pod_config.inc" -using namespace flashinfer; - namespace flashinfer { + +using PrefillParams = BatchPrefillPagedParams<{{ dtype_q }}, {{ dtype_kv }}, {{ dtype_o }}, {{ idtype }}>; +using DecodeParams = BatchPrefillPagedParams<{{ dtype_q }}, {{ dtype_kv }}, {{ dtype_o }}, {{ idtype }}>; + +constexpr auto POS_ENCODING_MODE = PosEncodingMode::kNone; + constexpr auto use_custom_mask_p = {{ mask_mode_p }} == MaskMode::kCustom; constexpr auto use_custom_mask_d = {{ mask_mode_d }} == MaskMode::kCustom; -// Not sure about the below declaration -constexpr auto POS_ENCODING_MODE = PosEncodingMode::kNone; +using PrefillAttentionVariant = DefaultAttention; +using DecodeAttentionVariant = DefaultAttention; + +{% for cta_tile_q_p in [16, 64, 128] %} template cudaError_t PODWithKVCacheTensorDispatched< - {{ head_dim_qk }}, {{ head_dim_vo }}, POS_ENCODING_MODE, - {{ use_fp16_qk_reduction }}, {{ mask_mode_p }}, 16, - {{ mask_mode_d }}, {{ variant_name_p }}, - {{ variant_name_d }}, PrefillParams, DecodeParams>( - PrefillParams prefill_params, {{ dtype_o }}* tmp, - DecodeParams decode_params, {{ dtype_o }}* tmp_v, - float *tmp_s, bool enable_pdl, cudaStream_t stream); -}; + {{ head_dim_qk }}, {{ head_dim_vo }}, POS_ENCODING_MODE, {{ use_fp16_qk_reduction }}, {{ mask_mode_p }}, /*CTA_TILE_Q_P=*/{{cta_tile_q_p}}, 16, {{ mask_mode_d }}, + PrefillAttentionVariant, DecodeAttentionVariant, PrefillParams, DecodeParams>( + PrefillParams prefill_params, DecodeParams decode_params, + {{ dtype_o }}* tmp_v, float* tmp_s, bool enable_pdl, cudaStream_t stream); +{% endfor %} + +}; // namespace flashinfer diff --git a/flashinfer/jit/cpp_ext.py b/flashinfer/jit/cpp_ext.py index 2c3a56d92b..1f244f0edf 100644 --- a/flashinfer/jit/cpp_ext.py +++ b/flashinfer/jit/cpp_ext.py @@ -121,6 +121,12 @@ def generate_ninja_build_for_op( if not sysconfig.get_config_var("Py_GIL_DISABLED"): common_cflags.append("-DPy_LIMITED_API=0x03090000") common_cflags += _get_glibcxx_abi_build_flags() + + # Add debug flags if FLASHINFER_DEBUG is set + debug = os.environ.get("FLASHINFER_DEBUG", "0") == "1" + if debug: + common_cflags.append("-g") + if extra_include_dirs is not None: for extra_dir in extra_include_dirs: common_cflags.append(f"-I{extra_dir.resolve()}") diff --git a/flashinfer/pod.py b/flashinfer/pod.py index 59e113f238..721cb758df 100644 --- a/flashinfer/pod.py +++ b/flashinfer/pod.py @@ -21,9 +21,7 @@ import torch -from .jit import gen_pod_module -from .page import get_seq_lens -from .prefill import get_batch_prefill_module +from .jit import gen_pod_module, get_pod_uri from .quantization import packbits from .utils import ( MaskMode, @@ -33,18 +31,126 @@ _check_kv_layout, _check_pos_encoding_mode, _get_cache_alibi_slopes_buf, - _get_cache_buf, _get_range_buf, _unpack_paged_kv_cache, canonicalize_torch_dtype, device_support_pdl, + register_custom_op, + register_fake_op, ) @functools.cache def get_pod_module(*args): + """Get POD module with cached compilation.""" + # Use the proper JIT compilation system like batch prefill + uri = get_pod_uri(*args) module = gen_pod_module(*args).build_and_load() - return SimpleNamespace(run_tensor=module.pod_with_kv_cache_tensor) + plan_func = module.PODWithKVCachePlan + run_tensor_func = module.PODWithKVCacheTensorRun + + # Register custom op for POD tensor run + @register_custom_op( + f"flashinfer::{uri}_pod_run", + mutates_args=( + "float_workspace_buffer", + "int_workspace_buffer", + "paged_k_cache", + "paged_v_cache", + "o", + "maybe_lse", + ), + ) + def pod_run( + float_workspace_buffer: torch.Tensor, + int_workspace_buffer: torch.Tensor, + plan_info_vec: List[int], + paged_k_cache: torch.Tensor, + paged_v_cache: torch.Tensor, + qo_indptr: torch.Tensor, + paged_kv_indptr: torch.Tensor, + paged_kv_indices: torch.Tensor, + paged_kv_last_page_len: torch.Tensor, + o: torch.Tensor, + maybe_lse: Optional[torch.Tensor], + layout: int, + # Prefill params + q_p: torch.Tensor, + mask_mode_code_p: int, + window_left_p: int, + maybe_custom_mask_p: Optional[torch.Tensor], + maybe_alibi_slopes_p: Optional[torch.Tensor], + logits_soft_cap_p: float, + sm_scale_p: float, + rope_rcp_scale_p: float, + rope_rcp_theta_p: float, + # Decode params + q_d: torch.Tensor, + mask_mode_code_d: int, + window_left_d: int, + maybe_custom_mask_d: Optional[torch.Tensor], + maybe_mask_indptr_d: Optional[torch.Tensor], + maybe_alibi_slopes_d: Optional[torch.Tensor], + logits_soft_cap_d: float, + sm_scale_d: float, + rope_rcp_scale_d: float, + rope_rcp_theta_d: float, + enable_pdl: bool, + ) -> None: + run_tensor_func( + float_workspace_buffer, + int_workspace_buffer, + plan_info_vec, + paged_k_cache, + paged_v_cache, + qo_indptr, + paged_kv_indptr, + paged_kv_indices, + paged_kv_last_page_len, + o, + maybe_lse, + layout, + q_p, + mask_mode_code_p, + window_left_p, + maybe_custom_mask_p, + maybe_alibi_slopes_p, + logits_soft_cap_p, + sm_scale_p, + rope_rcp_scale_p, + rope_rcp_theta_p, + q_d, + mask_mode_code_d, + window_left_d, + maybe_custom_mask_d, + maybe_mask_indptr_d, + maybe_alibi_slopes_d, + logits_soft_cap_d, + sm_scale_d, + rope_rcp_scale_d, + rope_rcp_theta_d, + enable_pdl, + ) + + @register_fake_op(f"flashinfer::{uri}_pod_run") + def _fake_pod_run(*args) -> None: + pass + + # # Create a simple namespace that wraps the JIT module functions + # class PODModule: + # def __init__(self): + # pass + + # def plan(self, *args): + # """Call the POD plan function.""" + # return plan_func(*args) + + # def run_tensor(self, *args): + # """Call the POD tensor run function.""" + # return pod_run(*args) + + # return PODModule() + return SimpleNamespace(run_tensor=pod_run, plan=plan_func) class PODWithPagedKVCacheWrapper: @@ -65,7 +171,7 @@ class PODWithPagedKVCacheWrapper: >>> page_size = 16 >>> # allocate 128MB workspace buffer >>> workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device="cuda:0") - >>> decode_wrapper = flashinfer.PODWithPagedKVCacheWrapper( + >>> wrapper = flashinfer.PODWithPagedKVCacheWrapper( ... workspace_buffer, "NHD" ... ) >>> batch_size = 7 @@ -83,7 +189,7 @@ class PODWithPagedKVCacheWrapper: ... ) for _ in range(num_layers) ... ] >>> # create auxiliary data structures for batch decode attention - >>> decode_wrapper.plan( + >>> wrapper.plan( ... kv_page_indptr, ... kv_page_indices, ... kv_last_page_len, @@ -118,9 +224,12 @@ def __init__( float_workspace_buffer: torch.Tensor, kv_layout: str = "NHD", use_cuda_graph: bool = False, + qo_indptr_buffer: Optional[torch.Tensor] = None, paged_kv_indptr_buffer: Optional[torch.Tensor] = None, paged_kv_indices_buffer: Optional[torch.Tensor] = None, paged_kv_last_page_len_buffer: Optional[torch.Tensor] = None, + custom_mask_buf_p: Optional[torch.Tensor] = None, + mask_indptr_buf_p: Optional[torch.Tensor] = None, jit_args: Optional[List[Any]] = None, ) -> None: r"""Constructor of :class:`PODWithPagedKVCacheWrapper`. @@ -140,19 +249,24 @@ def __init__( auxiliary data structures will be stored as the provided buffers. The ``batch_size`` cannot change during the lifecycle of this wrapper when CUDAGraph is enabled. - indptr_buffer : Optional[torch.Tensor] - The user reserved buffer on GPU to store the indptr of the paged kv cache, the size + qo_indptr_buffer: Optional[torch.Tensor] + The user reserved buffer to store the ``qo_indptr`` array, the size of the buffer + should be ``[batch_size + 1]``. + This argument is only effective when ``use_cuda_graph`` is ``True``. + + paged_kv_indptr_buffer: Optional[torch.Tensor] + The user reserved buffer on GPU to store the indptr of the prefill paged kv cache, the size of the buffer should be ``[batch_size + 1]``. Only needed when ``use_cuda_graph`` is ``True``. - indices_buffer : Optional[torch.Tensor] - The user reserved buffer on GPU to store the page indices of the paged kv cache, + paged_kv_indices_buffer: Optional[torch.Tensor] + The user reserved buffer on GPU to store the page indices of the prefill paged kv cache, should be large enough to store the maximum number of page indices (``max_num_pages``) during the lifecycle of this wrapper. Only needed when ``use_cuda_graph`` is ``True``. - last_page_len_buffer : Optional[torch.Tensor] - The user reserved buffer on GPU to store the number of entries in the last page, the + paged_kv_last_page_len_buffer: Optional[torch.Tensor] + The user reserved buffer on GPU to store the number of entries in the last page for prefill, the size of the buffer should be ``[batch_size]``. Only needed when ``use_cuda_graph`` is ``True``. @@ -176,10 +290,14 @@ def __init__( # Override options. Only tensor core version is performant. use_tensor_cores = True self._jit_module: SimpleNamespace = None + assert custom_mask_buf_p is None and mask_indptr_buf_p is None, ( + "custom_mask_buf_p and mask_indptr_buf_p are not supported yet" + ) self._kv_layout = kv_layout self._float_workspace_buffer = float_workspace_buffer self.device = float_workspace_buffer.device + self._qo_indptr_buf = qo_indptr_buffer self._int_workspace_buffer = torch.empty( (8 * 1024 * 1024,), dtype=torch.uint8, device=self.device ) @@ -191,22 +309,36 @@ def __init__( ) if use_cuda_graph: - if not torch.is_tensor(paged_kv_indptr_buffer): + if not torch.is_tensor(qo_indptr_buffer): + raise ValueError( + "qo_indptr_buffer should be a torch.Tensor in CUDA graph mode" + ) + if not torch.is_tensor(paged_kv_indptr_buffer) or not torch.is_tensor( + paged_kv_indptr_buffer + ): raise ValueError( "paged_kv_indptr_buffer should be a torch.Tensor in cudagraph mode" ) - if not torch.is_tensor(paged_kv_indices_buffer): + if not torch.is_tensor(paged_kv_indices_buffer) or not torch.is_tensor( + paged_kv_indices_buffer + ): raise ValueError( "paged_kv_indices_buffer should be a torch.Tensor in cudagraph mode" ) - if not torch.is_tensor(paged_kv_last_page_len_buffer): + if not torch.is_tensor( + paged_kv_last_page_len_buffer + ) or not torch.is_tensor(paged_kv_last_page_len_buffer): raise ValueError( "paged_kv_last_page_len_buffer should be a torch.Tensor in cudagraph mode" ) self._fixed_batch_size = len(paged_kv_last_page_len_buffer) if len(paged_kv_indptr_buffer) != self._fixed_batch_size + 1: raise ValueError( - "The size of paged_kv_indptr_buffer should be batch_size + 1" + "The length of paged_kv_indptr_buffer_p should be batch_size + 1" + ) + if len(paged_kv_last_page_len_buffer) != self._fixed_batch_size: + raise ValueError( + "The length of paged_kv_last_page_len_buffer_p should be batch_size" ) else: self._fixed_batch_size = 0 @@ -216,7 +348,6 @@ def __init__( self._paged_kv_last_page_len_buf = paged_kv_last_page_len_buffer self._use_tensor_cores = use_tensor_cores self._use_cuda_graph = use_cuda_graph - if use_cuda_graph: # NOTE(Zihao): if once created, no need to update it in plan/run self._qo_indptr_buf = torch.arange( @@ -255,9 +386,13 @@ def reset_workspace_buffer( def plan( self, - indptr: torch.Tensor, - indices: torch.Tensor, - last_page_len: torch.Tensor, + qo_indptr_p: torch.Tensor, + kv_indptr_p: torch.Tensor, + kv_indices_p: torch.Tensor, + last_page_len_p: torch.Tensor, + kv_indptr_d: torch.Tensor, + kv_indices_d: torch.Tensor, + last_page_len_d: torch.Tensor, num_qo_heads: int, num_kv_heads: int, head_dim: int, @@ -268,6 +403,7 @@ def plan( kv_data_type: Optional[Union[str, torch.dtype]] = None, data_type: Optional[Union[str, torch.dtype]] = None, sm_scale: Optional[float] = None, + logits_soft_cap: Optional[float] = 0.0, rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, non_blocking: bool = True, @@ -276,12 +412,21 @@ def plan( Parameters ---------- - indptr : torch.Tensor - The indptr of the paged kv cache, shape: ``[batch_size + 1]`` - indices : torch.Tensor - The page indices of the paged kv cache, shape: ``[qo_indptr[-1]]`` - last_page_len : torch.Tensor - The number of entries in the last page of each request in the paged kv + qo_indptr_p: torch.Tensor + The indptr of the query/output tensor for prefill, shape: ``[batch_size + 1]``. + kv_indptr_p: torch.Tensor + The indptr of the paged kv cache for prefill, shape: ``[batch_size + 1]``. + kv_indices_p: torch.Tensor + The page indices of the paged kv cache for prefill, shape: ``[kv_indptr[-1]]``. + last_page_len_p : torch.Tensor + The number of entries in the last page of each request in the kv + cache, shape: ``[batch_size]`` + kv_indptr_d : torch.Tensor + The indptr of the paged kv cache for decode, shape: ``[batch_size + 1]`` + kv_indices_d : torch.Tensor + The page indices of the paged kv cache for decode, shape: ``[kv_indptr[-1]]`` + last_page_len_d : torch.Tensor + The number of entries in the last page of each request in the kv cache, shape: ``[batch_size]`` num_qo_heads : int The number of query/output heads @@ -322,46 +467,59 @@ def plan( The :meth:`plan` method cannot be used in Cuda Graph or in ``torch.compile``. """ - # Logits soft cap is not supported currently - batch_size = len(last_page_len) - logits_soft_cap = 0.0 - - qo_indptr_host = _get_range_buf(batch_size + 1, "cpu") + # Logits soft cap is not supported currently; keep a float for typing + batch_size_p = len(last_page_len_p) + batch_size_d = len(last_page_len_d) + batch_size = batch_size_p + batch_size_d + # keep logits_soft_cap as float consistently + + qo_indptr_host_p = qo_indptr_p.to("cpu", non_blocking=True) + qo_indptr_host_d = _get_range_buf(batch_size_d + 1, "cpu") + to_device = lambda x: x.to(self.device, non_blocking=non_blocking) + qo_indptr_p = to_device(qo_indptr_p) + qo_indptr = torch.cat( + [qo_indptr_p, to_device(qo_indptr_host_d)[1:] + qo_indptr_p[-1]] + ) + kv_indptr_p = to_device(kv_indptr_p) + kv_indptr = torch.cat( + [kv_indptr_p, to_device(kv_indptr_d)[1:] + kv_indptr_p[-1]] + ) + kv_indices_p = to_device(kv_indices_p) + kv_indices = torch.cat( + [kv_indices_p, to_device(kv_indices_d)[1:] + kv_indices_p[-1]] + ) + last_page_len = torch.cat( + [to_device(last_page_len_p), to_device(last_page_len_d)] + ) if self.is_cuda_graph_enabled: if batch_size != self._fixed_batch_size: raise ValueError( "The batch size should be fixed in cudagraph mode, the runtime batch size {} " " mismatches the batch size set during initialization {}".format( - batch_size, self._fixed_batch_size + batch_size_d, self._fixed_batch_size ) ) - if len(indices) > len(self._paged_kv_indices_buf): + if len(kv_indices_d) + len(kv_indices_p) > len(self._paged_kv_indices_buf): raise ValueError( "The size of indices should be less than or equal to the allocated buffer" ) - self._paged_kv_indptr_buf.copy_(indptr, non_blocking=non_blocking) - self._paged_kv_last_page_len_buf.copy_( + self._paged_kv_indptr_buf[: batch_size + 1].copy_( + kv_indptr, non_blocking=non_blocking + ) + self._paged_kv_last_page_len_buf[: batch_size + 1].copy_( last_page_len, non_blocking=non_blocking ) - self._paged_kv_indices_buf[: len(indices)].copy_( - indices, non_blocking=(indices.device == self.device) and non_blocking + self._paged_kv_indices_buf[: len(kv_indices)].copy_( + kv_indices, non_blocking=non_blocking ) else: - self._paged_kv_indptr_buf = indptr.to( - self.device, non_blocking=non_blocking - ) - self._paged_kv_indices_buf = indices.to( - self.device, non_blocking=non_blocking - ) - self._paged_kv_last_page_len_buf = last_page_len.to( - self.device, non_blocking=non_blocking - ) - self._qo_indptr_buf = qo_indptr_host.to( - self.device, non_blocking=non_blocking - ) + self._qo_indptr_buf = qo_indptr + self._paged_kv_indptr_buf = kv_indptr + self._paged_kv_indices_buf = kv_indices + self._paged_kv_last_page_len_buf = last_page_len - indptr_host = indptr.to("cpu") - last_page_len_host = last_page_len.to("cpu") + kv_indptr_host_p = kv_indptr_p.to("cpu") + kv_indptr_host_d = kv_indptr_d.to("cpu") if data_type is not None: if q_data_type is None: @@ -376,79 +534,81 @@ def plan( self._cached_q_data_type = q_data_type self._cached_kv_data_type = kv_data_type - kv_lens_arr_host = get_seq_lens(indptr_host, last_page_len_host, page_size) + self._indptr_type = kv_indptr_d.dtype + self._pos_encoding_mode = pos_encoding_mode + self._window_left = window_left + self._logits_soft_cap = logits_soft_cap + self._sm_scale = sm_scale + self._rope_scale = rope_scale + self._rope_theta = rope_theta + assert ( + qo_indptr_p.dtype == self._indptr_type + and kv_indptr_p.dtype == self._indptr_type + and qo_indptr_host_d.dtype == self._indptr_type + and kv_indptr_d.dtype == self._indptr_type + and f"Indices dtype mismatch: {qo_indptr_p.dtype}, {kv_indptr_p.dtype}, {qo_indptr_host_d.dtype}, {kv_indptr_d.dtype}" + ) + if self._jit_module is not None: self._cached_module = self._jit_module else: - self._cached_module = get_batch_prefill_module( - "fa2", + self._cached_module = get_pod_module( + # Prefill params q_data_type, kv_data_type, q_data_type, - indptr.dtype, head_dim, # head_dim_qk - head_dim, # head_dim_vo PosEncodingMode[pos_encoding_mode].value, window_left != -1, # use_sliding_window logits_soft_cap > 0, # use_logits_soft_cap False, # use_fp16_qk_reduction + # Decode params + self._indptr_type, + PosEncodingMode[pos_encoding_mode].value, + window_left != -1, # use_sliding_window + logits_soft_cap > 0, # use_logits_soft_cap ) - self._plan_info = self._cached_module.plan( self._float_workspace_buffer, self._int_workspace_buffer, self._pin_memory_int_workspace_buffer, - qo_indptr_host, - indptr_host, - kv_lens_arr_host, - batch_size, # total_num_rows - batch_size, + qo_indptr_host_p, + kv_indptr_host_p, + int(qo_indptr_host_p[-1]), # total_num_rows_p + batch_size_p, + qo_indptr_host_d, + kv_indptr_host_d, + int(qo_indptr_host_d[-1]), # total_num_rows_d + batch_size_d, num_qo_heads, num_kv_heads, + head_dim, # head_dim_qk + head_dim, # head_dim_vo page_size, self.is_cuda_graph_enabled, - head_dim, - head_dim, - False, # causal - window_left, - -1, # fixed_split_size - False, # disable_split_kv ) - self._indptr_type = indptr.dtype - self._pos_encoding_mode = pos_encoding_mode - self._window_left = window_left - self._logits_soft_cap = logits_soft_cap - self._sm_scale = sm_scale - self._rope_scale = rope_scale - self._rope_theta = rope_theta - begin_forward = plan def run( self, # Main params (prefill and decode) q_p: torch.Tensor, - k_p: torch.Tensor, - v_p: torch.Tensor, q_d: torch.Tensor, - paged_kv_cache_d: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + paged_kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], # Prefill options custom_mask_p: Optional[torch.Tensor] = None, packed_custom_mask_p: Optional[torch.Tensor] = None, causal_p: bool = False, - kv_layout_p: str = "NHD", pos_encoding_mode_p: str = "NONE", sm_scale_p: Optional[float] = None, window_left_p: int = -1, rope_scale_p: Optional[float] = None, rope_theta_p: Optional[float] = None, - return_lse_p: bool = False, # Decode options custom_mask_d: Optional[torch.Tensor] = None, packed_custom_mask_d: Optional[torch.Tensor] = None, causal_d: bool = False, - kv_layout_d: str = "NHD", pos_encoding_mode_d: str = "NONE", sm_scale_d: Optional[float] = None, window_left_d: int = -1, @@ -457,7 +617,7 @@ def run( q_scale: Optional[float] = None, k_scale: Optional[float] = None, v_scale: Optional[float] = None, - return_lse_d: bool = False, + return_lse: bool = False, use_fp16_qk_reduction: bool = False, enable_pdl: Optional[bool] = None, *args, @@ -471,8 +631,6 @@ def run( logits_soft_cap_d = None # Prefill setup _check_pos_encoding_mode(pos_encoding_mode_p) - _check_kv_layout(kv_layout_p) - tmp_p = _get_cache_buf("pod_with_kv_cache_tmp", 32 * 1024 * 1024, q_p.device) if logits_soft_cap_p is None: logits_soft_cap_p = 0.0 if sm_scale_p is None: @@ -495,18 +653,27 @@ def run( else: mask_mode_p = MaskMode.NON_CAUSAL.value - lse_p = None - if return_lse_p: - lse_p = torch.empty( - (q_p.size(0), q_p.size(1)), dtype=torch.float32, device=q_p.device + lse = None + if return_lse: + lse = torch.empty( + (q_p.size(0) + q_d.size(0), q_p.size(1)), + dtype=torch.float32, + device=q_p.device, ) - - out_p = torch.empty_like(q_p) + qo_len_p, num_qo_heads, head_dim = q_p.shape + qo_len_d, _, _ = q_d.shape + out = torch.empty( + qo_len_p + qo_len_d, + num_qo_heads, + head_dim, + device=q_p.device, + dtype=q_p.dtype, + ) # Decode setup - k_cache_d, v_cache_d = _unpack_paged_kv_cache(paged_kv_cache_d, self._kv_layout) + k_cache, v_cache = _unpack_paged_kv_cache(paged_kv_cache, self._kv_layout) _check_cached_qkv_data_type( - q_d, k_cache_d, self._cached_q_data_type, self._cached_kv_data_type + q_d, k_cache, self._cached_q_data_type, self._cached_kv_data_type ) # TODO_AK: Where are these coming from? pos_encoding_mode_d = self._pos_encoding_mode @@ -531,17 +698,10 @@ def run( if rope_theta_d is None: rope_theta_d = 1e4 - lse_d = None - if return_lse_d: - lse_d = torch.empty( - (q_d.size(0), q_d.size(1)), dtype=torch.float32, device=q_d.device - ) - out_d = torch.empty_like(q_d) - module_getter = get_pod_module( # Prefill params q_p.dtype, - k_p.dtype, + k_cache.dtype, q_p.dtype, q_p.shape[-1], PosEncodingMode[pos_encoding_mode_p].value, @@ -559,16 +719,24 @@ def run( window_left_d != -1, # use_sliding_window logits_soft_cap_d > 0, # use_logits_soft_cap ) + module_getter.run_tensor( + # Shared params + self._float_workspace_buffer, + self._int_workspace_buffer, + self._plan_info, + k_cache, + v_cache, + self._qo_indptr_buf, # contains both prefill and decode indptr + self._paged_kv_indptr_buf, + self._paged_kv_indices_buf, + self._paged_kv_last_page_len_buf, + out, + lse, + TensorLayout[self._kv_layout].value, # Prefill params q_p, - k_p, - v_p, - tmp_p, - out_p, - lse_p, mask_mode_p, - TensorLayout[kv_layout_p].value, window_left_p, packed_custom_mask_p, _get_cache_alibi_slopes_buf(q_p.shape[1], q_p.device), @@ -577,20 +745,8 @@ def run( 1.0 / rope_scale_p, 1.0 / rope_theta_p, # Decode params - self._float_workspace_buffer, - self._int_workspace_buffer, - self._plan_info, q_d, - k_cache_d, - v_cache_d, - self._qo_indptr_buf, - self._paged_kv_indptr_buf, - self._paged_kv_indices_buf, - self._paged_kv_last_page_len_buf, - out_d, - lse_d, MaskMode.NON_CAUSAL.value, - TensorLayout[self._kv_layout].value, window_left_d, None, # packed_custom_mask None, # mask_indptr_buf @@ -603,9 +759,9 @@ def run( ) if v_scale is not None: - out_d *= v_scale + out *= v_scale - return (out_p, out_d) + return out[:qo_len_p], out[qo_len_p:] def end_forward(self) -> None: r"""Warning: this function is deprecated and has no effect.""" diff --git a/include/flashinfer/attention/cascade.cuh b/include/flashinfer/attention/cascade.cuh index 09f00f6852..640694bd85 100644 --- a/include/flashinfer/attention/cascade.cuh +++ b/include/flashinfer/attention/cascade.cuh @@ -354,11 +354,11 @@ __global__ void MergeStatesLargeNumIndexSetsKernel(DTypeIn* __restrict__ V, floa * \tparam DTypeO The data type of v_merged. * \param V The partial v of index sets. (nnz, h, d) * \param S The logsumexp value of index sets. (nnz, h) - * \param indptr The start offsets of each position in the variable length array. + * \param merge_indptr The start offsets of each position in the variable length array. * \param v_merged The merged v of index sets union. (n, h, d) * \param s_merged The merged logsumexp value of index sets union. (n, h) * \param max_seq_len The maximum sequence length supported by the kernel. - * \param seq_len_ptr The current sequence length (number of positions populated in indptr). + * \param seq_len_ptr The current sequence length (number of positions populated in merge_indptr). * \param num_heads The number of heads of v. * \param head_dim The dimension of each head. * \note s are logsumexp values with base 2. @@ -366,9 +366,9 @@ __global__ void MergeStatesLargeNumIndexSetsKernel(DTypeIn* __restrict__ V, floa template __global__ void PersistentVariableLengthMergeStatesKernel( - DTypeIn* __restrict__ V, float* __restrict__ S, IdType* indptr, DTypeO* __restrict__ v_merged, - float* __restrict__ s_merged, uint32_t max_seq_len, uint32_t* __restrict__ seq_len_ptr, - uint32_t num_heads) { + DTypeIn* __restrict__ V, float* __restrict__ S, IdType* merge_indptr, + DTypeO* __restrict__ v_merged, float* __restrict__ s_merged, uint32_t max_seq_len, + uint32_t* __restrict__ seq_len_ptr, uint32_t num_heads) { uint32_t tx = threadIdx.x, ty = threadIdx.y; uint32_t cta_id = blockIdx.x; uint32_t num_ctas = gridDim.x; @@ -383,7 +383,6 @@ __global__ void PersistentVariableLengthMergeStatesKernel( #if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) asm volatile("griddepcontrol.wait;"); #endif - #pragma unroll 1 for (uint32_t i = cta_id; i < seq_len * num_heads; i += num_ctas) { // NOTE (Yilong): necessary to prevent hazard on smaller `num_index_sets` @@ -392,7 +391,7 @@ __global__ void PersistentVariableLengthMergeStatesKernel( uint32_t pos = i / num_heads; uint32_t head_idx = i % num_heads; state_t st; - const uint32_t num_index_sets = indptr[pos + 1] - indptr[pos]; + const uint32_t num_index_sets = merge_indptr[pos + 1] - merge_indptr[pos]; if (num_index_sets == 0) { vec_t v; @@ -406,10 +405,10 @@ __global__ void PersistentVariableLengthMergeStatesKernel( if (num_index_sets == 1) { vec_t v; - v.cast_load(V + (indptr[pos] * num_heads + head_idx) * head_dim + tx * vec_size); + v.cast_load(V + (merge_indptr[pos] * num_heads + head_idx) * head_dim + tx * vec_size); v.store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size); if (s_merged != nullptr) { - s_merged[pos * num_heads + head_idx] = S[indptr[pos] * num_heads + head_idx]; + s_merged[pos * num_heads + head_idx] = S[merge_indptr[pos] * num_heads + head_idx]; } continue; } @@ -418,7 +417,8 @@ __global__ void PersistentVariableLengthMergeStatesKernel( for (uint32_t iter = 0; iter < num_smem_stages; ++iter) { cp_async::pred_load( v_smem + (iter * bdy + ty) * head_dim + tx * vec_size, - V + ((indptr[pos] + (iter * bdy + ty)) * num_heads + head_idx) * head_dim + tx * vec_size, + V + ((merge_indptr[pos] + (iter * bdy + ty)) * num_heads + head_idx) * head_dim + + tx * vec_size, (iter * bdy + ty) < num_index_sets); cp_async::commit_group(); } @@ -427,7 +427,7 @@ __global__ void PersistentVariableLengthMergeStatesKernel( if (iter % bdx == 0) { s_smem[ty * bdx + tx] = iter * bdy + (ty * bdx + tx) < num_index_sets - ? S[(indptr[pos] + (iter * bdy + ty * bdx + tx)) * num_heads + head_idx] + ? S[(merge_indptr[pos] + (iter * bdy + ty * bdx + tx)) * num_heads + head_idx] : 0.f; __syncthreads(); } @@ -437,13 +437,15 @@ __global__ void PersistentVariableLengthMergeStatesKernel( v.cast_load(v_smem + ((iter % num_smem_stages) * bdy + ty) * head_dim + tx * vec_size); if (iter * bdy + ty < num_index_sets) { float s = s_smem[(iter % bdx) * bdy + ty]; + // printf("Debug: qo_id: %d, head_idx: %d, kv_id: %d, output: %f\n", pos, head_idx, + // iter * bdy + ty, s); st.merge(v, s, 1); } __syncthreads(); cp_async::pred_load( v_smem + ((iter % num_smem_stages) * bdy + ty) * head_dim + tx * vec_size, V + - ((indptr[pos] + ((iter + num_smem_stages) * bdy + ty)) * num_heads + head_idx) * + ((merge_indptr[pos] + ((iter + num_smem_stages) * bdy + ty)) * num_heads + head_idx) * head_dim + tx * vec_size, (iter + num_smem_stages) * bdy + ty < num_index_sets); @@ -468,11 +470,9 @@ __global__ void PersistentVariableLengthMergeStatesKernel( template -__global__ void PersistentVariableLengthAttentionSumKernel(DTypeIn* __restrict__ V, IdType* indptr, - DTypeO* __restrict__ v_sum, - uint32_t max_seq_len, - uint32_t* __restrict__ seq_len_ptr, - uint32_t num_heads) { +__global__ void PersistentVariableLengthAttentionSumKernel( + DTypeIn* __restrict__ V, IdType* merge_indptr, DTypeO* __restrict__ v_sum, uint32_t max_seq_len, + uint32_t* __restrict__ seq_len_ptr, uint32_t num_heads) { uint32_t tx = threadIdx.x, ty = threadIdx.y; uint32_t cta_id = blockIdx.x; uint32_t num_ctas = gridDim.x; @@ -494,7 +494,7 @@ __global__ void PersistentVariableLengthAttentionSumKernel(DTypeIn* __restrict__ uint32_t pos = i / num_heads; uint32_t head_idx = i % num_heads; - const uint32_t num_index_sets = indptr[pos + 1] - indptr[pos]; + const uint32_t num_index_sets = merge_indptr[pos + 1] - merge_indptr[pos]; if (num_index_sets == 0) { vec_t v; @@ -505,7 +505,7 @@ __global__ void PersistentVariableLengthAttentionSumKernel(DTypeIn* __restrict__ if (num_index_sets == 1) { vec_t v; - v.cast_load(V + (indptr[pos] * num_heads + head_idx) * head_dim + tx * vec_size); + v.cast_load(V + (merge_indptr[pos] * num_heads + head_idx) * head_dim + tx * vec_size); v.store(v_sum + (pos * num_heads + head_idx) * head_dim + tx * vec_size); continue; } @@ -514,7 +514,8 @@ __global__ void PersistentVariableLengthAttentionSumKernel(DTypeIn* __restrict__ for (uint32_t iter = 0; iter < num_smem_stages; ++iter) { cp_async::pred_load( v_smem + (iter * bdy + ty) * head_dim + tx * vec_size, - V + ((indptr[pos] + (iter * bdy + ty)) * num_heads + head_idx) * head_dim + tx * vec_size, + V + ((merge_indptr[pos] + (iter * bdy + ty)) * num_heads + head_idx) * head_dim + + tx * vec_size, (iter * bdy + ty) < num_index_sets); cp_async::commit_group(); } @@ -534,7 +535,7 @@ __global__ void PersistentVariableLengthAttentionSumKernel(DTypeIn* __restrict__ cp_async::pred_load( v_smem + ((iter % num_smem_stages) * bdy + ty) * head_dim + tx * vec_size, V + - ((indptr[pos] + ((iter + num_smem_stages) * bdy + ty)) * num_heads + head_idx) * + ((merge_indptr[pos] + ((iter + num_smem_stages) * bdy + ty)) * num_heads + head_idx) * head_dim + tx * vec_size, (iter + num_smem_stages) * bdy + ty < num_index_sets); @@ -684,7 +685,7 @@ cudaError_t AttentionSum(DTypeIn* v, DTypeO* v_sum, uint32_t num_index_sets, uin } template -cudaError_t VariableLengthMergeStates(DTypeIn* v, float* s, IdType* indptr, DTypeO* v_merged, +cudaError_t VariableLengthMergeStates(DTypeIn* v, float* s, IdType* merge_indptr, DTypeO* v_merged, float* s_merged, uint32_t max_seq_len, uint32_t* seq_len, uint32_t num_heads, uint32_t head_dim, bool enable_pdl, cudaStream_t stream = nullptr) { @@ -707,10 +708,10 @@ cudaError_t VariableLengthMergeStates(DTypeIn* v, float* s, IdType* indptr, DTyp FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks_per_sm, kernel, num_threads, smem_size)); num_blocks_per_sm = min(num_blocks_per_sm, ceil_div(max_seq_len * num_heads, num_sms)); - dim3 nblks(num_sms * num_blocks_per_sm); dim3 nthrs(bdx, bdy); - void* args[] = {&v, &s, &indptr, &v_merged, &s_merged, &max_seq_len, &seq_len, &num_heads}; + void* args[] = {&v, &s, &merge_indptr, &v_merged, + &s_merged, &max_seq_len, &seq_len, &num_heads}; FLASHINFER_CUDA_CALL( cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); @@ -726,8 +727,8 @@ cudaError_t VariableLengthMergeStates(DTypeIn* v, float* s, IdType* indptr, DTyp config.blockDim = nthrs; config.dynamicSmemBytes = smem_size; config.stream = stream; - FLASHINFER_CUDA_CALL(cudaLaunchKernelEx(&config, kernel, v, s, indptr, v_merged, s_merged, - max_seq_len, seq_len, num_heads)); + FLASHINFER_CUDA_CALL(cudaLaunchKernelEx(&config, kernel, v, s, merge_indptr, v_merged, + s_merged, max_seq_len, seq_len, num_heads)); } else { FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); } @@ -736,7 +737,7 @@ cudaError_t VariableLengthMergeStates(DTypeIn* v, float* s, IdType* indptr, DTyp } template -cudaError_t VariableLengthAttentionSum(DTypeIn* v, IdType* indptr, DTypeO* v_sum, +cudaError_t VariableLengthAttentionSum(DTypeIn* v, IdType* merge_indptr, DTypeO* v_sum, uint32_t max_seq_len, uint32_t* seq_len, uint32_t num_heads, uint32_t head_dim, bool enable_pdl, cudaStream_t stream = nullptr) { @@ -761,7 +762,7 @@ cudaError_t VariableLengthAttentionSum(DTypeIn* v, IdType* indptr, DTypeO* v_sum dim3 nblks(num_sms * num_blocks_per_sm); dim3 nthrs(bdx, bdy); - void* args[] = {&v, &indptr, &v_sum, &max_seq_len, &seq_len, &num_heads}; + void* args[] = {&v, &merge_indptr, &v_sum, &max_seq_len, &seq_len, &num_heads}; FLASHINFER_CUDA_CALL( cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); @@ -777,8 +778,8 @@ cudaError_t VariableLengthAttentionSum(DTypeIn* v, IdType* indptr, DTypeO* v_sum config.blockDim = nthrs; config.dynamicSmemBytes = smem_size; config.stream = stream; - FLASHINFER_CUDA_CALL( - cudaLaunchKernelEx(&config, kernel, v, indptr, v_sum, max_seq_len, seq_len, num_heads)); + FLASHINFER_CUDA_CALL(cudaLaunchKernelEx(&config, kernel, v, merge_indptr, v_sum, max_seq_len, + seq_len, num_heads)); } else { FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); } diff --git a/include/flashinfer/attention/persistent.cuh b/include/flashinfer/attention/persistent.cuh index 86c65e1843..f073ff0c00 100644 --- a/include/flashinfer/attention/persistent.cuh +++ b/include/flashinfer/attention/persistent.cuh @@ -488,7 +488,7 @@ struct BlockBatchReductionPersistent { static __device__ __forceinline__ void Run( typename KTraits::DTypeIn* __restrict__ V, typename KTraits::DTypeO* __restrict__ v_merged, float* __restrict__ S, float* __restrict__ s_merged, - const typename KTraits::IdType num_packed_qo_len, const uint_fastdiv gqa_group_size, + const typename KTraits::IdType num_to_merge_qo_len, const uint_fastdiv gqa_group_size, const uint32_t num_kv_heads, const typename KTraits::IdType* indptr, const typename KTraits::IdType* o_indices, uint8_t* smem PROFILER_CLOSURE_FUNC_PARAMS) { __syncthreads(); // NOTE(Zihao): required for guarantee correctness on blackwell @@ -517,10 +517,10 @@ struct BlockBatchReductionPersistent { float* s_smem = (float*)(smem + num_warps * num_smem_stages * bdy * head_dim * sizeof(DTypeIn) + warp_idx * 32 * sizeof(float)); - // V: [num_packed_qo_len x num_kv_tiles, num_kv_heads, head_dim] + // V: [num_to_merge_qo_len x num_kv_tiles, num_kv_heads, head_dim] // v_merged: [qo_len, num_kv_heads, gqa_group_size, head_dim] #pragma unroll 1 - for (uint32_t i = worker_id; i < num_packed_qo_len * num_kv_heads; i += num_workers) { + for (uint32_t i = worker_id; i < num_to_merge_qo_len * num_kv_heads; i += num_workers) { PROFILER_EVENT_START(profiler_closure, PersistentProfileEventType::kReduction); __syncwarp(); // avoid data hazard due to reordering st.cast_store // remap workload diff --git a/include/flashinfer/attention/persistent_template.cuh b/include/flashinfer/attention/persistent_template.cuh index 3bd2331b3e..3c367cc8f8 100644 --- a/include/flashinfer/attention/persistent_template.cuh +++ b/include/flashinfer/attention/persistent_template.cuh @@ -81,7 +81,7 @@ __global__ __launch_bounds__( grid.sync(); BlockReductionRunner::Run(params_1.partial_o, params_1.final_o, params_1.partial_lse, - params_1.final_lse, *(params_1.num_packed_qo_len), + params_1.final_lse, *(params_1.num_to_merge_qo_len), params_1.gqa_group_size, params_1.num_kv_heads, params_1.merge_indptr, params_1.merge_o_indices, smem); #else @@ -90,7 +90,7 @@ __global__ __launch_bounds__( grid.sync(); BlockReductionRunner::Run(params_1.partial_o, params_1.final_o, params_1.partial_lse, - params_1.final_lse, *(params_1.num_packed_qo_len), + params_1.final_lse, *(params_1.num_to_merge_qo_len), params_1.gqa_group_size, params_1.num_kv_heads, params_1.merge_indptr, params_1.merge_o_indices, smem, profiler_closure); #endif diff --git a/include/flashinfer/attention/pod.cuh b/include/flashinfer/attention/pod.cuh index 03ffdb7551..e7cf7128fe 100644 --- a/include/flashinfer/attention/pod.cuh +++ b/include/flashinfer/attention/pod.cuh @@ -34,29 +34,25 @@ enum Operation { DECODE = 1, }; -template +template __global__ __launch_bounds__(std::max( KTraits_P::NUM_THREADS, - KTraits_D::NUM_THREADS)) void PODWithKVCacheTensorKernel(const uint32_t xsize, - const __grid_constant__ PrefillParams + KTraits_D::NUM_THREADS)) void PODWithKVCacheTensorKernel(const __grid_constant__ PrefillParams prefill_params, const __grid_constant__ DecodeParams decode_params, int* tbAssign) { extern __shared__ uint8_t smem[]; + const uint32_t num_kv_heads = prefill_params.paged_kv.num_heads; // PREFILL VARS - const uint32_t num_kv_heads_p = prefill_params.num_kv_heads; - const uint32_t num_chunks = prefill_params.partition_kv; - const uint32_t qo_len = prefill_params.qo_len; + const uint32_t padded_bsize_p = prefill_params.padded_batch_size; // DECODE VARS - const uint32_t padded_bsize = decode_params.padded_batch_size; - const uint32_t num_kv_heads_d = decode_params.paged_kv.num_heads; + const uint32_t padded_bsize_d = decode_params.padded_batch_size; // THREADBLOCKS - const uint32_t prefill_blocks = num_kv_heads_p * xsize * (PartitionKV_P ? num_chunks : 1); - const uint32_t decode_blocks = padded_bsize * num_kv_heads_d; + const uint32_t prefill_blocks = padded_bsize_p * num_kv_heads; + const uint32_t decode_blocks = padded_bsize_d * num_kv_heads; int op; int linear_bid; @@ -77,14 +73,26 @@ __global__ __launch_bounds__(std::max( const int prefill_slots = (prefill_blocks + blk_factor_p - 1) / blk_factor_p; const int decode_slots = (decode_blocks + blk_factor_d - 1) / blk_factor_d; - if (prefill_slots <= decode_slots) { + if (blockIdx.x == 0) { + printf("Debug: prefill_slots: %d, decode_slots: %d\n", prefill_slots, decode_slots); + } + + // Dispatch op type + if (prefill_slots == 0 && decode_slots == 0) + FLASHINFER_RUNTIME_ASSERT( + "Number of prefill and decode slots are both 0. Check your kv indices."); + else if (prefill_slots == 0) { + op = DECODE; + } else if (decode_slots == 0) { + op = PREFILL; + } else if (prefill_slots <= decode_slots) { // Total tags = (decode + prefill) / min(decode, prefill) // = 1 + decode / prefill; when prefill < decode const int total_tags = decode_slots / prefill_slots + 1; // For this SM, what's the next operation we want to run? op = (atomicAdd(&tbAssign[linear_bid], 1) % total_tags); if (op > 0) { - op = 1; + op = DECODE; } } else { // Total tags = (decode + prefill) / min(decode, prefill) @@ -94,11 +102,16 @@ __global__ __launch_bounds__(std::max( // For this SM, what's the next operation we want to run? op = (atomicAdd(&tbAssign[linear_bid], 1) % (pref_tags + 1)); if (op < pref_tags) { - op = 0; + op = PREFILL; } else { - op = 1; + op = DECODE; } } + if (op == 0) { + printf("Debug: block %d running prefill. op: %d\n", blockIdx.x, op); + } else { + printf("Debug: block %d running decode. op: %d\n", blockIdx.x, op); + } // Get the next blockId for that operation linear_bid = atomicAdd(&tbAssign[num_SMs + op], 1); @@ -110,72 +123,59 @@ __global__ __launch_bounds__(std::max( op = !op; linear_bid = atomicAdd(&tbAssign[num_SMs + 0], 1); } - // Write the blockId and operation to shared memory + // Write the global blockId and operation to shared memory ((int*)smem)[0] = linear_bid; ((int*)smem)[1] = op; } - // Sync to wait for dynamic scheduler to finish + // Sync to wait for dynamic scheduler to write to smem __syncthreads(); // Fetch from shared memory the assigned blockId and operation. linear_bid = ((int*)smem)[0]; op = ((int*)smem)[1]; // Sync to force all threads to wait - __syncthreads(); + // __syncthreads(); if (op == PREFILL) { const uint32_t linear_tid = threadIdx.x; // Return if threadId exceeds number of threads for this op if (linear_tid >= 32 * KTraits_P::NUM_WARPS_Q * KTraits_P::NUM_WARPS_KV) return; + if (linear_bid >= prefill_blocks) return; const dim3 tid = dim3(linear_tid % 32, (linear_tid / 32) % KTraits_P::NUM_WARPS_Q, (linear_tid / 32) / KTraits_P::NUM_WARPS_Q); - // dim3 nblks(ceil_div(qo_len * group_size, CTA_TILE_Q), 1, num_kv_heads); - // dim3 nblks(ceil_div(qo_len * group_size, CTA_TILE_Q), num_chunks, num_kv_heads); - // BlockID exceeds limit - if (linear_bid >= prefill_blocks) return; - - const uint32_t bx = linear_bid % xsize; auto& smem_storage = reinterpret_cast(smem); - // Not partition_kv - if constexpr (!PartitionKV_P) { - const uint32_t chunk_idx = 0; - const uint32_t kv_head_idx = linear_bid / xsize; - SinglePrefillWithKVCacheDevice(prefill_params, smem_storage, tid, bx, chunk_idx, - kv_head_idx, 1, num_kv_heads_p); - } else { - const uint32_t chunk_idx = (linear_bid / xsize) % num_chunks; - const uint32_t kv_head_idx = linear_bid / (xsize * num_chunks); - SinglePrefillWithKVCacheDevice(prefill_params, smem_storage, tid, bx, chunk_idx, - kv_head_idx, num_chunks, num_kv_heads_p); - } - } else /* OP == DECODE */ { - auto& smem_storage = reinterpret_cast(smem); - // dim3 nblks_d(padded_batch_size_d, 1, num_kv_heads); - if (linear_bid >= decode_blocks) return; + const uint32_t bx = linear_bid % padded_bsize_p; + const uint32_t kv_head_idx = linear_bid / padded_bsize_p; - const uint32_t bx = linear_bid % padded_bsize; - const uint32_t kv_head_idx = linear_bid / padded_bsize; + BatchPrefillWithPagedKVCacheDevice(prefill_params, smem_storage, tid, bx, + kv_head_idx, num_kv_heads); - // dim3 nthrs_d(32, NUM_WARPS_Q_D, NUM_WARPS_KV_D); + } else /* OP == DECODE */ { const uint32_t linear_tid = threadIdx.x; // Return if threadId exceeds number of threads for this op if (linear_tid >= 32 * KTraits_D::NUM_WARPS_Q * KTraits_D::NUM_WARPS_KV) return; + if (linear_bid >= decode_blocks) return; const dim3 tid = dim3(linear_tid % 32, (linear_tid / 32) % KTraits_D::NUM_WARPS_Q, (linear_tid / 32) / KTraits_D::NUM_WARPS_Q); + auto& smem_storage = reinterpret_cast(smem); + // dim3 nblks_d(padded_batch_size_d, 1, num_kv_heads); + const uint32_t bx = linear_bid % padded_bsize_d; + const uint32_t kv_head_idx = linear_bid / padded_bsize_d; + // dim3 nthrs_d(32, NUM_WARPS_Q_D, NUM_WARPS_KV_D); + + // Decode is faster with tensor cores, which are usually not saturated by prefill BatchPrefillWithPagedKVCacheDevice(decode_params, smem_storage, tid, bx, kv_head_idx, - num_kv_heads_d); + num_kv_heads); } } template -cudaError_t PODWithKVCacheTensorDispatched(PrefillParams prefill_params, - typename PrefillParams::DTypeO* tmp_p, - DecodeParams decode_params, + bool USE_FP16_QK_REDUCTION, MaskMode MASK_MODE_P, uint32_t CTA_TILE_Q_P, + uint32_t CTA_TILE_Q_D, MaskMode MASK_MODE_D, typename PrefillAttentionVariant, + typename DecodeAttentionVariant, typename PrefillParams, typename DecodeParams> +cudaError_t PODWithKVCacheTensorDispatched(PrefillParams prefill_params, DecodeParams decode_params, typename DecodeParams::DTypeO* tmp_v, float* tmp_s, bool enable_pdl, cudaStream_t stream) { static_assert(std::is_same::value); @@ -183,50 +183,17 @@ cudaError_t PODWithKVCacheTensorDispatched(PrefillParams prefill_params, std::is_same::value); static_assert(std::is_same::value); // Ensure heads match - assert(prefill_params.num_kv_heads == decode_params.paged_kv.num_heads); + assert(prefill_params.paged_kv.num_heads == decode_params.paged_kv.num_heads); assert(prefill_params.num_qo_heads == decode_params.num_qo_heads); // Prefill variable setup using DTypeQ_P = typename PrefillParams::DTypeQ; using DTypeKV_P = typename PrefillParams::DTypeKV; using DTypeO_P = typename PrefillParams::DTypeO; const uint32_t num_qo_heads = prefill_params.num_qo_heads; - const uint32_t num_kv_heads = prefill_params.num_kv_heads; - const uint32_t qo_len = prefill_params.qo_len; - const uint32_t kv_len = prefill_params.kv_len; - if (kv_len < qo_len && MASK_MODE_P == MaskMode::kCausal) { - std::ostringstream err_msg; - err_msg << "When mask_mode is set to MaskMode::kCausal, kv_len must be greater than or equal " - "to qo_len, got kv_len" - << kv_len << " and qo_len " << qo_len; - FLASHINFER_ERROR(err_msg.str()); - } + const uint32_t num_kv_heads = prefill_params.paged_kv.num_heads; - const uint32_t group_size = num_qo_heads / num_kv_heads; - const uint_fastdiv group_size_fastdiv(group_size); constexpr uint32_t NUM_MMA_D_QK = HEAD_DIM_QK / 16; constexpr uint32_t NUM_MMA_D_VO = HEAD_DIM_VO / 16; - - uint32_t cta_tile_q_p = 0; - int64_t unpacked_qo_len = qo_len * group_size; - if (unpacked_qo_len > 64 && HEAD_DIM_VO < 256) { - cta_tile_q_p = 128; - } else { - auto compute_capacity = GetCudaComputeCapability(); - if (compute_capacity.first >= 8) { - // Ampere or newer - if (unpacked_qo_len > 16) { - // avg_packed_qo_len <= 64 - cta_tile_q_p = 64; - } else { - // avg_packed_qo_len <= 16 - cta_tile_q_p = 16; - } - } else { - // NOTE(Zihao): not enough shared memory on Turing for 1x4 warp layout - cta_tile_q_p = 64; - } - } - // Decode vars setup using DTypeQ_D = typename DecodeParams::DTypeQ; using DTypeKV_D = typename DecodeParams::DTypeKV; @@ -269,207 +236,176 @@ cudaError_t PODWithKVCacheTensorDispatched(PrefillParams prefill_params, NUM_MMA_Q_D * NUM_WARPS_Q_D) / (2 * NUM_WARPS_KV_D); - DISPATCH_CTA_TILE_Q(cta_tile_q_p, CTA_TILE_Q_P, { - constexpr uint32_t NUM_WARPS_Q_P = get_num_warps_q(CTA_TILE_Q_P); - constexpr uint32_t NUM_WARPS_KV_P = get_num_warps_kv(CTA_TILE_Q_P); - constexpr uint32_t NUM_MMA_Q_P = get_num_mma_q(CTA_TILE_Q_P); - - using DTypeQKAccum_P = - typename std::conditional, half, - float>::type; - - // we expect each sm execute two threadblocks - // TODO(Zihao): fix the following computation - const int num_ctas_per_sm_p = - max_smem_per_sm > (16 * HEAD_DIM_QK * sizeof(DTypeQ_P) * 16) ? 2 : 1; - const int max_smem_per_threadblock_p = max_smem_per_sm / num_ctas_per_sm_p; - - constexpr uint32_t max_num_mma_kv_reg_p = - (HEAD_DIM_VO >= 128 && NUM_MMA_Q_P == 2 && - POS_ENCODING_MODE == PosEncodingMode::kRoPELlama && !USE_FP16_QK_REDUCTION) - ? 2 - : (8 / NUM_MMA_Q_P); - // TODO(Zihao): fix the following computation - const uint32_t max_num_mma_kv_smem_p = - (max_smem_per_threadblock_p / (16 * HEAD_DIM_QK * sizeof(DTypeQ_P)) - - NUM_MMA_Q_P * NUM_WARPS_Q_P) / - (2 * NUM_WARPS_KV_P); - - // control NUM_MMA_KV for maximum warp occupancy - DISPATCH_NUM_MMA_KV(min(max_num_mma_kv_smem_p, max_num_mma_kv_reg_p), NUM_MMA_KV_P, { - using KTraits_P = - KernelTraits; - - if constexpr (KTraits_P::IsInvalid()) { - // Invalid configuration, skip - std::ostringstream err_msg; - err_msg << "FlashInfer Internal Error: Invalid configuration : NUM_MMA_Q=" << NUM_MMA_Q_P - << " NUM_MMA_D_QK=" << NUM_MMA_D_QK << " NUM_MMA_D_VO=" << NUM_MMA_D_VO - << " NUM_MMA_KV=" << NUM_MMA_KV_P << " NUM_WARPS_Q=" << NUM_WARPS_Q_P - << " NUM_WARPS_KV=" << NUM_WARPS_KV_P - << " please create an issue (https://github.com/flashinfer-ai/flashinfer/issues)" - " and report the issue to the developers."; - FLASHINFER_ERROR(err_msg.str()); - } else { - // Decode stuff - // TODO: Is there a way to avoid this nested dispatch? - DISPATCH_NUM_MMA_KV(min(max_num_mma_kv_smem_d, max_num_mma_kv_reg_d), NUM_MMA_KV_D, { - using KTraits_D = - KernelTraits; - if constexpr (KTraits_D::IsInvalid()) { - // Invalid configuration, skip - std::ostringstream err_msg; - err_msg - << "FlashInfer Internal Error: Invalid configuration : NUM_MMA_Q=" << NUM_MMA_Q_D - << " NUM_MMA_D_QK=" << NUM_MMA_D_QK << " NUM_MMA_D_VO=" << NUM_MMA_D_VO - << " NUM_MMA_KV=" << NUM_MMA_KV_D << " NUM_WARPS_Q=" << NUM_WARPS_Q_D - << " NUM_WARPS_KV=" << NUM_WARPS_KV_D - << " please create an issue (https://github.com/flashinfer-ai/flashinfer/issues)" - " and report the issue to the developers."; - FLASHINFER_ERROR(err_msg.str()); + constexpr uint32_t NUM_WARPS_Q_P = get_num_warps_q(CTA_TILE_Q_P); + constexpr uint32_t NUM_WARPS_KV_P = get_num_warps_kv(CTA_TILE_Q_P); + constexpr uint32_t NUM_MMA_Q_P = get_num_mma_q(CTA_TILE_Q_P); + + using DTypeQKAccum_P = + typename std::conditional, half, + float>::type; + + // we expect each sm execute two threadblocks + // TODO(Zihao): fix the following computation + const int num_ctas_per_sm_p = + max_smem_per_sm > (16 * HEAD_DIM_QK * sizeof(DTypeQ_P) * 16) ? 2 : 1; + const int max_smem_per_threadblock_p = max_smem_per_sm / num_ctas_per_sm_p; + + constexpr uint32_t max_num_mma_kv_reg_p = + (HEAD_DIM_VO >= 128 && NUM_MMA_Q_P == 2 && POS_ENCODING_MODE == PosEncodingMode::kRoPELlama && + !USE_FP16_QK_REDUCTION) + ? 2 + : (8 / NUM_MMA_Q_P); + // TODO(Zihao): fix the following computation + const uint32_t max_num_mma_kv_smem_p = + (max_smem_per_threadblock_p / (16 * HEAD_DIM_QK * sizeof(DTypeQ_P)) - + NUM_MMA_Q_P * NUM_WARPS_Q_P) / + (2 * NUM_WARPS_KV_P); + + // control NUM_MMA_KV for maximum warp occupancy + uint32_t max_num_mma_kv_p = std::min(max_num_mma_kv_smem_p, max_num_mma_kv_reg_p); + uint32_t max_num_mma_kv_d = std::min(max_num_mma_kv_smem_d, max_num_mma_kv_reg_d); + + DISPATCH_NUM_MMA_KV(max_num_mma_kv_p, NUM_MMA_KV_P, { + using KTraits_P = KernelTraits; + + if constexpr (KTraits_P::IsInvalid()) { + // Invalid configuration, skip + std::ostringstream err_msg; + err_msg << "FlashInfer Internal Error: Invalid configuration : NUM_MMA_Q=" << NUM_MMA_Q_P + << " NUM_MMA_D_QK=" << NUM_MMA_D_QK << " NUM_MMA_D_VO=" << NUM_MMA_D_VO + << " NUM_MMA_KV=" << NUM_MMA_KV_P << " NUM_WARPS_Q=" << NUM_WARPS_Q_P + << " NUM_WARPS_KV=" << NUM_WARPS_KV_P + << " please create an issue (https://github.com/flashinfer-ai/flashinfer/issues)" + " and report the issue to the developers."; + FLASHINFER_ERROR(err_msg.str()); + } else { + // Decode stuff + // TODO: Is there a way to avoid this nested dispatch? + DISPATCH_NUM_MMA_KV(max_num_mma_kv_d, NUM_MMA_KV_D, { + using KTraits_D = + KernelTraits; + if constexpr (KTraits_D::IsInvalid()) { + // Invalid configuration, skip + std::ostringstream err_msg; + err_msg << "FlashInfer Internal Error: Invalid configuration : NUM_MMA_Q=" << NUM_MMA_Q_D + << " NUM_MMA_D_QK=" << NUM_MMA_D_QK << " NUM_MMA_D_VO=" << NUM_MMA_D_VO + << " NUM_MMA_KV=" << NUM_MMA_KV_D << " NUM_WARPS_Q=" << NUM_WARPS_Q_D + << " NUM_WARPS_KV=" << NUM_WARPS_KV_D + << " please create an issue (https://github.com/flashinfer-ai/flashinfer/issues)" + " and report the issue to the developers."; + FLASHINFER_ERROR(err_msg.str()); + } else { + // End decode stuff + constexpr uint32_t num_threads_p = (NUM_WARPS_Q_P * NUM_WARPS_KV_P) * WARP_SIZE; + size_t smem_size_p = sizeof(typename KTraits_P::SharedStorage); + size_t smem_size_d = sizeof(typename KTraits_D::SharedStorage); + + auto kernel = + PODWithKVCacheTensorKernel; + // Prefill: decide num_splits for split-kv + int num_blocks_per_sm = 0; + int num_sm = 0; + FLASHINFER_CUDA_CALL( + cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id)); + // FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + // &num_blocks_per_sm, kernel, num_threads_p, smem_size_p)); + // Above function returns 0 for some reason, so we use a workaround + num_blocks_per_sm = std::max( + 1, std::min((int)(max_smem_per_sm / smem_size_p), (int)(256 / num_threads_p))); + + // Setup new prefill params if (not) split + auto o = prefill_params.o; + auto lse = prefill_params.lse; + if (prefill_params.partition_kv) { + // Use cooperative groups to increase occupancy + assert(tmp_v != nullptr); + prefill_params.o = tmp_v; + prefill_params.lse = tmp_s; + } + + // Setup new decode params if (not) split + if (prefill_params.partition_kv) { + assert(tmp_v != nullptr); + decode_params.o = tmp_v; + decode_params.lse = tmp_s; + } + + uint32_t padded_batch_size_p = prefill_params.padded_batch_size; + uint32_t padded_batch_size_d = decode_params.padded_batch_size; + printf("Debug: launching prefill with padded_batch_size_p %d, num_kv_heads %d\n", + padded_batch_size_p, num_kv_heads); + int nblks_p(padded_batch_size_p * num_kv_heads); + int nthrs_p(32 * NUM_WARPS_Q_P * NUM_WARPS_KV_P); + printf("Debug: launching decode with padded_batch_size_d %d, num_kv_heads %d\n", + padded_batch_size_d, num_kv_heads); + int nblks_d(padded_batch_size_d * num_kv_heads); + int nthrs_d(32 * NUM_WARPS_Q_D * NUM_WARPS_KV_D); + + // ******* Select final combined sizes here ******* / + size_t smem_size = max(smem_size_p, smem_size_d); + int nblks = nblks_p + nblks_d; + int nthrs = max(nthrs_p, nthrs_d); + + // printf("Smem: prefill %zu, decode %zu, total %zu\n", smem_size_p, smem_size_d, + // smem_size); printf("Blocks: prefill %d, decode %d, total %d\n", nblks_p, nblks_d, + // nblks); printf("Threads: prefill %d, decode %d, total %d\n", nthrs_p, nthrs_d, + // nthrs); + // ************************************************ / + + static int* tbAssign = nullptr; + if (tbAssign == nullptr) cudaMalloc(&tbAssign, sizeof(int) * (num_sm + 2)); + cudaMemset(tbAssign, 0, sizeof(int) * (num_sm + 2)); + + // Setup kernel arguments + + FLASHINFER_CUDA_CALL( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + + // Launch kernel + if (enable_pdl) { + cudaLaunchAttribute attribute[1]; + cudaLaunchConfig_t config; + attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attribute[0].val.programmaticStreamSerializationAllowed = 1; + config.attrs = attribute; + config.numAttrs = 1; + config.gridDim = nblks; + config.blockDim = nthrs; + config.dynamicSmemBytes = smem_size; + config.stream = stream; + FLASHINFER_CUDA_CALL( + cudaLaunchKernelEx(&config, kernel, prefill_params, decode_params, tbAssign)); } else { - // End decode stuff - constexpr uint32_t num_threads_p = (NUM_WARPS_Q_P * NUM_WARPS_KV_P) * WARP_SIZE; - size_t smem_size_p = sizeof(typename KTraits_P::SharedStorage); - size_t smem_size_d = sizeof(typename KTraits_D::SharedStorage); - - auto kernel = - PODWithKVCacheTensorKernel; - // Prefill: decide num_splits for split-kv - int num_blocks_per_sm = 0; - int num_sm = 0; + void* args[] = {(void*)&prefill_params, (void*)&decode_params, (void*)&tbAssign}; FLASHINFER_CUDA_CALL( - cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id)); - // FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor( - // &num_blocks_per_sm, kernel, num_threads_p, smem_size_p)); - // Above function returns 0 for some reason, so we use a workaround - num_blocks_per_sm = std::max( - 1, std::min((int)(max_smem_per_sm / smem_size_p), (int)(256 / num_threads_p))); - uint32_t max_num_kv_chunks = - (num_blocks_per_sm * num_sm) / - (num_kv_heads * ceil_div(qo_len * group_size, KTraits_P::CTA_TILE_Q)); - uint32_t num_chunks; - if (max_num_kv_chunks > 0) { - uint32_t chunk_size = max(ceil_div(kv_len, max_num_kv_chunks), 256); - num_chunks = ceil_div(kv_len, chunk_size); - } else { - num_chunks = 0; - } - - // Setup new prefill params if (not) split - auto o_p = prefill_params.o; - auto lse_p = prefill_params.lse; - float* tmp_lse = (float*)(tmp_p + num_chunks * qo_len * num_qo_heads * HEAD_DIM_VO); - if (num_chunks <= 1 || tmp_p == nullptr) { - // Enough parallelism, do not split-kv - prefill_params.partition_kv = 0; - kernel = PODWithKVCacheTensorKernel; - } else { - // Use cooperative groups to increase occupancy - prefill_params.partition_kv = num_chunks; - prefill_params.o = tmp_p; - prefill_params.lse = tmp_lse; - kernel = PODWithKVCacheTensorKernel; - } + cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + } - // Setup new decode params if (not) split - auto o_d = decode_params.o; - auto lse_d = decode_params.lse; - if (tmp_v == nullptr) { - // do not partition kv - decode_params.partition_kv = false; - } else { - decode_params.partition_kv = true; - decode_params.o = tmp_v; - decode_params.lse = tmp_s; - } - uint32_t xsize = ceil_div(qo_len * group_size, KTraits_P::CTA_TILE_Q); - int nblks_p(xsize * (prefill_params.partition_kv ? prefill_params.partition_kv : 1) * - num_kv_heads); - int nthrs_p(32 * NUM_WARPS_Q_P * NUM_WARPS_KV_P); - - int nblks_d(padded_batch_size_d * 1 * num_kv_heads); - int nthrs_d(32 * NUM_WARPS_Q_D * NUM_WARPS_KV_D); - - // ******* Select final combined sizes here ******* / - size_t smem_size = max(smem_size_p, smem_size_d); - int nblks = nblks_p + nblks_d; - int nthrs = max(nthrs_p, nthrs_d); - - // printf("Smem: prefill %zu, decode %zu, total %zu\n", smem_size_p, smem_size_d, - // smem_size); printf("Blocks: prefill %d, decode %d, total %d\n", nblks_p, nblks_d, - // nblks); printf("Threads: prefill %d, decode %d, total %d\n", nthrs_p, nthrs_d, - // nthrs); - // ************************************************ / - - static int* tbAssign = nullptr; - if (tbAssign == nullptr) cudaMalloc(&tbAssign, sizeof(int) * (num_sm + 2)); - cudaMemset(tbAssign, 0, sizeof(int) * (num_sm + 2)); - - // Setup kernel arguments - void* args[] = {(void*)&xsize, (void*)&prefill_params, (void*)&decode_params, - (void*)&tbAssign}; - FLASHINFER_CUDA_CALL(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - - // Launch kernel - if (enable_pdl) { - cudaLaunchAttribute attribute[1]; - cudaLaunchConfig_t config; - attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; - attribute[0].val.programmaticStreamSerializationAllowed = 1; - config.attrs = attribute; - config.numAttrs = 1; - config.gridDim = nblks; - config.blockDim = nthrs; - config.dynamicSmemBytes = smem_size; - config.stream = stream; - FLASHINFER_CUDA_CALL(cudaLaunchKernelEx(&config, kernel, xsize, prefill_params, - decode_params, tbAssign)); + // Post-kernel stuff for split-kv + if (tmp_v != nullptr) { + if constexpr (DecodeAttentionVariant::use_softmax) { + FLASHINFER_CUDA_CALL(VariableLengthMergeStates( + tmp_v, tmp_s, decode_params.merge_indptr, o, lse, + decode_params.max_total_num_rows, decode_params.total_num_rows, num_qo_heads, + HEAD_DIM_VO, enable_pdl, stream)); } else { - FLASHINFER_CUDA_CALL( - cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); - } - - // Post-kernel stuff for split-kv prefill - if (!(num_chunks <= 1 || tmp_p == nullptr)) { - if constexpr (PrefillAttentionVariant::use_softmax) { - FLASHINFER_CUDA_CALL(MergeStates(tmp_p, tmp_lse, o_p, lse_p, num_chunks, qo_len, - num_qo_heads, HEAD_DIM_VO, stream)); - } else { - FLASHINFER_CUDA_CALL(AttentionSum(tmp_p, o_p, num_chunks, qo_len, num_qo_heads, - HEAD_DIM_VO, stream)); - } - } - // Post-kernel stuff for split-kv decode - if (tmp_v != nullptr) { - if constexpr (DecodeAttentionVariant::use_softmax) { - FLASHINFER_CUDA_CALL(VariableLengthMergeStates( - tmp_v, tmp_s, decode_params.merge_indptr, o_d, lse_d, - decode_params.max_total_num_rows, decode_params.total_num_rows, num_qo_heads, - HEAD_DIM_VO, enable_pdl, stream)); - } else { - FLASHINFER_CUDA_CALL(VariableLengthAttentionSum( - tmp_v, decode_params.merge_indptr, o_d, decode_params.max_total_num_rows, - decode_params.total_num_rows, num_qo_heads, HEAD_DIM_VO, enable_pdl, stream)); - } + FLASHINFER_CUDA_CALL(VariableLengthAttentionSum( + tmp_v, decode_params.merge_indptr, o, decode_params.max_total_num_rows, + decode_params.total_num_rows, num_qo_heads, HEAD_DIM_VO, enable_pdl, stream)); } } - }); - } - }); + } + }); + } }); return cudaSuccess; } - } // namespace flashinfer #endif // FLASHINFER_PREFILL_CUH_ diff --git a/include/flashinfer/attention/prefill.cuh b/include/flashinfer/attention/prefill.cuh index 5db013bb03..aedb1eb688 100644 --- a/include/flashinfer/attention/prefill.cuh +++ b/include/flashinfer/attention/prefill.cuh @@ -2103,6 +2103,7 @@ __device__ __forceinline__ void BatchPrefillWithPagedKVCacheDevice( AttentionVariant variant(params, /*batch_idx=*/request_idx, smem); const uint32_t qo_len = variant.qo_len, kv_len = variant.kv_len, window_left = variant.window_left; + const uint32_t kv_len_safe = kv_len > 0 ? kv_len : 1; const uint32_t qo_upper_bound = min(qo_len, ceil_div((qo_tile_idx + 1) * CTA_TILE_Q, group_size)); @@ -2114,7 +2115,13 @@ __device__ __forceinline__ void BatchPrefillWithPagedKVCacheDevice( partition_kv ? min(kv_tile_idx * max_chunk_size + kv_start_idx, kv_len) : kv_start_idx; const uint32_t chunk_end = partition_kv ? min((kv_tile_idx + 1) * max_chunk_size + kv_start_idx, kv_len) : kv_len; + if (chunk_end < chunk_start) { + FLASHINFER_RUNTIME_ASSERT("chunk_end must >= chunk_start. Check your paged kv indices."); + } const uint32_t chunk_size = chunk_end - chunk_start; + if (chunk_size == 0) { + return; // no kv data + } DTypeQKAccum s_frag[NUM_MMA_Q][NUM_MMA_KV][8]; alignas(16) float o_frag[NUM_MMA_Q][NUM_MMA_D_VO][8]; DTypeQKAccum m[NUM_MMA_Q][2]; @@ -2221,6 +2228,17 @@ __device__ __forceinline__ void BatchPrefillWithPagedKVCacheDevice( chunk_start)) : chunk_size), CTA_TILE_KV); + if (threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0 && num_iterations > 100000) { + uint32_t seq_len = kv_len - qo_len + ceil_div(((qo_tile_idx + 1) * CTA_TILE_Q), group_size); + printf( + "Debug: num_iterations: %d, request_idx: %d, chunk_size: %d, chunk_start: %d, " + "chunk_end: %d, max_chunk_size: %d, partition_kv: %d, qo_len: %d, kv_len: %d, " + "sub_if_greater_or_zero: %d, divided by CTA_TILE_KV: %d\n", + num_iterations, request_idx, chunk_size, chunk_start, chunk_end, max_chunk_size, + partition_kv, qo_len, kv_len, sub_if_greater_or_zero(seq_len, chunk_start), + sub_if_greater_or_zero(seq_len, chunk_start) / CTA_TILE_KV); + } + } else if constexpr (MASK_MODE == MaskMode::kMultiItemScoring) { num_iterations_prefix = ceil_div( min(min(chunk_size, @@ -2247,7 +2265,12 @@ __device__ __forceinline__ void BatchPrefillWithPagedKVCacheDevice( chunk_start)), CTA_TILE_KV)); } - + if (threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0) { + printf( + "Debug: block %d request_idx: %d num_iterations: %d, qo_len: %d, kv_len: %d, " + "chunk_size:%d\n", + blockIdx.x, request_idx, num_iterations, qo_len, kv_len, chunk_size); + } const uint32_t window_iteration = ceil_div( sub_if_greater_or_zero(kv_len + ceil_div((qo_tile_idx + 1) * CTA_TILE_Q, group_size), qo_len + window_left + chunk_start), @@ -2566,7 +2589,7 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched(Params params, typename Param // this won't happen in CUDAGraph mode because we fixed the padded_batch_size return cudaSuccess; } - + // bs = num_qo_tiles * num_kv_tiles * gqa_group_size dim3 nblks(padded_batch_size, 1, num_kv_heads); dim3 nthrs(32, NUM_WARPS_Q, NUM_WARPS_KV); diff --git a/include/flashinfer/attention/scheduler.cuh b/include/flashinfer/attention/scheduler.cuh index 4f888e716b..4540976bb9 100644 --- a/include/flashinfer/attention/scheduler.cuh +++ b/include/flashinfer/attention/scheduler.cuh @@ -22,6 +22,7 @@ #include #include #include +#include #include #include @@ -58,7 +59,7 @@ inline void CopyToPageLockedBuffer(void* page_locked_int_buffer, int64_t offset, } /*! - * \brief Compute the maximum number of pages per batch and the new batch size + * \brief Compute the maximum number of pages per batch and the new batch size (grid dim x) * after we partition Paged KV-Cache into multiple chunks on KV sequence length * dimension. * \tparam IdType A template type indicates the index data type @@ -100,23 +101,23 @@ inline auto PartitionPagedKVCacheBinarySearchMinNumPagePerBatch( inline auto PrefillBinarySearchKVChunkSize(const bool enable_cuda_graph, const uint32_t max_batch_size_if_split, - const std::vector& packed_qo_len_arr, - const std::vector& kv_len_arr, + const std::vector& packed_qo_len_arr, + const std::vector& kv_len_arr, const uint32_t qo_chunk_size, const uint32_t min_kv_chunk_size = 1) { - const int64_t batch_size = packed_qo_len_arr.size(); - int64_t max_kv_len = 1; - for (const int64_t& kv_len : kv_len_arr) { + const int32_t batch_size = packed_qo_len_arr.size(); + int32_t max_kv_len = 1; + for (const int32_t& kv_len : kv_len_arr) { max_kv_len = std::max(max_kv_len, kv_len); } - int64_t low = min_kv_chunk_size; - int64_t high = max_kv_len; - constexpr int64_t min_kv_len = 1; + int32_t low = min_kv_chunk_size; + int32_t high = max_kv_len; + constexpr int32_t min_kv_len = 1; while (low < high) { - const int64_t mid = (low + high) / 2; - int64_t new_batch_size = 0; - for (uint32_t i = 0; i < batch_size; ++i) { + const int32_t mid = (low + high) / 2; + int32_t new_batch_size = 0; + for (int32_t i = 0; i < batch_size; ++i) { new_batch_size += ceil_div(packed_qo_len_arr[i], qo_chunk_size) * ceil_div(std::max(kv_len_arr[i], min_kv_len), mid); } @@ -165,7 +166,7 @@ inline cudaError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatched( constexpr uint32_t tile_size_per_bdx = GROUP_SIZE == 1 ? (sizeof(DTypeKV) == 1 ? 2U : 4U) : 1U; const uint32_t num_kv_heads = num_qo_heads / GROUP_SIZE; gdy = num_kv_heads; - const uint32_t smem_size = + const uint32_t smem_size = // kv + max + denominator 2 * NUM_STAGES_SMEM * tile_size_per_bdx * bdy * bdz * HEAD_DIM * sizeof(DTypeKV) + std::max(tile_size_per_bdx * num_threads * sizeof(DTypeKV*), 2 * bdy * bdz * sizeof(float)); @@ -493,29 +494,19 @@ inline cudaError_t DecodePlan(void* float_buffer, size_t float_workspace_size_in } template -inline auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h, - uint32_t total_num_rows, uint32_t batch_size, - uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim, - uint32_t page_size, uint32_t max_batch_size_if_split, - bool enable_cuda_graph, int32_t window_left, - int32_t fixed_split_size, bool disable_split_kv) { - std::vector request_indices, qo_tile_indices, kv_tile_indices, merge_indptr, o_indptr; - merge_indptr.push_back(0); - o_indptr.push_back(0); - - const uint32_t gqa_group_size = num_qo_heads / num_kv_heads; - - // step 1: determine packed_qo_len_arr and verify qo_indptr contents. - std::vector packed_qo_len_arr(batch_size), kv_len_arr(batch_size); +inline auto get_qkv_len_arr(IdType* qo_indptr_h, IdType* kv_indptr_h, uint32_t batch_size, + uint32_t num_qo_heads, uint32_t gqa_group_size) { + std::vector packed_qo_len_arr(batch_size), kv_len_arr(batch_size); for (uint32_t i = 0; i < batch_size; ++i) { - packed_qo_len_arr[i] = int64_t(qo_indptr_h[i + 1] - qo_indptr_h[i]) * int64_t(gqa_group_size); + packed_qo_len_arr[i] = int32_t(qo_indptr_h[i + 1] - qo_indptr_h[i]) * int32_t(gqa_group_size); if (packed_qo_len_arr[i] < 0) { std::ostringstream err_msg; err_msg << "qo_indptr[" << i + 1 << "]" << qo_indptr_h[i + 1] << " - qo_indptr[" << i << "]" << qo_indptr_h[i] << " should be non-negative"; FLASHINFER_ERROR(err_msg.str()); } - kv_len_arr[i] = int64_t(kv_indptr_h[i + 1] - kv_indptr_h[i]); + kv_len_arr[i] = int32_t(kv_indptr_h[i + 1] - kv_indptr_h[i]); + printf("Debug: request %d qo_len: %d, kv_len: %d\n", i, packed_qo_len_arr[i], kv_len_arr[i]); if (kv_len_arr[i] < 0) { std::ostringstream err_msg; err_msg << "kv_indptr[" << i + 1 << "]" << kv_indptr_h[i + 1] << " - kv_indptr[" << i << "]" @@ -523,94 +514,187 @@ inline auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h, FLASHINFER_ERROR(err_msg.str()); } } + return std::make_tuple(packed_qo_len_arr, kv_len_arr); +} - // step 2: determine cta_tile_q, kv_chunk_size and total_num_tiles_q - const uint32_t min_kv_chunk_size = std::max((128 / page_size), 1U); +inline auto get_q_tiles(std::vector packed_qo_len_arr, uint32_t batch_size, + uint32_t head_dim, uint32_t page_size, uint32_t total_num_rows, + uint32_t gqa_group_size, bool enable_cuda_graph, bool is_decode = false) { uint32_t cta_tile_q; uint32_t total_num_tiles_q; if (enable_cuda_graph) { // When CUDA graphs are enabled, the lengths of sequences determined by // qo_indptr_h can vary. We assume that the dummy data based on which // the CUDA graph is created fixes the maximum number of tokens. - const uint64_t max_seq_len = total_num_rows - batch_size + 1; - uint64_t max_qo_len = uint64_t(max_seq_len) * gqa_group_size; - cta_tile_q = FA2DetermineCtaTileQ(max_qo_len, head_dim); - + if (is_decode) { + cta_tile_q = 16; + } else { + const uint64_t max_seq_len = total_num_rows - batch_size + 1; + uint64_t max_qo_len = uint64_t(max_seq_len) * gqa_group_size; + cta_tile_q = FA2DetermineCtaTileQ(max_qo_len, head_dim); + } // Find an upper bound for the number of tiles, derived from the total // number of rows and the batch size. The sum of qo lengths rounded // up to cta_tile_q will not exceed this number derived from the total // number of rows. total_num_tiles_q = ceil_div(total_num_rows * gqa_group_size, cta_tile_q) + batch_size - 1; } else { - int64_t sum_packed_qo_len = 0; - for (uint32_t i = 0; i < batch_size; ++i) { - sum_packed_qo_len += packed_qo_len_arr[i]; + if (is_decode) { + cta_tile_q = 16; + } else { + int64_t sum_packed_qo_len = 0; + for (uint32_t i = 0; i < batch_size; ++i) { + sum_packed_qo_len += packed_qo_len_arr[i]; + } + const int64_t avg_packed_qo_len = sum_packed_qo_len / batch_size; + cta_tile_q = FA2DetermineCtaTileQ(avg_packed_qo_len, head_dim); } - const int64_t avg_packed_qo_len = sum_packed_qo_len / batch_size; - cta_tile_q = FA2DetermineCtaTileQ(avg_packed_qo_len, head_dim); total_num_tiles_q = 0; for (uint32_t i = 0; i < batch_size; ++i) { total_num_tiles_q += ceil_div(packed_qo_len_arr[i], cta_tile_q); } } + return std::make_tuple(cta_tile_q, total_num_tiles_q); +} - // Calculate the actual needed CTA when considering sliding window - std::vector effective_kv_len_arr(batch_size); - for (uint32_t i = 0; i < batch_size; ++i) { - // pad CTA_TILE_Q to consider the causal kv-len - effective_kv_len_arr[i] = - std::min(window_left >= 0 ? ceil_div(window_left + cta_tile_q, page_size) : kv_len_arr[i], - kv_len_arr[i]); - } +template +inline auto get_qkv_tile_indices( + const std::vector& packed_qo_len_arr, const std::vector& kv_len_arr, + uint32_t batch_size, uint32_t cta_tile_q, uint32_t kv_chunk_size, uint32_t gqa_group_size, + std::vector* request_indices = nullptr, std::vector* qo_tile_indices = nullptr, + std::vector* kv_tile_indices = nullptr, std::vector* merge_indptr = nullptr, + std::vector* o_indptr = nullptr, int32_t fixed_split_size = -1, + bool disable_split_kv = false) { + std::vector local_req; + std::vector local_qo; + std::vector local_kv; + std::vector local_merge{0}; + std::vector local_o{0}; + + auto* out_req = request_indices ? request_indices : &local_req; + auto* out_qo = qo_tile_indices ? qo_tile_indices : &local_qo; + auto* out_kv = kv_tile_indices ? kv_tile_indices : &local_kv; + auto* out_merge = merge_indptr ? merge_indptr : &local_merge; + auto* out_o = o_indptr ? o_indptr : &local_o; + uint32_t start_req_idx = 0; // for global q,k,v,o indexing bool split_kv = false; - int64_t kv_chunk_size; - if (disable_split_kv) { - kv_chunk_size = std::numeric_limits::max(); - } else if (!disable_split_kv && fixed_split_size > 0) { - kv_chunk_size = fixed_split_size; - } else { - std::tie(split_kv, kv_chunk_size) = PrefillBinarySearchKVChunkSize( - enable_cuda_graph, max_batch_size_if_split, packed_qo_len_arr, effective_kv_len_arr, - cta_tile_q, min_kv_chunk_size); + if (request_indices && !request_indices->empty()) { + start_req_idx = request_indices->back(); } - // step 3: split qo_indptr and kv_indptr + uint32_t new_batch_size = 0; for (uint32_t request_idx = 0; request_idx < batch_size; ++request_idx) { - const int64_t packed_qo_len = packed_qo_len_arr[request_idx]; - const int64_t num_tiles_q = ceil_div(packed_qo_len, cta_tile_q); - const int64_t kv_len = std::max(int(effective_kv_len_arr[request_idx]), 1); - const int64_t num_chunks_kv = disable_split_kv ? 1 : ceil_div(kv_len, kv_chunk_size); + const int32_t packed_qo_len = packed_qo_len_arr[request_idx]; + const int32_t kv_len = std::max(kv_len_arr[request_idx], 1); + const int32_t num_tiles_q = ceil_div(packed_qo_len, cta_tile_q); + const int32_t num_chunks_kv = ceil_div(kv_len, kv_chunk_size); if (fixed_split_size > 0 && !disable_split_kv) { split_kv = split_kv || num_chunks_kv > 1; } for (uint32_t q_tile_idx = 0; q_tile_idx < num_tiles_q; ++q_tile_idx) { for (uint32_t kv_tile_idx = 0; kv_tile_idx < num_chunks_kv; ++kv_tile_idx) { new_batch_size += 1; - request_indices.push_back(request_idx); - qo_tile_indices.push_back(q_tile_idx); - kv_tile_indices.push_back(kv_tile_idx); + out_req->push_back(request_idx + start_req_idx); + out_qo->push_back(q_tile_idx); + out_kv->push_back(kv_tile_idx); + printf("Debug: num_tiles_q: %d, num_chunks_kv: %d, q_tile_idx: %d, kv_tile_idx: %d\n", + num_tiles_q, num_chunks_kv, q_tile_idx, kv_tile_idx); } } int64_t qo_len = packed_qo_len / gqa_group_size; for (uint32_t row = 0; row < qo_len; ++row) { - merge_indptr.push_back(merge_indptr.back() + num_chunks_kv); + out_merge->push_back(out_merge->back() + num_chunks_kv); + printf("Debug: merge_indptr[%d]: %d\n", out_merge->size() - 1, out_merge->back()); } - o_indptr.push_back(o_indptr.back() + qo_len * num_chunks_kv); + out_o->push_back(out_o->back() + qo_len * num_chunks_kv); + printf("Debug: o_indptr[%d]: %d, num_kv_tiles: %d, qo_len: %d\n", out_o->size() - 1, + out_o->back(), num_chunks_kv, qo_len); + } + printf("Debug: o_indptr.size(): %d, merge_indptr.size(): %d, batch_size: %d\n", out_o->size(), + out_merge->size(), batch_size); + return std::make_tuple(std::move(local_req), std::move(local_qo), std::move(local_kv), + std::move(local_merge), std::move(local_o), new_batch_size); +} + +template +inline auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h, + uint32_t total_num_rows, uint32_t batch_size, + uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim, + uint32_t page_size, uint32_t max_batch_size_if_split, + bool enable_cuda_graph, int32_t window_left, + int32_t fixed_split_size, bool disable_split_kv) { + const uint32_t gqa_group_size = num_qo_heads / num_kv_heads; + const uint32_t min_kv_chunk_size = std::max((128 / page_size), 1U); + + // step 1: determine packed_qo_len_arr and verify qo_indptr contents. + auto [packed_qo_len_arr, kv_len_arr] = + get_qkv_len_arr(qo_indptr_h, kv_indptr_h, batch_size, num_qo_heads, gqa_group_size); + + // step 2: determine cta_tile_q, kv_chunk_size and total_num_tiles_q + auto [cta_tile_q, total_num_tiles_q] = + get_q_tiles(packed_qo_len_arr, batch_size, head_dim, page_size, total_num_rows, + gqa_group_size, enable_cuda_graph); + + // Calculate the actual needed CTA when considering sliding window + std::vector effective_kv_len_arr(batch_size); + for (uint32_t i = 0; i < batch_size; ++i) { + // pad by CTA_TILE_Q to consider the causal kv-len + effective_kv_len_arr[i] = std::min( + window_left >= 0 ? static_cast(ceil_div(window_left + cta_tile_q, page_size)) + : kv_len_arr[i], + kv_len_arr[i]); + } + + bool split_kv; + int64_t kv_chunk_size; + if (disable_split_kv) { + kv_chunk_size = std::numeric_limits::max(); + } else if (!disable_split_kv && fixed_split_size > 0) { + kv_chunk_size = fixed_split_size; + } else { + std::tie(split_kv, kv_chunk_size) = PrefillBinarySearchKVChunkSize( + enable_cuda_graph, max_batch_size_if_split, packed_qo_len_arr, kv_len_arr, cta_tile_q, + min_kv_chunk_size); + } + // step 3: split qo_indptr and kv_indptr + uint32_t kv_chunk_size_u32 = + (kv_chunk_size > static_cast(std::numeric_limits::max())) + ? std::numeric_limits::max() + : static_cast(kv_chunk_size); + auto [request_indices, qo_tile_indices, kv_tile_indices, merge_indptr, o_indptr, new_batch_size] = + get_qkv_tile_indices(packed_qo_len_arr, effective_kv_len_arr, batch_size, cta_tile_q, + kv_chunk_size_u32, gqa_group_size, nullptr, nullptr, nullptr, + nullptr, nullptr, fixed_split_size, disable_split_kv); + // print indices + printf("Debug: -------------------------------\n"); + for (int i = 0; i < request_indices.size(); i++) { + printf("Debug: request_indices[%d]: %d\n", i, request_indices[i]); + } + for (int i = 0; i < merge_indptr.size(); i++) { + printf("Debug: merge_indptr[%d]: %d\n", i, merge_indptr[i]); + } + for (int i = 0; i < o_indptr.size(); i++) { + printf("Debug: o_indptr[%d]: %d\n", i, o_indptr[i]); + } + for (int i = 0; i < qo_tile_indices.size(); i++) { + printf("Debug: qo_tile_indices[%d]: %d\n", i, qo_tile_indices[i]); + } + for (int i = 0; i < kv_tile_indices.size(); i++) { + printf("Debug: kv_tile_indices[%d]: %d\n", i, kv_tile_indices[i]); } const size_t padded_batch_size = enable_cuda_graph ? std::max(max_batch_size_if_split, total_num_tiles_q) : new_batch_size; FLASHINFER_CHECK(new_batch_size <= padded_batch_size, - "new batch size should not exceed padded batch size. If you are using fixed " - "split size, please consider disabling cuda graph."); + "new batch size should not exceed padded batch size"); // step 4: multiply kv_chunk_size by page_size kv_chunk_size *= page_size; + return std::make_tuple(split_kv, new_batch_size, padded_batch_size, cta_tile_q, kv_chunk_size, - std::move(request_indices), std::move(qo_tile_indices), - std::move(kv_tile_indices), std::move(merge_indptr), std::move(o_indptr)); + request_indices, qo_tile_indices, kv_tile_indices, merge_indptr, o_indptr); } struct PrefillPlanInfo { @@ -794,6 +878,311 @@ inline cudaError_t PrefillPlan(void* float_buffer, size_t float_workspace_size_i return cudaSuccess; } +/* +Modifed to support two tile sizes, and assign blocks proportional to +the number of tiles. +*/ +template +inline auto PODSplitQOKVIndptr(IdType* qo_indptr_p, IdType* kv_indptr_p, uint32_t total_num_rows_p, + uint32_t batch_size_p, IdType* qo_indptr_d, IdType* kv_indptr_d, + uint32_t total_num_rows_d, uint32_t batch_size_d, + uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim, + uint32_t page_size, uint32_t max_batch_size_if_split, + bool enable_cuda_graph) { + const uint32_t gqa_group_size = num_qo_heads / num_kv_heads; + const uint32_t min_kv_chunk_size = std::max((128 / page_size), 1U); + // step 1: determine packed_qo_len_arr and verify qo_indptr contents. + auto [packed_qo_len_arr_p, kv_len_arr_p] = + get_qkv_len_arr(qo_indptr_p, kv_indptr_p, batch_size_p, num_qo_heads, gqa_group_size); + auto [packed_qo_len_arr_d, kv_len_arr_d] = + get_qkv_len_arr(qo_indptr_d, kv_indptr_d, batch_size_d, num_qo_heads, gqa_group_size); + + // step 2: determine cta_tile_q, kv_chunk_size and total_num_tiles_q + auto [cta_tile_q_p, num_tiles_q_p] = + get_q_tiles(packed_qo_len_arr_p, batch_size_p, head_dim, page_size, total_num_rows_p, + gqa_group_size, enable_cuda_graph); + auto [cta_tile_q_d, num_tiles_q_d] = + get_q_tiles(packed_qo_len_arr_d, batch_size_d, head_dim, page_size, total_num_rows_d, + gqa_group_size, enable_cuda_graph, /*is_decode=*/true); + + uint32_t total_num_tiles_q = num_tiles_q_p + num_tiles_q_d; + // Allocate CTAs proportional to the number of query tiles in prefill and decode + // TODO(Wenxuan): explore a more balanced cost function considering kv len. + // See discussion: https://github.com/flashinfer-ai/flashinfer/issues/1175 + uint32_t max_bs_p = max_batch_size_if_split * num_tiles_q_p / total_num_tiles_q; + uint32_t max_bs_d = max_batch_size_if_split - max_bs_p; + auto [split_kv_p, kv_chunk_size_p] = + PrefillBinarySearchKVChunkSize(enable_cuda_graph, max_bs_p, packed_qo_len_arr_p, kv_len_arr_p, + cta_tile_q_p, min_kv_chunk_size); + + auto [split_kv_d, kv_chunk_size_d] = + PrefillBinarySearchKVChunkSize(enable_cuda_graph, max_bs_d, packed_qo_len_arr_d, kv_len_arr_d, + cta_tile_q_d, min_kv_chunk_size); + printf("Debug: max_bs_p: %d, max_bs_d: %d, kv_chunk_size_p: %d, kv_chunk_size_d: %d\n", max_bs_p, + max_bs_d, kv_chunk_size_p, kv_chunk_size_d); + // step 3: split qo_indptr and kv_indptr + // Use one set of qkv indices, merge_indptr and o_indptr to simply merging. + auto [request_indices, qo_tile_indices, kv_tile_indices, merge_indptr, o_indptr, real_bs_p] = + get_qkv_tile_indices(packed_qo_len_arr_p, kv_len_arr_p, batch_size_p, cta_tile_q_p, + kv_chunk_size_p, gqa_group_size); + auto [_, __, _____, _______, _________, real_bs_d] = + get_qkv_tile_indices(packed_qo_len_arr_d, kv_len_arr_d, batch_size_d, cta_tile_q_d, + kv_chunk_size_d, gqa_group_size, &request_indices, + &qo_tile_indices, &kv_tile_indices, &merge_indptr, &o_indptr); + // print indices + printf("Debug: -------------------------------\n"); + for (int i = 0; i < request_indices.size(); i++) { + printf("Debug: request_indices[%d]: %d\n", i, request_indices[i]); + } + for (int i = 0; i < merge_indptr.size(); i++) { + printf("Debug: merge_indptr[%d]: %d\n", i, merge_indptr[i]); + } + for (int i = 0; i < o_indptr.size(); i++) { + printf("Debug: o_indptr[%d]: %d\n", i, o_indptr[i]); + } + for (int i = 0; i < qo_tile_indices.size(); i++) { + printf("Debug: qo_tile_indices[%d]: %d\n", i, qo_tile_indices[i]); + } + for (int i = 0; i < kv_tile_indices.size(); i++) { + printf("Debug: kv_tile_indices[%d]: %d\n", i, kv_tile_indices[i]); + } + + bool split_kv = split_kv_p || split_kv_d; + uint32_t new_batch_size = real_bs_p + real_bs_d; + const size_t padded_batch_size_p = + enable_cuda_graph ? std::max(max_bs_p, num_tiles_q_p) : real_bs_p; + const size_t padded_batch_size_d = + enable_cuda_graph ? std::max(max_bs_d, num_tiles_q_d) : real_bs_d; + FLASHINFER_CHECK(new_batch_size <= padded_batch_size_p + padded_batch_size_d, + "new batch size should not exceed padded batch size"); + + // step 4: multiply kv_chunk_size by page_size + kv_chunk_size_p *= page_size; + kv_chunk_size_d *= page_size; + + return std::make_tuple(split_kv, new_batch_size, padded_batch_size_p, padded_batch_size_d, + cta_tile_q_p, cta_tile_q_d, kv_chunk_size_p, kv_chunk_size_d, + std::move(request_indices), std::move(qo_tile_indices), + std::move(kv_tile_indices), std::move(merge_indptr), std::move(o_indptr)); +} + +struct PODPlanInfo { + int64_t padded_batch_size_p; + int64_t padded_batch_size_d; + int64_t total_num_rows; + int64_t total_num_rows_p; + int64_t total_num_rows_d; + int64_t total_num_rows_offset; + uint16_t cta_tile_q_p; + uint16_t cta_tile_q_d; + int64_t request_indices_offset; + int64_t qo_tile_indices_offset; + int64_t kv_tile_indices_offset; + int64_t merge_indptr_offset; + int64_t o_indptr_offset; + int64_t kv_chunk_size_ptr_offset_p; + int64_t kv_chunk_size_ptr_offset_d; + int64_t v_offset; + int64_t s_offset; + int64_t block_valid_mask_offset; + bool enable_cuda_graph; + bool split_kv; + + PODPlanInfo() + : padded_batch_size_p(0), + padded_batch_size_d(0), + total_num_rows(0), + total_num_rows_p(0), + total_num_rows_d(0), + total_num_rows_offset(0), + cta_tile_q_p(0), + cta_tile_q_d(0), + request_indices_offset(0), + qo_tile_indices_offset(0), + kv_tile_indices_offset(0), + merge_indptr_offset(0), + o_indptr_offset(0), + kv_chunk_size_ptr_offset_p(0), + kv_chunk_size_ptr_offset_d(0), + v_offset(0), + s_offset(0), + block_valid_mask_offset(0), + enable_cuda_graph(false), + split_kv(false) {} + + // convert PrefillPlanInfo to std::vector + std::vector ToVector() const { + return {padded_batch_size_p, + padded_batch_size_d, + total_num_rows, + total_num_rows_p, + total_num_rows_d, + total_num_rows_offset, + cta_tile_q_p, + cta_tile_q_d, + request_indices_offset, + qo_tile_indices_offset, + kv_tile_indices_offset, + merge_indptr_offset, + o_indptr_offset, + kv_chunk_size_ptr_offset_p, + kv_chunk_size_ptr_offset_d, + v_offset, + s_offset, + block_valid_mask_offset, + enable_cuda_graph, + split_kv}; + } + + // From std::vector to PodPlanInfo + void FromVector(const std::vector& vec) { + if (vec.size() != 20) { + std::ostringstream err_msg; + err_msg << "PodPlanInfo::FromVector: vec.size() should be 20, but got " << vec.size(); + FLASHINFER_ERROR(err_msg.str()); + } + padded_batch_size_p = vec[0]; + padded_batch_size_d = vec[1]; + total_num_rows = vec[2]; + total_num_rows_p = vec[3]; + total_num_rows_d = vec[4]; + total_num_rows_offset = vec[5]; + cta_tile_q_p = vec[6]; + cta_tile_q_d = vec[7]; + request_indices_offset = vec[8]; + qo_tile_indices_offset = vec[9]; + kv_tile_indices_offset = vec[10]; + merge_indptr_offset = vec[11]; + o_indptr_offset = vec[12]; + kv_chunk_size_ptr_offset_p = vec[13]; + kv_chunk_size_ptr_offset_d = vec[14]; + v_offset = vec[15]; + s_offset = vec[16]; + block_valid_mask_offset = vec[17]; + enable_cuda_graph = vec[18]; + split_kv = vec[19]; + } +}; + +template +inline cudaError_t PODPlan(void* float_buffer, size_t float_workspace_size_in_bytes, + void* int_buffer, void* page_locked_int_buffer, + size_t int_workspace_size_in_bytes, PODPlanInfo& plan_info, + IdType* qo_indptr_p, IdType* kv_indptr_p, uint32_t total_num_rows_p, + uint32_t batch_size_p, IdType* qo_indptr_d, IdType* kv_indptr_d, + uint32_t total_num_rows_d, uint32_t batch_size_d, uint32_t num_qo_heads, + uint32_t num_kv_heads, uint32_t head_dim_qk, uint32_t head_dim_vo, + uint32_t page_size, bool enable_cuda_graph, uint32_t sizeof_dtype_o, + cudaStream_t stream) { + if (num_qo_heads % num_kv_heads != 0) { + std::ostringstream err_msg; + err_msg << "num_qo_heads " << num_qo_heads << " should be divisible by num_kv_heads " + << num_kv_heads; + FLASHINFER_ERROR(err_msg.str()); + } + + // step 0: get the number of SMs + int num_sm = 0; + int dev_id = 0; + FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id)); + FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id)); + int num_blocks_per_sm = 3; // TODO(Wenxuan): increase this to reduce wave quantization? + int max_grid_size = num_blocks_per_sm * num_sm; + uint32_t max_batch_size_if_split = max_grid_size / num_kv_heads; + printf("Debug: max_batch_size_if_split: %d\n", max_batch_size_if_split); + // step 2: determine kv_chunk_size + auto [split_kv, new_batch_size, padded_batch_size_p, padded_batch_size_d, cta_tile_q_p, + cta_tile_q_d, kv_chunk_size_p, kv_chunk_size_d, request_indices_vec, qo_tile_indices_vec, + kv_tile_indices_vec, merge_indptr_vec, o_indptr_vec] = + PODSplitQOKVIndptr(qo_indptr_p, kv_indptr_p, total_num_rows_p, batch_size_p, qo_indptr_d, + kv_indptr_d, total_num_rows_d, batch_size_d, num_qo_heads, num_kv_heads, + head_dim_vo, page_size, max_batch_size_if_split, enable_cuda_graph); + uint32_t padded_batch_size = padded_batch_size_p + padded_batch_size_d; + uint32_t batch_size = batch_size_p + batch_size_d; + uint32_t total_num_rows = total_num_rows_p + total_num_rows_d; + + plan_info.padded_batch_size_p = padded_batch_size_p; + plan_info.padded_batch_size_d = padded_batch_size_d; + plan_info.total_num_rows_p = total_num_rows_p; + plan_info.total_num_rows_d = total_num_rows_d; + plan_info.total_num_rows = total_num_rows; + plan_info.cta_tile_q_p = cta_tile_q_p; + plan_info.cta_tile_q_d = cta_tile_q_d; + plan_info.enable_cuda_graph = enable_cuda_graph; + plan_info.split_kv = split_kv; + + AlignedAllocator int_allocator(int_buffer, int_workspace_size_in_bytes); + plan_info.request_indices_offset = int_allocator.aligned_alloc_offset( + sizeof(IdType) * padded_batch_size, 16, "pod_prefill_request_indices"); + plan_info.qo_tile_indices_offset = int_allocator.aligned_alloc_offset( + sizeof(IdType) * padded_batch_size, 16, "pod_prefill_qo_tile_indices"); + plan_info.kv_tile_indices_offset = int_allocator.aligned_alloc_offset( + sizeof(IdType) * padded_batch_size, 16, "pod_prefill_kv_tile_indices"); + plan_info.o_indptr_offset = + int_allocator.aligned_alloc_offset(sizeof(IdType) * (batch_size + 1), 16, "pod_o_indptr"); + plan_info.kv_chunk_size_ptr_offset_p = + int_allocator.aligned_alloc_offset(sizeof(IdType), 1, "pod_prefill_kv_chunk_size_ptr"); + plan_info.kv_chunk_size_ptr_offset_d = + int_allocator.aligned_alloc_offset(sizeof(IdType), 1, "pod_decode_kv_chunk_size_ptr"); + + if (plan_info.enable_cuda_graph) { + plan_info.total_num_rows_offset = + int_allocator.aligned_alloc_offset(sizeof(uint32_t), 16, "batch_prefill_total_num_rows"); + uint32_t* total_num_rows_h = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.total_num_rows_offset); + *total_num_rows_h = total_num_rows_p + total_num_rows_d; + } + + IdType* request_indices_h = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.request_indices_offset); + IdType* qo_tile_indices_h = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.qo_tile_indices_offset); + IdType* kv_tile_indices_h = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.kv_tile_indices_offset); + IdType* o_indptr_h = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.o_indptr_offset); + IdType* kv_chunk_size_ptr_p = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.kv_chunk_size_ptr_offset_p); + IdType* kv_chunk_size_ptr_d = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.kv_chunk_size_ptr_offset_d); + std::copy(request_indices_vec.begin(), request_indices_vec.end(), request_indices_h); + std::copy(qo_tile_indices_vec.begin(), qo_tile_indices_vec.end(), qo_tile_indices_h); + std::copy(kv_tile_indices_vec.begin(), kv_tile_indices_vec.end(), kv_tile_indices_h); + std::copy(o_indptr_vec.begin(), o_indptr_vec.end(), o_indptr_h); + kv_chunk_size_ptr_p[0] = kv_chunk_size_p; + kv_chunk_size_ptr_d[0] = kv_chunk_size_d; + + if (split_kv) { + // TODO(Wenxuan): write through for non-split-kv requests + uint32_t num_outputs_p = num_qo_heads * padded_batch_size_p * cta_tile_q_p; + uint32_t num_outputs_d = num_qo_heads * padded_batch_size_d * cta_tile_q_d; + AlignedAllocator float_allocator(float_buffer, float_workspace_size_in_bytes); + plan_info.v_offset = float_allocator.aligned_alloc_offset( + (num_outputs_p + num_outputs_d) * head_dim_vo * sizeof(float), 16, "pod_tmp_v"); + plan_info.s_offset = float_allocator.aligned_alloc_offset( + (num_outputs_p + num_outputs_d) * sizeof(float), 16, "pod_tmp_s"); + plan_info.merge_indptr_offset = int_allocator.aligned_alloc_offset( + sizeof(IdType) * (plan_info.total_num_rows + 1), 16, "pod_merge_indptr"); + plan_info.block_valid_mask_offset = int_allocator.aligned_alloc_offset( + sizeof(bool) * padded_batch_size, 16, "pod_block_valid_mask"); + + IdType* merge_indptr_h = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.merge_indptr_offset); + bool* block_valid_mask_h = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.block_valid_mask_offset); + std::copy(merge_indptr_vec.begin(), merge_indptr_vec.end(), merge_indptr_h); + for (uint32_t i = 0; i < padded_batch_size; ++i) { + block_valid_mask_h[i] = i < new_batch_size; + } + } + + size_t num_bytes_to_copy = int_allocator.num_allocated_bytes(); + FLASHINFER_CUDA_CALL(cudaMemcpyAsync(int_buffer, page_locked_int_buffer, num_bytes_to_copy, + cudaMemcpyHostToDevice, stream)); + + return cudaSuccess; +} + inline float cost_function(int qo_len, int kv_len) { return 2 * float(qo_len) + kv_len; } template diff --git a/tests/utils/test_pod_kernels.py b/tests/utils/test_pod_kernels.py index 8900cc1b6c..0bf08bccf1 100644 --- a/tests/utils/test_pod_kernels.py +++ b/tests/utils/test_pod_kernels.py @@ -72,12 +72,12 @@ def warmup_jit(): yield -@pytest.mark.parametrize("kv_len_p", [127, 12288]) -@pytest.mark.parametrize("qo_len_p", [127, 12288]) +@pytest.mark.parametrize("kv_len_p", [128, 12288]) +@pytest.mark.parametrize("qo_len_p", [128, 12288]) @pytest.mark.parametrize("causal", [False, True]) -@pytest.mark.parametrize("batch_size_d", [1, 17, 127]) -@pytest.mark.parametrize("kv_len_d", [127, 12288]) -@pytest.mark.parametrize("page_size_d", [1, 16]) +@pytest.mark.parametrize("batch_size_d", [80, 220, 250]) +@pytest.mark.parametrize("kv_len_d", [128, 12288]) +@pytest.mark.parametrize("page_size", [1, 16]) @pytest.mark.parametrize("kv_layout_d", ["NHD"]) @pytest.mark.parametrize("num_kv_heads", [8]) @pytest.mark.parametrize("num_qo_heads", [8, 32]) @@ -94,7 +94,7 @@ def test_pod_with_paged_kv_cache( # Decode params batch_size_d, kv_len_d, - page_size_d, + page_size, kv_layout_d, # Shared params num_kv_heads, @@ -105,36 +105,69 @@ def test_pod_with_paged_kv_cache( kv_dtype, contiguous_kv, ): + device = "cuda:0" if causal and qo_len_p > kv_len_p: pytest.skip("Causal prefill with qo_len_p > kv_len_p is not supported") + # Prefill inputs + kv_layout_p = kv_layout_d + kv_len_p_padded = (kv_len_p + page_size - 1) // page_size * page_size q_p = torch.randn( - qo_len_p, num_qo_heads, head_dim, device="cuda:0", dtype=torch.float16 + qo_len_p, num_qo_heads, head_dim, device=device, dtype=torch.float16 ) k_p = torch.randn( - kv_len_p, num_kv_heads, head_dim, device="cuda:0", dtype=torch.float16 + kv_len_p_padded, num_kv_heads, head_dim, device=device, dtype=torch.float16 ) v_p = torch.randn( - kv_len_p, num_kv_heads, head_dim, device="cuda:0", dtype=torch.float16 + kv_len_p_padded, num_kv_heads, head_dim, device=device, dtype=torch.float16 ) - # Generate prefill reference output - o_ref_p = flashinfer.prefill.single_prefill_with_kv_cache( - q_p, - k_p, - v_p, - causal=causal, - pos_encoding_mode=pos_encoding_mode, + # Generate paged prefill inputs for POD + qo_indptr_p = torch.cat( + [ + torch.tensor([0], device=device), + ( + torch.cumsum(torch.tensor([qo_len_p], device=device), 0) + if qo_len_p > 0 + else torch.empty(0, device=device) + ), + ], + dim=0, + ).int() + total_num_pages_p = kv_len_p_padded // page_size + kv_indptr_p = torch.tensor( + [0, total_num_pages_p] if total_num_pages_p > 0 else [0], + device=device, + dtype=torch.int32, + ) + kv_indices_p = torch.arange( + 0, max(1, total_num_pages_p), device=device, dtype=torch.int32 + ) + last_page_len_p = torch.tensor( + [(kv_len_p_padded - 1) % page_size + 1], device=device, dtype=torch.int32 ) + + if kv_layout_p == "NHD": + k_p = k_p.reshape(total_num_pages_p, 1, page_size, num_kv_heads, head_dim) + v_p = v_p.reshape(total_num_pages_p, 1, page_size, num_kv_heads, head_dim) + else: + k_p = k_p.reshape( + total_num_pages_p, 1, page_size, num_kv_heads, head_dim + ).transpose(2, 3) + v_p = v_p.reshape( + total_num_pages_p, 1, page_size, num_kv_heads, head_dim + ).transpose(2, 3) + kv_data_p = torch.cat([k_p, v_p], dim=1) + # Decode inputs q_d = torch.randn( - batch_size_d, num_qo_heads, head_dim, device="cuda:0", dtype=torch.float16 + batch_size_d, num_qo_heads, head_dim, device=device, dtype=torch.float16 ) - num_pages_per_seq = (kv_len_d + page_size_d - 1) // page_size_d + num_pages_per_seq = (kv_len_d + page_size - 1) // page_size total_num_pages = num_pages_per_seq * batch_size_d - if kv_layout_d == "HND": - kv_shape = [total_num_pages, 2, num_kv_heads, page_size_d, head_dim] + if kv_layout_d == "NHD": + kv_shape = [total_num_pages, 2, page_size, num_kv_heads, head_dim] else: - kv_shape = [total_num_pages, 2, page_size_d, num_kv_heads, head_dim] + kv_shape = [total_num_pages, 2, num_kv_heads, page_size, head_dim] if not contiguous_kv: tmp = [kv_shape[0]] for v_d in kv_shape[1:]: @@ -142,78 +175,92 @@ def test_pod_with_paged_kv_cache( tmp.append(v_d) kv_shape = tmp kv_data_fp32 = torch.randn(*kv_shape, device="cuda:0", dtype=torch.float32) - kv_data = kv_data_fp32.to(kv_dtype) - kv_data = kv_data[:, 1, :, 1, :, 1, :, 1, :] + kv_data_d = kv_data_fp32.to(kv_dtype) + kv_data_d = kv_data_d[:, 1, :, 1, :, 1, :, 1, :] kv_data_fp32 = kv_data_fp32[:, 1, :, 1, :, 1, :, 1, :] # actual data is stored in non-contiguous memory assert ( - kv_data.stride(-4) - != kv_data.shape[-3] * kv_data.shape[-2] * kv_data.shape[-1] + kv_data_d.stride(-4) + != kv_data_d.shape[-3] * kv_data_d.shape[-2] * kv_data_d.shape[-1] ) else: kv_data_fp32 = torch.randn(*kv_shape, device="cuda:0", dtype=torch.float32) - kv_data = kv_data_fp32.to(kv_dtype) + kv_data_d = kv_data_fp32.to(kv_dtype) kv_indptr_d = ( - torch.arange(0, batch_size_d + 1, device="cuda:0", dtype=torch.int32) + torch.arange(0, batch_size_d + 1, device=device, dtype=torch.int32) * num_pages_per_seq ) - kv_indices_d = torch.arange(0, total_num_pages, device="cuda:0", dtype=torch.int32) - kv_last_page_len = torch.full( + kv_indices_d = torch.arange(0, total_num_pages, device=device, dtype=torch.int32) + last_page_len_d = torch.full( (batch_size_d,), - (kv_len_d - 1) % page_size_d + 1, - device="cuda:0", + (kv_len_d - 1) % page_size + 1, + device=device, dtype=torch.int32, ) - # Generate decode reference output - decode_workspace_buffer = torch.empty( - 32 * 1024 * 1024, device="cuda:0", dtype=torch.int8 - ) - decode_wrapper = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper( - decode_workspace_buffer, kv_layout_d + workspace_buffer = torch.empty(32 * 1024 * 1024, device=device, dtype=torch.int8) + pod_wrapper = flashinfer.PODWithPagedKVCacheWrapper( + workspace_buffer, + kv_layout_d, ) - decode_wrapper.plan( + + kv_data = torch.cat([kv_data_p, kv_data_d]) + + pod_wrapper.plan( + qo_indptr_p, + kv_indptr_p, + kv_indices_p, + last_page_len_p, kv_indptr_d, kv_indices_d, - kv_last_page_len, + last_page_len_d, num_qo_heads, num_kv_heads, head_dim, - page_size_d, + page_size, pos_encoding_mode=pos_encoding_mode, data_type=kv_dtype, q_data_type=q_dtype, ) - o_ref_d = decode_wrapper.run(q_d, kv_data) + o_p, o_d = pod_wrapper.run( + q_p, + q_d, + kv_data, + pos_encoding_mode_p=pos_encoding_mode, + causal_p=causal, + ) - workspace_buffer = torch.empty(32 * 1024 * 1024, device="cuda:0", dtype=torch.int8) - pod_wrapper = flashinfer.PODWithPagedKVCacheWrapper( - workspace_buffer, - kv_layout_d, + # Generate prefill reference output + o_ref_p = flashinfer.prefill.single_prefill_with_kv_cache( + q_p, + k_p, + v_p, + causal=causal, + pos_encoding_mode=pos_encoding_mode, ) - pod_wrapper.plan( + # Generate decode reference output + decode_workspace_buffer = torch.empty( + 32 * 1024 * 1024, device=device, dtype=torch.int8 + ) + decode_wrapper = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper( + decode_workspace_buffer, kv_layout_d + ) + decode_wrapper.plan( kv_indptr_d, kv_indices_d, - kv_last_page_len, + last_page_len_d, num_qo_heads, num_kv_heads, head_dim, - page_size_d, + page_size, pos_encoding_mode=pos_encoding_mode, data_type=kv_dtype, q_data_type=q_dtype, ) + o_ref_d = decode_wrapper.run(q_d, kv_data_d) - o_p, o_d = pod_wrapper.run( - q_p, - k_p, - v_p, - q_d, - kv_data, - pos_encoding_mode_p=pos_encoding_mode, - causal_p=causal, - ) # Prefill is run with batch size 1 + torch.cuda.synchronize() torch.testing.assert_close( o_p, o_ref_p, rtol=1e-3, atol=1e-3, msg="Prefill mismatch" )