diff --git a/paddle/phi/kernels/gpu/flash_attn_v3_grad_kernel.cu b/paddle/phi/kernels/gpu/flash_attn_v3_grad_kernel.cu index 2c7ed18d50ebf0..f2629f872d3d85 100644 --- a/paddle/phi/kernels/gpu/flash_attn_v3_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/flash_attn_v3_grad_kernel.cu @@ -1019,29 +1019,124 @@ void FlashMaskV2GradBaseKernel( bool const is_local = (window_size_left >= 0 || window_size_right >= 0) && !is_causal; bool const is_flashmask = startend_row_indices_.is_initialized(); + DenseTensor startend_row_indices; + if (is_flashmask) startend_row_indices = startend_row_indices_.get(); + bool const has_softcap = softcap > 0.0; - int const kBlockM_sm90 = - head_size_rounded <= 64 - ? (is_flashmask && !is_causal) - ? 64 - : (is_causal && softcap || is_flashmask > 0.0 ? 96 : 128) - : (head_size_rounded <= 128 - ? (is_flashmask && !is_causal) - ? 64 - : (is_causal || is_local || is_flashmask || softcap > 0.0 - ? 64 - : 80) - : 64); + // flashmask + DenseTensor flashmask_maxmin, lt_start_row_indices, lt_end_row_indices, + ut_start_row_indices, ut_end_row_indices; + if (is_flashmask) { + PADDLE_ENFORCE_EQ( + startend_row_indices.dtype(), + phi::DataType::INT32, + common::errors::InvalidArgument( + "flashmask_attention startend_row_indices must be INT32 type")); + PADDLE_ENFORCE_EQ( + startend_row_indices.dims().size(), + 4, + common::errors::InvalidArgument( + "flashmask_attention receive startend_row_indices with dim " + "[batch_size, num_heads,seq_len, mask_bounds]")); + PADDLE_ENFORCE_EQ(startend_row_indices.dims()[3] == 1 || + startend_row_indices.dims()[3] == 2 || + startend_row_indices.dims()[3] == 4, + true, + common::errors::InvalidArgument( + "flashmask_attention startend_row_indices " + "mask_bounds must in [1,2,4]")); + + auto flashmask_maxmin_shape = startend_row_indices.dims(); + // TODO(umiswing): refine this block constraint (kBlockN % 32), since some + // of kBlockN is not divisible by 32 flashmask_maxmin_shape[2] = + // (flashmask_maxmin_shape[2] + 31) / 32 * 8; + flashmask_maxmin_shape[2] = + ((flashmask_maxmin_shape[2] + 31) / 32 + 3) / 4 * 4; + flashmask_maxmin_shape[3] = 8; + + flashmask_maxmin.set_type(phi::DataType::INT32); + flashmask_maxmin.Resize(flashmask_maxmin_shape); + dev_ctx.template Alloc(&flashmask_maxmin); + + lt_start_row_indices = + phi::Slice(dev_ctx, startend_row_indices, {3}, {0}, {1}); + if (startend_row_indices.dims()[3] == 2) { + if (!is_causal) { + ut_end_row_indices = + phi::Slice(dev_ctx, startend_row_indices, {3}, {1}, {2}); + } else { + lt_end_row_indices = + phi::Slice(dev_ctx, startend_row_indices, {3}, {1}, {2}); + } + } else if (startend_row_indices.dims()[3] == 4) { + ut_end_row_indices = + phi::Slice(dev_ctx, startend_row_indices, {3}, {3}, {4}); + lt_end_row_indices = + phi::Slice(dev_ctx, startend_row_indices, {3}, {1}, {2}); + ut_start_row_indices = + phi::Slice(dev_ctx, startend_row_indices, {3}, {2}, {3}); + } + } + + const bool has_lt_start = lt_start_row_indices.initialized(); + const bool has_lt_end = lt_end_row_indices.initialized(); + const bool has_ut_start = ut_start_row_indices.initialized(); + const bool has_ut_end = ut_end_row_indices.initialized(); + + // umiswing: The tile dispatch for flashmask is now different from fa3. + // Replacing the original ternary operator with lambda makes the code + // easier to reason about and less error-prone. + const auto [kBlockM_sm90, kBlockN_sm90] = [&]() -> std::pair { + if (head_size_rounded <= 64) { + if (is_flashmask && !is_causal) { + return {64, 96}; + } else if (is_causal && has_softcap || is_flashmask) { + return {96, 128}; + } else { + return {128, 128}; + } + } else if (head_size_rounded <= 128) { + // umiswing: by now, we reuse template instantiation of head dim 128 for + // head dim in range (64, 128], and therefore no separate dispatch for + // head dim in range (64, 96] + if (is_causal || is_local || has_softcap) { + return {64, 128}; + } else { + if ((seqlen_q >= 1024 || seqlen_k >= 1024) && + !(has_lt_end && has_ut_start)) { + return {64, 128}; + } else { + return {64, 64}; + } + } + } else if (head_size_rounded <= 192) { + // umiswing: head dim > 128 is not supported now + PADDLE_THROW( + common::errors::Unimplemented("head dim is rounded to %d, which is " + "not supported in FlashMask V3 now.", + head_size_rounded)); + return {0, 0}; + } else if (head_size_rounded <= 256) { + // umiswing: head dim > 128 is not supported now + PADDLE_THROW( + common::errors::Unimplemented("head dim is rounded to %d, which is " + "not supported in FlashMask V3 now.", + head_size_rounded)); + return {0, 0}; + } else { + PADDLE_THROW( + common::errors::Unimplemented("head dim is rounded to %d, which is " + "not supported in FlashMask V3 now.", + head_size_rounded)); + return {0, 0}; + } + }(); int const kBlockM_sm80 = head_size_rounded <= 64 ? 128 : 64; int const kBlockM_sm86 = head_size_rounded <= 192 ? 64 : 32; int const kBlockM = arch >= 90 ? kBlockM_sm90 : (arch == 86 || arch == 89 ? kBlockM_sm86 : kBlockM_sm80); - int const kBlockN_sm90 = - head_size_rounded <= 64 && (is_flashmask && !is_causal) ? 96 - : head_size_rounded <= 128 ? (is_flashmask && !is_causal) ? 64 : 128 - : (head_size_rounded <= 192 ? 96 : 80); int const kBlockN_sm80 = head_size_rounded <= 128 ? 128 : (head_size_rounded <= 192 ? 80 : 64); int const kBlockN_sm86 = @@ -1308,62 +1403,6 @@ void FlashMaskV2GradBaseKernel( dynload::flashmaskv2_bwd_params_set_dv_semaphore(params_handle, dv_semaphore.data()); } - // flashmask - DenseTensor startend_row_indices; - if (is_flashmask) startend_row_indices = startend_row_indices_.get(); - DenseTensor flashmask_maxmin, lt_start_row_indices, lt_end_row_indices, - ut_start_row_indices, ut_end_row_indices; - if (is_flashmask) { - PADDLE_ENFORCE_EQ( - startend_row_indices.dtype(), - phi::DataType::INT32, - common::errors::InvalidArgument( - "flashmask_attention startend_row_indices must be INT32 type")); - PADDLE_ENFORCE_EQ( - startend_row_indices.dims().size(), - 4, - common::errors::InvalidArgument( - "flashmask_attention receive startend_row_indices with dim " - "[batch_size, num_heads,seq_len, mask_bounds]")); - PADDLE_ENFORCE_EQ(startend_row_indices.dims()[3] == 1 || - startend_row_indices.dims()[3] == 2 || - startend_row_indices.dims()[3] == 4, - true, - common::errors::InvalidArgument( - "flashmask_attention startend_row_indices " - "mask_bounds must in [1,2,4]")); - - auto flashmask_maxmin_shape = startend_row_indices.dims(); - // TODO(umiswing): refine this block constraint (kBlockN % 32), since some - // of kBlockN is not divisible by 32 flashmask_maxmin_shape[2] = - // (flashmask_maxmin_shape[2] + 31) / 32 * 8; - flashmask_maxmin_shape[2] = - ((flashmask_maxmin_shape[2] + 31) / 32 + 3) / 4 * 4; - flashmask_maxmin_shape[3] = 8; - - flashmask_maxmin.set_type(phi::DataType::INT32); - flashmask_maxmin.Resize(flashmask_maxmin_shape); - dev_ctx.template Alloc(&flashmask_maxmin); - - lt_start_row_indices = - phi::Slice(dev_ctx, startend_row_indices, {3}, {0}, {1}); - if (startend_row_indices.dims()[3] == 2) { - if (!is_causal) { - ut_end_row_indices = - phi::Slice(dev_ctx, startend_row_indices, {3}, {1}, {2}); - } else { - lt_end_row_indices = - phi::Slice(dev_ctx, startend_row_indices, {3}, {1}, {2}); - } - } else if (startend_row_indices.dims()[3] == 4) { - ut_end_row_indices = - phi::Slice(dev_ctx, startend_row_indices, {3}, {3}, {4}); - lt_end_row_indices = - phi::Slice(dev_ctx, startend_row_indices, {3}, {1}, {2}); - ut_start_row_indices = - phi::Slice(dev_ctx, startend_row_indices, {3}, {2}, {3}); - } - } if (is_flashmask) { if (lt_start_row_indices.initialized()) diff --git a/paddle/phi/kernels/gpu/flash_attn_v3_kernel.cu b/paddle/phi/kernels/gpu/flash_attn_v3_kernel.cu index a2bbc66d5abf2a..1f90117c545e77 100644 --- a/paddle/phi/kernels/gpu/flash_attn_v3_kernel.cu +++ b/paddle/phi/kernels/gpu/flash_attn_v3_kernel.cu @@ -1762,9 +1762,7 @@ void FlashMaskV2BaseKernel( const int params_arch = phi::dynload::flashmaskv2_fwd_params_get_arch(params_handle); bool const scheduler_needs_semaphore = - params_arch >= 90 ? (((params_is_causal || params_is_local) && - (params_num_splits == 1)) || - is_varlen) + params_arch >= 90 ? true : ((params_is_causal && !is_varlen) || (is_varlen && params_num_splits > 1)); if (scheduler_needs_semaphore || use_dynamic_split) { diff --git a/third_party/flashattn b/third_party/flashattn index 649d81c12f895e..bb1563a1403f78 160000 --- a/third_party/flashattn +++ b/third_party/flashattn @@ -1 +1 @@ -Subproject commit 649d81c12f895e38742dfd3cfa2e7c5db3f882e3 +Subproject commit bb1563a1403f78c519edaac9fc49142a04635f21