Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 39 additions & 1 deletion vllm/_aiter_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,16 @@
from torch._ops import OpOverload

import vllm.envs as envs
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.ops.rocm_aiter_mla_sparse import (
rocm_aiter_sparse_attn_indexer,
rocm_aiter_sparse_attn_indexer_fake,
)

logger = init_logger(__name__)

# fp8_dtype is not cached.
# on ROCm the fp8_dtype always calls is_fp8_fnuz
# which is a host op, so we cache it once here.
Expand Down Expand Up @@ -999,7 +1002,42 @@
@classmethod
@if_aiter_supported
def is_fp4bmm_enabled(cls) -> bool:
return cls._AITER_ENABLED and cls._FP4BMM_ENABLED
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just do this?

return cls._AITER_ENABLED and cls._FP4BMM_ENABLED and on_gfx950()

Using

def on_gfx950() -> bool:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with @BowenBao

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@khairulkabir1661 can you keep the changes to just this one line?

"""Check if FP4 BMM is enabled and supported by hardware.

FP4 (MXFP4) is only supported on AMD MI325X/MI350X (gfx950).
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: gfx950: MI355x and MI350x.
For MI325x, it is gfx942.

MI300X/MI300A (gfx942) do not support FP4.

This method checks both environment variables AND hardware capability
to prevent runtime errors on unsupported hardware.

Returns:
bool: True if FP4 BMM is both requested and hardware-supported.
"""
if not (cls._AITER_ENABLED and cls._FP4BMM_ENABLED):
return False

# Check hardware support before enabling FP4
try:
from aiter.ops.triton.utils._triton.arch_info import (
get_arch,
is_fp4_avail,
)

if not is_fp4_avail():
arch = get_arch()
logger.info(
"FP4BMM requested via VLLM_ROCM_USE_AITER_FP4BMM but not "
f"supported on {arch}. FP4 requires gfx950 "
"(MI325X/MI350X). Falling back to FP8."

Check failure on line 1031 in vllm/_aiter_ops.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (G004)

vllm/_aiter_ops.py:1029:21: G004 Logging statement uses f-string
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MI325x is gfx942 as well.

)
return False
return True
except ImportError:
logger.warning(
"AITER arch_info not available. Disabling FP4BMM to avoid "
"potential runtime errors."
)
return False
Comment on lines 1004 to +1040
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

For performance and consistency with other checks in this file (like _check_aiter_mla_fp8_support), the result of the hardware capability check should be cached. This avoids repeatedly performing file imports and function calls, which can be expensive if this method is called frequently.

You can cache the result in a class attribute. The check will be performed only on the first call, and subsequent calls will use the cached value.

    def is_fp4bmm_enabled(cls) -> bool:
        """Check if FP4 BMM is enabled and supported by hardware.

        FP4 (MXFP4) is only supported on AMD MI325X/MI350X (gfx950).
        MI300X/MI300A (gfx942) do not support FP4.

        This method checks both environment variables AND hardware capability
        to prevent runtime errors on unsupported hardware.

        Returns:
            bool: True if FP4 BMM is both requested and hardware-supported.
        """
        if not (cls._AITER_ENABLED and cls._FP4BMM_ENABLED):
            return False

        if hasattr(cls, "_FP4BMM_HW_SUPPORTED"):
            return cls._FP4BMM_HW_SUPPORTED

        # Check hardware support before enabling FP4
        try:
            from aiter.ops.triton.utils._triton.arch_info import (
                get_arch,
                is_fp4_avail,
            )

            if not is_fp4_avail():
                arch = get_arch()
                logger.info(
                    "FP4BMM requested via VLLM_ROCM_USE_AITER_FP4BMM but not "
                    f"supported on {arch}. FP4 requires gfx950 "
                    "(MI325X/MI350X). Falling back to FP8."
                )
                cls._FP4BMM_HW_SUPPORTED = False
                return False
            cls._FP4BMM_HW_SUPPORTED = True
            return True
        except ImportError:
            logger.warning(
                "AITER arch_info not available. Disabling FP4BMM to avoid "
                "potential runtime errors."
            )
            cls._FP4BMM_HW_SUPPORTED = False
            return False


@classmethod
@if_aiter_supported
Expand Down
Loading