-
Notifications
You must be signed in to change notification settings - Fork 573
[Feature] Support batch prefill for POD Attention #1231
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from 31 commits
Commits
Show all changes
52 commits
Select commit
Hold shift + click to select a range
06ef37d
use FA2DetermineCtaTileQ for pod
Edenzzzz 14e3f13
modify wrapper..
Edenzzzz 22dc539
fix
Edenzzzz 3257899
bench against persistent
Edenzzzz 82f1550
rename xsize to num_qo_tiles
Edenzzzz 06fee31
fix
Edenzzzz f47f73e
fix
Edenzzzz c51cecc
fix
Edenzzzz e73566c
add mixed scheduler
Edenzzzz 1b2d4c0
rename to num_to_merge_qo_len
Edenzzzz 78e1266
add params
Edenzzzz 4979a2a
plan to use one reduction kernel for prefill and decode
Edenzzzz 2102e22
fix
Edenzzzz fab82ae
use unifed qkv indptr
Edenzzzz 13b6b19
fix
Edenzzzz 7d29232
fix plan func upper call interface
Edenzzzz 106bfdc
rename new_batch_size to real_batch_size
Edenzzzz 5e3e896
concat request_indices
Edenzzzz ac07253
unifed indices in wrapper.plan
Edenzzzz eb8f719
fixes
Edenzzzz e8b266d
fix params
Edenzzzz 560918b
fix some indices and params
Edenzzzz 2105101
update PODWithKVCacheTensorRun args
Edenzzzz 1a82b17
add paged kv params
Edenzzzz dd80a06
complete PODWithKVCacheTensorRun params
Edenzzzz 0bb164b
share lse
Edenzzzz 32d762b
templaterize CTA_TILE_Q_P
Edenzzzz 870b0b2
update dispatch logic
Edenzzzz a7eb44b
add pod template inst and .cu gen
Edenzzzz f57fde0
fixes
Edenzzzz 0832bac
trying to fix template
Edenzzzz a570bff
fix PODSplitQOKVIndptr param
Edenzzzz 965d073
fix template and vec type errors
Edenzzzz c2fcc54
use get_pod_module jit in wrapper
Edenzzzz 708d1be
attempt to fix some templates
Edenzzzz 9b9f8f4
fix
Edenzzzz 6f75dae
Fix compilation errors
AKKamath 332e4bf
Various other fixes for POD
AKKamath d7707f5
update pod test
Edenzzzz 766e8a8
update pod test
Edenzzzz ac408f4
fix
Edenzzzz cb36898
fix tests
Edenzzzz ff0e4c3
remove kv_len_arr_p
Edenzzzz 1a1323d
fix negative kv len bug
Edenzzzz ab5d4f1
add chunk_size check
Edenzzzz 9927dd2
fix
Edenzzzz 19cad42
remove mistaken causal
Edenzzzz f39be0b
use int32 for indices
Edenzzzz 875f170
add some checks
Edenzzzz 81fc9a7
fix dispatch logic when prefill request is empty
Edenzzzz b604637
Merge main Oct 30
Edenzzzz 458bcce
fix 3rdparty
Edenzzzz File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,121 @@ | ||
| """ | ||
| Copyright (c) 2024 by FlashInfer team. | ||
|
|
||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
|
|
||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. | ||
| """ | ||
|
|
||
| import re | ||
| import sys | ||
| from pathlib import Path | ||
|
|
||
| from .literal_map import ( | ||
| dtype_literal, | ||
| idtype_literal, | ||
| mask_mode_literal, | ||
| pos_encoding_mode_literal, | ||
| ) | ||
|
|
||
|
|
||
| def get_cu_file_str( | ||
| head_dim_qk, | ||
| head_dim_vo, | ||
| pos_encoding_mode, | ||
| use_fp16_qk_reduction, | ||
| mask_mode_p, | ||
| mask_mode_d, | ||
| dtype_q, | ||
| dtype_kv, | ||
| dtype_out, | ||
| idtype, | ||
| ): | ||
| cta_tile_q_choice = [128, 64, 16] | ||
| cta_tile_q_d = 16 | ||
|
|
||
| def get_insts(attention_variant_p, attention_variant_d, dtype_out): | ||
| return "\n".join( | ||
| [ | ||
| """template cudaError_t PODWithKVCacheTensorDispatched<{head_dim_qk}, {head_dim_vo}, {pos_encoding_mode}, {use_fp16_qk_reduction}, {mask_mode_p}, {cta_tile_q_p}, {cta_tile_q_d}, {mask_mode_d}, {attention_variant_p}, {attention_variant_d}, PrefillParams, DecodeParams>( | ||
| PrefillParams prefill_params, DecodeParams decode_params, | ||
| {dtype_out}* tmp_v, float* tmp_s, bool enable_pdl, cudaStream_t stream); | ||
| """.format( | ||
| head_dim_qk=head_dim_qk, | ||
| head_dim_vo=head_dim_vo, | ||
| pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)], | ||
| use_fp16_qk_reduction=use_fp16_qk_reduction, | ||
| mask_mode_p=mask_mode_literal[int(mask_mode_p)], | ||
| cta_tile_q_p=cta_tile_q_p, | ||
| cta_tile_q_d=cta_tile_q_d, | ||
| mask_mode_d=mask_mode_literal[int(mask_mode_d)], | ||
| attention_variant_p=attention_variant_p, | ||
| attention_variant_d=attention_variant_d, | ||
| dtype_out=dtype_out, | ||
| ) | ||
| for cta_tile_q_p in cta_tile_q_choice | ||
| ] | ||
| ) | ||
|
|
||
| use_custom_mask_p = "true" if int(mask_mode_p) == 2 else "false" | ||
| use_custom_mask_d = "true" if int(mask_mode_d) == 2 else "false" | ||
| dtype_q = dtype_literal[dtype_q] | ||
| dtype_kv = dtype_literal[dtype_kv] | ||
| dtype_out = dtype_literal[dtype_out] | ||
| idtype = idtype_literal[idtype] | ||
|
|
||
| content = f"""#include <flashinfer/attention/pod.cuh> | ||
| #include "pod_config.inc" | ||
|
|
||
| namespace flashinfer {{ | ||
|
|
||
| constexpr auto use_custom_mask_p = MaskMode::kNone == MaskMode::kCustom; | ||
| constexpr auto use_custom_mask_d = MaskMode::kNone == MaskMode::kCustom; | ||
|
|
||
| using PrefillParams = BatchPrefillPagedParams<{dtype_q}, {dtype_kv}, {dtype_out}>; | ||
| using DecodeParams = BatchPrefillPagedParams<{dtype_q}, {dtype_kv}, {dtype_out}, {idtype}>; | ||
|
|
||
| using AttentionVariant1_P = DefaultAttention<use_custom_mask_p, /*use_sliding_window=*/true, /*use_logits_soft_cap=*/false, /*use_alibi_bias=*/false>; | ||
| using AttentionVariant1_D = DefaultAttention<use_custom_mask_d, /*use_sliding_window=*/true, /*use_logits_soft_cap=*/false, /*use_alibi_bias=*/false>; | ||
|
|
||
| {get_insts("AttentionVariant1_P", "AttentionVariant1_D", dtype_out)} | ||
|
|
||
| using AttentionVariant2_P = DefaultAttention<use_custom_mask_p, /*use_sliding_window=*/true, /*use_logits_soft_cap=*/true, /*use_alibi_bias=*/false>; | ||
| using AttentionVariant2_D = DefaultAttention<use_custom_mask_d, /*use_sliding_window=*/true, /*use_logits_soft_cap=*/true, /*use_alibi_bias=*/false>; | ||
|
|
||
| {get_insts("AttentionVariant2_P", "AttentionVariant2_D", dtype_out)} | ||
|
|
||
| using AttentionVariant3_P = DefaultAttention<use_custom_mask_p, /*use_sliding_window=*/false, /*use_logits_soft_cap=*/false, /*use_alibi_bias=*/false>; | ||
| using AttentionVariant3_D = DefaultAttention<use_custom_mask_d, /*use_sliding_window=*/false, /*use_logits_soft_cap=*/false, /*use_alibi_bias=*/false>; | ||
|
|
||
| {get_insts("AttentionVariant3_P", "AttentionVariant3_D", dtype_out)} | ||
|
|
||
| using AttentionVariant4_P = DefaultAttention<use_custom_mask_p, /*use_sliding_window=*/false, /*use_logits_soft_cap=*/true, /*use_alibi_bias=*/false>; | ||
| using AttentionVariant4_D = DefaultAttention<use_custom_mask_d, /*use_sliding_window=*/false, /*use_logits_soft_cap=*/true, /*use_alibi_bias=*/false>; | ||
|
|
||
| {get_insts("AttentionVariant4_P", "AttentionVariant4_D", dtype_out)} | ||
|
|
||
| }}""" | ||
| return content | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| pattern = ( | ||
| r"pod_head_qk_([0-9]+)_head_vo_([0-9]+)_posenc_([0-9]+)_" | ||
| r"fp16qkred_([a-z]+)_maskp_([0-9]+)_maskd_([0-9]+)_" | ||
| r"dtypeq_([a-z0-9]+)_dtypekv_([a-z0-9]+)_dtypeout_([a-z0-9]+)_idtype_([a-z0-9]+)\.cu" | ||
| ) | ||
| compiled_pattern = re.compile(pattern) | ||
| path = Path(sys.argv[1]) | ||
| fname = path.name | ||
| match = compiled_pattern.match(fname) | ||
|
|
||
| with open(path, "w") as f: | ||
| f.write(get_cu_file_str(*match.groups())) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.