diff --git a/flash_attn/cute/mask.py b/flash_attn/cute/mask.py index 6f39539adfd..7881128e0fb 100644 --- a/flash_attn/cute/mask.py +++ b/flash_attn/cute/mask.py @@ -77,11 +77,11 @@ class AttentionMask: window_size_right: Optional[Int32] = None qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1 # only pass in if we're doing PackGQA swap_AB: cutlass.Constexpr[bool] = False - + @property def seqlen_q(self) -> Int32: return self.seqlen_info.seqlen_q - + @property def seqlen_k(self) -> Int32: return self.seqlen_info.seqlen_k @@ -549,6 +549,7 @@ def apply_mask_sm100_transposed( head_idx_ssa, q_idx_ssa, kv_idx_ssa, + self.seqlen_info, aux_tensors, ) cond = cutlass.Boolean(utils.ssa_to_scalar(mask_value)) diff --git a/tests/cute/test_score_mod.py b/tests/cute/test_score_mod.py index 82c135a8ee1..c90fc14c629 100644 --- a/tests/cute/test_score_mod.py +++ b/tests/cute/test_score_mod.py @@ -35,6 +35,8 @@ dual_buffer_factory as dual_buffer_bias, ) +COMPUTE_CAPABILITY = torch.cuda.get_device_capability()[0] + # Test pairs: (cute_jit_function, eager_reference_function) TEST_PAIRS = [ (score_mod_1, None),