Skip to content

[ROCm][CI] Fix GPT-OSS Quark MXFP4+FP8 MoE startup#41330

Closed
AndreasKaratzas wants to merge 1 commit intovllm-project:mainfrom
ROCm:akaratza_fix_gptoss_ci
Closed

[ROCm][CI] Fix GPT-OSS Quark MXFP4+FP8 MoE startup#41330
AndreasKaratzas wants to merge 1 commit intovllm-project:mainfrom
ROCm:akaratza_fix_gptoss_ci

Conversation

@AndreasKaratzas
Copy link
Copy Markdown
Collaborator

Fix GPT-OSS Quark MXFP4+FP8 MoE startup on ROCm/gfx950 by applying the padding required by the Triton/CDNA4 MXFP4 scale layout, and align the GPT-OSS Quark monolithic MoE method with the current MoERunner call signature. The GPQA ROCm config for amd/gpt-oss-20b-MoE-Quant-W-MXFP4-A-FP8-KV-FP8 failed during model startup while processing MoE weights:

RuntimeError: shape '[-1, 90, 2, 16, 11, 2, 4, 1]' is invalid for input of size 8294400

The failing path is QuarkOCP_MX_MoEMethod_OSS, used for GPT-OSS Quark weights with MXFP4 weights and static FP8 activations. This subclass still swizzles MXFP4 weights and scales through Triton kernels in process_weights_after_loading().

However, the base QuarkOCP_MX_MoEMethod treats w_mxfp4_a_fp8 as an emulation path for sizing purposes, so it does not apply the MXFP4 backend padding. On ROCm/CDNA4, Triton scale swizzling requires aligned dimensions. For the 20B TP=2 case, the unpadded dimensions produce scale shapes based on hidden_size=2880 and intermediate_size_per_partition=1440; hidden_size / 32 = 90, which is not compatible with the CDNA4 scale swizzle layout.

Testing

Ran the focused padded MoE test file:

pytest -s -v evals/gpt_oss/test_gpqa_correctness.py --config-list-file=configs/models-gfx950.txt

cc @kenroche

…dding required by the Triton CDNA4 MXFP4 scale layout

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
@AndreasKaratzas
Copy link
Copy Markdown
Collaborator Author

cc @Rohan138

@AndreasKaratzas AndreasKaratzas marked this pull request as ready for review April 30, 2026 03:34
Copy link
Copy Markdown

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Code Review

This pull request is from a fork — automated review is disabled. A repository maintainer can comment @claude review to run a one-time review.

@mergify mergify Bot added gpt-oss Related to GPT-OSS models rocm Related to AMD ROCm labels Apr 30, 2026
@AndreasKaratzas AndreasKaratzas removed the gpt-oss Related to GPT-OSS models label Apr 30, 2026
@mergify mergify Bot added the gpt-oss Related to GPT-OSS models label Apr 30, 2026
@github-project-automation github-project-automation Bot moved this to Todo in AMD Apr 30, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the maybe_roundup_sizes method to the Quark MoE implementation to ensure proper padding for MXFP4 weights when using Triton kernels. The feedback suggests using a more idiomatic Python approach with super() to call the grandparent class method instead of referencing the base class directly, which improves code style and MRO handling.

Comment on lines +1498 to +1504
FusedMoEMethodBase.maybe_roundup_sizes(
self,
hidden_size=hidden_size,
intermediate_size_per_partition=intermediate_size_per_partition,
act_dtype=act_dtype,
moe_parallel_config=moe_parallel_config,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The use of FusedMoEMethodBase.maybe_roundup_sizes(self, ...) is a bit unconventional for calling a grandparent's method. While it works, using super(QuarkOCP_MX_MoEMethod, self).maybe_roundup_sizes(...) is the standard and more idiomatic way in Python to explicitly skip the immediate parent's implementation and call the next one in the MRO.

            super(QuarkOCP_MX_MoEMethod, self).maybe_roundup_sizes(
                hidden_size=hidden_size,
                intermediate_size_per_partition=intermediate_size_per_partition,
                act_dtype=act_dtype,
                moe_parallel_config=moe_parallel_config,
            )

@Rohan138
Copy link
Copy Markdown
Contributor

However, the base QuarkOCP_MX_MoEMethod treats w_mxfp4_a_fp8 as an emulation path for sizing purposes, so it does not apply the MXFP4 backend padding

So this part is actually incorrect, coming from #39801 ... Similar to #41175, we need to add back w_mxfp4_a_fp8 to the allowlist. @BowenBao has a PR to refactor the MXFP4 backends correctly: #39136

@AndreasKaratzas
Copy link
Copy Markdown
Collaborator Author

Closing in favor of #39136

@github-project-automation github-project-automation Bot moved this from To Triage to Done in gpt-oss Issues & Enhancements Apr 30, 2026
@github-project-automation github-project-automation Bot moved this from Todo to Done in AMD Apr 30, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

gpt-oss Related to GPT-OSS models rocm Related to AMD ROCm

Projects

Status: Done
Status: Done

Development

Successfully merging this pull request may close these issues.

2 participants