diff --git a/.buildkite/test-ready.yml b/.buildkite/test-ready.yml index 2f749f0ee9..68f8e61528 100644 --- a/.buildkite/test-ready.yml +++ b/.buildkite/test-ready.yml @@ -295,6 +295,31 @@ steps: volumes: - "/fsx/hf_cache:/fsx/hf_cache" + - label: "VoxCPM E2E Test" + timeout_in_minutes: 20 + depends_on: upload-ready-pipeline + commands: + - | + timeout 20m bash -c ' + pip install voxcpm + export VLLM_LOGGING_LEVEL=DEBUG + export VLLM_WORKER_MULTIPROC_METHOD=spawn + pytest -s -v tests/e2e/offline_inference/test_voxcpm.py -m "core_model" --run-level "core_model" + ' + agents: + queue: "gpu_1_queue" + plugins: + - docker#v5.2.0: + image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT + always-pull: true + propagate-environment: true + shm-size: "8gb" + environment: + - "HF_HOME=/fsx/hf_cache" + - "HF_TOKEN" + volumes: + - "/fsx/hf_cache:/fsx/hf_cache" + - label: "VoxCPM2 Native AR E2E Test" timeout_in_minutes: 20 depends_on: upload-ready-pipeline diff --git a/benchmarks/voxcpm/README.md b/benchmarks/voxcpm/README.md new file mode 100644 index 0000000000..17f904101b --- /dev/null +++ b/benchmarks/voxcpm/README.md @@ -0,0 +1,119 @@ +# VoxCPM Benchmark + +This directory contains both: + +- online serving benchmark through the OpenAI-compatible `/v1/audio/speech` API +- offline benchmark for `Omni` / `AsyncOmni` +- full offline smoke-matrix orchestration + +Both benchmark paths report: + +- TTFP: time to first PCM packet +- E2E latency +- RTF: real-time factor (`e2e / audio_duration`) + +## Offline Benchmark + +Single offline benchmark run: + +```bash +python benchmarks/voxcpm/vllm_omni/bench_tts_offline.py \ + --model /path/to/voxcpm-model \ + --stage-configs-path vllm_omni/model_executor/stage_configs/voxcpm.yaml \ + --text "This is a split-stage VoxCPM synthesis example running on vLLM Omni." \ + --warmup-runs 1 \ + --output-dir benchmarks/voxcpm/results/offline_single +``` + +Streaming offline benchmark: + +```bash +python benchmarks/voxcpm/vllm_omni/bench_tts_offline.py \ + --model /path/to/voxcpm-model \ + --stage-configs-path vllm_omni/model_executor/stage_configs/voxcpm_async_chunk.yaml \ + --text "This is a split-stage VoxCPM streaming example running on vLLM Omni." \ + --warmup-runs 1 \ + --output-dir benchmarks/voxcpm/results/offline_streaming +``` + +Full fixed offline matrix, equivalent to the old `examples/offline_inference/voxcpm/test.py`: + +```bash +python benchmarks/voxcpm/vllm_omni/run_offline_matrix.py \ + --model /path/to/voxcpm-model \ + --ref-audio /path/to/reference.wav \ + --ref-text "The exact transcript spoken in reference.wav." \ + --output-root benchmarks/voxcpm/results/offline_matrix +``` + +The full matrix covers both routes: + +- streaming: `voxcpm_async_chunk.yaml` +- sync: `voxcpm.yaml` + +And these six scenarios under each route: + +- warmup + single TTS +- warmup + single voice cloning +- warmup + batch TTS +- warmup + batch voice cloning +- cold single TTS +- cold single voice cloning + +`bench_tts_offline.py` itself no longer writes `summary.json` / `results.json`; it prints TTFP / RTF inline and saves generated WAV files only. The matrix runner keeps only per-case `run.log`. + +## Start the Server + +Async-chunk: + +```bash +vllm serve /path/to/voxcpm-model \ + --stage-configs-path vllm_omni/model_executor/stage_configs/voxcpm_async_chunk.yaml \ + --trust-remote-code \ + --enforce-eager \ + --omni \ + --port 8091 +``` + +Non-streaming: + +```bash +vllm serve /path/to/voxcpm-model \ + --stage-configs-path vllm_omni/model_executor/stage_configs/voxcpm.yaml \ + --trust-remote-code \ + --enforce-eager \ + --omni \ + --port 8091 +``` + +## Run the Benchmark + +```bash +python benchmarks/voxcpm/vllm_omni/bench_tts_serve.py \ + --host 127.0.0.1 \ + --port 8091 \ + --num-prompts 20 \ + --max-concurrency 1 \ + --result-dir /tmp/voxcpm_bench +``` + +Voice cloning benchmark: + +```bash +python benchmarks/voxcpm/vllm_omni/bench_tts_serve.py \ + --host 127.0.0.1 \ + --port 8091 \ + --num-prompts 10 \ + --max-concurrency 1 \ + --ref-audio https://example.com/reference.wav \ + --ref-text "The exact transcript spoken in the reference audio." \ + --result-dir /tmp/voxcpm_clone_bench +``` + +## Notes + +- The benchmark uses `stream=true` and `response_format=pcm` so TTFP is measured from the first audio packet. +- `RTF < 1.0` means the server generates audio faster than real time. +- For `voxcpm_async_chunk.yaml`, keep concurrency at `1`. This matches native VoxCPM streaming more closely. +- Do not benchmark concurrent online streaming on `voxcpm_async_chunk.yaml`; use `voxcpm.yaml` for multi-request throughput runs. +- For the offline matrix mode, `--ref-audio` and `--ref-text` are required because clone cases are part of the fixed coverage set. diff --git a/benchmarks/voxcpm/vllm_omni/bench_tts_offline.py b/benchmarks/voxcpm/vllm_omni/bench_tts_offline.py new file mode 100644 index 0000000000..a3bad3e692 --- /dev/null +++ b/benchmarks/voxcpm/vllm_omni/bench_tts_offline.py @@ -0,0 +1,890 @@ +"""Offline VoxCPM benchmark for vLLM Omni. + +Supports both: +- sync one-shot (Omni.generate) +- streaming (AsyncOmni.generate with async_chunk config) +- text-only synthesis +- voice cloning +- text/clone batch inputs from txt or jsonl +""" + +from __future__ import annotations + +import asyncio +import json +import logging +import os +import tempfile +import time +import uuid +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import torch +from vllm.utils.argparse_utils import FlexibleArgumentParser + +from vllm_omni import AsyncOmni, Omni + +REPO_ROOT = Path(__file__).resolve().parents[3] +DEFAULT_STAGE_ASYNC = REPO_ROOT / "vllm_omni" / "model_executor" / "stage_configs" / "voxcpm_async_chunk.yaml" +DEFAULT_STAGE_SYNC = REPO_ROOT / "vllm_omni" / "model_executor" / "stage_configs" / "voxcpm.yaml" + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True, slots=True) +class PromptSpec: + text: str + label: str + ref_audio: str | None = None + ref_text: str | None = None + + +def _require_soundfile(): + try: + import soundfile as sf # type: ignore + except ModuleNotFoundError as exc: + raise RuntimeError( + "soundfile is required to write VoxCPM benchmark WAV outputs. Install it with: pip install soundfile" + ) from exc + return sf + + +def _build_prompt( + args, + *, + text: str, + ref_audio: str | None = None, + ref_text: str | None = None, + global_request_id: str | None = None, +) -> dict[str, Any]: + additional_information: dict[str, list[Any]] = { + "text": [text], + "cfg_value": [args.cfg_value], + "inference_timesteps": [args.inference_timesteps], + "min_len": [args.min_len], + "max_new_tokens": [args.max_new_tokens], + } + if args.streaming_prefix_len is not None: + additional_information["streaming_prefix_len"] = [args.streaming_prefix_len] + + if ref_audio: + additional_information["ref_audio"] = [ref_audio] + if ref_text: + additional_information["ref_text"] = [ref_text] + if global_request_id is not None: + additional_information["global_request_id"] = [global_request_id] + + return { + "prompt_token_ids": [1], + "additional_information": additional_information, + } + + +def _extract_audio_tensor(mm: dict[str, Any]) -> torch.Tensor: + audio = mm.get("audio", mm.get("model_outputs")) + if audio is None: + raise ValueError("No audio output found in multimodal output.") + if isinstance(audio, list): + parts = [torch.as_tensor(a).float().cpu().reshape(-1) for a in audio] + audio = torch.cat(parts, dim=-1) if parts else torch.zeros(0) + if not isinstance(audio, torch.Tensor): + audio = torch.as_tensor(audio) + return audio.float().cpu().reshape(-1) + + +def _extract_sample_rate(mm: dict[str, Any]) -> int: + sr_raw = mm.get("sr", 24000) + if isinstance(sr_raw, list) and sr_raw: + sr_raw = sr_raw[-1] + if hasattr(sr_raw, "item"): + return int(sr_raw.item()) + return int(sr_raw) + + +def _emit_offline_metrics( + *, + request_id: str, + elapsed_s: float, + first_audio_elapsed: float | None, + audio_duration_s: float, +) -> None: + metrics = { + "request_id": request_id, + "ttfp_ms": round(first_audio_elapsed * 1000.0, 3) if first_audio_elapsed is not None else None, + "audio_duration_s": round(audio_duration_s, 6), + "rtf": round(elapsed_s / audio_duration_s, 6) if audio_duration_s > 0 else None, + } + print(f"[OfflineMetrics] {metrics}") + + +def _write_audio_tensor(output_path: Path, audio_tensor: Any, sample_rate: int) -> None: + sf = _require_soundfile() + if isinstance(audio_tensor, torch.Tensor): + audio_np = audio_tensor.float().cpu().clamp(-1.0, 1.0).numpy() + else: + audio_np = torch.as_tensor(audio_tensor).float().cpu().clamp(-1.0, 1.0).numpy() + sf.write( + output_path, + audio_np, + sample_rate, + format="WAV", + subtype="PCM_16", + ) + + +def _save_wav(mm: dict[str, Any], output_dir: Path, request_id: str) -> Path: + output_dir.mkdir(parents=True, exist_ok=True) + output_path = output_dir / f"output_{request_id}.wav" + _write_audio_tensor(output_path, _extract_audio_tensor(mm), _extract_sample_rate(mm)) + return output_path + + +def _iter_request_multimodal_outputs(request_output: Any): + outputs = getattr(request_output, "outputs", None) + if outputs: + for output in outputs: + mm = getattr(output, "multimodal_output", None) + if isinstance(mm, dict): + yield mm + + mm = getattr(request_output, "multimodal_output", None) + if isinstance(mm, dict): + yield mm + + +def _read_non_empty_lines(path: str) -> list[str]: + with open(path, encoding="utf-8") as f: + return [line.strip() for line in f if line.strip()] + + +def _load_prompt_specs(args) -> list[PromptSpec]: + specs: list[PromptSpec] = [] + + if args.txt_prompts is not None: + texts = _read_non_empty_lines(args.txt_prompts) + if not texts: + raise ValueError(f"No prompts found in {args.txt_prompts}") + for idx, text in enumerate(texts, start=1): + specs.append( + PromptSpec( + text=text, + label=f"item{idx:03d}", + ref_audio=args.ref_audio, + ref_text=args.ref_text, + ) + ) + return specs + + if args.jsonl_prompts is not None: + with open(args.jsonl_prompts, encoding="utf-8") as f: + for line_no, raw_line in enumerate(f, start=1): + line = raw_line.strip() + if not line: + continue + try: + item = json.loads(line) + except json.JSONDecodeError as exc: + raise ValueError(f"{args.jsonl_prompts}:{line_no} is not valid JSON: {exc}") from exc + if not isinstance(item, dict): + raise ValueError(f"{args.jsonl_prompts}:{line_no} must be a JSON object") + + text = item.get("text") + if not isinstance(text, str) or not text.strip(): + raise ValueError(f"{args.jsonl_prompts}:{line_no} requires non-empty string field 'text'") + + ref_audio = item.get("ref_audio", args.ref_audio) + ref_text = item.get("ref_text", args.ref_text) + if (ref_audio is None) != (ref_text is None): + raise ValueError( + f"{args.jsonl_prompts}:{line_no} must provide both 'ref_audio' and 'ref_text' together" + ) + + specs.append( + PromptSpec( + text=text.strip(), + label=f"item{len(specs) + 1:03d}", + ref_audio=ref_audio, + ref_text=ref_text, + ) + ) + + if not specs: + raise ValueError(f"No prompts found in {args.jsonl_prompts}") + return specs + + specs.append( + PromptSpec( + text=args.text, + label="item001", + ref_audio=args.ref_audio, + ref_text=args.ref_text, + ) + ) + return specs + + +def _build_prompt_for_spec(args, spec: PromptSpec, *, global_request_id: str | None = None) -> dict[str, Any]: + return _build_prompt( + args, + text=spec.text, + ref_audio=spec.ref_audio, + ref_text=spec.ref_text, + global_request_id=global_request_id, + ) + + +def _count_voice_clone_prompts(prompt_specs: list[PromptSpec]) -> int: + return sum(1 for spec in prompt_specs if spec.ref_audio is not None) + + +def _get_warmup_specs(prompt_specs: list[PromptSpec]) -> list[PromptSpec]: + return prompt_specs[:1] + + +def _extract_stream_finished(stage_output: Any) -> bool: + request_output = getattr(stage_output, "request_output", None) + request_finished = getattr(request_output, "finished", None) + if request_finished is not None: + return bool(request_finished) + return bool(getattr(stage_output, "finished", False)) + + +def _build_profiled_stage_config( + stage_configs_path: str, + profiler_dir: str, +) -> str: + stage_config_path = Path(stage_configs_path) + yaml_text = stage_config_path.read_text(encoding="utf-8") + injected_lines: list[str] = [] + injected_count = 0 + + for line in yaml_text.splitlines(): + injected_lines.append(line) + if line.strip() != "engine_args:": + continue + indent = line[: len(line) - len(line.lstrip())] + child_indent = indent + " " + grandchild_indent = child_indent + " " + injected_lines.extend( + [ + f"{child_indent}profiler_config:", + f'{grandchild_indent}profiler: "torch"', + f'{grandchild_indent}torch_profiler_dir: "{profiler_dir}"', + f"{grandchild_indent}torch_profiler_with_stack: true", + ] + ) + injected_count += 1 + + if injected_count == 0: + raise ValueError(f"No engine_args block found in stage config: {stage_configs_path}") + + tmp = tempfile.NamedTemporaryFile( + mode="w", + encoding="utf-8", + delete=False, + suffix=".yaml", + prefix=f"{stage_config_path.stem}_profile_", + ) + tmp.write("\n".join(injected_lines) + "\n") + tmp.close() + return tmp.name + + +def parse_args(): + parser = FlexibleArgumentParser( + description="Offline split-stage VoxCPM inference with vLLM Omni (auto sync/streaming by stage config)" + ) + parser.add_argument( + "--model", + type=str, + default=os.environ.get("VOXCPM_MODEL"), + help="Local VoxCPM model directory. Defaults to $VOXCPM_MODEL.", + ) + parser.add_argument( + "--text", + type=str, + default="This is a split-stage VoxCPM synthesis example running on vLLM Omni.", + help="Text to synthesize. Ignored when --txt-prompts or --jsonl-prompts is used.", + ) + parser.add_argument( + "--txt-prompts", + type=str, + default=None, + help="Path to a .txt file with one synthesis text per line.", + ) + parser.add_argument( + "--jsonl-prompts", + type=str, + default=None, + help=( + "Path to a .jsonl file. Each line must contain at least {'text': ...}; " + "clone rows can also set ref_audio/ref_text, and ref_text must be the " + "real transcript of ref_audio." + ), + ) + parser.add_argument( + "--ref-audio", + type=str, + default=None, + help=( + "Optional reference audio path for voice cloning. With --txt-prompts, " + "the same reference is applied to every line." + ), + ) + parser.add_argument( + "--ref-text", + type=str, + default=None, + help=( + "Real transcript of the reference audio. Placeholder text or mismatched " + "text will usually produce noisy/electronic clone audio." + ), + ) + parser.add_argument( + "--stage-configs-path", + type=str, + default=str(DEFAULT_STAGE_SYNC), + help="Stage config YAML path. Routing is selected only from this path.", + ) + parser.add_argument( + "--cfg-value", + type=float, + default=2.0, + help="Classifier-free guidance value for VoxCPM.", + ) + parser.add_argument( + "--inference-timesteps", + type=int, + default=10, + help="Number of inference timesteps.", + ) + parser.add_argument( + "--min-len", + type=int, + default=2, + help="Minimum generated token length.", + ) + parser.add_argument( + "--max-new-tokens", + type=int, + default=4096, + help="Maximum generated token length.", + ) + parser.add_argument( + "--streaming-prefix-len", + type=int, + default=None, + help="VoxCPM streaming window (optional, streaming mode only).", + ) + parser.add_argument( + "--output-dir", + type=str, + default=None, + help="Directory for output WAV files.", + ) + parser.add_argument( + "--stage-init-timeout", + type=int, + default=600, + help="Stage initialization timeout in seconds.", + ) + parser.add_argument( + "--log-stats", + dest="log_stats", + action="store_true", + help="Enable vLLM Omni stats logging.", + ) + parser.add_argument( + "--no-log-stats", + dest="log_stats", + action="store_false", + help="Disable vLLM Omni stats logging.", + ) + parser.set_defaults(log_stats=True) + parser.add_argument( + "--num-runs", + type=int, + default=1, + help="Number of full inference runs (same prompt each time). Default 1.", + ) + parser.add_argument( + "--warmup-runs", + type=int, + default=0, + help=( + "Optional number of warmup passes before measured runs. Warmup uses only " + "the first prompt and does not save outputs." + ), + ) + parser.add_argument( + "--enable-profiler", + action="store_true", + help=( + "Enable torch profiler for the configured stages. A temporary profiled " + "stage config is generated automatically." + ), + ) + parser.add_argument( + "--profiler-dir", + type=str, + default=None, + help="Directory for profiler traces. Defaults to /profiler when profiling is enabled.", + ) + parser.add_argument( + "--profiler-stages", + type=int, + nargs="*", + default=None, + help="Optional stage ids to profile. Defaults to all stages that have profiler_config.", + ) + parser.add_argument( + "--profiler-wait-seconds", + type=float, + default=30.0, + help="Seconds to wait after stop_profile for trace files to flush.", + ) + args = parser.parse_args() + + if not args.model: + parser.error("--model is required unless $VOXCPM_MODEL is set") + if args.txt_prompts is not None and args.jsonl_prompts is not None: + parser.error("--txt-prompts and --jsonl-prompts are mutually exclusive") + if (args.ref_audio is None) != (args.ref_text is None): + parser.error("--ref-audio and --ref-text must be provided together") + if args.num_runs < 1: + parser.error("--num-runs must be >= 1") + if args.warmup_runs < 0: + parser.error("--warmup-runs must be >= 0") + if args.output_dir is None: + args.output_dir = ( + "output_audio_streaming" if _is_streaming_stage_config(args.stage_configs_path) else "output_audio" + ) + if args.enable_profiler and args.profiler_dir is None: + args.profiler_dir = str(Path(args.output_dir) / "profiler") + try: + args.prompt_specs = _load_prompt_specs(args) + except ValueError as exc: + parser.error(str(exc)) + + return args + + +def _is_streaming_stage_config(stage_configs_path: str) -> bool: + cfg_name = Path(stage_configs_path).name.lower() + # Keep routing purely config-path based: + # - voxcpm.yaml => sync + # - voxcpm_async_chunk.yaml => streaming + return "async_chunk" in cfg_name + + +async def _collect_streaming_audio( + omni: AsyncOmni, + args: Any, + spec: PromptSpec, + request_id: str, + *, + phase_label: str, + prompt_index: int, + prompt_count: int, + print_prompt: bool = False, +) -> tuple[torch.Tensor, int, float, float | None]: + prompt = _build_prompt_for_spec(args, spec, global_request_id=request_id) + delta_chunks: list[torch.Tensor] = [] + sample_rate = 24000 + chunk_i = 0 + prev_total_samples = 0 + t_start = time.perf_counter() + first_audio_elapsed: float | None = None + + if print_prompt: + print(f"---prompt---:{prompt}") + + async for stage_output in omni.generate(prompt, request_id=request_id): + mm = getattr(stage_output, "multimodal_output", None) + if not isinstance(mm, dict): + ro = getattr(stage_output, "request_output", None) + if ro is None: + continue + mm = getattr(ro, "multimodal_output", None) + if not isinstance(mm, dict) and getattr(ro, "outputs", None): + seq = ro.outputs[0] + mm = getattr(seq, "multimodal_output", None) + if not isinstance(mm, dict): + continue + sample_rate = _extract_sample_rate(mm) + try: + w = _extract_audio_tensor(mm) + n = int(w.numel()) + if n == 0: + continue + finished = _extract_stream_finished(stage_output) + if n > prev_total_samples: + delta = w.reshape(-1)[prev_total_samples:] + prev_total_samples = n + elif finished and n == prev_total_samples: + delta = w.reshape(-1)[:0] + else: + delta = w.reshape(-1) + prev_total_samples += int(delta.numel()) + if int(delta.numel()) > 0: + delta_chunks.append(delta) + if first_audio_elapsed is None and int(delta.numel()) > 0: + first_audio_elapsed = time.perf_counter() - t_start + logger.info( + "%s prompt=%d/%d chunk=%d delta_samples=%d buf_len=%d finished=%s", + phase_label, + prompt_index + 1, + prompt_count, + chunk_i, + int(delta.numel()), + n, + finished, + ) + chunk_i += 1 + except ValueError: + if not _extract_stream_finished(stage_output): + logger.debug("skip non-audio partial output chunk=%d", chunk_i) + + if not delta_chunks: + raise RuntimeError("No audio chunks received; check stage config and logs.") + + audio_cat = torch.cat([c.reshape(-1) for c in delta_chunks], dim=0) + elapsed = time.perf_counter() - t_start + return audio_cat, sample_rate, elapsed, first_audio_elapsed + + +async def _abort_streaming_residual_work( + omni: AsyncOmni, + request_id: str, + *, + settle_seconds: float = 0.1, +) -> None: + """Stop any late stage-0 work once the final audio has been collected.""" + await omni.engine.abort_async([request_id]) + if settle_seconds > 0: + await asyncio.sleep(settle_seconds) + + +async def _run_streaming_single( + omni: AsyncOmni, + args: Any, + spec: PromptSpec, + output_dir: Path, + request_id: str, + *, + run_index: int, + num_runs: int, + prompt_index: int, + prompt_count: int, +) -> Path: + audio_cat, sample_rate, elapsed, first_audio_elapsed = await _collect_streaming_audio( + omni, + args, + spec, + request_id, + phase_label=f"run={run_index + 1}/{num_runs}", + prompt_index=prompt_index, + prompt_count=prompt_count, + print_prompt=(run_index == 0 and prompt_index == 0), + ) + await _abort_streaming_residual_work(omni, request_id) + output_path = output_dir / f"output_run{run_index + 1}_{spec.label}.wav" + _write_audio_tensor(output_path, audio_cat, sample_rate) + audio_duration_s = float(audio_cat.numel()) / float(sample_rate) if sample_rate > 0 else 0.0 + ttfp_text = f", ttfp={first_audio_elapsed:.2f}s" if first_audio_elapsed is not None else "" + rtf_text = f", rtf={elapsed / audio_duration_s:.3f}" if audio_duration_s > 0 else "" + print( + f"Saved (streaming) run {run_index + 1}/{num_runs}, " + f"prompt {prompt_index + 1}/{prompt_count}: {output_path} ({elapsed:.2f}s{ttfp_text}{rtf_text})" + ) + _emit_offline_metrics( + request_id=request_id, + elapsed_s=elapsed, + first_audio_elapsed=first_audio_elapsed, + audio_duration_s=audio_duration_s, + ) + return output_path + + +async def _run_streaming_warmup(args, omni: AsyncOmni) -> None: + if args.warmup_runs == 0: + return + + warmup_specs = _get_warmup_specs(args.prompt_specs) + print( + f"Warmup: {args.warmup_runs} run(s) using the first prompt " + f"({len(warmup_specs)} prompt(s)); outputs will be discarded." + ) + for warmup_index in range(args.warmup_runs): + t_warmup = time.perf_counter() + tasks = [] + request_ids: list[str] = [] + for prompt_index, spec in enumerate(warmup_specs): + request_id = f"warmup_stream_{warmup_index + 1}_{spec.label}_{uuid.uuid4().hex[:8]}" + request_ids.append(request_id) + tasks.append( + _collect_streaming_audio( + omni, + args, + spec, + request_id, + phase_label=f"warmup={warmup_index + 1}/{args.warmup_runs}", + prompt_index=prompt_index, + prompt_count=len(warmup_specs), + ) + ) + results = await asyncio.gather(*tasks) + for request_id in request_ids: + await _abort_streaming_residual_work(omni, request_id) + total_samples = sum(int(audio.numel()) for audio, _, _, _ in results) + warmup_ttfps = [ttfp for _, _, _, ttfp in results if ttfp is not None] + ttfp_text = f", ttfp={min(warmup_ttfps):.2f}s" if warmup_ttfps else "" + print( + f"Warmup (streaming) {warmup_index + 1}/{args.warmup_runs} finished: " + f"{len(results)} prompt(s), {total_samples} sample(s) " + f"({time.perf_counter() - t_warmup:.2f}s{ttfp_text})" + ) + + +async def _run_streaming(args) -> list[Path]: + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + omni = AsyncOmni( + model=args.model, + stage_configs_path=args.stage_configs_path, + log_stats=args.log_stats, + stage_init_timeout=args.stage_init_timeout, + ) + + await _run_streaming_warmup(args, omni) + profiler_started = False + if args.enable_profiler: + profile_prefix = f"voxcpm_streaming_{int(time.time())}" + stages_text = args.profiler_stages if args.profiler_stages is not None else "all-configured" + print(f"Starting profiler (streaming): stages={stages_text}, dir={args.profiler_dir}") + await omni.start_profile(profile_prefix=profile_prefix, stages=args.profiler_stages) + profiler_started = True + t_total = time.perf_counter() + total_elapsed = 0.0 + paths: list[Path] = [] + prompt_specs: list[PromptSpec] = args.prompt_specs + try: + for run in range(args.num_runs): + for prompt_index, spec in enumerate(prompt_specs): + request_id = f"stream_{run + 1}_{spec.label}_{uuid.uuid4().hex[:8]}" + paths.append( + await _run_streaming_single( + omni, + args, + spec, + output_dir, + request_id, + run_index=run, + num_runs=args.num_runs, + prompt_index=prompt_index, + prompt_count=len(prompt_specs), + ) + ) + total_elapsed = time.perf_counter() - t_total + finally: + if profiler_started: + print("Stopping profiler (streaming)...") + await omni.stop_profile(stages=args.profiler_stages) + if args.profiler_wait_seconds > 0: + print(f"Waiting {args.profiler_wait_seconds:.1f}s for profiler traces to flush...") + await asyncio.sleep(args.profiler_wait_seconds) + + print( + f"All streaming runs finished: {args.num_runs} run(s), " + f"{len(prompt_specs)} prompt(s), {len(paths)} file(s) in {total_elapsed:.2f}s total" + ) + return paths + + +def _run_sync(args) -> list[Path]: + output_dir = Path(args.output_dir) + + omni = Omni( + model=args.model, + stage_configs_path=args.stage_configs_path, + log_stats=args.log_stats, + stage_init_timeout=args.stage_init_timeout, + ) + + def _run_sync_single( + spec: PromptSpec, + *, + request_prefix: str, + save_outputs: bool, + run_index: int | None = None, + ) -> tuple[list[Path], int, float | None, float, float, str]: + global_request_id = f"{request_prefix}_{spec.label}" + prompt = _build_prompt_for_spec(args, spec, global_request_id=global_request_id) + if save_outputs and run_index == 0 and spec.label == "item001": + print(f"---prompt---:{prompt}") + + saved_paths: list[Path] = [] + output_count = 0 + first_audio_elapsed: float | None = None + total_audio_duration_s = 0.0 + metrics_request_id = global_request_id + t_start = time.perf_counter() + for stage_outputs in omni.generate(prompt): + request_output = stage_outputs.request_output + if request_output is None: + continue + request_output_id = getattr(request_output, "request_id", None) + if isinstance(request_output_id, str) and request_output_id: + metrics_request_id = request_output_id + for j, mm in enumerate(_iter_request_multimodal_outputs(request_output)): + output_count += 1 + if first_audio_elapsed is None: + try: + audio_tensor = _extract_audio_tensor(mm) + if int(audio_tensor.numel()) > 0: + first_audio_elapsed = time.perf_counter() - t_start + total_audio_duration_s += float(audio_tensor.numel()) / float(_extract_sample_rate(mm)) + except ValueError: + pass + else: + try: + audio_tensor = _extract_audio_tensor(mm) + total_audio_duration_s += float(audio_tensor.numel()) / float(_extract_sample_rate(mm)) + except ValueError: + pass + if not save_outputs: + continue + save_stem = f"run{run_index + 1}_{spec.label}" if j == 0 else f"run{run_index + 1}_{spec.label}_{j}" + saved_paths.append(_save_wav(mm, output_dir, save_stem)) + + if output_count == 0: + raise RuntimeError("No output from Omni.generate") + elapsed_s = time.perf_counter() - t_start + return saved_paths, output_count, first_audio_elapsed, elapsed_s, total_audio_duration_s, metrics_request_id + + if args.warmup_runs: + warmup_specs = _get_warmup_specs(args.prompt_specs) + print( + f"Warmup: {args.warmup_runs} run(s) using the first prompt " + f"({len(warmup_specs)} prompt(s)); outputs will be discarded." + ) + for warmup_index in range(args.warmup_runs): + t_warmup = time.perf_counter() + _, output_count, first_audio_elapsed, elapsed_s, audio_duration_s, _ = _run_sync_single( + warmup_specs[0], + request_prefix=f"warmup_sync{warmup_index + 1}", + save_outputs=False, + ) + ttfp_text = f", ttfp={first_audio_elapsed:.2f}s" if first_audio_elapsed is not None else "" + rtf_text = f", rtf={elapsed_s / audio_duration_s:.3f}" if audio_duration_s > 0 else "" + print( + f"Warmup (sync) {warmup_index + 1}/{args.warmup_runs} finished: " + f"{output_count} output(s) ({time.perf_counter() - t_warmup:.2f}s{ttfp_text}{rtf_text})" + ) + + profiler_started = False + if args.enable_profiler: + profile_prefix = f"voxcpm_sync_{int(time.time())}" + stages_text = args.profiler_stages if args.profiler_stages is not None else "all-configured" + print(f"Starting profiler (sync): stages={stages_text}, dir={args.profiler_dir}") + omni.start_profile(profile_prefix=profile_prefix, stages=args.profiler_stages) + profiler_started = True + + t_total = time.perf_counter() + total_elapsed = 0.0 + saved_paths: list[Path] = [] + prompt_specs: list[PromptSpec] = args.prompt_specs + try: + for run in range(args.num_runs): + t_run = time.perf_counter() + run_paths: list[Path] = [] + for prompt_index, spec in enumerate(prompt_specs): + prompt_paths, _, first_audio_elapsed, elapsed_s, audio_duration_s, metrics_request_id = ( + _run_sync_single( + spec, + request_prefix=f"sync_run{run + 1}_{prompt_index + 1:03d}", + save_outputs=True, + run_index=run, + ) + ) + run_paths.extend(prompt_paths) + ttfp_text = f", ttfp={first_audio_elapsed:.2f}s" if first_audio_elapsed is not None else "" + rtf_text = f", rtf={elapsed_s / audio_duration_s:.3f}" if audio_duration_s > 0 else "" + print( + f"Saved (sync) run {run + 1}/{args.num_runs}, " + f"prompt {prompt_index + 1}/{len(prompt_specs)}: {len(prompt_paths)} file(s){ttfp_text}{rtf_text}" + ) + _emit_offline_metrics( + request_id=metrics_request_id, + elapsed_s=elapsed_s, + first_audio_elapsed=first_audio_elapsed, + audio_duration_s=audio_duration_s, + ) + + saved_paths.extend(run_paths) + print( + f"Run {run + 1}/{args.num_runs} finished: {len(run_paths)} file(s) ({time.perf_counter() - t_run:.2f}s)" + ) + for path in run_paths: + print(f" {path}") + + total_elapsed = time.perf_counter() - t_total + finally: + if profiler_started: + print("Stopping profiler (sync)...") + omni.stop_profile(stages=args.profiler_stages) + if args.profiler_wait_seconds > 0: + print(f"Waiting {args.profiler_wait_seconds:.1f}s for profiler traces to flush...") + time.sleep(args.profiler_wait_seconds) + + print( + f"All sync runs finished: {args.num_runs} run(s), " + f"{len(prompt_specs)} prompt(s), {len(saved_paths)} file(s) in {total_elapsed:.2f}s total" + ) + return saved_paths + + +def main(args) -> int: + logging.basicConfig(level=logging.INFO) + profiled_stage_config_path: str | None = None + original_stage_config_path = args.stage_configs_path + if args.enable_profiler: + Path(args.profiler_dir).mkdir(parents=True, exist_ok=True) + profiled_stage_config_path = _build_profiled_stage_config( + args.stage_configs_path, + str(Path(args.profiler_dir).resolve()), + ) + args.stage_configs_path = profiled_stage_config_path + + is_streaming = _is_streaming_stage_config(args.stage_configs_path) + voice_clone_count = _count_voice_clone_prompts(args.prompt_specs) + print(f"Model: {args.model}") + print(f"Stage config: {original_stage_config_path}") + print(f"Route: {'streaming' if is_streaming else 'sync'} (from stage-configs-path)") + print(f"Prompt count: {len(args.prompt_specs)}") + print("Batch mode: sequential (aligned with native VoxCPM)") + print(f"Warmup runs: {args.warmup_runs}") + print(f"Voice cloning prompts: {voice_clone_count}/{len(args.prompt_specs)}") + if args.enable_profiler: + print(f"Profiler: enabled (dir={args.profiler_dir}, stages={args.profiler_stages or 'all-configured'})") + print(f"Profiled stage config: {args.stage_configs_path}") + if voice_clone_count: + print("Voice cloning note: --ref-text/ref_text must match the spoken content of the reference audio.") + print(f"Num runs: {args.num_runs}") + try: + if is_streaming: + asyncio.run(_run_streaming(args)) + else: + _run_sync(args) + finally: + if profiled_stage_config_path is not None and os.path.exists(profiled_stage_config_path): + os.unlink(profiled_stage_config_path) + return 0 + + +if __name__ == "__main__": + os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + raise SystemExit(main(parse_args())) diff --git a/benchmarks/voxcpm/vllm_omni/bench_tts_serve.py b/benchmarks/voxcpm/vllm_omni/bench_tts_serve.py new file mode 100644 index 0000000000..816df32796 --- /dev/null +++ b/benchmarks/voxcpm/vllm_omni/bench_tts_serve.py @@ -0,0 +1,283 @@ +"""Benchmark VoxCPM via /v1/audio/speech. + +Reports TTFP (time to first packet), E2E latency, and RTF (real-time factor). +""" + +from __future__ import annotations + +import argparse +import asyncio +import json +import time +from dataclasses import asdict, dataclass, field +from datetime import datetime +from pathlib import Path + +import aiohttp +import numpy as np +from tqdm.asyncio import tqdm + +DEFAULT_MODEL = "OpenBMB/VoxCPM1.5" +DEFAULT_SAMPLE_RATE = 24000 +PROMPTS = [ + "Hello, welcome to the VoxCPM speech benchmark.", + "This is a short benchmark prompt for online text-to-speech generation.", + "The quick brown fox jumps over the lazy dog near the riverbank.", + "Please remember to bring your identification documents tomorrow morning.", + "Learning a new language takes patience, practice, and curiosity.", + "This benchmark reports TTFP and RTF for the VoxCPM online serving path.", +] + + +@dataclass +class RequestResult: + success: bool = False + ttfp: float = 0.0 + e2e: float = 0.0 + audio_bytes: int = 0 + audio_duration: float = 0.0 + rtf: float = 0.0 + prompt: str = "" + error: str = "" + + +@dataclass +class BenchmarkResult: + concurrency: int = 0 + num_prompts: int = 0 + completed: int = 0 + failed: int = 0 + duration_s: float = 0.0 + mean_ttfp_ms: float = 0.0 + median_ttfp_ms: float = 0.0 + p95_ttfp_ms: float = 0.0 + mean_e2e_ms: float = 0.0 + median_e2e_ms: float = 0.0 + p95_e2e_ms: float = 0.0 + mean_rtf: float = 0.0 + median_rtf: float = 0.0 + p95_rtf: float = 0.0 + total_audio_duration_s: float = 0.0 + request_throughput: float = 0.0 + per_request: list[dict[str, float | str]] = field(default_factory=list) + + +def pcm_bytes_to_duration(num_bytes: int, sample_rate: int = DEFAULT_SAMPLE_RATE, sample_width: int = 2) -> float: + num_samples = num_bytes / sample_width + return num_samples / sample_rate + + +async def send_tts_request( + session: aiohttp.ClientSession, + api_url: str, + *, + model: str, + prompt: str, + ref_audio: str | None, + ref_text: str | None, + pbar: tqdm | None = None, +) -> RequestResult: + payload: dict[str, object] = { + "model": model, + "input": prompt, + "stream": True, + "response_format": "pcm", + } + if ref_audio is not None: + payload["ref_audio"] = ref_audio + if ref_text is not None: + payload["ref_text"] = ref_text + + result = RequestResult(prompt=prompt) + started_at = time.perf_counter() + + try: + async with session.post(api_url, json=payload) as response: + if response.status != 200: + result.error = f"HTTP {response.status}: {await response.text()}" + return result + + first_chunk = True + total_bytes = 0 + async for chunk in response.content.iter_any(): + if not chunk: + continue + if first_chunk: + result.ttfp = time.perf_counter() - started_at + first_chunk = False + total_bytes += len(chunk) + + result.e2e = time.perf_counter() - started_at + result.audio_bytes = total_bytes + result.audio_duration = pcm_bytes_to_duration(total_bytes) + if result.audio_duration > 0: + result.rtf = result.e2e / result.audio_duration + result.success = True + except Exception as e: + result.error = str(e) + result.e2e = time.perf_counter() - started_at + + if pbar is not None: + pbar.update(1) + return result + + +async def run_benchmark( + *, + host: str, + port: int, + model: str, + num_prompts: int, + max_concurrency: int, + num_warmups: int, + ref_audio: str | None, + ref_text: str | None, +) -> BenchmarkResult: + api_url = f"http://{host}:{port}/v1/audio/speech" + connector = aiohttp.TCPConnector(limit=max_concurrency, limit_per_host=max_concurrency, keepalive_timeout=60) + timeout = aiohttp.ClientTimeout(total=600) + + async with aiohttp.ClientSession(connector=connector, timeout=timeout) as session: + if num_warmups > 0: + print(f" Warming up with {num_warmups} requests...") + warmup_tasks = [ + send_tts_request( + session, + api_url, + model=model, + prompt=PROMPTS[i % len(PROMPTS)], + ref_audio=ref_audio, + ref_text=ref_text, + ) + for i in range(num_warmups) + ] + await asyncio.gather(*warmup_tasks) + print(" Warmup done.") + + request_prompts = [PROMPTS[i % len(PROMPTS)] for i in range(num_prompts)] + semaphore = asyncio.Semaphore(max_concurrency) + pbar = tqdm(total=num_prompts, desc=f" concurrency={max_concurrency}") + + async def limited_request(prompt: str) -> RequestResult: + async with semaphore: + return await send_tts_request( + session, + api_url, + model=model, + prompt=prompt, + ref_audio=ref_audio, + ref_text=ref_text, + pbar=pbar, + ) + + started_at = time.perf_counter() + results = await asyncio.gather(*[asyncio.create_task(limited_request(prompt)) for prompt in request_prompts]) + duration = time.perf_counter() - started_at + pbar.close() + + succeeded = [result for result in results if result.success] + bench = BenchmarkResult( + concurrency=max_concurrency, + num_prompts=num_prompts, + completed=len(succeeded), + failed=len(results) - len(succeeded), + duration_s=duration, + ) + + if not succeeded: + return bench + + ttfps = np.array([result.ttfp * 1000 for result in succeeded], dtype=np.float64) + e2es = np.array([result.e2e * 1000 for result in succeeded], dtype=np.float64) + rtfs = np.array([result.rtf for result in succeeded], dtype=np.float64) + audio_durations = np.array([result.audio_duration for result in succeeded], dtype=np.float64) + + bench.mean_ttfp_ms = float(np.mean(ttfps)) + bench.median_ttfp_ms = float(np.median(ttfps)) + bench.p95_ttfp_ms = float(np.percentile(ttfps, 95)) + bench.mean_e2e_ms = float(np.mean(e2es)) + bench.median_e2e_ms = float(np.median(e2es)) + bench.p95_e2e_ms = float(np.percentile(e2es, 95)) + bench.mean_rtf = float(np.mean(rtfs)) + bench.median_rtf = float(np.median(rtfs)) + bench.p95_rtf = float(np.percentile(rtfs, 95)) + bench.total_audio_duration_s = float(np.sum(audio_durations)) + bench.request_throughput = len(succeeded) / duration if duration > 0 else 0.0 + bench.per_request = [ + { + "prompt": result.prompt, + "ttfp_ms": result.ttfp * 1000, + "e2e_ms": result.e2e * 1000, + "rtf": result.rtf, + "audio_duration_s": result.audio_duration, + } + for result in succeeded + ] + + return bench + + +def print_summary(result: BenchmarkResult) -> None: + width = 54 + print("") + print("=" * width) + print(f"{'VoxCPM Serving Benchmark':^{width}}") + print("=" * width) + print(f"concurrency : {result.concurrency}") + print(f"requests : {result.completed}/{result.num_prompts} succeeded") + print(f"wall time (s) : {result.duration_s:.3f}") + print(f"mean TTFP (ms) : {result.mean_ttfp_ms:.2f}") + print(f"p95 TTFP (ms) : {result.p95_ttfp_ms:.2f}") + print(f"mean E2E (ms) : {result.mean_e2e_ms:.2f}") + print(f"p95 E2E (ms) : {result.p95_e2e_ms:.2f}") + print(f"mean RTF : {result.mean_rtf:.3f}") + print(f"p95 RTF : {result.p95_rtf:.3f}") + print(f"request throughput : {result.request_throughput:.2f} req/s") + print("=" * width) + + +async def main_async(args) -> None: + result_dir = Path(args.result_dir) + result_dir.mkdir(parents=True, exist_ok=True) + + all_results: list[BenchmarkResult] = [] + for concurrency in args.max_concurrency: + result = await run_benchmark( + host=args.host, + port=args.port, + model=args.model, + num_prompts=args.num_prompts, + max_concurrency=concurrency, + num_warmups=args.num_warmups, + ref_audio=args.ref_audio, + ref_text=args.ref_text, + ) + print_summary(result) + all_results.append(result) + + payload = { + "model": args.model, + "created_at": datetime.utcnow().isoformat() + "Z", + "results": [asdict(result) for result in all_results], + } + result_path = result_dir / "bench_tts_serve.json" + result_path.write_text(json.dumps(payload, indent=2), encoding="utf-8") + print(f"Saved results to: {result_path}") + + +def parse_args(): + parser = argparse.ArgumentParser(description="Benchmark VoxCPM via /v1/audio/speech") + parser.add_argument("--host", default="127.0.0.1", help="Server host") + parser.add_argument("--port", type=int, default=8091, help="Server port") + parser.add_argument("--model", default=DEFAULT_MODEL, help="Model name or path") + parser.add_argument("--num-prompts", type=int, default=20, help="Number of prompts to send") + parser.add_argument("--max-concurrency", type=int, nargs="+", default=[1], help="Concurrency levels to benchmark") + parser.add_argument("--num-warmups", type=int, default=3, help="Warmup request count") + parser.add_argument("--ref-audio", default=None, help="Reference audio URL or data URL for voice cloning") + parser.add_argument("--ref-text", default=None, help="Reference audio transcript for voice cloning") + parser.add_argument("--result-dir", default="results", help="Directory to save benchmark JSON") + return parser.parse_args() + + +if __name__ == "__main__": + asyncio.run(main_async(parse_args())) diff --git a/benchmarks/voxcpm/vllm_omni/run_offline_matrix.py b/benchmarks/voxcpm/vllm_omni/run_offline_matrix.py new file mode 100644 index 0000000000..cee46c0f86 --- /dev/null +++ b/benchmarks/voxcpm/vllm_omni/run_offline_matrix.py @@ -0,0 +1,303 @@ +"""Run the full offline VoxCPM smoke matrix. + +This script keeps the old `test.py` coverage, but delegates each case to +`bench_tts_offline.py` so the benchmark runner itself stays focused on a +single execution path. +""" + +from __future__ import annotations + +import shlex +import subprocess +import sys +import time +from dataclasses import dataclass +from pathlib import Path + +from vllm.utils.argparse_utils import FlexibleArgumentParser + +REPO_ROOT = Path(__file__).resolve().parents[3] +BENCH_SCRIPT = Path(__file__).with_name("bench_tts_offline.py") +DEFAULT_STAGE_ASYNC = REPO_ROOT / "vllm_omni" / "model_executor" / "stage_configs" / "voxcpm_async_chunk.yaml" +DEFAULT_STAGE_SYNC = REPO_ROOT / "vllm_omni" / "model_executor" / "stage_configs" / "voxcpm.yaml" +DEFAULT_OUTPUT_ROOT = BENCH_SCRIPT.parents[1] / "results" / "offline_matrix" + +SINGLE_TTS_TEXT = "This is a single text-to-speech smoke test for VoxCPM on vLLM Omni." +SINGLE_CLONE_TEXT = "This sentence is synthesized with the cloned voice for validation." +BATCH_TTS_TEXTS = [ + "The first batch text-to-speech sample validates sequential batch execution.", + "The second batch text-to-speech sample checks another prompt in the same file.", + "The third batch text-to-speech sample completes the sequential batch path.", +] +BATCH_CLONE_TEXTS = [ + "The first cloned sample validates sequential batch voice cloning.", + "The second cloned sample checks the same reference voice on another prompt.", + "The third cloned sample finishes the shared-reference clone batch path.", +] + + +@dataclass(frozen=True, slots=True) +class ModeSpec: + name: str + stage_config: Path + + +@dataclass(frozen=True, slots=True) +class CaseSpec: + name: str + warmup_runs: int + prompt_kind: str + voice_clone: bool + + +@dataclass(frozen=True, slots=True) +class CaseResult: + mode: str + case: str + returncode: int + elapsed_s: float + output_dir: Path + log_path: Path + + @property + def ok(self) -> bool: + return self.returncode == 0 + + +MODE_SPECS = [ + ModeSpec(name="streaming", stage_config=DEFAULT_STAGE_ASYNC), + ModeSpec(name="sync", stage_config=DEFAULT_STAGE_SYNC), +] + +CASE_SPECS = [ + CaseSpec(name="warmup_single_tts", warmup_runs=1, prompt_kind="single", voice_clone=False), + CaseSpec(name="warmup_single_clone", warmup_runs=1, prompt_kind="single", voice_clone=True), + CaseSpec(name="warmup_batch_tts", warmup_runs=1, prompt_kind="batch", voice_clone=False), + CaseSpec(name="warmup_batch_clone", warmup_runs=1, prompt_kind="batch", voice_clone=True), + CaseSpec(name="cold_single_tts", warmup_runs=0, prompt_kind="single", voice_clone=False), + CaseSpec(name="cold_single_clone", warmup_runs=0, prompt_kind="single", voice_clone=True), +] + + +def _write_lines(path: Path, lines: list[str]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text("\n".join(lines) + "\n", encoding="utf-8") + + +def _prepare_batch_inputs(output_root: Path) -> tuple[Path, Path]: + input_dir = output_root / "inputs" + batch_tts_path = input_dir / "batch_tts_prompts.txt" + batch_clone_path = input_dir / "batch_clone_prompts.txt" + _write_lines(batch_tts_path, BATCH_TTS_TEXTS) + _write_lines(batch_clone_path, BATCH_CLONE_TEXTS) + return batch_tts_path, batch_clone_path + + +def _base_command(args, mode: ModeSpec, output_dir: Path) -> list[str]: + cmd = [ + args.python, + str(BENCH_SCRIPT), + "--model", + args.model, + "--stage-configs-path", + str(mode.stage_config), + "--output-dir", + str(output_dir), + "--num-runs", + str(args.num_runs), + "--stage-init-timeout", + str(args.stage_init_timeout), + ] + cmd.append("--log-stats" if args.log_stats else "--no-log-stats") + cmd.extend(["--cfg-value", str(args.cfg_value)]) + cmd.extend(["--inference-timesteps", str(args.inference_timesteps)]) + cmd.extend(["--min-len", str(args.min_len)]) + cmd.extend(["--max-new-tokens", str(args.max_new_tokens)]) + if args.streaming_prefix_len is not None: + cmd.extend(["--streaming-prefix-len", str(args.streaming_prefix_len)]) + if args.enable_profiler: + profiler_dir = Path(args.profiler_dir) if args.profiler_dir is not None else (output_dir / "profiler") + cmd.append("--enable-profiler") + cmd.extend(["--profiler-dir", str(profiler_dir)]) + cmd.extend(["--profiler-wait-seconds", str(args.profiler_wait_seconds)]) + if args.profiler_stages is not None: + cmd.append("--profiler-stages") + cmd.extend(str(stage_id) for stage_id in args.profiler_stages) + return cmd + + +def _build_case_command( + args, + mode: ModeSpec, + case: CaseSpec, + *, + batch_tts_path: Path, + batch_clone_path: Path, + output_dir: Path, +) -> list[str]: + cmd = _base_command(args, mode, output_dir) + cmd.extend(["--warmup-runs", str(case.warmup_runs)]) + if case.prompt_kind == "single": + cmd.extend(["--text", SINGLE_CLONE_TEXT if case.voice_clone else SINGLE_TTS_TEXT]) + else: + cmd.extend(["--txt-prompts", str(batch_clone_path if case.voice_clone else batch_tts_path)]) + if case.voice_clone: + cmd.extend(["--ref-audio", args.ref_audio, "--ref-text", args.ref_text]) + return cmd + + +def _run_case( + args, + mode: ModeSpec, + case: CaseSpec, + *, + batch_tts_path: Path, + batch_clone_path: Path, + output_root: Path, +) -> CaseResult: + case_output_dir = output_root / mode.name / case.name + case_output_dir.mkdir(parents=True, exist_ok=True) + case_log_path = case_output_dir / "run.log" + cmd = _build_case_command( + args, + mode, + case, + batch_tts_path=batch_tts_path, + batch_clone_path=batch_clone_path, + output_dir=case_output_dir, + ) + + print() + print("=" * 80) + print(f"[{mode.name}] {case.name}") + print(f"Output directory: {case_output_dir}") + print(shlex.join(cmd)) + + start = time.perf_counter() + with case_log_path.open("w", encoding="utf-8") as log_fp: + process = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + bufsize=1, + ) + assert process.stdout is not None + for line in process.stdout: + print(line, end="") + log_fp.write(line) + process.wait() + + elapsed_s = time.perf_counter() - start + status = "PASS" if (process.returncode or 0) == 0 else f"FAIL({process.returncode})" + print(f"[{mode.name}] {case.name} -> {status} ({elapsed_s:.2f}s)") + return CaseResult( + mode=mode.name, + case=case.name, + returncode=int(process.returncode or 0), + elapsed_s=elapsed_s, + output_dir=case_output_dir, + log_path=case_log_path, + ) + + +def parse_args(): + parser = FlexibleArgumentParser(description="Run the full offline VoxCPM smoke matrix.") + parser.add_argument("--model", type=str, required=True, help="Local VoxCPM model directory.") + parser.add_argument("--ref-audio", type=str, required=True, help="Reference audio path for clone cases.") + parser.add_argument("--ref-text", type=str, required=True, help="Exact transcript spoken in --ref-audio.") + parser.add_argument("--output-root", type=str, default=str(DEFAULT_OUTPUT_ROOT), help="Root directory for outputs.") + parser.add_argument("--python", type=str, default=sys.executable, help="Python executable used to launch cases.") + parser.add_argument("--stage-init-timeout", type=int, default=600, help="Stage initialization timeout in seconds.") + parser.add_argument("--log-stats", dest="log_stats", action="store_true", help="Enable vLLM Omni stats logging.") + parser.add_argument( + "--no-log-stats", + dest="log_stats", + action="store_false", + help="Disable vLLM Omni stats logging.", + ) + parser.set_defaults(log_stats=True) + parser.add_argument("--num-runs", type=int, default=1, help="Number of measured runs per case.") + parser.add_argument("--cfg-value", type=float, default=2.0, help="Classifier-free guidance value for VoxCPM.") + parser.add_argument("--inference-timesteps", type=int, default=10, help="Number of inference timesteps.") + parser.add_argument("--min-len", type=int, default=2, help="Minimum generated token length.") + parser.add_argument("--max-new-tokens", type=int, default=4096, help="Maximum generated token length.") + parser.add_argument( + "--streaming-prefix-len", + type=int, + default=None, + help="Optional VoxCPM streaming window passed to streaming cases.", + ) + parser.add_argument("--enable-profiler", action="store_true", help="Enable torch profiler for each case.") + parser.add_argument( + "--profiler-dir", + type=str, + default=None, + help="Profiler output root. Defaults to /profiler.", + ) + parser.add_argument( + "--profiler-stages", + type=int, + nargs="*", + default=None, + help="Optional stage ids to profile. Defaults to all configured stages.", + ) + parser.add_argument( + "--profiler-wait-seconds", + type=float, + default=30.0, + help="Seconds to wait after stopping profiler for traces to flush.", + ) + args = parser.parse_args() + if args.num_runs < 1: + parser.error("--num-runs must be >= 1") + return args + + +def main(args) -> int: + output_root = Path(args.output_root) + output_root.mkdir(parents=True, exist_ok=True) + batch_tts_path, batch_clone_path = _prepare_batch_inputs(output_root) + + print(f"Model: {args.model}") + print(f"Reference audio: {args.ref_audio}") + print(f"Reference text: {args.ref_text}") + print(f"Python: {args.python}") + print(f"Output root: {output_root}") + print(f"Cases: {len(MODE_SPECS) * len(CASE_SPECS)}") + + results: list[CaseResult] = [] + for mode in MODE_SPECS: + for case in CASE_SPECS: + results.append( + _run_case( + args, + mode, + case, + batch_tts_path=batch_tts_path, + batch_clone_path=batch_clone_path, + output_root=output_root, + ) + ) + + failed = [result for result in results if not result.ok] + print() + print("=" * 80) + print("Summary:") + for result in results: + status = "PASS" if result.ok else f"FAIL({result.returncode})" + print(f"- [{result.mode}] {result.case}: {status} ({result.elapsed_s:.2f}s)") + print(f" output_dir={result.output_dir}") + print(f" log={result.log_path}") + + print(f"Passed: {len(results) - len(failed)}/{len(results)}") + if failed: + print("Failed cases:") + for result in failed: + print(f"- [{result.mode}] {result.case}: see {result.log_path}") + return 1 + return 0 + + +if __name__ == "__main__": + raise SystemExit(main(parse_args())) diff --git a/examples/offline_inference/voxcpm/README.md b/examples/offline_inference/voxcpm/README.md new file mode 100644 index 0000000000..1eaea9b0db --- /dev/null +++ b/examples/offline_inference/voxcpm/README.md @@ -0,0 +1,123 @@ +# VoxCPM Offline Example + +This directory contains the minimal offline VoxCPM example for vLLM Omni. + +`end2end.py` is intentionally small and only covers: + +- single text-to-speech +- single voice cloning with `ref_audio` + `ref_text` +- non-streaming with `vllm_omni/model_executor/stage_configs/voxcpm.yaml` +- streaming with `vllm_omni/model_executor/stage_configs/voxcpm_async_chunk.yaml` + +Advanced workflows were moved out of the getting-started example: + +- `benchmarks/voxcpm/vllm_omni/bench_tts_offline.py`: warmup, batch prompts, profiler, offline TTFP / RTF +- `benchmarks/voxcpm/vllm_omni/run_offline_matrix.py`: fixed offline smoke matrix +- `benchmarks/voxcpm/`: benchmark scripts and benchmark docs + +## Prerequisites + +Install VoxCPM in one of these ways: + +```bash +pip install voxcpm +``` + +or point vLLM Omni to the local VoxCPM source tree: + +```bash +export VLLM_OMNI_VOXCPM_CODE_PATH=/path/to/VoxCPM/src +``` + +The example writes WAV files with `soundfile`: + +```bash +pip install soundfile +``` + +## Model Path + +Pass the native VoxCPM model directory directly: + +```bash +export VOXCPM_MODEL=/path/to/voxcpm-model +``` + +If the native VoxCPM `config.json` does not contain HuggingFace metadata such as +`model_type`, prepare a persistent HF-compatible config directory and point the +stage configs to it with `VLLM_OMNI_VOXCPM_HF_CONFIG_PATH`: + +```bash +export VLLM_OMNI_VOXCPM_HF_CONFIG_PATH=/tmp/voxcpm_hf_config +mkdir -p "$VLLM_OMNI_VOXCPM_HF_CONFIG_PATH" +cp "$VOXCPM_MODEL/config.json" "$VLLM_OMNI_VOXCPM_HF_CONFIG_PATH/config.json" +cp "$VOXCPM_MODEL/generation_config.json" "$VLLM_OMNI_VOXCPM_HF_CONFIG_PATH/generation_config.json" 2>/dev/null || true +python3 -c 'import json, os; p=os.path.join(os.environ["VLLM_OMNI_VOXCPM_HF_CONFIG_PATH"], "config.json"); cfg=json.load(open(p, "r", encoding="utf-8")); cfg["model_type"]="voxcpm"; cfg.setdefault("architectures", ["VoxCPMForConditionalGeneration"]); json.dump(cfg, open(p, "w", encoding="utf-8"), indent=2, ensure_ascii=False)' +``` + +If the model directory itself already has `model_type`, this extra directory is +not required. + +## Quick Start + +Single text-to-speech, non-streaming: + +```bash +python examples/offline_inference/voxcpm/end2end.py \ + --model "$VOXCPM_MODEL" \ + --text "This is a split-stage VoxCPM synthesis example running on vLLM Omni." +``` + +Single voice cloning, non-streaming: + +```bash +python examples/offline_inference/voxcpm/end2end.py \ + --model "$VOXCPM_MODEL" \ + --text "This sentence is synthesized with a cloned voice." \ + --ref-audio /path/to/reference.wav \ + --ref-text "The exact transcript spoken in reference.wav." +``` + +Streaming: + +```bash +python examples/offline_inference/voxcpm/end2end.py \ + --model "$VOXCPM_MODEL" \ + --stage-configs-path vllm_omni/model_executor/stage_configs/voxcpm_async_chunk.yaml \ + --text "This is a split-stage VoxCPM streaming example running on vLLM Omni." +``` + +By default, `end2end.py` writes to `output_audio/` for non-streaming and +`output_audio_streaming/` for streaming. + +## Advanced Workflows + +Use `benchmarks/voxcpm/vllm_omni/bench_tts_offline.py` when you need: + +- warmup runs +- prompt files +- batch JSONL inputs +- profiler injection +- offline TTFP / RTF emission + +Use `benchmarks/voxcpm/vllm_omni/run_offline_matrix.py` when you need the fixed offline smoke matrix that previously lived in `test.py`. + +Full matrix benchmark example: + +```bash +python benchmarks/voxcpm/vllm_omni/run_offline_matrix.py \ + --model "$VOXCPM_MODEL" \ + --ref-audio /path/to/reference.wav \ + --ref-text "The exact transcript spoken in reference.wav." +``` + +For online serving examples, see [examples/online_serving/voxcpm](../../online_serving/voxcpm/README.md). + +For benchmark reporting, see [benchmarks/voxcpm](../../../benchmarks/voxcpm/README.md). + +## Notes + +- `voxcpm.yaml` is the default non-streaming stage config. +- `voxcpm_async_chunk.yaml` is the streaming stage config. +- Streaming is currently single-request oriented; the fixed smoke matrix now lives in `benchmarks/voxcpm/vllm_omni/run_offline_matrix.py`. +- `ref_text` must be the real transcript of the reference audio. Mismatched text usually causes obvious quality degradation. diff --git a/examples/offline_inference/voxcpm/end2end.py b/examples/offline_inference/voxcpm/end2end.py new file mode 100644 index 0000000000..980410feae --- /dev/null +++ b/examples/offline_inference/voxcpm/end2end.py @@ -0,0 +1,206 @@ +"""Minimal offline VoxCPM example for vLLM Omni.""" + +from __future__ import annotations + +import asyncio +import time +from pathlib import Path +from typing import Any + +import soundfile as sf +import torch +from vllm.utils.argparse_utils import FlexibleArgumentParser + +from vllm_omni import AsyncOmni, Omni + +REPO_ROOT = Path(__file__).resolve().parents[3] +DEFAULT_SYNC_STAGE_CONFIG = REPO_ROOT / "vllm_omni" / "model_executor" / "stage_configs" / "voxcpm.yaml" + + +def _build_prompt(args) -> dict[str, Any]: + additional_information: dict[str, list[Any]] = { + "text": [args.text], + "cfg_value": [args.cfg_value], + "inference_timesteps": [args.inference_timesteps], + "min_len": [args.min_len], + "max_new_tokens": [args.max_new_tokens], + } + if args.streaming_prefix_len is not None: + additional_information["streaming_prefix_len"] = [args.streaming_prefix_len] + if args.ref_audio is not None: + additional_information["ref_audio"] = [args.ref_audio] + if args.ref_text is not None: + additional_information["ref_text"] = [args.ref_text] + return { + "prompt_token_ids": [1], + "additional_information": additional_information, + } + + +def _extract_audio_tensor(mm: dict[str, Any]) -> torch.Tensor: + audio = mm.get("audio", mm.get("model_outputs")) + if audio is None: + raise ValueError("No audio output found in multimodal output.") + if isinstance(audio, list): + parts = [torch.as_tensor(item).float().cpu().reshape(-1) for item in audio] + audio = torch.cat(parts, dim=-1) if parts else torch.zeros(0) + if not isinstance(audio, torch.Tensor): + audio = torch.as_tensor(audio) + return audio.float().cpu().reshape(-1) + + +def _extract_sample_rate(mm: dict[str, Any]) -> int: + sr_raw = mm.get("sr", 24000) + if isinstance(sr_raw, list) and sr_raw: + sr_raw = sr_raw[-1] + if hasattr(sr_raw, "item"): + return int(sr_raw.item()) + return int(sr_raw) + + +def _is_streaming_stage_config(stage_config_path: str) -> bool: + return "async_chunk" in Path(stage_config_path).stem + + +def _save_audio(audio: torch.Tensor, sample_rate: int, output_dir: Path, request_id: str) -> Path: + output_dir.mkdir(parents=True, exist_ok=True) + output_path = output_dir / f"output_{request_id}.wav" + sf.write( + output_path, + audio.float().cpu().clamp(-1.0, 1.0).numpy(), + sample_rate, + format="WAV", + subtype="PCM_16", + ) + return output_path + + +async def _run_streaming(args) -> Path: + prompt = _build_prompt(args) + output_dir = Path(args.output_dir) if args.output_dir is not None else Path("output_audio_streaming") + request_id = "streaming_example" + sample_rate = 24000 + buffered_samples = 0 + chunks: list[torch.Tensor] = [] + started = time.perf_counter() + omni = AsyncOmni( + model=args.model, + stage_configs_path=args.stage_configs_path, + log_stats=args.log_stats, + stage_init_timeout=args.stage_init_timeout, + ) + try: + async for stage_output in omni.generate(prompt, request_id=request_id): + mm = getattr(stage_output, "multimodal_output", None) + if not isinstance(mm, dict): + request_output = getattr(stage_output, "request_output", None) + if request_output is None: + continue + mm = getattr(request_output, "multimodal_output", None) + if not isinstance(mm, dict) and getattr(request_output, "outputs", None): + mm = getattr(request_output.outputs[0], "multimodal_output", None) + if not isinstance(mm, dict): + continue + audio = _extract_audio_tensor(mm) + if audio.numel() == 0: + continue + sample_rate = _extract_sample_rate(mm) + if audio.numel() > buffered_samples: + delta = audio[buffered_samples:] + buffered_samples = int(audio.numel()) + else: + delta = audio + buffered_samples += int(delta.numel()) + if delta.numel() > 0: + chunks.append(delta) + if not chunks: + raise RuntimeError("No streaming audio chunks received from VoxCPM.") + output_audio = torch.cat(chunks, dim=0) + output_path = _save_audio(output_audio, sample_rate, output_dir, request_id) + print(f"Saved streaming audio to: {output_path} ({time.perf_counter() - started:.2f}s)") + return output_path + finally: + omni.shutdown() + + +def _run_sync(args) -> Path: + prompt = _build_prompt(args) + output_dir = Path(args.output_dir) if args.output_dir is not None else Path("output_audio") + request_id = "sync_example" + started = time.perf_counter() + last_mm: dict[str, Any] | None = None + omni = Omni( + model=args.model, + stage_configs_path=args.stage_configs_path, + log_stats=args.log_stats, + stage_init_timeout=args.stage_init_timeout, + ) + for stage_outputs in omni.generate(prompt): + request_output = getattr(stage_outputs, "request_output", None) + if request_output is None: + continue + outputs = getattr(request_output, "outputs", None) + if outputs: + for output in outputs: + mm = getattr(output, "multimodal_output", None) + if isinstance(mm, dict): + last_mm = mm + mm = getattr(request_output, "multimodal_output", None) + if isinstance(mm, dict): + last_mm = mm + if last_mm is None: + raise RuntimeError("No audio output received from VoxCPM.") + output_path = _save_audio( + _extract_audio_tensor(last_mm), + _extract_sample_rate(last_mm), + output_dir, + request_id, + ) + print(f"Saved audio to: {output_path} ({time.perf_counter() - started:.2f}s)") + return output_path + + +def parse_args(): + parser = FlexibleArgumentParser(description="Minimal offline VoxCPM example for vLLM Omni.") + parser.add_argument("--model", type=str, required=True, help="Local VoxCPM model directory.") + parser.add_argument( + "--stage-configs-path", + type=str, + default=str(DEFAULT_SYNC_STAGE_CONFIG), + help=("Stage config path. Use voxcpm.yaml for non-streaming or voxcpm_async_chunk.yaml for streaming."), + ) + parser.add_argument("--text", type=str, required=True, help="Input text for synthesis.") + parser.add_argument("--ref-audio", type=str, default=None, help="Reference audio path for voice cloning.") + parser.add_argument("--ref-text", type=str, default=None, help="Transcript of the reference audio.") + parser.add_argument("--output-dir", type=str, default=None, help="Output directory for generated wav files.") + parser.add_argument("--cfg-value", type=float, default=2.0, help="Guidance value passed to VoxCPM.") + parser.add_argument("--inference-timesteps", type=int, default=10, help="Number of diffusion timesteps.") + parser.add_argument("--min-len", type=int, default=2, help="Minimum latent length.") + parser.add_argument("--max-new-tokens", type=int, default=4096, help="Maximum latent length.") + parser.add_argument( + "--streaming-prefix-len", + type=int, + default=3, + help="Streaming prefix length used by voxcpm_async_chunk.yaml.", + ) + parser.add_argument("--stage-init-timeout", type=int, default=600, help="Stage initialization timeout in seconds.") + parser.add_argument("--log-stats", action="store_true", help="Enable vLLM Omni stats logging.") + args = parser.parse_args() + if (args.ref_audio is None) != (args.ref_text is None): + raise ValueError("Voice cloning requires --ref-audio and --ref-text together.") + return args + + +def main(args) -> None: + route = "streaming" if _is_streaming_stage_config(args.stage_configs_path) else "sync" + print(f"Model: {args.model}") + print(f"Stage config: {args.stage_configs_path}") + print(f"Route: {route}") + if route == "streaming": + asyncio.run(_run_streaming(args)) + else: + _run_sync(args) + + +if __name__ == "__main__": + main(parse_args()) diff --git a/examples/online_serving/voxcpm/README.md b/examples/online_serving/voxcpm/README.md new file mode 100644 index 0000000000..78e1bf4aaa --- /dev/null +++ b/examples/online_serving/voxcpm/README.md @@ -0,0 +1,166 @@ +# VoxCPM + +## Prerequisites + +Install VoxCPM in one of these ways: + +```bash +pip install voxcpm +``` + +or point vLLM-Omni to a local VoxCPM source tree: + +```bash +export VLLM_OMNI_VOXCPM_CODE_PATH=/path/to/VoxCPM/src +``` + +If the native VoxCPM `config.json` lacks HF metadata such as `model_type`, +prepare a persistent HF-compatible config directory and export: + +```bash +export VLLM_OMNI_VOXCPM_HF_CONFIG_PATH=/tmp/voxcpm_hf_config +mkdir -p "$VLLM_OMNI_VOXCPM_HF_CONFIG_PATH" +cp "$VOXCPM_MODEL/config.json" "$VLLM_OMNI_VOXCPM_HF_CONFIG_PATH/config.json" +cp "$VOXCPM_MODEL/generation_config.json" "$VLLM_OMNI_VOXCPM_HF_CONFIG_PATH/generation_config.json" 2>/dev/null || true +python3 -c 'import json, os; p=os.path.join(os.environ["VLLM_OMNI_VOXCPM_HF_CONFIG_PATH"], "config.json"); cfg=json.load(open(p, "r", encoding="utf-8")); cfg["model_type"]="voxcpm"; cfg.setdefault("architectures", ["VoxCPMForConditionalGeneration"]); json.dump(cfg, open(p, "w", encoding="utf-8"), indent=2, ensure_ascii=False)' +``` + +The VoxCPM stage configs read `VLLM_OMNI_VOXCPM_HF_CONFIG_PATH` directly. The `python3 -c` form above avoids heredoc/indentation issues in interactive shells. + +## Launch the Server + +Use the async-chunk stage config by default: + +```bash +export VOXCPM_MODEL=/path/to/voxcpm-model +cd examples/online_serving/voxcpm +./run_server.sh +``` + +Use the non-streaming stage config: + +```bash +./run_server.sh sync +``` + +You can also launch the server directly: + +```bash +vllm serve "$VOXCPM_MODEL" \ + --stage-configs-path vllm_omni/model_executor/stage_configs/voxcpm_async_chunk.yaml \ + --trust-remote-code \ + --enforce-eager \ + --omni \ + --port 8091 +``` + +## Send Requests + +### Basic text-to-speech + +```bash +python openai_speech_client.py \ + --model "$VOXCPM_MODEL" \ + --text "This is a VoxCPM online text-to-speech example." +``` + +### Voice cloning + +```bash +python openai_speech_client.py \ + --model "$VOXCPM_MODEL" \ + --text "This sentence is synthesized with a cloned voice." \ + --ref-audio /path/to/reference.wav \ + --ref-text "The exact transcript spoken in reference.wav." +``` + +`ref_text` must be the real transcript of the reference audio. Placeholder text or mismatched text will usually degrade quality badly. + +### Streaming PCM output + +```bash +python openai_speech_client.py \ + --model "$VOXCPM_MODEL" \ + --text "This is a streaming VoxCPM request." \ + --stream \ + --output voxcpm_stream.pcm +``` + +### Using curl + +```bash +curl -X POST http://localhost:8091/v1/audio/speech \ + -H "Content-Type: application/json" \ + -d '{ + "model": "OpenBMB/VoxCPM1.5", + "input": "Hello from VoxCPM online serving.", + "response_format": "wav" + }' --output output.wav +``` + +Voice cloning: + +```bash +curl -X POST http://localhost:8091/v1/audio/speech \ + -H "Content-Type: application/json" \ + -d '{ + "model": "OpenBMB/VoxCPM1.5", + "input": "This sentence uses a cloned voice.", + "ref_audio": "https://example.com/reference.wav", + "ref_text": "The exact transcript spoken in the reference audio.", + "response_format": "wav" + }' --output cloned.wav +``` + +Streaming PCM: + +```bash +curl -X POST http://localhost:8091/v1/audio/speech \ + -H "Content-Type: application/json" \ + -d '{ + "model": "OpenBMB/VoxCPM1.5", + "input": "This is a streaming VoxCPM request.", + "stream": true, + "response_format": "pcm" + }' --output output.pcm +``` + +## Supported Request Shape + +VoxCPM online serving currently supports: + +- plain text-to-speech +- voice cloning with `ref_audio` + `ref_text` +- `stream=true` with `response_format=pcm` or `wav` + +VoxCPM online serving does not use these generic TTS fields: + +- `voice` +- `instructions` +- `language` +- `speaker_embedding` +- `x_vector_only_mode` + +## Streaming vs Non-Streaming + +- `voxcpm_async_chunk.yaml` enables async-chunk streaming and is best for single-request streaming latency. +- `voxcpm.yaml` performs one-shot latent generation then VAE decode. + +Like native VoxCPM, the async streaming path should be treated as single-request. If you need stable throughput benchmarking, prefer `voxcpm.yaml`. + +Do not use `voxcpm_async_chunk.yaml` for concurrent online streaming or `/v1/audio/speech/batch`. For multiple requests, prefer `voxcpm.yaml`. + +## Benchmark + +The serving benchmark reports TTFP and RTF: + +```bash +python benchmarks/voxcpm/vllm_omni/bench_tts_serve.py \ + --host 127.0.0.1 \ + --port 8091 \ + --num-prompts 10 \ + --max-concurrency 1 \ + --result-dir /tmp/voxcpm_bench +``` + +For the async-chunk server, keep `--max-concurrency 1`. diff --git a/examples/online_serving/voxcpm/openai_speech_client.py b/examples/online_serving/voxcpm/openai_speech_client.py new file mode 100644 index 0000000000..c400114e8b --- /dev/null +++ b/examples/online_serving/voxcpm/openai_speech_client.py @@ -0,0 +1,155 @@ +"""OpenAI-compatible client for VoxCPM via /v1/audio/speech. + +Examples: + # Basic text-to-speech + python openai_speech_client.py --text "Hello from VoxCPM" + + # Voice cloning + python openai_speech_client.py \ + --text "This sentence uses the cloned voice." \ + --ref-audio /path/to/reference.wav \ + --ref-text "The exact transcript spoken in the reference audio." + + # Streaming PCM output + python openai_speech_client.py \ + --text "This is a streaming VoxCPM request." \ + --stream \ + --output output.pcm +""" + +import argparse +import base64 +import os + +import httpx + +DEFAULT_API_BASE = "http://localhost:8091" +DEFAULT_API_KEY = "EMPTY" +DEFAULT_MODEL = "OpenBMB/VoxCPM1.5" + + +def encode_audio_to_base64(audio_path: str) -> str: + """Encode a local audio file to base64 data URL.""" + if not os.path.exists(audio_path): + raise FileNotFoundError(f"Audio file not found: {audio_path}") + + ext = audio_path.lower().rsplit(".", 1)[-1] + mime_map = { + "wav": "audio/wav", + "mp3": "audio/mpeg", + "flac": "audio/flac", + "ogg": "audio/ogg", + } + mime_type = mime_map.get(ext, "audio/wav") + + with open(audio_path, "rb") as f: + audio_b64 = base64.b64encode(f.read()).decode("utf-8") + return f"data:{mime_type};base64,{audio_b64}" + + +def build_payload(args) -> dict[str, object]: + payload: dict[str, object] = { + "model": args.model, + "input": args.text, + "response_format": "pcm" if args.stream else args.response_format, + } + + if args.ref_audio: + if args.ref_audio.startswith(("http://", "https://", "data:")): + payload["ref_audio"] = args.ref_audio + else: + payload["ref_audio"] = encode_audio_to_base64(args.ref_audio) + if args.ref_text: + payload["ref_text"] = args.ref_text + if args.max_new_tokens is not None: + payload["max_new_tokens"] = args.max_new_tokens + if args.stream: + payload["stream"] = True + + return payload + + +def run_tts(args) -> None: + payload = build_payload(args) + api_url = f"{args.api_base}/v1/audio/speech" + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {args.api_key}", + } + + print(f"Model: {args.model}") + print(f"Text: {args.text}") + if args.ref_audio: + print("Mode: voice cloning") + print(f"Reference audio: {args.ref_audio}") + else: + print("Mode: text-to-speech") + + if args.stream: + output_path = args.output or "voxcpm_output.pcm" + with httpx.Client(timeout=300.0) as client: + with client.stream("POST", api_url, json=payload, headers=headers) as response: + if response.status_code != 200: + print(f"Error: {response.status_code}") + print(response.read().decode("utf-8", errors="ignore")) + return + + total_bytes = 0 + with open(output_path, "wb") as f: + for chunk in response.iter_bytes(): + if not chunk: + continue + f.write(chunk) + total_bytes += len(chunk) + print(f"Streamed {total_bytes} bytes to: {output_path}") + return + + with httpx.Client(timeout=300.0) as client: + response = client.post(api_url, json=payload, headers=headers) + + if response.status_code != 200: + print(f"Error: {response.status_code}") + print(response.text) + return + + try: + text = response.content.decode("utf-8") + if text.startswith('{"error"'): + print(f"Error: {text}") + return + except UnicodeDecodeError: + pass + + output_path = args.output or "voxcpm_output.wav" + with open(output_path, "wb") as f: + f.write(response.content) + print(f"Audio saved to: {output_path}") + + +def main(): + parser = argparse.ArgumentParser(description="VoxCPM OpenAI-compatible speech client") + parser.add_argument("--api-base", default=DEFAULT_API_BASE, help="API base URL") + parser.add_argument("--api-key", default=DEFAULT_API_KEY, help="API key") + parser.add_argument("--model", "-m", default=DEFAULT_MODEL, help="Model name or path") + parser.add_argument("--text", required=True, help="Text to synthesize") + parser.add_argument("--ref-audio", default=None, help="Reference audio path, URL, or data URL") + parser.add_argument( + "--ref-text", + default=None, + help="The exact transcript spoken in the reference audio", + ) + parser.add_argument("--stream", action="store_true", help="Enable streaming PCM output") + parser.add_argument( + "--response-format", + default="wav", + choices=["wav", "pcm", "flac", "mp3", "aac", "opus"], + help="Audio format for non-streaming mode (default: wav)", + ) + parser.add_argument("--max-new-tokens", type=int, default=None, help="Maximum tokens to generate") + parser.add_argument("--output", "-o", default=None, help="Output file path") + args = parser.parse_args() + run_tts(args) + + +if __name__ == "__main__": + main() diff --git a/examples/online_serving/voxcpm/run_server.sh b/examples/online_serving/voxcpm/run_server.sh new file mode 100755 index 0000000000..ab4b6fe854 --- /dev/null +++ b/examples/online_serving/voxcpm/run_server.sh @@ -0,0 +1,38 @@ +#!/bin/bash +# Launch vLLM-Omni server for VoxCPM online speech serving. +# +# Usage: +# ./run_server.sh # default: async_chunk stage config +# ./run_server.sh async # async_chunk stage config +# ./run_server.sh sync # no-async-chunk stage config +# VOXCPM_MODEL=/path/to/model ./run_server.sh + +set -e + +MODE="${1:-async}" +MODEL="${VOXCPM_MODEL:-OpenBMB/VoxCPM1.5}" + +case "$MODE" in + async) + STAGE_CONFIG="vllm_omni/model_executor/stage_configs/voxcpm_async_chunk.yaml" + ;; + sync) + STAGE_CONFIG="vllm_omni/model_executor/stage_configs/voxcpm.yaml" + ;; + *) + echo "Unknown mode: $MODE" + echo "Supported: async, sync" + exit 1 + ;; +esac + +echo "Starting VoxCPM server with model: $MODEL" +echo "Stage config: $STAGE_CONFIG" + +vllm serve "$MODEL" \ + --stage-configs-path "$STAGE_CONFIG" \ + --host 0.0.0.0 \ + --port 8091 \ + --trust-remote-code \ + --enforce-eager \ + --omni diff --git a/tests/e2e/offline_inference/test_voxcpm.py b/tests/e2e/offline_inference/test_voxcpm.py new file mode 100644 index 0000000000..d7f65525e9 --- /dev/null +++ b/tests/e2e/offline_inference/test_voxcpm.py @@ -0,0 +1,156 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""E2E test for VoxCPM offline inference.""" + +import json +import os +from pathlib import Path +from typing import Any + +import numpy as np +import pytest +import torch + +import tests.conftest as omni_test_conftest +from tests.conftest import OmniRunner +from tests.utils import hardware_test +from vllm_omni.model_executor.models.voxcpm.voxcpm_runtime_utils import ( + prepare_voxcpm_hf_config_dir, + resolve_voxcpm_model_dir, +) + +VOXCPM_MODEL = os.environ.get("VOXCPM_MODEL", "OpenBMB/VoxCPM1.5") +STAGE_CONFIG = str( + Path(__file__).parent.parent.parent.parent / "vllm_omni" / "model_executor" / "stage_configs" / "voxcpm.yaml" +) +SAMPLE_RATE = 24000 + + +@pytest.fixture(autouse=True) +def _patch_npu_cleanup_for_voxcpm(monkeypatch: pytest.MonkeyPatch): + """Limit the NPU cleanup workaround to this VoxCPM test module only.""" + original_cleanup = omni_test_conftest.cleanup_dist_env_and_memory + + def _safe_cleanup() -> None: + try: + original_cleanup() + except RuntimeError as exc: + if "Allocator for npu is not a DeviceAllocator" in str(exc): + return + raise + + monkeypatch.setattr(omni_test_conftest, "cleanup_dist_env_and_memory", _safe_cleanup) + + +def _build_prompt(text: str) -> dict[str, Any]: + return { + "prompt_token_ids": [1], + "additional_information": { + "text": [text], + "cfg_value": [2.0], + "inference_timesteps": [10], + "min_len": [2], + "max_new_tokens": [1024], + }, + } + + +def _extract_audio_tensor(multimodal_output: dict[str, Any]) -> torch.Tensor: + audio = multimodal_output.get("audio", multimodal_output.get("model_outputs")) + assert audio is not None, f"No audio output found, keys={list(multimodal_output.keys())}" + + if isinstance(audio, list): + parts: list[torch.Tensor] = [] + for item in audio: + if item is None: + continue + tensor = torch.as_tensor(item) + if tensor.numel() == 0: + continue + parts.append(tensor.float().cpu().reshape(-1)) + return torch.cat(parts, dim=-1) if parts else torch.zeros((0,), dtype=torch.float32) + + return torch.as_tensor(audio).float().cpu().reshape(-1) + + +def _extract_final_multimodal_output(outputs) -> dict[str, Any]: + for item in reversed(outputs): + request_output = getattr(item, "request_output", None) + if request_output is not None: + multimodal_output = getattr(request_output, "multimodal_output", None) + if isinstance(multimodal_output, dict): + return multimodal_output + completions = getattr(request_output, "outputs", None) or [] + for completion in completions: + multimodal_output = getattr(completion, "multimodal_output", None) + if isinstance(multimodal_output, dict): + return multimodal_output + + multimodal_output = getattr(item, "multimodal_output", None) + if isinstance(multimodal_output, dict): + return multimodal_output + + raise AssertionError("No multimodal audio output found in VoxCPM generate results") + + +@pytest.fixture +def voxcpm_model_path(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> str: + model_dir = resolve_voxcpm_model_dir(VOXCPM_MODEL) + + hf_config_env = os.environ.get("VLLM_OMNI_VOXCPM_HF_CONFIG_PATH") + if hf_config_env: + hf_config_dir = Path(hf_config_env).expanduser() + else: + hf_config_dir = tmp_path / "voxcpm_hf_config" + + if not (hf_config_dir / "config.json").exists(): + prepare_voxcpm_hf_config_dir(model_dir, hf_config_dir) + + monkeypatch.setenv("VLLM_OMNI_VOXCPM_HF_CONFIG_PATH", str(hf_config_dir)) + return str(model_dir) + + +def test_prepare_voxcpm_hf_config_dir(tmp_path: Path): + model_dir = tmp_path / "model" + model_dir.mkdir() + (model_dir / "config.json").write_text(json.dumps({"hidden_size": 1024}), encoding="utf-8") + (model_dir / "generation_config.json").write_text(json.dumps({"do_sample": False}), encoding="utf-8") + + hf_config_dir = prepare_voxcpm_hf_config_dir(model_dir, tmp_path / "voxcpm_hf_config") + + prepared_config = json.loads((hf_config_dir / "config.json").read_text(encoding="utf-8")) + assert prepared_config["model_type"] == "voxcpm" + assert prepared_config["architectures"] == ["VoxCPMForConditionalGeneration"] + assert (hf_config_dir / "generation_config.json").exists() + + +def test_resolve_voxcpm_model_dir_local_path(tmp_path: Path): + model_dir = tmp_path / "OpenBMB" / "VoxCPM1.5" + model_dir.mkdir(parents=True) + + assert resolve_voxcpm_model_dir(str(model_dir)) == model_dir + + +@pytest.mark.core_model +@pytest.mark.omni +@hardware_test(res={"cuda": "L4"}, num_cards=1) +def test_voxcpm_zero_shot_001(voxcpm_model_path: str): + with OmniRunner(voxcpm_model_path, stage_configs_path=STAGE_CONFIG) as runner: + outputs = list(runner.omni.generate(_build_prompt("Hello, this is a VoxCPM offline inference test."))) + + assert outputs, "No outputs returned" + + multimodal_output = _extract_final_multimodal_output(outputs) + audio = _extract_audio_tensor(multimodal_output) + assert audio.numel() > SAMPLE_RATE // 2, f"Audio too short: {audio.numel()} samples" + + duration_s = audio.shape[0] / SAMPLE_RATE + assert 0.5 < duration_s < 30.0, f"Audio duration out of range: {duration_s:.2f}s" + + peak = float(torch.max(torch.abs(audio)).item()) if audio.numel() > 0 else 0.0 + assert peak > 0.01, "Generated audio appears to be silence" + + audio_np = audio.numpy() + rms = float(np.sqrt(np.mean(np.square(audio_np)))) if audio_np.size else 0.0 + assert rms > 1e-4, "Generated audio RMS too low" diff --git a/tests/engine/test_arg_utils.py b/tests/engine/test_arg_utils.py index 35d55f1cc4..565c83c1ad 100644 --- a/tests/engine/test_arg_utils.py +++ b/tests/engine/test_arg_utils.py @@ -7,6 +7,7 @@ import argparse import inspect from types import SimpleNamespace +from unittest.mock import Mock import pytest from pydantic import ValidationError @@ -166,6 +167,24 @@ def test_stage_configs_path_field(): assert args.stage_configs_path == "/some/path.yaml" +def test_voxcpm_model_arch_injects_model_type_override(mocker): + """Ensure VoxCPM model_arch injects hf_overrides for config resolution.""" + mocker.patch.object(OmniEngineArgs, "_ensure_omni_models_registered", return_value=True) + mocker.patch.object(OmniEngineArgs, "_patch_empty_hf_config") + mocker.patch.object(EngineArgs, "create_model_config", return_value=Mock()) + mocker.patch.object(OmniModelConfig, "from_vllm_model_config", return_value=Mock()) + + args = OmniEngineArgs( + model="OpenBMB/VoxCPM1.5", + model_arch="VoxCPMForConditionalGeneration", + ) + args.create_model_config() + + assert args.hf_overrides["architectures"] == ["VoxCPMForConditionalGeneration"] + assert args.hf_overrides["model_type"] == "voxcpm" + args._patch_empty_hf_config.assert_called_once_with("voxcpm") + + def test_strip_single_engine_args(): """_strip_single_engine_args should remove EngineArgs fields but keep omni fields.""" kwargs = { diff --git a/tests/entrypoints/openai_api/test_serving_speech_voxcpm.py b/tests/entrypoints/openai_api/test_serving_speech_voxcpm.py new file mode 100644 index 0000000000..48660b6d1c --- /dev/null +++ b/tests/entrypoints/openai_api/test_serving_speech_voxcpm.py @@ -0,0 +1,143 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""UTs for VoxCPM OpenAI speech serving behavior.""" + +import asyncio +from types import SimpleNamespace +from unittest.mock import AsyncMock + +import pytest +from pytest_mock import MockerFixture + +from vllm_omni.entrypoints.openai.protocol.audio import OpenAICreateSpeechRequest +from vllm_omni.entrypoints.openai.serving_speech import OmniOpenAIServingSpeech + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + + +@pytest.fixture +def voxcpm_server(mocker: MockerFixture): + mocker.patch.object(OmniOpenAIServingSpeech, "_load_supported_speakers", return_value=set()) + mocker.patch.object(OmniOpenAIServingSpeech, "_load_codec_frame_rate", return_value=None) + + mock_engine_client = mocker.MagicMock() + mock_engine_client.errored = False + mock_engine_client.model_config = mocker.MagicMock(model="OpenBMB/VoxCPM1.5") + mock_engine_client.default_sampling_params_list = [SimpleNamespace(max_tokens=2048)] + mock_engine_client.tts_batch_max_items = 32 + mock_engine_client.generate = mocker.MagicMock(return_value="generator") + mock_engine_client.stage_configs = [ + SimpleNamespace( + engine_args=SimpleNamespace( + model_stage="latent_generator", + model_arch="VoxCPMForConditionalGeneration", + ), + tts_args={}, + ), + SimpleNamespace( + engine_args=SimpleNamespace(model_stage="vae"), + tts_args={}, + ), + ] + + mock_models = mocker.MagicMock() + mock_models.is_base_model.return_value = True + + return OmniOpenAIServingSpeech( + engine_client=mock_engine_client, + models=mock_models, + request_logger=mocker.MagicMock(), + ) + + +class TestVoxCPMServing: + def test_voxcpm_model_type_detection(self, voxcpm_server): + assert voxcpm_server._tts_model_type == "voxcpm" + assert voxcpm_server._is_tts is True + assert voxcpm_server.supported_speakers == set() + + @pytest.mark.parametrize( + ("request_kwargs", "expected_substring"), + [ + ({"voice": "alice"}, "voice"), + ({"instructions": "whisper"}, "instructions"), + ({"language": "en"}, "language"), + ({"task_type": "CustomVoice"}, "plain tts"), + ({"x_vector_only_mode": True}, "x_vector_only_mode"), + ({"speaker_embedding": [0.1, 0.2]}, "speaker_embedding"), + ({"initial_codec_chunk_frames": 4}, "initial_codec_chunk_frames"), + ({"ref_text": "reference"}, "ref_audio"), + ], + ) + def test_validate_voxcpm_rejects_unsupported_fields(self, voxcpm_server, request_kwargs, expected_substring): + request = OpenAICreateSpeechRequest(input="hello voxcpm", **request_kwargs) + error = voxcpm_server._validate_voxcpm_request(request) + assert error is not None + assert expected_substring in error.lower() + + def test_validate_voxcpm_accepts_plain_tts_request(self, voxcpm_server): + request = OpenAICreateSpeechRequest(input="hello voxcpm", max_new_tokens=256) + assert voxcpm_server._validate_voxcpm_request(request) is None + + def test_validate_voxcpm_accepts_voice_clone_request(self, voxcpm_server): + request = OpenAICreateSpeechRequest( + input="clone this voice", + ref_audio="data:audio/wav;base64,QUJD", + ref_text="reference transcript", + max_new_tokens=256, + ) + assert voxcpm_server._validate_voxcpm_request(request) is None + + def test_prepare_speech_generation_voxcpm_text_only(self, voxcpm_server): + request = OpenAICreateSpeechRequest(input="hello voxcpm", max_new_tokens=321) + + request_id, generator, tts_params = asyncio.run(voxcpm_server._prepare_speech_generation(request)) + + assert request_id.startswith("speech-") + assert generator == "generator" + assert tts_params == { + "text": ["hello voxcpm"], + "cfg_value": [2.0], + "inference_timesteps": [10], + "min_len": [2], + "max_new_tokens": [321], + } + + voxcpm_server.engine_client.generate.assert_called_once() + call = voxcpm_server.engine_client.generate.call_args + assert call.kwargs["prompt"] == { + "prompt_token_ids": [1], + "additional_information": tts_params, + } + assert call.kwargs["output_modalities"] == ["audio"] + + def test_prepare_speech_generation_voxcpm_voice_clone_resolves_ref_audio(self, voxcpm_server): + voxcpm_server._resolve_ref_audio = AsyncMock(return_value=([0.1, -0.1, 0.2], 16000)) + request = OpenAICreateSpeechRequest( + input="clone this voice", + ref_audio="data:audio/wav;base64,QUJD", + ref_text="reference transcript", + max_new_tokens=512, + ) + + request_id, generator, tts_params = asyncio.run(voxcpm_server._prepare_speech_generation(request)) + + assert request_id.startswith("speech-") + assert generator == "generator" + assert tts_params == { + "text": ["clone this voice"], + "cfg_value": [2.0], + "inference_timesteps": [10], + "min_len": [2], + "max_new_tokens": [512], + "ref_text": ["reference transcript"], + "ref_audio": [[[0.1, -0.1, 0.2], 16000]], + } + + voxcpm_server._resolve_ref_audio.assert_awaited_once_with("data:audio/wav;base64,QUJD") + call = voxcpm_server.engine_client.generate.call_args + assert call.kwargs["prompt"] == { + "prompt_token_ids": [1], + "additional_information": tts_params, + } diff --git a/tests/entrypoints/test_utils.py b/tests/entrypoints/test_utils.py index 94e254c250..248629d51d 100644 --- a/tests/entrypoints/test_utils.py +++ b/tests/entrypoints/test_utils.py @@ -310,6 +310,39 @@ def mock_exists(path): assert result is not None assert "glm_image.yaml" in result + def test_voxcpm_transformers_format_resolution(self, mocker: MockerFixture): + """Test VoxCPM transformers config resolves to the voxcpm stage config.""" + mocker.patch( + "vllm_omni.entrypoints.utils.get_config", + side_effect=ValueError("missing transformers config"), + ) + mocker.patch( + "vllm_omni.entrypoints.utils.file_or_path_exists", + side_effect=lambda _model, filename, revision=None: filename == "config.json", + ) + mocker.patch( + "vllm_omni.entrypoints.utils.get_hf_file_to_dict", + return_value={"model_type": "voxcpm"}, + ) + mocker.patch( + "vllm_omni.entrypoints.utils.current_omni_platform.get_default_stage_config_path", + return_value="vllm_omni/model_executor/stage_configs", + ) + + original_exists = os.path.exists + + def mock_exists(path): + if "voxcpm.yaml" in str(path): + return True + return original_exists(path) + + mocker.patch("os.path.exists", side_effect=mock_exists) + + result = resolve_model_config_path("OpenBMB/VoxCPM1.5") + + assert result is not None + assert "voxcpm.yaml" in result + class TestLoadAndResolveStageConfigs: def test_load_and_resolve_with_kwargs(self): diff --git a/tests/model_executor/stage_input_processors/test_voxcpm_async_chunk.py b/tests/model_executor/stage_input_processors/test_voxcpm_async_chunk.py new file mode 100644 index 0000000000..7d6fc6e74c --- /dev/null +++ b/tests/model_executor/stage_input_processors/test_voxcpm_async_chunk.py @@ -0,0 +1,87 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""UTs for VoxCPM async-chunk stage input processing.""" + +from types import SimpleNamespace + +import pytest +import torch + +from vllm_omni.model_executor.stage_input_processors.voxcpm import ( + _VOXCPM_LATENT_MAGIC, + _coerce_finished_flag, + latent2vae_async_chunk, +) + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + + +def _request(*, finished): + return SimpleNamespace(is_finished=lambda: finished) + + +def _decode_serialized_latent(codes: list[int]) -> torch.Tensor: + assert codes[0] == _VOXCPM_LATENT_MAGIC + latent_dim = codes[1] + time_dim = codes[2] + payload = torch.tensor(codes[3:], dtype=torch.int32).to(torch.uint16) + return payload.view(torch.bfloat16).to(torch.float32).reshape(1, latent_dim, time_dim) + + +@pytest.mark.parametrize( + ("value", "expected"), + [ + (None, False), + (False, False), + (True, True), + (torch.tensor(False), False), + (torch.tensor(True), True), + ([torch.tensor(True)], True), + (([True],), True), + ([], False), + ], +) +def test_coerce_finished_flag(value, expected): + assert _coerce_finished_flag(value) is expected + + +def test_latent2vae_async_chunk_serializes_latent_payload(): + latent = torch.arange(6, dtype=torch.float32).reshape(2, 3) + + payload = latent2vae_async_chunk( + transfer_manager=None, + pooling_output={"latent_audio_feat": latent}, + request=_request(finished=False), + is_finished=torch.tensor(False), + ) + + assert payload is not None + assert torch.equal(payload["finished"], torch.tensor(False, dtype=torch.bool)) + recovered = _decode_serialized_latent(payload["code_predictor_codes"]) + torch.testing.assert_close(recovered, latent.to(torch.bfloat16).to(torch.float32).unsqueeze(0)) + + +def test_latent2vae_async_chunk_returns_terminal_marker_without_latent(): + payload = latent2vae_async_chunk( + transfer_manager=None, + pooling_output=None, + request=_request(finished=[torch.tensor(True)]), + is_finished=False, + ) + + assert payload == { + "code_predictor_codes": [], + "finished": torch.tensor(True, dtype=torch.bool), + } + + +def test_latent2vae_async_chunk_returns_none_for_nonterminal_empty_chunk(): + payload = latent2vae_async_chunk( + transfer_manager=None, + pooling_output={"latent_audio_feat": torch.zeros((0,), dtype=torch.float32)}, + request=_request(finished=False), + is_finished=False, + ) + + assert payload is None diff --git a/vllm_omni/engine/arg_utils.py b/vllm_omni/engine/arg_utils.py index d61102c7e1..5b69d6b1f0 100644 --- a/vllm_omni/engine/arg_utils.py +++ b/vllm_omni/engine/arg_utils.py @@ -21,6 +21,7 @@ "CosyVoice3Model": "cosyvoice3", "OmniVoiceModel": "omnivoice", "VoxCPM2TalkerForConditionalGeneration": "voxcpm2", + "VoxCPMForConditionalGeneration": "voxcpm", } # Maps model architecture names to tokenizer subfolder paths within HF repos. @@ -41,6 +42,7 @@ def _register_omni_hf_configs() -> None: from vllm_omni.model_executor.models.voxtral_tts.configuration_voxtral_tts import ( VoxtralTTSConfig, ) + from vllm_omni.transformers_utils.configs.voxcpm import VoxCPMConfig from vllm_omni.transformers_utils.configs.voxcpm2 import VoxCPM2Config except Exception as exc: # pragma: no cover - best-effort optional registration logger.warning("Skipping omni HF config registration due to import error: %s", exc) @@ -59,6 +61,7 @@ def _register_omni_hf_configs() -> None: ("cosyvoice3", CosyVoice3Config), ("omnivoice", OmniVoiceConfig), ("voxtral_tts", VoxtralTTSConfig), + ("voxcpm", VoxCPMConfig), ("voxcpm2", VoxCPM2Config), ]: try: diff --git a/vllm_omni/entrypoints/openai/serving_speech.py b/vllm_omni/entrypoints/openai/serving_speech.py index 3dc5f595d0..dc223cdb27 100644 --- a/vllm_omni/entrypoints/openai/serving_speech.py +++ b/vllm_omni/entrypoints/openai/serving_speech.py @@ -49,6 +49,7 @@ _FISH_TTS_MODEL_STAGES = {"fish_speech_slow_ar"} _COSYVOICE3_TTS_MODEL_STAGES = {"cosyvoice3_talker"} _OMNIVOICE_TTS_MODEL_STAGES = {"omnivoice_generator"} +_VOXCPM_TTS_MODEL_STAGES = {"latent_generator", "vae"} _VOXCPM2_TTS_MODEL_STAGES = {"latent_generator"} _TTS_MODEL_STAGES: set[str] = ( _VOXTRAL_TTS_MODEL_STAGES @@ -56,6 +57,7 @@ | _FISH_TTS_MODEL_STAGES | _COSYVOICE3_TTS_MODEL_STAGES | _OMNIVOICE_TTS_MODEL_STAGES + | _VOXCPM_TTS_MODEL_STAGES | _VOXCPM2_TTS_MODEL_STAGES ) _TTS_LANGUAGES: set[str] = { @@ -282,6 +284,11 @@ def _detect_tts_model_type(self) -> str | None: if self._tts_stage is None: return None model_stage = getattr(self._tts_stage.engine_args, "model_stage", None) + model_arch = getattr(self._tts_stage.engine_args, "model_arch", None) + if model_arch == "VoxCPM2TalkerForConditionalGeneration": + return "voxcpm2" + if model_arch == "VoxCPMForConditionalGeneration": + return "voxcpm" if model_stage in _QWEN3_TTS_MODEL_STAGES: return "qwen3_tts" if model_stage in _VOXTRAL_TTS_MODEL_STAGES: @@ -292,8 +299,12 @@ def _detect_tts_model_type(self) -> str | None: return "cosyvoice3" if model_stage in _OMNIVOICE_TTS_MODEL_STAGES: return "omnivoice" - if model_stage in _VOXCPM2_TTS_MODEL_STAGES: - return "voxcpm2" + if model_stage in (_VOXCPM_TTS_MODEL_STAGES | _VOXCPM2_TTS_MODEL_STAGES): + has_vae_stage = any( + getattr(getattr(stage, "engine_args", None), "model_stage", None) == "vae" + for stage in self.engine_client.stage_configs + ) + return "voxcpm" if has_vae_stage or model_stage == "vae" else "voxcpm2" return None def _compute_max_instructions_length(self) -> int: @@ -318,6 +329,8 @@ def _compute_max_instructions_length(self) -> int: def _load_supported_speakers(self) -> set[str]: """Load supported speakers (case-insensitive) from the model configuration.""" try: + if self._tts_model_type == "voxcpm": + return set() if self._tts_model_type == "voxtral_tts": config = self.engine_client.model_config.hf_config.audio_config else: @@ -377,6 +390,8 @@ def _estimate_ref_code_len(self, ref_audio: object) -> int | None: def _estimate_prompt_len(self, tts_params: dict[str, Any]) -> int: """Estimate prompt length so the placeholder matches model-side embeddings.""" try: + if self._tts_model_type == "voxcpm": + return 1 from vllm_omni.model_executor.models.qwen3_tts.qwen3_tts_talker import ( Qwen3TTSTalkerForConditionalGeneration, ) @@ -791,6 +806,8 @@ def _validate_tts_request(self, request: OpenAICreateSpeechRequest) -> str | Non return self._validate_fish_tts_request(request) if self._tts_model_type == "cosyvoice3": return self._validate_cosyvoice3_request(request) + if self._tts_model_type == "voxcpm": + return self._validate_voxcpm_request(request) if self._tts_model_type == "voxcpm2": return None # VoxCPM2 accepts any text input return self._validate_qwen_tts_request(request) @@ -832,6 +849,43 @@ def _validate_voxtral_tts_request(self, request: OpenAICreateSpeechRequest) -> s return None + def _validate_voxcpm_request(self, request: OpenAICreateSpeechRequest) -> str | None: + """Validate VoxCPM request parameters. Returns error message or None.""" + if not request.input or not request.input.strip(): + return "Input text cannot be empty" + + if request.voice is not None: + return "'voice' is not supported for VoxCPM" + if request.instructions is not None: + return "'instructions' is not supported for VoxCPM" + if request.language is not None: + return "'language' is not supported for VoxCPM" + if request.task_type not in (None, "Base"): + return "VoxCPM only supports plain TTS or voice cloning with ref_audio/ref_text" + if request.x_vector_only_mode is not None: + return "'x_vector_only_mode' is not supported for VoxCPM" + if request.speaker_embedding is not None: + return "'speaker_embedding' is not supported for VoxCPM" + if request.initial_codec_chunk_frames is not None: + return "'initial_codec_chunk_frames' is not supported for VoxCPM" + + if request.ref_audio is not None: + fmt_err = self._validate_ref_audio_format(request.ref_audio) + if fmt_err: + return fmt_err + if not request.ref_text or not request.ref_text.strip(): + return "Voice cloning requires 'ref_text' (transcript of the reference audio)" + elif request.ref_text is not None: + return "'ref_text' requires 'ref_audio' for VoxCPM voice cloning" + + if request.max_new_tokens is not None: + if request.max_new_tokens < _TTS_MAX_NEW_TOKENS_MIN: + return f"max_new_tokens must be at least {_TTS_MAX_NEW_TOKENS_MIN}" + if request.max_new_tokens > _TTS_MAX_NEW_TOKENS_MAX: + return f"max_new_tokens cannot exceed {_TTS_MAX_NEW_TOKENS_MAX}" + + return None + def _validate_qwen_tts_request(self, request: OpenAICreateSpeechRequest) -> str | None: """Validate Qwen TTS request parameters. Returns error message or None.""" # Infer Base task when ref_audio or ref_text is provided without explicit task_type. @@ -1169,6 +1223,18 @@ def _build_tts_params(self, request: OpenAICreateSpeechRequest) -> dict[str, Any Processes each parameter if present, skips if not. Values are wrapped in lists as required by the model. """ + if self._tts_model_type == "voxcpm": + params: dict[str, Any] = { + "text": [request.input], + "cfg_value": [2.0], + "inference_timesteps": [10], + "min_len": [2], + "max_new_tokens": [request.max_new_tokens or 4096], + } + if request.ref_text is not None: + params["ref_text"] = [request.ref_text] + return params + params: dict[str, Any] = {} # Text content (always required) @@ -1481,6 +1547,8 @@ async def _prepare_speech_generation( model_type = "voxtral_tts" elif self._tts_model_type == "cosyvoice3": model_type = "cosyvoice3" + elif self._tts_model_type == "voxcpm": + model_type = "voxcpm" elif self._tts_model_type == "voxcpm2": model_type = "voxcpm2" elif self._is_tts: diff --git a/vllm_omni/model_executor/models/registry.py b/vllm_omni/model_executor/models/registry.py index 0894088005..3407b42869 100644 --- a/vllm_omni/model_executor/models/registry.py +++ b/vllm_omni/model_executor/models/registry.py @@ -145,6 +145,12 @@ "fish_speech_dac_decoder", "FishSpeechDACDecoder", ), + ## VoxCPM + "VoxCPMForConditionalGeneration": ( + "voxcpm", + "voxcpm", + "VoxCPMForConditionalGeneration", + ), ## VoxCPM2 "VoxCPM2TalkerForConditionalGeneration": ( "voxcpm2", diff --git a/vllm_omni/model_executor/models/voxcpm/__init__.py b/vllm_omni/model_executor/models/voxcpm/__init__.py new file mode 100644 index 0000000000..3b064c0f68 --- /dev/null +++ b/vllm_omni/model_executor/models/voxcpm/__init__.py @@ -0,0 +1,7 @@ +from .configuration_voxcpm import VoxCPMConfig +from .voxcpm import VoxCPMForConditionalGeneration + +__all__ = [ + "VoxCPMConfig", + "VoxCPMForConditionalGeneration", +] diff --git a/vllm_omni/model_executor/models/voxcpm/configuration_voxcpm.py b/vllm_omni/model_executor/models/voxcpm/configuration_voxcpm.py new file mode 100644 index 0000000000..ce1d809bd3 --- /dev/null +++ b/vllm_omni/model_executor/models/voxcpm/configuration_voxcpm.py @@ -0,0 +1,3 @@ +from vllm_omni.transformers_utils.configs.voxcpm import VoxCPMConfig + +__all__ = ["VoxCPMConfig"] diff --git a/vllm_omni/model_executor/models/voxcpm/voxcpm.py b/vllm_omni/model_executor/models/voxcpm/voxcpm.py new file mode 100644 index 0000000000..6fa36fc420 --- /dev/null +++ b/vllm_omni/model_executor/models/voxcpm/voxcpm.py @@ -0,0 +1,886 @@ +from __future__ import annotations + +import json +import os +import sys +import tempfile +import warnings +import wave +from collections.abc import Callable, Generator, Iterable +from pathlib import Path +from typing import Any + +import numpy as np +import torch +import torch.nn as nn +from einops import rearrange +from tqdm import tqdm +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.sequence import IntermediateTensors + +from vllm_omni.model_executor.models.output_templates import OmniOutput + +from .voxcpm_loader import ( + _build_prompt_cache_with_soundfile, + _device_to_string, + _force_cuda_available_for_npu, + _import_voxcpm_audio_vae_classes, + _import_voxcpm_base_model_class, + _is_torchcodec_load_error, + _normalize_dtype_name, + _prepare_runtime_model_dir, + _resolve_runtime_device, +) +from .voxcpm_runtime_utils import resolve_voxcpm_model_dir +from .voxcpm_stage_wrappers import _DirectVoxCPMAudioVAE, _DirectVoxCPMLatentGenerator + +logger = init_logger(__name__) +_VOXCPM_LATENT_MAGIC = 131071 + + +def _make_voxcpm_model_for_omni(base: type[Any]) -> type[Any]: + """Subclass upstream VoxCPMModel: local ``_inference`` + ``latents_only`` prompt-cache generation.""" + + from voxcpm.model.utils import get_dtype + + class VoxCPMModelForOmni(base): + @torch.inference_mode() + def build_prompt_cache(self, *args: Any, **kwargs: Any): + try: + return super().build_prompt_cache(*args, **kwargs) + except (ImportError, ModuleNotFoundError, RuntimeError) as exc: + if not _is_torchcodec_load_error(exc): + raise + return _build_prompt_cache_with_soundfile(self, *args, **kwargs) + + @torch.inference_mode() + def _inference( + self, + text: torch.Tensor, + text_mask: torch.Tensor, + feat: torch.Tensor, + feat_mask: torch.Tensor, + min_len: int = 2, + max_len: int = 2000, + inference_timesteps: int = 10, + cfg_value: float = 2.0, + streaming: bool = False, + streaming_prefix_len: int = 3, + ) -> Generator[tuple[torch.Tensor, torch.Tensor | list[torch.Tensor]], None, None]: + B, _, _, _ = feat.shape + + feat_embed = self.feat_encoder(feat) + feat_embed = self.enc_to_lm_proj(feat_embed) + + scale_emb = self.config.lm_config.scale_emb if self.config.lm_config.use_mup else 1.0 + text_embed = self.base_lm.embed_tokens(text) * scale_emb + combined_embed = text_mask.unsqueeze(-1) * text_embed + feat_mask.unsqueeze(-1) * feat_embed + + prefix_feat_cond = feat[:, -1, ...] + pred_feat_seq: list[torch.Tensor] = [] + + audio_patch_count = int(feat_mask.sum().item()) + if audio_patch_count > 0: + context_len = min(streaming_prefix_len - 1, audio_patch_count) + prompt_context_patches = list(feat[:, -context_len:, :, :].split(1, dim=1)) + pred_feat_seq = prompt_context_patches + pred_feat_seq + + enc_outputs, kv_cache_tuple = self.base_lm( + inputs_embeds=combined_embed, + is_causal=True, + ) + self.base_lm.kv_cache.fill_caches(kv_cache_tuple) + + enc_outputs = self.fsq_layer(enc_outputs) * feat_mask.unsqueeze(-1) + enc_outputs * text_mask.unsqueeze(-1) + lm_hidden = enc_outputs[:, -1, :] + + residual_enc_outputs, residual_kv_cache_tuple = self.residual_lm( + inputs_embeds=enc_outputs + feat_mask.unsqueeze(-1) * feat_embed, + is_causal=True, + ) + self.residual_lm.kv_cache.fill_caches(residual_kv_cache_tuple) + residual_hidden = residual_enc_outputs[:, -1, :] + + for step_idx in tqdm(range(max_len)): + dit_hidden = self.lm_to_dit_proj(lm_hidden) + self.res_to_dit_proj(residual_hidden) + pred_feat = self.feat_decoder( + mu=dit_hidden, + patch_size=self.patch_size, + cond=prefix_feat_cond.transpose(1, 2).contiguous(), + n_timesteps=inference_timesteps, + cfg_value=cfg_value, + ).transpose(1, 2) + + curr_embed = self.enc_to_lm_proj(self.feat_encoder(pred_feat.unsqueeze(1))) + pred_feat_seq.append(pred_feat.unsqueeze(1)) + prefix_feat_cond = pred_feat + + if streaming: + pred_feat_chunk = torch.cat(pred_feat_seq[-streaming_prefix_len:], dim=1) + feat_pred = rearrange(pred_feat_chunk, "b t p d -> b d (t p)", b=B, p=self.patch_size) + yield feat_pred, pred_feat_seq + + stop_flag = self.stop_head(self.stop_actn(self.stop_proj(lm_hidden))).argmax(dim=-1)[0].cpu().item() + if step_idx > min_len and stop_flag == 1: + break + + lm_hidden = self.base_lm.forward_step( + curr_embed[:, 0, :], + torch.tensor([self.base_lm.kv_cache.step()], device=curr_embed.device), + ).clone() + lm_hidden = self.fsq_layer(lm_hidden) + residual_hidden = self.residual_lm.forward_step( + lm_hidden + curr_embed[:, 0, :], + torch.tensor([self.residual_lm.kv_cache.step()], device=curr_embed.device), + ).clone() + + if not streaming: + pred_feat_seq_cat = torch.cat(pred_feat_seq, dim=1) + feat_pred = rearrange(pred_feat_seq_cat, "b t p d -> b d (t p)", b=B, p=self.patch_size) + yield feat_pred, pred_feat_seq_cat.squeeze(0).cpu() + + @torch.inference_mode() + def generate_latents_with_prompt_cache( + self, + target_text: str, + prompt_cache: dict, + min_len: int = 2, + max_len: int = 2000, + inference_timesteps: int = 10, + cfg_value: float = 2.0, + retry_badcase: bool = False, + retry_badcase_max_times: int = 3, + retry_badcase_ratio_threshold: float = 6.0, + streaming_prefix_len: int = 3, + ) -> tuple[None, torch.Tensor, torch.Tensor]: + return next( + self._generate_with_prompt_cache( + target_text=target_text, + prompt_cache=prompt_cache, + min_len=min_len, + max_len=max_len, + inference_timesteps=inference_timesteps, + cfg_value=cfg_value, + retry_badcase=retry_badcase, + retry_badcase_max_times=retry_badcase_max_times, + retry_badcase_ratio_threshold=retry_badcase_ratio_threshold, + streaming=False, + streaming_prefix_len=streaming_prefix_len, + latents_only=True, + ) + ) + + @torch.inference_mode() + def generate_latents_with_prompt_cache_streaming( + self, + target_text: str, + prompt_cache: dict, + min_len: int = 2, + max_len: int = 2000, + inference_timesteps: int = 10, + cfg_value: float = 2.0, + retry_badcase: bool = False, + retry_badcase_max_times: int = 3, + retry_badcase_ratio_threshold: float = 6.0, + streaming_prefix_len: int = 3, + ) -> Generator[tuple[None, torch.Tensor, torch.Tensor], None, None]: + return self._generate_with_prompt_cache( + target_text=target_text, + prompt_cache=prompt_cache, + min_len=min_len, + max_len=max_len, + inference_timesteps=inference_timesteps, + cfg_value=cfg_value, + retry_badcase=retry_badcase, + retry_badcase_max_times=retry_badcase_max_times, + retry_badcase_ratio_threshold=retry_badcase_ratio_threshold, + streaming=True, + streaming_prefix_len=streaming_prefix_len, + latents_only=True, + ) + + @torch.inference_mode() + def _generate_with_prompt_cache( + self, + target_text: str, + prompt_cache: dict, + min_len: int = 2, + max_len: int = 2000, + inference_timesteps: int = 10, + cfg_value: float = 2.0, + retry_badcase: bool = False, + retry_badcase_max_times: int = 3, + retry_badcase_ratio_threshold: float = 6.0, + streaming: bool = False, + streaming_prefix_len: int = 3, + latents_only: bool = False, + ) -> Generator[tuple[torch.Tensor | None, torch.Tensor, torch.Tensor | list[torch.Tensor]], None, None]: + if retry_badcase and streaming: + warnings.warn("Retry on bad cases is not supported in streaming mode, setting retry_badcase=False.") + retry_badcase = False + if prompt_cache is None: + prompt_audio_feat = torch.empty((0, self.patch_size, self.audio_vae.latent_dim), dtype=torch.float32) + text = target_text + else: + prompt_audio_feat = prompt_cache["audio_feat"] + prompt_text = prompt_cache["prompt_text"] + text = prompt_text + target_text + + text_token = torch.LongTensor(self.text_tokenizer(text)) + text_token = torch.cat( + [ + text_token, + torch.tensor([self.audio_start_token], dtype=torch.int32, device=text_token.device), + ], + dim=-1, + ) + target_text_token = torch.LongTensor(self.text_tokenizer(target_text)) + + audio_length = prompt_audio_feat.size(0) + text_length = text_token.shape[0] + text_pad_token = torch.zeros(audio_length, dtype=torch.int32, device=text_token.device) + audio_pad_feat = torch.zeros( + (text_token.shape[0], self.patch_size, self.audio_vae.latent_dim), + dtype=torch.float32, + device=text_token.device, + ) + text_token = torch.cat([text_token, text_pad_token]) + audio_feat = torch.cat([audio_pad_feat, prompt_audio_feat], dim=0) + text_mask = ( + torch.cat([torch.ones(text_length), torch.zeros(audio_length)]).type(torch.int32).to(text_token.device) + ) + audio_mask = ( + torch.cat([torch.zeros(text_length), torch.ones(audio_length)]).type(torch.int32).to(text_token.device) + ) + + text_token = text_token.unsqueeze(0).to(self.device) + text_mask = text_mask.unsqueeze(0).to(self.device) + audio_feat = audio_feat.unsqueeze(0).to(self.device).to(get_dtype(self.config.dtype)) + audio_mask = audio_mask.unsqueeze(0).to(self.device) + + target_text_length = len(self.text_tokenizer(target_text)) + retry_badcase_times = 0 + while retry_badcase_times < retry_badcase_max_times: + inference_result = self._inference( + text_token, + text_mask, + audio_feat, + audio_mask, + min_len=min_len, + max_len=min(int(target_text_length * retry_badcase_ratio_threshold + 10), max_len), + inference_timesteps=inference_timesteps, + cfg_value=cfg_value, + streaming=streaming, + streaming_prefix_len=streaming_prefix_len, + ) + if streaming: + patch_len = self.patch_size * self.chunk_size + for latent_pred, pred_audio_feat in inference_result: + if latents_only: + decode_audio = None + yield (decode_audio, target_text_token, latent_pred) + else: + decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)) + decode_audio = decode_audio[..., -patch_len:].squeeze(1).cpu() + yield (decode_audio, target_text_token, pred_audio_feat) + break + + latent_pred, pred_audio_feat = next(inference_result) + if retry_badcase and pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold: + ratio = pred_audio_feat.shape[0] / target_text_length + print(f" Badcase detected, audio_text_ratio={ratio}, retrying...", file=sys.stderr) + retry_badcase_times += 1 + continue + break + + if not streaming: + if latents_only: + decode_audio = None + else: + decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)) + patch_len = self.patch_size * self.chunk_size + if audio_mask.sum().item() > 0: + decode_audio = decode_audio[..., patch_len * (streaming_prefix_len - 1) :].squeeze(1).cpu() + else: + decode_audio = decode_audio[..., :].squeeze(1).cpu() + yield (decode_audio, target_text_token, pred_audio_feat) + + VoxCPMModelForOmni.__name__ = "VoxCPMModelForOmni" + VoxCPMModelForOmni.__qualname__ = "VoxCPMModelForOmni" + return VoxCPMModelForOmni + + +def _import_voxcpm_model_class() -> type[Any]: + base = _import_voxcpm_base_model_class() + return _make_voxcpm_model_for_omni(base) + + +def _load_native_voxcpm_model( + model_path: str, + *, + device: torch.device, + dtype: str | None, +): + VoxCPMModel = _import_voxcpm_model_class() + model_dir = resolve_voxcpm_model_dir(model_path) + runtime_model_path = _prepare_runtime_model_dir(model_dir, target_device=device, target_dtype=dtype) + + if device.type == "npu" and hasattr(torch, "npu"): + torch.npu.set_device(device) + + with _force_cuda_available_for_npu(device): + return VoxCPMModel.from_local( + runtime_model_path, + optimize=device.type == "cuda", + ) + + +def _load_native_voxcpm_latent_generator( + model_path: str, + *, + device: torch.device, + dtype: str | None, +) -> _DirectVoxCPMLatentGenerator: + return _DirectVoxCPMLatentGenerator(_load_native_voxcpm_model(model_path, device=device, dtype=dtype)) + + +def _load_native_voxcpm_audio_vae( + model_path: str, + *, + device: torch.device, +) -> _DirectVoxCPMAudioVAE: + AudioVAE, AudioVAEConfig = _import_voxcpm_audio_vae_classes() + model_dir = resolve_voxcpm_model_dir(model_path) + runtime_model_path = _prepare_runtime_model_dir(model_dir, target_device=device, target_dtype="float32") + config_dict = json.loads((Path(runtime_model_path) / "config.json").read_text()) + audio_vae_config = config_dict.get("audio_vae_config") + audio_vae = AudioVAE(config=AudioVAEConfig(**audio_vae_config)) if audio_vae_config is not None else AudioVAE() + + state_dict = torch.load( + Path(runtime_model_path) / "audiovae.pth", + map_location="cpu", + weights_only=True, + )["state_dict"] + audio_vae.load_state_dict(state_dict, strict=True) + audio_vae = audio_vae.to(device=device, dtype=torch.float32).eval() + if device.type == "npu" and hasattr(torch, "npu"): + torch.npu.set_device(device) + patch_size = int(config_dict.get("patch_size", 2)) + return _DirectVoxCPMAudioVAE(audio_vae, patch_size=patch_size) + + +class VoxCPMForConditionalGeneration(nn.Module): + input_modalities = "audio" + _LATENT_STAGES = {"latent_generator", "latent", "ar_dit"} + _VAE_STAGES = {"vae", "audio_vae"} + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + del prefix + self.vllm_config = vllm_config + self.model_path = vllm_config.model_config.model + self.model_stage = getattr(vllm_config.model_config, "model_stage", "latent_generator") + self.have_multimodal_outputs = True + self.has_preprocess = False + self.has_postprocess = False + self.enable_update_additional_information = True + self.requires_raw_input_tokens = True + self.inject_omni_request_id_into_runtime_info = True + self._pipeline = None + self._latent_stream_gens: dict[str, Any] = {} + self._latent_stream_terminal_pending: dict[str, int] = {} + self._latent_stream_completed: set[str] = set() + self._next_local_stream_key = 0 + self._ar_emit_stop_token = True + + def _runner_hidden_device_dtype(self) -> tuple[torch.device, torch.dtype]: + device = _resolve_runtime_device(self.vllm_config) + model_config = getattr(self.vllm_config, "model_config", None) + dtype = getattr(model_config, "dtype", torch.float32) if model_config is not None else torch.float32 + return device, dtype + + def _ensure_model_loaded(self): + if self._pipeline is not None: + return + + target_device = _resolve_runtime_device(self.vllm_config) + model_dtype = getattr(self.vllm_config.model_config, "dtype", None) + normalized_dtype = _normalize_dtype_name(model_dtype) + if self.model_stage in self._LATENT_STAGES: + self._pipeline = _load_native_voxcpm_latent_generator( + self.model_path, + device=target_device, + dtype=normalized_dtype, + ) + elif self.model_stage in self._VAE_STAGES: + self._pipeline = _load_native_voxcpm_audio_vae( + self.model_path, + device=target_device, + ) + else: + raise ValueError( + f"Unsupported VoxCPM model_stage: {self.model_stage}. " + "pure_voxcpm only supports split-stage latent_generator/vae inference." + ) + + logger.info("Loaded VoxCPM stage '%s' on %s", self.model_stage, _device_to_string(target_device)) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + del weights + self._ensure_model_loaded() + return set() + + @staticmethod + def _extract_val(info: dict[str, Any], key: str, default: Any) -> Any: + value = info.get(key, default) + if isinstance(value, list): + return value[0] if value else default + return value + + def _resolve_stream_request_key(self, info: dict[str, Any]) -> str: + request_key = info.get("__voxcpm_stream_key") + if request_key is not None: + return str(request_key) + + request_key = info.get("_omni_req_id") + if request_key is not None: + request_key = str(request_key) + info["__voxcpm_stream_key"] = request_key + return request_key + + request_key = f"voxcpm-local-{self._next_local_stream_key}" + self._next_local_stream_key += 1 + info["__voxcpm_stream_key"] = request_key + return str(request_key) + + def _recover_latent_from_input_ids(self, input_ids: torch.Tensor | None) -> torch.Tensor | None: + if input_ids is None or input_ids.numel() == 0: + return None + flat_ids = input_ids.detach().reshape(-1).to("cpu") + if flat_ids.numel() < 4 or int(flat_ids[0].item()) != _VOXCPM_LATENT_MAGIC: + return None + latent_dim = int(flat_ids[1].item()) + time_dim = int(flat_ids[2].item()) + payload = flat_ids[3:] + expected = latent_dim * time_dim + if latent_dim <= 0 or time_dim <= 0: + raise ValueError(f"Invalid VoxCPM latent header: latent_dim={latent_dim}, time_dim={time_dim}") + if int(payload.numel()) != expected: + raise ValueError( + "Invalid VoxCPM latent payload size: " + f"expected={expected}, actual={int(payload.numel())}, " + f"latent_dim={latent_dim}, time_dim={time_dim}" + ) + packed = payload.to(dtype=torch.int32).to(torch.uint16) + return packed.view(torch.bfloat16).to(torch.float32).reshape(1, latent_dim, time_dim) + + def _maybe_recover_vae_infos( + self, + infos: list[dict[str, Any]], + input_ids: torch.Tensor | None, + *, + async_chunk: bool, + ) -> list[dict[str, Any]]: + if not async_chunk: + return infos + if any(self._extract_val(info, "latent_audio_feat", None) is not None for info in infos): + return infos + recovered = self._recover_latent_from_input_ids(input_ids) + if recovered is None: + return infos + return [{"latent_audio_feat": recovered}] + + @staticmethod + def _normalize_audio_samples(samples: Any) -> np.ndarray: + if isinstance(samples, torch.Tensor): + return samples.detach().cpu().float().reshape(-1).numpy() + return np.asarray(samples, dtype=np.float32).reshape(-1) + + @classmethod + def _normalize_ref_audio(cls, ref_audio: Any) -> tuple[np.ndarray, int]: + if isinstance(ref_audio, str): + raise TypeError("String ref_audio should be handled as a path before waveform normalization.") + + if isinstance(ref_audio, dict): + sample_rate = ref_audio.get("sample_rate") or ref_audio.get("sampling_rate") or ref_audio.get("sr") + samples = None + for key in ("audio", "wav", "samples", "array", "waveform"): + if key in ref_audio and ref_audio[key] is not None: + samples = ref_audio[key] + break + if sample_rate is None or samples is None: + raise ValueError("ref_audio dict must contain waveform data and sample rate.") + return cls._normalize_audio_samples(samples), int(sample_rate) + + if isinstance(ref_audio, (list, tuple)): + if len(ref_audio) == 1: + return cls._normalize_ref_audio(ref_audio[0]) + if len(ref_audio) == 2 and np.isscalar(ref_audio[1]): + return cls._normalize_audio_samples(ref_audio[0]), int(ref_audio[1]) + + raise TypeError(f"Unsupported ref_audio format: {type(ref_audio)!r}") + + @staticmethod + def _write_temp_prompt_wav(waveform: np.ndarray, sample_rate: int) -> str: + prompt_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") + prompt_file.close() + + wav = np.asarray(waveform, dtype=np.float32).reshape(-1) + wav = np.clip(wav, -1.0, 1.0) + pcm16 = (wav * 32767.0).astype(np.int16) + with wave.open(prompt_file.name, "wb") as wav_file: + wav_file.setnchannels(1) + wav_file.setsampwidth(2) + wav_file.setframerate(int(sample_rate)) + wav_file.writeframes(pcm16.tobytes()) + + return prompt_file.name + + @classmethod + def _resolve_prompt_inputs(cls, info: dict[str, Any]) -> tuple[str | None, str | None, str | None]: + prompt_text = cls._extract_val(info, "prompt_text", None) + prompt_wav_path = cls._extract_val(info, "prompt_wav_path", None) + if prompt_wav_path: + if prompt_text is None: + prompt_text = cls._extract_val(info, "ref_text", None) + return prompt_wav_path, prompt_text, None + + ref_audio = cls._extract_val(info, "ref_audio", None) + ref_text = cls._extract_val(info, "ref_text", None) + if ref_audio is None or ref_text is None: + return None, None, None + if isinstance(ref_audio, str): + return ref_audio, ref_text, None + + waveform, sample_rate = cls._normalize_ref_audio(ref_audio) + temp_prompt_wav = cls._write_temp_prompt_wav(waveform, sample_rate) + return temp_prompt_wav, ref_text, temp_prompt_wav + + def embed_input_ids(self, input_ids: torch.Tensor, **_: Any) -> torch.Tensor: + if input_ids.numel() == 0: + return torch.empty((0, 1), device=input_ids.device, dtype=torch.float32) + return torch.zeros((input_ids.shape[0], 1), device=input_ids.device, dtype=torch.float32) + + def _get_vocab_size(self) -> int: + model_config = getattr(self.vllm_config, "model_config", None) + if model_config is not None: + getter = getattr(model_config, "get_vocab_size", None) + if callable(getter): + try: + return int(getter()) + except Exception: + pass + hf_config = getattr(model_config, "hf_text_config", None) + if hf_config is not None and hasattr(hf_config, "vocab_size"): + return int(hf_config.vocab_size) + return 32000 + + def _make_empty_output( + self, + *, + output_key: str, + payload_factory: Callable[[], torch.Tensor], + infos: list[dict[str, Any]], + sample_rate: int, + out_device: torch.device, + out_dtype: torch.dtype, + hidden_rows: int | None = None, + ) -> OmniOutput: + if hidden_rows is None: + hidden_rows = len(infos) + return OmniOutput( + text_hidden_states=torch.zeros((hidden_rows, 1), device=out_device, dtype=out_dtype), + multimodal_outputs={ + output_key: [payload_factory() for _ in infos], + "sr": [torch.tensor(sample_rate, dtype=torch.int32) for _ in infos], + }, + ) + + def _finalize_stage_output( + self, + *, + output_key: str, + outputs: list[torch.Tensor], + sample_rates: list[torch.Tensor], + out_device: torch.device, + out_dtype: torch.dtype, + hidden_rows: int | None = None, + ) -> OmniOutput: + multimodal_outputs: dict[str, Any] = {output_key: outputs, "sr": sample_rates} + if hidden_rows is not None: + text_hidden_states = torch.zeros((hidden_rows, 1), device=out_device, dtype=out_dtype) + elif outputs: + outputs_tensor = torch.stack(outputs) + text_hidden_states = ( + outputs_tensor.unsqueeze(-1) + if outputs_tensor.ndim == 1 + else outputs_tensor.reshape(-1, outputs_tensor.shape[-1]) + ) + else: + text_hidden_states = torch.zeros((0, 1), device=out_device, dtype=out_dtype) + text_hidden_states = text_hidden_states.to(device=out_device, dtype=out_dtype) + return OmniOutput( + text_hidden_states=text_hidden_states, + multimodal_outputs=multimodal_outputs, + ) + + def _forward_vae_stage( + self, + infos: list[dict[str, Any]], + *, + sample_rate: int, + async_chunk: bool, + out_device: torch.device, + out_dtype: torch.dtype, + ) -> OmniOutput: + if all(self._extract_val(info, "latent_audio_feat", None) is None for info in infos): + self._ar_emit_stop_token = True + return self._make_empty_output( + output_key="model_outputs", + payload_factory=lambda: torch.zeros((0,), dtype=torch.float32), + infos=infos, + sample_rate=sample_rate, + out_device=out_device, + out_dtype=out_dtype, + ) + + outputs: list[torch.Tensor] = [] + sample_rates: list[torch.Tensor] = [] + for info in infos: + latent_audio_feat = self._extract_val(info, "latent_audio_feat", None) + audio_tensor = self._pipeline.decode(latent_audio_feat, trim_streaming_patch=async_chunk) + outputs.append(audio_tensor.float().cpu()) + sample_rates.append(torch.tensor(sample_rate, dtype=torch.int32)) + + self._ar_emit_stop_token = True + return self._finalize_stage_output( + output_key="model_outputs", + outputs=outputs, + sample_rates=sample_rates, + out_device=out_device, + out_dtype=out_dtype, + ) + + def _forward_latent_stage( + self, + infos: list[dict[str, Any]], + *, + sample_rate: int, + async_chunk: bool, + out_device: torch.device, + out_dtype: torch.dtype, + hidden_rows: int, + ) -> OmniOutput: + texts = [self._extract_val(info, "text", "") for info in infos] + if all(not text for text in texts): + self._ar_emit_stop_token = True + return self._make_empty_output( + output_key="latent_audio_feat", + payload_factory=lambda: torch.zeros((0,), dtype=torch.float32), + infos=infos, + sample_rate=sample_rate, + out_device=out_device, + out_dtype=out_dtype, + hidden_rows=hidden_rows, + ) + + outputs: list[torch.Tensor] = [] + sample_rates: list[torch.Tensor] = [] + last_chunk_flags: list[bool] | None = [] if async_chunk else None + payload_finished_flags: list[bool] | None = [] if async_chunk else None + for info in infos: + text = self._extract_val(info, "text", "") + cfg_value = float(self._extract_val(info, "cfg_value", 2.0)) + inference_timesteps = int(self._extract_val(info, "inference_timesteps", 10)) + min_len = int(self._extract_val(info, "min_len", 2)) + max_len = int(self._extract_val(info, "max_len", self._extract_val(info, "max_new_tokens", 4096))) + retry_badcase = bool(self._extract_val(info, "retry_badcase", True)) + retry_badcase_max_times = int(self._extract_val(info, "retry_badcase_max_times", 3)) + retry_badcase_ratio_threshold = float(self._extract_val(info, "retry_badcase_ratio_threshold", 6.0)) + streaming_prefix_len = int(self._extract_val(info, "streaming_prefix_len", 3)) + + request_key = self._resolve_stream_request_key(info) + created_temp: str | None = None + + if async_chunk: + terminal_pending = self._latent_stream_terminal_pending.get(request_key, 0) + if terminal_pending > 0: + outputs.append(torch.zeros((0,), dtype=torch.float32)) + assert last_chunk_flags is not None + last_chunk_flags.append(True) + assert payload_finished_flags is not None + payload_finished_flags.append(terminal_pending == 1) + if terminal_pending == 1: + self._latent_stream_terminal_pending.pop(request_key, None) + else: + self._latent_stream_terminal_pending[request_key] = terminal_pending - 1 + sample_rates.append(torch.tensor(sample_rate, dtype=torch.int32)) + continue + + if request_key in self._latent_stream_completed: + outputs.append(torch.zeros((0,), dtype=torch.float32)) + assert last_chunk_flags is not None + last_chunk_flags.append(True) + assert payload_finished_flags is not None + payload_finished_flags.append(False) + sample_rates.append(torch.tensor(sample_rate, dtype=torch.int32)) + continue + + if request_key not in self._latent_stream_gens: + prompt_wav_path, prompt_text, temp_prompt_wav = self._resolve_prompt_inputs(info) + created_temp = temp_prompt_wav + self._latent_stream_gens[request_key] = self._pipeline.iter_latent_chunks_streaming( + text=text, + prompt_wav_path=prompt_wav_path, + prompt_text=prompt_text, + cfg_value=cfg_value, + inference_timesteps=inference_timesteps, + min_len=min_len, + max_len=max_len, + streaming_prefix_len=streaming_prefix_len, + retry_badcase=False, + retry_badcase_max_times=retry_badcase_max_times, + retry_badcase_ratio_threshold=retry_badcase_ratio_threshold, + ) + generator = self._latent_stream_gens[request_key] + try: + chunk_latent, is_last = next(generator) + except StopIteration: + self._latent_stream_gens.pop(request_key, None) + self._latent_stream_terminal_pending[request_key] = 1 + self._latent_stream_completed.add(request_key) + outputs.append(torch.zeros((0,), dtype=torch.float32)) + assert last_chunk_flags is not None + last_chunk_flags.append(True) + assert payload_finished_flags is not None + payload_finished_flags.append(True) + else: + if is_last: + self._latent_stream_gens.pop(request_key, None) + self._latent_stream_terminal_pending[request_key] = 1 + self._latent_stream_completed.add(request_key) + outputs.append(chunk_latent.detach().float().cpu()) + assert last_chunk_flags is not None + last_chunk_flags.append(bool(is_last)) + assert payload_finished_flags is not None + payload_finished_flags.append(False) + finally: + if created_temp is not None and os.path.exists(created_temp): + os.unlink(created_temp) + sample_rates.append(torch.tensor(sample_rate, dtype=torch.int32)) + continue + + prompt_wav_path, prompt_text, temp_prompt_wav = self._resolve_prompt_inputs(info) + try: + latent_audio_feat = self._pipeline.generate_latents( + text=text, + prompt_wav_path=prompt_wav_path, + prompt_text=prompt_text, + cfg_value=cfg_value, + inference_timesteps=inference_timesteps, + min_len=min_len, + max_len=max_len, + retry_badcase=retry_badcase, + retry_badcase_max_times=retry_badcase_max_times, + retry_badcase_ratio_threshold=retry_badcase_ratio_threshold, + ) + outputs.append(latent_audio_feat.float().cpu()) + finally: + if temp_prompt_wav is not None and os.path.exists(temp_prompt_wav): + os.unlink(temp_prompt_wav) + + sample_rates.append(torch.tensor(sample_rate, dtype=torch.int32)) + + self._ar_emit_stop_token = all(last_chunk_flags) if async_chunk and last_chunk_flags else True + output = self._finalize_stage_output( + output_key="latent_audio_feat", + outputs=outputs, + sample_rates=sample_rates, + out_device=out_device, + out_dtype=out_dtype, + hidden_rows=hidden_rows, + ) + if async_chunk and payload_finished_flags is not None: + output.multimodal_outputs["finished"] = [ + torch.tensor(flag, dtype=torch.bool) for flag in payload_finished_flags + ] + return output + + def compute_logits(self, hidden_states: torch.Tensor | OmniOutput, sampling_metadata: Any = None) -> torch.Tensor: + del sampling_metadata + if isinstance(hidden_states, OmniOutput): + hidden_states = hidden_states.text_hidden_states + if hidden_states is None: + device, dtype = self._runner_hidden_device_dtype() + hidden_states = torch.zeros((0, 1), device=device, dtype=dtype) + if hidden_states.ndim == 1: + hidden_states = hidden_states.unsqueeze(-1) + elif hidden_states.ndim > 2: + hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1]) + + vocab_size = self._get_vocab_size() + num_rows = int(hidden_states.shape[0]) + logits = torch.zeros((num_rows, vocab_size), dtype=torch.float32, device=hidden_states.device) + eos_id = 2 if vocab_size > 2 else 0 + safe_id = 1 if vocab_size > 1 and 1 != eos_id else 0 + emit_stop = getattr(self, "_ar_emit_stop_token", True) + if num_rows > 0: + if emit_stop: + logits[:, eos_id] = 1.0e6 + else: + logits[:, eos_id] = -1.0e9 + logits[:, safe_id] = 1.0e6 + return logits + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor | None = None, + positions: torch.Tensor | None = None, + intermediate_tensors: Any = None, + inputs_embeds: torch.Tensor | None = None, + runtime_additional_information: list[dict[str, Any]] | None = None, + model_intermediate_buffer: list[dict[str, Any]] | None = None, + **kwargs: Any, + ) -> OmniOutput: + del positions, intermediate_tensors, inputs_embeds, kwargs + self._ensure_model_loaded() + out_device, out_dtype = self._runner_hidden_device_dtype() + if input_ids is not None and input_ids.device.type == out_device.type: + out_device = input_ids.device + + infos = model_intermediate_buffer or runtime_additional_information or [{}] + hidden_rows = len(infos) + if input_ids is not None and len(input_ids.shape) > 0: + hidden_rows = max(hidden_rows, int(input_ids.shape[0])) + sample_rate = int(getattr(self._pipeline, "sample_rate", 24000)) + async_chunk = bool(getattr(self.vllm_config.model_config, "async_chunk", False)) + if self.model_stage in self._VAE_STAGES: + infos = self._maybe_recover_vae_infos(infos, input_ids, async_chunk=async_chunk) + return self._forward_vae_stage( + infos, + sample_rate=sample_rate, + async_chunk=async_chunk, + out_device=out_device, + out_dtype=out_dtype, + ) + if self.model_stage in self._LATENT_STAGES: + return self._forward_latent_stage( + infos, + sample_rate=sample_rate, + async_chunk=async_chunk, + out_device=out_device, + out_dtype=out_dtype, + hidden_rows=hidden_rows, + ) + raise ValueError(f"Unsupported VoxCPM model_stage at runtime: {self.model_stage}") + + def make_empty_intermediate_tensors( + self, batch_size: int, dtype: torch.dtype, device: torch.device + ) -> IntermediateTensors: + del batch_size, dtype, device + return {} + + +__all__ = ["VoxCPMForConditionalGeneration"] diff --git a/vllm_omni/model_executor/models/voxcpm/voxcpm_loader.py b/vllm_omni/model_executor/models/voxcpm/voxcpm_loader.py new file mode 100644 index 0000000000..dac7117cad --- /dev/null +++ b/vllm_omni/model_executor/models/voxcpm/voxcpm_loader.py @@ -0,0 +1,247 @@ +from __future__ import annotations + +import importlib +import json +import os +import shutil +import sys +import tempfile +from contextlib import contextmanager +from hashlib import sha256 +from pathlib import Path +from typing import Any +from unittest.mock import patch + +import numpy as np +import torch +from vllm.config import VllmConfig +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def _iter_voxcpm_src_candidates() -> list[Path]: + candidates: list[Path] = [] + env_path = os.environ.get("VLLM_OMNI_VOXCPM_CODE_PATH") + if env_path: + candidates.append(Path(env_path).expanduser()) + + repo_root = Path(__file__).resolve().parents[4] + candidates.append(repo_root.parent / "VoxCPM" / "src") + + unique_candidates: list[Path] = [] + seen: set[str] = set() + for candidate in candidates: + candidate_key = str(candidate) + if candidate_key in seen: + continue + seen.add(candidate_key) + unique_candidates.append(candidate) + return unique_candidates + + +def _prepend_voxcpm_src(candidate: Path) -> None: + candidate_str = str(candidate) + if candidate_str not in sys.path: + sys.path.insert(0, candidate_str) + + +def _import_voxcpm_attrs(module_name: str, *attr_names: str) -> tuple[Any, ...]: + last_exc: ImportError | None = None + for candidate in _iter_voxcpm_src_candidates(): + if not candidate.exists(): + continue + _prepend_voxcpm_src(candidate) + try: + module = importlib.import_module(module_name) + return tuple(getattr(module, attr_name) for attr_name in attr_names) + except ImportError as exc: + last_exc = exc + + try: + module = importlib.import_module(module_name) + return tuple(getattr(module, attr_name) for attr_name in attr_names) + except ImportError as exc: + last_exc = exc + + raise ImportError(f"Failed to import {module_name}.") from last_exc + + +def _import_voxcpm_base_model_class(): + """Import upstream ``VoxCPMModel`` from ``VoxCPM/src/voxcpm`` (env, sibling tree, or pip).""" + try: + (VoxCPMModel,) = _import_voxcpm_attrs("voxcpm.model.voxcpm", "VoxCPMModel") + return VoxCPMModel + except ImportError as exc: + raise ImportError( + "Failed to import VoxCPMModel. Install the `voxcpm` package or set " + "`VLLM_OMNI_VOXCPM_CODE_PATH` to the VoxCPM repository `src` directory " + "(the parent of the `voxcpm` package that contains `model/` and `modules/`)." + ) from exc + + +def _import_voxcpm_audio_vae_classes(): + try: + return _import_voxcpm_attrs("voxcpm.modules.audiovae", "AudioVAE", "AudioVAEConfig") + except ImportError as exc: + raise ImportError( + "Failed to import VoxCPM AudioVAE. Install the `voxcpm` package or set " + "`VLLM_OMNI_VOXCPM_CODE_PATH` to the VoxCPM repository `src` directory." + ) from exc + + +def _device_to_string(device: torch.device) -> str: + if device.index is None: + return device.type + return f"{device.type}:{device.index}" + + +def _normalize_dtype_name(dtype: Any) -> str | None: + if dtype is None: + return None + if isinstance(dtype, torch.dtype): + mapping = { + torch.bfloat16: "bfloat16", + torch.float16: "float16", + torch.float32: "float32", + } + return mapping.get(dtype, str(dtype).removeprefix("torch.")) + dtype_str = str(dtype) + return dtype_str.removeprefix("torch.") + + +def _resolve_runtime_device(vllm_config: VllmConfig) -> torch.device: + try: + from vllm_omni.platforms import current_omni_platform + + return current_omni_platform.get_torch_device() + except Exception: + pass + + device = getattr(getattr(vllm_config, "device_config", None), "device", None) + if isinstance(device, torch.device): + return device + if device: + return torch.device(device) + return torch.device("cpu") + + +def _prepare_runtime_model_dir( + model_path: str | Path, + *, + target_device: torch.device, + target_dtype: str | None, +) -> str: + source_dir = Path(model_path) + config_path = source_dir / "config.json" + if not config_path.exists(): + return str(source_dir) + + config_text = config_path.read_text() + config_dict = json.loads(config_text) + desired_device = target_device.type + desired_dtype = target_dtype or config_dict.get("dtype") + + if config_dict.get("device") == desired_device and config_dict.get("dtype") == desired_dtype: + return str(source_dir) + + digest = sha256(f"{source_dir.resolve()}:{config_text}:{desired_device}:{desired_dtype}".encode()).hexdigest()[:16] + runtime_dir = Path(tempfile.gettempdir()) / "vllm_omni_voxcpm_runtime" / digest + runtime_dir.mkdir(parents=True, exist_ok=True) + + for entry in source_dir.iterdir(): + target = runtime_dir / entry.name + if entry.name == "config.json" or target.exists(): + continue + try: + target.symlink_to(entry, target_is_directory=entry.is_dir()) + except OSError as exc: + logger.warning( + "Falling back to copying VoxCPM runtime artifact %s into %s because symlink creation failed: %s", + entry, + runtime_dir, + exc, + ) + if entry.is_dir(): + shutil.copytree(entry, target, dirs_exist_ok=True) + else: + shutil.copy2(entry, target) + + patched_config = dict(config_dict) + patched_config["device"] = desired_device + if desired_dtype is not None: + patched_config["dtype"] = desired_dtype + (runtime_dir / "config.json").write_text(json.dumps(patched_config, indent=2, sort_keys=True)) + return str(runtime_dir) + + +@contextmanager +def _force_cuda_available_for_npu(device: torch.device): + if device.type != "npu": + yield + return + + with patch("torch.cuda.is_available", return_value=True): + yield + + +def _is_torchcodec_load_error(exc: BaseException) -> bool: + message = str(exc).lower() + return "torchcodec" in message or "load_with_torchcodec" in message + + +def _load_audio_with_soundfile( + prompt_wav_path: str, + *, + sample_rate: int, +) -> torch.Tensor: + try: + import soundfile as sf + except ImportError: + raise + + audio_np, source_sr = sf.read(prompt_wav_path, dtype="float32", always_2d=True) + audio = torch.from_numpy(np.ascontiguousarray(audio_np.T)) + + if audio.size(0) > 1: + audio = audio.mean(dim=0, keepdim=True) + + if int(source_sr) != int(sample_rate): + try: + import torchaudio + except ImportError as exc: + raise ImportError("torchaudio is required for resampling prompt audio.") from exc + audio = torchaudio.functional.resample(audio, int(source_sr), int(sample_rate)) + + return audio + + +def _build_prompt_cache_with_soundfile(model: Any, *args: Any, **kwargs: Any) -> dict[str, Any]: + if args: + prompt_text = args[0] + prompt_wav_path = args[1] if len(args) > 1 else kwargs.get("prompt_wav_path") + else: + prompt_text = kwargs.get("prompt_text") + prompt_wav_path = kwargs.get("prompt_wav_path") + + if not prompt_text or not prompt_wav_path: + raise ValueError("prompt_text and prompt_wav_path are required") + + audio = _load_audio_with_soundfile(prompt_wav_path, sample_rate=int(model.sample_rate)) + + patch_len = model.patch_size * model.chunk_size + if audio.size(1) % patch_len != 0: + padding_size = patch_len - audio.size(1) % patch_len + audio = torch.nn.functional.pad(audio, (padding_size, 0)) + + audio_feat = model.audio_vae.encode(audio.to(model.device), model.sample_rate).cpu() + audio_feat = audio_feat.view( + model.audio_vae.latent_dim, + -1, + model.patch_size, + ).permute(1, 2, 0) + + return { + "prompt_text": prompt_text, + "audio_feat": audio_feat, + } diff --git a/vllm_omni/model_executor/models/voxcpm/voxcpm_runtime_utils.py b/vllm_omni/model_executor/models/voxcpm/voxcpm_runtime_utils.py new file mode 100644 index 0000000000..36b4282c2d --- /dev/null +++ b/vllm_omni/model_executor/models/voxcpm/voxcpm_runtime_utils.py @@ -0,0 +1,44 @@ +from __future__ import annotations + +import json +import shutil +from pathlib import Path + + +def resolve_voxcpm_model_dir(model: str) -> Path: + model_path = Path(model).expanduser() + if model_path.exists(): + return model_path + + from huggingface_hub import snapshot_download + + return Path(snapshot_download(repo_id=model)) + + +def prepare_voxcpm_hf_config_dir(model_dir: str | Path, hf_config_dir: str | Path) -> Path: + model_dir = Path(model_dir).expanduser() + hf_config_dir = Path(hf_config_dir).expanduser() + hf_config_dir.mkdir(parents=True, exist_ok=True) + + source_config_path = model_dir / "config.json" + if not source_config_path.exists(): + raise FileNotFoundError(f"VoxCPM config.json not found under {model_dir}") + + config_path = hf_config_dir / "config.json" + shutil.copy2(source_config_path, config_path) + + source_generation_config_path = model_dir / "generation_config.json" + if source_generation_config_path.exists(): + shutil.copy2(source_generation_config_path, hf_config_dir / "generation_config.json") + + config_dict = json.loads(config_path.read_text(encoding="utf-8")) + config_dict["model_type"] = "voxcpm" + config_dict.setdefault("architectures", ["VoxCPMForConditionalGeneration"]) + config_path.write_text(json.dumps(config_dict, indent=2, ensure_ascii=False), encoding="utf-8") + return hf_config_dir + + +__all__ = [ + "prepare_voxcpm_hf_config_dir", + "resolve_voxcpm_model_dir", +] diff --git a/vllm_omni/model_executor/models/voxcpm/voxcpm_stage_wrappers.py b/vllm_omni/model_executor/models/voxcpm/voxcpm_stage_wrappers.py new file mode 100644 index 0000000000..f4446c796e --- /dev/null +++ b/vllm_omni/model_executor/models/voxcpm/voxcpm_stage_wrappers.py @@ -0,0 +1,185 @@ +from __future__ import annotations + +import os +from collections.abc import Generator +from typing import Any + +import torch +import torch.nn as nn +from einops import rearrange + + +class _DirectVoxCPMLatentGenerator: + def __init__(self, tts_model: Any): + self.tts_model = tts_model + self.sample_rate = int(getattr(tts_model, "sample_rate", 24000)) + + def generate_latents( + self, + *, + text: str, + prompt_wav_path: str | None = None, + prompt_text: str | None = None, + cfg_value: float = 2.0, + inference_timesteps: int = 10, + min_len: int = 2, + max_len: int = 4096, + retry_badcase: bool = True, + retry_badcase_max_times: int = 3, + retry_badcase_ratio_threshold: float = 6.0, + ) -> torch.Tensor: + if not isinstance(text, str) or not text.strip(): + raise ValueError("target text must be a non-empty string") + if (prompt_wav_path is None) != (prompt_text is None): + raise ValueError("prompt_wav_path and prompt_text must both be provided or both be None") + if prompt_wav_path is not None and not os.path.exists(prompt_wav_path): + raise FileNotFoundError(f"prompt_wav_path does not exist: {prompt_wav_path}") + + prompt_cache = None + if prompt_wav_path is not None and prompt_text is not None: + prompt_cache = self.tts_model.build_prompt_cache( + prompt_text=prompt_text, + prompt_wav_path=prompt_wav_path, + ) + + gen_kw = dict( + target_text=" ".join(text.split()), + prompt_cache=prompt_cache, + min_len=min_len, + max_len=max_len, + inference_timesteps=inference_timesteps, + cfg_value=cfg_value, + retry_badcase=retry_badcase, + retry_badcase_max_times=retry_badcase_max_times, + retry_badcase_ratio_threshold=retry_badcase_ratio_threshold, + ) + latent_entry = getattr(self.tts_model, "generate_latents_with_prompt_cache", None) + if latent_entry is not None: + _, _, pred_audio_feat = latent_entry(**gen_kw) + else: + try: + _, _, pred_audio_feat = self.tts_model.generate_with_prompt_cache( + **gen_kw, + latents_only=True, + ) + except TypeError: + _, _, pred_audio_feat = self.tts_model.generate_with_prompt_cache(**gen_kw) + return pred_audio_feat.detach().cpu().to(torch.float32) + + def iter_latent_chunks_streaming( + self, + *, + text: str, + prompt_wav_path: str | None = None, + prompt_text: str | None = None, + cfg_value: float = 2.0, + inference_timesteps: int = 10, + min_len: int = 2, + max_len: int = 4096, + streaming_prefix_len: int = 3, + retry_badcase: bool = False, + retry_badcase_max_times: int = 3, + retry_badcase_ratio_threshold: float = 6.0, + ) -> Generator[tuple[torch.Tensor, bool], None, None]: + """Yield ``(latent_window, is_last_chunk)`` for Omni async_chunk latent to VAE.""" + if not isinstance(text, str) or not text.strip(): + raise ValueError("target text must be a non-empty string") + if (prompt_wav_path is None) != (prompt_text is None): + raise ValueError("prompt_wav_path and prompt_text must both be provided or both be None") + if prompt_wav_path is not None and not os.path.exists(prompt_wav_path): + raise FileNotFoundError(f"prompt_wav_path does not exist: {prompt_wav_path}") + + prompt_cache = None + if prompt_wav_path is not None and prompt_text is not None: + prompt_cache = self.tts_model.build_prompt_cache( + prompt_text=prompt_text, + prompt_wav_path=prompt_wav_path, + ) + + gen_kw = dict( + target_text=" ".join(text.split()), + prompt_cache=prompt_cache, + min_len=min_len, + max_len=max_len, + inference_timesteps=inference_timesteps, + cfg_value=cfg_value, + retry_badcase=retry_badcase, + retry_badcase_max_times=retry_badcase_max_times, + retry_badcase_ratio_threshold=retry_badcase_ratio_threshold, + streaming_prefix_len=streaming_prefix_len, + ) + stream_entry = getattr(self.tts_model, "generate_latents_with_prompt_cache_streaming", None) + if stream_entry is not None: + gen = stream_entry(**gen_kw) + else: + fallback_stream_entry = getattr(self.tts_model, "generate_with_prompt_cache_streaming", None) + if fallback_stream_entry is not None: + gen = fallback_stream_entry(**gen_kw, latents_only=True) + else: + gen = self.tts_model._generate_with_prompt_cache(streaming=True, latents_only=True, **gen_kw) + + iterator = iter(gen) + previous = next(iterator, None) + while previous is not None: + current = next(iterator, None) + _, _target_tok, chunk_latent = previous + if not isinstance(chunk_latent, torch.Tensor): + chunk_latent = torch.as_tensor(chunk_latent) + yield chunk_latent, current is None + previous = current + + +class _DirectVoxCPMAudioVAE: + def __init__(self, audio_vae: nn.Module, *, patch_size: int = 2): + self.audio_vae = audio_vae + self.sample_rate = int(getattr(audio_vae, "sample_rate", 24000)) + self.latent_dim = int(getattr(audio_vae, "latent_dim", 64)) + self.patch_size = int(patch_size) + self._chunk_size = int(getattr(audio_vae, "chunk_size", 1)) + self._stream_audio_patch_samples = max(1, self.patch_size * self._chunk_size) + + def _prepare_latents_for_decode(self, latent_audio_feat: Any) -> torch.Tensor: + latents = latent_audio_feat + if not isinstance(latents, torch.Tensor): + latents = torch.tensor(latents, dtype=torch.float32) + latents = latents.detach().to(torch.float32) + + if latents.ndim == 3: + if latents.shape[-1] == self.latent_dim: + latents = rearrange(latents, "t p d -> 1 d (t p)") + elif latents.shape[1] == self.latent_dim: + latents = latents.contiguous() + else: + raise ValueError(f"Unsupported latent_audio_feat shape: {tuple(latents.shape)}") + elif latents.ndim == 2: + if latents.shape[0] == self.latent_dim: + latents = latents.unsqueeze(0) + elif latents.shape[1] == self.latent_dim: + latents = rearrange(latents, "t d -> 1 d t") + else: + raise ValueError(f"Unsupported latent_audio_feat shape: {tuple(latents.shape)}") + else: + raise ValueError(f"Unsupported latent_audio_feat ndim: {latents.ndim}") + + return latents + + @torch.no_grad() + def decode(self, latent_audio_feat: Any, *, trim_streaming_patch: bool = False) -> torch.Tensor: + latents = self._prepare_latents_for_decode(latent_audio_feat) + device = next(self.audio_vae.parameters()).device + raw = self.audio_vae.decode(latents.to(device=device, dtype=torch.float32)) + if isinstance(raw, dict): + audio = raw.get("audio") + if audio is None: + audio = next(v for v in raw.values() if isinstance(v, torch.Tensor)) + else: + audio = raw + if audio.dim() == 3: + stream = audio.squeeze(1) + elif audio.dim() == 2: + stream = audio + else: + stream = audio.reshape(audio.shape[0], -1) + if trim_streaming_patch: + stream = stream[..., -self._stream_audio_patch_samples :] + return stream.reshape(-1).detach().cpu().to(torch.float32) diff --git a/vllm_omni/model_executor/stage_configs/voxcpm.yaml b/vllm_omni/model_executor/stage_configs/voxcpm.yaml new file mode 100644 index 0000000000..a5f324f660 --- /dev/null +++ b/vllm_omni/model_executor/stage_configs/voxcpm.yaml @@ -0,0 +1,69 @@ +# VoxCPM two-stage (latent → VAE) without async_chunk: one-shot latent then decode. +stage_args: + - stage_id: 0 + stage_type: llm + is_comprehension: true + runtime: + devices: "0" + max_batch_size: 1 + engine_args: + dtype: bfloat16 + model_stage: latent_generator + model_arch: VoxCPMForConditionalGeneration + # Optional persistent HF-compatible config dir for native VoxCPM models. + hf_config_path: ${oc.env:VLLM_OMNI_VOXCPM_HF_CONFIG_PATH,} + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + enforce_eager: true + trust_remote_code: true + async_scheduling: false + enable_prefix_caching: false + engine_output_type: latent + gpu_memory_utilization: 0.7 + distributed_executor_backend: "mp" + max_num_batched_tokens: 4096 + max_model_len: 4096 + default_sampling_params: + temperature: 0.0 + top_p: 1.0 + top_k: -1 + max_tokens: 4096 + stop_token_ids: [2] + seed: 42 + detokenize: false + repetition_penalty: 1.0 + final_output: false + + - stage_id: 1 + stage_type: llm + runtime: + devices: "0" + max_batch_size: 1 + engine_args: + dtype: float32 + model_stage: vae + model_arch: VoxCPMForConditionalGeneration + hf_config_path: ${oc.env:VLLM_OMNI_VOXCPM_HF_CONFIG_PATH,} + worker_type: generation + scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler + enforce_eager: true + trust_remote_code: true + async_scheduling: false + enable_prefix_caching: false + engine_output_type: audio + gpu_memory_utilization: 0.15 + distributed_executor_backend: "mp" + max_num_batched_tokens: 8192 + max_model_len: 4096 + engine_input_source: [0] + custom_process_input_func: vllm_omni.model_executor.stage_input_processors.voxcpm.latent2vae + final_output: true + final_output_type: audio + default_sampling_params: + temperature: 0.0 + top_p: 1.0 + top_k: -1 + max_tokens: 1 + seed: 42 + detokenize: true + repetition_penalty: 1.0 diff --git a/vllm_omni/model_executor/stage_configs/voxcpm_async_chunk.yaml b/vllm_omni/model_executor/stage_configs/voxcpm_async_chunk.yaml new file mode 100644 index 0000000000..cf78d4e438 --- /dev/null +++ b/vllm_omni/model_executor/stage_configs/voxcpm_async_chunk.yaml @@ -0,0 +1,102 @@ +# VoxCPM two-stage streaming (align with qwen3_tts.yaml async_chunk pattern). +# Stage0 (latent_generator) emits latent in time chunks; Stage1 (VAE) decodes as chunks arrive. +async_chunk: true +stage_args: + - stage_id: 0 + stage_type: llm + is_comprehension: true + runtime: + devices: "0" + max_batch_size: 1 + engine_args: + dtype: bfloat16 + model_stage: latent_generator + model_arch: VoxCPMForConditionalGeneration + # Optional persistent HF-compatible config dir for native VoxCPM models. + hf_config_path: ${oc.env:VLLM_OMNI_VOXCPM_HF_CONFIG_PATH,} + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + enforce_eager: true + trust_remote_code: true + async_scheduling: true + enable_prefix_caching: false + engine_output_type: latent + gpu_memory_utilization: 0.7 + distributed_executor_backend: "mp" + max_num_batched_tokens: 4096 + max_model_len: 4096 + custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.voxcpm.latent2vae_async_chunk + default_sampling_params: + temperature: 0.0 + top_p: 1.0 + top_k: -1 + max_tokens: 4096 + stop_token_ids: [2] + seed: 42 + detokenize: false + repetition_penalty: 1.0 + final_output: false + output_connectors: + to_stage_1: voxcpm_shm + + - stage_id: 1 + stage_type: llm + runtime: + devices: "0" + max_batch_size: 1 + engine_args: + dtype: float32 + model_stage: vae + model_arch: VoxCPMForConditionalGeneration + hf_config_path: ${oc.env:VLLM_OMNI_VOXCPM_HF_CONFIG_PATH,} + worker_type: generation + scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler + enforce_eager: true + trust_remote_code: true + async_scheduling: false + enable_prefix_caching: false + engine_output_type: audio + gpu_memory_utilization: 0.15 + distributed_executor_backend: "mp" + max_num_batched_tokens: 8192 + max_model_len: 4096 + engine_input_source: [0] + final_output: true + final_output_type: audio + input_connectors: + from_stage_0: voxcpm_shm + default_sampling_params: + temperature: 0.0 + top_p: 1.0 + top_k: -1 + max_tokens: 128 + seed: 42 + detokenize: true + repetition_penalty: 1.0 + + +runtime: + enabled: true + defaults: + window_size: -1 + max_inflight: 1 + + connectors: + voxcpm_shm: + name: SharedMemoryConnector + extra: + shm_threshold_bytes: 65536 + # Frame-aligned codec streaming transport. + codec_streaming: true + # Connector polling / timeout (unit: loop count, sleep interval in seconds). + connector_get_sleep_s: 0.01 + connector_get_max_wait_first_chunk: 3000 + connector_get_max_wait: 300 + # Align with Omni: small chunks with sufficient context overlap. + codec_chunk_frames: 1 + codec_left_context_frames: 1 + + edges: + - from: 0 + to: 1 + window_size: -1 diff --git a/vllm_omni/model_executor/stage_input_processors/voxcpm.py b/vllm_omni/model_executor/stage_input_processors/voxcpm.py new file mode 100644 index 0000000000..c2fcf521bf --- /dev/null +++ b/vllm_omni/model_executor/stage_input_processors/voxcpm.py @@ -0,0 +1,128 @@ +from __future__ import annotations + +from typing import Any + +import torch +from vllm.inputs import TextPrompt + +from vllm_omni.inputs.data import OmniTokensPrompt + +_VOXCPM_LATENT_MAGIC = 131071 + + +def _serialize_latent_to_codes(latent: Any) -> list[int]: + latent_tensor = latent if isinstance(latent, torch.Tensor) else torch.as_tensor(latent) + latent_tensor = latent_tensor.detach().cpu().contiguous() + if latent_tensor.ndim == 3: + if latent_tensor.shape[0] != 1: + raise ValueError(f"Expected batch=1 latent tensor, got shape={tuple(latent_tensor.shape)}") + latent_tensor = latent_tensor.squeeze(0) + if latent_tensor.ndim != 2: + raise ValueError(f"Unsupported latent_audio_feat shape for async chunk: {tuple(latent_tensor.shape)}") + latent_dim, time_dim = int(latent_tensor.shape[0]), int(latent_tensor.shape[1]) + packed = latent_tensor.to(torch.bfloat16).contiguous().view(torch.uint16).reshape(-1).to(torch.int32) + return [_VOXCPM_LATENT_MAGIC, latent_dim, time_dim, *packed.tolist()] + + +def _coerce_finished_flag(value: Any) -> bool: + """Normalize VoxCPM async-chunk finished markers to a Python bool.""" + if value is None: + return False + if isinstance(value, torch.Tensor): + if value.numel() != 1: + raise ValueError(f"finished tensor must be scalar, got shape={tuple(value.shape)}") + return bool(value.detach().cpu().item()) + if isinstance(value, (list, tuple)): + if not value: + return False + if len(value) != 1: + raise ValueError(f"finished container must have one element, got len={len(value)}") + return _coerce_finished_flag(value[0]) + return bool(value) + + +def latent2vae( + stage_list: list[Any], + engine_input_source: list[int], + prompt: OmniTokensPrompt | TextPrompt | None = None, + requires_multimodal_data: bool = False, +) -> list[OmniTokensPrompt]: + del prompt, requires_multimodal_data + + if not engine_input_source: + raise ValueError("engine_input_source cannot be empty") + + source_stage_id = engine_input_source[0] + if source_stage_id >= len(stage_list): + raise IndexError(f"Invalid stage_id: {source_stage_id}") + + source_outputs = stage_list[source_stage_id].engine_outputs + if source_outputs is None: + raise RuntimeError(f"Stage {source_stage_id} has no outputs yet") + + vae_inputs: list[OmniTokensPrompt] = [] + for source_output in source_outputs: + output = source_output.outputs[0] + multimodal_output = getattr(output, "multimodal_output", None) + if not isinstance(multimodal_output, dict) or "latent_audio_feat" not in multimodal_output: + raise ValueError( + "VoxCPM latent stage output missing 'latent_audio_feat'. " + f"request_id={getattr(source_output, 'request_id', None)}" + ) + + additional_information = { + "latent_audio_feat": multimodal_output["latent_audio_feat"], + } + if "sr" in multimodal_output: + additional_information["sample_rate"] = [int(multimodal_output["sr"])] + + vae_inputs.append( + OmniTokensPrompt( + prompt_token_ids=[0], + additional_information=additional_information, + multi_modal_data=None, + mm_processor_kwargs=None, + ) + ) + + return vae_inputs + + +def latent2vae_async_chunk( + transfer_manager: Any, + pooling_output: dict[str, Any] | None, + request: Any, + is_finished: bool = False, +) -> dict[str, Any] | None: + """Stage-0 latent → stage-1 VAE under ``async_chunk`` (connector payload).""" + # Kept for callback signature compatibility with OmniChunkTransferAdapter. + _ = transfer_manager + finished_request = _coerce_finished_flag(is_finished) + if callable(getattr(request, "is_finished", None)): + finished_request = finished_request or _coerce_finished_flag(request.is_finished()) + if not isinstance(pooling_output, dict): + if finished_request: + return { + "code_predictor_codes": [], + "finished": torch.tensor(True, dtype=torch.bool), + } + return None + + latent = pooling_output.get("latent_audio_feat") + if isinstance(latent, torch.Tensor) and latent.numel() == 0: + latent = None + + if latent is None: + if finished_request: + return { + "code_predictor_codes": [], + "finished": torch.tensor(True, dtype=torch.bool), + } + return None + + serialized_codes = _serialize_latent_to_codes(latent) + out: dict[str, Any] = { + "code_predictor_codes": serialized_codes, + "finished": torch.tensor(finished_request, dtype=torch.bool), + } + return out diff --git a/vllm_omni/platforms/npu/stage_configs/voxcpm.yaml b/vllm_omni/platforms/npu/stage_configs/voxcpm.yaml new file mode 100644 index 0000000000..dcd1f40517 --- /dev/null +++ b/vllm_omni/platforms/npu/stage_configs/voxcpm.yaml @@ -0,0 +1,67 @@ +stage_args: + - stage_id: 0 + stage_type: llm + is_comprehension: true + runtime: + devices: "0" + max_batch_size: 1 + engine_args: + dtype: bfloat16 + model_stage: latent_generator + model_arch: VoxCPMForConditionalGeneration + # Optional persistent HF-compatible config dir for native VoxCPM models. + hf_config_path: ${oc.env:VLLM_OMNI_VOXCPM_HF_CONFIG_PATH,} + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + enforce_eager: true + trust_remote_code: true + async_scheduling: false + enable_prefix_caching: false + engine_output_type: latent + gpu_memory_utilization: 0.75 + distributed_executor_backend: "mp" + max_num_batched_tokens: 4096 + max_model_len: 4096 + default_sampling_params: + temperature: 0.0 + top_p: 1.0 + top_k: -1 + max_tokens: 4096 + seed: 42 + detokenize: false + repetition_penalty: 1.0 + final_output: false + + - stage_id: 1 + stage_type: llm + runtime: + devices: "0" + max_batch_size: 1 + engine_args: + dtype: float32 + model_stage: vae + model_arch: VoxCPMForConditionalGeneration + hf_config_path: ${oc.env:VLLM_OMNI_VOXCPM_HF_CONFIG_PATH,} + worker_type: generation + scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler + enforce_eager: true + trust_remote_code: true + async_scheduling: false + enable_prefix_caching: false + engine_output_type: audio + gpu_memory_utilization: 0.1 + distributed_executor_backend: "mp" + max_num_batched_tokens: 8192 + max_model_len: 4096 + engine_input_source: [0] + custom_process_input_func: vllm_omni.model_executor.stage_input_processors.voxcpm.latent2vae + final_output: true + final_output_type: audio + default_sampling_params: + temperature: 0.0 + top_p: 1.0 + top_k: -1 + max_tokens: 1 + seed: 42 + detokenize: true + repetition_penalty: 1.0 diff --git a/vllm_omni/platforms/npu/stage_configs/voxcpm_async_chunk.yaml b/vllm_omni/platforms/npu/stage_configs/voxcpm_async_chunk.yaml new file mode 100644 index 0000000000..0a4ed7497d --- /dev/null +++ b/vllm_omni/platforms/npu/stage_configs/voxcpm_async_chunk.yaml @@ -0,0 +1,93 @@ +async_chunk: true +stage_args: + - stage_id: 0 + stage_type: llm + is_comprehension: true + runtime: + devices: "0" + max_batch_size: 1 + engine_args: + dtype: bfloat16 + model_stage: latent_generator + model_arch: VoxCPMForConditionalGeneration + # Optional persistent HF-compatible config dir for native VoxCPM models. + hf_config_path: ${oc.env:VLLM_OMNI_VOXCPM_HF_CONFIG_PATH,} + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + enforce_eager: true + trust_remote_code: true + async_scheduling: false + enable_prefix_caching: false + engine_output_type: latent + gpu_memory_utilization: 0.75 + distributed_executor_backend: "mp" + max_num_batched_tokens: 4096 + max_model_len: 4096 + custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.voxcpm.latent2vae_async_chunk + output_connectors: + to_stage_1: connector_of_shared_memory + default_sampling_params: + temperature: 0.0 + top_p: 1.0 + top_k: -1 + max_tokens: 4096 + seed: 42 + detokenize: false + repetition_penalty: 1.0 + final_output: false + + - stage_id: 1 + stage_type: llm + runtime: + devices: "0" + max_batch_size: 1 + engine_args: + dtype: float32 + model_stage: vae + model_arch: VoxCPMForConditionalGeneration + hf_config_path: ${oc.env:VLLM_OMNI_VOXCPM_HF_CONFIG_PATH,} + worker_type: generation + scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler + enforce_eager: true + trust_remote_code: true + async_scheduling: false + enable_prefix_caching: false + engine_output_type: audio + gpu_memory_utilization: 0.1 + distributed_executor_backend: "mp" + max_num_batched_tokens: 8192 + max_model_len: 4096 + engine_input_source: [0] + input_connectors: + from_stage_0: connector_of_shared_memory + final_output: true + final_output_type: audio + default_sampling_params: + temperature: 0.0 + top_p: 1.0 + top_k: -1 + max_tokens: 1 + seed: 42 + detokenize: true + repetition_penalty: 1.0 + +runtime: + enabled: true + defaults: + window_size: -1 + max_inflight: 1 + + connectors: + connector_of_shared_memory: + name: SharedMemoryConnector + extra: + shm_threshold_bytes: 65536 + codec_streaming: false + connector_get_sleep_s: 0.01 + connector_get_max_wait_first_chunk: 3000 + connector_get_max_wait: 300 + + edges: + - from: 0 + to: 1 + window_size: -1 diff --git a/vllm_omni/transformers_utils/configs/__init__.py b/vllm_omni/transformers_utils/configs/__init__.py index 5f957c2f6d..0aa3624f80 100644 --- a/vllm_omni/transformers_utils/configs/__init__.py +++ b/vllm_omni/transformers_utils/configs/__init__.py @@ -17,6 +17,7 @@ "FishSpeechConfig": "vllm_omni.transformers_utils.configs.fish_speech", "FishSpeechSlowARConfig": "vllm_omni.transformers_utils.configs.fish_speech", "FishSpeechFastARConfig": "vllm_omni.transformers_utils.configs.fish_speech", + "VoxCPMConfig": "vllm_omni.transformers_utils.configs.voxcpm", "VoxCPM2Config": "vllm_omni.transformers_utils.configs.voxcpm2", } @@ -28,6 +29,7 @@ "FishSpeechConfig", "FishSpeechSlowARConfig", "FishSpeechFastARConfig", + "VoxCPMConfig", "VoxCPM2Config", ] @@ -49,4 +51,5 @@ def __dir__(): # run as soon as `vllm_omni.transformers_utils.configs` is imported. from vllm_omni.transformers_utils.configs import fish_speech as _fish_speech # noqa: F401, E402 from vllm_omni.transformers_utils.configs import mammoth_moda2 as _mammoth_moda2 # noqa: F401, E402 +from vllm_omni.transformers_utils.configs import voxcpm as _voxcpm # noqa: F401, E402 from vllm_omni.transformers_utils.configs import voxcpm2 as _voxcpm2 # noqa: F401, E402 diff --git a/vllm_omni/transformers_utils/configs/voxcpm.py b/vllm_omni/transformers_utils/configs/voxcpm.py new file mode 100644 index 0000000000..0267838915 --- /dev/null +++ b/vllm_omni/transformers_utils/configs/voxcpm.py @@ -0,0 +1,68 @@ +from transformers import AutoConfig +from transformers.configuration_utils import PretrainedConfig + + +class VoxCPMConfig(PretrainedConfig): + model_type = "voxcpm" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + bos_token_id: int = 1, + eos_token_id: int = 2, + vocab_size: int = 32000, + hidden_size: int = 1024, + intermediate_size: int = 4096, + max_position_embeddings: int = 4096, + num_attention_heads: int = 16, + num_hidden_layers: int = 24, + num_key_value_heads: int = 16, + rms_norm_eps: float = 1e-6, + rope_theta: float = 10000.0, + rope_scaling: dict | None = None, + lm_config: dict | None = None, + encoder_config: dict | None = None, + dit_config: dict | None = None, + audio_vae_config: dict | None = None, + patch_size: int = 2, + feat_dim: int = 64, + residual_lm_num_layers: int = 6, + scalar_quantization_latent_dim: int = 256, + scalar_quantization_scale: int = 9, + max_length: int = 4096, + device: str = "cuda", + dtype: str = "bfloat16", + dit_mean_mode: bool = False, + **kwargs, + ): + super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.max_position_embeddings = max_position_embeddings + self.num_attention_heads = num_attention_heads + self.num_hidden_layers = num_hidden_layers + self.num_key_value_heads = num_key_value_heads + self.rms_norm_eps = rms_norm_eps + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + + self.lm_config = lm_config or {} + self.encoder_config = encoder_config or {} + self.dit_config = dit_config or {} + self.audio_vae_config = audio_vae_config + + self.patch_size = patch_size + self.feat_dim = feat_dim + self.residual_lm_num_layers = residual_lm_num_layers + self.scalar_quantization_latent_dim = scalar_quantization_latent_dim + self.scalar_quantization_scale = scalar_quantization_scale + self.max_length = max_length + self.device = device + self.dtype = dtype + self.dit_mean_mode = dit_mean_mode + + +AutoConfig.register("voxcpm", VoxCPMConfig) + +__all__ = ["VoxCPMConfig"]