Skip to content

[Kernel][MoE] fix computation order of MoE weight multiplication and improve flow#31962

Merged
mgoin merged 4 commits intovllm-project:mainfrom
xuebwang-amd:xuebin_triton_fused_moe_mul_routed_weight_issue
Jan 12, 2026
Merged

[Kernel][MoE] fix computation order of MoE weight multiplication and improve flow#31962
mgoin merged 4 commits intovllm-project:mainfrom
xuebwang-amd:xuebin_triton_fused_moe_mul_routed_weight_issue

Conversation

@xuebwang-amd
Copy link
Contributor

@xuebwang-amd xuebwang-amd commented Jan 8, 2026

Purpose

Background:

This PR is a continuous work on two previous PRs:

  • PR #31676
    • key contribution: move bias adding after dequantization
    • computation order: to(compute_type) -> HAS_BIAS (bias adding) -> MUL_ROUTED_WEIGHT: this is the most closest one to the right order except for MoE weight multiplication not in float32.
  • PR #31931
    • key contribution: preserving router weight scaling in float32
    • computation order: MUL_ROUTED_WEIGHT -> to(compute_type) -> HAS_BIAS (bias adding): this appears to be a right order if no quantization and no bias, but will be wrong if having quantization.

This PR fixes computation flow as:

1. dequantization for supported quantization schemes while keep accumulator in float32
2. bias addition (bias is usually also in float32)
3. MoE weight multiplication in float32
4. cast to desired compute/output dtype

Test Plan

Mostly cover two previous PRs:

  • PR #31676: amd/gpt-oss-20b-WFP8-AFP8-KVFP8 since gpt-oss model has bias in addition to weight (y=Wx+b)
    • Note: This is a complete test case since the model amd/gpt-oss-20b-WFP8-AFP8-KVFP8 used for test has all essential ingredients including quantization use_fp8_w8a8 and bias.
  • PR #31931: pytest -s -v tests/lora/test_olmoe_tp.py::test_olmoe_lora_mixed:
    • Note: It is probably not a complete test case in terms of computation flow, since the model allenai/OLMoE-1B-7B-0125-Instruct used for test does not include quantization and bias.

Test Result

PR amd/gpt-oss-20b-WFP8-AFP8-KVFP8 (TP=2, non-eager) tests/lora/test_olmoe_tp.py::test_olmoe_lora_mixed
MI355 MI325 MI355 MI325
PR #31676 gsm8k_platinum: 0.90 gsm8k_platinum: 0.90 4/4 passed 3/4 passed
PR #31931 gsm8k_platinum: 0.00 gsm8k_platinum: 0.00 3/4 passed 4/4 passed
This PR gsm8k_platinum: 0.90 gsm8k_platinum: 0.90 3/4 passed 4/4 passed

Note

Aligns fused MoE compute flow for numerical correctness across quantization modes.

  • In fused_moe_kernel, perform fp32 dequantization (int8_w8a16, fp8_w8a8, int8_w8a8) → add bias → multiply router (MoE) weights in fp32 → cast once to compute_type
  • Ensures bias is added after dequantization and router weight scaling happens in fp32; minor cleanups (use +=, improved comments)

Written by Cursor Bugbot for commit 5e17691. This will update automatically on new commits. Configure here.


Note

Aligns fused MoE post-matmul flow for numerical correctness across quantization modes.

  • In fused_moe_kernel, perform fp32 dequantization for int8_w8a16, fp8_w8a8, and int8_w8a8 → add bias → multiply router weights in fp32 → cast once to compute_type
  • Minor cleanups: use += and improve comments; no other functional changes

Written by Cursor Bugbot for commit 27ac7fb. This will update automatically on new commits. Configure here.

…type at last, make flow compact)

Signed-off-by: xuebwang-amd <xuebwang@amd.com>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the computation flow in the fused_moe_kernel to improve numerical stability and code clarity. The changes correctly reorder the operations to perform dequantization scaling, bias addition, and router weight multiplication sequentially, all within float32 precision, before a single final cast to the target compute_type. This is a significant improvement over the previous implementation, which had multiple casts and a less optimal operation order. The new code is more robust, readable, and maintainable. The changes look good and are a solid improvement.

Signed-off-by: xuebwang-amd <xuebwang@amd.com>
@xuebwang-amd
Copy link
Contributor Author

Please feel free to review. @AndreasKaratzas @tjtanaa @mgoin

@AndreasKaratzas
Copy link
Collaborator

AndreasKaratzas commented Jan 8, 2026

@xuebwang-amd I run your changes on MI325. The test passed all times. Where and how did it fail for you? Where you able to debug it?

For example, we could adjust the atol of the test if it was a small diff.

@xuebwang-amd
Copy link
Contributor Author

@xuebwang-amd I run your changes on MI325. The test passed all times. Where and how did it fail for you? Where you able to debug it?

@AndreasKaratzas Yes, please find MI325 test results in the PR description. The two test cases are both in good accuracies in MI325.

For example, we could adjust the atol of the test if it was a small diff.

I agree.
Current assert would be a little bit strong,

assert generated_text.startswith(expected_output)

