From 0ae26dd30f74105fb8133d44f7665041d1297c12 Mon Sep 17 00:00:00 2001 From: reubenconducts Date: Tue, 5 May 2026 23:54:35 +0000 Subject: [PATCH 01/15] add dynamicpersistentvarlenscheduler to flash_fwd_sm100 and prepare kernel --- flash_attn/cute/block_info.py | 23 +- flash_attn/cute/flash_fwd_combine.py | 5 +- flash_attn/cute/flash_fwd_sm100.py | 313 +++++++++++---- flash_attn/cute/interface.py | 323 ++++++++++++++- flash_attn/cute/prepare_scheduler.py | 392 ++++++++++++++++++ flash_attn/cute/tile_scheduler.py | 570 ++++++++++++++++++++++++--- flash_attn/cute/utils.py | 7 + 7 files changed, 1487 insertions(+), 146 deletions(-) create mode 100644 flash_attn/cute/prepare_scheduler.py diff --git a/flash_attn/cute/block_info.py b/flash_attn/cute/block_info.py index 35bb4365ff6..9abdac17a4e 100644 --- a/flash_attn/cute/block_info.py +++ b/flash_attn/cute/block_info.py @@ -19,6 +19,9 @@ class BlockInfo: window_size_left: Optional[Int32] = None window_size_right: Optional[Int32] = None qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1 + num_splits: Int32 = 1 + num_splits_dynamic_ptr: Optional[cute.Tensor] = None + num_n_blocks_per_split: Optional[cutlass.Constexpr[Int32]] = None @cute.jit def get_n_block_min_max( @@ -26,6 +29,7 @@ def get_n_block_min_max( seqlen_info: SeqlenInfoQK, m_block: Int32, split_idx: Int32 = 0, + batch_idx: Int32 = 0, num_splits: Int32 = 1, ) -> Tuple[Int32, Int32]: n_block_max = cute.ceil_div(seqlen_info.seqlen_k, self.tile_n) @@ -45,11 +49,20 @@ def get_n_block_min_max( n_idx_left = n_idx - self.window_size_left n_block_min = cutlass.max(n_idx_left // self.tile_n, 0) if cutlass.const_expr(self.is_split_kv): - num_n_blocks_per_split = ( - Int32(0) - if n_block_max <= n_block_min - else (n_block_max - n_block_min + num_splits - 1) // num_splits - ) + if const_expr(self.num_splits_dynamic_ptr is not None): + # Unpack num_splits from top 16 bits of split_idx (packed by scheduler) + num_splits = split_idx >> 16 + split_idx = split_idx & 0xFFFF + else: + num_splits = self.num_splits + if const_expr(self.num_n_blocks_per_split is not None): + num_n_blocks_per_split = self.num_n_blocks_per_split + else: + num_n_blocks_per_split = ( + Int32(0) + if n_block_max <= n_block_min + else (n_block_max - n_block_min + num_splits - 1) // num_splits + ) n_block_min = n_block_min + split_idx * num_n_blocks_per_split n_block_max = cutlass.min(n_block_min + num_n_blocks_per_split, n_block_max) return n_block_min, n_block_max diff --git a/flash_attn/cute/flash_fwd_combine.py b/flash_attn/cute/flash_fwd_combine.py index 493620235ec..0d9f7985e70 100644 --- a/flash_attn/cute/flash_fwd_combine.py +++ b/flash_attn/cute/flash_fwd_combine.py @@ -394,8 +394,9 @@ def kernel( num_head = mO_partial.shape[3] max_idx = seqlen * num_head - # Early exit for single split if dynamic - if (const_expr(num_splits_dynamic_ptr is None) or num_splits > 1) and ( + # TODO: early exit for single split if dynamic — for now always merge so the + # num_splits_dynamic == 1 case still writes mO from mO_partial[0]. + if (const_expr(num_splits_dynamic_ptr is None) or num_splits > 0) and ( const_expr(not varlen) or m_block * self.tile_m < max_idx ): # Wait for dependent grids (e.g., the main attention kernel that produces O_partial/LSE_partial) diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index 55a92f690bd..575b1d45bc7 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -56,7 +56,7 @@ from cutlass.cute import FastDivmodDivisor from quack.cute_dsl_utils import ParamsBase from flash_attn.cute.tile_scheduler import ( - ClcState, + SchedulerState, SchedulingMode, TileSchedulerArguments, TileSchedulerProtocol, @@ -64,6 +64,7 @@ StaticPersistentTileScheduler, SingleTileLPTScheduler, SingleTileVarlenScheduler, + DynamicPersistentVarlenScheduler, ) from flash_attn.cute.fa_logging import fa_log, fa_printf from flash_attn.cute.utils import smid @@ -134,6 +135,8 @@ def __init__( is_varlen_q: bool = False, use_2cta_instrs: bool = False, use_clc_scheduler: bool = False, + has_tile_count_semaphore: bool = False, + seqlen_k_per_split: Optional[int] = None, ): self.use_tma_KV = not paged_kv_non_tma # self.dtype = dtype @@ -158,6 +161,10 @@ def __init__( self.split_P_arrive = int(self.split_P_arrive / 32) * 32 # multiple of 32 assert self.split_P_arrive % 32 == 0 assert self.split_P_arrive < self.n_block_size + assert seqlen_k_per_split is None or seqlen_k_per_split % n_block_size == 0 + self.num_n_blocks_per_split = ( + seqlen_k_per_split // n_block_size if seqlen_k_per_split is not None else None + ) self.arch = BaseDSL._get_dsl().get_arch_enum() assert self.arch >= Arch.sm_100 and self.arch <= Arch.sm_110f, "Only SM 10.x and 11.x are supported" @@ -188,6 +195,15 @@ def __init__( self.vec_size: cutlass.Constexpr = getattr( score_mod, "__vec_size__", 1 if cutlass.const_expr(has_aux_tensors) else 2 ) + self.dynamic_persistent = has_tile_count_semaphore and is_varlen_q + if self.dynamic_persistent: + self.is_persistent = True + assert not use_clc_scheduler, ( + "use_clc_scheduler and dynamic_persistent (varlen + tile_count_semaphore) " + "are not currently composable; pick one. TODO: future revision could let " + "DynamicPersistentVarlenScheduler use CLC for tile distribution while " + "keeping prepare_scheduler's per-batch num_splits and LPT batch-sort." + ) # Does S1 need to wait for S0 to finish # self.s0_s1_barrier = self.head_dim_padded in [64, 96] and (not self.is_causal and not self.is_local) is_sm103 = self.arch >= Arch.sm_103 and self.arch <= Arch.sm_103f @@ -200,8 +216,6 @@ def __init__( (self.head_dim_padded == 192 and self.head_dim_v_padded >= 64) or (self.head_dim_v_padded >= 128 and self.is_split_kv) ) - if self.overlap_sO_sQ: - self.is_persistent = False assert self.use_tma_KV or not (self.check_hdim_oob or self.check_hdim_v_oob), ( "Paged KV does not support irregular head dim" @@ -221,10 +235,17 @@ def __init__( f"CLC cluster M != cta_group_size: {self.cluster_shape_mn}, {self.cta_group_size}" ) - self.scheduling_mode = SchedulingMode.CLC if self.use_clc_scheduler else SchedulingMode.STATIC + self.scheduling_mode = ( + SchedulingMode.CLC if self.use_clc_scheduler + else SchedulingMode.DYNAMIC if self.dynamic_persistent + else SchedulingMode.STATIC + ) if is_varlen_q: - self.TileScheduler = SingleTileVarlenScheduler + if self.dynamic_persistent: + self.TileScheduler = DynamicPersistentVarlenScheduler + else: + self.TileScheduler = SingleTileVarlenScheduler elif self.is_causal or self.is_local or self.use_clc_scheduler: self.TileScheduler = SingleTileLPTScheduler elif self.is_persistent: @@ -274,7 +295,10 @@ def __init__( elif self.is_varlen_q: # fallback self.epilogue_warp_ids = (13, 14) - self.clc_scheduler_warp_id = self.empty_warp_ids[0] if self.use_clc_scheduler else None + self.scheduler_warp_id = ( + self.empty_warp_ids[0] + if (self.use_clc_scheduler or self.dynamic_persistent) else None + ) self.tmem_s_offset = [0, self.n_block_size] # e.g., 0, 128 self.tmem_o_offset = [ @@ -374,6 +398,12 @@ def __call__( descale_tensors: Optional[DescaleTensors] = None, blocksparse_tensors: Optional[BlockSparseTensors] = None, aux_tensors: Optional[list] = None, + num_splits_dynamic_ptr: Optional[cute.Tensor] = None, + tile_count_semaphore: Optional[cute.Tensor] = None, + num_m_blocks_ptr: Optional[cute.Tensor] = None, + varlen_batch_idx_ptr: Optional[cute.Tensor] = None, + num_nheads_in_l2_ptr: Optional[cute.Tensor] = None, + max_seqlen_q: Int32 | int | None = None, # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI). stream: cuda.CUstream = None, ): @@ -627,10 +657,14 @@ def __call__( vO_layout = cute.make_layout((1, async_copy_elems)) gmem_tiled_copy_O = cute.make_tiled_copy_tv(atom_universal_copy, tO_layout, vO_layout) + if const_expr(max_seqlen_q is None): + eff_seqlen_q = cute.size(mQ.shape[0]) + else: + eff_seqlen_q = max_seqlen_q if const_expr(not self.pack_gqa) else max_seqlen_q * self.qhead_per_kvhead TileScheduler = self.TileScheduler _num_block_divisor = self.cta_tiler[0] * (self.cta_group_size if not self.is_persistent and self.cta_group_size > 1 else 1) tile_sched_args = TileSchedulerArguments( - cute.ceil_div(cute.size(mQ.shape[0]), _num_block_divisor), + cute.ceil_div(eff_seqlen_q, _num_block_divisor), cute.size(mQ.shape[2]), cute.size(mQ.shape[3]) if const_expr(mCuSeqlensQ is None) @@ -654,6 +688,11 @@ def __call__( is_split_kv=self.is_split_kv, cluster_shape_mn=self.cluster_shape_mn, use_cluster_idx=not self.is_persistent and self.cta_group_size > 1, + num_splits_dynamic_ptr=num_splits_dynamic_ptr, + num_m_blocks_ptr=num_m_blocks_ptr, + varlen_batch_idx_ptr=varlen_batch_idx_ptr, + num_nheads_in_l2_ptr=num_nheads_in_l2_ptr, + tile_count_semaphore=tile_count_semaphore.iterator if tile_count_semaphore is not None else None, ) tile_sched_params = TileScheduler.to_underlying_arguments( tile_sched_args, scheduling_mode=self.scheduling_mode @@ -667,8 +706,15 @@ def __call__( cutlass.max(cute.cosize(sQ_layout), cute.cosize(sO_layout) * self.o_dtype.width // self.q_dtype.width) ) - clc_response_size = self.sched_stages * 4 if self.use_clc_scheduler else 0 - clc_mbar_size = self.sched_stages * 2 if self.use_clc_scheduler else 0 + sched_response_size = ( + self.sched_stages * 4 + if (self.use_clc_scheduler or self.dynamic_persistent) else 0 + ) + sched_mbar_size = ( + self.sched_stages * 2 + if (self.use_clc_scheduler or self.dynamic_persistent) else 0 + ) + load_epi_mbar_size = 2 if const_expr(self.overlap_sO_sQ) else 0 @cute.struct class SharedStorage: @@ -681,6 +727,7 @@ class SharedStorage: mbar_softmax_stats: cute.struct.MemRange[Int64, self.q_stage * 2] # mbar_softmax_stats: cute.struct.MemRange[Int64, self.q_stage * 4 * 2] mbar_O_epi: cute.struct.MemRange[Int64, self.q_stage * 2] + mbar_load_epi: cute.struct.MemRange[Int64, load_epi_mbar_size] mbar_s0_s1_sequence: cute.struct.MemRange[Int64, 2 * 2] # Tmem dealloc cluster barrier tmem_dealloc_mbar_ptr: Int64 @@ -689,12 +736,12 @@ class SharedStorage: # Smem tensors # store row max and row sum sScale: cute.struct.MemRange[Float32, self.q_stage * self.m_block_size * 2] - # CLC buffers placed here to utilize padding before sO's 1024-byte alignment. - # This avoids adding bytes at the end when we're at the smem limit. - # PipelineClcFetchAsync expects 2 * sched_stages mbarriers (full + empty). - clc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, clc_mbar_size] - # CLC response storage (16 bytes per stage, stored as 4 Int32s). - clc_response: cute.struct.MemRange[Int32, clc_response_size] + # Scheduler buffers placed here to utilize padding before sO's 1024-byte + # alignment. PipelineClcFetchAsync / PipelineAsync both expect + # 2 * sched_stages mbarriers (full + empty). Response is 4 Int32 per stage + # (CLC HW response, or work_info written by dynamic persistent producer). + sched_mbar_ptr: cute.struct.MemRange[Int64, sched_mbar_size] + sched_response: cute.struct.MemRange[Int32, sched_response_size] # Large TMA buffers with 1024-byte alignment sO: cute.struct.Align[ cute.struct.MemRange[self.o_dtype, sO_size], self.buffer_align_bytes @@ -761,6 +808,11 @@ class SharedStorage: tiled_mma_pv, tile_sched_params, num_splits, + num_splits_dynamic_ptr, + tile_count_semaphore, + num_m_blocks_ptr, + varlen_batch_idx_ptr, + num_nheads_in_l2_ptr, aux_tensors, fastdiv_mods, head_divmod, @@ -808,6 +860,11 @@ def kernel( tiled_mma_pv: cute.TiledMma, tile_sched_params: ParamsBase, num_splits: Int32, + num_splits_dynamic_ptr: Optional[cute.Tensor] = None, + tile_count_semaphore: Optional[cute.Tensor] = None, + num_m_blocks_ptr: Optional[cute.Tensor] = None, + varlen_batch_idx_ptr: Optional[cute.Tensor] = None, + num_nheads_in_l2_ptr: Optional[cute.Tensor] = None, aux_tensors: Optional[list] = None, fastdiv_mods=(None, None), head_divmod=None, @@ -870,6 +927,7 @@ def kernel( ThreadCooperativeGroup = partial(pipeline.CooperativeGroup, pipeline.Agent.Thread) mma_warp = ThreadCooperativeGroup(len([self.mma_warp_id])) tma_warp = ThreadCooperativeGroup(1) + load_warps = ThreadCooperativeGroup(len(self.load_warp_ids)) load_threads = ThreadCooperativeGroup(len(self.load_warp_ids) * cute.arch.WARP_SIZE) softmax_warps = ThreadCooperativeGroup(len(self.softmax0_warp_ids)) softmax_threads = ThreadCooperativeGroup(cute.arch.WARP_SIZE * len(self.softmax0_warp_ids)) @@ -881,6 +939,7 @@ def kernel( softmax_correction_threads = ThreadCooperativeGroup( cute.arch.WARP_SIZE * len(self.softmax0_warp_ids + self.correction_warp_ids) ) + epilogue_warps = ThreadCooperativeGroup(len(self.epilogue_warp_ids)) epilogue_threads = ThreadCooperativeGroup(cute.arch.WARP_SIZE * len(self.epilogue_warp_ids)) # For UMMA-bridging pipelines: the non-MMA side spans both CTAs in the cluster, # so the thread count must include warps from both CTAs. @@ -993,6 +1052,22 @@ def kernel( defer_sync=True, ) + pipeline_load_epi = None + if const_expr(self.overlap_sO_sQ and self.is_persistent): + epi_warps_for_release = ( + ThreadCooperativeGroup(len(self.correction_warp_ids)) + if self.use_correction_warps_for_epi + else epilogue_warps + ) + pipeline_load_epi = pipeline_custom.PipelineAsync.create( + barrier_storage=storage.mbar_load_epi.data_ptr(), + num_stages=1, + producer_group=epi_warps_for_release, + consumer_group=load_warps, + defer_sync=True, + ) + + # Cluster arrive after barrier init pipeline_init_arrive(cluster_shape_mn=cta_layout_vmnk, is_relaxed=True) @@ -1042,6 +1117,9 @@ def kernel( window_size_left, window_size_right, qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, + num_splits=num_splits, + num_splits_dynamic_ptr=num_splits_dynamic_ptr, + num_n_blocks_per_split=self.num_n_blocks_per_split, ) SeqlenInfoCls = partial( SeqlenInfoQK.create, @@ -1071,60 +1149,79 @@ def kernel( # Cluster wait before tensor memory alloc pipeline_init_wait(cluster_shape_mn=cta_layout_vmnk) - if const_expr(self.use_clc_scheduler): - clc_response_ptr = storage.clc_response.data_ptr() - clc_mbar_ptr = storage.clc_mbar_ptr.data_ptr() - - clc_pipeline_producer_group = cutlass_pipeline.CooperativeGroup( + sched_ctx = None + if const_expr(self.use_clc_scheduler or self.dynamic_persistent): + sched_response_ptr = storage.sched_response.data_ptr() + sched_mbar_ptr = storage.sched_mbar_ptr.data_ptr() + sched_producer_group = cutlass_pipeline.CooperativeGroup( cutlass_pipeline.Agent.Thread ) - num_clc_consumer_warps_per_cta = self.threads_per_cta // cute.arch.WARP_SIZE - # NB on CTA0 warp15 == scheduler on CTA1 == empty but still both consume - num_clc_consumer_warps = num_clc_consumer_warps_per_cta * self.cta_group_size - clc_pipeline_consumer_group = cutlass_pipeline.CooperativeGroup( - cutlass_pipeline.Agent.Thread, cute.arch.WARP_SIZE * num_clc_consumer_warps - ) - - block_idx = cute.arch.block_idx() - clc = ClcState.create( - hw_scheduler=ClcDynamicPersistentTileScheduler.create( - self.tile_scheduler_cls.clc_problem_shape(tile_sched_params), - block_idx, - cute.arch.grid_dim(), - clc_response_ptr, - ), - pipeline=cutlass_pipeline.PipelineClcFetchAsync.create( - barrier_storage=clc_mbar_ptr, - num_stages=self.sched_stages, - producer_group=clc_pipeline_producer_group, - consumer_group=clc_pipeline_consumer_group, - tx_count=16, - cta_layout_vmnk=cta_layout_vmnk, - ), - consumer_state=cutlass_pipeline.make_pipeline_state( - cutlass_pipeline.PipelineUserType.Consumer, self.sched_stages - ), - producer_state=cutlass_pipeline.make_pipeline_state( - cutlass_pipeline.PipelineUserType.Producer, self.sched_stages - ), + num_sched_consumer_warps_per_cta = self.threads_per_cta // cute.arch.WARP_SIZE + num_sched_consumer_warps = num_sched_consumer_warps_per_cta * self.cta_group_size + sched_consumer_group = cutlass_pipeline.CooperativeGroup( + cutlass_pipeline.Agent.Thread, + cute.arch.WARP_SIZE * num_sched_consumer_warps, ) - tile_scheduler = self.tile_scheduler_cls.create(tile_sched_params, clc=clc) + if const_expr(self.use_clc_scheduler): + _block_idx = cute.arch.block_idx() + sched_ctx = SchedulerState.create_clc( + hw_scheduler=ClcDynamicPersistentTileScheduler.create( + self.tile_scheduler_cls.clc_problem_shape(tile_sched_params), + _block_idx, + cute.arch.grid_dim(), + sched_response_ptr, + ), + pipeline=cutlass_pipeline.PipelineClcFetchAsync.create( + barrier_storage=sched_mbar_ptr, + num_stages=self.sched_stages, + producer_group=sched_producer_group, + consumer_group=sched_consumer_group, + tx_count=16, + cta_layout_vmnk=cta_layout_vmnk, + ), + consumer_state=cutlass_pipeline.make_pipeline_state( + cutlass_pipeline.PipelineUserType.Consumer, self.sched_stages + ), + producer_state=cutlass_pipeline.make_pipeline_state( + cutlass_pipeline.PipelineUserType.Producer, self.sched_stages + ), + ) + else: + assert tile_count_semaphore is not None + sched_ctx = SchedulerState.create_dynamic_persistent( + work_info=storage.sched_response.get_tensor((4,)), + pipeline=cutlass_pipeline.PipelineAsync.create( + barrier_storage=sched_mbar_ptr, + num_stages=self.sched_stages, + producer_group=sched_producer_group, + consumer_group=sched_consumer_group, + ), + consumer_state=cutlass_pipeline.make_pipeline_state( + cutlass_pipeline.PipelineUserType.Consumer, self.sched_stages + ), + producer_state=cutlass_pipeline.make_pipeline_state( + cutlass_pipeline.PipelineUserType.Producer, self.sched_stages + ), + ) + if const_expr(self.use_clc_scheduler or self.dynamic_persistent): + tile_scheduler = self.tile_scheduler_cls.create(tile_sched_params, ctx=sched_ctx) else: tile_scheduler = self.tile_scheduler_cls.create(tile_sched_params) assert isinstance(tile_scheduler, TileSchedulerProtocol), f"tile_scheduler is not a TileSchedulerProtocol: {type(tile_scheduler)}" # /////////////////////////////////////////////////////////////////////////////// - # EMPTY / CLC SCHEDULER WARP + # EMPTY / SCHEDULER WARP # /////////////////////////////////////////////////////////////////////////////// - if const_expr(self.use_clc_scheduler): - if warp_idx == self.clc_scheduler_warp_id: + if const_expr(self.use_clc_scheduler or self.dynamic_persistent): + if warp_idx == self.scheduler_warp_id: cute.arch.setmaxregister_decrease(self.num_regs_other) - if is_leader_cta: - self.clc_scheduler_warp(tile_scheduler) + # CLC: only leader CTA produces. + if const_expr(self.dynamic_persistent) or is_leader_cta: + self.scheduler_warp(tile_scheduler) else: self.empty_warp(tile_scheduler) for i in cutlass.range_constexpr(len(self.empty_warp_ids)): - if warp_idx == self.empty_warp_ids[i] and warp_idx != self.clc_scheduler_warp_id: + if warp_idx == self.empty_warp_ids[i] and warp_idx != self.scheduler_warp_id: cute.arch.setmaxregister_decrease(self.num_regs_other) self.empty_warp(tile_scheduler) else: @@ -1153,6 +1250,7 @@ def kernel( gmem_tiled_copy_Q, pipeline_q, pipeline_kv, + pipeline_load_epi, block_info, num_splits, SeqlenInfoCls, @@ -1207,6 +1305,7 @@ def kernel( gmem_tiled_copy_O, tma_atom_O, pipeline_o_epi, + pipeline_load_epi, block_info, num_splits, SeqlenInfoCls, @@ -1286,6 +1385,7 @@ def kernel( pipeline_sm_stats, sm_stats_barrier, pipeline_o_epi, + pipeline_load_epi, learnable_sink, descale_tensors, gmem_tiled_copy_O, @@ -1319,6 +1419,7 @@ def load( gmem_tiled_copy_Q: Optional[cute.TiledCopy], pipeline_q: pipeline.PipelineAsync, pipeline_kv: pipeline.PipelineAsync, + pipeline_load_epi: Optional[pipeline.PipelineAsync], block_info: BlockInfo, num_splits: Int32, SeqlenInfoCls: Callable, @@ -1328,6 +1429,9 @@ def load( num_load_threads = len(self.load_warp_ids) * cute.arch.WARP_SIZE tidx = cute.arch.thread_idx()[0] % num_load_threads warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + load_epi_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, 1 + ) issue_kv_for_this_warp = ( const_expr(not self.use_tma_KV or len(self.load_warp_ids) == 1) or warp_idx == self.load_warp_ids[0] @@ -1453,8 +1557,10 @@ def load( if const_expr(not self.use_block_sparsity): n_block_min, n_block_max = block_info.get_n_block_min_max( - seqlen, m_block, split_idx, num_splits + seqlen, m_block, split_idx=split_idx, batch_idx=batch_idx, num_splits=num_splits, ) + if const_expr(self.is_split_kv and block_info.num_splits_dynamic_ptr is not None): + split_idx = split_idx & 0xFFFF if const_expr(not self.is_split_kv) or n_block_min < n_block_max: n_block_first = n_block_max - 1 if n_block_max > 0 else 0 page_idx = ( @@ -1515,7 +1621,11 @@ def load( work_tile = tile_scheduler.advance_to_next_work() - # End of persistent scheduler loop + if const_expr(pipeline_load_epi is not None): + pipeline_load_epi.consumer_wait(load_epi_consumer_state) + with cute.arch.elect_one(): + pipeline_load_epi.consumer_release(load_epi_consumer_state) + load_epi_consumer_state.advance() if issue_kv_for_this_warp: pipeline_kv.producer_tail(kv_producer_state) @@ -1639,7 +1749,6 @@ def mma( while work_tile.is_valid_tile: m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx seqlen = SeqlenInfoCls(batch_idx) - block_iter_count = Int32(0) process_tile = False @@ -1657,7 +1766,11 @@ def mma( ) process_tile = block_iter_count > Int32(0) else: - n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits) + n_block_min, n_block_max = block_info.get_n_block_min_max( + seqlen, m_block, split_idx=split_idx, batch_idx=batch_idx, num_splits=num_splits, + ) + if const_expr(self.is_split_kv and block_info.num_splits_dynamic_ptr is not None): + split_idx = split_idx & 0xFFFF block_iter_count = n_block_max - n_block_min if const_expr(not self.is_split_kv): process_tile = True @@ -1936,7 +2049,11 @@ def softmax_loop( m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx kv_head_idx = self._kv_head_idx(head_idx) seqlen = SeqlenInfoCls(batch_idx) - n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits) + n_block_min, n_block_max = block_info.get_n_block_min_max( + seqlen, m_block, split_idx=split_idx, batch_idx=batch_idx, num_splits=num_splits, + ) + if const_expr(self.is_split_kv and block_info.num_splits_dynamic_ptr is not None): + split_idx = split_idx & 0xFFFF mask = AttentionMaskCls(seqlen) shared_mask_kwargs = dict( @@ -2358,6 +2475,7 @@ def correction_loop( pipeline_sm_stats: pipeline.PipelineAsync, sm_stats_barrier: pipeline.NamedBarrier, pipeline_o_epi: pipeline.PipelineAsync, + pipeline_load_epi: Optional[pipeline.PipelineAsync], learnable_sink: Optional[cute.Tensor], descale_tensors: Optional[DescaleTensors], gmem_tiled_copy_O: cute.TiledCopy, @@ -2396,6 +2514,9 @@ def correction_loop( sm_stats_consumer_phase = Int32(0) o_corr_consumer_phase = Int32(0) corr_epi_producer_phase = Int32(1) + load_epi_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, 1 + ) work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: @@ -2412,7 +2533,11 @@ def correction_loop( Float32(256.0) if cutlass.const_expr(self.q_dtype.width == 8) else Float32(1.0) ) seqlen = SeqlenInfoCls(batch_idx) - n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits) + n_block_min, n_block_max = block_info.get_n_block_min_max( + seqlen, m_block, split_idx=split_idx, batch_idx=batch_idx, num_splits=num_splits, + ) + if const_expr(self.is_split_kv and block_info.num_splits_dynamic_ptr is not None): + split_idx = split_idx & 0xFFFF if const_expr(self.is_split_kv): mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3)[None, None, head_idx, split_idx] @@ -2643,6 +2768,12 @@ def correction_loop( ) cute.make_tensor(lse_gmem_ptr, (1,))[0] = lse + if const_expr(pipeline_load_epi is not None and self.use_correction_warps_for_epi): + pipeline_load_epi.producer_acquire(load_epi_producer_state) + with cute.arch.elect_one(): + pipeline_load_epi.producer_commit(load_epi_producer_state) + load_epi_producer_state.advance() + # Advance to next tile work_tile = tile_scheduler.advance_to_next_work() # End of persistent scheduler loop @@ -2848,6 +2979,7 @@ def epilogue_s2g( gmem_tiled_copy_O: cute.TiledCopy, tma_atom_O: Optional[cute.CopyAtom], pipeline_o_epi: pipeline.PipelineAsync, + pipeline_load_epi: Optional[pipeline.PipelineAsync], block_info: BlockInfo, num_splits: int, SeqlenInfoCls: Callable, @@ -2856,11 +2988,18 @@ def epilogue_s2g( tile_scheduler=None, ): epi_consumer_phase = Int32(0) + load_epi_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, 1 + ) work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx seqlen = SeqlenInfoCls(batch_idx) - n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits) + n_block_min, n_block_max = block_info.get_n_block_min_max( + seqlen, m_block, split_idx=split_idx, batch_idx=batch_idx, num_splits=num_splits, + ) + if const_expr(self.is_split_kv and block_info.num_splits_dynamic_ptr is not None): + split_idx = split_idx & 0xFFFF has_work = const_expr(self.use_block_sparsity or not self.is_split_kv) or n_block_min < n_block_max if has_work: @@ -2911,30 +3050,44 @@ def epilogue_s2g( epi_consumer_phase ^= 1 + if const_expr(pipeline_load_epi is not None): + pipeline_load_epi.producer_acquire(load_epi_producer_state) + with cute.arch.elect_one(): + pipeline_load_epi.producer_commit(load_epi_producer_state) + load_epi_producer_state.advance() + # Advance to next tile work_tile = tile_scheduler.advance_to_next_work() @cute.jit - def clc_scheduler_warp( + def scheduler_warp( self, tile_scheduler: TileSchedulerProtocol, ): - work_tile = tile_scheduler.initial_work_tile_info() - while work_tile.is_valid_tile: - tile_scheduler.prefetch_next_work() - work_tile = tile_scheduler.advance_to_next_work() - if cute.arch.thread_idx()[0] == self.clc_scheduler_warp_id * cute.arch.WARP_SIZE: - fa_printf( - 3, - "[CLC] query sm={} cta={} (m_blk={},h={},b={},s={}) valid={}\n", - smid(), - cute.arch.block_idx()[0], - work_tile.tile_idx[0], - work_tile.tile_idx[1], - work_tile.tile_idx[2], - work_tile.tile_idx[3], - work_tile.is_valid_tile, - ) + if const_expr(self.dynamic_persistent): + work_tile, group_start_tile = tile_scheduler.initial_sched_state() + batch_idx = Int32(work_tile.tile_idx[2]) + while work_tile.is_valid_tile: + group_start_tile = tile_scheduler.prefetch_next_work(batch_idx, group_start_tile) + work_tile = tile_scheduler.advance_to_next_work() + batch_idx = Int32(work_tile.tile_idx[2]) + else: + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + tile_scheduler.prefetch_next_work() + work_tile = tile_scheduler.advance_to_next_work() + if cute.arch.thread_idx()[0] == self.scheduler_warp_id * cute.arch.WARP_SIZE: + fa_printf( + 3, + "[CLC] query sm={} cta={} (m_blk={},h={},b={},s={}) valid={}\n", + smid(), + cute.arch.block_idx()[0], + work_tile.tile_idx[0], + work_tile.tile_idx[1], + work_tile.tile_idx[2], + work_tile.tile_idx[3], + work_tile.is_valid_tile, + ) tile_scheduler.producer_tail() @cute.jit diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 189ae1faca7..7bcbb7fff17 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -44,6 +44,7 @@ from flash_attn.cute.flash_bwd_postprocess import FlashAttentionBackwardPostprocess from flash_attn.cute.flash_fwd_combine import FlashAttentionForwardCombine from flash_attn.cute.flash_fwd_mla_sm100 import FlashAttentionMLAForwardSm100 +from flash_attn.cute.prepare_scheduler import FlashPrepareScheduler, SchedulerMetadataTensorsTorch # SM100 head_dim=256 2CTA kernel imports from flash_attn.cute.sm100_hd256_2cta_fmha_forward import BlackwellFusedMultiHeadAttentionForward @@ -326,6 +327,10 @@ def _flash_attn_fwd( k_descale: Optional[torch.Tensor] = None, v_descale: Optional[torch.Tensor] = None, gather_kv_indices: Optional[torch.Tensor] = None, + scheduler_metadata: Optional[SchedulerMetadataTensorsTorch] = None, + seqlen_k_per_split: Optional[int] = None, + disable_scheduler_metadata: bool = False, + zfill_padded_output: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: """Forward pass for FlashAttention. @@ -675,6 +680,61 @@ def _flash_attn_fwd( sparse_kv = None disable_sparse_kv_bitmask = None + + reuse_scheduler_metadata = scheduler_metadata is not None + is_varlen_q = cu_seqlens_q is not None or seqused_q is not None + if is_split_kv and is_varlen_q and scheduler_metadata is None and not disable_scheduler_metadata: + scheduler_metadata = get_scheduler_metadata( + num_batch=batch_size, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + nheads=num_head, + nheads_kv=num_head_kv, + headdim=head_dim, + num_splits=num_splits, + tile_m=tile_m, + tile_n=tile_n, + headdim_v=head_dim_v, + pack_gqa=pack_gqa, + causal=causal, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + seqused_q=seqused_q, + seqused_k=seqused_k, + seqlen_k_per_split=seqlen_k_per_split, + ) + + has_scheduler_metadata = scheduler_metadata is not None and not disable_scheduler_metadata + if has_scheduler_metadata: + ( + num_m_blocks, + num_splits_dynamic, + varlen_batch_idx, + num_nheads_in_l2, + tile_count_semaphore, + ) = scheduler_metadata + assert all( + t is None or t.is_cuda + for t in scheduler_metadata + ), "scheduler metadata must be on CUDA device" + assert all( + t is None or t.shape == (batch_size,) + for t in ( + num_m_blocks, + num_splits_dynamic, + varlen_batch_idx, + num_nheads_in_l2, + ) + ), "these scheduler metadata tensors must have shape (batch_size,)" + if tile_count_semaphore is not None: + assert tile_count_semaphore.shape == (1,), "semaphore must have size 1" + else: + num_m_blocks = None + num_splits_dynamic = None + varlen_batch_idx = None + num_nheads_in_l2 = None + tile_count_semaphore = None + compile_key = ( dtype, head_dim, @@ -713,6 +773,12 @@ def _flash_attn_fwd( mma_pv_is_rs, intra_wg_overlap, use_clc_scheduler, + has_scheduler_metadata, + num_m_blocks is not None, + num_splits_dynamic is not None, + varlen_batch_idx is not None, + num_nheads_in_l2 is not None, + tile_count_semaphore is not None, qv is not None, gather_kv_length, sparse_kv, @@ -784,6 +850,27 @@ def _flash_attn_fwd( if aux_tensors is not None: cute_aux_tensors = [to_cute_aux_tensor(buf) for buf in aux_tensors] + num_splits_dynamic_tensor = ( + to_cute_tensor(num_splits_dynamic, assumed_align=4, leading_dim=0) + if num_splits_dynamic is not None else None + ) + tile_count_semaphore_tensor = ( + to_cute_tensor(tile_count_semaphore, assumed_align=4, leading_dim=0) + if tile_count_semaphore is not None else None + ) + num_m_blocks_tensor = ( + to_cute_tensor(num_m_blocks, assumed_align=4, leading_dim=0) + if num_m_blocks is not None else None + ) + varlen_batch_idx_tensor = ( + to_cute_tensor(varlen_batch_idx, assumed_align=4, leading_dim=0) + if varlen_batch_idx is not None else None + ) + num_nheads_in_l2_tensor = ( + to_cute_tensor(num_nheads_in_l2, assumed_align=4, leading_dim=0) + if num_nheads_in_l2 is not None else None + ) + qv_tensor = to_cute_tensor(qv) if qv is not None else None gather_kv_indices_tensor = to_cute_tensor(gather_kv_indices) if gather_kv_indices is not None else None @@ -904,6 +991,7 @@ def _flash_attn_fwd( q_subtile_factor=q_subtile_factor, use_2cta_instrs=use_2cta_instrs, use_clc_scheduler=use_clc_scheduler, + has_tile_count_semaphore=tile_count_semaphore is not None, ) elif arch // 10 == 12: # SM120 (Blackwell GeForce / DGX Spark): uses SM80 MMA with SM120 SMEM capacity @@ -970,17 +1058,19 @@ def _flash_attn_fwd( window_size_left, window_size_right, learnable_sink_tensor, - ] - if arch // 10 in [10, 11]: - compile_args.append(descale_tensors_tensor) - compile_args.extend([ sparse_tensors, cute_aux_tensors, - ]) - compile_args.append(current_stream) - _flash_attn_fwd.compile_cache[compile_key] = cute.compile( - *compile_args, options="--enable-tvm-ffi" - ) + num_splits_dynamic_tensor, + tile_count_semaphore_tensor, + num_m_blocks_tensor, + varlen_batch_idx_tensor, + num_nheads_in_l2_tensor, + max_seqlen_q, + current_stream, + ] + if arch // 10 in [10, 11]: + compile_args.insert(-9, descale_tensors_tensor) + _flash_attn_fwd.compile_cache[compile_key] = cute.compile(*compile_args, options="--enable-tvm-ffi") if not is_fake_mode(): q_call, k_call, v_call = q.detach(), k.detach(), v.detach() @@ -1048,6 +1138,12 @@ def _flash_attn_fwd( if normalized_block_sparse_tensors is not None else None, aux_tensors, + num_splits_dynamic, + tile_count_semaphore, + num_m_blocks, + varlen_batch_idx, + num_nheads_in_l2, + max_seqlen_q, ]) _flash_attn_fwd.compile_cache[compile_key](*call_args) if is_split_kv: @@ -1058,7 +1154,11 @@ def _flash_attn_fwd( lse.transpose(-1, -2) if lse is not None else None, cu_seqlens_q, seqused_q, + num_splits_dynamic_ptr=num_splits_dynamic if has_scheduler_metadata else None, + varlen_batch_idx=varlen_batch_idx if has_scheduler_metadata else None, ) + if reuse_scheduler_metadata and tile_count_semaphore is not None: + tile_count_semaphore.zero_() return out, lse @@ -2036,6 +2136,9 @@ def forward( block_sparse_tensors: Optional[list] = None, aux_tensors: Optional[list] = None, return_lse: bool = False, + scheduler_metadata: Optional["SchedulerMetadataTensorsTorch"] = None, + seqlen_k_per_split: Optional[int] = None, + disable_scheduler_metadata: bool = False, ): out, lse = _flash_attn_fwd( q, @@ -2064,6 +2167,9 @@ def forward( aux_tensors=aux_tensors, return_lse=return_lse, gather_kv_indices=gather_kv_indices, + scheduler_metadata=scheduler_metadata, + seqlen_k_per_split=seqlen_k_per_split, + disable_scheduler_metadata=disable_scheduler_metadata, ) ctx.save_for_backward( q, @@ -2200,6 +2306,9 @@ def flash_attn_varlen_func( block_sparse_tensors: Optional[BlockSparseTensorsTorch] = None, aux_tensors: Optional[list] = None, return_lse: bool = False, + scheduler_metadata: Optional[SchedulerMetadataTensorsTorch] = None, + seqlen_k_per_split: Optional[int] = None, + disable_scheduler_metadata: bool = False, ): """ Explanation of some optional arguments: @@ -2243,12 +2352,16 @@ def flash_attn_varlen_func( block_sparse_tensors, aux_tensors, return_lse, + scheduler_metadata, + seqlen_k_per_split, + disable_scheduler_metadata, ) def _compile_fwd_combine( dtype, dtype_partial, head_dim, tile_m, k_block_size, log_max_splits, has_cu_seqlens, has_seqused, has_lse, has_varlen_batch_idx, + has_num_splits_dynamic, has_semaphore_to_reset, ): """Compile fwd combine kernel using cute fake tensors (no real GPU tensors needed).""" sym = cute.sym_int @@ -2290,9 +2403,9 @@ def _compile_fwd_combine( batchp1 = sym() mCuSeqlens = fake_tensor(Int32, (batchp1,), divisibility=1) if has_cu_seqlens else None mSeqused = fake_tensor(Int32, (batch_for_1d,), divisibility=1) if has_seqused else None - mNumSplitsDynamic = None # Not parametrized in compile_key + mNumSplitsDynamic = fake_tensor(Int32, (batch_for_1d,), divisibility=1) if has_num_splits_dynamic else None mVarlenBatchIdx = fake_tensor(Int32, (batch_for_1d,), divisibility=1) if has_varlen_batch_idx else None - mSemaphore = None # Not parametrized in compile_key + mSemaphore = fake_tensor(Int32, (1,), divisibility=1) if has_semaphore_to_reset else None return cute.compile( fa_combine, @@ -2381,6 +2494,8 @@ def _flash_attn_fwd_combine( seqused is not None, lse is not None, varlen_batch_idx is not None, + num_splits_dynamic_ptr is not None, + semaphore_to_reset is not None, ) if compile_key not in _flash_attn_fwd_combine.compile_cache: _flash_attn_fwd_combine.compile_cache[compile_key] = _compile_fwd_combine( @@ -2484,3 +2599,189 @@ def flash_attn_combine( varlen_batch_idx=varlen_batch_idx, ) return out, lse + + +def get_scheduler_metadata( + num_batch: int, + max_seqlen_q: int, + max_seqlen_k: int, + nheads: int, + nheads_kv: int, + headdim: int, + num_splits: int, + tile_m: int, + tile_n: int, + headdim_v: Optional[int] = None, + pack_gqa: Optional[bool] = False, + causal: bool = False, + enable_pdl: bool = False, + sort: bool = False, + seqlen_k_new: int = 0, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k: Optional[torch.Tensor] = None, + cu_seqlens_k_new: Optional[torch.Tensor] = None, + seqused_q: Optional[torch.Tensor] = None, + seqused_k: Optional[torch.Tensor] = None, + leftpad_k: Optional[torch.Tensor] = None, + seqlen_k_per_split: Optional[int] = None, + zfill_padded_output: bool = True, +) -> SchedulerMetadataTensorsTorch: + device = None + for t in [cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k]: + if t is not None: + device = t.device + break + if device is None: + raise ValueError( + "At least one of cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k must be provided on device" + ) + if headdim_v is None: + headdim_v = headdim + + # Override enable_pdl (not supported yet) + enable_pdl = False + + # Override sort (not supported yet) + sort = False + + if seqlen_k_per_split is not None: + assert seqlen_k_per_split % tile_n == 0, "seqlen per split must be divisible by tile_n" + n_blocks_per_split = seqlen_k_per_split // tile_n + else: + n_blocks_per_split = None + + # Allocate metadata torch tensors + num_m_blocks = torch.empty(num_batch, dtype=torch.int32, device=device) + num_splits_dynamic = torch.empty(num_batch, dtype=torch.int32, device=device) + varlen_batch_idx = torch.empty(num_batch, dtype=torch.int32, device=device) if sort else None + num_nheads_in_l2 = torch.empty(num_batch, dtype=torch.int32, device=device) if causal else None + tile_count_semaphore = torch.empty(1, dtype=torch.int32, device=device) + + # Compute num_warps based on num_batch (capped at 32) + num_warps = min((num_batch + 30) // 31, 32) + # Round up to the nearest power of 2 + num_warps = 1 << (num_warps - 1).bit_length() + + cache_key = ( + num_warps, + tile_m, + tile_n, + nheads, + nheads_kv, + headdim, + headdim_v, + causal, + pack_gqa, + enable_pdl, + sort, + cu_seqlens_q is not None, + cu_seqlens_k is not None, + cu_seqlens_k_new is not None, + seqused_q is not None, + seqused_k is not None, + leftpad_k is not None, + num_m_blocks is not None, + num_splits_dynamic is not None, + varlen_batch_idx is not None, + num_nheads_in_l2 is not None, + tile_count_semaphore is not None, + n_blocks_per_split is not None, + zfill_padded_output, + ) + + if cache_key not in get_scheduler_metadata.compile_cache: + ( + num_m_blocks_cute, + num_splits_dynamic_cute, + varlen_batch_idx_cute, + num_nheads_in_l2_cute, + tile_count_semaphore_cute, + cu_seqlens_q_cute, + cu_seqlens_k_cute, + cu_seqlens_k_new_cute, + seqused_q_cute, + seqused_k_cute, + leftpad_k_cute, + ) = [ + to_cute_tensor(t, assumed_align=4) if t is not None else None + for t in ( + num_m_blocks, + num_splits_dynamic, + varlen_batch_idx, + num_nheads_in_l2, + tile_count_semaphore, + cu_seqlens_q, + cu_seqlens_k, + cu_seqlens_k_new, + seqused_q, + seqused_k, + leftpad_k, + ) + ] + scheduler = FlashPrepareScheduler( + num_warps, + tile_m, + tile_n, + nheads, + nheads_kv, + headdim, + headdim_v, + causal, + packgqa=pack_gqa, + sort=sort, + zfill_padded_output=zfill_padded_output, + ) + get_scheduler_metadata.compile_cache[cache_key] = cute.compile( + scheduler, + max_seqlen_q, + max_seqlen_k, + seqlen_k_new, + cu_seqlens_q_cute, + cu_seqlens_k_cute, + cu_seqlens_k_new_cute, + seqused_q_cute, + seqused_k_cute, + leftpad_k_cute, + num_batch, + num_splits, + tile_count_semaphore_cute, + num_m_blocks_cute, + num_splits_dynamic_cute, + varlen_batch_idx_cute, + num_nheads_in_l2_cute, + n_blocks_per_split, + cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), + options="--enable-tvm-ffi", + ) + + if not is_fake_mode(): + get_scheduler_metadata.compile_cache[cache_key]( + max_seqlen_q, + max_seqlen_k, + seqlen_k_new, + cu_seqlens_q, + cu_seqlens_k, + cu_seqlens_k_new, + seqused_q, + seqused_k, + leftpad_k, + num_batch, + num_splits, + tile_count_semaphore, + num_m_blocks, + num_splits_dynamic, + varlen_batch_idx, + num_nheads_in_l2, + n_blocks_per_split, + ) + + return SchedulerMetadataTensorsTorch( + num_m_blocks_ptr=num_m_blocks, + num_splits_dynamic_ptr=num_splits_dynamic, + varlen_batch_idx_ptr=varlen_batch_idx, + num_nheads_in_l2_ptr=num_nheads_in_l2, + tile_count_semaphore=tile_count_semaphore, + ) + + +get_scheduler_metadata.compile_cache = get_jit_cache("scheduler_metadata") diff --git a/flash_attn/cute/prepare_scheduler.py b/flash_attn/cute/prepare_scheduler.py new file mode 100644 index 00000000000..f56662aaea0 --- /dev/null +++ b/flash_attn/cute/prepare_scheduler.py @@ -0,0 +1,392 @@ +# A reimplementation of https://github.com/Dao-AILab/flash-attention/blob/main/hopper/flash_prepare_scheduler.cu +# from CUTLASS C++ to Cute-DSL. + +from typing import Tuple, Optional, Callable, List, NamedTuple +import operator +import torch +import cuda.bindings.driver as cuda +import cutlass +import cutlass.cute as cute +from cutlass import Boolean, Int32, const_expr, Constexpr, Float32 +from cutlass.cute import FastDivmodDivisor +import flash_attn.cute.utils as utils + + +class SchedulerMetadataTensorsTorch(NamedTuple): + """Class to store scheduler metadata for varlen""" + + # tensors of shape (batch) + num_m_blocks_ptr: Optional[torch.Tensor] + num_splits_dynamic_ptr: Optional[torch.Tensor] + varlen_batch_idx_ptr: Optional[torch.Tensor] + num_nheads_in_l2_ptr: Optional[torch.Tensor] + # tensor of shape (1) + tile_count_semaphore: Optional[torch.Tensor] + + +class FlashPrepareScheduler: + def __init__( + self, + num_warps: int, + tile_m: int, + tile_n: int, + nheads: int, + nheads_kv: int, + headdim: int, + headdim_v: Optional[int] = None, + is_causal: bool = False, + packgqa: bool = False, + sort: bool = False, + zfill_padded_output: bool = False, + ): + self.num_warps = num_warps + self.is_causal = is_causal + self.packgqa = packgqa + # TODO: Implement batch sort for LPT. + self.sort = False + self.num_threads_per_warp = 32 + self.tile_m = tile_m + self.tile_n = tile_n + self.d = headdim + self.dv = headdim_v if headdim_v is not None else headdim + self.k_num_batch_per_warp = 31 + self.k_smem_size = 1 + self.zfill_padded_output = zfill_padded_output + + # for pack gqa, query heads per kv head is combined with seqlen_q + self.nheads_computed = nheads if not self.packgqa else nheads_kv + + # L2 cache calculations + self.qhead_per_khead = nheads // nheads_kv + self.size_l2_divisor = ( + 1 + if self.qhead_per_khead == 1 + else ( + 2 + if self.qhead_per_khead <= 2 + else (4 if self.qhead_per_khead <= 4 else (8 if self.qhead_per_khead <= 8 else 16)) + ) + ) + self.size_l2 = (32 * 1024 * 1024) // self.size_l2_divisor + element_size = 2 + self.size_one_kvblock = self.tile_n * (self.d + self.dv) * element_size + self.max_kvblocks_in_l2 = ( + self.size_l2 + self.size_one_kvblock - 1 + ) // self.size_one_kvblock + + @staticmethod + def get_grid_shape(num_batch: int) -> Tuple[int, int, int]: + num_ctas = (num_batch + (31 * 32 - 1)) // (31 * 32) + return (num_ctas, 1, 1) + + @cute.jit + def __call__( + self, + seqlen_q_static: int, + seqlen_k_static: int, + seqlen_k_new_static: int, + mCuSeqlensQ: Optional[cute.Tensor], + mCuSeqlensK: Optional[cute.Tensor], + mCuSeqlensKNew: Optional[cute.Tensor], + mSeqUsedQ: Optional[cute.Tensor], + mSeqUsedK: Optional[cute.Tensor], + mLeftPadK: Optional[cute.Tensor], + num_batch: int, + num_splits_static: int, + tile_count_semaphore: Optional[cute.Tensor], + num_m_blocks_ptr: Optional[cute.Tensor], + num_splits_dynamic_ptr: Optional[cute.Tensor], + varlen_batch_idx_ptr: Optional[cute.Tensor], + num_nheads_in_l2_ptr: Optional[cute.Tensor], + n_blocks_per_split: Optional[int], # overrides heuristic + stream: cuda.CUstream, + ): + tile_m_divmod = FastDivmodDivisor(self.tile_m) + tile_n_divmod = FastDivmodDivisor(self.tile_n) + + @cute.struct + class SharedStorage: + total_blocks_smem: cute.struct.MemRange[Int32, self.k_smem_size] + + self.shared_storage = SharedStorage + + block = (32 * self.num_warps, 1, 1) + grid = self.get_grid_shape(num_batch) + + hardware_info = cutlass.utils.HardwareInfo() + num_sm = hardware_info.get_device_multiprocessor_count() + + self.kernel( + seqlen_q_static, + seqlen_k_static, + seqlen_k_new_static, + mCuSeqlensQ, + mCuSeqlensK, + mCuSeqlensKNew, + mSeqUsedQ, + mSeqUsedK, + mLeftPadK, + num_batch, + num_sm, + num_splits_static, + tile_m_divmod, + tile_n_divmod, + tile_count_semaphore, + num_m_blocks_ptr, + num_splits_dynamic_ptr, + varlen_batch_idx_ptr, + num_nheads_in_l2_ptr, + n_blocks_per_split, + ).launch( + grid=grid, + block=block, + stream=stream, + smem=self.shared_storage.size_in_bytes(), + ) + + @cute.kernel + def kernel( + self, + seqlen_q_static: Int32, + seqlen_k_static: Int32, + seqlen_k_new_static: Int32, + mCuSeqlensQ: Optional[cute.Tensor], + mCuSeqlensK: Optional[cute.Tensor], + mCuSeqlensKNew: Optional[cute.Tensor], + mSeqUsedQ: Optional[cute.Tensor], + mSeqUsedK: Optional[cute.Tensor], + mLeftPadK: Optional[cute.Tensor], + num_batch: Int32, + num_sm: Int32, + num_splits_static: Int32, + tile_m_divmod: FastDivmodDivisor, + tile_n_divmod: FastDivmodDivisor, + tile_count_semaphore: Optional[cute.Tensor], + num_m_blocks_ptr: Optional[cute.Tensor], + num_splits_dynamic_ptr: Optional[cute.Tensor], + varlen_batch_idx_ptr: Optional[cute.Tensor], + num_nheads_in_l2_ptr: Optional[cute.Tensor], + n_blocks_per_split: Optional[Int32], + ): + bidx, _, _ = cute.arch.block_idx() + tidx, _, _ = cute.arch.thread_idx() + grid_dimx, _, _ = cute.arch.grid_dim() + warp_idx = cute.arch.warp_idx() + lane_idx = cute.arch.lane_idx() + + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + total_blocks_smem = storage.total_blocks_smem.get_tensor((1,)) + + if tidx == 0: + total_blocks_smem[0] = Int32(0) + cute.arch.sync_threads() + + if const_expr(tile_count_semaphore is not None): + if tidx == 0: + tile_count_semaphore[0] = Int32(0) + + batch_cta_idx_offset = bidx * 992 + bidb_start = batch_cta_idx_offset + self.k_num_batch_per_warp * warp_idx + batch_idx = lane_idx + bidb_start + + num_m_blocks, seqlen_q = self.get_num_m_blocks_and_seqlen( + lane_idx, + batch_idx, + mSeqUsedQ, + mCuSeqlensQ, + seqlen_q_static, + tile_m_divmod, + num_batch, + ) + + num_n_blocks = self.get_num_n_blocks( + lane_idx, + batch_idx, + mSeqUsedK, + mCuSeqlensK, + mCuSeqlensKNew, + seqlen_k_static, + seqlen_k_new_static, + mLeftPadK, + tile_n_divmod, + num_batch, + ) + + num_splits_dynamic = Int32(1) + if const_expr(n_blocks_per_split is not None): + # print("n_blocks_per_splits = ", n_blocks_per_split) + num_splits_dynamic = cutlass.min( + cute.ceil_div(num_n_blocks, n_blocks_per_split), num_splits_static + ) + if const_expr(self.zfill_padded_output): + num_splits_dynamic = cutlass.max(num_splits_dynamic, Int32(1)) + if num_splits_dynamic > 0: + num_n_blocks = cute.ceil_div(num_n_blocks, num_splits_dynamic) + else: + if grid_dimx > 1 or num_splits_static == 1: + num_splits_dynamic = Int32(1) + else: + total_blocks = num_m_blocks * num_n_blocks + total_blocks = utils.warp_reduce(total_blocks, operator.add) + if lane_idx == 0: + utils.atomic_add_i32(total_blocks, total_blocks_smem.iterator) + + cute.arch.sync_threads() + + total_blocks = total_blocks_smem[0] + + sm_margin = max(Float32(num_sm) / 128 + 0.001, 1.1) # e.g. 148/128 = 1.15625 + blocks_per_sm = cutlass.max( + Int32( + ( + Float32(total_blocks) + * sm_margin + * Float32(self.nheads_computed) + / Float32(num_sm) + ) + ), + Int32(1), + ) + # blocks_per_sm = cute.ceil_div(total_blocks * self.nheads_computed, num_sm) + num_splits_dynamic = cutlass.min( + cute.ceil_div(num_n_blocks, blocks_per_sm), num_splits_static + ) + if const_expr(self.zfill_padded_output): + num_splits_dynamic = cutlass.max(num_splits_dynamic, Int32(1)) + if num_splits_dynamic > 0: + num_n_blocks = cute.ceil_div(num_n_blocks, num_splits_dynamic) + # if tidx == 0: + # cute.printf("num_batch = {}", num_batch) + # cute.printf("num_m_blocks = {}", num_m_blocks) + # cute.printf("num_n_blocks = {}", num_n_blocks) + # cute.printf("total_blocks = {}", total_blocks) + # cute.printf("numerator = {}", total_blocks * self.nheads_computed) + # cute.printf("denominator num_sm = {}", num_sm) + # cute.printf("blocks_per_sm = {}", blocks_per_sm) + # cute.printf("sm margin = {}", sm_margin) + # cute.printf("num_splits_dynamic = {}", num_splits_dynamic) + + if const_expr(self.sort): + # TODO: Implement sort logic + pass + + if batch_idx < num_batch and lane_idx < self.k_num_batch_per_warp: + if const_expr(num_m_blocks_ptr is not None): + num_m_blocks_ptr[batch_idx] = num_m_blocks + if const_expr(num_splits_dynamic_ptr is not None): + num_splits_dynamic_ptr[batch_idx] = num_splits_dynamic + if const_expr(num_nheads_in_l2_ptr is not None): + nheads_in_l2 = self.get_num_nheads_in_l2(num_n_blocks) + num_nheads_in_l2_ptr[batch_idx] = nheads_in_l2 + + @cute.jit + def get_num_m_blocks_and_seqlen( + self, + lane_idx: Int32, + batch_idx: Int32, + mSeqUsedQ: Optional[cute.Tensor], + mCuSeqlensQ: Optional[cute.Tensor], + seqlen_q_static: Int32, + tile_m_divmod: FastDivmodDivisor, + num_batch: Int32, + ): + seqlen = Int32(0) + if const_expr(mSeqUsedQ is not None): + seqlen = mSeqUsedQ[batch_idx] if batch_idx < num_batch else Int32(0) + elif const_expr(mCuSeqlensQ is not None): + # Since k_num_batch_per_warp = 31, lane 31 never processes batches + # So shuffle_down is safe: lane 30 gets lane 31's value (which is 0) + # Only access cu_seqlens if batch_idx is valid (0 to num_batch inclusive) + cur_cu_seqlen = Int32(0) + if batch_idx <= num_batch: + cur_cu_seqlen = mCuSeqlensQ[batch_idx] + next_cu_seqlen = cute.arch.shuffle_sync_down(cur_cu_seqlen, offset=1) + seqlen = next_cu_seqlen - cur_cu_seqlen + else: + seqlen = seqlen_q_static + + seqlen_for_blocks = seqlen + if const_expr(self.packgqa): + seqlen_for_blocks = seqlen * self.qhead_per_khead + num_m_blocks = ( + (seqlen_for_blocks + self.tile_m - 1) // tile_m_divmod + if batch_idx < num_batch and lane_idx < self.k_num_batch_per_warp + else Int32(0) + ) + return (num_m_blocks, seqlen) + + @cute.jit + def get_num_n_blocks( + self, + lane_idx: Int32, + batch_idx: Int32, + mSeqUsedK: Optional[cute.Tensor], + mCuSeqlensK: Optional[cute.Tensor], + mCuSeqlensKNew: Optional[cute.Tensor], + seqlen_k_static: Int32, + seqlen_k_new_static: Int32, + mLeftPadK: Optional[cute.Tensor], + tile_n_divmod: FastDivmodDivisor, + num_batch: Int32, + ): + leftpad_k = ( + mLeftPadK[batch_idx] + if const_expr(mLeftPadK is not None) and batch_idx < num_batch + else Int32(0) + ) + seqlen = Int32(0) + if const_expr(mSeqUsedK is not None): + seqlen = mSeqUsedK[batch_idx] if batch_idx < num_batch else Int32(0) + elif const_expr(mCuSeqlensK is not None): + # Since k_num_batch_per_warp = 31, lane 31 never processes batches + # So shuffle_down is safe: lane 30 gets lane 31's value (which is 0) + # Only access cu_seqlens if batch_idx is valid (0 to num_batch inclusive) + cur_cu_seqlen = Int32(0) + if batch_idx <= num_batch: + cur_cu_seqlen = mCuSeqlensK[batch_idx] + next_cu_seqlen = cute.arch.shuffle_sync_down(cur_cu_seqlen, offset=1) + seqlen = next_cu_seqlen - cur_cu_seqlen + else: + seqlen = seqlen_k_static + + seqlen_new = Int32(0) + if const_expr(mCuSeqlensKNew is not None): + # Since k_num_batch_per_warp = 31, lane 31 never processes batches + # So shuffle_down is safe: lane 30 gets lane 31's value (which is 0) + # Only access cu_seqlens if batch_idx is valid (0 to num_batch inclusive) + cur_cu_seqlen_new = Int32(0) + if batch_idx <= num_batch: + cur_cu_seqlen_new = mCuSeqlensKNew[batch_idx] + next_cu_seqlen_new = cute.arch.shuffle_sync_down(cur_cu_seqlen_new, offset=1) + seqlen_new = next_cu_seqlen_new - cur_cu_seqlen_new + else: + seqlen_new = seqlen_k_new_static + seqlen = seqlen - leftpad_k + seqlen_new + return ( + (seqlen + self.tile_n - 1) // tile_n_divmod + if batch_idx < num_batch and lane_idx < self.k_num_batch_per_warp + else Int32(0) + ) + + @cute.jit + def get_num_nheads_in_l2( + self, + num_n_blocks: Int32, + ): + max_kvblocks_in_l2 = self.max_kvblocks_in_l2 + qhead_per_khead = self.qhead_per_khead + nheads_in_l2 = Int32(16) + if num_n_blocks * Int32(16) <= max_kvblocks_in_l2: + nheads_in_l2 = Int32(16) + elif num_n_blocks * Int32(8) <= max_kvblocks_in_l2: + nheads_in_l2 = Int32(8) + elif num_n_blocks * Int32(4) <= max_kvblocks_in_l2: + nheads_in_l2 = Int32(4) + elif num_n_blocks * Int32(2) <= max_kvblocks_in_l2: + nheads_in_l2 = Int32(2) + else: + nheads_in_l2 = Int32(1) + if const_expr(not self.packgqa): + nheads_in_l2 *= qhead_per_khead + return cutlass.min(nheads_in_l2, self.nheads_computed) diff --git a/flash_attn/cute/tile_scheduler.py b/flash_attn/cute/tile_scheduler.py index ff820e59626..31f39b75e63 100644 --- a/flash_attn/cute/tile_scheduler.py +++ b/flash_attn/cute/tile_scheduler.py @@ -38,35 +38,64 @@ class SchedulingMode(IntEnum): @dataclass -class ClcState(ParamsBase): - """Owns the runtime state shared by CLC-capable tile schedulers. - - `FlashAttentionForwardSm100` constructs this state because it owns the CLC - response buffer, mbarrier storage, and launch geometry needed to initialize - the hardware scheduler and async pipeline. Individual tile schedulers then - consume this state and map the returned hardware work tiles into their own - logical `WorkTileInfo` coordinates. - - To add CLC support to a scheduler: - - implement `clc_problem_shape(params)` so the kernel can create the hardware scheduler - - accept `clc: ClcState | None` in `create(...)` / `__init__` - - map `clc.initial_work_tile_info()` and `clc.get_current_work()` into scheduler coordinates +class SchedulerState(ParamsBase): + """Owns the runtime state shared by CLC and dynamic persistent tile schedulers. + + Main kernel constructs this state because it owns the + response buffer / work_info smem, mbarrier storage, and launch geometry + needed to initialize the backend (CLC hardware scheduler or atomic-counter + work_info region) and the async pipeline. Individual tile schedulers then + consume this state and map the returned work tiles into their own logical + `WorkTileInfo` coordinates. + + Tagged by `scheduling_mode`: + - CLC: `_hw_scheduler` is set; `prefetch_next_work` issues the HW query. + - DYNAMIC: `_work_info` is set; the scheduler class does its own + atomicAdd + warp-prefix-sum and writes via `write_work_info`. """ - _hw_scheduler: ClcDynamicPersistentTileScheduler - _pipeline: PipelineClcFetchAsync + scheduling_mode: cutlass.Constexpr[SchedulingMode] + _pipeline: cutlass.pipeline.PipelineAsync _consumer_state: PipelineState _producer_state: PipelineState + _hw_scheduler: Optional[ClcDynamicPersistentTileScheduler] = None + _work_info: Optional[cute.Tensor] = None @staticmethod - def create( + def create_clc( *, hw_scheduler: ClcDynamicPersistentTileScheduler, pipeline: PipelineClcFetchAsync, consumer_state: PipelineState, producer_state: PipelineState, - ) -> "ClcState": - return ClcState(hw_scheduler, pipeline, consumer_state, producer_state) + ) -> "SchedulerState": + return SchedulerState( + SchedulingMode.CLC, + pipeline, + consumer_state, + producer_state, + hw_scheduler, + None, + ) + + @staticmethod + def create_dynamic_persistent( + *, + work_info: cute.Tensor, + pipeline: cutlass.pipeline.PipelineAsync, + consumer_state: PipelineState, + producer_state: PipelineState, + ) -> "SchedulerState": + return SchedulerState( + SchedulingMode.DYNAMIC, + pipeline, + consumer_state, + producer_state, + None, + work_info, + ) + + # ---- CLC-mode ---- def initial_work_tile_info(self): return self._hw_scheduler.initial_work_tile_info() @@ -80,6 +109,25 @@ def prefetch_next_work(self, *, loc=None, ip=None): self._hw_scheduler.advance_to_next_work(mbarrier_addr, loc=loc, ip=ip) self._producer_state.advance(loc=loc, ip=ip) + # ---- Dynamic-persistent ---- + + def producer_acquire(self, *, loc=None, ip=None): + self._pipeline.producer_acquire(self._producer_state, loc=loc, ip=ip) + + def producer_commit(self, *, loc=None, ip=None): + self._pipeline.producer_commit(self._producer_state, loc=loc, ip=ip) + + def advance_producer_state(self, *, loc=None, ip=None): + self._producer_state.advance(loc=loc, ip=ip) + + def write_work_info(self, block: Int32, head: Int32, batch: Int32, split: Int32): + self._work_info[0] = block + self._work_info[1] = head + self._work_info[2] = batch + self._work_info[3] = split + + # ---- Common ---- + def consumer_wait(self, *, loc=None, ip=None): self._pipeline.consumer_wait(self._consumer_state, loc=loc, ip=ip) @@ -87,9 +135,62 @@ def consumer_release(self, *, loc=None, ip=None): self._pipeline.consumer_release(self._consumer_state, loc=loc, ip=ip) self._consumer_state.advance(loc=loc, ip=ip) + def advance_consumer_state(self, *, loc=None, ip=None): + self._consumer_state.advance(loc=loc, ip=ip) + def producer_tail(self, *, loc=None, ip=None): self._pipeline.producer_tail(self._producer_state, loc=loc, ip=ip) + def __extract_mlir_values__(self): + ordered = [ + self.scheduling_mode, + self._pipeline, + self._consumer_state, + self._producer_state, + self._hw_scheduler, + self._work_info, + ] + values, self._values_pos = [], [] + for obj in ordered: + if obj is None or isinstance( + obj, (cutlass.Constexpr, int, bool, str, float, type(None)) + ): + self._values_pos.append(0) + continue + obj_values = cutlass.extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + ordered = [ + self.scheduling_mode, + self._pipeline, + self._consumer_state, + self._producer_state, + self._hw_scheduler, + self._work_info, + ] + rebuilt = [] + for obj, n_items in zip(ordered, self._values_pos): + if n_items == 0: + rebuilt.append(obj) + else: + rebuilt.append(cutlass.new_from_mlir_values(obj, values[:n_items])) + values = values[n_items:] + return SchedulerState( + scheduling_mode=rebuilt[0], + _pipeline=rebuilt[1], + _consumer_state=rebuilt[2], + _producer_state=rebuilt[3], + _hw_scheduler=rebuilt[4], + _work_info=rebuilt[5], + ) + + +# Deprecated alias; remove after downstream call sites are updated. +ClcState = SchedulerState + class WorkTileInfo(cutlass.utils.WorkTileInfo): """Altered WorkTileInfo which includes four axes: (block, head, batch, split)""" @@ -108,7 +209,7 @@ class TileSchedulerProtocol(Protocol): Schedulers are responsible for: 1. Coordinate mapping: linear tile index -> (m_block, head, batch, split) - 2. Work distribution: how to get the next tile (static grid-stride vs CLC dynamic) + 2. Work distribution: how to get the next tile (static grid-stride vs dynamic) """ def get_current_work(self) -> WorkTileInfo: @@ -123,14 +224,14 @@ def advance_to_next_work(self, *, loc=None, ip=None): """Consumer-side advance: move to next tile and return it. For static schedulers: grid-stride increment + get_current_work. - For CLC schedulers: consumer wait + get_current_work + consumer release + state advance. + For dynamic schedulers: consumer wait + get_current_work + consumer release + state advance. """ ... def prefetch_next_work(self, *, loc=None, ip=None) -> None: """Producer-side prefetch of next work tile (no-op for static schedulers). - For CLC schedulers: producer acquire + issue CLC query + producer state advance. + For dynamic schedulers: producer acquire (+ issue CLC query) + producer state advance. Only called by the scheduler warp. """ ... @@ -138,7 +239,7 @@ def prefetch_next_work(self, *, loc=None, ip=None) -> None: def producer_tail(self, *, loc=None, ip=None) -> None: """Producer-side cleanup after the last tile. - No-op for static schedulers. For CLC schedulers: pipeline producer_tail. + No-op for static schedulers. For dynamic schedulers: pipeline producer_tail. """ ... @@ -164,6 +265,12 @@ class TileSchedulerArguments(ParamsBase): is_split_kv: cutlass.Constexpr[bool] = False head_swizzle: cutlass.Constexpr[bool] = False use_cluster_idx: cutlass.Constexpr[bool] = False + num_splits_dynamic_ptr: Optional[cute.Tensor] = None + num_m_blocks_ptr: Optional[cute.Tensor] = None + varlen_batch_idx_ptr: Optional[cute.Tensor] = None + num_nheads_in_l2_ptr: Optional[cute.Tensor] = None + tile_count_semaphore: Optional[cute.Pointer] = None + persistent_cta_multiplier: cutlass.Constexpr[int] = 1 class SingleTileScheduler: @@ -215,7 +322,7 @@ def to_underlying_arguments( @staticmethod def create( - params: Params, clc: ClcState | None = None, *, loc=None, ip=None + params: Params, ctx: SchedulerState | None = None, *, loc=None, ip=None ) -> "SingleTileScheduler": if const_expr(cute.size(params.cluster_shape_mn) == 1 or not params.use_cluster_idx): blk_coord = cute.arch.block_idx() @@ -326,7 +433,7 @@ def to_underlying_arguments( @staticmethod def create( - params: Params, clc: ClcState | None = None, *, loc=None, ip=None + params: Params, ctx: SchedulerState | None = None, *, loc=None, ip=None ) -> "StaticPersistentTileScheduler": if const_expr(cute.size(params.cluster_shape_m) == 1): tile_idx = cute.arch.block_idx()[0] @@ -461,7 +568,7 @@ def __init__( params: Params, tile_idx: Int32, split_idx: Int32, - clc: ClcState | None = None, + ctx: SchedulerState | None = None, *, loc=None, ip=None, @@ -469,7 +576,7 @@ def __init__( self.params = params self._tile_idx = tile_idx self._split_idx = split_idx - self.clc = clc + self.ctx = ctx self._loc = loc self._ip = ip @@ -509,11 +616,11 @@ def clc_problem_shape(params: Params): @staticmethod @cute.jit def create( - params: Params, clc: ClcState | None = None, *, loc=None, ip=None + params: Params, ctx: SchedulerState | None = None, *, loc=None, ip=None ) -> "SingleTileLPTScheduler": if const_expr(params.scheduling_mode == SchedulingMode.CLC): return SingleTileLPTScheduler( - params, cute.arch.block_idx()[0], Int32(0), clc, loc=loc, ip=ip + params, cute.arch.block_idx()[0], Int32(0), ctx, loc=loc, ip=ip ) tile_idx, split_idx, _ = cute.arch.block_idx() return SingleTileLPTScheduler(params, tile_idx, split_idx, loc=loc, ip=ip) @@ -562,7 +669,7 @@ def clc_work_to_coords(self, work) -> WorkTileInfo: @cute.jit def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): - work = self.clc.get_current_work() + work = self.ctx.get_current_work() self._tile_idx = work.tile_idx[0] return self.clc_work_to_coords(work) # Static path: L2-swizzled coordinate mapping @@ -589,20 +696,20 @@ def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: @cute.jit def initial_work_tile_info(self, *, loc=None, ip=None): if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): - work = self.clc.initial_work_tile_info() + work = self.ctx.initial_work_tile_info() self._tile_idx = work.tile_idx[0] return self.clc_work_to_coords(work) return self.get_current_work(loc=loc, ip=ip) def prefetch_next_work(self, *, loc=None, ip=None): if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): - self.clc.prefetch_next_work(loc=loc, ip=ip) + self.ctx.prefetch_next_work(loc=loc, ip=ip) def advance_to_next_work(self, *, loc=None, ip=None): if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): - self.clc.consumer_wait(loc=loc, ip=ip) + self.ctx.consumer_wait(loc=loc, ip=ip) work = self.get_current_work() - self.clc.consumer_release(loc=loc, ip=ip) + self.ctx.consumer_release(loc=loc, ip=ip) return work # Single tile scheduler - set to invalid tile_idx to indicate no more work self._tile_idx = self.params.total_blocks @@ -610,13 +717,13 @@ def advance_to_next_work(self, *, loc=None, ip=None): def producer_tail(self, *, loc=None, ip=None): if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): - self.clc.producer_tail(loc=loc, ip=ip) + self.ctx.producer_tail(loc=loc, ip=ip) def __extract_mlir_values__(self): values, self._values_pos = [], [] objs = [self.params, self._tile_idx, self._split_idx] if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): - objs += [self.clc] + objs += [self.ctx] for obj in objs: obj_values = cutlass.extract_mlir_values(obj) values += obj_values @@ -627,7 +734,7 @@ def __new_from_mlir_values__(self, values): obj_list = [] objs = [self.params, self._tile_idx, self._split_idx] if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): - objs += [self.clc] + objs += [self.ctx] for obj, n_items in zip(objs, self._values_pos): obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) values = values[n_items:] @@ -838,7 +945,7 @@ def __init__( params: Params, tile_idx: Int32, split_idx: Int32, - clc: ClcState | None = None, + ctx: SchedulerState | None = None, *, loc=None, ip=None, @@ -847,7 +954,7 @@ def __init__( self._tile_idx = tile_idx self._split_idx = split_idx self._is_first_block = True - self.clc = clc + self.ctx = ctx self._loc = loc self._ip = ip @@ -874,7 +981,7 @@ def clc_problem_shape(params: Params): @staticmethod @cute.jit def create( - params: Params, clc: ClcState | None = None, *, loc=None, ip=None + params: Params, ctx: SchedulerState | None = None, *, loc=None, ip=None ) -> "SingleTileVarlenScheduler": if const_expr(params.scheduling_mode == SchedulingMode.CLC): block_idx = cute.arch.block_idx() @@ -885,7 +992,7 @@ def create( params, block_idx[0], split_idx, - clc, + ctx, loc=loc, ip=ip, ) @@ -1034,7 +1141,7 @@ def _varlen_coord_map(self) -> WorkTileInfo: @cute.jit def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): - clc_work = self.clc.get_current_work() + clc_work = self.ctx.get_current_work() # Default to grid_dim (one past last valid flat index) so _varlen_coord_map # returns is_valid=False when CLC is exhausted. CLC tile_idx is garbage when # invalid, so we can't trust it. Local-then-assign avoids CuTe DSL structural @@ -1052,7 +1159,7 @@ def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: @cute.jit def initial_work_tile_info(self, *, loc=None, ip=None): if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): - clc_work = self.clc.initial_work_tile_info() + clc_work = self.ctx.initial_work_tile_info() # See get_current_work for why grid_dim and local-then-assign. new_tile_idx = cute.arch.grid_dim()[0] new_split_idx = Int32(0) @@ -1066,26 +1173,26 @@ def initial_work_tile_info(self, *, loc=None, ip=None): def prefetch_next_work(self, *, loc=None, ip=None): if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): - self.clc.prefetch_next_work(loc=loc, ip=ip) + self.ctx.prefetch_next_work(loc=loc, ip=ip) def advance_to_next_work(self, *, loc=None, ip=None): if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): - self.clc.consumer_wait(loc=loc, ip=ip) + self.ctx.consumer_wait(loc=loc, ip=ip) work = self.get_current_work() - self.clc.consumer_release(loc=loc, ip=ip) + self.ctx.consumer_release(loc=loc, ip=ip) return work self._is_first_block = False return self.get_current_work() def producer_tail(self, *, loc=None, ip=None): if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): - self.clc.producer_tail(loc=loc, ip=ip) + self.ctx.producer_tail(loc=loc, ip=ip) def __extract_mlir_values__(self): values, self._values_pos = [], [] objs = [self.params, self._tile_idx, self._split_idx] if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): - objs += [self.clc] + objs += [self.ctx] for obj in objs: obj_values = cutlass.extract_mlir_values(obj) values += obj_values @@ -1096,13 +1203,380 @@ def __new_from_mlir_values__(self, values): obj_list = [] objs = [self.params, self._tile_idx, self._split_idx] if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): - objs += [self.clc] + objs += [self.ctx] for obj, n_items in zip(objs, self._values_pos): obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) values = values[n_items:] return self.__class__(*obj_list, loc=self._loc) +class DynamicPersistentVarlenScheduler: + @dataclass + class Params(ParamsBase): + num_head: Int32 + num_batch: Int32 + total_q: Int32 + num_splits: Int32 + max_kvblock_in_l2: Int32 + tile_shape_mn: cutlass.Constexpr[Tuple[int, int]] + mCuSeqlensQ: Optional[cute.Tensor] = None + mSeqUsedQ: Optional[cute.Tensor] = None + qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1 + lpt: cutlass.Constexpr[bool] = False + is_split_kv: cutlass.Constexpr[bool] = False + num_splits_dynamic_ptr: Optional[cute.Tensor] = None + num_m_blocks_ptr: Optional[cute.Tensor] = None + varlen_batch_idx_ptr: Optional[cute.Tensor] = None + num_nheads_in_l2_ptr: Optional[cute.Tensor] = None + tile_count_semaphore: Optional[cute.Pointer] = None + persistent_cta_multiplier: cutlass.Constexpr[int] = 1 + + @staticmethod + @cute.jit + def create( + args: TileSchedulerArguments, *, loc=None, ip=None + ) -> "DynamicPersistentVarlenScheduler.Params": + size_l2 = 50 * 1024 * 1024 # 50 MB for K & V + max_kvblock_in_l2 = size_l2 // ( + (args.headdim + args.headdim_v) * args.element_size * args.tile_shape_mn[1] + ) + assert args.mCuSeqlensQ is not None or args.mSeqUsedQ is not None, ( + "At least one of mCuSeqlensQ or mSeqUsedQ must be provided" + ) + return DynamicPersistentVarlenScheduler.Params( + num_head=args.num_head, + num_batch=args.num_batch, + total_q=args.total_q, + num_splits=args.num_splits, + max_kvblock_in_l2=max_kvblock_in_l2, + tile_shape_mn=args.tile_shape_mn, + mCuSeqlensQ=args.mCuSeqlensQ, + mSeqUsedQ=args.mSeqUsedQ, + qhead_per_kvhead_packgqa=args.qhead_per_kvhead_packgqa, + lpt=args.lpt, + is_split_kv=args.is_split_kv, + num_splits_dynamic_ptr=args.num_splits_dynamic_ptr, + num_m_blocks_ptr=args.num_m_blocks_ptr, + varlen_batch_idx_ptr=args.varlen_batch_idx_ptr, + num_nheads_in_l2_ptr=args.num_nheads_in_l2_ptr, + tile_count_semaphore=args.tile_count_semaphore, + persistent_cta_multiplier=args.persistent_cta_multiplier, + ) + + def __init__( + self, + params: Params, + ctx: SchedulerState, + *, + loc=None, + ip=None, + ): + self.params = params + self._ctx = ctx + self._loc = loc + self._ip = ip + + @staticmethod + def to_underlying_arguments( + args: TileSchedulerArguments, + *, + scheduling_mode: SchedulingMode = SchedulingMode.DYNAMIC, + loc=None, + ip=None, + ) -> Params: + assert scheduling_mode == SchedulingMode.DYNAMIC, ( + f"DynamicPersistentVarlenScheduler only supports DYNAMIC, got {scheduling_mode!r}" + ) + return DynamicPersistentVarlenScheduler.Params.create(args, loc=loc, ip=ip) + + @staticmethod + def create( + params: Params, + ctx: SchedulerState, + *, + loc=None, + ip=None, + ) -> "DynamicPersistentVarlenScheduler": + return DynamicPersistentVarlenScheduler(params, ctx, loc=loc, ip=ip) + + # called by host + @staticmethod + def get_grid_shape( + params: Params, + *, + loc=None, + ip=None, + ) -> Tuple[Int32, Int32, Int32]: + total_blocks_max = ( + params.total_q + params.num_batch * (params.tile_shape_mn[0] - 1) + ) // params.tile_shape_mn[0] + total_blocks = total_blocks_max * params.num_head * params.num_splits + hardware_info = HardwareInfo() + sm_count = ( + hardware_info.get_device_multiprocessor_count() * params.persistent_cta_multiplier + ) + return (cutlass.min(sm_count, total_blocks), Int32(1), Int32(1)) + + @cute.jit + def _get_num_m_blocks(self, lane: Int32, bidb_start: Int32) -> Int32: + params = self.params + batch_idx = lane + bidb_start + if cutlass.const_expr(params.varlen_batch_idx_ptr is not None): + if cutlass.const_expr(params.num_m_blocks_ptr is not None): + # num_m_blocks is at virtual idx (prepare_scheduler writes by vbidx) + n = Int32(0) + if batch_idx < params.num_batch and lane < cute.arch.WARP_SIZE - 1: + n = params.num_m_blocks_ptr[batch_idx] + return n + seqlen = Int32(0) + if batch_idx < params.num_batch and lane < cute.arch.WARP_SIZE - 1: + real_batch_idx = params.varlen_batch_idx_ptr[batch_idx] + seqlen = params.mCuSeqlensQ[real_batch_idx + 1] - params.mCuSeqlensQ[real_batch_idx] + if cutlass.const_expr(params.qhead_per_kvhead_packgqa > 1): + seqlen *= params.qhead_per_kvhead_packgqa + return ( + cute.ceil_div(seqlen, params.tile_shape_mn[0]) + if batch_idx < params.num_batch and lane < cute.arch.WARP_SIZE - 1 + else Int32(0) + ) + if cutlass.const_expr(params.num_m_blocks_ptr is not None): + n = Int32(0) + if batch_idx < params.num_batch and lane < cute.arch.WARP_SIZE - 1: + n = params.num_m_blocks_ptr[batch_idx] + return n + if cutlass.const_expr(params.mSeqUsedQ is not None): + seqlen = Int32(0) + if batch_idx < params.num_batch: + seqlen = params.mSeqUsedQ[batch_idx] + else: + assert params.mCuSeqlensQ is not None + cur_cu_seqlen = Int32(0) + if batch_idx <= params.num_batch: + cur_cu_seqlen = params.mCuSeqlensQ[batch_idx] + next_cu_seqlen = cute.arch.shuffle_sync_down(cur_cu_seqlen, offset=1) + seqlen = next_cu_seqlen - cur_cu_seqlen + if cutlass.const_expr(params.qhead_per_kvhead_packgqa > 1): + seqlen *= params.qhead_per_kvhead_packgqa + return ( + cute.ceil_div(seqlen, params.tile_shape_mn[0]) + if batch_idx < params.num_batch and lane < cute.arch.WARP_SIZE - 1 + else Int32(0) + ) + + @cute.jit + def _get_num_splits(self, lane: Int32, bidb_start: Int32) -> Int32: + params = self.params + batch_idx = lane + bidb_start + is_valid = batch_idx < params.num_batch and lane < cute.arch.WARP_SIZE - 1 + if cutlass.const_expr(not params.is_split_kv): + return Int32(1) + elif cutlass.const_expr(params.num_splits_dynamic_ptr is not None): + num_splits = Int32(0) + if is_valid: + num_splits = params.num_splits_dynamic_ptr[batch_idx] + return num_splits + else: + return Int32(0) if not is_valid else params.num_splits + + @cute.jit + def get_current_work( + self, + next_tile_idx: Int32, + bidb_start: Int32, + group_start_tile: Int32, + *, + loc=None, + ip=None, + ) -> WorkTileInfo: + params = self.params + lane_idx = cute.arch.lane_idx() + num_m_blocks = self._get_num_m_blocks(lane_idx, bidb_start=bidb_start) + num_splits = self._get_num_splits(lane_idx, bidb_start=bidb_start) + num_splits_m_blocks = ( + num_m_blocks * num_splits if const_expr(params.is_split_kv) else num_m_blocks + ) + num_m_blocks_cumulative = utils.warp_prefix_sum(num_splits_m_blocks, lane_idx) + # Total number of blocks for the next 31 batches + m_blocks_in_group = cute.arch.shuffle_sync(num_m_blocks_cumulative, cute.arch.WARP_SIZE - 1) + # Same for all lanes + group_end_tile = m_blocks_in_group * params.num_head + group_start_tile + + block, head_idx, batch_idx, split_idx = Int32(0), Int32(0), bidb_start, Int32(0) + while group_end_tile <= next_tile_idx: + batch_idx += cute.arch.WARP_SIZE - 1 + if batch_idx >= params.num_batch: + batch_idx = Int32(params.num_batch) + group_end_tile = next_tile_idx + 1 + else: + num_m_blocks = self._get_num_m_blocks(lane_idx, bidb_start=batch_idx) + num_splits = self._get_num_splits(lane_idx, bidb_start=batch_idx) + num_splits_m_blocks = ( + num_m_blocks * num_splits if const_expr(params.is_split_kv) else num_m_blocks + ) + num_m_blocks_cumulative = utils.warp_prefix_sum(num_splits_m_blocks, lane_idx) + m_blocks_in_group = cute.arch.shuffle_sync( + num_m_blocks_cumulative, cute.arch.WARP_SIZE - 1 + ) + group_end_tile += m_blocks_in_group * params.num_head + is_valid = batch_idx < params.num_batch + if is_valid: + group_start_tile = group_end_tile - m_blocks_in_group * params.num_head + # The next problem to process is the first one that does not have ending tile position + # that is less than or equal to tile index. + batch_idx_in_group = cute.arch.popc( + cute.arch.vote_ballot_sync( + group_start_tile + num_m_blocks_cumulative * params.num_head <= next_tile_idx + ) + ) + batch_idx += batch_idx_in_group + num_m_blocks_prev_lane = ( + 0 + if batch_idx_in_group == 0 + else cute.arch.shuffle_sync(num_m_blocks_cumulative, batch_idx_in_group - 1) + ) + group_start_tile += num_m_blocks_prev_lane * params.num_head + num_m_blocks = cute.arch.shuffle_sync(num_m_blocks, batch_idx_in_group) + if const_expr(params.is_split_kv): + num_splits = cute.arch.shuffle_sync(num_splits, batch_idx_in_group) + mh_block = next_tile_idx - group_start_tile + if const_expr(params.lpt): + if const_expr(not params.is_split_kv) or num_splits == 1: + # This is a version of the SingleTileLPTScheduler, complicated by the fact that + # the seqlen can vary per batch. + if const_expr(params.num_nheads_in_l2_ptr is not None): + nheads_in_l2 = Int32(params.num_nheads_in_l2_ptr[batch_idx]) + else: + # TODO: by right we should read the seqlen_kv but we're assuming seqlen_q == seqlen_k here + num_n_blocks = ( + num_m_blocks + * params.tile_shape_mn[0] + // params.qhead_per_kvhead_packgqa + // params.tile_shape_mn[1] + ) + # Seems faster to have this be a power of 2 + nheads_in_l2 = ( + 16 + if num_n_blocks * 16 <= params.max_kvblock_in_l2 + else ( + 8 + if num_n_blocks * 8 <= params.max_kvblock_in_l2 + else ( + 4 + if num_n_blocks * 4 <= params.max_kvblock_in_l2 + else (2 if num_n_blocks * 2 <= params.max_kvblock_in_l2 else 1) + ) + ) + ) + nheads_in_l2 = min(nheads_in_l2, params.num_head) + mh_in_l2 = nheads_in_l2 * num_m_blocks + section_idx = mh_block // mh_in_l2 + l2_mod = mh_block - section_idx * mh_in_l2 + # Deal with tail section + nheads_in_this_section = ( + nheads_in_l2 + if nheads_in_l2 * (section_idx + 1) <= params.num_head + else params.num_head - section_idx * nheads_in_l2 + ) + block = l2_mod // nheads_in_this_section + head_idx_residual = l2_mod - block * nheads_in_this_section + head_idx = section_idx * nheads_in_l2 + head_idx_residual + else: + head_split_idx = mh_block // num_m_blocks + block = mh_block - head_split_idx * num_m_blocks + if const_expr(params.is_split_kv): + head_idx = head_split_idx // num_splits + split_idx = head_split_idx - head_idx * num_splits + else: + head_idx = head_split_idx + + block = num_m_blocks - 1 - block + else: + head_split_idx = mh_block // num_m_blocks + block = mh_block - head_split_idx * num_m_blocks + if const_expr(params.is_split_kv): + head_idx = head_split_idx // num_splits + split_idx = head_split_idx - head_idx * num_splits + else: + head_idx = head_split_idx + + # Pack num_splits into top 16 bits of split_idx + if const_expr(params.is_split_kv and params.num_splits_dynamic_ptr is not None): + if is_valid: + split_idx = split_idx | (num_splits << 16) + if const_expr(params.varlen_batch_idx_ptr is not None): + if is_valid: + batch_idx = params.varlen_batch_idx_ptr[batch_idx] + + return ( + WorkTileInfo( + (Int32(block), Int32(head_idx), Int32(batch_idx), Int32(split_idx)), is_valid + ), + group_start_tile, + ) + + @cute.jit + def prefetch_next_work(self, batch_idx, tile_idx, *, loc=None, ip=None): + ctx = self._ctx + next_tile_idx = Int32(0) + if cute.arch.lane_idx() == 0: + next_tile_idx = cute.arch.grid_dim()[0] + utils.atomic_add_i32( + 1, + self.params.tile_count_semaphore, + ) + next_tile_idx = cute.arch.shuffle_sync(next_tile_idx, 0) + work_info, new_tile_idx = self.get_current_work(next_tile_idx, batch_idx, tile_idx) + ctx.producer_acquire() + with cute.arch.elect_one(): + block, head_idx, batch_idx, split_idx = work_info.tile_idx + ctx.write_work_info(block, head_idx, batch_idx, split_idx) + ctx.producer_commit() + ctx.advance_producer_state() + return new_tile_idx + + @cute.jit + def advance_to_next_work(self, *, loc=None, ip=None) -> WorkTileInfo: + ctx = self._ctx + ctx.consumer_wait() + block = ctx._work_info[0] + head_idx = ctx._work_info[1] + batch_idx = ctx._work_info[2] + split_idx = ctx._work_info[3] + is_valid = batch_idx < self.params.num_batch + work_info = WorkTileInfo((block, head_idx, batch_idx, split_idx), is_valid) + ctx.consumer_release() + return work_info + + @cute.jit + def initial_work_tile_info(self, *, loc=None, ip=None) -> WorkTileInfo: + cta_tile_idx, _, _ = cute.arch.block_idx() + work_info, _ = self.get_current_work(cta_tile_idx, Int32(0), Int32(0)) + return work_info + + @cute.jit + def initial_sched_state(self, *, loc=None, ip=None): + cta_tile_idx, _, _ = cute.arch.block_idx() + work_info, group_start_tile = self.get_current_work(cta_tile_idx, Int32(0), Int32(0)) + return work_info, group_start_tile + + def producer_tail(self, *, loc=None, ip=None): + self._ctx.producer_tail(loc=loc, ip=ip) + + def __extract_mlir_values__(self): + values, self._values_pos = [], [] + for obj in [self.params, self._ctx]: + obj_values = cutlass.extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + obj_list = [] + for obj, n_items in zip([self.params, self._ctx], self._values_pos): + obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) + values = values[n_items:] + return DynamicPersistentVarlenScheduler(*(tuple(obj_list)), loc=self._loc) + + # ----------------------------------------------------------------------------- # SM100 FMHA-specific schedulers (kept separate from generic schedulers). # ----------------------------------------------------------------------------- diff --git a/flash_attn/cute/utils.py b/flash_attn/cute/utils.py index c8398c9a78d..7ffc85a8092 100644 --- a/flash_attn/cute/utils.py +++ b/flash_attn/cute/utils.py @@ -466,6 +466,13 @@ def fadd_reduce( return local_sum[0][0] + local_sum[0][1] +@dsl_user_op +def atomic_add_i32(a: int | Int32, ptr: cute.Pointer, *, loc=None, ip=None) -> Int32: + return nvvm.atomicrmw( + res=T.i32(), op=nvvm.AtomicOpKind.ADD, ptr=ptr.llvm_ptr, a=Int32(a).ir_value() + ) + + @dsl_user_op def atomic_add_fp32(a: float | Float32, gmem_ptr: cute.Pointer, *, loc=None, ip=None) -> None: # gmem_ptr_i64 = gmem_ptr.toint(loc=loc, ip=ip).ir_value() From e601530eced6b4013e08e6ba51b97f1d8cc96168 Mon Sep 17 00:00:00 2001 From: reubenconducts Date: Wed, 6 May 2026 19:48:19 +0000 Subject: [PATCH 02/15] mild refactor to tile scheduler protocol, guard num_m_blocks_ptr for sm100, update tests to use scheduler metadata --- flash_attn/cute/flash_fwd_sm100.py | 21 ++------ flash_attn/cute/interface.py | 81 +++++++++++++++++++----------- flash_attn/cute/tile_scheduler.py | 50 ++++++++---------- tests/cute/test_flash_attn.py | 43 ++++++++++------ 4 files changed, 106 insertions(+), 89 deletions(-) diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index 575b1d45bc7..7a89e720bbb 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -400,7 +400,6 @@ def __call__( aux_tensors: Optional[list] = None, num_splits_dynamic_ptr: Optional[cute.Tensor] = None, tile_count_semaphore: Optional[cute.Tensor] = None, - num_m_blocks_ptr: Optional[cute.Tensor] = None, varlen_batch_idx_ptr: Optional[cute.Tensor] = None, num_nheads_in_l2_ptr: Optional[cute.Tensor] = None, max_seqlen_q: Int32 | int | None = None, @@ -689,7 +688,6 @@ def __call__( cluster_shape_mn=self.cluster_shape_mn, use_cluster_idx=not self.is_persistent and self.cta_group_size > 1, num_splits_dynamic_ptr=num_splits_dynamic_ptr, - num_m_blocks_ptr=num_m_blocks_ptr, varlen_batch_idx_ptr=varlen_batch_idx_ptr, num_nheads_in_l2_ptr=num_nheads_in_l2_ptr, tile_count_semaphore=tile_count_semaphore.iterator if tile_count_semaphore is not None else None, @@ -810,7 +808,6 @@ class SharedStorage: num_splits, num_splits_dynamic_ptr, tile_count_semaphore, - num_m_blocks_ptr, varlen_batch_idx_ptr, num_nheads_in_l2_ptr, aux_tensors, @@ -862,7 +859,6 @@ def kernel( num_splits: Int32, num_splits_dynamic_ptr: Optional[cute.Tensor] = None, tile_count_semaphore: Optional[cute.Tensor] = None, - num_m_blocks_ptr: Optional[cute.Tensor] = None, varlen_batch_idx_ptr: Optional[cute.Tensor] = None, num_nheads_in_l2_ptr: Optional[cute.Tensor] = None, aux_tensors: Optional[list] = None, @@ -3064,18 +3060,11 @@ def scheduler_warp( self, tile_scheduler: TileSchedulerProtocol, ): - if const_expr(self.dynamic_persistent): - work_tile, group_start_tile = tile_scheduler.initial_sched_state() - batch_idx = Int32(work_tile.tile_idx[2]) - while work_tile.is_valid_tile: - group_start_tile = tile_scheduler.prefetch_next_work(batch_idx, group_start_tile) - work_tile = tile_scheduler.advance_to_next_work() - batch_idx = Int32(work_tile.tile_idx[2]) - else: - work_tile = tile_scheduler.initial_work_tile_info() - while work_tile.is_valid_tile: - tile_scheduler.prefetch_next_work() - work_tile = tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + tile_scheduler.prefetch_next_work() + work_tile = tile_scheduler.advance_to_next_work() + if const_expr(self.use_clc_scheduler): if cute.arch.thread_idx()[0] == self.scheduler_warp_id * cute.arch.WARP_SIZE: fa_printf( 3, diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 7bcbb7fff17..2fbfb51b667 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -683,7 +683,17 @@ def _flash_attn_fwd( reuse_scheduler_metadata = scheduler_metadata is not None is_varlen_q = cu_seqlens_q is not None or seqused_q is not None - if is_split_kv and is_varlen_q and scheduler_metadata is None and not disable_scheduler_metadata: + if use_dedicated_hd256_kernel: + # The hd=256 2CTA fwd kernel does not support the dynamic-persistent scheduler. + scheduler_metadata = None + reuse_scheduler_metadata = False + if ( + is_split_kv + and is_varlen_q + and scheduler_metadata is None + and not disable_scheduler_metadata + and not use_dedicated_hd256_kernel + ): scheduler_metadata = get_scheduler_metadata( num_batch=batch_size, max_seqlen_q=max_seqlen_q, @@ -712,9 +722,9 @@ def _flash_attn_fwd( varlen_batch_idx, num_nheads_in_l2, tile_count_semaphore, - ) = scheduler_metadata + ) = scheduler_metadata assert all( - t is None or t.is_cuda + t is None or t.is_cuda for t in scheduler_metadata ), "scheduler metadata must be on CUDA device" assert all( @@ -729,11 +739,11 @@ def _flash_attn_fwd( if tile_count_semaphore is not None: assert tile_count_semaphore.shape == (1,), "semaphore must have size 1" else: - num_m_blocks = None - num_splits_dynamic = None - varlen_batch_idx = None - num_nheads_in_l2 = None - tile_count_semaphore = None + num_m_blocks = None + num_splits_dynamic = None + varlen_batch_idx = None + num_nheads_in_l2 = None + tile_count_semaphore = None compile_key = ( dtype, @@ -773,7 +783,6 @@ def _flash_attn_fwd( mma_pv_is_rs, intra_wg_overlap, use_clc_scheduler, - has_scheduler_metadata, num_m_blocks is not None, num_splits_dynamic is not None, varlen_batch_idx is not None, @@ -967,9 +976,7 @@ def _flash_attn_fwd( else FlashAttentionForwardSm100 ) - fa_fwd = flash_fwd_obj_cls( - head_dim, - head_dim_v, + fa_fwd_kwargs = dict( qhead_per_kvhead=qhead_per_kvhead, is_causal=causal, is_local=local, @@ -991,8 +998,10 @@ def _flash_attn_fwd( q_subtile_factor=q_subtile_factor, use_2cta_instrs=use_2cta_instrs, use_clc_scheduler=use_clc_scheduler, - has_tile_count_semaphore=tile_count_semaphore is not None, ) + if not use_dedicated_hd256_kernel: + fa_fwd_kwargs["has_tile_count_semaphore"] = tile_count_semaphore is not None + fa_fwd = flash_fwd_obj_cls(head_dim, head_dim_v, **fa_fwd_kwargs) elif arch // 10 == 12: # SM120 (Blackwell GeForce / DGX Spark): uses SM80 MMA with SM120 SMEM capacity assert not use_block_sparsity, "Block sparsity not supported on SM 12.0" @@ -1058,18 +1067,26 @@ def _flash_attn_fwd( window_size_left, window_size_right, learnable_sink_tensor, - sparse_tensors, - cute_aux_tensors, - num_splits_dynamic_tensor, - tile_count_semaphore_tensor, - num_m_blocks_tensor, - varlen_batch_idx_tensor, - num_nheads_in_l2_tensor, - max_seqlen_q, - current_stream, ] if arch // 10 in [10, 11]: - compile_args.insert(-9, descale_tensors_tensor) + compile_args.append(descale_tensors_tensor) + compile_args.extend([ + sparse_tensors, + cute_aux_tensors, + ]) + if not use_dedicated_hd256_kernel: + compile_args.extend([ + num_splits_dynamic_tensor, + tile_count_semaphore_tensor, + ]) + if arch // 10 == 9: + compile_args.append(num_m_blocks_tensor) + compile_args.extend([ + varlen_batch_idx_tensor, + num_nheads_in_l2_tensor, + max_seqlen_q, + ]) + compile_args.append(current_stream) _flash_attn_fwd.compile_cache[compile_key] = cute.compile(*compile_args, options="--enable-tvm-ffi") if not is_fake_mode(): @@ -1138,13 +1155,19 @@ def _flash_attn_fwd( if normalized_block_sparse_tensors is not None else None, aux_tensors, - num_splits_dynamic, - tile_count_semaphore, - num_m_blocks, - varlen_batch_idx, - num_nheads_in_l2, - max_seqlen_q, ]) + if not use_dedicated_hd256_kernel: + call_args.extend([ + num_splits_dynamic, + tile_count_semaphore, + ]) + if arch // 10 == 9: + call_args.append(num_m_blocks) + call_args.extend([ + varlen_batch_idx, + num_nheads_in_l2, + max_seqlen_q, + ]) _flash_attn_fwd.compile_cache[compile_key](*call_args) if is_split_kv: _flash_attn_fwd_combine( diff --git a/flash_attn/cute/tile_scheduler.py b/flash_attn/cute/tile_scheduler.py index 31f39b75e63..853dfcfc815 100644 --- a/flash_attn/cute/tile_scheduler.py +++ b/flash_attn/cute/tile_scheduler.py @@ -212,10 +212,6 @@ class TileSchedulerProtocol(Protocol): 2. Work distribution: how to get the next tile (static grid-stride vs dynamic) """ - def get_current_work(self) -> WorkTileInfo: - """Get the current work tile coordinates.""" - ... - def initial_work_tile_info(self) -> WorkTileInfo: """Get the initial work tile for this CTA.""" ... @@ -1267,12 +1263,16 @@ def __init__( self, params: Params, ctx: SchedulerState, + bidb_start: Int32, + group_start_tile: Int32, *, loc=None, ip=None, ): self.params = params self._ctx = ctx + self._bidb_start = bidb_start + self._group_start_tile = group_start_tile self._loc = loc self._ip = ip @@ -1290,6 +1290,7 @@ def to_underlying_arguments( return DynamicPersistentVarlenScheduler.Params.create(args, loc=loc, ip=ip) @staticmethod + @cute.jit def create( params: Params, ctx: SchedulerState, @@ -1297,7 +1298,7 @@ def create( loc=None, ip=None, ) -> "DynamicPersistentVarlenScheduler": - return DynamicPersistentVarlenScheduler(params, ctx, loc=loc, ip=ip) + return DynamicPersistentVarlenScheduler(params, ctx, Int32(0), Int32(0), loc=loc, ip=ip) # called by host @staticmethod @@ -1322,12 +1323,6 @@ def _get_num_m_blocks(self, lane: Int32, bidb_start: Int32) -> Int32: params = self.params batch_idx = lane + bidb_start if cutlass.const_expr(params.varlen_batch_idx_ptr is not None): - if cutlass.const_expr(params.num_m_blocks_ptr is not None): - # num_m_blocks is at virtual idx (prepare_scheduler writes by vbidx) - n = Int32(0) - if batch_idx < params.num_batch and lane < cute.arch.WARP_SIZE - 1: - n = params.num_m_blocks_ptr[batch_idx] - return n seqlen = Int32(0) if batch_idx < params.num_batch and lane < cute.arch.WARP_SIZE - 1: real_batch_idx = params.varlen_batch_idx_ptr[batch_idx] @@ -1339,11 +1334,6 @@ def _get_num_m_blocks(self, lane: Int32, bidb_start: Int32) -> Int32: if batch_idx < params.num_batch and lane < cute.arch.WARP_SIZE - 1 else Int32(0) ) - if cutlass.const_expr(params.num_m_blocks_ptr is not None): - n = Int32(0) - if batch_idx < params.num_batch and lane < cute.arch.WARP_SIZE - 1: - n = params.num_m_blocks_ptr[batch_idx] - return n if cutlass.const_expr(params.mSeqUsedQ is not None): seqlen = Int32(0) if batch_idx < params.num_batch: @@ -1515,7 +1505,7 @@ def get_current_work( ) @cute.jit - def prefetch_next_work(self, batch_idx, tile_idx, *, loc=None, ip=None): + def prefetch_next_work(self, *, loc=None, ip=None): ctx = self._ctx next_tile_idx = Int32(0) if cute.arch.lane_idx() == 0: @@ -1524,14 +1514,19 @@ def prefetch_next_work(self, batch_idx, tile_idx, *, loc=None, ip=None): self.params.tile_count_semaphore, ) next_tile_idx = cute.arch.shuffle_sync(next_tile_idx, 0) - work_info, new_tile_idx = self.get_current_work(next_tile_idx, batch_idx, tile_idx) + work_info, new_group_start_tile = self.get_current_work( + next_tile_idx, self._bidb_start, self._group_start_tile + ) + # Advance scan state so the next prefetch resumes from this tile's batch + # group instead of restarting at batch 0. + self._bidb_start = Int32(work_info.tile_idx[2]) + self._group_start_tile = new_group_start_tile ctx.producer_acquire() with cute.arch.elect_one(): block, head_idx, batch_idx, split_idx = work_info.tile_idx ctx.write_work_info(block, head_idx, batch_idx, split_idx) ctx.producer_commit() ctx.advance_producer_state() - return new_tile_idx @cute.jit def advance_to_next_work(self, *, loc=None, ip=None) -> WorkTileInfo: @@ -1549,21 +1544,17 @@ def advance_to_next_work(self, *, loc=None, ip=None) -> WorkTileInfo: @cute.jit def initial_work_tile_info(self, *, loc=None, ip=None) -> WorkTileInfo: cta_tile_idx, _, _ = cute.arch.block_idx() - work_info, _ = self.get_current_work(cta_tile_idx, Int32(0), Int32(0)) + work_info, new_group_start_tile = self.get_current_work(cta_tile_idx, Int32(0), Int32(0)) + self._bidb_start = Int32(work_info.tile_idx[2]) + self._group_start_tile = new_group_start_tile return work_info - @cute.jit - def initial_sched_state(self, *, loc=None, ip=None): - cta_tile_idx, _, _ = cute.arch.block_idx() - work_info, group_start_tile = self.get_current_work(cta_tile_idx, Int32(0), Int32(0)) - return work_info, group_start_tile - def producer_tail(self, *, loc=None, ip=None): self._ctx.producer_tail(loc=loc, ip=ip) def __extract_mlir_values__(self): values, self._values_pos = [], [] - for obj in [self.params, self._ctx]: + for obj in [self.params, self._ctx, self._bidb_start, self._group_start_tile]: obj_values = cutlass.extract_mlir_values(obj) values += obj_values self._values_pos.append(len(obj_values)) @@ -1571,7 +1562,10 @@ def __extract_mlir_values__(self): def __new_from_mlir_values__(self, values): obj_list = [] - for obj, n_items in zip([self.params, self._ctx], self._values_pos): + for obj, n_items in zip( + [self.params, self._ctx, self._bidb_start, self._group_start_tile], + self._values_pos, + ): obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) values = values[n_items:] return DynamicPersistentVarlenScheduler(*(tuple(obj_list)), loc=self._loc) diff --git a/tests/cute/test_flash_attn.py b/tests/cute/test_flash_attn.py index 764d7123681..e73f4959d86 100644 --- a/tests/cute/test_flash_attn.py +++ b/tests/cute/test_flash_attn.py @@ -30,9 +30,11 @@ from flash_attn.cute.interface import ( flash_attn_func, flash_attn_varlen_func, + get_scheduler_metadata, _flash_attn_fwd, _flash_attn_bwd, ) +from flash_attn.cute.prepare_scheduler import SchedulerMetadataTensorsTorch def retry_on_oom(func): @wraps(func) @@ -993,7 +995,7 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): @pytest.mark.parametrize("has_leftpad", [False]) # @pytest.mark.parametrize("has_batch_idx", [False, True]) @pytest.mark.parametrize("has_batch_idx", [False]) -@pytest.mark.parametrize("varlen_q", [False, True]) +@pytest.mark.parametrize("varlen_q", [True]) # @pytest.mark.parametrize("varlen_q", [False]) # @pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) @@ -1381,26 +1383,35 @@ def test_flash_attn_kvcache( # num_splits_vals = [1, 0] # SplitKV is not supported for hdim >= 192 num_splits_vals = [1, 3] if d < 192 and not DISABLE_SPLIT else [1] - # precompute_metadata_vals = [False, True] - precompute_metadata_vals = [False] + precompute_metadata_vals = [False, True] + # precompute_metadata_vals = [False] for num_splits, precompute_metadata in itertools.product( num_splits_vals, precompute_metadata_vals ): # SplitKV not supported on SM90 - skip this iteration if IS_SM90 and num_splits > 1: continue - # if precompute_metadata: - # scheduler_metadata = get_scheduler_metadata( - # batch_size, max_seqlen_q if varlen_q else seqlen_q, seqlen_k, nheads, nheads_k, d, - # cache_seqlens, q.dtype, headdim_v=dv, cu_seqlens_q=cu_seqlens_q, - # cu_seqlens_k_new=cu_seqlens_k_new, cache_leftpad=cache_leftpad, - # max_seqlen_k_new=seqlen_new, page_size=page_size, - # causal=causal, window_size=window_size, attention_chunk=attention_chunk, - # num_splits=num_splits - # ) - # else: - # scheduler_metadata = None - scheduler_metadata = None + if precompute_metadata and is_fake_mode(): + continue + if precompute_metadata: + scheduler_metadata = get_scheduler_metadata( + num_batch=batch_size, + max_seqlen_q=max_seqlen_q if varlen_q else seqlen_q, + max_seqlen_k=seqlen_k, + nheads=nheads, + nheads_kv=nheads_k, + headdim=d, + headdim_v=dv, + num_splits=num_splits, + tile_m=128, + tile_n=128, + causal=causal, + sort=True, + cu_seqlens_q=cu_seqlens_q, + seqused_k=cache_seqlens, + ) + else: + scheduler_metadata = None # Repeat to test metadata reuse for _ in range(1 if not precompute_metadata else 2): if page_size is None: @@ -1431,7 +1442,7 @@ def test_flash_attn_kvcache( learnable_sink=learnable_sink, # attention_chunk=attention_chunk, # rotary_interleaved=rotary_interleaved, - # scheduler_metadata=scheduler_metadata, + scheduler_metadata=scheduler_metadata, num_splits=num_splits, # return_softmax_lse=True ) From 2c14ab75f27825d11126cbb1c308f8512d69f0f7 Mon Sep 17 00:00:00 2001 From: reubenconducts Date: Thu, 7 May 2026 01:11:09 +0000 Subject: [PATCH 03/15] rename varlen_batch_idx -> virtual_batch_idx, because it is relevant for non-varlen blocksparse batch sorting --- flash_attn/cute/flash_fwd_sm100.py | 8 ++--- flash_attn/cute/interface.py | 52 ++++++++++++++-------------- flash_attn/cute/prepare_scheduler.py | 8 ++--- flash_attn/cute/tile_scheduler.py | 14 ++++---- 4 files changed, 41 insertions(+), 41 deletions(-) diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index 7a89e720bbb..fe878a69265 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -400,7 +400,7 @@ def __call__( aux_tensors: Optional[list] = None, num_splits_dynamic_ptr: Optional[cute.Tensor] = None, tile_count_semaphore: Optional[cute.Tensor] = None, - varlen_batch_idx_ptr: Optional[cute.Tensor] = None, + virtual_batch_idx_ptr: Optional[cute.Tensor] = None, num_nheads_in_l2_ptr: Optional[cute.Tensor] = None, max_seqlen_q: Int32 | int | None = None, # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI). @@ -688,7 +688,7 @@ def __call__( cluster_shape_mn=self.cluster_shape_mn, use_cluster_idx=not self.is_persistent and self.cta_group_size > 1, num_splits_dynamic_ptr=num_splits_dynamic_ptr, - varlen_batch_idx_ptr=varlen_batch_idx_ptr, + virtual_batch_idx_ptr=virtual_batch_idx_ptr, num_nheads_in_l2_ptr=num_nheads_in_l2_ptr, tile_count_semaphore=tile_count_semaphore.iterator if tile_count_semaphore is not None else None, ) @@ -808,7 +808,7 @@ class SharedStorage: num_splits, num_splits_dynamic_ptr, tile_count_semaphore, - varlen_batch_idx_ptr, + virtual_batch_idx_ptr, num_nheads_in_l2_ptr, aux_tensors, fastdiv_mods, @@ -859,7 +859,7 @@ def kernel( num_splits: Int32, num_splits_dynamic_ptr: Optional[cute.Tensor] = None, tile_count_semaphore: Optional[cute.Tensor] = None, - varlen_batch_idx_ptr: Optional[cute.Tensor] = None, + virtual_batch_idx_ptr: Optional[cute.Tensor] = None, num_nheads_in_l2_ptr: Optional[cute.Tensor] = None, aux_tensors: Optional[list] = None, fastdiv_mods=(None, None), diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 2fbfb51b667..4651de90c62 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -719,7 +719,7 @@ def _flash_attn_fwd( ( num_m_blocks, num_splits_dynamic, - varlen_batch_idx, + virtual_batch_idx, num_nheads_in_l2, tile_count_semaphore, ) = scheduler_metadata @@ -732,7 +732,7 @@ def _flash_attn_fwd( for t in ( num_m_blocks, num_splits_dynamic, - varlen_batch_idx, + virtual_batch_idx, num_nheads_in_l2, ) ), "these scheduler metadata tensors must have shape (batch_size,)" @@ -741,7 +741,7 @@ def _flash_attn_fwd( else: num_m_blocks = None num_splits_dynamic = None - varlen_batch_idx = None + virtual_batch_idx = None num_nheads_in_l2 = None tile_count_semaphore = None @@ -785,7 +785,7 @@ def _flash_attn_fwd( use_clc_scheduler, num_m_blocks is not None, num_splits_dynamic is not None, - varlen_batch_idx is not None, + virtual_batch_idx is not None, num_nheads_in_l2 is not None, tile_count_semaphore is not None, qv is not None, @@ -871,9 +871,9 @@ def _flash_attn_fwd( to_cute_tensor(num_m_blocks, assumed_align=4, leading_dim=0) if num_m_blocks is not None else None ) - varlen_batch_idx_tensor = ( - to_cute_tensor(varlen_batch_idx, assumed_align=4, leading_dim=0) - if varlen_batch_idx is not None else None + virtual_batch_idx_tensor = ( + to_cute_tensor(virtual_batch_idx, assumed_align=4, leading_dim=0) + if virtual_batch_idx is not None else None ) num_nheads_in_l2_tensor = ( to_cute_tensor(num_nheads_in_l2, assumed_align=4, leading_dim=0) @@ -1082,7 +1082,7 @@ def _flash_attn_fwd( if arch // 10 == 9: compile_args.append(num_m_blocks_tensor) compile_args.extend([ - varlen_batch_idx_tensor, + virtual_batch_idx_tensor, num_nheads_in_l2_tensor, max_seqlen_q, ]) @@ -1164,7 +1164,7 @@ def _flash_attn_fwd( if arch // 10 == 9: call_args.append(num_m_blocks) call_args.extend([ - varlen_batch_idx, + virtual_batch_idx, num_nheads_in_l2, max_seqlen_q, ]) @@ -1178,7 +1178,7 @@ def _flash_attn_fwd( cu_seqlens_q, seqused_q, num_splits_dynamic_ptr=num_splits_dynamic if has_scheduler_metadata else None, - varlen_batch_idx=varlen_batch_idx if has_scheduler_metadata else None, + virtual_batch_idx=virtual_batch_idx if has_scheduler_metadata else None, ) if reuse_scheduler_metadata and tile_count_semaphore is not None: tile_count_semaphore.zero_() @@ -2383,7 +2383,7 @@ def flash_attn_varlen_func( def _compile_fwd_combine( dtype, dtype_partial, head_dim, tile_m, k_block_size, log_max_splits, - has_cu_seqlens, has_seqused, has_lse, has_varlen_batch_idx, + has_cu_seqlens, has_seqused, has_lse, has_virtual_batch_idx, has_num_splits_dynamic, has_semaphore_to_reset, ): """Compile fwd combine kernel using cute fake tensors (no real GPU tensors needed).""" @@ -2427,13 +2427,13 @@ def _compile_fwd_combine( mCuSeqlens = fake_tensor(Int32, (batchp1,), divisibility=1) if has_cu_seqlens else None mSeqused = fake_tensor(Int32, (batch_for_1d,), divisibility=1) if has_seqused else None mNumSplitsDynamic = fake_tensor(Int32, (batch_for_1d,), divisibility=1) if has_num_splits_dynamic else None - mVarlenBatchIdx = fake_tensor(Int32, (batch_for_1d,), divisibility=1) if has_varlen_batch_idx else None + mVirtualBatchIdx = fake_tensor(Int32, (batch_for_1d,), divisibility=1) if has_virtual_batch_idx else None mSemaphore = fake_tensor(Int32, (1,), divisibility=1) if has_semaphore_to_reset else None return cute.compile( fa_combine, mO_partial, mLSE_partial, mO, mLSE, - mCuSeqlens, mSeqused, mNumSplitsDynamic, mVarlenBatchIdx, mSemaphore, + mCuSeqlens, mSeqused, mNumSplitsDynamic, mVirtualBatchIdx, mSemaphore, cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), options="--enable-tvm-ffi", ) @@ -2447,7 +2447,7 @@ def _flash_attn_fwd_combine( cu_seqlens: Optional[torch.Tensor] = None, seqused: Optional[torch.Tensor] = None, num_splits_dynamic_ptr: Optional[torch.Tensor] = None, - varlen_batch_idx: Optional[torch.Tensor] = None, + virtual_batch_idx: Optional[torch.Tensor] = None, semaphore_to_reset: Optional[torch.Tensor] = None, ) -> None: """Forward combine kernel for split attention computation. @@ -2516,7 +2516,7 @@ def _flash_attn_fwd_combine( cu_seqlens is not None, seqused is not None, lse is not None, - varlen_batch_idx is not None, + virtual_batch_idx is not None, num_splits_dynamic_ptr is not None, semaphore_to_reset is not None, ) @@ -2527,7 +2527,7 @@ def _flash_attn_fwd_combine( if not is_fake_mode(): _flash_attn_fwd_combine.compile_cache[compile_key]( out_partial, lse_partial, out, lse, - cu_seqlens, seqused, num_splits_dynamic_ptr, varlen_batch_idx, + cu_seqlens, seqused, num_splits_dynamic_ptr, virtual_batch_idx, semaphore_to_reset, ) @@ -2542,7 +2542,7 @@ def flash_attn_combine( out_dtype: Optional[torch.dtype] = None, cu_seqlens: Optional[torch.Tensor] = None, seqused: Optional[torch.Tensor] = None, - varlen_batch_idx: Optional[torch.Tensor] = None, + virtual_batch_idx: Optional[torch.Tensor] = None, return_lse: bool = True, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """Flash Attention combine function for split attention computation. @@ -2562,7 +2562,7 @@ def flash_attn_combine( out_dtype: Optional output dtype. If None, will use fp16/bf16 based on input. cu_seqlens: Cumulative sequence lengths for variable length sequences seqused: Used sequence lengths for each batch - varlen_batch_idx: Optional mapping from virtual batch index to real batch index + virtual_batch_idx: Optional mapping from virtual batch index to real batch index (int32 tensor of shape (batch_size,)). Used by persistent tile schedulers that reorder batch processing for load balancing. return_lse: Whether to return the combined LSE tensor. Default is True. @@ -2619,7 +2619,7 @@ def flash_attn_combine( lse, cu_seqlens, seqused, - varlen_batch_idx=varlen_batch_idx, + virtual_batch_idx=virtual_batch_idx, ) return out, lse @@ -2676,7 +2676,7 @@ def get_scheduler_metadata( # Allocate metadata torch tensors num_m_blocks = torch.empty(num_batch, dtype=torch.int32, device=device) num_splits_dynamic = torch.empty(num_batch, dtype=torch.int32, device=device) - varlen_batch_idx = torch.empty(num_batch, dtype=torch.int32, device=device) if sort else None + virtual_batch_idx = torch.empty(num_batch, dtype=torch.int32, device=device) if sort else None num_nheads_in_l2 = torch.empty(num_batch, dtype=torch.int32, device=device) if causal else None tile_count_semaphore = torch.empty(1, dtype=torch.int32, device=device) @@ -2705,7 +2705,7 @@ def get_scheduler_metadata( leftpad_k is not None, num_m_blocks is not None, num_splits_dynamic is not None, - varlen_batch_idx is not None, + virtual_batch_idx is not None, num_nheads_in_l2 is not None, tile_count_semaphore is not None, n_blocks_per_split is not None, @@ -2716,7 +2716,7 @@ def get_scheduler_metadata( ( num_m_blocks_cute, num_splits_dynamic_cute, - varlen_batch_idx_cute, + virtual_batch_idx_cute, num_nheads_in_l2_cute, tile_count_semaphore_cute, cu_seqlens_q_cute, @@ -2730,7 +2730,7 @@ def get_scheduler_metadata( for t in ( num_m_blocks, num_splits_dynamic, - varlen_batch_idx, + virtual_batch_idx, num_nheads_in_l2, tile_count_semaphore, cu_seqlens_q, @@ -2770,7 +2770,7 @@ def get_scheduler_metadata( tile_count_semaphore_cute, num_m_blocks_cute, num_splits_dynamic_cute, - varlen_batch_idx_cute, + virtual_batch_idx_cute, num_nheads_in_l2_cute, n_blocks_per_split, cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), @@ -2793,7 +2793,7 @@ def get_scheduler_metadata( tile_count_semaphore, num_m_blocks, num_splits_dynamic, - varlen_batch_idx, + virtual_batch_idx, num_nheads_in_l2, n_blocks_per_split, ) @@ -2801,7 +2801,7 @@ def get_scheduler_metadata( return SchedulerMetadataTensorsTorch( num_m_blocks_ptr=num_m_blocks, num_splits_dynamic_ptr=num_splits_dynamic, - varlen_batch_idx_ptr=varlen_batch_idx, + virtual_batch_idx_ptr=virtual_batch_idx, num_nheads_in_l2_ptr=num_nheads_in_l2, tile_count_semaphore=tile_count_semaphore, ) diff --git a/flash_attn/cute/prepare_scheduler.py b/flash_attn/cute/prepare_scheduler.py index f56662aaea0..796f147a5d8 100644 --- a/flash_attn/cute/prepare_scheduler.py +++ b/flash_attn/cute/prepare_scheduler.py @@ -18,7 +18,7 @@ class SchedulerMetadataTensorsTorch(NamedTuple): # tensors of shape (batch) num_m_blocks_ptr: Optional[torch.Tensor] num_splits_dynamic_ptr: Optional[torch.Tensor] - varlen_batch_idx_ptr: Optional[torch.Tensor] + virtual_batch_idx_ptr: Optional[torch.Tensor] num_nheads_in_l2_ptr: Optional[torch.Tensor] # tensor of shape (1) tile_count_semaphore: Optional[torch.Tensor] @@ -96,7 +96,7 @@ def __call__( tile_count_semaphore: Optional[cute.Tensor], num_m_blocks_ptr: Optional[cute.Tensor], num_splits_dynamic_ptr: Optional[cute.Tensor], - varlen_batch_idx_ptr: Optional[cute.Tensor], + virtual_batch_idx_ptr: Optional[cute.Tensor], num_nheads_in_l2_ptr: Optional[cute.Tensor], n_blocks_per_split: Optional[int], # overrides heuristic stream: cuda.CUstream, @@ -134,7 +134,7 @@ class SharedStorage: tile_count_semaphore, num_m_blocks_ptr, num_splits_dynamic_ptr, - varlen_batch_idx_ptr, + virtual_batch_idx_ptr, num_nheads_in_l2_ptr, n_blocks_per_split, ).launch( @@ -164,7 +164,7 @@ def kernel( tile_count_semaphore: Optional[cute.Tensor], num_m_blocks_ptr: Optional[cute.Tensor], num_splits_dynamic_ptr: Optional[cute.Tensor], - varlen_batch_idx_ptr: Optional[cute.Tensor], + virtual_batch_idx_ptr: Optional[cute.Tensor], num_nheads_in_l2_ptr: Optional[cute.Tensor], n_blocks_per_split: Optional[Int32], ): diff --git a/flash_attn/cute/tile_scheduler.py b/flash_attn/cute/tile_scheduler.py index 853dfcfc815..dc529093db2 100644 --- a/flash_attn/cute/tile_scheduler.py +++ b/flash_attn/cute/tile_scheduler.py @@ -263,7 +263,7 @@ class TileSchedulerArguments(ParamsBase): use_cluster_idx: cutlass.Constexpr[bool] = False num_splits_dynamic_ptr: Optional[cute.Tensor] = None num_m_blocks_ptr: Optional[cute.Tensor] = None - varlen_batch_idx_ptr: Optional[cute.Tensor] = None + virtual_batch_idx_ptr: Optional[cute.Tensor] = None num_nheads_in_l2_ptr: Optional[cute.Tensor] = None tile_count_semaphore: Optional[cute.Pointer] = None persistent_cta_multiplier: cutlass.Constexpr[int] = 1 @@ -1222,7 +1222,7 @@ class Params(ParamsBase): is_split_kv: cutlass.Constexpr[bool] = False num_splits_dynamic_ptr: Optional[cute.Tensor] = None num_m_blocks_ptr: Optional[cute.Tensor] = None - varlen_batch_idx_ptr: Optional[cute.Tensor] = None + virtual_batch_idx_ptr: Optional[cute.Tensor] = None num_nheads_in_l2_ptr: Optional[cute.Tensor] = None tile_count_semaphore: Optional[cute.Pointer] = None persistent_cta_multiplier: cutlass.Constexpr[int] = 1 @@ -1253,7 +1253,7 @@ def create( is_split_kv=args.is_split_kv, num_splits_dynamic_ptr=args.num_splits_dynamic_ptr, num_m_blocks_ptr=args.num_m_blocks_ptr, - varlen_batch_idx_ptr=args.varlen_batch_idx_ptr, + virtual_batch_idx_ptr=args.virtual_batch_idx_ptr, num_nheads_in_l2_ptr=args.num_nheads_in_l2_ptr, tile_count_semaphore=args.tile_count_semaphore, persistent_cta_multiplier=args.persistent_cta_multiplier, @@ -1322,10 +1322,10 @@ def get_grid_shape( def _get_num_m_blocks(self, lane: Int32, bidb_start: Int32) -> Int32: params = self.params batch_idx = lane + bidb_start - if cutlass.const_expr(params.varlen_batch_idx_ptr is not None): + if cutlass.const_expr(params.virtual_batch_idx_ptr is not None): seqlen = Int32(0) if batch_idx < params.num_batch and lane < cute.arch.WARP_SIZE - 1: - real_batch_idx = params.varlen_batch_idx_ptr[batch_idx] + real_batch_idx = params.virtual_batch_idx_ptr[batch_idx] seqlen = params.mCuSeqlensQ[real_batch_idx + 1] - params.mCuSeqlensQ[real_batch_idx] if cutlass.const_expr(params.qhead_per_kvhead_packgqa > 1): seqlen *= params.qhead_per_kvhead_packgqa @@ -1493,9 +1493,9 @@ def get_current_work( if const_expr(params.is_split_kv and params.num_splits_dynamic_ptr is not None): if is_valid: split_idx = split_idx | (num_splits << 16) - if const_expr(params.varlen_batch_idx_ptr is not None): + if const_expr(params.virtual_batch_idx_ptr is not None): if is_valid: - batch_idx = params.varlen_batch_idx_ptr[batch_idx] + batch_idx = params.virtual_batch_idx_ptr[batch_idx] return ( WorkTileInfo( From 79aea1439a738f61f34a07c5d46d51b337f771c1 Mon Sep 17 00:00:00 2001 From: reubenconducts Date: Tue, 12 May 2026 02:04:53 +0000 Subject: [PATCH 04/15] split out VarlenSchedulerBase to share code between SingleTile and DynamicPersistent schedulers --- flash_attn/cute/flash_fwd_combine.py | 11 +- flash_attn/cute/flash_fwd_sm100.py | 28 +- flash_attn/cute/interface.py | 3 +- .../cute/sm100_hd256_2cta_fmha_forward.py | 4 +- flash_attn/cute/tile_scheduler.py | 585 +++++++++--------- tests/cute/test_flash_attn_combine.py | 14 +- 6 files changed, 324 insertions(+), 321 deletions(-) diff --git a/flash_attn/cute/flash_fwd_combine.py b/flash_attn/cute/flash_fwd_combine.py index 0d9f7985e70..e831fb0ce88 100644 --- a/flash_attn/cute/flash_fwd_combine.py +++ b/flash_attn/cute/flash_fwd_combine.py @@ -197,7 +197,7 @@ def __call__( cu_seqlens: Optional[cute.Tensor] = None, seqused: Optional[cute.Tensor] = None, num_splits_dynamic_ptr: Optional[cute.Tensor] = None, - varlen_batch_idx: Optional[cute.Tensor] = None, + virtual_batch_idx: Optional[cute.Tensor] = None, semaphore_to_reset: Optional[cute.Tensor] = None, # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI). stream: cuda.CUstream = None, @@ -301,7 +301,7 @@ class SharedStorage: cu_seqlens, seqused, num_splits_dynamic_ptr, - varlen_batch_idx, + virtual_batch_idx, semaphore_to_reset, SharedStorage, self.smem_layout_lse, @@ -330,7 +330,7 @@ def kernel( cu_seqlens: Optional[cute.Tensor], seqused: Optional[cute.Tensor], num_splits_dynamic_ptr: Optional[cute.Tensor], - varlen_batch_idx: Optional[cute.Tensor], + virtual_batch_idx: Optional[cute.Tensor], semaphore_to_reset: Optional[cute.Tensor], SharedStorage: cutlass.Constexpr, smem_layout_lse: cute.Layout | cute.ComposedLayout, @@ -349,8 +349,8 @@ def kernel( # Map virtual batch index to real batch index (for persistent tile schedulers) batch_idx = ( - varlen_batch_idx[maybe_virtual_batch] - if const_expr(varlen_batch_idx is not None) + virtual_batch_idx[maybe_virtual_batch] + if const_expr(virtual_batch_idx is not None) else maybe_virtual_batch ) @@ -365,6 +365,7 @@ def kernel( # Handle semaphore reset — wait for dependent grids first if const_expr(semaphore_to_reset is not None): + # maybe handle on first CTA? if ( tidx == 0 and m_block == cute.arch.grid_dim()[0] - 1 diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index fe878a69265..de653f6b53d 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -127,7 +127,7 @@ def __init__( m_block_size: int = 128, n_block_size: int = 128, q_stage: cutlass.Constexpr[int] = 2, - is_persistent: bool = True, + is_static_persistent: bool = True, score_mod: cutlass.Constexpr | None = None, mask_mod: cutlass.Constexpr | None = None, has_aux_tensors: cutlass.Constexpr = False, @@ -178,7 +178,7 @@ def __init__( self.qk_acc_dtype = Float32 self.pv_acc_dtype = Float32 self.cluster_shape_mn = (2, 1) if self.use_2cta_instrs else (1, 1) - self.is_persistent = is_persistent + self.is_static_persistent = is_static_persistent self.is_causal = is_causal self.is_local = is_local self.is_varlen_q = is_varlen_q @@ -195,15 +195,7 @@ def __init__( self.vec_size: cutlass.Constexpr = getattr( score_mod, "__vec_size__", 1 if cutlass.const_expr(has_aux_tensors) else 2 ) - self.dynamic_persistent = has_tile_count_semaphore and is_varlen_q - if self.dynamic_persistent: - self.is_persistent = True - assert not use_clc_scheduler, ( - "use_clc_scheduler and dynamic_persistent (varlen + tile_count_semaphore) " - "are not currently composable; pick one. TODO: future revision could let " - "DynamicPersistentVarlenScheduler use CLC for tile distribution while " - "keeping prepare_scheduler's per-batch num_splits and LPT batch-sort." - ) + # Does S1 need to wait for S0 to finish # self.s0_s1_barrier = self.head_dim_padded in [64, 96] and (not self.is_causal and not self.is_local) is_sm103 = self.arch >= Arch.sm_103 and self.arch <= Arch.sm_103f @@ -221,12 +213,15 @@ def __init__( "Paged KV does not support irregular head dim" ) + self.use_clc_scheduler = use_clc_scheduler + self.dynamic_persistent = (has_tile_count_semaphore and is_varlen_q) or use_clc_scheduler # ClC does not compose with these other features, so disable even if requested self.use_clc_scheduler = ( use_clc_scheduler and self.use_tma_KV - and not self.overlap_sO_sQ ) + self.static_persistent = is_static_persistent + self.is_persistent = self.dynamic_persistent or self.static_persistent self.sched_stages = 1 if self.use_clc_scheduler: assert self.cluster_shape_mn[1] == 1, f"CLC requires cluster N == 1: {self.cluster_shape_mn}" @@ -242,13 +237,13 @@ def __init__( ) if is_varlen_q: - if self.dynamic_persistent: + if self.dynamic_persistent and not self.use_clc_scheduler: self.TileScheduler = DynamicPersistentVarlenScheduler else: self.TileScheduler = SingleTileVarlenScheduler elif self.is_causal or self.is_local or self.use_clc_scheduler: self.TileScheduler = SingleTileLPTScheduler - elif self.is_persistent: + elif self.static_persistent: self.TileScheduler = StaticPersistentTileScheduler else: self.TileScheduler = SingleTileScheduler @@ -295,10 +290,7 @@ def __init__( elif self.is_varlen_q: # fallback self.epilogue_warp_ids = (13, 14) - self.scheduler_warp_id = ( - self.empty_warp_ids[0] - if (self.use_clc_scheduler or self.dynamic_persistent) else None - ) + self.scheduler_warp_id = self.empty_warp_ids[0] if self.dynamic_persistent else None self.tmem_s_offset = [0, self.n_block_size] # e.g., 0, 128 self.tmem_o_offset = [ diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 4651de90c62..ddedf0c50be 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -985,7 +985,7 @@ def _flash_attn_fwd( m_block_size=tile_m, n_block_size=tile_n, q_stage=q_stage, - is_persistent=not causal + is_static_persistent=not causal and not local and cu_seqlens_q is None and seqused_q is None @@ -1181,6 +1181,7 @@ def _flash_attn_fwd( virtual_batch_idx=virtual_batch_idx if has_scheduler_metadata else None, ) if reuse_scheduler_metadata and tile_count_semaphore is not None: + # combine kernel does this for us tile_count_semaphore.zero_() return out, lse diff --git a/flash_attn/cute/sm100_hd256_2cta_fmha_forward.py b/flash_attn/cute/sm100_hd256_2cta_fmha_forward.py index 379cebc1905..c2abfab3dfc 100644 --- a/flash_attn/cute/sm100_hd256_2cta_fmha_forward.py +++ b/flash_attn/cute/sm100_hd256_2cta_fmha_forward.py @@ -46,7 +46,7 @@ def __init__( m_block_size: int = 128, n_block_size: int = 128, q_stage: int = 2, - is_persistent: bool = True, + is_static_persistent: bool = True, score_mod=None, mask_mod=None, has_aux_tensors: bool = False, @@ -54,6 +54,8 @@ def __init__( is_varlen_q: bool = False, use_2cta_instrs: bool = False, use_clc_scheduler: bool = False, + has_tile_count_semaphore: bool = False, + seqlen_k_per_split: Optional[int] = None, ): head_dim_v = head_dim if head_dim_v is None else head_dim_v assert head_dim == 256 and head_dim_v == 256, ( diff --git a/flash_attn/cute/tile_scheduler.py b/flash_attn/cute/tile_scheduler.py index dc529093db2..c232759b577 100644 --- a/flash_attn/cute/tile_scheduler.py +++ b/flash_attn/cute/tile_scheduler.py @@ -280,6 +280,7 @@ class Params(ParamsBase): is_split_kv: cutlass.Constexpr[bool] = False cluster_shape_mn: cutlass.Constexpr[Tuple[int, int]] = (1, 1) use_cluster_idx: cutlass.Constexpr[bool] = False + num_splits_dynamic_ptr: Optional[cute.Tensor] = None @staticmethod def create( @@ -294,6 +295,7 @@ def create( args.is_split_kv, args.cluster_shape_mn, args.use_cluster_idx, + args.num_splits_dynamic_ptr, ) def __init__(self, params: Params, blk_coord: cute.Coord, *, loc=None, ip=None): @@ -349,13 +351,21 @@ def get_grid_shape( def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: block_idx, head_idx, batch_idx = self._blk_coord + is_valid = self._is_first_block if const_expr(self.params.is_split_kv): head_idx, split_idx = divmod(head_idx, self.params.num_splits_divmod) else: split_idx = Int32(0) + # Pack dynamic per-batch num_splits into high 16 bits of split_idx + if const_expr( + self.params.is_split_kv and self.params.num_splits_dynamic_ptr is not None + ): + if is_valid: + num_splits = Int32(self.params.num_splits_dynamic_ptr[batch_idx]) + split_idx = split_idx | (num_splits << 16) return WorkTileInfo( (block_idx, head_idx, batch_idx, split_idx), - self._is_first_block, + is_valid, ) def initial_work_tile_info(self, *, loc=None, ip=None): @@ -513,6 +523,7 @@ class Params(ParamsBase): scheduling_mode: cutlass.Constexpr[SchedulingMode] = SchedulingMode.STATIC lpt: cutlass.Constexpr[bool] = True use_cluster_idx: cutlass.Constexpr[bool] = True + num_splits_dynamic_ptr: Optional[cute.Tensor] = None @staticmethod @cute.jit @@ -557,6 +568,7 @@ def create( scheduling_mode=scheduling_mode, lpt=args.lpt, use_cluster_idx=args.use_cluster_idx, + num_splits_dynamic_ptr=args.num_splits_dynamic_ptr, ) def __init__( @@ -572,7 +584,7 @@ def __init__( self.params = params self._tile_idx = tile_idx self._split_idx = split_idx - self.ctx = ctx + self._ctx = ctx self._loc = loc self._ip = ip @@ -657,6 +669,13 @@ def clc_work_to_coords(self, work) -> WorkTileInfo: if const_expr(self.params.cluster_shape_m > 1 and not self.params.use_cluster_idx): bidx_in_cluster = cute.arch.block_in_cluster_idx() block_idx = block_idx * self.params.cluster_shape_m + bidx_in_cluster[0] + # Pack dynamic per-batch num_splits into high 16 bits of split_idx + if const_expr( + self.params.is_split_kv and self.params.num_splits_dynamic_ptr is not None + ): + if work.is_valid_tile: + num_splits = Int32(self.params.num_splits_dynamic_ptr[batch_idx]) + split_idx = split_idx | (num_splits << 16) return WorkTileInfo( (Int32(block_idx), Int32(work.tile_idx[1]), Int32(batch_idx), Int32(split_idx)), work.is_valid_tile, @@ -665,7 +684,7 @@ def clc_work_to_coords(self, work) -> WorkTileInfo: @cute.jit def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): - work = self.ctx.get_current_work() + work = self._ctx.get_current_work() self._tile_idx = work.tile_idx[0] return self.clc_work_to_coords(work) # Static path: L2-swizzled coordinate mapping @@ -685,27 +704,33 @@ def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: if const_expr(params.lpt): block = params.num_block - 1 - block is_valid = self._tile_idx < params.total_blocks + split_idx = self._split_idx + # Pack dynamic per-batch num_splits into high 16 bits of split_idx + if const_expr(params.is_split_kv and params.num_splits_dynamic_ptr is not None): + if is_valid: + num_splits = Int32(params.num_splits_dynamic_ptr[batch_idx]) + split_idx = split_idx | (num_splits << 16) return WorkTileInfo( - (Int32(block), Int32(head_idx), Int32(batch_idx), Int32(self._split_idx)), is_valid + (Int32(block), Int32(head_idx), Int32(batch_idx), Int32(split_idx)), is_valid ) @cute.jit def initial_work_tile_info(self, *, loc=None, ip=None): if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): - work = self.ctx.initial_work_tile_info() + work = self._ctx.initial_work_tile_info() self._tile_idx = work.tile_idx[0] return self.clc_work_to_coords(work) return self.get_current_work(loc=loc, ip=ip) def prefetch_next_work(self, *, loc=None, ip=None): if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): - self.ctx.prefetch_next_work(loc=loc, ip=ip) + self._ctx.prefetch_next_work(loc=loc, ip=ip) def advance_to_next_work(self, *, loc=None, ip=None): if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): - self.ctx.consumer_wait(loc=loc, ip=ip) + self._ctx.consumer_wait(loc=loc, ip=ip) work = self.get_current_work() - self.ctx.consumer_release(loc=loc, ip=ip) + self._ctx.consumer_release(loc=loc, ip=ip) return work # Single tile scheduler - set to invalid tile_idx to indicate no more work self._tile_idx = self.params.total_blocks @@ -713,13 +738,13 @@ def advance_to_next_work(self, *, loc=None, ip=None): def producer_tail(self, *, loc=None, ip=None): if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): - self.ctx.producer_tail(loc=loc, ip=ip) + self._ctx.producer_tail(loc=loc, ip=ip) def __extract_mlir_values__(self): values, self._values_pos = [], [] objs = [self.params, self._tile_idx, self._split_idx] if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): - objs += [self.ctx] + objs += [self._ctx] for obj in objs: obj_values = cutlass.extract_mlir_values(obj) values += obj_values @@ -730,7 +755,7 @@ def __new_from_mlir_values__(self, values): obj_list = [] objs = [self.params, self._tile_idx, self._split_idx] if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): - objs += [self.ctx] + objs += [self._ctx] for obj, n_items in zip(objs, self._values_pos): obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) values = values[n_items:] @@ -871,7 +896,218 @@ def __new_from_mlir_values__(self, values): return self.__class__(*(tuple(obj_list)), loc=self._loc) -class SingleTileVarlenScheduler: +class VarlenSchedulerBase: + """Base for varlen tile schedulers (SingleTileVarlenScheduler, + DynamicPersistentVarlenScheduler). Owns the shared per-batch m-block lookup + and the warp-prefix-sum search-and-decode of the work tile. + + Subclasses must expose: + - self.params (ParamsBase) with the fields documented on each method. + - _get_num_splits(lane, bidb_start) -> Int32 + """ + + @cute.jit + def _get_num_m_blocks(self, lane: Int32, bidb_start: Int32) -> Int32: + """Per-batch m-block count""" + params = self.params + cluster_shape_m = getattr(params, "cluster_shape_m", 1) + batch_idx = lane + bidb_start + is_valid = batch_idx < params.num_batch and lane < cute.arch.WARP_SIZE - 1 + if cutlass.const_expr(params.num_m_blocks_ptr is not None): + num_m_blocks_raw = Int32(0) + if is_valid: + if cutlass.const_expr(params.virtual_batch_idx_ptr is not None): + real_batch_idx = params.virtual_batch_idx_ptr[batch_idx] + else: + real_batch_idx = batch_idx + num_m_blocks_raw = Int32(params.num_m_blocks_ptr[real_batch_idx]) + return cute.ceil_div(num_m_blocks_raw, cluster_shape_m) if is_valid else Int32(0) + if cutlass.const_expr(params.virtual_batch_idx_ptr is not None): + seqlen = Int32(0) + if is_valid: + real_batch_idx = params.virtual_batch_idx_ptr[batch_idx] + seqlen = ( + params.mCuSeqlensQ[real_batch_idx + 1] + - params.mCuSeqlensQ[real_batch_idx] + ) + if cutlass.const_expr(params.qhead_per_kvhead_packgqa > 1): + seqlen *= params.qhead_per_kvhead_packgqa + return ( + cute.ceil_div(cute.ceil_div(seqlen, params.tile_shape_mn[0]), cluster_shape_m) + if is_valid + else Int32(0) + ) + if cutlass.const_expr(params.mSeqUsedQ is not None): + seqlen = Int32(0) + if batch_idx < params.num_batch: + seqlen = params.mSeqUsedQ[batch_idx] + else: + assert params.mCuSeqlensQ is not None + cur_cu_seqlen = Int32(0) + if batch_idx <= params.num_batch: + cur_cu_seqlen = params.mCuSeqlensQ[batch_idx] + next_cu_seqlen = cute.arch.shuffle_sync_down(cur_cu_seqlen, offset=1) + seqlen = next_cu_seqlen - cur_cu_seqlen + if cutlass.const_expr(params.qhead_per_kvhead_packgqa > 1): + seqlen *= params.qhead_per_kvhead_packgqa + return ( + cute.ceil_div(cute.ceil_div(seqlen, params.tile_shape_mn[0]), cluster_shape_m) + if is_valid + else Int32(0) + ) + + @cute.jit + def _varlen_coord_map( + self, + next_tile_idx: Int32, + bidb_start: Int32, + group_start_tile: Int32, + ) -> Tuple[Int32, Int32, Int32, Int32, Int32, Int32, Boolean]: + """Search varlen batches via warp-level prefix sums and decode the work tile. + + Returns + - block + - head_idx + - batch_idx + - split_idx + - num_splits + - group_start_tile + - is_valid + + self.params must expose: + - num_head: Int32 + - num_batch: Int32 + - tile_shape_mn: Constexpr[tuple] + - qhead_per_kvhead_packgqa: Constexpr[int] + - max_kvblock_in_l2: Int32 + - is_split_kv: Constexpr[bool] + - lpt: Constexpr[bool] + Optionally: + - head_swizzle: Constexpr[bool] | None + - cluster_shape_m: Constexpr[int] | None + - num_nheads_in_l2_ptr: cute.Tensor | None + - virtual_batch_idx_ptr: cute.Tensor | None + """ + params = self.params + head_swizzle = getattr(params, "head_swizzle", False) + cluster_shape_m = getattr(params, "cluster_shape_m", 1) + num_nheads_in_l2_ptr = getattr(params, "num_nheads_in_l2_ptr", None) + virtual_batch_idx_ptr = getattr(params, "virtual_batch_idx_ptr", None) + + lane_idx = cute.arch.lane_idx() + num_m_blocks = self._get_num_m_blocks(lane_idx, bidb_start=bidb_start) + num_splits = self._get_num_splits(lane_idx, bidb_start=bidb_start) + per_batch = num_m_blocks * num_splits if const_expr(params.is_split_kv) else num_m_blocks + cumulative = utils.warp_prefix_sum(per_batch, lane_idx) + m_blocks_in_group = cute.arch.shuffle_sync(cumulative, cute.arch.WARP_SIZE - 1) + group_end_tile = m_blocks_in_group * params.num_head + group_start_tile + + batch_idx = bidb_start + while group_end_tile <= next_tile_idx: + batch_idx += cute.arch.WARP_SIZE - 1 + if batch_idx >= params.num_batch: + batch_idx = Int32(params.num_batch) + group_end_tile = next_tile_idx + 1 + else: + num_m_blocks = self._get_num_m_blocks(lane_idx, bidb_start=batch_idx) + num_splits = self._get_num_splits(lane_idx, bidb_start=batch_idx) + per_batch = ( + num_m_blocks * num_splits if const_expr(params.is_split_kv) else num_m_blocks + ) + cumulative = utils.warp_prefix_sum(per_batch, lane_idx) + m_blocks_in_group = cute.arch.shuffle_sync(cumulative, cute.arch.WARP_SIZE - 1) + group_end_tile += m_blocks_in_group * params.num_head + + is_valid = batch_idx < params.num_batch + block, head_idx, split_idx = Int32(0), Int32(0), Int32(0) + if is_valid: + group_start_tile = group_end_tile - m_blocks_in_group * params.num_head + batch_idx_in_group = cute.arch.popc( + cute.arch.vote_ballot_sync( + group_start_tile + cumulative * params.num_head <= next_tile_idx + ) + ) + batch_idx += batch_idx_in_group + num_m_blocks_prev_lane = ( + Int32(0) + if batch_idx_in_group == 0 + else cute.arch.shuffle_sync(cumulative, batch_idx_in_group - 1) + ) + group_start_tile += num_m_blocks_prev_lane * params.num_head + num_m_blocks = cute.arch.shuffle_sync(num_m_blocks, batch_idx_in_group) + if const_expr(params.is_split_kv): + num_splits = cute.arch.shuffle_sync(num_splits, batch_idx_in_group) + mh_block = next_tile_idx - group_start_tile + + if const_expr(params.lpt or head_swizzle): + # This is a version of the SingleTileLPTScheduler, complicated by the fact that + # the seqlen can vary per batch. + # TODO: is there any case where num_m_blocks is 0? + if const_expr(not params.is_split_kv) or num_splits == 1: + if const_expr(num_nheads_in_l2_ptr is not None): + if const_expr(virtual_batch_idx_ptr is not None): + nheads_in_l2 = Int32(num_nheads_in_l2_ptr[virtual_batch_idx_ptr[batch_idx]]) + else: + nheads_in_l2 = Int32(num_nheads_in_l2_ptr[batch_idx]) + else: + # TODO: by right we should read the seqlen_kv but we're assuming seqlen_q == seqlen_k here + num_n_blocks = ( + num_m_blocks + * params.tile_shape_mn[0] + * cluster_shape_m + // params.qhead_per_kvhead_packgqa + // params.tile_shape_mn[1] + ) + # Seems faster to have nheads_in_l2 be a power of 2 + nheads_in_l2 = ( + 16 + if num_n_blocks * 16 <= params.max_kvblock_in_l2 + else ( + 8 + if num_n_blocks * 8 <= params.max_kvblock_in_l2 + else ( + 4 + if num_n_blocks * 4 <= params.max_kvblock_in_l2 + else (2 if num_n_blocks * 2 <= params.max_kvblock_in_l2 else 1) + ) + ) + ) + nheads_in_l2 = min(nheads_in_l2, params.num_head) + mh_in_l2 = nheads_in_l2 * num_m_blocks + section_idx = mh_block // mh_in_l2 + l2_mod = mh_block - section_idx * mh_in_l2 + nheads_in_this_section = ( + nheads_in_l2 + if nheads_in_l2 * (section_idx + 1) <= params.num_head + else params.num_head - section_idx * nheads_in_l2 + ) + block = l2_mod // nheads_in_this_section + head_idx_residual = l2_mod - block * nheads_in_this_section + head_idx = section_idx * nheads_in_l2 + head_idx_residual + else: + head_split_idx = mh_block // num_m_blocks + block = mh_block - head_split_idx * num_m_blocks + head_idx = head_split_idx // num_splits + split_idx = head_split_idx - head_idx * num_splits + if const_expr(params.lpt): + block = num_m_blocks - 1 - block + else: + head_split_idx = mh_block // num_m_blocks + block = mh_block - head_split_idx * num_m_blocks + if const_expr(params.is_split_kv): + head_idx = head_split_idx // num_splits + split_idx = head_split_idx - head_idx * num_splits + else: + head_idx = head_split_idx + + if const_expr(cluster_shape_m > 1): + bidx_in_cluster = cute.arch.block_in_cluster_idx() + block = block * cluster_shape_m + bidx_in_cluster[0] + + return block, head_idx, batch_idx, split_idx, num_splits, group_start_tile, is_valid + + +class SingleTileVarlenScheduler(VarlenSchedulerBase): @dataclass class Params(ParamsBase): num_head: Int32 @@ -888,6 +1124,10 @@ class Params(ParamsBase): head_swizzle: cutlass.Constexpr[bool] = False cluster_shape_m: cutlass.Constexpr[int] = 1 scheduling_mode: cutlass.Constexpr[SchedulingMode] = SchedulingMode.STATIC + num_splits_dynamic_ptr: Optional[cute.Tensor] = None + num_m_blocks_ptr: Optional[cute.Tensor] = None + virtual_batch_idx_ptr: Optional[cute.Tensor] = None + num_nheads_in_l2_ptr: Optional[cute.Tensor] = None @staticmethod @cute.jit @@ -934,6 +1174,10 @@ def create( head_swizzle=args.head_swizzle, cluster_shape_m=args.cluster_shape_mn[0], scheduling_mode=scheduling_mode, + num_splits_dynamic_ptr=args.num_splits_dynamic_ptr, + num_m_blocks_ptr=args.num_m_blocks_ptr, + virtual_batch_idx_ptr=args.virtual_batch_idx_ptr, + num_nheads_in_l2_ptr=args.num_nheads_in_l2_ptr, ) def __init__( @@ -950,7 +1194,7 @@ def __init__( self._tile_idx = tile_idx self._split_idx = split_idx self._is_first_block = True - self.ctx = ctx + self._ctx = ctx self._loc = loc self._ip = ip @@ -1012,133 +1256,37 @@ def get_grid_shape( return (total_blocks_max * params.num_head, params.num_splits, Int32(1)) @cute.jit - def _get_num_m_blocks(self, lane: Int32, bidb_start: Int32) -> Int32: - params = self.params - batch_idx = lane + bidb_start - if cutlass.const_expr(params.mSeqUsedQ is not None): - seqlen = Int32(0) - if batch_idx < params.num_batch: - seqlen = params.mSeqUsedQ[batch_idx] - else: - assert params.mCuSeqlensQ is not None - cur_cu_seqlen = Int32(0) - if batch_idx <= params.num_batch: - cur_cu_seqlen = params.mCuSeqlensQ[batch_idx] - next_cu_seqlen = cute.arch.shuffle_sync_down(cur_cu_seqlen, offset=1) - seqlen = next_cu_seqlen - cur_cu_seqlen - if cutlass.const_expr(params.qhead_per_kvhead_packgqa > 1): - seqlen *= params.qhead_per_kvhead_packgqa - return ( - cute.ceil_div(cute.ceil_div(seqlen, params.tile_shape_mn[0]), params.cluster_shape_m) - if batch_idx < params.num_batch and lane < cute.arch.WARP_SIZE - 1 - else Int32(0) - ) + def _get_num_splits(self, lane: Int32, bidb_start: Int32) -> Int32: + return Int32(1) @cute.jit - def _varlen_coord_map(self) -> WorkTileInfo: - """Map self._tile_idx to (block, head, batch) via warp-level prefix sums.""" + def _decode_work_tile(self) -> WorkTileInfo: + """Map self._tile_idx to (block, head, batch, split) via warp-level prefix sums.""" params = self.params - lane_idx = cute.arch.lane_idx() - num_m_blocks = self._get_num_m_blocks(lane_idx, bidb_start=0) - num_m_blocks_cumulative = utils.warp_prefix_sum(num_m_blocks, lane_idx) - # Total number of blocks for the next 31 batches - m_blocks_in_group = cute.arch.shuffle_sync(num_m_blocks_cumulative, cute.arch.WARP_SIZE - 1) - # Same for all lanes - group_end_tile = m_blocks_in_group * params.num_head - # if cute.arch.thread_idx()[0] == 128 + 31: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, group_end_tile = %d, num_m_blocks=%d, num_m_blocks_cumulative = %d, m_blocks_in_group = %d", self._tile_idx, group_end_tile, num_m_blocks, num_m_blocks_cumulative, m_blocks_in_group) - block, head_idx, batch_idx = Int32(0), Int32(0), Int32(0) next_tile_idx = self._tile_idx // params.cluster_shape_m - while group_end_tile <= next_tile_idx: - batch_idx += cute.arch.WARP_SIZE - 1 - if batch_idx >= params.num_batch: - batch_idx = Int32(params.num_batch) - group_end_tile = next_tile_idx + 1 - else: - num_m_blocks = self._get_num_m_blocks(lane_idx, bidb_start=batch_idx) - num_m_blocks_cumulative = utils.warp_prefix_sum(num_m_blocks, lane_idx) - m_blocks_in_group = cute.arch.shuffle_sync( - num_m_blocks_cumulative, cute.arch.WARP_SIZE - 1 - ) - group_end_tile += m_blocks_in_group * params.num_head - is_valid = False - if batch_idx >= params.num_batch: - block, head_idx, batch_idx = Int32(0), Int32(0), Int32(params.num_batch) - else: - group_start_tile = group_end_tile - m_blocks_in_group * params.num_head - # if cute.arch.thread_idx()[0] == 128 + 31: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, group_end_tile = %d, num_m_blocks=%d, batch_idx = %d", self._tile_idx, group_end_tile, num_m_blocks, batch_idx) - # The next problem to process is the first one that does not have ending tile position - # that is greater than or equal to tile index. - batch_idx_in_group = cute.arch.popc( - cute.arch.vote_ballot_sync( - group_start_tile + num_m_blocks_cumulative * params.num_head <= next_tile_idx - ) - ) - batch_idx += batch_idx_in_group - num_m_blocks_prev_lane = ( - 0 - if batch_idx_in_group == 0 - else cute.arch.shuffle_sync(num_m_blocks_cumulative, batch_idx_in_group - 1) - ) - num_m_blocks = cute.arch.shuffle_sync(num_m_blocks, batch_idx_in_group) - mh_block = next_tile_idx - group_start_tile - num_m_blocks_prev_lane * params.num_head - if cutlass.const_expr(params.lpt or params.head_swizzle): - # This is a version of the SingleTileLPTScheduler, complicated by the fact that - # the seqlen can vary per batch. - # TODO: is there any case where num_m_blocks is 0? - # TODO: by right we should read the seqlen_kv but we're assuming seqlen_q == seqlen_k here - num_n_blocks = ( - num_m_blocks - * params.tile_shape_mn[0] - * params.cluster_shape_m - // params.qhead_per_kvhead_packgqa - // params.tile_shape_mn[1] - ) - # nheads_in_l2 = min(max(self.max_kvblock_in_l2 // num_n_blocks, 1), self.num_head) - # Seems faster to have this be a power of 2 - nheads_in_l2 = ( - 16 - if num_n_blocks * 16 <= params.max_kvblock_in_l2 - else ( - 8 - if num_n_blocks * 8 <= params.max_kvblock_in_l2 - else ( - 4 - if num_n_blocks * 4 <= params.max_kvblock_in_l2 - else (2 if num_n_blocks * 2 <= params.max_kvblock_in_l2 else 1) - ) - ) - ) - nheads_in_l2 = min(nheads_in_l2, params.num_head) - mh_in_l2 = nheads_in_l2 * num_m_blocks - section_idx = mh_block // mh_in_l2 - l2_mod = mh_block - section_idx * mh_in_l2 - # Deal with tail section - nheads_in_this_section = ( - nheads_in_l2 - if nheads_in_l2 * (section_idx + 1) <= params.num_head - else params.num_head - section_idx * nheads_in_l2 - ) - block = l2_mod // nheads_in_this_section - head_idx_residual = l2_mod - block * nheads_in_this_section - head_idx = section_idx * nheads_in_l2 + head_idx_residual - if cutlass.const_expr(params.lpt): - block = num_m_blocks - 1 - block - else: - head_idx = mh_block // num_m_blocks - block = mh_block - head_idx * num_m_blocks - is_valid = self._is_first_block and batch_idx < params.num_batch - if cutlass.const_expr(params.cluster_shape_m > 1): - bidx_in_cluster = cute.arch.block_in_cluster_idx() - block = block * params.cluster_shape_m + bidx_in_cluster[0] - # if cute.arch.thread_idx()[0] == 128: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, batch_idx=%d, head_idx=%d, block=%d, is_valid = %d", self._tile_idx, batch_idx, head_idx, block, is_valid) + block, head_idx, batch_idx, _, _, _, is_valid = self._varlen_coord_map( + next_tile_idx, Int32(0), Int32(0) + ) + is_valid = is_valid and self._is_first_block split_idx = self._split_idx if const_expr(params.is_split_kv) else Int32(0) - return WorkTileInfo((Int32(block), Int32(head_idx), Int32(batch_idx), split_idx), is_valid) + if const_expr(params.virtual_batch_idx_ptr is not None): + if is_valid: + batch_idx = params.virtual_batch_idx_ptr[batch_idx] + # Pack dynamic per-batch num_splits into high 16 bits of split_idx + if const_expr(params.is_split_kv and params.num_splits_dynamic_ptr is not None): + if is_valid: + num_splits = Int32(params.num_splits_dynamic_ptr[batch_idx]) + split_idx = split_idx | (num_splits << 16) + return WorkTileInfo( + (Int32(block), Int32(head_idx), Int32(batch_idx), Int32(split_idx)), + is_valid, + ) @cute.jit def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): - clc_work = self.ctx.get_current_work() - # Default to grid_dim (one past last valid flat index) so _varlen_coord_map + clc_work = self._ctx.get_current_work() + # Default to grid_dim (one past last valid flat index) so _decode_work_tile # returns is_valid=False when CLC is exhausted. CLC tile_idx is garbage when # invalid, so we can't trust it. Local-then-assign avoids CuTe DSL structural # mismatch on self inside the runtime if. @@ -1150,12 +1298,12 @@ def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: new_split_idx = clc_work.tile_idx[1] self._tile_idx = new_tile_idx self._split_idx = new_split_idx - return self._varlen_coord_map() + return self._decode_work_tile() @cute.jit def initial_work_tile_info(self, *, loc=None, ip=None): if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): - clc_work = self.ctx.initial_work_tile_info() + clc_work = self._ctx.initial_work_tile_info() # See get_current_work for why grid_dim and local-then-assign. new_tile_idx = cute.arch.grid_dim()[0] new_split_idx = Int32(0) @@ -1165,30 +1313,30 @@ def initial_work_tile_info(self, *, loc=None, ip=None): new_split_idx = clc_work.tile_idx[1] self._tile_idx = new_tile_idx self._split_idx = new_split_idx - return self._varlen_coord_map() + return self._decode_work_tile() def prefetch_next_work(self, *, loc=None, ip=None): if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): - self.ctx.prefetch_next_work(loc=loc, ip=ip) + self._ctx.prefetch_next_work(loc=loc, ip=ip) def advance_to_next_work(self, *, loc=None, ip=None): if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): - self.ctx.consumer_wait(loc=loc, ip=ip) + self._ctx.consumer_wait(loc=loc, ip=ip) work = self.get_current_work() - self.ctx.consumer_release(loc=loc, ip=ip) + self._ctx.consumer_release(loc=loc, ip=ip) return work self._is_first_block = False return self.get_current_work() def producer_tail(self, *, loc=None, ip=None): if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): - self.ctx.producer_tail(loc=loc, ip=ip) + self._ctx.producer_tail(loc=loc, ip=ip) def __extract_mlir_values__(self): values, self._values_pos = [], [] objs = [self.params, self._tile_idx, self._split_idx] if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): - objs += [self.ctx] + objs += [self._ctx] for obj in objs: obj_values = cutlass.extract_mlir_values(obj) values += obj_values @@ -1199,14 +1347,14 @@ def __new_from_mlir_values__(self, values): obj_list = [] objs = [self.params, self._tile_idx, self._split_idx] if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): - objs += [self.ctx] + objs += [self._ctx] for obj, n_items in zip(objs, self._values_pos): obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) values = values[n_items:] return self.__class__(*obj_list, loc=self._loc) -class DynamicPersistentVarlenScheduler: +class DynamicPersistentVarlenScheduler(VarlenSchedulerBase): @dataclass class Params(ParamsBase): num_head: Int32 @@ -1318,41 +1466,6 @@ def get_grid_shape( ) return (cutlass.min(sm_count, total_blocks), Int32(1), Int32(1)) - @cute.jit - def _get_num_m_blocks(self, lane: Int32, bidb_start: Int32) -> Int32: - params = self.params - batch_idx = lane + bidb_start - if cutlass.const_expr(params.virtual_batch_idx_ptr is not None): - seqlen = Int32(0) - if batch_idx < params.num_batch and lane < cute.arch.WARP_SIZE - 1: - real_batch_idx = params.virtual_batch_idx_ptr[batch_idx] - seqlen = params.mCuSeqlensQ[real_batch_idx + 1] - params.mCuSeqlensQ[real_batch_idx] - if cutlass.const_expr(params.qhead_per_kvhead_packgqa > 1): - seqlen *= params.qhead_per_kvhead_packgqa - return ( - cute.ceil_div(seqlen, params.tile_shape_mn[0]) - if batch_idx < params.num_batch and lane < cute.arch.WARP_SIZE - 1 - else Int32(0) - ) - if cutlass.const_expr(params.mSeqUsedQ is not None): - seqlen = Int32(0) - if batch_idx < params.num_batch: - seqlen = params.mSeqUsedQ[batch_idx] - else: - assert params.mCuSeqlensQ is not None - cur_cu_seqlen = Int32(0) - if batch_idx <= params.num_batch: - cur_cu_seqlen = params.mCuSeqlensQ[batch_idx] - next_cu_seqlen = cute.arch.shuffle_sync_down(cur_cu_seqlen, offset=1) - seqlen = next_cu_seqlen - cur_cu_seqlen - if cutlass.const_expr(params.qhead_per_kvhead_packgqa > 1): - seqlen *= params.qhead_per_kvhead_packgqa - return ( - cute.ceil_div(seqlen, params.tile_shape_mn[0]) - if batch_idx < params.num_batch and lane < cute.arch.WARP_SIZE - 1 - else Int32(0) - ) - @cute.jit def _get_num_splits(self, lane: Int32, bidb_start: Int32) -> Int32: params = self.params @@ -1363,6 +1476,8 @@ def _get_num_splits(self, lane: Int32, bidb_start: Int32) -> Int32: elif cutlass.const_expr(params.num_splits_dynamic_ptr is not None): num_splits = Int32(0) if is_valid: + if cutlass.const_expr(params.virtual_batch_idx_ptr is not None): + batch_idx = params.virtual_batch_idx_ptr[batch_idx] num_splits = params.num_splits_dynamic_ptr[batch_idx] return num_splits else: @@ -1379,127 +1494,19 @@ def get_current_work( ip=None, ) -> WorkTileInfo: params = self.params - lane_idx = cute.arch.lane_idx() - num_m_blocks = self._get_num_m_blocks(lane_idx, bidb_start=bidb_start) - num_splits = self._get_num_splits(lane_idx, bidb_start=bidb_start) - num_splits_m_blocks = ( - num_m_blocks * num_splits if const_expr(params.is_split_kv) else num_m_blocks + block, head_idx, batch_idx, split_idx, num_splits, group_start_tile, is_valid = ( + self._varlen_coord_map(next_tile_idx, bidb_start, group_start_tile) ) - num_m_blocks_cumulative = utils.warp_prefix_sum(num_splits_m_blocks, lane_idx) - # Total number of blocks for the next 31 batches - m_blocks_in_group = cute.arch.shuffle_sync(num_m_blocks_cumulative, cute.arch.WARP_SIZE - 1) - # Same for all lanes - group_end_tile = m_blocks_in_group * params.num_head + group_start_tile - - block, head_idx, batch_idx, split_idx = Int32(0), Int32(0), bidb_start, Int32(0) - while group_end_tile <= next_tile_idx: - batch_idx += cute.arch.WARP_SIZE - 1 - if batch_idx >= params.num_batch: - batch_idx = Int32(params.num_batch) - group_end_tile = next_tile_idx + 1 - else: - num_m_blocks = self._get_num_m_blocks(lane_idx, bidb_start=batch_idx) - num_splits = self._get_num_splits(lane_idx, bidb_start=batch_idx) - num_splits_m_blocks = ( - num_m_blocks * num_splits if const_expr(params.is_split_kv) else num_m_blocks - ) - num_m_blocks_cumulative = utils.warp_prefix_sum(num_splits_m_blocks, lane_idx) - m_blocks_in_group = cute.arch.shuffle_sync( - num_m_blocks_cumulative, cute.arch.WARP_SIZE - 1 - ) - group_end_tile += m_blocks_in_group * params.num_head - is_valid = batch_idx < params.num_batch - if is_valid: - group_start_tile = group_end_tile - m_blocks_in_group * params.num_head - # The next problem to process is the first one that does not have ending tile position - # that is less than or equal to tile index. - batch_idx_in_group = cute.arch.popc( - cute.arch.vote_ballot_sync( - group_start_tile + num_m_blocks_cumulative * params.num_head <= next_tile_idx - ) - ) - batch_idx += batch_idx_in_group - num_m_blocks_prev_lane = ( - 0 - if batch_idx_in_group == 0 - else cute.arch.shuffle_sync(num_m_blocks_cumulative, batch_idx_in_group - 1) - ) - group_start_tile += num_m_blocks_prev_lane * params.num_head - num_m_blocks = cute.arch.shuffle_sync(num_m_blocks, batch_idx_in_group) - if const_expr(params.is_split_kv): - num_splits = cute.arch.shuffle_sync(num_splits, batch_idx_in_group) - mh_block = next_tile_idx - group_start_tile - if const_expr(params.lpt): - if const_expr(not params.is_split_kv) or num_splits == 1: - # This is a version of the SingleTileLPTScheduler, complicated by the fact that - # the seqlen can vary per batch. - if const_expr(params.num_nheads_in_l2_ptr is not None): - nheads_in_l2 = Int32(params.num_nheads_in_l2_ptr[batch_idx]) - else: - # TODO: by right we should read the seqlen_kv but we're assuming seqlen_q == seqlen_k here - num_n_blocks = ( - num_m_blocks - * params.tile_shape_mn[0] - // params.qhead_per_kvhead_packgqa - // params.tile_shape_mn[1] - ) - # Seems faster to have this be a power of 2 - nheads_in_l2 = ( - 16 - if num_n_blocks * 16 <= params.max_kvblock_in_l2 - else ( - 8 - if num_n_blocks * 8 <= params.max_kvblock_in_l2 - else ( - 4 - if num_n_blocks * 4 <= params.max_kvblock_in_l2 - else (2 if num_n_blocks * 2 <= params.max_kvblock_in_l2 else 1) - ) - ) - ) - nheads_in_l2 = min(nheads_in_l2, params.num_head) - mh_in_l2 = nheads_in_l2 * num_m_blocks - section_idx = mh_block // mh_in_l2 - l2_mod = mh_block - section_idx * mh_in_l2 - # Deal with tail section - nheads_in_this_section = ( - nheads_in_l2 - if nheads_in_l2 * (section_idx + 1) <= params.num_head - else params.num_head - section_idx * nheads_in_l2 - ) - block = l2_mod // nheads_in_this_section - head_idx_residual = l2_mod - block * nheads_in_this_section - head_idx = section_idx * nheads_in_l2 + head_idx_residual - else: - head_split_idx = mh_block // num_m_blocks - block = mh_block - head_split_idx * num_m_blocks - if const_expr(params.is_split_kv): - head_idx = head_split_idx // num_splits - split_idx = head_split_idx - head_idx * num_splits - else: - head_idx = head_split_idx - - block = num_m_blocks - 1 - block - else: - head_split_idx = mh_block // num_m_blocks - block = mh_block - head_split_idx * num_m_blocks - if const_expr(params.is_split_kv): - head_idx = head_split_idx // num_splits - split_idx = head_split_idx - head_idx * num_splits - else: - head_idx = head_split_idx - - # Pack num_splits into top 16 bits of split_idx if const_expr(params.is_split_kv and params.num_splits_dynamic_ptr is not None): if is_valid: split_idx = split_idx | (num_splits << 16) if const_expr(params.virtual_batch_idx_ptr is not None): if is_valid: batch_idx = params.virtual_batch_idx_ptr[batch_idx] - return ( WorkTileInfo( - (Int32(block), Int32(head_idx), Int32(batch_idx), Int32(split_idx)), is_valid + (Int32(block), Int32(head_idx), Int32(batch_idx), Int32(split_idx)), + is_valid, ), group_start_tile, ) @@ -1568,7 +1575,7 @@ def __new_from_mlir_values__(self, values): ): obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) values = values[n_items:] - return DynamicPersistentVarlenScheduler(*(tuple(obj_list)), loc=self._loc) + return self.__class__(*obj_list, loc=self._loc) # ----------------------------------------------------------------------------- diff --git a/tests/cute/test_flash_attn_combine.py b/tests/cute/test_flash_attn_combine.py index 6344f96ab4b..ea77ded67ed 100644 --- a/tests/cute/test_flash_attn_combine.py +++ b/tests/cute/test_flash_attn_combine.py @@ -228,12 +228,12 @@ def test_flash_attn_combine_varlen(varlen_mode, num_splits, seqlen, d, dtype): @pytest.mark.parametrize("num_splits", [2, 5, 17]) # @pytest.mark.parametrize("num_splits", [5]) @maybe_fake_tensor_mode(USE_FAKE_TENSOR) -def test_flash_attn_combine_varlen_batch_idx(num_splits, seqlen, d, dtype): - """Test that varlen_batch_idx correctly remaps virtual batch indices to real batch indices. +def test_flash_attn_combine_virtual_batch_idx(num_splits, seqlen, d, dtype): + """Test that virtual_batch_idx correctly remaps virtual batch indices to real batch indices. - varlen_batch_idx maps blockIdx.z (virtual batch) -> real batch index. The kernel + virtual_batch_idx maps blockIdx.z (virtual batch) -> real batch index. The kernel reads AND writes using the remapped batch_idx, so with a permutation the output - should match running without varlen_batch_idx (each real batch is processed once). + should match running without virtual_batch_idx (each real batch is processed once). We also test with seqused to verify interaction with variable-length sequences. """ @@ -255,18 +255,18 @@ def test_flash_attn_combine_varlen_batch_idx(num_splits, seqlen, d, dtype): perm = torch.tensor([2, 0, 3, 1], device=device, dtype=torch.int32) assert perm.shape[0] == batch_size - # Also test with seqused to verify interaction with varlen_batch_idx + # Also test with seqused to verify interaction with virtual_batch_idx seqused = torch.randint(1, seqlen + 1, (batch_size,), device=device, dtype=torch.int32) # Zero out / -inf beyond seqused so reference matches kernel for i in range(batch_size): out_partial[:, i, seqused[i]:] = 0 lse_partial[:, i, seqused[i]:] = -float("inf") - # Run with varlen_batch_idx and seqused via public API + # Run with virtual_batch_idx and seqused via public API out, lse = flash_attn_combine( out_partial, lse_partial, out_dtype=dtype, seqused=seqused, - varlen_batch_idx=perm, + virtual_batch_idx=perm, return_lse=True, ) if is_fake_mode(): From e4071c0de58efdcbb06f415ce78fddaa6bbf8e08 Mon Sep 17 00:00:00 2001 From: reubenconducts Date: Wed, 13 May 2026 00:31:37 +0000 Subject: [PATCH 05/15] add benchmark script for varlen dynamic persistent scheduler --- benchmarks/benchmark_varlen_sched.py | 457 +++++++++++++++++++++++++++ 1 file changed, 457 insertions(+) create mode 100644 benchmarks/benchmark_varlen_sched.py diff --git a/benchmarks/benchmark_varlen_sched.py b/benchmarks/benchmark_varlen_sched.py new file mode 100644 index 00000000000..2e99eb27d99 --- /dev/null +++ b/benchmarks/benchmark_varlen_sched.py @@ -0,0 +1,457 @@ +"""Benchmark the dynamic-persistent varlen scheduler against the prior default +(`SingleTileVarlenScheduler`), CLC (if available), and — on constant-seqlen workloads — the +non-varlen `flash_attn_func` baseline. + +Examples: + python benchmarks/benchmark_varlen_sched.py --total-tokens 32k --patterns longtail + python benchmarks/benchmark_varlen_sched.py --total-tokens 32k,64k --shapes 32x1k,16x2k \\ + --patterns constant longtail --csv > out.csv +""" + +import argparse +import time +from itertools import product + +import torch +from triton.testing import do_bench + +from flash_attn.cute import utils as fa_utils +from flash_attn.cute.interface import ( + flash_attn_func, + flash_attn_varlen_func, + get_scheduler_metadata, +) + + +_CLC_MODES = {"clc", "clc-prep"} + + +def _supports_clc(device): + return torch.cuda.get_device_capability(device)[0] == 10 + + +def parse_int_k(s): + """Parse an integer with optional k/K/m/M suffix, e.g. '8k' -> 8192, '1m' -> 1048576.""" + s = str(s).strip().lower() + if s.endswith("m"): + return int(s[:-1]) * 1024 * 1024 + if s.endswith("k"): + return int(s[:-1]) * 1024 + return int(s) + + +def csv_ints(s): + return [parse_int_k(x) for x in s.split(",")] + + +def parse_shape(s): + """Parse 'x' (seqlen accepts k suffix). Returns (batch, seqlen).""" + b, sl = s.lower().split("x") + return int(b), parse_int_k(sl) + + +def parse_shapes(s): + return [parse_shape(x) for x in s.split(",")] + + +def _make_seqlens(batch, seqlen, pattern, seed): + g = torch.Generator(device="cpu").manual_seed(seed) + if pattern == "constant": + return [seqlen] * batch + if pattern == "uniform": + lo = max(1, seqlen // 2) + return torch.randint(lo, seqlen + 1, (batch,), generator=g).tolist() + if pattern == "wide": + return torch.randint(1, seqlen + 1, (batch,), generator=g).tolist() + if pattern == "longtail": + n_long = max(1, batch // 8) + out = torch.randint( + max(1, seqlen // 16), max(2, seqlen // 8), (batch,), generator=g + ).tolist() + for i in torch.randperm(batch, generator=g)[:n_long].tolist(): + out[i] = seqlen + return out + if pattern == "bimodal": + return [seqlen if i % 2 == 0 else max(1, seqlen // 8) for i in range(batch)] + if pattern == "skew": + return [max(1, int(seqlen * i / max(1, batch - 1))) for i in range(batch)] + if pattern == "skew_shuffled": + out = [max(1, int(seqlen * i / max(1, batch - 1))) for i in range(batch)] + return [out[i] for i in torch.randperm(batch, generator=g).tolist()] + raise ValueError(f"unknown pattern {pattern!r}") + + +def _causal_tiles(sq, sk, tile_m=128, tile_n=128): + if sq <= 0 or sk <= 0: + return 0 + nq = (sq + tile_m - 1) // tile_m + nk = (sk + tile_n - 1) // tile_n + if nq <= 1: + return nk + return nq * nk - (nq * (nq - 1)) // 2 + + +def _apply_sort(seqlens_q, seqlens_k, sort): + if sort == "none": + return seqlens_q, seqlens_k + pairs = list(zip(seqlens_q, seqlens_k)) + keyfn = { + "asc": lambda p: _causal_tiles(*p), + "desc": lambda p: -_causal_tiles(*p), + }.get(sort) + if keyfn is None: + raise ValueError(f"unknown sort {sort!r}") + pairs.sort(key=keyfn) + return [p[0] for p in pairs], [p[1] for p in pairs] + + +def _override_random_subset( + seqlens_q, seqlens_k, frac, seed_salt, sq_value, sk_value, seed +): + """Pick `frac` of batches at random and overwrite their seqlens to the given values. + `sk_value=None` leaves seqlens_k untouched (used for decode-mix).""" + if frac <= 0: + return seqlens_q, seqlens_k + g = torch.Generator(device="cpu").manual_seed(seed + seed_salt) + B = len(seqlens_q) + n = int(round(frac * B)) + if n <= 0: + return seqlens_q, seqlens_k + idx = torch.randperm(B, generator=g)[:n].tolist() + sq, sk = list(seqlens_q), list(seqlens_k) + for i in idx: + sq[i] = sq_value + if sk_value is not None: + sk[i] = sk_value + return sq, sk + + +def build_ctx( + args, batch, seqlen, pattern, sort, decode_frac, zero_frac, num_splits, seed +): + seqlens_k = _make_seqlens(batch, seqlen, pattern, seed) + seqlens_q = list(seqlens_k) + seqlens_q, seqlens_k = _override_random_subset( + seqlens_q, seqlens_k, decode_frac, 7919, sq_value=1, sk_value=None, seed=seed + ) + seqlens_q, seqlens_k = _override_random_subset( + seqlens_q, seqlens_k, zero_frac, 31337, sq_value=0, sk_value=0, seed=seed + ) + seqlens_q, seqlens_k = _apply_sort(seqlens_q, seqlens_k, sort) + + dtype, device = torch.bfloat16, "cuda" + nheads, nheads_kv, headdim = args.nheads, args.nheads_kv, args.headdim + + cu_q = torch.zeros(batch + 1, dtype=torch.int32, device=device) + cu_q[1:] = torch.tensor(seqlens_q, dtype=torch.int32, device=device).cumsum(0) + cu_k = torch.zeros(batch + 1, dtype=torch.int32, device=device) + cu_k[1:] = torch.tensor(seqlens_k, dtype=torch.int32, device=device).cumsum(0) + q_unpad = torch.randn( + max(sum(seqlens_q), 1), nheads, headdim, device=device, dtype=dtype + ) + k_unpad = torch.randn( + max(sum(seqlens_k), 1), nheads_kv, headdim, device=device, dtype=dtype + ) + v_unpad = torch.randn( + max(sum(seqlens_k), 1), nheads_kv, headdim, device=device, dtype=dtype + ) + + return dict( + batch=batch, + seqlen=seqlen, + pattern=pattern, + decode_frac=decode_frac, + zero_frac=zero_frac, + nheads=nheads, + nheads_kv=nheads_kv, + headdim=headdim, + seqlens_q=seqlens_q, + seqlens_k=seqlens_k, + q_unpad=q_unpad, + k_unpad=k_unpad, + v_unpad=v_unpad, + cu_q=cu_q, + cu_k=cu_k, + max_seqlen_q=max(seqlens_q) if seqlens_q else 0, + max_seqlen_k=max(seqlens_k) if seqlens_k else 0, + causal=True, + num_splits=num_splits, + ) + + +def _make_meta(ctx): + return get_scheduler_metadata( + num_batch=ctx["batch"], + max_seqlen_q=ctx["max_seqlen_q"], + max_seqlen_k=ctx["max_seqlen_k"], + nheads=ctx["nheads"], + nheads_kv=ctx["nheads_kv"], + headdim=ctx["headdim"], + num_splits=ctx["num_splits"], + tile_m=128, + tile_n=128, + causal=ctx["causal"], + cu_seqlens_q=ctx["cu_q"], + cu_seqlens_k=ctx["cu_k"], + ) + + +def setup_dense(ctx): + """Non-varlen baseline; only meaningful when every batch has the same seqlen.""" + if ctx["pattern"] != "constant" or ctx["decode_frac"] != 0 or ctx["zero_frac"] != 0: + return None + batch, seqlen = ctx["batch"], ctx["seqlen"] + nheads, nheads_kv, headdim = ctx["nheads"], ctx["nheads_kv"], ctx["headdim"] + dtype, device = torch.bfloat16, "cuda" + q = torch.randn(batch, seqlen, nheads, headdim, device=device, dtype=dtype) + k = torch.randn(batch, seqlen, nheads_kv, headdim, device=device, dtype=dtype) + v = torch.randn(batch, seqlen, nheads_kv, headdim, device=device, dtype=dtype) + return lambda: flash_attn_func( + q, k, v, causal=ctx["causal"], num_splits=ctx["num_splits"] + ) + + +def make_varlen_setup(*, clc: bool, prep: str): + """`prep` is one of 'none', 'precompute', 'recompute'.""" + assert prep in ("none", "precompute", "recompute") + + def setup(ctx): + meta_precomputed = _make_meta(ctx) if prep == "precompute" else None + + def fn(): + fa_utils._fa_clc_enabled = clc + meta = _make_meta(ctx) if prep == "recompute" else meta_precomputed + return flash_attn_varlen_func( + ctx["q_unpad"], + ctx["k_unpad"], + ctx["v_unpad"], + cu_seqlens_q=ctx["cu_q"], + cu_seqlens_k=ctx["cu_k"], + max_seqlen_q=ctx["max_seqlen_q"], + max_seqlen_k=ctx["max_seqlen_k"], + causal=ctx["causal"], + num_splits=ctx["num_splits"], + scheduler_metadata=meta, + disable_scheduler_metadata=(prep == "none"), + ) + + return fn + + return setup + + +# fmt: off +MODES = [ + ("dense", setup_dense), + ("static", make_varlen_setup(clc=False, prep="none")), + ("clc", make_varlen_setup(clc=True, prep="none")), + ("clc-prep", make_varlen_setup(clc=True, prep="precompute")), + ("dynamic-prep", make_varlen_setup(clc=False, prep="precompute")), + ("dynamic+prep", make_varlen_setup(clc=False, prep="recompute")), +] +# fmt: on + + +def parse_args(): + p = argparse.ArgumentParser(description="Benchmark FA4 varlen scheduler modes") + p.add_argument( + "--total-tokens", + type=csv_ints, + default=[32 * 1024], + help="Total tokens (batch*seqlen) per workload, comma-separated. e.g. 32k,64k", + ) + p.add_argument( + "--shapes", + type=parse_shapes, + default=None, + help="Explicit (batch x seqlen) pairs, comma-separated, e.g. 32x1k,16x2k. " + "If unset, derived from --total-tokens by sweeping a default isoline.", + ) + p.add_argument( + "--patterns", + nargs="+", + default=["constant", "longtail", "bimodal", "uniform"], + help="Length distributions: constant, uniform, wide, longtail, bimodal, skew, skew_shuffled", + ) + p.add_argument( + "--sorts", + nargs="+", + default=["none"], + help="Batch ordering by tile count: none, asc, desc", + ) + p.add_argument( + "--decode-fracs", + nargs="+", + type=float, + default=[0.0], + help="Fraction(s) of batches to force seqlen_q=1 (mixed prefill/decode)", + ) + p.add_argument( + "--zero-fracs", + nargs="+", + type=float, + default=[0.0], + help="Fraction(s) of batches to force seqlen=0", + ) + p.add_argument( + "--num-splits", + nargs="+", + type=int, + default=[1], + help="num_splits values; >1 enables SplitKV", + ) + p.add_argument("--modes", nargs="+", default=[cli for cli, _ in MODES]) + p.add_argument("--headdim", type=int, default=128) + p.add_argument("--nheads", type=int, default=16) + p.add_argument("--nheads-kv", type=int, default=2) + p.add_argument("--seeds", type=int, default=3) + p.add_argument("--warmup", type=int, default=2) + p.add_argument("--rep", type=int, default=20) + p.add_argument( + "--sleep", + type=float, + default=0.5, + help="Sleep between modes to dodge clock throttling (seconds)", + ) + p.add_argument("--device", type=int, default=0) + p.add_argument( + "--csv", action="store_true", help="Emit CSV rows instead of the pretty table" + ) + return p.parse_args() + + +def _default_isoline(total_tokens): + """(batch, seqlen) pairs where batch * seqlen == total_tokens, doubling seqlen from 256.""" + return [ + (total_tokens // s, s) + for s in (1 << b for b in range(8, total_tokens.bit_length())) + if total_tokens // s >= 1 + ] + + +def _format_row(cells, csv, widths): + if csv: + return ",".join(str(c) for c in cells) + return " ".join(f"{str(c):<{w}}" for c, w in zip(cells, widths)) + + +def main(): + args = parse_args() + torch.cuda.set_device(args.device) + torch.manual_seed(0) + + if args.shapes is not None: + shapes = args.shapes + else: + shapes = [s for t in args.total_tokens for s in _default_isoline(t)] + + selected_modes = [(cli, fn) for cli, fn in MODES if cli in args.modes] + if not _supports_clc(args.device): + dropped = [cli for cli, _ in selected_modes if cli in _CLC_MODES] + if dropped: + print(f"# skipping CLC modes: {', '.join(dropped)}") + selected_modes = [ + (cli, fn) for cli, fn in selected_modes if cli not in _CLC_MODES + ] + + print(f"# device {args.device}: {torch.cuda.get_device_name(args.device)}") + print( + f"# headdim={args.headdim} nheads={args.nheads} nheads_kv={args.nheads_kv} " + f"(qhead_per_kvhead={args.nheads // args.nheads_kv})" + ) + cols = [ + ("pattern", 14), + ("decode", 8), + ("zero", 6), + ("shape", 10), + ("splits", 8), + ("mode", 14), + ("mean_us", 10), + ("tok/us", 9), + ("rel_static", 11), + ("rel_clc", 9), + ] + widths = [w for _, w in cols] + print(_format_row([h for h, _ in cols], args.csv, widths)) + + for shape, pattern, sort, decode_frac, zero_frac, num_splits in product( + shapes, + args.patterns, + args.sorts, + args.decode_fracs, + args.zero_fracs, + args.num_splits, + ): + batch, seqlen = shape + results = {} + # Workload is identical across modes; build once to get total_q for the report. + ref_ctx = build_ctx( + args, + batch, + seqlen, + pattern, + sort, + decode_frac, + zero_frac, + num_splits, + seed=0, + ) + total_q = sum(ref_ctx["seqlens_q"]) + + for cli, setup in selected_modes: + samples = [] + for s in range(args.seeds): + ctx = build_ctx( + args, + batch, + seqlen, + pattern, + sort, + decode_frac, + zero_frac, + num_splits, + seed=s, + ) + fn = setup(ctx) + if fn is None: + samples = None + break + fn() + torch.cuda.synchronize() + time.sleep(args.sleep) + samples.append(do_bench(fn, warmup=args.warmup, rep=args.rep)) + results[cli] = ( + None if samples is None else sum(samples) / len(samples) * 1e3 + ) + + static_us = results.get("static") + clc_us = results.get("clc") + for cli, _ in selected_modes: + us = results.get(cli) + if us is None: + continue + tok_per_us = (total_q / us) if us > 0 else 0.0 + rel_st = f"{static_us / us:.3f}" if static_us else "-" + rel_cl = f"{clc_us / us:.3f}" if clc_us else "-" + print( + _format_row( + [ + pattern, + f"{decode_frac:.2f}", + f"{zero_frac:.2f}", + f"{batch}x{seqlen}", + num_splits, + cli, + f"{us:.2f}", + f"{tok_per_us:.2f}", + rel_st, + rel_cl, + ], + args.csv, + widths, + ) + ) + + +if __name__ == "__main__": + main() From 70efc72da35e029f74292e59babd525ade31a099 Mon Sep 17 00:00:00 2001 From: reubenconducts Date: Wed, 13 May 2026 00:53:53 +0000 Subject: [PATCH 06/15] minor clean up --- flash_attn/cute/flash_fwd_combine.py | 1 - tests/cute/test_flash_attn.py | 3 +-- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/flash_attn/cute/flash_fwd_combine.py b/flash_attn/cute/flash_fwd_combine.py index e831fb0ce88..97164b8f533 100644 --- a/flash_attn/cute/flash_fwd_combine.py +++ b/flash_attn/cute/flash_fwd_combine.py @@ -365,7 +365,6 @@ def kernel( # Handle semaphore reset — wait for dependent grids first if const_expr(semaphore_to_reset is not None): - # maybe handle on first CTA? if ( tidx == 0 and m_block == cute.arch.grid_dim()[0] - 1 diff --git a/tests/cute/test_flash_attn.py b/tests/cute/test_flash_attn.py index e73f4959d86..b188e0301fa 100644 --- a/tests/cute/test_flash_attn.py +++ b/tests/cute/test_flash_attn.py @@ -34,7 +34,6 @@ _flash_attn_fwd, _flash_attn_bwd, ) -from flash_attn.cute.prepare_scheduler import SchedulerMetadataTensorsTorch def retry_on_oom(func): @wraps(func) @@ -995,7 +994,7 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): @pytest.mark.parametrize("has_leftpad", [False]) # @pytest.mark.parametrize("has_batch_idx", [False, True]) @pytest.mark.parametrize("has_batch_idx", [False]) -@pytest.mark.parametrize("varlen_q", [True]) +@pytest.mark.parametrize("varlen_q", [False, True]) # @pytest.mark.parametrize("varlen_q", [False]) # @pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) From f5d0abd8b297485d8d79531cf9902f271a439b03 Mon Sep 17 00:00:00 2001 From: reubenconducts Date: Wed, 13 May 2026 19:36:29 +0000 Subject: [PATCH 07/15] updates to has_work logic, tile scheduler selection, and varlen test suite --- benchmarks/benchmark_varlen_sched.py | 8 +- flash_attn/cute/flash_fwd_sm100.py | 57 ++++++++---- flash_attn/cute/interface.py | 16 ++-- tests/cute/test_flash_attn.py | 128 ++++++++++++++++----------- 4 files changed, 132 insertions(+), 77 deletions(-) diff --git a/benchmarks/benchmark_varlen_sched.py b/benchmarks/benchmark_varlen_sched.py index 2e99eb27d99..f08dc92f7c5 100644 --- a/benchmarks/benchmark_varlen_sched.py +++ b/benchmarks/benchmark_varlen_sched.py @@ -243,7 +243,7 @@ def fn(): # fmt: off MODES = [ ("dense", setup_dense), - ("static", make_varlen_setup(clc=False, prep="none")), + ("single-tile", make_varlen_setup(clc=False, prep="none")), ("clc", make_varlen_setup(clc=True, prep="none")), ("clc-prep", make_varlen_setup(clc=True, prep="precompute")), ("dynamic-prep", make_varlen_setup(clc=False, prep="precompute")), @@ -368,7 +368,7 @@ def main(): ("mode", 14), ("mean_us", 10), ("tok/us", 9), - ("rel_static", 11), + ("rel_st", 7), ("rel_clc", 9), ] widths = [w for _, w in cols] @@ -424,14 +424,14 @@ def main(): None if samples is None else sum(samples) / len(samples) * 1e3 ) - static_us = results.get("static") + single_tile_us = results.get("single-tile") clc_us = results.get("clc") for cli, _ in selected_modes: us = results.get(cli) if us is None: continue tok_per_us = (total_q / us) if us > 0 else 0.0 - rel_st = f"{static_us / us:.3f}" if static_us else "-" + rel_st = f"{single_tile_us / us:.3f}" if single_tile_us else "-" rel_cl = f"{clc_us / us:.3f}" if clc_us else "-" print( _format_row( diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index de653f6b53d..e55e81548ec 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -220,8 +220,7 @@ def __init__( use_clc_scheduler and self.use_tma_KV ) - self.static_persistent = is_static_persistent - self.is_persistent = self.dynamic_persistent or self.static_persistent + self.is_persistent = self.dynamic_persistent or self.is_static_persistent self.sched_stages = 1 if self.use_clc_scheduler: assert self.cluster_shape_mn[1] == 1, f"CLC requires cluster N == 1: {self.cluster_shape_mn}" @@ -236,14 +235,19 @@ def __init__( else SchedulingMode.STATIC ) + self.use_varlen_scheduler = False if is_varlen_q: if self.dynamic_persistent and not self.use_clc_scheduler: + self.use_varlen_scheduler = True self.TileScheduler = DynamicPersistentVarlenScheduler + elif self.is_static_persistent: + self.TileScheduler = StaticPersistentTileScheduler else: + self.use_varlen_scheduler = True self.TileScheduler = SingleTileVarlenScheduler elif self.is_causal or self.is_local or self.use_clc_scheduler: self.TileScheduler = SingleTileLPTScheduler - elif self.static_persistent: + elif self.is_static_persistent: self.TileScheduler = StaticPersistentTileScheduler else: self.TileScheduler = SingleTileScheduler @@ -1545,11 +1549,15 @@ def load( if const_expr(not self.use_block_sparsity): n_block_min, n_block_max = block_info.get_n_block_min_max( - seqlen, m_block, split_idx=split_idx, batch_idx=batch_idx, num_splits=num_splits, + seqlen, + m_block, + split_idx=split_idx, + batch_idx=batch_idx, + num_splits=num_splits, ) if const_expr(self.is_split_kv and block_info.num_splits_dynamic_ptr is not None): split_idx = split_idx & 0xFFFF - if const_expr(not self.is_split_kv) or n_block_min < n_block_max: + if self.process_work_tile(seqlen, n_block_min, n_block_max): n_block_first = n_block_max - 1 if n_block_max > 0 else 0 page_idx = ( mPageTable[batch_idx, n_block_first] @@ -1755,15 +1763,14 @@ def mma( process_tile = block_iter_count > Int32(0) else: n_block_min, n_block_max = block_info.get_n_block_min_max( - seqlen, m_block, split_idx=split_idx, batch_idx=batch_idx, num_splits=num_splits, - ) - if const_expr(self.is_split_kv and block_info.num_splits_dynamic_ptr is not None): - split_idx = split_idx & 0xFFFF + seqlen, + m_block, + split_idx=split_idx, + batch_idx=batch_idx, + num_splits=num_splits, + ) block_iter_count = n_block_max - n_block_min - if const_expr(not self.is_split_kv): - process_tile = True - else: - process_tile = n_block_min < n_block_max + process_tile = self.process_work_tile(seqlen, n_block_min, n_block_max) if process_tile and is_leader_cta: for stage in cutlass.range_constexpr(self.q_stage): @@ -2132,7 +2139,7 @@ def softmax_loop( has_work = tile_block_count > Int32(0) else: tile_block_count = n_block_max - n_block_min - has_work = const_expr(not self.is_split_kv) or tile_block_count > Int32(0) + has_work = self.process_work_tile(seqlen, n_block_min, n_block_max) softmax_step = partial( self.softmax_step, @@ -2213,7 +2220,7 @@ def softmax_loop( sm_stats_barrier.arrive_w_index(index=stage * 4 + warp_idx) # if tidx == 0: cute.printf("softmax row sum stage %d: %f\n", stage, softmax.row_sum[0]) else: - if const_expr(not self.is_split_kv) or tile_block_count > Int32(0): + if has_work: mma_si_consumer_phase, sm_stats_producer_phase, s0_s1_sequence_phase = softmax_step( mma_si_consumer_phase, sm_stats_producer_phase, @@ -2558,7 +2565,7 @@ def correction_loop( has_work = total_block_count > Int32(0) else: total_block_count = n_block_max - n_block_min - has_work = const_expr(not self.is_split_kv) or total_block_count > Int32(0) + has_work = self.process_work_tile(seqlen, n_block_min, n_block_max) if has_work: # Ignore first signal from softmax as no correction is required @@ -2988,9 +2995,8 @@ def epilogue_s2g( ) if const_expr(self.is_split_kv and block_info.num_splits_dynamic_ptr is not None): split_idx = split_idx & 0xFFFF - has_work = const_expr(self.use_block_sparsity or not self.is_split_kv) or n_block_min < n_block_max - if has_work: + if self.process_work_tile(seqlen, n_block_min, n_block_max): if const_expr(self.is_split_kv): mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3)[None, None, head_idx, split_idx] else: @@ -3262,3 +3268,18 @@ def apply_score_mod( constant_q_idx=q_idx_logical, qhead_per_kvhead=self.qhead_per_kvhead if cutlass.const_expr(self.pack_gqa) else 1, ) + + @cute.jit + def process_work_tile( + self, + seqlen_info: SeqlenInfoQK, + n_block_min: Int32, + n_block_max: Int32, + ): + is_varlen_q = seqlen_info.has_cu_seqlens_q or seqlen_info.has_seqused_q + process_work_tile_k = const_expr(not self.is_split_kv) or n_block_min < n_block_max + if const_expr(is_varlen_q and not self.use_varlen_scheduler): + process_work_tile_q = seqlen_info.seqlen_q > 0 + else: + process_work_tile_q = True + return process_work_tile_k and process_work_tile_q diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index ddedf0c50be..387388269d7 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -541,6 +541,7 @@ def _flash_attn_fwd( q_stage = 1 m_block_size_effective = q_stage * tile_m + max_m_blocks_leq_one = seqlen_q_packgqa <= m_block_size_effective seqlen_k_loaded = max_seqlen_k if not local else max(0, min(max_seqlen_k, (window_size_right or max_seqlen_k) + (window_size_left or max_seqlen_k) + 1 + tile_m)) num_m_blocks = (seqlen_q_packgqa + m_block_size_effective - 1) // m_block_size_effective total_mblocks = batch_size * num_head_kv * num_m_blocks @@ -745,6 +746,14 @@ def _flash_attn_fwd( num_nheads_in_l2 = None tile_count_semaphore = None + is_static_persistent = ( + not causal + and not local + and cu_seqlens_q is None + and seqused_q is None + and not is_split_kv + ) or (max_m_blocks_leq_one and not is_split_kv) + compile_key = ( dtype, head_dim, @@ -788,6 +797,7 @@ def _flash_attn_fwd( virtual_batch_idx is not None, num_nheads_in_l2 is not None, tile_count_semaphore is not None, + is_static_persistent, qv is not None, gather_kv_length, sparse_kv, @@ -985,11 +995,7 @@ def _flash_attn_fwd( m_block_size=tile_m, n_block_size=tile_n, q_stage=q_stage, - is_static_persistent=not causal - and not local - and cu_seqlens_q is None - and seqused_q is None - and not is_split_kv, + is_static_persistent=is_static_persistent, score_mod=score_mod, mask_mod=mask_mod, has_aux_tensors=aux_tensors is not None, diff --git a/tests/cute/test_flash_attn.py b/tests/cute/test_flash_attn.py index b188e0301fa..1a15c0cdc27 100644 --- a/tests/cute/test_flash_attn.py +++ b/tests/cute/test_flash_attn.py @@ -763,10 +763,15 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): # num_splits_vals = [1, 3] # SplitKV is not supported for hdim >= 192 num_splits_vals = [1, 3] if d < 192 and not DISABLE_SPLIT and not TEST_BWD_ONLY else [1] - for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): + precompute_metadata_vals = [False, True] + for pack_gqa, num_splits, precompute_metadata in itertools.product( + pack_gqa_vals, num_splits_vals, precompute_metadata_vals + ): # SplitKV not supported on SM90 - skip this iteration if IS_SM90 and num_splits > 1: continue + if precompute_metadata and is_fake_mode(): + continue # TODO(wangsiyu): SM100 head_dim=256 2CTA kernel does not support pack_gqa yet. # pack_gqa=None means auto-enable for GQA/MQA (qhead_per_kvhead > 1) # Remove this when support is added. @@ -775,56 +780,79 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): continue if pack_gqa is None and mha_type != "mha": continue - out_unpad, lse = flash_attn_varlen_func( - q_unpad if unpad_q else q, - k_unpad if unpad_kv else k, - v_unpad if unpad_kv else v, - cu_seqlens_q=cu_seqlens_q if unpad_q else None, - cu_seqlens_k=cu_seqlens_k if unpad_kv else None, - max_seqlen_q=seqlen_q, - max_seqlen_k=seqlen_k, - seqused_q=seqused_q if not unpad_q else None, - seqused_k=seqused_k if not unpad_kv else None, - causal=causal, - # qv=qv_unpad, - # q_descale=q_descale, - # k_descale=k_descale, v_descale=v_descale, - window_size=window_size, - # attention_chunk=attention_chunk, - learnable_sink=learnable_sink, - softcap=softcap, - num_splits=num_splits, - pack_gqa=pack_gqa, - deterministic=deterministic, - ) - out = output_pad_fn(out_unpad) if unpad_q else out_unpad - if is_fake_mode(): - # no more flash_attn cutedsl calls for the rest of the loop - # skip data-dependent postprocessing - continue - if query_unused_mask is not None: - out.masked_fill_(q_zero_masking, 0.0) - # When unpad_q=False with seqused_q, the kernel doesn't write positions - # beyond seqused_q, so those contain uninitialized values. Mask them out - # before comparing. - out_cmp, out_ref_cmp, out_pt_cmp = out, out_ref, out_pt - if not unpad_q and seqused_q is not None: - seqused_mask = torch.arange(seqlen_q, device=device)[None, :] < seqused_q[:, None] - seqused_mask = rearrange(seqused_mask, "b s -> b s 1 1") - out_cmp = out.clone().masked_fill_(~seqused_mask, 0.0) - out_ref_cmp = out_ref.clone().masked_fill_(~seqused_mask, 0.0) - out_pt_cmp = out_pt.clone().masked_fill_(~seqused_mask, 0.0) - print(f"Output max diff: {(out_cmp - out_ref_cmp).abs().max().item()}") - print(f"Output mean diff: {(out_cmp - out_ref_cmp).abs().mean().item()}") - # if not causal: - # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") - # breakpoint() + if precompute_metadata: + scheduler_metadata = get_scheduler_metadata( + num_batch=batch_size, + max_seqlen_q=seqlen_q, + max_seqlen_k=seqlen_k, + nheads=nheads, + nheads_kv=nheads_kv, + headdim=d, + headdim_v=dv, + num_splits=num_splits, + tile_m=128, + tile_n=128, + causal=causal, + cu_seqlens_q=cu_seqlens_q if unpad_q else None, + cu_seqlens_k=cu_seqlens_k if unpad_kv else None, + seqused_q=seqused_q if not unpad_q else None, + seqused_k=seqused_k if not unpad_kv else None, + ) + else: + scheduler_metadata = None + # Repeat to exercise metadata reuse across calls. + for _ in range(1 if not precompute_metadata else 2): + out_unpad, lse = flash_attn_varlen_func( + q_unpad if unpad_q else q, + k_unpad if unpad_kv else k, + v_unpad if unpad_kv else v, + cu_seqlens_q=cu_seqlens_q if unpad_q else None, + cu_seqlens_k=cu_seqlens_k if unpad_kv else None, + max_seqlen_q=seqlen_q, + max_seqlen_k=seqlen_k, + seqused_q=seqused_q if not unpad_q else None, + seqused_k=seqused_k if not unpad_kv else None, + causal=causal, + # qv=qv_unpad, + # q_descale=q_descale, + # k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + # attention_chunk=attention_chunk, + learnable_sink=learnable_sink, + softcap=softcap, + scheduler_metadata=scheduler_metadata, + num_splits=num_splits, + pack_gqa=pack_gqa, + deterministic=deterministic, + ) + out = output_pad_fn(out_unpad) if unpad_q else out_unpad + if is_fake_mode(): + # no more flash_attn cutedsl calls for the rest of the loop + # skip data-dependent postprocessing + continue + if query_unused_mask is not None: + out.masked_fill_(q_zero_masking, 0.0) + # When unpad_q=False with seqused_q, the kernel doesn't write positions + # beyond seqused_q, so those contain uninitialized values. Mask them out + # before comparing. + out_cmp, out_ref_cmp, out_pt_cmp = out, out_ref, out_pt + if not unpad_q and seqused_q is not None: + seqused_mask = torch.arange(seqlen_q, device=device)[None, :] < seqused_q[:, None] + seqused_mask = rearrange(seqused_mask, "b s -> b s 1 1") + out_cmp = out.clone().masked_fill_(~seqused_mask, 0.0) + out_ref_cmp = out_ref.clone().masked_fill_(~seqused_mask, 0.0) + out_pt_cmp = out_pt.clone().masked_fill_(~seqused_mask, 0.0) + print(f"Output max diff: {(out_cmp - out_ref_cmp).abs().max().item()}") + print(f"Output mean diff: {(out_cmp - out_ref_cmp).abs().mean().item()}") + # if not causal: + # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") + # breakpoint() - # Check that FlashAttention's numerical error is at most 3x the numerical error - # of a Pytorch implementation. - assert (out_cmp - out_ref_cmp).abs().max().item() <= rtol * ( - out_pt_cmp - out_ref_cmp - ).abs().max().item() + fwd_atol + # Check that FlashAttention's numerical error is at most 3x the numerical error + # of a Pytorch implementation. + assert (out_cmp - out_ref_cmp).abs().max().item() <= rtol * ( + out_pt_cmp - out_ref_cmp + ).abs().max().item() + fwd_atol if ( dtype != torch.float8_e4m3fn From 2e9bc0f3e198f95acb7a27347418863a32b24373 Mon Sep 17 00:00:00 2001 From: reubenconducts Date: Wed, 13 May 2026 23:16:32 +0000 Subject: [PATCH 08/15] fix tile scheduler dispatch logic --- flash_attn/cute/flash_fwd_sm100.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index e55e81548ec..bd636524f08 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -240,7 +240,7 @@ def __init__( if self.dynamic_persistent and not self.use_clc_scheduler: self.use_varlen_scheduler = True self.TileScheduler = DynamicPersistentVarlenScheduler - elif self.is_static_persistent: + elif self.is_static_persistent and not self.use_clc_scheduler: self.TileScheduler = StaticPersistentTileScheduler else: self.use_varlen_scheduler = True From e76dde1179538e9fb25a1733c9a7de7929e61428 Mon Sep 17 00:00:00 2001 From: reubenconducts Date: Mon, 25 May 2026 22:11:00 +0000 Subject: [PATCH 09/15] integrate binary batch search for single tile varlen --- benchmarks/benchmark_varlen_sched.py | 54 +++++++++++++++++++++++++-- flash_attn/cute/flash_fwd_sm100.py | 4 ++ flash_attn/cute/interface.py | 55 ++++++++++++++++++++++++---- flash_attn/cute/prepare_scheduler.py | 5 +++ flash_attn/cute/tile_scheduler.py | 46 +++++++++++++++++++++-- 5 files changed, 149 insertions(+), 15 deletions(-) diff --git a/benchmarks/benchmark_varlen_sched.py b/benchmarks/benchmark_varlen_sched.py index f08dc92f7c5..d915123cdf6 100644 --- a/benchmarks/benchmark_varlen_sched.py +++ b/benchmarks/benchmark_varlen_sched.py @@ -176,6 +176,7 @@ def build_ctx( max_seqlen_k=max(seqlens_k) if seqlens_k else 0, causal=True, num_splits=num_splits, + pack_gqa=args.pack_gqa, ) @@ -191,11 +192,20 @@ def _make_meta(ctx): tile_m=128, tile_n=128, causal=ctx["causal"], + pack_gqa=ctx["pack_gqa"], cu_seqlens_q=ctx["cu_q"], cu_seqlens_k=ctx["cu_k"], ) +def _make_meta_no_semaphore(ctx): + """Like _make_meta but with tile_count_semaphore nulled out, so the FA kernel + selects SingleTileVarlenScheduler (STATIC) instead of DynamicPersistentVarlen. + Exercises the binary-search hint path on the scheduler that lacks resumption.""" + m = _make_meta(ctx) + return m._replace(tile_count_semaphore=None) + + def setup_dense(ctx): """Non-varlen baseline; only meaningful when every batch has the same seqlen.""" if ctx["pattern"] != "constant" or ctx["decode_frac"] != 0 or ctx["zero_frac"] != 0: @@ -211,16 +221,22 @@ def setup_dense(ctx): ) -def make_varlen_setup(*, clc: bool, prep: str): - """`prep` is one of 'none', 'precompute', 'recompute'.""" +def make_varlen_setup(*, clc: bool, prep: str, no_semaphore: bool = False): + """`prep` is one of 'none', 'precompute', 'recompute'. + + `no_semaphore=True` nulls out `tile_count_semaphore` in the metadata so the + FA kernel picks SingleTileVarlenScheduler (STATIC) instead of the auto- + selected DynamicPersistentVarlenScheduler. Use this to exercise the binary- + search hint path on the no-resumption scheduler that PR #2520 targets.""" assert prep in ("none", "precompute", "recompute") + meta_fn = _make_meta_no_semaphore if no_semaphore else _make_meta def setup(ctx): - meta_precomputed = _make_meta(ctx) if prep == "precompute" else None + meta_precomputed = meta_fn(ctx) if prep == "precompute" else None def fn(): fa_utils._fa_clc_enabled = clc - meta = _make_meta(ctx) if prep == "recompute" else meta_precomputed + meta = meta_fn(ctx) if prep == "recompute" else meta_precomputed return flash_attn_varlen_func( ctx["q_unpad"], ctx["k_unpad"], @@ -233,6 +249,7 @@ def fn(): num_splits=ctx["num_splits"], scheduler_metadata=meta, disable_scheduler_metadata=(prep == "none"), + pack_gqa=ctx["pack_gqa"], ) return fn @@ -244,6 +261,7 @@ def fn(): MODES = [ ("dense", setup_dense), ("single-tile", make_varlen_setup(clc=False, prep="none")), + ("st-prep", make_varlen_setup(clc=False, prep="precompute", no_semaphore=True)), ("clc", make_varlen_setup(clc=True, prep="none")), ("clc-prep", make_varlen_setup(clc=True, prep="precompute")), ("dynamic-prep", make_varlen_setup(clc=False, prep="precompute")), @@ -304,6 +322,9 @@ def parse_args(): p.add_argument("--headdim", type=int, default=128) p.add_argument("--nheads", type=int, default=16) p.add_argument("--nheads-kv", type=int, default=2) + p.add_argument("--pack-gqa", action="store_true", default=True, + help="Force pack_gqa=True (default). --no-pack-gqa to disable.") + p.add_argument("--no-pack-gqa", dest="pack_gqa", action="store_false") p.add_argument("--seeds", type=int, default=3) p.add_argument("--warmup", type=int, default=2) p.add_argument("--rep", type=int, default=20) @@ -368,6 +389,7 @@ def main(): ("mode", 14), ("mean_us", 10), ("tok/us", 9), + ("tflops", 8), ("rel_st", 7), ("rel_clc", 9), ] @@ -397,6 +419,28 @@ def main(): seed=0, ) total_q = sum(ref_ctx["seqlens_q"]) + # Causal varlen attention FLOPs per batch: + # per (head, query q in [0, sq)): 4 * d * effective_k where + # effective_k = max(0, sk - sq + q + 1). + # sum_q effective_k = sq*sk - sq*(sq-1)/2 (for sk >= sq; otherwise clamped). + total_flops = 0 + for sq, sk in zip(ref_ctx["seqlens_q"], ref_ctx["seqlens_k"]): + if sq == 0 or sk == 0: + continue + if ref_ctx["causal"]: + # sum_{q=0}^{sq-1} max(0, sk - sq + q + 1) + shift = sk - sq + if shift >= 0: + eff = sq * sk - sq * (sq - 1) // 2 + else: + # clamp to non-negative for queries near 0 + first_visible_q = -shift # smallest q with sk - sq + q + 1 > 0 is q = sq - sk + visible = sq - first_visible_q + eff = visible * sk - visible * (visible - 1) // 2 + eff = max(0, eff) + else: + eff = sq * sk + total_flops += 4 * args.headdim * args.nheads * eff for cli, setup in selected_modes: samples = [] @@ -431,6 +475,7 @@ def main(): if us is None: continue tok_per_us = (total_q / us) if us > 0 else 0.0 + tflops = (total_flops / (us * 1e6)) if us > 0 else 0.0 rel_st = f"{single_tile_us / us:.3f}" if single_tile_us else "-" rel_cl = f"{clc_us / us:.3f}" if clc_us else "-" print( @@ -444,6 +489,7 @@ def main(): cli, f"{us:.2f}", f"{tok_per_us:.2f}", + f"{tflops:.2f}", rel_st, rel_cl, ], diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index bd636524f08..3da16ece4fa 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -398,6 +398,8 @@ def __call__( tile_count_semaphore: Optional[cute.Tensor] = None, virtual_batch_idx_ptr: Optional[cute.Tensor] = None, num_nheads_in_l2_ptr: Optional[cute.Tensor] = None, + cu_total_m_blocks_ptr: Optional[cute.Tensor] = None, + cu_total_splits_m_blocks_ptr: Optional[cute.Tensor] = None, max_seqlen_q: Int32 | int | None = None, # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI). stream: cuda.CUstream = None, @@ -686,6 +688,8 @@ def __call__( num_splits_dynamic_ptr=num_splits_dynamic_ptr, virtual_batch_idx_ptr=virtual_batch_idx_ptr, num_nheads_in_l2_ptr=num_nheads_in_l2_ptr, + cu_total_m_blocks_ptr=cu_total_m_blocks_ptr, + cu_total_splits_m_blocks_ptr=cu_total_splits_m_blocks_ptr, tile_count_semaphore=tile_count_semaphore.iterator if tile_count_semaphore is not None else None, ) tile_sched_params = TileScheduler.to_underlying_arguments( diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 387388269d7..52df5f8f1b4 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -717,13 +717,11 @@ def _flash_attn_fwd( has_scheduler_metadata = scheduler_metadata is not None and not disable_scheduler_metadata if has_scheduler_metadata: - ( - num_m_blocks, - num_splits_dynamic, - virtual_batch_idx, - num_nheads_in_l2, - tile_count_semaphore, - ) = scheduler_metadata + num_m_blocks = scheduler_metadata.num_m_blocks_ptr + num_splits_dynamic = scheduler_metadata.num_splits_dynamic_ptr + virtual_batch_idx = scheduler_metadata.virtual_batch_idx_ptr + num_nheads_in_l2 = scheduler_metadata.num_nheads_in_l2_ptr + tile_count_semaphore = scheduler_metadata.tile_count_semaphore assert all( t is None or t.is_cuda for t in scheduler_metadata @@ -739,12 +737,39 @@ def _flash_attn_fwd( ), "these scheduler metadata tensors must have shape (batch_size,)" if tile_count_semaphore is not None: assert tile_count_semaphore.shape == (1,), "semaphore must have size 1" + # The kernel's _get_num_m_blocks uses tile_shape_mn[0] = q_stage * tile_m, + # so cu_total must be built with that effective tile. Rebuild here rather + # than trust scheduler_metadata.cu_total_*_ptr (which were computed with + # the user's tile_m and would mis-decode the CLC-exhausted sentinel). + if ( + num_m_blocks is not None + and num_splits_dynamic is not None + and os.environ.get("FLASH_ATTENTION_DISABLE_BINARY_SEARCH", "0") != "1" + ): + num_m_blocks_eff = (num_m_blocks + q_stage - 1) // q_stage + num_splits_m_blocks_eff = num_m_blocks_eff * num_splits_dynamic + if virtual_batch_idx is not None: + order = virtual_batch_idx.long() + stacked = torch.stack( + [num_m_blocks_eff[order], num_splits_m_blocks_eff[order]], dim=0 + ) + else: + stacked = torch.stack([num_m_blocks_eff, num_splits_m_blocks_eff], dim=0) + cum = torch.cumsum(stacked, dim=1, dtype=torch.int32) + padded = torch.nn.functional.pad(cum, (1, 0)) + cu_total_m_blocks = padded[0] + cu_total_splits_m_blocks = padded[1] + else: + cu_total_m_blocks = None + cu_total_splits_m_blocks = None else: num_m_blocks = None num_splits_dynamic = None virtual_batch_idx = None num_nheads_in_l2 = None tile_count_semaphore = None + cu_total_m_blocks = None + cu_total_splits_m_blocks = None is_static_persistent = ( not causal @@ -797,6 +822,8 @@ def _flash_attn_fwd( virtual_batch_idx is not None, num_nheads_in_l2 is not None, tile_count_semaphore is not None, + cu_total_m_blocks is not None, + cu_total_splits_m_blocks is not None, is_static_persistent, qv is not None, gather_kv_length, @@ -889,6 +916,14 @@ def _flash_attn_fwd( to_cute_tensor(num_nheads_in_l2, assumed_align=4, leading_dim=0) if num_nheads_in_l2 is not None else None ) + cu_total_m_blocks_tensor = ( + to_cute_tensor(cu_total_m_blocks, assumed_align=4, leading_dim=0) + if cu_total_m_blocks is not None else None + ) + cu_total_splits_m_blocks_tensor = ( + to_cute_tensor(cu_total_splits_m_blocks, assumed_align=4, leading_dim=0) + if cu_total_splits_m_blocks is not None else None + ) qv_tensor = to_cute_tensor(qv) if qv is not None else None gather_kv_indices_tensor = to_cute_tensor(gather_kv_indices) if gather_kv_indices is not None else None @@ -1090,6 +1125,8 @@ def _flash_attn_fwd( compile_args.extend([ virtual_batch_idx_tensor, num_nheads_in_l2_tensor, + cu_total_m_blocks_tensor, + cu_total_splits_m_blocks_tensor, max_seqlen_q, ]) compile_args.append(current_stream) @@ -1172,6 +1209,8 @@ def _flash_attn_fwd( call_args.extend([ virtual_batch_idx, num_nheads_in_l2, + cu_total_m_blocks, + cu_total_splits_m_blocks, max_seqlen_q, ]) _flash_attn_fwd.compile_cache[compile_key](*call_args) @@ -2811,6 +2850,8 @@ def get_scheduler_metadata( virtual_batch_idx_ptr=virtual_batch_idx, num_nheads_in_l2_ptr=num_nheads_in_l2, tile_count_semaphore=tile_count_semaphore, + cu_total_m_blocks_ptr=None, + cu_total_splits_m_blocks_ptr=None, ) diff --git a/flash_attn/cute/prepare_scheduler.py b/flash_attn/cute/prepare_scheduler.py index 796f147a5d8..52ef2e881a4 100644 --- a/flash_attn/cute/prepare_scheduler.py +++ b/flash_attn/cute/prepare_scheduler.py @@ -22,6 +22,11 @@ class SchedulerMetadataTensorsTorch(NamedTuple): num_nheads_in_l2_ptr: Optional[torch.Tensor] # tensor of shape (1) tile_count_semaphore: Optional[torch.Tensor] + # tensors of shape (batch + 1) + # cu_total_m_blocks[b+1] = sum_{i<=b} num_m_blocks[i] + # cu_total_splits_m_blocks[b+1] = sum_{i<=b} num_m_blocks[i] * num_splits_dynamic[i] + cu_total_m_blocks_ptr: Optional[torch.Tensor] = None + cu_total_splits_m_blocks_ptr: Optional[torch.Tensor] = None class FlashPrepareScheduler: diff --git a/flash_attn/cute/tile_scheduler.py b/flash_attn/cute/tile_scheduler.py index c232759b577..7090d0ee293 100644 --- a/flash_attn/cute/tile_scheduler.py +++ b/flash_attn/cute/tile_scheduler.py @@ -265,6 +265,8 @@ class TileSchedulerArguments(ParamsBase): num_m_blocks_ptr: Optional[cute.Tensor] = None virtual_batch_idx_ptr: Optional[cute.Tensor] = None num_nheads_in_l2_ptr: Optional[cute.Tensor] = None + cu_total_m_blocks_ptr: Optional[cute.Tensor] = None + cu_total_splits_m_blocks_ptr: Optional[cute.Tensor] = None tile_count_semaphore: Optional[cute.Pointer] = None persistent_cta_multiplier: cutlass.Constexpr[int] = 1 @@ -993,11 +995,33 @@ def _varlen_coord_map( cluster_shape_m = getattr(params, "cluster_shape_m", 1) num_nheads_in_l2_ptr = getattr(params, "num_nheads_in_l2_ptr", None) virtual_batch_idx_ptr = getattr(params, "virtual_batch_idx_ptr", None) + cu_total_m_blocks_ptr = getattr(params, "cu_total_m_blocks_ptr", None) + cu_total_splits_m_blocks_ptr = getattr(params, "cu_total_splits_m_blocks_ptr", None) + scheduling_mode = getattr(params, "scheduling_mode", None) + if const_expr(params.is_split_kv): + cu_hint_ptr = cu_total_splits_m_blocks_ptr + else: + cu_hint_ptr = cu_total_m_blocks_ptr + # Both SingleTileVarlen STATIC and CLC; not DynamicPersistent (where + # warp-scan's _bidb_start resumption already amortizes per-call cost). + use_cumsum_hint = const_expr( + cluster_shape_m == 1 + and cu_hint_ptr is not None + and (scheduling_mode == SchedulingMode.STATIC or scheduling_mode == SchedulingMode.CLC) + ) + if const_expr(use_cumsum_hint): + target = next_tile_idx // params.num_head + lo = utils.get_batch_from_cu_tensor(target, cu_hint_ptr) + group_size = Int32(cute.arch.WARP_SIZE - 1) + bidb_start = (lo // group_size) * group_size + group_start_tile = cu_hint_ptr[bidb_start] * params.num_head lane_idx = cute.arch.lane_idx() num_m_blocks = self._get_num_m_blocks(lane_idx, bidb_start=bidb_start) num_splits = self._get_num_splits(lane_idx, bidb_start=bidb_start) - per_batch = num_m_blocks * num_splits if const_expr(params.is_split_kv) else num_m_blocks + per_batch = ( + num_m_blocks * num_splits if const_expr(params.is_split_kv) else num_m_blocks + ) cumulative = utils.warp_prefix_sum(per_batch, lane_idx) m_blocks_in_group = cute.arch.shuffle_sync(cumulative, cute.arch.WARP_SIZE - 1) group_end_tile = m_blocks_in_group * params.num_head + group_start_tile @@ -1012,14 +1036,15 @@ def _varlen_coord_map( num_m_blocks = self._get_num_m_blocks(lane_idx, bidb_start=batch_idx) num_splits = self._get_num_splits(lane_idx, bidb_start=batch_idx) per_batch = ( - num_m_blocks * num_splits if const_expr(params.is_split_kv) else num_m_blocks + num_m_blocks * num_splits + if const_expr(params.is_split_kv) + else num_m_blocks ) cumulative = utils.warp_prefix_sum(per_batch, lane_idx) m_blocks_in_group = cute.arch.shuffle_sync(cumulative, cute.arch.WARP_SIZE - 1) group_end_tile += m_blocks_in_group * params.num_head is_valid = batch_idx < params.num_batch - block, head_idx, split_idx = Int32(0), Int32(0), Int32(0) if is_valid: group_start_tile = group_end_tile - m_blocks_in_group * params.num_head batch_idx_in_group = cute.arch.popc( @@ -1037,6 +1062,9 @@ def _varlen_coord_map( num_m_blocks = cute.arch.shuffle_sync(num_m_blocks, batch_idx_in_group) if const_expr(params.is_split_kv): num_splits = cute.arch.shuffle_sync(num_splits, batch_idx_in_group) + + block, head_idx, split_idx = Int32(0), Int32(0), Int32(0) + if is_valid: mh_block = next_tile_idx - group_start_tile if const_expr(params.lpt or head_swizzle): @@ -1046,7 +1074,9 @@ def _varlen_coord_map( if const_expr(not params.is_split_kv) or num_splits == 1: if const_expr(num_nheads_in_l2_ptr is not None): if const_expr(virtual_batch_idx_ptr is not None): - nheads_in_l2 = Int32(num_nheads_in_l2_ptr[virtual_batch_idx_ptr[batch_idx]]) + nheads_in_l2 = Int32( + num_nheads_in_l2_ptr[virtual_batch_idx_ptr[batch_idx]] + ) else: nheads_in_l2 = Int32(num_nheads_in_l2_ptr[batch_idx]) else: @@ -1128,6 +1158,8 @@ class Params(ParamsBase): num_m_blocks_ptr: Optional[cute.Tensor] = None virtual_batch_idx_ptr: Optional[cute.Tensor] = None num_nheads_in_l2_ptr: Optional[cute.Tensor] = None + cu_total_m_blocks_ptr: Optional[cute.Tensor] = None + cu_total_splits_m_blocks_ptr: Optional[cute.Tensor] = None @staticmethod @cute.jit @@ -1178,6 +1210,8 @@ def create( num_m_blocks_ptr=args.num_m_blocks_ptr, virtual_batch_idx_ptr=args.virtual_batch_idx_ptr, num_nheads_in_l2_ptr=args.num_nheads_in_l2_ptr, + cu_total_m_blocks_ptr=args.cu_total_m_blocks_ptr, + cu_total_splits_m_blocks_ptr=args.cu_total_splits_m_blocks_ptr, ) def __init__( @@ -1372,6 +1406,8 @@ class Params(ParamsBase): num_m_blocks_ptr: Optional[cute.Tensor] = None virtual_batch_idx_ptr: Optional[cute.Tensor] = None num_nheads_in_l2_ptr: Optional[cute.Tensor] = None + cu_total_m_blocks_ptr: Optional[cute.Tensor] = None + cu_total_splits_m_blocks_ptr: Optional[cute.Tensor] = None tile_count_semaphore: Optional[cute.Pointer] = None persistent_cta_multiplier: cutlass.Constexpr[int] = 1 @@ -1403,6 +1439,8 @@ def create( num_m_blocks_ptr=args.num_m_blocks_ptr, virtual_batch_idx_ptr=args.virtual_batch_idx_ptr, num_nheads_in_l2_ptr=args.num_nheads_in_l2_ptr, + cu_total_m_blocks_ptr=args.cu_total_m_blocks_ptr, + cu_total_splits_m_blocks_ptr=args.cu_total_splits_m_blocks_ptr, tile_count_semaphore=args.tile_count_semaphore, persistent_cta_multiplier=args.persistent_cta_multiplier, ) From 83328ab11c9bdb4d540c2f182c143a92c2871a3c Mon Sep 17 00:00:00 2001 From: reubenconducts Date: Mon, 25 May 2026 23:46:49 +0000 Subject: [PATCH 10/15] refactor tile scheduler for compositionality --- flash_attn/cute/flash_fwd_mla_sm100.py | 4 +- ...100_hd256_2cta_fmha_backward_dkdvkernel.py | 4 +- ...sm100_hd256_2cta_fmha_backward_dqkernel.py | 4 +- .../cute/sm100_hd256_2cta_fmha_forward.py | 4 +- flash_attn/cute/tile_scheduler.py | 583 ++++++++---------- 5 files changed, 255 insertions(+), 344 deletions(-) diff --git a/flash_attn/cute/flash_fwd_mla_sm100.py b/flash_attn/cute/flash_fwd_mla_sm100.py index 2987b4c0460..0ee73cf5f0f 100644 --- a/flash_attn/cute/flash_fwd_mla_sm100.py +++ b/flash_attn/cute/flash_fwd_mla_sm100.py @@ -26,7 +26,7 @@ import flash_attn.cute.blackwell_helpers as fa_sm100_utils from flash_attn.cute.softmax import SoftmaxSm100 from flash_attn.cute.tile_scheduler import ( - ClcState, + SchedulerState, SchedulingMode, TileSchedulerArguments, TileSchedulerProtocol, @@ -961,7 +961,7 @@ def make_pipeline(cls, mbar_ptr, num_stages, producer, consumer, tx_count=None): clc_pipeline_consumer_group = pipeline.CooperativeGroup( pipeline.Agent.Thread, cute.arch.WARP_SIZE * num_clc_consumer_warps ) - clc = ClcState.create( + clc = SchedulerState.create_clc( hw_scheduler=ClcDynamicPersistentTileScheduler.create( self.tile_scheduler_cls.clc_problem_shape(tile_sched_params), cute.arch.block_idx(), diff --git a/flash_attn/cute/sm100_hd256_2cta_fmha_backward_dkdvkernel.py b/flash_attn/cute/sm100_hd256_2cta_fmha_backward_dkdvkernel.py index 885ae336f5f..629e95c42bc 100644 --- a/flash_attn/cute/sm100_hd256_2cta_fmha_backward_dkdvkernel.py +++ b/flash_attn/cute/sm100_hd256_2cta_fmha_backward_dkdvkernel.py @@ -24,7 +24,7 @@ from cutlass.utils import ClcDynamicPersistentTileScheduler from flash_attn.cute.tile_scheduler import ( - ClcState, + SchedulerState, SM100_TMEM_CAPACITY_COLUMNS, make_sm100_thread_cooperative_group as make_thread_cooperative_group, Sm100FmhaClcDynamicTileSchedulerParams as FmhaClcDynamicTileSchedulerParams, @@ -1062,7 +1062,7 @@ def dkdv_bwd( pipeline.Agent.Thread, num_clc_consumer_threads ) clc_response_ptr = storage.clc_response.data_ptr() - clc = ClcState.create( + clc = SchedulerState.create_clc( hw_scheduler=ClcDynamicPersistentTileScheduler.create( self.tile_sched_params.clc_hw_params(), cute.arch.block_idx(), diff --git a/flash_attn/cute/sm100_hd256_2cta_fmha_backward_dqkernel.py b/flash_attn/cute/sm100_hd256_2cta_fmha_backward_dqkernel.py index 25d6a91de70..5b666083132 100644 --- a/flash_attn/cute/sm100_hd256_2cta_fmha_backward_dqkernel.py +++ b/flash_attn/cute/sm100_hd256_2cta_fmha_backward_dqkernel.py @@ -16,7 +16,7 @@ from cutlass.utils import ClcDynamicPersistentTileScheduler from flash_attn.cute.tile_scheduler import ( - ClcState, + SchedulerState, compute_sm100_fmha_grid as compute_grid, compute_sm100_fmha_grid_clc as compute_grid_clc, make_sm100_thread_cooperative_group as make_thread_cooperative_group, @@ -779,7 +779,7 @@ def kernel( pipeline.Agent.Thread, num_clc_consumer_threads ) clc_response_ptr = storage.clc_response.data_ptr() - clc = ClcState.create( + clc = SchedulerState.create_clc( hw_scheduler=ClcDynamicPersistentTileScheduler.create( self.tile_sched_params.clc_hw_params(), cute.arch.block_idx(), diff --git a/flash_attn/cute/sm100_hd256_2cta_fmha_forward.py b/flash_attn/cute/sm100_hd256_2cta_fmha_forward.py index c2abfab3dfc..6cafc6da30e 100644 --- a/flash_attn/cute/sm100_hd256_2cta_fmha_forward.py +++ b/flash_attn/cute/sm100_hd256_2cta_fmha_forward.py @@ -15,7 +15,7 @@ from cutlass.utils import ClcDynamicPersistentTileScheduler from flash_attn.cute.tile_scheduler import ( - ClcState, + SchedulerState, compute_sm100_fmha_grid as compute_grid, compute_sm100_fmha_grid_clc as compute_grid_clc, make_sm100_thread_cooperative_group as make_thread_cooperative_group, @@ -699,7 +699,7 @@ def kernel( pipeline.Agent.Thread, num_clc_consumer_threads ) clc_response_ptr = storage.clc_response.data_ptr() - clc = ClcState.create( + clc = SchedulerState.create_clc( hw_scheduler=ClcDynamicPersistentTileScheduler.create( self.tile_sched_params.clc_hw_params(), cute.arch.block_idx(), diff --git a/flash_attn/cute/tile_scheduler.py b/flash_attn/cute/tile_scheduler.py index 7090d0ee293..f9282d9c2ca 100644 --- a/flash_attn/cute/tile_scheduler.py +++ b/flash_attn/cute/tile_scheduler.py @@ -39,27 +39,18 @@ class SchedulingMode(IntEnum): @dataclass class SchedulerState(ParamsBase): - """Owns the runtime state shared by CLC and dynamic persistent tile schedulers. - - Main kernel constructs this state because it owns the - response buffer / work_info smem, mbarrier storage, and launch geometry - needed to initialize the backend (CLC hardware scheduler or atomic-counter - work_info region) and the async pipeline. Individual tile schedulers then - consume this state and map the returned work tiles into their own logical - `WorkTileInfo` coordinates. - - Tagged by `scheduling_mode`: - - CLC: `_hw_scheduler` is set; `prefetch_next_work` issues the HW query. - - DYNAMIC: `_work_info` is set; the scheduler class does its own - atomicAdd + warp-prefix-sum and writes via `write_work_info`. + """Runtime state shared by CLC and dynamic persistent tile schedulers: + the async pipeline and its producer/consumer states. + + Main kernels construct this via `create_clc` / `create_dynamic_persistent`, + which return the appropriate concrete state (`ClcSchedulerState` or + `DynamicPersistentSchedulerState`). Schedulers consume it through the + `ctx: SchedulerState | None` parameter on their `create(...)`. """ - scheduling_mode: cutlass.Constexpr[SchedulingMode] _pipeline: cutlass.pipeline.PipelineAsync _consumer_state: PipelineState _producer_state: PipelineState - _hw_scheduler: Optional[ClcDynamicPersistentTileScheduler] = None - _work_info: Optional[cute.Tensor] = None @staticmethod def create_clc( @@ -68,15 +59,8 @@ def create_clc( pipeline: PipelineClcFetchAsync, consumer_state: PipelineState, producer_state: PipelineState, - ) -> "SchedulerState": - return SchedulerState( - SchedulingMode.CLC, - pipeline, - consumer_state, - producer_state, - hw_scheduler, - None, - ) + ) -> "ClcSchedulerState": + return ClcSchedulerState(pipeline, consumer_state, producer_state, hw_scheduler) @staticmethod def create_dynamic_persistent( @@ -85,17 +69,28 @@ def create_dynamic_persistent( pipeline: cutlass.pipeline.PipelineAsync, consumer_state: PipelineState, producer_state: PipelineState, - ) -> "SchedulerState": - return SchedulerState( - SchedulingMode.DYNAMIC, - pipeline, - consumer_state, - producer_state, - None, - work_info, - ) + ) -> "DynamicPersistentSchedulerState": + return DynamicPersistentSchedulerState(pipeline, consumer_state, producer_state, work_info) + + def consumer_wait(self, *, loc=None, ip=None): + self._pipeline.consumer_wait(self._consumer_state, loc=loc, ip=ip) + + def consumer_release(self, *, loc=None, ip=None): + self._pipeline.consumer_release(self._consumer_state, loc=loc, ip=ip) + self._consumer_state.advance(loc=loc, ip=ip) + + def advance_consumer_state(self, *, loc=None, ip=None): + self._consumer_state.advance(loc=loc, ip=ip) + + def producer_tail(self, *, loc=None, ip=None): + self._pipeline.producer_tail(self._producer_state, loc=loc, ip=ip) + + +@dataclass +class ClcSchedulerState(SchedulerState): + """CLC-backed: `prefetch_next_work` issues the HW query.""" - # ---- CLC-mode ---- + _hw_scheduler: ClcDynamicPersistentTileScheduler def initial_work_tile_info(self): return self._hw_scheduler.initial_work_tile_info() @@ -109,7 +104,13 @@ def prefetch_next_work(self, *, loc=None, ip=None): self._hw_scheduler.advance_to_next_work(mbarrier_addr, loc=loc, ip=ip) self._producer_state.advance(loc=loc, ip=ip) - # ---- Dynamic-persistent ---- + +@dataclass +class DynamicPersistentSchedulerState(SchedulerState): + """Semaphore-backed: the scheduler class drives atomicAdd + warp-prefix-sum + and writes the resolved work tile via `write_work_info`.""" + + _work_info: cute.Tensor def producer_acquire(self, *, loc=None, ip=None): self._pipeline.producer_acquire(self._producer_state, loc=loc, ip=ip) @@ -126,69 +127,7 @@ def write_work_info(self, block: Int32, head: Int32, batch: Int32, split: Int32) self._work_info[2] = batch self._work_info[3] = split - # ---- Common ---- - - def consumer_wait(self, *, loc=None, ip=None): - self._pipeline.consumer_wait(self._consumer_state, loc=loc, ip=ip) - - def consumer_release(self, *, loc=None, ip=None): - self._pipeline.consumer_release(self._consumer_state, loc=loc, ip=ip) - self._consumer_state.advance(loc=loc, ip=ip) - - def advance_consumer_state(self, *, loc=None, ip=None): - self._consumer_state.advance(loc=loc, ip=ip) - def producer_tail(self, *, loc=None, ip=None): - self._pipeline.producer_tail(self._producer_state, loc=loc, ip=ip) - - def __extract_mlir_values__(self): - ordered = [ - self.scheduling_mode, - self._pipeline, - self._consumer_state, - self._producer_state, - self._hw_scheduler, - self._work_info, - ] - values, self._values_pos = [], [] - for obj in ordered: - if obj is None or isinstance( - obj, (cutlass.Constexpr, int, bool, str, float, type(None)) - ): - self._values_pos.append(0) - continue - obj_values = cutlass.extract_mlir_values(obj) - values += obj_values - self._values_pos.append(len(obj_values)) - return values - - def __new_from_mlir_values__(self, values): - ordered = [ - self.scheduling_mode, - self._pipeline, - self._consumer_state, - self._producer_state, - self._hw_scheduler, - self._work_info, - ] - rebuilt = [] - for obj, n_items in zip(ordered, self._values_pos): - if n_items == 0: - rebuilt.append(obj) - else: - rebuilt.append(cutlass.new_from_mlir_values(obj, values[:n_items])) - values = values[n_items:] - return SchedulerState( - scheduling_mode=rebuilt[0], - _pipeline=rebuilt[1], - _consumer_state=rebuilt[2], - _producer_state=rebuilt[3], - _hw_scheduler=rebuilt[4], - _work_info=rebuilt[5], - ) - - -# Deprecated alias; remove after downstream call sites are updated. ClcState = SchedulerState @@ -898,68 +837,146 @@ def __new_from_mlir_values__(self, values): return self.__class__(*(tuple(obj_list)), loc=self._loc) -class VarlenSchedulerBase: - """Base for varlen tile schedulers (SingleTileVarlenScheduler, - DynamicPersistentVarlenScheduler). Owns the shared per-batch m-block lookup - and the warp-prefix-sum search-and-decode of the work tile. - - Subclasses must expose: - - self.params (ParamsBase) with the fields documented on each method. - - _get_num_splits(lane, bidb_start) -> Int32 +@dataclass +class VarlenDecoder(ParamsBase): + """Per-batch m-block lookup + warp-prefix-sum search-and-decode of the + varlen work tile. Composed into both `SingleTileVarlenScheduler.Params` + and `DynamicPersistentVarlenScheduler.Params`. + + `fold_splits_into_scan` controls whether the prefix-sum scan folds per-batch + `num_splits` into the per-batch tile count (DynamicPersistent) or always + counts only m_blocks (SingleTileVarlen, where splits are dispatched at the + grid level and resolved post-scan). """ + num_head: Int32 + num_batch: Int32 + num_splits: Int32 + max_kvblock_in_l2: Int32 + tile_shape_mn: cutlass.Constexpr[Tuple[int, int]] + qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1 + is_split_kv: cutlass.Constexpr[bool] = False + lpt: cutlass.Constexpr[bool] = False + head_swizzle: cutlass.Constexpr[bool] = False + cluster_shape_m: cutlass.Constexpr[int] = 1 + fold_splits_into_scan: cutlass.Constexpr[bool] = False + scheduling_mode: cutlass.Constexpr[SchedulingMode] = SchedulingMode.STATIC + mCuSeqlensQ: Optional[cute.Tensor] = None + mSeqUsedQ: Optional[cute.Tensor] = None + num_m_blocks_ptr: Optional[cute.Tensor] = None + num_splits_dynamic_ptr: Optional[cute.Tensor] = None + virtual_batch_idx_ptr: Optional[cute.Tensor] = None + num_nheads_in_l2_ptr: Optional[cute.Tensor] = None + cu_total_m_blocks_ptr: Optional[cute.Tensor] = None + cu_total_splits_m_blocks_ptr: Optional[cute.Tensor] = None + + @staticmethod @cute.jit - def _get_num_m_blocks(self, lane: Int32, bidb_start: Int32) -> Int32: + def create( + args: TileSchedulerArguments, + *, + fold_splits_into_scan: bool, + head_swizzle: bool = False, + cluster_shape_m: int = 1, + scheduling_mode: SchedulingMode = SchedulingMode.STATIC, + loc=None, + ip=None, + ) -> "VarlenDecoder": + size_l2 = 50 * 1024 * 1024 # 50 MB for K & V + # if backward, this is qdo block size + kv_block_size = (args.headdim + args.headdim_v) * args.element_size * args.tile_shape_mn[1] + # if backward, add dqaccum block size to calculate swizzle + if head_swizzle: + kv_block_size += args.headdim * 4 * args.tile_shape_mn[1] + max_kvblock_in_l2 = size_l2 // kv_block_size + return VarlenDecoder( + num_head=args.num_head, + num_batch=args.num_batch, + num_splits=args.num_splits, + max_kvblock_in_l2=max_kvblock_in_l2, + tile_shape_mn=args.tile_shape_mn, + qhead_per_kvhead_packgqa=args.qhead_per_kvhead_packgqa, + is_split_kv=args.is_split_kv, + lpt=args.lpt, + head_swizzle=head_swizzle, + cluster_shape_m=cluster_shape_m, + fold_splits_into_scan=fold_splits_into_scan, + scheduling_mode=scheduling_mode, + mCuSeqlensQ=args.mCuSeqlensQ, + mSeqUsedQ=args.mSeqUsedQ, + num_m_blocks_ptr=args.num_m_blocks_ptr, + num_splits_dynamic_ptr=args.num_splits_dynamic_ptr, + virtual_batch_idx_ptr=args.virtual_batch_idx_ptr, + num_nheads_in_l2_ptr=args.num_nheads_in_l2_ptr, + cu_total_m_blocks_ptr=args.cu_total_m_blocks_ptr, + cu_total_splits_m_blocks_ptr=args.cu_total_splits_m_blocks_ptr, + ) + + @cute.jit + def _num_m_blocks(self, lane: Int32, bidb_start: Int32) -> Int32: """Per-batch m-block count""" - params = self.params - cluster_shape_m = getattr(params, "cluster_shape_m", 1) batch_idx = lane + bidb_start - is_valid = batch_idx < params.num_batch and lane < cute.arch.WARP_SIZE - 1 - if cutlass.const_expr(params.num_m_blocks_ptr is not None): + is_valid = batch_idx < self.num_batch and lane < cute.arch.WARP_SIZE - 1 + if cutlass.const_expr(self.num_m_blocks_ptr is not None): num_m_blocks_raw = Int32(0) if is_valid: - if cutlass.const_expr(params.virtual_batch_idx_ptr is not None): - real_batch_idx = params.virtual_batch_idx_ptr[batch_idx] + if cutlass.const_expr(self.virtual_batch_idx_ptr is not None): + real_batch_idx = self.virtual_batch_idx_ptr[batch_idx] else: real_batch_idx = batch_idx - num_m_blocks_raw = Int32(params.num_m_blocks_ptr[real_batch_idx]) - return cute.ceil_div(num_m_blocks_raw, cluster_shape_m) if is_valid else Int32(0) - if cutlass.const_expr(params.virtual_batch_idx_ptr is not None): + num_m_blocks_raw = Int32(self.num_m_blocks_ptr[real_batch_idx]) + return cute.ceil_div(num_m_blocks_raw, self.cluster_shape_m) if is_valid else Int32(0) + if cutlass.const_expr(self.virtual_batch_idx_ptr is not None): seqlen = Int32(0) if is_valid: - real_batch_idx = params.virtual_batch_idx_ptr[batch_idx] - seqlen = ( - params.mCuSeqlensQ[real_batch_idx + 1] - - params.mCuSeqlensQ[real_batch_idx] - ) - if cutlass.const_expr(params.qhead_per_kvhead_packgqa > 1): - seqlen *= params.qhead_per_kvhead_packgqa + real_batch_idx = self.virtual_batch_idx_ptr[batch_idx] + seqlen = self.mCuSeqlensQ[real_batch_idx + 1] - self.mCuSeqlensQ[real_batch_idx] + if cutlass.const_expr(self.qhead_per_kvhead_packgqa > 1): + seqlen *= self.qhead_per_kvhead_packgqa return ( - cute.ceil_div(cute.ceil_div(seqlen, params.tile_shape_mn[0]), cluster_shape_m) + cute.ceil_div(cute.ceil_div(seqlen, self.tile_shape_mn[0]), self.cluster_shape_m) if is_valid else Int32(0) ) - if cutlass.const_expr(params.mSeqUsedQ is not None): + if cutlass.const_expr(self.mSeqUsedQ is not None): seqlen = Int32(0) - if batch_idx < params.num_batch: - seqlen = params.mSeqUsedQ[batch_idx] + if batch_idx < self.num_batch: + seqlen = self.mSeqUsedQ[batch_idx] else: - assert params.mCuSeqlensQ is not None + assert self.mCuSeqlensQ is not None cur_cu_seqlen = Int32(0) - if batch_idx <= params.num_batch: - cur_cu_seqlen = params.mCuSeqlensQ[batch_idx] + if batch_idx <= self.num_batch: + cur_cu_seqlen = self.mCuSeqlensQ[batch_idx] next_cu_seqlen = cute.arch.shuffle_sync_down(cur_cu_seqlen, offset=1) seqlen = next_cu_seqlen - cur_cu_seqlen - if cutlass.const_expr(params.qhead_per_kvhead_packgqa > 1): - seqlen *= params.qhead_per_kvhead_packgqa + if cutlass.const_expr(self.qhead_per_kvhead_packgqa > 1): + seqlen *= self.qhead_per_kvhead_packgqa return ( - cute.ceil_div(cute.ceil_div(seqlen, params.tile_shape_mn[0]), cluster_shape_m) + cute.ceil_div(cute.ceil_div(seqlen, self.tile_shape_mn[0]), self.cluster_shape_m) if is_valid else Int32(0) ) @cute.jit - def _varlen_coord_map( + def _num_splits(self, lane: Int32, bidb_start: Int32) -> Int32: + if cutlass.const_expr(not self.fold_splits_into_scan): + return Int32(1) + batch_idx = lane + bidb_start + is_valid = batch_idx < self.num_batch and lane < cute.arch.WARP_SIZE - 1 + if cutlass.const_expr(not self.is_split_kv): + return Int32(1) + elif cutlass.const_expr(self.num_splits_dynamic_ptr is not None): + num_splits = Int32(0) + if is_valid: + if cutlass.const_expr(self.virtual_batch_idx_ptr is not None): + batch_idx = self.virtual_batch_idx_ptr[batch_idx] + num_splits = self.num_splits_dynamic_ptr[batch_idx] + return num_splits + else: + return Int32(0) if not is_valid else self.num_splits + + @cute.jit + def decode( self, next_tile_idx: Int32, bidb_start: Int32, @@ -975,81 +992,58 @@ def _varlen_coord_map( - num_splits - group_start_tile - is_valid - - self.params must expose: - - num_head: Int32 - - num_batch: Int32 - - tile_shape_mn: Constexpr[tuple] - - qhead_per_kvhead_packgqa: Constexpr[int] - - max_kvblock_in_l2: Int32 - - is_split_kv: Constexpr[bool] - - lpt: Constexpr[bool] - Optionally: - - head_swizzle: Constexpr[bool] | None - - cluster_shape_m: Constexpr[int] | None - - num_nheads_in_l2_ptr: cute.Tensor | None - - virtual_batch_idx_ptr: cute.Tensor | None """ - params = self.params - head_swizzle = getattr(params, "head_swizzle", False) - cluster_shape_m = getattr(params, "cluster_shape_m", 1) - num_nheads_in_l2_ptr = getattr(params, "num_nheads_in_l2_ptr", None) - virtual_batch_idx_ptr = getattr(params, "virtual_batch_idx_ptr", None) - cu_total_m_blocks_ptr = getattr(params, "cu_total_m_blocks_ptr", None) - cu_total_splits_m_blocks_ptr = getattr(params, "cu_total_splits_m_blocks_ptr", None) - scheduling_mode = getattr(params, "scheduling_mode", None) - if const_expr(params.is_split_kv): - cu_hint_ptr = cu_total_splits_m_blocks_ptr + if const_expr(self.is_split_kv): + cu_hint_ptr = self.cu_total_splits_m_blocks_ptr else: - cu_hint_ptr = cu_total_m_blocks_ptr + cu_hint_ptr = self.cu_total_m_blocks_ptr # Both SingleTileVarlen STATIC and CLC; not DynamicPersistent (where # warp-scan's _bidb_start resumption already amortizes per-call cost). use_cumsum_hint = const_expr( - cluster_shape_m == 1 + self.cluster_shape_m == 1 and cu_hint_ptr is not None - and (scheduling_mode == SchedulingMode.STATIC or scheduling_mode == SchedulingMode.CLC) + and ( + self.scheduling_mode == SchedulingMode.STATIC + or self.scheduling_mode == SchedulingMode.CLC + ) ) if const_expr(use_cumsum_hint): - target = next_tile_idx // params.num_head + target = next_tile_idx // self.num_head lo = utils.get_batch_from_cu_tensor(target, cu_hint_ptr) group_size = Int32(cute.arch.WARP_SIZE - 1) bidb_start = (lo // group_size) * group_size - group_start_tile = cu_hint_ptr[bidb_start] * params.num_head + group_start_tile = cu_hint_ptr[bidb_start] * self.num_head lane_idx = cute.arch.lane_idx() - num_m_blocks = self._get_num_m_blocks(lane_idx, bidb_start=bidb_start) - num_splits = self._get_num_splits(lane_idx, bidb_start=bidb_start) - per_batch = ( - num_m_blocks * num_splits if const_expr(params.is_split_kv) else num_m_blocks - ) + num_m_blocks = self._num_m_blocks(lane_idx, bidb_start=bidb_start) + num_splits = self._num_splits(lane_idx, bidb_start=bidb_start) + per_batch = num_m_blocks * num_splits if const_expr(self.is_split_kv) else num_m_blocks cumulative = utils.warp_prefix_sum(per_batch, lane_idx) m_blocks_in_group = cute.arch.shuffle_sync(cumulative, cute.arch.WARP_SIZE - 1) - group_end_tile = m_blocks_in_group * params.num_head + group_start_tile + group_end_tile = m_blocks_in_group * self.num_head + group_start_tile batch_idx = bidb_start while group_end_tile <= next_tile_idx: batch_idx += cute.arch.WARP_SIZE - 1 - if batch_idx >= params.num_batch: - batch_idx = Int32(params.num_batch) + if batch_idx >= self.num_batch: + batch_idx = Int32(self.num_batch) group_end_tile = next_tile_idx + 1 else: - num_m_blocks = self._get_num_m_blocks(lane_idx, bidb_start=batch_idx) - num_splits = self._get_num_splits(lane_idx, bidb_start=batch_idx) + num_m_blocks = self._num_m_blocks(lane_idx, bidb_start=batch_idx) + num_splits = self._num_splits(lane_idx, bidb_start=batch_idx) per_batch = ( - num_m_blocks * num_splits - if const_expr(params.is_split_kv) - else num_m_blocks + num_m_blocks * num_splits if const_expr(self.is_split_kv) else num_m_blocks ) cumulative = utils.warp_prefix_sum(per_batch, lane_idx) m_blocks_in_group = cute.arch.shuffle_sync(cumulative, cute.arch.WARP_SIZE - 1) - group_end_tile += m_blocks_in_group * params.num_head + group_end_tile += m_blocks_in_group * self.num_head - is_valid = batch_idx < params.num_batch + is_valid = batch_idx < self.num_batch if is_valid: - group_start_tile = group_end_tile - m_blocks_in_group * params.num_head + group_start_tile = group_end_tile - m_blocks_in_group * self.num_head batch_idx_in_group = cute.arch.popc( cute.arch.vote_ballot_sync( - group_start_tile + cumulative * params.num_head <= next_tile_idx + group_start_tile + cumulative * self.num_head <= next_tile_idx ) ) batch_idx += batch_idx_in_group @@ -1058,58 +1052,58 @@ def _varlen_coord_map( if batch_idx_in_group == 0 else cute.arch.shuffle_sync(cumulative, batch_idx_in_group - 1) ) - group_start_tile += num_m_blocks_prev_lane * params.num_head + group_start_tile += num_m_blocks_prev_lane * self.num_head num_m_blocks = cute.arch.shuffle_sync(num_m_blocks, batch_idx_in_group) - if const_expr(params.is_split_kv): + if const_expr(self.is_split_kv): num_splits = cute.arch.shuffle_sync(num_splits, batch_idx_in_group) block, head_idx, split_idx = Int32(0), Int32(0), Int32(0) if is_valid: mh_block = next_tile_idx - group_start_tile - if const_expr(params.lpt or head_swizzle): + if const_expr(self.lpt or self.head_swizzle): # This is a version of the SingleTileLPTScheduler, complicated by the fact that # the seqlen can vary per batch. # TODO: is there any case where num_m_blocks is 0? - if const_expr(not params.is_split_kv) or num_splits == 1: - if const_expr(num_nheads_in_l2_ptr is not None): - if const_expr(virtual_batch_idx_ptr is not None): + if const_expr(not self.is_split_kv) or num_splits == 1: + if const_expr(self.num_nheads_in_l2_ptr is not None): + if const_expr(self.virtual_batch_idx_ptr is not None): nheads_in_l2 = Int32( - num_nheads_in_l2_ptr[virtual_batch_idx_ptr[batch_idx]] + self.num_nheads_in_l2_ptr[self.virtual_batch_idx_ptr[batch_idx]] ) else: - nheads_in_l2 = Int32(num_nheads_in_l2_ptr[batch_idx]) + nheads_in_l2 = Int32(self.num_nheads_in_l2_ptr[batch_idx]) else: # TODO: by right we should read the seqlen_kv but we're assuming seqlen_q == seqlen_k here num_n_blocks = ( num_m_blocks - * params.tile_shape_mn[0] - * cluster_shape_m - // params.qhead_per_kvhead_packgqa - // params.tile_shape_mn[1] + * self.tile_shape_mn[0] + * self.cluster_shape_m + // self.qhead_per_kvhead_packgqa + // self.tile_shape_mn[1] ) # Seems faster to have nheads_in_l2 be a power of 2 nheads_in_l2 = ( 16 - if num_n_blocks * 16 <= params.max_kvblock_in_l2 + if num_n_blocks * 16 <= self.max_kvblock_in_l2 else ( 8 - if num_n_blocks * 8 <= params.max_kvblock_in_l2 + if num_n_blocks * 8 <= self.max_kvblock_in_l2 else ( 4 - if num_n_blocks * 4 <= params.max_kvblock_in_l2 - else (2 if num_n_blocks * 2 <= params.max_kvblock_in_l2 else 1) + if num_n_blocks * 4 <= self.max_kvblock_in_l2 + else (2 if num_n_blocks * 2 <= self.max_kvblock_in_l2 else 1) ) ) ) - nheads_in_l2 = min(nheads_in_l2, params.num_head) + nheads_in_l2 = min(nheads_in_l2, self.num_head) mh_in_l2 = nheads_in_l2 * num_m_blocks section_idx = mh_block // mh_in_l2 l2_mod = mh_block - section_idx * mh_in_l2 nheads_in_this_section = ( nheads_in_l2 - if nheads_in_l2 * (section_idx + 1) <= params.num_head - else params.num_head - section_idx * nheads_in_l2 + if nheads_in_l2 * (section_idx + 1) <= self.num_head + else self.num_head - section_idx * nheads_in_l2 ) block = l2_mod // nheads_in_this_section head_idx_residual = l2_mod - block * nheads_in_this_section @@ -1119,47 +1113,30 @@ def _varlen_coord_map( block = mh_block - head_split_idx * num_m_blocks head_idx = head_split_idx // num_splits split_idx = head_split_idx - head_idx * num_splits - if const_expr(params.lpt): + if const_expr(self.lpt): block = num_m_blocks - 1 - block else: head_split_idx = mh_block // num_m_blocks block = mh_block - head_split_idx * num_m_blocks - if const_expr(params.is_split_kv): + if const_expr(self.is_split_kv): head_idx = head_split_idx // num_splits split_idx = head_split_idx - head_idx * num_splits else: head_idx = head_split_idx - if const_expr(cluster_shape_m > 1): + if const_expr(self.cluster_shape_m > 1): bidx_in_cluster = cute.arch.block_in_cluster_idx() - block = block * cluster_shape_m + bidx_in_cluster[0] + block = block * self.cluster_shape_m + bidx_in_cluster[0] return block, head_idx, batch_idx, split_idx, num_splits, group_start_tile, is_valid -class SingleTileVarlenScheduler(VarlenSchedulerBase): +class SingleTileVarlenScheduler: @dataclass class Params(ParamsBase): - num_head: Int32 - num_batch: Int32 total_q: Int32 - num_splits: Int32 - max_kvblock_in_l2: Int32 - tile_shape_mn: cutlass.Constexpr[Tuple[int, int]] - mCuSeqlensQ: Optional[cute.Tensor] = None - mSeqUsedQ: Optional[cute.Tensor] = None - qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1 - lpt: cutlass.Constexpr[bool] = False - is_split_kv: cutlass.Constexpr[bool] = False - head_swizzle: cutlass.Constexpr[bool] = False - cluster_shape_m: cutlass.Constexpr[int] = 1 - scheduling_mode: cutlass.Constexpr[SchedulingMode] = SchedulingMode.STATIC - num_splits_dynamic_ptr: Optional[cute.Tensor] = None - num_m_blocks_ptr: Optional[cute.Tensor] = None - virtual_batch_idx_ptr: Optional[cute.Tensor] = None - num_nheads_in_l2_ptr: Optional[cute.Tensor] = None - cu_total_m_blocks_ptr: Optional[cute.Tensor] = None - cu_total_splits_m_blocks_ptr: Optional[cute.Tensor] = None + scheduling_mode: cutlass.Constexpr[SchedulingMode] + decoder: VarlenDecoder @staticmethod @cute.jit @@ -1173,15 +1150,6 @@ def create( assert scheduling_mode in (SchedulingMode.STATIC, SchedulingMode.CLC), ( f"Only STATIC and CLC are supported, got {scheduling_mode!r}" ) - size_l2 = 50 * 1024 * 1024 # 50 MB for K & V - # if backward, this is qdo block size - kv_block_size = ( - (args.headdim + args.headdim_v) * args.element_size * args.tile_shape_mn[1] - ) - # if backward, add dqaccum block size to calculate swizzle - if args.head_swizzle: - kv_block_size += args.headdim * 4 * args.tile_shape_mn[1] - max_kvblock_in_l2 = size_l2 // kv_block_size assert args.mCuSeqlensQ is not None or args.mSeqUsedQ is not None, ( "At least one of mCuSeqlensQ or mSeqUsedQ must be provided" ) @@ -1191,27 +1159,19 @@ def create( assert scheduling_mode != SchedulingMode.CLC or args.cluster_shape_mn[0] == 1, ( "Varlen CLC currently requires cluster_shape_mn[0] == 1" ) - return SingleTileVarlenScheduler.Params( - num_head=args.num_head, - num_batch=args.num_batch, - total_q=args.total_q, - num_splits=args.num_splits, - max_kvblock_in_l2=max_kvblock_in_l2, - tile_shape_mn=args.tile_shape_mn, - mCuSeqlensQ=args.mCuSeqlensQ, - mSeqUsedQ=args.mSeqUsedQ, - qhead_per_kvhead_packgqa=args.qhead_per_kvhead_packgqa, - lpt=args.lpt, - is_split_kv=args.is_split_kv, + decoder = VarlenDecoder.create( + args, + fold_splits_into_scan=False, head_swizzle=args.head_swizzle, cluster_shape_m=args.cluster_shape_mn[0], scheduling_mode=scheduling_mode, - num_splits_dynamic_ptr=args.num_splits_dynamic_ptr, - num_m_blocks_ptr=args.num_m_blocks_ptr, - virtual_batch_idx_ptr=args.virtual_batch_idx_ptr, - num_nheads_in_l2_ptr=args.num_nheads_in_l2_ptr, - cu_total_m_blocks_ptr=args.cu_total_m_blocks_ptr, - cu_total_splits_m_blocks_ptr=args.cu_total_splits_m_blocks_ptr, + loc=loc, + ip=ip, + ) + return SingleTileVarlenScheduler.Params( + total_q=args.total_q, + scheduling_mode=scheduling_mode, + decoder=decoder, ) def __init__( @@ -1260,7 +1220,7 @@ def create( if const_expr(params.scheduling_mode == SchedulingMode.CLC): block_idx = cute.arch.block_idx() split_idx = Int32(0) - if const_expr(params.is_split_kv): + if const_expr(params.decoder.is_split_kv): split_idx = block_idx[1] return SingleTileVarlenScheduler( params, @@ -1281,35 +1241,29 @@ def get_grid_shape( loc=None, ip=None, ) -> Tuple[Int32, Int32, Int32]: + d = params.decoder total_blocks_max = ( - params.total_q - + params.num_batch * (params.cluster_shape_m * params.tile_shape_mn[0] - 1) - ) // params.tile_shape_mn[0] + params.total_q + d.num_batch * (d.cluster_shape_m * d.tile_shape_mn[0] - 1) + ) // d.tile_shape_mn[0] # Round down to nearest multiple of cluster since odd excess is always padding. - total_blocks_max = total_blocks_max // params.cluster_shape_m * params.cluster_shape_m - return (total_blocks_max * params.num_head, params.num_splits, Int32(1)) - - @cute.jit - def _get_num_splits(self, lane: Int32, bidb_start: Int32) -> Int32: - return Int32(1) + total_blocks_max = total_blocks_max // d.cluster_shape_m * d.cluster_shape_m + return (total_blocks_max * d.num_head, d.num_splits, Int32(1)) @cute.jit def _decode_work_tile(self) -> WorkTileInfo: """Map self._tile_idx to (block, head, batch, split) via warp-level prefix sums.""" - params = self.params - next_tile_idx = self._tile_idx // params.cluster_shape_m - block, head_idx, batch_idx, _, _, _, is_valid = self._varlen_coord_map( - next_tile_idx, Int32(0), Int32(0) - ) + d = self.params.decoder + next_tile_idx = self._tile_idx // d.cluster_shape_m + block, head_idx, batch_idx, _, _, _, is_valid = d.decode(next_tile_idx, Int32(0), Int32(0)) is_valid = is_valid and self._is_first_block - split_idx = self._split_idx if const_expr(params.is_split_kv) else Int32(0) - if const_expr(params.virtual_batch_idx_ptr is not None): + split_idx = self._split_idx if const_expr(d.is_split_kv) else Int32(0) + if const_expr(d.virtual_batch_idx_ptr is not None): if is_valid: - batch_idx = params.virtual_batch_idx_ptr[batch_idx] + batch_idx = d.virtual_batch_idx_ptr[batch_idx] # Pack dynamic per-batch num_splits into high 16 bits of split_idx - if const_expr(params.is_split_kv and params.num_splits_dynamic_ptr is not None): + if const_expr(d.is_split_kv and d.num_splits_dynamic_ptr is not None): if is_valid: - num_splits = Int32(params.num_splits_dynamic_ptr[batch_idx]) + num_splits = Int32(d.num_splits_dynamic_ptr[batch_idx]) split_idx = split_idx | (num_splits << 16) return WorkTileInfo( (Int32(block), Int32(head_idx), Int32(batch_idx), Int32(split_idx)), @@ -1328,7 +1282,7 @@ def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: new_split_idx = Int32(0) if clc_work.is_valid_tile: new_tile_idx = clc_work.tile_idx[0] - if const_expr(self.params.is_split_kv): + if const_expr(self.params.decoder.is_split_kv): new_split_idx = clc_work.tile_idx[1] self._tile_idx = new_tile_idx self._split_idx = new_split_idx @@ -1343,7 +1297,7 @@ def initial_work_tile_info(self, *, loc=None, ip=None): new_split_idx = Int32(0) if clc_work.is_valid_tile: new_tile_idx = clc_work.tile_idx[0] - if const_expr(self.params.is_split_kv): + if const_expr(self.params.decoder.is_split_kv): new_split_idx = clc_work.tile_idx[1] self._tile_idx = new_tile_idx self._split_idx = new_split_idx @@ -1388,26 +1342,11 @@ def __new_from_mlir_values__(self, values): return self.__class__(*obj_list, loc=self._loc) -class DynamicPersistentVarlenScheduler(VarlenSchedulerBase): +class DynamicPersistentVarlenScheduler: @dataclass class Params(ParamsBase): - num_head: Int32 - num_batch: Int32 total_q: Int32 - num_splits: Int32 - max_kvblock_in_l2: Int32 - tile_shape_mn: cutlass.Constexpr[Tuple[int, int]] - mCuSeqlensQ: Optional[cute.Tensor] = None - mSeqUsedQ: Optional[cute.Tensor] = None - qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1 - lpt: cutlass.Constexpr[bool] = False - is_split_kv: cutlass.Constexpr[bool] = False - num_splits_dynamic_ptr: Optional[cute.Tensor] = None - num_m_blocks_ptr: Optional[cute.Tensor] = None - virtual_batch_idx_ptr: Optional[cute.Tensor] = None - num_nheads_in_l2_ptr: Optional[cute.Tensor] = None - cu_total_m_blocks_ptr: Optional[cute.Tensor] = None - cu_total_splits_m_blocks_ptr: Optional[cute.Tensor] = None + decoder: VarlenDecoder tile_count_semaphore: Optional[cute.Pointer] = None persistent_cta_multiplier: cutlass.Constexpr[int] = 1 @@ -1416,31 +1355,19 @@ class Params(ParamsBase): def create( args: TileSchedulerArguments, *, loc=None, ip=None ) -> "DynamicPersistentVarlenScheduler.Params": - size_l2 = 50 * 1024 * 1024 # 50 MB for K & V - max_kvblock_in_l2 = size_l2 // ( - (args.headdim + args.headdim_v) * args.element_size * args.tile_shape_mn[1] - ) assert args.mCuSeqlensQ is not None or args.mSeqUsedQ is not None, ( "At least one of mCuSeqlensQ or mSeqUsedQ must be provided" ) + decoder = VarlenDecoder.create( + args, + fold_splits_into_scan=True, + scheduling_mode=SchedulingMode.DYNAMIC, + loc=loc, + ip=ip, + ) return DynamicPersistentVarlenScheduler.Params( - num_head=args.num_head, - num_batch=args.num_batch, total_q=args.total_q, - num_splits=args.num_splits, - max_kvblock_in_l2=max_kvblock_in_l2, - tile_shape_mn=args.tile_shape_mn, - mCuSeqlensQ=args.mCuSeqlensQ, - mSeqUsedQ=args.mSeqUsedQ, - qhead_per_kvhead_packgqa=args.qhead_per_kvhead_packgqa, - lpt=args.lpt, - is_split_kv=args.is_split_kv, - num_splits_dynamic_ptr=args.num_splits_dynamic_ptr, - num_m_blocks_ptr=args.num_m_blocks_ptr, - virtual_batch_idx_ptr=args.virtual_batch_idx_ptr, - num_nheads_in_l2_ptr=args.num_nheads_in_l2_ptr, - cu_total_m_blocks_ptr=args.cu_total_m_blocks_ptr, - cu_total_splits_m_blocks_ptr=args.cu_total_splits_m_blocks_ptr, + decoder=decoder, tile_count_semaphore=args.tile_count_semaphore, persistent_cta_multiplier=args.persistent_cta_multiplier, ) @@ -1494,33 +1421,17 @@ def get_grid_shape( loc=None, ip=None, ) -> Tuple[Int32, Int32, Int32]: + d = params.decoder total_blocks_max = ( - params.total_q + params.num_batch * (params.tile_shape_mn[0] - 1) - ) // params.tile_shape_mn[0] - total_blocks = total_blocks_max * params.num_head * params.num_splits + params.total_q + d.num_batch * (d.tile_shape_mn[0] - 1) + ) // d.tile_shape_mn[0] + total_blocks = total_blocks_max * d.num_head * d.num_splits hardware_info = HardwareInfo() sm_count = ( hardware_info.get_device_multiprocessor_count() * params.persistent_cta_multiplier ) return (cutlass.min(sm_count, total_blocks), Int32(1), Int32(1)) - @cute.jit - def _get_num_splits(self, lane: Int32, bidb_start: Int32) -> Int32: - params = self.params - batch_idx = lane + bidb_start - is_valid = batch_idx < params.num_batch and lane < cute.arch.WARP_SIZE - 1 - if cutlass.const_expr(not params.is_split_kv): - return Int32(1) - elif cutlass.const_expr(params.num_splits_dynamic_ptr is not None): - num_splits = Int32(0) - if is_valid: - if cutlass.const_expr(params.virtual_batch_idx_ptr is not None): - batch_idx = params.virtual_batch_idx_ptr[batch_idx] - num_splits = params.num_splits_dynamic_ptr[batch_idx] - return num_splits - else: - return Int32(0) if not is_valid else params.num_splits - @cute.jit def get_current_work( self, @@ -1531,16 +1442,16 @@ def get_current_work( loc=None, ip=None, ) -> WorkTileInfo: - params = self.params - block, head_idx, batch_idx, split_idx, num_splits, group_start_tile, is_valid = ( - self._varlen_coord_map(next_tile_idx, bidb_start, group_start_tile) + d = self.params.decoder + block, head_idx, batch_idx, split_idx, num_splits, group_start_tile, is_valid = d.decode( + next_tile_idx, bidb_start, group_start_tile ) - if const_expr(params.is_split_kv and params.num_splits_dynamic_ptr is not None): + if const_expr(d.is_split_kv and d.num_splits_dynamic_ptr is not None): if is_valid: split_idx = split_idx | (num_splits << 16) - if const_expr(params.virtual_batch_idx_ptr is not None): + if const_expr(d.virtual_batch_idx_ptr is not None): if is_valid: - batch_idx = params.virtual_batch_idx_ptr[batch_idx] + batch_idx = d.virtual_batch_idx_ptr[batch_idx] return ( WorkTileInfo( (Int32(block), Int32(head_idx), Int32(batch_idx), Int32(split_idx)), @@ -1581,7 +1492,7 @@ def advance_to_next_work(self, *, loc=None, ip=None) -> WorkTileInfo: head_idx = ctx._work_info[1] batch_idx = ctx._work_info[2] split_idx = ctx._work_info[3] - is_valid = batch_idx < self.params.num_batch + is_valid = batch_idx < self.params.decoder.num_batch work_info = WorkTileInfo((block, head_idx, batch_idx, split_idx), is_valid) ctx.consumer_release() return work_info From 5a695710d382397feaa6a1a5a9113d05f16abb32 Mon Sep 17 00:00:00 2001 From: reubenconducts Date: Tue, 26 May 2026 22:00:13 +0000 Subject: [PATCH 11/15] work PR 2520 into interface and kernels --- benchmarks/benchmark_varlen_sched.py | 22 +- flash_attn/cute/flash_bwd.py | 2 + flash_attn/cute/flash_bwd_postprocess.py | 2 + flash_attn/cute/flash_bwd_preprocess.py | 2 + flash_attn/cute/flash_bwd_sm100.py | 2 + flash_attn/cute/flash_bwd_sm90.py | 2 + flash_attn/cute/flash_fwd.py | 4 + flash_attn/cute/flash_fwd_sm100.py | 8 +- flash_attn/cute/flash_fwd_sm90.py | 5 + flash_attn/cute/interface.py | 433 ++++++++++++++--------- flash_attn/cute/prepare_scheduler.py | 4 +- tests/cute/test_flash_attn.py | 168 +++++++++ 12 files changed, 484 insertions(+), 170 deletions(-) diff --git a/benchmarks/benchmark_varlen_sched.py b/benchmarks/benchmark_varlen_sched.py index d915123cdf6..a055fcec58a 100644 --- a/benchmarks/benchmark_varlen_sched.py +++ b/benchmarks/benchmark_varlen_sched.py @@ -181,6 +181,13 @@ def build_ctx( def _make_meta(ctx): + tile_m = 128 + qhead_per_kvhead = ctx["nheads"] // ctx["nheads_kv"] + arch = torch.cuda.get_device_capability()[0] + if arch == 10 and ctx["max_seqlen_q"] * qhead_per_kvhead > tile_m: + q_stage = 2 + else: + q_stage = 1 return get_scheduler_metadata( num_batch=ctx["batch"], max_seqlen_q=ctx["max_seqlen_q"], @@ -189,12 +196,13 @@ def _make_meta(ctx): nheads_kv=ctx["nheads_kv"], headdim=ctx["headdim"], num_splits=ctx["num_splits"], - tile_m=128, + tile_m=tile_m, tile_n=128, causal=ctx["causal"], pack_gqa=ctx["pack_gqa"], cu_seqlens_q=ctx["cu_q"], cu_seqlens_k=ctx["cu_k"], + q_stage=q_stage, ) @@ -322,8 +330,12 @@ def parse_args(): p.add_argument("--headdim", type=int, default=128) p.add_argument("--nheads", type=int, default=16) p.add_argument("--nheads-kv", type=int, default=2) - p.add_argument("--pack-gqa", action="store_true", default=True, - help="Force pack_gqa=True (default). --no-pack-gqa to disable.") + p.add_argument( + "--pack-gqa", + action="store_true", + default=True, + help="Force pack_gqa=True (default). --no-pack-gqa to disable.", + ) p.add_argument("--no-pack-gqa", dest="pack_gqa", action="store_false") p.add_argument("--seeds", type=int, default=3) p.add_argument("--warmup", type=int, default=2) @@ -434,7 +446,9 @@ def main(): eff = sq * sk - sq * (sq - 1) // 2 else: # clamp to non-negative for queries near 0 - first_visible_q = -shift # smallest q with sk - sq + q + 1 > 0 is q = sq - sk + first_visible_q = ( + -shift + ) # smallest q with sk - sq + q + 1 > 0 is q = sq - sk visible = sq - first_visible_q eff = visible * sk - visible * (visible - 1) // 2 eff = max(0, eff) diff --git a/flash_attn/cute/flash_bwd.py b/flash_attn/cute/flash_bwd.py index 81c8ac68bd9..e5041a75d32 100644 --- a/flash_attn/cute/flash_bwd.py +++ b/flash_attn/cute/flash_bwd.py @@ -389,6 +389,7 @@ def __call__( mdV_semaphore: Optional[cute.Tensor] = None, aux_tensors: Optional[list] = None, blocksparse_tensors: Optional[BlockSparseTensors] = None, + mCuTotalMBlocks: Optional[cute.Tensor] = None, # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI). stream: cuda.CUstream = None, ): @@ -429,6 +430,7 @@ def __call__( qhead_per_kvhead_packgqa=self.qhead_per_kvhead if cutlass.const_expr(self.pack_gqa) else 1, mCuSeqlensQ=mCuSeqlensK, mSeqUsedQ=mSeqUsedK, + cu_total_m_blocks_ptr=mCuTotalMBlocks, ) tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) diff --git a/flash_attn/cute/flash_bwd_postprocess.py b/flash_attn/cute/flash_bwd_postprocess.py index 76c856221c5..f971fe240b4 100644 --- a/flash_attn/cute/flash_bwd_postprocess.py +++ b/flash_attn/cute/flash_bwd_postprocess.py @@ -215,6 +215,7 @@ def __call__( scale: cutlass.Float32, mCuSeqlensQ: Optional[cute.Tensor], mSeqUsedQ: Optional[cute.Tensor], + mCuTotalMBlocks: Optional[cute.Tensor] = None, # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI). stream: cuda.CUstream = None, ): @@ -258,6 +259,7 @@ def __call__( tile_shape_mn=(self.tile_m, 1), mCuSeqlensQ=mCuSeqlensQ, mSeqUsedQ=mSeqUsedQ, + cu_total_m_blocks_ptr=mCuTotalMBlocks, ) tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) diff --git a/flash_attn/cute/flash_bwd_preprocess.py b/flash_attn/cute/flash_bwd_preprocess.py index 8142def5ebb..d3b9b5d2974 100644 --- a/flash_attn/cute/flash_bwd_preprocess.py +++ b/flash_attn/cute/flash_bwd_preprocess.py @@ -136,6 +136,7 @@ def __call__( mCuSeqlensQ: Optional[cute.Tensor], # (batch + 1,) mSeqUsedQ: Optional[cute.Tensor], # (batch,) mdLSE: Optional[cute.Tensor], # (batch, nheads, seqlen) or (nheads, total_q) + mCuTotalMBlocks: Optional[cute.Tensor] = None, # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI). stream: cuda.CUstream = None, ): @@ -193,6 +194,7 @@ def __call__( tile_shape_mn=(self.tile_m, 1), mCuSeqlensQ=mCuSeqlensQ, mSeqUsedQ=mSeqUsedQ, + cu_total_m_blocks_ptr=mCuTotalMBlocks, ) tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py index 11db2dab563..63c1f6b7cdd 100644 --- a/flash_attn/cute/flash_bwd_sm100.py +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -466,6 +466,7 @@ def __call__( aux_tensors: Optional[list] = None, # Block-sparse tensors (Q direction - for iterating m_blocks per n_block): blocksparse_tensors: Optional[BlockSparseTensors] = None, + mCuTotalMBlocks: Optional[cute.Tensor] = None, # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI). stream: cuda.CUstream = None, ): @@ -732,6 +733,7 @@ def __call__( qhead_per_kvhead_packgqa=1, # pack_gqa disabled for bwd element_size=self.k_dtype.width // 8, is_persistent=self.is_persistent, # persistent mode not tested + cu_total_m_blocks_ptr=mCuTotalMBlocks, lpt=self.spt, head_swizzle=self.deterministic, ) diff --git a/flash_attn/cute/flash_bwd_sm90.py b/flash_attn/cute/flash_bwd_sm90.py index 2e420924e92..e8649802035 100644 --- a/flash_attn/cute/flash_bwd_sm90.py +++ b/flash_attn/cute/flash_bwd_sm90.py @@ -357,6 +357,7 @@ def __call__( mdV_semaphore: Optional[cute.Tensor] = None, aux_tensors: Optional[list] = None, blocksparse_tensors: Optional[BlockSparseTensors] = None, + mCuTotalMBlocks: Optional[cute.Tensor] = None, # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI). stream: cuda.CUstream = None, ): @@ -536,6 +537,7 @@ def _qkv_transpose(t): is_persistent=False, lpt=self.spt, head_swizzle=self.deterministic, + cu_total_m_blocks_ptr=mCuTotalMBlocks, ) tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) diff --git a/flash_attn/cute/flash_fwd.py b/flash_attn/cute/flash_fwd.py index d1a43cfd247..36f61efb550 100644 --- a/flash_attn/cute/flash_fwd.py +++ b/flash_attn/cute/flash_fwd.py @@ -631,6 +631,8 @@ def __call__( learnable_sink: Optional[cute.Tensor] = None, blocksparse_tensors: Optional[BlockSparseTensors] = None, aux_tensors=None, + mCuTotalMBlocks: Optional[cute.Tensor] = None, + mCuTotalSplitsMBlocks: Optional[cute.Tensor] = None, # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI). stream: cuda.CUstream = None, ): @@ -692,6 +694,8 @@ def __call__( qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, mCuSeqlensQ=mCuSeqlensQ, mSeqUsedQ=mSeqUsedQ, + cu_total_m_blocks_ptr=mCuTotalMBlocks, + cu_total_splits_m_blocks_ptr=mCuTotalSplitsMBlocks, ) tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) grid_dim = TileScheduler.get_grid_shape(tile_sched_params) diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index 3da16ece4fa..08a4a319972 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -398,8 +398,8 @@ def __call__( tile_count_semaphore: Optional[cute.Tensor] = None, virtual_batch_idx_ptr: Optional[cute.Tensor] = None, num_nheads_in_l2_ptr: Optional[cute.Tensor] = None, - cu_total_m_blocks_ptr: Optional[cute.Tensor] = None, - cu_total_splits_m_blocks_ptr: Optional[cute.Tensor] = None, + mCuTotalMBlocks: Optional[cute.Tensor] = None, + mCuTotalSplitsMBlocks: Optional[cute.Tensor] = None, max_seqlen_q: Int32 | int | None = None, # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI). stream: cuda.CUstream = None, @@ -688,8 +688,8 @@ def __call__( num_splits_dynamic_ptr=num_splits_dynamic_ptr, virtual_batch_idx_ptr=virtual_batch_idx_ptr, num_nheads_in_l2_ptr=num_nheads_in_l2_ptr, - cu_total_m_blocks_ptr=cu_total_m_blocks_ptr, - cu_total_splits_m_blocks_ptr=cu_total_splits_m_blocks_ptr, + cu_total_m_blocks_ptr=mCuTotalMBlocks, + cu_total_splits_m_blocks_ptr=mCuTotalSplitsMBlocks, tile_count_semaphore=tile_count_semaphore.iterator if tile_count_semaphore is not None else None, ) tile_sched_params = TileScheduler.to_underlying_arguments( diff --git a/flash_attn/cute/flash_fwd_sm90.py b/flash_attn/cute/flash_fwd_sm90.py index 23f92181166..e38361582c1 100644 --- a/flash_attn/cute/flash_fwd_sm90.py +++ b/flash_attn/cute/flash_fwd_sm90.py @@ -172,6 +172,8 @@ def __call__( learnable_sink: Optional[cute.Tensor] = None, blocksparse_tensors: Optional[BlockSparseTensors] = None, aux_tensors: Optional[list] = None, + mCuTotalMBlocks: Optional[cute.Tensor] = None, + mCuTotalSplitsMBlocks: Optional[cute.Tensor] = None, # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI). stream: cuda.CUstream = None, ): @@ -312,6 +314,7 @@ def __call__( (self.tile_m, self.tile_hdimv), # No mcast ) if const_expr(mCuSeqlensQ is not None or mSeqUsedQ is not None): + # TODO: dispatch to DynamicPersistentVarlenScheduler when appropriate TileScheduler = SingleTileVarlenScheduler else: TileScheduler = ( @@ -341,6 +344,8 @@ def __call__( element_size=self.dtype.width // 8, is_persistent=False, lpt=self.is_causal or self.is_local, + cu_total_m_blocks_ptr=mCuTotalMBlocks, + cu_total_splits_m_blocks_ptr=mCuTotalSplitsMBlocks, ) tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) grid_dim = TileScheduler.get_grid_shape(tile_sched_params) diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 52df5f8f1b4..1e3b2428217 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -59,6 +59,8 @@ get_block_sparse_broadcast_pattern, ) +BIN_BATCH_SEARCH_THRESH = 512 # SingleTileVarlenScheduler uses binary search to find batch above this + def _parse_arch_str(arch_str): """Parse arch string (e.g. 'sm_80', 'sm_90a', '80', '100') to int (e.g. 80, 90, 100).""" import re @@ -290,6 +292,42 @@ def _resolve_causal_local_window(causal, window_size_left, window_size_right, ma local = False return causal, local, window_size_left, window_size_right + +def _compute_tile_cumsum( + *, + num_m_blocks: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.Tensor] = None, + seqused: Optional[torch.Tensor] = None, + num_splits_dynamic: Optional[torch.Tensor] = None, + virtual_batch_idx: Optional[torch.Tensor] = None, + tile_size: int = 1, + q_stage: int = 1, + qhead_per_kvhead: int = 1, + pack_gqa: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """(cu_total_m_blocks, cu_total_splits_m_blocks), int32, (num_batch + 1,). + + cu_total_splits_m_blocks is None when num_splits_dynamic is None. + """ + if num_m_blocks is None: + seqlens = seqused if seqused is not None else (cu_seqlens[1:] - cu_seqlens[:-1]) + if pack_gqa and qhead_per_kvhead > 1: + seqlens = seqlens * qhead_per_kvhead + num_m_blocks = (seqlens + tile_size - 1) // tile_size + num_m_blocks_eff = (num_m_blocks + q_stage - 1) // q_stage + order = virtual_batch_idx.long() if virtual_batch_idx is not None else None + if order is not None: + num_m_blocks_eff = num_m_blocks_eff[order] + if num_splits_dynamic is None: + cum = torch.cumsum(num_m_blocks_eff, dim=0, dtype=torch.int32) + return torch.nn.functional.pad(cum, (1, 0)), None + splits = num_splits_dynamic[order] if order is not None else num_splits_dynamic + stacked = torch.stack([num_m_blocks_eff, num_m_blocks_eff * splits], dim=0) + cum = torch.cumsum(stacked, dim=1, dtype=torch.int32) + padded = torch.nn.functional.pad(cum, (1, 0)) + return padded[0], padded[1] + + def _flash_attn_fwd( q: torch.Tensor, k: torch.Tensor, @@ -713,6 +751,7 @@ def _flash_attn_fwd( seqused_q=seqused_q, seqused_k=seqused_k, seqlen_k_per_split=seqlen_k_per_split, + q_stage=q_stage, ) has_scheduler_metadata = scheduler_metadata is not None and not disable_scheduler_metadata @@ -737,39 +776,43 @@ def _flash_attn_fwd( ), "these scheduler metadata tensors must have shape (batch_size,)" if tile_count_semaphore is not None: assert tile_count_semaphore.shape == (1,), "semaphore must have size 1" - # The kernel's _get_num_m_blocks uses tile_shape_mn[0] = q_stage * tile_m, - # so cu_total must be built with that effective tile. Rebuild here rather - # than trust scheduler_metadata.cu_total_*_ptr (which were computed with - # the user's tile_m and would mis-decode the CLC-exhausted sentinel). - if ( - num_m_blocks is not None - and num_splits_dynamic is not None - and os.environ.get("FLASH_ATTENTION_DISABLE_BINARY_SEARCH", "0") != "1" - ): - num_m_blocks_eff = (num_m_blocks + q_stage - 1) // q_stage - num_splits_m_blocks_eff = num_m_blocks_eff * num_splits_dynamic - if virtual_batch_idx is not None: - order = virtual_batch_idx.long() - stacked = torch.stack( - [num_m_blocks_eff[order], num_splits_m_blocks_eff[order]], dim=0 - ) - else: - stacked = torch.stack([num_m_blocks_eff, num_splits_m_blocks_eff], dim=0) - cum = torch.cumsum(stacked, dim=1, dtype=torch.int32) - padded = torch.nn.functional.pad(cum, (1, 0)) - cu_total_m_blocks = padded[0] - cu_total_splits_m_blocks = padded[1] - else: - cu_total_m_blocks = None - cu_total_splits_m_blocks = None else: num_m_blocks = None num_splits_dynamic = None virtual_batch_idx = None num_nheads_in_l2 = None tile_count_semaphore = None - cu_total_m_blocks = None - cu_total_splits_m_blocks = None + + # use binary batch search in SingleTileVarlenScheduler to avoid + # O(N^2) lookup; observed to be faster only for batch_size > BIN_BATCH_SEARCH_THRESH; this is tunable + cu_total_m_blocks = None + cu_total_splits_m_blocks = None + use_single_tile_varlen_scheduler = use_clc_scheduler or tile_count_semaphore is None + use_cu_hint = ( + is_varlen + and use_single_tile_varlen_scheduler + and batch_size > BIN_BATCH_SEARCH_THRESH + and not use_dedicated_hd256_kernel + ) + if ( + use_cu_hint + and has_scheduler_metadata + and scheduler_metadata.cu_total_m_blocks is not None + ): + cu_total_m_blocks = scheduler_metadata.cu_total_m_blocks + cu_total_splits_m_blocks = scheduler_metadata.cu_total_splits_m_blocks + elif use_cu_hint: + cu_total_m_blocks, cu_total_splits_m_blocks = _compute_tile_cumsum( + num_m_blocks=num_m_blocks, + cu_seqlens=cu_seqlens_q, + seqused=seqused_q, + num_splits_dynamic=num_splits_dynamic, + virtual_batch_idx=virtual_batch_idx, + tile_size=tile_m, + q_stage=q_stage, + qhead_per_kvhead=qhead_per_kvhead, + pack_gqa=pack_gqa, + ) is_static_persistent = ( not causal @@ -1280,7 +1323,7 @@ def make_fake_bwd_tensors(dtype, has_gqa, varlen_q, varlen_k): def _compile_bwd_preprocess( dtype, head_dim, head_dim_v, m_block_size, has_cuseqlens_q, has_seqused_q, has_dlse, has_dq_accum, - use_padded_offsets, + use_padded_offsets, has_cu_total_m_blocks, ): """Compile bwd preprocess kernel using cute fake tensors (no real GPU tensors needed).""" mQ, mK, mV, mO, mdO, mdQ, mdK, mdV, mLSE, mLSElog2, mPdPsum, mdQaccum, mdKaccum, mdVaccum = make_fake_bwd_tensors( @@ -1292,11 +1335,13 @@ def _compile_bwd_preprocess( mSequsedQ = fake_tensor(Int32, (batch,), divisibility=1) if has_seqused_q else None mdLSE = fake_tensor(Float32, mLSE.shape, divisibility=1) if has_dlse else None mdQaccum = mdQaccum if has_dq_accum else None + mCuTotalMBlocks = fake_tensor(Int32, (batchp1,), divisibility=1) if has_cu_total_m_blocks else None fa_bwd_pre = FlashAttentionBackwardPreprocess( dtype, head_dim, head_dim_v, m_block_size, use_padded_offsets=use_padded_offsets ) return cute.compile( fa_bwd_pre, mO, mdO, mPdPsum, mLSE, mLSElog2, mdQaccum, mCuSeqlensQ, mSequsedQ, mdLSE, + mCuTotalMBlocks, cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), options="--enable-tvm-ffi", ) @@ -1307,18 +1352,27 @@ def _bwd_preprocess( cu_seqlens_q, seqused_q, dlse, dtype, head_dim, head_dim_v, m_block_size, use_padded_offsets=True, + cu_total_m_blocks=None, ): """Backward preprocess: compute (o * dout).sum(dim=-1) - dLSE, lse * log2_e, and zero out dq_accum.""" is_varlen = cu_seqlens_q is not None + batch_size = (cu_seqlens_q.shape[0] - 1) if is_varlen else 0 + if cu_total_m_blocks is None and batch_size > BIN_BATCH_SEARCH_THRESH: + cu_total_m_blocks, _ = _compute_tile_cumsum( + cu_seqlens=cu_seqlens_q, + seqused=seqused_q, + tile_size=m_block_size, + ) compile_key = ( dtype, head_dim, head_dim_v, m_block_size, is_varlen, seqused_q is not None, dlse is not None, dq_accum is not None, - use_padded_offsets, + use_padded_offsets, cu_total_m_blocks is not None, ) if compile_key not in _bwd_preprocess.compile_cache: _bwd_preprocess.compile_cache[compile_key] = _compile_bwd_preprocess(*compile_key) if not is_fake_mode(): _bwd_preprocess.compile_cache[compile_key]( - out, dout, dpsum, lse, lse_log2, dq_accum, cu_seqlens_q, seqused_q, dlse + out, dout, dpsum, lse, lse_log2, dq_accum, cu_seqlens_q, seqused_q, dlse, + cu_total_m_blocks, ) @@ -1328,7 +1382,7 @@ def _bwd_preprocess( def _compile_bwd_postprocess( dtype, hdim, block_size, num_threads, atom_layout, swap_ab, has_cuseqlens_q, has_seqused_q, - use_2cta_instrs, cluster_size, arch, + use_2cta_instrs, cluster_size, arch, has_cu_total_m_blocks, ): """Compile bwd postprocess kernel using cute fake tensors.""" mQ, mK, mV, mO, mdO, mdQ, mdK, mdV, mLSE, mLSElog2, mPdPsum, mdQaccum, mdKaccum, mdVaccum = make_fake_bwd_tensors( @@ -1338,6 +1392,7 @@ def _compile_bwd_postprocess( batchp1 = cute.sym_int() mCuSeqlensQ = fake_tensor(Int32, (batchp1,), divisibility=1) if has_cuseqlens_q else None mSeqUsedQ = fake_tensor(Int32, (batch,), divisibility=1) if has_seqused_q else None + mCuTotalMBlocks = fake_tensor(Int32, (batchp1,), divisibility=1) if has_cu_total_m_blocks else None fa_bwd_post = FlashAttentionBackwardPostprocess( dtype, hdim, arch, block_size, num_threads, atom_layout, swap_ab, use_2cta_instrs=use_2cta_instrs, @@ -1345,6 +1400,7 @@ def _compile_bwd_postprocess( ) return cute.compile( fa_bwd_post, mdQaccum, mdQ, Float32(0.0), mCuSeqlensQ, mSeqUsedQ, + mCuTotalMBlocks, cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), options="--enable-tvm-ffi", ) @@ -1356,18 +1412,28 @@ def _bwd_postprocess_convert( arch, dtype, hdim, block_size, num_threads, atom_layout, swap_ab, use_2cta_instrs=False, cluster_size=1, + cu_total_m_blocks=None, ): """Backward postprocess: convert float32 accumulator to bf16/fp16 output.""" + is_varlen = cu_seqlens is not None + batch_size = (cu_seqlens.shape[0] - 1) if is_varlen else 0 + if cu_total_m_blocks is None and is_varlen and batch_size > BIN_BATCH_SEARCH_THRESH: + cu_total_m_blocks, _ = _compute_tile_cumsum( + cu_seqlens=cu_seqlens, + seqused=seqused, + tile_size=block_size, + ) compile_key = ( dtype, hdim, block_size, num_threads, atom_layout, swap_ab, cu_seqlens is not None, seqused is not None, - use_2cta_instrs, cluster_size, arch, + use_2cta_instrs, cluster_size, arch, cu_total_m_blocks is not None, ) if compile_key not in _bwd_postprocess_convert.compile_cache: _bwd_postprocess_convert.compile_cache[compile_key] = _compile_bwd_postprocess(*compile_key) if not is_fake_mode(): _bwd_postprocess_convert.compile_cache[compile_key]( accum, output, scale, cu_seqlens, seqused, + cu_total_m_blocks, ) @@ -1477,12 +1543,6 @@ def _flash_attn_bwd( dQ_single_wg = cfg.dQ_single_wg cluster_size = 1 use_2cta_instrs = False - is_varlen = ( - cu_seqlens_q is not None - or cu_seqlens_k is not None - or seqused_q is not None - or seqused_k is not None - ) else: m_block_size = 128 n_block_size = 128 @@ -1503,6 +1563,12 @@ def _flash_attn_bwd( use_dedicated_hd256_kernel = arch // 10 == 10 and head_dim == 256 and head_dim_v == 256 use_2cta_instrs = use_2cta_instrs or use_dedicated_hd256_kernel + is_varlen = ( + cu_seqlens_q is not None + or cu_seqlens_k is not None + or seqused_q is not None + or seqused_k is not None + ) q, k, v, out, dout, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = [ maybe_contiguous(t) @@ -1706,6 +1772,22 @@ def _flash_attn_bwd( dK_semaphore = None dV_semaphore = None + # SingleTileVarlenScheduler uses binary search to find batch idx with > 512 batch size + # shared across preprocess, main bwd, and the three postprocess calls. + cu_total_m_blocks_q = None + cu_total_m_blocks_k = None + if is_varlen and batch_size > BIN_BATCH_SEARCH_THRESH and not use_dedicated_hd256_kernel: + cu_total_m_blocks_q, _ = _compute_tile_cumsum( + cu_seqlens=cu_seqlens_q, + seqused=seqused_q, + tile_size=m_block_size, + ) + cu_total_m_blocks_k, _ = _compute_tile_cumsum( + cu_seqlens=cu_seqlens_k, + seqused=seqused_k, + tile_size=n_block_size, + ) + # Preprocess kernel: compute (o * dout).sum(dim=-1) - dLSE, lse * log2_e, and zero out dq_accum. # For hd=256 dedicated path, dq_accum is None so preprocess only fills dpsum/lse_log2. _bwd_preprocess( @@ -1713,6 +1795,7 @@ def _flash_attn_bwd( cu_seqlens_q, seqused_q, dlse, dtype, head_dim, head_dim_v, m_block_size, use_padded_offsets=use_dedicated_hd256_kernel, + cu_total_m_blocks=cu_total_m_blocks_q, ) # num_threads: SM90 derives from BwdConfig.num_wg, SM120 is set to 128 above, # SM100/SM110 uses default from function signature (384). @@ -1810,6 +1893,7 @@ def _flash_attn_bwd( # Prevent TVM stride poisoning when only one block is present. (seqlen_q_rounded // m_block_size == 1), (seqlen_k_rounded // n_block_size == 1), + cu_total_m_blocks_k is not None, ) else: compile_key = ( @@ -1846,6 +1930,7 @@ def _flash_attn_bwd( # Prevent TVM stride poisoning when only one block is present. (seqlen_q_rounded // m_block_size == 1), (seqlen_k_rounded // n_block_size == 1), + cu_total_m_blocks_k is not None, ) if compile_key not in _flash_attn_bwd.compile_cache: @@ -1858,9 +1943,9 @@ def _flash_attn_bwd( dk_accum_tensor, dv_accum_tensor = [ to_cute_tensor(t) for t in (dk_accum, dv_accum) ] - cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor = [ + cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor, cu_total_m_blocks_k_tensor = [ to_cute_tensor(t, assumed_align=4) if t is not None else None - for t in (cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k) + for t in (cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, cu_total_m_blocks_k) ] dQ_semaphore_tensor, dK_semaphore_tensor, dV_semaphore_tensor = [ utils.convert_from_dlpack_leading_static(t.detach(), leading_dim=3, alignment=4, stride_order=t.dim_order()) @@ -2002,6 +2087,7 @@ def _flash_attn_bwd( dV_semaphore_tensor, cute_aux_tensors, sparse_tensors_compile, + cu_total_m_blocks_k_tensor, current_stream, options="--enable-tvm-ffi", ) @@ -2040,6 +2126,7 @@ def _flash_attn_bwd( ) if normalized_block_sparse_tensors is not None else None, + cu_total_m_blocks_k, ) # Postprocess: convert dq_accum from float32 to dq in bf16/fp16 # hd=256 2CTA backward has its own internal postprocess, skip here. @@ -2058,6 +2145,7 @@ def _flash_attn_bwd( arch, dtype, head_dim, m_block_size, num_threads_post_dQ, AtomLayoutMdQ, dQ_swapAB, use_2cta_instrs=use_2cta_instrs, cluster_size=1, + cu_total_m_blocks=cu_total_m_blocks_q, ) if dKV_postprocess: @@ -2068,6 +2156,7 @@ def _flash_attn_bwd( arch, dtype, head_dim, n_block_size, num_threads_post_dKV, AtomLayoutNdKV, dKV_swapAB, cluster_size=cluster_size, + cu_total_m_blocks=cu_total_m_blocks_k, ) # Postprocess: convert dv_accum from float32 to dv in bf16/fp16 _bwd_postprocess_convert( @@ -2682,6 +2771,7 @@ def get_scheduler_metadata( tile_n: int, headdim_v: Optional[int] = None, pack_gqa: Optional[bool] = False, + q_stage: int = 1, causal: bool = False, enable_pdl: bool = False, sort: bool = False, @@ -2719,75 +2809,20 @@ def get_scheduler_metadata( else: n_blocks_per_split = None - # Allocate metadata torch tensors - num_m_blocks = torch.empty(num_batch, dtype=torch.int32, device=device) - num_splits_dynamic = torch.empty(num_batch, dtype=torch.int32, device=device) - virtual_batch_idx = torch.empty(num_batch, dtype=torch.int32, device=device) if sort else None - num_nheads_in_l2 = torch.empty(num_batch, dtype=torch.int32, device=device) if causal else None - tile_count_semaphore = torch.empty(1, dtype=torch.int32, device=device) + is_split_kv = num_splits > 1 + needs_prepare_kernel = is_split_kv or causal or sort - # Compute num_warps based on num_batch (capped at 32) - num_warps = min((num_batch + 30) // 31, 32) - # Round up to the nearest power of 2 - num_warps = 1 << (num_warps - 1).bit_length() + if needs_prepare_kernel: + num_m_blocks = torch.empty(num_batch, dtype=torch.int32, device=device) + num_splits_dynamic = torch.empty(num_batch, dtype=torch.int32, device=device) + virtual_batch_idx = torch.empty(num_batch, dtype=torch.int32, device=device) if sort else None + num_nheads_in_l2 = torch.empty(num_batch, dtype=torch.int32, device=device) if causal else None + tile_count_semaphore = torch.empty(1, dtype=torch.int32, device=device) - cache_key = ( - num_warps, - tile_m, - tile_n, - nheads, - nheads_kv, - headdim, - headdim_v, - causal, - pack_gqa, - enable_pdl, - sort, - cu_seqlens_q is not None, - cu_seqlens_k is not None, - cu_seqlens_k_new is not None, - seqused_q is not None, - seqused_k is not None, - leftpad_k is not None, - num_m_blocks is not None, - num_splits_dynamic is not None, - virtual_batch_idx is not None, - num_nheads_in_l2 is not None, - tile_count_semaphore is not None, - n_blocks_per_split is not None, - zfill_padded_output, - ) + num_warps = min((num_batch + 30) // 31, 32) + num_warps = 1 << (num_warps - 1).bit_length() - if cache_key not in get_scheduler_metadata.compile_cache: - ( - num_m_blocks_cute, - num_splits_dynamic_cute, - virtual_batch_idx_cute, - num_nheads_in_l2_cute, - tile_count_semaphore_cute, - cu_seqlens_q_cute, - cu_seqlens_k_cute, - cu_seqlens_k_new_cute, - seqused_q_cute, - seqused_k_cute, - leftpad_k_cute, - ) = [ - to_cute_tensor(t, assumed_align=4) if t is not None else None - for t in ( - num_m_blocks, - num_splits_dynamic, - virtual_batch_idx, - num_nheads_in_l2, - tile_count_semaphore, - cu_seqlens_q, - cu_seqlens_k, - cu_seqlens_k_new, - seqused_q, - seqused_k, - leftpad_k, - ) - ] - scheduler = FlashPrepareScheduler( + cache_key = ( num_warps, tile_m, tile_n, @@ -2796,63 +2831,141 @@ def get_scheduler_metadata( headdim, headdim_v, causal, - packgqa=pack_gqa, - sort=sort, - zfill_padded_output=zfill_padded_output, - ) - get_scheduler_metadata.compile_cache[cache_key] = cute.compile( - scheduler, - max_seqlen_q, - max_seqlen_k, - seqlen_k_new, - cu_seqlens_q_cute, - cu_seqlens_k_cute, - cu_seqlens_k_new_cute, - seqused_q_cute, - seqused_k_cute, - leftpad_k_cute, - num_batch, - num_splits, - tile_count_semaphore_cute, - num_m_blocks_cute, - num_splits_dynamic_cute, - virtual_batch_idx_cute, - num_nheads_in_l2_cute, - n_blocks_per_split, - cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), - options="--enable-tvm-ffi", + pack_gqa, + enable_pdl, + sort, + cu_seqlens_q is not None, + cu_seqlens_k is not None, + cu_seqlens_k_new is not None, + seqused_q is not None, + seqused_k is not None, + leftpad_k is not None, + num_m_blocks is not None, + num_splits_dynamic is not None, + virtual_batch_idx is not None, + num_nheads_in_l2 is not None, + tile_count_semaphore is not None, + n_blocks_per_split is not None, + zfill_padded_output, ) - if not is_fake_mode(): - get_scheduler_metadata.compile_cache[cache_key]( - max_seqlen_q, - max_seqlen_k, - seqlen_k_new, - cu_seqlens_q, - cu_seqlens_k, - cu_seqlens_k_new, - seqused_q, - seqused_k, - leftpad_k, - num_batch, - num_splits, - tile_count_semaphore, - num_m_blocks, - num_splits_dynamic, - virtual_batch_idx, - num_nheads_in_l2, - n_blocks_per_split, - ) + if cache_key not in get_scheduler_metadata.compile_cache: + ( + num_m_blocks_cute, + num_splits_dynamic_cute, + virtual_batch_idx_cute, + num_nheads_in_l2_cute, + tile_count_semaphore_cute, + cu_seqlens_q_cute, + cu_seqlens_k_cute, + cu_seqlens_k_new_cute, + seqused_q_cute, + seqused_k_cute, + leftpad_k_cute, + ) = [ + to_cute_tensor(t, assumed_align=4) if t is not None else None + for t in ( + num_m_blocks, + num_splits_dynamic, + virtual_batch_idx, + num_nheads_in_l2, + tile_count_semaphore, + cu_seqlens_q, + cu_seqlens_k, + cu_seqlens_k_new, + seqused_q, + seqused_k, + leftpad_k, + ) + ] + scheduler = FlashPrepareScheduler( + num_warps, + tile_m, + tile_n, + nheads, + nheads_kv, + headdim, + headdim_v, + causal, + packgqa=pack_gqa, + sort=sort, + zfill_padded_output=zfill_padded_output, + ) + get_scheduler_metadata.compile_cache[cache_key] = cute.compile( + scheduler, + max_seqlen_q, + max_seqlen_k, + seqlen_k_new, + cu_seqlens_q_cute, + cu_seqlens_k_cute, + cu_seqlens_k_new_cute, + seqused_q_cute, + seqused_k_cute, + leftpad_k_cute, + num_batch, + num_splits, + tile_count_semaphore_cute, + num_m_blocks_cute, + num_splits_dynamic_cute, + virtual_batch_idx_cute, + num_nheads_in_l2_cute, + n_blocks_per_split, + cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), + options="--enable-tvm-ffi", + ) - return SchedulerMetadataTensorsTorch( - num_m_blocks_ptr=num_m_blocks, - num_splits_dynamic_ptr=num_splits_dynamic, - virtual_batch_idx_ptr=virtual_batch_idx, - num_nheads_in_l2_ptr=num_nheads_in_l2, - tile_count_semaphore=tile_count_semaphore, - cu_total_m_blocks_ptr=None, - cu_total_splits_m_blocks_ptr=None, - ) + if not is_fake_mode(): + get_scheduler_metadata.compile_cache[cache_key]( + max_seqlen_q, + max_seqlen_k, + seqlen_k_new, + cu_seqlens_q, + cu_seqlens_k, + cu_seqlens_k_new, + seqused_q, + seqused_k, + leftpad_k, + num_batch, + num_splits, + tile_count_semaphore, + num_m_blocks, + num_splits_dynamic, + virtual_batch_idx, + num_nheads_in_l2, + n_blocks_per_split, + ) + else: + num_m_blocks = None + num_splits_dynamic = None + virtual_batch_idx = None + num_nheads_in_l2 = None + tile_count_semaphore = None + + if is_fake_mode(): + return + + qhead_per_kvhead = nheads // nheads_kv + cu_total_m_blocks, cu_total_splits_m_blocks = _compute_tile_cumsum( + num_m_blocks=num_m_blocks, + cu_seqlens=cu_seqlens_q, + seqused=seqused_q, + num_splits_dynamic=num_splits_dynamic, + virtual_batch_idx=virtual_batch_idx, + tile_size=tile_m, + q_stage=q_stage, + qhead_per_kvhead=qhead_per_kvhead, + pack_gqa=bool(pack_gqa), + ) + + return SchedulerMetadataTensorsTorch( + num_m_blocks_ptr=num_m_blocks, + num_splits_dynamic_ptr=num_splits_dynamic, + virtual_batch_idx_ptr=virtual_batch_idx, + num_nheads_in_l2_ptr=num_nheads_in_l2, + tile_count_semaphore=tile_count_semaphore, + cu_total_m_blocks=cu_total_m_blocks, + cu_total_splits_m_blocks=cu_total_splits_m_blocks, + ) get_scheduler_metadata.compile_cache = get_jit_cache("scheduler_metadata") diff --git a/flash_attn/cute/prepare_scheduler.py b/flash_attn/cute/prepare_scheduler.py index 52ef2e881a4..b7ee8172a18 100644 --- a/flash_attn/cute/prepare_scheduler.py +++ b/flash_attn/cute/prepare_scheduler.py @@ -25,8 +25,8 @@ class SchedulerMetadataTensorsTorch(NamedTuple): # tensors of shape (batch + 1) # cu_total_m_blocks[b+1] = sum_{i<=b} num_m_blocks[i] # cu_total_splits_m_blocks[b+1] = sum_{i<=b} num_m_blocks[i] * num_splits_dynamic[i] - cu_total_m_blocks_ptr: Optional[torch.Tensor] = None - cu_total_splits_m_blocks_ptr: Optional[torch.Tensor] = None + cu_total_m_blocks: Optional[torch.Tensor] = None + cu_total_splits_m_blocks: Optional[torch.Tensor] = None class FlashPrepareScheduler: diff --git a/tests/cute/test_flash_attn.py b/tests/cute/test_flash_attn.py index 1a15c0cdc27..fa7fd9c9881 100644 --- a/tests/cute/test_flash_attn.py +++ b/tests/cute/test_flash_attn.py @@ -994,6 +994,174 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): ).abs().max().item() + dv_atol +@pytest.mark.parametrize( + "cumsum_mode", ["jit_cumsum", "metadata_cumsum_only", "metadata_full"] +) +@pytest.mark.parametrize("causal", [False, True]) +@pytest.mark.parametrize("qhead_per_kvhead", [1, 4]) +@retry_on_oom +@maybe_fake_tensor_mode(USE_FAKE_TENSOR) +def test_flash_attn_varlen_cumsum_metadata_paths(causal, cumsum_mode, qhead_per_kvhead): + """Exercise the cu_total_m_blocks fast paths end-to-end. + + - "jit_cumsum": batch_size > 512 varlen, no scheduler_metadata. Triggers + the just-in-time host cumsum in _flash_attn_fwd and the hoisted Q/K + cumsum in _flash_attn_bwd. + - "metadata_cumsum_only": scheduler_metadata from get_scheduler_metadata + with num_splits=1 — skips the FlashPrepareScheduler kernel and returns + only cu_total_m_blocks. Fwd reads it from scheduler_metadata. + - "metadata_full": scheduler_metadata with num_splits>1 (SM100 only). + Runs the full prepare kernel and populates both cu_total tensors. + """ + if cumsum_mode == "metadata_full" and (IS_SM90 or DISABLE_SPLIT): + pytest.skip("split-kv not yet implemented on SM90") + device = "cuda" + torch.manual_seed(0) + random.seed(0) + + if cumsum_mode == "jit_cumsum": + batch_size = 600 + else: + batch_size = 64 + seqlen_q = seqlen_k = 64 + nheads_kv = 4 + nheads = nheads_kv * qhead_per_kvhead + d = dv = 128 + dtype = torch.bfloat16 + num_splits = 4 if cumsum_mode == "metadata_full" else 1 + + q_ref = torch.randn( + batch_size, seqlen_q, nheads, d, device=device, dtype=dtype + ).requires_grad_() + k_ref = torch.randn( + batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype + ).requires_grad_() + v_ref = torch.randn( + batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype + ).requires_grad_() + q, k, v = [x.detach().requires_grad_() for x in (q_ref, k_ref, v_ref)] + + query_padding_mask = generate_random_padding_mask( + seqlen_q, batch_size, device, mode="third" + ) + key_padding_mask = generate_random_padding_mask( + seqlen_k, batch_size, device, mode="third" + ) + ( + q_unpad, + k_unpad, + v_unpad, + _qv_unpad, + cu_seqlens_q, + cu_seqlens_k, + _seqused_q, + _seqused_k, + max_seqlen_q, + max_seqlen_k, + _q, + _k, + _v, + _qv, + output_pad_fn, + dq_pad_fn, + dk_pad_fn, + ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False) + q_unpad = q_unpad.detach().requires_grad_() + k_unpad = k_unpad.detach().requires_grad_() + v_unpad = v_unpad.detach().requires_grad_() + + scheduler_metadata = None + if cumsum_mode != "jit_cumsum": + scheduler_metadata = get_scheduler_metadata( + num_batch=batch_size, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + nheads=nheads, + nheads_kv=nheads_kv, + headdim=d, + headdim_v=dv, + num_splits=num_splits, + tile_m=128, + tile_n=128, + causal=causal, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + ) + if is_fake_mode(): + return + assert scheduler_metadata.cu_total_m_blocks is not None + if cumsum_mode == "metadata_cumsum_only" and not causal: + # FlashPrepareScheduler is skipped only when num_splits == 1 and not causal and not sort. + assert scheduler_metadata.num_m_blocks_ptr is None + assert scheduler_metadata.tile_count_semaphore is None + if cumsum_mode == "metadata_full": + assert scheduler_metadata.num_m_blocks_ptr is not None + assert scheduler_metadata.cu_total_splits_m_blocks is not None + + out_ref, _ = attention_ref( + q_ref, k_ref, v_ref, query_padding_mask, key_padding_mask, causal=causal + ) + out_pt, _ = attention_ref( + q_ref, + k_ref, + v_ref, + query_padding_mask, + key_padding_mask, + causal=causal, + upcast=False, + reorder_ops=True, + ) + + out_unpad, _ = flash_attn_varlen_func( + q_unpad, + k_unpad, + v_unpad, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + causal=causal, + scheduler_metadata=scheduler_metadata, + num_splits=num_splits, + ) + if is_fake_mode(): + return + out = output_pad_fn(out_unpad) + + fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() + assert (out - out_ref).abs().max().item() <= 2 * ( + out_pt - out_ref + ).abs().max().item() + fwd_atol + + if cumsum_mode == "metadata_full": + return # split-kv bwd not supported + + g_unpad = torch.randn_like(out_unpad) + dq_unpad, dk_unpad, dv_unpad = torch.autograd.grad( + out_unpad, (q_unpad, k_unpad, v_unpad), g_unpad + ) + dq = dq_pad_fn(dq_unpad) + dk = dk_pad_fn(dk_unpad) + dv = dk_pad_fn(dv_unpad) + dq.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) + dk.masked_fill_(rearrange(~key_padding_mask, "b s -> b s 1 1"), 0.0) + dv.masked_fill_(rearrange(~key_padding_mask, "b s -> b s 1 1"), 0.0) + + g = output_pad_fn(g_unpad) + dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), g) + dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g) + + for name, x, x_ref, x_pt in [ + ("dq", dq, dq_ref, dq_pt), + ("dk", dk, dk_ref, dk_pt), + ("dv", dv, dv_ref, dv_pt), + ]: + atol = 2 * (x_ref + 0.3 - 0.3 - x_ref).abs().max().item() + assert (x - x_ref).abs().max().item() <= 2 * ( + x_pt - x_ref + ).abs().max().item() + atol, name + + # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) @pytest.mark.parametrize("dtype", [torch.bfloat16]) # @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) From 5b992d5bc97d458dbe580d2429c8d957c17b9f5d Mon Sep 17 00:00:00 2001 From: reubenconducts Date: Tue, 26 May 2026 22:37:56 +0000 Subject: [PATCH 12/15] fix linter errors --- flash_attn/cute/block_info.py | 6 +++--- flash_attn/cute/prepare_scheduler.py | 4 ++-- flash_attn/cute/tile_scheduler.py | 10 +++------- 3 files changed, 8 insertions(+), 12 deletions(-) diff --git a/flash_attn/cute/block_info.py b/flash_attn/cute/block_info.py index 9abdac17a4e..9cadfd38651 100644 --- a/flash_attn/cute/block_info.py +++ b/flash_attn/cute/block_info.py @@ -19,9 +19,9 @@ class BlockInfo: window_size_left: Optional[Int32] = None window_size_right: Optional[Int32] = None qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1 - num_splits: Int32 = 1 - num_splits_dynamic_ptr: Optional[cute.Tensor] = None - num_n_blocks_per_split: Optional[cutlass.Constexpr[Int32]] = None + num_splits: Int32 = 1 + num_splits_dynamic_ptr: Optional[cute.Tensor] = None + num_n_blocks_per_split: Optional[cutlass.Constexpr[Int32]] = None @cute.jit def get_n_block_min_max( diff --git a/flash_attn/cute/prepare_scheduler.py b/flash_attn/cute/prepare_scheduler.py index b7ee8172a18..4532d5e3d70 100644 --- a/flash_attn/cute/prepare_scheduler.py +++ b/flash_attn/cute/prepare_scheduler.py @@ -1,13 +1,13 @@ # A reimplementation of https://github.com/Dao-AILab/flash-attention/blob/main/hopper/flash_prepare_scheduler.cu # from CUTLASS C++ to Cute-DSL. -from typing import Tuple, Optional, Callable, List, NamedTuple +from typing import Tuple, Optional, NamedTuple import operator import torch import cuda.bindings.driver as cuda import cutlass import cutlass.cute as cute -from cutlass import Boolean, Int32, const_expr, Constexpr, Float32 +from cutlass import Int32, const_expr, Float32 from cutlass.cute import FastDivmodDivisor import flash_attn.cute.utils as utils diff --git a/flash_attn/cute/tile_scheduler.py b/flash_attn/cute/tile_scheduler.py index f9282d9c2ca..a4f818eb26b 100644 --- a/flash_attn/cute/tile_scheduler.py +++ b/flash_attn/cute/tile_scheduler.py @@ -298,9 +298,7 @@ def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: else: split_idx = Int32(0) # Pack dynamic per-batch num_splits into high 16 bits of split_idx - if const_expr( - self.params.is_split_kv and self.params.num_splits_dynamic_ptr is not None - ): + if const_expr(self.params.is_split_kv and self.params.num_splits_dynamic_ptr is not None): if is_valid: num_splits = Int32(self.params.num_splits_dynamic_ptr[batch_idx]) split_idx = split_idx | (num_splits << 16) @@ -464,7 +462,7 @@ class Params(ParamsBase): scheduling_mode: cutlass.Constexpr[SchedulingMode] = SchedulingMode.STATIC lpt: cutlass.Constexpr[bool] = True use_cluster_idx: cutlass.Constexpr[bool] = True - num_splits_dynamic_ptr: Optional[cute.Tensor] = None + num_splits_dynamic_ptr: Optional[cute.Tensor] = None @staticmethod @cute.jit @@ -611,9 +609,7 @@ def clc_work_to_coords(self, work) -> WorkTileInfo: bidx_in_cluster = cute.arch.block_in_cluster_idx() block_idx = block_idx * self.params.cluster_shape_m + bidx_in_cluster[0] # Pack dynamic per-batch num_splits into high 16 bits of split_idx - if const_expr( - self.params.is_split_kv and self.params.num_splits_dynamic_ptr is not None - ): + if const_expr(self.params.is_split_kv and self.params.num_splits_dynamic_ptr is not None): if work.is_valid_tile: num_splits = Int32(self.params.num_splits_dynamic_ptr[batch_idx]) split_idx = split_idx | (num_splits << 16) From 5391f12e56f07909196b2de027097faff2b1413f Mon Sep 17 00:00:00 2001 From: reubenconducts Date: Mon, 1 Jun 2026 17:18:21 +0000 Subject: [PATCH 13/15] wip: modify scheduler metadata public api --- flash_attn/cute/interface.py | 70 ++++++++++++++++++++++++++++++++---- 1 file changed, 64 insertions(+), 6 deletions(-) diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index fadc5aa191e..6d977d992db 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -1,6 +1,7 @@ # Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. # [2025-07-04] Version in Cute-DSL, for Hopper and Blackwell. You'll need install nvidia-cutlass-dsl==4.2.0. +from numpy._core.defchararray import zfill import os import math from dataclasses import dataclass @@ -733,7 +734,7 @@ def _flash_attn_fwd( and not disable_scheduler_metadata and not use_dedicated_hd256_kernel ): - scheduler_metadata = get_scheduler_metadata( + scheduler_metadata = _get_scheduler_metadata( num_batch=batch_size, max_seqlen_q=max_seqlen_q, max_seqlen_k=max_seqlen_k, @@ -2759,7 +2760,7 @@ def flash_attn_combine( return out, lse -def get_scheduler_metadata( +def _get_scheduler_metadata( num_batch: int, max_seqlen_q: int, max_seqlen_k: int, @@ -2849,7 +2850,7 @@ def get_scheduler_metadata( zfill_padded_output, ) - if cache_key not in get_scheduler_metadata.compile_cache: + if cache_key not in _get_scheduler_metadata.compile_cache: ( num_m_blocks_cute, num_splits_dynamic_cute, @@ -2891,7 +2892,7 @@ def get_scheduler_metadata( sort=sort, zfill_padded_output=zfill_padded_output, ) - get_scheduler_metadata.compile_cache[cache_key] = cute.compile( + _get_scheduler_metadata.compile_cache[cache_key] = cute.compile( scheduler, max_seqlen_q, max_seqlen_k, @@ -2915,7 +2916,7 @@ def get_scheduler_metadata( ) if not is_fake_mode(): - get_scheduler_metadata.compile_cache[cache_key]( + _get_scheduler_metadata.compile_cache[cache_key]( max_seqlen_q, max_seqlen_k, seqlen_k_new, @@ -2968,4 +2969,61 @@ def get_scheduler_metadata( ) -get_scheduler_metadata.compile_cache = get_jit_cache("scheduler_metadata") +_get_scheduler_metadata.compile_cache = get_jit_cache("scheduler_metadata") + + +def get_scheduler_metadata( + max_seqlen_q: int, + max_seqlen_k: int, + nheads: int, + nheads_kv: int, + headdim: int, + num_splits: int, + headdim_v: Optional[int] = None, + pack_gqa: Optional[int] = None, + causal: bool = False, + enable_pdl: bool = False, + sort: bool = False, + seqlen_k_new: int = 0, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k: Optional[torch.Tensor] = None, + cu_seqlens_k_new: Optional[torch.Tensor] = None, + seqused_q: Optional[torch.Tensor] = None, + seqused_k: Optional[torch.Tensor] = None, + leftpad_k: Optional[torch.Tensor] = None, + seqlen_k_per_split: Optional[int] = None, +) -> SchedulerMetadataTensorsTorch: + """Public entrypoint for scheduler metadata computation""" + num_batch = cu_seqlens_q.shape[0] - 1 # TODO: ensure batch size consistent across tensors + + # TODO: get tile size and q stage from heuristic (same as fwd) + tile_m = 128 + tile_n = 128 + q_stage = 1 + + return _get_scheduler_metadata( + num_batch, + max_seqlen_q, + max_seqlen_k, + nheads, + nheads_kv, + headdim, + num_splits, + tile_m, + tile_n, + headdim_v=headdim_v, + pack_gqa=pack_gqa, + q_stage=q_stage, + causal=causal, + enable_pdl=enable_pdl, + sort=sort, + seqlen_k_new=seqlen_k_new, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + cu_seqlens_k_new=cu_seqlens_k_new, + seqused_q=seqused_q, + seqused_k=seqused_k, + leftpad_k=leftpad_k, + seqlen_k_per_split=seqlen_k_per_split, + zfill_padded_output=True, + ) From 341651470ecc19811e88989a0bd176f88070e70f Mon Sep 17 00:00:00 2001 From: reubenconducts Date: Mon, 1 Jun 2026 21:14:25 +0000 Subject: [PATCH 14/15] clean up scheduler metadata API; add docstrings; split out _get_fwd_config method; remove cluster_size==1 restriction; guard architectures against unused scheduler metadata args --- flash_attn/cute/interface.py | 351 +++++++++++++++++++++--------- flash_attn/cute/tile_scheduler.py | 3 +- tests/cute/test_flash_attn.py | 9 - 3 files changed, 252 insertions(+), 111 deletions(-) diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index b510a7b0095..5ac837b4950 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -117,6 +117,8 @@ class FwdConfig: n_block_size: int mma_pv_is_rs: bool intra_wg_overlap: bool + q_stage: int = 1 + num_splits: int = 1 def _tile_size_fwd_sm90(head_dim, head_dim_v, is_causal, is_local, sparse_block_size_q=None): @@ -271,6 +273,100 @@ def num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, max_splits): return min(num_SMs // total_mblocks, max_splits, num_n_blocks) +def _get_fwd_config( + *, + arch: int, + head_dim: int, + head_dim_v: int, + max_seqlen_q: int, + max_seqlen_k: int, + num_head_kv: int, + qhead_per_kvhead: int, + pack_gqa: bool, + batch_size: int, + causal: bool, + local: bool, + window_size_left: Optional[int], + window_size_right: Optional[int], + num_splits: int, + device, + seqlen_q: Optional[int] = None, + tile_mn: Optional[Tuple[int, int]] = None, + block_sparse_tensors: Optional[BlockSparseTensorsTorch] = None, + mma_pv_is_rs: Optional[bool] = None, + intra_wg_overlap: Optional[bool] = None, +) -> FwdConfig: + if seqlen_q is None: + seqlen_q = max_seqlen_q + + # Base tile sizes and flags: explicit override, else per-arch heuristic. + cfg = FwdConfig(128, 128, True, True) + if tile_mn is None: + if arch // 10 == 12: + # SM120 tile sizes tuned for 99 KB SMEM capacity: + # D<=64: 128x128 → 48 KB (good occupancy) + # D>64: 128x64 → 64 KB (128x128 would use 96 KB, hurting occupancy) + if head_dim > 64: + cfg = FwdConfig(128, 64, True, True) + elif arch // 10 == 8: + cfg = FwdConfig(128, 64, True, True) # SM80, should tune + elif arch // 10 == 9: + sparse_q = get_sparse_q_block_size(block_sparse_tensors, seqlen_q) + cfg = _tile_size_fwd_sm90( + head_dim, head_dim_v, causal, local, sparse_block_size_q=sparse_q + ) + else: + cfg = FwdConfig(tile_mn[0], tile_mn[1], cfg.mma_pv_is_rs, cfg.intra_wg_overlap) + + tile_m, tile_n = cfg.m_block_size, cfg.n_block_size + if mma_pv_is_rs is None: + mma_pv_is_rs = cfg.mma_pv_is_rs + if intra_wg_overlap is None: + intra_wg_overlap = cfg.intra_wg_overlap + + seqlen_q_packgqa = max_seqlen_q * (qhead_per_kvhead if pack_gqa else 1) + if arch // 10 in [10, 11]: + q_stage = 2 if seqlen_q_packgqa > tile_m else 1 + else: + q_stage = 1 + + m_block_size_effective = q_stage * tile_m + seqlen_k_loaded = ( + max_seqlen_k + if not local + else max( + 0, + min( + max_seqlen_k, + (window_size_right or max_seqlen_k) + + (window_size_left or max_seqlen_k) + + 1 + + tile_m, + ), + ) + ) + num_m_blocks = (seqlen_q_packgqa + m_block_size_effective - 1) // m_block_size_effective + total_mblocks = batch_size * num_head_kv * num_m_blocks + num_n_blocks = (seqlen_k_loaded + tile_n - 1) // tile_n + num_SMs = ( + 132 if is_fake_mode() else torch.cuda.get_device_properties(device).multi_processor_count + ) + if num_splits < 1: + num_splits = num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, 128) + + # SplitKV uses float32 partial output, which doubles the O buffer size + # in shared memory, causing OOM for diff-headdim (192, 128) + if arch // 10 in [10, 11] and head_dim != head_dim_v and num_splits > 1: + if num_n_blocks >= 64 and head_dim_v != 512: + tile_n = 64 + num_n_blocks = (seqlen_k_loaded + tile_n - 1) // tile_n + num_splits = num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, 128) + else: + num_splits = 1 + + return FwdConfig(tile_m, tile_n, mma_pv_is_rs, intra_wg_overlap, q_stage, num_splits) + + def _resolve_causal_local_window(causal, window_size_left, window_size_right, mask_mod=None): """Resolve causal/local/window settings into canonical form. @@ -303,6 +399,7 @@ def _compute_tile_cumsum( virtual_batch_idx: Optional[torch.Tensor] = None, tile_size: int = 1, q_stage: int = 1, + cluster_shape_m: int = 1, qhead_per_kvhead: int = 1, pack_gqa: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: @@ -316,6 +413,7 @@ def _compute_tile_cumsum( seqlens = seqlens * qhead_per_kvhead num_m_blocks = (seqlens + tile_size - 1) // tile_size num_m_blocks_eff = (num_m_blocks + q_stage - 1) // q_stage + num_m_blocks_eff = (num_m_blocks_eff + cluster_shape_m - 1) // cluster_shape_m order = virtual_batch_idx.long() if virtual_batch_idx is not None else None if order is not None: num_m_blocks_eff = num_m_blocks_eff[order] @@ -369,7 +467,6 @@ def _flash_attn_fwd( scheduler_metadata: Optional[SchedulerMetadataTensorsTorch] = None, seqlen_k_per_split: Optional[int] = None, disable_scheduler_metadata: bool = False, - zfill_padded_output: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: """Forward pass for FlashAttention. @@ -559,33 +656,10 @@ def _flash_attn_fwd( if arch // 10 in [8, 12]: num_threads = 128 - fwd_cfg = FwdConfig(128, 128, True, True) # default - if tile_mn is None: - if arch // 10 == 12: - # SM120 tile sizes tuned for 99 KB SMEM capacity: - # D<=64: 128x128 → 48 KB (good occupancy) - # D>64: 128x64 → 64 KB (128x128 would use 96 KB, hurting occupancy) - if head_dim <= 64: - fwd_cfg = FwdConfig(128, 128, True, True) - else: - fwd_cfg = FwdConfig(128, 64, True, True) - elif arch // 10 == 8: - fwd_cfg = FwdConfig(128, 64, True, True) # SM80, should tune - elif arch // 10 == 9: - sparse_q = get_sparse_q_block_size(block_sparse_tensors, seqlen_q) - fwd_cfg = _tile_size_fwd_sm90(head_dim, head_dim_v, causal, local, sparse_block_size_q=sparse_q) - else: - fwd_cfg = FwdConfig(tile_mn[0], tile_mn[1], fwd_cfg.mma_pv_is_rs, fwd_cfg.intra_wg_overlap) - tile_m, tile_n = fwd_cfg.m_block_size, fwd_cfg.n_block_size - if mma_pv_is_rs is None: - mma_pv_is_rs = fwd_cfg.mma_pv_is_rs - if intra_wg_overlap is None: - intra_wg_overlap = fwd_cfg.intra_wg_overlap - # TODO: fix GQA + SplitKV + non-varlen if pack_gqa and num_splits != 1 and cu_seqlens_q is None: pack_gqa = False - + if pack_gqa and qv is not None and 128 % qhead_per_kvhead != 0: pack_gqa = False @@ -594,32 +668,38 @@ def _flash_attn_fwd( if max_seqlen_k is None: max_seqlen_k = seqlen_k if cu_seqlens_k is None and seqused_k is None: - min_seqlen_k = seqlen_k - seqlen_q_packgqa = max_seqlen_q * qhead_per_kvhead - if arch // 10 in [10, 11]: - q_stage = 2 if seqlen_q_packgqa > tile_m else 1 - else: - q_stage = 1 + min_seqlen_k = seqlen_k - m_block_size_effective = q_stage * tile_m - max_m_blocks_leq_one = seqlen_q_packgqa <= m_block_size_effective - seqlen_k_loaded = max_seqlen_k if not local else max(0, min(max_seqlen_k, (window_size_right or max_seqlen_k) + (window_size_left or max_seqlen_k) + 1 + tile_m)) - num_m_blocks = (seqlen_q_packgqa + m_block_size_effective - 1) // m_block_size_effective - total_mblocks = batch_size * num_head_kv * num_m_blocks - num_n_blocks = (seqlen_k_loaded + tile_n - 1) // tile_n - num_SMs = 132 if is_fake_mode() else torch.cuda.get_device_properties(device).multi_processor_count - if num_splits < 1: - num_splits = num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, 128) + fwd_cfg = _get_fwd_config( + arch=arch, + head_dim=head_dim, + head_dim_v=head_dim_v, + causal=causal, + local=local, + window_size_left=window_size_left, + window_size_right=window_size_right, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + qhead_per_kvhead=qhead_per_kvhead, + pack_gqa=pack_gqa, + batch_size=batch_size, + num_head_kv=num_head_kv, + num_splits=num_splits, + device=device, + seqlen_q=seqlen_q, + tile_mn=tile_mn, + block_sparse_tensors=block_sparse_tensors, + mma_pv_is_rs=mma_pv_is_rs, + intra_wg_overlap=intra_wg_overlap, + ) + tile_m, tile_n = fwd_cfg.m_block_size, fwd_cfg.n_block_size + q_stage = fwd_cfg.q_stage + num_splits = fwd_cfg.num_splits + mma_pv_is_rs = fwd_cfg.mma_pv_is_rs + intra_wg_overlap = fwd_cfg.intra_wg_overlap - # SplitKV uses float32 partial output, which doubles the O buffer size - # in shared memory, causing OOM for diff-headdim (192, 128) - if arch // 10 in [10, 11] and head_dim != head_dim_v and num_splits > 1: - if num_n_blocks >= 64 and head_dim_v != 512: - tile_n = 64 - num_n_blocks = (seqlen_k_loaded + tile_n - 1) // tile_n - num_splits = num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, 128) - else: - num_splits = 1 + seqlen_q_packgqa = max_seqlen_q * (qhead_per_kvhead if pack_gqa else 1) + max_m_blocks_leq_one = seqlen_q_packgqa <= q_stage * tile_m is_split_kv = num_splits > 1 if is_split_kv: @@ -782,6 +862,7 @@ def _flash_attn_fwd( seqused_k=seqused_k, seqlen_k_per_split=seqlen_k_per_split, q_stage=q_stage, + cluster_shape_m=2 if use_2cta_instrs else 1, ) has_scheduler_metadata = scheduler_metadata is not None and not disable_scheduler_metadata @@ -890,7 +971,6 @@ def _flash_attn_fwd( mma_pv_is_rs, intra_wg_overlap, use_clc_scheduler, - num_m_blocks is not None, num_splits_dynamic is not None, virtual_batch_idx is not None, num_nheads_in_l2 is not None, @@ -972,34 +1052,24 @@ def _flash_attn_fwd( if aux_tensors is not None: cute_aux_tensors = [to_cute_aux_tensor(buf) for buf in aux_tensors] - num_splits_dynamic_tensor = ( - to_cute_tensor(num_splits_dynamic, assumed_align=4, leading_dim=0) - if num_splits_dynamic is not None else None - ) - tile_count_semaphore_tensor = ( - to_cute_tensor(tile_count_semaphore, assumed_align=4, leading_dim=0) - if tile_count_semaphore is not None else None - ) - num_m_blocks_tensor = ( - to_cute_tensor(num_m_blocks, assumed_align=4, leading_dim=0) - if num_m_blocks is not None else None - ) - virtual_batch_idx_tensor = ( - to_cute_tensor(virtual_batch_idx, assumed_align=4, leading_dim=0) - if virtual_batch_idx is not None else None - ) - num_nheads_in_l2_tensor = ( - to_cute_tensor(num_nheads_in_l2, assumed_align=4, leading_dim=0) - if num_nheads_in_l2 is not None else None - ) - cu_total_m_blocks_tensor = ( - to_cute_tensor(cu_total_m_blocks, assumed_align=4, leading_dim=0) - if cu_total_m_blocks is not None else None - ) - cu_total_splits_m_blocks_tensor = ( - to_cute_tensor(cu_total_splits_m_blocks, assumed_align=4, leading_dim=0) - if cu_total_splits_m_blocks is not None else None - ) + ( + num_splits_dynamic_tensor, + tile_count_semaphore_tensor, + virtual_batch_idx_tensor, + num_nheads_in_l2_tensor, + cu_total_m_blocks_tensor, + cu_total_splits_m_blocks_tensor, + ) = [ + to_cute_tensor(t, assumed_align=4, leading_dim=0) + for t in ( + num_splits_dynamic, + tile_count_semaphore, + virtual_batch_idx, + num_nheads_in_l2, + cu_total_m_blocks, + cu_total_splits_m_blocks, + ) + ] qv_tensor = to_cute_tensor(qv) if qv is not None else None gather_kv_indices_tensor = to_cute_tensor(gather_kv_indices) if gather_kv_indices is not None else None @@ -1198,20 +1268,21 @@ def _flash_attn_fwd( sparse_tensors, cute_aux_tensors, ]) - if not use_dedicated_hd256_kernel: + if arch // 10 in [10, 11] and not use_dedicated_hd256_kernel: compile_args.extend([ num_splits_dynamic_tensor, tile_count_semaphore_tensor, - ]) - if arch // 10 == 9: - compile_args.append(num_m_blocks_tensor) - compile_args.extend([ virtual_batch_idx_tensor, num_nheads_in_l2_tensor, cu_total_m_blocks_tensor, cu_total_splits_m_blocks_tensor, max_seqlen_q, ]) + elif arch // 10 in [8, 9, 12]: + compile_args.extend([ + cu_total_m_blocks_tensor, + cu_total_splits_m_blocks_tensor, + ]) compile_args.append(current_stream) _flash_attn_fwd.compile_cache[compile_key] = cute.compile(*compile_args, options="--enable-tvm-ffi") @@ -1285,20 +1356,21 @@ def _flash_attn_fwd( else None, aux_tensors, ]) - if not use_dedicated_hd256_kernel: + if arch // 10 in [10, 11] and not use_dedicated_hd256_kernel: call_args.extend([ num_splits_dynamic, tile_count_semaphore, - ]) - if arch // 10 == 9: - call_args.append(num_m_blocks) - call_args.extend([ virtual_batch_idx, num_nheads_in_l2, cu_total_m_blocks, cu_total_splits_m_blocks, max_seqlen_q, ]) + elif arch // 10 in [8, 9, 12]: + call_args.extend([ + cu_total_m_blocks, + cu_total_splits_m_blocks, + ]) _flash_attn_fwd.compile_cache[compile_key](*call_args) if is_split_kv: _flash_attn_fwd_combine( @@ -2555,6 +2627,15 @@ def flash_attn_varlen_func( min_seqlen_k: for varlen, specifies the minimum kv sequence length for any batch. Used with gather_kv_indices to determine if we need oob masking. + + scheduler_metadata: optional tensors used by certain tile schedulers, for optimization + and functionality. computed in get_scheduler_metadata. + + seqlen_k_per_split: when using dynamic (per-batch) num_splits, can set a fixed seqlen_k to be + covered per split for bitwise reproducibility. + + disable_scheduler_metadata: if True, ignores scheduler_metadata if it is passed and skips + computing metadata fresh. """ return FlashAttnVarlenFunc.apply( q, @@ -2846,6 +2927,7 @@ def _get_scheduler_metadata( headdim_v: Optional[int] = None, pack_gqa: Optional[bool] = False, q_stage: int = 1, + cluster_shape_m: int = 1, causal: bool = False, enable_pdl: bool = False, sort: bool = False, @@ -2889,8 +2971,12 @@ def _get_scheduler_metadata( if needs_prepare_kernel: num_m_blocks = torch.empty(num_batch, dtype=torch.int32, device=device) num_splits_dynamic = torch.empty(num_batch, dtype=torch.int32, device=device) - virtual_batch_idx = torch.empty(num_batch, dtype=torch.int32, device=device) if sort else None - num_nheads_in_l2 = torch.empty(num_batch, dtype=torch.int32, device=device) if causal else None + virtual_batch_idx = ( + torch.empty(num_batch, dtype=torch.int32, device=device) if sort else None + ) + num_nheads_in_l2 = ( + torch.empty(num_batch, dtype=torch.int32, device=device) if causal else None + ) tile_count_semaphore = torch.empty(1, dtype=torch.int32, device=device) num_warps = min((num_batch + 30) // 31, 32) @@ -3027,6 +3113,7 @@ def _get_scheduler_metadata( virtual_batch_idx=virtual_batch_idx, tile_size=tile_m, q_stage=q_stage, + cluster_shape_m=cluster_shape_m, qhead_per_kvhead=qhead_per_kvhead, pack_gqa=bool(pack_gqa), ) @@ -3055,8 +3142,8 @@ def get_scheduler_metadata( headdim_v: Optional[int] = None, pack_gqa: Optional[int] = None, causal: bool = False, - enable_pdl: bool = False, - sort: bool = False, + window_size_left: Optional[int] = None, + window_size_right: Optional[int] = None, seqlen_k_new: int = 0, cu_seqlens_q: Optional[torch.Tensor] = None, cu_seqlens_k: Optional[torch.Tensor] = None, @@ -3065,14 +3152,78 @@ def get_scheduler_metadata( seqused_k: Optional[torch.Tensor] = None, leftpad_k: Optional[torch.Tensor] = None, seqlen_k_per_split: Optional[int] = None, + _arch: Optional[int] = None, ) -> SchedulerMetadataTensorsTorch: - """Public entrypoint for scheduler metadata computation""" - num_batch = cu_seqlens_q.shape[0] - 1 # TODO: ensure batch size consistent across tensors + """Prepares metadata tensors used by varlen tile schedulers (SingleTileVarlenScheduler + and DynamicPersistentVarlenScheduler) + + Explanation of selected args: + num_splits: maximum number of splits per batch entry that the prepare kernel can emit + seqlen_k_per_split: for bitwise reproducibility between forward and backward, can fix + an exact seqlen_k per split; num_splits is calculated accordingly. + + Returns + SchedulerMetadataTensorsTorch, a named tuple including: + - num_splits_dynamic_ptr: per-batch num_splits + - num_nheads_in_l2_ptr: used for head swizzle to avoid l2 cache thrashing + - tile_count_semaphore: the global semaphore used by DynamicPersistentVarlenScheduler atomic incrementation + - cu_total_m_blocks: cumsum tensor counting total m_blocks, used for binary batch search with large batch_size + - cu_total_splits_m_blocks: complementary cumsum tensor used for binary batch search and to + extract dynamic num splits in the absense of num_splits_dynamic_ptr + """ + arch = _get_device_arch() if _arch is None else _arch + if headdim_v is None: + headdim_v = headdim + + batch_sizes = {} + if cu_seqlens_q is not None: + batch_sizes["cu_seqlens_q"] = cu_seqlens_q.shape[0] - 1 + if cu_seqlens_k is not None: + batch_sizes["cu_seqlens_k"] = cu_seqlens_k.shape[0] - 1 + if seqused_q is not None: + batch_sizes["seqused_q"] = seqused_q.shape[0] + if seqused_k is not None: + batch_sizes["seqused_k"] = seqused_k.shape[0] + assert batch_sizes, ( + "get_scheduler_metadata requires at least one of " + "cu_seqlens_q/cu_seqlens_k/seqused_q/seqused_k" + ) + num_batch = next(iter(batch_sizes.values())) + assert all(b == num_batch for b in batch_sizes.values()), ( + f"inconsistent batch size across inputs: {batch_sizes}" + ) + device = next( + t.device for t in (cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k) if t is not None + ) + + causal, local, window_size_left, window_size_right = _resolve_causal_local_window( + causal, window_size_left, window_size_right + ) - # TODO: get tile size and q stage from heuristic (same as fwd) - tile_m = 128 - tile_n = 128 - q_stage = 1 + qhead_per_kvhead = nheads // nheads_kv + if pack_gqa is None: + pack_gqa = qhead_per_kvhead > 1 + + fwd_cfg = _get_fwd_config( + arch=arch, + head_dim=headdim, + head_dim_v=headdim_v, + causal=causal, + local=local, + window_size_left=window_size_left, + window_size_right=window_size_right, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + qhead_per_kvhead=qhead_per_kvhead, + pack_gqa=pack_gqa, + batch_size=num_batch, + num_head_kv=nheads_kv, + num_splits=num_splits, + device=device, + ) + tile_m, tile_n = fwd_cfg.m_block_size, fwd_cfg.n_block_size + q_stage = fwd_cfg.q_stage + num_splits = fwd_cfg.num_splits return _get_scheduler_metadata( num_batch, @@ -3088,8 +3239,8 @@ def get_scheduler_metadata( pack_gqa=pack_gqa, q_stage=q_stage, causal=causal, - enable_pdl=enable_pdl, - sort=sort, + enable_pdl=False, # pdl not yet enabled + sort=False, # LPT batch sort not yet enabled seqlen_k_new=seqlen_k_new, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, diff --git a/flash_attn/cute/tile_scheduler.py b/flash_attn/cute/tile_scheduler.py index a4f818eb26b..a33cd492698 100644 --- a/flash_attn/cute/tile_scheduler.py +++ b/flash_attn/cute/tile_scheduler.py @@ -996,8 +996,7 @@ def decode( # Both SingleTileVarlen STATIC and CLC; not DynamicPersistent (where # warp-scan's _bidb_start resumption already amortizes per-call cost). use_cumsum_hint = const_expr( - self.cluster_shape_m == 1 - and cu_hint_ptr is not None + cu_hint_ptr is not None and ( self.scheduling_mode == SchedulingMode.STATIC or self.scheduling_mode == SchedulingMode.CLC diff --git a/tests/cute/test_flash_attn.py b/tests/cute/test_flash_attn.py index 51a5db34b7e..4d1779fe7f4 100644 --- a/tests/cute/test_flash_attn.py +++ b/tests/cute/test_flash_attn.py @@ -826,7 +826,6 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): continue if precompute_metadata: scheduler_metadata = get_scheduler_metadata( - num_batch=batch_size, max_seqlen_q=seqlen_q, max_seqlen_k=seqlen_k, nheads=nheads, @@ -834,8 +833,6 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): headdim=d, headdim_v=dv, num_splits=num_splits, - tile_m=128, - tile_n=128, causal=causal, cu_seqlens_q=cu_seqlens_q if unpad_q else None, cu_seqlens_k=cu_seqlens_k if unpad_kv else None, @@ -1117,7 +1114,6 @@ def test_flash_attn_varlen_cumsum_metadata_paths(causal, cumsum_mode, qhead_per_ scheduler_metadata = None if cumsum_mode != "jit_cumsum": scheduler_metadata = get_scheduler_metadata( - num_batch=batch_size, max_seqlen_q=max_seqlen_q, max_seqlen_k=max_seqlen_k, nheads=nheads, @@ -1125,8 +1121,6 @@ def test_flash_attn_varlen_cumsum_metadata_paths(causal, cumsum_mode, qhead_per_ headdim=d, headdim_v=dv, num_splits=num_splits, - tile_m=128, - tile_n=128, causal=causal, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, @@ -1634,7 +1628,6 @@ def test_flash_attn_kvcache( continue if precompute_metadata: scheduler_metadata = get_scheduler_metadata( - num_batch=batch_size, max_seqlen_q=max_seqlen_q if varlen_q else seqlen_q, max_seqlen_k=seqlen_k, nheads=nheads, @@ -1642,8 +1635,6 @@ def test_flash_attn_kvcache( headdim=d, headdim_v=dv, num_splits=num_splits, - tile_m=128, - tile_n=128, causal=causal, sort=True, cu_seqlens_q=cu_seqlens_q, From 3f1f21777e85594593f8f12b98363c4ea6e201a6 Mon Sep 17 00:00:00 2001 From: Reuben Stern Date: Tue, 2 Jun 2026 14:50:06 +0000 Subject: [PATCH 15/15] address driss' comments --- flash_attn/cute/flash_fwd_sm100.py | 24 ++++++++++++------------ flash_attn/cute/interface.py | 1 - flash_attn/cute/tile_scheduler.py | 15 +++++++++++++-- 3 files changed, 25 insertions(+), 15 deletions(-) diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index 58a984d7fa0..f74a823f8af 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -219,12 +219,12 @@ def __init__( ) self.use_clc_scheduler = use_clc_scheduler - self.dynamic_persistent = (has_tile_count_semaphore and is_varlen_q) or use_clc_scheduler # ClC does not compose with these other features, so disable even if requested self.use_clc_scheduler = ( use_clc_scheduler and self.use_tma_KV ) + self.dynamic_persistent = (has_tile_count_semaphore and is_varlen_q) or use_clc_scheduler self.is_persistent = self.dynamic_persistent or self.is_static_persistent self.sched_stages = 1 if self.use_clc_scheduler: @@ -712,14 +712,8 @@ def __call__( cutlass.max(cute.cosize(sQ_layout), cute.cosize(sO_layout) * self.o_dtype.width // self.q_dtype.width) ) - sched_response_size = ( - self.sched_stages * 4 - if (self.use_clc_scheduler or self.dynamic_persistent) else 0 - ) - sched_mbar_size = ( - self.sched_stages * 2 - if (self.use_clc_scheduler or self.dynamic_persistent) else 0 - ) + sched_response_size = self.sched_stages * 4 if self.dynamic_persistent else 0 + sched_mbar_size = self.sched_stages * 2 if self.dynamic_persistent else 0 load_epi_mbar_size = 2 if const_expr(self.overlap_sO_sQ) else 0 @cute.struct @@ -743,7 +737,8 @@ class SharedStorage: # store row max and row sum sScale: cute.struct.MemRange[Float32, self.q_stage * self.m_block_size * 2] # Scheduler buffers placed here to utilize padding before sO's 1024-byte - # alignment. PipelineClcFetchAsync / PipelineAsync both expect + # alignment. This avoids adding bytes at the end when we're at the smem limit. + # PipelineClcFetchAsync / PipelineAsync both expect # 2 * sched_stages mbarriers (full + empty). Response is 4 Int32 per stage # (CLC HW response, or work_info written by dynamic persistent producer). sched_mbar_ptr: cute.struct.MemRange[Int64, sched_mbar_size] @@ -1070,6 +1065,9 @@ def kernel( pipeline_load_epi = None if const_expr(self.overlap_sO_sQ and self.is_persistent): + # when overlapping sO and sQ with a persistent kernel, we need this + # additional pipeline to ensure sO from the previous work tile is + # free for use by sQ in the current one. epi_warps_for_release = ( ThreadCooperativeGroup(len(self.correction_warp_ids)) if self.use_correction_warps_for_epi @@ -1168,6 +1166,7 @@ def kernel( cutlass_pipeline.Agent.Thread ) num_sched_consumer_warps_per_cta = self.threads_per_cta // cute.arch.WARP_SIZE + # NB on CTA0 warp15 == scheduler on CTA1 == empty but still both consume num_sched_consumer_warps = num_sched_consumer_warps_per_cta * self.cta_group_size sched_consumer_group = cutlass_pipeline.CooperativeGroup( cutlass_pipeline.Agent.Thread, @@ -3082,11 +3081,12 @@ def scheduler_warp( while work_tile.is_valid_tile: tile_scheduler.prefetch_next_work() work_tile = tile_scheduler.advance_to_next_work() - if const_expr(self.use_clc_scheduler): + if const_expr(self.dynamic_persistent): if cute.arch.thread_idx()[0] == self.scheduler_warp_id * cute.arch.WARP_SIZE: + prefix_str = "[CLC] query " if const_expr(self.use_clc_scheduler) else "[DYNAMIC] info " fa_printf( 3, - "[CLC] query sm={} cta={} (m_blk={},h={},b={},s={}) valid={}\n", + prefix_str + "sm={} cta={} (m_blk={},h={},b={},s={}) valid={}\n", smid(), cute.arch.block_idx()[0], work_tile.tile_idx[0], diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 5ac837b4950..ec976b12223 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -1,7 +1,6 @@ # Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. # [2025-07-04] Version in Cute-DSL, for Hopper and Blackwell. You'll need install nvidia-cutlass-dsl==4.2.0. -from numpy._core.defchararray import zfill import os import math from dataclasses import dataclass diff --git a/flash_attn/cute/tile_scheduler.py b/flash_attn/cute/tile_scheduler.py index a33cd492698..4af09907709 100644 --- a/flash_attn/cute/tile_scheduler.py +++ b/flash_attn/cute/tile_scheduler.py @@ -45,7 +45,7 @@ class SchedulerState(ParamsBase): Main kernels construct this via `create_clc` / `create_dynamic_persistent`, which return the appropriate concrete state (`ClcSchedulerState` or `DynamicPersistentSchedulerState`). Schedulers consume it through the - `ctx: SchedulerState | None` parameter on their `create(...)`. + `ctx: SchedulerState | None` parameter on their `__init__(...)`. """ _pipeline: cutlass.pipeline.PipelineAsync @@ -88,7 +88,18 @@ def producer_tail(self, *, loc=None, ip=None): @dataclass class ClcSchedulerState(SchedulerState): - """CLC-backed: `prefetch_next_work` issues the HW query.""" + """Owns the runtime state shared by CLC-capable tile schedulers. + + `FlashAttentionForwardSm100` constructs this state because it owns the CLC + response buffer, mbarrier storage, and launch geometry needed to initialize + the hardware scheduler and async pipeline. Individual tile schedulers then + consume this state and map the returned hardware work tiles into their own + logical `WorkTileInfo` coordinates. + + To add CLC support to a scheduler: + - implement `clc_problem_shape(params)` so the kernel can create the hardware scheduler + - map `ctx.initial_work_tile_info()` and `ctx.get_current_work()` into scheduler coordinates + """ _hw_scheduler: ClcDynamicPersistentTileScheduler