-
-
Notifications
You must be signed in to change notification settings - Fork 15.6k
[ROCm] Add hardware detection for FP4 BMM to prevent MI300X crashes #34647
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,13 +7,16 @@ | |
| from torch._ops import OpOverload | ||
|
|
||
| import vllm.envs as envs | ||
| from vllm.logger import init_logger | ||
| from vllm.platforms import current_platform | ||
| from vllm.utils.torch_utils import direct_register_custom_op | ||
| from vllm.v1.attention.ops.rocm_aiter_mla_sparse import ( | ||
| rocm_aiter_sparse_attn_indexer, | ||
| rocm_aiter_sparse_attn_indexer_fake, | ||
| ) | ||
|
|
||
| logger = init_logger(__name__) | ||
|
|
||
| # fp8_dtype is not cached. | ||
| # on ROCm the fp8_dtype always calls is_fp8_fnuz | ||
| # which is a host op, so we cache it once here. | ||
|
|
@@ -999,7 +1002,42 @@ | |
| @classmethod | ||
| @if_aiter_supported | ||
| def is_fp4bmm_enabled(cls) -> bool: | ||
| return cls._AITER_ENABLED and cls._FP4BMM_ENABLED | ||
| """Check if FP4 BMM is enabled and supported by hardware. | ||
|
|
||
| FP4 (MXFP4) is only supported on AMD MI325X/MI350X (gfx950). | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note: gfx950: MI355x and MI350x. |
||
| MI300X/MI300A (gfx942) do not support FP4. | ||
|
|
||
| This method checks both environment variables AND hardware capability | ||
| to prevent runtime errors on unsupported hardware. | ||
|
|
||
| Returns: | ||
| bool: True if FP4 BMM is both requested and hardware-supported. | ||
| """ | ||
| if not (cls._AITER_ENABLED and cls._FP4BMM_ENABLED): | ||
| return False | ||
|
|
||
| # Check hardware support before enabling FP4 | ||
| try: | ||
| from aiter.ops.triton.utils._triton.arch_info import ( | ||
| get_arch, | ||
| is_fp4_avail, | ||
| ) | ||
|
|
||
| if not is_fp4_avail(): | ||
| arch = get_arch() | ||
| logger.info( | ||
| "FP4BMM requested via VLLM_ROCM_USE_AITER_FP4BMM but not " | ||
| f"supported on {arch}. FP4 requires gfx950 " | ||
| "(MI325X/MI350X). Falling back to FP8." | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. MI325x is gfx942 as well. |
||
| ) | ||
| return False | ||
| return True | ||
| except ImportError: | ||
| logger.warning( | ||
| "AITER arch_info not available. Disabling FP4BMM to avoid " | ||
| "potential runtime errors." | ||
| ) | ||
| return False | ||
|
Comment on lines
1004
to
+1040
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For performance and consistency with other checks in this file (like You can cache the result in a class attribute. The check will be performed only on the first call, and subsequent calls will use the cached value. def is_fp4bmm_enabled(cls) -> bool:
"""Check if FP4 BMM is enabled and supported by hardware.
FP4 (MXFP4) is only supported on AMD MI325X/MI350X (gfx950).
MI300X/MI300A (gfx942) do not support FP4.
This method checks both environment variables AND hardware capability
to prevent runtime errors on unsupported hardware.
Returns:
bool: True if FP4 BMM is both requested and hardware-supported.
"""
if not (cls._AITER_ENABLED and cls._FP4BMM_ENABLED):
return False
if hasattr(cls, "_FP4BMM_HW_SUPPORTED"):
return cls._FP4BMM_HW_SUPPORTED
# Check hardware support before enabling FP4
try:
from aiter.ops.triton.utils._triton.arch_info import (
get_arch,
is_fp4_avail,
)
if not is_fp4_avail():
arch = get_arch()
logger.info(
"FP4BMM requested via VLLM_ROCM_USE_AITER_FP4BMM but not "
f"supported on {arch}. FP4 requires gfx950 "
"(MI325X/MI350X). Falling back to FP8."
)
cls._FP4BMM_HW_SUPPORTED = False
return False
cls._FP4BMM_HW_SUPPORTED = True
return True
except ImportError:
logger.warning(
"AITER arch_info not available. Disabling FP4BMM to avoid "
"potential runtime errors."
)
cls._FP4BMM_HW_SUPPORTED = False
return False |
||
|
|
||
| @classmethod | ||
| @if_aiter_supported | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we just do this?
return cls._AITER_ENABLED and cls._FP4BMM_ENABLED and on_gfx950()Using
vllm/vllm/platforms/rocm.py
Line 150 in 8d9babd
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree with @BowenBao
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@khairulkabir1661 can you keep the changes to just this one line?