Skip to content

[ROCm] Enable FP8 inference on gfx1201 AMD RDNA4 (Radeon AI PRO R9700) with aiter kernels#36659

Draft
vllmellm wants to merge 7 commits intovllm-project:mainfrom
EmbeddedLLM:rdna4-aiter
Draft

[ROCm] Enable FP8 inference on gfx1201 AMD RDNA4 (Radeon AI PRO R9700) with aiter kernels#36659
vllmellm wants to merge 7 commits intovllm-project:mainfrom
EmbeddedLLM:rdna4-aiter

Conversation

@vllmellm
Copy link
Copy Markdown
Contributor

@vllmellm vllmellm commented Mar 10, 2026

Purpose

No tuned Triton FP8 MoE configuration existed for the AMD Radeon AI PRO R9700 (gfx1201, RDNA4). vLLM selects fused MoE tiling parameters by (E, N, device_name, dtype) key at runtime — without an R9700 entry, it falls back to untuned defaults, leaving significant performance on the table.

Related: #28649.

Changes:

1. FP8 MoE Path for gfx12x (vllm/model_executor/layers/fused_moe/fused_moe.py)

device_supports_fp8 in TritonExperts._supports_quant_scheme() was gated on is_rocm_on_gfx9. Extended to include is_rocm_on_gfx12 (via the existing on_gfx12x() platform helper), enabling the FP8 expert linear kernel path for RDNA4. Without this, FP8 weight quantization is skipped and the MoE falls back to BF16/FP16 computation.

2. Tuned Triton FP8 MoE Config for R9700 (vllm/model_executor/layers/fused_moe/configs/)

Added E=64,N=768,device_name=AMD_Radeon_AI_PRO_R9700,dtype=fp8_w8a8,block_shape=[128,128].json.

vLLM selects fused MoE tiling parameters by (E, N, device_name, dtype) key at runtime. No R9700 entry existed previously, causing fallback to untuned defaults. This config covers Qwen3-30B-A3B-FP8 (E=64 routed experts, moe_intermediate_size=768).

Test Plan

  • vLLM serving benchmark on Qwen3-0.6B-FP8 and Qwen3-30B-A3B-FP8
  • ISL/OSL evaluated: 1024/1024, 2048/2048, 4096/4096, 8192/1024, 16384/2048
  • Hardware: AMD Radeon AI PRO R9700 (gfx1201)

Test Result

Comparison: Default (no tuned config, fallback tiling) vs Tuned MoE (with this config).

Qwen3-30B-A3B-FP8 (MoE)

Mean TTFT (s) — lower is better

ISL/OSL Default Tuned MoE Δ
1024/1024 0.926 0.726 −22%
2048/2048 1.565 1.231 −21%
4096/4096 5.333 4.351 −18%
8192/1024 14.066 10.974 −22%
16384/2048 144.387 120.308 −17%

Mean TPOT (s) — lower is better

ISL/OSL Default Tuned MoE Δ
1024/1024 0.0378 0.0288 −24%
2048/2048 0.0377 0.0314 −17%
4096/4096 0.0456 0.0402 −12%
8192/1024 0.0788 0.0687 −13%
16384/2048 0.0715 0.0659 −8%

Total Token Throughput (tok/s) — higher is better

ISL/OSL Default Tuned MoE Δ
1024/1024 1652 2167 +31%
2048/2048 1662 1995 +20%
4096/4096 1360 1547 +14%
8192/1024 2856 3411 +19%
16384/2048 1715 1979 +15%

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@mergify mergify bot added the rocm Related to AMD ROCm label Mar 10, 2026
@mergify
Copy link
Copy Markdown

mergify bot commented Mar 10, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @vllmellm.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request extends vLLM's support for ROCm gfx12x (RDNA4) GPUs by introducing specific detection for gfx12x architectures and integrating it into various components. Key changes include enabling AITER support for gfx12x, implementing a conditional mha_v3 Triton kernel for flash_attn_varlen_func on gfx12x for performance, and extending FP8 quantization support to gfx12x for fused and batched Mixture of Experts (MoE) layers. Additionally, new Triton tuning configurations for fused MoE on AMD_Radeon_AI_PRO_R9700 (a gfx12x device) have been added. A potential issue was identified where the condition on_gfx1x() and on_gfx12x() in get_vit_attn_backend will always be false, inadvertently disabling the Triton Flash Attention backend for both gfx11 and gfx12x devices in that specific function.

wi-adam added a commit to wi-adam/vllm that referenced this pull request Mar 12, 2026
Cherry-picked and adapted from 4 open PRs:

- vllm-project#34740 (laudney): Replace on_gfx9()/on_mi3xx() FP8 gates with
  supports_fp8(), unblocking FP8 on RDNA4/gfx12
- vllm-project#34709 (laudney): Enable wvSplitK/wvSplitKQ skinny GEMM kernels
  for RDNA4 decode (~15% improvement), wave32 DPP reduction
- vllm-project#34741 (laudney): FP8 KV-cache for RDNA4 custom paged attention
  via software dequantization
- vllm-project#36659 (vllmellm): Tuned FP8 MoE Triton configs for AMD Radeon
  AI PRO R9700, AITER mha_v3 attention on gfx12x
@freddybc
Copy link
Copy Markdown

if someone is interested to validate also MoE tuning for Int4, find attached the configuration files.

E=128,N=768,device_name=AMD_Radeon_AI_PRO_R9700,dtype=int4_w4a16.json
E=256,N=512,device_name=AMD_Radeon_AI_PRO_R9700,dtype=int4_w4a16.json

@
E=128,N=768,device_name=AMD_Radeon_AI_PRO_R9700,dtype=int4_w4a16.json
E=256,N=512,device_name=AMD_Radeon_AI_PRO_R9700,dtype=int4_w4a16.json

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

needs-rebase rocm Related to AMD ROCm

Projects

Status: Todo

Development

Successfully merging this pull request may close these issues.

3 participants