diff --git a/benchmarks/benchmark_attn.py b/benchmarks/benchmark_attn.py index 1a868e0a286..cb6bc44eae2 100644 --- a/benchmarks/benchmark_attn.py +++ b/benchmarks/benchmark_attn.py @@ -22,7 +22,12 @@ # from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler from flash_attn.cute.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler -from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func + +try: + from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func +except ImportError: + flash_attn_func = None + flash_attn_varlen_func = None from flash_attn.cute.interface import flash_attn_func as flash_attn_func_python from flash_attn.cute.interface import flash_attn_varlen_func as flash_attn_varlen_func_python try: diff --git a/flash_attn/cute/fast_math.py b/flash_attn/cute/fast_math.py index 943388fd291..c56ea89e798 100644 --- a/flash_attn/cute/fast_math.py +++ b/flash_attn/cute/fast_math.py @@ -1,12 +1,8 @@ # Copyright (c) 2025, Tri Dao. -from typing import Tuple - import cutlass import cutlass.cute as cute -from cutlass import Int32, Uint32 -from cutlass.cutlass_dsl import T, dsl_user_op -from cutlass._mlir.dialects import llvm +from cutlass import Int32 @cute.jit @@ -23,75 +19,3 @@ def clz(x: Int32) -> Int32: res = Int32(i) done = True return res - - -def find_log2(x: Int32) -> Int32: - a: Int32 = Int32(31 - clz(x)) - return a + ((x & (x - 1)) != 0) # Round up, add 1 if not a power of 2. - - -@dsl_user_op -def umulhi(a: Int32, b: Int32, *, loc=None, ip=None) -> Uint32: - return Uint32( - llvm.inline_asm( - T.i32(), - [Int32(a).ir_value(loc=loc, ip=ip), Int32(b).ir_value(loc=loc, ip=ip)], - "mul.hi.u32 $0, $1, $2;", - "=r,r,r", - has_side_effects=False, - is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT, - ) - ) - - -class FastDivmod: - def __init__( - self, divisor: Int32, multipler: Uint32, shift_right: Uint32, *, loc=None, ip=None - ): - self.divisor = divisor - self.multiplier = multipler - self.shift_right = shift_right - self._loc = loc - - # called by host - @staticmethod - def create(divisor: Int32, *, loc=None, ip=None) -> "FastDivmod": - """Construct the FastDivmod object, in host code. - This precomputes some values based on the divisor and is computationally expensive. - """ - p = Uint32(31 + find_log2(divisor)) - divisor_u32 = Uint32(divisor) - multiplier = Uint32(((cutlass.Uint64(1) << p) + divisor_u32 - 1) // divisor_u32) - shift_right = Uint32(p - 32) - return FastDivmod(divisor, multiplier, shift_right, loc=loc, ip=ip) - - @cute.jit - def div(self, dividend: Int32) -> Int32: - return ( - Int32(umulhi(dividend, self.multiplier) >> self.shift_right) - if self.divisor != 1 - else dividend - ) - - def divmod(self, dividend: Int32) -> Tuple[Int32, Int32]: - quotient = self.div(dividend) - remainder = dividend - quotient * self.divisor - return quotient, remainder - - def __extract_mlir_values__(self): - values, self._values_pos = [], [] - for obj in [self.divisor, self.multiplier, self.shift_right]: - obj_values = cutlass.extract_mlir_values(obj) - values += obj_values - self._values_pos.append(len(obj_values)) - return values - - def __new_from_mlir_values__(self, values): - obj_list = [] - for obj, n_items in zip( - [self.divisor, self.multiplier, self.shift_right], self._values_pos - ): - obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) - values = values[n_items:] - return FastDivmod(*(tuple(obj_list)), loc=self._loc) diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py index fb0e2e9b778..7fc45666638 100644 --- a/flash_attn/cute/flash_bwd_sm100.py +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -414,8 +414,7 @@ def __call__( assert mdK_semaphore is not None assert mdV_semaphore is not None mdK_semaphore, mdV_semaphore = [ - utils.select(t, mode=semaphore_transpose) - for t in (mdK_semaphore, mdV_semaphore) + utils.select(t, mode=semaphore_transpose) for t in (mdK_semaphore, mdV_semaphore) ] else: mdK_semaphore = None @@ -562,7 +561,7 @@ def __call__( cute.size(mQ.shape[2]), # num_heads = num_query_heads cute.size(mK.shape[3]), 1, # num_splits - cute.size(mQ.shape[0]), # pass seqlen_q for seqlen_k + cute.size(mQ.shape[0]), # pass seqlen_q for seqlen_k mQ.shape[1], mV.shape[1], total_q=cute.size(mQ.shape[0]), @@ -1905,7 +1904,9 @@ def compute_loop( if const_expr(not self.use_smem_dS_for_mma_dK): cute.arch.fence_view_async_tmem_store() - cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta + ) self.compute_sync_barrier.arrive_and_wait() # with cute.arch.elect_one(): @@ -2032,7 +2033,7 @@ def dQacc_reduce( gdQaccum = cute.flat_divide( gdQaccum_, (self.tile_m * self.tile_hdim // self.dQaccum_reduce_stage,) ) - + if const_expr(self.deterministic): mdQ_semaphore_cur = mdQ_semaphore[None, None, head_idx, batch_idx] @@ -2068,12 +2069,17 @@ def dQacc_reduce( if const_expr(self.spt): n_block_max_for_m_block = min( n_block_global_max, - cute.ceil_div((m_block + 1) * self.tile_m + seqlen.seqlen_k - seqlen.seqlen_q, self.tile_n) + cute.ceil_div( + (m_block + 1) * self.tile_m + seqlen.seqlen_k - seqlen.seqlen_q, + self.tile_n, + ), ) lock_value = n_block_max_for_m_block - 1 - n_block else: lock_value = n_block - barrier.wait_eq(mdQ_semaphore_cur[(m_block, None)].iterator, tidx, 0, lock_value) + barrier.wait_eq( + mdQ_semaphore_cur[(m_block, None)].iterator, tidx, 0, lock_value + ) self.reduce_sync_barrier.arrive_and_wait() # Copy from shared memory to global memory if is_tma_warp: @@ -2101,7 +2107,9 @@ def dQacc_reduce( # semaphore release for prior m_block if const_expr(self.deterministic and stage == 0 and delay_semaphore_release): if m_block > m_block_min: - barrier.arrive_inc(mdQ_semaphore_cur[(m_block - 1, None)].iterator, tidx, 0, 1) + barrier.arrive_inc( + mdQ_semaphore_cur[(m_block - 1, None)].iterator, tidx, 0, 1 + ) # semaphore release # NOTE: arrive_inc calls red_release which issues membar diff --git a/flash_attn/cute/flash_fwd.py b/flash_attn/cute/flash_fwd.py index e341ac4feee..57874f6559f 100644 --- a/flash_attn/cute/flash_fwd.py +++ b/flash_attn/cute/flash_fwd.py @@ -44,7 +44,7 @@ SingleTileVarlenScheduler, ParamsBase, ) -from flash_attn.cute.fast_math import FastDivmod +from cutlass.cute import FastDivmodDivisor class FlashAttentionForwardBase: @@ -692,8 +692,8 @@ def __call__( self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1 ) seqlen_k = cute.size(mK.shape[0]) - seqlen_q_divmod = FastDivmod.create(seqlen_q) - seqlen_k_divmod = FastDivmod.create(seqlen_k) + seqlen_q_divmod = FastDivmodDivisor(seqlen_q) + seqlen_k_divmod = FastDivmodDivisor(seqlen_k) fastdiv_mods = (seqlen_q_divmod, seqlen_k_divmod) self.kernel( @@ -1503,8 +1503,8 @@ def __call__( self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1 ) seqlen_k = cute.size(mK.shape[0]) - seqlen_q_divmod = FastDivmod.create(seqlen_q) - seqlen_k_divmod = FastDivmod.create(seqlen_k) + seqlen_q_divmod = FastDivmodDivisor(seqlen_q) + seqlen_k_divmod = FastDivmodDivisor(seqlen_k) fastdiv_mods = (seqlen_q_divmod, seqlen_k_divmod) self.kernel( diff --git a/flash_attn/cute/flash_fwd_combine.py b/flash_attn/cute/flash_fwd_combine.py index b23ab8ba78e..02672e319de 100644 --- a/flash_attn/cute/flash_fwd_combine.py +++ b/flash_attn/cute/flash_fwd_combine.py @@ -14,8 +14,8 @@ from cutlass import Float32, Int32, const_expr from flash_attn.cute import utils -from flash_attn.cute.fast_math import FastDivmod from flash_attn.cute.seqlen_info import SeqlenInfo +from cutlass.cute import FastDivmodDivisor class FlashAttentionForwardCombine: @@ -257,9 +257,9 @@ class SharedStorage: num_head = mO_partial.shape[3] batch_size = mO_partial.shape[4] if const_expr(cu_seqlens is None) else Int32(cu_seqlens.shape[0] - 1) - # Create FastDivmod objects for efficient division - seqlen_divmod = FastDivmod.create(seqlen) - head_divmod = FastDivmod.create(num_head) + # Create FastDivmodDivisor objects for efficient division + seqlen_divmod = FastDivmodDivisor(seqlen) + head_divmod = FastDivmodDivisor(num_head) grid_dim = ( cute.ceil_div(seqlen * num_head, self.m_block_size), @@ -311,8 +311,8 @@ def kernel( gmem_tiled_copy_O: cute.TiledCopy, gmem_tiled_copy_LSE: cute.TiledCopy, s2r_tiled_copy_LSE: cute.TiledCopy, - seqlen_divmod: FastDivmod, - head_divmod: FastDivmod, + seqlen_divmod: FastDivmodDivisor, + head_divmod: FastDivmodDivisor, varlen: cutlass.Constexpr[bool], ): # Thread and block indices @@ -380,9 +380,9 @@ def kernel( mi = tLSEcLSE[0, 0, m][1] # Get m coordinate idx = m_block * self.m_block_size + mi if idx < max_idx: - # Calculate actual sequence position and head using FastDivmod + # Calculate actual sequence position and head using FastDivmodDivisor if const_expr(not varlen): - head_idx, m_idx = seqlen_divmod.divmod(idx) + head_idx, m_idx = divmod(idx, seqlen_divmod) else: head_idx = idx // seqlen m_idx = idx - head_idx * seqlen @@ -420,7 +420,7 @@ def kernel( mi = tOcO[0, m, 0][0] # m coordinate idx = m_block * self.m_block_size + mi if const_expr(not varlen): - tOhidx[m], tOmidx[m] = seqlen_divmod.divmod(idx) + tOhidx[m], tOmidx[m] = divmod(idx, seqlen_divmod) else: tOhidx[m] = idx // seqlen tOmidx[m] = idx - tOhidx[m] * seqlen @@ -536,7 +536,7 @@ def kernel( idx = m_block * self.m_block_size + mi if idx < max_idx: if const_expr(not varlen): - head_idx, m_idx = seqlen_divmod.divmod(idx) + head_idx, m_idx = divmod(idx, seqlen_divmod) else: head_idx = idx // seqlen m_idx = idx - head_idx * seqlen diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index 2234d69ca99..645ad97b003 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -45,7 +45,7 @@ from flash_attn.cute.pack_gqa import PackGQA from flash_attn.cute import mma_sm100_desc as sm100_desc from flash_attn.cute import blackwell_helpers as sm100_utils -from flash_attn.cute.fast_math import FastDivmod +from cutlass.cute import FastDivmodDivisor from flash_attn.cute.tile_scheduler import ( TileSchedulerArguments, SingleTileScheduler, @@ -659,8 +659,8 @@ class SharedStorage: self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1 ) seqlen_k = cute.size(mK.shape[0]) - seqlen_q_divmod = FastDivmod.create(seqlen_q) - seqlen_k_divmod = FastDivmod.create(seqlen_k) + seqlen_q_divmod = FastDivmodDivisor(seqlen_q) + seqlen_k_divmod = FastDivmodDivisor(seqlen_k) fastdiv_mods = (seqlen_q_divmod, seqlen_k_divmod) self.use_block_sparsity = cutlass.const_expr(blocksparse_tensors is not None) @@ -1190,7 +1190,7 @@ def load( mPageTable, mK, mV, - FastDivmod.create(page_size), + FastDivmodDivisor(page_size), batch_idx, head_idx_kv, tidx, @@ -2660,7 +2660,7 @@ def apply_score_mod( if cutlass.const_expr(aux_tensors is not None): seqlen_q_divmod, _ = fastdiv_mods - _, q_idx_logical = seqlen_q_divmod.divmod(q_idx_logical) + _, q_idx_logical = divmod(q_idx_logical, seqlen_q_divmod) apply_score_mod_inner( tSrS_t2r, diff --git a/flash_attn/cute/mask.py b/flash_attn/cute/mask.py index aa3d1bba099..da3ed8fb2d3 100644 --- a/flash_attn/cute/mask.py +++ b/flash_attn/cute/mask.py @@ -145,7 +145,7 @@ def apply_mask( global_row_idx = tScS_mn[r, 0][0] + m_block * self.tile_m row_for_mod = global_row_idx if const_expr(wrap_aux_indices): - _, row_for_mod = fastdiv_mods[0].divmod(global_row_idx) + _, row_for_mod = divmod(global_row_idx, fastdiv_mods[0]) for col in cutlass.range_constexpr(ncol): col_idx_local = t0ScS_mn[0, col][1] @@ -153,7 +153,7 @@ def apply_mask( global_col_idx = thr_col_offset + col_idx_local + n_block * self.tile_n col_for_mod = global_col_idx if const_expr(wrap_aux_indices): - _, col_for_mod = fastdiv_mods[1].divmod(global_col_idx) + _, col_for_mod = divmod(global_col_idx, fastdiv_mods[1]) batch_idx_ssa = utils.scalar_to_ssa(batch_idx, cutlass.Int32) head_idx_ssa = utils.scalar_to_ssa(head_idx, cutlass.Int32) @@ -357,7 +357,7 @@ def apply_mask_sm100( mask_row = global_row mask_row_for_mod = mask_row if const_expr(wrap_aux_indices): - _, mask_row_for_mod = fastdiv_mods[0].divmod(mask_row) + _, mask_row_for_mod = divmod(mask_row, fastdiv_mods[0]) mask_row_ssa = utils.scalar_to_ssa(mask_row_for_mod, cutlass.Int32) ncol = const_expr(cute.size(tScS_t2r.shape)) @@ -366,7 +366,7 @@ def apply_mask_sm100( global_col = col_coord + n_block * self.tile_n global_col_for_mod = global_col if const_expr(wrap_aux_indices): - _, global_col_for_mod = fastdiv_mods[1].divmod(global_col) + _, global_col_for_mod = divmod(global_col, fastdiv_mods[1]) kv_idx_ssa = utils.scalar_to_ssa(global_col_for_mod, cutlass.Int32) mask_value = mask_mod( batch_idx_ssa, diff --git a/flash_attn/cute/paged_kv.py b/flash_attn/cute/paged_kv.py index ccb2296b4a7..8b0949d1404 100644 --- a/flash_attn/cute/paged_kv.py +++ b/flash_attn/cute/paged_kv.py @@ -7,8 +7,8 @@ from cutlass import Int32, const_expr from flash_attn.cute import utils -from flash_attn.cute.fast_math import FastDivmod from flash_attn.cute.cute_dsl_utils import ParamsBase +from cutlass.cute import FastDivmodDivisor @dataclass @@ -18,7 +18,7 @@ class PagedKVManager(ParamsBase): mV_paged: cute.Tensor thread_idx: Int32 - page_size_divmod: FastDivmod + page_size_divmod: FastDivmodDivisor seqlen_k: Int32 leftpad_k: Int32 n_block_size: Int32 @@ -42,7 +42,7 @@ def create( mPageTable: cute.Tensor, mK_paged: cute.Tensor, mV_paged: cute.Tensor, - page_size_divmod: FastDivmod, + page_size_divmod: FastDivmodDivisor, bidb: Int32, bidh: Int32, thread_idx: Int32, @@ -118,7 +118,7 @@ def load_page_table(self, n_block: Int32): row = (i * self.num_threads + self.thread_idx) // self.gmem_threads_per_row row_idx = n_block * self.n_block_size + row - page_idx, page_offset = self.page_size_divmod.divmod(row_idx + self.leftpad_k) + page_idx, page_offset = divmod(row_idx + self.leftpad_k, self.page_size_divmod) is_valid = ( (i + 1) * self.num_threads <= self.n_block_size or row < self.n_block_size @@ -173,4 +173,16 @@ def load_KV(self, n_block: Int32, sX: cute.Tensor, K_or_V: str): ) elif const_expr(K_or_V == "V"): # Don't need to clear out the rest of the smem for K since we'll mask out the scores anyway. - tXsX[None, m, None].fill(0) + fill_swizzled(tXsX[None, m, None], 0) + + +@cutlass.dsl_user_op +def fill_swizzled(tensor, value: cutlass.Numeric, *, loc=None, ip=None) -> None: + """Fill tensor with a constant value. + + Fills all elements of the tensor with the specified value, assuming static size + and supported memory space. + """ + rTmp = cute.make_rmem_tensor_like(tensor, tensor.element_type) + rTmp.fill(value) + cute.autovec_copy(rTmp, tensor) diff --git a/flash_attn/cute/pyproject.toml b/flash_attn/cute/pyproject.toml index 1b21df4b227..8b5942b10d0 100644 --- a/flash_attn/cute/pyproject.toml +++ b/flash_attn/cute/pyproject.toml @@ -22,7 +22,7 @@ classifiers = [ ] dependencies = [ - "nvidia-cutlass-dsl==4.3.0.dev0", + "nvidia-cutlass-dsl==4.3.0", "torch", "einops", "typing_extensions", diff --git a/flash_attn/cute/softmax.py b/flash_attn/cute/softmax.py index 0ca08f3f2e3..658934ce753 100644 --- a/flash_attn/cute/softmax.py +++ b/flash_attn/cute/softmax.py @@ -392,12 +392,12 @@ def apply_score_mod_inner( if cutlass.const_expr(constant_q_idx is None): seqlen_q_divmod, seqlen_k_divmod = fastdiv_mods q_idx_floored = floor_if_packed(index_tensor[i + j][0], qhead_per_kvhead) - _, q_idx_wrapped = seqlen_q_divmod.divmod(q_idx_floored) + _, q_idx_wrapped = divmod(q_idx_floored, seqlen_q_divmod) q_idx_vec[j] = q_idx_wrapped else: _, seqlen_k_divmod = fastdiv_mods - _, kv_idx_wrapped = seqlen_k_divmod.divmod(index_tensor[i + j][1]) + _, kv_idx_wrapped = divmod(index_tensor[i + j][1], seqlen_k_divmod) kv_idx_vec[j] = kv_idx_wrapped else: # No bounds checking - direct indexing diff --git a/flash_attn/cute/tile_scheduler.py b/flash_attn/cute/tile_scheduler.py index ad6ab099b0a..ef47cedecdf 100644 --- a/flash_attn/cute/tile_scheduler.py +++ b/flash_attn/cute/tile_scheduler.py @@ -14,7 +14,8 @@ from cutlass import Int32, const_expr import flash_attn.cute.utils as utils -from flash_attn.cute.fast_math import FastDivmod, clz +from flash_attn.cute.fast_math import clz +from cutlass.cute import FastDivmodDivisor class WorkTileInfo(cutlass.utils.WorkTileInfo): @@ -80,7 +81,7 @@ class Params(ParamsBase): num_head: Int32 num_batch: Int32 num_splits: Int32 - num_splits_divmod: FastDivmod + num_splits_divmod: FastDivmodDivisor is_split_kv: cutlass.Constexpr[bool] = False cluster_shape_mn: cutlass.Constexpr[Tuple[int, int]] = (1, 1) @@ -93,7 +94,7 @@ def create( args.num_head, args.num_batch, args.num_splits, - FastDivmod.create(args.num_splits), + FastDivmodDivisor(args.num_splits), args.is_split_kv, args.cluster_shape_mn, ) @@ -133,7 +134,7 @@ def get_grid_shape( def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: block_idx, head_idx, batch_idx = self._blk_coord if const_expr(self.params.is_split_kv): - head_idx, split_idx = self.params.num_splits_divmod.divmod(head_idx) + head_idx, split_idx = divmod(head_idx, self.params.num_splits_divmod) else: split_idx = Int32(0) return WorkTileInfo( @@ -169,8 +170,8 @@ def __new_from_mlir_values__(self, values): class StaticPersistentTileScheduler: @dataclass class Params(ParamsBase): - num_block_divmod: FastDivmod - num_head_divmod: FastDivmod + num_block_divmod: FastDivmodDivisor + num_head_divmod: FastDivmodDivisor total_blocks: Int32 @staticmethod @@ -179,7 +180,7 @@ def create( ) -> "StaticPersistentTileScheduler.Params": total_blocks = args.num_block * args.num_head * args.num_batch return StaticPersistentTileScheduler.Params( - FastDivmod.create(args.num_block), FastDivmod.create(args.num_head), total_blocks + FastDivmodDivisor(args.num_block), FastDivmodDivisor(args.num_head), total_blocks ) def __init__(self, params: Params, tile_idx: Int32, *, loc=None, ip=None): @@ -211,8 +212,8 @@ def get_grid_shape( # @cute.jit def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: - hn_idx, block_idx = self.params.num_block_divmod.divmod(self._tile_idx) - batch_idx, head_idx = self.params.num_head_divmod.divmod(hn_idx) + hn_idx, block_idx = divmod(self._tile_idx, self.params.num_block_divmod) + batch_idx, head_idx = divmod(hn_idx, self.params.num_head_divmod) is_valid = self._tile_idx < self.params.total_blocks # if cute.arch.thread_idx()[0] == 0: # cute.printf("TileScheduler: tile_idx=%d, hn_idx=%d, block_idx=%d, batch_idx=%d, head_idx=%d, is_valid=%d", self._tile_idx, hn_idx, block_idx, batch_idx, head_idx, is_valid) @@ -253,11 +254,13 @@ class SingleTileLPTScheduler: class Params(ParamsBase): total_blocks: Int32 num_splits: Int32 - num_block_divmod: FastDivmod - num_head_divmod: FastDivmod - l2_minor_divmod: FastDivmod - l2_major_divmod: FastDivmod - l2_minor_residual_divmod: FastDivmod + num_block: Int32 + l2_minor: Int32 + num_block_divmod: FastDivmodDivisor + num_head_divmod: FastDivmodDivisor + l2_minor_divmod: FastDivmodDivisor + l2_major_divmod: FastDivmodDivisor + l2_minor_residual_divmod: FastDivmodDivisor num_hb_quotient: Int32 is_split_kv: cutlass.Constexpr[bool] = False @@ -284,11 +287,13 @@ def create( num_hb_remainder = (args.num_head * args.num_batch) % swizzle return SingleTileLPTScheduler.Params( total_blocks=args.num_block * args.num_head * args.num_batch, - num_block_divmod=FastDivmod.create(args.num_block), - num_head_divmod=FastDivmod.create(args.num_head), - l2_minor_divmod=FastDivmod.create(swizzle), - l2_major_divmod=FastDivmod.create(swizzle * args.num_block), - l2_minor_residual_divmod=FastDivmod.create( + num_block=args.num_block, + l2_minor=Int32(swizzle), + num_block_divmod=FastDivmodDivisor(args.num_block), + num_head_divmod=FastDivmodDivisor(args.num_head), + l2_minor_divmod=FastDivmodDivisor(swizzle), + l2_major_divmod=FastDivmodDivisor(swizzle * args.num_block), + l2_minor_residual_divmod=FastDivmodDivisor( max(num_hb_remainder, 1) ), # don't divide by 0 num_hb_quotient=Int32(num_hb_quotient), @@ -327,18 +332,18 @@ def get_grid_shape( def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: params = self.params # Implement LPT scheduling coordinate calculation - bidhb, l2_mod = params.l2_major_divmod.divmod(self._tile_idx) + bidhb, l2_mod = divmod(self._tile_idx, params.l2_major_divmod) # If we're in the last section (called residual), we don't want to divide by # swizzle. Instead we want to divide by the remainder. block, bidhb_residual = 0, 0 if bidhb < params.num_hb_quotient: - block, bidhb_residual = params.l2_minor_divmod.divmod(l2_mod) + block, bidhb_residual = divmod(l2_mod, params.l2_minor_divmod) else: - block, bidhb_residual = params.l2_minor_residual_divmod.divmod(l2_mod) - bidhb_actual = bidhb * params.l2_minor_divmod.divisor + bidhb_residual - batch_idx, head_idx = params.num_head_divmod.divmod(bidhb_actual) + block, bidhb_residual = divmod(l2_mod, params.l2_minor_residual_divmod) + bidhb_actual = bidhb * params.l2_minor + bidhb_residual + batch_idx, head_idx = divmod(bidhb_actual, params.num_head_divmod) # Longest-processing-time-first - block = params.num_block_divmod.divisor - 1 - block + block = params.num_block - 1 - block is_valid = self._tile_idx < params.total_blocks return WorkTileInfo( (Int32(block), Int32(head_idx), Int32(batch_idx), Int32(self._split_idx)), is_valid @@ -375,10 +380,11 @@ class SingleTileLPTBwdScheduler: class Params(ParamsBase): total_blocks: Int32 num_block: Int32 - num_head_divmod: FastDivmod - l2_minor_divmod: FastDivmod - l2_major_divmod: FastDivmod - l2_minor_residual_divmod: FastDivmod + l2_minor: Int32 + num_head_divmod: FastDivmodDivisor + l2_minor_divmod: FastDivmodDivisor + l2_major_divmod: FastDivmodDivisor + l2_minor_residual_divmod: FastDivmodDivisor num_hb_quotient: Int32 cluster_shape_mn: cutlass.Constexpr[Tuple[int, int]] = (1, 1) spt: cutlass.Constexpr[bool] = True @@ -406,10 +412,11 @@ def create( * args.num_head * args.num_batch, num_block=num_block, - num_head_divmod=FastDivmod.create(args.num_head), - l2_minor_divmod=FastDivmod.create(swizzle), - l2_major_divmod=FastDivmod.create(swizzle * num_block), - l2_minor_residual_divmod=FastDivmod.create( + l2_minor=Int32(swizzle), + num_head_divmod=FastDivmodDivisor(args.num_head), + l2_minor_divmod=FastDivmodDivisor(swizzle), + l2_major_divmod=FastDivmodDivisor(swizzle * num_block), + l2_minor_residual_divmod=FastDivmodDivisor( max(num_hb_remainder, 1) ), # don't divide by 0 num_hb_quotient=Int32(num_hb_quotient), @@ -448,16 +455,16 @@ def get_current_work(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo: cluster_idx = self._tile_idx // self.params.cluster_shape_mn[0] params = self.params # Implement LPT scheduling coordinate calculation - bidhb, l2_mod = params.l2_major_divmod.divmod(cluster_idx) + bidhb, l2_mod = divmod(cluster_idx, params.l2_major_divmod) # If we're in the last section (called residual), we don't want to divide by # swizzle. Instead we want to divide by the remainder. block, bidhb_residual = 0, 0 if bidhb < params.num_hb_quotient: - block, bidhb_residual = params.l2_minor_divmod.divmod(l2_mod) + block, bidhb_residual = divmod(l2_mod, params.l2_minor_divmod) else: - block, bidhb_residual = params.l2_minor_residual_divmod.divmod(l2_mod) - bidhb_actual = bidhb * params.l2_minor_divmod.divisor + bidhb_residual - batch_idx, head_idx = params.num_head_divmod.divmod(bidhb_actual) + block, bidhb_residual = divmod(l2_mod, params.l2_minor_residual_divmod) + bidhb_actual = bidhb * params.l2_minor + bidhb_residual + batch_idx, head_idx = divmod(bidhb_actual, params.num_head_divmod) is_valid = self._tile_idx < params.total_blocks bidx_in_cluster = cute.arch.block_in_cluster_idx() block = block * params.cluster_shape_mn[0] + bidx_in_cluster[0] diff --git a/tests/cute/test_flash_attn_varlen.py b/tests/cute/test_flash_attn_varlen.py index 3a514664449..53d907eed94 100644 --- a/tests/cute/test_flash_attn_varlen.py +++ b/tests/cute/test_flash_attn_varlen.py @@ -29,7 +29,7 @@ def test_varlen( ): if min_seq_len > max_seq_len: pytest.skip("Skipping min_seq_len > max_seq_len") - + q, k, v, cu_seqlens_q, cu_seqlens_k, total_q, total_k = generate_varlen_args( batch_size=B, n_heads=H, @@ -40,30 +40,36 @@ def test_varlen( dtype=dtype ) - ok = check_backward_vs_torch_flash( - q, k, v, - cu_seqlens_q, cu_seqlens_k, - total_q=total_q, total_k=total_k, - softmax_scale=softmax_scale, + # SM100 (Blackwell) backward pass doesn't support varlen yet + compute_capability = torch.cuda.get_device_capability()[0] + skip_backward = (compute_capability == 10) + + ok = check_varlen_vs_torch_flash( + q, k, v, + cu_seqlens_q, cu_seqlens_k, + total_q=total_q, total_k=total_k, + softmax_scale=softmax_scale, causal=causal, mha_type=mha_type, + skip_backward=skip_backward, ) assert ok -def check_backward_vs_torch_flash( - q, k, v, - cu_seqlens_q=None, - cu_seqlens_k=None, - seqused_q=None, - seqused_k=None, +def check_varlen_vs_torch_flash( + q, k, v, + cu_seqlens_q=None, + cu_seqlens_k=None, + seqused_q=None, + seqused_k=None, total_q=None, total_k=None, - softmax_scale=None, + softmax_scale=None, causal=True, mha_type='mha', softcap=0.0, - atol=3e-2, + atol=3e-2, rtol=3e-2, + skip_backward=False, ): assert q.requires_grad and k.requires_grad and v.requires_grad, "Set requires_grad=True on inputs" @@ -103,18 +109,27 @@ def clone_like(t): ) out_t = torch_flash_ref( - q_t, k_t, v_t, - cu_seqlens_q=cu_seqlens_q_t, - cu_seqlens_k=cu_seqlens_k_t, + q_t, k_t, v_t, + cu_seqlens_q=cu_seqlens_q_t, + cu_seqlens_k=cu_seqlens_k_t, seqused_q=seqused_q, seqused_k=seqused_k, total_q=total_q, total_k=total_k, - softmax_scale=softmax_scale, + softmax_scale=softmax_scale, causal=causal, mha_type=mha_type, ) + + ok_fwd = torch.allclose(out_fa.float(), out_t.float(), atol=atol, rtol=rtol) + if not ok_fwd: + return False + + # Skip backward if not supported (e.g., SM100 varlen) + if skip_backward: + return True + # Use the same upstream gradient to compare backward paths grad_out = torch.randn_like(out_fa) @@ -164,7 +179,7 @@ def generate_varlen_args( total_q = cu_seqlens_q[-1] total_k = cu_seqlens_k[-1] - + cu_seqlens_q = cu_seqlens_q.contiguous().to(dtype=torch.int32, device=device) cu_seqlens_k = cu_seqlens_k.contiguous().to(dtype=torch.int32, device=device) @@ -187,15 +202,15 @@ def generate_varlen_args( # Simple for loop over batch dim implementation def torch_flash_ref( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - cu_seqlens_q: torch.Tensor = None, - cu_seqlens_k: torch.Tensor = None, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q: torch.Tensor = None, + cu_seqlens_k: torch.Tensor = None, total_q: int = 0, total_k: int = 0, - softmax_scale: Optional[float] = None, - causal: bool = False, + softmax_scale: Optional[float] = None, + causal: bool = False, **kwargs ): @@ -255,7 +270,7 @@ def torch_flash_ref( for b in range(B): if hcseq_q is not None: q_start, q_end = int(hcseq_q[b]), int(hcseq_q[b+1]) - qb = q[q_start:q_end] + qb = q[q_start:q_end] else: qb = q[b] @@ -266,7 +281,7 @@ def torch_flash_ref( else: kb = k[b] vb = v[b] - + qb = qb.permute(1, 0, 2).unsqueeze(0) kb = kb.permute(1, 0, 2).unsqueeze(0) vb = vb.permute(1, 0, 2).unsqueeze(0)