Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 118 additions & 0 deletions tests/rocm/aiter/test_mla_fp8_support_check.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Unit tests for AITER MLA FP8 support detection.

These tests verify that the _check_aiter_mla_fp8_support() function
correctly handles various error conditions without crashing.
"""

from unittest.mock import patch

import pytest


class TestAiterMlaFp8SupportCheck:
"""Test cases for _check_aiter_mla_fp8_support() function."""

def setup_method(self):
"""Reset the global cache before each test."""
import vllm._aiter_ops as aiter_ops

aiter_ops._AITER_MLA_SUPPORTS_FP8 = None

@patch("vllm._aiter_ops.is_aiter_found_and_supported", return_value=True)
def test_import_error_handling(self, mock_supported):
"""Test that ImportError is handled gracefully."""
import vllm._aiter_ops as aiter_ops
from vllm._aiter_ops import _check_aiter_mla_fp8_support

aiter_ops._AITER_MLA_SUPPORTS_FP8 = None

# Should return False without raising
with patch(
"vllm._aiter_ops.inspect.signature",
side_effect=ImportError("No module"),
):
result = _check_aiter_mla_fp8_support()
assert result is False

@patch("vllm._aiter_ops.is_aiter_found_and_supported", return_value=True)
def test_module_not_found_error_handling(self, mock_supported):
"""Test that ModuleNotFoundError is handled gracefully."""
import vllm._aiter_ops as aiter_ops
from vllm._aiter_ops import _check_aiter_mla_fp8_support

aiter_ops._AITER_MLA_SUPPORTS_FP8 = None

with patch(
"vllm._aiter_ops.inspect.signature",
side_effect=ModuleNotFoundError("Module not found"),
):
# Should return False without raising
assert _check_aiter_mla_fp8_support() is False
# Cache should be set to False
assert aiter_ops._AITER_MLA_SUPPORTS_FP8 is False

@patch("vllm._aiter_ops.is_aiter_found_and_supported", return_value=True)
def test_attribute_error_handling(self, mock_supported):
"""Test that AttributeError is handled gracefully."""
import vllm._aiter_ops as aiter_ops
from vllm._aiter_ops import _check_aiter_mla_fp8_support

aiter_ops._AITER_MLA_SUPPORTS_FP8 = None

with patch(
"vllm._aiter_ops.inspect.signature",
side_effect=AttributeError("No attribute"),
):
assert _check_aiter_mla_fp8_support() is False
assert aiter_ops._AITER_MLA_SUPPORTS_FP8 is False

@patch("vllm._aiter_ops.is_aiter_found_and_supported", return_value=True)
def test_value_error_handling(self, mock_supported):
"""Test that ValueError is handled gracefully (no signature)."""
import vllm._aiter_ops as aiter_ops
from vllm._aiter_ops import _check_aiter_mla_fp8_support

aiter_ops._AITER_MLA_SUPPORTS_FP8 = None

with patch(
"vllm._aiter_ops.inspect.signature",
side_effect=ValueError("No signature"),
):
assert _check_aiter_mla_fp8_support() is False
assert aiter_ops._AITER_MLA_SUPPORTS_FP8 is False

@patch("vllm._aiter_ops.is_aiter_found_and_supported", return_value=True)
def test_type_error_handling(self, mock_supported):
"""Test that TypeError is handled gracefully (not callable)."""
import vllm._aiter_ops as aiter_ops
from vllm._aiter_ops import _check_aiter_mla_fp8_support

aiter_ops._AITER_MLA_SUPPORTS_FP8 = None

with patch(
"vllm._aiter_ops.inspect.signature",
side_effect=TypeError("Not a callable"),
):
assert _check_aiter_mla_fp8_support() is False
assert aiter_ops._AITER_MLA_SUPPORTS_FP8 is False

@patch("vllm._aiter_ops.is_aiter_found_and_supported", return_value=True)
def test_result_caching(self, mock_supported):
"""Test that the result is cached after first check."""
import vllm._aiter_ops as aiter_ops

# Set cache to True
aiter_ops._AITER_MLA_SUPPORTS_FP8 = True

from vllm._aiter_ops import _check_aiter_mla_fp8_support

# Should return cached value without re-checking
result = _check_aiter_mla_fp8_support()
assert result is True


if __name__ == "__main__":
pytest.main([__file__, "-v"])
12 changes: 11 additions & 1 deletion vllm/_aiter_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,17 @@ def _check_aiter_mla_fp8_support() -> bool:
_AITER_MLA_SUPPORTS_FP8 = (
"q_scale" in sig.parameters and "kv_scale" in sig.parameters
)
except Exception:
except (
ImportError,
ModuleNotFoundError,
AttributeError,
ValueError,
TypeError,
):
# ImportError/ModuleNotFoundError: aiter.mla module not available
# AttributeError: mla_decode_fwd doesn't exist
# ValueError: mla_decode_fwd has no signature (e.g., built-in)
# TypeError: mla_decode_fwd is not a callable
_AITER_MLA_SUPPORTS_FP8 = False
return _AITER_MLA_SUPPORTS_FP8

Expand Down