diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index c4a569fa0d1..915315d461b 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -27,6 +27,7 @@ import cutlass.cute.nvgpu.tcgen05 as tcgen05 import cutlass.utils.blackwell_helpers as sm100_utils_basic +from flash_attn.cute.paged_kv import PagedKVManager import flash_attn.cute.utils as utils from flash_attn.cute import copy_utils import flash_attn.cute.pipeline as pipeline @@ -76,7 +77,9 @@ def __init__( is_persistent: bool = True, score_mod: cutlass.Constexpr | None = None, has_aux_tensors: cutlass.Constexpr = False, + paged_kv_non_tma: bool = False, ): + self.use_tma_KV = not paged_kv_non_tma # self.dtype = dtype # padding head_dim to a multiple of 16 as k_block_size hdim_multiple_of = 16 @@ -127,11 +130,15 @@ def __init__( if self.overlap_sO_sQ: self.is_persistent = False + assert self.use_tma_KV or not (self.check_hdim_oob or self.check_hdim_v_oob), ( + "Paged KV does not support irregular head dim" + ) + self.softmax0_warp_ids = (0, 1, 2, 3) self.softmax1_warp_ids = (4, 5, 6, 7) self.correction_warp_ids = (8, 9, 10, 11) self.mma_warp_id = 12 - self.load_warp_id = 13 + self.load_warp_ids = (13,) self.epilogue_warp_ids = (14,) self.empty_warp_ids = (15,) SM100_TMEM_CAPACITY_COLUMNS = 512 @@ -143,7 +150,7 @@ def __init__( *self.softmax1_warp_ids, *self.correction_warp_ids, self.mma_warp_id, - self.load_warp_id, + *self.load_warp_ids, *self.epilogue_warp_ids, *self.empty_warp_ids, ) @@ -449,11 +456,20 @@ def __call__( mLSE.iterator, cute.make_layout(shape_LSE_packed, stride=stride_LSE_packed) ) + self.tma_copy_bytes = { + name: cute.size_in_bytes(mX.element_type, cute.select(layout, mode=[0, 1, 2])) + for name, mX, layout in [ + ("Q", mQ, sQ_layout), + ("K", mK, sK_layout), + ("V", mV, sV_layout), + ] + } + # TMA load for Q tma_load_op = cpasync.CopyBulkTensorTileG2SOp(cta_group) tma_store_op = cpasync.CopyBulkTensorTileS2GOp() - tma_atom_Q, tma_tensor_Q = cute.nvgpu.make_tiled_tma_atom_A( + tma_atom_Q, mQ = cute.nvgpu.make_tiled_tma_atom_A( tma_load_op, mQ, cute.select(sQ_layout, mode=[0, 1, 2]), @@ -462,24 +478,32 @@ def __call__( self.cluster_layout_vmnk.shape, ) - # TMA load for K - tma_atom_K, tma_tensor_K = cute.nvgpu.make_tiled_tma_atom_B( - tma_load_op, - mK, - cute.select(sK_layout, mode=[0, 1, 2]), - self.mma_tiler_qk, - tiled_mma_qk, - self.cluster_layout_vmnk.shape, - ) - # TMA load for V - tma_atom_V, tma_tensor_V = cute.nvgpu.make_tiled_tma_atom_B( - tma_load_op, - mV, - cute.select(sV_layout, mode=[0, 1, 2]), - self.mma_tiler_pv, - tiled_mma_pv, - self.cluster_layout_vmnk.shape, - ) + if const_expr(self.use_tma_KV): + # TMA load for K + tma_atom_K, mK = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, + mK, + cute.select(sK_layout, mode=[0, 1, 2]), + self.mma_tiler_qk, + tiled_mma_qk, + self.cluster_layout_vmnk.shape, + ) + # TMA load for V + tma_atom_V, mV = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, + mV, + cute.select(sV_layout, mode=[0, 1, 2]), + self.mma_tiler_pv, + tiled_mma_pv, + self.cluster_layout_vmnk.shape, + ) + else: + assert self.use_tma_O, "Loading O and K/V will contend for the empty warp." + self.epilogue_warp_ids = (13,) + self.load_warp_ids = (14, 15) + self.empty_warp_ids = () + tma_atom_K = None + tma_atom_V = None o_cta_v_layout = cute.composition(cute.make_identity_layout(mO.shape), self.epi_tile) @@ -514,15 +538,7 @@ def __call__( assert self.m_block_size % tO_layout.shape[0] == 0 vO_layout = cute.make_layout((1, async_copy_elems)) gmem_tiled_copy_O = cute.make_tiled_copy_tv(atom_universal_copy, tO_layout, vO_layout) - - self.tma_copy_bytes = { - name: cute.size_in_bytes(mX.element_type, cute.select(layout, mode=[0, 1, 2])) - for name, mX, layout in [ - ("Q", mQ, sQ_layout), - ("K", mK, sK_layout), - ("V", mV, sV_layout), - ] - } + print("gmem_tiled_copy_O: ", gmem_tiled_copy_O) if const_expr(mCuSeqlensQ is not None or mSeqUsedQ is not None): TileScheduler = SingleTileVarlenScheduler @@ -638,9 +654,9 @@ class SharedStorage: # Launch the kernel synchronously self.kernel( - tma_tensor_Q, - tma_tensor_K, - tma_tensor_V, + mQ, + mK, + mV, mO, mLSE, mCuSeqlensQ, @@ -693,8 +709,8 @@ def kernel( mSeqUsedK: Optional[cute.Tensor], mPageTable: Optional[cute.Tensor], tma_atom_Q: cute.CopyAtom, - tma_atom_K: cute.CopyAtom, - tma_atom_V: cute.CopyAtom, + tma_atom_K: Optional[cute.CopyAtom], + tma_atom_V: Optional[cute.CopyAtom], tma_atom_O: Optional[cute.CopyAtom], softmax_scale_log2: Float32, softmax_scale: Float32 | None, @@ -733,8 +749,10 @@ def kernel( # Prefetch tma descriptor if warp_idx == 0: cpasync.prefetch_descriptor(tma_atom_Q) - cpasync.prefetch_descriptor(tma_atom_K) - cpasync.prefetch_descriptor(tma_atom_V) + if const_expr(tma_atom_K is not None): + cpasync.prefetch_descriptor(tma_atom_K) + if const_expr(tma_atom_V is not None): + cpasync.prefetch_descriptor(tma_atom_V) if const_expr(tma_atom_O is not None): cpasync.prefetch_descriptor(tma_atom_O) @@ -748,7 +766,7 @@ def kernel( # Init "full" barrier with number of producers, "empty" barrier with number of consumers for i in cutlass.range_constexpr(self.q_stage): cute.arch.mbarrier_init( - mbar_ptr + self.mbar_load_q_full_offset + i, len([self.load_warp_id]) + mbar_ptr + self.mbar_load_q_full_offset + i, 1 ) cute.arch.mbarrier_init( mbar_ptr + self.mbar_load_q_empty_offset + i, len([self.mma_warp_id]) @@ -902,7 +920,7 @@ def kernel( # /////////////////////////////////////////////////////////////////////////////// # LOAD # /////////////////////////////////////////////////////////////////////////////// - if warp_idx == self.load_warp_id: + if warp_idx >= self.load_warp_ids[0] and warp_idx <= self.load_warp_ids[-1]: cute.arch.warpgroup_reg_dealloc(self.num_regs_other) self.load( thr_mma_qk, @@ -1070,8 +1088,8 @@ def load( sV: cute.Tensor, mPageTable: Optional[cute.Tensor], tma_atom_Q: cute.CopyAtom, - tma_atom_K: cute.CopyAtom, - tma_atom_V: cute.CopyAtom, + tma_atom_K: Optional[cute.CopyAtom], + tma_atom_V: Optional[cute.CopyAtom], pipeline_kv: cutlass.pipeline.PipelineAsync, mbar_ptr: cute.Pointer, block_info: BlockInfo, @@ -1079,6 +1097,8 @@ def load( SeqlenInfoCls: Callable, TileSchedulerCls: Callable, ): + num_load_threads = len(self.load_warp_ids) * cute.arch.WARP_SIZE + tidx = cute.arch.thread_idx()[0] % num_load_threads q_producer_phase = Int32(1) kv_producer_state = cutlass.pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Producer, self.kv_stage @@ -1117,20 +1137,43 @@ def load( load_Q_fn, _, _ = copy_utils.tma_get_copy_fn( tma_atom_Q, 0, cute.make_layout(1), tSgQ, sQ ) - tKsK, tKgK = cpasync.tma_partition( - tma_atom_K, - 0, # no multicast - cute.make_layout(1), - cute.group_modes(sK, 0, 3), - cute.group_modes(tSgK, 0, 3), - ) - tVsV, tVgV = cpasync.tma_partition( - tma_atom_V, - 0, # no multicast - cute.make_layout(1), - cute.group_modes(sV, 0, 3), - cute.group_modes(tOgV, 0, 3), - ) + + if const_expr(self.use_tma_KV): + tKsK, tKgK = cpasync.tma_partition( + tma_atom_K, + 0, # no multicast + cute.make_layout(1), + cute.group_modes(sK, 0, 3), + cute.group_modes(tSgK, 0, 3), + ) + tVsV, tVgV = cpasync.tma_partition( + tma_atom_V, + 0, # no multicast + cute.make_layout(1), + cute.group_modes(sV, 0, 3), + cute.group_modes(tOgV, 0, 3), + ) + paged_kv_manager = None + else: + page_size = mK.shape[0] + paged_kv_manager = PagedKVManager.create( + mPageTable, + mK, + mV, + FastDivmod.create(page_size), + batch_idx, + head_idx_kv, + tidx, + seqlen.seqlen_k, + 0, # leftpad_k + self.n_block_size, + self.head_dim_padded, + self.head_dim_v_padded, + num_load_threads, + mK.element_type, + ) + tKsK, tKgK = None, None + tVsV, tVgV = None, None load_Q = partial( self.load_Q, @@ -1146,6 +1189,8 @@ def load( tma_atom_K, tKgK, tKsK, + paged_kv_manager, + sK, mbar_ptr + self.mbar_load_kv_full_offset, mbar_ptr + self.mbar_load_kv_empty_offset, K_or_V="K", @@ -1155,6 +1200,8 @@ def load( tma_atom_V, tVgV, tVsV, + paged_kv_manager, + sV, mbar_ptr + self.mbar_load_kv_full_offset, mbar_ptr + self.mbar_load_kv_empty_offset, K_or_V="V", @@ -1163,15 +1210,19 @@ def load( n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits) if const_expr(not self.is_split_kv) or n_block_min < n_block_max: - load_Q(block=self.q_stage * m_block + 0, stage=0) # Q0 + if const_expr(self.use_tma_KV) or tidx < cute.arch.WARP_SIZE: + load_Q(block=self.q_stage * m_block + 0, stage=0) # Q0 + n_block_first = n_block_max - 1 if n_block_max > 0 else 0 page_idx = ( - mPageTable[batch_idx, n_block_max - 1] - if const_expr(mPageTable is not None) + mPageTable[batch_idx, n_block_first] + if const_expr(mPageTable is not None and self.use_tma_KV) else None ) + if const_expr(not self.use_tma_KV): + paged_kv_manager.load_page_table(n_block_first) load_K(block=n_block_max - 1, producer_state=kv_producer_state, page_idx=page_idx) # K0 kv_producer_state.advance() - if const_expr(self.q_stage == 2): + if const_expr(self.q_stage == 2) and (const_expr(self.use_tma_KV) or tidx < cute.arch.WARP_SIZE): load_Q(block=self.q_stage * m_block + 1, stage=1) # Q1 q_producer_phase ^= 1 load_V(block=n_block_max - 1, producer_state=kv_producer_state, page_idx=page_idx) # V0 @@ -1179,8 +1230,12 @@ def load( for i in cutlass.range(n_block_max - 1 - n_block_min, unroll=1): n_block = n_block_max - 2 - i page_idx = ( - mPageTable[batch_idx, n_block] if const_expr(mPageTable is not None) else None + mPageTable[batch_idx, n_block] + if const_expr(mPageTable is not None and self.use_tma_KV) + else None ) + if const_expr(not self.use_tma_KV): + paged_kv_manager.load_page_table(n_block) # if cute.arch.thread_idx()[0] % 32 == 0: cute.printf("n_block = {}, page_idx = {}", n_block, page_idx) load_K(block=n_block, producer_state=kv_producer_state, page_idx=page_idx) # Ki kv_producer_state.advance() @@ -2235,9 +2290,11 @@ def load_Q( @cute.jit def load_KV( self, - tma_atom: cute.CopyAtom, - tXgX: cute.Tensor, - tXsX: cute.Tensor, + tma_atom: Optional[cute.CopyAtom], + tXgX: Optional[cute.Tensor], + tXsX: Optional[cute.Tensor], + paged_kv_manager: Optional[PagedKVManager], + sX: cute.Tensor, mbar_full_ptr: cute.Pointer, mbar_empty_ptr: cute.Pointer, block: Int32, @@ -2253,17 +2310,29 @@ def load_KV( # K. So we need to wait for the stage after that (stage 1) to be empty as well. if stage == 0: cute.arch.mbarrier_wait(mbar_empty_ptr + 1, phase) - with cute.arch.elect_one(): - cute.arch.mbarrier_arrive_and_expect_tx( - mbar_full_ptr + stage, self.tma_copy_bytes[K_or_V] + + if const_expr(self.use_tma_KV): + assert ( + tXgX is not None and + tXsX is not None and + tma_atom is not None ) - tXsX_cur = tXsX[None, stage] - if const_expr(self.uneven_kv_smem): - # Since this is the producer_state, the phase starts at 1, so we have to invert it - tXsX_cur = self.offset_kv_smem(tXsX_cur, stage, phase ^ 1) - # Currently we assume that page_size == n_block_size so we index into tXgX with block = 0 - tXgX_cur = tXgX[None, block] if const_expr(page_idx is None) else tXgX[None, 0, page_idx] - cute.copy(tma_atom, tXgX_cur, tXsX_cur, tma_bar_ptr=mbar_full_ptr + stage) + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive_and_expect_tx( + mbar_full_ptr + stage, self.tma_copy_bytes[K_or_V], + ) + tXsX_cur = tXsX[None, stage] + if const_expr(self.uneven_kv_smem): + # Since this is the producer_state, the phase starts at 1, so we have to invert it + tXsX_cur = self.offset_kv_smem(tXsX_cur, stage, phase ^ 1) + # Currently we assume that page_size == n_block_size so we index into tXgX with block = 0 + tXgX_cur = tXgX[None, block] if const_expr(page_idx is None) else tXgX[None, 0, page_idx] + cute.copy(tma_atom, tXgX_cur, tXsX_cur, tma_bar_ptr=mbar_full_ptr + stage) + else: + assert paged_kv_manager is not None + paged_kv_manager.load_KV(block, sX[None, None, None, stage], K_or_V) + cute.arch.cp_async_commit_group() + cute.arch.cp_async_mbarrier_arrive_noinc(mbar_full_ptr + stage) @cute.jit def offset_kv_smem(self, sX: cute.Tensor, stage: Int32, phase: Int32): @@ -2277,19 +2346,30 @@ def offset_kv_smem(self, sX: cute.Tensor, stage: Int32, phase: Int32): return sX def make_and_init_load_kv_pipeline(self, load_kv_mbar_ptr): - load_kv_producer_group = cutlass.pipeline.CooperativeGroup( - cutlass.pipeline.Agent.Thread, len([self.load_warp_id]) - ) load_kv_consumer_group = cutlass.pipeline.CooperativeGroup( cutlass.pipeline.Agent.Thread, len([self.mma_warp_id]) ) - return cutlass.pipeline.PipelineTmaUmma.create( - barrier_storage=load_kv_mbar_ptr, - num_stages=self.kv_stage, - producer_group=load_kv_producer_group, - consumer_group=load_kv_consumer_group, - tx_count=self.tma_copy_bytes["K"], - ) + if self.use_tma_KV: + load_kv_producer_group = cutlass.pipeline.CooperativeGroup( + cutlass.pipeline.Agent.Thread, len(self.load_warp_ids) + ) + return cutlass.pipeline.PipelineTmaUmma.create( + barrier_storage=load_kv_mbar_ptr, + num_stages=self.kv_stage, + producer_group=load_kv_producer_group, + consumer_group=load_kv_consumer_group, + tx_count=self.tma_copy_bytes["K"], + ) + else: + load_kv_producer_group = cutlass.pipeline.CooperativeGroup( + cutlass.pipeline.Agent.Thread, len(self.load_warp_ids) * cute.arch.WARP_SIZE + ) + return cutlass.pipeline.PipelineAsyncUmma.create( + num_stages=self.kv_stage, + producer_group=load_kv_producer_group, + consumer_group=load_kv_consumer_group, + barrier_storage=load_kv_mbar_ptr, + ) # @cute.jit # def warp_scheduler_barrier_init(self): diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 4989067b8c1..fb36bfd492b 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -413,6 +413,7 @@ def _flash_attn_fwd( is_split_kv, pack_gqa, compute_capability, + page_size not in [None, 128], # paged KV non-TMA ) if compile_key not in _flash_attn_fwd.compile_cache: @@ -441,9 +442,6 @@ def _flash_attn_fwd( has_aux_tensors=aux_tensors is not None, ) elif compute_capability == 10: - assert page_size in [None, 128], ( - "Only page_size=128 is supported for paged KV on SM 10.0" - ) if sparse_tensors is not None: raise NotImplementedError("BlockSparsity not yet supported on SM 10.0") fa_fwd = FlashAttentionForwardSm100( @@ -461,6 +459,7 @@ def _flash_attn_fwd( and not is_split_kv, score_mod=score_mod, has_aux_tensors=aux_tensors is not None, + paged_kv_non_tma=page_size not in [None, 128], ) else: raise ValueError( diff --git a/flash_attn/cute/mask.py b/flash_attn/cute/mask.py index 6f92d0835ac..aa18566cb23 100644 --- a/flash_attn/cute/mask.py +++ b/flash_attn/cute/mask.py @@ -106,6 +106,11 @@ def apply_mask( ROW = 0 if const_expr(not self.swap_AB) else 1 COL = 1 if const_expr(not self.swap_AB) else 0 thr_col_offset = tScS_mn[0][COL] + # To handle edge cases of completely masked out rows where n_block_max = 0, + # we treat negative n_blocks as 0th n_block + # TODO: find more transparent solution + if n_block < 0: + n_block = 0 seqlenk_col_limit = self.seqlen_k - n_block * self.tile_n - thr_col_offset if const_expr(not mask_causal and not mask_local and mask_mod is None): if const_expr(mask_seqlen): @@ -299,6 +304,11 @@ def apply_mask_sm100( cS = cute.make_identity_tensor(acc_shape if not self.swap_AB else acc_shape[::-1]) tScS = thr_mma.partition_C(cS) tScS_t2r = thr_tmem_load.partition_D(tScS) + # To handle edge cases of completely masked out rows where n_block_max = 0, + # we treat negative n_blocks as 0th n_block + # TODO: find more transparent solution + if n_block < 0: + n_block = 0 seqlenk_col_limit = self.seqlen_k - n_block * self.tile_n r2p = True if const_expr(not mask_causal and not mask_local): diff --git a/flash_attn/cute/paged_kv.py b/flash_attn/cute/paged_kv.py new file mode 100644 index 00000000000..ccb2296b4a7 --- /dev/null +++ b/flash_attn/cute/paged_kv.py @@ -0,0 +1,176 @@ +from typing import Type +from dataclasses import dataclass + +import cutlass +import cutlass.cute as cute +from cutlass.cute.nvgpu import cpasync +from cutlass import Int32, const_expr + +from flash_attn.cute import utils +from flash_attn.cute.fast_math import FastDivmod +from flash_attn.cute.cute_dsl_utils import ParamsBase + + +@dataclass +class PagedKVManager(ParamsBase): + mPageTable: cute.Tensor + mK_paged: cute.Tensor + mV_paged: cute.Tensor + thread_idx: Int32 + + page_size_divmod: FastDivmod + seqlen_k: Int32 + leftpad_k: Int32 + n_block_size: Int32 + num_threads: cutlass.Constexpr[Int32] + head_dim_padded: cutlass.Constexpr[Int32] + head_dim_v_padded: cutlass.Constexpr[Int32] + + gmem_threads_per_row: cutlass.Constexpr[Int32] + page_entry_per_thread: Int32 + async_copy_elems: Int32 + + gmem_tiled_copy_KV: cute.TiledCopy + gmem_thr_copy_KV: cute.TiledCopy + tPrPage: cute.Tensor + tPrPageOffset: cute.Tensor + tKpK: cute.Tensor + tVpV: cute.Tensor + + @staticmethod + def create( + mPageTable: cute.Tensor, + mK_paged: cute.Tensor, + mV_paged: cute.Tensor, + page_size_divmod: FastDivmod, + bidb: Int32, + bidh: Int32, + thread_idx: Int32, + seqlen_k: Int32, + leftpad_k: Int32, + n_block_size: cutlass.Constexpr[Int32], + head_dim_padded: cutlass.Constexpr[Int32], + head_dim_v_padded: cutlass.Constexpr[Int32], + num_threads: cutlass.Constexpr[Int32], + dtype: Type[cutlass.Numeric], + ): + universal_copy_bits = 128 + gmem_threads_per_row = 8 # 8 threads loading 128 bits = 128 bytes = 1 cache line + async_copy_elems = universal_copy_bits // dtype.width + atom_async_copy = cute.make_copy_atom( + cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL), + dtype, + num_bits_per_copy=universal_copy_bits, + ) + thr_layout = cute.make_ordered_layout( + (num_threads // gmem_threads_per_row, gmem_threads_per_row), + order=(1, 0), + ) + val_layout = cute.make_layout((1, async_copy_elems)) + gmem_tiled_copy_KV = cute.make_tiled_copy_tv(atom_async_copy, thr_layout, val_layout) + gmem_thr_copy_KV = gmem_tiled_copy_KV.get_slice(thread_idx) + page_entry_per_thread = n_block_size * gmem_threads_per_row // num_threads + + tPrPage = cute.make_rmem_tensor((page_entry_per_thread,), Int32) + tPrPageOffset = cute.make_rmem_tensor((page_entry_per_thread,), Int32) + + mPageTable = mPageTable[bidb, None] + mK_paged = mK_paged[None, None, bidh, None] + mV_paged = mV_paged[None, None, bidh, None] + + cK = cute.make_identity_tensor((n_block_size, head_dim_padded)) + tKcK = gmem_thr_copy_KV.partition_S(cK) + tKpK = utils.predicate_k(tKcK, limit=mK_paged.shape[1]) + + if const_expr(head_dim_padded == head_dim_v_padded): + tVpV = tKpK + else: + cV = cute.make_identity_tensor((n_block_size, head_dim_v_padded)) + tVcV = gmem_thr_copy_KV.partition_S(cV) + tVpV = utils.predicate_k(tVcV, limit=mV_paged.shape[0]) + + return PagedKVManager( + mPageTable, + mK_paged, + mV_paged, + thread_idx, + page_size_divmod, + seqlen_k, + leftpad_k, + n_block_size, + num_threads, + head_dim_padded, + head_dim_v_padded, + gmem_threads_per_row, + page_entry_per_thread, + async_copy_elems, + gmem_tiled_copy_KV, + gmem_thr_copy_KV, + tPrPage, + tPrPageOffset, + tKpK, + tVpV, + ) + + @cute.jit + def load_page_table(self, n_block: Int32): + for i in cutlass.range(self.page_entry_per_thread, unroll=1): + row = (i * self.num_threads + self.thread_idx) // self.gmem_threads_per_row + row_idx = n_block * self.n_block_size + row + + page_idx, page_offset = self.page_size_divmod.divmod(row_idx + self.leftpad_k) + + is_valid = ( + (i + 1) * self.num_threads <= self.n_block_size or row < self.n_block_size + ) and row_idx < self.seqlen_k + page = self.mPageTable[page_idx] if is_valid else 0 + + self.tPrPage[i] = page + self.tPrPageOffset[i] = page_offset + + @cute.jit + def load_KV(self, n_block: Int32, sX: cute.Tensor, K_or_V: str): + assert K_or_V in ("K", "V") + + # Finesse sX layout to be (M, N). + sX_pi = cute.make_tensor( + sX.iterator, + cute.make_layout( + (sX.shape[0][0], (sX.shape[0][1], sX.shape[2])), + stride=(sX.stride[0][0], (sX.stride[0][1], sX.stride[2])), + ), + ) + + if const_expr(K_or_V == "V"): + # Need to transpose V + sX_pi = cute.make_tensor(sX_pi.iterator, cute.select(sX_pi.layout, mode=[1, 0])) + + head_dim = self.head_dim_v_padded if const_expr(K_or_V == "V") else self.head_dim_padded + cX = cute.make_identity_tensor((self.n_block_size, head_dim)) + tXsX = self.gmem_thr_copy_KV.partition_D(sX_pi) + tXcX = self.gmem_thr_copy_KV.partition_S(cX) + + seqlenk_row_limit = self.seqlen_k - n_block * self.n_block_size if n_block >= 0 else 0 + for m in cutlass.range(cute.size(tXsX, mode=[1]), unroll=1): + should_load = tXcX[0, m, 0][0] < seqlenk_row_limit + + page = self.tPrPage[m] + page_offset = self.tPrPageOffset[m] + mX_paged_cur = ( + self.mK_paged[page_offset, None, page] + if const_expr(K_or_V == "K") + else self.mV_paged[None, page_offset, page] + ) + mX_paged_cur_copy = cute.tiled_divide(mX_paged_cur, (self.async_copy_elems,)) + + if should_load: + for k in cutlass.range(cute.size(tXsX, mode=[2]), unroll=1): + ki = tXcX[0, 0, k][1] // self.async_copy_elems + cute.copy( + self.gmem_tiled_copy_KV, + mX_paged_cur_copy[None, ki], + tXsX[None, m, k], + ) + elif const_expr(K_or_V == "V"): + # Don't need to clear out the rest of the smem for K since we'll mask out the scores anyway. + tXsX[None, m, None].fill(0) diff --git a/tests/cute/test_flash_attn.py b/tests/cute/test_flash_attn.py index 6c264c30f55..14034fa9fd2 100644 --- a/tests/cute/test_flash_attn.py +++ b/tests/cute/test_flash_attn.py @@ -731,8 +731,8 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): @pytest.mark.parametrize("rotary_interleaved", [True]) # @pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0]) @pytest.mark.parametrize("rotary_fraction", [0.0]) -# @pytest.mark.parametrize("page_size", [None] + ([1, 4, 128])) -@pytest.mark.parametrize("page_size", [None, 128]) +@pytest.mark.parametrize("page_size", [None] + ([1, 4, 128])) +# @pytest.mark.parametrize("page_size", [None, 128]) # @pytest.mark.parametrize("page_size", [128]) # @pytest.mark.parametrize("has_leftpad", [False, True]) @pytest.mark.parametrize("has_leftpad", [False]) @@ -1154,7 +1154,7 @@ def test_flash_attn_kvcache( # attention_chunk=attention_chunk, # rotary_interleaved=rotary_interleaved, # scheduler_metadata=scheduler_metadata, - # num_splits=num_splits, + num_splits=num_splits, # return_softmax_lse=True ) if varlen_q: