Skip to content

[Perf] Use vLLM's SharedFusedMoE in Qwen3-Omni#560

Merged
Isotr0py merged 8 commits intovllm-project:mainfrom
gcanlin:qwen3-omni-perf
Jan 10, 2026
Merged

[Perf] Use vLLM's SharedFusedMoE in Qwen3-Omni#560
Isotr0py merged 8 commits intovllm-project:mainfrom
gcanlin:qwen3-omni-perf

Conversation

@gcanlin
Copy link
Copy Markdown
Collaborator

@gcanlin gcanlin commented Dec 31, 2025

PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS (AT THE BOTTOM) HAVE BEEN CONSIDERED.

Purpose

Add Qwen3OmniMoeTalkerSparseMoeBlock to use SharedFusedMoE instead of FusedMoE so that we don't need to write shared fused moe manually and we can leverage the high-performance operators already implemented in vLLM. Since this is a custom op, it can be dispatched based on the hardware platform, enabling better performance on each target hardware.

Test Plan

Test Result

execute_model: 177 ms ---> 123 ms
Qwen3MoeDecoderLayer_0: 34 ms ---> 21 ms
Qwen3MoeSparseMoeBlock_0: 24 ms --> 13 ms

image image
Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft.

BEFORE SUBMITTING, PLEASE READ https://github.com/vllm-project/vllm-omni/blob/main/CONTRIBUTING.md (anything written below this line will be removed by GitHub Actions)

Signed-off-by: gcanlin <canlinguosdu@gmail.com>
@hsliuustc0106 hsliuustc0106 linked an issue Dec 31, 2025 that may be closed by this pull request
35 tasks
Signed-off-by: gcanlin <canlinguosdu@gmail.com>
@gcanlin
Copy link
Copy Markdown
Collaborator Author

gcanlin commented Dec 31, 2025

@Isotr0py Could you please help review? Actually, I'm not very familiar with SharedFusedMoE and afraid that I could bring some hidden bugs into modeling.

Signed-off-by: gcanlin <canlinguosdu@gmail.com>
@hsliuustc0106
Copy link
Copy Markdown
Collaborator

@amy-why-3459 PTAL

Comment thread vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_talker.py
Comment on lines +750 to +769
Replace Qwen3MoeSparseMoeBlock layers with Qwen3OmniMoeTalkerSparseMoeBlock
that includes shared expert support via SharedFusedMoE.
"""
# Get compilation config to clean up registered layer names
compilation_config = self.talker_vllm_config.compilation_config

for layer_idx, layer in enumerate(self.model.layers):
# Check if this layer has a MoE block (has experts attribute)
if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock):
# Remove old layer registration from static_forward_context
old_experts_prefix = f"{prefix}.model.layers.{layer_idx}.mlp.experts"
if old_experts_prefix in compilation_config.static_forward_context:
del compilation_config.static_forward_context[old_experts_prefix]

# Create new MoE block with shared expert support
layer.mlp = Qwen3OmniMoeTalkerSparseMoeBlock(
config=self.config,
quant_config=self.talker_vllm_config.quant_config,
prefix=f"{prefix}.model.layers.{layer_idx}.mlp",
)
Copy link
Copy Markdown
Member

@Isotr0py Isotr0py Dec 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks a bit hacky for compilation workaround. Perhaps we can upstream the SharedFusedMoE support to vLLM's qwen3_moe.py to avoid this in following PR?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. We should make it upstream later.

Comment thread vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_talker.py Outdated
@hsliuustc0106 hsliuustc0106 requested a review from a team January 6, 2026 14:39
@ZJY0516
Copy link
Copy Markdown
Member

ZJY0516 commented Jan 8, 2026

Any progress on this?

Signed-off-by: gcanlin <canlinguosdu@gmail.com>
@gcanlin
Copy link
Copy Markdown
Collaborator Author

gcanlin commented Jan 8, 2026

Any progress on this?

Yes. I'm updating it today.

Signed-off-by: gcanlin <canlinguosdu@gmail.com>
@gcanlin gcanlin marked this pull request as ready for review January 8, 2026 09:50
@gcanlin gcanlin requested a review from hsliuustc0106 as a code owner January 8, 2026 09:50
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: d0a9346260

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines +641 to +648
# Combine shared and routed expert outputs
if self._shared_expert_wrapper is not None:
# SharedFusedMoE returns tuple: (shared_out, fused_out)
final_hidden_states = final_hidden_states[0] + final_hidden_states[1]

# Apply tensor parallel reduction if needed
if self.tp_size > 1:
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(final_hidden_states)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Reduce routed output before adding shared expert

If the shared expert is instantiated on every tensor-parallel rank (as it is here) and only the routed experts are sharded, the current order will all-reduce the shared expert output along with the routed output. That sums the shared contribution across TP ranks when tp_size > 1, so logits scale with TP size and diverge from HF for multi-GPU runs. A safer order is to all-reduce only the routed output, then add the shared output after reduction (or otherwise prevent the shared output from being summed). This only affects TP>1 with a replicated shared expert.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Member

@Isotr0py Isotr0py left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
@Isotr0py Isotr0py enabled auto-merge (squash) January 10, 2026 03:36
@Isotr0py Isotr0py added the ready label to trigger buildkite CI label Jan 10, 2026
@Isotr0py Isotr0py merged commit 9c2f746 into vllm-project:main Jan 10, 2026
6 of 7 checks passed
@gcanlin
Copy link
Copy Markdown
Collaborator Author

gcanlin commented Jan 10, 2026

@Isotr0py Many thanks!

sniper35 pushed a commit to sniper35/vllm-omni that referenced this pull request Jan 10, 2026
Signed-off-by: gcanlin <canlinguosdu@gmail.com>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
qibaoyuan pushed a commit to qibaoyuan/vllm-omni that referenced this pull request Jan 12, 2026
Signed-off-by: gcanlin <canlinguosdu@gmail.com>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: 齐保元 <qibaoyuan@xiaomi.com>
qibaoyuan pushed a commit to qibaoyuan/vllm-omni that referenced this pull request Jan 12, 2026
Signed-off-by: gcanlin <canlinguosdu@gmail.com>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: 齐保元 <qibaoyuan@xiaomi.com>
with1015 pushed a commit to with1015/vllm-omni that referenced this pull request Jan 20, 2026
Signed-off-by: gcanlin <canlinguosdu@gmail.com>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready label to trigger buildkite CI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[RFC]: Qwen3-Omni deployment

5 participants