diff --git a/flash_attn/cute/block_info.py b/flash_attn/cute/block_info.py index eeaa0e3e740..be13e70f892 100644 --- a/flash_attn/cute/block_info.py +++ b/flash_attn/cute/block_info.py @@ -58,12 +58,16 @@ def get_n_block_min_max( def get_m_block_min_max(self, seqlen_info: SeqlenInfoQK, n_block: Int32) -> Tuple[Int32, Int32]: m_block_max = cute.ceil_div(seqlen_info.seqlen_q, self.tile_m) m_block_min = 0 - if const_expr(self.is_causal): - m_block_min = max( - m_block_min, - (n_block * self.tile_n + seqlen_info.seqlen_q - seqlen_info.seqlen_k) - // self.tile_m, - ) + if const_expr(self.is_causal or (self.is_local and self.window_size_right is not None)): + n_idx_min = n_block * self.tile_n + m_idx = n_idx_min + seqlen_info.seqlen_q - seqlen_info.seqlen_k + m_idx_right = m_idx if const_expr(self.is_causal) else m_idx - self.window_size_right + m_block_min = max(m_block_min, m_idx_right // self.tile_m) + if const_expr(self.is_local and self.window_size_left is not None): + n_idx_max = (n_block + 1) * self.tile_n + m_idx = n_idx_max + seqlen_info.seqlen_q - seqlen_info.seqlen_k + m_idx_left = m_idx + self.window_size_left + m_block_max = min(m_block_max, cute.ceil_div(m_idx_left, self.tile_m)) return m_block_min, m_block_max @cute.jit diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py index 78506b77dba..00c8cbf66d7 100644 --- a/flash_attn/cute/flash_bwd_sm100.py +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -82,7 +82,7 @@ def __init__( self.cluster_shape_mn = (cluster_size, 1) self.is_persistent = is_persistent self.is_causal = is_causal - self.is_local = False + self.is_local = is_local self.qhead_per_kvhead = qhead_per_kvhead self.pack_gqa = False self.use_tma_store = True @@ -384,11 +384,19 @@ def __call__( *(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]), t.stride[-1], ) - (mdQaccum,) = [ + ( + mdQaccum, + mdK, + mdV, + ) = [ cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) if t is not None else None - for t in (mdQaccum,) + for t in ( + mdQaccum, + mdK, + mdV, + ) ] layout_transpose = [1, 3, 2, 0] # (b, s, n, h) --> (s, h, n, b) @@ -555,7 +563,8 @@ def __call__( TileScheduler = SingleTileLPTBwdScheduler else: TileScheduler = SingleTileScheduler - self.spt = self.is_causal and self.deterministic + # reads n_blocks right-to-left + self.spt = (self.is_causal or self.is_local) and self.deterministic tile_sched_args = TileSchedulerArguments( cute.ceil_div(cute.size(mK.shape[0]), self.cta_tiler[0]), cute.size(mQ.shape[2]), # num_heads = num_query_heads @@ -657,6 +666,12 @@ class SharedStorage: LOG2_E = math.log2(math.e) softmax_scale_log2 = softmax_scale * LOG2_E + + if const_expr(window_size_left is not None): + window_size_left = Int32(window_size_left) + if const_expr(window_size_right is not None): + window_size_right = Int32(window_size_right) + self.kernel( tma_tensor_Q, tma_tensor_K, @@ -701,6 +716,8 @@ class SharedStorage: tiled_copy_r2s_dKV, softmax_scale, softmax_scale_log2, + window_size_left, + window_size_right, tile_sched_params, ).launch( grid=grid_dim, @@ -757,6 +774,8 @@ def kernel( tiled_copy_r2s_dKV: cute.TiledCopy, softmax_scale: cutlass.Float32, softmax_scale_log2: cutlass.Float32, + window_size_left: Optional[Int32], + window_size_right: Optional[Int32], tile_sched_params: ParamsBase, ): warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) @@ -975,8 +994,8 @@ def kernel( self.is_causal, self.is_local, False, # is_split_kv - None, - None, + window_size_left, + window_size_right, qhead_per_kvhead_packgqa=1, ) SeqlenInfoCls = partial( @@ -990,12 +1009,13 @@ def kernel( ) TileSchedulerCls = partial(self.tile_scheduler_cls.create, tile_sched_params) - # TODO: support local AttentionMaskCls = partial( AttentionMask, self.tile_m, self.tile_n, swap_AB=True, + window_size_left=window_size_left, + window_size_right=window_size_right, ) # EMPTY @@ -1228,8 +1248,8 @@ def load( tdPgV = thr_mma_dP.partition_A(gV) gQ = cute.local_tile(mQ_cur, cute.select(self.mma_tiler_kq, mode=[1, 2]), (None, 0)) tSgQ = thr_mma_S.partition_B(gQ) - gLSE = cute.local_tile(mLSE_cur, (self.tile_n,), (None,)) - gdPsum = cute.local_tile(mPsum_cur, (self.tile_n,), (None,)) + gLSE = cute.local_tile(mLSE_cur, (self.tile_m,), (None,)) + gdPsum = cute.local_tile(mPsum_cur, (self.tile_m,), (None,)) gdO = cute.local_tile(mdO_cur, cute.select(self.mma_tiler_pdo, mode=[1, 2]), (0, None)) tdPgdO = thr_mma_dV.partition_B(gdO) @@ -1272,80 +1292,83 @@ def load( # gdPsum = cute.logical_divide(gdPsum, (64,))[(None, block_in_cluster_coord_vmnk[1]), None] # copy_stats = partial(cute.copy, copy_atom_stats, mcast_mask=q_do_mcast_mask) - # First iteration: load K together w Q & LSE, then V together w dO & dPsum - if const_expr(should_load_Q): - # K & Q - pipeline_Q.producer_acquire( - producer_state_Q_LSE, extra_tx_count=self.tma_copy_bytes["K"] - ) - load_K(tma_bar_ptr=pipeline_Q.producer_get_barrier(producer_state_Q_LSE)) - load_Q(m_block_min, producer_state=producer_state_Q_LSE) - pipeline_Q.producer_commit(producer_state_Q_LSE) - # LSE - pipeline_LSE.producer_acquire(producer_state_Q_LSE) - with cute.arch.elect_one(): - copy_stats( - gLSE[None, m_block_min], - sLSE[None, producer_state_Q_LSE.index], - mbar_ptr=pipeline_LSE.producer_get_barrier(producer_state_Q_LSE), - ) - producer_state_Q_LSE.advance() - if const_expr(should_load_dO): - # V & dO - pipeline_dO.producer_acquire( - producer_state_dO_dPsum, extra_tx_count=self.tma_copy_bytes["V"] - ) - load_V(tma_bar_ptr=pipeline_dO.producer_get_barrier(producer_state_dO_dPsum)) - load_dO(m_block_min, producer_state=producer_state_dO_dPsum) - pipeline_dO.producer_commit(producer_state_dO_dPsum) - # dPsum - pipeline_dPsum.producer_acquire(producer_state_dO_dPsum) - with cute.arch.elect_one(): - copy_stats( - gdPsum[None, m_block_min], - sdPsum[None, producer_state_dO_dPsum.index], - mbar_ptr=pipeline_dPsum.producer_get_barrier(producer_state_dO_dPsum), - ) - producer_state_dO_dPsum.advance() - - for m_block in cutlass.range(m_block_min + 1, m_block_max, unroll=1): + if const_expr(not self.is_local) or m_block_min < m_block_max: + # First iteration: load K together w Q & LSE, then V together w dO & dPsum if const_expr(should_load_Q): - # Q - pipeline_Q.producer_acquire(producer_state_Q_LSE) - load_Q(m_block, producer_state=producer_state_Q_LSE) + # K & Q + pipeline_Q.producer_acquire( + producer_state_Q_LSE, extra_tx_count=self.tma_copy_bytes["K"] + ) + load_K(tma_bar_ptr=pipeline_Q.producer_get_barrier(producer_state_Q_LSE)) + load_Q(m_block_min, producer_state=producer_state_Q_LSE) pipeline_Q.producer_commit(producer_state_Q_LSE) # LSE pipeline_LSE.producer_acquire(producer_state_Q_LSE) with cute.arch.elect_one(): copy_stats( - gLSE[None, m_block], + gLSE[None, m_block_min], sLSE[None, producer_state_Q_LSE.index], mbar_ptr=pipeline_LSE.producer_get_barrier(producer_state_Q_LSE), ) producer_state_Q_LSE.advance() if const_expr(should_load_dO): - # dO - pipeline_dO.producer_acquire(producer_state_dO_dPsum) - load_dO(m_block, producer_state=producer_state_dO_dPsum) + # V & dO + pipeline_dO.producer_acquire( + producer_state_dO_dPsum, extra_tx_count=self.tma_copy_bytes["V"] + ) + load_V(tma_bar_ptr=pipeline_dO.producer_get_barrier(producer_state_dO_dPsum)) + load_dO(m_block_min, producer_state=producer_state_dO_dPsum) pipeline_dO.producer_commit(producer_state_dO_dPsum) # dPsum pipeline_dPsum.producer_acquire(producer_state_dO_dPsum) with cute.arch.elect_one(): copy_stats( - gdPsum[None, m_block], + gdPsum[None, m_block_min], sdPsum[None, producer_state_dO_dPsum.index], mbar_ptr=pipeline_dPsum.producer_get_barrier(producer_state_dO_dPsum), ) producer_state_dO_dPsum.advance() - if const_expr(should_load_Q): - pipeline_Q.producer_tail( - producer_state_Q_LSE.clone() - ) # will hang if we don't clone - pipeline_LSE.producer_tail(producer_state_Q_LSE) - if const_expr(should_load_dO): - pipeline_dO.producer_tail(producer_state_dO_dPsum.clone()) - pipeline_dPsum.producer_tail(producer_state_dO_dPsum) + for m_block in cutlass.range(m_block_min + 1, m_block_max, unroll=1): + if const_expr(should_load_Q): + # Q + pipeline_Q.producer_acquire(producer_state_Q_LSE) + load_Q(m_block, producer_state=producer_state_Q_LSE) + pipeline_Q.producer_commit(producer_state_Q_LSE) + # LSE + pipeline_LSE.producer_acquire(producer_state_Q_LSE) + with cute.arch.elect_one(): + copy_stats( + gLSE[None, m_block], + sLSE[None, producer_state_Q_LSE.index], + mbar_ptr=pipeline_LSE.producer_get_barrier(producer_state_Q_LSE), + ) + producer_state_Q_LSE.advance() + if const_expr(should_load_dO): + # dO + pipeline_dO.producer_acquire(producer_state_dO_dPsum) + load_dO(m_block, producer_state=producer_state_dO_dPsum) + pipeline_dO.producer_commit(producer_state_dO_dPsum) + # dPsum + pipeline_dPsum.producer_acquire(producer_state_dO_dPsum) + with cute.arch.elect_one(): + copy_stats( + gdPsum[None, m_block], + sdPsum[None, producer_state_dO_dPsum.index], + mbar_ptr=pipeline_dPsum.producer_get_barrier( + producer_state_dO_dPsum + ), + ) + producer_state_dO_dPsum.advance() + + if const_expr(should_load_Q): + pipeline_Q.producer_tail( + producer_state_Q_LSE.clone() + ) # will hang if we don't clone + pipeline_LSE.producer_tail(producer_state_Q_LSE) + if const_expr(should_load_dO): + pipeline_dO.producer_tail(producer_state_dO_dPsum.clone()) + pipeline_dPsum.producer_tail(producer_state_dO_dPsum) tile_scheduler.prefetch_next_work() tile_scheduler.advance_to_next_work() @@ -1474,130 +1497,129 @@ def mma( m_block_min, m_block_max = block_info.get_m_block_min_max( seqlen, n_block // self.cluster_shape_mnk[0] ) - - accumulate_dK = False - # ----------------------------------------------------------- - ###### Prologue - # ----------------------------------------------------------- - # 1. S = Q0 @ K.T - # 2. dP = V @ dO.T - # 3. dV = P @ dO - - # 1) S = Q0 @ K.T - handle_Q = pipeline_Q_consumer.wait_and_advance() - pipeline_S_P.sync_object_empty.wait(0, producer_phase_acc) - mma_qk_fn(B_idx=handle_Q.index) - # Don't release Q yet - pipeline_S_P.sync_object_full.arrive(0, pipeline_S_P.producer_mask, cta_group) - - # 2) dP = V @ dO.T - pipeline_dO.consumer_wait(consumer_state_dO) - pipeline_dP.sync_object_empty.wait(0, producer_phase_acc) - # dQ uses the same tmem as dP - pipeline_dQ.sync_object_empty.wait(0, producer_phase_acc) - mma_dov_fn(B_idx=consumer_state_dO.index) - # Don't release dO yet - pipeline_dP.sync_object_full.arrive(0, pipeline_dP.producer_mask, cta_group) - - producer_phase_acc ^= 1 - # 3) dV = P.T @ dO - # wait for P to be ready, which uses the same tmem as S - pipeline_S_P.sync_object_empty.wait(0, producer_phase_acc) - mma_pdo_fn(B_idx=consumer_state_dO.index, zero_init=True) - pipeline_dO.consumer_release(consumer_state_dO) - consumer_state_dO.advance() - # ----------------------------------------------------------- - ###### MAIN LOOP - # ----------------------------------------------------------- - # 1. S = K @ Q.T - # 2. dQ = dS @ K - # 3. dK = dS.T @ Q - # 4. dP = V @ dO.T - # 5. dV = P.T @ dO - - for _ in cutlass.range(m_block_min + 1, m_block_max, unroll=1): - # 1) S = K @ Q_i - handle_Q_next = pipeline_Q_consumer.wait_and_advance() - # Don't need to wait for S, as P must have been ready ealier, i.e., S is ready - mma_qk_fn(B_idx=handle_Q_next.index) + if const_expr(not self.is_local) or m_block_min < m_block_max: + accumulate_dK = False + # ----------------------------------------------------------- + ###### Prologue + # ----------------------------------------------------------- + # 1. S = Q0 @ K.T + # 2. dP = V @ dO.T + # 3. dV = P @ dO + # 1) S = Q0 @ K.T + handle_Q = pipeline_Q_consumer.wait_and_advance() + pipeline_S_P.sync_object_empty.wait(0, producer_phase_acc) + mma_qk_fn(B_idx=handle_Q.index) + # Don't release Q yet pipeline_S_P.sync_object_full.arrive(0, pipeline_S_P.producer_mask, cta_group) - # 2-3) - # Do dK = dS.T @ Q, then dQ = dS @ K if dS in tmem for first mma - # Otherwise, reverse order - pipeline_dS.consumer_wait(consumer_state_dS) - - if const_expr(self.use_smem_dS_for_mma_dK): - mma_dsk_fn() - pipeline_dQ.sync_object_full.arrive(0, pipeline_dQ.producer_mask, cta_group) - mma_dsq_fn(B_idx=handle_Q.index, zero_init=not accumulate_dK) - accumulate_dK = True - handle_Q.release() - else: - mma_dsq_fn(B_idx=handle_Q.index, zero_init=not accumulate_dK) - accumulate_dK = True - handle_Q.release() - mma_dsk_fn() - pipeline_dQ.sync_object_full.arrive(0, pipeline_dQ.producer_mask, cta_group) - - # dP uses the same tmem as dQ - # However, if dS is ready, then dP must have been ready, - # so we don't need this wait before mma_dsk_fn() - # pipeline_dP.sync_object_empty.wait(0, producer_phase_acc) - - pipeline_dS.consumer_release(consumer_state_dS) - consumer_state_dS.advance() - - # 4) dP = V @ dO.T + # 2) dP = V @ dO.T pipeline_dO.consumer_wait(consumer_state_dO) + pipeline_dP.sync_object_empty.wait(0, producer_phase_acc) # dQ uses the same tmem as dP pipeline_dQ.sync_object_empty.wait(0, producer_phase_acc) mma_dov_fn(B_idx=consumer_state_dO.index) + # Don't release dO yet pipeline_dP.sync_object_full.arrive(0, pipeline_dP.producer_mask, cta_group) producer_phase_acc ^= 1 - # 5) dV += P @ dO + # 3) dV = P.T @ dO # wait for P to be ready, which uses the same tmem as S pipeline_S_P.sync_object_empty.wait(0, producer_phase_acc) - mma_pdo_fn(B_idx=consumer_state_dO.index, zero_init=False) + mma_pdo_fn(B_idx=consumer_state_dO.index, zero_init=True) pipeline_dO.consumer_release(consumer_state_dO) consumer_state_dO.advance() + # ----------------------------------------------------------- + ###### MAIN LOOP + # ----------------------------------------------------------- + # 1. S = K @ Q.T + # 2. dQ = dS @ K + # 3. dK = dS.T @ Q + # 4. dP = V @ dO.T + # 5. dV = P.T @ dO + + for _ in cutlass.range(m_block_min + 1, m_block_max, unroll=1): + # 1) S = K @ Q_i + handle_Q_next = pipeline_Q_consumer.wait_and_advance() + # Don't need to wait for S, as P must have been ready ealier, i.e., S is ready + mma_qk_fn(B_idx=handle_Q_next.index) + pipeline_S_P.sync_object_full.arrive(0, pipeline_S_P.producer_mask, cta_group) + + # 2-3) + # Do dK = dS.T @ Q, then dQ = dS @ K if dS in tmem for first mma + # Otherwise, reverse order + pipeline_dS.consumer_wait(consumer_state_dS) + + if const_expr(self.use_smem_dS_for_mma_dK): + mma_dsk_fn() + pipeline_dQ.sync_object_full.arrive(0, pipeline_dQ.producer_mask, cta_group) + mma_dsq_fn(B_idx=handle_Q.index, zero_init=not accumulate_dK) + accumulate_dK = True + handle_Q.release() + else: + mma_dsq_fn(B_idx=handle_Q.index, zero_init=not accumulate_dK) + accumulate_dK = True + handle_Q.release() + mma_dsk_fn() + pipeline_dQ.sync_object_full.arrive(0, pipeline_dQ.producer_mask, cta_group) + + # dP uses the same tmem as dQ + # However, if dS is ready, then dP must have been ready, + # so we don't need this wait before mma_dsk_fn() + # pipeline_dP.sync_object_empty.wait(0, producer_phase_acc) + + pipeline_dS.consumer_release(consumer_state_dS) + consumer_state_dS.advance() + + # 4) dP = V @ dO.T + pipeline_dO.consumer_wait(consumer_state_dO) + # dQ uses the same tmem as dP + pipeline_dQ.sync_object_empty.wait(0, producer_phase_acc) + mma_dov_fn(B_idx=consumer_state_dO.index) + pipeline_dP.sync_object_full.arrive(0, pipeline_dP.producer_mask, cta_group) + + producer_phase_acc ^= 1 + # 5) dV += P @ dO + # wait for P to be ready, which uses the same tmem as S + pipeline_S_P.sync_object_empty.wait(0, producer_phase_acc) + mma_pdo_fn(B_idx=consumer_state_dO.index, zero_init=False) + pipeline_dO.consumer_release(consumer_state_dO) + consumer_state_dO.advance() + + handle_Q = handle_Q_next - handle_Q = handle_Q_next - - pipeline_S_P.sync_object_full.arrive(0, pipeline_S_P.producer_mask, cta_group) - - # signal to the epilogue that dV is ready - # pipeline_dKV.producer_acquire(producer_state_dKV) - pipeline_dKV.sync_object_empty.wait(0, producer_phase_dKV) - # pipeline_dKV.producer_commit(producer_state_dKV) - pipeline_dKV.sync_object_full.arrive(0, pipeline_dKV.producer_mask, cta_group) - # producer_state_dKV.advance() - # pipeline_dKV.producer_acquire(producer_state_dKV) - pipeline_dKV.sync_object_empty.wait(1, producer_phase_dKV) - - # ----------------------------------------------------------- - ###### Remaining 2 - # ----------------------------------------------------------- - # 1) dK += dS.T @ Q - pipeline_dS.consumer_wait(consumer_state_dS) - mma_dsq_fn(B_idx=handle_Q.index, zero_init=not accumulate_dK) - # signal to the epilogue that dK is ready - # pipeline_dKV.producer_commit(producer_state_dKV) - pipeline_dKV.sync_object_full.arrive(1, pipeline_dKV.producer_mask, cta_group) - # producer_state_dKV.advance() - producer_phase_dKV ^= 1 - - # 2) dQ = dS @ K - # dS is done, so dP must have been ready, we don't need to wait - mma_dsk_fn() - pipeline_dQ.sync_object_full.arrive(0, pipeline_dQ.producer_mask, cta_group) - # Wait until dQ is done before releasing Q, since K and Q0 uses the same mbarrier - handle_Q.release() - pipeline_dS.consumer_release(consumer_state_dS) - consumer_state_dS.advance() - - producer_phase_acc ^= 1 + pipeline_S_P.sync_object_full.arrive(0, pipeline_S_P.producer_mask, cta_group) + + # signal to the epilogue that dV is ready + # pipeline_dKV.producer_acquire(producer_state_dKV) + pipeline_dKV.sync_object_empty.wait(0, producer_phase_dKV) + # pipeline_dKV.producer_commit(producer_state_dKV) + pipeline_dKV.sync_object_full.arrive(0, pipeline_dKV.producer_mask, cta_group) + # producer_state_dKV.advance() + # pipeline_dKV.producer_acquire(producer_state_dKV) + pipeline_dKV.sync_object_empty.wait(1, producer_phase_dKV) + + # ----------------------------------------------------------- + ###### Remaining 2 + # ----------------------------------------------------------- + # 1) dK += dS.T @ Q + pipeline_dS.consumer_wait(consumer_state_dS) + mma_dsq_fn(B_idx=handle_Q.index, zero_init=not accumulate_dK) + # signal to the epilogue that dK is ready + # pipeline_dKV.producer_commit(producer_state_dKV) + pipeline_dKV.sync_object_full.arrive(1, pipeline_dKV.producer_mask, cta_group) + # producer_state_dKV.advance() + producer_phase_dKV ^= 1 + + # 2) dQ = dS @ K + # dS is done, so dP must have been ready, we don't need to wait + mma_dsk_fn() + pipeline_dQ.sync_object_full.arrive(0, pipeline_dQ.producer_mask, cta_group) + # Wait until dQ is done before releasing Q, since K and Q0 uses the same mbarrier + handle_Q.release() + pipeline_dS.consumer_release(consumer_state_dS) + consumer_state_dS.advance() + + producer_phase_acc ^= 1 tile_scheduler.advance_to_next_work() work_tile = tile_scheduler.get_current_work() @@ -1717,7 +1739,7 @@ def compute_loop( # 0: [256...384] # 1: [128...256] - tileP_f32_like = self.mma_tiler_kq[0] // 32 * self.v_dtype.width # (128, 64) + tileP_f32_like = self.mma_tiler_kq[0] // 32 * self.v_dtype.width # 64 for tile_n = 128 # tStS has shape ((128, 128), 1, 1), tStP has shape ((128, 64), 1, 1) # tP overlap with tS tStP = cute.composition(tStS, (cute.make_layout((self.tile_n, tileP_f32_like)), 1, 1)) @@ -1943,61 +1965,96 @@ def compute_loop( pipeline_dS.producer_commit(producer_state_dS) producer_state_dS.advance() - if const_expr(not self.use_tma_store): - consumer_state_dKV = self.epilogue_dKV( - dp_idx, - warp_idx, - batch_idx, - head_idx, - n_block, - thr_mma_dV, - thr_mma_dK, - tdVtdV, - tdKtdK, - mdV, - mdK, - pipeline_dKV, - consumer_state_dKV, - softmax_scale, - ) - else: - thr_copy_r2s_dKV = tiled_copy_r2s_dKV.get_slice(dp_idx) - #### STORE dV - consumer_state_dKV = self.epilogue_dK_or_dV_tma( - dp_idx, - batch_idx, - head_idx, - n_block, - thr_mma_dV, - tdVtdV, - mdV_tma_tensor, - sdV, - tma_atom_dV, - thr_copy_r2s_dKV, - pipeline_dKV, - consumer_state_dKV, - None, # Don't scale - int(NamedBarrierBwdSm100.EpilogueWG1), # barrier_id - mdV_semaphore, - ) - #### STORE dK - consumer_state_dKV = self.epilogue_dK_or_dV_tma( - dp_idx, - batch_idx, - head_idx, - n_block, - thr_mma_dK, - tdKtdK, - mdK_tma_tensor, - sdK, - tma_atom_dK, - thr_copy_r2s_dKV, - pipeline_dKV, - consumer_state_dKV, - softmax_scale if const_expr(self.qhead_per_kvhead == 1) else None, - int(NamedBarrierBwdSm100.EpilogueWG1), # barrier_id - mdK_semaphore, - ) + # Epilogue + if const_expr(not self.is_local) or m_block_min < m_block_max: + if const_expr(not self.use_tma_store): + consumer_state_dKV = self.epilogue_dKV( + dp_idx, + warp_idx, + batch_idx, + head_idx, + n_block, + thr_mma_dV, + thr_mma_dK, + tdVtdV, + tdKtdK, + mdV, + mdK, + pipeline_dKV, + consumer_state_dKV, + softmax_scale, + ) + else: + thr_copy_r2s_dKV = tiled_copy_r2s_dKV.get_slice(dp_idx) + #### STORE dV + consumer_state_dKV = self.epilogue_dK_or_dV_tma( + dp_idx, + batch_idx, + head_idx, + n_block, + thr_mma_dV, + tdVtdV, + mdV_tma_tensor, + sdV, + tma_atom_dV, + thr_copy_r2s_dKV, + pipeline_dKV, + consumer_state_dKV, + None, # Don't scale + int(NamedBarrierBwdSm100.EpilogueWG1), # barrier_id + mdV_semaphore, + ) + #### STORE dK + consumer_state_dKV = self.epilogue_dK_or_dV_tma( + dp_idx, + batch_idx, + head_idx, + n_block, + thr_mma_dK, + tdKtdK, + mdK_tma_tensor, + sdK, + tma_atom_dK, + thr_copy_r2s_dKV, + pipeline_dKV, + consumer_state_dKV, + softmax_scale if const_expr(self.qhead_per_kvhead == 1) else None, + int(NamedBarrierBwdSm100.EpilogueWG1), # barrier_id + mdK_semaphore, + ) + if const_expr(self.qhead_per_kvhead == 1 and self.is_local): + if m_block_min >= m_block_max: + # if tidx == 0: + # cute.printf("m_block_min = {}, m_block_max = {}", m_block_min, m_block_max) + # like other epis, currently assumes hdim == hdimv + gmem_tiled_copy_zero_dKV = copy_utils.tiled_copy_2d( + self.dk_dtype, + self.tile_hdim, + 128, # num_threads + ) + gmem_thr_copy_zero_dKV = gmem_tiled_copy_zero_dKV.get_slice(dp_idx) + mdV_cur = mdV[None, None, head_idx, batch_idx] + mdK_cur = mdK[None, None, head_idx, batch_idx] + gdK = cute.local_tile(mdK_cur, (self.tile_n, self.tile_hdim), (n_block, 0)) + gdV = cute.local_tile(mdV_cur, (self.tile_n, self.tile_hdimv), (n_block, 0)) + tdKgdK = gmem_thr_copy_zero_dKV.partition_D(gdK) + tdVgdV = gmem_thr_copy_zero_dKV.partition_D(gdV) + assert tdKgdK.shape[2] == 1 + assert tdVgdV.shape[2] == 1 + cdKV = cute.make_identity_tensor((self.tile_n, self.tile_hdim)) + tdKVcdKV = gmem_thr_copy_zero_dKV.partition_D(cdKV) + zero = cute.make_fragment_like(tdKgdK[None, 0, 0]) + zero.fill(0.0) + if tidx < 128: + for i in cutlass.range_constexpr(tdKgdK.shape[1]): + row_idx = tdKVcdKV[0, i, 0][0] + if row_idx < seqlen.seqlen_k - self.tile_n * n_block: + cute.copy(gmem_tiled_copy_zero_dKV, zero, tdKgdK[None, i, 0]) + else: + for i in cutlass.range_constexpr(tdVgdV.shape[1]): + row_idx = tdKVcdKV[0, i, 0][0] + if row_idx < seqlen.seqlen_k - self.tile_n * n_block: + cute.copy(gmem_tiled_copy_zero_dKV, zero, tdVgdV[None, i, 0]) tile_scheduler.advance_to_next_work() work_tile = tile_scheduler.get_current_work() @@ -2092,13 +2149,20 @@ def dQacc_reduce( # semaphore acquire if const_expr(self.deterministic and stage == 0): if const_expr(self.spt): - n_block_max_for_m_block = min( - n_block_global_max, - cute.ceil_div( - (m_block + 1) * self.tile_m + seqlen.seqlen_k - seqlen.seqlen_q, - self.tile_n, - ), - ) + if const_expr( + self.is_causal or block_info.window_size_right is not None + ): + n_idx_right = ( + (m_block + 1) * self.tile_m + seqlen.seqlen_k - seqlen.seqlen_q + ) + if const_expr(block_info.window_size_right is not None): + n_idx_right += block_info.window_size_right + n_block_max_for_m_block = min( + n_block_global_max, + cute.ceil_div(n_idx_right, self.tile_n), + ) + else: + n_block_max_for_m_block = n_block_global_max lock_value = n_block_max_for_m_block - 1 - n_block else: lock_value = n_block @@ -2144,12 +2208,22 @@ def dQacc_reduce( self.reduce_sync_barrier.arrive_and_wait() barrier.arrive_inc(mdQ_semaphore_cur[m_block, None].iterator, tidx, 0, 1) - if is_tma_warp: - cute.arch.cp_async_bulk_wait_group(0, read=read_flag) - self.reduce_sync_barrier.arrive_and_wait() - # final semaphore release - if const_expr(self.deterministic and delay_semaphore_release): - barrier.arrive_inc(mdQ_semaphore_cur[(m_block_max - 1, None)].iterator, tidx, 0, 1) + if const_expr(not self.is_local) or m_block_min < m_block_max: + if is_tma_warp: + cute.arch.cp_async_bulk_wait_group(0, read=read_flag) + self.reduce_sync_barrier.arrive_and_wait() + # final semaphore release + if const_expr(self.deterministic and delay_semaphore_release): + barrier.arrive_inc( + mdQ_semaphore_cur[(m_block_max - 1, None)].iterator, tidx, 0, 1 + ) + + if const_expr( + self.deterministic and not self.spt and block_info.window_size_left is not None + ): + m_block_global_max = cute.ceil_div(seqlen.seqlen_q, self.tile_m) + for m_block in cutlass.range(m_block_max, m_block_global_max, unroll=1): + barrier.arrive_inc(mdQ_semaphore_cur[(m_block, None)].iterator, tidx, 0, 1) tile_scheduler.advance_to_next_work() work_tile = tile_scheduler.get_current_work() @@ -2222,7 +2296,7 @@ def epilogue_dKV( dV_vec = tdVrdV_t2r[(None, i, 0, 0)].load() tdVrdV_r2s[(None, i, 0, 0)].store(dV_vec.to(self.dv_dtype)) - gdV = cute.local_tile(mdV_cur, (self.tile_m, self.tile_hdimv), (None, 0)) + gdV = cute.local_tile(mdV_cur, (self.tile_n, self.tile_hdimv), (None, 0)) gdV_tile = gdV[None, None, n_block] tdVgdV = thr_mma_dV.partition_C(gdV_tile) diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 4c3e52f46d5..651e9393135 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -295,6 +295,7 @@ def _flash_attn_fwd( if window_size_left is not None or window_size_right is not None: if window_size_left is None and window_size_right == 0: causal, local = True, False + window_size_right = None else: causal, local = False, True else: @@ -540,6 +541,8 @@ def _flash_attn_bwd( softmax_scale: Optional[float] = None, causal: bool = False, softcap: float = 0.0, + window_size_left: Optional[int] = None, + window_size_right: Optional[int] = None, m_block_size: int = 64, n_block_size: int = 128, num_threads: int = 256, @@ -575,6 +578,7 @@ def _flash_attn_bwd( AtomLayoutNdKV = 2 AtomLayoutMdQ = 1 cluster_size = 1 + assert window_size_left is None and window_size_right is None, "local not supported yet on 9.x" else: m_block_size = 128 n_block_size = 128 @@ -608,6 +612,16 @@ def _flash_attn_bwd( num_head_kv = k.shape[-2] head_dim_v = v.shape[-1] + if causal: + window_size_right = 0 + local = window_size_left is not None or window_size_right is not None + if local: + if window_size_left is None and window_size_right == 0: + causal, local = True, False + window_size_right = None + else: + causal, local = False, True + if cu_seqlens_k is None: assert k.shape == (batch_size, seqlen_k, num_head_kv, head_dim) assert v.shape == (batch_size, seqlen_k, num_head_kv, head_dim_v) @@ -840,6 +854,8 @@ def _flash_attn_bwd( head_dim_v, qhead_per_kvhead, causal, + window_size_left is not None, + window_size_right is not None, softcap != 0.0, m_block_size, n_block_size, @@ -896,6 +912,7 @@ def _flash_attn_bwd( head_dim, head_dim_v, is_causal=causal, + is_local=local, qhead_per_kvhead=qhead_per_kvhead, # tile_m=m_block_size, # tile_n=n_block_size, @@ -921,6 +938,8 @@ def _flash_attn_bwd( cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor, + window_size_left=window_size_left, + window_size_right=window_size_right, mdQ_semaphore=dQ_semaphore_tensor, mdK_semaphore=dK_semaphore_tensor, mdV_semaphore=dV_semaphore_tensor, @@ -941,6 +960,8 @@ def _flash_attn_bwd( cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor, + window_size_left=window_size_left, + window_size_right=window_size_right, mdQ_semaphore=dQ_semaphore_tensor, mdK_semaphore=dK_semaphore_tensor, mdV_semaphore=dV_semaphore_tensor, @@ -1103,6 +1124,8 @@ def backward(ctx, dout, *args): ctx.softmax_scale, ctx.causal, ctx.softcap, + window_size_left=ctx.window_size[0], + window_size_right=ctx.window_size[1], deterministic=ctx.deterministic, ) return dq, dk, dv, *((None,) * 20) # Extra Nones is fine diff --git a/flash_attn/cute/mask.py b/flash_attn/cute/mask.py index da3ed8fb2d3..430c7d26fc5 100644 --- a/flash_attn/cute/mask.py +++ b/flash_attn/cute/mask.py @@ -239,10 +239,10 @@ def apply_mask( ) if const_expr(self.window_size_right is not None): col_limit_right = row_idx + local_row_offset_right - if const_expr(mask_seqlen): - col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit) else: col_limit_right = self.tile_n + if const_expr(mask_seqlen): + col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit) col_limit_left = ( row_idx + local_row_offset_left if const_expr(self.window_size_left is not None) @@ -411,10 +411,10 @@ def apply_mask_sm100( ) if const_expr(self.window_size_right is not None): col_limit_right = row_idx + local_row_offset_right - if const_expr(mask_seqlen): - col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit) else: col_limit_right = self.tile_n + if const_expr(mask_seqlen): + col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit) col_limit_left = ( row_idx + local_row_offset_left if const_expr(self.window_size_left is not None) @@ -447,28 +447,27 @@ def apply_mask_sm100_transposed( assert not (mask_causal and mask_local), "mask_causal and mask_local cannot be both True" ROW = 0 if const_expr(not self.swap_AB) else 1 COL = 1 if const_expr(not self.swap_AB) else 0 + assert t0ScS_t2r[0][COL] == 0, "col0 == 0" thr_col_offset = tScS_t2r[0][COL] seqlenk_col_limit = self.seqlen_k - n_block * self.tile_n - thr_col_offset if const_expr(not mask_causal and not mask_local): if const_expr(mask_seqlen): - if t0ScS_t2r[0][COL] >= seqlenk_col_limit: + if seqlenk_col_limit <= 0: for i in cutlass.range(cute.size(acc_S.shape), unroll_full=True): acc_S[i] = -cutlass.Float32.inf else: # Causal or local thr_row_offset = tScS_t2r[0][ROW] - causal_row_offset = ( - seqlenk_col_limit - self.seqlen_q + m_block * self.tile_m + thr_row_offset - ) + seqlenq_row_limit = self.seqlen_q - m_block * self.tile_m - thr_row_offset + causal_offset = seqlenq_row_limit - seqlenk_col_limit if const_expr(mask_causal): - col0 = t0ScS_t2r[0][COL] - row_limit_top = col0 - causal_row_offset # tidx = cute.arch.thread_idx()[0] % 256 # if tidx < 32: - # cute.printf("tidx = {}, {} {}, {} {}, col0 = {}", tidx, tScS_t2r[0][0], tScS_t2r[0][1], tScS_t2r[1][0], tScS_t2r[1][1], col0) + # cute.printf("tidx = {}, {} {}, {} {}", tidx, tScS_t2r[0][0], tScS_t2r[0][1], tScS_t2r[1][0], tScS_t2r[1][1]) + row_limit_top = causal_offset if const_expr(mask_seqlen): # If col is beyond the column limit, we want to mask out the entire # column, by setting row limit to be self.tile_m. - if t0ScS_t2r[0][COL] >= seqlenk_col_limit: + if seqlenk_col_limit <= 0: row_limit_top = self.tile_m r2p = True if const_expr(not r2p): @@ -480,4 +479,18 @@ def apply_mask_sm100_transposed( num_rep = cute.size(tScS_t2r, mode=[0]) # 16 or 32 mask_r2p_transposed(acc_S, row_limit_top, num_rep) else: - assert False, "Local masking isn't supported yet" + if const_expr(self.window_size_right is not None): + row_limit_top = causal_offset - self.window_size_right + else: + row_limit_top = 0 + if const_expr(self.window_size_left is not None): + row_limit_bot = causal_offset + self.window_size_left + if const_expr(mask_seqlen): + if seqlenk_col_limit <= 0: + row_limit_top = self.tile_m + for i in cutlass.range(cute.size(acc_S.shape), unroll_full=True): + row_idx = t0ScS_t2r[i][ROW] + local_mask = row_idx < row_limit_top + if const_expr(self.window_size_left is not None): + local_mask |= row_idx > row_limit_bot + acc_S[i] = -cutlass.Float32.inf if local_mask else acc_S[i] diff --git a/flash_attn/cute/testing.py b/flash_attn/cute/testing.py index 214ed09bc9e..a23a624d059 100644 --- a/flash_attn/cute/testing.py +++ b/flash_attn/cute/testing.py @@ -260,8 +260,12 @@ def construct_local_mask( return col_idx > row_idx + sk - sq + window_size[1] else: sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk + if window_size[1] is None: + local_mask_left = col_idx > sk + else: + local_mask_left = col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk) return torch.logical_or( - col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk), + local_mask_left, torch.logical_and( col_idx < row_idx + sk - sq - window_size[0], col_idx >= sink_token_length ), diff --git a/tests/cute/test_flash_attn.py b/tests/cute/test_flash_attn.py index fc26fb34af8..fe1d18afb6d 100644 --- a/tests/cute/test_flash_attn.py +++ b/tests/cute/test_flash_attn.py @@ -29,7 +29,8 @@ DISABLE_SPLIT = os.getenv("FLASH_ATTENTION_DISABLE_SPLIT", "FALSE") == "TRUE" - +TEST_BWD_ONLY = False +VERBOSE = True # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) @pytest.mark.parametrize("dtype", [torch.bfloat16]) @@ -43,8 +44,8 @@ @pytest.mark.parametrize("deterministic", [False]) # @pytest.mark.parametrize("softcap", [0.0, 15.0]) @pytest.mark.parametrize("softcap", [0.0]) -@pytest.mark.parametrize("local", [False, True]) -# @pytest.mark.parametrize("local", [False]) +@pytest.mark.parametrize("local_enum", [0, 1, 2, 3]) +# @pytest.mark.parametrize("local_enum", [0]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize("causal", [True]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) @@ -92,7 +93,7 @@ def test_flash_attn_output( seqlen_k, d, causal, - local, + local_enum, softcap, deterministic, has_qv, @@ -100,8 +101,9 @@ def test_flash_attn_output( mha_type, dtype, ): - # if (causal or local) and seqlen_k < seqlen_q: - # pytest.skip("Causal attention requires seqlen_k >= seqlen_q") + local = local_enum > 0 + if local and causal: + pytest.skip() device = "cuda" # set seed torch.random.manual_seed(0) @@ -115,7 +117,7 @@ def test_flash_attn_output( dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype # dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) dv_vals = [128] if d == 192 else ([d] if d != 128 else [64, d]) - if dtype == torch.float8_e4m3fn: + if dtype == torch.float8_e4m3fn or TEST_BWD_ONLY: dv_vals = [d] # attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] attention_chunk_vals = [0] @@ -157,6 +159,12 @@ def test_flash_attn_output( window_size = ( (None, None) if not local else torch.randint(0, seqlen_k, (2,)).tolist() ) + if local_enum == 2: + window_size = (None, -window_size[1]) + elif local_enum == 3: + window_size = (-window_size[0], None) + if local: + print("window size = ", window_size) # window_size = (-1, -1) if not local else (16, 0) if has_learnable_sink: learnable_sink = torch.randn(nheads, dtype=torch.bfloat16, device=device) @@ -228,7 +236,7 @@ def test_flash_attn_output( # pack_gqa_vals = [False, True, None] # SplitKV is not supported for hdim >= 192 pack_gqa_vals = [False] - num_splits_vals = [1, 3] if d < 192 and not DISABLE_SPLIT else [1] + num_splits_vals = [1, 3] if d < 192 and not DISABLE_SPLIT and not TEST_BWD_ONLY else [1] for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): out, lse = flash_attn_func( q, @@ -241,8 +249,9 @@ def test_flash_attn_output( # attention_chunk=attention_chunk, softcap=softcap, learnable_sink=learnable_sink, - # pack_gqa=pack_gqa, + pack_gqa=pack_gqa, num_splits=num_splits, + deterministic=deterministic, ) print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") @@ -262,12 +271,9 @@ def test_flash_attn_output( and not dv > 256 and not attention_chunk != 0 and softcap == 0.0 - and not local and dv == d and learnable_sink is None - # and mha_type == "mha" # and False - and not ((causal or local) and seqlen_k < seqlen_q) ): g = torch.randn_like(out) # do_o = ((g.float() * out.float()).sum(-1)).transpose(1, 2) @@ -301,6 +307,26 @@ def test_flash_attn_output( print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") + + if VERBOSE: + diff_dq = (dq - dq_ref).abs() + max_idx = diff_dq.argmax() + coords = torch.unravel_index(max_idx, diff_dq.shape) + print(f"dQ max diff: {diff_dq.max().item()}") + print(f" at coordinates {tuple(c.item() for c in coords)}: dQ={dq[coords].item()}, dQ_ref={dq_ref[coords].item()}") + + diff_dk = (dk - dk_ref).abs() + max_idx = diff_dk.argmax() + coords = torch.unravel_index(max_idx, diff_dk.shape) + print(f"dK max diff: {diff_dk.max().item()}") + print(f" at coordinates {tuple(c.item() for c in coords)}: dK={dk[coords].item()}, dK_ref={dk_ref[coords].item()}") + + diff_dv = (dv - dv_ref).abs() + max_idx = diff_dv.argmax() + coords = torch.unravel_index(max_idx, diff_dv.shape) + print(f"dV max diff: {diff_dv.max().item()}") + print(f" at coordinates {tuple(c.item() for c in coords)}: dV={dv[coords].item()}, dV_ref={dv_ref[coords].item()}") + # breakpoint() dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + ( 0 if softcap == 0 else 3e-4 diff --git a/tests/cute/test_flash_attn_race_condition.py b/tests/cute/test_flash_attn_race_condition.py index 101e058d60e..520cf6466a7 100644 --- a/tests/cute/test_flash_attn_race_condition.py +++ b/tests/cute/test_flash_attn_race_condition.py @@ -44,25 +44,17 @@ @pytest.mark.parametrize("deterministic", [True]) # @pytest.mark.parametrize("softcap", [0.0, 15.0]) @pytest.mark.parametrize("softcap", [0.0]) -# @pytest.mark.parametrize("local", [False, True]) -@pytest.mark.parametrize("local", [False]) -@pytest.mark.parametrize("causal", [False, True]) -# @pytest.mark.parametrize("causal", [False]) -# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) -# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192, 256]) -# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) -# @pytest.mark.parametrize('d', [56, 80]) -# @pytest.mark.parametrize("d", [64, 128, 256]) -# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128]) -# @pytest.mark.parametrize("d", [64, 96, 128, 192]) -# @pytest.mark.parametrize("d", [64, 128]) -# @pytest.mark.parametrize("d", [128, 192]) +@pytest.mark.parametrize("local_enum", [0, 1, 2, 3]) +# @pytest.mark.parametrize("local_enum", [0]) +# @pytest.mark.parametrize("causal", [False, True]) +@pytest.mark.parametrize("causal", [False]) @pytest.mark.parametrize("d", [64, 128]) +# @pytest.mark.parametrize("d", [128]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ (4224, 4224), - (2048, 4096), + (2000, 4000), ], ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)]) @@ -71,7 +63,7 @@ def test_flash_attn_output( seqlen_k, d, causal, - local, + local_enum, softcap, deterministic, has_qv, @@ -79,8 +71,9 @@ def test_flash_attn_output( mha_type, dtype, ): - if (causal or local) and seqlen_k < seqlen_q: - pytest.skip("Causal attention requires seqlen_k >= seqlen_q") + local = local_enum > 0 + if local and causal: + pytest.skip() device = "cuda" # set seed torch.random.manual_seed(0) @@ -137,6 +130,12 @@ def test_flash_attn_output( window_size = ( (None, None) if not local else torch.randint(0, seqlen_k, (2,)).tolist() ) + if local_enum == 2: + window_size = (None, -window_size[1]) + elif local_enum == 3: + window_size = (-window_size[0], None) + if local: + print("window size = ", window_size) # window_size = (-1, -1) if not local else (16, 0) if has_learnable_sink: learnable_sink = torch.randn(nheads, dtype=torch.bfloat16, device=device) @@ -222,7 +221,7 @@ def test_flash_attn_output( # attention_chunk=attention_chunk, softcap=softcap, learnable_sink=learnable_sink, - # pack_gqa=pack_gqa, + pack_gqa=pack_gqa, num_splits=num_splits, deterministic=deterministic, ) @@ -244,12 +243,9 @@ def test_flash_attn_output( and not dv > 256 and not attention_chunk != 0 and softcap == 0.0 - and not local and dv == d and learnable_sink is None - # and mha_type == "mha" # and False - and not ((causal or local) and seqlen_k < seqlen_q) ): g = torch.randn_like(out) # do_o = ((g.float() * out.float()).sum(-1)).transpose(1, 2) @@ -303,11 +299,13 @@ def test_flash_attn_output( dv_pt - dv_ref ).abs().max().item() + dv_atol - num_iters = 100_000 + num_iters = 20_000 for i in range(num_iters): dq2, dk2, dv2, = _flash_attn_bwd( q, k, v, out, g, lse, causal=causal, + window_size_left=window_size[0], + window_size_right=window_size[1], deterministic=True, )