-
Notifications
You must be signed in to change notification settings - Fork 910
[Perf] Use vLLM's SharedFusedMoE in Qwen3-Omni #560
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
52ee5bb
dca79ba
11283fa
8026bbb
e24650e
d0a9346
11bfafb
66b686c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,9 +16,12 @@ | |
| Qwen3OmniMoeAudioEncoder, | ||
| ) | ||
| from vllm.config import VllmConfig | ||
| from vllm.distributed import get_tensor_model_parallel_world_size | ||
| from vllm.logger import init_logger | ||
| from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY | ||
| from vllm.model_executor.layers.fused_moe import SharedFusedMoE | ||
| from vllm.model_executor.layers.linear import ReplicatedLinear | ||
| from vllm.model_executor.layers.quantization import QuantizationConfig | ||
| from vllm.model_executor.models.interfaces import ( | ||
| MultiModalEmbeddings, | ||
| SupportsMultiModal, | ||
|
|
@@ -27,13 +30,12 @@ | |
| from vllm.model_executor.models.qwen2_5_omni_thinker import ( | ||
| Qwen2_5OmniThinkerDummyInputsBuilder, | ||
| ) | ||
| from vllm.model_executor.models.qwen3_moe import Qwen3MoeMLP | ||
| from vllm.model_executor.models.qwen3_moe import Qwen3MoeMLP, Qwen3MoeSparseMoeBlock | ||
| from vllm.model_executor.models.qwen3_omni_moe_thinker import Qwen3Omni_VisionTransformer | ||
| from vllm.model_executor.models.utils import ( | ||
| AutoWeightsLoader, | ||
| WeightsMapper, | ||
| maybe_prefix, | ||
| sequence_parallel_chunk, | ||
| ) | ||
| from vllm.multimodal import MULTIMODAL_REGISTRY | ||
| from vllm.sequence import IntermediateTensors | ||
|
|
@@ -531,130 +533,198 @@ def forward(self, hidden_state): | |
| return self.linear_fc2(self.act_fn(self.linear_fc1(hidden_state))) | ||
|
|
||
|
|
||
| class Qwen3OmniMoeTalkerSharedExpertWrapper(nn.Module): | ||
| """ | ||
| Wrapper that combines shared_expert MLP with its sigmoid gate. | ||
|
|
||
| This matches the HuggingFace weight structure where: | ||
| - mlp.shared_expert.{gate_proj, up_proj, down_proj}.weight | ||
| - mlp.shared_expert_gate.weight (sibling, not child) | ||
|
|
||
| The wrapper applies: sigmoid(shared_expert_gate(x)) * shared_expert(x) | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| shared_expert: Qwen3MoeMLP, | ||
| shared_expert_gate: nn.Linear, | ||
| ): | ||
| super().__init__() | ||
| self._shared_expert = shared_expert | ||
| self._shared_expert_gate = shared_expert_gate | ||
|
|
||
| def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
| out = self._shared_expert(x) | ||
| gate_values = F.sigmoid(self._shared_expert_gate(x)) # [batch, 1] | ||
| return gate_values * out # Broadcasting: [batch, 1] * [batch, hidden] | ||
|
|
||
|
|
||
| class Qwen3OmniMoeTalkerSparseMoeBlock(nn.Module): | ||
| """ | ||
| Sparse MoE block for Qwen3 Omni MoE Talker with shared expert support. | ||
|
|
||
| This block uses SharedFusedMoE to efficiently compute both routed experts | ||
| and the shared expert, potentially overlapping computation with communication. | ||
|
|
||
| Weight structure matches HuggingFace: | ||
| - mlp.gate.weight (router) | ||
| - mlp.shared_expert.{gate_proj, up_proj, down_proj}.weight | ||
| - mlp.shared_expert_gate.weight | ||
| - mlp.experts.{0..n}.{gate_proj, up_proj, down_proj}.weight | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| config: Qwen3OmniMoeTalkerConfig, | ||
| quant_config: QuantizationConfig | None = None, | ||
| prefix: str = "", | ||
| ): | ||
| super().__init__() | ||
| text_config = config.text_config | ||
| self.tp_size = get_tensor_model_parallel_world_size() | ||
|
|
||
| if self.tp_size > text_config.num_experts: | ||
| raise ValueError( | ||
| f"Tensor parallel size {self.tp_size} is greater than the number of experts {text_config.num_experts}." | ||
| ) | ||
|
|
||
| # Router gate for selecting top-k experts | ||
| self.gate = ReplicatedLinear( | ||
| text_config.hidden_size, | ||
| text_config.num_experts, | ||
| bias=False, | ||
| quant_config=quant_config, | ||
| prefix=f"{prefix}.gate", | ||
| ) | ||
|
|
||
| # Shared expert MLP (matches HF: mlp.shared_expert.*) | ||
| if text_config.shared_expert_intermediate_size > 0: | ||
| self.shared_expert = Qwen3MoeMLP( | ||
| hidden_size=text_config.hidden_size, | ||
| intermediate_size=text_config.shared_expert_intermediate_size, | ||
| hidden_act=text_config.hidden_act, | ||
| quant_config=quant_config, | ||
| reduce_results=False, # Don't reduce, we'll handle it | ||
| prefix=f"{prefix}.shared_expert", | ||
| ) | ||
| # Shared expert gate (matches HF: mlp.shared_expert_gate.weight) | ||
| # This is a sibling of shared_expert, not a child | ||
| self.shared_expert_gate = torch.nn.Linear(text_config.hidden_size, 1, bias=False) | ||
| # Create wrapper for SharedFusedMoE | ||
| self._shared_expert_wrapper = Qwen3OmniMoeTalkerSharedExpertWrapper( | ||
| self.shared_expert, self.shared_expert_gate | ||
| ) | ||
| else: | ||
| self.shared_expert = None | ||
| self.shared_expert_gate = None | ||
| self._shared_expert_wrapper = None | ||
|
|
||
| # Fused MoE with shared expert support | ||
| self.experts = SharedFusedMoE( | ||
| shared_experts=self._shared_expert_wrapper, | ||
| num_experts=text_config.num_experts, | ||
| top_k=text_config.num_experts_per_tok, | ||
| hidden_size=text_config.hidden_size, | ||
| intermediate_size=text_config.moe_intermediate_size, | ||
| reduce_results=False, # We'll reduce manually after combining | ||
| renormalize=text_config.norm_topk_prob, | ||
| quant_config=quant_config, | ||
| prefix=f"{prefix}.experts", | ||
| ) | ||
|
|
||
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | ||
| # NOTE: hidden_states can have either 1D or 2D shape. | ||
| orig_shape = hidden_states.shape | ||
| hidden_dim = hidden_states.shape[-1] | ||
| hidden_states = hidden_states.view(-1, hidden_dim) | ||
|
|
||
| # Compute router logits | ||
| router_logits, _ = self.gate(hidden_states) | ||
|
|
||
| # Forward through SharedFusedMoE | ||
| # Returns (shared_out, fused_out) when shared_expert is present | ||
| final_hidden_states = self.experts(hidden_states=hidden_states, router_logits=router_logits) | ||
|
|
||
| # Combine shared and routed expert outputs | ||
| if self._shared_expert_wrapper is not None: | ||
| # SharedFusedMoE returns tuple: (shared_out, fused_out) | ||
| final_hidden_states = final_hidden_states[0] + final_hidden_states[1] | ||
|
|
||
| # Apply tensor parallel reduction if needed | ||
| if self.tp_size > 1: | ||
| final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(final_hidden_states) | ||
|
|
||
| return final_hidden_states.view(orig_shape) | ||
|
|
||
|
|
||
| class Qwen3OmniMoeModel(Qwen3MoeLLMForCausalLM): | ||
| def __init__(self, vllm_config, talker_config, prefix): | ||
| """ | ||
| Qwen3 Omni MoE Talker language model. | ||
|
|
||
| This model extends Qwen3MoeLLMForCausalLM with: | ||
| - Shared expert support via SharedFusedMoE | ||
| - Codec embedding instead of text embedding | ||
| - No LM head (codec head is separate in the parent class) | ||
| """ | ||
|
|
||
| def __init__(self, vllm_config: VllmConfig, talker_config: Qwen3OmniMoeTalkerConfig, prefix: str): | ||
| # Create a vllm_config for the talker's text model | ||
| talker_vllm_config = vllm_config.with_hf_config( | ||
| talker_config.text_config, architectures=["Qwen3MoeForCausalLM"] | ||
| ) | ||
| talker_vllm_config.model_config.hf_text_config = talker_vllm_config.model_config.hf_config | ||
|
|
||
| super().__init__( | ||
| vllm_config=talker_vllm_config, | ||
| prefix=prefix, | ||
| ) | ||
|
|
||
| self.config = talker_config | ||
| self.talker_vllm_config = talker_vllm_config | ||
|
|
||
| # Remove the inherited LM head so the talker only exposes codec outputs. | ||
| if hasattr(self, "lm_head"): | ||
| del self.lm_head | ||
|
|
||
| # Replace the base embed tokens with codec embedding (defined below). | ||
| # Replace the base embed tokens with codec embedding. | ||
| if hasattr(self.model, "embed_tokens"): | ||
| del self.model.embed_tokens | ||
|
|
||
| # Codec embedding for RVQ code generation | ||
| self.model.codec_embedding = nn.Embedding( | ||
| talker_config.text_config.vocab_size, talker_config.text_config.hidden_size | ||
| talker_config.text_config.vocab_size, | ||
| talker_config.text_config.hidden_size, | ||
| ) | ||
|
|
||
| # Add shared expert to each MoE layer and patch the forward method | ||
| layer_idx = 0 | ||
| for layer in self.model.layers: | ||
| # add shared expert to Qwen3OmniMoeSparseMoeBlock layers | ||
| if hasattr(layer.mlp, "experts"): # Check if it's a SparseMoeBlock | ||
| # Shared expert is a regular gated MLP (SwiGLU) | ||
| layer.mlp.shared_expert = Qwen3MoeMLP( | ||
| hidden_size=self.config.text_config.hidden_size, | ||
| intermediate_size=self.config.text_config.shared_expert_intermediate_size, | ||
| hidden_act=self.config.text_config.hidden_act, | ||
| quant_config=talker_vllm_config.quant_config, | ||
| reduce_results=False, # Don't reduce since we'll add it manually | ||
| prefix=f"{prefix}.layers.{layer_idx}.mlp.shared_expert", | ||
| ) | ||
| # Replace MoE blocks with shared expert versions | ||
| self._replace_moe_blocks_with_shared_expert(prefix) | ||
|
|
||
| # Shared expert gate outputs a single scalar per token | ||
| layer.mlp.shared_expert_gate = ReplicatedLinear( | ||
| self.config.text_config.hidden_size, | ||
| 1, # Output single scalar per token | ||
| bias=False, | ||
| quant_config=None, | ||
| prefix=f"{prefix}.layers.{layer_idx}.mlp.shared_expert_gate", | ||
| ) | ||
|
|
||
| # Store MoE config values for router computation | ||
| layer.mlp.top_k = self.config.text_config.num_experts_per_tok | ||
| layer.mlp.norm_topk_prob = self.config.text_config.norm_topk_prob | ||
| layer.mlp.num_experts = self.config.text_config.num_experts | ||
|
|
||
| # Monkey-patch the forward method to use shared expert | ||
| layer.mlp.forward = self._create_moe_forward_with_shared_expert(layer.mlp) | ||
|
|
||
| layer_idx += 1 | ||
|
|
||
| def _create_moe_forward_with_shared_expert(self, moe_layer): | ||
| """Create a forward method that includes shared expert computation. | ||
|
|
||
| This matches the Transformers implementation where: | ||
| 1. Compute shared expert output (regular MLP) | ||
| 2. Gate it with sigmoid(shared_expert_gate(x)) | ||
| 3. Apply softmax BEFORE top-k selection (matches Transformers router) | ||
| 4. Add to routed expert outputs | ||
| def _replace_moe_blocks_with_shared_expert(self, prefix: str) -> None: | ||
| """ | ||
|
|
||
| def forward_with_shared_expert(hidden_states: torch.Tensor, layer_idx: int = 0) -> torch.Tensor: | ||
| # Save original shape | ||
| orig_shape = hidden_states.shape | ||
| hidden_dim = hidden_states.shape[-1] | ||
| hidden_states = hidden_states.view(-1, hidden_dim) | ||
|
|
||
| # handle sequence parallel if needed | ||
| if hasattr(moe_layer, "is_sequence_parallel") and moe_layer.is_sequence_parallel: | ||
| hidden_states = sequence_parallel_chunk(hidden_states) | ||
|
|
||
| # Compute shared expert output | ||
| # The shared expert is a regular MLP, not a routed MoE | ||
| shared_output = None | ||
| if hasattr(moe_layer, "shared_expert") and moe_layer.shared_expert is not None: | ||
| # Forward through shared expert MLP | ||
| shared_output = moe_layer.shared_expert(hidden_states) | ||
|
|
||
| # Apply gating with sigmoid: sigmoid(gate(x)) * shared_expert(x) | ||
| if hasattr(moe_layer, "shared_expert_gate") and moe_layer.shared_expert_gate is not None: | ||
| gate_logits, _ = moe_layer.shared_expert_gate(hidden_states) | ||
| gate_values = F.sigmoid(gate_logits) # [batch, 1] | ||
| shared_output = gate_values * shared_output # Broadcasting: [batch, 1] * [batch, hidden] | ||
|
|
||
| # Compute experts results | ||
| # router_logits: (num_tokens, n_experts) | ||
| router_logits, _ = moe_layer.gate(hidden_states) | ||
| experts_output = moe_layer.experts(hidden_states=hidden_states, router_logits=router_logits) | ||
|
|
||
| # combine experts and shared expert results | ||
| if shared_output is not None: | ||
| final_hidden_states = experts_output + shared_output | ||
|
|
||
| # Handle sequence parallel if needed | ||
| if hasattr(moe_layer, "is_sequence_parallel") and moe_layer.is_sequence_parallel: | ||
| from vllm.distributed import tensor_model_parallel_all_gather | ||
|
|
||
| num_tokens = orig_shape[0] if len(orig_shape) > 1 else 1 | ||
| final_hidden_states = tensor_model_parallel_all_gather(final_hidden_states, 0) | ||
| final_hidden_states = final_hidden_states[:num_tokens] | ||
| try: | ||
| final_hidden_states.view(orig_shape) | ||
| except Exception as e: | ||
| print(f"Error viewing final hidden states: {e}") | ||
| print(f"final_hidden_states.shape: {final_hidden_states.shape}") | ||
| print(f"orig_shape: {orig_shape}") | ||
| raise e | ||
| # Return with original shape | ||
| return final_hidden_states.view(orig_shape) | ||
|
|
||
| return forward_with_shared_expert | ||
| Replace Qwen3MoeSparseMoeBlock layers with Qwen3OmniMoeTalkerSparseMoeBlock | ||
| that includes shared expert support via SharedFusedMoE. | ||
| """ | ||
| # Get compilation config to clean up registered layer names | ||
| compilation_config = self.talker_vllm_config.compilation_config | ||
|
|
||
| for layer_idx, layer in enumerate(self.model.layers): | ||
| # Check if this layer has a MoE block (has experts attribute) | ||
| if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock): | ||
| # Remove old layer registration from static_forward_context | ||
| old_experts_prefix = f"{prefix}.model.layers.{layer_idx}.mlp.experts" | ||
| if old_experts_prefix in compilation_config.static_forward_context: | ||
| del compilation_config.static_forward_context[old_experts_prefix] | ||
|
|
||
| # Create new MoE block with shared expert support | ||
| layer.mlp = Qwen3OmniMoeTalkerSparseMoeBlock( | ||
| config=self.config, | ||
| quant_config=self.talker_vllm_config.quant_config, | ||
| prefix=f"{prefix}.model.layers.{layer_idx}.mlp", | ||
| ) | ||
|
Comment on lines
+704
to
+723
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks a bit hacky for compilation workaround. Perhaps we can upstream the
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah. We should make it upstream later. |
||
|
|
||
| def embed_input_ids( | ||
| self, | ||
| input_ids: torch.Tensor, | ||
| **kwargs: object, | ||
| ) -> torch.Tensor: | ||
| """Embed codec input IDs.""" | ||
| return self.model.codec_embedding(input_ids) | ||
|
gcanlin marked this conversation as resolved.
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the shared expert is instantiated on every tensor-parallel rank (as it is here) and only the routed experts are sharded, the current order will all-reduce the shared expert output along with the routed output. That sums the shared contribution across TP ranks when
tp_size > 1, so logits scale with TP size and diverge from HF for multi-GPU runs. A safer order is to all-reduce only the routed output, then add the shared output after reduction (or otherwise prevent the shared output from being summed). This only affects TP>1 with a replicated shared expert.Useful? React with 👍 / 👎.