Skip to content
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,10 @@ class Qwen3OmniMoeTalkerSharedExpertWrapper(nn.Module):
- mlp.shared_expert.{gate_proj, up_proj, down_proj}.weight
- mlp.shared_expert_gate.weight (sibling, not child)

The wrapper applies: sigmoid(shared_expert_gate(x)) * shared_expert(x)
The wrapper applies: sigmoid(shared_expert_gate(x)) * shared_expert(x).

It also exposes the underlying shared_expert interface to keep
compatibility with backends that split shared-expert computation.
"""

def __init__(
Expand All @@ -575,9 +578,30 @@ def __init__(
self._shared_expert = shared_expert
self._shared_expert_gate = shared_expert_gate

@property
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need the modification in the model file?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We add those properties so the wrapper presents the same interface as a normal Qwen MLP. vLLM-Ascend’s SharedFusedMoE validation splits the shared expert into gate_up_proj → act_fn → down_proj and expects these attributes to exist. The wrapper only holds the real MLP, so we forward them to keep validation and split execution working on NPU. It doesn't affect GPU behaviour. When we upgrade to vLLM v0.15.0 or later, we will remove this hack wrapper because upstream has implemented shared expert in Qwen3-MoE that we can reuse directly.

def gate_up_proj(self):
return self._shared_expert.gate_up_proj

@property
def down_proj(self):
return self._shared_expert.down_proj

@property
def act_fn(self):
return self._shared_expert.act_fn

def expert_gate(self, x: torch.Tensor):
gate_out = self._shared_expert_gate(x)
if isinstance(gate_out, tuple):
return gate_out
return gate_out, None

def forward(self, x: torch.Tensor) -> torch.Tensor:
out = self._shared_expert(x)
gate_values = F.sigmoid(self._shared_expert_gate(x)) # [batch, 1]
gate_out = self._shared_expert_gate(x)
if isinstance(gate_out, tuple):
gate_out = gate_out[0]
gate_values = F.sigmoid(gate_out) # [batch, 1]
return gate_values * out # Broadcasting: [batch, 1] * [batch, hidden]


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ stage_args:
worker_cls: vllm_omni.worker.npu.npu_ar_worker.NPUARWorker
scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
gpu_memory_utilization: 0.6
enforce_eager: true
enforce_eager: false
trust_remote_code: true
engine_output_type: latent # Output hidden states for talker
distributed_executor_backend: "mp"
Expand Down Expand Up @@ -44,7 +44,7 @@ stage_args:
worker_cls: vllm_omni.worker.npu.npu_ar_worker.NPUARWorker
scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
gpu_memory_utilization: 0.2
enforce_eager: true
enforce_eager: false
trust_remote_code: true
engine_output_type: latent # Output codec codes for code2wav
# tensor_parallel_size: 2
Expand Down Expand Up @@ -75,6 +75,7 @@ stage_args:
scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
enforce_eager: true
trust_remote_code: true
async_scheduling: false
enable_prefix_caching: false
engine_output_type: audio # Final output: audio waveform
gpu_memory_utilization: 0.1
Expand Down
22 changes: 22 additions & 0 deletions vllm_omni/model_executor/stage_configs/npu/qwen3_tts.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
stage_args:
- stage_id: 0
stage_type: llm # Use llm stage type to launch OmniLLM
runtime:
devices: "0"
max_batch_size: 1
engine_args:
model_stage: qwen3_tts
model_arch: Qwen3TTSForConditionalGeneration
worker_cls: vllm_omni.worker.npu.npu_generation_worker.NPUGenerationWorker
scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
enforce_eager: true
trust_remote_code: true
async_scheduling: false
enable_prefix_caching: false
engine_output_type: audio # Final output: audio waveform
gpu_memory_utilization: 0.1
distributed_executor_backend: "mp"
max_num_batched_tokens: 1000000

final_output: true
final_output_type: audio
63 changes: 0 additions & 63 deletions vllm_omni/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from vllm_omni.inputs.data import OmniTokensPrompt
from vllm_omni.model_executor.layers.mrope import MRotaryEmbedding
from vllm_omni.request import OmniRequest
from vllm_omni.utils import is_npu

for module_name, module in sys.modules.items():
# only do patch on module of vllm, pass others
Expand All @@ -32,65 +31,3 @@
module.Request = OmniRequest
if hasattr(module, "EngineCoreRequest") and module.EngineCoreRequest == _OriginalEngineCoreRequest:
module.EngineCoreRequest = OmniEngineCoreRequest


# Patch for vllm-ascend prefetch functions bug fix
# Issue: The original functions access forward_context attributes like
# prefetch_mlp_gate_up_proj, prefetch_mlp_down_proj, layer_idx without checking
# if they exist, which causes AttributeError when prefetch_mlp_enabled is not set.
# TODO: Remove this patch after upgrading to vllm-ascend v0.13.0 or later.
# This issue has been fixed in https://github.com/vllm-project/vllm-ascend/pull/5035
if is_npu():
import torch
import torch.nn as nn
from vllm.model_executor.models.qwen2_5_omni_thinker import Qwen2_5_VLImageInputs, Qwen2_5_VLVideoInputs
from vllm_ascend.ascend_forward_context import set_ascend_forward_context

from vllm_omni.model_executor.models.qwen2_5_omni.qwen2_5_omni_thinker import (
Qwen2_5OmniThinkerForConditionalGeneration,
)

class AscendQwen2_5OmniThinkerForConditionalGeneration(nn.Module):
def _process_image_input(self, image_input: Qwen2_5_VLImageInputs) -> tuple[torch.Tensor, ...]:
if image_input["type"] == "image_embeds":
return image_input["image_embeds"].type(self.visual.dtype)

grid_thw = image_input["image_grid_thw"]
assert grid_thw.ndim == 2

pixel_values = image_input["pixel_values"].type(self.visual.dtype)
with set_ascend_forward_context(None, self.vllm_config):
image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
# Split concatenated embeddings for each image item.
merge_size = self.visual.spatial_merge_size
sizes = grid_thw.prod(-1) // merge_size // merge_size

return image_embeds.split(sizes.tolist())

def _process_video_input(
self,
video_input: Qwen2_5_VLVideoInputs,
video_hashes: list[str] | None = None,
cached_video_embeds: torch.Tensor | None = None,
) -> torch.Tensor:
if video_input["type"] == "video_embeds":
return video_input["video_embeds"].type(self.visual.dtype)

grid_thw = video_input["video_grid_thw"]
assert grid_thw.ndim == 2

pixel_values_videos = video_input["pixel_values_videos"].type(self.visual.dtype)
with set_ascend_forward_context(None, self.vllm_config):
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
# Split concatenated embeddings for each video item.
merge_size = self.visual.spatial_merge_size
sizes = grid_thw.prod(-1) // merge_size // merge_size

return video_embeds.split(sizes.tolist())

Qwen2_5OmniThinkerForConditionalGeneration._process_image_input = (
AscendQwen2_5OmniThinkerForConditionalGeneration._process_image_input
)
Qwen2_5OmniThinkerForConditionalGeneration._process_video_input = (
AscendQwen2_5OmniThinkerForConditionalGeneration._process_video_input
)
Loading