Skip to content

Commit a6644f4

Browse files
authored
[Bug Fix] Correct FlashMask V3 Backward Tile Size and Refine Tile Size Configuration (#75995)
* fix mistach tile size in phi, and refine bwd interface * refine * refine
1 parent 2b9ec1e commit a6644f4

File tree

1 file changed

+106
-70
lines changed

1 file changed

+106
-70
lines changed

paddle/phi/kernels/gpu/flash_attn_v3_grad_kernel.cu

Lines changed: 106 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1020,27 +1020,119 @@ void FlashMaskV2GradBaseKernel(
10201020
bool const is_local =
10211021
(window_size_left >= 0 || window_size_right >= 0) && !is_causal;
10221022
bool const is_flashmask = startend_row_indices_.is_initialized();
1023+
DenseTensor startend_row_indices;
1024+
if (is_flashmask) startend_row_indices = startend_row_indices_.get();
1025+
bool const has_softcap = softcap > 0.0;
10231026

1024-
int const kBlockM_sm90 =
1025-
head_size_rounded <= 64
1026-
? (is_flashmask && !is_causal)
1027-
? 64
1028-
: (is_causal && softcap || is_flashmask > 0.0 ? 96 : 128)
1029-
: (head_size_rounded <= 128
1030-
? (is_causal || is_local || softcap > 0.0
1031-
? 64
1032-
: 64)
1033-
: 64);
1027+
// flashmask
1028+
DenseTensor flashmask_maxmin, lt_start_row_indices, lt_end_row_indices,
1029+
ut_start_row_indices, ut_end_row_indices;
1030+
if (is_flashmask) {
1031+
PADDLE_ENFORCE_EQ(
1032+
startend_row_indices.dtype(),
1033+
phi::DataType::INT32,
1034+
common::errors::InvalidArgument(
1035+
"flashmask_attention startend_row_indices must be INT32 type"));
1036+
PADDLE_ENFORCE_EQ(
1037+
startend_row_indices.dims().size(),
1038+
4,
1039+
common::errors::InvalidArgument(
1040+
"flashmask_attention receive startend_row_indices with dim "
1041+
"[batch_size, num_heads,seq_len, mask_bounds]"));
1042+
PADDLE_ENFORCE_EQ(startend_row_indices.dims()[3] == 1 ||
1043+
startend_row_indices.dims()[3] == 2 ||
1044+
startend_row_indices.dims()[3] == 4,
1045+
true,
1046+
common::errors::InvalidArgument(
1047+
"flashmask_attention startend_row_indices "
1048+
"mask_bounds must in [1,2,4]"));
1049+
1050+
auto flashmask_maxmin_shape = startend_row_indices.dims();
1051+
// TODO(umiswing): refine this block constraint (kBlockN % 32), since some
1052+
// of kBlockN is not divisible by 32 flashmask_maxmin_shape[2] =
1053+
// (flashmask_maxmin_shape[2] + 31) / 32 * 8;
1054+
flashmask_maxmin_shape[2] =
1055+
((flashmask_maxmin_shape[2] + 31) / 32 + 3) / 4 * 4;
1056+
flashmask_maxmin_shape[3] = 8;
1057+
1058+
flashmask_maxmin.set_type(phi::DataType::INT32);
1059+
flashmask_maxmin.Resize(flashmask_maxmin_shape);
1060+
dev_ctx.template Alloc<int32_t>(&flashmask_maxmin);
1061+
1062+
lt_start_row_indices =
1063+
phi::Slice<int32_t>(dev_ctx, startend_row_indices, {3}, {0}, {1});
1064+
if (startend_row_indices.dims()[3] == 2) {
1065+
if (!is_causal) {
1066+
ut_end_row_indices =
1067+
phi::Slice<int32_t>(dev_ctx, startend_row_indices, {3}, {1}, {2});
1068+
} else {
1069+
lt_end_row_indices =
1070+
phi::Slice<int32_t>(dev_ctx, startend_row_indices, {3}, {1}, {2});
1071+
}
1072+
} else if (startend_row_indices.dims()[3] == 4) {
1073+
ut_end_row_indices =
1074+
phi::Slice<int32_t>(dev_ctx, startend_row_indices, {3}, {3}, {4});
1075+
lt_end_row_indices =
1076+
phi::Slice<int32_t>(dev_ctx, startend_row_indices, {3}, {1}, {2});
1077+
ut_start_row_indices =
1078+
phi::Slice<int32_t>(dev_ctx, startend_row_indices, {3}, {2}, {3});
1079+
}
1080+
}
1081+
1082+
const bool has_lt_start = lt_start_row_indices.initialized();
1083+
const bool has_lt_end = lt_end_row_indices.initialized();
1084+
const bool has_ut_start = ut_start_row_indices.initialized();
1085+
const bool has_ut_end = ut_end_row_indices.initialized();
1086+
1087+
// umiswing: The tile dispatch for flashmask is now different from fa3.
1088+
// Replacing the original ternary operator with lambda makes the code
1089+
// easier to reason about and less error-prone.
1090+
const auto [kBlockM_sm90, kBlockN_sm90] = [&]() -> std::pair<int, int> {
1091+
if (head_size_rounded <= 64) {
1092+
if (is_flashmask && !is_causal) {
1093+
return {64, 96};
1094+
} else if (is_causal && has_softcap || is_flashmask) {
1095+
return {96, 128};
1096+
} else {
1097+
return {128, 128};
1098+
}
1099+
} else if (head_size_rounded <= 128) {
1100+
// umiswing: by now, we resue template instantiation of head dim 128 for head dim in range (64, 128],
1101+
// and therefore no separate dispatch for head dim in range (64, 96]
1102+
if (is_causal || is_local || has_softcap) {
1103+
return {64, 128};
1104+
} else {
1105+
if ((seqlen_q >= 1024 || seqlen_k >= 1024) && !(has_lt_end && has_ut_start)) {
1106+
return {64, 128};
1107+
} else {
1108+
return {64, 64};
1109+
}
1110+
}
1111+
} else if (head_size_rounded <= 192) {
1112+
// umiswing: head dim > 128 is not supported now
1113+
PADDLE_THROW(common::errors::Unimplemented(
1114+
"head dim is rounded to %d, which is not supported in FlashMask V3 now.",
1115+
head_size_rounded));
1116+
return {0, 0};
1117+
} else if (head_size_rounded <= 256) {
1118+
// umiswing: head dim > 128 is not supported now
1119+
PADDLE_THROW(common::errors::Unimplemented(
1120+
"head dim is rounded to %d, which is not supported in FlashMask V3 now.",
1121+
head_size_rounded));
1122+
return {0, 0};
1123+
} else {
1124+
PADDLE_THROW(common::errors::Unimplemented(
1125+
"head dim is rounded to %d, which is not supported in FlashMask V3 now.",
1126+
head_size_rounded));
1127+
return {0, 0};
1128+
}
1129+
}();
10341130

10351131
int const kBlockM_sm80 = head_size_rounded <= 64 ? 128 : 64;
10361132
int const kBlockM_sm86 = head_size_rounded <= 192 ? 64 : 32;
10371133
int const kBlockM =
10381134
arch >= 90 ? kBlockM_sm90
10391135
: (arch == 86 || arch == 89 ? kBlockM_sm86 : kBlockM_sm80);
1040-
int const kBlockN_sm90 =
1041-
head_size_rounded <= 64 && (is_flashmask && !is_causal) ? 96
1042-
: head_size_rounded <= 128 ? (is_causal || is_local || softcap > 0.f || seqlen_q >= 8192 ? 128 : 64)
1043-
: (head_size_rounded <= 192 ? 96 : 80);
10441136
int const kBlockN_sm80 =
10451137
head_size_rounded <= 128 ? 128 : (head_size_rounded <= 192 ? 80 : 64);
10461138
int const kBlockN_sm86 =
@@ -1307,62 +1399,6 @@ void FlashMaskV2GradBaseKernel(
13071399
dynload::flashmaskv2_bwd_params_set_dv_semaphore(params_handle,
13081400
dv_semaphore.data<int>());
13091401
}
1310-
// flashmask
1311-
DenseTensor startend_row_indices;
1312-
if (is_flashmask) startend_row_indices = startend_row_indices_.get();
1313-
DenseTensor flashmask_maxmin, lt_start_row_indices, lt_end_row_indices,
1314-
ut_start_row_indices, ut_end_row_indices;
1315-
if (is_flashmask) {
1316-
PADDLE_ENFORCE_EQ(
1317-
startend_row_indices.dtype(),
1318-
phi::DataType::INT32,
1319-
common::errors::InvalidArgument(
1320-
"flashmask_attention startend_row_indices must be INT32 type"));
1321-
PADDLE_ENFORCE_EQ(
1322-
startend_row_indices.dims().size(),
1323-
4,
1324-
common::errors::InvalidArgument(
1325-
"flashmask_attention receive startend_row_indices with dim "
1326-
"[batch_size, num_heads,seq_len, mask_bounds]"));
1327-
PADDLE_ENFORCE_EQ(startend_row_indices.dims()[3] == 1 ||
1328-
startend_row_indices.dims()[3] == 2 ||
1329-
startend_row_indices.dims()[3] == 4,
1330-
true,
1331-
common::errors::InvalidArgument(
1332-
"flashmask_attention startend_row_indices "
1333-
"mask_bounds must in [1,2,4]"));
1334-
1335-
auto flashmask_maxmin_shape = startend_row_indices.dims();
1336-
// TODO(umiswing): refine this block constraint (kBlockN % 32), since some
1337-
// of kBlockN is not divisible by 32 flashmask_maxmin_shape[2] =
1338-
// (flashmask_maxmin_shape[2] + 31) / 32 * 8;
1339-
flashmask_maxmin_shape[2] =
1340-
((flashmask_maxmin_shape[2] + 31) / 32 + 3) / 4 * 4;
1341-
flashmask_maxmin_shape[3] = 8;
1342-
1343-
flashmask_maxmin.set_type(phi::DataType::INT32);
1344-
flashmask_maxmin.Resize(flashmask_maxmin_shape);
1345-
dev_ctx.template Alloc<int32_t>(&flashmask_maxmin);
1346-
1347-
lt_start_row_indices =
1348-
phi::Slice<int32_t>(dev_ctx, startend_row_indices, {3}, {0}, {1});
1349-
if (startend_row_indices.dims()[3] == 2) {
1350-
if (!is_causal) {
1351-
ut_end_row_indices =
1352-
phi::Slice<int32_t>(dev_ctx, startend_row_indices, {3}, {1}, {2});
1353-
} else {
1354-
lt_end_row_indices =
1355-
phi::Slice<int32_t>(dev_ctx, startend_row_indices, {3}, {1}, {2});
1356-
}
1357-
} else if (startend_row_indices.dims()[3] == 4) {
1358-
ut_end_row_indices =
1359-
phi::Slice<int32_t>(dev_ctx, startend_row_indices, {3}, {3}, {4});
1360-
lt_end_row_indices =
1361-
phi::Slice<int32_t>(dev_ctx, startend_row_indices, {3}, {1}, {2});
1362-
ut_start_row_indices =
1363-
phi::Slice<int32_t>(dev_ctx, startend_row_indices, {3}, {2}, {3});
1364-
}
1365-
}
13661402

13671403
if (is_flashmask) {
13681404
if (lt_start_row_indices.initialized())

0 commit comments

Comments
 (0)