Skip to content

[Bugfix][ROCm][MoE] Fix mxfp4 oracle regressions from #37128#37787

Merged
tjtanaa merged 15 commits intovllm-project:mainfrom
ROCm:akaratza_fix_gptoss
Mar 25, 2026
Merged

[Bugfix][ROCm][MoE] Fix mxfp4 oracle regressions from #37128#37787
tjtanaa merged 15 commits intovllm-project:mainfrom
ROCm:akaratza_fix_gptoss

Conversation

@AndreasKaratzas
Copy link
Copy Markdown
Collaborator

@AndreasKaratzas AndreasKaratzas commented Mar 22, 2026

Fixes several issues introduced by #37128 that broke gpt-oss on ROCm.

  • Restore gfx950 gate for CK mxfp4 backend selection. The old code only picked CK on gfx950 via on_gfx950(), the refactor dropped this and let CK get selected on gfx942 where it crashes.
  • Restore CK_MXFP4_MOE_DIM_ALIGNMENT (256) check. Models with intermediate_size not aligned to 256 (like gpt-oss-20b at 2880) hit a reshape error in aiter shuffle_scale_a16w4. Added is_supported_config to AiterExperts so the backend selector falls through to Triton.
  • Restore hidden_pad/intermediate_pad for CK path. These were passed to rocm_aiter_ops.fused_moe() in the old code but got lost in the refactor. Added fields to FusedMoEQuantConfig and wired them through.

Tested on MI325X (gfx942):

  • test_gpt_oss_speculative_reasoning_leakage passes
  • GPQA eval via gpt_oss.evals: 56.76% (1584 questions, effort=low)
  • Backend correctly falls back to Triton on non-gfx950

Related:

cc @kenroche



Signed-off-by: Andreas Karatzas <akaratza@amd.com>


Signed-off-by: Andreas Karatzas <akaratza@amd.com>


Signed-off-by: Andreas Karatzas <akaratza@amd.com>
@AndreasKaratzas AndreasKaratzas marked this pull request as ready for review March 22, 2026 02:52
@AndreasKaratzas AndreasKaratzas added the rocm Related to AMD ROCm label Mar 22, 2026
@AndreasKaratzas AndreasKaratzas added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 22, 2026
@github-project-automation github-project-automation bot moved this to Todo in AMD Mar 22, 2026
@mergify mergify bot added gpt-oss Related to GPT-OSS models bug Something isn't working labels Mar 22, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses several regressions related to MoE with mxfp4 quantization, primarily affecting ROCm platforms, which were introduced in a recent refactoring. The fixes include restoring platform-specific checks for the CK backend, correctly handling dimension padding, and resolving an issue with LoRA on NVIDIA. My review identifies a critical contradiction in one of the changes: while the PR description claims to enable mxfp4 LoRA on ROCm, the code continues to raise a NotImplementedError, albeit with a more descriptive message. This discrepancy needs to be resolved.

@mergify
Copy link
Copy Markdown
Contributor

mergify bot commented Mar 22, 2026

Hi @AndreasKaratzas, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

@mergify mergify bot added the ci/build label Mar 22, 2026
@AndreasKaratzas
Copy link
Copy Markdown
Collaborator Author

Putting this PR in draft mode, as CI regression seems not to be addressed by the PR. The most straight-forward solution probably is to revert the problematic PR.

@AndreasKaratzas AndreasKaratzas marked this pull request as draft March 22, 2026 08:26
@mergify
Copy link
Copy Markdown
Contributor

mergify bot commented Mar 22, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @AndreasKaratzas.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

… in AiterExperts

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
@BowenBao
Copy link
Copy Markdown
Contributor

Please share if there's a discussion with upstream folks. #37128 broke CI, and rather than reverting it to figure out a proper solution, there's pressure to quickly land this PR. The size -> shape change and padding logic feel like quick(hacky) band-aids that will require the ROCm team to do more follow-up work down the line.

I don't feel this aligns with OSS best practices. However, if the maintainers are aware of this and have agreed it's the best way forward, I have nothing to add. cc @tjtanaa, @gshtras, @dllehr-amd, @ChuanLi1101

@AndreasKaratzas
Copy link
Copy Markdown
Collaborator Author

Please share if there's a discussion with upstream folks. #37128 broke CI, and rather than reverting it to figure out a proper solution, there's pressure to quickly land this PR. The size -> shape change and padding logic feel like quick(hacky) band-aids that will require the ROCm team to do more follow-up work down the line.

