diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py index 33d61c5f1d16..6baf4bf83f49 100644 --- a/tests/compile/test_fusion_attn.py +++ b/tests/compile/test_fusion_attn.py @@ -40,13 +40,12 @@ @pytest.mark.parametrize( "model, quant_key", [("amd/Llama-3.1-8B-Instruct-FP8-KV", kFp8StaticTensorSym)]) -@pytest.mark.parametrize( - "use_triton_fa", [True, False] if current_platform.is_rocm() else [False]) +@pytest.mark.parametrize("use_triton_fa", [True, False]) @pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8") -@pytest.mark.skipif(not current_platform.is_cuda_alike(), - reason="Only test CUDA and ROCm") -def test_attention_fusion(example_prompts, monkeypatch, model: str, - quant_key: QuantKey, use_triton_fa: bool): +@pytest.mark.skipif(not current_platform.is_rocm(), + reason="V0 attn quant fusion only on ROCm") +def test_attention_fusion_v0(example_prompts, monkeypatch, model: str, + quant_key: QuantKey, use_triton_fa: bool): # Clean Dynamo cache to avoid reusing other test cases # (for some reason the reset at the end is not enough) torch._dynamo.reset()