[ROCm] Use supports_fp8() for FP8 feature gates instead of arch checks#34740
[ROCm] Use supports_fp8() for FP8 feature gates instead of arch checks#34740laudney wants to merge 1 commit intovllm-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request is a well-executed refactoring that replaces verbose, architecture-specific FP8 capability checks with a unified current_platform.supports_fp8() predicate. The changes are applied consistently across multiple files, simplifying the code and improving maintainability. Most importantly, this change correctly enables FP8 features on newer RDNA4 (gfx12) GPUs, which were previously excluded by MI300-specific gates. The updated error messages and docstrings are also clearer and more generic. The changes are correct and a clear improvement to the codebase.
|
Hi @laudney, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
5cd81f8 to
c2f191a
Compare
yewentao256
left a comment
There was a problem hiding this comment.
LGTM, thanks for the work!
|
This pull request has merge conflicts that must be resolved before it can be |
Cherry-picked and adapted from 4 open PRs: - vllm-project#34740 (laudney): Replace on_gfx9()/on_mi3xx() FP8 gates with supports_fp8(), unblocking FP8 on RDNA4/gfx12 - vllm-project#34709 (laudney): Enable wvSplitK/wvSplitKQ skinny GEMM kernels for RDNA4 decode (~15% improvement), wave32 DPP reduction - vllm-project#34741 (laudney): FP8 KV-cache for RDNA4 custom paged attention via software dequantization - vllm-project#36659 (vllmellm): Tuned FP8 MoE Triton configs for AMD Radeon AI PRO R9700, AITER mha_v3 attention on gfx12x
Head branch was pushed to by a user without write access
8c7e7af to
a6eb66e
Compare
|
Hey, this is approved and rebased on latest main. What else do I need to do to get it merged? |
|
@laudney We added a new Intel CI pipeline that only gates Intel PRs, so it does not apply to your PR. Feel free to ignore the result. |
Replace verbose architecture-specific checks (on_gfx9(), on_mi3xx(), has_device_capability(94)) with the cross-platform supports_fp8() predicate across FP8-related code paths. This enables FP8 features on RDNA4 (gfx12) GPUs which support FP8 but were excluded by the MI300-specific gates. Affected paths: - TritonExperts / BatchedTritonExperts: FP8 MoE gate - ROCmFP8ScaledMMLinearKernel: per-tensor FP8 skinny GEMM gate - RowWiseTorchFP8ScaledMMLinearKernel: rowwise FP8 matmul gate - PTPCFp8Config: dynamic FP8 quantization config Signed-off-by: L.B.R. <lbr@mmonad.com>
a6eb66e to
7658407
Compare
|
Rebased onto latest |
Summary
Replace verbose architecture-specific FP8 capability checks (
on_gfx9(),on_mi3xx(),has_device_capability(94)) with the cross-platformcurrent_platform.supports_fp8()predicate across all FP8-related code paths.This is a small refactoring PR that unblocks FP8 features on RDNA4 (gfx12) GPUs which support FP8 but were excluded by MI300-specific gates. The
supports_fp8()method already correctly covers MI300, gfx950, gfx12 on ROCm and capability >= 8.9 on CUDA — this PR simply switches the callers to use it.Changes (5 files, net -27 lines)
fused_batched_moe.pyon_gfx9()+ CUDA capability checksupports_fp8()one-linerfused_moe.pysupports_fp8()one-linerscaled_mm/rocm.pyon_mi3xx()— excluded gfx12supports_fp8()— includes gfx12scaled_mm/pytorch.pyon_mi3xx()+capability >= 94supports_fp8()ptpc_fp8.pyhas_device_capability(94)supports_fp8()+ updated docstringWhy this matters
Without this change, RDNA4 GPUs fall through to non-FP8 paths even though they have hardware FP8 support (
v_dot4_f32_fp8_fp8, FP8 format intorch._scaled_mm). The oldon_mi3xx()/on_gfx9()gates were written before RDNA4 existed.Related PRs (RDNA4/gfx12 series)
Test plan
has_device_capability(8, 9))