diff --git a/flash_attn/cute/flash_fwd.py b/flash_attn/cute/flash_fwd.py index 57874f6559f..23fee1e1850 100644 --- a/flash_attn/cute/flash_fwd.py +++ b/flash_attn/cute/flash_fwd.py @@ -1050,6 +1050,7 @@ def compute_one_n_block( batch_idx: cutlass.Int32, head_idx: cutlass.Int32, m_block: cutlass.Int32, + seqlen: SeqlenInfoQK, aux_tensors=None, fastdiv_mods=None, mask_fn: Optional[Callable] = None, @@ -1105,6 +1106,7 @@ def load_V_next(): m_block, acc_S, n_block, + seqlen, softmax_scale=softmax.softmax_scale, aux_tensors=aux_tensors, fastdiv_mods=fastdiv_mods, @@ -1502,7 +1504,11 @@ def __call__( seqlen_q = cute.size(mQ.shape[0]) // ( self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1 ) - seqlen_k = cute.size(mK.shape[0]) + seqlen_k = ( + cute.size(mK.shape[0]) + if const_expr(mPageTable is None) + else mK.shape[0] * mPageTable.shape[1] + ) seqlen_q_divmod = FastDivmodDivisor(seqlen_q) seqlen_k_divmod = FastDivmodDivisor(seqlen_k) fastdiv_mods = (seqlen_q_divmod, seqlen_k_divmod) @@ -1982,6 +1988,25 @@ def mma( # shape: (atom_v_m * rest_m) m_block, head_idx, batch_idx, _ = work_tile.tile_idx seqlen = SeqlenInfoCls(batch_idx) + + # Recompute fastdiv_mods if necessary for varlen with aux_tensors + recompute_fastdiv_mods_q = cutlass.const_expr( + aux_tensors is not None and (seqlen.has_cu_seqlens_q or seqlen.has_seqused_q) + ) + recompute_fastdiv_mods_k = cutlass.const_expr( + aux_tensors is not None and (seqlen.has_cu_seqlens_k or seqlen.has_seqused_k) + ) + if cutlass.const_expr(fastdiv_mods is not None): + seqlen_q_divmod, seqlen_k_divmod = fastdiv_mods + fastdiv_mods = ( + seqlen_q_divmod + if not recompute_fastdiv_mods_q + else FastDivmodDivisor(seqlen.seqlen_q), + seqlen_k_divmod + if not recompute_fastdiv_mods_k + else FastDivmodDivisor(seqlen.seqlen_k), + ) + mask = AttentionMaskCls(seqlen.seqlen_q, seqlen.seqlen_k) mask_fn = partial( mask.apply_mask, @@ -2046,6 +2071,7 @@ def mma( if const_expr(self.intra_wg_overlap): kv_consumer_state = process_first_half_block( n_block=n_block_max - 1, + seqlen=seqlen, kv_consumer_state=kv_consumer_state, mask_fn=partial(mask_fn, mask_mod=self.mask_mod), score_mod_fn=score_mod_fn, @@ -2058,6 +2084,7 @@ def mma( kv_consumer_state = mma_one_n_block( kv_consumer_state, n_block=n_block_max - 1, + seqlen=seqlen, mma_pv_fn=partial(mma_pv_fn, zero_init=True), is_first_n_block=True, mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=True), @@ -2077,6 +2104,7 @@ def mma( kv_consumer_state = mma_one_n_block( kv_consumer_state, n_block=n_block_max - 1 - n_tile, + seqlen=seqlen, mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=False), ) @@ -2091,6 +2119,7 @@ def mma( kv_consumer_state = mma_one_n_block( kv_consumer_state, n_block=n_block_max - 1 - n_tile, + seqlen=seqlen, mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=False), ) @@ -2102,6 +2131,7 @@ def mma( kv_consumer_state = mma_one_n_block( kv_consumer_state, n_block=n_block_max - 1 - n_tile, + seqlen=seqlen, mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=False), ) @@ -2195,6 +2225,7 @@ def first_half_block_overlap( tOrP: cute.Tensor, smem_copy_params: SimpleNamespace, softmax: Softmax, + seqlen: SeqlenInfoQK, mask_fn: Callable = None, score_mod_fn: Optional[Callable] = None, is_first_block: bool = False, @@ -2207,7 +2238,7 @@ def first_half_block_overlap( # Apply score modification if present if const_expr(score_mod_fn is not None): - score_mod_fn(acc_S, n_block=n_block) + score_mod_fn(acc_S, n_block=n_block, seqlen=seqlen) # Apply mask; mask_seqlen always True for first block # Caveat: if full block further right than mask block, seqlen masking is redundant; @@ -2267,6 +2298,7 @@ def mma_one_n_block( tOrP: cute.Tensor, smem_copy_params: SimpleNamespace, softmax: Softmax, + seqlen: SeqlenInfoQK, score_mod_fn: Optional[Callable] = None, mask_fn: Optional[Callable] = None, is_first_n_block: cutlass.Constexpr = False, @@ -2281,7 +2313,7 @@ def mma_one_n_block( # handle score mods and masking if const_expr(score_mod_fn is not None): - score_mod_fn(acc_S, n_block=n_block) + score_mod_fn(acc_S, n_block=n_block, seqlen=seqlen) if const_expr(mask_fn is not None): mask_fn(acc_S=acc_S, n_block=n_block) @@ -2326,6 +2358,7 @@ def mma_one_n_block_intrawg_overlap( tOrP: cute.Tensor, smem_copy_params: SimpleNamespace, softmax: Softmax, + seqlen: SeqlenInfoQK, score_mod_fn: Optional[Callable] = None, mask_fn: Optional[Callable] = None, check_inf: cutlass.Constexpr = True, @@ -2345,7 +2378,7 @@ def mma_one_n_block_intrawg_overlap( # handle score mods and masking if const_expr(score_mod_fn is not None): - score_mod_fn(acc_S, n_block=n_block) + score_mod_fn(acc_S, n_block=n_block, seqlen=seqlen) if const_expr(mask_fn is not None): mask_fn(acc_S=acc_S, n_block=n_block) # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(utils.make_acc_tensor_mn_view(acc_S)) @@ -2392,6 +2425,7 @@ def apply_score_mod( acc_S, n_block, softmax_scale, + seqlen, aux_tensors: Optional[list] = None, fastdiv_mods=None, ): @@ -2411,6 +2445,7 @@ def apply_score_mod( self.qk_acc_dtype, aux_tensors, fastdiv_mods, + seqlen_info=seqlen, constant_q_idx=None, qhead_per_kvhead=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, ) @@ -2436,4 +2471,5 @@ def warp_scheduler_barrier_arrive(self): cute.arch.barrier_arrive( barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1) + next_wg, number_of_threads=2 * self.num_threads_per_warp_group, - ) \ No newline at end of file + ) + diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index 645ad97b003..aa5a5e30b2d 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -658,7 +658,11 @@ class SharedStorage: seqlen_q = cute.size(mQ.shape[0]) // ( self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1 ) - seqlen_k = cute.size(mK.shape[0]) + seqlen_k = ( + cute.size(mK.shape[0]) + if const_expr(mPageTable is None) + else mK.shape[0] * mPageTable.shape[1] + ) seqlen_q_divmod = FastDivmodDivisor(seqlen_q) seqlen_k_divmod = FastDivmodDivisor(seqlen_k) fastdiv_mods = (seqlen_q_divmod, seqlen_k_divmod) @@ -1624,6 +1628,26 @@ def softmax_loop( head_idx=head_idx, aux_tensors=aux_tensors, ) + + # Recompute fastdiv_mods if necessary + recompute_fastdiv_mods_q = cutlass.const_expr( + aux_tensors is not None and (seqlen.has_cu_seqlens_q or seqlen.has_seqused_q) + ) + recompute_fastdiv_mods_k = cutlass.const_expr( + aux_tensors is not None and (seqlen.has_cu_seqlens_k or seqlen.has_seqused_k) + ) + + if cutlass.const_expr(fastdiv_mods is not None): + seqlen_q_divmod, seqlen_k_divmod = fastdiv_mods + fastdiv_mods = ( + seqlen_q_divmod + if not recompute_fastdiv_mods_q + else FastDivmodDivisor(seqlen.seqlen_q), + seqlen_k_divmod + if not recompute_fastdiv_mods_k + else FastDivmodDivisor(seqlen.seqlen_k), + ) + mask_mod = self.mask_mod if const_expr(self.mask_mod is not None) else None mask_fn = partial( mask.apply_mask_sm100, @@ -1874,6 +1898,7 @@ def softmax_step( m_block, n_block, softmax, + seqlen, aux_tensors, fastdiv_mods, ) @@ -2369,7 +2394,7 @@ def correction_epilogue( self.check_hdim_v_oob, self.qhead_per_kvhead, ) - + # load acc O from smem to rmem for wider vectorization tOrO = cute.make_fragment_like(tOsO, self.o_dtype) cute.autovec_copy(tOsO, tOrO) @@ -2637,6 +2662,7 @@ def apply_score_mod( m_block, n_block, softmax, + seqlen: SeqlenInfoQK, aux_tensors=None, fastdiv_mods=(None, None), ): @@ -2673,6 +2699,7 @@ def apply_score_mod( self.qk_acc_dtype, aux_tensors, fastdiv_mods, + seqlen_info=seqlen, constant_q_idx=q_idx_logical, qhead_per_kvhead=self.qhead_per_kvhead if cutlass.const_expr(self.pack_gqa) else 1, ) diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 346cbd82cad..c181f0e281f 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -114,7 +114,7 @@ def _flash_attn_fwd( ... score_mod: A callable that takes the attention scores and applies a modification. mask_mod: A callable that takes token position information and selectively masks - block_sparse_tensors: A tuple of tensors used for block sparsity. + block_sparse_tensors: A tuple of tensors used for block sparsity. return_lse: Whether to return the log softmax of the attention scores. If set to True will always calculate out: Optional pre-allocated output tensor. If None, will be allocated internally. lse: Optional pre-allocated log-sum-exp tensor. If None, will be allocated when needed. @@ -294,6 +294,7 @@ def _flash_attn_fwd( if compute_capability == 9: # TODO: tune block size according to hdim. if head_dim == head_dim_v == 128 and not causal and not local and not use_block_sparsity: n_block_size = 192 + if compute_capability == 10: # TODO: fix the varlen case if ( @@ -335,7 +336,7 @@ def _flash_attn_fwd( elif lse is not None: lse_tensor = from_dlpack(lse.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=lse.ndim - 1) else: - lse_tensor = None + lse_tensor = None # hash score and mask mods for compile cache score_mod_hash = utils.hash_callable(score_mod) if score_mod is not None else False @@ -351,11 +352,6 @@ def _flash_attn_fwd( or seqused_q is not None or seqused_k is not None ) - if score_mod is not None: - if is_varlen: - raise NotImplementedError( - "score_mod with aux_tensors is not yet supported for varlen sequences. This will be fixed in a future PR." - ) if mask_mod is not None: if is_varlen: @@ -1154,6 +1150,8 @@ def forward( num_splits: int = 1, pack_gqa: Optional[bool] = None, deterministic: bool = False, + score_mod: Optional[Callable] = None, + aux_tensors: Optional[list] = None, ): out, lse = _flash_attn_fwd( q, @@ -1172,6 +1170,8 @@ def forward( softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, + score_mod=score_mod, + aux_tensors=aux_tensors, ) ctx.save_for_backward(q, k, v, out, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k) ctx.softmax_scale = softmax_scale @@ -1261,6 +1261,8 @@ def flash_attn_varlen_func( num_splits: int = 1, pack_gqa: Optional[bool] = None, deterministic: bool = False, + score_mod: Optional[Callable] = None, + aux_tensors: Optional[list] = None, ): return FlashAttnVarlenFunc.apply( q, @@ -1279,6 +1281,8 @@ def flash_attn_varlen_func( num_splits, pack_gqa, deterministic, + score_mod, + aux_tensors, ) diff --git a/flash_attn/cute/seqlen_info.py b/flash_attn/cute/seqlen_info.py index 0851ddd0522..baa38236a78 100644 --- a/flash_attn/cute/seqlen_info.py +++ b/flash_attn/cute/seqlen_info.py @@ -42,6 +42,8 @@ class SeqlenInfoQK: seqlen_k: cutlass.Int32 has_cu_seqlens_q: cutlass.Constexpr[bool] has_cu_seqlens_k: cutlass.Constexpr[bool] + has_seqused_q: cutlass.Constexpr[bool] + has_seqused_k: cutlass.Constexpr[bool] @staticmethod def create( @@ -73,8 +75,17 @@ def create( ) has_cu_seqlens_q: int = mCuSeqlensQ is not None has_cu_seqlens_k: int = mCuSeqlensK is not None + has_seqused_q: int = mSeqUsedQ is not None + has_seqused_k: int = mSeqUsedK is not None return SeqlenInfoQK( - offset_q, offset_k, seqlen_q, seqlen_k, has_cu_seqlens_q, has_cu_seqlens_k + offset_q, + offset_k, + seqlen_q, + seqlen_k, + has_cu_seqlens_q, + has_cu_seqlens_k, + has_seqused_q, + has_seqused_k, ) def offset_batch_Q(self, mQ: cute.Tensor, batch_idx: Int32, dim: int) -> cute.Tensor: diff --git a/flash_attn/cute/softmax.py b/flash_attn/cute/softmax.py index 658934ce753..e824324355a 100644 --- a/flash_attn/cute/softmax.py +++ b/flash_attn/cute/softmax.py @@ -11,6 +11,7 @@ import flash_attn.cute.utils as utils from flash_attn.cute.cute_dsl_utils import ParamsBase +from flash_attn.cute.seqlen_info import SeqlenInfoQK @dataclass @@ -29,8 +30,8 @@ def create( arch: cutlass.Constexpr[int] = 80, softmax_scale: Float32 | None = None, ): - row_max = cute.make_fragment(num_rows, Float32) - row_sum = cute.make_fragment(num_rows, Float32) + row_max = cute.make_rmem_tensor(num_rows, Float32) + row_sum = cute.make_rmem_tensor(num_rows, Float32) return Softmax(scale_log2, num_rows, row_max, row_sum, arch, softmax_scale) def reset(self) -> None: @@ -168,8 +169,8 @@ def create( ): num_rows = 1 arch = 100 - row_max = cute.make_fragment(num_rows, Float32) - row_sum = cute.make_fragment(num_rows, Float32) + row_max = cute.make_rmem_tensor(num_rows, Float32) + row_sum = cute.make_rmem_tensor(num_rows, Float32) return SoftmaxSm100( scale_log2, num_rows, @@ -339,6 +340,7 @@ def apply_score_mod_inner( qk_acc_dtype: cutlass.Constexpr, aux_tensors, fastdiv_mods, + seqlen_info: SeqlenInfoQK, constant_q_idx: cutlass.Constexpr, qhead_per_kvhead: cutlass.Constexpr[int] = 1, ): @@ -355,25 +357,26 @@ def apply_score_mod_inner( qk_acc_dtype: Data type for accumulator aux_tensors: Optional aux_tensors for FlexAttention fastdiv_mods: Tuple of (seqlen_q_divmod, seqlen_k_divmod) for wrapping + seqlen_info: Sequence length info constant_q_idx: If provided, use this constant for all q_idx values - If None, compute q_idx per-element + If None, compute q_idx per-element qhead_per_kvhead_packgqa: Pack-GQA replication factor. Divide q_idx by this when greater than 1 so score mods see logical heads. """ n_vals = cutlass.const_expr(cute.size(score_tensor.shape)) - score_vec = cute.make_fragment(vec_size, qk_acc_dtype) - kv_idx_vec = cute.make_fragment(vec_size, cutlass.Int32) + score_vec = cute.make_rmem_tensor(vec_size, qk_acc_dtype) + kv_idx_vec = cute.make_rmem_tensor(vec_size, cutlass.Int32) # SSA values for batch (constant across all elements) batch_idx_ssa = utils.scalar_to_ssa(batch_idx, cutlass.Int32).broadcast_to((vec_size,)) # Handle q_idx based on whether it's constant - q_idx_vec = cute.make_fragment(vec_size, cutlass.Int32) + q_idx_vec = cute.make_rmem_tensor(vec_size, cutlass.Int32) # For Pack-GQA with non-constant q_idx, we need per-element head indices # since a thread my process multiple query head indices if cutlass.const_expr(qhead_per_kvhead > 1 and constant_q_idx is None): - head_idx_vec = cute.make_fragment(vec_size, cutlass.Int32) + head_idx_vec = cute.make_rmem_tensor(vec_size, cutlass.Int32) for i in cutlass.range(0, n_vals, vec_size, unroll_full=True): for j in cutlass.range(vec_size, unroll_full=True): @@ -431,6 +434,7 @@ def apply_score_mod_inner( head_idx_ssa, q_idx=q_idx_ssa, kv_idx=kv_idx_ssa, + seqlen_info=seqlen_info, aux_tensors=aux_args, ) diff --git a/tests/cute/score_mod_definitions.py b/tests/cute/score_mod_definitions.py new file mode 100644 index 00000000000..be6333a6448 --- /dev/null +++ b/tests/cute/score_mod_definitions.py @@ -0,0 +1,591 @@ +import torch +import cutlass +import cutlass.cute as cute +from cutlass._mlir.dialects import math as mlir_math +import operator + +# ============================================================================= +# Score_mod functions that don't use global indices +# All use signature: (tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors) +# ============================================================================= + + +@cute.jit +def score_mod_identity(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): + return tSrS_ssa + + +@cute.jit +def score_mod_causal(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): + mask = operator.ge(q_idx, kv_idx) + return cute.where(mask, tSrS_ssa, cute.full_like(tSrS_ssa, float("-inf"))) + + +@cute.jit +def score_mod_rel_bias(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): + diff = q_idx - kv_idx + abs_diff = cute.TensorSSA(mlir_math.absi(diff), diff.shape, diff.dtype) + return tSrS_ssa + abs_diff.to(cutlass.Float32) + + +@cute.jit +def score_mod_rel_bias_x2(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): + diff = q_idx - kv_idx + abs_diff = cute.TensorSSA(mlir_math.absi(diff), diff.shape, diff.dtype) + scaled = abs_diff * cute.full_like(abs_diff, 2) + return tSrS_ssa + scaled.to(cutlass.Float32) + + +@cute.jit +def score_mod_times_two(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): + return tSrS_ssa * cute.full_like(tSrS_ssa, 2) + + +@cute.jit +def score_mod_alibi(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): + score = tSrS_ssa.to(cutlass.Float32) + slope_exp = (h_idx + cute.full_like(h_idx, 1)) * cute.full_like(h_idx, -8) + slope = cute.math.exp2( + slope_exp.to(cutlass.Float32) + * cute.full_like(score, 0.125 * 0.6931471805599453 * 1.4426950408889634) + ) + diff = q_idx - kv_idx + abs_diff = cute.TensorSSA(mlir_math.absi(diff), diff.shape, diff.dtype).to(cutlass.Float32) + return score - slope * abs_diff + + +@cute.jit +def score_mod_sliding_window(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): + diff = q_idx - kv_idx + abs_diff = cute.TensorSSA(mlir_math.absi(diff), diff.shape, diff.dtype) + mask = operator.le(abs_diff, cute.full_like(abs_diff, 256)) + return cute.where(mask, tSrS_ssa, cute.full_like(tSrS_ssa, float("-inf"))) + + +@cute.jit +def score_mod_block_diagonal(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): + q_block = q_idx // 64 + kv_block = kv_idx // 64 + mask = operator.eq(q_block, kv_block) + return cute.where(mask, tSrS_ssa, cute.full_like(tSrS_ssa, float("-inf"))) + + +@cute.jit +def score_mod_causal_v2(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): + diff = q_idx - kv_idx + mask = operator.ge(diff, cute.full_like(diff, 0)) + return cute.where(mask, tSrS_ssa, cute.full_like(tSrS_ssa, float("-inf"))) + + +@cute.jit +def score_mod_batch_bias(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): + batch_bias = aux_tensors[0] + dtype = batch_bias.element_type + b_frag = cute.make_fragment(1, cutlass.Int32) + b_frag.store(b_idx) + bias_frag = cute.make_fragment(1, dtype) + bias_frag[0] = batch_bias[b_frag[0]] + bias_val = (bias_frag.load()).to(cutlass.Float32) + return tSrS_ssa + bias_val + + +@cute.jit +def score_mod_dual_buffer(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): + head_bias = aux_tensors[0] + pos_bias = aux_tensors[1] + dtype = head_bias.element_type + + h_frag = cute.make_fragment(1, cutlass.Int32) + h_frag.store(h_idx) + head_val_frag = cute.make_fragment(1, dtype) + head_val_frag[0] = head_bias[h_frag[0]] + head_val = (head_val_frag.load()).to(cutlass.Float32) + + q_frag = cute.make_fragment(1, cutlass.Int32) + q_frag.store(q_idx) + pos_val_frag = cute.make_fragment(1, dtype) + pos_val_frag[0] = pos_bias[q_frag[0]] + pos_val = (pos_val_frag.load()).to(cutlass.Float32) + + return tSrS_ssa + head_val + pos_val + + +# ============================================================================= +# Score_mod functions that use global indices +# All use signature: (tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors) +# Global indices computed as: q_idx_global = q_idx + seqlen_info.offset_q (and similarly for kv) +# ============================================================================= + + +@cute.jit +def score_mod_global_kv_bias( + tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors +): + """Per-token bias using global kv index.""" + offset_k = seqlen_info.offset_k + kv_idx_global = kv_idx + offset_k + token_bias = aux_tensors[0] + dtype = token_bias.element_type + kv_frag = cute.make_fragment(1, cutlass.Int32) + kv_frag.store(kv_idx_global) + bias_frag = cute.make_fragment(1, dtype) + bias_frag[0] = token_bias[kv_frag[0]] + + return tSrS_ssa + (bias_frag.load()).to(cutlass.Float32) + + +@cute.jit +def score_mod_global_q_bias( + tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors +): + """Per-token bias using global q index.""" + offset_q = seqlen_info.offset_q + q_idx_global = q_idx + offset_q + token_bias = aux_tensors[0] + dtype = token_bias.element_type + q_frag = cute.make_fragment(1, cutlass.Int32) + q_frag.store(q_idx_global) + bias_frag = cute.make_fragment(1, dtype) + bias_frag[0] = token_bias[q_frag[0]] + return tSrS_ssa + (bias_frag.load()).to(cutlass.Float32) + + +@cute.jit +def score_mod_global_rel_plus_kv_bias( + tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors +): + """Relative position (logical) + per-token bias (global kv).""" + offset_k = seqlen_info.offset_k + kv_idx_global = kv_idx + offset_k + token_bias = aux_tensors[0] + dtype = token_bias.element_type + + rel_pos = q_idx - kv_idx + rel_pos_abs = cute.TensorSSA(mlir_math.absi(rel_pos), rel_pos.shape, rel_pos.dtype) + rel_bias = rel_pos_abs.to(cutlass.Float32) * cute.full_like(tSrS_ssa, 0.1) + + kv_frag = cute.make_fragment(1, cutlass.Int32) + kv_frag.store(kv_idx_global) + bias_frag = cute.make_fragment(1, dtype) + bias_frag[0] = token_bias[kv_frag[0]] + + return tSrS_ssa + rel_bias + (bias_frag.load()).to(cutlass.Float32) + + +@cute.jit +def score_mod_global_q_and_kv_bias( + tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors +): + """Both q and kv global indices.""" + offset_q = seqlen_info.offset_q + q_idx_global = q_idx + offset_q + offset_k = seqlen_info.offset_k + kv_idx_global = kv_idx + offset_k + q_bias = aux_tensors[0] + kv_bias = aux_tensors[1] + dtype = q_bias.element_type + + q_frag = cute.make_fragment(1, cutlass.Int32) + q_frag.store(q_idx_global) + q_bias_frag = cute.make_fragment(1, dtype) + q_bias_frag[0] = q_bias[q_frag[0]] + + kv_frag = cute.make_fragment(1, cutlass.Int32) + kv_frag.store(kv_idx_global) + kv_bias_frag = cute.make_fragment(1, dtype) + kv_bias_frag[0] = kv_bias[kv_frag[0]] + + return ( + tSrS_ssa + + (q_bias_frag.load()).to(cutlass.Float32) + + (kv_bias_frag.load()).to(cutlass.Float32) + ) + + +@cute.jit +def score_mod_global_logical_rel_plus_kv_bias( + tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors +): + """Logical relative + global-indexed per-token bias.""" + offset_k = seqlen_info.offset_k + kv_idx_global = kv_idx + offset_k + token_bias = aux_tensors[0] + dtype = token_bias.element_type + + rel_pos = q_idx - kv_idx + rel_pos_abs = cute.TensorSSA(mlir_math.absi(rel_pos), rel_pos.shape, rel_pos.dtype) + rel_bias = rel_pos_abs.to(cutlass.Float32) * cute.full_like(tSrS_ssa, 0.01) + + kv_frag = cute.make_fragment(1, cutlass.Int32) + kv_frag.store(kv_idx_global) + bias_frag = cute.make_fragment(1, dtype) + bias_frag[0] = token_bias[kv_frag[0]] + + return tSrS_ssa + rel_bias + (bias_frag.load()).to(cutlass.Float32) + + +# "Stress tests" - score_mods with complex global index usage + +@cute.jit +def score_mod_stress_complex_arithmetic( + tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors +): + """All indices in complex arithmetic.""" + offset_q = seqlen_info.offset_q + q_idx_global = q_idx + offset_q + bias = aux_tensors[0] + dtype = bias.element_type + + # Use absolute value instead of squaring to avoid overflow with large sequences + rel_pos = q_idx - kv_idx + rel_pos_abs = cute.TensorSSA(mlir_math.absi(rel_pos), rel_pos.shape, rel_pos.dtype) + rel_bias = rel_pos_abs.to(cutlass.Float32) * cute.full_like(tSrS_ssa, 0.001) + + q_frag = cute.make_fragment(1, cutlass.Int32) + q_frag.store(q_idx_global) + bias_q_frag = cute.make_fragment(1, dtype) + bias_q_frag[0] = bias[q_frag[0]] + bias_q = (bias_q_frag.load()).to(cutlass.Float32) + + scale = (b_idx + cute.full_like(b_idx, 1)) * (h_idx + cute.full_like(h_idx, 1)) + scale_f32 = scale.to(cutlass.Float32) * 0.001 + + result = tSrS_ssa + rel_bias + bias_q * scale_f32 + return result + + +@cute.jit +def score_mod_stress_conditional_mask( + tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors +): + """Conditional masking with global vs logical.""" + offset_q = seqlen_info.offset_q + q_idx_global = q_idx + offset_q + offset_k = seqlen_info.offset_k + kv_idx_global = kv_idx + offset_k + token_bias = aux_tensors[0] + dtype = token_bias.element_type + + kv_frag = cute.make_fragment(1, cutlass.Int32) + kv_frag.store(kv_idx_global) + bias_frag = cute.make_fragment(1, dtype) + bias_frag[0] = token_bias[kv_frag[0]] + bias_val = (bias_frag.load()).to(cutlass.Float32) + + is_causal = operator.ge(q_idx, kv_idx) + + global_diff = q_idx_global - kv_idx_global + is_nearby = operator.le( + cute.TensorSSA(mlir_math.absi(global_diff), global_diff.shape, global_diff.dtype), + cute.full_like(global_diff, 512), + ) + + both_conditions = is_causal & is_nearby + return cute.where(both_conditions, tSrS_ssa + bias_val, cute.full_like(tSrS_ssa, float("-inf"))) + + +@cute.jit +def score_mod_stress_multi_buffer( + tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors +): + """Multiple aux tensors with different indexing.""" + offset_q = seqlen_info.offset_q + q_idx_global = q_idx + offset_q + offset_k = seqlen_info.offset_k + kv_idx_global = kv_idx + offset_k + batch_bias = aux_tensors[0] + head_scale = aux_tensors[1] + q_pos_bias = aux_tensors[2] + kv_pos_bias = aux_tensors[3] + rel_pos_scale = aux_tensors[4] + + dtype = batch_bias.element_type + + b_frag = cute.make_fragment(1, cutlass.Int32) + b_frag.store(b_idx) + bb_frag = cute.make_fragment(1, dtype) + bb_frag[0] = batch_bias[b_frag[0]] + bb_val = (bb_frag.load()).to(cutlass.Float32) + + h_frag = cute.make_fragment(1, cutlass.Int32) + h_frag.store(h_idx) + hs_frag = cute.make_fragment(1, dtype) + hs_frag[0] = head_scale[h_frag[0]] + hs_val = (hs_frag.load()).to(cutlass.Float32) + + qg_frag = cute.make_fragment(1, cutlass.Int32) + qg_frag.store(q_idx_global) + qpb_frag = cute.make_fragment(1, dtype) + qpb_frag[0] = q_pos_bias[qg_frag[0]] + qpb_val = (qpb_frag.load()).to(cutlass.Float32) + + kvg_frag = cute.make_fragment(1, cutlass.Int32) + kvg_frag.store(kv_idx_global) + kvpb_frag = cute.make_fragment(1, dtype) + kvpb_frag[0] = kv_pos_bias[kvg_frag[0]] + kvpb_val = (kvpb_frag.load()).to(cutlass.Float32) + + rel_idx = q_idx - kv_idx + cute.full_like(q_idx, 512) + rel_idx_clamped = cute.where( + operator.lt(rel_idx, cute.full_like(rel_idx, 0)), cute.full_like(rel_idx, 0), rel_idx + ) + rel_idx_clamped = cute.where( + operator.gt(rel_idx_clamped, cute.full_like(rel_idx_clamped, 1024)), + cute.full_like(rel_idx_clamped, 1024), + rel_idx_clamped, + ) + ri_frag = cute.make_fragment(1, cutlass.Int32) + ri_frag.store(rel_idx_clamped) + rps_frag = cute.make_fragment(1, dtype) + rps_frag[0] = rel_pos_scale[ri_frag[0]] + rps_val = (rps_frag.load()).to(cutlass.Float32) + + return tSrS_ssa * hs_val + bb_val + qpb_val + kvpb_val + rps_val * cute.full_like(tSrS_ssa, 0.1) + + +@cute.jit +def score_mod_stress_global_offset( + tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors +): + """Verify global - logical = offset.""" + offset_k = seqlen_info.offset_k + kv_idx_global = kv_idx + offset_k + token_bias = aux_tensors[0] + dtype = token_bias.element_type + + kv_frag = cute.make_fragment(1, cutlass.Int32) + kv_frag.store(kv_idx_global) + bias_frag = cute.make_fragment(1, dtype) + bias_frag[0] = token_bias[kv_frag[0]] + + return tSrS_ssa + (bias_frag.load()).to(cutlass.Float32) + + +@cute.jit +def score_mod_stress_xor_pattern( + tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors +): + """XOR-based pattern using index bits.""" + offset_k = seqlen_info.offset_k + kv_idx_global = kv_idx + offset_k + token_bias = aux_tensors[0] + dtype = token_bias.element_type + + xor_logical = q_idx ^ kv_idx + pattern_logical = xor_logical & cute.full_like(xor_logical, 0xFF) + pattern_bias = pattern_logical.to(cutlass.Float32) * cute.full_like(tSrS_ssa, 0.001) + + kv_frag = cute.make_fragment(1, cutlass.Int32) + kv_frag.store(kv_idx_global) + bias_frag = cute.make_fragment(1, dtype) + bias_frag[0] = token_bias[kv_frag[0]] + + return ( + tSrS_ssa + + pattern_bias + + (bias_frag.load()).to(cutlass.Float32) * cute.full_like(tSrS_ssa, 0.1) + ) + + +@cute.jit +def score_mod_debug_global_idx( + tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors +): + # Don't read from aux_tensors at all - just add the global index as bias + offset_k = seqlen_info.offset_k + kv_idx_global = kv_idx + offset_k + bias = kv_idx_global.to(cutlass.Float32) * cute.full_like(tSrS_ssa, 0.001) + return tSrS_ssa + bias + + +# ============================================================================= +# Eager reference functions +# ============================================================================= + + +def identity_eager(score, b, h, q_idx, kv_idx): + return score + + +def causal_eager(score, b, h, q_idx, kv_idx): + return torch.where(q_idx >= kv_idx, score, float("-inf")) + + +def rel_bias_eager(score, b, h, q_idx, kv_idx): + return score + torch.abs(q_idx - kv_idx) + + +def rel_bias_x2_eager(score, b, h, q_idx, kv_idx): + return score + 2 * torch.abs(q_idx - kv_idx) + + +def times_two_eager(score, b, h, q_idx, kv_idx): + return score * 2 + + +def alibi_eager(score, b, h, q_idx, kv_idx): + slope = 2 ** (-8 * (h + 1) / 8) + return score - slope * torch.abs(q_idx - kv_idx) + + +def sliding_window_eager(score, b, h, q_idx, kv_idx): + return torch.where(torch.abs(q_idx - kv_idx) <= 256, score, float("-inf")) + + +def block_diagonal_eager(score, b, h, q_idx, kv_idx): + return torch.where(q_idx // 64 == kv_idx // 64, score, float("-inf")) + + +def causal_v2_eager(score, b, h, q_idx, kv_idx): + return torch.where(q_idx - kv_idx >= 0, score, float("-inf")) + + +def batch_bias_factory(bias_tensor): + def mod(score, b, h, q_idx, kv_idx): + return score + bias_tensor[b] + + return mod + + +def dual_buffer_factory(head_bias, pos_bias): + def mod(score, b, h, q_idx, kv_idx): + return score + head_bias[h] + pos_bias[q_idx] + + return mod + + +def packed_kv_bias_factory(bias_tensor, cu_seqlens_k): + def mod(score, b, h, q_idx, kv_idx): + # Calculate valid length for this sequence + start = cu_seqlens_k[b] + seq_len = cu_seqlens_k[b+1] - start + + # Clamp kv_idx. + safe_kv_idx = torch.clamp(kv_idx, max=seq_len - 1) + + return score + bias_tensor[start + safe_kv_idx] + return mod + + +def packed_q_bias_factory(bias_tensor, cu_seqlens_q): + def mod(score, b, h, q_idx, kv_idx): + start = cu_seqlens_q[b] + seq_len = cu_seqlens_q[b+1] - start + + # Clamp q_idx + safe_q_idx = torch.clamp(q_idx, max=seq_len - 1) + + return score + bias_tensor[start + safe_q_idx] + return mod + + +def packed_rel_plus_kv_bias_factory(bias_tensor, cu_seqlens_k): + def mod(score, b, h, q_idx, kv_idx): + start = cu_seqlens_k[b] + seq_len = cu_seqlens_k[b+1] - start + + # Clamp kv_idx + safe_kv_idx = torch.clamp(kv_idx, max=seq_len - 1) + + rel_bias = torch.abs(q_idx - kv_idx).float() * 0.1 + return score + rel_bias + bias_tensor[start + safe_kv_idx] + + return mod + + +def packed_q_and_kv_bias_factory(q_bias, kv_bias, cu_seqlens_q, cu_seqlens_k): + def mod(score, b, h, q_idx, kv_idx): + # Handle Q bounds + q_start = cu_seqlens_q[b] + q_len = cu_seqlens_q[b+1] - q_start + safe_q_idx = torch.clamp(q_idx, max=q_len - 1) + + # Handle KV bounds + kv_start = cu_seqlens_k[b] + kv_len = cu_seqlens_k[b+1] - kv_start + safe_kv_idx = torch.clamp(kv_idx, max=kv_len - 1) + + return score + q_bias[q_start + safe_q_idx] + kv_bias[kv_start + safe_kv_idx] + + return mod + + +def packed_logical_rel_plus_kv_bias_factory(bias_tensor, cu_seqlens_k): + def mod(score, b, h, q_idx, kv_idx): + rel_bias = torch.abs(q_idx - kv_idx).float() * 0.01 + return score + rel_bias + bias_tensor[cu_seqlens_k[b] + kv_idx] + + return mod + + +def stress_complex_arithmetic_factory(bias, cu_seqlens_q): + def mod(score, b, h, q_idx, kv_idx): + # Use absolute value instead of squaring to avoid overflow with large sequences + rel_pos_abs = torch.abs(q_idx - kv_idx) + q_global = cu_seqlens_q[b] + q_idx + bias_q = bias[q_global] + scale = (b + 1) * (h + 1) * 0.001 + rel_bias = rel_pos_abs * 0.001 + return score + rel_bias + bias_q * scale + + return mod + + +def stress_conditional_mask_factory(token_bias, cu_seqlens_q, cu_seqlens_k): + def mod(score, b, h, q_idx, kv_idx): + kv_global = cu_seqlens_k[b] + kv_idx + bias_val = token_bias[kv_global] + is_causal = q_idx >= kv_idx + q_global = cu_seqlens_q[b] + q_idx + global_diff = q_global - kv_global + is_nearby = torch.abs(global_diff) <= 512 + both_conditions = is_causal & is_nearby + return torch.where(both_conditions, score + bias_val, float("-inf")) + + return mod + + +def stress_multi_buffer_factory( + batch_bias, + head_scale, + q_pos_bias, + kv_pos_bias, + rel_pos_scale, + cu_seqlens_q, + cu_seqlens_k, + max_rel_pos=512, +): + def mod(score, b, h, q_idx, kv_idx): + bb_val = batch_bias[b] + hs_val = head_scale[h] + qpb_val = q_pos_bias[cu_seqlens_q[b] + q_idx] + kvpb_val = kv_pos_bias[cu_seqlens_k[b] + kv_idx] + rel_idx = (q_idx - kv_idx + max_rel_pos).clamp(0, max_rel_pos * 2) + rps_val = rel_pos_scale[rel_idx] + return score * hs_val + bb_val + qpb_val + kvpb_val + rps_val * 0.1 + + return mod + + +def stress_global_offset_factory(token_bias, cu_seqlens_k): + def mod(score, b, h, q_idx, kv_idx): + return score + token_bias[cu_seqlens_k[b] + kv_idx] + + return mod + + +def stress_xor_pattern_factory(token_bias, cu_seqlens_q, cu_seqlens_k): + def mod(score, b, h, q_idx, kv_idx): + xor_logical = q_idx ^ kv_idx + pattern_bias = (xor_logical & 0xFF).float() * 0.001 + kv_global = cu_seqlens_k[b] + kv_idx + return score + pattern_bias + token_bias[kv_global] * 0.1 + + return mod + +def debug_global_idx_factory(bias, cu_seqlens_k): + offsets = cu_seqlens_k.tolist() + def mod(score, b, h, q_idx, kv_idx): + global_kv = offsets[b] + kv_idx + return score + global_kv.float() * 0.001 + return mod diff --git a/tests/cute/test_flash_attn.py b/tests/cute/test_flash_attn.py index 98a752a3a35..83d2b9d3bf5 100644 --- a/tests/cute/test_flash_attn.py +++ b/tests/cute/test_flash_attn.py @@ -274,6 +274,7 @@ def test_flash_attn_output( and dv == d and learnable_sink is None # and False + and not ((causal or local) and seqlen_k < seqlen_q) ): g = torch.randn_like(out) # do_o = ((g.float() * out.float()).sum(-1)).transpose(1, 2) diff --git a/tests/cute/test_score_mod.py b/tests/cute/test_score_mod.py index 147e5519394..d5577ceaec8 100644 --- a/tests/cute/test_score_mod.py +++ b/tests/cute/test_score_mod.py @@ -6,218 +6,34 @@ import operator from torch.nn.attention.flex_attention import flex_attention from flash_attn.cute.interface import _flash_attn_fwd - - -@cute.jit -def score_mod_1(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, aux_tensors): - tmp0 = tSrS_ssa - tSrS_ssa = tmp0 - return tSrS_ssa - - -@cute.jit -def score_mod_2(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, aux_tensors): - tmp0 = q_idx - tmp1 = kv_idx - tmp2 = operator.ge(tmp0, tmp1) - tmp3 = tSrS_ssa - tmp4 = cute.where(tmp2, tmp3, cute.full_like(tmp3, float("-inf"))) - tSrS_ssa = tmp4 - return tSrS_ssa - - -@cute.jit -def score_mod_3(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, aux_tensors): - tmp0 = tSrS_ssa - tmp1 = q_idx - tmp2 = kv_idx - tmp3 = tmp1 - tmp2 - tmp4 = cute.TensorSSA(mlir_math.absi(tmp3), tmp3.shape, tmp3.dtype) - tmp5 = tmp4.to(cutlass.Float32) - tmp6 = tmp0 + tmp5 - tSrS_ssa = tmp6 - return tSrS_ssa - - -@cute.jit -def score_mod_4(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, aux_tensors): - tmp0 = tSrS_ssa - tmp1 = q_idx - tmp2 = kv_idx - tmp3 = tmp1 - tmp2 - tmp4 = cute.TensorSSA(mlir_math.absi(tmp3), tmp3.shape, tmp3.dtype) - tmp5 = tmp4 * cute.full_like(tmp4, 2) - tmp6 = tmp5.to(cutlass.Float32) - tmp7 = tmp0 + tmp6 - tSrS_ssa = tmp7 - return tSrS_ssa - - -@cute.jit -def score_mod_5(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, aux_tensors): - tmp0 = tSrS_ssa - tmp1 = tmp0 * cute.full_like(tmp0, 2) - tSrS_ssa = tmp1 - return tSrS_ssa - - -@cute.jit -def score_mod_6(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, aux_tensors): - tmp0 = tSrS_ssa - tmp1 = tmp0.to(cutlass.Float32) - tmp2 = h_idx - tmp3 = tmp2 + cute.full_like(tmp2, 1) - tmp4 = tmp3 * cute.full_like(tmp3, -8) - tmp5 = tmp4.to(cutlass.Float32) - tmp6 = tmp5 * cute.full_like(tmp5, 0.125) - tmp7 = tmp6 * cute.full_like(tmp6, 0.6931471805599453) - tmp8 = cute.math.exp2(tmp7 * 1.4426950408889634) - tmp9 = q_idx - tmp10 = kv_idx - tmp11 = tmp9 - tmp10 - tmp12 = cute.TensorSSA(mlir_math.absi(tmp11), tmp11.shape, tmp11.dtype) - tmp13 = tmp12.to(cutlass.Float32) - tmp14 = tmp8 * tmp13 - tmp15 = tmp1 - tmp14 - tSrS_ssa = tmp15 - return tSrS_ssa - - -@cute.jit -def score_mod_7(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, aux_tensors): - tmp0 = q_idx - tmp1 = kv_idx - tmp2 = tmp0 - tmp1 - tmp3 = cute.TensorSSA(mlir_math.absi(tmp2), tmp2.shape, tmp2.dtype) - tmp4 = operator.le(tmp3, cute.full_like(tmp3, 256)) - tmp5 = tSrS_ssa - tmp6 = cute.where(tmp4, tmp5, cute.full_like(tmp5, float("-inf"))) - tSrS_ssa = tmp6 - return tSrS_ssa - - -@cute.jit -def score_mod_8(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, aux_tensors): - tmp0 = q_idx - tmp1 = kv_idx - tmp2 = tSrS_ssa - tmp3 = cute.where( - operator.eq(tmp0 // 64, tmp1 // 64), tmp2, cute.full_like(tmp2, float("-inf")) - ) - tSrS_ssa = tmp3 - return tSrS_ssa - - -@cute.jit -def score_mod_9(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, aux_tensors): - tmp0 = q_idx - tmp1 = kv_idx - tmp2 = tmp0 - tmp1 - tmp3 = operator.ge(tmp2, cute.full_like(tmp2, 0)) - tmp4 = tSrS_ssa - tmp5 = cute.where(tmp3, tmp4, cute.full_like(tmp4, float("-inf"))) - tSrS_ssa = tmp5 - return tSrS_ssa - - -@cute.jit -def score_mod_10(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, aux_tensors): - batch_bias = aux_tensors[0] - - # Detect dtype from buffer element type - dtype = batch_bias.element_type - - b_frag = cute.make_fragment(1, cutlass.Int32) - b_frag.store(b_idx) - bias_frag = cute.make_fragment(1, dtype) - bias_frag[0] = batch_bias[b_frag[0]] - bias_val = (bias_frag.load()).to(cutlass.Float32) - - return tSrS_ssa + bias_val - - -@cute.jit -def score_mod_11(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, aux_tensors): - head_bias = aux_tensors[0] - pos_bias = aux_tensors[1] - - # Detect dtype from buffer element type - dtype = head_bias.element_type - - h_frag = cute.make_fragment(1, cutlass.Int32) - h_frag.store(h_idx) - head_val_frag = cute.make_fragment(1, dtype) - head_val_frag[0] = head_bias[h_frag[0]] - head_val = (head_val_frag.load()).to(cutlass.Float32) - - q_frag = cute.make_fragment(1, cutlass.Int32) - q_frag.store(q_idx) - pos_val_frag = cute.make_fragment(1, dtype) - pos_val_frag[0] = pos_bias[q_frag[0]] - pos_val = (pos_val_frag.load()).to(cutlass.Float32) - - return tSrS_ssa + head_val + pos_val - - -# Eager reference functions for comparison -def identity_eager(score, b, h, q_idx, kv_idx): - return score - - -def causal_mask_eager(score, b, h, q_idx, kv_idx): - return torch.where(q_idx >= kv_idx, score, float("-inf")) - - -def relative_bias_eager(score, b, h, q_idx, kv_idx): - return score + torch.abs(q_idx - kv_idx) - - -def relative_bias_v2_eager(score, b, h, q_idx, kv_idx): - return score + 2 * torch.abs(q_idx - kv_idx) - - -def times_two_eager(score, b, h, q_idx, kv_idx): - return score * 2 - - -def alibi_bias_eager(score, b, h, q_idx, kv_idx): - slope = 2 ** (-8 * (h + 1) / 8) - return score - slope * torch.abs(q_idx - kv_idx) - - -def sliding_window_eager(score, b, h, q_idx, kv_idx): - return torch.where(torch.abs(q_idx - kv_idx) <= 256, score, float("-inf")) - - -def block_diagonal_eager(score, b, h, q_idx, kv_idx): - q_block = q_idx // 64 - kv_block = kv_idx // 64 - return torch.where(q_block == kv_block, score, float("-inf")) - - -def causal_mask_v2_eager(score, b, h, q_idx, kv_idx): - return torch.where(q_idx - kv_idx >= 0, score, float("-inf")) - - -def batch_bias(bias_tensor): - """Per-batch bias (tests batch indexing).""" - - def batch_bias_mod(score, b, h, q_idx, kv_idx): - return score + bias_tensor[b] - - return batch_bias_mod - - -def dual_buffer_bias(head_bias, pos_scale): - """Dual buffer loading (tests loading from 2 separate tensors).""" - - def dual_buffer_mod(score, b, h, q_idx, kv_idx): - head_component = head_bias[h] - pos_component = pos_scale[q_idx] - return score + pos_component + head_component - - return dual_buffer_mod - +from score_mod_definitions import ( + # TensorSSA-based score mods + score_mod_identity as score_mod_1, + score_mod_causal as score_mod_2, + score_mod_rel_bias as score_mod_3, + score_mod_rel_bias_x2 as score_mod_4, + score_mod_times_two as score_mod_5, + score_mod_alibi as score_mod_6, + score_mod_sliding_window as score_mod_7, + score_mod_block_diagonal as score_mod_8, + score_mod_causal_v2 as score_mod_9, + score_mod_batch_bias as score_mod_10, + score_mod_dual_buffer as score_mod_11, +) # isort: split +from score_mod_definitions import ( + # Eager (torch) reference score mods + identity_eager, + causal_eager as causal_mask_eager, + rel_bias_eager as relative_bias_eager, + rel_bias_x2_eager as relative_bias_v2_eager, + times_two_eager, + alibi_eager as alibi_bias_eager, + sliding_window_eager, + block_diagonal_eager, + causal_v2_eager as causal_mask_v2_eager, + batch_bias_factory as batch_bias, + dual_buffer_factory as dual_buffer_bias, +) # Test pairs: (cute_jit_function, eager_reference_function) TEST_PAIRS = [ @@ -238,6 +54,29 @@ def dual_buffer_mod(score, b, h, q_idx, kv_idx): (score_mod_11, dual_buffer_bias), ] +SEQLEN_CONFIGS = [ + (1, 1), + (64, 128), + (128, 192), + (256, 256), + (239, 1), + (799, 3), + (113, 203), + (113, 128), + (128, 217), + (113, 211), + (108, 256), + (256, 512), + (384, 256), + (640, 128), + (512, 256), + (1024, 1024), + (1023, 1024), + (1024, 1023), + (4096, 4096), + (4224, 4224), +] + def create_tensors( batch_size=2, num_heads=4, seqlen_q=64, seqlen_kv=64, dim=128, dtype=torch.bfloat16 @@ -277,31 +116,7 @@ def run_flex_reference(q, k, v, eager_score_mod, dtype=None) -> torch.Tensor: ) -@pytest.mark.parametrize( - "seqlen_q,seqlen_kv", - [ - (1, 1), - (64, 128), - (128, 192), - (256, 256), - (239, 1), - (799, 3), - (113, 203), - (113, 128), - (128, 217), - (113, 211), - (108, 256), - (256, 512), - (384, 256), - (640, 128), - (512, 256), - (1024, 1024), - (1023, 1024), - (1024, 1023), - (4096, 4096), - (4224, 4224), - ], -) +@pytest.mark.parametrize("seqlen_q,seqlen_kv", SEQLEN_CONFIGS) @pytest.mark.parametrize("qhead_per_kvhead,num_kv_heads", [(1, 2), (4, 2)]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("score_mod_pair", TEST_PAIRS) @@ -354,31 +169,7 @@ def test_cute_vs_flex_attention( ) -@pytest.mark.parametrize( - "seqlen_q,seqlen_kv", - [ - (1, 1), - (64, 128), - (128, 192), - (256, 256), - (239, 1), - (799, 3), - (113, 203), - (113, 128), - (128, 217), - (113, 211), - (108, 256), - (256, 512), - (384, 256), - (640, 128), - (512, 256), - (1024, 1024), - (1023, 1024), - (1024, 1023), - (4096, 4096), - (4224, 4224), - ], -) +@pytest.mark.parametrize("seqlen_q,seqlen_kv", SEQLEN_CONFIGS) @pytest.mark.parametrize("qhead_per_kvhead,num_kv_heads", [(1, 1), (4, 2)]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("score_mod_pair", TEST_PAIRS_WITH_AUX_TENSORS) @@ -451,48 +242,359 @@ def test_cute_vs_flex_attention_with_aux_tensors( ) -@pytest.mark.xfail( - raises=NotImplementedError, reason="Varlen with score_mod not yet supported" +def _generate_block_kvcache( + seqlen_k, page_size, batch_size, nheads_k, d, device, dtype +): + import math + from einops import rearrange + + num_blocks = math.ceil(seqlen_k / page_size) * batch_size * 3 + k_cache_paged = torch.randn( + num_blocks, page_size, nheads_k, d, device=device, dtype=dtype + ) + v_cache_paged = torch.randn( + num_blocks, page_size, nheads_k, d, device=device, dtype=dtype + ) + page_table = rearrange( + torch.randperm(num_blocks, dtype=torch.int32, device=device), + "(b nblocks) -> b nblocks", + b=batch_size, + ) + k_cache_bshd = rearrange( + k_cache_paged[page_table.flatten()], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k] + v_cache_bshd = rearrange( + v_cache_paged[page_table.flatten()], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k] + k_cache = k_cache_bshd.transpose(1, 2) + v_cache = v_cache_bshd.transpose(1, 2) + return k_cache, v_cache, page_table, k_cache_paged, v_cache_paged, num_blocks + + +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("page_size", [None, 1, 4, 128]) +@pytest.mark.parametrize("qhead_per_kvhead,num_kv_heads", [(1, 2), (4, 2)]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_kv", + [ + (1, 128), + (64, 256), + (64, 800), + (256, 256), + (113, 203), + ], ) -def test_varlen_with_score_mod(): - """Test that varlen (variable length sequences) works with score_mod. +@pytest.mark.parametrize("score_mod_pair", TEST_PAIRS) +def test_score_mod_with_paged_kvcache( + seqlen_q, + seqlen_kv, + qhead_per_kvhead, + num_kv_heads, + page_size, + dtype, + score_mod_pair, +): + if page_size is not None and seqlen_kv % page_size != 0: + pytest.skip() - For varlen, tokens from different sequences should not attend to each other. - Without proper index mapping, the causal mask will be applied to the global - indices instead of per-sequence logical indices. - """ torch.random.manual_seed(42) + cute_score_mod, eager_score_mod = score_mod_pair + + batch_size = 2 + num_q_heads = num_kv_heads * qhead_per_kvhead + pack_gqa = qhead_per_kvhead > 1 + dim = 128 + device = "cuda" + + q = torch.randn(batch_size, num_q_heads, seqlen_q, dim, device=device, dtype=dtype) + + if page_size is None: + k_cache = torch.randn( + batch_size, num_kv_heads, seqlen_kv, dim, device=device, dtype=dtype + ) + v_cache = torch.randn( + batch_size, num_kv_heads, seqlen_kv, dim, device=device, dtype=dtype + ) + page_table = None + k_cache_paged = None + v_cache_paged = None + else: + ( + k_cache, + v_cache, + page_table, + k_cache_paged, + v_cache_paged, + num_blocks, + ) = _generate_block_kvcache( + seqlen_kv, page_size, batch_size, num_kv_heads, dim, device, dtype + ) + + cache_seqlens = torch.randint( + 1, seqlen_kv + 1, (batch_size,), dtype=torch.int32, device=device + ) + + from einops import rearrange + + arange = rearrange(torch.arange(seqlen_kv, device=device), "s -> 1 s") + cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") + key_padding_mask = arange < cache_seqlens_expanded + + if pack_gqa: + k_cache_rep = k_cache.repeat_interleave(qhead_per_kvhead, dim=1) + v_cache_rep = v_cache.repeat_interleave(qhead_per_kvhead, dim=1) + else: + k_cache_rep = k_cache + v_cache_rep = v_cache - seqlens = [64, 56, 128] - total_seq = sum(seqlens) - num_heads = 4 - dtype = torch.bfloat16 + def make_masked_score_mod(base_score_mod, seqused_k_tensor): + seqused_k_dev = seqused_k_tensor - cu_seqlens = torch.tensor( - [0] + list(torch.tensor(seqlens).cumsum(0).tolist()), - device="cuda", - dtype=torch.int32, + def masked_score_mod(score, b, h, q_idx, kv_idx): + if base_score_mod is not None: + score = base_score_mod(score, b, h, q_idx, kv_idx) + seqlen_limit = torch.gather(seqused_k_dev, 0, b.long()) + valid_mask = kv_idx < seqlen_limit + return torch.where(valid_mask, score, torch.full_like(score, float("-inf"))) + + return masked_score_mod + + masked_score_mod_fp32 = make_masked_score_mod(eager_score_mod, cache_seqlens) + masked_score_mod = make_masked_score_mod(eager_score_mod, cache_seqlens) + + out_ref_fp32 = run_flex_reference( + q, k_cache_rep, v_cache_rep, masked_score_mod_fp32, dtype=torch.float32 ) - q = torch.randn(total_seq, num_heads, 128, device="cuda", dtype=dtype) - k = torch.randn(total_seq, num_heads, 128, device="cuda", dtype=dtype) - v = torch.randn(total_seq, num_heads, 128, device="cuda", dtype=dtype) + out_pt = run_flex_reference(q, k_cache_rep, v_cache_rep, masked_score_mod) + + q_bshd = q.transpose(1, 2) + out_cute = torch.empty_like(q_bshd) + + if page_size is None: + k_bshd = k_cache.transpose(1, 2) + v_bshd = v_cache.transpose(1, 2) + _flash_attn_fwd( + q_bshd, + k_bshd, + v_bshd, + seqused_k=cache_seqlens, + return_lse=True, + score_mod=cute_score_mod, + out=out_cute, + lse=None, + pack_gqa=pack_gqa, + ) + else: + _flash_attn_fwd( + q_bshd, + k_cache_paged, + v_cache_paged, + seqused_k=cache_seqlens, + page_table=page_table, + return_lse=True, + score_mod=cute_score_mod, + out=out_cute, + lse=None, + pack_gqa=pack_gqa, + ) + + out_cute = out_cute.transpose(1, 2) - out_cute = torch.empty_like(q) + assert out_cute.shape == out_ref_fp32.shape == out_pt.shape + assert not torch.isnan(out_cute).any() + assert not torch.isnan(out_ref_fp32).any() + assert not torch.isnan(out_pt).any() + assert torch.isfinite(out_cute).all() + assert torch.isfinite(out_ref_fp32).all() + assert torch.isfinite(out_pt).all() - _flash_attn_fwd( - q, - k, - v, - cu_seqlens_q=cu_seqlens, - cu_seqlens_k=cu_seqlens, - return_lse=True, - score_mod=score_mod_2, - out=out_cute, - lse=None, + fwd_atol = 2 * (out_ref_fp32 + 0.3 - 0.3 - out_ref_fp32).abs().max().item() + rtol = 2 + + pt_error = (out_pt - out_ref_fp32).abs().max().item() + cute_error = (out_cute - out_ref_fp32).abs().max().item() + + print( + f"\nNumerical comparison for {cute_score_mod.__name__} (paged={page_size is not None}):" + ) + print(f" PyTorch vs FP32 ref max error: {pt_error:.2e}") + print(f" CuTE vs FP32 ref max error: {cute_error:.2e}") + print(f" Dynamic absolute tolerance: {fwd_atol:.2e}") + print(f" Error ratio (CuTE/PyTorch): {cute_error / max(pt_error, 1e-10):.2f}") + + assert cute_error <= rtol * pt_error + fwd_atol, ( + f"CuTE error {cute_error:.2e} exceeds {rtol}x PyTorch error {pt_error:.2e} + {fwd_atol:.2e}" + ) + + +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("page_size", [None, 128]) +@pytest.mark.parametrize("qhead_per_kvhead,num_kv_heads", [(1, 1), (4, 2)]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_kv", + [ + (64, 128), + (128, 256), + (256, 256), + ], +) +@pytest.mark.parametrize("score_mod_pair", TEST_PAIRS_WITH_AUX_TENSORS) +def test_score_mod_with_paged_kvcache_aux_tensors( + seqlen_q, + seqlen_kv, + qhead_per_kvhead, + num_kv_heads, + page_size, + dtype, + score_mod_pair, +): + if page_size is not None and seqlen_kv % page_size != 0: + pytest.skip() + + torch.random.manual_seed(42) + cute_score_mod, eager_score_mod_factory = score_mod_pair + + batch_size = 2 + num_q_heads = num_kv_heads * qhead_per_kvhead + pack_gqa = qhead_per_kvhead > 1 + dim = 128 + device = "cuda" + + q = torch.randn(batch_size, num_q_heads, seqlen_q, dim, device=device, dtype=dtype) + + if page_size is None: + k_cache = torch.randn( + batch_size, num_kv_heads, seqlen_kv, dim, device=device, dtype=dtype + ) + v_cache = torch.randn( + batch_size, num_kv_heads, seqlen_kv, dim, device=device, dtype=dtype + ) + page_table = None + k_cache_paged = None + v_cache_paged = None + else: + ( + k_cache, + v_cache, + page_table, + k_cache_paged, + v_cache_paged, + num_blocks, + ) = _generate_block_kvcache( + seqlen_kv, page_size, batch_size, num_kv_heads, dim, device, dtype + ) + + cache_seqlens = torch.randint( + 1, seqlen_kv + 1, (batch_size,), dtype=torch.int32, device=device + ) + + if cute_score_mod == score_mod_10: + buffer = torch.randn(batch_size, device=device, dtype=dtype) * 0.1 + aux_tensors = [buffer] + eager_score_mod = eager_score_mod_factory(buffer) + elif cute_score_mod == score_mod_11: + head_bias = torch.randn(num_q_heads, device=device, dtype=dtype) * 0.2 + pos_scale = torch.arange(seqlen_q, device=device, dtype=dtype) * 0.01 + aux_tensors = [head_bias, pos_scale] + eager_score_mod = eager_score_mod_factory(head_bias, pos_scale) + + from einops import rearrange + + arange = rearrange(torch.arange(seqlen_kv, device=device), "s -> 1 s") + cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") + key_padding_mask = arange < cache_seqlens_expanded + + if pack_gqa: + k_cache_rep = k_cache.repeat_interleave(qhead_per_kvhead, dim=1) + v_cache_rep = v_cache.repeat_interleave(qhead_per_kvhead, dim=1) + else: + k_cache_rep = k_cache + v_cache_rep = v_cache + + def make_masked_score_mod(base_score_mod, seqused_k_tensor): + seqused_k_dev = seqused_k_tensor + + def masked_score_mod(score, b, h, q_idx, kv_idx): + if base_score_mod is not None: + score = base_score_mod(score, b, h, q_idx, kv_idx) + seqlen_limit = torch.gather(seqused_k_dev, 0, b.long()) + valid_mask = kv_idx < seqlen_limit + return torch.where(valid_mask, score, torch.full_like(score, float("-inf"))) + + return masked_score_mod + + masked_score_mod_fp32 = make_masked_score_mod(eager_score_mod, cache_seqlens) + masked_score_mod = make_masked_score_mod(eager_score_mod, cache_seqlens) + + out_ref_fp32 = run_flex_reference( + q, k_cache_rep, v_cache_rep, masked_score_mod_fp32, dtype=torch.float32 + ) + out_pt = run_flex_reference(q, k_cache_rep, v_cache_rep, masked_score_mod) + + q_bshd = q.transpose(1, 2) + out_cute = torch.empty_like(q_bshd) + + if page_size is None: + k_bshd = k_cache.transpose(1, 2) + v_bshd = v_cache.transpose(1, 2) + _flash_attn_fwd( + q_bshd, + k_bshd, + v_bshd, + seqused_k=cache_seqlens, + return_lse=True, + score_mod=cute_score_mod, + out=out_cute, + lse=None, + aux_tensors=aux_tensors, + pack_gqa=pack_gqa, + ) + else: + _flash_attn_fwd( + q_bshd, + k_cache_paged, + v_cache_paged, + seqused_k=cache_seqlens, + page_table=page_table, + return_lse=True, + score_mod=cute_score_mod, + out=out_cute, + lse=None, + aux_tensors=aux_tensors, + pack_gqa=pack_gqa, + ) + + out_cute = out_cute.transpose(1, 2) + + assert out_cute.shape == out_ref_fp32.shape == out_pt.shape + assert not torch.isnan(out_cute).any() + assert not torch.isnan(out_ref_fp32).any() + assert not torch.isnan(out_pt).any() + assert torch.isfinite(out_cute).all() + assert torch.isfinite(out_ref_fp32).all() + assert torch.isfinite(out_pt).all() + + fwd_atol = 2 * (out_ref_fp32 + 0.3 - 0.3 - out_ref_fp32).abs().max().item() + rtol = 2 + + pt_error = (out_pt - out_ref_fp32).abs().max().item() + cute_error = (out_cute - out_ref_fp32).abs().max().item() + + print( + f"\nNumerical comparison for {cute_score_mod.__name__} (paged={page_size is not None}):" ) + print(f" PyTorch vs FP32 ref max error: {pt_error:.2e}") + print(f" CuTE vs FP32 ref max error: {cute_error:.2e}") + print(f" Dynamic absolute tolerance: {fwd_atol:.2e}") + print(f" Error ratio (CuTE/PyTorch): {cute_error / max(pt_error, 1e-10):.2f}") - assert not torch.isnan(out_cute).any(), "Output contains NaN values" - assert torch.isfinite(out_cute).all(), "Output contains infinite values" + assert cute_error <= rtol * pt_error + fwd_atol, ( + f"CuTE error {cute_error:.2e} exceeds {rtol}x PyTorch error {pt_error:.2e} + {fwd_atol:.2e}" + ) if __name__ == "__main__": diff --git a/tests/cute/test_score_mod_varlen.py b/tests/cute/test_score_mod_varlen.py new file mode 100644 index 00000000000..3f339e548c5 --- /dev/null +++ b/tests/cute/test_score_mod_varlen.py @@ -0,0 +1,1048 @@ +import pytest +import torch +from torch.nn.attention.flex_attention import flex_attention +from flash_attn.cute.interface import _flash_attn_fwd +from test_score_mod import _generate_block_kvcache +from score_mod_definitions import ( + # TensorSSA-based score mods + score_mod_alibi, + score_mod_batch_bias, + score_mod_block_diagonal, + score_mod_causal, + score_mod_causal_v2, + score_mod_debug_global_idx, + score_mod_dual_buffer, + score_mod_global_kv_bias, + score_mod_global_logical_rel_plus_kv_bias, + score_mod_global_q_and_kv_bias, + score_mod_global_q_bias, + score_mod_global_rel_plus_kv_bias, + score_mod_identity, + score_mod_rel_bias, + score_mod_rel_bias_x2, + score_mod_sliding_window, + score_mod_stress_complex_arithmetic, + score_mod_stress_conditional_mask, + score_mod_stress_global_offset, + score_mod_stress_multi_buffer, + score_mod_stress_xor_pattern, + score_mod_times_two, +) # isort: split +from score_mod_definitions import ( + # Eager (torch) reference score mods + identity_eager, + causal_eager, + rel_bias_eager, + rel_bias_x2_eager, + times_two_eager, + alibi_eager, + sliding_window_eager, + block_diagonal_eager, + causal_v2_eager, + batch_bias_factory, + dual_buffer_factory, + packed_kv_bias_factory, + packed_q_bias_factory, + packed_rel_plus_kv_bias_factory, + packed_q_and_kv_bias_factory, + packed_logical_rel_plus_kv_bias_factory, + stress_complex_arithmetic_factory, + stress_conditional_mask_factory, + stress_multi_buffer_factory, + stress_global_offset_factory, + stress_xor_pattern_factory, + debug_global_idx_factory, +) + +# ============================================================================= +# Test pairs +# ============================================================================= + +# (cute_score_mod, eager_factory_or_fn, aux_type) +# aux_type: None, "batch", "dual_buffer" +# All score_mods use 7-arg signature: (tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors) +TEST_PAIRS_NO_GLOBAL = [ + (score_mod_identity, identity_eager, None), + (score_mod_causal, causal_eager, None), + (score_mod_rel_bias, rel_bias_eager, None), + (score_mod_rel_bias_x2, rel_bias_x2_eager, None), + (score_mod_times_two, times_two_eager, None), + (score_mod_alibi, alibi_eager, None), + (score_mod_sliding_window, sliding_window_eager, None), + (score_mod_block_diagonal, block_diagonal_eager, None), + (score_mod_causal_v2, causal_v2_eager, None), + (score_mod_batch_bias, batch_bias_factory, "batch"), + (score_mod_dual_buffer, dual_buffer_factory, "dual_buffer"), +] + +# (cute_score_mod, eager_factory, aux_type, requires_global) +# aux_type: "kv", "q", "q_and_kv", "q_concat", "kv_with_cu", "multi_buffer" +# requires_global: "q" (needs varlen_q), "kv" (needs varlen_k), "both" (needs both) +# All score_mods use 7-arg signature and compute global indices from seqlen_info +TEST_PAIRS_WITH_GLOBAL = [ + (score_mod_global_kv_bias, packed_kv_bias_factory, "kv", "kv"), + (score_mod_global_q_bias, packed_q_bias_factory, "q", "q"), + (score_mod_global_rel_plus_kv_bias, packed_rel_plus_kv_bias_factory, "kv", "kv"), + (score_mod_global_q_and_kv_bias, packed_q_and_kv_bias_factory, "q_and_kv", "both"), + ( + score_mod_global_logical_rel_plus_kv_bias, + packed_logical_rel_plus_kv_bias_factory, + "kv", + "kv", + ), + ( + score_mod_stress_complex_arithmetic, + stress_complex_arithmetic_factory, + "q_concat", + "q", + ), + ( + score_mod_stress_conditional_mask, + stress_conditional_mask_factory, + "kv_with_cu", + "both", + ), + ( + score_mod_stress_multi_buffer, + stress_multi_buffer_factory, + "multi_buffer", + "both", + ), + (score_mod_stress_global_offset, stress_global_offset_factory, "kv", "kv"), + (score_mod_stress_xor_pattern, stress_xor_pattern_factory, "kv_with_cu", "kv"), + (score_mod_debug_global_idx, debug_global_idx_factory, "kv", "kv"), +] + +SEQLEN_CONFIGS = [ + ([1], [1]), + ([1, 1], [1, 1]), + ([2, 3], [2, 3]), + ([8, 16], [8, 16]), + ([32, 32], [32, 32]), + ([64, 128], [64, 128]), + ([64, 56, 128], [64, 56, 128]), + ([256, 512], [256, 512]), + ([113, 203], [113, 203]), + ([239, 1], [239, 1]), + ([64], [64]), + ([128, 128], [128, 128]), + ([32, 32, 32, 32], [32, 32, 32, 32]), + ([16, 32, 64, 128, 256], [16, 32, 64, 128, 256]), + ([1, 1024], [1, 1024]), + ([1024, 1], [1024, 1]), + ([1, 256, 1], [1, 256, 1]), + ([256, 1, 256], [256, 1, 256]), + ([17, 33, 65], [17, 33, 65]), + ([64, 128], [32, 64]), + ([100, 100], [50, 50]), + ([256, 512, 256], [128, 256, 128]), + ([2, 1], [16384, 32 * 1024]), + ([1, 1], [128 * 1024] * 2), + ([2, 1], [8192, 8192]), + ([1, 3], [8192, 8192]), + ([3, 3], [8192, 8192]), + ([128, 128], [8192, 8192]), + ([2, 2, 2], [8 * 1024] * 3), + ([2, 1], [1024 * 32, 16384]), + ([1, 2], [1024 * 32, 16384]), + ([1, 1, 1], [128 * 1024] * 3), + ([1, 1, 1], [256 * 1024] * 3), +] + +# ============================================================================= +# Helper functions +# ============================================================================= + + +def run_cute_flash( + q, + k, + v, + score_mod, + aux_tensors=None, + pack_gqa=False, + cu_seqlens_q=None, + cu_seqlens_k=None, + page_table=None, + seqused_k=None, +): + """Run CuTE flash attention.""" + if cu_seqlens_q is not None or cu_seqlens_k is not None: + out = torch.empty_like(q) + _flash_attn_fwd( + q, + k, + v, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + seqused_k=seqused_k, + page_table=page_table, + return_lse=True, + score_mod=score_mod, + out=out, + lse=None, + aux_tensors=aux_tensors, + pack_gqa=pack_gqa, + ) + return out + + out = torch.empty_like(q) + _flash_attn_fwd( + q, + k, + v, + seqused_k=seqused_k, + page_table=page_table, + return_lse=True, + score_mod=score_mod, + out=out, + lse=None, + aux_tensors=aux_tensors, + pack_gqa=pack_gqa, + ) + return out + + +def run_flex_varlen_ref(q, k, v, cu_seqlens_q, cu_seqlens_k, score_mod, dtype=None): + """Run flex_attention per-sequence for varlen reference.""" + if cu_seqlens_q is not None: + num_batches = len(cu_seqlens_q) - 1 + else: + num_batches = len(cu_seqlens_k) - 1 + + results = [] + for i in range(num_batches): + # Get Q slice + if cu_seqlens_q is not None: + q_slice = ( + q[cu_seqlens_q[i] : cu_seqlens_q[i + 1]].unsqueeze(0).transpose(1, 2) + ) + else: + q_slice = q[i : i + 1].transpose(1, 2) + + # Get K/V slices + if cu_seqlens_k is not None: + k_slice = ( + k[cu_seqlens_k[i] : cu_seqlens_k[i + 1]].unsqueeze(0).transpose(1, 2) + ) + v_slice = ( + v[cu_seqlens_k[i] : cu_seqlens_k[i + 1]].unsqueeze(0).transpose(1, 2) + ) + else: + k_slice = k[i : i + 1].transpose(1, 2) + v_slice = v[i : i + 1].transpose(1, 2) + + if dtype is not None: + q_slice, k_slice, v_slice = ( + q_slice.to(dtype), + k_slice.to(dtype), + v_slice.to(dtype), + ) + + def wrapped_mod(score, b, h, q_idx, kv_idx): + return score_mod(score, i, h, q_idx, kv_idx) + + out = flex_attention( + q_slice, + k_slice, + v_slice, + score_mod=wrapped_mod, + enable_gqa=q_slice.shape[1] != k_slice.shape[1], + ) + results.append(out.transpose(1, 2).squeeze(0)) + + return torch.cat(results, dim=0) + + +def setup_tensors(seqlens_q, seqlens_k, varlen_q, varlen_k, num_heads, head_dim, dtype): + """Create Q, K, V tensors and cu_seqlens based on varlen flags.""" + batch_size = len(seqlens_q) + + if varlen_q: + total_q = sum(seqlens_q) + q = torch.randn(total_q, num_heads, head_dim, device="cuda", dtype=dtype) + cu_seqlens_q = torch.tensor( + [0] + list(torch.tensor(seqlens_q).cumsum(0).tolist()), + device="cuda", + dtype=torch.int32, + ) + else: + seqlen_q = seqlens_q[0] # All sequences have the same length for non-varlen + q = torch.randn( + batch_size, seqlen_q, num_heads, head_dim, device="cuda", dtype=dtype + ) + cu_seqlens_q = None + + if varlen_k: + total_k = sum(seqlens_k) + k = torch.randn(total_k, num_heads, head_dim, device="cuda", dtype=dtype) + v = torch.randn(total_k, num_heads, head_dim, device="cuda", dtype=dtype) + cu_seqlens_k = torch.tensor( + [0] + list(torch.tensor(seqlens_k).cumsum(0).tolist()), + device="cuda", + dtype=torch.int32, + ) + else: + seqlen_k = seqlens_k[0] # All sequences have the same length for non-varlen + k = torch.randn( + batch_size, seqlen_k, num_heads, head_dim, device="cuda", dtype=dtype + ) + v = torch.randn( + batch_size, seqlen_k, num_heads, head_dim, device="cuda", dtype=dtype + ) + cu_seqlens_k = None + + return q, k, v, cu_seqlens_q, cu_seqlens_k + + +def prepare_ref_tensors( + q, k, v, cu_seqlens_q, cu_seqlens_k, varlen_q, varlen_k, batch_size, seqlens_q +): + """Prepare tensors for flex_attention reference (handle mixed varlen formats).""" + num_heads = q.shape[1] if varlen_q else q.shape[2] + + if not varlen_q and varlen_k: + seqlen_q = q.shape[1] + q_packed = q.reshape(-1, num_heads, q.shape[-1]) + ref_cu_seqlens_q = torch.tensor( + [seqlen_q * i for i in range(batch_size + 1)], + device="cuda", + dtype=torch.int32, + ) + return q_packed, k, v, ref_cu_seqlens_q, cu_seqlens_k + + if varlen_q and not varlen_k: + return q, k, v, cu_seqlens_q, None + + return q, k, v, cu_seqlens_q, cu_seqlens_k + + +def check_results( + out_cute, + out_ref_fp32, + out_pt, + test_name, + rtol=2, + extra_atol=1e-4, + seqlens_q=None, + cu_seqlens_q=None, +): + """Compare CuTE output against references.""" + assert not torch.isnan(out_cute).any(), f"{test_name}: NaN in output" + assert torch.isfinite(out_cute).all(), f"{test_name}: Inf in output" + + varlen_q = cu_seqlens_q is not None + + if varlen_q: + # Unpack and compare per-sequence + assert seqlens_q is not None, "varlen_q requires use of seqlens_q" + num_seqs = len(seqlens_q) + max_cute_error = 0.0 + max_pt_error = 0.0 + + for i in range(num_seqs): + # Extract sequences using cu_seqlens (all outputs are in packed format) + start_q = cu_seqlens_q[i] + end_q = cu_seqlens_q[i + 1] + cute_seq = out_cute[start_q:end_q] + ref_seq = out_ref_fp32[start_q:end_q] + pt_seq = out_pt[start_q:end_q] + + max_cute_error = max( + max_cute_error, (cute_seq - ref_seq).abs().max().item() + ) + max_pt_error = max(max_pt_error, (pt_seq - ref_seq).abs().max().item()) + + cute_error = max_cute_error + pt_error = max_pt_error + else: + # Direct comparison + pt_error = (out_pt - out_ref_fp32).abs().max().item() + cute_error = (out_cute - out_ref_fp32).abs().max().item() + + fwd_atol = 2 * (out_ref_fp32 + 0.3 - 0.3 - out_ref_fp32).abs().max().item() + + print(f"\n{test_name}:") + print(f" PyTorch vs FP32 ref: {pt_error:.2e}") + print(f" CuTE vs FP32 ref: {cute_error:.2e}") + + tol = rtol * pt_error + fwd_atol + extra_atol + assert cute_error <= tol, ( + f"{test_name}: CuTE error {cute_error:.2e} exceeds tolerance {tol:.2e}" + ) + + +# ============================================================================= +# Tests +# ============================================================================= + + +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("varlen_q", [True, False]) +@pytest.mark.parametrize("varlen_k", [True, False]) +@pytest.mark.parametrize("qhead_per_kvhead,num_kv_heads", [(4, 2)]) +@pytest.mark.parametrize("seqlens_q,seqlens_k", SEQLEN_CONFIGS) +@pytest.mark.parametrize("score_mod_tuple", TEST_PAIRS_NO_GLOBAL) +def test_varlen_with_score_mod( + seqlens_q, + seqlens_k, + varlen_q, + varlen_k, + qhead_per_kvhead, + num_kv_heads, + dtype, + score_mod_tuple, +): + """Test varlen attention with score_mod functions that don't use global indices. + + Covers: both varlen, varlen Q only, varlen K only. + Skips: neither varlen + """ + if not varlen_q and not varlen_k: + pytest.skip( + "At least one of varlen_q or varlen_k must be True for varlen tests" + ) + + # For non-varlen dimension, all sequences must have same length + if not varlen_q: + seqlens_q = [seqlens_q[0]] * len(seqlens_q) + if not varlen_k: + seqlens_k = [seqlens_k[0]] * len(seqlens_k) + + torch.random.manual_seed(42) + cute_score_mod, eager_factory, aux_type = score_mod_tuple + + num_heads = num_kv_heads * qhead_per_kvhead + pack_gqa = qhead_per_kvhead > 1 + head_dim = 128 + batch_size = len(seqlens_q) + + q, k, v, cu_seqlens_q, cu_seqlens_k = setup_tensors( + seqlens_q, seqlens_k, varlen_q, varlen_k, num_heads, head_dim, dtype + ) + + if pack_gqa: + if varlen_k: + k = k[:, :num_kv_heads, :].clone() + v = v[:, :num_kv_heads, :].clone() + else: + k = k[:, :, :num_kv_heads, :].clone() + v = v[:, :, :num_kv_heads, :].clone() + + aux_tensors = None + if aux_type == "batch": + bias = torch.zeros(batch_size, device="cuda", dtype=dtype) * 0.1 + aux_tensors = [bias] + eager_score_mod = eager_factory(bias) + elif aux_type == "dual_buffer": + seqlen_q = seqlens_q[0] if not varlen_q else max(seqlens_q) + head_bias = torch.randn(num_heads, device="cuda", dtype=dtype) * 0.2 + pos_bias = torch.arange(seqlen_q, device="cuda", dtype=dtype) * 0.01 + aux_tensors = [head_bias, pos_bias] + eager_score_mod = eager_factory(head_bias, pos_bias) + else: + eager_score_mod = eager_factory + + # Prepare reference tensors + q_ref, k_ref, v_ref, ref_cu_q, ref_cu_k = prepare_ref_tensors( + q, k, v, cu_seqlens_q, cu_seqlens_k, varlen_q, varlen_k, batch_size, seqlens_q + ) + + out_ref_fp32 = run_flex_varlen_ref( + q_ref, k_ref, v_ref, ref_cu_q, ref_cu_k, eager_score_mod, dtype=torch.float32 + ) + out_pt = run_flex_varlen_ref( + q_ref, k_ref, v_ref, ref_cu_q, ref_cu_k, eager_score_mod, dtype=dtype + ) + out_cute = run_cute_flash( + q, + k, + v, + cute_score_mod, + aux_tensors=aux_tensors, + pack_gqa=pack_gqa, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + ) + + if not varlen_q and varlen_k: + seqlen_q = q.shape[1] + out_ref_fp32 = out_ref_fp32.reshape(batch_size, seqlen_q, num_heads, head_dim) + out_pt = out_pt.reshape(batch_size, seqlen_q, num_heads, head_dim) + + assert out_cute.shape == out_ref_fp32.shape, ( + f"Shape mismatch: {out_cute.shape} vs {out_ref_fp32.shape}" + ) + + test_name = f"{cute_score_mod.__name__} (varlen_q={varlen_q}, varlen_k={varlen_k})" + extra_atol = 2e-3 + check_results( + out_cute, + out_ref_fp32, + out_pt, + test_name, + extra_atol=extra_atol, + seqlens_q=seqlens_q if varlen_q else None, + cu_seqlens_q=cu_seqlens_q if varlen_q else None, + ) + + +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("varlen_q", [True, False]) +@pytest.mark.parametrize("varlen_k", [True, False]) +@pytest.mark.parametrize("qhead_per_kvhead,num_kv_heads", [(1, 1), (4, 2)]) +@pytest.mark.parametrize("seqlens_q,seqlens_k", SEQLEN_CONFIGS) +@pytest.mark.parametrize("score_mod_tuple", TEST_PAIRS_WITH_GLOBAL) +def test_varlen_with_global_idx_score_mod( + seqlens_q, + seqlens_k, + varlen_q, + varlen_k, + qhead_per_kvhead, + num_kv_heads, + dtype, + score_mod_tuple, +): + """Test varlen attention with score_mod functions that use global indices. + + These score_mods compute q_idx_global and/or kv_idx_global from seqlen_info for packed tensor indexing. + Skips tests where required global indices aren't available. + """ + if not varlen_q and not varlen_k: + pytest.skip( + "At least one of varlen_q or varlen_k must be True for varlen tests" + ) + + cute_score_mod, eager_factory, aux_type, requires_global = score_mod_tuple + + # Skip if score_mod requires global indices we can't provide + if requires_global == "q" and not varlen_q: + pytest.skip(f"{cute_score_mod.__name__} requires varlen_q for q_idx_global") + if requires_global == "kv" and not varlen_k: + pytest.skip(f"{cute_score_mod.__name__} requires varlen_k for kv_idx_global") + if requires_global == "both" and (not varlen_q or not varlen_k): + pytest.skip(f"{cute_score_mod.__name__} requires both varlen_q and varlen_k") + + # For non-varlen dimension, all sequences must have same length + if not varlen_q: + seqlens_q = [seqlens_q[0]] * len(seqlens_q) + if not varlen_k: + seqlens_k = [seqlens_k[0]] * len(seqlens_k) + + torch.random.manual_seed(42) + + num_heads = num_kv_heads * qhead_per_kvhead + pack_gqa = qhead_per_kvhead > 1 + head_dim = 128 + batch_size = len(seqlens_q) + max_rel_pos = 512 + + total_q = sum(seqlens_q) + total_k = sum(seqlens_k) + + cu_seqlens_q = torch.tensor( + [0] + list(torch.tensor(seqlens_q).cumsum(0).tolist()), + device="cuda", + dtype=torch.int32, + ) + cu_seqlens_k = torch.tensor( + [0] + list(torch.tensor(seqlens_k).cumsum(0).tolist()), + device="cuda", + dtype=torch.int32, + ) + + if varlen_q: + q = torch.randn(total_q, num_heads, head_dim, device="cuda", dtype=dtype) + else: + seqlen_q = seqlens_q[0] + q = torch.randn( + batch_size, seqlen_q, num_heads, head_dim, device="cuda", dtype=dtype + ) + + if varlen_k: + k = torch.randn(total_k, num_heads, head_dim, device="cuda", dtype=dtype) + v = torch.randn(total_k, num_heads, head_dim, device="cuda", dtype=dtype) + else: + seqlen_k = seqlens_k[0] + k = torch.randn( + batch_size, seqlen_k, num_heads, head_dim, device="cuda", dtype=dtype + ) + v = torch.randn( + batch_size, seqlen_k, num_heads, head_dim, device="cuda", dtype=dtype + ) + + if pack_gqa: + if varlen_k: + k = k[:, :num_kv_heads, :].clone() + v = v[:, :num_kv_heads, :].clone() + else: + k = k[:, :, :num_kv_heads, :].clone() + v = v[:, :, :num_kv_heads, :].clone() + + # Setup aux tensors based on indexing type + if aux_type == "kv": + bias = torch.randn(total_k, device="cuda", dtype=dtype) * 0.1 + aux_tensors = [bias] + eager_score_mod = eager_factory(bias, cu_seqlens_k) + elif aux_type == "q": + bias = torch.randn(total_q, device="cuda", dtype=dtype) * 0.1 + aux_tensors = [bias] + eager_score_mod = eager_factory(bias, cu_seqlens_q) + elif aux_type == "q_and_kv": + q_bias = torch.randn(total_q, device="cuda", dtype=dtype) * 0.1 + kv_bias = torch.randn(total_k, device="cuda", dtype=dtype) * 0.1 + aux_tensors = [q_bias, kv_bias] + eager_score_mod = eager_factory(q_bias, kv_bias, cu_seqlens_q, cu_seqlens_k) + elif aux_type == "q_concat": + bias = torch.randn(total_q, device="cuda", dtype=dtype) * 0.1 + aux_tensors = [bias] + eager_score_mod = eager_factory(bias, cu_seqlens_q) + elif aux_type == "kv_with_cu": + kv_bias = torch.randn(total_k, device="cuda", dtype=dtype) * 0.1 + aux_tensors = [kv_bias] + eager_score_mod = eager_factory(kv_bias, cu_seqlens_q, cu_seqlens_k) + elif aux_type == "multi_buffer": + batch_bias = torch.randn(batch_size, device="cuda", dtype=dtype) * 0.1 + head_scale = torch.randn(num_heads, device="cuda", dtype=dtype) * 0.1 + 1.0 + q_pos_bias = torch.randn(total_q, device="cuda", dtype=dtype) * 0.1 + kv_pos_bias = torch.randn(total_k, device="cuda", dtype=dtype) * 0.1 + rel_pos_scale = ( + torch.randn(max_rel_pos * 2 + 1, device="cuda", dtype=dtype) * 0.1 + ) + aux_tensors = [batch_bias, head_scale, q_pos_bias, kv_pos_bias, rel_pos_scale] + eager_score_mod = eager_factory( + batch_bias, + head_scale, + q_pos_bias, + kv_pos_bias, + rel_pos_scale, + cu_seqlens_q, + cu_seqlens_k, + max_rel_pos, + ) + else: + raise ValueError(f"Unknown aux_type: {aux_type}") + + # Prepare reference tensors for flex_attention + q_ref, k_ref, v_ref, ref_cu_q, ref_cu_k = prepare_ref_tensors( + q, k, v, cu_seqlens_q, cu_seqlens_k, varlen_q, varlen_k, batch_size, seqlens_q + ) + + out_ref_fp32 = run_flex_varlen_ref( + q_ref, k_ref, v_ref, ref_cu_q, ref_cu_k, eager_score_mod, dtype=torch.float32 + ) + out_pt = run_flex_varlen_ref( + q_ref, k_ref, v_ref, ref_cu_q, ref_cu_k, eager_score_mod, dtype=dtype + ) + + kernel_cu_seqlens_q = cu_seqlens_q if varlen_q else None + kernel_cu_seqlens_k = cu_seqlens_k if varlen_k else None + out_cute = run_cute_flash( + q, + k, + v, + cute_score_mod, + aux_tensors=aux_tensors, + pack_gqa=pack_gqa, + cu_seqlens_q=kernel_cu_seqlens_q, + cu_seqlens_k=kernel_cu_seqlens_k, + ) + + if varlen_q: + out_ref_final = out_ref_fp32 + out_pt_final = out_pt + out_cute_final = out_cute + else: + seqlen_q = seqlens_q[0] + out_ref_final = out_ref_fp32.reshape(batch_size, seqlen_q, num_heads, head_dim) + out_pt_final = out_pt.reshape(batch_size, seqlen_q, num_heads, head_dim) + out_cute_final = out_cute + + assert out_cute_final.shape == out_ref_final.shape, ( + f"Shape mismatch: {out_cute_final.shape} vs {out_ref_final.shape}" + ) + + test_name = f"{cute_score_mod.__name__} (varlen_q={varlen_q}, varlen_k={varlen_k}, {aux_type})" + + check_results( + out_cute_final, + out_ref_final, + out_pt_final, + test_name, + extra_atol=1e-3, + seqlens_q=seqlens_q if varlen_q else None, + cu_seqlens_q=cu_seqlens_q if varlen_q else None, + ) + + +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("page_size", [None, 128]) +@pytest.mark.parametrize("varlen_q", [True, False]) +@pytest.mark.parametrize("varlen_k", [True, False]) +@pytest.mark.parametrize("qhead_per_kvhead,num_kv_heads", [(4, 2)]) +@pytest.mark.parametrize("seqlens_q,seqlens_k", SEQLEN_CONFIGS) +@pytest.mark.parametrize("score_mod_tuple", TEST_PAIRS_NO_GLOBAL) +def test_varlen_score_mod_kvcache( + seqlens_q, + seqlens_k, + varlen_q, + varlen_k, + qhead_per_kvhead, + num_kv_heads, + page_size, + dtype, + score_mod_tuple, +): + """Test varlen attention with score_mod and paged KV cache.""" + if not varlen_q and not varlen_k: + pytest.skip( + "At least one of varlen_q or varlen_k must be True for varlen tests" + ) + + if page_size is not None and varlen_k: + pytest.skip("Paged KV requires batched (non-varlen) K") + + if not varlen_q: + seqlens_q = [seqlens_q[0]] * len(seqlens_q) + if not varlen_k: + seqlens_k = [seqlens_k[0]] * len(seqlens_k) + + # Skip if page_size doesn't divide seqlens evenly (for simplicity) + if page_size is not None and not varlen_k: + if seqlens_k[0] % page_size != 0: + pytest.skip("page_size must divide seqlen_k") + + torch.random.manual_seed(42) + cute_score_mod, eager_factory, aux_type = score_mod_tuple + + num_heads = num_kv_heads * qhead_per_kvhead + pack_gqa = qhead_per_kvhead > 1 + head_dim = 128 + batch_size = len(seqlens_q) + device = "cuda" + + # Setup tensors + q, k, v, cu_seqlens_q, cu_seqlens_k = setup_tensors( + seqlens_q, seqlens_k, varlen_q, varlen_k, num_heads, head_dim, dtype + ) + + if pack_gqa: + if varlen_k: + k = k[:, :num_kv_heads, :].clone() + v = v[:, :num_kv_heads, :].clone() + else: + k = k[:, :, :num_kv_heads, :].clone() + v = v[:, :, :num_kv_heads, :].clone() + + page_table = None + k_cache_paged = None + v_cache_paged = None + k_cache = k + v_cache = v + + if page_size is not None: + seqlen_k = seqlens_k[0] + ( + k_cache_bhsd, + v_cache_bhsd, + page_table, + k_cache_paged, + v_cache_paged, + num_blocks, + ) = _generate_block_kvcache( + seqlen_k, page_size, batch_size, num_kv_heads, head_dim, device, dtype + ) + k_cache = k_cache_bhsd.transpose(1, 2) # BHSD -> BSHD + v_cache = v_cache_bhsd.transpose(1, 2) + seqused_k = torch.tensor(seqlens_k, dtype=torch.int32, device=device) + else: + seqused_k = None + + # Setup aux tensors and eager score_mod + aux_tensors = None + if aux_type == "batch": + bias = torch.zeros(batch_size, device=device, dtype=dtype) * 0.1 + aux_tensors = [bias] + eager_score_mod = eager_factory(bias) + elif aux_type == "dual_buffer": + seqlen_q = seqlens_q[0] if not varlen_q else max(seqlens_q) + head_bias = torch.randn(num_heads, device=device, dtype=dtype) * 0.2 + pos_bias = torch.arange(seqlen_q, device=device, dtype=dtype) * 0.01 + aux_tensors = [head_bias, pos_bias] + eager_score_mod = eager_factory(head_bias, pos_bias) + else: + eager_score_mod = eager_factory + + # Prepare reference tensors + q_ref, k_ref, v_ref, ref_cu_q, ref_cu_k = prepare_ref_tensors( + q, + k_cache, + v_cache, + cu_seqlens_q, + cu_seqlens_k, + varlen_q, + varlen_k, + batch_size, + seqlens_q, + ) + + out_ref_fp32 = run_flex_varlen_ref( + q_ref, k_ref, v_ref, ref_cu_q, ref_cu_k, eager_score_mod, dtype=torch.float32 + ) + out_pt = run_flex_varlen_ref( + q_ref, k_ref, v_ref, ref_cu_q, ref_cu_k, eager_score_mod, dtype=dtype + ) + + k_input = k_cache_paged if page_size is not None else k_cache + v_input = v_cache_paged if page_size is not None else v_cache + + out_cute = run_cute_flash( + q, + k_input, + v_input, + cute_score_mod, + aux_tensors=aux_tensors, + pack_gqa=pack_gqa, + cu_seqlens_q=cu_seqlens_q if varlen_q else None, + cu_seqlens_k=cu_seqlens_k if (varlen_k and page_size is None) else None, + page_table=page_table if page_size is not None else None, + seqused_k=seqused_k if page_size is not None else None, + ) + + if not varlen_q and varlen_k: + seqlen_q = q.shape[1] + out_ref_fp32 = out_ref_fp32.reshape(batch_size, seqlen_q, num_heads, head_dim) + out_pt = out_pt.reshape(batch_size, seqlen_q, num_heads, head_dim) + + assert out_cute.shape == out_ref_fp32.shape, ( + f"Shape mismatch: {out_cute.shape} vs {out_ref_fp32.shape}" + ) + + test_name = f"{cute_score_mod.__name__} (varlen_q={varlen_q}, varlen_k={varlen_k}, paged={page_size is not None})" + extra_atol = 2e-3 + check_results( + out_cute, + out_ref_fp32, + out_pt, + test_name, + extra_atol=extra_atol, + seqlens_q=seqlens_q if varlen_q else None, + cu_seqlens_q=cu_seqlens_q if varlen_q else None, + ) + + +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("page_size", [None, 128]) +@pytest.mark.parametrize("varlen_q", [True, False]) +@pytest.mark.parametrize("varlen_k", [True, False]) +@pytest.mark.parametrize("qhead_per_kvhead,num_kv_heads", [(1, 1), (4, 2)]) +@pytest.mark.parametrize("seqlens_q,seqlens_k", SEQLEN_CONFIGS) +@pytest.mark.parametrize("score_mod_tuple", TEST_PAIRS_WITH_GLOBAL) +def test_varlen_score_mod_with_paged_kvcache_global( + seqlens_q, + seqlens_k, + varlen_q, + varlen_k, + qhead_per_kvhead, + num_kv_heads, + page_size, + dtype, + score_mod_tuple, +): + """Test varlen attention with global idx score_mod and paged KV cache.""" + if page_size is not None and varlen_k: + pytest.skip("Paged KV cache requires batched (non-varlen) K") + + if not varlen_q and not varlen_k: + pytest.skip( + "At least one of varlen_q or varlen_k must be True for varlen tests" + ) + + if not varlen_q: + seqlens_q = [seqlens_q[0]] * len(seqlens_q) + if not varlen_k: + seqlens_k = [seqlens_k[0]] * len(seqlens_k) + + if page_size is not None and not varlen_k: + if seqlens_k[0] % page_size != 0: + pytest.skip("page_size must divide seqlen_k") + + cute_score_mod, eager_factory, aux_type, requires_global = score_mod_tuple + + if requires_global == "q" and not varlen_q: + pytest.skip(f"{cute_score_mod.__name__} requires varlen_q for q_idx_global") + if requires_global == "kv" and not varlen_k: + pytest.skip(f"{cute_score_mod.__name__} requires varlen_k for kv_idx_global") + if requires_global == "both" and (not varlen_q or not varlen_k): + pytest.skip(f"{cute_score_mod.__name__} requires both varlen_q and varlen_k") + + torch.random.manual_seed(42) + + num_heads = num_kv_heads * qhead_per_kvhead + pack_gqa = qhead_per_kvhead > 1 + head_dim = 128 + batch_size = len(seqlens_q) + max_rel_pos = 512 + device = "cuda" + + total_q = sum(seqlens_q) + total_k = sum(seqlens_k) + + cu_seqlens_q = torch.tensor( + [0] + list(torch.tensor(seqlens_q).cumsum(0).tolist()), + device=device, + dtype=torch.int32, + ) + cu_seqlens_k = torch.tensor( + [0] + list(torch.tensor(seqlens_k).cumsum(0).tolist()), + device=device, + dtype=torch.int32, + ) + cu_seqlens_k_for_kernel = cu_seqlens_k if varlen_k else None + + q = torch.randn(total_q, num_heads, head_dim, device=device, dtype=dtype) + if varlen_k: + k = torch.randn(total_k, num_heads, head_dim, device=device, dtype=dtype) + v = torch.randn(total_k, num_heads, head_dim, device=device, dtype=dtype) + else: + seqlen_k = seqlens_k[0] + k = torch.randn( + batch_size, seqlen_k, num_heads, head_dim, device=device, dtype=dtype + ) + v = torch.randn( + batch_size, seqlen_k, num_heads, head_dim, device=device, dtype=dtype + ) + + if pack_gqa: + if varlen_k: + k = k[:, :num_kv_heads, :].clone() + v = v[:, :num_kv_heads, :].clone() + else: + k = k[:, :, :num_kv_heads, :].clone() + v = v[:, :, :num_kv_heads, :].clone() + + page_table = None + k_cache_paged = None + v_cache_paged = None + k_cache = k + v_cache = v + + if page_size is not None: + seqlen_k = seqlens_k[0] + ( + k_cache_bhsd, + v_cache_bhsd, + page_table, + k_cache_paged, + v_cache_paged, + num_blocks, + ) = _generate_block_kvcache( + seqlen_k, page_size, batch_size, num_kv_heads, head_dim, device, dtype + ) + k_cache = k_cache_bhsd.transpose(1, 2) # BHSD -> BSHD + v_cache = v_cache_bhsd.transpose(1, 2) + seqused_k = torch.tensor(seqlens_k, dtype=torch.int32, device=device) + else: + seqused_k = None + + if aux_type == "kv": + bias = torch.randn(total_k, device=device, dtype=dtype) * 0.1 + aux_tensors = [bias] + eager_score_mod = eager_factory(bias, cu_seqlens_k) + elif aux_type == "q": + bias = torch.randn(total_q, device=device, dtype=dtype) * 0.1 + aux_tensors = [bias] + eager_score_mod = eager_factory(bias, cu_seqlens_q) + elif aux_type == "q_and_kv": + q_bias = torch.randn(total_q, device=device, dtype=dtype) * 0.1 + kv_bias = torch.randn(total_k, device=device, dtype=dtype) * 0.1 + aux_tensors = [q_bias, kv_bias] + eager_score_mod = eager_factory(q_bias, kv_bias, cu_seqlens_q, cu_seqlens_k) + elif aux_type == "q_concat": + bias = torch.randn(total_q, device=device, dtype=dtype) * 0.1 + aux_tensors = [bias] + eager_score_mod = eager_factory(bias, cu_seqlens_q) + elif aux_type == "kv_with_cu": + kv_bias = torch.randn(total_k, device=device, dtype=dtype) * 0.1 + aux_tensors = [kv_bias] + eager_score_mod = eager_factory(kv_bias, cu_seqlens_q, cu_seqlens_k) + elif aux_type == "multi_buffer": + batch_bias = torch.randn(batch_size, device=device, dtype=dtype) * 0.1 + head_scale = torch.randn(num_heads, device=device, dtype=dtype) * 0.1 + 1.0 + q_pos_bias = torch.randn(total_q, device=device, dtype=dtype) * 0.1 + kv_pos_bias = torch.randn(total_k, device=device, dtype=dtype) * 0.1 + rel_pos_scale = ( + torch.randn(max_rel_pos * 2 + 1, device=device, dtype=dtype) * 0.1 + ) + aux_tensors = [batch_bias, head_scale, q_pos_bias, kv_pos_bias, rel_pos_scale] + eager_score_mod = eager_factory( + batch_bias, + head_scale, + q_pos_bias, + kv_pos_bias, + rel_pos_scale, + cu_seqlens_q, + cu_seqlens_k, + max_rel_pos, + ) + else: + raise ValueError(f"Unknown aux_type: {aux_type}") + + q_ref, k_ref, v_ref, ref_cu_q, ref_cu_k = prepare_ref_tensors( + q, + k_cache, + v_cache, + cu_seqlens_q, + cu_seqlens_k, + True, + varlen_k, + batch_size, + seqlens_q, + ) + + out_ref_fp32 = run_flex_varlen_ref( + q_ref, k_ref, v_ref, ref_cu_q, ref_cu_k, eager_score_mod, dtype=torch.float32 + ) + out_pt = run_flex_varlen_ref( + q_ref, k_ref, v_ref, ref_cu_q, ref_cu_k, eager_score_mod, dtype=dtype + ) + + # Run CuTE + k_input = k_cache_paged if page_size is not None else k_cache + v_input = v_cache_paged if page_size is not None else v_cache + + out_cute = torch.empty_like(q) + _flash_attn_fwd( + q, + k_input, + v_input, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k_for_kernel if page_size is None else None, + seqused_k=seqused_k if page_size is not None else None, + page_table=page_table, + return_lse=True, + score_mod=cute_score_mod, + out=out_cute, + lse=None, + aux_tensors=aux_tensors, + pack_gqa=pack_gqa, + ) + + assert out_cute.shape == out_ref_fp32.shape, ( + f"Shape mismatch: {out_cute.shape} vs {out_ref_fp32.shape}" + ) + + test_name = f"{cute_score_mod.__name__} (paged={page_size is not None}, {aux_type})" + check_results( + out_cute, + out_ref_fp32, + out_pt, + test_name, + extra_atol=1e-3, + seqlens_q=seqlens_q, + cu_seqlens_q=cu_seqlens_q, + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])