Skip to content

[Bugfix] Fix Triton FusedMoE LoRA#30585

Merged
jeejeelee merged 2 commits intovllm-project:mainfrom
xyang16:nan
Jan 9, 2026
Merged

[Bugfix] Fix Triton FusedMoE LoRA#30585
jeejeelee merged 2 commits intovllm-project:mainfrom
xyang16:nan

Conversation

@xyang16
Copy link
Contributor

@xyang16 xyang16 commented Dec 13, 2025

Purpose

This PR is to fix Triton fused_moe_lora.

  • Reorder the rows of intermediate_cache1 to make sure LoRA weights is added to the correct rows here.
  • This will fix test_gptoss_tp.py for Triton backend.

Test Plan

pytest -s -v tests/lora/test_gptoss_tp.py

Tests passed.

Accuracy Testing

Marlin:

VLLM_MXFP4_USE_MARLIN=1 vllm serve openai/gpt-oss-20b \
  --tensor-parallel-size 1 \
  --max-num-seqs 16 \
  --enable-lora \
  --max-loras 1 \
  --lora-modules lora1=/opt/dlami/nvme/models/gpt-oss-20b-lora-gpqa/checkpoint-13 \
  --max-lora-rank 32 \
  --no-enable-prefix-caching
OPENAI_API_KEY=EMPTY python3 -m gpt_oss.evals --model lora1 --eval gpqa --n-threads 200 --reasoning-effort low
Writing report to /tmp/gpqa_lora1-low_temp1.0_20251214_195839.html
{'chars': np.float64(21.839015151515152), 'chars:std': np.float64(141.08266424169597), 'score': np.float64(0.586489898989899), 'score:std': np.float64(0.4924626862745207)}
Writing results to /tmp/gpqa_lora1-low_temp1.0_20251214_195839.json
Writing all results to /tmp/gpqa_lora1-low_temp1.0_20251214_195839_allresults.json
[{'eval_name': 'gpqa', 'model_name': 'lora1-low_temp1.0_20251214_195839', 'metric': 0.586489898989899}]

Triton:

vllm serve openai/gpt-oss-20b \
  --tensor-parallel-size 1 \
  --max-num-seqs 16 \
  --enable-lora \
  --max-loras 1 \
  --lora-modules lora1=/opt/dlami/nvme/models/gpt-oss-20b-lora-gpqa/checkpoint-13 \
  --max-lora-rank 32 \
  --no-enable-prefix-caching
OPENAI_API_KEY=EMPTY python3 -m gpt_oss.evals --model lora1 --eval gpqa --n-threads 200 --reasoning-effort low
Writing report to /tmp/gpqa_lora1-low_temp1.0_20251214_201352.html
{'chars': np.float64(25.067550505050505), 'chars:std': np.float64(149.54616903920677), 'score': np.float64(0.5883838383838383), 'score:std': np.float64(0.4921263019922218)}
Writing results to /tmp/gpqa_lora1-low_temp1.0_20251214_201352.json
Writing all results to /tmp/gpqa_lora1-low_temp1.0_20251214_201352_allresults.json
[{'eval_name': 'gpqa', 'model_name': 'lora1-low_temp1.0_20251214_201352', 'metric': 0.5883838383838383}]

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

cc @robertgshaw2-redhat @jeejeelee


Note

Addresses incorrect row alignment in Triton Unfused MoE LoRA path.

  • In gpt_oss_triton_kernels_moe.py (UnfusedOAITritonExperts): apply activation to intermediate_cache1.view(-1, N)[gather_indx.dst_indx] and feed intermediate_cache2[gather_indx.src_indx] into second matmul_ogs; update comments accordingly
  • Tests: tests/lora/test_gptoss_tp.py now parametrizes VLLM_MXFP4_USE_MARLIN (True/False) for both single-GPU and TP2 runs, covering fully-sharded variants
  • Config: add indx to typos allowlist in pyproject.toml

Written by Cursor Bugbot for commit b20346f. This will update automatically on new commits. Configure here.

