Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 111 additions & 72 deletions paddle/phi/kernels/gpu/flash_attn_v3_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1019,29 +1019,124 @@ void FlashMaskV2GradBaseKernel(
bool const is_local =
(window_size_left >= 0 || window_size_right >= 0) && !is_causal;
bool const is_flashmask = startend_row_indices_.is_initialized();
DenseTensor startend_row_indices;
if (is_flashmask) startend_row_indices = startend_row_indices_.get();
bool const has_softcap = softcap > 0.0;

int const kBlockM_sm90 =
head_size_rounded <= 64
? (is_flashmask && !is_causal)
? 64
: (is_causal && softcap || is_flashmask > 0.0 ? 96 : 128)
: (head_size_rounded <= 128
? (is_flashmask && !is_causal)
? 64
: (is_causal || is_local || is_flashmask || softcap > 0.0
? 64
: 80)
: 64);
// flashmask
DenseTensor flashmask_maxmin, lt_start_row_indices, lt_end_row_indices,
ut_start_row_indices, ut_end_row_indices;
if (is_flashmask) {
PADDLE_ENFORCE_EQ(
startend_row_indices.dtype(),
phi::DataType::INT32,
common::errors::InvalidArgument(
"flashmask_attention startend_row_indices must be INT32 type"));
PADDLE_ENFORCE_EQ(
startend_row_indices.dims().size(),
4,
common::errors::InvalidArgument(
"flashmask_attention receive startend_row_indices with dim "
"[batch_size, num_heads,seq_len, mask_bounds]"));
PADDLE_ENFORCE_EQ(startend_row_indices.dims()[3] == 1 ||
startend_row_indices.dims()[3] == 2 ||
startend_row_indices.dims()[3] == 4,
true,
common::errors::InvalidArgument(
"flashmask_attention startend_row_indices "
"mask_bounds must in [1,2,4]"));

auto flashmask_maxmin_shape = startend_row_indices.dims();
// TODO(umiswing): refine this block constraint (kBlockN % 32), since some
// of kBlockN is not divisible by 32 flashmask_maxmin_shape[2] =
// (flashmask_maxmin_shape[2] + 31) / 32 * 8;
flashmask_maxmin_shape[2] =
((flashmask_maxmin_shape[2] + 31) / 32 + 3) / 4 * 4;
flashmask_maxmin_shape[3] = 8;

flashmask_maxmin.set_type(phi::DataType::INT32);
flashmask_maxmin.Resize(flashmask_maxmin_shape);
dev_ctx.template Alloc<int32_t>(&flashmask_maxmin);

lt_start_row_indices =
phi::Slice<int32_t>(dev_ctx, startend_row_indices, {3}, {0}, {1});
if (startend_row_indices.dims()[3] == 2) {
if (!is_causal) {
ut_end_row_indices =
phi::Slice<int32_t>(dev_ctx, startend_row_indices, {3}, {1}, {2});
} else {
lt_end_row_indices =
phi::Slice<int32_t>(dev_ctx, startend_row_indices, {3}, {1}, {2});
}
} else if (startend_row_indices.dims()[3] == 4) {
ut_end_row_indices =
phi::Slice<int32_t>(dev_ctx, startend_row_indices, {3}, {3}, {4});
lt_end_row_indices =
phi::Slice<int32_t>(dev_ctx, startend_row_indices, {3}, {1}, {2});
ut_start_row_indices =
phi::Slice<int32_t>(dev_ctx, startend_row_indices, {3}, {2}, {3});
}
}

const bool has_lt_start = lt_start_row_indices.initialized();
const bool has_lt_end = lt_end_row_indices.initialized();
const bool has_ut_start = ut_start_row_indices.initialized();
const bool has_ut_end = ut_end_row_indices.initialized();

// umiswing: The tile dispatch for flashmask is now different from fa3.
// Replacing the original ternary operator with lambda makes the code
// easier to reason about and less error-prone.
const auto [kBlockM_sm90, kBlockN_sm90] = [&]() -> std::pair<int, int> {
if (head_size_rounded <= 64) {
if (is_flashmask && !is_causal) {
return {64, 96};
} else if (is_causal && has_softcap || is_flashmask) {
return {96, 128};
} else {
return {128, 128};
}
} else if (head_size_rounded <= 128) {
// umiswing: by now, we reuse template instantiation of head dim 128 for
// head dim in range (64, 128], and therefore no separate dispatch for
// head dim in range (64, 96]
if (is_causal || is_local || has_softcap) {
return {64, 128};
} else {
if ((seqlen_q >= 1024 || seqlen_k >= 1024) &&
!(has_lt_end && has_ut_start)) {
return {64, 128};
} else {
return {64, 64};
}
}
} else if (head_size_rounded <= 192) {
// umiswing: head dim > 128 is not supported now
PADDLE_THROW(
common::errors::Unimplemented("head dim is rounded to %d, which is "
"not supported in FlashMask V3 now.",
head_size_rounded));
return {0, 0};
} else if (head_size_rounded <= 256) {
// umiswing: head dim > 128 is not supported now
PADDLE_THROW(
common::errors::Unimplemented("head dim is rounded to %d, which is "
"not supported in FlashMask V3 now.",
head_size_rounded));
return {0, 0};
} else {
PADDLE_THROW(
common::errors::Unimplemented("head dim is rounded to %d, which is "
"not supported in FlashMask V3 now.",
head_size_rounded));
return {0, 0};
}
}();

int const kBlockM_sm80 = head_size_rounded <= 64 ? 128 : 64;
int const kBlockM_sm86 = head_size_rounded <= 192 ? 64 : 32;
int const kBlockM =
arch >= 90 ? kBlockM_sm90
: (arch == 86 || arch == 89 ? kBlockM_sm86 : kBlockM_sm80);
int const kBlockN_sm90 =
head_size_rounded <= 64 && (is_flashmask && !is_causal) ? 96
: head_size_rounded <= 128 ? (is_flashmask && !is_causal) ? 64 : 128
: (head_size_rounded <= 192 ? 96 : 80);
int const kBlockN_sm80 =
head_size_rounded <= 128 ? 128 : (head_size_rounded <= 192 ? 80 : 64);
int const kBlockN_sm86 =
Expand Down Expand Up @@ -1308,62 +1403,6 @@ void FlashMaskV2GradBaseKernel(
dynload::flashmaskv2_bwd_params_set_dv_semaphore(params_handle,
dv_semaphore.data<int>());
}
// flashmask
DenseTensor startend_row_indices;
if (is_flashmask) startend_row_indices = startend_row_indices_.get();
DenseTensor flashmask_maxmin, lt_start_row_indices, lt_end_row_indices,
ut_start_row_indices, ut_end_row_indices;
if (is_flashmask) {
PADDLE_ENFORCE_EQ(
startend_row_indices.dtype(),
phi::DataType::INT32,
common::errors::InvalidArgument(
"flashmask_attention startend_row_indices must be INT32 type"));
PADDLE_ENFORCE_EQ(
startend_row_indices.dims().size(),
4,
common::errors::InvalidArgument(
"flashmask_attention receive startend_row_indices with dim "
"[batch_size, num_heads,seq_len, mask_bounds]"));
PADDLE_ENFORCE_EQ(startend_row_indices.dims()[3] == 1 ||
startend_row_indices.dims()[3] == 2 ||
startend_row_indices.dims()[3] == 4,
true,
common::errors::InvalidArgument(
"flashmask_attention startend_row_indices "
"mask_bounds must in [1,2,4]"));

auto flashmask_maxmin_shape = startend_row_indices.dims();
// TODO(umiswing): refine this block constraint (kBlockN % 32), since some
// of kBlockN is not divisible by 32 flashmask_maxmin_shape[2] =
// (flashmask_maxmin_shape[2] + 31) / 32 * 8;
flashmask_maxmin_shape[2] =
((flashmask_maxmin_shape[2] + 31) / 32 + 3) / 4 * 4;
flashmask_maxmin_shape[3] = 8;

flashmask_maxmin.set_type(phi::DataType::INT32);
flashmask_maxmin.Resize(flashmask_maxmin_shape);
dev_ctx.template Alloc<int32_t>(&flashmask_maxmin);

lt_start_row_indices =
phi::Slice<int32_t>(dev_ctx, startend_row_indices, {3}, {0}, {1});
if (startend_row_indices.dims()[3] == 2) {
if (!is_causal) {
ut_end_row_indices =
phi::Slice<int32_t>(dev_ctx, startend_row_indices, {3}, {1}, {2});
} else {
lt_end_row_indices =
phi::Slice<int32_t>(dev_ctx, startend_row_indices, {3}, {1}, {2});
}
} else if (startend_row_indices.dims()[3] == 4) {
ut_end_row_indices =
phi::Slice<int32_t>(dev_ctx, startend_row_indices, {3}, {3}, {4});
lt_end_row_indices =
phi::Slice<int32_t>(dev_ctx, startend_row_indices, {3}, {1}, {2});
ut_start_row_indices =
phi::Slice<int32_t>(dev_ctx, startend_row_indices, {3}, {2}, {3});
}
}

if (is_flashmask) {
if (lt_start_row_indices.initialized())
Expand Down
4 changes: 1 addition & 3 deletions paddle/phi/kernels/gpu/flash_attn_v3_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1762,9 +1762,7 @@ void FlashMaskV2BaseKernel(
const int params_arch =
phi::dynload::flashmaskv2_fwd_params_get_arch(params_handle);
bool const scheduler_needs_semaphore =
params_arch >= 90 ? (((params_is_causal || params_is_local) &&
(params_num_splits == 1)) ||
is_varlen)
params_arch >= 90 ? true
: ((params_is_causal && !is_varlen) ||
(is_varlen && params_num_splits > 1));
if (scheduler_needs_semaphore || use_dynamic_split) {
Expand Down
Loading