diff --git a/docs/serving/audio_generate_api.md b/docs/serving/audio_generate_api.md new file mode 100644 index 00000000000..e7eaef1860b --- /dev/null +++ b/docs/serving/audio_generate_api.md @@ -0,0 +1,338 @@ +# Audio Generate API + +vLLM-Omni provides an API for text-to-audio generation using diffusion-based models such as Stable Audio. + +Unlike the [Speech API](speech_api.md) which targets text-to-speech synthesis, the Audio Generate API is designed for general-purpose audio generation from text descriptions (sound effects, music, ambient soundscapes, etc.). + +Each server instance runs a single model (specified at startup via `vllm-omni serve --omni`). + +## Quick Start + +### Start the Server + +```bash +vllm-omni serve stabilityai/stable-audio-open-1.0 \ + --host 0.0.0.0 \ + --port 8000 \ + --gpu-memory-utilization 0.9 \ + --trust-remote-code \ + --enforce-eager \ + --omni +``` + +### Generate Audio + +**Using curl:** + +```bash +curl -X POST http://localhost:8000/v1/audio/generate \ + -H "Content-Type: application/json" \ + -d '{ + "input": "The sound of a cat purring", + "audio_length": 10.0 + }' --output cat.wav +``` + +**Using Python:** + +```python +import httpx + +response = httpx.post( + "http://localhost:8000/v1/audio/generate", + json={ + "input": "The sound of a cat purring", + "audio_length": 10.0, + }, + timeout=300.0, +) + +with open("cat.wav", "wb") as f: + f.write(response.content) +``` + +## API Reference + +### Endpoint + +``` +POST /v1/audio/generate +Content-Type: application/json +``` + +### Request Parameters + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `input` | string | **required** | Text prompt describing the audio to generate | +| `model` | string | server's model | Model to use (optional, should match server if specified) | +| `response_format` | string | "wav" | Audio format: wav, mp3, flac, pcm, aac, opus | +| `speed` | float | 1.0 | Playback speed (0.25 - 4.0) | + +#### Diffusion Parameters + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `audio_length` | float | null | Audio duration in seconds (default value is the max ~47s for `stable-audio-open-1.0`) | +| `audio_start` | float | 0.0 | Audio start time in seconds | +| `negative_prompt` | string | null | Text describing what to avoid in generation | +| `guidance_scale` | float | model default | Classifier-free guidance scale (higher = more adherence to prompt) | +| `num_inference_steps` | int | model default | Number of denoising steps (higher = better quality, slower) | +| `seed` | int | null | Random seed for reproducible generation | + +### Response Format + +Returns binary audio data with the appropriate `Content-Type` header: + +| `response_format` | Content-Type | +|--------------------|--------------| +| `wav` | `audio/wav` | +| `mp3` | `audio/mpeg` | +| `flac` | `audio/flac` | +| `pcm` | `audio/pcm` | +| `aac` | `audio/aac` | +| `opus` | `audio/opus` | + +## Examples + +### Basic Generation + +Generate audio with only a text prompt (model defaults for all other parameters): + +```bash +curl -X POST http://localhost:8000/v1/audio/generate \ + -H "Content-Type: application/json" \ + -d '{ + "input": "The sound of ocean waves crashing on a beach" + }' --output ocean.wav +``` + +### Custom Duration + +Specify an explicit audio length in seconds: + +```bash +curl -X POST http://localhost:8000/v1/audio/generate \ + -H "Content-Type: application/json" \ + -d '{ + "input": "A dog barking", + "audio_length": 5.0 + }' --output dog_5s.wav +``` + +### High Quality with Negative Prompt + +Use a negative prompt to steer generation away from undesired characteristics, and increase inference steps for higher quality: + +```bash +curl -X POST http://localhost:8000/v1/audio/generate \ + -H "Content-Type: application/json" \ + -d '{ + "input": "A piano playing a gentle melody", + "audio_length": 10.0, + "negative_prompt": "Low quality, distorted, noisy", + "guidance_scale": 8.0, + "num_inference_steps": 150 + }' --output piano_hq.wav +``` + +### Reproducible Generation + +Set a `seed` to get deterministic results across runs: + +```bash +curl -X POST http://localhost:8000/v1/audio/generate \ + -H "Content-Type: application/json" \ + -d '{ + "input": "Thunder and rain sounds", + "audio_length": 15.0, + "seed": 42 + }' --output thunder.wav +``` + +### Full Control + +Combine all parameters for precise control over generation: + +```bash +curl -X POST http://localhost:8000/v1/audio/generate \ + -H "Content-Type: application/json" \ + -d '{ + "input": "Thunder and rain sounds", + "audio_length": 15.0, + "negative_prompt": "Low quality", + "guidance_scale": 7.0, + "num_inference_steps": 100, + "seed": 42 + }' --output thunder_rain.wav +``` + +### Quick Generation (Fewer Steps) + +For faster generation with slightly lower quality: + +```bash +curl -X POST http://localhost:8000/v1/audio/generate \ + -H "Content-Type: application/json" \ + -d '{ + "input": "Birds chirping in a forest", + "audio_length": 8.0, + "num_inference_steps": 50 + }' --output birds_quick.wav +``` + +### Python Client + +```python +import httpx + +response = httpx.post( + "http://localhost:8000/v1/audio/generate", + json={ + "input": "Thunder and rain", + "audio_length": 15.0, + "negative_prompt": "Low quality", + "guidance_scale": 7.0, + "num_inference_steps": 100, + "seed": 42, + "response_format": "wav", + }, + timeout=300.0, +) + +with open("thunder.wav", "wb") as f: + f.write(response.content) +``` + +## Parameter Tuning Guide + +### `guidance_scale` + +Controls how closely the generated audio follows the text prompt. + +| Range | Behaviour | +|-------|-----------| +| 3 - 5 | More creative / varied output | +| 7 (default) | Balanced adherence | +| 10+ | Strict adherence to the prompt | + +### `num_inference_steps` + +Controls the number of denoising steps in the diffusion process. + +| Steps | Quality | Speed | Use Case | +|-------|---------|-------|----------| +| 50 | Good | Fast | Quick previews | +| 100 | Very Good | Medium | General purpose | +| 150+ | Excellent | Slow | Final / critical audio | + +### `audio_length` + +Duration of the generated audio clip. For `stable-audio-open-1.0`, the maximum is approximately 47 seconds. If omitted, the model uses its own default length. + +### `negative_prompt` + +Describes characteristics to avoid. Common negative prompts include: + +- `"Low quality, distorted, noisy"` +- `"Silence, static"` +- `"Music"` (when generating sound effects only) + +## Supported Models + +| Model | Description | +|-------|-------------| +| `stabilityai/stable-audio-open-1.0` | Open-source audio generation model, up to ~47 seconds, 44.1 kHz stereo | + +## Error Responses + +### 400 Bad Request + +Invalid or missing parameters: + +```json +{ + "error": { + "message": "Audio generation model did not produce audio output.", + "type": "BadRequestError", + "param": null, + "code": 400 + } +} +``` + +### 404 Not Found + +Model mismatch: + +```json +{ + "error": { + "message": "The model `xxx` does not exist.", + "type": "NotFoundError", + "param": "model", + "code": 404 + } +} +``` + +### 422 Unprocessable Entity + +Pydantic validation failure (e.g. invalid `response_format`, `speed` out of range): + +```json +{ + "detail": [ + { + "type": "literal_error", + "msg": "Input should be 'wav', 'pcm', 'flac', 'mp3', 'aac' or 'opus'", + ... + } + ] +} +``` + +## Troubleshooting + +### "Audio generation model did not produce audio output" + +The model finished but returned no audio data. Verify the server started successfully and the model loaded without errors. + +### Server Not Responding + +```bash +# Check if the server is healthy +curl http://localhost:8000/health +``` + +### Audio Quality Issues + +- Increase `num_inference_steps` (e.g. 150). +- Add a negative prompt: `"Low quality, distorted, noisy"`. +- Increase `guidance_scale` for stronger prompt adherence. + +### Generation Timeout + +- Reduce `num_inference_steps`. +- Reduce `audio_length`. +- Check GPU memory with `nvidia-smi`. + +### Out of Memory + +- Lower `--gpu-memory-utilization` (e.g. 0.8). +- Reduce `audio_length`. + +## Development + +Enable debug logging: + +```bash +vllm-omni serve stabilityai/stable-audio-open-1.0 \ + --host 0.0.0.0 \ + --port 8000 \ + --gpu-memory-utilization 0.9 \ + --trust-remote-code \ + --enforce-eager \ + --omni \ + --uvicorn-log-level debug +``` diff --git a/docs/user_guide/examples/online_serving/text_to_audio.md b/docs/user_guide/examples/online_serving/text_to_audio.md new file mode 100644 index 00000000000..b84bde6c482 --- /dev/null +++ b/docs/user_guide/examples/online_serving/text_to_audio.md @@ -0,0 +1,193 @@ +# Text-To-Audio + +Source . + +This example demonstrates how to deploy Stable Audio models for online text-to-audio generation using vLLM-Omni. + +## Supported Models + +| Model | Description | +|-------|-------------| +| `stabilityai/stable-audio-open-1.0` | Open-source audio generation, up to ~47 seconds, 44.1 kHz stereo | + +## Start Server + +### Basic Start + +```bash +vllm-omni serve stabilityai/stable-audio-open-1.0 \ + --host 0.0.0.0 \ + --port 8000 \ + --gpu-memory-utilization 0.9 \ + --trust-remote-code \ + --enforce-eager \ + --omni +``` + +## API Calls + +### Method 1: Using curl + +```bash +# Run all curl examples +bash curl_examples.sh + +# Or execute directly +curl -X POST http://localhost:8000/v1/audio/generate \ + -H "Content-Type: application/json" \ + -d '{ + "input": "The sound of a cat purring", + "audio_length": 10.0 + }' --output cat.wav +``` + +### Method 2: Using Python Client + +```bash +cd examples/online_serving/stable_audio + +# Simple generation +python stable_audio_client.py \ + --text "The sound of a cat purring" + +# With custom duration +python stable_audio_client.py \ + --text "A dog barking" \ + --audio_length 5.0 + +# With all parameters +python stable_audio_client.py \ + --text "Thunder and rain" \ + --audio_length 15.0 \ + --negative_prompt "Low quality" \ + --guidance_scale 7.0 \ + --num_inference_steps 100 \ + --seed 42 \ + --output thunder.wav +``` + +The Python client supports the following command-line arguments: + +- `--api_url`: API endpoint URL (default: `http://localhost:8000/v1/audio/generate`) +- `--text`: Text prompt for audio generation (default: `"The sound of a cat purring"`) +- `--audio_length`: Audio length in seconds (default: `10.0`, max ~47s for `stable-audio-open-1.0`) +- `--audio_start`: Audio start time in seconds (default: `0.0`) +- `--negative_prompt`: Negative prompt for classifier-free guidance (default: `"Low quality"`) +- `--guidance_scale`: Guidance scale for diffusion (default: `7.0`) +- `--num_inference_steps`: Number of inference steps (default: `100`) +- `--seed`: Random seed for reproducibility (default: `None`) +- `--response_format`: Audio output format (default: `wav`). Options: `wav`, `mp3`, `flac`, `pcm` +- `--output`: Output file path (default: `stable_audio_output.wav`) + +### Method 3: Using Python httpx + +```python +import httpx + +response = httpx.post( + "http://localhost:8000/v1/audio/generate", + json={ + "input": "The sound of ocean waves crashing on a beach", + "audio_length": 10.0, + "negative_prompt": "Low quality, distorted", + "guidance_scale": 7.0, + "num_inference_steps": 100, + }, + timeout=300.0, +) + +with open("ocean.wav", "wb") as f: + f.write(response.content) +``` + +## Request Format + +### Simple Generation + +```json +{ + "input": "The sound of ocean waves" +} +``` + +### Generation with Parameters + +```json +{ + "input": "A piano playing a gentle melody", + "audio_length": 10.0, + "negative_prompt": "Low quality, distorted, noisy", + "guidance_scale": 8.0, + "num_inference_steps": 150, + "seed": 42, + "response_format": "wav" +} +``` + +## API Reference + +### Endpoint + +``` +POST /v1/audio/generate +Content-Type: application/json +``` + +### Generation Parameters + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `input` | string | **required** | Text prompt describing the audio to generate | +| `model` | string | server's model | Model to use (optional, should match server if specified) | +| `response_format` | string | "wav" | Audio format: wav, mp3, flac, pcm, aac, opus | +| `speed` | float | 1.0 | Playback speed (0.25 - 4.0) | +| `audio_length` | float | null | Audio duration in seconds (max ~47s for `stable-audio-open-1.0`) | +| `audio_start` | float | 0.0 | Audio start time in seconds | +| `negative_prompt` | string | null | Text describing what to avoid in generation | +| `guidance_scale` | float | model default | Classifier-free guidance scale (higher = more adherence to prompt) | +| `num_inference_steps` | int | model default | Number of denoising steps (higher = better quality, slower) | +| `seed` | int | null | Random seed for reproducible generation | + +### Response Format + +Returns binary audio data with appropriate `Content-Type` header (e.g., `audio/wav`). + +## Tuning Tips + +1. **Audio Length**: Keep under 47 seconds for `stable-audio-open-1.0`. +2. **Quality vs Speed**: + - 50 steps: Fast, decent quality (quick previews) + - 100 steps: Good balance (general purpose) + - 150+ steps: High quality, slower (final / critical audio) +3. **Guidance Scale**: + - Lower (3 - 5): More creative / varied output + - Default (7): Good balance + - Higher (10+): Strict adherence to the prompt +4. **Negative Prompts**: Use to avoid unwanted characteristics such as `"Low quality"`, `"distorted"`, `"noisy"`. +5. **Seeds**: Set a fixed seed to get deterministic, reproducible results. + +## File Description + +| File | Description | +|------|-------------| +| `curl_examples.sh` | Curl examples covering common use cases | +| `stable_audio_client.py` | Python client with full CLI argument support | + +## Troubleshooting + +1. **Audio generation model did not produce audio output**: Verify the server started successfully and the model loaded without errors. +2. **Connection refused**: Make sure the server is running on the correct port. +3. **Generation timeout**: Reduce `num_inference_steps` or `audio_length`, and check GPU memory with `nvidia-smi`. +4. **Out of memory**: Lower `--gpu-memory-utilization` or reduce `audio_length`. +5. **Audio quality issues**: Increase `num_inference_steps`, add a negative prompt, or raise `guidance_scale`. + +## Example materials + +??? abstract "stable_audio_client.py" + ``````py + --8<-- "examples/online_serving/stable_audio/stable_audio_client.py" + `````` +??? abstract "curl_examples.sh" + ``````sh + --8<-- "examples/online_serving/stable_audio/curl_examples.sh" + `````` diff --git a/examples/online_serving/stable_audio/README.md b/examples/online_serving/stable_audio/README.md new file mode 100644 index 00000000000..c8ffaadcaae --- /dev/null +++ b/examples/online_serving/stable_audio/README.md @@ -0,0 +1,234 @@ +# Stable Audio Online Serving + +Generate audio from text prompts using Stable Audio models via an OpenAI-compatible API endpoint. + +## Features + +- **OpenAI-compatible API**: Use `/v1/audio/generate` endpoint +- **Flexible control**: Adjust audio length, guidance scale, inference steps +- **Quality control**: Use negative prompts to avoid unwanted characteristics +- **Reproducible**: Set random seed for deterministic generation + +## Quick Start + +### 1. Start the Server + +```bash +vllm-omni serve stabilityai/stable-audio-open-1.0 \ + --host 0.0.0.0 \ + --port 8000 \ + --gpu-memory-utilization 0.9 \ + --trust-remote-code \ + --enforce-eager \ + --omni +``` + +### 2. Generate Audio + +#### Using curl + +```bash +curl -X POST http://localhost:8000/v1/audio/generate \ + -H "Content-Type: application/json" \ + -d '{ + "input": "The sound of a cat purring", + "audio_length": 10.0 + }' --output cat.wav +``` + +#### Using Python Client + +```bash +python stable_audio_client.py \ + --text "The sound of a cat purring" \ + --audio_length 10.0 \ + --output cat.wav +``` + +#### Using Bash Script + +```bash +bash curl_examples.sh +``` + +## API Reference + +### Endpoint + +``` +POST /v1/audio/generate +``` + +### Request Body + +```json +{ + "input": "Text description of the audio", + "audio_length": 10.0, + "audio_start": 0.0, + "negative_prompt": "Low quality", + "guidance_scale": 7.0, + "num_inference_steps": 100, + "seed": 42, + "response_format": "wav" +} +``` + +### Parameters + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `input` | string | **required** | Text prompt describing the audio to generate | +| `audio_length` | float | ~47s | Audio duration in seconds (max ~47s for stable-audio-open-1.0) | +| `audio_start` | float | 0.0 | Audio start time in seconds | +| `negative_prompt` | string | null | Text describing what to avoid in generation | +| `guidance_scale` | float | 7.0 | Classifier-free guidance scale (higher = more adherence to prompt) | +| `num_inference_steps` | int | 50 | Number of denoising steps (higher = better quality, slower) | +| `seed` | int | null | Random seed for reproducibility | +| `response_format` | string | "wav" | Output format: wav, mp3, flac, pcm | + +### Response + +Returns audio data in the requested format (default: WAV). + +## Usage Examples + +### Basic Generation + +```bash +curl -X POST http://localhost:8000/v1/audio/generate \ + -H "Content-Type: application/json" \ + -d '{ + "input": "The sound of ocean waves" + }' --output ocean.wav +``` + +### Custom Duration + +```bash +curl -X POST http://localhost:8000/v1/audio/generate \ + -H "Content-Type: application/json" \ + -d '{ + "input": "A dog barking", + "audio_length": 5.0 + }' --output dog_5s.wav +``` + +### High Quality with Negative Prompt + +```bash +curl -X POST http://localhost:8000/v1/audio/generate \ + -H "Content-Type: application/json" \ + -d '{ + "input": "A piano playing a gentle melody", + "audio_length": 10.0, + "negative_prompt": "Low quality, distorted, noisy", + "guidance_scale": 8.0, + "num_inference_steps": 150 + }' --output piano_hq.wav +``` + +### Reproducible Generation + +```bash +curl -X POST http://localhost:8000/v1/audio/generate \ + -H "Content-Type: application/json" \ + -d '{ + "input": "Thunder and rain sounds", + "audio_length": 15.0, + "seed": 42 + }' --output thunder.wav +``` + +### Quick Generation (Fewer Steps) + +For faster generation with slightly lower quality: + +```bash +curl -X POST http://localhost:8000/v1/audio/generate \ + -H "Content-Type: application/json" \ + -d '{ + "input": "Birds chirping in a forest", + "audio_length": 8.0, + "num_inference_steps": 50 + }' --output birds_quick.wav +``` + +## Python Client Examples + +### Simple Generation + +```bash +python stable_audio_client.py \ + --text "The sound of a cat purring" +``` + +### Custom Parameters + +```bash +python stable_audio_client.py \ + --text "Thunder and rain" \ + --audio_length 15.0 \ + --negative_prompt "Low quality" \ + --guidance_scale 7.0 \ + --num_inference_steps 100 \ + --seed 42 \ + --output thunder.wav +``` + +### Different Output Format + +```bash +python stable_audio_client.py \ + --text "Guitar playing" \ + --response_format mp3 \ + --output guitar.mp3 +``` + +## Tips + +1. **Audio Length**: Keep under 47 seconds for `stable-audio-open-1.0` +2. **Quality vs Speed**: + - 50 steps: Fast, decent quality + - 100 steps: Good balance (default) + - 150+ steps: High quality, slower +3. **Guidance Scale**: + - Lower (3-5): More creative/varied + - Default (7): Good balance + - Higher (10+): More literal to prompt +4. **Negative Prompts**: Use to avoid "Low quality", "distorted", "noisy", etc. +5. **Seeds**: Use same seed for reproducible results + +## Performance + +| Inference Steps | Quality | Speed | Use Case | +|----------------|---------|-------|----------| +| 50 | Good | Fast | Quick previews | +| 100 (default) | Very Good | Medium | Production | +| 150+ | Excellent | Slow | Final/critical audio | + +## Troubleshooting + +### Server not responding +- Check if server is running: `curl http://localhost:8000/health` +- Check server logs for errors + +### Audio quality issues +- Increase `num_inference_steps` (e.g., 150) +- Add negative prompts: `"Low quality, distorted, noisy"` +- Increase `guidance_scale` for more prompt adherence + +### Generation timeout +- Reduce `num_inference_steps` +- Reduce `audio_length` +- Check GPU memory with `nvidia-smi` + +### Wrong audio length +- Ensure `audio_length` is within model limits (~47s max) +- Adjust `audio_start` if trimming is needed + +## See Also + +- [Offline Inference Example](../../offline_inference/text_to_audio/README.md) +- [Stable Audio Model Card](https://huggingface.co/stabilityai/stable-audio-open-1.0) +- [vLLM-Omni Documentation](https://github.com/vllm-project/vllm-omni) diff --git a/examples/online_serving/stable_audio/curl_examples.sh b/examples/online_serving/stable_audio/curl_examples.sh new file mode 100755 index 00000000000..8c4b0c8463f --- /dev/null +++ b/examples/online_serving/stable_audio/curl_examples.sh @@ -0,0 +1,54 @@ +#!/bin/bash +# Examples for using Stable Audio with curl via /v1/audio/generate endpoint + +# Example 1: Simple request with default parameters +echo "Example 1: Simple request with default parameters" +curl -X POST http://localhost:8000/v1/audio/generate \ + -H "Content-Type: application/json" \ + -d '{ + "input": "The sound audience clapping and cheering in a stadium" + }' --output stadium.wav + +# Example 2: Request with custom audio_length +echo "Example 2: Custom audio length (5 seconds)" +curl -X POST http://localhost:8000/v1/audio/generate \ + -H "Content-Type: application/json" \ + -d '{ + "input": "The sound of a dog barking", + "audio_length": 5.0 + }' --output dog_5s.wav + +# Example 3: Request with negative prompt for quality control +echo "Example 3: With negative prompt" +curl -X POST http://localhost:8000/v1/audio/generate \ + -H "Content-Type: application/json" \ + -d '{ + "input": "A piano playing a gentle melody", + "audio_length": 10.0, + "negative_prompt": "Low quality, distorted, noisy" + }' --output piano.wav + +# Example 4: Full control with all parameters +echo "Example 4: Full control (custom length, guidance, steps, seed)" +curl -X POST http://localhost:8000/v1/audio/generate \ + -H "Content-Type: application/json" \ + -d '{ + "input": "Thunder and rain sounds", + "audio_length": 15.0, + "negative_prompt": "Low quality", + "guidance_scale": 7.0, + "num_inference_steps": 100, + "seed": 42 + }' --output thunder_rain.wav + +# Example 5: Quick generation with fewer steps (faster but lower quality) +echo "Example 5: Quick generation (fewer steps)" +curl -X POST http://localhost:8000/v1/audio/generate \ + -H "Content-Type: application/json" \ + -d '{ + "input": "Ocean waves crashing on a beach", + "audio_length": 8.0, + "num_inference_steps": 50 + }' --output ocean.wav + +echo "All examples completed!" diff --git a/examples/online_serving/stable_audio/stable_audio_client.py b/examples/online_serving/stable_audio/stable_audio_client.py new file mode 100755 index 00000000000..8e6a59d3647 --- /dev/null +++ b/examples/online_serving/stable_audio/stable_audio_client.py @@ -0,0 +1,170 @@ +#!/usr/bin/env python3 +""" +OpenAI-compatible client for Stable Audio via /v1/audio/generate endpoint. + +This script demonstrates how to use the OpenAI-compatible speech API +to generate audio from text using Stable Audio models. + +Examples: + # Simple generation + python stable_audio_client.py --text "The sound of a cat purring" + + # With custom duration + python stable_audio_client.py --text "A dog barking" --audio_length 5.0 + + # With all parameters + python stable_audio_client.py --text "Thunder and rain" \ + --audio_length 15.0 \ + --negative_prompt "Low quality" \ + --guidance_scale 7.0 \ + --num_inference_steps 100 \ + --seed 42 \ + --output thunder.wav +""" + +import argparse +import sys + +import requests + + +def parse_args(): + parser = argparse.ArgumentParser(description="Generate audio with Stable Audio via OpenAI-compatible API") + parser.add_argument( + "--api_url", + default="http://localhost:8000/v1/audio/generate", + help="API endpoint URL", + ) + parser.add_argument( + "--text", + default="The sound of a cat purring", + help="Text prompt for audio generation", + ) + parser.add_argument( + "--audio_length", + type=float, + default=10.0, + help="Audio length in seconds (max ~47s for stable-audio-open-1.0)", + ) + parser.add_argument( + "--audio_start", + type=float, + default=0.0, + help="Audio start time in seconds", + ) + parser.add_argument( + "--negative_prompt", + default="Low quality", + help="Negative prompt for classifier-free guidance", + ) + parser.add_argument( + "--guidance_scale", + type=float, + default=7.0, + help="Guidance scale for diffusion (higher = more adherence to prompt)", + ) + parser.add_argument( + "--num_inference_steps", + type=int, + default=100, + help="Number of inference steps (higher = better quality, slower)", + ) + parser.add_argument( + "--seed", + type=int, + default=None, + help="Random seed for reproducibility", + ) + parser.add_argument( + "--output", + default="stable_audio_output.wav", + help="Output file path", + ) + parser.add_argument( + "--response_format", + default="wav", + choices=["wav", "mp3", "flac", "pcm"], + help="Audio output format", + ) + return parser.parse_args() + + +def generate_audio(args): + """Generate audio using the API.""" + + # Build request payload + payload = { + "input": args.text, + "audio_length": args.audio_length, + "audio_start": args.audio_start, + "response_format": args.response_format, + } + + # Add optional parameters + if args.negative_prompt: + payload["negative_prompt"] = args.negative_prompt + if args.guidance_scale: + payload["guidance_scale"] = args.guidance_scale + if args.num_inference_steps: + payload["num_inference_steps"] = args.num_inference_steps + if args.seed is not None: + payload["seed"] = args.seed + + print(f"\n{'=' * 60}") + print("Stable Audio - Text-to-Audio Generation") + print(f"{'=' * 60}") + print(f"API URL: {args.api_url}") + print(f"Prompt: {args.text}") + print(f"Audio length: {args.audio_length}s") + print(f"Negative prompt: {args.negative_prompt}") + print(f"Guidance scale: {args.guidance_scale}") + print(f"Inference steps: {args.num_inference_steps}") + if args.seed is not None: + print(f"Seed: {args.seed}") + print(f"Output: {args.output}") + print(f"{'=' * 60}\n") + + try: + # Make the API request + print("Generating audio...") + response = requests.post( + args.api_url, + json=payload, + headers={"Content-Type": "application/json"}, + timeout=300, # 5 minute timeout for long generations + ) + + # Check for errors + if response.status_code != 200: + print(f"Error: API returned status code {response.status_code}") + print(f"Response: {response.text}") + return False + + # Save the audio + with open(args.output, "wb") as f: + f.write(response.content) + + print(f"✓ Audio saved to {args.output}") + print(f" File size: {len(response.content) / 1024:.1f} KB") + return True + + except requests.exceptions.Timeout: + print("Error: Request timed out. Try reducing inference steps or audio length.") + return False + except requests.exceptions.ConnectionError: + print(f"Error: Could not connect to {args.api_url}") + print("Make sure the server is running.") + return False + except Exception as e: + print(f"Error: {e}") + return False + + +def main(): + args = parse_args() + success = generate_audio(args) + sys.exit(0 if success else 1) + + +if __name__ == "__main__": + main() diff --git a/tests/entrypoints/openai_api/test_serving_audio_generate.py b/tests/entrypoints/openai_api/test_serving_audio_generate.py new file mode 100644 index 00000000000..fcf22bcefc6 --- /dev/null +++ b/tests/entrypoints/openai_api/test_serving_audio_generate.py @@ -0,0 +1,509 @@ +# tests/entrypoints/openai_api/test_serving_audio_generate.py +import logging +from inspect import Signature, signature +from unittest.mock import MagicMock, patch + +import pytest +import torch +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from vllm_omni.entrypoints.openai.protocol.audio import ( + CreateAudio, + OpenAICreateAudioGenerateRequest, +) +from vllm_omni.entrypoints.openai.serving_audio_generate import ( + OmniOpenAIServingAudioGenerate, +) +from vllm_omni.inputs.data import OmniDiffusionSamplingParams +from vllm_omni.outputs import OmniRequestOutput + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + +logger = logging.getLogger(__name__) + + +# Helper: create a mock audio output for endpoint tests +def create_mock_audio_output( + request_id: str = "audiogen-mock-123", + sample_rate: int = 44100, + num_samples: int = 44100, + audio_key: str = "audio", +) -> OmniRequestOutput: + """Return an OmniRequestOutput mimicking diffusion audio model output.""" + + audio_tensor = torch.sin(torch.linspace(0, 440 * 2 * torch.pi, num_samples)) + + return OmniRequestOutput.from_diffusion( + request_id=request_id, + images=[], + prompt=None, + metrics={}, + multimodal_output={audio_key: audio_tensor, "sr": sample_rate}, + ) + + +def _make_engine_client(*, audio_key: str = "audio", sample_rate: int = 44100): + """Build a mock engine client producing audio output.""" + mock_engine_client = MagicMock() + mock_engine_client.errored = False + mock_engine_client.model_type = "StableAudioPipeline" + mock_engine_client.default_sampling_params_list = [{}] + + async def mock_generate_fn(*args, **kwargs): + yield create_mock_audio_output( + request_id=kwargs.get("request_id", "audiogen-mock"), + sample_rate=sample_rate, + audio_key=audio_key, + ) + + mock_engine_client.generate = MagicMock(side_effect=mock_generate_fn) + return mock_engine_client + + +def _make_server(engine_client=None): + """Build an OmniOpenAIServingAudioGenerate with mocks.""" + if engine_client is None: + engine_client = _make_engine_client() + + mock_models = MagicMock() + mock_models.is_base_model.return_value = True + + return OmniOpenAIServingAudioGenerate( + engine_client=engine_client, + models=mock_models, + request_logger=MagicMock(), + ) + + +@pytest.fixture +def test_app(): + server = _make_server() + + original_fn = server.create_audio_generate + sig = signature(original_fn) + new_params = [p for name, p in sig.parameters.items() if name != "raw_request"] + new_sig = Signature(parameters=new_params, return_annotation=sig.return_annotation) + + async def patched_create_audio_generate(*args, **kwargs): + return await original_fn(*args, **kwargs) + + patched_create_audio_generate.__signature__ = new_sig + server.create_audio_generate = patched_create_audio_generate + + app = FastAPI() + app.add_api_route( + "/v1/audio/generate", + server.create_audio_generate, + methods=["POST"], + response_model=None, + ) + return app + + +@pytest.fixture +def client(test_app): + return TestClient(test_app) + + +# Request Validation (Pydantic model) +class TestRequestValidation: + """Validate OpenAICreateAudioGenerateRequest pydantic constraints.""" + + def test_valid_minimal_request(self): + req = OpenAICreateAudioGenerateRequest(input="A calm piano melody") + assert req.input == "A calm piano melody" + assert req.response_format == "wav" + assert req.speed == 1.0 + + def test_fields_are_wired_correctly(self): + req = OpenAICreateAudioGenerateRequest( + input="rain sounds", + model="stable-audio", + response_format="flac", + speed=1.5, + audio_length=10.0, + audio_start=2.0, + negative_prompt="noise", + guidance_scale=7.5, + num_inference_steps=100, + seed=42, + ) + assert req.input == "rain sounds" + assert req.model == "stable-audio" + assert req.response_format == "flac" + assert req.speed == 1.5 + assert req.audio_length == 10.0 + assert req.audio_start == 2.0 + assert req.negative_prompt == "noise" + assert req.guidance_scale == 7.5 + assert req.num_inference_steps == 100 + assert req.seed == 42 + + def test_invalid_response_format(self): + with pytest.raises(Exception): + OpenAICreateAudioGenerateRequest(input="test", response_format="invalid_format") + + def test_speed_lower_bound(self): + with pytest.raises(Exception): + OpenAICreateAudioGenerateRequest(input="test", speed=0.1) + + def test_speed_upper_bound(self): + with pytest.raises(Exception): + OpenAICreateAudioGenerateRequest(input="test", speed=5.0) + + def test_speed_at_boundaries(self): + req_low = OpenAICreateAudioGenerateRequest(input="test", speed=0.25) + assert req_low.speed == 0.25 + req_high = OpenAICreateAudioGenerateRequest(input="test", speed=4.0) + assert req_high.speed == 4.0 + + def test_stream_format_sse_rejected(self): + with pytest.raises(Exception): + OpenAICreateAudioGenerateRequest(input="test", stream_format="sse") + + def test_stream_format_audio_accepted(self): + req = OpenAICreateAudioGenerateRequest(input="test", stream_format="audio") + assert req.stream_format == "audio" + + +# Constructor & Class Methods +class TestConstructor: + def test_default_init(self): + server = _make_server() + assert server.diffusion_mode is False + + def test_for_diffusion_factory(self): + engine_client = _make_engine_client() + mock_models = MagicMock() + mock_models.is_base_model.return_value = True + + server = OmniOpenAIServingAudioGenerate.for_diffusion( + engine_client=engine_client, + models=mock_models, + request_logger=MagicMock(), + ) + assert server.diffusion_mode is True + + def test_is_stable_audio_model_true(self): + server = _make_server() + assert server._is_stable_audio_model() is True + + def test_is_stable_audio_model_false(self): + engine = _make_engine_client() + engine.model_type = "SomeOtherModel" + server = _make_server(engine_client=engine) + assert server._is_stable_audio_model() is False + + +# Parameter Wiring — verify request params reach the engine +class TestParameterWiring: + """Ensure request parameters are correctly forwarded to the engine.""" + + @pytest.fixture + def server_and_engine(self): + engine = _make_engine_client() + server = _make_server(engine_client=engine) + return server, engine + + @pytest.mark.asyncio + async def test_prompt_wiring(self, server_and_engine): + server, engine = server_and_engine + req = OpenAICreateAudioGenerateRequest(input="birds chirping") + await server.create_audio_generate(req) + + engine.generate.assert_called_once() + call_kwargs = engine.generate.call_args[1] + assert call_kwargs["prompt"]["prompt"] == "birds chirping" + assert call_kwargs["output_modalities"] == ["audio"] + + @pytest.mark.asyncio + async def test_negative_prompt_wiring(self, server_and_engine): + server, engine = server_and_engine + req = OpenAICreateAudioGenerateRequest(input="a calm ocean", negative_prompt="noise distortion") + await server.create_audio_generate(req) + + call_kwargs = engine.generate.call_args[1] + assert call_kwargs["prompt"]["negative_prompt"] == "noise distortion" + + @pytest.mark.asyncio + async def test_negative_prompt_absent(self, server_and_engine): + server, engine = server_and_engine + req = OpenAICreateAudioGenerateRequest(input="a calm ocean") + await server.create_audio_generate(req) + + call_kwargs = engine.generate.call_args[1] + assert "negative_prompt" not in call_kwargs["prompt"] + + @pytest.mark.asyncio + async def test_guidance_scale_wiring(self, server_and_engine): + server, engine = server_and_engine + req = OpenAICreateAudioGenerateRequest(input="test", guidance_scale=12.0) + await server.create_audio_generate(req) + + call_kwargs = engine.generate.call_args[1] + sp = call_kwargs["sampling_params_list"][0] + assert isinstance(sp, OmniDiffusionSamplingParams) + assert sp.guidance_scale == 12.0 + + @pytest.mark.asyncio + async def test_num_inference_steps_wiring(self, server_and_engine): + server, engine = server_and_engine + req = OpenAICreateAudioGenerateRequest(input="test", num_inference_steps=200) + await server.create_audio_generate(req) + + sp = engine.generate.call_args[1]["sampling_params_list"][0] + assert sp.num_inference_steps == 200 + + @pytest.mark.asyncio + async def test_seed_creates_generator(self, server_and_engine): + server, engine = server_and_engine + req = OpenAICreateAudioGenerateRequest(input="test", seed=42) + + with patch("vllm_omni.entrypoints.openai.serving_audio_generate.torch") as mock_torch: + mock_gen = MagicMock() + mock_gen.manual_seed.return_value = mock_gen + mock_torch.Generator.return_value = mock_gen + + await server.create_audio_generate(req) + + mock_torch.Generator.assert_called_once() + mock_gen.manual_seed.assert_called_once_with(42) + + @pytest.mark.asyncio + async def test_seed_none_skips_generator(self, server_and_engine): + server, engine = server_and_engine + req = OpenAICreateAudioGenerateRequest(input="test") + + await server.create_audio_generate(req) + + sp = engine.generate.call_args[1]["sampling_params_list"][0] + assert sp.generator is None + + @pytest.mark.asyncio + async def test_audio_length_wiring(self, server_and_engine): + server, engine = server_and_engine + req = OpenAICreateAudioGenerateRequest(input="test", audio_length=10.0, audio_start=2.0) + await server.create_audio_generate(req) + + sp = engine.generate.call_args[1]["sampling_params_list"][0] + assert sp.extra_args["audio_start_in_s"] == 2.0 + assert sp.extra_args["audio_end_in_s"] == 12.0 # start + length + + @pytest.mark.asyncio + async def test_audio_length_default_start(self, server_and_engine): + server, engine = server_and_engine + req = OpenAICreateAudioGenerateRequest(input="test", audio_length=5.0) + await server.create_audio_generate(req) + + sp = engine.generate.call_args[1]["sampling_params_list"][0] + assert sp.extra_args["audio_start_in_s"] == 0.0 + assert sp.extra_args["audio_end_in_s"] == 5.0 + + @pytest.mark.asyncio + async def test_no_audio_length_skips_extra_args(self, server_and_engine): + server, engine = server_and_engine + req = OpenAICreateAudioGenerateRequest(input="test") + await server.create_audio_generate(req) + + sp = engine.generate.call_args[1]["sampling_params_list"][0] + assert sp.extra_args == {} + + @pytest.mark.asyncio + async def test_defaults_not_set_when_omitted(self, server_and_engine): + """Guidance scale and num_inference_steps keep dataclass defaults when not in request.""" + server, engine = server_and_engine + req = OpenAICreateAudioGenerateRequest(input="test") + await server.create_audio_generate(req) + + sp = engine.generate.call_args[1]["sampling_params_list"][0] + defaults = OmniDiffusionSamplingParams() + assert sp.guidance_scale == defaults.guidance_scale + assert sp.num_inference_steps == defaults.num_inference_steps + + +# Audio Response Format +class TestAudioResponseFormat: + def test_wav_response(self, client): + payload = {"input": "a gentle rain", "response_format": "wav"} + response = client.post("/v1/audio/generate", json=payload) + assert response.status_code == 200 + assert response.headers["content-type"] == "audio/wav" + assert len(response.content) > 0 + + def test_mp3_response(self, client): + payload = {"input": "a gentle rain", "response_format": "mp3"} + response = client.post("/v1/audio/generate", json=payload) + assert response.status_code == 200 + assert response.headers["content-type"] == "audio/mpeg" + assert len(response.content) > 0 + + def test_flac_response(self, client): + payload = {"input": "a gentle rain", "response_format": "flac"} + response = client.post("/v1/audio/generate", json=payload) + assert response.status_code == 200 + assert response.headers["content-type"] == "audio/flac" + assert len(response.content) > 0 + + def test_invalid_format_rejected(self, client): + payload = {"input": "test", "response_format": "banana"} + response = client.post("/v1/audio/generate", json=payload) + assert response.status_code == 422 + + @patch("vllm_omni.entrypoints.openai.serving_audio_generate.OmniOpenAIServingAudioGenerate.create_audio") + def test_speed_parameter_forwarded(self, mock_create_audio, test_app): + mock_audio_response = MagicMock() + mock_audio_response.audio_data = b"dummy_audio" + mock_audio_response.media_type = "audio/wav" + mock_create_audio.return_value = mock_audio_response + + c = TestClient(test_app) + payload = {"input": "test", "response_format": "wav", "speed": 2.5} + c.post("/v1/audio/generate", json=payload) + + mock_create_audio.assert_called_once() + audio_obj = mock_create_audio.call_args[0][0] + assert isinstance(audio_obj, CreateAudio) + assert audio_obj.speed == 2.5 + + @patch("vllm_omni.entrypoints.openai.serving_audio_generate.OmniOpenAIServingAudioGenerate.create_audio") + def test_sample_rate_from_output(self, mock_create_audio, test_app): + mock_audio_response = MagicMock() + mock_audio_response.audio_data = b"dummy" + mock_audio_response.media_type = "audio/wav" + mock_create_audio.return_value = mock_audio_response + + c = TestClient(test_app) + payload = {"input": "test"} + c.post("/v1/audio/generate", json=payload) + + audio_obj = mock_create_audio.call_args[0][0] + assert audio_obj.sample_rate == 44100 # Stable Audio default + + +# Error Handling +class TestErrorHandling: + @pytest.mark.asyncio + async def test_no_output_returns_error(self): + engine = _make_engine_client() + + async def empty_gen(*args, **kwargs): + return + yield # noqa: unreachable – makes this an async generator + + engine.generate = MagicMock(side_effect=empty_gen) + server = _make_server(engine_client=engine) + req = OpenAICreateAudioGenerateRequest(input="test") + resp = await server.create_audio_generate(req) + + # create_error_response returns an ErrorResponse with .error.message + assert "No output generated" in resp.error.message + + @pytest.mark.asyncio + async def test_no_audio_in_output_returns_error(self): + engine = _make_engine_client() + + async def gen_without_audio(*args, **kwargs): + yield OmniRequestOutput.from_diffusion( + request_id="test", + images=[], + prompt=None, + metrics={}, + multimodal_output={}, # no audio key + ) + + engine.generate = MagicMock(side_effect=gen_without_audio) + server = _make_server(engine_client=engine) + req = OpenAICreateAudioGenerateRequest(input="test") + resp = await server.create_audio_generate(req) + + assert "did not produce audio" in resp.error.message + + @pytest.mark.asyncio + async def test_engine_errored_raises(self): + engine = _make_engine_client() + engine.errored = True + engine.dead_error = RuntimeError("engine is dead") + server = _make_server(engine_client=engine) + + req = OpenAICreateAudioGenerateRequest(input="test") + with pytest.raises(RuntimeError, match="engine is dead"): + await server.create_audio_generate(req) + + @pytest.mark.asyncio + async def test_model_outputs_key_fallback(self): + """Audio data under 'model_outputs' key should be accepted.""" + engine = _make_engine_client(audio_key="model_outputs") + server = _make_server(engine_client=engine) + req = OpenAICreateAudioGenerateRequest(input="test") + resp = await server.create_audio_generate(req) + + # Should succeed and return a Response with audio bytes + assert hasattr(resp, "body") + assert len(resp.body) > 0 + + @pytest.mark.asyncio + async def test_value_error_returns_error_response(self): + engine = _make_engine_client() + + async def gen_value_error(*args, **kwargs): + raise ValueError("bad value") + yield # noqa: unreachable + + engine.generate = MagicMock(side_effect=gen_value_error) + server = _make_server(engine_client=engine) + req = OpenAICreateAudioGenerateRequest(input="test") + resp = await server.create_audio_generate(req) + + assert "bad value" in resp.error.message + + @pytest.mark.asyncio + async def test_generic_exception_returns_error_response(self): + engine = _make_engine_client() + + async def gen_runtime_error(*args, **kwargs): + raise RuntimeError("something went wrong") + yield # noqa: unreachable + + engine.generate = MagicMock(side_effect=gen_runtime_error) + server = _make_server(engine_client=engine) + req = OpenAICreateAudioGenerateRequest(input="test") + resp = await server.create_audio_generate(req) + + assert "Audio generation failed" in resp.error.message + + +# End-to-End via TestClient +class TestAudioGenerateAPI: + def test_basic_success(self, client): + payload = {"input": "ambient forest sounds"} + response = client.post("/v1/audio/generate", json=payload) + assert response.status_code == 200 + assert len(response.content) > 0 + + def test_with_all_params(self, client): + payload = { + "input": "gentle piano", + "response_format": "wav", + "speed": 1.0, + "audio_length": 5.0, + "audio_start": 0.0, + "negative_prompt": "noise", + "guidance_scale": 7.0, + "num_inference_steps": 50, + "seed": 123, + } + response = client.post("/v1/audio/generate", json=payload) + assert response.status_code == 200 + assert response.headers["content-type"] == "audio/wav" + + def test_missing_input_rejected(self, client): + payload = {} + response = client.post("/v1/audio/generate", json=payload) + assert response.status_code == 422 + + def test_extra_unknown_fields_ignored(self, client): + payload = {"input": "test", "unknown_field": "value"} + response = client.post("/v1/audio/generate", json=payload) + # Pydantic v2 ignores extra fields by default + assert response.status_code == 200 diff --git a/tests/entrypoints/test_omni_diffusion.py b/tests/entrypoints/test_omni_diffusion.py index 90b0fd05326..2fbe7a8b42b 100644 --- a/tests/entrypoints/test_omni_diffusion.py +++ b/tests/entrypoints/test_omni_diffusion.py @@ -612,7 +612,7 @@ def test_initialize_stage_configs_called_when_none( """Test that stage configs are auto-loaded when stage_configs_path is None.""" def _fake_loader( - model: str, + config_path: str, stage_configs_path: str | None = None, base_engine_args: dict | None = None, default_stage_cfg_factory=None, @@ -687,7 +687,7 @@ def test_generate_raises_on_length_mismatch(monkeypatch: pytest.MonkeyPatch, moc """Test that generate raises ValueError when sampling_params_list length doesn't match.""" def _fake_loader( - model: str, + config_path: str, stage_configs_path: str | None = None, base_engine_args: dict | None = None, default_stage_cfg_factory=None, @@ -742,7 +742,7 @@ def test_generate_pipeline_and_final_outputs(monkeypatch: pytest.MonkeyPatch, mo stage_cfg1["processed_input"] = ["processed-for-stage-1"] def _fake_loader( - model: str, + config_path: str, stage_configs_path: str | None = None, base_engine_args: dict | None = None, default_stage_cfg_factory=None, @@ -848,7 +848,7 @@ def test_generate_pipeline_with_batch_input(monkeypatch: pytest.MonkeyPatch, moc stage_cfg1["stage_id"] = 1 def _fake_loader( - model: str, + config_path: str, stage_configs_path: str | None = None, base_engine_args: dict | None = None, default_stage_cfg_factory=None, @@ -969,7 +969,7 @@ def test_generate_no_final_output_returns_empty( stage_cfg1["final_output"] = False def _fake_loader( - model: str, + config_path: str, stage_configs_path: str | None = None, base_engine_args: dict | None = None, default_stage_cfg_factory=None, @@ -1062,7 +1062,7 @@ def test_generate_sampling_params_none_use_default( stage_cfg1["final_output"] = False def _fake_loader( - model: str, + config_path: str, stage_configs_path: str | None = None, base_engine_args: dict | None = None, default_stage_cfg_factory=None, @@ -1140,7 +1140,7 @@ def test_wait_for_stages_ready_timeout(monkeypatch: pytest.MonkeyPatch, mocker: """Test that _wait_for_stages_ready handles timeout correctly.""" def _fake_loader( - model: str, + config_path: str, stage_configs_path: str | None = None, base_engine_args: dict | None = None, default_stage_cfg_factory=None, @@ -1202,7 +1202,7 @@ def test_generate_handles_error_messages(monkeypatch: pytest.MonkeyPatch, mocker """Test that generate handles error messages from stages correctly.""" def _fake_loader( - model: str, + config_path: str, stage_configs_path: str | None = None, base_engine_args: dict | None = None, default_stage_cfg_factory=None, @@ -1283,7 +1283,7 @@ def test_close_sends_shutdown_signal(monkeypatch: pytest.MonkeyPatch, mocker: Mo """Test that close() sends shutdown signal to all input queues.""" def _fake_loader( - model: str, + config_path: str, stage_configs_path: str | None = None, base_engine_args: dict | None = None, default_stage_cfg_factory=None, diff --git a/tests/entrypoints/test_omni_llm.py b/tests/entrypoints/test_omni_llm.py index b11b14c8939..0e1745c6833 100644 --- a/tests/entrypoints/test_omni_llm.py +++ b/tests/entrypoints/test_omni_llm.py @@ -577,7 +577,7 @@ def test_initialize_stage_configs_called_when_none( """Test that stage configs are auto-loaded when stage_configs_path is None.""" def _fake_loader( - model: str, + config_path: str, stage_configs_path: str | None = None, base_engine_args: dict | None = None, default_stage_cfg_factory=None, @@ -654,7 +654,7 @@ def test_generate_raises_on_length_mismatch(monkeypatch: pytest.MonkeyPatch, moc """Test that generate raises ValueError when sampling_params_list length doesn't match.""" def _fake_loader( - model: str, + config_path: str, stage_configs_path: str | None = None, base_engine_args: dict | None = None, default_stage_cfg_factory=None, @@ -711,7 +711,7 @@ def test_generate_pipeline_and_final_outputs(monkeypatch: pytest.MonkeyPatch, mo stage_cfg1["processed_input"] = ["processed-for-stage-1"] def _fake_loader( - model: str, + config_path: str, stage_configs_path: str | None = None, base_engine_args: dict | None = None, default_stage_cfg_factory=None, @@ -819,7 +819,7 @@ def test_generate_no_final_output_returns_empty( stage_cfg1["final_output"] = False def _fake_loader( - model: str, + config_path: str, stage_configs_path: str | None = None, base_engine_args: dict | None = None, default_stage_cfg_factory=None, @@ -911,7 +911,7 @@ def test_generate_sampling_params_none_use_default( stage_cfg1["final_output"] = False def _fake_loader( - model: str, + config_path: str, stage_configs_path: str | None = None, base_engine_args: dict | None = None, default_stage_cfg_factory=None, @@ -988,7 +988,7 @@ def test_wait_for_stages_ready_timeout(monkeypatch: pytest.MonkeyPatch, mocker: """Test that _wait_for_stages_ready handles timeout correctly.""" def _fake_loader( - model: str, + config_path: str, stage_configs_path: str | None = None, base_engine_args: dict | None = None, default_stage_cfg_factory=None, @@ -1052,7 +1052,7 @@ def test_generate_handles_error_messages(monkeypatch: pytest.MonkeyPatch, mocker """Test that generate handles error messages from stages correctly.""" def _fake_loader( - model: str, + config_path: str, stage_configs_path: str | None = None, base_engine_args: dict | None = None, default_stage_cfg_factory=None, @@ -1135,7 +1135,7 @@ def test_close_sends_shutdown_signal(monkeypatch: pytest.MonkeyPatch, mocker: Mo """Test that close() sends shutdown signal to all input queues.""" def _fake_loader( - model: str, + config_path: str, stage_configs_path: str | None = None, base_engine_args: dict | None = None, default_stage_cfg_factory=None, diff --git a/vllm_omni/diffusion/models/stable_audio/pipeline_stable_audio.py b/vllm_omni/diffusion/models/stable_audio/pipeline_stable_audio.py index 22dfc06c5a8..a1eb74eaf18 100644 --- a/vllm_omni/diffusion/models/stable_audio/pipeline_stable_audio.py +++ b/vllm_omni/diffusion/models/stable_audio/pipeline_stable_audio.py @@ -356,7 +356,6 @@ def forward( negative_prompt: str | list[str] | None = None, audio_end_in_s: float | None = None, audio_start_in_s: float = 0.0, - num_inference_steps: int = 100, guidance_scale: float = 7.0, num_waveforms_per_prompt: int = 1, generator: torch.Generator | list[torch.Generator] | None = None, @@ -374,7 +373,6 @@ def forward( negative_prompt: Negative prompt for CFG audio_end_in_s: Audio end time in seconds (max ~47s for stable-audio-open-1.0) audio_start_in_s: Audio start time in seconds - num_inference_steps: Number of denoising steps guidance_scale: CFG scale num_waveforms_per_prompt: Number of audio outputs per prompt generator: Random generator for reproducibility @@ -395,7 +393,7 @@ def forward( elif req.prompts: negative_prompt = ["" if isinstance(p, str) else (p.get("negative_prompt") or "") for p in req.prompts] - num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps + num_inference_steps = req.sampling_params.num_inference_steps if req.sampling_params.guidance_scale_provided: guidance_scale = req.sampling_params.guidance_scale diff --git a/vllm_omni/entrypoints/omni_llm.py b/vllm_omni/entrypoints/omni_llm.py index 27f470b4b5e..91c1c6bed1d 100644 --- a/vllm_omni/entrypoints/omni_llm.py +++ b/vllm_omni/entrypoints/omni_llm.py @@ -28,6 +28,7 @@ load_stage_configs_from_model, load_stage_configs_from_yaml, resolve_model_config_path, + resolve_model_type, ) logger = init_logger(__name__) @@ -84,12 +85,13 @@ def __init__( self.worker_backend = kwargs.get("worker_backend", "multi_process") self.ray_address = kwargs.get("ray_address", None) self.batch_timeout = batch_timeout + self.model_type = resolve_model_type(model) self.log_stats: bool = bool(log_stats) # Load stage configurations if stage_configs_path is None: - self.config_path = resolve_model_config_path(model) - self.stage_configs = load_stage_configs_from_model(model) + self.config_path = resolve_model_config_path(self.model_type) + self.stage_configs = load_stage_configs_from_model(config_path=self.config_path) else: self.config_path = stage_configs_path self.stage_configs = load_stage_configs_from_yaml(stage_configs_path) diff --git a/vllm_omni/entrypoints/openai/api_server.py b/vllm_omni/entrypoints/openai/api_server.py index 6db97dd2ddb..bb4649bf148 100644 --- a/vllm_omni/entrypoints/openai/api_server.py +++ b/vllm_omni/entrypoints/openai/api_server.py @@ -87,7 +87,10 @@ encode_image_base64, parse_size, ) -from vllm_omni.entrypoints.openai.protocol.audio import OpenAICreateSpeechRequest +from vllm_omni.entrypoints.openai.protocol.audio import ( + OpenAICreateAudioGenerateRequest, + OpenAICreateSpeechRequest, +) from vllm_omni.entrypoints.openai.protocol.images import ( ImageData, ImageGenerationRequest, @@ -97,6 +100,7 @@ VideoGenerationRequest, VideoGenerationResponse, ) +from vllm_omni.entrypoints.openai.serving_audio_generate import OmniOpenAIServingAudioGenerate from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat from vllm_omni.entrypoints.openai.serving_speech import OmniOpenAIServingSpeech from vllm_omni.entrypoints.openai.serving_video import OmniOpenAIServingVideo @@ -175,8 +179,29 @@ class _DiffusionServingModels: provide a lightweight fallback. """ + class _NullModelConfig: + def __getattr__(self, name): + return None + + class _Unsupported: + def __init__(self, name: str): + self.name = name + + def __call__(self, *args, **kwargs): + raise NotImplementedError(f"{self.name} is not supported in diffusion mode") + + def __getattr__(self, attr): + raise NotImplementedError(f"{self.name}.{attr} is not supported in diffusion mode") + def __init__(self, base_model_paths: list[BaseModelPath]) -> None: self._base_model_paths = base_model_paths + self.model_config = self._NullModelConfig() + + def __getattr__(self, name): + """Return a sentinel that raises NotImplementedError if called or + accessed, so any use of unsupported OpenAIServingModels features in + diffusion mode fails loudly with a descriptive message.""" + return self._Unsupported(name) async def show_available_models(self) -> ModelList: return ModelList( @@ -434,10 +459,10 @@ async def omni_init_app_state( # For omni models state.stage_configs = engine_client.stage_configs if hasattr(engine_client, "stage_configs") else None + model_name = served_model_names[0] if served_model_names else args.model # Pure Diffusion mode: use simplified initialization logic if is_pure_diffusion: - model_name = served_model_names[0] if served_model_names else args.model state.vllm_config = None state.diffusion_engine = engine_client state.openai_serving_models = _DiffusionServingModels(base_model_paths) @@ -450,6 +475,16 @@ async def omni_init_app_state( model_name=model_name, ) + # audio related + state.openai_serving_speech = None + state.openai_serving_audio_generate = OmniOpenAIServingAudioGenerate.for_diffusion( + engine_client, + state.openai_serving_models, + request_logger=request_logger, + model_name=model_name, + ) + + # video related diffusion_stage_configs = engine_client.stage_configs if hasattr(engine_client, "stage_configs") else None state.openai_serving_video = OmniOpenAIServingVideo.for_diffusion( diffusion_engine=engine_client, # type: ignore @@ -457,7 +492,6 @@ async def omni_init_app_state( stage_configs=diffusion_stage_configs, ) - state.enable_server_load_tracking = getattr(args, "enable_server_load_tracking", False) state.server_load_metrics = 0 logger.info("Pure diffusion API server initialized for model: %s", model_name) return @@ -732,7 +766,11 @@ async def omni_init_app_state( ) state.openai_serving_speech = OmniOpenAIServingSpeech( - engine_client, state.openai_serving_models, request_logger=request_logger + engine_client, state.openai_serving_models, request_logger=request_logger, model_name=model_name + ) + + state.openai_serving_audio_generate = OmniOpenAIServingAudioGenerate( + engine_client, state.openai_serving_models, request_logger=request_logger, model_name=model_name ) state.openai_serving_video = OmniOpenAIServingVideo( @@ -757,6 +795,10 @@ def Omnispeech(request: Request) -> OmniOpenAIServingSpeech | None: return request.app.state.openai_serving_speech +def OmniAudioGenerate(request: Request) -> OmniOpenAIServingAudioGenerate | None: + return getattr(request.app.state, "openai_serving_audio_generate", None) + + @router.post( "/v1/chat/completions", dependencies=[Depends(validate_json_request)], @@ -867,6 +909,34 @@ async def create_speech(request: OpenAICreateSpeechRequest, raw_request: Request raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)) from e +@router.post( + "/v1/audio/generate", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.OK.value: {"content": {"audio/*": {}}}, + HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, + HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse}, + HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, + }, +) +@with_cancellation +@load_aware_call +async def create_audio_generate(request: OpenAICreateAudioGenerateRequest, raw_request: Request): + handler = OmniAudioGenerate(raw_request) + if handler is None: + base_server = getattr(raw_request.app.state, "openai_serving_tokenization", None) + if base_server is None: + raise HTTPException( + status_code=HTTPStatus.NOT_FOUND.value, + detail="The model does not support Audio Generate API", + ) + return base_server.create_error_response(message="The model does not support Audio Generate API") + try: + return await handler.create_audio_generate(request, raw_request) + except Exception as e: + raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)) from e + + @router.get( "/v1/audio/voices", responses={ diff --git a/vllm_omni/entrypoints/openai/audio_utils_mixin.py b/vllm_omni/entrypoints/openai/audio_utils_mixin.py index 13df32ebe00..3023b67ad65 100644 --- a/vllm_omni/entrypoints/openai/audio_utils_mixin.py +++ b/vllm_omni/entrypoints/openai/audio_utils_mixin.py @@ -45,6 +45,10 @@ def create_audio(self, audio_obj: CreateAudio) -> AudioResponse: "Only mono (1D) and stereo (2D) are supported." ) + if audio_tensor.ndim == 2 and audio_tensor.shape[0] == 2: + # Convert from [channels, samples] to [samples, channels] + audio_tensor = audio_tensor.T + audio_tensor, sample_rate = self._apply_speed_adjustment(audio_tensor, speed, sample_rate) supported_formats = { diff --git a/vllm_omni/entrypoints/openai/protocol/audio.py b/vllm_omni/entrypoints/openai/protocol/audio.py index 1efab8ebec2..91690cbf0c2 100644 --- a/vllm_omni/entrypoints/openai/protocol/audio.py +++ b/vllm_omni/entrypoints/openai/protocol/audio.py @@ -85,6 +85,53 @@ def validate_streaming_constraints(self) -> "OpenAICreateSpeechRequest": return self +class OpenAICreateAudioGenerateRequest(BaseModel): + """Request model for audio generation via diffusion models (e.g. Stable Audio).""" + + input: str = Field( + description="Text prompt describing the audio to generate", + ) + model: str | None = None + response_format: Literal["wav", "pcm", "flac", "mp3", "aac", "opus"] = "wav" + speed: float | None = Field( + default=1.0, + ge=0.25, + le=4.0, + ) + stream_format: Literal["sse", "audio"] | None = "audio" + audio_length: float | None = Field( + default=None, + description="Audio length in seconds", + ) + audio_start: float | None = Field( + default=0.0, + description="Audio start time in seconds", + ) + negative_prompt: str | None = Field( + default=None, + description="Negative prompt for classifier-free guidance", + ) + guidance_scale: float | None = Field( + default=None, + description="Guidance scale for diffusion models", + ) + num_inference_steps: int | None = Field( + default=None, + description="Number of inference steps", + ) + seed: int | None = Field( + default=None, + description="Random seed for reproducibility", + ) + + @field_validator("stream_format") + @classmethod + def validate_stream_format(cls, v: str) -> str: + if v == "sse": + raise ValueError("'sse' is not a supported stream_format yet. Please use 'audio'.") + return v + + class CreateAudio(BaseModel): audio_tensor: np.ndarray sample_rate: int = 24000 diff --git a/vllm_omni/entrypoints/openai/serving_audio_generate.py b/vllm_omni/entrypoints/openai/serving_audio_generate.py new file mode 100644 index 00000000000..bd5c37157a0 --- /dev/null +++ b/vllm_omni/entrypoints/openai/serving_audio_generate.py @@ -0,0 +1,168 @@ +import asyncio + +import torch +from fastapi import Request +from fastapi.responses import Response +from vllm.entrypoints.openai.engine.serving import OpenAIServing +from vllm.logger import init_logger +from vllm.utils import random_uuid + +from vllm_omni.entrypoints.openai.audio_utils_mixin import AudioMixin +from vllm_omni.entrypoints.openai.protocol.audio import ( + AudioResponse, + CreateAudio, + OpenAICreateAudioGenerateRequest, +) +from vllm_omni.inputs.data import OmniDiffusionSamplingParams +from vllm_omni.outputs import OmniRequestOutput + +logger = init_logger(__name__) + + +class OmniOpenAIServingAudioGenerate(OpenAIServing, AudioMixin): + """Serving class for audio generation via diffusion models (e.g. Stable Audio).""" + + def __init__(self, *args, **kwargs): + self.model_name = kwargs.pop("model_name", None) + super().__init__(*args, **kwargs) + self.diffusion_mode = False + + @classmethod + def for_diffusion(cls, *args, **kwargs) -> "OmniOpenAIServingAudioGenerate": + """Create an instance configured to run in diffusion mode.""" + instance = cls(*args, **kwargs) + instance.diffusion_mode = True + return instance + + def _is_stable_audio_model(self) -> bool: + return self.engine_client.model_type == "StableAudioPipeline" + + async def create_audio_generate( + self, + request: OpenAICreateAudioGenerateRequest, + raw_request: Request | None = None, + ): + """ + Generate audio using diffusion-based models (e.g. Stable Audio). + + This endpoint is designed for audio generation models as + opposed to TTS models that specifically generate speech. + """ + + error_check_ret = await self._check_model(request) + if error_check_ret is not None: + logger.error("Error with model %s", error_check_ret) + return error_check_ret + + if self.engine_client.errored: + raise self.engine_client.dead_error + + request_id = f"audiogen-{random_uuid()}" + + try: + default_sr = 44100 # Default sample rate for Stable Audio + + # Build prompt for diffusion audio generation + prompt = { + "prompt": request.input, + } + if request.negative_prompt: + prompt["negative_prompt"] = request.negative_prompt + + # Build sampling params for diffusion + sampling_params_list = [OmniDiffusionSamplingParams(num_outputs_per_prompt=1)] + + # Create generator if seed provided + if request.seed is not None: + from vllm_omni.platforms import current_omni_platform + + generator = torch.Generator(device=current_omni_platform.device_type).manual_seed(request.seed) + sampling_params_list[0].generator = generator + + if request.guidance_scale is not None: + sampling_params_list[0].guidance_scale = request.guidance_scale + + if request.num_inference_steps is not None: + sampling_params_list[0].num_inference_steps = request.num_inference_steps + + # Set up audio duration parameters + if request.audio_length is not None: + audio_length = request.audio_length + audio_start = request.audio_start if request.audio_start is not None else 0.0 + audio_end_in_s = audio_start + audio_length + sampling_params_list[0].extra_args = { + "audio_start_in_s": audio_start, + "audio_end_in_s": audio_end_in_s, + } + + logger.info( + "Audio generation request %s: prompt=%r", + request_id, + request.input[:50] + "..." if len(request.input) > 50 else request.input, + ) + + generator = self.engine_client.generate( + prompt=prompt, + request_id=request_id, + sampling_params_list=sampling_params_list, + output_modalities=["audio"], + ) + + final_output: OmniRequestOutput | None = None + async for res in generator: + final_output = res + + if final_output is None: + return self.create_error_response("No output generated from the model.") + + # Extract audio from output + audio_output = None + if hasattr(final_output, "multimodal_output") and final_output.multimodal_output: + audio_output = final_output.multimodal_output + if not audio_output and hasattr(final_output, "request_output"): + if final_output.request_output and hasattr(final_output.request_output, "multimodal_output"): + audio_output = final_output.request_output.multimodal_output + + # Check for audio data using either "audio" or "model_outputs" key + audio_key = None + if audio_output: + if "audio" in audio_output: + audio_key = "audio" + elif "model_outputs" in audio_output: + audio_key = "model_outputs" + + if not audio_output or audio_key is None: + return self.create_error_response("Audio generation model did not produce audio output.") + + audio_tensor = audio_output[audio_key] + sample_rate = audio_output.get("sr", default_sr) + if hasattr(sample_rate, "item"): + sample_rate = sample_rate.item() + + # Convert tensor to numpy + if hasattr(audio_tensor, "float"): + audio_tensor = audio_tensor.float().detach().cpu().numpy() + + # Squeeze batch dimension if present, but preserve channel dimension for stereo + if audio_tensor.ndim > 1: + audio_tensor = audio_tensor.squeeze() + + audio_obj = CreateAudio( + audio_tensor=audio_tensor, + sample_rate=int(sample_rate), + response_format=request.response_format or "wav", + speed=request.speed or 1.0, + stream_format=request.stream_format, + base64_encode=False, + ) + + audio_response: AudioResponse = self.create_audio(audio_obj) + return Response(content=audio_response.audio_data, media_type=audio_response.media_type) + + except asyncio.CancelledError: + return self.create_error_response("Client disconnected") + except ValueError as e: + return self.create_error_response(e) + except Exception as e: + logger.exception("Audio generation failed: %s", e) + return self.create_error_response(f"Audio generation failed: {e}") diff --git a/vllm_omni/entrypoints/openai/serving_speech.py b/vllm_omni/entrypoints/openai/serving_speech.py index de9c559ddfb..dbca3749117 100644 --- a/vllm_omni/entrypoints/openai/serving_speech.py +++ b/vllm_omni/entrypoints/openai/serving_speech.py @@ -83,6 +83,7 @@ def _validate_path_within_directory(file_path: Path, directory: Path) -> bool: class OmniOpenAIServingSpeech(OpenAIServing, AudioMixin): def __init__(self, *args, **kwargs): + self.model_name = kwargs.pop("model_name", None) super().__init__(*args, **kwargs) # Initialize uploaded speakers storage speech_voice_samples_dir = os.environ.get("SPEECH_VOICE_SAMPLES", "/tmp/voice_samples") @@ -753,6 +754,8 @@ async def create_speech( request_id = f"speech-{random_uuid()}" try: + sampling_params_list = self.engine_client.default_sampling_params_list + default_sr = 24000 # Default sample rate for TTS models if self._is_tts: # Validate TTS parameters validation_error = self._validate_tts_request(request) @@ -778,8 +781,6 @@ async def create_speech( tts_params.get("task_type", ["unknown"])[0], ) - sampling_params_list = self.engine_client.default_sampling_params_list - generator = self.engine_client.generate( prompt=prompt, request_id=request_id, @@ -806,7 +807,7 @@ async def create_speech( return self.create_error_response("TTS model did not produce audio output.") audio_tensor = audio_output[audio_key] - sr_raw = audio_output.get("sr", 24000) + sr_raw = audio_output.get("sr", default_sr) sr_val = sr_raw[-1] if isinstance(sr_raw, list) and sr_raw else sr_raw sample_rate = sr_val.item() if hasattr(sr_val, "item") else int(sr_val) diff --git a/vllm_omni/entrypoints/utils.py b/vllm_omni/entrypoints/utils.py index 69ce73fc47a..beec074d89a 100644 --- a/vllm_omni/entrypoints/utils.py +++ b/vllm_omni/entrypoints/utils.py @@ -169,22 +169,46 @@ def _convert_dataclasses_to_dict(obj: Any) -> Any: return obj -def resolve_model_config_path(model: str) -> str: - """Resolve the stage config file path from the model name. +def resolve_model_config_path(model_type: str) -> str | None: + """Resolve the stage config file path from the model type. Resolves stage configuration path based on the model type and device type. First tries to find a device-specific YAML file from stage_configs/{device_type}/ directory. If not found, falls back to the default config file. Args: - model: Model name or path (used to determine model_type) + model_type: Model type string Returns: - String path to the stage configuration file + String path to the stage configuration file if found, None otherwise + """ + default_config_path = current_omni_platform.get_default_stage_config_path() + config_file_name = f"{model_type}.yaml" + complete_config_path = PROJECT_ROOT / default_config_path / config_file_name + if os.path.exists(complete_config_path): + return str(complete_config_path) + + # Fall back to default config + stage_config_file = f"vllm_omni/model_executor/stage_configs/{model_type}.yaml" + stage_config_path = PROJECT_ROOT / stage_config_file + if not os.path.exists(stage_config_path): + return None + return str(stage_config_path) + + +def resolve_model_type(model: str) -> str: + """Resolve the model type from the model name. + + Args: + model: Model name or path + + Returns: + Model type string (e.g. ``"Qwen3TTSForConditionalGeneration"``, + ``"StableAudioPipeline"``). Raises: - ValueError: If model_type cannot be determined - FileNotFoundError: If no stage config file exists for the model type + ValueError: If the model type cannot be determined from any + available configuration file. """ # Try to get config from standard transformers format first try: @@ -217,42 +241,26 @@ def resolve_model_config_path(model: str) -> str: f"Please ensure the model has proper configuration files with 'model_type' field" ) - default_config_path = current_omni_platform.get_default_stage_config_path() - model_type_str = f"{model_type}.yaml" - complete_config_path = PROJECT_ROOT / default_config_path / model_type_str - if os.path.exists(complete_config_path): - return str(complete_config_path) - - # Fall back to default config - stage_config_file = f"vllm_omni/model_executor/stage_configs/{model_type}.yaml" - stage_config_path = PROJECT_ROOT / stage_config_file - if not os.path.exists(stage_config_path): - return None - return str(stage_config_path) - + return model_type -def load_stage_configs_from_model(model: str, base_engine_args: dict | None = None) -> list: - """Load stage configurations from model's default config file. - Loads stage configurations based on the model type and device type. - First tries to load a device-specific YAML file from stage_configs/{device_type}/ - directory. If not found, falls back to the default config file. +def load_stage_configs_from_model(config_path: str | None, base_engine_args: dict | None = None) -> list: + """Load stage configurations from a resolved config file path. Args: - model: Model name or path (used to determine model_type) + config_path: Path to the YAML configuration file, or None. + When None, returns an empty list. + base_engine_args: Optional engine arguments to merge with stage configs. Returns: - List of stage configuration dictionaries - - Raises: - FileNotFoundError: If no stage config file exists for the model type + List of stage configuration dictionaries, or empty list if + config_path is None. """ if base_engine_args is None: base_engine_args = {} - stage_config_path = resolve_model_config_path(model) - if stage_config_path is None: + if config_path is None: return [] - stage_configs = load_stage_configs_from_yaml(config_path=stage_config_path, base_engine_args=base_engine_args) + stage_configs = load_stage_configs_from_yaml(config_path=config_path, base_engine_args=base_engine_args) return stage_configs @@ -306,9 +314,10 @@ def load_and_resolve_stage_configs( Returns: Tuple of (config_path, stage_configs) """ + model_type = resolve_model_type(model) if stage_configs_path is None: - config_path = resolve_model_config_path(model) - stage_configs = load_stage_configs_from_model(model, base_engine_args=kwargs) + config_path = resolve_model_config_path(model_type) + stage_configs = load_stage_configs_from_model(config_path, base_engine_args=kwargs) if not stage_configs: if default_stage_cfg_factory is not None: default_stage_cfg = default_stage_cfg_factory()