diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index 67a28458649..85f4b1bfac2 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -130,10 +130,6 @@ def __init__( self.is_split_kv = is_split_kv self.pack_gqa = pack_gqa self.q_subtile_factor = q_subtile_factor - if pack_gqa: - assert m_block_size % self.qhead_per_kvhead == 0, ( - "For PackGQA, m_block_size must be divisible by qhead_per_kvhead" - ) assert not (self.is_split_kv and self.head_dim_v_padded >= 192), ( "SplitKV is not supported for hdim >= 192" ) @@ -180,8 +176,10 @@ def __init__( ) ) + self.use_tma_Q = not (self.pack_gqa and self.m_block_size % self.qhead_per_kvhead != 0) + if self.q_stage == 1: - if not self.use_tma_KV: + if not self.use_tma_KV or not self.use_tma_Q: self.empty_warp_ids = self.empty_warp_ids + self.load_warp_ids self.load_warp_ids = self.softmax1_warp_ids else: @@ -352,7 +350,13 @@ def __call__( if const_expr(self.q_dtype != self.v_dtype): raise TypeError(f"Type mismatch: {self.q_dtype} != {self.v_dtype}") self._setup_attributes() - self.use_tma_O = self.arch >= Arch.sm_90 and mCuSeqlensQ is None and mSeqUsedQ is None + self.use_tma_O = ( + self.arch >= Arch.sm_90 + and mCuSeqlensQ is None + and mSeqUsedQ is None + and not (self.pack_gqa and self.m_block_size % self.qhead_per_kvhead != 0) + and not (self.pack_gqa and self.is_split_kv) + ) # This can be tuned # This is currently very ad-hoc, we should tune it systematically self.ex2_emu_freq = 0 @@ -468,14 +472,24 @@ def __call__( tma_load_op = cpasync.CopyBulkTensorTileG2SOp(cta_group) tma_store_op = cpasync.CopyBulkTensorTileS2GOp() - tma_atom_Q, mQ = cute.nvgpu.make_tiled_tma_atom_A( - tma_load_op, - mQ, - cute.select(sQ_layout, mode=[0, 1, 2]), - self.mma_tiler_qk, - tiled_mma_qk, - cta_layout_vmnk.shape, - ) + if const_expr(self.use_tma_Q): + tma_atom_Q, mQ = cute.nvgpu.make_tiled_tma_atom_A( + tma_load_op, + mQ, + cute.select(sQ_layout, mode=[0, 1, 2]), + self.mma_tiler_qk, + tiled_mma_qk, + cta_layout_vmnk.shape, + ) + gmem_tiled_copy_Q = None + else: + tma_atom_Q = None + async_copy_elems = 128 // self.q_dtype.width + num_load_threads = cute.arch.WARP_SIZE * len(self.load_warp_ids) + threads_per_row = self.head_dim_padded // async_copy_elems + gmem_tiled_copy_Q = copy_utils.tiled_copy_2d( + self.q_dtype, threads_per_row, num_load_threads, async_copy_elems, is_async=True + ) tma_atom_K = None tma_atom_V = None @@ -643,6 +657,7 @@ class SharedStorage: tP_layout, sV_layout, sO_layout, + gmem_tiled_copy_Q, gmem_tiled_copy_O, tiled_mma_qk, tiled_mma_pv, @@ -673,7 +688,7 @@ def kernel( mSeqUsedQ: Optional[cute.Tensor], mSeqUsedK: Optional[cute.Tensor], mPageTable: Optional[cute.Tensor], - tma_atom_Q: cute.CopyAtom, + tma_atom_Q: Optional[cute.CopyAtom], tma_atom_K: Optional[cute.CopyAtom], tma_atom_V: Optional[cute.CopyAtom], tma_atom_O: Optional[cute.CopyAtom], @@ -688,6 +703,7 @@ def kernel( tP_layout: cute.ComposedLayout, sV_layout: cute.ComposedLayout, sO_layout: cute.ComposedLayout, + gmem_tiled_copy_Q: Optional[cute.TiledCopy], gmem_tiled_copy_O: Optional[cute.TiledCopy], tiled_mma_qk: cute.TiledMma, tiled_mma_pv: cute.TiledMma, @@ -778,15 +794,28 @@ def kernel( softmax_correction_threads_cluster = ThreadCooperativeGroup( cute.arch.WARP_SIZE * len(self.softmax0_warp_ids + self.correction_warp_ids) * self.cta_group_size ) - pipeline_q = pipeline_custom.PipelineTmaUmma.create( - barrier_storage=storage.mbar_load_Q.data_ptr(), - num_stages=self.q_stage, - producer_group=tma_warp, - consumer_group=mma_warp, - tx_count=self.tma_copy_bytes["Q"], - cta_layout_vmnk=cta_layout_vmnk, - defer_sync=True, - ) + if const_expr(self.use_tma_Q): + pipeline_q = pipeline_custom.PipelineTmaUmma.create( + barrier_storage=storage.mbar_load_Q.data_ptr(), + num_stages=self.q_stage, + producer_group=tma_warp, + consumer_group=mma_warp, + tx_count=self.tma_copy_bytes["Q"], + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + else: + cpasync_producer_group_q = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len(self.load_warp_ids) * cute.arch.WARP_SIZE + ) + pipeline_q = pipeline_custom.PipelineAsyncUmma.create( + barrier_storage=storage.mbar_load_Q.data_ptr(), + num_stages=self.q_stage, + producer_group=cpasync_producer_group_q, + consumer_group=mma_warp, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) if const_expr(self.use_tma_KV): pipeline_kv = pipeline_custom.PipelineTmaUmma.create( barrier_storage=storage.mbar_load_KV.data_ptr(), @@ -970,6 +999,7 @@ def kernel( tma_atom_Q, tma_atom_K, tma_atom_V, + gmem_tiled_copy_Q, pipeline_q, pipeline_kv, block_info, @@ -1129,9 +1159,10 @@ def load( sK: cute.Tensor, sV: cute.Tensor, mPageTable: Optional[cute.Tensor], - tma_atom_Q: cute.CopyAtom, + tma_atom_Q: Optional[cute.CopyAtom], tma_atom_K: Optional[cute.CopyAtom], tma_atom_V: Optional[cute.CopyAtom], + gmem_tiled_copy_Q: Optional[cute.TiledCopy], pipeline_q: pipeline.PipelineAsync, pipeline_kv: pipeline.PipelineAsync, block_info: BlockInfo, @@ -1143,6 +1174,14 @@ def load( num_load_threads = len(self.load_warp_ids) * cute.arch.WARP_SIZE tidx = cute.arch.thread_idx()[0] % num_load_threads warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + issue_kv_for_this_warp = ( + const_expr(not self.use_tma_KV or len(self.load_warp_ids) == 1) or + warp_idx == self.load_warp_ids[0] + ) + issue_q_for_this_warp = ( + const_expr(not self.use_tma_Q or len(self.load_warp_ids) == 1) or + warp_idx == self.load_warp_ids[0] + ) q_producer_phase = Int32(1) kv_producer_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Producer, self.kv_stage @@ -1153,11 +1192,6 @@ def load( m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx seqlen = SeqlenInfoCls(batch_idx) mQ_cur = seqlen.offset_batch_Q(mQ, batch_idx, dim=3)[None, None, head_idx] - tiler_gQ = ((self.mma_tiler_qk[0] * self.q_stage), self.head_dim_padded) - gQ = cute.local_tile(mQ_cur, tiler_gQ, (m_block, 0)) # (128 * 2, 128) - gQ = layout_utils.select( - cute.flat_divide(gQ, (self.mma_tiler_qk[0],)), mode=[0, 2, 1] - ) # (128, 128, 2) head_idx_kv = ( head_idx // self.qhead_per_kvhead if const_expr(not self.pack_gqa) else head_idx @@ -1179,12 +1213,32 @@ def load( gV = cute.local_tile( mV_cur, cute.select(self.mma_tiler_pv, mode=[1, 2]), (0, None, None) ) - tSgQ = thr_mma_qk.partition_A(gQ) tSgK = thr_mma_qk.partition_B(gK) tOgV = thr_mma_pv.partition_B(gV) - load_Q_fn, _, _ = copy_utils.tma_get_copy_fn( - tma_atom_Q, 0, cute.make_layout(1), tSgQ, sQ - ) + if const_expr(self.use_tma_Q): + tiler_gQ = ((self.mma_tiler_qk[0] * self.q_stage), self.head_dim_padded) + gQ = cute.local_tile(mQ_cur, tiler_gQ, (m_block, 0)) # (128 * 2, 128) + gQ = layout_utils.select( + cute.flat_divide(gQ, (self.mma_tiler_qk[0],)), mode=[0, 2, 1] + ) # (128, 128, 2) + tSgQ = thr_mma_qk.partition_A(gQ) + load_Q_fn, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_Q, 0, cute.make_layout(1), tSgQ, sQ + ) + load_Q = partial(self.load_Q, load_Q_fn, pipeline_q=pipeline_q, phase=q_producer_phase) + else: + assert gmem_tiled_copy_Q is not None + load_Q = partial( + self.load_Q_non_tma, + mQ_cur, + sQ, + gmem_tiled_copy_Q, + pipeline_q, + tidx, + seqlen.seqlen_q, + m_block, + phase=q_producer_phase, + ) if const_expr(self.use_tma_KV): tKsK, tKgK = cpasync.tma_partition( @@ -1223,7 +1277,6 @@ def load( tKsK, tKgK = None, None tVsV, tVgV = None, None - load_Q = partial(self.load_Q, load_Q_fn, pipeline_q=pipeline_q, phase=q_producer_phase) load_K = partial( self.load_KV, tma_atom_K, @@ -1258,24 +1311,19 @@ def load( ) if const_expr(not self.use_tma_KV): paged_kv_manager.load_page_table(n_block_first) - load_K(block=n_block_max - 1, producer_state=kv_producer_state, page_idx=page_idx) # K0 + if issue_kv_for_this_warp: + load_K(block=n_block_max - 1, producer_state=kv_producer_state, page_idx=page_idx) # K0 # load_K(block=n_block_max - 1, producer_state=kv_producer_state, page_idx=page_idx, extra_tx_count=self.tma_copy_bytes["Q"]) # K0 - if const_expr(len(self.load_warp_ids) == 1) or warp_idx == self.load_warp_ids[0]: - # load_Q(block=0, stage=0) # Q0 - pipeline_q.producer_acquire_w_index_phase(0, q_producer_phase) - # pipeline_q.sync_object_empty.wait(0, q_producer_phase) - tma_bar_ptr = pipeline_q.sync_object_full.get_barrier(0) - # tma_bar_ptr = pipeline_kv.producer_get_barrier(kv_producer_state) - load_Q_fn(src_idx=0, dst_idx=0, tma_bar_ptr=tma_bar_ptr) - kv_producer_state.advance() - if const_expr(self.q_stage == 2) and (const_expr(len(self.load_warp_ids) == 1) or warp_idx == self.load_warp_ids[0]): - # load_Q(block=1, stage=1) # Q1 - pipeline_q.producer_acquire_w_index_phase(1, q_producer_phase) - tma_bar_ptr = pipeline_q.sync_object_full.get_barrier(1) - load_Q_fn(src_idx=1, dst_idx=1, tma_bar_ptr=tma_bar_ptr) + if issue_q_for_this_warp: + load_Q(block=0, stage=0) + if issue_kv_for_this_warp: + kv_producer_state.advance() + if const_expr(self.q_stage == 2) and issue_q_for_this_warp: + load_Q(block=1, stage=1) q_producer_phase ^= 1 - load_V(block=n_block_max - 1, producer_state=kv_producer_state, page_idx=page_idx) # V0 - kv_producer_state.advance() + if issue_kv_for_this_warp: + load_V(block=n_block_max - 1, producer_state=kv_producer_state, page_idx=page_idx) # V0 + kv_producer_state.advance() for i in cutlass.range(n_block_max - 1 - n_block_min, unroll=1): n_block = n_block_max - 2 - i page_idx = ( @@ -1286,10 +1334,11 @@ def load( if const_expr(not self.use_tma_KV): paged_kv_manager.load_page_table(n_block) # if cute.arch.thread_idx()[0] % 32 == 0: cute.printf("n_block = {}, page_idx = {}", n_block, page_idx) - load_K(block=n_block, producer_state=kv_producer_state, page_idx=page_idx) # Ki - kv_producer_state.advance() - load_V(block=n_block, producer_state=kv_producer_state, page_idx=page_idx) # Vi - kv_producer_state.advance() + if issue_kv_for_this_warp: + load_K(block=n_block, producer_state=kv_producer_state, page_idx=page_idx) # Ki + kv_producer_state.advance() + load_V(block=n_block, producer_state=kv_producer_state, page_idx=page_idx) # Vi + kv_producer_state.advance() else: kv_producer_state, q_producer_phase = produce_block_sparse_loads_sm100( @@ -1313,9 +1362,10 @@ def load( work_tile = tile_scheduler.get_current_work() # End of persistent scheduler loop - pipeline_kv.producer_tail(kv_producer_state) - # This is equivalent to pipeline_q.producer_tail - if const_expr(len(self.load_warp_ids) == 1) or warp_idx == self.load_warp_ids[0]: + if issue_kv_for_this_warp: + pipeline_kv.producer_tail(kv_producer_state) + # This is equivalent to pipeline_q.producer_tail for the TMA-Q producer warp. + if issue_q_for_this_warp: pipeline_q.producer_acquire_w_index_phase(self.q_stage - 1, q_producer_phase) @cute.jit @@ -2149,12 +2199,14 @@ def correction_loop( mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3)[None, None, head_idx, split_idx] else: mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3)[None, None, head_idx] - tiler_gO = ((self.mma_tiler_pv[0] * self.q_stage), self.head_dim_v_padded) - gO = cute.local_tile(mO_cur, tiler_gO, (m_block, 0)) # (128 * 2, 128) - gO = layout_utils.select( - cute.flat_divide(gO, (self.mma_tiler_pv[0],)), mode=[0, 2, 1] - ) # (128, 128, 2) - gO = cute.flat_divide(gO, (self.mma_tiler_pv[0] // self.cta_group_size,))[None, mma_tile_coord_v, None, None] + gO = None + if const_expr(self.use_tma_O or not self.pack_gqa): + tiler_gO = ((self.mma_tiler_pv[0] * self.q_stage), self.head_dim_v_padded) + gO = cute.local_tile(mO_cur, tiler_gO, (m_block, 0)) # (128 * 2, 128) + gO = layout_utils.select( + cute.flat_divide(gO, (self.mma_tiler_pv[0],)), mode=[0, 2, 1] + ) # (128, 128, 2) + gO = cute.flat_divide(gO, (self.mma_tiler_pv[0] // self.cta_group_size,))[None, mma_tile_coord_v, None, None] # Default LSE to -inf for invalid split_idx tiles stats = [(0.0, -Float32.inf if const_expr(mLSE is not None or learnable_sink is not None) else None, True)] * self.q_stage @@ -2255,6 +2307,7 @@ def correction_loop( pipeline_o_acc.consumer_wait_w_index_phase(stage, o_corr_consumer_phase) if const_expr(not self.use_correction_warps_for_epi): pipeline_o_epi.producer_acquire_w_index_phase(stage, corr_epi_producer_phase) + gO_stage = gO[None, None, stage] if const_expr(gO is not None) else None self.correction_epilogue( thr_mma_pv, tOtO[None, None, None, stage], @@ -2265,7 +2318,7 @@ def correction_loop( scale, sO[None, None, stage], mO_cur, - gO[None, None, stage], + gO_stage, gmem_tiled_copy_O, ) # Signal for the next work tile that O buffers in tmem are already read, so @@ -2335,7 +2388,6 @@ def correction_loop( mLSE_cur = cute.domain_offset((offset,), mLSE[None, head_idx]) for stage in cutlass.range_constexpr(self.q_stage): m_tile_idx = (m_block * self.q_stage + stage) * self.cta_group_size + mma_tile_coord_v - gLSE = cute.local_tile(mLSE_cur, (self.m_block_size,), (m_tile_idx,)) row_sum, row_max, acc_O_mn_row_is_zero_or_nan = stats[stage] # if tidx == 0 and stage <= 1: # cute.printf("row_sum = {}, row_max = {}, acc_O_mn_row_is_zero_or_nan = {}\n", row_sum, row_max, acc_O_mn_row_is_zero_or_nan) @@ -2350,9 +2402,21 @@ def correction_loop( if const_expr(not self.pack_gqa) else seqlen.seqlen_q * self.qhead_per_kvhead ) - if tidx < seqlen_q - m_tile_idx * self.m_block_size: - # This actually just works with PackGQA too - gLSE[tidx] = lse + if const_expr(not self.pack_gqa or self.m_block_size % self.qhead_per_kvhead == 0): + gLSE = cute.local_tile(mLSE_cur, (self.m_block_size,), (m_tile_idx,)) + if tidx < seqlen_q - m_tile_idx * self.m_block_size: + # This actually just works with PackGQA too + gLSE[tidx] = lse + else: + idx = m_tile_idx * self.m_block_size + tidx + if idx < seqlen_q: + m_idx = idx // self.qhead_per_kvhead + h_idx = idx - m_idx * self.qhead_per_kvhead + lse_ptr_i64 = utils.elem_pointer(mLSE_cur, ((h_idx, m_idx),)).toint() + lse_gmem_ptr = cute.make_ptr( + mLSE_cur.element_type, lse_ptr_i64, cute.AddressSpace.gmem, assumed_align=4 + ) + cute.make_tensor(lse_gmem_ptr, (1,))[0] = lse # Advance to next tile tile_scheduler.advance_to_next_work() @@ -2507,7 +2571,7 @@ def correction_epilogue( def _store_O_to_gmem( self, sO_stage: cute.Tensor, - gO: cute.Tensor, + gO: Optional[cute.Tensor], mO_cur: cute.Tensor, gmem_tiled_copy_O: cute.TiledCopy, tidx: Int32, @@ -2518,7 +2582,6 @@ def _store_O_to_gmem( gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx) tOsO = gmem_thr_copy_O.partition_S(sO_stage) cO = cute.make_identity_tensor((self.m_block_size, self.head_dim_v_padded)) - tOgO = gmem_thr_copy_O.partition_D(gO) tOcO = gmem_thr_copy_O.partition_S(cO) t0OcO = gmem_tiled_copy_O.get_slice(0).partition_S(cO) tOpO = copy_utils.predicate_k(tOcO, limit=mO_cur.shape[1]) @@ -2534,6 +2597,8 @@ def _store_O_to_gmem( cute.autovec_copy(tOsO, tOrO) # copy acc O from rmem to gmem if const_expr(not self.pack_gqa): + assert gO is not None + tOgO = gmem_thr_copy_O.partition_D(gO) for rest_m in cutlass.range_constexpr(cute.size(tOrO.shape[1])): if ( t0OcO[0, rest_m, 0][0] < seqlen_q - m_tile_idx * self.m_block_size - tOcO[0][0] @@ -2578,12 +2643,14 @@ def epilogue_s2g( mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3)[None, None, head_idx, split_idx] else: mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3)[None, None, head_idx] - tiler_gO = ((self.mma_tiler_pv[0] * self.q_stage), self.head_dim_v_padded) - gO = cute.local_tile(mO_cur, tiler_gO, (m_block, 0)) # (128 * 2, 128) - gO = layout_utils.select( - cute.flat_divide(gO, (self.mma_tiler_pv[0],)), mode=[0, 2, 1] - ) # (128, 128, 2) - gO = cute.flat_divide(gO, (self.mma_tiler_pv[0] // self.cta_group_size,))[None, mma_tile_coord_v, None, None] + gO = None + if const_expr(self.use_tma_O or not self.pack_gqa): + tiler_gO = ((self.mma_tiler_pv[0] * self.q_stage), self.head_dim_v_padded) + gO = cute.local_tile(mO_cur, tiler_gO, (m_block, 0)) # (128 * 2, 128) + gO = layout_utils.select( + cute.flat_divide(gO, (self.mma_tiler_pv[0],)), mode=[0, 2, 1] + ) # (128, 128, 2) + gO = cute.flat_divide(gO, (self.mma_tiler_pv[0] // self.cta_group_size,))[None, mma_tile_coord_v, None, None] if const_expr(self.use_tma_O): store_O, _, _ = copy_utils.tma_get_copy_fn( @@ -2610,8 +2677,9 @@ def epilogue_s2g( pipeline_o_epi.consumer_wait_w_index_phase(stage, epi_consumer_phase) # 2. copy O0 / O1 to gmem m_tile_idx = (m_block * self.q_stage + stage) * self.cta_group_size + mma_tile_coord_v + gO_stage = gO[None, None, stage] if const_expr(gO is not None) else None self._store_O_to_gmem( - sO[None, None, stage], gO[None, None, stage], mO_cur, gmem_tiled_copy_O, + sO[None, None, stage], gO_stage, mO_cur, gmem_tiled_copy_O, tidx, seqlen.seqlen_q, m_tile_idx, ) pipeline_o_epi.consumer_release_w_index(stage) @@ -2633,6 +2701,39 @@ def load_Q( pipeline_q.producer_acquire_w_index_phase(stage, phase) load_Q_fn(src_idx=block, dst_idx=stage, tma_bar_ptr=pipeline_q.sync_object_full.get_barrier(stage)) + def load_Q_non_tma( + self, + mQ: cute.Tensor, + sQ: cute.Tensor, + gmem_tiled_copy_Q: cute.TiledCopy, + pipeline_q: pipeline.PipelineAsync, + tidx: Int32, + seqlen_q: Int32, + m_block: Int32, + block: Int32, + stage: int, + phase: Int32, + ): + assert self.cta_group_size == 1, "cta_group_size must be 1 for non-tma Q load" + pipeline_q.producer_acquire_w_index_phase(stage, phase) + pack_gqa = PackGQA( + self.m_block_size, + self.head_dim_padded, + self.check_hdim_oob, + self.qhead_per_kvhead, + ) + sQ_stage = sQ[None, None, None, stage] + sQ_pi = cute.make_tensor( + sQ_stage.iterator, + cute.make_layout( + (sQ_stage.shape[0][0], (sQ_stage.shape[0][1], sQ_stage.shape[2])), + stride=(sQ_stage.stride[0][0], (sQ_stage.stride[0][1], sQ_stage.stride[2])), + ), + ) + pack_gqa.load_Q(mQ, sQ_pi, gmem_tiled_copy_Q, tidx, m_block * self.q_stage + block, seqlen_q) + cute.arch.cp_async_commit_group() + pipeline_q.sync_object_full.arrive_cp_async_mbarrier(stage) + @cute.jit def load_KV( self, diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 8e56a2688a3..fa1d434c6fd 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -485,10 +485,6 @@ def _flash_attn_fwd( if pack_gqa and num_splits != 1 and cu_seqlens_q is None: pack_gqa = False - if arch // 10 in [10, 11]: - if pack_gqa and (128 % qhead_per_kvhead != 0): - pack_gqa = False - if max_seqlen_q is None: max_seqlen_q = seqlen_q if cu_seqlens_q is None else total_q if max_seqlen_k is None: @@ -535,6 +531,7 @@ def _flash_attn_fwd( and int(math.ceil(head_dim / 16) * 16) == 128 and int(math.ceil(head_dim_v / 16) * 16) == 128 and seqlen_q_packgqa > 2 * tile_m + and (tile_m % qhead_per_kvhead == 0 or not pack_gqa) ) # hash score and mask mods for compile cache diff --git a/tests/cute/test_flash_attn_fast.py b/tests/cute/test_flash_attn_fast.py index f93af9bb44a..433859d94d8 100644 --- a/tests/cute/test_flash_attn_fast.py +++ b/tests/cute/test_flash_attn_fast.py @@ -32,7 +32,7 @@ # --------------------------------------------------------------------------- @pytest.mark.parametrize("dtype", [torch.bfloat16]) -@pytest.mark.parametrize("mha_type", ["mha", "gqa"]) +@pytest.mark.parametrize("mha_type", ["mha", "gqa", "mqa"]) @pytest.mark.parametrize("num_splits", [1, 3]) @pytest.mark.parametrize("causal", [False, True]) @pytest.mark.parametrize("d", [64, 128]) @@ -55,7 +55,7 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, causal, num_splits, mha_type, torch.cuda.empty_cache() batch_size = 4 nheads = 6 - nheads_kv = nheads if mha_type == "mha" else 3 + nheads_kv = nheads if mha_type == "mha" else (3 if mha_type == "gqa" else 1) q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype).to(dtype).requires_grad_() k_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype).to(dtype).requires_grad_() @@ -108,7 +108,7 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, causal, num_splits, mha_type, # --------------------------------------------------------------------------- @pytest.mark.parametrize("dtype", [torch.bfloat16]) -@pytest.mark.parametrize("mha_type", ["mha", "gqa"]) +@pytest.mark.parametrize("mha_type", ["mha", "gqa", "mqa"]) @pytest.mark.parametrize("causal", [False, True]) @pytest.mark.parametrize("d", [64, 128]) @pytest.mark.parametrize("seqlen", [128, 256, 1024]) @@ -121,7 +121,7 @@ def test_flash_attn_varlen_output(seqlen, d, causal, mha_type, dtype): random.seed(seed) batch_size = 9 nheads = 6 - nheads_kv = nheads if mha_type == "mha" else 3 + nheads_kv = nheads if mha_type == "mha" else (3 if mha_type == "gqa" else 1) q_ref = torch.randn(batch_size, seqlen, nheads, d, device=device, dtype=dtype).to(dtype).requires_grad_() k_ref = torch.randn(batch_size, seqlen, nheads_kv, d, device=device, dtype=dtype).to(dtype).requires_grad_() @@ -177,7 +177,7 @@ def test_flash_attn_varlen_output(seqlen, d, causal, mha_type, dtype): # --------------------------------------------------------------------------- @pytest.mark.parametrize("dtype", [torch.bfloat16]) -@pytest.mark.parametrize("mha_type", ["mha", "gqa"]) +@pytest.mark.parametrize("mha_type", ["mha", "gqa", "mqa"]) @pytest.mark.parametrize("causal", [False, True]) @pytest.mark.parametrize("d", [64, 128]) @pytest.mark.parametrize("seqlen", [128, 256]) @@ -194,7 +194,7 @@ def test_flash_attn_varlen_unpad_output(seqlen, d, causal, mha_type, unpad_q, un random.seed(seed) batch_size = 9 nheads = 6 - nheads_kv = nheads if mha_type == "mha" else 3 + nheads_kv = nheads if mha_type == "mha" else (3 if mha_type == "gqa" else 1) q = torch.randn(batch_size, seqlen, nheads, d, device=device, dtype=dtype) k = torch.randn(batch_size, seqlen, nheads_kv, d, device=device, dtype=dtype) @@ -272,6 +272,14 @@ def test_flash_attn_varlen_unpad_output(seqlen, d, causal, mha_type, unpad_q, un g = torch.randn_like(out_unpad) dq_in, dk_in, dv_in = torch.autograd.grad(out_unpad, (q_in, k_in, v_in), g) + # Mask out padding positions again + k_mask = rearrange(key_padding_mask, "b s -> b s 1 1") + if not unpad_q: + dq_in = dq_in.clone().masked_fill_(~q_mask, 0.0) + if not unpad_kv: + dk_in = dk_in.clone().masked_fill_(~k_mask, 0.0) + dv_in = dv_in.clone().masked_fill_(~k_mask, 0.0) + assert dq_in.isfinite().all(), "dq contains non-finite values" assert dk_in.isfinite().all(), "dk contains non-finite values" assert dv_in.isfinite().all(), "dv contains non-finite values"