so minor numerical fluctuations could probably cause assert to be False there. as seen from the debugging output:

> /home/xuebwang/github/xuebwang/vllm/tests/lora/test_olmoe_tp.py(92)generate_and_test()
-> assert generated_text.startswith(expected_output)
(Pdb) generated_text
'SELECT c.Poll_Source FROM candidate c\nJOIN people p ON c.People_ID = p.People_ID\nORDER BY COUNT(c.Candidate_ID) DESC\nLIMIT 1;\n\n##Explanation:\nTo find the poll with the most candidates, we'
(Pdb) expected_output
'SELECT Candidate_ID, Poll_Source FROM candidate WHERE People_ID IN (SELECT People_ID FROM people) ORDER BY COUNT(*) DESC LIMIT 1'

I recommend to use a PPL check in place of an exact string match there.

@AndreasKaratzas
Copy link
Collaborator

AndreasKaratzas commented Jan 9, 2026

@xuebwang-amd I run your changes on MI325. The test passed all times. Where and how did it fail for you? Where you able to debug it?

@AndreasKaratzas Yes, please find MI325 test results in the PR description. The two test cases are both in good accuracies in MI325.

For example, we could adjust the atol of the test if it was a small diff.

I agree. Current assert would be a little bit strong,

assert generated_text.startswith(expected_output)

so minor numerical fluctuations could probably cause assert to be False there. as seen from the debugging output:

> /home/xuebwang/github/xuebwang/vllm/tests/lora/test_olmoe_tp.py(92)generate_and_test()
-> assert generated_text.startswith(expected_output)
(Pdb) generated_text
'SELECT c.Poll_Source FROM candidate c\nJOIN people p ON c.People_ID = p.People_ID\nORDER BY COUNT(c.Candidate_ID) DESC\nLIMIT 1;\n\n##Explanation:\nTo find the poll with the most candidates, we'
(Pdb) expected_output
'SELECT Candidate_ID, Poll_Source FROM candidate WHERE People_ID IN (SELECT People_ID FROM people) ORDER BY COUNT(*) DESC LIMIT 1'

I recommend to use a PPL check in place of an exact string match there.

@xuebwang-amd If you got any time, maybe integrate such a more elastic checking mechanism into this PR (I understand it's a small mod, but if you think it needs a more careful study, we can leave this for another PR).

@LucasWilkinson
Copy link
Collaborator

cc @mgoin

@mergify
Copy link

mergify bot commented Jan 10, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @xuebwang-amd.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jan 10, 2026
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
@mergify mergify bot removed the needs-rebase label Jan 12, 2026
@xuebwang-amd
Copy link
Contributor Author

@mgoin @tjtanaa


# Final precision conversion:
# Cast once at the end to the desired compute/output dtype.
accumulator = accumulator.to(compute_type)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Duplicate type conversion left after refactoring

Low Severity

The code contains two consecutive identical calls to accumulator.to(compute_type) at lines 566 and 568. The comment explicitly says "Cast once at the end" but the code performs the cast twice. This appears to be a refactoring artifact where the new line at 566 was added, but the original line at 568 was not removed. While functionally harmless (the conversion is idempotent), this is redundant code that contradicts the comment and may confuse future maintainers.

Fix in Cursor Fix in Web

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

duplicated line is removed.

Signed-off-by: xuebwang-amd <xuebwang@amd.com>
@xuebwang-amd
Copy link
Contributor Author

@xuebwang-amd If you got any time, maybe integrate such a more elastic checking mechanism into this PR (I understand it's a small mod, but if you think it needs a more careful study, we can leave this for another PR).

I propose the second way.

@mgoin mgoin added bug Something isn't working ready ONLY add when PR is ready to merge/full CI is needed labels Jan 12, 2026
@mgoin mgoin merged commit 629584b into vllm-project:main Jan 12, 2026
56 of 57 checks passed
TomerBN-Nvidia pushed a commit to TomerBN-Nvidia/vllm that referenced this pull request Jan 13, 2026
…improve flow (vllm-project#31962)

Signed-off-by: xuebwang-amd <xuebwang@amd.com>
Signed-off-by: Tomer Natan <tbarnatan@computelab-frontend-8.nvidia.com>
sammysun0711 pushed a commit to sammysun0711/vllm that referenced this pull request Jan 16, 2026
…improve flow (vllm-project#31962)

Signed-off-by: xuebwang-amd <xuebwang@amd.com>
akh64bit pushed a commit to akh64bit/vllm that referenced this pull request Jan 16, 2026
…improve flow (vllm-project#31962)

Signed-off-by: xuebwang-amd <xuebwang@amd.com>
dsuhinin pushed a commit to dsuhinin/vllm that referenced this pull request Jan 21, 2026
…improve flow (vllm-project#31962)

Signed-off-by: xuebwang-amd <xuebwang@amd.com>
Signed-off-by: dsuhinin <suhinin.dmitriy@gmail.com>
ItzDEXX pushed a commit to ItzDEXX/vllm that referenced this pull request Feb 19, 2026
…improve flow (vllm-project#31962)

Signed-off-by: xuebwang-amd <xuebwang@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants