diff --git a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp index 6e0d8b768b7..1f72fb33aef 100644 --- a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp +++ b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp @@ -43,6 +43,8 @@ struct CollectiveMainloopFwdSm90 { using ElementSAux = ElementSAux_; using ArchTag = ArchTag_; static constexpr bool Is_FP8 = cute::is_same_v || cute::is_same_v;; + static constexpr int kFP8TwoLevelInterval = FLASHATTENTION_FP8_TWO_LEVEL_INTERVAL; + static constexpr bool UseIntervalTwoLevel = Is_FP8 && (kFP8TwoLevelInterval >= 1); static constexpr bool Is_causal = Is_causal_; static constexpr bool Is_local = Is_local_; static constexpr bool Has_softcap = Has_softcap_; @@ -1240,6 +1242,25 @@ struct CollectiveMainloopFwdSm90 { clear(tOrO); // tiled_mma_pv.accumulate_ = GMMA::ScaleOut::Zero; + // For interval-based two-level accumulation + using AccumFragment_t = std::conditional_t< + UseIntervalTwoLevel, + decltype(cute::make_fragment_like(tOrO)), + cute::array>; + using AccumScale_t = std::conditional_t< + UseIntervalTwoLevel, + decltype(cute::make_tensor_like(scores_scale)), + cute::array>; + [[maybe_unused]] AccumFragment_t tOrO_accum{}; + [[maybe_unused]] AccumScale_t accum_scale{}; + if constexpr (UseIntervalTwoLevel) { + tOrO_accum = cute::make_fragment_like(tOrO); + accum_scale = cute::make_tensor_like(scores_scale); + cute::clear(tOrO_accum); + cute::fill(accum_scale, 1.0f); + } + [[maybe_unused]] int tile_idx = 0; + // Each step does gemm0 for iter n_block, gemm1 for iter n_block + 1, and softmax for iter n_block. auto fwd_step = [&](int const n_block, auto mask_fn, auto check_inf_type) { static constexpr bool Check_inf = decltype(check_inf_type)::value; @@ -1249,7 +1270,12 @@ struct CollectiveMainloopFwdSm90 { if (!UseSchedulerBarrier || warp_group_idx == 0) { consumer_wait(pipeline_k, smem_pipe_read); } warp_scheduler_barrier_sync(); flash::gemm(tiled_mma_qk, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS); - if constexpr (RescaleOBeforeGemm) { softmax.rescale_o(tOrO, scores_scale); } + if constexpr (RescaleOBeforeGemm) { + softmax.rescale_o(tOrO, scores_scale); + if constexpr (UseIntervalTwoLevel) { + softmax.update_accum_scale(accum_scale, scores_scale); + } + } if constexpr(!HasQv) { if (!UseSchedulerBarrier || warp_group_idx == 0) { consumer_wait(pipeline_v, smem_pipe_read_v); } } @@ -1276,8 +1302,20 @@ struct CollectiveMainloopFwdSm90 { convert_type_out(make_tensor(tSrS.data(), tOrP.layout()), tOrP); if constexpr (Is_FP8 && V_colmajor) { flash::permute_Aregs_fp8(tOrP); } if constexpr (!MmaPV_is_RS) { write_P_to_smem(tOrP); } - if constexpr (!RescaleOBeforeGemm) { softmax.rescale_o(tOrO, scores_scale); } + if constexpr (!RescaleOBeforeGemm) { + softmax.rescale_o(tOrO, scores_scale); + if constexpr (UseIntervalTwoLevel) { + softmax.update_accum_scale(accum_scale, scores_scale); + } + } if constexpr (!MmaPV_is_RS) { arrive_on_P_write_barrier(); } + if constexpr (UseIntervalTwoLevel) { + if (((tile_idx + 1) % kFP8TwoLevelInterval) == 0) { + softmax.merge_accum_with_scale(tOrO, tOrO_accum, accum_scale); + cute::clear(tOrO); + } + ++tile_idx; + } }; if constexpr (Is_causal || Is_local) { // Separate iterations with causal or local masking @@ -1323,8 +1361,19 @@ struct CollectiveMainloopFwdSm90 { } // Tell producers that smem_q is ready cutlass::arch::NamedBarrier::arrive(NumMmaThreadsQK + (Use_TMA_Q ? cutlass::NumThreadsPerWarp : NumProducerThreads), static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); - if constexpr (RescaleOBeforeGemm) { softmax.rescale_o(tOrO, scores_scale); } + if constexpr (RescaleOBeforeGemm) { + softmax.rescale_o(tOrO, scores_scale); + if constexpr (UseIntervalTwoLevel) { + softmax.update_accum_scale(accum_scale, scores_scale); + } + } if constexpr (!HasQv) { consumer_wait(pipeline_v, smem_pipe_read); } + if constexpr (UseIntervalTwoLevel) { + if (((tile_idx + 1) % kFP8TwoLevelInterval) == 0) { + softmax.merge_accum_with_scale(tOrO, tOrO_accum, accum_scale); + cute::clear(tOrO); + } + } flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read.index()), tOrO); float const v_descale = !Is_FP8 || params.ptr_v_descale == nullptr ? 1.0f : params.ptr_v_descale[bidb * get<0>(params.stride_v_descale) + bidh_kv * get<1>(params.stride_v_descale)]; // cute::copy(softmax.finalize(v_descale), scores_scale); @@ -1336,6 +1385,9 @@ struct CollectiveMainloopFwdSm90 { } warpgroup_wait<0>(); pipeline_v.consumer_release(smem_pipe_read); // release V, otherwise producers will hang + if constexpr (UseIntervalTwoLevel) { + softmax.final_merge_accum(tOrO, tOrO_accum, accum_scale); + } softmax.rescale_o(tOrO, scores_scale); if constexpr (Is_FP8 && !V_colmajor) { flash::permute_output_fp8(tOrO); } ++smem_pipe_read; @@ -1344,6 +1396,24 @@ struct CollectiveMainloopFwdSm90 { warp_scheduler_barrier_sync(); + using AccumFragmentNoOverlap_t = std::conditional_t< + UseIntervalTwoLevel, + decltype(cute::make_fragment_like(tOrO)), + cute::array>; + using AccumScaleNoOverlap_t = std::conditional_t< + UseIntervalTwoLevel, + decltype(cute::make_tensor_like(softmax.row_max)), + cute::array>; + [[maybe_unused]] AccumFragmentNoOverlap_t tOrO_accum_no_overlap{}; + [[maybe_unused]] AccumScaleNoOverlap_t accum_scale_no_overlap{}; + if constexpr (UseIntervalTwoLevel) { + tOrO_accum_no_overlap = cute::make_fragment_like(tOrO); + accum_scale_no_overlap = cute::make_tensor_like(softmax.row_max); + cute::clear(tOrO_accum_no_overlap); + cute::fill(accum_scale_no_overlap, 1.0f); + } + [[maybe_unused]] int tile_idx_no_overlap = 0; + auto fwd_step = [&](int const n_block, auto mask_fn, auto is_first_iter_type, auto check_inf_type) { static constexpr bool Is_first_iter = decltype(is_first_iter_type)::value; static constexpr bool Check_inf = decltype(check_inf_type)::value; @@ -1378,7 +1448,12 @@ struct CollectiveMainloopFwdSm90 { convert_type_out(tOrP_acc, tOrP); if constexpr (Is_FP8 && V_colmajor) { flash::permute_Aregs_fp8(tOrP); } if constexpr (!MmaPV_is_RS) { write_P_to_smem(tOrP); } - if constexpr (!Is_first_iter) { softmax.rescale_o(tOrO, scores_scale); } + if constexpr (!Is_first_iter) { + softmax.rescale_o(tOrO, scores_scale); + if constexpr (UseIntervalTwoLevel) { + softmax.update_accum_scale(accum_scale_no_overlap, scores_scale); + } + } if constexpr (!MmaPV_is_RS && !MmaPV_use_RS_WG1) { arrive_on_P_write_barrier(); } if constexpr (!HasQv) { consumer_wait(pipeline_v, smem_pipe_read); } warp_scheduler_barrier_sync(); @@ -1391,6 +1466,13 @@ struct CollectiveMainloopFwdSm90 { if constexpr (!MmaPV_is_RS && MmaPV_use_RS_WG1) { arrive_on_P_write_barrier(); } warpgroup_wait<0>(); pipeline_v.consumer_release(smem_pipe_read); // release V + if constexpr (UseIntervalTwoLevel && !Is_first_iter) { + if (((tile_idx_no_overlap + 1) % kFP8TwoLevelInterval) == 0) { + softmax.merge_accum_with_scale(tOrO, tOrO_accum_no_overlap, accum_scale_no_overlap); + cute::clear(tOrO); + } + ++tile_idx_no_overlap; + } }; auto first_iter_mask_fn = [&](auto& tSrS, int n_block) { mask.template apply(tSrS, m_block, n_block); }; @@ -1439,6 +1521,9 @@ struct CollectiveMainloopFwdSm90 { store_scales(scores_scale, smem_pipe_read.index()); cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PFull) /*id*/); } + if constexpr (UseIntervalTwoLevel) { + softmax.final_merge_accum(tOrO, tOrO_accum_no_overlap, accum_scale_no_overlap); + } softmax.rescale_o(tOrO, scores_scale); if constexpr (Is_FP8 && !V_colmajor) { flash::permute_output_fp8(tOrO); } ++smem_pipe_read; diff --git a/hopper/setup.py b/hopper/setup.py index 887b6339023..4b35445c8d9 100644 --- a/hopper/setup.py +++ b/hopper/setup.py @@ -48,7 +48,6 @@ DISABLE_SPLIT = os.getenv("FLASH_ATTENTION_DISABLE_SPLIT", "FALSE") == "TRUE" DISABLE_PAGEDKV = os.getenv("FLASH_ATTENTION_DISABLE_PAGEDKV", "FALSE") == "TRUE" DISABLE_APPENDKV = os.getenv("FLASH_ATTENTION_DISABLE_APPENDKV", "FALSE") == "TRUE" -DISABLE_FP8_TWO_LEVEL_ACCUMULATION = os.getenv("FLASH_ATTENTION_DISABLE_FP8_TWO_LEVEL_ACCUMULATION", "FALSE") == "TRUE" DISABLE_LOCAL = os.getenv("FLASH_ATTENTION_DISABLE_LOCAL", "FALSE") == "TRUE" DISABLE_SOFTCAP = os.getenv("FLASH_ATTENTION_DISABLE_SOFTCAP", "FALSE") == "TRUE" DISABLE_PACKGQA = os.getenv("FLASH_ATTENTION_DISABLE_PACKGQA", "FALSE") == "TRUE" @@ -69,6 +68,8 @@ DISABLE_HDIMDIFF192 = os.getenv("FLASH_ATTENTION_DISABLE_HDIMDIFF192", "FALSE") == "TRUE" PACKGQA_ONLY = os.getenv("FLASH_ATTENTION_PACKGQA_ONLY", "FALSE") == "TRUE" +# Controls two level accumulation interval in fp8, 0 disables entirely. +FP8_TWO_LEVEL_INTERVAL = int(os.getenv("FLASH_ATTENTION_FP8_TWO_LEVEL_INTERVAL", "4")) DISABLE_BACKWARD = True # DISABLE_SPLIT = True @@ -481,7 +482,6 @@ def nvcc_threads_args(): + (["-DFLASHATTENTION_DISABLE_PAGEDKV"] if DISABLE_PAGEDKV else []) + (["-DFLASHATTENTION_DISABLE_SPLIT"] if DISABLE_SPLIT else []) + (["-DFLASHATTENTION_DISABLE_APPENDKV"] if DISABLE_APPENDKV else []) - + (["-DFLASHATTENTION_DISABLE_FP8_TWO_LEVEL_ACCUMULATION"] if DISABLE_FP8_TWO_LEVEL_ACCUMULATION else []) + (["-DFLASHATTENTION_DISABLE_LOCAL"] if DISABLE_LOCAL else []) + (["-DFLASHATTENTION_DISABLE_SOFTCAP"] if DISABLE_SOFTCAP else []) + (["-DFLASHATTENTION_DISABLE_PACKGQA"] if DISABLE_PACKGQA else []) @@ -499,6 +499,7 @@ def nvcc_threads_args(): + (["-DFLASHATTENTION_DISABLE_HDIMDIFF64"] if DISABLE_HDIMDIFF64 else []) + (["-DFLASHATTENTION_DISABLE_HDIMDIFF192"] if DISABLE_HDIMDIFF192 else []) + (["-DFLASHATTENTION_PACKGQA_ONLY"] if PACKGQA_ONLY else []) + + [f"-DFLASHATTENTION_FP8_TWO_LEVEL_INTERVAL={FP8_TWO_LEVEL_INTERVAL}"] + (["-DFLASHATTENTION_VARLEN_ONLY"]) ) diff --git a/hopper/softmax.h b/hopper/softmax.h index e167b2e4955..83dd1b8a07d 100644 --- a/hopper/softmax.h +++ b/hopper/softmax.h @@ -186,6 +186,48 @@ struct Softmax { } }; + // Update cumulative accum scale factors (for deferred scaling) + template + __forceinline__ __device__ void update_accum_scale(TensorScale &accum_scale, TensorT const &scores_scale) { + #pragma unroll + for (int mi = 0; mi < kNRows; ++mi) { + accum_scale(mi) *= scores_scale(mi); + } + } + + // Apply deferred scale to accum and merge with current accumulator + template + __forceinline__ __device__ void merge_accum_with_scale(Tensor1 &acc_o, Tensor2 &acc_o_accum, TensorScale &accum_scale) { + Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); + Tensor accum_rowcol = make_tensor(acc_o_accum.data(), flash::convert_layout_acc_rowcol(acc_o_accum.layout())); + static_assert(CUTE_STATIC_V(size<0>(acc_o_rowcol)) == kNRows); + #pragma unroll + for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) { + float const scale = accum_scale(mi); + #pragma unroll + for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { + accum_rowcol(mi, ni) = scale * accum_rowcol(mi, ni) + acc_o_rowcol(mi, ni); + } + accum_scale(mi) = 1.0f; + } + } + + // Apply scale to accum, add to accumulator, result in accumulator + template + __forceinline__ __device__ void final_merge_accum(Tensor1 &acc_o, Tensor2 const &acc_o_accum, TensorScale const &accum_scale) { + Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); + Tensor accum_rowcol = make_tensor(acc_o_accum.data(), flash::convert_layout_acc_rowcol(acc_o_accum.layout())); + static_assert(CUTE_STATIC_V(size<0>(acc_o_rowcol)) == kNRows); + #pragma unroll + for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) { + float const scale = accum_scale(mi); + #pragma unroll + for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { + acc_o_rowcol(mi, ni) = scale * accum_rowcol(mi, ni) + acc_o_rowcol(mi, ni); + } + } + }; + }; } // namespace flash diff --git a/hopper/tile_size.h b/hopper/tile_size.h index 2e437559fc5..376f3f7cd53 100644 --- a/hopper/tile_size.h +++ b/hopper/tile_size.h @@ -63,7 +63,7 @@ constexpr std::tuple tile_size_fwd_sm90( // Prefill tiles — two-level accumulation needs smaller tiles to reduce // register pressure from the separate fp32 accumulator (tOrO_accum). // Currently just optimized for causal case. -#ifndef FLASHATTENTION_DISABLE_FP8_TWO_LEVEL_ACCUMULATION +#if FLASHATTENTION_FP8_TWO_LEVEL_INTERVAL >= 1 if (headdim <= 64) { return {192, 128, true, true}; } else if (headdim <= 96) { diff --git a/hopper/utils.h b/hopper/utils.h index efc755f7d7f..3719eab9202 100644 --- a/hopper/utils.h +++ b/hopper/utils.h @@ -235,21 +235,6 @@ auto mma_partition_fragment_AB(Mma const& mma, Tensor0 const& tensor0) { template CUTLASS_DEVICE void gemm(TiledMma& tiled_mma, Tensor0 const& tCrA, Tensor1 const& tCrB, Tensor2& tCrC) { -#ifndef FLASHATTENTION_DISABLE_FP8_TWO_LEVEL_ACCUMULATION - static constexpr bool Is_FP8 = cute::is_same_v - || cute::is_same_v; - static constexpr bool Use_Two_Level = Is_FP8 && !zero_init; - - auto tCrC_original = cute::make_fragment_like(tCrC); - if constexpr (Use_Two_Level) { - // Copy original values to backup - #pragma unroll - for (int i = 0; i < cute::size(tCrC); ++i) { - tCrC_original(i) = tCrC(i); - } - cute::clear(tCrC); - } -#endif if constexpr (M_slice >= 0) { static constexpr int MMA_M = decltype(size<1>(tCrC))::value; static_assert(M_slice < MMA_M); @@ -316,15 +301,6 @@ CUTLASS_DEVICE void gemm(TiledMma& tiled_mma, Tensor0 const& tCrA, Tensor1 const } } } - -#ifndef FLASHATTENTION_DISABLE_FP8_TWO_LEVEL_ACCUMULATION - if constexpr (Use_Two_Level) { - #pragma unroll - for (int i = 0; i < cute::size(tCrC); ++i) { - tCrC(i) = tCrC_original(i) + tCrC(i); // Add temp results to original - } - } -#endif } ////////////////////////////////////////////////////////////////////////////////////////////////////