Skip to content

[Examples] QwenOmni Example#2125

Merged
dsikka merged 5 commits intomainfrom
kylesayrs/qwen_omni
Jan 14, 2026
Merged

[Examples] QwenOmni Example#2125
dsikka merged 5 commits intomainfrom
kylesayrs/qwen_omni

Conversation

@kylesayrs
Copy link
Collaborator

@kylesayrs kylesayrs commented Dec 14, 2025

Purpose

Changes

  • Pass model.thinker to oneshot, since model does not implement a forward method (the thinker module is a PreTrainedModel that contains all of the parameters worth quantizing)
  • Patch the fast_pos_embed_interpolate to support accelerate offloading
  • Patch dataloader to squeeze image_grid_thw, but leave pixel_values and other inputs unsqueezed
  • Save example audio generation

Testing

  • Ran example e2e and produced sane outputs
    output.wav
['Based on the image provided, the animal is a small, white kitten.\n\nHere are some specific details about the animal:\n\
n*   **Species and Age:** It is a very young domestic cat, commonly known as a kitten.\n*   **Color and Coat:** The kitten
 has a pure white coat. Its fur appears soft and fluffy.\n*   **Posture and Location:** It is curled up and sleeping sound
ly on a white computer keyboard, which is placed on a wooden desk.\n*   **Features:** The kitten has its eyes closed, indi
cating it is asleep. It has small, pointed ears and a small nose.']

@github-actions
Copy link

👋 Hi! Thank you for contributing to llm-compressor. Please add the ready label when the PR is ready for review.

Note: This is required to complete the testing suite, please only add the label once the PR is code complete and local testing has been performed.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @kylesayrs, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request integrates the Qwen3OmniMoeForConditionalGeneration model by providing a comprehensive example for its quantization using llmcompressor. It addresses specific technical requirements such as patching the position embedding interpolation for offloading compatibility and adjusting data collation for multimodal inputs, ultimately enabling efficient and functional use of this advanced model.

Highlights

  • Qwen3OmniMoeForConditionalGeneration Support: This pull request introduces support for the Qwen3OmniMoeForConditionalGeneration model, enabling its use within the system.
  • New Example Script: A new example script, qwen3_omni_example.py, has been added to demonstrate the quantization of the Qwen3OmniMoeForConditionalGeneration model using llmcompressor's oneshot functionality.
  • Position Embedding Patch: The fast_pos_embed_interpolate function has been patched to correctly handle scenarios where pos_embed.weight might be offloaded, ensuring compatibility with accelerate offloading strategies.
  • Custom Data Collator: A custom data_collator is implemented to properly squeeze the image_grid_thw tensor while leaving other inputs unsqueezed, which is crucial for correct data processing during calibration.
  • Audio Generation Example: The example now includes saving generated audio output to sample_output.wav, showcasing the model's multimodal capabilities.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new example for Qwen3OmniMoeForConditionalGeneration, including a patch to support accelerated offloading. The example script demonstrates how to perform one-shot quantization with GPTQ and generate sample outputs. The changes are well-structured and the example is clear. My review includes a suggestion to improve the performance of the patch file by using more efficient tensor operations, and a comment on improving the clarity of the example script's save directory naming.

@brian-dellabetta
Copy link
Collaborator

Related to #1673 as well

@allerou4
Copy link

allerou4 commented Dec 17, 2025

Hi, I pass model.thinker to quantize, but model.save_pretrained saved a full bf16 model
If I save model.thinker only, it's correct

@Sekri0
Copy link

Sekri0 commented Dec 17, 2025

Hi, I pass model.thinker to quantize, but model.save_pretrained saved a full bf16 model If I save model.thinker only, it's correct

This is because in 'oneshot' pre_precess, 'model.save_pretrained' will be modified for saving compressed model. Passing model.thinker to oneshot means model.save_pretrained is not modified. Just add the code below before you save the model will fix the problem.

add the following code before save

from llmcompressor.transformers.compression.compressed_tensors_utils import modify_save_pretrained
modify_save_pretrained(model)

image

@allerou4
Copy link

Hi, I pass model.thinker to quantize, but model.save_pretrained saved a full bf16 model If I save model.thinker only, it's correct

This is because in 'oneshot' pre_precess, 'model.save_pretrained' will be modified for saving compressed model. Passing model.thinker to oneshot means model.save_pretrained is not modified. Just add the code below before you save the model will fix the problem.

add the following code before save

from llmcompressor.transformers.compression.compressed_tensors_utils import modify_save_pretrained modify_save_pretrained(model)

image

why does it save nothing?

# Apply AWQ quantization.
oneshot(
    model=model.thinker,
    processor=processor,
    recipe=recipe,
    dataset=ds,
    max_seq_length=MAX_SEQUENCE_LENGTH,
    num_calibration_samples=NUM_CALIBRATION_SAMPLES,
    data_collator=data_collator,
)

# calibrated_model.save_pretrained(SAVE_DIR, save_compressed=True)

print("========== SAMPLE GENERATION ==============")
# dispatch_for_generation(model)
modify_save_pretrained(model)
SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-W4A16-G128"
model.save_pretrained(SAVE_DIR, save_compressed=True)
processor.save_pretrained(SAVE_DIR)

@Sekri0
Copy link

Sekri0 commented Dec 17, 2025

Hi, I pass model.thinker to quantize, but model.save_pretrained saved a full bf16 model If I save model.thinker only, it's correct

This is because in 'oneshot' pre_precess, 'model.save_pretrained' will be modified for saving compressed model. Passing model.thinker to oneshot means model.save_pretrained is not modified. Just add the code below before you save the model will fix the problem.

add the following code before save

from llmcompressor.transformers.compression.compressed_tensors_utils import modify_save_pretrained modify_save_pretrained(model)
image

why does it save nothing?

# Apply AWQ quantization.
oneshot(
    model=model.thinker,
    processor=processor,
    recipe=recipe,
    dataset=ds,
    max_seq_length=MAX_SEQUENCE_LENGTH,
    num_calibration_samples=NUM_CALIBRATION_SAMPLES,
    data_collator=data_collator,
)

# calibrated_model.save_pretrained(SAVE_DIR, save_compressed=True)

print("========== SAMPLE GENERATION ==============")
# dispatch_for_generation(model)
modify_save_pretrained(model)
SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-W4A16-G128"
model.save_pretrained(SAVE_DIR, save_compressed=True)
processor.save_pretrained(SAVE_DIR)

I tried with no calibration quant, successfully saved compressed model.

CODE

import requests
import soundfile as sf
from PIL import Image
from qwen3_omni_patch import fast_pos_embed_interpolate
from transformers import (
AutoProcessor,
Qwen3OmniMoeForConditionalGeneration,
default_data_collator,
)

from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import GPTQModifier, QuantizationModifier
from llmcompressor.utils import dispatch_for_generation
from llmcompressor.transformers.compression.compressed_tensors_utils import modify_save_pretrained

model_id = "/mnt/home/model/Qwen3-Omni-30B-A3B-Instruct"
model = Qwen3OmniMoeForConditionalGeneration.from_pretrained(
model_id, torch_dtype="auto"
)
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)

