diff --git a/.buildkite/test-ready.yml b/.buildkite/test-ready.yml index 6f3ad6504e8..3d9810447c4 100644 --- a/.buildkite/test-ready.yml +++ b/.buildkite/test-ready.yml @@ -317,6 +317,31 @@ steps: volumes: - "/fsx/hf_cache:/fsx/hf_cache" + - label: "VoxCPM2 Native AR E2E Test" + timeout_in_minutes: 20 + depends_on: upload-ready-pipeline + commands: + - | + timeout 20m bash -c ' + pip install voxcpm + export VLLM_LOGGING_LEVEL=DEBUG + export VLLM_WORKER_MULTIPROC_METHOD=spawn + pytest -s -v tests/e2e/offline_inference/test_voxcpm2.py -m "core_model" --run-level "core_model" + ' + agents: + queue: "gpu_1_queue" + plugins: + - docker#v5.2.0: + image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT + always-pull: true + propagate-environment: true + shm-size: "8gb" + environment: + - "HF_HOME=/fsx/hf_cache" + - "HF_TOKEN" + volumes: + - "/fsx/hf_cache:/fsx/hf_cache" + - label: "OmniVoice E2E Test" timeout_in_minutes: 20 depends_on: upload-ready-pipeline diff --git a/examples/offline_inference/voxcpm2/README.md b/examples/offline_inference/voxcpm2/README.md new file mode 100644 index 00000000000..df48a85f569 --- /dev/null +++ b/examples/offline_inference/voxcpm2/README.md @@ -0,0 +1,83 @@ +# VoxCPM2 Offline Inference (Native AR) + +VoxCPM2 is a 2B-parameter tokenizer-free diffusion AR TTS model. It produces 48kHz audio and supports 30+ languages with a single-stage native AR pipeline backed by MiniCPM4. + +## Prerequisites + +Install the `voxcpm` package, or set the environment variable pointing to the source tree: + +```bash +# Option A: install package +pip install voxcpm + +# Option B: use source checkout +export VLLM_OMNI_VOXCPM_CODE_PATH=/path/to/voxcpm +``` + +## Quick Start + +Zero-shot synthesis: + +```bash +python examples/offline_inference/voxcpm2/end2end.py \ + --model openbmb/VoxCPM2 \ + --text "Hello, this is a VoxCPM2 demo." \ + --output-dir output_audio +``` + +Voice cloning with a reference audio: + +```bash +python examples/offline_inference/voxcpm2/end2end.py \ + --text "Hello, this is a voice clone demo." \ + --reference-audio /path/to/reference.wav \ + --output-dir output_clone +``` + +Prompt continuation (matched audio + text prefix): + +```bash +python examples/offline_inference/voxcpm2/end2end.py \ + --text "Continuation target sentence." \ + --prompt-audio /path/to/prompt.wav \ + --prompt-text "Transcript of the prompt audio." \ + --output-dir output_cont +``` + +The script accepts the following arguments: + +| Argument | Default | Description | +|---|---|---| +| `--model` | `openbmb/VoxCPM2` | HuggingFace repo ID or local path | +| `--text` | (example sentence) | Text to synthesize | +| `--output-dir` | `output_audio` | Directory for output WAV files | +| `--stage-configs-path` | `voxcpm2.yaml` | Stage config YAML path | +| `--reference-audio` | `None` | Reference audio for voice cloning (isolated) | +| `--prompt-audio` | `None` | Prompt audio for continuation mode | +| `--prompt-text` | `None` | Transcript matching `--prompt-audio` | + +## Performance + +Measured on a single H20 GPU (80 GB), voxcpm 0.0.0, PyTorch 2.10.0+cu128: + +| Input length | RTF | Sample rate | +|---|---|---| +| Short (~6 words) | ~0.81 | 48 kHz | +| Long (~50 words) | ~0.72 | 48 kHz | + +RTF < 1.0 means faster than real time. + +## Architecture + +VoxCPM2 uses a single-stage native AR pipeline: + +``` +feat_encoder +└─► MiniCPM4 (base LM) + └─► FSQ (finite scalar quantization) + └─► residual_lm (residual AR) + └─► LocDiT (local diffusion transformer) + └─► AudioVAE → 48 kHz waveform +``` + +All stages are fused into one vllm-native execution graph via `voxcpm2.yaml`, eliminating inter-stage coordination overhead and enabling true end-to-end batching. diff --git a/examples/offline_inference/voxcpm2/end2end.py b/examples/offline_inference/voxcpm2/end2end.py new file mode 100644 index 00000000000..2dce7508975 --- /dev/null +++ b/examples/offline_inference/voxcpm2/end2end.py @@ -0,0 +1,145 @@ +"""Offline VoxCPM2 inference example (native AR pipeline). + +Uses the single-stage native AR config (voxcpm2.yaml). +Requires the `voxcpm` package or VLLM_OMNI_VOXCPM_CODE_PATH env var. +""" + +from __future__ import annotations + +import os +import time +from pathlib import Path + +import soundfile as sf +import torch +from vllm.utils.argparse_utils import FlexibleArgumentParser + +from vllm_omni import Omni + +REPO_ROOT = Path(__file__).resolve().parents[3] +DEFAULT_STAGE_CONFIGS_PATH = str(REPO_ROOT / "vllm_omni" / "model_executor" / "stage_configs" / "voxcpm2.yaml") +SAMPLE_RATE = 48_000 + + +def parse_args(): + parser = FlexibleArgumentParser(description="Offline VoxCPM2 native AR inference") + parser.add_argument( + "--model", + type=str, + default="openbmb/VoxCPM2", + help="VoxCPM2 model path or HuggingFace repo ID.", + ) + parser.add_argument( + "--text", + type=str, + default="This is a VoxCPM2 native AR synthesis example running on vLLM Omni.", + help="Text to synthesize.", + ) + parser.add_argument( + "--output-dir", + type=str, + default="output_audio", + help="Directory for output WAV files.", + ) + parser.add_argument( + "--stage-configs-path", + type=str, + default=DEFAULT_STAGE_CONFIGS_PATH, + help="Path to the stage config YAML file.", + ) + parser.add_argument( + "--reference-audio", + type=str, + default=None, + help="Path to reference audio for voice cloning (isolated ref mode).", + ) + parser.add_argument( + "--prompt-audio", + type=str, + default=None, + help="Path to prompt audio for continuation mode (requires --prompt-text).", + ) + parser.add_argument( + "--prompt-text", + type=str, + default=None, + help="Text matching --prompt-audio for continuation mode.", + ) + return parser.parse_args() + + +def extract_audio(multimodal_output: dict) -> torch.Tensor: + """Extract the final complete audio tensor from multimodal output. + + The output processor accumulates per-step full audio under ``audio`` + as a list. The last element is the complete waveform. + """ + audio = multimodal_output.get("audio") or multimodal_output.get("model_outputs") + if audio is None: + raise ValueError(f"No audio key in multimodal_output: {list(multimodal_output.keys())}") + + if isinstance(audio, list): + # Take the last valid tensor (most complete audio) + valid = [torch.as_tensor(a).float().cpu().reshape(-1) for a in audio if a is not None] + if not valid: + raise ValueError("Audio list is empty or all elements are None.") + return valid[-1] + + return torch.as_tensor(audio).float().cpu().reshape(-1) + + +def main(): + args = parse_args() + + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + engine = Omni( + model=args.model, + stage_configs_path=args.stage_configs_path, + ) + + additional: dict = {} + if args.reference_audio: + additional["reference_audio"] = args.reference_audio + if args.prompt_audio and args.prompt_text: + additional["prompt_audio"] = args.prompt_audio + additional["prompt_text"] = args.prompt_text + + prompt: dict = {"prompt": args.text} + if additional: + prompt["additional_information"] = additional + + print(f"Model : {args.model}") + print(f"Text : {args.text}") + if args.reference_audio: + print(f"Ref audio : {args.reference_audio}") + if args.prompt_audio: + print(f"Prompt audio: {args.prompt_audio}") + print(f"Prompt text : {args.prompt_text}") + print(f"Output dir : {output_dir}") + + t_start = time.perf_counter() + outputs = engine.generate([prompt]) + elapsed = time.perf_counter() - t_start + + # outputs[0].outputs[0].multimodal_output["audio"] is a list of tensors + request_output = outputs[0] + mm = request_output.outputs[0].multimodal_output + audio = extract_audio(mm) + + duration = audio.numel() / SAMPLE_RATE + rtf = elapsed / duration if duration > 0 else float("inf") + + output_path = output_dir / "output.wav" + sf.write(str(output_path), audio.numpy(), SAMPLE_RATE, format="WAV") + + print(f"Saved : {output_path}") + print(f"Duration : {duration:.2f}s") + print(f"Inference : {elapsed:.2f}s") + print(f"RTF : {rtf:.3f}") + + +if __name__ == "__main__": + os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + main() diff --git a/tests/e2e/offline_inference/test_voxcpm2.py b/tests/e2e/offline_inference/test_voxcpm2.py new file mode 100644 index 00000000000..7e17c6a3691 --- /dev/null +++ b/tests/e2e/offline_inference/test_voxcpm2.py @@ -0,0 +1,101 @@ +"""E2E test for VoxCPM2 native AR offline inference.""" + +import os + +import pytest +import torch + +from tests.utils import hardware_test + +VOXCPM2_MODEL = "openbmb/VoxCPM2" +STAGE_CONFIG = os.path.join( + os.path.dirname(__file__), + "..", + "..", + "..", + "vllm_omni", + "model_executor", + "stage_configs", + "voxcpm2.yaml", +) +SAMPLE_RATE = 48000 + + +@pytest.fixture(scope="module") +def voxcpm2_engine(): + """Create VoxCPM2 engine for testing.""" + from vllm_omni import Omni + + engine = Omni(model=VOXCPM2_MODEL, stage_configs_path=STAGE_CONFIG) + yield engine + + +def _extract_audio(multimodal_output: dict) -> torch.Tensor: + """Extract the final complete audio tensor from multimodal output.""" + assert isinstance(multimodal_output, dict), f"Expected dict, got {type(multimodal_output)}" + + # Output processor accumulates per-step full audio under "audio". + audio = multimodal_output.get("audio") or multimodal_output.get("model_outputs") + assert audio is not None, f"No audio key, got {list(multimodal_output.keys())}" + + if isinstance(audio, list): + valid = [x for x in audio if isinstance(x, torch.Tensor) and x.numel() > 100] + assert valid, "No valid audio tensors in output list" + audio = valid[-1] + + assert isinstance(audio, torch.Tensor), f"Expected Tensor, got {type(audio)}" + return audio + + +@pytest.mark.core_model +@pytest.mark.omni +@hardware_test(res={"cuda": "L4"}, num_cards=1) +def test_voxcpm2_zero_shot_001(voxcpm2_engine): + """Test zero-shot TTS produces valid audio output.""" + outputs = voxcpm2_engine.generate([{"prompt": "Hello, this is a test."}]) + assert len(outputs) == 1 + + audio = _extract_audio(outputs[0].outputs[0].multimodal_output) + duration_s = audio.shape[0] / SAMPLE_RATE + assert 0.5 < duration_s < 30.0, f"Audio duration out of range: {duration_s:.2f}s" + + +@pytest.mark.core_model +@pytest.mark.omni +@hardware_test(res={"cuda": "L4"}, num_cards=1) +def test_voxcpm2_voice_clone_002(voxcpm2_engine): + """Test voice cloning with a reference audio file. + + Uses the example ``reference_speaker.wav`` bundled with the voxcpm + package. Skipped if the file is not present. + """ + # Try to locate a reference wav from the voxcpm package / env override + candidates = [] + env_path = os.environ.get("VLLM_OMNI_VOXCPM_CODE_PATH") + if env_path: + candidates.append(os.path.join(env_path, "..", "examples", "reference_speaker.wav")) + try: + import voxcpm # noqa: F401 (only used to locate path) + + vox_dir = os.path.dirname(os.path.dirname(os.path.abspath(voxcpm.__file__))) + candidates.append(os.path.join(vox_dir, "examples", "reference_speaker.wav")) + except ImportError: + pass + + ref_path = next((p for p in candidates if p and os.path.exists(p)), None) + if ref_path is None: + pytest.skip("No reference audio available for voice clone test") + + outputs = voxcpm2_engine.generate( + [ + { + "prompt": "Hello, this is a voice clone demo.", + "additional_information": {"reference_audio": ref_path}, + } + ] + ) + assert len(outputs) == 1 + + audio = _extract_audio(outputs[0].outputs[0].multimodal_output) + duration_s = audio.shape[0] / SAMPLE_RATE + assert 0.5 < duration_s < 30.0, f"Audio duration out of range: {duration_s:.2f}s" diff --git a/vllm_omni/engine/arg_utils.py b/vllm_omni/engine/arg_utils.py index b6637892624..e887b4799b3 100644 --- a/vllm_omni/engine/arg_utils.py +++ b/vllm_omni/engine/arg_utils.py @@ -20,6 +20,7 @@ _ARCH_TO_MODEL_TYPE: dict[str, str] = { "CosyVoice3Model": "cosyvoice3", "OmniVoiceModel": "omnivoice", + "VoxCPM2TalkerForConditionalGeneration": "voxcpm2", } # Maps model architecture names to tokenizer subfolder paths within HF repos. @@ -40,6 +41,7 @@ def _register_omni_hf_configs() -> None: from vllm_omni.model_executor.models.voxtral_tts.configuration_voxtral_tts import ( VoxtralTTSConfig, ) + from vllm_omni.transformers_utils.configs.voxcpm2 import VoxCPM2Config except Exception as exc: # pragma: no cover - best-effort optional registration logger.warning("Skipping omni HF config registration due to import error: %s", exc) return @@ -57,6 +59,7 @@ def _register_omni_hf_configs() -> None: ("cosyvoice3", CosyVoice3Config), ("omnivoice", OmniVoiceConfig), ("voxtral_tts", VoxtralTTSConfig), + ("voxcpm2", VoxCPM2Config), ]: try: AutoConfig.register(model_type, config_cls) diff --git a/vllm_omni/model_executor/models/registry.py b/vllm_omni/model_executor/models/registry.py index 3b51f20023d..08940880056 100644 --- a/vllm_omni/model_executor/models/registry.py +++ b/vllm_omni/model_executor/models/registry.py @@ -145,6 +145,12 @@ "fish_speech_dac_decoder", "FishSpeechDACDecoder", ), + ## VoxCPM2 + "VoxCPM2TalkerForConditionalGeneration": ( + "voxcpm2", + "voxcpm2_talker", + "VoxCPM2TalkerForConditionalGeneration", + ), ## Voxtral TTS "VoxtralTTSForConditionalGeneration": ( "voxtral_tts", diff --git a/vllm_omni/model_executor/models/voxcpm2/__init__.py b/vllm_omni/model_executor/models/voxcpm2/__init__.py new file mode 100644 index 00000000000..77bd8dfb518 --- /dev/null +++ b/vllm_omni/model_executor/models/voxcpm2/__init__.py @@ -0,0 +1,5 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from .voxcpm2_talker import VoxCPM2TalkerForConditionalGeneration + +__all__ = ["VoxCPM2TalkerForConditionalGeneration"] diff --git a/vllm_omni/model_executor/models/voxcpm2/voxcpm2_import_utils.py b/vllm_omni/model_executor/models/voxcpm2/voxcpm2_import_utils.py new file mode 100644 index 00000000000..231a51bbca4 --- /dev/null +++ b/vllm_omni/model_executor/models/voxcpm2/voxcpm2_import_utils.py @@ -0,0 +1,82 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Dynamic import utilities for the native VoxCPM2 package. + +Supports three discovery modes (first match wins): +1. ``VLLM_OMNI_VOXCPM_CODE_PATH`` env var (explicit source tree) +2. Sibling ``../VoxCPM/src`` relative to the vllm-omni repo root +3. pip-installed ``voxcpm`` package (>= 2.0) +""" + +from __future__ import annotations + +import importlib +import os +import sys +from pathlib import Path +from typing import Any + +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def _iter_voxcpm2_src_candidates() -> list[Path]: + """Yield candidate source directories for VoxCPM2.""" + candidates: list[Path] = [] + env_path = os.environ.get("VLLM_OMNI_VOXCPM_CODE_PATH") + if env_path: + candidates.append(Path(env_path).expanduser()) + + repo_root = Path(__file__).resolve().parents[4] + candidates.append(repo_root.parent / "VoxCPM" / "src") + + seen: set[str] = set() + unique: list[Path] = [] + for c in candidates: + key = str(c) + if key not in seen: + seen.add(key) + unique.append(c) + return unique + + +def _prepend_src(candidate: Path) -> None: + candidate_str = str(candidate) + if candidate_str not in sys.path: + sys.path.insert(0, candidate_str) + + +def _import_voxcpm2_attrs(module_name: str, *attr_names: str) -> tuple[Any, ...]: + """Import attributes from the voxcpm package, trying source tree first.""" + last_exc: ImportError | None = None + + for candidate in _iter_voxcpm2_src_candidates(): + if not candidate.exists(): + continue + _prepend_src(candidate) + try: + mod = importlib.import_module(module_name) + return tuple(getattr(mod, name) for name in attr_names) + except (ImportError, AttributeError) as exc: + last_exc = ImportError(str(exc)) + continue + + try: + mod = importlib.import_module(module_name) + return tuple(getattr(mod, name) for name in attr_names) + except (ImportError, AttributeError) as exc: + last_exc = ImportError(str(exc)) + + raise ImportError( + f"Could not import {attr_names} from {module_name}. " + f"Install voxcpm>=2.0: pip install voxcpm. " + f"Or set VLLM_OMNI_VOXCPM_CODE_PATH to the VoxCPM source tree. " + f"Last error: {last_exc}" + ) + + +def import_voxcpm2_core(): + """Import the VoxCPM core class used to load the native TTS model.""" + (VoxCPM,) = _import_voxcpm2_attrs("voxcpm.core", "VoxCPM") + return VoxCPM diff --git a/vllm_omni/model_executor/models/voxcpm2/voxcpm2_talker.py b/vllm_omni/model_executor/models/voxcpm2/voxcpm2_talker.py new file mode 100644 index 00000000000..ade68b673b7 --- /dev/null +++ b/vllm_omni/model_executor/models/voxcpm2/voxcpm2_talker.py @@ -0,0 +1,569 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""VoxCPM2 native AR talker — uses native MiniCPM4 base_lm directly. + +Uses native VoxCPM2 modules (no PagedAttention, manual KV cache). +Each AR decode step: + feat_encoder → base_lm → FSQ → residual_lm → LocDiT → stop + +TODO(PagedAttention): The base_lm is a MiniCPM4 variant (GQA + LongRoPE, +use_mup=False). vllm's MiniCPMModel already supports the architecture +(LongRoPE via Phi3LongRoPEScaledRotaryEmbedding, muP via config), but +two issues block replacing the native base_lm with a vllm MiniCPM4Model: + 1. Per-request state isolation — residual_lm and LocDiT diffusion use + shared native KV caches; concurrent requests clobber each other. + Fix: save/restore residual_lm cache per request, or pool N instances. + 2. Streaming audio — make_omni_output re-decodes all patches each step. + Fix: sliding-window VAE decode (decode_pad pattern from nanovllm). +""" + +from __future__ import annotations + +from collections.abc import Iterable +from typing import Any + +import torch +import torch.nn as nn +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.model_executor.models.minicpm import MiniCPMModel +from vllm.model_executor.models.utils import ( + AutoWeightsLoader, + WeightsMapper, + maybe_prefix, +) +from vllm.sequence import IntermediateTensors + +from vllm_omni.model_executor.models.output_templates import OmniOutput + +from .voxcpm2_import_utils import import_voxcpm2_core + +logger = init_logger(__name__) + + +class VoxCPM2TalkerForConditionalGeneration(nn.Module): + """VoxCPM2 talker using native MiniCPM4 base_lm. + + Loads the full VoxCPM2 model natively and decomposes the AR loop: + each vllm decode step runs one iteration of the native generate loop. + """ + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + self.vllm_config = vllm_config + self.config = vllm_config.model_config.hf_config + + # Flags for OmniGPUModelRunner + self.have_multimodal_outputs = True + self.has_preprocess = True + self.has_postprocess = True + self._accumulated_patches: list[torch.Tensor] = [] + + # vllm MiniCPMModel scaffold — needed for warmup/profiling/KV cache + # sizing. Not used for actual computation (native modules are used). + self.model = MiniCPMModel(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) + self.make_empty_intermediate_tensors = self.model.make_empty_intermediate_tensors + + # Placeholder — actual native model loaded in load_weights + self._tts: nn.Module | None = None + self._device = "cuda" + self._side_dtype = torch.bfloat16 + + # Config values + self._patch_size = getattr(self.config, "patch_size", 4) + self._feat_dim = getattr(self.config, "feat_dim", 64) + self._inference_timesteps = 10 + self._cfg_value = 2.0 + + # TODO: implement sliding-window VAE decode (nanovllm pattern) + # for O(1) per-step streaming. Current impl re-decodes all patches. + + @property + def tts(self) -> nn.Module: + assert self._tts is not None, "Model not loaded yet" + return self._tts + + # -------------------- vllm hooks -------------------- + + def embed_input_ids(self, input_ids: torch.Tensor, **_: Any) -> torch.Tensor: + """Embed input IDs using native base_lm with scale_emb.""" + embeds = self.tts.base_lm.embed_tokens(input_ids) + return embeds * self.tts.config.lm_config.scale_emb + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + **kwargs: Any, + ) -> torch.Tensor | IntermediateTensors: + """Full VoxCPM2 AR step: base_lm → FSQ → residual_lm → diffusion.""" + # Always run scaffold model to keep FlashInfer/attention happy + model_output = self.model(input_ids, positions, intermediate_tensors, inputs_embeds) + if isinstance(model_output, IntermediateTensors): + return model_output + scaffold_hidden = model_output + if isinstance(scaffold_hidden, tuple): + scaffold_hidden = scaffold_hidden[0] + + # Real computation: use native modules + has_infos = bool(getattr(self, "_current_step_infos", None)) + is_prefill = scaffold_hidden.shape[0] > 1 + + if is_prefill and has_infos: + self._forward_prefill(inputs_embeds, scaffold_hidden.device) + # Return scaffold output (right shape for engine) — our side + # computation results are stored in instance state + return scaffold_hidden + + if not is_prefill and hasattr(self, "_prev_feat_embed"): + self._forward_decode(inputs_embeds, scaffold_hidden.device) + return scaffold_hidden + + return scaffold_hidden + + def _build_prefill_inputs(self, text: str, dev: Any): + """Build text_token / audio_feat / masks like native _generate_with_prompt_cache. + + Returns a dict with keys: text_token, audio_feat, text_mask, audio_mask, + prefix_feat_cond. Handles zero-shot, reference (voice clone), continuation, + and ref_continuation modes. + """ + tts = self.tts + dtype = self._side_dtype + cache = getattr(self, "_prompt_cache", None) + mode = cache.get("mode", "continuation") if cache else "zero_shot" + + if cache is not None and mode in ("continuation", "ref_continuation"): + full_text = cache.get("prompt_text", "") + text + else: + full_text = text + + text_token = torch.LongTensor(tts.text_tokenizer(full_text)) + text_token = torch.cat( + [ + text_token, + torch.tensor([tts.audio_start_token], dtype=torch.int32, device=text_token.device), + ], + dim=-1, + ) + text_length = text_token.shape[0] + latent_dim = tts.audio_vae.latent_dim + patch_size = tts.patch_size + + if mode in ("zero_shot", "continuation"): + prompt_audio_feat = ( + cache["audio_feat"] if cache else torch.empty((0, patch_size, latent_dim), dtype=torch.float32) + ) + audio_length = prompt_audio_feat.size(0) + text_pad_token = torch.zeros(audio_length, dtype=torch.int32) + text_pad_feat = torch.zeros((text_length, patch_size, latent_dim), dtype=torch.float32) + text_token = torch.cat([text_token, text_pad_token]) + audio_feat = torch.cat([text_pad_feat, prompt_audio_feat], dim=0) + text_mask = torch.cat( + [ + torch.ones(text_length, dtype=torch.int32), + torch.zeros(audio_length, dtype=torch.int32), + ] + ) + audio_mask = torch.cat( + [ + torch.zeros(text_length, dtype=torch.int32), + torch.ones(audio_length, dtype=torch.int32), + ] + ) + elif mode == "reference": + ref_audio_feat = cache["ref_audio_feat"] + ref_tokens, ref_feats, ref_t_mask, ref_a_mask = tts._make_ref_prefix(ref_audio_feat, text_token.device) + text_pad_feat = torch.zeros((text_length, patch_size, latent_dim), dtype=torch.float32) + text_token = torch.cat([ref_tokens.cpu(), text_token]) + audio_feat = torch.cat([ref_feats.cpu(), text_pad_feat], dim=0) + text_mask = torch.cat([ref_t_mask.cpu(), torch.ones(text_length, dtype=torch.int32)]) + audio_mask = torch.cat([ref_a_mask.cpu(), torch.zeros(text_length, dtype=torch.int32)]) + else: + # ref_continuation + ref_audio_feat = cache["ref_audio_feat"] + prompt_audio_feat = cache["audio_feat"] + prompt_audio_length = prompt_audio_feat.size(0) + ref_tokens, ref_feats, ref_t_mask, ref_a_mask = tts._make_ref_prefix(ref_audio_feat, text_token.device) + prompt_pad_token = torch.zeros(prompt_audio_length, dtype=torch.int32) + text_pad_feat = torch.zeros((text_length, patch_size, latent_dim), dtype=torch.float32) + text_token = torch.cat([ref_tokens.cpu(), text_token, prompt_pad_token]) + audio_feat = torch.cat([ref_feats.cpu(), text_pad_feat, prompt_audio_feat], dim=0) + text_mask = torch.cat( + [ + ref_t_mask.cpu(), + torch.ones(text_length, dtype=torch.int32), + torch.zeros(prompt_audio_length, dtype=torch.int32), + ] + ) + audio_mask = torch.cat( + [ + ref_a_mask.cpu(), + torch.zeros(text_length, dtype=torch.int32), + torch.ones(prompt_audio_length, dtype=torch.int32), + ] + ) + + return { + "text_token": text_token.unsqueeze(0).to(dev), + "audio_feat": audio_feat.unsqueeze(0).to(dev).to(dtype), + "text_mask": text_mask.unsqueeze(0).to(dev), + "audio_mask": audio_mask.unsqueeze(0).to(dev), + } + + def _forward_prefill(self, inputs_embeds: torch.Tensor, dev: Any) -> torch.Tensor: + """Prefill: build combined embeds, run base_lm + residual_lm + first diffusion. + + Uses the same path as native ``VoxCPM2Model._inference`` so zero-shot, + voice cloning (reference), continuation, and ref_continuation modes + all share the same code. + """ + tts = self.tts + dtype = self._side_dtype + text = getattr(self, "_prefill_text", None) + if text is None: + # Fallback (should not hit at runtime; preprocess sets this) + text = "" + + inputs = self._build_prefill_inputs(text, dev) + text_token = inputs["text_token"] + feat = inputs["audio_feat"] + text_mask = inputs["text_mask"] + feat_mask = inputs["audio_mask"] + + # Compose combined_embed exactly like native _inference + feat_embed = tts.feat_encoder(feat) + feat_embed = tts.enc_to_lm_proj(feat_embed) + scale_emb = tts.config.lm_config.scale_emb if tts.config.lm_config.use_mup else 1.0 + text_embed = tts.base_lm.embed_tokens(text_token) * scale_emb + combined_embed = text_mask.unsqueeze(-1) * text_embed + feat_mask.unsqueeze(-1) * feat_embed + + # last audio patch becomes initial prefix_feat_cond (zeros for zero-shot, + # last reference/prompt patch for voice clone / continuation) + prefix_feat_cond = ( + feat[:, -1, ...] + if feat.shape[1] > 0 + else torch.zeros(1, tts.patch_size, tts.feat_dim, device=dev, dtype=dtype) + ) + + # Base LM prefill + tts.base_lm.setup_cache(1, 4096, dev, dtype) + enc_out, enc_kv = tts.base_lm(inputs_embeds=combined_embed, is_causal=True) + tts.base_lm.kv_cache.fill_caches(enc_kv) + + # FSQ: identity on text positions, quantized on audio positions + enc_outputs = tts.fsq_layer(enc_out) * feat_mask.unsqueeze(-1) + enc_out * text_mask.unsqueeze(-1) + lm_hidden = enc_outputs[:, -1, :] # [1, H] + + logger.info( + "PREFILL: enc shape=%s last_norm=%.4f", + enc_outputs.shape, + lm_hidden.norm().item(), + ) + + # Residual LM prefill + tts.residual_lm.setup_cache(1, 4096, dev, dtype) + residual_input = tts.fusion_concat_proj(torch.cat([enc_outputs, feat_mask.unsqueeze(-1) * feat_embed], dim=-1)) + res_out, res_kv = tts.residual_lm(inputs_embeds=residual_input, is_causal=True) + tts.residual_lm.kv_cache.fill_caches(res_kv) + residual_hidden = res_out[:, -1, :] # [1, H] + + # Precompute stop logits for first compute_logits call + stop_logits = tts.stop_head(tts.stop_actn(tts.stop_proj(lm_hidden))) + self._precomputed_stop_logits = stop_logits.detach() + logger.info("PREFILL stop: %s", stop_logits[0].tolist()) + + # First diffusion step + dit_h = torch.cat( + [ + tts.lm_to_dit_proj(lm_hidden), + tts.res_to_dit_proj(residual_hidden), + ], + dim=-1, + ) + pred_feat = tts.feat_decoder( + mu=dit_h, + patch_size=tts.patch_size, + cond=prefix_feat_cond.transpose(1, 2).contiguous(), + n_timesteps=self._inference_timesteps, + cfg_value=self._cfg_value, + ).transpose(1, 2) # [1, P, D] + + with torch.no_grad(): + curr_embed = tts.enc_to_lm_proj(tts.feat_encoder(pred_feat.unsqueeze(1))).squeeze(1) + + # Store state for decode steps + self._curr_embed_for_next = curr_embed.detach() + self._prev_feat_embed = curr_embed.detach() + self._curr_prefix_feat_cond = pred_feat[0].detach() + self._last_audio_patch = pred_feat.reshape(1, -1).detach().cpu().float() + + logger.info( + "PREFILL patch: norm=%.4f first3=%s", + pred_feat.norm().item(), + pred_feat[0, 0, :3].tolist(), + ) + + return lm_hidden.to(dtype) + + def _forward_decode(self, inputs_embeds: torch.Tensor | None, dev: Any) -> torch.Tensor: + """Decode step: base_lm → FSQ → residual_lm → diffusion.""" + tts = self.tts + dtype = self._side_dtype + + # 1. Base LM step with curr_embed from previous diffusion + curr_embed = self._curr_embed_for_next.to(dev, dtype=dtype) + if curr_embed.ndim == 2: + curr_embed_3d = curr_embed.unsqueeze(0) # [1, 1, H] + else: + curr_embed_3d = curr_embed + + step_pos = torch.tensor([tts.base_lm.kv_cache.step()], device=dev) + new_hidden = tts.base_lm.forward_step(curr_embed_3d[:, 0, :], step_pos).clone() + + # 2. FSQ + new_lm_hidden = tts.fsq_layer(new_hidden) + if new_lm_hidden.ndim == 1: + new_lm_hidden = new_lm_hidden.unsqueeze(0) + + # 3. Residual LM step + prev_fe = self._prev_feat_embed.to(dtype) + if prev_fe.ndim == 1: + prev_fe = prev_fe.unsqueeze(0) + res_input = tts.fusion_concat_proj(torch.cat([new_lm_hidden, prev_fe], dim=-1)) + res_step_pos = torch.tensor([tts.residual_lm.kv_cache.step()], device=dev) + new_res_hidden = tts.residual_lm.forward_step(res_input, res_step_pos).clone() + if new_res_hidden.ndim == 1: + new_res_hidden = new_res_hidden.unsqueeze(0) + + # 4. Diffusion + p = self._patch_size + pfc = self._curr_prefix_feat_cond.to(dtype).unsqueeze(0) + + dit_h = torch.cat( + [ + tts.lm_to_dit_proj(new_lm_hidden), + tts.res_to_dit_proj(new_res_hidden), + ], + dim=-1, + ) + pred_feat = tts.feat_decoder( + mu=dit_h, + patch_size=p, + cond=pfc.transpose(1, 2).contiguous(), + n_timesteps=self._inference_timesteps, + cfg_value=self._cfg_value, + ).transpose(1, 2) # [1, P, D] + + # 5. feat_encoder → curr_embed + with torch.no_grad(): + curr_embed = tts.enc_to_lm_proj(tts.feat_encoder(pred_feat.unsqueeze(1))).squeeze(1) + + # 6. Stop logits + stop_logits = tts.stop_head(tts.stop_actn(tts.stop_proj(new_lm_hidden))) + self._precomputed_stop_logits = stop_logits.detach() + + # 7. Store state + self._curr_embed_for_next = curr_embed.detach() + self._prev_feat_embed = curr_embed.detach() + self._curr_prefix_feat_cond = pred_feat[0].detach() + self._last_audio_patch = pred_feat.reshape(1, -1).detach().cpu().float() + + return new_lm_hidden[-1:].detach() + + def compute_logits( + self, + hidden_states: torch.Tensor | OmniOutput, + sampling_metadata: Any = None, + ) -> torch.Tensor | None: + if isinstance(hidden_states, OmniOutput): + hidden_states = hidden_states.text_hidden_states + if hidden_states is None: + return None + + precomputed = getattr(self, "_precomputed_stop_logits", None) + if precomputed is not None: + self._precomputed_stop_logits = None + raw_logits = precomputed[: hidden_states.shape[0]] + else: + # Fallback for warmup + bsz = hidden_states.shape[0] + raw_logits = torch.zeros(bsz, 2, device=hidden_states.device) + raw_logits[:, 0] = 1.0 # continue + + bsz = raw_logits.shape[0] + full_logits = torch.full( + (bsz, self.config.vocab_size), + float("-inf"), + device=raw_logits.device, + dtype=raw_logits.dtype, + ) + full_logits[:, 0] = raw_logits[:, 0] # continue + full_logits[:, 1] = raw_logits[:, 1] # stop + return full_logits + + # -------------------- Omni output -------------------- + + def make_omni_output(self, model_outputs: torch.Tensor | OmniOutput, **kwargs: Any) -> OmniOutput: + if isinstance(model_outputs, OmniOutput): + return model_outputs + + hidden = model_outputs + patch = getattr(self, "_last_audio_patch", None) + mm: dict[str, Any] = {} + + if patch is not None: + self._last_audio_patch = None + self._accumulated_patches.append(patch.clone()) + + # Decode all accumulated patches → full audio waveform. + # TODO: implement sliding-window VAE decode (nanovllm pattern) + # for O(1) per-step streaming instead of O(N) re-decode. + if self._accumulated_patches: + all_p = torch.cat(self._accumulated_patches, dim=0) + d = self._feat_dim + from einops import rearrange + + feat = rearrange(all_p.float().reshape(1, -1, d), "b t d -> b d t") + with torch.no_grad(): + audio = self.tts.audio_vae.decode(feat.to(self._device)).reshape(-1).detach().cpu().float() + + mm["model_outputs"] = [audio] + mm["sr"] = [torch.tensor(48000, dtype=torch.int32)] + + return OmniOutput( + text_hidden_states=hidden, + multimodal_outputs=mm, + ) + + # -------------------- preprocess / postprocess -------------------- + + def preprocess( + self, + input_ids: torch.Tensor, + input_embeds: torch.Tensor | None, + **info_dict: Any, + ) -> tuple[torch.Tensor, torch.Tensor, dict[str, Any]]: + additional_information = info_dict.get("additional_information") + if isinstance(additional_information, dict): + merged = {k: v for k, v in info_dict.items() if k != "additional_information"} + for k, v in additional_information.items(): + merged.setdefault(k, v) + info_dict = merged + + span_len = int(input_ids.shape[0]) + dev = input_ids.device + + if span_len > 1: + # ---- Prefill ---- + # Decode the text from input_ids for native-matching tokenization. + # Speech API tokenizes with BOS; we use the detokenized string so + # native's ``text_tokenizer`` produces the exact same tokens as + # ``generate()``. + ids = input_ids.tolist() + if ids and ids[0] == self.config.bos_token_id: + ids = ids[1:] + text = self.tts.text_tokenizer.tokenizer.decode(ids, skip_special_tokens=True) + self._prefill_text = text + + # Voice clone / continuation: build prompt cache from info_dict. + ref_audio = info_dict.get("reference_audio") or info_dict.get("ref_audio") + prompt_audio = info_dict.get("prompt_audio") + prompt_text = info_dict.get("prompt_text") + if isinstance(ref_audio, list): + ref_audio = ref_audio[0] if ref_audio else None + if isinstance(prompt_audio, list): + prompt_audio = prompt_audio[0] if prompt_audio else None + if isinstance(prompt_text, list): + prompt_text = prompt_text[0] if prompt_text else None + + self._prompt_cache = None + if ref_audio or (prompt_audio and prompt_text): + try: + self._prompt_cache = self.tts.build_prompt_cache( + prompt_text=prompt_text, + prompt_wav_path=prompt_audio, + reference_wav_path=ref_audio, + ) + except Exception as e: + logger.warning("build_prompt_cache failed: %s; falling back to zero-shot", e) + self._prompt_cache = None + + # Reset per-request state (fresh generation) + self._accumulated_patches = [] + if hasattr(self, "_prev_feat_embed"): + del self._prev_feat_embed + if hasattr(self, "_curr_embed_for_next"): + del self._curr_embed_for_next + + # Store info for forward + self._current_step_infos = [{"is_prefill": True}] + + # The scaffold model still needs embeddings sized to span_len for + # its warmup/attention bookkeeping. Native modules use the full + # (potentially longer) sequence internally. Pass zeros — scaffold + # output is discarded. + embeds = torch.zeros( + span_len, + self.config.hidden_size, + device=dev, + dtype=self._side_dtype, + ) + + return input_ids, embeds, {} + + # ---- Decode ---- + curr_embed = getattr(self, "_curr_embed_for_next", None) + if curr_embed is not None: + inputs_embeds = curr_embed.to(dev, dtype=self._side_dtype).reshape(1, -1) + else: + inputs_embeds = torch.zeros( + 1, + self.config.hidden_size, + device=dev, + dtype=self._side_dtype, + ) + + self._current_step_infos = [{}] + return input_ids, inputs_embeds, {} + + def postprocess(self, hidden_states: torch.Tensor, **info: Any) -> dict[str, Any]: + return {} + + # -------------------- Weight loading -------------------- + + # Weight mapping for vllm scaffold + hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"base_lm.": "model."}) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + """Load scaffold weights via vllm + native model for computation.""" + + # Filter: only pass base_lm weights to the scaffold + def _base_lm_only(ws): + for name, tensor in ws: + if name.startswith("base_lm."): + yield name, tensor + + loader = AutoWeightsLoader(self) + loaded = loader.load_weights(_base_lm_only(weights), mapper=self.hf_to_vllm_mapper) + + # Load the full native model for actual computation + model_path = self.vllm_config.model_config.model + VoxCPM = import_voxcpm2_core() + native = VoxCPM.from_pretrained(model_path, load_denoiser=False, optimize=False) + self._tts = native.tts_model.to("cuda") + self._side_dtype = self._tts.fusion_concat_proj.weight.dtype + self._device = "cuda" + + self._patch_size = self._tts.patch_size + self._feat_dim = self._tts.feat_dim + + logger.info( + "Loaded native VoxCPM2 (patch_size=%d, feat_dim=%d, dtype=%s)", + self._patch_size, + self._feat_dim, + self._side_dtype, + ) + return loaded diff --git a/vllm_omni/model_executor/stage_configs/voxcpm2.yaml b/vllm_omni/model_executor/stage_configs/voxcpm2.yaml new file mode 100644 index 00000000000..de15c88de4e --- /dev/null +++ b/vllm_omni/model_executor/stage_configs/voxcpm2.yaml @@ -0,0 +1,36 @@ +# VoxCPM2 native AR single-stage pipeline. +# Uses native MiniCPM4 base_lm + native VAE decode in one stage. +# All computation (base_lm, residual_lm, diffusion, VAE) in forward(). +stage_args: + - stage_id: 0 + stage_type: llm + is_comprehension: true + runtime: + devices: "0" + max_batch_size: 1 + engine_args: + dtype: bfloat16 + model_stage: latent_generator + model_arch: VoxCPM2TalkerForConditionalGeneration + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + enforce_eager: true + trust_remote_code: true + async_scheduling: true + enable_prefix_caching: false + engine_output_type: audio + gpu_memory_utilization: 0.9 + distributed_executor_backend: "mp" + max_num_batched_tokens: 4096 + max_model_len: 4096 + default_sampling_params: + temperature: 0.0 + top_p: 1.0 + top_k: -1 + max_tokens: 4096 + seed: 42 + detokenize: false + repetition_penalty: 1.0 + stop_token_ids: [1] + final_output: true + final_output_type: audio diff --git a/vllm_omni/transformers_utils/configs/__init__.py b/vllm_omni/transformers_utils/configs/__init__.py index 59b23f91490..5f957c2f6de 100644 --- a/vllm_omni/transformers_utils/configs/__init__.py +++ b/vllm_omni/transformers_utils/configs/__init__.py @@ -17,6 +17,7 @@ "FishSpeechConfig": "vllm_omni.transformers_utils.configs.fish_speech", "FishSpeechSlowARConfig": "vllm_omni.transformers_utils.configs.fish_speech", "FishSpeechFastARConfig": "vllm_omni.transformers_utils.configs.fish_speech", + "VoxCPM2Config": "vllm_omni.transformers_utils.configs.voxcpm2", } __all__ = [ @@ -27,6 +28,7 @@ "FishSpeechConfig", "FishSpeechSlowARConfig", "FishSpeechFastARConfig", + "VoxCPM2Config", ] @@ -47,3 +49,4 @@ def __dir__(): # run as soon as `vllm_omni.transformers_utils.configs` is imported. from vllm_omni.transformers_utils.configs import fish_speech as _fish_speech # noqa: F401, E402 from vllm_omni.transformers_utils.configs import mammoth_moda2 as _mammoth_moda2 # noqa: F401, E402 +from vllm_omni.transformers_utils.configs import voxcpm2 as _voxcpm2 # noqa: F401, E402 diff --git a/vllm_omni/transformers_utils/configs/voxcpm2.py b/vllm_omni/transformers_utils/configs/voxcpm2.py new file mode 100644 index 00000000000..c625284bd67 --- /dev/null +++ b/vllm_omni/transformers_utils/configs/voxcpm2.py @@ -0,0 +1,153 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import math + +from transformers import AutoConfig +from transformers.configuration_utils import PretrainedConfig +from transformers.modeling_rope_utils import rope_config_validation + + +class VoxCPM2Config(PretrainedConfig): + """Configuration for VoxCPM2 native AR integration. + + The HuggingFace checkpoint stores LM parameters inside a nested + ``lm_config`` dict. This class hoists them to top-level attributes + so that vllm's ``MiniCPMModel`` can consume them directly. + + vllm's MiniCPM **always** applies muP scaling (scale_emb, scale_depth, + dim_model_base). VoxCPM2 was trained with ``use_mup=false``, so we + neutralise the scalings: + * ``scale_emb = 1.0`` + * ``scale_depth = sqrt(num_hidden_layers)`` (cancels the division) + * ``dim_model_base = hidden_size`` (makes scale_width = 1.0) + """ + + model_type = "voxcpm2" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + # -- top-level VoxCPM2 params -- + architecture: str = "voxcpm2", + lm_config: dict | None = None, + encoder_config: dict | None = None, + dit_config: dict | None = None, + audio_vae_config: dict | None = None, + patch_size: int = 4, + feat_dim: int = 64, + residual_lm_num_layers: int = 8, + residual_lm_no_rope: bool = True, + scalar_quantization_latent_dim: int = 512, + scalar_quantization_scale: int = 9, + max_length: int = 8192, + device: str = "cuda", + dtype: str = "bfloat16", + # -- LM defaults (overridden by lm_config if present) -- + bos_token_id: int = 1, + eos_token_id: int = 2, + vocab_size: int = 73448, + hidden_size: int = 2048, + intermediate_size: int = 6144, + max_position_embeddings: int = 32768, + num_attention_heads: int = 16, + num_hidden_layers: int = 28, + num_key_value_heads: int = 2, + rms_norm_eps: float = 1e-5, + rope_theta: float = 10000.0, + rope_scaling: dict | None = None, + **kwargs, + ): + super().__init__( + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + **kwargs, + ) + self.architecture = architecture + + # -- VoxCPM2-specific fields -- + self.lm_config = lm_config or {} + self.encoder_config = encoder_config or {} + self.dit_config = dit_config or {} + self.audio_vae_config = audio_vae_config or {} + self.patch_size = patch_size + self.feat_dim = feat_dim + self.residual_lm_num_layers = residual_lm_num_layers + self.residual_lm_no_rope = residual_lm_no_rope + self.scalar_quantization_latent_dim = scalar_quantization_latent_dim + self.scalar_quantization_scale = scalar_quantization_scale + self.max_length = max_length + self.device = device + self.dtype = dtype + + # -- Hoist LM parameters to top-level for MiniCPMModel -- + lm = self.lm_config + self.vocab_size = lm.get("vocab_size", vocab_size) + self.hidden_size = lm.get("hidden_size", hidden_size) + self.intermediate_size = lm.get("intermediate_size", intermediate_size) + self.max_position_embeddings = lm.get("max_position_embeddings", max_position_embeddings) + self.num_attention_heads = lm.get("num_attention_heads", num_attention_heads) + self.num_hidden_layers = lm.get("num_hidden_layers", num_hidden_layers) + self.num_key_value_heads = lm.get("num_key_value_heads", num_key_value_heads) + self.rms_norm_eps = lm.get("rms_norm_eps", rms_norm_eps) + self.rope_theta = lm.get("rope_theta", rope_theta) + + # MiniCPM-specific: kv_channels overrides head_dim when set. + kv_channels = lm.get("kv_channels") + if kv_channels is not None: + self.head_dim = kv_channels + else: + self.head_dim = self.hidden_size // self.num_attention_heads + + # MiniCPM requires hidden_act; VoxCPM2 uses SiLU. + self.hidden_act = "silu" + self.hidden_act_param = 0.0 + self.tie_word_embeddings = False + self.num_experts = 0 + + # -- muP scaling -- + # Native VoxCPM2 MiniCPM gates scale_depth behind use_mup: + # use_mup=True → residual += h * (scale_depth / sqrt(N)) + # use_mup=False → residual += h (plain add, no scaling) + # But vllm's MiniCPMModel ALWAYS applies scale_depth / sqrt(N). + # Native applies scale_emb externally; vllm applies it in embed_input_ids. + use_mup = lm.get("use_mup", False) + self.scale_emb = lm.get("scale_emb", 1.0) + if use_mup: + self.scale_depth = lm.get("scale_depth", 1.0) + self.dim_model_base = lm.get("dim_model_base", self.hidden_size) + else: + # Neutralize: scale_depth/sqrt(N) = 1.0, scale_width = 1.0 + self.scale_depth = math.sqrt(self.num_hidden_layers) + self.dim_model_base = self.hidden_size + + # -- RoPE scaling (longrope) -- + raw_rope = lm.get("rope_scaling", rope_scaling) + if raw_rope is not None: + self.rope_scaling = dict(raw_rope) + # HF expects "rope_type" not "type" + if "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling.pop("type") + # longrope requires "factor" (used by HF validation) + if "factor" not in self.rope_scaling: + self.rope_scaling["factor"] = 1.0 + rope_config_validation(self) + + # vllm's MiniCPMAttention reads config.rope_parameters (a dict + # with rope_type, theta, scaling factors, etc.). HF transformers + # only auto-computes this for known model_types; for custom + # types we must build it manually. + if not getattr(self, "rope_parameters", None): + rp = dict(self.rope_scaling) + rp["rope_theta"] = self.rope_theta + self.rope_parameters = rp + else: + self.rope_scaling = None + + def get_text_config(self, **kwargs): + """Return self as the text config — LM attributes are top-level.""" + return self + + +AutoConfig.register("voxcpm2", VoxCPM2Config) + +__all__ = ["VoxCPM2Config"]