diff --git a/hopper/flash.h b/hopper/flash.h index 91fb5c8127..28997613dc 100644 --- a/hopper/flash.h +++ b/hopper/flash.h @@ -158,6 +158,9 @@ struct Flash_fwd_params : public Qkv_params { int arch; int num_sm; + + // The S extra matrix, (num_heads) + void *__restrict__ s_aux_ptr; }; //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index 9511929b79..07878dea63 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -666,7 +666,8 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq std::optional &scheduler_metadata_, // (b + 1) int num_splits, std::optional pack_gqa_, - int const sm_margin + int const sm_margin, + std::optional &s_aux_ // (h) ) { auto dprops = at::cuda::getCurrentDeviceProperties(); @@ -1091,6 +1092,18 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq } } + if(s_aux_.has_value()) { + auto s_aux = s_aux_.value(); + TORCH_CHECK(s_aux.scalar_type() == at::ScalarType::BFloat16, + "We only support bf16 dtype for S extra."); + CHECK_DEVICE(s_aux); + CHECK_SHAPE(s_aux, num_heads); + CHECK_CONTIGUOUS(s_aux); + params.s_aux_ptr = s_aux.data_ptr(); + } else { + params.s_aux_ptr = nullptr; + } + #ifdef FLASHATTENTION_DISABLE_LOCAL TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention."); #endif diff --git a/hopper/flash_api_torch_lib.cpp b/hopper/flash_api_torch_lib.cpp index a2006f3c4e..ad2c515f9d 100644 --- a/hopper/flash_api_torch_lib.cpp +++ b/hopper/flash_api_torch_lib.cpp @@ -51,7 +51,8 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq std::optional &scheduler_metadata_, // (b + 1) int num_splits, std::optional pack_gqa_, - int const sm_margin + int const sm_margin, + std::optional &s_aux_ ); // Only applicable to the case where seqused_k (i.e. cache_seqlens) is available @@ -118,7 +119,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor? scheduler_metadata," " int num_splits," " bool? pack_gqa," - " int sm_margin) -> Tensor[]"); + " int sm_margin," + " Tensor? s_aux) -> Tensor[]"); ops.impl("fwd", torch::kCUDA, make_pytorch_shim(&mha_fwd)); ops.def("get_scheduler_metadata(" diff --git a/hopper/flash_attn_interface.py b/hopper/flash_attn_interface.py index 9e8d6908ef..d3150fbb67 100644 --- a/hopper/flash_attn_interface.py +++ b/hopper/flash_attn_interface.py @@ -48,7 +48,8 @@ def _flash_attn_forward( scheduler_metadata=None, num_splits=1, pack_gqa=None, - sm_margin=0): + sm_margin=0, + s_aux=None): q, k, k_new, v_new = [maybe_contiguous(x) for x in (q, k, k_new, v_new)] v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new = [ @@ -94,6 +95,7 @@ def _flash_attn_forward( num_splits, pack_gqa, sm_margin, + s_aux ) return out, softmax_lse, *rest @@ -233,7 +235,7 @@ def backward(ctx, dout, *args): ctx.causal, ctx.window_size, ctx.softcap, - ctx.deterministic, + ctx.deterministic, ) dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension return dqkv, None, None, None, None, None, None, None, None, None, None @@ -257,6 +259,7 @@ def forward( pack_gqa=None, deterministic=False, sm_margin=0, + s_aux=None, ): if softmax_scale is None: softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5) @@ -281,6 +284,7 @@ def forward( num_splits=num_splits, pack_gqa=pack_gqa, sm_margin=sm_margin, + s_aux=s_aux, ) # ctx.save_for_backward(q, k, v, out_padded, softmax_lse) ctx.save_for_backward(q, k, v, out, softmax_lse) @@ -319,7 +323,7 @@ def backward(ctx, dout, *args): dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dk = dk[..., : dout.shape[-1]] dv = dv[..., : dout.shape[-1]] - return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None class FlashAttnVarlenFunc(torch.autograd.Function): @@ -346,6 +350,7 @@ def forward( pack_gqa=None, deterministic=False, sm_margin=0, + s_aux=None, ): if softmax_scale is None: softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5) @@ -374,6 +379,7 @@ def forward( num_splits=num_splits, pack_gqa=pack_gqa, sm_margin=sm_margin, + s_aux=s_aux, ) # ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k) ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k) @@ -417,7 +423,7 @@ def backward(ctx, dout, *args): dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dk = dk[..., : dout.shape[-1]] dv = dv[..., : dout.shape[-1]] - return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None def flash_attn_qkvpacked_func( @@ -490,6 +496,7 @@ def flash_attn_func( pack_gqa=None, deterministic=False, sm_margin=0, + s_aux=None, ): """dropout_p should be set to 0.0 during evaluation Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads @@ -550,6 +557,7 @@ def flash_attn_func( pack_gqa, deterministic, sm_margin, + s_aux, ) @@ -573,6 +581,7 @@ def flash_attn_varlen_func( pack_gqa=None, deterministic=False, sm_margin=0, + s_aux=None, ): return FlashAttnVarlenFunc.apply( q, @@ -594,6 +603,7 @@ def flash_attn_varlen_func( pack_gqa, deterministic, sm_margin, + s_aux, ) @@ -631,6 +641,7 @@ def flash_attn_with_kvcache( pack_gqa=None, # Can be tuned for speed sm_margin=0, # Can be tuned if some SMs are used for communication return_softmax_lse=False, + s_aux=None, ): """ If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from @@ -757,6 +768,7 @@ def flash_attn_with_kvcache( num_splits=num_splits, pack_gqa=pack_gqa, sm_margin=sm_margin, + s_aux=s_aux, ) # return (out, softmax_lse) if return_softmax_lse else out return (out, softmax_lse, *rest) if return_softmax_lse else out diff --git a/hopper/flash_fwd_kernel_sm90.h b/hopper/flash_fwd_kernel_sm90.h index 47b3817cd2..242da9bf8a 100644 --- a/hopper/flash_fwd_kernel_sm90.h +++ b/hopper/flash_fwd_kernel_sm90.h @@ -52,6 +52,8 @@ class FlashAttnFwdSm90 { static_assert(CollectiveMainloop::LargeHeadDimV == CollectiveEpilogue::LargeHeadDimV); using SeqlenInfo_t = typename CollectiveMainloop::SeqlenInfo_t; + using SmemLayoutSAux = typename CollectiveMainloop::SmemLayoutSAux; + // Mainloop derived types using TileShape_MNK_PV = typename CollectiveMainloop::TileShape_MNK_PV; using TiledMmaPV = typename CollectiveMainloop::TiledMmaPV; @@ -295,6 +297,14 @@ class FlashAttnFwdSm90 { CollectiveMainloop mainloop; CollectiveEpilogue epilogue; + const int num_heads = get<2>(params.mainloop.shape_Q); + Tensor gS_aux = make_tensor(make_gmem_ptr(params.mainloop.ptr_S_aux), make_shape(num_heads)); + Tensor sS_aux = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_s_aux.data()), SmemLayoutSAux{}); + + if(params.mainloop.ptr_S_aux && threadIdx.x < num_heads) { + sS_aux(threadIdx.x) = gS_aux(threadIdx.x); + } + // We need this to guarantee that the Pipeline init is visible to all producers and consumer blocks in the Cluster if constexpr (size(ClusterShape{}) > 1) { cute::cluster_arrive_relaxed(); diff --git a/hopper/flash_fwd_launch_template.h b/hopper/flash_fwd_launch_template.h index 2556321913..2c9363300a 100644 --- a/hopper/flash_fwd_launch_template.h +++ b/hopper/flash_fwd_launch_template.h @@ -35,6 +35,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { static constexpr bool Is_FP8 = cute::is_same_v || cute::is_same_v; static constexpr bool FP8_TransposeV = Is_FP8 && !V_colmajor; using ArchTag = std::conditional_t= 90, cutlass::arch::Sm90, cutlass::arch::Sm80>; + using ElementS = cutlass::bfloat16_t; // Can't use structured binding since it's not compatible with constexpr static constexpr std::tuple kBlockMN_RS_IntraWGOverlap = tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(Element) /*element_size*/, V_colmajor, PagedKVNonTMA, Has_softcap, Use_one_mma_wg); @@ -52,8 +53,8 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { using ClusterShape = cute::Shape, _1, _1>; using CollectiveMainloop = std::conditional_t< Arch >= 90, - flash::CollectiveMainloopFwdSm90, - flash::CollectiveMainloopFwdSm80 + flash::CollectiveMainloopFwdSm90, + flash::CollectiveMainloopFwdSm80 >; using CollectiveEpilogue = flash::CollectiveEpilogueFwd; @@ -127,7 +128,8 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { params.kv_batch_idx, params.cu_seqlens_q, params.cu_seqlens_k, params.cu_seqlens_knew, params.seqused_q, params.seqused_k, - params.leftpad_k, params.seqlens_rotary + params.leftpad_k, params.seqlens_rotary, + static_cast(params.s_aux_ptr) }; typename CollectiveEpilogue::Arguments epilogue_args { static_cast(params.o_ptr), diff --git a/hopper/mainloop_fwd_sm80.hpp b/hopper/mainloop_fwd_sm80.hpp index 905be872dd..4ce024f346 100644 --- a/hopper/mainloop_fwd_sm80.hpp +++ b/hopper/mainloop_fwd_sm80.hpp @@ -25,7 +25,7 @@ using namespace cute; template + bool PackGQA_, bool Split_, class ElementSAux_> struct CollectiveMainloopFwdSm80 { static constexpr int kStages = Stages; @@ -34,6 +34,7 @@ struct CollectiveMainloopFwdSm80 { using TileShape_MNK_PV = Shape(TileShape_MNK{})), Int, decltype(get<1>(TileShape_MNK{}))>; using Element = Element_; using ElementAccum = ElementAccum_; + using ElementSAux = ElementSAux_; using ArchTag = ArchTag_; static constexpr bool Is_FP8 = cute::is_same_v || cute::is_same_v;; static constexpr bool Is_causal = Is_causal_; @@ -213,6 +214,7 @@ struct CollectiveMainloopFwdSm80 { int const* const seqused_k = nullptr; int const* const leftpad_k = nullptr; int const* const seqlens_rotary = nullptr; + ElementSAux const* const ptr_S_aux = nullptr; }; // Device side kernel params @@ -258,6 +260,7 @@ struct CollectiveMainloopFwdSm80 { int const* const seqused_k = nullptr; int const* const leftpad_k = nullptr; int const* const seqlens_rotary = nullptr; + ElementSAux const* const ptr_S_aux = nullptr; }; static Params @@ -297,7 +300,8 @@ struct CollectiveMainloopFwdSm80 { !Split ? 1 : args.num_splits, args.kv_batch_idx, args.cu_seqlens_q, args.cu_seqlens_k, args.cu_seqlens_k_new, - args.seqused_q, args.seqused_k, args.leftpad_k, args.seqlens_rotary}; + args.seqused_q, args.seqused_k, args.leftpad_k, args.seqlens_rotary, + args.ptr_S_aux}; } template diff --git a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp index 68988862e5..0bdd419153 100644 --- a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp +++ b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp @@ -30,7 +30,7 @@ using namespace cute; template + bool MmaPV_is_RS, bool IntraWGOverlap, bool PackGQA_, bool Split_, bool V_colmajor_, class ElementSAux_> struct CollectiveMainloopFwdSm90 { static constexpr int kStages = Stages; @@ -40,6 +40,7 @@ struct CollectiveMainloopFwdSm90 { using TileShape_MNK_QV = Shape(TileShape_MNK{})), decltype(get<1>(TileShape_MNK{})), Int>; using Element = Element_; using ElementAccum = ElementAccum_; + using ElementSAux = ElementSAux_; using ArchTag = ArchTag_; static constexpr bool Is_FP8 = cute::is_same_v || cute::is_same_v;; static constexpr bool Is_causal = Is_causal_; @@ -174,6 +175,9 @@ struct CollectiveMainloopFwdSm90 { using SmemCopyAtomP = Copy_Atom; + // Hardcoded to be at most 64 query heads + using SmemLayoutSAux = Layout>; + // Use LDSM.T and STSM to transpose V in the case of FP8 and V being row-major. // For FP16/BF16 we don't do any transposing. static_assert(!Transpose_V || (kHeadDimV % 32 == 0 && kBlockN % 32 == 0)); @@ -310,6 +314,7 @@ struct CollectiveMainloopFwdSm90 { cute::array_aligned, SmemAlignmentQ> smem_q; cute::array_aligned, SmemAlignmentK> smem_k; SmemQv_t smem_qv; + cute::array_aligned, 128> smem_s_aux; }; struct TensorStorageWithPNoTranspose : cute::aligned_struct { @@ -318,6 +323,7 @@ struct CollectiveMainloopFwdSm90 { cute::array_aligned, SmemAlignmentK> smem_k; SmemQv_t smem_qv; SmemP_t smem_p; + cute::array_aligned, 128> smem_s_aux; }; struct TensorStorageWithPScaleNoTranspose : cute::aligned_struct { cute::array_aligned, SmemAlignmentVtNoTranspose> smem_v; @@ -326,6 +332,7 @@ struct CollectiveMainloopFwdSm90 { SmemQv_t smem_qv; SmemP_t smem_p; SmemScale_t smem_scale; + cute::array_aligned, 128> smem_s_aux; }; using TensorStorageNoTranspose = std::conditional_t< @@ -344,6 +351,7 @@ struct CollectiveMainloopFwdSm90 { cute::array_aligned, SmemAlignmentK> smem_k; SmemQv_t smem_qv; SmemScale_t smem_scale; + cute::array_aligned, 128> smem_s_aux; }; using TensorStorage = std::conditional_t; @@ -396,6 +404,7 @@ struct CollectiveMainloopFwdSm90 { int const* const seqused_k = nullptr; int const* const leftpad_k = nullptr; int const* const seqlens_rotary = nullptr; + ElementSAux const* const ptr_S_aux = nullptr; }; // Device side kernel params @@ -451,7 +460,8 @@ struct CollectiveMainloopFwdSm90 { int const* const seqused_q = nullptr; int const* const seqused_k = nullptr; int const* const leftpad_k = nullptr; - int const *const seqlens_rotary = nullptr; + int const* const seqlens_rotary = nullptr; + ElementSAux const* const ptr_S_aux = nullptr; }; static Params @@ -560,7 +570,8 @@ struct CollectiveMainloopFwdSm90 { !Split ? 1 : args.num_splits, args.kv_batch_idx, args.cu_seqlens_q, args.cu_seqlens_k, args.cu_seqlens_k_new, - args.seqused_q, args.seqused_k, args.leftpad_k, args.seqlens_rotary}; + args.seqused_q, args.seqused_k, args.leftpad_k, args.seqlens_rotary, + args.ptr_S_aux}; } /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance @@ -1084,6 +1095,39 @@ struct CollectiveMainloopFwdSm90 { } }; + using TensorT = typename Softmax::TensorT; + using LayoutT = typename TensorT::layout_type; + auto finalize_dispatch = [&](TensorT& scores_scale, float const v_descale) { + if (params.ptr_S_aux && (!Split || (split_idx & 0x0000FFFF) == 0)) { + Tensor sS_aux = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_s_aux.data()), SmemLayoutSAux{}); + Tensor tSrS_aux = make_tensor_like(scores_scale); + static_assert(is_static::value); + static_assert(size(tSrS_aux) == size(LayoutT{})); + if constexpr(!PackGQA) { + #pragma unroll + for(int mi = 0; mi < size(tSrS_aux); ++mi) { + tSrS_aux(mi) = static_cast(sS_aux(bidh)); + } + } else { + Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_MNK{})); + auto thread_mma_qk = tiled_mma_qk.get_thread_slice(thread_idx); + Tensor tScS = thread_mma_qk.partition_C(cS); + Tensor tScS_rowcol = make_tensor(tScS.data(), flash::convert_layout_acc_rowcol(tScS.layout())); + static_assert(size<0>(tScS_rowcol) == size(tSrS_aux)); + int const qhead_per_khead = params.qhead_per_khead_divmod.divisor; + #pragma unroll + for(int mi = 0; mi < size(tSrS_aux); ++mi) { + int row = m_block * kBlockM + get<0>(tScS_rowcol(mi, _0{})); + int bidh_mi = (row % qhead_per_khead) + bidh_kv * qhead_per_khead; + tSrS_aux(mi) = static_cast(sS_aux(bidh_mi)); + } + } + cute::copy(softmax.finalize_aux(tSrS_aux, v_descale), scores_scale); + } else { + cute::copy(softmax.finalize(v_descale), scores_scale); + } + }; + auto &barrier_Q = shared_storage.pipelines.barrier_Q; if constexpr (!AppendKV) { barrier_Q.wait(work_idx % 2); @@ -1234,7 +1278,8 @@ struct CollectiveMainloopFwdSm90 { if constexpr (!HasQv) { consumer_wait(pipeline_v, smem_pipe_read); } flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read.index()), tOrO); float const v_descale = !Is_FP8 || params.ptr_v_descale == nullptr ? 1.0f : params.ptr_v_descale[bidb * get<0>(params.stride_v_descale) + bidh_kv * get<1>(params.stride_v_descale)]; - cute::copy(softmax.finalize(v_descale), scores_scale); + // cute::copy(softmax.finalize(v_descale), scores_scale); + finalize_dispatch(scores_scale, v_descale); if constexpr (LargeHeadDimV) { cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); store_scales(scores_scale, smem_pipe_read.index()); @@ -1334,7 +1379,10 @@ struct CollectiveMainloopFwdSm90 { // Tell producers that smem_q is ready cutlass::arch::NamedBarrier::arrive(NumMmaThreadsQK + (Use_TMA_Q ? cutlass::NumThreadsPerWarp : NumProducerThreads), static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); float const v_descale = !Is_FP8 || params.ptr_v_descale == nullptr ? 1.0f : params.ptr_v_descale[bidb * get<0>(params.stride_v_descale) + bidh_kv * get<1>(params.stride_v_descale)]; - Tensor scores_scale = softmax.finalize(v_descale); + // Tensor scores_scale = softmax.finalize(v_descale); + Tensor scores_scale = make_tensor_like(softmax.row_max); + finalize_dispatch(scores_scale, v_descale); + if constexpr (LargeHeadDimV) { cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); store_scales(scores_scale, smem_pipe_read.index()); diff --git a/hopper/softmax.h b/hopper/softmax.h index 8fcdb6bd07..e167b2e495 100644 --- a/hopper/softmax.h +++ b/hopper/softmax.h @@ -153,6 +153,27 @@ struct Softmax { return scores_scale; }; + __forceinline__ __device__ TensorT finalize_aux(TensorT const& tSrSAux, float const final_scale=1.f) { + SumOp sum_op; + quad_allreduce_(row_sum, row_sum, sum_op); + TensorT scores_scale; + #pragma unroll + for (int mi = 0; mi < size(row_sum); ++mi) { + if (row_max(mi) == -INFINITY) { row_max(mi) = 0.f; } + const float max_scaled = row_max(mi) * softmax_scale_log2 - Max_offset; + float sum = row_sum(mi) + exp2f(float(M_LOG2E) * tSrSAux(mi) - max_scaled); + float inv_sum = (sum == 0.f || sum != sum) ? 0.f : 1.f / sum; + scores_scale(mi) = inv_sum * final_scale; + // For FP8, we might have scaled the output of exp by 2**8 so we need to divide sum by that amount. + if constexpr (Max_offset != 0) { + static constexpr float sum_scale = 1.f / float(1 << Max_offset); + sum *= sum_scale; + } + row_sum(mi) = (sum == 0.f || sum != sum) ? -INFINITY : row_max(mi) * (softmax_scale_log2 * float(M_LN2)) + __logf(sum); + } + return scores_scale; + }; + template __forceinline__ __device__ void rescale_o(Tensor1 &acc_o, TensorT const &scores_scale) { // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) diff --git a/vllm_flash_attn/flash_attn_interface.py b/vllm_flash_attn/flash_attn_interface.py index ba21c49d4d..4fc95d56b4 100644 --- a/vllm_flash_attn/flash_attn_interface.py +++ b/vllm_flash_attn/flash_attn_interface.py @@ -145,6 +145,7 @@ def flash_attn_varlen_func( num_splits: int = 0, # Version selector fa_version: int = DEFAULT_FA_VERSION, + s_aux=None, ): """dropout_p should be set to 0.0 during evaluation Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads @@ -249,6 +250,7 @@ def flash_attn_varlen_func( softcap, return_softmax_lse and dropout_p > 0, None, + s_aux=s_aux, ) elif fa_version == 3: assert alibi_slopes is None, "Alibi is not supported in FA3" @@ -276,6 +278,7 @@ def flash_attn_varlen_func( num_splits, None, # pack_gqa 0, # sm_margin + s_aux # s_aux ) else: raise ValueError(f"Unsupported FA version: {fa_version}")