diff --git a/.buildkite/test-amd-ready.yaml b/.buildkite/test-amd-ready.yaml index 597733cbb86..47b48e1cf51 100644 --- a/.buildkite/test-amd-ready.yaml +++ b/.buildkite/test-amd-ready.yaml @@ -73,6 +73,17 @@ steps: # - export VLLM_WORKER_MULTIPROC_METHOD=spawn # - timeout 20m pytest -s -v tests/e2e/offline_inference/test_stable_audio_expansion.py -m "advanced_model and diffusion and L4" --run-level advanced_model +- label: "AudioX Online Test" + agent_pool: mi325_1 + depends_on: amd-build + mirror_hardwares: [amdproduction] + grade: Blocking + commands: + - export GPU_ARCHS=gfx942 + - export VLLM_LOGGING_LEVEL=DEBUG + - export VLLM_WORKER_MULTIPROC_METHOD=spawn + - timeout 20m pytest -s -v tests/e2e/online_serving/test_audiox_online.py + - label: "Diffusion Cache Backend Test" agent_pool: mi325_1 depends_on: amd-build diff --git a/.buildkite/test-ready.yml b/.buildkite/test-ready.yml index 080f18885ef..2d46c753a92 100644 --- a/.buildkite/test-ready.yml +++ b/.buildkite/test-ready.yml @@ -120,6 +120,23 @@ steps: volumes: - "/fsx/hf_cache:/fsx/hf_cache" + - label: "AudioX Online Test" + depends_on: upload-ready-pipeline + commands: + - timeout 20m pytest -s -v tests/e2e/online_serving/test_audiox_online.py -m "core_model and diffusion" --run-level core_model + agents: + queue: "gpu_1_queue" # g6.4xlarge instance on AWS, has 1 L4 GPU + plugins: + - docker#v5.2.0: + image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT + always-pull: true + propagate-environment: true + environment: + - "HF_HOME=/fsx/hf_cache" + - "HF_TOKEN" + volumes: + - "/fsx/hf_cache:/fsx/hf_cache" + - label: "Diffusion Cache Backend Test" depends_on: upload-ready-pipeline commands: diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 6b305f16c2e..e0cc17aa1cf 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -53,6 +53,7 @@ th { | `FluxPipeline` | FLUX.1-schnell | `black-forest-labs/FLUX.1-schnell` | ✅︎ | ✅︎ | | ✅︎ | | `OmniGen2Pipeline` | OmniGen2 | `OmniGen2/OmniGen2` | ✅︎ | ✅︎ | | ✅︎ | | `StableAudioPipeline` | Stable-Audio-Open | `stabilityai/stable-audio-open-1.0` | ✅︎ | ✅︎ | | ✅︎ | +| `AudioXPipeline` | AudioX | `zhangj1an/AudioX` | ✅︎ | ✅︎ | | | | `Qwen3TTSForConditionalGeneration` | Qwen3-TTS-12Hz-1.7B-CustomVoice | `Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice` | ✅︎ | ✅︎ | ✅︎ | ✅︎ | | `Qwen3TTSForConditionalGeneration` | Qwen3-TTS-12Hz-1.7B-VoiceDesign | `Qwen/Qwen3-TTS-12Hz-1.7B-VoiceDesign` | ✅︎ | ✅︎ | ✅︎ | ✅︎ | | `Qwen3TTSForConditionalGeneration` | Qwen3-TTS-12Hz-1.7B-Base | `Qwen/Qwen3-TTS-12Hz-0.6B-Base` | ✅︎ | ✅︎ | ✅︎ | ✅︎ | diff --git a/examples/offline_inference/audiox/README.md b/examples/offline_inference/audiox/README.md new file mode 100644 index 00000000000..a0b46f9196f --- /dev/null +++ b/examples/offline_inference/audiox/README.md @@ -0,0 +1,40 @@ +# AudioX offline inference + +Generate audio with the [AudioX](https://zeyuet.github.io/AudioX/) MMDiT diffusion +pipeline (`AudioXPipeline`). Six tasks: `t2a`, `t2m`, `v2a`, `v2m`, `tv2a`, `tv2m`. + +## Prerequisites + +Download a vLLM-Omni weight bundle (component-sharded safetensors): + +```bash +huggingface-cli download zhangj1an/AudioX --local-dir ./audiox_weights +``` + +The Hugging Face id `zhangj1an/AudioX` also works directly without prefetching. + +## Usage + +```bash +# Text-to-audio only (default uses zhangj1an/AudioX from the Hub): +python end2end.py --tasks t2a + +# All six tasks against a local bundle and a sample video for v2*/tv2*: +python end2end.py \ + --model ./audiox_weights \ + --video https://zeyuet.github.io/AudioX/static/samples/V2M/1XeBotOFqHA.mp4 + +# Subset of tasks, custom seed and steps: +python end2end.py --tasks t2a tv2a --num-inference-steps 100 --seed 0 +``` + +## Arguments + +- `--model`: HF id or local bundle path (default: `zhangj1an/AudioX`). +- `--tasks`: any subset of `t2a t2m v2a v2m tv2a tv2m` (default: all). +- `--video`: video file/URL — required for `v2*` and `tv2*`. +- `--reference-audio`: optional audio prompt (audio-conditioned generation). +- `--num-inference-steps`, `--guidance-scale`, `--seed`, `--seconds-total`, + `--sample-rate`, `--output-dir`: generation knobs. + +Outputs land in `/.wav` as 16-bit stereo WAV. diff --git a/examples/offline_inference/audiox/end2end.py b/examples/offline_inference/audiox/end2end.py new file mode 100644 index 00000000000..34b5425724e --- /dev/null +++ b/examples/offline_inference/audiox/end2end.py @@ -0,0 +1,110 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""End-to-end AudioX offline example covering the 6 t2*/v2*/tv2* tasks. + +Provide a directory with the **vLLM-Omni AudioX safetensors bundle** (e.g. from +``zhangj1an/AudioX`` on Hugging Face):: + + huggingface-cli download zhangj1an/AudioX --local-dir ./audiox_weights + python end2end.py --model ./audiox_weights + python end2end.py --model ./audiox_weights --tasks t2a tv2a +""" + +from __future__ import annotations + +import argparse +import time +from pathlib import Path + +import soundfile +import torch +import torchaudio.functional as TF + +from vllm_omni.entrypoints.omni import Omni +from vllm_omni.inputs.data import OmniDiffusionSamplingParams +from vllm_omni.platforms import current_omni_platform + +ROOT = Path(__file__).resolve().parent + +SAMPLE_PROMPTS: dict[str, str] = { + "t2a": "Fireworks burst twice, followed by a period of silence before a clock begins ticking.", + "t2m": "Uplifting ukulele tune for a travel vlog", + "v2a": "", + "v2m": "", + "tv2a": "drum beating sound and human talking", + "tv2m": "uplifting music matching the scene", +} + +ALL_TASKS = ("t2a", "t2m", "v2a", "v2m", "tv2a", "tv2m") +VIDEO_TASKS = frozenset({"v2a", "v2m", "tv2a", "tv2m"}) +TEXT_TASKS = frozenset({"t2a", "t2m", "tv2a", "tv2m"}) + + +def parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser(description="AudioX offline end-to-end (6 t2*/v2*/tv2* tasks).") + p.add_argument("--model", default="zhangj1an/AudioX", help="HF id or local AudioX bundle path.") + p.add_argument("--tasks", nargs="+", default=list(ALL_TASKS), choices=ALL_TASKS) + p.add_argument("--video", default="", help="Video path / URL (required for v2*/tv2*).") + p.add_argument("--reference-audio", default="", help="Optional audio prompt for audio-conditioned generation.") + p.add_argument("--output-dir", default=str(ROOT / "audiox_task_outputs")) + p.add_argument("--num-inference-steps", type=int, default=250) + p.add_argument("--seconds-total", type=float, default=10.0) + p.add_argument("--guidance-scale", type=float, default=6.0) + p.add_argument("--seed", type=int, default=42) + p.add_argument("--sample-rate", type=int, default=48000, help="Output WAV rate (resampled if != model rate).") + return p.parse_args() + + +def save_wav(audio: torch.Tensor, path: Path, sample_rate: int) -> None: + """Write 16-bit PCM WAV. ``audio`` is ``[channels, samples]`` float in [-1, 1].""" + path.parent.mkdir(parents=True, exist_ok=True) + soundfile.write(str(path), audio.clamp(-1.0, 1.0).cpu().T.numpy(), sample_rate, subtype="PCM_16") + + +def main() -> None: + args = parse_args() + + omni = Omni(model=args.model, model_class_name="AudioXPipeline") + + for task in args.tasks: + if task in VIDEO_TASKS and not args.video: + raise SystemExit(f"task={task!r} requires --video") + prompt = SAMPLE_PROMPTS[task] if task in TEXT_TASKS else "" + extra: dict = {"audiox_task": task, "seconds_start": 0.0, "seconds_total": float(args.seconds_total)} + if task in VIDEO_TASKS: + extra["video_path"] = args.video + if args.reference_audio: + extra["audio_path"] = args.reference_audio + + generator = torch.Generator(device=current_omni_platform.device_type).manual_seed(args.seed) + t0 = time.perf_counter() + outputs = omni.generate( + prompt, + OmniDiffusionSamplingParams( + generator=generator, + guidance_scale=args.guidance_scale, + num_inference_steps=args.num_inference_steps, + seed=args.seed, + extra_args=extra, + ), + ) + audio = outputs[0].request_output.multimodal_output.get("audio") + if audio is None: + raise RuntimeError(f"No audio produced for task {task!r}") + audio = torch.as_tensor(audio).detach().cpu().float() + if audio.ndim == 3: + audio = audio[0] + + model_sr = int(outputs[0].request_output.multimodal_output.get("audio_sample_rate") or 44100) + if model_sr != args.sample_rate: + audio = TF.resample(audio, model_sr, args.sample_rate) + + out_path = Path(args.output_dir) / f"{task}.wav" + save_wav(audio, out_path, args.sample_rate) + print(f"[{task}] saved {out_path} ({time.perf_counter() - t0:.2f}s)") + + omni.close() + + +if __name__ == "__main__": + main() diff --git a/examples/online_serving/audiox/README.md b/examples/online_serving/audiox/README.md new file mode 100644 index 00000000000..aad65eb301e --- /dev/null +++ b/examples/online_serving/audiox/README.md @@ -0,0 +1,65 @@ +# AudioX online serving + +Launches the `AudioXPipeline` behind vLLM-Omni's OpenAI-compatible chat endpoint and provides a +minimal Python client that covers all six tasks (`t2a`, `t2m`, `v2a`, `v2m`, `tv2a`, `tv2m`). + +## Start the server + +```bash +cd examples/online_serving/audiox +bash run_server.sh # defaults: MODEL=zhangj1an/AudioX, PORT=8099 +``` + +Environment overrides: `MODEL`, `PORT`, `DIFFUSION_ATTENTION_BACKEND`. + +## Call from Python + +```bash +# text-to-audio +python openai_chat_client.py --task t2a \ + --prompt "Fireworks burst twice, followed by a period of silence before a clock begins ticking." \ + --output t2a.wav + +# text-to-music +python openai_chat_client.py --task t2m \ + --prompt "Uplifting ukulele tune for a travel vlog" \ + --output t2m.wav + +# video-to-audio (no text) +python openai_chat_client.py --task v2a --video path/to/clip.mp4 --output v2a.wav + +# text+video-to-audio +python openai_chat_client.py --task tv2a \ + --prompt "drum beating sound and human talking" \ + --video path/to/clip.mp4 \ + --output tv2a.wav +``` + +The client sends: + +- `num_inference_steps`, `guidance_scale`, `seed` as first-class OpenAI chat-completion fields +- `audiox_task`, `seconds_start`, `seconds_total`, `sigma_min`, `sigma_max` nested under + `extra_args` (a reserved dict on the request body that the server forwards verbatim into + the pipeline's `sampling_params.extra_args` — the same escape hatch `serving_video.py` exposes + as `extra_params` on /v1/videos) +- For `v2*` / `tv2*` tasks, the video as a `video_url` content item (data URI for local files) + +## curl + +```bash +curl -sS -X POST http://localhost:8099/v1/chat/completions \ + -H 'Content-Type: application/json' \ + -d '{ + "model": "zhangj1an/AudioX", + "messages": [{"role": "user", "content": [{"type": "text", "text": "Uplifting ukulele"}]}], + "num_inference_steps": 250, + "guidance_scale": 7.0, + "seed": 42, + "extra_args": { + "audiox_task": "t2m", + "seconds_total": 10.0, + "sigma_min": 0.3, + "sigma_max": 500.0 + } + }' > t2m.json +``` diff --git a/examples/online_serving/audiox/openai_chat_client.py b/examples/online_serving/audiox/openai_chat_client.py new file mode 100755 index 00000000000..8c8b572e50d --- /dev/null +++ b/examples/online_serving/audiox/openai_chat_client.py @@ -0,0 +1,119 @@ +#!/usr/bin/env python3 +"""AudioX OpenAI-compatible chat client. + +AudioX supports 6 tasks (t2a, t2m, v2a, v2m, tv2a, tv2m). Text-only tasks send the prompt as the +chat message; video-conditioned tasks additionally attach the video via a ``video_url`` content +item (data URI for local files). Task + generation knobs (steps, cfg, sigma range, seconds, seed) +are sent via the OpenAI SDK's ``extra_body`` as ``extra_args`` — the same pipeline-agnostic escape +hatch used by the /v1/videos endpoint's ``extra_params`` field. + +Usage: + python openai_chat_client.py --task t2a --prompt "Fireworks burst twice..." --output t2a.wav + python openai_chat_client.py --task tv2a --prompt "drum beating" --video clip.mp4 -o tv2a.wav +""" + +from __future__ import annotations + +import argparse +import base64 +import io +import mimetypes +import sys +from pathlib import Path + +import requests +import soundfile +import torch + +VIDEO_TASKS = frozenset({"v2a", "v2m", "tv2a", "tv2m"}) +TEXT_TASKS = frozenset({"t2a", "t2m", "tv2a", "tv2m"}) + + +def _to_data_url(path: str) -> str: + mime, _ = mimetypes.guess_type(path) + mime = mime or "video/mp4" + with open(path, "rb") as f: + data = base64.b64encode(f.read()).decode("ascii") + return f"data:{mime};base64,{data}" + + +def _save_wav(audio: torch.Tensor, path: Path, sample_rate: int) -> None: + audio = audio.to(torch.float32) + audio = audio / audio.abs().max().clamp(min=1e-8) + path.parent.mkdir(parents=True, exist_ok=True) + # soundfile expects channels-last (T, C); project convention is (C, T). + soundfile.write(str(path), audio.clamp(-1.0, 1.0).cpu().T.numpy(), sample_rate, subtype="PCM_16") + + +def _decode_audio_from_response(body: dict) -> tuple[torch.Tensor, int]: + for choice in body.get("choices", []): + audio_obj = choice.get("message", {}).get("audio") + if not (isinstance(audio_obj, dict) and audio_obj.get("data")): + continue + data, sr = soundfile.read(io.BytesIO(base64.b64decode(audio_obj["data"])), dtype="float32", always_2d=True) + return torch.from_numpy(data).transpose(0, 1), sr + brief = {k: v for k, v in body.items() if k != "choices"} + raise RuntimeError(f"no audio in response message.audio: {brief}") + + +def main() -> int: + p = argparse.ArgumentParser(description="AudioX OpenAI chat client") + p.add_argument("--task", required=True, choices=["t2a", "t2m", "v2a", "v2m", "tv2a", "tv2m"]) + p.add_argument("--prompt", "-p", default="", help="Text prompt (required for t2*/tv2*).") + p.add_argument("--video", help="Video path or URL (required for v2*/tv2*).") + p.add_argument("--output", "-o", default="audiox_out.wav") + p.add_argument("--server", "-s", default="http://localhost:8099") + p.add_argument("--model", default="zhangj1an/AudioX") + p.add_argument("--steps", type=int, default=250) + p.add_argument("--guidance-scale", type=float, default=7.0) + p.add_argument("--seed", type=int, default=42) + p.add_argument("--seconds-total", type=float, default=10.0) + p.add_argument("--seconds-start", type=float, default=0.0) + p.add_argument("--sigma-min", type=float, default=0.03) + p.add_argument("--sigma-max", type=float, default=1000.0) + args = p.parse_args() + + if args.task in VIDEO_TASKS and not args.video: + print(f"ERROR: task {args.task!r} requires --video", file=sys.stderr) + return 2 + if args.task in TEXT_TASKS and not args.prompt.strip() and args.task not in {"v2a", "v2m"}: + print(f"ERROR: task {args.task!r} requires --prompt", file=sys.stderr) + return 2 + + content: list[dict] = [{"type": "text", "text": args.prompt}] + if args.task in VIDEO_TASKS: + vurl = args.video if args.video.startswith(("http://", "https://")) else _to_data_url(args.video) + content.append({"type": "video_url", "video_url": {"url": vurl}}) + + payload = { + "model": args.model, + "messages": [{"role": "user", "content": content}], + "num_inference_steps": args.steps, + "guidance_scale": args.guidance_scale, + "seed": args.seed, + "extra_args": { + "audiox_task": args.task, + "seconds_start": args.seconds_start, + "seconds_total": args.seconds_total, + "sigma_min": args.sigma_min, + "sigma_max": args.sigma_max, + }, + } + + print(f"POST {args.server}/v1/chat/completions task={args.task} steps={args.steps}") + r = requests.post( + f"{args.server}/v1/chat/completions", + headers={"Content-Type": "application/json"}, + json=payload, + timeout=600, + ) + r.raise_for_status() + audio, sr = _decode_audio_from_response(r.json()) + _save_wav(audio, Path(args.output), sr) + dur = audio.shape[-1] / sr + print(f"saved {args.output} sr={sr}Hz duration={dur:.2f}s channels={audio.shape[0]}") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/examples/online_serving/audiox/run_server.sh b/examples/online_serving/audiox/run_server.sh new file mode 100755 index 00000000000..d1a259bfd7a --- /dev/null +++ b/examples/online_serving/audiox/run_server.sh @@ -0,0 +1,16 @@ +#!/bin/bash +# AudioX online serving startup script. + +MODEL="${MODEL:-zhangj1an/AudioX}" +PORT="${PORT:-8099}" +DIFFUSION_ATTENTION_BACKEND="${DIFFUSION_ATTENTION_BACKEND:-FLASH_ATTN}" + +echo "Starting AudioX server..." +echo "Model: $MODEL" +echo "Port: $PORT" +echo "Diffusion attention backend: $DIFFUSION_ATTENTION_BACKEND" + +DIFFUSION_ATTENTION_BACKEND="$DIFFUSION_ATTENTION_BACKEND" \ + vllm serve "$MODEL" --omni \ + --model-class-name AudioXPipeline \ + --port "$PORT" diff --git a/recipes/README.md b/recipes/README.md index 01f76d1a847..0c490dace81 100644 --- a/recipes/README.md +++ b/recipes/README.md @@ -34,6 +34,8 @@ recipes/ - [`Baidu/ERNIE-Image.md`](./Baidu/ERNIE-Image.md): text-to-image serving online serving recipe for ERNIE-Image 8B on `1x RTX 4090 24GB` or `2x RTX 4090 24GB` - [`fishaudio/Fish-Speech-S2-Pro.md`](./fishaudio/Fish-Speech-S2-Pro.md): online serving recipe for TTS on `1x A800 80GB` +- [`audiox/AudioX.md`](./audiox/AudioX.md): offline + online recipe for AudioX + unified text/video→audio diffusion on `1x L4 24GB` Within a single recipe file, include different hardware support sections such as `GPU`, `ROCm`, and `NPU`, and add concrete tested configurations like diff --git a/recipes/audiox/AudioX.md b/recipes/audiox/AudioX.md new file mode 100644 index 00000000000..6365bd02136 --- /dev/null +++ b/recipes/audiox/AudioX.md @@ -0,0 +1,95 @@ +# AudioX + +> AudioX MMDiT for unified audio + music generation: t2a / t2m / v2a / v2m / tv2a / tv2m + +## Summary + +- Vendor: HKUSTAudio (project), `zhangj1an/AudioX` weight bundle +- Model: `zhangj1an/AudioX` +- Task: Text/video → audio or music. Six tasks: `t2a`, `t2m`, `v2a`, `v2m`, `tv2a`, `tv2m`. +- Mode: Offline inference + online serving (pure diffusion) +- Maintainer: Community + +## When to use this recipe + +Use this recipe to run AudioX for sound-effect (`*2a`) or music (`*2m`) generation +from a text prompt and/or video clip. AudioX is a unified diffusion transformer +that produces stereo 44.1 kHz audio up to ~10 s per call. + +## References + +- Project page: +- vLLM-Omni weight bundle: +- Pipeline: `vllm_omni.diffusion.models.audiox.pipeline_audiox.AudioXPipeline` +- Input transforms: `vllm_omni.transformers_utils.processors.audiox` +- Offline example: [`examples/offline_inference/audiox/`](../../examples/offline_inference/audiox/) +- Online example: [`examples/online_serving/audiox/`](../../examples/online_serving/audiox/) + +## Hardware Support + +## GPU + +### 1x L4 24GB + +#### Environment + +- OS: Ubuntu 22.04 +- Python: 3.10+ +- Driver / runtime: CUDA 12.4 +- vLLM version: 0.20.0 +- vLLM-Omni version: 0.1.x + +#### Command + +Offline (text-to-audio): + +```bash +huggingface-cli download zhangj1an/AudioX --local-dir ./audiox_weights +python examples/offline_inference/audiox/end2end.py \ + --model ./audiox_weights \ + --tasks t2a \ + --num-inference-steps 250 \ + --seconds-total 10 +``` + +Online: + +```bash +bash examples/online_serving/audiox/run_server.sh +python examples/online_serving/audiox/openai_chat_client.py \ + --task t2a \ + --prompt "Fireworks burst twice, followed by a clock ticking." \ + --output t2a.wav +``` + +#### Verification + +```bash +# Health check +curl http://localhost:8099/health + +# Listen to the saved file (stereo, 44.1 kHz, sigma_min=0.03, sigma_max=1000 — upstream defaults) +ffprobe t2a.wav +``` + +#### Notes + +- Memory usage: ~10 GB peak with `num_inference_steps=250`, 10 s of audio. +- Output rate: 44.1 kHz stereo, regardless of `--sample-rate` (resampled in the example + script if requested). +- Supported tasks: `t2a`, `t2m`, `v2a`, `v2m`, `tv2a`, `tv2m`. Pass via + `extra_args["audiox_task"]` (offline) or the `extra_args` field in the OpenAI + chat-completions body (online). +- Video conditioning: `v2*` and `tv2*` require a video file; the online client + attaches it as an OpenAI `video_url` content item (data URI for local files). +- Cache acceleration is **not** supported (AudioXPipeline is in `_NO_CACHE_ACCELERATION`). +- Tensor parallelism is supported via `--tensor-parallel-size` (DiT QKV is sharded with + `QKVParallelLinear`); cross-attention K/V is also TP-sharded. + +### Known limitations + +- Inference uses an inlined DPM-Solver++(3M) SDE sampler (k-diffusion port). Replacing it with + diffusers' `EDMDPMSolverMultistepScheduler` introduces a fixed ~861 Hz resonance and is not + recommended. +- Generation is fixed at 10 s (configured by the bundle's `sample_size`); longer outputs require + a different bundle. diff --git a/tests/e2e/offline_inference/test_audiox_model.py b/tests/e2e/offline_inference/test_audiox_model.py new file mode 100644 index 00000000000..f353b24513d --- /dev/null +++ b/tests/e2e/offline_inference/test_audiox_model.py @@ -0,0 +1,75 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import os +import sys +from pathlib import Path + +import numpy as np +import pytest +import torch + +from tests.helpers.mark import hardware_test +from vllm_omni.inputs.data import OmniDiffusionSamplingParams +from vllm_omni.outputs import OmniRequestOutput +from vllm_omni.platforms import current_omni_platform + +# ruff: noqa: E402 +REPO_ROOT = Path(__file__).resolve().parents[2] +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + +from vllm_omni import Omni + +# Prefer a tiny/random checkpoint for CI. +# Override in CI if needed: AUDIOX_TEST_MODEL= +models = [os.environ.get("AUDIOX_TEST_MODEL", "zhangj1an/audiox_random")] + + +@pytest.mark.core_model +@pytest.mark.diffusion +@hardware_test(res={"cuda": "L4", "xpu": "B60"}) +@pytest.mark.parametrize("model_name", models) +def test_audiox_model(model_name: str): + m = Omni(model=model_name, model_class_name="AudioXPipeline") + + # Keep runtime short for CI. + seconds_total = 2.0 + # AudioXPipeline always emits 44.1 kHz stereo (advertised via class-level + # ``audio_sample_rate``); the trimmed output should match this rate. + sample_rate = 44100 + + outputs = m.generate( + prompts={"prompt": "A dog barking in a quiet park."}, + sampling_params_list=OmniDiffusionSamplingParams( + num_inference_steps=4, + guidance_scale=6.0, + generator=torch.Generator(current_omni_platform.device_type).manual_seed(42), + num_outputs_per_prompt=1, + extra_args={ + "audiox_task": "t2a", + "seconds_start": 0.0, + "seconds_total": seconds_total, + }, + ), + ) + + assert outputs is not None + first_output = outputs[0] + assert first_output.final_output_type == "audio" + assert hasattr(first_output, "request_output") and first_output.request_output + + req_out = first_output.request_output + assert isinstance(req_out, OmniRequestOutput) + assert req_out.final_output_type == "audio" + assert hasattr(req_out, "multimodal_output") and req_out.multimodal_output + + audio = req_out.multimodal_output.get("audio") + assert isinstance(audio, np.ndarray) + # audio shape: (batch, channels, samples) + assert audio.ndim == 3 + assert audio.shape[0] == 1 + assert audio.shape[1] == 2 + assert audio.shape[2] > 0 + expected_samples = int(seconds_total * sample_rate) + assert abs(audio.shape[2] - expected_samples) <= 2 * 1024 diff --git a/tests/e2e/offline_inference/test_stable_audio_expansion.py b/tests/e2e/offline_inference/test_stable_audio_expansion.py index a5d3e6d2281..a7968aef366 100644 --- a/tests/e2e/offline_inference/test_stable_audio_expansion.py +++ b/tests/e2e/offline_inference/test_stable_audio_expansion.py @@ -56,13 +56,9 @@ def generate_stable_audio_short_clip( assert outputs is not None first_output = outputs[0] - # Outer OmniRequestOutput.final_output_type comes from get_stage_metadata. - # The nested request_output is the worker OmniRequestOutput - # (e.g. final_output_type="audio") and holds the multimodal payload. - # Follow-up: add StableAudioPipeline stage YAML, and pass model into - # _create_default_diffusion_stage_cfg so default diffusion metadata can set - # final_output_type to "audio" for future audio pipelines without YAML. - assert first_output.final_output_type == "image" + # Audio-output diffusion pipelines (those with ``support_audio_output = True``) now have + # ``final_output_type="audio"`` set on the outer stage metadata as well as the inner request. + assert first_output.final_output_type == "audio" assert hasattr(first_output, "request_output") and first_output.request_output req_out = first_output.request_output diff --git a/tests/e2e/online_serving/test_audiox_online.py b/tests/e2e/online_serving/test_audiox_online.py new file mode 100644 index 00000000000..6ef84b206e3 --- /dev/null +++ b/tests/e2e/online_serving/test_audiox_online.py @@ -0,0 +1,55 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +E2E online serving test for AudioX text-to-audio diffusion. + +Mirrors `tests/e2e/online_serving/test_sd3_expansion.py` for image diffusion: spin up +`vllm-omni` with `--model-class-name AudioXPipeline` and validate via the standard +`send_diffusion_request` helper (now audio-aware). +""" + +import os + +import pytest + +from tests.helpers.mark import hardware_marks +from tests.helpers.runtime import OmniServer, OmniServerParams, OpenAIClientHandler, dummy_messages_from_mix_data + +# Tiny / random checkpoint usable in CI; override to a real bundle locally. +AUDIOX_TEST_MODEL = os.environ.get("AUDIOX_TEST_MODEL", "zhangj1an/audiox_random") +T2A_PROMPT = "A quiet living room with soft fabric rustle and gentle cat breathing." + +SINGLE_CARD_FEATURE_MARKS = hardware_marks(res={"cuda": "L4"}) + + +def _audiox_server_cases(model: str): + return [ + pytest.param( + OmniServerParams( + model=model, + server_args=["--model-class-name", "AudioXPipeline"], + ), + id="t2a", + marks=SINGLE_CARD_FEATURE_MARKS, + ), + ] + + +@pytest.mark.core_model +@pytest.mark.diffusion +@pytest.mark.parametrize("omni_server", _audiox_server_cases(AUDIOX_TEST_MODEL), indirect=True) +def test_audiox_t2a_online(omni_server: OmniServer, openai_client: OpenAIClientHandler) -> None: + """AudioX text-to-audio: chat completion returns a non-empty WAV in `message.audio.data`.""" + request_config = { + "model": omni_server.model, + "messages": dummy_messages_from_mix_data(content_text=T2A_PROMPT), + "extra_body": { + "num_inference_steps": 4, + "guidance_scale": 6.0, + "seed": 42, + "audiox_task": "t2a", + "seconds_start": 0.0, + "seconds_total": 2.0, + }, + } + openai_client.send_diffusion_request(request_config) diff --git a/tests/helpers/assertions.py b/tests/helpers/assertions.py index 8b7ad1a2772..1ef5fbc45a8 100644 --- a/tests/helpers/assertions.py +++ b/tests/helpers/assertions.py @@ -3,6 +3,7 @@ import io import tempfile import threading +import wave from io import BytesIO from pathlib import Path from typing import Any @@ -115,8 +116,17 @@ def assert_audio_diffusion_response( ) -> None: """ Validate audio diffusion response. + + `response.audios` carries one entry per choice, each a `dict` with raw WAV + bytes (`wav_bytes`) and the OpenAI audio metadata (`id`, `expires_at`). """ - raise NotImplementedError("Audio validation is not implemented yet") + assert response.audios, "Audio response is empty" + for audio in response.audios: + wav_bytes = audio.get("wav_bytes") + assert wav_bytes, "Audio entry missing decoded WAV bytes" + with wave.open(io.BytesIO(wav_bytes), "rb") as wav_file: + assert wav_file.getnframes() > 0, "Decoded WAV has zero frames" + assert wav_file.getframerate() > 0, "Decoded WAV has invalid sample rate" def _maybe_int(value: Any) -> int | None: diff --git a/tests/helpers/runtime.py b/tests/helpers/runtime.py index ab9b6cda9f9..0b08e350ba5 100644 --- a/tests/helpers/runtime.py +++ b/tests/helpers/runtime.py @@ -701,6 +701,7 @@ def _process_diffusion_response(self, chat_completion) -> DiffusionResponse: start_time = time.perf_counter() try: images = [] + audios = [] for choice in chat_completion.choices: content = getattr(choice.message, "content", None) if isinstance(content, list): @@ -714,8 +715,21 @@ def _process_diffusion_response(self, chat_completion) -> DiffusionResponse: if image_url and image_url.startswith("data:image"): b64_data = image_url.split(",", 1)[1] images.append(decode_b64_image(b64_data)) + + # OpenAI audio responses (e.g. AudioX text-to-audio) populate `message.audio`. + audio_obj = getattr(choice.message, "audio", None) + audio_b64 = getattr(audio_obj, "data", None) if audio_obj is not None else None + if audio_b64: + audios.append( + { + "wav_bytes": base64.b64decode(audio_b64), + "id": getattr(audio_obj, "id", None), + "expires_at": getattr(audio_obj, "expires_at", None), + } + ) result.e2e_latency = time.perf_counter() - start_time result.images = images if images else None + result.audios = audios if audios else None result.success = True except Exception as e: result.error_message = f"Diffusion response processing error: {str(e)}" diff --git a/vllm_omni/diffusion/diffusion_engine.py b/vllm_omni/diffusion/diffusion_engine.py index 80ef0df666c..db5a85cfbf6 100644 --- a/vllm_omni/diffusion/diffusion_engine.py +++ b/vllm_omni/diffusion/diffusion_engine.py @@ -199,6 +199,16 @@ def step(self, request: OmniDiffusionRequest) -> list[OmniRequestOutput]: # Handle single request or multiple requests is_audio_output = supports_audio_output(self.od_config.model_class_name) + if is_audio_output and model_audio_sample_rate is None: + model_cls = DiffusionModelRegistry._try_load_model_cls(self.od_config.model_class_name) + model_audio_sample_rate = getattr(model_cls, "audio_sample_rate", None) + + def _audio_mm(payload: Any) -> dict[str, Any]: + mm: dict[str, Any] = {"audio": payload} + if model_audio_sample_rate is not None: + mm["audio_sample_rate"] = model_audio_sample_rate + return mm + if len(request.prompts) == 1: # Single request: return single OmniRequestOutput prompt = request.prompts[0] @@ -217,7 +227,7 @@ def step(self, request: OmniDiffusionRequest) -> list[OmniRequestOutput]: trajectory_timesteps=output.trajectory_timesteps, trajectory_log_probs=output.trajectory_log_probs, trajectory_decoded=output.trajectory_decoded, - multimodal_output={"audio": request_audio_payload}, + multimodal_output=_audio_mm(request_audio_payload), final_output_type="audio", stage_durations=output.stage_durations, peak_memory_mb=output.peak_memory_mb, @@ -277,7 +287,7 @@ def step(self, request: OmniDiffusionRequest) -> list[OmniRequestOutput]: trajectory_timesteps=output.trajectory_timesteps, trajectory_log_probs=output.trajectory_log_probs, trajectory_decoded=output.trajectory_decoded, - multimodal_output={"audio": request_audio_payload}, + multimodal_output=_audio_mm(request_audio_payload), final_output_type="audio", stage_durations=output.stage_durations, peak_memory_mb=output.peak_memory_mb, diff --git a/vllm_omni/diffusion/layers/fourier.py b/vllm_omni/diffusion/layers/fourier.py new file mode 100644 index 00000000000..ff450f90667 --- /dev/null +++ b/vllm_omni/diffusion/layers/fourier.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +import math + +import torch +from torch import nn + + +class GaussianFourierProjection(nn.Module): + """Shared Gaussian Fourier features with optional trainable frequencies.""" + + def __init__( + self, + *, + in_features: int, + embedding_size: int, + scale: float = 1.0, + trainable: bool = True, + ) -> None: + super().__init__() + self.in_features = int(in_features) + self.embedding_size = int(embedding_size) + weight = torch.randn(self.embedding_size, self.in_features) * scale + self.weight = nn.Parameter(weight, requires_grad=trainable) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.reshape(-1, self.in_features) + x_proj = 2 * math.pi * x @ self.weight.T + return torch.cat([x_proj.cos(), x_proj.sin()], dim=-1) diff --git a/vllm_omni/diffusion/models/audiox/__init__.py b/vllm_omni/diffusion/models/audiox/__init__.py new file mode 100644 index 00000000000..cec1470f249 --- /dev/null +++ b/vllm_omni/diffusion/models/audiox/__init__.py @@ -0,0 +1,11 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from .audiox_transformer import MMDiffusionTransformer +from .pipeline_audiox import AudioXPipeline, get_audiox_post_process_func + +__all__ = [ + "AudioXPipeline", + "MMDiffusionTransformer", + "get_audiox_post_process_func", +] diff --git a/vllm_omni/diffusion/models/audiox/audiox_transformer.py b/vllm_omni/diffusion/models/audiox/audiox_transformer.py new file mode 100644 index 00000000000..9377f2b03f2 --- /dev/null +++ b/vllm_omni/diffusion/models/audiox/audiox_transformer.py @@ -0,0 +1,487 @@ +from __future__ import annotations + +import logging +from collections.abc import Iterable +from typing import Any + +import torch +import torch.nn as nn +from einops import rearrange +from torch.nn import functional as F +from vllm.distributed import ( + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce, +) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, +) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader, sharded_weight_loader + +from vllm_omni.diffusion.attention.layer import Attention +from vllm_omni.diffusion.layers.fourier import GaussianFourierProjection +from vllm_omni.diffusion.layers.rope import RotaryEmbedding + + +class AudioXCrossAttention(nn.Module): + def __init__(self, dim: int, nheads: int, prefix: str = ""): + super().__init__() + self.dim = dim + self.nheads = nheads + head_dim = dim // nheads + self.head_dim = head_dim + + # to_kv bundle weights arrive in (head, dim, VK-index) interleaved order; pipeline's + # load_weights restacks them to [V|K] before MergedColumnParallelLinear consumes them. + self.to_q = ColumnParallelLinear(dim, dim, bias=False, gather_output=False, prefix=f"{prefix}.to_q") + self.to_kv = MergedColumnParallelLinear( + input_size=dim, output_sizes=[dim, dim], bias=False, gather_output=False, prefix=f"{prefix}.to_kv" + ) + self.q_norm = AudioXRMSNorm(head_dim) + self.k_norm = AudioXRMSNorm(head_dim) + local_nheads = nheads // get_tensor_model_parallel_world_size() + self.attn = Attention( + num_heads=local_nheads, + head_size=head_dim, + softmax_scale=head_dim**-0.5, + causal=False, + ) + + def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor: + tp = get_tensor_model_parallel_world_size() + local_h = self.nheads // tp + d = self.head_dim + + # ROCm's ``wvSplitK`` GEMM bypasses autocast and rejects mismatched dtypes, + # so pre-match the input to the linear's weight dtype. + weight_dtype = self.to_q.weight.dtype + q_flat, _ = self.to_q(x.to(weight_dtype)) + q = rearrange(q_flat, "b n (h d) -> b n h d", h=local_h, d=d) + kv_flat, _ = self.to_kv(context.to(weight_dtype)) + v_flat, k_flat = kv_flat.chunk(2, dim=-1) + # Upstream `CrossAttention.forward` unpacks `pre_attention()` as `q, v, k = ...` (K/V + # swapped), AND only normalizes the first chunk of to_kv via k_norm. Trained weights + # depend on this quirk: first chunk normalized -> V; second chunk unnormalized -> K. + v = rearrange(v_flat, "b n (h d) -> b n h d", h=local_h, d=d) + k = rearrange(k_flat, "b n (h d) -> b n h d", h=local_h, d=d) + q = self.q_norm(q) + v = self.k_norm(v) + + out = self.attn(q.contiguous(), k.contiguous(), v.contiguous(), attn_metadata=None) + out = rearrange(out, "b n h d -> b n (h d)").contiguous() + if tp > 1: + out = tensor_model_parallel_all_gather(out, dim=-1) + return out + + +logger = logging.getLogger(__name__) + +__all__ = [ + "AudioXMMChannelLastConv1d", + "AudioXMMConvFeedForward", + "AudioXMMDiTSelfAttention", + "AudioXMMDiTBlock", + "ContinuousMMDiTTransformer", + "MMDiffusionTransformer", +] + + +class AudioXMMChannelLastConv1d(nn.Conv1d): + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = rearrange(x, "b n c -> b c n") + x = super().forward(x) + x = rearrange(x, "b c n -> b n c") + return x + + +class _ColumnParallelChannelLastConv1d(AudioXMMChannelLastConv1d): + def __init__(self, in_channels: int, out_channels_total: int, **kwargs: Any): + tp_size = get_tensor_model_parallel_world_size() + assert out_channels_total % tp_size == 0, (out_channels_total, tp_size) + super().__init__(in_channels, out_channels_total // tp_size, **kwargs) + self.weight.weight_loader = sharded_weight_loader(0) + + +class _RowParallelChannelLastConv1d(AudioXMMChannelLastConv1d): + def __init__(self, in_channels_total: int, out_channels: int, **kwargs: Any): + tp_size = get_tensor_model_parallel_world_size() + assert in_channels_total % tp_size == 0, (in_channels_total, tp_size) + super().__init__(in_channels_total // tp_size, out_channels, **kwargs) + self._tp_size = tp_size + self.weight.weight_loader = sharded_weight_loader(1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + y = super().forward(x) + if self._tp_size > 1: + y = tensor_model_parallel_all_reduce(y) + return y + + +class AudioXMMConvFeedForward(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int = 256, + kernel_size: int = 3, + padding: int = 1, + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.w1 = _ColumnParallelChannelLastConv1d( + dim, hidden_dim, bias=False, kernel_size=kernel_size, padding=padding + ) + self.w2 = _RowParallelChannelLastConv1d(hidden_dim, dim, bias=False, kernel_size=kernel_size, padding=padding) + self.w3 = _ColumnParallelChannelLastConv1d( + dim, hidden_dim, bias=False, kernel_size=kernel_size, padding=padding + ) + + def forward(self, x): + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + +class AudioXRMSNorm(nn.Module): + def __init__(self, hidden_size: int, eps: float = 1e-6) -> None: + super().__init__() + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + mean_sq = torch.mean(x**2, dim=-1, keepdim=True) + scale = torch.rsqrt(mean_sq + self.eps) + return x * scale + + +class AudioXMMDiTSelfAttention(nn.Module): + def __init__(self, dim: int, nheads: int, prefix: str = ""): + super().__init__() + self.dim = dim + self.nheads = nheads + + head_dim = dim // nheads + self.head_dim = head_dim + # Bundle weights arrive in interleaved (h, d, qkv) layout; pipeline's load_weights + # restacks to [Q|K|V] before QKVParallelLinear's weight_loader consumes them. + self.qkv = QKVParallelLinear( + hidden_size=dim, + head_size=head_dim, + total_num_heads=nheads, + bias=True, + prefix=f"{prefix}.qkv", + ) + self.q_norm = AudioXRMSNorm(head_dim) + self.k_norm = AudioXRMSNorm(head_dim) + + self.rope = RotaryEmbedding(is_neox_style=False) + self.attn = Attention( + num_heads=self.qkv.num_heads, + head_size=head_dim, + softmax_scale=head_dim**-0.5, + causal=False, + ) + + def apply_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + out = self.attn(q.contiguous(), k.contiguous(), v.contiguous(), attn_metadata=None) + out = rearrange(out, "b n h d -> b n (h d)").contiguous() + # Downstream linear1/ffn are replicated and expect full hidden dim. + if get_tensor_model_parallel_world_size() > 1: + out = tensor_model_parallel_all_gather(out, dim=-1) + return out + + def pre_attention(self, x: torch.Tensor, rot: tuple[torch.Tensor, torch.Tensor] | None = None): + # ROCm GEMM dtype cast — see AudioXCrossAttention.forward. + qkv, _ = self.qkv(x.to(self.qkv.weight.dtype)) + local_h = self.qkv.num_heads + d = self.head_dim + q_size = local_h * d + q, k, v = qkv.split([q_size, q_size, q_size], dim=-1) + q = rearrange(q, "b n (h d) -> b n h d", h=local_h, d=d) + k = rearrange(k, "b n (h d) -> b n h d", h=local_h, d=d) + v = rearrange(v, "b n (h d) -> b n h d", h=local_h, d=d) + q = self.q_norm(q) + k = self.k_norm(k) + + if rot is not None: + cos, sin = rot + cos = cos.to(dtype=q.dtype) + sin = sin.to(dtype=q.dtype) + q = self.rope(q, cos, sin) + k = self.rope(k, cos, sin) + + return q, k, v + + def forward( + self, + x: torch.Tensor, + rot: tuple[torch.Tensor, torch.Tensor] | None = None, + ) -> torch.Tensor: + q, k, v = self.pre_attention(x, rot=rot) + return self.apply_attention(q, k, v) + + +class AudioXMMDiTBlock(nn.Module): + def __init__( + self, + dim: int, + nhead: int, + mlp_ratio: float = 4.0, + prefix: str = "", + ): + super().__init__() + self.norm1 = nn.LayerNorm(dim, elementwise_affine=False) + self.attn = AudioXMMDiTSelfAttention(dim, nhead, prefix=f"{prefix}.attn") + self.cross_attn = AudioXCrossAttention(dim, nhead, prefix=f"{prefix}.cross_attn") + self.linear1 = AudioXMMChannelLastConv1d(dim, dim, kernel_size=3, padding=1) + self.norm2 = nn.LayerNorm(dim, elementwise_affine=False) + self.ffn = AudioXMMConvFeedForward(dim, int(dim * mlp_ratio), kernel_size=3, padding=1) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim, 6 * dim, bias=True)) + + def pre_attention(self, x: torch.Tensor, c: torch.Tensor, rot: tuple[torch.Tensor, torch.Tensor] | None): + modulation = self.adaLN_modulation(c) + (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp) = modulation.chunk(6, dim=-1) + x = self.norm1(x) * (1 + scale_msa) + shift_msa + q, k, v = self.attn.pre_attention(x, rot) + return (q, k, v), (gate_msa, shift_mlp, scale_mlp, gate_mlp) + + def post_attention(self, x: torch.Tensor, attn_out: torch.Tensor, c: tuple[torch.Tensor, ...], context=None): + (gate_msa, shift_mlp, scale_mlp, gate_mlp) = c + x = x + self.linear1(attn_out) * gate_msa + + x = x + self.cross_attn(x, context=context) + + r = self.norm2(x) * (1 + scale_mlp) + shift_mlp + x = x + self.ffn(r) * gate_mlp + return x + + def forward( + self, + x: torch.Tensor, + cond: torch.Tensor, + rot: tuple[torch.Tensor, torch.Tensor] | None, + context: torch.Tensor = None, + ) -> torch.Tensor: + x_qkv, x_conditions = self.pre_attention(x, cond, rot) + attn_out = self.attn.apply_attention(*x_qkv) + x = self.post_attention(x, attn_out, x_conditions, context=context) + return x + + +class MMDiffusionTransformer(nn.Module): + """AudioX MMDiT, specialized for the published bundle (`zhangj1an/AudioX`). + + The bundle fixes patch_size=1, transformer_type="continuous_transformer", + cond_token_dim=768 (>0, project_cond_tokens=False), and never sets + prepend_cond_dim or input_concat_dim, so those code paths are removed. + """ + + def __init__( + self, + io_channels: int, + embed_dim: int, + cond_token_dim: int, + global_cond_dim: int, + depth: int, + num_heads: int, + project_cond_tokens: bool = False, + project_global_cond: bool = True, + **kwargs, + ): + super().__init__() + if kwargs: + logger.debug("MMDiffusionTransformer ignoring unused config keys: %s", sorted(kwargs.keys())) + if project_cond_tokens: + raise ValueError("AudioX bundle requires project_cond_tokens=False to match official checkpoints.") + + self.cond_token_dim = cond_token_dim + + timestep_features_dim = 256 + self.timestep_features = GaussianFourierProjection( + in_features=1, + embedding_size=timestep_features_dim // 2, + scale=1.0, + trainable=False, + ) + self.to_timestep_embed = nn.Sequential( + nn.Linear(timestep_features_dim, embed_dim, bias=True), + nn.SiLU(), + nn.Linear(embed_dim, embed_dim, bias=True), + ) + + cond_embed_dim = cond_token_dim if not project_cond_tokens else embed_dim + self.to_cond_embed = nn.Sequential( + nn.Linear(cond_token_dim, cond_embed_dim, bias=False), + nn.SiLU(), + nn.Linear(cond_embed_dim, cond_embed_dim, bias=False), + ) + + # ``to_global_embed`` weights live in the bundle but global conditioning is always None + # at inference; kept so AutoWeightsLoader has a slot to load them into. + global_embed_dim = global_cond_dim if not project_global_cond else embed_dim + self.to_global_embed = nn.Sequential( + nn.Linear(global_cond_dim, global_embed_dim, bias=False), + nn.SiLU(), + nn.Linear(global_embed_dim, global_embed_dim, bias=False), + ) + + self.transformer = ContinuousMMDiTTransformer( + dim=embed_dim, + depth=depth, + dim_heads=embed_dim // num_heads, + dim_in=io_channels, + dim_out=io_channels, + ) + + self.preprocess_conv = nn.Conv1d(io_channels, io_channels, 1, bias=False) + nn.init.zeros_(self.preprocess_conv.weight) + self.postprocess_conv = nn.Conv1d(io_channels, io_channels, 1, bias=False) + nn.init.zeros_(self.postprocess_conv.weight) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + def _forward(self, x, t, cross_attn_cond): + timestep_embed = self.to_timestep_embed(self.timestep_features(t[:, None])) + prepend_inputs = timestep_embed.unsqueeze(1) + prepend_length = prepend_inputs.shape[1] + + x = self.preprocess_conv(x) + x + x = rearrange(x, "b c n -> b n c") + output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond) + output = rearrange(output, "b n c -> b c n")[:, :, prepend_length:] + return self.postprocess_conv(output) + output + + def forward( + self, + x, + t, + cross_attn_cond, + negative_cross_attn_cond=None, + negative_cross_attn_mask=None, + cfg_scale: float = 1.0, + scale_phi: float = 0.0, + **kwargs, + ): + if cfg_scale == 1.0: + return self._forward(x, t, cross_attn_cond) + + # Classifier-free guidance: batch the conditional + unconditional pass. + null_embed = torch.zeros_like(cross_attn_cond) + if negative_cross_attn_cond is not None and negative_cross_attn_mask is not None: + mask = negative_cross_attn_mask.to(torch.bool).unsqueeze(2) + negative_cross_attn_cond = torch.where(mask, negative_cross_attn_cond, null_embed) + uncond = negative_cross_attn_cond if negative_cross_attn_cond is not None else null_embed + + batch_output = self._forward( + torch.cat([x, x], dim=0), + torch.cat([t, t], dim=0), + torch.cat([cross_attn_cond, uncond], dim=0), + ) + cond_output, uncond_output = torch.chunk(batch_output, 2, dim=0) + cfg_output = uncond_output + (cond_output - uncond_output) * cfg_scale + + if scale_phi == 0.0: + return cfg_output + cond_std = cond_output.std(dim=1, keepdim=True) + cfg_std = cfg_output.std(dim=1, keepdim=True) + return scale_phi * (cfg_output * (cond_std / cfg_std)) + (1 - scale_phi) * cfg_output + + +class ContinuousMMDiTTransformer(nn.Module): + def __init__( + self, + dim, + depth, + *, + dim_in=None, + dim_out=None, + dim_heads=64, + _latent_seq_len=237, + ): + super().__init__() + + self.dim = dim + self.depth = depth + + self.project_in = nn.Linear(dim_in, dim, bias=False) if dim_in is not None else nn.Identity() + self.project_out = nn.Linear(dim, dim_out, bias=False) if dim_out is not None else nn.Identity() + + hidden_dim = dim + num_heads = dim_heads + mlp_ratio = 4.0 + self._latent_seq_len = _latent_seq_len + + self.layers = nn.ModuleList( + [ + AudioXMMDiTBlock( + hidden_dim, + num_heads, + mlp_ratio=mlp_ratio, + prefix=f"layers.{i}", + ) + for i in range(depth) + ] + ) + self.proj_mm_tokens = nn.Linear(768, hidden_dim) if dim != 768 else nn.Identity() + self.proj_mm_seq_len = nn.Linear(384, self._latent_seq_len) if self._latent_seq_len != 384 else nn.Identity() + + # AudioX RoPE: interleaved (GPT-J) pair layout, theta=10000. + head_dim = hidden_dim // num_heads + pos = torch.arange(self._latent_seq_len, dtype=torch.float32, device=self.device) + inv_freq = 1.0 / (10000 ** (torch.arange(0, head_dim, 2, dtype=torch.float32, device=self.device) / head_dim)) + ang = torch.outer(pos, inv_freq) + self.register_buffer("latent_rope_cos", torch.cos(ang), persistent=False) + self.register_buffer("latent_rope_sin", torch.sin(ang), persistent=False) + + @property + def device(self): + return next(self.parameters()).device + + def forward( + self, + x, + prepend_embeds=None, + context=None, + ): + x = self.project_in(x) + + if prepend_embeds is not None: + prepend_dim = prepend_embeds.shape[-1] + assert prepend_dim == x.shape[-1], "prepend dimension must match sequence dimension" + x = torch.cat((prepend_embeds, x), dim=-2) + + time_cond = prepend_embeds.squeeze(1) + mm_tokens = context + + mm_tokens = self.proj_mm_tokens(mm_tokens) + mm_tokens = rearrange(mm_tokens, "b s d -> b d s") + mm_tokens = self.proj_mm_seq_len(mm_tokens) + mm_tokens = rearrange(mm_tokens, "b d s -> b s d") + + time_cond = time_cond.unsqueeze(1) + rot = ( + self.latent_rope_cos.to(device=x.device, dtype=x.dtype), + self.latent_rope_sin.to(device=x.device, dtype=x.dtype), + ) + for block in self.layers: + x = block(x, mm_tokens, rot, context=time_cond) + + x = self.project_out(x) + return x + + +def __getattr__(name: str): + if name in __all__: + return globals()[name] + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/vllm_omni/diffusion/models/audiox/pipeline_audiox.py b/vllm_omni/diffusion/models/audiox/pipeline_audiox.py new file mode 100644 index 00000000000..b2035b7a8d0 --- /dev/null +++ b/vllm_omni/diffusion/models/audiox/pipeline_audiox.py @@ -0,0 +1,922 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +import json +import math +import os +from collections.abc import Iterable +from typing import Any, ClassVar + +import torch +import torch.nn.functional as F +import torchsde +from diffusers import AutoencoderOobleck +from einops import rearrange +from torch import einsum, nn +from torchvision import transforms +from transformers import AutoConfig, CLIPVisionModelWithProjection, T5EncoderModel, T5TokenizerFast +from vllm.logger import init_logger +from vllm.model_executor.models.utils import AutoWeightsLoader + +from vllm_omni.diffusion.attention.layer import Attention +from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.utils import get_local_device +from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.models.audiox.audiox_transformer import MMDiffusionTransformer +from vllm_omni.diffusion.models.interface import SupportAudioOutput +from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin +from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.transformers_utils.processors import audiox as _audiox_transforms + +_VIDEO_ONLY_TASKS = _audiox_transforms.VIDEO_ONLY_TASKS +_TEXT_VIDEO_TASKS = _audiox_transforms.TEXT_VIDEO_TASKS +_VIDEO_CONDITIONED_TASKS = _audiox_transforms.VIDEO_CONDITIONED_TASKS +_normalize_prompts = _audiox_transforms.normalize_prompts +prepare_audio_reference = _audiox_transforms.prepare_audio_reference +prepare_video_reference = _audiox_transforms.prepare_video_reference + +# Polyexponential sigma schedule defaults; mirror upstream AudioX gradio interface +# (``audiox/interface/gradio.py`` ``generate_cond`` defaults: ``sigma_min=0.03``, ``sigma_max=1000``). +_DEFAULT_UPSTREAM_SIGMA_MIN = 0.03 +_DEFAULT_UPSTREAM_SIGMA_MAX = 1000.0 + +logger = init_logger(__name__) + + +def _load_audiox_bundle_config(model_root: str) -> dict[str, Any]: + with open(os.path.join(os.path.abspath(model_root), "config.json"), encoding="utf-8") as f: + return json.load(f) + + +def _audio_conditioning_input_samples(model_config: dict[str, Any]) -> int: + configs = model_config["model"]["conditioning"]["configs"] + cfg = next(c["config"] for c in configs if c["id"] == "audio_prompt") + return int(cfg["latent_seq_len"]) * int(cfg["pretransform_config"]["config"]["downsampling_ratio"]) + + +def get_audiox_post_process_func(od_config: OmniDiffusionConfig): + """Convert the pipeline's float audio tensor to a CPU numpy array for serving.""" + + def post_process_func(audio: torch.Tensor) -> Any: + if isinstance(audio, torch.Tensor): + return audio.detach().cpu().float().numpy() + return audio + + return post_process_func + + +class SA_PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + + def forward(self, x, **kwargs): + return self.fn(self.norm(x), **kwargs) + + +class SA_FeedForward(nn.Module): + def __init__(self, dim, hidden_dim): + super().__init__() + # Dropout p=0 preserves upstream ``net.{2,4}`` state-dict keys so the upstream weights + # load into the right slots; inference is identical to no dropout. + self.net = nn.Sequential( + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(0.0), + nn.Linear(hidden_dim, dim), + nn.Dropout(0.0), + ) + + def forward(self, x): + return self.net(x) + + +# Manual einsum+softmax only. SDPA/diffusion Attention here degrades conditioning vs upstream. +class SA_Attention(nn.Module): + def __init__(self, dim, heads=8, dim_head=64): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head**-0.5 + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) + + self.to_out = ( + nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(0.0), + ) + if project_out + else nn.Identity() + ) + + def forward(self, x): + h = self.heads + qkv = self.to_qkv(x).chunk(3, dim=-1) + q, k, v = (rearrange(t, "b n (h d) -> b h n d", h=h) for t in qkv) + + dots = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale + + attn = dots.softmax(dim=-1) + + out = einsum("b h i j, b h j d -> b h i d", attn, v) + out = rearrange(out, "b h n d -> b n (h d)") + out = self.to_out(out) + return out + + +class SA_Transformer(nn.Module): + def __init__(self, dim, depth, heads, dim_head, mlp_dim): + super().__init__() + self.layers = nn.ModuleList([]) + self.norm = nn.LayerNorm(dim) + for _ in range(depth): + self.layers.append( + nn.ModuleList( + [ + SA_PreNorm(dim, SA_Attention(dim, heads=heads, dim_head=dim_head)), + SA_PreNorm(dim, SA_FeedForward(dim, mlp_dim)), + ] + ) + ) + + def forward(self, x): + for attn, ff in self.layers: + x = attn(x) + x + x = ff(x) + x + return self.norm(x) + + +_AUDIOX_OOBLECK_CONFIG = { + "audio_channels": 2, + "channel_multiples": [1, 2, 4, 8, 16], + "decoder_channels": 128, + "decoder_input_channels": 64, + "downsampling_ratios": [2, 4, 4, 8, 8], + "encoder_hidden_size": 128, + "sampling_rate": 44100, +} + + +def _build_audiox_oobleck(scaling_factor: float = 1.0) -> AutoencoderOobleck: + vae = AutoencoderOobleck(**_AUDIOX_OOBLECK_CONFIG) + vae.audiox_scaling_factor = float(scaling_factor) # type: ignore[attr-defined] + return vae.eval().requires_grad_(False) + + +class AudioVaePromptAdapter(nn.Module): + def __init__(self, *, cond_dim: int, latent_seq_len: int = 215): + super().__init__() + self.pretransform = _build_audiox_oobleck() + in_ch = int(self.pretransform.config.decoder_input_channels) + self.proj_features_128 = nn.Linear(latent_seq_len, 128) + self.proj_out = nn.Linear(in_ch, cond_dim) if in_ch != cond_dim else nn.Identity() + + def forward(self, audio: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + z = self.pretransform.encode(audio, return_dict=True).latent_dist.sample() + latents = z / float(self.pretransform.audiox_scaling_factor) + latents = rearrange(self.proj_features_128(latents), "b c s -> b s c") + latents = self.proj_out(latents) + ones = torch.ones(latents.shape[0], latents.shape[2], device=latents.device) + return latents, ones + + +class _MAFCrossAttentionBlock(nn.Module): + def __init__(self, dim: int, num_heads: int): + super().__init__() + assert dim % num_heads == 0, "dim must be divisible by num_heads" + self._num_heads = num_heads + self._head_dim = dim // num_heads + self._scale = self._head_dim**-0.5 + self.to_q = nn.Linear(dim, dim, bias=True) + self.to_kv = nn.Linear(dim, dim * 2, bias=True) + self.to_out = nn.Linear(dim, dim, bias=True) + nn.init.zeros_(self.to_out.weight) + if self.to_out.bias is not None: + nn.init.zeros_(self.to_out.bias) + + def forward(self, experts: torch.Tensor, full_context: torch.Tensor) -> torch.Tensor: + nh, hd = self._num_heads, self._head_dim + q = rearrange(self.to_q(experts), "b n (h d) -> b h n d", h=nh, d=hd) + k, v = (rearrange(t, "b n (h d) -> b h n d", h=nh, d=hd) for t in self.to_kv(full_context).chunk(2, dim=-1)) + out = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False, scale=self._scale) + return self.to_out(rearrange(out, "b h n d -> b n (h d)")) + + +class _MAFFusionBlock(nn.Module): + def __init__(self, dim: int, num_heads: int, mlp_ratio: float): + super().__init__() + assert dim % num_heads == 0, "dim must be divisible by num_heads" + head_dim = dim // num_heads + self._num_heads = num_heads + hidden = int(dim * mlp_ratio) + self.norm1 = nn.LayerNorm(dim) + self.to_qkv = nn.Linear(dim, dim * 3, bias=True) + self.self_attn = Attention( + num_heads=num_heads, + head_size=head_dim, + softmax_scale=head_dim**-0.5, + causal=False, + ) + self.out_proj = nn.Linear(dim, dim, bias=True) + self.norm2 = nn.LayerNorm(dim) + self.ff = nn.Sequential( + nn.Linear(dim, hidden), + nn.GELU(), + nn.Linear(hidden, dim), + ) + nn.init.zeros_(self.out_proj.weight) + if self.out_proj.bias is not None: + nn.init.zeros_(self.out_proj.bias) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h = self.norm1(x) + q, k, v = self.to_qkv(h).chunk(3, dim=-1) + nh = self._num_heads + q_bsn = rearrange(q, "b n (h d) -> b n h d", h=nh).contiguous() + k_bsn = rearrange(k, "b n (h d) -> b n h d", h=nh).contiguous() + v_bsn = rearrange(v, "b n (h d) -> b n h d", h=nh).contiguous() + out = self.self_attn(q_bsn, k_bsn, v_bsn, attn_metadata=None) + out = rearrange(out, "b n h d -> b n (h d)") + out = self.out_proj(out) + x = x + out + x = x + self.ff(self.norm2(x)) + return x + + +class MAF_Block(nn.Module): + DIM = 768 + MLP_RATIO = 4.0 + + def __init__( + self, + *, + dim: int = 768, + num_experts_per_modality: int = 64, + num_heads: int = 24, + num_fusion_layers: int = 8, + mlp_ratio: float = 4.0, + ): + super().__init__() + total_experts = num_experts_per_modality * 3 + + self.gating_network = nn.Sequential( + nn.Linear(dim * 3, dim), + nn.GELU(), + nn.Linear(dim, 3), + nn.Sigmoid(), + ) + + self.unified_experts = nn.Parameter(torch.randn(total_experts, dim)) + + self.cross_block = _MAFCrossAttentionBlock(dim, num_heads) + self.norm1 = nn.LayerNorm(dim) + self.fusion_blocks = nn.ModuleList( + [_MAFFusionBlock(dim, num_heads, mlp_ratio) for _ in range(num_fusion_layers)] + ) + + self.norm_v2 = nn.LayerNorm(dim) + self.norm_t2 = nn.LayerNorm(dim) + self.norm_a2 = nn.LayerNorm(dim) + self.bypass_gate_v = nn.Parameter(torch.tensor(-10.0)) + self.bypass_gate_t = nn.Parameter(torch.tensor(-10.0)) + self.bypass_gate_a = nn.Parameter(torch.tensor(-10.0)) + + def forward( + self, + video_tokens: torch.Tensor, + text_tokens: torch.Tensor, + audio_tokens: torch.Tensor, + ) -> dict[str, torch.Tensor]: + batch_size = video_tokens.shape[0] + + v_global = video_tokens.mean(dim=1) + t_global = text_tokens.mean(dim=1) + a_global = audio_tokens.mean(dim=1) + + all_global = torch.cat([v_global, t_global, a_global], dim=1) + gates = self.gating_network(all_global) + w_v, w_t, w_a = gates.chunk(3, dim=-1) + + gated_v = video_tokens * w_v.unsqueeze(-1) + gated_t = text_tokens * w_t.unsqueeze(-1) + gated_a = audio_tokens * w_a.unsqueeze(-1) + + full_context = torch.cat([gated_v, gated_t, gated_a], dim=1) + + experts = self.unified_experts.unsqueeze(0).expand(batch_size, -1, -1) + cross_out = self.cross_block(experts, full_context) + updated_experts = self.norm1(experts + cross_out) + + for blk in self.fusion_blocks: + updated_experts = blk(updated_experts) + + fused_v_experts, fused_t_experts, fused_a_experts = updated_experts.chunk(3, dim=1) + + refinement_v = fused_v_experts.mean(dim=1) + refinement_t = fused_t_experts.mean(dim=1) + refinement_a = fused_a_experts.mean(dim=1) + + alpha_v = torch.sigmoid(self.bypass_gate_v) + alpha_t = torch.sigmoid(self.bypass_gate_t) + alpha_a = torch.sigmoid(self.bypass_gate_a) + + final_v = video_tokens + alpha_v * self.norm_v2(refinement_v).unsqueeze(1) + final_t = text_tokens + alpha_t * self.norm_t2(refinement_t).unsqueeze(1) + final_a = audio_tokens + alpha_a * self.norm_a2(refinement_a).unsqueeze(1) + + return { + "video": final_v, + "text": final_t, + "audio": final_a, + } + + +class _BrownianTreeNoiseSampler: + """Brownian-tree noise sampler for DPM-Solver++ SDE, ported from k-diffusion. + + Returns a scaled Brownian increment between two sigma levels; the tree is indexed by + transformed ``sigma`` (here linear, same as k-diffusion's default), so re-querying + identical (sigma, sigma_next) pairs returns the same noise. Entropy must come from + the user's seed — otherwise the unseeded global RNG would make sampling non-deterministic. + """ + + def __init__(self, x: torch.Tensor, sigma_min: torch.Tensor, sigma_max: torch.Tensor, entropy: int): + t0 = torch.as_tensor(sigma_min) + t1 = torch.as_tensor(sigma_max) + self._tree = torchsde.BrownianTree(t0, torch.zeros_like(x), t1, entropy=entropy) + + def __call__(self, sigma: torch.Tensor, sigma_next: torch.Tensor) -> torch.Tensor: + # AudioX denoises from sigma_max → 0, so sigma > sigma_next throughout the loop. + t0 = torch.as_tensor(sigma_next) + t1 = torch.as_tensor(sigma) + return -self._tree(t0, t1) / (t1 - t0).sqrt() + + +class AudioXPipeline(nn.Module, SupportAudioOutput, DiffusionPipelineProfilerMixin): + support_audio_output: ClassVar[bool] = True + audio_sample_rate: ClassVar[int] = 44100 + audio_channels: ClassVar[int] = 2 + _PROFILER_TARGETS: ClassVar[list[str]] = ["diffuse"] + _CLIP_SYNC_DURATION_SEC: ClassVar[float] = 10.0 + _VIDEO_SYNC_FRAME_COUNT: ClassVar[int] = 240 + + def __init__( + self, + *, + od_config: OmniDiffusionConfig, + prefix: str = "", + ): + super().__init__() + self.od_config = od_config + self.device = get_local_device() + if od_config.model is None: + raise ValueError( + "AudioXPipeline requires od_config.model (directory with unified safetensors; " + "see https://huggingface.co/zhangj1an/AudioX)." + ) + + if os.path.exists(od_config.model): + self._model_root = os.path.abspath(od_config.model) + else: + from vllm_omni.model_executor.model_loader.weight_utils import download_weights_from_hf_specific + + self._model_root = download_weights_from_hf_specific(od_config.model, None, ["*"]) + self._model_config = _load_audiox_bundle_config(self._model_root) + + model_config = self._model_config["model"] + diffusion_config = model_config["diffusion"] + + self.model = MMDiffusionTransformer(**dict(diffusion_config["config"])) + + cond_configs = {c["id"]: c["config"] for c in model_config["conditioning"]["configs"]} + self.audio_vae_adapter = AudioVaePromptAdapter( + cond_dim=int(model_config["conditioning"]["cond_dim"]), + latent_seq_len=int(cond_configs["audio_prompt"]["latent_seq_len"]), + ) + + t5_name = cond_configs["text_prompt"]["t5_model_name"] + self._t5_max_length = int(cond_configs["text_prompt"]["max_length"]) + self.tokenizer = T5TokenizerFast.from_pretrained(t5_name) + t5_config = AutoConfig.from_pretrained(t5_name) + self.text_encoder = T5EncoderModel(t5_config).train(False).requires_grad_(False).to(torch.float16) + + clip_name = cond_configs["video_prompt"]["clip_model_name"] + clip_config = AutoConfig.from_pretrained(clip_name) + self.clip_encoder = CLIPVisionModelWithProjection(clip_config.vision_config) + _CLIP_PATCH_TOKENS, _VIDEO_FPS, _DURATION_SEC, _DIM = 50, 5, 10, 768 + _in_features = _CLIP_PATCH_TOKENS * _VIDEO_FPS * _DURATION_SEC + self._clip_in_features = _in_features + self._clip_out_features = 128 + self.clip_proj = nn.Linear(_in_features, self._clip_out_features) + self.clip_proj_sync = nn.Linear(240, self._clip_out_features) + self.clip_sync_weight = nn.Parameter(torch.tensor(0.0)) + self.clip_temp_transformer = SA_Transformer(_DIM, depth=4, heads=16, dim_head=64, mlp_dim=_DIM * 4) + self.clip_temp_pos_embedding = nn.Parameter(torch.randn(1, _VIDEO_FPS * _DURATION_SEC, _DIM)) + self.clip_empty_visual_feat = nn.Parameter(torch.zeros(1, self._clip_out_features, _DIM), requires_grad=False) + _CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073) + _CLIP_STD = (0.26862954, 0.26130258, 0.27577711) + self._clip_normalize = transforms.Compose([transforms.Normalize(mean=list(_CLIP_MEAN), std=list(_CLIP_STD))]) + + self.pretransform = _build_audiox_oobleck( + scaling_factor=float(model_config["pretransform"].get("scale", 1.0)), + ) + + self.io_channels = model_config["io_channels"] + self.diffusion_objective = "v" + + gate_type_config = diffusion_config["gate_type_config"] + self.maf_block = MAF_Block( + dim=768, + num_experts_per_modality=int(gate_type_config["num_experts_per_modality"]), + num_heads=int(gate_type_config["num_heads"]), + num_fusion_layers=int(gate_type_config["num_fusion_layers"]), + ) + + logger.debug("AudioX model built from %s", self._model_root) + + self.weights_sources = [ + DiffusersPipelineLoader.ComponentSource( + model_or_path=self._model_root, + subfolder="transformer", + revision=getattr(od_config, "revision", None), + prefix="", + ), + ] + sample_rate = int(self._model_config.get("sample_rate", 48000)) + self._sample_rate = sample_rate + self._sample_size = int(self._model_config.get("sample_size", sample_rate * 10)) + self._target_fps = int(self._model_config.get("video_fps", 5)) + self._audio_conditioning_samples = _audio_conditioning_input_samples(self._model_config) + + self.setup_diffusion_pipeline_profiler( + profiler_targets=list(self._PROFILER_TARGETS), + enable_diffusion_pipeline_profiler=od_config.enable_diffusion_pipeline_profiler, + ) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + _legacy_prefix = "conditioner.conditioners.audio_prompt." + + # DiT self-attn QKV: bundle stores weights with last dim (h, d, qkv) interleaved + # (i.e. Q/K/V rows for the same head sit next to each other), whereas + # QKVParallelLinear expects stacked [Q|K|V] blocks along the output axis. + # Reshape once at load time so the parallel loader can consume them. + nheads = int(self._model_config["model"]["diffusion"]["config"]["num_heads"]) + embed_dim = int(self._model_config["model"]["diffusion"]["config"]["embed_dim"]) + total_heads = embed_dim // nheads + head_dim = embed_dim // total_heads + qkv_mid = ".attn.qkv." + to_kv_mid = ".cross_attn.to_kv." + + def _restack_interleaved(tensor: torch.Tensor, n_slots: int) -> torch.Tensor: + """Turn bundle's (h, d, slot) interleaved-last-dim layout into stacked [slot|slot|...].""" + if tensor.dim() == 2: # weight: (out, in) + out, inp = tensor.shape + assert out == n_slots * total_heads * head_dim, (out, n_slots, total_heads, head_dim) + return ( + tensor.view(total_heads, head_dim, n_slots, inp).permute(2, 0, 1, 3).reshape(out, inp).contiguous() + ) + out = tensor.shape[0] + assert out == n_slots * total_heads * head_dim + return tensor.view(total_heads, head_dim, n_slots).permute(2, 0, 1).reshape(out).contiguous() + + def _remap(items): + for name, tensor in items: + if name.startswith(_legacy_prefix): + name = "audio_vae_adapter." + name[len(_legacy_prefix) :] + if qkv_mid in name and (name.endswith(".weight") or name.endswith(".bias")): + tensor = _restack_interleaved(tensor, 3) + elif to_kv_mid in name and (name.endswith(".weight") or name.endswith(".bias")): + tensor = _restack_interleaved(tensor, 2) + yield name, tensor + + loaded = AutoWeightsLoader(self).load_weights(_remap(weights)) + + self.to(torch.float32) + self.eval().requires_grad_(False) + + return loaded + + def _conditioning_dtype(self) -> torch.dtype: + p = next(self.model.parameters()) + return p.dtype if p.dtype.is_floating_point else torch.float32 + + @staticmethod + def _normalize_task(task: str | None) -> str | None: + if task is None: + return None + t = str(task).strip().lower() + return t or None + + @staticmethod + def _text_for_task(task_norm: str | None, prompt: str) -> str: + if task_norm in _VIDEO_ONLY_TASKS: + return "" + return prompt + + @staticmethod + def _ensure_text_video_prompts(task_norm: str | None, prompts: list[str]) -> None: + if task_norm not in _TEXT_VIDEO_TASKS: + return + for i, p in enumerate(prompts): + if not str(p).strip(): + raise ValueError( + f"audiox_task={task_norm!r} requires a non-empty text prompt for item {i}; " + "use v2a/v2m for video-only generation." + ) + + def _audio_prompt_tensors( + self, + *, + raw_prompts: list[Any], + extra: dict[str, Any], + device: torch.device, + cond_dtype: torch.dtype = torch.float32, + ) -> list[torch.Tensor]: + target_len = self._audio_conditioning_samples + sample_rate = self._sample_rate + seconds_start = float(extra.get("seconds_start", 0.0)) + seconds_total = float(target_len) / float(sample_rate) + out: list[torch.Tensor] = [] + for raw in raw_prompts: + src = extra.get("audio_path") + if src is None: + out.append(torch.zeros(2, target_len, device=device, dtype=cond_dtype)) + continue + wav = prepare_audio_reference( + src, + model_sample_rate=sample_rate, + seconds_start=seconds_start, + seconds_total=seconds_total, + device=device, + ) + out.append(wav.to(dtype=cond_dtype)) + return out + + def _video_feature_tensors( + self, + *, + task_norm: str | None, + raw_prompts: list[Any], + extra: dict[str, Any], + seconds_start: float, + target_fps: int, + device: torch.device, + cond_dtype: torch.dtype = torch.float32, + ) -> list[torch.Tensor]: + clip_frames = int(round(self._CLIP_SYNC_DURATION_SEC * target_fps)) + if task_norm not in _VIDEO_CONDITIONED_TASKS: + empty = torch.zeros(clip_frames, 3, 224, 224, device=device, dtype=cond_dtype) + return [empty for _ in raw_prompts] + + tensors: list[torch.Tensor] = [] + for _ in raw_prompts: + src = extra.get("video_path") + if src is None: + raise ValueError(f"audiox_task={task_norm!r} requires video input: set extra_args['video_path'].") + vt = prepare_video_reference( + src, + duration=float(self._CLIP_SYNC_DURATION_SEC), + target_fps=target_fps, + seek_time=seconds_start, + ) + tensors.append(vt.to(device=device, dtype=cond_dtype)) + return tensors + + def get_conditioning_inputs(self, conditioning_tensors: dict[str, Any], negative: bool = False) -> dict[str, Any]: + video_feature, video_mask = conditioning_tensors["video_prompt"] + text_feature, text_mask = conditioning_tensors["text_prompt"] + audio_feature, audio_mask = conditioning_tensors["audio_prompt"] + + refined = self.maf_block(text_feature, video_feature, audio_feature) + fused = torch.cat(list(refined.values()), dim=1) + masks = torch.cat([video_mask, text_mask, audio_mask], dim=1) + + if negative: + return {"negative_cross_attn_cond": fused, "negative_cross_attn_mask": masks} + return {"cross_attn_cond": fused} + + def diffuse( + self, + *, + steps: int, + guidance_scale: float, + conditioning_tensors: dict[str, Any], + negative_conditioning_tensors: dict[str, Any] | None, + batch_size: int, + sigma_min: float, + sigma_max: float, + generator: torch.Generator, + cfg_rescale: float, + ) -> torch.Tensor: + device = self.device + model_dtype = next(self.model.parameters()).dtype + + # Match upstream AudioX: disable TF32 matmul + fp16 reduced precision + cudnn benchmark + # for numerical parity with audiox/inference/generation.py:152-156. + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cudnn.allow_tf32 = False + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False + torch.backends.cudnn.benchmark = False + + latent_len = self._sample_size // int(self.pretransform.hop_length) + noise = torch.randn( + [batch_size, self.io_channels, latent_len], device=device, generator=generator, dtype=model_dtype + ) + + def _cast(d: dict[str, Any]) -> dict[str, Any]: + return {k: (v.type(model_dtype) if isinstance(v, torch.Tensor) else v) for k, v in d.items()} + + cond = _cast(self.get_conditioning_inputs(conditioning_tensors)) + neg = ( + _cast(self.get_conditioning_inputs(negative_conditioning_tensors, negative=True)) + if negative_conditioning_tensors is not None + else {} + ) + + # Inlined k-diffusion VDenoiser + sample_dpmpp_3m_sde, matching upstream AudioX exactly. + # diffusers' EDMDPMSolverMultistepScheduler uses different v-prediction preconditioning + # and a different stochastic update rule, which here produces a fixed ~861 Hz resonance + # in the decoded audio regardless of conditioning. + def denoise(x_in: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: + s2 = sigma * sigma + c_skip = 1.0 / (s2 + 1.0) + c_out = -sigma / (s2 + 1.0).sqrt() + c_in = 1.0 / (s2 + 1.0).sqrt() + t_cond = sigma.atan() * (2.0 / math.pi) + v = self.model( + x_in * c_in, + t_cond, + cross_attn_cond=cond["cross_attn_cond"], + negative_cross_attn_cond=neg.get("negative_cross_attn_cond"), + negative_cross_attn_mask=neg.get("negative_cross_attn_mask"), + cfg_scale=guidance_scale, + scale_phi=cfg_rescale, + ) + return v * c_out + x_in * c_skip + + ramp = torch.linspace(1.0, 0.0, steps, device=device) + sigmas = torch.cat( + [ + torch.exp(ramp * (math.log(sigma_max) - math.log(sigma_min)) + math.log(sigma_min)), + torch.zeros(1, device=device), + ] + ) + x = noise * sigmas[0] + + # Match upstream AudioX: sampler runs under fp16 autocast (see audiox/inference/sampling.py:184). + with torch.cuda.amp.autocast(): + if steps <= 1: + s_in = x.new_ones([x.shape[0]]) + sampled = denoise(x, sigmas[0] * s_in) + else: + # DPM-Solver++(3M) SDE loop (k-diffusion sample_dpmpp_3m_sde), eta=1.0, s_noise=1.0. + noise_sampler = _BrownianTreeNoiseSampler( + x, + sigmas[sigmas > 0].min(), + sigmas.max(), + entropy=generator.initial_seed(), + ) + s_in = x.new_ones([x.shape[0]]) + denoised_1 = denoised_2 = None + h_1 = h_2 = None + eta = 1.0 + for i in range(len(sigmas) - 1): + denoised = denoise(x, sigmas[i] * s_in) + if sigmas[i + 1] == 0: + x = denoised + else: + t_, s_ = -sigmas[i].log(), -sigmas[i + 1].log() + h = s_ - t_ + h_eta = h * (eta + 1) + x = torch.exp(-h_eta) * x + (-h_eta).expm1().neg() * denoised + if h_2 is not None: + r0 = h_1 / h + r1 = h_2 / h + d1_0 = (denoised - denoised_1) / r0 + d1_1 = (denoised_1 - denoised_2) / r1 + d1 = d1_0 + (d1_0 - d1_1) * r0 / (r0 + r1) + d2 = (d1_0 - d1_1) / (r0 + r1) + phi_2 = h_eta.neg().expm1() / h_eta + 1 + phi_3 = phi_2 / h_eta - 0.5 + x = x + phi_2 * d1 - phi_3 * d2 + elif h_1 is not None: + r = h_1 / h + d = (denoised - denoised_1) / r + phi_2 = h_eta.neg().expm1() / h_eta + 1 + x = x + phi_2 * d + x = ( + x + + noise_sampler(sigmas[i], sigmas[i + 1]) + * sigmas[i + 1] + * (-2 * h * eta).expm1().neg().sqrt() + ) + denoised_1, denoised_2 = denoised, denoised_1 + h_1, h_2 = h, h_1 + sampled = x + + vae = self.pretransform.to(device=sampled.device, dtype=torch.float32).eval() + return vae.decode(sampled.to(torch.float32) * float(vae.audiox_scaling_factor), return_dict=True).sample + + def _encode_text(self, texts: list[str], device: torch.device) -> list[torch.Tensor]: + self.text_encoder.to(device) + encoded = self.tokenizer( + texts, + truncation=True, + max_length=self._t5_max_length, + padding="max_length", + return_tensors="pt", + ) + input_ids = encoded["input_ids"].to(device) + attention_mask = encoded["attention_mask"].to(device).to(torch.bool) + + self.text_encoder.eval() + with torch.no_grad(): + embeddings = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)["last_hidden_state"] + + embeddings = embeddings.float() * attention_mask.unsqueeze(-1).float() + return [embeddings, attention_mask] + + def _encode_video(self, video_list: list[dict], device: torch.device) -> list[torch.Tensor]: + self.clip_encoder.to(device).eval() + + video_tensors = [item["video_tensors"] for item in video_list] + video_sync_frames = torch.cat([item["video_sync_frames"] for item in video_list], dim=0).to(device) + + original_videos = torch.cat(video_tensors, dim=0).to(device) + batch_size, time_length, _, _, _ = original_videos.size() + is_zero = torch.all(original_videos == 0, dim=(1, 2, 3, 4)) + + frames = original_videos.flatten(0, 1) + pixel_values = self._clip_normalize(frames).to(device) + + with torch.no_grad(): + outputs = self.clip_encoder(pixel_values=pixel_values) + hidden = outputs.last_hidden_state + hidden = rearrange(hidden, "(b t) p d -> (b p) t d", b=batch_size, t=time_length) + hidden = hidden + self.clip_temp_pos_embedding + hidden = self.clip_temp_transformer(hidden) + hidden = rearrange(hidden, "(b p) t d -> b (t p) d", b=batch_size) + hidden = self.clip_proj(hidden.view(-1, self._clip_in_features)) + hidden = hidden.view(batch_size, self._clip_out_features, -1) + + sync = self.clip_proj_sync(video_sync_frames.view(-1, 240)) + sync = sync.view(batch_size, self._clip_out_features, -1) + hidden = hidden + self.clip_sync_weight * sync + + empty = self.clip_empty_visual_feat.expand(batch_size, -1, -1) + hidden = torch.where(is_zero.view(batch_size, 1, 1), empty, hidden) + return [hidden, torch.ones(batch_size, 1, device=device)] + + def _encode_conditioning_tensors(self, batch_metadata: list[dict[str, Any]]) -> dict[str, Any]: + device = self.device + audio = torch.cat([item["audio_prompt"] for item in batch_metadata], dim=0).to(device) + return { + "audio_prompt": list(self.audio_vae_adapter(audio)), + "text_prompt": self._encode_text([item["text_prompt"] for item in batch_metadata], device), + "video_prompt": self._encode_video([item["video_prompt"] for item in batch_metadata], device), + } + + def _build_conditioning_batch( + self, + *, + texts: list[str], + video_tensors_list: list[torch.Tensor], + audio_prompt_list: list[torch.Tensor], + sync_features: torch.Tensor, + seconds_start: float, + seconds_model: float, + num_outputs_per_prompt: int, + task_norm: str | None, + ) -> list[dict[str, Any]]: + batch: list[dict[str, Any]] = [] + for i, text in enumerate(texts): + for _ in range(num_outputs_per_prompt): + batch.append( + { + "video_prompt": { + "video_tensors": video_tensors_list[i].unsqueeze(0), + "video_sync_frames": sync_features, + }, + "text_prompt": self._text_for_task(task_norm, text), + "audio_prompt": audio_prompt_list[i].unsqueeze(0), + "seconds_start": seconds_start, + "seconds_total": seconds_model, + } + ) + return batch + + def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput: + if req.prompts is None or len(req.prompts) == 0: + raise ValueError("AudioXPipeline requires at least one prompt.") + normalized_prompts = _normalize_prompts(list(req.prompts)) + prompts = [p["prompt"] for p in normalized_prompts] + + sampling_params = req.sampling_params + if sampling_params.num_inference_steps is None: + raise ValueError("AudioXPipeline requires sampling_params.num_inference_steps.") + num_inference_steps = int(sampling_params.num_inference_steps) + extra_args = sampling_params.extra_args or {} + task_norm = self._normalize_task(extra_args.get("audiox_task")) + self._ensure_text_video_prompts(task_norm, prompts) + + neg: list[str] | None = None + if not all(p.get("negative_prompt") is None for p in normalized_prompts): + neg = [str(p.get("negative_prompt") or "") for p in normalized_prompts] + + guidance_scale = float(sampling_params.guidance_scale) + if sampling_params.num_outputs_per_prompt <= 0: + raise ValueError("AudioXPipeline requires sampling_params.num_outputs_per_prompt > 0.") + num_outputs_per_prompt = int(sampling_params.num_outputs_per_prompt) + batch_size = len(prompts) * num_outputs_per_prompt + + seconds_start = float(extra_args.get("seconds_start", 0.0)) + seconds_model = self._sample_size / self._sample_rate + seconds_total = float(extra_args.get("seconds_total", seconds_model)) + sigma_min = float(extra_args.get("sigma_min", _DEFAULT_UPSTREAM_SIGMA_MIN)) + sigma_max = float(extra_args.get("sigma_max", _DEFAULT_UPSTREAM_SIGMA_MAX)) + cfg_rescale = float(extra_args.get("cfg_rescale", 0.0)) + device = self.device + generator = sampling_params.generator + if generator is None: + raise ValueError("AudioXPipeline requires sampling_params.generator.") + target_fps = self._target_fps + cond_dtype = self._conditioning_dtype() + + sync_features = torch.zeros(1, self._VIDEO_SYNC_FRAME_COUNT, 768, device=device, dtype=cond_dtype) + + audio_prompt_list = self._audio_prompt_tensors( + raw_prompts=normalized_prompts, + extra=extra_args, + device=device, + cond_dtype=cond_dtype, + ) + + video_tensors_list = self._video_feature_tensors( + task_norm=task_norm, + raw_prompts=normalized_prompts, + extra=extra_args, + seconds_start=seconds_start, + target_fps=target_fps, + device=device, + cond_dtype=cond_dtype, + ) + + conditioning_batch = self._build_conditioning_batch( + texts=prompts, + video_tensors_list=video_tensors_list, + audio_prompt_list=audio_prompt_list, + sync_features=sync_features, + seconds_start=seconds_start, + seconds_model=seconds_model, + num_outputs_per_prompt=num_outputs_per_prompt, + task_norm=task_norm, + ) + + negative_conditioning_batch: list[dict[str, Any]] | None = None + if neg is not None and guidance_scale > 1.0: + negative_conditioning_batch = self._build_conditioning_batch( + texts=neg, + video_tensors_list=video_tensors_list, + audio_prompt_list=audio_prompt_list, + sync_features=sync_features, + seconds_start=seconds_start, + seconds_model=seconds_model, + num_outputs_per_prompt=num_outputs_per_prompt, + task_norm=task_norm, + ) + + conditioning_tensors = self._encode_conditioning_tensors(conditioning_batch) + negative_conditioning_tensors: dict[str, Any] | None = None + if negative_conditioning_batch is not None: + negative_conditioning_tensors = self._encode_conditioning_tensors(negative_conditioning_batch) + + audio = self.diffuse( + steps=num_inference_steps, + guidance_scale=guidance_scale, + conditioning_tensors=conditioning_tensors, + negative_conditioning_tensors=negative_conditioning_tensors, + batch_size=batch_size, + sigma_min=sigma_min, + sigma_max=sigma_max, + generator=generator, + cfg_rescale=cfg_rescale, + ) + + # Trim decoded audio to the requested duration (matches upstream AudioX sample script). + if 0.0 < seconds_total < seconds_model: + target_samples = int(seconds_total * self._sample_rate) + audio = audio[..., :target_samples] + + return DiffusionOutput( + output=audio, + custom_output={"audiox_task": task_norm}, + stage_durations=self.stage_durations + if getattr(self, "enable_diffusion_pipeline_profiler", False) + else None, + ) diff --git a/vllm_omni/diffusion/models/stable_audio/stable_audio_transformer.py b/vllm_omni/diffusion/models/stable_audio/stable_audio_transformer.py index a3d4dc517f7..7422f320f67 100644 --- a/vllm_omni/diffusion/models/stable_audio/stable_audio_transformer.py +++ b/vllm_omni/diffusion/models/stable_audio/stable_audio_transformer.py @@ -5,7 +5,6 @@ Stable Audio DiT Model for vLLM-Omni. """ -import math from collections.abc import Iterable import torch @@ -18,6 +17,7 @@ from vllm_omni.diffusion.attention.layer import Attention from vllm_omni.diffusion.data import OmniDiffusionConfig from vllm_omni.diffusion.distributed.hsdp_utils import is_transformer_block_module +from vllm_omni.diffusion.layers.fourier import GaussianFourierProjection logger = init_logger(__name__) @@ -56,7 +56,7 @@ def apply_rotary_emb_stable_audio( return torch.cat([x_rot, x_pass], dim=-1) -class StableAudioGaussianFourierProjection(nn.Module): +class StableAudioGaussianFourierProjection(GaussianFourierProjection): """Gaussian Fourier embeddings for noise levels. Matches diffusers StableAudioGaussianFourierProjection with: @@ -65,15 +65,12 @@ class StableAudioGaussianFourierProjection(nn.Module): """ def __init__(self, embedding_size: int = 256, scale: float = 1.0): - super().__init__() - self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) + super().__init__(in_features=1, embedding_size=embedding_size, scale=scale, trainable=False) def forward(self, x: torch.Tensor) -> torch.Tensor: # x shape: [batch] or [batch, 1] - # Output: [batch, embedding_size * 2] - x_proj = 2 * math.pi * x[:, None] @ self.weight[None, :] - # flip_sin_to_cos=True means cos comes first - return torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1) + # Output: [batch, embedding_size * 2], with cos first. + return super().forward(x) class StableAudioSelfAttention(nn.Module): diff --git a/vllm_omni/diffusion/registry.py b/vllm_omni/diffusion/registry.py index d7b058ef21a..44beb4bc1d6 100644 --- a/vllm_omni/diffusion/registry.py +++ b/vllm_omni/diffusion/registry.py @@ -205,6 +205,11 @@ "pipeline_dreamid_omni", "DreamIDOmniPipeline", ), + "AudioXPipeline": ( + "audiox", + "pipeline_audiox", + "AudioXPipeline", + ), "HunyuanVideo15Pipeline": ( "hunyuan_video", "pipeline_hunyuan_video_1_5", @@ -251,6 +256,7 @@ _NO_CACHE_ACCELERATION = { # Pipelines that do not support cache acceleration (cache_dit / tea_cache). "NextStep11Pipeline", + "AudioXPipeline", } @@ -426,6 +432,7 @@ def _apply_sequence_parallel_if_enabled(model, od_config: OmniDiffusionConfig) - "LTX23Pipeline": "get_ltx2_post_process_func", "LTX23ImageToVideoPipeline": "get_ltx2_post_process_func", "StableAudioPipeline": "get_stable_audio_post_process_func", + "AudioXPipeline": "get_audiox_post_process_func", "WanImageToVideoPipeline": "get_wan22_i2v_post_process_func", "WanT2VDMD2Pipeline": "get_wan22_post_process_func", "WanI2VDMD2Pipeline": "get_wan22_i2v_post_process_func", diff --git a/vllm_omni/engine/async_omni_engine.py b/vllm_omni/engine/async_omni_engine.py index 54c9d32d9ea..66096de7c0c 100644 --- a/vllm_omni/engine/async_omni_engine.py +++ b/vllm_omni/engine/async_omni_engine.py @@ -36,6 +36,7 @@ from vllm_omni.config.stage_config import strip_parent_engine_args from vllm_omni.diffusion.data import DiffusionParallelConfig +from vllm_omni.diffusion.diffusion_engine import supports_audio_output from vllm_omni.diffusion.stage_diffusion_client import StageDiffusionClient from vllm_omni.diffusion.stage_diffusion_proc import ( complete_diffusion_handshake, @@ -1324,6 +1325,8 @@ def _create_default_diffusion_stage_cfg(kwargs: dict[str, Any]) -> list: num_devices = max(1, int(parallel_config.world_size)) devices = ",".join(str(i) for i in range(num_devices)) + model_class_name = kwargs.get("model_class_name", None) + final_output_type = "audio" if model_class_name and supports_audio_output(model_class_name) else "image" stage_engine_args = { "max_num_seqs": 1, @@ -1380,7 +1383,7 @@ def _create_default_diffusion_stage_cfg(kwargs: dict[str, Any]) -> list: "engine_args": stage_engine_args, "default_sampling_params": stage_default_sampling_params, "final_output": True, - "final_output_type": "image", + "final_output_type": final_output_type, } ] default_stage_cfg[0]["engine_args"]["model_stage"] = "diffusion" diff --git a/vllm_omni/entrypoints/openai/serving_chat.py b/vllm_omni/entrypoints/openai/serving_chat.py index 7f36e38fd85..4952e69294a 100644 --- a/vllm_omni/entrypoints/openai/serving_chat.py +++ b/vllm_omni/entrypoints/openai/serving_chat.py @@ -2392,8 +2392,10 @@ async def _create_diffusion_chat_completion( else: messages.append({"role": getattr(msg, "role", "user"), "content": getattr(msg, "content", "")}) - # Extract prompt and images from messages - prompt, reference_images = self._extract_diffusion_prompt_and_images(messages) + # Extract prompt and multimodal inputs from messages + prompt, reference_images, reference_videos, reference_audios = self._extract_diffusion_prompt_and_media( + messages + ) # Extract generation parameters from extra_body (preferred) # Reference: text_to_image.py and text_to_video.py for supported parameters @@ -2484,6 +2486,14 @@ async def _create_diffusion_chat_completion( if resolution is not None: gen_params.resolution = resolution + # Pipeline-agnostic escape hatch (mirrors ``extra_params`` on the /v1/videos + # endpoint in ``serving_video.py``): a single reserved ``extra_args`` dict in + # ``extra_body`` flows straight into ``gen_params.extra_args``, with no keys + # hardcoded here. + extra_args_body = extra_body.get("extra_args") + if isinstance(extra_args_body, dict): + gen_params.extra_args.update(extra_args_body) + # Parse per-request LoRA. if lora_body and isinstance(lora_body, dict): try: @@ -2517,7 +2527,12 @@ async def _create_diffusion_chat_completion( status_code=400, ) - # Generate image + if reference_videos: + gen_params.extra_args["video_path"] = reference_videos[0] + if reference_audios: + gen_params.extra_args["audio_path"] = reference_audios[0] + + # Generate image or audio (e.g. AudioX) via AsyncOmni diffusion_engine = cast(AsyncOmni, self._diffusion_engine) stage_configs = list(getattr(diffusion_engine, "stage_configs", []) or []) sampling_params_list = build_stage_sampling_params_list( @@ -2539,54 +2554,112 @@ async def _create_diffusion_chat_completion( result = output if result is None: return self._create_error_response("No output generated from AsyncOmni") - # Extract images from result + final_output_type = getattr(result, "final_output_type", "image") # Handle nested OmniRequestOutput structure where images might be in request_output images = getattr(result.request_output, "images", []) + multimodal_output = getattr(result, "multimodal_output", {}) or {} stage_durations = result.stage_durations peak_memory_mb = result.peak_memory_mb - # Convert images to base64 content - image_contents: list[dict[str, Any]] = [] - flat_images = [] - for item in images: - if isinstance(item, list): - flat_images.extend(item) + if final_output_type == "audio": + sample_rate = 48000 + for key in ("audio_sample_rate", "sample_rate", "sampling_rate", "sr"): + raw_rate = multimodal_output.get(key) + try: + if raw_rate is not None: + sample_rate = int(raw_rate) + break + except (TypeError, ValueError): + pass + + audio_payload = multimodal_output.get("audio") + if isinstance(audio_payload, list): + if len(audio_payload) == 0: + audio_payload = None + elif len(audio_payload) == 1: + audio_payload = audio_payload[0] + if audio_payload is None: + return self._create_error_response("Audio generation completed but no audio was produced.") + + if isinstance(audio_payload, torch.Tensor): + audio_tensor = audio_payload.detach().cpu().float() else: - flat_images.append(item) - - for img in flat_images: - with BytesIO() as buffer: - img.save(buffer, format="PNG") - img_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8") - image_contents.append( - { - "type": "image_url", - "image_url": { - "url": f"data:image/png;base64,{img_base64}", - }, - "stage_durations": stage_durations, - "peak_memory_mb": peak_memory_mb, - } + audio_tensor = torch.as_tensor(audio_payload).detach().cpu().float() + # Pipelines deliver audio as (C, T), (T,), or (B, C, T) in channels-first + # convention (torch default). Drop a leading batch dim, then transpose to + # (T, C) for soundfile / CreateAudio. Flattening here would corrupt stereo. + if audio_tensor.ndim == 3: + audio_tensor = audio_tensor[0] + if audio_tensor.ndim == 2: + audio_tensor = audio_tensor.transpose(0, 1).contiguous() + elif audio_tensor.ndim > 3: + raise ValueError(f"Unexpected audio tensor rank {audio_tensor.ndim}; expected 1-3 dims.") + audio_array = audio_tensor.numpy() + + audio_obj = CreateAudio( + audio_tensor=audio_array, + sample_rate=sample_rate, + response_format="wav", + speed=1.0, + stream_format="audio", + base64_encode=True, + ) + audio_response: AudioResponse = self.create_audio(audio_obj) + audio_base64 = audio_response.audio_data + audio_id = f"audio-{uuid.uuid4().hex[:16]}" + expires_at = int((datetime.now(timezone.utc) + timedelta(hours=24)).timestamp()) + message = ChatMessage( + role="assistant", + audio=OpenAIChatCompletionAudio( + id=audio_id, + data=audio_base64, + expires_at=expires_at, + transcript="", + ), ) - - # Build response - if not image_contents: - content = "Image generation completed but no images were produced." else: - content = image_contents - - # Use model_construct to bypass validation for multimodal content - # (ChatMessage.content only accepts str, but we need list for images) - # Then use object.__setattr__ to directly set the field, bypassing Pydantic's type checking - import warnings as warnings_module - - with warnings_module.catch_warnings(): - warnings_module.filterwarnings("ignore", category=UserWarning, module="pydantic") - message = ChatMessage.model_construct(role="assistant") - object.__setattr__(message, "content", content) - # Mark content as set in fields_set to ensure proper serialization - if hasattr(message, "__pydantic_fields_set__"): - message.__pydantic_fields_set__.add("content") + # Convert images to base64 content + image_contents: list[dict[str, Any]] = [] + flat_images = [] + for item in images: + if isinstance(item, list): + flat_images.extend(item) + else: + flat_images.append(item) + + for img in flat_images: + with BytesIO() as buffer: + img.save(buffer, format="PNG") + img_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8") + image_contents.append( + { + "type": "image_url", + "image_url": { + "url": f"data:image/png;base64,{img_base64}", + }, + "stage_durations": stage_durations, + "peak_memory_mb": peak_memory_mb, + } + ) + + # Build response + if not image_contents: + content = "Image generation completed but no images were produced." + else: + content = image_contents + + # Use model_construct to bypass validation for multimodal content + # (ChatMessage.content only accepts str, but we need list for images) + # Then use object.__setattr__ to directly set the field, bypassing Pydantic's type checking + import warnings as warnings_module + + with warnings_module.catch_warnings(): + warnings_module.filterwarnings("ignore", category=UserWarning, module="pydantic") + message = ChatMessage.model_construct(role="assistant") + object.__setattr__(message, "content", content) + # Mark content as set in fields_set to ensure proper serialization + if hasattr(message, "__pydantic_fields_set__"): + message.__pydantic_fields_set__.add("content") choice = ChatCompletionResponseChoice.model_construct( index=0, message=message, @@ -2608,8 +2681,9 @@ async def _create_diffusion_chat_completion( ) logger.info( - "Diffusion chat completed for request %s: %d images", + "Diffusion chat completed for request %s: output_type=%s, image_count=%d", request_id, + final_output_type, len(images), ) @@ -2622,20 +2696,22 @@ async def _create_diffusion_chat_completion( status_code=500, ) - def _extract_diffusion_prompt_and_images( + def _extract_diffusion_prompt_and_media( self, messages: list[dict[str, Any]], - ) -> tuple[str, list[str]]: - """Extract text prompt and base64 images from chat messages. + ) -> tuple[str, list[str], list[str], list[str]]: + """Extract text prompt and multimodal inputs from chat messages. Args: messages: List of chat messages Returns: - Tuple of (prompt_text, list_of_base64_images) + Tuple of (prompt_text, list_of_base64_images, list_of_video_urls, list_of_audio_urls) """ prompt_parts: list[str] = [] images: list[str] = [] + videos: list[str] = [] + audios: list[str] = [] for message in messages: role = message.get("role", "") @@ -2670,19 +2746,32 @@ def _extract_diffusion_prompt_and_images( images.append(b64_data) except ValueError: logger.warning("Invalid data URL format") + elif item.get("type") == "video_url": + url = item.get("video_url", {}).get("url", "") + if isinstance(url, str) and url: + videos.append(url) + elif item.get("type") == "audio_url": + url = item.get("audio_url", {}).get("url", "") + if isinstance(url, str) and url: + audios.append(url) # Handle {"image": "base64..."} format elif "image" in item: images.append(item["image"]) + elif "video" in item and isinstance(item["video"], str): + videos.append(item["video"]) + elif "audio" in item and isinstance(item["audio"], str): + audios.append(item["audio"]) prompt = " ".join(prompt_parts).strip() - return prompt, images + return prompt, images, videos, audios def _extract_diffusion_prompt_and_images_from_messages( self, messages: list[Any], ) -> tuple[str, list[str]]: """Normalize mixed message types and extract prompt + reference images once.""" - return self._extract_diffusion_prompt_and_images(self._messages_to_dicts(messages)) + prompt, images, _videos, _audios = self._extract_diffusion_prompt_and_media(self._messages_to_dicts(messages)) + return prompt, images @staticmethod def _messages_to_dicts(messages: list[Any]) -> list[dict[str, Any]]: diff --git a/vllm_omni/transformers_utils/processors/audiox.py b/vllm_omni/transformers_utils/processors/audiox.py new file mode 100644 index 00000000000..b0e53cfc7c5 --- /dev/null +++ b/vllm_omni/transformers_utils/processors/audiox.py @@ -0,0 +1,240 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Input transform utilities for the AudioX diffusion pipeline. + +Loads and normalizes the raw audio/video conditioning signals (file path / URL / +``data:`` URI / ``np.ndarray`` / ``torch.Tensor``) into the (channels, samples) and +[T, C, H, W] tensors the pipeline needs. The pipeline itself stays focused on model +forward + sampling logic. +""" + +from __future__ import annotations + +import base64 +import os +import tempfile +from typing import Any +from urllib.parse import urlparse +from urllib.request import urlopen + +import av +import numpy as np +import soundfile +import torch +import torch.nn.functional as F +import torchaudio.functional as taF +from einops import rearrange +from torchvision.io import read_image + +# AudioX task taxonomy. Tasks beginning with "v" require a video input; tasks containing +# "t" carry a text prompt. tv2*/v2* share the same conditioning pathways. +VIDEO_ONLY_TASKS = frozenset({"v2a", "v2m"}) +TEXT_VIDEO_TASKS = frozenset({"tv2a", "tv2m"}) +VIDEO_CONDITIONED_TASKS = VIDEO_ONLY_TASKS | TEXT_VIDEO_TASKS + +_IMAGE_EXTS = frozenset({".jpg", ".jpeg", ".png"}) + + +def normalize_prompts(prompts: list[Any]) -> list[dict[str, Any]]: + """Coerce raw prompt entries into ``{"prompt": str, ...}`` dicts (preserves extras).""" + out: list[dict[str, Any]] = [] + for i, raw in enumerate(prompts): + if isinstance(raw, str): + out.append({"prompt": raw.strip()}) + elif isinstance(raw, dict): + p = dict(raw) + p["prompt"] = str(p.get("prompt") or "").strip() + out.append(p) + else: + raise TypeError(f"AudioX prompt {i} must be str or dict, got {type(raw)!r}") + return out + + +def materialize_media_source(source: str) -> str: + """Return a local filesystem path for ``source``. + + Accepts a local path, a ``data:;base64,...`` URI, or an ``http(s)://`` URL. + Anything non-local is fetched into a NamedTemporaryFile and that path is returned; + callers don't need to clean the tempfile up (the OS does on exit). + """ + if source.startswith("data:"): + _, _, payload = source.partition(",") + raw = base64.b64decode(payload) + f = tempfile.NamedTemporaryFile(prefix="audiox_media_", suffix=".bin", delete=False) + f.write(raw) + f.close() + return f.name + parsed = urlparse(source) + if parsed.scheme in ("http", "https"): + with urlopen(source) as resp: + data = resp.read() + f = tempfile.NamedTemporaryFile(prefix="audiox_media_", suffix=".bin", delete=False) + f.write(data) + f.close() + return f.name + return source + + +def _load_video_path_pyav( + path: str, + *, + target_fps: int, + duration: float, + seek_time: float, +) -> torch.Tensor: + path = materialize_media_source(path) + seek_time = float(seek_time) + duration = float(duration) + end_time = seek_time + duration if duration > 0 else None + + with av.open(path) as container: + stream = container.streams.video[0] + time_base = float(stream.time_base) + src_fps = float(stream.average_rate or target_fps) + + if seek_time > 0: + container.seek(int(seek_time / time_base), stream=stream) + + frames: list[np.ndarray] = [] + for frame in container.decode(stream): + t = float(frame.pts) * time_base + if t < seek_time: + continue + if end_time is not None and t >= end_time: + break + frames.append(frame.to_ndarray(format="rgb24")) + + if not frames: + raise ValueError(f"No frames in range seek_time={seek_time!r}, duration={duration!r} for {path!r}") + + # PyAV gives [H, W, C] uint8 RGB per frame; AudioX expects [T, C, H, W]. + video = torch.from_numpy(np.stack(frames)).permute(0, 3, 1, 2).contiguous() + n_target = max(1, int(round(video.shape[0] * float(target_fps) / src_fps))) + if n_target >= video.shape[0]: + return video + indices = torch.linspace(0, video.shape[0] - 1, n_target).round().long() + return video[indices] + + +def load_video_source( + source: Any, + *, + target_fps: int, + duration: float, + seek_time: float = 0.0, +) -> torch.Tensor: + if isinstance(source, str): + ext = os.path.splitext(source)[1].lower() + if ext in _IMAGE_EXTS: + return read_image(materialize_media_source(source)).unsqueeze(0) + return _load_video_path_pyav( + source, + target_fps=target_fps, + duration=duration, + seek_time=seek_time, + ) + + if isinstance(source, torch.Tensor): + return source + if isinstance(source, np.ndarray): + return torch.from_numpy(source) + raise TypeError(f"Unsupported video source type: {type(source)!r}") + + +def normalize_video_tensor(frames: torch.Tensor, size: int = 224) -> torch.Tensor: + if frames.dim() != 4: + raise ValueError(f"Expected [T, C, H, W], got {tuple(frames.shape)}") + + frames = frames.float() + if frames.max() > 1.5: + frames = frames / 255.0 + + if frames.shape[-2:] != (size, size): + frames = F.interpolate(frames, size=(size, size), mode="bicubic", align_corners=False) + return frames + + +def adjust_video_duration(frames: torch.Tensor, duration: float, target_fps: int) -> torch.Tensor: + target_t = int(duration * target_fps) + cur_t = frames.shape[0] + + if cur_t > target_t: + return frames[:target_t] + if cur_t < target_t: + last = frames[-1:].repeat(target_t - cur_t, 1, 1, 1) + return torch.cat([frames, last], dim=0) + return frames + + +def prepare_video_reference( + source: Any, + *, + duration: float, + target_fps: int, + seek_time: float = 0.0, +) -> torch.Tensor: + """Decode a video clip (or single image) into the AudioX [T, 3, 224, 224] form.""" + frames = load_video_source( + source, + target_fps=target_fps, + duration=duration, + seek_time=seek_time, + ) + + if frames.dim() == 4 and frames.shape[-1] == 3: + frames = rearrange(frames, "t h w c -> t c h w") + + frames = normalize_video_tensor(frames, size=224) + if duration > 0: + frames = adjust_video_duration(frames, duration, target_fps) + return frames + + +def prepare_audio_reference( + source: Any, + *, + model_sample_rate: int, + seconds_start: float, + seconds_total: float, + device: torch.device, +) -> torch.Tensor: + """Decode an audio source into a stereo (2, samples) tensor at the model's rate.""" + target_len = int(model_sample_rate * seconds_total) + start = int(model_sample_rate * seconds_start) + if isinstance(source, str): + data, sr = soundfile.read(materialize_media_source(source), dtype="float32", always_2d=True) + # soundfile returns channels-last (T, C); project convention is (C, T). + wav = torch.from_numpy(data).transpose(0, 1).contiguous() + if sr != model_sample_rate: + wav = taF.resample(wav, sr, model_sample_rate) + elif isinstance(source, torch.Tensor): + wav = source.float() + elif isinstance(source, np.ndarray): + wav = torch.from_numpy(source).float() + else: + raise TypeError(f"Unsupported audio source type: {type(source)!r}") + if wav.dim() == 1: + wav = wav.unsqueeze(0) + if wav.shape[0] == 1: + wav = wav.repeat(2, 1) + elif wav.shape[0] > 2: + wav = wav[:2] + wav = wav[:, start : start + target_len] + if wav.shape[1] < target_len: + wav = F.pad(wav, (target_len - wav.shape[1], 0)) + return wav.to(device=device, dtype=torch.float32) + + +__all__ = [ + "VIDEO_ONLY_TASKS", + "TEXT_VIDEO_TASKS", + "VIDEO_CONDITIONED_TASKS", + "normalize_prompts", + "materialize_media_source", + "load_video_source", + "normalize_video_tensor", + "adjust_video_duration", + "prepare_video_reference", + "prepare_audio_reference", +]