From d7f8e3fbb76b0f80a09b2e782ba742dcc36466f8 Mon Sep 17 00:00:00 2001 From: jayhshah Date: Fri, 26 Dec 2025 06:04:04 +0000 Subject: [PATCH 01/15] varlen bwd with rounded padded offsets --- flash_attn/cute/flash_bwd_postprocess.py | 290 +---------------------- flash_attn/cute/flash_bwd_preprocess.py | 12 +- flash_attn/cute/flash_bwd_sm100.py | 165 +++++++++---- flash_attn/cute/interface.py | 66 +++++- flash_attn/cute/seqlen_info.py | 36 ++- 5 files changed, 229 insertions(+), 340 deletions(-) diff --git a/flash_attn/cute/flash_bwd_postprocess.py b/flash_attn/cute/flash_bwd_postprocess.py index 14d746ba346..8a4ed74e776 100644 --- a/flash_attn/cute/flash_bwd_postprocess.py +++ b/flash_attn/cute/flash_bwd_postprocess.py @@ -233,13 +233,15 @@ def __call__( TileScheduler = SingleTileVarlenScheduler num_head = mdQ.shape[1] num_batch = mCuSeqlensQ.shape[0] - 1 + num_block = cute.ceil_div(mdQ.shape[0], self.tile_m) else: TileScheduler = SingleTileScheduler num_head = mdQ.shape[2] num_batch = mdQ.shape[0] + num_block = cute.ceil_div(mdQ.shape[1], self.tile_m) tile_sched_args = TileSchedulerArguments( - num_block=cute.ceil_div(mdQ.shape[1], self.tile_m), + num_block=num_block, num_head=num_head, num_batch=num_batch, num_splits=1, @@ -318,7 +320,7 @@ def kernel( tile_scheduler = TileScheduler.create(tile_sched_params) work_tile = tile_scheduler.initial_work_tile_info() - m_block, num_head, batch_size, _ = work_tile.tile_idx + m_block, head_idx, batch_idx, _ = work_tile.tile_idx if work_tile.is_valid_tile: # /////////////////////////////////////////////////////////////////////////////// @@ -326,7 +328,7 @@ def kernel( # /////////////////////////////////////////////////////////////////////////////// seqlen = SeqlenInfoQK.create( - batch_size, + batch_idx, mdQ.shape[1], 0, mCuSeqlensQ=mCuSeqlensQ, @@ -335,14 +337,16 @@ def kernel( mSeqUsedK=None, ) if const_expr(not seqlen.has_cu_seqlens_q): - mdQ_cur = mdQ[batch_size, None, num_head, None] - mdQaccum_cur = mdQaccum[batch_size, num_head, None] + mdQ_cur = mdQ[batch_idx, None, head_idx, None] + mdQaccum_cur = mdQaccum[batch_idx, head_idx, None] head_dim = mdQ.shape[3] else: - padded_offset_q = seqlen.offset_q + batch_size * self.tile_m - mdQ_cur = cute.domain_offset((seqlen.offset_q, 0), mdQ[None, num_head, None]) + padded_offset_q = cute.round_up( + seqlen.offset_q + batch_idx * self.tile_m, self.tile_m + ) + mdQ_cur = cute.domain_offset((seqlen.offset_q, 0), mdQ[None, head_idx, None]) mdQaccum_cur = cute.domain_offset( - (padded_offset_q * self.tile_hdim,), mdQaccum[num_head, None] + (padded_offset_q * self.tile_hdim,), mdQaccum[head_idx, None] ) head_dim = mdQ.shape[2] @@ -457,273 +461,3 @@ def kernel( tdQgdQ[None, rest_m, None], pred=tdQpdQ[None, rest_m, None], ) - - -class FlashAttentionBackwardPostprocess_sm100(FlashAttentionBackwardPostprocess): - def __init__( - self, - dtype: Type[cutlass.Numeric], - head_dim: int, - tile_m: int = 128, - num_threads: int = 256, - AtomLayoutMdQ: int = 1, - dQ_swapAB: bool = False, - ): - super().__init__( - dtype=dtype, - head_dim=head_dim, - arch=90, # tmp dummy placement for now - tile_m=tile_m, - num_threads=num_threads, - AtomLayoutMdQ=AtomLayoutMdQ, - dQ_swapAB=dQ_swapAB, - ) - - def _setup_attributes(self): - self.num_stages = self.tile_hdim // 32 # 2 for D=64, 4 for D=128 - - self.sdQaccum_layout = cute.make_layout( - shape=(self.tile_m * 32, 2), stride=(1, self.tile_m * 32) - ) - self.epi_tile_q = (self.tile_m, self.tile_hdim) - self.sdQ_layout = sm100_utils_basic.make_smem_layout_epi( - self.dtype, - LayoutEnum.ROW_MAJOR, - self.epi_tile_q, - 1, - ) - - @cute.jit - def __call__( - self, - mdQaccum: cute.Tensor, - mdQ: cute.Tensor, - scale: cutlass.Float32, - stream: cuda.CUstream, - ): - # Assume all strides are divisible by 128 bits except the last stride - new_stride = lambda t: ( - *(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]), - t.stride[-1], - ) - mdQaccum, mdQ = [ - cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) - for t in (mdQaccum, mdQ) - ] - # (b, h, s*d) -> (s*d, h, b) - mdQaccum = cute.make_tensor(mdQaccum.iterator, cute.select(mdQaccum.layout, mode=[2, 1, 0])) - # (b, s, h, d) -> (s, d, h, b) - mdQ = cute.make_tensor(mdQ.iterator, cute.select(mdQ.layout, mode=[1, 3, 2, 0])) - - self._setup_attributes() - - grid_dim = [ - cute.ceil_div(mdQ.shape[0], self.tile_m), - cute.size(mdQ.shape[2]), - cute.size(mdQ.shape[3]), - ] - - cta_group = tcgen05.CtaGroup.ONE - self.mma_tiler_dsk = (self.tile_m, self.tile_hdim) - - dS_major_mode = tcgen05.OperandMajorMode.MN - kt_major_mode_dsq = tcgen05.OperandMajorMode.MN - - tiled_mma_dsk = sm100_utils_basic.make_trivial_tiled_mma( - cutlass.BFloat16, - dS_major_mode, - kt_major_mode_dsq, - cutlass.Float32, - cta_group, - self.mma_tiler_dsk, - ) - - dQ_cta_v_layout = cute.composition(cute.make_identity_layout(mdQ.shape), self.mma_tiler_dsk) - tma_store_op = cpasync.CopyBulkTensorTileS2GOp() - tma_atom_dQ, tma_tensor_dQ = cute.nvgpu.cpasync.make_tiled_tma_atom( - tma_store_op, - mdQ, - cute.select(self.sdQ_layout, mode=[0, 1]), - dQ_cta_v_layout, - ) - - buffer_align_bytes = 1024 - - @cute.struct - class SharedStorage: - sdQaccum: cute.struct.Align[ - cute.struct.MemRange[cutlass.Float32, cute.cosize(self.sdQaccum_layout)], - 128, - ] - - sdQ: cute.struct.Align[ - cute.struct.MemRange[self.dtype, cute.cosize(self.sdQ_layout)], - buffer_align_bytes, - ] - - self.shared_storage = SharedStorage - - self.kernel( - mdQaccum, - tma_tensor_dQ, - tma_atom_dQ, - self.sdQaccum_layout, - self.sdQ_layout, - tiled_mma_dsk, - scale, - ).launch( - grid=grid_dim, - block=[self.num_threads, 1, 1], - smem=self.shared_storage.size_in_bytes(), - stream=stream, - ) - - @cute.kernel - def kernel( - self, - mdQaccum: cute.Tensor, - mdQ: cute.Tensor, - tma_atom_dQ: cute.CopyAtom, - sdQaccum_layout: cute.Layout, - sdQ_layout: cute.ComposedLayout, - tiled_mma_dsk: cute.TiledMma, - scale: cutlass.Float32, - ): - tidx = cute.arch.thread_idx()[0] - warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) - m_block, head_idx, batch_idx = cute.arch.block_idx() - - # SMEM - smem = cutlass.utils.SmemAllocator() - storage = smem.allocate(self.shared_storage) - swz128 = cute.make_swizzle(3, 4, 3) - sdQaccum = storage.sdQaccum.get_tensor(sdQaccum_layout, swizzle=swz128) - - sdQ = storage.sdQ.get_tensor(sdQ_layout.outer, swizzle=sdQ_layout.inner) - - mdQaccum_cur = mdQaccum[None, head_idx, batch_idx] - mdQ_cur = mdQ[None, None, head_idx, batch_idx] - - thr_mma_dsk = tiled_mma_dsk.get_slice(tidx) - dQacc_shape = thr_mma_dsk.partition_shape_C(self.mma_tiler_dsk[:2]) - tdQtdQ = thr_mma_dsk.make_fragment_C(dQacc_shape) - tdQtdQ = cute.make_tensor(tdQtdQ.iterator, tdQtdQ.layout) - - tmem_ld_atom = cute.make_copy_atom( - tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), cutlass.Float32 - ) - tiled_tmem_ld = tcgen05.make_tmem_copy(tmem_ld_atom, tdQtdQ) - thr_tmem_ld = tiled_tmem_ld.get_slice(tidx) - - cdQ = cute.make_identity_tensor((self.mma_tiler_dsk[0], self.mma_tiler_dsk[1])) - tdQcdQ = thr_mma_dsk.partition_C(cdQ) - tdQcdQ_tensor = cute.make_tensor(tdQcdQ.iterator, tdQcdQ.layout) - tdQrdQ = thr_tmem_ld.partition_D(tdQcdQ_tensor) - - gdQaccum = cute.local_tile(mdQaccum_cur, (self.tile_m * self.tile_hdim,), (m_block,)) - - num_reduce_warps = 4 - num_reduce_threads = cute.arch.WARP_SIZE * num_reduce_warps - - atom_universal_copy = cute.make_copy_atom( - cute.nvgpu.CopyUniversalOp(), cutlass.Float32, num_bits_per_copy=128 - ) - tiler_mn, layout_tv = cute.make_layout_tv( - thr_layout=cute.make_layout(shape=num_reduce_threads, stride=1), - val_layout=cute.make_layout(shape=4, stride=1), - ) - G2S_tiled_copy_dQaccum = cute.make_tiled_copy( - atom_universal_copy, layout_tv=layout_tv, tiler_mn=tiler_mn - ) - - smem_thr_copy_g2s = G2S_tiled_copy_dQaccum.get_slice(tidx) - - # S->R - tdQrdQ_t2r = cute.make_fragment(tdQrdQ.shape, cutlass.Float32) - tiled_smem_store_s2r = cute.make_tiled_copy( - atom_universal_copy, layout_tv=layout_tv, tiler_mn=tiler_mn - ) - - s2r_thr_copy_dQaccum = tiled_smem_store_s2r.get_slice(tidx) - tdQsdQ_s2r = s2r_thr_copy_dQaccum.partition_S(sdQaccum) - tdQrdQ_s2r = cute.make_tensor(tdQrdQ_t2r.iterator, tdQrdQ_t2r.shape) - - # R->S - smem_copy_atom = sm100_utils_basic.get_smem_store_op( - LayoutEnum.ROW_MAJOR, self.dtype, cutlass.Float32, tiled_tmem_ld - ) - tiled_smem_store_r2s = cute.make_tiled_copy( - smem_copy_atom, - layout_tv=tiled_tmem_ld.layout_dst_tv_tiled, - tiler_mn=tiled_tmem_ld.tiler_mn, - ) - tdQsdQ_r2s = thr_tmem_ld.partition_D(thr_mma_dsk.partition_C(sdQ)) - tdQrdQ_r2s = cute.make_fragment(tdQsdQ_r2s.shape, self.dtype) - - num_stages = cute.size(tdQrdQ_t2r, mode=[1]) - for stage in cutlass.range_constexpr(num_stages): - # G->S - gdQaccum_stage = cute.local_tile( - gdQaccum, - (self.tile_m * 32,), - (stage,), - ) - - gdQaccum_layout_g2s = cute.make_layout(shape=(self.tile_m * 32, 1), stride=(1, 0)) - gdQaccum_stage_g2s = cute.make_tensor( - cute.recast_ptr(gdQaccum_stage.iterator, swizzle_=swz128), gdQaccum_layout_g2s - ) - - tdQgdQ = smem_thr_copy_g2s.partition_S(gdQaccum_stage_g2s) - tdQsdQ = smem_thr_copy_g2s.partition_D(sdQaccum) - - cute.copy(smem_thr_copy_g2s, tdQgdQ[None, None, 0], tdQsdQ[None, None, 0]) - - cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta - ) - cute.arch.barrier(barrier_id=6, number_of_threads=num_reduce_threads) - - # S -> R - tdQrdQ_s2r_cpy = tdQrdQ_s2r[None, stage, None, None] - tdQsdQ_s2r_p = tdQsdQ_s2r[None, None, 0] - tdQrdQ_r2s_cpy = cute.make_tensor( - tdQrdQ_s2r_cpy.iterator, cute.make_layout(tdQsdQ_s2r_p.shape) - ) - - cute.copy(s2r_thr_copy_dQaccum, tdQsdQ_s2r_p, tdQrdQ_r2s_cpy) - - cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta - ) - cute.arch.barrier(barrier_id=7, number_of_threads=num_reduce_threads) - - # R->S - tdQrdQ_r2s_cpy = cute.make_tensor( - cute.recast_ptr(tdQrdQ_r2s_cpy.iterator), - tdQrdQ_r2s[((None, 0), stage, 0, 0, 0)].shape, - ) - dQ_vec = tdQrdQ_r2s_cpy.load() * scale - tdQrdQ_r2s[((None, 0), stage, 0, 0, 0)].store(dQ_vec.to(self.dtype)) - - cute.copy( - tiled_smem_store_r2s, - tdQrdQ_r2s[None, None, None, None, 0], - tdQsdQ_r2s[None, None, None, None, 0], - ) - cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta - ) - cute.arch.barrier(barrier_id=8, number_of_threads=num_reduce_threads) - - # S-> G - gdQ = cute.local_tile(mdQ_cur, (self.tile_m, self.tile_hdim), (None, 0)) - tdQsdQ, tdQgdQ = cpasync.tma_partition( - tma_atom_dQ, - 0, - cute.make_layout(1), - cute.group_modes(sdQ, 0, 2), - cute.group_modes(gdQ, 0, 2), - ) - - cute.copy(tma_atom_dQ, tdQsdQ[None, 0], tdQgdQ[None, m_block]) diff --git a/flash_attn/cute/flash_bwd_preprocess.py b/flash_attn/cute/flash_bwd_preprocess.py index 985391a7898..0596e89d057 100644 --- a/flash_attn/cute/flash_bwd_preprocess.py +++ b/flash_attn/cute/flash_bwd_preprocess.py @@ -238,7 +238,9 @@ def kernel( mO_cur = cute.domain_offset((seqlen.offset_q, 0), mO[None, num_head, None]) mdO_cur = cute.domain_offset((seqlen.offset_q, 0), mdO[None, num_head, None]) - padded_offset_q = seqlen.offset_q + batch_size * self.m_block_size + padded_offset_q = cute.round_up( + seqlen.offset_q + batch_size * self.m_block_size, self.m_block_size + ) mdPsum_cur = cute.domain_offset((padded_offset_q,), mdPsum[num_head, None]) headdim_v = mO.shape[2] @@ -325,7 +327,9 @@ def kernel( if cutlass.const_expr(not seqlen.has_cu_seqlens_q): mdQaccum_cur = mdQaccum[batch_size, num_head, None] else: - padded_offset_q = seqlen.offset_q + batch_size * self.m_block_size + padded_offset_q = cute.round_up( + seqlen.offset_q + batch_size * self.m_block_size, self.m_block_size + ) mdQaccum_cur = cute.domain_offset( (padded_offset_q * self.head_dim_padded,), mdQaccum[num_head, None] ) @@ -354,7 +358,9 @@ def kernel( if cutlass.const_expr(not seqlen.has_cu_seqlens_q): mLSElog2_cur = mLSElog2[batch_size, num_head, None] else: - padded_offset_q = seqlen.offset_q + batch_size * self.m_block_size + padded_offset_q = cute.round_up( + seqlen.offset_q + batch_size * self.m_block_size, self.m_block_size + ) mLSElog2_cur = cute.domain_offset((padded_offset_q,), mLSElog2[num_head, None]) gLSElog2 = cute.local_tile(mLSElog2_cur, (self.m_block_size,), (m_block,)) diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py index fd49e81292d..99de4f9c3aa 100644 --- a/flash_attn/cute/flash_bwd_sm100.py +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -25,6 +25,7 @@ TileSchedulerArguments, SingleTileScheduler, SingleTileLPTBwdScheduler, # noqa + SingleTileVarlenScheduler, ParamsBase, ) @@ -78,7 +79,7 @@ def __init__( self.tile_n = tile_n # CTA tiler - self.cta_tiler = (tile_m, tile_n, self.tile_hdim) + self.cta_tiler = (tile_n, tile_m, self.tile_hdim) # S = K @ Q.T self.mma_tiler_kq = (tile_n, tile_m, self.tile_hdim) # dP = V @ dO.T @@ -391,9 +392,9 @@ def __call__( # Block-sparse tensors (Q direction - for iterating m_blocks per n_block): blocksparse_tensors: Optional[BlockSparseTensors] = None, ): - assert all(x is None for x in (mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK)), ( - "Variable sequence length is not supported yet in FlashAttentionBackwardSm100" - ) + # assert all(x is None for x in (mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK)), ( + # "Variable sequence length is not supported yet in FlashAttentionBackwardSm100" + # ) self.q_dtype = mQ.element_type self.k_dtype = mK.element_type self.v_dtype = mV.element_type @@ -405,7 +406,9 @@ def __call__( self.dv_dtype = mdV.element_type self.ds_dtype = self.q_dtype - if const_expr(self.qhead_per_kvhead > 1): + self.dKV_postprocess = self.qhead_per_kvhead > 1 or const_expr(mCuSeqlensK is not None) + + if const_expr(self.dKV_postprocess): assert self.dk_dtype.width == 32, "Must accumulate dK in float precision for GQA" assert self.dv_dtype.width == 32, "Must accumulate dV in float precision for GQA" @@ -429,21 +432,30 @@ def __call__( ) ] - layout_transpose = [1, 3, 2, 0] # (b, s, n, h) --> (s, h, n, b) - mQ, mK, mV, mdO = [utils.select(t, mode=layout_transpose) for t in (mQ, mK, mV, mdO)] - LSE_dPsum_dQaccum_transpose = [2, 1, 0] # (b, n, s) --> (s, n, b) + # (b, s, n, h) --> (s, h, n, b) or (t, n, h) -> (t, h, n) + QO_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensQ is None) else [0, 2, 1] + mQ, mdO = [utils.select(t, mode=QO_layout_transpose) for t in (mQ, mdO)] + + KV_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensK is None) else [0, 2, 1] + mK, mV = [utils.select(t, mode=KV_layout_transpose) for t in (mK, mV)] + + # (b, n, s) --> (s, n, b) or (n, t) --> (t, n) + LSE_dPsum_dQaccum_transpose = [2, 1, 0] if const_expr(mCuSeqlensQ is None) else [1, 0] mLSE, mdPsum, mdQaccum = [ utils.select(t, mode=LSE_dPsum_dQaccum_transpose) for t in (mLSE, mdPsum, mdQaccum) ] - if const_expr(self.qhead_per_kvhead == 1): - layout_dKV_transpose = layout_transpose + + if const_expr(not self.dKV_postprocess): + layout_dKV_transpose = KV_layout_transpose else: layout_dKV_transpose = LSE_dPsum_dQaccum_transpose mdK, mdV = [utils.select(t, mode=layout_dKV_transpose) for t in (mdK, mdV)] - dO_transpose = [1, 0, 2, 3] # (s, h, n, b) --> (h, s, n, b) + # (s, h, n, b) --> (h, s, n, b) or (t, h, n) -> (h, t, b) + dO_transpose = [1, 0, 2, 3] if const_expr(mCuSeqlensQ is None) else [1, 0, 2] mdO = utils.select(mdO, mode=dO_transpose) - semaphore_transpose = [2, 3, 1, 0] # (b, n, block, stage) -> (block, stage, n, b) + # (b, n, block, stage) -> (block, stage, n, b) or (n, block, stage) -> (block, stage, n) + semaphore_transpose = [2, 3, 1, 0] if const_expr(mCuSeqlensQ is None) else [1, 2, 0] if const_expr(self.deterministic): assert mdQ_semaphore is not None mdQ_semaphore = utils.select(mdQ_semaphore, mode=semaphore_transpose) @@ -478,7 +490,7 @@ def __call__( self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1]) self.is_q_do_mcast = self.num_mcast_ctas_b > 1 - if const_expr(self.qhead_per_kvhead == 1): + if const_expr(not self.dKV_postprocess): self.mdK_layout_enum = LayoutEnum.from_tensor(mdK) self.mdV_layout_enum = LayoutEnum.from_tensor(mdV) dK_major_mode = self.mdK_layout_enum.mma_major_mode() @@ -488,7 +500,7 @@ def __call__( if const_expr(dV_major_mode != tcgen05.OperandMajorMode.K): raise RuntimeError("The layout of mdV is wrong") - if const_expr(self.use_tma_store and self.qhead_per_kvhead == 1): + if const_expr(self.use_tma_store and not self.dKV_postprocess): tma_copy_op_dKV = cpasync.CopyBulkTensorTileS2GOp() tma_atom_dK, mdK_tma_tensor = cpasync.make_tiled_tma_atom( tma_copy_op_dKV, @@ -510,7 +522,7 @@ def __call__( tma_atom_dV = None tma_atom_dK = None - if const_expr(self.qhead_per_kvhead == 1): + if const_expr(not self.dKV_postprocess): thr_layout_r2s_dKV = cute.make_ordered_layout((128, 1), order=(1, 0)) # 128 threads val_layout_r2s_dKV = cute.make_ordered_layout( (1, 128 // self.dk_dtype.width), order=(1, 0) @@ -589,28 +601,32 @@ def __call__( self.tma_copy_bytes["dKacc"] = self.tile_n * self.dK_reduce_ncol * Float32.width // 8 # TileScheduler = SingleTileScheduler - if const_expr(self.deterministic): + if const_expr(mCuSeqlensK is not None or mSeqUsedK is not None): + TileScheduler = SingleTileVarlenScheduler + elif const_expr(self.deterministic): TileScheduler = SingleTileLPTBwdScheduler else: TileScheduler = SingleTileScheduler # reads n_blocks right-to-left self.spt = (self.is_causal or self.is_local) and self.deterministic tile_sched_args = TileSchedulerArguments( - cute.ceil_div(cute.size(mK.shape[0]), self.cta_tiler[0]), + cute.ceil_div(cute.size(mK.shape[0]), self.cta_tiler[0]), # num_blocks cute.size(mQ.shape[2]), # num_heads = num_query_heads - cute.size(mK.shape[3]), + cute.size(mK.shape[3]) + if const_expr(mCuSeqlensK is None) + else cute.size(mCuSeqlensK.shape[0] - 1), # num_batches 1, # num_splits - cute.size(mQ.shape[0]), # pass seqlen_q for seqlen_k - mQ.shape[1], - mV.shape[1], - total_q=cute.size(mQ.shape[0]), - tile_shape_mn=self.cta_tiler[:2], + cute.size(mQ.shape[0]), # pass seqlen_q or total_q for seqlen_k + mQ.shape[1], # headdim + mV.shape[1], # headdim_v + total_q=cute.size(mK.shape[0]), # pass total_k for total_q + tile_shape_mn=self.cta_tiler[:2], # (tile_n, tile_m) cluster_shape_mn=self.cluster_shape_mnk[:2], - mCuSeqlensQ=None, - mSeqUsedQ=None, - qhead_per_kvhead_packgqa=1, + mCuSeqlensQ=mCuSeqlensK, + mSeqUsedQ=mSeqUsedK, + qhead_per_kvhead_packgqa=1, # pack_gqa disabled for bwd element_size=self.k_dtype.width // 8, - is_persistent=self.is_persistent, + is_persistent=self.is_persistent, # persistent mode not tested lpt=self.spt, ) @@ -718,6 +734,22 @@ class SharedStorage: fastdiv_mods = (seqlen_q_divmod, seqlen_k_divmod) self.use_block_sparsity = cutlass.const_expr(blocksparse_tensors is not None) + if const_expr(self.use_block_sparsity or aux_tensors is not None): + assert all(x is None for x in (mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK)), ( + "Variable sequence length is not supported yet for blocksparse or aux tensors in bwd" + ) + + # cute.printf("mQ = {}", tma_tensor_Q.layout) + # cute.printf("mK = {}", tma_tensor_K.layout) + # cute.printf("mV = {}", tma_tensor_V.layout) + # cute.printf("mLSE = {}", mLSE.layout) + # cute.printf("mdPsum = {}", mdPsum.layout) + # cute.printf("tma_tensor_dO = {}", tma_tensor_dO.layout) + # cute.printf("mdV = {}", mdV.layout) + # cute.printf("mdK = {}", mdK.layout) + # cute.printf("mdQaccum = {}", mdQaccum.layout) + # cute.printf("grid_dim = {}", grid_dim) + self.kernel( tma_tensor_Q, tma_tensor_K, @@ -733,6 +765,10 @@ class SharedStorage: mdQ_semaphore, mdK_semaphore, mdV_semaphore, + mCuSeqlensQ, + mCuSeqlensK, + mSeqUsedQ, + mSeqUsedK, tma_atom_Q, tma_atom_K, tma_atom_V, @@ -794,6 +830,10 @@ def kernel( mdQ_semaphore: Optional[cute.Tensor], mdK_semaphore: Optional[cute.Tensor], mdV_semaphore: Optional[cute.Tensor], + mCuSeqlensQ: Optional[cute.Tensor], + mCuSeqlensK: Optional[cute.Tensor], + mSeqUsedQ: Optional[cute.Tensor], + mSeqUsedK: Optional[cute.Tensor], tma_atom_Q: cute.CopyAtom, tma_atom_K: cute.CopyAtom, tma_atom_V: cute.CopyAtom, @@ -986,7 +1026,7 @@ def kernel( ) sLSE = storage.sLSE.get_tensor(sLSE_layout) sdPsum = storage.sdPsum.get_tensor(sdPsum_layout) - if const_expr(self.qhead_per_kvhead == 1): + if const_expr(not self.dKV_postprocess): sdV = storage.sdO.get_tensor( sdKV_layout.outer, swizzle=sdKV_layout.inner, dtype=self.dv_dtype ) @@ -1054,10 +1094,12 @@ def kernel( SeqlenInfoQK.create, seqlen_q_static=mQ.shape[0], seqlen_k_static=mK.shape[0], - mCuSeqlensQ=None, - mCuSeqlensK=None, - mSeqUsedQ=None, - mSeqUsedK=None, + mCuSeqlensQ=mCuSeqlensQ, + mCuSeqlensK=mCuSeqlensK, + mSeqUsedQ=mSeqUsedQ, + mSeqUsedK=mSeqUsedK, + tile_m=self.tile_m, + tile_n=self.tile_n, ) TileSchedulerCls = partial(self.tile_scheduler_cls.create, tile_sched_params) @@ -1294,12 +1336,22 @@ def load( seqlen, n_block // self.cluster_shape_mnk[0] ) head_idx_kv = head_idx // self.qhead_per_kvhead - mQ_cur = mQ[None, None, head_idx, batch_idx] - mK_cur = mK[None, None, head_idx_kv, batch_idx] - mV_cur = mV[None, None, head_idx_kv, batch_idx] - mdO_cur = mdO[None, None, head_idx, batch_idx] - mLSE_cur = mLSE[None, head_idx, batch_idx] - mPsum_cur = mdPsum[None, head_idx, batch_idx] + mQ_cur = seqlen.offset_batch_Q(mQ, batch_idx, dim=3)[None, None, head_idx] + mK_cur = seqlen.offset_batch_K(mK, batch_idx, dim=3)[None, None, head_idx_kv] + mV_cur = seqlen.offset_batch_K(mV, batch_idx, dim=3)[None, None, head_idx_kv] + if const_expr(not seqlen.has_cu_seqlens_q): + mdO_cur = mdO[None, None, head_idx, batch_idx] + else: + mdO_cur = cute.domain_offset((0, seqlen.offset_q), mdO[None, None, head_idx]) + if const_expr(not seqlen.has_cu_seqlens_q): + mLSE_cur = mLSE[None, head_idx, batch_idx] + mdPsum_cur = mdPsum[None, head_idx, batch_idx] + else: + padded_offset_q = cute.round_up( + seqlen.offset_q + batch_idx * self.tile_m, self.tile_m + ) + mLSE_cur = cute.domain_offset((padded_offset_q,), mLSE[None, head_idx]) + mdPsum_cur = cute.domain_offset((padded_offset_q,), mdPsum[None, head_idx]) gK = cute.local_tile(mK_cur, cute.select(self.mma_tiler_kq, mode=[0, 2]), (n_block, 0)) tSgK = thr_mma_S.partition_A(gK) @@ -1308,7 +1360,7 @@ def load( gQ = cute.local_tile(mQ_cur, cute.select(self.mma_tiler_kq, mode=[1, 2]), (None, 0)) tSgQ = thr_mma_S.partition_B(gQ) gLSE = cute.local_tile(mLSE_cur, (self.tile_m,), (None,)) - gdPsum = cute.local_tile(mPsum_cur, (self.tile_m,), (None,)) + gdPsum = cute.local_tile(mdPsum_cur, (self.tile_m,), (None,)) gdO = cute.local_tile(mdO_cur, cute.select(self.mma_tiler_pdo, mode=[1, 2]), (0, None)) tdPgdO = thr_mma_dV.partition_B(gdO) @@ -2289,6 +2341,7 @@ def compute_loop( batch_idx, head_idx, n_block, + seqlen, thr_mma_dV, tdVtdV, mdV_tma_tensor, @@ -2307,6 +2360,7 @@ def compute_loop( batch_idx, head_idx, n_block, + seqlen, thr_mma_dK, tdKtdK, mdK_tma_tensor, @@ -2315,13 +2369,13 @@ def compute_loop( thr_copy_r2s_dKV, pipeline_dKV, consumer_state_dKV, - softmax_scale if const_expr(self.qhead_per_kvhead == 1) else None, + softmax_scale if const_expr(not self.dKV_postprocess) else None, int(NamedBarrierBwdSm100.EpilogueWG1), # barrier_id mdK_semaphore, ) # Zero dK/dV for empty tiles (local attention or block sparsity) # When total_m_block_cnt == 0 for block sparsity, no Q tiles contribute to this KV tile - if const_expr(self.qhead_per_kvhead == 1): + if const_expr(not self.dKV_postprocess): should_zero_dKV = False if const_expr(self.is_local): should_zero_dKV = m_block_min >= m_block_max @@ -2415,7 +2469,15 @@ def dQacc_reduce( m_block_min, m_block_max = block_info.get_m_block_min_max( seqlen, n_block // self.cluster_shape_mnk[0] ) - mdQaccum_cur = mdQaccum[None, head_idx, batch_idx] + if const_expr(not seqlen.has_cu_seqlens_q): + mdQaccum_cur = mdQaccum[None, head_idx, batch_idx] + else: + padded_offset_q = cute.round_up( + seqlen.offset_q + batch_idx * self.tile_m, self.tile_m + ) + mdQaccum_cur = cute.domain_offset( + (padded_offset_q * self.tile_hdim,), mdQaccum[None, head_idx] + ) gdQaccum_ = cute.local_tile(mdQaccum_cur, (self.tile_m * self.tile_hdim,), (None,)) # (M * K / STAGE, STAGE, _) gdQaccum = cute.flat_divide( @@ -2715,6 +2777,7 @@ def epilogue_dK_or_dV_tma( batch_idx: Int32, head_idx: Int32, n_block: Int32, + seqlen_info, thr_mma: cute.core.ThrMma, tdKVtdKV: cute.Tensor, mdKV: cute.Tensor, @@ -2734,7 +2797,7 @@ def epilogue_dK_or_dV_tma( num_wg = num_compute_threads // 128 leader_warp = (cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4) == 0 - if const_expr(self.qhead_per_kvhead == 1): + if const_expr(not self.dKV_postprocess): sdKV = sdKV[None, None, wg_idx] # (tile_n, 64) for bf16 else: sdKV = sdKV[None, wg_idx] # (tile_n * 32) for fp32 @@ -2743,7 +2806,7 @@ def epilogue_dK_or_dV_tma( tdKVsdKV_r2s = thr_copy_r2s_dKV.partition_D(sdKV) head_idx_kv = head_idx // self.qhead_per_kvhead - if const_expr(self.qhead_per_kvhead == 1): + if const_expr(not self.dKV_postprocess): mdKV_cur = mdKV[None, None, head_idx_kv, batch_idx] # (seqlen, hdim) gdKV_p = cute.local_tile( mdKV_cur, (self.tile_n, self.tile_hdim), (n_block, 0) @@ -2753,7 +2816,15 @@ def epilogue_dK_or_dV_tma( gdKV, self.sdKV_epi_tile, (0, None) ) # (tile_n, 64, epi_stage = (hdim / 2) / 64) else: - mdKV_cur = mdKV[None, head_idx_kv, batch_idx] # (seqlen * hdim) + if const_expr(not seqlen_info.has_cu_seqlens_k): + mdKV_cur = mdKV[None, head_idx_kv, batch_idx] # (seqlen * hdim) + else: + padded_offset_k = cute.round_up( + seqlen_info.offset_k + batch_idx * self.tile_n, self.tile_n + ) + mdKV_cur = cute.domain_offset( + (padded_offset_k * self.tile_hdim,), mdKV[None, head_idx_kv] + ) gdKV_p = cute.local_tile( mdKV_cur, (self.tile_n * self.tile_hdim,), (n_block,) ) # (tile_n * hdim) @@ -2768,7 +2839,7 @@ def epilogue_dK_or_dV_tma( if const_expr(deterministic_KV): mdKV_semaphore_cur = mdKV_semaphore[n_block, None, head_idx_kv, batch_idx] - if const_expr(self.qhead_per_kvhead == 1): + if const_expr(not self.dKV_postprocess): tdKVsdKV, tdKVgdKV = cpasync.tma_partition( tma_atom_dKV, 0, # no multicast @@ -2842,7 +2913,7 @@ def epilogue_dK_or_dV_tma( # SMEM -> GMEM if leader_warp: - if const_expr(self.qhead_per_kvhead == 1): + if const_expr(not self.dKV_postprocess): cute.copy(tma_atom_dKV, tdKVsdKV, tdKVgdKV[None, epi_stage]) else: with cute.arch.elect_one(): diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 6a04ec45dfa..f78b77a36f4 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -742,13 +742,15 @@ def _flash_attn_bwd( total_q_rounded_padded = ( (total_q + cu_seqlens_q.shape[0] * m_block_size - 1) // m_block_size * m_block_size ) + # print("total_q_rounded_padded = ", total_q_rounded_padded) dq_accum = torch.empty( num_head, total_q_rounded_padded * head_dim_rounded, dtype=torch.float32, device=device ) dpsum = torch.empty(num_head, total_q_rounded_padded, dtype=torch.float32, device=device) lse_log2 = torch.empty(num_head, total_q_rounded_padded, dtype=torch.float32, device=device) - if qhead_per_kvhead > 1: + # todo: allow for mha with cu_seqlens_k to skip dK/dV postprocess + if qhead_per_kvhead > 1 or cu_seqlens_k is not None: head_dim_v_rounded = (head_dim_v + 32 - 1) // 32 * 32 if cu_seqlens_k is None: seqlen_k_rounded = (seqlen_k + n_block_size - 1) // n_block_size * n_block_size @@ -773,6 +775,7 @@ def _flash_attn_bwd( total_k_rounded_padded = ( (total_k + cu_seqlens_k.shape[0] * n_block_size - 1) // n_block_size * n_block_size ) + # print("total_k_rounded_padded = ", total_k_rounded_padded) num_n_blocks = total_k_rounded_padded // n_block_size if cluster_size == 2 and num_n_blocks % cluster_size != 0: total_k_rounded_padded = total_k_rounded_padded + n_block_size @@ -805,7 +808,15 @@ def _flash_attn_bwd( dV_semaphore = None # Preprocess kernel: compute (o * dout).sum(dim=-1), lse * log2_e, and zero out dq_accum. - compile_key_pre = (compute_capability, dtype, head_dim_v, m_block_size, num_threads) + compile_key_pre = ( + compute_capability, + dtype, + head_dim_v, + m_block_size, + num_threads, + cu_seqlens_q is None, + seqused_q is None, + ) if compile_key_pre not in _flash_attn_bwd.compile_cache_pre: o_tensor, do_tensor = [to_cute_tensor(t) for t in (out, dout)] dq_accum_tensor, dpsum_tensor, lse_log2_tensor = [ @@ -871,6 +882,10 @@ def _flash_attn_bwd( AtomLayoutNdKV, AtomLayoutMdQ, V_in_regs, + cu_seqlens_q is None, + cu_seqlens_k is None, + seqused_q is None, + seqused_k is None, ) cute_aux_tensors = None else: @@ -904,6 +919,10 @@ def _flash_attn_bwd( mask_mod_hash, num_aux_tensors, use_block_sparsity, + cu_seqlens_q is None, + cu_seqlens_k is None, + seqused_q is None, + seqused_k is None, ) num_threads = 384 if compile_key not in _flash_attn_bwd.compile_cache: @@ -1069,7 +1088,19 @@ def _flash_attn_bwd( num_threads = 256 if compute_capability == 9 else 128 # Postprocess kernel: convert dq_accum from float32 to dq in bf16/fp16 - compile_key_post = (dtype, head_dim, m_block_size, num_threads, AtomLayoutMdQ, dQ_swapAB) + compile_key_post = ( + compute_capability, + dtype, + head_dim, + m_block_size, + num_threads, + AtomLayoutMdQ, + dQ_swapAB, + cu_seqlens_q is None, + cu_seqlens_k is None, + seqused_q is None, + seqused_k is None, + ) if compile_key_post not in _flash_attn_bwd.compile_cache_post: dq_accum_tensor = to_cute_tensor(dq_accum) dq_tensor = to_cute_tensor(dq) @@ -1101,9 +1132,21 @@ def _flash_attn_bwd( current_stream, ) - if qhead_per_kvhead > 1: + if qhead_per_kvhead > 1 or cu_seqlens_k is not None: # Postprocess kernel: convert dk_accum & dv_accum from float32 to bf16/fp16 - compile_key_post = (dtype, head_dim, n_block_size, num_threads, AtomLayoutNdKV, dKV_swapAB) + compile_key_post = ( + compute_capability, + dtype, + head_dim, + n_block_size, + num_threads, + AtomLayoutNdKV, + dKV_swapAB, + cu_seqlens_q is None, + cu_seqlens_k is None, + seqused_q is None, + seqused_k is None, + ) if compile_key_post not in _flash_attn_bwd.compile_cache_post: dk_accum_tensor = to_cute_tensor(dk_accum) dk_tensor = to_cute_tensor(dk) @@ -1111,8 +1154,9 @@ def _flash_attn_bwd( to_cute_tensor(t, assumed_align=4) if t is not None else None for t in (cu_seqlens_k, seqused_k) ] + arch = compute_capability * 10 fa_bwd_post = FlashAttentionBackwardPostprocess( - dtype, head_dim, n_block_size, num_threads, AtomLayoutNdKV, dKV_swapAB + dtype, head_dim, arch, n_block_size, num_threads, AtomLayoutNdKV, dKV_swapAB ) # TODO: check @can_implement _flash_attn_bwd.compile_cache_post[compile_key_post] = cute.compile( @@ -1134,12 +1178,17 @@ def _flash_attn_bwd( current_stream, ) compile_key_post = ( + compute_capability, dtype, head_dim_v, n_block_size, num_threads, AtomLayoutNdKV, dKV_swapAB, + cu_seqlens_q is None, + cu_seqlens_k is None, + seqused_q is None, + seqused_k is None, ) if compile_key_post not in _flash_attn_bwd.compile_cache_post: dv_accum_tensor = to_cute_tensor(dv_accum) @@ -1148,8 +1197,9 @@ def _flash_attn_bwd( to_cute_tensor(t, assumed_align=4) if t is not None else None for t in (cu_seqlens_k, seqused_k) ] + arch = compute_capability * 10 fa_bwd_post = FlashAttentionBackwardPostprocess( - dtype, head_dim_v, n_block_size, num_threads, AtomLayoutNdKV, dKV_swapAB + dtype, head_dim_v, arch, n_block_size, num_threads, AtomLayoutNdKV, dKV_swapAB ) # TODO: check @can_implement _flash_attn_bwd.compile_cache_post[compile_key_post] = cute.compile( @@ -1322,6 +1372,8 @@ def backward(ctx, dout, *args): ctx.softmax_scale, ctx.causal, ctx.softcap, + window_size_left=ctx.window_size[0], + window_size_right=ctx.window_size[1], cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, seqused_q=seqused_q, diff --git a/flash_attn/cute/seqlen_info.py b/flash_attn/cute/seqlen_info.py index baa38236a78..c656a079a74 100644 --- a/flash_attn/cute/seqlen_info.py +++ b/flash_attn/cute/seqlen_info.py @@ -44,6 +44,8 @@ class SeqlenInfoQK: has_cu_seqlens_k: cutlass.Constexpr[bool] has_seqused_q: cutlass.Constexpr[bool] has_seqused_k: cutlass.Constexpr[bool] + tile_m: cutlass.Constexpr[cutlass.Int32] + tile_n: cutlass.Constexpr[cutlass.Int32] @staticmethod def create( @@ -54,6 +56,8 @@ def create( mCuSeqlensK: Optional[cute.Tensor] = None, mSeqUsedQ: Optional[cute.Tensor] = None, mSeqUsedK: Optional[cute.Tensor] = None, + tile_m: cutlass.Constexpr[cutlass.Int32] = 128, + tile_n: cutlass.Constexpr[cutlass.Int32] = 128, ): offset_q = 0 if const_expr(mCuSeqlensQ is None) else mCuSeqlensQ[batch_idx] offset_k = 0 if const_expr(mCuSeqlensK is None) else mCuSeqlensK[batch_idx] @@ -86,25 +90,47 @@ def create( has_cu_seqlens_k, has_seqused_q, has_seqused_k, + tile_m, + tile_n, ) - def offset_batch_Q(self, mQ: cute.Tensor, batch_idx: Int32, dim: int) -> cute.Tensor: + def offset_batch_Q( + self, + mQ: cute.Tensor, + batch_idx: Int32, + dim: int, + padded: cutlass.Constexpr[bool] = False, + ) -> cute.Tensor: """Seqlen must be the first dimension of mQ""" if const_expr(not self.has_cu_seqlens_q): idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mQ) - 1 - dim) return mQ[idx] else: - offset = ( - self.offset_q if const_expr(cute.rank(mQ.shape[0]) == 1) else (0, self.offset_q) + offset_q = ( + self.offset_q + if const_expr(not padded) + else cute.round_up(self.offset_q + batch_idx * self.tile_m, self.tile_m) ) + offset = offset_q if const_expr(cute.rank(mQ.shape[0]) == 1) else (0, offset_q) idx = (offset,) + (0,) * (cute.rank(mQ) - 1) return cute.domain_offset(idx, mQ) - def offset_batch_K(self, mK: cute.Tensor, batch_idx: Int32, dim: int) -> cute.Tensor: + def offset_batch_K( + self, + mK: cute.Tensor, + batch_idx: Int32, + dim: int, + padded: cutlass.Constexpr[bool] = False, + ) -> cute.Tensor: """Seqlen must be the first dimension of mK""" if const_expr(not self.has_cu_seqlens_k): idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mK) - 1 - dim) return mK[idx] else: - idx = (self.offset_k,) + (0,) * (cute.rank(mK) - 1) + offset_k = ( + self.offset_k + if const_expr(not padded) + else cute.round_up(self.offset_k + batch_idx * self.tile_n, self.tile_n) + ) + idx = (offset_k,) + (0,) * (cute.rank(mK) - 1) return cute.domain_offset(idx, mK) From de146720e870ed8ccd44142e9f5ac3db5dfe3c37 Mon Sep 17 00:00:00 2001 From: jayhshah Date: Fri, 26 Dec 2025 06:14:58 +0000 Subject: [PATCH 02/15] fix mha --- flash_attn/cute/flash_bwd_sm100.py | 2 +- flash_attn/cute/interface.py | 15 ++++++++------- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py index 99de4f9c3aa..07b55a75f94 100644 --- a/flash_attn/cute/flash_bwd_sm100.py +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -354,7 +354,7 @@ def _setup_smem_layout(self): self.num_epi_stages = max(1, (self.tile_hdim // 2) // self.sdKV_epi_tile[1]) self.sdKV_flat_epi_tile = self.tile_n * (self.tile_hdim // 2) // self.num_epi_stages # TODO: dK and dV could have different shapes - if const_expr(self.qhead_per_kvhead == 1): + if const_expr(not self.dKV_postprocess): self.sdKV_layout = sm100_utils_basic.make_smem_layout_epi( self.dk_dtype, LayoutEnum.ROW_MAJOR, diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index f78b77a36f4..adf88e493e3 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -750,7 +750,8 @@ def _flash_attn_bwd( lse_log2 = torch.empty(num_head, total_q_rounded_padded, dtype=torch.float32, device=device) # todo: allow for mha with cu_seqlens_k to skip dK/dV postprocess - if qhead_per_kvhead > 1 or cu_seqlens_k is not None: + dKV_postprocess = qhead_per_kvhead > 1 or cu_seqlens_k is not None + if dKV_postprocess: head_dim_v_rounded = (head_dim_v + 32 - 1) // 32 * 32 if cu_seqlens_k is None: seqlen_k_rounded = (seqlen_k + n_block_size - 1) // n_block_size * n_block_size @@ -932,7 +933,7 @@ def _flash_attn_bwd( dq_accum_tensor, dpsum_tensor, lse_log2_tensor = [ to_cute_tensor(t) for t in (dq_accum, dpsum, lse_log2) ] - if qhead_per_kvhead > 1: + if dKV_postprocess: dk_accum_tensor, dv_accum_tensor = [ to_cute_tensor(t) for t in (dk_accum, dv_accum) ] @@ -1030,8 +1031,8 @@ def _flash_attn_bwd( lse_log2_tensor, dpsum_tensor, dq_accum_tensor, - dk_tensor if qhead_per_kvhead == 1 else dk_accum_tensor, - dv_tensor if qhead_per_kvhead == 1 else dv_accum_tensor, + dk_tensor if not dKV_postprocess else dk_accum_tensor, + dv_tensor if not dKV_postprocess else dv_accum_tensor, softmax_scale, current_stream, cu_seqlens_q_tensor, @@ -1068,8 +1069,8 @@ def _flash_attn_bwd( lse_log2, dpsum, dq_accum, - dk if qhead_per_kvhead == 1 else dk_accum, - dv if qhead_per_kvhead == 1 else dv_accum, + dk if not dKV_postprocess else dk_accum, + dv if not dKV_postprocess else dv_accum, softmax_scale, current_stream, cu_seqlens_q, @@ -1132,7 +1133,7 @@ def _flash_attn_bwd( current_stream, ) - if qhead_per_kvhead > 1 or cu_seqlens_k is not None: + if dKV_postprocess: # Postprocess kernel: convert dk_accum & dv_accum from float32 to bf16/fp16 compile_key_post = ( compute_capability, From 77303683d760b16227e174635eb213b89331481c Mon Sep 17 00:00:00 2001 From: jayhshah Date: Fri, 26 Dec 2025 06:45:45 +0000 Subject: [PATCH 03/15] change offset mode to round down multiple --- flash_attn/cute/flash_bwd_postprocess.py | 4 +-- flash_attn/cute/flash_bwd_preprocess.py | 46 +++++++++++++----------- flash_attn/cute/flash_bwd_sm100.py | 12 +++---- flash_attn/cute/seqlen_info.py | 4 +-- 4 files changed, 36 insertions(+), 30 deletions(-) diff --git a/flash_attn/cute/flash_bwd_postprocess.py b/flash_attn/cute/flash_bwd_postprocess.py index 8a4ed74e776..5211fc3b7ae 100644 --- a/flash_attn/cute/flash_bwd_postprocess.py +++ b/flash_attn/cute/flash_bwd_postprocess.py @@ -341,8 +341,8 @@ def kernel( mdQaccum_cur = mdQaccum[batch_idx, head_idx, None] head_dim = mdQ.shape[3] else: - padded_offset_q = cute.round_up( - seqlen.offset_q + batch_idx * self.tile_m, self.tile_m + padded_offset_q = ( + (seqlen.offset_q + batch_idx * self.tile_m) // self.tile_m * self.tile_m ) mdQ_cur = cute.domain_offset((seqlen.offset_q, 0), mdQ[None, head_idx, None]) mdQaccum_cur = cute.domain_offset( diff --git a/flash_attn/cute/flash_bwd_preprocess.py b/flash_attn/cute/flash_bwd_preprocess.py index 0596e89d057..faee1da055f 100644 --- a/flash_attn/cute/flash_bwd_preprocess.py +++ b/flash_attn/cute/flash_bwd_preprocess.py @@ -213,14 +213,14 @@ def kernel( tile_scheduler = TileScheduler.create(tile_sched_params) work_tile = tile_scheduler.initial_work_tile_info() - m_block, num_head, batch_size, _ = work_tile.tile_idx + m_block, head_idx, batch_idx, _ = work_tile.tile_idx if work_tile.is_valid_tile: # /////////////////////////////////////////////////////////////////////////////// # Get the appropriate tiles for this thread block. # /////////////////////////////////////////////////////////////////////////////// seqlen = SeqlenInfoQK.create( - batch_size, + batch_idx, mO.shape[1], 0, mCuSeqlensQ=mCuSeqlensQ, @@ -230,18 +230,20 @@ def kernel( ) if cutlass.const_expr(not seqlen.has_cu_seqlens_q): - mO_cur = mO[batch_size, None, num_head, None] - mdO_cur = mdO[batch_size, None, num_head, None] - mdPsum_cur = mdPsum[batch_size, num_head, None] + mO_cur = mO[batch_idx, None, head_idx, None] + mdO_cur = mdO[batch_idx, None, head_idx, None] + mdPsum_cur = mdPsum[batch_idx, head_idx, None] headdim_v = mO.shape[3] else: - mO_cur = cute.domain_offset((seqlen.offset_q, 0), mO[None, num_head, None]) - mdO_cur = cute.domain_offset((seqlen.offset_q, 0), mdO[None, num_head, None]) + mO_cur = cute.domain_offset((seqlen.offset_q, 0), mO[None, head_idx, None]) + mdO_cur = cute.domain_offset((seqlen.offset_q, 0), mdO[None, head_idx, None]) - padded_offset_q = cute.round_up( - seqlen.offset_q + batch_size * self.m_block_size, self.m_block_size + padded_offset_q = ( + (seqlen.offset_q + batch_idx * self.m_block_size) + // self.m_block_size + * self.m_block_size ) - mdPsum_cur = cute.domain_offset((padded_offset_q,), mdPsum[num_head, None]) + mdPsum_cur = cute.domain_offset((padded_offset_q,), mdPsum[head_idx, None]) headdim_v = mO.shape[2] blkOdO_shape = (self.m_block_size, self.head_dim_padded) @@ -270,9 +272,9 @@ def kernel( if cutlass.const_expr(mLSE is not None): if cutlass.const_expr(not seqlen.has_cu_seqlens_q): - mLSE_cur = mLSE[batch_size, num_head, None] + mLSE_cur = mLSE[batch_idx, head_idx, None] else: - mLSE_cur = cute.domain_offset((seqlen.offset_q,), mLSE[num_head, None]) + mLSE_cur = cute.domain_offset((seqlen.offset_q,), mLSE[head_idx, None]) gLSE = cute.local_tile(mLSE_cur, (self.m_block_size,), (m_block,)) lse = Float32.inf @@ -325,13 +327,15 @@ def kernel( # Clear dQaccum if cutlass.const_expr(mdQaccum is not None): if cutlass.const_expr(not seqlen.has_cu_seqlens_q): - mdQaccum_cur = mdQaccum[batch_size, num_head, None] + mdQaccum_cur = mdQaccum[batch_idx, head_idx, None] else: - padded_offset_q = cute.round_up( - seqlen.offset_q + batch_size * self.m_block_size, self.m_block_size + padded_offset_q = ( + (seqlen.offset_q + batch_idx * self.m_block_size) + // self.m_block_size + * self.m_block_size ) mdQaccum_cur = cute.domain_offset( - (padded_offset_q * self.head_dim_padded,), mdQaccum[num_head, None] + (padded_offset_q * self.head_dim_padded,), mdQaccum[head_idx, None] ) # HACK: Compiler doesn't seem to recognize that padding @@ -356,12 +360,14 @@ def kernel( if cutlass.const_expr(mLSE is not None): if cutlass.const_expr(not seqlen.has_cu_seqlens_q): - mLSElog2_cur = mLSElog2[batch_size, num_head, None] + mLSElog2_cur = mLSElog2[batch_idx, head_idx, None] else: - padded_offset_q = cute.round_up( - seqlen.offset_q + batch_size * self.m_block_size, self.m_block_size + padded_offset_q = ( + (seqlen.offset_q + batch_idx * self.m_block_size) + // self.m_block_size + * self.m_block_size ) - mLSElog2_cur = cute.domain_offset((padded_offset_q,), mLSElog2[num_head, None]) + mLSElog2_cur = cute.domain_offset((padded_offset_q,), mLSElog2[head_idx, None]) gLSElog2 = cute.local_tile(mLSElog2_cur, (self.m_block_size,), (m_block,)) LOG2_E = math.log2(math.e) diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py index 07b55a75f94..4e77df02b57 100644 --- a/flash_attn/cute/flash_bwd_sm100.py +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -1347,8 +1347,8 @@ def load( mLSE_cur = mLSE[None, head_idx, batch_idx] mdPsum_cur = mdPsum[None, head_idx, batch_idx] else: - padded_offset_q = cute.round_up( - seqlen.offset_q + batch_idx * self.tile_m, self.tile_m + padded_offset_q = ( + (seqlen.offset_q + batch_idx * self.tile_m) // self.tile_m * self.tile_m ) mLSE_cur = cute.domain_offset((padded_offset_q,), mLSE[None, head_idx]) mdPsum_cur = cute.domain_offset((padded_offset_q,), mdPsum[None, head_idx]) @@ -2472,8 +2472,8 @@ def dQacc_reduce( if const_expr(not seqlen.has_cu_seqlens_q): mdQaccum_cur = mdQaccum[None, head_idx, batch_idx] else: - padded_offset_q = cute.round_up( - seqlen.offset_q + batch_idx * self.tile_m, self.tile_m + padded_offset_q = ( + (seqlen.offset_q + batch_idx * self.tile_m) // self.tile_m * self.tile_m ) mdQaccum_cur = cute.domain_offset( (padded_offset_q * self.tile_hdim,), mdQaccum[None, head_idx] @@ -2819,8 +2819,8 @@ def epilogue_dK_or_dV_tma( if const_expr(not seqlen_info.has_cu_seqlens_k): mdKV_cur = mdKV[None, head_idx_kv, batch_idx] # (seqlen * hdim) else: - padded_offset_k = cute.round_up( - seqlen_info.offset_k + batch_idx * self.tile_n, self.tile_n + padded_offset_k = ( + (seqlen_info.offset_k + batch_idx * self.tile_n) // self.tile_n * self.tile_n ) mdKV_cur = cute.domain_offset( (padded_offset_k * self.tile_hdim,), mdKV[None, head_idx_kv] diff --git a/flash_attn/cute/seqlen_info.py b/flash_attn/cute/seqlen_info.py index c656a079a74..426fb4e7cb6 100644 --- a/flash_attn/cute/seqlen_info.py +++ b/flash_attn/cute/seqlen_info.py @@ -109,7 +109,7 @@ def offset_batch_Q( offset_q = ( self.offset_q if const_expr(not padded) - else cute.round_up(self.offset_q + batch_idx * self.tile_m, self.tile_m) + else (self.offset_q + batch_idx * self.tile_m) // self.tile_m * self.tile_m ) offset = offset_q if const_expr(cute.rank(mQ.shape[0]) == 1) else (0, offset_q) idx = (offset,) + (0,) * (cute.rank(mQ) - 1) @@ -130,7 +130,7 @@ def offset_batch_K( offset_k = ( self.offset_k if const_expr(not padded) - else cute.round_up(self.offset_k + batch_idx * self.tile_n, self.tile_n) + else (self.offset_k + batch_idx * self.tile_n) // self.tile_n * self.tile_n ) idx = (offset_k,) + (0,) * (cute.rank(mK) - 1) return cute.domain_offset(idx, mK) From 30c6c22b47718ed9ea5872e14f441ccabb328e7f Mon Sep 17 00:00:00 2001 From: jayhshah Date: Fri, 26 Dec 2025 06:46:58 +0000 Subject: [PATCH 04/15] enable varlen bwd tests --- tests/cute/test_flash_attn.py | 31 +++++++++++++++++++------------ 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/tests/cute/test_flash_attn.py b/tests/cute/test_flash_attn.py index b2809ab61ec..10e8827d6f0 100644 --- a/tests/cute/test_flash_attn.py +++ b/tests/cute/test_flash_attn.py @@ -236,7 +236,7 @@ def test_flash_attn_output( print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") # num_splits_vals = [1, 3] - pack_gqa_vals = [False, True, None] + pack_gqa_vals = [False, True, None] if not TEST_BWD_ONLY else [False] # SplitKV is not supported for hdim >= 192 # pack_gqa_vals = [False] num_splits_vals = [1, 3] if d < 192 and not DISABLE_SPLIT and not TEST_BWD_ONLY else [1] @@ -371,9 +371,9 @@ def test_flash_attn_output( # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) -# @pytest.mark.parametrize("mha_type", ["mqa"]) -@pytest.mark.parametrize("has_learnable_sink", [False, True]) -# @pytest.mark.parametrize("has_learnable_sink", [False]) +# @pytest.mark.parametrize("mha_type", ["mha"]) +# @pytest.mark.parametrize("has_learnable_sink", [False, True]) +@pytest.mark.parametrize("has_learnable_sink", [False]) # @pytest.mark.parametrize("has_qv", [False, True]) @pytest.mark.parametrize("has_qv", [False]) # @pytest.mark.parametrize("deterministic", [False, True]) @@ -383,7 +383,7 @@ def test_flash_attn_output( # @pytest.mark.parametrize("local", [False, True]) @pytest.mark.parametrize("local", [False]) @pytest.mark.parametrize("causal", [False, True]) -# @pytest.mark.parametrize("causal", [False]) +# @pytest.mark.parametrize("causal", [True]) # @pytest.mark.parametrize("add_unused_qkv", [False, True]) @pytest.mark.parametrize("add_unused_qkv", [False]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) @@ -419,6 +419,7 @@ def test_flash_attn_output( (2048, 2048), ], ) +@pytest.mark.parametrize("varlen_mode", ["random", "third", "full"]) def test_flash_attn_varlen_output( seqlen_q, seqlen_k, @@ -432,6 +433,7 @@ def test_flash_attn_varlen_output( has_learnable_sink, mha_type, dtype, + varlen_mode, ): if ( causal or local @@ -442,13 +444,12 @@ def test_flash_attn_varlen_output( torch.random.manual_seed(seqlen_q + seqlen_k + d + int(causal) * 2 + int(local)) batch_size = 49 if seqlen_q <= 1024 else 7 nheads = 6 - # batch_size = 1 # nheads = 1 nheads_kv = nheads if mha_type == "mha" else (3 if mha_type == "gqa" else 1) dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype # dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) dv_vals = [128] if d == 192 else ([d] if d != 128 else [64, d]) - if dtype == torch.float8_e4m3fn: + if dtype == torch.float8_e4m3fn or TEST_BWD_ONLY: dv_vals = [d] # attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if seqlen_q <= seqlen_k else [0] attention_chunk_vals = [0] @@ -505,7 +506,11 @@ def test_flash_attn_varlen_output( q, k, v = [x.detach().requires_grad_() for x in (q_ref, k_ref, v_ref)] qv = qv_ref.detach() if has_qv else None query_padding_mask = generate_random_padding_mask( - seqlen_q, batch_size, device, mode="random", zero_lengths=False + seqlen_q, + batch_size, + device, + mode=varlen_mode, + zero_lengths=False ) # TODO: test zero_lengths key_padding_mask = generate_random_padding_mask( @@ -513,7 +518,7 @@ def test_flash_attn_varlen_output( seqlen_k, batch_size, device, - mode="random", + mode=varlen_mode, zero_lengths=False, ) @@ -570,6 +575,8 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): query_unused_mask=query_unused_mask, key_unused_mask=key_unused_mask, ) + print("cu_seqlens_q = ", cu_seqlens_q) + print("cu_seqlens_k = ", cu_seqlens_k) q_unpad, k_unpad, v_unpad = [ x.detach().to(dtype).requires_grad_() for x in (q_unpad, k_unpad, v_unpad) ] @@ -619,11 +626,11 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() rtol = 2 if softcap == 0.0 else 3 - pack_gqa_vals = [False, True, None] + pack_gqa_vals = [False, True, None] if not TEST_BWD_ONLY else [False] # pack_gqa_vals = [False] # num_splits_vals = [1, 3] # SplitKV is not supported for hdim >= 192 - num_splits_vals = [1, 3] if d < 192 and not DISABLE_SPLIT else [1] + num_splits_vals = [1, 3] if d < 192 and not DISABLE_SPLIT and not TEST_BWD_ONLY else [1] for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): # SplitKV not supported on SM90 - skip this iteration if IS_SM90 and num_splits > 1: @@ -670,7 +677,7 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): and not attention_chunk != 0 and dv == d and not has_learnable_sink - and False + # and False ): g_unpad = torch.randn_like(out_unpad) do_o = ((g_unpad.float() * out_unpad.float()).sum(-1)).transpose(-1, -2) From 107bed6f51dc84de6866ecc1fc62ff654f1b7b4c Mon Sep 17 00:00:00 2001 From: jayhshah Date: Fri, 26 Dec 2025 07:18:13 +0000 Subject: [PATCH 05/15] enable deterministic mode --- flash_attn/cute/flash_bwd_sm100.py | 11 ++++++++--- flash_attn/cute/interface.py | 20 ++++++++++++++++++-- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py index 4e77df02b57..f9dfc960d16 100644 --- a/flash_attn/cute/flash_bwd_sm100.py +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -454,8 +454,8 @@ def __call__( dO_transpose = [1, 0, 2, 3] if const_expr(mCuSeqlensQ is None) else [1, 0, 2] mdO = utils.select(mdO, mode=dO_transpose) - # (b, n, block, stage) -> (block, stage, n, b) or (n, block, stage) -> (block, stage, n) - semaphore_transpose = [2, 3, 1, 0] if const_expr(mCuSeqlensQ is None) else [1, 2, 0] + # (b, n, block, stage) -> (block, stage, n, b) + semaphore_transpose = [2, 3, 1, 0] if const_expr(self.deterministic): assert mdQ_semaphore is not None mdQ_semaphore = utils.select(mdQ_semaphore, mode=semaphore_transpose) @@ -608,7 +608,12 @@ def __call__( else: TileScheduler = SingleTileScheduler # reads n_blocks right-to-left - self.spt = (self.is_causal or self.is_local) and self.deterministic + self.spt = ( + (self.is_causal or self.is_local) + and self.deterministic + and mCuSeqlensK is None + and mSeqUsedK is None + ) tile_sched_args = TileSchedulerArguments( cute.ceil_div(cute.size(mK.shape[0]), self.cta_tiler[0]), # num_blocks cute.size(mQ.shape[2]), # num_heads = num_query_heads diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index adf88e493e3..858b1c223df 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -569,6 +569,8 @@ def _flash_attn_bwd( cu_seqlens_k: Optional[torch.Tensor] = None, seqused_q: Optional[torch.Tensor] = None, seqused_k: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_k: Optional[int] = None, deterministic: bool = False, dq: Optional[torch.Tensor] = None, dk: Optional[torch.Tensor] = None, @@ -615,7 +617,7 @@ def _flash_attn_bwd( total_q = batch_size * seqlen_q else: batch_size = cu_seqlens_q.shape[0] - 1 - seqlen_q = None + seqlen_q = max_seqlen_q total_q = q.shape[0] if cu_seqlens_k is None: @@ -623,7 +625,7 @@ def _flash_attn_bwd( total_k = batch_size * seqlen_k else: batch_size = cu_seqlens_k.shape[0] - 1 - seqlen_k = None + seqlen_k = max_seqlen_k total_k = k.shape[0] num_head_kv = k.shape[-2] @@ -797,11 +799,15 @@ def _flash_attn_bwd( current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) if deterministic: + assert seqlen_q is not None, "seqlen_q not provided" + seqlen_q_rounded = (seqlen_q + m_block_size - 1) // m_block_size * m_block_size dQ_semaphore = torch.zeros(batch_size, num_head, seqlen_q_rounded // m_block_size, 1, dtype=torch.int32, device="cuda") else: dQ_semaphore = None if deterministic and qhead_per_kvhead > 1: + assert seqlen_k is not None, "seqlen_k not provided" + seqlen_k_rounded = (seqlen_k + n_block_size - 1) // n_block_size * n_block_size dK_semaphore = torch.zeros(batch_size, num_head_kv, seqlen_k_rounded // n_block_size, 2, dtype=torch.int32, device="cuda") dV_semaphore = torch.zeros(batch_size, num_head_kv, seqlen_k_rounded // n_block_size, 2, dtype=torch.int32, device="cuda") else: @@ -1314,6 +1320,8 @@ def forward( cu_seqlens_k: Optional[torch.Tensor], seqused_q: Optional[torch.Tensor] = None, seqused_k: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_k: Optional[int] = None, page_table: Optional[torch.Tensor] = None, softmax_scale: Optional[float] = None, causal: bool = False, @@ -1356,6 +1364,8 @@ def forward( ctx.window_size = window_size ctx.softcap = softcap ctx.deterministic = deterministic + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_k = max_seqlen_k return out, lse @staticmethod @@ -1379,6 +1389,8 @@ def backward(ctx, dout, *args): cu_seqlens_k=cu_seqlens_k, seqused_q=seqused_q, seqused_k=seqused_k, + max_seqlen_q=ctx.max_seqlen_q, + max_seqlen_k=ctx.max_seqlen_k, deterministic=ctx.deterministic, ) @@ -1431,6 +1443,8 @@ def flash_attn_varlen_func( cu_seqlens_k: Optional[torch.Tensor] = None, seqused_q: Optional[torch.Tensor] = None, seqused_k: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_k: Optional[int] = None, page_table: Optional[torch.Tensor] = None, softmax_scale: Optional[float] = None, causal: bool = False, @@ -1453,6 +1467,8 @@ def flash_attn_varlen_func( cu_seqlens_k, seqused_q, seqused_k, + max_seqlen_q, + max_seqlen_k, page_table, softmax_scale, causal, From 34a72acd96fd301394e8f0e8da923796ca41196e Mon Sep 17 00:00:00 2001 From: jayhshah Date: Thu, 8 Jan 2026 08:03:33 +0000 Subject: [PATCH 06/15] fix deadlock and switch mha to no postprocess --- flash_attn/cute/flash_bwd_sm100.py | 63 +-- flash_attn/cute/interface.py | 11 +- flash_attn/cute/testing.py | 7 +- tests/cute/test_flash_attn.py | 114 +++-- tests/cute/test_flash_attn_race_condition.py | 435 +++++++++++++++++++ 5 files changed, 549 insertions(+), 81 deletions(-) diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py index f9dfc960d16..2aa49342ebd 100644 --- a/flash_attn/cute/flash_bwd_sm100.py +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -100,7 +100,6 @@ def __init__( self.is_local = is_local self.qhead_per_kvhead = qhead_per_kvhead self.pack_gqa = False - self.use_tma_store = True self.deterministic = deterministic # Score mod and mask mod support @@ -392,9 +391,6 @@ def __call__( # Block-sparse tensors (Q direction - for iterating m_blocks per n_block): blocksparse_tensors: Optional[BlockSparseTensors] = None, ): - # assert all(x is None for x in (mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK)), ( - # "Variable sequence length is not supported yet in FlashAttentionBackwardSm100" - # ) self.q_dtype = mQ.element_type self.k_dtype = mK.element_type self.v_dtype = mV.element_type @@ -406,7 +402,10 @@ def __call__( self.dv_dtype = mdV.element_type self.ds_dtype = self.q_dtype - self.dKV_postprocess = self.qhead_per_kvhead > 1 or const_expr(mCuSeqlensK is not None) + self.is_varlen_k = mCuSeqlensK is not None or mSeqUsedK is not None + self.is_varlen_q = mCuSeqlensQ is not None or mSeqUsedQ is not None + self.use_tma_store = not (self.qhead_per_kvhead == 1 and self.is_varlen_k) + self.dKV_postprocess = self.qhead_per_kvhead > 1 if const_expr(self.dKV_postprocess): assert self.dk_dtype.width == 32, "Must accumulate dK in float precision for GQA" @@ -601,19 +600,14 @@ def __call__( self.tma_copy_bytes["dKacc"] = self.tile_n * self.dK_reduce_ncol * Float32.width // 8 # TileScheduler = SingleTileScheduler - if const_expr(mCuSeqlensK is not None or mSeqUsedK is not None): + if const_expr(self.is_varlen_k): TileScheduler = SingleTileVarlenScheduler elif const_expr(self.deterministic): TileScheduler = SingleTileLPTBwdScheduler else: TileScheduler = SingleTileScheduler # reads n_blocks right-to-left - self.spt = ( - (self.is_causal or self.is_local) - and self.deterministic - and mCuSeqlensK is None - and mSeqUsedK is None - ) + self.spt = (self.is_causal or self.is_local) and self.deterministic and not self.is_varlen_k tile_sched_args = TileSchedulerArguments( cute.ceil_div(cute.size(mK.shape[0]), self.cta_tiler[0]), # num_blocks cute.size(mQ.shape[2]), # num_heads = num_query_heads @@ -1420,7 +1414,10 @@ def load( ) process_tile = total_m_block_cnt > Int32(0) else: - process_tile = const_expr(not self.is_local) or m_block_min < m_block_max + process_tile = ( + const_expr(not self.is_local and not self.is_varlen_q) + or m_block_min < m_block_max + ) if process_tile: if const_expr(self.use_block_sparsity): @@ -1673,7 +1670,10 @@ def mma( process_tile = block_iter_count > Int32(0) else: block_iter_count = m_block_max - m_block_min - process_tile = const_expr(not self.is_local) or m_block_min < m_block_max + process_tile = ( + const_expr(not self.is_local and not self.is_varlen_q) + or m_block_min < m_block_max + ) if process_tile: accumulate_dK = False @@ -2112,7 +2112,10 @@ def compute_loop( ) process_tile = loop_count > Int32(0) else: - process_tile = const_expr(not self.is_local) or m_block_min < m_block_max + process_tile = ( + const_expr(not self.is_local and not self.is_varlen_q) + or m_block_min < m_block_max + ) loop_count = m_block_max - m_block_min # Mainloop @@ -2328,6 +2331,7 @@ def compute_loop( batch_idx, head_idx, n_block, + seqlen, thr_mma_dV, thr_mma_dK, tdVtdV, @@ -2382,7 +2386,7 @@ def compute_loop( # When total_m_block_cnt == 0 for block sparsity, no Q tiles contribute to this KV tile if const_expr(not self.dKV_postprocess): should_zero_dKV = False - if const_expr(self.is_local): + if const_expr(self.is_local or seqlen.has_cu_seqlens_q): should_zero_dKV = m_block_min >= m_block_max if const_expr(self.use_block_sparsity): # For block sparsity, zero when no m_blocks contribute to this n_block @@ -2397,8 +2401,8 @@ def compute_loop( 128, # num_threads ) gmem_thr_copy_zero_dKV = gmem_tiled_copy_zero_dKV.get_slice(dp_idx) - mdV_cur = mdV[None, None, head_idx, batch_idx] - mdK_cur = mdK[None, None, head_idx, batch_idx] + mdV_cur = seqlen.offset_batch_K(mdV, batch_idx, dim=3)[None, None, head_idx] + mdK_cur = seqlen.offset_batch_K(mdK, batch_idx, dim=3)[None, None, head_idx] gdK = cute.local_tile(mdK_cur, (self.tile_n, self.tile_hdim), (n_block, 0)) gdV = cute.local_tile(mdV_cur, (self.tile_n, self.tile_hdimv), (n_block, 0)) tdKgdK = gmem_thr_copy_zero_dKV.partition_D(gdK) @@ -2513,7 +2517,10 @@ def dQacc_reduce( ) process_tile = loop_count > Int32(0) else: - process_tile = const_expr(not self.is_local) or m_block_min < m_block_max + process_tile = ( + const_expr(not self.is_local and not self.is_varlen_q) + or m_block_min < m_block_max + ) loop_count = m_block_max - m_block_min # dQacc_reduce mainloop @@ -2647,6 +2654,7 @@ def epilogue_dKV( batch_idx: Int32, head_idx: Int32, n_block: Int32, + seqlen, thr_mma_dV: cute.core.ThrMma, thr_mma_dK: cute.core.ThrMma, tdVtdV: cute.Tensor, @@ -2663,8 +2671,8 @@ def epilogue_dKV( num_wg = cute.arch.WARP_SIZE * len(self.compute_warp_ids) // 128 assert self.qhead_per_kvhead == 1, "This epilogue path is only for MHA" - mdV_cur = mdV[None, None, head_idx, batch_idx] - mdK_cur = mdK[None, None, head_idx, batch_idx] + mdV_cur = seqlen.offset_batch_K(mdV, batch_idx, dim=3)[None, None, head_idx] + mdK_cur = seqlen.offset_batch_K(mdK, batch_idx, dim=3)[None, None, head_idx] tmem_load_atom = cute.make_copy_atom( tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(16)), Float32 @@ -2714,7 +2722,8 @@ def epilogue_dKV( tdVgdV_r2g_p = thr_tmem_ld_dV.partition_D(tdVgdV) tdVgdV_r2g = self.split_wg(tdVgdV_r2g_p, wg_idx, num_wg) - cute.copy(tiled_gmem_store_dV, tdVrdV_r2s, tdVgdV_r2g) + if tidx < seqlen.seqlen_k - self.tile_n * n_block: + cute.copy(tiled_gmem_store_dV, tdVrdV_r2s, tdVgdV_r2g) cute.arch.sync_warp() with cute.arch.elect_one(): @@ -2767,7 +2776,8 @@ def epilogue_dKV( tdKgdK_r2g_p = thr_tmem_ld_dK.partition_D(tdKgdK) tdKgdK_r2g = self.split_wg(tdKgdK_r2g_p, wg_idx, num_wg) - cute.copy(tiled_gmem_store_dK, tdKrdK_r2s, tdKgdK_r2g) + if tidx < seqlen.seqlen_k - self.tile_n * n_block: + cute.copy(tiled_gmem_store_dK, tdKrdK_r2s, tdKgdK_r2g) cute.arch.sync_warp() with cute.arch.elect_one(): @@ -2782,7 +2792,7 @@ def epilogue_dK_or_dV_tma( batch_idx: Int32, head_idx: Int32, n_block: Int32, - seqlen_info, + seqlen, thr_mma: cute.core.ThrMma, tdKVtdKV: cute.Tensor, mdKV: cute.Tensor, @@ -2812,6 +2822,7 @@ def epilogue_dK_or_dV_tma( head_idx_kv = head_idx // self.qhead_per_kvhead if const_expr(not self.dKV_postprocess): + assert not seqlen.has_cu_seqlens_k, "varlen uses non tma store path" mdKV_cur = mdKV[None, None, head_idx_kv, batch_idx] # (seqlen, hdim) gdKV_p = cute.local_tile( mdKV_cur, (self.tile_n, self.tile_hdim), (n_block, 0) @@ -2821,11 +2832,11 @@ def epilogue_dK_or_dV_tma( gdKV, self.sdKV_epi_tile, (0, None) ) # (tile_n, 64, epi_stage = (hdim / 2) / 64) else: - if const_expr(not seqlen_info.has_cu_seqlens_k): + if const_expr(not seqlen.has_cu_seqlens_k): mdKV_cur = mdKV[None, head_idx_kv, batch_idx] # (seqlen * hdim) else: padded_offset_k = ( - (seqlen_info.offset_k + batch_idx * self.tile_n) // self.tile_n * self.tile_n + (seqlen.offset_k + batch_idx * self.tile_n) // self.tile_n * self.tile_n ) mdKV_cur = cute.domain_offset( (padded_offset_k * self.tile_hdim,), mdKV[None, head_idx_kv] diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 858b1c223df..5e95690e505 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -617,16 +617,16 @@ def _flash_attn_bwd( total_q = batch_size * seqlen_q else: batch_size = cu_seqlens_q.shape[0] - 1 - seqlen_q = max_seqlen_q total_q = q.shape[0] + seqlen_q = max_seqlen_q if max_seqlen_q is not None else total_q if cu_seqlens_k is None: batch_size, seqlen_k = k.shape[:2] total_k = batch_size * seqlen_k else: batch_size = cu_seqlens_k.shape[0] - 1 - seqlen_k = max_seqlen_k total_k = k.shape[0] + seqlen_k = max_seqlen_k if max_seqlen_k is not None else total_k num_head_kv = k.shape[-2] head_dim_v = v.shape[-1] @@ -744,15 +744,13 @@ def _flash_attn_bwd( total_q_rounded_padded = ( (total_q + cu_seqlens_q.shape[0] * m_block_size - 1) // m_block_size * m_block_size ) - # print("total_q_rounded_padded = ", total_q_rounded_padded) dq_accum = torch.empty( num_head, total_q_rounded_padded * head_dim_rounded, dtype=torch.float32, device=device ) dpsum = torch.empty(num_head, total_q_rounded_padded, dtype=torch.float32, device=device) lse_log2 = torch.empty(num_head, total_q_rounded_padded, dtype=torch.float32, device=device) - # todo: allow for mha with cu_seqlens_k to skip dK/dV postprocess - dKV_postprocess = qhead_per_kvhead > 1 or cu_seqlens_k is not None + dKV_postprocess = qhead_per_kvhead > 1 if dKV_postprocess: head_dim_v_rounded = (head_dim_v + 32 - 1) // 32 * 32 if cu_seqlens_k is None: @@ -778,7 +776,6 @@ def _flash_attn_bwd( total_k_rounded_padded = ( (total_k + cu_seqlens_k.shape[0] * n_block_size - 1) // n_block_size * n_block_size ) - # print("total_k_rounded_padded = ", total_k_rounded_padded) num_n_blocks = total_k_rounded_padded // n_block_size if cluster_size == 2 and num_n_blocks % cluster_size != 0: total_k_rounded_padded = total_k_rounded_padded + n_block_size @@ -799,14 +796,12 @@ def _flash_attn_bwd( current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) if deterministic: - assert seqlen_q is not None, "seqlen_q not provided" seqlen_q_rounded = (seqlen_q + m_block_size - 1) // m_block_size * m_block_size dQ_semaphore = torch.zeros(batch_size, num_head, seqlen_q_rounded // m_block_size, 1, dtype=torch.int32, device="cuda") else: dQ_semaphore = None if deterministic and qhead_per_kvhead > 1: - assert seqlen_k is not None, "seqlen_k not provided" seqlen_k_rounded = (seqlen_k + n_block_size - 1) // n_block_size * n_block_size dK_semaphore = torch.zeros(batch_size, num_head_kv, seqlen_k_rounded // n_block_size, 2, dtype=torch.int32, device="cuda") dV_semaphore = torch.zeros(batch_size, num_head_kv, seqlen_k_rounded // n_block_size, 2, dtype=torch.int32, device="cuda") diff --git a/flash_attn/cute/testing.py b/flash_attn/cute/testing.py index a23a624d059..2897e64fc3d 100644 --- a/flash_attn/cute/testing.py +++ b/flash_attn/cute/testing.py @@ -92,7 +92,12 @@ def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random", device=device, ) else: - lengths = torch.randint(max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device) + lengths = torch.randint( + max(0 if zero_lengths else 1, max_seqlen // 3), + max_seqlen + 1, + (batch_size, 1), + device=device, + ) if zero_lengths: for i in range(batch_size): diff --git a/tests/cute/test_flash_attn.py b/tests/cute/test_flash_attn.py index 10e8827d6f0..d6c5029642d 100644 --- a/tests/cute/test_flash_attn.py +++ b/tests/cute/test_flash_attn.py @@ -37,20 +37,20 @@ # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) @pytest.mark.parametrize("dtype", [torch.bfloat16]) -@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) -# @pytest.mark.parametrize("mha_type", ["mha"]) -@pytest.mark.parametrize("has_learnable_sink", [False, True]) -# @pytest.mark.parametrize("has_learnable_sink", [False]) +# @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +@pytest.mark.parametrize("mha_type", ["mha"]) +# @pytest.mark.parametrize("has_learnable_sink", [False, True]) +@pytest.mark.parametrize("has_learnable_sink", [False]) # @pytest.mark.parametrize("has_qv", [False, True]) @pytest.mark.parametrize("has_qv", [False]) # @pytest.mark.parametrize("deterministic", [False, True]) @pytest.mark.parametrize("deterministic", [False]) # @pytest.mark.parametrize("softcap", [0.0, 15.0]) @pytest.mark.parametrize("softcap", [0.0]) -@pytest.mark.parametrize("local_enum", [0, 1, 2, 3]) -# @pytest.mark.parametrize("local_enum", [0]) -@pytest.mark.parametrize("causal", [False, True]) -# @pytest.mark.parametrize("causal", [True]) +# @pytest.mark.parametrize("local_enum", [0, 1, 2, 3]) +@pytest.mark.parametrize("local_enum", [0]) +# @pytest.mark.parametrize("causal", [False, True]) +@pytest.mark.parametrize("causal", [False]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192, 256]) # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) @@ -60,34 +60,35 @@ # @pytest.mark.parametrize("d", [64, 96, 128, 192]) # @pytest.mark.parametrize("d", [64, 128]) # @pytest.mark.parametrize("d", [128, 192]) -@pytest.mark.parametrize("d", [64, 128]) +# @pytest.mark.parametrize("d", [64, 128]) +@pytest.mark.parametrize("d", [128]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ - (1, 1), - (3, 3), - (64, 32), - (64, 128), - (128, 128), - (128, 192), - (256, 256), - (239, 1), - (799, 3), - (113, 203), - (113, 128), - (128, 217), - (113, 211), - (108, 256), - (256, 512), - (384, 256), - (640, 128), - (512, 256), + # (1, 1), + # (3, 3), + # (64, 32), + # (64, 128), + # (128, 128), + # (128, 192), + # (256, 256), + # (239, 1), + # (799, 3), + # (113, 203), + # (113, 128), + # (128, 217), + # (113, 211), + # (108, 256), + # (256, 512), + # (384, 256), + # (640, 128), + # (512, 256), (1024, 1024), - (1023, 1024), - (1024, 1023), - (2048, 2048), - (4096, 4096), - (4224, 4224), + # (1023, 1024), + # (1024, 1023), + # (2048, 2048), + # (4096, 4096), + # (4224, 4224), ], ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)]) @@ -112,8 +113,8 @@ def test_flash_attn_output( torch.random.manual_seed(0) torch.cuda.empty_cache() torch.cuda.synchronize() - batch_size = 9 if seqlen_k <= 2048 else 2 - # batch_size = 1 + # batch_size = 9 if seqlen_k <= 2048 else 2 + batch_size = 2 nheads = 6 # nheads = 1 nheads_kv = nheads if mha_type == "mha" else (3 if mha_type == "gqa" else 1) @@ -372,18 +373,18 @@ def test_flash_attn_output( @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) # @pytest.mark.parametrize("mha_type", ["mha"]) -# @pytest.mark.parametrize("has_learnable_sink", [False, True]) -@pytest.mark.parametrize("has_learnable_sink", [False]) +@pytest.mark.parametrize("has_learnable_sink", [False, True]) +# @pytest.mark.parametrize("has_learnable_sink", [False]) # @pytest.mark.parametrize("has_qv", [False, True]) @pytest.mark.parametrize("has_qv", [False]) # @pytest.mark.parametrize("deterministic", [False, True]) @pytest.mark.parametrize("deterministic", [False]) # @pytest.mark.parametrize("softcap", [0.0, 15.0]) @pytest.mark.parametrize("softcap", [0.0]) -# @pytest.mark.parametrize("local", [False, True]) -@pytest.mark.parametrize("local", [False]) +@pytest.mark.parametrize("local_enum", [0, 1, 2, 3]) +# @pytest.mark.parametrize("local_enum", [0]) @pytest.mark.parametrize("causal", [False, True]) -# @pytest.mark.parametrize("causal", [True]) +# @pytest.mark.parametrize("causal", [False]) # @pytest.mark.parametrize("add_unused_qkv", [False, True]) @pytest.mark.parametrize("add_unused_qkv", [False]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) @@ -420,13 +421,23 @@ def test_flash_attn_output( ], ) @pytest.mark.parametrize("varlen_mode", ["random", "third", "full"]) +# @pytest.mark.parametrize("varlen_mode", ["full"]) +@pytest.mark.parametrize( + "zero_lengths_q, zero_lengths_k", + [ + (False, False), + (True, False), + (False, True), + (True, True), + ], +) def test_flash_attn_varlen_output( seqlen_q, seqlen_k, d, add_unused_qkv, causal, - local, + local_enum, softcap, deterministic, has_qv, @@ -434,7 +445,12 @@ def test_flash_attn_varlen_output( mha_type, dtype, varlen_mode, + zero_lengths_q, + zero_lengths_k, ): + local = local_enum > 0 + if local and causal: + pytest.skip() if ( causal or local ): # Right now reference only supports causal attention with seqlen_k == seqlen_q @@ -491,6 +507,12 @@ def test_flash_attn_varlen_output( window_size = ( (None, None) if not local else torch.randint(0, seqlen_k, (2,)).tolist() ) + if local_enum == 2: + window_size = (None, window_size[1]) + elif local_enum == 3: + window_size = (window_size[0], None) + if local: + print("window size = ", window_size) if has_learnable_sink: learnable_sink = torch.randn(nheads, dtype=torch.bfloat16, device=device) else: @@ -510,18 +532,15 @@ def test_flash_attn_varlen_output( batch_size, device, mode=varlen_mode, - zero_lengths=False + zero_lengths=zero_lengths_q, ) - # TODO: test zero_lengths key_padding_mask = generate_random_padding_mask( - # seqlen_k, batch_size, device, mode="random", zero_lengths=True seqlen_k, batch_size, device, mode=varlen_mode, - zero_lengths=False, + zero_lengths=zero_lengths_k, ) - def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): if add_unused: another_mask = generate_random_padding_mask(max_seq_len, bs, device) @@ -644,6 +663,8 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): # max_seqlen_k, # seqused_q=seqused_q, # seqused_k=seqused_k, + max_seqlen_q=seqlen_q, + max_seqlen_k=seqlen_k, causal=causal, # qv=qv_unpad, # q_descale=q_descale, @@ -654,6 +675,7 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, + deterministic=deterministic, ) out = output_pad_fn(out_unpad) if query_unused_mask is not None: @@ -680,7 +702,7 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): # and False ): g_unpad = torch.randn_like(out_unpad) - do_o = ((g_unpad.float() * out_unpad.float()).sum(-1)).transpose(-1, -2) + # do_o = ((g_unpad.float() * out_unpad.float()).sum(-1)).transpose(-1, -2) # import flash_attn_3_cuda # dq_unpad, dk_unpad, dv_unpad, softmax_d, dq_accum, lse_log2 = flash_attn_3_cuda.bwd_varlen( # g_unpad, diff --git a/tests/cute/test_flash_attn_race_condition.py b/tests/cute/test_flash_attn_race_condition.py index 0174040687f..6edf3565c8d 100644 --- a/tests/cute/test_flash_attn_race_condition.py +++ b/tests/cute/test_flash_attn_race_condition.py @@ -342,3 +342,438 @@ def test_flash_attn_output( print(f"✅ Iteration {i} passed!") + +# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +# @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +@pytest.mark.parametrize("mha_type", ["gqa"]) +# @pytest.mark.parametrize("has_learnable_sink", [False, True]) +@pytest.mark.parametrize("has_learnable_sink", [False]) +# @pytest.mark.parametrize("has_qv", [False, True]) +@pytest.mark.parametrize("has_qv", [False]) +# @pytest.mark.parametrize("deterministic", [False, True]) +@pytest.mark.parametrize("deterministic", [True]) +# @pytest.mark.parametrize("softcap", [0.0, 15.0]) +@pytest.mark.parametrize("softcap", [0.0]) +# @pytest.mark.parametrize("local_enum", [0, 1, 2, 3]) +@pytest.mark.parametrize("local_enum", [0, 1]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize("causal", [False]) +# @pytest.mark.parametrize("add_unused_qkv", [False, True]) +@pytest.mark.parametrize("add_unused_qkv", [False]) +# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192, 256]) +# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [56, 80]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128]) +# @pytest.mark.parametrize("d", [64, 96, 128]) +# @pytest.mark.parametrize("d", [128, 192]) +@pytest.mark.parametrize("d", [64, 128]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (1024, 1024), + (2048, 2048), + ], +) +@pytest.mark.parametrize("varlen_mode", ["random", "third", "full"]) +# @pytest.mark.parametrize("varlen_mode", ["random"]) +@pytest.mark.parametrize( + "zero_lengths_q, zero_lengths_k", + [ + (False, False), + (True, False), + (False, True), + (True, True), + ], +) +def test_flash_attn_varlen_output( + seqlen_q, + seqlen_k, + d, + add_unused_qkv, + causal, + local_enum, + softcap, + deterministic, + has_qv, + has_learnable_sink, + mha_type, + dtype, + varlen_mode, + zero_lengths_q, + zero_lengths_k, +): + local = local_enum > 0 + if local and causal: + pytest.skip() + if ( + causal or local + ): # Right now reference only supports causal attention with seqlen_k == seqlen_q + seqlen_k = seqlen_q + device = "cuda" + # set seed + torch.random.manual_seed(seqlen_q + seqlen_k + d + int(causal) * 2 + int(local)) + batch_size = 49 if seqlen_q <= 1024 else 7 + nheads = 6 + # nheads = 1 + nheads_kv = nheads if mha_type == "mha" else (3 if mha_type == "gqa" else 1) + dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype + # dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) + dv_vals = [128] if d == 192 else ([d] if d != 128 else [64, d]) + dv_vals = [d] # override + # attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if seqlen_q <= seqlen_k else [0] + attention_chunk_vals = [0] + for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): + q_ref = torch.randn( + batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref + ) + if softcap > 0.0: + # Ensure the values of qk are at least within softcap range. + q_ref = (q_ref * softcap / 4).detach().requires_grad_() + q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_() + k_ref = ( + torch.randn( + batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + .requires_grad_() + ) + v_ref = ( + torch.randn( + batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + .requires_grad_() + ) + if has_qv: + qv_ref = ( + torch.randn( + batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + ) + else: + qv_ref = None + # Put window_size after QKV randn so that window_size changes from test to test + window_size = ( + (None, None) if not local else torch.randint(0, seqlen_k, (2,)).tolist() + ) + if local_enum == 2: + window_size = (None, window_size[1]) + elif local_enum == 3: + window_size = (window_size[0], None) + if local: + print("window size = ", window_size) + if has_learnable_sink: + learnable_sink = torch.randn(nheads, dtype=torch.bfloat16, device=device) + else: + learnable_sink = None + if dtype == torch.float8_e4m3fn: + q_descale, k_descale, v_descale = [ + torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) + * 2 + for _ in range(3) + ] + else: + q_descale, k_descale, v_descale = None, None, None + q, k, v = [x.detach().requires_grad_() for x in (q_ref, k_ref, v_ref)] + qv = qv_ref.detach() if has_qv else None + query_padding_mask = generate_random_padding_mask( + seqlen_q, + batch_size, + device, + mode=varlen_mode, + zero_lengths=zero_lengths_q, + ) + key_padding_mask = generate_random_padding_mask( + seqlen_k, + batch_size, + device, + mode=varlen_mode, + zero_lengths=zero_lengths_k, + ) + def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): + if add_unused: + another_mask = generate_random_padding_mask(max_seq_len, bs, device) + attn_mask = torch.logical_and(padding_mask, another_mask) + unused_mask = torch.logical_xor( + torch.logical_or(padding_mask, another_mask), attn_mask + ) + else: + attn_mask = padding_mask + unused_mask = None + return attn_mask, unused_mask + + query_padding_mask, query_unused_mask = _gen_unused_masks( + query_padding_mask, add_unused_qkv, seqlen_q, batch_size, q.device + ) + # query_padding_mask[:] = True + # query_unused_mask = None + key_padding_mask, key_unused_mask = _gen_unused_masks( + key_padding_mask, add_unused_qkv, seqlen_k, batch_size, k.device + ) + + if causal or local: + key_padding_mask = query_padding_mask + + ( + q_unpad, + k_unpad, + v_unpad, + qv_unpad, + cu_seqlens_q, + cu_seqlens_k, + seqused_q, + seqused_k, + max_seqlen_q, + max_seqlen_k, + q, + k, + v, + qv, + output_pad_fn, + dq_pad_fn, + dk_pad_fn, + ) = generate_qkv( + q, + k, + v, + query_padding_mask, + key_padding_mask, + qv=qv, + kvpacked=False, + query_unused_mask=query_unused_mask, + key_unused_mask=key_unused_mask, + ) + print("cu_seqlens_q = ", cu_seqlens_q) + print("cu_seqlens_k = ", cu_seqlens_k) + q_unpad, k_unpad, v_unpad = [ + x.detach().to(dtype).requires_grad_() for x in (q_unpad, k_unpad, v_unpad) + ] + out_ref, attn_ref = attention_ref( + q_ref, + k_ref, + v_ref, + query_padding_mask, + key_padding_mask, + causal=causal, + qv=qv_ref, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + learnable_sink=learnable_sink, + softcap=softcap, + ) + out_pt, attn_pt = attention_ref( + q_ref, + k_ref, + v_ref, + query_padding_mask, + key_padding_mask, + causal=causal, + qv=qv_ref, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + learnable_sink=learnable_sink, + softcap=softcap, + upcast=False, + reorder_ops=True, + intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, + ) + + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + + if query_unused_mask is not None: + q_zero_masking = rearrange(query_unused_mask, "b s -> b s 1 1") + + # Numerical error if we just do any arithmetic on out_ref + fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() + rtol = 2 if softcap == 0.0 else 3 + + out_unpad, lse = flash_attn_varlen_func( + q_unpad, + k_unpad, + v_unpad, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + # max_seqlen_k, + # seqused_q=seqused_q, + # seqused_k=seqused_k, + max_seqlen_q=seqlen_q, + max_seqlen_k=seqlen_k, + causal=causal, + # qv=qv_unpad, + # q_descale=q_descale, + # k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + # attention_chunk=attention_chunk, + learnable_sink=learnable_sink, + softcap=softcap, + num_splits=1, + pack_gqa=False, + deterministic=deterministic, + ) + out = output_pad_fn(out_unpad) + if query_unused_mask is not None: + out.masked_fill_(q_zero_masking, 0.0) + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + # if not causal: + # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") + # breakpoint() + + # Check that FlashAttention's numerical error is at most 3x the numerical error + # of a Pytorch implementation. + assert (out - out_ref).abs().max().item() <= rtol * ( + out_pt - out_ref + ).abs().max().item() + fwd_atol + + if ( + dtype != torch.float8_e4m3fn + and not has_qv + and not dv > 256 + and not attention_chunk != 0 + and dv == d + and not has_learnable_sink + # and False + ): + g_unpad = torch.randn_like(out_unpad) + # do_o = ((g_unpad.float() * out_unpad.float()).sum(-1)).transpose(-1, -2) + # import flash_attn_3_cuda + # dq_unpad, dk_unpad, dv_unpad, softmax_d, dq_accum, lse_log2 = flash_attn_3_cuda.bwd_varlen( + # g_unpad, + # q_unpad, + # k_unpad, + # v_unpad, + # out_unpad, + # lse, + # None, + # None, + # None, + # cu_seqlens_q, + # cu_seqlens_k, + # None, None, + # max_seqlen_q, + # max_seqlen_k, + # d ** (-0.5), + # causal, + # window_size[0], window_size[1], + # softcap, + # deterministic, + # 0, # sm_margin + # ) + dq_unpad, dk_unpad, dv_unpad = torch.autograd.grad( + out_unpad, (q_unpad, k_unpad, v_unpad), g_unpad + ) + dq = dq_pad_fn(dq_unpad) + dk = dk_pad_fn(dk_unpad) + dv = dk_pad_fn(dv_unpad) + if key_unused_mask is not None: + k_zero_masking = rearrange(key_unused_mask, "b s -> b s 1 1") + dk.masked_fill_(k_zero_masking, 0.0) + dv.masked_fill_(k_zero_masking, 0.0) + if query_unused_mask is not None: + dq.masked_fill_(q_zero_masking, 0.0) + # print(f"dO_O max diff: {(softmax_d - do_o).abs().max().item()}") + # assert (softmax_d - do_o).abs().max().item() <= 1e-5 + # assert dq_accum.abs().max().item() == 0.0 + g = output_pad_fn(g_unpad) + + # qk = torch.einsum('bthd,bshd->bhts', q / (d ** 0.5), k).float() + # qk = torch.masked_fill(qk, rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) + # dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float()) + # P = torch.softmax(qk, -1) + # dP = P * (dS - (g.float() * out.float()).sum(-1).transpose(1, 2).unsqueeze(-1)) + # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float()) + # dV = torch.einsum('bhts,bthd->bshd', P, g.float()) + # dK = torch.einsum('bhts,bthd->bshd', dP, q.float()) + + # dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) + dq_ref, dk_ref, dv_ref = torch.autograd.grad( + out_ref, (q_ref, k_ref, v_ref), g + ) + dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g) + print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") + print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") + print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") + print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") + print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") + print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") + print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") + print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") + print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") + print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") + print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") + print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") + # breakpoint() + dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + ( + 0 if softcap == 0 else 3e-4 + ) + assert (dq - dq_ref).abs().max().item() <= rtol * ( + dq_pt - dq_ref + ).abs().max().item() + dq_atol + dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + ( + 0 if softcap == 0 else 3e-4 + ) + assert (dk - dk_ref).abs().max().item() <= rtol * ( + dk_pt - dk_ref + ).abs().max().item() + dk_atol + dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + ( + 0 if softcap == 0 else 3e-4 + ) + assert (dv - dv_ref).abs().max().item() <= rtol * ( + dv_pt - dv_ref + ).abs().max().item() + dv_atol + + num_iters = 10_000 + + for i in range(num_iters): + dq_unpad2, dk_unpad2, dv_unpad2 = _flash_attn_bwd( + q_unpad, k_unpad, v_unpad, out_unpad, g_unpad, lse, + causal=causal, + window_size_left=window_size[0], + window_size_right=window_size[1], + deterministic=True, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=seqlen_q, + max_seqlen_k=seqlen_k, + ) + + diff_dq = (dq_unpad - dq_unpad2).abs() + max_idx = diff_dq.argmax() + print(f"dQ max diff: {diff_dq.max().item()}") + print(f" at index {max_idx.item()}: dQ={dq_unpad.flatten()[max_idx].item()}, dQ2={dq_unpad2.flatten()[max_idx].item()}") + + diff_dk = (dk_unpad - dk_unpad2).abs() + max_idx = diff_dk.argmax() + print(f"dK max diff: {diff_dk.max().item()}") + print(f" at index {max_idx.item()}: dK={dk_unpad.flatten()[max_idx].item()}, dK2={dk_unpad2.flatten()[max_idx].item()}") + + diff_dv = (dv_unpad - dv_unpad2).abs() + max_idx = diff_dv.argmax() + print(f"dV max diff: {diff_dv.max().item()}") + print(f" at index {max_idx.item()}: dV={dv_unpad.flatten()[max_idx].item()}, dV2={dv_unpad2.flatten()[max_idx].item()}") + + # print(f"dQ max diff with myself: {(dq - dq2).abs().max().item()}") + # print(f"dK max diff with myself: {(dk - dk2).abs().max().item()}") + # print(f"dV max diff with myself: {(dv - dv2).abs().max().item()}") + # print(f"dQ mean diff with myself: {(dq - dq2).abs().mean().item()}") + # print(f"dK mean diff with myself: {(dk - dk2).abs().mean().item()}") + # print(f"dV mean diff with myself: {(dv - dv2).abs().mean().item()}") + + assert torch.equal(dq_unpad, dq_unpad2) + assert torch.equal(dk_unpad, dk_unpad2) + assert torch.equal(dv_unpad, dv_unpad2) + + print(f"✅ Iteration {i} passed!") \ No newline at end of file From 09c21b07f3cf48cb54721dafdbdb690a237d02f0 Mon Sep 17 00:00:00 2001 From: jayhshah Date: Thu, 8 Jan 2026 08:13:34 +0000 Subject: [PATCH 07/15] reenable tests --- tests/cute/test_flash_attn.py | 67 +++++++++++++++++------------------ 1 file changed, 33 insertions(+), 34 deletions(-) diff --git a/tests/cute/test_flash_attn.py b/tests/cute/test_flash_attn.py index d6c5029642d..382c3460534 100644 --- a/tests/cute/test_flash_attn.py +++ b/tests/cute/test_flash_attn.py @@ -37,20 +37,20 @@ # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) @pytest.mark.parametrize("dtype", [torch.bfloat16]) -# @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) -@pytest.mark.parametrize("mha_type", ["mha"]) -# @pytest.mark.parametrize("has_learnable_sink", [False, True]) -@pytest.mark.parametrize("has_learnable_sink", [False]) +@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +# @pytest.mark.parametrize("mha_type", ["mha"]) +@pytest.mark.parametrize("has_learnable_sink", [False, True]) +# @pytest.mark.parametrize("has_learnable_sink", [False]) # @pytest.mark.parametrize("has_qv", [False, True]) @pytest.mark.parametrize("has_qv", [False]) # @pytest.mark.parametrize("deterministic", [False, True]) @pytest.mark.parametrize("deterministic", [False]) # @pytest.mark.parametrize("softcap", [0.0, 15.0]) @pytest.mark.parametrize("softcap", [0.0]) -# @pytest.mark.parametrize("local_enum", [0, 1, 2, 3]) -@pytest.mark.parametrize("local_enum", [0]) -# @pytest.mark.parametrize("causal", [False, True]) -@pytest.mark.parametrize("causal", [False]) +@pytest.mark.parametrize("local_enum", [0, 1, 2, 3]) +# @pytest.mark.parametrize("local_enum", [0]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize("causal", [False]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192, 256]) # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) @@ -58,37 +58,36 @@ # @pytest.mark.parametrize("d", [64, 128, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128]) # @pytest.mark.parametrize("d", [64, 96, 128, 192]) -# @pytest.mark.parametrize("d", [64, 128]) # @pytest.mark.parametrize("d", [128, 192]) -# @pytest.mark.parametrize("d", [64, 128]) -@pytest.mark.parametrize("d", [128]) +@pytest.mark.parametrize("d", [64, 128]) +# @pytest.mark.parametrize("d", [128]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ - # (1, 1), - # (3, 3), - # (64, 32), - # (64, 128), - # (128, 128), - # (128, 192), - # (256, 256), - # (239, 1), - # (799, 3), - # (113, 203), - # (113, 128), - # (128, 217), - # (113, 211), - # (108, 256), - # (256, 512), - # (384, 256), - # (640, 128), - # (512, 256), + (1, 1), + (3, 3), + (64, 32), + (64, 128), + (128, 128), + (128, 192), + (256, 256), + (239, 1), + (799, 3), + (113, 203), + (113, 128), + (128, 217), + (113, 211), + (108, 256), + (256, 512), + (384, 256), + (640, 128), + (512, 256), (1024, 1024), - # (1023, 1024), - # (1024, 1023), - # (2048, 2048), - # (4096, 4096), - # (4224, 4224), + (1023, 1024), + (1024, 1023), + (2048, 2048), + (4096, 4096), + (4224, 4224), ], ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)]) From abfe1f0e13988ab825db0807c64b778598f52faf Mon Sep 17 00:00:00 2001 From: jayhshah Date: Thu, 8 Jan 2026 08:14:56 +0000 Subject: [PATCH 08/15] --- tests/cute/test_flash_attn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/cute/test_flash_attn.py b/tests/cute/test_flash_attn.py index 382c3460534..06f0e45d3f5 100644 --- a/tests/cute/test_flash_attn.py +++ b/tests/cute/test_flash_attn.py @@ -112,8 +112,8 @@ def test_flash_attn_output( torch.random.manual_seed(0) torch.cuda.empty_cache() torch.cuda.synchronize() - # batch_size = 9 if seqlen_k <= 2048 else 2 - batch_size = 2 + batch_size = 9 if seqlen_k <= 2048 else 2 + # batch_size = 2 nheads = 6 # nheads = 1 nheads_kv = nheads if mha_type == "mha" else (3 if mha_type == "gqa" else 1) From 3dda1053080b20732297dea08377f334420bccb0 Mon Sep 17 00:00:00 2001 From: jayhshah Date: Thu, 8 Jan 2026 08:21:27 +0000 Subject: [PATCH 09/15] fix lint error --- flash_attn/cute/cute_dsl_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/flash_attn/cute/cute_dsl_utils.py b/flash_attn/cute/cute_dsl_utils.py index 6673b155dc4..9d6ee345d00 100644 --- a/flash_attn/cute/cute_dsl_utils.py +++ b/flash_attn/cute/cute_dsl_utils.py @@ -123,6 +123,7 @@ def cute_compile_patched(*args, **kwargs): pathlib.Path(cubin_path).with_suffix(".annotated.sass").write_text(sass) return output + def to_cute_tensor(t, assumed_align=16, leading_dim=-1, fully_dynamic=False, enable_tvm_ffi=True): """Convert torch tensor to cute tensor for TVM FFI. leading_dim=-1 defaults to t.ndim-1.""" tensor = from_dlpack(t.detach(), assumed_align=assumed_align, enable_tvm_ffi=enable_tvm_ffi) From d7eeffca13b2d468f0858bdc5a6cf16e5758e5d5 Mon Sep 17 00:00:00 2001 From: jayhshah Date: Thu, 8 Jan 2026 17:27:16 +0000 Subject: [PATCH 10/15] use head swizzle/spt for deterministic, update tests --- benchmarks/benchmark_attn.py | 11 ++++-- flash_attn/cute/flash_bwd_sm100.py | 3 +- flash_attn/cute/tile_scheduler.py | 8 +++- tests/cute/test_flash_attn_race_condition.py | 41 +++++++++----------- 4 files changed, 34 insertions(+), 29 deletions(-) diff --git a/benchmarks/benchmark_attn.py b/benchmarks/benchmark_attn.py index cb6bc44eae2..6158eddc174 100644 --- a/benchmarks/benchmark_attn.py +++ b/benchmarks/benchmark_attn.py @@ -325,9 +325,9 @@ def run(*args, **kwargs): else: page_table = None - # for causal in [False, True]: - for causal in [True]: - print(f"\n### {headdim = }, {causal = }, {seqlen = } ###") + for causal in [False, True]: + # for causal in [True]: + print(f"\n### {headdim = }, {causal = }, {seqlen = }, {batch_size = }, {nheads = }, {nheads_kv = }, {varlen = }, {deterministic = } ###") nFLOPS = flops(batch_size, nheads, seqlen_q, seqlen, headdim if not has_qv else headdim + headdim_v, headdim_v, causal=causal, window_size=window_size) if cudnn is not None: # if False: @@ -395,7 +395,10 @@ def run(*args, **kwargs): # pytorch_profiler(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, deterministic=deterministic, backward=True) # benchmark_forward(torch.clone, k, repeats=repeats, verbose=verbose, desc='Memcpy') if dtype != torch.float8_e4m3fn and headdim == headdim_v and flash_attn_func_python is not None and has_backward: - _, m1b_py = benchmark_backward(flash_attn_func_python, q, k, v, causal=causal, softcap=softcap, repeats=repeats, verbose=False, desc='Fav2 python') + if not varlen: + _, m1b_py = benchmark_backward(flash_attn_func_python, q, k, v, causal=causal, softcap=softcap, deterministic=deterministic, repeats=repeats, verbose=False, desc='Fav4 python') + else: + _, m1b_py = benchmark_backward(flash_attn_varlen_func_python, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, causal=causal, softcap=softcap, deterministic=deterministic, repeats=repeats, verbose=False, desc='Fav4 python') if dtype != torch.float8_e4m3fn and headdim == headdim_v and flash_attn_func is not None: # if False: diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py index 2aa49342ebd..91e70a91769 100644 --- a/flash_attn/cute/flash_bwd_sm100.py +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -607,7 +607,7 @@ def __call__( else: TileScheduler = SingleTileScheduler # reads n_blocks right-to-left - self.spt = (self.is_causal or self.is_local) and self.deterministic and not self.is_varlen_k + self.spt = (self.is_causal or self.is_local) and self.deterministic tile_sched_args = TileSchedulerArguments( cute.ceil_div(cute.size(mK.shape[0]), self.cta_tiler[0]), # num_blocks cute.size(mQ.shape[2]), # num_heads = num_query_heads @@ -627,6 +627,7 @@ def __call__( element_size=self.k_dtype.width // 8, is_persistent=self.is_persistent, # persistent mode not tested lpt=self.spt, + head_swizzle=self.deterministic, ) tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) diff --git a/flash_attn/cute/tile_scheduler.py b/flash_attn/cute/tile_scheduler.py index ef47cedecdf..36a5c6b75ec 100644 --- a/flash_attn/cute/tile_scheduler.py +++ b/flash_attn/cute/tile_scheduler.py @@ -72,6 +72,7 @@ class TileSchedulerArguments(ParamsBase): is_persistent: cutlass.Constexpr[bool] = False lpt: cutlass.Constexpr[bool] = False is_split_kv: cutlass.Constexpr[bool] = False + head_swizzle: cutlass.Constexpr[bool] = False class SingleTileScheduler: @@ -512,6 +513,7 @@ class Params(ParamsBase): qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1 lpt: cutlass.Constexpr[bool] = False is_split_kv: cutlass.Constexpr[bool] = False + head_swizzle: cutlass.Constexpr[bool] = False @staticmethod @cute.jit @@ -537,6 +539,7 @@ def create( qhead_per_kvhead_packgqa=args.qhead_per_kvhead_packgqa, lpt=args.lpt, is_split_kv=args.is_split_kv, + head_swizzle=args.head_swizzle, ) def __init__(self, params: Params, tile_idx: Int32, split_idx: Int32, *, loc=None, ip=None): @@ -638,7 +641,7 @@ def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: ) num_m_blocks = cute.arch.shuffle_sync(num_m_blocks, batch_idx_in_group) mh_block = next_tile_idx - group_start_tile - num_m_blocks_prev_lane * params.num_head - if cutlass.const_expr(params.lpt): + if cutlass.const_expr(params.lpt or params.head_swizzle): # This is a version of the SingleTileLPTScheduler, complicated by the fact that # the seqlen can vary per batch. # TODO: is there any case where num_m_blocks is 0? @@ -677,7 +680,8 @@ def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: block = l2_mod // nheads_in_this_section head_idx_residual = l2_mod - block * nheads_in_this_section head_idx = section_idx * nheads_in_l2 + head_idx_residual - block = num_m_blocks - 1 - block + if cutlass.const_expr(params.lpt): + block = num_m_blocks - 1 - block else: head_idx = mh_block // num_m_blocks block = mh_block - head_idx * num_m_blocks diff --git a/tests/cute/test_flash_attn_race_condition.py b/tests/cute/test_flash_attn_race_condition.py index 6edf3565c8d..c2a649067bf 100644 --- a/tests/cute/test_flash_attn_race_condition.py +++ b/tests/cute/test_flash_attn_race_condition.py @@ -31,7 +31,7 @@ DISABLE_SPLIT = os.getenv("FLASH_ATTENTION_DISABLE_SPLIT", "FALSE") == "TRUE" IS_SM90 = torch.cuda.get_device_capability()[0] == 9 - +INCREASED_TRIALS = False # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) @pytest.mark.parametrize("dtype", [torch.bfloat16]) @@ -304,7 +304,7 @@ def test_flash_attn_output( dv_pt - dv_ref ).abs().max().item() + dv_atol - num_iters = 20_000 + num_iters = 10_000 if INCREASED_TRIALS else 1000 for i in range(num_iters): dq2, dk2, dv2, = _flash_attn_bwd( q, k, v, out, g, lse, @@ -345,8 +345,8 @@ def test_flash_attn_output( # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) @pytest.mark.parametrize("dtype", [torch.bfloat16]) -# @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) -@pytest.mark.parametrize("mha_type", ["gqa"]) +@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +# @pytest.mark.parametrize("mha_type", ["gqa"]) # @pytest.mark.parametrize("has_learnable_sink", [False, True]) @pytest.mark.parametrize("has_learnable_sink", [False]) # @pytest.mark.parametrize("has_qv", [False, True]) @@ -355,10 +355,10 @@ def test_flash_attn_output( @pytest.mark.parametrize("deterministic", [True]) # @pytest.mark.parametrize("softcap", [0.0, 15.0]) @pytest.mark.parametrize("softcap", [0.0]) -# @pytest.mark.parametrize("local_enum", [0, 1, 2, 3]) -@pytest.mark.parametrize("local_enum", [0, 1]) +@pytest.mark.parametrize("local_enum", [0, 1, 2, 3]) +# @pytest.mark.parametrize("local_enum", [0, 1]) @pytest.mark.parametrize("causal", [False, True]) -# @pytest.mark.parametrize("causal", [False]) +# @pytest.mark.parametrize("causal", [True]) # @pytest.mark.parametrize("add_unused_qkv", [False, True]) @pytest.mark.parametrize("add_unused_qkv", [False]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) @@ -735,7 +735,7 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): dv_pt - dv_ref ).abs().max().item() + dv_atol - num_iters = 10_000 + num_iters = 10_000 if INCREASED_TRIALS else 1000 for i in range(num_iters): dq_unpad2, dk_unpad2, dv_unpad2 = _flash_attn_bwd( @@ -752,28 +752,25 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): diff_dq = (dq_unpad - dq_unpad2).abs() max_idx = diff_dq.argmax() - print(f"dQ max diff: {diff_dq.max().item()}") - print(f" at index {max_idx.item()}: dQ={dq_unpad.flatten()[max_idx].item()}, dQ2={dq_unpad2.flatten()[max_idx].item()}") + if i % 100 == 0: + print(f"dQ max diff: {diff_dq.max().item()}") + print(f" at index {max_idx.item()}: dQ={dq_unpad.flatten()[max_idx].item()}, dQ2={dq_unpad2.flatten()[max_idx].item()}") diff_dk = (dk_unpad - dk_unpad2).abs() max_idx = diff_dk.argmax() - print(f"dK max diff: {diff_dk.max().item()}") - print(f" at index {max_idx.item()}: dK={dk_unpad.flatten()[max_idx].item()}, dK2={dk_unpad2.flatten()[max_idx].item()}") + if i % 100 == 0: + print(f"dK max diff: {diff_dk.max().item()}") + print(f" at index {max_idx.item()}: dK={dk_unpad.flatten()[max_idx].item()}, dK2={dk_unpad2.flatten()[max_idx].item()}") diff_dv = (dv_unpad - dv_unpad2).abs() max_idx = diff_dv.argmax() - print(f"dV max diff: {diff_dv.max().item()}") - print(f" at index {max_idx.item()}: dV={dv_unpad.flatten()[max_idx].item()}, dV2={dv_unpad2.flatten()[max_idx].item()}") - - # print(f"dQ max diff with myself: {(dq - dq2).abs().max().item()}") - # print(f"dK max diff with myself: {(dk - dk2).abs().max().item()}") - # print(f"dV max diff with myself: {(dv - dv2).abs().max().item()}") - # print(f"dQ mean diff with myself: {(dq - dq2).abs().mean().item()}") - # print(f"dK mean diff with myself: {(dk - dk2).abs().mean().item()}") - # print(f"dV mean diff with myself: {(dv - dv2).abs().mean().item()}") + if i % 100 == 0: + print(f"dV max diff: {diff_dv.max().item()}") + print(f" at index {max_idx.item()}: dV={dv_unpad.flatten()[max_idx].item()}, dV2={dv_unpad2.flatten()[max_idx].item()}") assert torch.equal(dq_unpad, dq_unpad2) assert torch.equal(dk_unpad, dk_unpad2) assert torch.equal(dv_unpad, dv_unpad2) - print(f"✅ Iteration {i} passed!") \ No newline at end of file + if i % 100 == 0: + print(f"✅ Iteration {i} passed!") \ No newline at end of file From 68d5fbf413e0a6550dc310fd4bcba38190bcfcf3 Mon Sep 17 00:00:00 2001 From: jayhshah Date: Thu, 8 Jan 2026 18:00:46 +0000 Subject: [PATCH 11/15] change padding offset based on arch --- flash_attn/cute/flash_bwd_preprocess.py | 22 ++++++---------------- flash_attn/cute/interface.py | 2 ++ 2 files changed, 8 insertions(+), 16 deletions(-) diff --git a/flash_attn/cute/flash_bwd_preprocess.py b/flash_attn/cute/flash_bwd_preprocess.py index faee1da055f..cd514316f88 100644 --- a/flash_attn/cute/flash_bwd_preprocess.py +++ b/flash_attn/cute/flash_bwd_preprocess.py @@ -3,7 +3,7 @@ # from Cutlass C++ to Cute-DSL. import math import operator -from typing import Callable, Type, Optional +from typing import Callable, Type, Optional, Literal import cuda.bindings.driver as cuda @@ -27,6 +27,7 @@ def __init__( self, dtype: Type[cutlass.Numeric], head_dim: int, + arch: Literal[80, 90, 100], m_block_size: int = 128, num_threads: int = 128, ): @@ -43,6 +44,7 @@ def __init__( """ self.dtype = dtype self.m_block_size = m_block_size + self.arch = arch # padding head_dim to a multiple of 32 as k_block_size hdim_multiple_of = 32 self.head_dim_padded = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of) @@ -238,11 +240,9 @@ def kernel( mO_cur = cute.domain_offset((seqlen.offset_q, 0), mO[None, head_idx, None]) mdO_cur = cute.domain_offset((seqlen.offset_q, 0), mdO[None, head_idx, None]) - padded_offset_q = ( - (seqlen.offset_q + batch_idx * self.m_block_size) - // self.m_block_size - * self.m_block_size - ) + padded_offset_q = seqlen.offset_q + batch_idx * self.m_block_size + if cutlass.const_expr(self.arch >= 90): + padded_offset_q = padded_offset_q // self.m_block_size * self.m_block_size mdPsum_cur = cute.domain_offset((padded_offset_q,), mdPsum[head_idx, None]) headdim_v = mO.shape[2] @@ -329,11 +329,6 @@ def kernel( if cutlass.const_expr(not seqlen.has_cu_seqlens_q): mdQaccum_cur = mdQaccum[batch_idx, head_idx, None] else: - padded_offset_q = ( - (seqlen.offset_q + batch_idx * self.m_block_size) - // self.m_block_size - * self.m_block_size - ) mdQaccum_cur = cute.domain_offset( (padded_offset_q * self.head_dim_padded,), mdQaccum[head_idx, None] ) @@ -362,11 +357,6 @@ def kernel( if cutlass.const_expr(not seqlen.has_cu_seqlens_q): mLSElog2_cur = mLSElog2[batch_idx, head_idx, None] else: - padded_offset_q = ( - (seqlen.offset_q + batch_idx * self.m_block_size) - // self.m_block_size - * self.m_block_size - ) mLSElog2_cur = cute.domain_offset((padded_offset_q,), mLSElog2[head_idx, None]) gLSElog2 = cute.local_tile(mLSElog2_cur, (self.m_block_size,), (m_block,)) diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 5e95690e505..909d3ac2ebe 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -829,9 +829,11 @@ def _flash_attn_bwd( to_cute_tensor(t, assumed_align=4) if t is not None else None for t in (cu_seqlens_q, seqused_q) ] + arch = compute_capability * 10 fa_bwd_pre = FlashAttentionBackwardPreprocess( dtype, head_dim_v, + arch, m_block_size, num_threads=num_threads, ) From f3a4610ca35b7e6e76cb4eddfedd0c450c565b63 Mon Sep 17 00:00:00 2001 From: jayhshah Date: Thu, 8 Jan 2026 18:50:42 +0000 Subject: [PATCH 12/15] rebase and update interface, tests --- flash_attn/cute/interface.py | 19 ++++++------------- tests/cute/test_flash_attn.py | 11 +++++------ tests/cute/test_flash_attn_varlen.py | 6 +++--- 3 files changed, 14 insertions(+), 22 deletions(-) diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 909d3ac2ebe..071738f9098 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -92,6 +92,8 @@ def _flash_attn_fwd( cu_seqlens_k: Optional[torch.Tensor] = None, seqused_q: Optional[torch.Tensor] = None, seqused_k: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_k: Optional[int] = None, page_table: Optional[torch.Tensor] = None, softmax_scale: Optional[float] = None, causal: bool = False, @@ -115,8 +117,6 @@ def _flash_attn_fwd( out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, aux_tensors: Optional[list[torch.Tensor]] = None, - max_seqlen_q: Optional[int] = None, - max_seqlen_k: Optional[int] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """Forward pass for FlashAttention. @@ -1330,8 +1330,6 @@ def forward( deterministic: bool = False, score_mod: Optional[Callable] = None, aux_tensors: Optional[list] = None, - max_seqlen_q: Optional[int] = None, - max_seqlen_k: Optional[int] = None, ): out, lse = _flash_attn_fwd( q, @@ -1341,6 +1339,8 @@ def forward( cu_seqlens_k, seqused_q, seqused_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, page_table=page_table, softmax_scale=softmax_scale, causal=causal, @@ -1352,8 +1352,6 @@ def forward( pack_gqa=pack_gqa, score_mod=score_mod, aux_tensors=aux_tensors, - max_seqlen_q=max_seqlen_q, - max_seqlen_k=max_seqlen_k, ) ctx.save_for_backward(q, k, v, out, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k) ctx.softmax_scale = softmax_scale @@ -1368,7 +1366,6 @@ def forward( @staticmethod def backward(ctx, dout, *args): q, k, v, out, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors - assert seqused_q == seqused_k == None assert ctx.softcap == 0.0 dq, dk, dv = _flash_attn_bwd( q, @@ -1438,10 +1435,10 @@ def flash_attn_varlen_func( v: torch.Tensor, cu_seqlens_q: Optional[torch.Tensor] = None, cu_seqlens_k: Optional[torch.Tensor] = None, - seqused_q: Optional[torch.Tensor] = None, - seqused_k: Optional[torch.Tensor] = None, max_seqlen_q: Optional[int] = None, max_seqlen_k: Optional[int] = None, + seqused_q: Optional[torch.Tensor] = None, + seqused_k: Optional[torch.Tensor] = None, page_table: Optional[torch.Tensor] = None, softmax_scale: Optional[float] = None, causal: bool = False, @@ -1453,8 +1450,6 @@ def flash_attn_varlen_func( deterministic: bool = False, score_mod: Optional[Callable] = None, aux_tensors: Optional[list] = None, - max_seqlen_q: Optional[int] = None, - max_seqlen_k: Optional[int] = None, ): return FlashAttnVarlenFunc.apply( q, @@ -1477,8 +1472,6 @@ def flash_attn_varlen_func( deterministic, score_mod, aux_tensors, - max_seqlen_q, - max_seqlen_k, ) diff --git a/tests/cute/test_flash_attn.py b/tests/cute/test_flash_attn.py index 06f0e45d3f5..c0cd927be26 100644 --- a/tests/cute/test_flash_attn.py +++ b/tests/cute/test_flash_attn.py @@ -372,8 +372,8 @@ def test_flash_attn_output( @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) # @pytest.mark.parametrize("mha_type", ["mha"]) -@pytest.mark.parametrize("has_learnable_sink", [False, True]) -# @pytest.mark.parametrize("has_learnable_sink", [False]) +# @pytest.mark.parametrize("has_learnable_sink", [False, True]) +@pytest.mark.parametrize("has_learnable_sink", [False]) # @pytest.mark.parametrize("has_qv", [False, True]) @pytest.mark.parametrize("has_qv", [False]) # @pytest.mark.parametrize("deterministic", [False, True]) @@ -393,7 +393,7 @@ def test_flash_attn_output( # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128]) # @pytest.mark.parametrize("d", [64, 96, 128]) # @pytest.mark.parametrize("d", [128, 192]) -@pytest.mark.parametrize("d", [128]) +@pytest.mark.parametrize("d", [64, 128]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ @@ -659,11 +659,10 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): v_unpad, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, - # max_seqlen_k, - # seqused_q=seqused_q, - # seqused_k=seqused_k, max_seqlen_q=seqlen_q, max_seqlen_k=seqlen_k, + # seqused_q=seqused_q, + # seqused_k=seqused_k, causal=causal, # qv=qv_unpad, # q_descale=q_descale, diff --git a/tests/cute/test_flash_attn_varlen.py b/tests/cute/test_flash_attn_varlen.py index 3f726676749..1666a08fb00 100644 --- a/tests/cute/test_flash_attn_varlen.py +++ b/tests/cute/test_flash_attn_varlen.py @@ -43,8 +43,8 @@ def test_varlen( dtype=dtype ) - # SM90/SM100 backward pass doesn't support varlen yet - skip_backward = IS_SM90 or torch.cuda.get_device_capability()[0] == 10 + # SM90 backward pass doesn't support varlen yet + skip_backward = IS_SM90 ok = check_varlen_vs_torch_flash( q, k, v, @@ -128,7 +128,7 @@ def clone_like(t): if not ok_fwd: return False - # Skip backward if not supported (e.g., SM100 varlen) + # Skip backward if not supported (e.g., SM90 varlen) if skip_backward: return True From 7b7c045623687874feb15b580e26cd63d6c3d8eb Mon Sep 17 00:00:00 2001 From: jayhshah Date: Thu, 8 Jan 2026 19:16:43 +0000 Subject: [PATCH 13/15] add arch dispatch for padded offset q to postprocess --- flash_attn/cute/flash_bwd_postprocess.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/flash_attn/cute/flash_bwd_postprocess.py b/flash_attn/cute/flash_bwd_postprocess.py index 5211fc3b7ae..5b1a3acae64 100644 --- a/flash_attn/cute/flash_bwd_postprocess.py +++ b/flash_attn/cute/flash_bwd_postprocess.py @@ -341,9 +341,9 @@ def kernel( mdQaccum_cur = mdQaccum[batch_idx, head_idx, None] head_dim = mdQ.shape[3] else: - padded_offset_q = ( - (seqlen.offset_q + batch_idx * self.tile_m) // self.tile_m * self.tile_m - ) + padded_offset_q = seqlen.offset_q + batch_idx * self.tile_m + if cutlass.const_expr(self.arch >= 90): + padded_offset_q = padded_offset_q // self.tile_m * self.tile_m mdQ_cur = cute.domain_offset((seqlen.offset_q, 0), mdQ[None, head_idx, None]) mdQaccum_cur = cute.domain_offset( (padded_offset_q * self.tile_hdim,), mdQaccum[head_idx, None] From f84e7762b763cd0d450593a4311848608c68d041 Mon Sep 17 00:00:00 2001 From: jayhshah Date: Fri, 9 Jan 2026 22:36:51 +0000 Subject: [PATCH 14/15] address comments --- flash_attn/cute/flash_bwd_sm100.py | 40 ++++++++---------------------- flash_attn/cute/interface.py | 7 +++--- flash_attn/cute/seqlen_info.py | 26 +++++++++++-------- 3 files changed, 29 insertions(+), 44 deletions(-) diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py index 91e70a91769..ed4154edbf3 100644 --- a/flash_attn/cute/flash_bwd_sm100.py +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -404,7 +404,7 @@ def __call__( self.is_varlen_k = mCuSeqlensK is not None or mSeqUsedK is not None self.is_varlen_q = mCuSeqlensQ is not None or mSeqUsedQ is not None - self.use_tma_store = not (self.qhead_per_kvhead == 1 and self.is_varlen_k) + self.use_tma_store = not (self.qhead_per_kvhead == 1 and mCuSeqlensK is not None) self.dKV_postprocess = self.qhead_per_kvhead > 1 if const_expr(self.dKV_postprocess): @@ -618,7 +618,9 @@ def __call__( cute.size(mQ.shape[0]), # pass seqlen_q or total_q for seqlen_k mQ.shape[1], # headdim mV.shape[1], # headdim_v - total_q=cute.size(mK.shape[0]), # pass total_k for total_q + total_q=cute.size(mK.shape[0]) # pass total_k for total_q + if const_expr(mCuSeqlensK is not None) + else cute.size(mK.shape[0]) * cute.size(mK.shape[3]), tile_shape_mn=self.cta_tiler[:2], # (tile_n, tile_m) cluster_shape_mn=self.cluster_shape_mnk[:2], mCuSeqlensQ=mCuSeqlensK, @@ -739,17 +741,6 @@ class SharedStorage: "Variable sequence length is not supported yet for blocksparse or aux tensors in bwd" ) - # cute.printf("mQ = {}", tma_tensor_Q.layout) - # cute.printf("mK = {}", tma_tensor_K.layout) - # cute.printf("mV = {}", tma_tensor_V.layout) - # cute.printf("mLSE = {}", mLSE.layout) - # cute.printf("mdPsum = {}", mdPsum.layout) - # cute.printf("tma_tensor_dO = {}", tma_tensor_dO.layout) - # cute.printf("mdV = {}", mdV.layout) - # cute.printf("mdK = {}", mdK.layout) - # cute.printf("mdQaccum = {}", mdQaccum.layout) - # cute.printf("grid_dim = {}", grid_dim) - self.kernel( tma_tensor_Q, tma_tensor_K, @@ -1343,15 +1334,10 @@ def load( mdO_cur = mdO[None, None, head_idx, batch_idx] else: mdO_cur = cute.domain_offset((0, seqlen.offset_q), mdO[None, None, head_idx]) - if const_expr(not seqlen.has_cu_seqlens_q): - mLSE_cur = mLSE[None, head_idx, batch_idx] - mdPsum_cur = mdPsum[None, head_idx, batch_idx] - else: - padded_offset_q = ( - (seqlen.offset_q + batch_idx * self.tile_m) // self.tile_m * self.tile_m - ) - mLSE_cur = cute.domain_offset((padded_offset_q,), mLSE[None, head_idx]) - mdPsum_cur = cute.domain_offset((padded_offset_q,), mdPsum[None, head_idx]) + mLSE_cur = seqlen.offset_batch_Q(mLSE, batch_idx, dim=2, padded=True)[None, head_idx] + mdPsum_cur = seqlen.offset_batch_Q(mdPsum, batch_idx, dim=2, padded=True)[ + None, head_idx + ] gK = cute.local_tile(mK_cur, cute.select(self.mma_tiler_kq, mode=[0, 2]), (n_block, 0)) tSgK = thr_mma_S.partition_A(gK) @@ -2482,11 +2468,8 @@ def dQacc_reduce( if const_expr(not seqlen.has_cu_seqlens_q): mdQaccum_cur = mdQaccum[None, head_idx, batch_idx] else: - padded_offset_q = ( - (seqlen.offset_q + batch_idx * self.tile_m) // self.tile_m * self.tile_m - ) mdQaccum_cur = cute.domain_offset( - (padded_offset_q * self.tile_hdim,), mdQaccum[None, head_idx] + (seqlen.padded_offset_q * self.tile_hdim,), mdQaccum[None, head_idx] ) gdQaccum_ = cute.local_tile(mdQaccum_cur, (self.tile_m * self.tile_hdim,), (None,)) # (M * K / STAGE, STAGE, _) @@ -2836,11 +2819,8 @@ def epilogue_dK_or_dV_tma( if const_expr(not seqlen.has_cu_seqlens_k): mdKV_cur = mdKV[None, head_idx_kv, batch_idx] # (seqlen * hdim) else: - padded_offset_k = ( - (seqlen.offset_k + batch_idx * self.tile_n) // self.tile_n * self.tile_n - ) mdKV_cur = cute.domain_offset( - (padded_offset_k * self.tile_hdim,), mdKV[None, head_idx_kv] + (seqlen.padded_offset_k * self.tile_hdim,), mdKV[None, head_idx_kv] ) gdKV_p = cute.local_tile( mdKV_cur, (self.tile_n * self.tile_hdim,), (n_block,) diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 071738f9098..fff327fc564 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -628,6 +628,9 @@ def _flash_attn_bwd( total_k = k.shape[0] seqlen_k = max_seqlen_k if max_seqlen_k is not None else total_k + seqlen_q_rounded = (seqlen_q + m_block_size - 1) // m_block_size * m_block_size + seqlen_k_rounded = (seqlen_k + n_block_size - 1) // n_block_size * n_block_size + num_head_kv = k.shape[-2] head_dim_v = v.shape[-1] @@ -726,7 +729,6 @@ def _flash_attn_bwd( head_dim_rounded = (head_dim + 32 - 1) // 32 * 32 if cu_seqlens_q is None: - seqlen_q_rounded = (seqlen_q + m_block_size - 1) // m_block_size * m_block_size dq_accum = torch.empty( batch_size, num_head, @@ -754,7 +756,6 @@ def _flash_attn_bwd( if dKV_postprocess: head_dim_v_rounded = (head_dim_v + 32 - 1) // 32 * 32 if cu_seqlens_k is None: - seqlen_k_rounded = (seqlen_k + n_block_size - 1) // n_block_size * n_block_size num_n_blocks = seqlen_k_rounded // n_block_size if cluster_size == 2 and num_n_blocks % cluster_size != 0: seqlen_k_rounded = seqlen_k_rounded + n_block_size @@ -796,13 +797,11 @@ def _flash_attn_bwd( current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) if deterministic: - seqlen_q_rounded = (seqlen_q + m_block_size - 1) // m_block_size * m_block_size dQ_semaphore = torch.zeros(batch_size, num_head, seqlen_q_rounded // m_block_size, 1, dtype=torch.int32, device="cuda") else: dQ_semaphore = None if deterministic and qhead_per_kvhead > 1: - seqlen_k_rounded = (seqlen_k + n_block_size - 1) // n_block_size * n_block_size dK_semaphore = torch.zeros(batch_size, num_head_kv, seqlen_k_rounded // n_block_size, 2, dtype=torch.int32, device="cuda") dV_semaphore = torch.zeros(batch_size, num_head_kv, seqlen_k_rounded // n_block_size, 2, dtype=torch.int32, device="cuda") else: diff --git a/flash_attn/cute/seqlen_info.py b/flash_attn/cute/seqlen_info.py index 426fb4e7cb6..666c9d91e09 100644 --- a/flash_attn/cute/seqlen_info.py +++ b/flash_attn/cute/seqlen_info.py @@ -38,6 +38,8 @@ def create( class SeqlenInfoQK: offset_q: cutlass.Int32 offset_k: cutlass.Int32 + padded_offset_q: cutlass.Int32 + padded_offset_k: cutlass.Int32 seqlen_q: cutlass.Int32 seqlen_k: cutlass.Int32 has_cu_seqlens_q: cutlass.Constexpr[bool] @@ -61,6 +63,16 @@ def create( ): offset_q = 0 if const_expr(mCuSeqlensQ is None) else mCuSeqlensQ[batch_idx] offset_k = 0 if const_expr(mCuSeqlensK is None) else mCuSeqlensK[batch_idx] + padded_offset_q = ( + 0 + if const_expr(mCuSeqlensQ is None) + else (offset_q + batch_idx * tile_m) // tile_m * tile_m + ) + padded_offset_k = ( + 0 + if const_expr(mCuSeqlensK is None) + else (offset_k + batch_idx * tile_n) // tile_n * tile_n + ) if const_expr(mSeqUsedQ is not None): seqlen_q = mSeqUsedQ[batch_idx] else: @@ -84,6 +96,8 @@ def create( return SeqlenInfoQK( offset_q, offset_k, + padded_offset_q, + padded_offset_k, seqlen_q, seqlen_k, has_cu_seqlens_q, @@ -106,11 +120,7 @@ def offset_batch_Q( idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mQ) - 1 - dim) return mQ[idx] else: - offset_q = ( - self.offset_q - if const_expr(not padded) - else (self.offset_q + batch_idx * self.tile_m) // self.tile_m * self.tile_m - ) + offset_q = self.offset_q if const_expr(not padded) else self.padded_offset_q offset = offset_q if const_expr(cute.rank(mQ.shape[0]) == 1) else (0, offset_q) idx = (offset,) + (0,) * (cute.rank(mQ) - 1) return cute.domain_offset(idx, mQ) @@ -127,10 +137,6 @@ def offset_batch_K( idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mK) - 1 - dim) return mK[idx] else: - offset_k = ( - self.offset_k - if const_expr(not padded) - else (self.offset_k + batch_idx * self.tile_n) // self.tile_n * self.tile_n - ) + offset_k = self.offset_k if const_expr(not padded) else self.padded_offset_k idx = (offset_k,) + (0,) * (cute.rank(mK) - 1) return cute.domain_offset(idx, mK) From 92744ce4b70f169c569d589d9002fb60b55214fa Mon Sep 17 00:00:00 2001 From: jayhshah Date: Fri, 9 Jan 2026 22:39:40 +0000 Subject: [PATCH 15/15] remove tile sizes from seqlen info class vars --- flash_attn/cute/seqlen_info.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/flash_attn/cute/seqlen_info.py b/flash_attn/cute/seqlen_info.py index 666c9d91e09..6d8c6feb279 100644 --- a/flash_attn/cute/seqlen_info.py +++ b/flash_attn/cute/seqlen_info.py @@ -46,8 +46,6 @@ class SeqlenInfoQK: has_cu_seqlens_k: cutlass.Constexpr[bool] has_seqused_q: cutlass.Constexpr[bool] has_seqused_k: cutlass.Constexpr[bool] - tile_m: cutlass.Constexpr[cutlass.Int32] - tile_n: cutlass.Constexpr[cutlass.Int32] @staticmethod def create( @@ -104,8 +102,6 @@ def create( has_cu_seqlens_k, has_seqused_q, has_seqused_k, - tile_m, - tile_n, ) def offset_batch_Q(