[Bugfix] Fix GPT-OSS AR+NORM fusion#28841
Conversation
There was a problem hiding this comment.
Code Review
This pull request refactors the Mixture-of-Experts layer to enable an important fusion optimization (all_reduce + norm). The change involves moving a tensor slicing operation to occur before the all_reduce call, which allows the compiler to fuse all_reduce with the subsequent normalization. The implementation correctly adjusts all call sites to account for this reordering. The logic is sound, and the provided performance metrics show a clear benefit without any significant accuracy degradation. The changes are well-contained and address the issue described.
ProExpertProg
left a comment
There was a problem hiding this comment.
Nice, this should also reduce the size of the reduction in general! cc @ilmarkov @varun-sundar-rabindranath @bnellnm can you take a look please
|
@elvischenv @mgoin should we add GPT-OSS to nightly E2E fusion tests? |
|
Yes definitely, the 20b is good for CI |
f010fd4 to
3dd153b
Compare
| ModelBackendTestCase( | ||
| model_name="openai/gpt-oss-20b", | ||
| model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"), | ||
| backend=AttentionBackendEnum.FLASHINFER, | ||
| matches=Matches( | ||
| attention_fusion=0, | ||
| allreduce_fusion=49, | ||
| sequence_parallel=49, | ||
| async_tp=48, | ||
| ), | ||
| ), |
There was a problem hiding this comment.
@ProExpertProg @mgoin Added the 20b e2e fusion test.
Also tested on main and got expected failure:
> assert int(log_matches[0]) == matches.allreduce_fusion
^^^^^^^^^^^^^^^^^
E AssertionError: assert 25 == 49
E + where 25 = int('25')
E + and 49 = Matches(attention_fusion=0, allreduce_fusion=49, sequence_parallel=49, async_tp=48).allreduce_fusion
| if self.hidden_size != og_hidden_states: | ||
| states = states[..., :og_hidden_states].contiguous() |
There was a problem hiding this comment.
Is og_hidden_states going to be the same across all ranks? Also, is the call to contiguous necessary?
There was a problem hiding this comment.
Please see the above few lines. og_hidden_states is just used by padding before the MoE kernel. trtllm ar kernel just raise error if the memory is not contiguous.
There was a problem hiding this comment.
Can we move the contiguous call so that it only applies for the trtllm kernel?
There was a problem hiding this comment.
Sure. Does that mean other all reduce kernels support non-continuous memory?
There was a problem hiding this comment.
Not sure but there didn't seem to be a call to contiguous before?
There was a problem hiding this comment.
Sorry, I just thought the issue may not a trtllm specific issue, but a general issue with symm_mem:
File "/workspace/vllm/vllm/compilation/collective_fusion.py", line 560, in call_trtllm_fused_allreduce_norm
allreduce_out = tensor_model_parallel_all_reduce(allreduce_in)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/vllm/vllm/distributed/communication_op.py", line 14, in tensor_model_parallel_all_reduce
return get_tp_group().all_reduce(input_)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/vllm/vllm/distributed/parallel_state.py", line 499, in all_reduce
return torch.ops.vllm.all_reduce(input_, group_name=self.unique_name)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/.venv/lib/python3.12/site-packages/torch/_ops.py", line 1255, in __call__
return self._op(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/vllm/vllm/distributed/parallel_state.py", line 122, in all_reduce
return group._all_reduce_out_place(tensor)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/vllm/vllm/distributed/parallel_state.py", line 506, in _all_reduce_out_place
return self.device_communicator.all_reduce(input_)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/vllm/vllm/distributed/device_communicators/cuda_communicator.py", line 157, in all_reduce
out = symm_mem_comm.all_reduce(input_)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/vllm/vllm/distributed/device_communicators/symm_mem.py", line 134, in all_reduce
self.buffer[: inp.numel()].copy_(inp.view(-1))
^^^^^^^^^^^^
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
Just pushed a fix with .reshape(). Or maybe there are better solutions?
3dd153b to
66c597f
Compare
ProExpertProg
left a comment
There was a problem hiding this comment.
It seems like the "Blackwell Compile and Fusion tests" did not get triggered in CI. Could you add fused_moe/layer.py to the list of dependencies?
66c597f to
9d62f0b
Compare
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
9d62f0b to
54ab93f
Compare
elvischenv
left a comment
There was a problem hiding this comment.
Resolved merge conflict.
|
@ProExpertProg Could you review again and see if your comment was addressed? |
|
@mgoin could you help us to trigger the failing pipelines? I can't see the pipeline logs. thanks! |
|
This PR has been reverted by #29483 as it broke LoRA TP tests. Please open a new version of the PR that passes the test. |
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
…lm-project#29483) Signed-off-by: Huamin Li <3ericli@gmail.com>
…lm-project#29483) Signed-off-by: Huamin Li <3ericli@gmail.com> Signed-off-by: dsuhinin <suhinin.dmitriy@gmail.com>
Purpose
The slice op between all_reduce and rms_norm breaks the AR+NORM fusion with GPT-OSS+B200.
This PR move the slice op before the all_reduce to enable the possible fusion:
Test Plan && Result
Accuracy:
PR:
main:
Perf (TP8 con8):
PR: 3.8% improvement
main:
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.