@@ -2087,8 +2087,32 @@ void FlashMaskV2BaseKernel(
20872087 // TODO(umiswing): refine this block constraint (kBlockN % 32), since some
20882088 // of kBlockN is not divisible by 32 flashmask_maxmin_shape[2] =
20892089 // (flashmask_maxmin_shape[2] + 31) / 32 * 8;
2090- flashmask_maxmin_shape[2 ] =
2091- ((flashmask_maxmin_shape[2 ] + 31 ) / 32 + 3 ) / 4 * 4 ;
2090+
2091+ int device_id = dev_ctx.GetPlace ().GetDeviceId ();
2092+ auto dprops = paddle::platform::GetDeviceProperties (device_id);
2093+ const bool is_sm90 = dprops.major == 9 && dprops.minor == 0 ;
2094+
2095+ if (is_sm90) {
2096+ // seqlen_k to nblock_seqlen, here we use kBlockN = 64
2097+ // as a conservative estimation (reduce allocation size)
2098+ flashmask_maxmin_shape[2 ] =
2099+ ((flashmask_maxmin_shape[2 ] + 63 ) / 64 + 3 ) / 4 * 4 ;
2100+ // make sure this is the same with FlashMaskV3 fwd main loop
2101+ static constexpr int flashmask_buffer_length = 16 * 1024 ;
2102+ // estimate the upper bound of the possible chunk size
2103+ static constexpr int chunk_padded_length =
2104+ ((flashmask_buffer_length + 63 ) / 64 + 31 ) & 0xffffffe0 ;
2105+ static constexpr int chunk_valid_length =
2106+ ((flashmask_buffer_length + 63 ) / 64 + 3 ) & 0xfffffffc ;
2107+ const int num_chunk =
2108+ (flashmask_maxmin_shape[2 ] + chunk_valid_length - 1 ) /
2109+ chunk_valid_length;
2110+ flashmask_maxmin_shape[2 ] = num_chunk * chunk_padded_length;
2111+ } else {
2112+ // seqlen_k to nblock_seqlen
2113+ flashmask_maxmin_shape[2 ] =
2114+ ((flashmask_maxmin_shape[2 ] + 31 ) / 32 + 3 ) / 4 * 4 ;
2115+ }
20922116 flashmask_maxmin_shape[3 ] = 8 ;
20932117
20942118 flashmask_maxmin.set_type (phi::DataType::INT32);
0 commit comments