Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 26 additions & 2 deletions paddle/phi/kernels/gpu/flash_attn_v3_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2087,8 +2087,32 @@ void FlashMaskV2BaseKernel(
// TODO(umiswing): refine this block constraint (kBlockN % 32), since some
// of kBlockN is not divisible by 32 flashmask_maxmin_shape[2] =
// (flashmask_maxmin_shape[2] + 31) / 32 * 8;
flashmask_maxmin_shape[2] =
((flashmask_maxmin_shape[2] + 31) / 32 + 3) / 4 * 4;

int device_id = dev_ctx.GetPlace().GetDeviceId();
auto dprops = paddle::platform::GetDeviceProperties(device_id);
const bool is_sm90 = dprops.major == 9 && dprops.minor == 0;

if (is_sm90) {
// seqlen_k to nblock_seqlen, here we use kBlockN = 64
// as a conservative estimation (reduce allocation size)
flashmask_maxmin_shape[2] =
((flashmask_maxmin_shape[2] + 63) / 64 + 3) / 4 * 4;
// make sure this is the same with FlashMaskV3 fwd main loop
static constexpr int flashmask_buffer_length = 16 * 1024;
// estimate the upper bound of the possible chunk size
static constexpr int chunk_padded_length =
((flashmask_buffer_length + 63) / 64 + 31) & 0xffffffe0;
static constexpr int chunk_valid_length =
((flashmask_buffer_length + 63) / 64 + 3) & 0xfffffffc;
const int num_chunk =
(flashmask_maxmin_shape[2] + chunk_valid_length - 1) /
chunk_valid_length;
flashmask_maxmin_shape[2] = num_chunk * chunk_padded_length;
} else {
// seqlen_k to nblock_seqlen
flashmask_maxmin_shape[2] =
((flashmask_maxmin_shape[2] + 31) / 32 + 3) / 4 * 4;
}
flashmask_maxmin_shape[3] = 8;

flashmask_maxmin.set_type(phi::DataType::INT32);
Expand Down
Loading