diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py index 0a29ce462a8..fb0e2e9b778 100644 --- a/flash_attn/cute/flash_bwd_sm100.py +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -91,6 +91,7 @@ def __init__( # Speed optimizations, does not affect correctness self.shuffle_LSE = False self.shuffle_dPsum = False + self.use_smem_dS_for_mma_dK = self.deterministic and self.is_causal self.reduce_warp_ids = (0, 1, 2, 3) self.compute_warp_ids = (4, 5, 6, 7, 8, 9, 10, 11) @@ -146,7 +147,7 @@ def __init__( self.tmem_dK_offset = self.tmem_dP_offset + self.tile_m self.tmem_dS_offset = self.tmem_dP_offset # overlap with dP - if not is_causal and not is_local: + if (not is_causal and not is_local) or deterministic: self.num_regs_reduce = 152 self.num_regs_compute = 136 else: @@ -203,6 +204,10 @@ def _get_tiled_mma(self): a_source=tcgen05.OperandSource.TMEM, ) # dK += dS.T @ Q + if const_expr(self.use_smem_dS_for_mma_dK): + mma_dK_a_src = tcgen05.OperandSource.SMEM + else: + mma_dK_a_src = tcgen05.OperandSource.TMEM tiled_mma_dK = sm100_utils_basic.make_trivial_tiled_mma( self.do_dtype, tcgen05.OperandMajorMode.K, # dS_major_mode @@ -210,7 +215,7 @@ def _get_tiled_mma(self): self.acc_dtype, cta_group, self.mma_tiler_dsq[:2], - a_source=tcgen05.OperandSource.TMEM, + a_source=mma_dK_a_src, ) # dQ = dS @ K tiled_mma_dQ = sm100_utils_basic.make_trivial_tiled_mma( @@ -403,13 +408,13 @@ def __call__( semaphore_transpose = [2, 3, 1, 0] # (b, n, block, stage) -> (block, stage, n, b) if const_expr(self.deterministic): assert mdQ_semaphore is not None - mdQ_semaphore = utils.select(mdQ_semaphore.layout, mode=semaphore_transpose) + mdQ_semaphore = utils.select(mdQ_semaphore, mode=semaphore_transpose) if const_expr(self.deterministic and self.qhead_per_kvhead > 1): assert mdK_semaphore is not None assert mdV_semaphore is not None mdK_semaphore, mdV_semaphore = [ - utils.select(t.layout, mode=semaphore_transpose) + utils.select(t, mode=semaphore_transpose) for t in (mdK_semaphore, mdV_semaphore) ] else: @@ -546,15 +551,18 @@ def __call__( self.tma_copy_bytes["dQ"] = self.tile_m * self.dQ_reduce_ncol * Float32.width // 8 self.tma_copy_bytes["dKacc"] = self.tile_n * self.dK_reduce_ncol * Float32.width // 8 - # TileScheduler = SingleTileScheduler if not self.is_causal else SingleTileLPTBwdScheduler - TileScheduler = SingleTileScheduler - # TODO -- optimizer scheduler for causal + # TileScheduler = SingleTileScheduler + if const_expr(self.deterministic): + TileScheduler = SingleTileLPTBwdScheduler + else: + TileScheduler = SingleTileScheduler + self.spt = self.is_causal and self.deterministic tile_sched_args = TileSchedulerArguments( cute.ceil_div(cute.size(mK.shape[0]), self.cta_tiler[0]), cute.size(mQ.shape[2]), # num_heads = num_query_heads cute.size(mK.shape[3]), 1, # num_splits - cute.size(mK.shape[0]), + cute.size(mQ.shape[0]), # pass seqlen_q for seqlen_k mQ.shape[1], mV.shape[1], total_q=cute.size(mQ.shape[0]), @@ -565,7 +573,7 @@ def __call__( qhead_per_kvhead_packgqa=1, element_size=self.k_dtype.width // 8, is_persistent=self.is_persistent, - lpt=False, + lpt=self.spt, ) tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) @@ -1364,8 +1372,10 @@ def mma( tdPrV = tiled_mma_dP.make_fragment_A(sV) tdPrdOt = tiled_mma_dP.make_fragment_B(sdOt) # dK = dS.T @ Q - # tdKrdS = tiled_mma_dK.make_fragment_A(sdSt) - tdKrdS = tiled_mma_dK.make_fragment_A(tdS) + if const_expr(self.use_smem_dS_for_mma_dK): + tdKrdS = tiled_mma_dK.make_fragment_A(sdSt) + else: + tdKrdS = tiled_mma_dK.make_fragment_A(tdS) tdKrQ = tiled_mma_dK.make_fragment_B(sQt) # dQ = dS @ K tdQrdS = tiled_mma_dQ.make_fragment_A(sdS) @@ -1404,18 +1414,20 @@ def mma( # mma_dsk_fn = partial( # gemm_ptx_w_idx, tiled_mma_dQ, tdQtdQ, tdQrdS, tdQrK, sA=sdS, sB=sKt, zero_init=True # ) - # mma_dsq_fn = partial(gemm_w_idx, tiled_mma_dK, tdKtdK, tdKrdS, tdKrQ) - # Need to explicitly pass in tA_addr for correctness - mma_dsq_fn = partial( - gemm_ptx_w_idx, - tiled_mma_dK, - tdKtdK, - tdKrdS, - tdKrQ, - sA=None, - sB=sQt, - tA_addr=self.tmem_dS_offset, - ) + if const_expr(self.use_smem_dS_for_mma_dK): + mma_dsq_fn = partial(gemm_w_idx, tiled_mma_dK, tdKtdK, tdKrdS, tdKrQ) + else: + # Need to explicitly pass in tA_addr for correctness + mma_dsq_fn = partial( + gemm_ptx_w_idx, + tiled_mma_dK, + tdKtdK, + tdKrdS, + tdKrQ, + sA=None, + sB=sQt, + tA_addr=self.tmem_dS_offset, + ) consumer_state_dO = cutlass.pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Consumer, self.dO_stage @@ -1486,18 +1498,29 @@ def mma( mma_qk_fn(B_idx=handle_Q_next.index) pipeline_S_P.sync_object_full.arrive(0, pipeline_S_P.producer_mask, cta_group) - # 2) dK = dS.T @ Q + # 2-3) + # Do dK = dS.T @ Q, then dQ = dS @ K if dS in tmem for first mma + # Otherwise, reverse order pipeline_dS.consumer_wait(consumer_state_dS) - mma_dsq_fn(B_idx=handle_Q.index, zero_init=not accumulate_dK) - accumulate_dK = True - handle_Q.release() - # 3) dQ = dS @ K + if const_expr(self.use_smem_dS_for_mma_dK): + mma_dsk_fn() + pipeline_dQ.sync_object_full.arrive(0, pipeline_dQ.producer_mask, cta_group) + mma_dsq_fn(B_idx=handle_Q.index, zero_init=not accumulate_dK) + accumulate_dK = True + handle_Q.release() + else: + mma_dsq_fn(B_idx=handle_Q.index, zero_init=not accumulate_dK) + accumulate_dK = True + handle_Q.release() + mma_dsk_fn() + pipeline_dQ.sync_object_full.arrive(0, pipeline_dQ.producer_mask, cta_group) + # dP uses the same tmem as dQ - # However, if dS is ready, then dP must have been ready, so we don't need to wait + # However, if dS is ready, then dP must have been ready, + # so we don't need this wait before mma_dsk_fn() # pipeline_dP.sync_object_empty.wait(0, producer_phase_acc) - mma_dsk_fn() - pipeline_dQ.sync_object_full.arrive(0, pipeline_dQ.producer_mask, cta_group) + pipeline_dS.consumer_release(consumer_state_dS) consumer_state_dS.advance() @@ -1823,8 +1846,8 @@ def compute_loop( ) cute.arch.fence_view_async_tmem_store() + self.compute_sync_barrier.arrive_and_wait() - cute.arch.sync_warp() with cute.arch.elect_one(): pipeline_S_P.consumer_release(consumer_state_S_P_dP) # pipeline_S_P.sync_object_empty.arrive(0, pipeline_S_P.consumer_mask) @@ -1847,6 +1870,7 @@ def compute_loop( tdPrdP_t2r = cute.make_fragment(tScS_t2r[None, 0, None, None].shape, Float32) cute.copy(thr_copy_t2r, tdPtdP_t2r[None, stage, None, None], tdPrdP_t2r) cute.arch.fence_view_async_tmem_load() + self.compute_sync_barrier.arrive_and_wait() tdPrdP_cur = tdPrdP_t2r[None, 0, 0] tSrS_cur = tSrS_t2r[None, stage, 0, 0] tSsdPsum_cur = tSsdPsum[None, stage, 0, 0, consumer_state_dPsum.index] @@ -1875,22 +1899,20 @@ def compute_loop( if const_expr(stage == 0): pipeline_dS.producer_acquire(producer_state_dS) cute.autovec_copy(tdPrdS_cvt, tRS_sdS[None, stage]) - tdPrdS_r2t_f32 = cute.recast_tensor(tdPrdS_cvt, Float32) - cute.copy(thr_copy_r2t, tdPrdS_r2t_f32, tdPtdS_r2t[None, stage, 0, 0]) + if const_expr(not self.use_smem_dS_for_mma_dK): + tdPrdS_r2t_f32 = cute.recast_tensor(tdPrdS_cvt, Float32) + cute.copy(thr_copy_r2t, tdPrdS_r2t_f32, tdPtdS_r2t[None, stage, 0, 0]) - cute.arch.fence_view_async_tmem_store() + if const_expr(not self.use_smem_dS_for_mma_dK): + cute.arch.fence_view_async_tmem_store() + cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) + self.compute_sync_barrier.arrive_and_wait() - cute.arch.sync_warp() # with cute.arch.elect_one(): # The mma warp no longer waits for dP (it waits for dS), so we don't have to arrive # pipeline_dP.sync_object_empty.arrive(0, pipeline_dP.consumer_mask) pipeline_dPsum.consumer_release(consumer_state_dPsum) consumer_state_dPsum.advance() - - cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta - ) - cute.arch.sync_warp() with cute.arch.elect_one(): pipeline_dS.producer_commit(producer_state_dS) producer_state_dS.advance() @@ -2010,10 +2032,13 @@ def dQacc_reduce( gdQaccum = cute.flat_divide( gdQaccum_, (self.tile_m * self.tile_hdim // self.dQaccum_reduce_stage,) ) - mdQ_semaphore_cur = None + if const_expr(self.deterministic): mdQ_semaphore_cur = mdQ_semaphore[None, None, head_idx, batch_idx] + delay_semaphore_release = self.is_causal + n_block_global_max = cute.ceil_div(seqlen.seqlen_k, self.tile_n) + for m_block in cutlass.range(m_block_min, m_block_max, unroll=1): pipeline_dQ.consumer_wait(dQ_consumer_state) # TMEM -> RMEM @@ -2025,11 +2050,6 @@ def dQacc_reduce( pipeline_dQ.consumer_release(dQ_consumer_state) dQ_consumer_state.advance() - # semaphore acquire - if const_expr(self.deterministic): - barrier.wait_eq(mdQ_semaphore_cur[m_block, None].iterator, tidx, 0, n_block) - self.reduce_sync_barrier.arrive_and_wait() - gdQaccum_cur = gdQaccum[None, None, m_block] for stage in cutlass.range_constexpr(cute.size(tdQrdQ_t2r, mode=[1])): # 4 @@ -2043,6 +2063,17 @@ def dQacc_reduce( cute.arch.fence_proxy( cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta ) + # semaphore acquire + if const_expr(self.deterministic and stage == 0): + if const_expr(self.spt): + n_block_max_for_m_block = min( + n_block_global_max, + cute.ceil_div((m_block + 1) * self.tile_m + seqlen.seqlen_k - seqlen.seqlen_q, self.tile_n) + ) + lock_value = n_block_max_for_m_block - 1 - n_block + else: + lock_value = n_block + barrier.wait_eq(mdQ_semaphore_cur[(m_block, None)].iterator, tidx, 0, lock_value) self.reduce_sync_barrier.arrive_and_wait() # Copy from shared memory to global memory if is_tma_warp: @@ -2067,17 +2098,25 @@ def dQacc_reduce( # tdQrdQ_r2s[4 * i + 3], # utils.elem_pointer(tdQgdQ, 4 * i), # ) + # semaphore release for prior m_block + if const_expr(self.deterministic and stage == 0 and delay_semaphore_release): + if m_block > m_block_min: + barrier.arrive_inc(mdQ_semaphore_cur[(m_block - 1, None)].iterator, tidx, 0, 1) # semaphore release # NOTE: arrive_inc calls red_release which issues membar - if const_expr(self.deterministic): - if tidx == 0: + if const_expr(self.deterministic and not delay_semaphore_release): + if is_tma_warp: cute.arch.cp_async_bulk_wait_group(0, read=read_flag) self.reduce_sync_barrier.arrive_and_wait() barrier.arrive_inc(mdQ_semaphore_cur[m_block, None].iterator, tidx, 0, 1) - if warp_idx == 0: + if is_tma_warp: cute.arch.cp_async_bulk_wait_group(0, read=read_flag) + self.reduce_sync_barrier.arrive_and_wait() + # final semaphore release + if const_expr(self.deterministic and delay_semaphore_release): + barrier.arrive_inc(mdQ_semaphore_cur[(m_block_max - 1, None)].iterator, tidx, 0, 1) tile_scheduler.advance_to_next_work() work_tile = tile_scheduler.get_current_work() @@ -2274,7 +2313,8 @@ def epilogue_dK_or_dV_tma( gdKV, (self.sdKV_flat_epi_tile,) ) # (tile_n * hdim / 2 / epi_stage, epi_stage) - if const_expr(self.deterministic and self.qhead_per_kvhead > 1): + deterministic_KV = self.deterministic and self.qhead_per_kvhead > 1 + if const_expr(deterministic_KV): mdKV_semaphore_cur = mdKV_semaphore[n_block, None, head_idx_kv, batch_idx] if const_expr(self.qhead_per_kvhead == 1): @@ -2296,12 +2336,12 @@ def epilogue_dK_or_dV_tma( tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), Float32 ) - read_flag = const_expr(not self.deterministic) + read_flag = const_expr(not deterministic_KV) pipeline_dKV.consumer_wait(consumer_state_dKV) # semaphore acquire - if const_expr(self.deterministic): + if const_expr(deterministic_KV): barrier.wait_eq( mdKV_semaphore_cur.iterator, tidx, wg_idx, head_idx % self.qhead_per_kvhead ) @@ -2377,7 +2417,7 @@ def epilogue_dK_or_dV_tma( # semaphore release # NOTE: arrive_inc calls red_release which issues membar - if const_expr(self.deterministic): + if const_expr(deterministic_KV): if leader_warp: cute.arch.cp_async_bulk_commit_group() cute.arch.cp_async_bulk_wait_group(0, read=read_flag) diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 28bcb994ee7..1e94453252e 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -561,6 +561,7 @@ def _flash_attn_bwd( cu_seqlens_k: Optional[torch.Tensor] = None, seqused_q: Optional[torch.Tensor] = None, seqused_k: Optional[torch.Tensor] = None, + deterministic: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: compute_capability = torch.cuda.get_device_capability()[0] assert compute_capability in [9, 10], "Unsupported compute capability. Supported: 9.x, 10.x" @@ -659,6 +660,8 @@ def _flash_attn_bwd( pack_gqa = qhead_per_kvhead > 1 if compute_capability == 10: pack_gqa = False # override for now + if compute_capability != 10: + assert deterministic is False, "bwd deterministic only supported for sm100 for now" device = q.device # TODO: check if this is the right rounding @@ -757,6 +760,22 @@ def _flash_attn_bwd( else None for t in (cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k) ] + if deterministic: + dQ_semaphore = torch.zeros(batch_size, num_head, seqlen_q_rounded // m_block_size, 1, dtype=torch.int32, device="cuda") + else: + dQ_semaphore = None + + if deterministic and qhead_per_kvhead > 1: + dK_semaphore = torch.zeros(batch_size, num_head_kv, seqlen_k_rounded // n_block_size, 2, dtype=torch.int32, device="cuda") + dV_semaphore = torch.zeros(batch_size, num_head_kv, seqlen_k_rounded // n_block_size, 2, dtype=torch.int32, device="cuda") + else: + dK_semaphore = None + dV_semaphore = None + dQ_semaphore_tensor, dK_semaphore_tensor, dV_semaphore_tensor = [ + utils.convert_from_dlpack_leading_static(t.detach(), leading_dim=3, alignment=4, stride_order=t.dim_order()) + if t is not None else None + for t in (dQ_semaphore, dK_semaphore, dV_semaphore) + ] current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) # Preprocess kernel: compute (o * dout).sum(dim=-1), lse * log2_e, and zero out dq_accum. @@ -831,6 +850,7 @@ def _flash_attn_bwd( num_threads, pack_gqa, cluster_size, + deterministic, ) num_threads = 384 if compile_key not in _flash_attn_bwd.compile_cache: @@ -885,6 +905,7 @@ def _flash_attn_bwd( # tile_n=n_block_size, cluster_size=cluster_size, # cluster_size=1, + deterministic=deterministic, ) # TODO: check @can_implement _flash_attn_bwd.compile_cache[compile_key] = cute.compile( @@ -904,6 +925,9 @@ def _flash_attn_bwd( cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor, + mdQ_semaphore=dQ_semaphore_tensor, + mdK_semaphore=dK_semaphore_tensor, + mdV_semaphore=dV_semaphore_tensor, ) _flash_attn_bwd.compile_cache[compile_key]( q_tensor, @@ -921,6 +945,9 @@ def _flash_attn_bwd( cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor, + mdQ_semaphore=dQ_semaphore_tensor, + mdK_semaphore=dK_semaphore_tensor, + mdV_semaphore=dV_semaphore_tensor, ) num_threads = 256 if compute_capability == 9 else 128 @@ -1028,6 +1055,7 @@ def forward( softcap: float = 0.0, num_splits: int = 1, pack_gqa: Optional[bool] = None, + deterministic: bool = False, mask_mod: Optional[Callable] = None, full_block_cnt: Optional[torch.Tensor] = None, full_block_idx: Optional[torch.Tensor] = None, @@ -1063,6 +1091,7 @@ def forward( ctx.causal = causal ctx.window_size = window_size ctx.softcap = softcap + ctx.deterministic = deterministic return out, lse @staticmethod @@ -1078,6 +1107,7 @@ def backward(ctx, dout, *args): ctx.softmax_scale, ctx.causal, ctx.softcap, + deterministic=ctx.deterministic, ) return dq, dk, dv, *((None,) * 20) # Extra Nones is fine @@ -1101,6 +1131,7 @@ def forward( softcap: float = 0.0, num_splits: int = 1, pack_gqa: Optional[bool] = None, + deterministic: bool = False, ): out, lse = _flash_attn_fwd( q, @@ -1125,6 +1156,7 @@ def forward( ctx.causal = causal ctx.window_size = window_size ctx.softcap = softcap + ctx.deterministic = deterministic return out, lse @staticmethod @@ -1146,6 +1178,7 @@ def backward(ctx, dout, *args): cu_seqlens_k=cu_seqlens_k, seqused_q=seqused_q, seqused_k=seqused_k, + deterministic=ctx.deterministic, ) return dq, dk, dv, *((None,) * 20) @@ -1162,6 +1195,7 @@ def flash_attn_func( softcap: float = 0.0, num_splits: int = 1, pack_gqa: Optional[bool] = None, + deterministic: bool = False, mask_mod: Optional[Callable] = None, full_block_cnt: Optional[torch.Tensor] = None, full_block_idx: Optional[torch.Tensor] = None, @@ -1179,6 +1213,7 @@ def flash_attn_func( softcap, num_splits, pack_gqa, + deterministic, mask_mod, full_block_cnt, full_block_idx, @@ -1203,6 +1238,7 @@ def flash_attn_varlen_func( softcap: float = 0.0, num_splits: int = 1, pack_gqa: Optional[bool] = None, + deterministic: bool = False, ): return FlashAttnVarlenFunc.apply( q, @@ -1220,6 +1256,7 @@ def flash_attn_varlen_func( softcap, num_splits, pack_gqa, + deterministic, ) diff --git a/flash_attn/cute/tile_scheduler.py b/flash_attn/cute/tile_scheduler.py index f3a06c186e7..ad6ab099b0a 100644 --- a/flash_attn/cute/tile_scheduler.py +++ b/flash_attn/cute/tile_scheduler.py @@ -374,19 +374,28 @@ class SingleTileLPTBwdScheduler: @dataclass class Params(ParamsBase): total_blocks: Int32 + num_block: Int32 num_head_divmod: FastDivmod l2_minor_divmod: FastDivmod l2_major_divmod: FastDivmod l2_minor_residual_divmod: FastDivmod num_hb_quotient: Int32 cluster_shape_mn: cutlass.Constexpr[Tuple[int, int]] = (1, 1) + spt: cutlass.Constexpr[bool] = True @staticmethod @cute.jit def create( args: TileSchedulerArguments, *, loc=None, ip=None ) -> "SingleTileLPTBwdScheduler.Params": - swizzle = 8 + size_l2 = 50 * 1024 * 1024 + size_one_qdo_head = args.seqlen_k * (args.headdim + args.headdim_v) * args.element_size + # size_one_dqaccum_head = args.seqlen_k * (args.headdim) * 4 + size_one_dqaccum_head = 0 + size_one_head = size_one_qdo_head + size_one_dqaccum_head + log2_floor = lambda n: 31 - clz(n) + swizzle = 1 if size_l2 < size_one_head else (1 << log2_floor(size_l2 // size_one_head)) + # swizzle = 8 # If we're in the last section (called residual), we don't want to divide by # swizzle. Instead we want to divide by the remainder. num_hb_quotient = (args.num_head * args.num_batch) // swizzle @@ -396,6 +405,7 @@ def create( total_blocks=(num_block * args.cluster_shape_mn[0]) * args.num_head * args.num_batch, + num_block=num_block, num_head_divmod=FastDivmod.create(args.num_head), l2_minor_divmod=FastDivmod.create(swizzle), l2_major_divmod=FastDivmod.create(swizzle * num_block), @@ -404,6 +414,7 @@ def create( ), # don't divide by 0 num_hb_quotient=Int32(num_hb_quotient), cluster_shape_mn=args.cluster_shape_mn, + spt=args.lpt, ) def __init__(self, params: Params, tile_idx: Int32, *, loc=None, ip=None): @@ -450,6 +461,8 @@ def get_current_work(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo: is_valid = self._tile_idx < params.total_blocks bidx_in_cluster = cute.arch.block_in_cluster_idx() block = block * params.cluster_shape_mn[0] + bidx_in_cluster[0] + if cutlass.const_expr(params.spt): + block = params.num_block - 1 - block return WorkTileInfo((Int32(block), Int32(head_idx), Int32(batch_idx), Int32(0)), is_valid) def initial_work_tile_info(self, *, loc=None, ip=None): diff --git a/flash_attn/cute/utils.py b/flash_attn/cute/utils.py index aa50c89c5bf..eb8b86cbe0b 100644 --- a/flash_attn/cute/utils.py +++ b/flash_attn/cute/utils.py @@ -71,6 +71,14 @@ def convert_from_dlpack(x, leading_dim, alignment=16, divisibility=1) -> cute.Te ) ) +def convert_from_dlpack_leading_static(x, leading_dim, alignment=16, static_modes=None, stride_order=None) -> cute.Tensor: + if stride_order is None: + stride_order = x.dim_order() + x_ = from_dlpack(x, assumed_align=alignment) + for i in range(x.ndim): + if i != leading_dim and (static_modes is None or i not in static_modes): + x_ = x_.mark_compact_shape_dynamic(mode=i, stride_order=stride_order) + return x_ def make_tiled_copy_A( copy_atom: cute.CopyAtom, tiled_mma: cute.TiledMma, swapAB: cutlass.Constexpr[bool] = False diff --git a/tests/cute/test_flash_attn_race_condition.py b/tests/cute/test_flash_attn_race_condition.py new file mode 100644 index 00000000000..5cedc49d3c4 --- /dev/null +++ b/tests/cute/test_flash_attn_race_condition.py @@ -0,0 +1,341 @@ +# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + +import math +import itertools +import os + +import pytest +import torch + +from einops import rearrange, repeat + +try: + from flash_attn.layers.rotary import apply_rotary_emb +except ImportError: + apply_rotary_emb = None + +from flash_attn.cute.testing import ( + attention_ref, + generate_qkv, + generate_random_padding_mask, + pad_input, + unpad_input, +) +from flash_attn.cute.interface import ( + flash_attn_func, + flash_attn_varlen_func, + flash_attn_combine, + _flash_attn_bwd, +) + + +DISABLE_SPLIT = os.getenv("FLASH_ATTENTION_DISABLE_SPLIT", "FALSE") == "TRUE" + + +# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +# @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +@pytest.mark.parametrize("mha_type", ["gqa"]) +# @pytest.mark.parametrize("has_learnable_sink", [False, True]) +@pytest.mark.parametrize("has_learnable_sink", [False]) +# @pytest.mark.parametrize("has_qv", [False, True]) +@pytest.mark.parametrize("has_qv", [False]) +# @pytest.mark.parametrize("deterministic", [False, True]) +@pytest.mark.parametrize("deterministic", [True]) +# @pytest.mark.parametrize("softcap", [0.0, 15.0]) +@pytest.mark.parametrize("softcap", [0.0]) +# @pytest.mark.parametrize("local", [False, True]) +@pytest.mark.parametrize("local", [False]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize("causal", [False]) +# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192, 256]) +# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [56, 80]) +# @pytest.mark.parametrize("d", [64, 128, 256]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128]) +# @pytest.mark.parametrize("d", [64, 96, 128, 192]) +# @pytest.mark.parametrize("d", [64, 128]) +# @pytest.mark.parametrize("d", [128, 192]) +@pytest.mark.parametrize("d", [128]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (4224, 4224), + (2048, 4096), + ], +) +# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)]) +def test_flash_attn_output( + seqlen_q, + seqlen_k, + d, + causal, + local, + softcap, + deterministic, + has_qv, + has_learnable_sink, + mha_type, + dtype, +): + if (causal or local) and seqlen_k < seqlen_q: + pytest.skip("Causal attention requires seqlen_k >= seqlen_q") + device = "cuda" + # set seed + torch.random.manual_seed(0) + torch.cuda.empty_cache() + torch.cuda.synchronize() + batch_size = 9 if seqlen_k <= 2048 else 2 + # batch_size = 1 + nheads = 6 + # nheads = 1 + nheads_kv = nheads if mha_type == "mha" else (3 if mha_type == "gqa" else 1) + dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype + # dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) + dv_vals = [128] if d == 192 else ([d] if d != 128 else [64, d]) + if dtype == torch.float8_e4m3fn: + dv_vals = [d] + dv_vals = [d] + # attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] + attention_chunk_vals = [0] + for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): + q_ref = torch.randn( + batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref + ) + if softcap > 0.0: + # Ensure the values of qk are at least within softcap range. + q_ref = q_ref * softcap / 4 + q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_() + k_ref = ( + torch.randn( + batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + .requires_grad_() + ) + v_ref = ( + torch.randn( + batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + .requires_grad_() + ) + if has_qv: + qv_ref = ( + torch.randn( + batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + ) + else: + qv_ref = None + # Put window_size after QKV randn so that window_size changes from test to test + window_size = ( + (None, None) if not local else torch.randint(0, seqlen_k, (2,)).tolist() + ) + # window_size = (-1, -1) if not local else (16, 0) + if has_learnable_sink: + learnable_sink = torch.randn(nheads, dtype=torch.bfloat16, device=device) + else: + learnable_sink = None + if dtype == torch.float8_e4m3fn: + q_descale, k_descale, v_descale = [ + torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) + * 2 + for _ in range(3) + ] + else: + q_descale, k_descale, v_descale = None, None, None + q, k, v = [x.detach().to(dtype).requires_grad_() for x in (q_ref, k_ref, v_ref)] + qv = qv_ref.detach().to(dtype).requires_grad_() if has_qv else None + out_ref, attn_ref = attention_ref( + q_ref, + k_ref, + v_ref, + None, + None, + causal=causal, + qv=qv_ref, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + learnable_sink=learnable_sink, + softcap=softcap, + ) + out_pt, attn_pt = attention_ref( + q_ref, + k_ref, + v_ref, + None, + None, + causal=causal, + qv=qv_ref, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + learnable_sink=learnable_sink, + softcap=softcap, + upcast=False, + reorder_ops=True, + intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, + ) + + # k_extended = repeat(k_ref, "b s h d -> b s (h k) d", k=nheads // nheads_kv) + # qk = torch.einsum('bshd,bthd->bhst', q_ref, k_extended).float() + # # if qv is not None: + # # qk += torch.einsum('bshd,bthd->bhst', qv_ref, v_ref).float() + # m = qk.amax(-1, keepdim=True) + # s_tmp = torch.exp((qk - m) / math.sqrt(d)) + # exp_sum = s_tmp.sum(-1) + # # qk = torch.einsum('bthd,bshd->bhts', q_ref.float() / math.sqrt(d), k_ref.float()) + # # lse_ref = torch.logsumexp(qk, dim=-1) + + # Numerical error if we just do any arithmetic on out_ref + fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() + rtol = 2 if softcap == 0.0 else 3 + + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + # num_splits_vals = [1, 3] + # pack_gqa_vals = [False, True, None] + # SplitKV is not supported for hdim >= 192 + pack_gqa_vals = [False] + # num_splits_vals = [1, 3] if d < 192 and not DISABLE_SPLIT else [1] + num_splits_vals = [1] + for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): + out, lse = flash_attn_func( + q, + k, + v, + causal=causal, + # qv=qv, + # q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + # attention_chunk=attention_chunk, + softcap=softcap, + learnable_sink=learnable_sink, + # pack_gqa=pack_gqa, + num_splits=num_splits, + deterministic=deterministic, + ) + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + # if not causal: + # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") + # breakpoint() + + # Check that FlashAttention's numerical error is at most twice the numerical error + # of a Pytorch implementation. + assert (out - out_ref).abs().max().item() <= rtol * ( + out_pt - out_ref + ).abs().max().item() + fwd_atol + + if ( + dtype != torch.float8_e4m3fn + and not has_qv + and not dv > 256 + and not attention_chunk != 0 + and softcap == 0.0 + and not local + and dv == d + and learnable_sink is None + # and mha_type == "mha" + # and False + and not ((causal or local) and seqlen_k < seqlen_q) + ): + g = torch.randn_like(out) + # do_o = ((g.float() * out.float()).sum(-1)).transpose(1, 2) + dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) + # print(f"dO_O max diff: {(softmax_d - do_o).abs().max().item()}") + # assert (softmax_d - do_o).abs().max().item() <= 1e-5 + # assert dq_accum.abs().max().item() == 0.0 + + # dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float()) + # P = torch.softmax(qk, -1) + # dP = P * (dS - do_o.transpose(1, 2).unsqueeze(1)) + # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float()) + # dV = torch.einsum('bhts,bthd->bshd', P, g.float()) + # dK = torch.einsum('bhts,bthd->bshd', dP, q.float()) + # breakpoint() + + # dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) + dq_ref, dk_ref, dv_ref = torch.autograd.grad( + out_ref, (q_ref, k_ref, v_ref), g + ) + dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g) + print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") + print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") + print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") + print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") + print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") + print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") + print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") + print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") + print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") + print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") + print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") + print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") + # breakpoint() + dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + ( + 0 if softcap == 0 else 3e-4 + ) + assert (dq - dq_ref).abs().max().item() <= rtol * ( + dq_pt - dq_ref + ).abs().max().item() + dq_atol + dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + ( + 0 if softcap == 0 else 3e-4 + ) + assert (dk - dk_ref).abs().max().item() <= rtol * ( + dk_pt - dk_ref + ).abs().max().item() + dk_atol + dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + ( + 0 if softcap == 0 else 3e-4 + ) + assert (dv - dv_ref).abs().max().item() <= rtol * ( + dv_pt - dv_ref + ).abs().max().item() + dv_atol + + num_iters = 100_000 + for i in range(num_iters): + dq2, dk2, dv2, = _flash_attn_bwd( + q, k, v, out, g, lse, + causal=causal, + deterministic=True, + ) + + diff_dq = (dq - dq2).abs() + max_idx = diff_dq.argmax() + print(f"dQ max diff: {diff_dq.max().item()}") + print(f" at index {max_idx.item()}: dQ={dq.flatten()[max_idx].item()}, dQ2={dq2.flatten()[max_idx].item()}") + + diff_dk = (dk - dk2).abs() + max_idx = diff_dk.argmax() + print(f"dK max diff: {diff_dk.max().item()}") + print(f" at index {max_idx.item()}: dK={dk.flatten()[max_idx].item()}, dK2={dk2.flatten()[max_idx].item()}") + + diff_dv = (dv - dv2).abs() + max_idx = diff_dv.argmax() + print(f"dV max diff: {diff_dv.max().item()}") + print(f" at index {max_idx.item()}: dV={dv.flatten()[max_idx].item()}, dV2={dv2.flatten()[max_idx].item()}") + + # print(f"dQ max diff with myself: {(dq - dq2).abs().max().item()}") + # print(f"dK max diff with myself: {(dk - dk2).abs().max().item()}") + # print(f"dV max diff with myself: {(dv - dv2).abs().max().item()}") + # print(f"dQ mean diff with myself: {(dq - dq2).abs().mean().item()}") + # print(f"dK mean diff with myself: {(dk - dk2).abs().mean().item()}") + # print(f"dV mean diff with myself: {(dv - dv2).abs().mean().item()}") + + assert torch.equal(dq, dq2) + assert torch.equal(dk, dk2) + assert torch.equal(dv, dv2) + + print(f"✅ Iteration {i} passed!") +