diff --git a/auto_round/auto_scheme/utils.py b/auto_round/auto_scheme/utils.py index df1783d83..51b8da8e1 100644 --- a/auto_round/auto_scheme/utils.py +++ b/auto_round/auto_scheme/utils.py @@ -27,6 +27,7 @@ get_layer_features, get_module, is_hpex_available, + normalize_no_split_modules, parse_available_devices, ) @@ -219,7 +220,7 @@ def dispatch_model_by_all_available_devices( if device_map is None: device_map = 0 - no_split_modules = list(getattr(model, "_no_split_modules", [])) + no_split_modules = normalize_no_split_modules(getattr(model, "_no_split_modules", [])) if device_map == "auto": max_memory = get_balanced_memory( model, diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index 20514fba1..e15f918aa 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -101,6 +101,7 @@ llm_load_model, memory_monitor, mv_module_from_gpu, + normalize_no_split_modules, set_amax_for_all_moe_layers, set_module, to_device, diff --git a/auto_round/compressors/mllm/processor.py b/auto_round/compressors/mllm/processor.py index 9d98be1e9..1063beed4 100644 --- a/auto_round/compressors/mllm/processor.py +++ b/auto_round/compressors/mllm/processor.py @@ -196,6 +196,76 @@ def squeeze_result(ret): return ret +@register_processor("qwen2_5_omni") +class Qwen2_5OmniProcessor(HFProcessor): + """Processor for Qwen2.5-Omni multimodal models. + + Qwen2.5-Omni supports text, image, video, and audio inputs. + This processor handles proper tokenization and preprocessing for calibration. + """ + + @staticmethod + def squeeze_result(ret): + for key in ret: + # Skip squeezing for multi-modal data that may have special dimensions + if key in ["pixel_values", "pixel_values_videos", "input_features"]: + continue + ret[key] = ret[key][0] + return ret + + def _process_v1(self, messages, image): + """Process messages for Qwen2.5-Omni model.""" + conversation = [] + for content in messages: + conversation.append( + { + "role": content["role"], + "content": [{"text": content["content"].replace(self.IMAGE_TOKEN, ""), "type": "text"}], + } + ) + if self.IMAGE_TOKEN in content["content"]: + conversation[-1]["content"].append({"image": image, "type": "image"}) + ret = self.processor.apply_chat_template( + conversation, add_generation_prompt=True, tokenize=True, return_dict=True + ) + return ret + + +@register_processor("qwen3_omni") +class Qwen3OmniProcessor(HFProcessor): + """Processor for Qwen3-Omni multimodal models. + + Qwen3-Omni supports text, image, video, and audio inputs. + This processor handles proper tokenization and preprocessing for calibration. + """ + + @staticmethod + def squeeze_result(ret): + for key in ret: + # Skip squeezing for multi-modal data that may have special dimensions + if key in ["pixel_values", "pixel_values_videos", "input_features"]: + continue + ret[key] = ret[key][0] + return ret + + def _process_v1(self, messages, image): + """Process messages for Qwen3-Omni model.""" + conversation = [] + for content in messages: + conversation.append( + { + "role": content["role"], + "content": [{"text": content["content"].replace(self.IMAGE_TOKEN, ""), "type": "text"}], + } + ) + if self.IMAGE_TOKEN in content["content"]: + conversation[-1]["content"].append({"image": image, "type": "image"}) + ret = self.processor.apply_chat_template( + conversation, add_generation_prompt=True, tokenize=True, return_dict=True + ) + return ret + + @register_processor("cogvlm2") class CogVLM2Processor(BasicProcessor): def get_input(self, text, images, truncation=False, squeeze=True, max_length=None, **kwargs): diff --git a/auto_round/compressors/mllm/template.py b/auto_round/compressors/mllm/template.py index 75190a091..ce2825a47 100644 --- a/auto_round/compressors/mllm/template.py +++ b/auto_round/compressors/mllm/template.py @@ -119,6 +119,8 @@ def _register_template( _register_template("qwen2_vl", default_dataset="NeelNanda/pile-10k", processor=PROCESSORS["qwen2_vl"]) _register_template("qwen2_5_vl", default_dataset="NeelNanda/pile-10k", processor=PROCESSORS["qwen2_vl"]) +_register_template("qwen2_5_omni", default_dataset="NeelNanda/pile-10k", processor=PROCESSORS["qwen2_5_omni"]) +_register_template("qwen3_omni_moe", default_dataset="NeelNanda/pile-10k", processor=PROCESSORS["qwen3_omni"]) _register_template("mllama", default_dataset="liuhaotian/llava", processor=PROCESSORS["hf"]) _register_template("deepseek_vl_v2", default_dataset="NeelNanda/pile-10k", processor=PROCESSORS["deepseek_v2"]) _register_template("mistral3", default_dataset="NeelNanda/pile-10k", processor=PROCESSORS["hf"]) diff --git a/auto_round/compressors/mllm/utils.py b/auto_round/compressors/mllm/utils.py index e8535666c..34609d964 100644 --- a/auto_round/compressors/mllm/utils.py +++ b/auto_round/compressors/mllm/utils.py @@ -27,6 +27,9 @@ "audio", "talker", "token2wav", + "code2wav", + "audio_tower", + "code_predictor", "multi_modal_projector", "vision_tower", "multimodal_projector", diff --git a/auto_round/compressors/shard_writer.py b/auto_round/compressors/shard_writer.py index 20e959e8e..f61f7394d 100644 --- a/auto_round/compressors/shard_writer.py +++ b/auto_round/compressors/shard_writer.py @@ -171,8 +171,18 @@ def finalize(self): # 1. Capture remaining weights not yet saved full_sd = self.model.state_dict() tie_word_embeddings = False + config = getattr(self.model, "config", None) if hasattr(self.model, "config") and hasattr(self.model.config, "tie_word_embeddings"): tie_word_embeddings = self.model.config.tie_word_embeddings + if tie_word_embeddings is None: + # For multimodal models, check nested text/thinker configs + for sub_attr in ("text_config", "thinker_config", "language_config", "llm_config"): + sub_config = getattr(config, sub_attr, None) + if sub_config is not None: + val = getattr(sub_config, "tie_word_embeddings", None) + if val is not None: + tie_word_embeddings = val + break finalize_skipped_meta_tensors = [] for pname, tensor in full_sd.items(): diff --git a/auto_round/inference/convert_model.py b/auto_round/inference/convert_model.py index 1e4324244..ba746e1c7 100644 --- a/auto_round/inference/convert_model.py +++ b/auto_round/inference/convert_model.py @@ -493,6 +493,10 @@ def post_init(model: torch.nn.Module, used_backends: list[str]) -> None: used_backends (List[str]): List of backend names used for quantization. """ + from auto_round.utils.common import monkey_patch_model + + monkey_patch_model(model) + need_autogptq_init = False need_gptqmodel_init = False need_ipex_init = False diff --git a/auto_round/modeling/fused_moe/qwen3_omni.py b/auto_round/modeling/fused_moe/qwen3_omni.py new file mode 100644 index 000000000..ad2004592 --- /dev/null +++ b/auto_round/modeling/fused_moe/qwen3_omni.py @@ -0,0 +1,237 @@ +# Copyright (c) 2026 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""MoE module replacements for Qwen3-Omni model. + +Qwen3-Omni has MoE blocks in both thinker and talker: +- Thinker: Qwen3OmniMoeThinkerTextSparseMoeBlock (experts + gate, no shared expert) +- Talker: Qwen3OmniMoeTalkerTextSparseMoeBlock (experts + gate + shared_expert + shared_expert_gate) + +This module provides replacement classes that unfuse fused expert weights (3D Parameters) +into individual nn.Linear layers, enabling per-expert quantization with meta device optimization. +""" + +import torch + +from auto_round.modeling.fused_moe.replace_modules import ReplacementModuleBase +from auto_round.modeling.fused_moe.utils import _update_parameter +from auto_round.utils import clear_memory, unsupported_meta_device + +# --------------------------------------------------------------------------- +# Thinker MoE replacement (no shared expert) +# --------------------------------------------------------------------------- + + +class LinearQwen3OmniThinkerSparseMoeBlock(ReplacementModuleBase): + """Calibration replacement for Qwen3OmniMoeThinkerTextSparseMoeBlock. + + Unfuses fused expert weights into individual nn.Linear layers for + per-expert quantization. Uses meta device to avoid doubling memory. + + Structure: gate (router) + experts (unfused). + """ + + def __init__(self, original, config): + super().__init__(original) + self.gate = original.gate + self.num_experts = original.experts.num_experts + text_config = config.thinker_config.text_config + with torch.device("meta"): + self.experts = SequentialQwen3OmniThinkerExperts(text_config, original.experts) + + @classmethod + def original_module_class(cls) -> str: + return "Qwen3OmniMoeThinkerTextSparseMoeBlock" + + def _materialize_weights(self) -> None: + original = self._get_original_module() + self.experts._materialize_weights(original.experts) + clear_memory() + + def experts_forward(self, hidden_states, top_k_index, top_k_weights): + final_hidden_states = torch.zeros_like(hidden_states) + with torch.no_grad(): + expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts) + expert_mask = expert_mask.permute(2, 1, 0) + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + + for expert_idx in expert_hit: + expert_idx = expert_idx[0] + if expert_idx == self.num_experts: + continue + top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) + current_state = hidden_states[token_idx] + current_hidden_states = self.experts[expert_idx](current_state) + current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None] + final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) + + return final_hidden_states + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states_reshaped = hidden_states.view(-1, hidden_dim) + _, routing_weights, selected_experts = self.gate(hidden_states_reshaped) + expert_output = self.experts_forward(hidden_states_reshaped, selected_experts, routing_weights) + return expert_output.reshape(batch_size, sequence_length, hidden_dim) + + @classmethod + def from_original(cls, original, config, **kwargs): + return cls(original, config) + + +# --------------------------------------------------------------------------- +# Talker MoE replacement (with shared expert, same pattern as qwen3_5_moe) +# --------------------------------------------------------------------------- + + +class LinearQwen3OmniTalkerSparseMoeBlock(ReplacementModuleBase): + """Calibration replacement for Qwen3OmniMoeTalkerTextSparseMoeBlock. + + Unfuses fused expert weights and preserves the shared_expert + shared_expert_gate. + Similar to Qwen3.5-MoE pattern. + + Structure: gate (router) + experts (unfused) + shared_expert + shared_expert_gate. + """ + + def __init__(self, original, config): + super().__init__(original) + self.gate = original.gate + self.shared_expert = original.shared_expert + self.shared_expert_gate = original.shared_expert_gate + self.num_experts = original.experts.num_experts + text_config = config.talker_config.text_config + with torch.device("meta"): + self.experts = SequentialQwen3OmniTalkerExperts(text_config, original.experts) + + @classmethod + def original_module_class(cls) -> str: + return "Qwen3OmniMoeTalkerTextSparseMoeBlock" + + def _materialize_weights(self) -> None: + original = self._get_original_module() + self.experts._materialize_weights(original.experts) + clear_memory() + + def experts_forward(self, hidden_states, top_k_index, top_k_weights): + final_hidden_states = torch.zeros_like(hidden_states) + with torch.no_grad(): + expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts) + expert_mask = expert_mask.permute(2, 1, 0) + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + + for expert_idx in expert_hit: + expert_idx = expert_idx[0] + if expert_idx == self.num_experts: + continue + top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) + current_state = hidden_states[token_idx] + current_hidden_states = self.experts[expert_idx](current_state) + current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None] + final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) + + return final_hidden_states + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states_reshaped = hidden_states.view(-1, hidden_dim) + shared_expert_output = self.shared_expert(hidden_states_reshaped) + _, routing_weights, selected_experts = self.gate(hidden_states_reshaped) + expert_output = self.experts_forward(hidden_states_reshaped, selected_experts, routing_weights) + + shared_expert_output = torch.sigmoid(self.shared_expert_gate(hidden_states_reshaped)) * shared_expert_output + + expert_output = expert_output + shared_expert_output + return expert_output.reshape(batch_size, sequence_length, hidden_dim) + + @classmethod + def from_original(cls, original, config, **kwargs): + return cls(original, config) + + +# --------------------------------------------------------------------------- +# Sequential expert containers (unfused nn.Linear per expert) +# --------------------------------------------------------------------------- + + +class SequentialQwen3OmniThinkerExperts(torch.nn.ModuleList): + """Unfused per-expert nn.Linear layers for Qwen3-Omni thinker MoE. + + Replaces fused 3D Parameters (gate_up_proj, down_proj) with individual + Qwen3OmniMoeThinkerTextMLP modules per expert. + """ + + def __init__(self, config, original): + super().__init__() + self.num_experts = original.gate_up_proj.shape[0] + intermediate_size = config.moe_intermediate_size + + from transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe import ( + Qwen3OmniMoeThinkerTextMLP, + ) + + with torch.device("meta"): + super().__init__([Qwen3OmniMoeThinkerTextMLP(config, intermediate_size) for _ in range(self.num_experts)]) + + def _materialize_weights(self, original) -> None: + """Unfuse fused expert weights into individual nn.Linear layers. + + gate_up_proj shape: (num_experts, 2 * moe_intermediate, hidden) + down_proj shape: (num_experts, hidden, moe_intermediate) + """ + intermediate_size = original.down_proj.shape[-1] + if not unsupported_meta_device(original): + for i in range(self.num_experts): + gate_up = original.gate_up_proj[i] + down = original.down_proj[i] + + gate_proj = gate_up[:intermediate_size, :] + up_proj = gate_up[intermediate_size:, :] + + _update_parameter(self[i].gate_proj, "weight", gate_proj.contiguous()) + _update_parameter(self[i].up_proj, "weight", up_proj.contiguous()) + _update_parameter(self[i].down_proj, "weight", down.contiguous()) + del gate_up, down, gate_proj, up_proj + original.to_empty(device="meta") # release original fused parameters + clear_memory() + + +class SequentialQwen3OmniTalkerExperts(torch.nn.ModuleList): + """Unfused per-expert nn.Linear layers for Qwen3-Omni talker MoE. + + Replaces fused 3D Parameters (gate_up_proj, down_proj) with individual + Qwen3OmniMoeTalkerTextMLP modules per expert. + """ + + def __init__(self, config, original): + super().__init__() + self.num_experts = original.gate_up_proj.shape[0] + intermediate_size = config.moe_intermediate_size + + from transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe import ( + Qwen3OmniMoeTalkerTextMLP, + ) + + with torch.device("meta"): + super().__init__([Qwen3OmniMoeTalkerTextMLP(config, intermediate_size) for _ in range(self.num_experts)]) + + def _materialize_weights(self, original) -> None: + """Unfuse fused expert weights into individual nn.Linear layers. + + gate_up_proj shape: (num_experts, 2 * moe_intermediate, hidden) + down_proj shape: (num_experts, hidden, moe_intermediate) + """ + intermediate_size = original.down_proj.shape[-1] + if not unsupported_meta_device(original): + for i in range(self.num_experts): + gate_up = original.gate_up_proj[i] + down = original.down_proj[i] + + gate_proj = gate_up[:intermediate_size, :] + up_proj = gate_up[intermediate_size:, :] + + _update_parameter(self[i].gate_proj, "weight", gate_proj.contiguous()) + _update_parameter(self[i].up_proj, "weight", up_proj.contiguous()) + _update_parameter(self[i].down_proj, "weight", down.contiguous()) + del gate_up, down, gate_proj, up_proj + original.to_empty(device="meta") # release original fused parameters + clear_memory() diff --git a/auto_round/modeling/fused_moe/replace_modules.py b/auto_round/modeling/fused_moe/replace_modules.py index f18cc2151..e7993d109 100644 --- a/auto_round/modeling/fused_moe/replace_modules.py +++ b/auto_round/modeling/fused_moe/replace_modules.py @@ -39,6 +39,8 @@ "qwen3_5_moe_text": LazyImport("auto_round.modeling.fused_moe.qwen3_5_moe"), # Step 3.5 MoE: splits fused MoELinear into per-expert nn.Linear "step3p5": LazyImport("auto_round.modeling.fused_moe.step3_5_moe"), + # Qwen3-Omni MoE: thinker (no shared expert) + talker (with shared expert) + "qwen3_omni_moe": LazyImport("auto_round.modeling.fused_moe.qwen3_omni"), } diff --git a/auto_round/special_model_handler.py b/auto_round/special_model_handler.py index fa0000bce..5e271f691 100644 --- a/auto_round/special_model_handler.py +++ b/auto_round/special_model_handler.py @@ -21,7 +21,14 @@ from auto_round.modeling.fused_moe.replace_modules import apply_replacements, release_original_module_ from auto_round.utils import is_moe_model_via_config, logger -mllms_with_limited_bs = ("llava", "qwen2_vl", "phi3_v", "mllama") # Limitations on batch_size +mllms_with_limited_bs = ( + "llava", + "qwen2_vl", + "phi3_v", + "mllama", + "qwen2_5_omni", + "qwen3_omni_moe", +) # Limitations on batch_size SUPPORT_ONLY_TEXT_MODELS = [ "phi3_v", @@ -36,6 +43,8 @@ "internvl_chat", "glm4v_moe", "qwen3_vl_moe", + "qwen2_5_omni", + "qwen3_omni_moe", "gemma3", ] @@ -53,6 +62,14 @@ def _handle_special_model(model): from functools import partial model.forward = partial(_deepseek_vl2_forward, model) + if hasattr(model, "config") and model.config.model_type == "qwen2_5_omni": + from functools import partial + + model.forward = partial(_qwen2_5_omni_forward, model) + if hasattr(model, "config") and model.config.model_type == "qwen3_omni_moe": + from functools import partial + + model.forward = partial(_qwen3_omni_moe_forward, model) return model @@ -80,7 +97,85 @@ def _get_deepseek_vl2_multimodal_block(model, quant_vision=False): return block_names -SPECIAL_MULTIMODAL_BLOCK = {"deepseek_vl_v2": _get_deepseek_vl2_multimodal_block} +def _get_qwen2_5_omni_multimodal_block(model, quant_vision=False): + """Get block names for Qwen2.5-Omni model. + + Qwen2.5-Omni has the following structure: + - thinker: Contains audio_tower, visual, model (text decoder) + - talker: Contains model (talker decoder) + - token2wav: Audio decoder + + For quantization, we focus on: + - thinker.model.layers (text decoder layers) - main LLM layers + - talker.model.layers (talker decoder layers) + - Optionally: visual encoder blocks, audio encoder layers + """ + block_names = [] + + # Quantize visual encoder blocks if quant_vision is enabled + if quant_vision: + if hasattr(model, "thinker") and hasattr(model.thinker, "visual") and hasattr(model.thinker.visual, "blocks"): + block_names.append([f"thinker.visual.blocks.{i}" for i in range(len(model.thinker.visual.blocks))]) + if hasattr(model, "thinker") and hasattr(model.thinker, "audio_tower"): + if hasattr(model.thinker.audio_tower, "layers"): + block_names.append( + [f"thinker.audio_tower.layers.{i}" for i in range(len(model.thinker.audio_tower.layers))] + ) + + # Thinker text model layers (main LLM decoder) + if hasattr(model, "thinker") and hasattr(model.thinker, "model") and hasattr(model.thinker.model, "layers"): + block_names.append([f"thinker.model.layers.{i}" for i in range(len(model.thinker.model.layers))]) + + # Talker model layers (if available) + if hasattr(model, "talker") and hasattr(model.talker, "model") and hasattr(model.talker.model, "layers"): + block_names.append([f"talker.model.layers.{i}" for i in range(len(model.talker.model.layers))]) + + return block_names + + +def _get_qwen3_omni_moe_multimodal_block(model, quant_vision=False): + """Get block names for Qwen3-Omni MoE model. + + Qwen3-Omni has the following structure: + - thinker: Contains audio_tower, visual, model (text decoder) + - talker: Contains model (talker decoder), code_predictor + - code2wav: Audio decoder + + For quantization, we focus on: + - thinker.model.layers (text decoder layers) - main LLM layers + - talker.model.layers (talker decoder layers) + - Optionally: visual encoder blocks, audio encoder layers + """ + block_names = [] + + # Quantize visual encoder blocks if quant_vision is enabled + if quant_vision: + # Vision encoder blocks + if hasattr(model, "thinker") and hasattr(model.thinker, "visual") and hasattr(model.thinker.visual, "blocks"): + block_names.append([f"thinker.visual.blocks.{i}" for i in range(len(model.thinker.visual.blocks))]) + # Audio encoder layers + if hasattr(model, "thinker") and hasattr(model.thinker, "audio_tower"): + if hasattr(model.thinker.audio_tower, "layers"): + block_names.append( + [f"thinker.audio_tower.layers.{i}" for i in range(len(model.thinker.audio_tower.layers))] + ) + + # Thinker text model layers (main LLM decoder) + if hasattr(model, "thinker") and hasattr(model.thinker, "model") and hasattr(model.thinker.model, "layers"): + block_names.append([f"thinker.model.layers.{i}" for i in range(len(model.thinker.model.layers))]) + + # Talker model layers (if available) + if hasattr(model, "talker") and hasattr(model.talker, "model") and hasattr(model.talker.model, "layers"): + block_names.append([f"talker.model.layers.{i}" for i in range(len(model.talker.model.layers))]) + + return block_names + + +SPECIAL_MULTIMODAL_BLOCK = { + "deepseek_vl_v2": _get_deepseek_vl2_multimodal_block, + "qwen2_5_omni": _get_qwen2_5_omni_multimodal_block, + "qwen3_omni_moe": _get_qwen3_omni_moe_multimodal_block, +} def _deepseek_vl2_forward( @@ -122,6 +217,200 @@ def _deepseek_vl2_forward( ) +def _qwen2_5_omni_forward( + model, + input_ids=None, + input_features=None, + pixel_values=None, + pixel_values_videos=None, + image_grid_thw=None, + video_grid_thw=None, + attention_mask=None, + feature_attention_mask=None, + audio_feature_lengths=None, + position_ids=None, + past_key_values=None, + inputs_embeds=None, + rope_deltas=None, + labels=None, + use_cache=None, + use_audio_in_video=None, + cache_position=None, + video_second_per_grid=None, + **kwargs, +): + """Forward function for Qwen2.5-Omni model. + + This delegates to the thinker module for calibration, then optionally + runs a forward through the talker to ensure its layers are also calibrated. + """ + thinker_output = model.thinker( + input_ids=input_ids, + input_features=input_features, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + attention_mask=attention_mask, + feature_attention_mask=feature_attention_mask, + audio_feature_lengths=audio_feature_lengths, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + rope_deltas=rope_deltas, + labels=labels, + use_cache=use_cache, + use_audio_in_video=use_audio_in_video, + cache_position=cache_position, + video_second_per_grid=video_second_per_grid, + output_hidden_states=True, + return_dict=True, + **kwargs, + ) + + # Run talker forward if available (for calibration purposes) + if hasattr(model, "talker") and model.has_talker: + try: + thinker_hidden = thinker_output.hidden_states[-1] if thinker_output.hidden_states else None + + if thinker_hidden is not None: + batch_size, seq_len, _ = thinker_hidden.shape + + if hasattr(model.talker, "thinker_to_talker_proj"): + thinker_embeds = model.thinker.get_input_embeddings()(input_ids) + proj_dtype = next(model.talker.thinker_to_talker_proj.parameters()).dtype + talker_inputs_embeds = model.talker.thinker_to_talker_proj(thinker_embeds.to(proj_dtype)) + else: + talker_hidden_size = model.talker.model.config.hidden_size + talker_inputs_embeds = torch.zeros( + batch_size, + seq_len, + talker_hidden_size, + device=thinker_hidden.device, + dtype=thinker_hidden.dtype, + ) + + # Align dtype with talker model weights + talker_dtype = next(model.talker.model.parameters()).dtype + _ = model.talker.model( + inputs_embeds=talker_inputs_embeds.to(talker_dtype), + attention_mask=attention_mask, + use_cache=False, + ) + except Exception as exc: + logger.warning( + "Qwen2.5-Omni talker forward failed during calibration; " + "continuing with thinker-only quantization. Error: %s", + exc, + exc_info=True, + ) + + return thinker_output + + +def _qwen3_omni_moe_forward( + model, + input_ids=None, + input_features=None, + pixel_values=None, + pixel_values_videos=None, + image_grid_thw=None, + video_grid_thw=None, + attention_mask=None, + feature_attention_mask=None, + audio_feature_lengths=None, + position_ids=None, + past_key_values=None, + inputs_embeds=None, + rope_deltas=None, + labels=None, + use_cache=None, + output_router_logits=None, + use_audio_in_video=None, + cache_position=None, + video_second_per_grid=None, + **kwargs, +): + """Forward function for Qwen3-Omni-MoE model. + + This runs forward through both thinker and talker modules for calibration. + The thinker processes text/vision/audio input, and talker uses thinker's + hidden states to prepare for speech synthesis. + """ + # Run thinker forward with output_hidden_states to get hidden states for talker + thinker_output = model.thinker( + input_ids=input_ids, + input_features=input_features, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + attention_mask=attention_mask, + feature_attention_mask=feature_attention_mask, + audio_feature_lengths=audio_feature_lengths, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + rope_deltas=rope_deltas, + labels=labels, + use_cache=use_cache, + output_router_logits=output_router_logits, + use_audio_in_video=use_audio_in_video, + cache_position=cache_position, + video_second_per_grid=video_second_per_grid, + output_hidden_states=True, + return_dict=True, + **kwargs, + ) + + # Run talker forward if available (for calibration purposes) + if getattr(model, "has_talker", False) and getattr(model, "talker", None) is not None: + try: + # Get thinker hidden states for talker input + # Use the last hidden state from thinker as input to talker + thinker_hidden = thinker_output.hidden_states[-1] if thinker_output.hidden_states else None + + if thinker_hidden is not None: + # Create simple talker input from thinker hidden states + # Project thinker hidden states to talker dimension + batch_size, seq_len, _ = thinker_hidden.shape + talker_hidden_size = model.config.talker_config.text_config.hidden_size + + # Use text projection to convert thinker embeddings to talker space + if hasattr(model.talker, "text_projection"): + # Get thinker embeddings and align dtype with text_projection weights + thinker_embeds = model.thinker.get_input_embeddings()(input_ids) + proj_dtype = next(model.talker.text_projection.parameters()).dtype + talker_inputs_embeds = model.talker.text_projection(thinker_embeds.to(proj_dtype)) + else: + # Fallback: create zero embeddings of correct size + talker_inputs_embeds = torch.zeros( + batch_size, + seq_len, + talker_hidden_size, + device=thinker_hidden.device, + dtype=thinker_hidden.dtype, + ) + + # Run talker model forward — align dtype with talker model weights + talker_dtype = next(model.talker.model.parameters()).dtype + _ = model.talker.model( + inputs_embeds=talker_inputs_embeds.to(talker_dtype), + attention_mask=attention_mask, + use_cache=False, + ) + except Exception as exc: + # Log talker forward errors during calibration without interrupting thinker quantization + logger.warning( + "Qwen3-Omni-MoE talker forward failed during calibration; " + "continuing with thinker-only quantization. Error: %s", + exc, + exc_info=True, + ) + + return thinker_output + + def check_mllm_model_batch(model, batch_size, gradient_accumulate_steps=1): """ Checks model configuration to determine if it's necessary to limit bs to avoid potential input shape mismatches. diff --git a/auto_round/utils/common.py b/auto_round/utils/common.py index 96b6066e5..dbe65ef8b 100644 --- a/auto_round/utils/common.py +++ b/auto_round/utils/common.py @@ -104,6 +104,25 @@ def patched(klass, *args, **kwargs): setattr(cls, method_name, classmethod(patched)) +def normalize_no_split_modules(no_split_modules): + if not no_split_modules: + return [] + + def flatten_items(value): + if isinstance(value, (list, tuple, set)): + for item in value: + yield from flatten_items(item) + else: + yield value + + flattened = [] + for item in flatten_items(no_split_modules): + if item is None: + continue + flattened.append(item) + return flattened + + def _patch_transpose_for_buffers(): """Patch Transpose.convert() to skip transposition for buffer tensors. @@ -220,6 +239,191 @@ def monkey_patch(): monkey_patch_transformers() +def monkey_patch_model(model) -> None: + """Apply model-instance-level monkey patches after a model is loaded. + + This is the central place for all instance-level patches (as opposed to the + class-level patches in ``monkey_patch_transformers``). + """ + _patch_prepare_inputs_for_generation(model) + + +def _patch_prepare_inputs_for_generation(model) -> None: + """Fix positional-arg mismatch in models whose prepare_inputs_for_generation + passes arguments positionally to GenerationMixin.prepare_inputs_for_generation. + + transformers >= 5.1 inserted ``next_sequence_length`` as the 2nd positional + parameter in the base ``GenerationMixin.prepare_inputs_for_generation``. + Some model implementations (e.g. Qwen2.5-Omni / Qwen3-Omni MoE talker) + still pass ``past_key_values`` etc. positionally, causing a + "got multiple values for argument 'next_sequence_length'" TypeError. + + This function monkey-patches the affected sub-model to use keyword arguments + instead, which works with both old and new transformers. + """ + model_type = getattr(getattr(model, "config", None), "model_type", None) + + if model_type == "qwen2_5_omni": + _patch_qwen25_omni_talker(model) + elif model_type == "qwen3_omni_moe": + _patch_qwen3_omni_moe_talker(model) + + +def _patch_qwen25_omni_talker(model) -> None: + """Patch Qwen2.5-Omni talker prepare_inputs_for_generation.""" + talker = getattr(model, "talker", None) + if talker is None: + return + + import types + + def _fixed_prepare_inputs_for_generation( + self, + input_ids, + input_text_ids=None, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + thinker_reply_part=None, + cache_position=None, + position_ids=None, + use_cache=True, + pixel_values=None, + pixel_values_videos=None, + image_grid_thw=None, + video_grid_thw=None, + input_audio_features=None, + audio_feature_attention_mask=None, + audio_feature_lengths=None, + use_audio_in_video=None, + video_second_per_grid=None, + **kwargs, + ): + from transformers.generation.utils import GenerationMixin + + model_inputs = GenerationMixin.prepare_inputs_for_generation( + self, + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + use_cache=use_cache, + thinker_reply_part=thinker_reply_part, + input_text_ids=input_text_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + input_audio_features=input_audio_features, + audio_feature_attention_mask=audio_feature_attention_mask, + audio_feature_lengths=audio_feature_lengths, + use_audio_in_video=use_audio_in_video, + video_second_per_grid=video_second_per_grid, + **kwargs, + ) + model_inputs["position_ids"] = None + return model_inputs + + talker.prepare_inputs_for_generation = types.MethodType(_fixed_prepare_inputs_for_generation, talker) + from auto_round.logger import logger + + logger.info("Patched Qwen2.5-Omni talker prepare_inputs_for_generation for transformers compat.") + + +def _patch_qwen3_omni_moe_talker(model) -> None: + """Patch Qwen3-Omni MoE talker prepare_inputs_for_generation. + + The talker passes past_key_values, attention_mask, inputs_embeds positionally + to super().prepare_inputs_for_generation(), colliding with the new + next_sequence_length parameter in transformers >= 5.1. + """ + talker = getattr(model, "talker", None) + if talker is None: + return + + import types + + _orig_prepare = ( + talker.prepare_inputs_for_generation.__func__ + if hasattr(talker.prepare_inputs_for_generation, "__func__") + else talker.prepare_inputs_for_generation + ) + + def _fixed_prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + is_first_iteration=False, + **kwargs, + ): + from transformers.generation.utils import GenerationMixin + + hidden_states = kwargs.pop("hidden_states", None) + inputs = GenerationMixin.prepare_inputs_for_generation( + self, + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + is_first_iteration=is_first_iteration, + **kwargs, + ) + + # Qwen3-Omni will prepare position ids in forward with deltas + inputs["position_ids"] = None + + # Reproduce the talker's codec logic for non-first iterations + if not is_first_iteration and kwargs.get("use_cache", True): + import torch + + input_ids_last = input_ids[:, -1:] + generation_step = kwargs.get("generation_step") + trailing_text_hidden = kwargs.get("trailing_text_hidden") + tts_pad_embed = kwargs.get("tts_pad_embed") + last_id_hidden = self.get_input_embeddings()(input_ids_last) + + past_hidden = hidden_states[0][-1][:, -1:].to(last_id_hidden.device) + predictor_result = self.code_predictor.generate( + inputs_embeds=torch.cat((past_hidden, last_id_hidden), dim=1), + max_new_tokens=self.config.num_code_groups - 1, + do_sample=True, + top_k=50, + top_p=0.8, + output_hidden_states=True, + return_dict_in_generate=True, + ) + residual_codes = torch.cat((input_ids_last, predictor_result.sequences.to(input_ids_last.device)), dim=-1) + + mid_residual_hiddens = [hid[0].to(last_id_hidden.device) for hid in predictor_result.hidden_states[1:]] + last_residual_hidden = self.code_predictor.get_input_embeddings()[-1]( + predictor_result.sequences[..., -1:] + ).to(last_id_hidden.device) + codec_hiddens = torch.cat( + [last_id_hidden] + mid_residual_hiddens + [last_residual_hidden], + dim=1, + ) + inputs_embeds_new = codec_hiddens.sum(1, keepdim=True) + + if generation_step < trailing_text_hidden.shape[1]: + inputs_embeds_new = inputs_embeds_new + trailing_text_hidden[:, generation_step].unsqueeze(1).to( + inputs_embeds_new.device + ) + else: + inputs_embeds_new = inputs_embeds_new + tts_pad_embed.to(inputs_embeds_new.device) + inputs["inputs_embeds"] = inputs_embeds_new + inputs["residual_codes"] = residual_codes + + return inputs + + talker.prepare_inputs_for_generation = types.MethodType(_fixed_prepare_inputs_for_generation, talker) + from auto_round.logger import logger + + logger.info("Patched Qwen3-Omni MoE talker prepare_inputs_for_generation for transformers compat.") + + auto_gptq = LazyImport("auto_gptq") htcore = LazyImport("habana_frameworks.torch.core") @@ -280,6 +484,8 @@ def __getitem__(self, key): "audio", "talker", "token2wav", + "code2wav", + "code_predictor", "vision_model", "audio_tower", "vision_encoder", diff --git a/auto_round/utils/model.py b/auto_round/utils/model.py index 0e4a406fb..357e7c9df 100644 --- a/auto_round/utils/model.py +++ b/auto_round/utils/model.py @@ -431,6 +431,20 @@ def mllm_load_model( else: model_type = None + if model_type == "qwen2_5_omni": + if version.parse(transformers.__version__) < version.parse("4.52.0"): + raise RuntimeError( + f"Qwen2.5-Omni requires transformers >= 4.52.0, but found {transformers.__version__}. " + "Please upgrade: pip install transformers>=4.52.0" + ) + + if model_type == "qwen3_omni_moe": + if version.parse(transformers.__version__) < version.parse("5.1.0"): + raise RuntimeError( + f"Qwen3-Omni requires transformers >= 5.1.0, but found {transformers.__version__}. " + "Please upgrade: pip install transformers>=5.1.0" + ) + processor, image_processor = None, None if "deepseek_vl_v2" == model_type: from deepseek_vl2.models import DeepseekVLV2ForCausalLM, DeepseekVLV2Processor # pylint: disable=E0401 @@ -712,6 +726,8 @@ def is_moe_layer(module: torch.nn.Module) -> bool: "Qwen2MoeSparseMoeBlock".lower(), "Qwen3MoeSparseMoeBlock".lower(), "Qwen3VLMoeTextSparseMoeBlock".lower(), + "Qwen3OmniMoeThinkerTextSparseMoeBlock".lower(), + "Qwen3OmniMoeTalkerTextSparseMoeBlock".lower(), ] ) @@ -811,6 +827,8 @@ def module_match_name_list(module, name_list): "DeepseekV2MoE", "DeepseekV3MoE", "Qwen3VLMoeTextSparseMoeBlock", + "Qwen3OmniMoeThinkerTextSparseMoeBlock", + "Qwen3OmniMoeTalkerTextSparseMoeBlock", ], ): return ["gate_proj", "down_proj", "up_proj"] @@ -846,6 +864,8 @@ def module_match_name_list(module, name_list): "Qwen2MoeSparseMoeBlock", "Qwen3MoeSparseMoeBlock", "Qwen3VLMoeTextSparseMoeBlock", + "Qwen3OmniMoeThinkerTextSparseMoeBlock", + "Qwen3OmniMoeTalkerTextSparseMoeBlock", "DeepseekMoE", "DeepseekV2MoE", "DeepseekV3MoE", @@ -1525,6 +1545,28 @@ def _get_reference_amax_from_experts(moe_module: torch.nn.Module, attr_name: str return torch.max(all_values) +# Extra non-weight files that some models require at load time but are not saved +# by model.save_pretrained(). These are copied from the source model cache to +# the quantized output directory so that from_pretrained() works out of the box. +_EXTRA_MODEL_FILES = { + "spk_dict.pt", # Qwen2.5-Omni speaker dictionary for audio output +} + + +def _copy_extra_model_files(src_dir: str, dst_dir: str): + """Copy known extra model files from *src_dir* to *dst_dir* if they exist.""" + import os + import shutil + + for file in os.listdir(src_dir): + if file in _EXTRA_MODEL_FILES: + src_file = os.path.join(src_dir, file) + dst_file = os.path.join(dst_dir, file) + if os.path.isfile(src_file) and not os.path.exists(dst_file): + logger.debug(f"Transferring extra model file {src_file} to {dst_dir}") + shutil.copy(src_file, dst_dir) + + # Adapted from https://github.com/vllm-project/llm-compressor/blob/ # 5b3ddff74cae9651f24bef15d3255c4ee053fc60/src/llmcompressor/pytorch/model_load/helpers.py#L144 def copy_python_files_from_model_cache(model, save_path: str): @@ -1563,6 +1605,8 @@ def copy_python_files_from_model_cache(model, save_path: str): logger.debug(f"Transferring {full_file_name} to {save_path}") shutil.copy(full_file_name, save_path) + _copy_extra_model_files(cache_path, save_path) + def extract_block_names_to_str(quant_block_list): if not isinstance(quant_block_list, (list, tuple)): diff --git a/pyproject.toml b/pyproject.toml index 014a453f6..9249c6ad4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ extend-exclude = [ [tool.typos.default.extend-words] ue = "ue" endianess = "endianess" +thw = "thw" [tool.ruff] # Exclude a variety of commonly ignored directories. diff --git a/test/fixtures.py b/test/fixtures.py index 4ec157b1f..624ee1a19 100644 --- a/test/fixtures.py +++ b/test/fixtures.py @@ -15,6 +15,8 @@ lamini_name_or_path, opt_name_or_path, phi2_name_or_path, + qwen2_5_omni_name_or_path, + qwen3_omni_name_or_path, qwen_2_5_vl_name_or_path, qwen_moe_name_or_path, qwen_name_or_path, @@ -123,6 +125,75 @@ def tiny_qwen_2_5_vl_model_path(): shutil.rmtree(tiny_model_path, ignore_errors=True) +@pytest.fixture(scope="session") +def tiny_qwen2_5_omni(): + """Tiny Qwen2.5-Omni-3B model built from real config with reduced layers. + + Uses random weights (no checkpoint loading) so it is fast for CPU unit + tests while still exercising the real config structure. + Skipped automatically when the model path does not exist locally. + """ + + from transformers import AutoConfig, AutoProcessor, AutoTokenizer, Qwen2_5OmniForConditionalGeneration + + model_name = qwen2_5_omni_name_or_path + if not os.path.isdir(model_name): + pytest.skip(f"Qwen2.5-Omni-3B not found at {model_name}") + + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True) + config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) + + # Reduce layers — keeps real config structure but uses random weights + config.thinker_config.text_config.num_hidden_layers = 1 + config.thinker_config.vision_config.depth = 1 + config.thinker_config.audio_config.num_hidden_layers = 1 + config.talker_config.num_hidden_layers = 1 + if hasattr(config.thinker_config.text_config, "layer_types"): + config.thinker_config.text_config.layer_types = config.thinker_config.text_config.layer_types[:1] + if hasattr(config.talker_config, "layer_types"): + config.talker_config.layer_types = config.talker_config.layer_types[:1] + + model = Qwen2_5OmniForConditionalGeneration(config) + yield model, tokenizer, processor + + +@pytest.fixture(scope="session") +def tiny_qwen3_omni_moe(): + """Tiny Qwen3-Omni-MoE model built from real config with reduced layers. + + Uses random weights (no checkpoint loading) so it is fast for CI while + still exercising the real config structure. + Skipped automatically when the model path does not exist locally. + """ + + from transformers import AutoConfig, AutoProcessor, AutoTokenizer, Qwen3OmniMoeForConditionalGeneration + + model_name = qwen3_omni_name_or_path + if not os.path.isdir(model_name): + pytest.skip(f"Qwen3-Omni-30B-A3B-Instruct not found at {model_name}") + + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True) + config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) + + # Reduce layers — keeps real config structure but uses random weights + config.thinker_config.text_config.num_hidden_layers = 1 + config.thinker_config.vision_config.depth = 1 + config.thinker_config.audio_config.num_hidden_layers = 1 + if hasattr(config.thinker_config.text_config, "layer_types"): + config.thinker_config.text_config.layer_types = config.thinker_config.text_config.layer_types[:1] + # Talker + if hasattr(config, "talker_config"): + if hasattr(config.talker_config, "text_config"): + config.talker_config.text_config.num_hidden_layers = 1 + elif hasattr(config.talker_config, "num_hidden_layers"): + config.talker_config.num_hidden_layers = 1 + + model = Qwen3OmniMoeForConditionalGeneration(config) + yield model, tokenizer, processor + + # Mock torch.cuda.get_device_capability to always return (9, 0) like H100 @pytest.fixture() def mock_fp8_capable_device(): diff --git a/test/helpers.py b/test/helpers.py index 240791a6a..8a5d24649 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -98,6 +98,8 @@ def get_model_path(model_name: str) -> str: qwen_vl_name_or_path = get_model_path("Qwen/Qwen2-VL-2B-Instruct") qwen_2_5_vl_name_or_path = get_model_path("Qwen/Qwen2.5-VL-3B-Instruct") gemma_name_or_path = get_model_path("benzart/gemma-2b-it-fine-tuning-for-code-test") +qwen2_5_omni_name_or_path = get_model_path("Qwen/Qwen2.5-Omni-3B") +qwen3_omni_name_or_path = get_model_path("Qwen/Qwen3-Omni-30B-A3B-Instruct") # Slice model into tiny model for speedup diff --git a/test/test_cpu/models/test_omni_model.py b/test/test_cpu/models/test_omni_model.py new file mode 100644 index 000000000..8cd83b8bf --- /dev/null +++ b/test/test_cpu/models/test_omni_model.py @@ -0,0 +1,399 @@ +# Copyright (c) 2026 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for Qwen2.5-Omni and Qwen3-Omni-MoE model support. + +Tests cover: +- Block name discovery (thinker + talker layers) +- MoE module replacement and weight fidelity (Qwen3-Omni) +- Forward function patching (_handle_special_model) +- Processor and template registration +- MoE utility functions (is_moe_layer, get_expert_linear_names, etc.) +- Ignore layer registration (mlp.gate for MoE) +""" + +import copy +import shutil + +import pytest +import torch +from transformers import ( + Qwen3OmniMoeConfig, + Qwen3OmniMoeForConditionalGeneration, +) + +from ...helpers import check_version, transformers_version + +pytestmark = pytest.mark.skipif( + not check_version("transformers>=5.1.0"), + reason="Qwen-Omni models require transformers >= 5.1.0", +) + + +# --------------------------------------------------------------------------- +# Helper: create tiny Qwen3-Omni-MoE config (no real checkpoint needed for MoE) +# --------------------------------------------------------------------------- +def _make_tiny_qwen3_omni_moe_config(): + config = Qwen3OmniMoeConfig() + # Thinker + config.thinker_config.text_config.num_hidden_layers = 1 + config.thinker_config.text_config.hidden_size = 64 + config.thinker_config.text_config.intermediate_size = 128 + config.thinker_config.text_config.moe_intermediate_size = 32 + config.thinker_config.text_config.num_attention_heads = 4 + config.thinker_config.text_config.num_key_value_heads = 2 + config.thinker_config.text_config.num_experts = 4 + config.thinker_config.text_config.num_experts_per_tok = 2 + config.thinker_config.vision_config.depth = 1 + config.thinker_config.vision_config.embed_dim = 64 + config.thinker_config.vision_config.hidden_size = 64 + config.thinker_config.vision_config.num_heads = 4 + config.thinker_config.audio_config.num_hidden_layers = 1 + # Talker + config.talker_config.text_config.num_hidden_layers = 1 + config.talker_config.text_config.hidden_size = 64 + config.talker_config.text_config.intermediate_size = 128 + config.talker_config.text_config.moe_intermediate_size = 32 + config.talker_config.text_config.num_attention_heads = 4 + config.talker_config.text_config.num_key_value_heads = 2 + config.talker_config.text_config.num_experts = 4 + config.talker_config.text_config.num_local_experts = 4 + config.talker_config.text_config.num_experts_per_tok = 2 + config.talker_config.text_config.shared_expert_intermediate_size = 64 + config.talker_config.thinker_hidden_size = 64 + config.talker_config.spatial_merge_size = 2 + # Code2wav (minimal) + config.code2wav_config.hidden_size = 64 + config.code2wav_config.num_hidden_layers = 1 + config.code2wav_config.num_attention_heads = 4 + config.code2wav_config.num_key_value_heads = 4 + config.code2wav_config.intermediate_size = 128 + return config + + +# ========================= Qwen2.5-Omni Tests ============================= +# NOTE: Tests use the `tiny_qwen2_5_omni` session-scoped fixture from fixtures.py +# (real config, reduced layers, random weights). Skipped if model not available. + + +class TestQwen2_5OmniBlockNames: + """Test block name discovery for Qwen2.5-Omni (dense, not MoE).""" + + def test_block_names_default(self, tiny_qwen2_5_omni): + """Test that get_block_names returns thinker + talker layers.""" + from auto_round.utils import get_block_names + + model, _, _ = tiny_qwen2_5_omni + block_names = get_block_names(model, quant_vision=False) + # Should have thinker.model.layers and talker.model.layers + assert any( + "thinker.model.layers" in str(b) for b in block_names + ), f"Expected thinker.model.layers in block_names, got: {block_names}" + assert any( + "talker.model.layers" in str(b) for b in block_names + ), f"Expected talker.model.layers in block_names, got: {block_names}" + + def test_block_names_quant_vision(self, tiny_qwen2_5_omni): + """Test that quant_vision adds visual and audio blocks.""" + from auto_round.utils import get_block_names + + model, _, _ = tiny_qwen2_5_omni + blocks_no_vision = get_block_names(model, quant_vision=False) + blocks_with_vision = get_block_names(model, quant_vision=True) + + assert len(blocks_with_vision) > len(blocks_no_vision), "quant_vision=True should add visual/audio blocks" + + +class TestQwen2_5OmniForward: + """Test forward function patching for Qwen2.5-Omni.""" + + def test_handle_special_model(self, tiny_qwen2_5_omni): + """Test that _handle_special_model patches the forward function.""" + from auto_round.special_model_handler import _handle_special_model + + # Deepcopy to avoid mutating the shared session-scoped fixture + model = copy.deepcopy(tiny_qwen2_5_omni[0]) + original_forward = model.forward + model = _handle_special_model(model) + assert model.forward != original_forward, "Forward should be patched for qwen2_5_omni" + + +class TestQwen2_5OmniProcessor: + """Test processor and template registration for Qwen2.5-Omni.""" + + def test_processor_registered(self): + from auto_round.compressors.mllm.processor import PROCESSORS + + assert "qwen2_5_omni" in PROCESSORS, "qwen2_5_omni processor not registered" + + def test_template_registered(self): + from auto_round.compressors.mllm.template import TEMPLATES + + assert "qwen2_5_omni" in TEMPLATES, "qwen2_5_omni template not registered" + + def test_template_default_dataset(self): + from auto_round.compressors.mllm.template import TEMPLATES + + template = TEMPLATES["qwen2_5_omni"] + assert template.default_dataset is not None + + +# ========================= Qwen3-Omni-MoE Tests =========================== + + +class TestQwen3OmniMoeBlockNames: + """Test block name discovery for Qwen3-Omni-MoE.""" + + def test_block_names_default(self): + """Test that get_block_names returns thinker + talker layers.""" + from auto_round.utils import get_block_names + + config = _make_tiny_qwen3_omni_moe_config() + model = Qwen3OmniMoeForConditionalGeneration(config) + + block_names = get_block_names(model, quant_vision=False) + assert any( + "thinker.model.layers" in str(b) for b in block_names + ), f"Expected thinker.model.layers, got: {block_names}" + assert any( + "talker.model.layers" in str(b) for b in block_names + ), f"Expected talker.model.layers, got: {block_names}" + + def test_block_names_quant_vision(self): + """Test that quant_vision adds visual and audio blocks.""" + from auto_round.utils import get_block_names + + config = _make_tiny_qwen3_omni_moe_config() + model = Qwen3OmniMoeForConditionalGeneration(config) + + blocks_no_vision = get_block_names(model, quant_vision=False) + blocks_with_vision = get_block_names(model, quant_vision=True) + + assert len(blocks_with_vision) > len(blocks_no_vision), "quant_vision=True should add visual/audio blocks" + + +class TestQwen3OmniMoeForward: + """Test forward function patching for Qwen3-Omni-MoE.""" + + def test_handle_special_model(self): + """Test that _handle_special_model patches the forward function.""" + from auto_round.special_model_handler import _handle_special_model + + config = _make_tiny_qwen3_omni_moe_config() + model = Qwen3OmniMoeForConditionalGeneration(config) + + original_forward = model.forward + model = _handle_special_model(model) + assert model.forward != original_forward, "Forward should be patched for qwen3_omni_moe" + + +class TestQwen3OmniMoeReplacement: + """Test MoE module replacement for Qwen3-Omni-MoE.""" + + def test_replacement_registered(self): + """Test that both thinker and talker MoE blocks are registered.""" + from auto_round.modeling.fused_moe.qwen3_omni import ( + LinearQwen3OmniTalkerSparseMoeBlock, + LinearQwen3OmniThinkerSparseMoeBlock, + ) + from auto_round.modeling.fused_moe.replace_modules import ReplacementModuleBase + + assert ReplacementModuleBase.is_registered("Qwen3OmniMoeThinkerTextSparseMoeBlock") + assert ReplacementModuleBase.is_registered("Qwen3OmniMoeTalkerTextSparseMoeBlock") + + def test_builtin_modules_entry(self): + """Test that qwen3_omni_moe is in BUILTIN_MODULES.""" + from auto_round.modeling.fused_moe.replace_modules import BUILTIN_MODULES + + assert "qwen3_omni_moe" in BUILTIN_MODULES + + def test_is_custom_model(self): + """Test that is_custom_model returns True for Qwen3-Omni-MoE.""" + from auto_round.modeling.fused_moe.replace_modules import is_custom_model + + config = _make_tiny_qwen3_omni_moe_config() + model = Qwen3OmniMoeForConditionalGeneration(config) + assert is_custom_model(model) + + def test_apply_replacements(self): + """Test that MoE blocks are correctly replaced.""" + from auto_round.modeling.fused_moe.replace_modules import apply_replacements + + config = _make_tiny_qwen3_omni_moe_config() + model = Qwen3OmniMoeForConditionalGeneration(config) + + model = apply_replacements(model) + + # Check that thinker MoE was replaced + thinker_mlp = model.thinker.model.layers[0].mlp + assert ( + "LinearQwen3OmniThinker" in thinker_mlp.__class__.__name__ + ), f"Expected LinearQwen3OmniThinker, got {thinker_mlp.__class__.__name__}" + + # Check that talker MoE was replaced + talker_mlp = model.talker.model.layers[0].mlp + assert ( + "LinearQwen3OmniTalker" in talker_mlp.__class__.__name__ + ), f"Expected LinearQwen3OmniTalker, got {talker_mlp.__class__.__name__}" + + def test_weight_fidelity(self): + """Test that unfused weights match original fused weights.""" + from auto_round.modeling.fused_moe.replace_modules import apply_replacements, materialize_model_ + + config = _make_tiny_qwen3_omni_moe_config() + model = Qwen3OmniMoeForConditionalGeneration(config) + + # Save original fused weights + thinker_gate_up = model.thinker.model.layers[0].mlp.experts.gate_up_proj.data.clone() + thinker_down = model.thinker.model.layers[0].mlp.experts.down_proj.data.clone() + talker_gate_up = model.talker.model.layers[0].mlp.experts.gate_up_proj.data.clone() + talker_down = model.talker.model.layers[0].mlp.experts.down_proj.data.clone() + + model = apply_replacements(model) + materialize_model_(model) + + intermediate = 32 # moe_intermediate_size + # Verify thinker expert weights + for i in range(4): + expert = model.thinker.model.layers[0].mlp.experts[i] + assert torch.allclose(expert.gate_proj.weight.data, thinker_gate_up[i, :intermediate, :]) + assert torch.allclose(expert.up_proj.weight.data, thinker_gate_up[i, intermediate:, :]) + assert torch.allclose(expert.down_proj.weight.data, thinker_down[i]) + + # Verify talker expert weights + for i in range(4): + expert = model.talker.model.layers[0].mlp.experts[i] + assert torch.allclose(expert.gate_proj.weight.data, talker_gate_up[i, :intermediate, :]) + assert torch.allclose(expert.up_proj.weight.data, talker_gate_up[i, intermediate:, :]) + assert torch.allclose(expert.down_proj.weight.data, talker_down[i]) + + def test_forward_output_match(self): + """Test that replaced MoE forward output matches original.""" + from auto_round.modeling.fused_moe.replace_modules import apply_replacements, materialize_model_ + + config = _make_tiny_qwen3_omni_moe_config() + model = Qwen3OmniMoeForConditionalGeneration(config) + + x = torch.randn(1, 4, 64) + with torch.no_grad(): + orig_thinker_out = model.thinker.model.layers[0].mlp(x) + orig_talker_out = model.talker.model.layers[0].mlp(x) + + model = apply_replacements(model) + materialize_model_(model) + + with torch.no_grad(): + new_thinker_out = model.thinker.model.layers[0].mlp(x) + new_talker_out = model.talker.model.layers[0].mlp(x) + + assert torch.allclose(orig_thinker_out, new_thinker_out, atol=1e-5), "Thinker MoE forward mismatch" + assert torch.allclose(orig_talker_out, new_talker_out, atol=1e-5), "Talker MoE forward mismatch" + + +class TestQwen3OmniMoeProcessor: + """Test processor and template registration for Qwen3-Omni-MoE.""" + + def test_processor_registered(self): + from auto_round.compressors.mllm.processor import PROCESSORS + + assert "qwen3_omni" in PROCESSORS, "qwen3_omni processor not registered" + + def test_template_registered(self): + from auto_round.compressors.mllm.template import TEMPLATES + + assert "qwen3_omni_moe" in TEMPLATES, "qwen3_omni_moe template not registered" + + +class TestQwen3OmniMoeUtils: + """Test MoE utility functions for Qwen3-Omni-MoE classes.""" + + def test_is_moe_layer_thinker(self): + from auto_round.utils.model import is_moe_layer + + config = _make_tiny_qwen3_omni_moe_config() + model = Qwen3OmniMoeForConditionalGeneration(config) + moe_block = model.thinker.model.layers[0].mlp + assert is_moe_layer(moe_block), f"Thinker MoE block ({moe_block.__class__.__name__}) should be detected as MoE" + + def test_is_moe_layer_talker(self): + from auto_round.utils.model import is_moe_layer + + config = _make_tiny_qwen3_omni_moe_config() + model = Qwen3OmniMoeForConditionalGeneration(config) + moe_block = model.talker.model.layers[0].mlp + assert is_moe_layer(moe_block), f"Talker MoE block ({moe_block.__class__.__name__}) should be detected as MoE" + + def test_get_expert_linear_names(self): + from auto_round.utils.model import get_expert_linear_names + + config = _make_tiny_qwen3_omni_moe_config() + model = Qwen3OmniMoeForConditionalGeneration(config) + + thinker_mlp = model.thinker.model.layers[0].mlp + names = get_expert_linear_names(thinker_mlp) + assert set(names) == {"gate_proj", "down_proj", "up_proj"} + + talker_mlp = model.talker.model.layers[0].mlp + names = get_expert_linear_names(talker_mlp) + assert set(names) == {"gate_proj", "down_proj", "up_proj"} + + def test_get_expert_input_proj_names(self): + from auto_round.utils.model import get_expert_input_proj_names + + config = _make_tiny_qwen3_omni_moe_config() + model = Qwen3OmniMoeForConditionalGeneration(config) + + thinker_mlp = model.thinker.model.layers[0].mlp + names = get_expert_input_proj_names(thinker_mlp) + assert set(names) == {"gate_proj", "up_proj"} + + def test_ignore_layers_registered(self): + from auto_round.special_model_handler import get_predefined_ignore_layers + + config = _make_tiny_qwen3_omni_moe_config() + model = Qwen3OmniMoeForConditionalGeneration(config) + + ignore_layers = get_predefined_ignore_layers(model) + assert ( + "mlp.gate" in ignore_layers + ), f"Expected mlp.gate in ignore_layers for qwen3_omni_moe, got: {ignore_layers}" + + +class TestQwen2_5OmniNotMoe: + """Verify Qwen2.5-Omni is not treated as MoE.""" + + def test_not_moe(self, tiny_qwen2_5_omni): + from auto_round.utils.model import is_moe_layer + + model, _, _ = tiny_qwen2_5_omni + # Thinker uses dense MLP, not MoE + thinker_mlp = model.thinker.model.layers[0].mlp + assert not is_moe_layer(thinker_mlp), "Qwen2.5-Omni should not be detected as MoE" + + def test_not_custom_model(self, tiny_qwen2_5_omni): + from auto_round.modeling.fused_moe.replace_modules import is_custom_model + + model, _, _ = tiny_qwen2_5_omni + assert not is_custom_model(model), "Qwen2.5-Omni should not be in BUILTIN_MODULES" + + +class TestVisualKeysExclusion: + """Test that omni sub-modules are properly excluded from quantization.""" + + def test_visual_keys_contain_omni_keys(self): + from auto_round.compressors.mllm.utils import VISUAL_KEYS + + expected_keys = ["thinker", "talker", "audio", "token2wav", "code2wav", "audio_tower", "code_predictor"] + for key in expected_keys: + assert key in VISUAL_KEYS, f"'{key}' should be in VISUAL_KEYS" diff --git a/test/test_cuda/models/test_omni_model.py b/test/test_cuda/models/test_omni_model.py new file mode 100644 index 000000000..3b31de75f --- /dev/null +++ b/test/test_cuda/models/test_omni_model.py @@ -0,0 +1,176 @@ +# Copyright (c) 2026 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""CUDA integration tests for Qwen2.5-Omni and Qwen3-Omni-MoE quantization. + +Tests cover end-to-end quantization flow: +- Qwen2.5-Omni: loads real config from pretrained, reduces layers, random weights + (uses the shared ``tiny_qwen2_5_omni`` session-scoped fixture) +- Qwen3-Omni-MoE: fully synthetic tiny config with random weights + (no pretrained checkpoint needed — model is too large for CI) +- Quantize with AutoRound +- Save and reload +- Run inference on reloaded model +""" + +import os +import shutil + +import pytest +import torch +from transformers import ( + AutoTokenizer, + Qwen2_5OmniForConditionalGeneration, + Qwen3OmniMoeConfig, + Qwen3OmniMoeForConditionalGeneration, +) + +from auto_round import AutoRound + +from ...helpers import check_version, qwen2_5_omni_name_or_path + +pytestmark = [ + pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available"), + pytest.mark.skipif( + not check_version("transformers>=5.1.0"), + reason="Qwen-Omni models require transformers >= 5.1.0", + ), +] + + +# --------------------------------------------------------------------------- +# Fixture: tiny Qwen3-Omni-MoE +# Priority: use real config from qwen3_omni_name_or_path (skipped if absent); +# fall back to fully synthetic config using qwen_name_or_path tokenizer. +# --------------------------------------------------------------------------- +@pytest.fixture +def setup_qwen3_omni_moe(tiny_qwen3_omni_moe): + """Create a tiny Qwen3-Omni-MoE model. + + Uses the session-scoped ``tiny_qwen3_omni_moe`` fixture which loads the + real tokenizer/processor from ``qwen3_omni_name_or_path`` and builds a + model with reduced layers and random weights. + """ + model, tokenizer, processor = tiny_qwen3_omni_moe + output_dir = "./tmp/test_quantized_qwen3_omni_moe" + return model, tokenizer, processor, output_dir, model.config + + +# ========================= Qwen2.5-Omni Integration Tests ================== + + +class TestQwen2_5OmniQuantization: + """End-to-end quantization test for Qwen2.5-Omni (dense model).""" + + def test_quantize_and_reload(self, tiny_qwen2_5_omni): + """Quantize, save, reload, verify weights, and run inference.""" + model, tokenizer, processor = tiny_qwen2_5_omni + output_dir = "./tmp/test_quantized_qwen2_5_omni" + + # Quantize + autoround = AutoRound( + model, + tokenizer, + processor=processor, + nsamples=2, + iters=1, + seqlen=32, + ignore_layers="self_attn,lm_head", + ) + quantized_model, save_folder = autoround.quantize_and_save(format="auto_round", output_dir=output_dir) + assert quantized_model is not None, "Quantized model should not be None" + + # Copy model-specific files required for from_pretrained (e.g. spk_dict.pt for token2wav) + for extra_file in ["spk_dict.pt"]: + src = os.path.join(qwen2_5_omni_name_or_path, extra_file) + if os.path.exists(src): + shutil.copy2(src, output_dir) + + # Reload + loaded_model = Qwen2_5OmniForConditionalGeneration.from_pretrained(output_dir) + loaded_model.to("cuda") + + # Run inference on thinker + inp = torch.randint(0, 100, (1, 64)).to("cuda") + with torch.inference_mode(): + output = loaded_model.thinker(input_ids=inp) + assert output is not None, "Inference failed on reloaded model" + + # Cleanup + shutil.rmtree(output_dir, ignore_errors=True) + + +# ========================= Qwen3-Omni-MoE Integration Tests ================ + + +class TestQwen3OmniMoeQuantization: + """End-to-end quantization test for Qwen3-Omni-MoE.""" + + def test_quantize_and_reload(self, setup_qwen3_omni_moe): + """Quantize, save, reload, verify weights, and run inference.""" + model, tokenizer, processor, output_dir, config = setup_qwen3_omni_moe + + # Quantize + autoround = AutoRound( + model, + tokenizer, + processor=processor, + nsamples=2, + iters=1, + seqlen=32, + ignore_layers="self_attn,lm_head,mlp.gate", + ) + quantized_model, save_folder = autoround.quantize_and_save(format="auto_round", output_dir=output_dir) + assert quantized_model is not None, "Quantized model should not be None" + + # Reload + loaded_model = Qwen3OmniMoeForConditionalGeneration.from_pretrained(output_dir) + loaded_model.to("cuda") + + # Run inference on thinker + inp = torch.randint(0, 100, (1, 64)).to("cuda") + with torch.inference_mode(): + output = loaded_model.thinker(input_ids=inp) + assert output is not None, "Inference failed on reloaded model (thinker)" + + # Cleanup + shutil.rmtree(output_dir, ignore_errors=True) + + def test_quantize_mxfp4(self, setup_qwen3_omni_moe): + """Quantize with MXFP4 scheme and verify.""" + model, tokenizer, processor, output_dir, config = setup_qwen3_omni_moe + + autoround = AutoRound( + model, + tokenizer, + processor=processor, + scheme="MXFP4", + nsamples=2, + iters=1, + seqlen=32, + ignore_layers="self_attn,lm_head,mlp.gate", + ) + quantized_model, save_folder = autoround.quantize_and_save(format="auto_round", output_dir=output_dir) + assert quantized_model is not None, "MXFP4 quantized model should not be None" + + # Reload and inference + loaded_model = Qwen3OmniMoeForConditionalGeneration.from_pretrained(output_dir) + loaded_model.to("cuda") + + inp = torch.randint(0, 100, (1, 64)).to("cuda") + with torch.inference_mode(): + output = loaded_model.thinker(input_ids=inp) + assert output is not None + + shutil.rmtree(output_dir, ignore_errors=True)