@chatgpt-codex-connector
Copy link

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.

@mergify mergify bot added the v1 label Dec 13, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly addresses a critical bug where NaN values could appear in the attention output. The root cause is that the output tensor, allocated with torch.empty(), was not fully initialized, and subsequent attention operations only filled a portion of it up to num_actual_tokens. The added line output[num_actual_tokens:].fill_(0) effectively zeros out the remaining uninitialized part of the tensor, preventing any garbage values or NaNs from propagating. This is a robust and necessary fix. This same pattern of not zeroing out the padded portion of the output tensor may exist in other attention backends, and it would be beneficial to audit them for similar issues to ensure consistent behavior across the system.

@dcmaddix
Copy link
Contributor

Great find thanks a lot @xyang16! cc: @jeejeelee, @robertgshaw2-redhat, @varun-sundar-rabindranath

@xyang16 xyang16 changed the title [Bugfix] Fix NaN issue in attention output [Bugfix] Fix NaN issue for Triton FusedMoE LoRA Dec 13, 2025
@mergify mergify bot added the gpt-oss Related to GPT-OSS models label Dec 14, 2025
@mergify
Copy link

mergify bot commented Dec 14, 2025

Hi @xyang16, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@mergify
Copy link

mergify bot commented Dec 14, 2025

Hi @xyang16, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@robertgshaw2-redhat
Copy link
Collaborator

@jeejeelee - lmk if this looks okay to you

@bbrowning
Copy link
Contributor

I was able to reproduce this error on my A5500 hardware by running pytest -sv tests/lora/test_gptoss_tp.py and 2 of the tests failed because they generated 'SELECT!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!'

However, instead of applying the fix here, I applied the fix from #30650 and after that all of these tests passed. So, these two PRs and code paths are at least related if not directly intertwined.

@bbrowning
Copy link
Contributor

I believe there are two separate things in play here. The change to vllm/v1/attention/backends/flash_attn.py here looks directly related to #30650, and we probably either need to zero these in all backends or use the fix from 30650 to ensure we're exercising the custom_all_reduce path during compile/warmup.

With that said, I tried testing just the changes in this PR on an H100 that's using FLASH_ATTN and Triton MXFP4 kernels and am still seeing the infinite generation:

E           AssertionError: assert False                                                                                                                                                                            
E            +  where False = <built-in method startswith of str object at 0x7f70c96bf630>('SELECT AVG(Working_Horses) FROM farm WHERE Total_Horses > 5000;')
E            +    where <built-in method startswith of str object at 0x7f70c96bf630> = 'SELECT AVG(!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!'.startswith

Logs from the failed test showing FLASH_ATTN and Triton MXFP4 backend in use:

(EngineCore_DP0 pid=220174) (Worker_TP0 pid=220180) INFO 12-15 23:59:56 [gpu_model_runner.py:3562] Starting to load model openai/gpt-oss-20b...
(EngineCore_DP0 pid=220174) (Worker_TP0 pid=220180) INFO 12-15 23:59:56 [cuda.py:351] Using FLASH_ATTN attention backend out of potential backends: ('FLASH_ATTN', 'TRITON_ATTN')
(EngineCore_DP0 pid=220174) (Worker_TP0 pid=220180) INFO 12-15 23:59:56 [layer.py:372] Enabled separate cuda stream for MoE shared_experts  
(EngineCore_DP0 pid=220174) (Worker_TP0 pid=220180) INFO 12-15 23:59:56 [mxfp4.py:102] [get_mxfp4_backend_with_lora] Using Triton backend
(EngineCore_DP0 pid=220174) (Worker_TP1 pid=220182) INFO 12-15 23:59:56 [mxfp4.py:102] [get_mxfp4_backend_with_lora] Using Triton backend

@xyang16
Copy link
Contributor Author

xyang16 commented Dec 16, 2025

@bbrowning Thanks for helping investigate this!

I believe there are two separate things in play here. The change to vllm/v1/attention/backends/flash_attn.py here looks directly related to #30650, and we probably either need to zero these in all backends or use the fix from 30650 to ensure we're exercising the custom_all_reduce path during compile/warmup.

