Skip to content

Commit e14a08e

Browse files
committed
fix codestyle
1 parent c9debf7 commit e14a08e

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

paddle/phi/kernels/gpu/flash_attn_v3_grad_kernel.cu

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1026,9 +1026,7 @@ void FlashMaskV2GradBaseKernel(
10261026
? 64
10271027
: (is_causal && softcap || is_flashmask > 0.0 ? 96 : 128)
10281028
: (head_size_rounded <= 128
1029-
? (is_causal || is_local || softcap > 0.0
1030-
? 64
1031-
: 64)
1029+
? (is_causal || is_local || softcap > 0.0 ? 64 : 64)
10321030
: 64);
10331031

10341032
int const kBlockM_sm80 = head_size_rounded <= 64 ? 128 : 64;
@@ -1038,8 +1036,10 @@ void FlashMaskV2GradBaseKernel(
10381036
: (arch == 86 || arch == 89 ? kBlockM_sm86 : kBlockM_sm80);
10391037
int const kBlockN_sm90 =
10401038
head_size_rounded <= 64 && (is_flashmask && !is_causal) ? 96
1041-
: head_size_rounded <= 128 ? (is_causal || is_local || softcap > 0.f || seqlen_q >= 8192 ? 128 : 64)
1042-
: (head_size_rounded <= 192 ? 96 : 80);
1039+
: head_size_rounded <= 128
1040+
? (is_causal || is_local || softcap > 0.f || seqlen_q >= 8192 ? 128
1041+
: 64)
1042+
: (head_size_rounded <= 192 ? 96 : 80);
10431043
int const kBlockN_sm80 =
10441044
head_size_rounded <= 128 ? 128 : (head_size_rounded <= 192 ? 80 : 64);
10451045
int const kBlockN_sm86 =

0 commit comments

Comments
 (0)