diff --git a/flash_attn/cute/block_sparsity.py b/flash_attn/cute/block_sparsity.py index f19c8fb7f05..c069007873b 100644 --- a/flash_attn/cute/block_sparsity.py +++ b/flash_attn/cute/block_sparsity.py @@ -401,6 +401,7 @@ def to_cute_block_sparse_tensors( """Convert torch block sparsity tensors to CuTe tensors, optionally for tvm ffi""" if not is_block_sparsity_enabled(tensors): return None + ( mask_block_cnt, mask_block_idx, diff --git a/flash_attn/cute/cache_utils.py b/flash_attn/cute/cache_utils.py index 8606f04b62b..3fca0579d98 100644 --- a/flash_attn/cute/cache_utils.py +++ b/flash_attn/cute/cache_utils.py @@ -32,7 +32,11 @@ logger = logging.getLogger(__name__) _handler = logging.StreamHandler() -_handler.setFormatter(logging.Formatter("%(asctime)s.%(msecs)03d %(levelname)s %(message)s", datefmt="%Y-%m-%d %H:%M:%S")) +_handler.setFormatter( + logging.Formatter( + "%(asctime)s.%(msecs)03d %(levelname)s %(message)s", datefmt="%Y-%m-%d %H:%M:%S" + ) +) logger.addHandler(_handler) logger.setLevel(logging.DEBUG) diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index e38be692834..6c9c20d0b76 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -14,7 +14,7 @@ # https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/blackwell/fmha.py import math -from typing import Type, Tuple, Callable, Optional, Literal +from typing import Tuple, Callable, Optional, Literal from functools import partial import cuda.bindings.driver as cuda @@ -27,6 +27,7 @@ import cutlass.utils.blackwell_helpers as sm100_utils_basic from cutlass import pipeline from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait +from cutlass.utils import ClcDynamicPersistentTileScheduler from cutlass.base_dsl.arch import Arch from cutlass.cutlass_dsl import BaseDSL @@ -36,6 +37,7 @@ from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned from flash_attn.cute import utils import flash_attn.cute.pipeline as pipeline_custom +import cutlass.pipeline as cutlass_pipeline from flash_attn.cute.mask import AttentionMask from flash_attn.cute.softmax import SoftmaxSm100, apply_score_mod_inner from flash_attn.cute.seqlen_info import SeqlenInfoQK @@ -54,12 +56,17 @@ from cutlass.cute import FastDivmodDivisor from quack.cute_dsl_utils import ParamsBase from flash_attn.cute.tile_scheduler import ( + ClcState, + SchedulingMode, TileSchedulerArguments, + TileSchedulerProtocol, SingleTileScheduler, StaticPersistentTileScheduler, SingleTileLPTScheduler, SingleTileVarlenScheduler, ) +from flash_attn.cute.fa_logging import fa_log, fa_printf +from flash_attn.cute.utils import smid # === TUNING KNOBS (agent-editable) === # Keys: (use_2cta_instrs: bool, is_causal: bool, head_dim_padded: int, is_sm103: bool) @@ -106,6 +113,7 @@ def __init__( paged_kv_non_tma: bool = False, is_varlen_q: bool = False, use_2cta_instrs: bool = False, + use_clc_scheduler: bool = False, ): self.use_tma_KV = not paged_kv_non_tma # self.dtype = dtype @@ -179,6 +187,32 @@ def __init__( "Paged KV does not support irregular head dim" ) + self.use_clc_scheduler = ( + use_clc_scheduler + and self.use_tma_KV + and not self.overlap_sO_sQ + ) + self.sched_stages = 1 + if self.use_clc_scheduler: + assert self.cluster_shape_mn[1] == 1, f"CLC requires cluster N == 1: {self.cluster_shape_mn}" + assert self.cluster_shape_mn[0] in (1, 2), f"bad CLC cluster M: {self.cluster_shape_mn}" + assert self.cluster_shape_mn[0] == self.cta_group_size, ( + f"CLC cluster M != cta_group_size: {self.cluster_shape_mn}, {self.cta_group_size}" + ) + + self.scheduling_mode = SchedulingMode.CLC if self.use_clc_scheduler else SchedulingMode.STATIC + + if is_varlen_q: + self.TileScheduler = SingleTileVarlenScheduler + elif self.is_causal or self.is_local or self.use_clc_scheduler: + self.TileScheduler = SingleTileLPTScheduler + elif self.is_persistent: + self.TileScheduler = StaticPersistentTileScheduler + else: + self.TileScheduler = SingleTileScheduler + + fa_log(1, f"TileScheduler={self.TileScheduler.__name__}, scheduling_mode={self.scheduling_mode.name}, USE_2CTA={self.use_2cta_instrs}") + self.softmax0_warp_ids = (0, 1, 2, 3) self.softmax1_warp_ids = (4, 5, 6, 7) self.correction_warp_ids = (8, 9, 10, 11) @@ -219,6 +253,8 @@ def __init__( elif self.is_varlen_q: # fallback self.epilogue_warp_ids = (13, 14) + self.clc_scheduler_warp_id = self.empty_warp_ids[0] if self.use_clc_scheduler else None + self.tmem_s_offset = [0, self.n_block_size] # e.g., 0, 128 self.tmem_o_offset = [ self.tmem_s_offset[-1] + self.n_block_size + i * self.head_dim_v_padded @@ -551,19 +587,7 @@ def __call__( vO_layout = cute.make_layout((1, async_copy_elems)) gmem_tiled_copy_O = cute.make_tiled_copy_tv(atom_universal_copy, tO_layout, vO_layout) - if const_expr(mCuSeqlensQ is not None or mSeqUsedQ is not None): - TileScheduler = SingleTileVarlenScheduler - else: - if const_expr(self.is_causal or self.is_local): - TileScheduler = SingleTileLPTScheduler - else: - TileScheduler = ( - SingleTileScheduler - if const_expr(not self.is_persistent) - else StaticPersistentTileScheduler - ) - # For non-persistent 2CTA (use_cluster_idx), each cluster covers - # cta_tiler[0] * cta_group_size rows, so num_block must be divided accordingly + TileScheduler = self.TileScheduler _num_block_divisor = self.cta_tiler[0] * (self.cta_group_size if not self.is_persistent and self.cta_group_size > 1 else 1) tile_sched_args = TileSchedulerArguments( cute.ceil_div(cute.size(mQ.shape[0]), _num_block_divisor), @@ -591,7 +615,9 @@ def __call__( cluster_shape_mn=self.cluster_shape_mn, use_cluster_idx=not self.is_persistent and self.cta_group_size > 1, ) - tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) + tile_sched_params = TileScheduler.to_underlying_arguments( + tile_sched_args, scheduling_mode=self.scheduling_mode + ) self.tile_scheduler_cls = TileScheduler grid_dim = TileScheduler.get_grid_shape(tile_sched_params) @@ -601,6 +627,9 @@ def __call__( cutlass.max(cute.cosize(sQ_layout), cute.cosize(sO_layout) * self.o_dtype.width // self.q_dtype.width) ) + clc_response_size = self.sched_stages * 4 if self.use_clc_scheduler else 0 + clc_mbar_size = self.sched_stages * 2 if self.use_clc_scheduler else 0 + @cute.struct class SharedStorage: # m_barriers for pipelines @@ -620,6 +649,13 @@ class SharedStorage: # Smem tensors # store row max and row sum sScale: cute.struct.MemRange[Float32, self.q_stage * self.m_block_size * 2] + # CLC buffers placed here to utilize padding before sO's 1024-byte alignment. + # This avoids adding bytes at the end when we're at the smem limit. + # PipelineClcFetchAsync expects 2 * sched_stages mbarriers (full + empty). + clc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, clc_mbar_size] + # CLC response storage (16 bytes per stage, stored as 4 Int32s). + clc_response: cute.struct.MemRange[Int32, clc_response_size] + # Large TMA buffers with 1024-byte alignment sO: cute.struct.Align[ cute.struct.MemRange[self.o_dtype, sO_size], self.buffer_align_bytes ] @@ -980,17 +1016,69 @@ def kernel( window_size_right=window_size_right, qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, ) - TileSchedulerCls = partial(self.tile_scheduler_cls.create, tile_sched_params) - # Cluster wait before tensor memory alloc pipeline_init_wait(cluster_shape_mn=cta_layout_vmnk) + if const_expr(self.use_clc_scheduler): + clc_response_ptr = storage.clc_response.data_ptr() + clc_mbar_ptr = storage.clc_mbar_ptr.data_ptr() + + clc_pipeline_producer_group = cutlass_pipeline.CooperativeGroup( + cutlass_pipeline.Agent.Thread + ) + num_clc_consumer_warps_per_cta = self.threads_per_cta // cute.arch.WARP_SIZE + # NB on CTA0 warp15 == scheduler on CTA1 == empty but still both consume + num_clc_consumer_warps = num_clc_consumer_warps_per_cta * self.cta_group_size + clc_pipeline_consumer_group = cutlass_pipeline.CooperativeGroup( + cutlass_pipeline.Agent.Thread, cute.arch.WARP_SIZE * num_clc_consumer_warps + ) + + block_idx = cute.arch.block_idx() + clc = ClcState.create( + hw_scheduler=ClcDynamicPersistentTileScheduler.create( + self.tile_scheduler_cls.clc_problem_shape(tile_sched_params), + block_idx, + cute.arch.grid_dim(), + clc_response_ptr, + ), + pipeline=cutlass_pipeline.PipelineClcFetchAsync.create( + barrier_storage=clc_mbar_ptr, + num_stages=self.sched_stages, + producer_group=clc_pipeline_producer_group, + consumer_group=clc_pipeline_consumer_group, + tx_count=16, + cta_layout_vmnk=cta_layout_vmnk, + ), + consumer_state=cutlass_pipeline.make_pipeline_state( + cutlass_pipeline.PipelineUserType.Consumer, self.sched_stages + ), + producer_state=cutlass_pipeline.make_pipeline_state( + cutlass_pipeline.PipelineUserType.Producer, self.sched_stages + ), + ) + tile_scheduler = self.tile_scheduler_cls.create(tile_sched_params, clc=clc) + else: + tile_scheduler = self.tile_scheduler_cls.create(tile_sched_params) + assert isinstance(tile_scheduler, TileSchedulerProtocol), f"tile_scheduler is not a TileSchedulerProtocol: {type(tile_scheduler)}" + # /////////////////////////////////////////////////////////////////////////////// - # EMPTY + # EMPTY / CLC SCHEDULER WARP # /////////////////////////////////////////////////////////////////////////////// - for i in cutlass.range_constexpr(len(self.empty_warp_ids)): - if warp_idx == self.empty_warp_ids[i]: + if const_expr(self.use_clc_scheduler): + if warp_idx == self.clc_scheduler_warp_id: cute.arch.setmaxregister_decrease(self.num_regs_other) + if is_leader_cta: + self.clc_scheduler_warp(tile_scheduler) + else: + self.empty_warp(tile_scheduler) + for i in cutlass.range_constexpr(len(self.empty_warp_ids)): + if warp_idx == self.empty_warp_ids[i] and warp_idx != self.clc_scheduler_warp_id: + cute.arch.setmaxregister_decrease(self.num_regs_other) + self.empty_warp(tile_scheduler) + else: + for i in cutlass.range_constexpr(len(self.empty_warp_ids)): + if warp_idx == self.empty_warp_ids[i]: + cute.arch.setmaxregister_decrease(self.num_regs_other) # /////////////////////////////////////////////////////////////////////////////// # LOAD @@ -1016,8 +1104,8 @@ def kernel( block_info, num_splits, SeqlenInfoCls, - TileSchedulerCls, blocksparse_tensors, + tile_scheduler=tile_scheduler, ) # /////////////////////////////////////////////////////////////////////////////// @@ -1047,8 +1135,8 @@ def kernel( block_info, num_splits, SeqlenInfoCls, - TileSchedulerCls, blocksparse_tensors, + tile_scheduler=tile_scheduler, ) # Dealloc the tensor memory buffer tmem.relinquish_alloc_permit() @@ -1070,8 +1158,8 @@ def kernel( block_info, num_splits, SeqlenInfoCls, - TileSchedulerCls, mma_tile_coord_v, + tile_scheduler=tile_scheduler, ) # /////////////////////////////////////////////////////////////////////////////// @@ -1103,11 +1191,11 @@ def kernel( num_splits=num_splits, SeqlenInfoCls=SeqlenInfoCls, AttentionMaskCls=AttentionMaskCls, - TileSchedulerCls=TileSchedulerCls, aux_tensors=aux_tensors, fastdiv_mods=fastdiv_mods, head_divmod=head_divmod, blocksparse_tensors=blocksparse_tensors, + tile_scheduler=tile_scheduler, ) if const_expr(not self.s0_s1_barrier): @@ -1151,8 +1239,8 @@ def kernel( block_info, num_splits, SeqlenInfoCls, - TileSchedulerCls, blocksparse_tensors, + tile_scheduler=tile_scheduler, ) tmem_alloc_barrier.arrive() @@ -1179,8 +1267,8 @@ def load( block_info: BlockInfo, num_splits: Int32, SeqlenInfoCls: Callable, - TileSchedulerCls: Callable, blocksparse_tensors: Optional[BlockSparseTensors], + tile_scheduler: TileSchedulerProtocol, ): num_load_threads = len(self.load_warp_ids) * cute.arch.WARP_SIZE tidx = cute.arch.thread_idx()[0] % num_load_threads @@ -1197,7 +1285,6 @@ def load( kv_producer_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Producer, self.kv_stage ) - tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx @@ -1368,9 +1455,8 @@ def load( self.q_subtile_factor if self.q_subtile_factor is not None else 1, ) - tile_scheduler.prefetch_next_work() - tile_scheduler.advance_to_next_work() - work_tile = tile_scheduler.get_current_work() + + work_tile = tile_scheduler.advance_to_next_work() # End of persistent scheduler loop if issue_kv_for_this_warp: @@ -1399,8 +1485,8 @@ def mma( block_info: BlockInfo, num_splits: Int32, SeqlenInfoCls: Callable, - TileSchedulerCls: Callable, blocksparse_tensors: Optional[BlockSparseTensors], + tile_scheduler=None, ): tSrQ = tiled_mma_qk.make_fragment_A(sQ) tSrK = tiled_mma_qk.make_fragment_B(sK) @@ -1489,7 +1575,6 @@ def mma( ) P_full_O_rescaled_phase = Int32(0) - tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx @@ -1660,8 +1745,7 @@ def mma( # End of GEMM_PV1(i_end) (P1 * Vi_end -> O1) # Advance to next tile - tile_scheduler.advance_to_next_work() - work_tile = tile_scheduler.get_current_work() + work_tile = tile_scheduler.advance_to_next_work() # End of persistent scheduler loop # We don't need pipeline_s_p_o.producer_tail() since there's no dangling mbarrier at the end @@ -1690,11 +1774,11 @@ def softmax_loop( num_splits: Int32, SeqlenInfoCls: Callable, AttentionMaskCls: Callable, - TileSchedulerCls: Callable, aux_tensors: Optional[list] = None, fastdiv_mods=(None, None), head_divmod=None, blocksparse_tensors: Optional[BlockSparseTensors] = None, + tile_scheduler=None, ): """Compute softmax on attention scores from QK matrix multiplication. @@ -1754,7 +1838,6 @@ def softmax_loop( warp_idx_in_wg = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4 - tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx @@ -1997,8 +2080,7 @@ def softmax_loop( # gLSE[tidx] = lse # Advance to next tile - tile_scheduler.advance_to_next_work() - work_tile = tile_scheduler.get_current_work() + work_tile = tile_scheduler.advance_to_next_work() # End of persistent scheduler loop # This is equivalent to pipeline_sm_stats.producer_tail @@ -2168,8 +2250,8 @@ def correction_loop( block_info: BlockInfo, num_splits: Int32, SeqlenInfoCls: Callable, - TileSchedulerCls: Callable, blocksparse_tensors: Optional[BlockSparseTensors] = None, + tile_scheduler=None, ): tidx = cute.arch.thread_idx()[0] % (cute.arch.WARP_SIZE * len(self.correction_warp_ids)) warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4 @@ -2199,7 +2281,6 @@ def correction_loop( o_corr_consumer_phase = Int32(0) corr_epi_producer_phase = Int32(1) - tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx @@ -2430,8 +2511,7 @@ def correction_loop( cute.make_tensor(lse_gmem_ptr, (1,))[0] = lse # Advance to next tile - tile_scheduler.advance_to_next_work() - work_tile = tile_scheduler.get_current_work() + work_tile = tile_scheduler.advance_to_next_work() # End of persistent scheduler loop # This is equivalent to pipeline_o_epi.consumer_tail() for the correction warps @@ -2638,11 +2718,10 @@ def epilogue_s2g( block_info: BlockInfo, num_splits: int, SeqlenInfoCls: Callable, - TileSchedulerCls: Callable, mma_tile_coord_v: Int32 = 0, + tile_scheduler=None, ): epi_consumer_phase = Int32(0) - tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx @@ -2698,8 +2777,39 @@ def epilogue_s2g( epi_consumer_phase ^= 1 # Advance to next tile - tile_scheduler.advance_to_next_work() - work_tile = tile_scheduler.get_current_work() + work_tile = tile_scheduler.advance_to_next_work() + + @cute.jit + def clc_scheduler_warp( + self, + tile_scheduler: TileSchedulerProtocol, + ): + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + tile_scheduler.prefetch_next_work() + work_tile = tile_scheduler.advance_to_next_work() + if cute.arch.thread_idx()[0] == self.clc_scheduler_warp_id * cute.arch.WARP_SIZE: + fa_printf( + 3, + "[CLC] query sm={} cta={} (m_blk={},h={},b={},s={}) valid={}\n", + smid(), + cute.arch.block_idx()[0], + work_tile.tile_idx[0], + work_tile.tile_idx[1], + work_tile.tile_idx[2], + work_tile.tile_idx[3], + work_tile.is_valid_tile, + ) + tile_scheduler.producer_tail() + + @cute.jit + def empty_warp( + self, + tile_scheduler: TileSchedulerProtocol, + ): + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + work_tile = tile_scheduler.advance_to_next_work() def load_Q( self, diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index f922cb26477..e4b55456ffe 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -448,7 +448,9 @@ def _flash_attn_fwd( causal, window_size_left, window_size_right, mask_mod ) - # In fake mode (CPU-only compilation), use a fake stream placeholder. + requested_use_clc_scheduler = utils._get_use_clc_scheduler_default() + requested_disable_2cta = utils._get_disable_2cta_default() + current_stream = cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True) # SM80/SM120: uses SM80 MMA, 128 threads (4 warps) @@ -517,6 +519,7 @@ def _flash_attn_fwd( use_2cta_instrs = ( arch // 10 in [10, 11] + and not requested_disable_2cta and not causal and not local and not is_split_kv @@ -621,6 +624,7 @@ def _flash_attn_fwd( q_subtile_factor, mma_pv_is_rs, intra_wg_overlap, + requested_use_clc_scheduler, fa_logging.get_fa_log_level(), ) if compile_key not in _flash_attn_fwd.compile_cache: @@ -728,6 +732,7 @@ def _flash_attn_fwd( is_varlen_q=cu_seqlens_q is not None or seqused_q is not None, q_subtile_factor=q_subtile_factor, use_2cta_instrs=use_2cta_instrs, + use_clc_scheduler=requested_use_clc_scheduler, ) elif arch // 10 == 12: # SM120 (Blackwell GeForce / DGX Spark): uses SM80 MMA with SM120 SMEM capacity @@ -1056,8 +1061,10 @@ def _flash_attn_bwd( dKV_swapAB = False AtomLayoutMdQ = 1 AtomLayoutNdKV = 1 + requested_disable_2cta = utils._get_disable_2cta_default() disable_2cta = ( - score_mod is not None + requested_disable_2cta + or score_mod is not None or score_mod_bwd is not None or mask_mod is not None ) diff --git a/flash_attn/cute/tile_scheduler.py b/flash_attn/cute/tile_scheduler.py index c7067ae154b..2a1adfc1b42 100644 --- a/flash_attn/cute/tile_scheduler.py +++ b/flash_attn/cute/tile_scheduler.py @@ -1,6 +1,7 @@ # Copyright (c) 2025, Tri Dao. -from typing import Optional, Tuple +from enum import IntEnum, auto +from typing import Optional, Tuple, Protocol, runtime_checkable from dataclasses import dataclass try: @@ -9,10 +10,12 @@ from typing_extensions import override import cutlass +from cutlass.pipeline import PipelineClcFetchAsync, PipelineState from cutlass._mlir import ir import cutlass.cute as cute from cutlass import Int32, const_expr from cutlass.cute import FastDivmodDivisor +from cutlass.utils import ClcDynamicPersistentTileScheduler, ClcDynamicPersistentTileSchedulerParams from quack.cute_dsl_utils import ParamsBase @@ -20,6 +23,67 @@ from flash_attn.cute.fast_math import clz +class SchedulingMode(IntEnum): + NONE = auto() + STATIC = auto() + DYNAMIC = auto() + CLC = auto() + + +@dataclass +class ClcState(ParamsBase): + """Owns the runtime state shared by CLC-capable tile schedulers. + + `FlashAttentionForwardSm100` constructs this state because it owns the CLC + response buffer, mbarrier storage, and launch geometry needed to initialize + the hardware scheduler and async pipeline. Individual tile schedulers then + consume this state and map the returned hardware work tiles into their own + logical `WorkTileInfo` coordinates. + + To add CLC support to a scheduler: + - implement `clc_problem_shape(params)` so the kernel can create the hardware scheduler + - accept `clc: ClcState | None` in `create(...)` / `__init__` + - map `clc.initial_work_tile_info()` and `clc.get_current_work()` into scheduler coordinates + """ + + _hw_scheduler: ClcDynamicPersistentTileScheduler + _pipeline: PipelineClcFetchAsync + _consumer_state: PipelineState + _producer_state: PipelineState + + @staticmethod + def create( + *, + hw_scheduler: ClcDynamicPersistentTileScheduler, + pipeline: PipelineClcFetchAsync, + consumer_state: PipelineState, + producer_state: PipelineState, + ) -> "ClcState": + return ClcState(hw_scheduler, pipeline, consumer_state, producer_state) + + def initial_work_tile_info(self): + return self._hw_scheduler.initial_work_tile_info() + + def get_current_work(self): + return self._hw_scheduler.get_current_work() + + def prefetch_next_work(self, *, loc=None, ip=None): + self._pipeline.producer_acquire(self._producer_state, loc=loc, ip=ip) + mbarrier_addr = self._pipeline.producer_get_barrier(self._producer_state, loc=loc, ip=ip) + self._hw_scheduler.advance_to_next_work(mbarrier_addr, loc=loc, ip=ip) + self._producer_state.advance(loc=loc, ip=ip) + + def consumer_wait(self, *, loc=None, ip=None): + self._pipeline.consumer_wait(self._consumer_state, loc=loc, ip=ip) + + def consumer_release(self, *, loc=None, ip=None): + self._pipeline.consumer_release(self._consumer_state, loc=loc, ip=ip) + self._consumer_state.advance(loc=loc, ip=ip) + + def producer_tail(self, *, loc=None, ip=None): + self._pipeline.producer_tail(self._producer_state, loc=loc, ip=ip) + + class WorkTileInfo(cutlass.utils.WorkTileInfo): """Altered WorkTileInfo which includes four axes: (block, head, batch, split)""" @@ -31,6 +95,47 @@ def __new_from_mlir_values__(self, values: list[ir.Value]) -> "WorkTileInfo": return WorkTileInfo(new_tile_idx, new_is_valid_tile) +@runtime_checkable +class TileSchedulerProtocol(Protocol): + """Protocol defining the interface all tile schedulers must implement. + + Schedulers are responsible for: + 1. Coordinate mapping: linear tile index -> (m_block, head, batch, split) + 2. Work distribution: how to get the next tile (static grid-stride vs CLC dynamic) + """ + + def get_current_work(self) -> WorkTileInfo: + """Get the current work tile coordinates.""" + ... + + def initial_work_tile_info(self) -> WorkTileInfo: + """Get the initial work tile for this CTA.""" + ... + + def advance_to_next_work(self, *, loc=None, ip=None): + """Consumer-side advance: move to next tile and return it. + + For static schedulers: grid-stride increment + get_current_work. + For CLC schedulers: consumer wait + get_current_work + consumer release + state advance. + """ + ... + + def prefetch_next_work(self, *, loc=None, ip=None) -> None: + """Producer-side prefetch of next work tile (no-op for static schedulers). + + For CLC schedulers: producer acquire + issue CLC query + producer state advance. + Only called by the scheduler warp. + """ + ... + + def producer_tail(self, *, loc=None, ip=None) -> None: + """Producer-side cleanup after the last tile. + + No-op for static schedulers. For CLC schedulers: pipeline producer_tail. + """ + ... + + @dataclass class TileSchedulerArguments(ParamsBase): num_block: Int32 @@ -89,15 +194,25 @@ def __init__(self, params: Params, blk_coord: cute.Coord, *, loc=None, ip=None): self._ip = ip @staticmethod - def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params: + def to_underlying_arguments( + args: TileSchedulerArguments, + *, + scheduling_mode: SchedulingMode = SchedulingMode.STATIC, + loc=None, + ip=None, + ) -> Params: + assert scheduling_mode == SchedulingMode.STATIC, ( + f"SingleTileScheduler only supports STATIC, got {scheduling_mode!r}" + ) return SingleTileScheduler.Params.create(args, loc=loc, ip=ip) @staticmethod - def create(params: Params, *, loc=None, ip=None) -> "SingleTileScheduler": + def create( + params: Params, clc: ClcState | None = None, *, loc=None, ip=None + ) -> "SingleTileScheduler": if const_expr(cute.size(params.cluster_shape_mn) == 1 or not params.use_cluster_idx): blk_coord = cute.arch.block_idx() else: - # All CTAs in a cluster must get the same block coordinate blk_coord = cute.arch.cluster_idx() return SingleTileScheduler(params, blk_coord, loc=loc, ip=ip) @@ -141,6 +256,10 @@ def prefetch_next_work(self, *, loc=None, ip=None): def advance_to_next_work(self, *, loc=None, ip=None): self._is_first_block = False + return self.get_current_work() + + def producer_tail(self, *, loc=None, ip=None): + pass def __extract_mlir_values__(self): values, self._values_pos = [], [] @@ -186,18 +305,28 @@ def __init__(self, params: Params, tile_idx: Int32, *, loc=None, ip=None): self._ip = ip @staticmethod - def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params: + def to_underlying_arguments( + args: TileSchedulerArguments, + *, + scheduling_mode: SchedulingMode = SchedulingMode.STATIC, + loc=None, + ip=None, + ) -> Params: + assert scheduling_mode == SchedulingMode.STATIC, ( + f"StaticPersistentTileScheduler only supports STATIC, got {scheduling_mode!r}" + ) return StaticPersistentTileScheduler.Params.create(args, loc=loc, ip=ip) @staticmethod - def create(params: Params, *, loc=None, ip=None) -> "StaticPersistentTileScheduler": + def create( + params: Params, clc: ClcState | None = None, *, loc=None, ip=None + ) -> "StaticPersistentTileScheduler": if const_expr(cute.size(params.cluster_shape_m) == 1): tile_idx = cute.arch.block_idx()[0] else: tile_idx = cute.arch.cluster_idx()[0] return StaticPersistentTileScheduler(params, tile_idx, loc=loc, ip=ip) - # called by host @staticmethod def get_grid_shape( params: Params, @@ -207,18 +336,14 @@ def get_grid_shape( ) -> Tuple[Int32, Int32, Int32]: hardware_info = cutlass.utils.HardwareInfo() sm_count = hardware_info.get_device_multiprocessor_count() - # Grid must be a multiple of cluster_shape_m for CUDA cluster launch. max_ctas = (sm_count // params.cluster_shape_m) * params.cluster_shape_m grid_x = cutlass.min(max_ctas, params.total_blocks_cluster * params.cluster_shape_m) return (grid_x, Int32(1), Int32(1)) - # @cute.jit def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: hn_idx, block_idx = divmod(self._tile_idx, self.params.num_block_cluster_divmod) batch_idx, head_idx = divmod(hn_idx, self.params.num_head_divmod) is_valid = self._tile_idx < self.params.total_blocks_cluster - # if cute.arch.thread_idx()[0] == 0: - # cute.printf("TileScheduler: tile_idx=%d, hn_idx=%d, block_idx=%d, batch_idx=%d, head_idx=%d, is_valid=%d", self._tile_idx, hn_idx, block_idx, batch_idx, head_idx, is_valid) return WorkTileInfo( (Int32(block_idx), Int32(head_idx), Int32(batch_idx), Int32(0)), is_valid ) @@ -234,6 +359,10 @@ def advance_to_next_work(self, *, loc=None, ip=None): self._tile_idx += cute.arch.grid_dim()[0] else: self._tile_idx += cute.arch.cluster_dim()[0] + return self.get_current_work() + + def producer_tail(self, *, loc=None, ip=None): + pass def __extract_mlir_values__(self): values, self._values_pos = [], [] @@ -260,32 +389,41 @@ class Params(ParamsBase): total_blocks: Int32 num_splits: Int32 num_block: Int32 + num_head: Int32 + num_batch: Int32 l2_minor: Int32 - num_block_divmod: FastDivmodDivisor num_head_divmod: FastDivmodDivisor l2_minor_divmod: FastDivmodDivisor l2_major_divmod: FastDivmodDivisor l2_minor_residual_divmod: FastDivmodDivisor num_hb_quotient: Int32 + num_splits_divmod: FastDivmodDivisor is_split_kv: cutlass.Constexpr[bool] = False + cluster_shape_m: cutlass.Constexpr[int] = 1 + scheduling_mode: cutlass.Constexpr[SchedulingMode] = SchedulingMode.STATIC + lpt: cutlass.Constexpr[bool] = True @staticmethod @cute.jit def create( - args: TileSchedulerArguments, *, loc=None, ip=None + args: TileSchedulerArguments, + *, + scheduling_mode: SchedulingMode = SchedulingMode.STATIC, + loc=None, + ip=None, ) -> "SingleTileLPTScheduler.Params": - # cute.printf(args.num_block, args.num_head, args.num_batch, args.seqlen_k, args.headdim, args.headdim_v, args.total_q, args.tile_shape_mn, args.qhead_per_kvhead_packgqa, args.element_size) + assert scheduling_mode in (SchedulingMode.STATIC, SchedulingMode.CLC), ( + f"Only STATIC and CLC are supported, got {scheduling_mode!r}" + ) size_one_kv_head = args.seqlen_k * (args.headdim + args.headdim_v) * args.element_size size_one_head = size_one_kv_head size_l2 = 50 * 1024 * 1024 # 40 MB for K & V # Swizzle is the size of each "section". Round swizzle to a power of 2 # Need to be careful about the case where only one head will fit # swizzle is how many heads can fit in L2 - # swizzle = 1 if size_l2 < size_one_head else (size_l2 // size_one_head) - # Seems faster if swizzle if a power of 2 + # Seems faster if swizzle is a power of 2 log2_floor = lambda n: 31 - clz(n) swizzle = 1 if size_l2 < size_one_head else (1 << log2_floor(size_l2 // size_one_head)) - # swizzle = 1 if size_l2 < size_one_head else (size_l2 // size_one_head) # If we're in the last section (called residual), we don't want to divide by # swizzle. Instead we want to divide by the remainder. num_hb_quotient = (args.num_head * args.num_batch) // swizzle @@ -293,37 +431,84 @@ def create( return SingleTileLPTScheduler.Params( total_blocks=args.num_block * args.num_head * args.num_batch, num_block=args.num_block, + num_head=args.num_head, + num_batch=args.num_batch, l2_minor=Int32(swizzle), - num_block_divmod=FastDivmodDivisor(args.num_block), num_head_divmod=FastDivmodDivisor(args.num_head), l2_minor_divmod=FastDivmodDivisor(swizzle), l2_major_divmod=FastDivmodDivisor(swizzle * args.num_block), - l2_minor_residual_divmod=FastDivmodDivisor( - max(num_hb_remainder, 1) - ), # don't divide by 0 + l2_minor_residual_divmod=FastDivmodDivisor(max(num_hb_remainder, 1)), num_hb_quotient=Int32(num_hb_quotient), num_splits=args.num_splits, + num_splits_divmod=FastDivmodDivisor(args.num_splits), is_split_kv=args.is_split_kv, + cluster_shape_m=args.cluster_shape_mn[0], + scheduling_mode=scheduling_mode, + lpt=args.lpt, ) - def __init__(self, params: Params, tile_idx: Int32, split_idx: Int32, *, loc=None, ip=None): + def __init__( + self, + params: Params, + tile_idx: Int32, + split_idx: Int32, + clc: ClcState | None = None, + *, + loc=None, + ip=None, + ): self.params = params self._tile_idx = tile_idx self._split_idx = split_idx + self.clc = clc self._loc = loc self._ip = ip @staticmethod - def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params: - return SingleTileLPTScheduler.Params.create(args, loc=loc, ip=ip) + def to_underlying_arguments( + args: TileSchedulerArguments, + *, + scheduling_mode: SchedulingMode = SchedulingMode.STATIC, + loc=None, + ip=None, + ) -> Params: + return SingleTileLPTScheduler.Params.create( + args, scheduling_mode=scheduling_mode, loc=loc, ip=ip + ) + + @staticmethod + def _clc_grid_shape(params: Params): + num_batch_splits = ( + params.num_batch * params.num_splits + if const_expr(params.is_split_kv) + else params.num_batch + ) + return ( + cute.round_up(params.num_block, params.cluster_shape_m), + params.num_head, + num_batch_splits, + ) + + @staticmethod + @cute.jit + def clc_problem_shape(params: Params): + return ClcDynamicPersistentTileSchedulerParams( + problem_shape_ntile_mnl=SingleTileLPTScheduler._clc_grid_shape(params), + cluster_shape_mnk=(params.cluster_shape_m, 1, 1), + ) @staticmethod @cute.jit - def create(params: Params, *, loc=None, ip=None) -> "SingleTileLPTScheduler": + def create( + params: Params, clc: ClcState | None = None, *, loc=None, ip=None + ) -> "SingleTileLPTScheduler": + if const_expr(params.scheduling_mode == SchedulingMode.CLC): + return SingleTileLPTScheduler( + params, cute.arch.block_idx()[0], Int32(0), clc, loc=loc, ip=ip + ) tile_idx, split_idx, _ = cute.arch.block_idx() return SingleTileLPTScheduler(params, tile_idx, split_idx, loc=loc, ip=ip) - # called by host @staticmethod def get_grid_shape( params: Params, @@ -331,10 +516,40 @@ def get_grid_shape( loc=None, ip=None, ) -> Tuple[Int32, Int32, Int32]: + if const_expr(params.scheduling_mode == SchedulingMode.CLC): + return SingleTileLPTScheduler._clc_grid_shape(params) return (params.total_blocks, params.num_splits, Int32(1)) + @cute.jit + def clc_work_to_coords(self, work) -> WorkTileInfo: + """Convert CLC response (block, head, batch_split) to WorkTileInfo. + + CLC returns raw grid coordinates — no L2 swizzle (hardware decides order). + We only apply cluster division, optional LPT block reversal, and split_kv unpacking. + """ + block_idx = work.tile_idx[0] + if const_expr(self.params.cluster_shape_m > 1): + block_idx = block_idx // self.params.cluster_shape_m + if const_expr(self.params.lpt): + # Longest-processing-time-first: reverse block order + block_idx = self.params.num_block - 1 - block_idx + split_idx = Int32(0) + if const_expr(self.params.is_split_kv): + batch_idx, split_idx = divmod(work.tile_idx[2], self.params.num_splits_divmod) + else: + batch_idx = work.tile_idx[2] + return WorkTileInfo( + (Int32(block_idx), Int32(work.tile_idx[1]), Int32(batch_idx), Int32(split_idx)), + work.is_valid_tile, + ) + @cute.jit def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + work = self.clc.get_current_work() + self._tile_idx = work.tile_idx[0] + return self.clc_work_to_coords(work) + # Static path: L2-swizzled coordinate mapping params = self.params # Implement LPT scheduling coordinate calculation bidhb, l2_mod = divmod(self._tile_idx, params.l2_major_divmod) @@ -348,25 +563,44 @@ def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: bidhb_actual = bidhb * params.l2_minor + bidhb_residual batch_idx, head_idx = divmod(bidhb_actual, params.num_head_divmod) # Longest-processing-time-first - block = params.num_block - 1 - block + if const_expr(params.lpt): + block = params.num_block - 1 - block is_valid = self._tile_idx < params.total_blocks return WorkTileInfo( (Int32(block), Int32(head_idx), Int32(batch_idx), Int32(self._split_idx)), is_valid ) + @cute.jit def initial_work_tile_info(self, *, loc=None, ip=None): + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + work = self.clc.initial_work_tile_info() + self._tile_idx = work.tile_idx[0] + return self.clc_work_to_coords(work) return self.get_current_work(loc=loc, ip=ip) def prefetch_next_work(self, *, loc=None, ip=None): - pass + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + self.clc.prefetch_next_work(loc=loc, ip=ip) def advance_to_next_work(self, *, loc=None, ip=None): - # Single tile scheduler - set to invalid tile_idx to indicate no more work + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + self.clc.consumer_wait(loc=loc, ip=ip) + work = self.get_current_work() + self.clc.consumer_release(loc=loc, ip=ip) + return work self._tile_idx = self.params.total_blocks + return self.get_current_work() + + def producer_tail(self, *, loc=None, ip=None): + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + self.clc.producer_tail(loc=loc, ip=ip) def __extract_mlir_values__(self): values, self._values_pos = [], [] - for obj in [self.params, self._tile_idx, self._split_idx]: + objs = [self.params, self._tile_idx, self._split_idx] + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + objs += [self.clc] + for obj in objs: obj_values = cutlass.extract_mlir_values(obj) values += obj_values self._values_pos.append(len(obj_values)) @@ -374,10 +608,13 @@ def __extract_mlir_values__(self): def __new_from_mlir_values__(self, values): obj_list = [] - for obj, n_items in zip([self.params, self._tile_idx, self._split_idx], self._values_pos): + objs = [self.params, self._tile_idx, self._split_idx] + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + objs += [self.clc] + for obj, n_items in zip(objs, self._values_pos): obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) values = values[n_items:] - return self.__class__(*(tuple(obj_list)), loc=self._loc) + return self.__class__(*obj_list, loc=self._loc) class SingleTileLPTBwdScheduler: @@ -436,7 +673,16 @@ def __init__(self, params: Params, tile_idx: Int32, *, loc=None, ip=None): self._ip = ip @staticmethod - def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params: + def to_underlying_arguments( + args: TileSchedulerArguments, + *, + scheduling_mode: SchedulingMode = SchedulingMode.STATIC, + loc=None, + ip=None, + ) -> Params: + assert scheduling_mode == SchedulingMode.STATIC, ( + f"SingleTileLPTBwdScheduler only supports STATIC, got {scheduling_mode!r}" + ) return SingleTileLPTBwdScheduler.Params.create(args, loc=loc, ip=ip) @staticmethod @@ -487,6 +733,7 @@ def prefetch_next_work(self, *, loc=None, ip=None): def advance_to_next_work(self, *, loc=None, ip=None): # Single tile scheduler - set to invalid tile_idx to indicate no more work self._tile_idx = self.params.total_blocks + return self.get_current_work() def __extract_mlir_values__(self): values, self._values_pos = [], [] @@ -528,7 +775,9 @@ def create( ) -> "SingleTileVarlenScheduler.Params": size_l2 = 50 * 1024 * 1024 # 50 MB for K & V # if backward, this is qdo block size - kv_block_size = (args.headdim + args.headdim_v) * args.element_size * args.tile_shape_mn[1] + kv_block_size = ( + (args.headdim + args.headdim_v) * args.element_size * args.tile_shape_mn[1] + ) # if backward, add dqaccum block size to calculate swizzle if args.head_swizzle: kv_block_size += args.headdim * 4 * args.tile_shape_mn[1] @@ -562,7 +811,16 @@ def __init__(self, params: Params, tile_idx: Int32, split_idx: Int32, *, loc=Non self._ip = ip @staticmethod - def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params: + def to_underlying_arguments( + args: TileSchedulerArguments, + *, + scheduling_mode: SchedulingMode = SchedulingMode.STATIC, + loc=None, + ip=None, + ) -> Params: + assert scheduling_mode == SchedulingMode.STATIC, ( + f"SingleTileVarlenScheduler only supports STATIC, got {scheduling_mode!r}" + ) return SingleTileVarlenScheduler.Params.create(args, loc=loc, ip=ip) @staticmethod @@ -717,6 +975,7 @@ def prefetch_next_work(self, *, loc=None, ip=None): def advance_to_next_work(self, *, loc=None, ip=None): # Single tile scheduler - set to invalid tile_idx to indicate no more work self._is_first_block = False + return self.get_current_work() def __extract_mlir_values__(self): values, self._values_pos = [], [] diff --git a/flash_attn/cute/utils.py b/flash_attn/cute/utils.py index 2d8767c87f7..31186618569 100644 --- a/flash_attn/cute/utils.py +++ b/flash_attn/cute/utils.py @@ -3,12 +3,13 @@ import math import hashlib import inspect +import os from typing import Type, Callable, Optional, Tuple, overload import cutlass import cutlass.cute as cute -from cutlass import Float32, const_expr +from cutlass import Float32, Int32, const_expr from cutlass.cute import FastDivmodDivisor from cutlass.cutlass_dsl import T, dsl_user_op from cutlass._mlir.dialects import nvvm, llvm @@ -55,6 +56,17 @@ ), } +_fa_clc_enabled: bool = os.environ.get("FA_CLC", "0") == "1" +_fa_disable_2cta_enabled: bool = os.environ.get("FA_DISABLE_2CTA", "0") == "1" + + +def _get_use_clc_scheduler_default() -> bool: + return _fa_clc_enabled + + +def _get_disable_2cta_default() -> bool: + return _fa_disable_2cta_enabled + def _compute_base_hash(func: Callable) -> str: """Compute hash from source code or bytecode and closure values.""" @@ -250,6 +262,21 @@ def warp_reduce( return val +@dsl_user_op +def smid(*, loc=None, ip=None) -> Int32: + return Int32( + llvm.inline_asm( + T.i32(), + [], + "mov.u32 $0, %smid;", + "=r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + @dsl_user_op def fmax( a: float | Float32, b: float | Float32, c: float | Float32 | None = None, *, loc=None, ip=None diff --git a/tests/cute/conftest.py b/tests/cute/conftest.py index 86deb53608d..d2162255775 100644 --- a/tests/cute/conftest.py +++ b/tests/cute/conftest.py @@ -59,6 +59,9 @@ def pytest_configure(config): os.environ["CUDA_VISIBLE_DEVICES"] = gpu_ids[worker_num % len(gpu_ids)] def pytest_collection_finish(session): + if not session.config.option.collectonly: + return + # file_name -> test_name -> counter test_counts: dict[str, dict[str, int]] = {} for item in session.items: diff --git a/tests/cute/test_clc_fuzz.py b/tests/cute/test_clc_fuzz.py new file mode 100644 index 00000000000..b51a6bb6323 --- /dev/null +++ b/tests/cute/test_clc_fuzz.py @@ -0,0 +1,394 @@ +"""Adversarial regression tests for CLC tile scheduling. + +These cases intentionally target scheduler-sensitive shapes: mismatched +sequence lengths, non-aligned tiles, GQA ratios, minimal problems, and +larger persistent workloads. This is deterministic adversarial coverage, +not randomized fuzzing. +""" + +from contextlib import contextmanager +import os +from unittest import mock + +import pytest +import torch + +from flash_attn.cute import utils as cute_utils +from flash_attn.cute.flash_fwd_sm100 import FlashAttentionForwardSm100 +from flash_attn.cute.interface import flash_attn_func, flash_attn_varlen_func +from flash_attn.cute.testing import attention_ref +from flash_attn.cute.tile_scheduler import SchedulingMode, SingleTileLPTScheduler, SingleTileVarlenScheduler + + +if torch.cuda.is_available(): + COMPUTE_CAPABILITY = torch.cuda.get_device_capability()[0] + SM_COUNT = torch.cuda.get_device_properties("cuda").multi_processor_count +else: + COMPUTE_CAPABILITY = 0 + SM_COUNT = 0 +pytestmark = pytest.mark.skipif( + COMPUTE_CAPABILITY not in (10, 11), + reason="CLC adversarial tests require SM100/SM110 persistent forward", +) + +_captured_schedulers: list[tuple[type, SchedulingMode]] = [] +_orig_init = FlashAttentionForwardSm100.__init__ + + +def _spy_init(self_inner, *a, **kw): + _orig_init(self_inner, *a, **kw) + _captured_schedulers.append((self_inner.TileScheduler, self_inner.scheduling_mode)) + + +@contextmanager +def clc_scheduler_enabled(): + with ( + mock.patch.dict(os.environ, {"FA_CLC": "1"}, clear=False), + mock.patch.object(cute_utils, "_fa_clc_enabled", True), + mock.patch.object(FlashAttentionForwardSm100, "__init__", _spy_init), + ): + yield + + +def check_output(q, k, v, *, causal=False, window_size=(None, None), num_splits=1, assert_clc=True): + _captured_schedulers.clear() + out, _ = flash_attn_func(q, k, v, causal=causal, window_size=window_size, num_splits=num_splits) + torch.cuda.synchronize() + if assert_clc and _captured_schedulers: + sched_cls, sched_mode = _captured_schedulers[-1] + assert sched_cls is SingleTileLPTScheduler, f"Expected SingleTileLPTScheduler, got {sched_cls.__name__}" + assert sched_mode == SchedulingMode.CLC, f"Expected CLC scheduling mode, got {sched_mode!r}" + out_ref, _ = attention_ref(q, k, v, causal=causal, window_size=window_size) + out_pt, _ = attention_ref(q, k, v, causal=causal, window_size=window_size, upcast=False, reorder_ops=True) + fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() + assert (out - out_ref).abs().max().item() <= 2 * ( + out_pt - out_ref + ).abs().max().item() + fwd_atol, ( + f"max_diff={(out - out_ref).abs().max().item()}, " + f"pt_max_diff={(out_pt - out_ref).abs().max().item()}, " + f"fwd_atol={fwd_atol}, " + f"q={list(q.shape)} k={list(k.shape)} v={list(v.shape)} " + f"causal={causal} window_size={window_size} num_splits={num_splits}" + ) + + +def randn(b, s, h, d): + return torch.randn(b, s, h, d, device="cuda", dtype=torch.bfloat16) + + +def expected_total_tiles_mha(batch, seqlen_q, heads): + q_stage = 2 if COMPUTE_CAPABILITY == 10 and seqlen_q > 128 else 1 + num_block = (seqlen_q + q_stage * 128 - 1) // (q_stage * 128) + return num_block * heads * batch + + +@pytest.fixture(autouse=True) +def seed(): + torch.random.manual_seed(42) + + +@pytest.fixture(autouse=True) +def enable_clc_scheduler(): + with clc_scheduler_enabled(): + yield + + +class TestCLCMismatchedSeqlens: + + @pytest.mark.parametrize("sq,sk", [ + (128, 512), + (128, 1024), + (128, 2048), + (256, 64), + (256, 128), + (512, 127), + (512, 129), + (64, 4096), + (1, 128), + (1, 512), + (1, 1024), + ]) + def test_qk_mismatch(self, sq, sk): + check_output(randn(4, sq, 4, 128), randn(4, sk, 4, 128), randn(4, sk, 4, 128)) + + @pytest.mark.parametrize("sq,sk", [ + (128, 513), + (256, 1023), + (64, 257), + (192, 383), + (1, 255), + ]) + def test_qk_mismatch_nonaligned_k(self, sq, sk): + check_output(randn(4, sq, 4, 128), randn(4, sk, 4, 128), randn(4, sk, 4, 128)) + + @pytest.mark.parametrize("sq,sk", [ + (1, 128), + (1, 256), + (1, 1024), + (2, 128), + (3, 512), + ]) + def test_tiny_q_long_k(self, sq, sk): + check_output(randn(2, sq, 4, 128), randn(2, sk, 4, 128), randn(2, sk, 4, 128)) + + +class TestCLCNonAlignedShapes: + @pytest.mark.parametrize("sq", [1, 3, 7, 15, 31, 33, 63, 65, 127, 129, 191, 193, 255, 257]) + def test_nonaligned_q(self, sq): + check_output(randn(2, sq, 4, 128), randn(2, 256, 4, 128), randn(2, 256, 4, 128)) + + @pytest.mark.parametrize("sk", [1, 7, 31, 33, 63, 65, 127, 129, 255, 257, 511, 513]) + def test_nonaligned_k(self, sk): + check_output(randn(2, 256, 4, 128), randn(2, sk, 4, 128), randn(2, sk, 4, 128)) + + +class TestCLCPrimes: + @pytest.mark.parametrize("batch,heads,sq,sk", [ + (1, 1, 127, 131), + (3, 5, 131, 127), + (7, 3, 257, 251), + (11, 7, 67, 509), + (13, 1, 191, 193), + (5, 11, 61, 67), + (2, 3, 509, 127), + ]) + def test_all_prime(self, batch, heads, sq, sk): + check_output( + randn(batch, sq, heads, 128), + randn(batch, sk, heads, 128), + randn(batch, sk, heads, 128), + ) + + +class TestCLC2CTA: + @pytest.mark.parametrize("sq,sk", [ + (128, 512), + (256, 127), + (256, 129), + (128, 2048), + (1, 512), + (64, 1024), + (512, 64), + ]) + def test_2cta_qk_mismatch(self, sq, sk): + check_output(randn(4, sq, 4, 128), randn(4, sk, 4, 128), randn(4, sk, 4, 128)) + + @pytest.mark.parametrize("batch,heads,sq,sk", [ + (1, 1, 128, 128), + (1, 1, 256, 512), + (3, 5, 128, 1024), + (7, 3, 512, 127), + (9, 7, 256, 257), + (13, 1, 128, 64), + ]) + def test_2cta_adversarial_combos(self, batch, heads, sq, sk): + check_output( + randn(batch, sq, heads, 128), + randn(batch, sk, heads, 128), + randn(batch, sk, heads, 128), + ) + + +class TestCLCGQA: + @pytest.mark.parametrize("q_heads,kv_heads,sq,sk", [ + (4, 1, 128, 512), + (4, 1, 256, 127), + (8, 1, 64, 1024), + (8, 2, 512, 129), + (8, 4, 1, 256), + (6, 2, 192, 383), + (6, 3, 128, 1), + (12, 4, 257, 511), + ]) + def test_gqa_mismatch(self, q_heads, kv_heads, sq, sk): + check_output( + randn(4, sq, q_heads, 128), + randn(4, sk, kv_heads, 128), + randn(4, sk, kv_heads, 128), + ) + + @pytest.mark.parametrize("q_heads,kv_heads", [ + (4, 1), (4, 2), (8, 1), (8, 2), (8, 4), (6, 2), (6, 3), (12, 4), + ]) + def test_gqa_ratios(self, q_heads, kv_heads): + check_output( + randn(4, 512, q_heads, 128), + randn(4, 512, kv_heads, 128), + randn(4, 512, kv_heads, 128), + ) + + +class TestCLCHeadDim: + @pytest.mark.parametrize("d,dv,sq,sk", [ + (64, 64, 128, 512), + (64, 64, 1, 256), + (96, 96, 255, 127), + (128, 64, 192, 384), + (128, 64, 1, 1024), + ]) + def test_head_dims_adversarial(self, d, dv, sq, sk): + check_output(randn(4, sq, 4, d), randn(4, sk, 4, d), randn(4, sk, 4, dv)) + + def test_overlap_sO_sQ_fallback(self): + from flash_attn.cute.tile_scheduler import SingleTileScheduler + + _captured_schedulers.clear() + check_output(randn(4, 128, 4, 192), randn(4, 257, 4, 192), randn(4, 257, 4, 128), assert_clc=False) + assert _captured_schedulers, "No scheduler was captured" + sched_cls, sched_mode = _captured_schedulers[-1] + assert sched_cls is SingleTileScheduler, f"Expected SingleTileScheduler fallback, got {sched_cls.__name__}" + assert sched_mode == SchedulingMode.STATIC, f"Expected STATIC fallback, got {sched_mode!r}" + + +class TestCLCFallback: + + def test_varlen_fallback(self): + _captured_schedulers.clear() + batch, seqlen, heads, d = 4, 256, 4, 128 + lens = torch.tensor([64, 128, 32, 32], dtype=torch.int32) + cu_seqlens = torch.cat([torch.zeros(1, dtype=torch.int32), lens.cumsum(0)]).to(device="cuda", dtype=torch.int32) + total = int(cu_seqlens[-1]) + q = torch.randn(total, heads, d, device="cuda", dtype=torch.bfloat16) + k = torch.randn(total, heads, d, device="cuda", dtype=torch.bfloat16) + v = torch.randn(total, heads, d, device="cuda", dtype=torch.bfloat16) + out, _ = flash_attn_varlen_func( + q, k, v, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=int(lens.max()), + max_seqlen_k=int(lens.max()), + ) + torch.cuda.synchronize() + assert _captured_schedulers, "No scheduler was captured" + sched_cls, sched_mode = _captured_schedulers[-1] + assert sched_cls is SingleTileVarlenScheduler, ( + f"Expected SingleTileVarlenScheduler fallback for varlen, got {sched_cls.__name__}" + ) + assert sched_mode == SchedulingMode.STATIC, f"Expected STATIC fallback, got {sched_mode!r}" + + @pytest.mark.parametrize("sq,sk,wl,wr", [ + (512, 512, 128, 128), + (256, 1024, 64, 64), + (512, 512, 255, 0), + (128, 2048, 32, 512), + ]) + def test_local_window_with_clc(self, sq, sk, wl, wr): + check_output( + randn(4, sq, 4, 128), + randn(4, sk, 4, 128), + randn(4, sk, 4, 128), + window_size=(wl, wr), + ) + + +class TestCLCMinimal: + @pytest.mark.parametrize("sq,sk", [(1, 1), (1, 2), (2, 1), (1, 128), (128, 1)]) + def test_minimal(self, sq, sk): + check_output(randn(1, sq, 1, 128), randn(1, sk, 1, 128), randn(1, sk, 1, 128)) + + def test_single_element(self): + check_output(randn(1, 1, 1, 64), randn(1, 1, 1, 64), randn(1, 1, 1, 64)) + + +class TestCLCCausal: + + @pytest.mark.parametrize("batch,heads,sq,sk", [ + (3, 5, 259, 259), + (7, 3, 513, 513), + (1, 7, 1023, 1023), + (5, 11, 2049, 2049), + (2, 3, 4097, 4097), + ]) + def test_causal_square(self, batch, heads, sq, sk): + check_output(randn(batch, sq, heads, 128), randn(batch, sk, heads, 128), randn(batch, sk, heads, 128), causal=True) + + @pytest.mark.parametrize("batch,heads,sq,sk", [ + (3, 7, 127, 513), + (5, 3, 259, 1023), + (7, 5, 63, 2049), + (11, 1, 1, 511), + (2, 9, 1, 1025), + (9, 3, 33, 4097), + ]) + def test_causal_qk_mismatch(self, batch, heads, sq, sk): + check_output(randn(batch, sq, heads, 128), randn(batch, sk, heads, 128), randn(batch, sk, heads, 128), causal=True) + + @pytest.mark.parametrize("batch,heads,sq,sk", [ + (3, 7, 191, 191), + (7, 5, 193, 193), + (5, 3, 383, 383), + (11, 1, 129, 509), + (2, 13, 1, 131), + (9, 3, 67, 251), + ]) + def test_causal_nonaligned(self, batch, heads, sq, sk): + check_output(randn(batch, sq, heads, 128), randn(batch, sk, heads, 128), randn(batch, sk, heads, 128), causal=True) + + @pytest.mark.parametrize("batch,q_heads,kv_heads,sq", [ + (3, 6, 2, 513), + (7, 8, 1, 259), + (5, 12, 4, 1023), + (2, 8, 2, 2049), + (11, 4, 1, 191), + ]) + def test_causal_gqa(self, batch, q_heads, kv_heads, sq): + check_output( + randn(batch, sq, q_heads, 128), + randn(batch, sq, kv_heads, 128), + randn(batch, sq, kv_heads, 128), + causal=True, + ) + + def test_causal_large(self): + check_output(randn(3, 4097, 13, 128), randn(3, 4097, 13, 128), randn(3, 4097, 13, 128), causal=True) + + +class TestCLCLargeScale: + def test_large_batch(self): + check_output(randn(32, 512, 8, 128), randn(32, 512, 8, 128), randn(32, 512, 8, 128)) + + def test_long_seq(self): + check_output(randn(2, 4096, 4, 128), randn(2, 4096, 4, 128), randn(2, 4096, 4, 128)) + + def test_many_heads(self): + check_output(randn(4, 512, 32, 128), randn(4, 512, 32, 128), randn(4, 512, 32, 128)) + + @pytest.mark.parametrize("batch,heads,sq,sk", [ + (24, 8, 768, 2048), + (16, 8, 1536, 4096), + (12, 8, 2305, 4096), + ]) + def test_work_stealing_pressure(self, batch, heads, sq, sk): + total_tiles = expected_total_tiles_mha(batch, sq, heads) + assert total_tiles > SM_COUNT, f"expected total_tiles={total_tiles} > sm_count={SM_COUNT}" + check_output( + randn(batch, sq, heads, 128), + randn(batch, sk, heads, 128), + randn(batch, sk, heads, 128), + ) + + def test_long_k_short_q(self): + check_output(randn(8, 64, 8, 128), randn(8, 8192, 8, 128), randn(8, 8192, 8, 128)) + + def test_long_q_short_k(self): + check_output(randn(4, 4096, 4, 128), randn(4, 64, 4, 128), randn(4, 64, 4, 128)) + + +class TestCLCRepeatability: + @pytest.mark.parametrize("trial", range(5)) + def test_repeat_mismatch(self, trial): + torch.random.manual_seed(trial) + check_output(randn(7, 192, 5, 128), randn(7, 513, 5, 128), randn(7, 513, 5, 128)) + + @pytest.mark.parametrize("trial", range(5)) + def test_repeat_2cta(self, trial): + torch.random.manual_seed(trial) + check_output(randn(9, 257, 3, 128), randn(9, 511, 3, 128), randn(9, 511, 3, 128)) + + @pytest.mark.parametrize("trial", range(5)) + def test_repeat_gqa_mismatch(self, trial): + torch.random.manual_seed(trial) + check_output(randn(5, 128, 8, 128), randn(5, 1024, 2, 128), randn(5, 1024, 2, 128)) + +if __name__ == "__main__": + pytest.main([__file__, "-v"])