Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions vllm_omni/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

# Register custom configs (AutoConfig, AutoTokenizer) as early as possible.
from vllm_omni.transformers_utils import configs as _configs # noqa: F401, E402
from vllm_omni.transformers_utils import parsers as _parsers # noqa: F401, E402

from .config import OmniModelConfig

Expand Down
4 changes: 0 additions & 4 deletions vllm_omni/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,6 @@ def _register_omni_hf_configs() -> None:
from vllm_omni.model_executor.models.qwen3_tts.configuration_qwen3_tts import (
Qwen3TTSConfig,
)
from vllm_omni.model_executor.models.voxtral_tts.configuration_voxtral_tts import (
VoxtralTTSConfig,
)
from vllm_omni.transformers_utils.configs.voxcpm import VoxCPMConfig
from vllm_omni.transformers_utils.configs.voxcpm2 import VoxCPM2Config
except Exception as exc: # pragma: no cover - best-effort optional registration
Expand All @@ -61,7 +58,6 @@ def _register_omni_hf_configs() -> None:
("qwen3_tts", Qwen3TTSConfig),
("cosyvoice3", CosyVoice3Config),
("omnivoice", OmniVoiceConfig),
("voxtral_tts", VoxtralTTSConfig),
("voxcpm", VoxCPMConfig),
("voxcpm2", VoxCPM2Config),
]:
Expand Down

This file was deleted.

3 changes: 3 additions & 0 deletions vllm_omni/transformers_utils/configs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"FishSpeechFastARConfig": "vllm_omni.transformers_utils.configs.fish_speech",
"VoxCPMConfig": "vllm_omni.transformers_utils.configs.voxcpm",
"VoxCPM2Config": "vllm_omni.transformers_utils.configs.voxcpm2",
"VoxtralTTSConfig": "vllm_omni.transformers_utils.configs.voxtral_tts",
"BailingMoeV2Config": "vllm_omni.transformers_utils.configs.ming_flash_omni",
"BailingMM2Config": "vllm_omni.transformers_utils.configs.ming_flash_omni",
"MingFlashOmniConfig": "vllm_omni.transformers_utils.configs.ming_flash_omni",
Expand All @@ -36,6 +37,7 @@
"FishSpeechFastARConfig",
"VoxCPMConfig",
"VoxCPM2Config",
"VoxtralTTSConfig",
"BailingMoeV2Config",
"BailingMM2Config",
"MingFlashOmniConfig",
Expand Down Expand Up @@ -64,3 +66,4 @@ def __dir__():
from vllm_omni.transformers_utils.configs import ming_flash_omni as _ming_flash_omni # noqa: F401, E402
from vllm_omni.transformers_utils.configs import voxcpm as _voxcpm # noqa: F401, E402
from vllm_omni.transformers_utils.configs import voxcpm2 as _voxcpm2 # noqa: F401, E402
from vllm_omni.transformers_utils.configs import voxtral_tts as _voxtral_tts # noqa: F401, E402
36 changes: 36 additions & 0 deletions vllm_omni/transformers_utils/configs/voxtral_tts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from __future__ import annotations

from typing import Any

from transformers import AutoConfig, PretrainedConfig


class VoxtralTTSConfig(PretrainedConfig):
"""HuggingFace-style config for Voxtral TTS models."""

model_type = "voxtral_tts"

def __init__(
self,
text_config: PretrainedConfig | dict | None = None,
audio_config: dict[str, Any] | None = None,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)

if isinstance(text_config, PretrainedConfig):
self.text_config = text_config
elif isinstance(text_config, dict):
self.text_config = PretrainedConfig.from_dict(text_config)
else:
self.text_config = PretrainedConfig()

self.audio_config = audio_config or {}

def get_text_config(self, **kwargs: Any) -> PretrainedConfig:
return self.text_config


AutoConfig.register("voxtral_tts", VoxtralTTSConfig)
Comment thread
yuanheng-zhao marked this conversation as resolved.

