Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections.abc import Iterable
from dataclasses import replace
from functools import cached_property

import torch
Expand All @@ -18,6 +19,11 @@
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.sampler import Sampler

from vllm_omni.quantization.component_config import (
PRE_QUANTIZED_METHODS,
ComponentQuantizationConfig,
)


class Qwen2_5OmniTalkerForConditionalGeneration(
nn.Module,
Expand All @@ -41,6 +47,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config: Qwen2_5OmniTalkerConfig = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
if isinstance(quant_config, ComponentQuantizationConfig):
quant_config = quant_config.resolve("talker")
elif quant_config is not None and quant_config.get_name() not in PRE_QUANTIZED_METHODS:
quant_config = None
vllm_config = replace(vllm_config, quant_config=None)
self.vllm_config = vllm_config
self.prefix = prefix
self.quant_config = quant_config
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Thin Omni wrapper: reuse upstream Qwen2.5-Omni thinker (v0.14) with minimal overrides."""

from collections.abc import Iterable, Mapping
from dataclasses import replace
from typing import Any

import torch
Expand Down Expand Up @@ -65,7 +66,8 @@
from vllm.sequence import IntermediateTensors

from vllm_omni.quantization.component_config import (
resolve_encoder_quant_config,
PRE_QUANTIZED_METHODS,
ComponentQuantizationConfig,
)

try:
Expand Down Expand Up @@ -372,10 +374,26 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.quant_config = quant_config

# Pre-quantized checkpoints (modelopt NVFP4/FP8/MXFP8) only quantize
# the Thinker LM. Vision encoder weights remain in BF16 with no FP8
# scale tensors; passing quant_config causes FP8 kernels to run on
# BF16 weights, producing garbage embeddings. Keep None for encoders.
visual_quant_config = resolve_encoder_quant_config(quant_config)
# the Thinker LM (language model). Vision and audio encoder weights
# remain in BF16 and have no corresponding scale tensors in the
# checkpoint. Dynamic quantization methods (e.g. --quantization fp8)
# should also only target the language model.
visual_prefix = maybe_prefix(prefix, "visual")
language_prefix = maybe_prefix(prefix, "language_model")
if isinstance(quant_config, ComponentQuantizationConfig):
visual_quant_config = quant_config.resolve(visual_prefix)
elif quant_config is not None:
if quant_config.get_name() in PRE_QUANTIZED_METHODS:
visual_quant_config = None
else:
quant_config = ComponentQuantizationConfig(
component_configs={language_prefix: quant_config},
default_config=None,
)
vllm_config = replace(vllm_config, quant_config=quant_config)
visual_quant_config = None
else:
visual_quant_config = None

with self._mark_tower_model(vllm_config, "audio"):
if multimodal_config.get_limit_per_prompt("audio"):
Expand Down
Loading