Skip to content

Commit

Permalink
Add template when seqlen_q equal to seqlen_k with casual mask (Paddle…
Browse files Browse the repository at this point in the history
  • Loading branch information
AnnaTrainingG authored Nov 1, 2023
1 parent 0fa5933 commit 0598fa2
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 20 deletions.
13 changes: 8 additions & 5 deletions csrc/flash_attn/src/flash_fwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ inline __device__ void write_softmax_to_gmem(

////////////////////////////////////////////////////////////////////////////////////////////////////

template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_N, bool Is_even_K, bool Return_softmax, bool Is_attn_mask, typename Params>
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_N, bool Is_even_K, bool Return_softmax, bool Is_attn_mask, bool Is_equal_seq_qk, typename Params>
inline __device__ void compute_attn_1rowblock(const Params &params, const int bidb, const int bidh, const int m_block) {

using Element = typename Kernel_traits::Element;
Expand Down Expand Up @@ -500,8 +500,11 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
params.unscale_softmax);
tPgMask.data() = tPgMask.data() + (-kBlockN);
}

softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal || Is_attn_mask>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
if (Is_equal_seq_qk) {
softmax_rescale_o</*Is_first=*/false>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
} else {
softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal || Is_attn_mask>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
}

Tensor rP = flash::convert_type<Element>(scores);
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
Expand Down Expand Up @@ -609,7 +612,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi

////////////////////////////////////////////////////////////////////////////////////////////////////

template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_N, bool Is_even_K, bool Return_softmax, bool Is_attn_mask, typename Params>
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_N, bool Is_even_K, bool Return_softmax, bool Is_attn_mask, bool Is_equal_seq_qk, typename Params>
inline __device__ void compute_attn(const Params &params) {
const int m_block = blockIdx.x;
// The block index for the batch.
Expand All @@ -625,7 +628,7 @@ inline __device__ void compute_attn(const Params &params) {
// the attention matrix. This way, as long as we have the batch, head, and the location of
// the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern.

flash::compute_attn_1rowblock<Kernel_traits, Is_dropout, Is_causal, Is_even_N, Is_even_K, Return_softmax, Is_attn_mask>(params, bidb, bidh, m_block);
flash::compute_attn_1rowblock<Kernel_traits, Is_dropout, Is_causal, Is_even_N, Is_even_K, Return_softmax, Is_attn_mask, Is_equal_seq_qk>(params, bidb, bidh, m_block);
}

////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down
33 changes: 18 additions & 15 deletions csrc/flash_attn/src/flash_fwd_launch_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
#include "flash_fwd_kernel.h"
#include "cuda_utils.h"

template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_N, bool Is_even_K, bool Return_softmax, bool Is_attn_mask>
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_N, bool Is_even_K, bool Return_softmax, bool Is_attn_mask, bool Is_equal_seq_qk>
__global__ void flash_fwd_kernel(Flash_fwd_params params) {
flash::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_even_N, Is_even_K, Return_softmax, Is_attn_mask>(params);
flash::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_even_N, Is_even_K, Return_softmax, Is_attn_mask, Is_equal_seq_qk>(params);
}

template<typename Kernel_traits, bool Is_dropout, bool Is_causal>
Expand All @@ -35,23 +35,26 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
const bool is_even_K = params.d == Kernel_traits::kHeadDim;
const bool return_softmax = params.p_ptr != nullptr;
const bool is_attn_mask = params.attn_mask_ptr != nullptr;
const bool is_equal_qk = (params.seqlen_q == params.seqlen_k) && (Is_causal) && (!is_attn_mask);
BOOL_SWITCH(is_even_N, IsEvenNConst, [&] {
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] {
BOOL_SWITCH(is_attn_mask, Is_attn_mask, [&] {
// Will only return softmax if dropout, to reduce compilation time.
auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenNConst, IsEvenKConst, ReturnSoftmaxConst && Is_dropout, Is_attn_mask && !Is_causal>;
// auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenNConst, true, ReturnSoftmaxConst && Is_dropout>;
if (smem_size >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
int ctas_per_sm;
cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
// printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
BOOL_SWITCH(is_equal_qk, Is_equal_seq_qk, [&] {
// Will only return softmax if dropout, to reduce compilation time.
auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenNConst, IsEvenKConst, ReturnSoftmaxConst && Is_dropout, Is_attn_mask && !Is_causal, Is_equal_seq_qk>;
// auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenNConst, true, ReturnSoftmaxConst && Is_dropout>;
if (smem_size >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
int ctas_per_sm;
cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
// printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
});
});
Expand Down

0 comments on commit 0598fa2

Please sign in to comment.