scheme = "W4A16"

recipe = [
QuantizationModifier(
targets="Linear",
scheme=scheme,
ignore=[
r"re:.lm_head.",
r"re:.talker.",
r"re:.code2wav.",
r"re:.mlp.gate.",
r"re:.audio_tower.",
r"re:.visual.",
],
),
]

oneshot(
model=model.thinker,
processor=processor,
recipe=recipe,
)

modify_save_pretrained(model)

SAVE_DIR = model_id.rstrip("/").split("/")[-1] + "-thinker-" + scheme
model.save_pretrained(SAVE_DIR, save_compressed=True)
processor.save_pretrained(SAVE_DIR)

@allerou4
Copy link

@Sekri0 Thanks, it worked

@kylesayrs
Copy link
Collaborator Author

kylesayrs commented Dec 17, 2025

Awesome catch, thank you @Sekri0 @allerou4 !

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
@kylesayrs kylesayrs added the ready When a PR is ready for review label Jan 8, 2026
@mergify
Copy link
Contributor

mergify bot commented Jan 14, 2026

Documentation update

@mergify mergify bot added the documentation Improvements or additions to documentation label Jan 14, 2026
Copy link
Collaborator

@brian-dellabetta brian-dellabetta left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One comment regarding patch file placement

dsikka
dsikka previously approved these changes Jan 14, 2026
Copy link
Collaborator

@dsikka dsikka left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small nits. LGTM. Thanks!

@dsikka dsikka added the qwen For any PR / issue related to Qwen support label Jan 14, 2026
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
@dsikka dsikka merged commit 4876c4a into main Jan 14, 2026
20 of 21 checks passed
@dsikka dsikka deleted the kylesayrs/qwen_omni branch January 14, 2026 20:44
@JartX
Copy link
Contributor

JartX commented Jan 27, 2026

Hi! Any way to quant qwen3 omni using awq?

@JartX
Copy link
Contributor

JartX commented Jan 27, 2026

After quantifying with the script, I encounter the following error