__all__ = ["VoxtralTTSConfig"]
29 changes: 29 additions & 0 deletions vllm_omni/transformers_utils/parsers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""Custom vLLM config parsers for vllm-omni."""

from __future__ import annotations

import importlib

_CLASS_TO_MODULE: dict[str, str] = {
"VoxtralTTSConfigParser": "vllm_omni.transformers_utils.parsers.voxtral_tts",
}

__all__ = ["VoxtralTTSConfigParser"]


def __getattr__(name: str):
if name in _CLASS_TO_MODULE:
module_name = _CLASS_TO_MODULE[name]
module = importlib.import_module(module_name)
return getattr(module, name)

raise AttributeError(f"module 'vllm_omni.transformers_utils.parsers' has no attribute {name!r}")


def __dir__():
return sorted(list(__all__))


# Eagerly import parser modules so their registry side-effects run as soon as
# `vllm_omni.transformers_utils.parsers` is imported.
from vllm_omni.transformers_utils.parsers import voxtral_tts as _voxtral_tts # noqa: F401, E402
106 changes: 106 additions & 0 deletions vllm_omni/transformers_utils/parsers/voxtral_tts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from __future__ import annotations

from pathlib import Path
from typing import Any

from transformers import PretrainedConfig
from vllm.logger import init_logger
from vllm.transformers_utils.config import (
_CONFIG_FORMAT_TO_CONFIG_PARSER,
MistralConfigParser,
_download_mistral_config_file,
)

from vllm_omni.transformers_utils.configs.voxtral_tts import VoxtralTTSConfig

logger = init_logger(__name__)

_VOXTRAL_TTS_ARCHS = frozenset({"VoxtralTTSForConditionalGeneration"})
_VOXTRAL_TTS_MODEL_TYPE = "voxtral_tts"


def _is_voxtral_tts_params(config_dict: dict) -> bool:
"""Return True if the Mistral params.json describes a Voxtral-TTS model"""
if config_dict.get("model_type") == _VOXTRAL_TTS_MODEL_TYPE:
return True
architectures = set(config_dict.get("architectures") or [])
return bool(architectures & _VOXTRAL_TTS_ARCHS)


def _remap_voxtral_tts_audio_args(config_dict: dict) -> dict:
encoder_args = config_dict["multimodal"].pop("audio_model_args")
audio_tokenizer_args = config_dict["multimodal"].pop("audio_tokenizer_args", None)
if encoder_args is None:
return {}

acoustic_args = encoder_args.get("acoustic_transformer_args", {})
if acoustic_args.get("n_decoding_steps") is None:
logger.warning(
"n_decoding_steps not provided in acoustic_transformer_args, defaulting to 7. "
"Please add 'n_decoding_steps' to params.json under acoustic_transformer_args."
)
acoustic_args["n_decoding_steps"] = 7

return {
"sampling_rate": encoder_args["audio_encoding_args"]["sampling_rate"],
"codec_args": audio_tokenizer_args,
"audio_model_args": encoder_args,
"speaker_id": (audio_tokenizer_args or {}).get("voice", {}),
}


def _parse_voxtral_tts(config_dict: dict) -> tuple[dict, PretrainedConfig]:
from vllm.transformers_utils.configs.mistral import (
_remap_general_mistral_args,
_remap_mistral_quantization_args,
)

audio_config: dict[str, Any] = {}
if (config_dict.get("multimodal") or {}).get("audio_model_args"):
audio_config = _remap_voxtral_tts_audio_args(config_dict)

text_config = {k: v for k, v in config_dict.items() if k != "multimodal"}
text_config = _remap_general_mistral_args(text_config)
if text_config.get("quantization"):
text_config = _remap_mistral_quantization_args(text_config)
text_config.setdefault("architectures", ["MistralForCausalLM"])

config = VoxtralTTSConfig(
text_config=PretrainedConfig.from_dict(text_config),
audio_config=audio_config,
architectures=config_dict.get("architectures", ["VoxtralTTSForConditionalGeneration"]),
)
return config_dict, config


class VoxtralTTSConfigParser(MistralConfigParser):
"""Mistral parser that also recognizes Voxtral-TTS checkpoints."""

def parse(
self,
model: str | Path,
trust_remote_code: bool,
revision: str | None = None,
code_revision: str | None = None,
**kwargs: Any,
) -> tuple[dict, PretrainedConfig]:
config_dict = _download_mistral_config_file(model, revision)

if _is_voxtral_tts_params(config_dict):
return _parse_voxtral_tts(config_dict)

return super().parse(
model,
trust_remote_code,
revision=revision,
code_revision=code_revision,
**kwargs,
)


# Replace the default "mistral" slot directly.
# Any non-Voxtral-TTS Mistral ckpt still goes through
# the upstream code path via super().parse().
_CONFIG_FORMAT_TO_CONFIG_PARSER["mistral"] = VoxtralTTSConfigParser

__all__ = ["VoxtralTTSConfigParser"]
Loading