Is there going to be other reason other than skipping custom_all_reduce would make output[num_actual_tokens:] having NaNs?

With that said, I tried testing just the changes in this PR on an H100 that's using FLASH_ATTN and Triton MXFP4 kernels and am still seeing the infinite generation:

Yes, I have a note that says: This PR doesn't address the NaN caused by FULL_AND_PIECEWISE cudagraph mode, see #29539 (comment), so need to set cudagraph_mode to PIECEWISE + this PR to make it work.

llm = vllm.LLM(
    MODEL_PATH,
    max_model_len=1024,
    enable_lora=True,
    max_loras=4,
    max_lora_rank=8,
    compilation_config=vllm.config.CompilationConfig(  # Avoid OOM
        cudagraph_mode=vllm.config.compilation.CUDAGraphMode.PIECEWISE,
        cudagraph_specialize_lora=False,
    ),
)

@xyang16 xyang16 force-pushed the nan branch 3 times, most recently from 6287d37 to f31c2fd Compare December 17, 2025 19:36
@dcmaddix
Copy link
Contributor

dcmaddix commented Jan 6, 2026

cc: @robertgshaw2-redhat I think LoRA stream still needs this fix in addition to #30887 thanks!

@xyang16 xyang16 changed the title [Bugfix] Fix NaN issue for Triton FusedMoE LoRA [Bugfix] Fix Triton FusedMoE LoRA Jan 7, 2026
@@ -162,6 +162,7 @@ depthwise_seperable_CNN = "depthwise_seperable_CNN"
[tool.typos.default.extend-words]
iy = "iy"
tendencias = "tendencias"
indx = "indx"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xyang16 is this required ?

Copy link
Contributor Author

@xyang16 xyang16 Jan 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, otherwise pre-commit check will fail saying "indx" is not valid spelling. Thanks!

@xyang16 xyang16 requested a review from jeejeelee as a code owner January 7, 2026 20:29
@xyang16 xyang16 force-pushed the nan branch 2 times, most recently from 791774b to 9c1b3b2 Compare January 7, 2026 20:34
Copy link
Contributor

@varun-sundar-rabindranath varun-sundar-rabindranath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for adding the smoke test @xyang16 .

@xyang16 xyang16 force-pushed the nan branch 2 times, most recently from 4652ab4 to 654bb0a Compare January 7, 2026 20:46
@varun-sundar-rabindranath
Copy link
Contributor

cc @robertgshaw2-redhat @jeejeelee

@mergify
Copy link

mergify bot commented Jan 9, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @xyang16.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jan 9, 2026
Signed-off-by: Xin Yang <xyangx@amazon.com>
Signed-off-by: Xin Yang <xyangx@amazon.com>
Copy link
Collaborator

@jeejeelee jeejeelee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for missing this PR.

@github-project-automation github-project-automation bot moved this from To Triage to Ready in gpt-oss Issues & Enhancements Jan 9, 2026
@jeejeelee jeejeelee enabled auto-merge (squash) January 9, 2026 09:50
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Jan 9, 2026
@jeejeelee jeejeelee merged commit e7b68f4 into vllm-project:main Jan 9, 2026
61 of 62 checks passed
@xyang16 xyang16 deleted the nan branch January 9, 2026 18:41
akh64bit pushed a commit to akh64bit/vllm that referenced this pull request Jan 16, 2026
Signed-off-by: Xin Yang <xyangx@amazon.com>
dsuhinin pushed a commit to dsuhinin/vllm that referenced this pull request Jan 21, 2026
Signed-off-by: Xin Yang <xyangx@amazon.com>
Signed-off-by: dsuhinin <suhinin.dmitriy@gmail.com>
ItzDEXX pushed a commit to ItzDEXX/vllm that referenced this pull request Feb 19, 2026
Signed-off-by: Xin Yang <xyangx@amazon.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

gpt-oss Related to GPT-OSS models ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

6 participants