From 38eb8457ab1b8ae7ad2468e1c4ead8577e7cb2a9 Mon Sep 17 00:00:00 2001 From: zhyncs Date: Tue, 17 Jun 2025 23:56:01 +0000 Subject: [PATCH] fix: only enable flash_attn test on sm80 sm90 --- sgl-kernel/tests/test_flash_attention.py | 8 ++++++-- sgl-kernel/tests/test_sparse_flash_attn.py | 15 +++++++++++++-- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/sgl-kernel/tests/test_flash_attention.py b/sgl-kernel/tests/test_flash_attention.py index def092a3462..0c7e854b900 100644 --- a/sgl-kernel/tests/test_flash_attention.py +++ b/sgl-kernel/tests/test_flash_attention.py @@ -13,7 +13,7 @@ def is_hopper(): # Only Hopper supports different V headdim - return torch.cuda.get_device_properties(0).major >= 9 + return torch.cuda.get_device_properties(0).major == 9 def is_fa3_supported(device=None) -> bool: @@ -451,7 +451,7 @@ def generate_qkv( @pytest.mark.skipif( not is_fa3_supported(), - reason="flash_attn at sgl-kernel is only supported on sm90 and above", + reason="flash_attn at sgl-kernel is only supported on sm90 or sm80", ) # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) @pytest.mark.parametrize( @@ -1009,6 +1009,10 @@ def _generate_block_kvcache( return k_cache, v_cache, page_table, k_cache_paged, v_cache_paged, num_blocks +@pytest.mark.skipif( + not is_fa3_supported(), + reason="flash_attn at sgl-kernel is only supported on sm90 or sm80", +) # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) @pytest.mark.parametrize( "dtype", [torch.bfloat16] + ([torch.float8_e4m3fn] if not DISABLE_FP8 else []) diff --git a/sgl-kernel/tests/test_sparse_flash_attn.py b/sgl-kernel/tests/test_sparse_flash_attn.py index 4aa0a7c19ac..28c64cb6162 100644 --- a/sgl-kernel/tests/test_sparse_flash_attn.py +++ b/sgl-kernel/tests/test_sparse_flash_attn.py @@ -8,9 +8,8 @@ convert_vertical_slash_indexes, convert_vertical_slash_indexes_mergehead, sparse_attn_func, - sparse_attn_varlen_func, ) -from test_flash_attention import construct_local_mask +from test_flash_attention import construct_local_mask, is_fa3_supported def ref_attn( @@ -172,6 +171,10 @@ def ref_paged_attn( return torch.cat(outputs, dim=0) +@pytest.mark.skipif( + not is_fa3_supported(), + reason="flash_attn at sgl-kernel is only supported on sm90 or sm80", +) @pytest.mark.parametrize("batch_size", [1, 2]) @pytest.mark.parametrize( "seq_lens", @@ -257,6 +260,10 @@ def test_sparse_attention( # sparse attention utils # origin +@pytest.mark.skipif( + not is_fa3_supported(), + reason="flash_attn at sgl-kernel is only supported on sm90 or sm80", +) @pytest.mark.parametrize("causal", [True, False]) def test_convert_vertical_slash_indexes(causal): # Prepare small, hand-checkable inputs @@ -311,6 +318,10 @@ def test_convert_vertical_slash_indexes(causal): # mergehead +@pytest.mark.skipif( + not is_fa3_supported(), + reason="flash_attn at sgl-kernel is only supported on sm90 or sm80", +) @pytest.mark.parametrize("causal", [True, False]) def test_convert_vertical_slash_indexes_mergehead(causal): # Prepare small, hand-checkable inputs for mergehead version