Skip to content

Commit c9debf7

Browse files
committed
Optimize FlashMask v3 performance (PaddlePaddle#75737)
* tune bwd tile size * tune bwd tile size for seqlen <= 8192 * fix cuda 700 cause by incorrect bwd tile size * set scheduler_needs_semaphore to true * update fa submodule * update fa submodule * update fa submodule * update fa submodule
1 parent 33eff52 commit c9debf7

File tree

3 files changed

+6
-10
lines changed

3 files changed

+6
-10
lines changed

paddle/phi/kernels/gpu/flash_attn_v3_grad_kernel.cu

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1026,11 +1026,9 @@ void FlashMaskV2GradBaseKernel(
10261026
? 64
10271027
: (is_causal && softcap || is_flashmask > 0.0 ? 96 : 128)
10281028
: (head_size_rounded <= 128
1029-
? (is_flashmask && !is_causal)
1030-
? 64
1031-
: (is_causal || is_local || is_flashmask || softcap > 0.0
1032-
? 64
1033-
: 80)
1029+
? (is_causal || is_local || softcap > 0.0
1030+
? 64
1031+
: 64)
10341032
: 64);
10351033

10361034
int const kBlockM_sm80 = head_size_rounded <= 64 ? 128 : 64;
@@ -1040,7 +1038,7 @@ void FlashMaskV2GradBaseKernel(
10401038
: (arch == 86 || arch == 89 ? kBlockM_sm86 : kBlockM_sm80);
10411039
int const kBlockN_sm90 =
10421040
head_size_rounded <= 64 && (is_flashmask && !is_causal) ? 96
1043-
: head_size_rounded <= 128 ? (is_flashmask && !is_causal) ? 64 : 128
1041+
: head_size_rounded <= 128 ? (is_causal || is_local || softcap > 0.f || seqlen_q >= 8192 ? 128 : 64)
10441042
: (head_size_rounded <= 192 ? 96 : 80);
10451043
int const kBlockN_sm80 =
10461044
head_size_rounded <= 128 ? 128 : (head_size_rounded <= 192 ? 80 : 64);

paddle/phi/kernels/gpu/flash_attn_v3_kernel.cu

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1762,9 +1762,7 @@ void FlashMaskV2BaseKernel(
17621762
const int params_arch =
17631763
phi::dynload::flashmaskv2_fwd_params_get_arch(params_handle);
17641764
bool const scheduler_needs_semaphore =
1765-
params_arch >= 90 ? (((params_is_causal || params_is_local) &&
1766-
(params_num_splits == 1)) ||
1767-
is_varlen)
1765+
params_arch >= 90 ? true
17681766
: ((params_is_causal && !is_varlen) ||
17691767
(is_varlen && params_num_splits > 1));
17701768
if (scheduler_needs_semaphore || use_dynamic_split) {

0 commit comments

Comments
 (0)