diff --git a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_talker.py b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_talker.py index 946694ed229..ac7d8120dfa 100644 --- a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_talker.py +++ b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_talker.py @@ -3,7 +3,6 @@ import torch import torch.nn as nn -import torch.nn.functional as F from transformers.models.qwen3_omni_moe.configuration_qwen3_omni_moe import ( Qwen3OmniMoeTalkerConfig, ) @@ -11,12 +10,8 @@ Qwen3OmniMoeAudioEncoder, ) from vllm.config import VllmConfig -from vllm.distributed import get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY -from vllm.model_executor.layers.fused_moe import SharedFusedMoE -from vllm.model_executor.layers.linear import ReplicatedLinear -from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.models.interfaces import ( MultiModalEmbeddings, SupportsPP, @@ -24,7 +19,6 @@ from vllm.model_executor.models.qwen2_5_omni_thinker import ( Qwen2_5OmniThinkerDummyInputsBuilder, ) -from vllm.model_executor.models.qwen3_moe import Qwen3MoeMLP, Qwen3MoeSparseMoeBlock from vllm.model_executor.models.qwen3_omni_moe_thinker import Qwen3Omni_VisionTransformer from vllm.model_executor.models.utils import ( AutoWeightsLoader, @@ -512,162 +506,14 @@ def forward(self, hidden_state): return self.linear_fc2(self.act_fn(self.linear_fc1(hidden_state))) -class Qwen3OmniMoeTalkerSharedExpertWrapper(nn.Module): - """ - Wrapper that combines shared_expert MLP with its sigmoid gate. - - This matches the HuggingFace weight structure where: - - mlp.shared_expert.{gate_proj, up_proj, down_proj}.weight - - mlp.shared_expert_gate.weight (sibling, not child) - - The wrapper applies: sigmoid(shared_expert_gate(x)) * shared_expert(x). - - It also exposes the underlying shared_expert interface to keep - compatibility with backends that split shared-expert computation. - """ - - def __init__( - self, - shared_expert: Qwen3MoeMLP, - shared_expert_gate: nn.Linear, - ): - super().__init__() - self._shared_expert = shared_expert - self._shared_expert_gate = shared_expert_gate - - @property - def gate_up_proj(self): - return self._shared_expert.gate_up_proj - - @property - def down_proj(self): - return self._shared_expert.down_proj - - @property - def act_fn(self): - return self._shared_expert.act_fn - - def expert_gate(self, x: torch.Tensor): - gate_out = self._shared_expert_gate(x) - if isinstance(gate_out, tuple): - return gate_out - return gate_out, None - - def forward(self, x: torch.Tensor) -> torch.Tensor: - out = self._shared_expert(x) - gate_out = self._shared_expert_gate(x) - if isinstance(gate_out, tuple): - gate_out = gate_out[0] - gate_values = F.sigmoid(gate_out) # [batch, 1] - return gate_values * out # Broadcasting: [batch, 1] * [batch, hidden] - - -class Qwen3OmniMoeTalkerSparseMoeBlock(nn.Module): - """ - Sparse MoE block for Qwen3 Omni MoE Talker with shared expert support. - - This block uses SharedFusedMoE to efficiently compute both routed experts - and the shared expert, potentially overlapping computation with communication. - - Weight structure matches HuggingFace: - - mlp.gate.weight (router) - - mlp.shared_expert.{gate_proj, up_proj, down_proj}.weight - - mlp.shared_expert_gate.weight - - mlp.experts.{0..n}.{gate_proj, up_proj, down_proj}.weight - """ - - def __init__( - self, - config: Qwen3OmniMoeTalkerConfig, - quant_config: QuantizationConfig | None = None, - prefix: str = "", - ): - super().__init__() - text_config = config.text_config - self.tp_size = get_tensor_model_parallel_world_size() - - if self.tp_size > text_config.num_experts: - raise ValueError( - f"Tensor parallel size {self.tp_size} is greater than the number of experts {text_config.num_experts}." - ) - - # Router gate for selecting top-k experts - self.gate = ReplicatedLinear( - text_config.hidden_size, - text_config.num_experts, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.gate", - ) - - # Shared expert MLP (matches HF: mlp.shared_expert.*) - if text_config.shared_expert_intermediate_size > 0: - self.shared_expert = Qwen3MoeMLP( - hidden_size=text_config.hidden_size, - intermediate_size=text_config.shared_expert_intermediate_size, - hidden_act=text_config.hidden_act, - quant_config=quant_config, - reduce_results=False, # Don't reduce, we'll handle it - prefix=f"{prefix}.shared_expert", - ) - # Shared expert gate (matches HF: mlp.shared_expert_gate.weight) - # This is a sibling of shared_expert, not a child - self.shared_expert_gate = torch.nn.Linear(text_config.hidden_size, 1, bias=False) - # Create wrapper for SharedFusedMoE - self._shared_expert_wrapper = Qwen3OmniMoeTalkerSharedExpertWrapper( - self.shared_expert, self.shared_expert_gate - ) - else: - self.shared_expert = None - self.shared_expert_gate = None - self._shared_expert_wrapper = None - - # Fused MoE with shared expert support - self.experts = SharedFusedMoE( - shared_experts=self._shared_expert_wrapper, - num_experts=text_config.num_experts, - top_k=text_config.num_experts_per_tok, - hidden_size=text_config.hidden_size, - intermediate_size=text_config.moe_intermediate_size, - reduce_results=False, # We'll reduce manually after combining - renormalize=text_config.norm_topk_prob, - quant_config=quant_config, - prefix=f"{prefix}.experts", - ) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - # NOTE: hidden_states can have either 1D or 2D shape. - orig_shape = hidden_states.shape - hidden_dim = hidden_states.shape[-1] - hidden_states = hidden_states.view(-1, hidden_dim) - - # Compute router logits - router_logits, _ = self.gate(hidden_states) - - # Forward through SharedFusedMoE - # Returns (shared_out, fused_out) when shared_expert is present - final_hidden_states = self.experts(hidden_states=hidden_states, router_logits=router_logits) - - # Combine shared and routed expert outputs - if self._shared_expert_wrapper is not None: - # SharedFusedMoE returns tuple: (shared_out, fused_out) - final_hidden_states = final_hidden_states[0] + final_hidden_states[1] - - # Apply tensor parallel reduction if needed - if self.tp_size > 1: - final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(final_hidden_states) - - return final_hidden_states.view(orig_shape) - - class Qwen3OmniMoeModel(Qwen3MoeLLMForCausalLM): """ Qwen3 Omni MoE Talker language model. - This model extends Qwen3MoeLLMForCausalLM with: - - Shared expert support via SharedFusedMoE - - Codec embedding instead of text embedding - - No LM head (codec head is separate in the parent class) + Extends Qwen3MoeLLMForCausalLM (which already uses SharedFusedMoE with + shared-expert support) and replaces the text embedding / LM head with a + codec embedding so the talker operates over audio-codec tokens instead + of text tokens. """ def __init__(self, vllm_config: VllmConfig, talker_config: Qwen3OmniMoeTalkerConfig, prefix: str): @@ -699,32 +545,6 @@ def __init__(self, vllm_config: VllmConfig, talker_config: Qwen3OmniMoeTalkerCon talker_config.text_config.hidden_size, ) - # Replace MoE blocks with shared expert versions - self._replace_moe_blocks_with_shared_expert(prefix) - - def _replace_moe_blocks_with_shared_expert(self, prefix: str) -> None: - """ - Replace Qwen3MoeSparseMoeBlock layers with Qwen3OmniMoeTalkerSparseMoeBlock - that includes shared expert support via SharedFusedMoE. - """ - # Get compilation config to clean up registered layer names - compilation_config = self.talker_vllm_config.compilation_config - - for layer_idx, layer in enumerate(self.model.layers): - # Check if this layer has a MoE block (has experts attribute) - if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock): - # Remove old layer registration from static_forward_context - old_experts_prefix = f"{prefix}.model.layers.{layer_idx}.mlp.experts" - if old_experts_prefix in compilation_config.static_forward_context: - del compilation_config.static_forward_context[old_experts_prefix] - - # Create new MoE block with shared expert support - layer.mlp = Qwen3OmniMoeTalkerSparseMoeBlock( - config=self.config, - quant_config=self.talker_vllm_config.quant_config, - prefix=f"{prefix}.model.layers.{layer_idx}.mlp", - ) - def embed_input_ids( self, input_ids: torch.Tensor,