Skip to content

[ROCm][CI] Extended Fused MoE and FP8 MoE test support#41100

Draft
AndreasKaratzas wants to merge 1 commit intovllm-project:mainfrom
ROCm:akaratza_ci_fusedmoe_modelopt
Draft

[ROCm][CI] Extended Fused MoE and FP8 MoE test support#41100
AndreasKaratzas wants to merge 1 commit intovllm-project:mainfrom
ROCm:akaratza_ci_fusedmoe_modelopt

Conversation

@AndreasKaratzas
Copy link
Copy Markdown
Collaborator

This PR makes the fused MoE layer test matrix usable on ROCm/MI355 by fixing the real ModelOpt FP8/FP4 failures it exposes and by making distributed subcase failures visible to pytest.

Key changes:

  • Propagate fused MoE distributed subcase failures back to the parent pytest process instead of allowing child-rank failures to print as failed subcases while the parent test reports PASSED.
  • Avoid collecting invalid no-parallel feature combinations where routed_input_transform or gate is requested without shared_experts.
  • Allow modelopt_fp4 MoE test configs on ROCm and SM90+ paths, where native or emulated NVFP4 execution is available.
  • Use the existing NVFP4 reference quantization path to create packed FP4 test weights on ROCm, since ops.scaled_fp4_quant is not available there.
  • Keep NVFP4 emulation lookup tensors on the same device as the packed FP4 input during dequantization.
  • Keep ModelOpt FP8 tensor-wise MoE activation scales as rank-1 tensors after reduction so Triton receives loadable scale pointers rather than constexpr scalar values.
  • Gate the experimental MoRI fused MoE layer matrix behind VLLM_TEST_ENABLE_MORI_MOE_LAYER=1; when enabled, the test sets the AITER fused MoE env requirements and disables AITER shared expert fusion.
  • Enable the modular OAI Triton MoE test on CUDA-like platforms and pad MXFP4 test weights/inputs to the CDNA4 scale-layout alignment on ROCm while slicing outputs back to the original test shape.

cc @kenroche

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the MoE (Mixture of Experts) testing infrastructure and kernel implementations. Key changes include adding new Buildkite test configurations for AMD hardware, implementing FP4 emulation for ROCm, and improving the MoE layer test runner to handle distributed failures more robustly with temporary failure reports. Additionally, minor adjustments were made to Triton kernel inputs and quantization utilities to ensure compatibility across different platforms. I have no feedback to provide as there were no review comments.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build rocm Related to AMD ROCm

Projects

Status: Todo

Development

Successfully merging this pull request may close these issues.

1 participant