diff --git a/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_talker.py b/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_talker.py index 35d988895b3..4468bba86f9 100644 --- a/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_talker.py +++ b/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_talker.py @@ -1,4 +1,5 @@ from collections.abc import Iterable +from dataclasses import replace from functools import cached_property import torch @@ -18,6 +19,11 @@ from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.sampler import Sampler +from vllm_omni.quantization.component_config import ( + PRE_QUANTIZED_METHODS, + ComponentQuantizationConfig, +) + class Qwen2_5OmniTalkerForConditionalGeneration( nn.Module, @@ -41,6 +47,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config: Qwen2_5OmniTalkerConfig = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config + if isinstance(quant_config, ComponentQuantizationConfig): + quant_config = quant_config.resolve("talker") + elif quant_config is not None and quant_config.get_name() not in PRE_QUANTIZED_METHODS: + quant_config = None + vllm_config = replace(vllm_config, quant_config=None) self.vllm_config = vllm_config self.prefix = prefix self.quant_config = quant_config diff --git a/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_thinker.py b/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_thinker.py index 931049d8d44..f28d798b15a 100644 --- a/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_thinker.py +++ b/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_thinker.py @@ -1,6 +1,7 @@ """Thin Omni wrapper: reuse upstream Qwen2.5-Omni thinker (v0.14) with minimal overrides.""" from collections.abc import Iterable, Mapping +from dataclasses import replace from typing import Any import torch @@ -65,7 +66,8 @@ from vllm.sequence import IntermediateTensors from vllm_omni.quantization.component_config import ( - resolve_encoder_quant_config, + PRE_QUANTIZED_METHODS, + ComponentQuantizationConfig, ) try: @@ -372,10 +374,26 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.quant_config = quant_config # Pre-quantized checkpoints (modelopt NVFP4/FP8/MXFP8) only quantize - # the Thinker LM. Vision encoder weights remain in BF16 with no FP8 - # scale tensors; passing quant_config causes FP8 kernels to run on - # BF16 weights, producing garbage embeddings. Keep None for encoders. - visual_quant_config = resolve_encoder_quant_config(quant_config) + # the Thinker LM (language model). Vision and audio encoder weights + # remain in BF16 and have no corresponding scale tensors in the + # checkpoint. Dynamic quantization methods (e.g. --quantization fp8) + # should also only target the language model. + visual_prefix = maybe_prefix(prefix, "visual") + language_prefix = maybe_prefix(prefix, "language_model") + if isinstance(quant_config, ComponentQuantizationConfig): + visual_quant_config = quant_config.resolve(visual_prefix) + elif quant_config is not None: + if quant_config.get_name() in PRE_QUANTIZED_METHODS: + visual_quant_config = None + else: + quant_config = ComponentQuantizationConfig( + component_configs={language_prefix: quant_config}, + default_config=None, + ) + vllm_config = replace(vllm_config, quant_config=quant_config) + visual_quant_config = None + else: + visual_quant_config = None with self._mark_tower_model(vllm_config, "audio"): if multimodal_config.get_limit_per_prompt("audio"):