Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
f19650d
Support Qwen3 Omni model quantization
lvliang-intel Feb 4, 2026
913a993
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 4, 2026
fb0ae29
Merge branch 'main' into lvl/support_omni
lvliang-intel Feb 4, 2026
2fef4c2
fix lint issue
lvliang-intel Feb 4, 2026
9000a8d
Merge branch 'main' into lvl/support_omni
lvliang-intel Feb 9, 2026
cf69613
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 11, 2026
5072aff
fix issuse with transformers running
lvliang-intel Feb 11, 2026
6797b5d
Merge branch 'main' into lvl/support_omni
lvliang-intel Feb 11, 2026
5ca1500
Merge branch 'main' into lvl/support_omni
lvliang-intel Feb 11, 2026
9e95103
Update auto_round/special_model_handler.py
lvliang-intel Feb 11, 2026
b2a65fc
remove unnecessary code
lvliang-intel Feb 11, 2026
fa9b4e4
Merge branch 'main' into lvl/support_omni
lvliang-intel Mar 3, 2026
4a3d1c5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 3, 2026
3e2906e
Update special_model_handler.py
lvliang-intel Mar 3, 2026
40e5fa0
Merge branch 'main' into lvl/support_omni
lvliang-intel Mar 4, 2026
6a8cd62
support qwen2.5 omni
lvliang-intel Mar 4, 2026
8efe1f7
Merge branch 'main' of https://github.com/intel/auto-round into lvl/s…
lvliang-intel Mar 5, 2026
88a04f3
Merge branch 'lvl/support_omni' of https://github.com/intel/auto-roun…
lvliang-intel Mar 5, 2026
bc49bca
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 5, 2026
4733d83
fix qwen2.5 omni issues
lvliang-intel Mar 5, 2026
e60e3e7
add ut
lvliang-intel Mar 6, 2026
6926dfb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 6, 2026
4c89a4a
Merge branch 'main' into lvl/support_omni
lvliang-intel Mar 6, 2026
54cb02a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 6, 2026
a4fe356
Merge branch 'lvl/support_omni' of https://github.com/intel/auto-roun…
lvliang-intel Mar 6, 2026
681d391
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 6, 2026
4ac6f9b
fix lint issue
lvliang-intel Mar 6, 2026
1f2f464
Merge branch 'lvl/support_omni' of https://github.com/intel/auto-roun…
lvliang-intel Mar 6, 2026
9d976d6
Merge branch 'main' into lvl/support_omni
lvliang-intel Mar 17, 2026
5f32868
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 17, 2026
633f3af
Apply suggestion from @Copilot
lvliang-intel Mar 17, 2026
fbfebba
Merge branch 'main' into lvl/support_omni
lvliang-intel Mar 17, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion auto_round/auto_scheme/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
get_layer_features,
get_module,
is_hpex_available,
normalize_no_split_modules,
parse_available_devices,
)

