Skip to content

[Bugfix] Fix GPT-OSS AR+NORM fusion#28841

Merged
ProExpertProg merged 8 commits intovllm-project:mainfrom
elvischenv:elvischenv/fix-gpt-oss-ar-norm-fusion
Nov 25, 2025
Merged

[Bugfix] Fix GPT-OSS AR+NORM fusion#28841
ProExpertProg merged 8 commits intovllm-project:mainfrom
elvischenv:elvischenv/fix-gpt-oss-ar-norm-fusion

Conversation

@elvischenv
Copy link
Contributor

@elvischenv elvischenv commented Nov 17, 2025

Purpose

The slice op between all_reduce and rms_norm breaks the AR+NORM fusion with GPT-OSS+B200.

# moe
auto_functionalized_1 = torch.ops.higher_order.auto_functionalized(torch.ops.vllm.moe_forward.default, hidden_states = constant_pad_nd, router_logits = addmm_1, layer_name = 'model.layers.0.mlp.experts');  constant_pad_nd = addmm_1 = None
getitem_3: "bf16[s72, 3072]" = auto_functionalized_1[0];  auto_functionalized_1 = None

# all_reduce
all_reduce_1: "bf16[s72, 3072]" = torch.ops.vllm.all_reduce.default(getitem_3, 'tp:0');  getitem_3 = None

# slice
slice_1: "bf16[s72, 2880]" = torch.ops.aten.slice.Tensor(all_reduce_1, 1, 0, 2880)

# rms_norm
auto_functionalized_2 = torch.ops.higher_order.auto_functionalized(torch.ops._C.fused_add_rms_norm.default, input = slice_1, residual = getitem_18, weight = arg8_1, epsilon = 1e-05);  slice_1 = getitem_18 = arg8_1 = None

This PR move the slice op before the all_reduce to enable the possible fusion:

# moe
auto_functionalized = torch.ops.higher_order.auto_functionalized(torch.ops.vllm.moe_forward.default, hidden_states = constant_pad_nd, router_logits = addmm_1, layer_name = 'model.layers.1.mlp.experts');  constant_pad_nd = addmm_1 = None
getitem: "bf16[s72, 3072]" = auto_functionalized[0];  auto_functionalized = None

# slice
slice_1: "bf16[s72, 2880]" = torch.ops.aten.slice.Tensor(getitem, 1, 0, 2880);  getitem = None

# all_reduce + rms_norm
auto_functionalized_1 = torch.ops.higher_order.auto_functionalized(torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default, allreduce_in = slice_1, residual = getitem_14, norm_out = None, quant_out = None, scale_out = None, rms_gamma = arg8_1, rms_eps = 1e-05, pattern_code = 1, world_rank = 0, world_size = 4, launch_with_pdl = True, trigger_completion_at_end = True, fp32_acc = True, max_token_num = 5825);  slice_1 = getitem_14 = arg8_1 = None

Test Plan && Result

Accuracy:

PR:

[{'eval_name': 'aime25', 'model_name': 'gpt-oss-20b-medium_temp1.0_20251117_060221', 'metric': 0.7166666666666667}]

main:

[{'eval_name': 'aime25', 'model_name': 'gpt-oss-20b-medium_temp1.0_20251117_071029', 'metric': 0.7375}]

Perf (TP8 con8):

PR: 3.8% improvement

============ Serving Benchmark Result ============
Successful requests:                     40
Benchmark duration (s):                  10.84
Total input tokens:                      40960
Total generated tokens:                  40960
Request throughput (req/s):              3.69
Output token throughput (tok/s):         3779.37
Total Token throughput (tok/s):          7558.74
---------------Time to First Token----------------
Mean TTFT (ms):                          54.30
Median TTFT (ms):                        56.31
P99 TTFT (ms):                           71.81
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          2.06
Median TPOT (ms):                        2.07
P99 TPOT (ms):                           2.11
---------------Inter-token Latency----------------
Mean ITL (ms):                           2.07
Median ITL (ms):                         2.06
P99 ITL (ms):                            2.38
----------------End-to-end Latency----------------
Mean E2EL (ms):                          2165.61
Median E2EL (ms):                        2168.17
P99 E2EL (ms):                           2211.53
==================================================

