diff --git a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_talker.py b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_talker.py index 255d45b452..1bfb0d3a3b 100644 --- a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_talker.py +++ b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_talker.py @@ -16,9 +16,12 @@ Qwen3OmniMoeAudioEncoder, ) from vllm.config import VllmConfig +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY +from vllm.model_executor.layers.fused_moe import SharedFusedMoE from vllm.model_executor.layers.linear import ReplicatedLinear +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.models.interfaces import ( MultiModalEmbeddings, SupportsMultiModal, @@ -27,13 +30,12 @@ from vllm.model_executor.models.qwen2_5_omni_thinker import ( Qwen2_5OmniThinkerDummyInputsBuilder, ) -from vllm.model_executor.models.qwen3_moe import Qwen3MoeMLP +from vllm.model_executor.models.qwen3_moe import Qwen3MoeMLP, Qwen3MoeSparseMoeBlock from vllm.model_executor.models.qwen3_omni_moe_thinker import Qwen3Omni_VisionTransformer from vllm.model_executor.models.utils import ( AutoWeightsLoader, WeightsMapper, maybe_prefix, - sequence_parallel_chunk, ) from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sequence import IntermediateTensors @@ -531,130 +533,198 @@ def forward(self, hidden_state): return self.linear_fc2(self.act_fn(self.linear_fc1(hidden_state))) +class Qwen3OmniMoeTalkerSharedExpertWrapper(nn.Module): + """ + Wrapper that combines shared_expert MLP with its sigmoid gate. + + This matches the HuggingFace weight structure where: + - mlp.shared_expert.{gate_proj, up_proj, down_proj}.weight + - mlp.shared_expert_gate.weight (sibling, not child) + + The wrapper applies: sigmoid(shared_expert_gate(x)) * shared_expert(x) + """ + + def __init__( + self, + shared_expert: Qwen3MoeMLP, + shared_expert_gate: nn.Linear, + ): + super().__init__() + self._shared_expert = shared_expert + self._shared_expert_gate = shared_expert_gate + + def forward(self, x: torch.Tensor) -> torch.Tensor: + out = self._shared_expert(x) + gate_values = F.sigmoid(self._shared_expert_gate(x)) # [batch, 1] + return gate_values * out # Broadcasting: [batch, 1] * [batch, hidden] + + +class Qwen3OmniMoeTalkerSparseMoeBlock(nn.Module): + """ + Sparse MoE block for Qwen3 Omni MoE Talker with shared expert support. + + This block uses SharedFusedMoE to efficiently compute both routed experts + and the shared expert, potentially overlapping computation with communication. + + Weight structure matches HuggingFace: + - mlp.gate.weight (router) + - mlp.shared_expert.{gate_proj, up_proj, down_proj}.weight + - mlp.shared_expert_gate.weight + - mlp.experts.{0..n}.{gate_proj, up_proj, down_proj}.weight + """ + + def __init__( + self, + config: Qwen3OmniMoeTalkerConfig, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ): + super().__init__() + text_config = config.text_config + self.tp_size = get_tensor_model_parallel_world_size() + + if self.tp_size > text_config.num_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than the number of experts {text_config.num_experts}." + ) + + # Router gate for selecting top-k experts + self.gate = ReplicatedLinear( + text_config.hidden_size, + text_config.num_experts, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate", + ) + + # Shared expert MLP (matches HF: mlp.shared_expert.*) + if text_config.shared_expert_intermediate_size > 0: + self.shared_expert = Qwen3MoeMLP( + hidden_size=text_config.hidden_size, + intermediate_size=text_config.shared_expert_intermediate_size, + hidden_act=text_config.hidden_act, + quant_config=quant_config, + reduce_results=False, # Don't reduce, we'll handle it + prefix=f"{prefix}.shared_expert", + ) + # Shared expert gate (matches HF: mlp.shared_expert_gate.weight) + # This is a sibling of shared_expert, not a child + self.shared_expert_gate = torch.nn.Linear(text_config.hidden_size, 1, bias=False) + # Create wrapper for SharedFusedMoE + self._shared_expert_wrapper = Qwen3OmniMoeTalkerSharedExpertWrapper( + self.shared_expert, self.shared_expert_gate + ) + else: + self.shared_expert = None + self.shared_expert_gate = None + self._shared_expert_wrapper = None + + # Fused MoE with shared expert support + self.experts = SharedFusedMoE( + shared_experts=self._shared_expert_wrapper, + num_experts=text_config.num_experts, + top_k=text_config.num_experts_per_tok, + hidden_size=text_config.hidden_size, + intermediate_size=text_config.moe_intermediate_size, + reduce_results=False, # We'll reduce manually after combining + renormalize=text_config.norm_topk_prob, + quant_config=quant_config, + prefix=f"{prefix}.experts", + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # NOTE: hidden_states can have either 1D or 2D shape. + orig_shape = hidden_states.shape + hidden_dim = hidden_states.shape[-1] + hidden_states = hidden_states.view(-1, hidden_dim) + + # Compute router logits + router_logits, _ = self.gate(hidden_states) + + # Forward through SharedFusedMoE + # Returns (shared_out, fused_out) when shared_expert is present + final_hidden_states = self.experts(hidden_states=hidden_states, router_logits=router_logits) + + # Combine shared and routed expert outputs + if self._shared_expert_wrapper is not None: + # SharedFusedMoE returns tuple: (shared_out, fused_out) + final_hidden_states = final_hidden_states[0] + final_hidden_states[1] + + # Apply tensor parallel reduction if needed + if self.tp_size > 1: + final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(final_hidden_states) + + return final_hidden_states.view(orig_shape) + + class Qwen3OmniMoeModel(Qwen3MoeLLMForCausalLM): - def __init__(self, vllm_config, talker_config, prefix): + """ + Qwen3 Omni MoE Talker language model. + + This model extends Qwen3MoeLLMForCausalLM with: + - Shared expert support via SharedFusedMoE + - Codec embedding instead of text embedding + - No LM head (codec head is separate in the parent class) + """ + + def __init__(self, vllm_config: VllmConfig, talker_config: Qwen3OmniMoeTalkerConfig, prefix: str): + # Create a vllm_config for the talker's text model talker_vllm_config = vllm_config.with_hf_config( talker_config.text_config, architectures=["Qwen3MoeForCausalLM"] ) talker_vllm_config.model_config.hf_text_config = talker_vllm_config.model_config.hf_config + super().__init__( vllm_config=talker_vllm_config, prefix=prefix, ) self.config = talker_config + self.talker_vllm_config = talker_vllm_config # Remove the inherited LM head so the talker only exposes codec outputs. if hasattr(self, "lm_head"): del self.lm_head - # Replace the base embed tokens with codec embedding (defined below). + # Replace the base embed tokens with codec embedding. if hasattr(self.model, "embed_tokens"): del self.model.embed_tokens # Codec embedding for RVQ code generation self.model.codec_embedding = nn.Embedding( - talker_config.text_config.vocab_size, talker_config.text_config.hidden_size + talker_config.text_config.vocab_size, + talker_config.text_config.hidden_size, ) - # Add shared expert to each MoE layer and patch the forward method - layer_idx = 0 - for layer in self.model.layers: - # add shared expert to Qwen3OmniMoeSparseMoeBlock layers - if hasattr(layer.mlp, "experts"): # Check if it's a SparseMoeBlock - # Shared expert is a regular gated MLP (SwiGLU) - layer.mlp.shared_expert = Qwen3MoeMLP( - hidden_size=self.config.text_config.hidden_size, - intermediate_size=self.config.text_config.shared_expert_intermediate_size, - hidden_act=self.config.text_config.hidden_act, - quant_config=talker_vllm_config.quant_config, - reduce_results=False, # Don't reduce since we'll add it manually - prefix=f"{prefix}.layers.{layer_idx}.mlp.shared_expert", - ) + # Replace MoE blocks with shared expert versions + self._replace_moe_blocks_with_shared_expert(prefix) - # Shared expert gate outputs a single scalar per token - layer.mlp.shared_expert_gate = ReplicatedLinear( - self.config.text_config.hidden_size, - 1, # Output single scalar per token - bias=False, - quant_config=None, - prefix=f"{prefix}.layers.{layer_idx}.mlp.shared_expert_gate", - ) - - # Store MoE config values for router computation - layer.mlp.top_k = self.config.text_config.num_experts_per_tok - layer.mlp.norm_topk_prob = self.config.text_config.norm_topk_prob - layer.mlp.num_experts = self.config.text_config.num_experts - - # Monkey-patch the forward method to use shared expert - layer.mlp.forward = self._create_moe_forward_with_shared_expert(layer.mlp) - - layer_idx += 1 - - def _create_moe_forward_with_shared_expert(self, moe_layer): - """Create a forward method that includes shared expert computation. - - This matches the Transformers implementation where: - 1. Compute shared expert output (regular MLP) - 2. Gate it with sigmoid(shared_expert_gate(x)) - 3. Apply softmax BEFORE top-k selection (matches Transformers router) - 4. Add to routed expert outputs + def _replace_moe_blocks_with_shared_expert(self, prefix: str) -> None: """ - - def forward_with_shared_expert(hidden_states: torch.Tensor, layer_idx: int = 0) -> torch.Tensor: - # Save original shape - orig_shape = hidden_states.shape - hidden_dim = hidden_states.shape[-1] - hidden_states = hidden_states.view(-1, hidden_dim) - - # handle sequence parallel if needed - if hasattr(moe_layer, "is_sequence_parallel") and moe_layer.is_sequence_parallel: - hidden_states = sequence_parallel_chunk(hidden_states) - - # Compute shared expert output - # The shared expert is a regular MLP, not a routed MoE - shared_output = None - if hasattr(moe_layer, "shared_expert") and moe_layer.shared_expert is not None: - # Forward through shared expert MLP - shared_output = moe_layer.shared_expert(hidden_states) - - # Apply gating with sigmoid: sigmoid(gate(x)) * shared_expert(x) - if hasattr(moe_layer, "shared_expert_gate") and moe_layer.shared_expert_gate is not None: - gate_logits, _ = moe_layer.shared_expert_gate(hidden_states) - gate_values = F.sigmoid(gate_logits) # [batch, 1] - shared_output = gate_values * shared_output # Broadcasting: [batch, 1] * [batch, hidden] - - # Compute experts results - # router_logits: (num_tokens, n_experts) - router_logits, _ = moe_layer.gate(hidden_states) - experts_output = moe_layer.experts(hidden_states=hidden_states, router_logits=router_logits) - - # combine experts and shared expert results - if shared_output is not None: - final_hidden_states = experts_output + shared_output - - # Handle sequence parallel if needed - if hasattr(moe_layer, "is_sequence_parallel") and moe_layer.is_sequence_parallel: - from vllm.distributed import tensor_model_parallel_all_gather - - num_tokens = orig_shape[0] if len(orig_shape) > 1 else 1 - final_hidden_states = tensor_model_parallel_all_gather(final_hidden_states, 0) - final_hidden_states = final_hidden_states[:num_tokens] - try: - final_hidden_states.view(orig_shape) - except Exception as e: - print(f"Error viewing final hidden states: {e}") - print(f"final_hidden_states.shape: {final_hidden_states.shape}") - print(f"orig_shape: {orig_shape}") - raise e - # Return with original shape - return final_hidden_states.view(orig_shape) - - return forward_with_shared_expert + Replace Qwen3MoeSparseMoeBlock layers with Qwen3OmniMoeTalkerSparseMoeBlock + that includes shared expert support via SharedFusedMoE. + """ + # Get compilation config to clean up registered layer names + compilation_config = self.talker_vllm_config.compilation_config + + for layer_idx, layer in enumerate(self.model.layers): + # Check if this layer has a MoE block (has experts attribute) + if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock): + # Remove old layer registration from static_forward_context + old_experts_prefix = f"{prefix}.model.layers.{layer_idx}.mlp.experts" + if old_experts_prefix in compilation_config.static_forward_context: + del compilation_config.static_forward_context[old_experts_prefix] + + # Create new MoE block with shared expert support + layer.mlp = Qwen3OmniMoeTalkerSparseMoeBlock( + config=self.config, + quant_config=self.talker_vllm_config.quant_config, + prefix=f"{prefix}.model.layers.{layer_idx}.mlp", + ) def embed_input_ids( self, input_ids: torch.Tensor, - **kwargs: object, ) -> torch.Tensor: + """Embed codec input IDs.""" return self.model.codec_embedding(input_ids)