Expand Down Expand Up @@ -219,7 +220,7 @@ def dispatch_model_by_all_available_devices(
if device_map is None:
device_map = 0

no_split_modules = list(getattr(model, "_no_split_modules", []))
no_split_modules = normalize_no_split_modules(getattr(model, "_no_split_modules", []))
if device_map == "auto":
max_memory = get_balanced_memory(
model,
Expand Down
1 change: 1 addition & 0 deletions auto_round/compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@
llm_load_model,
memory_monitor,
mv_module_from_gpu,
normalize_no_split_modules,
set_amax_for_all_moe_layers,
set_module,
to_device,
Expand Down
70 changes: 70 additions & 0 deletions auto_round/compressors/mllm/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,76 @@ def squeeze_result(ret):
return ret


@register_processor("qwen2_5_omni")
class Qwen2_5OmniProcessor(HFProcessor):
"""Processor for Qwen2.5-Omni multimodal models.

Qwen2.5-Omni supports text, image, video, and audio inputs.
This processor handles proper tokenization and preprocessing for calibration.
"""

@staticmethod
def squeeze_result(ret):
for key in ret:
# Skip squeezing for multi-modal data that may have special dimensions
if key in ["pixel_values", "pixel_values_videos", "input_features"]:
continue
ret[key] = ret[key][0]
return ret

def _process_v1(self, messages, image):
"""Process messages for Qwen2.5-Omni model."""
conversation = []
for content in messages:
conversation.append(
{
"role": content["role"],
"content": [{"text": content["content"].replace(self.IMAGE_TOKEN, ""), "type": "text"}],
}
)
if self.IMAGE_TOKEN in content["content"]:
conversation[-1]["content"].append({"image": image, "type": "image"})
ret = self.processor.apply_chat_template(
conversation, add_generation_prompt=True, tokenize=True, return_dict=True
)
return ret


@register_processor("qwen3_omni")
class Qwen3OmniProcessor(HFProcessor):
"""Processor for Qwen3-Omni multimodal models.

Qwen3-Omni supports text, image, video, and audio inputs.
This processor handles proper tokenization and preprocessing for calibration.
"""

@staticmethod
def squeeze_result(ret):
for key in ret:
# Skip squeezing for multi-modal data that may have special dimensions
if key in ["pixel_values", "pixel_values_videos", "input_features"]:
continue
ret[key] = ret[key][0]
return ret

def _process_v1(self, messages, image):
"""Process messages for Qwen3-Omni model."""
conversation = []
for content in messages:
conversation.append(
{
"role": content["role"],
"content": [{"text": content["content"].replace(self.IMAGE_TOKEN, ""), "type": "text"}],
}
)
if self.IMAGE_TOKEN in content["content"]:
conversation[-1]["content"].append({"image": image, "type": "image"})
ret = self.processor.apply_chat_template(
conversation, add_generation_prompt=True, tokenize=True, return_dict=True
)
return ret


@register_processor("cogvlm2")
class CogVLM2Processor(BasicProcessor):
def get_input(self, text, images, truncation=False, squeeze=True, max_length=None, **kwargs):
Expand Down
2 changes: 2 additions & 0 deletions auto_round/compressors/mllm/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ def _register_template(

_register_template("qwen2_vl", default_dataset="NeelNanda/pile-10k", processor=PROCESSORS["qwen2_vl"])
_register_template("qwen2_5_vl", default_dataset="NeelNanda/pile-10k", processor=PROCESSORS["qwen2_vl"])
_register_template("qwen2_5_omni", default_dataset="NeelNanda/pile-10k", processor=PROCESSORS["qwen2_5_omni"])
_register_template("qwen3_omni_moe", default_dataset="NeelNanda/pile-10k", processor=PROCESSORS["qwen3_omni"])
_register_template("mllama", default_dataset="liuhaotian/llava", processor=PROCESSORS["hf"])
_register_template("deepseek_vl_v2", default_dataset="NeelNanda/pile-10k", processor=PROCESSORS["deepseek_v2"])
_register_template("mistral3", default_dataset="NeelNanda/pile-10k", processor=PROCESSORS["hf"])
Expand Down
3 changes: 3 additions & 0 deletions auto_round/compressors/mllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
"audio",
"talker",
"token2wav",
"code2wav",
"audio_tower",
"code_predictor",
"multi_modal_projector",
"vision_tower",
"multimodal_projector",
Expand Down
10 changes: 10 additions & 0 deletions auto_round/compressors/shard_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,18 @@ def finalize(self):
# 1. Capture remaining weights not yet saved
full_sd = self.model.state_dict()
tie_word_embeddings = False
config = getattr(self.model, "config", None)
if hasattr(self.model, "config") and hasattr(self.model.config, "tie_word_embeddings"):
tie_word_embeddings = self.model.config.tie_word_embeddings
if tie_word_embeddings is None:
# For multimodal models, check nested text/thinker configs
for sub_attr in ("text_config", "thinker_config", "language_config", "llm_config"):
sub_config = getattr(config, sub_attr, None)
if sub_config is not None:
val = getattr(sub_config, "tie_word_embeddings", None)
if val is not None:
tie_word_embeddings = val
break

finalize_skipped_meta_tensors = []
for pname, tensor in full_sd.items():
Expand Down
4 changes: 4 additions & 0 deletions auto_round/inference/convert_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,10 @@ def post_init(model: torch.nn.Module, used_backends: list[str]) -> None:
used_backends (List[str]): List of backend names used for quantization.

"""
from auto_round.utils.common import monkey_patch_model

monkey_patch_model(model)

need_autogptq_init = False
need_gptqmodel_init = False
need_ipex_init = False
Expand Down
237 changes: 237 additions & 0 deletions auto_round/modeling/fused_moe/qwen3_omni.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
# Copyright (c) 2026 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""MoE module replacements for Qwen3-Omni model.

Qwen3-Omni has MoE blocks in both thinker and talker:
- Thinker: Qwen3OmniMoeThinkerTextSparseMoeBlock (experts + gate, no shared expert)
- Talker: Qwen3OmniMoeTalkerTextSparseMoeBlock (experts + gate + shared_expert + shared_expert_gate)

This module provides replacement classes that unfuse fused expert weights (3D Parameters)
into individual nn.Linear layers, enabling per-expert quantization with meta device optimization.
"""

import torch

from auto_round.modeling.fused_moe.replace_modules import ReplacementModuleBase
from auto_round.modeling.fused_moe.utils import _update_parameter
from auto_round.utils import clear_memory, unsupported_meta_device

# ---------------------------------------------------------------------------
# Thinker MoE replacement (no shared expert)
# ---------------------------------------------------------------------------


class LinearQwen3OmniThinkerSparseMoeBlock(ReplacementModuleBase):
"""Calibration replacement for Qwen3OmniMoeThinkerTextSparseMoeBlock.

Unfuses fused expert weights into individual nn.Linear layers for
per-expert quantization. Uses meta device to avoid doubling memory.

Structure: gate (router) + experts (unfused).
"""

def __init__(self, original, config):
super().__init__(original)
self.gate = original.gate
self.num_experts = original.experts.num_experts
text_config = config.thinker_config.text_config
with torch.device("meta"):
self.experts = SequentialQwen3OmniThinkerExperts(text_config, original.experts)

@classmethod
def original_module_class(cls) -> str:
return "Qwen3OmniMoeThinkerTextSparseMoeBlock"

def _materialize_weights(self) -> None:
original = self._get_original_module()
self.experts._materialize_weights(original.experts)
clear_memory()

def experts_forward(self, hidden_states, top_k_index, top_k_weights):
final_hidden_states = torch.zeros_like(hidden_states)
with torch.no_grad():
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts)
expert_mask = expert_mask.permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()

for expert_idx in expert_hit:
expert_idx = expert_idx[0]
if expert_idx == self.num_experts:
continue
top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
current_state = hidden_states[token_idx]
current_hidden_states = self.experts[expert_idx](current_state)
current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))

