diff --git a/flash_attn/cute/block_sparsity.py b/flash_attn/cute/block_sparsity.py index f19c8fb7f05..c069007873b 100644 --- a/flash_attn/cute/block_sparsity.py +++ b/flash_attn/cute/block_sparsity.py @@ -401,6 +401,7 @@ def to_cute_block_sparse_tensors( """Convert torch block sparsity tensors to CuTe tensors, optionally for tvm ffi""" if not is_block_sparsity_enabled(tensors): return None + ( mask_block_cnt, mask_block_idx, diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index 905deb98848..1f3767cf2d9 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -14,7 +14,7 @@ # https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/blackwell/fmha.py import math -from typing import Type, Tuple, Callable, Optional, Literal +from typing import Tuple, Callable, Optional, Literal from functools import partial import cuda.bindings.driver as cuda @@ -36,6 +36,7 @@ from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned from flash_attn.cute import utils import flash_attn.cute.pipeline as pipeline_custom +import cutlass.pipeline as cutlass_pipeline from flash_attn.cute.mask import AttentionMask from flash_attn.cute.softmax import SoftmaxSm100, apply_score_mod_inner from flash_attn.cute.seqlen_info import SeqlenInfoQK @@ -55,11 +56,15 @@ from quack.cute_dsl_utils import ParamsBase from flash_attn.cute.tile_scheduler import ( TileSchedulerArguments, + TileSchedulerProtocol, SingleTileScheduler, StaticPersistentTileScheduler, SingleTileLPTScheduler, SingleTileVarlenScheduler, + CLCDynamicTileScheduler, ) +from flash_attn.cute.logging import fa_log, fa_printf +from flash_attn.cute.utils import smid class FlashAttentionForwardSm100: @@ -84,6 +89,8 @@ def __init__( paged_kv_non_tma: bool = False, is_varlen_q: bool = False, use_2cta_instrs: bool = False, + use_clc_scheduler: bool = False, + sched_stages: int = 1, ): self.use_tma_KV = not paged_kv_non_tma # self.dtype = dtype @@ -159,6 +166,34 @@ def __init__( "Paged KV does not support irregular head dim" ) + self.use_clc_scheduler = ( + use_clc_scheduler + and self.use_tma_KV + and not self.overlap_sO_sQ + and not is_varlen_q + ) + self.sched_stages = sched_stages + if self.use_clc_scheduler: + assert self.cluster_shape_mn[1] == 1, f"CLC requires cluster N == 1: {self.cluster_shape_mn}" + assert self.cluster_shape_mn[0] in (1, 2), f"bad CLC cluster M: {self.cluster_shape_mn}" + assert self.cluster_shape_mn[0] == self.cta_group_size, ( + f"CLC cluster M != cta_group_size: {self.cluster_shape_mn}, {self.cta_group_size}" + ) + + if is_varlen_q: + self.TileScheduler = SingleTileVarlenScheduler + # NB: check CLC first since it's opt-in and we don't want other schedulers to override it + elif self.use_clc_scheduler: + self.TileScheduler = CLCDynamicTileScheduler + elif self.is_causal or self.is_local: + self.TileScheduler = SingleTileLPTScheduler + elif self.is_persistent: + self.TileScheduler = StaticPersistentTileScheduler + else: + self.TileScheduler = SingleTileScheduler + + fa_log(1, f"use_clc_scheduler={self.use_clc_scheduler}, TileScheduler={self.TileScheduler.__name__}, USE_2CTA={self.use_2cta_instrs}") + self.softmax0_warp_ids = (0, 1, 2, 3) self.softmax1_warp_ids = (4, 5, 6, 7) self.correction_warp_ids = (8, 9, 10, 11) @@ -197,6 +232,8 @@ def __init__( elif self.is_varlen_q: # fallback self.epilogue_warp_ids = (13, 14) + self.clc_scheduler_warp_id = self.empty_warp_ids[0] if self.use_clc_scheduler else None + self.tmem_s_offset = [0, self.n_block_size] # e.g., 0, 128 self.tmem_o_offset = [ self.tmem_s_offset[-1] + self.n_block_size + i * self.head_dim_v_padded @@ -524,17 +561,8 @@ def __call__( vO_layout = cute.make_layout((1, async_copy_elems)) gmem_tiled_copy_O = cute.make_tiled_copy_tv(atom_universal_copy, tO_layout, vO_layout) - if const_expr(mCuSeqlensQ is not None or mSeqUsedQ is not None): - TileScheduler = SingleTileVarlenScheduler - else: - if const_expr(self.is_causal or self.is_local): - TileScheduler = SingleTileLPTScheduler - else: - TileScheduler = ( - SingleTileScheduler - if const_expr(not self.is_persistent) - else StaticPersistentTileScheduler - ) + TileScheduler = self.TileScheduler + tile_sched_args = TileSchedulerArguments( cute.ceil_div(cute.size(mQ.shape[0]), self.cta_tiler[0]), cute.size(mQ.shape[2]), @@ -560,7 +588,12 @@ def __call__( is_split_kv=self.is_split_kv, cluster_shape_mn=self.cluster_shape_mn, ) - tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) + if const_expr(self.use_clc_scheduler): + tile_sched_params = TileScheduler.to_underlying_arguments( + tile_sched_args, sched_stages=self.sched_stages + ) + else: + tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) self.tile_scheduler_cls = TileScheduler grid_dim = TileScheduler.get_grid_shape(tile_sched_params) @@ -570,6 +603,9 @@ def __call__( cutlass.max(cute.cosize(sQ_layout), cute.cosize(sO_layout) * self.o_dtype.width // self.q_dtype.width) ) + clc_response_size = self.sched_stages * 4 if self.use_clc_scheduler else 0 + clc_mbar_size = self.sched_stages * 2 if self.use_clc_scheduler else 0 + @cute.struct class SharedStorage: # m_barriers for pipelines @@ -589,6 +625,13 @@ class SharedStorage: # Smem tensors # store row max and row sum sScale: cute.struct.MemRange[Float32, self.q_stage * self.m_block_size * 2] + # CLC buffers placed here to utilize padding before sO's 1024-byte alignment. + # This avoids adding bytes at the end when we're at the smem limit. + # PipelineClcFetchAsync expects 2 * sched_stages mbarriers (full + empty). + clc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, clc_mbar_size] + # CLC response storage (16 bytes per stage, stored as 4 Int32s). + clc_response: cute.struct.MemRange[Int32, clc_response_size] + # Large TMA buffers with 1024-byte alignment sO: cute.struct.Align[ cute.struct.MemRange[self.o_dtype, sO_size], self.buffer_align_bytes ] @@ -940,17 +983,67 @@ def kernel( window_size_right=window_size_right, qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, ) - TileSchedulerCls = partial(self.tile_scheduler_cls.create, tile_sched_params) - # Cluster wait before tensor memory alloc pipeline_init_wait(cluster_shape_mn=cta_layout_vmnk) + if const_expr(self.use_clc_scheduler): + clc_response_ptr = storage.clc_response.data_ptr() + clc_mbar_ptr = storage.clc_mbar_ptr.data_ptr() + + clc_pipeline_producer_group = cutlass_pipeline.CooperativeGroup( + cutlass_pipeline.Agent.Thread + ) + num_clc_consumer_warps_per_cta = ( + len(self.empty_warp_ids) + + len(self.load_warp_ids) + + 1 # mma_warp_id + + len(self.softmax0_warp_ids) + + len(self.softmax1_warp_ids) + + len(self.correction_warp_ids) + + len(self.epilogue_warp_ids) + ) + # NB on CTA0 warp15 == scheduler on CTA1 == empty but still both consume + num_clc_consumer_warps = num_clc_consumer_warps_per_cta * self.cta_group_size + clc_pipeline_consumer_group = cutlass_pipeline.CooperativeGroup( + cutlass_pipeline.Agent.Thread, cute.arch.WARP_SIZE * num_clc_consumer_warps + ) + clc_pipeline = cutlass_pipeline.PipelineClcFetchAsync.create( + barrier_storage=clc_mbar_ptr, + num_stages=self.sched_stages, + producer_group=clc_pipeline_producer_group, + consumer_group=clc_pipeline_consumer_group, + tx_count=16, + cta_layout_vmnk=cta_layout_vmnk, + ) + + tile_scheduler = self.tile_scheduler_cls.create(tile_sched_params, clc_response_ptr) + clc_consumer_state = cutlass_pipeline.make_pipeline_state( + cutlass_pipeline.PipelineUserType.Consumer, self.sched_stages + ) + tile_scheduler.set_clc_pipeline(clc_pipeline, clc_consumer_state) + else: + clc_pipeline = None + tile_scheduler = self.tile_scheduler_cls.create(tile_sched_params) + assert isinstance(tile_scheduler, TileSchedulerProtocol), f"tile_scheduler is not a TileSchedulerProtocol: {type(tile_scheduler)}" + # /////////////////////////////////////////////////////////////////////////////// - # EMPTY + # EMPTY / CLC SCHEDULER WARP # /////////////////////////////////////////////////////////////////////////////// - for i in cutlass.range_constexpr(len(self.empty_warp_ids)): - if warp_idx == self.empty_warp_ids[i]: + if const_expr(self.use_clc_scheduler): + if warp_idx == self.clc_scheduler_warp_id: cute.arch.setmaxregister_decrease(self.num_regs_other) + if is_leader_cta: + self.clc_scheduler_warp(clc_pipeline, tile_scheduler) + else: + self.empty_warp(clc_pipeline, tile_scheduler) + for i in cutlass.range_constexpr(len(self.empty_warp_ids)): + if warp_idx == self.empty_warp_ids[i] and warp_idx != self.clc_scheduler_warp_id: + cute.arch.setmaxregister_decrease(self.num_regs_other) + self.empty_warp(clc_pipeline, tile_scheduler) + else: + for i in cutlass.range_constexpr(len(self.empty_warp_ids)): + if warp_idx == self.empty_warp_ids[i]: + cute.arch.setmaxregister_decrease(self.num_regs_other) # /////////////////////////////////////////////////////////////////////////////// # LOAD @@ -975,8 +1068,9 @@ def kernel( block_info, num_splits, SeqlenInfoCls, - TileSchedulerCls, blocksparse_tensors, + clc_pipeline=clc_pipeline, + tile_scheduler=tile_scheduler, ) # /////////////////////////////////////////////////////////////////////////////// @@ -1006,8 +1100,9 @@ def kernel( block_info, num_splits, SeqlenInfoCls, - TileSchedulerCls, blocksparse_tensors, + clc_pipeline=clc_pipeline, + tile_scheduler=tile_scheduler, ) # Dealloc the tensor memory buffer tmem.relinquish_alloc_permit() @@ -1029,8 +1124,9 @@ def kernel( block_info, num_splits, SeqlenInfoCls, - TileSchedulerCls, mma_tile_coord_v, + clc_pipeline=clc_pipeline, + tile_scheduler=tile_scheduler, ) # /////////////////////////////////////////////////////////////////////////////// @@ -1062,11 +1158,12 @@ def kernel( num_splits=num_splits, SeqlenInfoCls=SeqlenInfoCls, AttentionMaskCls=AttentionMaskCls, - TileSchedulerCls=TileSchedulerCls, aux_tensors=aux_tensors, fastdiv_mods=fastdiv_mods, head_divmod=head_divmod, blocksparse_tensors=blocksparse_tensors, + clc_pipeline=clc_pipeline, + tile_scheduler=tile_scheduler, ) if const_expr(not self.s0_s1_barrier): @@ -1110,8 +1207,9 @@ def kernel( block_info, num_splits, SeqlenInfoCls, - TileSchedulerCls, blocksparse_tensors, + clc_pipeline=clc_pipeline, + tile_scheduler=tile_scheduler, ) tmem_alloc_barrier.arrive() @@ -1137,8 +1235,9 @@ def load( block_info: BlockInfo, num_splits: Int32, SeqlenInfoCls: Callable, - TileSchedulerCls: Callable, blocksparse_tensors: Optional[BlockSparseTensors], + clc_pipeline: cutlass_pipeline.PipelineClcFetchAsync | None, + tile_scheduler: TileSchedulerProtocol ): num_load_threads = len(self.load_warp_ids) * cute.arch.WARP_SIZE tidx = cute.arch.thread_idx()[0] % num_load_threads @@ -1147,7 +1246,6 @@ def load( kv_producer_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Producer, self.kv_stage ) - tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx @@ -1308,9 +1406,9 @@ def load( self.q_subtile_factor if self.q_subtile_factor is not None else 1, ) + tile_scheduler.prefetch_next_work() - tile_scheduler.advance_to_next_work() - work_tile = tile_scheduler.get_current_work() + work_tile = tile_scheduler.consumer_advance() # End of persistent scheduler loop pipeline_kv.producer_tail(kv_producer_state) @@ -1338,8 +1436,9 @@ def mma( block_info: BlockInfo, num_splits: Int32, SeqlenInfoCls: Callable, - TileSchedulerCls: Callable, blocksparse_tensors: Optional[BlockSparseTensors], + clc_pipeline=None, + tile_scheduler=None, ): tSrQ = tiled_mma_qk.make_fragment_A(sQ) tSrK = tiled_mma_qk.make_fragment_B(sK) @@ -1428,7 +1527,6 @@ def mma( ) P_full_O_rescaled_phase = Int32(0) - tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx @@ -1599,8 +1697,7 @@ def mma( # End of GEMM_PV1(i_end) (P1 * Vi_end -> O1) # Advance to next tile - tile_scheduler.advance_to_next_work() - work_tile = tile_scheduler.get_current_work() + work_tile = tile_scheduler.consumer_advance() # End of persistent scheduler loop # We don't need pipeline_s_p_o.producer_tail() since there's no dangling mbarrier at the end @@ -1629,11 +1726,12 @@ def softmax_loop( num_splits: Int32, SeqlenInfoCls: Callable, AttentionMaskCls: Callable, - TileSchedulerCls: Callable, aux_tensors: Optional[list] = None, fastdiv_mods=(None, None), head_divmod=None, blocksparse_tensors: Optional[BlockSparseTensors] = None, + clc_pipeline=None, + tile_scheduler=None, ): """Compute softmax on attention scores from QK matrix multiplication. @@ -1693,7 +1791,6 @@ def softmax_loop( warp_idx_in_wg = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4 - tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx @@ -1936,8 +2033,7 @@ def softmax_loop( # gLSE[tidx] = lse # Advance to next tile - tile_scheduler.advance_to_next_work() - work_tile = tile_scheduler.get_current_work() + work_tile = tile_scheduler.consumer_advance() # End of persistent scheduler loop # This is equivalent to pipeline_sm_stats.producer_tail @@ -2107,8 +2203,9 @@ def correction_loop( block_info: BlockInfo, num_splits: Int32, SeqlenInfoCls: Callable, - TileSchedulerCls: Callable, blocksparse_tensors: Optional[BlockSparseTensors] = None, + clc_pipeline=None, + tile_scheduler=None, ): tidx = cute.arch.thread_idx()[0] % (cute.arch.WARP_SIZE * len(self.correction_warp_ids)) warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4 @@ -2138,7 +2235,6 @@ def correction_loop( o_corr_consumer_phase = Int32(0) corr_epi_producer_phase = Int32(1) - tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx @@ -2355,8 +2451,7 @@ def correction_loop( gLSE[tidx] = lse # Advance to next tile - tile_scheduler.advance_to_next_work() - work_tile = tile_scheduler.get_current_work() + work_tile = tile_scheduler.consumer_advance() # End of persistent scheduler loop # This is equivalent to pipeline_o_epi.consumer_tail() for the correction warps @@ -2562,11 +2657,11 @@ def epilogue_s2g( block_info: BlockInfo, num_splits: int, SeqlenInfoCls: Callable, - TileSchedulerCls: Callable, mma_tile_coord_v: Int32 = 0, + clc_pipeline=None, + tile_scheduler=None, ): epi_consumer_phase = Int32(0) - tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx @@ -2619,8 +2714,60 @@ def epilogue_s2g( epi_consumer_phase ^= 1 # Advance to next tile - tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.consumer_advance() + + @cute.jit + def clc_scheduler_warp( + self, + clc_pipeline: cutlass_pipeline.PipelineClcFetchAsync | None, + tile_scheduler: TileSchedulerProtocol, + ): + clc_producer_state = cutlass_pipeline.make_pipeline_state( + cutlass_pipeline.PipelineUserType.Producer, self.sched_stages + ) + clc_consumer_state = cutlass_pipeline.make_pipeline_state( + cutlass_pipeline.PipelineUserType.Consumer, self.sched_stages + ) + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + clc_pipeline.producer_acquire(clc_producer_state) + mbarrier_addr = clc_pipeline.producer_get_barrier(clc_producer_state) + tile_scheduler.advance_to_next_work(mbarrier_addr=mbarrier_addr) + clc_producer_state.advance() + + clc_pipeline.consumer_wait(clc_consumer_state) + work_tile = tile_scheduler.get_current_work() + if cute.arch.thread_idx()[0] == self.clc_scheduler_warp_id * cute.arch.WARP_SIZE: + fa_printf( + 3, + "[CLC] query sm={} cta={} (m_blk={},h={},b={},s={}) valid={}\n", + smid(), + cute.arch.block_idx()[0], + work_tile.tile_idx[0], + work_tile.tile_idx[1], + work_tile.tile_idx[2], + work_tile.tile_idx[3], + work_tile.is_valid_tile, + ) + clc_pipeline.consumer_release(clc_consumer_state) + clc_consumer_state.advance() + clc_pipeline.producer_tail(clc_producer_state) + + @cute.jit + def empty_warp( + self, + clc_pipeline: cutlass_pipeline.PipelineClcFetchAsync | None, + tile_scheduler: TileSchedulerProtocol, + ): + clc_consumer_state = cutlass_pipeline.make_pipeline_state( + cutlass_pipeline.PipelineUserType.Consumer, self.sched_stages + ) + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + clc_pipeline.consumer_wait(clc_consumer_state) work_tile = tile_scheduler.get_current_work() + clc_pipeline.consumer_release(clc_consumer_state) + clc_consumer_state.advance() def load_Q( self, diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 0a15aa65b93..ea446df7cfe 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -209,6 +209,7 @@ def _flash_attn_fwd( out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, aux_tensors: Optional[list[torch.Tensor]] = None, + sched_stages: int = 1, ) -> Tuple[torch.Tensor, torch.Tensor]: """Forward pass for FlashAttention. @@ -339,6 +340,9 @@ def _flash_attn_fwd( causal, window_size_left, window_size_right, mask_mod ) + requested_use_clc_scheduler = utils._get_use_clc_scheduler_default() + requested_disable_2cta = utils._get_disable_2cta_default() + current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) fwd_cfg = FwdConfig(128, 128, True, True) # default @@ -393,6 +397,7 @@ def _flash_attn_fwd( use_2cta_instrs = ( arch // 10 in [10, 11] + and not requested_disable_2cta and not causal and not local and not is_split_kv @@ -496,6 +501,8 @@ def _flash_attn_fwd( q_subtile_factor, mma_pv_is_rs, intra_wg_overlap, + requested_use_clc_scheduler, + sched_stages, fa_logging.get_fa_log_level(), ) if compile_key not in _flash_attn_fwd.compile_cache: @@ -569,8 +576,8 @@ def _flash_attn_fwd( is_local=local, is_split_kv=is_split_kv, pack_gqa=pack_gqa, - tile_m=tile_m, - tile_n=tile_n, + m_block_size=tile_m, + n_block_size=tile_n, q_stage=q_stage, is_persistent=not causal and not local @@ -582,6 +589,8 @@ def _flash_attn_fwd( has_aux_tensors=aux_tensors is not None, paged_kv_non_tma=page_size not in [None, 128], is_varlen_q=cu_seqlens_q is not None or seqused_q is not None, + use_clc_scheduler=requested_use_clc_scheduler, + sched_stages=sched_stages, q_subtile_factor=q_subtile_factor, use_2cta_instrs=use_2cta_instrs, ) @@ -811,8 +820,10 @@ def _flash_attn_bwd( dKV_swapAB = False AtomLayoutMdQ = 1 AtomLayoutNdKV = 1 + requested_disable_2cta = utils._get_disable_2cta_default() disable_2cta = ( - local + requested_disable_2cta + or local or score_mod is not None or score_mod_bwd is not None or mask_mod is not None diff --git a/flash_attn/cute/tile_scheduler.py b/flash_attn/cute/tile_scheduler.py index 95481099b21..450ab13753b 100644 --- a/flash_attn/cute/tile_scheduler.py +++ b/flash_attn/cute/tile_scheduler.py @@ -1,6 +1,6 @@ # Copyright (c) 2025, Tri Dao. -from typing import Optional, Tuple +from typing import Optional, Tuple, Protocol, runtime_checkable from dataclasses import dataclass try: @@ -18,6 +18,7 @@ import flash_attn.cute.utils as utils from flash_attn.cute.fast_math import clz +from flash_attn.cute.logging import fa_printf class WorkTileInfo(cutlass.utils.WorkTileInfo): @@ -31,6 +32,49 @@ def __new_from_mlir_values__(self, values: list[ir.Value]) -> "WorkTileInfo": return WorkTileInfo(new_tile_idx, new_is_valid_tile) +@runtime_checkable +class TileSchedulerProtocol(Protocol): + """Protocol defining the interface all tile schedulers must implement. + + Schedulers are responsible for: + 1. Coordinate mapping: linear tile index -> (m_block, head, batch, split) + 2. Work distribution: how to get the next tile (static grid-stride vs CLC dynamic) + + The `mbarrier_addr` parameter in `advance_to_next_work` enables CLC support: + - Static schedulers: assert mbarrier_addr is None, use grid-stride iteration + - CLC schedulers: assert mbarrier_addr is not None, issue async CLC query + """ + + def get_current_work(self) -> WorkTileInfo: + """Get the current work tile coordinates.""" + ... + + def initial_work_tile_info(self) -> WorkTileInfo: + """Get the initial work tile for this CTA.""" + ... + + def advance_to_next_work(self, *, mbarrier_addr=None) -> None: + """Advance to the next work tile. + + Args: + mbarrier_addr: For CLC schedulers, the mbarrier address for async query. + Static schedulers should assert this is None. + """ + ... + + def prefetch_next_work(self) -> None: + """Prefetch next work tile info (optional, may be no-op).""" + ... + + def consumer_advance(self) -> WorkTileInfo: + """Consumer-side advance: move to next tile and return it. + + For static schedulers: advance + get_current_work. + For CLC schedulers: pipeline wait/release + get_current_work + state advance. + """ + ... + + @dataclass class TileSchedulerArguments(ParamsBase): num_block: Int32 @@ -133,9 +177,14 @@ def initial_work_tile_info(self, *, loc=None, ip=None): def prefetch_next_work(self, *, loc=None, ip=None): pass - def advance_to_next_work(self, *, loc=None, ip=None): + def advance_to_next_work(self, *, loc=None, ip=None, mbarrier_addr=None): + assert mbarrier_addr is None self._is_first_block = False + def consumer_advance(self): + self.advance_to_next_work() + return self.get_current_work() + def __extract_mlir_values__(self): values, self._values_pos = [], [] for obj in [self.params, self._blk_coord]: @@ -223,12 +272,17 @@ def initial_work_tile_info(self, *, loc=None, ip=None): def prefetch_next_work(self, *, loc=None, ip=None): pass - def advance_to_next_work(self, *, loc=None, ip=None): + def advance_to_next_work(self, *, loc=None, ip=None, mbarrier_addr=None): + assert mbarrier_addr is None if const_expr(self.params.cluster_shape_m == 1): self._tile_idx += cute.arch.grid_dim()[0] else: self._tile_idx += cute.arch.cluster_dim()[0] + def consumer_advance(self): + self.advance_to_next_work() + return self.get_current_work() + def __extract_mlir_values__(self): values, self._values_pos = [], [] for obj in [self.params, self._tile_idx]: @@ -354,10 +408,14 @@ def initial_work_tile_info(self, *, loc=None, ip=None): def prefetch_next_work(self, *, loc=None, ip=None): pass - def advance_to_next_work(self, *, loc=None, ip=None): - # Single tile scheduler - set to invalid tile_idx to indicate no more work + def advance_to_next_work(self, *, loc=None, ip=None, mbarrier_addr=None): + assert mbarrier_addr is None self._tile_idx = self.params.total_blocks + def consumer_advance(self): + self.advance_to_next_work() + return self.get_current_work() + def __extract_mlir_values__(self): values, self._values_pos = [], [] for obj in [self.params, self._tile_idx, self._split_idx]: @@ -478,10 +536,14 @@ def initial_work_tile_info(self, *, loc=None, ip=None): def prefetch_next_work(self, *, loc=None, ip=None): pass - def advance_to_next_work(self, *, loc=None, ip=None): - # Single tile scheduler - set to invalid tile_idx to indicate no more work + def advance_to_next_work(self, *, loc=None, ip=None, mbarrier_addr=None): + assert mbarrier_addr is None self._tile_idx = self.params.total_blocks + def consumer_advance(self): + self.advance_to_next_work() + return self.get_current_work() + def __extract_mlir_values__(self): values, self._values_pos = [], [] for obj in [self.params, self._tile_idx]: @@ -704,10 +766,14 @@ def initial_work_tile_info(self, *, loc=None, ip=None): def prefetch_next_work(self, *, loc=None, ip=None): pass - def advance_to_next_work(self, *, loc=None, ip=None): - # Single tile scheduler - set to invalid tile_idx to indicate no more work + def advance_to_next_work(self, *, loc=None, ip=None, mbarrier_addr=None): + assert mbarrier_addr is None self._is_first_block = False + def consumer_advance(self): + self.advance_to_next_work() + return self.get_current_work() + def __extract_mlir_values__(self): values, self._values_pos = [], [] for obj in [self.params, self._tile_idx, self._split_idx]: @@ -725,3 +791,246 @@ def __new_from_mlir_values__(self, values): obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) values = values[n_items:] return SingleTileVarlenScheduler(*(tuple(obj_list)), loc=self._loc) + + +class CLCDynamicTileScheduler: + """Dynamic tile scheduler using CLC (Cooperative Launch Control) for load balancing. + + This scheduler wraps the native CuTeDSL CLC support for dynamic work distribution. + SMs can "pull" work as they finish, naturally balancing load across the GPU. + Particularly beneficial for flex-attention workloads with variable computation per tile. + + Architecture (following CUTLASS pattern): + - Scheduler warp (producer): Issues CLC queries via advance_to_next_work(mbarrier_addr=...) + - All other warps (consumers): Read work tiles via get_current_work() + + Requires `use_tma_KV=True` so the kernel can dedicate an empty warp to drive CLC queries. + """ + + @dataclass + class Params(ParamsBase): + num_block: Int32 + num_head: Int32 + num_batch_splits: Int32 + num_splits_divmod: FastDivmodDivisor + cluster_shape_mn: cutlass.Constexpr[Tuple[int, int]] = (1, 1) + cluster_shape_m: cutlass.Constexpr[int] = 1 + is_split_kv: cutlass.Constexpr[bool] = False + sched_stages: cutlass.Constexpr[int] = 1 + lpt: cutlass.Constexpr[bool] = False + + @staticmethod + def create( + args: TileSchedulerArguments, + sched_stages: int = 1, + *, + loc=None, + ip=None, + ) -> "CLCDynamicTileScheduler.Params": + assert sched_stages == 1, f"CLC scheduler only supports 1 stage, got {sched_stages}" + num_batch_splits = ( + args.num_batch * args.num_splits if const_expr(args.is_split_kv) else args.num_batch + ) + return CLCDynamicTileScheduler.Params( + num_block=args.num_block, + num_head=args.num_head, + num_batch_splits=num_batch_splits, + num_splits_divmod=FastDivmodDivisor(args.num_splits), + cluster_shape_mn=args.cluster_shape_mn, + cluster_shape_m=args.cluster_shape_mn[0], + is_split_kv=args.is_split_kv, + sched_stages=sched_stages, + lpt=args.lpt, + ) + + def __init__( + self, + params: Params, + cutlass_scheduler, + tile_idx: Int32, + clc_pipeline=None, + clc_consumer_state=None, + *, + loc=None, + ip=None, + ): + self.params = params + self._scheduler = cutlass_scheduler + self._tile_idx = tile_idx + self._clc_pipeline = clc_pipeline + self._clc_consumer_state = clc_consumer_state + self._loc = loc + self._ip = ip + + @staticmethod + def to_underlying_arguments( + args: TileSchedulerArguments, + sched_stages: int = 1, + *, + loc=None, + ip=None, + ) -> Params: + return CLCDynamicTileScheduler.Params.create(args, sched_stages, loc=loc, ip=ip) + + @staticmethod + @cute.jit + def create( + params: Params, + clc_response_ptr: cute.Pointer, + *, + loc=None, + ip=None, + ) -> "CLCDynamicTileScheduler": + from cutlass.utils import ( + ClcDynamicPersistentTileScheduler, + ClcDynamicPersistentTileSchedulerParams, + ) + + cutlass_params = ClcDynamicPersistentTileSchedulerParams( + problem_shape_ntile_mnl=( + cute.round_up(params.num_block, params.cluster_shape_m), + params.num_head, + params.num_batch_splits, + ), + cluster_shape_mnk=(*params.cluster_shape_mn, 1), + ) + block_idx = cute.arch.block_idx() + grid_dim = cute.arch.grid_dim() + cutlass_scheduler = ClcDynamicPersistentTileScheduler.create( + cutlass_params, + block_idx, + grid_dim, + clc_response_ptr, + ) + tile_idx = block_idx[0] + return CLCDynamicTileScheduler(params, cutlass_scheduler, tile_idx, loc=loc, ip=ip) + + @staticmethod + def get_grid_shape( + params: Params, + *, + loc=None, + ip=None, + ) -> Tuple[Int32, Int32, Int32]: + # CLC needs the full problem grid as backlog for try_cancel work-stealing. + # Grid x must be a multiple of cluster_shape_m for CUDA cluster launch. + return ( + cute.round_up(params.num_block, params.cluster_shape_m), + params.num_head, + params.num_batch_splits, + ) + + @cute.jit + def _work_to_coords( + self, block_idx: Int32, head_idx: Int32, batch_split_idx: Int32, is_valid + ) -> WorkTileInfo: + """Convert CUTLASS work tile coordinates to WorkTileInfo.""" + if const_expr(self.params.cluster_shape_m > 1): + block_idx = block_idx // self.params.cluster_shape_m + if const_expr(self.params.lpt): + block_idx = self.params.num_block - 1 - block_idx + split_idx = Int32(0) + if const_expr(self.params.is_split_kv): + batch_idx, split_idx = divmod(batch_split_idx, self.params.num_splits_divmod) + else: + batch_idx = batch_split_idx + return WorkTileInfo( + (Int32(block_idx), Int32(head_idx), Int32(batch_idx), Int32(split_idx)), + is_valid, + ) + + @cute.jit + def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: + """CLC consumer: Read current work tile from CLC response.""" + work = self._scheduler.get_current_work() + self._tile_idx = work.tile_idx[0] + result = self._work_to_coords( + work.tile_idx[0], work.tile_idx[1], work.tile_idx[2], work.is_valid_tile + ) + if cute.arch.thread_idx()[0] == 0: + self._debug_print("pull", *result.tile_idx, result.is_valid_tile) + return result + + @cute.jit + def initial_work_tile_info(self, *, loc=None, ip=None) -> WorkTileInfo: + work = self._scheduler.initial_work_tile_info() + self._tile_idx = work.tile_idx[0] + result = self._work_to_coords( + work.tile_idx[0], work.tile_idx[1], work.tile_idx[2], work.is_valid_tile + ) + if cute.arch.thread_idx()[0] == 0: + self._debug_print("map", *result.tile_idx, result.is_valid_tile) + return result + + def prefetch_next_work(self, *, loc=None, ip=None): + pass + + @cute.jit + def advance_to_next_work(self, *, loc=None, ip=None, mbarrier_addr=None): + """CLC producer: Issue async CLC query for next tile.""" + assert mbarrier_addr is not None + self._scheduler.advance_to_next_work(mbarrier_addr) + + @cute.jit + def consumer_advance(self): + self._clc_pipeline.consumer_wait(self._clc_consumer_state) + work_tile = self.get_current_work() + self._clc_pipeline.consumer_release(self._clc_consumer_state) + self._clc_consumer_state.advance() + return work_tile + + def set_clc_pipeline(self, clc_pipeline, clc_consumer_state): + self._clc_pipeline = clc_pipeline + self._clc_consumer_state = clc_consumer_state + + def __extract_mlir_values__(self): + values, self._values_pos = [], [] + for obj in [ + self.params, + self._scheduler, + self._tile_idx, + self._clc_pipeline, + self._clc_consumer_state, + ]: + obj_values = cutlass.extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + obj_list = [] + for obj, n_items in zip( + [ + self.params, + self._scheduler, + self._tile_idx, + self._clc_pipeline, + self._clc_consumer_state, + ], + self._values_pos, + ): + obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) + values = values[n_items:] + return CLCDynamicTileScheduler(*obj_list, loc=self._loc) + + def _debug_print(self, phase: str, block_idx, head_idx, batch_idx, split_idx, is_valid): + linear_idx = ( + batch_idx * self.params.num_head * self.params.num_block + + head_idx * self.params.num_block + + block_idx + ) + total_tiles = self.params.num_block * self.params.num_head * self.params.num_batch_splits + fa_printf( + 3, + f"[CLC] {phase} sm={{}} cta={{}}/{{}} linear={{}}/{{}} (m_blk={{}},h={{}},b={{}},s={{}}) valid={{}}\n", + utils.smid(), + cute.arch.block_idx_in_cluster(), + Int32(self.params.cluster_shape_mn[0]), + linear_idx, + total_tiles, + block_idx, + head_idx, + batch_idx, + split_idx, + is_valid, + ) diff --git a/flash_attn/cute/utils.py b/flash_attn/cute/utils.py index 05985462116..6b3904e0764 100644 --- a/flash_attn/cute/utils.py +++ b/flash_attn/cute/utils.py @@ -3,12 +3,13 @@ import math import hashlib import inspect +import os from typing import Type, Callable, Optional, Tuple, overload import cutlass import cutlass.cute as cute -from cutlass import Float32, const_expr +from cutlass import Float32, Int32, const_expr from cutlass.cute import FastDivmodDivisor from cutlass.cutlass_dsl import T, dsl_user_op from cutlass._mlir.dialects import nvvm, llvm @@ -55,6 +56,17 @@ ), } +_fa_clc_enabled: bool = os.environ.get("FA_CLC", "0") == "1" +_fa_disable_2cta_enabled: bool = os.environ.get("FA_DISABLE_2CTA", "0") == "1" + + +def _get_use_clc_scheduler_default() -> bool: + return _fa_clc_enabled + + +def _get_disable_2cta_default() -> bool: + return _fa_disable_2cta_enabled + def _compute_base_hash(func: Callable) -> str: """Compute hash from source code or bytecode and closure values.""" @@ -250,6 +262,21 @@ def warp_reduce( return val +@dsl_user_op +def smid(*, loc=None, ip=None) -> Int32: + return Int32( + llvm.inline_asm( + T.i32(), + [], + "mov.u32 $0, %smid;", + "=r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + @dsl_user_op def fmax( a: float | Float32, b: float | Float32, c: float | Float32 | None = None, *, loc=None, ip=None diff --git a/tests/cute/conftest.py b/tests/cute/conftest.py index 86deb53608d..d2162255775 100644 --- a/tests/cute/conftest.py +++ b/tests/cute/conftest.py @@ -59,6 +59,9 @@ def pytest_configure(config): os.environ["CUDA_VISIBLE_DEVICES"] = gpu_ids[worker_num % len(gpu_ids)] def pytest_collection_finish(session): + if not session.config.option.collectonly: + return + # file_name -> test_name -> counter test_counts: dict[str, dict[str, int]] = {} for item in session.items: diff --git a/tests/cute/test_clc_fuzz.py b/tests/cute/test_clc_fuzz.py new file mode 100644 index 00000000000..3c2c01db4a0 --- /dev/null +++ b/tests/cute/test_clc_fuzz.py @@ -0,0 +1,342 @@ +"""Adversarial regression tests for CLC tile scheduling. + +These cases intentionally target scheduler-sensitive shapes: mismatched +sequence lengths, non-aligned tiles, GQA ratios, minimal problems, and +larger persistent workloads. This is deterministic adversarial coverage, +not randomized fuzzing. +""" + +from contextlib import contextmanager +import os +from unittest import mock + +import pytest +import torch + +from flash_attn.cute import utils as cute_utils +from flash_attn.cute.interface import flash_attn_func +from flash_attn.cute.testing import attention_ref + + +if torch.cuda.is_available(): + COMPUTE_CAPABILITY = torch.cuda.get_device_capability()[0] + SM_COUNT = torch.cuda.get_device_properties("cuda").multi_processor_count +else: + COMPUTE_CAPABILITY = 0 + SM_COUNT = 0 +pytestmark = pytest.mark.skipif( + COMPUTE_CAPABILITY not in (10, 11), + reason="CLC adversarial tests require SM100/SM110 persistent forward", +) + + +@contextmanager +def clc_scheduler_enabled(): + with ( + mock.patch.dict(os.environ, {"FA_CLC": "1"}, clear=False), + mock.patch.object(cute_utils, "_fa_clc_enabled", True), + ): + yield + + +def check_output(q, k, v, *, causal=False, num_splits=1): + out, _ = flash_attn_func(q, k, v, causal=causal, num_splits=num_splits) + torch.cuda.synchronize() + out_ref, _ = attention_ref(q, k, v, causal=causal) + out_pt, _ = attention_ref(q, k, v, causal=causal, upcast=False, reorder_ops=True) + fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() + assert (out - out_ref).abs().max().item() <= 2 * ( + out_pt - out_ref + ).abs().max().item() + fwd_atol, ( + f"max_diff={(out - out_ref).abs().max().item()}, " + f"pt_max_diff={(out_pt - out_ref).abs().max().item()}, " + f"fwd_atol={fwd_atol}, " + f"q={list(q.shape)} k={list(k.shape)} v={list(v.shape)} " + f"causal={causal} num_splits={num_splits}" + ) + + +def randn(b, s, h, d): + return torch.randn(b, s, h, d, device="cuda", dtype=torch.bfloat16) + + +def expected_total_tiles_mha(batch, seqlen_q, heads): + q_stage = 2 if COMPUTE_CAPABILITY == 10 and seqlen_q > 128 else 1 + num_block = (seqlen_q + q_stage * 128 - 1) // (q_stage * 128) + return num_block * heads * batch + + +@pytest.fixture(autouse=True) +def seed(): + torch.random.manual_seed(42) + + +@pytest.fixture(autouse=True) +def enable_clc_scheduler(): + with clc_scheduler_enabled(): + yield + + +class TestCLCMismatchedSeqlens: + + @pytest.mark.parametrize("sq,sk", [ + (128, 512), + (128, 1024), + (128, 2048), + (256, 64), + (256, 128), + (512, 127), + (512, 129), + (64, 4096), + (1, 128), + (1, 512), + (1, 1024), + ]) + def test_qk_mismatch(self, sq, sk): + check_output(randn(4, sq, 4, 128), randn(4, sk, 4, 128), randn(4, sk, 4, 128)) + + @pytest.mark.parametrize("sq,sk", [ + (128, 513), + (256, 1023), + (64, 257), + (192, 383), + (1, 255), + ]) + def test_qk_mismatch_nonaligned_k(self, sq, sk): + check_output(randn(4, sq, 4, 128), randn(4, sk, 4, 128), randn(4, sk, 4, 128)) + + @pytest.mark.parametrize("sq,sk", [ + (1, 128), + (1, 256), + (1, 1024), + (2, 128), + (3, 512), + ]) + def test_tiny_q_long_k(self, sq, sk): + check_output(randn(2, sq, 4, 128), randn(2, sk, 4, 128), randn(2, sk, 4, 128)) + + +class TestCLCNonAlignedShapes: + @pytest.mark.parametrize("sq", [1, 3, 7, 15, 31, 33, 63, 65, 127, 129, 191, 193, 255, 257]) + def test_nonaligned_q(self, sq): + check_output(randn(2, sq, 4, 128), randn(2, 256, 4, 128), randn(2, 256, 4, 128)) + + @pytest.mark.parametrize("sk", [1, 7, 31, 33, 63, 65, 127, 129, 255, 257, 511, 513]) + def test_nonaligned_k(self, sk): + check_output(randn(2, 256, 4, 128), randn(2, sk, 4, 128), randn(2, sk, 4, 128)) + + +class TestCLCPrimes: + @pytest.mark.parametrize("batch,heads,sq,sk", [ + (1, 1, 127, 131), + (3, 5, 131, 127), + (7, 3, 257, 251), + (11, 7, 67, 509), + (13, 1, 191, 193), + (5, 11, 61, 67), + (2, 3, 509, 127), + ]) + def test_all_prime(self, batch, heads, sq, sk): + check_output( + randn(batch, sq, heads, 128), + randn(batch, sk, heads, 128), + randn(batch, sk, heads, 128), + ) + + +class TestCLC2CTA: + @pytest.mark.parametrize("sq,sk", [ + (128, 512), + (256, 127), + (256, 129), + (128, 2048), + (1, 512), + (64, 1024), + (512, 64), + ]) + def test_2cta_qk_mismatch(self, sq, sk): + check_output(randn(4, sq, 4, 128), randn(4, sk, 4, 128), randn(4, sk, 4, 128)) + + @pytest.mark.parametrize("batch,heads,sq,sk", [ + (1, 1, 128, 128), + (1, 1, 256, 512), + (3, 5, 128, 1024), + (7, 3, 512, 127), + (9, 7, 256, 257), + (13, 1, 128, 64), + ]) + def test_2cta_adversarial_combos(self, batch, heads, sq, sk): + check_output( + randn(batch, sq, heads, 128), + randn(batch, sk, heads, 128), + randn(batch, sk, heads, 128), + ) + + +class TestCLCGQA: + @pytest.mark.parametrize("q_heads,kv_heads,sq,sk", [ + (4, 1, 128, 512), + (4, 1, 256, 127), + (8, 1, 64, 1024), + (8, 2, 512, 129), + (8, 4, 1, 256), + (6, 2, 192, 383), + (6, 3, 128, 1), + (12, 4, 257, 511), + ]) + def test_gqa_mismatch(self, q_heads, kv_heads, sq, sk): + check_output( + randn(4, sq, q_heads, 128), + randn(4, sk, kv_heads, 128), + randn(4, sk, kv_heads, 128), + ) + + @pytest.mark.parametrize("q_heads,kv_heads", [ + (4, 1), (4, 2), (8, 1), (8, 2), (8, 4), (6, 2), (6, 3), (12, 4), + ]) + def test_gqa_ratios(self, q_heads, kv_heads): + check_output( + randn(4, 512, q_heads, 128), + randn(4, 512, kv_heads, 128), + randn(4, 512, kv_heads, 128), + ) + + +class TestCLCHeadDim: + @pytest.mark.parametrize("d,dv,sq,sk", [ + (64, 64, 128, 512), + (64, 64, 1, 256), + (96, 96, 255, 127), + (128, 64, 192, 384), + (128, 64, 1, 1024), + ]) + def test_head_dims_adversarial(self, d, dv, sq, sk): + check_output(randn(4, sq, 4, d), randn(4, sk, 4, d), randn(4, sk, 4, dv)) + + def test_overlap_sO_sQ_fallback(self): + from flash_attn.cute.flash_fwd_sm100 import FlashAttentionForwardSm100 + from flash_attn.cute.tile_scheduler import SingleTileScheduler + + orig_init = FlashAttentionForwardSm100.__init__ + schedulers = [] + def spy_init(self_inner, *a, **kw): + orig_init(self_inner, *a, **kw) + schedulers.append(self_inner.TileScheduler) + + with mock.patch.object(FlashAttentionForwardSm100, '__init__', spy_init): + check_output(randn(4, 128, 4, 192), randn(4, 257, 4, 192), randn(4, 257, 4, 128)) + + assert schedulers and schedulers[-1] is SingleTileScheduler + + +class TestCLCMinimal: + @pytest.mark.parametrize("sq,sk", [(1, 1), (1, 2), (2, 1), (1, 128), (128, 1)]) + def test_minimal(self, sq, sk): + check_output(randn(1, sq, 1, 128), randn(1, sk, 1, 128), randn(1, sk, 1, 128)) + + def test_single_element(self): + check_output(randn(1, 1, 1, 64), randn(1, 1, 1, 64), randn(1, 1, 1, 64)) + + +class TestCLCCausal: + + @pytest.mark.parametrize("batch,heads,sq,sk", [ + (3, 5, 259, 259), + (7, 3, 513, 513), + (1, 7, 1023, 1023), + (5, 11, 2049, 2049), + (2, 3, 4097, 4097), + ]) + def test_causal_square(self, batch, heads, sq, sk): + check_output(randn(batch, sq, heads, 128), randn(batch, sk, heads, 128), randn(batch, sk, heads, 128), causal=True) + + @pytest.mark.parametrize("batch,heads,sq,sk", [ + (3, 7, 127, 513), + (5, 3, 259, 1023), + (7, 5, 63, 2049), + (11, 1, 1, 511), + (2, 9, 1, 1025), + (9, 3, 33, 4097), + ]) + def test_causal_qk_mismatch(self, batch, heads, sq, sk): + check_output(randn(batch, sq, heads, 128), randn(batch, sk, heads, 128), randn(batch, sk, heads, 128), causal=True) + + @pytest.mark.parametrize("batch,heads,sq,sk", [ + (3, 7, 191, 191), + (7, 5, 193, 193), + (5, 3, 383, 383), + (11, 1, 129, 509), + (2, 13, 1, 131), + (9, 3, 67, 251), + ]) + def test_causal_nonaligned(self, batch, heads, sq, sk): + check_output(randn(batch, sq, heads, 128), randn(batch, sk, heads, 128), randn(batch, sk, heads, 128), causal=True) + + @pytest.mark.parametrize("batch,q_heads,kv_heads,sq", [ + (3, 6, 2, 513), + (7, 8, 1, 259), + (5, 12, 4, 1023), + (2, 8, 2, 2049), + (11, 4, 1, 191), + ]) + def test_causal_gqa(self, batch, q_heads, kv_heads, sq): + check_output( + randn(batch, sq, q_heads, 128), + randn(batch, sq, kv_heads, 128), + randn(batch, sq, kv_heads, 128), + causal=True, + ) + + def test_causal_large(self): + check_output(randn(3, 4097, 13, 128), randn(3, 4097, 13, 128), randn(3, 4097, 13, 128), causal=True) + + +class TestCLCLargeScale: + def test_large_batch(self): + check_output(randn(32, 512, 8, 128), randn(32, 512, 8, 128), randn(32, 512, 8, 128)) + + def test_long_seq(self): + check_output(randn(2, 4096, 4, 128), randn(2, 4096, 4, 128), randn(2, 4096, 4, 128)) + + def test_many_heads(self): + check_output(randn(4, 512, 32, 128), randn(4, 512, 32, 128), randn(4, 512, 32, 128)) + + @pytest.mark.parametrize("batch,heads,sq,sk", [ + (24, 8, 768, 2048), + (16, 8, 1536, 4096), + (12, 8, 2305, 4096), + ]) + def test_work_stealing_pressure(self, batch, heads, sq, sk): + total_tiles = expected_total_tiles_mha(batch, sq, heads) + assert total_tiles > SM_COUNT, f"expected total_tiles={total_tiles} > sm_count={SM_COUNT}" + check_output( + randn(batch, sq, heads, 128), + randn(batch, sk, heads, 128), + randn(batch, sk, heads, 128), + ) + + def test_long_k_short_q(self): + check_output(randn(8, 64, 8, 128), randn(8, 8192, 8, 128), randn(8, 8192, 8, 128)) + + def test_long_q_short_k(self): + check_output(randn(4, 4096, 4, 128), randn(4, 64, 4, 128), randn(4, 64, 4, 128)) + + +class TestCLCRepeatability: + @pytest.mark.parametrize("trial", range(5)) + def test_repeat_mismatch(self, trial): + torch.random.manual_seed(trial) + check_output(randn(7, 192, 5, 128), randn(7, 513, 5, 128), randn(7, 513, 5, 128)) + + @pytest.mark.parametrize("trial", range(5)) + def test_repeat_2cta(self, trial): + torch.random.manual_seed(trial) + check_output(randn(9, 257, 3, 128), randn(9, 511, 3, 128), randn(9, 511, 3, 128)) + + @pytest.mark.parametrize("trial", range(5)) + def test_repeat_gqa_mismatch(self, trial): + torch.random.manual_seed(trial) + check_output(randn(5, 128, 8, 128), randn(5, 1024, 2, 128), randn(5, 1024, 2, 128)) + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) \ No newline at end of file