diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index c544d2d3d195..648c848d7808 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -7,6 +7,7 @@ from torch._ops import OpOverload import vllm.envs as envs +from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils.torch_utils import direct_register_custom_op from vllm.v1.attention.ops.rocm_aiter_mla_sparse import ( @@ -14,6 +15,8 @@ rocm_aiter_sparse_attn_indexer_fake, ) +logger = init_logger(__name__) + # fp8_dtype is not cached. # on ROCm the fp8_dtype always calls is_fp8_fnuz # which is a host op, so we cache it once here. @@ -999,7 +1002,42 @@ def is_fp8bmm_enabled(cls) -> bool: @classmethod @if_aiter_supported def is_fp4bmm_enabled(cls) -> bool: - return cls._AITER_ENABLED and cls._FP4BMM_ENABLED + """Check if FP4 BMM is enabled and supported by hardware. + + FP4 (MXFP4) is only supported on AMD MI325X/MI350X (gfx950). + MI300X/MI300A (gfx942) do not support FP4. + + This method checks both environment variables AND hardware capability + to prevent runtime errors on unsupported hardware. + + Returns: + bool: True if FP4 BMM is both requested and hardware-supported. + """ + if not (cls._AITER_ENABLED and cls._FP4BMM_ENABLED): + return False + + # Check hardware support before enabling FP4 + try: + from aiter.ops.triton.utils._triton.arch_info import ( + get_arch, + is_fp4_avail, + ) + + if not is_fp4_avail(): + arch = get_arch() + logger.info( + "FP4BMM requested via VLLM_ROCM_USE_AITER_FP4BMM but not " + f"supported on {arch}. FP4 requires gfx950 " + "(MI325X/MI350X). Falling back to FP8." + ) + return False + return True + except ImportError: + logger.warning( + "AITER arch_info not available. Disabling FP4BMM to avoid " + "potential runtime errors." + ) + return False @classmethod @if_aiter_supported