return final_hidden_states

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states_reshaped = hidden_states.view(-1, hidden_dim)
_, routing_weights, selected_experts = self.gate(hidden_states_reshaped)
expert_output = self.experts_forward(hidden_states_reshaped, selected_experts, routing_weights)
return expert_output.reshape(batch_size, sequence_length, hidden_dim)

@classmethod
def from_original(cls, original, config, **kwargs):
return cls(original, config)


# ---------------------------------------------------------------------------
# Talker MoE replacement (with shared expert, same pattern as qwen3_5_moe)
# ---------------------------------------------------------------------------


class LinearQwen3OmniTalkerSparseMoeBlock(ReplacementModuleBase):
"""Calibration replacement for Qwen3OmniMoeTalkerTextSparseMoeBlock.

Unfuses fused expert weights and preserves the shared_expert + shared_expert_gate.
Similar to Qwen3.5-MoE pattern.

Structure: gate (router) + experts (unfused) + shared_expert + shared_expert_gate.
"""

def __init__(self, original, config):
super().__init__(original)
self.gate = original.gate
self.shared_expert = original.shared_expert
self.shared_expert_gate = original.shared_expert_gate
self.num_experts = original.experts.num_experts
text_config = config.talker_config.text_config
with torch.device("meta"):
self.experts = SequentialQwen3OmniTalkerExperts(text_config, original.experts)

@classmethod
def original_module_class(cls) -> str:
return "Qwen3OmniMoeTalkerTextSparseMoeBlock"

def _materialize_weights(self) -> None:
original = self._get_original_module()
self.experts._materialize_weights(original.experts)
clear_memory()

def experts_forward(self, hidden_states, top_k_index, top_k_weights):
final_hidden_states = torch.zeros_like(hidden_states)
with torch.no_grad():
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts)
expert_mask = expert_mask.permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()

for expert_idx in expert_hit:
expert_idx = expert_idx[0]
if expert_idx == self.num_experts:
continue
top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
current_state = hidden_states[token_idx]
current_hidden_states = self.experts[expert_idx](current_state)
current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))

return final_hidden_states

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states_reshaped = hidden_states.view(-1, hidden_dim)
shared_expert_output = self.shared_expert(hidden_states_reshaped)
_, routing_weights, selected_experts = self.gate(hidden_states_reshaped)
expert_output = self.experts_forward(hidden_states_reshaped, selected_experts, routing_weights)

