[Bugfix] Fix Triton FusedMoE LoRA#30585
Conversation
|
Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits. |
There was a problem hiding this comment.
Code Review
This pull request correctly addresses a critical bug where NaN values could appear in the attention output. The root cause is that the output tensor, allocated with torch.empty(), was not fully initialized, and subsequent attention operations only filled a portion of it up to num_actual_tokens. The added line output[num_actual_tokens:].fill_(0) effectively zeros out the remaining uninitialized part of the tensor, preventing any garbage values or NaNs from propagating. This is a robust and necessary fix. This same pattern of not zeroing out the padded portion of the output tensor may exist in other attention backends, and it would be beneficial to audit them for similar issues to ensure consistent behavior across the system.
|
Great find thanks a lot @xyang16! cc: @jeejeelee, @robertgshaw2-redhat, @varun-sundar-rabindranath |
|
Hi @xyang16, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
|
Hi @xyang16, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
851c552 to
b1e7f12
Compare
|
@jeejeelee - lmk if this looks okay to you |
|
I was able to reproduce this error on my A5500 hardware by running However, instead of applying the fix here, I applied the fix from #30650 and after that all of these tests passed. So, these two PRs and code paths are at least related if not directly intertwined. |
|
I believe there are two separate things in play here. The change to With that said, I tried testing just the changes in this PR on an H100 that's using FLASH_ATTN and Triton MXFP4 kernels and am still seeing the infinite generation: Logs from the failed test showing FLASH_ATTN and Triton MXFP4 backend in use: |
|
@bbrowning Thanks for helping investigate this!
Is there going to be other reason other than skipping custom_all_reduce would make
Yes, I have a note that says: This PR doesn't address the NaN caused by FULL_AND_PIECEWISE cudagraph mode, see #29539 (comment), so need to set cudagraph_mode to PIECEWISE + this PR to make it work. |
6287d37 to
f31c2fd
Compare
|
cc: @robertgshaw2-redhat I think LoRA stream still needs this fix in addition to #30887 thanks! |
| @@ -162,6 +162,7 @@ depthwise_seperable_CNN = "depthwise_seperable_CNN" | |||
| [tool.typos.default.extend-words] | |||
| iy = "iy" | |||
| tendencias = "tendencias" | |||
| indx = "indx" | |||
There was a problem hiding this comment.
Yes, otherwise pre-commit check will fail saying "indx" is not valid spelling. Thanks!
791774b to
9c1b3b2
Compare
varun-sundar-rabindranath
left a comment
There was a problem hiding this comment.
LGTM! Thanks for adding the smoke test @xyang16 .
4652ab4 to
654bb0a
Compare
|
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: Xin Yang <xyangx@amazon.com>
Signed-off-by: Xin Yang <xyangx@amazon.com>
jeejeelee
left a comment
There was a problem hiding this comment.
Sorry for missing this PR.
Signed-off-by: Xin Yang <xyangx@amazon.com>
Signed-off-by: Xin Yang <xyangx@amazon.com> Signed-off-by: dsuhinin <suhinin.dmitriy@gmail.com>
Signed-off-by: Xin Yang <xyangx@amazon.com>
Purpose
This PR is to fix Triton fused_moe_lora.
intermediate_cache1to make sure LoRA weights is added to the correct rows here.test_gptoss_tp.pyfor Triton backend.Test Plan
Tests passed.
Accuracy Testing
Marlin:
Triton:
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.cc @robertgshaw2-redhat @jeejeelee
Note
Addresses incorrect row alignment in Triton Unfused MoE LoRA path.
gpt_oss_triton_kernels_moe.py(UnfusedOAITritonExperts): apply activation tointermediate_cache1.view(-1, N)[gather_indx.dst_indx]and feedintermediate_cache2[gather_indx.src_indx]into secondmatmul_ogs; update comments accordinglytests/lora/test_gptoss_tp.pynow parametrizesVLLM_MXFP4_USE_MARLIN(True/False) for both single-GPU and TP2 runs, covering fully-sharded variantsindxto typos allowlist inpyproject.tomlWritten by Cursor Bugbot for commit b20346f. This will update automatically on new commits. Configure here.