@@ -1020,27 +1020,119 @@ void FlashMaskV2GradBaseKernel(
10201020 bool const is_local =
10211021 (window_size_left >= 0 || window_size_right >= 0 ) && !is_causal;
10221022 bool const is_flashmask = startend_row_indices_.is_initialized ();
1023+ DenseTensor startend_row_indices;
1024+ if (is_flashmask) startend_row_indices = startend_row_indices_.get ();
1025+ bool const has_softcap = softcap > 0.0 ;
10231026
1024- int const kBlockM_sm90 =
1025- head_size_rounded <= 64
1026- ? (is_flashmask && !is_causal)
1027- ? 64
1028- : (is_causal && softcap || is_flashmask > 0.0 ? 96 : 128 )
1029- : (head_size_rounded <= 128
1030- ? (is_causal || is_local || softcap > 0.0
1031- ? 64
1032- : 64 )
1033- : 64 );
1027+ // flashmask
1028+ DenseTensor flashmask_maxmin, lt_start_row_indices, lt_end_row_indices,
1029+ ut_start_row_indices, ut_end_row_indices;
1030+ if (is_flashmask) {
1031+ PADDLE_ENFORCE_EQ (
1032+ startend_row_indices.dtype (),
1033+ phi::DataType::INT32,
1034+ common::errors::InvalidArgument (
1035+ " flashmask_attention startend_row_indices must be INT32 type" ));
1036+ PADDLE_ENFORCE_EQ (
1037+ startend_row_indices.dims ().size (),
1038+ 4 ,
1039+ common::errors::InvalidArgument (
1040+ " flashmask_attention receive startend_row_indices with dim "
1041+ " [batch_size, num_heads,seq_len, mask_bounds]" ));
1042+ PADDLE_ENFORCE_EQ (startend_row_indices.dims ()[3 ] == 1 ||
1043+ startend_row_indices.dims ()[3 ] == 2 ||
1044+ startend_row_indices.dims ()[3 ] == 4 ,
1045+ true ,
1046+ common::errors::InvalidArgument (
1047+ " flashmask_attention startend_row_indices "
1048+ " mask_bounds must in [1,2,4]" ));
1049+
1050+ auto flashmask_maxmin_shape = startend_row_indices.dims ();
1051+ // TODO(umiswing): refine this block constraint (kBlockN % 32), since some
1052+ // of kBlockN is not divisible by 32 flashmask_maxmin_shape[2] =
1053+ // (flashmask_maxmin_shape[2] + 31) / 32 * 8;
1054+ flashmask_maxmin_shape[2 ] =
1055+ ((flashmask_maxmin_shape[2 ] + 31 ) / 32 + 3 ) / 4 * 4 ;
1056+ flashmask_maxmin_shape[3 ] = 8 ;
1057+
1058+ flashmask_maxmin.set_type (phi::DataType::INT32);
1059+ flashmask_maxmin.Resize (flashmask_maxmin_shape);
1060+ dev_ctx.template Alloc <int32_t >(&flashmask_maxmin);
1061+
1062+ lt_start_row_indices =
1063+ phi::Slice<int32_t >(dev_ctx, startend_row_indices, {3 }, {0 }, {1 });
1064+ if (startend_row_indices.dims ()[3 ] == 2 ) {
1065+ if (!is_causal) {
1066+ ut_end_row_indices =
1067+ phi::Slice<int32_t >(dev_ctx, startend_row_indices, {3 }, {1 }, {2 });
1068+ } else {
1069+ lt_end_row_indices =
1070+ phi::Slice<int32_t >(dev_ctx, startend_row_indices, {3 }, {1 }, {2 });
1071+ }
1072+ } else if (startend_row_indices.dims ()[3 ] == 4 ) {
1073+ ut_end_row_indices =
1074+ phi::Slice<int32_t >(dev_ctx, startend_row_indices, {3 }, {3 }, {4 });
1075+ lt_end_row_indices =
1076+ phi::Slice<int32_t >(dev_ctx, startend_row_indices, {3 }, {1 }, {2 });
1077+ ut_start_row_indices =
1078+ phi::Slice<int32_t >(dev_ctx, startend_row_indices, {3 }, {2 }, {3 });
1079+ }
1080+ }
1081+
1082+ const bool has_lt_start = lt_start_row_indices.initialized ();
1083+ const bool has_lt_end = lt_end_row_indices.initialized ();
1084+ const bool has_ut_start = ut_start_row_indices.initialized ();
1085+ const bool has_ut_end = ut_end_row_indices.initialized ();
1086+
1087+ // umiswing: The tile dispatch for flashmask is now different from fa3.
1088+ // Replacing the original ternary operator with lambda makes the code
1089+ // easier to reason about and less error-prone.
1090+ const auto [kBlockM_sm90 , kBlockN_sm90 ] = [&]() -> std::pair<int , int > {
1091+ if (head_size_rounded <= 64 ) {
1092+ if (is_flashmask && !is_causal) {
1093+ return {64 , 96 };
1094+ } else if (is_causal && has_softcap || is_flashmask) {
1095+ return {96 , 128 };
1096+ } else {
1097+ return {128 , 128 };
1098+ }
1099+ } else if (head_size_rounded <= 128 ) {
1100+ // umiswing: by now, we resue template instantiation of head dim 128 for head dim in range (64, 128],
1101+ // and therefore no separate dispatch for head dim in range (64, 96]
1102+ if (is_causal || is_local || has_softcap) {
1103+ return {64 , 128 };
1104+ } else {
1105+ if ((seqlen_q >= 1024 || seqlen_k >= 1024 ) && !(has_lt_end && has_ut_start)) {
1106+ return {64 , 128 };
1107+ } else {
1108+ return {64 , 64 };
1109+ }
1110+ }
1111+ } else if (head_size_rounded <= 192 ) {
1112+ // umiswing: head dim > 128 is not supported now
1113+ PADDLE_THROW (common::errors::Unimplemented (
1114+ " head dim is rounded to %d, which is not supported in FlashMask V3 now." ,
1115+ head_size_rounded));
1116+ return {0 , 0 };
1117+ } else if (head_size_rounded <= 256 ) {
1118+ // umiswing: head dim > 128 is not supported now
1119+ PADDLE_THROW (common::errors::Unimplemented (
1120+ " head dim is rounded to %d, which is not supported in FlashMask V3 now." ,
1121+ head_size_rounded));
1122+ return {0 , 0 };
1123+ } else {
1124+ PADDLE_THROW (common::errors::Unimplemented (
1125+ " head dim is rounded to %d, which is not supported in FlashMask V3 now." ,
1126+ head_size_rounded));
1127+ return {0 , 0 };
1128+ }
1129+ }();
10341130
10351131 int const kBlockM_sm80 = head_size_rounded <= 64 ? 128 : 64 ;
10361132 int const kBlockM_sm86 = head_size_rounded <= 192 ? 64 : 32 ;
10371133 int const kBlockM =
10381134 arch >= 90 ? kBlockM_sm90
10391135 : (arch == 86 || arch == 89 ? kBlockM_sm86 : kBlockM_sm80 );
1040- int const kBlockN_sm90 =
1041- head_size_rounded <= 64 && (is_flashmask && !is_causal) ? 96
1042- : head_size_rounded <= 128 ? (is_causal || is_local || softcap > 0 .f || seqlen_q >= 8192 ? 128 : 64 )
1043- : (head_size_rounded <= 192 ? 96 : 80 );
10441136 int const kBlockN_sm80 =
10451137 head_size_rounded <= 128 ? 128 : (head_size_rounded <= 192 ? 80 : 64 );
10461138 int const kBlockN_sm86 =
@@ -1307,62 +1399,6 @@ void FlashMaskV2GradBaseKernel(
13071399 dynload::flashmaskv2_bwd_params_set_dv_semaphore (params_handle,
13081400 dv_semaphore.data <int >());
13091401 }
1310- // flashmask
1311- DenseTensor startend_row_indices;
1312- if (is_flashmask) startend_row_indices = startend_row_indices_.get ();
1313- DenseTensor flashmask_maxmin, lt_start_row_indices, lt_end_row_indices,
1314- ut_start_row_indices, ut_end_row_indices;
1315- if (is_flashmask) {
1316- PADDLE_ENFORCE_EQ (
1317- startend_row_indices.dtype (),
1318- phi::DataType::INT32,
1319- common::errors::InvalidArgument (
1320- " flashmask_attention startend_row_indices must be INT32 type" ));
1321- PADDLE_ENFORCE_EQ (
1322- startend_row_indices.dims ().size (),
1323- 4 ,
1324- common::errors::InvalidArgument (
1325- " flashmask_attention receive startend_row_indices with dim "
1326- " [batch_size, num_heads,seq_len, mask_bounds]" ));
1327- PADDLE_ENFORCE_EQ (startend_row_indices.dims ()[3 ] == 1 ||
1328- startend_row_indices.dims ()[3 ] == 2 ||
1329- startend_row_indices.dims ()[3 ] == 4 ,
1330- true ,
1331- common::errors::InvalidArgument (
1332- " flashmask_attention startend_row_indices "
1333- " mask_bounds must in [1,2,4]" ));
1334-
1335- auto flashmask_maxmin_shape = startend_row_indices.dims ();
1336- // TODO(umiswing): refine this block constraint (kBlockN % 32), since some
1337- // of kBlockN is not divisible by 32 flashmask_maxmin_shape[2] =
1338- // (flashmask_maxmin_shape[2] + 31) / 32 * 8;
1339- flashmask_maxmin_shape[2 ] =
1340- ((flashmask_maxmin_shape[2 ] + 31 ) / 32 + 3 ) / 4 * 4 ;
1341- flashmask_maxmin_shape[3 ] = 8 ;
1342-
1343- flashmask_maxmin.set_type (phi::DataType::INT32);
1344- flashmask_maxmin.Resize (flashmask_maxmin_shape);
1345- dev_ctx.template Alloc <int32_t >(&flashmask_maxmin);
1346-
1347- lt_start_row_indices =
1348- phi::Slice<int32_t >(dev_ctx, startend_row_indices, {3 }, {0 }, {1 });
1349- if (startend_row_indices.dims ()[3 ] == 2 ) {
1350- if (!is_causal) {
1351- ut_end_row_indices =
1352- phi::Slice<int32_t >(dev_ctx, startend_row_indices, {3 }, {1 }, {2 });
1353- } else {
1354- lt_end_row_indices =
1355- phi::Slice<int32_t >(dev_ctx, startend_row_indices, {3 }, {1 }, {2 });
1356- }
1357- } else if (startend_row_indices.dims ()[3 ] == 4 ) {
1358- ut_end_row_indices =
1359- phi::Slice<int32_t >(dev_ctx, startend_row_indices, {3 }, {3 }, {4 });
1360- lt_end_row_indices =
1361- phi::Slice<int32_t >(dev_ctx, startend_row_indices, {3 }, {1 }, {2 });
1362- ut_start_row_indices =
1363- phi::Slice<int32_t >(dev_ctx, startend_row_indices, {3 }, {2 }, {3 });
1364- }
1365- }
13661402
13671403 if (is_flashmask) {
13681404 if (lt_start_row_indices.initialized ())
0 commit comments