Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
266 changes: 168 additions & 98 deletions vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_talker.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@
Qwen3OmniMoeAudioEncoder,
)
from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.models.interfaces import (
MultiModalEmbeddings,
SupportsMultiModal,
Expand All @@ -27,13 +30,12 @@
from vllm.model_executor.models.qwen2_5_omni_thinker import (
Qwen2_5OmniThinkerDummyInputsBuilder,
)
from vllm.model_executor.models.qwen3_moe import Qwen3MoeMLP
from vllm.model_executor.models.qwen3_moe import Qwen3MoeMLP, Qwen3MoeSparseMoeBlock
from vllm.model_executor.models.qwen3_omni_moe_thinker import Qwen3Omni_VisionTransformer
from vllm.model_executor.models.utils import (
AutoWeightsLoader,
WeightsMapper,
maybe_prefix,
sequence_parallel_chunk,
)
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import IntermediateTensors
Expand Down Expand Up @@ -531,130 +533,198 @@ def forward(self, hidden_state):
return self.linear_fc2(self.act_fn(self.linear_fc1(hidden_state)))


class Qwen3OmniMoeTalkerSharedExpertWrapper(nn.Module):
"""
Wrapper that combines shared_expert MLP with its sigmoid gate.

This matches the HuggingFace weight structure where:
- mlp.shared_expert.{gate_proj, up_proj, down_proj}.weight
- mlp.shared_expert_gate.weight (sibling, not child)

The wrapper applies: sigmoid(shared_expert_gate(x)) * shared_expert(x)
"""

def __init__(
self,
shared_expert: Qwen3MoeMLP,
shared_expert_gate: nn.Linear,
):
super().__init__()
self._shared_expert = shared_expert
self._shared_expert_gate = shared_expert_gate

def forward(self, x: torch.Tensor) -> torch.Tensor:
out = self._shared_expert(x)
gate_values = F.sigmoid(self._shared_expert_gate(x)) # [batch, 1]
return gate_values * out # Broadcasting: [batch, 1] * [batch, hidden]


class Qwen3OmniMoeTalkerSparseMoeBlock(nn.Module):
"""
Sparse MoE block for Qwen3 Omni MoE Talker with shared expert support.

This block uses SharedFusedMoE to efficiently compute both routed experts
and the shared expert, potentially overlapping computation with communication.

Weight structure matches HuggingFace:
- mlp.gate.weight (router)
- mlp.shared_expert.{gate_proj, up_proj, down_proj}.weight
- mlp.shared_expert_gate.weight
- mlp.experts.{0..n}.{gate_proj, up_proj, down_proj}.weight
"""

def __init__(
self,
config: Qwen3OmniMoeTalkerConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__()
text_config = config.text_config
self.tp_size = get_tensor_model_parallel_world_size()

if self.tp_size > text_config.num_experts:
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than the number of experts {text_config.num_experts}."
)

# Router gate for selecting top-k experts
self.gate = ReplicatedLinear(
text_config.hidden_size,
text_config.num_experts,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate",
)

# Shared expert MLP (matches HF: mlp.shared_expert.*)
if text_config.shared_expert_intermediate_size > 0:
self.shared_expert = Qwen3MoeMLP(
hidden_size=text_config.hidden_size,
intermediate_size=text_config.shared_expert_intermediate_size,
hidden_act=text_config.hidden_act,
quant_config=quant_config,
reduce_results=False, # Don't reduce, we'll handle it
prefix=f"{prefix}.shared_expert",
)
# Shared expert gate (matches HF: mlp.shared_expert_gate.weight)
# This is a sibling of shared_expert, not a child
self.shared_expert_gate = torch.nn.Linear(text_config.hidden_size, 1, bias=False)
# Create wrapper for SharedFusedMoE
self._shared_expert_wrapper = Qwen3OmniMoeTalkerSharedExpertWrapper(
self.shared_expert, self.shared_expert_gate
)
else:
self.shared_expert = None
self.shared_expert_gate = None
self._shared_expert_wrapper = None

# Fused MoE with shared expert support
self.experts = SharedFusedMoE(
shared_experts=self._shared_expert_wrapper,
num_experts=text_config.num_experts,
top_k=text_config.num_experts_per_tok,
hidden_size=text_config.hidden_size,
intermediate_size=text_config.moe_intermediate_size,
reduce_results=False, # We'll reduce manually after combining
renormalize=text_config.norm_topk_prob,
quant_config=quant_config,
prefix=f"{prefix}.experts",
)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# NOTE: hidden_states can have either 1D or 2D shape.
orig_shape = hidden_states.shape
hidden_dim = hidden_states.shape[-1]
hidden_states = hidden_states.view(-1, hidden_dim)

# Compute router logits
router_logits, _ = self.gate(hidden_states)

# Forward through SharedFusedMoE
# Returns (shared_out, fused_out) when shared_expert is present
final_hidden_states = self.experts(hidden_states=hidden_states, router_logits=router_logits)

# Combine shared and routed expert outputs
if self._shared_expert_wrapper is not None:
# SharedFusedMoE returns tuple: (shared_out, fused_out)
final_hidden_states = final_hidden_states[0] + final_hidden_states[1]

# Apply tensor parallel reduction if needed
if self.tp_size > 1:
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(final_hidden_states)
Comment on lines +648 to +655
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Reduce routed output before adding shared expert

If the shared expert is instantiated on every tensor-parallel rank (as it is here) and only the routed experts are sharded, the current order will all-reduce the shared expert output along with the routed output. That sums the shared contribution across TP ranks when tp_size > 1, so logits scale with TP size and diverge from HF for multi-GPU runs. A safer order is to all-reduce only the routed output, then add the shared output after reduction (or otherwise prevent the shared output from being summed). This only affects TP>1 with a replicated shared expert.

Useful? React with 👍 / 👎.


return final_hidden_states.view(orig_shape)


class Qwen3OmniMoeModel(Qwen3MoeLLMForCausalLM):
def __init__(self, vllm_config, talker_config, prefix):
"""
Qwen3 Omni MoE Talker language model.

This model extends Qwen3MoeLLMForCausalLM with:
- Shared expert support via SharedFusedMoE
- Codec embedding instead of text embedding
- No LM head (codec head is separate in the parent class)
"""

def __init__(self, vllm_config: VllmConfig, talker_config: Qwen3OmniMoeTalkerConfig, prefix: str):
# Create a vllm_config for the talker's text model
talker_vllm_config = vllm_config.with_hf_config(
talker_config.text_config, architectures=["Qwen3MoeForCausalLM"]
)
talker_vllm_config.model_config.hf_text_config = talker_vllm_config.model_config.hf_config

super().__init__(
vllm_config=talker_vllm_config,
prefix=prefix,
)

self.config = talker_config
self.talker_vllm_config = talker_vllm_config

# Remove the inherited LM head so the talker only exposes codec outputs.
if hasattr(self, "lm_head"):
del self.lm_head

# Replace the base embed tokens with codec embedding (defined below).
# Replace the base embed tokens with codec embedding.
if hasattr(self.model, "embed_tokens"):
del self.model.embed_tokens

# Codec embedding for RVQ code generation
self.model.codec_embedding = nn.Embedding(
talker_config.text_config.vocab_size, talker_config.text_config.hidden_size
talker_config.text_config.vocab_size,
talker_config.text_config.hidden_size,
)

# Add shared expert to each MoE layer and patch the forward method
layer_idx = 0
for layer in self.model.layers:
# add shared expert to Qwen3OmniMoeSparseMoeBlock layers
if hasattr(layer.mlp, "experts"): # Check if it's a SparseMoeBlock
# Shared expert is a regular gated MLP (SwiGLU)
layer.mlp.shared_expert = Qwen3MoeMLP(
hidden_size=self.config.text_config.hidden_size,
intermediate_size=self.config.text_config.shared_expert_intermediate_size,
hidden_act=self.config.text_config.hidden_act,
quant_config=talker_vllm_config.quant_config,
reduce_results=False, # Don't reduce since we'll add it manually
prefix=f"{prefix}.layers.{layer_idx}.mlp.shared_expert",
)
# Replace MoE blocks with shared expert versions
self._replace_moe_blocks_with_shared_expert(prefix)

# Shared expert gate outputs a single scalar per token
layer.mlp.shared_expert_gate = ReplicatedLinear(
self.config.text_config.hidden_size,
1, # Output single scalar per token
bias=False,
quant_config=None,
prefix=f"{prefix}.layers.{layer_idx}.mlp.shared_expert_gate",
)

# Store MoE config values for router computation
layer.mlp.top_k = self.config.text_config.num_experts_per_tok
layer.mlp.norm_topk_prob = self.config.text_config.norm_topk_prob
layer.mlp.num_experts = self.config.text_config.num_experts

# Monkey-patch the forward method to use shared expert
layer.mlp.forward = self._create_moe_forward_with_shared_expert(layer.mlp)

layer_idx += 1

def _create_moe_forward_with_shared_expert(self, moe_layer):
"""Create a forward method that includes shared expert computation.

This matches the Transformers implementation where:
1. Compute shared expert output (regular MLP)
2. Gate it with sigmoid(shared_expert_gate(x))
3. Apply softmax BEFORE top-k selection (matches Transformers router)
4. Add to routed expert outputs
def _replace_moe_blocks_with_shared_expert(self, prefix: str) -> None:
"""

def forward_with_shared_expert(hidden_states: torch.Tensor, layer_idx: int = 0) -> torch.Tensor:
# Save original shape
orig_shape = hidden_states.shape
hidden_dim = hidden_states.shape[-1]
hidden_states = hidden_states.view(-1, hidden_dim)

# handle sequence parallel if needed
if hasattr(moe_layer, "is_sequence_parallel") and moe_layer.is_sequence_parallel:
hidden_states = sequence_parallel_chunk(hidden_states)

# Compute shared expert output
# The shared expert is a regular MLP, not a routed MoE
shared_output = None
if hasattr(moe_layer, "shared_expert") and moe_layer.shared_expert is not None:
# Forward through shared expert MLP
shared_output = moe_layer.shared_expert(hidden_states)

# Apply gating with sigmoid: sigmoid(gate(x)) * shared_expert(x)
if hasattr(moe_layer, "shared_expert_gate") and moe_layer.shared_expert_gate is not None:
gate_logits, _ = moe_layer.shared_expert_gate(hidden_states)
gate_values = F.sigmoid(gate_logits) # [batch, 1]
shared_output = gate_values * shared_output # Broadcasting: [batch, 1] * [batch, hidden]

# Compute experts results
# router_logits: (num_tokens, n_experts)
router_logits, _ = moe_layer.gate(hidden_states)
experts_output = moe_layer.experts(hidden_states=hidden_states, router_logits=router_logits)

# combine experts and shared expert results
if shared_output is not None:
final_hidden_states = experts_output + shared_output

# Handle sequence parallel if needed
if hasattr(moe_layer, "is_sequence_parallel") and moe_layer.is_sequence_parallel:
from vllm.distributed import tensor_model_parallel_all_gather

num_tokens = orig_shape[0] if len(orig_shape) > 1 else 1
final_hidden_states = tensor_model_parallel_all_gather(final_hidden_states, 0)
final_hidden_states = final_hidden_states[:num_tokens]
try:
final_hidden_states.view(orig_shape)
except Exception as e:
print(f"Error viewing final hidden states: {e}")
print(f"final_hidden_states.shape: {final_hidden_states.shape}")
print(f"orig_shape: {orig_shape}")
raise e
# Return with original shape
return final_hidden_states.view(orig_shape)

return forward_with_shared_expert
Replace Qwen3MoeSparseMoeBlock layers with Qwen3OmniMoeTalkerSparseMoeBlock
that includes shared expert support via SharedFusedMoE.
"""
# Get compilation config to clean up registered layer names
compilation_config = self.talker_vllm_config.compilation_config

for layer_idx, layer in enumerate(self.model.layers):
# Check if this layer has a MoE block (has experts attribute)
if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock):
# Remove old layer registration from static_forward_context
old_experts_prefix = f"{prefix}.model.layers.{layer_idx}.mlp.experts"
if old_experts_prefix in compilation_config.static_forward_context:
del compilation_config.static_forward_context[old_experts_prefix]

# Create new MoE block with shared expert support
layer.mlp = Qwen3OmniMoeTalkerSparseMoeBlock(
config=self.config,
quant_config=self.talker_vllm_config.quant_config,
prefix=f"{prefix}.model.layers.{layer_idx}.mlp",
)
Comment on lines +704 to +723
Copy link
Copy Markdown
Member

@Isotr0py Isotr0py Dec 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks a bit hacky for compilation workaround. Perhaps we can upstream the SharedFusedMoE support to vLLM's qwen3_moe.py to avoid this in following PR?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. We should make it upstream later.


def embed_input_ids(
self,
input_ids: torch.Tensor,
**kwargs: object,
) -> torch.Tensor:
"""Embed codec input IDs."""
return self.model.codec_embedding(input_ids)
Comment thread
gcanlin marked this conversation as resolved.