diff --git a/examples/multimodal_vision/qwen3_omni_example.py b/examples/multimodal_vision/qwen3_omni_example.py new file mode 100644 index 0000000000..4753575ae2 --- /dev/null +++ b/examples/multimodal_vision/qwen3_omni_example.py @@ -0,0 +1,106 @@ +import requests +import soundfile as sf +from PIL import Image +from transformers import ( + AutoProcessor, + Qwen3OmniMoeForConditionalGeneration, + default_data_collator, +) + +from llmcompressor import oneshot +from llmcompressor.modeling.patch.qwen3_omni_patch import fast_pos_embed_interpolate +from llmcompressor.modifiers.quantization import GPTQModifier +from llmcompressor.transformers.compression.compressed_tensors_utils import ( + modify_save_pretrained, +) +from llmcompressor.utils import dispatch_for_generation + +# Load model. +model_id = "Qwen/Qwen3-Omni-30B-A3B-Instruct" +model = Qwen3OmniMoeForConditionalGeneration.from_pretrained( + model_id, torch_dtype="auto" +) +processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) + +# Apply patch to fix accelerate offloading, can be removed after #2148 +model.thinker.visual.fast_pos_embed_interpolate = fast_pos_embed_interpolate.__get__( + model.thinker.visual +) + +# Oneshot arguments +BATCH_SIZE = 1 +NUM_CALIBRATION_SAMPLES = 512 +MAX_SEQUENCE_LENGTH = 2048 +DATASET_ID = "flickr30k" +DATASET_SPLIT = {"calibration": f"test[:{NUM_CALIBRATION_SAMPLES}]"} + +# Recipe +recipe = [ + GPTQModifier( + targets="Linear", + scheme="W4A16", + ignore=[ + "lm_head", + r"re:.*visual.*", + r"re:.*code2wav.*", + ], + ), +] + + +def data_collator(features): + batch = default_data_collator(features) + batch["image_grid_thw"] = batch["image_grid_thw"].squeeze(0) + return batch + + +# Perform oneshot +oneshot( + model=model.thinker, # base model does not define forward: pass `thinker` instead + processor=processor, + dataset=DATASET_ID, + splits=DATASET_SPLIT, + recipe=recipe, + batch_size=BATCH_SIZE, + max_seq_length=MAX_SEQUENCE_LENGTH, + num_calibration_samples=NUM_CALIBRATION_SAMPLES, + data_collator=data_collator, +) + +# Confirm generations of the quantized model look sane. +print("========== SAMPLE GENERATION ==============") +dispatch_for_generation(model) +messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Please describe the animal in this image\n"}, + {"type": "image"}, + ], + }, +] +prompt = processor.apply_chat_template(messages, add_generation_prompt=True) +image_url = "http://images.cocodataset.org/train2017/000000231895.jpg" +raw_image = Image.open(requests.get(image_url, stream=True).raw) + +inputs = processor(images=raw_image, text=prompt, return_tensors="pt").to(model.device) +text_ids, audio = model.generate(**inputs, max_new_tokens=100, disable_compile=True) +text = processor.batch_decode( + text_ids.sequences[:, inputs["input_ids"].shape[1] :], + skip_special_tokens=True, + clean_up_tokenization_spaces=False, +) +print(text) +if audio is not None: + sf.write( + "sample_output.wav", + audio.reshape(-1).detach().cpu().numpy(), + samplerate=24000, + ) +print("==========================================") + +# Save to disk compressed. +modify_save_pretrained(model) +SAVE_DIR = model_id.rstrip("/").split("/")[-1] + "-W4A16-G128" +model.save_pretrained(SAVE_DIR, save_compressed=True) +processor.save_pretrained(SAVE_DIR) diff --git a/src/llmcompressor/modeling/patch/qwen3_omni_patch.py b/src/llmcompressor/modeling/patch/qwen3_omni_patch.py new file mode 100644 index 0000000000..ff8f79add9 --- /dev/null +++ b/src/llmcompressor/modeling/patch/qwen3_omni_patch.py @@ -0,0 +1,72 @@ +# flake8: noqa +# ruff: noqa + +import torch +from compressed_tensors import get_execution_device + + +def fast_pos_embed_interpolate(self, grid_thw): + grid_ts, grid_hs, grid_ws = grid_thw[:, 0], grid_thw[:, 1], grid_thw[:, 2] + + idx_list = [[] for _ in range(4)] + weight_list = [[] for _ in range(4)] + + for t, h, w in zip(grid_ts, grid_hs, grid_ws): + h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h) + w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w) + + h_idxs_floor = h_idxs.int() + w_idxs_floor = w_idxs.int() + h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) + w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) + + dh = h_idxs - h_idxs_floor + dw = w_idxs - w_idxs_floor + + base_h = h_idxs_floor * self.num_grid_per_side + base_h_ceil = h_idxs_ceil * self.num_grid_per_side + + indices = [ + (base_h[None].T + w_idxs_floor[None]).flatten(), + (base_h[None].T + w_idxs_ceil[None]).flatten(), + (base_h_ceil[None].T + w_idxs_floor[None]).flatten(), + (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(), + ] + + weights = [ + ((1 - dh)[None].T * (1 - dw)[None]).flatten(), + ((1 - dh)[None].T * dw[None]).flatten(), + (dh[None].T * (1 - dw)[None]).flatten(), + (dh[None].T * dw[None]).flatten(), + ] + + for i in range(4): + idx_list[i].extend(indices[i].tolist()) + weight_list[i].extend(weights[i].tolist()) + + # PATCH: do not rely on `pos_embed.weight`, which may be offloaded + device = get_execution_device(self.pos_embed) + + idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=device) + weight_tensor = torch.tensor( + weight_list, dtype=self.pos_embed.weight.dtype, device=device + ) + pos_embeds = self.pos_embed(idx_tensor) * weight_tensor[:, :, None] + patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3] + + patch_pos_embeds = patch_pos_embeds.split([h * w for h, w in zip(grid_hs, grid_ws)]) + + patch_pos_embeds_permute = [] + merge_size = self.config.spatial_merge_size + for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws): + pos_embed = pos_embed.repeat(t, 1) + pos_embed = ( + pos_embed.view( + t, h // merge_size, merge_size, w // merge_size, merge_size, -1 + ) + .permute(0, 1, 3, 2, 4, 5) + .flatten(0, 4) + ) + patch_pos_embeds_permute.append(pos_embed) + patch_pos_embeds = torch.cat(patch_pos_embeds_permute) + return patch_pos_embeds