Loading safetensors checkpoint shards:  67% 4/6 [01:15<00:34, 17.06s/it](Worker_TP0 pid=116) ERROR 01-27 21:48:47 [multiproc_executor.py:772] WorkerProc failed to start.
vllm1-1  | (Worker_TP0 pid=116) ERROR 01-27 21:48:47 [multiproc_executor.py:772] Traceback (most recent call last):
vllm1-1  | (Worker_TP0 pid=116) ERROR 01-27 21:48:47 [multiproc_executor.py:772]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/executor/multiproc_executor.py", line 743, in worker_main
vllm1-1  | (Worker_TP0 pid=116) ERROR 01-27 21:48:47 [multiproc_executor.py:772]     worker = WorkerProc(*args, **kwargs)
vllm1-1  | (Worker_TP0 pid=116) ERROR 01-27 21:48:47 [multiproc_executor.py:772]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^
vllm1-1  | (Worker_TP0 pid=116) ERROR 01-27 21:48:47 [multiproc_executor.py:772]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/executor/multiproc_executor.py", line 578, in __init__
vllm1-1  | (Worker_TP0 pid=116) ERROR 01-27 21:48:47 [multiproc_executor.py:772]     self.worker.load_model()
vllm1-1  | (Worker_TP0 pid=116) ERROR 01-27 21:48:47 [multiproc_executor.py:772]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/gpu_worker.py", line 275, in load_model
vllm1-1  | (Worker_TP0 pid=116) ERROR 01-27 21:48:47 [multiproc_executor.py:772]     self.model_runner.load_model(eep_scale_up=eep_scale_up)
vllm1-1  | (Worker_TP0 pid=116) ERROR 01-27 21:48:47 [multiproc_executor.py:772]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/gpu_model_runner.py", line 3993, in load_model
vllm1-1  | (Worker_TP0 pid=116) ERROR 01-27 21:48:47 [multiproc_executor.py:772]     self.model = model_loader.load_model(
vllm1-1  | (Worker_TP0 pid=116) ERROR 01-27 21:48:47 [multiproc_executor.py:772]                  ^^^^^^^^^^^^^^^^^^^^^^^^
vllm1-1  | (Worker_TP0 pid=116) ERROR 01-27 21:48:47 [multiproc_executor.py:772]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/model_loader/base_loader.py", line 58, in load_model
vllm1-1  | (Worker_TP0 pid=116) ERROR 01-27 21:48:47 [multiproc_executor.py:772]     self.load_weights(model, model_config)
vllm1-1  | (Worker_TP0 pid=116) ERROR 01-27 21:48:47 [multiproc_executor.py:772]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/model_loader/default_loader.py", line 288, in load_weights
vllm1-1  | (Worker_TP0 pid=116) ERROR 01-27 21:48:47 [multiproc_executor.py:772]     loaded_weights = model.load_weights(self.get_all_weights(model_config, model))
vllm1-1  | (Worker_TP0 pid=116) ERROR 01-27 21:48:47 [multiproc_executor.py:772]                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
vllm1-1  | (Worker_TP0 pid=116) ERROR 01-27 21:48:47 [multiproc_executor.py:772]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/qwen3_omni_moe_thinker.py", line 1869, in load_weights
vllm1-1  | (Worker_TP0 pid=116) ERROR 01-27 21:48:47 [multiproc_executor.py:772]     loaded_weights = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
vllm1-1  | (Worker_TP0 pid=116) ERROR 01-27 21:48:47 [multiproc_executor.py:772]                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
vllm1-1  | (Worker_TP0 pid=116) ERROR 01-27 21:48:47 [multiproc_executor.py:772]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/model_loader/online_quantization.py", line 173, in patched_model_load_weights
vllm1-1  | (Worker_TP0 pid=116) ERROR 01-27 21:48:47 [multiproc_executor.py:772]     return original_load_weights(auto_weight_loader, weights, mapper=mapper)
vllm1-1  | (Worker_TP0 pid=116) ERROR 01-27 21:48:47 [multiproc_executor.py:772]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
vllm1-1  | (Worker_TP0 pid=116) ERROR 01-27 21:48:47 [multiproc_executor.py:772]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/utils.py", line 342, in load_weights
vllm1-1  | (Worker_TP0 pid=116) ERROR 01-27 21:48:47 [multiproc_executor.py:772]     autoloaded_weights = set(self._load_module("", self.module, weights))
vllm1-1  | (Worker_TP0 pid=116) ERROR 01-27 21:48:47 [multiproc_executor.py:772]                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
vllm1-1  | (Worker_TP0 pid=116) ERROR 01-27 21:48:47 [multiproc_executor.py:772]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/utils.py", line 290, in _load_module
vllm1-1  | (Worker_TP0 pid=116) ERROR 01-27 21:48:47 [multiproc_executor.py:772]     yield from self._load_module(
vllm1-1  | (Worker_TP0 pid=116) ERROR 01-27 21:48:47 [multiproc_executor.py:772]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/utils.py", line 263, in _load_module
vllm1-1  | (Worker_TP0 pid=116) ERROR 01-27 21:48:47 [multiproc_executor.py:772]     loaded_params = module_load_weights(weights)
vllm1-1  | (Worker_TP0 pid=116) ERROR 01-27 21:48:47 [multiproc_executor.py:772]                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
vllm1-1  | (Worker_TP0 pid=116) ERROR 01-27 21:48:47 [multiproc_executor.py:772]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/qwen3_omni_moe_thinker.py", line 507, in load_weights
vllm1-1  | (Worker_TP0 pid=116) ERROR 01-27 21:48:47 [multiproc_executor.py:772]     param = params_dict[name]
vllm1-1  | (Worker_TP0 pid=116) ERROR 01-27 21:48:47 [multiproc_executor.py:772]             ~~~~~~~~~~~^^^^^^
vllm1-1  | (Worker_TP0 pid=116) ERROR 01-27 21:48:47 [multiproc_executor.py:772] KeyError: 'layers.0.self_attn.qkv.weight_packed'
Loading safetensors checkpoint shards:  67% 4/6 [01:15<00:37, 18.88s/it]
vllm1-1  | (Worker_TP0 pid=116) INFO 01-27 21:48:47 [multiproc_executor.py:730] Parent process exited, terminating worker
vllm1-1  | (Worker_TP1 pid=137) INFO 01-27 21:48:47 [multiproc_executor.py:730] Parent process exited, terminating worker
vllm1-1  | [rank0]:[W127 21:48:47.075485175 ProcessGroupNCCL.cpp:1524] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
vllm1-1  | (Worker_TP1 pid=137) ERROR 01-27 21:48:48 [multiproc_executor.py:772] WorkerProc failed to start.
vllm1-1  | (Worker_TP1 pid=137) ERROR 01-27 21:48:48 [multiproc_executor.py:772] Traceback (most recent call last):
vllm1-1  | (Worker_TP1 pid=137) ERROR 01-27 21:48:48 [multiproc_executor.py:772]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/executor/multiproc_executor.py", line 743, in worker_main
vllm1-1  | (Worker_TP1 pid=137) ERROR 01-27 21:48:48 [multiproc_executor.py:772]     worker = WorkerProc(*args, **kwargs)
vllm1-1  | (Worker_TP1 pid=137) ERROR 01-27 21:48:48 [multiproc_executor.py:772]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^
vllm1-1  | (Worker_TP1 pid=137) ERROR 01-27 21:48:48 [multiproc_executor.py:772]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/executor/multiproc_executor.py", line 578, in __init__
vllm1-1  | (Worker_TP1 pid=137) ERROR 01-27 21:48:48 [multiproc_executor.py:772]     self.worker.load_model()
vllm1-1  | (Worker_TP1 pid=137) ERROR 01-27 21:48:48 [multiproc_executor.py:772]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/gpu_worker.py", line 275, in load_model
vllm1-1  | (Worker_TP1 pid=137) ERROR 01-27 21:48:48 [multiproc_executor.py:772]     self.model_runner.load_model(eep_scale_up=eep_scale_up)
vllm1-1  | (Worker_TP1 pid=137) ERROR 01-27 21:48:48 [multiproc_executor.py:772]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/gpu_model_runner.py", line 3993, in load_model
vllm1-1  | (Worker_TP1 pid=137) ERROR 01-27 21:48:48 [multiproc_executor.py:772]     self.model = model_loader.load_model(
vllm1-1  | (Worker_TP1 pid=137) ERROR 01-27 21:48:48 [multiproc_executor.py:772]                  ^^^^^^^^^^^^^^^^^^^^^^^^
vllm1-1  | (Worker_TP1 pid=137) ERROR 01-27 21:48:48 [multiproc_executor.py:772]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/model_loader/base_loader.py", line 58, in load_model
vllm1-1  | (Worker_TP1 pid=137) ERROR 01-27 21:48:48 [multiproc_executor.py:772]     self.load_weights(model, model_config)
vllm1-1  | (Worker_TP1 pid=137) ERROR 01-27 21:48:48 [multiproc_executor.py:772]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/model_loader/default_loader.py", line 288, in load_weights
vllm1-1  | (Worker_TP1 pid=137) ERROR 01-27 21:48:48 [multiproc_executor.py:772]     loaded_weights = model.load_weights(self.get_all_weights(model_config, model))
vllm1-1  | (Worker_TP1 pid=137) ERROR 01-27 21:48:48 [multiproc_executor.py:772]                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
vllm1-1  | (Worker_TP1 pid=137) ERROR 01-27 21:48:48 [multiproc_executor.py:772]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/qwen3_omni_moe_thinker.py", line 1869, in load_weights
vllm1-1  | (Worker_TP1 pid=137) ERROR 01-27 21:48:48 [multiproc_executor.py:772]     loaded_weights = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
vllm1-1  | (Worker_TP1 pid=137) ERROR 01-27 21:48:48 [multiproc_executor.py:772]                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
vllm1-1  | (Worker_TP1 pid=137) ERROR 01-27 21:48:48 [multiproc_executor.py:772]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/model_loader/online_quantization.py", line 173, in patched_model_load_weights
vllm1-1  | (Worker_TP1 pid=137) ERROR 01-27 21:48:48 [multiproc_executor.py:772]     return original_load_weights(auto_weight_loader, weights, mapper=mapper)
vllm1-1  | (Worker_TP1 pid=137) ERROR 01-27 21:48:48 [multiproc_executor.py:772]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
vllm1-1  | (Worker_TP1 pid=137) ERROR 01-27 21:48:48 [multiproc_executor.py:772]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/utils.py", line 342, in load_weights
vllm1-1  | (Worker_TP1 pid=137) ERROR 01-27 21:48:48 [multiproc_executor.py:772]     autoloaded_weights = set(self._load_module("", self.module, weights))
vllm1-1  | (Worker_TP1 pid=137) ERROR 01-27 21:48:48 [multiproc_executor.py:772]                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
vllm1-1  | (Worker_TP1 pid=137) ERROR 01-27 21:48:48 [multiproc_executor.py:772]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/utils.py", line 290, in _load_module
vllm1-1  | (Worker_TP1 pid=137) ERROR 01-27 21:48:48 [multiproc_executor.py:772]     yield from self._load_module(
vllm1-1  | (Worker_TP1 pid=137) ERROR 01-27 21:48:48 [multiproc_executor.py:772]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/utils.py", line 263, in _load_module
vllm1-1  | (Worker_TP1 pid=137) ERROR 01-27 21:48:48 [multiproc_executor.py:772]     loaded_params = module_load_weights(weights)
vllm1-1  | (Worker_TP1 pid=137) ERROR 01-27 21:48:48 [multiproc_executor.py:772]                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
vllm1-1  | (Worker_TP1 pid=137) ERROR 01-27 21:48:48 [multiproc_executor.py:772]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/qwen3_omni_moe_thinker.py", line 507, in load_weights
vllm1-1  | (Worker_TP1 pid=137) ERROR 01-27 21:48:48 [multiproc_executor.py:772]     param = params_dict[name]
vllm1-1  | (Worker_TP1 pid=137) ERROR 01-27 21:48:48 [multiproc_executor.py:772]             ~~~~~~~~~~~^^^^^^
vllm1-1  | (Worker_TP1 pid=137) ERROR 01-27 21:48:48 [multiproc_executor.py:772] KeyError: 'layers.0.self_attn.qkv.weight_packed'
vllm1-1  | (EngineCore_DP0 pid=76) ERROR 01-27 21:48:49 [core.py:935] EngineCore failed to start.
vllm1-1  | (EngineCore_DP0 pid=76) ERROR 01-27 21:48:49 [core.py:935] Traceback (most recent call last):
vllm1-1  | (EngineCore_DP0 pid=76) ERROR 01-27 21:48:49 [core.py:935]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 926, in run_engine_core

@JartX
Copy link
Contributor

JartX commented Jan 30, 2026

@kylesayrs the ignore layers solve this error: #2125 (comment)

recipe = [
    GPTQModifier(
        targets="Linear",
        scheme="W4A16",
        ignore=[
        "re:.*visual.*",
        "re:.*code2wav.*",
        "re:.*audio_tower.*",
        "re:^talker\..*",
        "re:.*embed_tokens",
        "re:.*mlp\.gate$",
        "re:.*shared_expert_gate$",
        "re:.*input_layernorm$",
        "re:.*post_attention_layernorm$",
        "re:.*norm$",
        "re:.*lm_head$"
    	],
    ),
]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation qwen For any PR / issue related to Qwen support ready When a PR is ready for review

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug]: get_input_embeddings not auto‑handled for Qwen3OmniMoeForConditionalGeneration

6 participants