shared_expert_output = torch.sigmoid(self.shared_expert_gate(hidden_states_reshaped)) * shared_expert_output

expert_output = expert_output + shared_expert_output
return expert_output.reshape(batch_size, sequence_length, hidden_dim)

@classmethod
def from_original(cls, original, config, **kwargs):
return cls(original, config)


# ---------------------------------------------------------------------------
# Sequential expert containers (unfused nn.Linear per expert)
# ---------------------------------------------------------------------------


class SequentialQwen3OmniThinkerExperts(torch.nn.ModuleList):
"""Unfused per-expert nn.Linear layers for Qwen3-Omni thinker MoE.

Replaces fused 3D Parameters (gate_up_proj, down_proj) with individual
Qwen3OmniMoeThinkerTextMLP modules per expert.
"""

def __init__(self, config, original):
super().__init__()
self.num_experts = original.gate_up_proj.shape[0]
intermediate_size = config.moe_intermediate_size

from transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe import (
Qwen3OmniMoeThinkerTextMLP,
)

with torch.device("meta"):
super().__init__([Qwen3OmniMoeThinkerTextMLP(config, intermediate_size) for _ in range(self.num_experts)])

def _materialize_weights(self, original) -> None:
"""Unfuse fused expert weights into individual nn.Linear layers.

gate_up_proj shape: (num_experts, 2 * moe_intermediate, hidden)
down_proj shape: (num_experts, hidden, moe_intermediate)
"""
intermediate_size = original.down_proj.shape[-1]
if not unsupported_meta_device(original):
for i in range(self.num_experts):
gate_up = original.gate_up_proj[i]
down = original.down_proj[i]

gate_proj = gate_up[:intermediate_size, :]
up_proj = gate_up[intermediate_size:, :]

_update_parameter(self[i].gate_proj, "weight", gate_proj.contiguous())
_update_parameter(self[i].up_proj, "weight", up_proj.contiguous())
_update_parameter(self[i].down_proj, "weight", down.contiguous())
del gate_up, down, gate_proj, up_proj
original.to_empty(device="meta") # release original fused parameters
clear_memory()


class SequentialQwen3OmniTalkerExperts(torch.nn.ModuleList):
"""Unfused per-expert nn.Linear layers for Qwen3-Omni talker MoE.

Replaces fused 3D Parameters (gate_up_proj, down_proj) with individual
Qwen3OmniMoeTalkerTextMLP modules per expert.
"""

def __init__(self, config, original):
super().__init__()
self.num_experts = original.gate_up_proj.shape[0]
intermediate_size = config.moe_intermediate_size

from transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe import (
Qwen3OmniMoeTalkerTextMLP,
)

with torch.device("meta"):
super().__init__([Qwen3OmniMoeTalkerTextMLP(config, intermediate_size) for _ in range(self.num_experts)])

def _materialize_weights(self, original) -> None:
"""Unfuse fused expert weights into individual nn.Linear layers.

gate_up_proj shape: (num_experts, 2 * moe_intermediate, hidden)
down_proj shape: (num_experts, hidden, moe_intermediate)
"""
intermediate_size = original.down_proj.shape[-1]
if not unsupported_meta_device(original):
for i in range(self.num_experts):
gate_up = original.gate_up_proj[i]
down = original.down_proj[i]

gate_proj = gate_up[:intermediate_size, :]
up_proj = gate_up[intermediate_size:, :]

_update_parameter(self[i].gate_proj, "weight", gate_proj.contiguous())
_update_parameter(self[i].up_proj, "weight", up_proj.contiguous())
_update_parameter(self[i].down_proj, "weight", down.contiguous())
del gate_up, down, gate_proj, up_proj
original.to_empty(device="meta") # release original fused parameters
clear_memory()
2 changes: 2 additions & 0 deletions auto_round/modeling/fused_moe/replace_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
"qwen3_5_moe_text": LazyImport("auto_round.modeling.fused_moe.qwen3_5_moe"),
# Step 3.5 MoE: splits fused MoELinear into per-expert nn.Linear
"step3p5": LazyImport("auto_round.modeling.fused_moe.step3_5_moe"),
# Qwen3-Omni MoE: thinker (no shared expert) + talker (with shared expert)
"qwen3_omni_moe": LazyImport("auto_round.modeling.fused_moe.qwen3_omni"),
}


Expand Down
Loading
Loading