main:

============ Serving Benchmark Result ============
Successful requests:                     40
Benchmark duration (s):                  11.24
Total input tokens:                      40960
Total generated tokens:                  40960
Request throughput (req/s):              3.56
Output token throughput (tok/s):         3644.69
Total Token throughput (tok/s):          7289.38
---------------Time to First Token----------------
Mean TTFT (ms):                          54.18
Median TTFT (ms):                        63.99
P99 TTFT (ms):                           72.94
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          2.14
Median TPOT (ms):                        2.15
P99 TPOT (ms):                           2.17
---------------Inter-token Latency----------------
Mean ITL (ms):                           2.14
Median ITL (ms):                         2.14
P99 ITL (ms):                            2.48
----------------End-to-end Latency----------------
Mean E2EL (ms):                          2245.97
Median E2EL (ms):                        2240.58
P99 E2EL (ms):                           2277.84
==================================================

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the Mixture-of-Experts layer to enable an important fusion optimization (all_reduce + norm). The change involves moving a tensor slicing operation to occur before the all_reduce call, which allows the compiler to fuse all_reduce with the subsequent normalization. The implementation correctly adjusts all call sites to account for this reordering. The logic is sound, and the provided performance metrics show a clear benefit without any significant accuracy degradation. The changes are well-contained and address the issue described.

@ZJY0516
Copy link
Member

ZJY0516 commented Nov 17, 2025

cc @ProExpertProg

Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, this should also reduce the size of the reduction in general! cc @ilmarkov @varun-sundar-rabindranath @bnellnm can you take a look please

@github-project-automation github-project-automation bot moved this from To Triage to Ready in gpt-oss Issues & Enhancements Nov 17, 2025
@ProExpertProg ProExpertProg added the ready ONLY add when PR is ready to merge/full CI is needed label Nov 17, 2025
@ProExpertProg
Copy link
Collaborator

@elvischenv @mgoin should we add GPT-OSS to nightly E2E fusion tests?

@mgoin
Copy link
Member

mgoin commented Nov 17, 2025

Yes definitely, the 20b is good for CI

@elvischenv elvischenv force-pushed the elvischenv/fix-gpt-oss-ar-norm-fusion branch from f010fd4 to 3dd153b Compare November 18, 2025 06:45
Comment on lines +118 to +128
ModelBackendTestCase(
model_name="openai/gpt-oss-20b",
model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"),
backend=AttentionBackendEnum.FLASHINFER,
matches=Matches(
attention_fusion=0,
allreduce_fusion=49,
sequence_parallel=49,
async_tp=48,
),
),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ProExpertProg @mgoin Added the 20b e2e fusion test.
Also tested on main and got expected failure:

>   assert int(log_matches[0]) == matches.allreduce_fusion
      ^^^^^^^^^^^^^^^^^
E   AssertionError: assert 25 == 49
E    +  where 25 = int('25')
E    +  and   49 = Matches(attention_fusion=0, allreduce_fusion=49, sequence_parallel=49, async_tp=48).allreduce_fusion

Comment on lines +1491 to +1492
if self.hidden_size != og_hidden_states:
states = states[..., :og_hidden_states].contiguous()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is og_hidden_states going to be the same across all ranks? Also, is the call to contiguous necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please see the above few lines. og_hidden_states is just used by padding before the MoE kernel. trtllm ar kernel just raise error if the memory is not contiguous.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move the contiguous call so that it only applies for the trtllm kernel?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. Does that mean other all reduce kernels support non-continuous memory?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure but there didn't seem to be a call to contiguous before?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I just thought the issue may not a trtllm specific issue, but a general issue with symm_mem:

