[Kernel][MoE] fix computation order of MoE weight multiplication and improve flow#31962
Conversation
…type at last, make flow compact) Signed-off-by: xuebwang-amd <xuebwang@amd.com>
There was a problem hiding this comment.
Code Review
This pull request refactors the computation flow in the fused_moe_kernel to improve numerical stability and code clarity. The changes correctly reorder the operations to perform dequantization scaling, bias addition, and router weight multiplication sequentially, all within float32 precision, before a single final cast to the target compute_type. This is a significant improvement over the previous implementation, which had multiple casts and a less optimal operation order. The new code is more robust, readable, and maintainable. The changes look good and are a solid improvement.
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
|
Please feel free to review. @AndreasKaratzas @tjtanaa @mgoin |
|
@xuebwang-amd I run your changes on MI325. The test passed all times. Where and how did it fail for you? Where you able to debug it? For example, we could adjust the atol of the test if it was a small diff. |
@AndreasKaratzas Yes, please find MI325 test results in the PR description. The two test cases are both in good accuracies in MI325.
I agree. vllm/tests/lora/test_olmoe_tp.py Line 91 in bde38c1 so minor numerical fluctuations could probably cause assert to be False there. as seen from the debugging output: I recommend to use a PPL check in place of an exact string match there. |
@xuebwang-amd If you got any time, maybe integrate such a more elastic checking mechanism into this PR (I understand it's a small mod, but if you think it needs a more careful study, we can leave this for another PR). |
|
cc @mgoin |
|
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
|
|
||
| # Final precision conversion: | ||
| # Cast once at the end to the desired compute/output dtype. | ||
| accumulator = accumulator.to(compute_type) |
There was a problem hiding this comment.
Duplicate type conversion left after refactoring
Low Severity
The code contains two consecutive identical calls to accumulator.to(compute_type) at lines 566 and 568. The comment explicitly says "Cast once at the end" but the code performs the cast twice. This appears to be a refactoring artifact where the new line at 566 was added, but the original line at 568 was not removed. While functionally harmless (the conversion is idempotent), this is redundant code that contradicts the comment and may confuse future maintainers.
There was a problem hiding this comment.
duplicated line is removed.
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
I propose the second way. |
…improve flow (vllm-project#31962) Signed-off-by: xuebwang-amd <xuebwang@amd.com> Signed-off-by: Tomer Natan <tbarnatan@computelab-frontend-8.nvidia.com>
…improve flow (vllm-project#31962) Signed-off-by: xuebwang-amd <xuebwang@amd.com>
…improve flow (vllm-project#31962) Signed-off-by: xuebwang-amd <xuebwang@amd.com>
…improve flow (vllm-project#31962) Signed-off-by: xuebwang-amd <xuebwang@amd.com> Signed-off-by: dsuhinin <suhinin.dmitriy@gmail.com>
…improve flow (vllm-project#31962) Signed-off-by: xuebwang-amd <xuebwang@amd.com>
Purpose
Background:
This PR is a continuous work on two previous PRs:
This PR fixes computation flow as:
1. dequantization for supported quantization schemes while keep accumulator in float32
2. bias addition (bias is usually also in float32)
3. MoE weight multiplication in float32
4. cast to desired compute/output dtype
Test Plan
Mostly cover two previous PRs:
amd/gpt-oss-20b-WFP8-AFP8-KVFP8since gpt-oss model has bias in addition to weight (y=Wx+b)amd/gpt-oss-20b-WFP8-AFP8-KVFP8used for test has all essential ingredients including quantizationuse_fp8_w8a8and bias.pytest -s -v tests/lora/test_olmoe_tp.py::test_olmoe_lora_mixed:allenai/OLMoE-1B-7B-0125-Instructused for test does not include quantization and bias.Test Result
Note
Aligns fused MoE compute flow for numerical correctness across quantization modes.
fused_moe_kernel, perform fp32 dequantization (int8_w8a16,fp8_w8a8,int8_w8a8) → addbias→ multiply router (MoE) weights in fp32 → cast once tocompute_type+=, improved comments)Written by Cursor Bugbot for commit 5e17691. This will update automatically on new commits. Configure here.
Note
Aligns fused MoE post-matmul flow for numerical correctness across quantization modes.
fused_moe_kernel, perform fp32 dequantization forint8_w8a16,fp8_w8a8, andint8_w8a8→ addbias→ multiply router weights in fp32 → cast once tocompute_type+=and improve comments; no other functional changesWritten by Cursor Bugbot for commit 27ac7fb. This will update automatically on new commits. Configure here.