File tree Expand file tree Collapse file tree 1 file changed +5
-5
lines changed Expand file tree Collapse file tree 1 file changed +5
-5
lines changed Original file line number Diff line number Diff line change @@ -1026,9 +1026,7 @@ void FlashMaskV2GradBaseKernel(
10261026 ? 64
10271027 : (is_causal && softcap || is_flashmask > 0.0 ? 96 : 128 )
10281028 : (head_size_rounded <= 128
1029- ? (is_causal || is_local || softcap > 0.0
1030- ? 64
1031- : 64 )
1029+ ? (is_causal || is_local || softcap > 0.0 ? 64 : 64 )
10321030 : 64 );
10331031
10341032 int const kBlockM_sm80 = head_size_rounded <= 64 ? 128 : 64 ;
@@ -1038,8 +1036,10 @@ void FlashMaskV2GradBaseKernel(
10381036 : (arch == 86 || arch == 89 ? kBlockM_sm86 : kBlockM_sm80 );
10391037 int const kBlockN_sm90 =
10401038 head_size_rounded <= 64 && (is_flashmask && !is_causal) ? 96
1041- : head_size_rounded <= 128 ? (is_causal || is_local || softcap > 0 .f || seqlen_q >= 8192 ? 128 : 64 )
1042- : (head_size_rounded <= 192 ? 96 : 80 );
1039+ : head_size_rounded <= 128
1040+ ? (is_causal || is_local || softcap > 0 .f || seqlen_q >= 8192 ? 128
1041+ : 64 )
1042+ : (head_size_rounded <= 192 ? 96 : 80 );
10431043 int const kBlockN_sm80 =
10441044 head_size_rounded <= 128 ? 128 : (head_size_rounded <= 192 ? 80 : 64 );
10451045 int const kBlockN_sm86 =
You can’t perform that action at this time.
0 commit comments