Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion tests/attention/test_fmha_v2_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
import math
from typing import Optional, Tuple, Union

pytestmark = pytest.mark.skip(
reason="todo(jimmyzho): temporarily skip this test due to hangs"
)

import flashinfer
from flashinfer.prefill import fmha_v2_prefill_deepseek
from tests.utils_fp8 import to_float8
Expand Down Expand Up @@ -837,6 +841,10 @@ def test_trtllm_fmha_v2_prefill(
and mask_mode == "SLIDING_WINDOW"
):
pytest.skip("Skip due to bug in fp8 sliding window")
if mask_mode == "SLIDING_WINDOW":
pytest.skip("todo(jimmyzho): temporarily skip sliding window test due to hang")
Comment on lines +844 to +845
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This new if statement makes the preceding check for mask_mode == "SLIDING_WINDOW" on lines 831-839 redundant, creating dead code. To avoid this, you can modify the condition to exclude the case that is already handled by the previous if block. This keeps both pytest.skip calls active for their respective conditions and makes the code easier to maintain.

    if mask_mode == "SLIDING_WINDOW" and not (
        batch_size == 16
        and num_kv_heads == 4
        and head_dim == 256
        and dtype == torch.float8_e4m3fn
        and input_layout in ["PACKED_QKV", "CONTIGUOUS_Q_KV"]
    ):
        pytest.skip("todo(jimmyzho): temporarily skip sliding window test due to hang")

if dtype == torch.float8_e4m3fn and o_dtype == torch.float8_e4m3fn:
pytest.skip("todo(jimmyzho): temporarily skip fp8 tests due to hang")
run_trtllm_fmha_v2_prefill_case(
input_layout=input_layout,
batch_size=batch_size,
Expand Down Expand Up @@ -955,7 +963,8 @@ def test_trtllm_fmha_v2_prefill_attention_sinks(

if not is_sm90a_supported(torch.device("cuda")):
pytest.skip("FMHA v2 requires SM90+ (Hopper) GPUs.")

if mask_mode == "SLIDING_WINDOW":
pytest.skip("todo(jimmyzho): temporarily skip sliding window test due to hang")
torch.manual_seed(42)
device = torch.device("cuda")

Expand Down
Loading