Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 89 additions & 4 deletions hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ struct CollectiveMainloopFwdSm90 {
using ElementSAux = ElementSAux_;
using ArchTag = ArchTag_;
static constexpr bool Is_FP8 = cute::is_same_v<Element, cutlass::float_e4m3_t> || cute::is_same_v<Element, cutlass::float_e5m2_t>;;
static constexpr int kFP8TwoLevelInterval = FLASHATTENTION_FP8_TWO_LEVEL_INTERVAL;
static constexpr bool UseIntervalTwoLevel = Is_FP8 && (kFP8TwoLevelInterval >= 1);
static constexpr bool Is_causal = Is_causal_;
static constexpr bool Is_local = Is_local_;
static constexpr bool Has_softcap = Has_softcap_;
Expand Down Expand Up @@ -1240,6 +1242,25 @@ struct CollectiveMainloopFwdSm90 {
clear(tOrO);
// tiled_mma_pv.accumulate_ = GMMA::ScaleOut::Zero;

// For interval-based two-level accumulation
using AccumFragment_t = std::conditional_t<
UseIntervalTwoLevel,
decltype(cute::make_fragment_like(tOrO)),
cute::array<float, 0>>;
using AccumScale_t = std::conditional_t<
UseIntervalTwoLevel,
decltype(cute::make_tensor_like(scores_scale)),
cute::array<float, 0>>;
[[maybe_unused]] AccumFragment_t tOrO_accum{};
[[maybe_unused]] AccumScale_t accum_scale{};
if constexpr (UseIntervalTwoLevel) {
tOrO_accum = cute::make_fragment_like(tOrO);
accum_scale = cute::make_tensor_like(scores_scale);
cute::clear(tOrO_accum);
cute::fill(accum_scale, 1.0f);
}
[[maybe_unused]] int tile_idx = 0;

// Each step does gemm0 for iter n_block, gemm1 for iter n_block + 1, and softmax for iter n_block.
auto fwd_step = [&](int const n_block, auto mask_fn, auto check_inf_type) {
static constexpr bool Check_inf = decltype(check_inf_type)::value;
Expand All @@ -1249,7 +1270,12 @@ struct CollectiveMainloopFwdSm90 {
if (!UseSchedulerBarrier || warp_group_idx == 0) { consumer_wait(pipeline_k, smem_pipe_read); }
warp_scheduler_barrier_sync();
flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma_qk, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS);
if constexpr (RescaleOBeforeGemm) { softmax.rescale_o(tOrO, scores_scale); }
if constexpr (RescaleOBeforeGemm) {
softmax.rescale_o(tOrO, scores_scale);
if constexpr (UseIntervalTwoLevel) {
softmax.update_accum_scale(accum_scale, scores_scale);
}
}
if constexpr(!HasQv) {
if (!UseSchedulerBarrier || warp_group_idx == 0) { consumer_wait(pipeline_v, smem_pipe_read_v); }
}
Expand All @@ -1276,8 +1302,20 @@ struct CollectiveMainloopFwdSm90 {
convert_type_out(make_tensor(tSrS.data(), tOrP.layout()), tOrP);
if constexpr (Is_FP8 && V_colmajor) { flash::permute_Aregs_fp8(tOrP); }
if constexpr (!MmaPV_is_RS) { write_P_to_smem(tOrP); }
if constexpr (!RescaleOBeforeGemm) { softmax.rescale_o(tOrO, scores_scale); }
if constexpr (!RescaleOBeforeGemm) {
softmax.rescale_o(tOrO, scores_scale);
if constexpr (UseIntervalTwoLevel) {
softmax.update_accum_scale(accum_scale, scores_scale);
}
}
if constexpr (!MmaPV_is_RS) { arrive_on_P_write_barrier(); }
if constexpr (UseIntervalTwoLevel) {
if (((tile_idx + 1) % kFP8TwoLevelInterval) == 0) {
softmax.merge_accum_with_scale(tOrO, tOrO_accum, accum_scale);
cute::clear(tOrO);
}
++tile_idx;
}
};

if constexpr (Is_causal || Is_local) { // Separate iterations with causal or local masking
Expand Down Expand Up @@ -1323,8 +1361,19 @@ struct CollectiveMainloopFwdSm90 {
}
// Tell producers that smem_q is ready
cutlass::arch::NamedBarrier::arrive(NumMmaThreadsQK + (Use_TMA_Q ? cutlass::NumThreadsPerWarp : NumProducerThreads), static_cast<uint32_t>(FwdNamedBarriers::QueryEmpty) /*id*/);
if constexpr (RescaleOBeforeGemm) { softmax.rescale_o(tOrO, scores_scale); }
if constexpr (RescaleOBeforeGemm) {
softmax.rescale_o(tOrO, scores_scale);
if constexpr (UseIntervalTwoLevel) {
softmax.update_accum_scale(accum_scale, scores_scale);
}
}
if constexpr (!HasQv) { consumer_wait(pipeline_v, smem_pipe_read); }
if constexpr (UseIntervalTwoLevel) {
if (((tile_idx + 1) % kFP8TwoLevelInterval) == 0) {
softmax.merge_accum_with_scale(tOrO, tOrO_accum, accum_scale);
cute::clear(tOrO);
}
}
flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiled_mma_pv, cute::conditional_return<MmaPV_is_RS>(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read.index()), tOrO);
float const v_descale = !Is_FP8 || params.ptr_v_descale == nullptr ? 1.0f : params.ptr_v_descale[bidb * get<0>(params.stride_v_descale) + bidh_kv * get<1>(params.stride_v_descale)];
// cute::copy(softmax.finalize(v_descale), scores_scale);
Expand All @@ -1336,6 +1385,9 @@ struct CollectiveMainloopFwdSm90 {
}
warpgroup_wait<0>();
pipeline_v.consumer_release(smem_pipe_read); // release V, otherwise producers will hang
if constexpr (UseIntervalTwoLevel) {
softmax.final_merge_accum(tOrO, tOrO_accum, accum_scale);
}
softmax.rescale_o(tOrO, scores_scale);
if constexpr (Is_FP8 && !V_colmajor) { flash::permute_output_fp8(tOrO); }
++smem_pipe_read;
Expand All @@ -1344,6 +1396,24 @@ struct CollectiveMainloopFwdSm90 {

warp_scheduler_barrier_sync();

using AccumFragmentNoOverlap_t = std::conditional_t<
UseIntervalTwoLevel,
decltype(cute::make_fragment_like(tOrO)),
cute::array<float, 0>>;
using AccumScaleNoOverlap_t = std::conditional_t<
UseIntervalTwoLevel,
decltype(cute::make_tensor_like(softmax.row_max)),
cute::array<float, 0>>;
[[maybe_unused]] AccumFragmentNoOverlap_t tOrO_accum_no_overlap{};
[[maybe_unused]] AccumScaleNoOverlap_t accum_scale_no_overlap{};
if constexpr (UseIntervalTwoLevel) {
tOrO_accum_no_overlap = cute::make_fragment_like(tOrO);
accum_scale_no_overlap = cute::make_tensor_like(softmax.row_max);
cute::clear(tOrO_accum_no_overlap);
cute::fill(accum_scale_no_overlap, 1.0f);
}
[[maybe_unused]] int tile_idx_no_overlap = 0;

auto fwd_step = [&](int const n_block, auto mask_fn, auto is_first_iter_type, auto check_inf_type) {
static constexpr bool Is_first_iter = decltype(is_first_iter_type)::value;
static constexpr bool Check_inf = decltype(check_inf_type)::value;
Expand Down Expand Up @@ -1378,7 +1448,12 @@ struct CollectiveMainloopFwdSm90 {
convert_type_out(tOrP_acc, tOrP);
if constexpr (Is_FP8 && V_colmajor) { flash::permute_Aregs_fp8(tOrP); }
if constexpr (!MmaPV_is_RS) { write_P_to_smem(tOrP); }
if constexpr (!Is_first_iter) { softmax.rescale_o(tOrO, scores_scale); }
if constexpr (!Is_first_iter) {
softmax.rescale_o(tOrO, scores_scale);
if constexpr (UseIntervalTwoLevel) {
softmax.update_accum_scale(accum_scale_no_overlap, scores_scale);
}
}
if constexpr (!MmaPV_is_RS && !MmaPV_use_RS_WG1) { arrive_on_P_write_barrier(); }
if constexpr (!HasQv) { consumer_wait(pipeline_v, smem_pipe_read); }
warp_scheduler_barrier_sync();
Expand All @@ -1391,6 +1466,13 @@ struct CollectiveMainloopFwdSm90 {
if constexpr (!MmaPV_is_RS && MmaPV_use_RS_WG1) { arrive_on_P_write_barrier(); }
warpgroup_wait<0>();
pipeline_v.consumer_release(smem_pipe_read); // release V
if constexpr (UseIntervalTwoLevel && !Is_first_iter) {
if (((tile_idx_no_overlap + 1) % kFP8TwoLevelInterval) == 0) {
softmax.merge_accum_with_scale(tOrO, tOrO_accum_no_overlap, accum_scale_no_overlap);
cute::clear(tOrO);
}
++tile_idx_no_overlap;
}
};

auto first_iter_mask_fn = [&](auto& tSrS, int n_block) { mask.template apply<true /*Seqlenk_mask*/, Is_causal, Is_local>(tSrS, m_block, n_block); };
Expand Down Expand Up @@ -1439,6 +1521,9 @@ struct CollectiveMainloopFwdSm90 {
store_scales(scores_scale, smem_pipe_read.index());
cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<uint32_t>(FwdNamedBarriers::PFull) /*id*/);
}
if constexpr (UseIntervalTwoLevel) {
softmax.final_merge_accum(tOrO, tOrO_accum_no_overlap, accum_scale_no_overlap);
}
softmax.rescale_o(tOrO, scores_scale);
if constexpr (Is_FP8 && !V_colmajor) { flash::permute_output_fp8(tOrO); }
++smem_pipe_read;
Expand Down
5 changes: 3 additions & 2 deletions hopper/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@
DISABLE_SPLIT = os.getenv("FLASH_ATTENTION_DISABLE_SPLIT", "FALSE") == "TRUE"
DISABLE_PAGEDKV = os.getenv("FLASH_ATTENTION_DISABLE_PAGEDKV", "FALSE") == "TRUE"
DISABLE_APPENDKV = os.getenv("FLASH_ATTENTION_DISABLE_APPENDKV", "FALSE") == "TRUE"
DISABLE_FP8_TWO_LEVEL_ACCUMULATION = os.getenv("FLASH_ATTENTION_DISABLE_FP8_TWO_LEVEL_ACCUMULATION", "FALSE") == "TRUE"
DISABLE_LOCAL = os.getenv("FLASH_ATTENTION_DISABLE_LOCAL", "FALSE") == "TRUE"
DISABLE_SOFTCAP = os.getenv("FLASH_ATTENTION_DISABLE_SOFTCAP", "FALSE") == "TRUE"
DISABLE_PACKGQA = os.getenv("FLASH_ATTENTION_DISABLE_PACKGQA", "FALSE") == "TRUE"
Expand All @@ -69,6 +68,8 @@
DISABLE_HDIMDIFF192 = os.getenv("FLASH_ATTENTION_DISABLE_HDIMDIFF192", "FALSE") == "TRUE"

PACKGQA_ONLY = os.getenv("FLASH_ATTENTION_PACKGQA_ONLY", "FALSE") == "TRUE"
# Controls two level accumulation interval in fp8, 0 disables entirely.
FP8_TWO_LEVEL_INTERVAL = int(os.getenv("FLASH_ATTENTION_FP8_TWO_LEVEL_INTERVAL", "4"))

DISABLE_BACKWARD = True
# DISABLE_SPLIT = True
Expand Down Expand Up @@ -481,7 +482,6 @@ def nvcc_threads_args():
+ (["-DFLASHATTENTION_DISABLE_PAGEDKV"] if DISABLE_PAGEDKV else [])
+ (["-DFLASHATTENTION_DISABLE_SPLIT"] if DISABLE_SPLIT else [])
+ (["-DFLASHATTENTION_DISABLE_APPENDKV"] if DISABLE_APPENDKV else [])
+ (["-DFLASHATTENTION_DISABLE_FP8_TWO_LEVEL_ACCUMULATION"] if DISABLE_FP8_TWO_LEVEL_ACCUMULATION else [])
+ (["-DFLASHATTENTION_DISABLE_LOCAL"] if DISABLE_LOCAL else [])
+ (["-DFLASHATTENTION_DISABLE_SOFTCAP"] if DISABLE_SOFTCAP else [])
+ (["-DFLASHATTENTION_DISABLE_PACKGQA"] if DISABLE_PACKGQA else [])
Expand All @@ -499,6 +499,7 @@ def nvcc_threads_args():
+ (["-DFLASHATTENTION_DISABLE_HDIMDIFF64"] if DISABLE_HDIMDIFF64 else [])
+ (["-DFLASHATTENTION_DISABLE_HDIMDIFF192"] if DISABLE_HDIMDIFF192 else [])
+ (["-DFLASHATTENTION_PACKGQA_ONLY"] if PACKGQA_ONLY else [])
+ [f"-DFLASHATTENTION_FP8_TWO_LEVEL_INTERVAL={FP8_TWO_LEVEL_INTERVAL}"]
+ (["-DFLASHATTENTION_VARLEN_ONLY"])
)

Expand Down
42 changes: 42 additions & 0 deletions hopper/softmax.h
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,48 @@ struct Softmax {
}
};

// Update cumulative accum scale factors (for deferred scaling)
template<typename TensorScale>
__forceinline__ __device__ void update_accum_scale(TensorScale &accum_scale, TensorT const &scores_scale) {
#pragma unroll
for (int mi = 0; mi < kNRows; ++mi) {
accum_scale(mi) *= scores_scale(mi);
}
}

// Apply deferred scale to accum and merge with current accumulator
template<typename Tensor1, typename Tensor2, typename TensorScale>
__forceinline__ __device__ void merge_accum_with_scale(Tensor1 &acc_o, Tensor2 &acc_o_accum, TensorScale &accum_scale) {
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
Tensor accum_rowcol = make_tensor(acc_o_accum.data(), flash::convert_layout_acc_rowcol(acc_o_accum.layout()));
static_assert(CUTE_STATIC_V(size<0>(acc_o_rowcol)) == kNRows);
#pragma unroll
for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) {
float const scale = accum_scale(mi);
#pragma unroll
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) {
accum_rowcol(mi, ni) = scale * accum_rowcol(mi, ni) + acc_o_rowcol(mi, ni);
}
accum_scale(mi) = 1.0f;
}
}

// Apply scale to accum, add to accumulator, result in accumulator
template<typename Tensor1, typename Tensor2, typename TensorScale>
__forceinline__ __device__ void final_merge_accum(Tensor1 &acc_o, Tensor2 const &acc_o_accum, TensorScale const &accum_scale) {
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
Tensor accum_rowcol = make_tensor(acc_o_accum.data(), flash::convert_layout_acc_rowcol(acc_o_accum.layout()));
static_assert(CUTE_STATIC_V(size<0>(acc_o_rowcol)) == kNRows);
#pragma unroll
for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) {
float const scale = accum_scale(mi);
#pragma unroll
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) {
acc_o_rowcol(mi, ni) = scale * accum_rowcol(mi, ni) + acc_o_rowcol(mi, ni);
}
}
};

};

} // namespace flash
2 changes: 1 addition & 1 deletion hopper/tile_size.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ constexpr std::tuple<int, int, bool, bool> tile_size_fwd_sm90(
// Prefill tiles — two-level accumulation needs smaller tiles to reduce
// register pressure from the separate fp32 accumulator (tOrO_accum).
// Currently just optimized for causal case.
#ifndef FLASHATTENTION_DISABLE_FP8_TWO_LEVEL_ACCUMULATION
#if FLASHATTENTION_FP8_TWO_LEVEL_INTERVAL >= 1
if (headdim <= 64) {
return {192, 128, true, true};
} else if (headdim <= 96) {
Expand Down
24 changes: 0 additions & 24 deletions hopper/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -235,21 +235,6 @@ auto mma_partition_fragment_AB(Mma const& mma, Tensor0 const& tensor0) {
template <bool zero_init=false, int wg_wait=0, bool SwapAB=false, int M_slice=-1,
typename Tensor0, typename Tensor1, typename Tensor2, typename TiledMma>
CUTLASS_DEVICE void gemm(TiledMma& tiled_mma, Tensor0 const& tCrA, Tensor1 const& tCrB, Tensor2& tCrC) {
#ifndef FLASHATTENTION_DISABLE_FP8_TWO_LEVEL_ACCUMULATION
static constexpr bool Is_FP8 = cute::is_same_v<typename Tensor0::value_type, cutlass::float_e4m3_t>
|| cute::is_same_v<typename Tensor0::value_type, cutlass::float_e5m2_t>;
static constexpr bool Use_Two_Level = Is_FP8 && !zero_init;

auto tCrC_original = cute::make_fragment_like(tCrC);
if constexpr (Use_Two_Level) {
// Copy original values to backup
#pragma unroll
for (int i = 0; i < cute::size(tCrC); ++i) {
tCrC_original(i) = tCrC(i);
}
cute::clear(tCrC);
}
#endif
if constexpr (M_slice >= 0) {
static constexpr int MMA_M = decltype(size<1>(tCrC))::value;
static_assert(M_slice < MMA_M);
Expand Down Expand Up @@ -316,15 +301,6 @@ CUTLASS_DEVICE void gemm(TiledMma& tiled_mma, Tensor0 const& tCrA, Tensor1 const
}
}
}

#ifndef FLASHATTENTION_DISABLE_FP8_TWO_LEVEL_ACCUMULATION
if constexpr (Use_Two_Level) {
#pragma unroll
for (int i = 0; i < cute::size(tCrC); ++i) {
tCrC(i) = tCrC_original(i) + tCrC(i); // Add temp results to original
}
}
#endif
}

////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down