diff --git a/flash_attn/cute/block_sparsity.py b/flash_attn/cute/block_sparsity.py index 9887355fa8d..dcaa3656b52 100644 --- a/flash_attn/cute/block_sparsity.py +++ b/flash_attn/cute/block_sparsity.py @@ -2,7 +2,7 @@ Block-sparsity utilities for FlexAttention """ -from typing import NamedTuple, Optional, Tuple +from typing import Callable, NamedTuple, Tuple import cutlass.cute as cute import torch @@ -17,8 +17,8 @@ def ceildiv(a: int, b: int) -> int: class BlockSparseTensors(NamedTuple): mask_block_cnt: cute.Tensor mask_block_idx: cute.Tensor - full_block_cnt: Optional[cute.Tensor] - full_block_idx: Optional[cute.Tensor] + full_block_cnt: cute.Tensor | None + full_block_idx: cute.Tensor | None def __new_from_mlir_values__(self, values): if len(values) == 2: @@ -29,14 +29,16 @@ def __new_from_mlir_values__(self, values): class BlockSparseTensorsTorch(NamedTuple): mask_block_cnt: torch.Tensor mask_block_idx: torch.Tensor - full_block_cnt: Optional[torch.Tensor] = None - full_block_idx: Optional[torch.Tensor] = None + full_block_cnt: torch.Tensor | None = None + full_block_idx: torch.Tensor | None = None def _expand_sparsity_tensor( tensor: torch.Tensor, expected_shape: Tuple[int, ...], tensor_name: str, + context: str | None, + hint: str | Callable[[], str] | None, ) -> torch.Tensor: """Check if we need to expand the tensor to expected shape, and do so if possible.""" needs_expand = tensor.shape != expected_shape @@ -44,19 +46,25 @@ def _expand_sparsity_tensor( return tensor can_expand = all(map(lambda cur, tgt: cur == tgt or cur == 1, tensor.shape, expected_shape)) if not can_expand: + context_clause = f" ({context})" if context else "" + resolved_hint = hint() if callable(hint) else hint + hint_clause = f" Hint: {resolved_hint}" if resolved_hint else "" raise ValueError( - f"{tensor_name} with shape {tensor.shape} cannot be expanded to expected shape {expected_shape}." + f"{tensor_name}{context_clause} with shape {tensor.shape} cannot be expanded to expected shape {expected_shape}." + f"{hint_clause}" ) return tensor.expand(*expected_shape).contiguous() def _check_and_expand_block( name: str, - cnt: Optional[torch.Tensor], - idx: Optional[torch.Tensor], + cnt: torch.Tensor | None, + idx: torch.Tensor | None, expected_count_shape: Tuple[int, int, int], expected_index_shape: Tuple[int, int, int, int], -) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: + context: str | None, + hint: str | Callable[[], str] | None, +) -> Tuple[torch.Tensor | None, torch.Tensor | None]: if (cnt is None) != (idx is None): raise ValueError( f"{name}_block_cnt and {name}_block_idx must both be provided or both be None" @@ -69,8 +77,12 @@ def _check_and_expand_block( raise ValueError(f"{name}_block_cnt and {name}_block_idx must be on the same device") if not cnt.is_cuda or not idx.is_cuda: raise ValueError(f"{name}_block tensors must live on CUDA") - expanded_cnt = _expand_sparsity_tensor(cnt, expected_count_shape, f"{name}_block_cnt") - expanded_idx = _expand_sparsity_tensor(idx, expected_index_shape, f"{name}_block_idx") + expanded_cnt = _expand_sparsity_tensor( + cnt, expected_count_shape, f"{name}_block_cnt", context, hint + ) + expanded_idx = _expand_sparsity_tensor( + idx, expected_index_shape, f"{name}_block_idx", context, hint + ) return expanded_cnt, expanded_idx @@ -120,6 +132,8 @@ def normalize_block_sparse_tensors( *, expected_count_shape: Tuple[int, int, int], expected_index_shape: Tuple[int, int, int, int], + context: str | None = None, + hint: str | Callable[[], str] | None = None, ) -> BlockSparseTensorsTorch: if tensors.mask_block_cnt is None or tensors.mask_block_idx is None: raise ValueError("mask_block_cnt and mask_block_idx must be provided for block sparsity.") @@ -130,6 +144,8 @@ def normalize_block_sparse_tensors( tensors.mask_block_idx, expected_count_shape, expected_index_shape, + context, + hint, ) if mask_cnt is None or mask_idx is None: raise ValueError("mask_block_cnt and mask_block_idx must be provided for block sparsity.") @@ -140,6 +156,8 @@ def normalize_block_sparse_tensors( tensors.full_block_idx, expected_count_shape, expected_index_shape, + context, + hint, ) if full_cnt is not None and mask_cnt.device != full_cnt.device: raise ValueError("All block sparse tensors must be on the same device") @@ -158,7 +176,7 @@ def is_block_sparsity_enabled(tensors: BlockSparseTensorsTorch) -> bool: def to_cute_block_sparse_tensors( tensors: BlockSparseTensorsTorch, enable_tvm_ffi: bool = True -) -> Optional[BlockSparseTensors]: +) -> BlockSparseTensors | None: """Convert torch block sparsity tensors to CuTe tensors, optionally for tvm ffi""" if not is_block_sparsity_enabled(tensors): return None diff --git a/flash_attn/cute/flash_bwd_sm90.py b/flash_attn/cute/flash_bwd_sm90.py index d9b504cee23..6c0c60b9724 100644 --- a/flash_attn/cute/flash_bwd_sm90.py +++ b/flash_attn/cute/flash_bwd_sm90.py @@ -23,6 +23,7 @@ from flash_attn.cute import pipeline from flash_attn.cute.tile_scheduler import TileSchedulerArguments, SingleTileScheduler, ParamsBase from flash_attn.cute.named_barrier import NamedBarrierFwd, NamedBarrierBwd +from flash_attn.cute.softmax import apply_score_mod_inner, apply_score_mod_bwd_inner from flash_attn.cute.block_sparsity import BlockSparseTensors from flash_attn.cute.block_sparse_utils import ( get_total_q_block_count_bwd, @@ -70,6 +71,8 @@ def __init__( AtomLayoutMdQ: int = 1, num_threads: int = 384, V_in_regs: bool = False, + score_mod: cutlass.Constexpr | None = None, + score_mod_bwd: cutlass.Constexpr | None = None, mask_mod: cutlass.Constexpr | None = None, has_aux_tensors: cutlass.Constexpr = False, subtile_factor: cutlass.Constexpr[int] = 1, @@ -118,6 +121,8 @@ def __init__( self.shuffle_LSE = self.SdP_swapAB and self.tile_hdim <= 64 self.shuffle_dPsum = self.SdP_swapAB and self.tile_hdim <= 64 + self.score_mod = score_mod + self.score_mod_bwd = score_mod_bwd self.mask_mod = mask_mod self.has_aux_tensors = has_aux_tensors self.subtile_factor = subtile_factor @@ -125,6 +130,7 @@ def __init__( self.vec_size: cutlass.Constexpr = 1 else: self.vec_size: cutlass.Constexpr = 4 + self.qk_acc_dtype = Float32 @staticmethod def can_implement( @@ -443,7 +449,10 @@ def __call__( grid_dim = TileScheduler.get_grid_shape(tile_sched_params) LOG2_E = math.log2(math.e) - softmax_scale_log2 = softmax_scale * LOG2_E + if const_expr(self.score_mod is None): + softmax_scale_log2 = softmax_scale * LOG2_E + else: + softmax_scale_log2 = LOG2_E fastdiv_mods = None if const_expr(aux_tensors is not None): @@ -856,6 +865,93 @@ def load( tile_scheduler.advance_to_next_work() work_tile = tile_scheduler.get_current_work() + @cute.jit + def apply_score_mod( + self, + acc_S: cute.Tensor, + thr_mma_SdP: cute.core.ThrMma, + batch_idx, + head_idx, + m_block, + n_block, + softmax_scale, + seqlen_info: SeqlenInfoQK, + aux_tensors=None, + fastdiv_mods=(None, None), + ): + # [NOTE] SdP_swapAB: swapAB transposes the tile, so use (n, m) indexing + cS = cute.make_identity_tensor( + (self.tile_n, self.tile_m) if self.SdP_swapAB else (self.tile_m, self.tile_n) + ) + cS = cute.domain_offset( + (n_block * self.tile_n, m_block * self.tile_m) + if self.SdP_swapAB + else (m_block * self.tile_m, n_block * self.tile_n), + cS, + ) + tScS = thr_mma_SdP.partition_C(cS) + + apply_score_mod_inner( + acc_S, + tScS, + self.score_mod, + batch_idx, + head_idx, + softmax_scale, + self.vec_size, + self.qk_acc_dtype, + aux_tensors, + fastdiv_mods, + seqlen_info, + constant_q_idx=None, + qhead_per_kvhead=self.qhead_per_kvhead, + transpose_indices=self.SdP_swapAB, + ) + + @cute.jit + def apply_score_mod_bwd( + self, + grad_tensor: cute.Tensor, + score_tensor: cute.Tensor, + thr_mma_SdP: cute.core.ThrMma, + batch_idx, + head_idx, + m_block, + n_block, + softmax_scale, + seqlen_info: SeqlenInfoQK, + aux_tensors=None, + fastdiv_mods=(None, None), + ): + cS = cute.make_identity_tensor( + (self.tile_n, self.tile_m) if self.SdP_swapAB else (self.tile_m, self.tile_n) + ) + cS = cute.domain_offset( + (n_block * self.tile_n, m_block * self.tile_m) + if self.SdP_swapAB + else (m_block * self.tile_m, n_block * self.tile_n), + cS, + ) + tScS = thr_mma_SdP.partition_C(cS) + + apply_score_mod_bwd_inner( + grad_tensor, + score_tensor, + tScS, + self.score_mod_bwd, + batch_idx, + head_idx, + softmax_scale, + self.vec_size, + self.qk_acc_dtype, + aux_tensors, + fastdiv_mods, + seqlen_info, + constant_q_idx=None, + qhead_per_kvhead=self.qhead_per_kvhead, + transpose_indices=self.SdP_swapAB, + ) + @cute.jit def mma( self, @@ -1196,6 +1292,24 @@ def mma_one_m_block( ) acc_dP = mma_dov_fn(A_idx=smem_idx_Q, wg_wait=1) + if const_expr(self.score_mod_bwd is not None): + acc_S_pre = cute.make_fragment_like(acc_S) + cute.autovec_copy(acc_S, acc_S_pre) + + if const_expr(self.score_mod is not None): + self.apply_score_mod( + acc_S, + thr_mma_SdP, + batch_idx, + head_idx, + m_block, + n_block, + softmax_scale, + seqlen, + aux_tensors, + fastdiv_mods, + ) + # (3) [Pointwise 1] P = exp(S - LSE) if cutlass.const_expr(mask_fn is not None): mask_fn(acc_S, m_block=m_block) @@ -1226,6 +1340,21 @@ def mma_one_m_block( for c in cutlass.range(cute.size(acc_dP_mn, mode=[1]), unroll_full=True): acc_dP_mn[r, c] = acc_S_mn[r, c] * (acc_dP_mn[r, c] - tLSErdPsum[r]) + if const_expr(self.score_mod_bwd is not None): + self.apply_score_mod_bwd( + acc_dP, + acc_S_pre, + thr_mma_SdP, + batch_idx, + head_idx, + m_block, + n_block, + softmax_scale, + seqlen, + aux_tensors, + fastdiv_mods, + ) + # Convert dS from f32 -> f16 tdKrdS = utils.cvt_f16(utils.make_acc_tensor_frgA_view(acc_dP), self.dtype) diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 574413bbd0f..37cbf42fdd4 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -713,7 +713,6 @@ def _flash_attn_bwd( assert cu_seqlens_q is None and cu_seqlens_k is None, ( "varlen + score_mod not supported in bwd yet" ) - assert compute_capability in [10, 11], "score_mod in bwd only supported on SM100/SM110 for now" device = q.device out_torch_dtype = q.dtype @@ -910,7 +909,6 @@ def _flash_attn_bwd( num_aux_tensors, use_block_sparsity, ) - cute_aux_tensors = None else: compile_key = ( compute_capability, @@ -999,6 +997,8 @@ def _flash_attn_bwd( AtomLayoutMdQ, num_threads, V_in_regs=V_in_regs, + score_mod=score_mod, + score_mod_bwd=score_mod_bwd, mask_mod=mask_mod, has_aux_tensors=aux_tensors is not None, subtile_factor=subtile_factor, @@ -1034,6 +1034,12 @@ def _flash_attn_bwd( block_sparse_tensors, expected_count_shape=expected_count_shape, expected_index_shape=expected_index_shape, + context="_flash_attn_bwd", + hint=lambda: ( + f"Backward expects Q-direction block-sparse tensors (q_mask_cnt/q_mask_idx, and optionally full_q_cnt/full_q_idx). " + f"Regenerate the backward BlockMask with BLOCK_SIZE=({sparse_block_size_q}, {n_block_size}) " + f"(sparse_block_size_q={sparse_block_size_q})." + ), ) sparse_tensors_compile = to_cute_block_sparse_tensors(compile_time_normalized) @@ -1076,6 +1082,12 @@ def _flash_attn_bwd( block_sparse_tensors, expected_count_shape=expected_count_shape, expected_index_shape=expected_index_shape, + context="_flash_attn_bwd", + hint=lambda: ( + f"Backward expects Q-direction block-sparse tensors (q_mask_cnt/q_mask_idx, and optionally full_q_cnt/full_q_idx). " + f"Regenerate the backward BlockMask with BLOCK_SIZE=({sparse_block_size_q}, {n_block_size}) " + f"(sparse_block_size_q={sparse_block_size_q})." + ), ) _flash_attn_bwd.compile_cache[compile_key]( diff --git a/flash_attn/cute/mask.py b/flash_attn/cute/mask.py index 7881128e0fb..4616edd6f9b 100644 --- a/flash_attn/cute/mask.py +++ b/flash_attn/cute/mask.py @@ -139,7 +139,6 @@ def apply_mask( ): # FlexAttention mask mod nrow = const_expr(cute.size(tScS_mn.shape[0])) ncol = const_expr(cute.size(tScS_mn.shape[1])) - thr_col_offset = tScS_mn[0, 0][1] has_fastdiv = const_expr( fastdiv_mods is not None and fastdiv_mods[0] is not None @@ -150,7 +149,9 @@ def apply_mask( ) for r in cutlass.range_constexpr(nrow): - global_row_idx = tScS_mn[r, 0][0] + m_block * self.tile_m + # Respect swap_AB: ROW/COL determine which coordinate component corresponds to Q/KV. + local_row = tScS_mn[r, 0][ROW] + global_row_idx = local_row + m_block * self.tile_m row_for_mod = global_row_idx head_idx_for_mod = head_idx if const_expr(self.qhead_per_kvhead_packgqa != 1): @@ -162,7 +163,7 @@ def apply_mask( _, row_for_mod = divmod(row_for_mod, fastdiv_mods[0]) for col in cutlass.range_constexpr(ncol): - col_idx_local = t0ScS_mn[0, col][1] + col_idx_local = t0ScS_mn[0, col][COL] # Convert to absolute column index global_col_idx = thr_col_offset + col_idx_local + n_block * self.tile_n col_for_mod = global_col_idx @@ -354,7 +355,7 @@ def apply_mask_sm100( mask_r2p(acc_S, seqlenk_col_limit, arch=100, rank1=True) elif const_expr(not mask_causal and not mask_local and mask_mod is not None): - # Block sparse w/ mask_mod + # Block sparse case w/ mask_mod has_fastdiv = const_expr( fastdiv_mods is not None and fastdiv_mods[0] is not None diff --git a/tests/cute/test_mask_mod.py b/tests/cute/test_mask_mod.py index 01261789f39..59409862406 100644 --- a/tests/cute/test_mask_mod.py +++ b/tests/cute/test_mask_mod.py @@ -13,7 +13,6 @@ # pytest test_mask_mod.py # Run all tests import math -from typing import Optional import pytest import torch @@ -62,7 +61,7 @@ def create_tensors( } -def compute_reference_flex_attn(tensors, mask_mod_flex, block_size: Optional[tuple[int, int]] = None): +def compute_reference_flex_attn(tensors, mask_mod_flex, block_size: tuple[int, int] | None = None): """Compute reference using flex_attention for custom mask_mods""" batch_size, seqlen_q, nheads, headdim = tensors["q"].shape _, seqlen_k, nheads_kv, _ = tensors["k"].shape @@ -240,6 +239,31 @@ def mask_mod_flex(b, h, q_idx, kv_idx, bias=bias): *_, ) = bm.as_tuple() + # SM90 block-sparse backward expects BlockMask granularity (128, 128) regardless of fwd tiling. + if COMPUTE_CAPABILITY == 9 and use_block_sparsity: + bm_bwd = create_block_mask( + mask_mod_flex, + batch_size, + nheads, + seqlen_q, + seqlen_k, + device="cuda", + BLOCK_SIZE=(128, 128), + ) + ( + _seq_q, + _seq_k, + _kv_mask_cnt, + _kv_mask_idx, + _full_kv_cnt, + _full_kv_idx, + q_mask_cnt, + q_mask_idx, + full_q_cnt, + full_q_idx, + *_, + ) = bm_bwd.as_tuple() + softmax_scale = 1.0 / math.sqrt(headdim) block_sparse_mask_fwd = BlockSparseTensorsTorch( @@ -343,8 +367,7 @@ def mask_mod_flex(b, h, q_idx, kv_idx, bias=bias): f"Kernel error {cute_error:.2e} exceeds {rtol}x PyTorch error {pt_error:.2e} + {fwd_atol:.2e}" ) - # Backward pass (SM100 only) - if needs_backward and COMPUTE_CAPABILITY == 10 and kv_mode == "mha": + if needs_backward and kv_mode == "mha": q = tensors["q"] k = tensors["k"] v = tensors["v"] @@ -453,9 +476,6 @@ def test_q_boundary_masking_block_sparse_bwd(seqlen_q, seqlen_k, mask_name): - Block-sparse with mask_mod: exercises is_full_block=True path - Backward pass: where the bug manifested """ - if COMPUTE_CAPABILITY != 10: - pytest.skip("SM100-only backward test") - _run_mask_test( seqlen_q=seqlen_q, seqlen_k=seqlen_k, @@ -474,6 +494,7 @@ def test_q_boundary_masking_block_sparse_bwd(seqlen_q, seqlen_k, mask_name): ) +@pytest.mark.skipif(COMPUTE_CAPABILITY != 10, reason="Test uses SM100 block mask conventions (2*tile_m)") def test_single_doc_bwd_minimal(): """Minimal test to isolate single-document backward pass bug. @@ -484,9 +505,6 @@ def test_single_doc_bwd_minimal(): Run with: pytest tests/cute/test_mask_mod.py::test_single_doc_bwd_minimal -v -s """ - if COMPUTE_CAPABILITY != 10: - pytest.skip("SM100-only test") - import random random.seed(42) torch.manual_seed(42) @@ -803,5 +821,76 @@ def run_flex_reference_bwd(q, k, v, block_mask, grad_out, dtype=None): ) +def test_sm90_block_sparse_bwd_mismatched_q_block_granularity_error_message(): + if COMPUTE_CAPABILITY != 9: + pytest.skip("SM90-only test") + + batch_size = 1 + seqlen_q = 256 + seqlen_k = 256 + nheads = 4 + nheads_kv = nheads + headdim = 128 + dtype = torch.bfloat16 + tile_m = 80 + tile_n = 128 + + tensors = create_tensors(batch_size, seqlen_q, seqlen_k, nheads, nheads_kv, headdim, headdim, dtype) + mask_mod_cute, mask_mod_flex = get_mask_pair("block_diagonal", seqlen_q=seqlen_q, seqlen_k=seqlen_k) + bm = create_block_mask( + mask_mod_flex, + batch_size, + nheads, + seqlen_q, + seqlen_k, + device="cuda", + BLOCK_SIZE=(tile_m, tile_n), + ) + ( + _seq_q, + _seq_k, + _kv_mask_cnt, + _kv_mask_idx, + _full_kv_cnt, + _full_kv_idx, + q_mask_cnt, + q_mask_idx, + full_q_cnt, + full_q_idx, + *_, + ) = bm.as_tuple() + + block_sparse_mask_bwd = BlockSparseTensorsTorch( + mask_block_cnt=q_mask_cnt, + mask_block_idx=q_mask_idx, + full_block_cnt=full_q_cnt, + full_block_idx=full_q_idx, + ) + + softmax_scale = 1.0 / math.sqrt(headdim) + out = torch.empty(batch_size, seqlen_q, nheads, headdim, device="cuda", dtype=dtype) + lse = torch.empty(batch_size, nheads, seqlen_q, device="cuda", dtype=torch.float32) + grad_out = torch.randn_like(out) + + with pytest.raises( + ValueError, + match=r"Hint: Backward expects Q-direction block-sparse tensors.*BLOCK_SIZE=\(128, 128\)", + ): + _flash_attn_bwd( + q=tensors["q"], + k=tensors["k"], + v=tensors["v"], + out=out, + dout=grad_out, + lse=lse, + softmax_scale=softmax_scale, + causal=False, + m_block_size=tile_m, + n_block_size=tile_n, + mask_mod=mask_mod_cute, + block_sparse_tensors=block_sparse_mask_bwd, + ) + + if __name__ == "__main__": pytest.main([__file__, "-v", "-s"]) diff --git a/tests/cute/test_score_mod.py b/tests/cute/test_score_mod.py index c90fc14c629..11efcc8cdbc 100644 --- a/tests/cute/test_score_mod.py +++ b/tests/cute/test_score_mod.py @@ -6,6 +6,9 @@ import operator from torch.nn.attention.flex_attention import flex_attention from flash_attn.cute.interface import _flash_attn_fwd, _flash_attn_bwd + +COMPUTE_CAPABILITY = torch.cuda.get_device_capability()[0] + from score_mod_definitions import ( # TensorSSA-based score mods score_mod_identity as score_mod_1, @@ -291,6 +294,7 @@ def _generate_block_kvcache( ], ) @pytest.mark.parametrize("score_mod_pair", TEST_PAIRS) +@pytest.mark.skipif(COMPUTE_CAPABILITY != 10, reason="Paged KV cache only supported on SM100") def test_score_mod_with_paged_kvcache( seqlen_q, seqlen_kv, @@ -447,6 +451,7 @@ def masked_score_mod(score, b, h, q_idx, kv_idx): ], ) @pytest.mark.parametrize("score_mod_pair", TEST_PAIRS_WITH_AUX_TENSORS) +@pytest.mark.skipif(COMPUTE_CAPABILITY != 10, reason="Paged KV cache only supported on SM100") def test_score_mod_with_paged_kvcache_aux_tensors( seqlen_q, seqlen_kv, @@ -740,6 +745,9 @@ def run_flex_reference_bwd(q, k, v, eager_score_mod, grad_out, dtype=None): @pytest.mark.parametrize("score_mod_triple", BWD_TEST_PAIRS) def test_cute_vs_flex_attention_backward(seqlen_q, seqlen_kv, dim, dtype, score_mod_triple): """Test backward pass with score_mod against flex_attention reference.""" + if COMPUTE_CAPABILITY == 9 and dim == 64: + pytest.skip("head_dim=64 not supported on SM90 for backward") + torch.random.manual_seed(42) cute_fwd, cute_bwd, eager_ref = score_mod_triple @@ -811,6 +819,9 @@ def make_aux_tensors_for_bwd(cute_score_mod, eager_factory, seqlen_q, num_heads, def test_cute_vs_flex_attention_backward_with_aux( seqlen_q, seqlen_kv, dim, dtype, score_mod_triple ): + if COMPUTE_CAPABILITY == 9 and dim == 64: + pytest.skip("head_dim=64 not supported on SM90 for backward") + torch.random.manual_seed(42) cute_fwd, cute_bwd, eager_factory = score_mod_triple @@ -864,14 +875,16 @@ def test_cute_vs_flex_attention_backward_with_aux( @pytest.mark.parametrize("seqlen_q,seqlen_kv", [(128, 128), (128, 256)]) -@pytest.mark.parametrize("dim", [64]) +@pytest.mark.parametrize("dim", [64, 128]) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) @pytest.mark.parametrize("qhead_per_kvhead,num_kv_heads", [(4, 2)]) @pytest.mark.parametrize("score_mod_triple", BWD_TEST_PAIRS_PACK_GQA) def test_cute_vs_flex_attention_backward_pack_gqa( seqlen_q, seqlen_kv, dim, dtype, qhead_per_kvhead, num_kv_heads, score_mod_triple ): - pytest.skip("pack_gqa backward not yet implemented") + if COMPUTE_CAPABILITY == 9: + pytest.xfail("pack_gqa backward not yet implemented on SM90") + torch.random.manual_seed(42) cute_fwd, cute_bwd, eager_ref = score_mod_triple