I don't feel this aligns with OSS best practices. However, if the maintainers are aware of this and have agreed it's the best way forward, I have nothing to add. cc @tjtanaa, @gshtras, @dllehr-amd, @ChuanLi1101

The .size() in to .shape change isn't a band-aid, it's the correct fix. triton_kernels.tensor.Tensor (used for MXFP4 swizzled weights since #37128) exposes .shape but not .size() or .dim(). .shape is the common interface that works with both tensor types and is the more Pythonic/numpy-standard accessor anyway. There's no follow-up work needed here.

On the padding logic, I hear the concern about it being incremental. The CK MXFP4 kernels on gfx950 require 256-byte aligned dimensions, and the padding was already happening in create_weights but wasn't being properly communicated to the layer/config.

@gshtras
Copy link
Copy Markdown
Collaborator

gshtras commented Mar 24, 2026

This appears to correctly fix the fallback to Triton caused by #35893 on gpt-oss-120b
lm_eval scores are back to normal when using CK MoE

# the triton_kernels/aiter side. This matches pre-#37128.
raise NotImplementedError(
"Mxfp4 LoRA is only supported on CUDA. "
"ROCm support is blocked by triton_kernels.tensor.Tensor "
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit. not is_cuda doesn't automatically mean rocm. Other platforms may get confused by this message.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done :)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
@github-project-automation github-project-automation bot moved this from To Triage to Ready in gpt-oss Issues & Enhancements Mar 24, 2026
@tjtanaa tjtanaa merged commit 679c6a3 into vllm-project:main Mar 25, 2026
72 of 73 checks passed
@github-project-automation github-project-automation bot moved this from Todo to Done in AMD Mar 25, 2026
@AndreasKaratzas AndreasKaratzas deleted the akaratza_fix_gptoss branch March 25, 2026 00:46
RhizoNymph pushed a commit to RhizoNymph/vllm that referenced this pull request Mar 26, 2026
HenryTangDev pushed a commit to HenryTangMain/vllm that referenced this pull request Mar 27, 2026
malaiwah pushed a commit to malaiwah/vllm that referenced this pull request Mar 27, 2026
 (vllm-project#37787)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
Signed-off-by: Michel Belleau <michel.belleau@malaiwah.com>
Monishver11 pushed a commit to Monishver11/vllm that referenced this pull request Mar 27, 2026
 (vllm-project#37787)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
nithinvc pushed a commit to nithinvc/vllm that referenced this pull request Mar 27, 2026
 (vllm-project#37787)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>

Signed-off-by: Nithin Chalapathi <nithin.ch10@gmail.com>
JiantaoXu pushed a commit to JiantaoXu/vllm that referenced this pull request Mar 28, 2026
vrdn-23 pushed a commit to vrdn-23/vllm that referenced this pull request Mar 30, 2026
 (vllm-project#37787)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
Signed-off-by: Vinay Damodaran <vrdn@hey.com>
EricccYang pushed a commit to EricccYang/vllm that referenced this pull request Apr 1, 2026
 (vllm-project#37787)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
Signed-off-by: EricccYang <yangyang4991@gmail.com>
bhargav-patel-29 pushed a commit to Bharatgen-Tech/vllm that referenced this pull request Apr 1, 2026
 (vllm-project#37787)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
Signed-off-by: bhargav-patel-29 <bhargav.patel@tihiitb.org>
liuchenbing2026 pushed a commit to liuchenbing2026/vllm that referenced this pull request Apr 4, 2026
rishitdholakia13 pushed a commit to rishitdholakia13/vllm that referenced this pull request Apr 7, 2026
 (vllm-project#37787)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
Signed-off-by: rishitdholakia13 <rishit+github@cohere.com>
puririshi98 pushed a commit to puririshi98/vllm that referenced this pull request Apr 7, 2026
 (vllm-project#37787)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
Signed-off-by: Rishi Puri <riship@nvidia.com>
big-yellow-duck pushed a commit to EmbeddedLLM/vllm that referenced this pull request Apr 8, 2026
mtparet pushed a commit to blackfuel-ai/vllm that referenced this pull request Apr 9, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working ci/build gpt-oss Related to GPT-OSS models ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm

Projects

Status: Done
Status: Done

Development

Successfully merging this pull request may close these issues.

7 participants