diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index ddd376eff95..fa7743f834d 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -13,7 +13,6 @@ # https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha # https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/blackwell/fmha.py -import enum import math from typing import Type, Tuple, Callable, Optional, Literal from functools import partial @@ -50,6 +49,7 @@ from flash_attn.cute.pack_gqa import PackGQA from flash_attn.cute import mma_sm100_desc as sm100_desc from flash_attn.cute import blackwell_helpers as sm100_utils +from flash_attn.cute.named_barrier import NamedBarrierFwdSm100 from cutlass.cute import FastDivmodDivisor from quack.cute_dsl_utils import ParamsBase from flash_attn.cute.tile_scheduler import ( @@ -60,22 +60,6 @@ SingleTileVarlenScheduler, ) - -class NamedBarrierFwd(enum.IntEnum): - Epilogue = enum.auto() # starts from 1 as barrier 0 is reserved for sync_threads() - TmemPtr = enum.auto() - SoftmaxStatsW0 = enum.auto() - SoftmaxStatsW1 = enum.auto() - SoftmaxStatsW2 = enum.auto() - SoftmaxStatsW3 = enum.auto() - SoftmaxStatsW4 = enum.auto() - SoftmaxStatsW5 = enum.auto() - SoftmaxStatsW6 = enum.auto() - SoftmaxStatsW7 = enum.auto() -# WarpSchedulerWG1 = enum.auto() -# WarpSchedulerWG2 = enum.auto() - - class FlashAttentionForwardSm100: def __init__( @@ -814,7 +798,7 @@ def kernel( storage = smem.allocate(self.shared_storage) tmem_alloc_barrier = pipeline.NamedBarrier( - barrier_id=int(NamedBarrierFwd.TmemPtr), + barrier_id=int(NamedBarrierFwdSm100.TmemPtr), num_threads=cute.arch.WARP_SIZE * len( (self.mma_warp_id, *self.softmax0_warp_ids, @@ -938,7 +922,7 @@ def kernel( ) # Should put the NamedBarrier inside the pipeline class so we'll just have pipeline_sm_stats sm_stats_barrier = pipeline_custom.NamedBarrier( - barrier_id=int(NamedBarrierFwd.SoftmaxStatsW0), num_threads=cute.arch.WARP_SIZE * 2 + barrier_id=int(NamedBarrierFwdSm100.SoftmaxStatsW0), num_threads=cute.arch.WARP_SIZE * 2 ) pipeline_o_epi = None if const_expr(not self.use_correction_warps_for_epi): @@ -2574,7 +2558,7 @@ def correction_epilogue( if const_expr(self.use_correction_warps_for_epi): assert(not self.use_tma_O) assert(gmem_tiled_copy_O is not None) - cute.arch.barrier(barrier_id=int(NamedBarrierFwd.Epilogue), + cute.arch.barrier(barrier_id=int(NamedBarrierFwdSm100.Epilogue), number_of_threads=len(self.epilogue_warp_ids) * cute.arch.WARP_SIZE) mma_tile_coord_v = thr_mma.thr_idx m_tile_idx = (m_block * self.q_stage + stage) * self.cta_group_size + mma_tile_coord_v @@ -2774,12 +2758,12 @@ def offset_kv_smem(self, sX: cute.Tensor, stage: Int32, phase: Int32): # warp_group_idx = utils.canonical_warp_group_idx(sync=False) # if warp_group_idx == 0: # cute.arch.barrier_arrive( - # barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1), number_of_threads=2 * 128, + # barrier_id=int(NamedBarrierFwdSm100.WarpSchedulerWG1), number_of_threads=2 * 128, # ) # def warp_scheduler_barrier_sync(self): # cute.arch.barrier( - # barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1) + utils.canonical_warp_group_idx(sync=False), + # barrier_id=int(NamedBarrierFwdSm100.WarpSchedulerWG1) + utils.canonical_warp_group_idx(sync=False), # number_of_threads=2 * 128 # ) @@ -2787,7 +2771,7 @@ def offset_kv_smem(self, sX: cute.Tensor, stage: Int32, phase: Int32): # cur_wg = utils.canonical_warp_group_idx(sync=False) # next_wg = 1 - cur_wg # cute.arch.barrier_arrive( - # barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1) + next_wg, number_of_threads=2 * 128, + # barrier_id=int(NamedBarrierFwdSm100.WarpSchedulerWG1) + next_wg, number_of_threads=2 * 128, # ) @cute.jit diff --git a/flash_attn/cute/named_barrier.py b/flash_attn/cute/named_barrier.py index eadac4b926c..09949fe9856 100644 --- a/flash_attn/cute/named_barrier.py +++ b/flash_attn/cute/named_barrier.py @@ -12,6 +12,19 @@ class NamedBarrierFwd(enum.IntEnum): PEmpty = enum.auto() +class NamedBarrierFwdSm100(enum.IntEnum): + Epilogue = enum.auto() # starts from 1 as barrier 0 is reserved for sync_threads() + TmemPtr = enum.auto() + SoftmaxStatsW0 = enum.auto() + SoftmaxStatsW1 = enum.auto() + SoftmaxStatsW2 = enum.auto() + SoftmaxStatsW3 = enum.auto() + SoftmaxStatsW4 = enum.auto() + SoftmaxStatsW5 = enum.auto() + SoftmaxStatsW6 = enum.auto() + SoftmaxStatsW7 = enum.auto() + + class NamedBarrierBwd(enum.IntEnum): Epilogue = enum.auto() WarpSchedulerWG1 = enum.auto()