self.buffer[: inp.numel()].copy_(inp.view(-1))

  File "/workspace/vllm/vllm/compilation/collective_fusion.py", line 560, in call_trtllm_fused_allreduce_norm
    allreduce_out = tensor_model_parallel_all_reduce(allreduce_in)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/vllm/vllm/distributed/communication_op.py", line 14, in tensor_model_parallel_all_reduce
    return get_tp_group().all_reduce(input_)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/vllm/vllm/distributed/parallel_state.py", line 499, in all_reduce
    return torch.ops.vllm.all_reduce(input_, group_name=self.unique_name)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/.venv/lib/python3.12/site-packages/torch/_ops.py", line 1255, in __call__
    return self._op(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/vllm/vllm/distributed/parallel_state.py", line 122, in all_reduce
    return group._all_reduce_out_place(tensor)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/vllm/vllm/distributed/parallel_state.py", line 506, in _all_reduce_out_place
    return self.device_communicator.all_reduce(input_)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/vllm/vllm/distributed/device_communicators/cuda_communicator.py", line 157, in all_reduce
    out = symm_mem_comm.all_reduce(input_)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/vllm/vllm/distributed/device_communicators/symm_mem.py", line 134, in all_reduce
    self.buffer[: inp.numel()].copy_(inp.view(-1))
                                     ^^^^^^^^^^^^
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

Just pushed a fix with .reshape(). Or maybe there are better solutions?

@elvischenv elvischenv force-pushed the elvischenv/fix-gpt-oss-ar-norm-fusion branch from 3dd153b to 66c597f Compare November 18, 2025 16:50
Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like the "Blackwell Compile and Fusion tests" did not get triggered in CI. Could you add fused_moe/layer.py to the list of dependencies?

@github-project-automation github-project-automation bot moved this from Ready to In progress in gpt-oss Issues & Enhancements Nov 18, 2025
@elvischenv elvischenv force-pushed the elvischenv/fix-gpt-oss-ar-norm-fusion branch from 66c597f to 9d62f0b Compare November 19, 2025 02:11
@mergify mergify bot added the ci/build label Nov 19, 2025
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
@elvischenv elvischenv force-pushed the elvischenv/fix-gpt-oss-ar-norm-fusion branch from 9d62f0b to 54ab93f Compare November 19, 2025 17:21
Copy link
Contributor Author

@elvischenv elvischenv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved merge conflict.

@nvpohanh
Copy link
Contributor

@ProExpertProg Could you review again and see if your comment was addressed?

@github-project-automation github-project-automation bot moved this from In progress to Ready in gpt-oss Issues & Enhancements Nov 21, 2025
@nvpohanh
Copy link
Contributor

@mgoin could you help us to trigger the failing pipelines? I can't see the pipeline logs. thanks!

@ProExpertProg ProExpertProg merged commit 6330f94 into vllm-project:main Nov 25, 2025
56 checks passed
@github-project-automation github-project-automation bot moved this from In review to Done in NVIDIA Nov 25, 2025
DarkLight1337 pushed a commit that referenced this pull request Nov 26, 2025
Signed-off-by: Huamin Li <3ericli@gmail.com>
@DarkLight1337
Copy link
Member

DarkLight1337 commented Nov 26, 2025

This PR has been reverted by #29483 as it broke LoRA TP tests. Please open a new version of the PR that passes the test.

devpatelio pushed a commit to SumanthRH/vllm that referenced this pull request Nov 29, 2025
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
kitaekatt pushed a commit to kitaekatt/vllm that referenced this pull request Dec 1, 2025
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
kitaekatt pushed a commit to kitaekatt/vllm that referenced this pull request Dec 1, 2025
dsuhinin pushed a commit to dsuhinin/vllm that referenced this pull request Jan 21, 2026
…lm-project#29483)

Signed-off-by: Huamin Li <3ericli@gmail.com>
Signed-off-by: dsuhinin <suhinin.dmitriy@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build gpt-oss Related to GPT-OSS models nvidia performance Performance-related issues ready ONLY add when PR is ready to merge/full CI is needed

Projects

Status: Done
Status: Done

Development

Successfully merging this pull request may close these issues.

7 participants