From 066a22689a48e0b40c450b72d3f7f19242b84506 Mon Sep 17 00:00:00 2001 From: pratapyash Date: Sun, 15 Mar 2026 10:59:25 +0000 Subject: [PATCH 1/3] feat: add lora support; enhance Qwen3OmniMoeAudioEncoder with ReplicatedLinear and quantization support - Replaced nn.Linear with ReplicatedLinear for conv_out, proj1, and proj2 layers to support quantization. - Added quant_config parameter to Qwen3OmniMoeAudioEncoder constructor. - Updated method calls to handle outputs from ReplicatedLinear layers. - Included SupportsLoRA in Qwen3OmniMoeThinkerForConditionalGeneration class. --- .../models/qwen3_omni_moe_thinker.py | 32 +++++++++++++------ 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/vllm/model_executor/models/qwen3_omni_moe_thinker.py b/vllm/model_executor/models/qwen3_omni_moe_thinker.py index ff352a735a65..7dd76f81c31f 100755 --- a/vllm/model_executor/models/qwen3_omni_moe_thinker.py +++ b/vllm/model_executor/models/qwen3_omni_moe_thinker.py @@ -60,6 +60,7 @@ from vllm.model_executor.layers.linear import ( ColumnParallelLinear, QKVParallelLinear, + ReplicatedLinear, RowParallelLinear, ) from vllm.model_executor.layers.logits_processor import LogitsProcessor @@ -85,6 +86,7 @@ from .interfaces import ( MultiModalEmbeddings, + SupportsLoRA, SupportsMRoPE, SupportsMultiModal, SupportsPP, @@ -324,6 +326,7 @@ class Qwen3OmniMoeAudioEncoder(nn.Module): def __init__( self, config: Qwen3OmniMoeAudioEncoderConfig, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -360,7 +363,10 @@ def __init__( conv_out_dim = config.downsample_hidden_size * ( (((config.num_mel_bins + 1) // 2 + 1) // 2 + 1) // 2 ) - self.conv_out = nn.Linear(conv_out_dim, config.d_model, bias=False) + self.conv_out = ReplicatedLinear( + conv_out_dim, config.d_model, bias=False, + quant_config=quant_config, + ) # Transformer encoder layers self.layers = nn.ModuleList( @@ -375,9 +381,15 @@ def __init__( # Output layers self.ln_post = nn.LayerNorm(config.d_model) - self.proj1 = nn.Linear(config.d_model, config.d_model) + self.proj1 = ReplicatedLinear( + config.d_model, config.d_model, bias=True, + quant_config=quant_config, + ) self.act = _ACTIVATION_REGISTRY[config.activation_function] - self.proj2 = nn.Linear(config.d_model, config.output_dim) + self.proj2 = ReplicatedLinear( + config.d_model, config.output_dim, bias=True, + quant_config=quant_config, + ) # Get attention backend self.attn_backend = get_vit_attn_backend( @@ -458,7 +470,7 @@ def forward( # (batch, channels, freq, time) -> (batch, time, channels*freq) b, c, f, t = padded_embed.size() - padded_embed = self.conv_out( + padded_embed, _ = self.conv_out( padded_embed.permute(0, 3, 1, 2).contiguous().view(b, t, c * f) ) @@ -501,9 +513,9 @@ def forward( # Apply output layers hidden_states = self.ln_post(hidden_states) - hidden_states = self.proj1(hidden_states) + hidden_states, _ = self.proj1(hidden_states) hidden_states = self.act(hidden_states) - hidden_states = self.proj2(hidden_states) + hidden_states, _ = self.proj2(hidden_states) return hidden_states @@ -1645,6 +1657,7 @@ def _process_audio_input( class Qwen3OmniMoeThinkerForConditionalGeneration( nn.Module, SupportsMultiModal, + SupportsLoRA, SupportsPP, SupportsMRoPE, Qwen3OmniMoeConditionalGenerationMixin, @@ -1664,12 +1677,10 @@ class Qwen3OmniMoeThinkerForConditionalGeneration( "k_proj", "v_proj", ], - "gate_up_proj": [ - "gate_proj", - "up_proj", - ], } + embedding_modules = {} + supported_languages = ISO639_1_SUPPORTED_LANGS @classmethod @@ -1698,6 +1709,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): with self._mark_tower_model(vllm_config, "audio"): self.audio_tower = Qwen3OmniMoeAudioEncoder( thinker_config.audio_config, + quant_config=quant_config, prefix=maybe_prefix(prefix, "audio_tower"), ) From e1bd31bee271f2c9e2bacff806736bbf388f130a Mon Sep 17 00:00:00 2001 From: pratapyash Date: Sun, 15 Mar 2026 13:36:47 +0000 Subject: [PATCH 2/3] refactor: replace ReplicatedLinear with nn.Linear in Qwen3OmniMoeAudioEncoder - Removed ReplicatedLinear usage for conv_out, proj1, and proj2 layers. - Eliminated quant_config parameter from Qwen3OmniMoeAudioEncoder constructor. - Updated method calls to reflect changes in layer outputs. --- .../models/qwen3_omni_moe_thinker.py | 24 +++++-------------- 1 file changed, 6 insertions(+), 18 deletions(-) diff --git a/vllm/model_executor/models/qwen3_omni_moe_thinker.py b/vllm/model_executor/models/qwen3_omni_moe_thinker.py index 7dd76f81c31f..8481a0a05960 100755 --- a/vllm/model_executor/models/qwen3_omni_moe_thinker.py +++ b/vllm/model_executor/models/qwen3_omni_moe_thinker.py @@ -60,7 +60,6 @@ from vllm.model_executor.layers.linear import ( ColumnParallelLinear, QKVParallelLinear, - ReplicatedLinear, RowParallelLinear, ) from vllm.model_executor.layers.logits_processor import LogitsProcessor @@ -326,7 +325,6 @@ class Qwen3OmniMoeAudioEncoder(nn.Module): def __init__( self, config: Qwen3OmniMoeAudioEncoderConfig, - quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -363,10 +361,7 @@ def __init__( conv_out_dim = config.downsample_hidden_size * ( (((config.num_mel_bins + 1) // 2 + 1) // 2 + 1) // 2 ) - self.conv_out = ReplicatedLinear( - conv_out_dim, config.d_model, bias=False, - quant_config=quant_config, - ) + self.conv_out = nn.Linear(conv_out_dim, config.d_model, bias=False) # Transformer encoder layers self.layers = nn.ModuleList( @@ -381,15 +376,9 @@ def __init__( # Output layers self.ln_post = nn.LayerNorm(config.d_model) - self.proj1 = ReplicatedLinear( - config.d_model, config.d_model, bias=True, - quant_config=quant_config, - ) + self.proj1 = nn.Linear(config.d_model, config.d_model) self.act = _ACTIVATION_REGISTRY[config.activation_function] - self.proj2 = ReplicatedLinear( - config.d_model, config.output_dim, bias=True, - quant_config=quant_config, - ) + self.proj2 = nn.Linear(config.d_model, config.output_dim) # Get attention backend self.attn_backend = get_vit_attn_backend( @@ -470,7 +459,7 @@ def forward( # (batch, channels, freq, time) -> (batch, time, channels*freq) b, c, f, t = padded_embed.size() - padded_embed, _ = self.conv_out( + padded_embed = self.conv_out( padded_embed.permute(0, 3, 1, 2).contiguous().view(b, t, c * f) ) @@ -513,9 +502,9 @@ def forward( # Apply output layers hidden_states = self.ln_post(hidden_states) - hidden_states, _ = self.proj1(hidden_states) + hidden_states = self.proj1(hidden_states) hidden_states = self.act(hidden_states) - hidden_states, _ = self.proj2(hidden_states) + hidden_states = self.proj2(hidden_states) return hidden_states @@ -1709,7 +1698,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): with self._mark_tower_model(vllm_config, "audio"): self.audio_tower = Qwen3OmniMoeAudioEncoder( thinker_config.audio_config, - quant_config=quant_config, prefix=maybe_prefix(prefix, "audio_tower"), ) From 2d11aa82afd4851fdc4467cfac2acd5eed20c1c5 Mon Sep 17 00:00:00 2001 From: pratapyash Date: Mon, 16 Mar 2026 11:27:08 +0000 Subject: [PATCH 3/3] feat: add support for skipping audio/vision tower modules during LoRA loading - Introduced lora_skip_prefixes to exclude audio_tower and visual modules from LoRA loading. - This change addresses the requirement for enable_tower_connector_lora, which is not yet supported. --- vllm/model_executor/models/qwen3_omni_moe_thinker.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm/model_executor/models/qwen3_omni_moe_thinker.py b/vllm/model_executor/models/qwen3_omni_moe_thinker.py index 8481a0a05960..60eb14ad06a0 100755 --- a/vllm/model_executor/models/qwen3_omni_moe_thinker.py +++ b/vllm/model_executor/models/qwen3_omni_moe_thinker.py @@ -1670,6 +1670,10 @@ class Qwen3OmniMoeThinkerForConditionalGeneration( embedding_modules = {} + # Skip audio/vision tower modules during LoRA loading -- tower LoRA + # requires enable_tower_connector_lora which is not yet supported. + lora_skip_prefixes = ["audio_tower.", "visual."] + supported_languages = ISO639_1_SUPPORTED_LANGS @classmethod