Skip to content

[ROCm] Use supports_fp8() for FP8 feature gates instead of arch checks#34740

Open
laudney wants to merge 1 commit intovllm-project:mainfrom
mmonad:feat/rocm-fp8-capability-gates
Open

[ROCm] Use supports_fp8() for FP8 feature gates instead of arch checks#34740
laudney wants to merge 1 commit intovllm-project:mainfrom
mmonad:feat/rocm-fp8-capability-gates

Conversation

@laudney
Copy link
Copy Markdown
Contributor

@laudney laudney commented Feb 17, 2026

Summary

Replace verbose architecture-specific FP8 capability checks (on_gfx9(), on_mi3xx(), has_device_capability(94)) with the cross-platform current_platform.supports_fp8() predicate across all FP8-related code paths.

This is a small refactoring PR that unblocks FP8 features on RDNA4 (gfx12) GPUs which support FP8 but were excluded by MI300-specific gates. The supports_fp8() method already correctly covers MI300, gfx950, gfx12 on ROCm and capability >= 8.9 on CUDA — this PR simply switches the callers to use it.

Changes (5 files, net -27 lines)

File Before After
fused_batched_moe.py 10-line on_gfx9() + CUDA capability check supports_fp8() one-liner
fused_moe.py Same 10-line pattern supports_fp8() one-liner
scaled_mm/rocm.py on_mi3xx() — excluded gfx12 supports_fp8() — includes gfx12
scaled_mm/pytorch.py on_mi3xx() + capability >= 94 supports_fp8()
ptpc_fp8.py has_device_capability(94) supports_fp8() + updated docstring

Why this matters

Without this change, RDNA4 GPUs fall through to non-FP8 paths even though they have hardware FP8 support (v_dot4_f32_fp8_fp8, FP8 format in torch._scaled_mm). The old on_mi3xx()/on_gfx9() gates were written before RDNA4 existed.

Related PRs (RDNA4/gfx12 series)

Test plan

  • Qwen3-14B-FP8 (block-wise FP8) serving on gfx1201 — FP8 rowwise matmul path now activates
  • Qwen3-Coder-30B-A3B AWQ-4bit — FP8 MoE gate correctly enables Triton FP8 experts
  • No regression on MI300-series (predicate returns same result as old checks)
  • No regression on CUDA (predicate delegates to has_device_capability(8, 9))

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request is a well-executed refactoring that replaces verbose, architecture-specific FP8 capability checks with a unified current_platform.supports_fp8() predicate. The changes are applied consistently across multiple files, simplifying the code and improving maintainability. Most importantly, this change correctly enables FP8 features on newer RDNA4 (gfx12) GPUs, which were previously excluded by MI300-specific gates. The updated error messages and docstrings are also clearer and more generic. The changes are correct and a clear improvement to the codebase.

@mergify
Copy link
Copy Markdown

mergify bot commented Feb 17, 2026

Hi @laudney, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@laudney laudney force-pushed the feat/rocm-fp8-capability-gates branch 2 times, most recently from 5cd81f8 to c2f191a Compare February 17, 2026 20:40
Copy link
Copy Markdown
Member

@yewentao256 yewentao256 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for the work!

@yewentao256 yewentao256 added the ready ONLY add when PR is ready to merge/full CI is needed label Feb 18, 2026
@DarkLight1337 DarkLight1337 enabled auto-merge (squash) February 20, 2026 04:11
@mergify
Copy link
Copy Markdown

mergify bot commented Feb 24, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @laudney.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Feb 24, 2026
wi-adam added a commit to wi-adam/vllm that referenced this pull request Mar 12, 2026
Cherry-picked and adapted from 4 open PRs:

- vllm-project#34740 (laudney): Replace on_gfx9()/on_mi3xx() FP8 gates with
  supports_fp8(), unblocking FP8 on RDNA4/gfx12
- vllm-project#34709 (laudney): Enable wvSplitK/wvSplitKQ skinny GEMM kernels
  for RDNA4 decode (~15% improvement), wave32 DPP reduction
- vllm-project#34741 (laudney): FP8 KV-cache for RDNA4 custom paged attention
  via software dequantization
- vllm-project#36659 (vllmellm): Tuned FP8 MoE Triton configs for AMD Radeon
  AI PRO R9700, AITER mha_v3 attention on gfx12x
auto-merge was automatically disabled March 22, 2026 12:48

Head branch was pushed to by a user without write access

@laudney laudney force-pushed the feat/rocm-fp8-capability-gates branch from 8c7e7af to a6eb66e Compare March 22, 2026 12:48
@laudney
Copy link
Copy Markdown
Contributor Author

laudney commented Mar 22, 2026

Hey, this is approved and rebased on latest main. What else do I need to do to get it merged?

@mergify mergify bot added the intel-gpu Related to Intel GPU label Mar 27, 2026
@wendyliu235
Copy link
Copy Markdown
Contributor

@laudney We added a new Intel CI pipeline that only gates Intel PRs, so it does not apply to your PR. Feel free to ignore the result.

Replace verbose architecture-specific checks (on_gfx9(), on_mi3xx(),
has_device_capability(94)) with the cross-platform supports_fp8()
predicate across FP8-related code paths. This enables FP8 features
on RDNA4 (gfx12) GPUs which support FP8 but were excluded by the
MI300-specific gates.

Affected paths:
- TritonExperts / BatchedTritonExperts: FP8 MoE gate
- ROCmFP8ScaledMMLinearKernel: per-tensor FP8 skinny GEMM gate
- RowWiseTorchFP8ScaledMMLinearKernel: rowwise FP8 matmul gate
- PTPCFp8Config: dynamic FP8 quantization config

Signed-off-by: L.B.R. <lbr@mmonad.com>
@laudney laudney force-pushed the feat/rocm-fp8-capability-gates branch from a6eb66e to 7658407 Compare March 27, 2026 13:40
@laudney
Copy link
Copy Markdown
Contributor Author

laudney commented Mar 27, 2026

Rebased onto latest main. The only conflict was with ptpc_fp8.py which was removed upstream in #32700 — accepted the deletion since the file no longer exists. All other changes applied cleanly. Pre-commit hooks pass (the mypy failures are pre-existing on main, unrelated to this PR).

@mergify mergify bot removed the needs-rebase label Mar 27, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

intel-gpu Related to Intel GPU ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm

Projects

Status: Todo

Development

Successfully merging this pull request may close these issues.

3 participants