Skip to content

Commit 882adc7

Browse files
[PHI] Flash Attention V3 128B aligned chunking load/store (#76003) (#76071)
* [PHI] Flash Attention V3 128B aligned chunking load/store * Update flashattn version Co-authored-by: Qianyue He <[email protected]>
1 parent d844a80 commit 882adc7

File tree

2 files changed

+27
-3
lines changed

2 files changed

+27
-3
lines changed

paddle/phi/kernels/gpu/flash_attn_v3_kernel.cu

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2087,8 +2087,32 @@ void FlashMaskV2BaseKernel(
20872087
// TODO(umiswing): refine this block constraint (kBlockN % 32), since some
20882088
// of kBlockN is not divisible by 32 flashmask_maxmin_shape[2] =
20892089
// (flashmask_maxmin_shape[2] + 31) / 32 * 8;
2090-
flashmask_maxmin_shape[2] =
2091-
((flashmask_maxmin_shape[2] + 31) / 32 + 3) / 4 * 4;
2090+
2091+
int device_id = dev_ctx.GetPlace().GetDeviceId();
2092+
auto dprops = paddle::platform::GetDeviceProperties(device_id);
2093+
const bool is_sm90 = dprops.major == 9 && dprops.minor == 0;
2094+
2095+
if (is_sm90) {
2096+
// seqlen_k to nblock_seqlen, here we use kBlockN = 64
2097+
// as a conservative estimation (reduce allocation size)
2098+
flashmask_maxmin_shape[2] =
2099+
((flashmask_maxmin_shape[2] + 63) / 64 + 3) / 4 * 4;
2100+
// make sure this is the same with FlashMaskV3 fwd main loop
2101+
static constexpr int flashmask_buffer_length = 16 * 1024;
2102+
// estimate the upper bound of the possible chunk size
2103+
static constexpr int chunk_padded_length =
2104+
((flashmask_buffer_length + 63) / 64 + 31) & 0xffffffe0;
2105+
static constexpr int chunk_valid_length =
2106+
((flashmask_buffer_length + 63) / 64 + 3) & 0xfffffffc;
2107+
const int num_chunk =
2108+
(flashmask_maxmin_shape[2] + chunk_valid_length - 1) /
2109+
chunk_valid_length;
2110+
flashmask_maxmin_shape[2] = num_chunk * chunk_padded_length;
2111+
} else {
2112+
// seqlen_k to nblock_seqlen
2113+
flashmask_maxmin_shape[2] =
2114+
((flashmask_maxmin_shape[2] + 31) / 32 + 3) / 4 * 4;
2115+
}
20922116
flashmask_maxmin_shape[3] = 8;
20932117

20942118
flashmask_maxmin.set_type(phi::DataType::INT32);

0 commit comments

Comments
 (0)